In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F


In [2]:
class ConvLSTMCell(nn.Module):
    """
    A single ConvLSTM cell.
    Input:  x_t  -> (B, C_in, H, W)
    Hidden: h_t  -> (B, C_hidden, H, W)
    Cell:   c_t  -> (B, C_hidden, H, W)
    """

    def __init__(self, in_channels, hidden_channels, kernel_size=3):
        super().__init__()

        padding = kernel_size // 2
        self.hidden_channels = hidden_channels

        # One convolution for all four gates
        self.conv = nn.Conv2d(
            in_channels + hidden_channels,
            4 * hidden_channels,
            kernel_size=kernel_size,
            padding=padding,
            bias=True
        )

    def forward(self, x_t, h_prev, c_prev):
        # Concatenate along channel dimension
        combined = torch.cat([x_t, h_prev], dim=1)

        # Compute all gates at once
        gates = self.conv(combined)

        # Split into gates
        i, f, o, g = torch.chunk(gates, 4, dim=1)

        i = torch.sigmoid(i)        # input gate
        f = torch.sigmoid(f)        # forget gate
        o = torch.sigmoid(o)        # output gate
        g = torch.tanh(g)           # candidate

        # Cell update
        c_t = f * c_prev + i * g

        # Hidden state
        h_t = o * torch.tanh(c_t)

        return h_t, c_t


In [3]:
class ConvLSTM(nn.Module):
    def __init__(self, in_channels, hidden_channels, kernel_size=3):
        super().__init__()
        self.cell = ConvLSTMCell(
            in_channels, hidden_channels, kernel_size
        )

    def forward(self, x):
        """
        x: (B, T, C, H, W)
        """
        B, T, C, H, W = x.shape
        device = x.device

        h = torch.zeros(B, self.cell.hidden_channels, H, W, device=device)
        c = torch.zeros_like(h)

        outputs = []

        for t in range(T):
            h, c = self.cell(x[:, t], h, c)
            outputs.append(h)

        # Stack outputs over time
        return torch.stack(outputs, dim=1)  # (B, T, C_hidden, H, W)


In [4]:
B, T, C, H, W = 2, 5, 3, 64, 64
x = torch.randn(B, T, C, H, W)

model = ConvLSTM(in_channels=3, hidden_channels=16)
y = model(x)

print(y.shape)
# -> (2, 5, 16, 64, 64)


torch.Size([2, 5, 16, 64, 64])
