# Section 7: Experiment Analysis

This notebook pulls training results from W&B and generates plots/tables for the writeup.

In [None]:
import wandb
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np

api = wandb.Api()
PROJECT = "ece496b-lm"

plt.rcParams.update({"figure.figsize": (10, 6), "font.size": 12})

## 7.2a Learning Rate Sweep
Plot training and validation loss curves for each learning rate.

In [None]:
# Fetch all LR sweep runs
runs = api.runs(PROJECT, filters={"display_name": {"$regex": "ts-lr-"}})

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))

for run in sorted(runs, key=lambda r: r.config.get("max_learning_rate", 0)):
    lr = run.config.get("max_learning_rate", "?")
    history = run.history(keys=["train/loss", "val/loss"], samples=5000)
    label = f"lr={lr}"

    # Training loss
    train = history.dropna(subset=["train/loss"])
    ax1.plot(train["_step"], train["train/loss"], label=label, alpha=0.8)

    # Validation loss
    val = history.dropna(subset=["val/loss"])
    if not val.empty:
        ax2.plot(val["_step"], val["val/loss"], label=label, marker="o", markersize=3)

ax1.set_xlabel("Step")
ax1.set_ylabel("Training Loss")
ax1.set_title("Training Loss — LR Sweep")
ax1.legend()
ax1.set_ylim(0, 12)
ax1.grid(True, alpha=0.3)

ax2.set_xlabel("Step")
ax2.set_ylabel("Validation Loss")
ax2.set_title("Validation Loss — LR Sweep")
ax2.legend()
ax2.set_ylim(0, 12)
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig("../outputs/lr_sweep.png", dpi=150, bbox_inches="tight")
plt.show()

In [None]:
# Summary table: final val loss for each LR
rows = []
for run in sorted(runs, key=lambda r: r.config.get("max_learning_rate", 0)):
    lr = run.config.get("max_learning_rate", "?")
    val_history = run.history(keys=["val/loss"], samples=5000).dropna(subset=["val/loss"])
    final_val = val_history["val/loss"].iloc[-1] if not val_history.empty else float("nan")
    status = "diverged" if run.state == "crashed" or final_val > 10 else f"{final_val:.4f}"
    rows.append({"Learning Rate": lr, "Final Val Loss": status, "State": run.state})

df = pd.DataFrame(rows)
print(df.to_markdown(index=False))

## 7.2b Edge of Stability
Analysis of how divergence threshold relates to the best learning rate.

In [None]:
# TODO: After sweep completes, identify:
# - Best LR (lowest final val loss)
# - Highest non-divergent LR
# - First divergent LR
# Analyze relationship between best LR and edge of stability

## Batch Size Experiment

In [None]:
# TODO: Fetch batch size sweep runs and plot loss curves
# runs = api.runs(PROJECT, filters={"display_name": {"$regex": "ts-bs-"}})

## Text Generation
Load best checkpoint and generate 256+ tokens.

In [None]:
import sys
sys.path.insert(0, "..")
import pickle
import torch
from ece496b_basics import TransformerLM, Tokenizer, generate, load_checkpoint, AdamW

# Load tokenizer
with open("../outputs/ts_vocab_10k.pkl", "rb") as f:
    vocab = pickle.load(f)
with open("../outputs/ts_merges_10k.pkl", "rb") as f:
    merges = pickle.load(f)
tokenizer = Tokenizer(vocab=vocab, merges=merges, special_tokens=["<|endoftext|>"])

# Load model from best checkpoint
# TODO: update path to best checkpoint
CKPT_PATH = "../checkpoints/lr-1e-3/ckpt_final.pt"
model = TransformerLM(
    vocab_size=10_000, context_length=256, d_model=512,
    num_layers=8, num_heads=8, d_ff=1088, rope_theta=10000.0,
)
optimizer = AdamW(model.parameters())
load_checkpoint(CKPT_PATH, model, optimizer)
model.eval()

# Generate
eos_id = tokenizer.encode("<|endoftext|>", special_tokens=True)[0]
prompt = "Once upon a time"
prompt_ids = tokenizer.encode(prompt)
gen_ids = generate(model, prompt_ids, max_tokens=256, temperature=0.8, top_p=0.9, eos_token_id=eos_id)
print(tokenizer.decode(gen_ids))

## 7.3 Ablations
Compare baseline vs ablation runs.

In [None]:
# TODO: After ablation runs complete, fetch and plot
# ablation_names = ["baseline", "no-rmsnorm", "post-norm", "no-rope", "ffn-silu"]
# runs = api.runs(PROJECT, filters={"display_name": {"$regex": "ts-ablation-"}})

## 7.4 OWT Training

In [None]:
# TODO: Fetch OWT run and plot loss curves + generate text