**Chapter 17 – Speeding Up Transformers**

_This notebook contains all the sample code and solutions to the exercises in Chapter 17._

<table align="left">
  <td>
    <a href="https://colab.research.google.com/github/ageron/handson-mlp/blob/main/17_advanced_transformer_techniques.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>
  </td>
  <td>
    <a target="_blank" href="https://kaggle.com/kernels/welcome?src=https://github.com/ageron/handson-mlp/blob/main/17_advanced_transformer_techniques.ipynb"><img src="https://kaggle.com/static/images/open-in-kaggle.svg" /></a>
  </td>
</table>

# Setup

This project requires Python 3.10 or above:

In [1]:
import sys

assert sys.version_info >= (3, 10)

Are we using Colab or Kaggle?

In [2]:
IS_COLAB = "google.colab" in sys.modules
IS_KAGGLE = "kaggle_secrets" in sys.modules

We also need PyTorch ≥ 2.6.0:

In [3]:
from packaging.version import Version
import torch

assert Version(torch.__version__) >= Version("2.6.0")

This chapter can be very slow without a hardware accelerator, so if we can find one, let's use it:

In [4]:
if torch.cuda.is_available():
    device = "cuda"
elif torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cpu"

device

'cuda'

Let's issue a warning if there's no hardware accelerator available:

In [5]:
if device == "cpu":
    print("Neural nets can be very slow without a hardware accelerator.")
    if IS_COLAB:
        print("Go to Runtime > Change runtime and select a GPU hardware "
              "accelerator.")
    if IS_KAGGLE:
        print("Go to Settings > Accelerator and select GPU.")

As we did in earlier chapters, let's define the default font sizes to make the figures prettier:

In [6]:
import matplotlib.pyplot as plt

plt.rc('font', size=14)
plt.rc('axes', labelsize=14, titlesize=14)
plt.rc('legend', fontsize=14)
plt.rc('xtick', labelsize=10)
plt.rc('ytick', labelsize=10)

# Faster Decoding at Inference Time

## Key/Value Caching

In [7]:
from transformers import AutoModelForCausalLM, AutoTokenizer

model_id = "facebook/opt-125m"
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(model_id)
prompt = "Once upon a time there lived"
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
for use_cache in (True, False):
    print(f"{use_cache=}")
    %time model.generate(**inputs, max_new_tokens=500, do_sample=False, use_cache=use_cache)

config.json:   0%|          | 0.00/651 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/251M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/137 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/251M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/685 [00:00<?, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/441 [00:00<?, ?B/s]

use_cache=True
CPU times: user 3.91 s, sys: 150 ms, total: 4.06 s
Wall time: 4.62 s
use_cache=False
CPU times: user 10.6 s, sys: 13.9 ms, total: 10.6 s
Wall time: 10.7 s


## Speculative decoding

In [8]:
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import set_seed

set_seed(42)
target_model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m",
                                                    device_map="auto")
draft_model = AutoModelForCausalLM.from_pretrained("facebook/opt-125m",
                                                   device_map="auto")
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
prompt = "Once upon a time there lived"
inputs = tokenizer(prompt, return_tensors="pt").to(target_model.device)
outputs = target_model.generate(**inputs, max_new_tokens=100, do_sample=True,
                                temperature=1, assistant_model=draft_model)
result = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(result)

config.json:   0%|          | 0.00/644 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/663M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/662M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/137 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/685 [00:00<?, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/441 [00:00<?, ?B/s]

Once upon a time there lived in an aldermanic family the eldest daughter. Her father was a very rich man of considerable wealth.

She was very wise and had developed a good marriage with her husband. He never slept on the floor while she slept, but on the table above the bed he never seemed to be awake.

The eldest daughter knew what was important to her husband and her parents, and if he refused to do as she told him, she would not accept a roomful of guests for a


# Boosting Multi-Head Attention

### BigBird

In [9]:
from transformers import pipeline

model_id = "google/bigbird-roberta-base"
pipeline = pipeline(task="fill-mask", model=model_id)
pipeline("She was feeling unwell so she took some [MASK] medicine.")

config.json:   0%|          | 0.00/760 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/513M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/513M [00:00<?, ?B/s]

BigBirdForMaskedLM has generative capabilities, as `prepare_inputs_for_generation` is explicitly defined. However, it doesn't directly inherit from `GenerationMixin`. From 👉v4.50👈 onwards, `PreTrainedModel` will NOT inherit from `GenerationMixin`, and this model will lose the ability to call `generate` and other related functions.
  - If you are the owner of the model architecture code, please modify your model class such that it inherits from `GenerationMixin` (after `PreTrainedModel`, otherwise you'll get an exception).
  - If you are not the owner of the model architecture class, please contact the model code owner to update it.
BigBirdForMaskedLM has generative capabilities, as `prepare_inputs_for_generation` is explicitly defined. However, it doesn't directly inherit from `GenerationMixin`. From 👉v4.50👈 onwards, `PreTrainedModel` will NOT inherit from `GenerationMixin`, and this model will lose the ability to call `generate` and other related functions.
  - If you are the owner of

tokenizer_config.json: 0.00B [00:00, ?B/s]

spiece.model:   0%|          | 0.00/846k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/775 [00:00<?, ?B/s]

Device set to use cuda:0
Attention type 'block_sparse' is not possible if sequence_length: 14 <= num global tokens: 2 * config.block_size + min. num sliding tokens: 3 * config.block_size + config.num_random_blocks * config.block_size + additional buffer: config.num_random_blocks * config.block_size = 704 with config.block_size = 64, config.num_random_blocks = 3. Changing attention type to 'original_full'...


[{'score': 0.2818744480609894,
  'token': 2457,
  'token_str': 'pain',
  'sequence': 'She was feeling unwell so she took some pain medicine.'},
 {'score': 0.21908044815063477,
  'token': 4793,
  'token_str': 'cold',
  'sequence': 'She was feeling unwell so she took some cold medicine.'},
 {'score': 0.11271445453166962,
  'token': 36298,
  'token_str': 'allergy',
  'sequence': 'She was feeling unwell so she took some allergy medicine.'},
 {'score': 0.06932126730680466,
  'token': 22195,
  'token_str': 'cough',
  'sequence': 'She was feeling unwell so she took some cough medicine.'},
 {'score': 0.045178256928920746,
  'token': 35968,
  'token_str': 'herbal',
  'sequence': 'She was feeling unwell so she took some herbal medicine.'}]

Note: you can safely ignore the warnings above.

## Approximate attention

### Reformer

In [10]:
import torch
import torch.nn as nn
import torch.nn.functional as F

def angular_lsh(vectors, k):
    R = torch.randn(vectors.size(-1), k // 2, device=vectors.device)
    normalized_vectors = F.normalize(vectors, p=2.0, dim=1)
    V_proj = normalized_vectors @ R
    V_concat = torch.cat([V_proj, -V_proj], dim=1)
    return torch.argmax(V_concat, dim=1)

In [11]:
torch.manual_seed(42)
vectors = torch.rand(16, 512)
angular_lsh(vectors, k=4)

tensor([2, 2, 0, 3, 0, 2, 2, 2, 2, 1, 1, 3, 3, 1, 2, 1])

## Performer

The expected value of $\exp(\mathbf{w}\cdot\mathbf{x})$ for a given vector **x** and a random vector **w** sampled from a Gaussian distribution is $\exp\left(\frac{1}{2}\|x\|^2\right)$. We can test this using PyTorch:

In [12]:
torch.manual_seed(42)
d, m = 64, 1024
W = torch.randn(d, m)
X = torch.randn(5, d) / d ** 0.5
R = torch.exp(X @ W)
print(R.mean(axis=-1))
print(torch.exp(0.5 * (X.norm(dim=-1)**2)))

tensor([1.5851, 1.6703, 1.6728, 1.9729, 1.7934])
tensor([1.6516, 1.7578, 1.6594, 1.8050, 1.7751])


Pretty close!

Now let's implement the function $\phi(\mathbf{x}) = \dfrac{\exp(\mathbf{x} \mathbf{W})-\frac{1}{2}\|\mathbf{x}\|^2)}{\sqrt m}$.

In [13]:
def phi(X, W, dim_subtract_max=(-2, -1)):
    squared_norms = X.square().sum(dim=-1, keepdim=True)
    X_proj = X @ W
    max_vals = X_proj.amax(dim=dim_subtract_max, keepdim=True)
    return torch.exp(X_proj - max_vals - squared_norms / 2) / W.size(-1) ** 0.5

It's possible to prove that the expected value of _ϕ_(**Q**)_ϕ_(**K**)<sup>⊺</sup> is equal to exp(**QK**<sup>⊺</sup>). Let's check that this is indeed the case:

In [14]:
torch.manual_seed(42)
batch_size = 32
d_model = 512
n_heads = 8
Lq = 200
Lk = 100
m = 256
d_head = d_model // n_heads
W = torch.randn(n_heads, d_head, m)
Q = torch.randn(batch_size, n_heads, Lq, d_head)
K = torch.randn(batch_size, n_heads, Lk, d_head)
scale = 1 / d_head ** 0.5
Qp = phi(Q * scale ** 0.5, W, dim_subtract_max=-1)
Kp = phi(K * scale ** 0.5, W)
A = Qp @ Kp.transpose(-2, -1)
D = A.sum(dim=-1, keepdim=True)
result = A / (D + 1e-6)
expected = torch.softmax(Q @ K.transpose(-2, -1) * scale, dim=-1)
rmse = F.mse_loss(result, expected) ** 0.5
rmse

tensor(0.0171)

That's a pretty good approximation! We can still improve it by orthogonalizing **W** using QR decomposition. Since there can only _d_ orthogonal vectors in a _d_-dimensional space, we orthogonalize each chunk of _d_ random vectors.

In [15]:
def orthogonalize(W):
    d_head = W.size(-2)
    W_orth = torch.cat([torch.linalg.qr(W_chunk)[0]
                        for W_chunk in W.split(d_head, dim=-1)], dim=-1)
    return W_orth * d_head ** 0.5

In [16]:
W_orth = orthogonalize(W)

Let's compute the RMSE once again:

In [17]:
Qp2 = phi(Q * scale ** 0.5, W_orth, dim_subtract_max=-1)
Kp2 = phi(K * scale ** 0.5, W_orth)
A2 = Qp2 @ Kp2.transpose(-2, -1)
D2 = A2.sum(dim=-1, keepdim=True)
result2 = A2 / (D2 + 1e-6)
rmse2 = F.mse_loss(result2, expected) ** 0.5
rmse2

tensor(0.0160)

That's a bit better. If you increase the number of features, it will reduce this error further, at the cost of more compute and memory usage.

Now we're ready to implement FAVOR+ attention:

In [18]:
class FavorAttention(nn.Module):
    def __init__(self, d_model, n_heads, n_features):
        super().__init__()
        self.d_head = d_model // n_heads
        W = torch.randn(n_heads, self.d_head, n_features)  # h, d, m
        W = orthogonalize(W)
        self.register_buffer("W", W)

    def forward(self, Q, K, V):
        scale = self.d_head ** -0.25
        Qp = phi(Q * scale, self.W, dim_subtract_max=-1)
        Kp = phi(K * scale, self.W)
        D = Qp @ Kp.sum(dim=-2).unsqueeze(-1)  # B, h, Lq, 1
        Kp_T_V = Kp.transpose(-2, -1) @ V  # B, h, m, d
        return (Qp @ Kp_T_V) / (D + 1e-6)

In [19]:
torch.manual_seed(42)
Q = torch.randn(batch_size, n_heads, Lq, d_head)
K = torch.randn(batch_size, n_heads, Lk, d_head)
V = torch.randn(batch_size, n_heads, Lk, d_head)
favor_attn = FavorAttention(d_model, n_heads, 256)
approx_attn = favor_attn(Q, K, V)

In [20]:
import torch.nn.functional as F

attn = F.scaled_dot_product_attention(Q, K, V)
attn_rmse = F.mse_loss(approx_attn, attn) ** 0.5
attn_rmse

tensor(0.1599)

## Sharing Projections in Multi-Head Attention

## MQA

In [21]:
batch_size, Lq, Lk, d_head = 32, 100, 90, 64
n_heads = 8
n_groups = 1
query = torch.randn(batch_size, n_heads, Lq, d_head)
key = torch.randn(batch_size, n_groups, Lk, d_head)
value = torch.randn(batch_size, n_groups, Lk, d_head)
attn = F.scaled_dot_product_attention(query, key, value, enable_gqa=True)

## GQA

In [22]:
batch_size, Lq, Lk, d_head = 32, 100, 90, 64
n_heads = 8
n_groups = 2
query = torch.randn(batch_size, n_heads, Lq, d_head)
key = torch.randn(batch_size, n_groups, Lk, d_head)
value = torch.randn(batch_size, n_groups, Lk, d_head)
attn = F.scaled_dot_product_attention(query, key, value, enable_gqa=True)

## FlashAttention

Here's a toy Python implementation of FlashAttention, to get an idea of how it works under the hood:

In [23]:
def flash_attention(Q, K, V, block_size_q, block_size_k):
    Lq, d = Q.shape
    Lk, _ = K.shape
    O = torch.zeros_like(Q)
    scale = d ** -0.5

    for i_start in range(0, Lq, block_size_q):  # iterate over query blocks
        i_end = min(i_start + block_size_q, Lq)
        Q_block = Q[i_start:i_end]

        O_block = torch.zeros(block_size_q, d, device=Q.device)
        l_block = torch.zeros(block_size_q, 1, device=Q.device)
        m_block = -torch.inf * torch.ones(block_size_q, 1, device=Q.device)

        for j_start in range(0, Lk, block_size_k):  # iterate over K/V blocks
            j_end = min(j_start + block_size_k, Lk)
            K_block = K[j_start:j_end]
            V_block = V[j_start:j_end]

            # Core attention calculation for the block
            S_ij = Q_block @ K_block.T * scale  # (block_size_q, block_size_k)

            # Find the new maximum score for the combined blocks so far
            m_ij_new, _ = torch.max(S_ij, dim=1, keepdim=True)
            m_block_new = torch.maximum(m_block, m_ij_new)

            # Rescale previous accumulator values based on the new max
            P_ij = torch.exp(S_ij - m_block_new)
            correction_factor = torch.exp(m_block - m_block_new)

            # Update the denominator (l) and output (O) accumulators
            l_block_new = ((l_block * correction_factor)
                           + torch.sum(P_ij, dim=1, keepdim=True))
            O_block = (O_block * correction_factor) + (P_ij @ V_block)

            # Update the state for the next inner loop iteration
            l_block = l_block_new
            m_block = m_block_new

        # Rescale the output block and write it to the final output matrix
        O[i_start:i_end] = O_block / l_block

    return O

Let's check that this works (note that this simple implementation only handles the case where _L_<sub>Q</sub> and _L_<sub>K</sub> are multiples of the block size:

In [24]:
torch.manual_seed(42)
block_size_q, block_size_k = 64, 64
Lq, Lk, d = 1280, 1152, 512
Q = torch.randn(Lq, d)
K = torch.randn(Lk, d)
V = torch.randn(Lk, d)

In [25]:
R1 = flash_attention(Q, K, V, block_size_q, block_size_k)
R2 = F.scaled_dot_product_attention(Q, K, V)

In [26]:
mse = F.mse_loss(R1, R2)
mse

tensor(8.0755e-16)

# Scaling Up With Mixture of experts


In [27]:
from peft import LoraConfig, get_peft_model

model_id = "EleutherAI/gpt-neo-125M"
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto",
                                             dtype=torch.float16)
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token
lora_config = LoraConfig(
    r=16,
    lora_alpha=32,
    target_modules=["q_proj", "v_proj"],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM"
)
peft_model = get_peft_model(model, lora_config)
peft_model.print_trainable_parameters()

config.json: 0.00B [00:00, ?B/s]

model.safetensors:   0%|          | 0.00/526M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/119 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/727 [00:00<?, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/357 [00:00<?, ?B/s]

trainable params: 589,824 || all params: 125,788,416 || trainable%: 0.4689


In [28]:
import torch
from datasets import load_dataset
from transformers import (TrainingArguments, Trainer,
                          DataCollatorForLanguageModeling)


# Faster Training

## Gradient Accumulation

In [29]:
model = torch.nn.Linear(10, 1).to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
criterion = nn.MSELoss()
data_loader = [(torch.randn(8, 10), torch.randn(8, 1)) for _ in range(100)]
model.train()

accumulation_steps = 4
optimizer.zero_grad()  # reset gradients before starting
for batch_index, (X_batch, y_batch) in enumerate(data_loader):
    X_batch, y_batch = X_batch.to(device), y_batch.to(device)
    y_pred = model(X_batch)
    loss = criterion(y_pred, y_batch)
    loss = loss / accumulation_steps
    loss.backward()
    if (batch_index + 1) % accumulation_steps == 0:
        optimizer.step()
        optimizer.zero_grad()

# Exercise Solutions

## Work in progress

I'm working on the exercise solutions, hoping to finish them by December 2025. Thanks for your patience!