In [2]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pickle
import os
import pandas as pd
import configs
import torch

import warnings
from sklearn.exceptions import ConvergenceWarning
warnings.filterwarnings('ignore', category=ConvergenceWarning)
warnings.filterwarnings('ignore', category=FutureWarning)

from sklearn.decomposition import PCA
from scipy.spatial.distance import pdist, squareform
from sklearn.linear_model import LinearRegression
from sklearn.linear_model import RidgeCV
from sklearn.linear_model import Ridge
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
import numpy as np

from src.utils import find_ckpt_file, convert_to_tensor
import h5py
import random
from src.evals.eval_trees import EvalTrees
from src.evals.eval_trees import EvalCntrees

# Arguments 

In [3]:
corr = 0.25
seq_length = 1200
start_idx = 900  # where to start drawing tokens from, until seq_length

# Load Model

In [4]:
engram_dir = "/n/holylfs06/LABS/krajan_lab/Lab/cfang/icl-maze/"
wandb_project = "tree_maze"
env_name = f"cntree_layers7_bprob0.9_corr{corr}_state_dim10_envs300000_H800_explore"
if corr == 0.25:
    model_name = "transformer_end_query_embd512_layer3_head4_lr0.0001_drop0.2_initseed0_batch512"
elif corr == 0.:
    model_name = "transformer_end_query_embd512_layer3_head4_lr1e-05_drop0_initseed1_batch256"
else:
    raise ValueError(f"Unknown correlation value: {corr}")
model_path = os.path.join(engram_dir, wandb_project, env_name, "models", model_name)
ckpt_name = find_ckpt_file(model_path, "best")
print(ckpt_name)
path_to_pkl = os.path.join(model_path, ckpt_name)

eval_dset_path = f"/n/holylfs06/LABS/krajan_lab/Lab/cfang/icl-maze/cntree/cntree_layers7_bprob1.0_corr{corr}_state_dim10_envs1000_H1600_explore/datasets/eval.pkl"

epoch=31-val_loss=0.000400.ckpt


In [5]:
# Extract parameters using regex
import re

n_embd = int(re.search(r'embd(\d+)', model_name).group(1))
n_layer = int(re.search(r'layer(\d+)', model_name).group(1))
n_head = int(re.search(r'head(\d+)', model_name).group(1))
dropout = float(re.search(r'drop(\d*\.?\d*)', model_name).group(1))

# Extract correlation and state_dim from eval dataset path
state_dim = int(re.search(r'state_dim(\d+)', eval_dset_path).group(1))

model_config = {
    "n_embd": n_embd,
    "n_layer": n_layer,
    "n_head": n_head,
    "state_dim": 10,
    "action_dim": 4,
    "dropout": dropout,
    "train_on_last_pred_only": False,
    "test": True,
    "name": "transformer_end_query",
    "optimizer_config": None,
    "linear_attention": False,
}

In [6]:
from src.models.transformer_end_query import Transformer
model_config['initialization_seed'] = 0
model = Transformer(**model_config)
checkpoint = torch.load(path_to_pkl)
model.load_state_dict(checkpoint['state_dict'])
model.eval()
model = model.to('cuda')

# Load Dataset and Create Environment

In [7]:
n_eval_envs = -1 #50

is_h5_file = eval_dset_path.endswith('.h5')
if is_h5_file:
    eval_trajs = h5py.File(eval_dset_path, 'r')
    traj_indices = list(eval_trajs.keys())
    n_eval_envs = min(n_eval_envs, len(traj_indices))
    random.seed(0)
    traj_indices = random.sample(traj_indices, n_eval_envs)
    random.seed()
    if n_eval_envs != -1:
        eval_trajs = [eval_trajs[i] for i in traj_indices]
    else:
        n_eval_envs = len(traj_indices)
else:  # Pickle file
    with open(eval_dset_path, 'rb') as f:
        eval_trajs = pickle.load(f)
    n_eval_envs = min(n_eval_envs, len(eval_trajs))
    if n_eval_envs != -1:
        random.seed(0)
        eval_trajs = random.sample(eval_trajs, n_eval_envs)
        random.seed()
    else:
        n_eval_envs = len(eval_trajs)


In [8]:
def run_model(traj, model, seq_length=1200, start_idx=800):
    hidden_states = []

    batch = {
        'context_states': convert_to_tensor([np.array(traj['context_states'])]),
        'context_actions': convert_to_tensor([np.array(traj['context_actions'])]),
        'context_next_states': convert_to_tensor([np.array(traj['context_next_states'])]),
        'context_rewards': convert_to_tensor([np.array(traj['context_rewards'])[:, None]]),
        'query_states': convert_to_tensor([np.array(traj['query_state'])]),  # Ignored
        }
    batch['zeros'] = torch.zeros(1, 10 ** 2 + 4 + 1).float()
    for k in batch.keys():
        if 'context' in k:
            batch[k] = batch[k][:,:seq_length]
        batch[k] = batch[k].to(model.device)
    model.save_activations = True
    with torch.no_grad():
        out = model(batch)
    _hidden_states = model.activations['hidden_states'][1:] # Tuple over layers of (1, seq, dim)
    state_features = batch['context_states'][0][start_idx:].to('cpu').numpy()
    next_state_features = batch['context_next_states'][0][start_idx:].to('cpu').numpy()
    actions = batch['context_actions'][0].argmax(dim=1)[start_idx:].to('cpu').numpy()
    for i_layer in range(model.n_layer):
        hidden_states.append(_hidden_states[i_layer][0,start_idx:-1])
    return hidden_states, state_features, next_state_features, actions

# Across context decoding
(within-context is not that good)

In [9]:
def get_subtree_location(layer, pos, subtree):
    midpt = 2**(layer-1)
    quarter_pt = midpt//2
    eighth_pt = quarter_pt//2
    if layer == 0:
        return 0
    if subtree == 'half':
        return 1 if pos < midpt else 2
    elif subtree == 'quarter':
        if layer == 1:
            return 0
        bins = np.arange(0, 2**layer, quarter_pt)
        return np.digitize([pos], bins)[0]
    elif subtree == 'eighth':
        if (layer == 1) or (layer == 2):
            return 0
        bins = np.arange(0, 2**layer, eighth_pt)
        return np.digitize([pos], bins)[0]

    

In [14]:
train_envs = np.arange(1, int(n_eval_envs*0.9))
test_envs = np.arange(int(n_eval_envs*0.9), n_eval_envs)

def make_train_test_matrices():
    X_train = [[] for _ in range(model.n_layer)]
    X_test = [[] for _ in range(model.n_layer)]
    Ys_dict = {
        "dist_from_goal": {"Y_train": [], "Y_test": []},
        "layer": {"Y_train": [], "Y_test": []},
        "node_identity": {"Y_train": [], "Y_test": []},
        "maze_half": {"Y_train": [], "Y_test": []},
        "maze_quarter": {"Y_train": [], "Y_test": []},
        "maze_eighth": {"Y_train": [], "Y_test": []},
        'is_goal': {"Y_train": [], "Y_test": []},
        'same_half_as_goal': {"Y_train": [], "Y_test": []},
        'opt_action': {"Y_train": [], "Y_test": []},
        'state_feature': {"Y_train": [], "Y_test": []},
        'next_state_feature': {"Y_train": [], "Y_test": []},
        "on_path": {"Y_train": [], "Y_test": []},
        "on_lr_path": {"Y_train": [], "Y_test": []},
    }
    for i_eval in range(n_eval_envs):
        traj = eval_trajs[i_eval]
        first_reward = np.argwhere(np.array(traj['context_rewards'])>0)
        if (first_reward.size == 0) or (first_reward[0] > start_idx):
            continue

        env_config = {
            'max_layers': 7,
            'horizon': 1600,
            'branching_prob': 1.0,
            'node_encoding_corr': corr,
            'state_dim': state_dim,
            'initialization_seed': np.array(traj['initialization_seed']).item()
        }
        env = EvalCntrees().create_env(env_config, np.array(traj['goal']), i_eval)
        opt_action_map, dist_from_goal = env.make_opt_action_dict()


        hidden_states, state_features, next_state_features, actions = run_model(traj, model, seq_length, start_idx)
        goal_node = env.node_map[tuple(env.goal.tolist())]
        goal_layer = goal_node.layer
        goal_pos = goal_node.pos
        for state_idx, state_feature in enumerate(state_features):
            next_state_feature = next_state_features[state_idx]
            state_feature_tuple = tuple(state_feature.tolist())
            d = dist_from_goal[state_feature_tuple]
            layer = env.node_map[state_feature_tuple].layer
            pos = env.node_map[state_feature_tuple].pos
            node_identity = 2**layer + pos
            maze_half = get_subtree_location(layer, pos, 'half')
            maze_quarter = get_subtree_location(layer, pos, 'quarter')
            maze_eighth = get_subtree_location(layer, pos, 'eighth')
            
            Y_key = "Y_train" if i_eval in train_envs else "Y_test"
            Ys_dict["dist_from_goal"][Y_key].append(d)
            Ys_dict["layer"][Y_key].append(layer)
            Ys_dict["node_identity"][Y_key].append(node_identity)
            Ys_dict["maze_half"][Y_key].append(maze_half)
            Ys_dict["maze_quarter"][Y_key].append(maze_quarter)
            Ys_dict["maze_eighth"][Y_key].append(maze_eighth)
            Ys_dict["is_goal"][Y_key].append(state_feature_tuple == tuple(env.goal.tolist()))
            Ys_dict["same_half_as_goal"][Y_key].append(maze_half == get_subtree_location(goal_layer, goal_pos, 'half'))
            Ys_dict["opt_action"][Y_key].append(opt_action_map[state_feature_tuple])
            Ys_dict["state_feature"][Y_key].append(state_feature)
            Ys_dict["next_state_feature"][Y_key].append(next_state_feature)
            Ys_dict["on_path"][Y_key].append(actions[state_idx] == opt_action_map[state_feature_tuple])
            Ys_dict["on_lr_path"][Y_key].append(
                (actions[state_idx] == opt_action_map[state_feature_tuple]) and
                (actions[state_idx] in [1, 2])
                )

            for layer in range(model.n_layer):
                hidden_state = hidden_states[layer][state_idx]
                if i_eval in train_envs:
                    X_train[layer].append(hidden_state.to('cpu').numpy())
                else:
                    X_test[layer].append(hidden_state.to('cpu').numpy())
        torch.cuda.empty_cache()
    return X_train, X_test, Ys_dict

In [15]:
import warnings

from sklearn.metrics import balanced_accuracy_score
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
import numpy as np

def fit_and_evaluate_regression(X_train, Y_train, X_test, Y_test, print_scores=True):
    from joblib import Parallel, delayed
    from sklearn.model_selection import KFold
    from sklearn.linear_model import Ridge
    
    X_train_np = [np.array([_x for _x in x]) for x in X_train]
    X_test_np = [np.array([_x for _x in x]) for x in X_test]
    Y_train_np = np.array(Y_train)
    Y_test_np = np.array(Y_test)

    alphas = np.logspace(0, 4, 10)
    n_splits = 5
    kf = KFold(n_splits=n_splits, shuffle=True, random_state=42)

    def evaluate_fold(X, y, train_idx, val_idx, alpha):
        # Train on this fold
        pipeline = Pipeline([
            ('scaler', StandardScaler()),
            ('ridge', Ridge(alpha=alpha))
        ])
        pipeline.fit(X[train_idx], y[train_idx])
        # Get validation score
        val_score = pipeline.score(X[val_idx], y[val_idx])
        return val_score

    pipelines = []
    test_scores = []
    
    for layer in range(len(X_train)-1):
        # Parallel CV for each alpha
        cv_scores = {alpha: [] for alpha in alphas}
        for alpha in alphas:
            scores = Parallel(n_jobs=-1)(
                delayed(evaluate_fold)(
                    X_train_np[layer], Y_train_np, 
                    train_idx, val_idx, alpha
                )
                for train_idx, val_idx in kf.split(X_train_np[layer])
            )
            cv_scores[alpha] = np.mean(scores)
        
        # Find best alpha
        best_alpha = max(cv_scores.items(), key=lambda x: x[1])[0]
        
        # Train final model with best alpha
        pipeline = Pipeline([
            ('scaler', StandardScaler()),
            ('ridge', Ridge(alpha=best_alpha))
        ])
        pipeline.fit(X_train_np[layer], Y_train_np)
        
        train_score = pipeline.score(X_train_np[layer], Y_train_np)
        test_score = pipeline.score(X_test_np[layer], Y_test_np)
        
        pipelines.append(pipeline)
        test_scores.append(test_score)
        
        if print_scores:
            print(f"Layer {layer}:")
            print(f"Best alpha: {best_alpha:.3f}")
            print(f"Train R2: {train_score:.3f}")
            print(f"Test R2: {test_score:.3f}")
            print()
            
    return pipelines, test_scores

def fit_and_evaluate_classification(X_train, Y_train, X_test, Y_test, print_scores=True):
    import warnings
    from sklearn.exceptions import ConvergenceWarning
    warnings.filterwarnings('ignore', category=ConvergenceWarning)
    warnings.filterwarnings('ignore', category=FutureWarning)
    from joblib import Parallel, delayed
    from sklearn.model_selection import KFold
    from sklearn.linear_model import LogisticRegression
    from sklearn.metrics import accuracy_score, f1_score

    X_train_np = [np.array([_x for _x in x]) for x in X_train]
    X_test_np = [np.array([_x for _x in x]) for x in X_test]
    Y_train_np = np.array(Y_train)
    Y_test_np = np.array(Y_test)

    Cs = np.logspace(-4, 4, 10)
    n_splits = 5
    kf = KFold(n_splits=n_splits, shuffle=True, random_state=42)

    def evaluate_fold(X, y, train_idx, val_idx, C):
        pipeline = Pipeline([
            ('scaler', StandardScaler()),
            ('classifier', LogisticRegression(
                C=C, 
                max_iter=3000,
                class_weight='balanced',  # Add class weighting
                random_state=42
            ))
        ])
        pipeline.fit(X[train_idx], y[train_idx])
        y_val_pred = pipeline.predict(X[val_idx])
        # Use balanced accuracy score instead of regular accuracy
        return balanced_accuracy_score(y[val_idx], y_val_pred)

    pipelines = []
    test_scores = []
    
    for layer in range(len(X_train)-1):
        # Parallel CV for each C value
        cv_scores = {C: [] for C in Cs}
        for C in Cs:
            scores = Parallel(n_jobs=-1)(
                delayed(evaluate_fold)(
                    X_train_np[layer], Y_train_np, 
                    train_idx, val_idx, C
                )
                for train_idx, val_idx in kf.split(X_train_np[layer])
            )
            cv_scores[C] = np.mean(scores)
        
        # Find best C
        best_C = max(cv_scores.items(), key=lambda x: x[1])[0]
        
        # Train final model with best C
        pipeline = Pipeline([
            ('scaler', StandardScaler()),
            ('classifier', LogisticRegression(
                C=best_C, 
                max_iter=3000,
                class_weight='balanced',  # Add class weighting
                random_state=42
            ))
        ])
        pipeline.fit(X_train_np[layer], Y_train_np)
        
        y_train_pred = pipeline.predict(X_train_np[layer])
        y_test_pred = pipeline.predict(X_test_np[layer])
        
        # Use balanced metrics
        train_accuracy = balanced_accuracy_score(Y_train_np, y_train_pred)
        test_accuracy = balanced_accuracy_score(Y_test_np, y_test_pred)
        train_f1 = f1_score(Y_train_np, y_train_pred, average='weighted')
        test_f1 = f1_score(Y_test_np, y_test_pred, average='weighted')

        if print_scores:
            print(f"Layer {layer}:")
            print(f"Best C: {best_C:.3f}")
            print(f"Train Balanced Accuracy: {train_accuracy:.3f}")
            print(f"Test Balanced Accuracy: {test_accuracy:.3f}")
            print(f"Train Weighted F1: {train_f1:.3f}")
            print(f"Test Weighted F1: {test_f1:.3f}")
            # Add class distribution information
            print("Class distribution:")
            for cls in np.unique(Y_train_np):
                print(f"Class {cls}: {np.sum(Y_train_np == cls)} samples")
            print()

    return pipelines, test_scores

In [16]:
X_train, X_test, Ys_dict = make_train_test_matrices()

In [17]:
pipeline, test_score = fit_and_evaluate_classification(X_train, Ys_dict["on_lr_path"]["Y_train"], X_test, Ys_dict["on_lr_path"]["Y_test"])

Layer 0:
Best C: 0.046
Train Balanced Accuracy: 0.960
Test Balanced Accuracy: 0.681
Train Weighted F1: 0.947
Test Weighted F1: 0.932
Class distribution:
Class False: 18000 samples
Class True: 300 samples

Layer 1:
Best C: 0.006
Train Balanced Accuracy: 0.973
Test Balanced Accuracy: 0.879
Train Weighted F1: 0.964
Test Weighted F1: 0.956
Class distribution:
Class False: 18000 samples
Class True: 300 samples



In [65]:
pipeline, test_score = fit_and_evaluate_classification(X_train, Ys_dict["on_path"]["Y_train"], X_test, Ys_dict["on_path"]["Y_test"])

Layer 0:
Best C: 0.359
Train Balanced Accuracy: 0.978
Test Balanced Accuracy: 0.957
Train Weighted F1: 0.980
Test Weighted F1: 0.961
Class distribution:
Class False: 18555 samples
Class True: 8245 samples

Layer 1:
Best C: 2.783
Train Balanced Accuracy: 0.986
Test Balanced Accuracy: 0.956
Train Weighted F1: 0.986
Test Weighted F1: 0.962
Class distribution:
Class False: 18555 samples
Class True: 8245 samples



In [35]:
pipeline, test_score = fit_and_evaluate_classification(X_train, Ys_dict["opt_action"]["Y_train"], X_test, Ys_dict["opt_action"]["Y_test"])

Layer 0:
Best C: 21.544
Train Balanced Accuracy: 0.957
Test Balanced Accuracy: 0.638
Train Weighted F1: 0.905
Test Weighted F1: 0.839
Class distribution:
Class 0: 25072 samples
Class 1: 682 samples
Class 2: 803 samples
Class 3: 243 samples

Layer 1:
Best C: 2.783
Train Balanced Accuracy: 0.988
Test Balanced Accuracy: 0.682
Train Weighted F1: 0.960
Test Weighted F1: 0.901
Class distribution:
Class 0: 25072 samples
Class 1: 682 samples
Class 2: 803 samples
Class 3: 243 samples



In [31]:
pipeline, test_score = fit_and_evaluate_classification(X_train, Ys_dict["is_goal"]["Y_train"], X_test, Ys_dict["is_goal"]["Y_test"])

Layer 0:
Best C: 0.046
Train Balanced Accuracy: 1.000
Test Balanced Accuracy: 1.000
Train Weighted F1: 1.000
Test Weighted F1: 1.000
Class distribution:
Class False: 26557 samples
Class True: 243 samples

Layer 1:
Best C: 0.046
Train Balanced Accuracy: 1.000
Test Balanced Accuracy: 1.000
Train Weighted F1: 1.000
Test Weighted F1: 1.000
Class distribution:
Class False: 26557 samples
Class True: 243 samples



In [16]:
pipeline, test_score = fit_and_evaluate_classification(X_train, Ys_dict["same_half_as_goal"]["Y_train"], X_test, Ys_dict["same_half_as_goal"]["Y_test"])

Layer 0:
Best C: 10000.000
Train Balanced Accuracy: 0.744
Test Balanced Accuracy: 0.556
Train Weighted F1: 0.744
Test Weighted F1: 0.591
Class distribution:
Class False: 13375 samples
Class True: 13425 samples

Layer 1:
Best C: 21.544
Train Balanced Accuracy: 0.809
Test Balanced Accuracy: 0.549
Train Weighted F1: 0.809
Test Weighted F1: 0.596
Class distribution:
Class False: 13375 samples
Class True: 13425 samples



In [13]:
pipeline, test_score = fit_and_evaluate_classification(
    X_train,
    [y <= 3 for y in Ys_dict["dist_from_goal"]["Y_train"]],
    X_test,
    [y <= 3 for y in Ys_dict["dist_from_goal"]["Y_test"]]
    )

Layer 0:
Best C: 2.783
Train Balanced Accuracy: 0.950
Test Balanced Accuracy: 0.737
Train Weighted F1: 0.943
Test Weighted F1: 0.857
Class distribution:
Class False: 17212 samples
Class True: 1088 samples

Layer 1:
Best C: 2.783
Train Balanced Accuracy: 0.985
Test Balanced Accuracy: 0.858
Train Weighted F1: 0.978
Test Weighted F1: 0.922
Class distribution:
Class False: 17212 samples
Class True: 1088 samples



In [18]:
pipeline, test_score = fit_and_evaluate_classification(
    X_train,
    [y == 6 for y in Ys_dict["layer"]["Y_train"]],
    X_test,
    [y == 6 for y in Ys_dict["layer"]["Y_test"]])

Layer 0:
Best C: 1291.550
Train Balanced Accuracy: 0.913
Test Balanced Accuracy: 0.865
Train Weighted F1: 0.912
Test Weighted F1: 0.865
Class distribution:
Class False: 14821 samples
Class True: 11979 samples

Layer 1:
Best C: 2.783
Train Balanced Accuracy: 0.983
Test Balanced Accuracy: 0.958
Train Weighted F1: 0.983
Test Weighted F1: 0.960
Class distribution:
Class False: 14821 samples
Class True: 11979 samples



In [26]:
pipeline, test_score = fit_and_evaluate_regression(X_train, Ys_dict["layer"]["Y_train"], X_test, Ys_dict["layer"]["Y_test"])

Layer 0:
Best alpha: 21.544
Train R2: 0.645
Test R2: 0.549

Layer 1:
Best alpha: 21.544
Train R2: 0.845
Test R2: 0.743

Layer 2:
Best alpha: 21.544
Train R2: 0.866
Test R2: 0.763



In [None]:
pipeline, test_score = fit_and_evaluate_classification(X_train, Ys_dict["node_identity"]["Y_train"], X_test, Ys_dict["node_identity"]["Y_test"])

In [15]:
pipeline, test_score = fit_and_evaluate_classification(X_train, Ys_dict["maze_half"]["Y_train"], X_test, Ys_dict["maze_half"]["Y_test"])

Layer 0:
Best C: 0.359
Train Balanced Accuracy: 0.816
Test Balanced Accuracy: 0.399
Train Weighted F1: 0.730
Test Weighted F1: 0.453
Class distribution:
Class 0: 227 samples
Class 1: 14722 samples
Class 2: 11851 samples

Layer 1:
Best C: 0.046
Train Balanced Accuracy: 0.858
Test Balanced Accuracy: 0.463
Train Weighted F1: 0.790
Test Weighted F1: 0.461
Class distribution:
Class 0: 227 samples
Class 1: 14722 samples
Class 2: 11851 samples



In [None]:
pipeline, test_score = fit_and_evaluate_classification(X_train, Ys_dict["maze_quarter"]["Y_train"], X_test, Ys_dict["maze_quarter"]["Y_test"])

In [None]:
pipeline, test_score = fit_and_evaluate_classification(X_train, Ys_dict["maze_eighth"]["Y_train"], X_test, Ys_dict["maze_eighth"]["Y_test"])

In [None]:
pipeline, test_score = fit_and_evaluate_classification(X_train, Ys_dict["action"]["Y_train"], X_test, Ys_dict["action"]["Y_test"])