In [None]:
from google.colab import drive
drive.mount('/gdrive/', force_remount=False)
!ls /gdrive

In [None]:
import os

BASE_PATH = '/gdrive/My Drive/colab_files/framelstm/'
if not os.path.exists(BASE_PATH):
    os.makedirs(BASE_PATH)

!pwd
!ls
!echo
os.chdir(BASE_PATH)
if not os.path.exists(BASE_PATH + 'MovingMNIST.py'):
    !wget https://raw.githubusercontent.com/tychovdo/MovingMNIST/master/MovingMNIST.py
!pwd
!ls -al

In [None]:
import torch
import torch.nn as nn

import seaborn as sns
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline

import sys
sys.path.append(BASE_PATH)
import pt_util

In [None]:
# basic code reference from https://stackabuse.com/time-series-prediction-using-lstm-with-pytorch-in-python/

In [None]:
from MovingMNIST import MovingMNIST

# MovingMNIST from http://www.cs.toronto.edu/~nitish/unsupervised_video/
# MovingMNIST dataset specifically from http://www.cs.toronto.edu/~nitish/unsup_video.pdf paper
# MovingMNIST code specifically from https://github.com/tychovdo/MovingMNIST
train_set = MovingMNIST(root='data/mnist', train=True, download=True)
test_set = MovingMNIST(root='data/mnist', train=False, download=True)

batch_size = 1
mnist_seq_length = 10 # this is constant in the dataset, I think

train_loader = torch.utils.data.DataLoader(
    dataset=train_set,
    batch_size=batch_size,
    shuffle=True
)
test_loader = torch.utils.data.DataLoader(
    dataset=test_set,
    batch_size=batch_size,
    shuffle=False
)

print('==>>> total training batch number: {}'.format(len(train_loader)))
print('==>>> total testing batch number: {}'.format(len(test_loader)))

for seq, seq_target in train_loader:
    print('--- Sample tensor sizes... ---')
    print('Input:  ', seq.shape)
    print('Target: ', seq_target.shape)
    break

In [None]:
class LSTM(nn.Module):
    def __init__(self, input_size=1, hidden_layer_size=100, output_size=1):
        super().__init__()
        self.hidden_layer_size = hidden_layer_size
        self.lstm = nn.LSTM(input_size, hidden_layer_size)
        self.hidden_cell = (
            torch.zeros(1,1,self.hidden_layer_size),
            torch.zeros(1,1,self.hidden_layer_size)
        )
        self.linear = nn.Linear(hidden_layer_size, output_size)

    def forward(self, input_seq):
        lstm_out, self.hidden_cell = self.lstm(input_seq.view(len(input_seq), 1, -1), self.hidden_cell)
        predictions = self.linear(lstm_out.view(len(input_seq), -1))
        return predictions[-1:]

    def loss(self, loss_function, prediction, labels):
        loss_val = loss_function(prediction, labels)
        return loss_val

    def save_model(self, file_path, num_to_keep=1):
        pt_util.save(self, file_path, num_to_keep)
        
    def save_best_model(self, accuracy, file_path, num_to_keep=1):
        if self.accuracy == None or accuracy > self.accuracy:
            self.accuracy = accuracy
            self.save_model(file_path, num_to_keep)

    def load_model(self, file_path):
        pt_util.restore(self, file_path)

    def load_last_model(self, dir_path):
        return pt_util.restore_latest(self, dir_path)

In [None]:
def train_model(data_size, epochs, start_save_path=None, end_save_path=None):
    # initialize some training meta variables
    training_print_interval = 4000
    epoch_print_interval = 1

    # initialize the model we will be training!
    model = LSTM(input_size=data_size, output_size=data_size)
    if start_save_path is not None:
        model.load_model(start_save_path)
    loss_function = nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

    for i in range(epochs):
        # do one epoch of training
        model.train()
        train_losses = []
        for batch_idx, (seq, seq_target) in enumerate(train_loader):
            flat_seq = seq.permute(1,2,3,0).view(mnist_seq_length,-1).float()
            flat_seq_target = seq_target[:, 0, :, :].view(1,-1).float()
            # initialize batch
            optimizer.zero_grad() 
            model.hidden_cell = (
                torch.zeros(1, 1, model.hidden_layer_size),
                torch.zeros(1, 1, model.hidden_layer_size)
            )
            # forward
            y_pred = model(flat_seq)
            # backward
            single_loss = model.loss(loss_function, y_pred, flat_seq_target)
            train_losses.append(single_loss.item())
            single_loss.backward()
            # and step :)
            optimizer.step()
            if batch_idx % training_print_interval == 0:
                print(f'  {batch_idx:5} : {np.mean(train_losses):10.8f}')
        if i%epoch_print_interval == 0:
            print(f'epoch: {i:3} train loss: {np.mean(train_losses):10.8f}')
        # evaluate the test performance
        model.eval()
        test_losses = []
        with torch.no_grad():
            for batch_idx, (seq, seq_target) in enumerate(test_loader):
                flat_seq = seq.permute(1,2,3,0).view(mnist_seq_length,-1).float()
                flat_seq_target = seq_target[:, 0, :, :].view(1,-1).float()
                model.hidden_cell = (
                    torch.zeros(1, 1, model.hidden_layer_size),
                    torch.zeros(1, 1, model.hidden_layer_size)
                )
                y_pred = model(flat_seq)
                single_loss = model.loss(loss_function, y_pred, flat_seq_target)
                test_losses.append(single_loss.item())
        if i%epoch_print_interval == 0:
            print(f'epoch: {i:3}  test loss: {np.mean(test_losses):10.8f}')
    
    # now save the model we have
    if end_save_path is not None:
        model.save_model(end_save_path)

In [None]:
from matplotlib.animation import FuncAnimation

def generate_pred_animations(img_dim, ckpt_path, anim_length, anim_path):
    data_size = img_dim*img_dim
    model = LSTM(input_size=data_size, output_size=data_size)
    model.load_model(ckpt_path)

    # generate the predictions from the first training set datapoint
    model.eval()
    fut_pred = anim_length
    for seq, seq_target in list(test_loader):
        # get the seed sequence
        flat_seq = seq.permute(1,2,3,0).view(mnist_seq_length,-1).float()
        test_inputs = flat_seq.tolist()

        # make following predictions
        for i in range(fut_pred):
            seq = torch.as_tensor(test_inputs[-mnist_seq_length:])
            with torch.no_grad():
                model.hidden = (
                    torch.zeros(1, 1, model.hidden_layer_size),
                    torch.zeros(1, 1, model.hidden_layer_size)
                )
                test_inputs.append(model(seq).view(data_size).tolist())
        break

    # Reference point https://brushingupscience.com/2016/06/21/matplotlib-animations-the-easy-way/
    # write the predicted output into an actual file now...

    # create list of tensors representing frames to be animated
    test_input_preview = torch.as_tensor(test_inputs[:anim_length])
    fig, ax = plt.subplots(figsize=(5, 3))

    # function to update animation frames
    def animate(i):
        ax.set_title('Frame ' + str(i))
        ax.imshow(test_input_preview[i, :].view(img_dim,img_dim), interpolation='nearest')

    anim = FuncAnimation(
        fig, animate, interval=150, frames=anim_length)
    
    plt.draw()
    plt.show()

    anim.save(anim_path)

In [None]:
train_model(data_size=64*64, epochs=10, start_save_path=None, end_save_path=BASE_PATH+'10e_saved_model.pt')
!cp 10e_saved_model.pt 10e_saved.ckpt
generate_pred_animations(img_dim=64, ckpt_path=BASE_PATH+'10e_saved.ckpt', anim_length=20, anim_path=BASE_PATH+'10e_preds_quick.mp4')
generate_pred_animations(img_dim=64, ckpt_path=BASE_PATH+'10e_saved.ckpt', anim_length=60, anim_path=BASE_PATH+'10e_preds_long.mp4')

In [None]:
train_model(data_size=64*64, epochs=10, start_save_path=BASE_PATH+'10e_saved_model.pt', end_save_path=BASE_PATH+'20e_saved_model.pt')
!cp 20e_saved_model.pt 20e_saved.ckpt
generate_pred_animations(img_dim=64, ckpt_path=BASE_PATH+'20e_saved.ckpt', anim_length=20, anim_path=BASE_PATH+'20e_preds_quick.mp4')
generate_pred_animations(img_dim=64, ckpt_path=BASE_PATH+'20e_saved.ckpt', anim_length=60, anim_path=BASE_PATH+'20e_preds_long.mp4')

In [None]:
train_model(data_size=64*64, epochs=10, start_save_path=BASE_PATH+'20e_saved_model.pt', end_save_path=BASE_PATH+'30e_saved_model.pt')
!cp 30e_saved_model.pt 30e_saved.ckpt
generate_pred_animations(img_dim=64, ckpt_path=BASE_PATH+'30e_saved.ckpt', anim_length=20, anim_path=BASE_PATH+'30e_preds_quick.mp4')
generate_pred_animations(img_dim=64, ckpt_path=BASE_PATH+'30e_saved.ckpt', anim_length=60, anim_path=BASE_PATH+'30e_preds_long.mp4')

In [None]:
train_model(data_size=64*64, epochs=10, start_save_path=BASE_PATH+'30e_saved_model.pt', end_save_path=BASE_PATH+'40e_saved_model.pt')
!cp 40e_saved_model.pt 40e_saved.ckpt
generate_pred_animations(img_dim=64, ckpt_path=BASE_PATH+'40e_saved.ckpt', anim_length=20, anim_path=BASE_PATH+'40e_preds_quick.mp4')
generate_pred_animations(img_dim=64, ckpt_path=BASE_PATH+'40e_saved.ckpt', anim_length=60, anim_path=BASE_PATH+'40e_preds_long.mp4')

In [None]:
train_model(data_size=64*64, epochs=10, start_save_path=BASE_PATH+'40e_saved_model.pt', end_save_path=BASE_PATH+'50e_saved_model.pt')
!cp 50e_saved_model.pt 50e_saved.ckpt
generate_pred_animations(img_dim=64, ckpt_path=BASE_PATH+'50e_saved.ckpt', anim_length=20, anim_path=BASE_PATH+'50e_preds_quick.mp4')
generate_pred_animations(img_dim=64, ckpt_path=BASE_PATH+'50e_saved.ckpt', anim_length=60, anim_path=BASE_PATH+'50e_preds_long.mp4')

In [None]:
import sys, os

def eval_model(data_size, save_path):
    # initialize the model we will be evaluating!
    model = LSTM(input_size=data_size, output_size=data_size)
    model.load_model(save_path)
    loss_function = nn.MSELoss()

    # evaluate the test performance
    model.eval()
    test_losses = []
    with torch.no_grad():
        for batch_idx, (seq, seq_target) in enumerate(test_loader):
            flat_seq = seq.permute(1,2,3,0).view(mnist_seq_length,-1).float()
            flat_seq_target = seq_target[:, 0, :, :].view(1,-1).float()
            model.hidden_cell = (
                torch.zeros(1, 1, model.hidden_layer_size),
                torch.zeros(1, 1, model.hidden_layer_size)
            )
            y_pred = model(flat_seq)
            single_loss = model.loss(loss_function, y_pred, flat_seq_target)
            test_losses.append(single_loss.item())
    print(f'model at {save_path} ...')
    print(f'         test loss: {np.mean(test_losses):10.8f}')

In [None]:
eval_model(data_size=64*64, save_path=BASE_PATH+'10e_saved.ckpt')
eval_model(data_size=64*64, save_path=BASE_PATH+'20e_saved.ckpt')
eval_model(data_size=64*64, save_path=BASE_PATH+'30e_saved.ckpt')
eval_model(data_size=64*64, save_path=BASE_PATH+'40e_saved.ckpt')
eval_model(data_size=64*64, save_path=BASE_PATH+'50e_saved.ckpt')