# Deep Learning course 2022-2023
## Project: Video Prediction on Moving MNIST
### Project Contributors
* Mattia Castelmare, 1815675
* Andrea Giuseppe Di Francesco, 1836928
* Enrico Fazzi, 2003876

#### Installation cells

In [None]:
# !git clone https://github.com/tychovdo/MovingMNIST.git
# !pip3 install pytorch-lightning==1.5.10
# !pip3 install torchvision
# !pip3 install matplotlib
# !pip install wandb

### Importing Libraries

In [None]:
### wandb codes ###
import wandb
#####################

import torch
import os
import torch.nn.functional as F
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
import torch
import torch.nn as nn
from torch.autograd import Variable
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.callbacks.progress import TQDMProgressBar
from pytorch_lightning import Trainer
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from matplotlib import rc
rc('animation', html='jshtml')


from MovingMNIST.MovingMNIST import *

## Useful Functions

In [None]:
def show_video(tensor):
    ''' This function display a video, given a torch tensor (source: https://stackoverflow.com/questions/67261108/how-to-display-a-video-in-colab-using-a-pytorch-tensor-of-rgb-image-arrays)
        INPUT: tensor (Frames x Channels x Height x Width) 
        OUTPUT: Display an animation '''

    fig, ax = plt.subplots()

    imgs = tensor
    imgs = imgs.permute(0, 2, 3, 1)  # Permuting to (Bx)HxWxC format
    frames = [[ax.imshow(imgs[i], cmap='gray')] for i in range(len(imgs))]

    ani = animation.ArtistAnimation(fig, frames)
    return ani


def collate(batch):
    ''' This function is used within the Dataloader pytorch utility. 
        INPUT: batch: List of tuples of tensor, each tuple contains the input video and the target video, 
        OUTPUT: A tuple (input, gt), s.t. both of them has dimension (B x T x C x H x W)'''

    list_tuples = list(map(lambda x: (torch.reshape(x[0], (1, x[0].shape[0], 1, x[0].shape[1], x[0].shape[2])), torch.reshape(
        x[1], (1, x[1].shape[0], 1, x[1].shape[1], x[1].shape[2]))), batch))
    
    input = list_tuples[0][0]
    gt = list_tuples[0][1]
    
    for i in range(1, len(list_tuples)):
        input = torch.cat((input, list_tuples[i][0]), dim=0)
        gt = torch.cat((gt, list_tuples[i][1]), dim = 0)
    
    return (input.type(torch.FloatTensor), gt.type(torch.FloatTensor))    
def save_model(checkpoint, path):
  # This function saves a pytorch model.
    torch.save(checkpoint, path)

def load_model(path,model, device):
  # This function loads a pytorch model from a path.
    checkpoint = torch.load(path, map_location=device)
    model.load_state_dict(checkpoint['model_state']) 
    
    return checkpoint


### Useful variables 

In [None]:
shuffle_train = True
shuffle_test = False
download = False

### wandb codes ###
wb = True
project_name = "VP_on_MMNIST"
if wb:
    wandb.login()
###################

SEED = 2812023
pl.seed_everything(SEED)

root = './data'
if not os.path.exists(root):
    os.mkdir(root)


### Pytorch lightning DataModule:
* Set the MMNIST Dataset object and Dataloader 

In [None]:
class pl_Dataset(pl.LightningDataModule):

    def __init__(self, batch_size):
         
         self.bs = batch_size


    def setup(self, stage = None):
        if stage == 'fit':
            self.train_set = MovingMNIST(root=root,
                        train=True,
                        download = download)
        elif stage == 'test':
            self.test_set = MovingMNIST(root=root,
                        train=False,
                        download = download)
            

    def train_dataloader(self, *args, **kwargs):

        train_loader = torch.utils.data.DataLoader(
                 dataset = self.train_set,
                 batch_size=self.bs,
                 shuffle=shuffle_train,
                 collate_fn = collate)

        return train_loader

    def val_dataloader(self, *args, **kwargs):

        test_loader = torch.utils.data.DataLoader(
                 dataset = self.test_set,
                 batch_size=self.bs,
                 shuffle=shuffle_test,
                 collate_fn = collate)

        return test_loader
    




### Hyperparameters

In [None]:
model_conf = 'simvp'


batch_size = 16
lr = 1e-2
wd = 5e-2

tau = 1e-1 # KL Divergency parameter



### Vanilla hyperparameters ###
if model_conf == 'random':
    hyperparameters = {'lr': lr,
                    'wd': wd,
                    'epochs': 100,
                    'batch_size': batch_size
                    }
elif model_conf == 'simvp':
    hyperparameters = {'lr': lr,
                    'wd': wd,
                    'epochs': 100,
                    'batch_size': batch_size
                    }
    CNN = {'input': 1,
           'hidden': 16,
           'ksize': 3,
           'Ns': 4}





data = pl_Dataset(batch_size)

data.setup(stage = 'fit')
data.setup(stage = 'test')



### Models

#### VanillaED_1conv
* Implementation of a simple Vanilla Encoder-Decoder Architecture

In [None]:
class VEncoder(nn.Module):

    def __init__(self, in_channels, out_channels):
        super(VEncoder, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(out_channels, out_channels * 2, kernel_size=3, stride=2, padding=1),
            nn.ReLU()
        )
    
    def forward(self, x):
        return self.conv(x)

class VDecoder(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(VDecoder, self).__init__()
        self.conv = nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(out_channels, 1, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU()
        )
    
    def forward(self, x):
        return self.conv(x)

class VanillaED(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(VanillaED, self).__init__()
        self.encoder = VEncoder(in_channels, out_channels)
        self.decoder = VDecoder(out_channels * 2, out_channels)
    
    def forward(self, x):
  
        original_shape = x.shape
     
        x = x.reshape((original_shape[0]*original_shape[1], original_shape[2], original_shape[3], original_shape[4]))


        x = self.encoder(x)
   
        
        x = self.decoder(x)
     
        x = x.reshape((original_shape[0], original_shape[1], original_shape[2], original_shape[3], original_shape[4]))
        

        return x
        
    

#### SimVP
* Implementation of the SimVP architecture, the details are available at https://arxiv.org/pdf/2206.05099.pdf,
* Encoder (CNN) +  Translator (Inception) + Decoder (CNN) 

In [None]:
class ConvBlock(nn.Module):
    
    def __init__(self, inputC, outputC, transpose = False, groups = 8):
        super(ConvBlock, self).__init__()
        if not transpose:
            self.conv = nn.Conv2d(inputC, outputC, CNN['ksize'], padding = 'same')
        else:
            self.conv = nn.ConvTranspose2d(inputC, outputC, CNN['ksize'], padding = 1)

        self.layernorm = nn.GroupNorm(groups, outputC)
        self.leaky = nn.LeakyReLU(0.2, inplace = True)

    def forward(self, x):

        x = self.leaky(self.layernorm(self.conv(x)))
        
        return x

class EncoderSimVP(nn.Module):

    def __init__(self, n_layers):
        super(EncoderSimVP, self).__init__()
        layers = [ConvBlock(CNN['input'], CNN['hidden'])]

        for layer in range(n_layers):
            layers.append(ConvBlock(CNN['hidden'], CNN['hidden']))

        self.enc = nn.Sequential(*layers)

    def forward (self, x):
       
        B, T, C, H, W = x.shape

        x = x.reshape((B*T, C, H, W))

        x = self.enc(x)

        x = x.reshape((B, T, CNN['hidden'], H, W))

        return x
    

class DecoderSimVP(nn.Module):

    def __init__(self, n_layers):
        super(DecoderSimVP, self).__init__()

        layers = []
        for layer in range(n_layers):
            layers.append(ConvBlock(CNN['hidden'], CNN['hidden'], transpose = True))
        
        layers.append(ConvBlock(CNN['hidden'], CNN['input'], transpose = True, groups = CNN['input']))

        self.dec = nn.Sequential(*layers)

    def forward(self, x):
        B, T, C, H, W = x.shape

        x = x.reshape((B*T, C, H, W))

        x = self.dec(x)
 
        x = x.reshape((B, T, CNN['input'], H, W))

        return x


# class InceptionModule(nn.Module):

#     def __init__(self, inputC, hiddenC, kernel_list = [3, 5, 7, 11], groups = 8):
#         super(InceptionModule, self).__init__()

#         outputC = hiddenC // 2


#         self.conv1 = nn.Conv2d()

class Inception(nn.Module):
    def __init__(self, C_in, C_hid, C_out, incep_ker=[3,5,7,11], groups=8):        
        super(Inception, self).__init__()
        self.conv1 = nn.Conv2d(C_in, C_hid, kernel_size=1, stride=1, padding=0)
        layers = []
        for ker in incep_ker:
            layers.append(ConvBlock(C_hid, C_out, ker, groups=groups))
        self.layers = nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        y = 0
        for layer in self.layers:
            y += layer(x)
        return y
    
class Mid_Xnet(nn.Module):
    def __init__(self, channel_in, channel_hid, N_T, incep_ker = [3,5,7,11], groups=8):
        super(Mid_Xnet, self).__init__()

        self.N_T = N_T
        enc_layers = [Inception(channel_in, channel_hid//2, channel_hid, incep_ker= incep_ker, groups=groups)]
        for i in range(1, N_T-1):
            enc_layers.append(Inception(channel_hid, channel_hid//2, channel_hid, incep_ker= incep_ker, groups=groups))
        enc_layers.append(Inception(channel_hid, channel_hid//2, channel_hid, incep_ker= incep_ker, groups=groups))

        dec_layers = [Inception(channel_hid, channel_hid//2, channel_hid, incep_ker= incep_ker, groups=groups)]
        for i in range(1, N_T-1):
            dec_layers.append(Inception(2*channel_hid, channel_hid//2, channel_hid, incep_ker= incep_ker, groups=groups))
        dec_layers.append(Inception(2*channel_hid, channel_hid//2, channel_in, incep_ker= incep_ker, groups=groups))

        self.enc = nn.Sequential(*enc_layers)
        self.dec = nn.Sequential(*dec_layers)

    def forward(self, x):
        B, T, C, H, W = x.shape
        x = x.reshape(B, T*C, H, W)

        # encoder
        skips = []
        z = x
        for i in range(self.N_T):
            z = self.enc[i](z)
            if i < self.N_T - 1:
                skips.append(z)

        # decoder
        z = self.dec[0](z)
        for i in range(1, self.N_T):
            z = self.dec[i](torch.cat([z, skips[-i]], dim=1))

        y = z.reshape(B, T, C, H, W)
        return y


class SimVP(nn.Module):
    def __init__(self, hid_S, hid_T=256, N_S=4, N_T=8, incep_ker=[3,5,7,11], groups=8):
        super(SimVP, self).__init__()

        self.enc = EncoderSimVP(CNN['Ns'])
        self.hid = Mid_Xnet(10*hid_S, hid_T, N_T, incep_ker, groups)
        self.dec = DecoderSimVP(4)


    def forward(self, x_raw):
        B, T, C, H, W = x_raw.shape
        # x = x_raw.view(B*T, C, H, W)

        embed = self.enc(x_raw)
        
        hid = self.hid(embed)

        Y = self.dec(hid)
     
        return Y

In [None]:
# sborra = SimVP(16).to('cpu')
# pisello = torch.randn((12, 10, 1, 64, 64), device = 'cpu')

# retto = sborra(pisello)

#pisello.shape

In [None]:
class plTrainingModule(pl.LightningModule):
    
    def __init__(self, model):
        super(plTrainingModule, self).__init__()
        self.model = model
        self.MSE = nn.MSELoss()
        self.KL = nn.KLDivLoss()
        self.tot_loss_tr = []
        self.tot_loss = []

        self.KL_list = []
        self.mse_list = []

    def training_step(self, batch, batch_idx):
        
        ground_truth = batch[1]
       
        ### OUTPUT COMPUTATION ###
       
        prediction = self.model(batch[0])
        #KL_Loss = self.compute_KLloss(prediction, ground_truth)
        mse = self.MSE(prediction, ground_truth)

        loss = mse #+ KL_Loss
 
        self.tot_loss_tr.append(loss.item())
        
        return loss
    
    def validation_step(self, batch, batch_idx):
        
        ground_truth = batch[1]
        
        ### OUTPUT COMPUTATION ###

        prediction = self.model(batch[0])

        #KL_Loss = self.compute_KLloss(prediction, ground_truth)

        mse = self.MSE(prediction, ground_truth)
        
        loss = mse #+ KL_Loss
        
        self.tot_loss.append(loss.item())
        # self.KL_list.append(KL_Loss.item())
        self.mse_list.append(mse.item())



        return loss
    

    def compute_KLloss(self, predicted, ground_truth):

        delta_pred = (predicted[:, 1:] - predicted[:, :-1])/tau
        delta_gt = (ground_truth[:, 1:] - ground_truth[:, :-1])/tau

        soft_pred = F.softmax(delta_pred, dim = 3) # 3 is the channel-related dimension
        soft_gt = F.softmax(delta_gt, dim = 3) # 3 is the channel-related dimension

        KL_Loss = self.KL(delta_pred, delta_gt)

        return KL_Loss

    def on_epoch_end(self):
        if len(self.tot_loss_tr)!= 0:
            loss_train_mean = sum(self.tot_loss_tr)/len(self.tot_loss_tr) 
            loss_mean = sum(self.tot_loss)/len(self.tot_loss)
            # KL_mean = sum(self.KL_list)/len(self.KL_list)
            mse_mean = sum(self.mse_list)/len(self.mse_list)

            self.log(name = 'TOT. Loss on train', value = loss_train_mean)
            self.log(name = 'TOT. Loss on test', value = loss_mean)
            # self.log(name = 'KL loss on test', value = KL_mean)
            self.log(name = 'MSE loss on test', value = mse_mean)


    def configure_optimizers(self):
        return torch.optim.AdamW(self.model.parameters(), lr, weight_decay = wd)

In [None]:
num_gpu = 1 if torch.cuda.is_available() else 0


exp_name = model_conf + ' ' + str(hyperparameters['epochs']) + ' ' + str(batch_size)

In [None]:
model = SimVP(16) #VanillaED(1, 64)


pl_training_MDL = plTrainingModule(model)


### WANDB CODE ###
if wb:

    wandb_logger = WandbLogger(project=project_name, name = exp_name, config = hyperparameters, entity = 'deepl_wizards')
    trainer = pl.Trainer(
            max_epochs = hyperparameters['epochs'],  # maximum number of epochs.
            gpus=num_gpu,  # the number of gpus we have at our disposal.
            default_root_dir="", logger = wandb_logger, callbacks=[TQDMProgressBar(refresh_rate=20)]        #, overfit_batches = 1
        )

else:
    trainer = pl.Trainer(
        max_epochs = hyperparameters['epochs'],  # maximum number of epochs.
        gpus=num_gpu,  # the number of gpus we have at our disposal.
        default_root_dir="", callbacks=[TQDMProgressBar(refresh_rate=20)]
    )



In [None]:
trainer.fit(model = pl_training_MDL, datamodule = data)

### WANDB code
wandb.finish()

In [None]:
#SOFTMAX ISSUES

delta_pred = (output[:, 1:, :, :, :] - output[:, :-1, : , :, :])/tau 
# delta_pred += 1200
delta_gt = (gt[:, 1:, :, :, :] - gt[:, :-1, :, :, :])/tau
# delta_gt += 1200
soft_pred = F.softmax(delta_pred, dim = 3) # 2 is the channel-related dimension
soft_gt = F.softmax(delta_gt, dim = 3) 

In [None]:
mod = SimVP(16)
load_model('simvp.pt', mod, 'cuda')

In [None]:
torch.equal(delta_pred.to('cuda'), delta_gt.to('cuda'))
torch.equal(soft_pred.to('cuda'), soft_gt.to('cuda'))


In [None]:
for batch in data.val_dataloader():
    gt = batch[1]
    output = mod(batch[0])
    break

In [None]:
show_video(gt[1].cpu())

In [None]:
show_video(output[1].cpu().detach())

In [None]:
show_video(batch[0][0].detach())
