In [1]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pickle
import os
import pandas as pd
import configs
import torch
from sklearn.decomposition import PCA
from scipy.spatial.distance import pdist, squareform

from src.utils import find_ckpt_file, convert_to_tensor
import h5py
import random
from src.evals.eval_trees import EvalTrees


# Load Model

In [9]:
engram_dir = "/n/holylfs06/LABS/krajan_lab/Lab/cfang/icl-maze/"
wandb_project = "darkroom_simple"
env_name = "darkroom_dim5_corr0.0_state_dim5_envs500000_H200_explore"
dataset_storage_dir = f'{engram_dir}/{wandb_project}/{env_name}/datasets'
use_h5 = os.path.exists(os.path.join(dataset_storage_dir, 'test.h5'))
file_suffix = '.h5' if use_h5 else '.pkl'
test_dset_path = os.path.join(dataset_storage_dir, 'test' + file_suffix)

In [10]:
print(dataset_storage_dir)

/n/holylfs06/LABS/krajan_lab/Lab/cfang/icl-maze//darkroom_simple/darkroom_fixedxy_dim5_corr0.0_state_dim2_envs500000_H200_explore/datasets


# Load Dataset and Create Environment

In [11]:
is_h5_file = test_dset_path.endswith('.h5')
if is_h5_file:
    eval_trajs = h5py.File(test_dset_path, 'r')
    traj_indices = list(eval_trajs.keys())
else:  # Pickle file
    with open(test_dset_path, 'rb') as f:
        eval_trajs = pickle.load(f)
n_eval_envs = len(eval_trajs)

# Get max accuracy

In [12]:
test_dset_path

'/n/holylfs06/LABS/krajan_lab/Lab/cfang/icl-maze//darkroom_simple/darkroom_fixedxy_dim5_corr0.0_state_dim2_envs500000_H200_explore/datasets/test.pkl'

In [17]:
query_every = 10
max_acc = []
for eval_traj in eval_trajs:
    seq_length = eval_traj['context_states'].shape[0]
    eval_idxs = np.concatenate([np.arange(0, seq_length, query_every), [seq_length-1]])
    opt_accuracy = np.ones(len(eval_idxs))*0.25
    first_reward = np.argwhere(eval_traj['context_rewards']>0)
    if first_reward.size > 0:
        first_reward = first_reward[0, 0]
        opt_accuracy[eval_idxs>first_reward] = 1
    max_acc.append(opt_accuracy.mean())

In [18]:
np.mean(max_acc)

np.float64(0.733836607142857)

In [15]:
eval_traj = eval_trajs[0]

In [16]:
states = eval_traj['context_states']
actions = eval_traj['context_actions']
next_states = eval_traj['context_next_states']
rewards = eval_traj['context_rewards']

for i in range(len(states)):
    print(states[i], actions[i].argmax(), next_states[i], rewards[i])


[2 4] 2 [3 4] 0
[3 4] 3 [3 3] 0
[3 3] 4 [3 3] 0
[3 3] 3 [3 2] 0
[3 2] 1 [3 3] 0
[3 3] 3 [3 2] 0
[3 2] 4 [3 2] 0
[3 2] 2 [4 2] 0
[4 2] 3 [4 1] 0
[4 1] 1 [4 2] 0
[4 2] 1 [4 3] 0
[4 3] 0 [3 3] 0
[3 3] 0 [2 3] 0
[2 3] 0 [1 3] 0
[1 3] 0 [0 3] 0
[0 3] 2 [1 3] 0
[1 3] 1 [1 4] 0
[1 4] 3 [1 3] 0
[1 3] 4 [1 3] 0
[1 3] 1 [1 4] 0
[1 4] 0 [0 4] 0
[0 4] 4 [0 4] 0
[0 4] 0 [0 4] 0
[0 4] 1 [0 4] 0
[0 4] 3 [0 3] 0
[0 3] 2 [1 3] 0
[1 3] 0 [0 3] 0
[0 3] 3 [0 2] 0
[0 2] 0 [0 2] 0
[0 2] 2 [1 2] 0
[1 2] 2 [2 2] 0
[2 2] 1 [2 3] 0
[2 3] 3 [2 2] 0
[2 2] 3 [2 1] 0
[2 1] 3 [2 0] 0
[2 0] 0 [1 0] 0
[1 0] 1 [1 1] 0
[1 1] 1 [1 2] 0
[1 2] 1 [1 3] 0
[1 3] 0 [0 3] 0
[0 3] 1 [0 4] 0
[0 4] 2 [1 4] 0
[1 4] 0 [0 4] 0
[0 4] 3 [0 3] 0
[0 3] 1 [0 4] 0
[0 4] 0 [0 4] 0
[0 4] 2 [1 4] 0
[1 4] 3 [1 3] 0
[1 3] 2 [2 3] 0
[2 3] 3 [2 2] 0
[2 2] 4 [2 2] 0
[2 2] 2 [3 2] 0
[3 2] 4 [3 2] 0
[3 2] 0 [2 2] 0
[2 2] 1 [2 3] 0
[2 3] 4 [2 3] 0
[2 3] 1 [2 4] 0
[2 4] 2 [3 4] 0
[3 4] 1 [3 4] 0
[3 4] 3 [3 3] 0
[3 3] 3 [3 2] 0
[3 2] 2 [4 2] 0
[4 2] 3 

In [11]:
first_rewards = []
for eval_traj in eval_trajs:
    first_reward = np.argwhere(eval_traj['context_rewards']>0)
    if first_reward.size > 0:
        first_rewards.append(first_reward[0, 0])
    else:
        first_rewards.append(np.nan)


In [12]:
first_rewards

[np.int64(0),
 np.int64(0),
 np.int64(77),
 np.int64(60),
 np.int64(43),
 np.int64(105),
 np.int64(1),
 np.int64(64),
 np.int64(145),
 np.int64(50),
 np.int64(71),
 np.int64(132),
 np.int64(34),
 np.int64(37),
 np.int64(106),
 np.int64(93),
 np.int64(36),
 np.int64(105),
 np.int64(67),
 np.int64(34),
 np.int64(5),
 np.int64(25),
 np.int64(181),
 np.int64(17),
 np.int64(86),
 np.int64(142),
 nan,
 np.int64(70),
 np.int64(106),
 np.int64(0),
 np.int64(9),
 np.int64(66),
 np.int64(21),
 np.int64(159),
 np.int64(4),
 np.int64(1),
 np.int64(157),
 np.int64(11),
 np.int64(43),
 np.int64(17),
 np.int64(60),
 np.int64(3),
 np.int64(97),
 np.int64(14),
 np.int64(3),
 np.int64(17),
 np.int64(52),
 np.int64(28),
 np.int64(61),
 np.int64(8),
 nan,
 np.int64(60),
 np.int64(69),
 np.int64(65),
 np.int64(65),
 np.int64(19),
 nan,
 np.int64(46),
 np.int64(0),
 np.int64(64),
 np.int64(199),
 np.int64(7),
 np.int64(159),
 np.int64(2),
 np.int64(22),
 np.int64(8),
 np.int64(117),
 np.int64(33),
 np.int64