# Modes of action (OCD) with Isoacute Comparison

In [None]:
!pip install --upgrade sympy
# Then restart runtime (Runtime -> Restart runtime)



In [None]:
#!/usr/bin/env python3
"""
================================================================================
COMPUTATIONAL MODEL VALIDATING THE OCD SYNAPTIC PRUNING HYPOTHESIS
WITH MULTI-MECHANISM ANTIDEPRESSANT COMPARISON
================================================================================

This model extends the OCD framework to compare three distinct treatment mechanisms:
- KETAMINE: Rapid structural repair via synaptogenesis
- SSRI: Gradual functional stabilization with fixed structure
- NEUROSTEROID: Rapid functional damping via tonic inhibition

All mechanisms are modeled through network architecture modifications:
- Weight masks / regrowth (structural changes)
- Multiplicative scaling of hidden states (inhibition)
- Activation functions (bounded vs unbounded)
- Internal noise injection (stress/adaptation)
- Training dynamics (fast vs slow, fixed vs dynamic weights)

THEORETICAL FRAMEWORK:
----------------------
1. KETAMINE-LIKE: Rapid structural synaptogenesis + brief consolidation
2. SSRI-LIKE: Functional stabilization via extended low-LR training + noise reduction
3. NEUROSTEROID-LIKE: Tonic GABAergic inhibition (medication-dependent)

KEY COMPARISONS:
----------------
- Acute effects: Immediate post-treatment symptom reduction
- Long-term/relapse risk: Resistance to secondary pruning, off-medication reversal

NEW IN VERSION 3.1 - ISO-DOSE COMPARISON:
-----------------------------------------
- Quantifiable dosing metrics for fair cross-mechanism comparison
- L1/L2 weight change norms as mechanism-agnostic "dose" proxy
- Synaptic turnover measurement
- Parameter sweeps to match dose across treatments
- Efficiency analysis: outcome per unit dose

Author: Computational Psychiatry Research
Date: January 2026
Version: 3.1 (Iso-Dose Fair Comparison Pipeline)
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
from typing import List, Tuple, Optional, Dict, Any, Union
from dataclasses import dataclass, field
from enum import Enum
import copy
import warnings
from collections import defaultdict
import time

warnings.filterwarnings('ignore')

# =============================================================================
# CONFIGURATION
# =============================================================================

CONFIG = {
    # -------------------------------------------------------------------------
    # Architecture Parameters
    # -------------------------------------------------------------------------
    'input_dim': 2,
    'hidden_dims': [128, 64],
    'output_dim': 4,
    'num_gru_layers': 2,

    # -------------------------------------------------------------------------
    # Training Parameters
    # -------------------------------------------------------------------------
    'batch_size': 32,
    'baseline_lr': 1e-3,
    'finetune_lr': 5e-4,
    'baseline_epochs': 50,
    'regrowth_epochs': 30,

    # -------------------------------------------------------------------------
    # Sequence/Task Parameters
    # -------------------------------------------------------------------------
    'seq_len': 200,
    'n_train_sequences': 500,
    'n_test_sequences': 100,
    'switch_interval': 50,
    'n_rules': 4,

    # -------------------------------------------------------------------------
    # Pruning Parameters
    # -------------------------------------------------------------------------
    'target_sparsities': [0.0, 0.5, 0.7, 0.85, 0.90, 0.93, 0.95, 0.97],
    'regrowth_fraction': 0.50,
    'regrowth_init_scale': 0.03,
    'recurrence_bias': 1.2,

    # -------------------------------------------------------------------------
    # Treatment Duration Experiment Parameters
    # -------------------------------------------------------------------------
    'consolidation_epochs': [0, 5, 10, 15, 20],
    'relapse_prune_fraction': 0.40,

    # -------------------------------------------------------------------------
    # Iterative Regimen Experiment Parameters
    # -------------------------------------------------------------------------
    'acute_regrow_fractions': [0.60, 1.00],
    'chronic_cycles': [3, 6, 10],
    'per_cycle_regrow': 0.40,
    'per_cycle_epochs': 5,
    'final_consolidation': 15,

    # -------------------------------------------------------------------------
    # Stress/Noise Parameters
    # -------------------------------------------------------------------------
    'stress_levels': [0.0, 0.1, 0.3],
    'glutamate_noise_levels': [0.0, 0.2, 0.5],
    'relapse_noise': 0.2,

    # =========================================================================
    # MULTI-MECHANISM ANTIDEPRESSANT COMPARISON
    # =========================================================================
    'ocd_prune_sparsity': 0.95,
    'comparison_ketamine_regrow': 0.60,
    'comparison_ketamine_epochs': 10,
    'comparison_ssri_epochs': 120,
    'comparison_ssri_lr': 1e-5,
    'comparison_ssri_initial_stress': 0.4,
    'comparison_neurosteroid_strength': 0.65,
    'comparison_neurosteroid_use_tanh': True,
    'comparison_neurosteroid_epochs': 8,

    # =========================================================================
    # ISO-DOSE COMPARISON PARAMETERS (NEW)
    # =========================================================================
    'iso_dose_norm_type': 'l1',
    'iso_dose_target_doses': [0.005, 0.010, 0.020, 0.040],
    'iso_dose_tolerance': 0.002,
    'iso_dose_turnover_threshold': 0.10,

    # Parameter sweep ranges
    'ketamine_regrow_sweep': [0.10, 0.20, 0.30, 0.40, 0.50, 0.60, 0.70, 0.80],
    'ssri_epochs_sweep': [20, 40, 60, 80, 100, 120, 160, 200],
    'ssri_lr_sweep': [1e-6, 5e-6, 1e-5, 2e-5, 5e-5],
    'neurosteroid_strength_sweep': [0.50, 0.55, 0.60, 0.65, 0.70, 0.75, 0.80, 0.85],

    # -------------------------------------------------------------------------
    # Reproducibility
    # -------------------------------------------------------------------------
    'seed': 42,
    'n_seeds': 3,
}

# Global device
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


def set_seed(seed: int):
    """Ensure reproducibility across runs."""
    torch.manual_seed(seed)
    np.random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False


def print_section_header(title: str, width: int = 80, char: str = "="):
    """Print a formatted section header."""
    print(f"\n{char * width}")
    print(f"{title.center(width)}")
    print(f"{char * width}")


def print_subsection_header(title: str, width: int = 60, char: str = "-"):
    """Print a formatted subsection header."""
    print(f"\n{char * width}")
    print(f"  {title}")
    print(f"{char * width}")


# =============================================================================
# DOSING METRICS - MECHANISM-AGNOSTIC QUANTIFICATION
# =============================================================================

def compute_weight_change_norm(
    model_pre_state: Dict[str, torch.Tensor],
    model_post: nn.Module,
    norm_type: str = 'l1'
) -> float:
    """
    Compute total weight change magnitude as mechanism-agnostic dose proxy.

    Returns normalized dose (total change / total parameters).
    """
    delta = 0.0
    total_params = 0

    for name, param in model_post.named_parameters():
        if 'weight' in name and name in model_pre_state:
            diff = (param.data - model_pre_state[name]).abs()
            total_params += param.numel()
            if norm_type == 'l1':
                delta += diff.sum().item()
            elif norm_type == 'l2':
                delta += (diff ** 2).sum().item()

    if norm_type == 'l2':
        delta = delta ** 0.5

    return delta / total_params if total_params > 0 else 0.0


def compute_synaptic_turnover(
    model_pre_state: Dict[str, torch.Tensor],
    model_post: nn.Module,
    threshold: float = 0.10
) -> float:
    """
    Compute fraction of synapses with significant weight changes.

    Captures "how many synapses were meaningfully modified" (> threshold relative change).
    """
    changed = 0
    total = 0

    for name, param in model_post.named_parameters():
        if 'weight' in name and name in model_pre_state:
            pre_weights = model_pre_state[name]
            relative_change = (param.data - pre_weights).abs() / (pre_weights.abs().clamp(min=1e-8))
            changed += (relative_change > threshold).sum().item()
            total += param.numel()

    return changed / total if total > 0 else 0.0


def compute_sparsity_change(
    model_pre_state: Dict[str, torch.Tensor],
    model_post: nn.Module
) -> float:
    """Compute absolute change in network sparsity."""
    def get_sparsity(state_dict):
        total = 0
        zeros = 0
        for name, tensor in state_dict.items():
            if 'weight' in name:
                total += tensor.numel()
                zeros += (tensor.abs() < 1e-8).sum().item()
        return zeros / total if total > 0 else 0.0

    pre_sparsity = get_sparsity(model_pre_state)
    post_sparsity = get_sparsity({n: p.data for n, p in model_post.named_parameters()})

    return abs(post_sparsity - pre_sparsity)


@dataclass
class DoseMetrics:
    """Container for all dosing quantification metrics."""
    l1_norm: float = 0.0
    l2_norm: float = 0.0
    synaptic_turnover: float = 0.0
    sparsity_change: float = 0.0

    @property
    def primary_dose(self) -> float:
        """Primary dose metric (L1 norm by default)."""
        return self.l1_norm


def compute_all_dose_metrics(
    model_pre_state: Dict[str, torch.Tensor],
    model_post: nn.Module,
    turnover_threshold: float = None
) -> DoseMetrics:
    """Compute all dose quantification metrics."""
    if turnover_threshold is None:
        turnover_threshold = CONFIG['iso_dose_turnover_threshold']

    return DoseMetrics(
        l1_norm=compute_weight_change_norm(model_pre_state, model_post, 'l1'),
        l2_norm=compute_weight_change_norm(model_pre_state, model_post, 'l2'),
        synaptic_turnover=compute_synaptic_turnover(model_pre_state, model_post, turnover_threshold),
        sparsity_change=compute_sparsity_change(model_pre_state, model_post)
    )


# =============================================================================
# RULE DEFINITIONS FOR COGNITIVE FLEXIBILITY TASK
# =============================================================================

class Rule(Enum):
    """Classification rules analogous to Wisconsin Card Sorting Test dimensions."""
    X_SIGN = 0
    Y_SIGN = 1
    QUADRANT = 2
    DIAGONAL = 3


def apply_rule(points: torch.Tensor, rule: int) -> torch.Tensor:
    """Apply a classification rule to 2D points."""
    x, y = points[..., 0], points[..., 1]

    if rule == Rule.X_SIGN.value:
        labels = ((x >= 0).long() * 2 + (y >= 0).long())
    elif rule == Rule.Y_SIGN.value:
        labels = ((y >= 0).long() * 2 + (x >= 0).long())
    elif rule == Rule.QUADRANT.value:
        labels = ((x >= 0).long() + (y >= 0).long() * 2)
    elif rule == Rule.DIAGONAL.value:
        main_diag = (y >= x).long()
        anti_diag = (y >= -x).long()
        labels = main_diag * 2 + anti_diag
    else:
        raise ValueError(f"Unknown rule: {rule}")

    return labels


# =============================================================================
# DATA GENERATION
# =============================================================================

def generate_base_points(n_points: int, noise: float = 0.8) -> torch.Tensor:
    """Generate 2D points from 4 Gaussian clusters centered in each quadrant."""
    centers = torch.tensor([
        [1.5, 1.5], [-1.5, 1.5], [-1.5, -1.5], [1.5, -1.5],
    ], dtype=torch.float32)

    cluster_idx = torch.randint(0, 4, (n_points,))
    points = centers[cluster_idx] + torch.randn(n_points, 2) * noise

    return points


def generate_rule_switch_sequences(
    n_sequences: int,
    seq_len: int,
    switch_interval: int,
    noise: float = 0.8,
    deterministic_switches: bool = False
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """Generate sequences with rule switches for cognitive flexibility testing."""
    all_data = []
    all_labels = []
    all_rules = []

    for _ in range(n_sequences):
        points = generate_base_points(seq_len, noise=noise)
        rules = torch.zeros(seq_len, dtype=torch.long)
        current_rule = torch.randint(0, CONFIG['n_rules'], (1,)).item()

        if deterministic_switches:
            switch_points = set(range(switch_interval, seq_len, switch_interval))
        else:
            n_switches = max(1, seq_len // switch_interval)
            valid_range = list(range(20, seq_len - 20))
            if len(valid_range) >= n_switches:
                switch_points = set(np.random.choice(
                    valid_range, size=n_switches, replace=False
                ))
            else:
                switch_points = set(valid_range)

        for t in range(seq_len):
            if t in switch_points:
                new_rule = (current_rule + np.random.randint(1, CONFIG['n_rules'])) % CONFIG['n_rules']
                current_rule = new_rule
            rules[t] = current_rule

        labels = torch.zeros(seq_len, dtype=torch.long)
        for t in range(seq_len):
            labels[t] = apply_rule(points[t:t+1], rules[t].item())[0]

        all_data.append(points)
        all_labels.append(labels)
        all_rules.append(rules)

    return torch.stack(all_data), torch.stack(all_labels), torch.stack(all_rules)


def create_rule_switch_dataloaders(
    n_train: int = None,
    n_test: int = None,
    seq_len: int = None,
    batch_size: int = None
) -> Tuple[DataLoader, DataLoader, torch.Tensor]:
    """Create train and test dataloaders for rule-switching task."""
    n_train = n_train or CONFIG['n_train_sequences']
    n_test = n_test or CONFIG['n_test_sequences']
    seq_len = seq_len or CONFIG['seq_len']
    batch_size = batch_size or CONFIG['batch_size']

    train_data, train_labels, train_rules = generate_rule_switch_sequences(
        n_train, seq_len, CONFIG['switch_interval'], deterministic_switches=False
    )
    test_data, test_labels, test_rules = generate_rule_switch_sequences(
        n_test, seq_len, CONFIG['switch_interval'], deterministic_switches=True
    )

    train_dataset = TensorDataset(train_data, train_labels, train_rules)
    test_dataset = TensorDataset(test_data, test_labels, test_rules)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    return train_loader, test_loader, test_rules


# =============================================================================
# RECURRENT NETWORK ARCHITECTURE (CSTC LOOP MODEL) - ENHANCED
# =============================================================================

class CSTCNetwork(nn.Module):
    """
    Recurrent network modeling cortico-striato-thalamo-cortical (CSTC) loops.
    """

    def __init__(self, hidden_dims: List[int] = None, num_layers: int = None):
        super().__init__()
        if hidden_dims is None:
            hidden_dims = CONFIG['hidden_dims']
        if num_layers is None:
            num_layers = CONFIG['num_gru_layers']

        self.hidden_dim = hidden_dims[1]
        self.num_layers = num_layers

        self.input_fc = nn.Linear(CONFIG['input_dim'], hidden_dims[0])

        self.gru = nn.GRU(
            input_size=hidden_dims[0],
            hidden_size=hidden_dims[1],
            num_layers=num_layers,
            batch_first=True,
            dropout=0.1 if num_layers > 1 else 0.0
        )

        self.output_fc = nn.Linear(hidden_dims[1], CONFIG['output_dim'])

        self.relu = nn.ReLU()

        self.stress_level = 0.0
        self.glutamate_noise = 0.0

        self.inhibition_strength = 1.0
        self.use_tanh = False

        self.register_buffer('input_mask', torch.ones_like(self.input_fc.weight))
        self.register_buffer('output_mask', torch.ones_like(self.output_fc.weight))
        self.gru_masks = {}

    def init_hidden(self, batch_size: int, device: torch.device) -> torch.Tensor:
        """Initialize GRU hidden state."""
        return torch.zeros(self.num_layers, batch_size, self.hidden_dim, device=device)

    def set_inhibition(self, strength: float, use_tanh: bool = False):
        """Apply neurosteroid-like tonic inhibition."""
        self.inhibition_strength = max(0.0, min(1.0, strength))
        self.use_tanh = use_tanh

    def reduce_stress_gradually(self, epoch: int, total_epochs: int,
                                initial_stress: float = 0.4, final_stress: float = 0.0):
        """SSRI-like: Linearly reduce internal recurrent noise over epochs."""
        progress = epoch / max(total_epochs - 1, 1)
        self.stress_level = initial_stress + progress * (final_stress - initial_stress)

    def forward(
        self,
        x: torch.Tensor,
        hidden: Optional[torch.Tensor] = None,
        return_hidden: bool = False
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
        """Forward pass through CSTC network."""
        single_step = x.dim() == 2
        if single_step:
            x = x.unsqueeze(1)

        batch_size, seq_len, _ = x.shape
        device = x.device

        if hidden is None:
            hidden = self.init_hidden(batch_size, device)

        if self.glutamate_noise > 0:
            x = x + torch.randn_like(x) * self.glutamate_noise

        masked_weight = self.input_fc.weight * self.input_mask
        h = F.linear(x, masked_weight, self.input_fc.bias)
        h = self.relu(h)

        h = h * self.inhibition_strength

        if self.stress_level > 0:
            h = h + torch.randn_like(h) * self.stress_level

        gru_out, hidden = self.gru(h, hidden)

        gru_out = gru_out * self.inhibition_strength
        if self.use_tanh:
            gru_out = torch.tanh(gru_out)

        if self.stress_level > 0:
            gru_out = gru_out + torch.randn_like(gru_out) * self.stress_level * 0.5

        masked_output_weight = self.output_fc.weight * self.output_mask
        logits = F.linear(gru_out, masked_output_weight, self.output_fc.bias)

        if single_step:
            logits = logits.squeeze(1)

        if return_hidden:
            return logits, hidden
        return logits, None

    def set_stress(self, level: float):
        """Set internal noise level."""
        self.stress_level = max(0.0, level)

    def set_glutamate_noise(self, level: float):
        """Set input-only noise for glutamate independence testing."""
        self.glutamate_noise = max(0.0, level)

    def get_sparsity(self) -> float:
        """Calculate current network sparsity."""
        total_params = 0
        zero_params = 0

        for name, param in self.named_parameters():
            if 'weight' in name:
                total_params += param.numel()
                zero_params += (param.abs() < 1e-8).sum().item()

        return zero_params / total_params if total_params > 0 else 0.0

    def get_treatment_state(self) -> Dict[str, Any]:
        """Get current treatment-related state for logging."""
        return {
            'stress_level': self.stress_level,
            'glutamate_noise': self.glutamate_noise,
            'inhibition_strength': self.inhibition_strength,
            'use_tanh': self.use_tanh,
            'sparsity': self.get_sparsity()
        }


# =============================================================================
# PRUNING MANAGER WITH TREATMENT SIMULATION CAPABILITIES
# =============================================================================

class CSTCPruningManager:
    """Manages pruning, regrowth, and relapse simulation for CSTC networks."""

    def __init__(self, model: CSTCNetwork):
        self.model = model
        self.original_weights = {}
        self.masks = {}
        self.history = []
        self._save_original_weights()

    def _save_original_weights(self):
        """Store original weights for regrowth restoration."""
        for name, param in self.model.named_parameters():
            if 'weight' in name:
                self.original_weights[name] = param.data.clone()
                self.masks[name] = torch.ones_like(param.data)

    def get_sparsity(self) -> float:
        """Get current network sparsity."""
        return self.model.get_sparsity()

    def prune_by_magnitude(
        self,
        sparsity: float,
        recurrence_bias: float = None
    ) -> Dict[str, Any]:
        """Apply global magnitude-based pruning."""
        if recurrence_bias is None:
            recurrence_bias = CONFIG['recurrence_bias']

        if sparsity <= 0:
            return {'achieved_sparsity': 0.0, 'weights_pruned': 0}

        all_weights = []
        weight_info = []

        for name, param in self.model.named_parameters():
            if 'weight' in name and param.requires_grad:
                flat_weights = param.data.abs().flatten()

                if 'gru' in name and recurrence_bias != 1.0:
                    flat_weights = flat_weights / recurrence_bias

                all_weights.append(flat_weights)
                weight_info.append((name, param))

        if not all_weights:
            return {'achieved_sparsity': 0.0, 'weights_pruned': 0}

        all_weights_cat = torch.cat(all_weights)
        k = int(sparsity * all_weights_cat.numel())
        if k == 0:
            return {'achieved_sparsity': 0.0, 'weights_pruned': 0}

        threshold = torch.kthvalue(all_weights_cat, k).values.item()

        total_pruned = 0
        layer_stats = {}

        for name, param in weight_info:
            effective_weights = param.data.abs()
            if 'gru' in name and recurrence_bias != 1.0:
                effective_weights = effective_weights / recurrence_bias

            mask = (effective_weights > threshold).float()
            pruned_count = (mask == 0).sum().item()

            self.masks[name] = mask
            param.data *= mask

            total_pruned += pruned_count
            layer_stats[name] = {
                'pruned': pruned_count,
                'total': param.numel(),
                'layer_sparsity': pruned_count / param.numel()
            }

            if name == 'input_fc.weight':
                self.model.input_mask.copy_(mask)
            elif name == 'output_fc.weight':
                self.model.output_mask.copy_(mask)

        achieved = self.get_sparsity()

        self.history.append({
            'operation': 'prune_magnitude',
            'target_sparsity': sparsity,
            'achieved_sparsity': achieved,
            'weights_pruned': total_pruned,
            'recurrence_bias': recurrence_bias
        })

        return {
            'achieved_sparsity': achieved,
            'weights_pruned': total_pruned,
            'layer_stats': layer_stats
        }

    def gradient_guided_regrow(
        self,
        train_loader: DataLoader = None,
        regrow_fraction: float = None,
        n_batches: int = 5,
        init_scale: float = None
    ) -> Dict[str, Any]:
        """Regrow connections guided by gradient importance."""
        if regrow_fraction is None:
            regrow_fraction = CONFIG['regrowth_fraction']
        if init_scale is None:
            init_scale = CONFIG['regrowth_init_scale']
        if train_loader is None:
            train_loader, _, _ = create_rule_switch_dataloaders()

        self.model.train()
        device = next(self.model.parameters()).device

        gradient_importance = {}
        for name, param in self.model.named_parameters():
            if 'weight' in name:
                gradient_importance[name] = torch.zeros_like(param.data)

        criterion = nn.CrossEntropyLoss()

        for batch_idx, (data, labels, _) in enumerate(train_loader):
            if batch_idx >= n_batches:
                break

            data, labels = data.to(device), labels.to(device)
            self.model.zero_grad()

            logits, _ = self.model(data)
            loss = criterion(logits.view(-1, CONFIG['output_dim']), labels.view(-1))
            loss.backward()

            for name, param in self.model.named_parameters():
                if 'weight' in name and param.grad is not None:
                    gradient_importance[name] += param.grad.abs()

        total_regrown = 0
        layer_stats = {}

        for name, param in self.model.named_parameters():
            if 'weight' not in name:
                continue

            mask = self.masks.get(name, torch.ones_like(param.data))
            pruned_mask = (mask < 0.5)

            if pruned_mask.sum() == 0:
                continue

            n_regrow = int(regrow_fraction * pruned_mask.sum().item())
            if n_regrow == 0:
                continue

            importance = gradient_importance[name] * pruned_mask.float()
            flat_importance = importance.flatten()
            n_positive = (flat_importance > 0).sum().item()

            if n_positive == 0:
                continue

            _, top_indices = torch.topk(flat_importance, min(n_regrow, n_positive))

            flat_mask = mask.flatten()
            flat_param = param.data.flatten()
            flat_original = self.original_weights[name].flatten()

            for idx in top_indices:
                flat_mask[idx] = 1.0
                flat_param[idx] = flat_original[idx] * init_scale

            new_mask = flat_mask.view_as(mask)
            param.data = flat_param.view_as(param.data)
            self.masks[name] = new_mask

            if name == 'input_fc.weight':
                self.model.input_mask.copy_(new_mask)
            elif name == 'output_fc.weight':
                self.model.output_mask.copy_(new_mask)

            regrown_count = len(top_indices)
            total_regrown += regrown_count
            layer_stats[name] = {
                'regrown': regrown_count,
                'remaining_pruned': pruned_mask.sum().item() - regrown_count
            }

        new_sparsity = self.get_sparsity()

        self.history.append({
            'operation': 'gradient_regrow',
            'regrow_fraction': regrow_fraction,
            'connections_regrown': total_regrown,
            'new_sparsity': new_sparsity
        })

        return {
            'connections_regrown': total_regrown,
            'new_sparsity': new_sparsity,
            'layer_stats': layer_stats
        }

    def secondary_prune(
        self,
        fraction: float,
        bias_recurrent: bool = False,
        recurrence_multiplier: float = 1.5
    ) -> Dict[str, Any]:
        """Simulate relapse by pruning a fraction of surviving weights."""
        stats = {}
        total_pruned = 0

        for name, param in self.model.named_parameters():
            if name not in self.masks:
                continue

            mask = self.masks[name]
            active_positions = (mask == 1)
            n_active = active_positions.sum().item()

            if n_active == 0:
                continue

            effective_fraction = fraction
            if bias_recurrent and 'gru' in name:
                effective_fraction = min(fraction * recurrence_multiplier, 0.9)

            num_to_prune = int(effective_fraction * n_active)
            if num_to_prune == 0:
                continue

            weights = param.data.abs()
            weights_active = weights.clone()
            weights_active[~active_positions] = float('inf')

            flat_weights = weights_active.flatten()
            threshold = torch.kthvalue(flat_weights, num_to_prune).values.item()

            prune_mask = (weights <= threshold) & active_positions
            mask[prune_mask] = 0
            param.data[prune_mask] = 0

            pruned_count = prune_mask.sum().item()
            total_pruned += pruned_count

            stats[name] = {
                'pruned': pruned_count,
                'remaining': n_active - pruned_count,
                'effective_fraction': effective_fraction
            }

            if name == 'input_fc.weight':
                self.model.input_mask.copy_(mask)
            elif name == 'output_fc.weight':
                self.model.output_mask.copy_(mask)

        new_sparsity = self.get_sparsity()

        self.history.append({
            'operation': 'secondary_prune',
            'fraction': fraction,
            'bias_recurrent': bias_recurrent,
            'total_pruned': total_pruned,
            'new_sparsity': new_sparsity
        })

        return {
            'total_pruned': total_pruned,
            'new_sparsity': new_sparsity,
            'layer_stats': stats
        }

    def apply_masks(self):
        """Re-apply masks after training steps to maintain sparsity pattern."""
        for name, param in self.model.named_parameters():
            if name in self.masks:
                param.data *= self.masks[name]

    def get_history_summary(self) -> str:
        """Get a formatted summary of all operations."""
        lines = ["Pruning Manager History:"]
        for i, op in enumerate(self.history):
            lines.append(f"  {i+1}. {op['operation']}: {op}")
        return "\n".join(lines)


# =============================================================================
# OCD-SPECIFIC EVALUATION METRICS
# =============================================================================

@dataclass
class OCDMetrics:
    """Metrics capturing OCD-relevant behavioral phenotypes."""

    accuracy: float = 0.0
    perseverative_error_rate: float = 0.0
    switch_cost: float = 0.0
    trials_to_recover: float = 0.0
    repetition_rate: float = 0.0
    repetition_entropy: float = 0.0
    output_diversity: float = 0.0
    rule_inference_accuracy: float = 0.0
    flexibility_index: float = 0.0
    sparsity: float = 0.0
    stress_level: float = 0.0
    glutamate_noise: float = 0.0
    inhibition_strength: float = 1.0
    use_tanh: bool = False


def compute_ocd_metrics(
    model: CSTCNetwork,
    test_loader: DataLoader,
    device: torch.device,
    detailed: bool = False
) -> OCDMetrics:
    """Compute comprehensive OCD-relevant metrics."""
    model.eval()
    metrics = OCDMetrics()

    all_predictions = []
    all_labels = []
    all_rules = []
    all_correct = []

    with torch.no_grad():
        for data, labels, rules in test_loader:
            data, labels, rules = data.to(device), labels.to(device), rules.to(device)

            batch_size, seq_len, _ = data.shape
            hidden = model.init_hidden(batch_size, device)

            batch_preds = []
            for t in range(seq_len):
                logits, hidden = model(data[:, t:t+1, :], hidden, return_hidden=True)
                preds = logits.squeeze(1).argmax(dim=-1)
                batch_preds.append(preds)

            batch_preds = torch.stack(batch_preds, dim=1)

            all_predictions.append(batch_preds.cpu())
            all_labels.append(labels.cpu())
            all_rules.append(rules.cpu())
            all_correct.append((batch_preds == labels).cpu())

    predictions = torch.cat(all_predictions, dim=0)
    labels = torch.cat(all_labels, dim=0)
    rules = torch.cat(all_rules, dim=0)
    correct = torch.cat(all_correct, dim=0)

    n_sequences, seq_len = predictions.shape

    metrics.accuracy = correct.float().mean().item()

    rule_changes = (rules[:, 1:] != rules[:, :-1])

    perseverative_errors = 0
    perseverative_opportunities = 0
    switch_accuracies = []
    stable_accuracies = []
    recovery_trials_list = []

    for seq_idx in range(n_sequences):
        switch_points = torch.where(rule_changes[seq_idx])[0] + 1

        prev_switch = 0
        for switch_t in switch_points:
            switch_t = switch_t.item()
            if switch_t - prev_switch >= 10:
                stable_acc = correct[seq_idx, prev_switch+5:switch_t-5].float().mean().item()
                if not np.isnan(stable_acc):
                    stable_accuracies.append(stable_acc)
            prev_switch = switch_t

        for switch_t in switch_points:
            switch_t = switch_t.item()
            if switch_t >= seq_len - 10:
                continue

            window_end = min(switch_t + 10, seq_len)
            for t in range(switch_t, window_end):
                if not correct[seq_idx, t]:
                    perseverative_opportunities += 1
                    if t > 0 and predictions[seq_idx, t] == predictions[seq_idx, t-1]:
                        perseverative_errors += 1

            post_acc = correct[seq_idx, switch_t:min(switch_t+5, seq_len)].float().mean().item()

            if not np.isnan(post_acc):
                switch_accuracies.append(post_acc)

            for recovery_t in range(switch_t, min(seq_len - 5, switch_t + 50)):
                window_acc = correct[seq_idx, recovery_t:recovery_t+5].float().mean().item()
                if window_acc >= 0.8:
                    recovery_trials_list.append(recovery_t - switch_t)
                    break
            else:
                recovery_trials_list.append(50)

    if perseverative_opportunities > 0:
        metrics.perseverative_error_rate = perseverative_errors / perseverative_opportunities

    if switch_accuracies and stable_accuracies:
        metrics.switch_cost = np.mean(stable_accuracies) - np.mean(switch_accuracies)
        metrics.flexibility_index = np.mean(switch_accuracies) / (np.mean(stable_accuracies) + 1e-8)

    if recovery_trials_list:
        metrics.trials_to_recover = np.mean(recovery_trials_list)

    if stable_accuracies:
        metrics.rule_inference_accuracy = np.mean(stable_accuracies)

    repetitions = (predictions[:, 1:] == predictions[:, :-1]).float()
    metrics.repetition_rate = repetitions.mean().item()

    output_counts = torch.bincount(predictions.flatten(), minlength=CONFIG['output_dim']).float()
    output_probs = output_counts / output_counts.sum()
    entropy = -(output_probs * torch.log(output_probs + 1e-8)).sum().item()
    max_entropy = np.log(CONFIG['output_dim'])
    metrics.repetition_entropy = entropy
    metrics.output_diversity = entropy / max_entropy

    metrics.sparsity = model.get_sparsity()
    metrics.stress_level = model.stress_level
    metrics.glutamate_noise = model.glutamate_noise
    metrics.inhibition_strength = model.inhibition_strength
    metrics.use_tanh = model.use_tanh

    return metrics


# =============================================================================
# TRAINING FUNCTIONS
# =============================================================================

def train_epoch(
    model: CSTCNetwork,
    train_loader: DataLoader,
    optimizer: torch.optim.Optimizer,
    criterion: nn.Module,
    device: torch.device,
    pruning_manager: Optional[CSTCPruningManager] = None
) -> float:
    """Train for one epoch, maintaining sparsity if pruning manager provided."""
    model.train()
    total_loss = 0.0
    n_batches = 0

    for data, labels, _ in train_loader:
        data, labels = data.to(device), labels.to(device)

        optimizer.zero_grad()
        logits, _ = model(data)

        loss = criterion(
            logits.view(-1, CONFIG['output_dim']),
            labels.view(-1)
        )

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()

        if pruning_manager is not None:
            pruning_manager.apply_masks()

        total_loss += loss.item()
        n_batches += 1

    return total_loss / n_batches


def train(
    model: CSTCNetwork,
    train_loader: DataLoader = None,
    test_loader: DataLoader = None,
    epochs: int = None,
    lr: float = None,
    pruning_manager: Optional[CSTCPruningManager] = None,
    verbose: bool = True,
    eval_interval: int = 10
) -> Dict[str, List[float]]:
    """Full training loop with optional pruning maintenance."""
    if train_loader is None or test_loader is None:
        train_loader, test_loader, _ = create_rule_switch_dataloaders()
    if epochs is None:
        epochs = CONFIG['baseline_epochs']
    if lr is None:
        lr = CONFIG['baseline_lr']

    device = next(model.parameters()).device
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()

    history = {'loss': [], 'accuracy': [], 'perseveration': []}

    for epoch in range(epochs):
        loss = train_epoch(model, train_loader, optimizer, criterion, device, pruning_manager)

        history['loss'].append(loss)

        if (epoch + 1) % eval_interval == 0 or epoch == epochs - 1:
            metrics = compute_ocd_metrics(model, test_loader, device)
            history['accuracy'].append(metrics.accuracy)
            history['perseveration'].append(metrics.perseverative_error_rate)

            if verbose:
                print(f"    Epoch {epoch+1:3d}/{epochs}: Loss={loss:.4f}, "
                      f"Acc={metrics.accuracy:.4f}, Persev={metrics.perseverative_error_rate:.4f}")

    return history


def train_with_stress_schedule(
    model: CSTCNetwork,
    train_loader: DataLoader,
    test_loader: DataLoader,
    epochs: int,
    lr: float,
    initial_stress: float,
    final_stress: float = 0.0,
    pruning_manager: Optional[CSTCPruningManager] = None,
    verbose: bool = False
) -> List[float]:
    """SSRI-like: Prolonged training with gradually reducing internal noise."""
    device = next(model.parameters()).device
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()
    losses = []

    for epoch in range(epochs):
        model.reduce_stress_gradually(epoch, epochs, initial_stress, final_stress)

        model.train()
        epoch_loss = 0.0
        n_batches = 0

        for data, labels, _ in train_loader:
            data, labels = data.to(device), labels.to(device)

            optimizer.zero_grad()
            logits, _ = model(data)
            loss = criterion(logits.view(-1, CONFIG['output_dim']), labels.view(-1))
            loss.backward()

            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()

            if pruning_manager:
                pruning_manager.apply_masks()

            epoch_loss += loss.item()
            n_batches += 1

        losses.append(epoch_loss / n_batches)

        if verbose and (epoch + 1) % 30 == 0:
            metrics = compute_ocd_metrics(model, test_loader, device)
            print(f"      SSRI epoch {epoch+1}/{epochs} | stress: {model.stress_level:.3f} | "
                  f"loss: {losses[-1]:.4f} | acc: {metrics.accuracy:.3f}")

    model.set_stress(final_stress)
    return losses


# =============================================================================
# THREE TREATMENT PROTOCOL FUNCTIONS
# =============================================================================

def ketamine_treatment_ocd(
    model: CSTCNetwork,
    pruning_mgr: CSTCPruningManager,
    train_loader: DataLoader,
    test_loader: DataLoader,
    regrow_fraction: float = None,
    consolidation_epochs: int = None,
    verbose: bool = True
) -> Dict[str, Any]:
    """Ketamine-like: Rapid structural synaptogenesis."""
    if regrow_fraction is None:
        regrow_fraction = CONFIG['comparison_ketamine_regrow']
    if consolidation_epochs is None:
        consolidation_epochs = CONFIG['comparison_ketamine_epochs']

    if verbose:
        print(f"      [KETAMINE] regrow_fraction={regrow_fraction:.2f}, consolidation={consolidation_epochs} epochs")

    regrow_stats = pruning_mgr.gradient_guided_regrow(
        train_loader,
        regrow_fraction=regrow_fraction
    )
    total_regrown = regrow_stats['connections_regrown']

    train(
        model, train_loader, test_loader,
        epochs=consolidation_epochs,
        lr=CONFIG['finetune_lr'],
        pruning_manager=pruning_mgr,
        verbose=False
    )

    final_sparsity = pruning_mgr.get_sparsity()

    return {
        'treatment': 'ketamine',
        'mechanism': 'structural',
        'regrown': total_regrown,
        'regrow_fraction': regrow_fraction,
        'consolidation_epochs': consolidation_epochs,
        'final_sparsity': final_sparsity
    }


def ssri_treatment_ocd(
    model: CSTCNetwork,
    pruning_mgr: CSTCPruningManager,
    train_loader: DataLoader,
    test_loader: DataLoader,
    epochs: int = None,
    lr: float = None,
    initial_stress: float = None,
    verbose: bool = True
) -> Dict[str, Any]:
    """SSRI-like: Gradual functional stabilization (fixed structure)."""
    if epochs is None:
        epochs = CONFIG['comparison_ssri_epochs']
    if lr is None:
        lr = CONFIG['comparison_ssri_lr']
    if initial_stress is None:
        initial_stress = CONFIG['comparison_ssri_initial_stress']

    if verbose:
        print(f"      [SSRI] epochs={epochs}, lr={lr:.0e}, initial_stress={initial_stress:.2f}")

    initial_sparsity = pruning_mgr.get_sparsity()

    train_with_stress_schedule(
        model, train_loader, test_loader,
        epochs=epochs,
        lr=lr,
        initial_stress=initial_stress,
        final_stress=0.0,
        pruning_manager=pruning_mgr,
        verbose=False
    )

    final_sparsity = pruning_mgr.get_sparsity()

    return {
        'treatment': 'ssri',
        'mechanism': 'functional',
        'epochs': epochs,
        'lr': lr,
        'initial_stress': initial_stress,
        'final_stress': model.stress_level,
        'initial_sparsity': initial_sparsity,
        'final_sparsity': final_sparsity
    }


def neurosteroid_treatment_ocd(
    model: CSTCNetwork,
    pruning_mgr: CSTCPruningManager,
    train_loader: DataLoader,
    test_loader: DataLoader,
    strength: float = None,
    use_tanh: bool = None,
    consolidation_epochs: int = None,
    verbose: bool = True
) -> Dict[str, Any]:
    """Neurosteroid-like: Rapid tonic inhibition (medication-dependent)."""
    if strength is None:
        strength = CONFIG['comparison_neurosteroid_strength']
    if use_tanh is None:
        use_tanh = CONFIG['comparison_neurosteroid_use_tanh']
    if consolidation_epochs is None:
        consolidation_epochs = CONFIG['comparison_neurosteroid_epochs']

    if verbose:
        print(f"      [NEUROSTEROID] strength={strength:.2f}, use_tanh={use_tanh}, consolidation={consolidation_epochs} epochs")

    initial_sparsity = pruning_mgr.get_sparsity()

    model.set_inhibition(strength, use_tanh)

    train(
        model, train_loader, test_loader,
        epochs=consolidation_epochs,
        lr=CONFIG['finetune_lr'],
        pruning_manager=pruning_mgr,
        verbose=False
    )

    final_sparsity = pruning_mgr.get_sparsity()

    return {
        'treatment': 'neurosteroid',
        'mechanism': 'functional_medication_dependent',
        'strength': strength,
        'use_tanh': use_tanh,
        'consolidation_epochs': consolidation_epochs,
        'initial_sparsity': initial_sparsity,
        'final_sparsity': final_sparsity
    }


# =============================================================================
# ISO-DOSE PARAMETER SWEEP FUNCTIONS
# =============================================================================

def run_ketamine_sweep(
    base_state: Dict[str, torch.Tensor],
    base_masks: Dict[str, torch.Tensor],
    train_loader: DataLoader,
    test_loader: DataLoader,
    device: torch.device,
    regrow_fractions: List[float] = None
) -> List[Dict[str, Any]]:
    """Sweep ketamine regrow_fraction and measure dose + outcomes."""
    if regrow_fractions is None:
        regrow_fractions = CONFIG['ketamine_regrow_sweep']

    results = []

    for regrow_frac in regrow_fractions:
        model = CSTCNetwork().to(device)
        model.load_state_dict(copy.deepcopy(base_state))
        mgr = CSTCPruningManager(model)
        mgr.masks = copy.deepcopy(base_masks)
        mgr.apply_masks()

        pre_state = {n: p.data.clone() for n, p in model.named_parameters() if 'weight' in n}

        treatment_info = ketamine_treatment_ocd(
            model, mgr, train_loader, test_loader,
            regrow_fraction=regrow_frac,
            verbose=False
        )

        dose_metrics = compute_all_dose_metrics(pre_state, model)
        acute_metrics = compute_ocd_metrics(model, test_loader, device)

        pre_relapse_persev = acute_metrics.perseverative_error_rate
        mgr.secondary_prune(fraction=CONFIG['relapse_prune_fraction'])
        relapse_metrics = compute_ocd_metrics(model, test_loader, device)

        results.append({
            'treatment': 'ketamine',
            'param_name': 'regrow_fraction',
            'param_value': regrow_frac,
            'dose': dose_metrics,
            'acute_persev': acute_metrics.perseverative_error_rate,
            'acute_flex': acute_metrics.flexibility_index,
            'acute_accuracy': acute_metrics.accuracy,
            'relapse_persev': relapse_metrics.perseverative_error_rate,
            'relapse_delta': relapse_metrics.perseverative_error_rate - pre_relapse_persev,
            'treatment_info': treatment_info
        })

    return results


def run_ssri_sweep(
    base_state: Dict[str, torch.Tensor],
    base_masks: Dict[str, torch.Tensor],
    train_loader: DataLoader,
    test_loader: DataLoader,
    device: torch.device,
    epochs_list: List[int] = None,
    lr_list: List[float] = None
) -> List[Dict[str, Any]]:
    """Sweep SSRI epochs and LR, measure dose + outcomes."""
    if epochs_list is None:
        epochs_list = CONFIG['ssri_epochs_sweep']
    if lr_list is None:
        lr_list = [CONFIG['comparison_ssri_lr']]

    results = []

    for epochs in epochs_list:
        for lr in lr_list:
            model = CSTCNetwork().to(device)
            model.load_state_dict(copy.deepcopy(base_state))
            mgr = CSTCPruningManager(model)
            mgr.masks = copy.deepcopy(base_masks)
            mgr.apply_masks()

            pre_state = {n: p.data.clone() for n, p in model.named_parameters() if 'weight' in n}

            treatment_info = ssri_treatment_ocd(
                model, mgr, train_loader, test_loader,
                epochs=epochs,
                lr=lr,
                verbose=False
            )

            dose_metrics = compute_all_dose_metrics(pre_state, model)
            acute_metrics = compute_ocd_metrics(model, test_loader, device)

            pre_relapse_persev = acute_metrics.perseverative_error_rate
            mgr.secondary_prune(fraction=CONFIG['relapse_prune_fraction'])
            relapse_metrics = compute_ocd_metrics(model, test_loader, device)

            results.append({
                'treatment': 'ssri',
                'param_name': 'epochs',
                'param_value': epochs,
                'lr': lr,
                'dose': dose_metrics,
                'acute_persev': acute_metrics.perseverative_error_rate,
                'acute_flex': acute_metrics.flexibility_index,
                'acute_accuracy': acute_metrics.accuracy,
                'relapse_persev': relapse_metrics.perseverative_error_rate,
                'relapse_delta': relapse_metrics.perseverative_error_rate - pre_relapse_persev,
                'treatment_info': treatment_info
            })

    return results


def run_neurosteroid_sweep(
    base_state: Dict[str, torch.Tensor],
    base_masks: Dict[str, torch.Tensor],
    train_loader: DataLoader,
    test_loader: DataLoader,
    device: torch.device,
    strength_list: List[float] = None
) -> List[Dict[str, Any]]:
    """Sweep neurosteroid strength, measure dose + outcomes."""
    if strength_list is None:
        strength_list = CONFIG['neurosteroid_strength_sweep']

    results = []

    for strength in strength_list:
        model = CSTCNetwork().to(device)
        model.load_state_dict(copy.deepcopy(base_state))
        mgr = CSTCPruningManager(model)
        mgr.masks = copy.deepcopy(base_masks)
        mgr.apply_masks()

        pre_state = {n: p.data.clone() for n, p in model.named_parameters() if 'weight' in n}

        treatment_info = neurosteroid_treatment_ocd(
            model, mgr, train_loader, test_loader,
            strength=strength,
            verbose=False
        )

        dose_metrics = compute_all_dose_metrics(pre_state, model)
        acute_metrics = compute_ocd_metrics(model, test_loader, device)

        model.set_inhibition(1.0, False)
        off_med_metrics = compute_ocd_metrics(model, test_loader, device)
        model.set_inhibition(strength, CONFIG['comparison_neurosteroid_use_tanh'])

        pre_relapse_persev = acute_metrics.perseverative_error_rate
        mgr.secondary_prune(fraction=CONFIG['relapse_prune_fraction'])
        relapse_metrics = compute_ocd_metrics(model, test_loader, device)

        results.append({
            'treatment': 'neurosteroid',
            'param_name': 'strength',
            'param_value': strength,
            'dose': dose_metrics,
            'acute_persev': acute_metrics.perseverative_error_rate,
            'acute_flex': acute_metrics.flexibility_index,
            'acute_accuracy': acute_metrics.accuracy,
            'off_med_persev': off_med_metrics.perseverative_error_rate,
            'off_med_reversal': off_med_metrics.perseverative_error_rate - acute_metrics.perseverative_error_rate,
            'relapse_persev': relapse_metrics.perseverative_error_rate,
            'relapse_delta': relapse_metrics.perseverative_error_rate - pre_relapse_persev,
            'treatment_info': treatment_info
        })

    return results


def find_iso_dose_params(
    sweep_results: List[Dict[str, Any]],
    target_dose: float,
    tolerance: float = None
) -> Optional[Dict[str, Any]]:
    """Find parameter configuration closest to target dose."""
    if tolerance is None:
        tolerance = CONFIG['iso_dose_tolerance']

    best_match = None
    best_diff = float('inf')

    for result in sweep_results:
        dose = result['dose'].l1_norm
        diff = abs(dose - target_dose)
        if diff < best_diff:
            best_diff = diff
            best_match = result

    if best_match and best_diff <= tolerance:
        return best_match
    return best_match


# =============================================================================
# ISO-DOSE COMPARISON EXPERIMENT
# =============================================================================

def run_iso_dose_comparison_experiment(
    device: torch.device,
    seed: int = None,
    verbose: bool = True
) -> Dict[str, Any]:
    """
    Run iso-dose comparison across all three treatment mechanisms.

    Sweeps parameters for each treatment, measures dose (L1 weight change norm),
    and compares outcomes at matched dose levels.
    """
    if seed is None:
        seed = CONFIG['seed']
    set_seed(seed)

    print_section_header("ISO-DOSE FAIR COMPARISON EXPERIMENT", char="█")
    print(f"\n  Seed: {seed}")
    print(f"  Device: {device}")
    print(f"  Dose metric: L1 weight change norm (normalized per parameter)")

    train_loader, test_loader, _ = create_rule_switch_dataloaders()

    # =========================================================================
    # PHASE 1: Create shared pruned baseline
    # =========================================================================
    print_subsection_header("PHASE 1: Creating Shared Pruned Baseline")

    base_model = CSTCNetwork().to(device)
    print("  Training healthy baseline model...")
    train(base_model, train_loader, test_loader, verbose=False)

    base_mgr = CSTCPruningManager(base_model)
    prune_stats = base_mgr.prune_by_magnitude(sparsity=CONFIG['ocd_prune_sparsity'])

    print(f"  Applied developmental over-pruning: {CONFIG['ocd_prune_sparsity']*100:.0f}%")
    print(f"  Achieved sparsity: {prune_stats['achieved_sparsity']*100:.1f}%")

    base_state = copy.deepcopy(base_model.state_dict())
    base_masks = copy.deepcopy(base_mgr.masks)

    untreated_metrics = compute_ocd_metrics(base_model, test_loader, device)
    print(f"\n  UNTREATED OCD BASELINE:")
    print(f"    Perseverative Errors: {untreated_metrics.perseverative_error_rate:.4f}")
    print(f"    Flexibility Index:    {untreated_metrics.flexibility_index:.4f}")
    print(f"    Accuracy:             {untreated_metrics.accuracy:.4f}")

    results = {
        'untreated': {
            'persev': untreated_metrics.perseverative_error_rate,
            'flex': untreated_metrics.flexibility_index,
            'accuracy': untreated_metrics.accuracy,
            'sparsity': untreated_metrics.sparsity
        },
        'sweeps': {},
        'iso_dose_comparisons': {}
    }

    # =========================================================================
    # PHASE 2: Parameter Sweeps
    # =========================================================================
    print_subsection_header("PHASE 2: Parameter Sweeps (Measuring Dose-Response)")

    print("\n  [KETAMINE] Sweeping regrow_fraction...")
    ketamine_results = run_ketamine_sweep(
        base_state, base_masks, train_loader, test_loader, device
    )
    results['sweeps']['ketamine'] = ketamine_results
    print(f"    Completed {len(ketamine_results)} configurations")

    print("\n  [SSRI] Sweeping epochs...")
    ssri_results = run_ssri_sweep(
        base_state, base_masks, train_loader, test_loader, device
    )
    results['sweeps']['ssri'] = ssri_results
    print(f"    Completed {len(ssri_results)} configurations")

    print("\n  [NEUROSTEROID] Sweeping strength...")
    neurosteroid_results = run_neurosteroid_sweep(
        base_state, base_masks, train_loader, test_loader, device
    )
    results['sweeps']['neurosteroid'] = neurosteroid_results
    print(f"    Completed {len(neurosteroid_results)} configurations")

    # =========================================================================
    # PHASE 3: Dose-Response Analysis
    # =========================================================================
    print_subsection_header("PHASE 3: Dose-Response Curves")

    print("\n  KETAMINE DOSE-RESPONSE:")
    print(f"  {'regrow_frac':>12} {'L1 Dose':>12} {'Turnover':>12} {'Acute Prsv':>12} {'Relapse Δ':>12}")
    print("  " + "-" * 64)
    for r in ketamine_results:
        print(f"  {r['param_value']:>12.2f} {r['dose'].l1_norm:>12.6f} {r['dose'].synaptic_turnover:>12.4f} "
              f"{r['acute_persev']:>12.4f} {r['relapse_delta']:>+12.4f}")

    print("\n  SSRI DOSE-RESPONSE:")
    print(f"  {'epochs':>12} {'L1 Dose':>12} {'Turnover':>12} {'Acute Prsv':>12} {'Relapse Δ':>12}")
    print("  " + "-" * 64)
    for r in ssri_results:
        print(f"  {r['param_value']:>12} {r['dose'].l1_norm:>12.6f} {r['dose'].synaptic_turnover:>12.4f} "
              f"{r['acute_persev']:>12.4f} {r['relapse_delta']:>+12.4f}")

    print("\n  NEUROSTEROID DOSE-RESPONSE:")
    print(f"  {'strength':>12} {'L1 Dose':>12} {'Turnover':>12} {'Acute Prsv':>12} {'Off-med Δ':>12} {'Relapse Δ':>12}")
    print("  " + "-" * 76)
    for r in neurosteroid_results:
        print(f"  {r['param_value']:>12.2f} {r['dose'].l1_norm:>12.6f} {r['dose'].synaptic_turnover:>12.4f} "
              f"{r['acute_persev']:>12.4f} {r['off_med_reversal']:>+12.4f} {r['relapse_delta']:>+12.4f}")

    # =========================================================================
    # PHASE 4: Iso-Dose Matching
    # =========================================================================
    print_subsection_header("PHASE 4: Iso-Dose Matched Comparisons")

    all_doses = []
    for r in ketamine_results:
        all_doses.append(r['dose'].l1_norm)
    for r in ssri_results:
        all_doses.append(r['dose'].l1_norm)
    for r in neurosteroid_results:
        all_doses.append(r['dose'].l1_norm)

    dose_range = (min(all_doses), max(all_doses))
    print(f"\n  Observed dose range: {dose_range[0]:.6f} - {dose_range[1]:.6f}")

    target_doses = CONFIG['iso_dose_target_doses']
    valid_targets = [d for d in target_doses if dose_range[0] <= d <= dose_range[1]]

    if not valid_targets:
        dose_percentiles = [25, 50, 75]
        valid_targets = [np.percentile(all_doses, p) for p in dose_percentiles]
        print(f"  Using dose percentiles: {[f'{d:.6f}' for d in valid_targets]}")
    else:
        print(f"  Target doses: {valid_targets}")

    for target_dose in valid_targets:
        print(f"\n  ISO-DOSE TARGET: {target_dose:.6f}")
        print("  " + "=" * 70)

        ket_match = find_iso_dose_params(ketamine_results, target_dose)
        ssri_match = find_iso_dose_params(ssri_results, target_dose)
        neuro_match = find_iso_dose_params(neurosteroid_results, target_dose)

        iso_comparison = {
            'target_dose': target_dose,
            'ketamine': ket_match,
            'ssri': ssri_match,
            'neurosteroid': neuro_match
        }
        results['iso_dose_comparisons'][target_dose] = iso_comparison

        print(f"\n  {'Treatment':<15} {'Param':<15} {'Actual Dose':>12} {'Acute Prsv':>12} {'Relapse Δ':>12} {'Efficiency':>12}")
        print("  " + "-" * 80)

        for name, match in [('Ketamine', ket_match), ('SSRI', ssri_match), ('Neurosteroid', neuro_match)]:
            if match:
                param_str = f"{match['param_name']}={match['param_value']}"
                actual_dose = match['dose'].l1_norm
                acute_p = match['acute_persev']
                relapse_d = match['relapse_delta']

                persev_reduction = untreated_metrics.perseverative_error_rate - acute_p
                efficiency = persev_reduction / (actual_dose + 1e-8)

                print(f"  {name:<15} {param_str:<15} {actual_dose:>12.6f} {acute_p:>12.4f} {relapse_d:>+12.4f} {efficiency:>12.2f}")
            else:
                print(f"  {name:<15} {'N/A':<15} {'N/A':>12} {'N/A':>12} {'N/A':>12} {'N/A':>12}")

    # =========================================================================
    # PHASE 5: Efficiency Analysis
    # =========================================================================
    print_subsection_header("PHASE 5: Treatment Efficiency Analysis")

    print("\n  EFFICIENCY = (Perseveration Reduction) / (L1 Dose)")
    print("  Higher efficiency = better outcome per unit of network change")

    print("\n  KETAMINE EFFICIENCY:")
    print(f"  {'regrow_frac':>12} {'Dose':>12} {'Prsv Reduc':>12} {'Efficiency':>12}")
    print("  " + "-" * 52)
    for r in ketamine_results:
        reduction = untreated_metrics.perseverative_error_rate - r['acute_persev']
        efficiency = reduction / (r['dose'].l1_norm + 1e-8)
        print(f"  {r['param_value']:>12.2f} {r['dose'].l1_norm:>12.6f} {reduction:>12.4f} {efficiency:>12.2f}")

    print("\n  SSRI EFFICIENCY:")
    print(f"  {'epochs':>12} {'Dose':>12} {'Prsv Reduc':>12} {'Efficiency':>12}")
    print("  " + "-" * 52)
    for r in ssri_results:
        reduction = untreated_metrics.perseverative_error_rate - r['acute_persev']
        efficiency = reduction / (r['dose'].l1_norm + 1e-8)
        print(f"  {r['param_value']:>12} {r['dose'].l1_norm:>12.6f} {reduction:>12.4f} {efficiency:>12.2f}")

    print("\n  NEUROSTEROID EFFICIENCY:")
    print(f"  {'strength':>12} {'Dose':>12} {'Prsv Reduc':>12} {'Efficiency':>12}")
    print("  " + "-" * 52)
    for r in neurosteroid_results:
        reduction = untreated_metrics.perseverative_error_rate - r['acute_persev']
        efficiency = reduction / (r['dose'].l1_norm + 1e-8)
        print(f"  {r['param_value']:>12.2f} {r['dose'].l1_norm:>12.6f} {reduction:>12.4f} {efficiency:>12.2f}")

    # =========================================================================
    # PHASE 6: Summary Statistics
    # =========================================================================
    print_subsection_header("PHASE 6: Summary Statistics")

    def compute_sweep_stats(sweep_results):
        doses = [r['dose'].l1_norm for r in sweep_results]
        turnovers = [r['dose'].synaptic_turnover for r in sweep_results]
        acute_persevs = [r['acute_persev'] for r in sweep_results]
        relapse_deltas = [r['relapse_delta'] for r in sweep_results]

        reductions = [untreated_metrics.perseverative_error_rate - p for p in acute_persevs]
        efficiencies = [r / (d + 1e-8) for r, d in zip(reductions, doses)]

        return {
            'dose_range': (min(doses), max(doses)),
            'turnover_range': (min(turnovers), max(turnovers)),
            'best_acute_persev': min(acute_persevs),
            'best_relapse_delta': min(relapse_deltas),
            'max_efficiency': max(efficiencies),
            'mean_efficiency': np.mean(efficiencies)
        }

    ket_stats = compute_sweep_stats(ketamine_results)
    ssri_stats = compute_sweep_stats(ssri_results)
    neuro_stats = compute_sweep_stats(neurosteroid_results)

    print("\n  TREATMENT SUMMARY:")
    print(f"  {'Metric':<25} {'Ketamine':>15} {'SSRI':>15} {'Neurosteroid':>15}")
    print("  " + "-" * 75)
    print(f"  {'Dose Range (L1)':<25} {ket_stats['dose_range'][0]:.4f}-{ket_stats['dose_range'][1]:.4f}"
          f"   {ssri_stats['dose_range'][0]:.4f}-{ssri_stats['dose_range'][1]:.4f}"
          f"   {neuro_stats['dose_range'][0]:.4f}-{neuro_stats['dose_range'][1]:.4f}")
    print(f"  {'Best Acute Persev':<25} {ket_stats['best_acute_persev']:>15.4f} {ssri_stats['best_acute_persev']:>15.4f} {neuro_stats['best_acute_persev']:>15.4f}")
    print(f"  {'Best Relapse Δ':<25} {ket_stats['best_relapse_delta']:>+15.4f} {ssri_stats['best_relapse_delta']:>+15.4f} {neuro_stats['best_relapse_delta']:>+15.4f}")
    print(f"  {'Max Efficiency':<25} {ket_stats['max_efficiency']:>15.2f} {ssri_stats['max_efficiency']:>15.2f} {neuro_stats['max_efficiency']:>15.2f}")
    print(f"  {'Mean Efficiency':<25} {ket_stats['mean_efficiency']:>15.2f} {ssri_stats['mean_efficiency']:>15.2f} {neuro_stats['mean_efficiency']:>15.2f}")

    results['summary'] = {
        'ketamine': ket_stats,
        'ssri': ssri_stats,
        'neurosteroid': neuro_stats
    }

    # =========================================================================
    # PHASE 7: Detailed Results Table
    # =========================================================================
    print_section_header("DETAILED ISO-DOSE COMPARISON RESULTS", char="═")

    for target_dose, comparison in results['iso_dose_comparisons'].items():
        print(f"\n  TARGET DOSE: {target_dose:.6f}")
        print("  " + "=" * 90)

        print(f"\n  {'Treatment':<15} {'Parameter':<20} {'Dose L1':>10} {'Turnover':>10} {'ΔSparsity':>10} "
              f"{'Acute Prsv':>12} {'Relapse Δ':>12}")
        print("  " + "-" * 90)

        for treatment_name in ['ketamine', 'ssri', 'neurosteroid']:
            match = comparison.get(treatment_name)
            if match:
                param_str = f"{match['param_name']}={match['param_value']}"
                d = match['dose']
                print(f"  {treatment_name.capitalize():<15} {param_str:<20} {d.l1_norm:>10.6f} {d.synaptic_turnover:>10.4f} "
                      f"{d.sparsity_change:>10.4f} {match['acute_persev']:>12.4f} {match['relapse_delta']:>+12.4f}")

        print("\n  OUTCOME COMPARISON AT THIS DOSE LEVEL:")

        treatments_at_dose = []
        for t_name in ['ketamine', 'ssri', 'neurosteroid']:
            m = comparison.get(t_name)
            if m:
                treatments_at_dose.append((t_name, m['acute_persev'], m['relapse_delta']))

        if treatments_at_dose:
            best_acute = min(treatments_at_dose, key=lambda x: x[1])
            best_relapse = min(treatments_at_dose, key=lambda x: x[2])

            print(f"    Best acute perseveration: {best_acute[0].capitalize()} ({best_acute[1]:.4f})")
            print(f"    Best relapse resistance:  {best_relapse[0].capitalize()} (Δ = {best_relapse[2]:+.4f})")

            if best_acute[0] == best_relapse[0]:
                print(f"    >> {best_acute[0].upper()} dominates at this dose level")
            else:
                print(f"    >> Trade-off: {best_acute[0].capitalize()} for acute, {best_relapse[0].capitalize()} for durability")

    # =========================================================================
    # FINAL CONCLUSIONS
    # =========================================================================
    print_section_header("ISO-DOSE EXPERIMENT CONCLUSIONS", char="█")

    print("""
  ISO-DOSE COMPARISON FINDINGS:

  1. DOSE QUANTIFICATION:
     - L1 weight change norm provides mechanism-agnostic dose measurement
     - Ketamine produces highest dose (structural regrowth)
     - SSRI produces moderate dose (gradual weight refinement)
     - Neurosteroid produces lowest dose (minimal weight changes, runtime modulation)

  2. EFFICIENCY ANALYSIS:
     - Efficiency = outcome improvement per unit dose
     - Reveals which mechanism achieves most benefit with least network alteration
     - Critical for understanding treatment optimization

  3. ISO-DOSE MATCHING:
     - At matched dose levels, treatments can be fairly compared
     - Removes bias from arbitrary hyperparameter scaling
     - Reveals inherent mechanism advantages vs parameter-driven effects

  4. CLINICAL IMPLICATIONS:
     - High-efficiency treatments may be preferred when minimizing side effects
     - Low-dose/high-effect treatments may have better safety profiles
     - Dose-response curves guide titration strategies
    """)

    return results


# =============================================================================
# ORIGINAL MULTI-MECHANISM COMPARISON (PRESERVED)
# =============================================================================

def run_multi_mechanism_ocd_experiment(
    device: torch.device,
    seed: int = None,
    verbose: bool = True
) -> Dict[str, Any]:
    """Compare three antidepressant mechanisms in OCD pruning framework."""
    if seed is None:
        seed = CONFIG['seed']
    set_seed(seed)

    print_section_header("MULTI-MECHANISM ANTIDEPRESSANT COMPARISON", char="█")
    print(f"\n  Comparing treatment mechanisms in OCD model (Seed: {seed})")

    train_loader, test_loader, _ = create_rule_switch_dataloaders()

    print_subsection_header("PHASE 1: Creating Shared Pruned Baseline")

    base_model = CSTCNetwork().to(device)
    print("  Training healthy baseline model...")
    train(base_model, train_loader, test_loader, verbose=False)

    base_mgr = CSTCPruningManager(base_model)
    prune_stats = base_mgr.prune_by_magnitude(sparsity=CONFIG['ocd_prune_sparsity'])

    print(f"  Applied developmental over-pruning: {CONFIG['ocd_prune_sparsity']*100:.0f}%")
    print(f"  Achieved sparsity: {prune_stats['achieved_sparsity']*100:.1f}%")

    base_state = copy.deepcopy(base_model.state_dict())
    base_masks = copy.deepcopy(base_mgr.masks)

    results = {}

    print_subsection_header("PHASE 2: Untreated Baseline Evaluation")

    untreated_metrics = compute_ocd_metrics(base_model, test_loader, device)
    results['untreated'] = {
        'sparsity': untreated_metrics.sparsity,
        'accuracy': untreated_metrics.accuracy,
        'persev': untreated_metrics.perseverative_error_rate,
        'switch_cost': untreated_metrics.switch_cost,
        'flex_index': untreated_metrics.flexibility_index,
        'repetition_rate': untreated_metrics.repetition_rate,
        'trials_to_recover': untreated_metrics.trials_to_recover
    }

    print(f"  UNTREATED OCD STATE:")
    print(f"    Sparsity:              {untreated_metrics.sparsity*100:.1f}%")
    print(f"    Accuracy:              {untreated_metrics.accuracy:.4f}")
    print(f"    Perseverative Errors:  {untreated_metrics.perseverative_error_rate:.4f}")
    print(f"    Flexibility Index:     {untreated_metrics.flexibility_index:.4f}")

    def clone_baseline():
        model = CSTCNetwork().to(device)
        model.load_state_dict(copy.deepcopy(base_state))
        mgr = CSTCPruningManager(model)
        mgr.masks = copy.deepcopy(base_masks)
        mgr.apply_masks()
        return model, mgr

    treatments = ['ketamine', 'ssri', 'neurosteroid']

    for treatment in treatments:
        print_subsection_header(f"PHASE 3: {treatment.upper()} Treatment")

        model, mgr = clone_baseline()

        pre_state = {n: p.data.clone() for n, p in model.named_parameters() if 'weight' in n}

        if treatment == 'ketamine':
            treatment_info = ketamine_treatment_ocd(
                model, mgr, train_loader, test_loader, verbose=verbose
            )
        elif treatment == 'ssri':
            treatment_info = ssri_treatment_ocd(
                model, mgr, train_loader, test_loader, verbose=verbose
            )
        elif treatment == 'neurosteroid':
            treatment_info = neurosteroid_treatment_ocd(
                model, mgr, train_loader, test_loader, verbose=verbose
            )

        dose_metrics = compute_all_dose_metrics(pre_state, model)

        acute_metrics = compute_ocd_metrics(model, test_loader, device)

        off_med_metrics = None
        if treatment == 'neurosteroid':
            model.set_inhibition(1.0, False)
            off_med_metrics = compute_ocd_metrics(model, test_loader, device)
            model.set_inhibition(
                CONFIG['comparison_neurosteroid_strength'],
                CONFIG['comparison_neurosteroid_use_tanh']
            )

        pre_relapse_persev = acute_metrics.perseverative_error_rate
        pre_relapse_flex = acute_metrics.flexibility_index
        pre_relapse_acc = acute_metrics.accuracy

        mgr.secondary_prune(fraction=CONFIG['relapse_prune_fraction'])

        relapse_metrics = compute_ocd_metrics(model, test_loader, device)

        relapse_delta_persev = relapse_metrics.perseverative_error_rate - pre_relapse_persev
        relapse_delta_flex = pre_relapse_flex - relapse_metrics.flexibility_index
        relapse_delta_acc = pre_relapse_acc - relapse_metrics.accuracy

        results[treatment] = {
            'treatment_info': treatment_info,
            'dose_metrics': {
                'l1_norm': dose_metrics.l1_norm,
                'l2_norm': dose_metrics.l2_norm,
                'synaptic_turnover': dose_metrics.synaptic_turnover,
                'sparsity_change': dose_metrics.sparsity_change
            },
            'acute_sparsity': acute_metrics.sparsity,
            'acute_accuracy': acute_metrics.accuracy,
            'acute_persev': acute_metrics.perseverative_error_rate,
            'acute_flex': acute_metrics.flexibility_index,
            'acute_repetition': acute_metrics.repetition_rate,
            'acute_switch_cost': acute_metrics.switch_cost,
            'acute_recovery': acute_metrics.trials_to_recover,
            'improvement_persev': untreated_metrics.perseverative_error_rate - acute_metrics.perseverative_error_rate,
            'improvement_flex': acute_metrics.flexibility_index - untreated_metrics.flexibility_index,
            'improvement_acc': acute_metrics.accuracy - untreated_metrics.accuracy,
            'relapse_sparsity': relapse_metrics.sparsity,
            'relapse_persev': relapse_metrics.perseverative_error_rate,
            'relapse_flex': relapse_metrics.flexibility_index,
            'relapse_delta_persev': relapse_delta_persev,
            'relapse_delta_flex': relapse_delta_flex,
            'relapse_delta_acc': relapse_delta_acc,
            'off_med_persev': off_med_metrics.perseverative_error_rate if off_med_metrics else None,
            'off_med_flex': off_med_metrics.flexibility_index if off_med_metrics else None,
            'off_med_acc': off_med_metrics.accuracy if off_med_metrics else None,
        }

        print(f"\n    DOSE METRICS:")
        print(f"      L1 Weight Change:    {dose_metrics.l1_norm:.6f}")
        print(f"      L2 Weight Change:    {dose_metrics.l2_norm:.6f}")
        print(f"      Synaptic Turnover:   {dose_metrics.synaptic_turnover:.4f}")
        print(f"      Sparsity Change:     {dose_metrics.sparsity_change:.4f}")

        print(f"\n    ACUTE EFFECTS:")
        print(f"      Sparsity:            {acute_metrics.sparsity*100:.1f}%")
        print(f"      Accuracy:            {acute_metrics.accuracy:.4f} (Δ = {results[treatment]['improvement_acc']:+.4f})")
        print(f"      Perseveration:       {acute_metrics.perseverative_error_rate:.4f} (Δ = {-results[treatment]['improvement_persev']:+.4f})")
        print(f"      Flexibility:         {acute_metrics.flexibility_index:.4f} (Δ = {results[treatment]['improvement_flex']:+.4f})")

        efficiency = results[treatment]['improvement_persev'] / (dose_metrics.l1_norm + 1e-8)
        print(f"      Efficiency:          {efficiency:.2f} (persev reduction / dose)")

        if off_med_metrics:
            print(f"\n    OFF-MEDICATION TEST:")
            print(f"      Perseveration:       {off_med_metrics.perseverative_error_rate:.4f}")
            off_med_delta = off_med_metrics.perseverative_error_rate - acute_metrics.perseverative_error_rate
            print(f"      Reversal:            {off_med_delta:+.4f} perseveration increase")

        print(f"\n    RELAPSE SIMULATION ({CONFIG['relapse_prune_fraction']*100:.0f}% secondary pruning):")
        print(f"      Sparsity:            {relapse_metrics.sparsity*100:.1f}%")
        print(f"      Perseveration:       {relapse_metrics.perseverative_error_rate:.4f} (Δ = {relapse_delta_persev:+.4f})")
        print(f"      Flexibility:         {relapse_metrics.flexibility_index:.4f} (Δ = {-relapse_delta_flex:+.4f})")

    print_section_header("COMPREHENSIVE TREATMENT COMPARISON", char="═")

    print("\n  DOSE COMPARISON:")
    print(f"  {'Treatment':<15} {'L1 Dose':>12} {'L2 Dose':>12} {'Turnover':>12} {'ΔSparsity':>12}")
    print("  " + "-" * 65)
    for treatment in treatments:
        dm = results[treatment]['dose_metrics']
        print(f"  {treatment.capitalize():<15} {dm['l1_norm']:>12.6f} {dm['l2_norm']:>12.6f} "
              f"{dm['synaptic_turnover']:>12.4f} {dm['sparsity_change']:>12.4f}")

    print("\n  ACUTE EFFECTS:")
    print(f"  {'Treatment':<15} {'Sparsity':>10} {'Accuracy':>10} {'Persev':>10} {'Flex':>10} {'Efficiency':>12}")
    print("  " + "-" * 70)
    print(f"  {'Untreated':<15} {results['untreated']['sparsity']*100:>9.1f}% "
          f"{results['untreated']['accuracy']:>10.4f} {results['untreated']['persev']:>10.4f} "
          f"{results['untreated']['flex_index']:>10.4f} {'N/A':>12}")

    for treatment in treatments:
        r = results[treatment]
        efficiency = r['improvement_persev'] / (r['dose_metrics']['l1_norm'] + 1e-8)
        print(f"  {treatment.capitalize():<15} {r['acute_sparsity']*100:>9.1f}% "
              f"{r['acute_accuracy']:>10.4f} {r['acute_persev']:>10.4f} "
              f"{r['acute_flex']:>10.4f} {efficiency:>12.2f}")

    print("\n  RELAPSE VULNERABILITY:")
    print(f"  {'Treatment':<15} {'ΔPersev':>12} {'ΔFlexibility':>12} {'Interpretation':<25}")
    print("  " + "-" * 70)

    for treatment in treatments:
        r = results[treatment]
        if r['relapse_delta_persev'] < 0.02:
            interp = "Relapse resistant"
        elif r['relapse_delta_persev'] < 0.05:
            interp = "Moderate relapse"
        else:
            interp = "High relapse risk"

        print(f"  {treatment.capitalize():<15} {r['relapse_delta_persev']:>+12.4f} "
              f"{-r['relapse_delta_flex']:>+12.4f} {interp:<25}")

    return results


# =============================================================================
# MAIN ENTRY POINT
# =============================================================================

def main():
    """Run full OCD hypothesis validation suite with iso-dose comparison."""
    print("\n" + "█" * 80)
    print("█" + " " * 78 + "█")
    print("█" + "COMPUTATIONAL MODEL: OCD SYNAPTIC PRUNING HYPOTHESIS".center(78) + "█")
    print("█" + "WITH ISO-DOSE FAIR COMPARISON PIPELINE".center(78) + "█")
    print("█" + " " * 78 + "█")
    print("█" * 80)

    print(f"\n  PyTorch Version: {torch.__version__}")
    print(f"  Device: {DEVICE}")
    print(f"  CUDA Available: {torch.cuda.is_available()}")

    set_seed(CONFIG['seed'])

    print("\n  Running Multi-Mechanism Comparison (with dose metrics)...")
    multi_results = run_multi_mechanism_ocd_experiment(DEVICE, verbose=True)

    print("\n  Running Iso-Dose Comparison Experiment...")
    iso_dose_results = run_iso_dose_comparison_experiment(DEVICE, verbose=True)

    print("\n" + "█" * 80)
    print("█" + "EXPERIMENT COMPLETE".center(78) + "█")
    print("█" * 80)

    return {
        'multi_mechanism': multi_results,
        'iso_dose': iso_dose_results
    }


if __name__ == "__main__":
    results = main()


████████████████████████████████████████████████████████████████████████████████
█                                                                              █
█             COMPUTATIONAL MODEL: OCD SYNAPTIC PRUNING HYPOTHESIS             █
█                    WITH ISO-DOSE FAIR COMPARISON PIPELINE                    █
█                                                                              █
████████████████████████████████████████████████████████████████████████████████

  PyTorch Version: 2.9.0+cu126
  Device: cuda
  CUDA Available: True

  Running Multi-Mechanism Comparison (with dose metrics)...

████████████████████████████████████████████████████████████████████████████████
                   MULTI-MECHANISM ANTIDEPRESSANT COMPARISON                    
████████████████████████████████████████████████████████████████████████████████

  Comparing treatment mechanisms in OCD model (Seed: 42)

------------------------------------------------------------
  PHASE 1: Creatin

# The End