# Self-Attention Next-Word Prediction with CUDA

This notebook demonstrates a single self-attention head predicting the next word in a tiny vocabulary, **with training** and **CUDA** support if available.

In [1]:
import torch
import torch.nn.functional as F

# Choose device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

Using device: cuda


In [2]:
# Define small vocabulary and mappings
vocab = ['I', 'love', 'cats', 'dogs', '<pad>', '<eos>']
vocab_size = len(vocab)
word_to_idx = {w: i for i, w in enumerate(vocab)}
idx_to_word = {i: w for i, w in enumerate(vocab)}

In [3]:
# Toy input: 'I love'
input_words = ['I', 'love']
input_idxs = torch.tensor([word_to_idx[w] for w in input_words], device=device)

In [4]:
# Model parameters and embedding
d_model, d_k, d_v = 8, 8, 8
emb = torch.nn.Embedding(vocab_size, d_model).to(device)
Wq = torch.nn.Linear(d_model, d_k).to(device)
Wk = torch.nn.Linear(d_model, d_k).to(device)
Wv = torch.nn.Linear(d_model, d_v).to(device)
Wo = torch.nn.Linear(d_v, d_model).to(device)
classifier = torch.nn.Linear(d_model, vocab_size).to(device)

In [5]:
# Training setup
optimizer = torch.optim.Adam(
    list(emb.parameters()) + list(Wq.parameters()) + list(Wk.parameters()) +
    list(Wv.parameters()) + list(Wo.parameters()) + list(classifier.parameters()),
    lr=0.01
)
loss_fn = torch.nn.CrossEntropyLoss()
# Single target: next word 'cats'
targets = torch.tensor([word_to_idx['cats']], device=device)

In [6]:
# Training loop
for epoch in range(200):
    X = emb(input_idxs)        # (2, d_model)
    Q = Wq(X); K = Wk(X); V = Wv(X)
    scores = Q @ K.T / (d_k ** 0.5)
    attn = F.softmax(scores, dim=-1)
    context = attn @ V         # (2, d_v)
    out = Wo(context)          # (2, d_model)
    logits = classifier(out[-1]).unsqueeze(0)
    loss = loss_fn(logits, targets)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    if epoch % 50 == 0:
        print(f"Epoch {epoch}: loss {loss.item():.4f}")

Epoch 0: loss 1.6159
Epoch 50: loss 0.0000
Epoch 100: loss 0.0000
Epoch 150: loss 0.0000


In [7]:
# Evaluate after training
with torch.no_grad():
    X = emb(input_idxs)
    Q = Wq(X); K = Wk(X); V = Wv(X)
    scores = Q @ K.T / (d_k ** 0.5)
    attn = F.softmax(scores, dim=-1)
    context = attn @ V
    out = Wo(context)
    logits = classifier(out[-1])
    probs = F.softmax(logits, dim=-1)
    print('Next-word probabilities after training:')
    for i, p in enumerate(probs):
        print(f"{idx_to_word[i]}: {p.item():.4f}")

Next-word probabilities after training:
I: 0.0000
love: 0.0000
cats: 1.0000
dogs: 0.0000
<pad>: 0.0000
<eos>: 0.0000
