In [1]:

import numpy as np
import torch
from torch.utils.data import TensorDataset, DataLoader
import torchvision
from matplotlib import pyplot as plt
from tqdm.auto import tqdm

# from google.colab import drive
# drive.mount('/content/drive')

device = "cuda" if torch.cuda.is_available() else "cpu"
# torch.cuda.get_device_name(0)



In [2]:

class Multiplicative_Unit(torch.nn.Module):
    
    def __init__(self, in_channels, kernel_size=3, stride=1, dilation=1):
        super(Multiplicative_Unit, self).__init__()
        
        self.in_channels = in_channels
        padding = kernel_size // 2
        self.mu_conv = torch.nn.Conv2d(in_channels=in_channels, out_channels=in_channels*4, 
                                       kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation)
        
    def forward(self, h):
        mu_conv = self.mu_conv(h)
        g1, g2, g3, u = torch.split(tensor=mu_conv, split_size_or_sections=self.in_channels, dim=1)
        g1 = torch.sigmoid(g1)
        g2 = torch.sigmoid(g2)
        g3 = torch.sigmoid(g3)
        u = torch.tanh(u)
        mu = g1 * torch.tanh(g2*h + g3*u)
        return mu


In [3]:

class Residual_Multiplicative_Blocks(torch.nn.Module):
    
    def __init__(self, in_channels, kernel_size=1, stride=1, dilation=1,
                unit_kernel_size=3, unit_stride=1, unit_dilation=1):
        super(Residual_Multiplicative_Blocks, self).__init__()
        
        padding = kernel_size//2
        self.h1_conv = torch.nn.Conv2d(in_channels=in_channels, out_channels=in_channels//2, 
                                       kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation)
        self.h2_mu = Multiplicative_Unit(in_channels//2, kernel_size=unit_kernel_size, stride=unit_stride, dilation=unit_dilation)
        self.h3_mu = Multiplicative_Unit(in_channels//2, kernel_size=unit_kernel_size, stride=unit_stride, dilation=unit_dilation)
        self.h4_conv = torch.nn.Conv2d(in_channels=in_channels//2, out_channels=in_channels, 
                                       kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation)
    
    def forward(self, h):
        h1 = self.h1_conv(h)
        h2 = self.h2_mu(h1)
        h3 = self.h3_mu(h2)
        h4 = self.h4_conv(h3)
        rmb = h + h4
        return rmb
    


In [4]:

# Test residual blocks
# residual_blocks = Residual_Multiplicative_Blocks(in_channels=3, kernel_size=1, stride=1, dilation=2)
# test_tensor = torch.randn(8, 3, 16, 16)
# test_tensor.shape



In [5]:


class Encoder(torch.nn.Module):
    def __init__(self, in_channels, nb_rmb=8, kernel_size=1, stride=1, dilation=1,
                 unit_kernel_size=3, unit_stride=1, unit_dilation=1):
        super(Encoder, self).__init__()
        
        rmb_list = []
        for i in range(2, nb_rmb+2):
           # According to the paper, use CNN encoders we use two repetitions of the dilation scheme [1, 2, 4, 8]
            unit_dilation = i // 2 * 2
            new_rmb = Residual_Multiplicative_Blocks(in_channels=in_channels, kernel_size=kernel_size, stride=stride, 
                                                     dilation=dilation, unit_kernel_size=unit_kernel_size, 
                                                     unit_stride=unit_stride, unit_dilation=unit_dilation)
            rmb_list.append(new_rmb)
        self.rmb_list = torch.nn.ModuleList(rmb_list)
        
    def forward(self, h):
        for rmb in self.rmb_list:
            h = rmb(h)
        return h



In [6]:

# # Test encoder
# encoder = Encoder(in_channels=3, nb_rmb=8, kernel_size=1, stride=1, dilation=2,
#                  unit_kernel_size=1, unit_stride=1, unit_dilation=3)
# test_tensor = torch.randn(8, 3, 16, 16)
# h = encoder(test_tensor)



In [7]:

class Conv_LSTM_Cell(torch.nn.Module):
    
    def __init__(self, image_shape, in_channels, hidden_channels, kernel_size, stride=1, dilation=1):
        super(Conv_LSTM_Cell, self).__init__()
        
        channels, height, width = image_shape
        padding = kernel_size // 2
        self.hidden_channels = hidden_channels
        self.i_bias, self.f_bias, self.c_bias, self.o_bias = torch.nn.Parameter(torch.randn(4))
        self.conv_x = torch.nn.Conv2d(in_channels=in_channels, out_channels=hidden_channels*4, 
                                       kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation)
        self.conv_h_prev = torch.nn.Conv2d(in_channels=hidden_channels, out_channels=hidden_channels*4, 
                                       kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation)
        self.W_i_c = torch.nn.Parameter(torch.randn(hidden_channels, height, width))
        self.W_f_c = torch.nn.Parameter(torch.randn(hidden_channels, height, width))
        self.W_o_c = torch.nn.Parameter(torch.randn(hidden_channels, height, width))


    def forward(self, x, h_prev, c_prev):
        conv_x = self.conv_x(x)
        i_x, f_x, c_x, o_x = torch.split(tensor=conv_x, split_size_or_sections=self.hidden_channels, dim=1)
        conv_h = self.conv_h_prev(h_prev)
        i_h, f_h, c_h, o_h = torch.split(tensor=conv_h, split_size_or_sections=self.hidden_channels, dim=1)
        i = torch.sigmoid(i_x + i_h + self.W_i_c*c_prev + self.i_bias)
        f = torch.sigmoid(f_x + f_h + self.W_f_c*c_prev + self.f_bias)
        new_c = f*c_prev + i*torch.tanh(c_x + c_h + self.c_bias)
        o = torch.sigmoid(o_x + o_h + self.W_o_c*new_c + self.o_bias)
        new_h = o * torch.tanh(new_c)
        return new_h, new_c
    


In [48]:

class Conv_LSTM(torch.nn.Module):
    
    def __init__(self, nb_layers, image_shape, in_channels, hidden_dim_channels, kernel_size, stride=1):
        super(Conv_LSTM, self).__init__()
        self.nb_layers = nb_layers
        self.hidden_dim_channels = hidden_dim_channels
        cell_list = []
        cell_0 = Conv_LSTM_Cell(image_shape=image_shape, in_channels=in_channels, 
                                          hidden_channels=hidden_dim_channels, kernel_size=kernel_size, stride=1, dilation=1)
        cell_list.append(cell_0)
        for i in range(1, nb_layers):
            new_cell = Conv_LSTM_Cell(image_shape=image_shape, in_channels=hidden_dim_channels, 
                                          hidden_channels=hidden_dim_channels, kernel_size=kernel_size, stride=1, dilation=1)
            cell_list.append(new_cell)
        self.cell_list = torch.nn.ModuleList(cell_list)
        
    def forward(self, x, device="cuda"):
        batch, channels, length, height, width = x.shape
        
        h_list = []
        c_list = []
        
        for layer in range(self.nb_layers):
            h_list.append(torch.zeros(batch, self.hidden_dim_channels, height, width, device=device))
            c_list.append(torch.zeros(batch, self.hidden_dim_channels, height, width, device=device))
        
        for time_step in range(length):
            h_list[0], c_list[0] = self.cell_list[0](x[:, :, time_step], h_list[0], c_list[0])
            for layer in range(1, self.nb_layers):
#                 print(self.cell_list[layer])
                h_list[layer], c_list[layer] = self.cell_list[layer](h_list[layer-1], h_list[layer], c_list[layer])
                
        return h_list[-1], c_list[-1]
            


In [50]:

# # Test ConvLSTM
# x = torch.randn(4, 3, 5, 16, 16)
# h_prev = torch.randn(4, 8, 16, 16)
# c_prev = torch.randn(4, 8, 16, 16)

# conv_lstm = Conv_LSTM(nb_layers=4, image_shape=(3, 16, 16), in_channels=3, 
#                       hidden_dim_channels=8, kernel_size=5, stride=1)

# h, c = conv_lstm(x, device=device)
# h.shape

