# Deep Learning course 2022-2023
## Project: Video Prediction on Moving MNIST
### Project Contributors
* Mattia Castelmare, 1815675
* Andrea Giuseppe Di Francesco, 1836928
* Enrico Fazzi, 2003876

#### Installation cells

In [None]:
# !git clone https://github.com/tychovdo/MovingMNIST.git
# !pip3 install pytorch-lightning==1.5.10
# !pip3 install torchvision
# !pip3 install matplotlib
# !pip install wandb

### Importing Libraries

In [1]:
### wandb codes ###
from MovingMNIST.MovingMNIST import *
import wandb
#####################

import torch
import os
import torch.nn.functional as F
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
import torch
import torch.nn as nn
from torch.autograd import Variable
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.callbacks.progress import TQDMProgressBar
from pytorch_lightning import Trainer
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from matplotlib import rc
rc('animation', html='jshtml')




## Useful Functions

In [74]:
def show_video(tensor):
    ''' This function display a video, given a torch tensor (source: https://stackoverflow.com/questions/67261108/how-to-display-a-video-in-colab-using-a-pytorch-tensor-of-rgb-image-arrays)
        INPUT: tensor (Frames x Channels x Height x Width) 
        OUTPUT: Display an animation '''

    fig, ax = plt.subplots()

    imgs = tensor
    imgs = imgs.permute(0, 2, 3, 1)  # Permuting to (Bx)HxWxC format
    frames = [[ax.imshow(imgs[i], cmap='gray')] for i in range(len(imgs))]

    ani = animation.ArtistAnimation(fig, frames)
    return ani


def collate(batch):
    ''' This function is used within the Dataloader pytorch utility. 
        INPUT: batch: List of tuples of tensor, each tuple contains the input video and the target video, 
        OUTPUT: A tuple (input, gt), s.t. both of them has dimension (B x T x C x H x W)'''

    list_tuples = list(map(lambda x: (torch.reshape(x[0], (1, x[0].shape[0], 1, x[0].shape[1], x[0].shape[2])), torch.reshape(
        x[1], (1, x[1].shape[0], 1, x[1].shape[1], x[1].shape[2]))), batch))

    input = list_tuples[0][0]
    gt = list_tuples[0][1]

    for i in range(1, len(list_tuples)):
        input = torch.cat((input, list_tuples[i][0]), dim=0)
        gt = torch.cat((gt, list_tuples[i][1]), dim=0)

    return (input.type(torch.FloatTensor), gt.type(torch.FloatTensor))


def save_model(model, loss, path):
  # This function is a customized in order to save a pytorch model.
    checkpoint = {'model_state': model.state_dict(),
                  'loss': loss}
    torch.save(checkpoint, path)


def load_model(path, model, device):
  # This function loads a pytorch model from a path.
    checkpoint = torch.load(path, map_location=device)
    model.load_state_dict(checkpoint['model_state'])

    return checkpoint


def kronecker_product(dyn, static):

    out = torch.zeros(static.shape)
    for i in range(out.shape[0]):
        r = torch.stack([torch.kron(ai, bi)
                        for ai, bi in zip(dyn[i], static[i])], dim=0)
        out[i, :, :, :] = r
    return out


### Useful variables 

In [None]:
shuffle_train = True
shuffle_test = False
download = False

### wandb codes ###
wb = True
project_name = "VP_on_MMNIST"
if wb:
    wandb.login()
###################

SEED = 2812023
pl.seed_everything(SEED)

root = './data'
if not os.path.exists(root):
    os.mkdir(root)


### Pytorch lightning DataModule:
* Set the MMNIST Dataset object and Dataloader 

In [None]:
class pl_Dataset(pl.LightningDataModule):

    def __init__(self, batch_size):

        self.bs = batch_size

    def setup(self, stage=None):
        if stage == 'fit':
            self.train_set = MovingMNIST(root=root,
                                         train=True,
                                         download=download)
        elif stage == 'test':
            self.test_set = MovingMNIST(root=root,
                                        train=False,
                                        download=download)

    def train_dataloader(self, *args, **kwargs):

        train_loader = torch.utils.data.DataLoader(
            dataset=self.train_set,
            batch_size=self.bs,
            shuffle=shuffle_train,
            collate_fn=collate)

        return train_loader

    def val_dataloader(self, *args, **kwargs):

        test_loader = torch.utils.data.DataLoader(
            dataset=self.test_set,
            batch_size=self.bs,
            shuffle=shuffle_test,
            collate_fn=collate)

        return test_loader


### Hyperparameters

In [None]:
model_conf = 'simvp'


batch_size = 16
lr = 1e-2
wd = 5e-2

tau = 1e-1  # KL Divergency parameter


### Vanilla hyperparameters ###
if model_conf == 'random':
    hyperparameters = {'lr': lr,
                       'wd': wd,
                       'epochs': 100,
                       'batch_size': batch_size
                       }
elif model_conf == 'simvp':
    hyperparameters = {'lr': lr,
                       'wd': wd,
                       'epochs': 100,
                       'batch_size': batch_size,
                       'CNN': {'input': 1,
                               'hidden_S': 64,  # Output channels in Encoder / Input channel in Decoder
                               'output_T': 512,  # Output channels in Inception, Hidden C. == Output channels // 2
                               'ksize': 3,
                               'Ns': 4,
                               'Nt': 3}
                       }


data = pl_Dataset(batch_size)

data.setup(stage='fit')
data.setup(stage='test')


### Models

### Our VanillaED (Enry and Matty)
* Implementation of a Vanilla Encoder Decoder architecture with a skip connection

In [None]:
class VEncoder(nn.Module):
    def __init__(self, input_channels, output_channels, kernel_size, padding, stride, n_layers):
        super(VEncoder, self).__init__()
        self.conv_layers = []

        self.conv = nn.Conv2d(in_channels=input_channels, out_channels=output_channels,
                              kernel_size=kernel_size, padding=padding, stride=stride)

        self.conv_layers.append(self.conv)

        for i in range(n_layers):
            self.conv_layers.append(nn.Conv2d(in_channels=output_channels, out_channels=output_channels,
                                    kernel_size=kernel_size, padding=padding, stride=stride))
            self.conv_layers.append(nn.ReLU())

        self.enc = nn.Sequential(*self.conv_layers)

    def forward(self, input):
        skip = self.enc[0](input)
        return self.enc(input), skip


class VDecoder(nn.Module):
    def __init__(self, input_channels, output_channels, kernel_size, padding, stride, n_layers, output_padding):
        super(VDecoder, self).__init__()

        self.trans_conv_layers = []

        for i in range(n_layers):
            self.trans_conv_layers.append(nn.ConvTranspose2d(in_channels=input_channels, out_channels=input_channels,
                                          kernel_size=kernel_size, padding=padding, stride=stride, output_padding=output_padding))
            self.trans_conv_layers.append(nn.ReLU())

        self.dec = nn.Sequential(*self.trans_conv_layers)

    def forward(self, input):
        return self.dec(input)


class VEncoderDecoder(nn.Module):
    def __init__(self, input_channels, output_channels, kernel_size, padding, stride, n_layers, output_padding):
        super(VEncoderDecoder, self).__init__()

        self.encoder = VEncoder(input_channels=input_channels, output_channels=output_channels,
                                kernel_size=kernel_size, padding=padding, stride=stride, n_layers=n_layers)
        self.decoder = VDecoder(input_channels=output_channels, output_channels=input_channels, kernel_size=kernel_size,
                                padding=padding, stride=stride, n_layers=n_layers, output_padding=output_padding)

        self.num_layers = n_layers

        self.final = nn.ConvTranspose2d(in_channels=2*output_channels, out_channels=input_channels, kernel_size=kernel_size, padding=padding, stride=stride, output_padding=output_padding)

    def forward(self, input):
        output, skip = self.encoder(input)
        for i in range(self.num_layers):
            output = self.decoder(output)

        output = torch.cat((skip, output), dim = 1)
        output = self.final(output)
        return output


#### SimVP
* Implementation of the SimVP architecture, the details are available at https://arxiv.org/pdf/2206.05099.pdf,
* Encoder (CNN) +  Translator (Inception) + Decoder (CNN) 

In [None]:
def stride_generator(N, reverse=False):
    strides = [1, 2]*10
    if reverse:
        return list(reversed(strides[:N]))
    else:
        return strides[:N]


class ConvBlock(nn.Module):

    def __init__(self, inputC, outputC, stride=None, padding='same', transpose=False, groups=8, kernel=hyperparameters['CNN']['ksize']):
        super(ConvBlock, self).__init__()
        if not transpose:
            self.conv = nn.Conv2d(
                inputC, outputC, kernel_size=kernel, padding=padding, stride=stride)
        else:
            self.conv = nn.ConvTranspose2d(
                inputC, outputC, kernel, stride=stride, padding=1, output_padding=stride // 2)

        self.layernorm = nn.GroupNorm(groups, outputC)
        self.leaky = nn.LeakyReLU(0.2, inplace=True)

    def forward(self, x):

        x = self.leaky(self.layernorm(self.conv(x)))

        return x


class EncoderSimVP(nn.Module):

    def __init__(self, n_layers):
        super(EncoderSimVP, self).__init__()
        strides = stride_generator(n_layers)
        layers = [ConvBlock(hyperparameters['CNN']['input'], hyperparameters['CNN']
                            ['hidden_S'], groups=2, stride=strides[0], padding=1)]

        for layer in range(1, n_layers):
            stride = strides[layer]
            layers.append(ConvBlock(
                hyperparameters['CNN']['hidden_S'], hyperparameters['CNN']['hidden_S'], groups=2, stride=stride, padding=1))

        self.enc = nn.Sequential(*layers)

    def forward(self, x):

        B, T, C, H, W = x.shape

        x = x.reshape((B*T, C, H, W))

        skip = self.enc[0](x)
        x = self.enc(x)

        # x = x.reshape((B, T, hyperparameters['CNN']['hidden_S'], H, W))

        return x, skip


class DecoderSimVP(nn.Module):

    def __init__(self, n_layers):
        super(DecoderSimVP, self).__init__()
        strides = stride_generator(n_layers, reverse=True)
        layers = []
        for layer in range(n_layers-1):
            stride = strides[layer]
            layers.append(ConvBlock(hyperparameters['CNN']['hidden_S'], hyperparameters['CNN']
                          ['hidden_S'], transpose=True, groups=2, stride=stride, padding=1))

        layers.append(ConvBlock(2*hyperparameters['CNN']['hidden_S'], hyperparameters['CNN']
                      ['hidden_S'], transpose=True, groups=2, stride=strides[-1], padding=1))

        self.proj = nn.Conv2d(
            hyperparameters['CNN']['hidden_S'], hyperparameters['CNN']['input'], 1)

        self.dec = nn.Sequential(*layers)

    def forward(self, x, skip):

        for l_idx in range(len(self.dec)-1):
            x = self.dec[l_idx](x)

        x = self.dec[-1](torch.cat((x, skip), dim=1))
        x = self.proj(x)

        return x


class InceptionModule(nn.Module):

    def __init__(self, inputC, outputC, kernel_list=[3, 5, 7, 11], groups=8):
        super(InceptionModule, self).__init__()

        hiddenC = outputC // 2

        self.Conv = nn.Conv2d(
            inputC, hiddenC, kernel_size=1, stride=1, padding=0)

        layers = []

        for kernel in kernel_list:
            layers.append(ConvBlock(hiddenC, outputC, groups=4,
                          stride=1, padding=kernel // 2, kernel=kernel))

        self.ConvParallel = nn.Sequential(*layers)

    def forward(self, input):

        x = self.Conv(input)

        output = 0

        for conv in self.ConvParallel:

            output += conv(x)

        return output


class Translator(nn.Module):

    def __init__(self, n_layers, inputC):
        super(Translator, self).__init__()

        self.n_layers = n_layers
        enc_layers = [InceptionModule(
            inputC, hyperparameters['CNN']['output_T'])]

        for inc_layer in range(n_layers-1):
            enc_layers.append(InceptionModule(
                hyperparameters['CNN']['output_T'], hyperparameters['CNN']['output_T']))

        enc_layers.append(InceptionModule(
            hyperparameters['CNN']['output_T'], hyperparameters['CNN']['output_T']))

        self.encoder = nn.Sequential(*enc_layers)

        dec_layers = [InceptionModule(
            hyperparameters['CNN']['output_T'], hyperparameters['CNN']['output_T'])]

        for inc_layers in range(n_layers-1):
            dec_layers.append(InceptionModule(
                2*hyperparameters['CNN']['output_T'], hyperparameters['CNN']['output_T']))
        dec_layers.append(InceptionModule(
            2*hyperparameters['CNN']['output_T'], inputC))

        self.decoder = nn.Sequential(*dec_layers)

    def forward(self, input):
        B, T, C, H, W = input.shape
        input = input.reshape((B, T*C, H, W))

        enc_feat = []
        mid_feat = input

        for encoder_layer in self.encoder:
            mid_feat = encoder_layer(mid_feat)
            enc_feat.append(mid_feat)

        output = self.decoder[0](mid_feat)

        for l_idx in range(1, self.n_layers+1):

            input = torch.cat((output, enc_feat[-l_idx]), dim=1)
            output = self.decoder[l_idx](input)

        output = output.reshape((B, T, C, H, W))

        return output


class SimVP(nn.Module):
    def __init__(self):
        super(SimVP, self).__init__()

        self.encoder = EncoderSimVP(hyperparameters['CNN']['Ns'])
        self.translator = Translator(
            hyperparameters['CNN']['Nt'], 10*hyperparameters['CNN']['hidden_S'])
        self.decoder = DecoderSimVP(hyperparameters['CNN']['Ns'])

    def forward(self, input):
        B, T, C, H, W = input.shape

        output, skip = self.encoder(input)
        _, C_, H_, W_ = output.shape
        output = output.view(B, T, C_, H_, W_)

        output = self.translator(output)

        output = output.reshape((B*T, C_, H_, W_))

        Y = self.decoder(output, skip)

        Y = Y.reshape((B, T, C, H, W))

        return Y


#### Ground Truth

In [None]:
import torch
from torch import nn


class BasicConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding, transpose=False, act_norm=False):
        super(BasicConv2d, self).__init__()
        self.act_norm = act_norm
        if not transpose:
            self.conv = nn.Conv2d(
                in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
        else:
            self.conv = nn.ConvTranspose2d(
                in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, output_padding=stride // 2)
        self.norm = nn.GroupNorm(2, out_channels)
        self.act = nn.LeakyReLU(0.2, inplace=True)

    def forward(self, x):
        y = self.conv(x)
        if self.act_norm:
            y = self.act(self.norm(y))
        return y


class ConvSC(nn.Module):
    def __init__(self, C_in, C_out, stride, transpose=False, act_norm=True):
        super(ConvSC, self).__init__()
        if stride == 1:
            transpose = False
        self.conv = BasicConv2d(C_in, C_out, kernel_size=3, stride=stride,
                                padding=1, transpose=transpose, act_norm=act_norm)

    def forward(self, x):
        y = self.conv(x)
        return y


class GroupConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding, groups, act_norm=False):
        super(GroupConv2d, self).__init__()
        self.act_norm = act_norm
        if in_channels % groups != 0:
            groups = 1
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size,
                              stride=stride, padding=padding, groups=groups)
        self.norm = nn.GroupNorm(groups, out_channels)
        self.activate = nn.LeakyReLU(0.2, inplace=True)

    def forward(self, x):
        y = self.conv(x)
        if self.act_norm:
            y = self.activate(self.norm(y))
        return y


class Inception(nn.Module):
    def __init__(self, C_in, C_hid, C_out, incep_ker=[3, 5, 7, 11], groups=8):
        super(Inception, self).__init__()
        self.conv1 = nn.Conv2d(C_in, C_hid, kernel_size=1, stride=1, padding=0)
        layers = []
        for ker in incep_ker:
            layers.append(GroupConv2d(C_hid, C_out, kernel_size=ker,
                          stride=1, padding=ker//2, groups=groups, act_norm=True))
        self.layers = nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        y = 0
        for layer in self.layers:
            y += layer(x)
        return y


def stride_generator(N, reverse=False):
    strides = [1, 2]*10
    if reverse:
        return list(reversed(strides[:N]))
    else:
        return strides[:N]


class Encoder(nn.Module):
    def __init__(self, C_in, C_hid, N_S):
        super(Encoder, self).__init__()
        strides = stride_generator(N_S)
        self.enc = nn.Sequential(
            ConvSC(C_in, C_hid, stride=strides[0]),
            *[ConvSC(C_hid, C_hid, stride=s) for s in strides[1:]]
        )

    def forward(self, x):  # B*4, 3, 128, 128
        enc1 = self.enc[0](x)
        latent = enc1
        for i in range(1, len(self.enc)):
            latent = self.enc[i](latent)

        return latent, enc1


class Decoder(nn.Module):
    def __init__(self, C_hid, C_out, N_S):
        super(Decoder, self).__init__()
        strides = stride_generator(N_S, reverse=True)
        self.dec = nn.Sequential(
            *[ConvSC(C_hid, C_hid, stride=s, transpose=True)
              for s in strides[:-1]],
            ConvSC(2*C_hid, C_hid, stride=strides[-1], transpose=True)
        )
        self.readout = nn.Conv2d(C_hid, C_out, 1)

    def forward(self, hid, enc1=None):
        for i in range(0, len(self.dec)-1):
            hid = self.dec[i](hid)
        Y = self.dec[-1](torch.cat([hid, enc1], dim=1))
        Y = self.readout(Y)
        return Y


class Mid_Xnet(nn.Module):
    def __init__(self, channel_in, channel_hid, N_T, incep_ker=[3, 5, 7, 11], groups=8):
        super(Mid_Xnet, self).__init__()

        self.N_T = N_T
        enc_layers = [Inception(
            channel_in, channel_hid//2, channel_hid, incep_ker=incep_ker, groups=groups)]
        for i in range(1, N_T-1):
            enc_layers.append(Inception(
                channel_hid, channel_hid//2, channel_hid, incep_ker=incep_ker, groups=groups))
        enc_layers.append(Inception(channel_hid, channel_hid //
                          2, channel_hid, incep_ker=incep_ker, groups=groups))

        dec_layers = [Inception(
            channel_hid, channel_hid//2, channel_hid, incep_ker=incep_ker, groups=groups)]
        for i in range(1, N_T-1):
            dec_layers.append(Inception(
                2*channel_hid, channel_hid//2, channel_hid, incep_ker=incep_ker, groups=groups))
        dec_layers.append(Inception(2*channel_hid, channel_hid //
                          2, channel_in, incep_ker=incep_ker, groups=groups))

        self.enc = nn.Sequential(*enc_layers)
        self.dec = nn.Sequential(*dec_layers)

    def forward(self, x):
        B, T, C, H, W = x.shape
        x = x.reshape(B, T*C, H, W)

        # encoder
        skips = []
        z = x
        for i in range(self.N_T):
            z = self.enc[i](z)
            if i < self.N_T - 1:
                skips.append(z)

        # decoder
        z = self.dec[0](z)
        for i in range(1, self.N_T):
            z = self.dec[i](torch.cat([z, skips[-i]], dim=1))

        y = z.reshape(B, T, C, H, W)
        return y


class SimVP(nn.Module):
    def __init__(self, shape_in, hid_S=64, hid_T=256, N_S=4, N_T=8, incep_ker=[3, 5, 7, 11], groups=8):
        super(SimVP, self).__init__()
        T, C, H, W = shape_in
        self.enc = Encoder(C, hid_S, N_S)
        self.hid = Mid_Xnet(T*hid_S, hid_T, N_T, incep_ker, groups)
        self.dec = Decoder(hid_S, C, N_S)

    def forward(self, x_raw):
        B, T, C, H, W = x_raw.shape
        x = x_raw.view(B*T, C, H, W)

        embed, skip = self.enc(x)
        _, C_, H_, W_ = embed.shape

        z = embed.view(B, T, C_, H_, W_)
        hid = self.hid(z)
        hid = hid.reshape(B*T, C_, H_, W_)

        Y = self.dec(hid, skip)
        Y = Y.reshape(B, T, C, H, W)
        return Y


### TAU Module
* Implementation of th Temporal Attention Unit architecture

In [90]:
class Dyn_attention(nn.Module):
    def __init__(self, input, n_layers, hidden_dim, H_final, W_final, hidden_act=True): 
        super(Dyn_attention, self).__init__()

        self.avg_pool = nn.AdaptiveAvgPool2d((H_final, W_final))
        layers = []

        with torch.no_grad():
            test = self.avg_pool(input)
            flat = F.torch.flatten(test)
            input_ftrs = flat.shape[0]

        layers.append(nn.Linear(
            in_features=input_ftrs, out_features=hidden_dim, bias=True))

        if hidden_act:
            layers.append(nn.ReLU(inplace=True))

        for i in range(n_layers):
            layers.append(nn.Linear(in_features=hidden_dim,
                          out_features=hidden_dim, bias=True))
            if hidden_act:
                layers.append(nn.ReLU(inplace=True))

        layers.append(nn.Linear(in_features=hidden_dim,
                      out_features=input_ftrs, bias=True))
        layers.append(nn.Sigmoid())

        self.MLP = nn.Sequential(*layers)

    def forward(self, H):
        x = self.avg_pool(H)
        B, TC, _, _ = x.shape

        x = F.torch.flatten(x)
        x = self.MLP(x).view(B, TC, 1, 1)
        return x


class Static_attention(nn.Module):
    def __init__(self, input, small_kern, dilatation, n_layers):
        super(Static_attention, self).__init__()

        B, TC, H, W = input.shape

        depth_conv = []
        dilat_conv = []
        last_conv = []

        for i in range(n_layers):
            depth_conv.append(nn.Conv2d(
                in_channels=TC, out_channels=TC, kernel_size=small_kern, padding='same', groups=TC))
            dilat_conv.append(nn.Conv2d(
                in_channels=TC, out_channels=TC, kernel_size=small_kern, padding='same', dilation=dilatation, groups=TC))
            last_conv.append(
                nn.Conv2d(in_channels=TC, out_channels=TC, kernel_size=1, padding='same'))

        self.depth_conv = nn.Sequential(*depth_conv)
        self.dilat_conv = nn.Sequential(*dilat_conv)
        self.last_conv = nn.Sequential(*last_conv)

    def forward(self, x):
        x = self.depth_conv(x)
        x = self.dilat_conv(x)
        x = self.last_conv(x)
        return x


class TAU(nn.Module):
    def __init__(self, input, hidden_dim, small_kern, dilatation, num_layers, H_final=1, W_final=1):
        super(TAU, self).__init__()

        self.DA_module = Dyn_attention(
            input=input, H_final=H_final, W_final=W_final, hidden_dim=hidden_dim, n_layers=num_layers, hidden_act=True)
        self.SA_module = Static_attention(
            input=input, small_kern=small_kern, dilatation=dilatation, n_layers=num_layers)

    def forward(self, H):
        dynamic = self.DA_module(H)
        static = self.SA_module(H)
        res = kronecker_product(dynamic, static)
        res = res * H
        return res


In [91]:
a = torch.randn(16,32,64,64)
t = TAU(a,hidden_dim=3,small_kern=3,dilatation=2,num_layers=3)
b = t(a)
print(b.shape)

torch.Size([16, 32, 64, 64])


# PROVA

In [None]:
class plTrainingModule(pl.LightningModule):

    def __init__(self, model):
        super(plTrainingModule, self).__init__()
        self.model = model
        self.MSE = nn.MSELoss()
        self.KL = nn.KLDivLoss()
        self.tot_loss_tr = []
        self.tot_loss = []

        self.KL_list = []
        self.mse_list = []

    def training_step(self, batch, batch_idx):

        ground_truth = batch[1]

        ### OUTPUT COMPUTATION ###

        prediction = self.model(batch[0])
        # KL_Loss = self.compute_KLloss(prediction, ground_truth)
        mse = self.MSE(prediction, ground_truth)

        loss = mse  # + KL_Loss

        self.tot_loss_tr.append(loss.item())

        return loss

    def validation_step(self, batch, batch_idx):

        ground_truth = batch[1]

        ### OUTPUT COMPUTATION ###

        prediction = self.model(batch[0])

        # KL_Loss = self.compute_KLloss(prediction, ground_truth)

        mse = self.MSE(prediction, ground_truth)

        loss = mse  # + KL_Loss

        self.tot_loss.append(loss.item())
        # self.KL_list.append(KL_Loss.item())
        self.mse_list.append(mse.item())

        return loss

    def compute_KLloss(self, predicted, ground_truth):

        delta_pred = (predicted[:, 1:] - predicted[:, :-1])/tau
        delta_gt = (ground_truth[:, 1:] - ground_truth[:, :-1])/tau

        # 3 is the channel-related dimension
        soft_pred = F.softmax(delta_pred, dim=3)
        # 3 is the channel-related dimension
        soft_gt = F.softmax(delta_gt, dim=3)

        KL_Loss = self.KL(delta_pred, delta_gt)

        return KL_Loss

    def on_epoch_end(self):
        if len(self.tot_loss_tr) != 0:
            loss_train_mean = sum(self.tot_loss_tr)/len(self.tot_loss_tr)
            loss_mean = sum(self.tot_loss)/len(self.tot_loss)
            # KL_mean = sum(self.KL_list)/len(self.KL_list)
            mse_mean = sum(self.mse_list)/len(self.mse_list)

            self.log(name='TOT. Loss on train', value=loss_train_mean)
            self.log(name='TOT. Loss on test', value=loss_mean)
            # self.log(name = 'KL loss on test', value = KL_mean)
            self.log(name='MSE loss on test', value=mse_mean)

    def configure_optimizers(self):
        return torch.optim.AdamW(self.model.parameters(), lr, weight_decay=wd)


In [None]:
num_gpu = 1 if torch.cuda.is_available() else 0


exp_name = model_conf + ' ' + \
    str(hyperparameters['epochs']) + ' ' + str(batch_size)


In [None]:
model = SimVP()  # VanillaED(1, 64)


pl_training_MDL = plTrainingModule(model)


### WANDB CODE ###
if wb:

    wandb_logger = WandbLogger(
        project=project_name, name=exp_name, config=hyperparameters, entity='deepl_wizards')
    trainer = pl.Trainer(
        max_epochs=hyperparameters['epochs'],  # maximum number of epochs.
        gpus=num_gpu,  # the number of gpus we have at our disposal.
        # , overfit_batches = 1
        default_root_dir="", logger=wandb_logger, callbacks=[TQDMProgressBar(refresh_rate=20)]
    )

else:
    trainer = pl.Trainer(
        max_epochs=hyperparameters['epochs'],  # maximum number of epochs.
        gpus=num_gpu,  # the number of gpus we have at our disposal.
        default_root_dir="", callbacks=[TQDMProgressBar(refresh_rate=20)]
    )


In [None]:
trainer.fit(model=pl_training_MDL, datamodule=data)

# WANDB code
wandb.finish()


In [None]:
# SOFTMAX ISSUES

delta_pred = (output[:, 1:, :, :, :] - output[:, :-1, :, :, :])/tau
# delta_pred += 1200
delta_gt = (gt[:, 1:, :, :, :] - gt[:, :-1, :, :, :])/tau
# delta_gt += 1200
soft_pred = F.softmax(delta_pred, dim=3)  # 2 is the channel-related dimension
soft_gt = F.softmax(delta_gt, dim=3)


In [None]:
mod = SimVP(16)
load_model('simvp.pt', mod, 'cuda')
mod.to('cuda')


In [None]:
torch.equal(delta_pred.to('cuda'), delta_gt.to('cuda'))
torch.equal(soft_pred.to('cuda'), soft_gt.to('cuda'))


In [None]:
for batch in data.val_dataloader():
    gt = batch[1].to('cuda')
    output = model(batch[0].to('cuda'))
    break

In [None]:
show_video(gt[1].cpu())

In [None]:
show_video(output[1].cpu().detach())

In [None]:
show_video(batch[0][0].detach())
