# Lesson 4 — Build a Tiny Transformer (Character-level)
**Goal:** Implement self-attention plus a mini Transformer block and train it on a small corpus so it can generate themed text.

**Big picture**
- A Transformer is like an upgraded N-gram model that can look at *all* previous tokens, not just the last few.
- Self-attention decides, for every position, which earlier characters are most relevant when predicting the next one.
- Residual connections and layer normalization keep the network stable while it stacks multiple attention + feedforward layers.

**Vocabulary check**
- **Self-attention:** For each position, compute Query (Q), Key (K), and Value (V) vectors. Attention weights = `softmax(Q·Kᵀ / √d)`.
- **Masking:** Prevents the model from peeking at future tokens by setting their attention weights to `-∞` before softmax.
- **Residual connection:** Adds the input of a block back to its output so information can flow easily.
- **Layer normalization:** Normalizes activations to keep training stable.
- **Positional encoding:** Injects information about token order since self-attention alone doesn’t know positions.

**Build order for this notebook**
1. **Tokenizer + dataset:** Reuse Lesson 1 tokens or stick with characters. Create input/output sequences of length `block_size`.
2. **Implement attention head:** Write code for Q, K, V, scaling, masking, and weighted sums.
3. **Stack into a Transformer block:** Combine multi-head attention, feedforward network, residuals, and layer norms.
4. **Training loop:** Use cross-entropy loss, an optimizer (Adam), and mini-batches to train.
5. **Sampling:** Start with a short prompt and repeatedly feed the model’s own predictions back in.

> 🔍 Analogy: Self-attention is like a group project where every student (token) scans everyone else’s notes and decides whose work to copy a little from.

In [None]:

import math, torch, torch.nn as nn
from pathlib import Path
import re, random

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device)

# Load corpus
data_dir = Path("data")
text = ""
for fname in ["space.txt","animals.txt","minecraft.txt"]:
    text += (data_dir / fname).read_text(encoding="utf-8") + "\n"

# Character vocabulary
chars = sorted(list(set(text)))
stoi = {ch:i for i,ch in enumerate(chars)}
itos = {i:ch for ch,i in stoi.items()}

def encode(s): return torch.tensor([stoi[c] for c in s], dtype=torch.long)
def decode(t): return "".join([itos[int(i)] for i in t])

data_t = encode(text)


In [None]:

# Train/val split
n = int(0.9 * len(data_t))
train_data = data_t[:n]
val_data = data_t[n:]

block_size = 64
batch_size = 32

def get_batch(split="train"):
    src = train_data if split=="train" else val_data
    ix = torch.randint(len(src)-block_size-1, (batch_size,))
    x = torch.stack([src[i:i+block_size] for i in ix])
    y = torch.stack([src[i+1:i+block_size+1] for i in ix])
    return x.to(device), y.to(device)

@torch.no_grad()
def estimate_loss(model, iters=50):
    model.eval()
    out = {}
    for split in ["train","val"]:
        losses = torch.zeros(iters)
        for k in range(iters):
            xb,yb = get_batch(split)
            logits, loss = model(xb, yb)
            losses[k] = loss.item()
        out[split] = losses.mean().item()
    model.train()
    return out


In [None]:

class SelfAttention(nn.Module):
    def __init__(self, n_embd, n_head):
        super().__init__()
        assert n_embd % n_head == 0
        self.n_head = n_head
        self.head_dim = n_embd // n_head
        self.qkv = nn.Linear(n_embd, 3*n_embd, bias=False)
        self.proj = nn.Linear(n_embd, n_embd, bias=False)

    def forward(self, x):
        B,T,C = x.shape
        qkv = self.qkv(x)  # (B,T,3C)
        q,k,v = qkv.chunk(3, dim=-1)
        # reshape for heads
        q = q.view(B,T,self.n_head,self.head_dim).transpose(1,2)
        k = k.view(B,T,self.n_head,self.head_dim).transpose(1,2)
        v = v.view(B,T,self.n_head,self.head_dim).transpose(1,2)
        # scaled dot-product attention with causal mask
        att = (q @ k.transpose(-2,-1)) / math.sqrt(self.head_dim)
        mask = torch.tril(torch.ones(T,T, device=x.device)) == 0
        att.masked_fill_(mask, float('-inf'))
        att = torch.softmax(att, dim=-1)
        out = att @ v  # (B,head,T,dim)
        out = out.transpose(1,2).contiguous().view(B,T,C)
        return self.proj(out)

class Block(nn.Module):
    def __init__(self, n_embd, n_head, dropout=0.0):
        super().__init__()
        self.ln1 = nn.LayerNorm(n_embd)
        self.sa = SelfAttention(n_embd, n_head)
        self.ln2 = nn.LayerNorm(n_embd)
        self.ff = nn.Sequential(
            nn.Linear(n_embd, 4*n_embd),
            nn.GELU(),
            nn.Linear(4*n_embd, n_embd),
        )
        self.drop = nn.Dropout(dropout)

    def forward(self, x):
        x = x + self.drop(self.sa(self.ln1(x)))
        x = x + self.drop(self.ff(self.ln2(x)))
        return x

class TinyTransformer(nn.Module):
    def __init__(self, vocab_size, n_embd=128, n_head=4, n_layer=2, block_size=64):
        super().__init__()
        self.token_emb = nn.Embedding(vocab_size, n_embd)
        self.pos_emb = nn.Embedding(block_size, n_embd)
        self.blocks = nn.Sequential(*[Block(n_embd, n_head) for _ in range(n_layer)])
        self.ln = nn.LayerNorm(n_embd)
        self.head = nn.Linear(n_embd, vocab_size)
        self.block_size = block_size

    def forward(self, idx, targets=None):
        B,T = idx.shape
        tok = self.token_emb(idx)
        pos = self.pos_emb(torch.arange(T, device=idx.device))
        x = tok + pos
        x = self.blocks(x)
        x = self.ln(x)
        logits = self.head(x)
        loss = None
        if targets is not None:
            loss = nn.functional.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
        return logits, loss

    @torch.no_grad()
    def generate(self, idx, max_new_tokens=100):
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -self.block_size:]
            logits, _ = self(idx_cond)
            logits = logits[:, -1, :]
            probs = torch.softmax(logits, dim=-1)
            next_id = torch.multinomial(probs, num_samples=1)
            idx = torch.cat([idx, next_id], dim=1)
        return idx

vocab_size = len(stoi)
model = TinyTransformer(vocab_size).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)


In [None]:

# Quick train (few steps to keep it fast). Increase steps for better results.
steps = 300
for step in range(steps):
    xb,yb = get_batch("train")
    logits, loss = model(xb, yb)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    if step % 50 == 0:
        print(step, estimate_loss(model))

# Sample
context = torch.tensor([[stoi.get('T',0)]], dtype=torch.long, device=device)
out = model.generate(context, max_new_tokens=200)[0].tolist()
print("=== SAMPLE ===")
print(decode(out))


### Challenges
- **Architectural tweaks:** Increase `n_layer` or number of attention heads and observe training stability and text quality.
- **Longer context:** Try `block_size = 128` (needs more VRAM/time) and see if the model captures longer phrases.
- **Add dropout:** Insert dropout in attention and feedforward blocks to reduce overfitting.
- **Checkpointing:** Save model weights and experiment with resuming training after changing the corpus.