In [1]:
import evaluate
import random
import os
import pickle

import torch
import torch.nn.functional as F
import numpy as np
from torch.utils.data import Dataset, DataLoader 
from tqdm import tqdm

# Logging
import wandb 

# Visualization Tools
import networkx as nx
import matplotlib.pyplot as plt
import seaborn as sns

# Our code
from dataloader import *
from trainer import POCMLTrainer
from model import POCML, sim
from visualizer import * 
from utils import *

import pickle


# Loading data

In [2]:
def dataset_loader(directory):
    """
    A generator function that yields the contents of each .pickle file in the given directory.

    Parameters:
    - directory (str): The path to the directory containing the .pickle files.

    Yields:
    - data: The content of each .pickle file, one at a time.
    """
    for filename in os.listdir(directory):
        if filename.endswith('.pickle'):
            file_path = os.path.join(directory, filename)
            with open(file_path, 'rb') as f:
                data = pickle.load(f)  # Load the .pickle file
                yield data  # Yield the loaded data one by one


In [3]:
# Example data usage:

# for data in dataset_loader('./data'):
#     # Extract datasets and metadata
#     train_dataset = data['train_dataset']
#     test_dataset = data['test_dataset']
#     metadata = data['metadata']
#     env = data['env']
#     print(metadata)

# Automate training

In [4]:
def generate_run_name(params):
    name = f"{params['env_type']}_{params['env_args']}_sdim_{params['state_dim']}_rfdim_{params['random_feature_dim']}_lrV_{params['lr_V']}_seed_{params['seed']}"
    return name

In [5]:
def run_trial(params, train_dataloader, test_dataloader, env, debug=False, show = True, log=False, save=True):

    run_name = generate_run_name(params)

    if log: 
        wandb.init(
            # Set the project where this run will be logged
            project="POCML",
            # We pass a run name (otherwise it’ll be randomly assigned, like sunshine-lollypop-10)
            name=run_name,
            # Track hyperparameters and run metadata
            config = params,
            # config={
            #     "learning_rate": 0.02,
            #     "architecture": "CNN",
            #     "dataset": "CIFAR-100",
            #     "epochs": 10,
            # },
            )

    # Set random seed
    seed = params["seed"]
    set_random_seed(seed)

    # Filter parameters to match the model & trainer's __init__ signature
    trainer_params = filter_param(params, POCMLTrainer)
    model_params = filter_param(params, POCML)

    # Instantiate the model & trainer using the filtered dictionary
    model = POCML(**model_params)
    trainer = POCMLTrainer(model = model, train_loader = train_dataloader, log = log, debug =debug, **trainer_params)
    # train the model and record its loss
    # loss_record = np.array(trainer.train(params["epochs"])).reshape(params["epochs"],-1)
    loss_record, model = trainer.train(params["epochs"])
    
    ## Save
    if save:
        torch.save(model.state_dict(), f"model/{run_name}.ckpt")

    ## Analytics
    # get state & action kernel similarities
    phi_Q = model.get_state_kernel()
    phi_V = model.get_action_kernel()
    k_sim_Q = sim(phi_Q, phi_Q)
    k_sim_V = sim(phi_V, phi_V)

    ## Evaluations
    train_acc, train_confidences = evaluate.accuracy(model, train_dataloader)
    test_acc, test_confidences = evaluate.accuracy(model, test_dataloader)
    sa_acc, sa_confidences, sa_distance_ratios = evaluate.state_transition_consistency(model, env)

    if debug: 
        print("State kernel similarities:\n", k_sim_Q)
        print("Action kernel similarities:\n", k_sim_V)

        print("Train obs accuracy/confidence:", train_acc, np.mean(train_confidences))
        print("Test obs accuracy/confidence:", test_acc, np.mean(test_confidences))
        print("State-action accuracy/confidence/distance ratio:", sa_acc, np.mean(sa_confidences), np.mean(sa_distance_ratios))

    ## Visualization 
    num_desired_trajectories = params['num_desired_trajectories']
    trajectory_length = params['trajectory_length']
    pca_visualize(model, env, log=log, show=show)
    visualize_loss(loss_record, num_desired_trajectories, trajectory_length, show = show, per_epoch=False)
    visualize_loss(loss_record, num_desired_trajectories, trajectory_length, show = show, per_epoch=True)

    # # Log train and validation metrics to wandb
    # TODO  
    # metrics = {}
    # val_metrics = {"val/val_loss": val_loss,
    #                 "val/val_accuracy": accuracy}
    # wandb.log({**metrics, **val_metrics})

    # Log the results
    if log: 
        wandb.summary['train_acc'] = train_acc
        wandb.summary['train_conf'] = np.mean(train_confidences)
        wandb.summary['test_accuracy'] = test_acc
        wandb.summary['test_conf'] = np.mean(test_confidences)
        wandb.summary['sa_accuracy'] = sa_acc
        wandb.summary['sa_conf'] = np.mean(sa_confidences)
        wandb.summary['sa_dist_ratio'] = np.mean(sa_distance_ratios)

        wandb.finish()

    # TODO log models
    # # beta_obs, beta_state, clean up rate
    # torch.save(model.state_dict(), "model/model_12_12_1.ckpt")

    return trainer

# Set config & hyperparameter pools for wandb

In [6]:
def matches_filter(allowed_values_dict, input_values_dict):
    # Iterate through the input values dictionary
    for key, value in allowed_values_dict.items():
        # If the key exists in allowed_values_dict and the value does not match
        if key in input_values_dict:
            if input_values_dict[key] not in value:
                return False
        else:
            return False
    # If all checks pass, return True
    return True

# Filter data to avoid reruns 
data_filter = {
    #'seed' : [65, 70, 75,]
}

In [7]:
# Convention: for each hyperparameter key, set the value to a list if you want to try multiple values
param_pool = {  
    # data-related config; can't be automated for now
    # "n_obs" : env.n_items,
    # "n_states" : env.size,
    # "n_actions" : env.n_actions,
    # "trajectory_length" : trajectory_length,  # numer of node visits in a trajectory
    # "num_desired_trajectories" : num_desired_trajectories,
    # Experiments
    "seed": [66, 67, 68, 69],
    # model 
    "state_dim" : [20, 100, 200, 1000],
    #"state_dim" : [50],
    #"state_dim" : [20], # best param.
    #
    "random_feature_dim" : [500], # best param.
    #"random_feature_dim" : [200, 500, 1000],
    "alpha" : 4, 
    #"alpha": [1, 2, 4, 8],
    "memory_bias" : True,
    "batch_size": 64,
    # trainer
    "lr_Q" : 0.1, 
    #"lr_V" : [0.02, 0.04, 0.08, 0.1], 
    "lr_V" : [0.04], 
    "lr_all" : 0.1,
    "lr_M": 1,
    "reg_Q": 0, # l2 reg to prevent manifold overfitting
    "reg_V": 0,
    "reg_M": 0,
    "eps_M": 1e-3,
    "max_iter_M": 1,
    "normalize": False,
    "reset_every" : 1,#[1, 5, 10],
    "update_state_given_obs": [True],
    # training / optimizer 
    "epochs" : 40,
}

debug = False
show = False
log = True

In [8]:
for data in dataset_loader('./data'):

    # Extract datasets and metadata
    train_dataloader = data['train_dataloader']
    test_dataloader = data['test_dataloader']
    metadata = data['metadata']
    env = data['env']

    if not matches_filter(data_filter, metadata):
        continue

    if metadata["env_config"]["env_type"] == "tree":
        print(metadata)

        # Use the loader to generate combinations one at a time
        param_loader = generate_combinations_loader(param_pool)

        for params in param_loader:

            full_params = params.copy()
            full_params.update({
                "n_obs" : env.n_items,
                "n_states" : env.size,
                "n_actions" : env.n_actions,
                'env_type' : metadata['env_config']['env_type'],
                'env_args': metadata['env_config']['args'],
                'dseed': metadata['seed'],
                "trajectory_length" : metadata['trajectory_length'],  # numer of node visits in a trajectory
                "num_desired_trajectories" : metadata['num_desired_trajectories'],
            })

            model = run_trial(full_params, train_dataloader, test_dataloader, env, debug=debug, show = show, log=log)


{'n_nodes': 9, 'trajectory_length': 12, 'num_desired_trajectories': 1536, 'batch_size': 64, 'env_config': {'n_items': 9, 'env_type': 'tree', 'trajectory_length': 12, 'num_desired_trajectories': 1536, 'unique': True, 'args': {'levels': 3}}, 'seed': 65}


[34m[1mwandb[0m: Currently logged in as: [33mchyeung[0m ([33mevanjeong[0m). Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


Epochs:  65%|██████▌   | 26/40 [03:08<01:41,  7.23s/it]

mean loss is nan





0,1
train/epoch_ct,▁▁▂▂▂▂▃▃▃▄▄▄▄▅▅▅▅▆▆▆▇▇▇▇██
train/loss,▂▁▁▂▁▁▁▁▂▁▁▁▃▁█▃▁▁▁▁▂▂▂▂▂▁▁█▂▁▂▂▂▁█▁▁▃▂▂
train/mloss_p_epoch,▃▃▃▃▃▃▂▁▁▁▁▂▂▁▄▂▁▁▂▂▁▂▃▃▄█
train/mloss_p_traj,▃▃▃▃▃▃▃▃▂▁▁▁▂▁▁▁▂▃▂▂▃▂▁▁▁▂▂▂▁▁▂▃▁▂▃▃▃▅█▆
train/step_ct,▁▁▁▁▁▁▂▂▂▂▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇▇▇▇█
train/traj_ct,▁▁▁▁▁▂▂▂▂▂▃▃▃▃▃▄▄▅▅▅▅▅▅▅▅▆▆▆▇▇▇▇▇▇▇▇████

0,1
sa_accuracy,0.83333
sa_conf,0.43241
sa_dist_ratio,0.53591
test_accuracy,0.40086
test_conf,0.29988
train/epoch_ct,26.0
train/loss,
train/mloss_p_epoch,3.80078
train/mloss_p_traj,
train/step_ct,7128.0


Epochs: 100%|██████████| 40/40 [06:15<00:00,  9.39s/it]


0,1
train/epoch_ct,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/loss,▂▂▁▁▂▂▂▂▂▂▁▂▁▁▁▂▁▁▂██▃▃▃▂█▂▂▂▂▄▂█▂█▂▂▂▂▂
train/mloss_p_epoch,▃▃▃▃▃▃▂▁▁▁▁▁▆▇█▇▆▆▇▆▅▄▄▄▄▄▄▅▅▄▃▂▂▂▂▃▆▄▄▄
train/mloss_p_traj,▃▃▃▃▃▃▂▁▁▁▁▁▇█▇▅▇▅▆▅▄▄▅▄▄▄▄▄▄▄▃▃▃▃▂▂▂▂▂▄
train/step_ct,▁▁▁▁▁▂▂▂▂▂▃▃▃▄▄▄▄▄▅▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇████
train/traj_ct,▁▁▁▂▂▂▂▂▂▂▃▃▃▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▆▆▆▆▇▇▇▇▇▇▇█

0,1
sa_accuracy,0.83333
sa_conf,0.49223
sa_dist_ratio,0.5406
test_accuracy,0.47461
test_conf,0.33878
train/epoch_ct,40.0
train/loss,1.57557
train/mloss_p_epoch,3.07292
train/mloss_p_traj,2.78904
train/step_ct,10560.0


Epochs:  82%|████████▎ | 33/40 [07:39<01:37, 13.93s/it]

mean loss is nan





0,1
train/epoch_ct,▁▁▁▂▂▂▂▃▃▃▃▃▄▄▄▄▅▅▅▅▅▆▆▆▆▆▇▇▇▇███
train/loss,▁▁▁▂▂▂▁▁▁▁▂▁▁▁▁█▁▂▁▂▂▂▂▂▃▂▁█▂▁▂▂▁▁▁▃▂▁▂▂
train/mloss_p_epoch,▄▄▄▄▄▃▃▂▁▁▂▃▃▃▃▄▅▃▄▅▅▃▃▄▄▃▃▄▆▄▃▇█
train/mloss_p_traj,▃▃▄▃▃▃▂▃▂▂▁▂▃▃▂▃▄▄▄▄▂▃▃▅▄▂▅▄▃▄▂▂▄▅▄▃▄▆█
train/step_ct,▁▁▂▂▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▅▅▅▅▅▅▅▅▅▅▆▆▇▇▇▇████
train/traj_ct,▁▁▁▂▂▃▃▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇███

0,1
sa_accuracy,0.75
sa_conf,0.50354
sa_dist_ratio,0.51583
test_accuracy,0.50041
test_conf,0.357
train/epoch_ct,33.0
train/loss,
train/mloss_p_epoch,3.88083
train/mloss_p_traj,
train/step_ct,8976.0


Epochs: 100%|██████████| 40/40 [12:40<00:00, 19.01s/it]


0,1
train/epoch_ct,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/loss,▂▄▂▂▂▂▁▁▁▁▂▂▂▂▂▁▂█▂▂█▂▂▁▂▁▂▁▁█▂▂▂▂▂▃▃▃▄▄
train/mloss_p_epoch,▃▃▃▃▃▃▃▂▁▁▁▂▂▂▄▄▂▃▂▂▂▂▂▂▂▂▂▁▁▂▂▃▄▅▄▄████
train/mloss_p_traj,▃▃▃▃▃▃▃▁▂▁▁▁▃▁▁▃▃▃▂▂▂▂▃▂▂▂▁▂▁▁▂▂▃▃▃█▇▇█▇
train/step_ct,▁▁▂▂▂▂▂▂▂▃▃▄▄▄▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▆▇▇▇▇▇███
train/traj_ct,▁▁▁▂▂▂▂▃▃▃▃▄▄▄▄▄▄▄▅▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇▇███

0,1
sa_accuracy,0.91667
sa_conf,0.69054
sa_dist_ratio,0.41789
test_accuracy,0.62861
test_conf,0.50147
train/epoch_ct,40.0
train/loss,3.31783
train/mloss_p_epoch,4.23366
train/mloss_p_traj,4.3083
train/step_ct,10560.0


Epochs: 100%|██████████| 40/40 [05:10<00:00,  7.77s/it]


0,1
train/epoch_ct,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/loss,▁▂▂▂▁▁█▁▁▁▂▁▂▂▁▂▁▁▂▁▁▁▂▃▂▂▂▂▂▂▂▃▂▂█▂▂▂▂▂
train/mloss_p_epoch,▅▅▅▅▅▄▃▂▂▂▁▃▃▄▃▂▃▃▅▃▃▄▆▅▅█▅▅▄▅▇▅▄▅▄█▅▆█▆
train/mloss_p_traj,▄▄▄▄▄▄▄▄▂▂▁▁▂▂▃▂▂▄▂▂▄▅▅▄▄▆▅▅▅▃▃▆▄▅█▆▅▄▆▃
train/step_ct,▁▁▂▂▂▂▂▃▃▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▆▆▆▆▆▆▆▇▇▇▇▇▇███
train/traj_ct,▁▁▁▁▁▁▂▂▂▂▃▃▃▃▄▄▄▄▄▄▄▄▅▅▅▅▅▅▅▅▆▆▆▆▆▇▇███

0,1
sa_accuracy,0.75
sa_conf,0.53458
sa_dist_ratio,0.49745
test_accuracy,0.53504
test_conf,0.37883
train/epoch_ct,40.0
train/loss,2.52692
train/mloss_p_epoch,3.04434
train/mloss_p_traj,3.40904
train/step_ct,10560.0


Epochs:  48%|████▊     | 19/40 [03:07<03:26,  9.85s/it]

mean loss is nan





0,1
train/epoch_ct,▁▁▂▂▃▃▃▄▄▅▅▅▆▆▆▇▇██
train/loss,▂▂▂▄▁█▂▂▁▁▂▂▂▂▂▁▁▁▁▁▂▂▁▁▃▁▁▂▁██▂▂▂▃▂▃▂▂
train/mloss_p_epoch,▄▄▄▄▄▃▃▂▁▁▁▂▂▃▄▃▄▆█
train/mloss_p_traj,▄▃▃▄▃▄▃▃▃▄▃▃▃▃▃▁▁▁▁▁▂▁▁▂▂▂▃▄▄▄▄▃▂▃▅▇▇█
train/step_ct,▁▁▁▂▂▂▂▂▂▂▂▃▃▃▃▃▄▄▄▅▅▅▅▆▆▆▆▆▆▆▆▇▇▇▇▇████
train/traj_ct,▁▁▁▁▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇▇███

0,1
sa_accuracy,0.83333
sa_conf,0.54997
sa_dist_ratio,0.51133
test_accuracy,0.50758
test_conf,0.37375
train/epoch_ct,19.0
train/loss,
train/mloss_p_epoch,3.85367
train/mloss_p_traj,
train/step_ct,5280.0


Epochs:  35%|███▌      | 14/40 [02:45<05:07, 11.81s/it]

mean loss is nan





0,1
train/epoch_ct,▁▂▂▃▃▄▄▅▅▆▆▇▇█
train/loss,▂▁▂▂▂▁▄▄▂▄▄▁▄▄▂▂▂▂▂█▂▂▂▂▂█▂▂▂▂▂▁▁▂▁▁▂▁▂
train/mloss_p_epoch,▃▃▃▃▃▃▃▂▁▁▁▁▃█
train/mloss_p_traj,▄▄▄▄▄▄▄▄▄▄▄▄▄▃▄▃▃▃▃▂▂▂▂▁▂▁▁▁▁▂▁▂▃▂▂▇▇███
train/step_ct,▁▁▁▁▁▁▂▂▂▃▃▃▄▄▄▄▄▄▄▄▅▅▅▅▆▆▆▆▆▆▆▇▇▇▇▇▇▇██
train/traj_ct,▁▁▂▂▂▃▃▃▃▃▃▃▃▄▄▄▅▅▅▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇█████

0,1
sa_accuracy,0.83333
sa_conf,0.49711
sa_dist_ratio,0.51424
test_accuracy,0.48745
test_conf,0.34281
train/epoch_ct,14.0
train/loss,
train/mloss_p_epoch,4.01261
train/mloss_p_traj,
train/step_ct,3960.0


Epochs: 100%|██████████| 40/40 [13:59<00:00, 20.99s/it]


0,1
train/epoch_ct,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/loss,▂▂▃▁▁▁▁▂▁▁▁▁█▁▂▃▃▃█▂▃▂▂▂▂▃▂▂▂█▂▂▃▃▂▁▁▁▃▂
train/mloss_p_epoch,▃▃▃▃▃▃▂▂▁▁▁▄▃▁▁▂▄█▇▇▆▅▅▅▅▃▄▃▄▄▃▃▃▄▆▆▅▅▄▅
train/mloss_p_traj,▃▃▃▃▃▂▂▂▂▁▁▁▂▂▂██▇▆▆▆▆▅▅▄▄▄▃▃▄▄▃▃▃▃▅▅▅▅▄
train/step_ct,▁▁▂▂▂▂▂▂▂▃▃▃▃▄▄▄▄▄▄▅▆▆▆▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇██
train/traj_ct,▁▁▁▁▁▁▂▂▂▂▃▃▃▄▄▄▅▅▅▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇████

0,1
sa_accuracy,0.83333
sa_conf,0.53538
sa_dist_ratio,0.49367
test_accuracy,0.53664
test_conf,0.39754
train/epoch_ct,40.0
train/loss,2.54053
train/mloss_p_epoch,3.4426
train/mloss_p_traj,3.93535
train/step_ct,10560.0


Epochs:  38%|███▊      | 15/40 [02:21<03:55,  9.42s/it]

mean loss is nan





0,1
train/epoch_ct,▁▁▂▃▃▃▄▅▅▅▆▇▇▇█
train/loss,▂█▂▁▂▂▁▂▄▄▂▂▂▂▂█▂▄▁▂▂▁▁▁▁▂▂▁▂█▁█▁▁▁█▁▂▃▂
train/mloss_p_epoch,▅▅▅▅▄▄▃▁▁▁▂▂▂▁█
train/mloss_p_traj,▄▄▄▄▃▄▄▃▄▄▄▄▄▃▃▃▃▂▂▂▁▁▁▁▁▁▁▂▁▂▂▁▁▂▂▂▁▁▃█
train/step_ct,▁▁▁▁▁▁▂▂▂▂▃▃▄▄▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇▇██
train/traj_ct,▁▁▁▂▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▄▄▅▅▅▅▆▆▆▇▇▇▇▇▇▇███

0,1
sa_accuracy,0.83333
sa_conf,0.5219
sa_dist_ratio,0.48853
test_accuracy,0.55996
test_conf,0.38476
train/epoch_ct,15.0
train/loss,
train/mloss_p_epoch,3.52664
train/mloss_p_traj,
train/step_ct,4224.0


Epochs: 100%|██████████| 40/40 [06:42<00:00, 10.07s/it]


0,1
train/epoch_ct,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/loss,▁▁█▁▂▁▁▁▁▁▁▁▂▂▂█▁▁▂▂▂▂▁▁▂▁▁▂▁▂██▁▂▁▂▁▂▁▁
train/mloss_p_epoch,▅▅▅▅▅▄▃▂▁▁▃▆▆▆▅▄▂▃▅▅▆▅▄▅▅▆▅▇▇▅▅▇▇▇▆▆█▆▅▄
train/mloss_p_traj,▄▄▅▄▄▁▁▄▅▅▁▂▄▅▆▄▆▄▃▅▅▄▄▃▅▆▂▆▅▅▄▇▅▄▃█▇▅▃▄
train/step_ct,▁▁▁▁▁▁▁▂▂▃▃▃▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▆▆▆▆▆▆▆▆▇▇▇██
train/traj_ct,▁▁▁▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▆▇▇▇▇▇███

0,1
sa_accuracy,0.75
sa_conf,0.36701
sa_dist_ratio,0.58411
test_accuracy,0.35529
test_conf,0.26704
train/epoch_ct,40.0
train/loss,1.60282
train/mloss_p_epoch,2.83261
train/mloss_p_traj,2.80933
train/step_ct,10560.0


Epochs: 100%|██████████| 40/40 [17:45<00:00, 26.63s/it]


0,1
train/epoch_ct,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/loss,█▂▂▂▁▂▂▄▁▁▁▃▂▁▂▁▃▁▁█▂▂▂▁▂▂▂▂▂▂▂▂▂▂▂█▂▁▂▂
train/mloss_p_epoch,▄▄▄▄▄▄▃▁▁▁▂▂▂▂▂▂▂▂▃▃▃▄▃▃▆▇▆▅▅▅▆▅▆██▆▆█▆▅
train/mloss_p_traj,▄▄▄▄▃▃▃▁▁▁▁▂▂▂▃▂▂▂▂▂▂▁▄▃▃▃▄▃▁▁█▄▄▅▆▇▅▅█▅
train/step_ct,▁▁▁▁▁▂▂▂▂▃▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▆▇▇▇███
train/traj_ct,▁▁▁▁▁▂▃▃▃▃▃▃▃▄▅▅▅▅▅▅▆▆▆▆▆▆▆▆▆▆▇▇▇▇▇█████

0,1
sa_accuracy,0.83333
sa_conf,0.48483
sa_dist_ratio,0.54131
test_accuracy,0.52012
test_conf,0.36593
train/epoch_ct,40.0
train/loss,2.09536
train/mloss_p_epoch,3.16813
train/mloss_p_traj,3.49797
train/step_ct,10560.0


Epochs:  90%|█████████ | 36/40 [16:04<01:47, 26.80s/it]

mean loss is nan





0,1
train/epoch_ct,▁▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇███
train/loss,▁▁▁▂▂▁▁▁▁▂▁▆█▆█▆▇▆▇▆▅▆▄▅▅▆▄▇▄▃▆▇▄▃▄▂▄▃▃▅
train/mloss_p_epoch,▂▂▂▂▂▂▂▁▁▁▁▁▁▃▂▂▅██████▇▇▇▆▆▆▆▅▃▃▃▄▄
train/mloss_p_traj,▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▂▁▃▆▇▇███▇▇█▇▇▆▇▅▆▆▅▃▃▃▃
train/step_ct,▁▁▁▁▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▆▆▇▇▇▇▇▇▇█
train/traj_ct,▁▁▁▁▁▂▂▃▃▃▄▄▄▄▄▄▄▅▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇▇███

0,1
sa_accuracy,0.83333
sa_conf,0.48511
sa_dist_ratio,0.51975
test_accuracy,0.50183
test_conf,0.34617
train/epoch_ct,36.0
train/loss,
train/mloss_p_epoch,3.84353
train/mloss_p_traj,
train/step_ct,9768.0


Epochs: 100%|██████████| 40/40 [08:33<00:00, 12.84s/it]


0,1
train/epoch_ct,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/loss,▁▁▁█▂▁▁▁▁▁▁▂▂▁▂▁▂▂▂▂▁▂▂▂█▁▂▁█▂▁▁▁▂▁▁▁▃▂▁
train/mloss_p_epoch,▄▄▄▄▄▄▄▃▂▂▂▂▂▃▄██▆▅▆▅▄▄▄▅▆▆▅▅▄▃▂▃▄▂▁▂▃▃▄
train/mloss_p_traj,▄▄▄▄▄▂▂▁▂▄▂▂▃▅▅█▆▅▅▄▄▃▆▅▆▅▅▅▄▄▁▄▃▃▂▂▄▃▄▃
train/step_ct,▁▁▁▁▁▂▂▂▂▂▃▃▃▃▄▄▄▄▄▅▅▅▆▆▆▆▆▆▆▆▆▇▇▇▇▇████
train/traj_ct,▁▁▁▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇▇▇█

0,1
sa_accuracy,0.75
sa_conf,0.63651
sa_dist_ratio,0.45613
test_accuracy,0.63364
test_conf,0.50021
train/epoch_ct,40.0
train/loss,1.37003
train/mloss_p_epoch,2.88226
train/mloss_p_traj,2.49894
train/step_ct,10560.0


Epochs:  62%|██████▎   | 25/40 [07:34<04:32, 18.19s/it]

mean loss is nan





0,1
train/epoch_ct,▁▁▂▂▂▂▃▃▃▄▄▄▅▅▅▅▆▆▆▇▇▇▇██
train/loss,▂▁▁▁▂▂▁▁▁▂▁▁▁▁▁▁▁▁▁▁█▂▁▂▁▂▂▂▂▂▁▁▁▂▂▂▃▂▁
train/mloss_p_epoch,▃▃▃▃▃▃▃▁▁▁▁▁▃▃▃▄▄▅▂▃▃▆▅▆█
train/mloss_p_traj,▄▄▄▄▄▄▄▄▄▄▂▂▂▁▂▂▁▂▁▁▄▄▄▂▅▅▃▂▃▃▂█▅▆▆▆▅▄▆
train/step_ct,▁▁▁▁▁▂▂▂▂▂▂▂▂▃▄▄▄▄▄▅▅▅▅▅▅▅▆▆▆▆▇▇▇▇▇▇▇███
train/traj_ct,▁▁▂▂▂▂▂▂▂▂▃▃▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇█

0,1
sa_accuracy,0.83333
sa_conf,0.46775
sa_dist_ratio,0.5648
test_accuracy,0.46609
test_conf,0.32494
train/epoch_ct,25.0
train/loss,
train/mloss_p_epoch,3.86261
train/mloss_p_traj,
train/step_ct,6864.0


Epochs: 100%|██████████| 40/40 [16:44<00:00, 25.11s/it] 


0,1
train/epoch_ct,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/loss,▂▁▂▁▁▂▁▂▁▁▁▁▁█▁▂▁▂▂▂▂▃▂▁▁█▁▁▂▁▁█▂▃▃▃▂▃▃▃
train/mloss_p_epoch,▃▃▃▃▃▃▃▂▁▁▁▂▂▁▂▂▃▄▃▃▄▇▃▂▂▂▂▂▂▂▆▇█▇▇▇▇▇▇▇
train/mloss_p_traj,▄▄▄▄▄▄▃▁▂▁▁▃▂▃▂▂▂▂▂▄▅▂▂▄▃▂▂▁▃▃▂▂▃▂▁▆██▇▇
train/step_ct,▁▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▃▃▃▃▃▃▄▄▄▄▅▅▅▅▅▅▆▆▆▇▇███
train/traj_ct,▁▁▁▁▁▂▂▂▂▂▂▂▃▃▃▃▃▃▃▃▄▄▄▄▅▅▅▆▆▆▆▇▇▇▇▇▇▇██

0,1
sa_accuracy,0.83333
sa_conf,0.5494
sa_dist_ratio,0.48444
test_accuracy,0.56209
test_conf,0.39973
train/epoch_ct,40.0
train/loss,2.55439
train/mloss_p_epoch,3.70156
train/mloss_p_traj,3.73685
train/step_ct,10560.0


Epochs:  35%|███▌      | 14/40 [05:55<11:00, 25.41s/it]

mean loss is nan





0,1
train/epoch_ct,▁▂▂▃▃▄▄▅▅▆▆▇▇█
train/loss,▂▂▁▂▂▁▁▂▁▂▂▁▄▄▂▂▁▃▁▁▁▁▁▁▁▂▁▁▂▂▂▂▂▂▃▃▂▂▆█
train/mloss_p_epoch,▃▃▃▃▃▃▂▁▁▁▁▂▆█
train/mloss_p_traj,▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▂▂▂▂▁▁▁▁▁▁▁▆██▇▆▅▇▇
train/step_ct,▁▁▁▁▁▂▂▂▂▂▂▂▃▃▃▄▄▄▄▄▅▅▅▅▆▆▆▆▆▇▇▇▇▇██████
train/traj_ct,▁▁▁▂▂▂▂▂▃▃▃▃▃▃▃▄▄▄▄▅▅▅▅▅▅▅▆▆▆▆▇▇▇▇▇▇▇▇██

0,1
sa_accuracy,0.83333
sa_conf,0.51089
sa_dist_ratio,0.51442
test_accuracy,0.53883
test_conf,0.3695
train/epoch_ct,14.0
train/loss,
train/mloss_p_epoch,3.9437
train/mloss_p_traj,
train/step_ct,3960.0
