In [None]:
from google.colab import drive
drive.mount('/gdrive/', force_remount=False)
!ls /gdrive

In [None]:
import os

BASE_PATH = '/gdrive/My Drive/colab_files/framelstm/'
if not os.path.exists(BASE_PATH):
    os.makedirs(BASE_PATH)

!pwd
!ls
!echo
os.chdir(BASE_PATH)
if not os.path.exists(BASE_PATH + 'MovingMNIST.py'):
    !wget https://raw.githubusercontent.com/tychovdo/MovingMNIST/master/MovingMNIST.py
!pwd
!ls -al

In [None]:
import torch
import torch.nn as nn

import seaborn as sns
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
import sys
sys.path.append(BASE_PATH)
import pt_util

In [None]:
# basic code reference from https://stackabuse.com/time-series-prediction-using-lstm-with-pytorch-in-python/

In [None]:
from MovingMNIST import MovingMNIST

# MovingMNIST dataset from http://www.cs.toronto.edu/~nitish/unsup_video.pdf paper
# MovingMNIST code from https://github.com/tychovdo/MovingMNIST
# More general reference from http://www.cs.toronto.edu/~nitish/unsupervised_video/
train_set = MovingMNIST(root='data/mnist', train=True, download=True)
test_set = MovingMNIST(root='data/mnist', train=False, download=True)

batch_size = 1
mnist_seq_length = 10 # this is constant in the dataset, I think

train_loader = torch.utils.data.DataLoader(
    dataset=train_set,
    batch_size=batch_size,
    shuffle=True
)
test_loader = torch.utils.data.DataLoader(
    dataset=test_set,
    batch_size=batch_size,
    shuffle=False
)

print('==>>> total training batch number: {}'.format(len(train_loader)))
print('==>>> total testing batch number: {}'.format(len(test_loader)))

for seq, seq_target in train_loader:
    print('--- Sample tensor sizes... ---')
    print('Input:  ', seq.shape)
    print('Target: ', seq_target.shape)
    break

In [None]:
class LSTM(nn.Module):
    def __init__(self, input_size=1, hidden_layer_size=100, output_size=1):
        super().__init__()
        self.hidden_layer_size = hidden_layer_size
        self.lstm = nn.LSTM(input_size, hidden_layer_size)
        self.hidden_cell = (
            torch.zeros(1,1,self.hidden_layer_size),
            torch.zeros(1,1,self.hidden_layer_size)
        )
        self.linear = nn.Linear(hidden_layer_size, output_size)

    def forward(self, input_seq):
        lstm_out, self.hidden_cell = self.lstm(input_seq.view(len(input_seq), 1, -1), self.hidden_cell)
        predictions = self.linear(lstm_out.view(len(input_seq), -1))
        return predictions[-1:]

    def loss(self, loss_function, prediction, labels):
        loss_val = loss_function(prediction, labels)
        return loss_val

    def save_model(self, file_path, num_to_keep=1):
        pt_util.save(self, file_path, num_to_keep)
        
    def save_best_model(self, accuracy, file_path, num_to_keep=1):
        if self.accuracy == None or accuracy > self.accuracy:
            self.accuracy = accuracy
            self.save_model(file_path, num_to_keep)

    def load_model(self, file_path):
        pt_util.restore(self, file_path)

    def load_last_model(self, dir_path):
        return pt_util.restore_latest(self, dir_path)

In [None]:
epochs = 10
data_size = 64*64

In [None]:
model = LSTM(input_size=data_size, output_size=data_size)
model.load_model(BASE_PATH+'10e_saved.pt')
loss_function = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

model.train()

losses = []
for i in range(epochs):
    losses = []
    for batch_idx, (seq, seq_target) in enumerate(train_loader):
        flat_seq = seq.permute(1,2,3,0).view(mnist_seq_length,-1).float()
        flat_seq_target = seq_target[:, 0, :, :].view(1,-1).float()
        # initialize batch
        optimizer.zero_grad()
        model.hidden_cell = (
            torch.zeros(1, 1, model.hidden_layer_size),
            torch.zeros(1, 1, model.hidden_layer_size)
        )
        # forward
        y_pred = model(flat_seq)
        # backward
        single_loss = model.loss(loss_function, y_pred, flat_seq_target)
        losses.append(single_loss.item())
        single_loss.backward()
        # and step :)
        optimizer.step()
        if batch_idx % 1000 == 0:
            print(f'  {batch_idx:5} : {np.mean(losses):10.8f}')
    if i%1 == 0:
        print(f'epoch: {i:3} loss: {np.mean(losses):10.8f}')

print(f'epoch: {i:3} loss: {np.mean(losses):10.10f}')
model.save_model(BASE_PATH+'20e_saved.pt')

In [None]:
# model = LSTM(input_size=data_size, output_size=data_size)
# model.load_model(BASE_PATH+'20e_saved.pt')

from matplotlib.animation import FuncAnimation

model.eval()

fut_pred = 500

for seq, seq_target in list(test_loader):
    # get the seed sequence
    flat_seq = seq.permute(1,2,3,0).view(mnist_seq_length,-1).float()
    test_inputs = flat_seq.tolist()

    # make following predictions
    for i in range(fut_pred):
        seq = torch.as_tensor(test_inputs[-mnist_seq_length:])
        with torch.no_grad():
            model.hidden = (
                torch.zeros(1, 1, model.hidden_layer_size),
                torch.zeros(1, 1, model.hidden_layer_size)
            )
            test_inputs.append(model(seq).view(data_size).tolist())
    break

# Reference point https://brushingupscience.com/2016/06/21/matplotlib-animations-the-easy-way/
def create_animation(num_images, filename):
    # show input and following predictions
    test_input_preview = torch.as_tensor(test_inputs[:num_images])

    fig, ax = plt.subplots(figsize=(5, 3))

    def animate(i):
        ax.set_title('Frame ' + str(i))
        ax.imshow(test_input_preview[i, :].view(64,64), interpolation='nearest')

    anim = FuncAnimation(
        fig, animate, interval=150, frames=num_images)
    
    plt.draw()
    plt.show()

    anim.save(BASE_PATH+filename)

create_animation(20, '20e_preds_quick.mp4')
create_animation(60, '20e_preds_long.mp4')