In [2]:
import math, os, time, random, sys
from dataclasses import dataclass
from pathlib import Path

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

try:
    from datasets import load_dataset  # optional
    HAS_DATASETS = True
except Exception:
    HAS_DATASETS = False

# Reproducibility
def set_seed(seed: int = 1337):
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    # Make CUDA deterministic where possible (slight perf hit)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(1337)

device = "cuda" if torch.cuda.is_available() else "cpu"
device


'cuda'

In [3]:
@dataclass
class Config:
    # Data / batching
    dataset_name: str = "tiny_shakespeare"   # options: 'tiny_shakespeare', 'wikitext2', 'ptb'
    block_size: int = 128                    # context length
    batch_size: int = 32
    num_workers: int = 0                     # set >0 for faster dataload (if CPU not overloaded)

    # Model
    d_model: int = 192
    n_layers: int = 4
    n_heads: int = 4
    d_ff: int = 4 * 192
    dropout: float = 0.1

    # Optimization
    max_steps: int = 1000                    # increase for better results
    eval_every: int = 200
    learning_rate: float = 3e-4
    weight_decay: float = 0.1
    beta1: float = 0.9
    beta2: float = 0.95
    grad_clip: float = 1.0
    warmup_steps: int = 100
    compile_model: bool = False              # PyTorch 2.x compile (if available)

    # Misc
    amp: bool = True                         # mixed precision on GPU
    ckpt_path: str = "transformer_lm.pt"

cfg = Config()
cfg


Config(dataset_name='tiny_shakespeare', block_size=128, batch_size=32, num_workers=0, d_model=192, n_layers=4, n_heads=4, d_ff=768, dropout=0.1, max_steps=1000, eval_every=200, learning_rate=0.0003, weight_decay=0.1, beta1=0.9, beta2=0.95, grad_clip=1.0, warmup_steps=100, compile_model=False, amp=True, ckpt_path='transformer_lm.pt')

In [4]:
def load_tiny_shakespeare_text():
    url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
    try:
        import requests
        txt = requests.get(url, timeout=15).text
        if len(txt) < 1000:
            raise RuntimeError("Downloaded text seems too small.")
        return txt
    except Exception:
        # Small fallback so the notebook still runs offline
        return (
            "From fairest creatures we desire increase,\n"
            "That thereby beauty's rose might never die,\n"
            "But as the riper should by time decease,\n"
            "His tender heir might bear his memory:\n"
        )

def load_text_splits(name: str):
    """
    Returns (train_text, val_text, test_text) as plain strings.
    """
    if name == "tiny_shakespeare":
        text = load_tiny_shakespeare_text()
        n = len(text)
        # 90/5/5 split
        train_text = text[: int(0.9*n)]
        val_text   = text[int(0.9*n): int(0.95*n)]
        test_text  = text[int(0.95*n):]
        return train_text, val_text, test_text

    if HAS_DATASETS and name in {"wikitext2", "ptb"}:
        ds_name = "wikitext" if name == "wikitext2" else "ptb_text_only"
        ds_conf = "wikitext-2-raw-v1" if name == "wikitext2" else "penn_treebank"
        ds = load_dataset(ds_name, ds_conf)
        # Join lines into a single string per split
        def merge(split):
            txts = ds[split]["text"]
            return "\n".join(t for t in txts if t is not None)
        return merge("train"), merge("validation"), merge("test")

    print(f"[INFO] Falling back to Tiny Shakespeare (datasets available: {HAS_DATASETS})")
    return load_text_splits("tiny_shakespeare")

class ByteTokenizer:
    """Simple byte-level tokenizer (vocab size 256)."""
    def __init__(self):
        self.vocab_size = 256
    def encode(self, s: str):
        return torch.tensor(list(s.encode("utf-8")), dtype=torch.long)
    def decode(self, ids: torch.Tensor):
        if isinstance(ids, torch.Tensor):
            ids = ids.detach().cpu().tolist()
        return bytes(ids).decode("utf-8", errors="ignore")

tokenizer = ByteTokenizer()

train_text, val_text, test_text = load_text_splits(cfg.dataset_name)

train_ids = tokenizer.encode(train_text)
val_ids   = tokenizer.encode(val_text)
test_ids  = tokenizer.encode(test_text)

vocab_size = tokenizer.vocab_size
vocab_size


256

In [5]:
class LMStreamDataset(torch.utils.data.Dataset):
    def __init__(self, ids: torch.Tensor, block_size: int):
        self.ids = ids
        self.block_size = block_size

    def __len__(self):
        # number of possible starting positions (we sample randomly anyway)
        return max(1, len(self.ids) - self.block_size - 1)

    def __getitem__(self, idx):
        # Allow DataLoader to ask for a specific index; we still do random sampling
        # so epochs are not strictly deterministic across __getitem__ orders.
        i = random.randint(0, len(self.ids) - self.block_size - 2)
        x = self.ids[i: i + self.block_size]
        y = self.ids[i+1: i + self.block_size + 1]
        return x, y

def make_eval_loader(ids: torch.Tensor, block_size: int, batch_size: int, num_batches: int = 50):
    """
    Make a fixed evaluation loader by deterministic slicing (no randomness),
    to keep eval stable across steps.
    """
    chunks = []
    stride = max(1, (len(ids) - (block_size + 1)) // (num_batches * batch_size))
    if stride == 0:
        stride = 1
    start = 0
    for _ in range(num_batches * batch_size):
        if start + block_size + 1 > len(ids):
            start = 0
        x = ids[start: start + block_size]
        y = ids[start+1: start + block_size + 1]
        chunks.append((x, y))
        start += stride

    def collate(batch):
        xs = torch.stack([b[0] for b in batch], dim=0)
        ys = torch.stack([b[1] for b in batch], dim=0)
        return xs, ys

    ds = chunks
    return DataLoader(ds, batch_size=batch_size, shuffle=False, num_workers=0, collate_fn=collate)

train_ds = LMStreamDataset(train_ids, cfg.block_size)
train_loader = DataLoader(train_ds, batch_size=cfg.batch_size, shuffle=True, num_workers=cfg.num_workers, drop_last=True)

val_loader  = make_eval_loader(val_ids,  cfg.block_size, cfg.batch_size, num_batches=50)
test_loader = make_eval_loader(test_ids, cfg.block_size, cfg.batch_size, num_batches=50)

len(train_ds), len(val_loader), len(test_loader)


(1003725, 50, 50)

In [6]:
class CausalSelfAttention(nn.Module):
    def __init__(self, d_model, n_heads, dropout, block_size):
        super().__init__()
        assert d_model % n_heads == 0
        self.n_heads = n_heads
        self.head_dim = d_model // n_heads
        self.scale = self.head_dim ** -0.5

        self.qkv = nn.Linear(d_model, 3 * d_model, bias=False)
        self.proj = nn.Linear(d_model, d_model, bias=False)
        self.attn_drop = nn.Dropout(dropout)
        self.resid_drop = nn.Dropout(dropout)

        # Causal mask cached as buffer
        mask = torch.tril(torch.ones(block_size, block_size))
        self.register_buffer("mask", mask.view(1, 1, block_size, block_size))

    def forward(self, x):
        B, T, C = x.size()
        qkv = self.qkv(x)
        q, k, v = qkv.split(C, dim=2)

        # reshape to (B, n_heads, T, head_dim)
        q = q.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
        k = k.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
        v = v.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)

        att = (q @ k.transpose(-2, -1)) * self.scale
        att = att.masked_fill(self.mask[:, :, :T, :T] == 0, float("-inf"))
        att = F.softmax(att, dim=-1)
        att = self.attn_drop(att)

        y = att @ v  # (B, n_heads, T, head_dim)
        y = y.transpose(1, 2).contiguous().view(B, T, C)
        y = self.resid_drop(self.proj(y))
        return y

class MLP(nn.Module):
    def __init__(self, d_model, d_ff, dropout):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_ff, d_model),
            nn.Dropout(dropout),
        )
    def forward(self, x):
        return self.net(x)

class Block(nn.Module):
    def __init__(self, d_model, n_heads, d_ff, dropout, block_size):
        super().__init__()
        self.ln1 = nn.LayerNorm(d_model)
        self.attn = CausalSelfAttention(d_model, n_heads, dropout, block_size)
        self.ln2 = nn.LayerNorm(d_model)
        self.mlp = MLP(d_model, d_ff, dropout)

    def forward(self, x):
        x = x + self.attn(self.ln1(x))
        x = x + self.mlp(self.ln2(x))
        return x

class GPT(nn.Module):
    def __init__(self, vocab_size, block_size, d_model, n_layers, n_heads, d_ff, dropout):
        super().__init__()
        self.block_size = block_size
        self.token_emb = nn.Embedding(vocab_size, d_model)
        self.pos_emb   = nn.Parameter(torch.zeros(1, block_size, d_model))
        self.drop = nn.Dropout(dropout)
        self.blocks = nn.ModuleList([
            Block(d_model, n_heads, d_ff, dropout, block_size)
            for _ in range(n_layers)
        ])
        self.ln_f = nn.LayerNorm(d_model)
        self.head = nn.Linear(d_model, vocab_size, bias=False)

        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, idx, targets=None):
        B, T = idx.size()
        assert T <= self.block_size, "Sequence length > block size"
        tok = self.token_emb(idx)                  # (B,T,C)
        pos = self.pos_emb[:, :T, :]               # (1,T,C)
        x = self.drop(tok + pos)
        for blk in self.blocks:
            x = blk(x)
        x = self.ln_f(x)
        logits = self.head(x)                      # (B,T,V)

        loss = None
        if targets is not None:
            loss = F.cross_entropy(
                logits.view(-1, logits.size(-1)),
                targets.view(-1)
            )
        return logits, loss

model = GPT(
    vocab_size=vocab_size,
    block_size=cfg.block_size,
    d_model=cfg.d_model,
    n_layers=cfg.n_layers,
    n_heads=cfg.n_heads,
    d_ff=cfg.d_ff,
    dropout=cfg.dropout
).to(device)

if cfg.compile_model and hasattr(torch, "compile"):
    model = torch.compile(model)

num_params = sum(p.numel() for p in model.parameters())
print(f"Model parameters: {num_params/1e6:.2f}M")


Model parameters: 1.90M


In [7]:
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=cfg.learning_rate,
    betas=(cfg.beta1, cfg.beta2),
    weight_decay=cfg.weight_decay,
)

def get_lr(step):
    # simple linear warmup, then cosine decay to 10% of base LR
    if step < cfg.warmup_steps:
        return cfg.learning_rate * (step + 1) / cfg.warmup_steps
    # cosine
    progress = (step - cfg.warmup_steps) / max(1, (cfg.max_steps - cfg.warmup_steps))
    min_lr = cfg.learning_rate * 0.1
    return min_lr + 0.5 * (cfg.learning_rate - min_lr) * (1 + math.cos(math.pi * progress))

@torch.no_grad()
def evaluate(model, loader):
    model.eval()
    tot_loss, tot_tokens, tot_correct = 0.0, 0, 0
    for x, y in loader:
        x = x.to(device)
        y = y.to(device)
        logits, loss = model(x, targets=y)
        # Accumulate
        token_count = y.numel()
        tot_loss += loss.item() * token_count
        tot_tokens += token_count
        preds = logits.argmax(dim=-1)
        tot_correct += (preds == y).sum().item()
    mean_loss = tot_loss / tot_tokens
    ppl = math.exp(mean_loss)
    acc = tot_correct / tot_tokens
    return {"loss": mean_loss, "perplexity": ppl, "accuracy": acc}


In [8]:
scaler = torch.cuda.amp.GradScaler(enabled=(device == "cuda" and cfg.amp))

global_step = 0
t0 = time.time()

model.train()
for step, (x, y) in enumerate(iter(train_loader)):
    if global_step >= cfg.max_steps:
        break

    x = x.to(device)
    y = y.to(device)

    # LR schedule
    for param_group in optimizer.param_groups:
        param_group["lr"] = get_lr(global_step)

    # Forward/backward
    with torch.cuda.amp.autocast(enabled=(device == "cuda" and cfg.amp)):
        logits, loss = model(x, targets=y)

    scaler.scale(loss).backward()
    # Gradient clipping
    if cfg.grad_clip is not None:
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.grad_clip)

    scaler.step(optimizer)
    scaler.update()
    optimizer.zero_grad(set_to_none=True)
    global_step += 1

    if global_step % 50 == 0 or global_step == 1:
        ppl = math.exp(loss.item())
        print(f"step {global_step:5d} | train loss {loss.item():.4f} | train ppl {ppl:.2f} | lr {optimizer.param_groups[0]['lr']:.2e}")

    if global_step % cfg.eval_every == 0 or global_step == cfg.max_steps:
        val_metrics = evaluate(model, val_loader)
        print(f"[Eval @ step {global_step}] val loss {val_metrics['loss']:.4f} | "
              f"val ppl {val_metrics['perplexity']:.2f} | val acc {val_metrics['accuracy']:.4f}")

print(f"Training finished in {time.time() - t0:.1f}s")

# Save checkpoint
torch.save({
    "model_state": model.state_dict(),
    "cfg": cfg.__dict__,
    "vocab_size": vocab_size,
}, cfg.ckpt_path)
print(f"Saved checkpoint to {cfg.ckpt_path}")


  scaler = torch.cuda.amp.GradScaler(enabled=(device == "cuda" and cfg.amp))
  with torch.cuda.amp.autocast(enabled=(device == "cuda" and cfg.amp)):


step     1 | train loss 5.5796 | train ppl 264.98 | lr 3.00e-06
step    50 | train loss 3.9733 | train ppl 53.16 | lr 1.50e-04
step   100 | train loss 2.7831 | train ppl 16.17 | lr 3.00e-04
step   150 | train loss 2.5812 | train ppl 13.21 | lr 2.98e-04
step   200 | train loss 2.5008 | train ppl 12.19 | lr 2.92e-04
[Eval @ step 200] val loss 2.4667 | val ppl 11.78 | val acc 0.2864
step   250 | train loss 2.3981 | train ppl 11.00 | lr 2.82e-04
step   300 | train loss 2.3737 | train ppl 10.74 | lr 2.69e-04
step   350 | train loss 2.2819 | train ppl 9.80 | lr 2.52e-04
step   400 | train loss 2.2425 | train ppl 9.42 | lr 2.33e-04
[Eval @ step 400] val loss 2.2805 | val ppl 9.78 | val acc 0.3296
step   450 | train loss 2.2318 | train ppl 9.32 | lr 2.12e-04
step   500 | train loss 2.1678 | train ppl 8.74 | lr 1.89e-04
step   550 | train loss 2.1373 | train ppl 8.48 | lr 1.65e-04
step   600 | train loss 2.1047 | train ppl 8.20 | lr 1.42e-04
[Eval @ step 600] val loss 2.1283 | val ppl 8.40 | va

In [9]:
test_metrics = evaluate(model, test_loader)
print(f"TEST: loss {test_metrics['loss']:.4f} | ppl {test_metrics['perplexity']:.2f} | acc {test_metrics['accuracy']:.4f}")


TEST: loss 2.0341 | ppl 7.65 | acc 0.3988


In [10]:
@torch.no_grad()
def generate(model, prompt, max_new_tokens=200, temperature=1.0, top_k=0):
    model.eval()
    ids = tokenizer.encode(prompt).unsqueeze(0).to(device)
    for _ in range(max_new_tokens):
        idx_cond = ids[:, -cfg.block_size:]
        logits, _ = model(idx_cond)
        logits = logits[:, -1, :] / max(1e-8, temperature)
        if top_k > 0:
            v, ix = torch.topk(logits, top_k)
            logits[logits < v[:, [-1]]] = -float("inf")
        probs = F.softmax(logits, dim=-1)
        next_id = torch.multinomial(probs, num_samples=1)
        ids = torch.cat([ids, next_id], dim=1)
    return tokenizer.decode(ids[0])

sample = generate(model, prompt="ROMEO:\n", max_new_tokens=200, temperature=0.8, top_k=50)
print(sample)


ROMEO:
MIget say, and dears the wi'll if lose,
O whe conie she mart and the would some: a ill phoursm wasch count
to hencer ladop.

DUCENGLIULE:
Golt Goodst I I youll Wherere of he, ame,
Whonk the to vill ha
