# Overnight Comparison: Patch Scaling ON vs OFF

Runs both conditions automatically:
1. Field ON + Patch Scaling (3 seeds, 10M steps)
2. Field OFF baseline (3 seeds, 10M steps)

Then prints comparison table.

**Just run all cells and check results in the morning.**


In [None]:
from google.colab import drive
drive.mount('/content/drive')

import os

REPO_DIR = '/content/emergence-lab'
GITHUB_USERNAME = 'imashishkh21'

if not os.path.exists(REPO_DIR):
    !git clone https://github.com/{GITHUB_USERNAME}/emergence-lab.git {REPO_DIR}
else:
    !cd {REPO_DIR} && git pull origin main

os.chdir(REPO_DIR)
!pip install -e ".[dev]" -q

import jax
print(f"JAX devices: {jax.devices()}")


In [None]:
from src.configs import Config

TOTAL_STEPS = 10_000_000  # 10M steps
NUM_SEEDS = 3
CHECKPOINT_BASE = '/content/drive/MyDrive/emergence-lab/overnight_comparison'


def build_config(patch_scaling_enabled: bool, field_enabled: bool = True) -> Config:
    """Build config for experiment."""
    config = Config()

    # Environment
    config.env.grid_size = 32
    config.env.num_agents = 16
    config.env.num_food = 40

    # Evolution
    config.evolution.enabled = True
    config.evolution.max_agents = 64
    config.evolution.starting_energy = 200
    config.evolution.food_energy = 100
    config.evolution.reproduce_threshold = 120
    config.evolution.reproduce_cost = 40
    config.evolution.mutation_std = 0.01

    # Training
    config.train.total_steps = TOTAL_STEPS
    config.train.num_envs = 32
    config.train.num_steps = 128
    config.log.wandb = False
    config.log.save_interval = 0

    # Patch scaling
    config.nest.patch_scaling_enabled = patch_scaling_enabled
    config.nest.patch_radius = 2
    config.nest.patch_n_cap = 6

    # Ch1 OFF in all conditions
    config.field.channel_decay_rates = (0.05, 1.0, 0.0, 0.0)

    # Field OFF = instant decay on all channels
    if not field_enabled:
        config.field.channel_decay_rates = (1.0, 1.0, 1.0, 1.0)
        config.field.channel_diffusion_rates = (0.0, 0.0, 0.0, 0.0)

    return config


# Define conditions to run
CONDITIONS = [
    ("scaling_ON", build_config(patch_scaling_enabled=True, field_enabled=True)),
    ("field_OFF", build_config(patch_scaling_enabled=False, field_enabled=False)),
]

print(f"Will run {len(CONDITIONS)} conditions x {NUM_SEEDS} seeds x {TOTAL_STEPS:,} steps")
print("Checkpoints base:", CHECKPOINT_BASE)


In [None]:
import gc
import time
import numpy as np

from src.training.parallel_train import ParallelTrainer

all_results = {}

for condition_name, config in CONDITIONS:
    print(f"\n{'='*60}")
    print(f"RUNNING: {condition_name}")
    print(f"{'='*60}")

    checkpoint_dir = f"{CHECKPOINT_BASE}/{condition_name}"
    seed_ids = list(range(NUM_SEEDS))

    steps_per_iter = config.train.num_envs * config.train.num_steps * config.evolution.max_agents
    num_iterations = TOTAL_STEPS // steps_per_iter

    print(f"Checkpoint dir: {checkpoint_dir}")
    print(f"Seeds: {seed_ids}")
    print(f"Steps/iter: {steps_per_iter}")
    print(f"Iterations: {num_iterations}")

    try:
        t0 = time.time()
        trainer = ParallelTrainer(
            config=config,
            num_seeds=NUM_SEEDS,
            seed_ids=seed_ids,
            checkpoint_dir=checkpoint_dir,
            master_seed=42,
        )

        metrics = trainer.train(
            num_iterations=num_iterations,
            checkpoint_interval_minutes=30,
            resume=True,  # Resume if exists
            print_interval=50,
        )

        elapsed = time.time() - t0

        all_results[condition_name] = {
            'metrics': metrics,
            'time': elapsed,
            'success': True,
        }

        print(f"\n{condition_name} completed in {elapsed/60:.1f} minutes")
        print(f"Final reward (per-seed): {metrics.get('mean_reward', 'N/A')}")
        print(f"Final population (per-seed): {metrics.get('population_size', 'N/A')}")

    except Exception as e:
        print(f"FAILED: {e}")
        all_results[condition_name] = {'success': False, 'error': str(e)}

    finally:
        try:
            del trainer
        except Exception:
            pass
        gc.collect()
        try:
            import jax
            if hasattr(jax, 'clear_caches'):
                jax.clear_caches()
        except Exception:
            pass

print("\n" + "="*60)
print("ALL CONDITIONS COMPLETE")
print("="*60)


In [None]:
import numpy as np

print("\n" + "="*60)
print("RESULTS COMPARISON")
print("="*60)

print(f"\n{'Condition':<20} {'Reward':>20} {'Population':>20} {'Time(min)':>10}")
print("-" * 74)

for name, result in all_results.items():
    if result.get('success'):
        metrics = result['metrics']
        reward = np.array(metrics.get('mean_reward', [0.0]), dtype=float)
        pop = np.array(metrics.get('population_size', [0.0]), dtype=float)
        elapsed_min = float(result.get('time', 0.0)) / 60.0

        reward_mean = float(np.mean(reward))
        reward_std = float(np.std(reward))
        pop_mean = float(np.mean(pop))
        pop_std = float(np.std(pop))

        print(f"{name:<20} {reward_mean:>8.4f} +/- {reward_std:<8.4f} {pop_mean:>8.2f} +/- {pop_std:<6.2f} {elapsed_min:>10.1f}")
    else:
        print(f"{name:<20} FAILED: {result.get('error', 'unknown')[:50]}")

# Calculate improvement
if all_results.get('scaling_ON', {}).get('success') and all_results.get('field_OFF', {}).get('success'):
    on_reward = float(np.mean(all_results['scaling_ON']['metrics'].get('mean_reward', [0.0])))
    off_reward = float(np.mean(all_results['field_OFF']['metrics'].get('mean_reward', [0.0])))

    if off_reward != 0:
        improvement = (on_reward - off_reward) / abs(off_reward) * 100.0
        print("\n" + "="*74)
        print(f"SCALING ON vs FIELD OFF: {improvement:+.1f}% reward difference")
        print("="*74)

print("\nCheckpoints saved to:", CHECKPOINT_BASE)
