
# Attention Is All You Need — Exploratory Notebook

This notebook walks through the core ideas of Vaswani et al.'s *Attention Is All You Need* paper. Rather than jumping straight to a finished library, we incrementally build reusable pieces — much like the fast.ai approach of exploratory programming — and compose them into a functioning miniature Transformer.



## Notebook Roadmap

We will iterate through the architecture in the same order the paper introduces it:

1. Build toy tokenisation utilities for a miniature sequence-to-sequence task.
2. Implement embedding lookup and sinusoidal positional encodings.
3. Explore scaled dot-product attention and attention masking.
4. Compose multi-head attention from the primitive operation.
5. Add the position-wise feed-forward network and residual + layer norm glue.
6. Stack encoder and decoder blocks to obtain a full Transformer.
7. Run a bite-sized experiment to learn a synthetic translation task, demonstrating the model's end-to-end behaviour.


In [None]:

import math
from typing import Optional, Tuple

import torch
from torch import nn
import torch.nn.functional as F

torch.manual_seed(42)



## 1. Toy tokenisation utilities

The original paper trains on large corpora, but we only need a tiny dataset to exercise the components. We define a mini parallel corpus and helper functions that map between text and token IDs.


In [None]:

toy_pairs = [
    ("i like deep learning", "ich mag tiefes lernen"),
    ("this is a tiny dataset", "dies ist ein winziger datensatz"),
    ("attention helps models focus", "aufmerksamkeit hilft modellen fokus"),
    ("transformers replace recurrence", "transformer ersetzen rekurrenz"),
    ("we build modules stepwise", "wir bauen module schrittweise"),
    ("layers communicate with attention", "schichten kommunizieren mit aufmerksamkeit"),
]

SPECIAL_TOKENS = ["<pad>", "<bos>", "<eos>"]

src_vocab = {token: idx for idx, token in enumerate(SPECIAL_TOKENS)}
tgt_vocab = src_vocab.copy()

for src, tgt in toy_pairs:
    for token in src.split():
        if token not in src_vocab:
            src_vocab[token] = len(src_vocab)
    for token in tgt.split():
        if token not in tgt_vocab:
            tgt_vocab[token] = len(tgt_vocab)

inv_src_vocab = {idx: token for token, idx in src_vocab.items()}
inv_tgt_vocab = {idx: token for token, idx in tgt_vocab.items()}

src_vocab_size, tgt_vocab_size = len(src_vocab), len(tgt_vocab)
print(f"Source vocab: {src_vocab_size} tokens | Target vocab: {tgt_vocab_size} tokens")


In [None]:

def encode(sentence: str, vocab: dict) -> torch.Tensor:
    return torch.tensor([vocab[token] for token in sentence.split()], dtype=torch.long)


def decode(ids: torch.Tensor, inv_vocab: dict) -> str:
    tokens = [inv_vocab[i] for i in ids.tolist()]
    return " ".join(tokens)


encoded_example = encode(toy_pairs[0][0], src_vocab)
print(encoded_example)
print(decode(encoded_example, inv_src_vocab))



## 2. Embeddings with sinusoidal positional encoding

The paper replaces recurrence with positional signals. We'll implement the sinusoidal encoding that can be added to learned token embeddings.


In [None]:

class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, max_len: int = 5000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer("pe", pe.unsqueeze(0))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return x + self.pe[:, : x.size(1)]


embedding_dim = 32
src_embed = nn.Embedding(src_vocab_size, embedding_dim)
tgt_embed = nn.Embedding(tgt_vocab_size, embedding_dim)
positional_encoding = PositionalEncoding(embedding_dim)

sample = src_embed(encoded_example.unsqueeze(0))
print("Embedding shape:", sample.shape)
print("With positional encoding:", positional_encoding(sample).shape)



## 3. Scaled dot-product attention

Scaled dot-product attention maps queries, keys, and values to contextualised representations. We'll implement it directly from the paper and probe it with a small example.


In [None]:

def scaled_dot_product_attention(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
    d_k = query.size(-1)
    scores = query @ key.transpose(-2, -1) / math.sqrt(d_k)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, float("-inf"))
    attn_weights = F.softmax(scores, dim=-1)
    output = attn_weights @ value
    return output, attn_weights


q = torch.randn(1, 4, embedding_dim)
k = torch.randn(1, 4, embedding_dim)
v = torch.randn(1, 4, embedding_dim)

context, weights = scaled_dot_product_attention(q, k, v)
print("Context shape:", context.shape)
print("Attention weights row sums:", weights.sum(-1))



Masks prevent information leakage in the decoder. We can reuse the same primitive with different masks for padding and future tokens.


In [None]:

def subsequent_mask(size: int) -> torch.Tensor:
    mask = torch.triu(torch.ones(size, size, dtype=torch.bool), diagonal=1)
    return (~mask).unsqueeze(0).unsqueeze(0)


mask = subsequent_mask(5)
print(mask[0, 0].int())



## 4. Multi-head attention as a composition

Now that we trust the scaled attention primitive, we wrap it into the multi-head structure with learned projections.


In [None]:

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model: int, num_heads: int):
        super().__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads

        self.w_q = nn.Linear(d_model, d_model)
        self.w_k = nn.Linear(d_model, d_model)
        self.w_v = nn.Linear(d_model, d_model)
        self.out_proj = nn.Linear(d_model, d_model)

    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        mask: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        batch_size = query.size(0)
        q = self.w_q(query).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        k = self.w_k(key).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        v = self.w_v(value).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)

        if mask is not None:
            mask = mask.unsqueeze(1)

        attn_output, _ = scaled_dot_product_attention(q, k, v, mask)
        attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
        return self.out_proj(attn_output)


mha = MultiHeadAttention(d_model=embedding_dim, num_heads=4)
x = torch.randn(2, 6, embedding_dim)
print(mha(x, x, x).shape)



## 5. Position-wise feed-forward network and residual glue

Each encoder/decoder layer applies a two-layer MLP with a ReLU in between, plus residual connections and layer normalisation.


In [None]:

class PositionwiseFeedForward(nn.Module):
    def __init__(self, d_model: int, d_ff: int = 128):
        super().__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.linear2(F.relu(self.linear1(x)))


class ResidualConnection(nn.Module):
    def __init__(self, size: int, dropout: float = 0.1):
        super().__init__()
        self.norm = nn.LayerNorm(size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor, sublayer) -> torch.Tensor:
        return x + self.dropout(sublayer(self.norm(x)))


tensor = torch.randn(3, 4, embedding_dim)
ff = PositionwiseFeedForward(embedding_dim)
residual = ResidualConnection(embedding_dim)
print(residual(tensor, ff).shape)



## 6. Encoder and decoder blocks

We can now assemble encoder and decoder layers, mirroring Figures 1 and 2 from the paper.


In [None]:

class EncoderLayer(nn.Module):
    def __init__(self, d_model: int, num_heads: int, d_ff: int, dropout: float = 0.1):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.feed_forward = PositionwiseFeedForward(d_model, d_ff)
        self.sublayers = nn.ModuleList(
            [ResidualConnection(d_model, dropout) for _ in range(2)]
        )

    def forward(
        self, x: torch.Tensor, src_mask: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        x = self.sublayers[0](x, lambda x_: self.self_attn(x_, x_, x_, src_mask))
        x = self.sublayers[1](x, self.feed_forward)
        return x


class DecoderLayer(nn.Module):
    def __init__(self, d_model: int, num_heads: int, d_ff: int, dropout: float = 0.1):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.src_attn = MultiHeadAttention(d_model, num_heads)
        self.feed_forward = PositionwiseFeedForward(d_model, d_ff)
        self.sublayers = nn.ModuleList(
            [ResidualConnection(d_model, dropout) for _ in range(3)]
        )

    def forward(
        self,
        x: torch.Tensor,
        memory: torch.Tensor,
        tgt_mask: Optional[torch.Tensor] = None,
        memory_mask: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        x = self.sublayers[0](x, lambda x_: self.self_attn(x_, x_, x_, tgt_mask))
        x = self.sublayers[1](x, lambda x_: self.src_attn(x_, memory, memory, memory_mask))
        x = self.sublayers[2](x, self.feed_forward)
        return x


encoder_layer = EncoderLayer(embedding_dim, num_heads=4, d_ff=128)
decoder_layer = DecoderLayer(embedding_dim, num_heads=4, d_ff=128)

memory = encoder_layer(torch.randn(2, 5, embedding_dim))
output = decoder_layer(torch.randn(2, 6, embedding_dim), memory)
print(memory.shape, output.shape)



## 7. Stacking layers into a mini Transformer

With layers in place, we construct encoder and decoder stacks plus the final linear generator.


In [None]:

class Encoder(nn.Module):
    def __init__(self, layer: EncoderLayer, N: int):
        super().__init__()
        self.layers = nn.ModuleList(
            [
                layer if i == 0 else EncoderLayer(
                    layer.self_attn.d_model,
                    layer.self_attn.num_heads,
                    layer.feed_forward.linear1.out_features,
                )
                for i in range(N)
            ]
        )
        self.norm = nn.LayerNorm(layer.self_attn.d_model)

    def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        for layer in self.layers:
            x = layer(x, mask)
        return self.norm(x)


class Decoder(nn.Module):
    def __init__(self, layer: DecoderLayer, N: int):
        super().__init__()
        self.layers = nn.ModuleList(
            [
                layer if i == 0 else DecoderLayer(
                    layer.self_attn.d_model,
                    layer.self_attn.num_heads,
                    layer.feed_forward.linear1.out_features,
                )
                for i in range(N)
            ]
        )
        self.norm = nn.LayerNorm(layer.self_attn.d_model)

    def forward(
        self,
        x: torch.Tensor,
        memory: torch.Tensor,
        tgt_mask: Optional[torch.Tensor] = None,
        memory_mask: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        for layer in self.layers:
            x = layer(x, memory, tgt_mask, memory_mask)
        return self.norm(x)


class Transformer(nn.Module):
    def __init__(
        self,
        src_vocab: int,
        tgt_vocab: int,
        d_model: int = 64,
        num_heads: int = 4,
        d_ff: int = 128,
        num_layers: int = 2,
        dropout: float = 0.1,
    ):
        super().__init__()
        self.d_model = d_model
        self.src_embed = nn.Embedding(src_vocab, d_model)
        self.tgt_embed = nn.Embedding(tgt_vocab, d_model)
        self.positional_encoding = PositionalEncoding(d_model)
        encoder_layer = EncoderLayer(d_model, num_heads, d_ff, dropout)
        decoder_layer = DecoderLayer(d_model, num_heads, d_ff, dropout)
        self.encoder = Encoder(encoder_layer, num_layers)
        self.decoder = Decoder(decoder_layer, num_layers)
        self.generator = nn.Linear(d_model, tgt_vocab)

    def encode(self, src: torch.Tensor, src_mask: Optional[torch.Tensor]) -> torch.Tensor:
        return self.encoder(self.positional_encoding(self.src_embed(src)), src_mask)

    def decode(
        self,
        tgt: torch.Tensor,
        memory: torch.Tensor,
        tgt_mask: Optional[torch.Tensor],
        memory_mask: Optional[torch.Tensor],
    ):
        return self.decoder(
            self.positional_encoding(self.tgt_embed(tgt)),
            memory,
            tgt_mask,
            memory_mask,
        )

    def forward(
        self,
        src: torch.Tensor,
        tgt: torch.Tensor,
        src_mask: Optional[torch.Tensor],
        tgt_mask: Optional[torch.Tensor],
        memory_mask: Optional[torch.Tensor],
    ) -> torch.Tensor:
        memory = self.encode(src, src_mask)
        output = self.decode(tgt, memory, tgt_mask, memory_mask)
        return self.generator(output)


model = Transformer(src_vocab_size, tgt_vocab_size, d_model=embedding_dim)
src_batch = torch.stack([encode(pair[0], src_vocab) for pair in toy_pairs])
tgt_batch = torch.stack([encode(pair[1], tgt_vocab) for pair in toy_pairs])
print(model(src_batch, tgt_batch, None, None, None).shape)



## 8. Training on a synthetic translation task

To verify the notebook components, we run a tiny training loop. The goal is simply to overfit the six toy sentence pairs so that the decoder learns to predict the next token.


In [None]:

def make_batch(pairs):
    src_seqs, tgt_seqs = [], []
    for src, tgt in pairs:
        src_ids = encode(src, src_vocab)
        tgt_ids = torch.cat(
            [
                torch.tensor([tgt_vocab["<bos>"]]),
                encode(tgt, tgt_vocab),
                torch.tensor([tgt_vocab["<eos>"]]),
            ]
        )
        src_seqs.append(src_ids)
        tgt_seqs.append(tgt_ids)
    src_pad = nn.utils.rnn.pad_sequence(
        src_seqs, batch_first=True, padding_value=src_vocab["<pad>"]
    )
    tgt_pad = nn.utils.rnn.pad_sequence(
        tgt_seqs, batch_first=True, padding_value=tgt_vocab["<pad>"]
    )
    return src_pad, tgt_pad


src_batch, tgt_batch = make_batch(toy_pairs)
print("Source batch shape:", src_batch.shape)
print("Target batch shape:", tgt_batch.shape)


In [None]:

def create_masks(src: torch.Tensor, tgt: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    src_mask = (src != src_vocab["<pad>"]).unsqueeze(1).unsqueeze(2)
    tgt_mask = (tgt != tgt_vocab["<pad>"]).unsqueeze(1).unsqueeze(2)
    size = tgt.size(1)
    tgt_mask = tgt_mask & subsequent_mask(size)
    memory_mask = src_mask
    return src_mask, tgt_mask, memory_mask


src_mask, tgt_mask, memory_mask = create_masks(src_batch, tgt_batch)
print(src_mask.shape, tgt_mask.shape, memory_mask.shape)


In [None]:

model = Transformer(
    src_vocab_size,
    tgt_vocab_size,
    d_model=embedding_dim,
    num_heads=4,
    d_ff=128,
    num_layers=2,
)
criterion = nn.CrossEntropyLoss(ignore_index=tgt_vocab["<pad>"])
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

num_epochs = 200
for epoch in range(1, num_epochs + 1):
    optimizer.zero_grad()
    logits = model(
        src_batch,
        tgt_batch[:, :-1],
        src_mask,
        tgt_mask[:, :, :, :-1],
        memory_mask,
    )
    loss = criterion(
        logits.reshape(-1, tgt_vocab_size),
        tgt_batch[:, 1:].reshape(-1),
    )
    loss.backward()
    optimizer.step()
    if epoch % 50 == 0:
        print(f"Epoch {epoch}: loss={loss.item():.4f}")


In [None]:

def greedy_decode(model: Transformer, src_sentence: str, max_len: int = 20) -> str:
    model.eval()
    src = encode(src_sentence, src_vocab).unsqueeze(0)
    src_mask = (src != src_vocab["<pad>"]).unsqueeze(1).unsqueeze(2)
    memory = model.encode(src, src_mask)

    ys = torch.tensor([[tgt_vocab["<bos>"]]])
    for _ in range(max_len):
        tgt_mask = subsequent_mask(ys.size(1)).to(ys.device)
        out = model.decode(ys, memory, tgt_mask, src_mask)
        prob = model.generator(out[:, -1])
        _, next_word = torch.max(prob, dim=1)
        next_word = next_word.item()
        ys = torch.cat([ys, torch.tensor([[next_word]])], dim=1)
        if next_word == tgt_vocab["<eos>"]:
            break
    return decode(ys[0, 1:-1], inv_tgt_vocab)


for src, tgt in toy_pairs:
    prediction = greedy_decode(model, src)
    print(f"SRC: {src}
PRED: {prediction}
TGT: {tgt}
")
