In [1]:
import os
os.chdir('..')
%load_ext autoreload
%autoreload 2

In [2]:
import torch
from comm_agents.data.data_handler import RefExpDataset
from comm_agents.models.model_multi_enc import MultiEncModel
import pandas as pd
import plotly.graph_objects as go
from ipywidgets import interact

In [3]:
dataset = RefExpDataset()

In [4]:
MODEL_PATH_PRE = './models/multi_enc_model_pre2020-10-30.ckpt'
MODEL_PATH_POST = './models/multi_enc_model_post2020-10-30-v0.ckpt'

In [5]:
model_pre, model_post = [MultiEncModel.load_from_checkpoint(p) for p in [MODEL_PATH_PRE, MODEL_PATH_POST]]

In [6]:
model_pre.selection_bias, model_post.selection_bias

(Parameter containing:
 tensor([[-17.2264, -15.9174, -12.9825, -12.6447],
         [-14.4348, -16.7241, -11.4729, -11.5596],
         [-16.1852, -16.2504, -16.1708, -16.2905],
         [-15.8370, -16.1113, -16.4024, -15.9794]], requires_grad=True),
 Parameter containing:
 tensor([[-3.0471, 16.7799, 16.7432, 16.6688],
         [16.9789, -2.5843, 16.9288, 16.9864],
         [-3.9637, -3.4496, -2.6675, -4.8138],
         [-3.9446, -3.4891, -3.1365, -5.1670]], requires_grad=True))

In [7]:
def get_a_ls_sb(model, obs, qs):
    lat_space = model.encode(obs)

    # filter
    s0, s1, s2, s3 = model.filter(lat_space, model.selection_bias)

    # decode
    answers = model.decode(s0, s1, s2, s3, qs)
    
    return answers, lat_space, model.selection_bias

In [8]:
df_opt_answers = pd.DataFrame(dataset.opt_answers.detach().numpy(),
                              columns=['alpha1_star', 'alpha2_star', 'phi1_star', 'phi2_star'])
df_hidden_states = pd.DataFrame(dataset.hidden_states.detach().numpy(),
                               columns=['m1', 'm2', 'q1', 'q2'])
df_question = pd.DataFrame(dataset.questions.detach().numpy(),
                       columns=['m_ref1', 'v_ref1', 'm_ref1', 'v_ref2'])
    

In [12]:
def create_answers_plot(answer, opt_answer, samples=1000, pretrain=True):
    
    model =  model_pre if pretrain else model_post
    answers, lat_spaces, selection_biases = get_a_ls_sb(model, dataset.observations[0:1000],
                                                  dataset.questions[0:1000])
    df_answers = pd.DataFrame(answers.detach().numpy(), columns=['alpha1', 'alpha2', 'phi1', 'phi2'])
    df_lat_space = pd.DataFrame(lat_spaces.detach().numpy(),
                                   columns=['l1', 'l2', 'l3', 'l4'])
    
    fig = go.Figure()
    fig.add_trace(go.Scatter(
        y=df_answers[answer][0:samples],
        x=df_opt_answers[opt_answer][0:samples],
        mode='markers',
        opacity=.5))
    fig.update_layout(title_text='True labels vs prediction')
    fig.update_xaxes(title_text='Optimal answer')
    fig.update_yaxes(title_text='Predicted answer')
    fig.show()
interact(create_answers_plot, answer=['alpha1', 'alpha2', 'phi1', 'phi2'],
        opt_answer=['alpha1_star', 'alpha2_star', 'phi1_star', 'phi2_star'],
        samples=[100, 1000, 10000])

interactive(children=(Dropdown(description='answer', options=('alpha1', 'alpha2', 'phi1', 'phi2'), value='alph…

<function __main__.create_answers_plot(answer, opt_answer, samples=1000, pretrain=True)>

In [13]:
def create_lat_space_plot(lat_neuron, hidden_state, samples=1000, pretrain=True):
    
    model =  model_pre if pretrain else model_post
    lat_spaces = model(dataset.observations[0:1000])
    df_lat_space = pd.DataFrame(lat_spaces.detach().numpy(),
                                   columns=['l1', 'l2', 'l3', 'l4'])
    
    fig = go.Figure()
    fig.add_trace(go.Scatter(
        y=df_lat_space[lat_neuron][0:samples],
        x=df_hidden_states[hidden_state][0:samples],
        mode='markers',
        opacity=.5))
    fig.update_layout(title_text='Latent neuron activation vs. hidden states')
    fig.update_xaxes(title_text=f'Hiddenstate {hidden_state}')
    fig.update_yaxes(title_text=f'Latent neuron activation {lat_neuron}')
    fig.show()
interact(create_lat_space_plot, lat_neuron=['l1', 'l2', 'l3', 'l4'],
        hidden_state=['m1', 'm2', 'q1', 'q2'],
        samples=[100, 1000, 10000])

interactive(children=(Dropdown(description='lat_neuron', options=('l1', 'l2', 'l3', 'l4'), value='l1'), Dropdo…

<function __main__.create_lat_space_plot(lat_neuron, hidden_state, samples=1000, pretrain=True)>

In [25]:
df_hidden_states

Unnamed: 0,m1,m2,q1,q2
0,2.886907e-20,3.190157e-20,3.881887e-16,-5.268712e-16
1,4.716424e-20,4.255824e-20,3.297433e-16,-4.814518e-16
2,4.007451e-20,2.639950e-20,5.569538e-16,-4.749995e-16
3,1.249199e-20,4.477453e-20,3.100031e-16,-3.586409e-16
4,1.201942e-20,4.804323e-20,7.033183e-16,-4.712755e-16
...,...,...,...,...
1041817,3.934335e-20,3.914088e-20,4.882463e-16,-5.192551e-16
1041818,3.934335e-20,3.914088e-20,4.882463e-16,-5.192551e-16
1041819,3.934335e-20,3.914088e-20,4.882463e-16,-5.192551e-16
1041820,3.934335e-20,3.914088e-20,4.882463e-16,-5.192551e-16
