---
# Modelling hippocampal neurons of animals navigating in VR with recurrent neural networks
### Marco P. Abrate, Daniel Liu
---

##### Outline
Rat simulation:
- Motion model (RatInABox)
- Environment design (Blender)
- Simulated rat vision (ratvision)

Vision autoencoder

Hippocampus model (RNN):
- RNN definition
- Data loading
- Training

Hidden state representations analysis:
- Rate maps
- Polar maps
- Quantitive metrics
- Comparison with real data

---
## **Part 2: Training a Vision Autoencoder**

In this notebook, we will write code to train an **Autoencoder**. An autoencoder is a pair of artificial neural networks that compresses information into a low-dimensional embedding through the first module (aka encoder) and reconstructs it to its original form through the second module (aka decoder).

Neuroscientists use vision autoencoders to model how neurons might represent visual stimuli in the brain. The visual cortex receives complex images and it is able to extract key features - such as edges, motions and shapes - into more compact forms (low-dimensional embedding). This non-linear dimensionality reduction process, along with the reconstruction of the original image, can be compared to an autoencoder.

We will use the **PyTorch** package for this tutorial.

Before starting this notebook, make sure you have video recordings from the previous part.

### **0. Import and install dependencies**

In [None]:
!pip install torch torchvision torchaudio
!pip install scikit-learn

In [None]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from utils import *
from PIL import Image
from pathlib import Path
import itertools
from sklearn.model_selection import train_test_split

if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")
    
SEED = 42

### **1. Visualize example frame**

In [None]:
trial_paths = [p for p in Path('/Users/marco/Downloads/vrtopc/box/run').iterdir() if 'exp' in p.name]

trial_paths

In [None]:
frame_example = trial_paths[0] / 'box_messy' / 'frame0001.png'
plt.imshow(np.array(Image.open(frame_example)), cmap='gray')
plt.show()

### **2 Writing an autoencoder**

We begin by coding up our autoencoders.

As we mentioned before, an autoencoder is a type of neural network that learns to compress data into a smaller representation (encoding) and then reconstruct it back (decoding).

There is no restriction on the structure of the encoder and the decoder, and they need not be symmetric. 

However, since we are processing image frames, **convolutional layers** will be helpful becuase they can:
* capture spatial features like edges, textures, shapes, etc.

* preserve local patterns and share weights across image

* detect features efficiently regardless of position

First, let's defome some parameters.

In [None]:
# class VisualEncoder(nn.Module):
#     def __init__(self, visual_embedding_dim, img_dim, grayscale):
#         super(VisualEncoder, self).__init__()

#         self.c_channel = 1 if grayscale else 3
        
#         self.encoder_layers = nn.Sequential(
#             # WRITE YOUR CODE HERE
            
#             nn.Conv2d(self.c_channel, 16, kernel_size=3, stride=1, padding=1),
#             nn.ReLU(),
#             nn.MaxPool2d(kernel_size=2, stride=2),
#             nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),
#             nn.ReLU(),
#             nn.MaxPool2d(kernel_size=2, stride=2),
#         )
        
#         # now we flatten and project linearly to the target embedding dimension.
#         self.threshold_dims = self.get_threshold_dims(img_dim)

#         self.linear_projection = nn.Sequential(
#             nn.Flatten(),
#             nn.Linear(np.prod(self.threshold_dims), visual_embedding_dim),
#             nn.Sigmoid()
#         )
    
#     def get_threshold_dims(self, img_dim):
#         # a helper function to obtain the CNN-processed dimension.
#         h, w = img_dim
#         with torch.no_grad():
#             return self.encoder_layers.cpu()(torch.rand(1, self.c_channel, h, w)).shape[1:]

#     def forward(self, img):
#         img_conv = self.encoder_layers(img)
#         emb = self.linear_projection(img_conv)
#         return emb


# # define a decoder
# class VisualDecoder(nn.Module):
#     def __init__(self, visual_embedding_dim, threshold_dims, img_dim, grayscale):
#         super(VisualDecoder, self).__init__()

#         self.c_channel = 1 if grayscale else 3
#         self.threshold_dims = threshold_dims
#         self.img_dim = img_dim

#         self.linear_projection = nn.Sequential(
#             nn.Linear(visual_embedding_dim, np.prod(threshold_dims)),
#             nn.ReLU(),
#         )
        
#         self.decoder_layers = nn.Sequential(
#             # WRITE YOUR CODE HERE
            
#             # nn.ConvTranspose2d(32, 32, kernel_size=3, stride=2, padding=1, output_padding=1),
#             # nn.ReLU(),
#             nn.ConvTranspose2d(32, 16, kernel_size=3, stride=2, padding=1, output_padding=1),
#             nn.ReLU(),
#             nn.ConvTranspose2d(16, self.c_channel, kernel_size=3, stride=1, padding=1),
#             nn.Sigmoid() # projects to (0, 1) scale
#         )

#     def forward(self, emb):
#         emb = self.linear_projection(emb)
#         emb = emb.reshape(emb.shape[0], self.threshold_dims[0], self.threshold_dims[1], self.threshold_dims[2])
        
#         img_dec = self.decoder_layers(emb)
#         return img_dec
    

In [None]:
class VisualEncoder(torch.nn.Module):
    def __init__(self,
        visual_embedding_dim, img_dim, grayscale,
        kernel_sizes: list[tuple[int]] = [(4,5)],
        kernel_strides: list[int] = [3],
        channels: list[int] = [8],
    ):
        super().__init__()

        inc = 1 if grayscale else 3
        img_dim_out = img_dim

        encoder_conv_layers = []
        for i in range(len(kernel_sizes)):
            ksize = kernel_sizes[i]
            stride = kernel_strides[i]
            encoder_conv_layers.append(torch.nn.Conv2d(
                in_channels = (inc if i == 0 else channels[i-1]),
                out_channels = channels[i],
                kernel_size = ksize,
                stride = stride,
            ))

            img_dim_out = [
                int((img_dim_out[i] - ksize[i])/stride + 1)
                for i in range(len(img_dim))
            ]
            
            encoder_conv_layers.append(torch.nn.ReLU())

        self.encoder_conv = torch.nn.Sequential(*encoder_conv_layers)
        
        self.img_dim_out = img_dim_out
        
        print(f'Encoder convolution layer: {self.encoder_conv}')
        print(f'Final convolution size: {channels[-1]}x{img_dim_out}')
        print(f'Flattens to {channels[-1]*np.prod(img_dim_out)}\n')
            
        self.flatten = torch.nn.Flatten(start_dim=1)

        self.encoder_lin = torch.nn.Sequential(
            torch.nn.Linear(channels[-1]*np.prod(img_dim_out), visual_embedding_dim),
            torch.nn.Sigmoid()
        )
        print(f'Encoder linear layer: {self.encoder_lin}')

    def forward(self, img):
        
        img_conv = self.encoder_conv(img)
        img_conv_flatten = self.flatten(img_conv)

        embeddings = self.encoder_lin(img_conv_flatten)

        return embeddings


In [None]:
class VisualDecoder(torch.nn.Module):
    def __init__(self,
        visual_embedding_dim, img_dim_out, grayscale,
        kernel_sizes: list[tuple[int]] = [(4,5)],
        kernel_strides: list[int] = [3],
        channels: list[int] = [8],
    ):
        super().__init__()

        inc = 1 if grayscale else 3

        self.decoder_lin = torch.nn.Linear(visual_embedding_dim, channels[-1]*np.prod(img_dim_out))
        print(f'\nDecoder linear layer: {self.decoder_lin}\n')

        self.unflatten = torch.nn.Unflatten(
            dim=1,
            unflattened_size=channels[-1:]+img_dim_out
        )
            
        decoder_conv_layers = []
        for i in range(len(kernel_sizes)-1, -1, -1):
            decoder_conv_layers.append(torch.nn.ConvTranspose2d(
                in_channels = channels[i],
                out_channels = (inc if i == 0 else channels[i-1]),
                kernel_size = kernel_sizes[i],
                stride = kernel_strides[i],
            ))
            decoder_conv_layers.append(
                torch.nn.Sigmoid() if i == 0 else torch.nn.ReLU()
            )

        self.decoder_conv = torch.nn.Sequential(*decoder_conv_layers)
        print(f'Decoder convolution layer: {self.decoder_conv}')

    def forward(self, embeddings):
        img_conv = self.decoder_lin(embeddings)

        img_conv = self.unflatten(img_conv)
        
        img_reconstructed = self.decoder_conv(img_conv)
        
        return img_reconstructed
    

Now that we have defined our encoder and decoder, let's initialise them and see an example of the reconstruction.

We will import two pre-defined functions, ```preprocess_frame()``` and ```preprocess_frame_batch()``` from ```utils.py```.

The ```preprocess_frame``` accepts the path of a frame, then converts it to RGB values and normalise it 1 so that we have a ```(width, length, channels)``` array of numbers between 0 and 1.

Optionally, the frame can be converted to grayscale, so that the channel dimension is 1. We can also decide to add Gaussian noise to the frame.

In [None]:
def preprocess_frame(frame, grayscale, img_noise_std):
    img = Image.open(frame)
    if grayscale: 
        img = img.convert('L')
    img = np.array(img, dtype=np.float32)

    # add noise to the image if img_noise_std > 0
    if img_noise_std > 0:
        img += np.random.normal(0, img_noise_std, size=img.shape)

    # normalise RGB to (0, 1) scale
    img = (img - img.min()) / (img.max() - img.min())
    
    if grayscale:
        img = img[None, ...] # (1, h, w) if grayscale
    else:
        img = np.moveaxis(img, -1, 0) # reshape to (3, h, w) if RGB

    return img

def preprocess_frame_batch(all_frames, batch_indices, grayscale, img_noise_std):
    # preprocess a batch of frames
    imgs = np.array([
        preprocess_frame(all_frames[idx], grayscale, img_noise_std)
        for idx in batch_indices
    ])
    return imgs

In [None]:
EMB_DIM = 50 # the number of neurons in the latent space (or number of latent features)
IMG_DIM = (64, 128) # (height, width) of the input images
GS = True # whether to use grayscale images
IMG_NOISE_STD = 0 # standard deviation of the noise to be added to the images

KERNEL_SIZES = [(2, 2), (3, 3)] # kernel sizes for the convolutional layers
KERNEL_STRIDES = [1, 2] # strides
CHANNELS = [16, 32] # number of channels

encoder = VisualEncoder(
    visual_embedding_dim = EMB_DIM,
    img_dim = IMG_DIM,
    grayscale = GS,
    kernel_sizes = KERNEL_SIZES,
    kernel_strides = KERNEL_STRIDES,
    channels = CHANNELS
).to(device)

decoder = VisualDecoder(
    visual_embedding_dim = EMB_DIM,
    img_dim_out = encoder.img_dim_out,
    grayscale = GS,
    kernel_sizes = KERNEL_SIZES,
    kernel_strides = KERNEL_STRIDES,
    channels = CHANNELS
).to(device)


In [None]:
# visualise an example
with torch.no_grad():
    frame_example_img = preprocess_frame(frame_example, grayscale=GS, img_noise_std=IMG_NOISE_STD)[None, ...]
    frame_example_recon = decoder(encoder(
        torch.from_numpy(frame_example_img).to(device)
    ))
    
frame_example_recon = frame_example_recon.cpu().numpy()

fig, axs = plt.subplots(1, 2, figsize=(10, 5))
axs[0].set_title('Original Frame')
axs[0].imshow(frame_example_img[0, 0], cmap='gray')
axs[1].set_title('Reconstructed Frame')
axs[1].imshow(frame_example_recon[0, 0], cmap='gray')
plt.show()

This looks like just noise. This is because the autoencoder is untrained. But the good news is: at least we got the dimensions correct!

Now, let's train the pair of networks.

### **3 Training an autoencoder**

The process of training a neural network is well established. In general, we follow these steps:

* **Collect and pre-process data**: in the case of image frames, we need to consider normalising the RGB values.

* **Split the dataset** into train, test (and optionally validation) sets

* **Initialise model weights**, which we have just done!

* **Train the model**. This includes a forward pass, computing loss, backpropagation and update weights

* Where needed, **validate** on the validation set to tune hyperparameters.

* **Test** the trained model on unseen data to evaluate performance.

* Once the model has reached satisfactory performance, it is ready for **deployment**.

In [None]:
globs = [(tp/'box_messy').glob('*.png') for tp in trial_paths]
all_frames = [f for g in globs for f in g][:20_480]

In [None]:
# Split dataset into train and test sets
BATCH_SIZE = 2048
TEST_SET_PROP = 0.1 # 10%

train_indices, test_indices = train_test_split(
    np.arange(len(all_frames)), test_size=TEST_SET_PROP, shuffle=True, random_state=SEED
)

In [None]:
train_batches = np.split(
    train_indices[:(len(train_indices)//BATCH_SIZE*BATCH_SIZE)],
    len(train_indices)//BATCH_SIZE
)
test_batches = np.split(
    test_indices[:(len(test_indices)//BATCH_SIZE*BATCH_SIZE)],
    len(test_indices)//BATCH_SIZE
)

In [None]:
# Define the train and test functions

def train_epoch(
    enc, dec,
    all_frames, train_batches,
    loss_fn, optimizer, scheduler=None
):
    enc.train()
    dec.train()
    epoch_loss = 0
    
    for batch_indices in train_batches:
        optimizer.zero_grad()

        imgs = preprocess_frame_batch(
            all_frames, batch_indices, grayscale=GS, img_noise_std=IMG_NOISE_STD
        )
        imgs = torch.from_numpy(imgs).to(device)
        
        # YOUR CODE HERE: forward pass
        imgs_recon = dec(enc(imgs))
        loss = loss_fn(imgs, imgs_recon)
        
        # YOUR CODE HERE: backward pass and optimisation step
        loss.backward()
        optimizer.step()
        if scheduler is not None:
            scheduler.step()
        
        epoch_loss += loss.detach().item()
            
    return epoch_loss / len(train_batches)

def test_epoch(
    enc, dec,
    all_frames, test_batches,
    loss_fn,
):
    enc.eval()
    dec.eval()
    epoch_loss = 0
    
    with torch.no_grad():
        for batch_indices in test_batches:
            imgs = preprocess_frame_batch(
                all_frames, batch_indices, grayscale=GS, img_noise_std=IMG_NOISE_STD
            )
            imgs = torch.from_numpy(imgs).to(device)
            
            # YOUR CODE HERE: forward pass
            imgs_recon = dec(enc(imgs))
            loss = loss_fn(imgs, imgs_recon)
            
            epoch_loss += loss.item()
    
    return epoch_loss / len(test_batches)

In [None]:
# Putting it all together

n_epochs = 10
loss_fn = nn.L1Loss()
learning_rate = 1e-3

optimizer = torch.optim.Adam(
    itertools.chain(encoder.parameters(), decoder.parameters()),
    lr=learning_rate
)

# optional: use a learning rate scheduler
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.999)


In [None]:
for epoch in range(n_epochs):
    train_loss = train_epoch(
        encoder, decoder,
        all_frames, train_batches,
        loss_fn, optimizer, scheduler
    )
    test_loss = test_epoch(
        encoder, decoder,
        all_frames, test_batches,
        loss_fn
    )
    
    with torch.no_grad():
        frame_example_recon = decoder(encoder(
            torch.from_numpy(frame_example_img).to(device)
        ))
    frame_example_recon = frame_example_recon.cpu().numpy()
    fig, axs = plt.subplots(1, 2, figsize=(10, 5))
    axs[0].set_title('Original Frame')
    axs[0].imshow(frame_example_img[0, 0], cmap='gray')
    axs[1].set_title('Reconstructed Frame')
    axs[1].imshow(frame_example_recon[0, 0], cmap='gray')
    plt.show()

    print(f'Epoch {epoch + 1}/{n_epochs}, Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}')


### **What have we achieved in this tutorial?**

We trained a pair of autoencoders, that can compress images into a latent vector.

With a biological constraint (such as non-negative, or sigmoid), then much like the visual cortex, we can interpret the latent vector for each frame as population activity of visual neurons encoding that visual scene.

These encodings of the visual scene can now be readily fed into a **Recurrent Neural Network** to generate spatial cell tunings.