# Field ON vs Field OFF: Statistical Comparison

30 seeds per condition, 10M steps each, PROVEN 64-agent config.

**All data loaded from Google Drive** (checkpoints for both conditions on Drive).

## Setup
1. Runtime > Change runtime type > **TPU v6e** + **High-RAM**
2. Run all cells (Ctrl+F9)

In [None]:
from google.colab import drive
drive.mount('/content/drive')

import os
REPO_DIR = '/content/emergence-lab'
GITHUB_USERNAME = "imashishkh"

if not os.path.exists(REPO_DIR):
    !git clone https://github.com/{GITHUB_USERNAME}/emergence-lab.git {REPO_DIR}
else:
    !cd {REPO_DIR} && git pull

os.chdir(REPO_DIR)
!pip install -e ".[dev]" -q
print(f"Working directory: {os.getcwd()}")

In [None]:
import numpy as np
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import json
import pickle
from pathlib import Path

from src.analysis.statistics import (
    compute_iqm, compare_methods, welch_t_test, mann_whitney_test,
    probability_of_improvement,
)
from src.analysis.paper_figures import (
    setup_publication_style, plot_performance_profiles,
    save_figure,
)

FIELD_ON_DIR = '/content/drive/MyDrive/emergence-lab/field_on/'
FIELD_OFF_DIR = '/content/drive/MyDrive/emergence-lab/field_off/'
OUTPUT_DIR = '/content/drive/MyDrive/emergence-lab/analysis_results/'
os.makedirs(OUTPUT_DIR, exist_ok=True)

print("Imports loaded. Output dir:", OUTPUT_DIR)
print(f"Field ON dir:  {FIELD_ON_DIR} (exists: {os.path.exists(FIELD_ON_DIR)})")
print(f"Field OFF dir: {FIELD_OFF_DIR} (exists: {os.path.exists(FIELD_OFF_DIR)})")

In [None]:
# =============================================================================
# TEST STEP - Verify everything works before running full analysis
# =============================================================================
# Tests: imports, stats, plots, BOTH Drive dirs, checkpoint loading, eval + divergence
# If this passes, the full analysis will work. If it fails, fix before proceeding.
# =============================================================================
import time

print("="*70)
print("TEST MODE: Running quick verification...")
print("="*70)

errors = []

# Test 1: Stats functions work
print("\n[1/6] Testing statistics functions...")
try:
    test_a = np.array([1.0, 2.0, 3.0, 4.0, 5.0])
    test_b = np.array([2.0, 3.0, 4.0, 5.0, 6.0])
    iqm_test = compute_iqm(test_a, n_bootstrap=100, seed=0)
    welch_test = welch_t_test(test_a, test_b)
    mw_test = mann_whitney_test(test_a, test_b)
    poi_test = probability_of_improvement(test_a, test_b, n_bootstrap=100, seed=0)
    comp_test = compare_methods({"A": test_a, "B": test_b}, n_bootstrap=100, seed=0)
    assert hasattr(iqm_test, 'iqm'), "IQM result missing .iqm"
    assert hasattr(welch_test, 'p_value'), "Welch result missing .p_value"
    assert 'prob_x_better' in poi_test, "POI missing prob_x_better"
    assert hasattr(comp_test, 'summary'), "Compare result missing .summary"
    print("   PASS: All statistics functions work correctly")
except Exception as e:
    errors.append(f"Stats functions: {e}")
    print(f"   FAIL: {e}")

# Test 2: Plot functions work
print("\n[2/6] Testing plot functions...")
try:
    setup_publication_style()
    fig_test, ax_test = plt.subplots()
    ax_test.plot([1, 2, 3])
    plt.close(fig_test)
    print("   PASS: Matplotlib + publication style working")
except Exception as e:
    errors.append(f"Plot functions: {e}")
    print(f"   FAIL: {e}")

# Test 3: Field ON Drive directory
print("\n[3/6] Testing Field ON Drive access...")
try:
    import glob as glob_mod
    assert os.path.exists(FIELD_ON_DIR), f"Directory not found: {FIELD_ON_DIR}"
    on_batches = [d for d in os.listdir(FIELD_ON_DIR) if d.startswith('batch_')]
    assert len(on_batches) > 0, "No batch directories found for Field ON"
    on_ckpts = 0
    for bd in on_batches:
        bp = os.path.join(FIELD_ON_DIR, bd)
        for sd in os.listdir(bp):
            sp = os.path.join(bp, sd)
            if os.path.isdir(sp):
                pkls = glob_mod.glob(os.path.join(sp, 'step_*.pkl'))
                on_ckpts += len(pkls) > 0
    print(f"   PASS: Field ON -- {len(on_batches)} batches, {on_ckpts} seed checkpoints")
except Exception as e:
    errors.append(f"Field ON Drive: {e}")
    print(f"   FAIL: {e}")

# Test 4: Field OFF Drive directory
print("\n[4/6] Testing Field OFF Drive access...")
try:
    assert os.path.exists(FIELD_OFF_DIR), f"Directory not found: {FIELD_OFF_DIR}"
    off_batches = [d for d in os.listdir(FIELD_OFF_DIR) if d.startswith('batch_')]
    assert len(off_batches) > 0, "No batch directories found for Field OFF"
    off_ckpts = 0
    for bd in off_batches:
        bp = os.path.join(FIELD_OFF_DIR, bd)
        for sd in os.listdir(bp):
            sp = os.path.join(bp, sd)
            if os.path.isdir(sp):
                pkls = glob_mod.glob(os.path.join(sp, 'step_*.pkl'))
                off_ckpts += len(pkls) > 0
    print(f"   PASS: Field OFF -- {len(off_batches)} batches, {off_ckpts} seed checkpoints")
except Exception as e:
    errors.append(f"Field OFF Drive: {e}")
    print(f"   FAIL: {e}")

# Test 5: Load one checkpoint from each condition
print("\n[5/6] Testing checkpoint loading (one per condition)...")
try:
    import jax
    import jax.numpy as jnp
    from src.training.checkpointing import load_checkpoint
    from src.agents.network import ActorCritic
    from src.analysis.ablation import _run_episode_full
    from src.analysis.specialization import compute_weight_divergence

    for cond_name, cond_dir in [("Field ON", FIELD_ON_DIR), ("Field OFF", FIELD_OFF_DIR)]:
        test_ckpt = None
        batch_dirs = sorted([d for d in os.listdir(cond_dir) if d.startswith('batch_')])
        for bd in batch_dirs:
            bp = os.path.join(cond_dir, bd)
            for sd in sorted(os.listdir(bp)):
                sp = os.path.join(bp, sd)
                if os.path.isdir(sp):
                    pkls = glob_mod.glob(os.path.join(sp, 'step_*.pkl'))
                    if pkls:
                        test_ckpt = sorted(pkls)[-1]
                        break
            if test_ckpt:
                break
        assert test_ckpt is not None, f"No checkpoint found for {cond_name}"
        ckpt = load_checkpoint(test_ckpt)
        config = ckpt['config']
        assert hasattr(config, 'env'), f"{cond_name} config not a dataclass"
        print(f"   {cond_name}: seed={ckpt.get('seed_id', -1)}, grid={config.env.grid_size}, "
              f"decay={config.field.decay_rate}, diffusion={config.field.diffusion_rate}")
    print("   PASS: Both conditions' checkpoints load correctly")
except Exception as e:
    errors.append(f"Checkpoint loading: {e}")
    print(f"   FAIL: {e}")

# Test 6: Eval episode + weight divergence (on last loaded checkpoint)
print("\n[6/6] Testing eval episode + weight divergence...")
try:
    agent_params = jax.tree_util.tree_map(lambda x: x[0], ckpt['agent_params'])
    network = ActorCritic(hidden_dims=tuple(config.agent.hidden_dims), num_actions=6)
    t0 = time.time()
    key = jax.random.PRNGKey(42)
    stats = _run_episode_full(
        network=network, params=ckpt['params'], config=config,
        key=key, condition="normal", evolution=True,
    )
    elapsed = time.time() - t0
    print(f"   Eval: reward={stats.total_reward:.1f}, pop={stats.final_population} ({elapsed:.1f}s)")

    div = compute_weight_divergence(agent_params)
    print(f"   Divergence: mean={div['mean_divergence']:.4f}, n_agents={len(div['agent_indices'])}")
    print(f"   PASS: Eval and weight divergence working")
except Exception as e:
    errors.append(f"Eval/divergence: {e}")
    print(f"   FAIL: {e}")

# Summary
print()
print("="*70)
if errors:
    print(f"TEST FAILED! {len(errors)} error(s):")
    for err in errors:
        print(f"  - {err}")
    print("\nDO NOT proceed until all tests pass.")
    print("="*70)
    raise RuntimeError(f"Test failed with {len(errors)} error(s)")
else:
    print("ALL 6 TESTS PASSED!")
    print("="*70)
    print("Proceed to run the full analysis below.")

## Phase 1: Load Training Data

Load training rewards from `training_summary.pkl` on Drive where available, with hardcoded fallback (verified by 4 independent audit agents).

In [None]:
# ========== LOAD TRAINING-TIME DATA FROM DRIVE ==========
import glob as glob_mod

FAILED_THRESHOLD = 1.0  # Seeds with reward below this are considered "failed"

def discover_checkpoints(drive_dir, condition_name):
    """Discover all checkpoint paths in a Drive directory."""
    paths = []
    for batch_idx in range(10):
        batch_dir = os.path.join(drive_dir, f'batch_{batch_idx}')
        if not os.path.exists(batch_dir):
            continue
        for seed_dir in sorted(os.listdir(batch_dir)):
            seed_path = os.path.join(batch_dir, seed_dir)
            if not os.path.isdir(seed_path):
                continue
            pkl_files = glob_mod.glob(os.path.join(seed_path, 'step_*.pkl'))
            if pkl_files:
                paths.append(sorted(pkl_files)[-1])
    print(f"  {condition_name}: Found {len(paths)} checkpoints on Drive")
    return paths

def load_training_summary(drive_dir, condition_name):
    """Load training_summary.pkl and extract per-seed rewards + populations."""
    summary_path = os.path.join(drive_dir, 'training_summary.pkl')
    if not os.path.exists(summary_path):
        print(f"  {condition_name}: No training_summary.pkl found")
        return None, None
    with open(summary_path, 'rb') as f:
        summary = pickle.load(f)
    # Format: {all_results: [{batch, seed_ids, metrics: {mean_reward, population_size, ...}, success: bool}, ...]}
    rewards = []
    populations = []
    skipped = 0
    for batch in summary['all_results']:
        if not batch.get('success', True):
            skipped += 1
            continue
        rewards.extend(batch['metrics']['mean_reward'])
        populations.extend(batch['metrics']['population_size'])
    rewards = np.array(rewards)
    populations = np.array(populations, dtype=int)
    if skipped > 0:
        print(f"  {condition_name}: WARNING - skipped {skipped} failed batches")
    print(f"  {condition_name}: Loaded {len(rewards)} rewards from training_summary.pkl")
    return rewards, populations

print("Discovering data on Drive...")
field_on_ckpt_paths = discover_checkpoints(FIELD_ON_DIR, "Field ON")
field_off_ckpt_paths = discover_checkpoints(FIELD_OFF_DIR, "Field OFF")

# ---- Load Field ON data ----
on_rewards_drive, on_pops_drive = load_training_summary(FIELD_ON_DIR, "Field ON")
if on_rewards_drive is not None:
    field_on_rewards = on_rewards_drive
    field_on_populations = on_pops_drive
else:
    # Hardcoded from EXPERIMENT_LOG.md (verified by 4 independent audit agents)
    field_on_rewards = np.array([
        5.19, 5.38, 4.70, 3.09, 4.84, 5.50, 2.54, 0.00, 5.18, 4.52,
        5.09, 5.33, 5.38, 3.70, 5.24, 4.61, 3.46, 4.56, 5.48, 4.42,
        5.43, 4.56, 5.30, 4.99, 4.22, 4.33, 5.51, 4.19, 4.67, 5.20,
    ])
    field_on_populations = np.array([
        64, 64, 6, 22, 64, 64, 11, 0, 64, 20,
        64, 64, 40, 30, 48, 62, 39, 58, 64, 64,
        64, 58, 64, 30, 28, 50, 62, 52, 60, 58,
    ])
    print("  Field ON: Using hardcoded values (no training_summary.pkl)")

# ---- Load Field OFF data ----
off_rewards_drive, off_pops_drive = load_training_summary(FIELD_OFF_DIR, "Field OFF")
if off_rewards_drive is not None:
    field_off_rewards = off_rewards_drive
    field_off_populations_train = off_pops_drive
else:
    # Hardcoded (verified against training_summary.pkl by audit)
    field_off_rewards = np.array([
        5.527, 5.614, 5.378, 5.728, 5.547, 5.624, 5.618, 5.386, 5.685, 5.600,
        5.430, 4.785, 5.542, 5.556, 5.610, 5.709, 5.494, 5.487, 5.695, 5.661,
        5.457, 5.605, 5.180, 5.588, 5.428, 5.399, 5.476, 5.297, 5.712, 5.670,
    ])
    field_off_populations_train = None
    print("  Field OFF: Using hardcoded values (no training_summary.pkl)")

# Populations from eval episodes (populated in Phase 3)
field_on_eval_populations = None
field_off_eval_populations = None

print(f"\n{'='*60}")
print("TRAINING DATA SUMMARY")
print(f"{'='*60}")
print(f"Field ON:  {len(field_on_rewards)} seeds, mean={field_on_rewards.mean():.3f} +/- {field_on_rewards.std(ddof=1):.3f}")
print(f"Field OFF: {len(field_off_rewards)} seeds, mean={field_off_rewards.mean():.3f} +/- {field_off_rewards.std(ddof=1):.3f}")
print(f"Field ON population: mean={field_on_populations.mean():.1f}, at max(64): {np.sum(field_on_populations==64)}/30")
if field_off_populations_train is not None:
    print(f"Field OFF population (train): mean={field_off_populations_train.mean():.1f}, at max(64): {np.sum(field_off_populations_train==64)}/30")
print(f"Field ON failed seeds (reward < {FAILED_THRESHOLD}): {np.sum(field_on_rewards < FAILED_THRESHOLD)}")
print(f"\nCheckpoints on Drive: {len(field_on_ckpt_paths)} Field ON, {len(field_off_ckpt_paths)} Field OFF")

In [None]:
# ========== DESCRIPTIVE STATISTICS ==========
print("="*60)
print("DESCRIPTIVE STATISTICS")
print("="*60)

for name, rewards in [("Field ON", field_on_rewards), ("Field OFF", field_off_rewards)]:
    iqm = compute_iqm(rewards, n_bootstrap=10000, seed=42)
    print(f"\n{name} (n={len(rewards)}):")
    print(f"  Mean:   {rewards.mean():.4f} +/- {rewards.std(ddof=1):.4f}")
    print(f"  Median: {np.median(rewards):.4f}")
    print(f"  IQM:    {iqm.iqm:.4f} [{iqm.ci_lower:.4f}, {iqm.ci_upper:.4f}]")
    print(f"  Min:    {rewards.min():.4f}")
    print(f"  Max:    {rewards.max():.4f}")
    print(f"  CoV:    {rewards.std(ddof=1)/rewards.mean():.4f}")

# Compute and store IQMs for later
iqm_on = compute_iqm(field_on_rewards, n_bootstrap=10000, seed=42)
iqm_off = compute_iqm(field_off_rewards, n_bootstrap=10000, seed=42)

In [None]:
# ========== HYPOTHESIS TESTS ==========
print("="*60)
print("HYPOTHESIS TESTS")
print("="*60)

# Welch's t-test
welch = welch_t_test(field_off_rewards, field_on_rewards)
print(f"\n1. Welch's t-test:")
print(f"   t = {welch.statistic:.4f}, p = {welch.p_value:.6f}")
print(f"   Cohen's d = {welch.effect_size:.4f}", end="")
d = abs(welch.effect_size)
if d < 0.2: print(" (negligible)")
elif d < 0.5: print(" (small)")
elif d < 0.8: print(" (MEDIUM)")
else: print(" (LARGE)")
print(f"   Significant at alpha=0.05: {welch.significant}")

# Mann-Whitney U
mw = mann_whitney_test(field_off_rewards, field_on_rewards)
print(f"\n2. Mann-Whitney U test:")
print(f"   U = {mw.statistic:.1f}, p = {mw.p_value:.6f}")
print(f"   Rank-biserial r = {mw.effect_size:.4f}")
print(f"   Significant at alpha=0.05: {mw.significant}")

# Probability of Improvement
poi = probability_of_improvement(field_off_rewards, field_on_rewards, n_bootstrap=5000, seed=42)
print(f"\n3. Probability of Improvement:")
print(f"   P(Field OFF > Field ON) = {poi['prob_x_better']:.4f}")
print(f"   P(Field ON > Field OFF) = {poi['prob_y_better']:.4f}")
print(f"   95% CI: [{poi['ci_lower']:.4f}, {poi['ci_upper']:.4f}]")

# Direction
print(f"\n4. Direction:")
print(f"   Field OFF mean: {field_off_rewards.mean():.4f}")
print(f"   Field ON mean:  {field_on_rewards.mean():.4f}")
print(f"   Gap: {field_off_rewards.mean() - field_on_rewards.mean():+.4f} (Field OFF {'higher' if field_off_rewards.mean() > field_on_rewards.mean() else 'lower'})")

# Sensitivity: exclude failed seeds (using consistent threshold)
mask_on = field_on_rewards >= FAILED_THRESHOLD
on_filtered = field_on_rewards[mask_on]
welch_f = welch_t_test(field_off_rewards, on_filtered)
print(f"\n5. Sensitivity (excluding {np.sum(~mask_on)} failed Field ON seeds, threshold={FAILED_THRESHOLD}):")
print(f"   Field ON filtered: n={len(on_filtered)}, mean={on_filtered.mean():.4f}")
print(f"   Welch p = {welch_f.p_value:.6f}, Cohen's d = {welch_f.effect_size:.4f}")

In [None]:
# ========== FULL METHOD COMPARISON ==========
print("="*60)
print("FULL METHOD COMPARISON (rliable-style)")
print("="*60)

comparison = compare_methods(
    {"Field ON": field_on_rewards, "Field OFF": field_off_rewards},
    n_bootstrap=10000, seed=42,
)
print(comparison.summary)

## Phase 2: Comparison Plots

In [None]:
# ========== COMPARISON PLOTS ==========
setup_publication_style()

fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# 1. Bar chart: IQM with CI
ax = axes[0]
methods = ['Field ON\n(Stigmergy)', 'Field OFF\n(No Field)']
iqm_vals = [iqm_on.iqm, iqm_off.iqm]
iqm_lo = [iqm_on.iqm - iqm_on.ci_lower, iqm_off.iqm - iqm_off.ci_lower]
iqm_hi = [iqm_on.ci_upper - iqm_on.iqm, iqm_off.ci_upper - iqm_off.iqm]
colors = ['#009988', '#BBBBBB']

bars = ax.bar(methods, iqm_vals, yerr=[iqm_lo, iqm_hi], color=colors,
              edgecolor='black', linewidth=0.5, capsize=8)
for bar, val in zip(bars, iqm_vals):
    ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + max(iqm_hi) + 0.05,
            f'{val:.3f}', ha='center', fontsize=10)
ax.set_ylabel('IQM Reward')
ax.set_title('(a) IQM + 95% CI')
sig = f"p = {welch.p_value:.4f}"
if welch.p_value < 0.001: sig += " ***"
elif welch.p_value < 0.01: sig += " **"
elif welch.p_value < 0.05: sig += " *"
ax.text(0.5, max(iqm_vals) + max(iqm_hi) + 0.3, sig, ha='center', fontsize=11,
        transform=ax.get_xaxis_transform())

# 2. Violin + swarm
ax = axes[1]
parts = ax.violinplot([field_on_rewards, field_off_rewards], positions=[1, 2],
                       showmeans=True, showmedians=True, showextrema=False)
for i, body in enumerate(parts['bodies']):
    body.set_facecolor(colors[i])
    body.set_alpha(0.4)

rng = np.random.default_rng(42)
for i, (data, pos) in enumerate(zip([field_on_rewards, field_off_rewards], [1, 2])):
    jitter = rng.normal(0, 0.05, size=len(data))
    ax.scatter(np.full_like(data, pos) + jitter, data, alpha=0.6, s=25,
               color=colors[i], edgecolor='black', linewidth=0.3, zorder=3)

ax.scatter([1], [iqm_on.iqm], marker='D', s=80, color='red', zorder=5, label='IQM')
ax.scatter([2], [iqm_off.iqm], marker='D', s=80, color='red', zorder=5)
ax.set_xticks([1, 2])
ax.set_xticklabels(['Field ON', 'Field OFF'])
ax.set_ylabel('Mean Reward')
ax.set_title('(b) Distribution (all 60 seeds)')
ax.legend(fontsize=9)

# 3. Population comparison (training-time)
ax = axes[2]
ax.hist(field_on_populations, bins=15, alpha=0.6, color=colors[0],
        label='Field ON', edgecolor='black')
if field_off_populations_train is not None:
    ax.hist(field_off_populations_train, bins=15, alpha=0.6, color=colors[1],
            label='Field OFF', edgecolor='black')
ax.set_xlabel('Final Population (training)')
ax.set_ylabel('Count')
ax.set_title('(c) Population Distribution')
ax.axvline(x=64, color='red', linestyle='--', alpha=0.5, label='Max capacity')
ax.legend(fontsize=9)

plt.suptitle('Field ON vs Field OFF: 30-Seed Comparison', fontsize=14, y=1.02)
plt.tight_layout()
save_figure(fig, os.path.join(OUTPUT_DIR, 'main_comparison'))
plt.show()

In [None]:
# ========== PERFORMANCE PROFILES ==========
setup_publication_style()
fig = plot_performance_profiles(
    {"Field ON (Stigmergy)": field_on_rewards, "Field OFF (No Field)": field_off_rewards},
    output_path=os.path.join(OUTPUT_DIR, 'performance_profiles'),
    tau_range=(0, 1.05),
)
plt.show()

In [None]:
# ========== REWARD vs POPULATION ==========
setup_publication_style()
fig, axes = plt.subplots(1, 2, figsize=(14, 6))

# Field ON
ax = axes[0]
ax.scatter(field_on_populations, field_on_rewards, color='#009988', label='Field ON',
           s=50, alpha=0.7, edgecolor='black', linewidth=0.5)
ax.set_xlabel('Final Population')
ax.set_ylabel('Mean Reward')
ax.set_title('(a) Field ON: Reward vs Population')

# Annotate failed seeds
failed_idx = np.where(field_on_rewards < FAILED_THRESHOLD)[0]
for idx in failed_idx:
    ax.annotate(f'Seed {idx}\n(died)', xy=(field_on_populations[idx], field_on_rewards[idx]),
                fontsize=8, color='red', arrowprops=dict(arrowstyle='->', color='red'),
                xytext=(field_on_populations[idx]+5, field_on_rewards[idx]+1))

# Field OFF
ax = axes[1]
if field_off_populations_train is not None:
    ax.scatter(field_off_populations_train, field_off_rewards, color='#BBBBBB',
               label='Field OFF', s=50, alpha=0.7, edgecolor='black', linewidth=0.5)
    ax.set_xlabel('Final Population')
    ax.set_ylabel('Mean Reward')
    ax.set_title('(b) Field OFF: Reward vs Population')
else:
    ax.text(0.5, 0.5, 'Field OFF populations\navailable after eval (Phase 3)',
            ha='center', va='center', transform=ax.transAxes, fontsize=12)
    ax.set_title('(b) Field OFF: Reward vs Population')

plt.tight_layout()
save_figure(fig, os.path.join(OUTPUT_DIR, 'reward_vs_population'))
plt.show()

# Correlation for Field ON (excluding failed seeds)
from scipy import stats as scipy_stats
mask = field_on_rewards >= FAILED_THRESHOLD
r, p = scipy_stats.pearsonr(field_on_populations[mask], field_on_rewards[mask])
print(f"Field ON correlation (pop vs reward, excl. failed): r={r:.3f}, p={p:.6f}")

if field_off_populations_train is not None:
    r2, p2 = scipy_stats.pearsonr(field_off_populations_train, field_off_rewards)
    print(f"Field OFF correlation (pop vs reward): r={r2:.3f}, p={p2:.6f}")

## Phase 3: Checkpoint Analysis (BOTH Conditions)

Load ALL 60 checkpoints from Drive, run eval episodes and compute weight divergence for both Field ON and Field OFF.

In [None]:
# ========== LOAD ALL CHECKPOINTS (BOTH CONDITIONS) ==========
import jax
import jax.numpy as jnp
from src.training.checkpointing import load_checkpoint
from src.agents.network import ActorCritic
from src.analysis.ablation import _run_episode_full

def load_seed_data(ckpt_path):
    """Load a checkpoint and extract seed data for eval + analysis."""
    ckpt = load_checkpoint(ckpt_path)
    config = ckpt['config']
    agent_params = jax.tree_util.tree_map(lambda x: x[0], ckpt['agent_params'])
    network = ActorCritic(hidden_dims=tuple(config.agent.hidden_dims), num_actions=6)
    return {
        'params': ckpt['params'],
        'agent_params': agent_params,
        'config': config,
        'network': network,
        'seed_id': ckpt.get('seed_id', -1),
    }

# Verify checkpoint counts
print(f"Field ON checkpoints:  {len(field_on_ckpt_paths)}")
print(f"Field OFF checkpoints: {len(field_off_ckpt_paths)}")
assert len(field_on_ckpt_paths) == 30, f"Expected 30 Field ON checkpoints, got {len(field_on_ckpt_paths)}"
assert len(field_off_ckpt_paths) == 30, f"Expected 30 Field OFF checkpoints, got {len(field_off_ckpt_paths)}"

# Test load one from each condition
for name, paths in [("Field ON", field_on_ckpt_paths), ("Field OFF", field_off_ckpt_paths)]:
    test_data = load_seed_data(paths[0])
    cfg = test_data['config']
    print(f"\n{name} config verification:")
    print(f"  seed={test_data['seed_id']}, grid={cfg.env.grid_size}, max_agents={cfg.evolution.max_agents}")
    print(f"  diffusion={cfg.field.diffusion_rate}, decay={cfg.field.decay_rate}, write={cfg.field.write_strength}")

print(f"\nAll checkpoints verified. Ready for eval episodes.")

In [None]:
# ========== EVAL EPISODES (BOTH CONDITIONS) ==========
# Run 5 eval episodes per seed to get population dynamics + survival stats
# Uses shared params (not per-agent); "normal" condition (field config differs between conditions)

NUM_EVAL_EPISODES = 5
eval_results_on = []
eval_results_off = []

for cond_name, ckpt_paths, results_list in [
    ("Field ON", field_on_ckpt_paths, eval_results_on),
    ("Field OFF", field_off_ckpt_paths, eval_results_off),
]:
    print(f"\n{'='*60}")
    print(f"EVALUATING {cond_name} ({len(ckpt_paths)} seeds x {NUM_EVAL_EPISODES} episodes)")
    print(f"{'='*60}")

    for i, ckpt_path in enumerate(ckpt_paths):
        seed_data = load_seed_data(ckpt_path)
        config = seed_data['config']

        seed_pops = []
        seed_rewards = []
        seed_births = []
        seed_deaths = []

        for ep in range(NUM_EVAL_EPISODES):
            key = jax.random.PRNGKey(ep * 1000 + i)
            stats = _run_episode_full(
                network=seed_data['network'],
                params=seed_data['params'],
                config=config,
                key=key,
                condition="normal",
                evolution=True,
            )
            seed_pops.append(stats.final_population)
            seed_rewards.append(stats.total_reward)
            seed_births.append(stats.total_births)
            seed_deaths.append(stats.total_deaths)

        results_list.append({
            'seed_id': seed_data['seed_id'],
            'ckpt_path': ckpt_path,
            'mean_total_reward': np.mean(seed_rewards),
            'std_total_reward': np.std(seed_rewards),
            'mean_population': np.mean(seed_pops),
            'mean_births': np.mean(seed_births),
            'mean_deaths': np.mean(seed_deaths),
            'survival_rate': np.mean(seed_pops) / config.evolution.max_agents,
            'all_rewards': seed_rewards,
            'all_populations': seed_pops,
        })

        if (i + 1) % 5 == 0 or i == 0:
            print(f"  [{i+1}/{len(ckpt_paths)}] seed {seed_data['seed_id']}: "
                  f"reward={np.mean(seed_rewards):.1f}, pop={np.mean(seed_pops):.1f}, "
                  f"births={np.mean(seed_births):.0f}, deaths={np.mean(seed_deaths):.0f}")

# Extract populations and eval rewards
field_on_eval_populations = np.array([r['mean_population'] for r in eval_results_on])
field_off_eval_populations = np.array([r['mean_population'] for r in eval_results_off])
field_on_eval_rewards = np.array([r['mean_total_reward'] for r in eval_results_on])
field_off_eval_rewards = np.array([r['mean_total_reward'] for r in eval_results_off])

# Summary
for cond_name, evals, pops, rewards in [
    ("Field ON", eval_results_on, field_on_eval_populations, field_on_eval_rewards),
    ("Field OFF", eval_results_off, field_off_eval_populations, field_off_eval_rewards),
]:
    print(f"\n{'='*60}")
    print(f"{cond_name} EVAL SUMMARY ({len(evals)} seeds x {NUM_EVAL_EPISODES} episodes)")
    print(f"  Mean total reward: {rewards.mean():.1f} +/- {rewards.std():.1f}")
    print(f"  Mean population:   {pops.mean():.1f} +/- {pops.std():.1f}")
    print(f"  At max capacity:   {np.sum(pops >= 60)}/{len(pops)}")
    print(f"  Mean births:       {np.mean([r['mean_births'] for r in evals]):.1f}")
    print(f"  Mean deaths:       {np.mean([r['mean_deaths'] for r in evals]):.1f}")

In [None]:
# ========== WEIGHT DIVERGENCE (BOTH CONDITIONS) ==========
from src.analysis.specialization import compute_weight_divergence

divergence_on = []
divergence_off = []

for cond_name, ckpt_paths, div_list in [
    ("Field ON", field_on_ckpt_paths, divergence_on),
    ("Field OFF", field_off_ckpt_paths, divergence_off),
]:
    print(f"\n{'='*60}")
    print(f"WEIGHT DIVERGENCE: {cond_name}")
    print(f"{'='*60}")

    for i, ckpt_path in enumerate(ckpt_paths):
        seed_data = load_seed_data(ckpt_path)
        div = compute_weight_divergence(seed_data['agent_params'])
        div_list.append({
            'seed_id': seed_data['seed_id'],
            'mean_divergence': float(div['mean_divergence']),
            'max_divergence': float(div['max_divergence']),
            'n_agents': len(div['agent_indices']),
        })

        if (i + 1) % 10 == 0 or i == 0:
            print(f"  [{i+1}/{len(ckpt_paths)}] seed {seed_data['seed_id']}: "
                  f"mean_div={div['mean_divergence']:.4f}, max_div={div['max_divergence']:.4f}")

# Extract arrays
on_mean_divs = np.array([r['mean_divergence'] for r in divergence_on])
off_mean_divs = np.array([r['mean_divergence'] for r in divergence_off])
on_max_divs = np.array([r['max_divergence'] for r in divergence_on])
off_max_divs = np.array([r['max_divergence'] for r in divergence_off])

# Summary
for name, mean_d, max_d in [("Field ON", on_mean_divs, on_max_divs),
                              ("Field OFF", off_mean_divs, off_max_divs)]:
    print(f"\n{name} DIVERGENCE SUMMARY:")
    print(f"  Mean: {mean_d.mean():.4f} +/- {mean_d.std():.4f}")
    print(f"  Max:  {max_d.mean():.4f} +/- {max_d.std():.4f}")
    print(f"  Range: [{mean_d.min():.4f}, {mean_d.max():.4f}]")

# ========== PHASE 3 PLOTS ==========
setup_publication_style()
colors = ['#009988', '#BBBBBB']

fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# (a) Weight divergence comparison
ax = axes[0, 0]
ax.hist(on_mean_divs, bins=12, alpha=0.6, color=colors[0], label='Field ON', edgecolor='black')
ax.hist(off_mean_divs, bins=12, alpha=0.6, color=colors[1], label='Field OFF', edgecolor='black')
ax.axvline(on_mean_divs.mean(), color=colors[0], linestyle='--', linewidth=2)
ax.axvline(off_mean_divs.mean(), color=colors[1], linestyle='--', linewidth=2)
ax.set_xlabel('Mean Pairwise Weight Divergence (cosine)')
ax.set_ylabel('Count')
ax.set_title('(a) Weight Divergence Distribution')
ax.legend()

# (b) Eval population comparison
ax = axes[0, 1]
ax.hist(field_on_eval_populations, bins=15, alpha=0.6, color=colors[0],
        label='Field ON', edgecolor='black')
ax.hist(field_off_eval_populations, bins=15, alpha=0.6, color=colors[1],
        label='Field OFF', edgecolor='black')
ax.axvline(x=64, color='red', linestyle='--', alpha=0.5, label='Max capacity')
ax.set_xlabel('Eval Population')
ax.set_ylabel('Count')
ax.set_title('(b) Eval Population Distribution')
ax.legend()

# (c) Eval reward vs population (both conditions)
ax = axes[1, 0]
ax.scatter(field_on_eval_populations, field_on_eval_rewards, color=colors[0],
           s=50, alpha=0.7, edgecolor='black', linewidth=0.5, label='Field ON')
ax.scatter(field_off_eval_populations, field_off_eval_rewards, color=colors[1],
           s=50, alpha=0.7, edgecolor='black', linewidth=0.5, label='Field OFF')
ax.set_xlabel('Eval Population')
ax.set_ylabel('Total Eval Reward')
ax.set_title('(c) Eval: Reward vs Population')
ax.legend()

# (d) Divergence vs training reward
ax = axes[1, 1]
ax.scatter(on_mean_divs, field_on_rewards, color=colors[0],
           s=50, alpha=0.7, edgecolor='black', linewidth=0.5, label='Field ON')
ax.scatter(off_mean_divs, field_off_rewards, color=colors[1],
           s=50, alpha=0.7, edgecolor='black', linewidth=0.5, label='Field OFF')
ax.set_xlabel('Mean Weight Divergence')
ax.set_ylabel('Training Reward')
ax.set_title('(d) Divergence vs Training Reward')
ax.legend()

plt.suptitle('Phase 3: Checkpoint Analysis (Both Conditions)', fontsize=14, y=1.01)
plt.tight_layout()
save_figure(fig, os.path.join(OUTPUT_DIR, 'checkpoint_analysis'))
plt.show()

# Divergence statistical comparison
div_welch = welch_t_test(on_mean_divs, off_mean_divs)
div_mw = mann_whitney_test(on_mean_divs, off_mean_divs)
print(f"\nDivergence comparison:")
print(f"  Welch t-test: t={div_welch.statistic:.4f}, p={div_welch.p_value:.6f}, d={div_welch.effect_size:.4f}")
print(f"  Mann-Whitney: U={div_mw.statistic:.1f}, p={div_mw.p_value:.6f}, r={div_mw.effect_size:.4f}")

## Phase 4: Report & Save

Generate formatted comparison report and save all results to Drive.

In [None]:
# ========== COMPARISON REPORT ==========
from datetime import datetime

# Determine significance level
if welch.p_value < 0.001:
    sig_str = "p < 0.001 (***)"
elif welch.p_value < 0.01:
    sig_str = f"p = {welch.p_value:.4f} (**)"
elif welch.p_value < 0.05:
    sig_str = f"p = {welch.p_value:.4f} (*)"
else:
    sig_str = f"p = {welch.p_value:.4f} (not significant)"

d_val = abs(welch.effect_size)
if d_val < 0.2: d_str = "negligible"
elif d_val < 0.5: d_str = "small"
elif d_val < 0.8: d_str = "medium"
else: d_str = "large"

winner = "Field OFF" if field_off_rewards.mean() > field_on_rewards.mean() else "Field ON"

# Sample std for CoV
cov_on = field_on_rewards.std(ddof=1) / field_on_rewards.mean()
cov_off = field_off_rewards.std(ddof=1) / field_off_rewards.mean()

n_failed = int(np.sum(field_on_rewards < FAILED_THRESHOLD))
on_nonfailed = field_on_rewards[field_on_rewards >= FAILED_THRESHOLD]

report = f"""# Field ON vs Field OFF: 30-Seed Comparison Report
Generated: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}

## Experiment Setup
- **Conditions**: Field ON (stigmergy) vs Field OFF (no shared field)
- **Seeds per condition**: 30
- **Training steps**: 10M per seed
- **Config**: 64-agent, grid=32, num_food=40, starting_energy=200
- **Field ON**: diffusion=0.1, decay=0.05, write_strength=1.0
- **Field OFF**: diffusion=0.0, decay=1.0, write_strength=0.0

## Key Results

### Training Reward Comparison
| Metric | Field ON | Field OFF |
|--------|----------|-----------|
| Mean | {field_on_rewards.mean():.4f} +/- {field_on_rewards.std(ddof=1):.4f} | {field_off_rewards.mean():.4f} +/- {field_off_rewards.std(ddof=1):.4f} |
| Median | {np.median(field_on_rewards):.4f} | {np.median(field_off_rewards):.4f} |
| IQM | {iqm_on.iqm:.4f} [{iqm_on.ci_lower:.4f}, {iqm_on.ci_upper:.4f}] | {iqm_off.iqm:.4f} [{iqm_off.ci_lower:.4f}, {iqm_off.ci_upper:.4f}] |
| Min | {field_on_rewards.min():.4f} | {field_off_rewards.min():.4f} |
| Max | {field_on_rewards.max():.4f} | {field_off_rewards.max():.4f} |
| CoV | {cov_on:.4f} | {cov_off:.4f} |

### Statistical Tests
- **Welch's t-test**: {sig_str}, Cohen's d = {welch.effect_size:.4f} ({d_str})
- **Mann-Whitney U**: U = {mw.statistic:.1f}, p = {mw.p_value:.6f}, rank-biserial r = {mw.effect_size:.4f}
- **P(Field OFF > Field ON)**: {poi['prob_x_better']:.4f}

### Winner: **{winner}** (by mean reward)

### Sensitivity Analysis (threshold = {FAILED_THRESHOLD})
- Excluding {n_failed} failed Field ON seeds:
  Field ON filtered mean = {on_nonfailed.mean():.4f} (n={len(on_nonfailed)})
  Welch p = {welch_f.p_value:.6f}, Cohen's d = {welch_f.effect_size:.4f}

### Eval Episodes ({NUM_EVAL_EPISODES} episodes/seed, BOTH conditions)
| Metric | Field ON | Field OFF |
|--------|----------|-----------|
| Mean total reward | {field_on_eval_rewards.mean():.1f} +/- {field_on_eval_rewards.std():.1f} | {field_off_eval_rewards.mean():.1f} +/- {field_off_eval_rewards.std():.1f} |
| Mean population | {field_on_eval_populations.mean():.1f} +/- {field_on_eval_populations.std():.1f} | {field_off_eval_populations.mean():.1f} +/- {field_off_eval_populations.std():.1f} |
| At max capacity | {np.sum(field_on_eval_populations >= 60)}/{len(field_on_eval_populations)} | {np.sum(field_off_eval_populations >= 60)}/{len(field_off_eval_populations)} |
| Mean births | {np.mean([r['mean_births'] for r in eval_results_on]):.1f} | {np.mean([r['mean_births'] for r in eval_results_off]):.1f} |
| Mean deaths | {np.mean([r['mean_deaths'] for r in eval_results_on]):.1f} | {np.mean([r['mean_deaths'] for r in eval_results_off]):.1f} |

### Weight Divergence (BOTH conditions)
| Metric | Field ON | Field OFF |
|--------|----------|-----------|
| Mean divergence | {on_mean_divs.mean():.4f} +/- {on_mean_divs.std():.4f} | {off_mean_divs.mean():.4f} +/- {off_mean_divs.std():.4f} |
| Max divergence | {on_max_divs.mean():.4f} +/- {on_max_divs.std():.4f} | {off_max_divs.mean():.4f} +/- {off_max_divs.std():.4f} |
| Divergence Welch p | {div_welch.p_value:.6f} | Cohen's d = {div_welch.effect_size:.4f} |

### Population Dynamics (training-time)
| Metric | Field ON | Field OFF |
|--------|----------|-----------|
| Mean final pop | {field_on_populations.mean():.1f} | {f'{field_off_populations_train.mean():.1f}' if field_off_populations_train is not None else 'N/A (eval only)'} |
| At max (64) | {np.sum(field_on_populations == 64)}/30 | {f'{np.sum(field_off_populations_train == 64)}/30' if field_off_populations_train is not None else 'N/A'} |
| Failed seeds | {n_failed} | 0 |

## Interpretation

**Surprising finding**: Field OFF agents achieve {'higher' if field_off_rewards.mean() > field_on_rewards.mean() else 'lower'} mean reward than Field ON.

Key observations:
1. Field ON has MUCH higher variance (CoV {cov_on:.3f} vs {cov_off:.3f})
2. Field ON has {n_failed} failed seed(s) with near-zero reward
3. Field OFF is remarkably consistent across all 30 seeds
4. When excluding failed seeds, the gap {'narrows' if abs(welch_f.effect_size) < abs(welch.effect_size) else 'remains'}
5. Weight divergence comparison: {'Field ON has higher divergence' if on_mean_divs.mean() > off_mean_divs.mean() else 'Field OFF has higher divergence'} (p={div_welch.p_value:.4f})

Possible explanations:
- The shared field may introduce a coordination overhead that hurts some seeds
- Field ON populations are more variable (some collapse, some max out)
- The field may be a harder optimization landscape requiring more training
- Field OFF is simpler: agents just learn individual foraging without field-reading costs
- Higher weight divergence in Field ON could indicate the field enables specialization

## Next Steps
- Investigate why some Field ON seeds fail (population collapse analysis)
- Try longer training (20M+ steps) to see if Field ON catches up
- Test with diversity_bonus and niche_pressure enabled
- Analyze field channel specialization in successful Field ON seeds
- Compare behavioral clustering between conditions
"""

from IPython.display import Markdown, display
display(Markdown(report))
print("\nReport generated successfully.")

In [None]:
# ========== SAVE RESULTS ==========
results = {
    'metadata': {
        'generated': datetime.now().isoformat(),
        'field_on_seeds': 30,
        'field_off_seeds': 30,
        'steps_per_seed': 10_000_000,
        'eval_episodes_per_seed': NUM_EVAL_EPISODES,
        'data_source': 'Google Drive checkpoints + training_summary.pkl',
    },
    'field_on': {
        'training_rewards': field_on_rewards.tolist(),
        'training_populations': field_on_populations.tolist(),
        'mean_reward': float(field_on_rewards.mean()),
        'std_reward': float(field_on_rewards.std(ddof=1)),
        'iqm': float(iqm_on.iqm),
        'iqm_ci': [float(iqm_on.ci_lower), float(iqm_on.ci_upper)],
        'eval': {
            'populations': field_on_eval_populations.tolist(),
            'total_rewards': field_on_eval_rewards.tolist(),
            'per_seed': [{k: v for k, v in r.items() if k != 'ckpt_path'}
                         for r in eval_results_on],
        },
        'weight_divergence': {
            'mean_divergences': on_mean_divs.tolist(),
            'max_divergences': on_max_divs.tolist(),
            'per_seed': divergence_on,
        },
    },
    'field_off': {
        'training_rewards': field_off_rewards.tolist(),
        'training_populations': field_off_populations_train.tolist() if field_off_populations_train is not None else None,
        'mean_reward': float(field_off_rewards.mean()),
        'std_reward': float(field_off_rewards.std(ddof=1)),
        'iqm': float(iqm_off.iqm),
        'iqm_ci': [float(iqm_off.ci_lower), float(iqm_off.ci_upper)],
        'eval': {
            'populations': field_off_eval_populations.tolist(),
            'total_rewards': field_off_eval_rewards.tolist(),
            'per_seed': [{k: v for k, v in r.items() if k != 'ckpt_path'}
                         for r in eval_results_off],
        },
        'weight_divergence': {
            'mean_divergences': off_mean_divs.tolist(),
            'max_divergences': off_max_divs.tolist(),
            'per_seed': divergence_off,
        },
    },
    'tests': {
        'welch_t': float(welch.statistic),
        'welch_p': float(welch.p_value),
        'cohens_d': float(welch.effect_size),
        'mann_whitney_u': float(mw.statistic),
        'mann_whitney_p': float(mw.p_value),
        'rank_biserial_r': float(mw.effect_size),
        'prob_off_better': float(poi['prob_x_better']),
        'divergence_welch_p': float(div_welch.p_value),
        'divergence_cohens_d': float(div_welch.effect_size),
    },
}

# Save JSON
json_path = os.path.join(OUTPUT_DIR, 'field_on_vs_off_results.json')
with open(json_path, 'w') as f:
    json.dump(results, f, indent=2)
print(f"JSON saved: {json_path}")

# Save pickle (preserves numpy arrays)
pkl_path = os.path.join(OUTPUT_DIR, 'field_on_vs_off_results.pkl')
with open(pkl_path, 'wb') as f:
    pickle.dump(results, f)
print(f"Pickle saved: {pkl_path}")

# Save report markdown
md_path = os.path.join(OUTPUT_DIR, 'comparison_report.md')
with open(md_path, 'w') as f:
    f.write(report)
print(f"Report saved: {md_path}")

# List all output files
print(f"\n{'='*60}")
print("ALL OUTPUT FILES:")
for fname in sorted(os.listdir(OUTPUT_DIR)):
    fpath = os.path.join(OUTPUT_DIR, fname)
    size_mb = os.path.getsize(fpath) / 1024 / 1024
    print(f"  {fname} ({size_mb:.2f} MB)")
print(f"\nDone! All results saved to {OUTPUT_DIR}")