# Self-Attention Next-Word Prediction Example

This notebook demonstrates a single self-attention head predicting the next word in a tiny vocabulary.

In [None]:
import torch
import torch.nn.functional as F

In [None]:
# Define small vocabulary and mappings
vocab = ['I', 'love', 'cats', 'dogs', '<pad>', '<eos>']
vocab_size = len(vocab)
word_to_idx = {w:i for i,w in enumerate(vocab)}
idx_to_word = {i:w for i,w in enumerate(vocab)}

In [None]:
# Toy input: 'I love'
input_words = ['I', 'love']
input_idxs = torch.tensor([word_to_idx[w] for w in input_words])

In [None]:
# Model parameters and embedding
d_model = 8
d_k = 8
d_v = 8
emb = torch.nn.Embedding(vocab_size, d_model)
Wq = torch.nn.Linear(d_model, d_k)
Wk = torch.nn.Linear(d_model, d_k)
Wv = torch.nn.Linear(d_model, d_v)
Wo = torch.nn.Linear(d_v, d_model)
classifier = torch.nn.Linear(d_model, vocab_size)

In [None]:
# Forward pass for self-attention and next-word prediction
X = emb(input_idxs)  # (2, d_model)
Q = Wq(X)            # (2, d_k)
K = Wk(X)
V = Wv(X)
scores = Q @ K.T / (d_k ** 0.5)
attn = F.softmax(scores, dim=-1)
context = attn @ V   # (2, d_v)
out = Wo(context)    # (2, d_model)
# Predict next word from the last position output
logits = classifier(out[-1])
probs = F.softmax(logits, dim=-1)
print('Next-word probabilities:')
for i, p in enumerate(probs):
    print(f"{idx_to_word[i]}: {p.item():.4f}")