## Name: Raffaello Baluyot
## Course: DT8058

<center><h1 style="font-size:40px;">Regression using Recurrent Neural Networks</h1></center>

Welcome to this lab session! Today, we will be diving into the fascinating world of Recurrent Neural Networks (RNNs) and exploring their use in regression tasks.

By the end of this lab session, you will be able to:

* Understand the basic concepts behind Recurrent Neural Networks and how they work.
* Implement different architectures to include recurrent neural networks.
* Apply these RNN architectures to solve regression problems.
* Evaluate the performance of your models and understand their strengths and limitations.

In [None]:
# select a gpu
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [None]:
# autoreload imports
%load_ext autoreload
%autoreload 2

In [None]:
import random
import torch
import torch.nn as nn
import torchvision.models as models
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
import numpy as np
import cv2
from collections import OrderedDict
import copy
from sklearn.metrics import mean_squared_error

In [None]:
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
torch.set_default_device(device)
batch_size = 128

## dataset 

In [None]:
# Convert angles to 2D vectors
def angle_to_vector_np(angles_rad):
    return np.cos(angles_rad), np.sin(angles_rad)


# Convert degrees to radians
def deg_to_rad(deg):
    return deg * np.pi / 180


# Convert radians to degrees
def rad_to_deg(rad):
    return rad * 180 / np.pi


# Convert angles to 2D vectors
def angle_to_vector(angles):
    angles_rad = deg_to_rad(angles)
    return torch.cos(angles_rad), torch.sin(angles_rad)


# Convert 2D vectors to angles
def vector_to_angle(vectors):
    return rad_to_deg(torch.atan2(vectors[1], vectors[0]))


# Convert 2D vectors to angles
def vector_to_angle_np(vectors):
    return rad_to_deg(np.arctan2(vectors[1], vectors[0]))

In [None]:
import matplotlib.pyplot as plt
from matplotlib import animation
from IPython.display import HTML


def animate_datapoint(x, y):
    fig = plt.figure(figsize=(5, 5))
    plt.axis("off")
    plt.title("Angle: {:.1f}".format(vector_to_angle_np(y)))
    ims = []
    for frame in x:
        if isinstance(frame, torch.Tensor):
            frame = frame.permute(1, 2, 0).numpy() * 255
        im = plt.imshow(frame.astype(np.uint8))
        ims.append([im])
    ani = animation.ArtistAnimation(fig, ims, interval=100, repeat_delay=1000)
    return ani


def animate_datapoint_with_direction(x, y, center, pred=None, pred_color="red"):
    """
    This function creates an animation of a data point with a direction arrow. Use this function to visualize the data points in the training, validation, and test sets.
    You can also use this function to visualize the predictions of your model, by providing the predicted direction vector as the `pred` parameter.
    In case the object color is interfering with the visibility of the direction arrow, you can change the color of the predicted direction arrow using the `pred_color` parameter.

    Parameters:
    x (array-like or torch.Tensor): The frames to be animated. If frames are torch.Tensor, they will be converted to numpy arrays.
    y (array-like): The direction vector for the arrow in the animation.
    center (tuple): The starting point (center) of the arrow.
    pred (array-like, optional): The predicted direction vector for the arrow in the animation. If provided, an additional arrow showing the predicted direction will be drawn. Defaults to None.
    pred_color (str, optional): The color of the predicted direction arrow. Defaults to "red".

    Returns:
    matplotlib.animation.ArtistAnimation: The resulting animation object.
    """

    fig = plt.figure(figsize=(5, 5))
    plt.axis("off")
    plt.title("Angle: {:.1f}".format(vector_to_angle_np(y)))

    ims = []
    for frame in x:
        if isinstance(frame, torch.Tensor):
            frame = frame.permute(1, 2, 0).numpy() * 255
        im = plt.imshow(frame.astype(np.uint8))
        arrow = plt.arrow(
            center[0],
            center[1],
            y[0] * 4,
            y[1] * -4,
            head_width=3,
            head_length=5,
            fc="black",
            ec="black",
        )
        if pred is not None:
            arrow_pred = plt.arrow(
                center[0],
                center[1],
                pred[0] * 4,
                pred[1] * -4,
                head_width=3,
                head_length=5,
                fc=pred_color,
                ec=pred_color,
            )
            ims.append([im, arrow, arrow_pred])
        else:
            ims.append([im, arrow])

    ani = animation.ArtistAnimation(fig, ims, interval=100, repeat_delay=1000)
    return ani

In [None]:
def generate_dataset(
    num_samples,
    canvas_size=(64, 64),
    shapes=["circle", "triangle", "square"],
    colors={
        "blue": (255, 0, 0),
        "yellow": (0, 255, 255),
        "red": (0, 0, 255),
        "gray": (128, 128, 128),
        "green": (0, 255, 0),
        "purple": (255, 0, 255),
    },
):
    """
    Generates a dataset of random shapes moving in a straight line in random directions.

    Parameters:
    num_samples (int): The number of samples to generate.
    canvas_size (tuple): The size of the canvas for each sample. Defaults to (64, 64).
    shapes (list): The shapes to be used in the dataset. Defaults to ["circle", "triangle", "square"].
    num_frames (int): The number of frames for each sample. Defaults to 30.
    colors (dict): The colors to be used for the shapes. Defaults to a dictionary of BGR color values.
    step_size (float): The step size for the movement of the shapes. Defaults to 0.5.

    Returns:
    tuple: A tuple containing the following elements:
        - dataset (numpy.ndarray): The generated dataset. Shape is (num_samples, num_frames, canvas_size[0], canvas_size[1], 3).
        - labels (numpy.ndarray): The labels for each sample. Shape is (num_samples,).
        - centers (list): The center positions for each shape.
    """

    labels_list = []
    centers_list = []
    shapes_list = []
    angles_deg_list = []
    sizes_list = []
    colors_list = []

    for i in range(num_samples):
        shape = np.random.choice(shapes)
        position = np.random.rand(2) * (np.array(canvas_size) - 1)
        size = np.random.randint(4, 10)
        angle_deg = np.random.uniform(0, 360)
        angle_rad = np.radians(angle_deg)
        color = colors[np.random.choice(list(colors.keys()))]
        colors_list.append(color)

        labels_list.append(angle_to_vector_np(angle_rad))
        centers_list.append(position.copy())
        shapes_list.append(shape)
        angles_deg_list.append(angle_deg)
        sizes_list.append(size)

    return (
        labels_list,
        centers_list,
        shapes_list,
        angles_deg_list,
        sizes_list,
        colors_list,
    )


def generate_frames(
    num_frames,
    center,
    angle_deg,
    canvas_size,
    shape,
    size,
    color,
    step_size=0.5,
):
    frames = np.zeros((num_frames, canvas_size[0], canvas_size[1], 3))
    angle_rad = np.radians(angle_deg)
    direction = np.array([np.cos(angle_rad), -np.sin(angle_rad)]) * step_size

    center = center.copy()

    for i in range(num_frames):
        canvas = np.ones((canvas_size[0], canvas_size[1], 3), dtype=np.uint8) * 255

        if shape == "circle":
            cv2.circle(canvas, tuple(center.astype(int)), size, color, -1)
        elif shape == "triangle":
            points = np.array(
                [
                    [center[0], center[1]],
                    [center[0] + size, center[1]],
                    [center[0] + size // 2, center[1] - size],
                ]
            )
            points = points.reshape((-1, 1, 2)).astype(int)
            cv2.fillPoly(canvas, [points], color=color)
        elif shape == "square":
            top_left = (int(center[0] - size // 2), int(center[1] - size // 2))
            bottom_right = (int(center[0] + size // 2), int(center[1] + size // 2))
            cv2.rectangle(canvas, top_left, bottom_right, color, -1)

        frames[i] = canvas

        # Update the center based on the direction and step size.
        center += direction

    return frames

In [None]:
class MovingShapesDataset(Dataset):
    def __init__(self, num_samples, num_frames, transform=None, canvas_size=(64, 64)):
        self.num_samples = num_samples
        self.num_frames = num_frames
        self.canvas_size = canvas_size
        (
            self.labels,
            self.centers,
            self.shapes,
            self.angles_deg,
            self.sizes,
            self.colors,
        ) = generate_dataset(num_samples)

        self.transform = transform

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        label, center, shape, angle_deg, size, color = (
            self.labels[idx],
            self.centers[idx],
            self.shapes[idx],
            self.angles_deg[idx],
            self.sizes[idx],
            self.colors[idx],
        )
        frames = generate_frames(
            self.num_frames, center, angle_deg, self.canvas_size, shape, size, color
        )
        frames = torch.from_numpy(frames).float() / 255.0
        if self.transform:
            frames, label = self.transform((frames, label))
        return (
            frames.permute(0, 3, 1, 2),
            np.array(label, dtype=np.float32),
            center,
        )  # center is added for plotting purposes, you are not allowed to use it in the network

In [None]:
train_dataset = MovingShapesDataset(num_samples=100000, num_frames=20)
val_dataset = MovingShapesDataset(num_samples=1000, num_frames=20)
test_dataset = MovingShapesDataset(num_samples=1000, num_frames=20)

In [None]:
index = random.randint(0, len(train_dataset) - 1)
sample_data, sample_label, sample_initial_position = train_dataset[index]

In [None]:
ani = animate_datapoint_with_direction(
    sample_data, sample_label, sample_initial_position, pred=(0.5, 0)
)
HTML(ani.to_jshtml())

## TASK 1

#TODO Define CNN model to predict the angle of movement. You are encouraged to use the models from torchvision. You can also use the models from the previous lab. Do not forget to change the output layer dimension. And using the proper training scripts from the previous labs, train the model. How the model is doing? 

## Answer

The model is still able to learn something from the training data. This makes sense since even before the popularity of recurrent models, there has been some success in using traditional modeling techniques for data with time dependence or order.

In [None]:
def train_epoch(
        optimizer:torch.optim.Optimizer, 
        loss: torch.nn.Module, 
        model: torch.nn.Module, 
        train_loader: DataLoader
):
    total_loss = 0
    total_items = 0
    model.train(True)

    for inputs, labels, _ in train_loader:
        inputs = inputs.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)

        loss_t = loss(outputs, labels)
        loss_t.backward()

        optimizer.step()
        
        n_items = len(inputs)
        total_loss += loss_t.item() * n_items
        total_items += n_items

    return total_loss / total_items

def validate_epoch(
        loss: torch.nn.Module, 
        model: torch.nn.Module, 
        val_loader: DataLoader
):
    total_loss = 0
    total_items = 0
    model.eval()

    with torch.no_grad():
        for inputs, labels, _ in val_loader:
            inputs = inputs.to(device)
            labels = labels.to(device)

            outputs = model(inputs)

            n_items = len(inputs)
            total_loss += loss(outputs, labels).item() * n_items
            total_items += n_items

    return total_loss / total_items


def training_loop(num_epoch, model, optimizer, loss, train_loader, val_loader):
    best_val_loss = np.inf
    best_model = None

    train_losses = list()
    val_losses = list()

    for epoch in range(num_epoch):
        train_loss = train_epoch(optimizer, loss, model, train_loader)
        val_loss = validate_epoch(loss, model, val_loader)
        train_losses.append(train_loss)
        val_losses.append(val_loss)

        if val_loss < best_val_loss:
            best_model = copy.deepcopy(model)
            best_val_loss = val_loss
        print(f"epoch {epoch + 1}: loss: {train_loss:0.4f} val loss: {val_loss:0.4f}")

    return best_model, train_losses, val_losses


def predict(model: torch.nn.Module, test_loader: DataLoader):
    with torch.no_grad():
        true = []
        pred = []
        for inputs, labels, _ in test_loader:
            inputs = inputs.to(device)
            labels = labels.to(device)

            outputs = model(inputs)
            predictions = outputs

            true.append(labels)
            pred.append(predictions)

    return torch.cat(true).cpu().numpy(), torch.cat(pred).cpu().numpy()

In [None]:
train_loader = DataLoader(
    train_dataset, 
    batch_size=batch_size, 
    shuffle=True,
    generator=torch.Generator(device=device)
)
valid_loader = DataLoader(val_dataset, batch_size=batch_size)
test_loader = DataLoader(test_dataset, batch_size=batch_size)

In [None]:
class CNNModel(torch.nn.Module):
    def __init__(self, num_channels=3, num_outputs=2, n_frames=20, input_shape=(64, 64)):
        super().__init__()

        n_conv_layers = 4
        max_pool_divisor = 2 ** n_conv_layers

        self.conv_layer1 = self._conv_layer_set(num_channels, 16)
        self.conv_layer2 = self._conv_layer_set(16, 32)
        self.conv_layer3 = self._conv_layer_set(32, 64)
        self.conv_layer4 = self._conv_layer_set(64, 128)
        flat_shape = (
            128 * n_frames *
            input_shape[0] // max_pool_divisor * 
            input_shape[1] // max_pool_divisor
        )
        self.fc1 = nn.Linear(flat_shape, num_outputs)

    def _conv_layer_set(self, in_c, out_c):
        conv_layer = nn.Sequential(OrderedDict([
            ('conv',nn.Conv2d(in_c, out_c, kernel_size=3, padding=1)),
            ('leakyrelu',nn.LeakyReLU()),
            ('maxpool',nn.MaxPool2d(2)),
        ]))
        return conv_layer
    
    def forward(self, x):
        batch_size, num_frames, C, H, W = x.shape
        x = x.view(batch_size * num_frames, C, H, W)

        out = self.conv_layer1(x)
        out = self.conv_layer2(out)
        out = self.conv_layer3(out)
        out = self.conv_layer4(out)

        out = out.reshape((batch_size, -1))

        out = self.fc1(out)
        return out


In [None]:
model = CNNModel()
critereon = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters())

In [None]:
torch.cuda.empty_cache()
best_model, train_losses, val_losses = training_loop(20, model, optimizer, critereon, train_loader, valid_loader)

In [None]:
plt.plot(train_losses, label="train")
plt.plot(val_losses, label="val")
plt.title("MSE Loss")
plt.legend()
plt.show()

In [None]:
test_true, test_pred = predict(best_model, test_loader)

In [None]:
mean_squared_error(test_true, test_pred)

## TASK 2

#TODO Define Recurrent neural network model to predict the angle of movement. And using the proper training scripts from the previous labs, train the model. How the model is doing? 

You can have a look at the following skeleton of a basic recurrent model in pytorch. Keep mind that is just a basic starting point. Have look at the [pytorch LSTM documentation](https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html) and adding more layers, dropout, bidirectionality and so on

## Answer

The model seems to struggle to learn. I tried different configurations, even thought the configuration below is very simple. All of these configurations fail. It seems like learning the spatial characteristics of images is more important then learning the time dependencies between the frames.

In [None]:
import torch
from torch import nn


class LSTMModel(nn.Module):
    def __init__(self, num_channels=3, hidden_size=128, output_size=2, num_layers=3, input_shape=(64, 64)):
        super().__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        input_size = num_channels * input_shape[0] * input_shape[1]
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, bidirectional=True)
        self.fc = nn.Linear(2 * hidden_size, output_size)

    def forward(self, x):
        batch_size, num_frames, C, H, W = x.shape
        x = x.reshape(batch_size, num_frames, C * H * W)

        out, _ = self.lstm(x)
        out = out[:, -1]
        out = self.fc(out)

        return out


In [None]:
model = LSTMModel(hidden_size=256)
critereon = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters())

In [None]:
torch.cuda.empty_cache()
best_model, train_losses, val_losses = training_loop(20, model, optimizer, critereon, train_loader, valid_loader)

In [None]:
test_true, test_pred = predict(best_model, test_loader)

In [None]:
plt.plot(train_losses, label="train")
plt.plot(val_losses, label="val")
plt.title("MSE Loss")
plt.legend()
plt.show()

In [None]:
mean_squared_error(test_true, test_pred)

## TASK 3

#TODO Define a CNN + LSTM model to predict the angle of movement. You should do so by passing the features learned from the CNN model you have defined to LSTM module. In another words, you can insert a LSTM layer in between the conv layers and fully connected layers in the CNN model you have defined. This should learn the temporal features in between the frames. Does this model achieve better results? 

## Answer

The model is able to achieve better performance than the first two models. Being able to extract spatial features and being able to correlate the time dependency between these spatial features allows the network to be able to learn more about the input features.

In [None]:
class CNNLSTMModel(torch.nn.Module):
    def __init__(self, num_channels=3, num_outputs=2, hidden_size=128, lstm_layers=1, input_shape=(64, 64)):
        super().__init__()

        n_conv_layers = 4
        max_pool_divisor = 2 ** n_conv_layers

        self.conv_layer1 = self._conv_layer_set(num_channels, 16)
        self.conv_layer2 = self._conv_layer_set(16, 32)
        self.conv_layer3 = self._conv_layer_set(32, 64)
        self.conv_layer4 = self._conv_layer_set(64, 128)
        flat_shape = (
            128 *
            input_shape[0] // max_pool_divisor * 
            input_shape[1] // max_pool_divisor
        )
        self.lstm = nn.LSTM(flat_shape, hidden_size, lstm_layers, batch_first=True, bidirectional=True)
        self.fc1 = nn.Linear(hidden_size * 2, num_outputs)

    def _conv_layer_set(self, in_c, out_c):
        conv_layer = nn.Sequential(OrderedDict([
            ('conv',nn.Conv2d(in_c, out_c, kernel_size=3, padding=1)),
            ('leakyrelu',nn.LeakyReLU()),
            ('maxpool',nn.MaxPool2d(2)),
        ]))
        return conv_layer
    
    def forward(self, x):
        batch_size, num_frames, C, H, W = x.shape
        x = x.view(batch_size * num_frames, C, H, W)

        out = self.conv_layer1(x)
        out = self.conv_layer2(out)
        out = self.conv_layer3(out)
        out = self.conv_layer4(out)

        out = out.reshape((batch_size, num_frames, -1))
        out, _ = self.lstm(out)

        out = self.fc1(out[:, -1])
        return out


In [None]:
model = CNNLSTMModel()
critereon = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters())

In [None]:
torch.cuda.empty_cache()
best_model, train_losses, val_losses = training_loop(20, model, optimizer, critereon, train_loader, valid_loader)

In [None]:
test_true, test_pred = predict(best_model, test_loader)

In [None]:
plt.plot(train_losses, label="train")
plt.plot(val_losses, label="val")
plt.title("MSE Loss")
plt.legend()
plt.show()

In [None]:
mean_squared_error(test_true, test_pred)

## TASK 4

#TODO Define a ConvLSTM model to predict the angle of movement. You should have a look at the lecture to refresh your memory on the mechanisms behind this. Does this model achieve better results? 

## Answer

Similar to the 3rd model, being able to have spatial features and time dependence enables the model to learn the features well.

We have included an [implementation of the ConvLSTM](https://github.com/ndrplz/ConvLSTM_pytorch/tree/master) for the convenience. You can use this as a building block in your ConvLSTM model

In [None]:
class ConvLSTMCell(nn.Module):
    def __init__(self, input_dim, hidden_dim, kernel_size, bias):
        """
        Initialize ConvLSTM cell.

        Parameters
        ----------
        input_dim: int
            Number of channels of input tensor.
        hidden_dim: int
            Number of channels of hidden state.
        kernel_size: (int, int)
            Size of the convolutional kernel.
        bias: bool
            Whether or not to add the bias.
        """

        super(ConvLSTMCell, self).__init__()

        self.input_dim = input_dim
        self.hidden_dim = hidden_dim

        self.kernel_size = kernel_size
        self.padding = kernel_size[0] // 2, kernel_size[1] // 2
        self.bias = bias

        self.conv = nn.Conv2d(
            in_channels=self.input_dim + self.hidden_dim,
            out_channels=4 * self.hidden_dim,
            kernel_size=self.kernel_size,
            padding=self.padding,
            bias=self.bias,
        )

    def forward(self, input_tensor, cur_state):
        h_cur, c_cur = cur_state

        combined = torch.cat(
            [input_tensor, h_cur], dim=1
        )  # concatenate along channel axis

        combined_conv = self.conv(combined)
        cc_i, cc_f, cc_o, cc_g = torch.split(combined_conv, self.hidden_dim, dim=1)
        i = torch.sigmoid(cc_i)
        f = torch.sigmoid(cc_f)
        o = torch.sigmoid(cc_o)
        g = torch.tanh(cc_g)

        c_next = f * c_cur + i * g
        h_next = o * torch.tanh(c_next)

        return h_next, c_next

    def init_hidden(self, batch_size, image_size):
        height, width = image_size
        return (
            torch.zeros(
                batch_size,
                self.hidden_dim,
                height,
                width,
                device=self.conv.weight.device,
            ),
            torch.zeros(
                batch_size,
                self.hidden_dim,
                height,
                width,
                device=self.conv.weight.device,
            ),
        )


class ConvLSTM(nn.Module):

    """

    Parameters:
        input_dim: Number of channels in input
        hidden_dim: Number of hidden channels
        kernel_size: Size of kernel in convolutions
        num_layers: Number of LSTM layers stacked on each other
        batch_first: Whether or not dimension 0 is the batch or not
        bias: Bias or no bias in Convolution
        return_all_layers: Return the list of computations for all layers
        Note: Will do same padding.

    Input:
        A tensor of size B, T, C, H, W or T, B, C, H, W
    Output:
        A tuple of two lists of length num_layers (or length 1 if return_all_layers is False).
            0 - layer_output_list is the list of lists of length T of each output
            1 - last_state_list is the list of last states
                    each element of the list is a tuple (h, c) for hidden state and memory
    Example:
        >> x = torch.rand((32, 10, 64, 128, 128))
        >> convlstm = ConvLSTM(64, 16, 3, 1, True, True, False)
        >> _, last_states = convlstm(x)
        >> h = last_states[0][0]  # 0 for layer index, 0 for h index
    """

    def __init__(
        self,
        input_dim,
        hidden_dim,
        kernel_size,
        num_layers,
        batch_first=False,
        bias=True,
        return_all_layers=False,
    ):
        super(ConvLSTM, self).__init__()

        self._check_kernel_size_consistency(kernel_size)

        # Make sure that both `kernel_size` and `hidden_dim` are lists having len == num_layers
        kernel_size = self._extend_for_multilayer(kernel_size, num_layers)
        hidden_dim = self._extend_for_multilayer(hidden_dim, num_layers)
        if not len(kernel_size) == len(hidden_dim) == num_layers:
            raise ValueError("Inconsistent list length.")

        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.kernel_size = kernel_size
        self.num_layers = num_layers
        self.batch_first = batch_first
        self.bias = bias
        self.return_all_layers = return_all_layers

        cell_list = []
        for i in range(0, self.num_layers):
            cur_input_dim = self.input_dim if i == 0 else self.hidden_dim[i - 1]

            cell_list.append(
                ConvLSTMCell(
                    input_dim=cur_input_dim,
                    hidden_dim=self.hidden_dim[i],
                    kernel_size=self.kernel_size[i],
                    bias=self.bias,
                )
            )

        self.cell_list = nn.ModuleList(cell_list)

    def forward(self, input_tensor, hidden_state=None):
        """

        Parameters
        ----------
        input_tensor: todo
            5-D Tensor either of shape (t, b, c, h, w) or (b, t, c, h, w)
        hidden_state: todo
            None. todo implement stateful

        Returns
        -------
        last_state_list, layer_output
        """
        if not self.batch_first:
            # (t, b, c, h, w) -> (b, t, c, h, w)
            input_tensor = input_tensor.permute(1, 0, 2, 3, 4)

        b, _, _, h, w = input_tensor.size()

        # Implement stateful ConvLSTM
        if hidden_state is not None:
            raise NotImplementedError()
        else:
            # Since the init is done in forward. Can send image size here
            hidden_state = self._init_hidden(batch_size=b, image_size=(h, w))

        layer_output_list = []
        last_state_list = []

        seq_len = input_tensor.size(1)
        cur_layer_input = input_tensor

        for layer_idx in range(self.num_layers):
            h, c = hidden_state[layer_idx]
            output_inner = []
            for t in range(seq_len):
                h, c = self.cell_list[layer_idx](
                    input_tensor=cur_layer_input[:, t, :, :, :], cur_state=[h, c]
                )
                output_inner.append(h)

            layer_output = torch.stack(output_inner, dim=1)
            cur_layer_input = layer_output

            layer_output_list.append(layer_output)
            last_state_list.append([h, c])

        if not self.return_all_layers:
            layer_output_list = layer_output_list[-1:]
            last_state_list = last_state_list[-1:]

        return layer_output_list, last_state_list

    def _init_hidden(self, batch_size, image_size):
        init_states = []
        for i in range(self.num_layers):
            init_states.append(self.cell_list[i].init_hidden(batch_size, image_size))
        return init_states

    @staticmethod
    def _check_kernel_size_consistency(kernel_size):
        if not (
            isinstance(kernel_size, tuple)
            or (
                isinstance(kernel_size, list)
                and all([isinstance(elem, tuple) for elem in kernel_size])
            )
        ):
            raise ValueError("`kernel_size` must be tuple or list of tuples")

    @staticmethod
    def _extend_for_multilayer(param, num_layers):
        if not isinstance(param, list):
            param = [param] * num_layers
        return param

In [None]:
class ConvLSTMModel(torch.nn.Module):
    def __init__(self, num_channels=3, num_outputs=2, hidden_size=32, num_layers=1, input_shape=(64, 64)):
        super().__init__()

        self.convlstm = ConvLSTM(
            input_dim=num_channels,
            hidden_dim=hidden_size,
            kernel_size=(3, 3),
            num_layers=num_layers,
            batch_first=True,
            bias=True,
            return_all_layers=False
        )
        flat_shape = hidden_size * input_shape[0] * input_shape[1]
        self.fc1 = nn.Linear(flat_shape, num_outputs)
    
    def forward(self, x):
        (out, ), _  = self.convlstm(x)
        out = out[:, -1].view(out.size(0), -1)

        out = self.fc1(out)
        return out


In [None]:
train_loader = DataLoader(
    train_dataset, 
    batch_size=50, 
    shuffle=True,
    generator=torch.Generator(device=device)
)
valid_loader = DataLoader(val_dataset, batch_size=50)
test_loader = DataLoader(test_dataset, batch_size=50)

In [None]:
model = ConvLSTMModel(hidden_size=16, num_layers=3)
critereon = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters())

In [None]:
torch.cuda.empty_cache()
best_model, train_losses, val_losses = training_loop(20, model, optimizer, critereon, train_loader, valid_loader)

In [None]:
test_true, test_pred = predict(best_model, test_loader)

In [None]:
plt.plot(train_losses, label="train")
plt.plot(val_losses, label="val")
plt.title("MSE Loss")
plt.legend()
plt.show()

In [None]:
mean_squared_error(test_true, test_pred)