# Model: transformer architecture

We will try to learn modular addition and substraction with a transformer.
Many codebase that ensure efficient implementation can be found online, e.g., with [`NanoGPT`](https://github.com/karpathy/nanoGPT), or [`xFormer`](https://github.com/facebookresearch/xformers).

I recommend watching Andrej Karpathy's [lectures](https://www.youtube.com/playlist?list=PLAqhIrjkxbuWI23v9cThsA9GvCAUhRvKZ) to ground yourself in the basics of deep learning and transformers.

In [1]:
%load_ext autoreload
%autoreload 2

Let us start by loading the data.

In [2]:
import math

import torch
import torch.nn as nn
import torch.nn.functional as F

from llmtuto.config import DATA_DIR

problem = 'multi_base'
save_dir = DATA_DIR / problem

x_train = torch.load(save_dir / 'x_train.pt')
y_train = torch.load(save_dir / 'y_train.pt')
x_test = torch.load(save_dir / 'x_test.pt')
y_test = torch.load(save_dir / 'y_test.pt')

# # Check for data correctness
# assert (x_train.sum(axis=1) % 60 == y_train).all()
# assert (x_test.sum(axis=1) % 60 == y_test).all()

  x_train = torch.load(save_dir / 'x_train.pt')
  y_train = torch.load(save_dir / 'y_train.pt')
  x_test = torch.load(save_dir / 'x_test.pt')
  y_test = torch.load(save_dir / 'y_test.pt')


## Token embeddings

Rather than discrete tokens, transformer works with sentences in $\mathbb{R}^d$.
Therefore, tokens are embed $\mathbb{R}^d$ through a look-up table (`one_hot(x_train) @ token_emb`).

In [3]:
import torch.nn as nn

vocab_size = torch.max(x_train) + 1
emb_dim = 8

token_emb = nn.Embedding(vocab_size, emb_dim) # word token embedding

# Comments abbreviations:
# N: batch size
# L: sequence length
# V: vocabulary size
# E: embedding dimension

x_emb = token_emb(x_train)  # one_hot(x_train) @ wte: (N, L, V, E) @ (V, E) -> (N, L, E)

In [4]:
x_emb.shape

torch.Size([28773, 4, 8])

#### Weight tying

We will predict a token `y` through a score (which will be interpreted as a logit by the cross-entropy loss):

```s(x, y) = g(token_emb(x)) @ token_emb(y).T  = g(x) @ token_emb.T @ one_hot(y)```.

This corresponds to adding a linear layer at the end of the network and tying weights between the embedding and the output "un-embedding" layer.

In [5]:
unemb = nn.Linear(emb_dim, vocab_size, bias=False)    # un-embedding layer (E, V)
unemb.weight = token_emb.weight                       # tie weights

## Attention block

We start with a simple implementation of a single attention head.
You can check Andrej Karpathy's videos for a more detailed explanation.

In [6]:
q_mat = nn.Linear(emb_dim, emb_dim, bias=False)       # query matrix (E, E)
k_mat = nn.Linear(emb_dim, emb_dim, bias=False)       # key matrix
v_mat = nn.Linear(emb_dim, emb_dim, bias=False)       # value matrix

q = q_mat(x_emb)                                      # query (N, L, E) @ (E, E) -> (N, L, E)
k = k_mat(x_emb)                                      # key 
v = v_mat(x_emb)                                      # value

attn = q @ k.transpose(-1, -2) / math.sqrt(emb_dim)   # attention (N, L, E) @ (N, E, L) -> (N, L, L)
# When attention is causal, we should not attend to previous tokens
causal = True
if causal:
    L = x_emb.shape[1]
    mask = torch.tril(torch.ones(L, L)) == 0
    attn = attn.masked_fill(mask, float('-inf'))
attn = torch.softmax(attn, dim=-1)                    # softmax over last dimension

z = attn @ v                                          # (N, L, L) @ (N, L, E) -> (N, L, E)

Let us now implement several heads, to do so, we cut the different matrices per heads (so to "fuse" matrix multiplication).

In [7]:
# Comments abbreviations:
# H: number of heads

N, L, E = x_emb.size()
H = 2

q_heads = q.view(N, L, H, E // H).transpose(1, 2)               # (N, L, E) -> (N, L, H, E / H) -> (N, H, L, E / H)
k_heads = k.view(N, L, H, E // H).transpose(1, 2)
v_heads = v.view(N, L, H, E // H).transpose(1, 2)

attn = q_heads @ k_heads.transpose(-1, -2) / math.sqrt(E // H)  # (N, H, L, E / H) @ (N, H, E / H, L) -> (N, H, L, L)
if causal:
    mask = torch.tril(torch.ones(L, L)).view(1, 1, L, L) == 0
    attn = attn.masked_fill(mask, float('-inf'))
attn = F.softmax(attn, dim=-1)
z = attn @ v_heads                                             # (N, H, L, L) @ (N, H, L, E / H) -> (N, H ,L, E / H)
z = z.transpose(1, 2).contiguous().view(N, L, E)               # (N, H, L, E / H) -> (N, L, H, E / H) -> (N, L, E)

We now turn to a faster implementation based on the "fusion" of many operation at detailed in the [flash attention paper](https://arxiv.org/abs/2205.14135) and implemented by pytorch.

In [8]:
attn_mat = nn.Linear(emb_dim, 3 * emb_dim, bias=False)          # attention matrix (E, E)

N, L, E = x_emb.size()
training = True
causal = False
dropout = 0
n_head = 2

# calculate query, key, values for all heads in batch and move head forward to be the batch dim
q, k, v  = attn_mat(x_emb).split(emb_dim, dim=2)
q = q.view(N, L, H, E // H).transpose(1, 2)
k = k.view(N, L, H, E // H).transpose(1, 2)
v = v.view(N, L, H, E // H).transpose(1, 2)

# efficient attention using Flash Attention CUDA kernels
z = F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=dropout if training else 0, is_causal=causal)
z1 = z.clone()
z = z.transpose(1, 2).contiguous().view(N, L, E)
z2 = z.clone()

# Check computation
attn = q @ k.transpose(-1, -2) / math.sqrt(E // H)
if causal:
    L = x_emb.shape[1]
    mask = torch.tril(torch.ones(L, L)) == 0
    attn = attn.masked_fill(mask, float('-inf'))
attn = torch.softmax(attn, dim=-1)
z_bis = attn @ v
z_bis = z_bis.transpose(1, 2).contiguous().view(N, L, E)
print('Testing correctness:', (z_bis - z).abs().max())

Testing correctness: tensor(2.3842e-07, grad_fn=<MaxBackward1>)


We are now ready to write a self-attention module.
For readibility, we will pass all arguments into a `config` object.

In [9]:
from llmtuto.model.transformer import SelfAttention, TransformerConfig


config = TransformerConfig(
    vocab_size = vocab_size, 
    emb_dim = emb_dim, 
    n_head = n_head, 
    attn_dropout = dropout,
    causal = causal
)

state_dict = {
    'query.weight': attn_mat.weight[:emb_dim],
    'key.weight': attn_mat.weight[emb_dim:2*emb_dim],
    'value.weight': attn_mat.weight[2*emb_dim:],
    'output.weight': torch.eye(emb_dim, emb_dim),
}

self_att = SelfAttention(config)
self_att.load_state_dict(state_dict)

with torch.no_grad():
    z_tres = self_att(x_emb)
    print('Testing correctness:', (z_tres - z).abs().max())

Testing correctness: tensor(3.5763e-07)


Similarly we could write a cross attention module. 
Let us consider the test data to create a new sequence to attend to.

In [10]:
from llmtuto.model.transformer import CrossAttention

x2_emb = token_emb(x_test[:, :3])
x2_emb.shape, x_emb.shape

xattn = nn.Linear(E, E, bias=False)      # query
yattn = nn.Linear(E, 2 * E, bias=False)  # key, value

N_new = min(x_emb.size(0), x2_emb.size(0))
x, y = x_emb[:N_new], x2_emb[:N_new]
S = y.size(1)

# Query, key, value
# (N, L, E) @ (E, E) -> (N, L, E)
q  = xattn(x)
# (N, S, E) @ (E, 2 * E) -> (N, S, 2 * E) -> (N, S, E) x 2
k, v  = yattn(y).split(E, dim=2)
# reformating: (N, LS, E) -> (N, LS, H, E / H) -> (N, H, LS, E / H)
q = q.view(N_new, L, H, E // H).transpose(1, 2)
k = k.view(N_new, S, H, E // H).transpose(1, 2)
v = v.view(N_new, S, H, E // H).transpose(1, 2)

# Attention with (q, k): (N, H, L, E / H) @ (N, H, E / H, S) -> (N, H, L, S)
# Value with v:          (N, H, L, S) @ (N, H, S, E / H) -> (N, H ,L, E / H)
z = F.scaled_dot_product_attention(
    q, k, v, attn_mask=None, dropout_p=dropout if training else 0, is_causal=causal
)
# reformating:           (N, H, L, E / H) -> (N, L, H, E / H) -> (N, L, E)
z = z.transpose(1, 2).contiguous().view(N_new, L, E)


config = TransformerConfig(
    vocab_size = vocab_size, 
    emb_dim = emb_dim, 
    n_head = n_head, 
    attn_dropout = dropout,
    causal = causal
)

state_dict = {
    'query.weight': xattn.weight,
    'key.weight': yattn.weight[:emb_dim],
    'value.weight': yattn.weight[emb_dim:],
    'output.weight': torch.eye(emb_dim, emb_dim),
}
cross_att = CrossAttention(config)
cross_att.load_state_dict(state_dict)

with torch.no_grad():
    z_tres = cross_att(x, y)
    print('Testing correctness:', (z_tres - z).abs().max())

Testing correctness: tensor(2.3842e-07)


#### Positional encodings

In LLaMa and Mistral implementations, a [ROPE](https://arxiv.org/pdf/2104.09864.pdf) position encoding is added inside each attention block.
It consists in decoupling the difference queries into a different frequencies, so that queries are more likely be attend keys-values of tokens that are close in the sequence.


In [11]:
def get_rope_freqs(seq_len, head_emb_dim, theta: float = 10_000):
    """
    Returns the frequencies for the positional encoding.

    Parameters
    ----------
    seq_len: int
        The length of the sequence.
    head_emb_dim: int
        The dimension of the head embedding (E / H).
    theta: float
        An angle parameter.
    """
    freqs = 1.0 / (theta ** (torch.arange(0, head_emb_dim - 1, 2).float() / head_emb_dim))
    t = torch.arange(seq_len, device=freqs.device).float()
    out = (t.unsqueeze(-1) * freqs.unsqueeze(0))
    out = torch.polar(torch.ones_like(out), out)
    return out


def rope(qk, angles):
    """
    Applies the rotary embeddings to queries or keys.

    Parameters
    ----------
    qk: torch.Tensor of size (N, H, L, E / H)
        The queries or keys intepreted as complex numbers (with contiguous real and imaginary parts).
    angles: torch.Tensor of size (L, (E / H) / 2)
        The angles to apply.
    """
    qk_complex = torch.view_as_complex(qk.reshape(*qk.shape[:-1], -1, 2))
    qk_rot = torch.view_as_real(qk_complex * angles).flatten(-2)
    return qk_rot.type_as(qk)

In [12]:
config = TransformerConfig(
    vocab_size=vocab_size,
    emb_dim=emb_dim,
    n_head=n_head,
    ffn_dropout=dropout,
    causal=causal,
    attn_bias=False,
    attn_dropout=0,
    rope=True,
    rope_theta=10_000,
    seq_len=4,
)

module = CrossAttention(config)
z = module(x, y)

module = SelfAttention(config)
z = module(x)

Let us test if our implementation works with other dtypes and device.

In [13]:
module = CrossAttention(config)
module = module.half()
module = module.to('cuda')

x_, y_ = x.half(), y.half()
x_, y_ = x_.to('cuda'), y_.to('cuda')

z = module(x_, y_)

## Feedforward block

Attention layers are followed in transformer by a multi-layer perceptron with one hidden layer.
Several activations could be used.

In [14]:
class SwiGLU(nn.Module):
    """
    Swish-Gated Linear Unit (SwiGLU) activation function.

    Parameters
    ----------
    fan_in: int
        input dimension
    """
    def __init__(self, fan_in):
        super().__init__()
        self.fc = nn.Linear(fan_in, fan_in, bias=False)

    def forward(self, x):
        return F.silu(x) * self.fc(x)


hidden_dim = 4 * emb_dim

fc1 = nn.Linear(emb_dim, hidden_dim, bias=False)
fc2 = nn.Linear(hidden_dim, emb_dim, bias=False)

activation_name = "swiglu"
match activation_name:
    case "relu":
        activation = F.relu
    case "gelu":
        activation = F.gelu
    case "swiglu":
        activation = SwiGLU(hidden_dim)

In [15]:
out = fc1(x)
out = activation(out)
out = F.dropout(out, p=dropout, training=training)
out = fc2(out)

Let us write this in a class.

In [16]:
from llmtuto.model.transformer import FeedForward, TransformerConfig


model = FeedForward(TransformerConfig(emb_dim=E, ffn_dim=8 * E))
z = model(x)

## Transformer block

We are now ready to build a transformer block. It consists of a self-attention layer followed by a feedforward layer, together with normalization and residual connections.

There are two main variants, whether layer normalization is done before or after the attention and feedforward layers (i.e., after or before the residual connection).
Thinking that residual connection helps by allowing the model to parameterize small changes to the input, it is better to put the normalization after the residual connection, which corresponds to the "pre-norm" implementation.


In [17]:
config = TransformerConfig(
    emb_dim = E,

    # Attention parameters
    n_head = H,
    causal = True,
    attn_bias = False,
    attn_dropout = 0.0,
    rope = True,
    seq_len = L,
    rope_theta = 10_000,

    # Feed-forward parameters
    activation = "swiglu",
    ffn_dim = None,
    ffn_bias = False,
    ffn_dropout = 0.0,

    # Transformer block parameter
    pre_norm = True,
)


# Pytorch 2.0.1 does not have LayerNorm without bias
if torch.__version__ < '2.1':
    class LayerNorm(nn.Module):
        def __init__(self, fan_in, bias):
            super().__init__()
            self.weight = nn.Parameter(torch.ones(fan_in))
            self.bias = nn.Parameter(torch.zeros(fan_in)) if bias else None

        def forward(self, x):
            return F.layer_norm(x, normalized_shape=self.weight.shape, weight=self.weight, bias=self.bias, eps=1e-5)
else:
    LayerNorm = nn.LayerNorm


ln_1 = LayerNorm(config.emb_dim, bias=False)
attn = SelfAttention(config)
ln_2 = LayerNorm(config.emb_dim, bias=False)
ffn = FeedForward(config)

In [18]:
if config.pre_norm:
    out = x + attn(ln_1(x))
    out = out + ffn(ln_2(out))
else:
    out = x + ln_1(attn(x))
    out = out + ln_2(ffn(out))

## Transformer architecture

Let us wrap everything into a (GPT-like) decoder only architecture.
We have not discuss position encoding.
We can define it through embeddings to be learned, and add together position and toekn embeddings.

In [19]:
from llmtuto.model.transformer import TransformerBlock

config = TransformerConfig(
    # Embedding parameters
    vocab_size = 32_768,
    emb_dim = 512,

    # Attention parameters
    n_head = 16,
    causal = True,
    attn_bias = False,
    attn_dropout = 0.0,
    rope = True,
    seq_len = 1024,
    rope_theta = 10_000,

    # Feed-forward parameters
    activation = "swiglu",
    ffn_dim = None,
    ffn_bias = False,
    ffn_dropout = 0.0,

    # Transformer block parameter
    norm = 'layer',
    norm_bias = False,
    pre_norm = True,

    # Transformer parameters
    n_layer = 12,
    emb_dropout = 0.0,
    pos_emb = True,
)


token_emb = nn.Embedding(config.vocab_size, config.emb_dim)
pos_emb = nn.Embedding(config.seq_len, config.emb_dim)

transformer = nn.Sequential(
    *(TransformerBlock(config) for _ in range(config.n_layer))
)
output = nn.Linear(config.emb_dim, config.vocab_size, bias=False)
output.weight = token_emb.weight

In [20]:
L = config.seq_len
N = 4
x = torch.randint(0, config.vocab_size, (N, L))

xte = token_emb(x)
xpe = pos_emb(torch.arange(L))
z = xte + xpe
z = F.dropout(z, p=config.emb_dropout, training=True)
z_out = transformer(z)
out = output(z_out)

Let us try our implementation on GPU with float16.

In [21]:
device = 'cuda'
x = x.to(device)
token_emb = token_emb.half().to(device)
pos_emb = pos_emb.half().to(device)
transformer = transformer.half().to(device)
output = output.half().to(device)

xte = token_emb(x)
xpe = pos_emb(torch.arange(L, device=device))
z = F.dropout(xte + xpe, p=config.emb_dropout, training=True)
z_out = transformer(z)
out = output(z_out)

We end up by wrapping everything into a class.

In [22]:
from llmtuto.model.transformer import CausalTransformer


model = CausalTransformer(config).to(device)
out = model(x)