# Transformers


<img src="https://raw.githubusercontent.com/gdewael/teaching/main/presentations/transformers/img/soup.png" width = 500>

*Images created by DALL-E 2 when prompted with "a bowl of soup that is a portal to another dimension as digital art"*

## Transformers in the news


<img src="https://raw.githubusercontent.com/gdewael/teaching/main/presentations/transformers/img/chatgpt.png" width = 500>

## Transformers in the news


<img src="https://raw.githubusercontent.com/gdewael/teaching/main/presentations/transformers/img/dalle3.png" width = 500>

## Transformers in the news

<img src="https://raw.githubusercontent.com/gdewael/teaching/main/presentations/transformers/img/alphafold.jpg" width = 500>

## Overview

1. Data as collections of "tokens"
2. Recap of neural architectures
3. Self-attention & The Transformer
4. Transformer tricks
5. Applications: ChatGPT, AlphaFold, other biological applications

## Data as collections of "tokens"

<img src="https://raw.githubusercontent.com/gdewael/teaching/main/presentations/transformers/img/tokens.png" width = 500>

## Text as collections of "tokens"

<img src="https://raw.githubusercontent.com/gdewael/teaching/main/presentations/transformers/img/tokens_text.png" width = 500>

## Images as collections of "tokens"

<img src="https://raw.githubusercontent.com/gdewael/teaching/main/presentations/transformers/img/tokens_img.png" width = 500>

## Molecules as collections of "tokens"
<img src="https://raw.githubusercontent.com/gdewael/teaching/main/presentations/transformers/img/tokens_mol.png" width = 350>

## (mass) spectra as collections of "tokens"


<img src="https://raw.githubusercontent.com/gdewael/teaching/main/presentations/transformers/img/tokens_spectra.png" width = 400>

## Tabular data as collections of "tokens"

<img src="https://raw.githubusercontent.com/gdewael/teaching/main/presentations/transformers/img/tokens_tabular.png" width = 500>

## PyTorch toy example of tokens

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
torch.set_printoptions(precision=2)

In [2]:
x = torch.randint(low = 0, high = 4, size = (1, 6))
print(x.shape)
print(x)

torch.Size([1, 6])
tensor([[1, 0, 3, 3, 0, 0]])


In [3]:
x_onehot = F.one_hot(x)
print(x_onehot.shape)
print(x_onehot.transpose(1,2))

torch.Size([1, 6, 4])
tensor([[[0, 1, 0, 0, 1, 1],
         [1, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0],
         [0, 0, 1, 1, 0, 0]]])


In [4]:
print(x)
print(x_onehot.transpose(1,2))

tensor([[1, 0, 3, 3, 0, 0]])
tensor([[[0, 1, 0, 0, 1, 1],
         [1, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0],
         [0, 0, 1, 1, 0, 0]]])


In [5]:
embedder = nn.Embedding(num_embeddings=4, embedding_dim = 4)
x_embed = embedder(x)
print(x_embed.shape)
print(x_embed.transpose(1,2))

torch.Size([1, 6, 4])
tensor([[[-0.39,  1.22,  0.98,  0.98,  1.22,  1.22],
         [-0.57,  0.68, -0.25, -0.25,  0.68,  0.68],
         [ 0.73,  0.31,  0.04,  0.04,  0.31,  0.31],
         [-0.22,  0.62, -0.76, -0.76,  0.62,  0.62]]],
       grad_fn=<TransposeBackward0>)


## Recap of neural architectures

<img src="https://raw.githubusercontent.com/gdewael/teaching/main/presentations/transformers/img/cnn.png" width = 300>

### CNN advantages
- Takes advantage of locality of patterns
- Applicable for 1/2/3/... dimensional data
- Efficient

However ..
- Receptive field of convolutions is pre-determined
- Hard to learn long-term interactions

<img src="https://raw.githubusercontent.com/gdewael/teaching/main/presentations/transformers/img/rnn.png" width = 500>

### RNN advantages
- Takes advantage of sequentiality of patterns.
- Intuitive to use in many-to-many, one-to-many, many-to-one, ...

However ..
- Inefficient (sequential vs parallel)
- Only really applicable for 1D data
- Can be liable to forgetting data over long term. (hidden state vector is a bottleneck).

## Where does the MLP fall in this?

<img src="https://raw.githubusercontent.com/gdewael/teaching/main/presentations/transformers/img/mlp.png" width = 500>

- Fixed structure, no variable-length data
- Not applicable for many-to-many, etc ...

### Quick reminder

<img src="http://karpathy.github.io/assets/rnn/diags.jpeg" width = 600>

### The holy grail

<img src="https://raw.githubusercontent.com/gdewael/teaching/main/presentations/transformers/img/whatwewant.png" width = 500>

- Efficient
- Compares all inputs vs all inputs?
- Also for 1D/2D/...
- Also for variable-length data.

## The holy grail = self-attention

<img src="https://raw.githubusercontent.com/gdewael/teaching/main/presentations/transformers/img/token_mixing.png" width = 500>

- Comparing all tokens?

### Similarity matrices via dot products

<img src="https://raw.githubusercontent.com/gdewael/teaching/main/presentations/transformers/img/dotproduct.png" width = 500>

For one vector:
$\boldsymbol{x} \cdot \boldsymbol{x}^\top$, with $\boldsymbol{x} \in \mathbb{R}^{1 \times d}$

For a matrix:
$\boldsymbol{X} \cdot \boldsymbol{X}^\top$, with $\boldsymbol{X} \in \mathbb{R}^{n \times d}$

### From similarity matrix back to $\mathbb{R}^{n \times d}$

<img src="https://raw.githubusercontent.com/gdewael/teaching/main/presentations/transformers/img/XXtX.png" width = 700>

$A = \boldsymbol{X} \boldsymbol{X}^\top$, with $\boldsymbol{A} \in \mathbb{R}^{n \times n}$

$Z = A \boldsymbol{X}$, with $\boldsymbol{Z} \in \mathbb{R}^{n \times d}$

$Z = (\boldsymbol{X}\boldsymbol{X}^\top) \boldsymbol{X} $

### Intuition behind the equations

<img src="https://raw.githubusercontent.com/gdewael/teaching/main/presentations/transformers/img/intuitionxxtx.png" width = 500>

Every output vector is a weighted sum of input vectors, weighted by a similarity measure

### Rounding out the equations

Learning?

$(\boldsymbol{X}\boldsymbol{X}^\top) \boldsymbol{X}$

$[(\boldsymbol{X}\boldsymbol{W}_q)(\boldsymbol{X}\boldsymbol{W}_k)^\top] {\boldsymbol{X}\boldsymbol{W}_v}$, with $\boldsymbol{W}_q, \boldsymbol{W}_k, \boldsymbol{W}_v \in \mathbb{R}^{h \times h}$

$(\boldsymbol{Q}\boldsymbol{K}^\top) \boldsymbol{V}$

### Rounding out the equations

Normalization?

$(\boldsymbol{Q}\boldsymbol{K}^\top) \boldsymbol{V}$

$\texttt{softmax}(\frac{\boldsymbol{Q}\boldsymbol{K}^\top}{\sqrt{d}}) \boldsymbol{V}$

### Rounding out the equation

<img src="https://raw.githubusercontent.com/gdewael/teaching/main/presentations/transformers/img/xxtx_to_attn.png" width = 600>

$\texttt{softmax}(\frac{\boldsymbol{Q}\boldsymbol{K}^\top}{\sqrt{d}}) \boldsymbol{V}$

### A recap, what have we achieved?

From convs and RNNs to self-attention

Self-Attention properties:
- Parallel execution
- Variable input length
- Everything interacts with eachother
- Information flow depends on content of tokens

==> Generic and elegant mechanism

==> Has to learn all the structure in data from scratch, but has the freedom to do so to a better extent

### Code example

In [6]:
print(x_embed.shape)
print(x_embed.transpose(1,2))

torch.Size([1, 6, 4])
tensor([[[-0.39,  1.22,  0.98,  0.98,  1.22,  1.22],
         [-0.57,  0.68, -0.25, -0.25,  0.68,  0.68],
         [ 0.73,  0.31,  0.04,  0.04,  0.31,  0.31],
         [-0.22,  0.62, -0.76, -0.76,  0.62,  0.62]]],
       grad_fn=<TransposeBackward0>)


In [7]:
W_q = nn.Linear(4, 4)
Q = W_q(x_embed)
print(Q.shape)
print(Q.transpose(1,2))

torch.Size([1, 6, 4])
tensor([[[-0.89, -0.07,  0.30,  0.30, -0.07, -0.07],
         [ 0.09, -1.42, -0.23, -0.23, -1.42, -1.42],
         [-0.20,  0.18,  0.70,  0.70,  0.18,  0.18],
         [-0.06,  0.71,  0.56,  0.56,  0.71,  0.71]]],
       grad_fn=<TransposeBackward0>)


$\texttt{softmax}(\frac{\boldsymbol{Q}\boldsymbol{K}^\top}{\sqrt{d}}) \boldsymbol{V}$

In [8]:
W_q, W_k, W_v = nn.Linear(4, 4), nn.Linear(4, 4), nn.Linear(4, 4)
Q, K, V = W_q(x_embed), W_k(x_embed), W_v(x_embed)

A = Q @ K.transpose(2,1)
print(A.shape)
print(A)

torch.Size([1, 6, 6])
tensor([[[-0.10,  0.20,  0.23,  0.23,  0.20,  0.20],
         [ 0.41,  0.25,  1.21,  1.21,  0.25,  0.25],
         [ 0.22,  0.16,  0.55,  0.55,  0.16,  0.16],
         [ 0.22,  0.16,  0.55,  0.55,  0.16,  0.16],
         [ 0.41,  0.25,  1.21,  1.21,  0.25,  0.25],
         [ 0.41,  0.25,  1.21,  1.21,  0.25,  0.25]]],
       grad_fn=<UnsafeViewBackward0>)


$\texttt{softmax}(\frac{\boldsymbol{Q}\boldsymbol{K}^\top}{\sqrt{d}}) \boldsymbol{V}$

In [9]:
print(A)

print(F.softmax(A, dim = -1))

tensor([[[-0.10,  0.20,  0.23,  0.23,  0.20,  0.20],
         [ 0.41,  0.25,  1.21,  1.21,  0.25,  0.25],
         [ 0.22,  0.16,  0.55,  0.55,  0.16,  0.16],
         [ 0.22,  0.16,  0.55,  0.55,  0.16,  0.16],
         [ 0.41,  0.25,  1.21,  1.21,  0.25,  0.25],
         [ 0.41,  0.25,  1.21,  1.21,  0.25,  0.25]]],
       grad_fn=<UnsafeViewBackward0>)
tensor([[[0.13, 0.17, 0.18, 0.18, 0.17, 0.17],
         [0.12, 0.11, 0.28, 0.28, 0.11, 0.11],
         [0.15, 0.14, 0.21, 0.21, 0.14, 0.14],
         [0.15, 0.14, 0.21, 0.21, 0.14, 0.14],
         [0.12, 0.11, 0.28, 0.28, 0.11, 0.11],
         [0.12, 0.11, 0.28, 0.28, 0.11, 0.11]]], grad_fn=<SoftmaxBackward0>)


In [10]:
A_normalized = F.softmax(A / Q.shape[-1], dim = -1)

In [11]:
print(A_normalized.shape, V.shape)
Z = A_normalized @ V
print(Z.shape)

torch.Size([1, 6, 6]) torch.Size([1, 6, 4])
torch.Size([1, 6, 4])


In [12]:
print(x_embed.shape)

torch.Size([1, 6, 4])


In [13]:
class SelfAttention(nn.Module):
    def __init__(self, d):
        super().__init__()
        self.W_q, self.W_k, self.W_v = (
            nn.Linear(d, d), nn.Linear(d, d), nn.Linear(d, d)
        )
    def forward(self, x):
        Q, K, V = self.W_q(x), self.W_k(x), self.W_v(x)
        A = Q @ K.transpose(2,1)
        A_normalized = F.softmax(A / Q.shape[-1], dim = -1)
        return A_normalized @ V

selfattn = SelfAttention(4)
z = selfattn(x_embed)
print(x_embed.shape, z.shape)
print(x_embed.transpose(2,1))
print(z.transpose(2,1))

torch.Size([1, 6, 4]) torch.Size([1, 6, 4])
tensor([[[-0.39,  1.22,  0.98,  0.98,  1.22,  1.22],
         [-0.57,  0.68, -0.25, -0.25,  0.68,  0.68],
         [ 0.73,  0.31,  0.04,  0.04,  0.31,  0.31],
         [-0.22,  0.62, -0.76, -0.76,  0.62,  0.62]]],
       grad_fn=<TransposeBackward0>)
tensor([[[-0.12, -0.13, -0.12, -0.12, -0.13, -0.13],
         [ 0.07,  0.07,  0.07,  0.07,  0.07,  0.07],
         [ 0.26,  0.25,  0.24,  0.24,  0.25,  0.25],
         [-0.31, -0.32, -0.33, -0.33, -0.32, -0.32]]],
       grad_fn=<TransposeBackward0>)


## From self-attention to the transformer

### Multiple heads

<img src="https://raw.githubusercontent.com/gdewael/teaching/main/presentations/transformers/img/heads.png" width = 700>

### The Transformer block

<img src="https://raw.githubusercontent.com/gdewael/teaching/main/presentations/transformers/img/layernorm.jpg" width = 250>

Many of these "blocks" are stacked to make up a full transformer

<img src="https://raw.githubusercontent.com/gdewael/teaching/main/presentations/transformers/img/transformerblock.png" width = 700>

Code example

In [14]:
class TransformerLayer(nn.Module):
    def __init__(self, d):
        super().__init__()
        self.norm1 = nn.LayerNorm(d)
        self.attn = SelfAttention(d)
        self.norm2 = nn.LayerNorm(d)
        self.ff = nn.Sequential(
            nn.Linear(d, d*4), nn.ReLU(), nn.Dropout(0.2), nn.Linear(d*4, d)
            )
    def forward(self, x):
        x = self.attn(self.norm1(x)) + x
        x = self.ff(self.norm2(x)) + x
        return x

In [15]:
layer = TransformerLayer(4)
z = layer(x_embed)
print(x_embed.shape, z.shape)
print(x_embed.transpose(2,1))
print(z.transpose(2,1))

torch.Size([1, 6, 4]) torch.Size([1, 6, 4])
tensor([[[-0.39,  1.22,  0.98,  0.98,  1.22,  1.22],
         [-0.57,  0.68, -0.25, -0.25,  0.68,  0.68],
         [ 0.73,  0.31,  0.04,  0.04,  0.31,  0.31],
         [-0.22,  0.62, -0.76, -0.76,  0.62,  0.62]]],
       grad_fn=<TransposeBackward0>)
tensor([[[-0.65,  1.61,  1.12,  1.10,  1.15,  1.37],
         [-0.37,  1.02, -0.15, -0.09,  0.64,  0.93],
         [ 0.45,  0.28,  0.10,  0.20,  0.51,  0.44],
         [-1.08,  0.16, -1.35, -1.29, -0.08, -0.03]]],
       grad_fn=<TransposeBackward0>)


## Transformer "tricks"

### Permutation invariance and positional information

<img src="https://raw.githubusercontent.com/gdewael/teaching/main/presentations/transformers/img/permutequivar.png.png" width = 600>

Due to its design, shuffling inputs just shuffles outputs => The order in the sequence does not matter.

<img src="https://raw.githubusercontent.com/gdewael/teaching/main/presentations/transformers/img/positionalenc.png" width = 500>

Positional encodings give a "signal" to the transformer signifying position.

Note: positional encodings can be made for 2D, 3D inputs ...

### Positional encodings for vision transformers

<img src="https://raw.githubusercontent.com/gdewael/teaching/main/presentations/transformers/img/vision_transformer.png" width = 500>

<img src="https://raw.githubusercontent.com/gdewael/teaching/main/presentations/transformers/img/relative_positional_encoding.png" width = 550>

Alternative way of adding positional encodings is through the attention matrix itself

### Causal attention = decoder = autoregressive

<img src="https://raw.githubusercontent.com/gdewael/teaching/main/presentations/transformers/img/causal.png" width = 600>

<img src="https://raw.githubusercontent.com/gdewael/teaching/main/presentations/transformers/img/causalmask.png" width = 750>

Masking with a very large negative number pre-softmax effectively makes self-attention causal

Code example

In [16]:
W_q, W_k, W_v = nn.Linear(4, 4), nn.Linear(4, 4), nn.Linear(4, 4)
Q, K, V = W_q(x_embed), W_k(x_embed), W_v(x_embed)
A = Q @ K.transpose(2,1)
print(A)

tensor([[[ 0.57,  0.57,  0.84,  0.84,  0.57,  0.57],
         [ 0.21,  0.38,  0.24,  0.24,  0.38,  0.38],
         [ 0.37,  0.11, -0.07, -0.07,  0.11,  0.11],
         [ 0.37,  0.11, -0.07, -0.07,  0.11,  0.11],
         [ 0.21,  0.38,  0.24,  0.24,  0.38,  0.38],
         [ 0.21,  0.38,  0.24,  0.24,  0.38,  0.38]]],
       grad_fn=<UnsafeViewBackward0>)


In [17]:
mask = torch.ones_like(A).triu(diagonal = 1)
print(mask)

tensor([[[0., 1., 1., 1., 1., 1.],
         [0., 0., 1., 1., 1., 1.],
         [0., 0., 0., 1., 1., 1.],
         [0., 0., 0., 0., 1., 1.],
         [0., 0., 0., 0., 0., 1.],
         [0., 0., 0., 0., 0., 0.]]])


In [18]:
A.masked_fill(mask.bool(), float("-inf"))

tensor([[[ 0.57,  -inf,  -inf,  -inf,  -inf,  -inf],
         [ 0.21,  0.38,  -inf,  -inf,  -inf,  -inf],
         [ 0.37,  0.11, -0.07,  -inf,  -inf,  -inf],
         [ 0.37,  0.11, -0.07, -0.07,  -inf,  -inf],
         [ 0.21,  0.38,  0.24,  0.24,  0.38,  -inf],
         [ 0.21,  0.38,  0.24,  0.24,  0.38,  0.38]]],
       grad_fn=<MaskedFillBackward0>)

In [19]:
A_masked = A.masked_fill(mask.bool(), float("-inf"))
F.softmax(A_masked, dim = -1)

tensor([[[1.00, 0.00, 0.00, 0.00, 0.00, 0.00],
         [0.46, 0.54, 0.00, 0.00, 0.00, 0.00],
         [0.42, 0.32, 0.27, 0.00, 0.00, 0.00],
         [0.33, 0.25, 0.21, 0.21, 0.00, 0.00],
         [0.19, 0.22, 0.19, 0.19, 0.22, 0.00],
         [0.15, 0.18, 0.16, 0.16, 0.18, 0.18]]], grad_fn=<SoftmaxBackward0>)

### Sparse attention (masks)

<img src="https://raw.githubusercontent.com/gdewael/teaching/main/presentations/transformers/img/custommasks.png" width = 700>

### From self-attention to cross-attention

<img src="https://raw.githubusercontent.com/gdewael/teaching/main/presentations/transformers/img/machine_translation.png" width = 700>

$\texttt{softmax}(\frac{\boldsymbol{Q}\boldsymbol{K}^\top}{\sqrt{d}}) \boldsymbol{V}$

To go from one "source" set of data tokens to another "target" set of data tokens:
- The keys $K$ and values $V$ come from the source domain
- The queries $Q$ come from the target domain

### Cross-attention

<img src="https://raw.githubusercontent.com/gdewael/teaching/main/presentations/transformers/img/crossattn.png" width = 700>

$\texttt{softmax}(\frac{\boldsymbol{Q}\boldsymbol{K}^\top}{\sqrt{d}}) \boldsymbol{V}$

To go from one "source" set of data tokens to another "target" set of data tokens:
- The keys $K$ and values $V$ come from the source domain (red)
- The queries $Q$ come from the target domain (blue)

Code example

In [20]:
x_source = torch.randint(low = 0, high = 4, size = (1, 5))
x_source_embed = embedder(x_source)
print(x_source)
print(x_source_embed.transpose(2,1))
print(x_source_embed.shape)

x_target = torch.randint(low = 0, high = 4, size = (1, 2))
x_target_embed = embedder(x_target)
print(x_target)
print(x_target_embed.transpose(2,1))
print(x_target_embed.shape)

tensor([[2, 3, 1, 1, 2]])
tensor([[[-1.20,  0.98, -0.39, -0.39, -1.20],
         [-0.33, -0.25, -0.57, -0.57, -0.33],
         [ 0.74,  0.04,  0.73,  0.73,  0.74],
         [-0.83, -0.76, -0.22, -0.22, -0.83]]], grad_fn=<TransposeBackward0>)
torch.Size([1, 5, 4])
tensor([[0, 2]])
tensor([[[ 1.22, -1.20],
         [ 0.68, -0.33],
         [ 0.31,  0.74],
         [ 0.62, -0.83]]], grad_fn=<TransposeBackward0>)
torch.Size([1, 2, 4])


In [21]:
Q = W_q(x_target_embed)
K, V = W_k(x_source_embed), W_v(x_source_embed)
A = Q @ K.transpose(2,1)
print(A.shape)
print(A)

torch.Size([1, 2, 5])
tensor([[[0.09, 0.24, 0.21, 0.21, 0.09],
         [0.54, 0.57, 0.51, 0.51, 0.54]]], grad_fn=<UnsafeViewBackward0>)


In [22]:
Z = torch.softmax(A / Q.shape[-1], dim = -1) @ V
print(Z.shape)
print(Z)

torch.Size([1, 2, 4])
tensor([[[-0.17, -0.47, -0.31, -0.86],
         [-0.17, -0.47, -0.31, -0.86]]], grad_fn=<UnsafeViewBackward0>)


## Modern Transformer libraries

- Included in PyTorch
- Huggingface
- ...
<img src="https://raw.githubusercontent.com/gdewael/teaching/main/presentations/transformers/img/hugginface.png" width = 700>

## Applications

### ChatGPT

1. Scrape the internet (essentially) + train a huge autoregressive transformer
2. Fine-tune it for answering prompts

### 1. Scrape the internet (essentially) + train a huge autoregressive transformer

<img src="https://raw.githubusercontent.com/gdewael/teaching/main/presentations/transformers/img/gpt.png" width = 400>

Bottleneck: $$$

### 2. Fine-tune it for answering prompts

<img src="https://i.kym-cdn.com/entries/icons/original/000/044/025/shoggothhh_header.jpg" width = 700>

The smiley says: "As an AI language model, I have trained to generate responses that are intended to be helpful, informative, and objective ..."

Bottleneck: Data collection (again, $$$)

#### Reinforcement learning from human feedback

<img src="https://miro.medium.com/v2/resize:fit:1400/1*yCzfUi2CgSl-yW_gYAjDMw.png" width = 600>

1. Collect various completions from prompts, let humans label what they like.
1. Fit a model to predict what humans like (Reward model).
2. Use the reward model to fine-tune the model to say what humans like.

### AlphaFold

<img src="https://lh3.googleusercontent.com/pL18FAkwzN55iHvMt2W4XRGjueHWe0ILqX1Qm2e4qlPsK3yjDSott3LZIgSg2uqPPn7Zvu3hfxUtYtjDs3bM27zcF8AO_jYnfk8q=w1440" width = 700>

- Flexibility of self-attention: learns interactions between the far end and the close end of proteins
- Positional encodings: Grounding in 3D space

### Other biological applications

#### CpG Transformer

<img src="https://www.biorxiv.org/content/biorxiv/early/2021/09/17/2021.06.08.447547/F1.large.jpg" width = 600>

#### TIS Transformer

<img src="https://www.biorxiv.org/content/biorxiv/early/2021/11/19/2021.11.18.468957/F1.large.jpg" width = 600>

### Take home message: the transformer

<img src="https://raw.githubusercontent.com/gdewael/teaching/main/presentations/transformers/img/whatwewant.png" width = 500>

As opposed to CNNs, RNNs, ...
- Learn interactions between all inputs
- "Freedom" to learn any pattern: high skill ceiling, but need lots of data
- Because of this: the basis for almost all big-news advancements in AI