In [42]:
from collections import OrderedDict
import torch
import torch.nn as nn
import logging


In [43]:
class CLSTM_cell(nn.Module):
    def __init__(self, shape, input_channels, filter_size, num_features):
        super(CLSTM_cell, self).__init__()

        self.shape = shape  # H, W
        self.input_channels = input_channels
        self.filter_size = filter_size
        self.num_features = num_features
        # in this way the output has the same size
        self.padding = (filter_size - 1) // 2
        self.conv = nn.Sequential(
            nn.Conv2d(
                self.input_channels + self.num_features,
                4 * self.num_features,
                self.filter_size,
                1,
                self.padding,
            ),
            nn.GroupNorm(4 * self.num_features // 32, 4 * self.num_features),
        )

    def forward(self, inputs=None, hidden_state=None, seq_len=10):
        #  seq_len=10 for moving_mnist
        if hidden_state is None:
            hx = torch.zeros(
                inputs.size(1), self.num_features, self.shape[0], self.shape[1]
            )
            cx = torch.zeros(
                inputs.size(1), self.num_features, self.shape[0], self.shape[1]
            )
        else:
            hx, cx = hidden_state
        output_inner = []
        for index in range(seq_len):
            if inputs is None:
                x = torch.zeros(
                    hx.size(0), self.input_channels, self.shape[0], self.shape[1]
                )
            else:
                x = inputs[index, ...]

            combined = torch.cat((x, hx), 1)
            gates = self.conv(combined)  # gates: S, num_features*4, H, W
            # it should return 4 tensors: i,f,g,o
            ingate, forgetgate, cellgate, outgate = torch.split(
                gates, self.num_features, dim=1
            )
            ingate = torch.sigmoid(ingate)
            forgetgate = torch.sigmoid(forgetgate)
            cellgate = torch.tanh(cellgate)
            outgate = torch.sigmoid(outgate)

            cy = (forgetgate * cx) + (ingate * cellgate)
            hy = outgate * torch.tanh(cy)
            output_inner.append(hy)
            hx = hy
            cx = cy
        return torch.stack(output_inner), (hy, cy)


In [44]:
def make_layers(block):
    layers = []
    for layer_name, v in block.items():
        if "pool" in layer_name:
            layer = nn.MaxPool2d(kernel_size=v[0], stride=v[1], padding=v[2])
            layers.append((layer_name, layer))
        elif "deconv" in layer_name:
            transposeConv2d = nn.ConvTranspose2d(
                in_channels=v[0],
                out_channels=v[1],
                kernel_size=v[2],
                stride=v[3],
                padding=v[4],
            )
            layers.append((layer_name, transposeConv2d))
            if "relu" in layer_name:
                layers.append(("relu_" + layer_name, nn.ReLU(inplace=True)))
            elif "leaky" in layer_name:
                layers.append(
                    (
                        "leaky_" + layer_name,
                        nn.LeakyReLU(negative_slope=0.2, inplace=True),
                    )
                )
        elif "conv" in layer_name:
            conv2d = nn.Conv2d(
                in_channels=v[0],
                out_channels=v[1],
                kernel_size=v[2],
                stride=v[3],
                padding=v[4],
            )
            layers.append((layer_name, conv2d))
            if "relu" in layer_name:
                layers.append(("relu_" + layer_name, nn.ReLU(inplace=True)))
            elif "leaky" in layer_name:
                layers.append(
                    (
                        "leaky_" + layer_name,
                        nn.LeakyReLU(negative_slope=0.2, inplace=True),
                    )
                )
        else:
            raise NotImplementedError
    return nn.Sequential(OrderedDict(layers))


In [45]:
class Encoder(nn.Module):
    def __init__(self, subnets, rnns):
        super().__init__()
        assert len(subnets) == len(rnns)
        self.blocks = len(subnets)

        for index, (params, rnn) in enumerate(zip(subnets, rnns), 1):
            # index sign from 1
            setattr(self, "stage" + str(index), make_layers(params))
            setattr(self, "rnn" + str(index), rnn)

    def forward_by_stage(self, inputs, subnet, rnn):
        seq_number, batch_size, input_channel, height, width = inputs.size()
        inputs = torch.reshape(inputs, (-1, input_channel, height, width))
        inputs = subnet(inputs)
        inputs = torch.reshape(
            inputs,
            (seq_number, batch_size, inputs.size(1), inputs.size(2), inputs.size(3)),
        )
        outputs_stage, state_stage = rnn(inputs, None)
        return outputs_stage, state_stage

    def forward(self, inputs):
        inputs = inputs.transpose(0, 1)  # to S,B,1,64,64
        hidden_states = []
        logging.debug(inputs.size())
        for i in range(1, self.blocks + 1):
            inputs, state_stage = self.forward_by_stage(
                inputs, getattr(self, "stage" + str(i)), getattr(self, "rnn" + str(i))
            )
            hidden_states.append(state_stage)
        return tuple(hidden_states)



In [46]:
convlstm_encoder_params = [
    [
        OrderedDict({"conv1_leaky_1": [1, 16, 3, 1, 1]}),
        OrderedDict({"conv2_leaky_1": [64, 64, 3, 2, 1]}),
        OrderedDict({"conv3_leaky_1": [96, 96, 3, 2, 1]}),
    ],
    [
        CLSTM_cell(shape=(64, 64), input_channels=16, filter_size=5, num_features=64),
        CLSTM_cell(shape=(32, 32), input_channels=64, filter_size=5, num_features=96),
        CLSTM_cell(shape=(16, 16), input_channels=96, filter_size=5, num_features=96),
    ],
]


### Experimenting with `CLSTM_cell`

In [47]:
# conv = CLSTM_cell((32, 32), 3, 5, 24)
# print(conv)

# batch_size = 11
# in_channels = 3
# frame_size = (32, 32)
# seq_length = 12
# X = torch.rand(batch_size, seq_length, in_channels, *frame_size)

# # Need switch the batch_size and seq_length dimensions
# outputs, (h, c) = conv(X.transpose(0, 1)) 

# print(outputs.size())
# print(h.size())
# print(c.size())

# # Outputs are the collection of cell states in the sequence
# print((outputs[-1] - h).norm())


In [49]:
encoder = Encoder(convlstm_encoder_params[0],convlstm_encoder_params[1])

In [56]:
encoder.stage1

Sequential(
  (conv1_leaky_1): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (leaky_conv1_leaky_1): LeakyReLU(negative_slope=0.2, inplace=True)
)