# Introduction

It is common nowadays to find data in the form of a sequence of images. The most typical example is video at social networks such as YouTube, Facebook and Instagram. Other classical examples are video calls, movies and trailers, satellites images and security cameras. 
We will show you how to code a ConvLSTM model for frame prediction using MovingMNIST dataset.

## Libraries

In [1]:
# import libraries
import os
import torch

from torch.utils.data import DataLoader

from lightning import Trainer
from multiprocessing import Process
import numpy as np
from torchvision import datasets, transforms
from torchts.utils.start_tensorboard import run_tensorboard
from torchts.nn.models.ConvLSTM import Seq2SeqConvLSTM, MovingMNISTConvLSTM

## Dataloader
The dataloader script f

In [2]:
# from: https://github.com/edenton/svg/blob/master/data/moving_mnist.py

class MovingMNIST(object):
    """Data Handler that creates Bouncing MNIST dataset on the fly."""

    def __init__(self, train, data_root, seq_len=20, num_digits=2, image_size=64, deterministic=True):
        path = data_root
        self.seq_len = seq_len
        self.num_digits = num_digits
        self.image_size = image_size
        self.step_length = 0.1
        self.digit_size = 32
        self.deterministic = deterministic
        self.seed_is_set = False  # multi thread loading
        self.channels = 1

        self.data = datasets.MNIST(
            path,
            train=train,
            download=False,
            transform=transforms.Compose(
                [transforms.Resize(self.digit_size),
                 transforms.ToTensor()]))

        self.N = len(self.data)

    def set_seed(self, seed):
        if not self.seed_is_set:
            self.seed_is_set = True
            np.random.seed(seed)

    def __len__(self):
        return self.N

    def __getitem__(self, index):
        self.set_seed(index)
        image_size = self.image_size
        digit_size = self.digit_size
        x = np.zeros((self.seq_len,
                      image_size,
                      image_size,
                      self.channels),
                     dtype=np.float32)
        for n in range(self.num_digits):
            idx = np.random.randint(self.N)
            digit, _ = self.data[idx]

            sx = np.random.randint(image_size - digit_size)
            sy = np.random.randint(image_size - digit_size)
            dx = np.random.randint(-4, 5)
            dy = np.random.randint(-4, 5)
            for t in range(self.seq_len):
                if sy < 0:
                    sy = 0
                    if self.deterministic:
                        dy = -dy
                    else:
                        dy = np.random.randint(1, 5)
                        dx = np.random.randint(-4, 5)
                elif sy >= image_size - 32:
                    sy = image_size - 32 - 1
                    if self.deterministic:
                        dy = -dy
                    else:
                        dy = np.random.randint(-4, 0)
                        dx = np.random.randint(-4, 5)

                if sx < 0:
                    sx = 0
                    if self.deterministic:
                        dx = -dx
                    else:
                        dx = np.random.randint(1, 5)
                        dy = np.random.randint(-4, 5)
                elif sx >= image_size - 32:
                    sx = image_size - 32 - 1
                    if self.deterministic:
                        dx = -dx
                    else:
                        dx = np.random.randint(-4, 0)
                        dy = np.random.randint(-4, 5)

                x[t, sy:sy + 32, sx:sx + 32, 0] += digit.numpy().squeeze()
                sy += dy
                sx += dx

        x[x > 1] = 1.
        return x

In [3]:
class Parameters:
    def __init__(self, path=None, lr=1e-4, beta_1=0.9, beta_2=0.98, batch_size=64,
                 epochs=300, use_amp=False, n_gpus=1, n_hidden_dim=64, n_layers=1,
                 n_steps_ahead=10, n_steps_past=10):
        self.path = path if path else os.getcwd() + '/data'
        self.lr = lr
        self.beta_1 = beta_1
        self.beta_2 = beta_2
        self.batch_size = batch_size
        self.epochs = epochs
        self.use_amp = use_amp
        self.n_gpus = n_gpus
        self.n_hidden_dim = n_hidden_dim
        self.n_layers = n_layers
        self.n_steps_ahead = n_steps_ahead
        self.n_steps_past = n_steps_past
        
# Example usage:
opt = Parameters()

In [4]:
print(opt.path)

/Users/kiddycharles/torchTS/examples/ConvLSTM/data


In [5]:

train_data = MovingMNIST(train=True,
                         data_root=opt.path,
                         seq_len=opt.n_steps_past + opt.n_steps_ahead,
                         image_size=64,
                         deterministic=True,
                         num_digits=2)

train_loader = torch.utils.data.DataLoader(dataset=train_data,
                                           batch_size=opt.batch_size,
                                           num_workers=4,
                                           shuffle=True)


test_data = MovingMNIST(train=False,
                        data_root=opt.path,
                        seq_len=opt.n_steps_past + opt.n_steps_ahead,
                        image_size=64,
                        deterministic=True,
                        num_digits=2)

test_loader = torch.utils.data.DataLoader(dataset=test_data,
                                          batch_size=opt.batch_size,
                                          num_workers=4,
                                          shuffle=True)


In [None]:
def run_trainer():
    conv_lstm_model = Seq2SeqConvLSTM(nf=opt.n_hidden_dim, in_channel=1)
    model = MovingMNISTConvLSTM(opt=opt, model=conv_lstm_model)
    trainer = Trainer(max_epochs=opt.epochs,
                      accelerator='auto',
                      devices='auto')
    trainer.fit(model, train_dataloaders=train_loader)
    

In [None]:
p1 = Process(target=run_trainer)
p1.start()
p2 = Process(target=run_tensorboard(new_run=True))
p2.start()
p1.join()
p2.join()