In [None]:

class ED3LSTMCell(torch.nn.Module):
    
    def __init__(self, input_shape, in_channels, hidden_channels, kernel_size, stride):
        super(ED3LSTMCell, self).__init__()
        
        self.hidden_channels = hidden_channels
        self.padding = kernel_size//2
        self.stride = stride
        
        self.r_bias, self.i_bias, self.g_bias, self.i_prime_bias, self.g_prime_bias, self.f_prime_bias, self.o_bias = torch.nn.Parameter(torch.randn(7))
        
        length, channel, width, height = input_shape
        
        self.conv_x = torch.nn.Sequential(
            torch.nn.Conv3d(in_channels=in_channels, out_channels=hidden_channels*7, kernel_size=kernel_size, padding=self.padding, stride=stride),
            torch.nn.LayerNorm([hidden_channels*7, length, width, height])
        )
        
        self.conv_h_prev = torch.nn.Sequential(
            torch.nn.Conv3d(in_channels=hidden_channels, out_channels=hidden_channels*4, kernel_size=kernel_size, padding=self.padding, stride=stride),
            torch.nn.LayerNorm([hidden_channels*4, length, width, height])
        )
        
        self.conv_m_prev = torch.nn.Sequential(
            torch.nn.Conv3d(in_channels=hidden_channels, out_channels=hidden_channels*3, kernel_size=kernel_size, padding=self.padding, stride=stride),
            torch.nn.LayerNorm([hidden_channels*3, length, width, height])
        )
        
        self.conv_c = torch.nn.Sequential(
            torch.nn.Conv3d(in_channels=hidden_channels, out_channels=hidden_channels, kernel_size=kernel_size, padding=self.padding, stride=stride),
            torch.nn.LayerNorm([hidden_channels, length, width, height])
        )
        
        self.conv_m = torch.nn.Sequential(
            torch.nn.Conv3d(in_channels=hidden_channels, out_channels=hidden_channels, kernel_size=kernel_size, padding=self.padding, stride=stride),
            torch.nn.LayerNorm([hidden_channels, length, width, height])
        )
        
        self.layer_norm_c = torch.nn.LayerNorm([hidden_channels, length, width, height])
        
        self.conv_c_m = torch.nn.Conv3d(in_channels=hidden_channels*2, out_channels=hidden_channels, kernel_size=1, padding=0, stride=1)
        
    def forward(self, x, h_prev, c_history, m_prev, tau):
        batch, channels, length, height, width = x.shape
        c_prev = c_history[:, -1]
        c_history = c_history[:, -tau:]
        
        conv_x = self.conv_x(x)
        r_x, i_x, g_x, i_x_prime, g_x_prime, f_x_prime, o_x = torch.split(tensor=conv_x, split_size_or_sections=self.hidden_channels, dim=1)
        conv_h_prev = self.conv_h_prev(h_prev)
        r_h, i_h, g_h, o_h = torch.split(tensor=conv_h_prev, split_size_or_sections=self.hidden_channels, dim=1)
        
        r = torch.sigmoid(r_x + r_h + self.r_bias)
        i = torch.sigmoid(i_x + i_h + self.i_bias)
        g = torch.tanh(g_x + g_h + self.g_bias)
        
        # Self attention with query=R and key=value= historical memories c (c_history)
        r = r.reshape(batch, length*height*width, self.hidden_channels)
        c_history = c_history.reshape(batch, tau*length*height*width, self.hidden_channels)
        recall = torch.softmax(r @ c_history.permute(0, 2, 1) @ c_history, dim=1)
        recall = recall.reshape(i.shape)
        
        # !!!! Add layer norm !!!!
        c_new = (i * g) + self.layer_norm_c(c_prev + recall)
        
        conv_m_prev = self.conv_m_prev(m_prev)
        i_m_prime, g_m_prime, f_m_prime = torch.split(tensor=conv_m_prev, split_size_or_sections=self.hidden_channels, dim=1)
        
        i_prime = torch.sigmoid(i_x_prime + i_m_prime + self.i_prime_bias)
        g_prime = torch.tanh(g_x_prime + g_m_prime + self.g_prime_bias)
        f_prime = torch.sigmoid(f_x_prime + f_m_prime + self.f_prime_bias)
        m_new = i_prime*g_prime + f_prime*m_prev
        
        o_c = self.conv_c(c_new)
        o_m = self.conv_m(m_new)
        
        o = torch.sigmoid(o_x + o_h + o_c + o_m + self.o_bias)
        c_m_cat = torch.cat((c_new, m_new), dim = 1)
        h_new = o * torch.tanh(self.conv_c_m(c_m_cat))
        
        return h_new, c_new, m_new
        
        

In [None]:

# Test ED3LSTMCell
hidden_layer = 1
batch = 2
length=5
tau = 2
x = torch.randn(batch, 3, length, 64, 64)
h_prev = torch.randn(batch, hidden_layer, length, 64, 64)
m_prev = torch.randn(batch, hidden_layer, length, 64, 64)
c_history = torch.randn(batch, tau, hidden_layer, length, 64, 64)

cell = ED3LSTMCell(input_shape=(length, 3, 64, 64), in_channels=3, hidden_channels=hidden_layer, kernel_size=5, stride=1)
h_new, c_new, m_new = cell(x, h_prev, c_history, m_prev, tau)

print(h_new.shape)
print(c_new.shape)
print(m_new.shape)


pred = torch.randn(batch, hidden_layer, length, 64, 64)
loss = torch.nn.MSELoss()(pred, h_new)
optim = torch.optim.Adam(cell.parameters())
optim.zero_grad()
loss.backward()
optim.step()


In [None]:

window_size = 3
window_stride = 1


# !!! X train of shape batch x channel x nb_element x window_size x height x width

# Load dataset
training_set =  np.load("./mnist_train_seq.npy")
training_set = training_set.reshape(80000, 20, 1, 64, 64)
# Shuffle dataset
random_indice = np.arange(0, 20000)
np.random.shuffle(random_indice)
training_set = training_set[random_indice]

x_train, y_train = torch.tensor(training_set[:10000, :10, ...]).float(), torch.tensor(training_set[:10000, 10, ...]).float()
x_train = x_train.permute(0, 2, 1, 3, 4)
x_train = extract_slided_sequence(x_train, window_size, window_stride)

# testing_set = np.load("./mnist_test_seq.npy")
# testing_set = testing_set.reshape(10000, 20, 1, 64, 64)
x_val, y_val = torch.tensor(training_set[10000:15000, :10, ...]).float(), torch.tensor(training_set[10000:15000, 10, ...]).float()
x_val = x_val.permute(0, 2, 1, 3, 4)
x_val = extract_slided_sequence(x_val, window_size, window_stride)

x_test, y_test = torch.tensor(training_set[15000:20000, :10, ...]).float(), torch.tensor(training_set[15000:20000, 10, ...]).float()
x_test = x_test.permute(0, 2, 1, 3, 4)
x_test = extract_slided_sequence(x_test, window_size, window_stride)

del training_set
# del testing_set

training_set = DataLoader(TensorDataset(x_train, y_train), batch_size=3, shuffle=True)
validation_set = DataLoader(TensorDataset(x_val, y_val), batch_size=3, shuffle=False)
testing_set = DataLoader(TensorDataset(x_test, y_test), batch_size=3, shuffle=False)
