# Hidden Food Coordination Analysis — Field ON vs Field OFF

**Hypothesis**: Field ON agents can use the shared field to signal hidden food locations,
enabling coordination. Field OFF agents have NO signaling mechanism.

## Key Metrics
- Hidden food revealed: how many times K agents cluster to reveal hidden food
- Hidden food collected: how many hidden food items are actually eaten
- Regular food collected: baseline foraging performance
- Reward breakdown: regular food energy vs hidden food energy

## Setup
1. Runtime > Change runtime type > **TPU v6e** + **High-RAM**
2. Run all cells (Ctrl+F9)

In [None]:
from google.colab import drive
drive.mount('/content/drive')

import os
REPO_DIR = '/content/emergence-lab'
GITHUB_USERNAME = "imashishkh21"

if not os.path.exists(REPO_DIR):
    !git clone https://github.com/{GITHUB_USERNAME}/emergence-lab.git {REPO_DIR}
else:
    !cd {REPO_DIR} && git pull

os.chdir(REPO_DIR)
!pip install -e ".[dev,phase5]" -q

import jax
print(f"JAX version: {jax.__version__}")
print(f"Devices: {jax.devices()}")
print(f"Device count: {jax.device_count()}")

In [None]:
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import json, pickle, gc, os, copy
import glob as glob_mod
from pathlib import Path
from collections import defaultdict
from datetime import datetime

import jax
import jax.numpy as jnp
from scipy import stats as scipy_stats

from src.configs import Config
from src.agents.network import ActorCritic
from src.agents.policy import get_deterministic_actions
from src.environment.env import reset, step
from src.environment.obs import get_observations
from src.analysis.specialization import compute_weight_divergence
from src.analysis.ablation import _run_episode_full
from src.training.checkpointing import load_checkpoint
from src.analysis.statistics import (
    compute_iqm, compare_methods, welch_t_test, mann_whitney_test,
    probability_of_improvement,
)
from src.analysis.paper_figures import (
    setup_publication_style, plot_performance_profiles, save_figure,
)

# Override Agg backend from paper_figures import
%matplotlib inline

FIELD_ON_DIR = '/content/drive/MyDrive/emergence-lab/hidden_food_field_on/'
FIELD_OFF_DIR = '/content/drive/MyDrive/emergence-lab/hidden_food_field_off/'
OUTPUT_DIR = '/content/drive/MyDrive/emergence-lab/hidden_food_analysis_results/'
os.makedirs(OUTPUT_DIR, exist_ok=True)

print("Imports loaded.")
print(f"Field ON dir:  {FIELD_ON_DIR} (exists: {os.path.exists(FIELD_ON_DIR)})")
print(f"Field OFF dir: {FIELD_OFF_DIR} (exists: {os.path.exists(FIELD_OFF_DIR)})")

In [None]:
# =============================================================================
# TEST STEP - Verify everything works before running full analysis
# =============================================================================
import time

print("="*70)
print("TEST MODE: Running quick verification...")
print("="*70)

errors = []

# Test 1: Stats functions work
print("\n[1/7] Testing statistics functions...")
try:
    test_a = np.array([1.0, 2.0, 3.0, 4.0, 5.0])
    test_b = np.array([2.0, 3.0, 4.0, 5.0, 6.0])
    iqm_test = compute_iqm(test_a, n_bootstrap=100, seed=0)
    welch_test = welch_t_test(test_a, test_b)
    mw_test = mann_whitney_test(test_a, test_b)
    poi_test = probability_of_improvement(test_a, test_b, n_bootstrap=100, seed=0)
    comp_test = compare_methods({"A": test_a, "B": test_b}, n_bootstrap=100, seed=0)
    assert hasattr(iqm_test, 'iqm'), "IQM result missing .iqm"
    assert hasattr(welch_test, 'p_value'), "Welch result missing .p_value"
    assert 'prob_x_better' in poi_test, "POI missing prob_x_better"
    assert hasattr(comp_test, 'summary'), "Compare result missing .summary"
    print("   PASS: All statistics functions work correctly")
except Exception as e:
    errors.append(f"Stats functions: {e}")
    print(f"   FAIL: {e}")

# Test 2: Plot functions work
print("\n[2/7] Testing plot functions...")
try:
    setup_publication_style()
    fig_test, ax_test = plt.subplots()
    ax_test.plot([1, 2, 3])
    plt.close(fig_test)
    print("   PASS: Matplotlib + publication style working")
except Exception as e:
    errors.append(f"Plot functions: {e}")
    print(f"   FAIL: {e}")

# Test 3: Field ON Drive directory
print("\n[3/7] Testing Field ON Drive access...")
try:
    assert os.path.exists(FIELD_ON_DIR), f"Directory not found: {FIELD_ON_DIR}"
    on_batches = [d for d in os.listdir(FIELD_ON_DIR) if d.startswith('batch_')]
    assert len(on_batches) > 0, "No batch directories found for Field ON"
    on_ckpts = 0
    for bd in on_batches:
        bp = os.path.join(FIELD_ON_DIR, bd)
        for sd in os.listdir(bp):
            sp = os.path.join(bp, sd)
            if os.path.isdir(sp):
                pkls = glob_mod.glob(os.path.join(sp, 'step_*.pkl'))
                on_ckpts += len(pkls) > 0
    print(f"   PASS: Field ON -- {len(on_batches)} batches, {on_ckpts} seed checkpoints")
except Exception as e:
    errors.append(f"Field ON Drive: {e}")
    print(f"   FAIL: {e}")

# Test 4: Field OFF Drive directory
print("\n[4/7] Testing Field OFF Drive access...")
try:
    assert os.path.exists(FIELD_OFF_DIR), f"Directory not found: {FIELD_OFF_DIR}"
    off_batches = [d for d in os.listdir(FIELD_OFF_DIR) if d.startswith('batch_')]
    assert len(off_batches) > 0, "No batch directories found for Field OFF"
    off_ckpts = 0
    for bd in off_batches:
        bp = os.path.join(FIELD_OFF_DIR, bd)
        for sd in os.listdir(bp):
            sp = os.path.join(bp, sd)
            if os.path.isdir(sp):
                pkls = glob_mod.glob(os.path.join(sp, 'step_*.pkl'))
                off_ckpts += len(pkls) > 0
    print(f"   PASS: Field OFF -- {len(off_batches)} batches, {off_ckpts} seed checkpoints")
except Exception as e:
    errors.append(f"Field OFF Drive: {e}")
    print(f"   FAIL: {e}")

# Test 5: Checkpoint loading - verify hidden food enabled
print("\n[5/7] Testing checkpoint loading (one per condition)...")
try:
    for cond_name, cond_dir in [("Field ON", FIELD_ON_DIR), ("Field OFF", FIELD_OFF_DIR)]:
        test_ckpt = None
        batch_dirs = sorted([d for d in os.listdir(cond_dir) if d.startswith('batch_')])
        for bd in batch_dirs:
            bp = os.path.join(cond_dir, bd)
            for sd in sorted(os.listdir(bp)):
                sp = os.path.join(bp, sd)
                if os.path.isdir(sp):
                    pkls = glob_mod.glob(os.path.join(sp, 'step_*.pkl'))
                    if pkls:
                        test_ckpt = sorted(pkls)[-1]
                        break
            if test_ckpt:
                break
        assert test_ckpt is not None, f"No checkpoint found for {cond_name}"
        ckpt = load_checkpoint(test_ckpt)
        config = ckpt['config']
        assert hasattr(config, 'env'), f"{cond_name} config not a dataclass"
        assert config.env.hidden_food.enabled == True, f"{cond_name}: hidden_food not enabled!"
        print(f"   {cond_name}: seed={ckpt.get('seed_id', -1)}, grid={config.env.grid_size}, "
              f"hidden_food={config.env.hidden_food.enabled}, "
              f"decay={config.field.decay_rate}, diffusion={config.field.diffusion_rate}")
    print("   PASS: Both conditions' checkpoints load correctly with hidden food enabled")
except Exception as e:
    errors.append(f"Checkpoint loading: {e}")
    print(f"   FAIL: {e}")

# Test 6: Standard eval episode
print("\n[6/7] Testing standard eval episode...")
try:
    agent_params = jax.tree_util.tree_map(lambda x: x[0], ckpt['agent_params'])
    network = ActorCritic(hidden_dims=tuple(config.agent.hidden_dims), num_actions=6)
    t0 = time.time()
    key = jax.random.PRNGKey(42)
    stats = _run_episode_full(
        network=network, params=ckpt['params'], config=config,
        key=key, condition="normal", evolution=True,
    )
    elapsed = time.time() - t0
    print(f"   Eval: reward={stats.total_reward:.1f}, pop={stats.final_population} ({elapsed:.1f}s)")
    print("   PASS: Standard eval episode works")
except Exception as e:
    errors.append(f"Standard eval: {e}")
    print(f"   FAIL: {e}")

# Test 7: Hidden food custom eval
print("\n[7/7] Testing hidden food custom eval...")
try:
    key = jax.random.PRNGKey(99)
    state = reset(key, config)
    obs = get_observations(state, config)
    obs_batched = obs[None, :, :]  # (1, max_agents, obs_dim)
    actions = get_deterministic_actions(network, ckpt['params'], obs_batched)
    actions = actions[0]  # (max_agents,)
    pre_revealed = state.hidden_food_revealed
    state2, rewards, done, info = step(state, actions, config)
    assert 'hidden_food_collected_this_step' in info
    assert 'food_collected_this_step' in info
    newly_revealed = (~pre_revealed) & state2.hidden_food_revealed
    print(f"   Custom eval step: regular_food={info['food_collected_this_step']}, "
          f"hf_collected={info['hidden_food_collected_this_step']}, "
          f"newly_revealed={int(jnp.sum(newly_revealed))}")
    print("   PASS: Hidden food custom eval mechanics work")
except Exception as e:
    errors.append(f"Hidden food eval: {e}")
    print(f"   FAIL: {e}")

# Summary
print()
print("="*70)
if errors:
    print(f"TEST FAILED! {len(errors)} error(s):")
    for err in errors:
        print(f"  - {err}")
    print("\nDO NOT proceed until all tests pass.")
    print("="*70)
    raise RuntimeError(f"Test failed with {len(errors)} error(s)")
else:
    print("ALL 7 TESTS PASSED!")
    print("="*70)
    print("Proceed to run the full analysis below.")

## Custom Eval: Hidden Food Metrics

This function tracks per-step hidden food statistics that aren't captured
by the standard `_run_episode_full()`.

**Warning**: ~5-10 minutes per seed on TPU (500 Python-level JAX calls per episode,
not JIT-compiled). With 60 seeds at 1 episode each = ~5-10 hours total.

In [None]:
def run_hidden_food_eval(network, params, config, key, num_episodes=1):
    """Run eval episodes tracking hidden food reveal and collection.

    Args:
        network: ActorCritic module.
        params: Network parameters.
        config: Config with hidden_food enabled.
        key: PRNG key.
        num_episodes: Episodes to run. Default 1 (set to 5 for more precision,
            but each episode takes ~5-10 min on TPU).

    Returns:
        dict with aggregated metrics and per_episode list.

    Note on reveal counting:
        Reveals are counted by diffing state.hidden_food_revealed before/after
        each step. If food is revealed AND collected in the same step, the diff
        misses it (revealed resets to False). So:
        - hidden_food_collected is EXACT (from info dict)
        - hidden_food_revealed is a LOWER BOUND
        - total_reveal_events = revealed + collected (upper bound,
          since some collected food may have been revealed in a prior step)
    """
    all_results = []

    for ep in range(num_episodes):
        key, ep_key = jax.random.split(key)
        state = reset(ep_key, config)

        ep_reward = 0.0
        ep_regular_food = 0.0
        ep_hf_revealed = 0
        ep_hf_collected = 0.0
        ep_births = 0
        ep_deaths = 0

        for t in range(config.env.max_steps):
            obs = get_observations(state, config)
            obs_batched = obs[None, :, :]  # (1, max_agents, obs_dim)
            actions = get_deterministic_actions(network, params, obs_batched)
            actions = actions[0]  # (max_agents,)

            # Track pre-step hidden food revealed state
            pre_revealed = state.hidden_food_revealed

            state, rewards, done, info = step(state, actions, config)

            ep_reward += float(jnp.sum(rewards))
            ep_regular_food += float(info['food_collected_this_step'])
            ep_hf_collected += float(info['hidden_food_collected_this_step'])
            ep_births += int(info['births_this_step'])
            ep_deaths += int(info['deaths_this_step'])

            # Count newly revealed hidden food (lower bound - see docstring)
            if pre_revealed is not None and state.hidden_food_revealed is not None:
                newly_revealed = (~pre_revealed) & state.hidden_food_revealed
                ep_hf_revealed += int(jnp.sum(newly_revealed))

            if bool(done):
                break

        final_pop = int(jnp.sum(state.agent_alive.astype(jnp.int32)))

        # Energy breakdown
        food_energy = config.evolution.food_energy
        hf_multiplier = config.env.hidden_food.hidden_food_value_multiplier
        regular_energy = ep_regular_food * food_energy
        hidden_energy = ep_hf_collected * food_energy * hf_multiplier

        all_results.append({
            'total_reward': ep_reward,
            'regular_food_collected': ep_regular_food,
            'hidden_food_revealed': ep_hf_revealed,
            'hidden_food_collected': ep_hf_collected,
            'regular_food_energy': regular_energy,
            'hidden_food_energy': hidden_energy,
            'final_population': final_pop,
            'total_births': ep_births,
            'total_deaths': ep_deaths,
        })

    # Aggregate
    agg = {k: np.mean([r[k] for r in all_results]) for k in all_results[0]}
    agg['per_episode'] = all_results
    return agg


# Quick smoke test
print("Hidden food eval function defined.")

## Phase 1: Load Training Data

Discover checkpoints and load training summaries from Drive.

In [None]:
# ========== DISCOVER CHECKPOINTS + LOAD TRAINING SUMMARIES ==========

def discover_checkpoints(drive_dir, condition_name):
    """Discover all checkpoint paths in a Drive directory."""
    paths = []
    for batch_idx in range(10):
        batch_dir = os.path.join(drive_dir, f'batch_{batch_idx}')
        if not os.path.exists(batch_dir):
            continue
        for seed_dir in sorted(os.listdir(batch_dir)):
            seed_path = os.path.join(batch_dir, seed_dir)
            if not os.path.isdir(seed_path):
                continue
            pkl_files = glob_mod.glob(os.path.join(seed_path, 'step_*.pkl'))
            if pkl_files:
                paths.append(sorted(pkl_files)[-1])
    print(f"  {condition_name}: Found {len(paths)} checkpoints")
    return paths


def load_training_summary(drive_dir, condition_name):
    """Load training_summary.pkl and extract per-seed rewards + populations."""
    summary_path = os.path.join(drive_dir, 'training_summary.pkl')
    if not os.path.exists(summary_path):
        print(f"  {condition_name}: No training_summary.pkl found")
        return None, None
    with open(summary_path, 'rb') as f:
        summary = pickle.load(f)
    rewards = []
    populations = []
    skipped = 0
    for batch in summary['all_results']:
        if not batch.get('success', True):
            skipped += 1
            continue
        if 'metrics' in batch and 'mean_reward' in batch['metrics']:
            rewards.extend(batch['metrics']['mean_reward'])
        if 'metrics' in batch and 'population_size' in batch['metrics']:
            populations.extend(batch['metrics']['population_size'])
    if skipped > 0:
        print(f"  {condition_name}: WARNING - skipped {skipped} failed batches")
    if rewards:
        rewards = np.array(rewards)
        populations = np.array(populations, dtype=int) if populations else None
        print(f"  {condition_name}: Loaded {len(rewards)} rewards from training_summary.pkl")
        return rewards, populations
    return None, None


print("Discovering data on Drive...")
field_on_ckpt_paths = discover_checkpoints(FIELD_ON_DIR, "Field ON")
field_off_ckpt_paths = discover_checkpoints(FIELD_OFF_DIR, "Field OFF")

# Load training summaries
on_rewards_drive, on_pops_drive = load_training_summary(FIELD_ON_DIR, "Field ON")
off_rewards_drive, off_pops_drive = load_training_summary(FIELD_OFF_DIR, "Field OFF")

field_on_rewards = on_rewards_drive
field_on_populations = on_pops_drive
field_off_rewards = off_rewards_drive
field_off_populations = off_pops_drive

print(f"\n{'='*60}")
print("TRAINING DATA SUMMARY")
print(f"{'='*60}")
print(f"Checkpoints: {len(field_on_ckpt_paths)} Field ON, {len(field_off_ckpt_paths)} Field OFF")
if field_on_rewards is not None:
    print(f"Field ON:  {len(field_on_rewards)} seeds, mean={field_on_rewards.mean():.3f} +/- {field_on_rewards.std(ddof=1):.3f}")
else:
    print("Field ON:  No training_summary.pkl (will use eval data only)")
if field_off_rewards is not None:
    print(f"Field OFF: {len(field_off_rewards)} seeds, mean={field_off_rewards.mean():.3f} +/- {field_off_rewards.std(ddof=1):.3f}")
else:
    print("Field OFF: No training_summary.pkl (will use eval data only)")

In [None]:
# ========== DESCRIPTIVE STATISTICS (Training Reward) ==========
print("="*60)
print("DESCRIPTIVE STATISTICS (Training Reward)")
print("="*60)

for name, rewards in [("Field ON", field_on_rewards), ("Field OFF", field_off_rewards)]:
    if rewards is None:
        print(f"\n{name}: No training data available")
        continue
    iqm = compute_iqm(rewards, n_bootstrap=10000, seed=42)
    print(f"\n{name} (n={len(rewards)}):")
    print(f"  Mean:   {rewards.mean():.4f} +/- {rewards.std(ddof=1):.4f}")
    print(f"  Median: {np.median(rewards):.4f}")
    print(f"  IQM:    {iqm.iqm:.4f} [{iqm.ci_lower:.4f}, {iqm.ci_upper:.4f}]")
    print(f"  Min:    {rewards.min():.4f}")
    print(f"  Max:    {rewards.max():.4f}")
    print(f"  CoV:    {rewards.std(ddof=1)/rewards.mean():.4f}")

In [None]:
# ========== HYPOTHESIS TESTS (Training Reward) ==========
print("="*60)
print("HYPOTHESIS TESTS (Training Reward)")
print("="*60)

if field_on_rewards is not None and field_off_rewards is not None:
    welch = welch_t_test(field_on_rewards, field_off_rewards)
    mw = mann_whitney_test(field_on_rewards, field_off_rewards)
    poi = probability_of_improvement(field_on_rewards, field_off_rewards, n_bootstrap=5000, seed=42)

    print(f"\n1. Welch's t-test:")
    print(f"   t = {welch.statistic:.4f}, p = {welch.p_value:.6f}")
    d = abs(welch.effect_size)
    d_str = "negligible" if d < 0.2 else "small" if d < 0.5 else "MEDIUM" if d < 0.8 else "LARGE"
    print(f"   Cohen's d = {welch.effect_size:.4f} ({d_str})")

    print(f"\n2. Mann-Whitney U test:")
    print(f"   U = {mw.statistic:.1f}, p = {mw.p_value:.6f}")
    print(f"   Rank-biserial r = {mw.effect_size:.4f}")

    print(f"\n3. Probability of Improvement:")
    print(f"   P(Field ON > Field OFF) = {poi['prob_x_better']:.4f}")
    print(f"   P(Field OFF > Field ON) = {poi['prob_y_better']:.4f}")

    print(f"\n4. Direction:")
    print(f"   Field ON mean:  {field_on_rewards.mean():.4f}")
    print(f"   Field OFF mean: {field_off_rewards.mean():.4f}")
    print(f"   Gap: {field_on_rewards.mean() - field_off_rewards.mean():+.4f}")
else:
    print("\nSkipping: training data not available for both conditions.")
    welch = mw = poi = None

In [None]:
# ========== FULL METHOD COMPARISON (rliable) ==========
print("="*60)
print("FULL METHOD COMPARISON (rliable-style)")
print("="*60)

if field_on_rewards is not None and field_off_rewards is not None:
    comparison = compare_methods(
        {"Field ON": field_on_rewards, "Field OFF": field_off_rewards},
        n_bootstrap=10000, seed=42,
    )
    print(comparison.summary)
else:
    print("Skipping: training data not available for both conditions.")

## Phase 2: Comparison Plots (Training Data)

In [None]:
# ========== COMPARISON PLOTS (Training Data) ==========
if field_on_rewards is not None and field_off_rewards is not None:
    setup_publication_style()
    colors = ['#009988', '#BBBBBB']

    iqm_on = compute_iqm(field_on_rewards, n_bootstrap=10000, seed=42)
    iqm_off = compute_iqm(field_off_rewards, n_bootstrap=10000, seed=42)

    fig, axes = plt.subplots(1, 3, figsize=(18, 5))

    # 1. Bar chart: IQM with CI
    ax = axes[0]
    methods = ['Field ON\n(Stigmergy)', 'Field OFF\n(No Field)']
    iqm_vals = [iqm_on.iqm, iqm_off.iqm]
    iqm_lo = [iqm_on.iqm - iqm_on.ci_lower, iqm_off.iqm - iqm_off.ci_lower]
    iqm_hi = [iqm_on.ci_upper - iqm_on.iqm, iqm_off.ci_upper - iqm_off.iqm]
    bars = ax.bar(methods, iqm_vals, yerr=[iqm_lo, iqm_hi], color=colors,
                  edgecolor='black', linewidth=0.5, capsize=8)
    for bar, val in zip(bars, iqm_vals):
        ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + max(iqm_hi) + 0.05,
                f'{val:.3f}', ha='center', fontsize=10)
    ax.set_ylabel('IQM Reward')
    ax.set_title('(a) IQM + 95% CI')

    # 2. Violin + swarm
    ax = axes[1]
    parts = ax.violinplot([field_on_rewards, field_off_rewards], positions=[1, 2],
                           showmeans=True, showmedians=True, showextrema=False)
    for i, body in enumerate(parts['bodies']):
        body.set_facecolor(colors[i])
        body.set_alpha(0.4)
    rng = np.random.default_rng(42)
    for i, (data, pos) in enumerate(zip([field_on_rewards, field_off_rewards], [1, 2])):
        jitter = rng.normal(0, 0.05, size=len(data))
        ax.scatter(np.full_like(data, pos) + jitter, data, alpha=0.6, s=25,
                   color=colors[i], edgecolor='black', linewidth=0.3, zorder=3)
    ax.set_xticks([1, 2])
    ax.set_xticklabels(['Field ON', 'Field OFF'])
    ax.set_ylabel('Mean Reward')
    ax.set_title('(b) Distribution')

    # 3. Population comparison
    ax = axes[2]
    if field_on_populations is not None:
        ax.hist(field_on_populations, bins=15, alpha=0.6, color=colors[0],
                label='Field ON', edgecolor='black')
    if field_off_populations is not None:
        ax.hist(field_off_populations, bins=15, alpha=0.6, color=colors[1],
                label='Field OFF', edgecolor='black')
    ax.set_xlabel('Final Population (training)')
    ax.set_ylabel('Count')
    ax.set_title('(c) Population Distribution')
    ax.axvline(x=64, color='red', linestyle='--', alpha=0.5, label='Max capacity')
    ax.legend(fontsize=9)

    plt.suptitle('Hidden Food: Field ON vs Field OFF (Training Data)', fontsize=14, y=1.02)
    plt.tight_layout()
    save_figure(fig, os.path.join(OUTPUT_DIR, 'training_comparison'))
    plt.show()
else:
    print("Skipping training plots: data not available.")

In [None]:
# ========== PERFORMANCE PROFILES ==========
if field_on_rewards is not None and field_off_rewards is not None:
    setup_publication_style()
    fig = plot_performance_profiles(
        {"Field ON (Stigmergy)": field_on_rewards, "Field OFF (No Field)": field_off_rewards},
        output_path=os.path.join(OUTPUT_DIR, 'performance_profiles'),
        tau_range=(0, 1.05),
    )
    plt.show()
else:
    print("Skipping performance profiles: training data not available.")

## Phase 3: Checkpoint Analysis — Hidden Food Eval

Load ALL 60 checkpoints. For each seed, run hidden food eval + compute weight divergence
in a SINGLE pass (don't load checkpoints twice).

**Warning**: This takes ~5-10 minutes per seed on TPU. Total: ~5-10 hours for 60 seeds.

In [None]:
# ========== LOAD + EVAL ALL CHECKPOINTS (BOTH CONDITIONS) ==========

NUM_EVAL_EPISODES = 1  # Set to 5 for higher precision (but ~30min/seed)


def load_seed_data(ckpt_path):
    """Load a checkpoint and extract seed data for eval + analysis."""
    ckpt = load_checkpoint(ckpt_path)
    config = ckpt['config']
    agent_params = jax.tree_util.tree_map(lambda x: x[0], ckpt['agent_params'])
    network = ActorCritic(hidden_dims=tuple(config.agent.hidden_dims), num_actions=6)
    result = {
        'params': ckpt['params'],
        'agent_params': agent_params,
        'config': config,
        'network': network,
        'seed_id': ckpt.get('seed_id', -1),
    }
    del ckpt  # Free full checkpoint (includes opt_state) to reduce memory pressure
    return result


hf_eval_on = []
hf_eval_off = []
divergence_on = []
divergence_off = []

for cond_name, ckpt_paths, eval_list, div_list in [
    ("Field ON", field_on_ckpt_paths, hf_eval_on, divergence_on),
    ("Field OFF", field_off_ckpt_paths, hf_eval_off, divergence_off),
]:
    print(f"\n{'='*60}")
    print(f"ANALYZING {cond_name} ({len(ckpt_paths)} seeds, {NUM_EVAL_EPISODES} ep/seed)")
    print(f"{'='*60}")

    for i, ckpt_path in enumerate(ckpt_paths):
        seed_data = load_seed_data(ckpt_path)
        key = jax.random.PRNGKey(42 + i)

        # Hidden food eval
        hf_result = run_hidden_food_eval(
            seed_data['network'], seed_data['params'],
            seed_data['config'], key, num_episodes=NUM_EVAL_EPISODES,
        )
        hf_result['seed_id'] = seed_data['seed_id']
        eval_list.append(hf_result)

        # Weight divergence
        div = compute_weight_divergence(seed_data['agent_params'])
        div_list.append({
            'seed_id': seed_data['seed_id'],
            'mean_divergence': float(div['mean_divergence']),
            'max_divergence': float(div['max_divergence']),
        })

        del seed_data
        gc.collect()

        if (i + 1) % 5 == 0 or i == 0:
            print(f"  [{i+1}/{len(ckpt_paths)}] seed {eval_list[-1]['seed_id']}: "
                  f"reg_food={hf_result['regular_food_collected']:.0f}, "
                  f"hf_revealed={hf_result['hidden_food_revealed']:.1f}, "
                  f"hf_collected={hf_result['hidden_food_collected']:.1f}, "
                  f"div={div['mean_divergence']:.4f}")

# ========== EXTRACT NUMPY ARRAYS ==========
on_hf_revealed = np.array([r['hidden_food_revealed'] for r in hf_eval_on])
off_hf_revealed = np.array([r['hidden_food_revealed'] for r in hf_eval_off])
on_hf_collected = np.array([r['hidden_food_collected'] for r in hf_eval_on])
off_hf_collected = np.array([r['hidden_food_collected'] for r in hf_eval_off])
on_regular_food = np.array([r['regular_food_collected'] for r in hf_eval_on])
off_regular_food = np.array([r['regular_food_collected'] for r in hf_eval_off])
on_regular_energy = np.array([r['regular_food_energy'] for r in hf_eval_on])
off_regular_energy = np.array([r['regular_food_energy'] for r in hf_eval_off])
on_hidden_energy = np.array([r['hidden_food_energy'] for r in hf_eval_on])
off_hidden_energy = np.array([r['hidden_food_energy'] for r in hf_eval_off])
on_eval_rewards = np.array([r['total_reward'] for r in hf_eval_on])
off_eval_rewards = np.array([r['total_reward'] for r in hf_eval_off])
on_eval_pops = np.array([r['final_population'] for r in hf_eval_on])
off_eval_pops = np.array([r['final_population'] for r in hf_eval_off])

on_mean_divs = np.array([r['mean_divergence'] for r in divergence_on])
off_mean_divs = np.array([r['mean_divergence'] for r in divergence_off])

# ========== SUMMARY TABLE ==========
print(f"\n{'='*60}")
print("HIDDEN FOOD EVAL SUMMARY")
print(f"{'='*60}")
for name, hf_rev, hf_col, reg_food, reg_e, hid_e, ev_rew, ev_pop, divs in [
    ("Field ON", on_hf_revealed, on_hf_collected, on_regular_food,
     on_regular_energy, on_hidden_energy, on_eval_rewards, on_eval_pops, on_mean_divs),
    ("Field OFF", off_hf_revealed, off_hf_collected, off_regular_food,
     off_regular_energy, off_hidden_energy, off_eval_rewards, off_eval_pops, off_mean_divs),
]:
    print(f"\n{name} ({len(hf_rev)} seeds):")
    print(f"  Hidden food revealed:  {hf_rev.mean():.2f} +/- {hf_rev.std(ddof=1):.2f}")
    print(f"  Hidden food collected: {hf_col.mean():.2f} +/- {hf_col.std(ddof=1):.2f}")
    print(f"  Regular food:          {reg_food.mean():.1f} +/- {reg_food.std(ddof=1):.1f}")
    print(f"  Regular energy:        {reg_e.mean():.0f}")
    print(f"  Hidden energy:         {hid_e.mean():.0f}")
    print(f"  Total eval reward:     {ev_rew.mean():.1f} +/- {ev_rew.std(ddof=1):.1f}")
    print(f"  Final population:      {ev_pop.mean():.1f} +/- {ev_pop.std(ddof=1):.1f}")
    print(f"  Weight divergence:     {divs.mean():.4f} +/- {divs.std(ddof=1):.4f}")

In [None]:
# ========== HIDDEN FOOD STATISTICAL TESTS ==========
print("="*60)
print("HIDDEN FOOD COORDINATION TESTS")
print("="*60)

for metric_name, on_vals, off_vals in [
    ("Hidden Food Revealed", on_hf_revealed, off_hf_revealed),
    ("Hidden Food Collected", on_hf_collected, off_hf_collected),
    ("Regular Food Collected", on_regular_food, off_regular_food),
    ("Hidden Food Energy", on_hidden_energy, off_hidden_energy),
]:
    print(f"\n--- {metric_name} ---")
    print(f"  Field ON:  {on_vals.mean():.2f} +/- {on_vals.std(ddof=1):.2f}")
    print(f"  Field OFF: {off_vals.mean():.2f} +/- {off_vals.std(ddof=1):.2f}")

    # Handle all-zero case
    if on_vals.sum() == 0 and off_vals.sum() == 0:
        print("  Both conditions show zero -- no coordination observed.")
        continue
    if off_vals.std() == 0 and on_vals.std() == 0:
        print("  No variance in either condition -- test not applicable.")
        continue

    w = welch_t_test(on_vals, off_vals)
    mw_hf = mann_whitney_test(on_vals, off_vals)
    print(f"  Welch: t={w.statistic:.3f}, p={w.p_value:.6f}, d={w.effect_size:.3f}")
    print(f"  Mann-Whitney: U={mw_hf.statistic:.1f}, p={mw_hf.p_value:.6f}")

    if off_vals.sum() == 0:
        print("  NOTE: Field OFF shows ZERO -- agents cannot coordinate without field.")

In [None]:
# ========== HIDDEN FOOD PLOTS (THE MONEY SHOTS) ==========
setup_publication_style()
colors = ['#009988', '#BBBBBB']

fig, axes = plt.subplots(1, 5, figsize=(25, 5))

# Helper for bar chart with zero handling
def bar_with_zero_handling(ax, on_vals, off_vals, title, ylabel):
    means = [on_vals.mean(), off_vals.mean()]
    stds = [on_vals.std(ddof=1) if len(on_vals) > 1 else 0,
            off_vals.std(ddof=1) if len(off_vals) > 1 else 0]
    bars = ax.bar(['Field ON', 'Field OFF'], means, yerr=stds, color=colors,
                  edgecolor='black', linewidth=0.5, capsize=8)
    for bar, val in zip(bars, means):
        ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + max(stds) * 0.1 + 0.01,
                f'{val:.2f}', ha='center', fontsize=10)
    ax.set_ylabel(ylabel)
    ax.set_title(title)
    # Annotate if Field OFF is all zero
    if off_vals.sum() == 0 and on_vals.sum() > 0:
        ax.text(0.5, 0.1, 'Field OFF: No coordination\nobserved (confirms hypothesis)',
                ha='center', transform=ax.transAxes, fontsize=9, color='red',
                style='italic')

# 1. Hidden Food Revealed
bar_with_zero_handling(axes[0], on_hf_revealed, off_hf_revealed,
                       '(a) Hidden Food Revealed', 'Reveal Events')

# 2. Hidden Food Collected
bar_with_zero_handling(axes[1], on_hf_collected, off_hf_collected,
                       '(b) Hidden Food Collected', 'Items Collected')

# 3. Regular Food Collected
bar_with_zero_handling(axes[2], on_regular_food, off_regular_food,
                       '(c) Regular Food Collected', 'Items Collected')

# 4. Energy Breakdown (stacked bar)
ax = axes[3]
x_pos = [0, 1]
x_labels = ['Field ON', 'Field OFF']
reg_means = [on_regular_energy.mean(), off_regular_energy.mean()]
hid_means = [on_hidden_energy.mean(), off_hidden_energy.mean()]
ax.bar(x_pos, reg_means, color=['#009988', '#BBBBBB'], edgecolor='black', linewidth=0.5,
       label='Regular Food Energy')
ax.bar(x_pos, hid_means, bottom=reg_means, color=['#006655', '#888888'], edgecolor='black',
       linewidth=0.5, label='Hidden Food Energy')
for i, (r, h) in enumerate(zip(reg_means, hid_means)):
    ax.text(i, r + h + 50, f'{r+h:.0f}', ha='center', fontsize=10)
ax.set_xticks(x_pos)
ax.set_xticklabels(x_labels)
ax.set_ylabel('Total Energy')
ax.set_title('(d) Energy Breakdown')
ax.legend(fontsize=8)

# 5. Hidden Food Collected Distribution (violin)
ax = axes[4]
# Check if we have non-zero data for violin
has_on_data = on_hf_collected.sum() > 0
has_off_data = off_hf_collected.sum() > 0

if has_on_data or has_off_data:
    data_to_plot = []
    positions = []
    plot_colors = []
    if has_on_data:
        data_to_plot.append(on_hf_collected)
        positions.append(1)
        plot_colors.append(colors[0])
    if has_off_data:
        data_to_plot.append(off_hf_collected)
        positions.append(2)
        plot_colors.append(colors[1])
    if len(data_to_plot) > 0:
        parts = ax.violinplot(data_to_plot, positions=positions,
                               showmeans=True, showmedians=True, showextrema=False)
        for i, body in enumerate(parts['bodies']):
            body.set_facecolor(plot_colors[i])
            body.set_alpha(0.4)
    if not has_off_data:
        ax.text(2, ax.get_ylim()[1] * 0.5, 'No coordination\nobserved',
                ha='center', fontsize=9, color='red', style='italic')
else:
    ax.text(0.5, 0.5, 'No hidden food coordination\nobserved in either condition',
            ha='center', va='center', transform=ax.transAxes, fontsize=10, color='red')

ax.set_xticks([1, 2])
ax.set_xticklabels(['Field ON', 'Field OFF'])
ax.set_ylabel('Hidden Food Collected')
ax.set_title('(e) HF Collected Distribution')

plt.suptitle('Hidden Food Coordination: Field ON vs Field OFF', fontsize=14, y=1.02)
plt.tight_layout()
save_figure(fig, os.path.join(OUTPUT_DIR, 'hidden_food_coordination'))
plt.show()

In [None]:
# ========== WEIGHT DIVERGENCE + CORRELATION PLOTS ==========
setup_publication_style()
colors = ['#009988', '#BBBBBB']

fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# (a) Weight divergence histogram
ax = axes[0, 0]
ax.hist(on_mean_divs, bins=12, alpha=0.6, color=colors[0], label='Field ON', edgecolor='black')
ax.hist(off_mean_divs, bins=12, alpha=0.6, color=colors[1], label='Field OFF', edgecolor='black')
ax.axvline(on_mean_divs.mean(), color=colors[0], linestyle='--', linewidth=2)
ax.axvline(off_mean_divs.mean(), color=colors[1], linestyle='--', linewidth=2)
ax.set_xlabel('Mean Pairwise Weight Divergence (cosine)')
ax.set_ylabel('Count')
ax.set_title('(a) Weight Divergence Distribution')
ax.legend()

# (b) Eval population histogram
ax = axes[0, 1]
ax.hist(on_eval_pops, bins=15, alpha=0.6, color=colors[0], label='Field ON', edgecolor='black')
ax.hist(off_eval_pops, bins=15, alpha=0.6, color=colors[1], label='Field OFF', edgecolor='black')
ax.axvline(x=64, color='red', linestyle='--', alpha=0.5, label='Max capacity')
ax.set_xlabel('Eval Population')
ax.set_ylabel('Count')
ax.set_title('(b) Eval Population Distribution')
ax.legend()

# (c) Hidden food collected vs weight divergence (scatter)
ax = axes[1, 0]
ax.scatter(on_mean_divs, on_hf_collected, color=colors[0], s=50, alpha=0.7,
           edgecolor='black', linewidth=0.5, label='Field ON')
ax.scatter(off_mean_divs, off_hf_collected, color=colors[1], s=50, alpha=0.7,
           edgecolor='black', linewidth=0.5, label='Field OFF')
# Pearson correlation for Field ON only
if on_hf_collected.sum() > 0 and on_mean_divs.std() > 0:
    r, p = scipy_stats.pearsonr(on_mean_divs, on_hf_collected)
    ax.text(0.05, 0.95, f'Field ON: r={r:.3f}, p={p:.4f}',
            transform=ax.transAxes, fontsize=9, va='top',
            bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
ax.set_xlabel('Mean Weight Divergence')
ax.set_ylabel('Hidden Food Collected')
ax.set_title('(c) Coordination vs Specialization')
ax.legend()

# (d) Eval reward vs population scatter
ax = axes[1, 1]
ax.scatter(on_eval_pops, on_eval_rewards, color=colors[0], s=50, alpha=0.7,
           edgecolor='black', linewidth=0.5, label='Field ON')
ax.scatter(off_eval_pops, off_eval_rewards, color=colors[1], s=50, alpha=0.7,
           edgecolor='black', linewidth=0.5, label='Field OFF')
ax.set_xlabel('Final Population')
ax.set_ylabel('Total Eval Reward')
ax.set_title('(d) Eval: Reward vs Population')
ax.legend()

plt.suptitle('Weight Divergence & Correlations', fontsize=14, y=1.01)
plt.tight_layout()
save_figure(fig, os.path.join(OUTPUT_DIR, 'divergence_correlation'))
plt.show()

# Divergence statistical comparison
div_welch = welch_t_test(on_mean_divs, off_mean_divs)
div_mw = mann_whitney_test(on_mean_divs, off_mean_divs)
print(f"\nDivergence comparison:")
print(f"  Welch t-test: t={div_welch.statistic:.4f}, p={div_welch.p_value:.6f}, d={div_welch.effect_size:.4f}")
print(f"  Mann-Whitney: U={div_mw.statistic:.1f}, p={div_mw.p_value:.6f}, r={div_mw.effect_size:.4f}")

## Phase 4: Report & Save

Generate formatted comparison report and save all results to Drive.

In [None]:
# ========== GENERATE MARKDOWN REPORT ==========
from IPython.display import Markdown, display

# Compute stats for report
div_welch = welch_t_test(on_mean_divs, off_mean_divs)

# Hidden food stats
hf_rev_welch = None
hf_col_welch = None
if on_hf_revealed.sum() > 0 or off_hf_revealed.sum() > 0:
    if on_hf_revealed.std() > 0 or off_hf_revealed.std() > 0:
        hf_rev_welch = welch_t_test(on_hf_revealed, off_hf_revealed)
if on_hf_collected.sum() > 0 or off_hf_collected.sum() > 0:
    if on_hf_collected.std() > 0 or off_hf_collected.std() > 0:
        hf_col_welch = welch_t_test(on_hf_collected, off_hf_collected)

# Correlation
corr_str = "N/A (no hidden food collected)"
if on_hf_collected.sum() > 0 and on_mean_divs.std() > 0:
    r_corr, p_corr = scipy_stats.pearsonr(on_mean_divs, on_hf_collected)
    corr_str = f"r={r_corr:.3f}, p={p_corr:.4f}"

# Energy advantage
energy_advantage = on_hidden_energy.mean() - off_hidden_energy.mean()

report = f"""# Hidden Food Coordination Analysis Report
Generated: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}

## Experiment Setup
- **Conditions**: Field ON (stigmergy) vs Field OFF (no shared field)
- **Seeds per condition**: {len(field_on_ckpt_paths)} Field ON, {len(field_off_ckpt_paths)} Field OFF
- **Training steps**: 10M per seed
- **Config**: 64-agent, grid=32, num_food=40
- **Hidden food**: 3 items, require K=3 agents within distance 3 to reveal, 5x value (500 energy)
- **Field ON**: diffusion=0.1, decay=0.05, write_strength=1.0
- **Field OFF**: diffusion=0.0, decay=1.0, write_strength=0.0
- **Eval episodes per seed**: {NUM_EVAL_EPISODES}

## Training Reward Comparison
"""

if field_on_rewards is not None and field_off_rewards is not None:
    iqm_on = compute_iqm(field_on_rewards, n_bootstrap=10000, seed=42)
    iqm_off = compute_iqm(field_off_rewards, n_bootstrap=10000, seed=42)
    report += f"""| Metric | Field ON | Field OFF |
|--------|----------|----------|
| Mean | {field_on_rewards.mean():.4f} +/- {field_on_rewards.std(ddof=1):.4f} | {field_off_rewards.mean():.4f} +/- {field_off_rewards.std(ddof=1):.4f} |
| Median | {np.median(field_on_rewards):.4f} | {np.median(field_off_rewards):.4f} |
| IQM | {iqm_on.iqm:.4f} [{iqm_on.ci_lower:.4f}, {iqm_on.ci_upper:.4f}] | {iqm_off.iqm:.4f} [{iqm_off.ci_lower:.4f}, {iqm_off.ci_upper:.4f}] |

"""
else:
    report += "Training summaries not available on Drive.\n\n"

report += f"""## Hidden Food Coordination Metrics (THE KEY SECTION)

| Metric | Field ON | Field OFF |
|--------|----------|----------|
| Hidden food revealed | {on_hf_revealed.mean():.2f} +/- {on_hf_revealed.std(ddof=1):.2f} | {off_hf_revealed.mean():.2f} +/- {off_hf_revealed.std(ddof=1):.2f} |
| Hidden food collected | {on_hf_collected.mean():.2f} +/- {on_hf_collected.std(ddof=1):.2f} | {off_hf_collected.mean():.2f} +/- {off_hf_collected.std(ddof=1):.2f} |
| Regular food collected | {on_regular_food.mean():.1f} +/- {on_regular_food.std(ddof=1):.1f} | {off_regular_food.mean():.1f} +/- {off_regular_food.std(ddof=1):.1f} |
| Regular food energy | {on_regular_energy.mean():.0f} | {off_regular_energy.mean():.0f} |
| Hidden food energy | {on_hidden_energy.mean():.0f} | {off_hidden_energy.mean():.0f} |
| Total energy | {(on_regular_energy + on_hidden_energy).mean():.0f} | {(off_regular_energy + off_hidden_energy).mean():.0f} |

### Statistical Tests (Hidden Food)
"""

if hf_rev_welch is not None:
    report += f"- **Hidden Food Revealed**: Welch t={hf_rev_welch.statistic:.3f}, p={hf_rev_welch.p_value:.6f}, d={hf_rev_welch.effect_size:.3f}\n"
else:
    report += "- **Hidden Food Revealed**: Test not applicable (zero variance or all zeros)\n"

if hf_col_welch is not None:
    report += f"- **Hidden Food Collected**: Welch t={hf_col_welch.statistic:.3f}, p={hf_col_welch.p_value:.6f}, d={hf_col_welch.effect_size:.3f}\n"
else:
    report += "- **Hidden Food Collected**: Test not applicable (zero variance or all zeros)\n"

report += f"""
### Energy Breakdown
- Field ON gets **{energy_advantage:.0f} extra energy** from hidden food coordination
- This represents a {energy_advantage / max(off_regular_energy.mean(), 1) * 100:.1f}% bonus over Field OFF's regular food energy

### Interpretation
"""

if off_hf_revealed.sum() == 0 and on_hf_revealed.sum() > 0:
    report += """**Field OFF shows ZERO hidden food reveals.** This confirms agents CANNOT coordinate
without the shared field. The field is necessary and sufficient for the coordination task.
"""
elif on_hf_revealed.mean() > off_hf_revealed.mean():
    report += """Field ON shows higher hidden food coordination than Field OFF.
The shared field enables agents to signal hidden food locations.
"""
else:
    report += """Results are inconclusive or unexpected. Further investigation needed.
"""

report += f"""
## Eval Episode Metrics

| Metric | Field ON | Field OFF |
|--------|----------|----------|
| Mean total reward | {on_eval_rewards.mean():.1f} +/- {on_eval_rewards.std(ddof=1):.1f} | {off_eval_rewards.mean():.1f} +/- {off_eval_rewards.std(ddof=1):.1f} |
| Mean population | {on_eval_pops.mean():.1f} +/- {on_eval_pops.std(ddof=1):.1f} | {off_eval_pops.mean():.1f} +/- {off_eval_pops.std(ddof=1):.1f} |

## Weight Divergence Comparison

| Metric | Field ON | Field OFF |
|--------|----------|----------|
| Mean divergence | {on_mean_divs.mean():.4f} +/- {on_mean_divs.std(ddof=1):.4f} | {off_mean_divs.mean():.4f} +/- {off_mean_divs.std(ddof=1):.4f} |
| Welch p-value | {div_welch.p_value:.6f} | Cohen's d = {div_welch.effect_size:.4f} |

## Correlation: Hidden Food vs Divergence
- Field ON Pearson correlation (hidden food collected vs weight divergence): {corr_str}

## Conclusions

1. **Coordination signal**: {'The shared field enables hidden food coordination (Field ON > Field OFF)' if on_hf_collected.mean() > off_hf_collected.mean() else 'Results inconclusive on coordination advantage'}
2. **Energy advantage**: Field ON gains {energy_advantage:.0f} extra energy from hidden food
3. **Specialization**: {'Field ON shows higher weight divergence, suggesting the field enables specialization' if on_mean_divs.mean() > off_mean_divs.mean() else 'Weight divergence comparison requires further analysis'}
4. **Note on reveal counting**: Hidden food revealed counts are LOWER BOUNDS (same-step reveal+collect events are missed by the diff method). Hidden food collected counts are EXACT.
"""

display(Markdown(report))
print("\nReport generated successfully.")

In [None]:
# ========== SAVE RESULTS ==========

# Prepare results dict
results = {
    'metadata': {
        'generated': datetime.now().isoformat(),
        'field_on_seeds': len(field_on_ckpt_paths),
        'field_off_seeds': len(field_off_ckpt_paths),
        'steps_per_seed': 10_000_000,
        'eval_episodes_per_seed': NUM_EVAL_EPISODES,
        'hidden_food_config': {
            'num_hidden': 3,
            'required_agents': 3,
            'reveal_distance': 3,
            'value_multiplier': 5.0,
        },
    },
    'hidden_food_metrics': {
        'field_on': {
            'hf_revealed': on_hf_revealed.tolist(),
            'hf_collected': on_hf_collected.tolist(),
            'regular_food': on_regular_food.tolist(),
            'regular_energy': on_regular_energy.tolist(),
            'hidden_energy': on_hidden_energy.tolist(),
        },
        'field_off': {
            'hf_revealed': off_hf_revealed.tolist(),
            'hf_collected': off_hf_collected.tolist(),
            'regular_food': off_regular_food.tolist(),
            'regular_energy': off_regular_energy.tolist(),
            'hidden_energy': off_hidden_energy.tolist(),
        },
    },
    'eval_metrics': {
        'field_on': {
            'rewards': on_eval_rewards.tolist(),
            'populations': on_eval_pops.tolist(),
            'per_seed': [{k: v for k, v in r.items() if k != 'per_episode'}
                         for r in hf_eval_on],
        },
        'field_off': {
            'rewards': off_eval_rewards.tolist(),
            'populations': off_eval_pops.tolist(),
            'per_seed': [{k: v for k, v in r.items() if k != 'per_episode'}
                         for r in hf_eval_off],
        },
    },
    'weight_divergence': {
        'field_on': on_mean_divs.tolist(),
        'field_off': off_mean_divs.tolist(),
        'welch_p': float(div_welch.p_value),
        'cohens_d': float(div_welch.effect_size),
    },
    'training_rewards': {
        'field_on': field_on_rewards.tolist() if field_on_rewards is not None else None,
        'field_off': field_off_rewards.tolist() if field_off_rewards is not None else None,
    },
}

# Save JSON
json_path = os.path.join(OUTPUT_DIR, 'hidden_food_results.json')
with open(json_path, 'w') as f:
    json.dump(results, f, indent=2)
print(f"JSON saved: {json_path}")

# Save pickle (preserves numpy arrays)
pkl_path = os.path.join(OUTPUT_DIR, 'hidden_food_results.pkl')
with open(pkl_path, 'wb') as f:
    pickle.dump({
        **results,
        # Include numpy arrays directly for convenience
        'numpy_arrays': {
            'on_hf_revealed': on_hf_revealed,
            'off_hf_revealed': off_hf_revealed,
            'on_hf_collected': on_hf_collected,
            'off_hf_collected': off_hf_collected,
            'on_regular_food': on_regular_food,
            'off_regular_food': off_regular_food,
            'on_regular_energy': on_regular_energy,
            'off_regular_energy': off_regular_energy,
            'on_hidden_energy': on_hidden_energy,
            'off_hidden_energy': off_hidden_energy,
            'on_eval_rewards': on_eval_rewards,
            'off_eval_rewards': off_eval_rewards,
            'on_eval_pops': on_eval_pops,
            'off_eval_pops': off_eval_pops,
            'on_mean_divs': on_mean_divs,
            'off_mean_divs': off_mean_divs,
        },
    }, f)
print(f"Pickle saved: {pkl_path}")

# Save report markdown
md_path = os.path.join(OUTPUT_DIR, 'hidden_food_report.md')
with open(md_path, 'w') as f:
    f.write(report)
print(f"Report saved: {md_path}")

# List all output files
print(f"\n{'='*60}")
print("ALL OUTPUT FILES:")
for fname in sorted(os.listdir(OUTPUT_DIR)):
    fpath = os.path.join(OUTPUT_DIR, fname)
    size_mb = os.path.getsize(fpath) / 1024 / 1024
    print(f"  {fname} ({size_mb:.2f} MB)")
print(f"\nDone! All results saved to {OUTPUT_DIR}")