# Análisis de Propagación de Spikes: Barrido 2D (K, rate_hz)

## Objetivo

Estudiar las probabilidades de activación de vecinos y el firing rate en función de:
- **K (acoplamiento recurrente)**: Factor de escalado de pesos sinápticos
- **rate_hz (input externo)**: Tasa de estímulo talámico

## Estrategia

1. **Barrido 2D**: Simular todas las combinaciones (K, rate_hz) → obtener (FR, P, σ)
2. **Matching por FR**: Para cada (K≠0, FR_target), encontrar K=0 con FR≈FR_target
3. **Contribución de red**: ΔP = P_coupled - P_baseline (mismo FR, diferente origen)
4. **Visualización**: Heatmaps, cortes 1D, análisis de ΔP(K, FR)

## Hipótesis

- FR ≈ a·rate_hz (relación casi lineal)
- K=0 define actividad espúrea (baseline)
- ΔP(K>0) captura la dinámica de red pura

---

## 1. Setup y Configuración

In [None]:
# =============================================================================
# IMPORTS
# =============================================================================
import os
import sys
from pathlib import Path

# Navegación al directorio raíz del proyecto
if Path.cwd().name == 'notebooks':
    os.chdir('..')

import numpy as np
import matplotlib.pyplot as plt
from matplotlib import colors
import seaborn as sns
from brian2 import *
from datetime import datetime
from collections import defaultdict
import pickle
from tqdm.auto import tqdm
import pandas as pd
from scipy.interpolate import interp1d
from scipy.ndimage import gaussian_filter

# Imports del proyecto
from src.two_populations.model import IzhikevichNetwork
from src.two_populations.metrics import analyze_simulation_results
from src.two_populations.helpers.logger import setup_logger

# Configurar logger
logger = setup_logger(
    experiment_name="spike_propagation_2d",
    console_level="INFO",
    file_level="DEBUG",
    log_to_file=False
)

logger.info(f"Working directory: {Path.cwd()}")
logger.info(f"Timestamp: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")

# Estilo de plots
plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette("husl")

## 2. Parámetros del Barrido

In [None]:
# =============================================================================
# CONFIGURACIÓN DEL BARRIDO
# =============================================================================

# Tamaño de red
Ne = 800
Ni = 200

# Parámetros de simulación
SIM_CONFIG = {
    'dt_ms': 0.1,
    'T_ms': 3000,
    'warmup_ms': 500
}

# Parámetros fijos de red
NETWORK_PARAMS = {
    'Ne': Ne,
    'Ni': Ni,
    'noise_exc': 0.884,
    'noise_inh': 0.60,
    'p_intra': 0.1,
    'delay': 0.0,
    'stim_start_ms': None,
    'stim_duration_ms': SIM_CONFIG['T_ms'],
    'stim_base': 1.0,
    'stim_elevated': None
}

# Rango de parámetros a barrer
K_VALUES = np.array([0.0, 0.5, 1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0, 5.0, 7.0, 10.0])
RATE_HZ_VALUES = np.array([2, 3, 4, 5, 6, 7, 8, 10, 12, 15, 20])

# Parámetros del análisis de propagación
PROPAGATION_CONFIG = {
    'window_ms': 5.0,         # Ventana temporal para detectar respuestas
    'min_weight': 0.0,        # Peso mínimo para considerar conexión
    'min_spikes': 20,         # Mínimo de spikes para incluir neurona
}

# Seeds
FIXED_SEED = 100
VARIABLE_SEED = 200

# Directorio de salida
OUTPUT_DIR = Path('results/spike_propagation_2d')
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

logger.info(f"Barrido configurado:")
logger.info(f"  K values: {K_VALUES}")
logger.info(f"  rate_hz values: {RATE_HZ_VALUES}")
logger.info(f"  Total combinaciones: {len(K_VALUES) * len(RATE_HZ_VALUES)}")
logger.info(f"  Simulación: {SIM_CONFIG['T_ms']}ms @ dt={SIM_CONFIG['dt_ms']}ms")
logger.info(f"  Red: {Ne}E + {Ni}I, p_intra={NETWORK_PARAMS['p_intra']}")

## 3. Clase de Análisis de Propagación

In [None]:
# =============================================================================
# PROPAGATION ANALYZER
# =============================================================================

class PropagationAnalyzer:
    """
    Analiza propagación forward E→E:
    Cuando neurona i dispara, ¿cuántos vecinos j responden en ventana temporal?
    
    Métricas:
        - P_transmission: probabilidad de activar vecino por spike
        - σ (sigma): branching ratio = <n_activados>
        - firing_rate: tasa de disparo poblacional (Hz)
    """
    
    def __init__(self, window_ms=5.0, min_weight=0.0, min_spikes=20):
        self.window = window_ms
        self.min_weight = min_weight
        self.min_spikes = min_spikes
        
    def extract_connectivity_E2E(self, synapses_intra, Ne, verbose=False):
        """
        Extrae grafo de conectividad E→E desde sinapsis Brian2.
        
        Returns:
            neighbors: dict {pre_idx: [post_idx_1, post_idx_2, ...]}
            weights: dict {(pre, post): weight}
        """
        neighbors = defaultdict(list)
        weights = {}
        
        pre_indices = np.array(synapses_intra.i)
        post_indices = np.array(synapses_intra.j)
        syn_weights = np.array(synapses_intra.w)
        
        # Filtro: E→E con peso > threshold
        E2E_mask = (pre_indices < Ne) & (post_indices < Ne) & (syn_weights >= 0.0)
        mask = E2E_mask & (syn_weights >= self.min_weight)
        
        for pre, post, w in zip(pre_indices[mask], post_indices[mask], syn_weights[mask]):
            neighbors[int(pre)].append(int(post))
            weights[(int(pre), int(post))] = float(w)
        
        if verbose:
            degrees = [len(v) for v in neighbors.values()]
            logger.debug(f"  E→E connections: {np.sum(mask)} (w>{self.min_weight})")
            logger.debug(f"  Out-degree: mean={np.mean(degrees):.1f}, max={np.max(degrees)}")
        
        return dict(neighbors), weights
    
    def organize_spike_times(self, spike_times_arr, spike_indices_arr):
        """
        Organiza spikes por neurona.
        
        Returns:
            spike_dict: {neuron_idx: sorted_spike_times_array}
        """
        spike_dict = defaultdict(list)
        
        for t, idx in zip(spike_times_arr, spike_indices_arr):
            spike_dict[int(idx)].append(float(t))
        
        spike_dict = {k: np.sort(v) for k, v in spike_dict.items()}
        return dict(spike_dict)
    
    def count_responses_single_spike(self, pre_spike_time, post_neuron_spikes):
        """
        Verifica si neurona post respondió en ventana [t, t+window).
        """
        if len(post_neuron_spikes) == 0:
            return False
        
        responses = post_neuron_spikes[
            (post_neuron_spikes > pre_spike_time) & 
            (post_neuron_spikes < pre_spike_time + self.window)
        ]
        return len(responses) > 0
    
    def analyze(self, spike_dict, neighbors, T_total, warmup=0.0):
        """
        Análisis principal de propagación.
        
        Args:
            spike_dict: {neuron_idx: spike_times}
            neighbors: {pre_idx: [post_idx_list]}
            T_total: duración total (ms)
            warmup: tiempo de warmup a excluir (ms)
            
        Returns:
            dict con métricas: P_transmission, sigma, firing_rate, stats
        """
        # Filtrar spikes por warmup
        spike_dict_filtered = {
            nid: times[times >= warmup] 
            for nid, times in spike_dict.items()
        }
        
        T_analysis = T_total - warmup
        
        ratios_per_spike = []
        activated_counts = []
        per_neuron_stats = {}
        
        total_spikes_analyzed = 0
        neurons_analyzed = 0
        
        for pre_idx in neighbors.keys():
            if pre_idx not in spike_dict_filtered:
                continue
            
            pre_spikes = spike_dict_filtered[pre_idx]
            
            if len(pre_spikes) < self.min_spikes:
                continue
            
            post_neighbors = neighbors[pre_idx]
            n_neighbors = len(post_neighbors)
            
            if n_neighbors == 0:
                continue
            
            neuron_ratios = []
            neuron_activated = []
            
            for spike_time in pre_spikes:
                n_activated = 0
                
                for post_idx in post_neighbors:
                    if post_idx not in spike_dict_filtered:
                        continue
                    
                    post_spikes = spike_dict_filtered[post_idx]
                    
                    if self.count_responses_single_spike(spike_time, post_spikes):
                        n_activated += 1
                
                ratio = n_activated / n_neighbors
                
                ratios_per_spike.append(ratio)
                activated_counts.append(n_activated)
                neuron_ratios.append(ratio)
                neuron_activated.append(n_activated)
                
                total_spikes_analyzed += 1
            
            per_neuron_stats[pre_idx] = {
                'n_spikes': len(pre_spikes),
                'n_neighbors': n_neighbors,
                'mean_ratio': np.mean(neuron_ratios),
                'mean_activated': np.mean(neuron_activated)
            }
            neurons_analyzed += 1
        
        # Calcular firing rate poblacional
        total_spikes = sum(len(times) for times in spike_dict_filtered.values())
        n_neurons = len(spike_dict_filtered)
        firing_rate = (total_spikes / n_neurons / T_analysis) * 1000.0  # Hz
        
        ratios_per_spike = np.array(ratios_per_spike)
        activated_counts = np.array(activated_counts)
        
        results = {
            'P_transmission': np.mean(ratios_per_spike) if len(ratios_per_spike) > 0 else 0.0,
            'P_transmission_std': np.std(ratios_per_spike) if len(ratios_per_spike) > 0 else 0.0,
            'sigma': np.mean(activated_counts) if len(activated_counts) > 0 else 0.0,
            'sigma_std': np.std(activated_counts) if len(activated_counts) > 0 else 0.0,
            'firing_rate': firing_rate,
            'ratio_distribution': ratios_per_spike,
            'activated_counts': activated_counts,
            'per_neuron': per_neuron_stats,
            'stats': {
                'n_neurons_analyzed': neurons_analyzed,
                'total_spikes_analyzed': total_spikes_analyzed,
                'total_spikes': total_spikes,
                'n_neurons_active': n_neurons,
                'T_analysis': T_analysis
            }
        }
        
        return results

logger.success("PropagationAnalyzer class defined")

## 4. Función de Simulación Parametrizada

In [None]:
# =============================================================================
# SIMULATION RUNNER
# =============================================================================

def run_single_simulation(k_factor, rate_hz, trial=0, verbose=False):
    """
    Ejecuta una simulación con parámetros (k_factor, rate_hz).
    
    Args:
        k_factor: Factor de acoplamiento recurrente
        rate_hz: Tasa de estímulo externo (Hz)
        trial: Índice de trial (para seeds)
        verbose: Si True, imprime detalles
        
    Returns:
        dict con:
            - network: objeto IzhikevichNetwork
            - results: resultados de simulación
            - spike_dict: {neuron_idx: spike_times}
            - neighbors: grafo de conectividad E→E
            - weights: pesos sinápticos
    """
    start_scope()
    
    # Crear red
    network = IzhikevichNetwork(
        dt_val=SIM_CONFIG['dt_ms'],
        T_total=SIM_CONFIG['T_ms'],
        fixed_seed=FIXED_SEED,
        variable_seed=VARIABLE_SEED,
        trial=trial
    )
    
    # Parámetros de población
    params = {
        **NETWORK_PARAMS,
        'k_exc': k_factor,
        'k_inh': k_factor * 3.9,
        'rate_hz': rate_hz
    }
    
    # Crear población A
    pop_A = network.create_population2(name='A', **params)
    
    # Setup monitors (no grabar voltajes para ahorrar memoria)
    network.setup_monitors(['A'], record_v_dt=None, sample_fraction=0.0)
    
    # Ejecutar simulación
    results = network.run_simulation()
    
    # Extraer conectividad
    analyzer = PropagationAnalyzer(
        window_ms=PROPAGATION_CONFIG['window_ms'],
        min_weight=PROPAGATION_CONFIG['min_weight'],
        min_spikes=PROPAGATION_CONFIG['min_spikes']
    )
    
    neighbors, weights = analyzer.extract_connectivity_E2E(
        network.populations['A']['syn_intra'],
        Ne,
        verbose=verbose
    )
    
    # Organizar spikes
    spike_dict = analyzer.organize_spike_times(
        results['A']['spike_times'],
        results['A']['spike_indices']
    )
    
    if verbose:
        n_spikes = len(results['A']['spike_times'])
        logger.info(f"  Simulation completed: {n_spikes} spikes")
    
    return {
        'network': network,
        'results': results,
        'spike_dict': spike_dict,
        'neighbors': neighbors,
        'weights': weights
    }

logger.success("Simulation runner function defined")

## 5. Barrido 2D: (K, rate_hz) → (FR, P, σ)

In [None]:
# =============================================================================
# 2D SWEEP
# =============================================================================

def run_2d_sweep(K_values, rate_hz_values, save_results=True):
    """
    Ejecuta barrido 2D completo.
    
    Args:
        K_values: array de valores de K
        rate_hz_values: array de valores de rate_hz
        save_results: si True, guarda resultados en pickle
        
    Returns:
        list of dicts con resultados de cada simulación
    """
    results_list = []
    total_sims = len(K_values) * len(rate_hz_values)
    
    logger.info(f"Starting 2D sweep: {total_sims} simulations")
    logger.info(f"K: {K_values}")
    logger.info(f"rate_hz: {rate_hz_values}")
    
    # Crear analyzer
    analyzer = PropagationAnalyzer(
        window_ms=PROPAGATION_CONFIG['window_ms'],
        min_weight=PROPAGATION_CONFIG['min_weight'],
        min_spikes=PROPAGATION_CONFIG['min_spikes']
    )
    
    # Progress bar
    with tqdm(total=total_sims, desc="2D Sweep") as pbar:
        for k_val in K_values:
            for rate_val in rate_hz_values:
                try:
                    # Simular
                    sim_data = run_single_simulation(
                        k_factor=k_val,
                        rate_hz=rate_val,
                        trial=0,
                        verbose=False
                    )
                    
                    # Analizar propagación
                    prop_results = analyzer.analyze(
                        spike_dict=sim_data['spike_dict'],
                        neighbors=sim_data['neighbors'],
                        T_total=SIM_CONFIG['T_ms'],
                        warmup=SIM_CONFIG['warmup_ms']
                    )
                    
                    # Guardar resultados condensados
                    result_entry = {
                        'k': k_val,
                        'rate_hz': rate_val,
                        'firing_rate': prop_results['firing_rate'],
                        'P_transmission': prop_results['P_transmission'],
                        'P_transmission_std': prop_results['P_transmission_std'],
                        'sigma': prop_results['sigma'],
                        'sigma_std': prop_results['sigma_std'],
                        'n_neurons_analyzed': prop_results['stats']['n_neurons_analyzed'],
                        'total_spikes': prop_results['stats']['total_spikes']
                    }
                    
                    results_list.append(result_entry)
                    
                    pbar.set_postfix({
                        'K': f"{k_val:.1f}",
                        'rate': f"{rate_val:.0f}",
                        'FR': f"{prop_results['firing_rate']:.1f}",
                        'P': f"{prop_results['P_transmission']:.3f}"
                    })
                    
                except Exception as e:
                    logger.error(f"Error at K={k_val}, rate={rate_val}: {str(e)}")
                    result_entry = {
                        'k': k_val,
                        'rate_hz': rate_val,
                        'firing_rate': np.nan,
                        'P_transmission': np.nan,
                        'P_transmission_std': np.nan,
                        'sigma': np.nan,
                        'sigma_std': np.nan,
                        'n_neurons_analyzed': 0,
                        'total_spikes': 0
                    }
                    results_list.append(result_entry)
                
                pbar.update(1)
    
    # Convertir a DataFrame
    df_results = pd.DataFrame(results_list)
    
    # Guardar resultados
    if save_results:
        timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
        output_file = OUTPUT_DIR / f'sweep_2d_{timestamp}.pkl'
        with open(output_file, 'wb') as f:
            pickle.dump({
                'df_results': df_results,
                'K_values': K_values,
                'rate_hz_values': rate_hz_values,
                'config': {
                    'SIM_CONFIG': SIM_CONFIG,
                    'NETWORK_PARAMS': NETWORK_PARAMS,
                    'PROPAGATION_CONFIG': PROPAGATION_CONFIG
                }
            }, f)
        logger.success(f"Results saved to {output_file}")
    
    logger.success(f"2D sweep completed: {len(df_results)} simulations")
    
    return df_results

logger.success("2D sweep function defined")

In [None]:
# =============================================================================
# EJECUTAR BARRIDO 2D
# =============================================================================

# NOTA: Este bloque tarda ~30-60 minutos dependiendo del hardware
# Puedes comentar esta celda y cargar resultados previos en la siguiente sección

logger.info("Starting 2D sweep...")
df_sweep = run_2d_sweep(K_VALUES, RATE_HZ_VALUES, save_results=True)

# Mostrar resumen
print("\n" + "="*80)
print("SWEEP SUMMARY")
print("="*80)
print(df_sweep.describe())
print("\n" + "="*80)

## 6. Cargar Resultados (Opcional)

In [None]:
# =============================================================================
# LOAD PREVIOUS RESULTS (OPTIONAL)
# =============================================================================

# Si ya ejecutaste el barrido antes, puedes cargar los resultados:

# load_file = OUTPUT_DIR / 'sweep_2d_YYYYMMDD_HHMMSS.pkl'  # <-- Editar con tu archivo
# with open(load_file, 'rb') as f:
#     loaded_data = pickle.load(f)
# df_sweep = loaded_data['df_results']
# logger.info(f"Loaded results from {load_file}")
# print(df_sweep.head())

## 7. Análisis de Matching por FR

In [None]:
# =============================================================================
# BASELINE MATCHING POR FIRING RATE
# =============================================================================

def compute_network_contribution(df_sweep):
    """
    Calcula contribución de red mediante matching por FR.
    
    Para cada (K>0, FR_target):
        1. Encontrar K=0 con FR ≈ FR_target
        2. Calcular ΔP = P_coupled - P_baseline
        3. Calcular Δσ = σ_coupled - σ_baseline
    
    Args:
        df_sweep: DataFrame con resultados del barrido
        
    Returns:
        DataFrame con columnas adicionales: FR_baseline_matched, P_baseline, delta_P, delta_sigma
    """
    # Separar baseline (K=0) y coupled (K>0)
    df_baseline = df_sweep[df_sweep['k'] == 0.0].copy()
    df_coupled = df_sweep[df_sweep['k'] > 0.0].copy()
    
    logger.info(f"Baseline points: {len(df_baseline)}")
    logger.info(f"Coupled points: {len(df_coupled)}")
    
    # Crear interpolador para baseline: rate_hz → FR, P, σ
    baseline_sorted = df_baseline.sort_values('rate_hz')
    
    interp_FR = interp1d(
        baseline_sorted['rate_hz'],
        baseline_sorted['firing_rate'],
        kind='linear',
        fill_value='extrapolate'
    )
    
    interp_P = interp1d(
        baseline_sorted['firing_rate'],
        baseline_sorted['P_transmission'],
        kind='linear',
        fill_value='extrapolate'
    )
    
    interp_sigma = interp1d(
        baseline_sorted['firing_rate'],
        baseline_sorted['sigma'],
        kind='linear',
        fill_value='extrapolate'
    )
    
    # Para cada punto coupled, encontrar baseline con FR similar
    matched_results = []
    
    for idx, row in df_coupled.iterrows():
        FR_target = row['firing_rate']
        
        # Interpolar valores baseline para este FR
        try:
            P_baseline = float(interp_P(FR_target))
            sigma_baseline = float(interp_sigma(FR_target))
            
            matched_results.append({
                'k': row['k'],
                'rate_hz': row['rate_hz'],
                'firing_rate': FR_target,
                'P_transmission': row['P_transmission'],
                'sigma': row['sigma'],
                'P_baseline': P_baseline,
                'sigma_baseline': sigma_baseline,
                'delta_P': row['P_transmission'] - P_baseline,
                'delta_sigma': row['sigma'] - sigma_baseline,
                'fold_change_P': row['P_transmission'] / P_baseline if P_baseline > 0 else np.nan,
                'fold_change_sigma': row['sigma'] / sigma_baseline if sigma_baseline > 0 else np.nan
            })
        except Exception as e:
            logger.warning(f"Could not match K={row['k']}, rate={row['rate_hz']}: {e}")
    
    df_matched = pd.DataFrame(matched_results)
    
    logger.success(f"Matched {len(df_matched)} points")
    
    return df_matched, df_baseline

# Ejecutar matching
df_network_contribution, df_baseline = compute_network_contribution(df_sweep)

# Mostrar resumen
print("\n" + "="*80)
print("NETWORK CONTRIBUTION SUMMARY")
print("="*80)
print(df_network_contribution.groupby('k')[['delta_P', 'delta_sigma', 'fold_change_P']].mean())
print("\n" + "="*80)

## 8. Visualizaciones: Heatmaps 2D

In [None]:
# =============================================================================
# HEATMAPS 2D: P(K, rate_hz), FR(K, rate_hz), σ(K, rate_hz)
# =============================================================================

def plot_2d_heatmaps(df_sweep, K_values, rate_hz_values):
    """
    Genera heatmaps 2D de las métricas principales.
    """
    fig, axes = plt.subplots(2, 3, figsize=(18, 12))
    
    # Preparar grillas
    metrics = [
        ('firing_rate', 'Firing Rate (Hz)', 'viridis'),
        ('P_transmission', 'P_transmission', 'plasma'),
        ('sigma', 'σ (branching ratio)', 'coolwarm')
    ]
    
    for col_idx, (metric, title, cmap) in enumerate(metrics):
        # Raw heatmap
        ax = axes[0, col_idx]
        
        pivot = df_sweep.pivot_table(
            index='k',
            columns='rate_hz',
            values=metric,
            aggfunc='mean'
        )
        
        im = ax.imshow(
            pivot.values,
            aspect='auto',
            cmap=cmap,
            origin='lower',
            extent=[rate_hz_values.min(), rate_hz_values.max(), 
                   K_values.min(), K_values.max()]
        )
        
        ax.set_xlabel('rate_hz (Hz)', fontsize=12)
        ax.set_ylabel('K (coupling)', fontsize=12)
        ax.set_title(f'{title}', fontsize=13, fontweight='bold')
        plt.colorbar(im, ax=ax, label=title)
        
        # Smoothed heatmap
        ax = axes[1, col_idx]
        
        smoothed = gaussian_filter(pivot.values, sigma=0.8)
        
        im = ax.imshow(
            smoothed,
            aspect='auto',
            cmap=cmap,
            origin='lower',
            extent=[rate_hz_values.min(), rate_hz_values.max(), 
                   K_values.min(), K_values.max()]
        )
        
        ax.set_xlabel('rate_hz (Hz)', fontsize=12)
        ax.set_ylabel('K (coupling)', fontsize=12)
        ax.set_title(f'{title} (smoothed)', fontsize=13, fontweight='bold')
        plt.colorbar(im, ax=ax, label=title)
    
    plt.tight_layout()
    return fig

# Generar heatmaps
fig_heatmaps = plot_2d_heatmaps(df_sweep, K_VALUES, RATE_HZ_VALUES)
plt.savefig(OUTPUT_DIR / 'heatmaps_2d.png', dpi=300, bbox_inches='tight')
plt.show()

logger.success("Heatmaps generated")

## 9. Visualizaciones: Cortes 1D

In [None]:
# =============================================================================
# CORTES 1D: P vs K (rate_hz fijo), P vs rate_hz (K fijo)
# =============================================================================

def plot_1d_slices(df_sweep, K_values, rate_hz_values):
    """
    Genera cortes 1D de las métricas.
    """
    fig, axes = plt.subplots(2, 3, figsize=(18, 10))
    
    # === FILA 1: P_transmission ===
    
    # P vs K (varios rate_hz)
    ax = axes[0, 0]
    rate_hz_samples = [4, 8, 12, 20]
    for rate in rate_hz_samples:
        df_slice = df_sweep[df_sweep['rate_hz'] == rate]
        ax.plot(df_slice['k'], df_slice['P_transmission'], 
               'o-', label=f'rate={rate}Hz', linewidth=2, markersize=6)
    ax.set_xlabel('K (coupling)', fontsize=12)
    ax.set_ylabel('P_transmission', fontsize=12)
    ax.set_title('P_transmission vs K', fontsize=13, fontweight='bold')
    ax.legend(fontsize=9)
    ax.grid(alpha=0.3)
    
    # P vs rate_hz (varios K)
    ax = axes[0, 1]
    K_samples = [0.0, 1.0, 2.5, 5.0, 10.0]
    for k in K_samples:
        df_slice = df_sweep[df_sweep['k'] == k]
        ax.plot(df_slice['rate_hz'], df_slice['P_transmission'], 
               'o-', label=f'K={k}', linewidth=2, markersize=6)
    ax.set_xlabel('rate_hz (Hz)', fontsize=12)
    ax.set_ylabel('P_transmission', fontsize=12)
    ax.set_title('P_transmission vs rate_hz', fontsize=13, fontweight='bold')
    ax.legend(fontsize=9)
    ax.grid(alpha=0.3)
    
    # σ vs K
    ax = axes[0, 2]
    for rate in rate_hz_samples:
        df_slice = df_sweep[df_sweep['rate_hz'] == rate]
        ax.plot(df_slice['k'], df_slice['sigma'], 
               'o-', label=f'rate={rate}Hz', linewidth=2, markersize=6)
    ax.axhline(1.0, color='red', linestyle='--', linewidth=2, alpha=0.5, label='Critical (σ=1)')
    ax.set_xlabel('K (coupling)', fontsize=12)
    ax.set_ylabel('σ (branching ratio)', fontsize=12)
    ax.set_title('σ vs K', fontsize=13, fontweight='bold')
    ax.legend(fontsize=9)
    ax.grid(alpha=0.3)
    
    # === FILA 2: Firing Rate ===
    
    # FR vs K
    ax = axes[1, 0]
    for rate in rate_hz_samples:
        df_slice = df_sweep[df_sweep['rate_hz'] == rate]
        ax.plot(df_slice['k'], df_slice['firing_rate'], 
               'o-', label=f'rate={rate}Hz', linewidth=2, markersize=6)
    ax.set_xlabel('K (coupling)', fontsize=12)
    ax.set_ylabel('Firing Rate (Hz)', fontsize=12)
    ax.set_title('Firing Rate vs K', fontsize=13, fontweight='bold')
    ax.legend(fontsize=9)
    ax.grid(alpha=0.3)
    
    # FR vs rate_hz
    ax = axes[1, 1]
    for k in K_samples:
        df_slice = df_sweep[df_sweep['k'] == k]
        ax.plot(df_slice['rate_hz'], df_slice['firing_rate'], 
               'o-', label=f'K={k}', linewidth=2, markersize=6)
    ax.set_xlabel('rate_hz (Hz)', fontsize=12)
    ax.set_ylabel('Firing Rate (Hz)', fontsize=12)
    ax.set_title('Firing Rate vs rate_hz', fontsize=13, fontweight='bold')
    ax.legend(fontsize=9)
    ax.grid(alpha=0.3)
    
    # FR vs rate_hz (linealidad)
    ax = axes[1, 2]
    colors = plt.cm.viridis(np.linspace(0, 1, len(K_samples)))
    for idx, k in enumerate(K_samples):
        df_slice = df_sweep[df_sweep['k'] == k]
        ax.scatter(df_slice['rate_hz'], df_slice['firing_rate'], 
                  label=f'K={k}', s=80, alpha=0.7, color=colors[idx])
        # Fit lineal
        if len(df_slice) > 2:
            z = np.polyfit(df_slice['rate_hz'], df_slice['firing_rate'], 1)
            p = np.poly1d(z)
            ax.plot(df_slice['rate_hz'], p(df_slice['rate_hz']), 
                   '--', color=colors[idx], linewidth=1.5, alpha=0.5)
    ax.set_xlabel('rate_hz (Hz)', fontsize=12)
    ax.set_ylabel('Firing Rate (Hz)', fontsize=12)
    ax.set_title('FR ≈ a·rate_hz (linearity check)', fontsize=13, fontweight='bold')
    ax.legend(fontsize=9, loc='upper left')
    ax.grid(alpha=0.3)
    
    plt.tight_layout()
    return fig

# Generar cortes 1D
fig_slices = plot_1d_slices(df_sweep, K_VALUES, RATE_HZ_VALUES)
plt.savefig(OUTPUT_DIR / 'slices_1d.png', dpi=300, bbox_inches='tight')
plt.show()

logger.success("1D slices generated")

## 10. Análisis de Contribución de Red: ΔP, Δσ

In [None]:
# =============================================================================
# NETWORK CONTRIBUTION VISUALIZATION
# =============================================================================

def plot_network_contribution(df_matched, df_baseline):
    """
    Visualiza la contribución de red (ΔP, Δσ).
    """
    fig, axes = plt.subplots(2, 3, figsize=(18, 10))
    
    # === FILA 1: ΔP ===
    
    # ΔP vs K
    ax = axes[0, 0]
    df_grouped = df_matched.groupby('k')['delta_P'].agg(['mean', 'std'])
    ax.errorbar(df_grouped.index, df_grouped['mean'], yerr=df_grouped['std'],
               fmt='o-', linewidth=2, markersize=8, capsize=5, capthick=2)
    ax.axhline(0, color='red', linestyle='--', linewidth=1.5, alpha=0.5)
    ax.set_xlabel('K (coupling)', fontsize=12)
    ax.set_ylabel('ΔP = P_coupled - P_baseline', fontsize=12)
    ax.set_title('Network Contribution to P', fontsize=13, fontweight='bold')
    ax.grid(alpha=0.3)
    
    # ΔP vs FR
    ax = axes[0, 1]
    K_samples = [1.0, 2.5, 5.0, 10.0]
    for k in K_samples:
        df_k = df_matched[df_matched['k'] == k]
        ax.scatter(df_k['firing_rate'], df_k['delta_P'], 
                  label=f'K={k}', s=80, alpha=0.7)
    ax.axhline(0, color='red', linestyle='--', linewidth=1.5, alpha=0.5)
    ax.set_xlabel('Firing Rate (Hz)', fontsize=12)
    ax.set_ylabel('ΔP', fontsize=12)
    ax.set_title('ΔP vs Firing Rate', fontsize=13, fontweight='bold')
    ax.legend(fontsize=9)
    ax.grid(alpha=0.3)
    
    # Fold change P
    ax = axes[0, 2]
    df_grouped = df_matched.groupby('k')['fold_change_P'].agg(['mean', 'std'])
    ax.errorbar(df_grouped.index, df_grouped['mean'], yerr=df_grouped['std'],
               fmt='o-', linewidth=2, markersize=8, capsize=5, capthick=2, color='purple')
    ax.axhline(1.0, color='red', linestyle='--', linewidth=1.5, alpha=0.5, label='No change')
    ax.set_xlabel('K (coupling)', fontsize=12)
    ax.set_ylabel('Fold Change = P_coupled / P_baseline', fontsize=12)
    ax.set_title('P Fold Change', fontsize=13, fontweight='bold')
    ax.legend(fontsize=9)
    ax.grid(alpha=0.3)
    
    # === FILA 2: Δσ ===
    
    # Δσ vs K
    ax = axes[1, 0]
    df_grouped = df_matched.groupby('k')['delta_sigma'].agg(['mean', 'std'])
    ax.errorbar(df_grouped.index, df_grouped['mean'], yerr=df_grouped['std'],
               fmt='o-', linewidth=2, markersize=8, capsize=5, capthick=2, color='orange')
    ax.axhline(0, color='red', linestyle='--', linewidth=1.5, alpha=0.5)
    ax.set_xlabel('K (coupling)', fontsize=12)
    ax.set_ylabel('Δσ = σ_coupled - σ_baseline', fontsize=12)
    ax.set_title('Network Contribution to σ', fontsize=13, fontweight='bold')
    ax.grid(alpha=0.3)
    
    # Δσ vs FR
    ax = axes[1, 1]
    for k in K_samples:
        df_k = df_matched[df_matched['k'] == k]
        ax.scatter(df_k['firing_rate'], df_k['delta_sigma'], 
                  label=f'K={k}', s=80, alpha=0.7)
    ax.axhline(0, color='red', linestyle='--', linewidth=1.5, alpha=0.5)
    ax.set_xlabel('Firing Rate (Hz)', fontsize=12)
    ax.set_ylabel('Δσ', fontsize=12)
    ax.set_title('Δσ vs Firing Rate', fontsize=13, fontweight='bold')
    ax.legend(fontsize=9)
    ax.grid(alpha=0.3)
    
    # Baseline vs Coupled scatter
    ax = axes[1, 2]
    colors = plt.cm.plasma(np.linspace(0, 1, len(K_samples)))
    for idx, k in enumerate(K_samples):
        df_k = df_matched[df_matched['k'] == k]
        ax.scatter(df_k['P_baseline'], df_k['P_transmission'],
                  label=f'K={k}', s=80, alpha=0.7, color=colors[idx])
    # Diagonal
    lims = [0, max(df_matched['P_baseline'].max(), df_matched['P_transmission'].max())]
    ax.plot(lims, lims, 'k--', linewidth=1.5, alpha=0.5, label='P_coupled = P_baseline')
    ax.set_xlabel('P_baseline (K=0)', fontsize=12)
    ax.set_ylabel('P_coupled (K>0)', fontsize=12)
    ax.set_title('P: Coupled vs Baseline', fontsize=13, fontweight='bold')
    ax.legend(fontsize=9)
    ax.grid(alpha=0.3)
    
    plt.tight_layout()
    return fig

# Generar plots de contribución
fig_contribution = plot_network_contribution(df_network_contribution, df_baseline)
plt.savefig(OUTPUT_DIR / 'network_contribution.png', dpi=300, bbox_inches='tight')
plt.show()

logger.success("Network contribution plots generated")

## 11. Tabla de Resultados Clave

In [None]:
# =============================================================================
# SUMMARY TABLE
# =============================================================================

print("\n" + "="*100)
print("SUMMARY: BASELINE (K=0) vs COUPLED (K>0)")
print("="*100)

# Baseline
df_baseline_summary = df_baseline.describe()[['firing_rate', 'P_transmission', 'sigma']]
print("\nBASELINE (K=0):")
print(df_baseline_summary)

# Coupled (promedio por K)
df_coupled = df_sweep[df_sweep['k'] > 0.0]
df_coupled_summary = df_coupled.groupby('k')[['firing_rate', 'P_transmission', 'sigma']].mean()
print("\nCOUPLED (K>0) - MEAN BY K:")
print(df_coupled_summary)

# Network contribution
df_contribution_summary = df_network_contribution.groupby('k')[['delta_P', 'delta_sigma', 'fold_change_P']].mean()
print("\nNETWORK CONTRIBUTION - MEAN BY K:")
print(df_contribution_summary)

print("\n" + "="*100)

# Guardar tabla
df_contribution_summary.to_csv(OUTPUT_DIR / 'network_contribution_summary.csv')
logger.success(f"Summary table saved to {OUTPUT_DIR / 'network_contribution_summary.csv'}")

## 12. Conclusiones y Próximos Pasos

### Resultados esperados:

1. **FR ≈ a·rate_hz**: Verificar linealidad entre input externo y firing rate
2. **P(K, rate_hz)**: Probabilidad de transmisión aumenta con K
3. **ΔP > 0 para K>0**: La red contribuye positivamente a la propagación
4. **σ(K) → 1**: Transición hacia criticalidad con acoplamiento óptimo

### Análisis complementarios:

- [ ] Barrido con delays (τ ≠ 0)
- [ ] Análisis de distribuciones de ISI
- [ ] Separación E/I en la propagación
- [ ] Correlación con INT
- [ ] Análisis de criticalidad (avalanchas)

### Optimizaciones:

- [ ] Paralelizar simulaciones (múltiples trials)
- [ ] Reducir resolución temporal en análisis
- [ ] Guardar solo métricas (no raster completo)

---