# Tutorial: PyTorch Code Walkthrough + a small Demo

We’ll focus on the three practical takeaways in PyTorch’s nn.MultiheadAttention, then run a tiny demo to interpret attention maps.



## PyTorch `nn.MultiheadAttention` Implementation


Source (pytorch/torch/nn/modules/activation.py):
- https://github.com/pytorch/pytorch/blob/main/torch/nn/modules/activation.py#L1089-L1566

### Three engineering takeaways
1) Supporting both input layouts: `(batch, seq, dim)` and `(seq, batch, dim)`.
2) Supporting three masking controls: `key_padding_mask`, `attn_mask`, and `is_causal`.
3) `need_weights` indirectly controls whether PyTorch uses the regular math implementation or CUDA-optimized SDPA kernels.

**For a quick look at how these are implemented (in `F.multi_head_attention_forward`), see `f_mha_key.ipynb`.**

In [None]:
class MultiheadAttention(Module):
    # ...
    __constants__ = ["batch_first"]
    bias_k: torch.Tensor | None
    bias_v: torch.Tensor | None

    def __init__(
        self,
        embed_dim,
        num_heads,
        dropout=0.0,
        bias=True,
        add_bias_kv=False,
        add_zero_attn=False,
        kdim=None,
        vdim=None,
        batch_first=False,
        device=None,
        dtype=None,
    ) -> None:
        if embed_dim <= 0 or num_heads <= 0:
            raise ValueError(
                f"embed_dim and num_heads must be greater than 0,"
                f" got embed_dim={embed_dim} and num_heads={num_heads} instead"
            )
        factory_kwargs = {"device": device, "dtype": dtype}
        super().__init__()
        self.embed_dim = embed_dim
        self.kdim = kdim if kdim is not None else embed_dim
        self.vdim = vdim if vdim is not None else embed_dim
        self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim

        self.num_heads = num_heads
        self.dropout = dropout
        self.batch_first = batch_first
        self.head_dim = embed_dim // num_heads
        if self.head_dim * num_heads != self.embed_dim:
            raise AssertionError("embed_dim must be divisible by num_heads")

        if not self._qkv_same_embed_dim:
            # ...
        else:
            self.in_proj_weight = Parameter(
                torch.empty((3 * embed_dim, embed_dim), **factory_kwargs)
            )
            self.register_parameter("q_proj_weight", None)
            self.register_parameter("k_proj_weight", None)
            self.register_parameter("v_proj_weight", None)


        if bias:
            self.in_proj_bias = Parameter(torch.empty(3 * embed_dim, **factory_kwargs))
        else:
            self.register_parameter("in_proj_bias", None)
        self.out_proj = NonDynamicallyQuantizableLinear(
            embed_dim, embed_dim, bias=bias, **factory_kwargs
        )

        if add_bias_kv:
            self.bias_k = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
            self.bias_v = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
        else:
            self.bias_k = self.bias_v = None


        self.add_zero_attn = add_zero_attn

        self._reset_parameters()

    def _reset_parameters(self) -> None:
        if self._qkv_same_embed_dim:
            xavier_uniform_(self.in_proj_weight)
        else:
            # ...

        if self.in_proj_bias is not None:
            constant_(self.in_proj_bias, 0.0)
            constant_(self.out_proj.bias, 0.0)
        if self.bias_k is not None:
            xavier_normal_(self.bias_k)
        if self.bias_v is not None:
            xavier_normal_(self.bias_v)


    def forward(
        self,
        query: Tensor,
        key: Tensor,
        value: Tensor,
        key_padding_mask: Tensor | None = None,
        need_weights: bool = True,
        attn_mask: Tensor | None = None,
        average_attn_weights: bool = True,
        is_causal: bool = False,
    ) -> tuple[Tensor, Tensor | None]:
        # ...
        is_batched = query.dim() == 3

        if self.batch_first and is_batched:
            # ...

        if not self._qkv_same_embed_dim:
            attn_output, attn_output_weights = F.multi_head_attention_forward(
                query,
                key,
                value,
                self.embed_dim,
                self.num_heads,
                self.in_proj_weight,
                self.in_proj_bias,
                self.bias_k,
                self.bias_v,
                self.add_zero_attn,
                self.dropout,
                self.out_proj.weight,
                self.out_proj.bias,
                training=self.training,
                key_padding_mask=key_padding_mask,
                need_weights=need_weights,
                attn_mask=attn_mask,
                use_separate_proj_weight=True,
                q_proj_weight=self.q_proj_weight,
                k_proj_weight=self.k_proj_weight,
                v_proj_weight=self.v_proj_weight,
                average_attn_weights=average_attn_weights,
                is_causal=is_causal,
            )
        else:
            attn_output, attn_output_weights = F.multi_head_attention_forward(
                query,
                key,
                value,
                self.embed_dim,
                self.num_heads,
                self.in_proj_weight,
                self.in_proj_bias,
                self.bias_k,
                self.bias_v,
                self.add_zero_attn,
                self.dropout,
                self.out_proj.weight,
                self.out_proj.bias,
                training=self.training,
                key_padding_mask=key_padding_mask,
                need_weights=need_weights,
                attn_mask=attn_mask,
                average_attn_weights=average_attn_weights,
                is_causal=is_causal,
            )
        if self.batch_first and is_batched:
            return attn_output.transpose(1, 0), attn_output_weights
        else:
            return attn_output, attn_output_weights

## A Demo Using `nn.MultiheadAttention`

In this demo, we train a tiny cross-attention model on three toy tasks and visualize **attention weights**.

**How to read the heatmap**
- x-axis: input position (key/value index)
- y-axis: output position (query index)
- brighter = higher attention weight

These synthetic tasks are designed so that “correct behavior” corresponds to a clear geometric pattern in the attention map.
(We set `need_weights=True` for visualization.)



### Setup

**What this model is doing (high level)**
- Input: a sequence of tokens (used to form keys/values)
- Queries: positions in the output sequence (or a single query token for the pointer task)
- Output: predicted tokens, plus attention maps for inspection





In [None]:
import os
import random

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

import matplotlib.pyplot as plt

def set_seed(seed=0):
    random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)

set_seed(0)
torch.set_num_threads(min(4, os.cpu_count() or 1))
device = torch.device('cpu')

class CrossAttnSeqClassifier(nn.Module):
    def __init__(self, vocab_size, d_model, num_heads, max_in_len, max_out_len):
        super().__init__()
        self.vocab_size, self.max_in_len, self.max_out_len = vocab_size, max_in_len, max_out_len
        self.tok_emb = nn.Embedding(vocab_size, d_model)
        self.in_pos_emb = nn.Embedding(max_in_len, d_model)
        self.out_pos_emb = nn.Embedding(max_out_len, d_model)
        self.attn = nn.MultiheadAttention(d_model, num_heads, dropout=0.0, batch_first=True)
        self.out = nn.Linear(d_model, vocab_size)

    def forward(self, key_tokens, value_tokens, out_len, query_token=None):
        B, Tk = key_tokens.shape
        if Tk > self.max_in_len or out_len > self.max_out_len:
            raise ValueError('sequence too long for positional embeddings')
        in_pos = torch.arange(Tk, device=key_tokens.device)
        out_pos = torch.arange(out_len, device=key_tokens.device)
        K = self.tok_emb(key_tokens) + self.in_pos_emb(in_pos)[None, :, :]
        V = self.tok_emb(value_tokens) + self.in_pos_emb(in_pos)[None, :, :]
        Q = self.out_pos_emb(out_pos)[None, :, :].expand(B, -1, -1)
        if query_token is not None:
            Q = Q + self.tok_emb(query_token)[:, None, :]
        ctx, attn_w = self.attn(Q, K, V, need_weights=True, average_attn_weights=False)
        return self.out(ctx), attn_w
    
def make_batch(task, B, L, V, M=12, K=12, Val=12):
    if task != 'ptr':
        x = torch.randint(1, V, (B, L), dtype=torch.long)
        y = x if task == 'copy' else x.flip(1)
        return x, x, y, None, L
    keys = torch.randint(1, K + 1, (B, M), dtype=torch.long)
    vals = torch.randint(K + 1, K + Val + 1, (B, M), dtype=torch.long)
    idx = torch.randint(0, M, (B,), dtype=torch.long)
    q = keys[torch.arange(B), idx]
    y = vals[torch.arange(B), idx]
    return keys, vals, y, q, 1

def train(task, steps=250, L=16, V=32, M=12, K=12, Val=12, d_model=32, heads=4, lr=3e-3):
    vocab = V if task != 'ptr' else 1 + K + Val
    max_in, max_out = (L, L) if task != 'ptr' else (M, 1)
    model = CrossAttnSeqClassifier(vocab, d_model, heads, max_in, max_out).to(device)
    opt = torch.optim.Adam(model.parameters(), lr=lr)
    for _ in range(steps):
        kt, vt, y, q, out_len = make_batch(task, 128, L, vocab, M, K, Val)
        kt, vt, y = kt.to(device), vt.to(device), y.to(device)
        q = None if q is None else q.to(device)
        logits, attn = model(kt, vt, out_len, q)
        loss = F.cross_entropy(logits.reshape(-1, vocab), y.reshape(-1)) if out_len > 1 else F.cross_entropy(logits[:, 0, :], y)
        opt.zero_grad(set_to_none=True); loss.backward(); opt.step()
    return model

def plot_heatmap(attn_w, title):
    a = attn_w[0].mean(dim=0).detach().cpu().numpy()
    plt.figure(figsize=(6, 5)); plt.imshow(a, aspect='auto', interpolation='nearest'); plt.colorbar()
    plt.title(title); plt.xlabel('input position (key/value index)'); plt.ylabel('output position (query index)')
    plt.tight_layout(); plt.show()




### 1) Copy (expect diagonal)

**Task:** reproduce the input sequence at the output (copy).

**What to expect in attention:** a bright **diagonal**  
Each output position should mainly attend to the same input position.


In [None]:
L, V = 16, 32
copy_model = train('copy', steps=200, L=L, V=V)
copy_model.eval()
kt, vt, y, q, out_len = make_batch('copy', 1, L, V)
with torch.no_grad():
    logits, attn = copy_model(kt, vt, out_len, q)
print('input :', kt[0].tolist())
print('target:', y[0].tolist())
print('pred  :', logits.argmax(dim=-1)[0].tolist())
plot_heatmap(attn, 'Copy: mean over heads (expect diagonal)')


### 2) Reverse (expect anti-diagonal)

**Task:** output the input sequence in reverse order.

**What to expect in attention:** a bright **anti-diagonal**  
Output position *t* should attend to input position *(L−1−t)*.


In [None]:
reverse_model = train('reverse', steps=250, L=L, V=V)
reverse_model.eval()
kt, vt, y, q, out_len = make_batch('reverse', 1, L, V)
with torch.no_grad():
    logits, attn = reverse_model(kt, vt, out_len, q)
print('input :', kt[0].tolist())
print('target:', y[0].tolist())
print('pred  :', logits.argmax(dim=-1)[0].tolist())
plot_heatmap(attn, 'Reverse: mean over heads (expect anti-diagonal)')


### 3) Pointer / lookup (expect sharp peak)

**Task:** given a set of (key, value) pairs, use the query token to “look up” the corresponding value.

**What to expect in attention:** a **sharp peak** (almost one-hot)  
The single output query should attend strongly to exactly one input slot (the matched key).


In [None]:
M, K, Val = 12, 12, 12
ptr_model = train('ptr', steps=350, M=M, K=K, Val=Val)
ptr_model.eval()
kt, vt, y, q, out_len = make_batch('ptr', 1, L, 1 + K + Val, M=M, K=K, Val=Val)
with torch.no_grad():
    logits, attn = ptr_model(kt, vt, out_len, q)
print('keys  :', kt[0].tolist())
print('values:', vt[0].tolist())
print('query :', int(q[0]))
print('target:', int(y[0]))
print('pred  :', int(logits[:, 0, :].argmax(dim=-1)[0]))
plot_heatmap(attn, 'Pointer: attention over pair index (expect 1-hot-ish)')
