# Attention Principle (Course Pre Demo)

Goal: show attention as learnable **alignment / addressing** via heatmaps (copy = diagonal, reverse = anti-diagonal, pointer = sharp peak).


## Setup
This notebook calls PyTorch's `torch.nn.MultiheadAttention` (no custom attention implementation here).
For reading PyTorch's implementation (trimmed by deletion only), see `pytorch_mha_reading.ipynb`.


In [None]:
import os
import random

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

import matplotlib.pyplot as plt

def set_seed(seed=0):
    random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)

set_seed(0)
torch.set_num_threads(min(4, os.cpu_count() or 1))
device = torch.device('cpu')


In [None]:
class CrossAttnSeqClassifier(nn.Module):
    def __init__(self, vocab_size, d_model, num_heads, max_in_len, max_out_len):
        super().__init__()
        self.vocab_size, self.max_in_len, self.max_out_len = vocab_size, max_in_len, max_out_len
        self.tok_emb = nn.Embedding(vocab_size, d_model)
        self.in_pos_emb = nn.Embedding(max_in_len, d_model)
        self.out_pos_emb = nn.Embedding(max_out_len, d_model)
        self.attn = nn.MultiheadAttention(d_model, num_heads, dropout=0.0, batch_first=True)
        self.out = nn.Linear(d_model, vocab_size)

    def forward(self, key_tokens, value_tokens, out_len, query_token=None):
        B, Tk = key_tokens.shape
        if Tk > self.max_in_len or out_len > self.max_out_len:
            raise ValueError('sequence too long for positional embeddings')
        in_pos = torch.arange(Tk, device=key_tokens.device)
        out_pos = torch.arange(out_len, device=key_tokens.device)
        K = self.tok_emb(key_tokens) + self.in_pos_emb(in_pos)[None, :, :]
        V = self.tok_emb(value_tokens) + self.in_pos_emb(in_pos)[None, :, :]
        Q = self.out_pos_emb(out_pos)[None, :, :].expand(B, -1, -1)
        if query_token is not None:
            Q = Q + self.tok_emb(query_token)[:, None, :]
        ctx, attn_w = self.attn(Q, K, V, need_weights=True, average_attn_weights=False)
        return self.out(ctx), attn_w


In [None]:
def make_batch(task, B, L, V, M=12, K=12, Val=12):
    if task != 'ptr':
        x = torch.randint(1, V, (B, L), dtype=torch.long)
        y = x if task == 'copy' else x.flip(1)
        return x, x, y, None, L
    keys = torch.randint(1, K + 1, (B, M), dtype=torch.long)
    vals = torch.randint(K + 1, K + Val + 1, (B, M), dtype=torch.long)
    idx = torch.randint(0, M, (B,), dtype=torch.long)
    q = keys[torch.arange(B), idx]
    y = vals[torch.arange(B), idx]
    return keys, vals, y, q, 1

def train(task, steps=250, L=16, V=32, M=12, K=12, Val=12, d_model=32, heads=4, lr=3e-3):
    vocab = V if task != 'ptr' else 1 + K + Val
    max_in, max_out = (L, L) if task != 'ptr' else (M, 1)
    model = CrossAttnSeqClassifier(vocab, d_model, heads, max_in, max_out).to(device)
    opt = torch.optim.Adam(model.parameters(), lr=lr)
    for _ in range(steps):
        kt, vt, y, q, out_len = make_batch(task, 128, L, vocab, M, K, Val)
        kt, vt, y = kt.to(device), vt.to(device), y.to(device)
        q = None if q is None else q.to(device)
        logits, attn = model(kt, vt, out_len, q)
        loss = F.cross_entropy(logits.reshape(-1, vocab), y.reshape(-1)) if out_len > 1 else F.cross_entropy(logits[:, 0, :], y)
        opt.zero_grad(set_to_none=True); loss.backward(); opt.step()
    return model

def plot_heatmap(attn_w, title):
    a = attn_w[0].mean(dim=0).detach().cpu().numpy()
    plt.figure(figsize=(6, 5)); plt.imshow(a, aspect='auto', interpolation='nearest'); plt.colorbar()
    plt.title(title); plt.xlabel('input position (key/value index)'); plt.ylabel('output position (query index)')
    plt.tight_layout(); plt.show()


## 1) Copy (expect diagonal)

In [None]:
L, V = 16, 32
copy_model = train('copy', steps=200, L=L, V=V)
copy_model.eval()
kt, vt, y, q, out_len = make_batch('copy', 1, L, V)
with torch.no_grad():
    logits, attn = copy_model(kt, vt, out_len, q)
print('input :', kt[0].tolist())
print('target:', y[0].tolist())
print('pred  :', logits.argmax(dim=-1)[0].tolist())
plot_heatmap(attn, 'Copy: mean over heads (expect diagonal)')


## 2) Reverse (expect anti-diagonal)

In [None]:
reverse_model = train('reverse', steps=250, L=L, V=V)
reverse_model.eval()
kt, vt, y, q, out_len = make_batch('reverse', 1, L, V)
with torch.no_grad():
    logits, attn = reverse_model(kt, vt, out_len, q)
print('input :', kt[0].tolist())
print('target:', y[0].tolist())
print('pred  :', logits.argmax(dim=-1)[0].tolist())
plot_heatmap(attn, 'Reverse: mean over heads (expect anti-diagonal)')


## 3) Pointer / lookup (expect sharp peak)

In [None]:
M, K, Val = 12, 12, 12
ptr_model = train('ptr', steps=350, M=M, K=K, Val=Val)
ptr_model.eval()
kt, vt, y, q, out_len = make_batch('ptr', 1, L, 1 + K + Val, M=M, K=K, Val=Val)
with torch.no_grad():
    logits, attn = ptr_model(kt, vt, out_len, q)
print('keys  :', kt[0].tolist())
print('values:', vt[0].tolist())
print('query :', int(q[0]))
print('target:', int(y[0]))
print('pred  :', int(logits[:, 0, :].argmax(dim=-1)[0]))
plot_heatmap(attn, 'Pointer: attention over pair index (expect 1-hot-ish)')
