# GRU Upgrade

In [2]:
"""
================================================================================
MULTI-MECHANISM ANTIDEPRESSANT COMPARISON EXPERIMENT
WITH MANIC CONVERSION RISK MODELING + LONGITUDINAL MANIC RELAPSE + KINDLING
================================================================================

GRU-BASED RECURRENT ARCHITECTURE UPGRADE:
- Replaced feed-forward MLP with GRUCell-based recurrent network
- Per-layer modulations become per-time-step modulations
- Hidden state carries memory across steps (better instability modeling)
- Positive bias compounds over recurrent steps (more realistic mania)
- All other logic (pruning, scarring, treatments, kindling) unchanged

This script compares three antidepressant mechanisms and models:
1. Acute treatment efficacy
2. Manic conversion risk (treatment-emergent mania)
3. Longitudinal manic relapse after medication discontinuation
4. Kindling hypothesis: progressive sensitization via uniform irreversible scarring

Key Architecture/Parameter Changes for Kindling:
- Added scar_masks to PruningManager: Tracks permanently damaged connections
- Uniform base scarring (0.05) applied upon each manic relapse
- Emergent severity scaling based on current gain and activation magnitude
- Gradient-guided regrowth respects scar_masks (cannot regrow scarred positions)
- Multi-cycle simulation with weakening triggers to test autonomy emergence
- Early adversity modeling via random initial permanent pruning

Treatment differentiation emerges from existing architecture:
- Ketamine: Regrowth compensates for scars, moderate gain
- SSRI: High gain amplifies vulnerability, no structural repair
- Neurosteroid: Low gain/inhibition buffers acute damage
================================================================================
"""

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.utils.data import DataLoader, TensorDataset
from typing import Dict, Tuple, List
import warnings
from collections import defaultdict
import copy

warnings.filterwarnings('ignore', category=UserWarning)

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# ============================================================================
# CONFIGURATION
# ============================================================================
CONFIG = {
    'n_seeds': 10,
    'n_train': 12000,
    'n_test': 4000,
    'n_clean_test': 2000,
    'data_noise': 0.8,
    'batch_size': 128,
    'hidden_dims': [512, 512, 256],  # Kept for reference/compatibility
    'input_dim': 2,
    'output_dim': 4,
    'baseline_epochs': 20,
    'baseline_lr': 0.001,
    'finetune_epochs': 15,
    'finetune_lr': 0.0005,
    'prune_sparsity': 0.95,
    'regrow_fraction': 0.5,
    'regrow_init_scale': 0.03,
    'gradient_accumulation_batches': 30,
    'extended_stress_levels': {
        'none': 0.0,
        'moderate': 0.5,
        'high': 1.0,
        'severe': 1.5,
        'extreme': 2.5
    },
    'monoaminergic_epochs': 100,
    'monoaminergic_lr': 1e-5,
    'monoaminergic_initial_stress': 0.5,
    'ssri_max_gain': 1.6,
    'ketamine_gain': 1.25,
    'neurosteroid_inhibition_strength': 0.7,
    'neurosteroid_use_tanh': True,
    'neurosteroid_consolidation_epochs': 10,
    'neurosteroid_gain': 0.85,
    'mania_test_bias': 1.0,
    'mania_test_sigma': 1.0,
    'comparison_ketamine_regrow': 0.5,
    'comparison_ketamine_epochs': 15,
    'comparison_ssri_epochs': 100,
    'comparison_neurosteroid_strength': 0.7,
    'comparison_neurosteroid_epochs': 10,

    # Mood Stabilizer (MS) Architecture Parameters
    'ms_gain_cap': 1.05,
    'ms_inhib_bias_strength': 0.15,
    'ms_bias_damping_factor': 0.3,
    'ms_max_protection': 1.0,

    # Longitudinal Manic Relapse Parameters
    'maintenance_durations': [25, 50, 100, 150, 200, 300],
    'maintenance_lr': 1e-6,
    'post_discontinuation_steps': 50,
    'manic_relapse_threshold': 60.0,

    # Treatment-specific MS decay rates (architecture parameter differences)
    'ketamine_ms_decay_rate': 0.002,
    'ssri_ms_decay_rate': 0.015,
    'neurosteroid_ms_decay_rate': 0.008,

    # Kindling Parameters (Architecture-based)
    'kindling_base_scar_sparsity': 0.05,
    'kindling_severity_min': 0.5,
    'kindling_severity_max': 2.0,
    'kindling_num_cycles': 6,
    'kindling_initial_trigger_bias': 1.5,
    'kindling_final_trigger_bias': 0.5,
    'kindling_inter_episode_maintenance_epochs': 20,
    'kindling_autonomy_threshold': 70.0,
    'early_adversity_max_scar': 0.06,

    # GRU-specific parameters (NEW)
    'seq_len': 20,           # Acts as "depth" (original had ~3 hidden layers + mods)
    'gru_hidden_size': 384,  # Balanced size (original total params ~ similar)
}


# ============================================================================
# DATA GENERATION
# ============================================================================
def generate_blobs(n_samples: int = 10000, noise: float = 0.8, seed: int = None) -> Tuple[torch.Tensor, torch.Tensor]:
    if seed is not None:
        rng = np.random.RandomState(seed)
    else:
        rng = np.random.RandomState()
    centers = np.array([[-3, -3], [3, 3], [-3, 3], [3, -3]])
    labels = rng.randint(0, 4, n_samples)
    data = centers[labels] + rng.randn(n_samples, 2) * noise
    return (torch.tensor(data, dtype=torch.float32), torch.tensor(labels, dtype=torch.long))


def create_data_loaders(seed: int) -> Tuple[DataLoader, DataLoader, DataLoader]:
    seq_len = CONFIG['seq_len']

    train_data, train_labels = generate_blobs(CONFIG['n_train'], noise=CONFIG['data_noise'], seed=seed*1000+100)
    test_data, test_labels = generate_blobs(CONFIG['n_test'], noise=CONFIG['data_noise'], seed=seed*1000+200)
    clean_test_data, clean_test_labels = generate_blobs(CONFIG['n_clean_test'], noise=0.0, seed=seed*1000+300)

    # Repeat static input over sequence dimension [N, 2] -> [N, seq_len, 2]
    train_data = train_data.unsqueeze(1).repeat(1, seq_len, 1)
    test_data = test_data.unsqueeze(1).repeat(1, seq_len, 1)
    clean_test_data = clean_test_data.unsqueeze(1).repeat(1, seq_len, 1)

    train_loader = DataLoader(TensorDataset(train_data, train_labels), batch_size=CONFIG['batch_size'], shuffle=True)
    test_loader = DataLoader(TensorDataset(test_data, test_labels), batch_size=1000)
    clean_test_loader = DataLoader(TensorDataset(clean_test_data, clean_test_labels), batch_size=1000)
    return train_loader, test_loader, clean_test_loader


# ============================================================================
# GRU-BASED RECURRENT NETWORK ARCHITECTURE WITH MS PROTECTION PARAMETERS
# ============================================================================
class RecurrentStressNetwork(nn.Module):
    """
    GRUCell-based recurrent network with per-step modulations.
    Preserves original per-layer logic (stress noise, gain, inhibition, MS protection).

    Architecture:
    - GRUCell processes input at each time step
    - Per-step modulations: stress noise, gain scaling, inhibition, MS bias damping
    - Hidden state carries memory across steps (enables instability compounding)
    - Final hidden state passed through output layer
    """

    def __init__(self):
        super().__init__()
        self.hidden_size = CONFIG['gru_hidden_size']
        self.seq_len = CONFIG['seq_len']

        # GRU cell for recurrent processing
        self.gru_cell = nn.GRUCell(CONFIG['input_dim'], self.hidden_size)

        # Output layer
        self.fc = nn.Linear(self.hidden_size, CONFIG['output_dim'])

        # Activation functions
        self.relu = nn.ReLU()
        self.tanh = nn.Tanh()

        # Core modulation parameters (same as original)
        self.stress_level = 0.0
        self.inhibition_strength = 1.0
        self.use_tanh = False
        self.gain_multiplier = 1.0

        # Mood Stabilizer (MS) parameters (same as original)
        self.ms_protection_level = 0.0
        self.lingering_ms_decay_rate = 0.0

        # Weight layers for pruning manager compatibility
        self.weight_layers = ['gru_cell.weight_ih', 'gru_cell.weight_hh', 'fc']

    def set_stress(self, level: float):
        self.stress_level = level

    def set_inhibition(self, strength: float, use_tanh: bool = False):
        self.inhibition_strength = strength
        self.use_tanh = use_tanh

    def set_gain(self, gain: float):
        self.gain_multiplier = gain

    def set_ms_protection(self, level: float, decay_rate: float = 0.0):
        self.ms_protection_level = max(0.0, min(1.0, level))
        self.lingering_ms_decay_rate = decay_rate

    def decay_ms_protection(self):
        if self.lingering_ms_decay_rate > 0:
            self.ms_protection_level = max(0.0, self.ms_protection_level - self.lingering_ms_decay_rate)

    def reset_antidepressant_effects(self):
        self.gain_multiplier = 1.0
        self.inhibition_strength = 1.0
        self.use_tanh = False
        self.stress_level = 0.0

    def get_effective_gain(self) -> float:
        if self.ms_protection_level > 0:
            max_allowed = CONFIG['ms_gain_cap'] + (1.0 - self.ms_protection_level) * (self.gain_multiplier - CONFIG['ms_gain_cap'])
            return min(self.gain_multiplier, max(CONFIG['ms_gain_cap'], max_allowed))
        return self.gain_multiplier

    def reduce_stress_gradually(self, epoch: int, total_epochs: int, initial_stress: float = 0.5, final_stress: float = 0.0):
        progress = epoch / max(total_epochs - 1, 1)
        self.stress_level = initial_stress + progress * (final_stress - initial_stress)

    def increase_gain_gradually(self, epoch: int, total_epochs: int, initial_gain: float = 1.0, max_gain: float = 1.6):
        progress = epoch / max(total_epochs - 1, 1)
        self.gain_multiplier = initial_gain + progress * (max_gain - initial_gain)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass with per-step modulations.
        x: [batch, seq_len, input_dim]
        """
        batch_size = x.size(0)
        h = torch.zeros(batch_size, self.hidden_size, device=x.device)

        activation = self.tanh if self.use_tanh else self.relu
        effective_gain = self.get_effective_gain()

        for t in range(x.size(1)):
            # Apply gain to input
            inp = x[:, t, :] * effective_gain

            # Add stress noise to input (per-step, like per-layer in original)
            if self.stress_level > 0.0:
                inp = inp + torch.randn_like(inp) * self.stress_level

            # GRU cell update
            h = self.gru_cell(inp, h)

            # Apply inhibition (per-step modulation)
            h = h * self.inhibition_strength

            # Apply MS protection bias (per-step)
            if self.ms_protection_level > 0.0:
                h = h - CONFIG['ms_inhib_bias_strength'] * self.ms_protection_level

            # Post-activation (optional, matches original ReLU/Tanh per layer)
            h = activation(h)

        return self.fc(h)

    def forward_with_biased_noise(self, x: torch.Tensor, internal_sigma: float = 1.0, bias: float = 1.0) -> torch.Tensor:
        """
        Forward pass with biased noise for mania testing.
        Bias compounds over recurrent steps -> better instability modeling.
        x: [batch, seq_len, input_dim]
        """
        batch_size = x.size(0)
        h = torch.zeros(batch_size, self.hidden_size, device=x.device)

        activation = self.tanh if self.use_tanh else self.relu
        effective_gain = self.get_effective_gain()

        # MS protection dampens bias
        bias_multiplier = 1.0 - (CONFIG['ms_bias_damping_factor'] * self.ms_protection_level)
        effective_bias = bias * bias_multiplier

        for t in range(x.size(1)):
            # Apply gain to input
            inp = x[:, t, :] * effective_gain

            # GRU cell update
            h = self.gru_cell(inp, h)

            # Add biased noise AFTER cell update (compounds recurrently -> better instability)
            h = h + (torch.randn_like(h) * internal_sigma + effective_bias)

            # Apply inhibition
            h = h * self.inhibition_strength

            # Apply MS protection bias
            if self.ms_protection_level > 0.0:
                h = h - CONFIG['ms_inhib_bias_strength'] * self.ms_protection_level

            # Post-activation
            h = activation(h)

        return self.fc(h)

    def count_parameters(self) -> Tuple[int, int]:
        total = sum(p.numel() for p in self.parameters())
        nonzero = sum((p != 0).sum().item() for p in self.parameters())
        return total, nonzero


# Alias for backward compatibility
StressAwareNetwork = RecurrentStressNetwork


# ============================================================================
# PRUNING MANAGER WITH SCAR MASKS (KINDLING ARCHITECTURE)
# ============================================================================
class PruningManager:
    """
    Extended PruningManager with permanent scar masks for kindling.
    Now works with GRU weights (weight_ih, weight_hh) in addition to linear layers.

    Architecture changes for kindling:
    - scar_masks: Binary masks tracking permanently damaged (scarred) connections
    - Scarred positions cannot be regrown (gradient_guided_regrow respects scars)
    - prune_by_magnitude with permanent=True updates scar_masks
    """

    def __init__(self, model: RecurrentStressNetwork, train_loader: DataLoader):
        self.model = model
        self.train_loader = train_loader
        self.masks = {}
        self.scar_masks = {}  # Permanent damage tracking
        self.gradient_buffer = {}

        for name, param in model.named_parameters():
            if 'weight' in name and param.dim() >= 2:
                self.masks[name] = torch.ones_like(param, dtype=torch.float32)
                self.scar_masks[name] = torch.zeros_like(param, dtype=torch.float32)  # 1 = scarred
                self.gradient_buffer[name] = torch.zeros_like(param)

    def prune_by_magnitude(self, sparsity: float, per_layer: bool = True, permanent: bool = False) -> Dict[str, Dict]:
        """
        Prune by magnitude. If permanent=True, mark pruned positions as scars.
        """
        stats = {}
        for name, param in self.model.named_parameters():
            if name in self.masks:
                # Only consider non-scarred positions for pruning
                available_mask = (self.scar_masks[name] == 0).float()
                weights = param.data.abs() * available_mask

                # Get threshold from available weights only
                available_weights = weights[available_mask > 0]
                if available_weights.numel() == 0:
                    stats[name] = {'kept': 0, 'total': self.masks[name].numel(), 'actual_sparsity': 1.0}
                    continue

                threshold = torch.quantile(available_weights.flatten(), sparsity)
                new_prune_mask = (weights >= threshold).float()

                # Combine with existing mask
                self.masks[name] = new_prune_mask * (1 - self.scar_masks[name])
                param.data *= self.masks[name]

                if permanent:
                    # Mark newly pruned positions as permanent scars
                    newly_pruned = (new_prune_mask == 0) * available_mask
                    self.scar_masks[name] = torch.clamp(self.scar_masks[name] + newly_pruned, 0, 1)

                kept = self.masks[name].sum().item()
                total = self.masks[name].numel()
                stats[name] = {'kept': int(kept), 'total': total, 'actual_sparsity': 1 - kept/total}
        return stats

    def apply_episode_scar(self, base_sparsity: float, severity_factor: float = 1.0) -> Dict[str, Dict]:
        """
        Apply uniform scarring upon manic episode. Severity scales the base sparsity.
        This is the kindling mechanism: each episode causes permanent damage.
        """
        effective_sparsity = base_sparsity * severity_factor
        effective_sparsity = min(0.2, max(0.02, effective_sparsity))  # Clamp reasonable range

        stats = {}
        for name, param in self.model.named_parameters():
            if name in self.masks:
                # Only scar currently active (non-zero, non-scarred) positions
                active_mask = (self.masks[name] > 0) * (self.scar_masks[name] == 0)
                active_weights = param.data.abs() * active_mask.float()

                active_values = active_weights[active_mask]
                if active_values.numel() == 0:
                    stats[name] = {'scarred': 0, 'total_scars': self.scar_masks[name].sum().item()}
                    continue

                # Prune lowest magnitude among active connections
                num_to_scar = max(1, int(effective_sparsity * active_values.numel()))
                threshold = torch.kthvalue(active_values, min(num_to_scar, active_values.numel())).values

                new_scars = (active_weights <= threshold) * active_mask
                self.scar_masks[name] = torch.clamp(self.scar_masks[name] + new_scars.float(), 0, 1)
                self.masks[name] = self.masks[name] * (1 - new_scars.float())
                param.data *= self.masks[name]

                stats[name] = {
                    'scarred': new_scars.sum().item(),
                    'total_scars': self.scar_masks[name].sum().item()
                }
        return stats

    def _accumulate_gradients(self, num_batches: int = 30):
        model = self.model
        loss_fn = nn.CrossEntropyLoss()
        for name in self.gradient_buffer:
            self.gradient_buffer[name].zero_()
        model.train()
        original_stress = model.stress_level
        original_gain = model.gain_multiplier
        model.set_stress(0.0)
        model.set_gain(1.0)
        batch_count = 0
        for x, y in self.train_loader:
            if batch_count >= num_batches:
                break
            x, y = x.to(DEVICE), y.to(DEVICE)
            loss = loss_fn(model(x), y)
            loss.backward()
            with torch.no_grad():
                for name, param in model.named_parameters():
                    if name in self.masks:
                        # Only accumulate for pruned, non-scarred positions
                        pruned_mask = (self.masks[name] == 0).float()
                        non_scarred = (self.scar_masks[name] == 0).float()
                        self.gradient_buffer[name] += param.grad.abs() * pruned_mask * non_scarred
            model.zero_grad()
            batch_count += 1
        model.set_stress(original_stress)
        model.set_gain(original_gain)

    def gradient_guided_regrow(self, regrow_fraction: float, init_scale: float = None) -> Dict[str, Dict]:
        """
        Regrow connections, respecting scar masks (cannot regrow scarred positions).
        """
        if init_scale is None:
            init_scale = CONFIG['regrow_init_scale']
        self._accumulate_gradients(num_batches=CONFIG['gradient_accumulation_batches'])
        stats = {}
        for name, param in self.model.named_parameters():
            if name not in self.masks:
                continue
            mask = self.masks[name]
            scar_mask = self.scar_masks[name]

            # Only consider pruned AND non-scarred positions for regrowth
            regrowable = (mask == 0) & (scar_mask == 0)
            num_regrowable = regrowable.sum().item()

            if num_regrowable == 0:
                stats[name] = {'regrown': 0, 'still_pruned': 0, 'scarred': scar_mask.sum().item()}
                continue

            gradient_scores = self.gradient_buffer[name][regrowable]
            num_regrow = max(1, int(regrow_fraction * num_regrowable))
            num_regrow = min(num_regrow, gradient_scores.numel())

            _, top_indices = torch.topk(gradient_scores.flatten(), num_regrow)
            flat_regrowable_indices = torch.where(regrowable.flatten())[0]
            regrow_flat_indices = flat_regrowable_indices[top_indices]

            flat_mask = mask.flatten()
            flat_param = param.data.flatten()
            flat_mask[regrow_flat_indices] = 1.0
            flat_param[regrow_flat_indices] = torch.randn(num_regrow, device=param.device) * init_scale

            self.masks[name] = flat_mask.view_as(mask)
            param.data = flat_param.view_as(param)

            stats[name] = {
                'regrown': num_regrow,
                'still_pruned': int(num_regrowable - num_regrow),
                'scarred': scar_mask.sum().item()
            }
        return stats

    def apply_masks(self):
        with torch.no_grad():
            for name, param in self.model.named_parameters():
                if name in self.masks:
                    param.data *= self.masks[name]

    def get_sparsity(self) -> float:
        total = sum(m.numel() for m in self.masks.values())
        zeros = sum((m == 0).sum().item() for m in self.masks.values())
        return zeros / total if total > 0 else 0.0

    def get_scar_fraction(self) -> float:
        total = sum(m.numel() for m in self.scar_masks.values())
        scarred = sum(m.sum().item() for m in self.scar_masks.values())
        return scarred / total if total > 0 else 0.0

    def clone_masks(self) -> Dict[str, torch.Tensor]:
        return {k: v.clone() for k, v in self.masks.items()}

    def clone_scar_masks(self) -> Dict[str, torch.Tensor]:
        return {k: v.clone() for k, v in self.scar_masks.items()}

    def apply_early_adversity(self, max_scar_fraction: float, seed: int) -> float:
        """
        Apply random early adversity scarring (uniform across treatments).
        Returns the fraction of parameters scarred.
        """
        rng = np.random.RandomState(seed)
        scar_fraction = rng.uniform(0.0, max_scar_fraction)

        for name, param in self.model.named_parameters():
            if name in self.masks:
                num_params = param.numel()
                num_to_scar = int(scar_fraction * num_params)
                if num_to_scar > 0:
                    flat_scar = self.scar_masks[name].flatten()
                    flat_mask = self.masks[name].flatten()
                    flat_param = param.data.flatten()

                    # Random positions to scar
                    available_idx = torch.where(flat_scar == 0)[0]
                    if len(available_idx) > num_to_scar:
                        perm = torch.randperm(len(available_idx))[:num_to_scar]
                        scar_idx = available_idx[perm]
                        flat_scar[scar_idx] = 1.0
                        flat_mask[scar_idx] = 0.0
                        flat_param[scar_idx] = 0.0

                    self.scar_masks[name] = flat_scar.view_as(self.scar_masks[name])
                    self.masks[name] = flat_mask.view_as(self.masks[name])
                    param.data = flat_param.view_as(param)

        return scar_fraction


# ============================================================================
# TRAINING FUNCTIONS
# ============================================================================
def train(model: RecurrentStressNetwork, train_loader: DataLoader, epochs: int = 15, lr: float = 0.001,
          pruning_manager: PruningManager = None) -> List[float]:
    optimizer = optim.Adam(model.parameters(), lr=lr)
    loss_fn = nn.CrossEntropyLoss()
    losses = []
    original_stress = model.stress_level
    model.set_stress(0.0)
    for epoch in range(epochs):
        model.train()
        epoch_loss = 0.0
        for x, y in train_loader:
            x, y = x.to(DEVICE), y.to(DEVICE)
            optimizer.zero_grad()
            loss = loss_fn(model(x), y)
            loss.backward()
            optimizer.step()
            if pruning_manager:
                pruning_manager.apply_masks()
            epoch_loss += loss.item()
        losses.append(epoch_loss / len(train_loader))
    model.set_stress(original_stress)
    return losses


def train_with_stress_and_gain_schedule(model: RecurrentStressNetwork, train_loader: DataLoader, epochs: int, lr: float,
                                         initial_stress: float, final_stress: float = 0.0,
                                         initial_gain: float = 1.0, max_gain: float = 1.6,
                                         pruning_manager: PruningManager = None) -> List[float]:
    optimizer = optim.Adam(model.parameters(), lr=lr)
    loss_fn = nn.CrossEntropyLoss()
    losses = []
    for epoch in range(epochs):
        model.reduce_stress_gradually(epoch, epochs, initial_stress, final_stress)
        model.increase_gain_gradually(epoch, epochs, initial_gain, max_gain)
        model.train()
        epoch_loss = 0.0
        for x, y in train_loader:
            x, y = x.to(DEVICE), y.to(DEVICE)
            optimizer.zero_grad()
            loss = loss_fn(model(x), y)
            loss.backward()
            optimizer.step()
            if pruning_manager:
                pruning_manager.apply_masks()
            epoch_loss += loss.item()
        losses.append(epoch_loss / len(train_loader))
    model.set_stress(0.0)
    return losses


# ============================================================================
# EVALUATION FUNCTIONS
# ============================================================================
def evaluate(model: RecurrentStressNetwork, loader: DataLoader, input_noise: float = 0.0, internal_stress: float = 0.0) -> float:
    model.eval()
    model.set_stress(internal_stress)
    correct, total = 0, 0
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(DEVICE), y.to(DEVICE)
            if input_noise > 0:
                x = x + torch.randn_like(x) * input_noise
            correct += (model(x).argmax(dim=1) == y).sum().item()
            total += y.size(0)
    model.set_stress(0.0)
    return 100.0 * correct / total


def evaluate_biased_stress(model: RecurrentStressNetwork, loader: DataLoader, input_noise: float = 0.0,
                           internal_sigma: float = 1.0, bias: float = 1.0) -> float:
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(DEVICE), y.to(DEVICE)
            if input_noise > 0:
                x = x + torch.randn_like(x) * input_noise
            out = model.forward_with_biased_noise(x, internal_sigma, bias)
            correct += (out.argmax(dim=1) == y).sum().item()
            total += y.size(0)
    return 100.0 * correct / total


def get_avg_activation_magnitude(model: RecurrentStressNetwork, loader: DataLoader) -> float:
    """
    Compute average activation magnitude across time steps.
    Updated for GRU architecture.
    """
    model.eval()
    all_acts = []

    activation = model.tanh if model.use_tanh else model.relu

    with torch.no_grad():
        for x, _ in loader:
            x = x.to(DEVICE)
            batch_size = x.size(0)
            h = torch.zeros(batch_size, model.hidden_size, device=x.device)

            for t in range(x.size(1)):
                inp = x[:, t, :] * model.gain_multiplier
                h = model.gru_cell(inp, h)
                h = h * model.inhibition_strength
                h = activation(h)
                all_acts.append(torch.mean(torch.abs(h)).item())

    return np.mean(all_acts) if all_acts else 0.0


# ============================================================================
# TREATMENT PROTOCOLS
# ============================================================================
def ketamine_treatment(model: RecurrentStressNetwork, pruning_mgr: PruningManager, train_loader: DataLoader) -> Dict:
    model.set_gain(CONFIG['ketamine_gain'])
    pruning_mgr.gradient_guided_regrow(regrow_fraction=CONFIG['comparison_ketamine_regrow'])
    train(model, train_loader, epochs=CONFIG['comparison_ketamine_epochs'], lr=CONFIG['finetune_lr'], pruning_manager=pruning_mgr)
    return {'final_sparsity': pruning_mgr.get_sparsity(), 'final_gain': model.gain_multiplier}


def ssri_treatment(model: RecurrentStressNetwork, pruning_mgr: PruningManager, train_loader: DataLoader) -> Dict:
    train_with_stress_and_gain_schedule(
        model, train_loader, epochs=CONFIG['comparison_ssri_epochs'], lr=CONFIG['monoaminergic_lr'],
        initial_stress=CONFIG['monoaminergic_initial_stress'], final_stress=0.0,
        initial_gain=1.0, max_gain=CONFIG['ssri_max_gain'], pruning_manager=pruning_mgr
    )
    return {'final_sparsity': pruning_mgr.get_sparsity(), 'final_gain': model.gain_multiplier}


def neurosteroid_treatment(model: RecurrentStressNetwork, pruning_mgr: PruningManager, train_loader: DataLoader) -> Dict:
    model.set_inhibition(CONFIG['neurosteroid_inhibition_strength'], CONFIG['neurosteroid_use_tanh'])
    model.set_gain(CONFIG['neurosteroid_gain'])
    train(model, train_loader, epochs=CONFIG['neurosteroid_consolidation_epochs'], lr=CONFIG['finetune_lr'], pruning_manager=pruning_mgr)
    return {'final_sparsity': pruning_mgr.get_sparsity(), 'final_gain': model.gain_multiplier,
            'effective_scale': model.inhibition_strength * model.gain_multiplier}


# ============================================================================
# KINDLING SIMULATION (MULTI-CYCLE WITH UNIFORM SCARRING)
# ============================================================================
def simulate_kindling(
    model_state: Dict[str, torch.Tensor],
    mask_state: Dict[str, torch.Tensor],
    scar_state: Dict[str, torch.Tensor],
    treatment_type: str,
    train_loader: DataLoader,
    test_loader: DataLoader,
    seed: int
) -> Dict:
    """
    Simulate kindling: progressive sensitization through repeated manic episodes.

    Architecture mechanism:
    - Each relapse triggers uniform base scarring (same for all treatments)
    - Severity factor computed from current gain and activation magnitude (emergent)
    - Scars accumulate across cycles (permanent mask positions)
    - Trigger bias weakens over cycles to test spontaneous episode emergence
    - Treatment differences emerge from:
      - Ketamine: gradient-guided regrowth partially compensates for scars
      - SSRI: high gain amplifies severity factor
      - Neurosteroid: low gain/inhibition reduces severity factor
    """

    num_cycles = CONFIG['kindling_num_cycles']
    initial_bias = CONFIG['kindling_initial_trigger_bias']
    final_bias = CONFIG['kindling_final_trigger_bias']

    results = {
        'cycles': [],
        'biased_accuracies': [],
        'trigger_biases': [],
        'relapsed': [],
        'scar_fractions': [],
        'sparsities': [],
        'severity_factors': [],
        'activation_magnitudes': [],
        'effective_gains': []
    }

    # Create model from state
    model = RecurrentStressNetwork().to(DEVICE)
    model.load_state_dict(copy.deepcopy(model_state))

    mgr = PruningManager(model, train_loader)
    mgr.masks = {k: v.clone() for k, v in mask_state.items()}
    mgr.scar_masks = {k: v.clone() for k, v in scar_state.items()}
    mgr.apply_masks()

    # Apply treatment-specific parameters
    if treatment_type == 'ketamine':
        model.set_gain(CONFIG['ketamine_gain'])
    elif treatment_type == 'ssri':
        model.set_gain(CONFIG['ssri_max_gain'])
    elif treatment_type == 'neurosteroid':
        model.set_inhibition(CONFIG['neurosteroid_inhibition_strength'], CONFIG['neurosteroid_use_tanh'])
        model.set_gain(CONFIG['neurosteroid_gain'])

    for cycle in range(num_cycles):
        # Calculate weakening trigger bias (linear interpolation)
        progress = cycle / max(num_cycles - 1, 1)
        trigger_bias = initial_bias + progress * (final_bias - initial_bias)

        # Inter-episode maintenance (treatment-specific recovery)
        if cycle > 0:
            if treatment_type == 'ketamine':
                # Ketamine: attempt regrowth during maintenance
                mgr.gradient_guided_regrow(regrow_fraction=0.3)
                train(model, train_loader, epochs=CONFIG['kindling_inter_episode_maintenance_epochs'],
                      lr=CONFIG['finetune_lr'], pruning_manager=mgr)
            elif treatment_type == 'ssri':
                # SSRI: standard maintenance, no structural repair
                train(model, train_loader, epochs=CONFIG['kindling_inter_episode_maintenance_epochs'],
                      lr=CONFIG['monoaminergic_lr'], pruning_manager=mgr)
            elif treatment_type == 'neurosteroid':
                # Neurosteroid: maintenance with inhibition active
                train(model, train_loader, epochs=CONFIG['kindling_inter_episode_maintenance_epochs'],
                      lr=CONFIG['finetune_lr'], pruning_manager=mgr)

        # Measure current architectural state
        current_gain = model.get_effective_gain()
        current_act_mag = get_avg_activation_magnitude(model, test_loader)
        current_sparsity = mgr.get_sparsity()
        current_scar_fraction = mgr.get_scar_fraction()

        # Manic trigger test with current bias
        biased_acc = evaluate_biased_stress(
            model, test_loader,
            internal_sigma=CONFIG['mania_test_sigma'],
            bias=trigger_bias
        )

        # Determine relapse
        relapsed = biased_acc < CONFIG['manic_relapse_threshold']

        # Compute severity factor (emergent from architecture)
        # Higher gain and higher activation magnitude = more severe episode = more scarring
        severity_factor = 1.0 + (current_gain - 1.0) + (current_act_mag - 0.1) * 2.0
        severity_factor = max(CONFIG['kindling_severity_min'],
                            min(CONFIG['kindling_severity_max'], severity_factor))

        # Apply uniform scarring if relapsed
        if relapsed:
            mgr.apply_episode_scar(
                base_sparsity=CONFIG['kindling_base_scar_sparsity'],
                severity_factor=severity_factor
            )

        # Record results
        results['cycles'].append(cycle)
        results['trigger_biases'].append(trigger_bias)
        results['biased_accuracies'].append(biased_acc)
        results['relapsed'].append(relapsed)
        results['scar_fractions'].append(mgr.get_scar_fraction())
        results['sparsities'].append(mgr.get_sparsity())
        results['severity_factors'].append(severity_factor)
        results['activation_magnitudes'].append(current_act_mag)
        results['effective_gains'].append(current_gain)

    # Summary statistics
    results['total_relapses'] = sum(results['relapsed'])
    results['final_scar_fraction'] = results['scar_fractions'][-1]
    results['final_sparsity'] = results['sparsities'][-1]

    # Autonomy test: relapse at minimal trigger
    final_acc_minimal = evaluate_biased_stress(
        model, test_loader,
        internal_sigma=CONFIG['mania_test_sigma'],
        bias=0.3  # Very weak trigger
    )
    results['autonomy_test_accuracy'] = final_acc_minimal
    results['autonomy_achieved'] = final_acc_minimal < CONFIG['kindling_autonomy_threshold']

    # First cycle of relapse at weak trigger
    weak_trigger_relapses = [i for i, (b, r) in enumerate(zip(results['trigger_biases'], results['relapsed']))
                             if b <= 1.0 and r]
    results['first_weak_trigger_relapse'] = weak_trigger_relapses[0] if weak_trigger_relapses else None

    return results


# ============================================================================
# LONGITUDINAL MANIC RELAPSE SIMULATION
# ============================================================================
def simulate_longitudinal_manic_relapse(
    model_state: Dict[str, torch.Tensor],
    mask_state: Dict[str, torch.Tensor],
    treatment_type: str,
    treatment_params: Dict,
    train_loader: DataLoader,
    test_loader: DataLoader
) -> Dict:
    """
    Simulate longitudinal manic relapse after medication discontinuation.
    """
    decay_rates = {
        'ketamine': CONFIG['ketamine_ms_decay_rate'],
        'ssri': CONFIG['ssri_ms_decay_rate'],
        'neurosteroid': CONFIG['neurosteroid_ms_decay_rate']
    }
    ms_decay_rate = decay_rates[treatment_type]

    results = {
        'maintenance_durations': [],
        'biased_stress_accuracies': [],
        'ms_protection_at_test': [],
        'relapsed': [],
        'effective_gain_at_test': []
    }

    for duration in CONFIG['maintenance_durations']:
        model = RecurrentStressNetwork().to(DEVICE)
        model.load_state_dict(copy.deepcopy(model_state))

        mgr = PruningManager(model, train_loader)
        mgr.masks = {k: v.clone() for k, v in mask_state.items()}
        mgr.apply_masks()

        if treatment_type == 'ketamine':
            model.set_gain(treatment_params.get('final_gain', CONFIG['ketamine_gain']))
        elif treatment_type == 'ssri':
            model.set_gain(treatment_params.get('final_gain', CONFIG['ssri_max_gain']))
        elif treatment_type == 'neurosteroid':
            model.set_inhibition(CONFIG['neurosteroid_inhibition_strength'], CONFIG['neurosteroid_use_tanh'])
            model.set_gain(treatment_params.get('final_gain', CONFIG['neurosteroid_gain']))

        model.set_ms_protection(CONFIG['ms_max_protection'], ms_decay_rate)
        train(model, train_loader, epochs=duration, lr=CONFIG['maintenance_lr'], pruning_manager=mgr)
        model.reset_antidepressant_effects()

        for _ in range(CONFIG['post_discontinuation_steps']):
            model.decay_ms_protection()

        final_ms_protection = model.ms_protection_level
        effective_gain = model.get_effective_gain()

        biased_acc = evaluate_biased_stress(
            model, test_loader,
            internal_sigma=CONFIG['mania_test_sigma'],
            bias=CONFIG['mania_test_bias']
        )

        relapsed = biased_acc < CONFIG['manic_relapse_threshold']

        results['maintenance_durations'].append(duration)
        results['biased_stress_accuracies'].append(biased_acc)
        results['ms_protection_at_test'].append(final_ms_protection)
        results['relapsed'].append(relapsed)
        results['effective_gain_at_test'].append(effective_gain)

    results['relapse_count'] = sum(results['relapsed'])
    results['total_tests'] = len(results['relapsed'])
    results['relapse_probability'] = (results['relapse_count'] / results['total_tests']) * 100
    results['mean_biased_acc'] = np.mean(results['biased_stress_accuracies'])
    results['std_biased_acc'] = np.std(results['biased_stress_accuracies'])
    results['ms_decay_rate'] = ms_decay_rate

    return results


# ============================================================================
# SINGLE SEED EXPERIMENT
# ============================================================================
def run_single_seed(seed: int) -> Dict[str, Dict]:
    torch.manual_seed(seed)
    np.random.seed(seed)

    train_loader, test_loader, clean_test_loader = create_data_loaders(seed)

    # Train and prune base model
    base_model = RecurrentStressNetwork().to(DEVICE)
    train(base_model, train_loader, epochs=CONFIG['baseline_epochs'], lr=CONFIG['baseline_lr'])
    base_pruning_mgr = PruningManager(base_model, train_loader)
    base_pruning_mgr.prune_by_magnitude(sparsity=CONFIG['prune_sparsity'])

    # Apply early adversity (uniform random scarring)
    early_adversity_fraction = base_pruning_mgr.apply_early_adversity(
        CONFIG['early_adversity_max_scar'], seed
    )

    base_state_dict = {k: v.clone() for k, v in base_model.state_dict().items()}
    base_masks = {k: v.clone() for k, v in base_pruning_mgr.masks.items()}
    base_scars = base_pruning_mgr.clone_scar_masks()

    results = {'early_adversity_fraction': early_adversity_fraction}

    # Untreated
    untreated_results = {
        'sparsity': base_pruning_mgr.get_sparsity() * 100,
        'scar_fraction': base_pruning_mgr.get_scar_fraction() * 100,
        'gain': 1.0,
        'clean': evaluate(base_model, clean_test_loader),
        'standard': evaluate(base_model, test_loader),
        'combined': evaluate(base_model, test_loader, 1.0, 0.5),
        'biased_stress': evaluate_biased_stress(base_model, test_loader, internal_sigma=CONFIG['mania_test_sigma'], bias=CONFIG['mania_test_bias']),
        'activation_magnitude': get_avg_activation_magnitude(base_model, test_loader)
    }
    for stress_name, stress_level in CONFIG['extended_stress_levels'].items():
        untreated_results[f'stress_{stress_name}'] = evaluate(base_model, test_loader, 0.0, stress_level)
    results['untreated'] = untreated_results

    # ========================================================================
    # KETAMINE TREATMENT
    # ========================================================================
    ketamine_model = RecurrentStressNetwork().to(DEVICE)
    ketamine_model.load_state_dict(copy.deepcopy(base_state_dict))
    ketamine_mgr = PruningManager(ketamine_model, train_loader)
    ketamine_mgr.masks = {k: v.clone() for k, v in base_masks.items()}
    ketamine_mgr.scar_masks = {k: v.clone() for k, v in base_scars.items()}
    ketamine_mgr.apply_masks()
    ketamine_stats = ketamine_treatment(ketamine_model, ketamine_mgr, train_loader)

    ketamine_post_state = {k: v.clone() for k, v in ketamine_model.state_dict().items()}
    ketamine_post_masks = ketamine_mgr.clone_masks()
    ketamine_post_scars = ketamine_mgr.clone_scar_masks()

    ketamine_results = {
        'sparsity': ketamine_mgr.get_sparsity() * 100,
        'scar_fraction': ketamine_mgr.get_scar_fraction() * 100,
        'gain': ketamine_model.gain_multiplier,
        'clean': evaluate(ketamine_model, clean_test_loader),
        'standard': evaluate(ketamine_model, test_loader),
        'combined': evaluate(ketamine_model, test_loader, 1.0, 0.5),
        'biased_stress': evaluate_biased_stress(ketamine_model, test_loader, internal_sigma=CONFIG['mania_test_sigma'], bias=CONFIG['mania_test_bias']),
        'activation_magnitude': get_avg_activation_magnitude(ketamine_model, test_loader)
    }
    for stress_name, stress_level in CONFIG['extended_stress_levels'].items():
        ketamine_results[f'stress_{stress_name}'] = evaluate(ketamine_model, test_loader, 0.0, stress_level)

    # Longitudinal relapse
    ketamine_longitudinal = simulate_longitudinal_manic_relapse(
        ketamine_post_state, ketamine_post_masks, 'ketamine', ketamine_stats,
        train_loader, test_loader
    )
    ketamine_results['longitudinal'] = ketamine_longitudinal

    # Kindling simulation
    ketamine_kindling = simulate_kindling(
        ketamine_post_state, ketamine_post_masks, ketamine_post_scars,
        'ketamine', train_loader, test_loader, seed
    )
    ketamine_results['kindling'] = ketamine_kindling

    # Acute relapse
    pre_relapse = ketamine_results['combined']
    target_sparsity = min(ketamine_mgr.get_sparsity() + (1 - ketamine_mgr.get_sparsity()) * 0.40, 0.99)
    ketamine_mgr.prune_by_magnitude(sparsity=target_sparsity)
    ketamine_mgr.apply_masks()
    ketamine_results['relapse_drop'] = pre_relapse - evaluate(ketamine_model, test_loader, 1.0, 0.5)
    results['ketamine'] = ketamine_results

    # ========================================================================
    # SSRI TREATMENT
    # ========================================================================
    ssri_model = RecurrentStressNetwork().to(DEVICE)
    ssri_model.load_state_dict(copy.deepcopy(base_state_dict))
    ssri_mgr = PruningManager(ssri_model, train_loader)
    ssri_mgr.masks = {k: v.clone() for k, v in base_masks.items()}
    ssri_mgr.scar_masks = {k: v.clone() for k, v in base_scars.items()}
    ssri_mgr.apply_masks()
    ssri_stats = ssri_treatment(ssri_model, ssri_mgr, train_loader)

    ssri_post_state = {k: v.clone() for k, v in ssri_model.state_dict().items()}
    ssri_post_masks = ssri_mgr.clone_masks()
    ssri_post_scars = ssri_mgr.clone_scar_masks()

    ssri_results = {
        'sparsity': ssri_mgr.get_sparsity() * 100,
        'scar_fraction': ssri_mgr.get_scar_fraction() * 100,
        'gain': ssri_model.gain_multiplier,
        'clean': evaluate(ssri_model, clean_test_loader),
        'standard': evaluate(ssri_model, test_loader),
        'combined': evaluate(ssri_model, test_loader, 1.0, 0.5),
        'biased_stress': evaluate_biased_stress(ssri_model, test_loader, internal_sigma=CONFIG['mania_test_sigma'], bias=CONFIG['mania_test_bias']),
        'activation_magnitude': get_avg_activation_magnitude(ssri_model, test_loader)
    }
    for stress_name, stress_level in CONFIG['extended_stress_levels'].items():
        ssri_results[f'stress_{stress_name}'] = evaluate(ssri_model, test_loader, 0.0, stress_level)

    ssri_longitudinal = simulate_longitudinal_manic_relapse(
        ssri_post_state, ssri_post_masks, 'ssri', ssri_stats,
        train_loader, test_loader
    )
    ssri_results['longitudinal'] = ssri_longitudinal

    ssri_kindling = simulate_kindling(
        ssri_post_state, ssri_post_masks, ssri_post_scars,
        'ssri', train_loader, test_loader, seed
    )
    ssri_results['kindling'] = ssri_kindling

    pre_relapse = ssri_results['combined']
    target_sparsity = min(ssri_mgr.get_sparsity() + (1 - ssri_mgr.get_sparsity()) * 0.40, 0.99)
    ssri_mgr.prune_by_magnitude(sparsity=target_sparsity)
    ssri_mgr.apply_masks()
    ssri_results['relapse_drop'] = pre_relapse - evaluate(ssri_model, test_loader, 1.0, 0.5)
    results['ssri'] = ssri_results

    # ========================================================================
    # NEUROSTEROID TREATMENT
    # ========================================================================
    neuro_model = RecurrentStressNetwork().to(DEVICE)
    neuro_model.load_state_dict(copy.deepcopy(base_state_dict))
    neuro_mgr = PruningManager(neuro_model, train_loader)
    neuro_mgr.masks = {k: v.clone() for k, v in base_masks.items()}
    neuro_mgr.scar_masks = {k: v.clone() for k, v in base_scars.items()}
    neuro_mgr.apply_masks()
    neuro_stats = neurosteroid_treatment(neuro_model, neuro_mgr, train_loader)

    neuro_post_state = {k: v.clone() for k, v in neuro_model.state_dict().items()}
    neuro_post_masks = neuro_mgr.clone_masks()
    neuro_post_scars = neuro_mgr.clone_scar_masks()

    neuro_results = {
        'sparsity': neuro_mgr.get_sparsity() * 100,
        'scar_fraction': neuro_mgr.get_scar_fraction() * 100,
        'gain': neuro_model.gain_multiplier,
        'effective_scale': neuro_model.inhibition_strength * neuro_model.gain_multiplier,
        'clean': evaluate(neuro_model, clean_test_loader),
        'standard': evaluate(neuro_model, test_loader),
        'combined': evaluate(neuro_model, test_loader, 1.0, 0.5),
        'biased_stress': evaluate_biased_stress(neuro_model, test_loader, internal_sigma=CONFIG['mania_test_sigma'], bias=CONFIG['mania_test_bias']),
        'activation_magnitude': get_avg_activation_magnitude(neuro_model, test_loader)
    }
    for stress_name, stress_level in CONFIG['extended_stress_levels'].items():
        neuro_results[f'stress_{stress_name}'] = evaluate(neuro_model, test_loader, 0.0, stress_level)

    neuro_longitudinal = simulate_longitudinal_manic_relapse(
        neuro_post_state, neuro_post_masks, 'neurosteroid', neuro_stats,
        train_loader, test_loader
    )
    neuro_results['longitudinal'] = neuro_longitudinal

    neuro_kindling = simulate_kindling(
        neuro_post_state, neuro_post_masks, neuro_post_scars,
        'neurosteroid', train_loader, test_loader, seed
    )
    neuro_results['kindling'] = neuro_kindling

    # Off medication
    neuro_model.set_inhibition(1.0, False)
    neuro_model.set_gain(1.0)
    neuro_results['off_medication_combined'] = evaluate(neuro_model, test_loader, 1.0, 0.5)
    neuro_results['off_medication_extreme'] = evaluate(neuro_model, test_loader, 0.0, 2.5)
    neuro_results['off_medication_biased'] = evaluate_biased_stress(neuro_model, test_loader, internal_sigma=CONFIG['mania_test_sigma'], bias=CONFIG['mania_test_bias'])
    neuro_results['off_medication_activation'] = get_avg_activation_magnitude(neuro_model, test_loader)

    neuro_model.set_inhibition(CONFIG['neurosteroid_inhibition_strength'], CONFIG['neurosteroid_use_tanh'])
    neuro_model.set_gain(CONFIG['neurosteroid_gain'])
    pre_relapse = neuro_results['combined']
    target_sparsity = min(neuro_mgr.get_sparsity() + (1 - neuro_mgr.get_sparsity()) * 0.40, 0.99)
    neuro_mgr.prune_by_magnitude(sparsity=target_sparsity)
    neuro_mgr.apply_masks()
    neuro_results['relapse_drop'] = pre_relapse - evaluate(neuro_model, test_loader, 1.0, 0.5)
    results['neurosteroid'] = neuro_results

    return results


# ============================================================================
# AGGREGATE RESULTS
# ============================================================================
def aggregate_results(all_results: List[Dict[str, Dict]]) -> Dict:
    treatments = ['untreated', 'ketamine', 'ssri', 'neurosteroid']
    aggregated = {'treatments': {}, 'longitudinal': {}, 'kindling': {}}

    # Aggregate early adversity
    early_adversity = [r['early_adversity_fraction'] for r in all_results if 'early_adversity_fraction' in r]
    aggregated['early_adversity'] = {
        'mean': np.mean(early_adversity) * 100,
        'std': np.std(early_adversity) * 100,
        'min': np.min(early_adversity) * 100,
        'max': np.max(early_adversity) * 100
    }

    # Aggregate treatment metrics
    for treatment in treatments:
        metrics = defaultdict(list)
        for seed_result in all_results:
            if treatment in seed_result:
                for key, value in seed_result[treatment].items():
                    if isinstance(value, (int, float)):
                        metrics[key].append(value)

        aggregated['treatments'][treatment] = {}
        for key, values in metrics.items():
            aggregated['treatments'][treatment][key] = {
                'mean': np.mean(values),
                'std': np.std(values),
                'min': np.min(values),
                'max': np.max(values)
            }

    # Aggregate longitudinal data
    for treatment in ['ketamine', 'ssri', 'neurosteroid']:
        longitudinal_data = {
            'relapse_probabilities': [],
            'per_duration': defaultdict(lambda: {'biased_acc': [], 'ms_protection': [], 'relapsed': [], 'effective_gain': []})
        }

        for seed_result in all_results:
            if treatment in seed_result and 'longitudinal' in seed_result[treatment]:
                long = seed_result[treatment]['longitudinal']
                longitudinal_data['relapse_probabilities'].append(long['relapse_probability'])

                for i, dur in enumerate(long['maintenance_durations']):
                    longitudinal_data['per_duration'][dur]['biased_acc'].append(long['biased_stress_accuracies'][i])
                    longitudinal_data['per_duration'][dur]['ms_protection'].append(long['ms_protection_at_test'][i])
                    longitudinal_data['per_duration'][dur]['relapsed'].append(long['relapsed'][i])
                    longitudinal_data['per_duration'][dur]['effective_gain'].append(long['effective_gain_at_test'][i])

        probs = longitudinal_data['relapse_probabilities']
        aggregated['longitudinal'][treatment] = {
            'relapse_probability_mean': np.mean(probs),
            'relapse_probability_std': np.std(probs),
            'relapse_probability_min': np.min(probs),
            'relapse_probability_max': np.max(probs),
            'ms_decay_rate': CONFIG[f'{treatment}_ms_decay_rate'],
            'per_duration': {}
        }

        for dur in CONFIG['maintenance_durations']:
            dur_data = longitudinal_data['per_duration'][dur]
            if dur_data['biased_acc']:
                relapse_rate = (sum(dur_data['relapsed']) / len(dur_data['relapsed'])) * 100
                aggregated['longitudinal'][treatment]['per_duration'][dur] = {
                    'biased_acc_mean': np.mean(dur_data['biased_acc']),
                    'biased_acc_std': np.std(dur_data['biased_acc']),
                    'ms_protection_mean': np.mean(dur_data['ms_protection']),
                    'ms_protection_std': np.std(dur_data['ms_protection']),
                    'effective_gain_mean': np.mean(dur_data['effective_gain']),
                    'relapse_rate': relapse_rate,
                    'n_relapsed': sum(dur_data['relapsed']),
                    'n_total': len(dur_data['relapsed'])
                }

    # Aggregate kindling data
    for treatment in ['ketamine', 'ssri', 'neurosteroid']:
        kindling_data = {
            'total_relapses': [],
            'final_scar_fractions': [],
            'autonomy_achieved': [],
            'autonomy_test_accuracies': [],
            'per_cycle': defaultdict(lambda: {
                'biased_acc': [], 'trigger_bias': [], 'relapsed': [],
                'scar_fraction': [], 'severity_factor': [], 'activation_mag': []
            })
        }

        for seed_result in all_results:
            if treatment in seed_result and 'kindling' in seed_result[treatment]:
                kind = seed_result[treatment]['kindling']
                kindling_data['total_relapses'].append(kind['total_relapses'])
                kindling_data['final_scar_fractions'].append(kind['final_scar_fraction'])
                kindling_data['autonomy_achieved'].append(kind['autonomy_achieved'])
                kindling_data['autonomy_test_accuracies'].append(kind['autonomy_test_accuracy'])

                for i, cycle in enumerate(kind['cycles']):
                    kindling_data['per_cycle'][cycle]['biased_acc'].append(kind['biased_accuracies'][i])
                    kindling_data['per_cycle'][cycle]['trigger_bias'].append(kind['trigger_biases'][i])
                    kindling_data['per_cycle'][cycle]['relapsed'].append(kind['relapsed'][i])
                    kindling_data['per_cycle'][cycle]['scar_fraction'].append(kind['scar_fractions'][i])
                    kindling_data['per_cycle'][cycle]['severity_factor'].append(kind['severity_factors'][i])
                    kindling_data['per_cycle'][cycle]['activation_mag'].append(kind['activation_magnitudes'][i])

        aggregated['kindling'][treatment] = {
            'total_relapses_mean': np.mean(kindling_data['total_relapses']),
            'total_relapses_std': np.std(kindling_data['total_relapses']),
            'final_scar_fraction_mean': np.mean(kindling_data['final_scar_fractions']) * 100,
            'final_scar_fraction_std': np.std(kindling_data['final_scar_fractions']) * 100,
            'autonomy_rate': (sum(kindling_data['autonomy_achieved']) / len(kindling_data['autonomy_achieved'])) * 100,
            'autonomy_test_acc_mean': np.mean(kindling_data['autonomy_test_accuracies']),
            'autonomy_test_acc_std': np.std(kindling_data['autonomy_test_accuracies']),
            'per_cycle': {}
        }

        for cycle in range(CONFIG['kindling_num_cycles']):
            cycle_data = kindling_data['per_cycle'][cycle]
            if cycle_data['biased_acc']:
                relapse_rate = (sum(cycle_data['relapsed']) / len(cycle_data['relapsed'])) * 100
                aggregated['kindling'][treatment]['per_cycle'][cycle] = {
                    'biased_acc_mean': np.mean(cycle_data['biased_acc']),
                    'biased_acc_std': np.std(cycle_data['biased_acc']),
                    'trigger_bias': np.mean(cycle_data['trigger_bias']),
                    'relapse_rate': relapse_rate,
                    'scar_fraction_mean': np.mean(cycle_data['scar_fraction']) * 100,
                    'scar_fraction_std': np.std(cycle_data['scar_fraction']) * 100,
                    'severity_factor_mean': np.mean(cycle_data['severity_factor']),
                    'activation_mag_mean': np.mean(cycle_data['activation_mag'])
                }

    return aggregated


# ============================================================================
# MAIN EXPERIMENT
# ============================================================================
def run_multi_seed_experiment() -> Dict:
    print("\n" + "="*80)
    print("  MULTI-MECHANISM ANTIDEPRESSANT COMPARISON EXPERIMENT")
    print("  WITH MANIC CONVERSION + LONGITUDINAL RELAPSE + KINDLING")
    print("  GRU-BASED RECURRENT ARCHITECTURE")
    print("  Running across {} random seeds".format(CONFIG['n_seeds']))
    print("="*80)

    # Print architecture info
    print(f"\n  Architecture: GRUCell-based recurrent network")
    print(f"  Sequence length (depth analogue): {CONFIG['seq_len']}")
    print(f"  GRU hidden size: {CONFIG['gru_hidden_size']}")

    # Estimate parameter count
    input_dim = CONFIG['input_dim']
    hidden = CONFIG['gru_hidden_size']
    output_dim = CONFIG['output_dim']
    gru_params = 3 * hidden * (input_dim + hidden + 2)  # weight_ih, weight_hh, biases
    fc_params = hidden * output_dim + output_dim
    total_params = gru_params + fc_params
    print(f"  Estimated total parameters: ~{total_params:,}")

    all_results = []

    for seed in range(CONFIG['n_seeds']):
        print(f"\n  Seed {seed+1}/{CONFIG['n_seeds']}...", end=" ", flush=True)
        seed_results = run_single_seed(seed)
        all_results.append(seed_results)
        print("done")

    aggregated = aggregate_results(all_results)

    # ========================================================================
    # PRINT COMPREHENSIVE RESULTS
    # ========================================================================

    print("\n" + "="*80)
    print("  AGGREGATED RESULTS ACROSS {} SEEDS".format(CONFIG['n_seeds']))
    print("="*80)

    treatments = ['untreated', 'ketamine', 'ssri', 'neurosteroid']
    labels = {
        'untreated': 'Untreated (pruned)',
        'ketamine': 'Ketamine-like',
        'ssri': 'SSRI-like',
        'neurosteroid': 'Neurosteroid-like'
    }

    # ========================================================================
    # TABLE 1: EARLY ADVERSITY
    # ========================================================================
    print("\n" + "-"*80)
    print("  TABLE 1: EARLY ADVERSITY (Uniform Random Scarring)")
    print("-"*80)
    ea = aggregated['early_adversity']
    print(f"\n  Initial scar fraction: {ea['mean']:.2f}% ± {ea['std']:.2f}% (range: {ea['min']:.2f}% - {ea['max']:.2f}%)")

    # ========================================================================
    # TABLE 2: ANTIDEPRESSANT EFFICACY
    # ========================================================================
    print("\n" + "-"*80)
    print("  TABLE 2: ANTIDEPRESSANT EFFICACY (Mean ± Std)")
    print("-"*80)
    print(f"\n  {'Treatment':<22} {'Sparsity':>12} {'Scar %':>10} {'Clean':>14} {'Combined':>14}")
    print("  " + "-"*74)

    for t in treatments:
        r = aggregated['treatments'][t]
        scar = r.get('scar_fraction', {'mean': 0, 'std': 0})
        print(f"  {labels[t]:<22} "
              f"{r['sparsity']['mean']:>5.1f}±{r['sparsity']['std']:>4.1f}% "
              f"{scar['mean']:>4.1f}±{scar['std']:>3.1f}% "
              f"{r['clean']['mean']:>6.1f}±{r['clean']['std']:>4.1f}% "
              f"{r['combined']['mean']:>6.1f}±{r['combined']['std']:>4.1f}%")

    # ========================================================================
    # TABLE 3: MANIC CONVERSION RISK METRICS
    # ========================================================================
    print("\n" + "-"*80)
    print("  TABLE 3: MANIC CONVERSION RISK METRICS (Mean ± Std)")
    print("-"*80)
    print(f"\n  {'Treatment':<22} {'Gain':>10} {'Biased Stress':>18} {'Act. Magnitude':>18}")
    print("  " + "-"*70)

    for t in treatments:
        r = aggregated['treatments'][t]
        gain_str = f"{r['gain']['mean']:>5.2f}±{r['gain']['std']:>4.2f}"
        biased_str = f"{r['biased_stress']['mean']:>6.1f}±{r['biased_stress']['std']:>4.1f}%"
        act_str = f"{r['activation_magnitude']['mean']:>6.3f}±{r['activation_magnitude']['std']:>5.3f}"
        print(f"  {labels[t]:<22} {gain_str:>10} {biased_str:>18} {act_str:>18}")

    # ========================================================================
    # TABLE 4: ACUTE RELAPSE VULNERABILITY
    # ========================================================================
    print("\n" + "-"*80)
    print("  TABLE 4: ACUTE RELAPSE VULNERABILITY (Mean ± Std)")
    print("-"*80)
    print(f"\n  {'Treatment':<22} {'Relapse Drop':>18}")
    print("  " + "-"*42)

    for t in ['ketamine', 'ssri', 'neurosteroid']:
        r = aggregated['treatments'][t]
        print(f"  {labels[t]:<22} {r['relapse_drop']['mean']:>6.1f}±{r['relapse_drop']['std']:>4.1f}%")

    # ========================================================================
    # LONGITUDINAL MANIC RELAPSE RESULTS
    # ========================================================================
    print("\n" + "="*80)
    print("  LONGITUDINAL MANIC RELAPSE AFTER DISCONTINUATION")
    print("="*80)

    print("\n  Protocol: MS + antidepressant maintenance, then all medications discontinued.")
    print(f"            Manic relapse threshold: biased stress accuracy < {CONFIG['manic_relapse_threshold']:.0f}%")

    # ========================================================================
    # TABLE 5: OVERALL MANIC RELAPSE PROBABILITY
    # ========================================================================
    print("\n" + "-"*80)
    print("  TABLE 5: OVERALL MANIC RELAPSE PROBABILITY")
    print("-"*80)
    print(f"\n  {'Treatment':<22} {'Mean':>10} {'Std':>10} {'Min':>10} {'Max':>10}")
    print("  " + "-"*64)

    for t in ['ketamine', 'ssri', 'neurosteroid']:
        r = aggregated['longitudinal'][t]
        print(f"  {labels[t]:<22} "
              f"{r['relapse_probability_mean']:>9.1f}% "
              f"{r['relapse_probability_std']:>9.1f}% "
              f"{r['relapse_probability_min']:>9.1f}% "
              f"{r['relapse_probability_max']:>9.1f}%")

    # ========================================================================
    # TABLE 6: MANIC RELAPSE BY MAINTENANCE DURATION
    # ========================================================================
    print("\n" + "-"*80)
    print("  TABLE 6: MANIC RELAPSE RATE BY MAINTENANCE DURATION")
    print("-"*80)

    durations = CONFIG['maintenance_durations']
    header = f"  {'Treatment':<18}"
    for dur in durations:
        header += f" {dur:>8}"
    print(f"\n{header} epochs")
    print("  " + "-"*(18 + 9*len(durations)))

    for t in ['ketamine', 'ssri', 'neurosteroid']:
        row = f"  {labels[t]:<18}"
        for dur in durations:
            if dur in aggregated['longitudinal'][t]['per_duration']:
                rate = aggregated['longitudinal'][t]['per_duration'][dur]['relapse_rate']
                row += f" {rate:>7.0f}%"
            else:
                row += f" {'N/A':>8}"
        print(row)

    # ========================================================================
    # KINDLING RESULTS
    # ========================================================================
    print("\n" + "="*80)
    print("  KINDLING: PROGRESSIVE SENSITIZATION VIA REPEATED EPISODES")
    print("="*80)

    print("\n  Architecture mechanism: Uniform base scarring upon each manic relapse.")
    print("  Severity factor computed from current gain and activation magnitude (emergent).")
    print("  Trigger bias weakens over cycles to test spontaneous episode emergence.")
    print(f"  Base scar sparsity: {CONFIG['kindling_base_scar_sparsity']*100:.1f}%")
    print(f"  Autonomy threshold: biased stress accuracy < {CONFIG['kindling_autonomy_threshold']:.0f}% at weak trigger")

    # ========================================================================
    # TABLE 7: KINDLING SUMMARY
    # ========================================================================
    print("\n" + "-"*80)
    print("  TABLE 7: KINDLING SUMMARY (Mean ± Std)")
    print("-"*80)
    print(f"\n  {'Treatment':<18} {'Total Relapses':>16} {'Final Scar %':>14} {'Autonomy Rate':>14} {'Autonomy Acc':>14}")
    print("  " + "-"*78)

    for t in ['ketamine', 'ssri', 'neurosteroid']:
        k = aggregated['kindling'][t]
        print(f"  {labels[t]:<18} "
              f"{k['total_relapses_mean']:>7.1f}±{k['total_relapses_std']:>5.1f} "
              f"{k['final_scar_fraction_mean']:>6.1f}±{k['final_scar_fraction_std']:>4.1f}% "
              f"{k['autonomy_rate']:>13.0f}% "
              f"{k['autonomy_test_acc_mean']:>6.1f}±{k['autonomy_test_acc_std']:>4.1f}%")

    # ========================================================================
    # TABLE 8: KINDLING PROGRESSION BY CYCLE
    # ========================================================================
    print("\n" + "-"*80)
    print("  TABLE 8: KINDLING PROGRESSION BY CYCLE")
    print("-"*80)

    for t in ['ketamine', 'ssri', 'neurosteroid']:
        print(f"\n  {labels[t]}:")
        print(f"  {'Cycle':>6} {'Trigger':>8} {'Biased Acc':>14} {'Relapse %':>12} {'Scar %':>14} {'Severity':>10}")
        print("  " + "-"*66)

        for cycle in range(CONFIG['kindling_num_cycles']):
            if cycle in aggregated['kindling'][t]['per_cycle']:
                c = aggregated['kindling'][t]['per_cycle'][cycle]
                print(f"  {cycle:>6} {c['trigger_bias']:>8.2f} "
                      f"{c['biased_acc_mean']:>6.1f}±{c['biased_acc_std']:>4.1f}% "
                      f"{c['relapse_rate']:>11.0f}% "
                      f"{c['scar_fraction_mean']:>6.1f}±{c['scar_fraction_std']:>4.1f}% "
                      f"{c['severity_factor_mean']:>9.2f}")

    # ========================================================================
    # TABLE 9: KINDLING ARCHITECTURE PARAMETERS
    # ========================================================================
    print("\n" + "-"*80)
    print("  TABLE 9: KINDLING ARCHITECTURE PARAMETERS")
    print("-"*80)
    print(f"\n  {'Parameter':<40} {'Value':>15}")
    print("  " + "-"*57)
    print(f"  {'Base scar sparsity per episode':<40} {CONFIG['kindling_base_scar_sparsity']*100:>14.1f}%")
    print(f"  {'Severity factor range':<40} {CONFIG['kindling_severity_min']:>6.1f} - {CONFIG['kindling_severity_max']:.1f}")
    print(f"  {'Number of cycles':<40} {CONFIG['kindling_num_cycles']:>15}")
    print(f"  {'Initial trigger bias':<40} {CONFIG['kindling_initial_trigger_bias']:>15.2f}")
    print(f"  {'Final trigger bias':<40} {CONFIG['kindling_final_trigger_bias']:>15.2f}")
    print(f"  {'Autonomy test bias':<40} {0.3:>15.2f}")

    # ========================================================================
    # TABLE 10: NEUROSTEROID MEDICATION DEPENDENCE
    # ========================================================================
    print("\n" + "-"*80)
    print("  TABLE 10: NEUROSTEROID MEDICATION DEPENDENCE (Mean ± Std)")
    print("-"*80)

    neuro = aggregated['treatments']['neurosteroid']
    print(f"\n  {'Condition':<30} {'Combined':>18} {'Extreme':>18} {'Biased':>18}")
    print("  " + "-"*86)
    print(f"  {'On medication':<30} "
          f"{neuro['combined']['mean']:>6.1f}±{neuro['combined']['std']:>4.1f}% "
          f"{neuro['stress_extreme']['mean']:>6.1f}±{neuro['stress_extreme']['std']:>4.1f}% "
          f"{neuro['biased_stress']['mean']:>6.1f}±{neuro['biased_stress']['std']:>4.1f}%")
    print(f"  {'Off medication':<30} "
          f"{neuro['off_medication_combined']['mean']:>6.1f}±{neuro['off_medication_combined']['std']:>4.1f}% "
          f"{neuro['off_medication_extreme']['mean']:>6.1f}±{neuro['off_medication_extreme']['std']:>4.1f}% "
          f"{neuro['off_medication_biased']['mean']:>6.1f}±{neuro['off_medication_biased']['std']:>4.1f}%")

    # ========================================================================
    # SUMMARY COMPARISON MATRIX
    # ========================================================================
    print("\n" + "="*80)
    print("  SUMMARY COMPARISON MATRIX")
    print("="*80)

    print("""
  ┌─────────────────────┬────────────┬────────────┬──────────────┬─────────────┐
  │ Metric              │ Ketamine   │ SSRI       │ Neurosteroid │ Untreated   │
  ├─────────────────────┼────────────┼────────────┼──────────────┼─────────────┤""")

    ket = aggregated['treatments']['ketamine']
    ssri = aggregated['treatments']['ssri']
    neuro = aggregated['treatments']['neurosteroid']
    untreated = aggregated['treatments']['untreated']

    print(f"  │ Combined Stress (%) │ {ket['combined']['mean']:>10.1f} │ {ssri['combined']['mean']:>10.1f} │ {neuro['combined']['mean']:>12.1f} │ {untreated['combined']['mean']:>11.1f} │")
    print(f"  │ Biased Stress (%)   │ {ket['biased_stress']['mean']:>10.1f} │ {ssri['biased_stress']['mean']:>10.1f} │ {neuro['biased_stress']['mean']:>12.1f} │ {untreated['biased_stress']['mean']:>11.1f} │")
    print(f"  │ Gain Multiplier     │ {ket['gain']['mean']:>10.2f} │ {ssri['gain']['mean']:>10.2f} │ {neuro['gain']['mean']:>12.2f} │ {untreated['gain']['mean']:>11.2f} │")
    print(f"  │ Acute Relapse Drop  │ {ket['relapse_drop']['mean']:>10.1f} │ {ssri['relapse_drop']['mean']:>10.1f} │ {neuro['relapse_drop']['mean']:>12.1f} │         N/A │")

    ket_long = aggregated['longitudinal']['ketamine']
    ssri_long = aggregated['longitudinal']['ssri']
    neuro_long = aggregated['longitudinal']['neurosteroid']

    print(f"  │ Long. Relapse Prob. │ {ket_long['relapse_probability_mean']:>9.1f}% │ {ssri_long['relapse_probability_mean']:>9.1f}% │ {neuro_long['relapse_probability_mean']:>11.1f}% │         N/A │")

    ket_kind = aggregated['kindling']['ketamine']
    ssri_kind = aggregated['kindling']['ssri']
    neuro_kind = aggregated['kindling']['neurosteroid']

    print(f"  │ Kindling Relapses   │ {ket_kind['total_relapses_mean']:>10.1f} │ {ssri_kind['total_relapses_mean']:>10.1f} │ {neuro_kind['total_relapses_mean']:>12.1f} │         N/A │")
    print(f"  │ Final Scar (%)      │ {ket_kind['final_scar_fraction_mean']:>10.1f} │ {ssri_kind['final_scar_fraction_mean']:>10.1f} │ {neuro_kind['final_scar_fraction_mean']:>12.1f} │         N/A │")
    print(f"  │ Autonomy Rate (%)   │ {ket_kind['autonomy_rate']:>10.0f} │ {ssri_kind['autonomy_rate']:>10.0f} │ {neuro_kind['autonomy_rate']:>12.0f} │         N/A │")

    print("  └─────────────────────┴────────────┴────────────┴──────────────┴─────────────┘")

    # ========================================================================
    # KINDLING RISK RANKING
    # ========================================================================
    print("\n" + "-"*80)
    print("  KINDLING RISK RANKING (by autonomy rate and total relapses)")
    print("-"*80)

    risk_data = [
        ('Ketamine-like', ket_kind['autonomy_rate'], ket_kind['total_relapses_mean'], ket_kind['final_scar_fraction_mean']),
        ('SSRI-like', ssri_kind['autonomy_rate'], ssri_kind['total_relapses_mean'], ssri_kind['final_scar_fraction_mean']),
        ('Neurosteroid-like', neuro_kind['autonomy_rate'], neuro_kind['total_relapses_mean'], neuro_kind['final_scar_fraction_mean'])
    ]
    risk_data.sort(key=lambda x: (x[1], x[2]))

    print(f"\n  {'Rank':<6} {'Treatment':<22} {'Autonomy %':>12} {'Relapses':>12} {'Final Scar %':>14}")
    print("  " + "-"*68)

    for i, (name, auto, rel, scar) in enumerate(risk_data):
        print(f"  {i+1:<6} {name:<22} {auto:>11.0f}% {rel:>12.1f} {scar:>13.1f}%")

    print("\n  Note: Lower autonomy rate and fewer relapses = better kindling resistance")
    print("        Scar accumulation reflects permanent structural damage from episodes")

    print("\n" + "="*80)
    print("  EXPERIMENT COMPLETE")
    print("  GRU-based recurrent architecture with per-step modulations")
    print("  Results averaged across {} random seeds".format(CONFIG['n_seeds']))
    print("="*80 + "\n")

    return {'all_results': all_results, 'aggregated': aggregated}


# ============================================================================
# ENTRY POINT
# ============================================================================
if __name__ == "__main__":
    print("\n" + "#"*80)
    print("#" + " "*78 + "#")
    print("#" + " MULTI-MECHANISM ANTIDEPRESSANT COMPARISON ".center(78) + "#")
    print("#" + " WITH MANIC CONVERSION + LONGITUDINAL RELAPSE + KINDLING ".center(78) + "#")
    print("#" + " GRU-BASED RECURRENT ARCHITECTURE ".center(78) + "#")
    print("#" + " Ketamine vs SSRI vs Neurosteroid ".center(78) + "#")
    print("#" + f" ({CONFIG['n_seeds']} Random Seeds) ".center(78) + "#")
    print("#" + " "*78 + "#")
    print("#"*80)

    results = run_multi_seed_experiment()


################################################################################
#                                                                              #
#                  MULTI-MECHANISM ANTIDEPRESSANT COMPARISON                   #
#           WITH MANIC CONVERSION + LONGITUDINAL RELAPSE + KINDLING            #
#                       GRU-BASED RECURRENT ARCHITECTURE                       #
#                       Ketamine vs SSRI vs Neurosteroid                       #
#                              (10 Random Seeds)                               #
#                                                                              #
################################################################################

  MULTI-MECHANISM ANTIDEPRESSANT COMPARISON EXPERIMENT
  WITH MANIC CONVERSION + LONGITUDINAL RELAPSE + KINDLING
  GRU-BASED RECURRENT ARCHITECTURE
  Running across 10 random seeds

  Architecture: GRUCell-based recurrent network
  Sequence length (depth analogue): 20


# The End