---
# Modelling hippocampal neurons of animals navigating in VR with recurrent neural networks
### Marco P. Abrate, Daniel Liu
---

##### Outline
Rat simulation:
- Motion model (RatInABox)
- Environment design (Blender)
- Simulated rat vision (ratvision)

Vision autoencoder

Hippocampus model (RNN):
- RNN definition
- Data loading
- Training

Hidden state representations analysis:
- Rate maps
- Polar maps
- Quantitive metrics
- Comparison with real data

---
## **Part 2: Training a Vision Autoencoder**

In this notebook, we will write code to train an **Autoencoder**. An autoencoder is a pair of artificial neural networks that compresses information into a low-dimensional embedding through the first module (aka encoder) and reconstructs it to its original form through the second module (aka decoder).

Neuroscientists use vision autoencoders to model how neurons might represent visual stimuli in the brain. The visual cortex receives complex images and it is able to extract key features - such as edges, motions and shapes - into more compact forms (low-dimensional embedding). This non-linear dimensionality reduction process, along with the reconstruction of the original image, can be compared to an autoencoder.

We will use the **PyTorch** package for this tutorial.

Before starting this notebook, make sure you have:
- video recordings from part 1.

### **0. Install and import dependencies**

In [1]:
!pip install torch torchvision torchaudio
!pip install scikit-learn
!pip install matplotlib



In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from pathlib import Path
import itertools
from sklearn.model_selection import train_test_split

if torch.cuda.is_available():
    DEVICE = torch.device("cuda")
elif torch.backends.mps.is_available():
    DEVICE = torch.device("mps")
else:
    DEVICE = torch.device("cpu")
    
SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)

: 

### **1. Visualize example frame**

In [None]:
d = '/media/data/vrtopc/box/run' # '/Users/marco/Downloads/vrtopc/box/run'

trial_paths = [p for p in Path(d).iterdir() if 'exp' in p.name]

trial_paths

: 

In [None]:
frame_example = trial_paths[0] / 'box_messy' / 'frame0001.png'
plt.imshow(np.array(Image.open(frame_example)), cmap='gray')
plt.show()

: 

In [None]:
IMG_DIM = (64, 128) # (height, width) of the input images
GS = True # whether to use grayscale images

: 

### **2. Load frames**

The ```preprocess_frame``` accepts the path of a frame, then converts it to RGB values and normalise it 1 so that we have a ```(channels, height, width)``` array of numbers between 0 and 1.

Optionally, the frame can be converted to grayscale, so that the channel dimension is 1.

In [None]:
def preprocess_frame(frame, grayscale):
    img = Image.open(frame)
    if grayscale: 
        img = img.convert('L')
    img = np.array(img, dtype=np.float32)

    # normalise RGB to (0, 1) scale
    img = img / 255.
    
    if grayscale:
        img = img[None, ...] # (1, h, w) if grayscale
    else:
        img = np.moveaxis(img, -1, 0) # reshape to (3, h, w) if RGB

    return img

def preprocess_frame_batch(all_frames, batch_indices, grayscale):
    # preprocess a batch of frames
    imgs = np.array([
        preprocess_frame(all_frames[idx], grayscale)
        for idx in batch_indices
    ])
    return imgs

: 

In [None]:
globs = [(tp/'box_messy').glob('*.png') for tp in trial_paths]
all_frames = [f for g in globs for f in g]

: 

In [None]:
# Split dataset into train and test sets
BATCH_SIZE = 32
TEST_SET_PROP = 0.1 # 10%

train_indices, test_indices = train_test_split(
    np.arange(len(all_frames)), test_size=TEST_SET_PROP, shuffle=True, random_state=SEED
)

: 

In [None]:
train_imgs = preprocess_frame_batch(
    all_frames, train_indices, grayscale=GS
)
test_imgs = preprocess_frame_batch(
    all_frames, test_indices, grayscale=GS
)

print(f"Train set shape: {train_imgs.shape}")
print(f"Test set shape: {test_imgs.shape}")

: 

In [None]:
from autoencoder.datasets import UnlabeledDataset

dataloader_train = torch.utils.data.DataLoader(
    UnlabeledDataset(torch.from_numpy(train_imgs)),
    batch_size=BATCH_SIZE, shuffle=True
)
dataloader_test = torch.utils.data.DataLoader(
    UnlabeledDataset(torch.from_numpy(test_imgs)),
    batch_size=BATCH_SIZE, shuffle=True
)

: 

### **3. Vision autoencoder definition**

As we mentioned before, an autoencoder is a type of neural network that learns to compress data into a smaller representation (encoding) and then reconstruct it back (decoding).

There is no restriction on the structure of the encoder and the decoder, and they need not be symmetric. 

However, since we are processing image frames, **convolutional layers** will be helpful becuase they can:
- capture spatial features like edges, textures, shapes, etc.
- preserve local patterns and share weights across image
- detect features efficiently regardless of position

In [None]:
from autoencoder.vision_ae import VisualEncoder, VisualDecoder

# KERNEL_SIZES = [(4,4), (5,5), (3,3), (3,3)] # kernel sizes for the convolutional layers
# KERNEL_STRIDES = [2, 2, 1, 1] # strides
# CHANNELS = [8, 8, 16, 16] # number of channels

KERNEL_SIZES = [(3,3), (4,4), (4,4), (3,3)] # kernel sizes for the convolutional layers
KERNEL_STRIDES = [1, 2, 2, 1] # strides
CHANNELS = [8, 8, 16, 16] # number of channels

EMB_DIM = 100 # the number of neurons in the latent space (or number of latent features)

img_dim_out = IMG_DIM

print('Need to make sure all numbers are INTEGERS!\n')
print(f'Input dimension:\t\t{1 if GS else 3}x {IMG_DIM}')
for i in range(len(KERNEL_SIZES)):
    ksize = KERNEL_SIZES[i]
    stride = KERNEL_STRIDES[i]

    img_dim_out = [
        (img_dim_out[i] - ksize[i])/stride + 1
        for i in range(len(IMG_DIM))
    ]
    
    print(f'Intermediate dimension {i+1}:\t{CHANNELS[i]}x {img_dim_out}')

print(f'Flattens to:\t\t\t{np.prod(img_dim_out)*CHANNELS[-1]}\n')


: 

In [None]:
encoder = VisualEncoder(
    visual_embedding_dim = EMB_DIM,
    img_dim = IMG_DIM,
    grayscale = GS,
    kernel_sizes = KERNEL_SIZES,
    kernel_strides = KERNEL_STRIDES,
    channels = CHANNELS
).to(DEVICE)

decoder = VisualDecoder(
    visual_embedding_dim = EMB_DIM,
    img_dim_out = encoder.img_dim_out,
    grayscale = GS,
    kernel_sizes = KERNEL_SIZES,
    kernel_strides = KERNEL_STRIDES,
    channels = CHANNELS
).to(DEVICE)

: 

Now that we have defined our encoder and decoder, let's see an example of the reconstruction.

In [None]:
# visualise examples
with torch.no_grad():
    for example_batch in dataloader_test:
        example_batch = example_batch[np.random.choice(len(example_batch), size=3, replace=False)]
        example_batch = example_batch.to(DEVICE)
        example_batch_recon = decoder(encoder(example_batch))

        fig, axs = plt.subplots(len(example_batch), 2, figsize=(6, 1.5*len(example_batch)))
        axs.flat[0].set_title('Original Images')
        axs.flat[1].set_title('Reconstructed Images')
        for i, (frame_example_img, frame_example_recon) in enumerate(zip(example_batch, example_batch_recon)):
            axs[i, 0].imshow(frame_example_img.cpu().numpy().squeeze(), cmap='gray')
            axs[i, 1].imshow(frame_example_recon.cpu().numpy().squeeze(), cmap='gray')
            axs[i, 0].axis('off')
            axs[i, 1].axis('off')
        plt.show()

        break

: 

This looks like just noise. This is because the autoencoder is untrained. But the good news is: at least we got the dimensions correct!

Now, let's train the pair of networks.

### **4. Autoencoder training**

* **Train the model**. This includes a forward pass, computing loss, backpropagation and update weights

* If needed, **validate** on the validation set to tune hyperparameters.

* **Test** the trained model on unseen data to evaluate performance.

* Once the model has reached satisfactory performance, it is ready for **deployment**.

In [None]:
# Define the train and test functions

def train_epoch(
    enc, dec,
    dataloader_train,
    loss_fn, optimizer
):
    enc.train()
    dec.train()
    epoch_loss = 0
    
    for batch in dataloader_train:
        optimizer.zero_grad()

        batch = batch.to(DEVICE)
        
        # YOUR CODE HERE: forward pass
        batch_recon = dec(enc(batch))
        loss = loss_fn(batch, batch_recon)
        
        # YOUR CODE HERE: backward pass and optimisation step
        loss.backward()
        optimizer.step()
        
        epoch_loss += loss.detach().item()
            
    return epoch_loss / len(dataloader_train)

def test_epoch(
    enc, dec,
    dataloader_test,
    loss_fn,
):
    enc.eval()
    dec.eval()
    epoch_loss = 0
    
    with torch.no_grad():
        for batch in dataloader_test:
            batch = batch.to(DEVICE)
            
            # YOUR CODE HERE: forward pass
            batch_recon = dec(enc(batch))
            loss = loss_fn(batch, batch_recon)
            
            epoch_loss += loss.item()
    
    return epoch_loss / len(dataloader_test)

: 

In [None]:
# Putting it all together

n_epochs = 100
loss_fn = torch.nn.L1Loss()
learning_rate = 1e-4

optimizer = torch.optim.RMSprop(
    itertools.chain(encoder.parameters(), decoder.parameters()),
    lr=learning_rate
)

LR_REDUCTION_FACTOR = 0.1
LR_SCHED_PATIENCE = 20
LR_SCHED_TH = 1e-3
# optional: use a learning rate scheduler
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, factor=LR_REDUCTION_FACTOR,
    patience=LR_SCHED_PATIENCE, threshold=LR_SCHED_TH
)

: 

In [None]:
train_loss_list = []
test_loss_list = []

for epoch in range(n_epochs):
    train_loss = train_epoch(
        encoder, decoder,
        dataloader_train,
        loss_fn, optimizer
    )
    test_loss = test_epoch(
        encoder, decoder,
        dataloader_test,
        loss_fn
    )
    scheduler.step(test_loss)

    train_loss_list.append(train_loss)
    test_loss_list.append(test_loss)

    # visualize examples
    if epoch%10 == 0:
        with torch.no_grad():
            for example_batch in dataloader_test:
                example_batch = example_batch[
                    np.random.choice(len(example_batch), size=3, replace=False)
                ]
                example_batch = example_batch.to(DEVICE)
                example_batch_recon = decoder(encoder(example_batch))

                fig, axs = plt.subplots(len(example_batch), 2, figsize=(6, 1.5*len(example_batch)))
                axs.flat[0].set_title('Original Images')
                axs.flat[1].set_title('Reconstructed Images')
                for i, (frame_example, frame_recon) in enumerate(zip(example_batch, example_batch_recon)):
                    axs[i, 0].imshow(frame_example.cpu().numpy().squeeze(), cmap='gray')
                    axs[i, 1].imshow(frame_recon.cpu().numpy().squeeze(), cmap='gray')
                    axs[i, 0].axis('off')
                    axs[i, 1].axis('off')
                plt.show()
                break

    print(f'Epoch {epoch + 1}/{n_epochs}, Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}')


: 

: 

### **What have we achieved in this tutorial?**

We trained a pair of autoencoders, that can compress images into a latent vector.

With a biological constraint (such as non-negative, or sigmoid), then much like the visual cortex, we can interpret the latent vector for each frame as population activity of visual neurons encoding that visual scene.

These encodings of the visual scene can now be readily fed into a **Recurrent Neural Network** to generate spatial cell tunings.