In [1]:
import evaluate
import random

import torch
import torch.nn.functional as F
import numpy as np
from torch.utils.data import Dataset, DataLoader 
from tqdm import tqdm

# Visualization Tools
import networkx as nx
import matplotlib.pyplot as plt
import seaborn as sns

# Our code
from dataloader import *
from trainer import POCMLTrainer
from model import POCML, sim
from visualizer import * 
from utils import *


# Environment & Data Configuration

In [2]:
torch.autograd.set_detect_anomaly(True)

n_nodes = 9
n_obs = 9
trajectory_length = 12  # numer of node visits in a trajectory
num_desired_trajectories= 20

env = GraphEnv( n_items=n_nodes,                     # number of possible observations
                env='grid', 
                batch_size=trajectory_length, 
                num_desired_trajectories=num_desired_trajectories, 
                device=None, 
                unique=True,                         # each state is assigned a unique observation if true
                args = {"rows": 3, "cols": 3}
            )

#dataset = RandomWalkDataset(env.adj_matrix, trajectory_length, num_desired_trajectories, n_obs, env.items)
train_dataset = env.dataset.data

env.gen_dataset()
test_dataset = env.dataset.data

train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=True)

# Set config & hyperparameter pools for wandb

In [3]:
# Convention: for each hyperparameter key, set the value to a list if you want to try multiple values
param_pool = {  
    # data-related config; can't be automated for now
    "n_obs" : env.n_items,
    "n_states" : env.size,
    "n_actions" : env.n_actions,
    # Experiments
    "seed": [68, 70],
    # model 
    "state_dim" : 500,
    "random_feature_dim" : [1000, 2000],
    "alpha" : 1,
    "beta_obs" : 8,
    "beta_state" : 8,
    "memory_bypass" : False,
    "mem_reweight_rate" : "adaptive", 
    "decay" : "adaptive",
    # trainer
    "lr_Q_o" : 0.1, 
    "lr_V_o" : 0.01, 
    "lr_Q_s" : 0., 
    "lr_V_s" : 0., 
    "lr_all" : 1,
    "normalize" : False,
    "reset_every" : 4,
    "refactor_memory" : True,
    # training / optimizer 
    "epochs" : 5
}

debug = True
log = True

In [4]:
def run_trial(params, train_dataloader, test_dataloader, debug=False, log=False):

    # Set random seed
    seed = params["seed"]
    set_random_seed(seed)

    # Filter parameters to match the model & trainer's __init__ signature
    trainer_params = filter_param(params, POCMLTrainer)
    model_params = filter_param(params, POCML)

    # Instantiate the model & trainer using the filtered dictionary
    model = POCML(**model_params)
    trainer = POCMLTrainer(model = model, train_loader = train_dataloader, **trainer_params)
    # train the model and record its loss
    loss_record = np.array(trainer.train(params["epochs"])).reshape(params["epochs"],-1)
    
    ## Analytics
    # get state & action kernel similarities
    phi_Q = model.get_state_kernel()
    phi_V = model.get_action_kernel()
    k_sim_Q = sim(phi_Q.T, phi_Q.T)
    k_sim_V = sim(phi_V.T, phi_V.T)

    ## Evaluations
    train_acc, train_confidences = evaluate.accuracy(model, train_dataloader)
    test_acc, test_confidences = evaluate.accuracy(model, test_dataloader)
    sa_acc, sa_confidences, sa_distance_ratios = evaluate.state_transition_consistency(model, env)

    if debug: 
        print("State kernel similarities:\n", k_sim_Q)
        print("Action kernel similarities:\n", k_sim_V)

        print("Train obs accuracy/confidence:", train_acc, np.mean(train_confidences))
        print("Test obs accuracy/confidence:", test_acc, np.mean(test_confidences))
        print("State-action accuracy/confidence/distance ratio:", sa_acc, np.mean(sa_confidences), np.mean(sa_distance_ratios))

    ## Visualization 
    visualize(model.get_state_differences().numpy(), legend = "State", title = "MDS State Differences")
    visualize(model.get_action_differences().numpy(), legend = "Action", title = "MDS Action Differences")
    visualize_loss(loss_record, num_desired_trajectories, trajectory_length, per_epoch=False)
    visualize_loss(loss_record, num_desired_trajectories, trajectory_length, per_epoch=True)

    return trainer

In [5]:
# Use the loader to generate combinations one at a time
param_loader = generate_combinations_loader(param_pool)

for params in param_loader:
    model = run_trial(params, train_dataloader, test_dataloader, debug=debug, log=log)
    

KeyError: 'seed'

In [None]:
# beta_obs, beta_state, clean up rate
torch.save(model.state_dict(), "model/model_12_12_1.ckpt")