# Smart Field Test: Does State-Dependent Writing Improve Collective Intelligence?

## Hypothesis

| Prediction | Mechanism |
|------------|----------|
| Smart field > Presence field > No field | State-dependent writes encode richer information (energy, food found, reproduction readiness) into each channel |
| Smart field agents coordinate better for hidden food | Channels carry semantic meaning (ch0=energy, ch1=food, ch2=hidden food, ch3=reproduction), enabling targeted coordination |
| Presence field still beats no field | Even uniform presence marks carry spatial memory of agent activity |
| The gap widens under harsh conditions | Richer field information becomes essential when food is scarce |

## Experimental Design

| Condition | Field | write_mode | Seeds | Steps |
|-----------|-------|------------|-------|-------|
| **Field OFF** | diffusion=0, decay=1, write=0 | presence | 15 | 10M |
| **Field ON (presence)** | default (diffusion=0.1, decay=0.05, write=1.0) | presence | 15 | 10M |
| **Field ON (smart)** | default (diffusion=0.1, decay=0.05, write=1.0) | state_dependent | 15 | 10M |

## Harsh Environment Config

| Parameter | Value | Why |
|-----------|-------|-----|
| grid_size | 40 | Larger arena = harder to find food |
| num_food | 10 | **SCARCE** (vs 40 in standard config) |
| food_energy | 60 | **REDUCED** (vs 100 standard) |
| num_hidden | 8 | Abundant hidden food \u2014 the coordination prize |
| required_agents | 5 | Hard coordination requirement |
| reveal_distance | 3 | Chebyshev distance for reveal |
| hidden_food_value_multiplier | 10.0 | Each hidden food worth 600 energy (10x regular) |
| max_agents | 64 | Population cap |
| starting_energy | 200 | Starting energy per agent |
| reproduce_threshold | 120 | Energy needed to reproduce |
| reproduce_cost | 40 | Energy cost of reproduction |

In [None]:
from google.colab import drive
drive.mount('/content/drive')

import os
if not os.path.exists('/content/emergence-lab'):
    !git clone https://github.com/imashishkh21/emergence-lab.git /content/emergence-lab
%cd /content/emergence-lab
!git pull origin main

!pip install -e ".[dev]" -q
!pip install rliable -q

# Verify write_mode support
from src.configs import FieldConfig
assert hasattr(FieldConfig, 'write_mode'), "FieldConfig missing write_mode! Pull latest main."
print(f"FieldConfig.write_mode default: {FieldConfig().write_mode}")

import jax
print(f"JAX devices: {jax.devices()}")
print(f"JAX backend: {jax.default_backend()}")

## Phase 1 \u2014 Quick Validation

Before committing to 45 full training runs, verify:
1. Populations survive the harsh environment (at least one condition has mean pop > 10)
2. Check if hidden food is being revealed at all (may take longer than 2M steps)

**Setup**: 3 seeds per condition (9 runs total), ~2M steps (~8 iterations of 262K steps each)

In [None]:
import time, gc
import jax
import jax.numpy as jnp
import numpy as np
from src.configs import Config, TrainingMode
from src.training.parallel_train import ParallelTrainer
from src.agents.network import ActorCritic
from src.agents.policy import get_deterministic_actions
from src.environment.env import reset, step
from src.environment.obs import get_observations

NUM_ENVS = 32
NUM_STEPS = 128
MAX_AGENTS = 64
STEPS_PER_ITER = NUM_ENVS * NUM_STEPS * MAX_AGENTS  # 262,144
VALIDATION_ITERS = 8  # ~2M steps


def build_config(mode: str) -> Config:
    """Build harsh environment config for smart field experiment.

    Args:
        mode: One of 'field_off', 'field_on_presence', 'field_on_smart'
    """
    config = Config()
    # Large grid, scarce resources
    config.env.grid_size = 40
    config.env.num_agents = 16
    config.env.num_food = 10              # SCARCE
    config.env.max_steps = 500
    config.evolution.enabled = True
    config.evolution.max_agents = 64
    config.evolution.starting_energy = 200
    config.evolution.food_energy = 60     # REDUCED
    config.evolution.energy_per_step = 1
    config.evolution.reproduce_threshold = 120
    config.evolution.reproduce_cost = 40
    config.evolution.mutation_std = 0.01
    # Hidden food
    config.env.hidden_food.enabled = True
    config.env.hidden_food.num_hidden = 8
    config.env.hidden_food.required_agents = 5
    config.env.hidden_food.reveal_distance = 3
    config.env.hidden_food.hidden_food_value_multiplier = 10.0
    # Training
    config.train.training_mode = TrainingMode.GRADIENT
    config.train.num_envs = NUM_ENVS
    config.train.num_steps = NUM_STEPS
    config.train.seed = 42
    config.log.wandb = False
    config.log.save_interval = 0

    # Field configuration per mode
    if mode == 'field_off':
        config.field.diffusion_rate = 0.0
        config.field.decay_rate = 1.0
        config.field.write_strength = 0.0
        config.field.write_mode = 'presence'
    elif mode == 'field_on_presence':
        # Default field params, presence writing
        config.field.write_mode = 'presence'
    elif mode == 'field_on_smart':
        # Default field params, state-dependent writing
        config.field.write_mode = 'state_dependent'
    else:
        raise ValueError(f"Unknown mode: {mode}. Use 'field_off', 'field_on_presence', or 'field_on_smart'")

    return config


def run_hidden_food_eval(network, params, config, key, num_episodes=1):
    """Run eval episodes tracking hidden food metrics."""
    all_results = []
    for ep in range(num_episodes):
        key, ep_key = jax.random.split(key)
        state = reset(ep_key, config)
        ep_reward = 0.0
        ep_regular_food = 0.0
        ep_hf_revealed = 0
        ep_hf_collected = 0.0
        for t in range(config.env.max_steps):
            obs = get_observations(state, config)
            obs_batched = obs[None, :, :]  # (1, max_agents, obs_dim)
            actions = get_deterministic_actions(network, params, obs_batched)
            actions = actions[0]  # (max_agents,)
            pre_revealed = state.hidden_food_revealed
            state, rewards, done, info = step(state, actions, config)
            ep_reward += float(jnp.sum(rewards))
            ep_regular_food += float(info['food_collected_this_step'])
            ep_hf_collected += float(info['hidden_food_collected_this_step'])
            if pre_revealed is not None and state.hidden_food_revealed is not None:
                newly_revealed = (~pre_revealed) & state.hidden_food_revealed
                ep_hf_revealed += int(jnp.sum(newly_revealed))
            if bool(done):
                break
        food_energy = config.evolution.food_energy
        hf_multiplier = config.env.hidden_food.hidden_food_value_multiplier
        all_results.append({
            'total_reward': ep_reward,
            'regular_food_collected': ep_regular_food,
            'hidden_food_revealed': ep_hf_revealed,
            'hidden_food_collected': ep_hf_collected,
            'regular_food_energy': ep_regular_food * food_energy,
            'hidden_food_energy': ep_hf_collected * food_energy * hf_multiplier,
            'final_population': int(jnp.sum(state.agent_alive.astype(jnp.int32))),
        })
    agg = {k: np.mean([r[k] for r in all_results]) for k in all_results[0]}
    agg['per_episode'] = all_results
    return agg


# --- Validation training ---
CONDITIONS = [
    ('field_off', [200, 201, 202]),
    ('field_on_presence', [203, 204, 205]),
    ('field_on_smart', [206, 207, 208]),
]

validation_results = {}
for cond_name, seed_ids in CONDITIONS:
    print(f"\n{'='*60}")
    print(f"VALIDATION: {cond_name.upper()}")
    print(f"{'='*60}")
    config = build_config(cond_name)
    trainer = ParallelTrainer(
        config=config, num_seeds=3, seed_ids=seed_ids,
        checkpoint_dir=f'/tmp/validation_{cond_name}', master_seed=9999,
    )
    t0 = time.time()
    metrics = trainer.train(
        num_iterations=VALIDATION_ITERS,
        checkpoint_interval_minutes=999,
        resume=False,
        print_interval=2,
    )
    elapsed = time.time() - t0

    # Extract live population from trainer state
    ps = trainer._parallel_state
    alive = np.array(ps.env_state.agent_alive)       # (3, 32, 64)
    pop_per_seed = alive.sum(axis=(1, 2)) / alive.shape[1]  # mean across envs
    final_pops = [float(p) for p in pop_per_seed]

    # Quick eval: 1 episode per condition with seed 0's params
    network = ActorCritic(hidden_dims=(64, 64), num_actions=6)
    eval_key = jax.random.PRNGKey(42)
    seed0_params = jax.tree.map(lambda x: x[0], ps.params)
    eval_result = run_hidden_food_eval(network, seed0_params, config, eval_key, num_episodes=1)

    validation_results[cond_name] = {
        'metrics': metrics, 'elapsed': elapsed,
        'final_pops': final_pops, 'eval': eval_result,
    }
    print(f"\n{cond_name}: {elapsed:.0f}s")
    print(f"  Final populations (per seed, mean across envs): {[f'{p:.1f}' for p in final_pops]}")
    print(f"  Eval reward: {eval_result['total_reward']:.1f}")
    print(f"  Eval HF revealed: {eval_result['hidden_food_revealed']}")
    print(f"  Eval HF collected: {eval_result['hidden_food_collected']:.1f}")
    print(f"  Eval final pop: {eval_result['final_population']}")

    del trainer
    gc.collect()
    jax.clear_caches()

# --- Decision gate ---
print("\n" + "="*60)
print("VALIDATION DECISION GATE")
print("="*60)
max_pop = 0
for cond_name, _ in CONDITIONS:
    pops = validation_results[cond_name]['final_pops']
    mean_pop = np.mean(pops)
    hfr = validation_results[cond_name]['eval']['hidden_food_revealed']
    print(f"{cond_name:<25} mean population: {mean_pop:.1f}  HF revealed: {hfr}")
    max_pop = max(max_pop, mean_pop)

if max_pop > 10:
    print("\n\u2705 PASS: At least one condition sustains population > 10. Proceed to Phase 2.")
else:
    print("\n\u274c FAIL: All conditions have dangerously low populations.")
    print("  Fallback 1: Increase num_food from 10 to 15 (slightly less harsh)")
    print("  Fallback 2: Increase food_energy from 60 to 80 (more reward per food)")
    print("  Fallback 3: Reduce required_agents from 5 to 4 (easier hidden food)")
    print("  Apply ONE change at a time and re-run validation.")

all_hfr = [validation_results[c]['eval']['hidden_food_revealed'] for c, _ in CONDITIONS]
if all(h == 0 for h in all_hfr):
    print("\nNOTE: No hidden food revealed in any condition at 2M steps.")
    print("  This is expected \u2014 coordination may take longer to emerge.")
    print("  If populations are healthy, proceed. HF coordination is the 10M-step question.")

## Phase 2 \u2014 Full Training

15 seeds per condition (45 total), 10M steps each.

**Execution plan**:
- 5 batches x 3 seeds per condition
- Seed IDs: OFF [50-64], Presence [0-14], Smart [100-114]
- Resume-safe with checkpoint detection
- Memory cleanup between every batch
- Population crash detection after each batch
- Checkpoints saved to Google Drive every 60 minutes

In [None]:
import os, time, gc, pickle
from datetime import datetime, timedelta

TOTAL_STEPS = 10_000_000
SEEDS_PER_BATCH = 3
TOTAL_BATCHES = 5
CHECKPOINT_INTERVAL_MINUTES = 60
RESUME = True

steps_per_iter = NUM_ENVS * NUM_STEPS * MAX_AGENTS  # 262,144
num_iterations = TOTAL_STEPS // steps_per_iter       # 38

CHECKPOINT_BASE = '/content/drive/MyDrive/emergence-lab'
CHECKPOINT_DIRS = {
    'field_off': f'{CHECKPOINT_BASE}/smart_field_off',
    'field_on_presence': f'{CHECKPOINT_BASE}/smart_field_presence',
    'field_on_smart': f'{CHECKPOINT_BASE}/smart_field_smart',
}

# Seed ID ranges (non-overlapping)
SEED_RANGES = {
    'field_off': list(range(50, 65)),           # [50..64]
    'field_on_presence': list(range(0, 15)),     # [0..14]
    'field_on_smart': list(range(100, 115)),     # [100..114]
}

print(f"Steps per iteration: {steps_per_iter:,}")
print(f"Total iterations: {num_iterations}")
print(f"Seeds per condition: {SEEDS_PER_BATCH * TOTAL_BATCHES}")
print(f"Total training runs: {3 * SEEDS_PER_BATCH * TOTAL_BATCHES}")

for condition_name in ['field_off', 'field_on_presence', 'field_on_smart']:
    config = build_config(condition_name)
    config.train.total_steps = TOTAL_STEPS
    checkpoint_dir_base = CHECKPOINT_DIRS[condition_name]
    os.makedirs(checkpoint_dir_base, exist_ok=True)
    all_seeds = SEED_RANGES[condition_name]

    all_results = []
    cond_start = time.time()

    for batch_number in range(TOTAL_BATCHES):
        seed_ids = all_seeds[
            batch_number * SEEDS_PER_BATCH : (batch_number + 1) * SEEDS_PER_BATCH
        ]
        checkpoint_dir = f'{checkpoint_dir_base}/batch_{batch_number}'
        batch_start = time.time()

        print(f"\n{'='*60}")
        print(f"{condition_name.upper()} - Batch {batch_number+1}/{TOTAL_BATCHES} - Seeds {seed_ids}")
        elapsed_total = time.time() - cond_start
        if batch_number > 0:
            avg_per_batch = elapsed_total / batch_number
            remaining = avg_per_batch * (TOTAL_BATCHES - batch_number)
            eta = datetime.now() + timedelta(seconds=remaining)
            print(f"ETA: {eta.strftime('%H:%M:%S')} ({remaining/60:.0f} min remaining)")
        print(f"{'='*60}")

        try:
            trainer = ParallelTrainer(
                config=config, num_seeds=SEEDS_PER_BATCH,
                seed_ids=seed_ids, checkpoint_dir=checkpoint_dir,
                master_seed=42 + batch_number * 1000,
            )
            metrics = trainer.train(
                num_iterations=num_iterations,
                checkpoint_interval_minutes=CHECKPOINT_INTERVAL_MINUTES,
                resume=RESUME,
                print_interval=5,
            )
            batch_time = time.time() - batch_start

            # Population crash detection
            ps = trainer._parallel_state
            alive = np.array(ps.env_state.agent_alive)
            pop_per_seed = alive.sum(axis=(1, 2)) / alive.shape[1]
            batch_pops = [float(p) for p in pop_per_seed]
            print(f"  Final populations: {[f'{p:.0f}' for p in batch_pops]}")
            if all(p < 3 for p in batch_pops):
                print(f"  \u26a0\ufe0f WARNING: All seeds in batch {batch_number} have near-zero population!")
                print(f"  Population may have crashed. Check environment parameters.")

            all_results.append({
                'batch': batch_number, 'seed_ids': seed_ids,
                'metrics': metrics, 'success': True,
                'time_seconds': batch_time, 'final_pops': batch_pops,
            })
            print(f"Batch {batch_number} done in {batch_time/60:.1f} min")
        except Exception as e:
            print(f"ERROR in batch {batch_number}: {e}")
            import traceback; traceback.print_exc()
            all_results.append({
                'batch': batch_number, 'seed_ids': seed_ids,
                'success': False, 'error': str(e),
            })
        finally:
            try: del trainer
            except: pass
            gc.collect()
            jax.clear_caches()

    cond_time = time.time() - cond_start
    summary_path = f'{checkpoint_dir_base}/training_summary.pkl'
    with open(summary_path, 'wb') as f:
        pickle.dump({
            'all_results': all_results,
            'total_time_seconds': cond_time,
            'condition': condition_name,
            'total_steps': TOTAL_STEPS,
            'config': {
                'grid_size': config.env.grid_size,
                'num_food': config.env.num_food,
                'food_energy': config.evolution.food_energy,
                'num_hidden': config.env.hidden_food.num_hidden,
                'required_agents': config.env.hidden_food.required_agents,
                'hidden_food_value_multiplier': config.env.hidden_food.hidden_food_value_multiplier,
                'field_enabled': condition_name != 'field_off',
                'write_mode': config.field.write_mode,
                'max_agents': config.evolution.max_agents,
            },
        }, f)
    print(f"\n{condition_name} COMPLETE in {cond_time/3600:.1f} hours")
    print(f"Summary saved to {summary_path}")

## Phase 3 \u2014 Analysis

Load checkpoints from all 3 conditions, run evaluation episodes, compute statistics, and generate publication-quality figures.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import json, pickle, gc, os
import glob as glob_mod
from pathlib import Path
from datetime import datetime
from scipy import stats as scipy_stats

import jax
import jax.numpy as jnp

from src.configs import (
    Config, TrainingMode, HiddenFoodConfig, EnvConfig, FieldConfig,
    AgentConfig, TrainConfig, LogConfig, AnalysisConfig,
    EvolutionConfig, SpecializationConfig, FreezeEvolveConfig, ArchiveConfig,
)
from src.agents.network import ActorCritic
from src.agents.policy import get_deterministic_actions
from src.environment.env import reset, step
from src.environment.obs import get_observations
from src.training.checkpointing import load_checkpoint
from src.analysis.specialization import compute_weight_divergence
from src.analysis.statistics import (
    compute_iqm, compare_methods, welch_t_test,
    mann_whitney_test, probability_of_improvement,
)

%matplotlib inline
from src.analysis.paper_figures import setup_publication_style, save_figure

setup_publication_style()

OUTPUT_DIR = '/content/drive/MyDrive/emergence-lab/smart_field_results'
os.makedirs(OUTPUT_DIR, exist_ok=True)

# --- Condition labels and colors ---
COND_NAMES = ['field_off', 'field_on_presence', 'field_on_smart']
COND_LABELS = {'field_off': 'Field OFF', 'field_on_presence': 'Presence', 'field_on_smart': 'Smart'}
COND_COLORS = {'field_off': '#BBBBBB', 'field_on_presence': '#009988', 'field_on_smart': '#EE7733'}


def reconstruct_config(d):
    """Convert plain dict from load_checkpoint() back to Config dataclass."""
    if isinstance(d, Config):
        return d

    # Env config - handle nested hidden_food dict
    env_d = dict(d.get('env', {}))
    if 'hidden_food' in env_d and isinstance(env_d['hidden_food'], dict):
        env_d['hidden_food'] = HiddenFoodConfig(**env_d['hidden_food'])

    # Train config - handle TrainingMode enum
    train_d = dict(d.get('train', {}))
    if 'training_mode' in train_d and isinstance(train_d['training_mode'], str):
        train_d['training_mode'] = TrainingMode(train_d['training_mode'])

    # Agent config - handle hidden_dims tuple
    agent_d = dict(d.get('agent', {}))
    if 'hidden_dims' in agent_d and isinstance(agent_d['hidden_dims'], list):
        agent_d['hidden_dims'] = tuple(agent_d['hidden_dims'])

    # Field config - handle write_mode backward compatibility
    field_d = dict(d.get('field', {}))
    if 'write_mode' not in field_d:
        field_d['write_mode'] = 'presence'  # default for pre-write_mode checkpoints

    return Config(
        env=EnvConfig(**env_d),
        field=FieldConfig(**field_d),
        agent=AgentConfig(**agent_d),
        train=TrainConfig(**train_d),
        log=LogConfig(**d.get('log', {})),
        analysis=AnalysisConfig(**d.get('analysis', {})),
        evolution=EvolutionConfig(**d.get('evolution', {})),
        specialization=SpecializationConfig(**d.get('specialization', {})),
        freeze_evolve=FreezeEvolveConfig(**d.get('freeze_evolve', {})),
        archive=ArchiveConfig(**d.get('archive', {})),
    )


def discover_checkpoints(base_dir, max_batches=20):
    """Find all checkpoint paths under base_dir/batch_*/seed_*/step_*.pkl."""
    paths = []
    for batch_idx in range(max_batches):
        batch_dir = os.path.join(base_dir, f'batch_{batch_idx}')
        if not os.path.exists(batch_dir):
            continue
        for seed_dir_name in sorted(os.listdir(batch_dir)):
            seed_path = os.path.join(batch_dir, seed_dir_name)
            if not os.path.isdir(seed_path):
                continue
            pkl_files = glob_mod.glob(os.path.join(seed_path, 'step_*.pkl'))
            if pkl_files:
                paths.append(sorted(pkl_files)[-1])  # Latest step
    return paths


def load_training_summary(base_dir):
    """Load training_summary.pkl and extract per-seed rewards + populations."""
    summary_path = os.path.join(base_dir, 'training_summary.pkl')
    if not os.path.exists(summary_path):
        return None, None
    with open(summary_path, 'rb') as f:
        summary = pickle.load(f)
    rewards = []
    populations = []
    for batch in summary['all_results']:
        if not batch.get('success', True):
            continue
        if 'metrics' in batch and 'mean_reward' in batch['metrics']:
            rewards.extend(batch['metrics']['mean_reward'])
        if 'metrics' in batch and 'population_size' in batch['metrics']:
            populations.extend(batch['metrics']['population_size'])
    if rewards:
        return np.array(rewards), np.array(populations, dtype=float)
    return None, None


def load_seed_data(ckpt_path):
    """Load checkpoint, extract network + config, free full checkpoint."""
    ckpt = load_checkpoint(ckpt_path)
    config = reconstruct_config(ckpt['config'])
    # agent_params from checkpoint: (num_envs, max_agents, ...)
    # Take env 0 for divergence analysis: (max_agents, ...)
    agent_params_env0 = jax.tree_util.tree_map(lambda x: x[0], ckpt['agent_params'])
    network = ActorCritic(
        hidden_dims=tuple(config.agent.hidden_dims), num_actions=6
    )
    result = {
        'params': ckpt['params'],
        'agent_params': agent_params_env0,
        'config': config,
        'network': network,
        'seed_id': ckpt.get('seed_id', -1),
    }
    del ckpt
    return result


# --- Load checkpoints ---
CHECKPOINT_BASE = '/content/drive/MyDrive/emergence-lab'

cond_dirs = {
    'field_off': f'{CHECKPOINT_BASE}/smart_field_off',
    'field_on_presence': f'{CHECKPOINT_BASE}/smart_field_presence',
    'field_on_smart': f'{CHECKPOINT_BASE}/smart_field_smart',
}

cond_ckpt_paths = {}
cond_rewards = {}
cond_populations = {}

for cond in COND_NAMES:
    cond_ckpt_paths[cond] = discover_checkpoints(cond_dirs[cond])
    cond_rewards[cond], cond_populations[cond] = load_training_summary(cond_dirs[cond])
    n_ckpt = len(cond_ckpt_paths[cond])
    print(f"{COND_LABELS[cond]:<12} checkpoints: {n_ckpt}/15", end="")
    if cond_rewards[cond] is not None:
        print(f"  training reward mean={np.mean(cond_rewards[cond]):.4f}")
    else:
        print()

In [None]:
# --- Training Reward Statistics (ANOVA + pairwise) ---

# Check we have training data for all conditions
have_rewards = all(cond_rewards[c] is not None for c in COND_NAMES)

if have_rewards:
    # ANOVA across 3 conditions
    f_stat, anova_p = scipy_stats.f_oneway(
        cond_rewards['field_off'],
        cond_rewards['field_on_presence'],
        cond_rewards['field_on_smart'],
    )
    print(f"ANOVA: F={f_stat:.3f}, p={anova_p:.6f}")
    if anova_p < 0.05:
        print("  Significant difference across conditions.")
    else:
        print("  No significant difference across conditions.")

    # Pairwise comparisons
    PAIRS = [
        ('field_on_smart', 'field_on_presence', 'Smart vs Presence (KEY)'),
        ('field_on_smart', 'field_off', 'Smart vs OFF'),
        ('field_on_presence', 'field_off', 'Presence vs OFF'),
    ]

    print(f"\n{'Comparison':<35} {'t':>8} {'p':>10} {'d':>8} {'P(A>B)':>8}")
    print("-" * 75)
    for cond_a, cond_b, label in PAIRS:
        welch = welch_t_test(cond_rewards[cond_a], cond_rewards[cond_b])
        poi = probability_of_improvement(cond_rewards[cond_a], cond_rewards[cond_b], seed=42)
        sig = "***" if welch.p_value < 0.001 else "**" if welch.p_value < 0.01 else "*" if welch.p_value < 0.05 else ""
        print(f"{label:<35} {welch.statistic:>8.3f} {welch.p_value:>10.4f} {welch.effect_size:>8.3f} {poi['prob_x_better']:>8.3f} {sig}")

    # IQM comparison
    print("\nIQM Comparison:")
    score_dict = {COND_LABELS[c]: cond_rewards[c] for c in COND_NAMES}
    comparison = compare_methods(score_dict, alpha=0.05, n_bootstrap=10000, seed=42)
    print(comparison.summary)
else:
    print("Training summaries not found for all conditions. Run training cells first.")
    comparison = None

In [None]:
# --- Training Comparison Plots (3 conditions) ---

if have_rewards:
    fig, axes = plt.subplots(1, 3, figsize=(18, 5))

    # Panel 1: IQM bars with CI
    ax = axes[0]
    iqm_vals = {}
    for i, cond in enumerate(COND_NAMES):
        iqm_vals[cond] = compute_iqm(cond_rewards[cond], seed=42)
    x_pos = np.arange(len(COND_NAMES))
    bars = ax.bar(
        x_pos,
        [iqm_vals[c].iqm for c in COND_NAMES],
        color=[COND_COLORS[c] for c in COND_NAMES],
        edgecolor='black', linewidth=0.8,
        tick_label=[COND_LABELS[c] for c in COND_NAMES],
    )
    ax.errorbar(
        x_pos,
        [iqm_vals[c].iqm for c in COND_NAMES],
        yerr=[
            [iqm_vals[c].iqm - iqm_vals[c].ci_lower for c in COND_NAMES],
            [iqm_vals[c].ci_upper - iqm_vals[c].iqm for c in COND_NAMES],
        ],
        fmt='none', color='black', capsize=5,
    )
    ax.set_ylabel('IQM Reward')
    ax.set_title('Training Reward (IQM + 95% CI)')

    # Panel 2: Violin + swarm
    ax = axes[1]
    data_list = [cond_rewards[c] for c in COND_NAMES]
    parts = ax.violinplot(data_list, positions=x_pos, showmeans=True, showmedians=True)
    for i, pc in enumerate(parts['bodies']):
        pc.set_facecolor(COND_COLORS[COND_NAMES[i]])
        pc.set_alpha(0.6)
    for i, cond in enumerate(COND_NAMES):
        jitter = np.random.default_rng(42 + i).uniform(-0.1, 0.1, len(cond_rewards[cond]))
        ax.scatter(i + jitter, cond_rewards[cond], c=COND_COLORS[cond], s=15, alpha=0.7, zorder=3)
    ax.set_xticks(x_pos)
    ax.set_xticklabels([COND_LABELS[c] for c in COND_NAMES])
    ax.set_ylabel('Final Reward')
    ax.set_title('Reward Distribution')

    # Panel 3: Population distribution
    ax = axes[2]
    have_pops = all(cond_populations[c] is not None for c in COND_NAMES)
    if have_pops:
        all_pop_vals = np.concatenate([cond_populations[c] for c in COND_NAMES])
        bins = np.linspace(all_pop_vals.min(), all_pop_vals.max(), 20)
        for cond in COND_NAMES:
            ax.hist(cond_populations[cond], bins=bins, alpha=0.5,
                    color=COND_COLORS[cond], label=COND_LABELS[cond], edgecolor='black')
        ax.set_xlabel('Final Population')
        ax.set_ylabel('Count')
        ax.set_title('Population Distribution')
        ax.legend()
    else:
        ax.text(0.5, 0.5, 'No population data', ha='center', va='center', transform=ax.transAxes)

    fig.suptitle('Smart Field Test: Training Comparison', fontsize=16, y=1.02)
    fig.tight_layout()
    save_figure(fig, os.path.join(OUTPUT_DIR, 'training_comparison'))
    plt.show()
else:
    print("No training data to plot.")

In [None]:
# --- Load & Eval All Checkpoints + Weight Divergence ---

cond_hf_eval = {c: [] for c in COND_NAMES}
cond_divergence = {c: [] for c in COND_NAMES}

for cond in COND_NAMES:
    ckpt_paths = cond_ckpt_paths[cond]
    label = COND_LABELS[cond]
    print(f"\n{'='*40}")
    print(f"Evaluating {label}: {len(ckpt_paths)} seeds")
    print(f"{'='*40}")

    for i, ckpt_path in enumerate(ckpt_paths):
        print(f"  [{label}] Eval seed {i+1}/{len(ckpt_paths)}: {os.path.basename(ckpt_path)}")
        seed_data = load_seed_data(ckpt_path)
        key = jax.random.PRNGKey(42 + i)

        hf_result = run_hidden_food_eval(
            seed_data['network'], seed_data['params'],
            seed_data['config'], key, num_episodes=1,
        )
        hf_result['seed_id'] = seed_data['seed_id']
        cond_hf_eval[cond].append(hf_result)

        # Compute weight divergence
        max_agents = seed_data['config'].evolution.max_agents
        alive_mask = np.ones(max_agents, dtype=bool)
        div = compute_weight_divergence(seed_data['agent_params'], alive_mask)
        cond_divergence[cond].append({
            'seed_id': seed_data['seed_id'],
            'mean_divergence': float(div['mean_divergence']),
            'max_divergence': float(div['max_divergence']),
        })

        del seed_data
        gc.collect()

        # Intermediate save after each seed
        progress_path = os.path.join(OUTPUT_DIR, f'eval_progress_{cond}.pkl')
        with open(progress_path, 'wb') as f:
            pickle.dump({
                'condition': cond,
                'completed_seeds': i + 1,
                'total_seeds': len(ckpt_paths),
                'hf_eval': list(cond_hf_eval[cond]),
                'divergence': list(cond_divergence[cond]),
            }, f)

# Build numpy arrays per condition per metric
EVAL_METRICS = [
    'total_reward', 'hidden_food_revealed', 'hidden_food_collected',
    'regular_food_collected', 'hidden_food_energy', 'regular_food_energy',
    'final_population',
]

cond_arrays = {}
for cond in COND_NAMES:
    cond_arrays[cond] = {}
    for metric in EVAL_METRICS:
        cond_arrays[cond][metric] = np.array([r[metric] for r in cond_hf_eval[cond]])
    cond_arrays[cond]['mean_divergence'] = np.array(
        [d['mean_divergence'] for d in cond_divergence[cond]]
    )

for cond in COND_NAMES:
    n = len(cond_hf_eval[cond])
    print(f"{COND_LABELS[cond]}: {n} seeds evaluated")

In [None]:
# --- Hidden Food Stats + Grouped Bar/Violin Plots ---

PLOT_METRICS = [
    ('hidden_food_revealed', 'Hidden Food Revealed'),
    ('hidden_food_collected', 'Hidden Food Collected'),
    ('regular_food_collected', 'Regular Food Collected'),
    ('total_reward', 'Total Reward'),
    ('final_population', 'Final Population'),
    ('mean_divergence', 'Weight Divergence'),
]

# --- Statistical table (ANOVA + pairwise per metric) ---
print(f"{'Metric':<25} {'OFF':>8} {'Pres':>8} {'Smart':>8} {'F':>8} {'p(ANOVA)':>10}")
print("-" * 75)

stat_results = {}
for metric, _ in PLOT_METRICS:
    vals = {c: cond_arrays[c][metric] for c in COND_NAMES}
    means = {c: np.mean(vals[c]) for c in COND_NAMES}

    # Skip ANOVA if all values are zero/constant
    all_std_zero = all(np.std(vals[c]) == 0 for c in COND_NAMES)
    if all_std_zero:
        f_stat, anova_p = 0.0, 1.0
    else:
        f_stat, anova_p = scipy_stats.f_oneway(
            vals['field_off'], vals['field_on_presence'], vals['field_on_smart']
        )

    sig = "***" if anova_p < 0.001 else "**" if anova_p < 0.01 else "*" if anova_p < 0.05 else ""
    print(f"{metric:<25} {means['field_off']:>8.2f} {means['field_on_presence']:>8.2f} {means['field_on_smart']:>8.2f} {f_stat:>8.2f} {anova_p:>10.4f} {sig}")

    # Pairwise t-tests
    pairwise = {}
    for cond_a, cond_b, label in PAIRS:
        if np.std(vals[cond_a]) == 0 and np.std(vals[cond_b]) == 0:
            pairwise[label] = {'t': 0.0, 'p': 1.0, 'd': 0.0}
        else:
            w = welch_t_test(vals[cond_a], vals[cond_b])
            pairwise[label] = {'t': w.statistic, 'p': w.p_value, 'd': w.effect_size}

    stat_results[metric] = {
        'means': means,
        'stds': {c: np.std(vals[c]) for c in COND_NAMES},
        'anova_f': f_stat, 'anova_p': anova_p,
        'pairwise': pairwise,
    }

print("\n* p<0.05  ** p<0.01  *** p<0.001")

# --- Grouped bar chart ---
fig, axes = plt.subplots(2, 3, figsize=(18, 10))
axes_flat = axes.flatten()

bar_width = 0.25
x_pos = np.arange(1)  # single group

for idx, (metric, title) in enumerate(PLOT_METRICS):
    ax = axes_flat[idx]
    for j, cond in enumerate(COND_NAMES):
        vals = cond_arrays[cond][metric]
        mean_val = np.mean(vals)
        std_val = np.std(vals)
        ax.bar(
            j * bar_width, mean_val, bar_width,
            yerr=std_val, color=COND_COLORS[cond],
            edgecolor='black', linewidth=0.8, capsize=4,
            label=COND_LABELS[cond] if idx == 0 else None,
        )
    ax.set_xticks([j * bar_width for j in range(3)])
    ax.set_xticklabels([COND_LABELS[c] for c in COND_NAMES], fontsize=9)
    ax.set_title(title)

    # Annotate ANOVA p-value
    sr = stat_results[metric]
    sig = "***" if sr['anova_p'] < 0.001 else "**" if sr['anova_p'] < 0.01 else "*" if sr['anova_p'] < 0.05 else "ns"
    ax.annotate(f"ANOVA p={sr['anova_p']:.3f} {sig}",
                xy=(0.5, 0.95), xycoords='axes fraction', ha='center', va='top', fontsize=8)

axes_flat[0].legend(loc='upper left', fontsize=9)
fig.suptitle('Smart Field Test: Eval Metrics (Grouped Bars)', fontsize=16, y=1.02)
fig.tight_layout()
save_figure(fig, os.path.join(OUTPUT_DIR, 'eval_grouped_bars'))
plt.show()

# --- Violin plots ---
fig, axes = plt.subplots(2, 3, figsize=(18, 10))
axes_flat = axes.flatten()

for idx, (metric, title) in enumerate(PLOT_METRICS):
    ax = axes_flat[idx]
    data_list = [cond_arrays[c][metric] for c in COND_NAMES]

    # Check if all data is zero
    if all(np.all(d == 0) for d in data_list):
        for j, cond in enumerate(COND_NAMES):
            ax.bar(j, 0, color=COND_COLORS[cond], edgecolor='black', linewidth=0.8)
        ax.annotate('All zero', xy=(0.5, 0.5), xycoords='axes fraction',
                    ha='center', va='center', fontsize=9, color='gray')
    else:
        parts = ax.violinplot(data_list, positions=range(3), showmeans=True, showmedians=True)
        for j, pc in enumerate(parts['bodies']):
            pc.set_facecolor(COND_COLORS[COND_NAMES[j]])
            pc.set_alpha(0.6)
        # Swarm overlay
        for j, cond in enumerate(COND_NAMES):
            jitter = np.random.default_rng(42 + j).uniform(-0.1, 0.1, len(data_list[j]))
            ax.scatter(j + jitter, data_list[j], c=COND_COLORS[cond], s=12, alpha=0.7, zorder=3)

    ax.set_xticks(range(3))
    ax.set_xticklabels([COND_LABELS[c] for c in COND_NAMES], fontsize=9)
    ax.set_title(title)

fig.suptitle('Smart Field Test: Eval Metrics (Violin)', fontsize=16, y=1.02)
fig.tight_layout()
save_figure(fig, os.path.join(OUTPUT_DIR, 'eval_violins'))
plt.show()

In [None]:
# --- Divergence & Correlation Plots ---
from matplotlib.lines import Line2D

fig, axes = plt.subplots(2, 2, figsize=(14, 11))

legend_elements = [
    Line2D([0], [0], marker='o', color='w', markerfacecolor=COND_COLORS[c],
           markersize=8, label=COND_LABELS[c])
    for c in COND_NAMES
]

# Panel 1: Weight divergence histogram (3 conditions overlaid)
ax = axes[0, 0]
all_divs = [cond_arrays[c]['mean_divergence'] for c in COND_NAMES]
if any(len(d) > 0 for d in all_divs):
    all_vals = np.concatenate([d for d in all_divs if len(d) > 0])
    bins = np.linspace(all_vals.min(), all_vals.max(), 20)
    for cond in COND_NAMES:
        if len(cond_arrays[cond]['mean_divergence']) > 0:
            ax.hist(cond_arrays[cond]['mean_divergence'], bins=bins, alpha=0.5,
                    color=COND_COLORS[cond], label=COND_LABELS[cond], edgecolor='black')
    ax.set_xlabel('Mean Weight Divergence')
    ax.set_ylabel('Count')
    ax.set_title('Weight Divergence Distribution')
    ax.legend()

# Panel 2: Population histogram (3 conditions overlaid)
ax = axes[0, 1]
all_pops_list = [cond_arrays[c]['final_population'] for c in COND_NAMES]
if any(len(p) > 0 for p in all_pops_list):
    all_pop_vals = np.concatenate([p for p in all_pops_list if len(p) > 0])
    bins = np.linspace(all_pop_vals.min(), all_pop_vals.max(), 20)
    for cond in COND_NAMES:
        if len(cond_arrays[cond]['final_population']) > 0:
            ax.hist(cond_arrays[cond]['final_population'], bins=bins, alpha=0.5,
                    color=COND_COLORS[cond], label=COND_LABELS[cond], edgecolor='black')
    ax.set_xlabel('Final Population (eval)')
    ax.set_ylabel('Count')
    ax.set_title('Eval Population Distribution')
    ax.legend()

# Panel 3: Scatter — HF collected vs divergence (color by condition)
ax = axes[1, 0]
for cond in COND_NAMES:
    hf = cond_arrays[cond]['hidden_food_collected']
    div = cond_arrays[cond]['mean_divergence']
    if len(hf) > 0:
        ax.scatter(div, hf, c=COND_COLORS[cond], s=30, alpha=0.7,
                   edgecolors='black', linewidth=0.5)
all_hf = np.concatenate([cond_arrays[c]['hidden_food_collected'] for c in COND_NAMES if len(cond_arrays[c]['hidden_food_collected']) > 0])
all_div = np.concatenate([cond_arrays[c]['mean_divergence'] for c in COND_NAMES if len(cond_arrays[c]['mean_divergence']) > 0])
if len(all_div) > 2 and np.std(all_div) > 0 and np.std(all_hf) > 0:
    r, p = scipy_stats.pearsonr(all_div, all_hf)
    ax.set_title(f'HF Collected vs Divergence (r={r:.3f}, p={p:.3f})')
else:
    ax.set_title('HF Collected vs Divergence')
ax.set_xlabel('Mean Weight Divergence')
ax.set_ylabel('Hidden Food Collected')
ax.legend(handles=legend_elements)

# Panel 4: Scatter — reward vs population (color by condition)
ax = axes[1, 1]
for cond in COND_NAMES:
    rew = cond_arrays[cond]['total_reward']
    pop = cond_arrays[cond]['final_population'].astype(float)
    if len(rew) > 0:
        ax.scatter(pop, rew, c=COND_COLORS[cond], s=30, alpha=0.7,
                   edgecolors='black', linewidth=0.5)
all_rew = np.concatenate([cond_arrays[c]['total_reward'] for c in COND_NAMES if len(cond_arrays[c]['total_reward']) > 0])
all_pop = np.concatenate([cond_arrays[c]['final_population'].astype(float) for c in COND_NAMES if len(cond_arrays[c]['final_population']) > 0])
if len(all_pop) > 2 and np.std(all_pop) > 0 and np.std(all_rew) > 0:
    r, p = scipy_stats.pearsonr(all_pop, all_rew)
    ax.set_title(f'Reward vs Population (r={r:.3f}, p={p:.3f})')
else:
    ax.set_title('Reward vs Population')
ax.set_xlabel('Final Population')
ax.set_ylabel('Total Reward')
ax.legend(handles=legend_elements)

fig.suptitle('Smart Field Test: Divergence & Correlations', fontsize=16, y=1.02)
fig.tight_layout()
save_figure(fig, os.path.join(OUTPUT_DIR, 'divergence_correlation'))
plt.show()

In [None]:
# --- Summary Report ---
from IPython.display import display, Markdown

n_seeds = {c: len(cond_hf_eval[c]) for c in COND_NAMES}

report = f"""# Smart Field Test: Results Summary

**Date**: {datetime.now().strftime('%Y-%m-%d %H:%M')}

## Experimental Setup
- **Grid**: 40x40, **Regular food**: 10 (scarce), **Food energy**: 60 (reduced)
- **Hidden food**: 8 items, require 5 agents within distance 3, value 10x (600 energy each)
- **Max agents**: 64, **Starting energy**: 200, **Reproduce threshold**: 120
- **Steps**: 10M per seed
- **Seeds**: {n_seeds['field_off']} OFF + {n_seeds['field_on_presence']} Presence + {n_seeds['field_on_smart']} Smart

## Conditions
| Condition | write_mode | Field params |
|-----------|------------|-------------|
| Field OFF | presence | diffusion=0, decay=1, write=0 |
| Presence | presence | diffusion=0.1, decay=0.05, write=1.0 |
| Smart | state_dependent | diffusion=0.1, decay=0.05, write=1.0 |

## Eval Results (ANOVA + Pairwise)

| Metric | OFF | Presence | Smart | ANOVA p |
|--------|-----|----------|-------|---------|
"""

for metric, title in PLOT_METRICS:
    sr = stat_results[metric]
    sig = "***" if sr['anova_p'] < 0.001 else "**" if sr['anova_p'] < 0.01 else "*" if sr['anova_p'] < 0.05 else ""
    report += (
        f"| {title} "
        f"| {sr['means']['field_off']:.2f}+/-{sr['stds']['field_off']:.2f} "
        f"| {sr['means']['field_on_presence']:.2f}+/-{sr['stds']['field_on_presence']:.2f} "
        f"| {sr['means']['field_on_smart']:.2f}+/-{sr['stds']['field_on_smart']:.2f} "
        f"| {sr['anova_p']:.4f}{sig} |\n"
    )

report += "\n### Pairwise Comparisons (Smart vs Presence = KEY)\n\n"
report += "| Metric | Comparison | t | p | Cohen's d |\n"
report += "|--------|-----------|---|---|----------|\n"

for metric, title in PLOT_METRICS:
    sr = stat_results[metric]
    for label, pw in sr['pairwise'].items():
        sig = "***" if pw['p'] < 0.001 else "**" if pw['p'] < 0.01 else "*" if pw['p'] < 0.05 else ""
        report += f"| {title} | {label} | {pw['t']:.2f} | {pw['p']:.4f}{sig} | {pw['d']:.2f} |\n"

report += "\n## Key Findings\n\n"

# Auto-generate key findings
if 'total_reward' in stat_results:
    sr = stat_results['total_reward']
    if sr['anova_p'] < 0.05:
        best = max(COND_NAMES, key=lambda c: sr['means'][c])
        report += f"- **{COND_LABELS[best]} has highest total reward** (ANOVA p={sr['anova_p']:.4f})\n"
        # Check key comparison: smart vs presence
        pw = sr['pairwise']['Smart vs Presence (KEY)']
        if pw['p'] < 0.05:
            winner = 'Smart' if pw['t'] > 0 else 'Presence'
            report += f"- **{winner} significantly outperforms** on reward vs the other field mode (p={pw['p']:.4f}, d={pw['d']:.2f})\n"
        else:
            report += f"- Smart vs Presence reward difference not significant (p={pw['p']:.4f})\n"
    else:
        report += f"- No significant reward difference across conditions (ANOVA p={sr['anova_p']:.4f})\n"

if 'hidden_food_collected' in stat_results:
    sr = stat_results['hidden_food_collected']
    any_hf = any(sr['means'][c] > 0 for c in COND_NAMES)
    if any_hf and sr['anova_p'] < 0.05:
        best = max(COND_NAMES, key=lambda c: sr['means'][c])
        report += f"- **{COND_LABELS[best]} collects more hidden food** (ANOVA p={sr['anova_p']:.4f})\n"
    elif not any_hf:
        report += "- Neither condition achieved hidden food collection in eval episodes\n"
    else:
        report += f"- No significant difference in hidden food collection (ANOVA p={sr['anova_p']:.4f})\n"

if 'final_population' in stat_results:
    sr = stat_results['final_population']
    if sr['anova_p'] < 0.05:
        best = max(COND_NAMES, key=lambda c: sr['means'][c])
        report += f"- **{COND_LABELS[best]} sustains largest populations** (ANOVA p={sr['anova_p']:.4f})\n"
    else:
        report += f"- No significant population difference (ANOVA p={sr['anova_p']:.4f})\n"

if 'mean_divergence' in stat_results:
    sr = stat_results['mean_divergence']
    if sr['anova_p'] < 0.05:
        best = max(COND_NAMES, key=lambda c: sr['means'][c])
        report += f"- **{COND_LABELS[best]} shows highest weight divergence** (ANOVA p={sr['anova_p']:.4f})\n"
    else:
        report += f"- No significant divergence difference (ANOVA p={sr['anova_p']:.4f})\n"

report += f"""\n## Output Files
- Figures: `{OUTPUT_DIR}/training_comparison.{{pdf,png}}`
- Figures: `{OUTPUT_DIR}/eval_grouped_bars.{{pdf,png}}`
- Figures: `{OUTPUT_DIR}/eval_violins.{{pdf,png}}`
- Figures: `{OUTPUT_DIR}/divergence_correlation.{{pdf,png}}`
- Results: `{OUTPUT_DIR}/all_results.json`
- Results: `{OUTPUT_DIR}/all_results.pkl`
- Report: `{OUTPUT_DIR}/summary_report.md`
"""

display(Markdown(report))

# --- Save all results ---
json_results = {
    'experiment': 'smart_field_test',
    'date': datetime.now().isoformat(),
    'config': {
        'grid_size': 40, 'num_food': 10, 'food_energy': 60,
        'num_hidden': 8, 'required_agents': 5,
        'hidden_food_value_multiplier': 10.0,
        'max_agents': 64, 'total_steps': TOTAL_STEPS,
    },
    'n_seeds': n_seeds,
    'eval_stats': {},
}

for metric, title in PLOT_METRICS:
    sr = stat_results[metric]
    entry = {
        'anova_f': float(sr['anova_f']),
        'anova_p': float(sr['anova_p']),
    }
    for cond in COND_NAMES:
        entry[f'{cond}_mean'] = float(sr['means'][cond])
        entry[f'{cond}_std'] = float(sr['stds'][cond])
    for label, pw in sr['pairwise'].items():
        safe_label = label.replace(' ', '_').replace('(', '').replace(')', '')
        entry[f'pw_{safe_label}_t'] = float(pw['t'])
        entry[f'pw_{safe_label}_p'] = float(pw['p'])
        entry[f'pw_{safe_label}_d'] = float(pw['d'])
    json_results['eval_stats'][metric] = entry

json_path = os.path.join(OUTPUT_DIR, 'all_results.json')
with open(json_path, 'w') as f:
    json.dump(json_results, f, indent=2)
print(f"JSON saved: {json_path}")

pkl_results = {
    **json_results,
    'cond_hf_eval': cond_hf_eval,
    'cond_divergence': cond_divergence,
    'cond_rewards': {c: cond_rewards[c] for c in COND_NAMES if cond_rewards[c] is not None},
    'cond_populations': {c: cond_populations[c] for c in COND_NAMES if cond_populations[c] is not None},
}
pkl_path = os.path.join(OUTPUT_DIR, 'all_results.pkl')
with open(pkl_path, 'wb') as f:
    pickle.dump(pkl_results, f)
print(f"Pickle saved: {pkl_path}")

md_path = os.path.join(OUTPUT_DIR, 'summary_report.md')
with open(md_path, 'w') as f:
    f.write(report)
print(f"Report saved: {md_path}")

print(f"\nAll results saved to {OUTPUT_DIR}")