# Structure of a ConvLSTM model

### Import Libraries

In [5]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

### Define the ConvLSTM Model

The ConvLSTMCell class implements a single ConvLSTM unit, which processes input frame sequences one time step at a time.

In [6]:
### 1 future step prediction
class ConvLSTMCell(nn.Module):
    
    def __init__(self, input_dim, hidden_dim, kernel_size, bias): #input_dim = Number of channels in input, hidden_dim = Number of channels in hidden state, kernel_size = Size of the convolutional kernel, bias = If False, then the layer does not use bias weights b_ih and b_hh. Default
        super(ConvLSTMCell, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.kernel_size = kernel_size
        self.padding = kernel_size[0] // 2, kernel_size[1] // 2
        self.bias = bias
        self.conv = nn.Conv2d(in_channels=self.input_dim + self.hidden_dim,
                              out_channels=4 * self.hidden_dim, # To compute all gates at once
                              padding=self.padding,
                              bias=self.bias)

    def forward(self, input_tensor, cur_state):
        h_cur, c_cur = cur_state # Unpack the current hidden and cell states
        combined = torch.cat([input_tensor, h_cur], dim=1) # Concatenate input with previous hidden state
        combined_conv = self.conv(combined) # Apply convolutional layer
        cc_i, cc_f, cc_o, cc_g = torch.split(combined_conv, self.hidden_dim, dim=1) # The output of the convolution is split into four equal parts, each corresponding to one of the LSTM gates.
        i = torch.sigmoid(cc_i) # Input gate
        f = torch.sigmoid(cc_f) # Forget gate
        o = torch.sigmoid(cc_o) # Output gate
        g = torch.tanh(cc_g) # Candidate cell state
        c_next = f * c_cur + i * g # Compute the next cell state
        h_next = o * torch.tanh(c_next) # Compute the next hidden state
        return h_next, c_next

    def init_hidden(self, batch_size, image_size): # Creates tensors filled with zeros for both h_0 and c_0
        height, width = image_size
        return (torch.zeros(batch_size, self.hidden_dim, height, width, device=self.conv.weight.device), 
                torch.zeros(batch_size, self.hidden_dim, height, width, device=self.conv.weight.device))


class ConvLSTM(nn.Module): # This class stacks multiple ConvLSTMCell layers to process entire sequences.
    def __init__(self, input_dim, hidden_dim, kernel_size, num_layers,
                 batch_first=False, bias=True, return_all_layers=False): # num_layers = Number of layers in the network
        super(ConvLSTM, self).__init__()
        kernel_size = self._extend_for_multilayer(kernel_size, num_layers)
        hidden_dim = self._extend_for_multilayer(hidden_dim, num_layers) # Allows specifying different kernel sizes and hidden dimensions per layer.
        if not len(kernel_size) == len(hidden_dim) == num_layers:
            raise ValueError('Inconsistent list length.')
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.kernel_size = kernel_size
        self.num_layers = num_layers
        self.batch_first = batch_first
        self.bias = bias
        self.return_all_layers = return_all_layers
        cell_list = []
        for i in range(0, self.num_layers): #Creates a list of ConvLSTMCell layers. The first layer receives input_dim, while subsequent layers receive the previous layer’s hidden_dim.
            cur_input_dim = self.input_dim if i == 0 else self.hidden_dim[i - 1]
            cell_list.append(ConvLSTMCell(input_dim=cur_input_dim,
                                          hidden_dim=self.hidden_dim[i],
                                          kernel_size=self.kernel_size[i],
                                          bias=self.bias))
        self.cell_list = nn.ModuleList(cell_list)

    def forward(self, input_tensor, hidden_state=None):
        if not self.batch_first: #Ensures the input shape follows (batch, seq, channels, height, width).
            input_tensor = input_tensor.permute(1, 0, 2, 3, 4)
        b, _, _, h, w = input_tensor.size()
        if hidden_state is None:
            hidden_state = self._init_hidden(batch_size=b, image_size=(h, w)) #If no initial state is given, it initializes zeros.
        layer_output_list = []
        last_state_list = []
        seq_len = input_tensor.size(1)
        cur_layer_input = input_tensor
        for layer_idx in range(self.num_layers):
            h, c = hidden_state[layer_idx] # Get initial hidden and cell states for the current layer
            output_inner = []
            for t in range(seq_len):
                h, c = self.cell_list[layer_idx](input_tensor=cur_layer_input[:, t, :, :, :], cur_state=[h, c])
                output_inner.append(h)
            layer_output = torch.stack(output_inner, dim=1)
            cur_layer_input = layer_output
            layer_output_list.append(layer_output)
            last_state_list.append([h, c])
        if not self.return_all_layers:
            layer_output_list = layer_output_list[-1:]
            last_state_list = last_state_list[-1:]
        return layer_output_list, last_state_list

    def _init_hidden(self, batch_size, image_size): #Calls init_hidden for each layer, ensuring they are properly initialized.
        init_states = []
        for i in range(self.num_layers):
            init_states.append(self.cell_list[i].init_hidden(batch_size, image_size))
        return init_states

    @staticmethod
    def _extend_for_multilayer(param, num_layers): #Ensures that the kernel size and hidden dimensions are lists of length num_layers.
        if not isinstance(param, list):
            param = [param] * num_layers
        return param

In [None]:
### 10 future step prediction
class ConvLSTMCell(nn.Module):
    
    def __init__(self, input_dim, hidden_dim, kernel_size, bias): #input_dim = Number of channels in input, hidden_dim = Number of channels in hidden state, kernel_size = Size of the convolutional kernel, bias = If False, then the layer does not use bias weights b_ih and b_hh. Default
        super(ConvLSTMCell, self).__init__()
        
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.kernel_size = kernel_size
        self.padding = kernel_size[0] // 2, kernel_size[1] // 2
        self.bias = bias

        self.conv = nn.Conv2d(
            in_channels=self.input_dim + self.hidden_dim,
            out_channels=4 * self.hidden_dim, # To compute all gates at once
            padding=self.padding,
            kernel_size=self.kernel_size,
            bias=self.bias
        )

    def forward(self, input_tensor, cur_state):
        h_cur, c_cur = cur_state # Unpack the current hidden and cell states

        combined = torch.cat([input_tensor, h_cur], dim=1) # Concatenate input with previous hidden state
        combined_conv = self.conv(combined) # Apply convolutional layer

        cc_i, cc_f, cc_o, cc_g = torch.split(combined_conv, self.hidden_dim, dim=1) # The output of the convolution is split into four equal parts, each corresponding to one of the LSTM gates.

        i = torch.sigmoid(cc_i) # Input gate
        f = torch.sigmoid(cc_f) # Forget gate
        o = torch.sigmoid(cc_o) # Output gate
        g = torch.tanh(cc_g) # Candidate cell state
        c_next = f * c_cur + i * g # Compute the next cell state
        h_next = o * torch.tanh(c_next) # Compute the next hidden state
        
        return h_next, c_next

    def init_hidden(self, batch_size, image_size): # Creates tensors filled with zeros for both h_0 and c_0
        height, width = image_size
        return (torch.zeros(batch_size, self.hidden_dim, height, width, device=self.conv.weight.device), 
                torch.zeros(batch_size, self.hidden_dim, height, width, device=self.conv.weight.device))


class ConvLSTM(nn.Module): # This class stacks multiple ConvLSTMCell layers to process entire sequences.
    def __init__(self, input_dim, hidden_dim, kernel_size, num_layers, batch_first=False, bias=True, return_all_layers=False): # num_layers = Number of layers in the network
        super(ConvLSTM, self).__init__()

        kernel_size = self._extend_for_multilayer(kernel_size, num_layers)
        hidden_dim = self._extend_for_multilayer(hidden_dim, num_layers) # Allows specifying different kernel sizes and hidden dimensions per layer.
        if not len(kernel_size) == len(hidden_dim) == num_layers:
            raise ValueError('Inconsistent list length.')
        
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.kernel_size = kernel_size
        self.num_layers = num_layers
        self.batch_first = batch_first
        self.bias = bias
        self.return_all_layers = return_all_layers
        cell_list = []
        for i in range(0, self.num_layers): #Creates a list of ConvLSTMCell layers. The first layer receives input_dim, while subsequent layers receive the previous layer’s hidden_dim.
            cur_input_dim = self.input_dim if i == 0 else self.hidden_dim[i - 1]
            cell_list.append(ConvLSTMCell(input_dim=cur_input_dim,
                                          hidden_dim=self.hidden_dim[i],
                                          kernel_size=self.kernel_size[i],
                                          bias=self.bias))
            
        self.cell_list = nn.ModuleList(cell_list)

        self.conv_out = nn.Conv2d(in_channels=self.hidden_dim[-1],  # Hidden dim of the last layer
                                    out_channels=input_dim,          # Output channels (e.g., 1 for grayscale)
                                    kernel_size=1,                   # 1x1 convolution
                                    padding=0,
                                    bias=bias)

    def forward(self, input_tensor, future_frames=10):
      #print("Input tensor shape:", input_tensor.shape)  # Debugging statement
      if not self.batch_first:
          input_tensor = input_tensor.permute(1, 0, 2, 3, 4)
      b, _, _, h, w = input_tensor.size()
      #print("Batch size:", b, "Time steps:", _, "Height:", h, "Width:", w)

      # Initialize hidden states
      hidden_state = self._init_hidden(batch_size=b, image_size=(h, w))

      # Process input sequence
      layer_output_list = []
      last_state_list = []
      cur_layer_input = input_tensor

      for layer_idx in range(self.num_layers):
          h, c = hidden_state[layer_idx]
          output_inner = []
          for t in range(input_tensor.size(1)):
              h, c = self.cell_list[layer_idx](input_tensor=cur_layer_input[:, t, :, :, :], cur_state=[h, c])
              output_inner.append(h)
          layer_output = torch.stack(output_inner, dim=1)
          cur_layer_input = layer_output
          layer_output_list.append(layer_output)
          last_state_list.append([h, c])

      # Predict future frames
      outputs = []
      last_output = layer_output[:, -1]
      #print("Last output shape (before prediction):", last_output.shape)
      for t in range(future_frames):
          last_output, c = self.cell_list[-1](input_tensor=last_output, cur_state=[last_output, c])
          #print("Last output shape (during prediction):", last_output.shape)
          outputs.append(self.conv_out(last_output))
      outputs = torch.stack(outputs, dim=1)
      #print("Outputs shape:", outputs.shape)

      return outputs

    def _init_hidden(self, batch_size, image_size): #Calls init_hidden for each layer, ensuring they are properly initialized.
        init_states = []
        for i in range(self.num_layers):
            init_states.append(self.cell_list[i].init_hidden(batch_size, image_size))
        return init_states

    @staticmethod
    def _extend_for_multilayer(param, num_layers): #Ensures that the kernel size and hidden dimensions are lists of length num_layers.
        if not isinstance(param, list):
            param = [param] * num_layers
        return param

In [None]:
## Including Self-Attention

import torch
import torch.nn as nn
import torch.nn.functional as F

class SelfAttention(nn.Module):
    def __init__(self, in_channels, reduction_ratio=8):
        super(SelfAttention, self).__init__()
        self.query_conv = nn.Conv2d(in_channels, in_channels // reduction_ratio, kernel_size=1)
        self.key_conv = nn.Conv2d(in_channels, in_channels // reduction_ratio, kernel_size=1)
        self.value_conv = nn.Conv2d(in_channels, in_channels, kernel_size=1)
        self.gamma = nn.Parameter(torch.zeros(1))  # Learnable scaling factor

    def forward(self, x):
        batch_size, channels, height, width = x.size()
        N = height * width  # Number of spatial locations

        # Map input to Query, Key, and Value
        query = self.query_conv(x).view(batch_size, -1, N).permute(0, 2, 1)  # (B, N, C')
        key = self.key_conv(x).view(batch_size, -1, N)                      # (B, C', N)
        value = self.value_conv(x).view(batch_size, -1, N)                  # (B, C, N)

        # Compute similarity scores
        energy = torch.bmm(query, key)  # (B, N, N)
        attention = F.softmax(energy, dim=-1)  # Normalize along columns

        # Weighted aggregation
        out = torch.bmm(value, attention.permute(0, 2, 1))  # (B, C, N)
        out = out.view(batch_size, channels, height, width)  # Reshape back to spatial dimensions

        # Add residual connection and scale
        out = self.gamma * out + x
        return out

class ConvLSTMCellWithAttention(nn.Module):
    def __init__(self, input_dim, hidden_dim, kernel_size, bias=True):
        super(ConvLSTMCellWithAttention, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.kernel_size = kernel_size
        self.bias = bias

        # Convolutional layers for gates
        self.conv = nn.Conv2d(
            in_channels=input_dim + hidden_dim,
            out_channels=4 * hidden_dim,
            kernel_size=kernel_size,
            padding=kernel_size[0] // 2,
            bias=bias
        )

        # Self-attention modules
        self.input_attention = SelfAttention(input_dim)
        self.hidden_attention = SelfAttention(hidden_dim)

    def forward(self, input_tensor, cur_state):
        h_cur, c_cur = cur_state

        # Apply self-attention to input and hidden state
        input_tensor = self.input_attention(input_tensor)
        h_cur = self.hidden_attention(h_cur)

        # Concatenate input and hidden state
        combined = torch.cat([input_tensor, h_cur], dim=1)

        # Compute gates
        combined_conv = self.conv(combined)
        cc_i, cc_f, cc_o, cc_g = torch.split(combined_conv, self.hidden_dim, dim=1)

        i = torch.sigmoid(cc_i)
        f = torch.sigmoid(cc_f)
        o = torch.sigmoid(cc_o)
        g = torch.tanh(cc_g)

        # Update cell and hidden states
        c_next = f * c_cur + i * g
        h_next = o * torch.tanh(c_next)

        return h_next, c_next

    def init_hidden(self, batch_size, image_size):
        height, width = image_size
        return (
            torch.zeros(batch_size, self.hidden_dim, height, width, device=self.conv.weight.device),
            torch.zeros(batch_size, self.hidden_dim, height, width, device=self.conv.weight.device)
        )
    
class ConvLSTMWithAttention(nn.Module):
    def __init__(self, input_dim, hidden_dim, kernel_size, num_layers, batch_first=True):
        super(ConvLSTMWithAttention, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.kernel_size = kernel_size
        self.num_layers = num_layers
        self.batch_first = batch_first

        # Stack multiple ConvLSTM cells with attention
        cell_list = []
        for i in range(num_layers):
            cur_input_dim = input_dim if i == 0 else hidden_dim
            cell_list.append(ConvLSTMCellWithAttention(cur_input_dim, hidden_dim, kernel_size))
        self.cell_list = nn.ModuleList(cell_list)

        # Output layer
        self.conv_out = nn.Conv2d(hidden_dim, input_dim, kernel_size=1)

    def forward(self, input_tensor, future_frames=0):
        if self.batch_first:
            input_tensor = input_tensor.permute(0, 1, 4, 2, 3)  # (B, T, H, W, C) -> (B, T, C, H, W)

        b, seq_len, _, h, w = input_tensor.size()

        # Initialize hidden states
        hidden_state = [cell.init_hidden(b, (h, w)) for cell in self.cell_list]

        outputs = []
        for t in range(seq_len):
            cur_layer_input = input_tensor[:, t]
            for layer_idx in range(self.num_layers):
                h, c = self.cell_list[layer_idx](cur_layer_input, hidden_state[layer_idx])
                hidden_state[layer_idx] = (h, c)
                cur_layer_input = h
            outputs.append(h)

        # Autoregressive prediction for future frames
        for t in range(future_frames):
            cur_layer_input = outputs[-1]
            for layer_idx in range(self.num_layers):
                h, c = self.cell_list[layer_idx](cur_layer_input, hidden_state[layer_idx])
                hidden_state[layer_idx] = (h, c)
                cur_layer_input = h
            outputs.append(h)

        outputs = torch.stack(outputs, dim=1)  # (B, T, C, H, W)
        outputs = self.conv_out(outputs.view(-1, self.hidden_dim, h, w)).view(b, -1, 1, h, w)
        return outputs

## Prepare the Dataset

In [8]:
# Load dataset (example: MNIST-like sequences)
!wget "https://github.com/felipeart25/Coastal_Vision/raw/main/data/Data/mnist_test_seq.npy" -O mnist_test_seq.npy
data = np.load("mnist_test_seq.npy")  # Shape: (num_sequences, time_steps, channels, height, width)
data = torch.tensor(data, dtype=torch.float32) / 255.0  # Normalize to [0, 1]
data = data.unsqueeze(2)
data = data.permute(1, 0, 2, 3, 4)  # Swap axes 

#print shape
print(data.shape)
# Split into train and test
train_data, test_data = data[:8000], data[8000:]

# Input: first T-10 frames, Target: next 10 frames
train_dataset = TensorDataset(train_data[:, :10], train_data[:, -10:])  # Input: T-10, Target: 10
print(train_data.shape)
test_dataset = TensorDataset(test_data[:, :10], test_data[:, -10:])
print(test_data.shape)

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False)

for inputs, targets in train_loader:
    print("Inputs shape:", inputs.shape)  # Should be (B, T-10, 1, 64, 64)
    print("Targets shape:", targets.shape)  # Should be (B, 10, 1, 64, 64)
    break

for inputs, targets in test_loader:
    print("Inputs shape:", inputs.shape)  # Should be (B, T-10, 1, 64, 64)
    print("Targets shape:", targets.shape)  # Should be (B, 10, 1, 64, 64)
    break




'wget' is not recognized as an internal or external command,
operable program or batch file.


FileNotFoundError: [Errno 2] No such file or directory: 'mnist_test_seq.npy'

### Set Up Training

This code initializes and trains a ConvLSTM model using PyTorch

In [None]:
# Initialize model, loss, and optimizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ConvLSTM(input_dim=1, hidden_dim=64, kernel_size=(3, 3), num_layers=2, batch_first=True).to(device)
criterion = nn.MSELoss() # MSE is commonly used for regression tasks, which is suitable for predicting pixel intensities in images
optimizer = optim.Adam(model.parameters(), lr=0.001) # Uses the Adam optimizer with a learning rate (lr) of 0.001 to update model weights
print("Input tensor shape:", inputs.shape)
# Training loop The model is trained for 10 epochs. Each epoch means the model sees the entire training dataset once.
num_epochs = 10
for epoch in range(num_epochs):
    model.train() #Sets the model to training mode (enables dropout, batch normalization updates, etc.)
    total_loss = 0 #Initializes a variable to accumulate the total loss for the epoch
    for inputs, targets in train_loader:
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad() #Resets the gradients before backpropagation. By default, gradients accumulate in PyTorch, so this step prevents incorrect updates.
        outputs = model(inputs, future_frames=10)  # Predict 10 frames
        loss = criterion(outputs, targets)  # Compute loss across all 10 frames
        loss.backward() #Computes gradients of the loss with respect to model parameters
        optimizer.step() #Updates the model’s parameters using the gradients. This step is where the actual learning happens.
        total_loss += loss.item() #Accumulates the loss for the current batch
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {total_loss / len(train_loader):.4f}")

### Validation


In [None]:
from skimage.metrics import structural_similarity as ssim

def validate(model, loader, criterion, device, future_frames=10):
    model.eval()
    total_loss, total_mae, total_rmse, total_ssim = 0, 0, 0, 0
    with torch.no_grad():
        for inputs, targets in loader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs, future_frames=future_frames)

            loss = criterion(outputs, targets)
            total_loss += loss.item()

            mae = torch.abs(outputs - targets).mean().item()
            rmse = torch.sqrt(((outputs - targets) ** 2).mean()).item()
            
            # Convert to numpy for SSIM
            output_np = outputs.cpu().numpy()
            target_np = targets.cpu().numpy()
            ssim_score = ssim(output_np, target_np, data_range=1.0, multichannel=True)
            
            total_mae += mae
            total_rmse += rmse
            total_ssim += ssim_score

    avg_loss = total_loss / len(loader)
    avg_mae = total_mae / len(loader)
    avg_rmse = total_rmse / len(loader)
    avg_ssim = total_ssim / len(loader)

    print(f"Validation Loss: {avg_loss:.4f}, MAE: {avg_mae:.4f}, RMSE: {avg_rmse:.4f}, SSIM: {avg_ssim:.4f}")
    return avg_loss, avg_mae, avg_rmse, avg_ssim

# Validate after each epoch
validate(model, test_loader, criterion, device, future_frames=10)

### Testing and Visualization

In [None]:
# Test the model
model.eval()
with torch.no_grad():
    inputs, targets = next(iter(test_loader))
    inputs, targets = inputs.to(device), targets.to(device)
    outputs = model(inputs, future_frames=10).cpu().numpy()

# Visualize results
fig, axes = plt.subplots(3, 10, figsize=(20, 6))
for i in range(10):
    axes[0, i].imshow(targets[0, i].cpu().squeeze(), cmap="gray")
    axes[0, i].set_title("Ground Truth")
    axes[0, i].axis("off")
    axes[1, i].imshow(outputs[0, i].squeeze(), cmap="gray")
    axes[1, i].set_title("Prediction")
    axes[1, i].axis("off")
plt.show()