# Self-Attention Next-Word Prediction Example

This notebook demonstrates a single self-attention head predicting the next word in a tiny vocabulary, **with training**.

In [None]:
import torch
import torch.nn.functional as F

In [None]:
# Define small vocabulary and mappings
vocab = ['I', 'love', 'cats', 'dogs', '<pad>', '<eos>']
vocab_size = len(vocab)
word_to_idx = {w:i for i,w in enumerate(vocab)}
idx_to_word = {i:w for i,w in enumerate(vocab)}

In [None]:
# Toy input: 'I love'
input_words = ['I', 'love']
input_idxs = torch.tensor([word_to_idx[w] for w in input_words])

In [None]:
# Model parameters and embedding
d_model, d_k, d_v = 8, 8, 8
emb = torch.nn.Embedding(vocab_size, d_model)
Wq = torch.nn.Linear(d_model, d_k)
Wk = torch.nn.Linear(d_model, d_k)
Wv = torch.nn.Linear(d_model, d_v)
Wo = torch.nn.Linear(d_v, d_model)
classifier = torch.nn.Linear(d_model, vocab_size)

In [None]:
# Training setup
optimizer = torch.optim.Adam(list(emb.parameters()) + \
    list(Wq.parameters()) + list(Wk.parameters()) + \
    list(Wv.parameters()) + list(Wo.parameters()) + \
    list(classifier.parameters()), lr=0.01)
loss_fn = torch.nn.CrossEntropyLoss()
# Single target: next word 'cats'
targets = torch.tensor([word_to_idx['cats']])

In [None]:
# Training loop
for epoch in range(200):
    X = emb(input_idxs)
    Q = Wq(X); K = Wk(X); V = Wv(X)
    scores = Q @ K.T / (d_k ** 0.5)
    attn = F.softmax(scores, dim=-1)
    context = attn @ V
    out = Wo(context)
    logits = classifier(out[-1]).unsqueeze(0)
    loss = loss_fn(logits, targets)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    if epoch % 50 == 0:
        print(f"Epoch {epoch}: loss {loss.item():.4f}")

In [None]:
# Evaluate after training
with torch.no_grad():
    X = emb(input_idxs)
    Q = Wq(X); K = Wk(X); V = Wv(X)
    scores = Q @ K.T / (d_k ** 0.5)
    attn = F.softmax(scores, dim=-1)
    context = attn @ V
    out = Wo(context)
    logits = classifier(out[-1])
    probs = F.softmax(logits, dim=-1)
    print('Next-word probabilities after training:')
    for i, p in enumerate(probs):
        print(f"{idx_to_word[i]}: {p.item():.4f}")