# Tunix RT - Kaggle Submission Notebook

**Competition:** Google Tunix Hack - Train a model to show its work  
**Version:** `m38_v1` ‚Äî TPU Training with HBM OOM Fix

This notebook provides a single-session workflow for the Tunix Hack competition.

**Workflow:**
1. **Clone repository** (required on Kaggle)
2. Install dependencies
3. **Authenticate with HuggingFace** (required for gated Gemma models)
4. Configure training parameters
5. Build/load dataset
6. **Smoke test** (tiny model, validates pipeline on GPU)
7. **Full TPU training** (Gemma 2B, **requires TPU v3-8 or v5e-8**)
8. Generate predictions
9. Evaluate and score (eval_v2: 100 items with scorecard)
10. Display submission summary with RESULT SUMMARY block

**Runtime:**
- **Smoke test:** Any GPU (uses tiny model)
- **Full training:** **TPU v3-8 or v5e-8 REQUIRED** (Gemma 2B will NOT fit on GPU)

**Time:** ~1 min (smoke) / ~30-60 min (200 steps on TPU)

**‚ö†Ô∏è Important (M38):**
- Run the "Clone Repository" cell first before any other cells!
- For full training, you **MUST** switch to **TPU** accelerator (Settings ‚Üí Accelerator ‚Üí TPU v3-8)
- GPU training with Gemma 2B is **blocked** by the training script (will exit with error)
- **M38 Fix:** Uses `%run` instead of subprocess to avoid TPU VFIO conflicts


## 0. Clone Repository (Required on Kaggle)

**Run this cell first!** It clones the tunix-rt repository so all training scripts and tools are available.


In [None]:
# Clone the tunix-rt repository
# This provides all training scripts, tools, and configurations
# Uses absolute paths to prevent nested directory issues on re-run

import os

REPO_URL = "https://github.com/m-cahill/tunix-rt.git"
KAGGLE_WORKING = "/kaggle/working"
REPO_DIR = f"{KAGGLE_WORKING}/tunix-rt"  # Absolute path

# Check if already cloned (for re-running cells)
if os.path.exists(REPO_DIR):
    print(f"üìÅ Repository already exists at {REPO_DIR}")
else:
    print(f"üì• Cloning repository from {REPO_URL}...")
    os.chdir(KAGGLE_WORKING)
    !git clone {REPO_URL}
    print(f"‚úÖ Repository cloned successfully!")

# Always cd to the repo directory (idempotent - safe to re-run)
os.chdir(REPO_DIR)

# Verify we're in the right directory
print(f"\nüìç Working directory: {os.getcwd()}")
print(f"üìÇ Contents: {os.listdir('.')[:10]}...")  # Show first 10 items


## 1. Setup

Install dependencies and verify JAX is working.


In [None]:
# Install dependencies (Kaggle environment)
# Note: JAX with TPU support is pre-installed on Kaggle TPU runtimes.
#       For GPU, we install jax[cuda12]. Pin transformers<5 for Flax support.

# Detect if TPU is available (Kaggle TPU has JAX pre-installed)
import subprocess
result = subprocess.run(["pip", "show", "jax"], capture_output=True, text=True)
has_jax = result.returncode == 0

if not has_jax:
    print("Installing JAX with CUDA support...")
    !pip install -q "jax[cuda12]"

# Install other dependencies (transformers v4 for Flax support)
!pip install -q flax optax orbax-checkpoint "transformers>=4.40,<5" datasets pyyaml huggingface_hub

# Verify JAX installation
import jax
print(f"\nJAX version: {jax.__version__}")
print(f"Available devices: {jax.devices()}")
print(f"Default backend: {jax.default_backend()}")

# Warn about device type
backend = jax.default_backend()
if backend == "gpu":
    print("\n‚ö†Ô∏è  Running on GPU. Smoke tests will work, but full Gemma training may OOM.")
    print("   For full training, switch to TPU (Settings ‚Üí Accelerator ‚Üí TPU v3-8)")
elif backend == "tpu":
    print("\n‚úÖ Running on TPU. Full Gemma training is supported.")
else:
    print(f"\n‚ö†Ô∏è  Running on {backend}. Performance may be limited.")

print("\n‚úÖ Setup complete")


## 1.5 HuggingFace Authentication (Required)

**Gemma models are gated.** You must authenticate with HuggingFace before loading.

**Prerequisites:**
1. Accept the Gemma license at https://huggingface.co/google/gemma-2b-flax
2. Add your HuggingFace token as a Kaggle Secret named `HF_TOKEN`


In [None]:
# Authenticate with HuggingFace (required for gated Gemma models)
from kaggle_secrets import UserSecretsClient
from huggingface_hub import login

user_secrets = UserSecretsClient()
hf_token = user_secrets.get_secret("HF_TOKEN")
login(token=hf_token)

# Verify authentication
from huggingface_hub import whoami
try:
    user_info = whoami()
    print(f"‚úÖ Logged in to HuggingFace as: {user_info['name']}")
except Exception as e:
    print(f"‚ùå HuggingFace authentication failed: {e}")
    print("   Make sure HF_TOKEN secret is set in Kaggle Secrets")


## 2. Configuration

Configure training parameters below. The notebook supports two modes:
- **Smoke Mode:** Quick validation (2 steps, ~5 min)
- **Full Mode:** Complete training run (~1-2 hours)

**Note:** Paths are relative to the cloned repository root.


In [None]:
import subprocess
import sys
import json
from pathlib import Path

# ============================================================
# CONFIGURATION - Modify these values as needed
# ============================================================

# M38: TPU Training Config (200 steps, memory-optimized for Gemma 256K vocab)
# This is the PRIMARY config for submission training.
# ‚ö†Ô∏è  DO NOT use this on GPU ‚Äî it will block with an error.
# M38 fix: seq_len=128, batch=1, grad_accum=8, bfloat16 to avoid HBM OOM
CONFIG_PATH = "training/configs/submission_tpu.yaml"

# Config file for SMOKE TESTS (tiny model that fits on any GPU)
# This validates the pipeline works without OOM issues
SMOKE_CONFIG_PATH = "training/configs/smoke_tiny.yaml"

# Dataset selection
# Options: dev-reasoning-v2 (550 traces, recommended), golden-v2 (100 traces, quick sanity)
DATASET = "dev-reasoning-v2"

# Training parameters
SMOKE_STEPS = 2        # Smoke run: 2 steps for validation

# Device selection
# M37: Use "tpu" for full training, "auto" for smoke tests
DEVICE_SMOKE = "auto"  # auto-detect GPU/TPU for smoke tests
DEVICE_FULL = "tpu"    # Explicit TPU for full training (M37)

# Output directories
OUTPUT_DIR = "./output/tpu_run"      # M37: TPU training output
SMOKE_OUTPUT_DIR = "./output/smoke_run"

# Evaluation (M36+: eval_v2 with 100 items and scorecard support)
# Options: eval_v2.jsonl (100 items, recommended), eval_v1.jsonl (50 items, legacy)
EVAL_SET = "training/evalsets/eval_v2.jsonl"

# ============================================================

# Load configs to display model names
import yaml
with open(CONFIG_PATH) as f:
    config = yaml.safe_load(f)
MODEL_NAME = config.get('model', {}).get('name', 'unknown')
MAX_STEPS = config.get('training', {}).get('num_steps', 100)
SEED = config.get('training', {}).get('seed', 42)

with open(SMOKE_CONFIG_PATH) as f:
    smoke_config = yaml.safe_load(f)
SMOKE_MODEL_NAME = smoke_config.get('model', {}).get('name', 'unknown')

print("Configuration:")
print(f"  Full Config:  {CONFIG_PATH}")
print(f"  Full Model:   {MODEL_NAME}")
print(f"  Smoke Config: {SMOKE_CONFIG_PATH}")
print(f"  Smoke Model:  {SMOKE_MODEL_NAME}")
print(f"  Dataset:      {DATASET}")
print(f"  Max Steps:    {MAX_STEPS} (from full config)")
print(f"  Smoke Steps:  {SMOKE_STEPS}")
print(f"  Seed:         {SEED}")
print(f"  Device (Smoke): {DEVICE_SMOKE}")
print(f"  Device (Full):  {DEVICE_FULL}")
print(f"  Eval Set:     {EVAL_SET}")
print(f"  Output:       {OUTPUT_DIR}")


## 3. Build Dataset

Seed scripts are located in `backend/tools/` and write to `backend/datasets/`.


In [None]:
# Build the selected dataset
# Note: Datasets are deterministically seeded (seed=42)

if DATASET == "dev-reasoning-v2":
    subprocess.run([sys.executable, "backend/tools/seed_dev_reasoning_v2.py"])
elif DATASET == "golden-v2":
    subprocess.run([sys.executable, "backend/tools/seed_golden_v2.py"])
elif DATASET == "dev-reasoning-v1":
    subprocess.run([sys.executable, "backend/tools/seed_dev_reasoning_v1.py"])
else:
    print(f"‚ö†Ô∏è  Dataset {DATASET} not recognized, assuming it already exists")

# Verify dataset exists
dataset_path = Path(f"backend/datasets/{DATASET}")
if dataset_path.exists():
    manifest_path = dataset_path / "manifest.json"
    if manifest_path.exists():
        with open(manifest_path) as f:
            manifest = json.load(f)
        print(f"\n‚úÖ Dataset ready: {DATASET}")
        print(f"   Traces: {manifest.get('trace_count', 'N/A')}")
    else:
        print(f"\n‚ö†Ô∏è  Manifest not found at {manifest_path}")
else:
    print(f"\n‚ùå Dataset directory not found: {dataset_path}")


## 4a. Smoke Run (Quick Validation)

Run this cell first to validate the pipeline works before the full training run.

**Recommended:** Always run smoke first to verify environment before committing to full training.


In [None]:
# SMOKE RUN - Quick validation (2 steps)
# This confirms imports, dataset loading, and basic training work correctly
#
# NOTE: We use --smoke_config to load a tiny model (sshleifer/tiny-gpt2) that
# fits on any GPU. This validates the PIPELINE without OOM issues.
# The full config (CONFIG_PATH) is still passed but not used during smoke.

print("üî• Starting Smoke Run (2 steps)...")
print("=" * 60)
print(f"   Using smoke config: {SMOKE_CONFIG_PATH}")
print(f"   Smoke model: {SMOKE_MODEL_NAME}")
print(f"   Device: {DEVICE_SMOKE}")
print("=" * 60)

smoke_cmd = [
    sys.executable, "training/run_train_jax.py",  # Use launcher for XLA env vars
    "--config", CONFIG_PATH,
    "--smoke_config", SMOKE_CONFIG_PATH,  # Tiny model for smoke
    "--output", SMOKE_OUTPUT_DIR,
    "--dataset", DATASET,
    "--device", DEVICE_SMOKE,  # M37: Use DEVICE_SMOKE for smoke tests
    "--smoke_steps", str(SMOKE_STEPS),
]

print(f"Command: {' '.join(smoke_cmd)}\n")

result = subprocess.run(smoke_cmd, capture_output=False)

if result.returncode == 0:
    print("\n" + "=" * 60)
    print("‚úÖ Smoke run completed successfully!")
    print("   Pipeline validated. Ready for full TPU training.")
    print("   ‚ö†Ô∏è  Ensure you have TPU v3-8 selected for the next cell!")
else:
    print(f"\n‚ùå Smoke run failed with exit code {result.returncode}")


## 4b. Full TPU Training Run (M38)

Run this cell for the complete training with Gemma 2B on TPU.

**‚ö†Ô∏è REQUIRES TPU v3-8 or v5e-8:** 
- Gemma 2B does NOT fit on Kaggle T4 GPU (verified in M36)
- The training script will **exit with error** if you try to run Gemma on GPU
- Go to Settings ‚Üí Accelerator ‚Üí TPU v3-8 **before** running this cell

**M38 Changes:**
- Uses `%run` instead of subprocess to avoid TPU VFIO device conflicts
- Config reduced to seq_len=128, batch=1, grad_accum=8 to avoid HBM OOM
- Uses bfloat16 (native TPU support, saves ~50% memory)

**Time budget:** ~30-60 min for 200 steps on TPU.


In [None]:
# FULL TPU TRAINING RUN (M38)
# This runs the complete training pipeline with the configured parameters
#
# M38: Uses runpy.run_path() to avoid TPU VFIO device conflicts.
# After JAX initializes TPU, subprocess.run() can cause "device busy" errors.
# runpy executes in the same Python process, avoiding this issue.
#
# ‚ö†Ô∏è  REQUIRES TPU: This will exit with error if you try to run on GPU.

import jax
import sys
import runpy

backend = jax.default_backend()
if backend != "tpu":
    print("=" * 60)
    print("‚ùå ERROR: TPU NOT DETECTED")
    print("=" * 60)
    print(f"   Current backend: {backend}")
    print(f"   Expected: tpu")
    print("")
    print("   To fix:")
    print("   1. Go to Settings (gear icon in top right)")
    print("   2. Scroll to 'Accelerator'")
    print("   3. Select 'TPU v3-8' or 'TPU v5e-8'")
    print("   4. Click 'Save' and wait for session restart")
    print("   5. Re-run this notebook from Cell 2")
    print("=" * 60)
    raise RuntimeError("TPU required for full training. See instructions above.")

print("üöÄ Starting Full TPU Training Run (M38)...")
print("=" * 60)
print(f"Config:    {CONFIG_PATH}")
print(f"Model:     {MODEL_NAME}")
print(f"Dataset:   {DATASET}")
print(f"Steps:     {MAX_STEPS} (from config)")
print(f"Device:    {DEVICE_FULL} (TPU)")
print(f"Output:    {OUTPUT_DIR}")
print("=" * 60 + "\n")

# M38: Use runpy.run_path() instead of %run to properly catch failures
# This runs the training script in the same Python process
train_args = [
    "training/run_train_jax.py",
    "--config", CONFIG_PATH,
    "--output", OUTPUT_DIR,
    "--dataset", DATASET,
    "--device", DEVICE_FULL,
    "--save_every_steps", "50"
]
print(f"Command: python {' '.join(train_args)}\n")

# Save original sys.argv and replace with training args
original_argv = sys.argv
sys.argv = train_args
training_success = False

try:
    runpy.run_path("training/run_train_jax.py", run_name="__main__")
    training_success = True
except SystemExit as e:
    if e.code == 0:
        training_success = True
    else:
        print("\n" + "=" * 60)
        print(f"‚ùå Training FAILED with exit code {e.code}")
        print("=" * 60)
        print("   Check the error messages above for details.")
        print("   Common issues:")
        print("   ‚Ä¢ HBM OOM: reduce max_length or batch_size in config")
        print("   ‚Ä¢ TPU not available: check accelerator settings")
        print("=" * 60)
except Exception as e:
    print("\n" + "=" * 60)
    print(f"‚ùå Training FAILED with exception: {type(e).__name__}")
    print("=" * 60)
    print(f"   {e}")
finally:
    sys.argv = original_argv

if training_success:
    print("\n" + "=" * 60)
    print("‚úÖ TPU Training completed successfully!")
    print("   Evidence artifacts saved to:", OUTPUT_DIR)
    print("=" * 60)
else:
    print("\n‚ö†Ô∏è  Training did not complete. See errors above.")


## 5. Generate Predictions

Generate predictions on the evaluation set using the trained checkpoint.


In [None]:
# Generate predictions on the evaluation set

predictions_file = f"{OUTPUT_DIR}/predictions.jsonl"

print("üìä Generating predictions...")
print("=" * 60)

eval_cmd = [
    sys.executable, "training/eval_generate.py",
    "--checkpoint", OUTPUT_DIR,
    "--eval_set", EVAL_SET,
    "--output", predictions_file,
]

print(f"Command: {' '.join(eval_cmd)}\n")

result = subprocess.run(eval_cmd, capture_output=False)

if result.returncode == 0:
    print("\n‚úÖ Predictions generated successfully!")
else:
    print(f"\n‚ùå Prediction generation failed with exit code {result.returncode}")


## 6. Evaluate & Score

Score predictions using the evaluation script with scorecard breakdown.


In [None]:
# Score predictions using the evaluation script

print("üìà Scoring predictions...")
print("=" * 60)

score_cmd = [
    sys.executable, "training/eval_report.py",
    "--predictions", predictions_file,
    "--eval_set", EVAL_SET,
]

print(f"Command: {' '.join(score_cmd)}\n")

result = subprocess.run(score_cmd, capture_output=False)

if result.returncode == 0:
    print("\n‚úÖ Evaluation complete!")
else:
    print(f"\n‚ùå Evaluation failed with exit code {result.returncode}")


## 7. Submission Summary

Displays final results with a **RESULT SUMMARY** block for easy evidence capture.


In [None]:
# Display final submission summary with RESULT SUMMARY block for evidence capture

print("\n" + "=" * 60)
print("         SUBMISSION SUMMARY")
print("=" * 60)

output_path = Path(OUTPUT_DIR)

# Model info
print(f"\nüì¶ Model ID: {MODEL_NAME}")
print(f"üìÅ Dataset:  {DATASET}")
print(f"üìã Eval Set: {EVAL_SET}")
print(f"üî¢ Steps:    {MAX_STEPS}")
print(f"üé≤ Seed:     {SEED}")

# Training metrics
metrics_file = output_path / "metrics.jsonl"
final_loss = None
if metrics_file.exists():
    print(f"\nüìä Training Metrics (last 5 steps):")
    with open(metrics_file, "r") as f:
        lines = f.readlines()
        for line in lines[-5:]:
            metric = json.loads(line)
            step = metric.get('step', '?')
            loss = metric.get('loss', '?')
            if isinstance(loss, float):
                print(f"   Step {step}: loss={loss:.4f}")
                final_loss = loss
            else:
                print(f"   Step {step}: loss={loss}")
else:
    print(f"\n‚ö†Ô∏è  Metrics file not found at {metrics_file}")

# Eval score and scorecard
eval_results_file = output_path / "eval_results.json"
primary_score = None
scorecard_info = {}
if eval_results_file.exists():
    with open(eval_results_file, "r") as f:
        results = json.load(f)
        primary_score = results.get('primary_score', results.get('answer_correctness'))
        scorecard_info = results.get('scorecard', {})
        if isinstance(primary_score, float):
            print(f"\nüéØ Primary Score: {primary_score:.4f} ({primary_score * 100:.1f}%)")
        else:
            print(f"\nüéØ Primary Score: {primary_score}")
        
        # M36: Display scorecard if available
        if scorecard_info:
            n_items = scorecard_info.get('n_items', '?')
            n_scored = scorecard_info.get('n_scored', '?')
            print(f"üìä Scorecard: {n_scored}/{n_items} items scored")
            section_scores = scorecard_info.get('section_scores', {})
            for section, score in section_scores.items():
                if score is not None:
                    print(f"   {section}: {score:.2f}")
else:
    print(f"\n‚ö†Ô∏è  Eval results not found (run evaluation cell first)")

# Artifact paths
print(f"\nüìÇ Artifact Paths:")
if output_path.exists():
    checkpoints = list(output_path.glob("checkpoint*"))
    for ckpt in checkpoints:
        print(f"   {ckpt}")
if metrics_file.exists():
    print(f"   {metrics_file}")
preds_path = Path(predictions_file)
if preds_path.exists():
    print(f"   {preds_path}")

# M36: Print RESULT SUMMARY block for evidence capture
print("\n" + "=" * 60)
print("         RESULT SUMMARY (copy to evidence files)")
print("=" * 60)
print(f"model_id: {MODEL_NAME}")
print(f"dataset: {DATASET}")
print(f"eval_set: {EVAL_SET}")
print(f"primary_score: {primary_score}")
print(f"final_loss: {final_loss}")
print(f"n_items: {scorecard_info.get('n_items', 'N/A')}")
print(f"n_scored: {scorecard_info.get('n_scored', 'N/A')}")
print("=" * 60)

print("\n‚úÖ Submission package ready!")
print("   See docs/submission_checklist.md for next steps.")
print("   Evidence folder: submission_runs/m37_v1/")
print("=" * 60)
