In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import numpy as np
import hypertools as hyp
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import pickle
import itertools
from collections import defaultdict

from scipy.cluster.hierarchy import dendrogram
from sklearn.cluster import AgglomerativeClustering, KMeans
from tommas.viz.embedding_plot import extract_embeddings_and_labels
from sklearn.metrics import silhouette_samples, silhouette_score

from transformers import GPT2Model, GPT2Config, GPT2LMHeadModel, GPT2Tokenizer, GPT2ForSequenceClassification, PreTrainedTokenizerFast
from tokenizers import Tokenizer, models, normalizers, pre_tokenizers, processors
import ecco
from ecco import LM, pack_tokenizer_config
import ecco.analysis as analysis

import torch
import torch.nn as nn
from torch.optim import Adam
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from pytorch_lightning import seed_everything

from tommas.agent_modellers.iterative_action_tommas_transformer import IterativeActionTOMMASTransformer
from tommas.agents.create_iterative_action_agents import get_random_iterative_action_agent, RandomStrategySampler
from tommas.data.iterative_action_dataset import IterativeActionTrajectory, play_episode
from tommas.data.gridworld_transforms import IterativeActionFullPastCurrentSplit
from tommas.data.datamodule_factory import make_datamodule
from experiments.experiment_base import load_modeller

from tommas.viz.embedding_responsibility import plot_embedding_responsibility
from tommas.viz.embedding_plot import plot_embeddings, hyp_plot

from tommas.analysis.representation_metrics import lstm_ttx_boxplot_comparison, \
    get_representation_similarity_score, lstm_ttx_agent_param_boxplot_comparison, calculate_df_corr, \
    load_representation_data, calculate_and_combine_model_corr, calculate_model_param_corr, \
    calculate_model_specific_corr, load_all_multi_strat_models, create_models_output_df


In [16]:
models_dict = load_all_multi_strat_models()

In [17]:
model_names = list(models_dict.keys())
for name in model_names:
    if "lstm" not in name or "seed4" in name or "seed5" in name:
        del models_dict[name]


In [18]:
models_dict.keys()

dict_keys(['lstm[112,1]_lstm[64,2]_seed1', 'lstm[112,1]_lstm[64,2]_seed2', 'lstm[112,1]_lstm[64,2]_seed3', 'lstm[128,1]_lstm[64,2]_seed1', 'lstm[128,1]_lstm[64,2]_seed2', 'lstm[128,1]_lstm[64,2]_seed3', 'lstm[160,1]_lstm[64,2]_seed1', 'lstm[160,1]_lstm[64,2]_seed2', 'lstm[160,1]_lstm[64,2]_seed3', 'lstm[200,1]_lstm[64,2]_seed1', 'lstm[200,1]_lstm[64,2]_seed2', 'lstm[200,1]_lstm[64,2]_seed3', 'lstm[256,1]_lstm[64,2]_seed1', 'lstm[256,1]_lstm[64,2]_seed2', 'lstm[256,1]_lstm[64,2]_seed3', 'lstm[48,1]_lstm[64,2]_seed1', 'lstm[48,1]_lstm[64,2]_seed2', 'lstm[48,1]_lstm[64,2]_seed3', 'lstm[512,1]_lstm[64,2]_seed1', 'lstm[512,1]_lstm[64,2]_seed2', 'lstm[512,1]_lstm[64,2]_seed3', 'lstm[64,1]_lstm[64,2]_seed1', 'lstm[64,1]_lstm[64,2]_seed2', 'lstm[64,1]_lstm[64,2]_seed3', 'lstm[80,1]_lstm[64,2]_seed1', 'lstm[80,1]_lstm[64,2]_seed2', 'lstm[80,1]_lstm[64,2]_seed3', 'lstm[96,1]_lstm[64,2]_seed1', 'lstm[96,1]_lstm[64,2]_seed2', 'lstm[96,1]_lstm[64,2]_seed3'])

In [19]:
output_df = create_models_output_df(models_dict, num_agents_per_cluster=20)

In [46]:
create_models_output_df(models_dict, num_agents_per_cluster=20)

tensor([[-0.0521,  0.0016, -0.0079,  ...,  0.0192, -0.0918,  0.0786],
        [-0.0996,  0.0048,  0.0170,  ..., -0.0125, -0.0007,  0.1565],
        [ 0.0150, -0.0966,  0.0177,  ..., -0.0695, -0.0866,  0.1563],
        ...,
        [-0.1240, -0.1866,  0.0205,  ..., -0.0512,  0.0218, -0.6142],
        [ 0.0043, -0.2675, -0.1548,  ..., -0.1059, -0.0417, -0.5873],
        [ 0.1471, -0.2106, -0.1412,  ..., -0.1269,  0.0498, -0.6647]],
       device='cuda:0')


ValueError: 

In [20]:
p0_df = output_df[output_df["n_past"] == 0]
p1_df = output_df[output_df["n_past"] == 1]
p5_df = output_df[output_df["n_past"] == 5]

In [21]:
p0_df.groupby(["model_name"])[[("loss"), "acc"]].agg(["describe"])[[("loss", "describe", "mean"), ("acc", "describe", "mean")]]

Unnamed: 0_level_0,loss,acc
Unnamed: 0_level_1,describe,describe
Unnamed: 0_level_2,mean,mean
model_name,Unnamed: 1_level_3,Unnamed: 2_level_3
"lstm[112,1]_lstm[64,2]_seed1",0.381665,0.815417
"lstm[112,1]_lstm[64,2]_seed2",0.386633,0.811629
"lstm[112,1]_lstm[64,2]_seed3",0.37312,0.818258
"lstm[128,1]_lstm[64,2]_seed1",0.373407,0.799697
"lstm[128,1]_lstm[64,2]_seed2",0.386367,0.810379
"lstm[128,1]_lstm[64,2]_seed3",0.3808,0.816061
"lstm[160,1]_lstm[64,2]_seed1",0.38399,0.811364
"lstm[160,1]_lstm[64,2]_seed2",0.392729,0.807652
"lstm[160,1]_lstm[64,2]_seed3",0.379313,0.814545
"lstm[200,1]_lstm[64,2]_seed1",0.386602,0.811174


In [22]:
p5_df.groupby(["model_name"])[[("loss"), "acc"]].agg(["describe"])[[("loss", "describe", "mean"), ("acc", "describe", "mean")]]

Unnamed: 0_level_0,loss,acc
Unnamed: 0_level_1,describe,describe
Unnamed: 0_level_2,mean,mean
model_name,Unnamed: 1_level_3,Unnamed: 2_level_3
"lstm[112,1]_lstm[64,2]_seed1",12.939002,0.421061
"lstm[112,1]_lstm[64,2]_seed2",11.690869,0.426439
"lstm[112,1]_lstm[64,2]_seed3",6.538598,0.372652
"lstm[128,1]_lstm[64,2]_seed1",3.286201,0.481818
"lstm[128,1]_lstm[64,2]_seed2",11.443441,0.373409
"lstm[128,1]_lstm[64,2]_seed3",12.777044,0.328068
"lstm[160,1]_lstm[64,2]_seed1",12.29281,0.398674
"lstm[160,1]_lstm[64,2]_seed2",11.586091,0.352614
"lstm[160,1]_lstm[64,2]_seed3",8.120497,0.469924
"lstm[200,1]_lstm[64,2]_seed1",9.524651,0.383106
