# CS336 Assignment 1 (basics): Building a Transformer LM

This notebook follows **Stanford CS336 Spring 2025 Assignment 1** (cs336_spring2025_assignment1_basics.pdf). It uses this repo's JAX implementation (BPE tokenizer, Transformer LM, AdamW, training, generation).

**What you will implement (per PDF):** §2 BPE tokenizer, §3 Transformer LM, §4 cross-entropy & AdamW, §5 training loop & checkpointing, §6 decoding.

**End-to-end (no manual steps):** This notebook downloads TinyStories only, trains BPE on it, runs tokenizer experiments, trains the Transformer LM, generates ≥256 tokens, and runs LR + batch-size experiments. Run all cells in order.

## Setup (Colab: clone + install; Local: set cwd or LM_PROJECT_ROOT)

In [None]:
import sys, os, tempfile
from pathlib import Path

IN_COLAB = "google.colab" in sys.modules
if IN_COLAB:
    repo = Path("/content/JAX-Transformer")
    if not (repo / "src" / "tokenizer.py").exists():
        get_ipython().system("git clone --depth 1 https://github.com/ns-1456/JAX-Transformer.git /content/JAX-Transformer")
    ROOT = repo
    get_ipython().run_line_magic("pip", "install -q -r " + str(ROOT / "requirements.txt"))
else:
    marker = Path("src") / "tokenizer.py"
    ROOT = Path(os.environ.get("LM_PROJECT_ROOT", "")).resolve() if os.environ.get("LM_PROJECT_ROOT") else None
    if not ROOT or not (ROOT / marker).exists():
        cwd = Path.cwd().resolve()
        for _ in range(8):
            if (cwd / marker).exists(): ROOT = cwd; break
            if (cwd / "project-2-lm-jax" / marker).exists(): ROOT = cwd / "project-2-lm-jax"; break
            if cwd == cwd.parent: break
            cwd = cwd.parent
    if not ROOT or not (ROOT / marker).exists():
        raise FileNotFoundError("Set LM_PROJECT_ROOT or run from project-2-lm-jax.")
sys.path.insert(0, str(ROOT / "src"))

import numpy as np
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
from tqdm import tqdm
from tokenizer import train_bpe, Tokenizer
from model import TransformerConfig, TransformerLM
from optimizer import init_adamw, adamw_step
from train import get_batch, loss_fn, save_checkpoint
from generate import generate

DEMO_DIR = ROOT / "notebooks" / "demo_outputs"
DATA_DIR = DEMO_DIR / "data"
DEMO_DIR.mkdir(parents=True, exist_ok=True)
DATA_DIR.mkdir(parents=True, exist_ok=True)
print("ROOT:", ROOT)
print("JAX devices:", jax.devices())

## Download datasets (TinyStories only)

Uses the same TinyStories data as [stanford-cs336/assignment1-basics](https://github.com/stanford-cs336/assignment1-basics): TinyStoriesV2-GPT4 train + valid. No manual wget or file placement.

In [None]:
import shutil
from huggingface_hub import hf_hub_download

EOT = "<|endoftext|>"

# TinyStories (same as assignment: roneneldan/TinyStories, TinyStoriesV2-GPT4-train/valid.txt)
print("Downloading TinyStoriesV2-GPT4 (train + valid)...")
for split in ("train", "valid"):
    fn = f"TinyStoriesV2-GPT4-{split}.txt"
    cached = hf_hub_download(repo_id="roneneldan/TinyStories", filename=fn, repo_type="dataset")
    dest = DATA_DIR / fn
    shutil.copy(cached, dest)
    print(f"  {fn} -> {dest}")
path_ts = DATA_DIR / "TinyStoriesV2-GPT4-train.txt"

print("Done. Using path_ts for BPE and training below.")

---
## §2 Byte-Pair Encoding (BPE) Tokenizer

### §2.1 The Unicode Standard

**Problem (unicode1):** (a) What character does `chr(0)` return? (b) How does `__repr__` differ from printed representation? (c) What happens when it occurs in text?

**Deliverable:** One-sentence each in writeup. Optional exploration below.

In [None]:
# Optional: explore chr(0) for writeup
c = chr(0)
print("repr:", repr(c)); print("print:", repr(print(c)))
s = "this is a test" + chr(0) + "string"
print("repr(s):", repr(s)); print("print(s):"); print(s)

### §2.2 Unicode Encodings

**Problem (unicode2):** (a) Why prefer UTF-8 over UTF-16/32? (b) Why is `decode_utf8_bytes_to_str_wrong` incorrect? Example? (c) Two-byte sequence that does not decode.

**Deliverable:** Short answers in writeup.

### §2.4 BPE Tokenizer Training

**Problem (train_bpe):** Train byte-level BPE: `input_path`, `vocab_size`, `special_tokens` → `(vocab, merges)`. GPT-2 regex pre-tokenization; lexicographic tie-break; strip special tokens before pre-tokenization.

**Implementation:** `src/tokenizer.py` → `train_bpe()`. Merges here are `list[tuple[int,int]]` (IDs); PDF uses `tuple[bytes,bytes]`—convert via vocab if needed.

In [None]:
SAMPLE = (
    "The little dog ran in the park. The cat sat on the mat. "
    "Once upon a time there was a girl. She had a big red ball. "
) * 80
with tempfile.NamedTemporaryFile(mode="w", suffix=".txt", delete=False) as f:
    f.write(SAMPLE)
    _path = f.name
try:
    vocab, merges = train_bpe(_path, vocab_size=1024, special_tokens=["<|endoftext|>"])
finally:
    os.unlink(_path)
tokenizer = Tokenizer(vocab=vocab, merges=merges, special_tokens=["<|endoftext|>"])
tokenizer.save(str(DEMO_DIR / "tokenizer.json"))
vocab_size = len(vocab)
print("Vocab size:", vocab_size)
print("Encode 'The dog ran':", tokenizer.encode("The dog ran.")[:8])

### §2.5 BPE on TinyStories

**Problem (train_bpe_tinystories):** Train BPE on TinyStories, vocab_size=10000, special token `<|endoftext|>`. Report: time, memory, longest token.

**Deliverable:** Run with dataset path; report in writeup. Example: `train_bpe(path, 10000, ["<|endoftext|>"])` then inspect vocab/merges.

In [None]:
# §2.5: Train BPE on downloaded TinyStories only.
import time
import tracemalloc

path_ts = DATA_DIR / "TinyStoriesV2-GPT4-train.txt"
if not path_ts.exists():
    raise FileNotFoundError("Run the 'Download datasets' cell first. It writes to demo_outputs/data/.")
EOT = "<|endoftext|>"  # in case this cell is run without the download cell

def run_bpe_section(name, path, vocab_size, special_tokens, out_json):
    """Train BPE, report time/memory/longest token, save tokenizer."""
    path = Path(path)
    print(f"--- {name} (vocab_size={vocab_size}) ---")
    tracemalloc.start()
    t0 = time.perf_counter()
    # use_fast=True uses Hugging Face tokenizers (Rust); max_bytes=5M for ~1-2 s demo.
    vocab, merges = train_bpe(str(path), vocab_size=vocab_size, special_tokens=special_tokens, max_bytes=5_000_000, use_fast=True)
    elapsed = time.perf_counter() - t0
    _, peak = tracemalloc.get_traced_memory()
    tracemalloc.stop()
    tok = Tokenizer(vocab=vocab, merges=merges, special_tokens=special_tokens)
    tok.save(str(out_json))
    longest = max((len(v) for v in vocab.values()), default=0)
    longest_tok = next((v for v in vocab.values() if len(v) == longest), b"")
    print(f"  Time: {elapsed:.1f}s  |  Peak RAM: {peak / 1024**3:.2f} GB  |  Longest token: {len(longest_tok)} bytes (e.g. {longest_tok[:50]!r})")
    return tok

tok_ts = run_bpe_section("TinyStories", path_ts, 10000, [EOT], DEMO_DIR / "tokenizer_tinystories.json")
tokenizer = tok_ts
vocab_size = len(tok_ts.vocab)
print("Using TinyStories tokenizer for LM.")

### §2.6 BPE Tokenizer: Encoding and Decoding

**Problem (tokenizer):** Implement Tokenizer: `encode`, `decode`, `encode_iterable` (for large files), `save`/`load`. Decode uses `errors='replace'` for malformed bytes.

**Implementation:** `src/tokenizer.py` → `Tokenizer` class.

In [None]:
text = "The cat sat on the mat."
ids = tokenizer.encode(text)
decoded = tokenizer.decode(ids)
print("ids:", ids)
print("decode:", decoded)
print("roundtrip ok:", decoded == text)

### §2.7 Tokenizer Experiments

**Problem (tokenizer_experiments):** (a) Sample 10 docs; compression ratio (bytes/token) for TinyStories tokenizer. (b) Throughput (bytes/s); time for 825GB. (c) Encode train to uint16; why uint16?

**Deliverable:** Report in writeup. Below: compression ratio and throughput on TinyStories.

In [None]:
# §2.7: Tokenizer experiments on TinyStories (10 docs, compression, throughput, 825GB, uint16).
with open(path_ts) as f:
    full_ts = f.read()
docs_ts = full_ts.split(EOT)[:10]

def compression_ratio(tok, docs):
    bytes_ = sum(len(d.encode("utf-8")) for d in docs)
    tokens = sum(len(tok.encode(d)) for d in docs)
    return bytes_ / max(1, tokens)

cr_ts = compression_ratio(tok_ts, docs_ts)
print(f"(a) Compression ratio (bytes/token) on 10 TinyStories docs: {cr_ts:.2f}")

sample = full_ts[:100000]
t0 = time.perf_counter()
for _ in range(5):
    tok_ts.encode(sample)
throughput = (5 * len(sample.encode("utf-8"))) / (time.perf_counter() - t0)
print(f"(b) Throughput: ~{throughput:.0f} bytes/s  ->  time for 825GB: ~{825e9/throughput/3600:.1f} hours")
print("(c) uint16: vocab <=65536 so token IDs fit in uint16; saves memory.")

ids_ts = np.array(tok_ts.encode(full_ts), dtype=np.uint16)
np.save(DEMO_DIR / "tinystories_ids.npy", ids_ts)
print(f"Encoded TinyStories to uint16: {ids_ts.nbytes/1e6:.1f} MB saved to tinystories_ids.npy")

---
## §3 Transformer Language Model

**Problems:** linear, embedding, RMSNorm, positionwise_feedforward (SwiGLU), RoPE, softmax, scaled_dot_product_attention, multihead_self_attention, transformer_block, transformer_lm.

**Implementation:** `src/model.py` — pre-norm Transformer block, RoPE on Q/K, SwiGLU FFN, causal masking, tied LM head.

**Problem (transformer_accounting):** FLOPs, params, memory for GPT-2 XL/small/medium/large; which parts dominate. **Deliverable:** Writeup.

In [None]:
batch_size, seq_len = 2, 64
cfg = TransformerConfig(vocab_size=vocab_size, d_model=128, num_layers=4, num_heads=4, d_ff=512, max_seq_len=seq_len, dropout_rate=0.1)
model = TransformerLM(cfg)
rng = jax.random.PRNGKey(0)
params = model.init(rng, jnp.ones((batch_size, seq_len), dtype=jnp.int32), deterministic=True)
x = jnp.array(tokenizer.encode("The dog ran.")[:seq_len], dtype=jnp.int32)
x = jnp.broadcast_to(x, (batch_size, seq_len))
logits = model.apply(params, x, deterministic=True)
print("logits shape:", logits.shape)
nparams = sum(np.size(v) for _, v in jax.tree_util.tree_leaves(params))
print("Param count:", nparams)

---
## §4 Training a Transformer LM

**Problems:** cross_entropy, learning_rate_tuning (SGD toy), adamw, adamwAccounting, learning_rate_schedule (cosine+warmup), gradient_clipping, data_loading, checkpointing.

**Implementation:** `src/train.py` (cross_entropy_loss, get_batch, save/load_checkpoint), `src/optimizer.py` (AdamW).

**Deliverables (writeup):** learning_rate_tuning (SGD 1e1, 1e2, 1e3), adamwAccounting (peak memory, max batch 80GB, FLOPs/step, MFU, days 400K steps).

In [None]:
data = np.load(DEMO_DIR / "tinystories_ids.npy").astype(np.int32)
batch_size, seq_len, num_steps = 16, 128, 800
cfg = TransformerConfig(vocab_size=vocab_size, d_model=256, num_layers=4, num_heads=8, d_ff=1024, max_seq_len=seq_len, dropout_rate=0.1)
model = TransformerLM(cfg)
rng = jax.random.PRNGKey(42)
params = model.init(rng, jnp.ones((1, seq_len), dtype=jnp.int32), deterministic=True)
opt_state = init_adamw(params)
np_rng = np.random.default_rng(42)
grad_fn = jax.jit(jax.grad(lambda p, b: loss_fn(p, model, b)))
jit_loss = jax.jit(lambda p, b: loss_fn(p, model, b))
lr, wd = 3e-4, 0.01
losses = []
for step in tqdm(range(1, num_steps + 1), desc="Training"):
    batch = jnp.array(get_batch(data, batch_size, seq_len, np_rng))
    grads = grad_fn(params, batch)
    params, opt_state = adamw_step(params, grads, opt_state, lr=lr, weight_decay=wd)
    if step % 50 == 0:
        losses.append((step, float(jit_loss(params, batch))))
save_checkpoint(str(DEMO_DIR / "ckpt.pkl"), params, opt_state, num_steps, cfg)
print("Checkpoint saved.")

In [None]:
steps, vals = zip(*losses)
plt.figure(figsize=(8, 4))
plt.plot(steps, vals, "b-", label="Loss")
plt.plot(steps, np.exp(vals), "g-", alpha=0.7, label="Perplexity")
plt.xlabel("Step"); plt.ylabel("Loss / Perplexity"); plt.legend(); plt.title("Training on TinyStories"); plt.tight_layout(); plt.show()

---
## §5 Training Loop

**Problem (training_together):** Script with configurable hyperparameters, memory-efficient data loading (e.g. memmap), checkpointing, logging.

**Implementation:** `src/train.py`; CLI: `python src/train.py --data ... --tokenizer ... --config small`. Notebook above is a short in-memory version.

---
## §6 Generating Text

**Problem (decoding):** Generate from prompt until `<|endoftext|>` or max_tokens; temperature scaling; top-p (nucleus) sampling.

**Implementation:** `src/generate.py` — temperature, top_k; top_p can be added.

In [None]:
out = generate(model, params, tokenizer, "Once upon a time", max_tokens=280, temperature=0.8, top_k=40, seed=99)
print("Generated (>=256 tokens):")
print(out)
print(f"\nLength: {len(out)} chars, ~{len(tokenizer.encode(out))} tokens.")

---
## §7 Experiments

**experiment_log:** Logging + experiment log document.

**learning_rate:** Sweep LR on TinyStories; curves; val loss ≤1.45 (or ≤2.0 low-resource).

**batch_size_experiment:** Vary batch size; curves + commentary.

**generate:** ≥256 tokens dump + comment on fluency + factors (temperature, top-p, etc.).

**Ablations (writeup + curves):** layer_norm_ablation, pre_norm_ablation, no_pos_emb (NoPE vs RoPE), swiglu_ablation (SwiGLU vs SiLU, matched params).

**main_experiment:** Train on OpenWebText; learning curve; generated text; compare to TinyStories.

**leaderboard (optional):** Submit validation perplexity; see course repo.

**Full checklist:** `docs/CS336_ASSIGNMENT1_CHECKLIST.md` maps every A1 problem to this codebase and writeup.

In [None]:
# §7 Learning rate sweep (run all experiments; no manual steps)
exp_steps = 120
lrs = [1e-4, 5e-4, 2e-3]
lr_curves = {lr: [] for lr in lrs}
for lr in lrs:
    p = model.init(jax.random.PRNGKey(0), jnp.ones((1, seq_len), dtype=jnp.int32), deterministic=True)
    o = init_adamw(p)
    r = np.random.default_rng(int(lr * 1e6))
    for step in range(1, exp_steps + 1):
        batch = jnp.array(get_batch(data, 8, seq_len, r))
        g = grad_fn(p, batch)
        p, o = adamw_step(p, g, o, lr=lr, weight_decay=wd)
        if step % 30 == 0:
            lr_curves[lr].append((step, float(jit_loss(p, batch))))
plt.figure(figsize=(8, 4))
for lr in lrs:
    s, v = zip(*lr_curves[lr])
    plt.plot(s, v, label=f"lr={lr}")
plt.xlabel("Step"); plt.ylabel("Loss"); plt.legend(); plt.title("Learning rate sweep"); plt.tight_layout(); plt.show()

In [None]:
# §7 Batch size experiment
batch_sizes = [4, 16, 64]
bs_curves = {bs: [] for bs in batch_sizes}
for bs in batch_sizes:
    p = model.init(jax.random.PRNGKey(1), jnp.ones((1, seq_len), dtype=jnp.int32), deterministic=True)
    o = init_adamw(p)
    r = np.random.default_rng(42 + bs)
    for step in range(1, exp_steps + 1):
        batch = jnp.array(get_batch(data, bs, seq_len, r))
        g = grad_fn(p, batch)
        p, o = adamw_step(p, g, o, lr=5e-4, weight_decay=wd)
        if step % 30 == 0:
            bs_curves[bs].append((step, float(jit_loss(p, batch))))
plt.figure(figsize=(8, 4))
for bs in batch_sizes:
    s, v = zip(*bs_curves[bs])
    plt.plot(s, v, label=f"batch={bs}")
plt.xlabel("Step"); plt.ylabel("Loss"); plt.legend(); plt.title("Batch size experiment"); plt.tight_layout(); plt.show()
print("Done. All assignment runs completed end-to-end.")