# PRUNE-WITHOUT-REPAIR HYPOTHESIS IRREVERSIBILITY SWEEP EXPERIMENT

In [None]:
"""
================================================================================
COMPUTATIONAL MODEL: PRUNE-WITHOUT-REPAIR HYPOTHESIS
IRREVERSIBILITY SWEEP EXPERIMENT
================================================================================

VERSION 9.0 - DIAGNOSIS-FREE REVERSIBILITY TESTING
───────────────────────────────────────────────────

This version refactors the neurodevelopmental disorder model into a clean,
diagnosis-free experiment that isolates the core "prune-without-repair"
mechanism as requested.

╔═══════════════════════════════════════════════════════════════════════════════╗
║                         VERSION 9.0 KEY CHANGES                               ║
╠═══════════════════════════════════════════════════════════════════════════════╣
║                                                                               ║
║  1. REMOVED ALL DIAGNOSIS-SPECIFIC CODE                                       ║
║     ────────────────────────────────────                                      ║
║     • Deleted DISORDER_PROTOCOLS entirely                                     ║
║     • Removed ADHD-specific: internal_noise escalation, distractors,         ║
║       impulsivity_penalty, variability_penalty                               ║
║     • Removed SCZ-specific: progressive phases, noise_escalation,            ║
║       catastrophic_threshold                                                 ║
║     • Removed OCD/ASD-specific: rigidity_persistence, restricted_boost       ║
║     • Simplified data generation: pure rule-learning task                    ║
║                                                                               ║
║  2. RENAMED "IQ" → "COGNITIVE INDEX" (CI)                                    ║
║     ──────────────────────────────────────                                    ║
║     • All variables: iq → ci, composite_iq → composite_ci                    ║
║     • IQMetrics → CIMetrics                                                  ║
║     • Healthy baseline anchored at ~115 (arbitrary proxy)                    ║
║                                                                               ║
║  3. NEW CORE EXPERIMENT: IRREVERSIBILITY SWEEP                                ║
║     ──────────────────────────────────────────                                ║
║     • Fixed pruning severity (calibration_factor=1.8, base_sparsity=0.95)   ║
║     • Fixed repair_factor=0.5 (impaired plasticity)                          ║
║     • Sweep irreversibility_factor = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0]         ║
║     • Multi-seed averaging for robustness (default n=10)                     ║
║                                                                               ║
║  4. HYPOTHESIS TESTED                                                         ║
║     ─────────────────────                                                     ║
║     When pruning is excessive but FULLY REVERSIBLE (low irreversibility),   ║
║     treatment can largely restore performance ("prune-with-repair").        ║
║                                                                               ║
║     When pruning damage is PARTIALLY/FULLY IRREVERSIBLE (high irrev.),      ║
║     recovery is blocked or minimal ("prune-without-repair"), even with      ║
║     plasticity-enhancing treatment.                                          ║
║                                                                               ║
║  EXPECTED GRADIENT:                                                           ║
║  ┌─────────────────────────────────────────────────────────────────────────┐ ║
║  │ Irreversibility │  Recovery   │  Interpretation                        │ ║
║  │ 0.0 (reversible)│  +20-30 pts │  Near-full recovery (prune-with-repair)│ ║
║  │ 0.2             │  +15-20 pts │  Good but incomplete                   │ ║
║  │ 0.4             │  +8-15 pts  │  Moderate recovery                     │ ║
║  │ 0.6             │  +2-8 pts   │  Minimal recovery                      │ ║
║  │ 0.8             │  0 to +3 pts│  Near-zero recovery                    │ ║
║  │ 1.0 (permanent) │  ~0 pts     │  No recovery (prune-without-repair)    │ ║
║  └─────────────────────────────────────────────────────────────────────────┘ ║
║                                                                               ║
╚═══════════════════════════════════════════════════════════════════════════════╝

Author: Computational Psychiatry Research
Date: January 2026
Version: 9.0 (Diagnosis-Free Irreversibility Sweep)
================================================================================
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
from typing import List, Tuple, Optional, Dict, Any
from dataclasses import dataclass, field
from enum import Enum
import warnings
from collections import defaultdict

warnings.filterwarnings('ignore')


# =============================================================================
# GLOBAL STATE FOR CLINICAL CALIBRATION
# =============================================================================
#
# ┌──────────────────────────────────────────────────────────────────────────────┐
# │ ANNOTATION: Global calibration state for consistent CI scoring              │
# │                                                                              │
# │ PURPOSE:                                                                     │
# │ ─────────                                                                    │
# │ Cognitive Index (CI) scoring requires a REFERENCE POINT (healthy baseline). │
# │ This global state ensures:                                                   │
# │                                                                              │
# │ 1. CONSISTENT NORMALIZATION                                                  │
# │    All conditions are measured relative to the same healthy baseline.       │
# │    Formula: CI = 115 + 15 × (raw - healthy_raw) / population_sd             │
# │                                                                              │
# │ 2. RECOVERY CEILING ENFORCEMENT                                              │
# │    Treated models cannot exceed healthy CI (prevents unrealistic gains).    │
# │    max_recovery_ci = healthy_ci × recovery_ceiling_factor                   │
# │                                                                              │
# │ 3. DOMAIN-SPECIFIC BASELINES                                                 │
# │    Track fluid, crystallized, executive separately for profiling.           │
# │                                                                              │
# │ INITIALIZATION SEQUENCE:                                                     │
# │ ────────────────────────                                                     │
# │ 1. Train healthy model → achieve stable performance                         │
# │ 2. Compute raw composite and domain scores                                  │
# │ 3. Store as calibration anchor (healthy_raw_composite)                      │
# │ 4. All subsequent CI calculations reference this anchor                     │
# │ 5. Set calibrated=True to activate anchored normalization                   │
# └──────────────────────────────────────────────────────────────────────────────┘
# =============================================================================

@dataclass
class GlobalCalibrationState:
    """
    Global state for Cognitive Index (CI) calibration.

    ┌──────────────────────────────────────────────────────────────────────────┐
    │ FIELD DESCRIPTIONS:                                                      │
    │                                                                          │
    │ healthy_ci (float, default=115.0):                                       │
    │   Target CI for healthy/optimal model. Set at 115 as arbitrary anchor.  │
    │   Analogous to IQ where population mean=100, but healthy optimum ~115.  │
    │                                                                          │
    │ healthy_raw_composite (float, default=0.75):                             │
    │   Raw performance score (0-1) from healthy model.                        │
    │   Used as anchor point for CI normalization.                            │
    │                                                                          │
    │ population_sd_raw (float, default=0.15):                                 │
    │   Standard deviation of raw scores in "population".                      │
    │   Used to scale raw differences into CI points (15 pts per SD).         │
    │                                                                          │
    │ calibrated (bool, default=False):                                        │
    │   Flag indicating whether calibration has been performed.                │
    │   CI computation uses different logic before/after calibration.          │
    │                                                                          │
    │ max_recovery_ci (float, default=117.0):                                  │
    │   Maximum CI achievable after treatment.                                 │
    │   Prevents unrealistic "super-recovery" beyond healthy baseline.        │
    │                                                                          │
    │ healthy_fluid/crystallized/executive (float):                            │
    │   Domain-specific performance for profile comparison.                    │
    └──────────────────────────────────────────────────────────────────────────┘
    """
    healthy_ci: float = 115.0
    healthy_raw_composite: float = 0.75
    population_sd_raw: float = 0.15
    calibrated: bool = False
    max_recovery_ci: float = 117.0
    healthy_fluid: float = 0.65
    healthy_crystallized: float = 0.80
    healthy_executive: float = 0.70


GLOBAL_CALIBRATION = GlobalCalibrationState()


# =============================================================================
# CONFIGURATION
# =============================================================================
#
# ┌──────────────────────────────────────────────────────────────────────────────┐
# │ ANNOTATION: Centralized configuration for all experiment parameters         │
# │                                                                              │
# │ ORGANIZATION:                                                                │
# │ ─────────────                                                                │
# │ • Architecture: Network structure (hidden dims, layers, etc.)               │
# │ • Training: Learning rates, epochs, batch sizes                             │
# │ • Task: Sequence lengths, rule counts, data generation                      │
# │ • Pruning: Sparsity levels, regrowth parameters                            │
# │ • CI Scaling: Normalization parameters for Cognitive Index                  │
# │ • Irreversibility Sweep: Core experiment parameters                         │
# │ • Multi-seed: Reproducibility settings                                      │
# │                                                                              │
# │ KEY PARAMETERS FOR IRREVERSIBILITY SWEEP:                                   │
# │ ──────────────────────────────────────────                                  │
# │ • pruning_calibration_factor = 1.8 (excessive pruning severity)            │
# │ • base_sparsity = 0.95 (target ~90-95% connections pruned)                 │
# │ • repair_factor = 0.5 (impaired plasticity, like NMDA hypofunction)        │
# │ • treatment_regrowth_fraction = 0.3 (fraction of pruned to regrow)         │
# └──────────────────────────────────────────────────────────────────────────────┘
# =============================================================================

CONFIG = {
    # =========================================================================
    # ARCHITECTURE
    # =========================================================================
    'input_dim': 2,                     # 2D input points (x, y coordinates)
    'hidden_dims': [128, 64],           # Two hidden layers
    'output_dim': 8,                    # 8-class classification
    'num_gru_layers': 2,                # Recurrent depth

    # =========================================================================
    # TRAINING
    # =========================================================================
    'batch_size': 32,
    'baseline_lr': 1e-3,                # Learning rate for initial training
    'finetune_lr': 5e-4,                # Learning rate for fine-tuning
    'baseline_epochs': 50,              # Epochs for initial training
    'regrowth_epochs': 30,              # Epochs after regrowth
    'consolidation_epochs_default': 20, # Epochs for treatment consolidation

    # =========================================================================
    # TASK
    # =========================================================================
    'seq_len': 200,                     # Sequence length per trial
    'n_train_sequences': 500,           # Training set size
    'n_test_sequences': 100,            # Test set size
    'n_rules': 4,                       # Training rules
    'n_rules_extended': 8,              # Total rules (including novel)

    # =========================================================================
    # PRUNING
    # =========================================================================
    'base_optimal_sparsity': 0.75,      # Optimal sparsity for healthy network
    'regrowth_fraction': 0.50,          # Fraction of pruned to regrow
    'regrowth_init_scale': 0.03,        # Scale for regrown weight init
    'recurrence_bias': 1.2,             # Protection factor for recurrent weights

    # =========================================================================
    # CI TASK
    # =========================================================================
    'ci_task_seq_len': 300,             # Extended sequence for CI assessment
    'ci_training_rules': [0, 1, 2, 3],  # Rules used in training
    'ci_novel_rules': [4, 5, 6, 7],     # Novel rules for fluid intelligence
    'ci_noise_levels': [0.0, 0.3, 0.6, 0.9, 1.2],  # Noise robustness sweep
    'ci_multi_step_depth': 3,           # Depth of multi-step integration

    # =========================================================================
    # CI SCALING
    # =========================================================================
    'ci_healthy_target': 115.0,         # Target CI for healthy model
    'ci_population_mean': 100.0,        # Population mean (for reference)
    'ci_population_sd': 15.0,           # Population SD (15 pts per SD)
    'ci_raw_anchor': 0.75,              # Expected raw score for healthy
    'ci_raw_sd': 0.15,                  # Raw score SD

    'ci_domain_weights': {              # Domain weights in composite
        'fluid': 0.35,                  # Novel rule generalization
        'crystallized': 0.30,           # Trained rule accuracy
        'executive': 0.35,              # Multi-step integration
    },

    # =========================================================================
    # IRREVERSIBILITY SWEEP (CORE EXPERIMENT)
    # =========================================================================
    'pruning_calibration_factor': 1.8,  # Excessive pruning severity
    'base_sparsity': 0.95,              # Target sparsity (90-95%)
    'repair_factor': 0.5,               # Impaired plasticity
    'treatment_regrowth_fraction': 0.3, # Fraction to regrow in treatment
    'treatment_consolidation_epochs': 20,
    'treatment_consolidation_lr_factor': 0.7,

    # Default irreversibility levels to sweep
    'irreversibility_levels': [0.0, 0.2, 0.4, 0.6, 0.8, 1.0],

    # =========================================================================
    # MULTI-SEED
    # =========================================================================
    'seed': 42,
    'n_seeds': 10,                      # Number of seeds for averaging
}

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


def set_seed(seed: int):
    """
    Ensure reproducibility across runs.

    ┌──────────────────────────────────────────────────────────────────────────┐
    │ REPRODUCIBILITY MECHANISM:                                               │
    │                                                                          │
    │ Sets random seeds for:                                                   │
    │ • PyTorch CPU operations                                                │
    │ • PyTorch CUDA operations (if available)                                │
    │ • NumPy random number generator                                         │
    │ • CUDA deterministic algorithms                                         │
    │                                                                          │
    │ Note: Some CUDA operations may still be non-deterministic.              │
    │ For full reproducibility, also set CUBLAS_WORKSPACE_CONFIG.             │
    └──────────────────────────────────────────────────────────────────────────┘
    """
    torch.manual_seed(seed)
    np.random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False


# =============================================================================
# PRINTING UTILITIES (ENHANCED ANNOTATIONS)
# =============================================================================

def print_section_header(title: str, width: int = 80, char: str = "="):
    """
    Print formatted section header with box drawing.

    ════════════════════════════════════════════════════════════════════════════════
                                    EXAMPLE TITLE
    ════════════════════════════════════════════════════════════════════════════════
    """
    print(f"\n{char * width}")
    print(f"{title.center(width)}")
    print(f"{char * width}")


def print_subsection_header(title: str, width: int = 60, char: str = "-"):
    """
    Print formatted subsection header.

    ------------------------------------------------------------
      Example Subsection
    ------------------------------------------------------------
    """
    print(f"\n{char * width}")
    print(f"  {title}")
    print(f"{char * width}")


def print_annotation(text: str, indent: int = 4, prefix: str = "→"):
    """
    Print annotation with arrow prefix for visual clarity.

    Example: → This is an annotated message
    """
    prefix_str = " " * indent + prefix + " "
    print(f"{prefix_str}{text}")


def print_box(lines: List[str], title: str = None, width: int = 74):
    """
    Print text in a bordered box for emphasis.

    ┌──────────────────────────────────────────────────────────────────────────┐
    │ TITLE (optional)                                                         │
    │                                                                          │
    │ Line 1 of content                                                        │
    │ Line 2 of content                                                        │
    └──────────────────────────────────────────────────────────────────────────┘
    """
    print("    ┌" + "─" * (width - 2) + "┐")
    if title:
        print(f"    │ {title:<{width-4}} │")
        print("    │" + " " * (width - 2) + "│")
    for line in lines:
        if len(line) > width - 4:
            line = line[:width-7] + "..."
        print(f"    │ {line:<{width-4}} │")
    print("    └" + "─" * (width - 2) + "┘")


def print_debug(seed: int, label: str, ci: float, raw: float = None, extra: str = ""):
    """
    Print debug information for multi-seed runs.

    Format: [DEBUG] Seed 42 Condition: CI=95.3, raw=0.652
    """
    raw_str = f", raw={raw:.3f}" if raw is not None else ""
    extra_str = f", {extra}" if extra else ""
    print(f"      [DEBUG] Seed {seed} {label}: CI={ci:.1f}{raw_str}{extra_str}")


def print_table_row(cols: List[str], widths: List[int], sep: str = "│"):
    """Print a formatted table row."""
    row = sep
    for col, width in zip(cols, widths):
        row += f" {col:^{width}} {sep}"
    print(row)


def print_table_separator(widths: List[int], left: str = "├", mid: str = "┼",
                          right: str = "┤", fill: str = "─"):
    """Print a table separator line."""
    line = left
    for i, width in enumerate(widths):
        line += fill * (width + 2)
        line += mid if i < len(widths) - 1 else right
    print(line)


# =============================================================================
# EXTENDED RULE DEFINITIONS (SIMPLIFIED - NO DISORDER-SPECIFIC TASKS)
# =============================================================================

class ExtendedRule(Enum):
    """
    Extended rule set with 8 rules for comprehensive CI assessment.

    ┌──────────────────────────────────────────────────────────────────────────┐
    │ RULE DESCRIPTIONS:                                                       │
    │                                                                          │
    │ TRAINING RULES (0-3): Learned during initial training                    │
    │ • X_SIGN (0): Classify by sign of x-coordinate                          │
    │ • Y_SIGN (1): Classify by sign of y-coordinate                          │
    │ • QUADRANT (2): Classify by quadrant                                    │
    │ • DIAGONAL (3): Classify by diagonal relationship                       │
    │                                                                          │
    │ NOVEL RULES (4-7): Used to assess generalization (fluid intelligence)   │
    │ • DISTANCE (4): Classify by distance from origin                        │
    │ • ANGLE (5): Classify by angle from positive x-axis                     │
    │ • SUM (6): Classify by sum of coordinates                               │
    │ • PRODUCT (7): Classify by product of coordinates                       │
    └──────────────────────────────────────────────────────────────────────────┘
    """
    X_SIGN = 0
    Y_SIGN = 1
    QUADRANT = 2
    DIAGONAL = 3
    DISTANCE = 4
    ANGLE = 5
    SUM = 6
    PRODUCT = 7


def apply_extended_rule(points: torch.Tensor, rule: int) -> torch.Tensor:
    """
    Apply classification rule to 2D points.

    Args:
        points: Tensor of shape (..., 2) containing x,y coordinates
        rule: Integer rule index (0-7)

    Returns:
        Tensor of class labels (0-7)
    """
    x, y = points[..., 0], points[..., 1]

    if rule == 0:    # X_SIGN
        labels = ((x >= 0).long() * 2 + (y >= 0).long())
    elif rule == 1:  # Y_SIGN
        labels = ((y >= 0).long() * 2 + (x >= 0).long())
    elif rule == 2:  # QUADRANT
        labels = ((x >= 0).long() + (y >= 0).long() * 2)
    elif rule == 3:  # DIAGONAL
        main_diag = (y >= x).long()
        anti_diag = (y >= -x).long()
        labels = main_diag * 2 + anti_diag
    elif rule == 4:  # DISTANCE
        distance = torch.sqrt(x**2 + y**2)
        labels = torch.zeros_like(x, dtype=torch.long)
        labels[distance >= 1.0] = 1
        labels[distance >= 2.0] = 2
        labels[distance >= 3.0] = 3
    elif rule == 5:  # ANGLE
        angle = torch.atan2(y, x)
        labels = torch.zeros_like(x, dtype=torch.long)
        labels[angle >= -np.pi/2] = 1
        labels[angle >= 0] = 2
        labels[angle >= np.pi/2] = 3
    elif rule == 6:  # SUM
        sum_xy = x + y
        labels = torch.zeros_like(x, dtype=torch.long)
        labels[sum_xy >= -1] = 1
        labels[sum_xy >= 0] = 2
        labels[sum_xy >= 1] = 3
    elif rule == 7:  # PRODUCT
        product = x * y
        labels = torch.zeros_like(x, dtype=torch.long)
        labels[product >= -1] = 1
        labels[product >= 0] = 2
        labels[product >= 1] = 3
    else:
        raise ValueError(f"Unknown rule: {rule}")

    return labels % CONFIG['output_dim']


# =============================================================================
# DATA GENERATION (SIMPLIFIED - NO DISTRACTORS)
# =============================================================================

def generate_ci_task_data(
    n_sequences: int,
    seq_len: int = None,
    training_rules: List[int] = None,
    include_novel: bool = True,
    noise_level: float = 0.0,
    include_multi_step: bool = True
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Dict[str, Any]]:
    """
    Generate CI proxy task data (simplified, no distractors).

    ┌──────────────────────────────────────────────────────────────────────────┐
    │ TASK STRUCTURE:                                                          │
    │                                                                          │
    │ • Input: Sequence of 2D points (x, y coordinates)                       │
    │ • Output: Classification label based on current rule                     │
    │ • Rules switch occasionally (2% probability per timestep)               │
    │ • Multi-step trials require integration over consecutive timesteps      │
    │                                                                          │
    │ TRIAL TYPE ENCODING (bitfield):                                          │
    │ ─────────────────────────────────                                        │
    │ Bit 0 (value 1): Novel rule trial                                        │
    │ Bit 1 (value 2): Multi-step integration trial                            │
    │ Bit 2 (value 4): High noise trial                                        │
    │                                                                          │
    │ SIMPLIFICATION FROM v8.2:                                                │
    │ • No distractor trials (removed ADHD-specific mechanism)                │
    │ • No disorder-specific noise injection                                   │
    │ • Pure rule-learning task for clean irreversibility assessment          │
    └──────────────────────────────────────────────────────────────────────────┘
    """
    if seq_len is None:
        seq_len = CONFIG['ci_task_seq_len']
    if training_rules is None:
        training_rules = CONFIG['ci_training_rules']

    novel_rules = CONFIG['ci_novel_rules']
    all_rules = training_rules + (novel_rules if include_novel else [])

    all_data, all_labels, all_trial_types = [], [], []

    for _ in range(n_sequences):
        # Generate random 2D points
        points = torch.randn(seq_len, 2) * 1.5

        # Add noise if specified
        if noise_level > 0:
            points = points + torch.randn_like(points) * noise_level

        rules = torch.zeros(seq_len, dtype=torch.long)
        trial_types = torch.zeros(seq_len, dtype=torch.long)

        current_rule = np.random.choice(all_rules)
        multi_step_active = False
        multi_step_count = 0

        for t in range(seq_len):
            # Occasional rule switch
            if np.random.random() < 0.02 and t > 10:
                current_rule = np.random.choice(all_rules)
            rules[t] = current_rule

            is_novel = current_rule in novel_rules

            # Multi-step activation
            if include_multi_step and np.random.random() < 0.1 and not multi_step_active:
                multi_step_active = True
                multi_step_count = CONFIG['ci_multi_step_depth']

            if multi_step_active:
                multi_step_count -= 1
                if multi_step_count <= 0:
                    multi_step_active = False

            # Encode trial type
            trial_type = 0
            if is_novel:
                trial_type |= 1
            if multi_step_active:
                trial_type |= 2
            if noise_level > 0.5:
                trial_type |= 4
            trial_types[t] = trial_type

        # Generate labels
        labels = torch.zeros(seq_len, dtype=torch.long)
        for t in range(seq_len):
            labels[t] = apply_extended_rule(points[t:t+1], rules[t].item())[0]

        all_data.append(points)
        all_labels.append(labels)
        all_trial_types.append(trial_types)

    metadata = {
        'training_rules': training_rules,
        'novel_rules': novel_rules if include_novel else [],
        'noise_level': noise_level,
    }

    return torch.stack(all_data), torch.stack(all_labels), torch.stack(all_trial_types), metadata


def create_ci_task_dataloaders(
    n_train: int = None,
    n_test: int = None,
    noise_level: float = 0.0,
    include_novel: bool = False,
    batch_size: int = None
) -> Tuple[DataLoader, DataLoader, Dict[str, Any]]:
    """
    Create dataloaders for CI assessment (simplified).

    ┌──────────────────────────────────────────────────────────────────────────┐
    │ DATALOADER STRUCTURE:                                                    │
    │                                                                          │
    │ Training loader:                                                         │
    │ • Only training rules (0-3)                                             │
    │ • No noise injection                                                    │
    │ • Shuffled for training                                                 │
    │                                                                          │
    │ Test loader:                                                             │
    │ • May include novel rules (4-7) for fluid intelligence                  │
    │ • May include noise for robustness testing                              │
    │ • Not shuffled for consistent evaluation                                │
    └──────────────────────────────────────────────────────────────────────────┘
    """
    n_train = n_train or CONFIG['n_train_sequences']
    n_test = n_test or CONFIG['n_test_sequences']
    batch_size = batch_size or CONFIG['batch_size']

    train_data, train_labels, train_types, _ = generate_ci_task_data(
        n_train, include_novel=False, noise_level=0.0
    )
    test_data, test_labels, test_types, test_meta = generate_ci_task_data(
        n_test, include_novel=include_novel, noise_level=noise_level
    )

    train_loader = DataLoader(
        TensorDataset(train_data, train_labels, train_types),
        batch_size=batch_size, shuffle=True
    )
    test_loader = DataLoader(
        TensorDataset(test_data, test_labels, test_types),
        batch_size=batch_size, shuffle=False
    )

    return train_loader, test_loader, test_meta


# =============================================================================
# CI METRICS DATACLASS
# =============================================================================

@dataclass
class CIMetrics:
    """
    Comprehensive Cognitive Index metrics (renamed from IQMetrics).

    ┌──────────────────────────────────────────────────────────────────────────┐
    │ METRIC CATEGORIES:                                                       │
    │                                                                          │
    │ FLUID INTELLIGENCE:                                                      │
    │ • fluid_accuracy: Performance on novel rules (generalization)           │
    │ • fluid_generalization_gap: Trained - novel accuracy                    │
    │ • fluid_transfer_efficiency: Novel/trained ratio                        │
    │                                                                          │
    │ CRYSTALLIZED INTELLIGENCE:                                               │
    │ • crystallized_accuracy: Performance on trained rules                   │
    │ • crystallized_stability: Consistency across trials                     │
    │                                                                          │
    │ EXECUTIVE FUNCTION:                                                      │
    │ • executive_multi_step: Performance on integration trials               │
    │ • executive_flexibility: Adaptation to rule switches                    │
    │                                                                          │
    │ NOISE ROBUSTNESS:                                                        │
    │ • noise_robustness: Performance retention under noise                   │
    │ • noise_slope: Rate of performance decline with noise                   │
    │                                                                          │
    │ COMPOSITE CI:                                                            │
    │ • raw_composite: Weighted average of domain scores (0-1)                │
    │ • composite_ci: Scaled Cognitive Index (~55-145)                        │
    │                                                                          │
    │ MODEL STATE:                                                             │
    │ • sparsity: Fraction of zero weights                                    │
    │ • irreversibility_factor: Fraction of permanent synaptic loss           │
    │                                                                          │
    │ TREATMENT RESPONSE:                                                      │
    │ • pre_treatment_ci: CI before intervention                              │
    │ • post_treatment_ci: CI after intervention                              │
    │ • recovery_delta: Post - Pre CI                                         │
    └──────────────────────────────────────────────────────────────────────────┘
    """
    # Fluid Intelligence
    fluid_accuracy: float = 0.0
    fluid_generalization_gap: float = 0.0
    fluid_transfer_efficiency: float = 0.0
    fluid_z_score: float = 0.0

    # Crystallized Intelligence
    crystallized_accuracy: float = 0.0
    crystallized_stability: float = 0.0
    crystallized_z_score: float = 0.0

    # Executive Function
    executive_multi_step: float = 0.0
    executive_flexibility: float = 0.0
    executive_z_score: float = 0.0

    # Noise Robustness
    noise_robustness: float = 0.0
    noise_slope: float = 0.0
    speed_z_score: float = 0.0

    # Composite CI
    raw_composite: float = 0.0
    composite_ci: float = 100.0
    ci_confidence: float = 0.0

    # Model State
    sparsity: float = 0.0
    calibration_factor: float = 1.0
    glutamate_factor: float = 1.0
    repair_factor: float = 1.0
    irreversibility_factor: float = 0.0
    irreversible_fraction: float = 0.0

    # Treatment Response
    pre_treatment_ci: float = 0.0
    post_treatment_ci: float = 0.0
    recovery_delta: float = 0.0


# =============================================================================
# STRESS-AWARE NETWORK (SIMPLIFIED)
# =============================================================================

class StressAwareNetwork(nn.Module):
    """
    Neural network with pruning-relevant parameters.

    ┌──────────────────────────────────────────────────────────────────────────┐
    │ ARCHITECTURE:                                                            │
    │                                                                          │
    │ Input (2D) → FC (128) → ReLU → GRU (64×2 layers) → FC (8) → Output      │
    │                                                                          │
    │ KEY PARAMETERS:                                                          │
    │                                                                          │
    │ glutamate_factor (float, default=1.0):                                   │
    │   Scales hidden activations. Models E/I balance.                        │
    │   >1.0 = hyperexcitability, <1.0 = hypofunction                        │
    │                                                                          │
    │ repair_factor (float, default=1.0):                                      │
    │   Scales gradients during training. Models plasticity.                  │
    │   <1.0 = impaired learning (like NMDA hypofunction)                     │
    │                                                                          │
    │ calibration_factor (float, default=1.0):                                 │
    │   Determines pruning severity. >1.0 = excessive pruning.               │
    │                                                                          │
    │ MASK MECHANISM:                                                          │
    │ Input and output layers have learnable masks for pruning.               │
    │ Masks are applied during forward pass: output = weight × mask           │
    │                                                                          │
    │ SIMPLIFICATION FROM v8.2:                                                │
    │ • Removed internal_noise_level (no ADHD-specific noise)                 │
    │ • Removed stress_level (no disorder-specific stress)                    │
    │ • Removed glutamate_noise (no E/I imbalance modeling)                   │
    │ • Clean forward pass for pure pruning/irreversibility study             │
    └──────────────────────────────────────────────────────────────────────────┘
    """

    def __init__(
        self,
        hidden_dims: List[int] = None,
        num_layers: int = None,
        glutamate_factor: float = 1.0,
        input_dim: int = None,
        calibration_factor: float = 1.0,
        repair_factor: float = 1.0
    ):
        super().__init__()

        if hidden_dims is None:
            hidden_dims = CONFIG['hidden_dims']
        if num_layers is None:
            num_layers = CONFIG['num_gru_layers']
        if input_dim is None:
            input_dim = CONFIG['input_dim']

        self.input_dim = input_dim
        self.hidden_dim = hidden_dims[1]
        self.num_layers = num_layers
        self.glutamate_factor = glutamate_factor
        self.calibration_factor = calibration_factor
        self.repair_factor = repair_factor

        # Network layers
        self.input_fc = nn.Linear(input_dim, hidden_dims[0])
        self.gru = nn.GRU(
            input_size=hidden_dims[0],
            hidden_size=hidden_dims[1],
            num_layers=num_layers,
            batch_first=True,
            dropout=0.1 if num_layers > 1 else 0.0
        )
        self.output_fc = nn.Linear(hidden_dims[1], CONFIG['output_dim'])
        self.relu = nn.ReLU()

        # Pruning masks
        self.register_buffer('input_mask', torch.ones_like(self.input_fc.weight))
        self.register_buffer('output_mask', torch.ones_like(self.output_fc.weight))

    def init_hidden(self, batch_size: int, device: torch.device) -> torch.Tensor:
        """Initialize GRU hidden state."""
        return torch.zeros(self.num_layers, batch_size, self.hidden_dim, device=device)

    def forward(self, x: torch.Tensor, hidden: Optional[torch.Tensor] = None,
                return_hidden: bool = False) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
        """
        Forward pass with mask application.

        ┌──────────────────────────────────────────────────────────────────────┐
        │ FORWARD PASS STAGES:                                                 │
        │                                                                      │
        │ 1. Input transformation: x → FC → ReLU                              │
        │    (masked: weight × input_mask)                                    │
        │                                                                      │
        │ 2. Glutamate scaling (if factor ≠ 1.0)                              │
        │    h = h × glutamate_factor                                         │
        │                                                                      │
        │ 3. Recurrent processing: h → GRU → hidden state                     │
        │                                                                      │
        │ 4. Output transformation: hidden → FC → logits                      │
        │    (masked: weight × output_mask)                                   │
        └──────────────────────────────────────────────────────────────────────┘
        """
        single_step = x.dim() == 2
        if single_step:
            x = x.unsqueeze(1)

        batch_size, seq_len, _ = x.shape
        device = x.device

        if hidden is None:
            hidden = self.init_hidden(batch_size, device)

        # Input transformation with mask
        masked_weight = self.input_fc.weight * self.input_mask
        h = F.linear(x, masked_weight, self.input_fc.bias)
        h = self.relu(h)

        # Glutamate factor scaling
        if self.glutamate_factor != 1.0:
            h = h * self.glutamate_factor

        # Recurrent processing
        gru_out, hidden = self.gru(h, hidden)

        # Output transformation with mask
        masked_output_weight = self.output_fc.weight * self.output_mask
        logits = F.linear(gru_out, masked_output_weight, self.output_fc.bias)

        if single_step:
            logits = logits.squeeze(1)

        return (logits, hidden) if return_hidden else (logits, None)

    def get_sparsity(self) -> float:
        """
        Compute network sparsity (fraction of zero weights).

        Returns:
            float: Sparsity in range [0, 1]
        """
        total, zero = 0, 0
        for name, param in self.named_parameters():
            if 'weight' in name:
                total += param.numel()
                zero += (param.abs() < 1e-8).sum().item()
        return zero / total if total > 0 else 0.0


# =============================================================================
# PRUNING MANAGER
# =============================================================================

class CSTCPruningManager:
    """
    Pruning manager with irreversibility support.

    ┌──────────────────────────────────────────────────────────────────────────┐
    │ MASK VALUES AND THEIR MEANINGS:                                          │
    │ ─────────────────────────────────                                        │
    │                                                                          │
    │ mask = 1.0  : Active connection (not pruned)                             │
    │ mask = 0.0  : Pruned but REVERSIBLE (can be regrown)                    │
    │ mask = -1.0 : Pruned and IRREVERSIBLE (permanently lost)                │
    │                                                                          │
    │ This three-state system enables modeling:                                │
    │                                                                          │
    │ REVERSIBLE PRUNING (mask=0):                                             │
    │ • Connection temporarily inactive                                        │
    │ • Can be restored via gradient-guided regrowth                          │
    │ • Models pruning that can be recovered via treatment                    │
    │                                                                          │
    │ IRREVERSIBLE PRUNING (mask=-1):                                          │
    │ • Connection permanently lost                                            │
    │ • Cannot be restored regardless of treatment                            │
    │ • Models complement-mediated synaptic elimination                       │
    │ • Core mechanism of "prune-without-repair" hypothesis                   │
    │                                                                          │
    │ KEY METHODS:                                                             │
    │ • excessive_prune(): Apply pruning with irreversibility fraction        │
    │ • gradient_guided_regrow(): Restore reversible connections              │
    │ • get_irreversible_fraction(): Measure permanent damage                 │
    └──────────────────────────────────────────────────────────────────────────┘
    """

    def __init__(self, model: StressAwareNetwork):
        self.model = model
        self.original_weights = {}
        self.masks = {}
        self.history = []
        self._save_original_weights()

    def _save_original_weights(self):
        """Store original weights for potential regrowth initialization."""
        for name, param in self.model.named_parameters():
            if 'weight' in name:
                self.original_weights[name] = param.data.clone()
                self.masks[name] = torch.ones_like(param.data)

    def get_sparsity(self) -> float:
        """Get current network sparsity."""
        return self.model.get_sparsity()

    def get_irreversible_fraction(self) -> float:
        """
        Get fraction of weights that are irreversibly pruned.

        ┌──────────────────────────────────────────────────────────────────────┐
        │ IRREVERSIBLE FRACTION:                                               │
        │                                                                      │
        │ Computed as: (count of mask == -1) / (total parameters)             │
        │                                                                      │
        │ This metric indicates the degree of "permanent damage" in the       │
        │ network. Higher values mean more treatment-resistant deficits.      │
        │                                                                      │
        │ In the prune-without-repair hypothesis:                             │
        │ • Low irreversibility → good recovery potential                     │
        │ • High irreversibility → poor recovery (treatment-resistant)        │
        └──────────────────────────────────────────────────────────────────────┘
        """
        total_params = 0
        irreversible_params = 0
        for name, mask in self.masks.items():
            total_params += mask.numel()
            irreversible_params += (mask < 0).sum().item()
        return irreversible_params / total_params if total_params > 0 else 0.0

    def calibrate_prune(
        self,
        base_sparsity: float = None,
        calibration_factor: float = None
    ) -> Dict[str, Any]:
        """
        Apply calibrated pruning (standard, reversible).

        Used for establishing healthy baseline (calibration_factor=1.0).
        """
        if base_sparsity is None:
            base_sparsity = CONFIG['base_optimal_sparsity']
        if calibration_factor is None:
            calibration_factor = self.model.calibration_factor

        effective_sparsity = max(0.0, min(0.99, base_sparsity * calibration_factor))
        self.model.calibration_factor = calibration_factor

        result = self.prune_by_magnitude(sparsity=effective_sparsity)
        result['calibration_factor'] = calibration_factor
        result['effective_sparsity'] = effective_sparsity

        self.history.append({
            'operation': 'calibrate_prune',
            'calibration_factor': calibration_factor,
            'achieved_sparsity': result['achieved_sparsity']
        })

        return result

    def excessive_prune(
        self,
        base_sparsity: float = 0.95,
        irreversibility_factor: float = 0.3,
        calibration_factor: float = 1.5
    ) -> Dict[str, Any]:
        """
        Apply EXCESSIVE pruning with IRREVERSIBILITY.

        ┌──────────────────────────────────────────────────────────────────────┐
        │ EXCESSIVE PRUNING WITH IRREVERSIBILITY:                              │
        │                                                                      │
        │ This is the core mechanism for testing prune-without-repair.        │
        │                                                                      │
        │ PARAMETERS:                                                          │
        │ • base_sparsity: Target fraction of weights to prune (default 0.95)│
        │ • irreversibility_factor: Fraction of pruned that become permanent │
        │ • calibration_factor: Multiplier for effective sparsity            │
        │                                                                      │
        │ PROCESS:                                                             │
        │ 1. Prune by magnitude to achieve target sparsity                    │
        │ 2. For each pruned position (mask=0):                               │
        │    • With probability=irreversibility_factor: set mask=-1           │
        │    • Otherwise: keep mask=0 (reversible)                            │
        │                                                                      │
        │ BIOLOGICAL INTERPRETATION:                                           │
        │ • irreversibility_factor models complement-tagged synapses          │
        │ • Once tagged, microglia eliminate synapse permanently              │
        │ • Higher C4A expression → higher irreversibility                    │
        └──────────────────────────────────────────────────────────────────────┘
        """
        effective_sparsity = min(0.99, base_sparsity * calibration_factor)
        result = self.prune_by_magnitude(sparsity=effective_sparsity)

        total_locked = 0
        for name, mask in self.masks.items():
            pruned_positions = (mask == 0)
            n_pruned = pruned_positions.sum().item()

            if n_pruned > 0:
                n_to_lock = int(irreversibility_factor * n_pruned)
                if n_to_lock > 0:
                    pruned_indices = torch.where(pruned_positions.flatten())[0]
                    lock_indices = pruned_indices[torch.randperm(len(pruned_indices))[:n_to_lock]]
                    flat_mask = mask.flatten()
                    flat_mask[lock_indices] = -1.0
                    self.masks[name] = flat_mask.view_as(mask)
                    total_locked += n_to_lock

        result['irreversibility_factor'] = irreversibility_factor
        result['irreversible_count'] = total_locked
        result['irreversible_fraction'] = self.get_irreversible_fraction()

        self.history.append({
            'operation': 'excessive_prune',
            'irreversibility_factor': irreversibility_factor,
            'irreversible_fraction': result['irreversible_fraction']
        })

        return result

    def prune_by_magnitude(self, sparsity: float, recurrence_bias: float = None) -> Dict[str, Any]:
        """
        Apply magnitude-based pruning.

        ┌──────────────────────────────────────────────────────────────────────┐
        │ MAGNITUDE PRUNING:                                                   │
        │                                                                      │
        │ Prunes smallest-magnitude weights to achieve target sparsity.       │
        │                                                                      │
        │ RECURRENCE BIAS:                                                     │
        │ GRU weights are protected by dividing their magnitude by            │
        │ recurrence_bias before ranking. This preserves temporal dynamics.   │
        │                                                                      │
        │ PROCESS:                                                             │
        │ 1. Collect all weight magnitudes (with recurrence adjustment)       │
        │ 2. Find k-th smallest value (k = sparsity × total)                 │
        │ 3. Zero all weights below threshold                                 │
        │ 4. Update masks to reflect pruned positions                         │
        └──────────────────────────────────────────────────────────────────────┘
        """
        if recurrence_bias is None:
            recurrence_bias = CONFIG['recurrence_bias']

        if sparsity <= 0:
            return {'achieved_sparsity': 0.0, 'weights_pruned': 0}

        all_weights, weight_info = [], []
        for name, param in self.model.named_parameters():
            if 'weight' in name and param.requires_grad:
                flat_weights = param.data.abs().flatten()
                if 'gru' in name and recurrence_bias != 1.0:
                    flat_weights = flat_weights / recurrence_bias
                all_weights.append(flat_weights)
                weight_info.append((name, param))

        if not all_weights:
            return {'achieved_sparsity': 0.0, 'weights_pruned': 0}

        all_weights_cat = torch.cat(all_weights)
        k = int(sparsity * all_weights_cat.numel())
        if k == 0:
            return {'achieved_sparsity': 0.0, 'weights_pruned': 0}

        threshold = torch.kthvalue(all_weights_cat, k).values.item()

        total_pruned = 0
        for name, param in weight_info:
            effective_weights = param.data.abs()
            if 'gru' in name and recurrence_bias != 1.0:
                effective_weights = effective_weights / recurrence_bias

            mask = (effective_weights > threshold).float()
            self.masks[name] = mask
            param.data *= mask
            total_pruned += (mask == 0).sum().item()

            if name == 'input_fc.weight':
                self.model.input_mask.copy_(mask.clamp(0, 1))
            elif name == 'output_fc.weight':
                self.model.output_mask.copy_(mask.clamp(0, 1))

        return {'achieved_sparsity': self.get_sparsity(), 'weights_pruned': total_pruned}

    def gradient_guided_regrow(
        self,
        train_loader: DataLoader = None,
        regrow_fraction: float = None,
        n_batches: int = 5,
        init_scale: float = None,
        respect_irreversibility: bool = True
    ) -> Dict[str, Any]:
        """
        Regrow connections with gradient guidance.

        ┌──────────────────────────────────────────────────────────────────────┐
        │ GRADIENT-GUIDED REGROWTH:                                            │
        │                                                                      │
        │ This models treatment/plasticity-enhancing intervention.             │
        │                                                                      │
        │ PROCESS:                                                             │
        │ 1. Compute gradient importance for all weights (forward + backward) │
        │ 2. For each pruned position:                                         │
        │    • If irreversible (mask=-1) and respect_irreversibility=True:    │
        │      → SKIP (cannot regrow, models permanent damage)                │
        │    • If reversible (mask=0):                                        │
        │      → Rank by gradient magnitude                                   │
        │      → Regrow top regrow_fraction with highest gradient             │
        │ 3. Initialize regrown weights from original × init_scale            │
        │                                                                      │
        │ KEY INSIGHT:                                                         │
        │ When respect_irreversibility=True, irreversible connections         │
        │ CANNOT be restored. This is the core of prune-without-repair:       │
        │ treatment cannot fix permanent synaptic loss.                       │
        │                                                                      │
        │ RETURNS:                                                             │
        │ • connections_regrown: Number successfully restored                 │
        │ • connections_blocked_irreversible: Number that couldn't regrow     │
        └──────────────────────────────────────────────────────────────────────┘
        """
        if regrow_fraction is None:
            regrow_fraction = CONFIG['regrowth_fraction']
        if init_scale is None:
            init_scale = CONFIG['regrowth_init_scale']
        if train_loader is None:
            train_loader, _, _ = create_ci_task_dataloaders()

        self.model.train()
        device = next(self.model.parameters()).device

        # Compute gradient importance
        gradient_importance = {name: torch.zeros_like(param.data)
                               for name, param in self.model.named_parameters()
                               if 'weight' in name}

        criterion = nn.CrossEntropyLoss()
        for batch_idx, batch in enumerate(train_loader):
            if batch_idx >= n_batches:
                break
            data, labels = batch[0].to(device), batch[1].to(device)
            self.model.zero_grad()
            logits, _ = self.model(data)
            loss = criterion(logits.view(-1, CONFIG['output_dim']), labels.view(-1))
            loss.backward()

            for name, param in self.model.named_parameters():
                if 'weight' in name and param.grad is not None:
                    gradient_importance[name] += param.grad.abs()

        total_regrown = 0
        total_blocked = 0

        for name, param in self.model.named_parameters():
            if 'weight' not in name:
                continue

            mask = self.masks.get(name, torch.ones_like(param.data))

            # Identify regrowth candidates
            if respect_irreversibility:
                # Only reversible (mask=0) can regrow; irreversible (mask<0) blocked
                pruned_mask = (mask == 0)
                total_blocked += (mask < 0).sum().item()
            else:
                # All pruned can regrow
                pruned_mask = (mask <= 0)

            if pruned_mask.sum() == 0:
                continue

            n_regrow = int(regrow_fraction * pruned_mask.sum().item())
            if n_regrow == 0:
                continue

            # Rank by gradient importance
            importance = gradient_importance[name] * pruned_mask.float()
            flat_importance = importance.flatten()
            n_positive = (flat_importance > 0).sum().item()

            if n_positive == 0:
                continue

            _, top_indices = torch.topk(flat_importance, min(n_regrow, n_positive))

            # Regrow selected connections
            flat_mask = mask.flatten()
            flat_param = param.data.flatten()
            flat_original = self.original_weights[name].flatten()

            for idx in top_indices:
                flat_mask[idx] = 1.0
                flat_param[idx] = flat_original[idx] * init_scale

            self.masks[name] = flat_mask.view_as(mask)
            param.data = flat_param.view_as(param.data)

            # Update model masks
            if name == 'input_fc.weight':
                self.model.input_mask.copy_(self.masks[name].clamp(0, 1))
            elif name == 'output_fc.weight':
                self.model.output_mask.copy_(self.masks[name].clamp(0, 1))

            total_regrown += len(top_indices)

        self.history.append({
            'operation': 'gradient_guided_regrow',
            'connections_regrown': total_regrown,
            'connections_blocked': total_blocked
        })

        return {
            'connections_regrown': total_regrown,
            'connections_blocked_irreversible': total_blocked,
            'new_sparsity': self.get_sparsity()
        }

    def apply_masks(self):
        """Apply masks, treating negative values as zeros."""
        for name, param in self.model.named_parameters():
            if name in self.masks:
                param.data *= self.masks[name].clamp(0, 1)


# =============================================================================
# TRAINING FUNCTIONS
# =============================================================================

def train_epoch(
    model: StressAwareNetwork,
    train_loader: DataLoader,
    optimizer: torch.optim.Optimizer,
    criterion: nn.Module,
    device: torch.device,
    pruning_manager: Optional[CSTCPruningManager] = None,
    repair_factor: float = 1.0
) -> float:
    """
    Train for one epoch with optional repair factor.

    ┌──────────────────────────────────────────────────────────────────────────┐
    │ REPAIR FACTOR:                                                           │
    │                                                                          │
    │ Scales gradients during training to model impaired plasticity.           │
    │                                                                          │
    │ • repair_factor=1.0: Normal learning                                    │
    │ • repair_factor=0.5: 50% gradient scaling (impaired plasticity)         │
    │ • repair_factor=0.0: No learning (complete plasticity block)            │
    │                                                                          │
    │ BIOLOGICAL INTERPRETATION:                                               │
    │ Models NMDA receptor hypofunction or other plasticity impairments.      │
    │ Lower repair_factor → harder to consolidate regrown connections.        │
    └──────────────────────────────────────────────────────────────────────────┘
    """
    model.train()
    total_loss, n_batches = 0.0, 0

    for batch in train_loader:
        data, labels = batch[0].to(device), batch[1].to(device)
        optimizer.zero_grad()
        logits, _ = model(data)
        loss = criterion(logits.view(-1, CONFIG['output_dim']), labels.view(-1))
        loss.backward()

        # Apply repair factor (gradient scaling)
        if repair_factor < 1.0:
            with torch.no_grad():
                for param in model.parameters():
                    if param.grad is not None:
                        param.grad *= repair_factor

        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()

        # Maintain pruning masks
        if pruning_manager:
            pruning_manager.apply_masks()

        total_loss += loss.item()
        n_batches += 1

    return total_loss / n_batches


def train(
    model: StressAwareNetwork,
    train_loader: DataLoader = None,
    test_loader: DataLoader = None,
    epochs: int = None,
    lr: float = None,
    pruning_manager: Optional[CSTCPruningManager] = None,
    repair_factor: float = 1.0,
    verbose: bool = True,
    eval_interval: int = 10
) -> Dict[str, List[float]]:
    """Full training loop."""
    if train_loader is None or test_loader is None:
        train_loader, test_loader, _ = create_ci_task_dataloaders()
    if epochs is None:
        epochs = CONFIG['baseline_epochs']
    if lr is None:
        lr = CONFIG['baseline_lr']

    device = next(model.parameters()).device
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()

    history = {'loss': []}

    for epoch in range(epochs):
        loss = train_epoch(
            model, train_loader, optimizer, criterion, device,
            pruning_manager, repair_factor
        )
        history['loss'].append(loss)

        if verbose and ((epoch + 1) % eval_interval == 0 or epoch == epochs - 1):
            repair_str = f" [repair={repair_factor:.2f}]" if repair_factor < 1.0 else ""
            print(f"      Epoch {epoch+1:3d}/{epochs}: Loss={loss:.4f}{repair_str}")

    return history


# =============================================================================
# CI COMPUTATION
# =============================================================================

def compute_ci_metrics(
    model: StressAwareNetwork,
    device: torch.device,
    noise_levels: List[float] = None,
    n_test_sequences: int = 50,
    verbose: bool = False
) -> CIMetrics:
    """
    Compute comprehensive Cognitive Index (CI) metrics.

    ┌──────────────────────────────────────────────────────────────────────────┐
    │ CI COMPUTATION PIPELINE:                                                 │
    │                                                                          │
    │ 1. CRYSTALLIZED INTELLIGENCE                                             │
    │    Test on trained rules with no noise                                  │
    │    → crystallized_accuracy                                              │
    │                                                                          │
    │ 2. FLUID INTELLIGENCE                                                    │
    │    Test on novel rules (generalization)                                 │
    │    → fluid_accuracy, generalization_gap, transfer_efficiency            │
    │                                                                          │
    │ 3. EXECUTIVE FUNCTION                                                    │
    │    Test on multi-step integration trials                                │
    │    → executive_multi_step                                               │
    │                                                                          │
    │ 4. NOISE ROBUSTNESS                                                      │
    │    Test across noise levels                                             │
    │    → noise_robustness, noise_slope                                      │
    │                                                                          │
    │ 5. COMPOSITE CI                                                          │
    │    Weighted average: 35% fluid + 30% crystallized + 35% executive       │
    │    Scaled to CI = 115 + 15 × z_score                                    │
    │                                                                          │
    │ SIMPLIFICATION FROM v8.2:                                                │
    │ • No distractor sensitivity (removed ADHD-specific)                     │
    │ • No impulsivity index (removed ADHD-specific)                          │
    │ • No catastrophic threshold (removed SCZ-specific)                      │
    │ • No spikiness profiling (removed ASD-specific)                         │
    └──────────────────────────────────────────────────────────────────────────┘
    """
    if noise_levels is None:
        noise_levels = CONFIG['ci_noise_levels']

    model.eval()
    metrics = CIMetrics()

    # =========================================================================
    # 1. CRYSTALLIZED INTELLIGENCE
    # =========================================================================
    _, crystal_loader, _ = create_ci_task_dataloaders(
        n_train=10, n_test=n_test_sequences, noise_level=0.0, include_novel=False
    )

    crystal_correct = []
    with torch.no_grad():
        for batch in crystal_loader:
            data, labels, _ = batch[0].to(device), batch[1].to(device), batch[2]
            batch_size, seq_len, _ = data.shape
            hidden = model.init_hidden(batch_size, device)

            for t in range(seq_len):
                logits, hidden = model(data[:, t:t+1, :], hidden, return_hidden=True)
                preds = logits.squeeze(1).argmax(dim=-1)
                crystal_correct.extend((preds == labels[:, t]).cpu().tolist())

    metrics.crystallized_accuracy = np.mean(crystal_correct) if crystal_correct else 0.0

    # =========================================================================
    # 2. FLUID INTELLIGENCE
    # =========================================================================
    _, fluid_loader, _ = create_ci_task_dataloaders(
        n_train=10, n_test=n_test_sequences, noise_level=0.0, include_novel=True
    )

    novel_correct = []
    trained_correct = []

    with torch.no_grad():
        for batch in fluid_loader:
            data, labels, trial_types = batch[0].to(device), batch[1].to(device), batch[2].to(device)
            batch_size, seq_len, _ = data.shape
            hidden = model.init_hidden(batch_size, device)

            for t in range(seq_len):
                logits, hidden = model(data[:, t:t+1, :], hidden, return_hidden=True)
                preds = logits.squeeze(1).argmax(dim=-1)

                for b in range(batch_size):
                    is_correct = (preds[b] == labels[b, t]).item()
                    is_novel = (trial_types[b, t].item() & 1) == 1

                    if is_novel:
                        novel_correct.append(is_correct)
                    else:
                        trained_correct.append(is_correct)

    metrics.fluid_accuracy = np.mean(novel_correct) if novel_correct else 0.0
    if trained_correct and novel_correct:
        trained_acc = np.mean(trained_correct)
        metrics.fluid_generalization_gap = trained_acc - metrics.fluid_accuracy
        metrics.fluid_transfer_efficiency = metrics.fluid_accuracy / (trained_acc + 1e-8)

    # =========================================================================
    # 3. EXECUTIVE FUNCTION
    # =========================================================================
    multi_step_correct = []
    _, exec_loader, _ = create_ci_task_dataloaders(
        n_train=10, n_test=n_test_sequences, noise_level=0.0, include_novel=False
    )

    with torch.no_grad():
        for batch in exec_loader:
            data, labels, trial_types = batch[0].to(device), batch[1].to(device), batch[2].to(device)
            batch_size, seq_len, _ = data.shape
            hidden = model.init_hidden(batch_size, device)

            for t in range(seq_len):
                logits, hidden = model(data[:, t:t+1, :], hidden, return_hidden=True)
                preds = logits.squeeze(1).argmax(dim=-1)

                for b in range(batch_size):
                    is_correct = (preds[b] == labels[b, t]).item()
                    is_multi_step = (trial_types[b, t].item() & 2) == 2

                    if is_multi_step:
                        multi_step_correct.append(is_correct)

    metrics.executive_multi_step = np.mean(multi_step_correct) if multi_step_correct else 0.0

    # =========================================================================
    # 4. NOISE ROBUSTNESS
    # =========================================================================
    noise_performance = []
    for noise in noise_levels:
        _, noise_loader, _ = create_ci_task_dataloaders(
            n_train=10, n_test=n_test_sequences // 2, noise_level=noise, include_novel=False
        )
        correct = []
        with torch.no_grad():
            for batch in noise_loader:
                data, labels, _ = batch[0].to(device), batch[1].to(device), batch[2]
                batch_size, seq_len, _ = data.shape
                hidden = model.init_hidden(batch_size, device)
                for t in range(seq_len):
                    logits, hidden = model(data[:, t:t+1, :], hidden, return_hidden=True)
                    preds = logits.squeeze(1).argmax(dim=-1)
                    correct.extend((preds == labels[:, t]).cpu().tolist())
        noise_performance.append((noise, np.mean(correct) if correct else 0.0))

    if len(noise_performance) >= 2:
        noises = [p[0] for p in noise_performance]
        accs = [p[1] for p in noise_performance]
        metrics.noise_slope = -np.polyfit(noises, accs, 1)[0]
        metrics.noise_robustness = accs[-1] / (accs[0] + 1e-8) if accs[0] > 0 else 0.0

    # =========================================================================
    # 5. COMPUTE COMPOSITE CI
    # =========================================================================
    weights = dict(CONFIG['ci_domain_weights'])
    weight_sum = sum(weights.values())
    weights = {k: v / weight_sum for k, v in weights.items()}

    raw_composite = (
        weights['fluid'] * metrics.fluid_accuracy +
        weights['crystallized'] * metrics.crystallized_accuracy +
        weights['executive'] * metrics.executive_multi_step
    )
    metrics.raw_composite = raw_composite

    # Use global calibration if available
    if GLOBAL_CALIBRATION.calibrated:
        healthy_raw = GLOBAL_CALIBRATION.healthy_raw_composite
        pop_sd = GLOBAL_CALIBRATION.population_sd_raw
        healthy_target = GLOBAL_CALIBRATION.healthy_ci
    else:
        healthy_raw = CONFIG['ci_raw_anchor']
        pop_sd = CONFIG['ci_raw_sd']
        healthy_target = CONFIG['ci_healthy_target']

    z_score = (raw_composite - healthy_raw) / pop_sd
    composite_ci = healthy_target + 15.0 * z_score

    # Clip to realistic range
    metrics.composite_ci = float(np.clip(composite_ci, 55, 145))

    # Record model state
    metrics.sparsity = model.get_sparsity()
    metrics.glutamate_factor = model.glutamate_factor
    metrics.calibration_factor = getattr(model, 'calibration_factor', 1.0)
    metrics.repair_factor = getattr(model, 'repair_factor', 1.0)

    return metrics


# =============================================================================
# IRREVERSIBILITY SWEEP EXPERIMENT (CORE NEW FUNCTION)
# =============================================================================

def run_irreversibility_sweep(
    device: torch.device,
    irreversibility_levels: List[float] = None,
    n_seeds: int = None,
    verbose: bool = True
) -> Dict[float, Dict[str, Any]]:
    """
    Diagnosis-free sweep: Vary irreversibility after fixed excessive pruning.
    Tests pure prune-without-repair hypothesis.

    ╔═══════════════════════════════════════════════════════════════════════════════╗
    ║                     IRREVERSIBILITY SWEEP EXPERIMENT                          ║
    ╠═══════════════════════════════════════════════════════════════════════════════╣
    ║                                                                               ║
    ║  HYPOTHESIS TESTED:                                                           ║
    ║  ──────────────────                                                           ║
    ║  The prune-without-repair hypothesis predicts that irreversible synaptic     ║
    ║  pruning leads to treatment-resistant cognitive deficits.                    ║
    ║                                                                               ║
    ║  EXPERIMENTAL DESIGN:                                                         ║
    ║  ─────────────────────                                                        ║
    ║  • FIXED: Pruning severity (calibration=1.8, sparsity=0.95)                 ║
    ║  • FIXED: Repair factor (0.5 - impaired plasticity)                         ║
    ║  • VARIED: Irreversibility factor (0.0 to 1.0)                              ║
    ║                                                                               ║
    ║  PROTOCOL FOR EACH IRREVERSIBILITY LEVEL:                                    ║
    ║  ──────────────────────────────────────────                                   ║
    ║  1. Train healthy model → calibrate CI                                       ║
    ║  2. Train new model with impaired plasticity                                 ║
    ║  3. Apply excessive pruning with specified irreversibility                  ║
    ║  4. Measure pre-treatment CI                                                 ║
    ║  5. Apply treatment (gradient-guided regrowth + consolidation)              ║
    ║  6. Measure post-treatment CI                                                ║
    ║  7. Compute recovery delta (post - pre)                                      ║
    ║                                                                               ║
    ║  MULTI-SEED AVERAGING:                                                        ║
    ║  ──────────────────────                                                       ║
    ║  Each condition run with n_seeds different random seeds.                     ║
    ║  Report mean ± SD for robust estimates.                                      ║
    ║                                                                               ║
    ║  EXPECTED RESULTS:                                                            ║
    ║  ─────────────────                                                            ║
    ║  ┌───────────────────────────────────────────────────────────────────────┐   ║
    ║  │ Irreversibility │ Pre-CI │ Post-CI │ Recovery │ Interpretation       │   ║
    ║  ├───────────────────────────────────────────────────────────────────────┤   ║
    ║  │ 0.0 (reversible)│ ~85-95 │ ~110-115│ +20-30   │ Near-full recovery   │   ║
    ║  │ 0.2             │ ~83-93 │ ~105-112│ +15-20   │ Good but incomplete  │   ║
    ║  │ 0.4             │ ~80-90 │ ~95-105 │ +8-15    │ Moderate             │   ║
    ║  │ 0.6             │ ~75-85 │ ~85-95  │ +2-8     │ Minimal              │   ║
    ║  │ 0.8             │ ~70-80 │ ~75-85  │ 0 to +3  │ Near-zero            │   ║
    ║  │ 1.0 (permanent) │ ~65-75 │ ~65-75  │ ~0       │ No recovery          │   ║
    ║  └───────────────────────────────────────────────────────────────────────┘   ║
    ║                                                                               ║
    ╚═══════════════════════════════════════════════════════════════════════════════╝

    Args:
        device: PyTorch device
        irreversibility_levels: List of irreversibility factors to test (0.0-1.0)
        n_seeds: Number of random seeds for averaging
        verbose: Print detailed output

    Returns:
        Dictionary mapping irreversibility level to result statistics
    """
    global GLOBAL_CALIBRATION

    if irreversibility_levels is None:
        irreversibility_levels = CONFIG['irreversibility_levels']
    if n_seeds is None:
        n_seeds = CONFIG['n_seeds']

    # Fixed parameters for the experiment
    PRUNING_CALIBRATION = CONFIG['pruning_calibration_factor']
    BASE_SPARSITY = CONFIG['base_sparsity']
    REPAIR_FACTOR = CONFIG['repair_factor']
    TREATMENT_REGROW_FRACTION = CONFIG['treatment_regrowth_fraction']

    if verbose:
        print_section_header("IRREVERSIBILITY SWEEP EXPERIMENT")
        print("""
    ╔════════════════════════════════════════════════════════════════════════════╗
    ║              PRUNE-WITHOUT-REPAIR HYPOTHESIS TEST                          ║
    ╠════════════════════════════════════════════════════════════════════════════╣
    ║                                                                            ║
    ║  This experiment tests whether irreversible synaptic pruning causes       ║
    ║  treatment-resistant cognitive deficits.                                   ║
    ║                                                                            ║
    ║  FIXED PARAMETERS:                                                         ║
    ║  ─────────────────                                                         ║""")
        print(f"    ║  • Pruning calibration factor: {PRUNING_CALIBRATION}                                 ║")
        print(f"    ║  • Base sparsity: {BASE_SPARSITY} (~{int(BASE_SPARSITY*100)}% connections pruned)                        ║")
        print(f"    ║  • Repair factor: {REPAIR_FACTOR} (impaired plasticity)                           ║")
        print(f"    ║  • Treatment regrowth fraction: {TREATMENT_REGROW_FRACTION}                               ║")
        print(f"    ║  • Seeds per condition: {n_seeds}                                            ║")
        print("""    ║                                                                            ║
    ║  IRREVERSIBILITY LEVELS TESTED:                                           ║
    ║  ──────────────────────────────                                            ║""")
        print(f"    ║  {irreversibility_levels}                                  ║")
        print("""    ║                                                                            ║
    ║  INTERPRETATION:                                                           ║
    ║  • 0.0 = All pruned connections can regrow (fully reversible)             ║
    ║  • 1.0 = No pruned connections can regrow (fully irreversible)            ║
    ║                                                                            ║
    ╚════════════════════════════════════════════════════════════════════════════╝
        """)

    summary = {}
    train_loader, test_loader, _ = create_ci_task_dataloaders(include_novel=False)

    # =========================================================================
    # CALIBRATION (done once at the start)
    # =========================================================================
    if verbose:
        print_subsection_header("Phase 1: Healthy Calibration")
        print_annotation("Training healthy model to establish CI baseline...")

    set_seed(CONFIG['seed'])
    model_healthy = StressAwareNetwork(glutamate_factor=1.0, repair_factor=1.0).to(device)
    train(model_healthy, train_loader, test_loader, repair_factor=1.0, verbose=False)
    mgr_healthy = CSTCPruningManager(model_healthy)
    mgr_healthy.calibrate_prune(calibration_factor=1.0)
    healthy_metrics = compute_ci_metrics(model_healthy, device)

    GLOBAL_CALIBRATION.healthy_raw_composite = healthy_metrics.raw_composite
    GLOBAL_CALIBRATION.healthy_ci = 115.0
    GLOBAL_CALIBRATION.healthy_fluid = healthy_metrics.fluid_accuracy
    GLOBAL_CALIBRATION.healthy_crystallized = healthy_metrics.crystallized_accuracy
    GLOBAL_CALIBRATION.healthy_executive = healthy_metrics.executive_multi_step
    GLOBAL_CALIBRATION.calibrated = True

    # Re-compute with calibration
    healthy_metrics = compute_ci_metrics(model_healthy, device)

    if verbose:
        print_box([
            f"Healthy raw composite: {healthy_metrics.raw_composite:.3f}",
            f"Healthy CI (anchored): {healthy_metrics.composite_ci:.1f}",
            f"Sparsity: {model_healthy.get_sparsity()*100:.1f}%",
            f"",
            "All subsequent CI values are relative to this baseline."
        ], title="CALIBRATION COMPLETE")

    # =========================================================================
    # IRREVERSIBILITY SWEEP
    # =========================================================================
    if verbose:
        print_subsection_header("Phase 2: Irreversibility Sweep")
        print_annotation(f"Testing {len(irreversibility_levels)} irreversibility levels × {n_seeds} seeds...")
        print()

    for irr_idx, irr in enumerate(irreversibility_levels):
        if verbose:
            irr_label = "FULLY REVERSIBLE" if irr == 0.0 else "FULLY IRREVERSIBLE" if irr == 1.0 else f"{irr*100:.0f}% IRREVERSIBLE"
            print_annotation(f"Testing irreversibility = {irr:.1f} ({irr_label})", prefix="●")

        ci_pre_list, ci_post_list, recovery_list, blocked_list = [], [], [], []
        sparsity_list, irr_fraction_list = [], []

        for seed_offset in range(n_seeds):
            seed = CONFIG['seed'] + seed_offset + irr_idx * 100  # Ensure unique seeds
            set_seed(seed)

            # Train model with impaired plasticity
            model = StressAwareNetwork(glutamate_factor=1.0, repair_factor=REPAIR_FACTOR).to(device)
            train(model, train_loader, test_loader, repair_factor=REPAIR_FACTOR, verbose=False)

            # Apply excessive pruning with specified irreversibility
            mgr = CSTCPruningManager(model)
            prune_result = mgr.excessive_prune(
                base_sparsity=BASE_SPARSITY,
                irreversibility_factor=irr,
                calibration_factor=PRUNING_CALIBRATION
            )

            sparsity_list.append(prune_result['achieved_sparsity'])
            irr_fraction_list.append(prune_result['irreversible_fraction'])

            # Pre-treatment CI
            pre_metrics = compute_ci_metrics(model, device)
            pre_ci = pre_metrics.composite_ci
            ci_pre_list.append(pre_ci)

            # Treatment: gradient-guided regrowth
            regrow_result = mgr.gradient_guided_regrow(
                train_loader=train_loader,
                regrow_fraction=TREATMENT_REGROW_FRACTION,
                respect_irreversibility=True
            )
            blocked_list.append(regrow_result['connections_blocked_irreversible'])

            # Consolidation training
            consolidation_lr = CONFIG['baseline_lr'] * CONFIG['treatment_consolidation_lr_factor']
            train(model, train_loader, test_loader,
                  epochs=CONFIG['treatment_consolidation_epochs'],
                  lr=consolidation_lr,
                  repair_factor=REPAIR_FACTOR,
                  pruning_manager=mgr,
                  verbose=False)

            # Post-treatment CI
            post_metrics = compute_ci_metrics(model, device)
            post_ci = post_metrics.composite_ci

            # Apply recovery ceiling
            max_recovery_ci = GLOBAL_CALIBRATION.max_recovery_ci
            post_ci = min(post_ci, max_recovery_ci)

            ci_post_list.append(post_ci)
            recovery_list.append(post_ci - pre_ci)

            if verbose:
                print_debug(seed, f"irr={irr:.1f}", post_ci,
                           extra=f"pre={pre_ci:.1f}, Δ={post_ci-pre_ci:+.1f}")

        # Aggregate results for this irreversibility level
        summary[irr] = {
            'pre_ci_mean': np.mean(ci_pre_list),
            'pre_ci_std': np.std(ci_pre_list),
            'post_ci_mean': np.mean(ci_post_list),
            'post_ci_std': np.std(ci_post_list),
            'recovery_mean': np.mean(recovery_list),
            'recovery_std': np.std(recovery_list),
            'avg_blocked_connections': np.mean(blocked_list),
            'avg_sparsity': np.mean(sparsity_list),
            'avg_irreversible_fraction': np.mean(irr_fraction_list),
            'irreversibility_factor': irr,
            'all_pre_ci': ci_pre_list,
            'all_post_ci': ci_post_list,
            'all_recovery': recovery_list,
        }

        if verbose:
            s = summary[irr]
            print_annotation(f"  Mean: Pre={s['pre_ci_mean']:.1f}±{s['pre_ci_std']:.1f}, "
                           f"Post={s['post_ci_mean']:.1f}±{s['post_ci_std']:.1f}, "
                           f"Recovery={s['recovery_mean']:+.1f}±{s['recovery_std']:.1f}", prefix=" ")
            print()

    # =========================================================================
    # RESULTS SUMMARY
    # =========================================================================
    if verbose:
        print_section_header("IRREVERSIBILITY SWEEP RESULTS")

        # Visual gradient indicator
        print("""
    ┌──────────────────────────────────────────────────────────────────────────────┐
    │                                                                              │
    │      Reversible ◄────────────────────────────────────────► Irreversible     │
    │      (treatment                                              (treatment      │
    │       effective)                                              resistant)     │
    │                                                                              │
    │      0.0        0.2        0.4        0.6        0.8        1.0             │
    │       │          │          │          │          │          │               │
    │       ▼          ▼          ▼          ▼          ▼          ▼               │
    │     ████████  ██████    ████      ██        ░        ░░░░░░░░░░             │
    │     Recovery  Good      Partial   Minimal   Near-    No                     │
    │     Complete  Recovery  Recovery  Recovery  Zero     Recovery               │
    │                                                                              │
    └──────────────────────────────────────────────────────────────────────────────┘
        """)

        # Results table
        print("    ┌─────────────────┬──────────────┬──────────────┬──────────────┬───────────┐")
        print("    │ Irreversibility │ Pre-Tx CI    │ Post-Tx CI   │ Recovery Δ   │ Blocked % │")
        print("    ├─────────────────┼──────────────┼──────────────┼──────────────┼───────────┤")

        for irr in irreversibility_levels:
            s = summary[irr]
            # Estimate total connections (~10k for this architecture)
            total_connections = sum(p.numel() for p in model.parameters() if 'weight' in str(p.shape))
            blocked_pct = s['avg_blocked_connections'] / max(1, total_connections) * 100

            recovery_str = f"{s['recovery_mean']:+6.1f}±{s['recovery_std']:4.1f}"

            print(f"    │ {irr:15.1f} │ {s['pre_ci_mean']:6.1f}±{s['pre_ci_std']:4.1f} │ "
                  f"{s['post_ci_mean']:6.1f}±{s['post_ci_std']:4.1f} │ "
                  f"{recovery_str} │ {blocked_pct:6.1f}%   │")

        print("    └─────────────────┴──────────────┴──────────────┴──────────────┴───────────┘")

        # Interpretation
        print("""
    ┌──────────────────────────────────────────────────────────────────────────────┐
    │ INTERPRETATION:                                                              │
    │                                                                              │
    │ The results above directly test the PRUNE-WITHOUT-REPAIR hypothesis:        │
    │                                                                              │
    │ • At IRREVERSIBILITY = 0.0 (fully reversible):                              │
    │   Treatment can restore most pruned connections.                            │
    │   Recovery is substantial (prune-WITH-repair).                              │
    │                                                                              │
    │ • At IRREVERSIBILITY = 1.0 (fully irreversible):                            │
    │   Treatment CANNOT restore pruned connections (blocked by mask=-1).         │
    │   Recovery is near-zero (prune-WITHOUT-repair).                             │
    │                                                                              │
    │ • GRADIENT: Higher irreversibility → less recovery, confirming that         │
    │   PERMANENT synaptic loss underlies treatment-resistant deficits.           │
    │                                                                              │
    │ BIOLOGICAL PARALLEL:                                                         │
    │ • Irreversible pruning models complement-tagged synapses (C4A mechanism)   │
    │ • Once microglia eliminate tagged synapses, they cannot be restored        │
    │ • This explains why some neurodevelopmental conditions are treatment-      │
    │   resistant despite intact learning machinery                               │
    │                                                                              │
    └──────────────────────────────────────────────────────────────────────────────┘
        """)

        # Statistical summary
        irr_low = [irr for irr in irreversibility_levels if irr <= 0.2]
        irr_high = [irr for irr in irreversibility_levels if irr >= 0.8]

        if irr_low and irr_high:
            low_recovery = np.mean([summary[irr]['recovery_mean'] for irr in irr_low])
            high_recovery = np.mean([summary[irr]['recovery_mean'] for irr in irr_high])

            print_box([
                f"Mean recovery at low irreversibility (≤0.2):  {low_recovery:+.1f} CI points",
                f"Mean recovery at high irreversibility (≥0.8): {high_recovery:+.1f} CI points",
                f"",
                f"Effect of irreversibility on recovery: {low_recovery - high_recovery:.1f} CI points",
                f"",
                "This difference represents the cognitive cost of irreversible",
                "synaptic pruning - the core of the prune-without-repair hypothesis."
            ], title="STATISTICAL SUMMARY")

    return summary


# =============================================================================
# MAIN ENTRY POINT
# =============================================================================

def main():
    """Run the irreversibility sweep experiment."""
    print("\n" + "█" * 80)
    print("█" + " " * 78 + "█")
    print("█" + "PRUNE-WITHOUT-REPAIR HYPOTHESIS TEST".center(78) + "█")
    print("█" + "Irreversibility Sweep Experiment v9.0".center(78) + "█")
    print("█" + " " * 78 + "█")
    print("█" * 80)

    print(f"\n  PyTorch Version: {torch.__version__}")
    print(f"  Device: {DEVICE}")
    print(f"  Random Seed Base: {CONFIG['seed']}")
    print(f"  Seeds per Condition: {CONFIG['n_seeds']}")

    set_seed(CONFIG['seed'])

    # Run the irreversibility sweep
    results = run_irreversibility_sweep(DEVICE, verbose=True)

    print("\n" + "█" * 80)
    print("█" + "EXPERIMENT COMPLETE".center(78) + "█")
    print("█" * 80)

    return results


if __name__ == "__main__":
    results = main()


████████████████████████████████████████████████████████████████████████████████
█                                                                              █
█                     PRUNE-WITHOUT-REPAIR HYPOTHESIS TEST                     █
█                    Irreversibility Sweep Experiment v9.0                     █
█                                                                              █
████████████████████████████████████████████████████████████████████████████████

  PyTorch Version: 2.9.0+cpu
  Device: cpu
  Random Seed Base: 42
  Seeds per Condition: 10

                        IRREVERSIBILITY SWEEP EXPERIMENT                        

    ╔════════════════════════════════════════════════════════════════════════════╗
    ║              PRUNE-WITHOUT-REPAIR HYPOTHESIS TEST                          ║
    ╠════════════════════════════════════════════════════════════════════════════╣
    ║                                                                            ║
    ║

# The End