# Transformers

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, TensorDataset, random_split

In [3]:
DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
DEVICE

'mps'

## 1. Attention mechanism

### 1.1 Motivation

[WIP]

### 1.2 Details

#### 1.2.1 Token by token

For each token, we embed them using a triplet of query, key, and value vectors. Keep in mind that
- The **query** is the **question.** Given a token and its query vector, we use **that query vector** and compares it with other tokens to **ask** "how much attention should we pay to other tokens?"
- The **key** is the **identity.** When our token (let's call it A) is queried by another token B, we use **A's key** and **B's query** to calculate the attention.
- The **value** is the **contribution.** After the attention has been calculated, the **contribution** to the final output of our token is the attention times **its value vector.**

To embed our tokens as described, we will need 3 learnable weight matrices: one for the query, one for the key, and one for the value. Often, they are call the QKV matrices. The dimensions of them are
$$
W_q \in \mathbb{R}^{d \times d_k}, \quad
W_k \in \mathbb{R}^{d \times d_k}, \quad
W_v \in \mathbb{R}^{d \times d_k}.
$$
Then, for each token $x_i$, we compute the three QKV vectors
$$
q_i = W_q^T x_i, \quad
k_i = W_k^T x_i, \quad
v_i = W_v^T x_i.
$$
The dot product between the query and the key tells us how "similar" these vectors are. Given a query, the attention mechanism works by computing the dot product of that query with all of the keys, and then applies a softmax function to get the probability distribution over the keys
$$
a_i = p(k | q_i) = \text{softmax} \left(
	\frac{[q_i^T k_1, q^T_i k_2, \dots, q^T_i k_n]}{\sqrt{d_k}}
\right) \in \mathbb{R}^{1 \times n}.
$$
Here, the factor of $1/\sqrt{d_k}$ helps with numerical stability (they keep the dot products from getting too large). Given this, the final output is given by
$$
z_i = \sum_j p(k_j | q_i) v_j.
$$

#### 1.2.2 Doing it all at once

We start by stacking our inputs into a single matrix $X$
$$
X = \begin{bmatrix}
	x_1^T  \\
	x_2^T  \\
	\vdots \\
	x_n^T  \\
\end{bmatrix} \in \mathbb{R}^{n \times d}.
$$
Then, we can stack the QKV vectors into $Q, K, V$ defined as
$$
Q = \begin{bmatrix}
	q_1^T  \\
	q_2^T  \\
	\vdots \\
	q_n^T  \\
\end{bmatrix}
= XW_q \in \mathbb{R}^{n \times d_k}, \quad
K = \begin{bmatrix}
	k_1^T  \\
	k_2^T  \\
	\vdots \\
	k_n^T  \\
\end{bmatrix}
= XW_k \in \mathbb{R}^{n \times d_k}, \quad
V = \begin{bmatrix}
	v_1^T  \\
	v_2^T  \\
	\vdots \\
	v_n^T  \\
\end{bmatrix}
= XW_q \in \mathbb{R}^{n \times d_k}.
$$
Note again that **row $i$ in the QKV matrices store the QKV vectors of token $i$**. The attention matrix is thus given by
$$
A = \text{softmax} \left(
	\frac{1}{\sqrt{d_k}}
	\begin{bmatrix}
		q_1^T k_1 & q^T_1 k_2 & \dots  & q^T_1 k_n \\
		q_2^T k_1 & q^T_2 k_2 & \dots  & q^T_2 k_n \\
		\vdots	  & \vdots    & \ddots & \vdots    \\
		q_n^T k_1 & q^T_n k_2 & \dots & q^T_n k_n  \\
	\end{bmatrix}
\right) = \text{softmax} \left(
	\frac{QK^T}{\sqrt{d_k}}
\right) \in \mathbb{R}^{n \times n}.
$$
Here, softmax is applied in a row-wise manner. Row $i$ of this matrix contains the weights of value vectors that we would need to construct the $i$-th attention outputs. It follows that the full output is given by
$$
Z = \begin{bmatrix}
	z_1^T  \\
	z_2^T  \\
	\vdots \\
	z_n^T  \\
\end{bmatrix}
= AV \in \mathbb{R}^{n \times d_k}.
$$

#### 1.2.3 Implementation

We will implement a simple attention layer as described above, with a fully connected layer at the end to map our outputs back to the original dimension.

In [None]:
class Attention(nn.Module):
	"""
	Implementation of a simple attention layer.
	"""
	def __init__(self, embed_dim, key_dim):
		"""
		Inputs:
		- embed_dim: dimension of the token embedding
		- key_dim:   dimension of the key (value, and query)
		"""
		super().__init__()
		self.d_embed = embed_dim
		self.d_key   = key_dim

		# The QKV matrices
		self.Wq = nn.Linear(embed_dim, key_dim)
		self.Wk = nn.Linear(embed_dim, key_dim)
		self.Wv = nn.Linear(embed_dim, key_dim)

		# Fully connected layer at the end
		self.Wc = nn.Linear(key_dim, embed_dim)
	
	def forward(self, x):
		"""
		Input:
		- x: (batch_size, seq_len, embed_dim)

		Output:
		- x: (batch_size, seq_len, embed_dim)
		- A: the attention matrices (batch_size, seq_len, seq_len)
		"""
		# Note: x has shape (B, n, d)
		Q = self.Wq(x)  # (B, n, d_k)
		K = self.Wk(x)  # (B, n, d_k)
		V = self.Wv(x)  # (B, n, d_k)

		A = F.softmax(Q @ K.transpose(-2, -1) / self.d_key**0.5, dim=-1)  # (b, n, n)

		# Forward pass
		x = A @ V       # (B, n, d_k)
		x = self.Wc(x)  # (B, n, d)

		return x, A

We can do a quick test to check that the shapes works out

In [18]:
# Setting up the layer
embed_dim = 4
key_dim = 2

attention_layer = Attention(embed_dim, key_dim)
attention_layer.to(DEVICE)

# Dummy data
batch_size = 10
seq_len = 20
x = torch.randn(batch_size, seq_len, embed_dim, device=DEVICE)
print(f"The shape of our input is {tuple(x.cpu().shape)}")

# Forward pass
x, A = attention_layer(x)
print(f"The shape of our output is {tuple(x.cpu().shape)}")
print(f"The shape of our attention matrices is {tuple(A.cpu().shape)}")

The shape of our input is (10, 20, 4)
The shape of our output is (10, 20, 4)
The shape of our attention matrices is (10, 20, 20)


[WIP]