# Transformer LM Demo: Train & Generate

Short training demo and text generation with the JAX/XLA Transformer LM (BPE, RoPE, RMSNorm, SwiGLU, KV cache, MLA/DSA).

**Google Colab:** Run the first cell to clone the repo and install dependencies, then run all.  
**Local:** Run from `project-2-lm-jax/` (or set kernel cwd there).

In [13]:
import sys
import os
import pickle
import tempfile
import numpy as np
from pathlib import Path

# --- Colab: clone repo and install deps; local: use cwd ---
IN_COLAB = "google.colab" in sys.modules
if IN_COLAB:
    repo_dir = Path("/content/JAX-XLA-LM")
    if not (repo_dir / "src" / "tokenizer.py").exists():
        get_ipython().system("git clone --depth 1 https://github.com/ns-1456/JAX-XLA-LM.git /content/JAX-XLA-LM")
    ROOT = repo_dir
    get_ipython().run_line_magic("pip", "install -q -r " + str(ROOT / "requirements.txt"))
else:
    ROOT = Path.cwd()
    if not (ROOT / "src" / "model.py").exists() and (ROOT.parent / "src" / "model.py").exists():
        ROOT = ROOT.parent

sys.path.insert(0, str(ROOT / "src"))

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
from tqdm import tqdm

from tokenizer import train_bpe, Tokenizer
from model import TransformerConfig, TransformerLM, TINY_CONFIG
from optimizer import init_adamw, adamw_step
from train import cross_entropy_loss, get_batch, loss_fn, save_checkpoint
from generate import generate, init_mla_params

DEMO_DIR = ROOT / "notebooks" / "demo_outputs"
DEMO_DIR.mkdir(parents=True, exist_ok=True)
print("ROOT:", ROOT)
print("JAX devices:", jax.devices())
print("Colab:", IN_COLAB)

ModuleNotFoundError: No module named 'tokenizer'

## 1. Train BPE tokenizer

In [None]:
SAMPLE = (
    "The little dog ran in the park. The cat sat on the mat. "
    "Once upon a time there was a girl. She had a big red ball. "
    "The sun was bright. The bird flew high in the sky. "
) * 100

with tempfile.NamedTemporaryFile(mode="w", suffix=".txt", delete=False) as f:
    f.write(SAMPLE)
    temp_path = f.name
try:
    vocab, merges = train_bpe(temp_path, vocab_size=1024, special_tokens=["<|endoftext|>"])
finally:
    os.unlink(temp_path)

tokenizer = Tokenizer(vocab=vocab, merges=merges, special_tokens=["<|endoftext|>"])
tokenizer.save(str(DEMO_DIR / "tokenizer.json"))
vocab_size = len(vocab)
print("Vocab size:", vocab_size)
print("Encode 'The dog ran':", tokenizer.encode("The dog ran.")[:8])

## 2. Prepare data and train model (short run)

In [None]:
data = np.array(tokenizer.encode(SAMPLE), dtype=np.int32)
batch_size, seq_len, num_steps = 8, 64, 200
cfg = TransformerConfig(
    vocab_size=vocab_size,
    d_model=128,
    num_layers=4,
    num_heads=4,
    d_ff=512,
    max_seq_len=seq_len,
    dropout_rate=0.1,
)

model = TransformerLM(cfg)
rng = jax.random.PRNGKey(42)
np_rng = np.random.default_rng(42)
params = model.init(rng, jnp.ones((1, seq_len), dtype=jnp.int32), deterministic=True)
opt_state = init_adamw(params)

grad_fn = jax.jit(jax.grad(lambda p, b: loss_fn(p, model, b)))
jit_loss = jax.jit(lambda p, b: loss_fn(p, model, b))
lr, wd = 3e-4, 0.01

losses = []
for step in tqdm(range(1, num_steps + 1), desc="Training"):
    batch = jnp.array(get_batch(data, batch_size, seq_len, np_rng))
    grads = grad_fn(params, batch)
    params, opt_state = adamw_step(params, grads, opt_state, lr=lr, weight_decay=wd)
    if step % 10 == 0:
        loss_val = float(jit_loss(params, batch))
        losses.append((step, loss_val))

save_checkpoint(str(DEMO_DIR / "ckpt.pkl"), params, opt_state, num_steps, cfg)
print("Checkpoint saved.")

## 3. Loss curve

In [None]:
steps, vals = zip(*losses)
plt.figure(figsize=(8, 4))
plt.plot(steps, vals, "b-", label="Loss")
plt.plot(steps, np.exp(vals), "g-", alpha=0.7, label="Perplexity")
plt.xlabel("Step")
plt.ylabel("Loss / Perplexity")
plt.legend()
plt.title("Training loss (demo)")
plt.tight_layout()
plt.show()

## 4. Generate text (standard KV cache)

In [None]:
out = generate(model, params, tokenizer, "The dog", max_tokens=40, temperature=0.8, top_k=40, seed=123)
print("Generated (KV cache):")
print(out)

## 5. Generate with MLA + DSA (sparse top-k)

In [None]:
mla_params = init_mla_params(jax.random.PRNGKey(0), cfg, d_latent=32)
out_mla = generate(model, params, tokenizer, "The cat", max_tokens=40, temperature=0.8, top_k=40,
                   seed=123, use_mla=True, mla_params=mla_params, sparse_top_k=16)
print("Generated (MLA + DSA top-16):")
print(out_mla)

## 6. Tokenizer experiments (CS336 A1 §2.7)

Compression ratio (bytes/token) and rough throughput. For full experiments use 10 sampled documents and report in writeup.

In [None]:
sample_text = SAMPLE[: 5000]
ids = tokenizer.encode(sample_text)
bytes_per_token = len(sample_text.encode("utf-8")) / max(1, len(ids))
print(f"Compression ratio (bytes/token): {bytes_per_token:.2f}")
import time
t0 = time.perf_counter()
for _ in range(10):
    tokenizer.encode(sample_text)
elapsed = time.perf_counter() - t0
throughput_bps = (10 * len(sample_text.encode("utf-8"))) / elapsed
print(f"Throughput (bytes/s): ~{throughput_bps:.0f}")
print("Perplexity (exp of mean loss) on last batch:", float(np.exp(np.mean(vals))))

## CS336 Assignment 1 alignment

This notebook runs: **BPE training**, **data loading**, **training loop**, **checkpointing**, **loss/perplexity curve**, **decoding** (temperature + top-k), and **tokenizer experiments** (compression ratio, throughput, perplexity). It does **not** implement the written problems (Unicode, accounting, ablations, etc.).

- **Full checklist:** [docs/CS336_ASSIGNMENT1_CHECKLIST.md](../docs/CS336_ASSIGNMENT1_CHECKLIST.md) — maps every A1 problem to this codebase and to writeup deliverables.
- **Written deliverables** (for writeup.pdf): §2.1 unicode1, §2.2 unicode2, §2.5 train_bpe_tinystories/owt, §2.7 tokenizer_experiments, §3 resource accounting, §4 learning_rate_tuning & adamwAccounting, §5 checkpointing, §7 learning_rate, batch_size_experiment, generate, ablations (layer_norm, pre_norm, no_pos_emb, swiglu), main_experiment (OWT).