上から順番に実行

In [None]:
import os
import json
import tqdm
import torch
import pickle
import numpy as np
import pandas as pd
from mlp import *
from evaluation import *
from utils import timestamp
from collections import namedtuple
from sklearn.linear_model import Ridge, RidgeCV
from sklearn.model_selection import train_test_split

SEED = 0

## activationの読み込み

In [None]:
def load_activation_probing_dataset(layer, is_place):
    if is_place:
        file_path = "./dataset/act/space.npz"
    else: 
        file_path = "./dataset/act/time.npz"
    with np.load(file_path) as f:
        if layer==0:
            activations = f["deters"]
        elif layer==1:
            activations = f["stochs"]
        # elif layer==2:
        #     activations = f["logit"]
        # elif layer==3:
        #     activations = f["enc"]
        # elif layer==4:
        #     activations = f["dec"]
        # elif layer==5:
        #     activations = f["pol"]
        else:
            raise ValueError("[original] There is no layer with activations.")
    return activations

## probeの保存

In [None]:
def save_probe_results(is_place, is_linear, probe_results):
    save_path = os.path.join(
        os.getenv('RESULTS_DIR', 'results')
    )
    os.makedirs(save_path, exist_ok=True)

    model_name = 'linear' if is_linear else 'mlp'
    objective = 'place' if is_place else 'time'

    probe_name = f'probing_{model_name}_{objective}.p'

    pickle.dump(
        probe_results,
        open(os.path.join(save_path, probe_name), 'wb')
    )

## 時空間情報のprobing関数

In [None]:
def probe_experiment(activations, target, is_place, probe=None):
    indices = np.arange(len(activations))
    (
        train_activations, 
        test_activations,
        train_target,
        test_target,
        train_indices,
        test_indices,
    ) = train_test_split(activations, target, indices, test_size=0.2, random_state=SEED)
        
    is_test = np.full(activations.shape[0], False)
    # test_indices = np.isin(activations, test_activations).all(axis=1)  # check
    is_test[test_indices] = True

    # train_target = target[~is_test]
    # test_target = target[is_test]

    norm_train_target = (train_target - train_target.mean(axis=0)) / train_target.std(axis=0)

    if probe is None:
        # activations.shapeの確認
        probe = Ridge(alpha=activations.shape[1])

    probe.fit(train_activations, norm_train_target)

    train_pred = probe.predict(train_activations)
    test_pred = probe.predict(test_activations)

    if is_place:
        train_pred_unnorm = train_pred * train_target.std(axis=0) + train_target.mean(axis=0)
        test_pred_unnorm = test_pred * train_target.std(axis=0) + train_target.mean(axis=0)

        projection = probe.predict(activations) * train_target.std(axis=0) + train_target.mean(axis=0)

        train_scores = score_place_probe(train_target, train_pred_unnorm)
        test_scores = score_place_probe(test_target, test_pred_unnorm)
    else:
        train_pred_unnorm = train_pred * train_target.std() + train_target.mean()
        test_pred_unnorm = test_pred * train_target.std() + train_target.mean()

        projection = probe.predict(activations) * train_target.std() + train_target.mean()

        train_scores = score_time_probe(train_target, train_pred_unnorm)
        test_scores = score_time_probe(test_target, test_pred_unnorm)

    scores = {
        **{('train', k): v for k, v in train_scores.items()},
        **{('test', k): v for k, v in test_scores.items()},
    }

    error_matrix = compute_proximity_error_matrix(target, projection, is_place)

    train_error, test_error, combined_error = proximity_scores(error_matrix, is_test)
    scores['train', 'prox_error'] = train_error.mean()
    scores['test', 'prox_error'] = test_error.mean()

    if is_place:
        projection_df = pd.DataFrame({
            'x': projection[:, 0],
            'y': projection[:, 1],
            'is_test': is_test,
            'x_error': projection[:, 0] - target[:, 0],
            'y_error': projection[:, 1] - target[:, 1],
            'prox_error': combined_error,
        })
    else:
        target = target.ravel()
        projection_df = pd.DataFrame({
            'projection': projection,
            'is_test': is_test,
            'error': projection - target,
            'prox_error': combined_error,
        })

    return probe, scores, projection_df

In [None]:
def get_target_values(is_place):
    target = []
    if is_place:
        file_path = "./dataset/act/space.npz"
    else: 
        file_path = "./dataset/act/time.npz"
    with np.load(file_path) as f:
        if is_place:
            target = f["positions"]
        else:
            target = f["episodes"]

    target = torch.from_numpy(target)

    if is_place:
        target = target.view(-1, 2).numpy()
    else:
        target = target.view(-1, 1).numpy()

    target = pd.DataFrame(target).values
    return target

## メインprobing関数

In [None]:
D = namedtuple("Def", "N_LAYERS DETERS STOCHS")(
    N_LAYERS = 2, # encoder, rssm(=deter, stoch, logits), decoder, policy => rssm(=deter, stoch)
    DETERS = 0,
    STOCHS = 1
)

線形回帰モデル

In [None]:
def linear_probe_experiment(is_place):
    n_layers = D.N_LAYERS

    results = {
        'scores': {},
        'projections': {},
        'probe_directions': {},
        'probe_biases': {},
        'probe_alphas': {},
    }

    for layer in tqdm.tqdm(range(n_layers)):
        activations = load_activation_probing_dataset(layer, is_place)

        if layer==D.DETERS:
            size = activations.shape[2]
        elif layer==D.STOCHS:
            size = activations.shape[2] * activations.shape[3]
        
        activations = torch.from_numpy(activations).dequantize().view(-1, size)

        if activations.isnan().any():
            print(timestamp(), 'WARNING: nan activations, skipping layer', layer)
            continue

        activations = activations.numpy()

        target = get_target_values(is_place)

        #TODO: alpha値の設定
        probe = RidgeCV(alphas=np.logspace(0.8, 4.1, 12), store_cv_values=True)

        probe, scores, projection = probe_experiment(activations, target, is_place, probe=probe)
        
        probe_direction = probe.coef_.T.astype(np.float16)
        probe_alphas = probe.cv_values_.mean(axis=(0, 1) if is_place else 0)

        results['scores'][layer] = scores
        results['projections'][layer] = projection
        results['probe_directions'][layer] = probe_direction
        results['probe_biases'][layer] = probe.intercept_
        results['probe_alphas'][layer] = probe_alphas

    save_probe_results(is_place, True, results)

MLP

In [None]:
MLP_PARAM_DICT = {
    'weight_decay': [0.01, 0.03, 0.1, 0.3]
}

def mlp_experiment(activations, target, is_place):
    ridge_probe = RidgeCV(alphas=np.logspace(3, 4.5, 12), store_cv_values=True)

    probe, ridge_scores, ridge_projection_df = probe_experiment(
            activations, target, is_place, probe=ridge_probe)
    probe_cv_values = probe.cv_values_.mean(axis=((0, 1) if is_place else 0))

    mlp_results = {}
    val_scores = []
    for wd in MLP_PARAM_DICT['weight_decay']:
        mlp_probe = MLPRegressor(
            input_size=activations.shape[-1],
            output_size=2 if is_place else 1,
            hidden_size=256,
            patience=3,
            learning_rate=1e-3,
            weight_decay=wd
        )

        probe, mlp_scores, mlp_projection_df = probe_experiment(
            activations, target, is_place, probe=mlp_probe)

        val_scores.append(min(probe.validation_scores))
        mlp_results[wd] = (mlp_scores, mlp_projection_df)

    best_mlp_wd = MLP_PARAM_DICT['weight_decay'][np.argmin(val_scores)]
    mlp_scores, mlp_projection_df = mlp_results[best_mlp_wd]

    results = {
        'ridge_scores': ridge_scores,
        'mlp_scores': mlp_scores,
        'ridge_prediction_df': ridge_projection_df,
        'mlp_prediction_df': mlp_projection_df,
        'ridge_cv_values': probe_cv_values,
        'mlp_validation_scores': val_scores
    }
    return results

In [None]:
def mlp_probe_experiment(is_place):
    n_layers = D.N_LAYERS

    results = {}
    for layer in tqdm.tqdm(range(n_layers)):
        activations = load_activation_probing_dataset(layer, is_place)
        if layer == D.DETERS:
            size = activations.shape[2]
        elif layer == D.STOCHS:
            size = activations.shape[2] * activations.shape[3]
        activations = torch.from_numpy(activations).dequantize().view(-1, size)

        if activations.isnan().any():
            print(timestamp(), 'WARNING: nan activations, skipping layer', layer)
            continue
        activations = activations.numpy()

        target = get_target_values(is_place)

        layers_results = mlp_experiment(activations, target, is_place)

        results[layer] = layers_results

    save_probe_results(is_place, False, results)

## probeの実行

In [None]:
# <線形回帰モデル>空間情報のprobe
print("|=|=|=|=|=|===[Linear Model]Probing for Space Information===|=|=|=|=|=|")
is_place = True
linear_probe_experiment(is_place)

In [None]:
# <線形回帰モデル>時間情報のprobe
print("|=|=|=|=|=|===[Linear Model]Probing for Time Information===|=|=|=|=|=|")
is_place = False
linear_probe_experiment(is_place)

In [None]:
# <MLP>空間情報のprobe
print("|=|=|=|=|=|===[MLP]Probing for Space Information===|=|=|=|=|=|")
is_place = True
mlp_probe_experiment(is_place)

In [None]:
# <MLP>時間情報のprobe
print("|=|=|=|=|=|===[MLP]Probing for Time Information===|=|=|=|=|=|")
is_place = False
mlp_probe_experiment(is_place)