In [1]:
#| default_exp convLSTM
#| default_cls_lvl 3

In [2]:
#| hide
from nbdev import *

# Convolutional Long-Shortterm Memory Network (ConvLSTM)

This code represent an implementation of the model structure suggested by Shi et al. (2015) - Convolutional LSTM Network: A Machine Learning Approach for Precipitation Nowcasting. The idea is by extending the fully connected LSTM (FC-LSTM) with convolutional strucutres in both the input-to-state and state-to-state transitions (input and hidden state) which is named convolutional LSTM (convLSTM).

## The model

Although the FC-LSTM has proven powerpful for handling temporal correlations, it contains to much redundancy for spatial data. To adresse that the author propose to include convolutional strucutres. By stacking multiple ConvLSTM layers they were able to predit spatiotemporal sequences. The **major drawback** of FC-LSTM in handling spatiotemporal data is its usage of full connections in input to state and state-to-state transitions in which **no** spatial informations is encoded

The ConvLSTM determines the future state of a certain cell in the grid by the inputs and past states of its local neighbors. This can easily be achieved by using a convolution operator in the state-to-state and input-to-state transitions.

Here are the key equations where `*` denotes the convolutional operator and $\circ$ as before the Hadamard product:

$$
i_t = \sigma(W_{xi}*X_t+W_{hi}*H_{t-1}+W_{ci}\circ C_{t-1}+bi) \\
f_t = \sigma(W_{xf}*X_t+W_{hf}*H_{t-1}+W_{cf}\circ C_{t-1}+bf) \\
C_t = f_t \circ C_{t-1} + i_t \circ tanh(W_{xc}*X_t+W_{hc}* H_{t-1}+bc) \\
o_t = \sigma(W_{xo}*X_t+W_{ho}*H_{t-1}+W_{co}\circ C_{t-1}+bo) \\
H_t = o_t \circ tanh (C_t)
$$

![](https://miro.medium.com/v2/resize:fit:942/1*u8neecA4w6b_F1NgnyPP0Q.png)

In [3]:
#| export

import torch.nn as nn
import torch

class ConvLSTMCell(nn.Module):

    def __init__(self, input_dim, hidden_dim, kernel_size, bias):
        """
        Initialize ConvLSTM cell.

        Parameters
        ----------
        input_dim: int
            Number of channels of input tensor.
        hidden_dim: int
            Number of channels of hidden state.
        kernel_size: (int, int)
            Size of the convolutional kernel.
        bias: bool
            Whether or not to add the bias.
        """

        super(ConvLSTMCell, self).__init__()

        self.input_dim = input_dim
        self.hidden_dim = hidden_dim

        self.kernel_size = kernel_size
        self.padding = kernel_size[0] // 2, kernel_size[1] // 2
        self.bias = bias

        self.conv = nn.Conv2d(in_channels=self.input_dim + self.hidden_dim,
                              out_channels=4 * self.hidden_dim,
                              kernel_size=self.kernel_size,
                              padding=self.padding,
                              bias=self.bias)

    def forward(self, input_tensor, cur_state):
        h_cur, c_cur = cur_state

        combined = torch.cat([input_tensor, h_cur], dim=1)  # concatenate along channel axis
        combined_conv = self.conv(combined)
        cc_i, cc_f, cc_o, cc_g = torch.split(combined_conv, self.hidden_dim, dim=1)
        i = torch.sigmoid(cc_i)
        f = torch.sigmoid(cc_f)
        o = torch.sigmoid(cc_o)
        g = torch.tanh(cc_g)

        c_next = f * c_cur + i * g
        h_next = o * torch.tanh(c_next)

        return h_next, c_next

    def init_hidden(self, batch_size, image_size):
        height, width = image_size
        return (torch.zeros(batch_size, self.hidden_dim, height, width, device=self.conv.weight.device),
                torch.zeros(batch_size, self.hidden_dim, height, width, device=self.conv.weight.device))

In [4]:
show_doc(ConvLSTMCell)

---

### ConvLSTMCell

>      ConvLSTMCell (input_dim, hidden_dim, kernel_size, bias)

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing to nest them in
a tree structure. You can assign the submodules as regular attributes::

    import torch.nn as nn
    import torch.nn.functional as F

    class Model(nn.Module):
        def __init__(self):
            super().__init__()
            self.conv1 = nn.Conv2d(1, 20, 5)
            self.conv2 = nn.Conv2d(20, 20, 5)

        def forward(self, x):
            x = F.relu(self.conv1(x))
            return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will have their
parameters converted too when you call :meth:`to`, etc.

.. note::
    As per the example above, an ``__init__()`` call to the parent class
    must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or
                evaluation mode.
:vartype training: bool

In [5]:
#| export

class ConvLSTM(nn.Module):

    """

    Parameters:
        input_dim: Number of channels in input
        hidden_dim: Number of hidden channels
        kernel_size: Size of kernel in convolutions
        num_layers: Number of LSTM layers stacked on each other
        batch_first: Whether or not dimension 0 is the batch or not
        bias: Bias or no bias in Convolution
        return_all_layers: Return the list of computations for all layers
        Note: Will do same padding.

    Input:
        A tensor of size B, T, C, H, W or T, B, C, H, W
    Output:
        A tuple of two lists of length num_layers (or length 1 if return_all_layers is False).
            0 - layer_output_list is the list of lists of length T of each output
            1 - last_state_list is the list of last states
                    each element of the list is a tuple (h, c) for hidden state and memory
    Example:
        >> x = torch.rand((32, 10, 64, 128, 128))
        >> convlstm = ConvLSTM(64, 16, 3, 1, True, True, False)
        >> _, last_states = convlstm(x)
        >> h = last_states[0][0]  # 0 for layer index, 0 for h index
    """

    def __init__(self, input_dim, hidden_dim, kernel_size, num_layers,
                 batch_first=False, bias=True, return_all_layers=False, stateful=False):
        super(ConvLSTM, self).__init__()

        self._check_kernel_size_consistency(kernel_size)

        # Make sure that both `kernel_size` and `hidden_dim` are lists having len == num_layers
        kernel_size = self._extend_for_multilayer(kernel_size, num_layers)
        hidden_dim = self._extend_for_multilayer(hidden_dim, num_layers)
        if not len(kernel_size) == len(hidden_dim) == num_layers:
            raise ValueError('Inconsistent list length.')

        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.kernel_size = kernel_size
        self.num_layers = num_layers
        self.batch_first = batch_first
        self.bias = bias
        self.return_all_layers = return_all_layers

        cell_list = []
        for i in range(0, self.num_layers):
            cur_input_dim = self.input_dim if i == 0 else self.hidden_dim[i - 1]

            cell_list.append(ConvLSTMCell(input_dim=cur_input_dim,
                                          hidden_dim=self.hidden_dim[i],
                                          kernel_size=self.kernel_size[i],
                                          bias=self.bias))

        self.cell_list = nn.ModuleList(cell_list)

        self.last_state = None
        self.stateful = stateful

    def forward(self, input_tensor, hidden_state=None):
        if not self.batch_first:
            # (t, b, c, h, w) -> (b, t, c, h, w)
            input_tensor = input_tensor.permute(1, 0, 2, 3, 4)

        b, _, _, h, w = input_tensor.size()

        if hidden_state is None:
            # Initialize hidden state if it's the first call or if not stateful
            if not self.stateful or self.last_state is None:
                hidden_state = self._init_hidden(batch_size=b, image_size=(h, w))
            else:
                # If the model is stateful and last_state is not None, use last_state
                hidden_state = self.last_state
        else:
            # If hidden_state was provided as an input, we use it directly
            self.last_state = hidden_state

        layer_output_list = []
        last_state_list = []

        seq_len = input_tensor.size(1)
        cur_layer_input = input_tensor

        for layer_idx in range(self.num_layers):
            # Fetching the hidden state for the current layer
            h, c = hidden_state[layer_idx]
            output_inner = []

            for t in range(seq_len):
                h, c = self.cell_list[layer_idx](
                    input_tensor=cur_layer_input[:, t, :, :, :],
                    cur_state=[h, c]
                )
                output_inner.append(h)

            layer_output = torch.stack(output_inner, dim=1)
            cur_layer_input = layer_output

            layer_output_list.append(layer_output)
            last_state_list.append((h, c))  # Save the last state as a tuple for consistency

        if not self.return_all_layers:
            layer_output_list = [layer_output_list[-1]]
            last_state_list = [last_state_list[-1]]

        if self.stateful:
            self.last_state = last_state_list  # Save the last state for the next call

        return layer_output_list, last_state_list



    def _init_hidden(self, batch_size, image_size):
        init_states = []
        for i in range(self.num_layers):
            init_states.append(self.cell_list[i].init_hidden(batch_size, image_size))
        return init_states

    @staticmethod
    def _check_kernel_size_consistency(kernel_size):
        if not (isinstance(kernel_size, tuple) or
                (isinstance(kernel_size, list) and all([isinstance(elem, tuple) for elem in kernel_size]))):
            raise ValueError('`kernel_size` must be tuple or list of tuples')

    @staticmethod
    def _extend_for_multilayer(param, num_layers):
        if not isinstance(param, list):
            param = [param] * num_layers
        return param

In [6]:
show_doc(ConvLSTM)

---

### ConvLSTM

>      ConvLSTM (input_dim, hidden_dim, kernel_size, num_layers,
>                batch_first=False, bias=True, return_all_layers=False)

Parameters:
    input_dim: Number of channels in input
    hidden_dim: Number of hidden channels
    kernel_size: Size of kernel in convolutions
    num_layers: Number of LSTM layers stacked on each other
    batch_first: Whether or not dimension 0 is the batch or not
    bias: Bias or no bias in Convolution
    return_all_layers: Return the list of computations for all layers
    Note: Will do same padding.

Input:
    A tensor of size B, T, C, H, W or T, B, C, H, W
Output:
    A tuple of two lists of length num_layers (or length 1 if return_all_layers is False).
        0 - layer_output_list is the list of lists of length T of each output
        1 - last_state_list is the list of last states
                each element of the list is a tuple (h, c) for hidden state and memory
Example:
    >> x = torch.rand((32, 10, 64, 128, 128))
    >> convlstm = ConvLSTM(64, 16, 3, 1, True, True, False)
    >> _, last_states = convlstm(x)
    >> h = last_states[0][0]  # 0 for layer index, 0 for h index

In [33]:
x = torch.rand((32, 10, 1, 128, 128))


In [34]:
convlstm = ConvLSTM(1, 16, (3,3), 1, batch_first=True, bias=True, return_all_layers=False)


In [35]:
convlstm.input_dim

1

In [37]:
 _, last_states = convlstm(x)
h = last_states[0][0]

In [38]:
h.shape

torch.Size([32, 16, 128, 128])