In [None]:
import os
import json
import tqdm
import torch
import pickle
import numpy as np
import pandas as pd
from utils import timestamp
from evaluation import *
from sklearn.linear_model import Ridge, RidgeCV
from sklearn.model_selection import train_test_split

SEED = 0

## activationの読み込み

In [None]:
def load_activation_probing_dataset(layer):
    #TODO: npzのfilePATH作成 (./dataset/act/FILE_NAME = stack状態を後で確認)
    with np.load('idx0.npz') as f:
        if layer==0:
            activations = f["deter"]
        elif layer==1:
            activations = f["stoch"]
        elif layer==2:
            activations = f["logit"]
        elif layer==3:
            activations = f["enc"]
        elif layer==4:
            activations = f["dec"]
        elif layer==5:
            activations = f["pol"]
        else:
            raise ValueError("[original] There is no layer with activations.")
    return activations

## probeの保存

In [None]:
def save_probe_results(is_place, probe_results):
    save_path = os.path.join(
        os.getenv('RESULTS_DIR', 'results'),
        'place' if is_place else 'time'
    )
    os.makedirs(save_path, exist_ok=True)

    probe_name = 'probe.p'

    pickle.dump(
        probe_results,
        open(os.path.join(save_path, probe_name), 'wb')
    )

## 時空間情報のprobing関数

In [None]:
def probe_experiment(activations, target, is_place, probe=None):
    train_activations, test_activations = train_test_split(activations, test_size=0.2, random_state=SEED)

    is_test = np.full(activations.shape[0], False)
    test_indices = np.isin(activations, test_activations).all(axis=1)
    is_test[test_indices] = True

    train_target = target[~is_test]
    test_target = target[is_test]

    norm_train_target = (train_target - train_target.mean(axis=0)) / train_target.std(axis=0)

    if probe is None:
        # activations.shapeの確認
        probe = Ridge(alpha=activations.shape[1])

    probe.fit(train_activations, norm_train_target)

    train_pred = probe.predict(train_activations)
    test_pred = probe.predict(test_activations)

    if is_place:
        train_pred_unnorm = train_pred * train_target.std(axis=0) + train_target.mean(axis=0)
        test_pred_unnorm = test_pred * train_target.std(axis=0) + train_target.mean(axis=0)

        projection = probe.predict(activations) * train_target.std(axis=0) + train_target.mean(axis=0)

        train_scores = score_place_probe(train_target, train_pred_unnorm)
        test_scores = score_place_probe(test_target, test_pred_unnorm)
    else:
        train_pred_unnorm = train_pred * train_target.std() + train_target.mean()
        test_pred_unnorm = test_pred * train_target.std() + train_target.mean()

        projection = probe.predict(activations) * train_target.std() + train_target.mean()

        train_scores = score_time_probe(train_target, train_pred_unnorm)
        test_scores = score_time_probe(test_target, test_pred_unnorm)

    scores = {
        **{('train', k): v for k, v in train_scores.items()},
        **{('test', k): v for k, v in test_scores.items()},
    }

    error_matrix = compute_proximity_error_matrix(target, projection, is_place)

    train_error, test_error, combined_error = proximity_scores(error_matrix, is_test)
    scores['train', 'prox_error'] = train_error.mean()
    scores['test', 'prox_error'] = test_error.mean()

    if is_place:
        projection_df = pd.DataFrame({
            'x': projection[:, 0],
            'y': projection[:, 1],
            'is_test': is_test,
            'x_error': projection[:, 0] - target[:, 0],
            'y_error': projection[:, 1] - target[:, 1],
            'prox_error': combined_error,
        })
    else:
        projection_df = pd.DataFrame({
            'projection': projection,
            'is_test': is_test,
            'error': projection - target,
            'prox_error': combined_error,
        })

    return probe, scores, projection_df

In [None]:
def get_target_values(is_place):
    file_path = #TODO: jsonlのfilePATH作成
    target = []

    with open(file_path, "r") as file:
        for line in file:
            data = json.loads(line)
            if is_place and "pos" in data:
                target.append({"pos": data["pos"]})
            elif not(is_place) and "step" in data:
                target.append({"step": data["step"]})

    target = pd.DataFrame(target).values
    return target

## メインprobing関数

In [None]:
def main_probe_experiment(is_place):
    n_layers = 6 # encoder, rssm(=deter, stoch, logits), decoder, policy

    results = {
        'scores': {},
        'projections': {},
        'probe_directions': {},
        'probe_biases': {},
        'probe_alphas': {},
    }

    for layer in tqdm.tqdm(range(n_layers)):
        activations = load_activation_probing_dataset(layer).dequantize()

        if activations.isnan().any():
            print(timestamp(), 'WARNING: nan activations, skipping layer', layer)
            continue

        target = get_target_values(is_place)

        #TODO: alpha値の設定
        probe = RidgeCV(alphas=np.logspace(0.8, 4.1, 12), store_cv_values=True)

        probe, scores, projection = probe_experiment(activations, target, is_place, probe=probe)
        
        probe_direction = probe.coef_.T.astype(np.float16)
        probe_alphas = probe.cv_values_.mean(axis=(0, 1) if is_place else 0)

        results['scores'][layer] = scores
        results['projections'][layer] = projection
        results['probe_directions'][layer] = probe_direction
        results['probe_biases'][layer] = probe.intercept_
        results['probe_alphas'][layer] = probe_alphas

    save_probe_results(results)