# Manic Relapse Chronic Treatment

In [1]:
"""
================================================================================
MULTI-MECHANISM ANTIDEPRESSANT COMPARISON EXPERIMENT
WITH MANIC CONVERSION RISK MODELING + LONGITUDINAL MANIC RELAPSE
================================================================================

This script compares three antidepressant mechanisms and models:
1. Acute treatment efficacy
2. Manic conversion risk (treatment-emergent mania)
3. Longitudinal manic relapse after medication discontinuation

Key Architecture/Parameter Changes for Longitudinal Manic Relapse:
- Added mood stabilizer (MS) protection parameters to StressAwareNetwork:
  - ms_protection_level: Overall MS protection (0.0-1.0)
  - ms_gain_cap: Caps excitability overshoot when MS active
  - ms_inhib_bias: Adds mild inhibitory bias to activations
  - ms_bias_damping: Reduces impact of positive-biased noise
  - lingering_ms_decay_rate: Treatment-specific decay after discontinuation

- Treatment-specific decay rates (architecture parameter differences):
  - Ketamine: slow decay (0.002/step) - structural changes persist
  - SSRI: fast decay (0.015/step) - functional overshoot re-emerges quickly
  - Neurosteroid: medium decay (0.008/step) - state-dependent adaptation

- Manic relapse = biased stress accuracy drops below threshold after discontinuation
================================================================================
"""

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.utils.data import DataLoader, TensorDataset
from typing import Dict, Tuple, List
import warnings
from collections import defaultdict
import copy

warnings.filterwarnings('ignore', category=UserWarning)

DEVICE = torch.device('cpu')

# ============================================================================
# CONFIGURATION
# ============================================================================
CONFIG = {
    'n_seeds': 10,
    'n_train': 12000,
    'n_test': 4000,
    'n_clean_test': 2000,
    'data_noise': 0.8,
    'batch_size': 128,
    'hidden_dims': [512, 512, 256],
    'input_dim': 2,
    'output_dim': 4,
    'baseline_epochs': 20,
    'baseline_lr': 0.001,
    'finetune_epochs': 15,
    'finetune_lr': 0.0005,
    'prune_sparsity': 0.95,
    'regrow_fraction': 0.5,
    'regrow_init_scale': 0.03,
    'gradient_accumulation_batches': 30,
    'extended_stress_levels': {
        'none': 0.0,
        'moderate': 0.5,
        'high': 1.0,
        'severe': 1.5,
        'extreme': 2.5
    },
    'monoaminergic_epochs': 100,
    'monoaminergic_lr': 1e-5,
    'monoaminergic_initial_stress': 0.5,
    'ssri_max_gain': 1.6,
    'ketamine_gain': 1.25,
    'neurosteroid_inhibition_strength': 0.7,
    'neurosteroid_use_tanh': True,
    'neurosteroid_consolidation_epochs': 10,
    'neurosteroid_gain': 0.85,
    'mania_test_bias': 1.0,
    'mania_test_sigma': 1.0,
    'comparison_ketamine_regrow': 0.5,
    'comparison_ketamine_epochs': 15,
    'comparison_ssri_epochs': 100,
    'comparison_neurosteroid_strength': 0.7,
    'comparison_neurosteroid_epochs': 10,

    # Mood Stabilizer (MS) Architecture Parameters
    'ms_gain_cap': 1.05,
    'ms_inhib_bias_strength': 0.15,
    'ms_bias_damping_factor': 0.3,
    'ms_max_protection': 1.0,

    # Longitudinal Manic Relapse Parameters
    'maintenance_durations': [25, 50, 100, 150, 200, 300],
    'maintenance_lr': 1e-6,
    'post_discontinuation_steps': 50,
    'manic_relapse_threshold': 60.0,

    # Treatment-specific MS decay rates (architecture parameter differences)
    'ketamine_ms_decay_rate': 0.002,
    'ssri_ms_decay_rate': 0.015,
    'neurosteroid_ms_decay_rate': 0.008,
}


# ============================================================================
# DATA GENERATION
# ============================================================================
def generate_blobs(n_samples: int = 10000, noise: float = 0.8, seed: int = None) -> Tuple[torch.Tensor, torch.Tensor]:
    if seed is not None:
        rng = np.random.RandomState(seed)
    else:
        rng = np.random.RandomState()
    centers = np.array([[-3, -3], [3, 3], [-3, 3], [3, -3]])
    labels = rng.randint(0, 4, n_samples)
    data = centers[labels] + rng.randn(n_samples, 2) * noise
    return (torch.tensor(data, dtype=torch.float32), torch.tensor(labels, dtype=torch.long))


def create_data_loaders(seed: int) -> Tuple[DataLoader, DataLoader, DataLoader]:
    train_data, train_labels = generate_blobs(CONFIG['n_train'], noise=CONFIG['data_noise'], seed=seed*1000+100)
    test_data, test_labels = generate_blobs(CONFIG['n_test'], noise=CONFIG['data_noise'], seed=seed*1000+200)
    clean_test_data, clean_test_labels = generate_blobs(CONFIG['n_clean_test'], noise=0.0, seed=seed*1000+300)
    train_loader = DataLoader(TensorDataset(train_data, train_labels), batch_size=CONFIG['batch_size'], shuffle=True)
    test_loader = DataLoader(TensorDataset(test_data, test_labels), batch_size=1000)
    clean_test_loader = DataLoader(TensorDataset(clean_test_data, clean_test_labels), batch_size=1000)
    return train_loader, test_loader, clean_test_loader


# ============================================================================
# NETWORK ARCHITECTURE WITH MS PROTECTION PARAMETERS
# ============================================================================
class StressAwareNetwork(nn.Module):
    """
    Feed-forward network with:
    - Internal noise injection (stress)
    - GABAergic modulation (inhibition)
    - Excitability gain modeling
    - Mood stabilizer (MS) protection parameters (NEW)

    MS effects are implemented as architecture parameters that modify forward pass:
    - ms_protection_level: Overall protection strength (0.0-1.0)
    - ms_gain_cap: Maximum effective gain when MS active
    - ms_inhib_bias: Adds mild negative shift to counter hyperexcitability
    - ms_bias_damping: Reduces positive bias impact in biased noise test
    - lingering_ms_decay_rate: Per-step decay after discontinuation
    """

    def __init__(self, hidden_dims: List[int] = None):
        super().__init__()
        if hidden_dims is None:
            hidden_dims = CONFIG['hidden_dims']
        self.fc1 = nn.Linear(CONFIG['input_dim'], hidden_dims[0])
        self.fc2 = nn.Linear(hidden_dims[0], hidden_dims[1])
        self.fc3 = nn.Linear(hidden_dims[1], hidden_dims[2])
        self.fc4 = nn.Linear(hidden_dims[2], CONFIG['output_dim'])
        self.relu = nn.ReLU()
        self.tanh = nn.Tanh()

        # Core modulation parameters
        self.stress_level = 0.0
        self.inhibition_strength = 1.0
        self.use_tanh = False
        self.gain_multiplier = 1.0

        # Mood Stabilizer (MS) parameters (architecture extension)
        self.ms_protection_level = 0.0
        self.lingering_ms_decay_rate = 0.0

        self.weight_layers = ['fc1', 'fc2', 'fc3', 'fc4']

    def set_stress(self, level: float):
        self.stress_level = level

    def set_inhibition(self, strength: float, use_tanh: bool = False):
        self.inhibition_strength = strength
        self.use_tanh = use_tanh

    def set_gain(self, gain: float):
        self.gain_multiplier = gain

    def set_ms_protection(self, level: float, decay_rate: float = 0.0):
        """Set mood stabilizer protection level and lingering decay rate."""
        self.ms_protection_level = max(0.0, min(1.0, level))
        self.lingering_ms_decay_rate = decay_rate

    def decay_ms_protection(self):
        """Apply one step of lingering MS decay (called per time unit post-discontinuation)."""
        if self.lingering_ms_decay_rate > 0:
            self.ms_protection_level = max(0.0, self.ms_protection_level - self.lingering_ms_decay_rate)

    def reset_antidepressant_effects(self):
        """Reset antidepressant-specific parameters (simulates discontinuation)."""
        self.gain_multiplier = 1.0
        self.inhibition_strength = 1.0
        self.use_tanh = False
        self.stress_level = 0.0

    def get_effective_gain(self) -> float:
        """Calculate effective gain considering MS protection."""
        if self.ms_protection_level > 0:
            # MS caps excitability overshoot proportionally to protection level
            max_allowed = CONFIG['ms_gain_cap'] + (1.0 - self.ms_protection_level) * (self.gain_multiplier - CONFIG['ms_gain_cap'])
            return min(self.gain_multiplier, max(CONFIG['ms_gain_cap'], max_allowed))
        return self.gain_multiplier

    def reduce_stress_gradually(self, epoch: int, total_epochs: int, initial_stress: float = 0.5, final_stress: float = 0.0):
        progress = epoch / max(total_epochs - 1, 1)
        self.stress_level = initial_stress + progress * (final_stress - initial_stress)

    def increase_gain_gradually(self, epoch: int, total_epochs: int, initial_gain: float = 1.0, max_gain: float = 1.6):
        progress = epoch / max(total_epochs - 1, 1)
        self.gain_multiplier = initial_gain + progress * (max_gain - initial_gain)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        activation = self.tanh if self.use_tanh else self.relu
        effective_gain = self.get_effective_gain()

        # Layer 1
        h = activation(self.fc1(x))
        if self.stress_level > 0:
            h = h + torch.randn_like(h) * self.stress_level
        h = h * self.inhibition_strength * effective_gain
        # MS inhibitory bias
        if self.ms_protection_level > 0:
            h = h - CONFIG['ms_inhib_bias_strength'] * self.ms_protection_level

        # Layer 2
        h = activation(self.fc2(h))
        if self.stress_level > 0:
            h = h + torch.randn_like(h) * self.stress_level
        h = h * self.inhibition_strength * effective_gain
        if self.ms_protection_level > 0:
            h = h - CONFIG['ms_inhib_bias_strength'] * self.ms_protection_level

        # Layer 3
        h = activation(self.fc3(h))
        if self.stress_level > 0:
            h = h + torch.randn_like(h) * self.stress_level
        h = h * self.inhibition_strength * effective_gain
        if self.ms_protection_level > 0:
            h = h - CONFIG['ms_inhib_bias_strength'] * self.ms_protection_level

        return self.fc4(h)

    def forward_with_biased_noise(self, x: torch.Tensor, internal_sigma: float = 1.0, bias: float = 1.0) -> torch.Tensor:
        """Forward pass with positively-biased internal noise (manic vulnerability test)."""
        activation = self.tanh if self.use_tanh else self.relu
        effective_gain = self.get_effective_gain()

        # Calculate bias damping from MS protection
        bias_multiplier = 1.0 - (CONFIG['ms_bias_damping_factor'] * self.ms_protection_level)
        effective_bias = bias * bias_multiplier

        # Layer 1
        h = activation(self.fc1(x))
        h = h + (torch.randn_like(h) * internal_sigma + effective_bias)
        h = h * self.inhibition_strength * effective_gain
        if self.ms_protection_level > 0:
            h = h - CONFIG['ms_inhib_bias_strength'] * self.ms_protection_level

        # Layer 2
        h = activation(self.fc2(h))
        h = h + (torch.randn_like(h) * internal_sigma + effective_bias)
        h = h * self.inhibition_strength * effective_gain
        if self.ms_protection_level > 0:
            h = h - CONFIG['ms_inhib_bias_strength'] * self.ms_protection_level

        # Layer 3
        h = activation(self.fc3(h))
        h = h + (torch.randn_like(h) * internal_sigma + effective_bias)
        h = h * self.inhibition_strength * effective_gain
        if self.ms_protection_level > 0:
            h = h - CONFIG['ms_inhib_bias_strength'] * self.ms_protection_level

        return self.fc4(h)

    def count_parameters(self) -> Tuple[int, int]:
        total = sum(p.numel() for p in self.parameters())
        nonzero = sum((p != 0).sum().item() for p in self.parameters())
        return total, nonzero


# ============================================================================
# PRUNING MANAGER
# ============================================================================
class PruningManager:
    def __init__(self, model: StressAwareNetwork, train_loader: DataLoader):
        self.model = model
        self.train_loader = train_loader
        self.masks = {}
        self.gradient_buffer = {}
        for name, param in model.named_parameters():
            if 'weight' in name and param.dim() >= 2:
                self.masks[name] = torch.ones_like(param, dtype=torch.float32)
                self.gradient_buffer[name] = torch.zeros_like(param)

    def prune_by_magnitude(self, sparsity: float, per_layer: bool = True) -> Dict[str, Dict]:
        stats = {}
        for name, param in self.model.named_parameters():
            if name in self.masks:
                weights = param.data.abs()
                threshold = torch.quantile(weights.flatten(), sparsity)
                self.masks[name] = (weights >= threshold).float()
                param.data *= self.masks[name]
                kept = self.masks[name].sum().item()
                total = self.masks[name].numel()
                stats[name] = {'kept': int(kept), 'total': total, 'actual_sparsity': 1 - kept/total}
        return stats

    def _accumulate_gradients(self, num_batches: int = 30):
        model = self.model
        loss_fn = nn.CrossEntropyLoss()
        for name in self.gradient_buffer:
            self.gradient_buffer[name].zero_()
        model.train()
        original_stress = model.stress_level
        original_gain = model.gain_multiplier
        model.set_stress(0.0)
        model.set_gain(1.0)
        batch_count = 0
        for x, y in self.train_loader:
            if batch_count >= num_batches:
                break
            x, y = x.to(DEVICE), y.to(DEVICE)
            loss = loss_fn(model(x), y)
            loss.backward()
            with torch.no_grad():
                for name, param in model.named_parameters():
                    if name in self.masks:
                        pruned_mask = (self.masks[name] == 0).float()
                        self.gradient_buffer[name] += param.grad.abs() * pruned_mask
            model.zero_grad()
            batch_count += 1
        model.set_stress(original_stress)
        model.set_gain(original_gain)

    def gradient_guided_regrow(self, regrow_fraction: float, init_scale: float = None) -> Dict[str, Dict]:
        if init_scale is None:
            init_scale = CONFIG['regrow_init_scale']
        self._accumulate_gradients(num_batches=CONFIG['gradient_accumulation_batches'])
        stats = {}
        for name, param in self.model.named_parameters():
            if name not in self.masks:
                continue
            mask = self.masks[name]
            pruned_positions = (mask == 0)
            num_pruned = pruned_positions.sum().item()
            if num_pruned == 0:
                stats[name] = {'regrown': 0, 'still_pruned': 0}
                continue
            gradient_scores = self.gradient_buffer[name][pruned_positions]
            num_regrow = max(1, int(regrow_fraction * num_pruned))
            num_regrow = min(num_regrow, gradient_scores.numel())
            _, top_indices = torch.topk(gradient_scores.flatten(), num_regrow)
            flat_pruned_indices = torch.where(pruned_positions.flatten())[0]
            regrow_flat_indices = flat_pruned_indices[top_indices]
            flat_mask = mask.flatten()
            flat_param = param.data.flatten()
            flat_mask[regrow_flat_indices] = 1.0
            flat_param[regrow_flat_indices] = torch.randn(num_regrow) * init_scale
            self.masks[name] = flat_mask.view_as(mask)
            param.data = flat_param.view_as(param)
            stats[name] = {'regrown': num_regrow, 'still_pruned': int(num_pruned - num_regrow)}
        return stats

    def apply_masks(self):
        with torch.no_grad():
            for name, param in self.model.named_parameters():
                if name in self.masks:
                    param.data *= self.masks[name]

    def get_sparsity(self) -> float:
        total = sum(m.numel() for m in self.masks.values())
        zeros = sum((m == 0).sum().item() for m in self.masks.values())
        return zeros / total if total > 0 else 0.0

    def clone_masks(self) -> Dict[str, torch.Tensor]:
        return {k: v.clone() for k, v in self.masks.items()}


# ============================================================================
# TRAINING FUNCTIONS
# ============================================================================
def train(model: StressAwareNetwork, train_loader: DataLoader, epochs: int = 15, lr: float = 0.001,
          pruning_manager: PruningManager = None) -> List[float]:
    optimizer = optim.Adam(model.parameters(), lr=lr)
    loss_fn = nn.CrossEntropyLoss()
    losses = []
    original_stress = model.stress_level
    model.set_stress(0.0)
    for epoch in range(epochs):
        model.train()
        epoch_loss = 0.0
        for x, y in train_loader:
            x, y = x.to(DEVICE), y.to(DEVICE)
            optimizer.zero_grad()
            loss = loss_fn(model(x), y)
            loss.backward()
            optimizer.step()
            if pruning_manager:
                pruning_manager.apply_masks()
            epoch_loss += loss.item()
        losses.append(epoch_loss / len(train_loader))
    model.set_stress(original_stress)
    return losses


def train_with_stress_and_gain_schedule(model: StressAwareNetwork, train_loader: DataLoader, epochs: int, lr: float,
                                         initial_stress: float, final_stress: float = 0.0,
                                         initial_gain: float = 1.0, max_gain: float = 1.6,
                                         pruning_manager: PruningManager = None) -> List[float]:
    optimizer = optim.Adam(model.parameters(), lr=lr)
    loss_fn = nn.CrossEntropyLoss()
    losses = []
    for epoch in range(epochs):
        model.reduce_stress_gradually(epoch, epochs, initial_stress, final_stress)
        model.increase_gain_gradually(epoch, epochs, initial_gain, max_gain)
        model.train()
        epoch_loss = 0.0
        for x, y in train_loader:
            x, y = x.to(DEVICE), y.to(DEVICE)
            optimizer.zero_grad()
            loss = loss_fn(model(x), y)
            loss.backward()
            optimizer.step()
            if pruning_manager:
                pruning_manager.apply_masks()
            epoch_loss += loss.item()
        losses.append(epoch_loss / len(train_loader))
    model.set_stress(0.0)
    return losses


# ============================================================================
# EVALUATION FUNCTIONS
# ============================================================================
def evaluate(model: StressAwareNetwork, loader: DataLoader, input_noise: float = 0.0, internal_stress: float = 0.0) -> float:
    model.eval()
    model.set_stress(internal_stress)
    correct, total = 0, 0
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(DEVICE), y.to(DEVICE)
            if input_noise > 0:
                x = x + torch.randn_like(x) * input_noise
            correct += (model(x).argmax(dim=1) == y).sum().item()
            total += y.size(0)
    model.set_stress(0.0)
    return 100.0 * correct / total


def evaluate_biased_stress(model: StressAwareNetwork, loader: DataLoader, input_noise: float = 0.0,
                           internal_sigma: float = 1.0, bias: float = 1.0) -> float:
    """Evaluate under positively-biased noise (manic vulnerability test)."""
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(DEVICE), y.to(DEVICE)
            if input_noise > 0:
                x = x + torch.randn_like(x) * input_noise
            out = model.forward_with_biased_noise(x, internal_sigma, bias)
            correct += (out.argmax(dim=1) == y).sum().item()
            total += y.size(0)
    return 100.0 * correct / total


def get_avg_activation_magnitude(model: StressAwareNetwork, loader: DataLoader) -> float:
    model.eval()
    activation = model.tanh if model.use_tanh else model.relu
    all_acts = []
    with torch.no_grad():
        for x, _ in loader:
            x = x.to(DEVICE)
            h1 = activation(model.fc1(x)) * model.inhibition_strength * model.gain_multiplier
            h2 = activation(model.fc2(h1)) * model.inhibition_strength * model.gain_multiplier
            h3 = activation(model.fc3(h2)) * model.inhibition_strength * model.gain_multiplier
            all_acts.append(torch.mean(torch.abs(h1)).item())
            all_acts.append(torch.mean(torch.abs(h2)).item())
            all_acts.append(torch.mean(torch.abs(h3)).item())
    return np.mean(all_acts)


# ============================================================================
# TREATMENT PROTOCOLS
# ============================================================================
def ketamine_treatment(model: StressAwareNetwork, pruning_mgr: PruningManager, train_loader: DataLoader) -> Dict:
    model.set_gain(CONFIG['ketamine_gain'])
    pruning_mgr.gradient_guided_regrow(regrow_fraction=CONFIG['comparison_ketamine_regrow'])
    train(model, train_loader, epochs=CONFIG['comparison_ketamine_epochs'], lr=CONFIG['finetune_lr'], pruning_manager=pruning_mgr)
    return {'final_sparsity': pruning_mgr.get_sparsity(), 'final_gain': model.gain_multiplier}


def ssri_treatment(model: StressAwareNetwork, pruning_mgr: PruningManager, train_loader: DataLoader) -> Dict:
    train_with_stress_and_gain_schedule(
        model, train_loader, epochs=CONFIG['comparison_ssri_epochs'], lr=CONFIG['monoaminergic_lr'],
        initial_stress=CONFIG['monoaminergic_initial_stress'], final_stress=0.0,
        initial_gain=1.0, max_gain=CONFIG['ssri_max_gain'], pruning_manager=pruning_mgr
    )
    return {'final_sparsity': pruning_mgr.get_sparsity(), 'final_gain': model.gain_multiplier}


def neurosteroid_treatment(model: StressAwareNetwork, pruning_mgr: PruningManager, train_loader: DataLoader) -> Dict:
    model.set_inhibition(CONFIG['neurosteroid_inhibition_strength'], CONFIG['neurosteroid_use_tanh'])
    model.set_gain(CONFIG['neurosteroid_gain'])
    train(model, train_loader, epochs=CONFIG['neurosteroid_consolidation_epochs'], lr=CONFIG['finetune_lr'], pruning_manager=pruning_mgr)
    return {'final_sparsity': pruning_mgr.get_sparsity(), 'final_gain': model.gain_multiplier,
            'effective_scale': model.inhibition_strength * model.gain_multiplier}


# ============================================================================
# LONGITUDINAL MANIC RELAPSE SIMULATION
# ============================================================================
def simulate_longitudinal_manic_relapse(
    model_state: Dict[str, torch.Tensor],
    mask_state: Dict[str, torch.Tensor],
    treatment_type: str,
    treatment_params: Dict,
    train_loader: DataLoader,
    test_loader: DataLoader
) -> Dict:
    """
    Simulate longitudinal manic relapse after medication discontinuation.

    Protocol:
    1. Chronic maintenance phase: MS + antidepressant active for varying durations
    2. Discontinuation: Antidepressant effects reset immediately, MS decays slowly
    3. Trigger test: Biased stress accuracy measured
    4. Manic relapse = biased stress accuracy < threshold

    All changes are architecture/parameter based:
    - MS protection parameters modulate forward pass
    - Treatment-specific decay rates determine vulnerability trajectory
    """

    # Get treatment-specific MS decay rate
    decay_rates = {
        'ketamine': CONFIG['ketamine_ms_decay_rate'],
        'ssri': CONFIG['ssri_ms_decay_rate'],
        'neurosteroid': CONFIG['neurosteroid_ms_decay_rate']
    }
    ms_decay_rate = decay_rates[treatment_type]

    results = {
        'maintenance_durations': [],
        'biased_stress_accuracies': [],
        'ms_protection_at_test': [],
        'relapsed': [],
        'effective_gain_at_test': []
    }

    for duration in CONFIG['maintenance_durations']:
        # Create fresh model from post-treatment state
        model = StressAwareNetwork().to(DEVICE)
        model.load_state_dict(copy.deepcopy(model_state))

        # Create pruning manager and restore masks
        mgr = PruningManager(model, train_loader)
        mgr.masks = {k: v.clone() for k, v in mask_state.items()}
        mgr.apply_masks()

        # Apply treatment-specific parameters
        if treatment_type == 'ketamine':
            model.set_gain(treatment_params.get('final_gain', CONFIG['ketamine_gain']))
        elif treatment_type == 'ssri':
            model.set_gain(treatment_params.get('final_gain', CONFIG['ssri_max_gain']))
        elif treatment_type == 'neurosteroid':
            model.set_inhibition(CONFIG['neurosteroid_inhibition_strength'], CONFIG['neurosteroid_use_tanh'])
            model.set_gain(treatment_params.get('final_gain', CONFIG['neurosteroid_gain']))

        # Set full MS protection with treatment-specific decay rate
        model.set_ms_protection(CONFIG['ms_max_protection'], ms_decay_rate)

        # Chronic maintenance phase (stable ultra-low LR training with MS + antidep active)
        train(model, train_loader, epochs=duration, lr=CONFIG['maintenance_lr'], pruning_manager=mgr)

        # Discontinuation: Reset antidepressant effects immediately
        model.reset_antidepressant_effects()

        # Post-discontinuation period: MS protection decays gradually
        for _ in range(CONFIG['post_discontinuation_steps']):
            model.decay_ms_protection()

        # Record MS protection level at test time
        final_ms_protection = model.ms_protection_level
        effective_gain = model.get_effective_gain()

        # Manic trigger test (biased stress)
        biased_acc = evaluate_biased_stress(
            model, test_loader,
            internal_sigma=CONFIG['mania_test_sigma'],
            bias=CONFIG['mania_test_bias']
        )

        # Determine if relapse occurred
        relapsed = biased_acc < CONFIG['manic_relapse_threshold']

        # Store results
        results['maintenance_durations'].append(duration)
        results['biased_stress_accuracies'].append(biased_acc)
        results['ms_protection_at_test'].append(final_ms_protection)
        results['relapsed'].append(relapsed)
        results['effective_gain_at_test'].append(effective_gain)

    # Calculate summary statistics
    results['relapse_count'] = sum(results['relapsed'])
    results['total_tests'] = len(results['relapsed'])
    results['relapse_probability'] = (results['relapse_count'] / results['total_tests']) * 100
    results['mean_biased_acc'] = np.mean(results['biased_stress_accuracies'])
    results['std_biased_acc'] = np.std(results['biased_stress_accuracies'])
    results['ms_decay_rate'] = ms_decay_rate

    return results


# ============================================================================
# SINGLE SEED EXPERIMENT
# ============================================================================
def run_single_seed(seed: int) -> Dict[str, Dict]:
    torch.manual_seed(seed)
    np.random.seed(seed)

    train_loader, test_loader, clean_test_loader = create_data_loaders(seed)

    # Train and prune base model
    base_model = StressAwareNetwork().to(DEVICE)
    train(base_model, train_loader, epochs=CONFIG['baseline_epochs'], lr=CONFIG['baseline_lr'])
    base_pruning_mgr = PruningManager(base_model, train_loader)
    base_pruning_mgr.prune_by_magnitude(sparsity=CONFIG['prune_sparsity'])

    base_state_dict = {k: v.clone() for k, v in base_model.state_dict().items()}
    base_masks = {k: v.clone() for k, v in base_pruning_mgr.masks.items()}

    results = {}

    # Untreated
    untreated_results = {
        'sparsity': base_pruning_mgr.get_sparsity() * 100,
        'gain': 1.0,
        'clean': evaluate(base_model, clean_test_loader),
        'standard': evaluate(base_model, test_loader),
        'combined': evaluate(base_model, test_loader, 1.0, 0.5),
        'biased_stress': evaluate_biased_stress(base_model, test_loader, internal_sigma=CONFIG['mania_test_sigma'], bias=CONFIG['mania_test_bias']),
        'activation_magnitude': get_avg_activation_magnitude(base_model, test_loader)
    }
    for stress_name, stress_level in CONFIG['extended_stress_levels'].items():
        untreated_results[f'stress_{stress_name}'] = evaluate(base_model, test_loader, 0.0, stress_level)
    results['untreated'] = untreated_results

    # ========================================================================
    # KETAMINE TREATMENT + LONGITUDINAL MANIC RELAPSE
    # ========================================================================
    ketamine_model = StressAwareNetwork().to(DEVICE)
    ketamine_model.load_state_dict(copy.deepcopy(base_state_dict))
    ketamine_mgr = PruningManager(ketamine_model, train_loader)
    ketamine_mgr.masks = {k: v.clone() for k, v in base_masks.items()}
    ketamine_mgr.apply_masks()
    ketamine_stats = ketamine_treatment(ketamine_model, ketamine_mgr, train_loader)

    # Save post-treatment state for longitudinal simulation
    ketamine_post_state = {k: v.clone() for k, v in ketamine_model.state_dict().items()}
    ketamine_post_masks = ketamine_mgr.clone_masks()

    ketamine_results = {
        'sparsity': ketamine_mgr.get_sparsity() * 100,
        'gain': ketamine_model.gain_multiplier,
        'clean': evaluate(ketamine_model, clean_test_loader),
        'standard': evaluate(ketamine_model, test_loader),
        'combined': evaluate(ketamine_model, test_loader, 1.0, 0.5),
        'biased_stress': evaluate_biased_stress(ketamine_model, test_loader, internal_sigma=CONFIG['mania_test_sigma'], bias=CONFIG['mania_test_bias']),
        'activation_magnitude': get_avg_activation_magnitude(ketamine_model, test_loader)
    }
    for stress_name, stress_level in CONFIG['extended_stress_levels'].items():
        ketamine_results[f'stress_{stress_name}'] = evaluate(ketamine_model, test_loader, 0.0, stress_level)

    # Longitudinal manic relapse simulation
    ketamine_longitudinal = simulate_longitudinal_manic_relapse(
        ketamine_post_state, ketamine_post_masks, 'ketamine', ketamine_stats,
        train_loader, test_loader
    )
    ketamine_results['longitudinal'] = ketamine_longitudinal

    # Acute relapse
    pre_relapse = ketamine_results['combined']
    target_sparsity = min(ketamine_mgr.get_sparsity() + (1 - ketamine_mgr.get_sparsity()) * 0.40, 0.99)
    ketamine_mgr.prune_by_magnitude(sparsity=target_sparsity)
    ketamine_mgr.apply_masks()
    ketamine_results['relapse_drop'] = pre_relapse - evaluate(ketamine_model, test_loader, 1.0, 0.5)
    results['ketamine'] = ketamine_results

    # ========================================================================
    # SSRI TREATMENT + LONGITUDINAL MANIC RELAPSE
    # ========================================================================
    ssri_model = StressAwareNetwork().to(DEVICE)
    ssri_model.load_state_dict(copy.deepcopy(base_state_dict))
    ssri_mgr = PruningManager(ssri_model, train_loader)
    ssri_mgr.masks = {k: v.clone() for k, v in base_masks.items()}
    ssri_mgr.apply_masks()
    ssri_stats = ssri_treatment(ssri_model, ssri_mgr, train_loader)

    # Save post-treatment state
    ssri_post_state = {k: v.clone() for k, v in ssri_model.state_dict().items()}
    ssri_post_masks = ssri_mgr.clone_masks()

    ssri_results = {
        'sparsity': ssri_mgr.get_sparsity() * 100,
        'gain': ssri_model.gain_multiplier,
        'clean': evaluate(ssri_model, clean_test_loader),
        'standard': evaluate(ssri_model, test_loader),
        'combined': evaluate(ssri_model, test_loader, 1.0, 0.5),
        'biased_stress': evaluate_biased_stress(ssri_model, test_loader, internal_sigma=CONFIG['mania_test_sigma'], bias=CONFIG['mania_test_bias']),
        'activation_magnitude': get_avg_activation_magnitude(ssri_model, test_loader)
    }
    for stress_name, stress_level in CONFIG['extended_stress_levels'].items():
        ssri_results[f'stress_{stress_name}'] = evaluate(ssri_model, test_loader, 0.0, stress_level)

    # Longitudinal manic relapse simulation
    ssri_longitudinal = simulate_longitudinal_manic_relapse(
        ssri_post_state, ssri_post_masks, 'ssri', ssri_stats,
        train_loader, test_loader
    )
    ssri_results['longitudinal'] = ssri_longitudinal

    pre_relapse = ssri_results['combined']
    target_sparsity = min(ssri_mgr.get_sparsity() + (1 - ssri_mgr.get_sparsity()) * 0.40, 0.99)
    ssri_mgr.prune_by_magnitude(sparsity=target_sparsity)
    ssri_mgr.apply_masks()
    ssri_results['relapse_drop'] = pre_relapse - evaluate(ssri_model, test_loader, 1.0, 0.5)
    results['ssri'] = ssri_results

    # ========================================================================
    # NEUROSTEROID TREATMENT + LONGITUDINAL MANIC RELAPSE
    # ========================================================================
    neuro_model = StressAwareNetwork().to(DEVICE)
    neuro_model.load_state_dict(copy.deepcopy(base_state_dict))
    neuro_mgr = PruningManager(neuro_model, train_loader)
    neuro_mgr.masks = {k: v.clone() for k, v in base_masks.items()}
    neuro_mgr.apply_masks()
    neuro_stats = neurosteroid_treatment(neuro_model, neuro_mgr, train_loader)

    # Save post-treatment state
    neuro_post_state = {k: v.clone() for k, v in neuro_model.state_dict().items()}
    neuro_post_masks = neuro_mgr.clone_masks()

    neuro_results = {
        'sparsity': neuro_mgr.get_sparsity() * 100,
        'gain': neuro_model.gain_multiplier,
        'effective_scale': neuro_model.inhibition_strength * neuro_model.gain_multiplier,
        'clean': evaluate(neuro_model, clean_test_loader),
        'standard': evaluate(neuro_model, test_loader),
        'combined': evaluate(neuro_model, test_loader, 1.0, 0.5),
        'biased_stress': evaluate_biased_stress(neuro_model, test_loader, internal_sigma=CONFIG['mania_test_sigma'], bias=CONFIG['mania_test_bias']),
        'activation_magnitude': get_avg_activation_magnitude(neuro_model, test_loader)
    }
    for stress_name, stress_level in CONFIG['extended_stress_levels'].items():
        neuro_results[f'stress_{stress_name}'] = evaluate(neuro_model, test_loader, 0.0, stress_level)

    # Longitudinal manic relapse simulation
    neuro_longitudinal = simulate_longitudinal_manic_relapse(
        neuro_post_state, neuro_post_masks, 'neurosteroid', neuro_stats,
        train_loader, test_loader
    )
    neuro_results['longitudinal'] = neuro_longitudinal

    # Off medication
    neuro_model.set_inhibition(1.0, False)
    neuro_model.set_gain(1.0)
    neuro_results['off_medication_combined'] = evaluate(neuro_model, test_loader, 1.0, 0.5)
    neuro_results['off_medication_extreme'] = evaluate(neuro_model, test_loader, 0.0, 2.5)
    neuro_results['off_medication_biased'] = evaluate_biased_stress(neuro_model, test_loader, internal_sigma=CONFIG['mania_test_sigma'], bias=CONFIG['mania_test_bias'])
    neuro_results['off_medication_activation'] = get_avg_activation_magnitude(neuro_model, test_loader)

    # Relapse with modulation
    neuro_model.set_inhibition(CONFIG['neurosteroid_inhibition_strength'], CONFIG['neurosteroid_use_tanh'])
    neuro_model.set_gain(CONFIG['neurosteroid_gain'])
    pre_relapse = neuro_results['combined']
    target_sparsity = min(neuro_mgr.get_sparsity() + (1 - neuro_mgr.get_sparsity()) * 0.40, 0.99)
    neuro_mgr.prune_by_magnitude(sparsity=target_sparsity)
    neuro_mgr.apply_masks()
    neuro_results['relapse_drop'] = pre_relapse - evaluate(neuro_model, test_loader, 1.0, 0.5)
    results['neurosteroid'] = neuro_results

    return results


# ============================================================================
# AGGREGATE RESULTS
# ============================================================================
def aggregate_results(all_results: List[Dict[str, Dict]]) -> Dict:
    treatments = ['untreated', 'ketamine', 'ssri', 'neurosteroid']
    aggregated = {'treatments': {}, 'longitudinal': {}}

    # Aggregate treatment metrics
    for treatment in treatments:
        metrics = defaultdict(list)
        for seed_result in all_results:
            if treatment in seed_result:
                for key, value in seed_result[treatment].items():
                    if isinstance(value, (int, float)):
                        metrics[key].append(value)

        aggregated['treatments'][treatment] = {}
        for key, values in metrics.items():
            aggregated['treatments'][treatment][key] = {
                'mean': np.mean(values),
                'std': np.std(values),
                'min': np.min(values),
                'max': np.max(values)
            }

    # Aggregate longitudinal manic relapse data
    for treatment in ['ketamine', 'ssri', 'neurosteroid']:
        longitudinal_data = {
            'relapse_probabilities': [],
            'per_duration': defaultdict(lambda: {'biased_acc': [], 'ms_protection': [], 'relapsed': [], 'effective_gain': []})
        }

        for seed_result in all_results:
            if treatment in seed_result and 'longitudinal' in seed_result[treatment]:
                long = seed_result[treatment]['longitudinal']
                longitudinal_data['relapse_probabilities'].append(long['relapse_probability'])

                for i, dur in enumerate(long['maintenance_durations']):
                    longitudinal_data['per_duration'][dur]['biased_acc'].append(long['biased_stress_accuracies'][i])
                    longitudinal_data['per_duration'][dur]['ms_protection'].append(long['ms_protection_at_test'][i])
                    longitudinal_data['per_duration'][dur]['relapsed'].append(long['relapsed'][i])
                    longitudinal_data['per_duration'][dur]['effective_gain'].append(long['effective_gain_at_test'][i])

        # Compute aggregated stats
        probs = longitudinal_data['relapse_probabilities']
        aggregated['longitudinal'][treatment] = {
            'relapse_probability_mean': np.mean(probs),
            'relapse_probability_std': np.std(probs),
            'relapse_probability_min': np.min(probs),
            'relapse_probability_max': np.max(probs),
            'ms_decay_rate': CONFIG[f'{treatment}_ms_decay_rate'],
            'per_duration': {}
        }

        for dur in CONFIG['maintenance_durations']:
            dur_data = longitudinal_data['per_duration'][dur]
            if dur_data['biased_acc']:
                relapse_rate = (sum(dur_data['relapsed']) / len(dur_data['relapsed'])) * 100
                aggregated['longitudinal'][treatment]['per_duration'][dur] = {
                    'biased_acc_mean': np.mean(dur_data['biased_acc']),
                    'biased_acc_std': np.std(dur_data['biased_acc']),
                    'ms_protection_mean': np.mean(dur_data['ms_protection']),
                    'ms_protection_std': np.std(dur_data['ms_protection']),
                    'effective_gain_mean': np.mean(dur_data['effective_gain']),
                    'relapse_rate': relapse_rate,
                    'n_relapsed': sum(dur_data['relapsed']),
                    'n_total': len(dur_data['relapsed'])
                }

    return aggregated


# ============================================================================
# MAIN EXPERIMENT
# ============================================================================
def run_multi_seed_experiment() -> Dict:
    print("\n" + "="*80)
    print("  MULTI-MECHANISM ANTIDEPRESSANT COMPARISON EXPERIMENT")
    print("  WITH MANIC CONVERSION RISK + LONGITUDINAL MANIC RELAPSE")
    print("  Running across {} random seeds".format(CONFIG['n_seeds']))
    print("="*80)

    all_results = []

    for seed in range(CONFIG['n_seeds']):
        print(f"\n  Seed {seed+1}/{CONFIG['n_seeds']}...", end=" ", flush=True)
        seed_results = run_single_seed(seed)
        all_results.append(seed_results)
        print("done")

    aggregated = aggregate_results(all_results)

    # ========================================================================
    # PRINT COMPREHENSIVE RESULTS
    # ========================================================================

    print("\n" + "="*80)
    print("  AGGREGATED RESULTS ACROSS {} SEEDS".format(CONFIG['n_seeds']))
    print("="*80)

    treatments = ['untreated', 'ketamine', 'ssri', 'neurosteroid']
    labels = {
        'untreated': 'Untreated (pruned)',
        'ketamine': 'Ketamine-like',
        'ssri': 'SSRI-like',
        'neurosteroid': 'Neurosteroid-like'
    }

    # ========================================================================
    # TABLE 1: ANTIDEPRESSANT EFFICACY
    # ========================================================================
    print("\n" + "-"*80)
    print("  TABLE 1: ANTIDEPRESSANT EFFICACY (Mean ± Std)")
    print("-"*80)
    print(f"\n  {'Treatment':<22} {'Sparsity':>12} {'Clean':>14} {'Standard':>14} {'Combined':>14}")
    print("  " + "-"*78)

    for t in treatments:
        r = aggregated['treatments'][t]
        print(f"  {labels[t]:<22} "
              f"{r['sparsity']['mean']:>5.1f}±{r['sparsity']['std']:>4.1f}% "
              f"{r['clean']['mean']:>6.1f}±{r['clean']['std']:>4.1f}% "
              f"{r['standard']['mean']:>6.1f}±{r['standard']['std']:>4.1f}% "
              f"{r['combined']['mean']:>6.1f}±{r['combined']['std']:>4.1f}%")

    # ========================================================================
    # TABLE 2: STRESS RESILIENCE PROFILE
    # ========================================================================
    print("\n" + "-"*80)
    print("  TABLE 2: STRESS RESILIENCE PROFILE (Mean ± Std)")
    print("-"*80)
    print(f"\n  {'Treatment':<22} {'None':>14} {'Moderate':>14} {'High':>14} {'Severe':>14} {'Extreme':>14}")
    print("  " + "-"*92)

    for t in treatments:
        r = aggregated['treatments'][t]
        print(f"  {labels[t]:<22} "
              f"{r['stress_none']['mean']:>6.1f}±{r['stress_none']['std']:>4.1f}% "
              f"{r['stress_moderate']['mean']:>6.1f}±{r['stress_moderate']['std']:>4.1f}% "
              f"{r['stress_high']['mean']:>6.1f}±{r['stress_high']['std']:>4.1f}% "
              f"{r['stress_severe']['mean']:>6.1f}±{r['stress_severe']['std']:>4.1f}% "
              f"{r['stress_extreme']['mean']:>6.1f}±{r['stress_extreme']['std']:>4.1f}%")

    # ========================================================================
    # TABLE 3: MANIC CONVERSION RISK METRICS
    # ========================================================================
    print("\n" + "-"*80)
    print("  TABLE 3: MANIC CONVERSION RISK METRICS (Mean ± Std)")
    print("-"*80)
    print(f"\n  {'Treatment':<22} {'Gain':>10} {'Biased Stress':>18} {'Act. Magnitude':>18}")
    print("  " + "-"*70)

    for t in treatments:
        r = aggregated['treatments'][t]
        gain_str = f"{r['gain']['mean']:>5.2f}±{r['gain']['std']:>4.2f}"
        biased_str = f"{r['biased_stress']['mean']:>6.1f}±{r['biased_stress']['std']:>4.1f}%"
        act_str = f"{r['activation_magnitude']['mean']:>6.3f}±{r['activation_magnitude']['std']:>5.3f}"
        print(f"  {labels[t]:<22} {gain_str:>10} {biased_str:>18} {act_str:>18}")

    print("\n  Note: Lower Biased Stress Accuracy = Higher manic conversion risk")
    print("        Higher Activation Magnitude = Higher latent hyperexcitability")

    # ========================================================================
    # TABLE 4: ACUTE RELAPSE VULNERABILITY
    # ========================================================================
    print("\n" + "-"*80)
    print("  TABLE 4: ACUTE RELAPSE VULNERABILITY (Mean ± Std)")
    print("-"*80)
    print(f"\n  {'Treatment':<22} {'Relapse Drop':>18}")
    print("  " + "-"*42)

    for t in ['ketamine', 'ssri', 'neurosteroid']:
        r = aggregated['treatments'][t]
        print(f"  {labels[t]:<22} {r['relapse_drop']['mean']:>6.1f}±{r['relapse_drop']['std']:>4.1f}%")

    # ========================================================================
    # LONGITUDINAL MANIC RELAPSE RESULTS
    # ========================================================================
    print("\n" + "="*80)
    print("  LONGITUDINAL MANIC RELAPSE AFTER DISCONTINUATION")
    print("="*80)

    print("\n  Protocol: Subjects on MS + antidepressant for chronic maintenance,")
    print("            then all medications discontinued, MS protection decays,")
    print("            manic trigger test administered after decay period.")
    print(f"            Manic relapse threshold: biased stress accuracy < {CONFIG['manic_relapse_threshold']:.0f}%")
    print(f"            Post-discontinuation decay steps: {CONFIG['post_discontinuation_steps']}")

    # ========================================================================
    # TABLE 5: MS DECAY RATE PARAMETERS (ARCHITECTURE)
    # ========================================================================
    print("\n" + "-"*80)
    print("  TABLE 5: MOOD STABILIZER DECAY RATE PARAMETERS (Architecture)")
    print("-"*80)
    print(f"\n  {'Treatment':<22} {'MS Decay Rate':>16} {'Decay per Step':>16}")
    print("  " + "-"*56)

    for t in ['ketamine', 'ssri', 'neurosteroid']:
        decay_rate = CONFIG[f'{t}_ms_decay_rate']
        total_decay = decay_rate * CONFIG['post_discontinuation_steps']
        print(f"  {labels[t]:<22} {decay_rate:>15.4f} {total_decay:>15.2f} (over {CONFIG['post_discontinuation_steps']} steps)")

    # ========================================================================
    # TABLE 6: OVERALL MANIC RELAPSE PROBABILITY
    # ========================================================================
    print("\n" + "-"*80)
    print("  TABLE 6: OVERALL MANIC RELAPSE PROBABILITY")
    print("-"*80)
    print(f"\n  {'Treatment':<22} {'Mean':>10} {'Std':>10} {'Min':>10} {'Max':>10}")
    print("  " + "-"*64)

    for t in ['ketamine', 'ssri', 'neurosteroid']:
        r = aggregated['longitudinal'][t]
        print(f"  {labels[t]:<22} "
              f"{r['relapse_probability_mean']:>9.1f}% "
              f"{r['relapse_probability_std']:>9.1f}% "
              f"{r['relapse_probability_min']:>9.1f}% "
              f"{r['relapse_probability_max']:>9.1f}%")

    # ========================================================================
    # TABLE 7: MANIC RELAPSE BY MAINTENANCE DURATION
    # ========================================================================
    print("\n" + "-"*80)
    print("  TABLE 7: MANIC RELAPSE RATE BY MAINTENANCE DURATION")
    print("-"*80)

    durations = CONFIG['maintenance_durations']
    header = f"  {'Treatment':<18}"
    for dur in durations:
        header += f" {dur:>8}"
    print(f"\n{header} epochs")
    print("  " + "-"*(18 + 9*len(durations)))

    for t in ['ketamine', 'ssri', 'neurosteroid']:
        row = f"  {labels[t]:<18}"
        for dur in durations:
            if dur in aggregated['longitudinal'][t]['per_duration']:
                rate = aggregated['longitudinal'][t]['per_duration'][dur]['relapse_rate']
                row += f" {rate:>7.0f}%"
            else:
                row += f" {'N/A':>8}"
        print(row)

    # ========================================================================
    # TABLE 8: BIASED STRESS ACCURACY BY MAINTENANCE DURATION
    # ========================================================================
    print("\n" + "-"*80)
    print("  TABLE 8: BIASED STRESS ACCURACY BY MAINTENANCE DURATION (Mean)")
    print("-"*80)

    header = f"  {'Treatment':<18}"
    for dur in durations:
        header += f" {dur:>8}"
    print(f"\n{header} epochs")
    print("  " + "-"*(18 + 9*len(durations)))

    for t in ['ketamine', 'ssri', 'neurosteroid']:
        row = f"  {labels[t]:<18}"
        for dur in durations:
            if dur in aggregated['longitudinal'][t]['per_duration']:
                acc = aggregated['longitudinal'][t]['per_duration'][dur]['biased_acc_mean']
                row += f" {acc:>7.1f}%"
            else:
                row += f" {'N/A':>8}"
        print(row)

    # ========================================================================
    # TABLE 9: MS PROTECTION LEVEL AT TEST BY MAINTENANCE DURATION
    # ========================================================================
    print("\n" + "-"*80)
    print("  TABLE 9: MS PROTECTION LEVEL AT TEST TIME (Mean)")
    print("-"*80)

    header = f"  {'Treatment':<18}"
    for dur in durations:
        header += f" {dur:>8}"
    print(f"\n{header} epochs")
    print("  " + "-"*(18 + 9*len(durations)))

    for t in ['ketamine', 'ssri', 'neurosteroid']:
        row = f"  {labels[t]:<18}"
        for dur in durations:
            if dur in aggregated['longitudinal'][t]['per_duration']:
                ms = aggregated['longitudinal'][t]['per_duration'][dur]['ms_protection_mean']
                row += f" {ms:>8.3f}"
            else:
                row += f" {'N/A':>8}"
        print(row)

    # ========================================================================
    # TABLE 10: DETAILED LONGITUDINAL RESULTS PER DURATION
    # ========================================================================
    print("\n" + "-"*80)
    print("  TABLE 10: DETAILED LONGITUDINAL RESULTS PER DURATION")
    print("-"*80)

    for dur in durations:
        print(f"\n  Maintenance Duration: {dur} epochs")
        print(f"  {'Treatment':<18} {'Biased Acc':>14} {'MS Protection':>14} {'Eff. Gain':>12} {'Relapse Rate':>14}")
        print("  " + "-"*74)

        for t in ['ketamine', 'ssri', 'neurosteroid']:
            if dur in aggregated['longitudinal'][t]['per_duration']:
                d = aggregated['longitudinal'][t]['per_duration'][dur]
                print(f"  {labels[t]:<18} "
                      f"{d['biased_acc_mean']:>6.1f}±{d['biased_acc_std']:>4.1f}% "
                      f"{d['ms_protection_mean']:>6.3f}±{d['ms_protection_std']:>4.3f} "
                      f"{d['effective_gain_mean']:>11.3f} "
                      f"{d['relapse_rate']:>7.0f}% ({d['n_relapsed']}/{d['n_total']})")

    # ========================================================================
    # TABLE 11: NEUROSTEROID MEDICATION DEPENDENCE
    # ========================================================================
    print("\n" + "-"*80)
    print("  TABLE 11: NEUROSTEROID MEDICATION DEPENDENCE (Mean ± Std)")
    print("-"*80)

    neuro = aggregated['treatments']['neurosteroid']
    print(f"\n  {'Condition':<30} {'Combined':>18} {'Extreme':>18} {'Biased':>18}")
    print("  " + "-"*86)
    print(f"  {'On medication':<30} "
          f"{neuro['combined']['mean']:>6.1f}±{neuro['combined']['std']:>4.1f}% "
          f"{neuro['stress_extreme']['mean']:>6.1f}±{neuro['stress_extreme']['std']:>4.1f}% "
          f"{neuro['biased_stress']['mean']:>6.1f}±{neuro['biased_stress']['std']:>4.1f}%")
    print(f"  {'Off medication':<30} "
          f"{neuro['off_medication_combined']['mean']:>6.1f}±{neuro['off_medication_combined']['std']:>4.1f}% "
          f"{neuro['off_medication_extreme']['mean']:>6.1f}±{neuro['off_medication_extreme']['std']:>4.1f}% "
          f"{neuro['off_medication_biased']['mean']:>6.1f}±{neuro['off_medication_biased']['std']:>4.1f}%")

    # ========================================================================
    # SUMMARY COMPARISON MATRIX
    # ========================================================================
    print("\n" + "="*80)
    print("  SUMMARY COMPARISON MATRIX")
    print("="*80)

    print("""
  ┌─────────────────────┬────────────┬────────────┬──────────────┬─────────────┐
  │ Metric              │ Ketamine   │ SSRI       │ Neurosteroid │ Untreated   │
  ├─────────────────────┼────────────┼────────────┼──────────────┼─────────────┤""")

    ket = aggregated['treatments']['ketamine']
    ssri = aggregated['treatments']['ssri']
    neuro = aggregated['treatments']['neurosteroid']
    untreated = aggregated['treatments']['untreated']

    print(f"  │ Combined Stress (%) │ {ket['combined']['mean']:>10.1f} │ {ssri['combined']['mean']:>10.1f} │ {neuro['combined']['mean']:>12.1f} │ {untreated['combined']['mean']:>11.1f} │")
    print(f"  │ Biased Stress (%)   │ {ket['biased_stress']['mean']:>10.1f} │ {ssri['biased_stress']['mean']:>10.1f} │ {neuro['biased_stress']['mean']:>12.1f} │ {untreated['biased_stress']['mean']:>11.1f} │")
    print(f"  │ Gain Multiplier     │ {ket['gain']['mean']:>10.2f} │ {ssri['gain']['mean']:>10.2f} │ {neuro['gain']['mean']:>12.2f} │ {untreated['gain']['mean']:>11.2f} │")
    print(f"  │ Activation Mag.     │ {ket['activation_magnitude']['mean']:>10.3f} │ {ssri['activation_magnitude']['mean']:>10.3f} │ {neuro['activation_magnitude']['mean']:>12.3f} │ {untreated['activation_magnitude']['mean']:>11.3f} │")
    print(f"  │ Acute Relapse Drop  │ {ket['relapse_drop']['mean']:>10.1f} │ {ssri['relapse_drop']['mean']:>10.1f} │ {neuro['relapse_drop']['mean']:>12.1f} │         N/A │")

    ket_long = aggregated['longitudinal']['ketamine']
    ssri_long = aggregated['longitudinal']['ssri']
    neuro_long = aggregated['longitudinal']['neurosteroid']

    print(f"  │ Manic Relapse Prob. │ {ket_long['relapse_probability_mean']:>9.1f}% │ {ssri_long['relapse_probability_mean']:>9.1f}% │ {neuro_long['relapse_probability_mean']:>11.1f}% │         N/A │")
    print(f"  │ MS Decay Rate       │ {CONFIG['ketamine_ms_decay_rate']:>10.4f} │ {CONFIG['ssri_ms_decay_rate']:>10.4f} │ {CONFIG['neurosteroid_ms_decay_rate']:>12.4f} │         N/A │")

    print("  └─────────────────────┴────────────┴────────────┴──────────────┴─────────────┘")

    # ========================================================================
    # LONGITUDINAL RELAPSE RISK RANKING
    # ========================================================================
    print("\n" + "-"*80)
    print("  LONGITUDINAL MANIC RELAPSE RISK RANKING")
    print("-"*80)

    risk_data = [
        ('Ketamine-like', ket_long['relapse_probability_mean'], CONFIG['ketamine_ms_decay_rate']),
        ('SSRI-like', ssri_long['relapse_probability_mean'], CONFIG['ssri_ms_decay_rate']),
        ('Neurosteroid-like', neuro_long['relapse_probability_mean'], CONFIG['neurosteroid_ms_decay_rate'])
    ]
    risk_data.sort(key=lambda x: x[1])

    print(f"\n  {'Rank':<6} {'Treatment':<22} {'Relapse Prob':>14} {'MS Decay Rate':>14}")
    print("  " + "-"*58)

    for i, (name, prob, decay) in enumerate(risk_data):
        print(f"  {i+1:<6} {name:<22} {prob:>13.1f}% {decay:>14.4f}")

    print("\n  Note: Lower relapse probability = better long-term stability")
    print("        Lower MS decay rate = structural changes persist longer")

    print("\n" + "="*80)
    print("  EXPERIMENT COMPLETE")
    print("  Results averaged across {} random seeds".format(CONFIG['n_seeds']))
    print("="*80 + "\n")

    return {'all_results': all_results, 'aggregated': aggregated}


# ============================================================================
# ENTRY POINT
# ============================================================================
if __name__ == "__main__":
    print("\n" + "#"*80)
    print("#" + " "*78 + "#")
    print("#" + " MULTI-MECHANISM ANTIDEPRESSANT COMPARISON ".center(78) + "#")
    print("#" + " WITH MANIC CONVERSION RISK + LONGITUDINAL MANIC RELAPSE ".center(78) + "#")
    print("#" + " Ketamine vs SSRI vs Neurosteroid ".center(78) + "#")
    print("#" + f" ({CONFIG['n_seeds']} Random Seeds) ".center(78) + "#")
    print("#" + " "*78 + "#")
    print("#"*80)

    results = run_multi_seed_experiment()


################################################################################
#                                                                              #
#                  MULTI-MECHANISM ANTIDEPRESSANT COMPARISON                   #
#           WITH MANIC CONVERSION RISK + LONGITUDINAL MANIC RELAPSE            #
#                       Ketamine vs SSRI vs Neurosteroid                       #
#                              (10 Random Seeds)                               #
#                                                                              #
################################################################################

  MULTI-MECHANISM ANTIDEPRESSANT COMPARISON EXPERIMENT
  WITH MANIC CONVERSION RISK + LONGITUDINAL MANIC RELAPSE
  Running across 10 random seeds

  Seed 1/10... done

  Seed 2/10... done

  Seed 3/10... done

  Seed 4/10... done

  Seed 5/10... done

  Seed 6/10... done

  Seed 7/10... done

  Seed 8/10... done

  Seed 9/10... done

  Seed 10/10..

# The End