# Experiment 12b: Keystone Edge Protection

**Phase 4 - Testing Targeted Conservation Strategies**

## Background

Experiment 12 revealed that **82% of edge removals IMPROVE recovery** - a counterintuitive finding suggesting most network connections actually hinder recovery. Only **6% of edges are true "keystones"** whose removal significantly harms recovery capacity.

## Key Question

**Does protecting ONLY the keystone edges (6%) provide equivalent recovery benefits to protecting the entire network?**

If so, this has major conservation implications: resources can be concentrated on a small subset of critical connections.

## Experimental Design

| Condition | Description | N Edges |
|-----------|-------------|--------|
| `full_network` | All edges intact (baseline) | ~100 |
| `keystone_only` | Only keystone edges (top 6 harmful) | 6 |
| `random_6` | Random 6 edges preserved | 6 |
| `random_10pct` | Random 10% of edges preserved | ~10 |
| `top_flow_6` | Top 6 edges by flow preserved | 6 |
| `no_edges` | All edges removed (isolated cells) | 0 |

## Protocol

**NO FORCING** - matches Experiment 12 conditions exactly.
- Cascade phase: Lévy noise (α=1.5, σ=0.06) triggers tipping
- Recovery phase: Gaussian noise (α=2.0, σ=0.04) allows potential recovery

## Expected Outcomes

- `keystone_only` should have recovery close to `full_network`
- `random_6` and `top_flow_6` should have much lower recovery
- This would validate the keystone identification from Experiment 12

## 1. Setup and Imports

In [None]:
import sys
sys.path.insert(0, '/opt/research-local/src')

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import pickle
import time
from pathlib import Path
from netCDF4 import Dataset
from dask.distributed import as_completed

from energy_constrained import get_dask_client

print("Imports successful!")
print(f"NumPy version: {np.__version__}")

In [None]:
# Connect to Dask cluster
client = get_dask_client()
print(f"Connected to: {client.scheduler_info()['address']}")
print(f"Workers: {len(client.scheduler_info()['workers'])}")
print(f"Total threads: {sum(w['nthreads'] for w in client.scheduler_info()['workers'].values())}")

## 2. Load Amazon Data

In [None]:
DATA_PATH = Path('/opt/research-local/data/amazon/amazon_adaptation_model/average_network/era5_new_network_data')

def load_amazon_data(year=2003, months=[7, 8, 9]):
    """Load and average Amazon moisture recycling data."""
    all_rain = []
    all_evap = []
    all_network = []
    
    for month in months:
        file_path = DATA_PATH / f'1deg_{year}_{month:02d}.nc'
        if file_path.exists():
            with Dataset(file_path, 'r') as ds:
                all_rain.append(ds.variables['rain'][:])
                all_evap.append(ds.variables['evap'][:])
                all_network.append(ds.variables['network'][:])
    
    return {
        'rain': np.mean(all_rain, axis=0),
        'evap': np.mean(all_evap, axis=0),
        'network': np.mean(all_network, axis=0),
        'n_cells': len(all_rain[0])
    }

amazon_data = load_amazon_data(year=2003)
print(f"Loaded Amazon data: {amazon_data['n_cells']} cells")

## 3. Experiment Configuration

In [None]:
CONFIG = {
    'n_cells': 50,
    'min_flow': 1.0,
    'barrier_height': 0.2,
    'cascade_duration': 200,
    'recovery_duration': 800,
    'dt': 0.5,
    'cascade_sigma': 0.06,
    'cascade_alpha': 1.5,
    'recovery_sigma': 0.04,
    'recovery_alpha': 2.0,
    'forcing': 0.0,  # NO FORCING - matches Experiment 12 conditions
    'n_runs': 20,
    'base_seed': 42,
}

# Keystone edges from Experiment 12 (most critical for recovery)
# These are the edges whose removal HURTS recovery the most
KEYSTONE_EDGES = [
    ('cell_28', 'cell_14'),  # Most critical: -7.9% recovery impact
    ('cell_27', 'cell_12'),  # Second: from Exp 12 top 20 list
    ('cell_32', 'cell_40'),
    ('cell_31', 'cell_4'),
    ('cell_17', 'cell_32'),
    ('cell_31', 'cell_24'),
]

KEYSTONE_NODES = ['cell_14', 'cell_31', 'cell_27', 'cell_32', 'cell_28', 'cell_40', 'cell_12']

print("=" * 60)
print("EXPERIMENT 12b: KEYSTONE PROTECTION TEST")
print("=" * 60)
print(f"Runs per condition: {CONFIG['n_runs']}")
print(f"Forcing: {CONFIG['forcing']} (NO FORCING - matches Exp 12)")
print(f"Keystone edges: {len(KEYSTONE_EDGES)}")

## 4. Build Edge List

In [None]:
def get_full_edge_list(data, config, seed=42):
    """Get full list of edges from Amazon network."""
    np.random.seed(seed)
    network_matrix = data['network']
    n_cells = config['n_cells']
    min_flow = config['min_flow']
    
    total_flow = network_matrix.sum(axis=0) + network_matrix.sum(axis=1)
    top_indices = np.argsort(total_flow)[-n_cells:]
    
    edges = []
    for i, idx_i in enumerate(top_indices):
        for j, idx_j in enumerate(top_indices):
            if i != j:
                flow = network_matrix[idx_i, idx_j]
                if flow > min_flow:
                    edges.append((f'cell_{i}', f'cell_{j}', flow))
    
    return edges, top_indices

full_edges, top_indices = get_full_edge_list(amazon_data, CONFIG)
print(f"Total edges in network: {len(full_edges)}")
print(f"Keystone edges to protect: {len(KEYSTONE_EDGES)}")

## 5. Define Experimental Conditions

In [None]:
all_edge_keys = [(e[0], e[1]) for e in full_edges]
edges_by_flow = sorted(full_edges, key=lambda x: x[2], reverse=True)
top_flow_edges = [(e[0], e[1]) for e in edges_by_flow[:6]]

np.random.seed(CONFIG['base_seed'])
random_6_edges = [all_edge_keys[i] for i in np.random.choice(len(all_edge_keys), 6, replace=False)]
random_10pct_edges = [all_edge_keys[i] for i in np.random.choice(len(all_edge_keys), 
                                                                   max(1, len(all_edge_keys)//10), 
                                                                   replace=False)]

CONDITIONS = {
    'full_network': {'edges': 'all', 'description': 'All edges intact (baseline)'},
    'keystone_only': {'edges': KEYSTONE_EDGES, 'description': f'Only {len(KEYSTONE_EDGES)} keystone edges'},
    'random_6': {'edges': random_6_edges, 'description': 'Random 6 edges'},
    'random_10pct': {'edges': random_10pct_edges, 'description': f'Random {len(random_10pct_edges)} edges (10%)'},
    'top_flow_6': {'edges': top_flow_edges, 'description': 'Top 6 edges by flow'},
    'no_edges': {'edges': 'none', 'description': 'No edges (isolated cells)'},
}

print("EXPERIMENTAL CONDITIONS:")
print("=" * 60)
for name, cond in CONDITIONS.items():
    n_edges = len(full_edges) if cond['edges'] == 'all' else (0 if cond['edges'] == 'none' else len(cond['edges']))
    print(f"  {name:20} : {n_edges:3} edges - {cond['description']}")

print(f"\nTotal simulations: {len(CONDITIONS) * CONFIG['n_runs']}")

## 6. Worker Function (Standalone)

In [None]:
def run_protection_experiment(data_bytes, condition_name, edges_to_keep, config, seed):
    """
    Worker function using the actual run_two_phase_experiment from solvers.
    This ensures we match Experiment 12's methodology exactly.
    """
    import numpy as np
    import pickle
    import sys
    
    if '/opt/research-local/src' not in sys.path:
        sys.path.insert(0, '/opt/research-local/src')
    
    from energy_constrained import (
        EnergyConstrainedNetwork,
        EnergyConstrainedCusp,
        GradientDrivenCoupling,
    )
    from energy_constrained.solvers import run_two_phase_experiment
    
    # Deserialize data
    data = pickle.loads(data_bytes)
    np.random.seed(seed)
    
    # Extract parameters
    network_matrix = data['network']
    n_cells = config['n_cells']
    min_flow = config['min_flow']
    barrier_height = config['barrier_height']
    
    # Select top cells by total flow
    total_flow = network_matrix.sum(axis=0) + network_matrix.sum(axis=1)
    top_indices = np.argsort(total_flow)[-n_cells:]
    
    # Build network using EnergyConstrainedNetwork
    net = EnergyConstrainedNetwork()
    
    # Add elements
    for i in range(n_cells):
        element = EnergyConstrainedCusp(
            a=-1.0, b=1.0, c=0.0, x_0=0.0,
            barrier_height=barrier_height,
            dissipation_rate=0.1
        )
        net.add_element(f'cell_{i}', element)
    
    # Build all possible edges
    all_edges = {}
    for i, idx_i in enumerate(top_indices):
        for j, idx_j in enumerate(top_indices):
            if i != j:
                flow = network_matrix[idx_i, idx_j]
                if flow > min_flow:
                    all_edges[(f'cell_{i}', f'cell_{j}')] = flow
    
    # Determine which edges to keep
    if edges_to_keep == 'all':
        edges_set = set(all_edges.keys())
    elif edges_to_keep == 'none':
        edges_set = set()
    else:
        edges_set = set(tuple(e) for e in edges_to_keep)
    
    # Add couplings for kept edges
    n_edges_added = 0
    for (src, tgt), flow in all_edges.items():
        if (src, tgt) in edges_set:
            coupling = GradientDrivenCoupling(
                conductivity=flow / 100.0,
                state_coupling=0.1
            )
            net.add_coupling(src, tgt, coupling)
            n_edges_added += 1
    
    # Run two-phase experiment using the SAME function as Experiment 12
    result = run_two_phase_experiment(
        network=net,
        cascade_duration=config['cascade_duration'],
        recovery_duration=config['recovery_duration'],
        dt=config['dt'],
        cascade_sigma=config['cascade_sigma'],
        cascade_alpha=config['cascade_alpha'],
        recovery_sigma=config['recovery_sigma'],
        recovery_alpha=config['recovery_alpha'],
        seed=seed
    )
    
    # Extract metrics (same as Exp 12)
    n_cells_actual = result.x_full.shape[1]
    n_tip_events = 0
    n_recover_events = 0
    
    for j in range(n_cells_actual):
        x_traj = result.x_full[:, j]
        signs = np.sign(x_traj)
        sign_changes = np.diff(signs)
        n_tip_events += np.sum(sign_changes > 0)
        n_recover_events += np.sum(sign_changes < 0)
    
    tip_recovery_ratio = n_tip_events / n_recover_events if n_recover_events > 0 else np.nan
    
    return {
        'condition': condition_name,
        'n_edges': n_edges_added,
        'seed': seed,
        'pct_tipped_cascade': result.metrics['pct_tipped_at_cascade_end'],
        'final_pct_tipped': result.metrics['final_pct_tipped'],
        'recovery_fraction': result.metrics['recovery_fraction'],
        'n_tip_events': n_tip_events,
        'n_recover_events': n_recover_events,
        'tip_recovery_ratio': tip_recovery_ratio,
        'n_permanent_tips': result.metrics['n_permanent_tips'],
    }

print("Worker function defined (using run_two_phase_experiment).")

## 7. Run Experiment

In [None]:
data_bytes = pickle.dumps(amazon_data)
print(f"Data serialized: {len(data_bytes) / 1024:.1f} KB")

print("\n" + "=" * 60)
print("EXPERIMENT 12b: Starting Keystone Protection Test")
print("=" * 60)
start_time = time.time()

futures = []
for condition_name, condition in CONDITIONS.items():
    edges = condition['edges']
    for run_idx in range(CONFIG['n_runs']):
        seed = CONFIG['base_seed'] + hash(condition_name) % 10000 + run_idx
        future = client.submit(
            run_protection_experiment,
            data_bytes, condition_name, edges, CONFIG, seed
        )
        futures.append(future)

print(f"Submitted {len(futures)} tasks")

all_results = []
print("\nProgress:")
for i, future in enumerate(as_completed(futures)):
    result = future.result()
    all_results.append(result)
    
    if (i + 1) % 20 == 0 or (i + 1) == len(futures):
        elapsed = time.time() - start_time
        print(f"  Completed {i+1}/{len(futures)} ({100*(i+1)/len(futures):.1f}%) - {elapsed:.0f}s elapsed")

elapsed = time.time() - start_time
print(f"\n" + "=" * 60)
print(f"COMPLETE: {len(all_results)} simulations in {elapsed:.1f}s ({elapsed/60:.1f} min)")
print("=" * 60)

## 8. Results Analysis

In [None]:
df = pd.DataFrame(all_results)
print(f"Results shape: {df.shape}")
df.head(10)

In [None]:
summary = df.groupby('condition').agg({
    'n_edges': 'first',
    'recovery_fraction': ['mean', 'std'],
    'pct_tipped_cascade': ['mean', 'std'],
    'final_pct_tipped': ['mean', 'std'],
    'tip_recovery_ratio': ['mean', 'std'],
}).round(4)

summary.columns = ['_'.join(col) if col[1] else col[0] for col in summary.columns]
summary = summary.reset_index()

condition_order = ['full_network', 'keystone_only', 'random_6', 'random_10pct', 'top_flow_6', 'no_edges']
summary['order'] = summary['condition'].map({c: i for i, c in enumerate(condition_order)})
summary = summary.sort_values('order').drop('order', axis=1)

print("SUMMARY BY CONDITION:")
print("=" * 80)
print(summary.to_string(index=False))

## 9. Visualization

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(14, 12))

colors = {
    'full_network': 'green',
    'keystone_only': 'blue',
    'random_6': 'orange',
    'random_10pct': 'purple',
    'top_flow_6': 'red',
    'no_edges': 'gray',
}

# Panel 1: Recovery Fraction
ax = axes[0, 0]
x_pos = range(len(condition_order))
means = [summary[summary['condition'] == c]['recovery_fraction_mean'].values[0] for c in condition_order]
stds = [summary[summary['condition'] == c]['recovery_fraction_std'].values[0] for c in condition_order]
bar_colors = [colors[c] for c in condition_order]

ax.bar(x_pos, means, yerr=stds, capsize=5, color=bar_colors, edgecolor='black', alpha=0.8)
ax.set_xticks(x_pos)
ax.set_xticklabels([c.replace('_', '\n') for c in condition_order], fontsize=10)
ax.set_ylabel('Recovery Fraction', fontsize=12)
ax.set_title('Recovery by Protection Strategy', fontsize=14)
ax.axhline(means[0], color='green', linestyle='--', alpha=0.5, label='Full network baseline')
ax.grid(True, alpha=0.3, axis='y')
ax.legend()

# Panel 2: Box plot
ax = axes[0, 1]
box_data = [df[df['condition'] == c]['recovery_fraction'].values for c in condition_order]
bp = ax.boxplot(box_data, labels=[c.replace('_', '\n') for c in condition_order], patch_artist=True)
for patch, color in zip(bp['boxes'], bar_colors):
    patch.set_facecolor(color)
    patch.set_alpha(0.7)
ax.set_ylabel('Recovery Fraction', fontsize=12)
ax.set_title('Recovery Distribution by Strategy', fontsize=14)
ax.grid(True, alpha=0.3, axis='y')

# Panel 3: Cascade vs Final
ax = axes[1, 0]
cascade_means = [summary[summary['condition'] == c]['pct_tipped_cascade_mean'].values[0] for c in condition_order]
final_means = [summary[summary['condition'] == c]['final_pct_tipped_mean'].values[0] for c in condition_order]

x = np.arange(len(condition_order))
width = 0.35
ax.bar(x - width/2, cascade_means, width, label='End of Cascade', color='tomato', edgecolor='black')
ax.bar(x + width/2, final_means, width, label='Final', color='steelblue', edgecolor='black')
ax.set_xticks(x)
ax.set_xticklabels([c.replace('_', '\n') for c in condition_order], fontsize=10)
ax.set_ylabel('% Cells Tipped', fontsize=12)
ax.set_title('Cascade Impact and Final State', fontsize=14)
ax.legend()
ax.grid(True, alpha=0.3, axis='y')

# Panel 4: Tip/Recovery Ratio
ax = axes[1, 1]
ratio_means = [summary[summary['condition'] == c]['tip_recovery_ratio_mean'].values[0] for c in condition_order]
ratio_stds = [summary[summary['condition'] == c]['tip_recovery_ratio_std'].values[0] for c in condition_order]

ax.bar(x_pos, ratio_means, yerr=ratio_stds, capsize=5, color=bar_colors, edgecolor='black', alpha=0.8)
ax.axhline(1.0, color='red', linestyle='--', label='Symmetric (ratio=1)')
ax.set_xticks(x_pos)
ax.set_xticklabels([c.replace('_', '\n') for c in condition_order], fontsize=10)
ax.set_ylabel('Tip/Recovery Event Ratio', fontsize=12)
ax.set_title('Tipping Asymmetry by Strategy', fontsize=14)
ax.legend()
ax.grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.savefig('/workspace/data/exp12b_keystone_protection.png', dpi=150, bbox_inches='tight')
plt.show()
print("\nPlot saved to /workspace/data/exp12b_keystone_protection.png")

## 10. Key Findings

In [None]:
print("\n" + "=" * 70)
print("EXPERIMENT 12b: KEY FINDINGS")
print("=" * 70)

full_recovery = summary[summary['condition'] == 'full_network']['recovery_fraction_mean'].values[0]
keystone_recovery = summary[summary['condition'] == 'keystone_only']['recovery_fraction_mean'].values[0]
random_recovery = summary[summary['condition'] == 'random_6']['recovery_fraction_mean'].values[0]
topflow_recovery = summary[summary['condition'] == 'top_flow_6']['recovery_fraction_mean'].values[0]
no_edge_recovery = summary[summary['condition'] == 'no_edges']['recovery_fraction_mean'].values[0]

keystone_efficiency = keystone_recovery / full_recovery * 100 if full_recovery > 0 else 0
random_efficiency = random_recovery / full_recovery * 100 if full_recovery > 0 else 0
topflow_efficiency = topflow_recovery / full_recovery * 100 if full_recovery > 0 else 0

print(f"""
1. RECOVERY COMPARISON:
   Full network (baseline):  {full_recovery:.1%}
   Keystone only (6 edges):  {keystone_recovery:.1%}  ({keystone_efficiency:.0f}% of full)
   Random 6 edges:           {random_recovery:.1%}  ({random_efficiency:.0f}% of full)
   Top flow 6 edges:         {topflow_recovery:.1%}  ({topflow_efficiency:.0f}% of full)
   No edges:                 {no_edge_recovery:.1%}

2. KEYSTONE EFFECTIVENESS:
   Keystone edges achieve {keystone_efficiency:.0f}% of full network recovery
   with only {len(KEYSTONE_EDGES)} edges ({100*len(KEYSTONE_EDGES)/len(full_edges):.1f}% of total)
   
   Efficiency gain: {keystone_efficiency / (100*len(KEYSTONE_EDGES)/len(full_edges)):.1f}x better than random

3. CONSERVATION IMPLICATION:
   {'Keystone protection is HIGHLY EFFICIENT!' if keystone_efficiency > 50 else 'Keystone protection shows limited benefit.'}
   Focus protection on: {', '.join(KEYSTONE_NODES[:5])}
""")

## 11. Save Results

In [None]:
df.to_csv('/workspace/data/experiment12b_keystone_protection_full.csv', index=False)
print(f"Full results saved to /workspace/data/experiment12b_keystone_protection_full.csv")

summary.to_csv('/workspace/data/experiment12b_keystone_protection_summary.csv', index=False)
print(f"Summary saved to /workspace/data/experiment12b_keystone_protection_summary.csv")

print("\n" + "=" * 60)
print("EXPERIMENT 12b COMPLETE")
print("=" * 60)