# Transformers


<img src="https://raw.githubusercontent.com/gdewael/teaching/main/presentations/transformers/img/soup.png" width = 500>

*Images created by DALL-E 2 when prompted with "a bowl of soup that is a portal to another dimension as digital art"*

## Transformers in the news


<img src="https://raw.githubusercontent.com/gdewael/teaching/main/presentations/transformers/img/chatgpt.png" width = 500>

## Transformers in the news


<img src="https://raw.githubusercontent.com/gdewael/teaching/main/presentations/transformers/img/dalle3.png" width = 500>

## Transformers in the news

<img src="https://raw.githubusercontent.com/gdewael/teaching/main/presentations/transformers/img/alphafold.jpg" width = 500>

## Overview

1. Data as collections of "tokens"
2. Recap of neural architectures
3. Self-attention & The Transformer
4. Transformer tricks
5. Applications: ChatGPT, AlphaFold, other biological applications

## Data as collections of "tokens"

<img src="https://raw.githubusercontent.com/gdewael/teaching/main/presentations/transformers/img/tokens.png" width = 500>

## Text as collections of "tokens"

<img src="https://raw.githubusercontent.com/gdewael/teaching/main/presentations/transformers/img/tokens_text.png" width = 500>

## Images as collections of "tokens"

<img src="https://raw.githubusercontent.com/gdewael/teaching/main/presentations/transformers/img/tokens_img.png" width = 500>

## Molecules as collections of "tokens"
<img src="https://raw.githubusercontent.com/gdewael/teaching/main/presentations/transformers/img/tokens_mol.png" width = 350>

## (mass) spectra as collections of "tokens"


<img src="https://raw.githubusercontent.com/gdewael/teaching/main/presentations/transformers/img/tokens_spectra.png" width = 400>

## Tabular data as collections of "tokens"

<img src="https://raw.githubusercontent.com/gdewael/teaching/main/presentations/transformers/img/tokens_tabular.png" width = 500>

## PyTorch toy example of tokens

In [19]:
import torch
import torch.nn as nn
import torch.nn.functional as F
torch.set_printoptions(precision=2)

In [20]:
x = torch.randint(low = 0, high = 4, size = (1, 6))
print(x.shape)
print(x)

torch.Size([1, 6])
tensor([[3, 2, 3, 2, 2, 1]])


In [21]:
x_onehot = F.one_hot(x)
print(x_onehot.shape)
print(x_onehot.transpose(1,2))

torch.Size([1, 6, 4])
tensor([[[0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 1],
         [0, 1, 0, 1, 1, 0],
         [1, 0, 1, 0, 0, 0]]])


In [22]:
print(x)
print(x_onehot.transpose(1,2))

tensor([[3, 2, 3, 2, 2, 1]])
tensor([[[0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 1],
         [0, 1, 0, 1, 1, 0],
         [1, 0, 1, 0, 0, 0]]])


In [23]:
embedder = nn.Embedding(num_embeddings=4, embedding_dim = 4)
x_embed = embedder(x)
print(x_embed.shape)
print(x_embed.transpose(1,2))

torch.Size([1, 6, 4])
tensor([[[-1.18,  0.50, -1.18,  0.50,  0.50,  2.00],
         [ 0.40, -0.62,  0.40, -0.62, -0.62, -0.03],
         [-1.56,  0.51, -1.56,  0.51,  0.51,  0.77],
         [-2.31, -1.87, -2.31, -1.87, -1.87, -1.20]]],
       grad_fn=<TransposeBackward0>)


## Recap of neural architectures

<img src="https://raw.githubusercontent.com/gdewael/teaching/main/presentations/transformers/img/cnn.png" width = 300>

### CNN advantages
- Takes advantage of locality of patterns
- Applicable for 1/2/3/... dimensional data
- Efficient

However ..
- Receptive field of convolutions is pre-determined
- Hard to learn long-term interactions

<img src="https://raw.githubusercontent.com/gdewael/teaching/main/presentations/transformers/img/rnn.png" width = 500>

### RNN advantages
- Takes advantage of sequentiality of patterns.
- Intuitive to use in many-to-many, one-to-many, many-to-one, ...

However ..
- Inefficient (sequential vs parallel)
- Only really applicable for 1D data
- Can be liable to forgetting data over long term. (hidden state vector is a bottleneck).

## Where does the MLP fall in this?

<img src="https://raw.githubusercontent.com/gdewael/teaching/main/presentations/transformers/img/mlp.png" width = 500>

- Fixed structure, no variable-length data
- Not applicable for many-to-many, etc ...

### Quick reminder

<img src="http://karpathy.github.io/assets/rnn/diags.jpeg" width = 600>

### The holy grail

<img src="https://raw.githubusercontent.com/gdewael/teaching/main/presentations/transformers/img/whatwewant.png" width = 500>

- Efficient
- Compares all inputs vs all inputs?
- Also for 1D/2D/...
- Also for variable-length data.

## The holy grail = self-attention

<img src="https://raw.githubusercontent.com/gdewael/teaching/main/presentations/transformers/img/token_mixing.png" width = 500>

- Comparing all tokens?

### Similarity matrices via dot products

<img src="https://raw.githubusercontent.com/gdewael/teaching/main/presentations/transformers/img/dotproduct.png" width = 500>

For one vector:
$\boldsymbol{x} \cdot \boldsymbol{x}^\top$, with $\boldsymbol{x} \in \mathbb{R}^{1 \times d}$

For a matrix:
$\boldsymbol{X} \cdot \boldsymbol{X}^\top$, with $\boldsymbol{X} \in \mathbb{R}^{n \times d}$

### From similarity matrix back to $\mathbb{R}^{n \times d}$

<img src="https://raw.githubusercontent.com/gdewael/teaching/main/presentations/transformers/img/XXtX.png" width = 700>

$A = \boldsymbol{X} \boldsymbol{X}^\top$, with $\boldsymbol{A} \in \mathbb{R}^{n \times n}$

$Z = A \boldsymbol{X}$, with $\boldsymbol{Z} \in \mathbb{R}^{n \times d}$

$Z = (\boldsymbol{X}\boldsymbol{X}^\top) \boldsymbol{X} $

### Intuition behind the equations

<img src="https://raw.githubusercontent.com/gdewael/teaching/main/presentations/transformers/img/intuitionxxtx.png" width = 500>

Every output vector is a weighted sum of input vectors, weighted by a similarity measure

### Rounding out the equations

Learning?

$(\boldsymbol{X}\boldsymbol{X}^\top) \boldsymbol{X}$

$[(\boldsymbol{X}\boldsymbol{W}_q)(\boldsymbol{X}\boldsymbol{W}_k)^\top] {\boldsymbol{X}\boldsymbol{W}_v}$, with $\boldsymbol{W}_q, \boldsymbol{W}_k, \boldsymbol{W}_v \in \mathbb{R}^{h \times h}$

$(\boldsymbol{Q}\boldsymbol{K}^\top) \boldsymbol{V}$

### Rounding out the equations

Normalization?

$(\boldsymbol{Q}\boldsymbol{K}^\top) \boldsymbol{V}$

$\texttt{softmax}(\frac{\boldsymbol{Q}\boldsymbol{K}^\top}{\sqrt{d}}) \boldsymbol{V}$

### Rounding out the equation

<img src="https://raw.githubusercontent.com/gdewael/teaching/main/presentations/transformers/img/xxtx_to_attn.png" width = 600>

$\texttt{softmax}(\frac{\boldsymbol{Q}\boldsymbol{K}^\top}{\sqrt{d}}) \boldsymbol{V}$

### A recap, what have we achieved?

From convs and RNNs to self-attention

Self-Attention properties:
- Parallel execution
- Variable input length
- Everything interacts with eachother
- Information flow depends on content of tokens

==> Generic and elegant mechanism

==> Has to learn all the structure in data from scratch, but has the freedom to do so to a better extent

### Code example

In [6]:
x

tensor([[2, 3, 3, 1, 0, 0]])

In [7]:
print(x_embed.shape)
print(x_embed.transpose(1,2))

torch.Size([1, 6, 4])


tensor([[[-2.02, -1.30, -1.30,  1.26,  0.34,  0.34],
         [ 1.05, -1.40, -1.40,  0.32,  0.75,  0.75],
         [-1.71,  1.33,  1.33,  0.18, -1.14, -1.14],
         [ 0.84, -0.93, -0.93, -0.32,  1.98,  1.98]]],
       grad_fn=<TransposeBackward0>)

In [8]:
W_q = nn.Linear(4, 4)
Q = W_q(x_embed)
print(Q.shape)
print(Q.transpose(1,2))

torch.Size([1, 6, 4])


tensor([[[-0.48, -1.20, -1.20, -0.03,  0.42,  0.42],
         [ 0.50, -0.67, -0.67,  0.37,  0.14,  0.14],
         [-0.19,  1.20,  1.20,  0.53, -0.33, -0.33],
         [ 0.32,  1.37,  1.37,  0.02, -0.66, -0.66]]],
       grad_fn=<TransposeBackward0>)

$\texttt{softmax}(\frac{\boldsymbol{Q}\boldsymbol{K}^\top}{\sqrt{d}}) \boldsymbol{V}$

In [1]:
W_q, W_k, W_v = nn.Linear(4, 4), nn.Linear(4, 4), nn.Linear(4, 4)
Q, K, V = W_q(x_embed), W_k(x_embed), W_v(x_embed)

A = Q @ K.transpose(2,1)
print(A.shape)
print(A)

NameError: name 'nn' is not defined

$\texttt{softmax}(\frac{\boldsymbol{Q}\boldsymbol{K}^\top}{\sqrt{d}}) \boldsymbol{V}$

In [10]:
print(A)

print(F.softmax(A, dim = -1))

tensor([[[-0.09,  0.95,  0.95, -0.09, -0.19, -0.19],
         [ 1.53, -2.10, -2.10,  0.42,  2.10,  2.10],
         [ 1.53, -2.10, -2.10,  0.42,  2.10,  2.10],
         [-0.57,  0.25,  0.25,  0.34, -0.40, -0.40],
         [-1.02,  0.77,  0.77, -0.26, -1.35, -1.35],
         [-1.02,  0.77,  0.77, -0.26, -1.35, -1.35]]],
       grad_fn=<UnsafeViewBackward0>)
tensor([[[0.11, 0.30, 0.30, 0.11, 0.09, 0.09],
         [0.20, 0.01, 0.01, 0.07, 0.36, 0.36],
         [0.20, 0.01, 0.01, 0.07, 0.36, 0.36],
         [0.10, 0.22, 0.22, 0.24, 0.11, 0.11],
         [0.06, 0.36, 0.36, 0.13, 0.04, 0.04],
         [0.06, 0.36, 0.36, 0.13, 0.04, 0.04]]], grad_fn=<SoftmaxBackward0>)


In [11]:
A_normalized = F.softmax(A / Q.shape[-1], dim = -1)

In [12]:
print(A_normalized.shape, V.shape)
Z = A_normalized @ V
print(Z.shape)

torch.Size([1, 6, 6]) torch.Size([1, 6, 4])
torch.Size([1, 6, 4])


In [13]:
print(x_embed.shape)

torch.Size([1, 6, 4])


In [18]:
class SelfAttention(nn.Module):
    def __init__(self, d):
        super().__init__()
        self.W_q, self.W_k, self.W_v = (
            nn.Linear(d, d), nn.Linear(d, d), nn.Linear(d, d)
        )
    def forward(self, x):
        Q, K, V = self.W_q(x), self.W_k(x), self.W_v(x)
        A = Q @ K.transpose(2,1)
        A_normalized = F.softmax(A / Q.shape[-1], dim = -1)
        return A_normalized @ V

selfattn = SelfAttention(4)
z = selfattn(x_embed)
print(x_embed.shape, z.shape)
print(x_embed.transpose(2,1))
print(z.transpose(2,1))

torch.Size([1, 6, 4]) torch.Size([1, 6, 4])
tensor([[[-2.02, -1.30, -1.30,  1.26,  0.34,  0.34],
         [ 1.05, -1.40, -1.40,  0.32,  0.75,  0.75],
         [-1.71,  1.33,  1.33,  0.18, -1.14, -1.14],
         [ 0.84, -0.93, -0.93, -0.32,  1.98,  1.98]]],
       grad_fn=<TransposeBackward0>)
tensor([[[-0.13, -0.28, -0.28, -0.41, -0.39, -0.39],
         [ 0.59,  0.26,  0.26,  0.12,  0.15,  0.15],
         [-0.64, -0.47, -0.47, -0.31, -0.34, -0.34],
         [-0.07,  0.24,  0.24,  0.50,  0.44,  0.44]]],
       grad_fn=<TransposeBackward0>)


## From self-attention to the transformer

### Multiple heads

<img src="https://raw.githubusercontent.com/gdewael/teaching/main/presentations/transformers/img/heads.png" width = 700>

### The Transformer block

<img src="https://raw.githubusercontent.com/gdewael/teaching/main/presentations/transformers/img/layernorm.jpg" width = 250>

Many of these "blocks" are stacked.

<img src="https://raw.githubusercontent.com/gdewael/teaching/main/presentations/transformers/img/transformerblock.png" width = 700>

Code example

In [None]:
class TransformerLayer()

## Transformer "tricks"

### Permutation invariance and positional information

<img src="https://raw.githubusercontent.com/gdewael/teaching/main/presentations/transformers/img/permutation_invariance.png" width = 700>

<img src="https://raw.githubusercontent.com/gdewael/teaching/main/presentations/transformers/img/positionalenc.png" width = 700>

<img src="https://raw.githubusercontent.com/gdewael/teaching/main/presentations/transformers/img/relative_positional_encoding.png" width = 700>

### Causal attention = decoder = autoregressive

<img src="https://raw.githubusercontent.com/gdewael/teaching/main/presentations/transformers/img/causal.png" width = 700>

Code example

In [None]:
def add_mask():
    

<img src="https://raw.githubusercontent.com/gdewael/teaching/main/presentations/transformers/img/causal_mask.png" width = 700>

### Sparse attention (masks)

<img src="https://raw.githubusercontent.com/gdewael/teaching/main/presentations/transformers/img/custommasks.png" width = 700>

### Cross-attention

<img src="https://raw.githubusercontent.com/gdewael/teaching/main/presentations/transformers/img/crossattn.png" width = 700>

## Modern Transformer libraries

- PyTorch
- Huggingface
- ....
- ....

## Applications

### ChatGPT

### AlphaFold

### Other biological applications