# Lightning Trainer for MaskSimVP

In [1]:
#| default_exp trainer

In [4]:
#| export
import torch
import os

import lightning as pl
from lightning.pytorch.loggers import WandbLogger
from lightning.pytorch.callbacks import ModelCheckpoint, LearningRateMonitor

import matplotlib.pyplot as plt
from torchvision import transforms
import wandb
import random
import numpy as np
from lightning.pytorch.utilities import grad_norm

from maskpredformer.mask_simvp import MaskSimVP
from maskpredformer.simvp_dataset import DLDataset

In [6]:
#| export
class MaskSimVPTrainer(pl.LightningModule):
    def __init__(self, 
                 in_shape, hid_S, hid_T, N_S, N_T, model_type,
                 batch_size, lr, weight_decay, max_epochs,
                 data_root):
        super().__init__()
        self.save_hyperparameters()
        self.model = MaskSimVP(
            in_shape, hid_S, hid_T, N_S, N_T, model_type
        )
        self.train_set = DLDataset(data_root, "train")
        self.val_set = DLDataset(data_root, "val")
        self.criterion = torch.nn.CrossEntropyLoss()
    
    def train_dataloader(self):
        return torch.utils.data.DataLoader(
            self.train_set, batch_size=self.hparams.batch_size, 
            num_workers=8, shuffle=True, pin_memory=True
        )

    def val_dataloader(self):
        return torch.utils.data.DataLoader(
            self.val_set, batch_size=self.hparams.batch_size, 
            num_workers=8, shuffle=False, pin_memory=True
        )

    def step(self, x, y):
        y_hat_logits = self.model(x)
        assert y_hat_logits.shape == (y.shape[0], y.shape[1], 49, y.shape[3], y.shape[4]) 
        return y_hat_logits
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat_logits = self.step(x, y)
        loss = self.criterion(y_hat_logits, y)
        self.log("train_loss", loss)
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat_logits = self.step(x, y)
        loss = self.criterion(y_hat_logits, y)
        self.log("train_loss", loss)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(
            self.parameters(), lr=self.hparams.lr, 
            weight_decay=self.hparams.weight_decay
        )
        scheduler = torch.optim.lr_scheduler.OneCycleLR(
            optimizer, max_lr=self.hparams.lr,
            total_steps=self.hparams.max_epochs*len(self.train_dataloader()),
            final_div_factor=1e4
        )
        return [optimizer], [scheduler]

# Callbacks

## Sample Video Callback

> sample video callback to generate video samples during training

In [8]:
#| export
class SampleVideoCallback(pl.Callback):
    def __init__(self, val_set):
        super().__init__()
        self.val_set = val_set

    def apply_cm(self, x):
        cm = plt.get_cmap()
        norm = plt.Normalize(vmin=x.min(), vmax=x.max())
        return cm(norm(x))[:, :, :3].transpose(2,0,1) 

    def generate_video(self, pl_module):
        pl_module.eval()
        sample_idx = random.randint(0, len(self.val_set)-1)
        
        x, y = self.val_set[sample_idx]
        x = x.unsqueeze(0).to(pl_module.device)
        y = y.unsqueeze(0).to(pl_module.device)

        y_hat_logits = pl_module.step(x,y).squeeze(0) # (T, 49, H, W)
        y_hat = torch.argmax(y_hat_logits, dim=1) # (T, H, W)

        # convert to numpy
        x = x.squeeze(0).cpu().numpy()
        y = y.squeeze(0).cpu().numpy()
        y_hat = y_hat.cpu().numpy()

        gt_imgs = []
        pred_imgs = []
        # add first 11 frames to both
        for t in range(x.shape[0]):
            gt_imgs.append(self.apply_cm(x[t]))
            pred_imgs.append(self.apply_cm(x[t]))
        
        # add ground truth and predictions
        for t in range(y.shape[0]):
            gt_imgs.append(self.apply_cm(y[t]))
            pred_imgs.append(self.apply_cm(y_hat[t]))
        
        gt_imgs = np.stack(gt_imgs, axis=0)
        pred_imgs = np.stack(pred_imgs, axis=0)
        video = (np.concatenate([gt_imgs, pred_imgs], axis=-1) * 255).astype(np.uint8)
        return video
    
    def on_validation_epoch_end(self, trainer, pl_module):
        video = self.generate_video(pl_module)
        trainer.logger.experiment.log({
            "val_video": wandb.Video(video, fps=4, format="gif")
        })

