# Transformers In Neuroscience

### Author: Domenick Mifsud
### Date: 5/21/23

---


## Overview

This tutorial covers the following sections:

1. [Introduction & Setup](#Section-1:-Introduction-&-Setup)

2. [What are Transformers?](#section-1-introduction--setup)
    - [Background](#Background)
    - [Model Architecture](#Model-Architecture)
    - [Self-Attention Overview](#Self-Attention-Overview)
    <br></br>
3. [Create a Neural Data Transformer](#Create-a-Neural-Data-Transformer)
    - [Encoder Layer](#Encoder-Layer)
    - [Feed Forward Layer](#Feed-Forward-Layer)
    - [Multi-head Attention Layer](#Multi-head-Attention-Layer)
    <br></br>
4. [Model Training & Evaluation](#Model-Training-&-Evaluation)
    - [Download the Data](#Download-the-Data)
    - [Model Training](#Model-Training)
    - [Model Evaluation](#Model-Evaluation)
    <br></br>
---

## Section 1: Introduction & Setup

### Introduction

> This is just an example introduction. This tutorial will be about what transformers are, specifically focusing on self-attention.

### Setup

#### Package Installs 
**(⚠️ Only run this once! ⚠️)**

To set up your environment for the first time, uncomment (delete the `#`) and run the following code to install the neccessary packages:

In [None]:
# !pip install numpy
# !pip install matplotlib
# !pip install torch

#### Download Data
**(⚠️ Only run this once! ⚠️)**

Uncomment the following lines to download the dataset into your colab notebook. We'll be using the dataset collected by Mark Churchland and first published in 2008 called the "Maze Dataset". It is publically available as part of the Neural Latents Benchmark contest (https://neurallatents.github.io/).

In [None]:
# !dandi download DANDI:000140/0.220113.0408
# !dandi download https://dandiarchive.org/dandiset/000138
# !mv 000140 ../../data/
# !mv 000138 ../../data/

#### Package Imports
Now run the following cells to import the required packages and set up the helper code:

In [5]:
import numpy as np  # For array operations
import matplotlib.pyplot as plt  # For plotting
import plotly.graph_objects as go # For 3D plotting

#### Functions

These are helper functions

In [94]:
# function to plot a 3D scatterplot using plotly
def create_3d_scatterplot(data_tuples, xaxis_title, yaxis_title, zaxis_title, fig_size=(700, 500)):
    data = [go.Scatter3d(x=[v[0]], y=[v[1]], z=[v[2]], mode='markers', name=n) for v, n in data_tuples]
    data.append(go.Scatter3d(x=[0], y=[0], z=[0], mode='markers', marker=dict(size=2, color='black'), name='center (0,0,0)'))
    layout_dict = dict(xaxis=dict(title=xaxis_title, range=[-1, 1]),
                       yaxis=dict(title=yaxis_title, range=[-1, 1]),
                       zaxis=dict(title=zaxis_title, range=[-1, 1]),
                       camera=dict(eye=dict(x=1.3, y=-1.3, z=1.3), center=dict(x=0.065, y=0.0, z=-0.075)),
                       aspectmode='cube')
    layout = go.Layout(scene=layout_dict, margin=dict(l=0,r=0,b=0,t=0), width=fig_size[0], height=fig_size[1])
    return go.Figure(data=data, layout=layout).show(config={'displayModeBar': False})

---

## Section 2: What are Transformers?

### Sec 2.1: **Background**

Transformer neural networks are sequence-to-sequence models, they take in a set of inputs (or *tokens*) and return a set of outputs.

They are highly parallelizable (can train models faster), as opposed to the sequential nature of previous models like RNNs and LSTMs. They have been widely used in natural language processing (NLP), but have also shown promise in many other domains, such as: computer vision 🖼️, time series modeling 📈, and even neuroscience 🧠.

<img src="transformer_inputs.png" alt="inputs" width="650"/>

What makes these models so special is the way that they move information throughout the inputs... 

In an RNN, this would be done using a hidden state that keeps track of the important information from previously seen inputs, and updates the hidden state for each new input. The issue with this is that the model can only fit so much into it's hidden state, so eventually you will need to forget some things. 

In a transformer however, information is not routed through a hidden state but is directly exchanged between inputs. This can be seen in the example below, where we are trying to predict the next word in the sequence. With the RNN, we only have information from the hidden state to try and predict the next word, with the transformer however, we can pull information from any words in the sequence!

<img src="rnn_v_transformer.png" alt="inputs" width="600"/>

This routing of information across the tokens is accomplished through the use of: ***Self-Attention***, more specifically, ***Scaled Dot Product Attention***

---

### Sec 2.2: **Model Architecture**
The original transformer in the Attention is all you need paper followed the encoder- decoder architectue:

---

### Sec 2.3: **Self-Attention Overview**


To understand how attention can "route" information across tokens, we need to first understand the mechanism by which the tokens interact. In the example below, the tokens are words. To feed these words into a neural network we must first convert them into numbers so the network can process them. This is accomplished through the use of *word embeddings*.

Word embeddings are vectors that represent individual words. Each dimension in the vector represents how much the word relates to some abstract concept. To illustrate this, lets take 3 example words:

* **`Fish`** 🐟
* **`Boat`** 🚢
* **`Hunt`** 🔫

and represent them as 3d embeddings. We will manually set the dimensions to represent 3 arbitrary abstract concepts:

1. Is it an **activity**?   (-1 is never, +1 is always)

2. Is it closer related to the **sea** or to **land**   (-1 is sea, +1 is land)
3. Is it related to **animals**?   (-1 is never, +1 is always)

In [22]:
#      [activity,  sea/land,  animals]
fish = [0.2,       -0.8,      0.9] 
boat = [-0.2,       -0.9,      -0.1]
hunt = [0.8,        0.9,      0.8]

Now that we have our 3 word embedding vectors, lets look at them in 3d space!

In [95]:
data_tuples = [(fish, 'fish'), (boat, 'boat'), (hunt, 'hunt')]
create_3d_scatterplot(data_tuples, 'Activity', 'Sea vs Land', 'Animals')

We can visually see that boat and fish are closely related in this space!

A cool example (which allows us to look at different low-d representations of high-d embeddings) is this: https://projector.tensorflow.org/


Now lets take a look at the **dot product**, or how self-attention quantifies the similarity between these vectors:

$$ A \cdot B = \sum_{i=1}^{n} a_{i} b_{i} $$

***Should I convert this part to be a student activity?***


In [104]:
print(f'Similarity between fish and boat: {np.dot(fish, boat):.2f}')
print(f'Similarity between fish and hunt: {np.dot(fish, hunt):.2f}')
print(f'Similarity between boat and hunt: {np.dot(boat, hunt):.2f}')

Similarity between fish and boat: 0.59
Similarity between fish and hunt: 0.16
Similarity between boat and hunt: -1.05


We can see that it correctly found `fish` and `boat` highly related! It also determined that `fish` and `hunt` are more closely related than `boat` and `hunt`.

> TODO: add transition into actual self attention

#### Self Attention Formula

The actual formula for SA is...

$$ Attention(Q, K, V) = softmax(\frac{QK^T}{\sqrt{d_k}})V $$

go into K, Q, V as a filing system analogy...

**(⚠️ Need to recreate these images! ⚠️)**

<img src="https://jalammar.github.io/images/gpt2/self-attention-example-folders-3.png" alt="inputs" width="650"/>
<br></br>After Attention & Softmax:<br></br>
<img src="https://jalammar.github.io/images/gpt2/self-attention-example-folders-scores-3.png" alt="inputs" width="650"/>

***Should this have a section where they project the embeddings created earlier into multiple 2-d embeddings (MHA example)??***

---

## Section 3: Create a Neural Data Transformer

### Sec 3.1: **Encoder Layer**

***Have them fill out some parts of the encoder based on diagrams / formulas / instructions***

---

### Sec 3.2: **Feed Forward Layer**

***Have them implemnt the FFN based on diagrams / formulas/ / instructions***


---

### Sec 3.3: **Multi-head Attention Layer**

***Have them fill out some parts of the MHA layer based on diagrams / formulas / instructions***


---

## Section 4: Model Training & Evaluation

### Sec 4.1: **Process the Data**

Using the MC Maze small, and predefined data functions break the NWB down (trialized) and get the train, val, test tensors.

Functions needed: NWB processing


---

### Sec 4.2: **Model Training**

load in pretrained parameters into the model that they just created.

***Have them fill out some parts of the training loop***

have the model finetune for X epochs.

Functions needed: Training function?

---

### Sec 4.3: **Model Evaluation**

***Have them fill out some parts of the model eval setup***

Show a PCA of the rates vs the smoothed spikes... aligned w/ trial start? movement may be too much 

Decoding of smoothed vs rates 

Functions needed: PCA, plotting for PCA, deocding

