# Hidden Food Specialization Training

Train 30 seeds x 2 conditions to test specialization with hidden food.
Config values from 34-config sweep (Run #007): pure FREEZE_EVOLVE won — diversity rewards didn't help.

## Conditions
| Condition | Description |
|-----------|-------------|
| **freeze_evolve** | Alternating GRADIENT/EVOLVE phases — sweep winner (D1_fe_ratio_50_50) |
| **baseline** | Standard GRADIENT mode — control (no specialization pressure) |

## Config
| Parameter | Value |
|-----------|-------|
| Grid | 32x32 |
| Starting agents | 16 |
| Max agents | 64 |
| Regular food | 40 |
| Hidden food items | 3 |
| Required agents to reveal | 3 |
| Reveal distance | 3 (Chebyshev) |
| Hidden food value | 5x (500 energy) |
| Steps per seed | 10M |
| Seeds per condition | 30 (10 batches x 3) |

## Runtime
- TPU v6e + High-RAM
- Resume-safe: re-run after disconnect and it picks up from checkpoints
- Run all cells (Ctrl+F9), come back when done

## Setup

Mount Google Drive, clone repo, install dependencies.

In [None]:
from google.colab import drive
drive.mount('/content/drive')

import os
CHECKPOINT_BASE = '/content/drive/MyDrive/emergence-lab'
os.makedirs(CHECKPOINT_BASE, exist_ok=True)

REPO_DIR = '/content/emergence-lab'
GITHUB_USERNAME = "imashishkh21"

if not os.path.exists(REPO_DIR):
    !git clone https://github.com/{GITHUB_USERNAME}/emergence-lab.git {REPO_DIR}
else:
    !cd {REPO_DIR} && git pull

os.chdir(REPO_DIR)
!pip install -e ".[dev]" -q
!pip install rliable -q

import jax
print(f"JAX version: {jax.__version__}")
print(f"Devices: {jax.devices()}")
print(f"Device count: {jax.device_count()}")

In [None]:
# =============================================================================
# TEST RUN - Verify both conditions before full training
# =============================================================================
import copy
import gc
import time
import jax
import jax.numpy as jnp
from src.configs import Config, TrainingMode
from src.training.parallel_train import ParallelTrainer

def make_test_config():
    """Base 64-agent hidden food config for testing."""
    cfg = Config()
    cfg.env.grid_size = 32
    cfg.env.num_agents = 16
    cfg.env.num_food = 40
    cfg.evolution.enabled = True
    cfg.evolution.max_agents = 64
    cfg.evolution.starting_energy = 200
    cfg.evolution.food_energy = 100
    cfg.evolution.reproduce_threshold = 120
    cfg.evolution.reproduce_cost = 40
    cfg.env.hidden_food.enabled = True
    cfg.env.hidden_food.num_hidden = 3
    cfg.env.hidden_food.required_agents = 3
    cfg.env.hidden_food.reveal_distance = 3
    cfg.train.num_envs = 32
    cfg.train.num_steps = 128
    cfg.log.wandb = False
    return cfg

# steps_per_iter = num_envs * num_steps * max_agents = 32 * 128 * 64 = 262144
steps_per_iter = 32 * 128 * 64
test_iters = max(4, 500_000 // steps_per_iter)  # At least 4 iters for phase transition

# --- Test 1: Freeze-Evolve (sweep winner) ---
print("[1/2] Testing FREEZE_EVOLVE (2 seeds, ~500K steps)...")
cfg1 = make_test_config()
cfg1.train.training_mode = TrainingMode.FREEZE_EVOLVE
cfg1.specialization.diversity_bonus = 0.0
cfg1.specialization.niche_pressure = 0.0
cfg1.freeze_evolve.gradient_steps = steps_per_iter     # 1 iter gradient
cfg1.freeze_evolve.evolve_steps = steps_per_iter       # 1 iter evolve
cfg1.freeze_evolve.evolve_mutation_boost = 5.0

trainer1 = ParallelTrainer(
    config=cfg1, num_seeds=2, seed_ids=[200, 201],
    checkpoint_dir=f'{CHECKPOINT_BASE}/test_spec_freeze_evolve', master_seed=8888,
)
t0 = time.time()
m1 = trainer1.train(num_iterations=test_iters, checkpoint_interval_minutes=60, resume=False)
print(f"   PASS: Freeze-evolve test done in {time.time()-t0:.1f}s")
print(f"   Reward: {m1.get('mean_reward', 'N/A')}")

# --- Test 2: Baseline (GRADIENT control) ---
print("\n[2/2] Testing BASELINE (2 seeds, ~500K steps)...")
cfg2 = make_test_config()
cfg2.train.training_mode = TrainingMode.GRADIENT
cfg2.specialization.diversity_bonus = 0.0
cfg2.specialization.niche_pressure = 0.0

trainer2 = ParallelTrainer(
    config=cfg2, num_seeds=2, seed_ids=[100, 101],
    checkpoint_dir=f'{CHECKPOINT_BASE}/test_spec_baseline', master_seed=9999,
)
t0 = time.time()
m2 = trainer2.train(num_iterations=test_iters, checkpoint_interval_minutes=60, resume=False)
print(f"   PASS: Baseline test done in {time.time()-t0:.1f}s")
print(f"   Reward: {m2.get('mean_reward', 'N/A')}")

# Cleanup test trainers to free memory before full training
del trainer1, trainer2
gc.collect()
jax.clear_caches()

print("\n" + "="*70)
print("BOTH TESTS PASSED! Proceed to full training.")
print("="*70)

## Configuration

Build configs for both conditions with hidden food enabled.
Values from 34-config sweep (Run #007).

In [None]:
import copy
import os
import time
import pickle
import numpy as np
from datetime import datetime, timedelta
from src.configs import Config, TrainingMode
from src.training.parallel_train import ParallelTrainer

TOTAL_STEPS = 10_000_000
NUM_ENVS = 32
NUM_STEPS = 128
SEEDS_PER_BATCH = 3
TOTAL_BATCHES = 10  # 10 batches x 3 seeds = 30 seeds per condition
CHECKPOINT_INTERVAL_MINUTES = 30
RESUME = True

CONDITIONS = ['freeze_evolve', 'baseline']
CHECKPOINT_DIRS = {
    'freeze_evolve': f'{CHECKPOINT_BASE}/hidden_food_specialization_freeze_evolve',
    'baseline': f'{CHECKPOINT_BASE}/hidden_food_specialization_baseline',
}

# steps_per_iter = num_envs * num_steps * max_agents = 32 * 128 * 64 = 262,144
# total_iters = 10M // 262144 = 38
# Phase steps are in agent-steps (same units as total_steps)
# gradient_steps=262144 -> 1 iter gradient, evolve_steps=262144 -> 1 iter evolve
# Cycle = 2 iters -> ~19 full cycles over 38 iterations


def build_config(condition: str) -> Config:
    """Build 64-agent hidden food config with specialization mechanisms."""
    config = Config()

    # PROVEN 64-AGENT CONFIG
    config.env.grid_size = 32
    config.env.num_agents = 16
    config.env.num_food = 40
    config.evolution.enabled = True
    config.evolution.max_agents = 64
    config.evolution.starting_energy = 200
    config.evolution.food_energy = 100
    config.evolution.energy_per_step = 1
    config.evolution.reproduce_threshold = 120
    config.evolution.reproduce_cost = 40
    config.evolution.mutation_std = 0.01

    # HIDDEN FOOD - the coordination task
    config.env.hidden_food.enabled = True
    config.env.hidden_food.num_hidden = 3
    config.env.hidden_food.required_agents = 3
    config.env.hidden_food.reveal_distance = 3

    # Training
    config.train.total_steps = TOTAL_STEPS
    config.train.num_envs = NUM_ENVS
    config.train.num_steps = NUM_STEPS
    config.train.seed = 42
    config.log.wandb = False
    config.log.save_interval = 0

    if condition == 'freeze_evolve':
        # Sweep winner: D1_fe_ratio_50_50
        config.train.training_mode = TrainingMode.FREEZE_EVOLVE
        config.specialization.diversity_bonus = 0.0
        config.specialization.niche_pressure = 0.0
        config.freeze_evolve.gradient_steps = 262144
        config.freeze_evolve.evolve_steps = 262144
        config.freeze_evolve.evolve_mutation_boost = 5.0

    elif condition == 'baseline':
        # Control: standard GRADIENT mode, no specialization pressure
        config.train.training_mode = TrainingMode.GRADIENT
        config.specialization.diversity_bonus = 0.0
        config.specialization.niche_pressure = 0.0

    return config


# Print config summary
steps_per_iter = NUM_ENVS * NUM_STEPS * 64
total_iters = TOTAL_STEPS // steps_per_iter
print(f"steps_per_iter={steps_per_iter:,}, total_iters={total_iters}\n")

for cond in CONDITIONS:
    cfg = build_config(cond)
    print(f"{cond}: mode={cfg.train.training_mode.value}, "
          f"div_bonus={cfg.specialization.diversity_bonus}, "
          f"niche={cfg.specialization.niche_pressure}, "
          f"hidden_food={cfg.env.hidden_food.enabled}")
    if cfg.train.training_mode == TrainingMode.FREEZE_EVOLVE:
        g_iters = max(1, cfg.freeze_evolve.gradient_steps // steps_per_iter)
        e_iters = max(1, cfg.freeze_evolve.evolve_steps // steps_per_iter)
        cycle = g_iters + e_iters
        n_cycles = total_iters // cycle if cycle > 0 else 0
        print(f"  freeze_evolve: grad_steps={cfg.freeze_evolve.gradient_steps:,} ({g_iters} iters), "
              f"evolve_steps={cfg.freeze_evolve.evolve_steps:,} ({e_iters} iters), "
              f"boost={cfg.freeze_evolve.evolve_mutation_boost}")
        print(f"  cycle={cycle} iters, ~{n_cycles} full cycles over {total_iters} iters")

## Autonomous Training — Both Conditions

Runs all 10 batches for each condition sequentially.
Total: 60 seeds (30 per condition) at 10M steps each.

Resume-safe: re-run and it picks up from latest checkpoints.

In [None]:
import gc

# Run both conditions sequentially
for condition_name in CONDITIONS:
    config = build_config(condition_name)
    checkpoint_dir_base = CHECKPOINT_DIRS[condition_name]
    os.makedirs(checkpoint_dir_base, exist_ok=True)

    steps_per_iter = NUM_ENVS * NUM_STEPS * config.evolution.max_agents
    num_iterations = TOTAL_STEPS // steps_per_iter

    print("\n" + "="*70)
    print(f"CONDITION: {condition_name.upper()}")
    print(f"Mode: {config.train.training_mode.value}")
    print(f"Diversity bonus: {config.specialization.diversity_bonus}")
    print(f"Niche pressure: {config.specialization.niche_pressure}")
    print(f"Hidden food: enabled={config.env.hidden_food.enabled}")
    print(f"Batches: {TOTAL_BATCHES}, Seeds/batch: {SEEDS_PER_BATCH}")
    print(f"Iterations: {num_iterations}, Steps/iter: {steps_per_iter:,}")
    print("="*70)

    all_results = []
    cond_start = time.time()

    for batch_number in range(TOTAL_BATCHES):
        seed_ids = list(range(batch_number * SEEDS_PER_BATCH, (batch_number + 1) * SEEDS_PER_BATCH))
        checkpoint_dir = f'{checkpoint_dir_base}/batch_{batch_number}'

        print(f"\n--- Batch {batch_number+1}/{TOTAL_BATCHES} | Seeds: {seed_ids} ---")

        try:
            trainer = ParallelTrainer(
                config=config, num_seeds=SEEDS_PER_BATCH, seed_ids=seed_ids,
                checkpoint_dir=checkpoint_dir, master_seed=42 + batch_number * 1000,
            )
            metrics = trainer.train(
                num_iterations=num_iterations,
                checkpoint_interval_minutes=CHECKPOINT_INTERVAL_MINUTES,
                resume=RESUME, print_interval=100,
            )
            all_results.append({
                'batch': batch_number, 'seed_ids': seed_ids,
                'metrics': metrics, 'success': True,
            })
            if 'mean_reward' in metrics:
                print(f"  Rewards: {metrics['mean_reward']}")
        except Exception as e:
            print(f"  ERROR: {e}")
            import traceback
            traceback.print_exc()
            all_results.append({
                'batch': batch_number, 'seed_ids': seed_ids,
                'error': str(e), 'success': False,
            })

        # Free trainer memory to avoid OOM across 30 instantiations
        del trainer
        gc.collect()
        jax.clear_caches()

        # Progress estimate
        batches_done = batch_number + 1
        total_elapsed = time.time() - cond_start
        avg_per_batch = total_elapsed / batches_done
        remaining = avg_per_batch * (TOTAL_BATCHES - batches_done)
        eta = datetime.now() + timedelta(seconds=remaining)
        print(f"  Progress: {batches_done}/{TOTAL_BATCHES} | "
              f"Elapsed: {total_elapsed/3600:.1f}h | ETA: {eta.strftime('%H:%M')}")

    # Save training summary for this condition
    cond_time = time.time() - cond_start
    summary_path = f'{checkpoint_dir_base}/training_summary.pkl'
    with open(summary_path, 'wb') as f:
        pickle.dump({
            'all_results': all_results,
            'total_time_seconds': cond_time,
            'condition': condition_name,
            'hidden_food_enabled': True,
            'total_steps': TOTAL_STEPS,
            'config': {
                'grid_size': config.env.grid_size,
                'num_food': config.env.num_food,
                'max_agents': config.evolution.max_agents,
                'training_mode': config.train.training_mode.value,
                'diversity_bonus': config.specialization.diversity_bonus,
                'niche_pressure': config.specialization.niche_pressure,
                'hidden_food_num_hidden': config.env.hidden_food.num_hidden,
                'hidden_food_required_agents': config.env.hidden_food.required_agents,
            },
        }, f)
    print(f"\n{condition_name} complete! Time: {cond_time/3600:.1f}h")
    print(f"Summary saved: {summary_path}")

print("\n" + "="*70)
print("ALL TRAINING COMPLETE!")
print("="*70)

## Final Summary

Scan all checkpoints across both conditions.

In [None]:
import glob as glob_mod
import pickle

print("="*70)
print("CHECKPOINT SUMMARY")
print("="*70)

for condition_name in CONDITIONS:
    checkpoint_dir_base = CHECKPOINT_DIRS[condition_name]
    print(f"\n{condition_name.upper()}:")
    print("-"*40)

    total_complete = 0
    total_partial = 0
    total_missing = 0

    for batch_idx in range(TOTAL_BATCHES):
        batch_dir = os.path.join(checkpoint_dir_base, f'batch_{batch_idx}')
        if not os.path.exists(batch_dir):
            seed_ids = list(range(batch_idx * SEEDS_PER_BATCH, (batch_idx + 1) * SEEDS_PER_BATCH))
            print(f"  Batch {batch_idx} (seeds {seed_ids}): NOT STARTED")
            total_missing += SEEDS_PER_BATCH
            continue

        batch_status = []
        for seed_dir in sorted(os.listdir(batch_dir)):
            seed_path = os.path.join(batch_dir, seed_dir)
            if not os.path.isdir(seed_path):
                continue
            pkl_files = glob_mod.glob(os.path.join(seed_path, 'step_*.pkl'))
            if pkl_files:
                latest = sorted(pkl_files)[-1]
                size_mb = os.path.getsize(latest) / (1024 * 1024)
                batch_status.append(f"{seed_dir}: {os.path.basename(latest)} ({size_mb:.1f}MB)")
                total_complete += 1
            else:
                batch_status.append(f"{seed_dir}: no checkpoints")
                total_partial += 1

        seed_ids = list(range(batch_idx * SEEDS_PER_BATCH, (batch_idx + 1) * SEEDS_PER_BATCH))
        print(f"  Batch {batch_idx} (seeds {seed_ids}): {len(batch_status)} seeds")
        for s in batch_status:
            print(f"    {s}")

    # Load training summary if available
    summary_path = os.path.join(checkpoint_dir_base, 'training_summary.pkl')
    if os.path.exists(summary_path):
        with open(summary_path, 'rb') as f:
            summary = pickle.load(f)
        n_success = sum(1 for r in summary['all_results'] if r.get('success', False))
        n_fail = sum(1 for r in summary['all_results'] if not r.get('success', True))
        print(f"\n  Summary: {n_success} successful batches, {n_fail} failed")
        if 'total_time_seconds' in summary:
            print(f"  Total time: {summary['total_time_seconds']/3600:.1f}h")

    print(f"\n  Complete: {total_complete} | Partial: {total_partial} | Missing: {total_missing}")

print("\n" + "="*70)
print("Next: Run the analysis notebook")
print("="*70)