In [1]:
import os
os.chdir('..')

In [2]:
%load_ext autoreload
%autoreload 2

In [45]:
import torch
import torchvision
from pytorch_lightning.callbacks import ModelCheckpoint
from src.model.lit_module import LitModule
from src.data.dataset import VideoLabelDataset
import src.constants as const
from torch.utils.data import DataLoader
from src.data.dataset import (VideoLabelDataset,
                              VideoFolderPathToTensor,
                              VideoResize)
import plotly
import numpy as np

In [41]:
%run scripts/question_and_optimal_answer.py

In [42]:
dataset = VideoLabelDataset(
            const.LABELS_TABLE_QA_PATH,
            img_transform=torchvision.transforms.Compose([
                VideoFolderPathToTensor(),
                VideoResize(const.IMG_SIZE)]))

In [43]:
df = dataset.dataframe
df.columns

Index(['Unnamed: 0', 'imgs_folder_path', 'box_x', 'pipe_x', 'enemy_speed',
       'mario_speed', 'answer_box', 'answer_pipe', 'answer_enemy'],
      dtype='object')

In [48]:
import plotly.graph_objects as go
cols = ['answer_box', 'answer_pipe', 'answer_enemy']

fig = go.Figure()
for c in cols:
    fig.add_trace(go.Histogram(x=df[c].values))

# Overlay both histograms
fig.update_layout(barmode='overlay')
# Reduce opacity to see both histograms
fig.update_traces(opacity=0.75)
fig.show()


In [40]:
import plotly.express as px
fig = px.histogram(df, x="answer_enemy")
fig.show()

In [5]:
dataloader = DataLoader(dataset, batch_size=20, num_workers=2)

In [6]:
videos, questions, answers, hidden_states, vid_folder  = iter(dataloader).next()

In [7]:
dataset.dataframe.head(10)

Unnamed: 0.1,Unnamed: 0,imgs_folder_path,box_x,pipe_x,enemy_speed,mario_speed,answer_box,answer_pipe,answer_enemy
0,0,data/imgs_series/00001,0.156313,0.18593,1.0,53.610005,0.222669,0.642245,0.772056
1,1,data/imgs_series/00002,0.214429,0.145729,0.25,90.388381,0.119453,0.087789,0.012569
2,2,data/imgs_series/00003,0.018036,0.130653,0.666667,80.306984,0.033699,0.18302,0.045223
3,3,data/imgs_series/00004,0.104208,1.0,0.291667,88.281578,0.066662,0.291999,0.016097
4,4,data/imgs_series/00005,0.262525,0.396985,0.125,70.172598,0.215911,0.383197,0.038661
5,5,data/imgs_series/00006,0.797595,0.477387,0.041667,73.116969,0.543775,0.36067,0.029216
6,6,data/imgs_series/00007,0.038076,0.376884,0.083333,52.924301,0.122878,0.728766,0.095075
7,7,data/imgs_series/00008,0.689379,0.095477,0.75,98.627947,0.324979,0.012256,0.01689
8,8,data/imgs_series/00009,0.390782,0.20603,0.708333,71.602496,0.292921,0.310474,0.076882
9,9,data/imgs_series/00010,0.150301,0.366834,0.416667,62.993787,0.169477,0.497407,0.084251


In [8]:
ll lightning_logs/version_16/checkpoints/

total 131948
-rw-rw-r-- 1 ubuntu 135111612 Jan  7 18:14 'epoch=60-step=1816.ckpt'


In [11]:
checkpoint_path = './lightning_logs/version_16/checkpoints/epoch=60-step=1816.ckpt'

In [12]:
model = LitModule.load_from_checkpoint(checkpoint_path)

In [17]:
predictions = model.eval()(videos)
predictions = [dec(predictions, questions) for dec in model.decoding_agents]
predictions = torch.cat(predictions, axis=1)
predictions

tensor([[0.2968, 0.6125, 0.2432],
        [0.1720, 0.1149, 0.0360],
        [0.1885, 0.1899, 0.0454],
        [0.1482, 0.2744, 0.0365],
        [0.1898, 0.3400, 0.0680],
        [0.5539, 0.3379, 0.0517],
        [0.2760, 0.6661, 0.2422],
        [0.3884, 0.0687, 0.0310],
        [0.3320, 0.2890, 0.0596],
        [0.2194, 0.4536, 0.1135],
        [0.1664, 0.1533, 0.0387],
        [0.4595, 0.2207, 0.0368],
        [0.2504, 0.7921, 0.2152],
        [0.3242, 0.1864, 0.0345],
        [0.4935, 0.4198, 0.0449],
        [0.3478, 0.2031, 0.0351],
        [0.4105, 0.3111, 0.0506],
        [0.3046, 0.1068, 0.0342],
        [0.4931, 0.4169, 0.0500],
        [0.2450, 0.4129, 0.1121]], grad_fn=<CatBackward>)

In [14]:
answers

tensor([[0.2227, 0.6422, 0.7721],
        [0.1195, 0.0878, 0.0126],
        [0.0337, 0.1830, 0.0452],
        [0.0667, 0.2920, 0.0161],
        [0.2159, 0.3832, 0.0387],
        [0.5438, 0.3607, 0.0292],
        [0.1229, 0.7288, 0.0951],
        [0.3250, 0.0123, 0.0169],
        [0.2929, 0.3105, 0.0769],
        [0.1695, 0.4974, 0.0843],
        [0.0321, 0.1668, 0.0435],
        [0.4351, 0.2393, 0.0168],
        [0.1477, 0.8709, 0.1408],
        [0.2635, 0.2030, 0.0370],
        [0.4482, 0.4178, 0.0513],
        [0.3035, 0.2227, 0.0091],
        [0.3721, 0.3494, 0.0412],
        [0.2732, 0.0569, 0.0085],
        [0.4221, 0.4421, 0.0293],
        [0.2300, 0.3867, 0.0945]])

In [18]:
mse_loss = torch.nn.MSELoss(reduction='sum')
mse_hidden = mse_loss(predictions[0:2,:].type(torch.float32),
                      hidden_states[0:2,:].type(torch.float32))
mse_hidden

tensor(0.2628, grad_fn=<MseLossBackward>)

73323.68388475002

In [None]:
model.eval()(videos)

In [22]:
predictions[0:2,:]

tensor([[543.4559, 996.8821,  37.2706],
        [543.4633, 996.8958,  37.2711]], grad_fn=<SliceBackward>)