In [33]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pickle
import os
import pandas as pd
import configs
import torch
from sklearn.decomposition import PCA
from scipy.spatial.distance import pdist, squareform

from src.utils import find_ckpt_file, convert_to_tensor
import h5py
import random
from src.evals.eval_trees import EvalTrees


# Load Model

In [35]:
engram_dir = "/n/holylfs06/LABS/krajan_lab/Lab/cfang/icl-maze/"
wandb_project = "random_tree"
env_name = "tree_layers7_bprob0.9_envs300000_H800_explore"
model_name = "transformer_end_query_embd512_layer4_head4_lr0.0001_drop0_batch256"
model_path = os.path.join(engram_dir, wandb_project, env_name, "models", model_name)
ckpt_name = find_ckpt_file(model_path, "best")
path_to_pkl = os.path.join(model_path, ckpt_name)

In [36]:
# Extract parameters using regex
import re
n_embd = int(re.search(r'embd(\d+)', model_name).group(1))
n_layer = int(re.search(r'layer(\d+)', model_name).group(1))
n_head = int(re.search(r'head(\d+)', model_name).group(1))
dropout = float(re.search(r'drop(\d*\.?\d*)', model_name).group(1))


model_config = {
    "n_embd": n_embd,
    "n_layer": n_layer,
    "n_head": n_head,
    "state_dim": 10,
    "action_dim": 4,
    "dropout": dropout,
    "train_on_last_pred_only": False,
    "test": True,
    "name": "transformer_end_query",
    "optimizer_config": None,
    "linear_attention": False,
}

In [37]:
from src.models.transformer_end_query import Transformer
model = Transformer(**model_config)
checkpoint = torch.load(path_to_pkl)
model.load_state_dict(checkpoint['state_dict'])
model.eval()
model = model.to('cuda')


  checkpoint = torch.load(path_to_pkl)


# Load Dataset and Create Environment

In [52]:
#eval_dataset_path = "/n/holylfs06/LABS/krajan_lab/Lab/cfang/icl-maze/lazyload/tree_layers7_bprob1.0_envs600000_H1600_explore/datasets/eval.h5"
#eval_dset_path = "/n/holylfs06/LABS/krajan_lab/Lab/cfang/icl-maze/random_tree/tree_layers7_bprob0.9_envs300000_H800_explore/datasets/eval.pkl"
eval_dset_path = "/n/holylfs06/LABS/krajan_lab/Lab/cfang/icl-maze/datasets/tree_layers7_bprob0.9_envs300000_H800_explore/eval.pkl"
n_eval_envs = 2000

is_h5_file = eval_dset_path.endswith('.h5')
if is_h5_file:
    eval_trajs = h5py.File(eval_dset_path, 'r')
    traj_indices = list(eval_trajs.keys())
    n_eval_envs = min(n_eval_envs, len(traj_indices))
    random.seed(0)
    traj_indices = random.sample(traj_indices, n_eval_envs)
    random.seed()
    eval_trajs = [eval_trajs[i] for i in traj_indices]
else:  # Pickle file
    with open(eval_dset_path, 'rb') as f:
        eval_trajs = pickle.load(f)
    n_eval_envs = min(n_eval_envs, len(eval_trajs))
    random.seed(0)
    eval_trajs = random.sample(eval_trajs, n_eval_envs)
    random.seed()


# Run Model

In [54]:
matches = []
optimal_actions = []
rs = []
for i_eval in range(n_eval_envs):
    traj = eval_trajs[i_eval]
    env_config = {
        'max_layers': 7,
        'horizon': 800,
        'branching_prob': 0.9,
        'node_encoding': 'random',
        'initialization_seed': np.array(traj['initialization_seed']).item()
    }
    #env = EvalTrees().create_env(env_config, np.array(traj['goal']), i_eval)

    batch = {
        'context_states': convert_to_tensor([np.array(traj['context_states'])]),
        'context_actions': convert_to_tensor([np.array(traj['context_actions'])]),
        'context_next_states': convert_to_tensor([np.array(traj['context_next_states'])]),
        'context_rewards': convert_to_tensor([np.array(traj['context_rewards'])[:, None]]),
        'query_states': convert_to_tensor([np.array(traj['query_state'])]),
        }
    rs.append(batch['context_rewards'].sum().item())

    #assert env.root.encoding_vector == tuple(traj['context_states'][0])
    continue


    batch['zeros'] = torch.zeros(1, 10 ** 2 + 4 + 1).float()
    for k in batch.keys():
        if 'context' in k:
            batch[k] = batch[k]
        batch[k] = batch[k].to(model.device)
    with torch.no_grad():
        out = model(batch)
    print(out)

    # Get predicted and optimal actions
    pred_action = torch.argmax(out.squeeze()).item()
    optimal_action = np.argmax(traj['optimal_action'])

    print(f"\nPredicted action: {pred_action}")
    print(f"Optimal action: {optimal_action}")
    print(f"Match: {pred_action == optimal_action}")
    matches.append(pred_action == optimal_action)
    optimal_actions.append(optimal_action)


In [55]:
(np.array(rs)==0).sum()/len(rs)

np.float64(0.239)

In [47]:
(np.array(optimal_actions)==0).sum()/len(optimal_actions)

  (np.array(optimal_actions)==0).sum()/len(optimal_actions)


np.float64(nan)

In [31]:
np.array(matches).sum()/len(matches)

np.float64(0.896)

In [32]:
ckpt_name

'epoch=26-val_loss=0.000552.ckpt'