In [30]:
from collections import OrderedDict
import torch
import torch.nn as nn


In [31]:
class CLSTM_cell(nn.Module):
    def __init__(self, shape, input_channels, filter_size, num_features):
        super(CLSTM_cell, self).__init__()

        self.shape = shape  # H, W
        self.input_channels = input_channels
        self.filter_size = filter_size
        self.num_features = num_features
        # in this way the output has the same size
        self.padding = (filter_size - 1) // 2
        self.conv = nn.Sequential(
            nn.Conv2d(
                self.input_channels + self.num_features,
                4 * self.num_features,
                self.filter_size,
                1,
                self.padding,
            ),
            nn.GroupNorm(4 * self.num_features // 32, 4 * self.num_features),
        )

    def forward(self, inputs=None, hidden_state=None, seq_len=10):
        #  seq_len=10 for moving_mnist
        if hidden_state is None:
            hx = torch.zeros(
                inputs.size(1), self.num_features, self.shape[0], self.shape[1]
            )
            cx = torch.zeros(
                inputs.size(1), self.num_features, self.shape[0], self.shape[1]
            )
        else:
            hx, cx = hidden_state
        output_inner = []
        for index in range(seq_len):
            if inputs is None:
                x = torch.zeros(
                    hx.size(0), self.input_channels, self.shape[0], self.shape[1]
                )
            else:
                x = inputs[index, ...]

            combined = torch.cat((x, hx), 1)
            gates = self.conv(combined)  # gates: S, num_features*4, H, W
            # it should return 4 tensors: i,f,g,o
            ingate, forgetgate, cellgate, outgate = torch.split(
                gates, self.num_features, dim=1
            )
            ingate = torch.sigmoid(ingate)
            forgetgate = torch.sigmoid(forgetgate)
            cellgate = torch.tanh(cellgate)
            outgate = torch.sigmoid(outgate)

            cy = (forgetgate * cx) + (ingate * cellgate)
            hy = outgate * torch.tanh(cy)
            output_inner.append(hy)
            hx = hy
            cx = cy
        return torch.stack(output_inner), (hy, cy)


In [32]:
convlstm_encoder_params = [
    [
        OrderedDict({"conv1_leaky_1": [1, 16, 3, 1, 1]}),
        OrderedDict({"conv2_leaky_1": [64, 64, 3, 2, 1]}),
        OrderedDict({"conv3_leaky_1": [96, 96, 3, 2, 1]}),
    ],
    [
        CLSTM_cell(shape=(64, 64), input_channels=16, filter_size=5, num_features=64),
        CLSTM_cell(shape=(32, 32), input_channels=64, filter_size=5, num_features=96),
        CLSTM_cell(shape=(16, 16), input_channels=96, filter_size=5, num_features=96),
    ],
]



In [33]:
conv = CLSTM_cell((32, 32), 3, 5, 24)
print(conv)


CLSTM_cell(
  (conv): Sequential(
    (0): Conv2d(27, 96, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (1): GroupNorm(3, 96, eps=1e-05, affine=True)
  )
)


In [34]:
batch_size = 11
in_channels = 3
frame_size = (32, 32)
seq_length = 12
X = torch.rand(batch_size, seq_length, in_channels, *frame_size)


In [35]:
outputs, (h ,c) = conv(X.transpose(0,1))


print(outputs.size())
print(h.size())
print(c.size())

torch.Size([10, 11, 24, 32, 32])
torch.Size([11, 24, 32, 32])
torch.Size([11, 24, 32, 32])


In [None]:
output[-1,...]

In [36]:
X.transpose(0,1).size()

torch.Size([12, 11, 3, 32, 32])

In [37]:
X.transpose(0,1)[0,...].size()

torch.Size([11, 3, 32, 32])