# Depression Relapse Comparison

In [1]:
"""
================================================================================
MULTI-MECHANISM ANTIDEPRESSANT COMPARISON EXPERIMENT
WITH LONGITUDINAL RELAPSE RISK MODELING
================================================================================

Compares three antidepressant mechanisms across 10 random seeds:
1. KETAMINE-LIKE: Gradient-guided synaptogenesis
2. SSRI-LIKE: Gradual stabilization without structural changes
3. NEUROSTEROID-LIKE: Tonic inhibition enhancement

Includes longitudinal simulation of depressive relapse under chronic stress.
================================================================================
"""

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.utils.data import DataLoader, TensorDataset
from typing import Dict, Tuple, List
import warnings
from collections import defaultdict
import copy

warnings.filterwarnings('ignore', category=UserWarning)

DEVICE = torch.device('cpu')

# ============================================================================
# CONFIGURATION
# ============================================================================
CONFIG = {
    'n_seeds': 10,
    'n_train': 12000,
    'n_test': 4000,
    'n_clean_test': 2000,
    'data_noise': 0.8,
    'batch_size': 128,
    'hidden_dims': [512, 512, 256],
    'input_dim': 2,
    'output_dim': 4,
    'baseline_epochs': 20,
    'baseline_lr': 0.001,
    'finetune_epochs': 15,
    'finetune_lr': 0.0005,
    'prune_sparsity': 0.95,
    'regrow_fraction': 0.5,
    'regrow_init_scale': 0.03,
    'gradient_accumulation_batches': 30,
    'extended_stress_levels': {
        'none': 0.0,
        'moderate': 0.5,
        'high': 1.0,
        'severe': 1.5,
        'extreme': 2.5
    },
    'monoaminergic_epochs': 100,
    'monoaminergic_lr': 1e-5,
    'monoaminergic_initial_stress': 0.5,
    'neurosteroid_inhibition_strength': 0.7,
    'neurosteroid_use_tanh': True,
    'neurosteroid_consolidation_epochs': 10,
    'comparison_ketamine_regrow': 0.5,
    'comparison_ketamine_epochs': 15,
    'comparison_ssri_epochs': 100,
    'comparison_neurosteroid_strength': 0.7,
    'comparison_neurosteroid_epochs': 10,

    # Longitudinal relapse parameters
    'longitudinal_cycles': 8,
    'longitudinal_pruning_per_cycle': 0.10,
    'longitudinal_ketamine_maintenance_epochs': 5,
    'longitudinal_ssri_maintenance_epochs': 20,
    'longitudinal_ssri_maintenance_stress': 0.3,
    'longitudinal_neurosteroid_maintenance_epochs': 10,
    'relapse_threshold': 80.0,
}


# ============================================================================
# DATA GENERATION
# ============================================================================
def generate_blobs(n_samples: int = 10000, noise: float = 0.8, seed: int = None) -> Tuple[torch.Tensor, torch.Tensor]:
    if seed is not None:
        rng = np.random.RandomState(seed)
    else:
        rng = np.random.RandomState()
    centers = np.array([[-3, -3], [3, 3], [-3, 3], [3, -3]])
    labels = rng.randint(0, 4, n_samples)
    data = centers[labels] + rng.randn(n_samples, 2) * noise
    return (torch.tensor(data, dtype=torch.float32), torch.tensor(labels, dtype=torch.long))


def create_data_loaders(seed: int) -> Tuple[DataLoader, DataLoader, DataLoader]:
    train_data, train_labels = generate_blobs(CONFIG['n_train'], noise=CONFIG['data_noise'], seed=seed*1000+100)
    test_data, test_labels = generate_blobs(CONFIG['n_test'], noise=CONFIG['data_noise'], seed=seed*1000+200)
    clean_test_data, clean_test_labels = generate_blobs(CONFIG['n_clean_test'], noise=0.0, seed=seed*1000+300)
    train_loader = DataLoader(TensorDataset(train_data, train_labels), batch_size=CONFIG['batch_size'], shuffle=True)
    test_loader = DataLoader(TensorDataset(test_data, test_labels), batch_size=1000)
    clean_test_loader = DataLoader(TensorDataset(clean_test_data, clean_test_labels), batch_size=1000)
    return train_loader, test_loader, clean_test_loader


# ============================================================================
# NETWORK ARCHITECTURE
# ============================================================================
class StressAwareNetwork(nn.Module):
    def __init__(self, hidden_dims: List[int] = None):
        super().__init__()
        if hidden_dims is None:
            hidden_dims = CONFIG['hidden_dims']
        self.fc1 = nn.Linear(CONFIG['input_dim'], hidden_dims[0])
        self.fc2 = nn.Linear(hidden_dims[0], hidden_dims[1])
        self.fc3 = nn.Linear(hidden_dims[1], hidden_dims[2])
        self.fc4 = nn.Linear(hidden_dims[2], CONFIG['output_dim'])
        self.relu = nn.ReLU()
        self.tanh = nn.Tanh()
        self.stress_level = 0.0
        self.inhibition_strength = 1.0
        self.use_tanh = False
        self.weight_layers = ['fc1', 'fc2', 'fc3', 'fc4']

    def set_stress(self, level: float):
        self.stress_level = level

    def set_inhibition(self, strength: float, use_tanh: bool = False):
        self.inhibition_strength = strength
        self.use_tanh = use_tanh

    def reduce_stress_gradually(self, epoch: int, total_epochs: int, initial_stress: float = 0.5, final_stress: float = 0.0):
        progress = epoch / max(total_epochs - 1, 1)
        self.stress_level = initial_stress + progress * (final_stress - initial_stress)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        activation = self.tanh if self.use_tanh else self.relu
        h = activation(self.fc1(x))
        if self.stress_level > 0:
            h = h + torch.randn_like(h) * self.stress_level
        h = h * self.inhibition_strength
        h = activation(self.fc2(h))
        if self.stress_level > 0:
            h = h + torch.randn_like(h) * self.stress_level
        h = h * self.inhibition_strength
        h = activation(self.fc3(h))
        if self.stress_level > 0:
            h = h + torch.randn_like(h) * self.stress_level
        h = h * self.inhibition_strength
        return self.fc4(h)

    def count_parameters(self) -> Tuple[int, int]:
        total = sum(p.numel() for p in self.parameters())
        nonzero = sum((p != 0).sum().item() for p in self.parameters())
        return total, nonzero


# ============================================================================
# PRUNING MANAGER
# ============================================================================
class PruningManager:
    def __init__(self, model: StressAwareNetwork, train_loader: DataLoader):
        self.model = model
        self.train_loader = train_loader
        self.masks = {}
        self.gradient_buffer = {}
        for name, param in model.named_parameters():
            if 'weight' in name and param.dim() >= 2:
                self.masks[name] = torch.ones_like(param, dtype=torch.float32)
                self.gradient_buffer[name] = torch.zeros_like(param)

    def prune_by_magnitude(self, sparsity: float, per_layer: bool = True) -> Dict[str, Dict]:
        stats = {}
        for name, param in self.model.named_parameters():
            if name in self.masks:
                weights = param.data.abs()
                threshold = torch.quantile(weights.flatten(), sparsity)
                self.masks[name] = (weights >= threshold).float()
                param.data *= self.masks[name]
                kept = self.masks[name].sum().item()
                total = self.masks[name].numel()
                stats[name] = {'kept': int(kept), 'total': total, 'actual_sparsity': 1 - kept/total}
        return stats

    def _accumulate_gradients(self, num_batches: int = 30):
        model = self.model
        loss_fn = nn.CrossEntropyLoss()
        for name in self.gradient_buffer:
            self.gradient_buffer[name].zero_()
        model.train()
        original_stress = model.stress_level
        model.set_stress(0.0)
        batch_count = 0
        for x, y in self.train_loader:
            if batch_count >= num_batches:
                break
            x, y = x.to(DEVICE), y.to(DEVICE)
            loss = loss_fn(model(x), y)
            loss.backward()
            with torch.no_grad():
                for name, param in model.named_parameters():
                    if name in self.masks:
                        pruned_mask = (self.masks[name] == 0).float()
                        self.gradient_buffer[name] += param.grad.abs() * pruned_mask
            model.zero_grad()
            batch_count += 1
        model.set_stress(original_stress)

    def gradient_guided_regrow(self, regrow_fraction: float, init_scale: float = None) -> Dict[str, Dict]:
        if init_scale is None:
            init_scale = CONFIG['regrow_init_scale']
        self._accumulate_gradients(num_batches=CONFIG['gradient_accumulation_batches'])
        stats = {}
        for name, param in self.model.named_parameters():
            if name not in self.masks:
                continue
            mask = self.masks[name]
            pruned_positions = (mask == 0)
            num_pruned = pruned_positions.sum().item()
            if num_pruned == 0:
                stats[name] = {'regrown': 0, 'still_pruned': 0}
                continue
            gradient_scores = self.gradient_buffer[name][pruned_positions]
            num_regrow = max(1, int(regrow_fraction * num_pruned))
            num_regrow = min(num_regrow, gradient_scores.numel())
            _, top_indices = torch.topk(gradient_scores.flatten(), num_regrow)
            flat_pruned_indices = torch.where(pruned_positions.flatten())[0]
            regrow_flat_indices = flat_pruned_indices[top_indices]
            flat_mask = mask.flatten()
            flat_param = param.data.flatten()
            flat_mask[regrow_flat_indices] = 1.0
            flat_param[regrow_flat_indices] = torch.randn(num_regrow) * init_scale
            self.masks[name] = flat_mask.view_as(mask)
            param.data = flat_param.view_as(param)
            stats[name] = {'regrown': num_regrow, 'still_pruned': int(num_pruned - num_regrow)}
        return stats

    def apply_masks(self):
        with torch.no_grad():
            for name, param in self.model.named_parameters():
                if name in self.masks:
                    param.data *= self.masks[name]

    def get_sparsity(self) -> float:
        total = sum(m.numel() for m in self.masks.values())
        zeros = sum((m == 0).sum().item() for m in self.masks.values())
        return zeros / total if total > 0 else 0.0

    def clone_masks(self) -> Dict[str, torch.Tensor]:
        return {k: v.clone() for k, v in self.masks.items()}

    def set_masks(self, masks: Dict[str, torch.Tensor]):
        self.masks = {k: v.clone() for k, v in masks.items()}
        self.apply_masks()


# ============================================================================
# TRAINING FUNCTIONS
# ============================================================================
def train(model: StressAwareNetwork, train_loader: DataLoader, epochs: int = 15, lr: float = 0.001,
          pruning_manager: PruningManager = None) -> List[float]:
    optimizer = optim.Adam(model.parameters(), lr=lr)
    loss_fn = nn.CrossEntropyLoss()
    losses = []
    original_stress = model.stress_level
    model.set_stress(0.0)
    for epoch in range(epochs):
        model.train()
        epoch_loss = 0.0
        for x, y in train_loader:
            x, y = x.to(DEVICE), y.to(DEVICE)
            optimizer.zero_grad()
            loss = loss_fn(model(x), y)
            loss.backward()
            optimizer.step()
            if pruning_manager:
                pruning_manager.apply_masks()
            epoch_loss += loss.item()
        losses.append(epoch_loss / len(train_loader))
    model.set_stress(original_stress)
    return losses


def train_with_stress_schedule(model: StressAwareNetwork, train_loader: DataLoader, epochs: int, lr: float,
                                initial_stress: float, final_stress: float = 0.0,
                                pruning_manager: PruningManager = None) -> List[float]:
    optimizer = optim.Adam(model.parameters(), lr=lr)
    loss_fn = nn.CrossEntropyLoss()
    losses = []
    for epoch in range(epochs):
        model.reduce_stress_gradually(epoch, epochs, initial_stress, final_stress)
        model.train()
        epoch_loss = 0.0
        for x, y in train_loader:
            x, y = x.to(DEVICE), y.to(DEVICE)
            optimizer.zero_grad()
            loss = loss_fn(model(x), y)
            loss.backward()
            optimizer.step()
            if pruning_manager:
                pruning_manager.apply_masks()
            epoch_loss += loss.item()
        losses.append(epoch_loss / len(train_loader))
    model.set_stress(0.0)
    return losses


# ============================================================================
# EVALUATION FUNCTIONS
# ============================================================================
def evaluate(model: StressAwareNetwork, loader: DataLoader, input_noise: float = 0.0, internal_stress: float = 0.0) -> float:
    model.eval()
    model.set_stress(internal_stress)
    correct, total = 0, 0
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(DEVICE), y.to(DEVICE)
            if input_noise > 0:
                x = x + torch.randn_like(x) * input_noise
            correct += (model(x).argmax(dim=1) == y).sum().item()
            total += y.size(0)
    model.set_stress(0.0)
    return 100.0 * correct / total


def evaluate_with_neurosteroid(model: StressAwareNetwork, loader: DataLoader, input_noise: float = 0.0, internal_stress: float = 0.0) -> float:
    model.eval()
    model.set_stress(internal_stress)
    correct, total = 0, 0
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(DEVICE), y.to(DEVICE)
            if input_noise > 0:
                x = x + torch.randn_like(x) * input_noise
            correct += (model(x).argmax(dim=1) == y).sum().item()
            total += y.size(0)
    model.set_stress(0.0)
    return 100.0 * correct / total


# ============================================================================
# TREATMENT PROTOCOLS
# ============================================================================
def ketamine_treatment(model: StressAwareNetwork, pruning_mgr: PruningManager, train_loader: DataLoader) -> Dict:
    pruning_mgr.gradient_guided_regrow(regrow_fraction=CONFIG['comparison_ketamine_regrow'])
    train(model, train_loader, epochs=CONFIG['comparison_ketamine_epochs'], lr=CONFIG['finetune_lr'], pruning_manager=pruning_mgr)
    return {'final_sparsity': pruning_mgr.get_sparsity()}


def ssri_treatment(model: StressAwareNetwork, pruning_mgr: PruningManager, train_loader: DataLoader) -> Dict:
    train_with_stress_schedule(
        model, train_loader, epochs=CONFIG['comparison_ssri_epochs'], lr=CONFIG['monoaminergic_lr'],
        initial_stress=CONFIG['monoaminergic_initial_stress'], final_stress=0.0, pruning_manager=pruning_mgr
    )
    return {'final_sparsity': pruning_mgr.get_sparsity()}


def neurosteroid_treatment(model: StressAwareNetwork, pruning_mgr: PruningManager, train_loader: DataLoader) -> Dict:
    model.set_inhibition(CONFIG['neurosteroid_inhibition_strength'], CONFIG['neurosteroid_use_tanh'])
    train(model, train_loader, epochs=CONFIG['neurosteroid_consolidation_epochs'], lr=CONFIG['finetune_lr'], pruning_manager=pruning_mgr)
    return {'final_sparsity': pruning_mgr.get_sparsity(), 'inhibition_strength': model.inhibition_strength, 'use_tanh': model.use_tanh}


# ============================================================================
# LONGITUDINAL RELAPSE SIMULATION
# ============================================================================
def run_longitudinal_simulation(post_states: Dict, post_masks: Dict, post_inhib: Dict,
                                 train_loader: DataLoader, test_loader: DataLoader) -> Dict:
    """
    Simulate longitudinal course with chronic stress and maintenance treatment.
    Returns accuracy and sparsity trajectories over cycles.
    """
    num_cycles = CONFIG['longitudinal_cycles']
    pruning_per_cycle = CONFIG['longitudinal_pruning_per_cycle']

    trajectories = {
        'ketamine': {'accuracy': [], 'sparsity': []},
        'ssri': {'accuracy': [], 'sparsity': []},
        'neurosteroid': {'accuracy': [], 'sparsity': []}
    }

    long_models = {}

    for treat in ['ketamine', 'ssri', 'neurosteroid']:
        model = StressAwareNetwork().to(DEVICE)
        model.load_state_dict(copy.deepcopy(post_states[treat]))
        mgr = PruningManager(model, train_loader)
        mgr.set_masks(post_masks[treat])

        if treat == 'neurosteroid':
            strength, use_tanh = post_inhib['neurosteroid']
            model.set_inhibition(strength, use_tanh)

        long_models[treat] = (model, mgr)

        # Initial post-treatment baseline
        sp = mgr.get_sparsity() * 100
        trajectories[treat]['sparsity'].append(sp)

        if treat == 'neurosteroid':
            acc = evaluate_with_neurosteroid(model, test_loader, 1.0, 0.5)
        else:
            acc = evaluate(model, test_loader, 1.0, 0.5)
        trajectories[treat]['accuracy'].append(acc)

    # Run longitudinal cycles
    for cycle in range(1, num_cycles + 1):
        # Apply ongoing stress (cumulative pruning)
        for treat, (model, mgr) in long_models.items():
            pre_sp = mgr.get_sparsity()
            target_sp = pre_sp + (1 - pre_sp) * pruning_per_cycle
            mgr.prune_by_magnitude(sparsity=target_sp, per_layer=True)
            mgr.apply_masks()

        # Apply chronic maintenance (treatment-specific)
        for treat, (model, mgr) in long_models.items():
            if treat == 'ketamine':
                train(model, train_loader, epochs=CONFIG['longitudinal_ketamine_maintenance_epochs'],
                      lr=CONFIG['finetune_lr'], pruning_manager=mgr)
            elif treat == 'ssri':
                train_with_stress_schedule(model, train_loader,
                                           epochs=CONFIG['longitudinal_ssri_maintenance_epochs'],
                                           lr=CONFIG['monoaminergic_lr'],
                                           initial_stress=CONFIG['longitudinal_ssri_maintenance_stress'],
                                           final_stress=0.0, pruning_manager=mgr)
            elif treat == 'neurosteroid':
                train(model, train_loader, epochs=CONFIG['longitudinal_neurosteroid_maintenance_epochs'],
                      lr=CONFIG['finetune_lr'], pruning_manager=mgr)

        # Evaluate after maintenance
        for treat, (model, mgr) in long_models.items():
            sp = mgr.get_sparsity() * 100
            trajectories[treat]['sparsity'].append(sp)

            if treat == 'neurosteroid':
                acc = evaluate_with_neurosteroid(model, test_loader, 1.0, 0.5)
            else:
                acc = evaluate(model, test_loader, 1.0, 0.5)
            trajectories[treat]['accuracy'].append(acc)

    return trajectories


# ============================================================================
# SINGLE SEED EXPERIMENT
# ============================================================================
def run_single_seed(seed: int) -> Dict:
    torch.manual_seed(seed)
    np.random.seed(seed)

    train_loader, test_loader, clean_test_loader = create_data_loaders(seed)

    # Train and prune base model
    base_model = StressAwareNetwork().to(DEVICE)
    train(base_model, train_loader, epochs=CONFIG['baseline_epochs'], lr=CONFIG['baseline_lr'])
    base_pruning_mgr = PruningManager(base_model, train_loader)
    base_pruning_mgr.prune_by_magnitude(sparsity=CONFIG['prune_sparsity'])

    base_state_dict = {k: v.clone() for k, v in base_model.state_dict().items()}
    base_masks = base_pruning_mgr.clone_masks()

    results = {}
    post_states = {}
    post_masks = {}
    post_inhib = {}

    # Untreated
    untreated_results = {
        'sparsity': base_pruning_mgr.get_sparsity() * 100,
        'clean': evaluate(base_model, clean_test_loader),
        'standard': evaluate(base_model, test_loader),
        'combined': evaluate(base_model, test_loader, 1.0, 0.5),
    }
    for stress_name, stress_level in CONFIG['extended_stress_levels'].items():
        untreated_results[f'stress_{stress_name}'] = evaluate(base_model, test_loader, 0.0, stress_level)
    results['untreated'] = untreated_results

    # Ketamine
    ketamine_model = StressAwareNetwork().to(DEVICE)
    ketamine_model.load_state_dict(copy.deepcopy(base_state_dict))
    ketamine_mgr = PruningManager(ketamine_model, train_loader)
    ketamine_mgr.set_masks(base_masks)
    ketamine_treatment(ketamine_model, ketamine_mgr, train_loader)

    ketamine_results = {
        'sparsity': ketamine_mgr.get_sparsity() * 100,
        'clean': evaluate(ketamine_model, clean_test_loader),
        'standard': evaluate(ketamine_model, test_loader),
        'combined': evaluate(ketamine_model, test_loader, 1.0, 0.5),
    }
    for stress_name, stress_level in CONFIG['extended_stress_levels'].items():
        ketamine_results[f'stress_{stress_name}'] = evaluate(ketamine_model, test_loader, 0.0, stress_level)

    # Save post-treatment state for longitudinal simulation
    post_states['ketamine'] = copy.deepcopy(ketamine_model.state_dict())
    post_masks['ketamine'] = ketamine_mgr.clone_masks()

    # Acute relapse
    pre_relapse = ketamine_results['combined']
    target_sparsity = min(ketamine_mgr.get_sparsity() + (1 - ketamine_mgr.get_sparsity()) * 0.40, 0.99)
    ketamine_mgr.prune_by_magnitude(sparsity=target_sparsity)
    ketamine_mgr.apply_masks()
    ketamine_results['relapse_drop'] = pre_relapse - evaluate(ketamine_model, test_loader, 1.0, 0.5)
    results['ketamine'] = ketamine_results

    # SSRI
    ssri_model = StressAwareNetwork().to(DEVICE)
    ssri_model.load_state_dict(copy.deepcopy(base_state_dict))
    ssri_mgr = PruningManager(ssri_model, train_loader)
    ssri_mgr.set_masks(base_masks)
    ssri_treatment(ssri_model, ssri_mgr, train_loader)

    ssri_results = {
        'sparsity': ssri_mgr.get_sparsity() * 100,
        'clean': evaluate(ssri_model, clean_test_loader),
        'standard': evaluate(ssri_model, test_loader),
        'combined': evaluate(ssri_model, test_loader, 1.0, 0.5),
    }
    for stress_name, stress_level in CONFIG['extended_stress_levels'].items():
        ssri_results[f'stress_{stress_name}'] = evaluate(ssri_model, test_loader, 0.0, stress_level)

    post_states['ssri'] = copy.deepcopy(ssri_model.state_dict())
    post_masks['ssri'] = ssri_mgr.clone_masks()

    pre_relapse = ssri_results['combined']
    target_sparsity = min(ssri_mgr.get_sparsity() + (1 - ssri_mgr.get_sparsity()) * 0.40, 0.99)
    ssri_mgr.prune_by_magnitude(sparsity=target_sparsity)
    ssri_mgr.apply_masks()
    ssri_results['relapse_drop'] = pre_relapse - evaluate(ssri_model, test_loader, 1.0, 0.5)
    results['ssri'] = ssri_results

    # Neurosteroid
    neuro_model = StressAwareNetwork().to(DEVICE)
    neuro_model.load_state_dict(copy.deepcopy(base_state_dict))
    neuro_mgr = PruningManager(neuro_model, train_loader)
    neuro_mgr.set_masks(base_masks)
    neurosteroid_treatment(neuro_model, neuro_mgr, train_loader)

    neuro_results = {
        'sparsity': neuro_mgr.get_sparsity() * 100,
        'clean': evaluate_with_neurosteroid(neuro_model, clean_test_loader),
        'standard': evaluate_with_neurosteroid(neuro_model, test_loader),
        'combined': evaluate_with_neurosteroid(neuro_model, test_loader, 1.0, 0.5),
    }
    for stress_name, stress_level in CONFIG['extended_stress_levels'].items():
        neuro_results[f'stress_{stress_name}'] = evaluate_with_neurosteroid(neuro_model, test_loader, 0.0, stress_level)

    post_states['neurosteroid'] = copy.deepcopy(neuro_model.state_dict())
    post_masks['neurosteroid'] = neuro_mgr.clone_masks()
    post_inhib['neurosteroid'] = (neuro_model.inhibition_strength, neuro_model.use_tanh)

    # Off-medication evaluation
    neuro_model.set_inhibition(1.0, False)
    neuro_results['off_medication_combined'] = evaluate(neuro_model, test_loader, 1.0, 0.5)
    neuro_results['off_medication_extreme'] = evaluate(neuro_model, test_loader, 0.0, 2.5)

    # Restore modulation for relapse
    neuro_model.set_inhibition(CONFIG['neurosteroid_inhibition_strength'], CONFIG['neurosteroid_use_tanh'])
    pre_relapse = neuro_results['combined']
    target_sparsity = min(neuro_mgr.get_sparsity() + (1 - neuro_mgr.get_sparsity()) * 0.40, 0.99)
    neuro_mgr.prune_by_magnitude(sparsity=target_sparsity)
    neuro_mgr.apply_masks()
    neuro_results['relapse_drop'] = pre_relapse - evaluate_with_neurosteroid(neuro_model, test_loader, 1.0, 0.5)
    results['neurosteroid'] = neuro_results

    # Run longitudinal simulation
    trajectories = run_longitudinal_simulation(post_states, post_masks, post_inhib, train_loader, test_loader)
    results['longitudinal'] = trajectories

    return results


# ============================================================================
# AGGREGATE RESULTS
# ============================================================================
def aggregate_results(all_results: List[Dict]) -> Dict:
    treatments = ['untreated', 'ketamine', 'ssri', 'neurosteroid']
    aggregated = {'treatments': {}, 'longitudinal': {}}

    # Aggregate treatment metrics
    for treatment in treatments:
        metrics = defaultdict(list)
        for seed_result in all_results:
            if treatment in seed_result:
                for key, value in seed_result[treatment].items():
                    if isinstance(value, (int, float)):
                        metrics[key].append(value)

        aggregated['treatments'][treatment] = {}
        for key, values in metrics.items():
            aggregated['treatments'][treatment][key] = {
                'mean': np.mean(values),
                'std': np.std(values),
                'min': np.min(values),
                'max': np.max(values)
            }

    # Aggregate longitudinal trajectories
    for treat in ['ketamine', 'ssri', 'neurosteroid']:
        acc_trajectories = []
        sp_trajectories = []
        for seed_result in all_results:
            if 'longitudinal' in seed_result and treat in seed_result['longitudinal']:
                acc_trajectories.append(seed_result['longitudinal'][treat]['accuracy'])
                sp_trajectories.append(seed_result['longitudinal'][treat]['sparsity'])

        acc_array = np.array(acc_trajectories)
        sp_array = np.array(sp_trajectories)

        aggregated['longitudinal'][treat] = {
            'accuracy_mean': np.mean(acc_array, axis=0).tolist(),
            'accuracy_std': np.std(acc_array, axis=0).tolist(),
            'sparsity_mean': np.mean(sp_array, axis=0).tolist(),
            'sparsity_std': np.std(sp_array, axis=0).tolist(),
            'accuracy_min': np.min(acc_array, axis=0).tolist(),
            'accuracy_max': np.max(acc_array, axis=0).tolist(),
        }

        # Calculate relapse metrics
        final_accs = acc_array[:, -1]
        initial_accs = acc_array[:, 0]
        drops = initial_accs - final_accs

        aggregated['longitudinal'][treat]['total_drop_mean'] = np.mean(drops)
        aggregated['longitudinal'][treat]['total_drop_std'] = np.std(drops)
        aggregated['longitudinal'][treat]['final_accuracy_mean'] = np.mean(final_accs)
        aggregated['longitudinal'][treat]['final_accuracy_std'] = np.std(final_accs)

        # Find relapse cycle for each seed
        relapse_cycles = []
        for traj in acc_trajectories:
            relapse_cycle = None
            for c, acc in enumerate(traj):
                if acc < CONFIG['relapse_threshold']:
                    relapse_cycle = c
                    break
            relapse_cycles.append(relapse_cycle if relapse_cycle is not None else len(traj))

        aggregated['longitudinal'][treat]['relapse_cycles'] = relapse_cycles
        relapse_cycles_valid = [c for c in relapse_cycles if c < len(acc_trajectories[0])]
        if relapse_cycles_valid:
            aggregated['longitudinal'][treat]['mean_relapse_cycle'] = np.mean(relapse_cycles_valid)
        else:
            aggregated['longitudinal'][treat]['mean_relapse_cycle'] = None
        aggregated['longitudinal'][treat]['n_relapsed'] = len(relapse_cycles_valid)
        aggregated['longitudinal'][treat]['n_no_relapse'] = len(relapse_cycles) - len(relapse_cycles_valid)

    return aggregated


# ============================================================================
# PRINT RESULTS
# ============================================================================
def print_results(aggregated: Dict):
    treatments = ['untreated', 'ketamine', 'ssri', 'neurosteroid']
    labels = {
        'untreated': 'Untreated (pruned)',
        'ketamine': 'Ketamine-like',
        'ssri': 'SSRI-like',
        'neurosteroid': 'Neurosteroid-like'
    }

    print("\n" + "="*80)
    print("  AGGREGATED RESULTS ACROSS {} SEEDS".format(CONFIG['n_seeds']))
    print("="*80)

    # ========================================================================
    # TABLE 1: ANTIDEPRESSANT EFFICACY
    # ========================================================================
    print("\n" + "-"*80)
    print("  TABLE 1: POST-TREATMENT EFFICACY (Mean ± Std)")
    print("-"*80)
    print(f"\n  {'Treatment':<22} {'Sparsity':>12} {'Clean':>14} {'Standard':>14} {'Combined':>14}")
    print("  " + "-"*78)

    for t in treatments:
        r = aggregated['treatments'][t]
        print(f"  {labels[t]:<22} "
              f"{r['sparsity']['mean']:>5.1f}±{r['sparsity']['std']:>4.1f}% "
              f"{r['clean']['mean']:>6.1f}±{r['clean']['std']:>4.1f}% "
              f"{r['standard']['mean']:>6.1f}±{r['standard']['std']:>4.1f}% "
              f"{r['combined']['mean']:>6.1f}±{r['combined']['std']:>4.1f}%")

    # ========================================================================
    # TABLE 2: STRESS RESILIENCE PROFILE
    # ========================================================================
    print("\n" + "-"*80)
    print("  TABLE 2: STRESS RESILIENCE PROFILE (Mean ± Std)")
    print("-"*80)
    print(f"\n  {'Treatment':<22} {'None':>14} {'Moderate':>14} {'High':>14} {'Severe':>14} {'Extreme':>14}")
    print("  " + "-"*92)

    for t in treatments:
        r = aggregated['treatments'][t]
        print(f"  {labels[t]:<22} "
              f"{r['stress_none']['mean']:>6.1f}±{r['stress_none']['std']:>4.1f}% "
              f"{r['stress_moderate']['mean']:>6.1f}±{r['stress_moderate']['std']:>4.1f}% "
              f"{r['stress_high']['mean']:>6.1f}±{r['stress_high']['std']:>4.1f}% "
              f"{r['stress_severe']['mean']:>6.1f}±{r['stress_severe']['std']:>4.1f}% "
              f"{r['stress_extreme']['mean']:>6.1f}±{r['stress_extreme']['std']:>4.1f}%")

    # ========================================================================
    # TABLE 3: ACUTE RELAPSE VULNERABILITY
    # ========================================================================
    print("\n" + "-"*80)
    print("  TABLE 3: ACUTE RELAPSE VULNERABILITY (40% additional pruning)")
    print("-"*80)
    print(f"\n  {'Treatment':<22} {'Relapse Drop':>18}")
    print("  " + "-"*42)

    for t in ['ketamine', 'ssri', 'neurosteroid']:
        r = aggregated['treatments'][t]
        print(f"  {labels[t]:<22} {r['relapse_drop']['mean']:>6.1f}±{r['relapse_drop']['std']:>4.1f}%")

    # ========================================================================
    # TABLE 4: NEUROSTEROID MEDICATION DEPENDENCE
    # ========================================================================
    print("\n" + "-"*80)
    print("  TABLE 4: NEUROSTEROID MEDICATION DEPENDENCE (Mean ± Std)")
    print("-"*80)

    neuro = aggregated['treatments']['neurosteroid']
    print(f"\n  {'Condition':<30} {'Combined':>18} {'Extreme':>18}")
    print("  " + "-"*68)
    print(f"  {'On medication':<30} "
          f"{neuro['combined']['mean']:>6.1f}±{neuro['combined']['std']:>4.1f}% "
          f"{neuro['stress_extreme']['mean']:>6.1f}±{neuro['stress_extreme']['std']:>4.1f}%")
    print(f"  {'Off medication':<30} "
          f"{neuro['off_medication_combined']['mean']:>6.1f}±{neuro['off_medication_combined']['std']:>4.1f}% "
          f"{neuro['off_medication_extreme']['mean']:>6.1f}±{neuro['off_medication_extreme']['std']:>4.1f}%")

    # ========================================================================
    # TABLE 5: LONGITUDINAL TRAJECTORY (ACCURACY BY CYCLE)
    # ========================================================================
    print("\n" + "="*80)
    print("  LONGITUDINAL RELAPSE RISK UNDER CHRONIC TREATMENT")
    print("="*80)

    print("\n" + "-"*80)
    print("  TABLE 5: COMBINED-STRESS ACCURACY (%) BY CYCLE (Mean ± Std)")
    print("-"*80)

    num_cycles = CONFIG['longitudinal_cycles']
    header = f"  {'Cycle':<6}"
    for t in ['ketamine', 'ssri', 'neurosteroid']:
        header += f" {t.capitalize():>18}"
    print(header)
    print("  " + "-"*62)

    for c in range(num_cycles + 1):
        row = f"  {c:<6}"
        for t in ['ketamine', 'ssri', 'neurosteroid']:
            mean = aggregated['longitudinal'][t]['accuracy_mean'][c]
            std = aggregated['longitudinal'][t]['accuracy_std'][c]
            row += f" {mean:>7.1f}±{std:>5.1f}%"
        print(row)

    # ========================================================================
    # TABLE 6: LONGITUDINAL SPARSITY BY CYCLE
    # ========================================================================
    print("\n" + "-"*80)
    print("  TABLE 6: NETWORK SPARSITY (%) BY CYCLE (Mean ± Std)")
    print("-"*80)

    header = f"  {'Cycle':<6}"
    for t in ['ketamine', 'ssri', 'neurosteroid']:
        header += f" {t.capitalize():>18}"
    print(header)
    print("  " + "-"*62)

    for c in range(num_cycles + 1):
        row = f"  {c:<6}"
        for t in ['ketamine', 'ssri', 'neurosteroid']:
            mean = aggregated['longitudinal'][t]['sparsity_mean'][c]
            std = aggregated['longitudinal'][t]['sparsity_std'][c]
            row += f" {mean:>7.1f}±{std:>5.1f}%"
        print(row)

    # ========================================================================
    # TABLE 7: LONGITUDINAL ACCURACY RANGE BY CYCLE
    # ========================================================================
    print("\n" + "-"*80)
    print("  TABLE 7: ACCURACY RANGE BY CYCLE (Min - Max across seeds)")
    print("-"*80)

    header = f"  {'Cycle':<6}"
    for t in ['ketamine', 'ssri', 'neurosteroid']:
        header += f" {t.capitalize():>22}"
    print(header)
    print("  " + "-"*76)

    for c in range(num_cycles + 1):
        row = f"  {c:<6}"
        for t in ['ketamine', 'ssri', 'neurosteroid']:
            min_val = aggregated['longitudinal'][t]['accuracy_min'][c]
            max_val = aggregated['longitudinal'][t]['accuracy_max'][c]
            row += f" {min_val:>7.1f} - {max_val:>7.1f}%"
        print(row)

    # ========================================================================
    # TABLE 8: RELAPSE SUMMARY STATISTICS
    # ========================================================================
    print("\n" + "-"*80)
    print("  TABLE 8: LONGITUDINAL RELAPSE SUMMARY")
    print("-"*80)

    print(f"\n  {'Metric':<40} {'Ketamine':>14} {'SSRI':>14} {'Neurosteroid':>14}")
    print("  " + "-"*84)

    # Total drop
    row = f"  {'Total accuracy drop (cycle 0 to 8)':<40}"
    for t in ['ketamine', 'ssri', 'neurosteroid']:
        mean = aggregated['longitudinal'][t]['total_drop_mean']
        std = aggregated['longitudinal'][t]['total_drop_std']
        row += f" {mean:>5.1f}±{std:>4.1f}%"
    print(row)

    # Final accuracy
    row = f"  {'Final accuracy (cycle 8)':<40}"
    for t in ['ketamine', 'ssri', 'neurosteroid']:
        mean = aggregated['longitudinal'][t]['final_accuracy_mean']
        std = aggregated['longitudinal'][t]['final_accuracy_std']
        row += f" {mean:>5.1f}±{std:>4.1f}%"
    print(row)

    # Number relapsed
    row = f"  {'Seeds with relapse (<{:.0f}%)'.format(CONFIG['relapse_threshold']):<40}"
    for t in ['ketamine', 'ssri', 'neurosteroid']:
        n_relapsed = aggregated['longitudinal'][t]['n_relapsed']
        n_total = CONFIG['n_seeds']
        row += f" {n_relapsed:>5}/{n_total:<8}"
    print(row)

    # Number no relapse
    row = f"  {'Seeds without relapse (≥{:.0f}%)'.format(CONFIG['relapse_threshold']):<40}"
    for t in ['ketamine', 'ssri', 'neurosteroid']:
        n_no = aggregated['longitudinal'][t]['n_no_relapse']
        n_total = CONFIG['n_seeds']
        row += f" {n_no:>5}/{n_total:<8}"
    print(row)

    # Mean relapse cycle
    row = f"  {'Mean cycle at relapse (if relapsed)':<40}"
    for t in ['ketamine', 'ssri', 'neurosteroid']:
        mean_cycle = aggregated['longitudinal'][t]['mean_relapse_cycle']
        if mean_cycle is not None:
            row += f" {mean_cycle:>13.1f}"
        else:
            row += f" {'N/A':>13}"
    print(row)

    # ========================================================================
    # TABLE 9: RELAPSE CYCLE DISTRIBUTION
    # ========================================================================
    print("\n" + "-"*80)
    print("  TABLE 9: RELAPSE CYCLE DISTRIBUTION (cycle when accuracy < {:.0f}%)".format(CONFIG['relapse_threshold']))
    print("-"*80)

    print(f"\n  {'Seed':<8}")
    for t in ['ketamine', 'ssri', 'neurosteroid']:
        print(f" {t.capitalize():>14}", end="")
    print()
    print("  " + "-"*50)

    for seed_idx in range(CONFIG['n_seeds']):
        row = f"  {seed_idx:<8}"
        for t in ['ketamine', 'ssri', 'neurosteroid']:
            cycle = aggregated['longitudinal'][t]['relapse_cycles'][seed_idx]
            if cycle >= num_cycles + 1:
                row += f" {'No relapse':>14}"
            else:
                row += f" {cycle:>14}"
        print(row)

    # ========================================================================
    # TABLE 10: IMPROVEMENT FROM UNTREATED STATE
    # ========================================================================
    print("\n" + "-"*80)
    print("  TABLE 10: IMPROVEMENT FROM UNTREATED STATE")
    print("-"*80)

    untreated_combined = aggregated['treatments']['untreated']['combined']['mean']
    print(f"\n  {'Treatment':<22} {'Untreated':>12} {'Treated':>12} {'Improvement':>14}")
    print("  " + "-"*62)

    for t in ['ketamine', 'ssri', 'neurosteroid']:
        treated = aggregated['treatments'][t]['combined']['mean']
        improvement = treated - untreated_combined
        print(f"  {labels[t]:<22} {untreated_combined:>11.1f}% {treated:>11.1f}% {improvement:>+13.1f}%")

    # ========================================================================
    # TABLE 11: FINAL COMPARISON MATRIX
    # ========================================================================
    print("\n" + "="*80)
    print("  FINAL COMPARISON MATRIX")
    print("="*80)

    print("""
  ┌─────────────────────────────────┬────────────┬────────────┬──────────────┬─────────────┐
  │ Metric                          │ Ketamine   │ SSRI       │ Neurosteroid │ Untreated   │
  ├─────────────────────────────────┼────────────┼────────────┼──────────────┼─────────────┤""")

    ket = aggregated['treatments']['ketamine']
    ssri = aggregated['treatments']['ssri']
    neuro = aggregated['treatments']['neurosteroid']
    untreated = aggregated['treatments']['untreated']

    print(f"  │ Post-treatment combined (%)     │ {ket['combined']['mean']:>10.1f} │ {ssri['combined']['mean']:>10.1f} │ {neuro['combined']['mean']:>12.1f} │ {untreated['combined']['mean']:>11.1f} │")
    print(f"  │ Extreme stress (%)              │ {ket['stress_extreme']['mean']:>10.1f} │ {ssri['stress_extreme']['mean']:>10.1f} │ {neuro['stress_extreme']['mean']:>12.1f} │ {untreated['stress_extreme']['mean']:>11.1f} │")
    print(f"  │ Acute relapse drop (%)          │ {ket['relapse_drop']['mean']:>10.1f} │ {ssri['relapse_drop']['mean']:>10.1f} │ {neuro['relapse_drop']['mean']:>12.1f} │         N/A │")
    print(f"  │ Sparsity (%)                    │ {ket['sparsity']['mean']:>10.1f} │ {ssri['sparsity']['mean']:>10.1f} │ {neuro['sparsity']['mean']:>12.1f} │ {untreated['sparsity']['mean']:>11.1f} │")

    print("  ├─────────────────────────────────┼────────────┼────────────┼──────────────┼─────────────┤")
    print("  │ LONGITUDINAL (8 cycles)         │            │            │              │             │")
    print("  ├─────────────────────────────────┼────────────┼────────────┼──────────────┼─────────────┤")

    ket_long = aggregated['longitudinal']['ketamine']
    ssri_long = aggregated['longitudinal']['ssri']
    neuro_long = aggregated['longitudinal']['neurosteroid']

    print(f"  │ Final accuracy (cycle 8) (%)   │ {ket_long['final_accuracy_mean']:>10.1f} │ {ssri_long['final_accuracy_mean']:>10.1f} │ {neuro_long['final_accuracy_mean']:>12.1f} │         N/A │")
    print(f"  │ Total drop over 8 cycles (%)   │ {ket_long['total_drop_mean']:>10.1f} │ {ssri_long['total_drop_mean']:>10.1f} │ {neuro_long['total_drop_mean']:>12.1f} │         N/A │")
    print(f"  │ Seeds relapsed (< 80%)         │ {ket_long['n_relapsed']:>10} │ {ssri_long['n_relapsed']:>10} │ {neuro_long['n_relapsed']:>12} │         N/A │")
    print(f"  │ Seeds no relapse (≥ 80%)       │ {ket_long['n_no_relapse']:>10} │ {ssri_long['n_no_relapse']:>10} │ {neuro_long['n_no_relapse']:>12} │         N/A │")

    print("  └─────────────────────────────────┴────────────┴────────────┴──────────────┴─────────────┘")

    # ========================================================================
    # TABLE 12: LONGITUDINAL CONFIGURATION
    # ========================================================================
    print("\n" + "-"*80)
    print("  TABLE 12: LONGITUDINAL SIMULATION CONFIGURATION")
    print("-"*80)

    print(f"\n  {'Parameter':<45} {'Value':>20}")
    print("  " + "-"*67)
    print(f"  {'Number of cycles':<45} {CONFIG['longitudinal_cycles']:>20}")
    print(f"  {'Pruning per cycle (% of remaining)':<45} {CONFIG['longitudinal_pruning_per_cycle']*100:>19.0f}%")
    print(f"  {'Relapse threshold':<45} {CONFIG['relapse_threshold']:>19.0f}%")
    print(f"  {'Ketamine maintenance epochs/cycle':<45} {CONFIG['longitudinal_ketamine_maintenance_epochs']:>20}")
    print(f"  {'SSRI maintenance epochs/cycle':<45} {CONFIG['longitudinal_ssri_maintenance_epochs']:>20}")
    print(f"  {'SSRI maintenance initial stress':<45} {CONFIG['longitudinal_ssri_maintenance_stress']:>20}")
    print(f"  {'Neurosteroid maintenance epochs/cycle':<45} {CONFIG['longitudinal_neurosteroid_maintenance_epochs']:>20}")


# ============================================================================
# MAIN EXPERIMENT
# ============================================================================
def run_multi_seed_experiment() -> Dict:
    print("\n" + "="*80)
    print("  MULTI-MECHANISM ANTIDEPRESSANT COMPARISON EXPERIMENT")
    print("  WITH LONGITUDINAL RELAPSE RISK MODELING")
    print("  Running across {} random seeds".format(CONFIG['n_seeds']))
    print("="*80)

    all_results = []

    for seed in range(CONFIG['n_seeds']):
        print(f"\n  Seed {seed+1}/{CONFIG['n_seeds']}...", end=" ", flush=True)
        seed_results = run_single_seed(seed)
        all_results.append(seed_results)
        print("done")

    aggregated = aggregate_results(all_results)
    print_results(aggregated)

    print("\n" + "="*80)
    print("  EXPERIMENT COMPLETE")
    print("  Results averaged across {} random seeds".format(CONFIG['n_seeds']))
    print("="*80 + "\n")

    return {'all_results': all_results, 'aggregated': aggregated}


# ============================================================================
# ENTRY POINT
# ============================================================================
if __name__ == "__main__":
    print("\n" + "#"*80)
    print("#" + " "*78 + "#")
    print("#" + " MULTI-MECHANISM ANTIDEPRESSANT COMPARISON ".center(78) + "#")
    print("#" + " WITH LONGITUDINAL RELAPSE RISK MODELING ".center(78) + "#")
    print("#" + " Ketamine vs SSRI vs Neurosteroid ".center(78) + "#")
    print("#" + f" ({CONFIG['n_seeds']} Random Seeds, {CONFIG['longitudinal_cycles']} Longitudinal Cycles) ".center(78) + "#")
    print("#" + " "*78 + "#")
    print("#"*80)

    results = run_multi_seed_experiment()


################################################################################
#                                                                              #
#                  MULTI-MECHANISM ANTIDEPRESSANT COMPARISON                   #
#                   WITH LONGITUDINAL RELAPSE RISK MODELING                    #
#                       Ketamine vs SSRI vs Neurosteroid                       #
#                   (10 Random Seeds, 8 Longitudinal Cycles)                   #
#                                                                              #
################################################################################

  MULTI-MECHANISM ANTIDEPRESSANT COMPARISON EXPERIMENT
  WITH LONGITUDINAL RELAPSE RISK MODELING
  Running across 10 random seeds

  Seed 1/10... done

  Seed 2/10... done

  Seed 3/10... done

  Seed 4/10... done

  Seed 5/10... done

  Seed 6/10... done

  Seed 7/10... done

  Seed 8/10... done

  Seed 9/10... done

  Seed 10/10... done

  AGGREG

# The End