In [1]:
import evaluate
import random
import os
import pickle

import torch
import torch.nn.functional as F
import numpy as np
from torch.utils.data import Dataset, DataLoader 
from tqdm import tqdm

# Logging
import wandb 

# Visualization Tools
import networkx as nx
import matplotlib.pyplot as plt
import seaborn as sns

# Our code
from dataloader import *
from trainer import POCMLTrainer
from model import POCML, sim
from visualizer import * 
from utils import *

import pickle


# Loading data

In [2]:
def dataset_loader(directory):
    """
    A generator function that yields the contents of each .pickle file in the given directory.

    Parameters:
    - directory (str): The path to the directory containing the .pickle files.

    Yields:
    - data: The content of each .pickle file, one at a time.
    """
    for filename in os.listdir(directory):
        if filename.endswith('.pickle'):
            file_path = os.path.join(directory, filename)
            with open(file_path, 'rb') as f:
                data = pickle.load(f)  # Load the .pickle file
                yield data  # Yield the loaded data one by one


In [3]:
# Example data usage:

# for data in dataset_loader('./data'):
#     # Extract datasets and metadata
#     train_dataset = data['train_dataset']
#     test_dataset = data['test_dataset']
#     metadata = data['metadata']
#     env = data['env']
#     print(metadata)

# Automate training

In [4]:
# torch.autograd.set_detect_anomaly(True)

# n_nodes = 9
# #batch_size = 16        # Note: in og CML trajectory length == batch_size; in POCML this should be decoupled
# n_obs = 9
# trajectory_length = 16  # numer of node visits in a trajectory
# num_desired_trajectories= 30

# # env = GraphEnv( n_items=n_nodes,                     # number of possible observations
# #                 env='grid', 
# #                 batch_size=trajectory_length, 
# #                 num_desiresd_trajectories=num_desired_trajectories, 
# #                 device=None, 
# #                 unique=True,                         # each state is assigned a unique observation if true
# #                 args = {"rows": 3, "cols": 3}
# #             )

# env = GraphEnv( n_items=n_nodes,                     # number of possible observations
#                 env='tree', 
#                 batch_size=trajectory_length, 
#                 num_desired_trajectorie=num_desired_trajectories, 
#                 device=None, 
#                 unique=True,                         # each state is assigned a unique observation if true
#                 args = {"levels": 4}
#             )

# train_dataset = env.gen_dataset()
# test_dataset = env.gen_dataset()

# train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True)
# test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=True)

In [5]:


def generate_run_name(params):
    name = f"{params['env_type']}_{params['env_args']}_sdim_{params['state_dim']}_rfdim_{params['random_feature_dim']}_a_{params['alpha']}_dseed_{params['dseed']}_seed_{params['seed']}"
    return name

In [6]:
def run_trial(params, train_dataloader, test_dataloader, env, gconfig, debug=False, show = True, log=False):

    run_name = generate_run_name(params)

    if log: 
        wandb.init(
            # Set the project where this run will be logged
            project="POCML",
            # We pass a run name (otherwise itâ€™ll be randomly assigned, like sunshine-lollypop-10)
            name=run_name,
            # Track hyperparameters and run metadata
            config = params,
            # config={
            #     "learning_rate": 0.02,
            #     "architecture": "CNN",
            #     "dataset": "CIFAR-100",
            #     "epochs": 10,
            # },
            )

    # Set random seed
    seed = params["seed"]
    set_random_seed(seed)

    # Filter parameters to match the model & trainer's __init__ signature
    trainer_params = filter_param(params, POCMLTrainer)
    model_params = filter_param(params, POCML)

    # Instantiate the model & trainer using the filtered dictionary
    model = POCML(**model_params)
    trainer = POCMLTrainer(model = model, train_loader = train_dataloader, log = log, debug =debug, **trainer_params)
    # train the model and record its loss
    # loss_record = np.array(trainer.train(params["epochs"])).reshape(params["epochs"],-1)
    loss_record, model = trainer.train(params["epochs"])
    
    ## Analytics
    # get state & action kernel similarities
    phi_Q = model.get_state_kernel()
    phi_V = model.get_action_kernel()
    k_sim_Q = sim(phi_Q, phi_Q)
    k_sim_V = sim(phi_V, phi_V)

    ## Evaluations
    train_acc, train_confidences = evaluate.accuracy(model, train_dataloader)
    test_acc, test_confidences = evaluate.accuracy(model, test_dataloader)
    sa_acc, sa_confidences, sa_distance_ratios = evaluate.state_transition_consistency(model, env)

    if debug: 
        print("State kernel similarities:\n", k_sim_Q)
        print("Action kernel similarities:\n", k_sim_V)

        print("Train obs accuracy/confidence:", train_acc, np.mean(train_confidences))
        print("Test obs accuracy/confidence:", test_acc, np.mean(test_confidences))
        print("State-action accuracy/confidence/distance ratio:", sa_acc, np.mean(sa_confidences), np.mean(sa_distance_ratios))

    ## Visualization 
    num_desired_trajectories = params['num_desired_trajectories']
    trajectory_length = params['trajectory_length']
    batch_visualize(model.get_state_differences().numpy(), legend = "State", methods = gconfig["visual_methods"], show = show, log = log)
    batch_visualize(model.get_action_differences().numpy(), legend = "Action", methods = gconfig["visual_methods"], show = show, log = log)
    visualize_loss(loss_record, num_desired_trajectories, trajectory_length, show = show, per_epoch=False)
    visualize_loss(loss_record, num_desired_trajectories, trajectory_length, show = show, per_epoch=True)

    # # Log train and validation metrics to wandb
    # TODO  
    # metrics = {}
    # val_metrics = {"val/val_loss": val_loss,
    #                 "val/val_accuracy": accuracy}
    # wandb.log({**metrics, **val_metrics})

    # Log the results
    if log: 
        wandb.summary['train_acc'] = train_acc
        wandb.summary['train_conf'] = np.mean(train_confidences)
        wandb.summary['test_accuracy'] = test_acc
        wandb.summary['test_conf'] = np.mean(test_confidences)
        wandb.summary['sa_accuracy'] = sa_acc
        wandb.summary['sa_conf'] = np.mean(sa_confidences)
        wandb.summary['sa_dist_ratio'] = np.mean(sa_distance_ratios)

        wandb.finish()

    # TODO log models
    # # beta_obs, beta_state, clean up rate
    # torch.save(model.state_dict(), "model/model_12_12_1.ckpt")

    return trainer

# Set config & hyperparameter pools for wandb

In [7]:
def matches_filter(allowed_values_dict, input_values_dict):
    # Iterate through the input values dictionary
    for key, value in allowed_values_dict.items():
        # If the key exists in allowed_values_dict and the value does not match
        if key in input_values_dict:
            if input_values_dict[key] not in value:
                return False
        else:
            return False
    # If all checks pass, return True
    return True

# Filter data to avoid reruns 
data_filter = {
    #'seed' : [65, 70, 75,]
}

In [8]:
# Convention: for each hyperparameter key, set the value to a list if you want to try multiple values
param_pool = {  
    # data-related config; can't be automated for now
    # "n_obs" : env.n_items,
    # "n_states" : env.size,
    # "n_actions" : env.n_actions,
    # "trajectory_length" : trajectory_length,  # numer of node visits in a trajectory
    # "num_desired_trajectories" : num_desired_trajectories,
    # Experiments
    "seed": [66],
    # model 
    #"state_dim" : [25, 50, 100, 200],
    #"state_dim" : [50],
    "state_dim" : [25], # best param.
    #
    "random_feature_dim" : [2000], # best param.
    #"random_feature_dim" : [250, 500, 1000, 2000],
    #"alpha" : 4, 
    "alpha": [1, 2, 4, 8],
    "memory_bypass" : False,
    # trainer
    "lr_Q" : 0.1, 
    "lr_V" : [0.08], 
    "lr_all" : 0.005,
    "normalize" : False,
    "reset_every" : 1,#[1, 5, 10],
    "update_state_given_obs": [True],
    # training / optimizer 
    "epochs" : 600,
}

gconfig = {
    # Visualization
    "visual_methods": "all"
}

debug = False
show = False
log = True

In [None]:
for data in dataset_loader('./data'):

    # Extract datasets and metadata
    train_dataset = data['train_dataset']
    test_dataset = data['test_dataset']
    metadata = data['metadata']
    env = data['env']

    if not matches_filter(data_filter, metadata):
        continue

    train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True)
    test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=True)
    
    print(metadata)

    # Use the loader to generate combinations one at a time
    param_loader = generate_combinations_loader(param_pool)

    for params in param_loader:

        full_params = params.copy()
        full_params.update({
            "n_obs" : env.n_items,
            "n_states" : env.size,
            "n_actions" : env.n_actions,
            'env_type' : metadata['env_config']['env_type'],
            'env_args': metadata['env_config']['args'],
            'dseed': metadata['seed'],
            "trajectory_length" : metadata['trajectory_length'],  # numer of node visits in a trajectory
            "num_desired_trajectories" : metadata['num_desired_trajectories'],
        })

        model = run_trial(full_params, train_dataloader, test_dataloader, env, gconfig = gconfig, debug=debug, show = show, log=log)
