In [1]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pickle
import os
import pandas as pd
import configs
import torch
from sklearn.decomposition import PCA
from scipy.spatial.distance import pdist, squareform

from src.utils import find_ckpt_file, convert_to_tensor
import h5py
import random
from src.evals.eval_trees import EvalTrees
from src.evals.eval_trees import EvalCntrees

# Load Model

In [2]:
engram_dir = "/n/holylfs06/LABS/krajan_lab/Lab/cfang/icl-maze/"
wandb_project = "cntree"
corr = 0.0
env_name = f"cntree_layers7_bprob0.9_corr{corr}_state_dim10_envs300000_H800_explore"
#model_name = "transformer_end_query_embd512_layer4_head4_lr1e-05_drop0_batch256"
#model_name = "transformer_end_query_embd800_layer3_head4_lr0.0001_drop0_batch256"
#model_name = "transformer_end_query_embd512_layer4_head4_lr1e-05_drop0_batch256"
model_name = "transformer_end_query_embd512_layer3_head4_lr0.0001_drop0_batch256"
#model_name = "transformer_end_query_embd512_layer3_head4_lr0.0001_drop0.2_batch256"
model_path = os.path.join(engram_dir, wandb_project, env_name, "models", model_name)
ckpt_name = find_ckpt_file(model_path, "best")
print(ckpt_name)
path_to_pkl = os.path.join(model_path, ckpt_name)

eval_dset_path = f"/n/holylfs06/LABS/krajan_lab/Lab/cfang/icl-maze/cntree/cntree_layers7_bprob1.0_corr{corr}_state_dim10_envs1000_H1600_explore/datasets/eval.pkl"

epoch=30-val_loss=0.000777.ckpt


In [2]:
engram_dir = "/n/holylfs06/LABS/krajan_lab/Lab/cfang/icl-maze/"
wandb_project = "tree_maze" #"cntree"
corr = 0.0
env_name = f"cntree_layers7_bprob0.9_corr{corr}_state_dim10_envs300000_H800_explore"
#model_name = "transformer_end_query_embd512_layer4_head4_lr1e-05_drop0_batch256"
#model_name = "transformer_end_query_embd800_layer3_head4_lr0.0001_drop0_batch256"
#model_name = "transformer_end_query_embd512_layer4_head4_lr1e-05_drop0_batch256"
model_name = "transformer_end_query_embd512_layer3_head4_lr0.0001_drop0_batch256"
#model_name = "transformer_end_query_embd512_layer3_head4_lr0.0001_drop0.2_batch256"
model_name = "transformer_end_query_embd512_layer3_head4_lr1e-05_drop0_initseed1_batch512_nosched"
model_path = os.path.join(engram_dir, wandb_project, env_name, "models", model_name)
ckpt_name = find_ckpt_file(model_path, "best")
print(ckpt_name)
path_to_pkl = os.path.join(model_path, ckpt_name)

eval_dset_path = f"/n/holylfs06/LABS/krajan_lab/Lab/cfang/icl-maze/cntree/cntree_layers7_bprob1.0_corr{corr}_state_dim10_envs1000_H1600_explore/datasets/eval.pkl"

epoch=36-val_loss=0.000377.ckpt


In [3]:
# Extract parameters using regex
import re

n_embd = int(re.search(r'embd(\d+)', model_name).group(1))
n_layer = int(re.search(r'layer(\d+)', model_name).group(1))
n_head = int(re.search(r'head(\d+)', model_name).group(1))
dropout = float(re.search(r'drop(\d*\.?\d*)', model_name).group(1))

# Extract correlation and state_dim from eval dataset path
state_dim = int(re.search(r'state_dim(\d+)', eval_dset_path).group(1))

model_config = {
    "n_embd": n_embd,
    "n_layer": n_layer,
    "n_head": n_head,
    "state_dim": 10,
    "action_dim": 4,
    "dropout": dropout,
    "train_on_last_pred_only": False,
    "test": True,
    "name": "transformer_end_query",
    "optimizer_config": None,
    "linear_attention": False,
}

In [4]:
from src.models.transformer_end_query import Transformer
model_config['initialization_seed'] = 0
model = Transformer(**model_config)
checkpoint = torch.load(path_to_pkl)
model.load_state_dict(checkpoint['state_dict'])
model.eval()
model = model.to('cuda')

  checkpoint = torch.load(path_to_pkl)


# Load Dataset and Create Environment

In [5]:
n_eval_envs = 50

is_h5_file = eval_dset_path.endswith('.h5')
if is_h5_file:
    eval_trajs = h5py.File(eval_dset_path, 'r')
    traj_indices = list(eval_trajs.keys())
    n_eval_envs = min(n_eval_envs, len(traj_indices))
    random.seed(0)
    traj_indices = random.sample(traj_indices, n_eval_envs)
    random.seed()
    eval_trajs = [eval_trajs[i] for i in traj_indices]
else:  # Pickle file
    with open(eval_dset_path, 'rb') as f:
        eval_trajs = pickle.load(f)
    n_eval_envs = min(n_eval_envs, len(eval_trajs))
    random.seed(0)
    eval_trajs = random.sample(eval_trajs, n_eval_envs)
    random.seed()


In [6]:
def run_model(traj, model, state_features, seq_length, zero_reward=False):
    hidden_states = [[] for _ in range(model.n_layer)]
    
    for state_feature in state_features: 
        batch = {
            'context_states': convert_to_tensor([np.array(traj['context_states'])]),
            'context_actions': convert_to_tensor([np.array(traj['context_actions'])]),
            'context_next_states': convert_to_tensor([np.array(traj['context_next_states'])]),
            'context_rewards': convert_to_tensor([np.array(traj['context_rewards'])[:, None]]),
            'query_states': convert_to_tensor([np.array(state_feature)]),
            }
        batch['zeros'] = torch.zeros(1, 10 ** 2 + 4 + 1).float()
        for k in batch.keys():
            if 'context' in k:
                batch[k] = batch[k][:,:seq_length]
            batch[k] = batch[k].to(model.device)
        model.save_activations = True
        if zero_reward:
            batch['context_rewards'] *= 0
        with torch.no_grad():
            out = model(batch)
        _hidden_states = model.activations['hidden_states'][1:] # Tuple over layers
        for i_layer in range(model.n_layer):
            hidden_states[i_layer].append(_hidden_states[i_layer])
    return hidden_states

In [7]:
train_envs = np.arange(1, int(n_eval_envs*0.9))
test_envs = np.arange(int(n_eval_envs*0.9), n_eval_envs)
X_train = [[] for _ in range(model.n_layer)]
Y_train = []
X_test = [[] for _ in range(model.n_layer)]
Y_test = []
for i_eval in range(n_eval_envs):
    traj = eval_trajs[i_eval]

    first_reward = np.argwhere(np.array(traj['context_rewards'])>0)
    if (first_reward.size == 0) or (first_reward[0] > 1000):
        continue

    print(i_eval)

    env_config = {
        'max_layers': 7,
        'horizon': 1600,
        'branching_prob': 1.0,
        'node_encoding_corr': corr,
        'state_dim': state_dim,
        'initialization_seed': np.array(traj['initialization_seed']).item()
    }
    env = EvalCntrees().create_env(env_config, np.array(traj['goal']), i_eval)
    state_features = []
    state_features = list(env.node_map.keys())
    _, dist_from_goal = env.make_opt_action_dict()

    seq_length = 1000
    hidden_states = run_model(traj, model, state_features, seq_length)
    for state_idx, state_feature in enumerate(state_features):
        d = dist_from_goal[state_feature]
        gamma = 0.7
        val = gamma**d
        if i_eval in train_envs:
            Y_train.append(val)
        else:
            Y_test.append(val)
        for layer in range(model.n_layer):
            hidden_state = hidden_states[layer][state_idx][0, -1]
            if i_eval in train_envs:
                X_train[layer].append(hidden_state)
            else:
                X_test[layer].append(hidden_state)




0
1
3
6
7
9
13
14
17
18
19
22
23
24
28
31
33
35
38
39
42
44
47
49


In [8]:
from sklearn.linear_model import LinearRegression
import numpy as np

# Convert lists to numpy arrays
X_train_np = [np.array([_x.cpu().numpy() for _x in x]) for x in X_train]
X_test_np = [np.array([_x.cpu().numpy() for _x in x]) for x in X_test]
Y_train_np = np.array(Y_train)
Y_test_np = np.array(Y_test)

# Fit and evaluate regression for each layer
for layer in range(len(X_train)):
    reg = LinearRegression()
    reg.fit(X_train_np[layer], Y_train_np)
    
    train_score = reg.score(X_train_np[layer], Y_train_np)
    test_score = reg.score(X_test_np[layer], Y_test_np)
    
    print(f"Layer {layer}:")
    print(f"Train R2: {train_score:.3f}")
    print(f"Test R2: {test_score:.3f}")
    print()


Layer 0:
Train R2: 0.525
Test R2: -0.650

Layer 1:
Train R2: 0.702
Test R2: 0.322

Layer 2:
Train R2: 0.737
Test R2: 0.511



In [9]:
train_envs = np.arange(1, int(n_eval_envs*0.9))
test_envs = np.arange(int(n_eval_envs*0.9), n_eval_envs)
X_train = [[] for _ in range(model.n_layer)]
Y_train = []
X_test = [[] for _ in range(model.n_layer)]
Y_test = []
for i_eval in range(n_eval_envs):
    traj = eval_trajs[i_eval]

    first_reward = np.argwhere(np.array(traj['context_rewards'])>0)
    if (first_reward.size == 0) or (first_reward[0] > 1000):
        continue

    print(i_eval)

    env_config = {
        'max_layers': 7,
        'horizon': 1600,
        'branching_prob': 1.0,
        'node_encoding_corr': corr,
        'state_dim': state_dim,
        'initialization_seed': np.array(traj['initialization_seed']).item()
    }
    env = EvalCntrees().create_env(env_config, np.array(traj['goal']), i_eval)
    state_features = []
    state_features = list(env.node_map.keys())
    _, dist_from_goal = env.make_opt_action_dict()

    seq_length = 1000
    hidden_states = run_model(traj, model, state_features, seq_length, zero_reward=True)
    for state_idx, state_feature in enumerate(state_features):
        d = dist_from_goal[state_feature]
        gamma = 0.8
        val = gamma**d
        if i_eval in train_envs:
            Y_train.append(val)
        else:
            Y_test.append(val)
        for layer in range(model.n_layer):
            hidden_state = hidden_states[layer][state_idx][0, -1]
            if i_eval in train_envs:
                X_train[layer].append(hidden_state)
            else:
                X_test[layer].append(hidden_state)

X_train_np = [np.array([_x.cpu().numpy() for _x in x]) for x in X_train]
X_test_np = [np.array([_x.cpu().numpy() for _x in x]) for x in X_test]
Y_train_np = np.array(Y_train)
Y_test_np = np.array(Y_test)

# Fit and evaluate regression for each layer
for layer in range(len(X_train)):
    reg = LinearRegression()
    reg.fit(X_train_np[layer], Y_train_np)
    
    train_score = reg.score(X_train_np[layer], Y_train_np)
    test_score = reg.score(X_test_np[layer], Y_test_np)
    
    print(f"Layer {layer}:")
    print(f"Train R2: {train_score:.3f}")
    print(f"Test R2: {test_score:.3f}")
    print()


0


1
3
6
7
9
13
14
17
18
19
22
23
24
28
31
33
35
38
39
42
44
47
49
Layer 0:
Train R2: 0.208
Test R2: -0.170

Layer 1:
Train R2: 0.258
Test R2: -0.148

Layer 2:
Train R2: 0.276
Test R2: -0.222

