# Field ON vs Field OFF: Statistical Comparison (v2 — Fresh Training)

**v2 experiment**: Trains both conditions from scratch using the post-energy-fix codebase
with the biological pheromone system (Phase 6). Config loaded from sweep YAML if available,
otherwise falls back to v1 sweep values.

30 seeds per condition (10 batches x 3 seeds), 10M steps each, 5 eval episodes per seed.
Metrics: reward, population, trail_strength, survival_rate.

**Key changes from v1:**
- No hardcoded data — all metrics from fresh training + eval episodes
- Field ON config uses sweep-optimized pheromone parameters
- Field OFF inherits ALL params from Field ON (controlled experiment)
- Energy fixes: crop refuel + free write steps
- New metrics: trail_strength, survival_rate

## Setup
1. Runtime > Change runtime type > **TPU v6e** + **High-RAM**
2. Run all cells (Ctrl+F9)

In [None]:
from google.colab import drive
drive.mount('/content/drive')

import os
REPO_DIR = '/content/emergence-lab'
GITHUB_USERNAME = "imashishkh21"

if not os.path.exists(REPO_DIR):
    !git clone https://github.com/{GITHUB_USERNAME}/emergence-lab.git {REPO_DIR}
else:
    !cd {REPO_DIR} && git pull

os.chdir(REPO_DIR)
!pip install -e ".[dev,phase5]" -q
print(f"Working directory: {os.getcwd()}")

In [None]:
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import json
import pickle
import gc
import os
import math
import re
import time
import traceback
import yaml
from pathlib import Path

import jax
import jax.numpy as jnp

from src.configs import Config, TrainingMode
from src.training.parallel_train import ParallelTrainer
from src.agents.network import ActorCritic
from src.agents.policy import sample_actions
from src.environment.env import reset, step
from src.environment.obs import get_observations
from src.analysis.statistics import (
    compute_iqm, compare_methods, welch_t_test, mann_whitney_test,
    probability_of_improvement,
)
from src.analysis.paper_figures import (
    setup_publication_style, plot_performance_profiles,
    save_figure,
)

# paper_figures.py sets matplotlib backend to Agg at import time.
# Override to inline for Colab display.
%matplotlib inline

# --- Paths ---
DRIVE_BASE = '/content/drive/MyDrive/emergence-lab/field_on_vs_off_v2'
FIELD_ON_DIR = os.path.join(DRIVE_BASE, 'field_on')
FIELD_OFF_DIR = os.path.join(DRIVE_BASE, 'field_off')
OUTPUT_DIR = os.path.join(DRIVE_BASE, 'analysis_results')
BEST_CONFIG_PATH = '/content/drive/MyDrive/emergence-lab/pheromone_best_config_v2.yaml'
os.makedirs(OUTPUT_DIR, exist_ok=True)
os.makedirs(FIELD_ON_DIR, exist_ok=True)
os.makedirs(FIELD_OFF_DIR, exist_ok=True)

# --- Training constants ---
NUM_ENVS = 32
NUM_STEPS = 128
MAX_AGENTS = 64
TOTAL_STEPS = 10_000_000
SEEDS_PER_BATCH = 3
TOTAL_BATCHES = 10  # 30 seeds total
NUM_EVAL_EPISODES = 5
STEPS_PER_ITER = NUM_ENVS * NUM_STEPS * MAX_AGENTS  # 262,144
NUM_ITERATIONS = math.ceil(TOTAL_STEPS / STEPS_PER_ITER)  # ~39

print("Imports loaded.")
print(f"Drive base:    {DRIVE_BASE}")
print(f"Field ON dir:  {FIELD_ON_DIR}")
print(f"Field OFF dir: {FIELD_OFF_DIR}")
print(f"Output dir:    {OUTPUT_DIR}")
print(f"Best config:   {BEST_CONFIG_PATH} (exists: {os.path.exists(BEST_CONFIG_PATH)})")
print(f"\nTraining plan: {TOTAL_BATCHES} batches x {SEEDS_PER_BATCH} seeds = "
      f"{TOTAL_BATCHES * SEEDS_PER_BATCH} seeds per condition")
print(f"Steps per seed: {TOTAL_STEPS:,} ({NUM_ITERATIONS} iterations)")
print(f"Eval episodes per seed: {NUM_EVAL_EPISODES}")

In [None]:
# =============================================================================
# Config Builders + Eval Functions + Helpers
# =============================================================================

# --- Part A: Config builder with YAML fallback ---

def build_field_on_config() -> Config:
    """Build Field ON config: sweep-optimized field + nest params.

    Loads from pheromone_best_config.yaml if the v2 sweep has completed.
    Falls back to v1 sweep best values with a warning otherwise.
    """
    cfg = Config()
    # Environment
    cfg.env.grid_size = 40
    cfg.env.num_agents = 16
    cfg.env.num_food = 25
    cfg.env.max_steps = 500
    # Evolution (survival-friendly)
    cfg.evolution.enabled = True
    cfg.evolution.food_energy = 100
    cfg.evolution.starting_energy = 200
    cfg.evolution.max_energy = 300
    cfg.evolution.reproduce_threshold = 180
    cfg.evolution.reproduce_cost = 80
    cfg.evolution.energy_per_step = 1
    cfg.evolution.max_agents = MAX_AGENTS
    cfg.evolution.mutation_std = 0.01
    # Training
    cfg.train.training_mode = TrainingMode.GRADIENT
    cfg.train.num_envs = NUM_ENVS
    cfg.train.num_steps = NUM_STEPS
    cfg.train.total_steps = TOTAL_STEPS
    cfg.log.wandb = False
    cfg.log.save_interval = 0
    # Field base
    cfg.field.num_channels = 4
    cfg.field.field_value_cap = 1.0

    # Field + Nest: load from sweep YAML or fall back to v1 values
    if os.path.exists(BEST_CONFIG_PATH):
        print(f"Loading best sweep config from {BEST_CONFIG_PATH}")
        with open(BEST_CONFIG_PATH, 'r') as f:
            best = yaml.safe_load(f)
        if 'field' in best:
            for k, v in best['field'].items():
                if isinstance(v, list):
                    v = tuple(v)
                if hasattr(cfg.field, k):
                    setattr(cfg.field, k, v)
        if 'nest' in best:
            for k, v in best['nest'].items():
                if hasattr(cfg.nest, k):
                    setattr(cfg.nest, k, v)
        print("  Loaded sweep-optimized field + nest config")
    else:
        print(f"WARNING: {BEST_CONFIG_PATH} not found.")
        print("  Using v1 sweep values. Re-run after v2 sweep completes for best results.")
        cfg.field.channel_diffusion_rates = (0.7, 0.01, 0.0, 0.0)
        cfg.field.channel_decay_rates = (0.08, 0.0001, 0.0, 0.0)
        cfg.field.territory_write_strength = 0.05
        cfg.nest.radius = 4
        cfg.nest.compass_noise_rate = 0.15

    return cfg


def build_field_off_config() -> Config:
    """Same as Field ON but with field channels disabled.

    CRITICAL: inherits ALL params from build_field_on_config() including
    nest radius, compass noise, etc. The ONLY difference is that pheromone
    channels are zeroed out. This ensures a controlled experiment where
    the field is the sole independent variable.
    """
    cfg = build_field_on_config()
    cfg.field.channel_diffusion_rates = (0.0, 0.0, 0.0, 0.0)
    cfg.field.channel_decay_rates = (1.0, 1.0, 1.0, 1.0)
    cfg.field.territory_write_strength = 0.0
    return cfg


# Print config summary showing both conditions share non-field params
print("=" * 70)
print("CONFIG SUMMARY")
print("=" * 70)
cfg_on = build_field_on_config()
cfg_off = build_field_off_config()

print(f"\n{'Parameter':<40} {'Field ON':>15} {'Field OFF':>15} {'Match?':>8}")
print("-" * 80)
# Shared params (should match)
for label, on_val, off_val in [
    ("env.grid_size", cfg_on.env.grid_size, cfg_off.env.grid_size),
    ("env.num_agents", cfg_on.env.num_agents, cfg_off.env.num_agents),
    ("env.num_food", cfg_on.env.num_food, cfg_off.env.num_food),
    ("evolution.max_agents", cfg_on.evolution.max_agents, cfg_off.evolution.max_agents),
    ("evolution.food_energy", cfg_on.evolution.food_energy, cfg_off.evolution.food_energy),
    ("evolution.reproduce_threshold", cfg_on.evolution.reproduce_threshold, cfg_off.evolution.reproduce_threshold),
    ("nest.radius", cfg_on.nest.radius, cfg_off.nest.radius),
    ("nest.compass_noise_rate", cfg_on.nest.compass_noise_rate, cfg_off.nest.compass_noise_rate),
    ("nest.food_sip_fraction", cfg_on.nest.food_sip_fraction, cfg_off.nest.food_sip_fraction),
]:
    match = "YES" if on_val == off_val else "NO!"
    print(f"  {label:<38} {str(on_val):>15} {str(off_val):>15} {match:>8}")

# Field params (should differ)
print("\nField-specific params (should differ):")
for label, on_val, off_val in [
    ("field.channel_diffusion_rates", cfg_on.field.channel_diffusion_rates, cfg_off.field.channel_diffusion_rates),
    ("field.channel_decay_rates", cfg_on.field.channel_decay_rates, cfg_off.field.channel_decay_rates),
    ("field.territory_write_strength", cfg_on.field.territory_write_strength, cfg_off.field.territory_write_strength),
]:
    print(f"  {label:<38} {str(on_val):>15} {str(off_val):>15}")

del cfg_on, cfg_off  # Free memory


# --- Part B: run_eval() function ---

def run_eval(network, params, config, key, num_steps=500):
    """Run a single eval episode using lax.scan."""
    key, reset_key = jax.random.split(key)
    init_state = reset(reset_key, config)

    def _eval_step(carry, _unused):
        state, rng, total_reward = carry
        obs = get_observations(state, config)
        obs_batched = obs[None, :, :]
        rng, act_key = jax.random.split(rng)
        actions, _, _, _ = sample_actions(network, params, obs_batched, act_key)
        actions = actions[0]
        state, rewards, done, info = step(state, actions, config)
        alive = state.agent_alive.astype(jnp.float32)
        total_reward = total_reward + jnp.sum(rewards * alive)
        return (state, rng, total_reward), None

    (final_state, _, total_reward), _ = jax.lax.scan(
        _eval_step, (init_state, key, jnp.float32(0.0)), None, length=num_steps,
    )

    ch0 = jnp.asarray(final_state.field_state.values[:, :, 0])
    nonzero_mask = ch0 > 0.01
    trail_strength = jnp.where(
        jnp.any(nonzero_mask),
        jnp.sum(jnp.where(nonzero_mask, ch0, 0.0)) /
          jnp.maximum(jnp.sum(nonzero_mask.astype(jnp.float32)), 1.0),
        0.0,
    )
    final_pop = jnp.sum(final_state.agent_alive.astype(jnp.int32))

    return {
        'total_reward': float(total_reward),
        'final_population': int(final_pop),
        'trail_strength': float(trail_strength),
        'survival_rate': float(final_pop) / config.env.num_agents,
    }


# --- Part C: eval_after_batch() helper ---

def eval_after_batch(trainer, config, seed_ids, num_episodes=NUM_EVAL_EPISODES):
    """Run eval episodes for each seed after a training batch completes.

    Extracts per-seed params from trainer._parallel_state (internal API --
    no public accessor exists; save_checkpoints uses the same pattern).
    """
    network = ActorCritic(
        hidden_dims=tuple(config.agent.hidden_dims),
        num_actions=config.agent.num_actions,
    )
    seed_evals = []
    for i, seed_id in enumerate(seed_ids):
        # Internal access: ParallelTrainer has no public API for per-seed params
        seed_params = jax.tree.map(lambda x: x[i], trainer._parallel_state.params)
        ep_results = []
        for ep in range(num_episodes):
            key = jax.random.PRNGKey(seed_id * 1000 + ep)
            result = run_eval(network, seed_params, config, key)
            ep_results.append(result)
        seed_evals.append({
            'seed_id': seed_id,
            'total_reward': float(np.mean([r['total_reward'] for r in ep_results])),
            'final_population': float(np.mean([r['final_population'] for r in ep_results])),
            'trail_strength': float(np.mean([r['trail_strength'] for r in ep_results])),
            'survival_rate': float(np.mean([r['survival_rate'] for r in ep_results])),
            'all_episodes': ep_results,
        })
        print(f"    Seed {seed_id}: reward={seed_evals[-1]['total_reward']:.1f}, "
              f"pop={seed_evals[-1]['final_population']:.1f}, "
              f"trail={seed_evals[-1]['trail_strength']:.3f}")
    return seed_evals


# --- Part D: Early sanity check helper ---

def sanity_check_batch(seed_evals, condition_name, batch_number):
    """Print loud warning if batch 0 looks broken (zero pop or zero reward)."""
    if batch_number != 0:
        return
    for s in seed_evals:
        if s['final_population'] == 0 and s['total_reward'] == 0:
            print("\n" + "!" * 70)
            print(f"WARNING: {condition_name} seed {s['seed_id']} has 0 population AND 0 reward!")
            print("Training may be broken. Check config and environment setup.")
            print("!" * 70 + "\n")
    avg_pop = np.mean([s['final_population'] for s in seed_evals])
    avg_rew = np.mean([s['total_reward'] for s in seed_evals])
    if avg_pop < 2:
        print("\n" + "!" * 70)
        print(f"WARNING: {condition_name} batch 0 avg population = {avg_pop:.1f}")
        print("Population near-zero. Consider: increase food_energy, decrease energy_per_step")
        print("!" * 70 + "\n")
    else:
        print(f"  Sanity check OK: batch 0 avg pop={avg_pop:.1f}, avg reward={avg_rew:.1f}")


print("\nConfig builders + eval functions ready.")

## Phase 1: Train Field ON (30 seeds, 10M steps) + Eval

10 batches x 3 seeds. Each batch trains to completion, then runs 5 eval episodes per seed.
Resume-safe: skips batches that already have results in pickle.

In [None]:
# ========== FIELD ON: TRAINING + EVAL ==========
RESULTS_PATH_ON = os.path.join(FIELD_ON_DIR, 'eval_results.pkl')
if os.path.exists(RESULTS_PATH_ON):
    with open(RESULTS_PATH_ON, 'rb') as f:
        all_results_on = pickle.load(f)
else:
    all_results_on = []

completed_batches_on = {r['batch'] for r in all_results_on if r.get('success')}
config_on = build_field_on_config()

for batch_number in range(TOTAL_BATCHES):
    if batch_number in completed_batches_on:
        print(f"[Batch {batch_number}] SKIPPED (already completed)")
        continue

    seed_ids = list(range(batch_number * SEEDS_PER_BATCH,
                          (batch_number + 1) * SEEDS_PER_BATCH))
    checkpoint_dir = os.path.join(FIELD_ON_DIR, f'batch_{batch_number}')

    print(f"\n{'='*60}")
    print(f"[Batch {batch_number}] Training Field ON seeds {seed_ids}")
    print(f"{'='*60}")

    try:
        trainer = ParallelTrainer(
            config=config_on, num_seeds=SEEDS_PER_BATCH,
            seed_ids=seed_ids, checkpoint_dir=checkpoint_dir,
            master_seed=42 + batch_number * 1000,
        )
        train_metrics = trainer.train(
            num_iterations=NUM_ITERATIONS,
            checkpoint_interval_minutes=30,
            resume=True, print_interval=5,
        )
        # Run proper eval episodes (not training metrics)
        print(f"  Running {NUM_EVAL_EPISODES} eval episodes per seed...")
        seed_evals = eval_after_batch(trainer, config_on, seed_ids)
        sanity_check_batch(seed_evals, "Field ON", batch_number)

        all_results_on.append({
            'batch': batch_number, 'seed_ids': seed_ids,
            'eval': seed_evals,
            'train_metrics': train_metrics,
            'success': True,
        })
    except Exception as e:
        traceback.print_exc()
        all_results_on.append({
            'batch': batch_number, 'seed_ids': seed_ids,
            'error': str(e), 'success': False,
        })
    finally:
        try:
            del trainer
        except NameError:
            pass
        gc.collect()
        jax.clear_caches()

    # Atomic save after each batch
    tmp = RESULTS_PATH_ON + '.tmp'
    with open(tmp, 'wb') as f:
        pickle.dump(all_results_on, f, protocol=pickle.HIGHEST_PROTOCOL)
    os.replace(tmp, RESULTS_PATH_ON)

print(f"\nField ON: {len([r for r in all_results_on if r.get('success')])} / "
      f"{TOTAL_BATCHES} batches completed")

## Phase 2: Train Field OFF (30 seeds, 10M steps) + Eval

Same structure as Phase 1 but with field channels disabled.
All other params (env, evolution, nest) are identical to Field ON.

In [None]:
# ========== FIELD OFF: TRAINING + EVAL ==========
RESULTS_PATH_OFF = os.path.join(FIELD_OFF_DIR, 'eval_results.pkl')
if os.path.exists(RESULTS_PATH_OFF):
    with open(RESULTS_PATH_OFF, 'rb') as f:
        all_results_off = pickle.load(f)
else:
    all_results_off = []

completed_batches_off = {r['batch'] for r in all_results_off if r.get('success')}
config_off = build_field_off_config()

for batch_number in range(TOTAL_BATCHES):
    if batch_number in completed_batches_off:
        print(f"[Batch {batch_number}] SKIPPED (already completed)")
        continue

    seed_ids = list(range(batch_number * SEEDS_PER_BATCH,
                          (batch_number + 1) * SEEDS_PER_BATCH))
    checkpoint_dir = os.path.join(FIELD_OFF_DIR, f'batch_{batch_number}')

    print(f"\n{'='*60}")
    print(f"[Batch {batch_number}] Training Field OFF seeds {seed_ids}")
    print(f"{'='*60}")

    try:
        trainer = ParallelTrainer(
            config=config_off, num_seeds=SEEDS_PER_BATCH,
            seed_ids=seed_ids, checkpoint_dir=checkpoint_dir,
            master_seed=42 + batch_number * 1000,
        )
        train_metrics = trainer.train(
            num_iterations=NUM_ITERATIONS,
            checkpoint_interval_minutes=30,
            resume=True, print_interval=5,
        )
        # Run proper eval episodes (not training metrics)
        print(f"  Running {NUM_EVAL_EPISODES} eval episodes per seed...")
        seed_evals = eval_after_batch(trainer, config_off, seed_ids)
        sanity_check_batch(seed_evals, "Field OFF", batch_number)

        all_results_off.append({
            'batch': batch_number, 'seed_ids': seed_ids,
            'eval': seed_evals,
            'train_metrics': train_metrics,
            'success': True,
        })
    except Exception as e:
        traceback.print_exc()
        all_results_off.append({
            'batch': batch_number, 'seed_ids': seed_ids,
            'error': str(e), 'success': False,
        })
    finally:
        try:
            del trainer
        except NameError:
            pass
        gc.collect()
        jax.clear_caches()

    # Atomic save after each batch
    tmp = RESULTS_PATH_OFF + '.tmp'
    with open(tmp, 'wb') as f:
        pickle.dump(all_results_off, f, protocol=pickle.HIGHEST_PROTOCOL)
    os.replace(tmp, RESULTS_PATH_OFF)

print(f"\nField OFF: {len([r for r in all_results_off if r.get('success')])} / "
      f"{TOTAL_BATCHES} batches completed")

In [None]:
# ========== EXTRACT EVAL METRICS INTO ARRAYS ==========
FAILED_THRESHOLD = 1.0  # Seeds with reward below this are considered "failed"

def extract_eval_arrays(all_results):
    """Extract per-seed metric arrays from batch eval results."""
    seeds = []
    for r in sorted(all_results, key=lambda x: x['batch']):
        if not r.get('success'):
            continue
        for s in r['eval']:
            seeds.append(s)
    rewards = np.array([s['total_reward'] for s in seeds])
    populations = np.array([s['final_population'] for s in seeds])
    trail_strengths = np.array([s['trail_strength'] for s in seeds])
    survival_rates = np.array([s['survival_rate'] for s in seeds])
    return rewards, populations, trail_strengths, survival_rates

field_on_rewards, field_on_populations, field_on_trails, field_on_survival = \
    extract_eval_arrays(all_results_on)
field_off_rewards, field_off_populations, field_off_trails, field_off_survival = \
    extract_eval_arrays(all_results_off)

# Alias for downstream analysis cells
field_off_populations_train = field_off_populations

assert len(field_on_rewards) == 30, f"Expected 30 Field ON seeds, got {len(field_on_rewards)}"
assert len(field_off_rewards) == 30, f"Expected 30 Field OFF seeds, got {len(field_off_rewards)}"

# Print summary table
print("=" * 80)
print("EVAL METRICS SUMMARY")
print("=" * 80)
print(f"\n{'Metric':<25} {'Field ON':>25} {'Field OFF':>25}")
print("-" * 80)
for label, on_arr, off_arr in [
    ("Reward", field_on_rewards, field_off_rewards),
    ("Population", field_on_populations, field_off_populations),
    ("Trail Strength", field_on_trails, field_off_trails),
    ("Survival Rate", field_on_survival, field_off_survival),
]:
    print(f"  {label:<23} {on_arr.mean():>10.3f} +/- {on_arr.std():<10.3f} "
          f"{off_arr.mean():>10.3f} +/- {off_arr.std():<10.3f}")

print(f"\nTotal seeds: {len(field_on_rewards)} ON, {len(field_off_rewards)} OFF")

In [None]:
# ========== DESCRIPTIVE STATISTICS ==========
print("="*60)
print("DESCRIPTIVE STATISTICS")
print("="*60)

for name, rewards in [("Field ON", field_on_rewards), ("Field OFF", field_off_rewards)]:
    iqm = compute_iqm(rewards, n_bootstrap=10000, seed=42)
    print(f"\n{name} (n={len(rewards)}):")
    print(f"  Mean:   {rewards.mean():.4f} +/- {rewards.std(ddof=1):.4f}")
    print(f"  Median: {np.median(rewards):.4f}")
    print(f"  IQM:    {iqm.iqm:.4f} [{iqm.ci_lower:.4f}, {iqm.ci_upper:.4f}]")
    print(f"  Min:    {rewards.min():.4f}")
    print(f"  Max:    {rewards.max():.4f}")
    print(f"  CoV:    {rewards.std(ddof=1)/rewards.mean():.4f}")

# Compute and store IQMs for later
iqm_on = compute_iqm(field_on_rewards, n_bootstrap=10000, seed=42)
iqm_off = compute_iqm(field_off_rewards, n_bootstrap=10000, seed=42)

In [None]:
# ========== HYPOTHESIS TESTS ==========
print("="*60)
print("HYPOTHESIS TESTS")
print("="*60)

# Welch's t-test
welch = welch_t_test(field_off_rewards, field_on_rewards)
print(f"\n1. Welch's t-test:")
print(f"   t = {welch.statistic:.4f}, p = {welch.p_value:.6f}")
print(f"   Cohen's d = {welch.effect_size:.4f}", end="")
d = abs(welch.effect_size)
if d < 0.2: print(" (negligible)")
elif d < 0.5: print(" (small)")
elif d < 0.8: print(" (MEDIUM)")
else: print(" (LARGE)")
print(f"   Significant at alpha=0.05: {welch.significant}")

# Mann-Whitney U
mw = mann_whitney_test(field_off_rewards, field_on_rewards)
print(f"\n2. Mann-Whitney U test:")
print(f"   U = {mw.statistic:.1f}, p = {mw.p_value:.6f}")
print(f"   Rank-biserial r = {mw.effect_size:.4f}")
print(f"   Significant at alpha=0.05: {mw.significant}")

# Probability of Improvement
poi = probability_of_improvement(field_off_rewards, field_on_rewards, n_bootstrap=5000, seed=42)
print(f"\n3. Probability of Improvement:")
print(f"   P(Field OFF > Field ON) = {poi['prob_x_better']:.4f}")
print(f"   P(Field ON > Field OFF) = {poi['prob_y_better']:.4f}")
print(f"   95% CI: [{poi['ci_lower']:.4f}, {poi['ci_upper']:.4f}]")

# Direction
print(f"\n4. Direction:")
print(f"   Field OFF mean: {field_off_rewards.mean():.4f}")
print(f"   Field ON mean:  {field_on_rewards.mean():.4f}")
print(f"   Gap: {field_off_rewards.mean() - field_on_rewards.mean():+.4f} (Field OFF {'higher' if field_off_rewards.mean() > field_on_rewards.mean() else 'lower'})")

# Sensitivity: exclude failed seeds (using consistent threshold)
mask_on = field_on_rewards >= FAILED_THRESHOLD
on_filtered = field_on_rewards[mask_on]
welch_f = welch_t_test(field_off_rewards, on_filtered)
print(f"\n5. Sensitivity (excluding {np.sum(~mask_on)} failed Field ON seeds, threshold={FAILED_THRESHOLD}):")
print(f"   Field ON filtered: n={len(on_filtered)}, mean={on_filtered.mean():.4f}")
print(f"   Welch p = {welch_f.p_value:.6f}, Cohen's d = {welch_f.effect_size:.4f}")

In [None]:
# ========== FULL METHOD COMPARISON ==========
print("="*60)
print("FULL METHOD COMPARISON (rliable-style)")
print("="*60)

comparison = compare_methods(
    {"Field ON": field_on_rewards, "Field OFF": field_off_rewards},
    n_bootstrap=10000, seed=42,
)
print(comparison.summary)

## Trail Strength & Survival Analysis

Trail strength: Field ON should show actual pheromone trails; OFF should be ~0.
Survival rate: final population / starting agents.

In [None]:
# ========== TRAIL STRENGTH & SURVIVAL ANALYSIS ==========
setup_publication_style()
colors = ['#009988', '#BBBBBB']

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# (a) Trail strength distribution
ax = axes[0]
ax.hist(field_on_trails, bins=15, alpha=0.6, color=colors[0],
        label='Field ON', edgecolor='black')
ax.hist(field_off_trails, bins=15, alpha=0.6, color=colors[1],
        label='Field OFF', edgecolor='black')
ax.axvline(field_on_trails.mean(), color=colors[0], linestyle='--', linewidth=2)
ax.axvline(field_off_trails.mean(), color=colors[1], linestyle='--', linewidth=2)
ax.set_xlabel('Trail Strength (mean Ch0 value in non-zero cells)')
ax.set_ylabel('Count')
ax.set_title('(a) Trail Strength Distribution')
ax.legend()

# (b) Survival rate distribution
ax = axes[1]
ax.hist(field_on_survival, bins=15, alpha=0.6, color=colors[0],
        label='Field ON', edgecolor='black')
ax.hist(field_off_survival, bins=15, alpha=0.6, color=colors[1],
        label='Field OFF', edgecolor='black')
ax.axvline(field_on_survival.mean(), color=colors[0], linestyle='--', linewidth=2)
ax.axvline(field_off_survival.mean(), color=colors[1], linestyle='--', linewidth=2)
ax.set_xlabel('Survival Rate (final pop / starting agents)')
ax.set_ylabel('Count')
ax.set_title('(b) Survival Rate Distribution')
ax.legend()

plt.suptitle('Trail Strength & Survival: Field ON vs OFF', fontsize=14, y=1.02)
plt.tight_layout()
save_figure(fig, os.path.join(OUTPUT_DIR, 'trail_survival_comparison'))
plt.show()

# Statistical tests
trail_welch = welch_t_test(field_on_trails, field_off_trails)
surv_welch = welch_t_test(field_on_survival, field_off_survival)

print("Trail Strength Comparison:")
print(f"  Field ON:  {field_on_trails.mean():.4f} +/- {field_on_trails.std():.4f}")
print(f"  Field OFF: {field_off_trails.mean():.4f} +/- {field_off_trails.std():.4f}")
print(f"  Welch t={trail_welch.statistic:.4f}, p={trail_welch.p_value:.6f}, "
      f"d={trail_welch.effect_size:.4f}")

print(f"\nSurvival Rate Comparison:")
print(f"  Field ON:  {field_on_survival.mean():.4f} +/- {field_on_survival.std():.4f}")
print(f"  Field OFF: {field_off_survival.mean():.4f} +/- {field_off_survival.std():.4f}")
print(f"  Welch t={surv_welch.statistic:.4f}, p={surv_welch.p_value:.6f}, "
      f"d={surv_welch.effect_size:.4f}")

## Comparison Plots

In [None]:
# ========== COMPARISON PLOTS ==========
setup_publication_style()

fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# 1. Bar chart: IQM with CI
ax = axes[0]
methods = ['Field ON\n(Stigmergy)', 'Field OFF\n(No Field)']
iqm_vals = [iqm_on.iqm, iqm_off.iqm]
iqm_lo = [iqm_on.iqm - iqm_on.ci_lower, iqm_off.iqm - iqm_off.ci_lower]
iqm_hi = [iqm_on.ci_upper - iqm_on.iqm, iqm_off.ci_upper - iqm_off.iqm]
colors = ['#009988', '#BBBBBB']

bars = ax.bar(methods, iqm_vals, yerr=[iqm_lo, iqm_hi], color=colors,
              edgecolor='black', linewidth=0.5, capsize=8)
for bar, val in zip(bars, iqm_vals):
    ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + max(iqm_hi) + 0.05,
            f'{val:.3f}', ha='center', fontsize=10)
ax.set_ylabel('IQM Reward')
ax.set_title('(a) IQM + 95% CI')
sig = f"p = {welch.p_value:.4f}"
if welch.p_value < 0.001: sig += " ***"
elif welch.p_value < 0.01: sig += " **"
elif welch.p_value < 0.05: sig += " *"
ax.text(0.5, max(iqm_vals) + max(iqm_hi) + 0.3, sig, ha='center', fontsize=11,
        transform=ax.get_xaxis_transform())

# 2. Violin + swarm
ax = axes[1]
parts = ax.violinplot([field_on_rewards, field_off_rewards], positions=[1, 2],
                       showmeans=True, showmedians=True, showextrema=False)
for i, body in enumerate(parts['bodies']):
    body.set_facecolor(colors[i])
    body.set_alpha(0.4)

rng = np.random.default_rng(42)
for i, (data, pos) in enumerate(zip([field_on_rewards, field_off_rewards], [1, 2])):
    jitter = rng.normal(0, 0.05, size=len(data))
    ax.scatter(np.full_like(data, pos) + jitter, data, alpha=0.6, s=25,
               color=colors[i], edgecolor='black', linewidth=0.3, zorder=3)

ax.scatter([1], [iqm_on.iqm], marker='D', s=80, color='red', zorder=5, label='IQM')
ax.scatter([2], [iqm_off.iqm], marker='D', s=80, color='red', zorder=5)
ax.set_xticks([1, 2])
ax.set_xticklabels(['Field ON', 'Field OFF'])
ax.set_ylabel('Mean Reward')
ax.set_title('(b) Distribution (all 60 seeds)')
ax.legend(fontsize=9)

# 3. Population comparison (training-time)
ax = axes[2]
ax.hist(field_on_populations, bins=15, alpha=0.6, color=colors[0],
        label='Field ON', edgecolor='black')
if field_off_populations_train is not None:
    ax.hist(field_off_populations_train, bins=15, alpha=0.6, color=colors[1],
            label='Field OFF', edgecolor='black')
ax.set_xlabel('Final Population (training)')
ax.set_ylabel('Count')
ax.set_title('(c) Population Distribution')
ax.axvline(x=64, color='red', linestyle='--', alpha=0.5, label='Max capacity')
ax.legend(fontsize=9)

plt.suptitle('Field ON vs Field OFF: 30-Seed Comparison', fontsize=14, y=1.02)
plt.tight_layout()
save_figure(fig, os.path.join(OUTPUT_DIR, 'main_comparison'))
plt.show()

In [None]:
# ========== PERFORMANCE PROFILES ==========
setup_publication_style()
fig = plot_performance_profiles(
    {"Field ON (Stigmergy)": field_on_rewards, "Field OFF (No Field)": field_off_rewards},
    output_path=os.path.join(OUTPUT_DIR, 'performance_profiles'),
    tau_range=(0, 1.05),
)
plt.show()

In [None]:
# ========== REWARD vs POPULATION ==========
setup_publication_style()
fig, axes = plt.subplots(1, 2, figsize=(14, 6))

# Field ON
ax = axes[0]
ax.scatter(field_on_populations, field_on_rewards, color='#009988', label='Field ON',
           s=50, alpha=0.7, edgecolor='black', linewidth=0.5)
ax.set_xlabel('Final Population')
ax.set_ylabel('Mean Reward')
ax.set_title('(a) Field ON: Reward vs Population')

# Annotate failed seeds
failed_idx = np.where(field_on_rewards < FAILED_THRESHOLD)[0]
for idx in failed_idx:
    ax.annotate(f'Seed {idx}\n(died)', xy=(field_on_populations[idx], field_on_rewards[idx]),
                fontsize=8, color='red', arrowprops=dict(arrowstyle='->', color='red'),
                xytext=(field_on_populations[idx]+5, field_on_rewards[idx]+1))

# Field OFF
ax = axes[1]
if field_off_populations_train is not None:
    ax.scatter(field_off_populations_train, field_off_rewards, color='#BBBBBB',
               label='Field OFF', s=50, alpha=0.7, edgecolor='black', linewidth=0.5)
    ax.set_xlabel('Final Population')
    ax.set_ylabel('Mean Reward')
    ax.set_title('(b) Field OFF: Reward vs Population')
else:
    ax.text(0.5, 0.5, 'Field OFF populations\navailable after eval (Phase 3)',
            ha='center', va='center', transform=ax.transAxes, fontsize=12)
    ax.set_title('(b) Field OFF: Reward vs Population')

plt.tight_layout()
save_figure(fig, os.path.join(OUTPUT_DIR, 'reward_vs_population'))
plt.show()

# Correlation for Field ON (excluding failed seeds)
from scipy import stats as scipy_stats
mask = field_on_rewards >= FAILED_THRESHOLD
r, p = scipy_stats.pearsonr(field_on_populations[mask], field_on_rewards[mask])
print(f"Field ON correlation (pop vs reward, excl. failed): r={r:.3f}, p={p:.6f}")

if field_off_populations_train is not None:
    r2, p2 = scipy_stats.pearsonr(field_off_populations_train, field_off_rewards)
    print(f"Field OFF correlation (pop vs reward): r={r2:.3f}, p={p2:.6f}")

## Phase 3: Checkpoint Analysis (BOTH Conditions)

Load ALL 60 checkpoints from Drive, run eval episodes and compute weight divergence for both Field ON and Field OFF.

In [None]:
# ========== LOAD ALL CHECKPOINTS (BOTH CONDITIONS) ==========
import glob as glob_mod
from src.training.checkpointing import load_checkpoint
from src.analysis.ablation import _run_episode_full

def discover_checkpoints(drive_dir, condition_name):
    """Discover all checkpoint paths in a Drive directory."""
    paths = []
    for batch_idx in range(10):
        batch_dir = os.path.join(drive_dir, f'batch_{batch_idx}')
        if not os.path.exists(batch_dir):
            continue
        for seed_dir in sorted(os.listdir(batch_dir)):
            seed_path = os.path.join(batch_dir, seed_dir)
            if not os.path.isdir(seed_path):
                continue
            pkl_files = glob_mod.glob(os.path.join(seed_path, 'step_*.pkl'))
            if pkl_files:
                paths.append(max(pkl_files, key=lambda p: int(re.search(r'step_(\d+)', p).group(1))))
    print(f"  {condition_name}: Found {len(paths)} checkpoints on Drive")
    return paths

field_on_ckpt_paths = discover_checkpoints(FIELD_ON_DIR, "Field ON")
field_off_ckpt_paths = discover_checkpoints(FIELD_OFF_DIR, "Field OFF")

def load_seed_data(ckpt_path):
    """Load a checkpoint and extract seed data for eval + analysis."""
    ckpt = load_checkpoint(ckpt_path)
    config = ckpt['config']
    agent_params = jax.tree_util.tree_map(lambda x: x[0], ckpt['agent_params'])
    network = ActorCritic(
        hidden_dims=tuple(config.agent.hidden_dims),
        num_actions=config.agent.num_actions,
    )
    return {
        'params': ckpt['params'],
        'agent_params': agent_params,
        'config': config,
        'network': network,
        'seed_id': ckpt.get('seed_id', -1),
    }

# Verify checkpoint counts
print(f"Field ON checkpoints:  {len(field_on_ckpt_paths)}")
print(f"Field OFF checkpoints: {len(field_off_ckpt_paths)}")
assert len(field_on_ckpt_paths) == 30, f"Expected 30 Field ON checkpoints, got {len(field_on_ckpt_paths)}"
assert len(field_off_ckpt_paths) == 30, f"Expected 30 Field OFF checkpoints, got {len(field_off_ckpt_paths)}"

# Test load one from each condition
for name, paths in [("Field ON", field_on_ckpt_paths), ("Field OFF", field_off_ckpt_paths)]:
    test_data = load_seed_data(paths[0])
    cfg = test_data['config']
    print(f"\n{name} config verification:")
    print(f"  seed={test_data['seed_id']}, grid={cfg.env.grid_size}, max_agents={cfg.evolution.max_agents}")
    print(f"  channel_diffusion={cfg.field.channel_diffusion_rates}, "
          f"channel_decay={cfg.field.channel_decay_rates}, "
          f"territory_write={cfg.field.territory_write_strength}")

print(f"\nAll checkpoints verified. Ready for eval episodes.")

In [None]:
# ========== EVAL EPISODES + WEIGHT DIVERGENCE (BOTH CONDITIONS) ==========
# Single pass: load each checkpoint once, run eval AND compute divergence
# This avoids loading 60 checkpoints twice from Drive.
from src.analysis.specialization import compute_weight_divergence

NUM_EVAL_EPISODES = 5
eval_results_on = []
eval_results_off = []
divergence_on = []
divergence_off = []

for cond_name, ckpt_paths, eval_list, div_list in [
    ("Field ON", field_on_ckpt_paths, eval_results_on, divergence_on),
    ("Field OFF", field_off_ckpt_paths, eval_results_off, divergence_off),
]:
    print(f"\n{'='*60}")
    print(f"ANALYZING {cond_name} ({len(ckpt_paths)} seeds)")
    print(f"{'='*60}")

    for i, ckpt_path in enumerate(ckpt_paths):
        seed_data = load_seed_data(ckpt_path)
        config = seed_data['config']

        # --- Eval episodes ---
        seed_pops = []
        seed_rewards = []
        seed_births = []
        seed_deaths = []

        for ep in range(NUM_EVAL_EPISODES):
            key = jax.random.PRNGKey(ep * 1000 + i)
            stats = _run_episode_full(
                network=seed_data['network'],
                params=seed_data['params'],
                config=config,
                key=key,
                condition="normal",
                evolution=True,
            )
            seed_pops.append(stats.final_population)
            seed_rewards.append(stats.total_reward)
            seed_births.append(stats.total_births)
            seed_deaths.append(stats.total_deaths)

        eval_list.append({
            'seed_id': seed_data['seed_id'],
            'ckpt_path': ckpt_path,
            'mean_total_reward': np.mean(seed_rewards),
            'std_total_reward': np.std(seed_rewards),
            'mean_population': np.mean(seed_pops),
            'mean_births': np.mean(seed_births),
            'mean_deaths': np.mean(seed_deaths),
            'survival_rate': np.mean(seed_pops) / config.env.num_agents,
            'all_rewards': seed_rewards,
            'all_populations': seed_pops,
        })

        # --- Weight divergence ---
        div = compute_weight_divergence(seed_data['agent_params'])
        div_list.append({
            'seed_id': seed_data['seed_id'],
            'mean_divergence': float(div['mean_divergence']),
            'max_divergence': float(div['max_divergence']),
            'n_agents': len(div['agent_indices']),
        })

        # Free checkpoint memory
        del seed_data
        gc.collect()

        if (i + 1) % 5 == 0 or i == 0:
            print(f"  [{i+1}/{len(ckpt_paths)}] seed {eval_list[-1]['seed_id']}: "
                  f"reward={np.mean(seed_rewards):.1f}, pop={np.mean(seed_pops):.1f}, "
                  f"div={div['mean_divergence']:.4f}")

# Extract arrays
field_on_eval_populations = np.array([r['mean_population'] for r in eval_results_on])
field_off_eval_populations = np.array([r['mean_population'] for r in eval_results_off])
field_on_eval_rewards = np.array([r['mean_total_reward'] for r in eval_results_on])
field_off_eval_rewards = np.array([r['mean_total_reward'] for r in eval_results_off])

on_mean_divs = np.array([r['mean_divergence'] for r in divergence_on])
off_mean_divs = np.array([r['mean_divergence'] for r in divergence_off])
on_max_divs = np.array([r['max_divergence'] for r in divergence_on])
off_max_divs = np.array([r['max_divergence'] for r in divergence_off])

# Summary
for cond_name, evals, pops, rewards, mean_d, max_d in [
    ("Field ON", eval_results_on, field_on_eval_populations, field_on_eval_rewards, on_mean_divs, on_max_divs),
    ("Field OFF", eval_results_off, field_off_eval_populations, field_off_eval_rewards, off_mean_divs, off_max_divs),
]:
    print(f"\n{'='*60}")
    print(f"{cond_name} SUMMARY ({len(evals)} seeds x {NUM_EVAL_EPISODES} episodes)")
    print(f"  Eval reward:       {rewards.mean():.1f} +/- {rewards.std():.1f}")
    print(f"  Eval population:   {pops.mean():.1f} +/- {pops.std():.1f}")
    print(f"  At max capacity:   {np.sum(pops >= 60)}/{len(pops)}")
    print(f"  Mean births:       {np.mean([r['mean_births'] for r in evals]):.1f}")
    print(f"  Mean deaths:       {np.mean([r['mean_deaths'] for r in evals]):.1f}")
    print(f"  Mean divergence:   {mean_d.mean():.4f} +/- {mean_d.std():.4f}")
    print(f"  Max divergence:    {max_d.mean():.4f} +/- {max_d.std():.4f}")

In [None]:
# ========== PHASE 3 PLOTS + STATISTICAL COMPARISON ==========
setup_publication_style()
colors = ['#009988', '#BBBBBB']

fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# (a) Weight divergence comparison
ax = axes[0, 0]
ax.hist(on_mean_divs, bins=12, alpha=0.6, color=colors[0], label='Field ON', edgecolor='black')
ax.hist(off_mean_divs, bins=12, alpha=0.6, color=colors[1], label='Field OFF', edgecolor='black')
ax.axvline(on_mean_divs.mean(), color=colors[0], linestyle='--', linewidth=2)
ax.axvline(off_mean_divs.mean(), color=colors[1], linestyle='--', linewidth=2)
ax.set_xlabel('Mean Pairwise Weight Divergence (cosine)')
ax.set_ylabel('Count')
ax.set_title('(a) Weight Divergence Distribution')
ax.legend()

# (b) Eval population comparison
ax = axes[0, 1]
ax.hist(field_on_eval_populations, bins=15, alpha=0.6, color=colors[0],
        label='Field ON', edgecolor='black')
ax.hist(field_off_eval_populations, bins=15, alpha=0.6, color=colors[1],
        label='Field OFF', edgecolor='black')
ax.axvline(x=64, color='red', linestyle='--', alpha=0.5, label='Max capacity')
ax.set_xlabel('Eval Population')
ax.set_ylabel('Count')
ax.set_title('(b) Eval Population Distribution')
ax.legend()

# (c) Eval reward vs population (both conditions)
ax = axes[1, 0]
ax.scatter(field_on_eval_populations, field_on_eval_rewards, color=colors[0],
           s=50, alpha=0.7, edgecolor='black', linewidth=0.5, label='Field ON')
ax.scatter(field_off_eval_populations, field_off_eval_rewards, color=colors[1],
           s=50, alpha=0.7, edgecolor='black', linewidth=0.5, label='Field OFF')
ax.set_xlabel('Eval Population')
ax.set_ylabel('Total Eval Reward')
ax.set_title('(c) Eval: Reward vs Population')
ax.legend()

# (d) Divergence vs training reward
ax = axes[1, 1]
ax.scatter(on_mean_divs, field_on_rewards, color=colors[0],
           s=50, alpha=0.7, edgecolor='black', linewidth=0.5, label='Field ON')
ax.scatter(off_mean_divs, field_off_rewards, color=colors[1],
           s=50, alpha=0.7, edgecolor='black', linewidth=0.5, label='Field OFF')
ax.set_xlabel('Mean Weight Divergence')
ax.set_ylabel('Training Reward')
ax.set_title('(d) Divergence vs Training Reward')
ax.legend()

plt.suptitle('Phase 3: Checkpoint Analysis (Both Conditions)', fontsize=14, y=1.01)
plt.tight_layout()
save_figure(fig, os.path.join(OUTPUT_DIR, 'checkpoint_analysis'))
plt.show()

# Divergence statistical comparison
div_welch = welch_t_test(on_mean_divs, off_mean_divs)
div_mw = mann_whitney_test(on_mean_divs, off_mean_divs)
print(f"\nDivergence comparison:")
print(f"  Welch t-test: t={div_welch.statistic:.4f}, p={div_welch.p_value:.6f}, d={div_welch.effect_size:.4f}")
print(f"  Mann-Whitney: U={div_mw.statistic:.1f}, p={div_mw.p_value:.6f}, r={div_mw.effect_size:.4f}")

## Phase 4: Report & Save

Generate formatted comparison report and save all results to Drive.

In [None]:
# ========== COMPARISON REPORT ==========
from datetime import datetime

# Determine significance level
if welch.p_value < 0.001:
    sig_str = "p < 0.001 (***)"
elif welch.p_value < 0.01:
    sig_str = f"p = {welch.p_value:.4f} (**)"
elif welch.p_value < 0.05:
    sig_str = f"p = {welch.p_value:.4f} (*)"
else:
    sig_str = f"p = {welch.p_value:.4f} (not significant)"

d_val = abs(welch.effect_size)
if d_val < 0.2: d_str = "negligible"
elif d_val < 0.5: d_str = "small"
elif d_val < 0.8: d_str = "medium"
else: d_str = "large"

winner = "Field OFF" if field_off_rewards.mean() > field_on_rewards.mean() else "Field ON"

# Sample std for CoV
cov_on = field_on_rewards.std(ddof=1) / field_on_rewards.mean() if field_on_rewards.mean() != 0 else float('inf')
cov_off = field_off_rewards.std(ddof=1) / field_off_rewards.mean() if field_off_rewards.mean() != 0 else float('inf')

n_failed = int(np.sum(field_on_rewards < FAILED_THRESHOLD))
on_nonfailed = field_on_rewards[field_on_rewards >= FAILED_THRESHOLD]

# Config source description
config_source = f"sweep-optimized ({os.path.basename(BEST_CONFIG_PATH)})" if os.path.exists(BEST_CONFIG_PATH) else "v1 sweep fallback values"

report = f"""# Field ON vs Field OFF: 30-Seed Comparison Report (v2 — Fresh Training)
Generated: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}

## Experiment Setup
- **Conditions**: Field ON (biological pheromone system) vs Field OFF (no shared field)
- **Seeds per condition**: 30 (10 batches x 3 seeds)
- **Training steps**: 10M per seed
- **Eval episodes**: {NUM_EVAL_EPISODES} per seed (post-training)
- **Config**: grid=40, num_agents=16 (start), max_agents=64, num_food=25
- **Energy**: starting=200, food=100, reproduce_threshold=180, cost=80
- **Config source**: {config_source}
- **Codebase**: post-energy-fix (crop refuel + free write steps)
- **Key design**: Field OFF inherits ALL params from Field ON (controlled experiment)
  - Same nest radius, compass noise, evolution params
  - ONLY difference: field channels zeroed out

## Key Results

### Eval Reward Comparison (from {NUM_EVAL_EPISODES} eval episodes per seed)
| Metric | Field ON | Field OFF |
|--------|----------|-----------|
| Mean | {field_on_rewards.mean():.4f} +/- {field_on_rewards.std(ddof=1):.4f} | {field_off_rewards.mean():.4f} +/- {field_off_rewards.std(ddof=1):.4f} |
| Median | {np.median(field_on_rewards):.4f} | {np.median(field_off_rewards):.4f} |
| IQM | {iqm_on.iqm:.4f} [{iqm_on.ci_lower:.4f}, {iqm_on.ci_upper:.4f}] | {iqm_off.iqm:.4f} [{iqm_off.ci_lower:.4f}, {iqm_off.ci_upper:.4f}] |
| Min | {field_on_rewards.min():.4f} | {field_off_rewards.min():.4f} |
| Max | {field_on_rewards.max():.4f} | {field_off_rewards.max():.4f} |
| CoV | {cov_on:.4f} | {cov_off:.4f} |

### Trail Strength & Survival
| Metric | Field ON | Field OFF |
|--------|----------|-----------|
| Trail strength | {field_on_trails.mean():.4f} +/- {field_on_trails.std():.4f} | {field_off_trails.mean():.4f} +/- {field_off_trails.std():.4f} |
| Survival rate | {field_on_survival.mean():.4f} +/- {field_on_survival.std():.4f} | {field_off_survival.mean():.4f} +/- {field_off_survival.std():.4f} |
| Trail Welch p | {trail_welch.p_value:.6f} | d = {trail_welch.effect_size:.4f} |
| Survival Welch p | {surv_welch.p_value:.6f} | d = {surv_welch.effect_size:.4f} |

### Statistical Tests (Reward)
- **Welch's t-test**: {sig_str}, Cohen's d = {welch.effect_size:.4f} ({d_str})
- **Mann-Whitney U**: U = {mw.statistic:.1f}, p = {mw.p_value:.6f}, rank-biserial r = {mw.effect_size:.4f}
- **P(Field OFF > Field ON)**: {poi['prob_x_better']:.4f}

### Winner: **{winner}** (by mean eval reward)

### Sensitivity Analysis (threshold = {FAILED_THRESHOLD})
- Excluding {n_failed} failed Field ON seeds:
  Field ON filtered mean = {on_nonfailed.mean():.4f} (n={len(on_nonfailed)})
  Welch p = {welch_f.p_value:.6f}, Cohen's d = {welch_f.effect_size:.4f}

### Population Comparison
| Metric | Field ON | Field OFF |
|--------|----------|-----------|
| Mean population | {field_on_populations.mean():.1f} +/- {field_on_populations.std():.1f} | {field_off_populations.mean():.1f} +/- {field_off_populations.std():.1f} |
| At max (64) | {np.sum(field_on_populations >= 60)}/30 | {np.sum(field_off_populations >= 60)}/30 |
| Failed seeds | {n_failed} | {int(np.sum(field_off_rewards < FAILED_THRESHOLD))} |

### Checkpoint Analysis ({NUM_EVAL_EPISODES} eval episodes/seed)
| Metric | Field ON | Field OFF |
|--------|----------|-----------|
| Eval total reward | {field_on_eval_rewards.mean():.1f} +/- {field_on_eval_rewards.std():.1f} | {field_off_eval_rewards.mean():.1f} +/- {field_off_eval_rewards.std():.1f} |
| Eval population | {field_on_eval_populations.mean():.1f} +/- {field_on_eval_populations.std():.1f} | {field_off_eval_populations.mean():.1f} +/- {field_off_eval_populations.std():.1f} |

### Weight Divergence
| Metric | Field ON | Field OFF |
|--------|----------|-----------|
| Mean divergence | {on_mean_divs.mean():.4f} +/- {on_mean_divs.std():.4f} | {off_mean_divs.mean():.4f} +/- {off_mean_divs.std():.4f} |
| Max divergence | {on_max_divs.mean():.4f} +/- {on_max_divs.std():.4f} | {off_max_divs.mean():.4f} +/- {off_max_divs.std():.4f} |
| Divergence Welch p | {div_welch.p_value:.6f} | Cohen's d = {div_welch.effect_size:.4f} |

## Interpretation

**Direction**: Field {'OFF' if field_off_rewards.mean() > field_on_rewards.mean() else 'ON'} achieves {'higher' if field_off_rewards.mean() > field_on_rewards.mean() else 'lower'} mean reward.

Key observations:
1. Trail strength: Field ON = {field_on_trails.mean():.4f}, Field OFF = {field_off_trails.mean():.4f} (p={trail_welch.p_value:.4f})
2. Survival rate: Field ON = {field_on_survival.mean():.3f}, Field OFF = {field_off_survival.mean():.3f}
3. Variance: Field ON CoV = {cov_on:.3f}, Field OFF CoV = {cov_off:.3f}
4. Failed seeds: {n_failed} Field ON seeds with reward < {FAILED_THRESHOLD}
5. Weight divergence: {'Field ON higher' if on_mean_divs.mean() > off_mean_divs.mean() else 'Field OFF higher'} (p={div_welch.p_value:.4f})

## Next Steps
- If Field ON > OFF: biological pheromone system enables collective intelligence
- If Field OFF > ON: pheromone overhead may need longer training or better hyperparameters
- Investigate trail formation patterns in successful Field ON seeds
- Test with diversity_bonus and niche_pressure enabled
"""

from IPython.display import Markdown, display
display(Markdown(report))
print("\nReport generated successfully.")

In [None]:
# ========== SAVE RESULTS ==========
results = {
    'metadata': {
        'generated': datetime.now().isoformat(),
        'field_on_seeds': 30,
        'field_off_seeds': 30,
        'steps_per_seed': 10_000_000,
        'eval_episodes_per_seed': NUM_EVAL_EPISODES,
        'data_source': 'Fresh training + eval episodes (v2)',
        'config_source': config_source,
    },
    'field_on': {
        'eval_rewards': field_on_rewards.tolist(),
        'eval_populations': field_on_populations.tolist(),
        'eval_trail_strengths': field_on_trails.tolist(),
        'eval_survival_rates': field_on_survival.tolist(),
        'mean_reward': float(field_on_rewards.mean()),
        'std_reward': float(field_on_rewards.std(ddof=1)),
        'iqm': float(iqm_on.iqm),
        'iqm_ci': [float(iqm_on.ci_lower), float(iqm_on.ci_upper)],
        'checkpoint_eval': {
            'populations': field_on_eval_populations.tolist(),
            'total_rewards': field_on_eval_rewards.tolist(),
            'per_seed': [{k: v for k, v in r.items() if k != 'ckpt_path'}
                         for r in eval_results_on],
        },
        'weight_divergence': {
            'mean_divergences': on_mean_divs.tolist(),
            'max_divergences': on_max_divs.tolist(),
            'per_seed': divergence_on,
        },
    },
    'field_off': {
        'eval_rewards': field_off_rewards.tolist(),
        'eval_populations': field_off_populations.tolist(),
        'eval_trail_strengths': field_off_trails.tolist(),
        'eval_survival_rates': field_off_survival.tolist(),
        'mean_reward': float(field_off_rewards.mean()),
        'std_reward': float(field_off_rewards.std(ddof=1)),
        'iqm': float(iqm_off.iqm),
        'iqm_ci': [float(iqm_off.ci_lower), float(iqm_off.ci_upper)],
        'checkpoint_eval': {
            'populations': field_off_eval_populations.tolist(),
            'total_rewards': field_off_eval_rewards.tolist(),
            'per_seed': [{k: v for k, v in r.items() if k != 'ckpt_path'}
                         for r in eval_results_off],
        },
        'weight_divergence': {
            'mean_divergences': off_mean_divs.tolist(),
            'max_divergences': off_max_divs.tolist(),
            'per_seed': divergence_off,
        },
    },
    'tests': {
        'welch_t': float(welch.statistic),
        'welch_p': float(welch.p_value),
        'cohens_d': float(welch.effect_size),
        'mann_whitney_u': float(mw.statistic),
        'mann_whitney_p': float(mw.p_value),
        'rank_biserial_r': float(mw.effect_size),
        'prob_off_better': float(poi['prob_x_better']),
        'divergence_welch_p': float(div_welch.p_value),
        'divergence_cohens_d': float(div_welch.effect_size),
        'trail_welch_p': float(trail_welch.p_value),
        'trail_cohens_d': float(trail_welch.effect_size),
        'survival_welch_p': float(surv_welch.p_value),
        'survival_cohens_d': float(surv_welch.effect_size),
    },
}

# Save JSON
json_path = os.path.join(OUTPUT_DIR, 'field_on_vs_off_results.json')
with open(json_path, 'w') as f:
    json.dump(results, f, indent=2)
print(f"JSON saved: {json_path}")

# Save pickle (preserves numpy arrays)
pkl_path = os.path.join(OUTPUT_DIR, 'field_on_vs_off_results.pkl')
with open(pkl_path, 'wb') as f:
    pickle.dump(results, f)
print(f"Pickle saved: {pkl_path}")

# Save report markdown
md_path = os.path.join(OUTPUT_DIR, 'comparison_report.md')
with open(md_path, 'w') as f:
    f.write(report)
print(f"Report saved: {md_path}")

# List all output files
print(f"\n{'='*60}")
print("ALL OUTPUT FILES:")
for fname in sorted(os.listdir(OUTPUT_DIR)):
    fpath = os.path.join(OUTPUT_DIR, fname)
    size_mb = os.path.getsize(fpath) / 1024 / 1024
    print(f"  {fname} ({size_mb:.2f} MB)")
print(f"\nDone! All results saved to {OUTPUT_DIR}")