# Impoved version v5

In [None]:
"""
================================================================================
HIERARCHICAL PREDICTIVE CODING MODEL OF PSYCHOSIS
ISO-DOSE TREATMENT COMPARISON: KETAMINE vs ANTIPSYCHOTICS vs ECT
================================================================================

Disease model (Bayesian predictive coding):
  Psychosis = precision imbalance + synaptic pruning across cortical hierarchy

  Positive symptoms (precision dysregulation):
    Sensory layers:  LOW gain + HIGH noise -> attenuated bottom-up PE
    Belief layers:   HIGH gain + LOW noise -> overprecise top-down priors
    + Aberrant plasticity (spurious strong connections at higher layers)
    + Global hyperexcitability

  Negative symptoms (synaptic pruning + capacity loss + effort deficit):
    Reversible pruning:      weights scaled to near-zero (treatable)
    Irreversible pruning:    weights permanently zeroed (treatment-resistant)
    Capacity gating:         unit_masks gate heavily-pruned units
    Hard death:              units below floor contribute ZERO
    Cascading degradation:   weakened units degrade via positive feedback
    Capacity-dependent exc:  reduced capacity -> blunted layer output
    Effort/motivation:       effort parameter reduces higher-layer drive
    Cognitive deficit:       measured on demanding probe task

All state stored as nn.Parameters in the network:
  Layer weights & biases     (nn.Linear, requires_grad=True)
  Per-layer gains            (nn.Parameter, requires_grad=False)
  Per-layer noise levels     (nn.Parameter, requires_grad=False)
  Global excitability        (nn.Parameter, requires_grad=False)
  Effort/motivation          (nn.Parameter, requires_grad=False)
  Per-unit capacity masks    (nn.ParameterList, requires_grad=False)
  Per-unit side-effect noise (nn.ParameterList, requires_grad=False)

Treatments (ALL effects are network-parameter modifications):
  Ketamine:       NMDA antagonism -> WORSENS positive + synaptogenesis
  Antipsychotics: D2 antagonism   -> TREATS positive, minimal synaptogenesis
  ECT:            Seizure reset   -> TREATS both, strongest synaptogenesis

Extensions:
  Separate positive / negative severity for patient profiles
  Chronic treatment + withdrawal with treatment-specific decay
  Relapse heterogeneity (precision fast, structural slow)
  Cognitive probe task for sensitive negative-symptom measurement
  10 patient profiles with varying positive/negative ratios

Improvements (v2):
  [1] Decoupled consolidation: unified epochs config + zero-consolidation ablation
  [2] Fixed iso-dose: reduced ECT noise + targeted dose metric
  [3] Fixed hallucination metric: logit-margin threshold (avoids softmax saturation)
  [4] Normalized PE metric: relative mismatch, not absolute scale

Iso-dose metric: L1 norm of ALL parameter changes (mechanism-agnostic)
                + targeted L1 (precision + masks + effort only)
================================================================================
"""

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.utils.data import DataLoader, TensorDataset
from dataclasses import dataclass
import warnings

warnings.filterwarnings('ignore')

SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)
DEVICE = torch.device('cpu')

# ============================================================================
# CONFIGURATION
# ============================================================================
CFG = {
    'n_train': 10000, 'n_test': 3000, 'n_clean': 2000,
    'data_noise': 0.7, 'batch_size': 128,
    'hidden_dims': [256, 256, 128], 'in_dim': 2, 'out_dim': 4,
    'base_epochs': 25, 'base_lr': 1e-3,

    # Precision profiles
    'healthy_gains':   [1.0, 1.0, 1.0],
    'healthy_noise':   [0.0, 0.0, 0.0],
    'psychosis_gains': [0.30, 0.60, 2.50],
    'psychosis_noise': [1.20, 0.60, 0.05],
    'psychosis_exc':   1.50,
    'aberrant_frac': 0.03, 'aberrant_scale': 3.0,
    'aberrant_bias': [0.10, 0.30, 0.60],

    # Synaptic pruning (negative symptoms)
    'pruning_frac_base': 0.25,
    'pruning_irrev_ratio': 0.10,
    'pruning_bias': [0.30, 0.55, 0.90],
    'pruning_scale': 0.01,

    # Capacity gating
    'unit_death_threshold': 0.10,
    'unit_alive_floor': 0.08,
    'unit_death_rate': 2.5,
    'pruning_cascade': 0.55,
    'cap_exc_floor': 0.50,

    # Effort / motivation (negative symptoms)
    'effort_healthy': 1.0,
    'effort_deficit_rate': 0.35,

    # Cognitive probe task
    'cog_center_scale': 0.50,
    'cog_noise': 1.2,
    'cog_n_test': 2000,

    # ── [FIX 1] Unified consolidation epochs ──
    'consolidation_epochs': {
        'default': 5,
        'ketamine': 5,
        'antipsychotic': 5,
        'ect_per_session': 3,
    },
    'run_zero_consolidation': False,

    # Ketamine
    'ket_gain_atten': 0.50, 'ket_noise_boost': 1.50,
    'ket_aberrant_boost': 1.50, 'ket_exc_boost': 1.30,
    'ket_synaptogen': 0.60, 'ket_cap_restore': 0.80,
    'ket_effort_restore': 0.35,
    'ket_side_noise': 0.05,
    'ket_epochs': 5, 'ket_lr': 3e-4,

    # Antipsychotics
    'ap_gain_norm': 0.70, 'ap_noise_red': 0.50,
    'ap_aberrant_weak': 0.60, 'ap_exc_norm': 0.70,
    'ap_synaptogen': 0.05, 'ap_cap_restore': 0.30,
    'ap_effort_restore': 0.05,
    'ap_side_noise': 0.03,
    'ap_epochs': 30, 'ap_lr': 3e-4,

    # ECT
    'ect_gain_reset': 0.12, 'ect_noise_reset': 0.12,
    'ect_weight_reset': 0.05, 'ect_aberrant_reset': 0.10,
    'ect_seizure_noise': 0.015,
    'ect_exc_reset': 0.10,
    'ect_synaptogen': 0.15, 'ect_cap_restore': 1.0,
    'ect_effort_restore': 0.12,
    'ect_irrev_recovery': 0.03,
    'ect_postictal_noise': 0.08, 'ect_se_decay': 0.70,
    'ect_retrain_ep': 3, 'ect_retrain_lr': 5e-4,

    # ── [FIX 3 v2] Logit-margin hallucination threshold ──
    'halluc_logit_thr': None,       # Set dynamically after healthy training

    # Chronic treatment + withdrawal
    'chronic_steps': 12,
    'maintenance_dose_frac': 0.40,
    'maintenance_interval': 3,
    'withdrawal_steps': 15,
    'relapse_drift_rate': 0.025,
    'treatment_durability': {
        'ketamine': 0.25,
        'antipsychotic': 0.55,
        'ect': 0.85,
    },
    'homeostatic_rate': 0.005,

    # Side-effect decay rates (per homeostatic step)
    'se_decay_rate': {
        'ketamine': 0.12,
        'antipsychotic': 0.02,
        'ect': 0.08,
    },
    'se_decay_default': 0.05,

    # Relapse heterogeneity (multipliers on base frac)
    'relapse_precision_mult': 1.0,
    'relapse_pruning_mult': 0.6,
    'relapse_capacity_mult': 0.5,
    'relapse_effort_mult': 0.7,

    # 10 patient profiles (severity_pos, severity_neg)
    'patient_seeds': [42, 137, 256, 314, 501, 619, 733, 842, 951, 1066],
    'patient_profiles': [
        (1.2, 0.8),
        (0.8, 1.2),
        (1.0, 1.0),
        (1.5, 0.5),
        (0.5, 1.5),
        (1.1, 1.1),
        (0.9, 0.9),
        (1.3, 0.7),
        (0.7, 1.3),
        (1.0, 1.2),
    ],

    # Sweeps
    'ket_sweep':  [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9],
    'ap_sweep':   [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0],
    'ect_sweep':  [1, 2, 3, 4, 6, 8, 10, 12],

    'stress': {'none': 0.0, 'mild': 0.3, 'moderate': 0.7,
               'severe': 1.2, 'extreme': 2.0},
    'relapse_revert': 0.40,
}


# ============================================================================
# DATA
# ============================================================================
def make_blobs(n, noise, seed):
    rng = np.random.RandomState(seed)
    ctrs = np.array([[-3, -3], [3, 3], [-3, 3], [3, -3]])
    y = rng.randint(0, 4, n)
    X = ctrs[y] + rng.randn(n, 2) * noise
    return torch.tensor(X, dtype=torch.float32), torch.tensor(y, dtype=torch.long)

def make_cog_probe(n, seed):
    rng = np.random.RandomState(seed)
    s = CFG['cog_center_scale']
    ctrs = np.array([[-3, -3], [3, 3], [-3, 3], [3, -3]]) * s
    y = rng.randint(0, 4, n)
    X = ctrs[y] + rng.randn(n, 2) * CFG['cog_noise']
    return torch.tensor(X, dtype=torch.float32), torch.tensor(y, dtype=torch.long)

def make_all_loaders():
    trd, trl = make_blobs(CFG['n_train'], CFG['data_noise'], 100)
    ted, tel = make_blobs(CFG['n_test'],  CFG['data_noise'], 200)
    cld, cll = make_blobs(CFG['n_clean'], 0.0, 300)
    cod, col = make_cog_probe(CFG['cog_n_test'], 400)
    return (DataLoader(TensorDataset(trd, trl), CFG['batch_size'], shuffle=True),
            DataLoader(TensorDataset(ted, tel), 1000),
            DataLoader(TensorDataset(cld, cll), 1000),
            DataLoader(TensorDataset(cod, col), 1000))

train_ld, test_ld, clean_ld, cog_ld = make_all_loaders()


# ============================================================================
# NETWORK
# ============================================================================
class PCNet(nn.Module):
    """
    Predictive-coding network.  ALL state is in nn.Parameters.

    Forward per hidden layer i:
      h = ReLU(W_i x + b_i)
      h *= gain_i * exc * layer_effort(effort, i)
      alive_gate = (unit_mask_i > floor).float()
      h *= unit_mask_i * alive_gate
      cap_scale = cap_floor + (1-cap_floor) * mean(unit_mask_i * alive_gate)
      h *= cap_scale
      h += N(0, noise_i + stress) + N(0, |side_effect_i|)
    """
    def __init__(self):
        super().__init__()
        dims = [CFG['in_dim']] + CFG['hidden_dims'] + [CFG['out_dim']]
        self.layers = nn.ModuleList(
            [nn.Linear(dims[i], dims[i+1]) for i in range(len(dims)-1)])
        self.n_hid = len(CFG['hidden_dims'])
        self.relu = nn.ReLU()

        self.gains = nn.Parameter(
            torch.tensor(CFG['healthy_gains'], dtype=torch.float32),
            requires_grad=False)
        self.noise_stds = nn.Parameter(
            torch.tensor(CFG['healthy_noise'], dtype=torch.float32),
            requires_grad=False)
        self.exc = nn.Parameter(torch.tensor(1.0), requires_grad=False)
        self.effort = nn.Parameter(
            torch.tensor(CFG['effort_healthy'], dtype=torch.float32),
            requires_grad=False)

        self.unit_masks = nn.ParameterList([
            nn.Parameter(torch.ones(d, dtype=torch.float32),
                         requires_grad=False)
            for d in CFG['hidden_dims']
        ])
        self.side_effects = nn.ParameterList([
            nn.Parameter(torch.zeros(d, dtype=torch.float32),
                         requires_grad=False)
            for d in CFG['hidden_dims']
        ])

        self.ext_stress = 0.0
        self._acts = []

    def forward(self, x):
        self._acts = []
        h = x
        for i in range(self.n_hid):
            h = self.relu(self.layers[i](h))
            le = 1.0 - (1.0 - self.effort.item()) * (i + 1) / self.n_hid
            h = h * self.gains[i] * self.exc * le
            floor = CFG['unit_alive_floor']
            alive_gate = (self.unit_masks[i].data > floor).float()
            h = h * self.unit_masks[i] * alive_gate
            cm = (self.unit_masks[i] * alive_gate).mean().item()
            cs = CFG['cap_exc_floor'] + (1.0 - CFG['cap_exc_floor']) * cm
            h = h * cs
            ns = self.noise_stds[i].item() + self.ext_stress
            if ns > 0:
                sc = ns if self.training else ns * 0.5
                h = h + torch.randn_like(h) * sc
            se = self.side_effects[i]
            if se.abs().max().item() > 1e-6:
                ssc = 1.0 if self.training else 0.5
                h = h + torch.randn_like(h) * se.abs() * ssc
            self._acts.append(h.detach())
        return self.layers[-1](h)

    # ── [FIX 4] Normalized PE: relative mismatch, not absolute scale ──
    def compute_pes(self, x):
        self.eval()
        with torch.no_grad():
            _ = self(x)
        pes = []
        for i in range(len(self._acts) - 1):
            W = self.layers[i+1].weight
            pred = self._acts[i+1] @ W
            act = self._acts[i]
            residual = (act - pred) / (act.abs().mean(dim=1, keepdim=True) + 1e-8)
            pes.append(residual.abs().mean().item())
        return pes

    def snap(self):
        return {k: v.clone() for k, v in self.state_dict().items()}

    def load_snap(self, s):
        self.load_state_dict({k: v.clone() for k, v in s.items()})


# ============================================================================
# PSYCHOSIS MANAGER
# ============================================================================
class PsychMgr:
    def __init__(self, model: PCNet):
        self.m = model
        self.healthy = None
        self.ab_masks    = {}
        self.ab_orig     = {}
        self.prune_masks = {}
        self.prune_orig  = {}
        self.irrev_masks = {}

    def save_healthy(self):
        self.healthy = self.m.snap()

    def induce(self, severity_pos=1.0, severity_neg=1.0):
        """
        Induce psychosis with separate positive/negative severity.
        All effects are network-parameter modifications.
        """
        # ─── 1. Precision imbalance (positive symptoms) ───
        hg = torch.tensor(CFG['healthy_gains'], dtype=torch.float32)
        pg = torch.tensor(CFG['psychosis_gains'], dtype=torch.float32)
        hn = torch.tensor(CFG['healthy_noise'], dtype=torch.float32)
        pn = torch.tensor(CFG['psychosis_noise'], dtype=torch.float32)
        self.m.gains.data.copy_(hg + severity_pos * (pg - hg))
        self.m.noise_stds.data.copy_(hn + severity_pos * (pn - hn))
        self.m.exc.data.fill_(
            1.0 + severity_pos * (CFG['psychosis_exc'] - 1.0))

        # ─── 2. Aberrant plasticity (positive symptoms) ───
        tot_aberrant = 0
        frac  = CFG['aberrant_frac'] * severity_pos
        scale = 1.0 + (CFG['aberrant_scale'] - 1.0) * severity_pos
        bias  = CFG['aberrant_bias']
        for i in range(self.m.n_hid):
            nm = f'layers.{i}.weight'
            w  = self.m.layers[i].weight.data
            self.ab_orig[nm] = w.clone()
            n  = w.numel()
            nab = int(n * frac * bias[min(i, len(bias)-1)])
            mask = torch.zeros(w.shape, dtype=torch.bool)
            if nab > 0:
                flat = w.flatten()
                idx  = torch.randperm(flat.numel())[:nab]
                mf   = torch.zeros(flat.numel(), dtype=torch.bool)
                mf[idx] = True
                mask = mf.view_as(w)
                flat[idx] *= scale
                flat[idx] += torch.randn(nab) * 0.5 * severity_pos
                w.copy_(flat.view_as(w))
                tot_aberrant += nab
            self.ab_masks[nm] = mask

        # ─── 3. Synaptic pruning (negative symptoms) ───
        tot_rev = tot_irrev = 0
        thr   = CFG['unit_death_threshold']
        floor = CFG['unit_alive_floor']
        dr    = CFG['unit_death_rate']

        for i in range(self.m.n_hid):
            nm = f'layers.{i}.weight'
            w  = self.m.layers[i].weight.data
            self.prune_orig[nm] = w.clone()
            n = w.numel()
            p_frac = (severity_neg * CFG['pruning_frac_base']
                      * CFG['pruning_bias'][min(i, 2)])
            n_prune = int(n * p_frac)

            rev_mask   = torch.zeros(w.shape, dtype=torch.bool)
            irrev_mask = torch.zeros(w.shape, dtype=torch.bool)

            if n_prune > 0:
                n_irrev = int(n_prune * CFG['pruning_irrev_ratio'])
                n_rev   = n_prune - n_irrev
                ab_flat = self.ab_masks[nm].flatten()
                avail   = (~ab_flat).nonzero(as_tuple=True)[0]
                if len(avail) < n_prune:
                    n_prune = len(avail)
                    n_irrev = int(n_prune * CFG['pruning_irrev_ratio'])
                    n_rev   = n_prune - n_irrev

                perm = torch.randperm(len(avail))[:n_prune]
                pidx = avail[perm]
                flat = w.flatten()

                irrev_flat = torch.zeros(n, dtype=torch.bool)
                if n_irrev > 0:
                    irrev_flat[pidx[:n_irrev]] = True
                    flat[pidx[:n_irrev]] = 0.0
                    tot_irrev += n_irrev
                irrev_mask = irrev_flat.view_as(w)

                rev_flat = torch.zeros(n, dtype=torch.bool)
                if n_rev > 0:
                    rev_flat[pidx[n_irrev:]] = True
                    flat[pidx[n_irrev:]] *= CFG['pruning_scale']
                    tot_rev += n_rev
                rev_mask = rev_flat.view_as(w)
                w.copy_(flat.view_as(w))

                combined = rev_mask | irrev_mask
                if combined.any():
                    per_unit = combined.float().mean(dim=1)
                    for j in range(w.shape[0]):
                        pf = per_unit[j].item()
                        if pf > thr:
                            excess = (pf - thr) / (1.0 - thr)
                            alive  = max(floor, 1.0 - excess * dr)
                            self.m.unit_masks[i].data[j] = alive

            self.prune_masks[nm] = rev_mask
            self.irrev_masks[nm] = irrev_mask

        # ─── 4. Cascading degradation (positive feedback on damage) ───
        cascade = CFG['pruning_cascade'] * severity_neg
        for i in range(self.m.n_hid):
            mask = self.m.unit_masks[i].data
            lb   = CFG['pruning_bias'][min(i, 2)]
            damage = (1.0 - mask).pow(0.5) * cascade * lb
            mask.sub_(damage)
            mask.clamp_(min=floor)

        # ─── 5. Effort deficit (negative symptoms) ───
        self.m.effort.data.fill_(
            max(0.1, CFG['effort_healthy']
                - severity_neg * CFG['effort_deficit_rate']))

        return {
            'severity_pos': severity_pos,
            'severity_neg': severity_neg,
            'aberrant': tot_aberrant,
            'pruned_reversible': tot_rev,
            'pruned_irreversible': tot_irrev,
            'gains': self.m.gains.data.tolist(),
            'noise': self.m.noise_stds.data.tolist(),
            'exc': self.m.exc.data.item(),
            'effort': self.m.effort.data.item(),
        }

    def pathology(self):
        gi = self.m.gains.data.max().item() - self.m.gains.data.min().item()
        an = self.m.noise_stds.data.mean().item()
        tw = sum(self.m.layers[i].weight.numel()
                 for i in range(self.m.n_hid))
        ta = sum(v.sum().item() for v in self.ab_masks.values())
        tr = sum(v.sum().item() for v in self.prune_masks.values())
        ti = sum(v.sum().item() for v in self.irrev_masks.values())
        af = ta / tw if tw else 0
        pf = (tr + ti) / tw if tw else 0
        tu = sum(self.m.unit_masks[i].numel()
                 for i in range(self.m.n_hid))
        au = sum(self.m.unit_masks[i].data.sum().item()
                 for i in range(self.m.n_hid))
        cap = au / tu if tu else 1.0
        se  = sum(self.m.side_effects[i].data.abs().sum().item()
                  for i in range(self.m.n_hid))
        eff = self.m.effort.item()
        pos = (gi / 2.5 + an / 1.2 + af * 50) / 3
        neg = pf + (1.0 - cap) + (1.0 - eff)
        return {
            'gain_imbalance': gi, 'avg_noise': an,
            'aberrant_frac': af, 'pruning_frac': pf,
            'pruning_rev': int(tr), 'pruning_irrev': int(ti),
            'capacity': cap, 'effort': eff, 'side_effects': se,
            'positive_composite': pos, 'negative_composite': neg,
            'total_composite': (pos + neg) / 2,
        }

    def active_pruning_frac(self):
        total = still = 0
        for nm in self.prune_masks:
            idx = int(nm.split('.')[1])
            w = self.m.layers[idx].weight.data
            orig = self.prune_orig[nm]
            comb = self.prune_masks[nm] | self.irrev_masks.get(
                nm, torch.zeros_like(self.prune_masks[nm]))
            nc = comb.sum().item()
            total += nc
            if nc > 0:
                ratio = w[comb].abs() / (orig[comb].abs() + 1e-8)
                still += (ratio < 0.20).sum().item()
        return still / total if total > 0 else 0.0

    def clone(self):
        return {
            'snap':        self.m.snap(),
            'ab_masks':    {k: v.clone() for k, v in self.ab_masks.items()},
            'ab_orig':     {k: v.clone() for k, v in self.ab_orig.items()},
            'prune_masks': {k: v.clone() for k, v in self.prune_masks.items()},
            'prune_orig':  {k: v.clone() for k, v in self.prune_orig.items()},
            'irrev_masks': {k: v.clone() for k, v in self.irrev_masks.items()},
        }

    def load(self, st):
        self.m.load_snap(st['snap'])
        self.ab_masks    = {k: v.clone() for k, v in st['ab_masks'].items()}
        self.ab_orig     = {k: v.clone() for k, v in st['ab_orig'].items()}
        self.prune_masks = {k: v.clone() for k, v in st['prune_masks'].items()}
        self.prune_orig  = {k: v.clone() for k, v in st['prune_orig'].items()}
        self.irrev_masks = {k: v.clone() for k, v in st['irrev_masks'].items()}


# ============================================================================
# TRAINING & EVALUATION
# ============================================================================
def train_net(model, epochs, lr, verbose=False):
    if epochs <= 0:
        return []
    trainable = [p for p in model.parameters() if p.requires_grad]
    opt = optim.Adam(trainable, lr=lr)
    ce  = nn.CrossEntropyLoss()
    losses = []
    for ep in range(epochs):
        model.train(); el = 0.0
        for x, y in train_ld:
            opt.zero_grad(); loss = ce(model(x), y)
            loss.backward(); opt.step(); el += loss.item()
        losses.append(el / len(train_ld))
        if verbose and (ep + 1) % 5 == 0:
            print(f"    Ep {ep+1}/{epochs}  loss={losses[-1]:.4f}")
    return losses

def acc(model, loader, inp_noise=0.0, ext_stress=0.0):
    model.eval()
    old = model.ext_stress; model.ext_stress = ext_stress
    c = t = 0
    with torch.no_grad():
        for x, y in loader:
            if inp_noise > 0:
                x = x + torch.randn_like(x) * inp_noise
            c += (model(x).argmax(1) == y).sum().item(); t += y.size(0)
    model.ext_stress = old
    return 100.0 * c / t


# ============================================================================
# [FIX 3 v2] HALLUCINATION METRIC — LOGIT-MARGIN BASED
# ============================================================================
#
# WHY the old metric failed:
#   softmax max-confidence saturates near 1.0 on this trivially-separable
#   4-class 2D task.  Healthy mean confidence on noise ≈ 0.97, so
#   mean + σ = 1.06 > 1.0 (theoretical softmax max).  Threshold is
#   unreachable → halluc_rate ≡ 0 for ALL conditions.
#
# FIX: Use the logit-margin (gap between top-2 raw logits) instead.
#   - Unbounded, no saturation.
#   - Psychotic model's amplified belief-layer gain (effective ~2×)
#     produces larger logits → wider margins → higher halluc rate.
#   - Threshold = healthy 75th-percentile → healthy ≈ 0.25 by calibration.
# ============================================================================

def compute_halluc_threshold(model, n=2000, runs=10):
    """
    Compute healthy model's logit-margin distribution on pure noise.
    Logit margin = (top-1 logit) − (top-2 logit).

    Returns: mean, std, 75th-percentile of the margin distribution.
    The 75th percentile is used as threshold so that healthy ≈ 25%
    "hallucination rate" — a meaningful baseline against which psychotic
    and treated states can be compared.
    """
    model.eval()
    all_margins = []
    for _ in range(runs):
        with torch.no_grad():
            noise_in = torch.randn(n, CFG['in_dim']) * 4.0
            logits = model(noise_in)
            top2 = logits.topk(2, dim=1).values
            margin = top2[:, 0] - top2[:, 1]
            all_margins.append(margin)
    all_margins = torch.cat(all_margins)
    mean_m = all_margins.mean().item()
    std_m = all_margins.std().item()
    sorted_m = all_margins.sort().values
    p75_idx = int(len(sorted_m) * 0.75)
    p75 = sorted_m[p75_idx].item()
    return mean_m, std_m, p75


def halluc_rate(model, n=1000, runs=5):
    """
    Hallucination rate: fraction of noise inputs whose logit margin
    exceeds the healthy-calibrated threshold.

    Interpretation:
      healthy  ≈ 0.25  (by calibration — 75th percentile threshold)
      psychotic > healthy  (amplified beliefs → larger logit margins)
      treated  between     (gains partially normalised)
    """
    model.eval()
    thr = CFG.get('halluc_logit_thr')
    if thr is None:
        thr = 5.0                       # Conservative fallback
    rates = []
    for _ in range(runs):
        with torch.no_grad():
            noise_in = torch.randn(n, CFG['in_dim']) * 4.0
            logits = model(noise_in)
            top2 = logits.topk(2, dim=1).values
            margin = top2[:, 0] - top2[:, 1]
        rates.append((margin > thr).float().mean().item())
    return float(np.mean(rates))


def flexibility(model, loader, pert=2.0, runs=3):
    model.eval(); kls = []
    for _ in range(runs):
        tot = n = 0
        with torch.no_grad():
            for x, y in loader:
                p1 = torch.softmax(model(x), 1).clamp(1e-8)
                xp = x + torch.randn_like(x) * pert
                p2 = torch.softmax(model(xp), 1).clamp(1e-8)
                tot += (p1*(p1.log()-p2.log())).sum(1).clamp(0).sum().item()
                n += x.size(0)
        kls.append(tot / n)
    return float(np.mean(kls))

def pe_score(model, loader):
    model.eval(); all_pe = []
    with torch.no_grad():
        for x, _ in loader:
            all_pe.append(model.compute_pes(x))
    avg = np.mean(all_pe, axis=0)
    return {'layers': avg.tolist(), 'total': float(avg.sum())}

def _net_metrics(model):
    cap = sum(model.unit_masks[i].data.mean().item()
              for i in range(model.n_hid)) / model.n_hid
    se  = sum(model.side_effects[i].data.abs().mean().item()
              for i in range(model.n_hid)) / model.n_hid
    eff = model.effort.item()
    return cap, se, eff

def full_eval(model, label="", mgr=None):
    r = {}
    r['clean']    = acc(model, clean_ld)
    r['standard'] = acc(model, test_ld)
    r['combined'] = acc(model, test_ld, 1.0, 0.5)
    r['cog']      = acc(model, cog_ld)
    r['cog_load'] = acc(model, cog_ld, ext_stress=0.5)
    for sn, sl in CFG['stress'].items():
        r[f's_{sn}'] = acc(model, test_ld, ext_stress=sl)
    r['halluc'] = halluc_rate(model)
    r['flex']   = flexibility(model, clean_ld)
    r['pe']     = pe_score(model, clean_ld)
    r['neg_score'] = max(0.0, (100.0 - r['clean']) / 100.0)
    cap, se, eff = _net_metrics(model)
    r['capacity'] = cap
    r['side_effects'] = se
    r['effort'] = eff
    if mgr is not None:
        r['pruning_active'] = mgr.active_pruning_frac()
        path = mgr.pathology()
        r['neg_composite'] = (r['neg_score']
                              + path['pruning_frac']
                              + (1-cap)
                              + (1-eff)) / 4
        r['pos_composite'] = path['positive_composite']
    else:
        r['pruning_active'] = 0.0
        r['neg_composite']  = r['neg_score']
        r['pos_composite']  = 0.0
    if label:
        print(f"\n  {label}")
        print(f"    Acc  clean={r['clean']:.1f}%  std={r['standard']:.1f}%  "
              f"comb={r['combined']:.1f}%  extreme={r['s_extreme']:.1f}%")
        print(f"    Cog  probe={r['cog']:.1f}%  loaded={r['cog_load']:.1f}%")
        print(f"    Pos  halluc={r['halluc']:.3f}  "
              f"flex={r['flex']:.4f}  PE={r['pe']['total']:.3f}")
        print(f"    Neg  neg={r['neg_score']:.3f}  cap={cap:.3f}  "
              f"eff={eff:.3f}  pruning={r['pruning_active']:.3f}  "
              f"SE={se:.4f}")
    return r

def light_eval(model, mgr=None):
    r = {}
    r['clean']    = acc(model, clean_ld)
    r['combined'] = acc(model, test_ld, 1.0, 0.5)
    r['cog']      = acc(model, cog_ld)
    r['halluc']   = halluc_rate(model, n=500, runs=1)
    r['neg_score'] = max(0.0, (100.0 - r['clean']) / 100.0)
    cap, se, eff = _net_metrics(model)
    r['capacity']     = cap
    r['side_effects'] = se
    r['effort']       = eff
    r['pruning_active'] = mgr.active_pruning_frac() if mgr else 0.0
    return r


# ============================================================================
# DOSE METRICS
# ============================================================================
@dataclass
class Dose:
    l1: float = 0.0;  l2: float = 0.0
    t_l1: float = 0.0
    gain_chg: float = 0.0;  noise_chg: float = 0.0;  exc_chg: float = 0.0
    cap_chg: float = 0.0;  se_chg: float = 0.0;  effort_chg: float = 0.0

def compute_dose(pre_snap, model) -> Dose:
    cur = model.state_dict()
    l1 = l2 = 0.0; tp = 0; cap_c = se_c = 0.0
    t_l1 = 0.0; t_tp = 0
    exclude_targeted = ['layers.', 'side_effects.']
    for nm in pre_snap:
        if nm in cur:
            d = (cur[nm].float() - pre_snap[nm].float()).abs()
            l1 += d.sum().item()
            l2 += (d ** 2).sum().item()
            tp += cur[nm].numel()
            if 'unit_masks' in nm:
                cap_c += d.sum().item()
            if 'side_effects' in nm:
                se_c += d.sum().item()
            if not any(ex in nm for ex in exclude_targeted):
                t_l1 += d.sum().item()
                t_tp += cur[nm].numel()
    l1 /= (tp or 1)
    l2  = (l2 ** 0.5) / (tp or 1)
    t_l1 /= (t_tp or 1)
    gc = (cur['gains'] - pre_snap['gains']).abs().sum().item()
    nc = (cur['noise_stds'] - pre_snap['noise_stds']).abs().sum().item()
    ec = (cur['exc'] - pre_snap['exc']).abs().item()
    efc = (cur['effort'] - pre_snap['effort']).abs().item()
    return Dose(l1, l2, t_l1, gc, nc, ec, cap_c, se_c, efc)


# ============================================================================
# TREATMENT PROTOCOLS
# ============================================================================

def treat_ketamine(model: PCNet, mgr: PsychMgr, intensity=1.0, verbose=True):
    """
    KETAMINE (NMDA antagonist)
      Positive: WORSENS  (disrupts precision params + aberrant wts)
      Negative: IMPROVES (synaptogenesis + capacity restore + effort boost)
      Side effects: dissociative noise
    """
    if verbose:
        print(f"\n    KETAMINE  intensity={intensity:.2f}")

    # Precision disruption
    for i in range(model.n_hid):
        if i <= model.n_hid // 2:
            model.gains.data[i] = max(0.05,
                model.gains[i].item()
                * (1 - intensity * CFG['ket_gain_atten']))
    bst = 1 + intensity * (CFG['ket_noise_boost'] - 1)
    for i in range(model.n_hid):
        model.noise_stds.data[i] = (
            model.noise_stds[i].item() * bst + intensity * 0.3)

    # Aberrant amplification
    ab = 1 + intensity * (CFG['ket_aberrant_boost'] - 1)
    for nm, mask in mgr.ab_masks.items():
        if mask.sum() == 0: continue
        idx = int(nm.split('.')[1])
        model.layers[idx].weight.data[mask] *= ab

    # Excitability
    model.exc.data.mul_(1 + intensity * (CFG['ket_exc_boost'] - 1))

    # Synaptogenesis (restores pruned weights)
    syn = intensity * CFG['ket_synaptogen']
    for nm, mask in mgr.prune_masks.items():
        if mask.sum() == 0: continue
        idx = int(nm.split('.')[1])
        w = model.layers[idx].weight.data
        w[mask] = w[mask] + syn * (mgr.prune_orig[nm][mask] - w[mask])

    # Capacity restoration
    cap_r = intensity * CFG['ket_synaptogen'] * CFG['ket_cap_restore']
    for i in range(model.n_hid):
        model.unit_masks[i].data += cap_r * (1.0 - model.unit_masks[i].data)
        model.unit_masks[i].data.clamp_(0, 1)

    # Effort restoration
    eff_r = intensity * CFG['ket_effort_restore']
    model.effort.data += eff_r * (CFG['effort_healthy'] - model.effort.data)
    model.effort.data.clamp_(0.1, CFG['effort_healthy'])

    # Side effects
    for i in range(model.n_hid):
        model.side_effects[i].data += (
            intensity * CFG['ket_side_noise']
            * torch.ones_like(model.side_effects[i].data))

    # ── [FIX 1] Consolidation — unified config ──
    if CFG['run_zero_consolidation']:
        n_ep = 0
    else:
        n_ep = CFG['consolidation_epochs']['ketamine']
    train_net(model, n_ep, CFG['ket_lr'])

    if verbose:
        print(f"      gains={[f'{g:.2f}' for g in model.gains.data.tolist()]}  "
              f"noise={[f'{n:.2f}' for n in model.noise_stds.data.tolist()]}  "
              f"exc={model.exc.item():.2f}  eff={model.effort.item():.2f}")


def treat_antipsychotic(model: PCNet, mgr: PsychMgr, dose=1.0, verbose=True):
    """
    ANTIPSYCHOTIC (D2 antagonist)
      Positive: TREATS  (normalizes precision + attenuates aberrant wts)
      Negative: MINIMAL (weak synaptogenesis + negligible effort)
      Side effects: extrapyramidal noise in lower layers
    """
    if verbose:
        print(f"\n    ANTIPSYCHOTIC  dose={dose:.2f}")

    hg = torch.tensor(CFG['healthy_gains'], dtype=torch.float32)

    # Precision normalization
    for i in range(model.n_hid):
        w = (i + 1) / model.n_hid
        nrm = dose * CFG['ap_gain_norm'] * w
        model.gains.data[i] += nrm * (hg[i] - model.gains[i].item())
    for i in range(model.n_hid):
        model.noise_stds.data[i] *= (1 - dose * CFG['ap_noise_red'])

    # Aberrant attenuation
    for nm, mask in mgr.ab_masks.items():
        if mask.sum() == 0: continue
        idx = int(nm.split('.')[1])
        w_d = model.layers[idx].weight.data
        orig = mgr.ab_orig.get(nm)
        if orig is not None:
            att = dose * CFG['ap_aberrant_weak']
            w_d[mask] = w_d[mask] * (1-att) + orig[mask] * att

    # Excitability normalization
    model.exc.data.add_(dose * CFG['ap_exc_norm'] * (1.0 - model.exc.item()))

    # Minimal synaptogenesis
    syn = dose * CFG['ap_synaptogen']
    if syn > 0:
        for nm, mask in mgr.prune_masks.items():
            if mask.sum() == 0: continue
            idx = int(nm.split('.')[1])
            w = model.layers[idx].weight.data
            w[mask] = w[mask] + syn * (mgr.prune_orig[nm][mask] - w[mask])
        cap_r = syn * CFG['ap_cap_restore']
        for i in range(model.n_hid):
            model.unit_masks[i].data += (
                cap_r * (1.0 - model.unit_masks[i].data))
            model.unit_masks[i].data.clamp_(0, 1)

    # Minimal effort restoration
    eff_r = dose * CFG['ap_effort_restore']
    model.effort.data += eff_r * (CFG['effort_healthy'] - model.effort.data)
    model.effort.data.clamp_(0.1, CFG['effort_healthy'])

    # Extrapyramidal side effects (lower layers)
    for i in range(min(2, model.n_hid)):
        model.side_effects[i].data += (
            dose * CFG['ap_side_noise']
            * torch.ones_like(model.side_effects[i].data))

    # ── [FIX 1] Consolidation — unified config ──
    if CFG['run_zero_consolidation']:
        n_ep = 0
    else:
        n_ep = CFG['consolidation_epochs']['antipsychotic']
    train_net(model, n_ep, CFG['ap_lr'])

    if verbose:
        print(f"      gains={[f'{g:.2f}' for g in model.gains.data.tolist()]}  "
              f"noise={[f'{n:.2f}' for n in model.noise_stds.data.tolist()]}  "
              f"exc={model.exc.item():.2f}  eff={model.effort.item():.2f}")


def treat_ect(model: PCNet, mgr: PsychMgr, sessions=6, verbose=True):
    """
    ECT (seizure reset)
      Positive: TREATS  (global precision normalization per session)
      Negative: TREATS  (strongest synaptogenesis + irreversible recovery)
      Side effects: postictal noise, decays across sessions
    """
    if verbose:
        print(f"\n    ECT  sessions={sessions}")

    hg = torch.tensor(CFG['healthy_gains'], dtype=torch.float32)
    hn = torch.tensor(CFG['healthy_noise'], dtype=torch.float32)
    h_sd = mgr.healthy

    for s in range(sessions):
        # Seizure perturbation
        with torch.no_grad():
            for nm, p in model.named_parameters():
                if nm.startswith('layers.'):
                    p.data += torch.randn_like(p) * CFG['ect_seizure_noise']

        # Precision normalization
        for i in range(model.n_hid):
            model.gains.data[i] += (
                CFG['ect_gain_reset']
                * (hg[i] - model.gains[i].item()))
            model.noise_stds.data[i] += (
                CFG['ect_noise_reset']
                * (hn[i] - model.noise_stds[i].item()))

        # Weight reset toward healthy
        if h_sd is not None:
            rw = CFG['ect_weight_reset']
            with torch.no_grad():
                for nm, p in model.named_parameters():
                    if nm in h_sd and nm.startswith('layers.'):
                        p.data.mul_(1-rw).add_(h_sd[nm].clone(), alpha=rw)

        # Aberrant normalization
        ra = CFG['ect_aberrant_reset']
        for nm, mask in mgr.ab_masks.items():
            if mask.sum() == 0: continue
            idx = int(nm.split('.')[1])
            w_d = model.layers[idx].weight.data
            orig = mgr.ab_orig.get(nm)
            if orig is not None:
                w_d[mask] = w_d[mask] * (1-ra) + orig[mask] * ra

        # Synaptogenesis (reversible)
        syn = CFG['ect_synaptogen']
        for nm, mask in mgr.prune_masks.items():
            if mask.sum() == 0: continue
            idx = int(nm.split('.')[1])
            w = model.layers[idx].weight.data
            w[mask] = w[mask] + syn * (mgr.prune_orig[nm][mask] - w[mask])

        # Partial irreversible recovery
        irr = CFG['ect_irrev_recovery']
        for nm, mask in mgr.irrev_masks.items():
            if mask.sum() == 0: continue
            idx = int(nm.split('.')[1])
            w = model.layers[idx].weight.data
            w[mask] = w[mask] + irr * mgr.prune_orig[nm][mask]

        # Capacity restoration
        cap_r = syn * CFG['ect_cap_restore']
        for i in range(model.n_hid):
            model.unit_masks[i].data += (
                cap_r * (1.0 - model.unit_masks[i].data))
            model.unit_masks[i].data.clamp_(0, 1)

        # Effort restoration
        eff_r = CFG['ect_effort_restore']
        model.effort.data += (
            eff_r * (CFG['effort_healthy'] - model.effort.data))
        model.effort.data.clamp_(0.1, CFG['effort_healthy'])

        # Postictal side effects (decay across sessions)
        for i in range(model.n_hid):
            model.side_effects[i].data *= CFG['ect_se_decay']
            model.side_effects[i].data += (
                CFG['ect_postictal_noise']
                * torch.ones_like(model.side_effects[i].data))

        # Excitability normalization
        model.exc.data.add_(
            CFG['ect_exc_reset'] * (1.0 - model.exc.item()))

        # ── [FIX 1] Post-session consolidation — unified config ──
        if CFG['run_zero_consolidation']:
            n_ep = 0
        else:
            n_ep = CFG['consolidation_epochs']['ect_per_session']
        train_net(model, n_ep, CFG['ect_retrain_lr'])

    if verbose:
        print(f"      gains={[f'{g:.2f}' for g in model.gains.data.tolist()]}  "
              f"noise={[f'{n:.2f}' for n in model.noise_stds.data.tolist()]}  "
              f"exc={model.exc.item():.2f}  eff={model.effort.item():.2f}")


# ============================================================================
# RELAPSE & HOMEOSTASIS
# ============================================================================
def do_relapse(model, mgr, psy_state, frac=None):
    """
    Heterogeneous relapse: precision fast, structural slow.
    All effects are parameter reversions.
    """
    base = frac or CFG['relapse_revert']
    fp  = base * CFG['relapse_precision_mult']
    fpr = base * CFG['relapse_pruning_mult']
    fc  = base * CFG['relapse_capacity_mult']
    fe  = base * CFG['relapse_effort_mult']
    ps  = psy_state['snap']

    # Precision relapse (fast)
    model.gains.data      += fp * (ps['gains'] - model.gains.data)
    model.noise_stds.data += fp * (ps['noise_stds'] - model.noise_stds.data)
    model.exc.data        += fp * (ps['exc'] - model.exc.data)

    # Effort relapse (moderate)
    model.effort.data += fe * (ps['effort'] - model.effort.data)

    # Capacity relapse (slow)
    for i in range(model.n_hid):
        key = f'unit_masks.{i}'
        if key in ps:
            model.unit_masks[i].data += (
                fc * (ps[key] - model.unit_masks[i].data))

    # Aberrant weight relapse (fast)
    for nm, mask in mgr.ab_masks.items():
        if mask.sum() == 0: continue
        idx = int(nm.split('.')[1])
        w_d = model.layers[idx].weight.data
        pw  = ps[f'layers.{idx}.weight']
        w_d[mask] = w_d[mask] * (1-fp) + pw[mask] * fp

    # Pruning weight relapse (slow)
    for nm, mask in mgr.prune_masks.items():
        if mask.sum() == 0: continue
        idx = int(nm.split('.')[1])
        w_d = model.layers[idx].weight.data
        pw  = ps[f'layers.{idx}.weight']
        w_d[mask] = w_d[mask] * (1-fpr) + pw[mask] * fpr

    # Irreversible push toward zero (slow)
    for nm, mask in mgr.irrev_masks.items():
        if mask.sum() == 0: continue
        idx = int(nm.split('.')[1])
        model.layers[idx].weight.data[mask] *= (1 - fpr)


def homeostatic_step(model, strength=1.0, se_decay=None):
    """
    Slow parameter drift toward healthy + side-effect decay.
    """
    hr = CFG['homeostatic_rate'] * strength
    hg = torch.tensor(CFG['healthy_gains'], dtype=torch.float32)
    hn = torch.tensor(CFG['healthy_noise'], dtype=torch.float32)
    model.gains.data      += hr * (hg - model.gains.data)
    model.noise_stds.data += hr * (hn - model.noise_stds.data)
    model.exc.data        += hr * (1.0 - model.exc.data)
    model.effort.data     += hr * (CFG['effort_healthy'] - model.effort.data)
    sd = se_decay if se_decay is not None else CFG['se_decay_default']
    for i in range(model.n_hid):
        model.side_effects[i].data *= (1 - sd)


# ============================================================================
# CHRONIC SIMULATION
# ============================================================================
def simulate_chronic(psy_state, h_snap, treat_fn, param, treat_name):
    """
    Chronic treatment -> maintenance -> withdrawal.
    All effects are network-parameter modifications.
    """
    model = PCNet().to(DEVICE); mgr = PsychMgr(model)
    mgr.load(psy_state); mgr.healthy = h_snap

    base_ev = light_eval(model, mgr)

    treat_fn(model, mgr, param, verbose=False)
    acute_ev = light_eval(model, mgr)

    if treat_name == 'ect':
        maint_p = max(1, int(param * CFG['maintenance_dose_frac']))
    else:
        maint_p = param * CFG['maintenance_dose_frac']

    se_rate = CFG['se_decay_rate'].get(treat_name, CFG['se_decay_default'])
    cumulative = param if not isinstance(param, int) else float(param)

    for step in range(CFG['chronic_steps']):
        if (step + 1) % CFG['maintenance_interval'] == 0:
            treat_fn(model, mgr, maint_p, verbose=False)
            cumulative += maint_p if not isinstance(maint_p, int) \
                else float(maint_p)
        homeostatic_step(model, se_decay=se_rate)
        train_net(model, 1, CFG['base_lr'] * 0.3)
    chronic_ev = light_eval(model, mgr)

    base_dur = CFG['treatment_durability'][treat_name]
    if treat_name == 'ect':
        dose_scale = min(1.5, cumulative / 6.0)
    else:
        dose_scale = min(1.5, cumulative / 0.5)
    eff_dur = min(0.98, base_dur * dose_scale)
    drift   = CFG['relapse_drift_rate'] * (1 - eff_dur)

    for step in range(CFG['withdrawal_steps']):
        do_relapse(model, mgr, psy_state, frac=drift)
        homeostatic_step(model, strength=0.3, se_decay=se_rate)
        train_net(model, 1, CFG['base_lr'] * 0.2)
    withdrawal_ev = light_eval(model, mgr)

    return {
        'baseline': base_ev, 'acute': acute_ev,
        'chronic': chronic_ev, 'withdrawal': withdrawal_ev,
        'durability': eff_dur, 'cumulative': cumulative,
    }


# ============================================================================
# ISO-DOSE SWEEP
# ============================================================================
def do_sweep(psy_st, h_snap, treat_fn, pname, pvals, unt_comb, verbose=True):
    out = []
    for v in pvals:
        m = PCNet().to(DEVICE); mg = PsychMgr(m)
        mg.load(psy_st); mg.healthy = h_snap
        pre = m.snap()
        treat_fn(m, mg, v, verbose=False)
        d  = compute_dose(pre, m)
        ev = full_eval(m, mgr=mg)
        imp = ev['combined'] - unt_comb
        pr  = ev['combined']
        do_relapse(m, mg, psy_st)
        postr = acc(m, test_ld, 1.0, 0.5)
        rdrop = pr - postr
        out.append({'p': v, 'dose': d, 'ev': ev, 'imp': imp,
                    'rdrop': rdrop, 'post_rel': postr,
                    'gains': m.gains.data.tolist(),
                    'noise': m.noise_stds.data.tolist()})
        if verbose:
            print(f"      {pname}={v}  L1={d.l1:.6f}  tL1={d.t_l1:.6f}  "
                  f"comb={ev['combined']:.1f}%  imp={imp:+.1f}%  "
                  f"cog={ev['cog']:.1f}%  hal={ev['halluc']:.3f}  "
                  f"neg={ev['neg_score']:.3f}  eff={ev['effort']:.2f}  "
                  f"cap={ev['capacity']:.3f}  SE={ev['side_effects']:.4f}  "
                  f"rel={rdrop:.1f}%")
    return out

def iso_match(results, target, use_targeted=False):
    if use_targeted:
        return min(results, key=lambda r: abs(r['dose'].t_l1 - target))
    return min(results, key=lambda r: abs(r['dose'].l1 - target))


# ============================================================================
# INDIVIDUAL VARIABILITY
# ============================================================================
def run_patient(patient_idx, h_snap):
    """
    Single patient with unique pathology mask and severity profile.
    """
    seed = CFG['patient_seeds'][patient_idx]
    pos_sev, neg_sev = CFG['patient_profiles'][patient_idx]

    torch.manual_seed(seed)
    np.random.seed(seed)

    model = PCNet().to(DEVICE)
    model.load_snap(h_snap)
    mgr = PsychMgr(model); mgr.save_healthy()
    mgr.induce(severity_pos=pos_sev, severity_neg=neg_sev)
    p_ev = full_eval(model, mgr=mgr)
    psy_state = mgr.clone()
    unt_comb  = p_ev['combined']

    results = {
        'psychotic': p_ev, 'unt_comb': unt_comb,
        'profile': (pos_sev, neg_sev),
    }
    for name, fn, arg in [('ketamine', treat_ketamine, 0.7),
                          ('antipsychotic', treat_antipsychotic, 0.7),
                          ('ect', treat_ect, 8)]:
        m2 = PCNet().to(DEVICE); mg2 = PsychMgr(m2)
        mg2.load(psy_state); mg2.healthy = h_snap
        fn(m2, mg2, arg, verbose=False)
        ev = full_eval(m2, mgr=mg2)
        results[name] = {'ev': ev, 'imp': ev['combined'] - unt_comb}
    return results


# ============================================================================
# MAIN EXPERIMENT
# ============================================================================
def run():
    print("=" * 80)
    print("  HIERARCHICAL PREDICTIVE CODING MODEL OF PSYCHOSIS")
    print("  Iso-Dose + Negative Symptoms + Chronic Dynamics")
    print("  + Cognitive Probe + 10-Patient Variability")
    print("  v2: Decoupled consolidation | Fixed iso-dose | Logit-margin halluc | Norm PE")
    print("=" * 80)

    # == PHASE 1: HEALTHY BASELINE ==========================================
    print("\n" + "-"*70)
    print("  PHASE 1 -- Healthy baseline")
    print("-"*70)
    model = PCNet().to(DEVICE)
    tp = sum(p.numel() for p in model.parameters())
    n_lp = sum(p.numel() for n, p in model.named_parameters()
               if n.startswith('layers.'))
    n_um = sum(model.unit_masks[i].numel() for i in range(model.n_hid))
    n_se = sum(model.side_effects[i].numel() for i in range(model.n_hid))
    print(f"  Net: {CFG['in_dim']}->{CFG['hidden_dims']}->{CFG['out_dim']}  "
          f"({tp:,} total params)")
    print(f"    Layer wt/bias: {n_lp:,}  |  "
          f"Precision: {2*model.n_hid+1}  |  "
          f"Effort: 1  |  Unit masks: {n_um}  |  Side effects: {n_se}")
    train_net(model, CFG['base_epochs'], CFG['base_lr'], verbose=True)

    # ── [FIX 3 v2] Compute logit-margin threshold from healthy model ──
    h_margin_mean, h_margin_std, h_margin_p75 = compute_halluc_threshold(model)
    CFG['halluc_logit_thr'] = h_margin_p75
    print(f"\n  Halluc threshold (logit margin on noise):")
    print(f"    mean={h_margin_mean:.3f}  std={h_margin_std:.3f}  "
          f"thr(p75)={h_margin_p75:.3f}")
    print(f"    (healthy model will show ~25% halluc rate by calibration)")

    h_ev   = full_eval(model, "HEALTHY BASELINE")
    h_snap = model.snap()

    # == PHASE 2: INDUCE PSYCHOSIS ==========================================
    print("\n" + "-"*70)
    print("  PHASE 2 -- Induce psychosis (positive + negative symptoms)")
    print("-"*70)
    print("""
  POSITIVE (precision imbalance + aberrant weights):
    Gains:  1.00->0.30 / 1.00->0.60 / 1.00->2.50   [nn.Parameter]
    Noise:  0.00->1.20 / 0.00->0.60 / 0.00->0.05   [nn.Parameter]
    Exc:    1.00->1.50   + aberrant weights           [nn.Parameter + layers]

  NEGATIVE (pruning + capacity + effort):
    ~25% weights pruned (higher-layer bias)            [layer weights]
    Cascading unit degradation                         [unit_masks nn.ParameterList]
    Hard gating: units below floor contribute ZERO
    Capacity-dependent excitability dampening
    Effort/motivation: higher layers more impaired     [effort nn.Parameter]
    Cognitive probe: sensitive to capacity loss
    """)
    mgr = PsychMgr(model); mgr.save_healthy()
    ps  = mgr.induce(severity_pos=1.0, severity_neg=1.0)
    path = mgr.pathology()
    print(f"  Aberrant connections:   {ps['aberrant']:,}")
    print(f"  Pruned (reversible):    {ps['pruned_reversible']:,}")
    print(f"  Pruned (irreversible):  {ps['pruned_irreversible']:,}")
    print(f"  Network capacity:       {path['capacity']:.3f}")
    print(f"  Effort:                 {ps['effort']:.3f}")
    print(f"  Pathology -- pos: {path['positive_composite']:.3f}  "
          f"neg: {path['negative_composite']:.3f}")
    p_ev = full_eval(model, "PSYCHOTIC STATE", mgr=mgr)
    psy_state = mgr.clone()
    unt_comb  = p_ev['combined']

    # == PHASE 3: DEFAULT-DOSE TREATMENTS ===================================
    print("\n" + "="*80)
    print("  PHASE 3 -- Treatments (all network-parameter modifications)")
    print(f"  Consolidation epochs: ket={CFG['consolidation_epochs']['ketamine']}  "
          f"ap={CFG['consolidation_epochs']['antipsychotic']}  "
          f"ect={CFG['consolidation_epochs']['ect_per_session']}/session")
    print("="*80)
    tx = {}
    specs = [
        ('KETAMINE',       treat_ketamine,      0.7),
        ('ANTIPSYCHOTIC',  treat_antipsychotic,  0.7),
        ('ECT',            treat_ect,            8),
    ]
    for label, fn, arg in specs:
        print(f"\n  -- {label} --")
        m2 = PCNet().to(DEVICE); mg2 = PsychMgr(m2)
        mg2.load(psy_state); mg2.healthy = h_snap
        pre = m2.snap()
        fn(m2, mg2, arg)
        d  = compute_dose(pre, m2)
        ev = full_eval(m2, f"  POST-{label}", mgr=mg2)
        imp = ev['combined'] - unt_comb
        pr  = ev['combined']
        do_relapse(m2, mg2, psy_state)
        postr = acc(m2, test_ld, 1.0, 0.5)
        key = label.lower()
        tx[key] = {'ev': ev, 'dose': d, 'imp': imp,
                   'rdrop': pr - postr,
                   'gains': m2.gains.data.tolist(),
                   'noise': m2.noise_stds.data.tolist(),
                   'exc': m2.exc.item(),
                   'effort': m2.effort.item()}

    # == PHASE 3b: ZERO-CONSOLIDATION ABLATION ==============================
    print("\n" + "-"*70)
    print("  PHASE 3b -- Zero-consolidation ablation (raw manipulation effects)")
    print("  [FIX 1] Reveals pure pharmacological effects without training confound")
    print("-"*70)

    old_zero = CFG['run_zero_consolidation']
    CFG['run_zero_consolidation'] = True
    tx_zero = {}
    for label, fn, arg in specs:
        m2 = PCNet().to(DEVICE); mg2 = PsychMgr(m2)
        mg2.load(psy_state); mg2.healthy = h_snap
        pre = m2.snap()
        fn(m2, mg2, arg, verbose=False)
        d  = compute_dose(pre, m2)
        ev = full_eval(m2, mgr=mg2)
        key = label.lower()
        tx_zero[key] = {'ev': ev, 'dose': d, 'imp': ev['combined'] - unt_comb}
    CFG['run_zero_consolidation'] = old_zero

    print(f"\n  {'Treatment':<16} {'With Consol':>14} {'Zero Consol':>14} "
          f"{'Delta':>8} {'Zero L1':>10} {'Full L1':>10}")
    print("  " + "-"*72)
    for k in ['ketamine', 'antipsychotic', 'ect']:
        wc = tx[k]['ev']['combined']
        zc = tx_zero[k]['ev']['combined']
        print(f"  {k.capitalize():<16} {wc:>13.1f}% {zc:>13.1f}% "
              f"{wc-zc:>+7.1f}% {tx_zero[k]['dose'].l1:>10.6f} "
              f"{tx[k]['dose'].l1:>10.6f}")

    print(f"\n  {'Treatment':<16} {'Halluc(w)':>10} {'Halluc(0)':>10} "
          f"{'Neg(w)':>8} {'Neg(0)':>8} {'Cog(w)':>8} {'Cog(0)':>8}")
    print("  " + "-"*66)
    for k in ['ketamine', 'antipsychotic', 'ect']:
        we = tx[k]['ev']
        ze = tx_zero[k]['ev']
        print(f"  {k.capitalize():<16} {we['halluc']:>10.3f} {ze['halluc']:>10.3f} "
              f"{we['neg_score']:>8.3f} {ze['neg_score']:>8.3f} "
              f"{we['cog']:>7.1f}% {ze['cog']:>7.1f}%")

    # -- Summary table --
    print("\n" + "="*80)
    print("  SUMMARY TABLE")
    print("="*80)
    hdr = (f"  {'State':<18} {'Cln':>6} {'Cmb':>6} {'Ext':>6} "
           f"{'Cog':>6} {'Hal':>6} {'Neg':>5} {'Cap':>5} "
           f"{'Eff':>5} {'SE':>5} {'L1':>9} {'tL1':>9}")
    print(f"\n{hdr}\n  " + "-"*100)
    rows = [
        ('Healthy',     h_ev, None),
        ('Psychotic',   p_ev, None),
        ('+ Ketamine',  tx['ketamine']['ev'],      tx['ketamine']['dose']),
        ('+ Antipsych', tx['antipsychotic']['ev'],  tx['antipsychotic']['dose']),
        ('+ ECT',       tx['ect']['ev'],            tx['ect']['dose']),
    ]
    for nm, ev, d in rows:
        ds  = f"{d.l1:.6f}" if d else "       --"
        tds = f"{d.t_l1:.6f}" if d else "       --"
        print(f"  {nm:<18} {ev['clean']:>5.1f}% {ev['combined']:>5.1f}% "
              f"{ev['s_extreme']:>5.1f}% {ev['cog']:>5.1f}% "
              f"{ev['halluc']:>6.3f} {ev['neg_score']:>5.3f} "
              f"{ev['capacity']:>5.3f} {ev['effort']:>5.2f} "
              f"{ev['side_effects']:>5.3f} {ds:>9} {tds:>9}")

    print(f"\n  Improvement from psychotic baseline:")
    for k in ['ketamine', 'antipsychotic', 'ect']:
        r = tx[k]
        print(f"    {k.capitalize():<16} comb={r['imp']:+6.1f}%  "
              f"cog={r['ev']['cog']:.1f}%  "
              f"neg={r['ev']['neg_score']:.3f}  "
              f"eff={r['ev']['effort']:.2f}  "
              f"cap={r['ev']['capacity']:.3f}  "
              f"SE={r['ev']['side_effects']:.4f}  "
              f"relapse={r['rdrop']:.1f}%")

    # == PHASE 4: PRECISION TRAJECTORIES ====================================
    print("\n" + "-"*70)
    print("  PRECISION + EFFORT TRAJECTORIES (nn.Parameter values)")
    print("-"*70)
    lnames = ['Sensory ', 'Percept ', 'Belief  ']
    print(f"\n  {'Layer':<9} {'Healthy':>14} {'Psychotic':>14} "
          f"{'+ Ket':>14} {'+AP':>14} {'+ ECT':>14}")
    print("  " + "-"*75)
    for i, ln in enumerate(lnames):
        gvals = [CFG['healthy_gains'][i], CFG['psychosis_gains'][i],
                 tx['ketamine']['gains'][i], tx['antipsychotic']['gains'][i],
                 tx['ect']['gains'][i]]
        nvals = [CFG['healthy_noise'][i], CFG['psychosis_noise'][i],
                 tx['ketamine']['noise'][i], tx['antipsychotic']['noise'][i],
                 tx['ect']['noise'][i]]
        parts = [f"g={gvals[j]:.2f} n={nvals[j]:.2f}" for j in range(5)]
        print(f"  {ln} {'  '.join(parts)}")

    evals = [CFG['effort_healthy'], p_ev['effort'],
             tx['ketamine']['effort'], tx['antipsychotic']['effort'],
             tx['ect']['effort']]
    print(f"  {'Effort':<9} " + "  ".join(
        f"{'e='+f'{e:.2f}':>14}" for e in evals))

    # == PHASE 5: ISO-DOSE SWEEPS ==========================================
    print("\n" + "="*80)
    print("  PHASE 5 -- Iso-dose parameter sweeps")
    print("  L1 = all param changes | tL1 = targeted (precision+masks+effort only)")
    print("="*80)

    print("\n  [KETAMINE] sweeping NMDA blockade intensity ...")
    ks = do_sweep(psy_state, h_snap, treat_ketamine,
                  'int', CFG['ket_sweep'], unt_comb)
    print("\n  [ANTIPSYCHOTIC] sweeping D2 blockade dose ...")
    aps = do_sweep(psy_state, h_snap, treat_antipsychotic,
                   'dose', CFG['ap_sweep'], unt_comb)
    print("\n  [ECT] sweeping session count ...")
    es = do_sweep(psy_state, h_snap, treat_ect,
                  'sess', CFG['ect_sweep'], unt_comb)

    # == PHASE 6: ISO-DOSE MATCHING ========================================
    print("\n" + "="*80)
    print("  PHASE 6 -- Iso-dose matched comparisons")
    print("  [FIX 2] Using both full L1 and targeted L1 matching")
    print("="*80)

    kd = [r['dose'].l1 for r in ks]
    ad = [r['dose'].l1 for r in aps]
    ed = [r['dose'].l1 for r in es]
    lo = max(min(kd), min(ad), min(ed))
    hi = min(max(kd), max(ad), max(ed))
    print(f"\n  Full L1 dose ranges:")
    print(f"    Ketamine      {min(kd):.6f} -- {max(kd):.6f}")
    print(f"    Antipsychotic {min(ad):.6f} -- {max(ad):.6f}")
    print(f"    ECT           {min(ed):.6f} -- {max(ed):.6f}")
    print(f"    Overlap       {lo:.6f} -- {hi:.6f}")

    kd_t = [r['dose'].t_l1 for r in ks]
    ad_t = [r['dose'].t_l1 for r in aps]
    ed_t = [r['dose'].t_l1 for r in es]
    lo_t = max(min(kd_t), min(ad_t), min(ed_t))
    hi_t = min(max(kd_t), max(ad_t), max(ed_t))
    print(f"\n  Targeted L1 dose ranges (precision+masks+effort):")
    print(f"    Ketamine      {min(kd_t):.6f} -- {max(kd_t):.6f}")
    print(f"    Antipsychotic {min(ad_t):.6f} -- {max(ad_t):.6f}")
    print(f"    ECT           {min(ed_t):.6f} -- {max(ed_t):.6f}")
    print(f"    Overlap       {lo_t:.6f} -- {hi_t:.6f}")

    tgts = (np.linspace(lo, hi, 5).tolist() if hi > lo else
            [float(np.percentile(kd+ad+ed, p)) for p in [20, 40, 60, 80]])
    print(f"\n  Full L1 Targets: {[f'{t:.6f}' for t in tgts]}\n")

    for ti, tgt in enumerate(tgts):
        km = iso_match(ks, tgt)
        am = iso_match(aps, tgt)
        em = iso_match(es, tgt)
        print(f"  ISO-DOSE(L1) {ti+1}  target ~ {tgt:.6f}")
        print(f"  {'Tx':<14} {'P':>5} {'L1':>10} {'tL1':>10} {'Cmb':>7} {'Imp':>7} "
              f"{'Cog':>6} {'Hal':>6} {'Neg':>5} {'Cap':>5} "
              f"{'Eff':>4} {'SE':>5} {'Rel':>6}")
        print("  " + "-"*100)
        bc = ('', -999); br = ('', 999)
        bh = ('', 999);  bn = ('', 999)
        for lbl, r in [('Ketamine', km),
                        ('Antipsychotic', am),
                        ('ECT', em)]:
            ns = r['ev']['neg_score']
            print(f"  {lbl:<14} {str(r['p']):>5} {r['dose'].l1:>10.6f} "
                  f"{r['dose'].t_l1:>10.6f} "
                  f"{r['ev']['combined']:>6.1f}% {r['imp']:>+6.1f}% "
                  f"{r['ev']['cog']:>5.1f}% "
                  f"{r['ev']['halluc']:>6.3f} {ns:>5.3f} "
                  f"{r['ev']['capacity']:>5.3f} "
                  f"{r['ev']['effort']:>4.2f} "
                  f"{r['ev']['side_effects']:>5.3f} "
                  f"{r['rdrop']:>5.1f}%")
            if r['ev']['combined'] > bc[1]: bc = (lbl, r['ev']['combined'])
            if r['rdrop'] < br[1]:          br = (lbl, r['rdrop'])
            if r['ev']['halluc'] < bh[1]:   bh = (lbl, r['ev']['halluc'])
            if ns < bn[1]:                  bn = (lbl, ns)
        print(f"    -> Best: comb={bc[0]}  relapse={br[0]}  "
              f"halluc(low)={bh[0]}  neg={bn[0]}\n")

    tgts_t = (np.linspace(lo_t, hi_t, 5).tolist() if hi_t > lo_t else
              [float(np.percentile(kd_t+ad_t+ed_t, p))
               for p in [20, 40, 60, 80]])
    print(f"\n  Targeted L1 Targets: {[f'{t:.6f}' for t in tgts_t]}\n")

    for ti, tgt in enumerate(tgts_t):
        km = iso_match(ks, tgt, use_targeted=True)
        am = iso_match(aps, tgt, use_targeted=True)
        em = iso_match(es, tgt, use_targeted=True)
        print(f"  ISO-DOSE(tL1) {ti+1}  target ~ {tgt:.6f}")
        print(f"  {'Tx':<14} {'P':>5} {'tL1':>10} {'L1':>10} {'Cmb':>7} {'Imp':>7} "
              f"{'Cog':>6} {'Hal':>6} {'Neg':>5} {'Cap':>5} "
              f"{'Eff':>4} {'SE':>5} {'Rel':>6}")
        print("  " + "-"*100)
        bc = ('', -999); br = ('', 999)
        bh = ('', 999);  bn = ('', 999)
        for lbl, r in [('Ketamine', km),
                        ('Antipsychotic', am),
                        ('ECT', em)]:
            ns = r['ev']['neg_score']
            print(f"  {lbl:<14} {str(r['p']):>5} {r['dose'].t_l1:>10.6f} "
                  f"{r['dose'].l1:>10.6f} "
                  f"{r['ev']['combined']:>6.1f}% {r['imp']:>+6.1f}% "
                  f"{r['ev']['cog']:>5.1f}% "
                  f"{r['ev']['halluc']:>6.3f} {ns:>5.3f} "
                  f"{r['ev']['capacity']:>5.3f} "
                  f"{r['ev']['effort']:>4.2f} "
                  f"{r['ev']['side_effects']:>5.3f} "
                  f"{r['rdrop']:>5.1f}%")
            if r['ev']['combined'] > bc[1]: bc = (lbl, r['ev']['combined'])
            if r['rdrop'] < br[1]:          br = (lbl, r['rdrop'])
            if r['ev']['halluc'] < bh[1]:   bh = (lbl, r['ev']['halluc'])
            if ns < bn[1]:                  bn = (lbl, ns)
        print(f"    -> Best: comb={bc[0]}  relapse={br[0]}  "
              f"halluc(low)={bh[0]}  neg={bn[0]}\n")

    # == PHASE 7: EFFICIENCY ================================================
    print("="*80)
    print("  PHASE 7 -- Treatment efficiency = |improvement| / L1 dose")
    print("="*80)
    for label, sw in [('KETAMINE', ks),
                      ('ANTIPSYCHOTIC', aps),
                      ('ECT', es)]:
        print(f"\n  {label}:")
        print(f"  {'P':>5} {'L1':>10} {'tL1':>10} {'Imp':>8} {'Eff':>9} "
              f"{'tEff':>9} {'Cog':>6} {'Hal':>6} {'Neg':>5} {'Cap':>5} "
              f"{'Efrt':>5} {'SE':>5}")
        print("  " + "-"*90)
        for r in sw:
            ef  = abs(r['imp']) / (r['dose'].l1 + 1e-8)
            tef = abs(r['imp']) / (r['dose'].t_l1 + 1e-8)
            print(f"  {str(r['p']):>5} {r['dose'].l1:>10.6f} "
                  f"{r['dose'].t_l1:>10.6f} "
                  f"{r['imp']:>+7.1f}% {ef:>9.1f} {tef:>9.1f} "
                  f"{r['ev']['cog']:>5.1f}% "
                  f"{r['ev']['halluc']:>6.3f} "
                  f"{r['ev']['neg_score']:>5.3f} "
                  f"{r['ev']['capacity']:>5.3f} "
                  f"{r['ev']['effort']:>5.2f} "
                  f"{r['ev']['side_effects']:>5.3f}")

    # == PHASE 8: CHRONIC + WITHDRAWAL =====================================
    print("\n" + "="*80)
    print("  PHASE 8 -- Chronic treatment + withdrawal dynamics")
    print("  Relapse: precision=fast, structural=slow, effort=moderate")
    print("  SE decay: ketamine=fast, ECT=moderate, antipsychotic=slow")
    print("="*80)

    chronic = {}
    ch_specs = [('ketamine', treat_ketamine, 0.7),
                ('antipsychotic', treat_antipsychotic, 0.7),
                ('ect', treat_ect, 8)]
    for name, fn, param in ch_specs:
        print(f"\n  Simulating {name.upper()} "
              f"(chronic={CFG['chronic_steps']}, "
              f"withdrawal={CFG['withdrawal_steps']}) ...")
        chronic[name] = simulate_chronic(
            psy_state, h_snap, fn, param, name)

    print(f"\n  {'Treatment':<16} {'Phase':<12} {'Cln':>6} {'Cmb':>6} "
          f"{'Cog':>6} {'Hal':>6} {'Neg':>5} {'Cap':>5} "
          f"{'Eff':>5} {'SE':>5}")
    print("  " + "-"*78)
    for name in ['ketamine', 'antipsychotic', 'ect']:
        cr = chronic[name]
        for phase in ['baseline', 'acute', 'chronic', 'withdrawal']:
            ev = cr[phase]
            print(f"  {name.capitalize():<16} {phase:<12} "
                  f"{ev['clean']:>5.1f}% {ev['combined']:>5.1f}% "
                  f"{ev['cog']:>5.1f}% "
                  f"{ev['halluc']:>6.3f} {ev['neg_score']:>5.3f} "
                  f"{ev['capacity']:>5.3f} {ev['effort']:>5.2f} "
                  f"{ev['side_effects']:>5.3f}")
        print()

    print(f"\n  DURABILITY COMPARISON:")
    print(f"  {'Treatment':<16} {'Acute D':>8} {'Final D':>8} "
          f"{'Retained':>9} {'Durability':>10} {'CumulDose':>10}")
    print("  " + "-"*64)
    for name in ['ketamine', 'antipsychotic', 'ect']:
        cr = chronic[name]
        bc = cr['baseline']['combined']
        ad = cr['acute']['combined'] - bc
        fd = cr['withdrawal']['combined'] - bc
        if abs(ad) > 0.5:
            ret = f"{fd/ad*100:.0f}%"
        else:
            ret = "N/A"
        print(f"  {name.capitalize():<16} {ad:>+7.1f}% {fd:>+7.1f}% "
              f"{ret:>9} {cr['durability']:>10.2f} "
              f"{cr['cumulative']:>10.1f}")

    # == PHASE 9: 10-PATIENT VARIABILITY ===================================
    print("\n" + "="*80)
    print("  PHASE 9 -- Individual variability (10 patients)")
    print("  Each patient: unique seed + (pos_sev, neg_sev) profile")
    print("="*80)

    print(f"\n  {'#':>3} {'Seed':>5} {'Pos':>4} {'Neg':>4} "
          f"{'PsyCmb':>7} {'PsyCog':>7}   "
          f"{'Ket':>7} {'AP':>7} {'ECT':>7}")
    print("  " + "-"*66)

    pat_results = []
    for pi in range(len(CFG['patient_seeds'])):
        pr = run_patient(pi, h_snap)
        pat_results.append(pr)
        pos_s, neg_s = pr['profile']
        print(f"  {pi+1:>3} {CFG['patient_seeds'][pi]:>5} "
              f"{pos_s:>4.1f} {neg_s:>4.1f} "
              f"{pr['psychotic']['combined']:>6.1f}% "
              f"{pr['psychotic']['cog']:>6.1f}%   "
              f"{pr['ketamine']['imp']:>+6.1f}% "
              f"{pr['antipsychotic']['imp']:>+6.1f}% "
              f"{pr['ect']['imp']:>+6.1f}%")

    print(f"\n  AGGREGATE (mean +/- std across "
          f"{len(pat_results)} patients):")
    print(f"  {'Treatment':<16} {'Comb Imp':>14} {'Cog Acc':>14} "
          f"{'Neg Score':>14} {'Effort':>14} {'Capacity':>14}")
    print("  " + "-"*86)
    for name in ['ketamine', 'antipsychotic', 'ect']:
        imps = [pr[name]['imp'] for pr in pat_results]
        cogs = [pr[name]['ev']['cog'] for pr in pat_results]
        negs = [pr[name]['ev']['neg_score'] for pr in pat_results]
        effs = [pr[name]['ev']['effort'] for pr in pat_results]
        caps = [pr[name]['ev']['capacity'] for pr in pat_results]
        print(f"  {name.capitalize():<16} "
              f"{np.mean(imps):>+5.1f}+/-{np.std(imps):>4.1f}%  "
              f"{np.mean(cogs):>5.1f}+/-{np.std(cogs):>4.1f}%  "
              f"{np.mean(negs):>5.3f}+/-{np.std(negs):>4.3f}  "
              f"{np.mean(effs):>5.2f}+/-{np.std(effs):>4.2f}  "
              f"{np.mean(caps):>5.3f}+/-{np.std(caps):>4.3f}")

    # Profile-stratified analysis
    print(f"\n  PROFILE-STRATIFIED RESPONSE:")
    print(f"  {'Type':<18} {'N':>3} {'Best Tx':>14} "
          f"{'Worst Tx':>14}")
    print("  " + "-"*54)
    pos_dom = [p for p in pat_results
               if p['profile'][0] > p['profile'][1]]
    neg_dom = [p for p in pat_results
               if p['profile'][1] > p['profile'][0]]
    balanced = [p for p in pat_results
                if abs(p['profile'][0] - p['profile'][1]) <= 0.1]

    from collections import Counter
    for lbl, grp in [('Pos-dominant', pos_dom),
                     ('Neg-dominant', neg_dom),
                     ('Balanced', balanced)]:
        if not grp:
            continue
        best_txs = []
        worst_txs = []
        for pr in grp:
            imps = {n: pr[n]['imp'] for n in
                    ['ketamine', 'antipsychotic', 'ect']}
            best_txs.append(max(imps, key=imps.get))
            worst_txs.append(min(imps, key=imps.get))
        bc = Counter(best_txs).most_common(1)[0]
        wc = Counter(worst_txs).most_common(1)[0]
        print(f"  {lbl:<18} {len(grp):>3} "
              f"{bc[0]:>10}({bc[1]}) "
              f"{wc[0]:>10}({wc[1]})")

    # == FIX SUMMARY ========================================================
    print("\n" + "="*80)
    print("  v2 FIX SUMMARY")
    print("="*80)
    print(f"\n  [FIX 1] Decoupled consolidation")
    print(f"    Unified epochs: ket={CFG['consolidation_epochs']['ketamine']}  "
          f"ap={CFG['consolidation_epochs']['antipsychotic']} (was 30)  "
          f"ect={CFG['consolidation_epochs']['ect_per_session']}/sess")
    print(f"    Zero-consol ablation: AP delta = "
          f"{tx['antipsychotic']['ev']['combined'] - tx_zero['antipsychotic']['ev']['combined']:+.1f}%  "
          f"(consolidation contribution)")

    print(f"\n  [FIX 2] Fixed iso-dose overlap")
    print(f"    ECT seizure noise: 0.15 -> {CFG['ect_seizure_noise']}")
    print(f"    Targeted L1 overlap: {lo_t:.6f} -- {hi_t:.6f}")
    print(f"    Full L1 overlap:     {lo:.6f} -- {hi:.6f}")

    print(f"\n  [FIX 3 v2] Logit-margin hallucination metric")
    print(f"    Old: softmax thr = mean+sigma → exceeded 1.0 → always 0")
    print(f"    New: logit margin thr(p75) = {h_margin_p75:.3f}")
    print(f"    Halluc rates: healthy={h_ev['halluc']:.3f}  "
          f"psychotic={p_ev['halluc']:.3f}  "
          f"ket={tx['ketamine']['ev']['halluc']:.3f}  "
          f"ap={tx['antipsychotic']['ev']['halluc']:.3f}  "
          f"ect={tx['ect']['ev']['halluc']:.3f}")
    psychotic_higher = p_ev['halluc'] > h_ev['halluc']
    print(f"    Psychotic > healthy: {psychotic_higher} "
          f"({'CORRECT' if psychotic_higher else 'INVESTIGATE — see note'})")
    if not psychotic_higher:
        print(f"    NOTE: If psychotic <= healthy, the stochastic noise injection")
        print(f"    at sensory layers may be REDUCING logit margins faster than")
        print(f"    the belief-layer amplification increases them. This is a")
        print(f"    modelling limitation of the feedforward architecture (no")
        print(f"    actual top-down override). Consider increasing belief gain")
        print(f"    or reducing sensory noise in the psychosis profile.")

    print(f"\n  [FIX 4] Normalized PE metric")
    print(f"    PE total: healthy={h_ev['pe']['total']:.3f}  "
          f"psychotic={p_ev['pe']['total']:.3f}  "
          f"ket={tx['ketamine']['ev']['pe']['total']:.3f}  "
          f"ap={tx['antipsychotic']['ev']['pe']['total']:.3f}  "
          f"ect={tx['ect']['ev']['pe']['total']:.3f}")

    return {'healthy': h_ev, 'psychotic': p_ev, 'treatments': tx,
            'treatments_zero': tx_zero,
            'sweeps': {'ket': ks, 'ap': aps, 'ect': es},
            'chronic': chronic, 'patients': pat_results,
            'halluc_threshold': h_margin_p75}


# ============================================================================
if __name__ == "__main__":
    results = run()
    print("=" * 80)
    print("  EXPERIMENT COMPLETE")
    print("=" * 80 + "\n")

  HIERARCHICAL PREDICTIVE CODING MODEL OF PSYCHOSIS
  Iso-Dose + Negative Symptoms + Chronic Dynamics
  + Cognitive Probe + 10-Patient Variability
  v2: Decoupled consolidation | Fixed iso-dose | Logit-margin halluc | Norm PE

----------------------------------------------------------------------
  PHASE 1 -- Healthy baseline
----------------------------------------------------------------------
  Net: 2->[256, 256, 128]->4  (101,260 total params)
    Layer wt/bias: 99,972  |  Precision: 7  |  Effort: 1  |  Unit masks: 640  |  Side effects: 640
    Ep 5/25  loss=0.0001
    Ep 10/25  loss=0.0000
    Ep 15/25  loss=0.0000
    Ep 20/25  loss=0.0000
    Ep 25/25  loss=0.0000

  Halluc threshold (logit margin on noise):
    mean=16.188  std=11.492  thr(p75)=23.130
    (healthy model will show ~25% halluc rate by calibration)

  HEALTHY BASELINE
    Acc  clean=100.0%  std=100.0%  comb=98.5%  extreme=100.0%
    Cog  probe=79.7%  loaded=80.0%
    Pos  halluc=0.248  flex=1.0314  PE=4.321
    Ne

# The end