# **IMPORTS**

In [None]:
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from torch.autograd import Variable
from torch.optim.lr_scheduler import ReduceLROnPlateau
import torch.utils.data as data
from torch.utils.data.sampler import SubsetRandomSampler
import numpy as np
from tqdm import tqdm
from numpy import *
from numpy.linalg import *
from scipy.special import factorial
from functools import reduce
import random
from torchvision import transforms
import matplotlib.pyplot as plt
import time
import gzip
import cv2
import math
import os
from PIL import Image
from skimage.metrics import structural_similarity as compare_ssim
from skimage.metrics import structural_similarity as ssim
from skimage.transform import resize
import argparse
# !pip install lpips
import codecs
# import lpips
#!pip install pynvml
#import pynvml

# **UTILITY**

In [None]:
def reshape_patch(img_tensor, patch_size):
    assert 4 == img_tensor.ndim
    seq_length = np.shape(img_tensor)[0]
    img_height = np.shape(img_tensor)[1]
    img_width = np.shape(img_tensor)[2]
    num_channels = np.shape(img_tensor)[3]
    a = np.reshape(img_tensor, [seq_length,
                                img_height // patch_size, patch_size,
                                img_width // patch_size, patch_size,
                                num_channels])
    b = np.transpose(a, [0, 1, 3, 2, 4, 5])
    patch_tensor = np.reshape(b, [seq_length,
                                  img_height // patch_size,
                                  img_width // patch_size,
                                  patch_size * patch_size * num_channels])
    return patch_tensor


def reshape_patch_back(patch_tensor, patch_size):
    # B L H W C
    assert 5 == patch_tensor.ndim
    batch_size = np.shape(patch_tensor)[0]
    seq_length = np.shape(patch_tensor)[1]
    patch_height = np.shape(patch_tensor)[2]
    patch_width = np.shape(patch_tensor)[3]
    channels = np.shape(patch_tensor)[4]
    img_channels = channels // (patch_size * patch_size)
    a = np.reshape(patch_tensor, [batch_size, seq_length,
                                  patch_height, patch_width,
                                  patch_size, patch_size,
                                  img_channels])
    b = np.transpose(a, [0, 1, 2, 4, 3, 5, 6])
    img_tensor = np.reshape(b, [batch_size, seq_length,
                                patch_height * patch_size,
                                patch_width * patch_size,
                                img_channels])
    return img_tensor


def reshape_patch_back_tensor(patch_tensor, patch_size):
    # B L H W C
    assert 5 == patch_tensor.ndim
    patch_narray = patch_tensor.detach().cpu().numpy()
    batch_size = np.shape(patch_narray)[0]
    seq_length = np.shape(patch_narray)[1]
    patch_height = np.shape(patch_narray)[2]
    patch_width = np.shape(patch_narray)[3]
    channels = np.shape(patch_narray)[4]
    img_channels = channels // (patch_size * patch_size)
    a = torch.reshape(patch_tensor, [batch_size, seq_length,
                                     patch_height, patch_width,
                                     patch_size, patch_size,
                                     img_channels])
    b = a.permute([0, 1, 2, 4, 3, 5, 6])
    img_tensor = torch.reshape(b, [batch_size, seq_length,
                                   patch_height * patch_size,
                                   patch_width * patch_size,
                                   img_channels])
    return img_tensor.permute(0, 1, 4, 2, 3)


def reshape_patch_tensor(img_tensor, patch_size):
    assert 4 == img_tensor.ndim
    seq_length = img_tensor.shape[0]
    img_height = img_tensor.shape[1]
    img_width = img_tensor.shape[2]
    num_channels = img_tensor.shape[3]
    a = torch.reshape(img_tensor, [seq_length,
                                   img_height // patch_size, patch_size,
                                   img_width // patch_size, patch_size,
                                   num_channels])
    b = a.permute((0, 1, 3, 2, 4, 5))
    patch_tensor = torch.reshape(b, [seq_length,
                                     img_height // patch_size,
                                     img_width // patch_size,
                                     patch_size * patch_size * num_channels])
    return patch_tensor.permute((0, 3, 1, 2))

In [None]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

# **DATA LOADER**

In [None]:
class Norm(object):
    def __init__(self, max=255):
        self.max = max

    def __call__(self, sample):
        video_x = sample
        new_video_x = video_x / self.max
        return new_video_x


class ToTensor(object):

    def __call__(self, sample):
        video_x = sample
        video_x = video_x.transpose((0, 3, 1, 2))
        video_x = np.array(video_x)
        return torch.from_numpy(video_x).float()
    

class Resize(object):

    def __call__(self, sample):
        imgs_out = np.zeros((
            sample.shape[0], configs.img_height, configs.img_width, sample.shape[3]))
        for i in range(sample.shape[0]):
            imgs_out[i,:,:,:] = resize(sample[i,:,:,:], imgs_out.shape[1:])
        return imgs_out

In [None]:
class TimeSeriesDataset(data.Dataset):
    def __init__(self, root_dir, n_frames_input=10, n_frames_output=10):
        self.n_frames_in = n_frames_input
        self.n_frames_out = n_frames_output
        random.seed(420)
        n_frames = n_frames_input + n_frames_output
        
        self.file = np.load(root_dir).transpose(1,0,2,3)[..., np.newaxis].transpose(0,1,4,2,3)#[:10]        
            
            
    def __len__(self):
        return len(self.file)

    def __getitem__(self, index):
        clips = torch.from_numpy(self.file[index])
        clips = clips.type(torch.float32)
        clips = (clips / 255)
        return clips

In [None]:
td = TimeSeriesDataset(root_dir='../input/moving-mnist/mnist_test_seq.npy', n_frames_input=10, n_frames_output=10)
train_loader = torch.utils.data.DataLoader(dataset=td, batch_size=1, shuffle=True, num_workers=2)

In [None]:
z = next(iter(train_loader))
print(z.shape)
torch.max(z), torch.min(z)

# **MODELS**

## MAU-CELL

In [None]:
class MAUCell(nn.Module):
    def __init__(self, in_channel, num_hidden, height, width, filter_size, stride, tau, cell_mode):
        super(MAUCell, self).__init__()

        self.num_hidden = num_hidden
        self.padding = (filter_size[0] // 2, filter_size[1] // 2)
        self.cell_mode = cell_mode
        self.d = num_hidden * height * width
        self.tau = tau
        self.states = ['residual', 'normal']
        if not self.cell_mode in self.states:
            raise AssertionError
        self.conv_t = nn.Sequential(
            nn.Conv2d(in_channel, 3 * num_hidden, kernel_size=filter_size, stride=stride, padding=self.padding,
                      ),
            nn.LayerNorm([3 * num_hidden, height, width])
        )
        self.conv_t_next = nn.Sequential(
            nn.Conv2d(in_channel, num_hidden, kernel_size=filter_size, stride=stride, padding=self.padding,
                      ),
            nn.LayerNorm([num_hidden, height, width])
        )
        self.conv_s = nn.Sequential(
            nn.Conv2d(num_hidden, 3 * num_hidden, kernel_size=filter_size, stride=stride, padding=self.padding,
                      ),
            nn.LayerNorm([3 * num_hidden, height, width])
        )
        self.conv_s_next = nn.Sequential(
            nn.Conv2d(num_hidden, num_hidden, kernel_size=filter_size, stride=stride, padding=self.padding,
                      ),
            nn.LayerNorm([num_hidden, height, width])
        )
        self.softmax = nn.Softmax(dim=0)

    def forward(self, T_t, S_t, t_att, s_att):
        s_next = self.conv_s_next(S_t)
        t_next = self.conv_t_next(T_t)
        weights_list = []
        for i in range(self.tau):
            weights_list.append((s_att[i] * s_next).sum(dim=(1, 2, 3)) / math.sqrt(self.d))
        weights_list = torch.stack(weights_list, dim=0)
        weights_list = torch.reshape(weights_list, (*weights_list.shape, 1, 1, 1))
        weights_list = self.softmax(weights_list)
        T_trend = t_att * weights_list
        T_trend = T_trend.sum(dim=0)
        t_att_gate = torch.sigmoid(t_next)
        T_fusion = T_t * t_att_gate + (1 - t_att_gate) * T_trend
        T_concat = self.conv_t(T_fusion)
        S_concat = self.conv_s(S_t)
        t_g, t_t, t_s = torch.split(T_concat, self.num_hidden, dim=1)
        s_g, s_t, s_s = torch.split(S_concat, self.num_hidden, dim=1)
        T_gate = torch.sigmoid(t_g)
        S_gate = torch.sigmoid(s_g)
        T_new = T_gate * t_t + (1 - T_gate) * s_t
        S_new = S_gate * s_s + (1 - S_gate) * t_s
        if self.cell_mode == 'residual':
            S_new = S_new + S_t
        return T_new, S_new

## MAU

In [None]:
class RNN(nn.Module):
    def __init__(self, num_layers, num_hidden, configs):
        super(RNN, self).__init__()
        self.configs = configs
        self.frame_channel = configs.patch_size * configs.patch_size * configs.img_channel
        self.num_layers = num_layers
        self.num_hidden = num_hidden
        self.tau = configs.tau
        self.cell_mode = configs.cell_mode
        self.states = ['recall', 'normal']
        if not self.configs.model_mode in self.states:
            raise AssertionError
        cell_list = []

        width = configs.img_width // configs.patch_size // configs.sr_size
        height = configs.img_height // configs.patch_size // configs.sr_size

        for i in range(num_layers):
            in_channel = num_hidden[i - 1]
            cell_list.append(
                MAUCell(in_channel, num_hidden[i], height, width, configs.filter_size,
                        configs.stride, self.tau, self.cell_mode)
            )
        self.cell_list = nn.ModuleList(cell_list)

        # Encoder
        n = int(math.log2(configs.sr_size))
        encoders = []
        encoder = nn.Sequential()
        encoder.add_module(name='encoder_t_conv{0}'.format(-1),
                           module=nn.Conv2d(in_channels=self.frame_channel,
                                            out_channels=self.num_hidden[0],
                                            stride=1,
                                            padding=0,
                                            kernel_size=1))
        encoder.add_module(name='relu_t_{0}'.format(-1),
                           module=nn.LeakyReLU(0.2))
        encoders.append(encoder)
        for i in range(n):
            encoder = nn.Sequential()
            encoder.add_module(name='encoder_t{0}'.format(i),
                               module=nn.Conv2d(in_channels=self.num_hidden[0],
                                                out_channels=self.num_hidden[0],
                                                stride=(2, 2),
                                                padding=(1, 1),
                                                kernel_size=(3, 3)
                                                ))
            encoder.add_module(name='encoder_t_relu{0}'.format(i),
                               module=nn.LeakyReLU(0.2))
            encoders.append(encoder)
        self.encoders = nn.ModuleList(encoders)

        # Decoder
        decoders = []

        for i in range(n - 1):
            decoder = nn.Sequential()
            decoder.add_module(name='c_decoder{0}'.format(i),
                               module=nn.ConvTranspose2d(in_channels=self.num_hidden[-1],
                                                         out_channels=self.num_hidden[-1],
                                                         stride=(2, 2),
                                                         padding=(1, 1),
                                                         kernel_size=(3, 3),
                                                         output_padding=(1, 1)
                                                         ))
            decoder.add_module(name='c_decoder_relu{0}'.format(i),
                               module=nn.LeakyReLU(0.2))
            decoders.append(decoder)

        if n > 0:
            decoder = nn.Sequential()
            decoder.add_module(name='c_decoder{0}'.format(n - 1),
                               module=nn.ConvTranspose2d(in_channels=self.num_hidden[-1],
                                                         out_channels=self.num_hidden[-1],
                                                         stride=(2, 2),
                                                         padding=(1, 1),
                                                         kernel_size=(3, 3),
                                                         output_padding=(1, 1)
                                                         ))
            decoders.append(decoder)
        self.decoders = nn.ModuleList(decoders)

        self.srcnn = nn.Sequential(
            nn.Conv2d(self.num_hidden[-1], self.frame_channel, kernel_size=1, stride=1, padding=0)
        )
        self.merge = nn.Conv2d(self.num_hidden[-1] * 2, self.num_hidden[-1], kernel_size=1, stride=1, padding=0)
        self.conv_last_sr = nn.Conv2d(self.frame_channel * 2, self.frame_channel, kernel_size=1, stride=1, padding=0)

    def forward(self, frames, mask_true):
        # print('ok')
        mask_true = mask_true.permute(0, 1, 4, 2, 3).contiguous()
        batch_size = frames.shape[0]
        height = frames.shape[3] // self.configs.sr_size
        width = frames.shape[4] // self.configs.sr_size
        frame_channels = frames.shape[2]
        next_frames = []
        T_t = []
        T_pre = []
        S_pre = []
        x_gen = None
        for layer_idx in range(self.num_layers):
            tmp_t = []
            tmp_s = []
            if layer_idx == 0:
                in_channel = self.num_hidden[layer_idx]
            else:
                in_channel = self.num_hidden[layer_idx - 1]
            for i in range(self.tau):
                tmp_t.append(torch.zeros([batch_size, in_channel, height, width]).to(self.configs.device))
                tmp_s.append(torch.zeros([batch_size, in_channel, height, width]).to(self.configs.device))
            T_pre.append(tmp_t)
            S_pre.append(tmp_s)

        for t in range(self.configs.total_length - 1):
            if t < self.configs.input_length:
                net = frames[:, t]
            else:
                time_diff = t - self.configs.input_length
                net = mask_true[:, time_diff] * frames[:, t] + (1 - mask_true[:, time_diff]) * x_gen
            frames_feature = net
            frames_feature_encoded = []
            for i in range(len(self.encoders)):
                frames_feature = self.encoders[i](frames_feature)
                frames_feature_encoded.append(frames_feature)
            if t == 0:
                for i in range(self.num_layers):
                    zeros = torch.zeros([batch_size, self.num_hidden[i], height, width]).to(self.configs.device)
                    T_t.append(zeros)
            S_t = frames_feature
            for i in range(self.num_layers):
                t_att = T_pre[i][-self.tau:]
                t_att = torch.stack(t_att, dim=0)
                s_att = S_pre[i][-self.tau:]
                s_att = torch.stack(s_att, dim=0)
                S_pre[i].append(S_t)
                T_t[i], S_t = self.cell_list[i](T_t[i], S_t, t_att, s_att)
                T_pre[i].append(T_t[i])
            out = S_t
            # out = self.merge(torch.cat([T_t[-1], S_t], dim=1))
            frames_feature_decoded = []
            for i in range(len(self.decoders)):
                out = self.decoders[i](out)
                if self.configs.model_mode == 'recall':
                    out = out + frames_feature_encoded[-2 - i]

            x_gen = self.srcnn(out)
            next_frames.append(x_gen)
        next_frames = torch.stack(next_frames, dim=0).permute(1, 0, 2, 3, 4).contiguous()
        return next_frames

# **TRAINER TESTER**

In [None]:
def train(model, ims, real_input_flag, configs, itr, val):
    _, loss_l1, loss_l2 = model.train(ims, real_input_flag, itr, val)
    if itr % configs.display_interval == 0:
        print('Step: ' + str(itr),
              'Training L1 loss: ' + str(loss_l1), 'Training L2 loss: ' + str(loss_l2))
    return loss_l1, loss_l2

In [None]:
def evaluation_proper(model, test_loader, configs, out_len=10):
    print('Evaluating...')
    
    loss_fn = lpips.LPIPS(net='alex', spatial=True).to(configs.device)
    mse_list = np.empty((len(test_loader), out_len))
    mae_list = np.empty((len(test_loader), out_len))
    ssim_list = np.empty((len(test_loader), out_len))
    psnr_list = np.empty((len(test_loader), out_len))
    lpips_list = np.empty((len(test_loader), out_len))
    
    total_mse = 0
    total_mae = 0
    
    with torch.no_grad():
        #model.eval()
        for i, data in tqdm(enumerate(test_loader, 0), total=len(test_loader)):
            batch_size = data.shape[0]
            real_input_flag = np.zeros(
                (batch_size,
                 configs.total_length - configs.input_length - 1,
                 configs.img_height // configs.patch_size,
                 configs.img_width // configs.patch_size,
                 configs.patch_size ** 2 * configs.img_channel))

            img_gen = model.test(data, real_input_flag)
            img_gen = img_gen.transpose(0, 1, 3, 4, 2)  # * 0.5 + 0.5
            test_ims = data.detach().cpu().numpy().transpose(0, 1, 3, 4, 2)  # * 0.5 + 0.5
            output_length = configs.total_length - configs.input_length
            output_length = min(output_length, configs.total_length - 1)
            test_ims = reshape_patch_back(test_ims, configs.patch_size)
            img_gen = reshape_patch_back(img_gen, configs.patch_size)
            target = data[:, configs.input_length:, :].detach().cpu().numpy().transpose(0, 1, 3, 4, 2)
            predictions = img_gen[:, -output_length:, :]
            
            if (i+1) % 500 == 0:
                print(target[0, 1, 40:42, 40:42, 0])
                print(predictions[0, 1, 40:42, 40:42, 0])
                fig, ax = plt.subplots(2, out_len, figsize=(25, 7))
                for i in range(2):
                    for j in range(out_len):
                        if i == 0:
                            ax[i][j].imshow(target[0][j])
                            ax[i][j].set_title('V Ground Truth')
                        if i == 1:
                            ax[i][j].imshow(predictions[0][j])
                            ax[i][j].set_title('V Generated')
                        ax[i][j].axis('off')
                plt.show()
            
            mse_batch = np.mean((predictions-target)**2 , axis=(0,1,4)).sum()
            mae_batch = np.mean(np.abs(predictions-target),  axis=(0,1,4)).sum() 
            total_mse += mse_batch
            total_mae += mae_batch
            
            for j in range(out_len):
                mse_list[i][j] = np.square(predictions[:,j,:,:,:] - target[:,j,:,:,:]).mean()
                mae_list[i][j] = np.abs(predictions[:,j,:,:,:] - target[:,j,:,:,:]).mean()
                ssim_list[i][j] = ssim(target[0,j,:,:,0], predictions[0,j,:,:,0], multichannel=False)
                psnr_list[i][j] = 20 * np.log10(1 / sqrt(mse_list[i][j]))
                t1 = torch.from_numpy((predictions[:,j,:,:,:] - 0.5) / 0.5).to(configs.device).permute((0, 3, 1, 2))
                t2 = torch.from_numpy((target[:,j,:,:,:] - 0.5) / 0.5).to(configs.device).permute((0, 3, 1, 2))
                d = loss_fn.forward(t1, t2)
                lpips_list[i][j] = d.mean().detach().cpu().numpy() * 100
                    
        #model.train()
        
    avg_mse_frame = mse_list.mean(axis=0)
    avg_mae_frame = mae_list.mean(axis=0)
    avg_ssim_frame = ssim_list.mean(axis=0)
    avg_psnr_frame = psnr_list.mean(axis=0)
    avg_lpips_frame = lpips_list.mean(axis=0)

    avg_mse = mse_list.mean()
    avg_mae = mae_list.mean()
    avg_ssim = ssim_list.mean()
    avg_psnr = psnr_list.mean()
    avg_lpips = lpips_list.mean()

    print('Eval MSE: ', total_mse/len(test_loader))
    print('Eval MAE: ', total_mae/len(test_loader))
    
    print(f'Avg-MSE: {avg_mse}\nMSE/Frame: {avg_mse_frame}')
    print(f'Avg-MAE: {avg_mae}\nMAE/Frame: {avg_mae_frame}')
    print(f'Avg-SSIM: {avg_ssim}\nSSIM/Frame: {avg_ssim_frame}')
    print(f'Avg-PSNR: {avg_psnr}\nPSNR/Frame: {avg_psnr_frame}')
    print(f'Avg-LPIPS: {avg_lpips}\nLPIPS/Frame: {avg_lpips_frame}')
    
    return avg_mse

# **TRAIN TEST WRAPPER**

In [None]:
def schedule_sampling(eta, itr, channel, batch_size):
    zeros = np.zeros((batch_size,
                      args.total_length - args.input_length - 1,
                      args.img_height // args.patch_size,
                      args.img_width // args.patch_size,
                      args.patch_size ** 2 * channel))
    if not args.scheduled_sampling:
        return 0.0, zeros

    if itr < args.sampling_stop_iter:
        eta -= args.sampling_changing_rate
    else:
        eta = 0.0
    #print('eta: ', eta)
    random_flip = np.random.random_sample(
        (batch_size, args.total_length - args.input_length - 1))
    true_token = (random_flip < eta)
    ones = np.ones((args.img_height // args.patch_size,
                    args.img_width // args.patch_size,
                    args.patch_size ** 2 * channel))
    zeros = np.zeros((args.img_height // args.patch_size,
                      args.img_width // args.patch_size,
                      args.patch_size ** 2 * channel))
    real_input_flag = []
    for i in range(batch_size):
        for j in range(args.total_length - args.input_length - 1):
            if true_token[i, j]:
                real_input_flag.append(ones)
            else:
                real_input_flag.append(zeros)
    real_input_flag = np.array(real_input_flag)
    real_input_flag = np.reshape(real_input_flag,
                                 (batch_size,
                                  args.total_length - args.input_length - 1,
                                  args.img_height // args.patch_size,
                                  args.img_width // args.patch_size,
                                  args.patch_size ** 2 * channel))
    return eta, real_input_flag


def train_wrapper(model):
    begin = 0
#     handle = pynvml.nvmlDeviceGetHandleByIndex(0)
#     meminfo_begin = pynvml.nvmlDeviceGetMemoryInfo(handle)

    if args.pretrained_model:
        model.load(args.pretrained_model)
        #begin = int(args.pretrained_model.split('-')[-1])

        
    # DATASET
    dataset = TimeSeriesDataset(root_dir=configs.data_train_path, n_frames_input=10, n_frames_output=10)
    
    # DATA LOADER + SPLIT
    validation_split = .3
    shuffle_dataset = True
    random_seed= 1000

    dataset_size = len(dataset)
    indices = list(range(dataset_size))
    split = int(np.floor(validation_split * dataset_size))
    if shuffle_dataset :
        np.random.seed(random_seed)
        np.random.shuffle(indices)
    train_indices, val_indices = indices[split:], indices[:split]

    # Creating PT data samplers and loaders:
    train_sampler = SubsetRandomSampler(train_indices)
    valid_sampler = SubsetRandomSampler(val_indices)
    train_input_handle = torch.utils.data.DataLoader(dataset, args.batch_size, sampler=train_sampler, num_workers=2, pin_memory=True)
    val_input_handle = torch.utils.data.DataLoader(dataset, batch_size=1, sampler=valid_sampler)

    losses_l1 = []
    losses_l2 = []
    
    eta = args.sampling_start_value
    eta -= (begin * args.sampling_changing_rate)
    itr = begin
    # real_input_flag = {}
    for epoch in range(0, args.max_epoches):
        
        if epoch == 0:    
            pass#evaluation_proper(model, val_input_handle, configs, out_len=10)
        
        for ims in tqdm(train_input_handle, total=len(train_input_handle)):
            if itr > args.max_iterations:
                break
            batch_size = ims.shape[0]
            if(configs.verbose):
                print('IMS shape: ', ims.shape)
                print('Stuff input to schedule sampling: ', eta, itr, args.img_channel, batch_size)
            eta, real_input_flag = schedule_sampling(eta, itr, args.img_channel, batch_size)
            if(configs.verbose):
                print('Stuff output from schedule sampling: ', eta, real_input_flag.shape)

            l1, l2 = train(model, ims, real_input_flag, args, itr, next(iter(val_input_handle)))
            losses_l1.append(l1.item())
            losses_l2.append(l2.item())
            
            if itr % configs.plot_interval == 0:
                fig, ax = plt.subplots(2, 1, figsize=(13, 5))
                a = ax.flatten()
                a[0].plot(losses_l1, 'r')
                a[0].set_title('Loss L1')
                a[1].plot(losses_l2, 'r')
                a[1].set_title('Loss L2')
                plt.show()
            
            if itr % args.snapshot_interval == 0 and itr > begin:
                model.save(itr)
            itr += 1
        print(f'Epoch: [{epoch}/{args.max_epoches}]')

    evaluation_proper(model, val_input_handle, configs, out_len=10)
            
#             meminfo_end = pynvml.nvmlDeviceGetMemoryInfo(handle)
#             if(configs.verbose):
#                 print("GPU memory:%dM" % ((meminfo_end.used - meminfo_begin.used) / (1024 ** 2)))


def test_wrapper(model, val_ds):
    model.load(args.pretrained_model)
    test_input_handle = val_ds

    itr = 1
    for i in range(itr):
        trainer.test(model, test_input_handle, args, itr)

# **MODEL FACTORY**

In [None]:
class Model(object):
    def __init__(self, configs):
        self.configs = configs
        self.patch_height = configs.img_height // configs.patch_size
        self.patch_width = configs.img_width // configs.patch_size
        self.patch_channel = configs.img_channel * (configs.patch_size ** 2)
        self.num_layers = configs.num_layers
        networks_map = {
            'mau': RNN,
        }
        num_hidden = []
        for i in range(configs.num_layers):
            num_hidden.append(configs.num_hidden)
        self.num_hidden = num_hidden
        if configs.model_name in networks_map:
            Network = networks_map[configs.model_name]
            self.network = Network(self.num_layers, self.num_hidden, configs).to(configs.device)
        else:
            raise ValueError('Name of network unknown %s' % configs.model_name)

        self.optimizer = Adam(self.network.parameters(), lr=configs.lr)
        self.scheduler = torch.optim.lr_scheduler.ExponentialLR(self.optimizer, gamma=configs.lr_decay)
        
        self.MSE_criterion = nn.MSELoss()
        self.L1_loss = nn.L1Loss()

    def save(self, itr):
        stats = {'net_param': self.network.state_dict()}
        checkpoint_path = os.path.join(self.configs.save_dir, 'model.ckpt' + '-' + str(itr))
        torch.save(stats, checkpoint_path)
        print("save predictive model to %s" % checkpoint_path)

    def load(self, pm_checkpoint_path):
        print('load predictive model:', pm_checkpoint_path)
        stats = torch.load(pm_checkpoint_path, map_location=torch.device(self.configs.device))
        self.network.load_state_dict(stats['net_param'])

    def train(self, data, mask, itr, val):
        frames = data
        self.network.train()
        val_tensor = torch.FloatTensor(val).to(self.configs.device)
        frames_tensor = torch.FloatTensor(frames).to(self.configs.device)
        mask_tensor = torch.FloatTensor(mask).to(self.configs.device)

        if(self.configs.verbose):
            print('FT', frames_tensor.shape)
            print('MT', mask_tensor.shape)
        
        next_frames = self.network(frames_tensor, mask_tensor)
        if(self.configs.verbose):
            print('Next Frames', next_frames.shape)

        ground_truth = frames_tensor
        if(self.configs.verbose):
            print('Ground', ground_truth[:, 1:].shape)

            
        
        if itr % configs.plot_interval == 0:
            with torch.no_grad():
                self.network.eval()
                x = frames_tensor[0][0:configs.input_length]
                y = frames_tensor[0][configs.input_length:]
                g = next_frames[0][configs.input_length-1:]
                m = mask_tensor[0]
                fig, ax = plt.subplots(4, configs.input_length, figsize=(25, 10))
                for i in range(4):
                    for j in range(configs.input_length):
                        if i == 0:
                            ax[i][j].imshow(x[j].to('cpu').permute(1, 2, 0), cmap='gray')
                            ax[i][j].set_title('T Input')
                        if i == 1:
                            if j == configs.input_length-1:
                                ax[i][j].axis('off')
                                continue
                            ax[i][j].imshow(m[j].to('cpu'))
                            ax[i][j].set_title('T Mask')
                        if i == 2:
                            ax[i][j].imshow(y[j].to('cpu').permute(1, 2, 0), cmap='gray')
                            ax[i][j].set_title('T Ground Truth')
                        if i == 3:
                            ax[i][j].imshow(g[j].detach().to('cpu').permute(1, 2, 0), cmap='gray')
                            ax[i][j].set_title('T Generated')
                        ax[i][j].axis('off')


                x = val_tensor[0][0:configs.input_length]
                y = val_tensor[0][configs.input_length:]
                mask = torch.zeros_like(mask_tensor[0]).unsqueeze(0).to(configs.device)
                next_frameszz = self.network(val_tensor, mask)
                m = mask[0]
                g = next_frameszz[0][configs.input_length-1:]
                fig, ax = plt.subplots(4, configs.input_length, figsize=(25, 10))
                for i in range(4):
                    for j in range(configs.input_length):
                        if i == 0:
                            ax[i][j].imshow(x[j].to('cpu').permute(1, 2, 0), cmap='gray')
                            ax[i][j].set_title('V Input')
                        if i == 1:
                            if j == configs.input_length-1:
                                ax[i][j].axis('off')
                                continue
                            ax[i][j].imshow(m[j].to('cpu'))
                            ax[i][j].set_title('V Mask')
                        if i == 2:
                            ax[i][j].imshow(y[j].to('cpu').permute(1, 2, 0), cmap='gray')
                            ax[i][j].set_title('V Ground Truth')
                        if i == 3:
                            ax[i][j].imshow(g[j].detach().to('cpu').permute(1, 2, 0), cmap='gray')
                            ax[i][j].set_title('V Generated')
                        ax[i][j].axis('off')
                
                self.network.train()
            
                    
            
        batch_size = next_frames.shape[0]

        self.optimizer.zero_grad()
        loss_l1 = self.L1_loss(next_frames,
                               ground_truth[:, 1:])
        loss_l2 = self.MSE_criterion(next_frames,
                                     ground_truth[:, 1:])
        loss_gen = loss_l2
        loss_gen.backward()
        self.optimizer.step()

        if itr >= self.configs.sampling_stop_iter and itr % self.configs.delay_interval == 0:
            self.scheduler.step()
            # self.scheduler_F.step()
            # self.scheduler_D.step()
            print('LR decay to:%.8f', self.optimizer.param_groups[0]['lr'])
        return next_frames, loss_l1.detach().cpu().numpy(), loss_l2.detach().cpu().numpy()

    def test(self, data, mask):
        frames = data
        self.network.eval()
        frames_tensor = torch.FloatTensor(frames).to(self.configs.device)
        mask_tensor = torch.FloatTensor(mask).to(self.configs.device)
        next_frames = self.network(frames_tensor, mask_tensor)
        return next_frames.detach().cpu().numpy()

# **CONFIG**

In [None]:
class Configuration:
    def __init__(self):
        super(Configuration, self).__init__()
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        self.data_train_path = '../input/moving-mnist/mnist_test_seq.npy'
        self.data_test_path = '../input/kthextract-to-jpg/data/test'
        self.input_length = 10
        self.real_length = 20
        self.total_length = 20
        self.img_height = 64
        self.img_width = 64
        self.sr_size = 4
        self.img_channel = 1
        self.patch_size = 1
        self.alpha = 1
        self.model_name = 'mau'
        self.dataset = 'mmnist'
        self.cell_mode = 'normal'
        self.model_mode = 'recall'
        self.num_workers = 2
        self.num_hidden = 64
        self.num_layers = 4
        self.num_heads = 4
        self.filter_size = (5, 5)
        self.stride = 1
        self.time = 2
        self.time_stride = 1
        self.tau = 5
        self.is_training = True
        self.lr = 1e-3
        self.lr_decay = 0.90
        self.delay_interval = 2000
        self.batch_size = 32
        self.max_iterations = 150000
        self.max_epoches = 80
        self.display_interval = 100
        self.plot_interval = 100
        self.test_interval = 1010
        self.snapshot_interval = 1000
        self.num_save_samples = 3
        self.n_gpu = 1
        self.pretrained_model = ''
        self.perforamnce_dir = 'results/mmnist'
        self.save_dir = 'saves/mmnist'
        self.gen_frm_dir = 'results/mmnist/'
        self.scheduled_sampling = True
        self.sampling_stop_iter = 50000
        self.sampling_start_value = 1.0
        self.sampling_changing_rate = 0.000005
        self.verbose = False
        
configs = Configuration()
args = configs

## TEST CASE

In [None]:
def test_model(configs):
    nl = 4
    nh = [64, 64, 64, 64]
    z = torch.randn(1, 20, 1, 256, 256).to(configs.device)
    m = torch.zeros(1, 9, 256, 256, 1).to(configs.device)
    model = RNN(nl, nh, configs).to(configs.device)
    g = model(z, m, True)
    print(g.shape)

#test_model(configs)

In [None]:
#model = Model(args)

In [None]:
#count_parameters(model.network)

# **TRAIN**

In [None]:
#pynvml.nvmlInit()

print('Initializing models')

model = Model(args)

if args.is_training:
    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)
    if not os.path.exists(args.gen_frm_dir):
        os.makedirs(args.gen_frm_dir)
    train_wrapper(model)
else:
    if not os.path.exists(args.gen_frm_dir):
        os.makedirs(args.gen_frm_dir)
    test_wrapper(model)