In [1]:

import numpy as np
import torch
from torch.utils.data import TensorDataset, DataLoader
import torchvision
from matplotlib import pyplot as plt
from tqdm.auto import tqdm

# from google.colab import drive
# drive.mount('/content/drive')

device = "cuda" if torch.cuda.is_available() else "cpu"
# torch.cuda.get_device_name(0)


In [2]:

class ED3LSTMCell(torch.nn.Module):
    
    def __init__(self, input_shape, in_channels, hidden_channels, kernel_size, stride):
        super(ED3LSTMCell, self).__init__()
        
        self.hidden_channels = hidden_channels
        self.padding = kernel_size//2
        self.stride = stride
        
        self.r_bias, self.i_bias, self.g_bias, self.i_prime_bias, self.g_prime_bias, self.f_prime_bias, self.o_bias = torch.nn.Parameter(torch.randn(7))
        
        channel, length, width, height = input_shape

        self.conv_x = torch.nn.Sequential(
            torch.nn.Conv3d(in_channels=in_channels, out_channels=hidden_channels*7, kernel_size=kernel_size, padding=self.padding, stride=stride),
            torch.nn.LayerNorm([hidden_channels*7, length, width, height])
        )
        
        self.conv_h_prev = torch.nn.Sequential(
            torch.nn.Conv3d(in_channels=hidden_channels, out_channels=hidden_channels*4, kernel_size=kernel_size, padding=self.padding, stride=stride),
            torch.nn.LayerNorm([hidden_channels*4, length, width, height])
        )
        
        self.conv_m_prev = torch.nn.Sequential(
            torch.nn.Conv3d(in_channels=hidden_channels, out_channels=hidden_channels*3, kernel_size=kernel_size, padding=self.padding, stride=stride),
            torch.nn.LayerNorm([hidden_channels*3, length, width, height])
        )
        
        self.conv_c = torch.nn.Sequential(
            torch.nn.Conv3d(in_channels=hidden_channels, out_channels=hidden_channels, kernel_size=kernel_size, padding=self.padding, stride=stride),
            torch.nn.LayerNorm([hidden_channels, length, width, height])
        )
        
        self.conv_m = torch.nn.Sequential(
            torch.nn.Conv3d(in_channels=hidden_channels, out_channels=hidden_channels, kernel_size=kernel_size, padding=self.padding, stride=stride),
            torch.nn.LayerNorm([hidden_channels, length, width, height])
        )
        
        self.layer_norm_c = torch.nn.LayerNorm([hidden_channels, length, width, height])
        
        self.conv_c_m = torch.nn.Conv3d(in_channels=hidden_channels*2, out_channels=hidden_channels, kernel_size=1, padding=0, stride=1)
        
    def forward(self, x, h_prev, c_history, m_prev, tau=-1):
        batch, channels, length, height, width = x.shape
        c_prev = c_history[-1]
        if (tau==-1 or tau > len(c_history)):
            tau = len(c_history)
        c_history = torch.stack(c_history[-tau:])
        c_history = c_history.permute(1,0,2,3,4,5)
#         c_history = torch.cat([c_history[-i].reshape(batch, 1, self.hidden_channels, length, height, width) for i in range(tau, 0, -1)], dim=1)
#         print(c_history.shape)
        
        conv_x = self.conv_x(x)
        r_x, i_x, g_x, i_x_prime, g_x_prime, f_x_prime, o_x = torch.split(tensor=conv_x, split_size_or_sections=self.hidden_channels, dim=1)
        conv_h_prev = self.conv_h_prev(h_prev)
        r_h, i_h, g_h, o_h = torch.split(tensor=conv_h_prev, split_size_or_sections=self.hidden_channels, dim=1)
        
        r = torch.sigmoid(r_x + r_h + self.r_bias)
        i = torch.sigmoid(i_x + i_h + self.i_bias)
        g = torch.tanh(g_x + g_h + self.g_bias)
        
        # Self attention with query=R and key=value= historical memories c (c_history)
        r = r.reshape(batch, length*height*width, self.hidden_channels)
        c_history = c_history.reshape(batch, tau*length*height*width, self.hidden_channels)
        
        print(r.shape)
        print(c_history.shape)
        # !!!!!!!!!!!!!!! Error here: Tesla V100-SXM2-16GB does not have enough memory to perform softmax !!!!!!!!!!!!!!!
        recall = torch.softmax(r @ c_history.permute(0, 2, 1), dim=1) @ c_history 
        print(recall.shape)
        recall = recall.reshape(i.shape)
        print(recall.shape)
        
        c_new = (i * g) + self.layer_norm_c(c_prev + recall)
        
        conv_m_prev = self.conv_m_prev(m_prev)
        i_m_prime, g_m_prime, f_m_prime = torch.split(tensor=conv_m_prev, split_size_or_sections=self.hidden_channels, dim=1)
        
        i_prime = torch.sigmoid(i_x_prime + i_m_prime + self.i_prime_bias)
        g_prime = torch.tanh(g_x_prime + g_m_prime + self.g_prime_bias)
        f_prime = torch.sigmoid(f_x_prime + f_m_prime + self.f_prime_bias)
        m_new = i_prime*g_prime + f_prime*m_prev
        
        o_c = self.conv_c(c_new)
        o_m = self.conv_m(m_new)
        
        o = torch.sigmoid(o_x + o_h + o_c + o_m + self.o_bias)
        c_m_cat = torch.cat((c_new, m_new), dim = 1)
        h_new = o * torch.tanh(self.conv_c_m(c_m_cat))
        
        return h_new, c_new, m_new
        
        

In [3]:

class ED3LSTM(torch.nn.Module):
    
    def __init__(self, nb_layers, encoder_hidden_layer_dim, input_shape, in_channel, hidden_layer_dim, kernel_size, stride):
        super(ED3LSTM, self).__init__()
        
        self.nb_layers = nb_layers
        self.hidden_layer_dim = hidden_layer_dim
        
        channels, length, height, width = input_shape
        
        self.encoder = torch.nn.Sequential(
            torch.nn.Conv3d(in_channels=channels, out_channels=encoder_hidden_layer_dim, kernel_size=kernel_size, padding=kernel_size//2, stride=1)
        )
        
        ed3_lstm_cells = []
        for i in range(nb_layers):
            if i == 0:
                new_cell = ED3LSTMCell(input_shape=input_shape, in_channels=encoder_hidden_layer_dim, hidden_channels=hidden_layer_dim, kernel_size=kernel_size, stride=stride)
            else:
                new_cell = ED3LSTMCell(input_shape=input_shape, in_channels=hidden_layer_dim, hidden_channels=hidden_layer_dim, kernel_size=kernel_size, stride=stride)
            ed3_lstm_cells.append(new_cell)
            
        self.ed3_lstm_cells = torch.nn.ModuleList(ed3_lstm_cells)

        
        
        self.decoder = torch.nn.Sequential(
            torch.nn.Conv3d(in_channels=hidden_layer_dim, out_channels=channels, kernel_size=(window_size, 1, 1), padding=0, stride=(window_size, 1, 1))
        )
        
    
    def forward(self, input_sequence, device="cuda", tau=-1):
        """batch_sequence: tensor with shape (batch, channel, nb_element, window_size, height, width) containing batch of slided consecutive frames."""
        batch, channel, nb_element, window_size, height, width = input_sequence.shape
        
        
        # 1-D list with length nb_layers including h of each layers
        h_list = []
        # 2-D list. The first dimension represent list of cell memory of a specific layer.
        # Length of first dimention is nb_layers. Length of second dimension is nb_element (i.e number of timesteps)
        # The element in the second dimension is a tensor representating a cell memeory at a specific layer and a specific timestep
        c_list = []
        # Store list of prediction
        prediction = []
        for layer in range(self.nb_layers):
            h_list.append(torch.zeros(batch, self.hidden_layer_dim, window_size, height, width, device=device))
            c_list.append([])
            for time_step in range(nb_element):
                c_list[layer].append(torch.zeros(batch, self.hidden_layer_dim, window_size, height, width, device=device))
        
        memory = torch.zeros(batch, self.hidden_layer_dim, window_size, height, width, device=device)
        for time_step in range(nb_element):
            print("time_step: ", time_step)
            encoder_output = self.encoder(input_sequence[:,:,time_step])
            c_history = c_list[0][:time_step+1]
            print("layer: ", 0)
            h_list[0], c_list[0][time_step], memory = self.ed3_lstm_cells[0](encoder_output, h_list[0], c_history, memory, tau=tau)
            
            for layer in range(1, self.nb_layers):
                print("layer: ", layer)
                c_history = c_list[layer][:time_step+1]
                h_list[layer], c_list[layer][time_step], memory = self.ed3_lstm_cells[layer](h_list[layer-1], h_list[layer], c_history, memory, tau=tau)
            print()
            timestep_prediction = self.decoder(h_list[-1])
            prediction.append(timestep_prediction)
#         return prediction
        prediction = torch.stack(prediction).squeeze(dim=2).permute(1,0,2,3,4)
        return prediction



In [4]:

def extract_slided_sequence(batch_sequences, window_size=2, window_stride=1):
    batch, channels, length, height, width  = batch_sequences.shape
    nb_elements = (length - window_size) // window_stride + 1
    slided_sequences = torch.zeros(batch, channels, nb_elements, window_size, height, width)
    
    for i in range(0, length - window_size+1, window_stride):
        element = batch_sequences[:, :, i:i+window_size, ...]
        slided_sequences[:,:,i//window_stride,:,:,:] = element
    return slided_sequences


In [5]:
# from pynvml import *
# nvmlInit()
# h = nvmlDeviceGetHandleByIndex(0)
# info = nvmlDeviceGetMemoryInfo(h)
# print(f'total    : {info.total}')
# print(f'free     : {info.free}')
# print(f'used     : {info.used}')

In [6]:

# # !!! X train of shape batch x channel x nb_element x window_size x height x width
window_size = 8
window_stride = 1

# # Load dataset
training_set =  np.load("./mnist_train_seq.npy")
training_set = training_set.reshape(80000, 20, 1, 64, 64)
x_train, y_train = torch.tensor(training_set[:10000, :10, ...]).float(), torch.tensor(training_set[:10000, window_size:10+window_size, ...]).float()
x_train = x_train.permute(0, 2, 1, 3, 4)
y_train = y_train.permute(0, 2, 1, 3, 4)

testing_set = np.load("./mnist_test_seq.npy")
testing_set = testing_set.reshape(10000, 20, 1, 64, 64)
x_val, y_val = torch.tensor(testing_set[:5000, :10, ...]).float(), torch.tensor(testing_set[:5000, window_size:10+window_size, ...]).float()
x_val = x_val.permute(0, 2, 1, 3, 4)
y_val = y_val.permute(0, 2, 1, 3, 4)

x_test, y_test = torch.tensor(testing_set[5000:, :10, ...]).float(), torch.tensor(testing_set[5000:, window_size:10+window_size, ...]).float()
x_test = x_test.permute(0, 2, 1, 3, 4)
y_test = y_test.permute(0, 2, 1, 3, 4)

del testing_set
del training_set

# Extract training, validation, testing set
x_train = extract_slided_sequence(x_train, window_size, window_stride)
x_val = extract_slided_sequence(x_val, window_size, window_stride)
x_test = extract_slided_sequence(x_test, window_size, window_stride)

training_set = DataLoader(TensorDataset(x_train, y_train), batch_size=2, shuffle=True)
validation_set = DataLoader(TensorDataset(x_val, y_val), batch_size=2, shuffle=False)
testing_set = DataLoader(TensorDataset(x_test, y_test), batch_size=2, shuffle=False)


In [7]:

channels = 1
height = 64
width = 64
nb_layers = 4
encoder_hidden_layer_dim = 3
hidden_layer_dim = 64
kernel_size = 5
stride = 1
tau = -1

# Init model, parameters
ed3_lstm = ED3LSTM(nb_layers=nb_layers, encoder_hidden_layer_dim=encoder_hidden_layer_dim, input_shape=(channels, window_size, height, width), 
                   in_channel=channels, hidden_layer_dim=hidden_layer_dim, kernel_size=kernel_size, stride=stride)
ed3_lstm.to(device=device)
optim = torch.optim.Adam(ed3_lstm.parameters())
l1_loss = torch.nn.L1Loss()
l2_loss = torch.nn.MSELoss()
train_loss = []
val_loss = []
current_epoch = 0
epochs = 20



In [None]:


for sequence, target in tqdm(training_set):
    sequence = sequence.to(device=device)
    target = target.to(device=device)
    pred = ed3_lstm(sequence, device=device, tau=tau)
    loss = l2_loss(pred, target) + l1_loss(pred, target)
    optim.zero_grad()
    loss.backward()
    optim.step()
    print(loss.item())



HBox(children=(FloatProgress(value=0.0, max=5000.0), HTML(value='')))

time_step:  0
layer:  0
torch.Size([2, 32768, 8])
torch.Size([2, 32768, 8])
torch.Size([2, 32768, 8])
torch.Size([2, 8, 8, 64, 64])
layer:  1
torch.Size([2, 32768, 8])
torch.Size([2, 32768, 8])


In [None]:

# a = torch.randn(4, 3, 5, 64, 64)
# conv = torch.nn.Conv3d(in_channels=3, out_channels=4, kernel_size=(2,5,5), stride=1, padding=(1,2,2))
# b = conv(a)
# b.shape
