# Emergence Lab - Autonomous Field OFF Training

This notebook runs **ALL 30 Field OFF seeds automatically** in a single execution.

**No manual intervention needed** - just run all cells and let it complete.

## Configuration

- **FIELD_ENABLED = False** (hardcoded for ablation study)
- **10 batches x 3 seeds = 30 seeds total**
- **PROVEN 64-agent config** (grid=32, food=40, max_agents=64)

## Setup Instructions

1. Open this notebook in Google Colab
2. Runtime > Change runtime type > **TPU v5e** (or v6e) + **High-RAM**
3. Run all cells (Ctrl+F9 or Runtime > Run all)
4. Come back when complete (~10+ hours for 30 seeds)

## Cell 1: Setup

Mount Google Drive, clone the repo, and install dependencies.

In [None]:
# Mount Google Drive for checkpoint storage
from google.colab import drive
drive.mount('/content/drive')

# Create checkpoint directory
import os
CHECKPOINT_BASE = '/content/drive/MyDrive/emergence-lab'
os.makedirs(CHECKPOINT_BASE, exist_ok=True)
print(f"Checkpoint base: {CHECKPOINT_BASE}")

In [None]:
# Clone the repo if not already present
REPO_DIR = '/content/emergence-lab'

# CHANGE THIS to your GitHub username
GITHUB_USERNAME = "imashishkh"

if not os.path.exists(REPO_DIR):
    !git clone https://github.com/{GITHUB_USERNAME}/emergence-lab.git {REPO_DIR}
else:
    print(f"Repo already exists at {REPO_DIR}")
    # Pull latest changes
    !cd {REPO_DIR} && git pull

# Change to repo directory
os.chdir(REPO_DIR)
print(f"Working directory: {os.getcwd()}")

In [None]:
# Install dependencies
!pip install -e ".[dev]" -q

# Verify JAX sees TPU
import jax
print(f"JAX version: {jax.__version__}")
print(f"Devices: {jax.devices()}")
print(f"Device count: {jax.device_count()}")

# Check if TPU is available
if 'tpu' in str(jax.devices()[0]).lower():
    print("TPU detected!")
else:
    print("WARNING: TPU not detected. Training will be slower on GPU/CPU.")

## Cell 2: Autonomous Field OFF Training

This cell runs ALL 10 batches (30 seeds) automatically.

**FIELD_ENABLED = False** is hardcoded - this is the ablation condition.

In [None]:
# =============================================================================
# TEST RUN - Verify parallel training works before full experiment
# =============================================================================
# This runs 2 seeds for 100K steps (~1-2 minutes) to catch errors early.
# If this succeeds, the full 30-seed run should work.
# =============================================================================

print("="*70)
print("TEST MODE: Running quick verification (Field OFF)...")
print("="*70)

import time
from src.configs import Config
from src.training.parallel_train import ParallelTrainer

# Minimal test config - MATCHES PROVEN 64-AGENT CONFIG
test_config = Config()
test_config.env.grid_size = 32
test_config.env.num_agents = 16
test_config.env.num_food = 40
test_config.evolution.enabled = True
test_config.evolution.max_agents = 64
test_config.evolution.starting_energy = 200
test_config.evolution.food_energy = 100
test_config.evolution.reproduce_threshold = 120
test_config.evolution.reproduce_cost = 40
test_config.train.num_envs = 32
test_config.train.num_steps = 128
test_config.log.wandb = False

# FIELD OFF settings for test
test_config.field.decay_rate = 1.0
test_config.field.diffusion_rate = 0.0
test_config.field.write_strength = 0.0

# Test parameters: 2 seeds, 100K steps
TEST_STEPS = 100_000
TEST_SEEDS = 2
TEST_CHECKPOINT_DIR = f'{CHECKPOINT_BASE}/test_field_off'

steps_per_iter = 32 * 128 * 64
test_iterations = max(1, TEST_STEPS // steps_per_iter)

print(f"Test config (Field OFF):")
print(f"  Seeds: {TEST_SEEDS}")
print(f"  Steps: {TEST_STEPS:,}")
print(f"  Iterations: {test_iterations}")
print(f"  Field: DISABLED (decay=1.0, diffusion=0.0, write=0.0)")
print()

try:
    test_trainer = ParallelTrainer(
        config=test_config,
        num_seeds=TEST_SEEDS,
        seed_ids=[100, 101],
        checkpoint_dir=TEST_CHECKPOINT_DIR,
        master_seed=9999,
    )
    
    t0 = time.time()
    test_metrics = test_trainer.train(
        num_iterations=test_iterations,
        checkpoint_interval_minutes=60,
        resume=False,
        print_interval=5,
    )
    elapsed = time.time() - t0
    
    print()
    print("="*70)
    print("TEST PASSED!")
    print("="*70)
    print(f"Time: {elapsed:.1f}s")
    print(f"Final rewards: {test_metrics.get('mean_reward', 'N/A')}")
    print(f"Final alive: {test_metrics.get('alive_count', 'N/A')}")
    print()
    print("Proceed to run the main autonomous training cell below.")
    print("="*70)
    
except Exception as e:
    print()
    print("="*70)
    print("TEST FAILED!")
    print("="*70)
    print(f"Error: {e}")
    print()
    print("DO NOT proceed with full training until this is fixed.")
    raise e

In [None]:
# =============================================================================
# AUTONOMOUS FIELD OFF TRAINING - ALL 10 BATCHES
# =============================================================================
# This cell runs ALL 30 Field OFF seeds automatically.
# No manual intervention needed - just run and wait.
#
# FIELD_ENABLED = False (hardcoded for ablation study)
# =============================================================================

import os
import time
import pickle
import numpy as np
from datetime import datetime, timedelta

from src.configs import Config
from src.training.parallel_train import ParallelTrainer

# =============================================================================
# CONFIGURATION (HARDCODED FOR FIELD OFF ABLATION)
# =============================================================================

# FIELD ABLATION - THIS IS THE KEY SETTING
FIELD_ENABLED = False  # HARDCODED: Field OFF for ablation study

# Training parameters
TOTAL_STEPS = 10_000_000  # 10M steps per seed
NUM_ENVS = 32
NUM_STEPS = 128
SEEDS_PER_BATCH = 3  # 3 seeds per batch for 64-agent config
TOTAL_BATCHES = 10   # 10 batches = 30 seeds total

# Checkpoint settings
CHECKPOINT_DIR_BASE = f'{CHECKPOINT_BASE}/field_off'
CHECKPOINT_INTERVAL_MINUTES = 30
RESUME = True  # Resume from existing checkpoints if available

# =============================================================================
# BUILD CONFIG (PROVEN 64-AGENT CONFIG)
# =============================================================================

def build_config():
    """Build the PROVEN 64-agent config with Field OFF."""
    config = Config()
    
    # PROVEN 64-AGENT CONFIG
    config.env.grid_size = 32             # Larger grid for 64 agents
    config.env.num_agents = 16            # Starting population
    config.env.num_food = 40              # More food for 64 agents
    
    # Evolution settings
    config.evolution.enabled = True
    config.evolution.max_agents = 64      # PROVEN: 64 achieved UTOPIA
    config.evolution.starting_energy = 200
    config.evolution.food_energy = 100
    config.evolution.energy_per_step = 1
    config.evolution.reproduce_threshold = 120
    config.evolution.reproduce_cost = 40  # PROVEN: 40 not 50!
    config.evolution.mutation_std = 0.01
    
    # Training parameters
    config.train.total_steps = TOTAL_STEPS
    config.train.num_envs = NUM_ENVS
    config.train.num_steps = NUM_STEPS
    config.train.seed = 42
    
    # FIELD OFF ABLATION - Zero out the field
    config.field.decay_rate = 1.0         # Field decays completely each step
    config.field.diffusion_rate = 0.0     # No diffusion
    config.field.write_strength = 0.0     # Agents can't write to field
    
    # Logging
    config.log.wandb = False
    config.log.save_interval = 0
    
    return config

# =============================================================================
# RUN ALL BATCHES
# =============================================================================

print("="*70)
print("AUTONOMOUS FIELD OFF TRAINING")
print("="*70)
print(f"Field Enabled: {FIELD_ENABLED} (ABLATION STUDY)")
print(f"Total batches: {TOTAL_BATCHES}")
print(f"Seeds per batch: {SEEDS_PER_BATCH}")
print(f"Total seeds: {TOTAL_BATCHES * SEEDS_PER_BATCH}")
print(f"Steps per seed: {TOTAL_STEPS:,}")
print(f"Checkpoint base: {CHECKPOINT_DIR_BASE}")
print(f"Resume: {RESUME}")
print("="*70)
print()

# Build config once
config = build_config()

# Calculate iterations
steps_per_iter = NUM_ENVS * NUM_STEPS * config.evolution.max_agents
num_iterations = TOTAL_STEPS // steps_per_iter

print(f"Steps per iteration: {steps_per_iter:,}")
print(f"Iterations per seed: {num_iterations:,}")
print()

# Track results across all batches
all_results = []
start_time = time.time()

# Run all 10 batches
for batch_number in range(TOTAL_BATCHES):
    batch_start = time.time()
    
    # Compute seed IDs for this batch
    seed_ids = list(range(
        batch_number * SEEDS_PER_BATCH,
        (batch_number + 1) * SEEDS_PER_BATCH
    ))
    
    # Checkpoint directory for this batch
    checkpoint_dir = f'{CHECKPOINT_DIR_BASE}/batch_{batch_number}'
    
    print("\n" + "="*70)
    print(f"BATCH {batch_number + 1}/{TOTAL_BATCHES}")
    print(f"Seeds: {seed_ids}")
    print(f"Checkpoint dir: {checkpoint_dir}")
    print("="*70)
    
    try:
        # Create trainer for this batch
        trainer = ParallelTrainer(
            config=config,
            num_seeds=SEEDS_PER_BATCH,
            seed_ids=seed_ids,
            checkpoint_dir=checkpoint_dir,
            master_seed=42 + batch_number * 1000,
        )
        
        # Run training
        metrics = trainer.train(
            num_iterations=num_iterations,
            checkpoint_interval_minutes=CHECKPOINT_INTERVAL_MINUTES,
            resume=RESUME,
            print_interval=100,  # Print every 100 iterations
        )
        
        batch_elapsed = time.time() - batch_start
        
        # Store results
        batch_result = {
            'batch': batch_number,
            'seed_ids': seed_ids,
            'metrics': metrics,
            'elapsed_seconds': batch_elapsed,
            'success': True,
        }
        all_results.append(batch_result)
        
        # Print batch summary
        print(f"\nBatch {batch_number} complete!")
        print(f"  Time: {batch_elapsed/3600:.1f} hours")
        if 'mean_reward' in metrics:
            print(f"  Final rewards: {metrics['mean_reward']}")
        if 'alive_count' in metrics:
            print(f"  Final alive: {metrics['alive_count']}")
        
    except Exception as e:
        print(f"\nERROR in batch {batch_number}: {e}")
        batch_result = {
            'batch': batch_number,
            'seed_ids': seed_ids,
            'error': str(e),
            'success': False,
        }
        all_results.append(batch_result)
        # Continue to next batch instead of stopping
        continue
    
    # Progress estimate
    total_elapsed = time.time() - start_time
    batches_done = batch_number + 1
    batches_remaining = TOTAL_BATCHES - batches_done
    avg_time_per_batch = total_elapsed / batches_done
    estimated_remaining = avg_time_per_batch * batches_remaining
    
    print(f"\nProgress: {batches_done}/{TOTAL_BATCHES} batches")
    print(f"Elapsed: {total_elapsed/3600:.1f} hours")
    print(f"Estimated remaining: {estimated_remaining/3600:.1f} hours")
    eta = datetime.now() + timedelta(seconds=estimated_remaining)
    print(f"ETA: {eta.strftime('%Y-%m-%d %H:%M')}")

# =============================================================================
# FINAL SUMMARY
# =============================================================================

total_time = time.time() - start_time

print("\n" + "="*70)
print("ALL BATCHES COMPLETE!")
print("="*70)
print(f"Total time: {total_time/3600:.1f} hours")
print(f"Successful batches: {sum(1 for r in all_results if r['success'])}/{TOTAL_BATCHES}")
print()

# Save results summary
results_path = f'{CHECKPOINT_DIR_BASE}/training_summary.pkl'
with open(results_path, 'wb') as f:
    pickle.dump({
        'all_results': all_results,
        'total_time_seconds': total_time,
        'field_enabled': FIELD_ENABLED,
        'total_steps': TOTAL_STEPS,
        'config': {
            'grid_size': config.env.grid_size,
            'num_food': config.env.num_food,
            'max_agents': config.evolution.max_agents,
            'reproduce_threshold': config.evolution.reproduce_threshold,
            'reproduce_cost': config.evolution.reproduce_cost,
        },
    }, f)
print(f"Results saved to: {results_path}")

## Cell 3: Final Summary

Print detailed summary of all 30 seeds.

In [None]:
# =============================================================================
# FINAL SUMMARY OF ALL 30 FIELD OFF SEEDS
# =============================================================================

import glob
import pickle
import numpy as np

print("="*70)
print("FIELD OFF TRAINING - FINAL SUMMARY")
print("="*70)
print()

# Collect results from all seeds
all_rewards = []
all_alive = []
seed_status = []

for batch_number in range(10):
    batch_dir = f'{CHECKPOINT_DIR_BASE}/batch_{batch_number}'
    seed_ids = list(range(batch_number * 3, (batch_number + 1) * 3))
    
    print(f"Batch {batch_number} (seeds {seed_ids[0]}-{seed_ids[-1]}):")
    
    for seed_id in seed_ids:
        seed_dir = os.path.join(batch_dir, f'seed_{seed_id}')
        latest_path = os.path.join(seed_dir, 'latest.pkl')
        
        if os.path.exists(latest_path):
            try:
                with open(latest_path, 'rb') as f:
                    data = pickle.load(f)
                
                step = data.get('step', 0)
                progress = (step / TOTAL_STEPS) * 100
                
                # Get final metrics if available
                reward = data.get('mean_reward', 'N/A')
                alive = data.get('alive_count', 'N/A')
                
                status = 'COMPLETE' if progress >= 99 else f'{progress:.0f}%'
                print(f"  Seed {seed_id}: {status}, step {step:,}")
                
                if progress >= 99:
                    if isinstance(reward, (int, float)):
                        all_rewards.append(reward)
                    if isinstance(alive, (int, float)):
                        all_alive.append(alive)
                    seed_status.append((seed_id, 'complete'))
                else:
                    seed_status.append((seed_id, 'partial'))
                    
            except Exception as e:
                print(f"  Seed {seed_id}: ERROR - {e}")
                seed_status.append((seed_id, 'error'))
        else:
            print(f"  Seed {seed_id}: NOT FOUND")
            seed_status.append((seed_id, 'missing'))
    print()

# Print aggregate statistics
print("="*70)
print("AGGREGATE STATISTICS")
print("="*70)

complete_count = sum(1 for _, status in seed_status if status == 'complete')
partial_count = sum(1 for _, status in seed_status if status == 'partial')
missing_count = sum(1 for _, status in seed_status if status == 'missing')
error_count = sum(1 for _, status in seed_status if status == 'error')

print(f"Complete: {complete_count}/30 seeds")
print(f"Partial: {partial_count}/30 seeds")
print(f"Missing: {missing_count}/30 seeds")
print(f"Errors: {error_count}/30 seeds")
print()

if all_rewards:
    print(f"Mean reward (complete seeds): {np.mean(all_rewards):.4f} +/- {np.std(all_rewards):.4f}")
    print(f"  Min: {np.min(all_rewards):.4f}")
    print(f"  Max: {np.max(all_rewards):.4f}")

if all_alive:
    print(f"Mean alive count: {np.mean(all_alive):.1f} +/- {np.std(all_alive):.1f}")
    print(f"  Min: {np.min(all_alive):.0f}")
    print(f"  Max: {np.max(all_alive):.0f}")

print()
print("="*70)
print("NEXT STEPS")
print("="*70)
print("1. Ensure Field ON training is also complete (30 seeds)")
print("2. Download results to local machine")
print("3. Run analysis:")
print("   python scripts/generate_multi_seed_report.py \\")
print("       --checkpoint-dir /path/to/field_on \\")
print("       --compare-dir /path/to/field_off")

In [None]:
# Quick comparison with Field ON (if available)
print("="*70)
print("COMPARISON WITH FIELD ON (if available)")
print("="*70)

field_on_dir = f'{CHECKPOINT_BASE}/field_on'

if os.path.exists(field_on_dir):
    field_on_rewards = []
    field_on_alive = []
    field_on_complete = 0
    
    for batch_number in range(10):
        batch_dir = f'{field_on_dir}/batch_{batch_number}'
        for seed_id in range(batch_number * 3, (batch_number + 1) * 3):
            latest_path = os.path.join(batch_dir, f'seed_{seed_id}', 'latest.pkl')
            if os.path.exists(latest_path):
                try:
                    with open(latest_path, 'rb') as f:
                        data = pickle.load(f)
                    step = data.get('step', 0)
                    if step >= TOTAL_STEPS * 0.99:
                        field_on_complete += 1
                        if 'mean_reward' in data:
                            field_on_rewards.append(data['mean_reward'])
                        if 'alive_count' in data:
                            field_on_alive.append(data['alive_count'])
                except:
                    pass
    
    print(f"Field ON complete seeds: {field_on_complete}/30")
    print(f"Field OFF complete seeds: {complete_count}/30")
    print()
    
    if field_on_rewards and all_rewards:
        print("REWARD COMPARISON:")
        print(f"  Field ON:  {np.mean(field_on_rewards):.4f} +/- {np.std(field_on_rewards):.4f}")
        print(f"  Field OFF: {np.mean(all_rewards):.4f} +/- {np.std(all_rewards):.4f}")
        diff = np.mean(field_on_rewards) - np.mean(all_rewards)
        print(f"  Difference: {diff:+.4f} (Field ON {'better' if diff > 0 else 'worse'})")
    
    if field_on_alive and all_alive:
        print("\nALIVE COUNT COMPARISON:")
        print(f"  Field ON:  {np.mean(field_on_alive):.1f} +/- {np.std(field_on_alive):.1f}")
        print(f"  Field OFF: {np.mean(all_alive):.1f} +/- {np.std(all_alive):.1f}")
else:
    print("Field ON results not found.")
    print(f"Expected location: {field_on_dir}")
    print("Run Field ON training using colab_parallel_training.ipynb")