# Modes of Antidepressants

In [2]:
"""
================================================================================
MULTI-MECHANISM ANTIDEPRESSANT COMPARISON EXPERIMENT
================================================================================

This script runs only Experiment 4: comparing three antidepressant mechanisms:
1. KETAMINE-LIKE: Gradient-guided synaptogenesis
2. SSRI-LIKE: Gradual stabilization without structural changes
3. NEUROSTEROID-LIKE: Tonic inhibition enhancement

All treatments start from identical pruned (depressed) network states.
================================================================================
"""

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.utils.data import DataLoader, TensorDataset
from typing import Dict, Tuple, List
import warnings

warnings.filterwarnings('ignore', category=UserWarning)

# ============================================================================
# REPRODUCIBILITY
# ============================================================================
SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)
DEVICE = torch.device('cpu')

# ============================================================================
# CONFIGURATION - Fixed version with proper syntax
# ============================================================================
CONFIG = {
    # Data generation
    'n_train': 12000,
    'n_test': 4000,
    'n_clean_test': 2000,
    'data_noise': 0.8,
    'batch_size': 128,

    # Network architecture
    'hidden_dims': [512, 512, 256],
    'input_dim': 2,
    'output_dim': 4,

    # Training hyperparameters
    'baseline_epochs': 20,
    'baseline_lr': 0.001,
    'finetune_epochs': 15,
    'finetune_lr': 0.0005,

    # Pruning parameters
    'prune_sparsity': 0.95,

    # Regrowth parameters
    'regrow_fraction': 0.5,
    'regrow_init_scale': 0.03,
    'gradient_accumulation_batches': 30,

    # Stress levels for evaluation
    'extended_stress_levels': {
        'none': 0.0,
        'moderate': 0.5,
        'high': 1.0,
        'severe': 1.5,
        'extreme': 2.5
    },

    # ========================================================================
    # MONOAMINERGIC (SSRI-LIKE) TREATMENT PARAMETERS
    # ========================================================================
    # Biological rationale: SSRIs increase synaptic serotonin, leading to
    # gradual receptor adaptations over weeks. No rapid synaptogenesis.
    # Network analog: Fixed sparsity, very low LR, gradual noise reduction.
    'monoaminergic_epochs': 100,
    'monoaminergic_lr': 1e-5,
    'monoaminergic_initial_stress': 0.5,

    # ========================================================================
    # NEUROSTEROID (GABAergic) TREATMENT PARAMETERS
    # ========================================================================
    # Biological rationale: Neurosteroids enhance tonic GABA inhibition,
    # reducing network excitability rapidly (days, not weeks).
    # Network analog: Global activation damping, bounded activations.
    'neurosteroid_inhibition_strength': 0.7,
    'neurosteroid_use_tanh': True,
    'neurosteroid_consolidation_epochs': 10,

    # ========================================================================
    # MULTI-MECHANISM COMPARISON PARAMETERS
    # ========================================================================
    'comparison_ketamine_regrow': 0.5,
    'comparison_ketamine_epochs': 15,
    'comparison_ssri_epochs': 100,
    'comparison_neurosteroid_strength': 0.7,
    'comparison_neurosteroid_epochs': 10,
}


# ============================================================================
# DATA GENERATION
# ============================================================================
def generate_blobs(
    n_samples: int = 10000,
    noise: float = 0.8,
    seed: int = None
) -> Tuple[torch.Tensor, torch.Tensor]:
    """Generate 4-class Gaussian blob classification data."""
    if seed is not None:
        rng = np.random.RandomState(seed)
    else:
        rng = np.random.RandomState()

    centers = np.array([[-3, -3], [3, 3], [-3, 3], [3, -3]])
    labels = rng.randint(0, 4, n_samples)
    data = centers[labels] + rng.randn(n_samples, 2) * noise

    return (
        torch.tensor(data, dtype=torch.float32),
        torch.tensor(labels, dtype=torch.long)
    )


def create_data_loaders() -> Tuple[DataLoader, DataLoader, DataLoader]:
    """Create train, test, and clean test data loaders."""
    train_data, train_labels = generate_blobs(CONFIG['n_train'], noise=CONFIG['data_noise'], seed=100)
    test_data, test_labels = generate_blobs(CONFIG['n_test'], noise=CONFIG['data_noise'], seed=200)
    clean_test_data, clean_test_labels = generate_blobs(CONFIG['n_clean_test'], noise=0.0, seed=300)

    train_loader = DataLoader(TensorDataset(train_data, train_labels), batch_size=CONFIG['batch_size'], shuffle=True)
    test_loader = DataLoader(TensorDataset(test_data, test_labels), batch_size=1000)
    clean_test_loader = DataLoader(TensorDataset(clean_test_data, clean_test_labels), batch_size=1000)

    return train_loader, test_loader, clean_test_loader


train_loader, test_loader, clean_test_loader = create_data_loaders()


# ============================================================================
# NETWORK ARCHITECTURE
# ============================================================================
class StressAwareNetwork(nn.Module):
    """
    Feed-forward network with internal noise injection and GABAergic modulation.

    Supports three modulation mechanisms:
    - stress_level: Internal noise (neuromodulatory disruption)
    - inhibition_strength: Multiplicative damping (tonic GABA inhibition)
    - use_tanh: Bounded activation (shunting inhibition)
    """

    def __init__(self, hidden_dims: List[int] = None):
        super().__init__()
        if hidden_dims is None:
            hidden_dims = CONFIG['hidden_dims']

        self.fc1 = nn.Linear(CONFIG['input_dim'], hidden_dims[0])
        self.fc2 = nn.Linear(hidden_dims[0], hidden_dims[1])
        self.fc3 = nn.Linear(hidden_dims[1], hidden_dims[2])
        self.fc4 = nn.Linear(hidden_dims[2], CONFIG['output_dim'])

        self.relu = nn.ReLU()
        self.tanh = nn.Tanh()

        # Modulation parameters
        self.stress_level = 0.0           # Internal noise magnitude
        self.inhibition_strength = 1.0    # Multiplicative damping (1.0 = none)
        self.use_tanh = False             # Use bounded activation

        self.weight_layers = ['fc1', 'fc2', 'fc3', 'fc4']

    def set_stress(self, level: float):
        """Set internal noise level for stress simulation."""
        self.stress_level = level

    def set_inhibition(self, strength: float, use_tanh: bool = False):
        """Set GABAergic tonic inhibition parameters."""
        self.inhibition_strength = strength
        self.use_tanh = use_tanh

    def reduce_stress_gradually(self, epoch: int, total_epochs: int,
                                 initial_stress: float = 0.5, final_stress: float = 0.0):
        """Linearly reduce internal stress over epochs (SSRI-like)."""
        progress = epoch / max(total_epochs - 1, 1)
        self.stress_level = initial_stress + progress * (final_stress - initial_stress)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward pass with noise injection and inhibitory modulation."""
        activation = self.tanh if self.use_tanh else self.relu

        # Layer 1
        h = activation(self.fc1(x))
        if self.stress_level > 0:
            h = h + torch.randn_like(h) * self.stress_level
        h = h * self.inhibition_strength

        # Layer 2
        h = activation(self.fc2(h))
        if self.stress_level > 0:
            h = h + torch.randn_like(h) * self.stress_level
        h = h * self.inhibition_strength

        # Layer 3
        h = activation(self.fc3(h))
        if self.stress_level > 0:
            h = h + torch.randn_like(h) * self.stress_level
        h = h * self.inhibition_strength

        # Output layer (no modulation)
        return self.fc4(h)

    def count_parameters(self) -> Tuple[int, int]:
        """Count total and non-zero parameters."""
        total = sum(p.numel() for p in self.parameters())
        nonzero = sum((p != 0).sum().item() for p in self.parameters())
        return total, nonzero


# ============================================================================
# PRUNING MANAGER
# ============================================================================
class PruningManager:
    """Manages structured pruning and gradient-guided regrowth."""

    def __init__(self, model: StressAwareNetwork):
        self.model = model
        self.masks = {}
        self.gradient_buffer = {}

        for name, param in model.named_parameters():
            if 'weight' in name and param.dim() >= 2:
                self.masks[name] = torch.ones_like(param, dtype=torch.float32)
                self.gradient_buffer[name] = torch.zeros_like(param)

    def prune_by_magnitude(self, sparsity: float, per_layer: bool = True) -> Dict[str, Dict]:
        """Prune weights by magnitude."""
        stats = {}

        for name, param in self.model.named_parameters():
            if name in self.masks:
                weights = param.data.abs()
                threshold = torch.quantile(weights.flatten(), sparsity)
                self.masks[name] = (weights >= threshold).float()
                param.data *= self.masks[name]

                kept = self.masks[name].sum().item()
                total = self.masks[name].numel()
                stats[name] = {'kept': int(kept), 'total': total, 'actual_sparsity': 1 - kept/total}

        return stats

    def _accumulate_gradients(self, num_batches: int = 30):
        """Accumulate gradient magnitudes at pruned positions."""
        model = self.model
        loss_fn = nn.CrossEntropyLoss()

        for name in self.gradient_buffer:
            self.gradient_buffer[name].zero_()

        model.train()
        original_stress = model.stress_level
        model.set_stress(0.0)

        batch_count = 0
        for x, y in train_loader:
            if batch_count >= num_batches:
                break
            x, y = x.to(DEVICE), y.to(DEVICE)
            loss = loss_fn(model(x), y)
            loss.backward()

            with torch.no_grad():
                for name, param in model.named_parameters():
                    if name in self.masks:
                        pruned_mask = (self.masks[name] == 0).float()
                        self.gradient_buffer[name] += param.grad.abs() * pruned_mask
            model.zero_grad()
            batch_count += 1

        model.set_stress(original_stress)

    def gradient_guided_regrow(self, regrow_fraction: float,
                                init_scale: float = None) -> Dict[str, Dict]:
        """Regrow pruned connections based on gradient importance."""
        if init_scale is None:
            init_scale = CONFIG['regrow_init_scale']

        self._accumulate_gradients(num_batches=CONFIG['gradient_accumulation_batches'])

        stats = {}
        for name, param in self.model.named_parameters():
            if name not in self.masks:
                continue

            mask = self.masks[name]
            pruned_positions = (mask == 0)
            num_pruned = pruned_positions.sum().item()

            if num_pruned == 0:
                stats[name] = {'regrown': 0, 'still_pruned': 0}
                continue

            gradient_scores = self.gradient_buffer[name][pruned_positions]
            num_regrow = max(1, int(regrow_fraction * num_pruned))
            num_regrow = min(num_regrow, gradient_scores.numel())

            _, top_indices = torch.topk(gradient_scores.flatten(), num_regrow)
            flat_pruned_indices = torch.where(pruned_positions.flatten())[0]
            regrow_flat_indices = flat_pruned_indices[top_indices]

            flat_mask = mask.flatten()
            flat_param = param.data.flatten()
            flat_mask[regrow_flat_indices] = 1.0
            flat_param[regrow_flat_indices] = torch.randn(num_regrow) * init_scale

            self.masks[name] = flat_mask.view_as(mask)
            param.data = flat_param.view_as(param)

            stats[name] = {'regrown': num_regrow, 'still_pruned': int(num_pruned - num_regrow)}

        return stats

    def apply_masks(self):
        """Re-apply masks to maintain sparsity."""
        with torch.no_grad():
            for name, param in self.model.named_parameters():
                if name in self.masks:
                    param.data *= self.masks[name]

    def get_sparsity(self) -> float:
        """Calculate overall network sparsity."""
        total = sum(m.numel() for m in self.masks.values())
        zeros = sum((m == 0).sum().item() for m in self.masks.values())
        return zeros / total if total > 0 else 0.0


# ============================================================================
# TRAINING FUNCTIONS
# ============================================================================
def train(model: StressAwareNetwork, epochs: int = 15, lr: float = 0.001,
          pruning_manager: PruningManager = None, verbose: bool = False) -> List[float]:
    """Standard training loop."""
    optimizer = optim.Adam(model.parameters(), lr=lr)
    loss_fn = nn.CrossEntropyLoss()
    losses = []

    original_stress = model.stress_level
    model.set_stress(0.0)

    for epoch in range(epochs):
        model.train()
        epoch_loss = 0.0
        for x, y in train_loader:
            x, y = x.to(DEVICE), y.to(DEVICE)
            optimizer.zero_grad()
            loss = loss_fn(model(x), y)
            loss.backward()
            optimizer.step()
            if pruning_manager:
                pruning_manager.apply_masks()
            epoch_loss += loss.item()
        losses.append(epoch_loss / len(train_loader))
        if verbose:
            print(f"      Epoch {epoch+1}/{epochs}, Loss: {losses[-1]:.4f}")

    model.set_stress(original_stress)
    return losses


def train_with_stress_schedule(model: StressAwareNetwork, epochs: int, lr: float,
                                initial_stress: float, final_stress: float = 0.0,
                                pruning_manager: PruningManager = None,
                                verbose: bool = False, print_interval: int = 20) -> List[float]:
    """Train with gradually reducing internal stress (SSRI-like)."""
    optimizer = optim.Adam(model.parameters(), lr=lr)
    loss_fn = nn.CrossEntropyLoss()
    losses = []

    for epoch in range(epochs):
        model.reduce_stress_gradually(epoch, epochs, initial_stress, final_stress)
        model.train()
        epoch_loss = 0.0

        for x, y in train_loader:
            x, y = x.to(DEVICE), y.to(DEVICE)
            optimizer.zero_grad()
            loss = loss_fn(model(x), y)
            loss.backward()
            optimizer.step()
            if pruning_manager:
                pruning_manager.apply_masks()
            epoch_loss += loss.item()

        losses.append(epoch_loss / len(train_loader))
        if verbose and (epoch + 1) % print_interval == 0:
            print(f"      SSRI epoch {epoch+1}/{epochs}, stress: {model.stress_level:.3f}, loss: {losses[-1]:.4f}")

    model.set_stress(0.0)
    return losses


# ============================================================================
# EVALUATION FUNCTIONS
# ============================================================================
def evaluate(model: StressAwareNetwork, loader: DataLoader,
             input_noise: float = 0.0, internal_stress: float = 0.0) -> float:
    """Evaluate model accuracy."""
    model.eval()
    model.set_stress(internal_stress)
    correct, total = 0, 0

    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(DEVICE), y.to(DEVICE)
            if input_noise > 0:
                x = x + torch.randn_like(x) * input_noise
            correct += (model(x).argmax(dim=1) == y).sum().item()
            total += y.size(0)

    model.set_stress(0.0)
    return 100.0 * correct / total


def evaluate_with_neurosteroid(model: StressAwareNetwork, loader: DataLoader,
                                input_noise: float = 0.0, internal_stress: float = 0.0) -> float:
    """Evaluate with neurosteroid modulation ACTIVE (inhibition settings preserved)."""
    model.eval()
    model.set_stress(internal_stress)
    correct, total = 0, 0

    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(DEVICE), y.to(DEVICE)
            if input_noise > 0:
                x = x + torch.randn_like(x) * input_noise
            correct += (model(x).argmax(dim=1) == y).sum().item()
            total += y.size(0)

    model.set_stress(0.0)
    return 100.0 * correct / total


# ============================================================================
# TREATMENT PROTOCOLS
# ============================================================================
def ketamine_treatment(model: StressAwareNetwork, pruning_mgr: PruningManager,
                       regrow_fraction: float = None, consolidation_epochs: int = None,
                       verbose: bool = True) -> Dict:
    """
    KETAMINE-LIKE TREATMENT: Gradient-guided synaptogenesis.

    Biological model:
    - NMDA antagonism → BDNF release → mTOR activation → new spine formation
    - Activity-dependent targeting of new synapses
    - Brief consolidation to strengthen useful connections

    Key feature: ADDS NEW SYNAPSES (reduces sparsity)
    """
    if regrow_fraction is None:
        regrow_fraction = CONFIG['comparison_ketamine_regrow']
    if consolidation_epochs is None:
        consolidation_epochs = CONFIG['comparison_ketamine_epochs']

    if verbose:
        print(f"\n    KETAMINE-LIKE TREATMENT:")
        print(f"      Regrowth fraction: {regrow_fraction*100:.0f}%")
        print(f"      Consolidation: {consolidation_epochs} epochs")
        print(f"      Estimating gradient importance...")

    regrow_stats = pruning_mgr.gradient_guided_regrow(regrow_fraction=regrow_fraction)
    total_regrown = sum(s['regrown'] for s in regrow_stats.values())

    if verbose:
        print(f"      Restored {total_regrown:,} synapses")
        print(f"      Consolidating new synapses...")

    consolidation_losses = train(model, epochs=consolidation_epochs,
                                  lr=CONFIG['finetune_lr'], pruning_manager=pruning_mgr)

    final_sparsity = pruning_mgr.get_sparsity()
    if verbose:
        print(f"      Final sparsity: {final_sparsity*100:.1f}%")

    return {'regrow_stats': regrow_stats, 'final_sparsity': final_sparsity}


def ssri_treatment(model: StressAwareNetwork, pruning_mgr: PruningManager,
                   epochs: int = None, learning_rate: float = None,
                   initial_stress: float = None, verbose: bool = True,
                   print_interval: int = 25) -> Dict:
    """
    SSRI-LIKE TREATMENT: Gradual stabilization without structural changes.

    Biological model:
    - Increased synaptic serotonin → gradual receptor adaptations
    - 5-HT1A autoreceptor desensitization over weeks
    - Improved signal-to-noise in existing circuits

    Key feature: NO NEW SYNAPSES (sparsity unchanged)
    """
    if epochs is None:
        epochs = CONFIG['monoaminergic_epochs']
    if learning_rate is None:
        learning_rate = CONFIG['monoaminergic_lr']
    if initial_stress is None:
        initial_stress = CONFIG['monoaminergic_initial_stress']

    if verbose:
        print(f"\n    SSRI-LIKE TREATMENT:")
        print(f"      Duration: {epochs} epochs (gradual)")
        print(f"      Learning rate: {learning_rate} (very low)")
        print(f"      Internal stress: {initial_stress} → 0.0")
        print(f"      Note: NO structural changes (fixed sparsity)")

    initial_sparsity = pruning_mgr.get_sparsity()

    losses = train_with_stress_schedule(model, epochs=epochs, lr=learning_rate,
                                         initial_stress=initial_stress, final_stress=0.0,
                                         pruning_manager=pruning_mgr, verbose=verbose,
                                         print_interval=print_interval)

    final_sparsity = pruning_mgr.get_sparsity()
    if verbose:
        print(f"      Final sparsity: {final_sparsity*100:.1f}% (unchanged)")

    return {'final_sparsity': final_sparsity, 'training_losses': losses}


def neurosteroid_treatment(model: StressAwareNetwork, pruning_mgr: PruningManager,
                           inhibition_strength: float = None, use_tanh: bool = None,
                           consolidation_epochs: int = None, verbose: bool = True) -> Dict:
    """
    NEUROSTEROID-LIKE TREATMENT: Enhanced tonic inhibition.

    Biological model:
    - Enhanced extrasynaptic GABA-A receptor activation
    - Tonic (sustained) inhibition reduces network excitability
    - Rapid onset (days, not weeks)

    Key features:
    - NO NEW SYNAPSES (sparsity unchanged)
    - Works by DAMPING activity rather than building structure
    - Medication-dependent (effects reverse when stopped)
    """
    if inhibition_strength is None:
        inhibition_strength = CONFIG['neurosteroid_inhibition_strength']
    if use_tanh is None:
        use_tanh = CONFIG['neurosteroid_use_tanh']
    if consolidation_epochs is None:
        consolidation_epochs = CONFIG['neurosteroid_consolidation_epochs']

    if verbose:
        print(f"\n    NEUROSTEROID-LIKE TREATMENT:")
        print(f"      Inhibition strength: {inhibition_strength} ({(1-inhibition_strength)*100:.0f}% damping)")
        print(f"      Bounded activation (tanh): {use_tanh}")
        print(f"      Consolidation: {consolidation_epochs} epochs")
        print(f"      Note: NO structural changes (fixed sparsity)")

    # Apply tonic inhibition modulation
    model.set_inhibition(inhibition_strength, use_tanh)
    if verbose:
        print(f"      Applied tonic inhibition modulation...")
        print(f"      Adapting to new activity dynamics...")

    consolidation_losses = train(model, epochs=consolidation_epochs,
                                  lr=CONFIG['finetune_lr'], pruning_manager=pruning_mgr)

    final_sparsity = pruning_mgr.get_sparsity()
    if verbose:
        print(f"      Final sparsity: {final_sparsity*100:.1f}% (unchanged)")

    return {'final_sparsity': final_sparsity, 'inhibition_strength': inhibition_strength,
            'use_tanh': use_tanh}


# ============================================================================
# MAIN EXPERIMENT
# ============================================================================
def run_multi_mechanism_experiment() -> Dict[str, Dict]:
    """
    Compare ketamine, SSRI, and neurosteroid treatment mechanisms.

    All treatments start from identical 95% sparse (depressed) networks.
    """
    print("\n" + "="*80)
    print("  MULTI-MECHANISM ANTIDEPRESSANT COMPARISON EXPERIMENT")
    print("="*80)

    print("""
  COMPARING THREE ANTIDEPRESSANT MECHANISMS:

  ┌─────────────────┬─────────────────────────────────────────────────────────┐
  │ Mechanism       │ Key Feature                                             │
  ├─────────────────┼─────────────────────────────────────────────────────────┤
  │ Ketamine        │ Gradient-guided synaptogenesis (↑ density)              │
  │ SSRI            │ Gradual noise reduction (stabilizes existing weights)   │
  │ Neurosteroid    │ Tonic inhibition (damps activity, bounds firing)        │
  └─────────────────┴─────────────────────────────────────────────────────────┘

  All treatments start from identical 95% sparse (depressed) networks.
    """)

    # ========================================================================
    # PREPARE BASE PRUNED MODEL
    # ========================================================================
    print("-"*70)
    print("  Preparing shared pruned baseline...")
    print("-"*70)

    base_model = StressAwareNetwork().to(DEVICE)
    print(f"  Training full network ({CONFIG['baseline_epochs']} epochs)...")
    train(base_model, epochs=CONFIG['baseline_epochs'], lr=CONFIG['baseline_lr'])

    base_pruning_mgr = PruningManager(base_model)
    base_pruning_mgr.prune_by_magnitude(sparsity=CONFIG['prune_sparsity'], per_layer=True)

    initial_sparsity = base_pruning_mgr.get_sparsity()
    print(f"  Pruned to {initial_sparsity*100:.1f}% sparse")

    # Evaluate untreated state
    print("\n  UNTREATED PRUNED STATE:")
    untreated_results = {'sparsity': initial_sparsity * 100}
    untreated_results['clean'] = evaluate(base_model, clean_test_loader, 0.0, 0.0)
    untreated_results['standard'] = evaluate(base_model, test_loader, 0.0, 0.0)
    for stress_name, stress_level in CONFIG['extended_stress_levels'].items():
        untreated_results[f'stress_{stress_name}'] = evaluate(base_model, test_loader, 0.0, stress_level)
    untreated_results['combined'] = evaluate(base_model, test_loader, 1.0, 0.5)

    print(f"    Clean: {untreated_results['clean']:.1f}%")
    print(f"    Standard: {untreated_results['standard']:.1f}%")
    print(f"    Combined stress: {untreated_results['combined']:.1f}%")
    print(f"    Extreme stress: {untreated_results['stress_extreme']:.1f}%")

    # Save state for cloning
    base_state_dict = {k: v.clone() for k, v in base_model.state_dict().items()}
    base_masks = {k: v.clone() for k, v in base_pruning_mgr.masks.items()}

    results = {'untreated': untreated_results}

    # ========================================================================
    # TREATMENT 1: KETAMINE-LIKE
    # ========================================================================
    print("\n" + "="*70)
    print("  TREATMENT 1: KETAMINE-LIKE (Synaptogenesis)")
    print("="*70)

    ketamine_model = StressAwareNetwork().to(DEVICE)
    ketamine_model.load_state_dict(base_state_dict)
    ketamine_mgr = PruningManager(ketamine_model)
    ketamine_mgr.masks = {k: v.clone() for k, v in base_masks.items()}
    ketamine_mgr.apply_masks()

    ketamine_stats = ketamine_treatment(ketamine_model, ketamine_mgr,
                                         regrow_fraction=CONFIG['comparison_ketamine_regrow'],
                                         consolidation_epochs=CONFIG['comparison_ketamine_epochs'])

    print("\n    POST-TREATMENT EVALUATION:")
    ketamine_results = {'sparsity': ketamine_mgr.get_sparsity() * 100}
    ketamine_results['clean'] = evaluate(ketamine_model, clean_test_loader, 0.0, 0.0)
    ketamine_results['standard'] = evaluate(ketamine_model, test_loader, 0.0, 0.0)
    for stress_name, stress_level in CONFIG['extended_stress_levels'].items():
        ketamine_results[f'stress_{stress_name}'] = evaluate(ketamine_model, test_loader, 0.0, stress_level)
    ketamine_results['combined'] = evaluate(ketamine_model, test_loader, 1.0, 0.5)

    print(f"      Clean: {ketamine_results['clean']:.1f}%")
    print(f"      Standard: {ketamine_results['standard']:.1f}%")
    print(f"      Combined stress: {ketamine_results['combined']:.1f}%")
    print(f"      Extreme stress: {ketamine_results['stress_extreme']:.1f}%")

    # Relapse simulation
    print("\n    RELAPSE SIMULATION:")
    pre_relapse = ketamine_results['combined']
    pre_sparsity = ketamine_mgr.get_sparsity()
    target_sparsity = min(pre_sparsity + (1 - pre_sparsity) * 0.40, 0.99)
    ketamine_mgr.prune_by_magnitude(sparsity=target_sparsity, per_layer=True)
    ketamine_mgr.apply_masks()
    post_relapse = evaluate(ketamine_model, test_loader, 1.0, 0.5)
    ketamine_results['relapse_drop'] = pre_relapse - post_relapse
    print(f"      Combined: {pre_relapse:.1f}% → {post_relapse:.1f}% (drop: {ketamine_results['relapse_drop']:.1f}%)")

    results['ketamine'] = ketamine_results

    # ========================================================================
    # TREATMENT 2: SSRI-LIKE
    # ========================================================================
    print("\n" + "="*70)
    print("  TREATMENT 2: SSRI-LIKE (Gradual Stabilization)")
    print("="*70)

    ssri_model = StressAwareNetwork().to(DEVICE)
    ssri_model.load_state_dict(base_state_dict)
    ssri_mgr = PruningManager(ssri_model)
    ssri_mgr.masks = {k: v.clone() for k, v in base_masks.items()}
    ssri_mgr.apply_masks()

    ssri_stats = ssri_treatment(ssri_model, ssri_mgr,
                                 epochs=CONFIG['comparison_ssri_epochs'],
                                 learning_rate=CONFIG['monoaminergic_lr'],
                                 initial_stress=CONFIG['monoaminergic_initial_stress'],
                                 print_interval=25)

    print("\n    POST-TREATMENT EVALUATION:")
    ssri_results = {'sparsity': ssri_mgr.get_sparsity() * 100}
    ssri_results['clean'] = evaluate(ssri_model, clean_test_loader, 0.0, 0.0)
    ssri_results['standard'] = evaluate(ssri_model, test_loader, 0.0, 0.0)
    for stress_name, stress_level in CONFIG['extended_stress_levels'].items():
        ssri_results[f'stress_{stress_name}'] = evaluate(ssri_model, test_loader, 0.0, stress_level)
    ssri_results['combined'] = evaluate(ssri_model, test_loader, 1.0, 0.5)

    print(f"      Clean: {ssri_results['clean']:.1f}%")
    print(f"      Standard: {ssri_results['standard']:.1f}%")
    print(f"      Combined stress: {ssri_results['combined']:.1f}%")
    print(f"      Extreme stress: {ssri_results['stress_extreme']:.1f}%")

    # Relapse simulation
    print("\n    RELAPSE SIMULATION:")
    pre_relapse = ssri_results['combined']
    pre_sparsity = ssri_mgr.get_sparsity()
    target_sparsity = min(pre_sparsity + (1 - pre_sparsity) * 0.40, 0.99)
    ssri_mgr.prune_by_magnitude(sparsity=target_sparsity, per_layer=True)
    ssri_mgr.apply_masks()
    post_relapse = evaluate(ssri_model, test_loader, 1.0, 0.5)
    ssri_results['relapse_drop'] = pre_relapse - post_relapse
    print(f"      Combined: {pre_relapse:.1f}% → {post_relapse:.1f}% (drop: {ssri_results['relapse_drop']:.1f}%)")

    results['ssri'] = ssri_results

    # ========================================================================
    # TREATMENT 3: NEUROSTEROID-LIKE
    # ========================================================================
    print("\n" + "="*70)
    print("  TREATMENT 3: NEUROSTEROID-LIKE (Tonic Inhibition)")
    print("="*70)

    neuro_model = StressAwareNetwork().to(DEVICE)
    neuro_model.load_state_dict(base_state_dict)
    neuro_mgr = PruningManager(neuro_model)
    neuro_mgr.masks = {k: v.clone() for k, v in base_masks.items()}
    neuro_mgr.apply_masks()

    neuro_stats = neurosteroid_treatment(neuro_model, neuro_mgr,
                                          inhibition_strength=CONFIG['neurosteroid_inhibition_strength'],
                                          use_tanh=CONFIG['neurosteroid_use_tanh'],
                                          consolidation_epochs=CONFIG['neurosteroid_consolidation_epochs'])

    # Evaluate WITH modulation active (patient on medication)
    print("\n    POST-TREATMENT EVALUATION (with modulation active):")
    neuro_results = {'sparsity': neuro_mgr.get_sparsity() * 100}
    neuro_results['clean'] = evaluate_with_neurosteroid(neuro_model, clean_test_loader, 0.0, 0.0)
    neuro_results['standard'] = evaluate_with_neurosteroid(neuro_model, test_loader, 0.0, 0.0)
    for stress_name, stress_level in CONFIG['extended_stress_levels'].items():
        neuro_results[f'stress_{stress_name}'] = evaluate_with_neurosteroid(neuro_model, test_loader, 0.0, stress_level)
    neuro_results['combined'] = evaluate_with_neurosteroid(neuro_model, test_loader, 1.0, 0.5)

    print(f"      Clean: {neuro_results['clean']:.1f}%")
    print(f"      Standard: {neuro_results['standard']:.1f}%")
    print(f"      Combined stress: {neuro_results['combined']:.1f}%")
    print(f"      Extreme stress: {neuro_results['stress_extreme']:.1f}%")

    # Test WITHOUT modulation (medication discontinued)
    print("\n    EVALUATION WITHOUT MODULATION (medication discontinued):")
    neuro_model.set_inhibition(1.0, False)
    off_med_combined = evaluate(neuro_model, test_loader, 1.0, 0.5)
    off_med_extreme = evaluate(neuro_model, test_loader, 0.0, 2.5)
    print(f"      Combined stress: {off_med_combined:.1f}%")
    print(f"      Extreme stress: {off_med_extreme:.1f}%")
    neuro_results['off_medication_combined'] = off_med_combined
    neuro_results['off_medication_extreme'] = off_med_extreme

    # Restore modulation for relapse test
    neuro_model.set_inhibition(CONFIG['neurosteroid_inhibition_strength'],
                                CONFIG['neurosteroid_use_tanh'])

    # Relapse simulation
    print("\n    RELAPSE SIMULATION (with modulation active):")
    pre_relapse = neuro_results['combined']
    pre_sparsity = neuro_mgr.get_sparsity()
    target_sparsity = min(pre_sparsity + (1 - pre_sparsity) * 0.40, 0.99)
    neuro_mgr.prune_by_magnitude(sparsity=target_sparsity, per_layer=True)
    neuro_mgr.apply_masks()
    post_relapse = evaluate_with_neurosteroid(neuro_model, test_loader, 1.0, 0.5)
    neuro_results['relapse_drop'] = pre_relapse - post_relapse
    print(f"      Combined: {pre_relapse:.1f}% → {post_relapse:.1f}% (drop: {neuro_results['relapse_drop']:.1f}%)")

    results['neurosteroid'] = neuro_results

    # ========================================================================
    # COMPREHENSIVE COMPARISON
    # ========================================================================
    print("\n" + "="*80)
    print("  COMPREHENSIVE COMPARISON: ALL TREATMENTS")
    print("="*80)

    treatments = ['untreated', 'ketamine', 'ssri', 'neurosteroid']
    labels = {'untreated': 'Untreated (pruned)', 'ketamine': 'Ketamine-like',
              'ssri': 'SSRI-like', 'neurosteroid': 'Neurosteroid-like'}

    print(f"\n  {'Treatment':<22} {'Sparsity':>10} {'Clean':>8} {'Standard':>10} "
          f"{'Combined':>10} {'Extreme':>10} {'Relapse':>10}")
    print("  " + "-"*85)

    for t in treatments:
        r = results[t]
        relapse = r.get('relapse_drop', 'N/A')
        relapse_str = f"{relapse:.1f}%" if isinstance(relapse, float) else relapse
        print(f"  {labels[t]:<22} {r['sparsity']:>9.1f}% {r['clean']:>7.1f}% "
              f"{r['standard']:>9.1f}% {r['combined']:>9.1f}% "
              f"{r['stress_extreme']:>9.1f}% {relapse_str:>10}")

    # Stress resilience profile
    print("\n  STRESS RESILIENCE PROFILE:")
    print(f"\n  {'Treatment':<22} {'None':>8} {'Moderate':>10} {'High':>8} {'Severe':>8} {'Extreme':>10}")
    print("  " + "-"*70)

    for t in treatments:
        r = results[t]
        print(f"  {labels[t]:<22} {r['stress_none']:>7.1f}% {r['stress_moderate']:>9.1f}% "
              f"{r['stress_high']:>7.1f}% {r['stress_severe']:>7.1f}% {r['stress_extreme']:>9.1f}%")

    # ========================================================================
    # ANALYSIS
    # ========================================================================
    print("\n" + "-"*80)
    print("  ANALYSIS")
    print("-"*80)

    ket, ssri, neuro = results['ketamine'], results['ssri'], results['neurosteroid']
    untreated = results['untreated']

    print("\n  1. IMPROVEMENT FROM UNTREATED STATE (Combined Stress):")
    print(f"     Ketamine:    {untreated['combined']:.1f}% → {ket['combined']:.1f}% (+{ket['combined'] - untreated['combined']:.1f}%)")
    print(f"     SSRI:        {untreated['combined']:.1f}% → {ssri['combined']:.1f}% (+{ssri['combined'] - untreated['combined']:.1f}%)")
    print(f"     Neurosteroid:{untreated['combined']:.1f}% → {neuro['combined']:.1f}% (+{neuro['combined'] - untreated['combined']:.1f}%)")

    print("\n  2. STRUCTURAL VS FUNCTIONAL CHANGES:")
    print(f"     Ketamine sparsity:    {ket['sparsity']:.1f}% (REDUCED from 95%)")
    print(f"     SSRI sparsity:        {ssri['sparsity']:.1f}% (UNCHANGED)")
    print(f"     Neurosteroid sparsity:{neuro['sparsity']:.1f}% (UNCHANGED)")
    print("\n     → Ketamine is the ONLY treatment that adds new connections")

    print("\n  3. EXTREME STRESS RESILIENCE (σ=2.5):")
    print(f"     Ketamine:    {ket['stress_extreme']:.1f}%")
    print(f"     SSRI:        {ssri['stress_extreme']:.1f}%")
    print(f"     Neurosteroid:{neuro['stress_extreme']:.1f}%")

    print("\n  4. NEUROSTEROID MEDICATION DEPENDENCE:")
    print(f"     Combined ON medication:  {neuro['combined']:.1f}%")
    print(f"     Combined OFF medication: {neuro['off_medication_combined']:.1f}%")
    print(f"     Extreme ON medication:   {neuro['stress_extreme']:.1f}%")
    print(f"     Extreme OFF medication:  {neuro['off_medication_extreme']:.1f}%")

    print("\n  5. RELAPSE VULNERABILITY:")
    print(f"     Ketamine:    {ket['relapse_drop']:.1f}% drop")
    print(f"     SSRI:        {ssri['relapse_drop']:.1f}% drop")
    print(f"     Neurosteroid:{neuro['relapse_drop']:.1f}% drop")

    # ========================================================================
    # CLINICAL INTERPRETATION
    # ========================================================================
    print("\n" + "-"*80)
    print("  CLINICAL INTERPRETATION")
    print("-"*80)

    print("""
  KEY FINDINGS:

  1. MECHANISM MATTERS: Different antidepressants work through distinct routes
     - Ketamine REBUILDS: Adds new synapses, restores structural density
     - SSRIs REFINE: Strengthen existing pathways via gradual adaptation
     - Neurosteroids STABILIZE: Damp hyperexcitability, bound activity range

  2. SPEED-DURABILITY TRADEOFF:
     - Ketamine: Fast onset, durable changes (new structure persists)
     - Neurosteroid: Fast onset, medication-dependent (dynamics reset if stopped)
     - SSRI: Slow onset, moderate durability (refined weights persist)

  3. TREATMENT SELECTION IMPLICATIONS:

     ┌──────────────────────────────────────┬──────────────────────────────────┐
     │ Clinical Scenario                    │ Suggested Mechanism              │
     ├──────────────────────────────────────┼──────────────────────────────────┤
     │ Severe, treatment-resistant MDD      │ Ketamine (structural repair)     │
     │ Postpartum depression, acute crisis  │ Neurosteroid (rapid stabilize)   │
     │ Mild-moderate, first-line            │ SSRI (gradual, acceptable)       │
     │ Recurrent with high relapse risk     │ Ketamine (durable structure)     │
     │ Hyperexcitable/anxious component     │ Neurosteroid (activity damping)  │
     └──────────────────────────────────────┴──────────────────────────────────┘

  4. COMBINATION THERAPY RATIONALE:
     - Ketamine + SSRI: Structural + neuromodulatory benefits
     - Neurosteroid + SSRI: Rapid stabilization while waiting for SSRI onset
     - Ketamine + Psychotherapy: New synapses + activity-guided consolidation
    """)

    return results


# ============================================================================
# ENTRY POINT
# ============================================================================
if __name__ == "__main__":
    print("\n" + "#"*80)
    print("#" + " "*78 + "#")
    print("#" + " MULTI-MECHANISM ANTIDEPRESSANT COMPARISON ".center(78) + "#")
    print("#" + " Ketamine vs SSRI vs Neurosteroid ".center(78) + "#")
    print("#" + " "*78 + "#")
    print("#"*80)

    results = run_multi_mechanism_experiment()

    print("\n" + "="*80)
    print("  EXPERIMENT COMPLETE")
    print("="*80 + "\n")


################################################################################
#                                                                              #
#                  MULTI-MECHANISM ANTIDEPRESSANT COMPARISON                   #
#                       Ketamine vs SSRI vs Neurosteroid                       #
#                                                                              #
################################################################################

  MULTI-MECHANISM ANTIDEPRESSANT COMPARISON EXPERIMENT

  COMPARING THREE ANTIDEPRESSANT MECHANISMS:
  
  ┌─────────────────┬─────────────────────────────────────────────────────────┐
  │ Mechanism       │ Key Feature                                             │
  ├─────────────────┼─────────────────────────────────────────────────────────┤
  │ Ketamine        │ Gradient-guided synaptogenesis (↑ density)              │
  │ SSRI            │ Gradual noise reduction (stabilizes existing weights)   │
  │ Neur

# The End