In [None]:
import os
os.chdir('..')
%load_ext autoreload
%autoreload 2

In [None]:
import torch
from comm_agents.data.data_handler import RefExpDataset
from comm_agents.models.model_single_enc import SingleEncModel
import pandas as pd
import plotly.graph_objects as go
from ipywidgets import interact

In [None]:
dataset = RefExpDataset()

In [None]:
MODEL_PATH_PRE = './models/single_enc_model_pretrain_2020-10-26.pt'
MODEL_PATH_POST = './models/single_enc_model2020-10-26.pt'

In [None]:
device = torch.device('cpu')
models = [SingleEncModel(observantion_size=40,
                       lat_space_size=3,
                       question_size=2,
                       enc_num_hidden_layers=10,
                       enc_hidden_size=100,
                       dec_num_hidden_layers=10,
                       dec_hidden_size=100,
                       num_decoding_agents=4,
                       device=device) for _ in range(2)]
for m, path in zip(models, [MODEL_PATH_PRE, MODEL_PATH_POST]):
    m.load_state_dict(torch.load(path, map_location=torch.device('cpu')))
model_pre, model_post = models

In [None]:
def create_scatter_plot(answer, opt_answer, samples=1000, pretrain=True):
    
    model =  model_pre if pretrain else model_post
    answers, lat_spaces, selection_biases = model(dataset.observations[0:1000],
                                                  dataset.questions[0:1000])
    df_answers = pd.DataFrame(answers.detach().numpy(), columns=['alpha1', 'alpha2', 'phi1', 'phi2'])
    df_opt_answers = pd.DataFrame(dataset.opt_answers.detach().numpy(),
                                  columns=['alpha1_star', 'alpha2_star', 'phi1_star', 'phi2_star'])
    df_hidden_states = pd.DataFrame(dataset.hidden_states.detach().numpy(),
                                   columns=['m1', 'm2', 'q1', 'q2'])
    df_lat_space = pd.DataFrame(lat_spaces.detach().numpy(),
                                   columns=['l1', 'l2', 'l3', 'l4'])
    df_question = pd.DataFrame(dataset.questions.detach().numpy(),
                           columns=['m_ref1', 'v_ref1', 'm_ref1', 'v_ref2'])
    
    fig = go.Figure()
    fig.add_trace(go.Scatter(
        y=df_answers[answer][0:samples],
        x=df_opt_answers[opt_answer][0:samples],
        mode='markers',
        opacity=.5))
    fig.update_layout(title_text='True labels vs prediction')
    fig.update_xaxes(title_text='Optimal answer')
    fig.update_yaxes(title_text='Predicted answer')
    fig.show()
interact(create_scatter_plot, answer=['alpha1', 'alpha2', 'phi1', 'phi2'],
        opt_answer=['alpha1_star', 'alpha2_star', 'phi1_star', 'phi2_star'],
        samples=[100, 1000, 10000])

In [None]:
model_post.selection_bias

In [None]:
import plotly.express as px
px.scatter(x=df_hidden_states.q1[0:1000]*df_hidden_states.q2[0:1000],
           y=df_lat_space.l1 * df_lat_space.l2)

In [None]:
df_hidden_states