# TinyStories speed-run

Minimal notebook: download TinyStories → train BPE → encode → train LM (fixed steps or full run) → plot loss → generate. Target: efficient run on A100/H100, validation loss ≤ 1.45. See `docs/SPEEDRUN_BENCHMARK.md`.

**Colab-friendly:** Open this notebook in Google Colab ("Open in Colab" from GitHub). If you opened a blank Colab, set `COLAB_CLONE_URL` in the Setup cell to your repo and run it to clone + install. Enable a GPU: **Runtime → Change runtime type → T4 GPU** (or A100 if available).

## 1. Setup

In [None]:
import sys, os
from pathlib import Path

IN_COLAB = "google.colab" in sys.modules
marker = Path("src") / "tokenizer.py"

def find_root():
    cwd = Path.cwd().resolve()
    for _ in range(10):
        if (cwd / marker).exists():
            return cwd
        if (cwd / "project-2-lm-jax" / marker).exists():
            return cwd / "project-2-lm-jax"
        if cwd == cwd.parent:
            break
        cwd = cwd.parent
    return None

if IN_COLAB:
    ROOT = find_root()
    if ROOT is None:
        # Clone repo: set this to your repo URL if you opened a blank Colab
        COLAB_CLONE_URL = "https://github.com/ns-1456/JAX-XLA-LM.git"
        clone_dir = Path("/content") / "JAX-XLA-LM"
        if not (clone_dir / marker).exists():
            get_ipython().system(f"git clone --depth 1 {COLAB_CLONE_URL} {clone_dir}")
        ROOT = clone_dir
    get_ipython().run_line_magic("pip", "install -q -r " + str(ROOT / "requirements.txt"))
    # Ensure we're in repo root for relative paths
    os.chdir(ROOT)
else:
    ROOT = Path(os.environ.get("LM_PROJECT_ROOT", "")).resolve() if os.environ.get("LM_PROJECT_ROOT") else None
    if not ROOT or not (ROOT / marker).exists():
        ROOT = find_root()
    if not ROOT or not (ROOT / marker).exists():
        raise FileNotFoundError("Set LM_PROJECT_ROOT or run from project-2-lm-jax.")

sys.path.insert(0, str(ROOT / "src"))

import numpy as np
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import time
from tqdm import tqdm
from tokenizer import train_bpe, Tokenizer
from model import TransformerConfig, TransformerLM
from optimizer import init_adamw, adamw_step
from train import get_batch, loss_fn, save_checkpoint
from generate import generate

OUT_DIR = ROOT / "notebooks" / "demo_outputs"
DATA_DIR = OUT_DIR / "data"
OUT_DIR.mkdir(parents=True, exist_ok=True)
DATA_DIR.mkdir(parents=True, exist_ok=True)
EOT = "<|endoftext|>"

print("ROOT:", ROOT)
print("Colab:", IN_COLAB)
devices = jax.devices()
print("JAX devices:", devices)
if IN_COLAB and not devices:
    print("No GPU detected. Runtime → Change runtime type → T4 GPU (or A100).")

## 2. Download TinyStories

In [None]:
import shutil
from huggingface_hub import hf_hub_download

# TinyStories (public); on Colab this caches under ~/.cache/huggingface
for split in ("train", "valid"):
    fn = f"TinyStoriesV2-GPT4-{split}.txt"
    cached = hf_hub_download(repo_id="roneneldan/TinyStories", filename=fn, repo_type="dataset")
    shutil.copy(cached, DATA_DIR / fn)
path_ts = DATA_DIR / "TinyStoriesV2-GPT4-train.txt"
print("TinyStories train:", path_ts, "| size:", path_ts.stat().st_size // 1_000_000, "MB")

## 3. Train BPE and encode corpus

In [None]:
vocab, merges = train_bpe(str(path_ts), vocab_size=10000, special_tokens=[EOT], max_bytes=5_000_000, use_fast=True)
tokenizer = Tokenizer(vocab=vocab, merges=merges, special_tokens=[EOT])
tokenizer.save(str(OUT_DIR / "tokenizer_tinystories.json"))
vocab_size = len(vocab)

with open(path_ts) as f:
    full_text = f.read()
ids = np.array(tokenizer.encode(full_text), dtype=np.uint16)
np.save(OUT_DIR / "tinystories_ids.npy", ids)
data = np.load(OUT_DIR / "tinystories_ids.npy").astype(np.int32)
print(f"Tokenizer: vocab_size={vocab_size}. Data: {len(data):,} tokens.")

## 4. Train LM (speed-run or full)

Set `NUM_STEPS` for a fixed-step speed run, or a large value (e.g. 100_000+) for full pretraining. **Colab:** T4 free tier — use `BATCH_SIZE = 32` to avoid OOM; A100 (Colab Pro) — 64–128 is fine.

In [None]:
NUM_STEPS = 30_000   # speed-run; use 100_000+ for full pretraining
BATCH_SIZE = 32      # 32 for Colab T4; 64–128 for A100/H100
SEQ_LEN = 256
LR = 3e-4
WEIGHT_DECAY = 0.01
LOG_EVERY = 200
SAVE_EVERY = 5000
SEED = 42

cfg = TransformerConfig(vocab_size=vocab_size, d_model=256, num_layers=6, num_heads=8, d_ff=1024, max_seq_len=SEQ_LEN, dropout_rate=0.1)
model = TransformerLM(cfg)
rng = jax.random.PRNGKey(SEED)
params = model.init(rng, jnp.ones((1, SEQ_LEN), dtype=jnp.int32), deterministic=True)
opt_state = init_adamw(params)
np_rng = np.random.default_rng(SEED)
grad_fn = jax.jit(jax.grad(lambda p, b: loss_fn(p, model, b)))
jit_loss = jax.jit(lambda p, b: loss_fn(p, model, b))

nparams = sum(x.size for x in jax.tree_util.tree_leaves(params))
print(f"Model: {cfg.num_layers}L-{cfg.d_model}D, {nparams:,} params. Steps: {NUM_STEPS}, batch={BATCH_SIZE}, seq_len={SEQ_LEN}")

In [None]:
losses = []
t0 = time.time()
for step in tqdm(range(1, NUM_STEPS + 1), desc="Training"):
    batch = jnp.array(get_batch(data, BATCH_SIZE, SEQ_LEN, np_rng))
    grads = grad_fn(params, batch)
    params, opt_state = adamw_step(params, grads, opt_state, lr=LR, weight_decay=WEIGHT_DECAY)
    if step % LOG_EVERY == 0:
        loss_val = float(jit_loss(params, batch))
        losses.append((step, loss_val))
        elapsed = time.time() - t0
        tok_per_sec = (step * BATCH_SIZE * SEQ_LEN) / elapsed
        tqdm.write(f"step {step} | loss {loss_val:.4f} | ppl {np.exp(loss_val):.2f} | {tok_per_sec:.0f} tok/s")
    if step % SAVE_EVERY == 0:
        save_checkpoint(str(OUT_DIR / f"ckpt_step_{step}.pkl"), params, opt_state, step, cfg)

elapsed = time.time() - t0
save_checkpoint(str(OUT_DIR / "ckpt_final.pkl"), params, opt_state, NUM_STEPS, cfg)
print(f"Done in {elapsed/60:.1f} min. Final checkpoint: ckpt_final.pkl")

## 5. Loss curve

In [None]:
if losses:
    steps, vals = zip(*losses)
    plt.figure(figsize=(8, 4))
    plt.plot(steps, vals, "b-", label="Loss")
    plt.xlabel("Step")
    plt.ylabel("Loss")
    plt.title("TinyStories speed-run")
    plt.legend()
    plt.tight_layout()
    plt.show()
    print(f"Final logged loss: {vals[-1]:.4f}")

## 6. Generate

In [None]:
out = generate(model, params, tokenizer, "Once upon a time", max_tokens=256, temperature=0.8, top_k=40, seed=99)
print(out)
print(f"\n~{len(tokenizer.encode(out))} tokens")