
---

## Modelling hippocampal neurons of animals navigating in VR with Recurrent Neural Networks

##### *Tutorial, COS 2025*
### *Part 4: Extracting Neural Representations from an RNN*
##### made by: Daniel Liu, Marco Abrate, UCL
---
In this notebook, we will write code to extract representations from the **RNN hidden states**. These includes:

* Rate maps

* Polar maps

* Quantitative metrics

* Comparison with _in vivo_ data


Prerequisites:

* Completed Notebook 2

Before starting this notebook, make sure you have:

* All frames, processed into embeddings in ```.npy``` file, from the Autoencoder we trained in the last tutorial.

* Trajectory file, including speed and rotational velocity at each discritised time step.

* The accompanying `utils.py` helper functions.


### **0. Install and import dependencies**

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from IPython.display import clear_output

if torch.cuda.is_available():
    DEVICE = torch.device("cuda")
elif torch.backends.mps.is_available():
    DEVICE = torch.device("mps")
else:
    DEVICE = torch.device("cpu")
    
SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)

### **1. Loading test data and the trained RNN**

In the first step, we will define the RNN in Part 3 of this tutorial and load the trained weights. We will also need the test embeddings and auxiliary variables, with which we will use to compute the rate maps.

In [None]:
# Run this cell, do not modify it.

def create_multiple_subsampling(data, stride, is_velocity=False):
    new_length = data.shape[0]//stride if not is_velocity else data.shape[0]//stride-1
    data_multisubs = np.zeros(
        (stride, new_length, data.shape[1]),
        dtype=np.float32
    )
    for start_idx in range(stride):
        if is_velocity:
            if start_idx < stride-1:
                data_multisubs[start_idx] = data[start_idx+1:start_idx-stride+1].reshape(
                    new_length, stride, -1
                ).sum(axis=1)
            else:
                data_multisubs[start_idx] = data[start_idx+1:].reshape(
                    new_length, stride, -1
                ).sum(axis=1)
        else:
            data_multisubs[start_idx] = data[start_idx::stride]
    return data_multisubs

class RNNCell(torch.nn.Module):
    def __init__(self, n_inputs, n_hidden, input_bias, hidden_bias):
        super(RNNCell, self).__init__()

        self.in2hidden = torch.nn.Linear(n_inputs, n_hidden, bias=input_bias)
        self.hidden2hidden = torch.nn.Linear(n_hidden, n_hidden, bias=hidden_bias)

        self.activation_fn = torch.nn.Sigmoid()

    def forward(self, x, hidden):
        igates = self.in2hidden(x)
        hgates = self.hidden2hidden(hidden)
        return self.activation_fn(igates + hgates)


class RNNModule(torch.nn.Module):
    def __init__(
        self, device, n_inputs, n_hidden,
        input_bias, hidden_bias
    ):
        super(RNNModule, self).__init__()

        self.rnn_cell = RNNCell(n_inputs, n_hidden, input_bias, hidden_bias)
        self.n_hidden = n_hidden

        self.device = device

    def forward(self, x, hidden=None):
        # x: [BATCH SIZE, TIME, N_FEATURES]
        # hidden: [BATCH SIZE, N_HIDDEN]
        
        output = torch.zeros(x.shape[0], x.shape[1], self.n_hidden).to(self.device)

        if hidden is None:
            h_out = torch.zeros(x.shape[0], self.n_hidden) # initialize hidden state
            h_out = h_out.to(self.device)
        else:
            h_out = hidden

        window_size = x.shape[1]

        # loop over time
        for t in range(window_size):
            x_t = x[:,t,...]
            h_out = self.rnn_cell(x_t, h_out)
            output[:,t,...] = h_out

        # return all outputs, and the last hidden state
        return output, h_out

class PredictiveRNN(torch.nn.Module):
    def __init__(self,
        device, n_inputs, n_hidden, n_outputs, bias=False
    ):
        super().__init__()

        self.rnn = RNNModule(
            device, n_inputs, n_hidden,
            input_bias=bias, hidden_bias=bias
        )

        self.linear_layer = torch.nn.Linear(n_hidden, n_outputs, bias=bias)

    def inputs2hidden(self, inputs, hidden):
        """ Encodes the input tensor into a latent representation.

        Args:
            x: [BATCH SIZE, TIME, CHANNELS, HEIGHT, WIDTH]
        """
        
        if hidden is not None:
            return self.rnn(inputs, hidden[None, ...])[0]
        else:
            return self.rnn(inputs)[0]

    def hidden2outputs(self, hidden):
        return self.linear_layer(hidden)
    
    def forward(self, inputs, hidden=None):
        hidden_new = self.inputs2hidden(inputs, hidden)

        output = self.hidden2outputs(hidden_new)

        return output, hidden_new[:,-1,:]

class SensoryDataset(torch.utils.data.Dataset):
    def __init__(self, embs, vels, rot_vels, pos, hds, tsteps=9):
        '''
        The initialisation function for the SensoryDataset class.
        At initialisation, all embeddings are converted to tensors.
        Args:
            embs: The visual embeddings of shape (N, T, D)
            vels: The speed signals of shape (N, T-1, 1)
            rot_vels: The rotational velocities of shape (N, T-1, 1)
            pos: The positions of shape (N, T, 2)
            hds: The headings of shape (N, T, 1)
            tsteps: The number of time steps for each batch.
                By default, this is set to 9 i.e. we use the sensory input from steps 1 to 9          
        '''
        self.embs = torch.from_numpy(embs)
        self.vels = torch.from_numpy(vels)
        self.rot_vels = torch.from_numpy(rot_vels)
        self.pos = torch.from_numpy(pos)
        self.hds = torch.from_numpy(hds)
        
        self.tsteps = tsteps
    
    def __len__(self):
        # COMPLETE THE CODE HERE: how many samples are in the dataset?
        return self.embs.shape[1] // self.tsteps - 1
    
    def __getitem__(self, idx):
        '''
        This function returns a batch of sensory inputs and the corresponding future sensory inputs.
        Args:
            idx: The index of the sample to return. idx will be automatically generated by the DataLoader.
        Returns:
        
        '''
        vels, rot_vels, pos, hds, embs_labels = [], [], [], [], []

        start_idx, end_idx = idx*self.tsteps, (idx + 1)*self.tsteps

        embs = self.embs[:, start_idx:end_idx]

        vels = self.vels[:, start_idx:end_idx]
        rot_vels = self.rot_vels[:, start_idx:end_idx]
        pos = self.pos[:, start_idx:end_idx]
        hds = self.hds[:, start_idx:end_idx]

        embs_labels = self.embs[:, start_idx+1 : end_idx+1]
        
        return embs, vels, rot_vels, pos, hds, embs_labels


    

In [None]:
# let's load the test data and process it into a dataloader, just like we did in the last part.

STRIDE = 10
d = './data/adult'
trial_paths = sorted([p for p in Path(d).iterdir() if 'exp' in p.name])
trial_paths


test_embeddings = []
test_vel, test_rotvel, test_pos, test_hds = [], [], [], []

for idx in range(20, 23):
    tp = trial_paths[idx]
    test_embeddings.append(
        create_multiple_subsampling(np.load(tp / 'vision_embeddings.npy'), stride=STRIDE)
    )
    test_vel.append(
        create_multiple_subsampling(
            np.load(tp / 'riab_simulation' / 'velocities.npy'), stride=STRIDE, is_velocity=True
        )
    )
    test_rotvel.append(
        create_multiple_subsampling(
            np.load(tp / 'riab_simulation' / 'rot_velocities.npy')[..., None],
            stride=STRIDE, is_velocity=True
        )
    )
    test_pos.append(
        create_multiple_subsampling(np.load(tp / 'riab_simulation' / 'positions.npy'), stride=STRIDE)
    )
    test_hds.append(
        create_multiple_subsampling(np.load(tp / 'riab_simulation' / 'thetas.npy')[..., None], stride=STRIDE)
    )

test_embeddings = np.concatenate(test_embeddings, axis=0)
test_vel = np.concatenate(test_vel, axis=0)
test_rotvel = np.concatenate(test_rotvel, axis=0)
test_pos = np.concatenate(test_pos, axis=0)
test_hds = np.concatenate(test_hds, axis=0)

test_dataloader = torch.utils.data.DataLoader(
    SensoryDataset(
        test_embeddings, test_vel, test_rotvel, test_pos, test_hds
    ), shuffle=False
)

In [None]:
# let's recreate the PredictiveRNN model, just like we did in the last part.
visual_embedding_dim = test_embeddings.shape[-1]
motion_signal_dim = test_vel.shape[-1] + test_rotvel.shape[-1]
trained_rnn_weights ='./rnn.pth'

rnn = PredictiveRNN(
    DEVICE,
    n_inputs=visual_embedding_dim + motion_signal_dim,
    n_hidden=500,
    n_outputs=visual_embedding_dim
).to(DEVICE)

loss_fn = torch.nn.L1Loss()

# load the trained weights
rnn.load_state_dict(torch.load(trained_rnn_weights, map_location=DEVICE))

### **2. Obtained hidden states from the trained RNN**

We had written a function named ```evaluate_rnn()```, when the parameter ```for_ratemaps``` is set to ```True```, the function will return a dictionary containing the hidden staes, positions, velocities, etc. at each step for the convenience of our analysis.

In [None]:
from utils import evaluate_rnn
d = evaluate_rnn(DEVICE, rnn, test_dataloader, loss_fn, for_ratemaps=True)

d['hidden_states'].shape

### **3. Rate maps**

The **rate map** of a neuron is essentially the average activity of a neuron across an environment. This is, by design, a continuous distribution.

Computationally, this requires us to discritise the environment into smallers 'bins', then compute the average activity at each small bin. The function below computes such average.

In [None]:
def compute_spatial_tuning_curves(hidden_states, pos, grid_size=(10,10)):
    '''
    Computes the spatial tuning curves of the neurons in the integrator RNN
    :params:
        hidden_states:      np.array, shape (num_instances, num_neurons), hidden states of the integrator RNN
        pos:                np.array, shape (num_instances, 2), spatial position of the agent
        grid_size:          tuple, discritisation size of the spatial field
    '''

    # Initialize arrays to hold tuning curves
    cumulative_activation_at_each_bin = np.zeros((hidden_states.shape[1], grid_size[0], grid_size[1]))
    occupancy_at_each_bin = np.ones((grid_size[0], grid_size[1]))

    # aggregate neuron activity by spatial position
    for i in range(pos.shape[0]):
        cumulative_activation_at_each_bin = None
        occupancy_at_each_bin = None
        ### COMPLETE THE CODE HERE: convert the position to grid coordinates
        # x, y = int(pos[i, 0] * (grid_size[0])), int(pos[i, 1] * (grid_size[1]))
        # cumulative_activation_at_each_bin[:, x, y] += h_ts[i]
        # occupancy_at_each_bin[x, y] += 1
    
    # COMPLETE THE CODE HERE: Compute tuning curve as the average activity at each spatial position
    rate_map = None
    # rate_map = cumulative_activation_at_each_bin / occupancy_at_each_bin

    return rate_map, occupancy_at_each_bin

rate_map, occupancy = compute_spatial_tuning_curves(
    d['hidden_states'], d['pos'], grid_size=(10, 10)
)

Let's now visualise a single rate map of the first neuron.

In [None]:
one_rate_map = rate_map[0]
plt.figure(figsize=(10, 10))
plt.imshow(one_rate_map, cmap='hot', interpolation='nearest')
plt.colorbar(label='Neuron activity')
plt.title('Spatial Tuning Curve for Neuron 0')
plt.xlabel('X Position (Grid)')
plt.ylabel('Y Position (Grid)')
plt.show()

In the real brain, neuron firings are stochastic. This means there is a lot of noise in the process, and the average activity of each bin is a crude (but unbiased) estimate of the real tuning curve, depending on the amount of data provided.

Often, we make the assumption that a neuron's activity do not change too much in neighbouring bins. To make our rate maps look smoother, it is customary to apply **Gaussian smoothing** to the rate maps. This is easily achievable with packages.

In [None]:
from scipy.ndimage import gaussian_filter

smoothened_rate_map = gaussian_filter(one_rate_map, sigma=2) # this means apply a Gaussian filter with a standard deviation of 2 bins

plt.figure(figsize=(10, 10))
plt.imshow(smoothened_rate_map, cmap='hot', interpolation='nearest')
plt.colorbar(label='Neuron activity')
plt.title('Smoothened Spatial Tuning Curve for Neuron 0')
plt.xlabel('X Position (Grid)')
plt.ylabel('Y Position (Grid)')
plt.show()

In [None]:
# optional: we can plot 100 rate maps at once here

### **4. Polar maps**

Comparable to rate maps, the **polar maps** show how a neuron fires when it is tuned to a specific angle (the head direction).