# Emergence Lab - Kaggle Training

Autonomous training notebook for running emergence experiments on Kaggle GPUs.

**What this does:**
1. Installs JAX with CUDA support and verifies GPU access
2. Clones the repo and installs dependencies
3. Configures hyperparameters
4. Resumes from checkpoint if available, otherwise starts fresh
5. Runs training with progress display and periodic checkpointing

**Usage:**
- Upload to Kaggle as a new notebook
- Enable GPU accelerator (Settings > Accelerator > GPU T4 x2 or P100)
- Run all cells
- Checkpoints are saved to `/kaggle/working/checkpoints/`
- Download checkpoints after the run completes

## Cell 1: Install JAX with CUDA and Verify GPU

In [None]:
# Cell 1: Install JAX with CUDA and verify GPU
import subprocess
import sys

# Install JAX with CUDA 12 support (Kaggle provides CUDA 12.x)
subprocess.check_call([
    sys.executable, "-m", "pip", "install", "-q",
    "jax[cuda12]",
])

# Verify GPU is available
import jax
print(f"JAX version: {jax.__version__}")
devices = jax.devices()
print(f"Devices: {devices}")

gpu_devices = [d for d in devices if "cuda" in str(d).lower() or "gpu" in str(d).lower()]
if gpu_devices:
    print(f"GPU available: {len(gpu_devices)} device(s) - {gpu_devices}")
else:
    print("WARNING: No GPU found! Training will be slow on CPU.")
    print("Make sure GPU accelerator is enabled in Kaggle notebook settings.")

## Cell 2: Clone Repo and Install Dependencies

In [None]:
# Cell 2: Clone repo and install dependencies
import os
import subprocess
import sys

REPO_URL = "https://github.com/imashishkh21/emergence-lab.git"
REPO_DIR = "/kaggle/working/emergence-lab"
BRANCH = "main"

# Clone or update repo
if os.path.exists(REPO_DIR):
    print(f"Repo exists at {REPO_DIR}, pulling latest...")
    subprocess.check_call(["git", "-C", REPO_DIR, "pull", "origin", BRANCH])
else:
    print(f"Cloning repo to {REPO_DIR}...")
    subprocess.check_call([
        "git", "clone", "--branch", BRANCH, "--depth", "1", REPO_URL, REPO_DIR,
    ])

# Install the package (skip JAX since we installed CUDA version above)
subprocess.check_call([
    sys.executable, "-m", "pip", "install", "-q",
    "-e", REPO_DIR,
    "--no-deps",  # Skip deps to avoid overwriting CUDA JAX
])

# Install non-JAX dependencies separately
subprocess.check_call([
    sys.executable, "-m", "pip", "install", "-q",
    "flax>=0.8.0", "optax>=0.1.7", "chex>=0.1.8",
    "numpy>=1.24.0", "scipy>=1.10.0", "scikit-learn>=1.3.0",
    "pyyaml>=6.0", "tqdm>=4.65.0", "matplotlib>=3.7.0",
    "wandb>=0.16.0", "imageio>=2.31.0", "imageio-ffmpeg>=0.4.8",
    "tyro>=0.6.0", "msgpack>=1.0.0",
])

# Add repo to Python path
sys.path.insert(0, REPO_DIR)
os.chdir(REPO_DIR)

# Verify import
from src.configs import Config
print(f"\nEmergence Lab imported successfully from {REPO_DIR}")
print(f"Working directory: {os.getcwd()}")

## Cell 3: Configuration

In [None]:
# Cell 3: Configuration
#
# MAXIMUM TRAINING RUN — uses full Kaggle weekly GPU budget (29 hours)
# across multiple 12-hour sessions with auto-resume.
#
# This produces ~384,000 gradient updates (160x more than the initial test run).
# Run the notebook 3 times: Session 1 (12h) → Session 2 (12h) → Session 3 (5h)
# Each session auto-resumes from the latest checkpoint.

from src.configs import (
    Config,
    EnvConfig,
    FieldConfig,
    AgentConfig,
    TrainConfig,
    LogConfig,
    EvolutionConfig,
    SpecializationConfig,
    AnalysisConfig,
)

# --- Paths ---
CHECKPOINT_DIR = "/kaggle/working/checkpoints"
RESUME_FROM = None  # Auto-detected in Cell 4

# --- Experiment Config ---
config = Config(
    env=EnvConfig(
        grid_size=20,
        num_agents=8,
        num_food=20,       # More food for robust population
        max_steps=500,
    ),
    field=FieldConfig(
        num_channels=4,
        diffusion_rate=0.1,
        decay_rate=0.05,
    ),
    agent=AgentConfig(
        hidden_dims=(64, 64),
    ),
    train=TrainConfig(
        seed=42,
        total_steps=1_600_000_000,  # 1.6B env steps = ~384K gradient updates = ~27 hours
        num_envs=32,
        num_steps=128,
        learning_rate=3e-4,
        resume_from=RESUME_FROM,
    ),
    log=LogConfig(
        wandb=False,
        save_interval=10_000_000,  # Save every 10M steps (~10 min, minimize data loss)
        checkpoint_dir=CHECKPOINT_DIR,
        server=False,
    ),
    analysis=AnalysisConfig(
        emergence_check_interval=10_000,
        specialization_check_interval=20_000,
    ),
    evolution=EvolutionConfig(
        enabled=True,
        starting_energy=200,
        food_energy=100,
        reproduce_threshold=120,
        reproduce_cost=50,
        mutation_std=0.01,
        max_agents=32,
    ),
    specialization=SpecializationConfig(
        diversity_bonus=0.1,
        niche_pressure=0.05,
    ),
)

print("=" * 60)
print("MAXIMUM TRAINING RUN — Phase 5 Data Collection")
print("=" * 60)
print(f"  Grid: {config.env.grid_size}x{config.env.grid_size}")
print(f"  Agents: {config.env.num_agents} (max: {config.evolution.max_agents})")
print(f"  Food: {config.env.num_food}")
print(f"  Total steps: {config.train.total_steps:,} (~384K gradient updates)")
print(f"  Save interval: every {config.log.save_interval:,} steps (~10 min)")
print(f"  Checkpoint dir: {config.log.checkpoint_dir}")
print(f"  Evolution: enabled (mutation_std={config.evolution.mutation_std})")
print(f"  Diversity bonus: {config.specialization.diversity_bonus}")
print(f"  Niche pressure: {config.specialization.niche_pressure}")
print(f"  Estimated time: ~27 hours across 3 sessions")
print(f"  Resume from: {config.train.resume_from or 'Auto-detect in Cell 4'}")
print("=" * 60)

## Cell 4: Resume or Start Fresh

In [None]:
# Cell 4: Resume-or-start logic
#
# Automatically detects existing checkpoints and resumes if available.
# If no checkpoint exists, starts fresh training.

import os

# Check for existing checkpoints
latest_checkpoint = os.path.join(CHECKPOINT_DIR, "latest.pkl")

if config.train.resume_from is not None:
    # Explicit resume path set in Cell 3
    if os.path.exists(config.train.resume_from):
        print(f"Will resume from explicit path: {config.train.resume_from}")
    else:
        print(f"WARNING: Resume path not found: {config.train.resume_from}")
        print("Will start fresh training.")
        config.train.resume_from = None
elif os.path.exists(latest_checkpoint):
    # Auto-detect latest checkpoint
    config.train.resume_from = latest_checkpoint
    print(f"Found existing checkpoint: {latest_checkpoint}")
    print("Will resume from latest checkpoint.")
    
    # Show checkpoint info
    from src.training.checkpointing import load_checkpoint
    ckpt = load_checkpoint(latest_checkpoint)
    print(f"  Checkpoint step: {ckpt.get('step', 'unknown')}")
    remaining = config.train.total_steps - ckpt.get('step', 0)
    print(f"  Remaining steps: {remaining:,}")
    if remaining <= 0:
        print("  Training already complete! Increase total_steps to continue.")
else:
    print("No existing checkpoint found. Starting fresh training.")
    os.makedirs(CHECKPOINT_DIR, exist_ok=True)
    print(f"Checkpoint directory created: {CHECKPOINT_DIR}")

print(f"\nReady to train for {config.train.total_steps:,} total steps.")

## Cell 5: Run Training

In [None]:
# Cell 5: Run training with progress display
#
# This cell runs the full training loop. Checkpoints are saved
# every save_interval steps and on completion.
#
# If the Kaggle session times out, re-run the notebook.
# Cell 4 will automatically detect the latest checkpoint and resume.

import time

from src.training.train import train

print("Starting training...")
print(f"Checkpoints will be saved to: {config.log.checkpoint_dir}")
print(f"Save interval: every {config.log.save_interval:,} steps")
print()

start_time = time.time()
try:
    final_state = train(config)
    elapsed = time.time() - start_time
    print(f"\nTraining complete! Elapsed: {elapsed / 3600:.1f} hours")
except KeyboardInterrupt:
    elapsed = time.time() - start_time
    print(f"\nTraining interrupted after {elapsed / 3600:.1f} hours")
    print("Emergency checkpoint should have been saved.")
except Exception as e:
    elapsed = time.time() - start_time
    print(f"\nTraining failed after {elapsed / 3600:.1f} hours: {e}")
    print("Check the latest checkpoint in the checkpoint directory.")
    raise

# List saved checkpoints
import glob
checkpoints = sorted(glob.glob(os.path.join(CHECKPOINT_DIR, "step_*.pkl")))
print(f"\nSaved checkpoints ({len(checkpoints)}):")
for cp in checkpoints:
    size_mb = os.path.getsize(cp) / (1024 * 1024)
    print(f"  {os.path.basename(cp)} ({size_mb:.1f} MB)")

print(f"\nDownload checkpoints from: {CHECKPOINT_DIR}")
print("Use Kaggle Output tab or kaggle_download.sh to retrieve them.")