
---

## Modelling hippocampal neurons of animals navigating in VR with Recurrent Neural Networks

##### *Tutorial, COS 2025*
### *Part 3: Training an RNN*
##### made by: Daniel Liu, Marco Abrate, UCL
---
In this notebook, we will write code to define the **RNN**. 

Recurrent Neural Networks (RNNs) are neural models designed to process sequential data by retaining memory over time. Specifically in this tutorial, we train an RNN to perform a **next-step prediction task**.

What is next-step prediction and why?



Prerequisites:

* Completed Notebook 1

Before starting this notebook, make sure you have:

* All frames, processed into embeddings in ```.npy``` file, from the Autoencoder we trained in the last tutorial.

* Trajectory file, including speed and rotational velocity at each discritised time step.

* The accompanying `utils.py` helper functions.

* If you run this locally, you will need a CUDA- or MKL- enabled PyTorch version, and a GPU (or an Apple M-series chip).

In [4]:
# imports
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from utils import *
from PIL import Image
import pathlib
import itertools
from sklearn.model_selection import train_test_split

# install any packages used in the utils.py function here

if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")
    
seed = 42

In [None]:
# load the embeddings and trajectory data
train_embeddings = None
train_speed = None
train_rot_vels = None

assert train_embeddings.shape[0] == train_speed.shape[0] == train_rot_vels.shape[0]

visual_embedding_dim = train_embeddings.shape[-1]
motion_signal_dim = train_speed.shape[-1] + train_rot_vels.shape[-1]


#### Defining an RNN

Let's define a recurrent neural network. This network will use a pre-defined ```RNNModule``` class, which we have provided, that employs Sigmoidal activation functions when projecting inputs to hidden states. The hidden states are then projected to predict the next sensory state via a linear layer. 

The sigmoidal function compresses the hidden states to between 0 and 1, which can be interpreted as the (scaled) activity of each neuron.

In [None]:
class PathIntegrationRNN(nn.Module):
    '''
    This class implements a simple RNN that uses current step sensory input, motion signals,
    and the previous hidden state to predict the next step sensory input.
    The RNN is expected to learn to integrate the sensory inputs over time, effectively simulating
    path integration.
    '''
    def __init__(self, input_dim, hidden_dim, output_dim, device):
        super(PathIntegrationRNN, self).__init__()

        self.device = device
        self.rnn = RNNModule(n_inputs =input_dim, n_hidden=hidden_dim, nonlinearity='sigmoid', device=device)
        self.fc = nn.Linear(hidden_dim, output_dim)
        
        self.to(device)
        
    def forward(self, x, hidden):
        
        # COMPLETE THE CODE HERE: 
        # x should be of shape (batch_size, seq_length, input_dim)
        raise NotImplementedError
        
        # out, hidden = self.rnn(x, hidden)
        # out = self.fc(out)
        # return out, hidden

Note that our input dimension would be the both the visual embedding dimension and the motion signal.

In [None]:
rnn = PathIntegrationRNN(input_dim=visual_embedding_dim + motion_signal_dim,
                         hidden_dim=50,
                         output_dim=visual_embedding_dim).to(device)

#### Defining a DataLoader

To allow our RNN to perform the next-state prediction task, we need to use a **Dataset** and **DataLoader**. The **Dataset** prepares sensory embeddings as well as motion signals to batches of paired inputs and labels; the **DataLoader** sequentially generates these batches during training and testing.

The Dataset and Dataloader classes will inherit PyTorch's built-in ```torch.utils.data.Dataset``` and ```torch.utils.data.DataLoader``` classes.

In [None]:
class SensoryDataset(torch.utils.data.Dataset):
    def __init__(self, embs, vels, rot_vels, hds, thetas, pos, tsteps=9, n_future_pred=1):
        '''
        The initialisation function for the SensoryDataset class.
        At initialisation, all embeddings are converted to tensors.
        Args:
            embs: The visual embeddings of shape (N, D)
            vels: The speed signals of shape (N, 1)
            rot_vels: The rotational velocities of shape (N, 1)
            hds: The headings of shape (N, 1)
            thetas: The headings of shape (N, 1)
            tsteps: The number of time steps for each batch. 
                    By default, this is set to 9 i.e. we use the sensory input from steps 1 to 9
            n_future_pred: The number of steps into the future to predict.
                    By default, this is set to 1 i.e. we predict the sensory input at steps 2 to 10            
        '''
        
        self.embs = torch.tensor(embs, dtype=torch.float32)
        self.vels = torch.tensor(vels, dtype=torch.float32)
        self.rot_vels = torch.tensor(rot_vels, dtype=torch.float32)
        self.hds = torch.tensor(hds, dtype=torch.float32)
        self.thetas = torch.tensor(thetas, dtype=torch.float32)
        self.pos = torch.tensor(pos, dtype=torch.float32)
        
        self.tsteps = tsteps
        self.n_future_pred = n_future_pred
        
        del embs, vels, rot_vels, hds, thetas # free up memory
    
    def __len__(self):
        # COMPLETE THE CODE HERE: how many samples are in the dataset?
        return self.embs.shape[0] // self.tsteps - self.n_future_pred
    
    def __getitem__(self, idx):
        '''
        This function returns a batch of sensory inputs and the corresponding future sensory inputs.
        Args:
            idx: The index of the sample to return.
        Returns:
            A tuple containing:
                - sensory_inputs: A tensor of shape (tsteps, input_dim) containing the sensory inputs
                - future_sensory_inputs: A tensor of shape (n_future_pred, input_dim) containing the future sensory inputs
        '''
        embs, vels, rot_vels, pos, hds, theta, embs_labels = [], [], [], [], [], [], []
        embs = self.embs[idx * self.tsteps:(idx + 1) * self.tsteps]
        
        start_idx, end_idx = idx * self.tsteps, (idx + 1) * self.tsteps
        
        for future_step in range(self.n_future_pred):
            # COMPLETE THE CODE HERE: 
            # Get the sensory inputs and future sensory inputs for the given index
            pass
        
            # vels.append(self.vels[:, start_idx + future_step: end_idx + future_step])
            # rot_vels.append(self.rot_vels[:, start_idx + future_step: end_idx + future_step])
            # pos.append(self.pos[:, start_idx + future_step: end_idx + future_step])
            # hds.append(self.hds[:, start_idx + future_step: end_idx + future_step])
            # theta.append(self.thetas[:, start_idx + future_step: end_idx + future_step])
            
            # embs_labels.append(embs[:, start_idx + future_step + 1: end_idx + future_step + 1])
        
        vels, rot_vels, pos, hds, theta, embs_labels = torch.vstack(vels, dim=1), \
                                                        torch.vstack(rot_vels, dim=1), \
                                                        torch.vstack(pos, dim=1), \
                                                        torch.vstack(hds, dim=1), \
                                                        torch.vstack(theta, dim=1), \
                                                        torch.vstack(embs_labels, dim=1) 
            
        
        return embs, vels, rot_vels, pos, hds, theta, embs_labels