In [100]:

import numpy as np
import torch
import torchvision
from matplotlib import pyplot as plt
from tqdm.auto import tqdm

device = "cuda" if torch.cuda.is_available() else "cpu"


In [101]:

train_set = np.load("mnist_train_seq.npy")
test_set =  np.load("mnist_test_seq.npy")


In [102]:

class SpatialTemoralLSTMCell(torch.nn.Module):
    
    def __init__(self, image_shape, in_channel, hidden_channels, kernel_size, stride=1):
        super(SpatialTemoralLSTMCell, self).__init__()
        """
        hidden_channels: Number of hidden features map 
        """
        self.hidden_channels = hidden_channels
        self.padding = kernel_size//2
        self.stride = stride
        
        self.g_bias, self.i_bias, self.f_bias, self.o_bias, self.g_prime_bias, self.i_prime_bias, self.f_prime_bias = torch.nn.Parameter(torch.rand(7))
        
        self.conv_x = torch.nn.Sequential(
            torch.nn.Conv2d(in_channels=in_channel, out_channels=self.hidden_channels*7, kernel_size=kernel_size, padding=self.padding, stride=self.stride),
            torch.nn.LayerNorm([hidden_channels*7, image_shape[0], image_shape[1]])
        )
        
        self.conv_h_prev = torch.nn.Sequential(
            torch.nn.Conv2d(in_channels=self.hidden_channels, out_channels=self.hidden_channels*4, kernel_size=kernel_size, padding=self.padding, stride=self.stride),
            torch.nn.LayerNorm([hidden_channels*4, image_shape[0], image_shape[1]])
        )
        
        self.conv_m_prev = torch.nn.Sequential(
            torch.nn.Conv2d(in_channels=self.hidden_channels, out_channels=self.hidden_channels*3, kernel_size=kernel_size, padding=self.padding, stride=self.stride),
            torch.nn.LayerNorm([hidden_channels*3, image_shape[0], image_shape[1]])
        )
        
        self.conv_c = torch.nn.Sequential(
            torch.nn.Conv2d(in_channels=self.hidden_channels, out_channels=self.hidden_channels, kernel_size=kernel_size, padding=self.padding, stride=self.stride),
            torch.nn.LayerNorm([hidden_channels, image_shape[0], image_shape[1]])
        )
        
        self.conv_m = torch.nn.Sequential(
            torch.nn.Conv2d(in_channels=self.hidden_channels, out_channels=self.hidden_channels, kernel_size=kernel_size, padding=self.padding, stride=self.stride),
            torch.nn.LayerNorm([hidden_channels, image_shape[0], image_shape[1]])
        )
        
        self.conv_c_m = torch.nn.Sequential(
            torch.nn.Conv2d(in_channels=self.hidden_channels*2, out_channels=self.hidden_channels, kernel_size=1, padding=0, stride=1)
        )
        
    def forward(self, x, h_prev, c_prev, m_prev):
        conv_x = self.conv_x(x)
        g_x, i_x, f_x, o_x, g_x_prime, i_x_prime, f_x_prime = torch.split(tensor=conv_x, split_size_or_sections=self.hidden_channels, dim=1)
        conv_h_prev = self.conv_h_prev(h_prev)
        g_h, i_h, f_h, o_h = torch.split(tensor=conv_h_prev, split_size_or_sections=self.hidden_channels, dim=1)
        g = torch.tanh(g_x + g_h + self.g_bias)
        i = torch.sigmoid(i_x + i_h + self.i_bias)
        f = torch.sigmoid(f_x + f_h + self.f_bias)
        c = f * c_prev + i * g
        
        conv_m_prev = self.conv_m_prev(m_prev)
        g_m_prime, i_m_prime, f_m_prime = torch.split(tensor=conv_m_prev, split_size_or_sections=self.hidden_channels, dim=1)
        g_prime = torch.tanh(g_x_prime + g_m_prime + self.g_prime_bias)
        i_prime = torch.sigmoid(i_x_prime + i_m_prime + self.i_prime_bias)
        f_prime = torch.sigmoid(f_x_prime + f_m_prime + self.f_prime_bias)
        m = f_prime * m_prev + i_prime * g_prime
        
        o_c = self.conv_c(c)
        o_m = self.conv_m(m)
        o = torch.sigmoid(o_x + o_h + o_c + o_m + self.o_bias)
        
        c_m_cat = torch.cat((c,m), dim=1)
        h = o * torch.tanh(self.conv_c_m(c_m_cat))
        
        return h, c, m
        


In [103]:


class PredRNN(torch.nn.Module):
    
    def __init__(self, nb_layers, image_shape, in_channel, hidden_layer_dim, kernel_size, stride=1, device="cuda"):
        super(PredRNN, self).__init__()
        
        self.nb_layers = nb_layers
        self.hidden_layer_dim = hidden_layer_dim
        self.cell_list = []
        # Fixed hidden layer dim for every layer ==> Fix later by changing param hidden_layer_dim from int to list of int
        for i in range(nb_layers):
            if i == 0:
                new_cell = SpatialTemoralLSTMCell(image_shape=image_shape, in_channel=in_channel, 
                                                  hidden_channels=hidden_layer_dim, kernel_size=kernel_size, stride=stride)
            else:
                new_cell = SpatialTemoralLSTMCell(image_shape=image_shape, in_channel=hidden_layer_dim, 
                                                  hidden_channels=hidden_layer_dim, kernel_size=kernel_size, stride=stride)
            self.cell_list.append(new_cell)
            
        self.cell_list = torch.nn.ModuleList(self.cell_list)
        self.output_conv = torch.nn.Conv2d(in_channels=hidden_layer_dim, out_channels=in_channel, kernel_size=1, stride=1)
            
    
    def forward(self, batch_of_sequences):
        batch, length, nb_channels, height, width = batch_of_sequences.shape
        
        # h_list = torch.zeros(batch, length, self.hidden_layer_dim, height, width, device=device)
        # c_list = torch.zeros(batch, length, self.hidden_layer_dim, height, width, device=device)
        # memory = torch.zeros(batch, self.hidden_layer_dim, height, width, device=device)
        
        # # Recurrent flow (For each timestep, perform vertical flow first)
        # for t in range(length):
        #     h_list[:, 0], c_list[:, 0], memory = self.cell_list[0](batch_of_sequences[:, t], h_list[:, 0], c_list[:, 0], memory)
            
        #     for layer in range(1, self.nb_layers):
        #         h_list[:, layer], c_list[:, layer], memory = self.cell_list[layer](h_list[:, layer-1], h_list[:, layer], c_list[:, layer], memory)
        
        # pred = self.output_conv(h_list[:, -1])
        h_list = []
        c_list = []
        for i in range(self.nb_layers):
            h_list.append(torch.zeros(batch, self.hidden_layer_dim, height, width, device=device))
            c_list.append(torch.zeros(batch, self.hidden_layer_dim, height, width, device=device))

        memory = torch.zeros(batch, self.hidden_layer_dim, height, width, device=device)

        # Recurrent flow (For each timestep, perform vertical flow first)
        for t in range(length):
            h_list[0], c_list[0], memory = self.cell_list[0](batch_of_sequences[:, t], h_list[0], c_list[0], memory)
            for layer in range(1, self.nb_layers):
                h_list[layer], c_list[layer], memory = self.cell_list[layer](h_list[layer-1], h_list[layer], c_list[layer], memory)
        
        pred = self.output_conv(h_list[-1])

        return pred
            



In [104]:
# # Test cell
# batch_size = 1
# test_tensor = torch.randn((batch_size, 3, 10, 10))
# test_h_prev = torch.randn((batch_size, 5, 10, 10))
# test_c_prev = torch.randn((batch_size, 5, 10, 10))
# test_m_prev = torch.randn((batch_size, 5, 10, 10))

# test_cell = SpatialTemoralLSTMCell(image_shape=(10,10,3), in_channel=3, hidden_channels=5, kernel_size=5, stride=1)

# h, c, m = test_cell(test_tensor, test_h_prev, test_c_prev, test_m_prev)


In [105]:
# # Test PredRNN
# batch_size, length, channels, height, width = 16, 10, 3, 64, 64
# test_tensor = torch.randn(batch_size, length, channels, height, width)

# pred_rnn = PredRNN(nb_layers=3, image_shape=(height, width, channels), in_channel=channels, hidden_layer_dim=8, kernel_size=5, stride=1)
# pred = pred_rnn(test_tensor)


In [106]:
print(pred.shape)

torch.Size([16, 3, 64, 64])
