In [2]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pickle
import os
import pandas as pd
import configs
import torch
from sklearn.decomposition import PCA
from scipy.spatial.distance import pdist, squareform

from src.utils import find_ckpt_file, convert_to_tensor
import h5py
import random
from src.envs.darkroom import DarkroomEnv

# Load Model

In [9]:
engram_dir = "/n/holylfs06/LABS/krajan_lab/Lab/cfang/icl-maze/"
wandb_project = "darkroom_simple"

env_name = "darkroom_dim5_corr0.0_state_dim10_envs900000_H200_explore"
model_name = "transformer_end_query_embd256_layer4_head4_lr0.0001_drop0_initseed0_batch1024"

env_name = "darkroom_dim5_corr0.0_state_dim10_envs900000_H200_explore"
model_name = "transformer_end_query_embd512_layer4_head4_lr0.0001_drop0_initseed0_batch1024"

env_name = "darkroom_dim5_corr0.25_state_dim10_envs900000_H200_explore"
model_name = "transformer_end_query_embd512_layer4_head4_lr0.0001_drop0.0_initseed0_batch1024"

# env_name = "darkroom_dim5_corr0.25_state_dim10_envs900000_H200_explore"
# model_name = "transformer_end_query_embd512_layer3_head4_lr0.0001_drop0.0_initseed0_batch1024"

env_name = "darkroom_dim5_corr0.25_state_dim10_envs1500000_H150_explore"
model_name = "transformer_end_query_embd512_layer3_head4_lr0.0001_drop0.0_initseed0_batch1024"

env_name = "darkroom_dim5_corr0.25_state_dim10_envs1500000_H150_explore"
model_name = "transformer_end_query_embd512_layer3_head4_lr0.0001_drop0.0_initseed0_batch1024"

model_path = os.path.join(engram_dir, wandb_project, env_name, "models", model_name)
ckpt_name = find_ckpt_file(model_path, "best")
print(ckpt_name)
path_to_pkl = os.path.join(model_path, ckpt_name)
eval_dset_path = f"/n/holylfs06/LABS/krajan_lab/Lab/cfang/icl-maze/{wandb_project}/{env_name}/datasets/eval.pkl"

epoch=11-val_loss=0.942141.ckpt


In [10]:
# Extract parameters using regex
import re

n_embd = int(re.search(r'embd(\d+)', model_name).group(1))
n_layer = int(re.search(r'layer(\d+)', model_name).group(1))
n_head = int(re.search(r'head(\d+)', model_name).group(1))
dropout = float(re.search(r'drop(\d*\.?\d*)', model_name).group(1))

# Extract correlation and state_dim from eval dataset path
state_dim = int(re.search(r'state_dim(\d+)', eval_dset_path).group(1))
maze_dim = int(re.search(r'_dim(\d+)_corr', eval_dset_path).group(1))
node_encoding_corr = float(re.search(r'corr(\d*\.?\d*)', eval_dset_path).group(1))

model_config = {
    "n_embd": n_embd,
    "n_layer": n_layer,
    "n_head": n_head,
    "state_dim": state_dim,
    "action_dim": 5,
    "dropout": dropout,
    "train_on_last_pred_only": False,
    "test": True,
    "name": "transformer_end_query",
    "optimizer_config": None,
    "linear_attention": False,
}

In [11]:
from src.models.transformer_end_query import Transformer
model_config['initialization_seed'] = 0
model = Transformer(**model_config)
checkpoint = torch.load(path_to_pkl)
model.load_state_dict(checkpoint['state_dict'])
model.eval()
model = model.to('cuda')

  checkpoint = torch.load(path_to_pkl)


# Load Dataset and Create Environment

In [12]:
n_eval_envs = 50

is_h5_file = eval_dset_path.endswith('.h5')
if is_h5_file:
    eval_trajs = h5py.File(eval_dset_path, 'r')
    traj_indices = list(eval_trajs.keys())
    n_eval_envs = min(n_eval_envs, len(traj_indices))
    random.seed(0)
    traj_indices = random.sample(traj_indices, n_eval_envs)
    random.seed()
    eval_trajs = [eval_trajs[i] for i in traj_indices]
else:  # Pickle file
    with open(eval_dset_path, 'rb') as f:
        eval_trajs = pickle.load(f)
    n_eval_envs = min(n_eval_envs, len(eval_trajs))
    random.seed(0)
    eval_trajs = random.sample(eval_trajs, n_eval_envs)
    random.seed()


In [13]:
i_eval = 0
traj = eval_trajs[i_eval]

In [14]:
env_config = {
    'maze_dim': maze_dim,
    'horizon': 200,
    'state_dim': state_dim,
    'node_encoding_corr': node_encoding_corr,
    'initialization_seed': np.array(traj['initialization_seed']).item(),
    'goal': np.array(traj['goal'])
}
env = DarkroomEnv(**env_config)


# Run Model

In [15]:
xs = []
ys = []
state_features = []
state_features = list(env.node_map_encoding_to_pos.keys())

for state_feature in state_features:
    xs.append(env.node_map_encoding_to_pos[state_feature][0])
    ys.append(env.node_map_encoding_to_pos[state_feature][1])

reward_idx = np.argwhere([np.all(s == env.goal) for s in state_features]).item()
xys = np.array([xs, ys]).T
state_features = np.array(state_features)

In [16]:
xys[reward_idx]

array([3, 0])

In [50]:
test_xys = [
    [0,3], [1,3], [2,3], [3,3], [4,3], [4,2], [4,1], [4,0], [3,0]
]

test_xys = [
    [0,2], [1,2], [2,2], [3,2], [4,2], [4,1], [4,0], [3,0]
]

test_xys = [
    [0,2], [1,2], [2,2], [3,2], [4,2], [4,1], [4,0], [3,0], #forward
    [3,0], [4,0], [4,1], [4,2], [3,2], [2,2], [1,2], [0,2] #backward
]

# test_xys = [
#     [0,3], [1,3], [2,3], [3,3], [4,3], [4,2], [4,1], [4,0], [3,0], # forward
#     [3,0], [4,0], [4,1], [4,2], [4,3], [3,3], [2,3], [1,3], [0,3], # backward
#     # [0,3], [1,3], [2,3], [3,3], [4,3], [4,2], [4,1], [4,0], [3,0], # forward
#     # [3,0], [4,0], [4,1], [4,2], [4,3], [3,3], [2,3], [1,3], [0,3], # backward
# #       [0,3], [1,3], [2,3], [3,3], [4,3], [4,2], [4,1], [4,0], [3,0], # forward
# #    [3,0], [4,0], [4,1], [4,2], [4,3], [3,3], [2,3], [1,3], [0,3], # backward
# ]

In [51]:
def get_onestep_action(xy1, xy2):
    x1, y1 = xy1
    x2, y2 = xy2
    assert np.abs(x1-x2) + np.abs(y1-y2) <= 1
    if xy1 == xy2:
        return 4
    elif x1 < x2 and y1 == y2:
        return 2
    elif x1 > x2 and y1 == y2:
        return 0
    elif x1 == x2 and y1 < y2:
        return 1
    elif x1 == x2 and y1 > y2:
        return 3
    else:
        raise ValueError(f"Invalid action: {xy1} -> {xy2}")

In [52]:
context_states = []
context_actions = []
context_next_states = []
context_rewards = []
query_state = []
for i in range(len(test_xys)-1):
    state_feature_idx = np.all(xys == test_xys[i], axis=-1)
    next_state_feature_idx = np.all(xys == test_xys[i+1], axis=-1)
    action_idx = get_onestep_action(test_xys[i], test_xys[i+1])
    action = np.zeros(5)
    action[action_idx] = 1
    reward = 1 if np.all(test_xys[i+1] == xys[reward_idx]) else 0
    context_states.append(state_features[state_feature_idx])
    context_actions.append(action)
    context_next_states.append(state_features[next_state_feature_idx])
    context_rewards.append(reward)
context_states = np.array(context_states).squeeze()
context_actions = np.array(context_actions)
context_next_states = np.array(context_next_states).squeeze()
context_rewards = np.array(context_rewards)


In [53]:
first_reward = np.argwhere(traj['context_rewards']>0).squeeze()[0]

In [54]:
make_substitution = True

if make_substitution:
    batch = {
        'context_states': convert_to_tensor([context_states]),
        'context_actions': convert_to_tensor([context_actions]),
        'context_next_states': convert_to_tensor([context_next_states]),
        'context_rewards': convert_to_tensor([context_rewards[:, None]]),
        }
else:
    batch = {
        'context_states': convert_to_tensor([np.array(traj['context_states'])]),
        'context_actions': convert_to_tensor([np.array(traj['context_actions'])]),
        'context_next_states': convert_to_tensor([np.array(traj['context_next_states'])]),
        'context_rewards': convert_to_tensor([np.array(traj['context_rewards'])[:, None]]),
        'query_states': convert_to_tensor([np.array(traj['query_state'])]),
        }
    for k in batch.keys():
        if 'context' in k:
            batch[k] = batch[k][:, first_reward-5:first_reward+5]
batch['zeros'] = torch.zeros(1, state_dim ** 2 + 5 + 1).float()
for k in batch.keys():
    if 'context' in k:
        batch[k] = batch[k]
    batch[k] = batch[k].to(model.device)

In [55]:
for i in range(batch['context_states'].shape[1]):
    start = batch['context_states'][:, i].cpu().numpy().squeeze()
    next = batch['context_next_states'][:, i].cpu().numpy().squeeze()
    action = np.argmax(batch['context_actions'][:, i].cpu().numpy())
    start = env.node_map_encoding_to_pos[tuple(start.tolist())]
    next = env.node_map_encoding_to_pos[tuple(next.tolist())]
    print(f'{start} to {next} with action {action}')
    

(0, 2) to (1, 2) with action 2
(1, 2) to (2, 2) with action 2
(2, 2) to (3, 2) with action 2
(3, 2) to (4, 2) with action 2
(4, 2) to (4, 1) with action 3
(4, 1) to (4, 0) with action 3
(4, 0) to (3, 0) with action 0
(3, 0) to (3, 0) with action 4
(3, 0) to (4, 0) with action 2
(4, 0) to (4, 1) with action 1
(4, 1) to (4, 2) with action 1
(4, 2) to (3, 2) with action 0
(3, 2) to (2, 2) with action 0
(2, 2) to (1, 2) with action 0
(1, 2) to (0, 2) with action 0


In [56]:
n_steps = 10
curr_xy_state = test_xys[0] #[-1] #[0]
query_state = env.node_map_pos_to_encoding[tuple(curr_xy_state)]
xy_path = [curr_xy_state]
chosen_actions = []

# Convert initial context tensors to numpy for easier concatenation
context_states = batch['context_states'].cpu().numpy().squeeze()
context_actions = batch['context_actions'].cpu().numpy().squeeze()
context_next_states = batch['context_next_states'].cpu().numpy().squeeze()
context_rewards = batch['context_rewards'].cpu().numpy().squeeze()

for i in range(n_steps):
    batch = {
        'context_states': convert_to_tensor([context_states]),
        'context_actions': convert_to_tensor([context_actions]),
        'context_next_states': convert_to_tensor([context_next_states]),
        'context_rewards': convert_to_tensor([context_rewards[:, None]]),
        }
    batch['query_states'] = convert_to_tensor([np.array(query_state)])
    batch['query_states'] = batch['query_states'].to(model.device)
    batch['zeros'] = torch.zeros(1, state_dim ** 2 + 5 + 1).float()
    for k in batch.keys():
        batch[k] = batch[k].to(model.device)

    with torch.no_grad():
        out = model(batch)

    pred_action = torch.argmax(out.squeeze()).item()
    action_encoding = np.zeros(5)
    action_encoding[pred_action] = 1
    next_state_encoding, reward = env.transit(np.array(query_state), action_encoding)

    context_states = np.vstack([context_states, query_state])
    context_actions = np.vstack([context_actions, action_encoding])
    context_next_states = np.vstack([context_next_states, next_state_encoding])
    context_rewards = np.append(context_rewards, reward)

    xy_path.append(env.node_map_encoding_to_pos[tuple(next_state_encoding)])
    chosen_actions.append(pred_action)
    query_state = next_state_encoding

In [57]:
xy_path

[[0, 2],
 (1, 2),
 (2, 2),
 (2, 1),
 (2, 0),
 (3, 0),
 (3, 0),
 (3, 0),
 (3, 0),
 (3, 0),
 (3, 0)]

In [41]:
# chosen_actions