# Hidden Food Coordination Training — Field ON vs Field OFF

Train 30 seeds x 2 conditions with hidden food enabled.
Hidden food requires K=3 agents within distance 3 to reveal — the field should enable this.

## Config
| Parameter | Value |
|-----------|-------|
| Grid | 32x32 |
| Starting agents | 16 |
| Max agents | 64 |
| Regular food | 40 |
| Hidden food items | 3 |
| Required agents to reveal | 3 |
| Reveal distance | 3 (Chebyshev) |
| Hidden food value | 5x (500 energy) |
| Steps per seed | 10M |
| Seeds per condition | 30 (10 batches x 3) |

## Runtime
- TPU v6e + High-RAM
- Expected: ~8-12 hours per condition on TPU v6e
- Run all cells (Ctrl+F9), come back when done
- Resume-safe: re-run after disconnect and it picks up from checkpoints

## Setup

Mount Google Drive, clone repo, install dependencies.

In [None]:
from google.colab import drive
drive.mount('/content/drive')

import os
CHECKPOINT_BASE = '/content/drive/MyDrive/emergence-lab'
os.makedirs(CHECKPOINT_BASE, exist_ok=True)

REPO_DIR = '/content/emergence-lab'
GITHUB_USERNAME = "imashishkh21"

if not os.path.exists(REPO_DIR):
    !git clone https://github.com/{GITHUB_USERNAME}/emergence-lab.git {REPO_DIR}
else:
    !cd {REPO_DIR} && git pull

os.chdir(REPO_DIR)
!pip install -e ".[dev,phase5]" -q

import jax
print(f"JAX version: {jax.__version__}")
print(f"Devices: {jax.devices()}")
print(f"Device count: {jax.device_count()}")

In [None]:
# =============================================================================
# TEST RUN - Verify hidden food mechanics + parallel training before full run
# =============================================================================
import copy
import time
import jax
import jax.numpy as jnp
from src.configs import Config
from src.training.parallel_train import ParallelTrainer
from src.environment.env import reset, step
from src.environment.obs import get_observations

# --- Test 1: Hidden food env mechanics ---
print("[1/3] Testing hidden food environment mechanics...")
test_cfg = Config()
test_cfg.env.grid_size = 32
test_cfg.env.num_agents = 16
test_cfg.env.num_food = 40
test_cfg.evolution.enabled = True
test_cfg.evolution.max_agents = 64
test_cfg.evolution.starting_energy = 200
test_cfg.evolution.food_energy = 100
test_cfg.evolution.reproduce_threshold = 120
test_cfg.evolution.reproduce_cost = 40

# Enable hidden food
test_cfg.env.hidden_food.enabled = True
test_cfg.env.hidden_food.num_hidden = 3
test_cfg.env.hidden_food.required_agents = 3
test_cfg.env.hidden_food.reveal_distance = 3

key = jax.random.PRNGKey(42)
state = reset(key, test_cfg)
assert state.hidden_food_positions is not None, "hidden_food_positions is None!"
assert state.hidden_food_positions.shape == (3, 2), f"Wrong shape: {state.hidden_food_positions.shape}"
assert state.hidden_food_revealed is not None

# Do one step
actions = jnp.zeros(64, dtype=jnp.int32)  # all stay
state2, rewards, done, info = step(state, actions, test_cfg)
assert 'hidden_food_collected_this_step' in info
assert 'food_collected_this_step' in info
assert 'births_this_step' in info
print("   PASS: Hidden food env works. Shapes verified.")

# --- Test 2: Field ON parallel training ---
print("\n[2/3] Testing Field ON parallel training (2 seeds, 100K steps)...")
field_on_cfg = copy.deepcopy(test_cfg)
field_on_cfg.train.num_envs = 32
field_on_cfg.train.num_steps = 128
field_on_cfg.log.wandb = False
# Field ON defaults: diffusion=0.1, decay=0.05, write=1.0

steps_per_iter = 32 * 128 * 64
test_iters = max(1, 100_000 // steps_per_iter)

trainer_on = ParallelTrainer(
    config=field_on_cfg, num_seeds=2, seed_ids=[100, 101],
    checkpoint_dir=f'{CHECKPOINT_BASE}/test_hidden_food_on', master_seed=9999,
)
t0 = time.time()
metrics_on = trainer_on.train(num_iterations=test_iters, checkpoint_interval_minutes=60, resume=False)
print(f"   PASS: Field ON test done in {time.time()-t0:.1f}s, reward={metrics_on.get('mean_reward', 'N/A')}")

# --- Test 3: Field OFF parallel training ---
print("\n[3/3] Testing Field OFF parallel training (2 seeds, 100K steps)...")
field_off_cfg = copy.deepcopy(test_cfg)
field_off_cfg.train.num_envs = 32
field_off_cfg.train.num_steps = 128
field_off_cfg.log.wandb = False
field_off_cfg.field.diffusion_rate = 0.0
field_off_cfg.field.decay_rate = 1.0
field_off_cfg.field.write_strength = 0.0

trainer_off = ParallelTrainer(
    config=field_off_cfg, num_seeds=2, seed_ids=[200, 201],
    checkpoint_dir=f'{CHECKPOINT_BASE}/test_hidden_food_off', master_seed=8888,
)
t0 = time.time()
metrics_off = trainer_off.train(num_iterations=test_iters, checkpoint_interval_minutes=60, resume=False)
print(f"   PASS: Field OFF test done in {time.time()-t0:.1f}s, reward={metrics_off.get('mean_reward', 'N/A')}")

print("\n" + "="*70)
print("ALL 3 TESTS PASSED! Proceed to full training.")
print("="*70)

## Configuration

Build Field ON and Field OFF configs with hidden food enabled.

In [None]:
import copy
import os
import time
import pickle
import numpy as np
from datetime import datetime, timedelta
from src.configs import Config
from src.training.parallel_train import ParallelTrainer

TOTAL_STEPS = 10_000_000
NUM_ENVS = 32
NUM_STEPS = 128
SEEDS_PER_BATCH = 3
TOTAL_BATCHES = 10  # 10 batches x 3 seeds = 30 seeds per condition
CHECKPOINT_INTERVAL_MINUTES = 30
RESUME = True

CHECKPOINT_DIRS = {
    'field_on': f'{CHECKPOINT_BASE}/hidden_food_field_on',
    'field_off': f'{CHECKPOINT_BASE}/hidden_food_field_off',
}


def build_config(field_enabled: bool) -> Config:
    """Build PROVEN 64-agent config with hidden food."""
    config = Config()

    # PROVEN 64-AGENT CONFIG
    config.env.grid_size = 32
    config.env.num_agents = 16
    config.env.num_food = 40
    config.evolution.enabled = True
    config.evolution.max_agents = 64
    config.evolution.starting_energy = 200
    config.evolution.food_energy = 100
    config.evolution.energy_per_step = 1
    config.evolution.reproduce_threshold = 120
    config.evolution.reproduce_cost = 40
    config.evolution.mutation_std = 0.01

    # HIDDEN FOOD - the coordination task
    config.env.hidden_food.enabled = True
    config.env.hidden_food.num_hidden = 3
    config.env.hidden_food.required_agents = 3
    config.env.hidden_food.reveal_distance = 3
    # reveal_duration=10 and hidden_food_value_multiplier=5.0 are defaults

    # Training
    config.train.total_steps = TOTAL_STEPS
    config.train.num_envs = NUM_ENVS
    config.train.num_steps = NUM_STEPS
    config.train.seed = 42
    config.log.wandb = False
    config.log.save_interval = 0

    if field_enabled:
        # Field ON: use defaults (diffusion=0.1, decay=0.05, write=1.0)
        pass
    else:
        # Field OFF: zero out field
        config.field.diffusion_rate = 0.0
        config.field.decay_rate = 1.0
        config.field.write_strength = 0.0

    return config


# Print config summary
for cond, enabled in [('Field ON', True), ('Field OFF', False)]:
    cfg = build_config(enabled)
    print(f"{cond}: diffusion={cfg.field.diffusion_rate}, decay={cfg.field.decay_rate}, "
          f"write={cfg.field.write_strength}, hidden_food={cfg.env.hidden_food.enabled}")

## Autonomous Training — Field ON + Field OFF

Runs all 10 batches for Field ON, then all 10 batches for Field OFF.
Total: 60 seeds (30 per condition) at 10M steps each.

Expected runtime: ~8-12 hours per condition on TPU v6e.
Resume-safe: re-run and it picks up from latest checkpoints.

In [None]:
# Run both conditions sequentially
for condition_name, field_enabled in [('field_on', True), ('field_off', False)]:
    config = build_config(field_enabled)
    checkpoint_dir_base = CHECKPOINT_DIRS[condition_name]
    os.makedirs(checkpoint_dir_base, exist_ok=True)

    steps_per_iter = NUM_ENVS * NUM_STEPS * config.evolution.max_agents
    num_iterations = TOTAL_STEPS // steps_per_iter

    print("\n" + "="*70)
    print(f"CONDITION: {condition_name.upper()} (field_enabled={field_enabled})")
    print(f"Hidden food: enabled={config.env.hidden_food.enabled}")
    print(f"Batches: {TOTAL_BATCHES}, Seeds/batch: {SEEDS_PER_BATCH}")
    print(f"Iterations: {num_iterations}, Steps/iter: {steps_per_iter:,}")
    print("="*70)

    all_results = []
    cond_start = time.time()

    for batch_number in range(TOTAL_BATCHES):
        seed_ids = list(range(batch_number * SEEDS_PER_BATCH, (batch_number + 1) * SEEDS_PER_BATCH))
        checkpoint_dir = f'{checkpoint_dir_base}/batch_{batch_number}'

        print(f"\n--- Batch {batch_number+1}/{TOTAL_BATCHES} | Seeds: {seed_ids} ---")

        try:
            trainer = ParallelTrainer(
                config=config, num_seeds=SEEDS_PER_BATCH, seed_ids=seed_ids,
                checkpoint_dir=checkpoint_dir, master_seed=42 + batch_number * 1000,
            )
            metrics = trainer.train(
                num_iterations=num_iterations,
                checkpoint_interval_minutes=CHECKPOINT_INTERVAL_MINUTES,
                resume=RESUME, print_interval=100,
            )
            all_results.append({
                'batch': batch_number, 'seed_ids': seed_ids,
                'metrics': metrics, 'success': True,
            })
            if 'mean_reward' in metrics:
                print(f"  Rewards: {metrics['mean_reward']}")
        except Exception as e:
            print(f"  ERROR: {e}")
            all_results.append({
                'batch': batch_number, 'seed_ids': seed_ids,
                'error': str(e), 'success': False,
            })
            continue

        # Progress estimate
        batches_done = batch_number + 1
        total_elapsed = time.time() - cond_start
        avg_per_batch = total_elapsed / batches_done
        remaining = avg_per_batch * (TOTAL_BATCHES - batches_done)
        eta = datetime.now() + timedelta(seconds=remaining)
        print(f"  Progress: {batches_done}/{TOTAL_BATCHES} | "
              f"Elapsed: {total_elapsed/3600:.1f}h | ETA: {eta.strftime('%H:%M')}")

    # Save training summary for this condition
    cond_time = time.time() - cond_start
    summary_path = f'{checkpoint_dir_base}/training_summary.pkl'
    with open(summary_path, 'wb') as f:
        pickle.dump({
            'all_results': all_results,
            'total_time_seconds': cond_time,
            'condition': condition_name,
            'hidden_food_enabled': True,
            'total_steps': TOTAL_STEPS,
            'config': {
                'grid_size': config.env.grid_size,
                'num_food': config.env.num_food,
                'max_agents': config.evolution.max_agents,
                'reproduce_threshold': config.evolution.reproduce_threshold,
                'reproduce_cost': config.evolution.reproduce_cost,
                'hidden_food_num_hidden': config.env.hidden_food.num_hidden,
                'hidden_food_required_agents': config.env.hidden_food.required_agents,
                'hidden_food_reveal_distance': config.env.hidden_food.reveal_distance,
                'hidden_food_value_multiplier': config.env.hidden_food.hidden_food_value_multiplier,
            },
        }, f)
    print(f"\n{condition_name} complete! Time: {cond_time/3600:.1f}h")
    print(f"Summary saved: {summary_path}")

print("\n" + "="*70)
print("ALL TRAINING COMPLETE!")
print("="*70)

## Final Summary

Scan all checkpoints for both conditions.

In [None]:
import glob as glob_mod
import pickle

print("="*70)
print("CHECKPOINT SUMMARY")
print("="*70)

for condition_name in ['field_on', 'field_off']:
    checkpoint_dir_base = CHECKPOINT_DIRS[condition_name]
    print(f"\n{condition_name.upper()}:")
    print("-"*40)

    total_complete = 0
    total_partial = 0
    total_missing = 0
    total_error = 0

    for batch_idx in range(TOTAL_BATCHES):
        batch_dir = os.path.join(checkpoint_dir_base, f'batch_{batch_idx}')
        if not os.path.exists(batch_dir):
            seed_ids = list(range(batch_idx * SEEDS_PER_BATCH, (batch_idx + 1) * SEEDS_PER_BATCH))
            print(f"  Batch {batch_idx} (seeds {seed_ids}): NOT STARTED")
            total_missing += SEEDS_PER_BATCH
            continue

        batch_status = []
        for seed_dir in sorted(os.listdir(batch_dir)):
            seed_path = os.path.join(batch_dir, seed_dir)
            if not os.path.isdir(seed_path):
                continue
            pkl_files = glob_mod.glob(os.path.join(seed_path, 'step_*.pkl'))
            if pkl_files:
                latest = sorted(pkl_files)[-1]
                size_mb = os.path.getsize(latest) / (1024 * 1024)
                batch_status.append(f"{seed_dir}: {os.path.basename(latest)} ({size_mb:.1f}MB)")
                total_complete += 1
            else:
                batch_status.append(f"{seed_dir}: no checkpoints")
                total_partial += 1

        seed_ids = list(range(batch_idx * SEEDS_PER_BATCH, (batch_idx + 1) * SEEDS_PER_BATCH))
        print(f"  Batch {batch_idx} (seeds {seed_ids}): {len(batch_status)} seeds")
        for s in batch_status:
            print(f"    {s}")

    # Load training summary if available
    summary_path = os.path.join(checkpoint_dir_base, 'training_summary.pkl')
    if os.path.exists(summary_path):
        with open(summary_path, 'rb') as f:
            summary = pickle.load(f)
        n_success = sum(1 for r in summary['all_results'] if r.get('success', False))
        n_fail = sum(1 for r in summary['all_results'] if not r.get('success', True))
        print(f"\n  Summary: {n_success} successful batches, {n_fail} failed")
        if 'total_time_seconds' in summary:
            print(f"  Total time: {summary['total_time_seconds']/3600:.1f}h")

    print(f"\n  Complete: {total_complete} | Partial: {total_partial} | Missing: {total_missing}")

print("\n" + "="*70)
print("Next: Run the analysis notebook (colab_hidden_food_analysis.ipynb)")
print("="*70)