# Tunix RT - Kaggle Submission Notebook

**Competition:** Google Tunix Hack - Train a model to show its work  
**Version:** `m33_v1`

This notebook provides a single-session workflow for the Tunix Hack competition.

**Workflow:**
1. Install dependencies
2. Configure training parameters
3. Build/load dataset
4. Train model (JAX/Flax) - Smoke or Full mode
5. Generate predictions
6. Evaluate and score
7. Display submission summary

**Runtime:** Kaggle TPU or GPU (recommended)  
**Time:** ~5 min (smoke) / ~1-2 hours (full)


## 1. Setup


In [None]:
# Install dependencies (Kaggle environment)
# Note: JAX with TPU support is pre-installed on Kaggle TPU runtimes
!pip install -q jax[cuda12] flax optax orbax-checkpoint transformers datasets

# Verify JAX installation
import jax
print(f"JAX version: {jax.__version__}")
print(f"Available devices: {jax.devices()}")
print("\n‚úÖ Setup complete")


## 2. Configuration

Configure training parameters below. The notebook supports two modes:
- **Smoke Mode:** Quick validation (2 steps, ~5 min)
- **Full Mode:** Complete training run (~1-2 hours)


In [None]:
import subprocess
import sys
import json
from pathlib import Path

# ============================================================
# CONFIGURATION - Modify these values as needed
# ============================================================

# Model selection (competition requirement: Gemma 2 2B or Gemma 3 1B)
MODEL_NAME = "google/gemma-3-1b-it"  # or "google/gemma-2-2b"

# Dataset selection
# Options: dev-reasoning-v2 (550 traces, recommended), golden-v2 (100 traces, quick sanity)
DATASET = "dev-reasoning-v2"

# Training parameters
MAX_STEPS = 100        # Full run: 100-1000 steps
SMOKE_STEPS = 2        # Smoke run: 2 steps for validation
SEED = 42

# Device selection
DEVICE = "auto"  # auto-detect GPU/TPU, or "cpu" for testing

# Output directories
OUTPUT_DIR = "./output/kaggle_run"
SMOKE_OUTPUT_DIR = "./output/smoke_run"

# Evaluation
EVAL_SET = "training/evalsets/eval_v1.jsonl"

# ============================================================

print("Configuration:")
print(f"  Model:       {MODEL_NAME}")
print(f"  Dataset:     {DATASET}")
print(f"  Max Steps:   {MAX_STEPS}")
print(f"  Smoke Steps: {SMOKE_STEPS}")
print(f"  Seed:        {SEED}")
print(f"  Device:      {DEVICE}")
print(f"  Output:      {OUTPUT_DIR}")


## 3. Build Dataset


In [None]:
# Build the selected dataset
# Note: Datasets are deterministically seeded (seed=42)

if DATASET == "dev-reasoning-v2":
    subprocess.run([sys.executable, "backend/tools/seed_dev_reasoning_v2.py"])
elif DATASET == "golden-v2":
    subprocess.run([sys.executable, "backend/tools/seed_golden_v2.py"])
elif DATASET == "dev-reasoning-v1":
    subprocess.run([sys.executable, "backend/tools/seed_dev_reasoning_v1.py"])
else:
    print(f"‚ö†Ô∏è  Dataset {DATASET} not recognized, assuming it already exists")

# Verify dataset exists
dataset_path = Path(f"backend/datasets/{DATASET}")
if dataset_path.exists():
    manifest_path = dataset_path / "manifest.json"
    if manifest_path.exists():
        with open(manifest_path) as f:
            manifest = json.load(f)
        print(f"\n‚úÖ Dataset ready: {DATASET}")
        print(f"   Traces: {manifest.get('trace_count', 'N/A')}")
    else:
        print(f"\n‚ö†Ô∏è  Manifest not found at {manifest_path}")
else:
    print(f"\n‚ùå Dataset directory not found: {dataset_path}")


## 4a. Smoke Run (Quick Validation)

Run this cell first to validate the pipeline works before the full training run.


In [None]:
# SMOKE RUN - Quick validation (2 steps)
# This confirms imports, dataset loading, and basic training work correctly

print("üî• Starting Smoke Run (2 steps)...")
print("=" * 60)

smoke_cmd = [
    sys.executable, "training/train_jax.py",
    "--dataset", DATASET,
    "--model_name", MODEL_NAME,
    "--device", DEVICE,
    "--output_dir", SMOKE_OUTPUT_DIR,
    "--seed", str(SEED),
    "--smoke_steps", str(SMOKE_STEPS),
]

print(f"Command: {' '.join(smoke_cmd)}\n")

result = subprocess.run(smoke_cmd, capture_output=False)

if result.returncode == 0:
    print("\n" + "=" * 60)
    print("‚úÖ Smoke run completed successfully!")
    print("   Pipeline validated. Ready for full training.")
else:
    print(f"\n‚ùå Smoke run failed with exit code {result.returncode}")


## 4b. Full Training Run

Run this cell for the complete training. Time budget: ~1-2 hours for 100 steps.


In [None]:
# FULL TRAINING RUN
# This runs the complete training pipeline with the configured parameters

print("üöÄ Starting Full Training Run...")
print("=" * 60)
print(f"Model:     {MODEL_NAME}")
print(f"Dataset:   {DATASET}")
print(f"Steps:     {MAX_STEPS}")
print(f"Output:    {OUTPUT_DIR}")
print("=" * 60 + "\n")

train_cmd = [
    sys.executable, "training/train_jax.py",
    "--dataset", DATASET,
    "--model_name", MODEL_NAME,
    "--max_steps", str(MAX_STEPS),
    "--device", DEVICE,
    "--output_dir", OUTPUT_DIR,
    "--seed", str(SEED),
    "--save_every_steps", "50",
]

print(f"Command: {' '.join(train_cmd)}\n")

result = subprocess.run(train_cmd, capture_output=False)

if result.returncode == 0:
    print("\n" + "=" * 60)
    print("‚úÖ Training completed successfully!")
else:
    print(f"\n‚ùå Training failed with exit code {result.returncode}")


## 5. Generate Predictions


In [None]:
# Generate predictions on the evaluation set

predictions_file = f"{OUTPUT_DIR}/predictions.jsonl"

print("üìä Generating predictions...")
print("=" * 60)

eval_cmd = [
    sys.executable, "training/eval_generate.py",
    "--checkpoint", OUTPUT_DIR,
    "--eval_set", EVAL_SET,
    "--output", predictions_file,
]

print(f"Command: {' '.join(eval_cmd)}\n")

result = subprocess.run(eval_cmd, capture_output=False)

if result.returncode == 0:
    print("\n‚úÖ Predictions generated successfully!")
else:
    print(f"\n‚ùå Prediction generation failed with exit code {result.returncode}")


## 6. Evaluate & Score


In [None]:
# Score predictions using the evaluation script

print("üìà Scoring predictions...")
print("=" * 60)

score_cmd = [
    sys.executable, "training/eval_report.py",
    "--predictions", predictions_file,
    "--eval_set", EVAL_SET,
]

print(f"Command: {' '.join(score_cmd)}\n")

result = subprocess.run(score_cmd, capture_output=False)

if result.returncode == 0:
    print("\n‚úÖ Evaluation complete!")
else:
    print(f"\n‚ùå Evaluation failed with exit code {result.returncode}")


## 7. Submission Summary


In [None]:
# Display final submission summary

print("\n" + "=" * 60)
print("         SUBMISSION SUMMARY")
print("=" * 60)

output_path = Path(OUTPUT_DIR)

# Model info
print(f"\nüì¶ Model ID: {MODEL_NAME}")
print(f"üìÅ Dataset:  {DATASET}")
print(f"üî¢ Steps:    {MAX_STEPS}")
print(f"üé≤ Seed:     {SEED}")

# Training metrics
metrics_file = output_path / "metrics.jsonl"
if metrics_file.exists():
    print(f"\nüìä Training Metrics (last 5 steps):")
    with open(metrics_file, "r") as f:
        lines = f.readlines()
        for line in lines[-5:]:
            metric = json.loads(line)
            step = metric.get('step', '?')
            loss = metric.get('loss', '?')
            if isinstance(loss, float):
                print(f"   Step {step}: loss={loss:.4f}")
            else:
                print(f"   Step {step}: loss={loss}")
else:
    print(f"\n‚ö†Ô∏è  Metrics file not found at {metrics_file}")

# Eval score
eval_results_file = output_path / "eval_results.json"
if eval_results_file.exists():
    with open(eval_results_file, "r") as f:
        results = json.load(f)
        score = results.get('answer_correctness', 'N/A')
        if isinstance(score, float):
            print(f"\nüéØ Eval Score: {score:.2f}")
        else:
            print(f"\nüéØ Eval Score: {score}")
else:
    print(f"\n‚ö†Ô∏è  Eval results not found (run evaluation cell first)")

# Artifact paths
print(f"\nüìÇ Artifact Paths:")
if output_path.exists():
    checkpoints = list(output_path.glob("checkpoint*"))
    for ckpt in checkpoints:
        print(f"   {ckpt}")
if metrics_file.exists():
    print(f"   {metrics_file}")
preds_path = Path(predictions_file)
if preds_path.exists():
    print(f"   {preds_path}")

print("\n" + "=" * 60)
print("‚úÖ Submission package ready!")
print("   See docs/submission_checklist.md for next steps.")
print("=" * 60)
