# Manic Conversion by Diff Antidep

In [1]:
"""
================================================================================
MULTI-MECHANISM ANTIDEPRESSANT COMPARISON EXPERIMENT
WITH MANIC CONVERSION RISK MODELING
================================================================================

This script compares three antidepressant mechanisms and models their
differential risk of inducing manic conversion (treatment-emergent mania):

1. KETAMINE-LIKE: Gradient-guided synaptogenesis (moderate transient excitability)
2. SSRI-LIKE: Gradual stabilization (progressive excitability overshoot)
3. NEUROSTEROID-LIKE: Tonic inhibition enhancement (protective low excitability)

Manic conversion is modeled as treatment-induced hyperexcitability leading to
network dysfunction under positively-biased perturbation.
================================================================================
"""

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.utils.data import DataLoader, TensorDataset
from typing import Dict, Tuple, List
import warnings

warnings.filterwarnings('ignore', category=UserWarning)

# ============================================================================
# REPRODUCIBILITY
# ============================================================================
SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)
DEVICE = torch.device('cpu')

# ============================================================================
# CONFIGURATION
# ============================================================================
CONFIG = {
    # Data generation
    'n_train': 12000,
    'n_test': 4000,
    'n_clean_test': 2000,
    'data_noise': 0.8,
    'batch_size': 128,

    # Network architecture
    'hidden_dims': [512, 512, 256],
    'input_dim': 2,
    'output_dim': 4,

    # Training hyperparameters
    'baseline_epochs': 20,
    'baseline_lr': 0.001,
    'finetune_epochs': 15,
    'finetune_lr': 0.0005,

    # Pruning parameters
    'prune_sparsity': 0.95,

    # Regrowth parameters
    'regrow_fraction': 0.5,
    'regrow_init_scale': 0.03,
    'gradient_accumulation_batches': 30,

    # Stress levels for evaluation
    'extended_stress_levels': {
        'none': 0.0,
        'moderate': 0.5,
        'high': 1.0,
        'severe': 1.5,
        'extreme': 2.5
    },

    # ========================================================================
    # MONOAMINERGIC (SSRI-LIKE) TREATMENT PARAMETERS
    # ========================================================================
    'monoaminergic_epochs': 100,
    'monoaminergic_lr': 1e-5,
    'monoaminergic_initial_stress': 0.5,
    'ssri_max_gain': 1.6,  # Peak excitability overshoot

    # ========================================================================
    # KETAMINE-LIKE TREATMENT PARAMETERS
    # ========================================================================
    'ketamine_gain': 1.25,  # Moderate glutamatergic surge

    # ========================================================================
    # NEUROSTEROID (GABAergic) TREATMENT PARAMETERS
    # ========================================================================
    'neurosteroid_inhibition_strength': 0.7,
    'neurosteroid_use_tanh': True,
    'neurosteroid_consolidation_epochs': 10,
    'neurosteroid_gain': 0.85,  # Protective reduction

    # ========================================================================
    # MANIC CONVERSION EVALUATION PARAMETERS
    # ========================================================================
    'mania_test_bias': 1.0,  # Positive bias magnitude
    'mania_test_sigma': 1.0,  # Internal noise for mania test

    # ========================================================================
    # MULTI-MECHANISM COMPARISON PARAMETERS
    # ========================================================================
    'comparison_ketamine_regrow': 0.5,
    'comparison_ketamine_epochs': 15,
    'comparison_ssri_epochs': 100,
    'comparison_neurosteroid_strength': 0.7,
    'comparison_neurosteroid_epochs': 10,
}


# ============================================================================
# DATA GENERATION
# ============================================================================
def generate_blobs(
    n_samples: int = 10000,
    noise: float = 0.8,
    seed: int = None
) -> Tuple[torch.Tensor, torch.Tensor]:
    """Generate 4-class Gaussian blob classification data."""
    if seed is not None:
        rng = np.random.RandomState(seed)
    else:
        rng = np.random.RandomState()

    centers = np.array([[-3, -3], [3, 3], [-3, 3], [3, -3]])
    labels = rng.randint(0, 4, n_samples)
    data = centers[labels] + rng.randn(n_samples, 2) * noise

    return (
        torch.tensor(data, dtype=torch.float32),
        torch.tensor(labels, dtype=torch.long)
    )


def create_data_loaders() -> Tuple[DataLoader, DataLoader, DataLoader]:
    """Create train, test, and clean test data loaders."""
    train_data, train_labels = generate_blobs(CONFIG['n_train'], noise=CONFIG['data_noise'], seed=100)
    test_data, test_labels = generate_blobs(CONFIG['n_test'], noise=CONFIG['data_noise'], seed=200)
    clean_test_data, clean_test_labels = generate_blobs(CONFIG['n_clean_test'], noise=0.0, seed=300)

    train_loader = DataLoader(TensorDataset(train_data, train_labels), batch_size=CONFIG['batch_size'], shuffle=True)
    test_loader = DataLoader(TensorDataset(test_data, test_labels), batch_size=1000)
    clean_test_loader = DataLoader(TensorDataset(clean_test_data, clean_test_labels), batch_size=1000)

    return train_loader, test_loader, clean_test_loader


train_loader, test_loader, clean_test_loader = create_data_loaders()


# ============================================================================
# NETWORK ARCHITECTURE WITH EXCITABILITY GAIN
# ============================================================================
class StressAwareNetwork(nn.Module):
    """
    Feed-forward network with internal noise injection, GABAergic modulation,
    and excitability gain modeling.

    Supports four modulation mechanisms:
    - stress_level: Internal noise (neuromodulatory disruption)
    - inhibition_strength: Multiplicative damping (tonic GABA inhibition)
    - use_tanh: Bounded activation (shunting inhibition)
    - gain_multiplier: Net excitatory tone (NEW - models manic conversion risk)
    """

    def __init__(self, hidden_dims: List[int] = None):
        super().__init__()
        if hidden_dims is None:
            hidden_dims = CONFIG['hidden_dims']

        self.fc1 = nn.Linear(CONFIG['input_dim'], hidden_dims[0])
        self.fc2 = nn.Linear(hidden_dims[0], hidden_dims[1])
        self.fc3 = nn.Linear(hidden_dims[1], hidden_dims[2])
        self.fc4 = nn.Linear(hidden_dims[2], CONFIG['output_dim'])

        self.relu = nn.ReLU()
        self.tanh = nn.Tanh()

        # Modulation parameters
        self.stress_level = 0.0           # Internal noise magnitude
        self.inhibition_strength = 1.0    # Multiplicative damping (1.0 = none)
        self.use_tanh = False             # Use bounded activation
        self.gain_multiplier = 1.0        # NEW: Net excitability gain (1.0 = balanced)

        self.weight_layers = ['fc1', 'fc2', 'fc3', 'fc4']

    def set_stress(self, level: float):
        """Set internal noise level for stress simulation."""
        self.stress_level = level

    def set_inhibition(self, strength: float, use_tanh: bool = False):
        """Set GABAergic tonic inhibition parameters."""
        self.inhibition_strength = strength
        self.use_tanh = use_tanh

    def set_gain(self, gain: float):
        """Set excitability gain multiplier (NEW)."""
        self.gain_multiplier = gain

    def reduce_stress_gradually(self, epoch: int, total_epochs: int,
                                 initial_stress: float = 0.5, final_stress: float = 0.0):
        """Linearly reduce internal stress over epochs (SSRI-like)."""
        progress = epoch / max(total_epochs - 1, 1)
        self.stress_level = initial_stress + progress * (final_stress - initial_stress)

    def increase_gain_gradually(self, epoch: int, total_epochs: int,
                                 initial_gain: float = 1.0, max_gain: float = 1.6):
        """
        Linearly increase excitability gain over epochs (SSRI-like monoaminergic overshoot).
        Models progressive dopaminergic/noradrenergic escalation.
        """
        progress = epoch / max(total_epochs - 1, 1)
        self.gain_multiplier = initial_gain + progress * (max_gain - initial_gain)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward pass with noise injection, inhibitory modulation, and gain."""
        activation = self.tanh if self.use_tanh else self.relu

        # Layer 1
        h = activation(self.fc1(x))
        if self.stress_level > 0:
            h = h + torch.randn_like(h) * self.stress_level
        h = h * self.inhibition_strength
        h = h * self.gain_multiplier  # Apply excitability gain after damping

        # Layer 2
        h = activation(self.fc2(h))
        if self.stress_level > 0:
            h = h + torch.randn_like(h) * self.stress_level
        h = h * self.inhibition_strength
        h = h * self.gain_multiplier  # Apply excitability gain after damping

        # Layer 3
        h = activation(self.fc3(h))
        if self.stress_level > 0:
            h = h + torch.randn_like(h) * self.stress_level
        h = h * self.inhibition_strength
        h = h * self.gain_multiplier  # Apply excitability gain after damping

        # Output layer (no modulation)
        return self.fc4(h)

    def forward_with_biased_noise(self, x: torch.Tensor, internal_sigma: float = 1.0,
                                   bias: float = 1.0) -> torch.Tensor:
        """
        Forward pass with positively-biased internal noise (mania vulnerability test).
        High gain amplifies bias → rapid positive saturation → accuracy collapse.
        """
        activation = self.tanh if self.use_tanh else self.relu

        # Layer 1
        h = activation(self.fc1(x))
        h = h + (torch.randn_like(h) * internal_sigma + bias)  # Biased noise
        h = h * self.inhibition_strength
        h = h * self.gain_multiplier

        # Layer 2
        h = activation(self.fc2(h))
        h = h + (torch.randn_like(h) * internal_sigma + bias)
        h = h * self.inhibition_strength
        h = h * self.gain_multiplier

        # Layer 3
        h = activation(self.fc3(h))
        h = h + (torch.randn_like(h) * internal_sigma + bias)
        h = h * self.inhibition_strength
        h = h * self.gain_multiplier

        return self.fc4(h)

    def count_parameters(self) -> Tuple[int, int]:
        """Count total and non-zero parameters."""
        total = sum(p.numel() for p in self.parameters())
        nonzero = sum((p != 0).sum().item() for p in self.parameters())
        return total, nonzero


# ============================================================================
# PRUNING MANAGER
# ============================================================================
class PruningManager:
    """Manages structured pruning and gradient-guided regrowth."""

    def __init__(self, model: StressAwareNetwork):
        self.model = model
        self.masks = {}
        self.gradient_buffer = {}

        for name, param in model.named_parameters():
            if 'weight' in name and param.dim() >= 2:
                self.masks[name] = torch.ones_like(param, dtype=torch.float32)
                self.gradient_buffer[name] = torch.zeros_like(param)

    def prune_by_magnitude(self, sparsity: float, per_layer: bool = True) -> Dict[str, Dict]:
        """Prune weights by magnitude."""
        stats = {}

        for name, param in self.model.named_parameters():
            if name in self.masks:
                weights = param.data.abs()
                threshold = torch.quantile(weights.flatten(), sparsity)
                self.masks[name] = (weights >= threshold).float()
                param.data *= self.masks[name]

                kept = self.masks[name].sum().item()
                total = self.masks[name].numel()
                stats[name] = {'kept': int(kept), 'total': total, 'actual_sparsity': 1 - kept/total}

        return stats

    def _accumulate_gradients(self, num_batches: int = 30):
        """Accumulate gradient magnitudes at pruned positions."""
        model = self.model
        loss_fn = nn.CrossEntropyLoss()

        for name in self.gradient_buffer:
            self.gradient_buffer[name].zero_()

        model.train()
        original_stress = model.stress_level
        original_gain = model.gain_multiplier
        model.set_stress(0.0)
        model.set_gain(1.0)

        batch_count = 0
        for x, y in train_loader:
            if batch_count >= num_batches:
                break
            x, y = x.to(DEVICE), y.to(DEVICE)
            loss = loss_fn(model(x), y)
            loss.backward()

            with torch.no_grad():
                for name, param in model.named_parameters():
                    if name in self.masks:
                        pruned_mask = (self.masks[name] == 0).float()
                        self.gradient_buffer[name] += param.grad.abs() * pruned_mask
            model.zero_grad()
            batch_count += 1

        model.set_stress(original_stress)
        model.set_gain(original_gain)

    def gradient_guided_regrow(self, regrow_fraction: float,
                                init_scale: float = None) -> Dict[str, Dict]:
        """Regrow pruned connections based on gradient importance."""
        if init_scale is None:
            init_scale = CONFIG['regrow_init_scale']

        self._accumulate_gradients(num_batches=CONFIG['gradient_accumulation_batches'])

        stats = {}
        for name, param in self.model.named_parameters():
            if name not in self.masks:
                continue

            mask = self.masks[name]
            pruned_positions = (mask == 0)
            num_pruned = pruned_positions.sum().item()

            if num_pruned == 0:
                stats[name] = {'regrown': 0, 'still_pruned': 0}
                continue

            gradient_scores = self.gradient_buffer[name][pruned_positions]
            num_regrow = max(1, int(regrow_fraction * num_pruned))
            num_regrow = min(num_regrow, gradient_scores.numel())

            _, top_indices = torch.topk(gradient_scores.flatten(), num_regrow)
            flat_pruned_indices = torch.where(pruned_positions.flatten())[0]
            regrow_flat_indices = flat_pruned_indices[top_indices]

            flat_mask = mask.flatten()
            flat_param = param.data.flatten()
            flat_mask[regrow_flat_indices] = 1.0
            flat_param[regrow_flat_indices] = torch.randn(num_regrow) * init_scale

            self.masks[name] = flat_mask.view_as(mask)
            param.data = flat_param.view_as(param)

            stats[name] = {'regrown': num_regrow, 'still_pruned': int(num_pruned - num_regrow)}

        return stats

    def apply_masks(self):
        """Re-apply masks to maintain sparsity."""
        with torch.no_grad():
            for name, param in self.model.named_parameters():
                if name in self.masks:
                    param.data *= self.masks[name]

    def get_sparsity(self) -> float:
        """Calculate overall network sparsity."""
        total = sum(m.numel() for m in self.masks.values())
        zeros = sum((m == 0).sum().item() for m in self.masks.values())
        return zeros / total if total > 0 else 0.0


# ============================================================================
# TRAINING FUNCTIONS
# ============================================================================
def train(model: StressAwareNetwork, epochs: int = 15, lr: float = 0.001,
          pruning_manager: PruningManager = None, verbose: bool = False) -> List[float]:
    """Standard training loop."""
    optimizer = optim.Adam(model.parameters(), lr=lr)
    loss_fn = nn.CrossEntropyLoss()
    losses = []

    original_stress = model.stress_level
    model.set_stress(0.0)

    for epoch in range(epochs):
        model.train()
        epoch_loss = 0.0
        for x, y in train_loader:
            x, y = x.to(DEVICE), y.to(DEVICE)
            optimizer.zero_grad()
            loss = loss_fn(model(x), y)
            loss.backward()
            optimizer.step()
            if pruning_manager:
                pruning_manager.apply_masks()
            epoch_loss += loss.item()
        losses.append(epoch_loss / len(train_loader))
        if verbose:
            print(f"      Epoch {epoch+1}/{epochs}, Loss: {losses[-1]:.4f}")

    model.set_stress(original_stress)
    return losses


def train_with_stress_and_gain_schedule(model: StressAwareNetwork, epochs: int, lr: float,
                                         initial_stress: float, final_stress: float = 0.0,
                                         initial_gain: float = 1.0, max_gain: float = 1.6,
                                         pruning_manager: PruningManager = None,
                                         verbose: bool = False, print_interval: int = 20) -> List[float]:
    """
    Train with gradually reducing internal stress AND increasing excitability gain.
    Models SSRI-like monoaminergic adaptation with progressive excitability overshoot.
    """
    optimizer = optim.Adam(model.parameters(), lr=lr)
    loss_fn = nn.CrossEntropyLoss()
    losses = []

    for epoch in range(epochs):
        # Update both stress and gain schedules
        model.reduce_stress_gradually(epoch, epochs, initial_stress, final_stress)
        model.increase_gain_gradually(epoch, epochs, initial_gain, max_gain)

        model.train()
        epoch_loss = 0.0

        for x, y in train_loader:
            x, y = x.to(DEVICE), y.to(DEVICE)
            optimizer.zero_grad()
            loss = loss_fn(model(x), y)
            loss.backward()
            optimizer.step()
            if pruning_manager:
                pruning_manager.apply_masks()
            epoch_loss += loss.item()

        losses.append(epoch_loss / len(train_loader))
        if verbose and (epoch + 1) % print_interval == 0:
            print(f"      SSRI epoch {epoch+1}/{epochs}, stress: {model.stress_level:.3f}, "
                  f"gain: {model.gain_multiplier:.3f}, loss: {losses[-1]:.4f}")

    model.set_stress(0.0)
    # Note: gain_multiplier remains at max_gain after training (persistent overshoot)
    return losses


# ============================================================================
# EVALUATION FUNCTIONS
# ============================================================================
def evaluate(model: StressAwareNetwork, loader: DataLoader,
             input_noise: float = 0.0, internal_stress: float = 0.0) -> float:
    """Evaluate model accuracy."""
    model.eval()
    model.set_stress(internal_stress)
    correct, total = 0, 0

    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(DEVICE), y.to(DEVICE)
            if input_noise > 0:
                x = x + torch.randn_like(x) * input_noise
            correct += (model(x).argmax(dim=1) == y).sum().item()
            total += y.size(0)

    model.set_stress(0.0)
    return 100.0 * correct / total


def evaluate_with_neurosteroid(model: StressAwareNetwork, loader: DataLoader,
                                input_noise: float = 0.0, internal_stress: float = 0.0) -> float:
    """Evaluate with neurosteroid modulation ACTIVE (inhibition settings preserved)."""
    model.eval()
    model.set_stress(internal_stress)
    correct, total = 0, 0

    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(DEVICE), y.to(DEVICE)
            if input_noise > 0:
                x = x + torch.randn_like(x) * input_noise
            correct += (model(x).argmax(dim=1) == y).sum().item()
            total += y.size(0)

    model.set_stress(0.0)
    return 100.0 * correct / total


def evaluate_biased_stress(model: StressAwareNetwork, loader: DataLoader,
                           input_noise: float = 0.0, internal_sigma: float = 1.0,
                           bias: float = 1.0) -> float:
    """
    Evaluate accuracy under excitatory-biased internal noise (mania vulnerability proxy).

    High gain amplifies positive bias → rapid saturation → accuracy collapse = manic switch.
    Lower accuracy indicates HIGHER manic conversion risk.
    """
    model.eval()
    correct, total = 0, 0

    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(DEVICE), y.to(DEVICE)
            if input_noise > 0:
                x = x + torch.randn_like(x) * input_noise
            out = model.forward_with_biased_noise(x, internal_sigma, bias)
            correct += (out.argmax(dim=1) == y).sum().item()
            total += y.size(0)

    return 100.0 * correct / total


def get_avg_activation_magnitude(model: StressAwareNetwork, loader: DataLoader) -> Dict[str, float]:
    """
    Measure average hidden activation magnitude (hyperexcitability risk indicator).

    Higher values indicate greater latent manic conversion risk.
    """
    model.eval()
    activation = model.tanh if model.use_tanh else model.relu

    layer_activations = {'layer1': [], 'layer2': [], 'layer3': []}

    with torch.no_grad():
        for x, _ in loader:
            x = x.to(DEVICE)

            # Layer 1
            h1 = activation(model.fc1(x))
            h1 = h1 * model.inhibition_strength * model.gain_multiplier
            layer_activations['layer1'].append(torch.mean(torch.abs(h1)).item())

            # Layer 2
            h2 = activation(model.fc2(h1))
            h2 = h2 * model.inhibition_strength * model.gain_multiplier
            layer_activations['layer2'].append(torch.mean(torch.abs(h2)).item())

            # Layer 3
            h3 = activation(model.fc3(h2))
            h3 = h3 * model.inhibition_strength * model.gain_multiplier
            layer_activations['layer3'].append(torch.mean(torch.abs(h3)).item())

    return {
        'layer1': np.mean(layer_activations['layer1']),
        'layer2': np.mean(layer_activations['layer2']),
        'layer3': np.mean(layer_activations['layer3']),
        'mean': np.mean([np.mean(v) for v in layer_activations.values()])
    }


# ============================================================================
# TREATMENT PROTOCOLS
# ============================================================================
def ketamine_treatment(model: StressAwareNetwork, pruning_mgr: PruningManager,
                       regrow_fraction: float = None, consolidation_epochs: int = None,
                       verbose: bool = True) -> Dict:
    """
    KETAMINE-LIKE TREATMENT: Gradient-guided synaptogenesis with moderate gain surge.

    Biological model:
    - NMDA antagonism → BDNF release → mTOR activation → new spine formation
    - Transient glutamatergic disinhibition (moderate excitability increase)
    - Brief consolidation allows weight compensation

    Manic conversion risk: LOW-MODERATE
    - Moderate transient gain (1.25) with effective consolidation
    """
    if regrow_fraction is None:
        regrow_fraction = CONFIG['comparison_ketamine_regrow']
    if consolidation_epochs is None:
        consolidation_epochs = CONFIG['comparison_ketamine_epochs']

    if verbose:
        print(f"\n    KETAMINE-LIKE TREATMENT:")
        print(f"      Regrowth fraction: {regrow_fraction*100:.0f}%")
        print(f"      Consolidation: {consolidation_epochs} epochs")
        print(f"      Excitability gain: {CONFIG['ketamine_gain']} (moderate surge)")
        print(f"      Estimating gradient importance...")

    # Apply moderate glutamatergic surge
    model.set_gain(CONFIG['ketamine_gain'])

    regrow_stats = pruning_mgr.gradient_guided_regrow(regrow_fraction=regrow_fraction)
    total_regrown = sum(s['regrown'] for s in regrow_stats.values())

    if verbose:
        print(f"      Restored {total_regrown:,} synapses")
        print(f"      Consolidating new synapses (allows gain compensation)...")

    # Higher LR, brief consolidation allows good weight compensation
    consolidation_losses = train(model, epochs=consolidation_epochs,
                                  lr=CONFIG['finetune_lr'], pruning_manager=pruning_mgr)

    final_sparsity = pruning_mgr.get_sparsity()
    if verbose:
        print(f"      Final sparsity: {final_sparsity*100:.1f}%")
        print(f"      Final gain: {model.gain_multiplier:.2f}")

    return {'regrow_stats': regrow_stats, 'final_sparsity': final_sparsity,
            'final_gain': model.gain_multiplier}


def ssri_treatment(model: StressAwareNetwork, pruning_mgr: PruningManager,
                   epochs: int = None, learning_rate: float = None,
                   initial_stress: float = None, max_gain: float = None,
                   verbose: bool = True, print_interval: int = 25) -> Dict:
    """
    SSRI-LIKE TREATMENT: Gradual stabilization with progressive excitability overshoot.

    Biological model:
    - Increased synaptic serotonin → gradual receptor adaptations
    - Progressive dopaminergic/noradrenergic escalation (gain increase)
    - Low LR + long duration → slow compensation → persistent high excitability

    Manic conversion risk: HIGH
    - Progressive gain ramp to 1.6, slow weight compensation
    """
    if epochs is None:
        epochs = CONFIG['monoaminergic_epochs']
    if learning_rate is None:
        learning_rate = CONFIG['monoaminergic_lr']
    if initial_stress is None:
        initial_stress = CONFIG['monoaminergic_initial_stress']
    if max_gain is None:
        max_gain = CONFIG['ssri_max_gain']

    if verbose:
        print(f"\n    SSRI-LIKE TREATMENT:")
        print(f"      Duration: {epochs} epochs (gradual)")
        print(f"      Learning rate: {learning_rate} (very low)")
        print(f"      Internal stress: {initial_stress} → 0.0")
        print(f"      Excitability gain: 1.0 → {max_gain} (progressive overshoot)")
        print(f"      Note: NO structural changes (fixed sparsity)")
        print(f"      WARNING: High manic conversion risk due to persistent gain")

    initial_sparsity = pruning_mgr.get_sparsity()

    losses = train_with_stress_and_gain_schedule(
        model, epochs=epochs, lr=learning_rate,
        initial_stress=initial_stress, final_stress=0.0,
        initial_gain=1.0, max_gain=max_gain,
        pruning_manager=pruning_mgr, verbose=verbose,
        print_interval=print_interval
    )

    final_sparsity = pruning_mgr.get_sparsity()
    if verbose:
        print(f"      Final sparsity: {final_sparsity*100:.1f}% (unchanged)")
        print(f"      Final gain: {model.gain_multiplier:.2f} (ELEVATED)")

    return {'final_sparsity': final_sparsity, 'training_losses': losses,
            'final_gain': model.gain_multiplier}


def neurosteroid_treatment(model: StressAwareNetwork, pruning_mgr: PruningManager,
                           inhibition_strength: float = None, use_tanh: bool = None,
                           consolidation_epochs: int = None, gain: float = None,
                           verbose: bool = True) -> Dict:
    """
    NEUROSTEROID-LIKE TREATMENT: Enhanced tonic inhibition with protective low gain.

    Biological model:
    - Enhanced extrasynaptic GABA-A receptor activation
    - Tonic inhibition reduces network excitability
    - Protective effect against manic conversion

    Manic conversion risk: LOW
    - Reduced gain (0.85) + damping (0.7) + tanh bounding
    - Effective scale ~0.6, lowest activation magnitude
    """
    if inhibition_strength is None:
        inhibition_strength = CONFIG['neurosteroid_inhibition_strength']
    if use_tanh is None:
        use_tanh = CONFIG['neurosteroid_use_tanh']
    if consolidation_epochs is None:
        consolidation_epochs = CONFIG['neurosteroid_consolidation_epochs']
    if gain is None:
        gain = CONFIG['neurosteroid_gain']

    if verbose:
        print(f"\n    NEUROSTEROID-LIKE TREATMENT:")
        print(f"      Inhibition strength: {inhibition_strength} ({(1-inhibition_strength)*100:.0f}% damping)")
        print(f"      Excitability gain: {gain} (protective reduction)")
        print(f"      Effective scale: {inhibition_strength * gain:.2f}")
        print(f"      Bounded activation (tanh): {use_tanh}")
        print(f"      Consolidation: {consolidation_epochs} epochs")
        print(f"      Note: LOWEST manic conversion risk")

    # Apply tonic inhibition modulation AND protective gain reduction
    model.set_inhibition(inhibition_strength, use_tanh)
    model.set_gain(gain)

    if verbose:
        print(f"      Applied tonic inhibition + protective gain modulation...")
        print(f"      Adapting to new activity dynamics...")

    consolidation_losses = train(model, epochs=consolidation_epochs,
                                  lr=CONFIG['finetune_lr'], pruning_manager=pruning_mgr)

    final_sparsity = pruning_mgr.get_sparsity()
    if verbose:
        print(f"      Final sparsity: {final_sparsity*100:.1f}% (unchanged)")
        print(f"      Final gain: {model.gain_multiplier:.2f}")

    return {'final_sparsity': final_sparsity, 'inhibition_strength': inhibition_strength,
            'use_tanh': use_tanh, 'final_gain': model.gain_multiplier}


# ============================================================================
# MAIN EXPERIMENT
# ============================================================================
def run_multi_mechanism_experiment() -> Dict[str, Dict]:
    """
    Compare ketamine, SSRI, and neurosteroid treatment mechanisms
    including manic conversion risk assessment.
    """
    print("\n" + "="*80)
    print("  MULTI-MECHANISM ANTIDEPRESSANT COMPARISON EXPERIMENT")
    print("  WITH MANIC CONVERSION RISK MODELING")
    print("="*80)

    print("""
  COMPARING THREE ANTIDEPRESSANT MECHANISMS:

  ┌─────────────────┬─────────────────────────────────────────────────────────┐
  │ Mechanism       │ Key Feature                          │ Mania Risk      │
  ├─────────────────┼──────────────────────────────────────┼─────────────────┤
  │ Ketamine        │ Synaptogenesis + moderate gain (1.25)│ LOW-MODERATE    │
  │ SSRI            │ Gradual stabilization + gain ↑ (1.6) │ HIGH            │
  │ Neurosteroid    │ Tonic inhibition + low gain (0.85)   │ LOW             │
  └─────────────────┴──────────────────────────────────────┴─────────────────┘

  Manic conversion modeled as:
  - Treatment-induced hyperexcitability (gain_multiplier > 1.0)
  - Vulnerability to positively-biased perturbation
  - Higher activation magnitude = higher latent risk
    """)

    # ========================================================================
    # PREPARE BASE PRUNED MODEL
    # ========================================================================
    print("-"*70)
    print("  Preparing shared pruned baseline...")
    print("-"*70)

    base_model = StressAwareNetwork().to(DEVICE)
    print(f"  Training full network ({CONFIG['baseline_epochs']} epochs)...")
    train(base_model, epochs=CONFIG['baseline_epochs'], lr=CONFIG['baseline_lr'])

    base_pruning_mgr = PruningManager(base_model)
    base_pruning_mgr.prune_by_magnitude(sparsity=CONFIG['prune_sparsity'], per_layer=True)

    initial_sparsity = base_pruning_mgr.get_sparsity()
    print(f"  Pruned to {initial_sparsity*100:.1f}% sparse")

    # Evaluate untreated state
    print("\n  UNTREATED PRUNED STATE:")
    untreated_results = {'sparsity': initial_sparsity * 100, 'gain': 1.0}
    untreated_results['clean'] = evaluate(base_model, clean_test_loader, 0.0, 0.0)
    untreated_results['standard'] = evaluate(base_model, test_loader, 0.0, 0.0)
    for stress_name, stress_level in CONFIG['extended_stress_levels'].items():
        untreated_results[f'stress_{stress_name}'] = evaluate(base_model, test_loader, 0.0, stress_level)
    untreated_results['combined'] = evaluate(base_model, test_loader, 1.0, 0.5)

    # Manic conversion metrics
    untreated_results['biased_stress'] = evaluate_biased_stress(
        base_model, test_loader,
        internal_sigma=CONFIG['mania_test_sigma'],
        bias=CONFIG['mania_test_bias']
    )
    untreated_results['activation_magnitude'] = get_avg_activation_magnitude(base_model, test_loader)

    print(f"    Clean: {untreated_results['clean']:.1f}%")
    print(f"    Standard: {untreated_results['standard']:.1f}%")
    print(f"    Combined stress: {untreated_results['combined']:.1f}%")
    print(f"    Biased stress (mania test): {untreated_results['biased_stress']:.1f}%")
    print(f"    Activation magnitude: {untreated_results['activation_magnitude']['mean']:.3f}")

    # Save state for cloning
    base_state_dict = {k: v.clone() for k, v in base_model.state_dict().items()}
    base_masks = {k: v.clone() for k, v in base_pruning_mgr.masks.items()}

    results = {'untreated': untreated_results}

    # ========================================================================
    # TREATMENT 1: KETAMINE-LIKE
    # ========================================================================
    print("\n" + "="*70)
    print("  TREATMENT 1: KETAMINE-LIKE (Synaptogenesis + Moderate Gain)")
    print("="*70)

    ketamine_model = StressAwareNetwork().to(DEVICE)
    ketamine_model.load_state_dict(base_state_dict)
    ketamine_mgr = PruningManager(ketamine_model)
    ketamine_mgr.masks = {k: v.clone() for k, v in base_masks.items()}
    ketamine_mgr.apply_masks()

    ketamine_stats = ketamine_treatment(ketamine_model, ketamine_mgr,
                                         regrow_fraction=CONFIG['comparison_ketamine_regrow'],
                                         consolidation_epochs=CONFIG['comparison_ketamine_epochs'])

    print("\n    POST-TREATMENT EVALUATION:")
    ketamine_results = {
        'sparsity': ketamine_mgr.get_sparsity() * 100,
        'gain': ketamine_model.gain_multiplier
    }
    ketamine_results['clean'] = evaluate(ketamine_model, clean_test_loader, 0.0, 0.0)
    ketamine_results['standard'] = evaluate(ketamine_model, test_loader, 0.0, 0.0)
    for stress_name, stress_level in CONFIG['extended_stress_levels'].items():
        ketamine_results[f'stress_{stress_name}'] = evaluate(ketamine_model, test_loader, 0.0, stress_level)
    ketamine_results['combined'] = evaluate(ketamine_model, test_loader, 1.0, 0.5)

    # Manic conversion metrics
    ketamine_results['biased_stress'] = evaluate_biased_stress(
        ketamine_model, test_loader,
        internal_sigma=CONFIG['mania_test_sigma'],
        bias=CONFIG['mania_test_bias']
    )
    ketamine_results['activation_magnitude'] = get_avg_activation_magnitude(ketamine_model, test_loader)

    print(f"      Clean: {ketamine_results['clean']:.1f}%")
    print(f"      Standard: {ketamine_results['standard']:.1f}%")
    print(f"      Combined stress: {ketamine_results['combined']:.1f}%")
    print(f"      Extreme stress: {ketamine_results['stress_extreme']:.1f}%")
    print(f"      Biased stress (mania test): {ketamine_results['biased_stress']:.1f}%")
    print(f"      Activation magnitude: {ketamine_results['activation_magnitude']['mean']:.3f}")
    print(f"      Final gain: {ketamine_results['gain']:.2f}")

    # Relapse simulation
    print("\n    RELAPSE SIMULATION:")
    pre_relapse = ketamine_results['combined']
    pre_sparsity = ketamine_mgr.get_sparsity()
    target_sparsity = min(pre_sparsity + (1 - pre_sparsity) * 0.40, 0.99)
    ketamine_mgr.prune_by_magnitude(sparsity=target_sparsity, per_layer=True)
    ketamine_mgr.apply_masks()
    post_relapse = evaluate(ketamine_model, test_loader, 1.0, 0.5)
    ketamine_results['relapse_drop'] = pre_relapse - post_relapse
    print(f"      Combined: {pre_relapse:.1f}% → {post_relapse:.1f}% (drop: {ketamine_results['relapse_drop']:.1f}%)")

    results['ketamine'] = ketamine_results

    # ========================================================================
    # TREATMENT 2: SSRI-LIKE
    # ========================================================================
    print("\n" + "="*70)
    print("  TREATMENT 2: SSRI-LIKE (Gradual Stabilization + Gain Overshoot)")
    print("="*70)

    ssri_model = StressAwareNetwork().to(DEVICE)
    ssri_model.load_state_dict(base_state_dict)
    ssri_mgr = PruningManager(ssri_model)
    ssri_mgr.masks = {k: v.clone() for k, v in base_masks.items()}
    ssri_mgr.apply_masks()

    ssri_stats = ssri_treatment(ssri_model, ssri_mgr,
                                 epochs=CONFIG['comparison_ssri_epochs'],
                                 learning_rate=CONFIG['monoaminergic_lr'],
                                 initial_stress=CONFIG['monoaminergic_initial_stress'],
                                 max_gain=CONFIG['ssri_max_gain'],
                                 print_interval=25)

    print("\n    POST-TREATMENT EVALUATION:")
    ssri_results = {
        'sparsity': ssri_mgr.get_sparsity() * 100,
        'gain': ssri_model.gain_multiplier
    }
    ssri_results['clean'] = evaluate(ssri_model, clean_test_loader, 0.0, 0.0)
    ssri_results['standard'] = evaluate(ssri_model, test_loader, 0.0, 0.0)
    for stress_name, stress_level in CONFIG['extended_stress_levels'].items():
        ssri_results[f'stress_{stress_name}'] = evaluate(ssri_model, test_loader, 0.0, stress_level)
    ssri_results['combined'] = evaluate(ssri_model, test_loader, 1.0, 0.5)

    # Manic conversion metrics
    ssri_results['biased_stress'] = evaluate_biased_stress(
        ssri_model, test_loader,
        internal_sigma=CONFIG['mania_test_sigma'],
        bias=CONFIG['mania_test_bias']
    )
    ssri_results['activation_magnitude'] = get_avg_activation_magnitude(ssri_model, test_loader)

    print(f"      Clean: {ssri_results['clean']:.1f}%")
    print(f"      Standard: {ssri_results['standard']:.1f}%")
    print(f"      Combined stress: {ssri_results['combined']:.1f}%")
    print(f"      Extreme stress: {ssri_results['stress_extreme']:.1f}%")
    print(f"      Biased stress (mania test): {ssri_results['biased_stress']:.1f}%")
    print(f"      Activation magnitude: {ssri_results['activation_magnitude']['mean']:.3f}")
    print(f"      Final gain: {ssri_results['gain']:.2f} (ELEVATED - HIGH MANIA RISK)")

    # Relapse simulation
    print("\n    RELAPSE SIMULATION:")
    pre_relapse = ssri_results['combined']
    pre_sparsity = ssri_mgr.get_sparsity()
    target_sparsity = min(pre_sparsity + (1 - pre_sparsity) * 0.40, 0.99)
    ssri_mgr.prune_by_magnitude(sparsity=target_sparsity, per_layer=True)
    ssri_mgr.apply_masks()
    post_relapse = evaluate(ssri_model, test_loader, 1.0, 0.5)
    ssri_results['relapse_drop'] = pre_relapse - post_relapse
    print(f"      Combined: {pre_relapse:.1f}% → {post_relapse:.1f}% (drop: {ssri_results['relapse_drop']:.1f}%)")

    results['ssri'] = ssri_results

    # ========================================================================
    # TREATMENT 3: NEUROSTEROID-LIKE
    # ========================================================================
    print("\n" + "="*70)
    print("  TREATMENT 3: NEUROSTEROID-LIKE (Tonic Inhibition + Protective Gain)")
    print("="*70)

    neuro_model = StressAwareNetwork().to(DEVICE)
    neuro_model.load_state_dict(base_state_dict)
    neuro_mgr = PruningManager(neuro_model)
    neuro_mgr.masks = {k: v.clone() for k, v in base_masks.items()}
    neuro_mgr.apply_masks()

    neuro_stats = neurosteroid_treatment(neuro_model, neuro_mgr,
                                          inhibition_strength=CONFIG['neurosteroid_inhibition_strength'],
                                          use_tanh=CONFIG['neurosteroid_use_tanh'],
                                          consolidation_epochs=CONFIG['neurosteroid_consolidation_epochs'],
                                          gain=CONFIG['neurosteroid_gain'])

    # Evaluate WITH modulation active (patient on medication)
    print("\n    POST-TREATMENT EVALUATION (with modulation active):")
    neuro_results = {
        'sparsity': neuro_mgr.get_sparsity() * 100,
        'gain': neuro_model.gain_multiplier,
        'effective_scale': neuro_model.inhibition_strength * neuro_model.gain_multiplier
    }
    neuro_results['clean'] = evaluate_with_neurosteroid(neuro_model, clean_test_loader, 0.0, 0.0)
    neuro_results['standard'] = evaluate_with_neurosteroid(neuro_model, test_loader, 0.0, 0.0)
    for stress_name, stress_level in CONFIG['extended_stress_levels'].items():
        neuro_results[f'stress_{stress_name}'] = evaluate_with_neurosteroid(neuro_model, test_loader, 0.0, stress_level)
    neuro_results['combined'] = evaluate_with_neurosteroid(neuro_model, test_loader, 1.0, 0.5)

    # Manic conversion metrics
    neuro_results['biased_stress'] = evaluate_biased_stress(
        neuro_model, test_loader,
        internal_sigma=CONFIG['mania_test_sigma'],
        bias=CONFIG['mania_test_bias']
    )
    neuro_results['activation_magnitude'] = get_avg_activation_magnitude(neuro_model, test_loader)

    print(f"      Clean: {neuro_results['clean']:.1f}%")
    print(f"      Standard: {neuro_results['standard']:.1f}%")
    print(f"      Combined stress: {neuro_results['combined']:.1f}%")
    print(f"      Extreme stress: {neuro_results['stress_extreme']:.1f}%")
    print(f"      Biased stress (mania test): {neuro_results['biased_stress']:.1f}%")
    print(f"      Activation magnitude: {neuro_results['activation_magnitude']['mean']:.3f}")
    print(f"      Final gain: {neuro_results['gain']:.2f}")
    print(f"      Effective scale: {neuro_results['effective_scale']:.2f} (LOWEST - LOW MANIA RISK)")

    # Test WITHOUT modulation (medication discontinued)
    print("\n    EVALUATION WITHOUT MODULATION (medication discontinued):")
    neuro_model.set_inhibition(1.0, False)
    neuro_model.set_gain(1.0)  # Reset gain to baseline
    off_med_combined = evaluate(neuro_model, test_loader, 1.0, 0.5)
    off_med_extreme = evaluate(neuro_model, test_loader, 0.0, 2.5)
    off_med_biased = evaluate_biased_stress(
        neuro_model, test_loader,
        internal_sigma=CONFIG['mania_test_sigma'],
        bias=CONFIG['mania_test_bias']
    )
    off_med_activation = get_avg_activation_magnitude(neuro_model, test_loader)
    print(f"      Combined stress: {off_med_combined:.1f}%")
    print(f"      Extreme stress: {off_med_extreme:.1f}%")
    print(f"      Biased stress (mania test): {off_med_biased:.1f}%")
    print(f"      Activation magnitude: {off_med_activation['mean']:.3f}")
    neuro_results['off_medication_combined'] = off_med_combined
    neuro_results['off_medication_extreme'] = off_med_extreme
    neuro_results['off_medication_biased'] = off_med_biased
    neuro_results['off_medication_activation'] = off_med_activation

    # Restore modulation for relapse test
    neuro_model.set_inhibition(CONFIG['neurosteroid_inhibition_strength'],
                                CONFIG['neurosteroid_use_tanh'])
    neuro_model.set_gain(CONFIG['neurosteroid_gain'])

    # Relapse simulation
    print("\n    RELAPSE SIMULATION (with modulation active):")
    pre_relapse = neuro_results['combined']
    pre_sparsity = neuro_mgr.get_sparsity()
    target_sparsity = min(pre_sparsity + (1 - pre_sparsity) * 0.40, 0.99)
    neuro_mgr.prune_by_magnitude(sparsity=target_sparsity, per_layer=True)
    neuro_mgr.apply_masks()
    post_relapse = evaluate_with_neurosteroid(neuro_model, test_loader, 1.0, 0.5)
    neuro_results['relapse_drop'] = pre_relapse - post_relapse
    print(f"      Combined: {pre_relapse:.1f}% → {post_relapse:.1f}% (drop: {neuro_results['relapse_drop']:.1f}%)")

    results['neurosteroid'] = neuro_results

    # ========================================================================
    # COMPREHENSIVE COMPARISON
    # ========================================================================
    print("\n" + "="*80)
    print("  COMPREHENSIVE COMPARISON: ALL TREATMENTS")
    print("="*80)

    treatments = ['untreated', 'ketamine', 'ssri', 'neurosteroid']
    labels = {'untreated': 'Untreated (pruned)', 'ketamine': 'Ketamine-like',
              'ssri': 'SSRI-like', 'neurosteroid': 'Neurosteroid-like'}

    print(f"\n  {'Treatment':<22} {'Sparsity':>10} {'Clean':>8} {'Standard':>10} "
          f"{'Combined':>10} {'Extreme':>10} {'Relapse':>10}")
    print("  " + "-"*85)

    for t in treatments:
        r = results[t]
        relapse = r.get('relapse_drop', 'N/A')
        relapse_str = f"{relapse:.1f}%" if isinstance(relapse, float) else relapse
        print(f"  {labels[t]:<22} {r['sparsity']:>9.1f}% {r['clean']:>7.1f}% "
              f"{r['standard']:>9.1f}% {r['combined']:>9.1f}% "
              f"{r['stress_extreme']:>9.1f}% {relapse_str:>10}")

    # ========================================================================
    # MANIC CONVERSION RISK COMPARISON
    # ========================================================================
    print("\n" + "="*80)
    print("  MANIC CONVERSION RISK COMPARISON")
    print("="*80)

    print(f"\n  {'Treatment':<22} {'Gain':>8} {'Biased Stress':>15} {'Act. Magnitude':>16} {'Risk Level':>12}")
    print("  " + "-"*75)

    for t in treatments:
        r = results[t]
        gain = r.get('gain', 1.0)
        biased = r.get('biased_stress', 0.0)
        act_mag = r.get('activation_magnitude', {}).get('mean', 0.0)

        # Determine risk level based on metrics
        if gain >= 1.5:
            risk = "HIGH"
        elif gain >= 1.2:
            risk = "MODERATE"
        elif gain < 1.0:
            risk = "LOW"
        else:
            risk = "BASELINE"

        print(f"  {labels[t]:<22} {gain:>7.2f} {biased:>14.1f}% {act_mag:>15.3f} {risk:>12}")

    print("""
  INTERPRETATION:
  ─────────────────────────────────────────────────────────────────────────────
  • Biased Stress Accuracy: LOWER = HIGHER manic conversion risk
    (Network collapses under positively-biased perturbation due to hyperexcitability)

  • Activation Magnitude: HIGHER = HIGHER latent manic risk
    (Greater excitatory tone even without biased input)

  • Gain Multiplier: >1.0 = increased excitability, <1.0 = protective inhibition
    """)

    # ========================================================================
    # CLINICAL INTERPRETATION WITH MANIA RISK
    # ========================================================================
    print("\n" + "-"*80)
    print("  CLINICAL INTERPRETATION")
    print("-"*80)

    ket, ssri, neuro = results['ketamine'], results['ssri'], results['neurosteroid']
    untreated = results['untreated']

    print("\n  1. ANTIDEPRESSANT EFFICACY (Combined Stress):")
    print(f"     Ketamine:    {untreated['combined']:.1f}% → {ket['combined']:.1f}% (+{ket['combined'] - untreated['combined']:.1f}%)")
    print(f"     SSRI:        {untreated['combined']:.1f}% → {ssri['combined']:.1f}% (+{ssri['combined'] - untreated['combined']:.1f}%)")
    print(f"     Neurosteroid:{untreated['combined']:.1f}% → {neuro['combined']:.1f}% (+{neuro['combined'] - untreated['combined']:.1f}%)")

    print("\n  2. MANIC CONVERSION RISK (Biased Stress - Lower = Higher Risk):")
    print(f"     Untreated:   {untreated['biased_stress']:.1f}% (baseline)")
    print(f"     Ketamine:    {ket['biased_stress']:.1f}% (moderate transient excitability)")
    print(f"     SSRI:        {ssri['biased_stress']:.1f}% (HIGH RISK - persistent gain overshoot)")
    print(f"     Neurosteroid:{neuro['biased_stress']:.1f}% (PROTECTIVE - tonic inhibition)")

    print("\n  3. EXCITABILITY PROFILE:")
    print(f"     Ketamine gain:     {ket['gain']:.2f} (moderate surge, good compensation)")
    print(f"     SSRI gain:         {ssri['gain']:.2f} (ELEVATED - poor compensation)")
    print(f"     Neurosteroid gain: {neuro['gain']:.2f} × {CONFIG['neurosteroid_inhibition_strength']:.2f} = {neuro['effective_scale']:.2f} (PROTECTIVE)")

    print("\n  4. CLINICAL ALIGNMENT:")
    print("""
     ┌─────────────────────┬───────────────────────────────────────────────────┐
     │ Clinical Evidence   │ Model Prediction                                  │
     ├─────────────────────┼───────────────────────────────────────────────────┤
     │ SSRIs: 20-40% switch│ HIGHEST gain (1.6), WORST biased-stress tolerance │
     │ rate in bipolar pts │ → Matches high manic conversion risk              │
     │                     │                                                   │
     │ Ketamine: Rare      │ MODERATE gain (1.25), effective consolidation     │
     │ switch reports      │ → Matches low/moderate risk in stabilized pts     │
     │                     │                                                   │
     │ Neurosteroids: Very │ LOWEST effective scale (0.6), BEST biased-stress  │
     │ low switch rates    │ → Matches protective GABAergic effect             │
     └─────────────────────┴───────────────────────────────────────────────────┘
    """)

    print("\n  5. TREATMENT SELECTION WITH BIPOLAR RISK:")
    print("""
     ┌──────────────────────────────────────┬──────────────────────────────────┐
     │ Clinical Scenario                    │ Suggested Mechanism              │
     ├──────────────────────────────────────┼──────────────────────────────────┤
     │ MDD with NO bipolar risk             │ Any (SSRIs acceptable first-line)│
     │ MDD with bipolar family history      │ Neurosteroid OR Ketamine + MS    │
     │ Bipolar depression (acute)           │ Neurosteroid (lowest switch risk)│
     │ Bipolar depression (maintenance)     │ Ketamine + mood stabilizer       │
     │ Treatment-resistant with mania hx    │ AVOID SSRI monotherapy           │
     └──────────────────────────────────────┴──────────────────────────────────┘

     MS = Mood Stabilizer (not modeled, but clinically essential)
    """)

    return results


# ============================================================================
# ENTRY POINT
# ============================================================================
if __name__ == "__main__":
    print("\n" + "#"*80)
    print("#" + " "*78 + "#")
    print("#" + " MULTI-MECHANISM ANTIDEPRESSANT COMPARISON ".center(78) + "#")
    print("#" + " WITH MANIC CONVERSION RISK MODELING ".center(78) + "#")
    print("#" + " Ketamine vs SSRI vs Neurosteroid ".center(78) + "#")
    print("#" + " "*78 + "#")
    print("#"*80)

    results = run_multi_mechanism_experiment()

    print("\n" + "="*80)
    print("  EXPERIMENT COMPLETE")
    print("="*80 + "\n")


################################################################################
#                                                                              #
#                  MULTI-MECHANISM ANTIDEPRESSANT COMPARISON                   #
#                     WITH MANIC CONVERSION RISK MODELING                      #
#                       Ketamine vs SSRI vs Neurosteroid                       #
#                                                                              #
################################################################################

  MULTI-MECHANISM ANTIDEPRESSANT COMPARISON EXPERIMENT
  WITH MANIC CONVERSION RISK MODELING

  COMPARING THREE ANTIDEPRESSANT MECHANISMS:
  
  ┌─────────────────┬─────────────────────────────────────────────────────────┐
  │ Mechanism       │ Key Feature                          │ Mania Risk      │
  ├─────────────────┼──────────────────────────────────────┼─────────────────┤
  │ Ketamine        │ Synaptogenesis + moderate gai

## Repeated

In [2]:
"""
================================================================================
MULTI-MECHANISM ANTIDEPRESSANT COMPARISON EXPERIMENT
WITH MANIC CONVERSION RISK MODELING
================================================================================

Compares three antidepressant mechanisms across 10 random seeds:
1. KETAMINE-LIKE: Gradient-guided synaptogenesis (moderate transient excitability)
2. SSRI-LIKE: Gradual stabilization (progressive excitability overshoot)
3. NEUROSTEROID-LIKE: Tonic inhibition enhancement (protective low excitability)
================================================================================
"""

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.utils.data import DataLoader, TensorDataset
from typing import Dict, Tuple, List
import warnings
from collections import defaultdict

warnings.filterwarnings('ignore', category=UserWarning)

DEVICE = torch.device('cpu')

# ============================================================================
# CONFIGURATION
# ============================================================================
CONFIG = {
    'n_seeds': 10,
    'n_train': 12000,
    'n_test': 4000,
    'n_clean_test': 2000,
    'data_noise': 0.8,
    'batch_size': 128,
    'hidden_dims': [512, 512, 256],
    'input_dim': 2,
    'output_dim': 4,
    'baseline_epochs': 20,
    'baseline_lr': 0.001,
    'finetune_epochs': 15,
    'finetune_lr': 0.0005,
    'prune_sparsity': 0.95,
    'regrow_fraction': 0.5,
    'regrow_init_scale': 0.03,
    'gradient_accumulation_batches': 30,
    'extended_stress_levels': {
        'none': 0.0,
        'moderate': 0.5,
        'high': 1.0,
        'severe': 1.5,
        'extreme': 2.5
    },
    'monoaminergic_epochs': 100,
    'monoaminergic_lr': 1e-5,
    'monoaminergic_initial_stress': 0.5,
    'ssri_max_gain': 1.6,
    'ketamine_gain': 1.25,
    'neurosteroid_inhibition_strength': 0.7,
    'neurosteroid_use_tanh': True,
    'neurosteroid_consolidation_epochs': 10,
    'neurosteroid_gain': 0.85,
    'mania_test_bias': 1.0,
    'mania_test_sigma': 1.0,
    'comparison_ketamine_regrow': 0.5,
    'comparison_ketamine_epochs': 15,
    'comparison_ssri_epochs': 100,
    'comparison_neurosteroid_strength': 0.7,
    'comparison_neurosteroid_epochs': 10,
}


# ============================================================================
# DATA GENERATION
# ============================================================================
def generate_blobs(n_samples: int = 10000, noise: float = 0.8, seed: int = None) -> Tuple[torch.Tensor, torch.Tensor]:
    if seed is not None:
        rng = np.random.RandomState(seed)
    else:
        rng = np.random.RandomState()
    centers = np.array([[-3, -3], [3, 3], [-3, 3], [3, -3]])
    labels = rng.randint(0, 4, n_samples)
    data = centers[labels] + rng.randn(n_samples, 2) * noise
    return (torch.tensor(data, dtype=torch.float32), torch.tensor(labels, dtype=torch.long))


def create_data_loaders(seed: int) -> Tuple[DataLoader, DataLoader, DataLoader]:
    train_data, train_labels = generate_blobs(CONFIG['n_train'], noise=CONFIG['data_noise'], seed=seed*1000+100)
    test_data, test_labels = generate_blobs(CONFIG['n_test'], noise=CONFIG['data_noise'], seed=seed*1000+200)
    clean_test_data, clean_test_labels = generate_blobs(CONFIG['n_clean_test'], noise=0.0, seed=seed*1000+300)
    train_loader = DataLoader(TensorDataset(train_data, train_labels), batch_size=CONFIG['batch_size'], shuffle=True)
    test_loader = DataLoader(TensorDataset(test_data, test_labels), batch_size=1000)
    clean_test_loader = DataLoader(TensorDataset(clean_test_data, clean_test_labels), batch_size=1000)
    return train_loader, test_loader, clean_test_loader


# ============================================================================
# NETWORK ARCHITECTURE
# ============================================================================
class StressAwareNetwork(nn.Module):
    def __init__(self, hidden_dims: List[int] = None):
        super().__init__()
        if hidden_dims is None:
            hidden_dims = CONFIG['hidden_dims']
        self.fc1 = nn.Linear(CONFIG['input_dim'], hidden_dims[0])
        self.fc2 = nn.Linear(hidden_dims[0], hidden_dims[1])
        self.fc3 = nn.Linear(hidden_dims[1], hidden_dims[2])
        self.fc4 = nn.Linear(hidden_dims[2], CONFIG['output_dim'])
        self.relu = nn.ReLU()
        self.tanh = nn.Tanh()
        self.stress_level = 0.0
        self.inhibition_strength = 1.0
        self.use_tanh = False
        self.gain_multiplier = 1.0
        self.weight_layers = ['fc1', 'fc2', 'fc3', 'fc4']

    def set_stress(self, level: float):
        self.stress_level = level

    def set_inhibition(self, strength: float, use_tanh: bool = False):
        self.inhibition_strength = strength
        self.use_tanh = use_tanh

    def set_gain(self, gain: float):
        self.gain_multiplier = gain

    def reduce_stress_gradually(self, epoch: int, total_epochs: int, initial_stress: float = 0.5, final_stress: float = 0.0):
        progress = epoch / max(total_epochs - 1, 1)
        self.stress_level = initial_stress + progress * (final_stress - initial_stress)

    def increase_gain_gradually(self, epoch: int, total_epochs: int, initial_gain: float = 1.0, max_gain: float = 1.6):
        progress = epoch / max(total_epochs - 1, 1)
        self.gain_multiplier = initial_gain + progress * (max_gain - initial_gain)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        activation = self.tanh if self.use_tanh else self.relu
        h = activation(self.fc1(x))
        if self.stress_level > 0:
            h = h + torch.randn_like(h) * self.stress_level
        h = h * self.inhibition_strength * self.gain_multiplier
        h = activation(self.fc2(h))
        if self.stress_level > 0:
            h = h + torch.randn_like(h) * self.stress_level
        h = h * self.inhibition_strength * self.gain_multiplier
        h = activation(self.fc3(h))
        if self.stress_level > 0:
            h = h + torch.randn_like(h) * self.stress_level
        h = h * self.inhibition_strength * self.gain_multiplier
        return self.fc4(h)

    def forward_with_biased_noise(self, x: torch.Tensor, internal_sigma: float = 1.0, bias: float = 1.0) -> torch.Tensor:
        activation = self.tanh if self.use_tanh else self.relu
        h = activation(self.fc1(x))
        h = h + (torch.randn_like(h) * internal_sigma + bias)
        h = h * self.inhibition_strength * self.gain_multiplier
        h = activation(self.fc2(h))
        h = h + (torch.randn_like(h) * internal_sigma + bias)
        h = h * self.inhibition_strength * self.gain_multiplier
        h = activation(self.fc3(h))
        h = h + (torch.randn_like(h) * internal_sigma + bias)
        h = h * self.inhibition_strength * self.gain_multiplier
        return self.fc4(h)

    def count_parameters(self) -> Tuple[int, int]:
        total = sum(p.numel() for p in self.parameters())
        nonzero = sum((p != 0).sum().item() for p in self.parameters())
        return total, nonzero


# ============================================================================
# PRUNING MANAGER
# ============================================================================
class PruningManager:
    def __init__(self, model: StressAwareNetwork, train_loader: DataLoader):
        self.model = model
        self.train_loader = train_loader
        self.masks = {}
        self.gradient_buffer = {}
        for name, param in model.named_parameters():
            if 'weight' in name and param.dim() >= 2:
                self.masks[name] = torch.ones_like(param, dtype=torch.float32)
                self.gradient_buffer[name] = torch.zeros_like(param)

    def prune_by_magnitude(self, sparsity: float, per_layer: bool = True) -> Dict[str, Dict]:
        stats = {}
        for name, param in self.model.named_parameters():
            if name in self.masks:
                weights = param.data.abs()
                threshold = torch.quantile(weights.flatten(), sparsity)
                self.masks[name] = (weights >= threshold).float()
                param.data *= self.masks[name]
                kept = self.masks[name].sum().item()
                total = self.masks[name].numel()
                stats[name] = {'kept': int(kept), 'total': total, 'actual_sparsity': 1 - kept/total}
        return stats

    def _accumulate_gradients(self, num_batches: int = 30):
        model = self.model
        loss_fn = nn.CrossEntropyLoss()
        for name in self.gradient_buffer:
            self.gradient_buffer[name].zero_()
        model.train()
        original_stress = model.stress_level
        original_gain = model.gain_multiplier
        model.set_stress(0.0)
        model.set_gain(1.0)
        batch_count = 0
        for x, y in self.train_loader:
            if batch_count >= num_batches:
                break
            x, y = x.to(DEVICE), y.to(DEVICE)
            loss = loss_fn(model(x), y)
            loss.backward()
            with torch.no_grad():
                for name, param in model.named_parameters():
                    if name in self.masks:
                        pruned_mask = (self.masks[name] == 0).float()
                        self.gradient_buffer[name] += param.grad.abs() * pruned_mask
            model.zero_grad()
            batch_count += 1
        model.set_stress(original_stress)
        model.set_gain(original_gain)

    def gradient_guided_regrow(self, regrow_fraction: float, init_scale: float = None) -> Dict[str, Dict]:
        if init_scale is None:
            init_scale = CONFIG['regrow_init_scale']
        self._accumulate_gradients(num_batches=CONFIG['gradient_accumulation_batches'])
        stats = {}
        for name, param in self.model.named_parameters():
            if name not in self.masks:
                continue
            mask = self.masks[name]
            pruned_positions = (mask == 0)
            num_pruned = pruned_positions.sum().item()
            if num_pruned == 0:
                stats[name] = {'regrown': 0, 'still_pruned': 0}
                continue
            gradient_scores = self.gradient_buffer[name][pruned_positions]
            num_regrow = max(1, int(regrow_fraction * num_pruned))
            num_regrow = min(num_regrow, gradient_scores.numel())
            _, top_indices = torch.topk(gradient_scores.flatten(), num_regrow)
            flat_pruned_indices = torch.where(pruned_positions.flatten())[0]
            regrow_flat_indices = flat_pruned_indices[top_indices]
            flat_mask = mask.flatten()
            flat_param = param.data.flatten()
            flat_mask[regrow_flat_indices] = 1.0
            flat_param[regrow_flat_indices] = torch.randn(num_regrow) * init_scale
            self.masks[name] = flat_mask.view_as(mask)
            param.data = flat_param.view_as(param)
            stats[name] = {'regrown': num_regrow, 'still_pruned': int(num_pruned - num_regrow)}
        return stats

    def apply_masks(self):
        with torch.no_grad():
            for name, param in self.model.named_parameters():
                if name in self.masks:
                    param.data *= self.masks[name]

    def get_sparsity(self) -> float:
        total = sum(m.numel() for m in self.masks.values())
        zeros = sum((m == 0).sum().item() for m in self.masks.values())
        return zeros / total if total > 0 else 0.0


# ============================================================================
# TRAINING FUNCTIONS
# ============================================================================
def train(model: StressAwareNetwork, train_loader: DataLoader, epochs: int = 15, lr: float = 0.001,
          pruning_manager: PruningManager = None) -> List[float]:
    optimizer = optim.Adam(model.parameters(), lr=lr)
    loss_fn = nn.CrossEntropyLoss()
    losses = []
    original_stress = model.stress_level
    model.set_stress(0.0)
    for epoch in range(epochs):
        model.train()
        epoch_loss = 0.0
        for x, y in train_loader:
            x, y = x.to(DEVICE), y.to(DEVICE)
            optimizer.zero_grad()
            loss = loss_fn(model(x), y)
            loss.backward()
            optimizer.step()
            if pruning_manager:
                pruning_manager.apply_masks()
            epoch_loss += loss.item()
        losses.append(epoch_loss / len(train_loader))
    model.set_stress(original_stress)
    return losses


def train_with_stress_and_gain_schedule(model: StressAwareNetwork, train_loader: DataLoader, epochs: int, lr: float,
                                         initial_stress: float, final_stress: float = 0.0,
                                         initial_gain: float = 1.0, max_gain: float = 1.6,
                                         pruning_manager: PruningManager = None) -> List[float]:
    optimizer = optim.Adam(model.parameters(), lr=lr)
    loss_fn = nn.CrossEntropyLoss()
    losses = []
    for epoch in range(epochs):
        model.reduce_stress_gradually(epoch, epochs, initial_stress, final_stress)
        model.increase_gain_gradually(epoch, epochs, initial_gain, max_gain)
        model.train()
        epoch_loss = 0.0
        for x, y in train_loader:
            x, y = x.to(DEVICE), y.to(DEVICE)
            optimizer.zero_grad()
            loss = loss_fn(model(x), y)
            loss.backward()
            optimizer.step()
            if pruning_manager:
                pruning_manager.apply_masks()
            epoch_loss += loss.item()
        losses.append(epoch_loss / len(train_loader))
    model.set_stress(0.0)
    return losses


# ============================================================================
# EVALUATION FUNCTIONS
# ============================================================================
def evaluate(model: StressAwareNetwork, loader: DataLoader, input_noise: float = 0.0, internal_stress: float = 0.0) -> float:
    model.eval()
    model.set_stress(internal_stress)
    correct, total = 0, 0
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(DEVICE), y.to(DEVICE)
            if input_noise > 0:
                x = x + torch.randn_like(x) * input_noise
            correct += (model(x).argmax(dim=1) == y).sum().item()
            total += y.size(0)
    model.set_stress(0.0)
    return 100.0 * correct / total


def evaluate_biased_stress(model: StressAwareNetwork, loader: DataLoader, input_noise: float = 0.0,
                           internal_sigma: float = 1.0, bias: float = 1.0) -> float:
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(DEVICE), y.to(DEVICE)
            if input_noise > 0:
                x = x + torch.randn_like(x) * input_noise
            out = model.forward_with_biased_noise(x, internal_sigma, bias)
            correct += (out.argmax(dim=1) == y).sum().item()
            total += y.size(0)
    return 100.0 * correct / total


def get_avg_activation_magnitude(model: StressAwareNetwork, loader: DataLoader) -> float:
    model.eval()
    activation = model.tanh if model.use_tanh else model.relu
    all_acts = []
    with torch.no_grad():
        for x, _ in loader:
            x = x.to(DEVICE)
            h1 = activation(model.fc1(x)) * model.inhibition_strength * model.gain_multiplier
            h2 = activation(model.fc2(h1)) * model.inhibition_strength * model.gain_multiplier
            h3 = activation(model.fc3(h2)) * model.inhibition_strength * model.gain_multiplier
            all_acts.append(torch.mean(torch.abs(h1)).item())
            all_acts.append(torch.mean(torch.abs(h2)).item())
            all_acts.append(torch.mean(torch.abs(h3)).item())
    return np.mean(all_acts)


# ============================================================================
# TREATMENT PROTOCOLS
# ============================================================================
def ketamine_treatment(model: StressAwareNetwork, pruning_mgr: PruningManager, train_loader: DataLoader) -> Dict:
    model.set_gain(CONFIG['ketamine_gain'])
    pruning_mgr.gradient_guided_regrow(regrow_fraction=CONFIG['comparison_ketamine_regrow'])
    train(model, train_loader, epochs=CONFIG['comparison_ketamine_epochs'], lr=CONFIG['finetune_lr'], pruning_manager=pruning_mgr)
    return {'final_sparsity': pruning_mgr.get_sparsity(), 'final_gain': model.gain_multiplier}


def ssri_treatment(model: StressAwareNetwork, pruning_mgr: PruningManager, train_loader: DataLoader) -> Dict:
    train_with_stress_and_gain_schedule(
        model, train_loader, epochs=CONFIG['comparison_ssri_epochs'], lr=CONFIG['monoaminergic_lr'],
        initial_stress=CONFIG['monoaminergic_initial_stress'], final_stress=0.0,
        initial_gain=1.0, max_gain=CONFIG['ssri_max_gain'], pruning_manager=pruning_mgr
    )
    return {'final_sparsity': pruning_mgr.get_sparsity(), 'final_gain': model.gain_multiplier}


def neurosteroid_treatment(model: StressAwareNetwork, pruning_mgr: PruningManager, train_loader: DataLoader) -> Dict:
    model.set_inhibition(CONFIG['neurosteroid_inhibition_strength'], CONFIG['neurosteroid_use_tanh'])
    model.set_gain(CONFIG['neurosteroid_gain'])
    train(model, train_loader, epochs=CONFIG['neurosteroid_consolidation_epochs'], lr=CONFIG['finetune_lr'], pruning_manager=pruning_mgr)
    return {'final_sparsity': pruning_mgr.get_sparsity(), 'final_gain': model.gain_multiplier,
            'effective_scale': model.inhibition_strength * model.gain_multiplier}


# ============================================================================
# SINGLE SEED EXPERIMENT
# ============================================================================
def run_single_seed(seed: int) -> Dict[str, Dict]:
    torch.manual_seed(seed)
    np.random.seed(seed)

    train_loader, test_loader, clean_test_loader = create_data_loaders(seed)

    # Train and prune base model
    base_model = StressAwareNetwork().to(DEVICE)
    train(base_model, train_loader, epochs=CONFIG['baseline_epochs'], lr=CONFIG['baseline_lr'])
    base_pruning_mgr = PruningManager(base_model, train_loader)
    base_pruning_mgr.prune_by_magnitude(sparsity=CONFIG['prune_sparsity'])

    base_state_dict = {k: v.clone() for k, v in base_model.state_dict().items()}
    base_masks = {k: v.clone() for k, v in base_pruning_mgr.masks.items()}

    results = {}

    # Untreated
    untreated_results = {
        'sparsity': base_pruning_mgr.get_sparsity() * 100,
        'gain': 1.0,
        'clean': evaluate(base_model, clean_test_loader),
        'standard': evaluate(base_model, test_loader),
        'combined': evaluate(base_model, test_loader, 1.0, 0.5),
        'biased_stress': evaluate_biased_stress(base_model, test_loader, internal_sigma=CONFIG['mania_test_sigma'], bias=CONFIG['mania_test_bias']),
        'activation_magnitude': get_avg_activation_magnitude(base_model, test_loader)
    }
    for stress_name, stress_level in CONFIG['extended_stress_levels'].items():
        untreated_results[f'stress_{stress_name}'] = evaluate(base_model, test_loader, 0.0, stress_level)
    results['untreated'] = untreated_results

    # Ketamine
    ketamine_model = StressAwareNetwork().to(DEVICE)
    ketamine_model.load_state_dict(base_state_dict)
    ketamine_mgr = PruningManager(ketamine_model, train_loader)
    ketamine_mgr.masks = {k: v.clone() for k, v in base_masks.items()}
    ketamine_mgr.apply_masks()
    ketamine_treatment(ketamine_model, ketamine_mgr, train_loader)

    ketamine_results = {
        'sparsity': ketamine_mgr.get_sparsity() * 100,
        'gain': ketamine_model.gain_multiplier,
        'clean': evaluate(ketamine_model, clean_test_loader),
        'standard': evaluate(ketamine_model, test_loader),
        'combined': evaluate(ketamine_model, test_loader, 1.0, 0.5),
        'biased_stress': evaluate_biased_stress(ketamine_model, test_loader, internal_sigma=CONFIG['mania_test_sigma'], bias=CONFIG['mania_test_bias']),
        'activation_magnitude': get_avg_activation_magnitude(ketamine_model, test_loader)
    }
    for stress_name, stress_level in CONFIG['extended_stress_levels'].items():
        ketamine_results[f'stress_{stress_name}'] = evaluate(ketamine_model, test_loader, 0.0, stress_level)

    # Relapse
    pre_relapse = ketamine_results['combined']
    target_sparsity = min(ketamine_mgr.get_sparsity() + (1 - ketamine_mgr.get_sparsity()) * 0.40, 0.99)
    ketamine_mgr.prune_by_magnitude(sparsity=target_sparsity)
    ketamine_mgr.apply_masks()
    ketamine_results['relapse_drop'] = pre_relapse - evaluate(ketamine_model, test_loader, 1.0, 0.5)
    results['ketamine'] = ketamine_results

    # SSRI
    ssri_model = StressAwareNetwork().to(DEVICE)
    ssri_model.load_state_dict(base_state_dict)
    ssri_mgr = PruningManager(ssri_model, train_loader)
    ssri_mgr.masks = {k: v.clone() for k, v in base_masks.items()}
    ssri_mgr.apply_masks()
    ssri_treatment(ssri_model, ssri_mgr, train_loader)

    ssri_results = {
        'sparsity': ssri_mgr.get_sparsity() * 100,
        'gain': ssri_model.gain_multiplier,
        'clean': evaluate(ssri_model, clean_test_loader),
        'standard': evaluate(ssri_model, test_loader),
        'combined': evaluate(ssri_model, test_loader, 1.0, 0.5),
        'biased_stress': evaluate_biased_stress(ssri_model, test_loader, internal_sigma=CONFIG['mania_test_sigma'], bias=CONFIG['mania_test_bias']),
        'activation_magnitude': get_avg_activation_magnitude(ssri_model, test_loader)
    }
    for stress_name, stress_level in CONFIG['extended_stress_levels'].items():
        ssri_results[f'stress_{stress_name}'] = evaluate(ssri_model, test_loader, 0.0, stress_level)

    pre_relapse = ssri_results['combined']
    target_sparsity = min(ssri_mgr.get_sparsity() + (1 - ssri_mgr.get_sparsity()) * 0.40, 0.99)
    ssri_mgr.prune_by_magnitude(sparsity=target_sparsity)
    ssri_mgr.apply_masks()
    ssri_results['relapse_drop'] = pre_relapse - evaluate(ssri_model, test_loader, 1.0, 0.5)
    results['ssri'] = ssri_results

    # Neurosteroid
    neuro_model = StressAwareNetwork().to(DEVICE)
    neuro_model.load_state_dict(base_state_dict)
    neuro_mgr = PruningManager(neuro_model, train_loader)
    neuro_mgr.masks = {k: v.clone() for k, v in base_masks.items()}
    neuro_mgr.apply_masks()
    neurosteroid_treatment(neuro_model, neuro_mgr, train_loader)

    neuro_results = {
        'sparsity': neuro_mgr.get_sparsity() * 100,
        'gain': neuro_model.gain_multiplier,
        'effective_scale': neuro_model.inhibition_strength * neuro_model.gain_multiplier,
        'clean': evaluate(neuro_model, clean_test_loader),
        'standard': evaluate(neuro_model, test_loader),
        'combined': evaluate(neuro_model, test_loader, 1.0, 0.5),
        'biased_stress': evaluate_biased_stress(neuro_model, test_loader, internal_sigma=CONFIG['mania_test_sigma'], bias=CONFIG['mania_test_bias']),
        'activation_magnitude': get_avg_activation_magnitude(neuro_model, test_loader)
    }
    for stress_name, stress_level in CONFIG['extended_stress_levels'].items():
        neuro_results[f'stress_{stress_name}'] = evaluate(neuro_model, test_loader, 0.0, stress_level)

    # Off medication
    neuro_model.set_inhibition(1.0, False)
    neuro_model.set_gain(1.0)
    neuro_results['off_medication_combined'] = evaluate(neuro_model, test_loader, 1.0, 0.5)
    neuro_results['off_medication_extreme'] = evaluate(neuro_model, test_loader, 0.0, 2.5)
    neuro_results['off_medication_biased'] = evaluate_biased_stress(neuro_model, test_loader, internal_sigma=CONFIG['mania_test_sigma'], bias=CONFIG['mania_test_bias'])
    neuro_results['off_medication_activation'] = get_avg_activation_magnitude(neuro_model, test_loader)

    # Relapse with modulation
    neuro_model.set_inhibition(CONFIG['neurosteroid_inhibition_strength'], CONFIG['neurosteroid_use_tanh'])
    neuro_model.set_gain(CONFIG['neurosteroid_gain'])
    pre_relapse = neuro_results['combined']
    target_sparsity = min(neuro_mgr.get_sparsity() + (1 - neuro_mgr.get_sparsity()) * 0.40, 0.99)
    neuro_mgr.prune_by_magnitude(sparsity=target_sparsity)
    neuro_mgr.apply_masks()
    neuro_results['relapse_drop'] = pre_relapse - evaluate(neuro_model, test_loader, 1.0, 0.5)
    results['neurosteroid'] = neuro_results

    return results


# ============================================================================
# AGGREGATE RESULTS
# ============================================================================
def aggregate_results(all_results: List[Dict[str, Dict]]) -> Dict[str, Dict[str, Dict[str, float]]]:
    treatments = ['untreated', 'ketamine', 'ssri', 'neurosteroid']
    aggregated = {}

    for treatment in treatments:
        metrics = defaultdict(list)
        for seed_result in all_results:
            if treatment in seed_result:
                for key, value in seed_result[treatment].items():
                    if isinstance(value, (int, float)):
                        metrics[key].append(value)

        aggregated[treatment] = {}
        for key, values in metrics.items():
            aggregated[treatment][key] = {
                'mean': np.mean(values),
                'std': np.std(values),
                'min': np.min(values),
                'max': np.max(values)
            }

    return aggregated


# ============================================================================
# MAIN EXPERIMENT
# ============================================================================
def run_multi_seed_experiment() -> Dict:
    print("\n" + "="*80)
    print("  MULTI-MECHANISM ANTIDEPRESSANT COMPARISON EXPERIMENT")
    print("  WITH MANIC CONVERSION RISK MODELING")
    print("  Running across {} random seeds".format(CONFIG['n_seeds']))
    print("="*80)

    all_results = []

    for seed in range(CONFIG['n_seeds']):
        print(f"\n  Seed {seed+1}/{CONFIG['n_seeds']}...", end=" ", flush=True)
        seed_results = run_single_seed(seed)
        all_results.append(seed_results)
        print("done")

    aggregated = aggregate_results(all_results)

    # ========================================================================
    # PRINT COMPREHENSIVE RESULTS
    # ========================================================================

    print("\n" + "="*80)
    print("  AGGREGATED RESULTS ACROSS {} SEEDS".format(CONFIG['n_seeds']))
    print("="*80)

    treatments = ['untreated', 'ketamine', 'ssri', 'neurosteroid']
    labels = {
        'untreated': 'Untreated (pruned)',
        'ketamine': 'Ketamine-like',
        'ssri': 'SSRI-like',
        'neurosteroid': 'Neurosteroid-like'
    }

    # ========================================================================
    # TABLE 1: ANTIDEPRESSANT EFFICACY
    # ========================================================================
    print("\n" + "-"*80)
    print("  TABLE 1: ANTIDEPRESSANT EFFICACY (Mean ± Std)")
    print("-"*80)
    print(f"\n  {'Treatment':<22} {'Sparsity':>12} {'Clean':>14} {'Standard':>14} {'Combined':>14}")
    print("  " + "-"*78)

    for t in treatments:
        r = aggregated[t]
        print(f"  {labels[t]:<22} "
              f"{r['sparsity']['mean']:>5.1f}±{r['sparsity']['std']:>4.1f}% "
              f"{r['clean']['mean']:>6.1f}±{r['clean']['std']:>4.1f}% "
              f"{r['standard']['mean']:>6.1f}±{r['standard']['std']:>4.1f}% "
              f"{r['combined']['mean']:>6.1f}±{r['combined']['std']:>4.1f}%")

    # ========================================================================
    # TABLE 2: STRESS RESILIENCE PROFILE
    # ========================================================================
    print("\n" + "-"*80)
    print("  TABLE 2: STRESS RESILIENCE PROFILE (Mean ± Std)")
    print("-"*80)
    print(f"\n  {'Treatment':<22} {'None':>14} {'Moderate':>14} {'High':>14} {'Severe':>14} {'Extreme':>14}")
    print("  " + "-"*92)

    for t in treatments:
        r = aggregated[t]
        print(f"  {labels[t]:<22} "
              f"{r['stress_none']['mean']:>6.1f}±{r['stress_none']['std']:>4.1f}% "
              f"{r['stress_moderate']['mean']:>6.1f}±{r['stress_moderate']['std']:>4.1f}% "
              f"{r['stress_high']['mean']:>6.1f}±{r['stress_high']['std']:>4.1f}% "
              f"{r['stress_severe']['mean']:>6.1f}±{r['stress_severe']['std']:>4.1f}% "
              f"{r['stress_extreme']['mean']:>6.1f}±{r['stress_extreme']['std']:>4.1f}%")

    # ========================================================================
    # TABLE 3: MANIC CONVERSION RISK METRICS
    # ========================================================================
    print("\n" + "-"*80)
    print("  TABLE 3: MANIC CONVERSION RISK METRICS (Mean ± Std)")
    print("-"*80)
    print(f"\n  {'Treatment':<22} {'Gain':>10} {'Biased Stress':>18} {'Act. Magnitude':>18}")
    print("  " + "-"*70)

    for t in treatments:
        r = aggregated[t]
        gain_str = f"{r['gain']['mean']:>5.2f}±{r['gain']['std']:>4.2f}"
        biased_str = f"{r['biased_stress']['mean']:>6.1f}±{r['biased_stress']['std']:>4.1f}%"
        act_str = f"{r['activation_magnitude']['mean']:>6.3f}±{r['activation_magnitude']['std']:>5.3f}"
        print(f"  {labels[t]:<22} {gain_str:>10} {biased_str:>18} {act_str:>18}")

    print("\n  Note: Lower Biased Stress = Higher manic conversion risk")
    print("        Higher Activation Magnitude = Higher latent hyperexcitability")

    # ========================================================================
    # TABLE 4: RELAPSE VULNERABILITY
    # ========================================================================
    print("\n" + "-"*80)
    print("  TABLE 4: RELAPSE VULNERABILITY (Mean ± Std)")
    print("-"*80)
    print(f"\n  {'Treatment':<22} {'Relapse Drop':>18}")
    print("  " + "-"*42)

    for t in ['ketamine', 'ssri', 'neurosteroid']:
        r = aggregated[t]
        print(f"  {labels[t]:<22} {r['relapse_drop']['mean']:>6.1f}±{r['relapse_drop']['std']:>4.1f}%")

    # ========================================================================
    # TABLE 5: NEUROSTEROID MEDICATION DEPENDENCE
    # ========================================================================
    print("\n" + "-"*80)
    print("  TABLE 5: NEUROSTEROID MEDICATION DEPENDENCE (Mean ± Std)")
    print("-"*80)

    neuro = aggregated['neurosteroid']
    print(f"\n  {'Condition':<30} {'Combined':>18} {'Extreme':>18} {'Biased':>18}")
    print("  " + "-"*86)
    print(f"  {'On medication':<30} "
          f"{neuro['combined']['mean']:>6.1f}±{neuro['combined']['std']:>4.1f}% "
          f"{neuro['stress_extreme']['mean']:>6.1f}±{neuro['stress_extreme']['std']:>4.1f}% "
          f"{neuro['biased_stress']['mean']:>6.1f}±{neuro['biased_stress']['std']:>4.1f}%")
    print(f"  {'Off medication':<30} "
          f"{neuro['off_medication_combined']['mean']:>6.1f}±{neuro['off_medication_combined']['std']:>4.1f}% "
          f"{neuro['off_medication_extreme']['mean']:>6.1f}±{neuro['off_medication_extreme']['std']:>4.1f}% "
          f"{neuro['off_medication_biased']['mean']:>6.1f}±{neuro['off_medication_biased']['std']:>4.1f}%")

    # ========================================================================
    # TABLE 6: IMPROVEMENT FROM UNTREATED STATE
    # ========================================================================
    print("\n" + "-"*80)
    print("  TABLE 6: IMPROVEMENT FROM UNTREATED STATE")
    print("-"*80)

    untreated_combined = aggregated['untreated']['combined']['mean']
    print(f"\n  {'Treatment':<22} {'Untreated':>12} {'Treated':>12} {'Improvement':>14}")
    print("  " + "-"*62)

    for t in ['ketamine', 'ssri', 'neurosteroid']:
        treated = aggregated[t]['combined']['mean']
        improvement = treated - untreated_combined
        print(f"  {labels[t]:<22} {untreated_combined:>11.1f}% {treated:>11.1f}% {improvement:>+13.1f}%")

    # ========================================================================
    # TABLE 7: MANIC CONVERSION RISK RANKING
    # ========================================================================
    print("\n" + "-"*80)
    print("  TABLE 7: MANIC CONVERSION RISK RANKING")
    print("-"*80)

    risk_data = []
    for t in treatments:
        r = aggregated[t]
        risk_data.append({
            'treatment': labels[t],
            'gain': r['gain']['mean'],
            'biased_stress': r['biased_stress']['mean'],
            'activation': r['activation_magnitude']['mean']
        })

    risk_data.sort(key=lambda x: (-x['gain'], x['biased_stress']))

    print(f"\n  {'Rank':<6} {'Treatment':<22} {'Gain':>8} {'Biased Stress':>15} {'Risk Assessment':<20}")
    print("  " + "-"*75)

    for i, rd in enumerate(risk_data):
        if rd['gain'] >= 1.5:
            risk = "HIGH"
        elif rd['gain'] >= 1.2:
            risk = "MODERATE"
        elif rd['gain'] < 1.0:
            risk = "LOW (protective)"
        else:
            risk = "BASELINE"
        print(f"  {i+1:<6} {rd['treatment']:<22} {rd['gain']:>7.2f} {rd['biased_stress']:>14.1f}% {risk:<20}")

    # ========================================================================
    # TABLE 8: STATISTICAL SUMMARY
    # ========================================================================
    print("\n" + "-"*80)
    print("  TABLE 8: STATISTICAL SUMMARY (Range across seeds)")
    print("-"*80)
    print(f"\n  {'Treatment':<22} {'Combined (min-max)':>22} {'Biased Stress (min-max)':>26}")
    print("  " + "-"*72)

    for t in treatments:
        r = aggregated[t]
        combined_range = f"{r['combined']['min']:.1f} - {r['combined']['max']:.1f}%"
        biased_range = f"{r['biased_stress']['min']:.1f} - {r['biased_stress']['max']:.1f}%"
        print(f"  {labels[t]:<22} {combined_range:>22} {biased_range:>26}")

    # ========================================================================
    # COMPREHENSIVE ANALYSIS SUMMARY
    # ========================================================================
    print("\n" + "="*80)
    print("  COMPREHENSIVE ANALYSIS SUMMARY")
    print("="*80)

    print("""
  ┌────────────────────────────────────────────────────────────────────────────┐
  │ ANTIDEPRESSANT EFFICACY                                                    │
  ├────────────────────────────────────────────────────────────────────────────┤""")

    ket = aggregated['ketamine']
    ssri = aggregated['ssri']
    neuro = aggregated['neurosteroid']
    untreated = aggregated['untreated']

    best_combined = max(ket['combined']['mean'], ssri['combined']['mean'], neuro['combined']['mean'])
    best_treatment = 'Ketamine' if ket['combined']['mean'] == best_combined else ('SSRI' if ssri['combined']['mean'] == best_combined else 'Neurosteroid')

    print(f"  │ Best combined stress performance: {best_treatment:<20} ({best_combined:.1f}%)         │")
    print(f"  │ Ketamine improvement from untreated: {ket['combined']['mean'] - untreated['combined']['mean']:>+5.1f}%                            │")
    print(f"  │ SSRI improvement from untreated:     {ssri['combined']['mean'] - untreated['combined']['mean']:>+5.1f}%                            │")
    print(f"  │ Neurosteroid improvement:            {neuro['combined']['mean'] - untreated['combined']['mean']:>+5.1f}%                            │")

    print("""  ├────────────────────────────────────────────────────────────────────────────┤
  │ MANIC CONVERSION RISK                                                      │
  ├────────────────────────────────────────────────────────────────────────────┤""")

    print(f"  │ Highest risk (gain):      SSRI        (gain = {ssri['gain']['mean']:.2f})                       │")
    print(f"  │ Moderate risk (gain):     Ketamine    (gain = {ket['gain']['mean']:.2f})                       │")
    print(f"  │ Lowest risk (protective): Neurosteroid (effective scale = {neuro['effective_scale']['mean']:.2f})          │")
    print(f"  │                                                                          │")
    print(f"  │ Biased stress tolerance (higher = more resistant to manic switch):       │")
    print(f"  │   Neurosteroid: {neuro['biased_stress']['mean']:>5.1f}% ± {neuro['biased_stress']['std']:.1f}%                                          │")
    print(f"  │   Ketamine:     {ket['biased_stress']['mean']:>5.1f}% ± {ket['biased_stress']['std']:.1f}%                                          │")
    print(f"  │   SSRI:         {ssri['biased_stress']['mean']:>5.1f}% ± {ssri['biased_stress']['std']:.1f}%                                          │")

    print("""  ├────────────────────────────────────────────────────────────────────────────┤
  │ STRUCTURAL VS FUNCTIONAL CHANGES                                           │
  ├────────────────────────────────────────────────────────────────────────────┤""")

    print(f"  │ Ketamine sparsity:     {ket['sparsity']['mean']:>5.1f}% (REDUCED from 95% - adds synapses)       │")
    print(f"  │ SSRI sparsity:         {ssri['sparsity']['mean']:>5.1f}% (UNCHANGED - no structural change)       │")
    print(f"  │ Neurosteroid sparsity: {neuro['sparsity']['mean']:>5.1f}% (UNCHANGED - functional modulation)     │")

    print("""  ├────────────────────────────────────────────────────────────────────────────┤
  │ RELAPSE VULNERABILITY                                                      │
  ├────────────────────────────────────────────────────────────────────────────┤""")

    print(f"  │ Ketamine relapse drop:     {ket['relapse_drop']['mean']:>5.1f}% ± {ket['relapse_drop']['std']:.1f}%                                │")
    print(f"  │ SSRI relapse drop:         {ssri['relapse_drop']['mean']:>5.1f}% ± {ssri['relapse_drop']['std']:.1f}%                                │")
    print(f"  │ Neurosteroid relapse drop: {neuro['relapse_drop']['mean']:>5.1f}% ± {neuro['relapse_drop']['std']:.1f}%                                │")

    print("""  ├────────────────────────────────────────────────────────────────────────────┤
  │ MEDICATION DEPENDENCE (Neurosteroid only)                                  │
  ├────────────────────────────────────────────────────────────────────────────┤""")

    on_off_diff = neuro['combined']['mean'] - neuro['off_medication_combined']['mean']
    print(f"  │ On medication combined:  {neuro['combined']['mean']:>5.1f}%                                        │")
    print(f"  │ Off medication combined: {neuro['off_medication_combined']['mean']:>5.1f}%                                        │")
    print(f"  │ Performance drop when discontinued: {on_off_diff:>+5.1f}%                            │")

    print("  └────────────────────────────────────────────────────────────────────────────┘")

    # ========================================================================
    # FINAL SUMMARY TABLE
    # ========================================================================
    print("\n" + "="*80)
    print("  FINAL COMPARISON MATRIX")
    print("="*80)

    print("""
  ┌─────────────────────┬────────────┬────────────┬──────────────┬─────────────┐
  │ Metric              │ Ketamine   │ SSRI       │ Neurosteroid │ Untreated   │
  ├─────────────────────┼────────────┼────────────┼──────────────┼─────────────┤""")

    print(f"  │ Combined Stress (%) │ {ket['combined']['mean']:>10.1f} │ {ssri['combined']['mean']:>10.1f} │ {neuro['combined']['mean']:>12.1f} │ {untreated['combined']['mean']:>11.1f} │")
    print(f"  │ Extreme Stress (%)  │ {ket['stress_extreme']['mean']:>10.1f} │ {ssri['stress_extreme']['mean']:>10.1f} │ {neuro['stress_extreme']['mean']:>12.1f} │ {untreated['stress_extreme']['mean']:>11.1f} │")
    print(f"  │ Biased Stress (%)   │ {ket['biased_stress']['mean']:>10.1f} │ {ssri['biased_stress']['mean']:>10.1f} │ {neuro['biased_stress']['mean']:>12.1f} │ {untreated['biased_stress']['mean']:>11.1f} │")
    print(f"  │ Gain Multiplier     │ {ket['gain']['mean']:>10.2f} │ {ssri['gain']['mean']:>10.2f} │ {neuro['gain']['mean']:>12.2f} │ {untreated['gain']['mean']:>11.2f} │")
    print(f"  │ Activation Mag.     │ {ket['activation_magnitude']['mean']:>10.3f} │ {ssri['activation_magnitude']['mean']:>10.3f} │ {neuro['activation_magnitude']['mean']:>12.3f} │ {untreated['activation_magnitude']['mean']:>11.3f} │")
    print(f"  │ Sparsity (%)        │ {ket['sparsity']['mean']:>10.1f} │ {ssri['sparsity']['mean']:>10.1f} │ {neuro['sparsity']['mean']:>12.1f} │ {untreated['sparsity']['mean']:>11.1f} │")
    print(f"  │ Relapse Drop (%)    │ {ket['relapse_drop']['mean']:>10.1f} │ {ssri['relapse_drop']['mean']:>10.1f} │ {neuro['relapse_drop']['mean']:>12.1f} │         N/A │")

    print("  └─────────────────────┴────────────┴────────────┴──────────────┴─────────────┘")

    print("\n" + "="*80)
    print("  EXPERIMENT COMPLETE")
    print("  Results averaged across {} random seeds".format(CONFIG['n_seeds']))
    print("="*80 + "\n")

    return {'all_results': all_results, 'aggregated': aggregated}


# ============================================================================
# ENTRY POINT
# ============================================================================
if __name__ == "__main__":
    print("\n" + "#"*80)
    print("#" + " "*78 + "#")
    print("#" + " MULTI-MECHANISM ANTIDEPRESSANT COMPARISON ".center(78) + "#")
    print("#" + " WITH MANIC CONVERSION RISK MODELING ".center(78) + "#")
    print("#" + " Ketamine vs SSRI vs Neurosteroid ".center(78) + "#")
    print("#" + f" ({CONFIG['n_seeds']} Random Seeds) ".center(78) + "#")
    print("#" + " "*78 + "#")
    print("#"*80)

    results = run_multi_seed_experiment()


################################################################################
#                                                                              #
#                  MULTI-MECHANISM ANTIDEPRESSANT COMPARISON                   #
#                     WITH MANIC CONVERSION RISK MODELING                      #
#                       Ketamine vs SSRI vs Neurosteroid                       #
#                              (10 Random Seeds)                               #
#                                                                              #
################################################################################

  MULTI-MECHANISM ANTIDEPRESSANT COMPARISON EXPERIMENT
  WITH MANIC CONVERSION RISK MODELING
  Running across 10 random seeds

  Seed 1/10... done

  Seed 2/10... done

  Seed 3/10... done

  Seed 4/10... done

  Seed 5/10... done

  Seed 6/10... done

  Seed 7/10... done

  Seed 8/10... done

  Seed 9/10... done

  Seed 10/10... done

  AGGREGATED

# The End