In [1]:
import torch
import torch.nn as nn
import numpy as np


In [2]:
class SpatioTemporalLSTMCell(nn.Module):
    def __init__(self, in_channel, num_hidden, width, filter_size, stride, layer_norm):
        super(SpatioTemporalLSTMCell, self).__init__()

        self.num_hidden = num_hidden
        self.padding = filter_size // 2
        self._forget_bias = 1.0
        if layer_norm:
            self.conv_x = nn.Sequential(
                nn.Conv2d(
                    in_channel,
                    num_hidden * 7,
                    kernel_size=filter_size,
                    stride=stride,
                    padding=self.padding,
                    bias=False,
                ),
                nn.LayerNorm([num_hidden * 7, width, width]),
            )
            self.conv_h = nn.Sequential(
                nn.Conv2d(
                    num_hidden,
                    num_hidden * 4,
                    kernel_size=filter_size,
                    stride=stride,
                    padding=self.padding,
                    bias=False,
                ),
                nn.LayerNorm([num_hidden * 4, width, width]),
            )
            self.conv_m = nn.Sequential(
                nn.Conv2d(
                    num_hidden,
                    num_hidden * 3,
                    kernel_size=filter_size,
                    stride=stride,
                    padding=self.padding,
                    bias=False,
                ),
                nn.LayerNorm([num_hidden * 3, width, width]),
            )
            self.conv_o = nn.Sequential(
                nn.Conv2d(
                    num_hidden * 2,
                    num_hidden,
                    kernel_size=filter_size,
                    stride=stride,
                    padding=self.padding,
                    bias=False,
                ),
                nn.LayerNorm([num_hidden, width, width]),
            )
        else:
            self.conv_x = nn.Sequential(
                nn.Conv2d(
                    in_channel,
                    num_hidden * 7,
                    kernel_size=filter_size,
                    stride=stride,
                    padding=self.padding,
                    bias=False,
                ),
            )
            self.conv_h = nn.Sequential(
                nn.Conv2d(
                    num_hidden,
                    num_hidden * 4,
                    kernel_size=filter_size,
                    stride=stride,
                    padding=self.padding,
                    bias=False,
                ),
            )
            self.conv_m = nn.Sequential(
                nn.Conv2d(
                    num_hidden,
                    num_hidden * 3,
                    kernel_size=filter_size,
                    stride=stride,
                    padding=self.padding,
                    bias=False,
                ),
            )
            self.conv_o = nn.Sequential(
                nn.Conv2d(
                    num_hidden * 2,
                    num_hidden,
                    kernel_size=filter_size,
                    stride=stride,
                    padding=self.padding,
                    bias=False,
                ),
            )
        self.conv_last = nn.Conv2d(
            num_hidden * 2, num_hidden, kernel_size=1, stride=1, padding=0, bias=False
        )

    def forward(self, x_t, h_t, c_t, m_t):
        x_concat = self.conv_x(x_t)
        h_concat = self.conv_h(h_t)
        m_concat = self.conv_m(m_t)
        i_x, f_x, g_x, i_x_prime, f_x_prime, g_x_prime, o_x = torch.split(
            x_concat, self.num_hidden, dim=1
        )
        i_h, f_h, g_h, o_h = torch.split(h_concat, self.num_hidden, dim=1)
        i_m, f_m, g_m = torch.split(m_concat, self.num_hidden, dim=1)

        i_t = torch.sigmoid(i_x + i_h)
        f_t = torch.sigmoid(f_x + f_h + self._forget_bias)
        g_t = torch.tanh(g_x + g_h)

        c_new = f_t * c_t + i_t * g_t

        i_t_prime = torch.sigmoid(i_x_prime + i_m)
        f_t_prime = torch.sigmoid(f_x_prime + f_m + self._forget_bias)
        g_t_prime = torch.tanh(g_x_prime + g_m)

        m_new = f_t_prime * m_t + i_t_prime * g_t_prime

        mem = torch.cat((c_new, m_new), 1)
        o_t = torch.sigmoid(o_x + o_h + self.conv_o(mem))
        h_new = o_t * torch.tanh(self.conv_last(mem))

        return h_new, c_new, m_new



In [3]:
class RNN(nn.Module):
    def __init__(self, num_layers, num_hidden, configs):
        super(RNN, self).__init__()

        self.configs = configs
        self.frame_channel = (
            configs.patch_size * configs.patch_size * configs.img_channel
        )
        self.num_layers = num_layers
        self.num_hidden = num_hidden
        cell_list = []

        width = configs.img_width // configs.patch_size
        self.MSE_criterion = nn.MSELoss()

        for i in range(num_layers):
            in_channel = self.frame_channel if i == 0 else num_hidden[i - 1]
            cell_list.append(
                SpatioTemporalLSTMCell(
                    in_channel,
                    num_hidden[i],
                    width,
                    configs.filter_size,
                    configs.stride,
                    configs.layer_norm,
                )
            )
        self.cell_list = nn.ModuleList(cell_list)
        self.conv_last = nn.Conv2d(
            num_hidden[num_layers - 1],
            self.frame_channel,
            kernel_size=1,
            stride=1,
            padding=0,
            bias=False,
        )

    def forward(self, frames_tensor, mask_true):
        # [batch, length, height, width, channel] -> [batch, length, channel, height, width]
        frames = frames_tensor.permute(0, 1, 4, 2, 3).contiguous()
        mask_true = mask_true.permute(0, 1, 4, 2, 3).contiguous()

        batch = frames.shape[0]
        height = frames.shape[3]
        width = frames.shape[4]

        next_frames = []
        h_t = []
        c_t = []

        for i in range(self.num_layers):
            zeros = torch.zeros([batch, self.num_hidden[i], height, width]).to(
                self.configs.device
            )
            h_t.append(zeros)
            c_t.append(zeros)

        memory = torch.zeros([batch, self.num_hidden[0], height, width]).to(
            self.configs.device
        )

        for t in range(self.configs.total_length - 1):
            # reverse schedule sampling
            if self.configs.reverse_scheduled_sampling == 1:
                if t == 0:
                    net = frames[:, t]
                else:
                    net = (
                        mask_true[:, t - 1] * frames[:, t]
                        + (1 - mask_true[:, t - 1]) * x_gen
                    )
            else:
                if t < self.configs.input_length:
                    net = frames[:, t]
                else:
                    net = (
                        mask_true[:, t - self.configs.input_length] * frames[:, t]
                        + (1 - mask_true[:, t - self.configs.input_length]) * x_gen
                    )

            h_t[0], c_t[0], memory = self.cell_list[0](net, h_t[0], c_t[0], memory)

            for i in range(1, self.num_layers):
                h_t[i], c_t[i], memory = self.cell_list[i](
                    h_t[i - 1], h_t[i], c_t[i], memory
                )

            x_gen = self.conv_last(h_t[self.num_layers - 1])
            next_frames.append(x_gen)

        # [length, batch, channel, height, width] -> [batch, length, height, width, channel]
        next_frames = (
            torch.stack(next_frames, dim=0).permute(1, 0, 3, 4, 2).contiguous()
        )
        loss = self.MSE_criterion(next_frames, frames_tensor[:, 1:])
        return next_frames, loss


### Experimenting


In [4]:
# in_channel = 3
# num_hidden = 6
# width = 32
# filter_size = 5
# stride = 1  # Does not seem to work when stride ~= 1
# layer_norm = False

# lstm = SpatioTemporalLSTMCell(
#     in_channel, num_hidden, width, filter_size, stride, layer_norm
# )



In [5]:
# batch_size = 11

# X = torch.rand(batch_size, in_channel, width, width)
# h = torch.rand(batch_size, num_hidden, width, width)
# c = torch.rand(batch_size, num_hidden, width, width)
# m = torch.rand(batch_size, num_hidden, width, width)

# h, c, m = lstm(X, h, c, m)

# print(h.size(), c.size(), m.size())


### Experimenting 2


In [6]:
num_layers = 3
num_hidden = [
    30,
    30,
    30,
]  # Hidden layer channel. Note that this code does not support different number of hidden filters for each layer
patch_size = 1
stride = 1
img_width = 32
layer_norm = True
filter_size = 5
img_channel = 3
# Not sure what is the different between input_length and sequence_length. But I will set it equal for now.
input_length = 10  # Sequence length
total_length = input_length


class Configs:
    def __init__(
        self,
        patch_size,
        stride,
        img_width,
        img_channel,
        layer_norm,
        filter_size,
        input_length,
        total_length,
    ):
        self.patch_size = patch_size
        self.stride = stride
        self.img_width = img_width
        self.img_channel = img_channel
        self.layer_norm = layer_norm
        self.filter_size = filter_size
        self.input_length = input_length
        self.total_length = total_length


configs = Configs(
    patch_size,
    stride,
    img_width,
    img_channel,
    layer_norm,
    filter_size,
    input_length,
    total_length,
)



In [7]:
model = RNN(num_layers=num_layers, num_hidden=num_hidden, configs=configs)


In [8]:
batch_size = 11
seq_length = 10
in_channel = 3

frames_tensor = torch.rand(batch_size, seq_length, img_width, img_width, in_channel)
print(frames_tensor.size())
frames = frames_tensor.permute(0, 1, 4, 2, 3).contiguous()
print(frames.size())

batch = frames.shape[0]
height = frames.shape[3]
width = frames.shape[4]
print(batch, height, width)

next_frames = []
h_t = []
c_t = []

for i in range(num_layers):
    zeros = torch.zeros([batch, num_hidden[i], height, width])
    h_t.append(zeros)
    c_t.append(zeros)

for h in h_t:
    print("h", h.size())

memory = torch.zeros([batch, num_hidden[0], height, width])

print("m", memory.size())

for t in range(2):
    print(f"------t={t}--------")
    net = frames[:, t]
    print("net", net.size())

    print(f"i={0}")
    h_t[0], c_t[0], memory = model.cell_list[0](net, h_t[0], c_t[0], memory)
    print(h_t[0].size(), c_t[0].size(), memory.size())
    for i in range(1, num_layers):
        print(f"i={i}")
        h_t[i], c_t[i], memory = model.cell_list[i](h_t[i - 1], h_t[i], c_t[i], memory)
        print(h_t[0].size(), c_t[0].size(), memory.size())

    x_gen = model.conv_last(h_t[num_layers - 1])
    next_frames.append(x_gen)
    print("x_gen", x_gen.size())



torch.Size([11, 10, 32, 32, 3])
torch.Size([11, 10, 3, 32, 32])
11 32 32
h torch.Size([11, 30, 32, 32])
h torch.Size([11, 30, 32, 32])
h torch.Size([11, 30, 32, 32])
m torch.Size([11, 30, 32, 32])
------t=0--------
net torch.Size([11, 3, 32, 32])
i=0
torch.Size([11, 30, 32, 32]) torch.Size([11, 30, 32, 32]) torch.Size([11, 30, 32, 32])
i=1
torch.Size([11, 30, 32, 32]) torch.Size([11, 30, 32, 32]) torch.Size([11, 30, 32, 32])
i=2
torch.Size([11, 30, 32, 32]) torch.Size([11, 30, 32, 32]) torch.Size([11, 30, 32, 32])
x_gen torch.Size([11, 3, 32, 32])
------t=1--------
net torch.Size([11, 3, 32, 32])
i=0
torch.Size([11, 30, 32, 32]) torch.Size([11, 30, 32, 32]) torch.Size([11, 30, 32, 32])
i=1
torch.Size([11, 30, 32, 32]) torch.Size([11, 30, 32, 32]) torch.Size([11, 30, 32, 32])
i=2
torch.Size([11, 30, 32, 32]) torch.Size([11, 30, 32, 32]) torch.Size([11, 30, 32, 32])
x_gen torch.Size([11, 3, 32, 32])


In [9]:
model.cell_list[0]


SpatioTemporalLSTMCell(
  (conv_x): Sequential(
    (0): Conv2d(3, 210, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), bias=False)
    (1): LayerNorm((210, 32, 32), eps=1e-05, elementwise_affine=True)
  )
  (conv_h): Sequential(
    (0): Conv2d(30, 120, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), bias=False)
    (1): LayerNorm((120, 32, 32), eps=1e-05, elementwise_affine=True)
  )
  (conv_m): Sequential(
    (0): Conv2d(30, 90, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), bias=False)
    (1): LayerNorm((90, 32, 32), eps=1e-05, elementwise_affine=True)
  )
  (conv_o): Sequential(
    (0): Conv2d(60, 30, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), bias=False)
    (1): LayerNorm((30, 32, 32), eps=1e-05, elementwise_affine=True)
  )
  (conv_last): Conv2d(60, 30, kernel_size=(1, 1), stride=(1, 1), bias=False)
)