In [1]:
import os
os.chdir('..')

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import torch
import torchvision
from pytorch_lightning.callbacks import ModelCheckpoint
from src.model.lit_module import LitModule
from src.data.dataset import VideoLabelDataset
import src.constants as const
from torch.utils.data import DataLoader
from src.data.dataset import (VideoLabelDataset,
                              VideoFolderPathToTensor,
                              VideoResize)
import plotly
import numpy as np
import pandas as pd

In [5]:
dataset = VideoLabelDataset(
            const.LABELS_TABLE_QA_PATH,
            img_transform=torchvision.transforms.Compose([
                VideoFolderPathToTensor(),
                VideoResize(const.IMG_SIZE)]))

In [6]:
dataloader = DataLoader(dataset, batch_size=500, num_workers=6)

In [7]:
from src.model.lit_module import LitModule

In [8]:
import yaml

In [9]:
with open("lightning_logs/version_16/hparams.yaml", 'r') as stream:
    hparams = yaml.load(stream, Loader=yaml.Loader)

In [10]:
lit_mod = LitModule(**hparams)
val_dataloader = lit_mod.val_dataloader()

In [11]:
videos, questions, answers, hidden_states, vid_folder  = iter(dataloader).next()

In [12]:
ll lightning_logs/version_16/checkpoints/

total 263344
-rw-rw-r-- 1 ubuntu 135112663 Jan 10 20:15 'epoch=149-step=4499.ckpt'
-rw-rw-r-- 1 ubuntu         0 Jan 10 20:17 'epoch=150-step=4528.ckpt'
-rw-rw-r-- 1 ubuntu 134520832 Jan 10 20:17 'epoch=150-step=4529.ckpt'


In [13]:
checkpoint_path = './lightning_logs/version_16/checkpoints/epoch=149-step=4499.ckpt'

In [14]:
model = LitModule.load_from_checkpoint(checkpoint_path)

In [None]:
hidden_states_pred = model.eval()(videos)
answers_pred = [dec(hidden_states_pred, questions) for dec in model.decoding_agents]
answers_pred = torch.cat(answers_pred, axis=1)

In [33]:
df_answers_pred = pd.DataFrame(answers_pred.detach().numpy(), columns=[f'{c}_pred' for c in const.ANSWER_COLS])
df_hidden_states_pred = pd.DataFrame(hidden_states_pred.detach().numpy(), columns=[f'lat_neuron_{c}' for c in range(len(const.HIDDEN_STATE_COLS))])
df_answers = pd.DataFrame(answers.detach().numpy(), columns=const.ANSWER_COLS)
df_hidden_states = pd.DataFrame(hidden_states.detach().numpy(), columns=const.HIDDEN_STATE_COLS)

In [34]:
import plotly.graph_objects as go
fig = go.Figure()
for c in const.ANSWER_COLS:
    fig.add_trace(go.Scatter(x=df_answers[c], y=df_answers_pred[f'{c}_pred'],
                        mode='markers',
                        name=c))
fig.show()

In [36]:
df_hidden_states_pred.head()

Unnamed: 0,lat_neuron_0,lat_neuron_1,lat_neuron_2
0,3.413215,3.978119,-3.656238
1,3.584043,2.175408,-1.059583
2,4.767894,1.794965,-3.731052
3,3.209562,2.345822,-3.231331
4,3.199042,3.277331,-2.692829


In [41]:
fig = go.Figure()
for c in df_hidden_states_pred.columns:
    fig.add_trace(go.Scatter(x=df_hidden_states.box_x, y=df_hidden_states_pred[c],
                        mode='markers',
                        name=f'activation {c} over box_x'))
fig.show()

In [43]:
fig = go.Figure()
for c in df_hidden_states_pred.columns:
    fig.add_trace(go.Scatter(x=df_hidden_states.pipe_x, y=df_hidden_states_pred[c],
                        mode='markers',
                        name=f'activation {c} over pipe_x'))
fig.show()

In [44]:
fig = go.Figure()
for c in df_hidden_states_pred.columns:
    fig.add_trace(go.Scatter(x=df_hidden_states.enemy_speed, y=df_hidden_states_pred[c],
                        mode='markers',
                        name=f'activation {c} over enemy_speed'))
fig.show()