---
# Modelling hippocampal neurons of animals navigating in VR with recurrent neural networks
### Marco P. Abrate, Daniel Liu
---

##### Outline
Rat simulation:
- Motion model (RatInABox)
- Environment design (Blender)
- Simulated rat vision (ratvision)

Vision autoencoder

Hippocampus model (RNN):
- RNN definition
- Data loading
- Training

Hidden state representations analysis:
- Rate maps
- Polar maps
- Quantitive metrics
- Comparison with real data


### *Part 2: Training a Vision Autoencoder*
---
In this notebook, we will write code to train an **Autoencoder**. An autoencoder is a pair of artificial neural network that compresses information into a low-dimensional embedding. It does so by learning to reconstruct the original information.

Researchers use autoencoders to model how neurons might represent visual stimuli in the brain. The visual cortex receives complex images but is able to extract key features, such as edges, motions and shapes, into more compact forms. This non-linear dimensionality reduction process can be somewhat comparable to what an autoencoder achieves.

Prerequisites:

* Basic Python syntax

* Basic understanding of deep learning

* We will use the **PyTorch** package for this tutorial. If you have used another DL framework, you will also be fine - simply refer to the documentations

Before starting this notebook, make sure you have:

* A video recording from the previous part.

* The accompanying `utils.py` helper functions.

* If you run this locally, you will need a CUDA- or MKL- enabled PyTorch version, and a GPU (or an Apple M-series chip).

### **0 Import and install dependencies**

In [None]:
# imports
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from utils import *
from PIL import Image
import pathlib
import itertools
from sklearn.model_selection import train_test_split

# install any packages used in the utils.py function here

if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")
    
SEED = 42

### **1 Visualize example frame**

In [None]:
frames_path = '/Users/marco/Downloads/vrtopc/box/run/exp_dim0.635_fps10_s720_seed01/box_messy'

In [None]:
frame_example = frames_path / 'frame0001.png'
plt.imshow(np.array(Image.open(frame_example)))

### **2 Writing an autoencoder**

We begin by coding up our autoencoders.

As we mentioned before, an autoencoder is a type of neural network that learns to compress data into a smaller representation (encoding) and then reconstruct it back (decoding).

There is no restriction on the structure of the encoder and the decoder, and they need not be symmetric. 

However, since we are processing image frames, **convolutional layers** will be helpful becuase they can:
* capture spatial features like edges, textures, shapes, etc.

* preserve local patterns and share weights across image

* detect features efficiently regardless of position

First, let's defome some parameters.

In [None]:
visual_embedding_dim = 50 # the number of neurons in the latent space (or number of latent features)
img_dim = (128, 64) # (width, height) of the input images
grayscale = True

In [None]:
class VisualEncoder(nn.Module):
    def __init__(self, visual_embedding_dim, img_dim, grayscale):
        super(VisualEncoder, self).__init__()

        self.c_channel = 1 if grayscale else 3
        
        self.encoder_layers = nn.Sequential(
            # WRITE YOUR CODE HERE
            
            # nn.Conv2d(self.c_channel, 32, kernel_size=3, stride=1, padding=1),
            # nn.ReLU(),
            # nn.MaxPool2d(kernel_size=2, stride=2),
            # nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            # nn.ReLU(),
            # nn.MaxPool2d(kernel_size=2, stride=2),
            # nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=1),
        )
        
        # now we flatten and project linearly to the target embedding dimension.
        self.threshold_dims = self.get_threshold_dims(img_dim)

        self.linear_projection = nn.Sequential(
            nn.Flatten(),
            nn.Linear(np.prod(self.threshold_dims), visual_embedding_dim),
            nn.Sigmoid()
        )
    
    def get_threshold_dims(self):
        # a helper function to obtain the CNN-processed dimension.
        w, l = img_dim
        with torch.no_grad():
            return self.encoder_layers.cpu()(torch.rand(1, self.c_channel, w, l)).shape[1:]

    def forward(self, img):
        img_conv = self.encoder_layers(img)
        emb = self.linear_projection(img_conv)
        return emb
    

In [None]:
# define a decoder
class VisualDecoder(nn.Module):
    def __init__(self, visual_embedding_dim, threshold_dims, img_dim, grayscale):
        super(VisualDecoder, self).__init__()

        self.c_channel = 1 if grayscale else 3
        self.threshold_dims = threshold_dims
        self.img_dim = img_dim

        self.linear_projection = nn.Sequential(
            nn.Linear(visual_embedding_dim, np.prod(threshold_dims)),
            nn.ReLU(),
        )
        
        self.decoder_layers = nn.Sequential(
            # WRITE YOUR CODE HERE
            
            # nn.Unflatten(1, (32, threshold_dim[0], threshold_dim[1])),
            # nn.ConvTranspose2d(32, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
            # nn.ReLU(),
            # nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1),
            # nn.ReLU(),
            # nn.ConvTranspose2d(32, self.c_channel, kernel_size=3, stride=1, padding=1),
            
            nn.Sigmoid() # projects to (0, 1) scale
        )

    def forward(self, emb):
        emb = self.linear_projection(emb)
        emb = emb.reshape(emb.shape[0], self.threshold_dims[0], self.threshold_dims[1], self.threshold_dims[2])
        
        img_dec = self.decoder_layers(emb)
        return img_dec

Now that we have defined our encoder and decoder, let's initialise them and see an example of the reconstruction.

In [None]:
# Let's implement a function to preprocess each frame
def preprocess_frame(frame, grayscale):
    img = Image.open(frame)
    if grayscale: 
        img = img.convert('L')
    img = np.array(img, dtype=np.float32)
    if grayscale:
        img = np.expand_dims(img, axis=-1)

    # YOUR CODE HERE: normalise RGB to (0, 1) scale
    # img = img / 255.0
    
    img = np.swapaxes(img, 0, 2) # convert to (3, w, l) if RGB, (1, w, l) if grayscale
    return img

def preprocess_frame_batch(frames_glob, batch_idxs, grayscale, img_noise=0.):
    # preprocess a batch of frames, with indices given by batch_idxs
    imgs = np.array([preprocess_frame(frames_glob[idx], grayscale=grayscale) for idx in batch_idxs])
    return np.clip(imgs + np.random.normal(0, img_noise, size=imgs.shape), a_min=0, a_max=1)

We will import two pre-defined functions, ```preprocess_frame()``` and ```preprocess_frame_batch()``` from ```utils.py```.

The ```preprocess_frame``` accepts the path of a frame, then converts it to RGB values and normalise it 1 so that we have a ```(width, length, channels)``` array of numbers between 0 and 1.

Optionally, the frame can be converted to grayscale, so that the channel dimension is 1. It can also add Gaussian noise to the frame.

In [None]:
from utils import preprocess_frame, preprocess_frame_batch

encoder = VisualEncoder(
    visual_embedding_dim = 50, # the number of neurons in the latent space (or number of latent features)
    img_dim = (128, 64), # (width, height) of the input images
    grayscale = True
).to(device)

decoder = VisualDecoder(
    visual_embedding_dim = 50,
    threshold_dims = encoder.threshold_dims,
    img_dim = (128, 64),
    grayscale = True
).to(device)

# visualise an example
with torch.no_grad():
    frame_example_img = preprocess_frame(frame_example, grayscale = True)
    frame_example_recon = decoder(encoder(torch.as_tensor(frame_example_img, dtype=torch.float32).to(device)))
    
plt.imshow(np.swapaxes(frame_example_recon.cpu().numpy(), 0, 2))

This looks like just noise. This is because the autoencoder is untrained. But the good news is: at least we got the dimensions correct!

Now, let's train the pair of networks.

### **3 Training an autoencoder**

The process of training a neural network is well established. In general, we follow these steps:

* **Collect and pre-process data**: in the case of image frames, we need to consider normalising the RGB values.

* **Split the dataset** into train, test (and optionally validation) sets

* **Initialise model weights**, which we have just done!

* **Train the model**. This includes a forward pass, computing loss, backpropagation and update weights

* Where needed, **validate** on the validation set to tune hyperparameters.

* **Test** the trained model on unseen data to evaluate performance.

* Once the model has reached satisfactory performance, it is ready for **deployment**.

In [None]:
# First, split dataset
batch_size = 2048
test_set_proportion = 0.05 # 5%

# create a glob list of all file paths
all_frames_glob = [png for png in frames_path.glob('*.png')]
train_idxs, test_idxs = train_test_split(
    np.arange(len(all_frames_glob)), test_size=test_set_proportion, shuffle=True, random_state=SEED
)

In [None]:
# Next, we can define the train and test functions
def train_epoch(
    enc, dec,
    train_idxs, batch_size, all_frames_glob,
    loss_fn, optim,
    grayscale, scheduler=None, img_noise=0.0
):
    enc.train()
    dec.train()
    epoch_loss = 0
    n_train_batches = len(train_idxs) // batch_size
    
    for batch_i in range(n_train_batches):
        batch_idxs = train_idxs[batch_i*batch_size : (batch_i + 1)*batch_size]
        imgs = preprocess_frame_batch(all_frames_glob, batch_idxs, grayscale=grayscale, img_noise=img_noise)
        imgs = torch.as_tensor(imgs, dtype=torch.float32).to(device)
        
        # YOUR CODE HERE: forward pass
        # imgs_recon = dec(enc(imgs))
        # loss = loss_fn(imgs, imgs_recon)
        
        # YOUR CODE HERE: backward pass and optimisation step
        # optim.zero_grad()
        # loss.backward()
        # optim.step()
        
        epoch_loss += loss.item()
        if scheduler is not None:
            scheduler.step(loss)
            
    return epoch_loss / n_train_batches

def test_epoch(
    enc, dec,
    test_idxs, batch_size, all_frames_glob,
    loss_fn,
    grayscale
):
    enc.eval()
    dec.eval()
    epoch_loss = 0
    n_test_batches = len(test_idxs) // batch_size
    
    with torch.no_grad():
        for batch_i in range(n_test_batches):
            batch_idxs = test_idxs[batch_i*batch_size : (batch_i + 1)*batch_size]
            imgs = preprocess_frame_batch(all_frames_glob, batch_idxs, grayscale=grayscale, img_noise=0.0)
            imgs = torch.as_tensor(imgs, dtype=torch.float32).to(device)
            
            # YOUR CODE HERE: forward pass
            # imgs_recon = dec(enc(imgs))
            # loss = loss_fn(imgs, imgs_recon)
            
            epoch_loss += loss.item()
    
    return epoch_loss / n_test_batches

In [None]:
# Now, putting it all together: the training loop

n_epochs = 10
loss_fn = nn.L1Loss()
learning_rate = 1e-3

optim = torch.optim.Adam(itertools.chain(encoder.parameters(), decoder.parameters()), lr=learning_rate)

# optional: use a learning rate scheduler
scheduler = torch.optim.lr_scheduler.ExponentialLR(optim, gamma=0.999)


In [None]:
for epoch in range(n_epochs):
    train_loss = train_epoch(
        encoder, decoder, train_idxs, batch_size, all_frames_glob,
        loss_fn, optim, grayscale=True, scheduler=scheduler
    )
    test_loss = test_epoch(
        encoder, decoder, test_idxs, batch_size, all_frames_glob,
        loss_fn, grayscale=True
    )
    
    with torch.no_grad():
        frame_example_recon = decoder(encoder(torch.as_tensor(frame_example_img, dtype=torch.float32).to(device)))
    
        plt.imshow(np.swapaxes(frame_example_recon.cpu().numpy(), 0, 2))
    print(f'Epoch {epoch + 1}/{n_epochs}, Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}')


### **What have we achieved in this tutorial?**

We trained a pair of autoencoders, that can compress images into a latent vector.

With a biological constraint (such as non-negative, or sigmoid), then much like the visual cortex, we can interpret the latent vector for each frame as population activity of visual neurons encoding that visual scene.

These encodings of the visual scene can now be readily fed into a **Recurrent Neural Network** to generate spatial cell tunings.