In [1]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pickle
import os
import pandas as pd
import configs
import torch

from src.utils import find_ckpt_file, convert_to_tensor
import h5py
from src.evals.eval_trees import EvalCntrees
from decoding_utils import fit_and_evaluate_classification, fit_and_evaluate_regression

# Arguments 

In [2]:
corr = 0.25
seq_length = 800

# Load Model

In [3]:
model_name, path_to_pkl, eval_dset_path = configs.get_model_paths(corr, "tree_maze")

epoch=34-val_loss=0.174876.ckpt


In [4]:
# Extract parameters using regex
import re

n_embd = int(re.search(r'embd(\d+)', model_name).group(1))
n_layer = int(re.search(r'layer(\d+)', model_name).group(1))
n_head = int(re.search(r'head(\d+)', model_name).group(1))
dropout = float(re.search(r'drop(\d*\.?\d*)', model_name).group(1))

# Extract correlation and state_dim from eval dataset path
state_dim = int(re.search(r'state_dim(\d+)', eval_dset_path).group(1))

model_config = {
    "n_embd": n_embd,
    "n_layer": n_layer,
    "n_head": n_head,
    "state_dim": 10,
    "action_dim": 4,
    "dropout": dropout,
    "test": True,
    "name": "transformer_end_query",
    "optimizer_config": None,
}

In [5]:
from src.models.transformer_end_query import Transformer
model_config['initialization_seed'] = 0
model = Transformer(**model_config)
checkpoint = torch.load(path_to_pkl)
model.load_state_dict(checkpoint['state_dict'])
model.eval()
model = model.to('cuda')


# Load Dataset and Create Environment

In [6]:
def load_eval_trajs(eval_dset_path):
    is_h5_file = eval_dset_path.endswith('.h5')
    if is_h5_file:
        eval_trajs = h5py.File(eval_dset_path, 'r')
    else:  # Pickle file
        with open(eval_dset_path, 'rb') as f:
            eval_trajs = pickle.load(f)
    n_eval_envs = len(eval_trajs)
    return eval_trajs, n_eval_envs

eval_trajs_1, n_eval_envs_1 = load_eval_trajs(eval_dset_path)
eval_trajs_2, n_eval_envs_2 = load_eval_trajs(eval_dset_path.replace('eval', 'test'))
eval_trajs = eval_trajs_1 + eval_trajs_2
n_eval_envs = n_eval_envs_1 + n_eval_envs_2

In [7]:
def run_model(traj, model, seq_length=1200, start_idx=800):
    hidden_states = []

    batch = {
        'context_states': convert_to_tensor([np.array(traj['context_states'])]),
        'context_actions': convert_to_tensor([np.array(traj['context_actions'])]),
        'context_next_states': convert_to_tensor([np.array(traj['context_next_states'])]),
        'context_rewards': convert_to_tensor([np.array(traj['context_rewards'])[:, None]]),
        'query_states': convert_to_tensor([np.array(traj['query_state'])]),  # Ignored
        }
    batch['zeros'] = torch.zeros(1, 10 ** 2 + 4 + 1).float()
    for k in batch.keys():
        if 'context' in k:
            batch[k] = batch[k][:,:seq_length]
        batch[k] = batch[k].to(model.device)
    model.save_activations = True
    with torch.no_grad():
        out = model(batch)
    _hidden_states = model.activations['hidden_states'] # Tuple over layers of (1, seq, dim)
    state_features = batch['context_states'][0][start_idx:].to('cpu').numpy()
    next_state_features = batch['context_next_states'][0][start_idx:].to('cpu').numpy()
    actions = batch['context_actions'][0].argmax(dim=1)[start_idx:].to('cpu').numpy()
    for i_layer in range(len(_hidden_states)):
        hidden_states.append(_hidden_states[i_layer][0,start_idx:-1])
    return hidden_states, state_features, next_state_features, actions

# Across context decoding

In [8]:
def get_subtree_location(layer, pos, subtree):
    midpt = 2**(layer-1)
    quarter_pt = midpt//2
    eighth_pt = quarter_pt//2
    if layer == 0:
        return 0
    if subtree == 'half':
        return 1 if pos < midpt else 2
    elif subtree == 'quarter':
        if layer == 1:
            return 0
        bins = np.arange(0, 2**layer, quarter_pt)
        return np.digitize([pos], bins)[0]
    elif subtree == 'eighth':
        if (layer == 1) or (layer == 2):
            return 0
        bins = np.arange(0, 2**layer, eighth_pt)
        return np.digitize([pos], bins)[0]

    

In [9]:
def make_train_test_matrices():
    train_envs = np.arange(1, int(n_eval_envs*0.9))
    test_envs = np.arange(int(n_eval_envs*0.9), n_eval_envs)

    X_train = [[] for _ in range(model.n_layer+1)]
    X_test = [[] for _ in range(model.n_layer+1)]
    Ys_dict = {
        "dist_from_goal": {"Y_train": [], "Y_test": []},
        "layer": {"Y_train": [], "Y_test": []},
        "node_identity": {"Y_train": [], "Y_test": []},
        "maze_half": {"Y_train": [], "Y_test": []},
        "maze_quarter": {"Y_train": [], "Y_test": []},
        "maze_eighth": {"Y_train": [], "Y_test": []},
        'is_goal': {"Y_train": [], "Y_test": []},
        'same_half_as_goal': {"Y_train": [], "Y_test": []},
        'same_quarter_as_goal': {"Y_train": [], "Y_test": []},
        'opt_action': {"Y_train": [], "Y_test": []},
        'state_feature': {"Y_train": [], "Y_test": []},
        'next_state_feature': {"Y_train": [], "Y_test": []},
        "on_path": {"Y_train": [], "Y_test": []},
        "on_lr_path": {"Y_train": [], "Y_test": []},
        "intersects_lr_path": {"Y_train": [], "Y_test": []},
        "inverse_action": {"Y_train": [], "Y_test": []},
        "action": {"Y_train": [], "Y_test": []}
    }
    for i_eval in range(n_eval_envs):
        traj = eval_trajs[i_eval]
        first_reward = np.argwhere(np.array(traj['context_rewards'])>0)
        if (first_reward.size == 0) or (first_reward[0] > seq_length):
            continue
        start_idx = first_reward[0].item()

        env_config = {
            'max_layers': 7,
            'horizon': 1600,
            'branching_prob': 1.0,
            'node_encoding_corr': corr,
            'state_dim': state_dim,
            'initialization_seed': np.array(traj['initialization_seed']).item()
        }
        env = EvalCntrees().create_env(env_config, np.array(traj['goal']), i_eval)
        opt_action_map, dist_from_goal = env.make_opt_action_dict()
        s = env.root.encoding()
        states_on_path_from_root_to_goal = [s]
        while True:
            action = np.zeros(4)
            action[opt_action_map[tuple(s)]] = 1
            s, _ = env.transit(np.array(s), action)
            states_on_path_from_root_to_goal.append(tuple(s))
            if np.array_equal(s, env.goal):
                break

        hidden_states, state_features, next_state_features, actions = run_model(traj, model, seq_length, start_idx)
        goal_node = env.node_map[tuple(env.goal.tolist())]
        goal_layer = goal_node.layer
        goal_pos = goal_node.pos
        seen_combos = set()
        for state_idx in reversed(range(len(state_features))):
            state_feature = state_features[state_idx]
            next_state_feature = next_state_features[state_idx]
            state_feature_tuple = tuple(state_feature.tolist())
            next_state_feature_tuple = tuple(next_state_feature.tolist())
            action = actions[state_idx]
            combo = tuple(state_feature.tolist() + next_state_feature.tolist())
            if combo in seen_combos:
                continue
            seen_combos.add(combo)
            d = dist_from_goal[next_state_feature_tuple]
            layer = env.node_map[next_state_feature_tuple].layer
            pos = env.node_map[next_state_feature_tuple].pos
            node_identity = 2**layer + pos
            maze_half = get_subtree_location(layer, pos, 'half')
            maze_quarter = get_subtree_location(layer, pos, 'quarter')
            maze_eighth = get_subtree_location(layer, pos, 'eighth')
            if action == 0:
                if env.node_map[next_state_feature_tuple].left == env.node_map[state_feature_tuple]:
                    inverse_action = 1
                elif env.node_map[next_state_feature_tuple].right == env.node_map[state_feature_tuple]:
                    inverse_action = 2
                else:
                    inverse_action = -1
            elif action == 1 or action == 2:
                inverse_action = 0
            else:
                inverse_action = 3

            on_lr_path = (state_feature_tuple in states_on_path_from_root_to_goal) and (next_state_feature_tuple in states_on_path_from_root_to_goal)
            intersects_lr_path = (state_feature_tuple in states_on_path_from_root_to_goal) or (next_state_feature_tuple in states_on_path_from_root_to_goal)
            
            Y_key = "Y_train" if i_eval in train_envs else "Y_test"
            Ys_dict["dist_from_goal"][Y_key].append(d)
            Ys_dict["layer"][Y_key].append(layer)
            Ys_dict["node_identity"][Y_key].append(node_identity)
            Ys_dict["maze_half"][Y_key].append(maze_half)
            Ys_dict["maze_quarter"][Y_key].append(maze_quarter)
            Ys_dict["maze_eighth"][Y_key].append(maze_eighth)
            Ys_dict["is_goal"][Y_key].append(state_feature_tuple == tuple(env.goal.tolist()))
            Ys_dict["same_half_as_goal"][Y_key].append(maze_half == get_subtree_location(goal_layer, goal_pos, 'half'))
            Ys_dict["same_quarter_as_goal"][Y_key].append(maze_quarter == get_subtree_location(goal_layer, goal_pos, 'quarter'))
            Ys_dict["opt_action"][Y_key].append(opt_action_map[state_feature_tuple])
            Ys_dict["state_feature"][Y_key].append(state_feature)
            Ys_dict["next_state_feature"][Y_key].append(next_state_feature)
            Ys_dict["on_path"][Y_key].append(action == opt_action_map[state_feature_tuple])
            Ys_dict["on_lr_path"][Y_key].append(on_lr_path)
            Ys_dict["intersects_lr_path"][Y_key].append(intersects_lr_path)
            Ys_dict["inverse_action"][Y_key].append(inverse_action)
            Ys_dict["action"][Y_key].append(action)
            
            for layer in range(len(hidden_states)):
                hidden_state = hidden_states[layer][state_idx].to('cpu').numpy()
                if i_eval in train_envs:
                    X_train[layer].append(hidden_state)
                else:
                    X_test[layer].append(hidden_state)
        torch.cuda.empty_cache()

    return X_train, X_test, Ys_dict


# Collect data and run regressions

In [10]:
X_train, X_test, Ys_dict = make_train_test_matrices()



In [11]:
pipeline, test_score, test_y, test_pred = fit_and_evaluate_classification(
    X_train, Ys_dict['on_lr_path']["Y_train"],
    X_test, Ys_dict['on_lr_path']["Y_test"], print_scores=False)

print([np.mean(_test_score) for _test_score in test_score])

results = {}
results['test_score'] = test_score
results['test_y'] = test_y
results['test_pred'] = test_pred
results['dist_from_goal'] = Ys_dict['dist_from_goal']["Y_test"]
with open('pickles/09_buffer_token_decoding_on_lr_path.pkl', 'wb') as f:
    pickle.dump(results, f)

[np.float64(0.5868448098663926), np.float64(0.7929085303186023), np.float64(0.9275436793422405), np.float64(0.9696813977389517)]


In [17]:
train_indices = [i for i, a in enumerate(Ys_dict['inverse_action']['Y_train']) if a in [1,2]]
test_indices = [i for i, a in enumerate(Ys_dict['inverse_action']['Y_test']) if a in [1,2]]
_Y_train = [Ys_dict['inverse_action']["Y_train"][i] for i in train_indices]
_Y_test = [Ys_dict['inverse_action']["Y_test"][i] for i in test_indices]
_X_train = [[X[i] for i in train_indices] for X in X_train] 
_X_test = [[X[i] for i in test_indices] for X in X_test]

pipeline, test_score, test_y, test_pred = fit_and_evaluate_classification(
    _X_train, _Y_train, _X_test, _Y_test,
    print_scores=False
)

print([np.mean(_test_score) for _test_score in test_score])

results = {}
results['test_score'] = test_score
results['test_y'] = test_y
results['test_pred'] = test_pred
results['dist_from_goal'] = [Ys_dict['dist_from_goal']["Y_test"][i] for i in test_indices]
with open('pickles/09_buffer_token_decoding_inverse_action.pkl', 'wb') as f:
    pickle.dump(results, f)

[np.float64(0.4912536443148688), np.float64(0.8498542274052479), np.float64(0.9927113702623906), np.float64(0.9927113702623906)]
