In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import requests, math



In [2]:
# =============================
# 1. Load Tiny Shakespeare text
# =============================
url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
text = requests.get(url).text

In [3]:
# =============================
# 2. Word-level tokenization
# =============================
# Simple whitespace tokenizer (teaching-friendly)
words = text.split()

# Build vocab
vocab = sorted(set(words))
stoi = {w: i for i, w in enumerate(vocab)}       # word → index
itos = {i: w for w, i in stoi.items()}           # index → word
vocab_size = len(vocab)

print("Total words:", len(words))
print("Vocab size:", vocab_size)
print("Sample tokens:", words[:20])

# =============================
# 3. Convert entire corpus to integer IDs
# =============================
data = torch.tensor([stoi[w] for w in words], dtype=torch.long)

# Train/validation split
split = int(0.9 * len(data))
train_data = data[:split]
val_data   = data[split:]

print("Train words:", len(train_data))
print("Val words:", len(val_data))

Total words: 202651
Vocab size: 25670
Sample tokens: ['First', 'Citizen:', 'Before', 'we', 'proceed', 'any', 'further,', 'hear', 'me', 'speak.', 'All:', 'Speak,', 'speak.', 'First', 'Citizen:', 'You', 'are', 'all', 'resolved', 'rather']
Train words: 182385
Val words: 20266


In [4]:
# =============================
# 4. Batching function for training
# =============================
def get_batch(split, block_size, batch_size, device):
    source = train_data if split == "train" else val_data
    # pick random starting word positions
    ids = torch.randint(0, len(source) - block_size - 1, (batch_size,))
    x = torch.stack([source[i:i+block_size] for i in ids]).to(device)
    y = torch.stack([source[i+1:i+block_size+1] for i in ids]).to(device)
    return x, y

In [5]:
# ============================================================
# 2. Attention variants (simple)
# ============================================================

class ScaledDotProductAttention(nn.Module):
    def __init__(self, n_embd, n_head):
        super().__init__()
        self.n_head = n_head
        self.d = n_embd // n_head
        self.qkv = nn.Linear(n_embd, 3*n_embd, bias=False)
        self.out = nn.Linear(n_embd, n_embd)

    def forward(self, x):
        B, T, C = x.shape
        qkv = self.qkv(x).reshape(B, T, 3, self.n_head, self.d).permute(2,0,3,1,4)
        q, k, v = qkv[0], qkv[1], qkv[2]

        att = (q @ k.transpose(-2,-1)) / math.sqrt(self.d)
        mask = torch.tril(torch.ones(T, T, device=x.device)) == 1
        att = att.masked_fill(~mask, float('-inf'))
        att = att.softmax(dim=-1)

        out = att @ v
        out = out.transpose(1,2).reshape(B,T,C)
        return self.out(out)

class FlashAttention(nn.Module):
    """Uses PyTorch 2 scaled_dot_product_attention (Flash when available)."""
    def __init__(self, n_embd, n_head):
        super().__init__()
        self.n_head = n_head
        self.d = n_embd // n_head
        self.qkv = nn.Linear(n_embd, 3*n_embd, bias=False)
        self.out = nn.Linear(n_embd, n_embd)

    def forward(self, x):
        B, T, C = x.shape
        qkv = self.qkv(x).reshape(B, T, 3, self.n_head, self.d).permute(2,0,3,1,4)
        q, k, v = qkv[0], qkv[1], qkv[2]
        out = F.scaled_dot_product_attention(q, k, v, is_causal=True)
        out = out.transpose(1,2).reshape(B,T,C)
        return self.out(out)

# ============================
# Local (sliding-window) attention
# ============================
class LocalAttention(nn.Module):
    """
    Causal sliding-window attention.
    Each position attends only to the previous `window` tokens (and itself).
    """
    def __init__(self, n_embd, n_head, window: int = 64):
        super().__init__()
        assert n_embd % n_head == 0
        self.n_head = n_head
        self.d = n_embd // n_head
        self.window = window
        self.qkv = nn.Linear(n_embd, 3 * n_embd, bias=False)
        self.out = nn.Linear(n_embd, n_embd)

    def forward(self, x):
        B, T, C = x.shape
        qkv = self.qkv(x).reshape(B, T, 3, self.n_head, self.d).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]     # (B, H, T, Dh)

        # raw attention scores
        att = (q @ k.transpose(-2, -1)) / math.sqrt(self.d)  # (B, H, T, T)

        # build causal+local mask: j<=i and i-j<=window
        device = x.device
        i = torch.arange(T, device=device).unsqueeze(1)
        j = torch.arange(T, device=device).unsqueeze(0)
        causal = j <= i
        local = (i - j) <= self.window
        mask = causal & local                 # (T, T) bool

        att = att.masked_fill(~mask, float('-inf'))
        att = att.softmax(dim=-1)
        out = att @ v                         # (B, H, T, Dh)
        out = out.transpose(1, 2).reshape(B, T, C)
        return self.out(out)

# ============================
# Linear (kernelized) attention (pedagogical)
# ============================
class LinearAttention(nn.Module):
    """
    Simple causal linear attention with φ(x)=ELU(x)+1 feature map.
    Pedagogical prefix-scan implementation (not the most optimized).
    """
    def __init__(self, n_embd, n_head):
        super().__init__()
        assert n_embd % n_head == 0
        self.n_head = n_head
        self.d = n_embd // n_head
        self.Wq = nn.Linear(n_embd, n_embd, bias=False)
        self.Wk = nn.Linear(n_embd, n_embd, bias=False)
        self.Wv = nn.Linear(n_embd, n_embd, bias=False)
        self.out = nn.Linear(n_embd, n_embd)
        self.eps = 1e-6

    @staticmethod
    def phi(x):
        return F.elu(x, alpha=1.0) + 1.0

    def forward(self, x):
        B, T, C = x.shape
        H, Dh = self.n_head, self.d

        q = self.Wq(x).reshape(B, T, H, Dh).transpose(1, 2)   # (B,H,T,Dh)
        k = self.Wk(x).reshape(B, T, H, Dh).transpose(1, 2)
        v = self.Wv(x).reshape(B, T, H, Dh).transpose(1, 2)

        qφ = self.phi(q)
        kφ = self.phi(k)

        # prefix accumulators
        K_accum  = torch.zeros(B, H, Dh, device=x.device, dtype=x.dtype)
        KV_accum = torch.zeros(B, H, Dh, Dh, device=x.device, dtype=x.dtype)

        outs = []
        for t in range(T):
            kt = kφ[:, :, t, :]                                    # (B,H,Dh)
            vt = v[:, :, t, :]                                     # (B,H,Dh)
            # accumulate
            KV_accum = KV_accum + torch.einsum("bhd,bhe->bhde", kt, vt)
            K_accum  = K_accum  + kt
            # query
            qt = qφ[:, :, t, :]                                    # (B,H,Dh)
            num = torch.einsum("bhd,bhde->bhe", qt, KV_accum)      # (B,H,Dh)
            den = torch.einsum("bhd,bhd->bh", qt, K_accum).unsqueeze(-1) + self.eps
            yt = num / den
            outs.append(yt)

        y = torch.stack(outs, dim=2).transpose(1, 2).reshape(B, T, C)  # (B,T,C)
        return self.out(y)

# ============================
# Multi-Query Attention (MQA)
# ============================
class MultiQueryAttention(nn.Module):
    """
    Multi-Query Attention: many query heads, shared K and V across heads.
    Cuts memory bandwidth for K/V compared with full MHA.
    """
    def __init__(self, n_embd, n_head):
        super().__init__()
        assert n_embd % n_head == 0
        self.n_head = n_head
        self.d = n_embd // n_head
        self.Wq = nn.Linear(n_embd, n_embd, bias=False)  # (B,T,C) -> (B,T,H*Dh)
        self.Wk = nn.Linear(n_embd, self.d, bias=False)  # shared K: (B,T,Dh)
        self.Wv = nn.Linear(n_embd, self.d, bias=False)  # shared V: (B,T,Dh)
        self.out = nn.Linear(n_embd, n_embd)

    def forward(self, x):
        B, T, C = x.shape
        H, Dh = self.n_head, self.d

        q = self.Wq(x).reshape(B, T, H, Dh).transpose(1, 2)      # (B,H,T,Dh)

        # shared K/V (expand over heads)
        k = self.Wk(x).unsqueeze(1).expand(B, H, T, Dh).contiguous()
        v = self.Wv(x).unsqueeze(1).expand(B, H, T, Dh).contiguous()

        # PyTorch's fused scaled dot-product attention (Flash when available)
        y = F.scaled_dot_product_attention(q, k, v, is_causal=True)  # (B,H,T,Dh)
        y = y.transpose(1, 2).reshape(B, T, C)
        return self.out(y)


In [6]:
# ============================================================
# 3. Tiny Transformer
# ============================================================

class Block(nn.Module):
    def __init__(self, n_embd, n_head, attn_class):
        super().__init__()
        self.ln1 = nn.LayerNorm(n_embd)
        self.attn = attn_class(n_embd, n_head)
        self.ln2 = nn.LayerNorm(n_embd)
        self.ff = nn.Sequential(
            nn.Linear(n_embd, 4*n_embd),
            nn.GELU(),
            nn.Linear(4*n_embd, n_embd)
        )

    def forward(self, x):
        x = x + self.attn(self.ln1(x))
        x = x + self.ff(self.ln2(x))
        return x

class TinyTransformer(nn.Module):
    def __init__(self, vocab_size, block_size, n_layer=4, n_embd=256, n_head=4, attn_class=ScaledDotProductAttention):
        super().__init__()
        self.block_size = block_size
        self.token = nn.Embedding(vocab_size, n_embd)
        self.pos   = nn.Embedding(block_size, n_embd)
        self.blocks = nn.Sequential(*[Block(n_embd, n_head, attn_class) for _ in range(n_layer)])
        self.ln_f = nn.LayerNorm(n_embd)
        self.head = nn.Linear(n_embd, vocab_size)

    def forward(self, idx, targets=None):
        B, T = idx.shape
        pos = torch.arange(0, T, device=idx.device)
        x = self.token(idx) + self.pos(pos)[None,:,:]
        x = self.blocks(x)
        x = self.ln_f(x)
        logits = self.head(x)
        if targets is None:
            return logits
        loss = F.cross_entropy(logits.view(-1, vocab_size), targets.view(-1))
        return logits, loss



In [None]:
# ============================================================
# 4. Training
# ============================================================

device = "cuda" if torch.cuda.is_available() else "cpu"
block_size = 256
batch_size = 64

# Choose model: ScaledDotProductAttention OR FlashAttention
attn_type = FlashAttention   # change this line

model = TinyTransformer(vocab_size, block_size, attn_class=attn_type).to(device)
optimizer = torch.optim.AdamW(model.parameters())

steps = 1000
print("Training...")

for step in range(steps):
    xb, yb = get_batch("train", block_size, batch_size, device)
    logits, loss = model(xb, yb)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if step % 100 == 0:
        print(f"step {step} | loss {loss.item():.3f}")



Training...
step 0 | loss 10.314
step 100 | loss 6.919
step 200 | loss 5.653
step 300 | loss 4.630
step 400 | loss 3.747


In [None]:
# ============================================================
# 5. Sampling
# ============================================================


def sample(model, start="ROMEO:", steps=100):
    model.eval()

    # Tokenize start prompt into words
    start_words = start.split()

    # Map words → IDs (use 0 for unknown words to avoid KeyErrors)
    start_ids = [stoi.get(w, 0) for w in start_words]

    # Build initial sequence tensor
    idx = torch.tensor([start_ids], device=device)

    for _ in range(steps):
        # Crop context to block_size so position embeddings never overflow
        idx_cond = idx[:, -model.block_size:]

        with torch.no_grad():
            logits = model(idx_cond)[:, -1, :]  # final word's logits

        # Greedy decoding: pick highest‑probability word
        next_id = torch.argmax(logits, dim=-1, keepdim=True)

        # Append to sequence
        idx = torch.cat([idx, next_id], dim=1)

    # Convert token IDs back to words
    result_words = [itos[int(i)] for i in idx[0]]
    return " ".join(result_words)




print("\n=== SAMPLE ===")
print(sample(model))
