In [None]:
import os
os.chdir('..')

In [None]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

In [None]:
import torch
import torchvision
from pytorch_lightning.callbacks import ModelCheckpoint
from src.model.lit_module import LitModule
from src.data.dataset import VideoLabelDataset
import src.constants as const
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from torch.utils.data import DataLoader
from src.data.dataset import (VideoLabelDataset,
                              VideoFolderPathToTensor,
                              VideoResize)
import plotly
import numpy as np
import pandas as pd
import yaml
import os
%load_ext autoreload
%autoreload 2

In [None]:
dataset = VideoLabelDataset(
            const.LABELS_TABLE_QA_PATH,
            img_transform=torchvision.transforms.Compose([
                VideoFolderPathToTensor(),
                VideoResize(const.IMG_SIZE)]))

In [None]:
dataloader = DataLoader(dataset, batch_size=200, num_workers=6)

In [None]:
#for validation data only
# with open("lightning_logs/version_16/hparams.yaml", 'r') as stream:
#     hparams = yaml.load(stream, Loader=yaml.Loader)
# lit_mod = LitModule(**hparams)
# dataloader = lit_mod.val_dataloader()

In [None]:
videos, questions, answers, hidden_states, vid_folder  = iter(dataloader).next()

In [None]:
checkpoint_path = './lightning_logs/version_16/checkpoints/epoch=149-step=4499.ckpt'

In [None]:
model = LitModule.load_from_checkpoint(checkpoint_path)

In [None]:
hidden_states_pred = model.eval()(videos)
answers_pred = [dec(hidden_states_pred, questions) for dec in model.decoding_agents]
answers_pred = torch.cat(answers_pred, axis=1)

In [None]:
df_answers_pred = pd.DataFrame(answers_pred.detach().numpy(), columns=[f'{c}_pred' for c in const.ANSWER_COLS])
df_hidden_states_pred = pd.DataFrame(hidden_states_pred.detach().numpy(), columns=[f'lat_neuron_{c}' for c in range(len(const.HIDDEN_STATE_COLS))])
df_answers = pd.DataFrame(answers.detach().numpy(), columns=const.ANSWER_COLS)
df_hidden_states = pd.DataFrame(hidden_states.detach().numpy(), columns=const.HIDDEN_STATE_COLS)
df_questions = pd.DataFrame(questions.detach().numpy(), columns=[const.QUESTION_COL])

In [None]:
fig = make_subplots(rows=1, cols=3)
for i, c in enumerate(const.ANSWER_COLS):
    fig.add_trace(go.Scatter(x=df_answers[c], y=df_answers_pred[f'{c}_pred'],
                        mode='markers',
                        marker_color='#1f77b4',
                        name=c), row=1, col=i+1)
    fig.update_xaxes(title_text=c, row=1, col=i+1)
fig.update_layout(title_text="Predicted answer over optimal answer", width=1200, showlegend=False)
fig.update_yaxes(title_text="Predicted answers", col=1)
fig.show()

In [None]:
from plotly.subplots import make_subplots
import plotly.graph_objects as go

fig = make_subplots(rows=3, cols=3)

for i, hs in enumerate(df_hidden_states.columns):
    for j, hs_pred in enumerate(df_hidden_states_pred.columns):
        fig.add_trace(go.Scatter(x=df_hidden_states[hs], y=df_hidden_states_pred[hs_pred],
                            mode='markers', name=f'activation {hs_pred} over box_x',
                                marker_color='#1f77b4'),
                     row=i+1, col=j+1)

# Update xaxis properties
for i in range(len(const.HIDDEN_STATE_COLS)):
    fig.update_xaxes(title_text="Coin position", row=1, col=i+1)
    fig.update_xaxes(title_text="Pipe position", row=2, col=i+1)
    fig.update_xaxes(title_text="Enemy speed", row=3, col=i+1)

for j in range(len(df_hidden_states_pred.columns)):
    fig.update_yaxes(title_text="Latent neuron 0 activation", row=j+1, col=1)
    fig.update_yaxes(title_text="Latent neuron 1 activation", row=j+1, col=2)
    fig.update_yaxes(title_text="Latent neuron 2 activation", row=j+1, col=3)

fig.update_layout(height=1000, width=1200, title_text="Latent neuron activations vs. hidden states", showlegend=False)
fig.show()