# Tunix RT - Kaggle Submission Notebook

**Competition:** Google Tunix Hack - Train a model to show its work  
**Version:** `m36_v4`

This notebook provides a single-session workflow for the Tunix Hack competition.

**Workflow:**
1. **Clone repository** (required on Kaggle)
2. Install dependencies
3. Configure training parameters
4. Build/load dataset
5. Train model (JAX/Flax) - Smoke or Full mode
6. Generate predictions
7. Evaluate and score (eval_v2: 100 items with scorecard)
8. Display submission summary with RESULT SUMMARY block

**Runtime:** Kaggle TPU or GPU (recommended)  
**Time:** ~5 min (smoke) / ~1-2 hours (full)

**‚ö†Ô∏è Important:** Run the "Clone Repository" cell first before any other cells!


## 0. Clone Repository (Required on Kaggle)

**Run this cell first!** It clones the tunix-rt repository so all training scripts and tools are available.


In [None]:
# Clone the tunix-rt repository
# This provides all training scripts, tools, and configurations
# Uses absolute paths to prevent nested directory issues on re-run

import os

REPO_URL = "https://github.com/m-cahill/tunix-rt.git"
KAGGLE_WORKING = "/kaggle/working"
REPO_DIR = f"{KAGGLE_WORKING}/tunix-rt"  # Absolute path

# Check if already cloned (for re-running cells)
if os.path.exists(REPO_DIR):
    print(f"üìÅ Repository already exists at {REPO_DIR}")
else:
    print(f"üì• Cloning repository from {REPO_URL}...")
    os.chdir(KAGGLE_WORKING)
    !git clone {REPO_URL}
    print(f"‚úÖ Repository cloned successfully!")

# Always cd to the repo directory (idempotent - safe to re-run)
os.chdir(REPO_DIR)

# Verify we're in the right directory
print(f"\nüìç Working directory: {os.getcwd()}")
print(f"üìÇ Contents: {os.listdir('.')[:10]}...")  # Show first 10 items


## 1. Setup

Install dependencies and verify JAX is working.


In [None]:
# Install dependencies (Kaggle environment)
# Note: JAX with TPU support is pre-installed on Kaggle TPU runtimes
!pip install -q jax[cuda12] flax optax orbax-checkpoint transformers datasets pyyaml

# Verify JAX installation
import jax
print(f"JAX version: {jax.__version__}")
print(f"Available devices: {jax.devices()}")
print("\n‚úÖ Setup complete")


## 2. Configuration

Configure training parameters below. The notebook supports two modes:
- **Smoke Mode:** Quick validation (2 steps, ~5 min)
- **Full Mode:** Complete training run (~1-2 hours)

**Note:** Paths are relative to the cloned repository root.


In [None]:
import subprocess
import sys
import json
from pathlib import Path

# ============================================================
# CONFIGURATION - Modify these values as needed
# ============================================================

# Config file selection (model name is inside the config)
# Options:
#   - training/configs/submission_gemma3_1b.yaml (Gemma 3 1B-it, recommended)
#   - training/configs/submission_gemma2_2b.yaml (Gemma 2 2B)
#   - training/configs/sft_tiny.yaml (for testing only)
CONFIG_PATH = "training/configs/submission_gemma3_1b.yaml"

# Dataset selection
# Options: dev-reasoning-v2 (550 traces, recommended), golden-v2 (100 traces, quick sanity)
DATASET = "dev-reasoning-v2"

# Training parameters
SMOKE_STEPS = 2        # Smoke run: 2 steps for validation

# Device selection
DEVICE = "auto"  # auto-detect GPU/TPU, or "cpu" for testing

# Output directories
OUTPUT_DIR = "./output/kaggle_run"
SMOKE_OUTPUT_DIR = "./output/smoke_run"

# Evaluation (M36: eval_v2 with 100 items and scorecard support)
# Options: eval_v2.jsonl (100 items, recommended), eval_v1.jsonl (50 items, legacy)
EVAL_SET = "training/evalsets/eval_v2.jsonl"

# ============================================================

# Load config to display model name
import yaml
with open(CONFIG_PATH) as f:
    config = yaml.safe_load(f)
MODEL_NAME = config.get('model', {}).get('name', 'unknown')
MAX_STEPS = config.get('training', {}).get('num_steps', 100)
SEED = config.get('training', {}).get('seed', 42)

print("Configuration:")
print(f"  Config:      {CONFIG_PATH}")
print(f"  Model:       {MODEL_NAME}")
print(f"  Dataset:     {DATASET}")
print(f"  Max Steps:   {MAX_STEPS} (from config)")
print(f"  Smoke Steps: {SMOKE_STEPS}")
print(f"  Seed:        {SEED}")
print(f"  Device:      {DEVICE}")
print(f"  Eval Set:    {EVAL_SET}")
print(f"  Output:      {OUTPUT_DIR}")


## 3. Build Dataset

Seed scripts are located in `backend/tools/` and write to `backend/datasets/`.


In [None]:
# Build the selected dataset
# Note: Datasets are deterministically seeded (seed=42)

if DATASET == "dev-reasoning-v2":
    subprocess.run([sys.executable, "backend/tools/seed_dev_reasoning_v2.py"])
elif DATASET == "golden-v2":
    subprocess.run([sys.executable, "backend/tools/seed_golden_v2.py"])
elif DATASET == "dev-reasoning-v1":
    subprocess.run([sys.executable, "backend/tools/seed_dev_reasoning_v1.py"])
else:
    print(f"‚ö†Ô∏è  Dataset {DATASET} not recognized, assuming it already exists")

# Verify dataset exists
dataset_path = Path(f"backend/datasets/{DATASET}")
if dataset_path.exists():
    manifest_path = dataset_path / "manifest.json"
    if manifest_path.exists():
        with open(manifest_path) as f:
            manifest = json.load(f)
        print(f"\n‚úÖ Dataset ready: {DATASET}")
        print(f"   Traces: {manifest.get('trace_count', 'N/A')}")
    else:
        print(f"\n‚ö†Ô∏è  Manifest not found at {manifest_path}")
else:
    print(f"\n‚ùå Dataset directory not found: {dataset_path}")


## 4a. Smoke Run (Quick Validation)

Run this cell first to validate the pipeline works before the full training run.

**Recommended:** Always run smoke first to verify environment before committing to full training.


In [None]:
# SMOKE RUN - Quick validation (2 steps)
# This confirms imports, dataset loading, and basic training work correctly

print("üî• Starting Smoke Run (2 steps)...")
print("=" * 60)

smoke_cmd = [
    sys.executable, "training/train_jax.py",
    "--config", CONFIG_PATH,
    "--output", SMOKE_OUTPUT_DIR,
    "--dataset", DATASET,
    "--device", DEVICE,
    "--smoke_steps", str(SMOKE_STEPS),
]

print(f"Command: {' '.join(smoke_cmd)}\n")

result = subprocess.run(smoke_cmd, capture_output=False)

if result.returncode == 0:
    print("\n" + "=" * 60)
    print("‚úÖ Smoke run completed successfully!")
    print("   Pipeline validated. Ready for full training.")
else:
    print(f"\n‚ùå Smoke run failed with exit code {result.returncode}")


## 4b. Full Training Run

Run this cell for the complete training.

**Time budget:** ~1-2 hours for 100 steps on TPU/GPU.


In [None]:
# FULL TRAINING RUN
# This runs the complete training pipeline with the configured parameters

print("üöÄ Starting Full Training Run...")
print("=" * 60)
print(f"Config:    {CONFIG_PATH}")
print(f"Model:     {MODEL_NAME}")
print(f"Dataset:   {DATASET}")
print(f"Steps:     {MAX_STEPS} (from config)")
print(f"Output:    {OUTPUT_DIR}")
print("=" * 60 + "\n")

train_cmd = [
    sys.executable, "training/train_jax.py",
    "--config", CONFIG_PATH,
    "--output", OUTPUT_DIR,
    "--dataset", DATASET,
    "--device", DEVICE,
    "--save_every_steps", "50",
]

print(f"Command: {' '.join(train_cmd)}\n")

result = subprocess.run(train_cmd, capture_output=False)

if result.returncode == 0:
    print("\n" + "=" * 60)
    print("‚úÖ Training completed successfully!")
else:
    print(f"\n‚ùå Training failed with exit code {result.returncode}")


## 5. Generate Predictions

Generate predictions on the evaluation set using the trained checkpoint.


In [None]:
# Generate predictions on the evaluation set

predictions_file = f"{OUTPUT_DIR}/predictions.jsonl"

print("üìä Generating predictions...")
print("=" * 60)

eval_cmd = [
    sys.executable, "training/eval_generate.py",
    "--checkpoint", OUTPUT_DIR,
    "--eval_set", EVAL_SET,
    "--output", predictions_file,
]

print(f"Command: {' '.join(eval_cmd)}\n")

result = subprocess.run(eval_cmd, capture_output=False)

if result.returncode == 0:
    print("\n‚úÖ Predictions generated successfully!")
else:
    print(f"\n‚ùå Prediction generation failed with exit code {result.returncode}")


## 6. Evaluate & Score

Score predictions using the evaluation script with scorecard breakdown.


In [None]:
# Score predictions using the evaluation script

print("üìà Scoring predictions...")
print("=" * 60)

score_cmd = [
    sys.executable, "training/eval_report.py",
    "--predictions", predictions_file,
    "--eval_set", EVAL_SET,
]

print(f"Command: {' '.join(score_cmd)}\n")

result = subprocess.run(score_cmd, capture_output=False)

if result.returncode == 0:
    print("\n‚úÖ Evaluation complete!")
else:
    print(f"\n‚ùå Evaluation failed with exit code {result.returncode}")


## 7. Submission Summary

Displays final results with a **RESULT SUMMARY** block for easy evidence capture.


In [None]:
# Display final submission summary with RESULT SUMMARY block for evidence capture

print("\n" + "=" * 60)
print("         SUBMISSION SUMMARY")
print("=" * 60)

output_path = Path(OUTPUT_DIR)

# Model info
print(f"\nüì¶ Model ID: {MODEL_NAME}")
print(f"üìÅ Dataset:  {DATASET}")
print(f"üìã Eval Set: {EVAL_SET}")
print(f"üî¢ Steps:    {MAX_STEPS}")
print(f"üé≤ Seed:     {SEED}")

# Training metrics
metrics_file = output_path / "metrics.jsonl"
final_loss = None
if metrics_file.exists():
    print(f"\nüìä Training Metrics (last 5 steps):")
    with open(metrics_file, "r") as f:
        lines = f.readlines()
        for line in lines[-5:]:
            metric = json.loads(line)
            step = metric.get('step', '?')
            loss = metric.get('loss', '?')
            if isinstance(loss, float):
                print(f"   Step {step}: loss={loss:.4f}")
                final_loss = loss
            else:
                print(f"   Step {step}: loss={loss}")
else:
    print(f"\n‚ö†Ô∏è  Metrics file not found at {metrics_file}")

# Eval score and scorecard
eval_results_file = output_path / "eval_results.json"
primary_score = None
scorecard_info = {}
if eval_results_file.exists():
    with open(eval_results_file, "r") as f:
        results = json.load(f)
        primary_score = results.get('primary_score', results.get('answer_correctness'))
        scorecard_info = results.get('scorecard', {})
        if isinstance(primary_score, float):
            print(f"\nüéØ Primary Score: {primary_score:.4f} ({primary_score * 100:.1f}%)")
        else:
            print(f"\nüéØ Primary Score: {primary_score}")
        
        # M36: Display scorecard if available
        if scorecard_info:
            n_items = scorecard_info.get('n_items', '?')
            n_scored = scorecard_info.get('n_scored', '?')
            print(f"üìä Scorecard: {n_scored}/{n_items} items scored")
            section_scores = scorecard_info.get('section_scores', {})
            for section, score in section_scores.items():
                if score is not None:
                    print(f"   {section}: {score:.2f}")
else:
    print(f"\n‚ö†Ô∏è  Eval results not found (run evaluation cell first)")

# Artifact paths
print(f"\nüìÇ Artifact Paths:")
if output_path.exists():
    checkpoints = list(output_path.glob("checkpoint*"))
    for ckpt in checkpoints:
        print(f"   {ckpt}")
if metrics_file.exists():
    print(f"   {metrics_file}")
preds_path = Path(predictions_file)
if preds_path.exists():
    print(f"   {preds_path}")

# M36: Print RESULT SUMMARY block for evidence capture
print("\n" + "=" * 60)
print("         RESULT SUMMARY (copy to evidence files)")
print("=" * 60)
print(f"model_id: {MODEL_NAME}")
print(f"dataset: {DATASET}")
print(f"eval_set: {EVAL_SET}")
print(f"primary_score: {primary_score}")
print(f"final_loss: {final_loss}")
print(f"n_items: {scorecard_info.get('n_items', 'N/A')}")
print(f"n_scored: {scorecard_info.get('n_scored', 'N/A')}")
print("=" * 60)

print("\n‚úÖ Submission package ready!")
print("   See docs/submission_checklist.md for next steps.")
print("   See docs/M36_KAGGLE_RUN.md for evidence capture instructions.")
print("=" * 60)
