In [1]:
import torch
import torch.nn.functional as F
from model import POCML
from dataloader import GraphEnv, DataLoader

import pickle
import evaluate

## Generate and save enviornment

In [73]:
env_type = "tree"

if env_type == "grid":
    with open("data/data_n_nodes_9_env_grid_traj_len_12_n_traj_1536_args_{'rows': 3, 'cols': 3}_seed_65.pickle", "rb") as f:
        p = pickle.load(f)
elif env_type == "tree":
    with open("data/data_n_nodes_9_env_tree_traj_len_12_n_traj_1536_args_{'levels': 3}_seed_65.pickle", "rb") as f:
        p = pickle.load(f)

env = p["env"]
dataset = env.gen_zero_shot_dataset(trajectory_length=20, num_environments=20)

In [74]:
with open(f"data/zero_shot_dataset_{env_type}.pickle", "wb") as f:
    pickle.dump(dataset, f)

## Load environment

In [None]:
with open(f"data/zero_shot_dataset_{env_type}.pickle", "rb") as f:
    dataset = pickle.load(f)

## Load model

In [76]:
if env_type == "grid:":
    state_dim = 20
    batch_size = 64
    random_feature_dim = 2000
else:
    state_dim = 1000
    batch_size = 64
    random_feature_dim = 500

model = POCML(
    n_obs = env.n_items,
    n_states = env.size,
    n_actions = env.n_actions,
    state_dim = state_dim,
    batch_size=batch_size,
    random_feature_dim = random_feature_dim,
    alpha = 4,
    memory_bias=False,
)

## Grid model
if env_type == "grid":
    model.load_state_dict(torch.load("model/grid_{'rows': 3, 'cols': 3}_sdim_20_rfdim_2000_lrV_0.1_seed_69.ckpt"))

## Tree model
if env_type == "tree":
    model.load_state_dict(torch.load("model/tree_{'levels': 3}_sdim_1000_rfdim_500_lrV_0.04_seed_66.ckpt"))

model.batch_size = 1

In [77]:
import numpy as np
def theoretical_best(dataset):
    accs = []
    for t1, t2 in dataset:
        s1 = set(t1[:, 3].tolist())
        correct, total = 0, 0
        for i in range(t2.shape[0]):
            if t2[i, 3].item() in s1:
                correct += 1
            total += 1
            s1.add(t2[i, 3].item())
        acc = correct / total
        accs.append(acc)
    return np.mean(accs)

In [78]:
acc_te = evaluate.zero_shot_accuracy(
    model,
    dataset,
    update_state_given_obs=True,
    update_memory=True,
    softmax=False,
    beta=1000,
    lr=1,
    max_iter=1,
    eps=1e-3,
    test_acc=True
)
best = theoretical_best(dataset)
print("Test", acc_te)
print("Best", best)
print("Adj. Test", acc_te / best)

acc = evaluate.zero_shot_accuracy(
    model,
    dataset,
    update_state_given_obs=True,
    update_memory=True,
    softmax=False,
    beta=1000,
    lr=2,
    max_iter=1,
    eps=1e-3,
    test_acc=False
)
print("Train", acc)

100%|██████████| 20/20 [00:06<00:00,  3.15it/s]


Test 0.7157894736842105
Best 0.9605263157894737
Adj. Test 0.7452054794520548


100%|██████████| 20/20 [00:06<00:00,  2.96it/s]

Train 0.781578947368421



