# Transformers In Neuroscience

### Author: Domenick Mifsud
### Date: 5/21/23

---


## Overview

This tutorial covers the following sections:

1. [Introduction & Setup](#Section-1:-Introduction-&-Setup)

2. [What are Transformers?](#section-1-introduction--setup)
    <br>
    <br> 2.1&nbsp;&nbsp;[Background](#Sec-2.1:-Background)
    <br> 2.2&nbsp;&nbsp;[Model Architecture](#Model-Architecture)
    <br> 2.3&nbsp;&nbsp;[Self-Attention Overview](#Self-Attention-Overview)
    <br>
3. [Applications to Neuroscience](#section-1-introduction--setup)
    <br>
    <br> 3.1&nbsp;&nbsp;[FMRI](#Background)
    <br> 3.2&nbsp;&nbsp;[Calcium](#Model-Architecture)
    <br> 3.3&nbsp;&nbsp;[Ephys](#Self-Attention-Overview)
    <br>
4. [Create a Neural Data Transformer](#Create-a-Neural-Data-Transformer)
    <br>
    <br> 4.1&nbsp;&nbsp;[Feed Forward Layer](#Feed-Forward-Layer)
    <br> 4.2&nbsp;&nbsp;[Multi-head Attention Layer](#Multi-head-Attention-Layer)
    <br> 4.3&nbsp;&nbsp;[Encoder Layer](#Encoder-Layer)
    <br> 4.4&nbsp;&nbsp;[The Full Model](#The-Full-Model)
    <br>
5. [Model Training & Evaluation](#Model-Training-&-Evaluation)
    <br>
    <br> 5.1&nbsp;&nbsp;[Download the Data](#Download-the-Data)
    <br> 5.2&nbsp;&nbsp;[Model Training](#Model-Training)
    <br> 5.3&nbsp;&nbsp;[Model Evaluation](#Model-Evaluation)
    <br>
---

## Section 1: Introduction & Setup

### Introduction

> This is just an example introduction. This tutorial will be about what transformers are, specifically focusing on self-attention.

### Setup

#### Package Installs 
**(⚠️ Only run this once! ⚠️)**

To set up your environment for the first time, uncomment (delete the `#`) and run the following code to install the neccessary packages:

In [None]:
# !pip install numpy
# !pip install matplotlib
# !pip install torch

#### Download Data
**(⚠️ Only run this once! ⚠️)**

Uncomment the following lines to download the dataset into your colab notebook. We'll be using the dataset collected by Mark Churchland and first published in 2008 called the "Maze Dataset". It is publically available as part of the Neural Latents Benchmark contest (https://neurallatents.github.io/).

In [None]:
# !dandi download DANDI:000140/0.220113.0408
# !dandi download https://dandiarchive.org/dandiset/000138
# !mv 000140 ../../data/
# !mv 000138 ../../data/

#### Package Imports
Now run the following cells to import the required packages and set up the helper code:

In [2]:
import torch  # For tensor operations
import numpy as np  # For array operations
import torch.nn as nn  # For neural network layers
import matplotlib.pyplot as plt  # For plotting
import torch.nn.functional as F  # For functional operations
import plotly.graph_objects as go # For 3D plotting

#### Functions

These are helper functions

In [94]:
# function to plot a 3D scatterplot using plotly
def create_3d_scatterplot(data_tuples, xaxis_title, yaxis_title, zaxis_title, fig_size=(700, 500)):
    data = [go.Scatter3d(x=[v[0]], y=[v[1]], z=[v[2]], mode='markers', name=n) for v, n in data_tuples]
    data.append(go.Scatter3d(x=[0], y=[0], z=[0], mode='markers', marker=dict(size=2, color='black'), name='center (0,0,0)'))
    layout_dict = dict(xaxis=dict(title=xaxis_title, range=[-1, 1]),
                       yaxis=dict(title=yaxis_title, range=[-1, 1]),
                       zaxis=dict(title=zaxis_title, range=[-1, 1]),
                       camera=dict(eye=dict(x=1.3, y=-1.3, z=1.3), center=dict(x=0.065, y=0.0, z=-0.075)),
                       aspectmode='cube')
    layout = go.Layout(scene=layout_dict, margin=dict(l=0,r=0,b=0,t=0), width=fig_size[0], height=fig_size[1])
    return go.Figure(data=data, layout=layout).show(config={'displayModeBar': False})

---

## Section 2: What are Transformers?

### Sec 2.1: **Background**

Transformer neural networks are sequence-to-sequence models, they take in a set of inputs (or *tokens*) and return a set of outputs.

They are highly parallelizable (can train models faster), as opposed to the sequential nature of previous models like RNNs and LSTMs. They have been widely used in natural language processing (NLP), but have also shown promise in many other domains, such as: computer vision 🖼️, time series modeling 📈, and even neuroscience 🧠.

<img src="./transformer_inputs.png" alt="inputs" width="650"/>

What makes these models so special is the way that they move information throughout the inputs... 

In an RNN, this would be done using a hidden state that keeps track of the important information from previously seen inputs, and updates the hidden state for each new input. The issue with this is that the model can only fit so much into it's hidden state, so eventually you will need to forget some things. 

In a transformer however, information is not routed through a hidden state but is directly exchanged between inputs. This can be seen in the example below, where we are trying to predict the next word in the sequence. With the RNN, we only have information from the hidden state to try and predict the next word, with the transformer however, we can pull information from any words in the sequence!

<img src="./rnn_v_transformer.png" alt="inputs" width="600"/>

This routing of information across the tokens is accomplished through the use of: ***Self-Attention***, more specifically, ***Scaled Dot Product Attention***

---

### Sec 2.2: **Model Architecture**
The original transformer in the Attention is all you need paper followed the encoder- decoder architectue:

---

### Sec 2.3: **Self-Attention Overview**


To understand how attention can "route" information across tokens, we need to first understand the mechanism by which the tokens interact. In the example below, the tokens are words. To feed these words into a neural network we must first convert them into numbers so the network can process them. This is accomplished through the use of *word embeddings*.

Word embeddings are vectors that represent individual words. Each dimension in the vector represents how much the word relates to some abstract concept. To illustrate this, lets take 3 example words:

* **`Fish`** 🐟
* **`Boat`** 🚢
* **`Hunt`** 🔫

and represent them as 3D embeddings. We will manually set the dimensions to represent 3 arbitrary abstract concepts:

1. Is it an **activity**?   (-1 is never, +1 is always)

2. Is it closer related to the **sea** or to **land**   (-1 is sea, +1 is land)
3. Is it related to **animals**?   (-1 is never, +1 is always)

In [22]:
#      [activity,  sea/land,  animals]
fish = [0.2,       -0.8,      0.9] 
boat = [-0.2,      -0.9,      -0.1]
hunt = [0.8,       0.9,       0.8]

Now that we have our 3 word embedding vectors, lets look at them in 3d space!

In [95]:
data_tuples = [(fish, 'fish'), (boat, 'boat'), (hunt, 'hunt')]
create_3d_scatterplot(data_tuples, 'Activity', 'Sea vs Land', 'Animals')

We can visually see that boat and fish are closely related in this space!

A cool example (which allows us to look at different low-d representations of high-d embeddings) is this: https://projector.tensorflow.org/


Now lets take a look at the **dot product**, or how self-attention quantifies the similarity between these vectors:

$$ A \cdot B = \sum_{i=1}^{n} a_{i} b_{i} $$

***Should I convert this part to be a student activity?***


In [104]:
print(f'Similarity between fish and boat: {np.dot(fish, boat):.2f}')
print(f'Similarity between fish and hunt: {np.dot(fish, hunt):.2f}')
print(f'Similarity between boat and hunt: {np.dot(boat, hunt):.2f}')

Similarity between fish and boat: 0.59
Similarity between fish and hunt: 0.16
Similarity between boat and hunt: -1.05


We can see that it correctly found `fish` and `boat` highly related! It also determined that `fish` and `hunt` are more closely related than `boat` and `hunt`.

> TODO: add transition into actual self attention

#### Self Attention Formula

The actual formula for SA is...

$$ Attention(Q, K, V) = softmax(\frac{QK^T}{\sqrt{d_k}})V $$

go into K, Q, V as a filing system analogy...

**(⚠️ Need to recreate these images! ⚠️)**

<img src="https://jalammar.github.io/images/gpt2/self-attention-example-folders-3.png" alt="inputs" width="650"/>
<br></br>After Attention & Softmax:<br></br>
<img src="https://jalammar.github.io/images/gpt2/self-attention-example-folders-scores-3.png" alt="inputs" width="650"/>

***Should this have a section where they project the embeddings created earlier into multiple 2-d embeddings (MHA example)??***

I think that I need to incorperate A section about 


---

## Section 3: Create a Neural Data Transformer

In this section we will be creating a simplified version of the Neural Data Transformer ([pdf](https://arxiv.org/pdf/2108.01210.pdf)) ([repo](https://github.com/snel-repo/neural-data-transformers)), training it on data from the Nerual Latents Benchmark, and ...

...

# !!!

...

##### ***Introductory Terms***:

* ***[Subclass](https://uwpce-pythoncert.github.io/ProgrammingInPython/modules/SubclassingAndInheritance.html)***: An object that gets to use all the functions defined in its "superclass". By creating a subclass you are saying: this Object is a *superclass*, but heres even more things that it can do that *superclass* can't.

* ***[Batch](https://www.linkedin.com/advice/3/what-benefits-drawbacks-batch-processing-ml-skills-batch-processing)***: A bunch of inputs being processed at once to speed up neural network training. Think of it like like an airport shuttle, if multiple people need to make the same trip it makes sense to take everyone together!

* ***[Linear Layer](https://pytorch.org/docs/stable/nn.html#linear-layers)***: A.K.A. a Fully-Connected Layer, is a matrix multiplication on the incoming data with a learned "weight" matrix and then addition with a learned "bias" term. 

* ***[Normalization](https://pytorch.org/docs/stable/nn.html#normalization-layers)***: is a technique that adjusts the inputs to make their distribution more uniform. This helps the model learn more efficiently as it ensures all inputs are on a similar scale. 
    * Transformers often use [Layer Normalization](https://pytorch.org/docs/stable/generated/torch.nn.LayerNorm.html) (LayerNorm), where the mean and variance are computed independently for each input across all its features, adjusting and scaling the activations within the current layer. This is different from other normalization techniques such as Batch Normalization, which computes a single mean and variance for the entire batch.
    <br></br>
* ***[Dropout](https://pytorch.org/docs/stable/nn.html#dropout-layers)***: is a "regularization" technique where a proportion feature in the inputs are randomly "dropped out" or set to zero during training. This helps keep the features from relying on the presence of particular other features and is forced to learn more robust features that are useful in conjunction with many different random subsets of the other features. This improves model robustness and the ability to generalize to unseen data.

##### ***Understanding the PyTorch Docs***:

As an example, here is the first line from the [linear layer docs](https://pytorch.org/docs/stable/generated/torch.nn.Linear.html#torch.nn.Linear):
```python
    class torch.nn.Linear(in_features, out_features, bias=True, device=None, dtype=None)
```
In the above line, the parameters are the items in the comma seperated list within the parentheses () after `torch.nn.Linear`.
<br></br>
A linear layer object would be created by providing the class with parameters in the order requested. The parameters that have the equal sign next to them do *not* need to be provided becuase the equal sign means, "here is what the value will be if you dont give me anything". 
<br>

To understand more about the parameters it is requesting, look for the unordered list under the bold text: "**Parameters:**"

Here is how you would create a object from the above class that takes in a vector of with 3 dimensions and outputs a vector with 4 dimensions:
```python
    in_dim = 3
    out_dim = 4
    self.example_layer = nn.Linear(in_dim, out_dim)
```


---

### Sec 3.1: **Feed Forward Layer**

The [Feed-Forward Network](https://en.wikipedia.org/wiki/Feedforward_neural_network) (FFN) is a crucial part of the transformer architecture. It is applied independently to each token in the sequence, so this operation is highly parrallelizable (speeds up training and execution of the model because you can do it all tokens in one operation). The FFN consists of two linear layers and a non-linearity in between. It is essential for introducing non-linearity to the system and allowing the model to learn more complex representations. Despite its name, it doesn't carry information forward in terms of sequential data but instead contributes to the transformation of the input data within each individual layer. 

* Transformers often use the [rectified linear unit](https://pytorch.org/docs/stable/generated/torch.nn.ReLU.html) (ReLU) function as their choice of non-linearity, which just sets negative values to 0. Simple huh! 

    <img src="https://pytorch.org/docs/stable/_images/ReLU.png" alt="inputs" width="400"/>

##### **Exercise**
Now let's implement the FFN using PyTorch. You will need to create a class called FeedForwardNetwork, which is a subclass of nn.Module. 

**You need to provide the class with code to do two things:**
1. First to initialize the various PyTorch layers/functions to the requested sizes within the `__init__` method of your class.
2. Then to transform the incoming using those layers you created within the `forward` method of your class.

**The FFN should consist of** two linear layers ([nn.Linear](https://pytorch.org/docs/stable/generated/torch.nn.Linear.html)), a normalization layer ([nn.LayerNorm](https://pytorch.org/docs/stable/generated/torch.nn.LayerNorm.html#layernorm)), a dropout layer ([nn.Dropout](https://pytorch.org/docs/stable/generated/torch.nn.Dropout.html#dropout)), and a non-linear activation function ([nn.ReLU](https://pytorch.org/docs/stable/generated/torch.nn.ReLU.html)). 
* The LayerNorm should be named `self.norm` 
    * it will normalize the data with shape `input_dim`
    <br></br>
* The first linear layer should be named `self.linear1`
    * it will transform the input dimensions from `input_dim` to `hidden_dim`
    <br></br>
* The ReLU function should be named `self.relu`
<br></br>
* The dropout layer should be named `self.dropout` 
    * it will have a `dropout_p` probability of an element to be zeroed
    <br></br>
* The second linear layer should be named `self.linear2` 
    * it will transform the input dimensions from `hidden_dim` to `input_dim`
    <br></br>

! ***MAKE SURE YOU USE THE VARIABLE NAMES DEFINED ABOVE*** !

Next, within the forward method, you are to define the forward pass for this network. 
<br>
**This method should:** 
1. Pass the input x through the LayerNorm
2. Pass the result through the first linear layer
3. Pass the result through the ReLU activation
5. Pass the result through the dropout layer
4. Pass the result through the second linear layer

Use the following code snippet as your starting point:

In [3]:
class FeedForwardLayer(nn.Module):
    def __init__(self, input_dim, hidden_dim, dropout_p):
        super().__init__()
        # TODO: Define the LayerNorm
        # TODO: Define the first linear layer
        # TODO: Define the ReLU activation function
        # TODO: Define the dropout layer
        # TODO: Define the second linear layer

    def forward(self, x):
        x = self.norm(x) # pass through LayerNorm
        # TODO: Define the rest of the forward pass
        return out # return the output of the second linear layer
    
class FeedForwardLayer(nn.Module):
    def __init__(self, input_dim, hidden_dim, dropout_p):
        super().__init__()
        self.norm = nn.LayerNorm(input_dim)
        self.linear1 = nn.Linear(input_dim, hidden_dim)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(dropout_p)
        self.linear2 = nn.Linear(hidden_dim, input_dim)

    def forward(self, x):
        x = self.norm(x) # pass through LayerNorm
        x = self.linear1(x) # pass through first linear layer
        x = self.relu(x) # pass through ReLU activation function
        x = self.dropout(x) # pass through dropout layer
        out = self.linear2(x) # pass through second linear layer
        return out # return the output of the second linear layer

After you finish creating the class, run the cell and continue to the next section.



---

### Sec 3.2: **Multi-head Attention Layer**

The [Feed-Forward Network](https://en.wikipedia.org/wiki/Feedforward_neural_network) (FFN) is a crucial part of the transformer architecture. It is applied independently to each token in the sequence, so this operation is highly parrallelizable (speeds up training and execution of the model because you can do it all tokens in one operation). The FFN consists of two linear layers and a non-linearity in between. It is essential for introducing non-linearity to the system and allowing the model to learn more complex representations. Despite its name, it doesn't carry information forward in terms of sequential data but instead contributes to the transformation of the input data within each individual layer. 

* Transformers often use the [rectified linear unit](https://pytorch.org/docs/stable/generated/torch.nn.ReLU.html) (ReLU) function as their choice of non-linearity, which just sets negative values to 0. Simple huh! 

##### **Exercise**
Now let's implement the FFN using PyTorch. You will need to create a class called FeedForwardNetwork, which is a subclass of nn.Module. 

**You need to provide the class with code to do two things:**
1. First to initialize the various PyTorch layers/functions to the requested sizes within the `__init__` method of your class.
2. Then to transform the incoming using those layers you created within the `forward` method of your class.

**The FFN should consist of** two linear layers ([nn.Linear](https://pytorch.org/docs/stable/generated/torch.nn.Linear.html)), a normalization layer ([nn.LayerNorm](https://pytorch.org/docs/stable/generated/torch.nn.LayerNorm.html#layernorm)), a dropout layer ([nn.Dropout](https://pytorch.org/docs/stable/generated/torch.nn.Dropout.html#dropout)), and a non-linear activation function ([nn.ReLU](https://pytorch.org/docs/stable/generated/torch.nn.ReLU.html)). 
* The LayerNorm should be named `self.norm` 
    * it will normalize the data with shape `input_dim`
    <br></br>
* The first linear layer should be named `self.linear1`
    * it will transform the input dimensions from `input_dim` to `hidden_dim`
    <br></br>
* The ReLU function should be named `self.relu`
<br></br>
* The dropout layer should be named `self.dropout` 
    * it will have a `dropout_p` probability of an element to be zeroed
    <br></br>
* The second linear layer should be named `self.linear2` 
    * it will transform the input dimensions from `hidden_dim` to `input_dim`
    <br></br>


! ***MAKE SURE YOU USE THE VARIABLE NAMES DEFINED ABOVE*** !

Next, within the forward method, you are to define the forward pass for this network. 
<br>
**This method should:** 
1. Pass the input x through the LayerNorm
2. Pass the result through the first linear layer
3. Pass the result through the ReLU activation
5. Pass the result through the dropout layer
4. Pass the result through the second linear layer

Use the following code snippet as your starting point:


In [8]:
class MultiHeadAttention(nn.Module):
    def __init__(self, input_dim, num_heads, head_dims, dropout_p):
        super().__init__()
        self.head_dims = head_dims
        self.num_heads = num_heads
        self.all_head_dims = head_dims * num_heads

        self.norm = nn.LayerNorm(input_dim)
        self.dropout = nn.Dropout(dropout_p)

        self.linear_q = nn.Linear(input_dim, self.all_head_dims)
        self.linear_k = nn.Linear(input_dim, self.all_head_dims)
        self.linear_v = nn.Linear(input_dim, self.all_head_dims)

        self.linear_out = nn.Linear(self.all_head_dims, input_dim)

    def forward(self, x):
        # Input shape: (batch_size, seq_len, d_model)
        batch_size, seq_len, input_dim = x.size()

        # 1) Pass the input through the LayerNorm
        x = self.norm(x)

        # 1) Pass the input through the linear layers to get the queries, keys, and values.
        query = self.linear_q(x)
        key = self.linear_k(x)
        value = self.linear_v(x)

        # 2) Reshape the queries, keys, and values to shape: 
        #   (seq_len, batch_size, self.num_heads, self.head_dims)
        split_shape = (batch_size, seq_len, self.num_heads, self.head_dims)
        query = query.view(*split_shape).transpose(0, 1)
        key = key.view(*split_shape).transpose(0, 1)
        value = value.view(*split_shape).transpose(0, 1)

        # 4) Pass the queries, keys, and values through the attention function
        attn_scores = torch.matmul(query, key.T)
        attn_scores /= torch.sqrt(self.head_dims)
        attn = F.softmax(attn_scores, dim=-1)
        attn = self.dropout(attn)
        x = torch.matmul(attn, value) # shape: (batch_size * num_heads, seq_len, head_dims)

        # 5) Swap seq_len and batch_size * num_heads dimensions
        x = x.transpose(0, 1).contiguous()  # shape: (seq_len, batch_size * num_heads, head_dims)

        # 6) Prepare data for final linear layer
        x = x.view(seq_len * batch_size, self.all_head_dims) # shape: (seq_len * batch_size, self.all_head_dims)

        # 7) Pass the output through the final linear layer
        x = self.linear_out(x)  # shape: (seq_len * batch_size, input_dim)

        # 8) Reshape the output back to (batch_size, seq_len, input_dim)
        x = x.view(batch_size, seq_len, input_dim)

        return x


---

### Sec 3.3: **Encoder Layer**

***Have them fill out some parts of the encoder based on diagrams / formulas / instructions***

In [None]:
class EncoderLayer(nn.Module):
    def __init__(self, input_dim, num_heads, head_dims, ffn_dim, mha_dropout_p, ffn_dropout_p):
        super().__init__()
        self.MHA = MultiHeadAttention(input_dim, num_heads, head_dims, mha_dropout_p)
        self.FFN = FeedForwardLayer(input_dim, ffn_dim, ffn_dropout_p)

        self.post_mha_dropout = nn.Dropout(mha_dropout_p)
        self.post_ffn_dropout = nn.Dropout(ffn_dropout_p)

    def forward(self, x):
        residual = x
        x = self.MHA(x)
        x = self.post_mha_dropout(x)
        x = x + residual

        residual = x
        x = self.FFN(x)
        x = self.post_ffn_dropout(x)
        x = x + residual

        return x

---

### Sec 3.4: **The Full Model**

No student section, puts all layers together


In [None]:
class NeuralDataTransformer(nn.Module):
    def __init__(self, config_dict):
    # def __init__(self, n_layers, n_electrodes, factor_dim, num_heads, head_dims, ffn_dim, mha_dropout_p, ffn_dropout_p):
        super().__init__()
        self.encoder = nn.ModuleList([EncoderLayer(**config_dict) for _ in range(config_dict['n_layers'])])
        self.readin = nn.Linear(config_dict['n_electrodes'], config_dict['factor_dim'])
        self.readout = nn.Linear(config_dict['factor_dim'], config_dict['n_electrodes'])

        self.norm = nn.LayerNorm(config_dict['input_dim'])
        self.post_readin_dropout = nn.Dropout(config_dict['post_readin_dropout_p'])
        self.post_encoder_dropout = nn.Dropout(config_dict['post_encoder_dropout_p'])

        self.scale = torch.sqrt(torch.FloatTensor([config_dict['factor_dim']]))
        self.classifier = nn.PoissonNLLLoss(reduction='none')
        self.pos_embedding = nn.Parameter(torch.randn(config_dict['seq_len'], config_dict['factor_dim']))

    def forward(self, x, labels=None):
        x = self.readin(x)

        x = self.post_readin_dropout(x)

        x *= self.scale

        x = x + self.pos_embedding

        for layer in self.encoder:
            x = layer(x)

        x = self.norm(x)

        x = self.post_encoder_dropout(x)

        x = self.readout(x)

        if labels == None: 
            return x
        
        loss = self.classifier(x, labels)

        masked_loss = loss[labels != -100]

        masked_loss = masked_loss.mean()

        return masked_loss, x

---

## Section 4: Model Training & Evaluation

### Sec 4.1: **Process the Data**

Using the MC Maze small, and predefined data functions break the NWB down (trialized) and get the train, val, test tensors.

Functions needed: NWB processing


---

### Sec 4.2: **Model Training**

load in pretrained parameters into the model that they just created.

***Have them fill out some parts of the training loop***

have the model finetune for X epochs.

Functions needed: Training function?

---

### Sec 4.3: **Model Evaluation**

***Have them fill out some parts of the model eval setup***

Show a PCA of the rates vs the smoothed spikes... aligned w/ trial start? movement may be too much 

Decoding of smoothed vs rates 

Functions needed: PCA, plotting for PCA, deocding

