# Survival Pressure: Does Environmental Harshness Make Stigmergy Essential?

## Hypothesis

| Prediction | Mechanism |
|------------|----------|
| Field ON agents survive better | Field encodes spatial memory of food locations, enabling efficient foraging under scarcity |
| Field ON agents coordinate to reveal hidden food | Field gradients guide multiple agents to converge near hidden food sites |
| Field OFF agents starve or fail to coordinate | Without shared medium, agents cannot communicate food locations or coordinate for hidden food |
| The harsher the environment, the bigger the gap | Stigmergy becomes *essential*, not just *helpful*, when survival is at stake |

## Experimental Design

| Condition | Field | Seeds | Steps |
|-----------|-------|-------|-------|
| **Field ON** | diffusion=0.1, decay=0.05, write=1.0 | 30 | 10M |
| **Field OFF** | diffusion=0.0, decay=1.0, write=0.0 | 30 | 10M |

## Harsh Environment Config

| Parameter | Value | Why |
|-----------|-------|-----|
| grid_size | 40 | Larger arena = harder to find food |
| num_food | 10 | **SCARCE** (vs 40 in standard config) |
| food_energy | 60 | **REDUCED** (vs 100 standard) |
| num_hidden | 8 | Abundant hidden food — the coordination prize |
| required_agents | 5 | Hard coordination requirement |
| reveal_distance | 3 | Chebyshev distance for reveal |
| hidden_food_value_multiplier | 10.0 | Each hidden food worth 600 energy (10x regular) |
| max_agents | 64 | Population cap |
| starting_energy | 200 | Starting energy per agent |
| reproduce_threshold | 120 | Energy needed to reproduce |
| reproduce_cost | 40 | Energy cost of reproduction |

**The survival pressure**: Regular food alone may not sustain a large population. Agents that learn to coordinate and reveal hidden food get a massive energy windfall (600 per hidden food item). The field should be essential for this coordination.

In [None]:
from google.colab import drive
drive.mount('/content/drive')

import os
if not os.path.exists('/content/emergence-lab'):
    !git clone https://github.com/imashishkh21/emergence-lab.git /content/emergence-lab
%cd /content/emergence-lab
!git pull origin main

!pip install -e ".[dev]" -q
!pip install rliable -q

import jax
print(f"JAX devices: {jax.devices()}")
print(f"JAX backend: {jax.default_backend()}")

## Phase 1 — Quick Validation

Before committing to 60 full training runs, verify:
1. Populations survive the harsh environment (at least one condition has mean pop > 10)
2. Check if hidden food is being revealed at all (may take longer than 2M steps)

**Setup**: 3 seeds ON + 3 seeds OFF, ~2M steps (~8 iterations of 262K steps each)

In [None]:
import time, gc
import jax
import jax.numpy as jnp
import numpy as np
from src.configs import Config, TrainingMode
from src.training.parallel_train import ParallelTrainer
from src.agents.network import ActorCritic
from src.agents.policy import get_deterministic_actions
from src.environment.env import reset, step
from src.environment.obs import get_observations

NUM_ENVS = 32
NUM_STEPS = 128
MAX_AGENTS = 64
STEPS_PER_ITER = NUM_ENVS * NUM_STEPS * MAX_AGENTS  # 262,144
VALIDATION_ITERS = 8  # ~2M steps


def build_config(field_enabled: bool) -> Config:
    """Build harsh environment config for survival pressure experiment."""
    config = Config()
    # 64-agent base on large grid
    config.env.grid_size = 40
    config.env.num_agents = 16
    config.env.num_food = 10              # SCARCE
    config.env.max_steps = 500            # Explicit episode length
    config.evolution.enabled = True
    config.evolution.max_agents = 64
    config.evolution.starting_energy = 200
    config.evolution.food_energy = 60     # REDUCED
    config.evolution.energy_per_step = 1
    config.evolution.reproduce_threshold = 120
    config.evolution.reproduce_cost = 40
    config.evolution.mutation_std = 0.01
    # Hidden food
    config.env.hidden_food.enabled = True
    config.env.hidden_food.num_hidden = 8
    config.env.hidden_food.required_agents = 5
    config.env.hidden_food.reveal_distance = 3
    config.env.hidden_food.hidden_food_value_multiplier = 10.0
    # Training
    config.train.training_mode = TrainingMode.GRADIENT
    config.train.num_envs = NUM_ENVS
    config.train.num_steps = NUM_STEPS
    config.train.seed = 42
    config.log.wandb = False
    config.log.save_interval = 0
    # Field
    if not field_enabled:
        config.field.diffusion_rate = 0.0
        config.field.decay_rate = 1.0
        config.field.write_strength = 0.0
    return config


def run_hidden_food_eval(network, params, config, key, num_episodes=1):
    """Run eval episodes tracking hidden food metrics."""
    all_results = []
    for ep in range(num_episodes):
        key, ep_key = jax.random.split(key)
        state = reset(ep_key, config)
        ep_reward = 0.0
        ep_regular_food = 0.0
        ep_hf_revealed = 0
        ep_hf_collected = 0.0
        for t in range(config.env.max_steps):
            obs = get_observations(state, config)
            obs_batched = obs[None, :, :]  # (1, max_agents, obs_dim)
            actions = get_deterministic_actions(network, params, obs_batched)
            actions = actions[0]  # (max_agents,)
            pre_revealed = state.hidden_food_revealed
            state, rewards, done, info = step(state, actions, config)
            ep_reward += float(jnp.sum(rewards))
            ep_regular_food += float(info['food_collected_this_step'])
            ep_hf_collected += float(info['hidden_food_collected_this_step'])
            if pre_revealed is not None and state.hidden_food_revealed is not None:
                newly_revealed = (~pre_revealed) & state.hidden_food_revealed
                ep_hf_revealed += int(jnp.sum(newly_revealed))
            if bool(done):
                break
        food_energy = config.evolution.food_energy
        hf_multiplier = config.env.hidden_food.hidden_food_value_multiplier
        all_results.append({
            'total_reward': ep_reward,
            'regular_food_collected': ep_regular_food,
            'hidden_food_revealed': ep_hf_revealed,
            'hidden_food_collected': ep_hf_collected,
            'regular_food_energy': ep_regular_food * food_energy,
            'hidden_food_energy': ep_hf_collected * food_energy * hf_multiplier,
            'final_population': int(jnp.sum(state.agent_alive.astype(jnp.int32))),
        })
    agg = {k: np.mean([r[k] for r in all_results]) for k in all_results[0]}
    agg['per_episode'] = all_results
    return agg


# --- Validation training ---
validation_results = {}
for cond_name, field_on in [('field_on', True), ('field_off', False)]:
    print(f"\n{'='*60}")
    print(f"VALIDATION: {cond_name.upper()}")
    print(f"{'='*60}")
    config = build_config(field_on)
    seed_ids = [100, 101, 102]
    trainer = ParallelTrainer(
        config=config, num_seeds=3, seed_ids=seed_ids,
        checkpoint_dir=f'/tmp/validation_{cond_name}', master_seed=9999,
    )
    t0 = time.time()
    metrics = trainer.train(
        num_iterations=VALIDATION_ITERS,
        checkpoint_interval_minutes=999,   # No checkpoint writes
        resume=False,
        print_interval=2,
    )
    elapsed = time.time() - t0

    # Extract live population from trainer state
    ps = trainer._parallel_state
    # ps.env_state.agent_alive: shape (num_seeds, num_envs, max_agents)
    alive = np.array(ps.env_state.agent_alive)       # (3, 32, 64)
    pop_per_seed = alive.sum(axis=(1, 2)) / alive.shape[1]  # mean across envs
    final_pops = [float(p) for p in pop_per_seed]

    # Quick eval: 1 episode per condition with seed 0's params
    network = ActorCritic(hidden_dims=(64, 64), num_actions=6)
    eval_key = jax.random.PRNGKey(42)
    # Use shared params from seed 0
    seed0_params = jax.tree.map(lambda x: x[0], ps.params)
    eval_result = run_hidden_food_eval(network, seed0_params, config, eval_key, num_episodes=1)

    validation_results[cond_name] = {
        'metrics': metrics, 'elapsed': elapsed,
        'final_pops': final_pops, 'eval': eval_result,
    }
    print(f"\n{cond_name}: {elapsed:.0f}s")
    print(f"  Final populations (per seed, mean across envs): {[f'{p:.1f}' for p in final_pops]}")
    print(f"  Eval reward: {eval_result['total_reward']:.1f}")
    print(f"  Eval HF revealed: {eval_result['hidden_food_revealed']}")
    print(f"  Eval HF collected: {eval_result['hidden_food_collected']:.1f}")
    print(f"  Eval final pop: {eval_result['final_population']}")

    del trainer
    gc.collect()
    jax.clear_caches()

# --- Decision gate ---
print("\n" + "="*60)
print("VALIDATION DECISION GATE")
print("="*60)
on_pops = validation_results['field_on']['final_pops']
off_pops = validation_results['field_off']['final_pops']
mean_on = np.mean(on_pops)
mean_off = np.mean(off_pops)
print(f"Field ON  mean population: {mean_on:.1f}")
print(f"Field OFF mean population: {mean_off:.1f}")

on_hfr = validation_results['field_on']['eval']['hidden_food_revealed']
off_hfr = validation_results['field_off']['eval']['hidden_food_revealed']
print(f"Field ON  hidden food revealed (eval): {on_hfr}")
print(f"Field OFF hidden food revealed (eval): {off_hfr}")

if mean_on > 10 or mean_off > 10:
    print("\nPASS: At least one condition sustains population > 10. Proceed to Phase 2.")
else:
    print("\nFAIL: Both conditions have dangerously low populations.")
    print("  Fallback 1: Increase num_food from 10 to 15 (slightly less harsh)")
    print("  Fallback 2: Increase food_energy from 60 to 80 (more reward per food)")
    print("  Fallback 3: Reduce required_agents from 5 to 4 (easier hidden food)")
    print("  Apply ONE change at a time and re-run validation.")

if on_hfr == 0 and off_hfr == 0:
    print("\nNOTE: No hidden food revealed in either condition at 2M steps.")
    print("  This is expected - coordination may take longer to emerge.")
    print("  If populations are healthy, proceed. HF coordination is the 10M-step question.")

## Phase 2 — Full Training

30 seeds Field ON + 30 seeds Field OFF, 10M steps each.

**Execution plan**:
- 10 batches x 3 seeds per condition
- Seed IDs: ON [0-29], OFF [50-79]
- Resume-safe with checkpoint detection
- Memory cleanup between every batch
- Population crash detection after each batch
- Checkpoints saved to Google Drive every 60 minutes

In [None]:
import os, time, gc, pickle
from datetime import datetime, timedelta

TOTAL_STEPS = 10_000_000
SEEDS_PER_BATCH = 3
TOTAL_BATCHES = 10
CHECKPOINT_INTERVAL_MINUTES = 60
RESUME = True

steps_per_iter = NUM_ENVS * NUM_STEPS * MAX_AGENTS  # 262,144
num_iterations = TOTAL_STEPS // steps_per_iter       # 38

CHECKPOINT_BASE = '/content/drive/MyDrive/emergence-lab'
CHECKPOINT_DIRS = {
    'field_on': f'{CHECKPOINT_BASE}/survival_pressure_on',
    'field_off': f'{CHECKPOINT_BASE}/survival_pressure_off',
}

print(f"Steps per iteration: {steps_per_iter:,}")
print(f"Total iterations: {num_iterations}")
print(f"Total seeds per condition: {SEEDS_PER_BATCH * TOTAL_BATCHES}")
print(f"Total training runs: {2 * SEEDS_PER_BATCH * TOTAL_BATCHES}")

In [None]:
for condition_name, field_enabled in [('field_on', True), ('field_off', False)]:
    config = build_config(field_enabled)
    config.train.total_steps = TOTAL_STEPS
    checkpoint_dir_base = CHECKPOINT_DIRS[condition_name]
    os.makedirs(checkpoint_dir_base, exist_ok=True)

    all_results = []
    cond_start = time.time()

    for batch_number in range(TOTAL_BATCHES):
        if condition_name == 'field_on':
            seed_ids = list(range(
                batch_number * SEEDS_PER_BATCH,
                (batch_number + 1) * SEEDS_PER_BATCH
            ))
        else:
            seed_ids = list(range(
                50 + batch_number * SEEDS_PER_BATCH,
                50 + (batch_number + 1) * SEEDS_PER_BATCH
            ))

        checkpoint_dir = f'{checkpoint_dir_base}/batch_{batch_number}'
        batch_start = time.time()

        print(f"\n{'='*60}")
        print(f"{condition_name.upper()} - Batch {batch_number+1}/{TOTAL_BATCHES} - Seeds {seed_ids}")
        elapsed_total = time.time() - cond_start
        if batch_number > 0:
            avg_per_batch = elapsed_total / batch_number
            remaining = avg_per_batch * (TOTAL_BATCHES - batch_number)
            eta = datetime.now() + timedelta(seconds=remaining)
            print(f"ETA: {eta.strftime('%H:%M:%S')} ({remaining/60:.0f} min remaining)")
        print(f"{'='*60}")

        try:
            trainer = ParallelTrainer(
                config=config, num_seeds=SEEDS_PER_BATCH,
                seed_ids=seed_ids, checkpoint_dir=checkpoint_dir,
                master_seed=42 + batch_number * 1000,
            )
            metrics = trainer.train(
                num_iterations=num_iterations,
                checkpoint_interval_minutes=CHECKPOINT_INTERVAL_MINUTES,
                resume=RESUME,
                print_interval=5,
            )
            batch_time = time.time() - batch_start

            # Population crash detection
            ps = trainer._parallel_state
            alive = np.array(ps.env_state.agent_alive)  # (num_seeds, num_envs, max_agents)
            pop_per_seed = alive.sum(axis=(1, 2)) / alive.shape[1]
            batch_pops = [float(p) for p in pop_per_seed]
            print(f"  Final populations: {[f'{p:.0f}' for p in batch_pops]}")
            if all(p < 3 for p in batch_pops):
                print(f"  WARNING: All seeds in batch {batch_number} have near-zero population!")
                print(f"  Population may have crashed. Check environment parameters.")

            all_results.append({
                'batch': batch_number, 'seed_ids': seed_ids,
                'metrics': metrics, 'success': True,
                'time_seconds': batch_time, 'final_pops': batch_pops,
            })
            print(f"Batch {batch_number} done in {batch_time/60:.1f} min")
        except Exception as e:
            print(f"ERROR in batch {batch_number}: {e}")
            import traceback; traceback.print_exc()
            all_results.append({
                'batch': batch_number, 'seed_ids': seed_ids,
                'success': False, 'error': str(e),
            })
        finally:
            try: del trainer
            except: pass
            gc.collect()
            jax.clear_caches()

    cond_time = time.time() - cond_start
    summary_path = f'{checkpoint_dir_base}/training_summary.pkl'
    with open(summary_path, 'wb') as f:
        pickle.dump({
            'all_results': all_results,
            'total_time_seconds': cond_time,
            'condition': condition_name,
            'total_steps': TOTAL_STEPS,
            'config': {
                'grid_size': config.env.grid_size,
                'num_food': config.env.num_food,
                'food_energy': config.evolution.food_energy,
                'num_hidden': config.env.hidden_food.num_hidden,
                'required_agents': config.env.hidden_food.required_agents,
                'hidden_food_value_multiplier': config.env.hidden_food.hidden_food_value_multiplier,
                'field_enabled': field_enabled,
                'max_agents': config.evolution.max_agents,
            },
        }, f)
    print(f"\n{condition_name} COMPLETE in {cond_time/3600:.1f} hours")
    print(f"Summary saved to {summary_path}")

## Phase 3 — Analysis

Load checkpoints from both conditions, run evaluation episodes, compute statistics, and generate publication-quality figures.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import json, pickle, gc, os
import glob as glob_mod
from pathlib import Path
from datetime import datetime
from scipy import stats as scipy_stats

import jax
import jax.numpy as jnp

from src.configs import (
    Config, TrainingMode, HiddenFoodConfig, EnvConfig, FieldConfig,
    AgentConfig, TrainConfig, LogConfig, AnalysisConfig,
    EvolutionConfig, SpecializationConfig, FreezeEvolveConfig, ArchiveConfig,
)
from src.agents.network import ActorCritic
from src.agents.policy import get_deterministic_actions
from src.environment.env import reset, step
from src.environment.obs import get_observations
from src.training.checkpointing import load_checkpoint
from src.analysis.specialization import compute_weight_divergence
from src.analysis.statistics import (
    compute_iqm, compare_methods, welch_t_test,
    mann_whitney_test, probability_of_improvement,
)
from src.analysis.paper_figures import setup_publication_style, save_figure

setup_publication_style()
%matplotlib inline

OUTPUT_DIR = '/content/drive/MyDrive/emergence-lab/survival_pressure_results'
os.makedirs(OUTPUT_DIR, exist_ok=True)


def reconstruct_config(d):
    """Convert plain dict from load_checkpoint() back to Config dataclass."""
    if isinstance(d, Config):
        return d

    # Env config - handle nested hidden_food dict
    env_d = dict(d.get('env', {}))
    if 'hidden_food' in env_d and isinstance(env_d['hidden_food'], dict):
        env_d['hidden_food'] = HiddenFoodConfig(**env_d['hidden_food'])

    # Train config - handle TrainingMode enum
    train_d = dict(d.get('train', {}))
    if 'training_mode' in train_d and isinstance(train_d['training_mode'], str):
        train_d['training_mode'] = TrainingMode(train_d['training_mode'])

    # Agent config - handle hidden_dims tuple
    agent_d = dict(d.get('agent', {}))
    if 'hidden_dims' in agent_d and isinstance(agent_d['hidden_dims'], list):
        agent_d['hidden_dims'] = tuple(agent_d['hidden_dims'])

    return Config(
        env=EnvConfig(**env_d),
        field=FieldConfig(**d.get('field', {})),
        agent=AgentConfig(**agent_d),
        train=TrainConfig(**train_d),
        log=LogConfig(**d.get('log', {})),
        analysis=AnalysisConfig(**d.get('analysis', {})),
        evolution=EvolutionConfig(**d.get('evolution', {})),
        specialization=SpecializationConfig(**d.get('specialization', {})),
        freeze_evolve=FreezeEvolveConfig(**d.get('freeze_evolve', {})),
        archive=ArchiveConfig(**d.get('archive', {})),
    )


def discover_checkpoints(base_dir):
    """Find all checkpoint paths under base_dir/batch_*/seed_*/step_*.pkl."""
    paths = []
    for batch_idx in range(10):
        batch_dir = os.path.join(base_dir, f'batch_{batch_idx}')
        if not os.path.exists(batch_dir):
            continue
        for seed_dir_name in sorted(os.listdir(batch_dir)):
            seed_path = os.path.join(batch_dir, seed_dir_name)
            if not os.path.isdir(seed_path):
                continue
            pkl_files = glob_mod.glob(os.path.join(seed_path, 'step_*.pkl'))
            if pkl_files:
                paths.append(sorted(pkl_files)[-1])  # Latest step
    return paths


def load_training_summary(base_dir):
    """Load training_summary.pkl and extract per-seed rewards + populations."""
    summary_path = os.path.join(base_dir, 'training_summary.pkl')
    if not os.path.exists(summary_path):
        return None, None
    with open(summary_path, 'rb') as f:
        summary = pickle.load(f)
    rewards = []
    populations = []
    for batch in summary['all_results']:
        if not batch.get('success', True):
            continue
        if 'metrics' in batch and 'mean_reward' in batch['metrics']:
            rewards.extend(batch['metrics']['mean_reward'])
        if 'metrics' in batch and 'population_size' in batch['metrics']:
            populations.extend(batch['metrics']['population_size'])
    if rewards:
        return np.array(rewards), np.array(populations, dtype=float)
    return None, None


def load_seed_data(ckpt_path):
    """Load checkpoint, extract network + config, free full checkpoint."""
    ckpt = load_checkpoint(ckpt_path)
    config = reconstruct_config(ckpt['config'])
    # agent_params from checkpoint: (num_envs, max_agents, ...)
    # Take env 0 for divergence analysis: (max_agents, ...)
    agent_params_env0 = jax.tree_util.tree_map(lambda x: x[0], ckpt['agent_params'])
    network = ActorCritic(
        hidden_dims=tuple(config.agent.hidden_dims), num_actions=6
    )
    result = {
        'params': ckpt['params'],
        'agent_params': agent_params_env0,
        'config': config,
        'network': network,
        'seed_id': ckpt.get('seed_id', -1),
    }
    del ckpt
    return result


# run_hidden_food_eval() already defined in Cell 3
print("Analysis utilities loaded.")

In [None]:
# --- Load Training Data ---
CHECKPOINT_BASE = '/content/drive/MyDrive/emergence-lab'

field_on_dir = f'{CHECKPOINT_BASE}/survival_pressure_on'
field_off_dir = f'{CHECKPOINT_BASE}/survival_pressure_off'

field_on_ckpt_paths = discover_checkpoints(field_on_dir)
field_off_ckpt_paths = discover_checkpoints(field_off_dir)

print(f"Found {len(field_on_ckpt_paths)}/30 Field ON seeds")
print(f"Found {len(field_off_ckpt_paths)}/30 Field OFF seeds")

# Load training summaries
on_rewards, on_populations = load_training_summary(field_on_dir)
off_rewards, off_populations = load_training_summary(field_off_dir)

if on_rewards is not None:
    print(f"\nField ON  training rewards: {len(on_rewards)} seeds, mean={np.mean(on_rewards):.3f}")
if off_rewards is not None:
    print(f"Field OFF training rewards: {len(off_rewards)} seeds, mean={np.mean(off_rewards):.3f}")

In [None]:
# --- Training Reward Statistics ---
if on_rewards is not None and off_rewards is not None:
    comparison = compare_methods(
        {'Field ON': on_rewards, 'Field OFF': off_rewards},
        alpha=0.05, n_bootstrap=10000, seed=42,
    )
    print(comparison.summary)

    # Additional tests
    welch = welch_t_test(on_rewards, off_rewards)
    mw = mann_whitney_test(on_rewards, off_rewards)
    poi = probability_of_improvement(on_rewards, off_rewards, seed=42)

    print(f"\nWelch's t-test: t={welch.statistic:.3f}, p={welch.p_value:.4f}, d={welch.effect_size:.3f}")
    print(f"Mann-Whitney U: U={mw.statistic:.1f}, p={mw.p_value:.4f}")
    print(f"P(Field ON > Field OFF): {poi['prob_x_better']:.3f}")
    print(f"P(Field OFF > Field ON): {poi['prob_y_better']:.3f}")
else:
    print("Training summaries not found. Run training cells first.")
    comparison = None

In [None]:
# --- Training Comparison Plots ---
COLOR_ON = '#009988'
COLOR_OFF = '#BBBBBB'

if on_rewards is not None and off_rewards is not None:
    fig, axes = plt.subplots(1, 3, figsize=(16, 5))

    # Panel 1: IQM bars with CI
    ax = axes[0]
    iqm_on = compute_iqm(on_rewards, seed=42)
    iqm_off = compute_iqm(off_rewards, seed=42)
    bars = ax.bar(
        ['Field ON', 'Field OFF'],
        [iqm_on.iqm, iqm_off.iqm],
        color=[COLOR_ON, COLOR_OFF],
        edgecolor='black', linewidth=0.8,
    )
    ax.errorbar(
        [0, 1],
        [iqm_on.iqm, iqm_off.iqm],
        yerr=[
            [iqm_on.iqm - iqm_on.ci_lower, iqm_off.iqm - iqm_off.ci_lower],
            [iqm_on.ci_upper - iqm_on.iqm, iqm_off.ci_upper - iqm_off.iqm],
        ],
        fmt='none', color='black', capsize=5,
    )
    ax.set_ylabel('IQM Reward')
    ax.set_title('Training Reward (IQM + 95% CI)')

    # Panel 2: Violin + swarm
    ax = axes[1]
    parts = ax.violinplot(
        [on_rewards, off_rewards], positions=[0, 1],
        showmeans=True, showmedians=True,
    )
    for i, pc in enumerate(parts['bodies']):
        pc.set_facecolor([COLOR_ON, COLOR_OFF][i])
        pc.set_alpha(0.6)
    # Swarm overlay
    jitter_on = np.random.default_rng(42).uniform(-0.1, 0.1, len(on_rewards))
    jitter_off = np.random.default_rng(43).uniform(-0.1, 0.1, len(off_rewards))
    ax.scatter(jitter_on, on_rewards, c=COLOR_ON, s=15, alpha=0.7, zorder=3)
    ax.scatter(1 + jitter_off, off_rewards, c=COLOR_OFF, s=15, alpha=0.7, zorder=3, edgecolors='gray')
    ax.set_xticks([0, 1])
    ax.set_xticklabels(['Field ON', 'Field OFF'])
    ax.set_ylabel('Final Reward')
    ax.set_title('Reward Distribution')

    # Panel 3: Population histogram
    ax = axes[2]
    if on_populations is not None and off_populations is not None:
        bins = np.linspace(
            min(on_populations.min(), off_populations.min()),
            max(on_populations.max(), off_populations.max()),
            20,
        )
        ax.hist(on_populations, bins=bins, alpha=0.6, color=COLOR_ON, label='Field ON', edgecolor='black')
        ax.hist(off_populations, bins=bins, alpha=0.6, color=COLOR_OFF, label='Field OFF', edgecolor='black')
        ax.set_xlabel('Final Population')
        ax.set_ylabel('Count')
        ax.set_title('Population Distribution')
        ax.legend()
    else:
        ax.text(0.5, 0.5, 'No population data', ha='center', va='center', transform=ax.transAxes)

    fig.suptitle('Survival Pressure: Training Comparison', fontsize=16, y=1.02)
    fig.tight_layout()
    save_figure(fig, os.path.join(OUTPUT_DIR, 'training_comparison'))
    plt.show()
else:
    print("No training data to plot.")

In [None]:
# --- Load & Eval All Checkpoints ---
hf_eval_on = []
hf_eval_off = []
divergence_on = []
divergence_off = []

for cond_name, ckpt_paths, eval_list, div_list in [
    ("Field ON", field_on_ckpt_paths, hf_eval_on, divergence_on),
    ("Field OFF", field_off_ckpt_paths, hf_eval_off, divergence_off),
]:
    print(f"\n{'='*40}")
    print(f"Evaluating {cond_name}: {len(ckpt_paths)} seeds")
    print(f"{'='*40}")

    for i, ckpt_path in enumerate(ckpt_paths):
        print(f"  [{cond_name}] Eval seed {i+1}/{len(ckpt_paths)}: {os.path.basename(ckpt_path)}")
        seed_data = load_seed_data(ckpt_path)
        key = jax.random.PRNGKey(42 + i)

        hf_result = run_hidden_food_eval(
            seed_data['network'], seed_data['params'],
            seed_data['config'], key, num_episodes=1,
        )
        hf_result['seed_id'] = seed_data['seed_id']
        eval_list.append(hf_result)

        # Compute weight divergence: agent_params is (max_agents, ...) from env 0
        max_agents = seed_data['config'].evolution.max_agents
        alive_mask = np.ones(max_agents, dtype=bool)
        div = compute_weight_divergence(seed_data['agent_params'], alive_mask)
        div_list.append({
            'seed_id': seed_data['seed_id'],
            'mean_divergence': float(div['mean_divergence']),
            'max_divergence': float(div['max_divergence']),
        })

        del seed_data
        gc.collect()

        # Intermediate save - survives crashes
        progress_path = os.path.join(
            OUTPUT_DIR,
            f'eval_progress_{cond_name.lower().replace(" ", "_")}.pkl'
        )
        with open(progress_path, 'wb') as f:
            pickle.dump({
                'condition': cond_name,
                'completed_seeds': i + 1,
                'total_seeds': len(ckpt_paths),
                'hf_eval': list(eval_list),
                'divergence': list(div_list),
            }, f)

# Extract numpy arrays for analysis
on_hf_revealed = np.array([r['hidden_food_revealed'] for r in hf_eval_on])
off_hf_revealed = np.array([r['hidden_food_revealed'] for r in hf_eval_off])
on_hf_collected = np.array([r['hidden_food_collected'] for r in hf_eval_on])
off_hf_collected = np.array([r['hidden_food_collected'] for r in hf_eval_off])
on_regular_food = np.array([r['regular_food_collected'] for r in hf_eval_on])
off_regular_food = np.array([r['regular_food_collected'] for r in hf_eval_off])
on_hf_energy = np.array([r['hidden_food_energy'] for r in hf_eval_on])
off_hf_energy = np.array([r['hidden_food_energy'] for r in hf_eval_off])
on_regular_energy = np.array([r['regular_food_energy'] for r in hf_eval_on])
off_regular_energy = np.array([r['regular_food_energy'] for r in hf_eval_off])
on_total_reward = np.array([r['total_reward'] for r in hf_eval_on])
off_total_reward = np.array([r['total_reward'] for r in hf_eval_off])
on_final_pop = np.array([r['final_population'] for r in hf_eval_on])
off_final_pop = np.array([r['final_population'] for r in hf_eval_off])
on_mean_div = np.array([d['mean_divergence'] for d in divergence_on])
off_mean_div = np.array([d['mean_divergence'] for d in divergence_off])

print(f"\nEval complete: {len(hf_eval_on)} ON seeds, {len(hf_eval_off)} OFF seeds")

In [None]:
# --- Hidden Food Statistics ---
metrics_to_test = [
    ('hidden_food_revealed', on_hf_revealed, off_hf_revealed),
    ('hidden_food_collected', on_hf_collected, off_hf_collected),
    ('regular_food_collected', on_regular_food, off_regular_food),
    ('hidden_food_energy', on_hf_energy, off_hf_energy),
    ('total_reward', on_total_reward, off_total_reward),
    ('final_population', on_final_pop.astype(float), off_final_pop.astype(float)),
]

print(f"{'Metric':<30} {'ON mean':>10} {'OFF mean':>10} {'t':>8} {'p':>10} {'d':>8}")
print("-" * 80)

stat_results = {}
for name, on_vals, off_vals in metrics_to_test:
    on_mean = np.mean(on_vals)
    off_mean = np.mean(off_vals)
    on_std = np.std(on_vals)
    off_std = np.std(off_vals)

    # Handle all-zero case gracefully
    if on_std == 0 and off_std == 0:
        t_val, p_val, d_val = 0.0, 1.0, 0.0
    else:
        welch = welch_t_test(on_vals, off_vals)
        t_val = welch.statistic
        p_val = welch.p_value
        d_val = welch.effect_size

    sig = "***" if p_val < 0.001 else "**" if p_val < 0.01 else "*" if p_val < 0.05 else ""
    print(f"{name:<30} {on_mean:>10.2f} {off_mean:>10.2f} {t_val:>8.2f} {p_val:>10.4f} {d_val:>8.2f} {sig}")

    stat_results[name] = {
        'on_mean': on_mean, 'on_std': on_std,
        'off_mean': off_mean, 'off_std': off_std,
        't': t_val, 'p': p_val, 'd': d_val,
    }

print("\n* p<0.05  ** p<0.01  *** p<0.001")

In [None]:
# --- Hidden Food Coordination Plots ---
fig, axes = plt.subplots(1, 5, figsize=(24, 5))

# Panel 1: HF revealed
ax = axes[0]
means = [np.mean(on_hf_revealed), np.mean(off_hf_revealed)]
stds = [np.std(on_hf_revealed), np.std(off_hf_revealed)]
bars = ax.bar(['Field ON', 'Field OFF'], means, yerr=stds, color=[COLOR_ON, COLOR_OFF],
              edgecolor='black', linewidth=0.8, capsize=5)
ax.set_ylabel('Count')
ax.set_title('Hidden Food Revealed')
if means[0] == 0 and means[1] == 0:
    ax.annotate('Both zero - coordination\nnot yet achieved',
                xy=(0.5, 0.5), xycoords='axes fraction', ha='center', va='center',
                fontsize=9, color='gray')

# Panel 2: HF collected
ax = axes[1]
means = [np.mean(on_hf_collected), np.mean(off_hf_collected)]
stds = [np.std(on_hf_collected), np.std(off_hf_collected)]
ax.bar(['Field ON', 'Field OFF'], means, yerr=stds, color=[COLOR_ON, COLOR_OFF],
       edgecolor='black', linewidth=0.8, capsize=5)
ax.set_ylabel('Count')
ax.set_title('Hidden Food Collected')

# Panel 3: Regular food
ax = axes[2]
means = [np.mean(on_regular_food), np.mean(off_regular_food)]
stds = [np.std(on_regular_food), np.std(off_regular_food)]
ax.bar(['Field ON', 'Field OFF'], means, yerr=stds, color=[COLOR_ON, COLOR_OFF],
       edgecolor='black', linewidth=0.8, capsize=5)
ax.set_ylabel('Count')
ax.set_title('Regular Food Collected')

# Panel 4: Energy breakdown (stacked bar)
ax = axes[3]
on_reg_e = np.mean(on_regular_energy)
on_hf_e = np.mean(on_hf_energy)
off_reg_e = np.mean(off_regular_energy)
off_hf_e = np.mean(off_hf_energy)
ax.bar(['Field ON', 'Field OFF'], [on_reg_e, off_reg_e],
       color=[COLOR_ON, COLOR_OFF], edgecolor='black', linewidth=0.8, label='Regular food')
ax.bar(['Field ON', 'Field OFF'], [on_hf_e, off_hf_e],
       bottom=[on_reg_e, off_reg_e],
       color=['#00665f', '#888888'], edgecolor='black', linewidth=0.8, label='Hidden food')
ax.set_ylabel('Energy')
ax.set_title('Energy Breakdown')
ax.legend()

# Panel 5: HF collected violin
ax = axes[4]
if np.any(on_hf_collected > 0) or np.any(off_hf_collected > 0):
    parts = ax.violinplot(
        [on_hf_collected, off_hf_collected], positions=[0, 1],
        showmeans=True, showmedians=True,
    )
    for i, pc in enumerate(parts['bodies']):
        pc.set_facecolor([COLOR_ON, COLOR_OFF][i])
        pc.set_alpha(0.6)
    ax.set_xticks([0, 1])
    ax.set_xticklabels(['Field ON', 'Field OFF'])
else:
    ax.bar(['Field ON', 'Field OFF'], [0, 0], color=[COLOR_ON, COLOR_OFF],
           edgecolor='black', linewidth=0.8)
    ax.annotate('No hidden food collected\nin either condition',
                xy=(0.5, 0.5), xycoords='axes fraction', ha='center', va='center',
                fontsize=9, color='gray')
ax.set_ylabel('Hidden Food Collected')
ax.set_title('HF Collected Distribution')

fig.suptitle('Survival Pressure: Hidden Food Coordination', fontsize=16, y=1.02)
fig.tight_layout()
save_figure(fig, os.path.join(OUTPUT_DIR, 'hidden_food_coordination'))
plt.show()

In [None]:
# --- Divergence & Correlation Plots ---
fig, axes = plt.subplots(2, 2, figsize=(12, 10))

# Panel 1: Weight divergence histogram
ax = axes[0, 0]
if len(on_mean_div) > 0 and len(off_mean_div) > 0:
    bins = np.linspace(
        min(on_mean_div.min(), off_mean_div.min()),
        max(on_mean_div.max(), off_mean_div.max()),
        20,
    )
    ax.hist(on_mean_div, bins=bins, alpha=0.6, color=COLOR_ON, label='Field ON', edgecolor='black')
    ax.hist(off_mean_div, bins=bins, alpha=0.6, color=COLOR_OFF, label='Field OFF', edgecolor='black')
    ax.set_xlabel('Mean Weight Divergence')
    ax.set_ylabel('Count')
    ax.set_title('Weight Divergence Distribution')
    ax.legend()

# Panel 2: Eval population histogram
ax = axes[0, 1]
if len(on_final_pop) > 0 and len(off_final_pop) > 0:
    all_pops = np.concatenate([on_final_pop, off_final_pop])
    bins = np.linspace(all_pops.min(), all_pops.max(), 20)
    ax.hist(on_final_pop, bins=bins, alpha=0.6, color=COLOR_ON, label='Field ON', edgecolor='black')
    ax.hist(off_final_pop, bins=bins, alpha=0.6, color=COLOR_OFF, label='Field OFF', edgecolor='black')
    ax.set_xlabel('Final Population (eval)')
    ax.set_ylabel('Count')
    ax.set_title('Eval Population Distribution')
    ax.legend()

# Panel 3: HF collected vs divergence scatter (with Pearson r)
ax = axes[1, 0]
all_hf = np.concatenate([on_hf_collected, off_hf_collected])
all_div = np.concatenate([on_mean_div, off_mean_div])
colors = [COLOR_ON] * len(on_hf_collected) + [COLOR_OFF] * len(off_hf_collected)
ax.scatter(all_div, all_hf, c=colors, s=30, alpha=0.7, edgecolors='black', linewidth=0.5)
if len(all_div) > 2 and np.std(all_div) > 0 and np.std(all_hf) > 0:
    r, p = scipy_stats.pearsonr(all_div, all_hf)
    ax.set_title(f'HF Collected vs Divergence (r={r:.3f}, p={p:.3f})')
else:
    ax.set_title('HF Collected vs Divergence')
ax.set_xlabel('Mean Weight Divergence')
ax.set_ylabel('Hidden Food Collected')
# Legend
from matplotlib.lines import Line2D
legend_elements = [
    Line2D([0], [0], marker='o', color='w', markerfacecolor=COLOR_ON, markersize=8, label='Field ON'),
    Line2D([0], [0], marker='o', color='w', markerfacecolor=COLOR_OFF, markersize=8, label='Field OFF'),
]
ax.legend(handles=legend_elements)

# Panel 4: Reward vs population scatter
ax = axes[1, 1]
all_rew = np.concatenate([on_total_reward, off_total_reward])
all_pop = np.concatenate([on_final_pop.astype(float), off_final_pop.astype(float)])
ax.scatter(all_pop, all_rew, c=colors, s=30, alpha=0.7, edgecolors='black', linewidth=0.5)
if len(all_pop) > 2 and np.std(all_pop) > 0 and np.std(all_rew) > 0:
    r, p = scipy_stats.pearsonr(all_pop, all_rew)
    ax.set_title(f'Reward vs Population (r={r:.3f}, p={p:.3f})')
else:
    ax.set_title('Reward vs Population')
ax.set_xlabel('Final Population')
ax.set_ylabel('Total Reward')
ax.legend(handles=legend_elements)

fig.suptitle('Survival Pressure: Divergence & Correlations', fontsize=16, y=1.02)
fig.tight_layout()
save_figure(fig, os.path.join(OUTPUT_DIR, 'divergence_correlation'))
plt.show()

In [None]:
# --- Population Trajectory Plot ---
# Extract per-batch population data from training summaries
fig, ax = plt.subplots(figsize=(10, 5))

has_trajectory = False
for cond_name, base_dir, color, label in [
    ('field_on', field_on_dir, COLOR_ON, 'Field ON'),
    ('field_off', field_off_dir, COLOR_OFF, 'Field OFF'),
]:
    summary_path = os.path.join(base_dir, 'training_summary.pkl')
    if not os.path.exists(summary_path):
        continue
    with open(summary_path, 'rb') as f:
        summary = pickle.load(f)

    # Collect final populations per batch (each batch has SEEDS_PER_BATCH values)
    batch_pops = []
    for batch in summary['all_results']:
        if batch.get('success', False) and 'final_pops' in batch:
            batch_pops.append(batch['final_pops'])

    if batch_pops:
        has_trajectory = True
        # Each entry is a list of per-seed populations
        all_pops = np.array(batch_pops)  # (num_batches, seeds_per_batch)
        batch_means = all_pops.mean(axis=1)
        batch_stds = all_pops.std(axis=1)
        batch_indices = np.arange(len(batch_means))

        ax.plot(batch_indices, batch_means, color=color, label=label, linewidth=2)
        ax.fill_between(
            batch_indices, batch_means - batch_stds, batch_means + batch_stds,
            color=color, alpha=0.2,
        )

if has_trajectory:
    ax.set_xlabel('Batch Index')
    ax.set_ylabel('Final Population (mean across seeds in batch)')
    ax.set_title('Population Trajectory Across Batches')
    ax.legend()
    ax.grid(True, alpha=0.3)
    save_figure(fig, os.path.join(OUTPUT_DIR, 'population_trajectory'))
    plt.show()
else:
    print("No population trajectory data available.")
    plt.close(fig)

In [None]:
# --- Summary Report ---
from IPython.display import display, Markdown

report = f"""# Survival Pressure Experiment: Results Summary

**Date**: {datetime.now().strftime('%Y-%m-%d %H:%M')}

## Experimental Setup
- **Grid**: 40x40, **Regular food**: 10 (scarce), **Food energy**: 60 (reduced)
- **Hidden food**: 8 items, require 5 agents within distance 3, value 10x (600 energy each)
- **Max agents**: 64, **Starting energy**: 200, **Reproduce threshold**: 120
- **Steps**: 10M per seed, **Seeds**: {len(hf_eval_on)} ON + {len(hf_eval_off)} OFF

## Training Rewards
"""

if on_rewards is not None and off_rewards is not None:
    report += f"""| Metric | Field ON | Field OFF |
|--------|----------|----------|
| Mean reward | {np.mean(on_rewards):.4f} | {np.mean(off_rewards):.4f} |
| Std reward | {np.std(on_rewards):.4f} | {np.std(off_rewards):.4f} |
| IQM | {compute_iqm(on_rewards, seed=42).iqm:.4f} | {compute_iqm(off_rewards, seed=42).iqm:.4f} |
"""

report += f"""
## Hidden Food Coordination (Eval Episodes)

| Metric | Field ON | Field OFF | t-stat | p-value | Cohen's d |
|--------|----------|----------|--------|---------|----------|
"""

for name, sr in stat_results.items():
    sig = "***" if sr['p'] < 0.001 else "**" if sr['p'] < 0.01 else "*" if sr['p'] < 0.05 else ""
    report += f"| {name} | {sr['on_mean']:.2f} +/- {sr['on_std']:.2f} | {sr['off_mean']:.2f} +/- {sr['off_std']:.2f} | {sr['t']:.2f} | {sr['p']:.4f}{sig} | {sr['d']:.2f} |\n"

report += f"""
## Weight Divergence
| Metric | Field ON | Field OFF |
|--------|----------|----------|
| Mean divergence | {np.mean(on_mean_div):.4f} +/- {np.std(on_mean_div):.4f} | {np.mean(off_mean_div):.4f} +/- {np.std(off_mean_div):.4f} |

## Key Findings
"""

# Auto-generate key findings
if 'total_reward' in stat_results:
    sr = stat_results['total_reward']
    if sr['p'] < 0.05:
        winner = "Field ON" if sr['on_mean'] > sr['off_mean'] else "Field OFF"
        report += f"- **{winner} significantly outperforms** on total reward (p={sr['p']:.4f}, d={sr['d']:.2f})\n"
    else:
        report += f"- No significant difference in total reward (p={sr['p']:.4f})\n"

if 'hidden_food_collected' in stat_results:
    sr = stat_results['hidden_food_collected']
    if sr['on_mean'] > 0 or sr['off_mean'] > 0:
        if sr['p'] < 0.05:
            winner = "Field ON" if sr['on_mean'] > sr['off_mean'] else "Field OFF"
            report += f"- **{winner} collects more hidden food** (p={sr['p']:.4f}, d={sr['d']:.2f})\n"
        else:
            report += f"- No significant difference in hidden food collection (p={sr['p']:.4f})\n"
    else:
        report += "- Neither condition achieved hidden food collection in eval episodes\n"

if 'final_population' in stat_results:
    sr = stat_results['final_population']
    if sr['p'] < 0.05:
        winner = "Field ON" if sr['on_mean'] > sr['off_mean'] else "Field OFF"
        report += f"- **{winner} sustains larger populations** (p={sr['p']:.4f}, d={sr['d']:.2f})\n"
    else:
        report += f"- No significant population difference (p={sr['p']:.4f})\n"

report += f"""
## Output Files
- Figures: `{OUTPUT_DIR}/training_comparison.{{pdf,png}}`
- Figures: `{OUTPUT_DIR}/hidden_food_coordination.{{pdf,png}}`
- Figures: `{OUTPUT_DIR}/divergence_correlation.{{pdf,png}}`
- Figures: `{OUTPUT_DIR}/population_trajectory.{{pdf,png}}`
- Results: `{OUTPUT_DIR}/all_results.json`
- Results: `{OUTPUT_DIR}/all_results.pkl`
- Report: `{OUTPUT_DIR}/summary_report.md`
"""

display(Markdown(report))

In [None]:
# --- Save All Results ---

# JSON-serializable results
json_results = {
    'experiment': 'survival_pressure',
    'date': datetime.now().isoformat(),
    'config': {
        'grid_size': 40, 'num_food': 10, 'food_energy': 60,
        'num_hidden': 8, 'required_agents': 5,
        'hidden_food_value_multiplier': 10.0,
        'max_agents': 64, 'total_steps': TOTAL_STEPS,
    },
    'n_seeds': {
        'field_on': len(hf_eval_on),
        'field_off': len(hf_eval_off),
    },
    'training_rewards': {},
    'eval_stats': {},
    'divergence': {
        'field_on_mean': float(np.mean(on_mean_div)) if len(on_mean_div) > 0 else None,
        'field_on_std': float(np.std(on_mean_div)) if len(on_mean_div) > 0 else None,
        'field_off_mean': float(np.mean(off_mean_div)) if len(off_mean_div) > 0 else None,
        'field_off_std': float(np.std(off_mean_div)) if len(off_mean_div) > 0 else None,
    },
}

if on_rewards is not None and off_rewards is not None:
    json_results['training_rewards'] = {
        'field_on_mean': float(np.mean(on_rewards)),
        'field_on_std': float(np.std(on_rewards)),
        'field_off_mean': float(np.mean(off_rewards)),
        'field_off_std': float(np.std(off_rewards)),
    }

for name, sr in stat_results.items():
    json_results['eval_stats'][name] = {
        'on_mean': float(sr['on_mean']),
        'on_std': float(sr['on_std']),
        'off_mean': float(sr['off_mean']),
        'off_std': float(sr['off_std']),
        't_statistic': float(sr['t']),
        'p_value': float(sr['p']),
        'cohens_d': float(sr['d']),
    }

# Save JSON
json_path = os.path.join(OUTPUT_DIR, 'all_results.json')
with open(json_path, 'w') as f:
    json.dump(json_results, f, indent=2)
print(f"JSON saved: {json_path}")

# Save pickle (includes numpy arrays)
pkl_results = {
    **json_results,
    'hf_eval_on': hf_eval_on,
    'hf_eval_off': hf_eval_off,
    'divergence_on': divergence_on,
    'divergence_off': divergence_off,
    'on_rewards': on_rewards,
    'off_rewards': off_rewards,
    'on_populations': on_populations,
    'off_populations': off_populations,
}
pkl_path = os.path.join(OUTPUT_DIR, 'all_results.pkl')
with open(pkl_path, 'wb') as f:
    pickle.dump(pkl_results, f)
print(f"Pickle saved: {pkl_path}")

# Save markdown report
md_path = os.path.join(OUTPUT_DIR, 'summary_report.md')
with open(md_path, 'w') as f:
    f.write(report)
print(f"Report saved: {md_path}")

print(f"\nAll results saved to {OUTPUT_DIR}")