In [6]:
import torch
import random
import numpy as np
import matplotlib.pyplot as plt
import os

import evaluate
from model import LSTM, Transformer
from trainer import BenchmarkTrainer
from dataloader import GraphEnv, DataLoader, preprocess_zero_shot_dataset

from utils import dataset_loader, set_random_seed, filter_param, matches_filter

import wandb

In [7]:
def generate_run_name_benchmark(params):
    name = None
    if params['model'] == 'Transformer':
        name = f"{params['model']}_{params['env_type']}_{params['env_args']}_hdim_{params['hidden_size']}_stateinfo_{params['include_init_state_info']}_dmodel_{params['d_model']}_nheads_{params['n_heads']})_dseed_{params['dseed']}_seed_{params['seed']}"
    elif params['model'] == 'LSTM':
        name = f"{params['model']}_{params['env_type']}_{params['env_args']}_hdim_{params['hidden_size']}_stateinfo_{params['include_init_state_info']}_dseed_{params['dseed']}_seed_{params['seed']}"
    else:
         raise Exception("Sorry, unknown model")
        
    return name

In [8]:
# Filter data to avoid reruns 
data_filter = {
    #'seed' : [65, 70, 75,]
}

debug = False
show = False
log = True

In [9]:
def run_trial(params, train_dataloader, test_dataloader, env, debug=False, show = True, log=False):

    run_name = generate_run_name_benchmark(params)

    if log: 
        wandb.init(
            # Set the project where this run will be logged
            project="POCML",
            # We pass a run name (otherwise itâ€™ll be randomly assigned, like sunshine-lollypop-10)
            name=run_name,
            # Track hyperparameters and run metadata
            config = params,
            # config={
            #     "learning_rate": 0.02,
            #     "architecture": "CNN",
            #     "dataset": "CIFAR-100",
            #     "epochs": 10,
            # },
            )

    # Set random seed
    seed = params["seed"]
    set_random_seed(seed)

    model = None
    trainer = None

    if params["model"] == "Transformer":

        
        model = Transformer(**(filter_param(params, Transformer)))
        print(sum(p.numel() for p in model.parameters() if p.requires_grad))

    elif params["model"] == "LSTM":
        model = LSTM(**filter_param(params, LSTM))
        print(sum(p.numel() for p in model.parameters() if p.requires_grad))

    trainer = BenchmarkTrainer(
        model,
        train_dataloader,
        torch.optim.Adam(model.parameters()),
        torch.nn.CrossEntropyLoss(),
        test_loader=test_dataloader,
        include_init_state_info=params["include_init_state_info"],
        reset_every=params["reset_every"],
        log = log, 
    )
    
    losses, model = trainer.train(params["epochs"])

    ## Analytics
    # get state & action kernel similarities
    
    ## Evaluations
    train_acc, train_conf = evaluate.benchmark_accuracy(model, trainer.train_dataset)
    test_acc, test_conf = evaluate.benchmark_accuracy(model, trainer.test_dataset)
    print("Train acc/conf: ", train_acc, np.mean(train_conf))
    print("Test acc/conf: ", test_acc, np.mean(test_conf))
    # state-action_acc + conf

    if debug: 
        print("Train obs accuracy/confidence:", train_acc, np.mean(train_conf))
        print("Test obs accuracy/confidence:", test_acc, np.mean(test_conf))
        # s.a._acc

    ## Visualization 
    # num_desired_trajectories = params['num_desired_trajectories']
    # trajectory_length = params['trajectory_length']
    # batch_visualize(model.get_state_differences().numpy(), legend = "State", methods = gconfig["visual_methods"], show = show, log = log)
    # batch_visualize(model.get_action_differences().numpy(), legend = "Action", methods = gconfig["visual_methods"], show = show, log = log)
    # visualize_loss(losses, num_desired_trajectories, trajectory_length, show = show, per_epoch=False)
    # visualize_loss(losses, num_desired_trajectories, trajectory_length, show = show, per_epoch=True)

    # Log the results
    if log: 
        wandb.summary['train_acc'] = train_acc
        wandb.summary['train_conf'] = np.mean(train_conf)
        wandb.summary['test_accuracy'] = test_acc
        wandb.summary['test_conf'] = np.mean(test_conf)
        #wandb.summary['sa_accuracy'] = sa_acc
        #wandb.summary['sa_conf'] = np.mean(sa_conf)
        #wandb.summary['sa_dist_ratio'] = np.mean(sa_distance_ratios)

        wandb.finish()

    return trainer

In [10]:
for data in dataset_loader('./data'):

    # Extract datasets and metadata
    train_dataset = data['train_dataset']
    test_dataset = data['test_dataset']
    metadata = data['metadata']
    env = data['env']

    if not matches_filter(data_filter, metadata):
        continue

    train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True)
    test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=True)
    
    print(metadata)

    # parameter set up
    params = {
        "model" : 'LSTM',
        'seed' : 66,
        'hidden_size' : 16,
        'epochs' : 400,
        'include_init_state_info' : True,
        'd_model' : 32,
        'n_heads' : 4,
        'reset_every' : 20,
    }

    full_params = params.copy()
    full_params.update({
        "n_obs" : env.n_items,
        "n_states" : env.size,
        "n_actions" : env.n_actions,
        'env_type' : metadata['env_config']['env_type'],
        'env_args': metadata['env_config']['args'],
        'dseed': metadata['seed'],
        "trajectory_length" : metadata['trajectory_length'],  # numer of node visits in a trajectory
        "num_desired_trajectories" : metadata['num_desired_trajectories'],
    })

    model = run_trial(full_params, train_dataloader, test_dataloader, env, debug=debug, show = show, log=log)


  return torch.load(io.BytesIO(b))


{'n_nodes': 9, 'trajectory_length': 16, 'num_desired_trajectories': 30, 'env_config': {'n_items': 9, 'env_type': 'tree', 'batch_size': 16, 'num_desired_trajectories': 30, 'unique': True, 'args': {'levels': 3}}, 'seed': 80}


2359


Epochs:   0%|          | 0/400 [00:00<?, ?it/s]