# Pheromone Hyperparameter Sweep

## Methodology

**One-at-a-time sweep** across 6 parameter groups for the biological pheromone system.
Each parameter group is swept independently while holding all others at their base values.
3 seeds per value, 23 configs total = **69 training runs**.

## Sweep Design

| Group | Parameter | Values | Count |
|-------|-----------|--------|-------|
| 1 | Recruitment decay (`channel_decay_rates[0]`) | 0.02, 0.03, 0.05, 0.08 | 4 |
| 2 | Recruitment diffusion (`channel_diffusion_rates[0]`) | 0.2, 0.3, 0.5, 0.7 | 4 |
| 3 | Territory write strength (`field.territory_write_strength`) | 0.005, 0.01, 0.02, 0.05 | 4 |
| 4 | Compass noise (`nest.compass_noise_rate`) | 0.05, 0.10, 0.15, 0.20 | 4 |
| 5 | Scout sip fraction (`nest.food_sip_fraction`) | 0.05, 0.10, 0.15 | 3 |
| 6 | Nest radius (`nest.radius`) | 2, 3, 4 | 3 |
| baseline | Field OFF (no pheromone) | — | 1 |
| **Total** | | | **23 configs × 3 seeds = 69 runs** |

## Base Config

| Parameter | Value | Rationale |
|-----------|-------|-----------|
| `grid_size` | 40 | Large arena for trail formation |
| `num_agents` | 16 | Starting population |
| `num_food` | 25 | More generous (carry-back mechanic makes food harder) |
| `food_energy` | 100 | More generous for pheromone system |
| `max_agents` | 64 | Population cap |
| `total_steps` | 2M | Per-config training budget |
| `nest.radius` | 2 | 5×5 nest area (default) |
| `channel_diffusion_rates` | (0.5, 0.01, 0.0, 0.0) | Recruitment spreads wide, territory stays local |
| `channel_decay_rates` | (0.05, 0.0001, 0.0, 0.0) | Recruitment fades fast, territory near-permanent |

## Decision Gates

- **After Group 1** (first 5 configs): If ALL runs have population < 5, print warning and suggest bumping `num_food` to 30 or `food_energy` to 120.
- **After full sweep**: Auto-select best value per group, combine into best config, run confirmation.

In [None]:
# Cell 1 — Setup
from google.colab import drive
drive.mount('/content/drive')

import os
REPO_DIR = '/content/emergence-lab'
GITHUB_USERNAME = 'imashishkh21'

if not os.path.exists(REPO_DIR):
    !git clone https://github.com/{GITHUB_USERNAME}/emergence-lab.git {REPO_DIR}
else:
    !cd {REPO_DIR} && git pull origin main

os.chdir(REPO_DIR)
!pip install -e ".[dev]" -q
!pip install rliable -q

# Verify pheromone system support
from src.configs import NestConfig, FieldConfig
assert hasattr(NestConfig, 'radius'), "NestConfig missing! Pull latest main."
assert hasattr(FieldConfig, 'territory_write_strength'), (
    "FieldConfig missing territory_write_strength! Pull latest main."
)
assert hasattr(FieldConfig, 'channel_decay_rates'), (
    "FieldConfig missing channel_decay_rates! Pull latest main."
)
print(f"NestConfig defaults: radius={NestConfig().radius}, "
      f"sip={NestConfig().food_sip_fraction}, noise={NestConfig().compass_noise_rate}")
print(f"FieldConfig territory_write_strength default: {FieldConfig().territory_write_strength}")

import jax
print(f"JAX version: {jax.__version__}")
print(f"JAX devices: {jax.devices()}")
print(f"JAX backend: {jax.default_backend()}")

## Sweep Configuration

Each config is built from the base by overriding a single parameter.
The field-off baseline disables the field entirely (diffusion=0, decay=1, write_strength=0).

**NOTE:** `territory_write_strength` must be a configurable field in `FieldConfig`.
If Group 3 fails with `AttributeError`, the pheromone config changes haven't been merged.

In [None]:
# Cell 3 — Sweep Configuration
import copy
from dataclasses import dataclass
from src.configs import Config, NestConfig, TrainingMode

# --- Constants ---
NUM_ENVS = 32
NUM_STEPS = 128
MAX_AGENTS = 64
STEPS_PER_ITER = NUM_ENVS * NUM_STEPS * MAX_AGENTS  # 262,144
NUM_ITERATIONS = 8   # ~2.1M steps
TOTAL_STEPS = 2_000_000


def build_config(**overrides) -> Config:
    """Build base pheromone config with optional overrides."""
    cfg = Config()

    # Environment
    cfg.env.grid_size = 40
    cfg.env.num_agents = 16
    cfg.env.num_food = 25
    cfg.env.max_steps = 500
    cfg.env.observation_radius = 5

    # Evolution
    cfg.evolution.enabled = True
    cfg.evolution.food_energy = 100
    cfg.evolution.starting_energy = 200
    cfg.evolution.max_energy = 300
    cfg.evolution.reproduce_threshold = 180
    cfg.evolution.reproduce_cost = 80
    cfg.evolution.energy_per_step = 1
    cfg.evolution.max_agents = MAX_AGENTS
    cfg.evolution.mutation_std = 0.01

    # Training
    cfg.train.training_mode = TrainingMode.GRADIENT
    cfg.train.num_envs = NUM_ENVS
    cfg.train.num_steps = NUM_STEPS
    cfg.train.total_steps = TOTAL_STEPS
    cfg.log.wandb = False
    cfg.log.save_interval = 0

    # Nest
    cfg.nest.radius = 2
    cfg.nest.food_sip_fraction = 0.05
    cfg.nest.compass_noise_rate = 0.10

    # Field: per-channel pheromone settings
    cfg.field.num_channels = 4
    cfg.field.channel_diffusion_rates = (0.5, 0.01, 0.0, 0.0)
    cfg.field.channel_decay_rates = (0.05, 0.0001, 0.0, 0.0)
    cfg.field.field_value_cap = 1.0
    cfg.field.territory_write_strength = 0.01

    # Hidden food: disabled for pheromone sweep (focus on foraging)
    cfg.env.hidden_food.enabled = False

    # Apply overrides
    for key, val in overrides.items():
        parts = key.split('.')
        obj = cfg
        for part in parts[:-1]:
            obj = getattr(obj, part)
        setattr(obj, parts[-1], val)

    return cfg


def build_field_off_config() -> Config:
    """Build field-OFF baseline: same base but field disabled.

    Only per-channel rates are set — they override the scalar defaults in env.py.
    All channels get zero diffusion + instant decay, and territory writes are off.
    """
    cfg = build_config()
    cfg.field.channel_diffusion_rates = (0.0, 0.0, 0.0, 0.0)
    cfg.field.channel_decay_rates = (1.0, 1.0, 1.0, 1.0)
    cfg.field.territory_write_strength = 0.0
    return cfg


# --- Build all 23 sweep configs ---
@dataclass
class SweepEntry:
    name: str
    group: int
    param_name: str
    param_value: object
    config: Config


BASE_DECAY_RATES = (0.05, 0.0001, 0.0, 0.0)
BASE_DIFFUSION_RATES = (0.5, 0.01, 0.0, 0.0)

sweep_configs: list[SweepEntry] = []

# Baseline FIRST so early decision gate can compare against it
sweep_configs.append(SweepEntry(
    name='baseline_field_off',
    group=0,
    param_name='field_off',
    param_value=None,
    config=build_field_off_config(),
))

# Group 1: Recruitment decay (channel_decay_rates[0])
for val in [0.02, 0.03, 0.05, 0.08]:
    rates = (val,) + BASE_DECAY_RATES[1:]
    cfg = build_config(**{'field.channel_decay_rates': rates})
    sweep_configs.append(SweepEntry(
        name=f'g1_recruit_decay_{val}',
        group=1,
        param_name='channel_decay_rates[0]',
        param_value=val,
        config=cfg,
    ))

# Group 2: Recruitment diffusion (channel_diffusion_rates[0])
for val in [0.2, 0.3, 0.5, 0.7]:
    rates = (val,) + BASE_DIFFUSION_RATES[1:]
    cfg = build_config(**{'field.channel_diffusion_rates': rates})
    sweep_configs.append(SweepEntry(
        name=f'g2_recruit_diffusion_{val}',
        group=2,
        param_name='channel_diffusion_rates[0]',
        param_value=val,
        config=cfg,
    ))

# Group 3: Territory write strength
for val in [0.005, 0.01, 0.02, 0.05]:
    cfg = build_config(**{'field.territory_write_strength': val})
    sweep_configs.append(SweepEntry(
        name=f'g3_territory_write_{val}',
        group=3,
        param_name='territory_write_strength',
        param_value=val,
        config=cfg,
    ))

# Group 4: Compass noise
for val in [0.05, 0.10, 0.15, 0.20]:
    cfg = build_config(**{'nest.compass_noise_rate': val})
    sweep_configs.append(SweepEntry(
        name=f'g4_compass_noise_{val}',
        group=4,
        param_name='compass_noise_rate',
        param_value=val,
        config=cfg,
    ))

# Group 5: Scout sip fraction
for val in [0.05, 0.10, 0.15]:
    cfg = build_config(**{'nest.food_sip_fraction': val})
    sweep_configs.append(SweepEntry(
        name=f'g5_sip_fraction_{val}',
        group=5,
        param_name='food_sip_fraction',
        param_value=val,
        config=cfg,
    ))

# Group 6: Nest radius
for val in [2, 3, 4]:
    cfg = build_config(**{'nest.radius': val})
    sweep_configs.append(SweepEntry(
        name=f'g6_nest_radius_{val}',
        group=6,
        param_name='nest_radius',
        param_value=val,
        config=cfg,
    ))

assert len(sweep_configs) == 23, f"Expected 23 configs, got {len(sweep_configs)}"

# Print summary
print(f"Total sweep configs: {len(sweep_configs)}")
print(f"Total runs (3 seeds each): {len(sweep_configs) * 3}")
print(f"Steps per config: ~{NUM_ITERATIONS * STEPS_PER_ITER:,}")
print()
print(f"{'#':<4} {'Name':<30} {'Group':>5} {'Param':<30} {'Value'}")
print('-' * 85)
for i, sc in enumerate(sweep_configs):
    print(f"{i:<4} {sc.name:<30} {sc.group:>5} {sc.param_name:<30} {sc.param_value}")

## Training Phase

Train all 23 configs, 3 seeds each. Results are saved incrementally to Google Drive.
Resume-safe: re-running skips already-completed configs.

**Config order:** Baseline runs first (index 0), then Group 1–6.

**Early decision gate** after first 5 configs (baseline + Group 1):
if all runs have population < 5, prints a warning with suggested parameter adjustments.

In [None]:
# Cell 5 — Training Loop
import gc
import os
import pickle
import time
import traceback
from datetime import datetime, timedelta

import jax
import jax.numpy as jnp
import numpy as np

from src.training.parallel_train import ParallelTrainer
from src.agents.network import ActorCritic
from src.agents.policy import sample_actions
from src.environment.env import reset, step
from src.environment.obs import get_observations

DRIVE_BASE = '/content/drive/MyDrive/emergence-lab/pheromone_sweep'
RESULTS_PATH = f'{DRIVE_BASE}/sweep_results.pkl'
os.makedirs(DRIVE_BASE, exist_ok=True)

SEEDS_PER_CONFIG = 3


# --- Eval function using jax.lax.scan for speed ---
def run_eval(network, params, config, key, num_steps=500):
    """Run a single eval episode using lax.scan.

    Follows the collect_rollout pattern in src/training/rollout.py:
    scan body closes over network, params, config; carry holds
    (state, key, total_reward); runs for num_steps.
    """
    key, reset_key = jax.random.split(key)
    init_state = reset(reset_key, config)

    def _eval_step(carry, _unused):
        state, rng, total_reward = carry
        obs = get_observations(state, config)          # (max_agents, obs_dim)
        obs_batched = obs[None, :, :]                   # (1, max_agents, obs_dim)
        rng, act_key = jax.random.split(rng)
        actions, _, _, _ = sample_actions(network, params, obs_batched, act_key)
        actions = actions[0]                             # (max_agents,)
        state, rewards, done, info = step(state, actions, config)
        alive = state.agent_alive.astype(jnp.float32)
        total_reward = total_reward + jnp.sum(rewards * alive)
        return (state, rng, total_reward), None

    (final_state, _, total_reward), _ = jax.lax.scan(
        _eval_step, (init_state, key, jnp.float32(0.0)), None, length=num_steps,
    )

    # Extract metrics AFTER the scan (not inside)
    ch0 = jnp.asarray(final_state.field_state.values[:, :, 0])
    nonzero_mask = ch0 > 0.01
    trail_strength = jnp.where(
        jnp.any(nonzero_mask),
        jnp.sum(jnp.where(nonzero_mask, ch0, 0.0)) / jnp.maximum(jnp.sum(nonzero_mask.astype(jnp.float32)), 1.0),
        0.0,
    )
    final_pop = jnp.sum(final_state.agent_alive.astype(jnp.int32))
    return {
        'total_reward': float(total_reward),
        'final_population': int(final_pop),
        'trail_strength': float(trail_strength),
        'survival_rate': float(final_pop) / config.env.num_agents,
    }


# --- Load existing results for resume ---
if os.path.exists(RESULTS_PATH):
    with open(RESULTS_PATH, 'rb') as f:
        all_results = pickle.load(f)
    print(f"Resumed: {len(all_results)} configs already completed")
else:
    all_results = {}


def save_results():
    """Save results dict to Drive (atomic write)."""
    tmp_path = RESULTS_PATH + '.tmp'
    with open(tmp_path, 'wb') as f:
        pickle.dump(all_results, f, protocol=pickle.HIGHEST_PROTOCOL)
    os.replace(tmp_path, RESULTS_PATH)


# --- Main training loop ---
sweep_start = time.time()
total_configs = len(sweep_configs)

for i, sc in enumerate(sweep_configs):
    # Skip if already completed
    if sc.name in all_results and all_results[sc.name].get('success'):
        print(f"[{i+1}/{total_configs}] {sc.name} - SKIPPED (already done)")
        continue

    print(f"\n{'='*60}")
    print(f"[{i+1}/{total_configs}] Training: {sc.name} (Group {sc.group})")
    print(f"  {sc.param_name} = {sc.param_value}")
    print(f"{'='*60}")

    seed_ids = [300 + i * 3, 300 + i * 3 + 1, 300 + i * 3 + 2]
    checkpoint_dir = f'{DRIVE_BASE}/{sc.name}'

    try:
        t0 = time.time()
        trainer = ParallelTrainer(
            config=sc.config,
            num_seeds=SEEDS_PER_CONFIG,
            seed_ids=seed_ids,
            checkpoint_dir=checkpoint_dir,
            master_seed=42 + i,
        )
        metrics = trainer.train(
            num_iterations=NUM_ITERATIONS,
            checkpoint_interval_minutes=999,
            resume=False,
            print_interval=2,
        )
        train_time = time.time() - t0

        # --- Eval per seed ---
        num_actions = getattr(sc.config.agent, 'num_actions', 5)
        ps = trainer._parallel_state
        network = ActorCritic(
            hidden_dims=tuple(sc.config.agent.hidden_dims),
            num_actions=num_actions,
        )

        seed_evals = []
        for s in range(SEEDS_PER_CONFIG):
            seed_params = jax.tree.map(lambda x: x[s], ps.params)
            eval_key = jax.random.PRNGKey(1000 + seed_ids[s])
            eval_result = run_eval(network, seed_params, sc.config, eval_key, num_steps=500)
            seed_evals.append(eval_result)

        # Aggregate across seeds
        metric_keys = seed_evals[0].keys()
        agg = {}
        for k in metric_keys:
            vals = [e[k] for e in seed_evals]
            agg[f'{k}_mean'] = float(np.mean(vals))
            agg[f'{k}_std'] = float(np.std(vals))
            agg[f'{k}_values'] = vals

        # Population from training state
        alive = np.array(ps.env_state.agent_alive)  # (num_seeds, num_envs, max_agents)
        train_pop_per_seed = alive.sum(axis=(1, 2)) / alive.shape[1]

        all_results[sc.name] = {
            'success': True,
            'group': sc.group,
            'param_name': sc.param_name,
            'param_value': sc.param_value,
            'train_metrics': metrics,
            'train_time': train_time,
            'train_pop_per_seed': [float(p) for p in train_pop_per_seed],
            'eval_per_seed': seed_evals,
            'eval_agg': agg,
        }

        print(f"  Done in {train_time:.0f}s")
        print(f"  Reward: {agg['total_reward_mean']:.1f} +/- {agg['total_reward_std']:.1f}")
        print(f"  Population: {agg['final_population_mean']:.1f} +/- {agg['final_population_std']:.1f}")
        print(f"  Trail strength: {agg['trail_strength_mean']:.4f}")
        print(f"  Survival rate: {agg['survival_rate_mean']:.2f}")

    except Exception as e:
        print(f"  FAILED: {e}")
        traceback.print_exc()
        all_results[sc.name] = {
            'success': False,
            'group': sc.group,
            'param_name': sc.param_name,
            'param_value': sc.param_value,
            'error': str(e),
        }

    finally:
        try:
            del trainer
        except NameError:
            pass
        gc.collect()
        jax.clear_caches()

    # Save after every config
    save_results()

    # Progress report
    completed = sum(1 for r in all_results.values() if r.get('success'))
    elapsed = time.time() - sweep_start
    if completed > 0:
        avg_time = elapsed / completed
        remaining = avg_time * (total_configs - completed)
        eta = datetime.now() + timedelta(seconds=remaining)
        print(f"  Progress: {completed}/{total_configs} | "
              f"Elapsed: {elapsed/60:.0f}min | ETA: {eta.strftime('%H:%M')}")

    # --- EARLY DECISION GATE after baseline + Group 1 (first 5 configs) ---
    # Config order: [0]=baseline, [1-4]=Group 1 decay values
    if i == 4:
        print("\n" + "="*60)
        print("EARLY DECISION GATE (after baseline + Group 1)")
        print("="*60)
        max_pop = 0
        for name, r in all_results.items():
            if r.get('success'):
                pop = r['eval_agg'].get('final_population_mean', 0)
                max_pop = max(max_pop, pop)
                print(f"  {name}: pop={pop:.1f}, reward={r['eval_agg']['total_reward_mean']:.1f}")
        if max_pop < 5:
            print("\n\u274c WARNING: All runs have population < 5!")
            print("  Consider adjusting base params:")
            print("    - Increase num_food from 25 to 30")
            print("    - Increase food_energy from 100 to 120")
            print("    - Decrease energy_per_step from 1 to 0.5")
            print("  Continuing sweep, but results may be unreliable.")
        else:
            print(f"\n\u2705 PASS: Max population = {max_pop:.1f}. Proceeding.")

total_elapsed = time.time() - sweep_start
completed = sum(1 for r in all_results.values() if r.get('success'))
print(f"\n{'='*60}")
print(f"SWEEP COMPLETE: {completed}/{total_configs} successful in {total_elapsed/60:.0f} min")
print(f"{'='*60}")

## Analysis Phase

Load results, build comparison tables, generate plots, auto-select best config.

In [None]:
# Cell 7 — Load Results & Master Comparison Table
import pickle
import numpy as np

DRIVE_BASE = '/content/drive/MyDrive/emergence-lab/pheromone_sweep'
RESULTS_PATH = f'{DRIVE_BASE}/sweep_results.pkl'

with open(RESULTS_PATH, 'rb') as f:
    all_results = pickle.load(f)

print(f"Loaded {len(all_results)} configs")

# Build rows sorted by total_reward
rows = []
for name, r in all_results.items():
    if not r.get('success'):
        continue
    agg = r['eval_agg']
    rows.append({
        'name': name,
        'group': r['group'],
        'param_name': r['param_name'],
        'param_value': r['param_value'],
        'reward_mean': agg['total_reward_mean'],
        'reward_std': agg['total_reward_std'],
        'pop_mean': agg['final_population_mean'],
        'pop_std': agg['final_population_std'],
        'trail_mean': agg['trail_strength_mean'],
        'trail_std': agg['trail_strength_std'],
        'survival_mean': agg['survival_rate_mean'],
        'survival_std': agg['survival_rate_std'],
    })

rows.sort(key=lambda r: r['reward_mean'], reverse=True)

# Print formatted table
print(f"\n{'Rank':<5} {'Group':>5} {'Config Name':<30} "
      f"{'Reward':>16} {'Population':>16} {'Trail Str':>14} {'Survival':>12}")
print('-' * 105)
for rank, row in enumerate(rows, 1):
    print(f"{rank:<5} {row['group']:>5} {row['name']:<30} "
          f"{row['reward_mean']:>7.1f}+/-{row['reward_std']:<6.1f} "
          f"{row['pop_mean']:>7.1f}+/-{row['pop_std']:<6.1f} "
          f"{row['trail_mean']:>7.4f}+/-{row['trail_std']:<5.4f} "
          f"{row['survival_mean']:>5.2f}+/-{row['survival_std']:<4.2f}")

# Find baseline for reference
baseline_row = next((r for r in rows if r['name'] == 'baseline_field_off'), None)
if baseline_row:
    print(f"\nBaseline (field OFF): reward={baseline_row['reward_mean']:.1f}, "
          f"pop={baseline_row['pop_mean']:.1f}")

n_failed = sum(1 for r in all_results.values() if not r.get('success'))
if n_failed > 0:
    print(f"\n{n_failed} configs FAILED:")
    for name, r in all_results.items():
        if not r.get('success'):
            print(f"  {name}: {r.get('error', 'unknown')[:100]}")

In [None]:
# Cell 8 — Plots
import matplotlib.pyplot as plt
import os

FIGURE_DIR = f'{DRIVE_BASE}/figures'
os.makedirs(FIGURE_DIR, exist_ok=True)

GROUP_LABELS = {
    1: 'Recruitment Decay',
    2: 'Recruitment Diffusion',
    3: 'Territory Write Strength',
    4: 'Compass Noise',
    5: 'Scout Sip Fraction',
    6: 'Nest Radius',
}

GROUP_COLORS = {
    1: '#EE7733',
    2: '#0077BB',
    3: '#33BBEE',
    4: '#EE3377',
    5: '#009988',
    6: '#CC3311',
}

# Baseline reward line
baseline_reward = baseline_row['reward_mean'] if baseline_row else 0.0

# --- Figure 1: Grouped bar charts, one subplot per parameter group ---
fig, axes = plt.subplots(2, 3, figsize=(18, 10))
axes_flat = axes.flatten()

for idx, group_id in enumerate([1, 2, 3, 4, 5, 6]):
    ax = axes_flat[idx]
    group_rows = [r for r in rows if r['group'] == group_id]
    group_rows.sort(key=lambda r: float(r['param_value']))

    x_labels = [str(r['param_value']) for r in group_rows]
    x_pos = np.arange(len(group_rows))
    means = [r['reward_mean'] for r in group_rows]
    stds = [r['reward_std'] for r in group_rows]

    ax.bar(x_pos, means, yerr=stds, color=GROUP_COLORS[group_id],
           edgecolor='black', linewidth=0.8, capsize=4, alpha=0.85)
    ax.axhline(y=baseline_reward, color='gray', linestyle='--', linewidth=1.5,
               label=f'Field OFF ({baseline_reward:.0f})')
    ax.set_xticks(x_pos)
    ax.set_xticklabels(x_labels)
    ax.set_xlabel(group_rows[0]['param_name'] if group_rows else '')
    ax.set_ylabel('Total Reward (mean +/- std)')
    ax.set_title(f'Group {group_id}: {GROUP_LABELS[group_id]}')
    ax.legend(fontsize=8)

fig.suptitle('Pheromone Sweep: Reward by Parameter Group', fontsize=16, y=1.02)
fig.tight_layout()
fig.savefig(os.path.join(FIGURE_DIR, 'grouped_bars.png'), dpi=300, bbox_inches='tight')
plt.show()

# --- Figure 2: Heatmap of all 23 configs x 4 metrics ---
metric_keys = ['reward_mean', 'pop_mean', 'trail_mean', 'survival_mean']
metric_labels = ['Total Reward', 'Population', 'Trail Strength', 'Survival Rate']

# Sort rows by group then param_value
heatmap_rows = sorted(rows, key=lambda r: (r['group'], str(r['param_value'])))
n_configs = len(heatmap_rows)
n_metrics = len(metric_keys)

heatmap_data = np.zeros((n_configs, n_metrics))
for i, row in enumerate(heatmap_rows):
    for j, mk in enumerate(metric_keys):
        heatmap_data[i, j] = row[mk]

# Normalize per column (min-max)
heatmap_norm = np.zeros_like(heatmap_data)
for j in range(n_metrics):
    col = heatmap_data[:, j]
    cmin, cmax = col.min(), col.max()
    if cmax > cmin:
        heatmap_norm[:, j] = (col - cmin) / (cmax - cmin)
    else:
        heatmap_norm[:, j] = 0.5

fig2, ax2 = plt.subplots(figsize=(10, max(8, n_configs * 0.4)))
im = ax2.imshow(heatmap_norm, aspect='auto', cmap='YlOrRd')
ax2.set_xticks(range(n_metrics))
ax2.set_xticklabels(metric_labels, rotation=30, ha='right')
ax2.set_yticks(range(n_configs))
ax2.set_yticklabels([f"G{r['group']} {r['name']}" for r in heatmap_rows], fontsize=7)
# Annotate cells with raw values
for i in range(n_configs):
    for j in range(n_metrics):
        val = heatmap_data[i, j]
        fmt = '.0f' if j < 2 else '.3f' if j == 2 else '.2f'
        ax2.text(j, i, f"{val:{fmt}}", ha='center', va='center', fontsize=6,
                 color='white' if heatmap_norm[i, j] > 0.6 else 'black')
fig2.colorbar(im, ax=ax2, label='Normalized (per column)', shrink=0.8)
ax2.set_title('Pheromone Sweep: All Configs x Metrics Heatmap')
fig2.tight_layout()
fig2.savefig(os.path.join(FIGURE_DIR, 'heatmap.png'), dpi=300, bbox_inches='tight')
plt.show()

# --- Figure 3: Best-of-each-group summary bar chart ---
best_per_group = {}
for row in rows:
    g = row['group']
    if g == 0:
        continue  # skip baseline
    if g not in best_per_group or row['reward_mean'] > best_per_group[g]['reward_mean']:
        best_per_group[g] = row

fig3, ax3 = plt.subplots(figsize=(10, 5))
labels = []
means = []
stds = []
colors = []
for g in sorted(best_per_group.keys()):
    r = best_per_group[g]
    labels.append(f"G{g}: {r['param_name']}={r['param_value']}")
    means.append(r['reward_mean'])
    stds.append(r['reward_std'])
    colors.append(GROUP_COLORS.get(g, '#999999'))

# Add baseline
if baseline_row:
    labels.append('Baseline (Field OFF)')
    means.append(baseline_row['reward_mean'])
    stds.append(baseline_row['reward_std'])
    colors.append('#BBBBBB')

x_pos = np.arange(len(labels))
ax3.barh(x_pos, means, xerr=stds, color=colors, edgecolor='black', linewidth=0.8, capsize=4)
ax3.set_yticks(x_pos)
ax3.set_yticklabels(labels, fontsize=9)
ax3.set_xlabel('Total Reward (mean +/- std)')
ax3.set_title('Best Value Per Parameter Group vs Baseline')
ax3.invert_yaxis()
fig3.tight_layout()
fig3.savefig(os.path.join(FIGURE_DIR, 'best_per_group.png'), dpi=300, bbox_inches='tight')
plt.show()

print(f"Figures saved to {FIGURE_DIR}/")

In [None]:
# Cell 9 — Auto-select Best Per Group + Save Combined Config
import yaml

print("Best parameter value per group (by mean total_reward):")
print(f"{'Group':>5} {'Parameter':<30} {'Best Value':>12} {'Reward':>12}")
print('-' * 65)

best_values = {}  # group -> (param_name, best_value, reward)
for g in sorted(best_per_group.keys()):
    r = best_per_group[g]
    best_values[g] = (r['param_name'], r['param_value'], r['reward_mean'])
    print(f"{g:>5} {r['param_name']:<30} {str(r['param_value']):>12} {r['reward_mean']:>12.1f}")

# Build combined best config
best_overrides = {}
for g, (param_name, param_value, _) in best_values.items():
    if g == 1:
        base_rates = list(build_config().field.channel_decay_rates)
        base_rates[0] = param_value
        best_overrides['field.channel_decay_rates'] = tuple(base_rates)
    elif g == 2:
        base_rates = list(build_config().field.channel_diffusion_rates)
        base_rates[0] = param_value
        best_overrides['field.channel_diffusion_rates'] = tuple(base_rates)
    elif g == 3:
        best_overrides['field.territory_write_strength'] = param_value
    elif g == 4:
        best_overrides['nest.compass_noise_rate'] = param_value
    elif g == 5:
        best_overrides['nest.food_sip_fraction'] = param_value
    elif g == 6:
        best_overrides['nest.radius'] = param_value

best_config = build_config(**best_overrides)

print("\n" + "="*60)
print("RECOMMENDED COMBINED CONFIG")
print("="*60)
print(f"  channel_decay_rates:     {best_config.field.channel_decay_rates}")
print(f"  channel_diffusion_rates: {best_config.field.channel_diffusion_rates}")
print(f"  territory_write_strength: {best_config.field.territory_write_strength}")
print(f"  compass_noise_rate:       {best_config.nest.compass_noise_rate}")
print(f"  food_sip_fraction:        {best_config.nest.food_sip_fraction}")
print(f"  nest_radius:              {best_config.nest.radius}")

# Save as YAML
yaml_path = '/content/drive/MyDrive/emergence-lab/pheromone_best_config.yaml'
best_config.to_yaml(yaml_path)
print(f"\nSaved best config to {yaml_path}")

# Also print as copy-paste Python for convenience
print("\n# Copy-paste config for full training notebook:")
print(f"cfg.field.channel_decay_rates = {best_config.field.channel_decay_rates}")
print(f"cfg.field.channel_diffusion_rates = {best_config.field.channel_diffusion_rates}")
print(f"cfg.field.territory_write_strength = {best_config.field.territory_write_strength}")
print(f"cfg.nest.compass_noise_rate = {best_config.nest.compass_noise_rate}")
print(f"cfg.nest.food_sip_fraction = {best_config.nest.food_sip_fraction}")
print(f"cfg.nest.radius = {best_config.nest.radius}")

In [None]:
# Cell 10 — Confirmation Run: Best Combined vs Default vs Field-OFF
import gc
import time
import jax
import jax.numpy as jnp
import numpy as np
from src.training.parallel_train import ParallelTrainer
from src.agents.network import ActorCritic

CONFIRM_ITERS = NUM_ITERATIONS  # Same as sweep: ~2M steps
CONFIRM_SEEDS = 3

confirm_configs = {
    'best_combined': best_config,
    'default_base': build_config(),   # All defaults
}

confirm_results = {}

# Reuse field-off baseline from sweep if available
if 'baseline_field_off' in all_results and all_results['baseline_field_off'].get('success'):
    confirm_results['field_off'] = all_results['baseline_field_off']['eval_agg']
    print("Reusing field-off baseline from sweep.")
else:
    confirm_configs['field_off'] = build_field_off_config()

for cond_name, cfg in confirm_configs.items():
    print(f"\n{'='*60}")
    print(f"Confirmation: {cond_name}")
    print(f"{'='*60}")

    seed_ids = [900, 901, 902]
    checkpoint_dir = f'{DRIVE_BASE}/confirm_{cond_name}'

    try:
        t0 = time.time()
        trainer = ParallelTrainer(
            config=cfg,
            num_seeds=CONFIRM_SEEDS,
            seed_ids=seed_ids,
            checkpoint_dir=checkpoint_dir,
            master_seed=999,
        )
        metrics = trainer.train(
            num_iterations=CONFIRM_ITERS,
            checkpoint_interval_minutes=999,
            resume=False,
            print_interval=2,
        )
        train_time = time.time() - t0

        # Eval
        num_actions = getattr(cfg.agent, 'num_actions', 5)
        ps = trainer._parallel_state
        network = ActorCritic(
            hidden_dims=tuple(cfg.agent.hidden_dims),
            num_actions=num_actions,
        )
        seed_evals = []
        for s in range(CONFIRM_SEEDS):
            seed_params = jax.tree.map(lambda x: x[s], ps.params)
            eval_key = jax.random.PRNGKey(2000 + seed_ids[s])
            eval_result = run_eval(network, seed_params, cfg, eval_key, num_steps=500)
            seed_evals.append(eval_result)

        metric_keys = seed_evals[0].keys()
        agg = {}
        for k in metric_keys:
            vals = [e[k] for e in seed_evals]
            agg[f'{k}_mean'] = float(np.mean(vals))
            agg[f'{k}_std'] = float(np.std(vals))

        confirm_results[cond_name] = agg
        print(f"  Done in {train_time:.0f}s")
        print(f"  Reward: {agg['total_reward_mean']:.1f} +/- {agg['total_reward_std']:.1f}")
        print(f"  Population: {agg['final_population_mean']:.1f}")

    except Exception as e:
        print(f"  FAILED: {e}")
        import traceback; traceback.print_exc()
        confirm_results[cond_name] = {'error': str(e)}

    finally:
        try:
            del trainer
        except NameError:
            pass
        gc.collect()
        jax.clear_caches()

# --- Comparison Table ---
print("\n" + "="*70)
print("CONFIRMATION COMPARISON")
print("="*70)
print(f"{'Condition':<20} {'Reward':>16} {'Population':>16} {'Trail':>14} {'Survival':>12}")
print('-' * 82)

for cond_name in ['best_combined', 'default_base', 'field_off']:
    agg = confirm_results.get(cond_name, {})
    if 'error' in agg:
        print(f"{cond_name:<20} FAILED")
        continue
    print(f"{cond_name:<20} "
          f"{agg.get('total_reward_mean', 0):>7.1f}+/-{agg.get('total_reward_std', 0):<6.1f} "
          f"{agg.get('final_population_mean', 0):>7.1f}+/-{agg.get('final_population_std', 0):<6.1f} "
          f"{agg.get('trail_strength_mean', 0):>7.4f}+/-{agg.get('trail_strength_std', 0):<5.4f} "
          f"{agg.get('survival_rate_mean', 0):>5.2f}+/-{agg.get('survival_rate_std', 0):<4.2f}")

# Save confirmation results
confirm_path = f'{DRIVE_BASE}/confirmation_results.pkl'
with open(confirm_path, 'wb') as f:
    pickle.dump(confirm_results, f)
print(f"\nConfirmation results saved to {confirm_path}")

## Summary Report

### Winning Configuration

The best value for each parameter group was selected by highest mean total reward across 3 seeds.
These values were combined into a single "best" config and validated in a confirmation run
against the default base config and field-off baseline.

### How to Use

The best config has been saved as YAML to:
```
/content/drive/MyDrive/emergence-lab/pheromone_best_config.yaml
```

Load it in a full training notebook:
```python
from src.configs import Config
config = Config.from_yaml('/content/drive/MyDrive/emergence-lab/pheromone_best_config.yaml')
config.train.total_steps = 10_000_000  # Full 10M step run
```

### Recommended Next Steps

1. **If pheromone > field-off**: Run full 10M steps with best config, 15 seeds
2. **If pheromone ~ field-off**: Check trail_strength — if trails form but don't help, the observation/action space may need tuning
3. **If pheromone < field-off**: The pheromone system implementation may need debugging; check that carry-back mechanics are working

### Sweep Results Location

All results are saved to Google Drive:
- `pheromone_sweep/sweep_results.pkl` — full results dict
- `pheromone_sweep/figures/` — publication-quality plots (300 DPI)
- `pheromone_best_config.yaml` — best combined config