In [None]:
import os
os.chdir('..')

In [None]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

In [None]:
import torch
import torchvision
import src.constants as const
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from torch.utils.data import DataLoader
from src.data.dataset import (VideoLabelDataset,
                              VideoFolderPathToTensor,
                              VideoResize)
import plotly
import numpy as np
import pandas as pd
import yaml
import os
%load_ext autoreload
%autoreload 2

In [None]:
dataset = VideoLabelDataset(
            const.LABELS_TABLE_QA_PATH,
            img_transform=None)

In [None]:
dataloader = DataLoader(dataset, batch_size=100, num_workers=6)

In [None]:
videos, answers, hidden_states, vid_folder  = iter(dataloader).next()

In [None]:
FORMULA_EXPONENTS = ['-1', '0', '1']
ENC_DIM_LAT_SPACE = 4
LAMBDA = 0.1
VALIDATION_SPLIT = 0.05
BATCH_SIZE = 2000
NUM_WORKERS = 4

In [None]:
from src.model.agents import FormulaFeatureGenerator, FormulaDecoder
import pytorch_lightning as pl



class LitModule(pl.LightningModule):

    def __init__(self):
        super().__init__()
        dataset = VideoLabelDataset(
            const.LABELS_TABLE_QA_PATH,
            img_transform=None)
        dataset_size = len(dataset)
        len_val = int(np.floor(dataset_size * VALIDATION_SPLIT))
        len_train = dataset_size - len_val
        self.dataset_train, self.dataset_val = torch.utils.data.random_split(
            dataset=dataset, lengths=[len_train, len_val],
            generator=torch.Generator())
        
        self.feature_generator = FormulaFeatureGenerator(
            formula_exponents=FORMULA_EXPONENTS,
            enc_dim_lat_space=ENC_DIM_LAT_SPACE
        )
        
        self.formula_decoder = FormulaDecoder(
            dec_num_features=self.feature_generator.num_features,
            dec_out_dim=1)
        self.batch_size = BATCH_SIZE
        self.dl_num_workers = NUM_WORKERS 
    
    def loss_function(self, dec_outs, answers):
        mse_loss = torch.nn.MSELoss()
        answer_loss = mse_loss(dec_outs, answers[:,0])
        dec_params = self.formula_decoder.lc.weight
        param_loss = torch.sum(torch.abs(dec_params))
        return answer_loss + LAMBDA * param_loss


    def forward(self, x):
        lsp_trans = self.feature_generator(hidden_states)
        out = self.formula_decoder(lsp_trans)
        return out

    def training_step(self, batch, batch_idx):
        _, answers, hidden_states, _ = batch
        lsp_trans = self.feature_generator(hidden_states)
        out = self.formula_decoder(lsp_trans)
        loss = self.loss_function(out, answers)
        self.logger.experiment.add_scalars("losses", {"train_loss": loss})

        return loss
    
    def validation_step(self, batch, batch_idx):
        _, answers, hidden_states, _ = batch
        lsp_trans = self.feature_generator(hidden_states)
        out = self.formula_decoder(lsp_trans)
        val_loss = self.loss_function(out, answers)
        self.logger.experiment.add_scalars("losses", {"val_loss": val_loss})
        return val_loss
    
    def train_dataloader(self):
        return DataLoader(self.dataset_train,
                          batch_size=self.batch_size,
                          num_workers=self.dl_num_workers,
                          pin_memory=True)

    def val_dataloader(self):
        return DataLoader(self.dataset_val,
                          batch_size=self.batch_size,
                          num_workers=self.dl_num_workers,
                          pin_memory=True)


    def configure_optimizers(self):
        optimizer = torch.optim.SGD(self.parameters(), lr=0.001)
        return optimizer

In [None]:
lit_model.parameters

In [None]:
LAST_CKP = 'lightning_logs/version_40/checkpoints/epoch=32-step=32.ckpt'
# from src.model.lit_module_formula_only import LitModule
lit_module = LitModule.load_from_checkpoint(LAST_CKP)

In [None]:
lit_model = LitModule()
trainer = pl.Trainer()
trainer.fit(lit_model)

In [None]:
videos, answers, hidden_states, vid_folder  = iter(dataloader).next()

In [None]:
hidden_states.shape

In [None]:
lit_model.feature_generator(hidden_states).max()

In [None]:
hidden_states.max()

In [None]:
lit_model.feature_generator.get_feature_names()

In [None]:
lit_model(hidden_states)[0:10]

In [None]:
answers[:, 0][0:10]

In [None]:
mse_loss = torch.nn.MSELoss()
mse_loss(lit_model(hidden_states), answers[:,0])

In [None]:
dec_params = lit_model.formula_decoder.lc.weight
dec_params
# param_loss = torch.sum(torch.abs(dec_params))

In [None]:
hidden_states

In [None]:
output = lit_module(hidden_states)
output[0:10, :]

In [None]:
#alt
output = lit_module(hidden_states)
output[0:10, :]

In [None]:
answers[0:10,:]

In [None]:
answers.min()

In [None]:
lit_module.dec_0.lc.weight

In [None]:
#alt
lit_module.dec_0.lc.weight

In [None]:
lit_module.dec_1.lc.weight

In [None]:
dec_params = torch.cat([dec.lc.weight for dec in lit_module.decoding_agents])
param_loss = torch.sum(torch.abs(dec_params))
param_loss

In [None]:
mse_loss = torch.nn.MSELoss()
answer_loss = mse_loss(output, answers)
answer_loss