In [1]:
import math
import os
import time
import urllib.request
from dataclasses import dataclass

import torch
import torch.nn as nn
import torch.nn.functional as F


In [None]:
import math
import os
import time
import urllib.request
from dataclasses import dataclass

import torch
import torch.nn as nn
import torch.nn.functional as F


# =========================
# 0) Config
# =========================
@dataclass
class CFG:
    # data
    data_path: str = "tinyshakespeare.txt"
    url: str = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"

    # training
    device: str = "cuda" if torch.cuda.is_available() else "cpu"
    batch_size: int = 64
    block_size: int = 128          # context length (N)
    max_iters: int = 3000
    eval_every: int = 300
    lr: float = 3e-4
    weight_decay: float = 0.01
    grad_clip: float = 1.0

    # model
    n_layers: int = 4
    n_heads: int = 4
    d_model: int = 256             # D
    d_ff: int = 4 * 256            # MLP hidden
    dropout: float = 0.1


cfg = CFG()


# =========================
# 1) Get dataset
# =========================
def ensure_data(cfg: CFG):
    if os.path.exists(cfg.data_path):
        return
    print(f"Downloading dataset to {cfg.data_path} ...")
    urllib.request.urlretrieve(cfg.url, cfg.data_path)
    print("Done.")


def load_text(path: str) -> str:
    with open(path, "r", encoding="utf-8") as f:
        return f.read()


# =========================
# 2) Char tokenizer
# =========================
class CharVocab:
    def __init__(self, text: str):
        chars = sorted(list(set(text)))
        self.stoi = {ch: i for i, ch in enumerate(chars)}
        self.itos = {i: ch for ch, i in self.stoi.items()}
        self.vocab_size = len(chars)

    def encode(self, s: str):
        return [self.stoi[c] for c in s]

    def decode(self, ids):
        return "".join(self.itos[i] for i in ids)


# =========================
# 3) Data batching
# =========================
def make_splits(data_ids: torch.Tensor, train_frac=0.9):
    n = int(len(data_ids) * train_frac)
    return data_ids[:n], data_ids[n:]


def get_batch(data_ids: torch.Tensor, cfg: CFG):
    # sample random positions
    ix = torch.randint(0, len(data_ids) - cfg.block_size - 1, (cfg.batch_size,))
    x = torch.stack([data_ids[i:i + cfg.block_size] for i in ix])
    y = torch.stack([data_ids[i + 1:i + cfg.block_size + 1] for i in ix])
    return x.to(cfg.device), y.to(cfg.device)


@torch.no_grad()
def estimate_loss(model, train_ids, val_ids, cfg: CFG, iters=50):
    model.eval()
    out = {}
    for split, ids in [("train", train_ids), ("val", val_ids)]:
        losses = []
        for _ in range(iters):
            x, y = get_batch(ids, cfg)
            logits = model(x)
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), y.view(-1))
            losses.append(loss.item())
        out[split] = sum(losses) / len(losses)
    model.train()
    return out


# =========================
# 4) Transformer building blocks (from scratch)
# =========================
class CausalSelfAttention(nn.Module):
    def __init__(self, d_model, n_heads, block_size, dropout):
        super().__init__()
        assert d_model % n_heads == 0
        self.n_heads = n_heads
        self.d_head = d_model // n_heads

        # combined QKV projection (one matmul)
        self.qkv = nn.Linear(d_model, 3 * d_model, bias=False)
        self.proj = nn.Linear(d_model, d_model, bias=False)
        self.dropout = nn.Dropout(dropout)

        # causal mask (precomputed)
        mask = torch.tril(torch.ones(block_size, block_size)).view(1, 1, block_size, block_size)
        self.register_buffer("mask", mask)  # not a parameter

    def forward(self, x):
        # x: (B, N, D)
        B, N, D = x.shape

        qkv = self.qkv(x)  # (B, N, 3D)
        q, k, v = qkv.split(D, dim=2)  # each (B, N, D)

        # reshape to heads
        q = q.view(B, N, self.n_heads, self.d_head).transpose(1, 2)  # (B, H, N, Dh)
        k = k.view(B, N, self.n_heads, self.d_head).transpose(1, 2)  # (B, H, N, Dh)
        v = v.view(B, N, self.n_heads, self.d_head).transpose(1, 2)  # (B, H, N, Dh)

        # scaled dot-product attention
        att = (q @ k.transpose(-2, -1)) / math.sqrt(self.d_head)  # (B, H, N, N)

        # causal mask: prevent attending to future tokens
        att = att.masked_fill(self.mask[:, :, :N, :N] == 0, float("-inf"))

        w = F.softmax(att, dim=-1)              # (B, H, N, N)
        w = self.dropout(w)

        y = w @ v                               # (B, H, N, Dh)
        y = y.transpose(1, 2).contiguous().view(B, N, D)  # concat heads -> (B, N, D)

        y = self.proj(y)
        y = self.dropout(y)
        return y


class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff, dropout):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.GELU(),
            nn.Linear(d_ff, d_model),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.net(x)


class TransformerBlock(nn.Module):
    """
    Pre-LN Transformer block:
    x = x + Attn(LN(x))
    x = x + MLP(LN(x))
    """
    def __init__(self, d_model, n_heads, block_size, d_ff, dropout):
        super().__init__()
        self.ln1 = nn.LayerNorm(d_model)
        self.attn = CausalSelfAttention(d_model, n_heads, block_size, dropout)
        self.ln2 = nn.LayerNorm(d_model)
        self.mlp = FeedForward(d_model, d_ff, dropout)

    def forward(self, x):
        x = x + self.attn(self.ln1(x))
        x = x + self.mlp(self.ln2(x))
        return x


class TinyGPT(nn.Module):
    def __init__(self, vocab_size, cfg: CFG):
        super().__init__()
        self.vocab_size = vocab_size
        self.block_size = cfg.block_size

        self.tok_emb = nn.Embedding(vocab_size, cfg.d_model)
        self.pos_emb = nn.Embedding(cfg.block_size, cfg.d_model)
        self.drop = nn.Dropout(cfg.dropout)

        self.blocks = nn.ModuleList([
            TransformerBlock(cfg.d_model, cfg.n_heads, cfg.block_size, cfg.d_ff, cfg.dropout)
            for _ in range(cfg.n_layers)
        ])
        self.ln_f = nn.LayerNorm(cfg.d_model)
        self.head = nn.Linear(cfg.d_model, vocab_size, bias=False)

        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.normal_(m.weight, mean=0.0, std=0.02)
            if m.bias is not None:
                nn.init.zeros_(m.bias)
        if isinstance(m, nn.Embedding):
            nn.init.normal_(m.weight, mean=0.0, std=0.02)

    def forward(self, idx):
        # idx: (B, N)
        B, N = idx.shape
        if N > self.block_size:
            raise ValueError(f"Sequence length {N} exceeds block_size {self.block_size}")

        pos = torch.arange(0, N, device=idx.device).unsqueeze(0)  # (1, N)
        x = self.tok_emb(idx) + self.pos_emb(pos)                 # (B, N, D)
        x = self.drop(x)

        for blk in self.blocks:
            x = blk(x)

        x = self.ln_f(x)
        logits = self.head(x)  # (B, N, vocab)
        return logits

    @torch.no_grad()
    def generate(self, idx, max_new_tokens=200, temperature=1.0):
        # idx: (B, N)
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -self.block_size:]
            logits = self(idx_cond)              # (B, N, vocab)
            logits = logits[:, -1, :] / temperature
            probs = F.softmax(logits, dim=-1)
            next_id = torch.multinomial(probs, num_samples=1)  # (B, 1)
            idx = torch.cat([idx, next_id], dim=1)
        return idx


# =========================
# 5) Train
# =========================
def main():
    torch.manual_seed(1337)

    # data
    ensure_data(cfg)
    text = load_text(cfg.data_path)
    vocab = CharVocab(text)
    data = torch.tensor(vocab.encode(text), dtype=torch.long)

    train_ids, val_ids = make_splits(data)

    print(f"Dataset chars: {len(text):,}")
    print(f"Vocab size: {vocab.vocab_size}")
    print(f"Train tokens: {len(train_ids):,} | Val tokens: {len(val_ids):,}")
    print(f"Device: {cfg.device}")

    # model
    model = TinyGPT(vocab.vocab_size, cfg).to(cfg.device)
    print(f"Params: {sum(p.numel() for p in model.parameters())/1e6:.2f}M")

    optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)

    t0 = time.time()
    for it in range(1, cfg.max_iters + 1):
        x, y = get_batch(train_ids, cfg)
        logits = model(x)
        loss = F.cross_entropy(logits.view(-1, logits.size(-1)), y.view(-1))

        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.grad_clip)
        optimizer.step()

        if it % cfg.eval_every == 0 or it == 1:
            losses = estimate_loss(model, train_ids, val_ids, cfg, iters=30)
            dt = time.time() - t0
            print(f"iter {it:5d} | train loss {losses['train']:.3f} | val loss {losses['val']:.3f} | time {dt:.1f}s")

            # sample generation
            context = torch.zeros((1, 1), dtype=torch.long, device=cfg.device)  # start token = arbitrary (0)
            gen = model.generate(context, max_new_tokens=300, temperature=1.0)[0].tolist()
            print("---- sample ----")
            print(vocab.decode(gen[:]))
            print("--------------\n")

    # final sample
    context = torch.zeros((1, 1), dtype=torch.long, device=cfg.device)
    gen = model.generate(context, max_new_tokens=800, temperature=0.9)[0].tolist()
    print(vocab.decode(gen))


if __name__ == "__main__":
    main()
