# Lightning Trainer for MaskSimVP

In [33]:
#| default_exp trainer

In [34]:
#| export
import torch
import os

import lightning as pl

import matplotlib.pyplot as plt
import wandb
import random

from maskpredformer.mask_simvp import MaskSimVP, DEFAULT_MODEL_CONFIG
from maskpredformer.simvp_dataset import DLDataset
from maskpredformer.vis_utils import show_gif
from maskpredformer.simvp_dataset import DLDataset, DEFAULT_DATA_PATH

In [35]:
#| export
class MaskSimVPModule(pl.LightningModule):
    def __init__(self, 
                 in_shape, hid_S, hid_T, N_S, N_T, model_type,
                 batch_size, lr, weight_decay, max_epochs,
                 data_root, unlabeled=False, downsample=False):
        super().__init__()
        self.save_hyperparameters()
        self.model = MaskSimVP(
            in_shape, hid_S, hid_T, N_S, N_T, model_type, downsample=downsample
        )
        self.train_set = DLDataset(data_root, "train", unlabeled=unlabeled)
        self.val_set = DLDataset(data_root, "val")
        self.criterion = torch.nn.CrossEntropyLoss()
    
    def train_dataloader(self):
        return torch.utils.data.DataLoader(
            self.train_set, batch_size=self.hparams.batch_size, 
            num_workers=8, shuffle=True, pin_memory=True
        )

    def val_dataloader(self):
        return torch.utils.data.DataLoader(
            self.val_set, batch_size=self.hparams.batch_size, 
            num_workers=8, shuffle=False, pin_memory=True
        )

    def step(self, x, y):
        y_hat_logits = self.model(x)
        return y_hat_logits
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat_logits = self.step(x, y)
        
        # Flatten batch and time dimensions
        b, t, *_ = y_hat_logits.shape
        y_hat_logits = y_hat_logits.view(b*t, *y_hat_logits.shape[2:])
        y = y.view(b*t, *y.shape[2:])

        loss = self.criterion(y_hat_logits, y)
        self.log("train_loss", loss)
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat_logits = self.step(x, y)

        # Flatten batch and time dimensions
        b, t, *_ = y_hat_logits.shape
        y_hat_logits = y_hat_logits.view(b*t, *y_hat_logits.shape[2:])
        y = y.view(b*t, *y.shape[2:])
       
        loss = self.criterion(y_hat_logits, y)
        self.log("val_loss", loss, sync_dist=True)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(
            self.parameters(), lr=self.hparams.lr, 
            weight_decay=self.hparams.weight_decay
        )
        lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(
            optimizer, max_lr=self.hparams.lr,
            total_steps=self.hparams.max_epochs*len(self.train_dataloader()),
            final_div_factor=1e4
        )
        opt_dict = {
            "optimizer": optimizer,
            "lr_scheduler":{
                "scheduler": lr_scheduler,
                "interval": "step",
                "frequency": 1
            } 
        }

        return opt_dict

**Test out the MaskSimVP Lightning Trainer**

In [36]:
pl_module = MaskSimVPModule(
    **DEFAULT_MODEL_CONFIG, 
    batch_size=1, lr=1e-3, weight_decay=0.0, max_epochs=10,
    downsample=True,
    data_root=DEFAULT_DATA_PATH
)

INFO: Loading masks from /home/enes/dev/maskpredformer/data/DL/train_masks.pt
INFO: Loading masks from /home/enes/dev/maskpredformer/data/DL/val_masks.pt


In [37]:
def test_pl_module(pl_module):
    x, y = pl_module.val_set[0]
    x = x.unsqueeze(0).to(pl_module.device)
    y = y.unsqueeze(0).to(pl_module.device)
    loss = pl_module.training_step((x,y), 0)
    print(loss)
test_pl_module(pl_module)

tensor(3.8254, grad_fn=<NllLoss2DBackward0>)


/home/enes/miniforge3/lib/python3.10/site-packages/lightning/pytorch/core/module.py:420: You are trying to `self.log()` but the `self.trainer` reference is not registered on the model yet. This is most likely because the model hasn't been passed to the `Trainer`


# Callbacks

## Sample Video Callback

> sample video callback to generate video samples during training

In [38]:
#| export
class SampleVideoCallback(pl.Callback):
    def __init__(self, val_set, video_path="./val_videos/"):
        super().__init__()
        self.val_set = val_set
        self.val_count = 0
        self.val_path = video_path
        if not os.path.exists(self.val_path):
            os.makedirs(self.val_path)

    def generate_video(self, pl_module):
        pl_module.eval()
        sample_idx = random.randint(0, len(self.val_set)-1)
        
        x, y = self.val_set[sample_idx]
        x = x.unsqueeze(0).to(pl_module.device)
        y = y.unsqueeze(0).to(pl_module.device)

        y_hat_logits = pl_module.step(x,y).squeeze(0) # (T, 49, H, W)
        y_hat = torch.argmax(y_hat_logits, dim=1) # (T, H, W)

        # convert to numpy
        x = x.squeeze(0).cpu().numpy()
        y = y.squeeze(0).cpu().numpy()
        y_hat = y_hat.cpu().numpy()

        gif_path = os.path.join(self.val_path, f"val_video_{self.val_count}.gif")

        show_gif(x, y, y_hat, out_path=gif_path)
        self.val_count += 1

        return gif_path
    
    def on_validation_epoch_end(self, trainer, pl_module):
        gif_path = self.generate_video(pl_module)
        trainer.logger.experiment.log({
            "val_video": wandb.Video(gif_path, fps=4, format="gif")
        })



In [39]:
# test video callback
sample_video_cb = SampleVideoCallback(pl_module.val_set)
gif_path = sample_video_cb.generate_video(pl_module)

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()