In [None]:
import os
os.chdir('..')
%load_ext autoreload
%autoreload 2

In [None]:
import torch
from comm_agents.data.data_handler import RefExpDataset
from comm_agents.models.model_single_enc_1 import SingleEncModel
import pandas as pd
import plotly.graph_objects as go
from ipywidgets import interact

In [None]:
dataset = RefExpDataset()

In [None]:
!ls -la models/

In [None]:
MODEL_PATH_PRE = './models/single_enc_model_pretrain_2020-10-31.ckpt'
MODEL_PATH_POST = './models/single_enc_model_2020-10-31.ckpt'

In [None]:
model_pre, model_post = [SingleEncModel.load_from_checkpoint(p) for p in [MODEL_PATH_PRE, MODEL_PATH_POST]]

In [None]:
model_pre.selection_bias, model_post.selection_bias

In [None]:
def get_a_ls_sb(model, obs, qs):
    lat_space = model.encode(obs)

    # filter
    s0, s1, s2, s3 = model.filter(lat_space, model.selection_bias)

    # decode
    answers = model.decode(s0, s1, s2, s3, qs)
    
    return answers, lat_space, model.selection_bias

In [None]:
df_opt_answers = pd.DataFrame(dataset.opt_answers.detach().numpy(),
                              columns=['alpha1_star', 'alpha2_star', 'phi1_star', 'phi2_star'])
df_hidden_states = pd.DataFrame(dataset.hidden_states.detach().numpy(),
                               columns=['m1', 'm2', 'q1', 'q2'])
df_question = pd.DataFrame(dataset.questions.detach().numpy(),
                       columns=['m_ref1', 'v_ref1', 'm_ref1', 'v_ref2'])

df_hidden_states['q0_t_q1'] = df_hidden_states.q1 * df_hidden_states.q2
    

In [None]:
def create_answers_plot(answer, opt_answer, samples=1000, pretrain=True):
    
    model =  model_pre if pretrain else model_post
    answers, lat_spaces, selection_biases = get_a_ls_sb(model, dataset.observations[0:1000],
                                                  dataset.questions[0:1000])
    df_answers = pd.DataFrame(answers.detach().numpy(), columns=['alpha1', 'alpha2', 'phi1', 'phi2'])
    df_lat_space = pd.DataFrame(lat_spaces.detach().numpy(),
                                   columns=['l1', 'l2', 'l3'])
    
    fig = go.Figure()
    fig.add_trace(go.Scatter(
        y=df_answers[answer][0:samples],
        x=df_opt_answers[opt_answer][0:samples],
        mode='markers',
        opacity=.5))
    fig.update_layout(title_text='True labels vs prediction')
    fig.update_xaxes(title_text='Optimal answer')
    fig.update_yaxes(title_text='Predicted answer')
    fig.show()
interact(create_answers_plot, answer=['alpha1', 'alpha2', 'phi1', 'phi2'],
        opt_answer=['alpha1_star', 'alpha2_star', 'phi1_star', 'phi2_star'],
        samples=[100, 1000, 10000])

In [None]:
def create_lat_space_plot(lat_neuron, hidden_state, samples=1000, pretrain=True):
    
    model =  model_pre if pretrain else model_post
    lat_spaces = model(dataset.observations[0:1000])
    df_lat_space = pd.DataFrame(lat_spaces.detach().numpy(),
                                   columns=['l0', 'l1', 'l2'])
    
    fig = go.Figure()
    fig.add_trace(go.Scatter(
        y=df_lat_space[lat_neuron][0:samples],
        x=df_hidden_states[hidden_state][0:samples],
        mode='markers',
        opacity=.5))
    fig.update_layout(title_text='Latent neuron activation vs. hidden states')
    fig.update_xaxes(title_text=f'Hiddenstate {hidden_state}')
    fig.update_yaxes(title_text=f'Latent neuron activation {lat_neuron}')
    fig.show()
interact(create_lat_space_plot, lat_neuron=['l0', 'l1', 'l2'],
        hidden_state=['m1', 'm2', 'q0_t_q1'],
        samples=[100, 1000, 10000])