# ALS Plasticity

### v2

In [None]:
"""
ALS Progression Model Pipeline v2.0
Microglial Pruning Continuum Framework — Amended

All symptoms and therapeutic changes are derived from network architecture state.

AMENDMENTS (v2.0):
    1. Sub-layer granularity: output split into bulbar/respiratory/fine_motor
       sub-modules for symptom-specific circuit mapping
    2. Aggregate toxicity as direct weight corruption (TDP-43 spread model):
       pruned connections perturb surviving neighbors in weight space
    3. Gradient-guided selective regrowth during treatment (BDNF-targeting):
       high-gradient pruned connections prioritized for regrowth
    4. Excitotoxicity derived from architectural instability:
       noise = f(sparsity) — high sparsity yields unstable / noisier circuits
    5. Stochastic flares, onset subtypes, cycle-to-month calibration,
       seed-based patient heterogeneity
    6. Riluzole comparator treatment arm (glutamate-modulation only)
    7. Weight instability metric (std dev of alive weights) → NFL biomarker proxy

Architecture-to-Symptom Mapping (v2.0):
    prefrontal_limbic sparsity             ->  depression_score (PHQ-like 0-48)
    upper_motor sparsity                   ->  spasticity_score (0-4)
    lower_motor sparsity                   ->  weakness_score (0-4)
    neuromuscular_junction sparsity        ->  fatigue_score (0-10)
    bulbar sub-layer sparsity              ->  bulbar_score (0-4) [speech/swallow]
    respiratory sub-layer sparsity + NMJ   ->  respiratory_risk_pct (0-100)
    fine_motor sub-layer sparsity          ->  fine_motor_score (0-4)
    overall accuracy                       ->  alsfrs_r_proxy (0-48)
    weight instability (std alive weights) ->  biomarker_nfl_proxy
    total sparsity                         ->  overall_disability_pct (0-100)
    prefrontal sparsity + weight corrupt   ->  cognitive_score (0-30)
    alive motor connections ratio           ->  motor_integrity_pct (0-100)
    motor-layer noise magnitude            ->  excitotoxicity_load

Architecture-to-Treatment Mapping (v2.0):
    Ketamine:
        NMDA blockade              ->  noise reduction (on sparsity-derived noise)
        Synaptogenesis (BDNF)      ->  gradient-guided mask regrowth
        Anti-inflammatory          ->  reduced per-cycle pruning rate
        Plasticity window          ->  optimizer LR multiplier
        Decay model                ->  half-life exponential on all effects
    Riluzole:
        Glutamate modulation       ->  noise damping in motor layers
        Mild neuroprotection       ->  slight pruning reduction
        No synaptogenesis          ->  no regrowth boost
        No plasticity boost        ->  no LR change
        Steady-state PK            ->  short half-life, continuous dosing
"""

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import csv
import json
import copy
from collections import defaultdict


# ============================================================
# SECTION 1: NETWORK ARCHITECTURE
# ============================================================

class MotorCircuitNetwork(nn.Module):
    """
    Each layer maps to a biological compartment.
    Sub-layer branching after NMJ into bulbar, respiratory, and
    fine_motor sub-modules enables symptom-specific sparsity readout.
    Noise injection per layer models excitotoxicity.

    v2.0: Output replaced by three sub-circuits (bulbar 128→64,
    respiratory 128→64, fine_motor 128→64) concatenated into
    final_output 192→4.  Enables bulbar-onset vs limb-onset ALS
    via differential pruning of sub-layers.
    """

    LAYER_NAMES = [
        'prefrontal_limbic',
        'upper_motor',
        'lower_motor',
        'neuromuscular_junction',
        'bulbar',
        'respiratory',
        'fine_motor',
        'final_output'
    ]

    def __init__(self, input_dim=2, output_dim=4):
        super().__init__()
        self.prefrontal_limbic = nn.Linear(input_dim, 512)
        self.upper_motor = nn.Linear(512, 512)
        self.lower_motor = nn.Linear(512, 256)
        self.neuromuscular_junction = nn.Linear(256, 128)
        # v2.0 sub-layer branching
        self.bulbar = nn.Linear(128, 64)           # speech / swallow circuits
        self.respiratory = nn.Linear(128, 64)       # diaphragm / breathing
        self.fine_motor = nn.Linear(128, 64)        # hand / dexterity
        self.final_output = nn.Linear(192, output_dim)
        self.relu = nn.ReLU()

    def forward(self, x, noise_profile=None):
        if noise_profile is None:
            noise_profile = {}

        h = self.prefrontal_limbic(x)
        if 'prefrontal_limbic' in noise_profile:
            h = h + torch.randn_like(h) * noise_profile['prefrontal_limbic']
        h = self.relu(h)

        h = self.upper_motor(h)
        if 'upper_motor' in noise_profile:
            h = h + torch.randn_like(h) * noise_profile['upper_motor']
        h = self.relu(h)

        h = self.lower_motor(h)
        if 'lower_motor' in noise_profile:
            h = h + torch.randn_like(h) * noise_profile['lower_motor']
        h = self.relu(h)

        h_nmj = self.neuromuscular_junction(h)
        if 'neuromuscular_junction' in noise_profile:
            h_nmj = h_nmj + torch.randn_like(h_nmj) * noise_profile['neuromuscular_junction']
        h_nmj = self.relu(h_nmj)

        # Sub-layer branching from NMJ
        h_bulbar = self.bulbar(h_nmj)
        if 'bulbar' in noise_profile:
            h_bulbar = h_bulbar + torch.randn_like(h_bulbar) * noise_profile['bulbar']
        h_bulbar = self.relu(h_bulbar)

        h_resp = self.respiratory(h_nmj)
        if 'respiratory' in noise_profile:
            h_resp = h_resp + torch.randn_like(h_resp) * noise_profile['respiratory']
        h_resp = self.relu(h_resp)

        h_fine = self.fine_motor(h_nmj)
        if 'fine_motor' in noise_profile:
            h_fine = h_fine + torch.randn_like(h_fine) * noise_profile['fine_motor']
        h_fine = self.relu(h_fine)

        h_combined = torch.cat([h_bulbar, h_resp, h_fine], dim=1)
        out = self.final_output(h_combined)
        return out

    def get_layer_params(self):
        return {
            'prefrontal_limbic': (self.prefrontal_limbic.weight,
                                   self.prefrontal_limbic.bias),
            'upper_motor': (self.upper_motor.weight, self.upper_motor.bias),
            'lower_motor': (self.lower_motor.weight, self.lower_motor.bias),
            'neuromuscular_junction': (self.neuromuscular_junction.weight,
                                       self.neuromuscular_junction.bias),
            'bulbar': (self.bulbar.weight, self.bulbar.bias),
            'respiratory': (self.respiratory.weight, self.respiratory.bias),
            'fine_motor': (self.fine_motor.weight, self.fine_motor.bias),
            'final_output': (self.final_output.weight, self.final_output.bias)
        }


# ============================================================
# SECTION 2: TOXIC AGGREGATE TRACKER
# ============================================================

class ToxicAggregateTracker:
    """
    Protein aggregation load per layer.
    Accumulates when pruned connections are not cleared (autophagy failure).
    High load accelerates further pruning, blocks regrowth,
    AND directly corrupts surviving weights (TDP-43 spread model).

    v2.0: corrupt_surviving_weights() perturbs alive weights near
    pruned connections.  This is a direct architectural effect —
    corrupted weights ARE the toxicity, not a disconnected variable.
    NFL biomarker can be read from weight instability (std dev).
    """

    def __init__(self, layer_names, clearance_rate=0.05,
                 accumulation_rate=0.8, toxicity_threshold=0.3,
                 corruption_spread_radius=5, corruption_magnitude=0.005):
        self.layer_names = layer_names
        self.clearance_rate = clearance_rate
        self.accumulation_rate = accumulation_rate
        self.toxicity_threshold = toxicity_threshold
        self.corruption_spread_radius = corruption_spread_radius
        self.corruption_magnitude = corruption_magnitude
        self.aggregates = {name: 0.0 for name in layer_names}
        self.history = {name: [] for name in layer_names}

    def accumulate(self, layer_name, newly_pruned_fraction):
        self.aggregates[layer_name] += newly_pruned_fraction * self.accumulation_rate

    def attempt_clearance(self, autophagy_efficiency=None):
        eff = (autophagy_efficiency if autophagy_efficiency is not None
               else self.clearance_rate)
        for name in self.layer_names:
            cleared = self.aggregates[name] * eff
            self.aggregates[name] = max(0.0, self.aggregates[name] - cleared)

    def get_toxicity_multiplier(self, layer_name):
        agg = self.aggregates[layer_name]
        if agg > self.toxicity_threshold:
            return 1.0 + (agg - self.toxicity_threshold) * 2.0
        return 1.0

    def get_regrowth_penalty(self, layer_name):
        agg = self.aggregates[layer_name]
        return max(0.0, 1.0 - agg * 1.5)

    def corrupt_surviving_weights(self, model, masks,
                                  newly_pruned_indices_per_layer):
        """
        Direct weight corruption: perturb alive weights near pruned
        connections.  Models TDP-43 / C9orf72 aggregate spread to
        neighboring synapses.

        Vectorized: builds corruption zone via index expansion, then
        adds noise only to alive positions within that zone.
        Magnitude scales with aggregate load (more aggregates →
        stronger corruption → architectural toxicity).
        """
        layer_params = model.get_layer_params()
        radius = self.corruption_spread_radius

        for layer_name, pruned_indices in newly_pruned_indices_per_layer.items():
            if pruned_indices is None:
                continue
            if torch.is_tensor(pruned_indices) and pruned_indices.numel() == 0:
                continue
            if not torch.is_tensor(pruned_indices):
                pruned_indices = torch.tensor(pruned_indices, dtype=torch.long)
            if pruned_indices.numel() == 0:
                continue

            weight, _ = layer_params[layer_name]
            mask_w = masks[layer_name]['weight']
            agg = self.aggregates.get(layer_name, 0.0)

            # Corruption magnitude scales with aggregate load
            magnitude = self.corruption_magnitude * (1.0 + agg)

            flat_weight = weight.data.flatten()
            flat_mask = mask_w.flatten()
            n = flat_weight.numel()

            # Vectorized zone expansion
            offsets = torch.arange(-radius, radius + 1)
            expanded = pruned_indices.unsqueeze(1) + offsets.unsqueeze(0)
            expanded = expanded.flatten().clamp(0, n - 1)
            expanded = torch.unique(expanded)

            # Corrupt only alive neighbours
            alive_in_zone = expanded[flat_mask[expanded] == 1]
            num_targets = alive_in_zone.numel()
            if num_targets > 0:
                flat_weight[alive_in_zone] += (
                    torch.randn(num_targets) * magnitude
                )

    def record(self):
        for name in self.layer_names:
            self.history[name].append(self.aggregates[name])

    def get_state(self):
        return {name: self.aggregates[name] for name in self.layer_names}


# ============================================================
# SECTION 3: MICROGLIAL PRUNING ENGINE
# ============================================================

class MicroglialPruningEngine:
    """
    Complement-tagged pruning targets weakest abs(weight) connections.
    Activity-dependent pruning removes connections at random.
    Masks (binary) track alive=1 / pruned=0 per weight element.
    Regrowth flips 0→1 and re-initialises weight from N(0, init_std).

    v2.0 additions:
        - prune_layer returns (fraction, indices) for toxicity spread
        - gradient_guided_regrow: BDNF-style selective unmasking of
          high-need pruned connections (rows with highest gradient)
        - accumulate_gradients: EMA of |grad| per weight position
        - get_weight_instability: std dev of alive weights per layer
          used as NFL biomarker proxy (purely architectural)
    """

    def __init__(self, model, complement_bias=0.7):
        self.model = model
        self.complement_bias = complement_bias
        self.masks = {}
        self.sparsity_history = {}
        self.gradient_accumulator = {}

        layer_params = model.get_layer_params()
        for name, (weight, bias) in layer_params.items():
            self.masks[name] = {
                'weight': torch.ones_like(weight.data),
                'bias': torch.ones_like(bias.data)
            }
            self.sparsity_history[name] = []
            self.gradient_accumulator[name] = torch.zeros_like(weight.data)

    # ----------------------------------------------------------
    # Pruning
    # ----------------------------------------------------------

    def prune_layer(self, layer_name, rate, toxicity_multiplier=1.0):
        """
        Returns (newly_pruned_fraction, pruned_flat_indices).
        Pruned indices are used by ToxicAggregateTracker for
        weight corruption spread.
        """
        effective_rate = min(rate * toxicity_multiplier, 0.95)
        layer_params = self.model.get_layer_params()
        weight, _ = layer_params[layer_name]
        mask_w = self.masks[layer_name]['weight']

        alive = (mask_w == 1)
        num_alive = alive.sum().item()
        if num_alive == 0:
            return 0.0, torch.tensor([], dtype=torch.long)

        num_to_prune = int(effective_rate * num_alive)
        if num_to_prune == 0:
            return 0.0, torch.tensor([], dtype=torch.long)

        all_pruned = []
        num_complement = int(self.complement_bias * num_to_prune)
        num_random = num_to_prune - num_complement

        if num_complement > 0:
            alive_weights = weight.data.abs().clone()
            alive_weights[~alive] = float('inf')
            flat = alive_weights.flatten()
            _, indices = torch.topk(
                flat, min(num_complement, flat.numel()), largest=False
            )
            valid = flat[indices] < float('inf')
            indices = indices[valid]
            mask_w.flatten()[indices] = 0
            all_pruned.append(indices)

        if num_random > 0:
            alive_after = (mask_w == 1)
            alive_indices = torch.nonzero(
                alive_after.flatten(), as_tuple=False
            ).squeeze(-1)
            if alive_indices.numel() > 0:
                count = min(num_random, alive_indices.numel())
                perm = torch.randperm(alive_indices.numel())[:count]
                selected = alive_indices[perm]
                mask_w.flatten()[selected] = 0
                all_pruned.append(selected)

        weight.data *= mask_w
        newly_pruned_frac = num_to_prune / mask_w.numel()

        if all_pruned:
            pruned_indices = torch.cat(all_pruned)
        else:
            pruned_indices = torch.tensor([], dtype=torch.long)

        return newly_pruned_frac, pruned_indices

    # ----------------------------------------------------------
    # Regrowth (standard random)
    # ----------------------------------------------------------

    def regrow_layer(self, layer_name, fraction, regrowth_penalty=1.0,
                     init_std=0.03):
        effective_fraction = fraction * regrowth_penalty
        layer_params = self.model.get_layer_params()
        weight, _ = layer_params[layer_name]
        mask_w = self.masks[layer_name]['weight']

        dead = (mask_w == 0)
        num_dead = dead.sum().item()
        if num_dead == 0:
            return 0

        num_regrow = int(effective_fraction * num_dead)
        if num_regrow == 0:
            return 0

        dead_indices = torch.nonzero(
            dead.flatten(), as_tuple=False
        ).squeeze(-1)
        count = min(num_regrow, dead_indices.numel())
        perm = torch.randperm(dead_indices.numel())[:count]
        selected = dead_indices[perm]

        mask_w.flatten()[selected] = 1
        weight.data.flatten()[selected] = torch.randn(count) * init_std
        return count

    # ----------------------------------------------------------
    # Gradient-guided regrowth (v2.0)
    # ----------------------------------------------------------

    def accumulate_gradients(self):
        """
        Store exponential moving average of |grad| per weight.
        Called after backward pass.  Used by gradient_guided_regrow
        to identify which dead connections have the highest "need"
        (i.e., live on the same output neuron as highly stressed
        alive connections → BDNF-driven synaptogenesis target).
        """
        layer_params = self.model.get_layer_params()
        for name, (weight, _) in layer_params.items():
            if weight.grad is not None:
                self.gradient_accumulator[name] = (
                    0.9 * self.gradient_accumulator[name]
                    + 0.1 * weight.grad.data.abs()
                )

    def gradient_guided_regrow(self, layer_name, fraction,
                               regrowth_penalty=1.0, init_std=0.03):
        """
        Regrow dead connections prioritising rows (output neurons) with
        highest mean gradient magnitude among alive connections.
        Models BDNF-driven synaptogenesis targeting the most stressed /
        active circuits.

        Falls back to random regrowth if gradient accumulator is all-zero.
        """
        effective_fraction = fraction * regrowth_penalty
        layer_params = self.model.get_layer_params()
        weight, _ = layer_params[layer_name]
        mask_w = self.masks[layer_name]['weight']
        grad_acc = self.gradient_accumulator[layer_name]

        dead = (mask_w == 0)
        num_dead = dead.sum().item()
        if num_dead == 0:
            return 0

        num_regrow = int(effective_fraction * num_dead)
        if num_regrow == 0:
            return 0

        # Mean |grad| per output neuron (row) among alive connections
        alive_count_per_row = mask_w.sum(dim=1).clamp(min=1)
        row_mean_grad = (grad_acc * mask_w).sum(dim=1) / alive_count_per_row

        # Expand to per-element need scores
        need_scores = row_mean_grad.unsqueeze(1).expand_as(mask_w)
        # Mask to dead-only and add small noise to break ties
        need_scores = (need_scores * dead.float()
                       + dead.float() * torch.rand_like(need_scores) * 1e-6)

        flat_scores = need_scores.flatten()
        dead_indices = torch.nonzero(
            dead.flatten(), as_tuple=False
        ).squeeze(-1)
        if dead_indices.numel() == 0:
            return 0

        dead_scores = flat_scores[dead_indices]
        count = min(num_regrow, dead_indices.numel())

        # If accumulator is all-zero, fall back to random
        if dead_scores.max() < 1e-5:
            perm = torch.randperm(dead_indices.numel())[:count]
            selected = dead_indices[perm]
        else:
            _, top_k = torch.topk(dead_scores, count, largest=True)
            selected = dead_indices[top_k]

        mask_w.flatten()[selected] = 1
        weight.data.flatten()[selected] = torch.randn(count) * init_std
        return count

    # ----------------------------------------------------------
    # Weight instability (v2.0 — NFL biomarker proxy)
    # ----------------------------------------------------------

    def get_weight_instability(self):
        """
        Std dev of alive weights per layer.
        Higher instability = more aggregate corruption / dysfunction.
        Purely architectural metric — no external variable.
        """
        layer_params = self.model.get_layer_params()
        instabilities = {}
        for name, (weight, _) in layer_params.items():
            mask_w = self.masks[name]['weight']
            alive = (mask_w == 1)
            if alive.sum() == 0:
                instabilities[name] = 0.0
            else:
                instabilities[name] = weight.data[alive].std().item()
        return instabilities

    # ----------------------------------------------------------
    # Mask utilities
    # ----------------------------------------------------------

    def apply_all_masks(self):
        layer_params = self.model.get_layer_params()
        for name in self.masks:
            weight, _ = layer_params[name]
            weight.data *= self.masks[name]['weight']

    def get_layer_sparsity(self, layer_name):
        mask_w = self.masks[layer_name]['weight']
        total = mask_w.numel()
        dead = (mask_w == 0).sum().item()
        return dead / total if total > 0 else 0.0

    def get_all_sparsities(self):
        return {name: self.get_layer_sparsity(name) for name in self.masks}

    def get_total_sparsity(self):
        total, dead = 0, 0
        for name in self.masks:
            m = self.masks[name]['weight']
            total += m.numel()
            dead += (m == 0).sum().item()
        return dead / total if total > 0 else 0.0

    def record_sparsities(self):
        for name in self.masks:
            self.sparsity_history[name].append(self.get_layer_sparsity(name))

    def get_alive_counts(self):
        counts = {}
        for name in self.masks:
            m = self.masks[name]['weight']
            counts[name] = (m == 1).sum().item()
        return counts

    def get_total_counts(self):
        counts = {}
        for name in self.masks:
            m = self.masks[name]['weight']
            counts[name] = m.numel()
        return counts


# ============================================================
# SECTION 4: SYMPTOM MAPPER
# ============================================================

class SymptomMapper:
    """
    Every symptom is a direct function of network architecture state.
    No free parameters disconnected from the model.

    v2.0 changes:
        - bulbar_score derived from bulbar sub-layer sparsity
        - respiratory_risk_pct from respiratory sub-layer + NMJ
        - fine_motor_score from fine_motor sub-layer sparsity
        - biomarker_nfl_proxy from weight instability (std dev)
        - cognitive_score penalised by PL weight instability
        - motor_integrity_pct includes sub-layers
    """

    def __init__(self):
        self.history = defaultdict(list)

    def compute_symptoms(self, sparsities, accuracy, aggregate_state,
                         total_sparsity, noise_profile, alive_counts,
                         total_counts, treatment_active,
                         weight_instabilities=None):
        symptoms = {}

        pl_sp = sparsities.get('prefrontal_limbic', 0.0)
        um_sp = sparsities.get('upper_motor', 0.0)
        lm_sp = sparsities.get('lower_motor', 0.0)
        nmj_sp = sparsities.get('neuromuscular_junction', 0.0)
        bul_sp = sparsities.get('bulbar', 0.0)
        resp_sp = sparsities.get('respiratory', 0.0)
        fine_sp = sparsities.get('fine_motor', 0.0)
        final_sp = sparsities.get('final_output', 0.0)

        # Derived from prefrontal_limbic sparsity
        symptoms['depression_score'] = pl_sp * 48.0

        # Derived from upper_motor sparsity
        symptoms['spasticity_score'] = um_sp * 4.0

        # Derived from lower_motor sparsity
        symptoms['weakness_score'] = lm_sp * 4.0

        # Derived from NMJ sparsity
        symptoms['fatigue_score'] = nmj_sp * 10.0

        # Derived from bulbar sub-layer sparsity (v2.0)
        symptoms['bulbar_score'] = bul_sp * 4.0

        # Derived from respiratory sub-layer + NMJ (v2.0)
        symptoms['respiratory_risk_pct'] = (
            resp_sp * 0.6 + nmj_sp * 0.4
        ) * 100.0

        # Derived from fine_motor sub-layer sparsity (v2.0)
        symptoms['fine_motor_score'] = fine_sp * 4.0

        # Derived from overall network accuracy
        symptoms['alsfrs_r_proxy'] = accuracy * 48.0 / 100.0

        # Derived from weight instability — std dev of alive weights (v2.0)
        if weight_instabilities is not None:
            total_instability = sum(weight_instabilities.values())
            symptoms['biomarker_nfl_proxy'] = total_instability * 100.0
        else:
            total_agg = sum(aggregate_state.values())
            symptoms['biomarker_nfl_proxy'] = total_agg * 100.0

        # Derived from total sparsity
        symptoms['overall_disability_pct'] = total_sparsity * 100.0

        # Derived from prefrontal sparsity + aggregate + weight corruption (v2.0)
        pl_agg = aggregate_state.get('prefrontal_limbic', 0.0)
        pl_instability = 0.0
        if weight_instabilities is not None:
            pl_instability = weight_instabilities.get('prefrontal_limbic', 0.0)
        cog = ((1.0 - pl_sp) * 30.0
               - pl_agg * 10.0
               - pl_instability * 5.0)
        symptoms['cognitive_score'] = max(0.0, cog)

        # Derived from alive connections in all motor layers (v2.0 includes sub-layers)
        motor_layers = ['upper_motor', 'lower_motor', 'neuromuscular_junction',
                        'bulbar', 'respiratory', 'fine_motor']
        motor_alive = sum(alive_counts.get(ln, 0) for ln in motor_layers)
        motor_total = sum(total_counts.get(ln, 0) for ln in motor_layers)
        symptoms['motor_integrity_pct'] = (
            motor_alive / motor_total * 100.0 if motor_total > 0 else 0.0
        )

        # Derived from noise magnitude in motor layers
        motor_noise = sum(noise_profile.get(ln, 0.0) for ln in motor_layers)
        symptoms['excitotoxicity_load'] = motor_noise

        # Treatment active flag (record-keeping metadata)
        symptoms['treatment_active'] = 1.0 if treatment_active else 0.0

        for key, val in symptoms.items():
            self.history[key].append(val)

        return symptoms


# ============================================================
# SECTION 5: TREATMENT MODULE
# ============================================================

class KetamineTreatment:
    """
    Each effect maps to a specific network architecture change.
    All effects decay exponentially with configurable half-life.

    v2.0: selective_regrow_targets flag triggers gradient-guided
    regrowth in the pruning engine (BDNF-targeting model).
    """

    def __init__(self, regrowth_strength=0.6, noise_reduction=0.5,
                 pruning_reduction=0.4, plasticity_boost=2.0,
                 half_life=3, dose=1.0):
        self.regrowth_strength = regrowth_strength * dose
        self.noise_reduction = noise_reduction * dose
        self.pruning_reduction = pruning_reduction * dose
        self.plasticity_boost = plasticity_boost * dose
        self.half_life = half_life
        self.dose = dose
        self.active = False
        self.cycles_since_dose = 0

    def administer(self):
        self.active = True
        self.cycles_since_dose = 0

    def get_decay_factor(self):
        if not self.active:
            return 0.0
        factor = 0.5 ** (self.cycles_since_dose / max(self.half_life, 1e-6))
        if factor < 0.05:
            self.active = False
            return 0.0
        return factor

    def tick(self):
        if self.active:
            self.cycles_since_dose += 1

    def get_current_effects(self):
        decay = self.get_decay_factor()
        return {
            'regrowth_boost': self.regrowth_strength * decay,
            'noise_reduction': self.noise_reduction * decay,
            'pruning_reduction': self.pruning_reduction * decay,
            'plasticity_lr_multiplier': 1.0 + (self.plasticity_boost - 1.0) * decay,
            'selective_regrow_targets': (
                'high_gradient_pruned' if decay > 0.1 else None
            ),
            'active': self.active,
            'decay_factor': decay,
            'treatment_type': 'ketamine'
        }


class RiluzoleTreatment:
    """
    Riluzole comparator: pure glutamate modulation.

    Architecture mapping:
        - Noise reduction in motor layers (anti-excitotoxic).
          Noise is already sparsity-derived (v2.0), so this
          modulates an architectural property.
        - Mild pruning reduction (neuroprotection).
        - No synaptogenesis (no mask regrowth).
        - No plasticity boost (no LR change).
        - Short half-life with continuous dosing → steady state.
    """

    def __init__(self, noise_reduction=0.3, pruning_reduction=0.15,
                 half_life=2, dose=1.0):
        self.noise_reduction = noise_reduction * dose
        self.pruning_reduction = pruning_reduction * dose
        self.half_life = half_life
        self.dose = dose
        self.active = False
        self.cycles_since_dose = 0

    def administer(self):
        self.active = True
        self.cycles_since_dose = 0

    def get_decay_factor(self):
        if not self.active:
            return 0.0
        factor = 0.5 ** (self.cycles_since_dose / max(self.half_life, 1e-6))
        if factor < 0.05:
            self.active = False
            return 0.0
        return factor

    def tick(self):
        if self.active:
            self.cycles_since_dose += 1

    def get_current_effects(self):
        decay = self.get_decay_factor()
        return {
            'regrowth_boost': 0.0,
            'noise_reduction': self.noise_reduction * decay,
            'pruning_reduction': self.pruning_reduction * decay,
            'plasticity_lr_multiplier': 1.0,
            'selective_regrow_targets': None,
            'active': self.active,
            'decay_factor': decay,
            'treatment_type': 'riluzole'
        }


class TreatmentSchedule:
    """Wraps KetamineTreatment or RiluzoleTreatment with a dosing calendar."""

    def __init__(self, dose_cycles=None, treatment_type='ketamine',
                 dose=1.0, half_life=3,
                 regrowth_strength=0.6, noise_reduction=0.5,
                 pruning_reduction=0.4, plasticity_boost=2.0):
        self.dose_cycles = set(dose_cycles) if dose_cycles is not None else set()
        self.treatment_type = treatment_type

        if treatment_type == 'ketamine':
            self.treatment = KetamineTreatment(
                regrowth_strength=regrowth_strength,
                noise_reduction=noise_reduction,
                pruning_reduction=pruning_reduction,
                plasticity_boost=plasticity_boost,
                half_life=half_life,
                dose=dose
            )
        elif treatment_type == 'riluzole':
            self.treatment = RiluzoleTreatment(
                noise_reduction=noise_reduction,
                pruning_reduction=pruning_reduction,
                half_life=half_life,
                dose=dose
            )
        else:
            raise ValueError(f"Unknown treatment type: {treatment_type}")

    def check_and_administer(self, cycle):
        if cycle in self.dose_cycles:
            self.treatment.administer()
            return True
        return False

    def get_effects(self):
        return self.treatment.get_current_effects()

    def tick(self):
        self.treatment.tick()


# ============================================================
# SECTION 6: EXCITOTOXICITY ENGINE
# ============================================================

class ExcitotoxicityEngine:
    """
    v2.0: Noise derived from architectural instability.
    noise_per_layer = (base + progression * cycle)
                      * (1 + sparsity * instability_gain)
                      * vulnerability
                      * (1 - treatment_reduction)

    High sparsity = fewer surviving connections = more unstable
    remaining circuits = more excitotoxic noise.  This removes
    the purely external noise schedule — noise is now emergent
    from network architecture state.
    """

    def __init__(self, base_noise=0.05, progression_rate=0.005,
                 motor_vulnerability=1.5, limbic_vulnerability=1.0,
                 instability_gain=3.0):
        self.base_noise = base_noise
        self.progression_rate = progression_rate
        self.motor_vulnerability = motor_vulnerability
        self.limbic_vulnerability = limbic_vulnerability
        self.instability_gain = instability_gain

    def get_noise_profile(self, cycle, current_sparsities=None,
                          treatment_effects=None):
        noise_reduction = 0.0
        if treatment_effects is not None:
            noise_reduction = treatment_effects.get('noise_reduction', 0.0)

        base = self.base_noise + self.progression_rate * cycle
        reduction_mult = max(0.0, 1.0 - noise_reduction)

        if current_sparsities is None:
            current_sparsities = {}

        vulnerability_map = {
            'prefrontal_limbic': self.limbic_vulnerability,
            'upper_motor': self.motor_vulnerability,
            'lower_motor': self.motor_vulnerability * 1.2,
            'neuromuscular_junction': self.motor_vulnerability * 0.8,
            'bulbar': self.motor_vulnerability * 1.0,
            'respiratory': self.motor_vulnerability * 1.1,
            'fine_motor': self.motor_vulnerability * 0.9,
        }

        profile = {}
        for layer_name, vuln in vulnerability_map.items():
            sp = current_sparsities.get(layer_name, 0.0)
            instability_factor = 1.0 + sp * self.instability_gain
            profile[layer_name] = base * instability_factor * vuln * reduction_mult

        return profile


# ============================================================
# SECTION 7: DATA GENERATOR
# ============================================================

class MotorTaskGenerator:
    """
    Synthetic 4-class task:
        0 = limb_function
        1 = bulbar_function
        2 = respiratory_function
        3 = fine_motor_function
    """

    def __init__(self, n_samples=2000, seed=42):
        self.n_samples = n_samples
        rng = torch.Generator()
        rng.manual_seed(seed)
        self.rng = rng
        self._generate(seed)

    def _generate(self, seed):
        torch.manual_seed(seed)
        centers = torch.tensor([
            [-1.0, -1.0],
            [ 1.0, -1.0],
            [-1.0,  1.0],
            [ 1.0,  1.0]
        ])
        x_list, y_list = [], []
        per_class = self.n_samples // 4
        for c in range(4):
            x_c = centers[c] + torch.randn(per_class, 2) * 0.5
            y_c = torch.full((per_class,), c, dtype=torch.long)
            x_list.append(x_c)
            y_list.append(y_c)

        x = torch.cat(x_list)
        y = torch.cat(y_list)
        perm = torch.randperm(len(x))
        x, y = x[perm], y[perm]

        split = int(0.8 * len(x))
        self.x_train, self.x_test = x[:split], x[split:]
        self.y_train, self.y_test = y[:split], y[split:]

    def get_train(self):
        return self.x_train, self.y_train

    def get_test(self):
        return self.x_test, self.y_test


# ============================================================
# SECTION 8: TRAINING AND EVALUATION
# ============================================================

def train_epoch(model, x, y, optimizer, criterion, pruning_engine,
                noise_profile=None, n_steps=5, batch_size=256):
    model.train()
    losses = []
    for _ in range(n_steps):
        idx = torch.randperm(len(x))[:batch_size]
        xb, yb = x[idx], y[idx]
        out = model(xb, noise_profile=noise_profile)
        loss = criterion(out, yb)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        pruning_engine.apply_all_masks()
        losses.append(loss.item())
    return np.mean(losses)


def evaluate(model, x, y, noise_profile=None):
    model.eval()
    with torch.no_grad():
        out = model(x, noise_profile=noise_profile)
        pred = out.argmax(dim=1)
        acc = (pred == y).float().mean().item() * 100.0
        per_class = {}
        for c in range(4):
            mask = (y == c)
            if mask.sum() > 0:
                per_class[c] = (pred[mask] == y[mask]).float().mean().item() * 100.0
            else:
                per_class[c] = 0.0
    return acc, per_class


# ============================================================
# SECTION 9: DISEASE PROFILES
# ============================================================

class DiseaseProfile:
    """
    v2.0 additions per profile:
        - pruning_rates for all 8 layers (including sub-layers)
        - months_per_cycle for calendar calibration
        - phase_boundaries for preclinical / symptomatic / advanced
        - flare_threshold, flare_probability, flare_magnitude
        - corruption_spread_radius, corruption_magnitude
        - instability_gain for sparsity-driven excitotoxicity
        - onset_type label

    Onset subtypes (ALS_Bulbar_Onset, ALS_Limb_Onset) achieve
    phenotypic variation purely through differential sub-layer
    pruning rates — no external symptom rules.
    """

    @staticmethod
    def als_profile():
        return {
            'name': 'ALS',
            'onset_type': 'mixed',
            'total_cycles': 30,
            'months_per_cycle': 1.5,
            'phase_boundaries': {
                'preclinical': 5, 'symptomatic': 20, 'advanced': 30
            },
            'pruning_rates': {
                'prefrontal_limbic': 0.04,
                'upper_motor': 0.10,
                'lower_motor': 0.12,
                'neuromuscular_junction': 0.08,
                'bulbar': 0.05,
                'respiratory': 0.06,
                'fine_motor': 0.09,
                'final_output': 0.04
            },
            'pruning_acceleration': 0.005,
            'autophagy_clearance': 0.05,
            'aggregate_accumulation': 0.8,
            'toxicity_threshold': 0.2,
            'base_regrowth': 0.02,
            'excitotoxicity_base': 0.05,
            'excitotoxicity_progression': 0.008,
            'motor_vulnerability': 2.0,
            'limbic_vulnerability': 0.8,
            'instability_gain': 3.0,
            'complement_bias': 0.7,
            'corruption_spread_radius': 5,
            'corruption_magnitude': 0.005,
            'flare_threshold': 0.3,
            'flare_probability': 0.10,
            'flare_magnitude': 2.0,
        }

    @staticmethod
    def als_bulbar_onset_profile():
        """Bulbar-onset: higher bulbar/respiratory pruning, lower limb."""
        profile = DiseaseProfile.als_profile()
        profile['name'] = 'ALS_Bulbar_Onset'
        profile['onset_type'] = 'bulbar'
        profile['pruning_rates']['bulbar'] = 0.14
        profile['pruning_rates']['respiratory'] = 0.10
        profile['pruning_rates']['fine_motor'] = 0.06
        profile['pruning_rates']['lower_motor'] = 0.08
        return profile

    @staticmethod
    def als_limb_onset_profile():
        """Limb-onset: higher fine_motor/lower_motor pruning, lower bulbar."""
        profile = DiseaseProfile.als_profile()
        profile['name'] = 'ALS_Limb_Onset'
        profile['onset_type'] = 'limb'
        profile['pruning_rates']['fine_motor'] = 0.14
        profile['pruning_rates']['lower_motor'] = 0.14
        profile['pruning_rates']['bulbar'] = 0.04
        profile['pruning_rates']['respiratory'] = 0.05
        return profile

    @staticmethod
    def mdd_profile():
        return {
            'name': 'MDD',
            'onset_type': 'none',
            'total_cycles': 30,
            'months_per_cycle': 1.0,
            'phase_boundaries': {
                'preclinical': 3, 'symptomatic': 15, 'advanced': 30
            },
            'pruning_rates': {
                'prefrontal_limbic': 0.10,
                'upper_motor': 0.02,
                'lower_motor': 0.02,
                'neuromuscular_junction': 0.01,
                'bulbar': 0.01,
                'respiratory': 0.01,
                'fine_motor': 0.01,
                'final_output': 0.02
            },
            'pruning_acceleration': 0.002,
            'autophagy_clearance': 0.30,
            'aggregate_accumulation': 0.2,
            'toxicity_threshold': 0.5,
            'base_regrowth': 0.08,
            'excitotoxicity_base': 0.03,
            'excitotoxicity_progression': 0.002,
            'motor_vulnerability': 0.5,
            'limbic_vulnerability': 2.0,
            'instability_gain': 2.0,
            'complement_bias': 0.5,
            'corruption_spread_radius': 3,
            'corruption_magnitude': 0.002,
            'flare_threshold': 0.5,
            'flare_probability': 0.05,
            'flare_magnitude': 1.5,
        }

    @staticmethod
    def als_mdd_comorbid_profile():
        return {
            'name': 'ALS_MDD_Comorbid',
            'onset_type': 'mixed',
            'total_cycles': 30,
            'months_per_cycle': 1.5,
            'phase_boundaries': {
                'preclinical': 4, 'symptomatic': 18, 'advanced': 30
            },
            'pruning_rates': {
                'prefrontal_limbic': 0.08,
                'upper_motor': 0.10,
                'lower_motor': 0.12,
                'neuromuscular_junction': 0.08,
                'bulbar': 0.05,
                'respiratory': 0.06,
                'fine_motor': 0.09,
                'final_output': 0.05
            },
            'pruning_acceleration': 0.005,
            'autophagy_clearance': 0.05,
            'aggregate_accumulation': 0.8,
            'toxicity_threshold': 0.2,
            'base_regrowth': 0.02,
            'excitotoxicity_base': 0.05,
            'excitotoxicity_progression': 0.008,
            'motor_vulnerability': 2.0,
            'limbic_vulnerability': 1.8,
            'instability_gain': 3.0,
            'complement_bias': 0.7,
            'corruption_spread_radius': 5,
            'corruption_magnitude': 0.006,
            'flare_threshold': 0.25,
            'flare_probability': 0.12,
            'flare_magnitude': 2.0,
        }

    @staticmethod
    def control_profile():
        return {
            'name': 'Control',
            'onset_type': 'none',
            'total_cycles': 30,
            'months_per_cycle': 1.5,
            'phase_boundaries': {
                'preclinical': 30, 'symptomatic': 30, 'advanced': 30
            },
            'pruning_rates': {
                'prefrontal_limbic': 0.01,
                'upper_motor': 0.01,
                'lower_motor': 0.01,
                'neuromuscular_junction': 0.01,
                'bulbar': 0.01,
                'respiratory': 0.01,
                'fine_motor': 0.01,
                'final_output': 0.01
            },
            'pruning_acceleration': 0.0,
            'autophagy_clearance': 0.50,
            'aggregate_accumulation': 0.1,
            'toxicity_threshold': 0.8,
            'base_regrowth': 0.15,
            'excitotoxicity_base': 0.01,
            'excitotoxicity_progression': 0.0,
            'motor_vulnerability': 1.0,
            'limbic_vulnerability': 1.0,
            'instability_gain': 1.0,
            'complement_bias': 0.5,
            'corruption_spread_radius': 2,
            'corruption_magnitude': 0.001,
            'flare_threshold': 0.8,
            'flare_probability': 0.0,
            'flare_magnitude': 1.0,
        }


# ============================================================
# SECTION 10: MAIN SIMULATION PIPELINE
# ============================================================

class ProgressionPipeline:
    """
    All symptoms and therapeutic changes are derived from network
    architecture state (masks, weights, sparsity, aggregate load,
    noise profile, accuracy, weight instability).

    v2.0 simulation loop per cycle:
        1. Check / administer treatment
        2. Stochastic flare check (per layer, sparsity-gated)
        3. Pruning (with flare magnitude boost)
        4. Toxic aggregate weight corruption (TDP-43 spread)
        5. Autophagy clearance
        6. Regrowth (gradient-guided if treatment active, else random)
        7. Apply masks
        8. Training (with sparsity-derived excitotoxicity noise)
        9. Accumulate gradients for future guided regrowth
       10. Evaluate accuracy and per-class
       11. Compute symptoms from architecture state
       12. Record all metrics including month, phase, flares
    """

    def __init__(self, disease_profile, treatment_schedule=None,
                 pretrain_epochs=50, seed=42):
        torch.manual_seed(seed)
        np.random.seed(seed)

        self.config = disease_profile
        self.treatment_schedule = treatment_schedule
        self.pretrain_epochs = pretrain_epochs
        self.seed = seed

        self.model = MotorCircuitNetwork()
        self.data = MotorTaskGenerator(n_samples=2000, seed=seed)

        self.pruning_engine = MicroglialPruningEngine(
            self.model,
            complement_bias=self.config['complement_bias']
        )

        self.aggregates = ToxicAggregateTracker(
            layer_names=MotorCircuitNetwork.LAYER_NAMES,
            clearance_rate=self.config['autophagy_clearance'],
            accumulation_rate=self.config['aggregate_accumulation'],
            toxicity_threshold=self.config['toxicity_threshold'],
            corruption_spread_radius=self.config.get(
                'corruption_spread_radius', 5
            ),
            corruption_magnitude=self.config.get(
                'corruption_magnitude', 0.005
            )
        )

        self.excitotoxicity = ExcitotoxicityEngine(
            base_noise=self.config['excitotoxicity_base'],
            progression_rate=self.config['excitotoxicity_progression'],
            motor_vulnerability=self.config['motor_vulnerability'],
            limbic_vulnerability=self.config['limbic_vulnerability'],
            instability_gain=self.config.get('instability_gain', 3.0)
        )

        self.symptom_mapper = SymptomMapper()
        self.criterion = nn.CrossEntropyLoss()
        self.base_lr = 0.001
        self.results = []

        # v2.0: patient heterogeneity — seed-based vulnerability modifiers
        rng_het = np.random.RandomState(seed + 1000)
        self.vulnerability_modifiers = {}
        for ln in MotorCircuitNetwork.LAYER_NAMES:
            self.vulnerability_modifiers[ln] = rng_het.uniform(0.85, 1.15)

    def _get_phase(self, cycle):
        boundaries = self.config.get('phase_boundaries', {})
        if cycle < boundaries.get('preclinical', 5):
            return 'preclinical'
        elif cycle < boundaries.get('symptomatic', 20):
            return 'symptomatic'
        else:
            return 'advanced'

    def _get_month(self, cycle):
        return cycle * self.config.get('months_per_cycle', 1.5)

    def pretrain(self):
        optimizer = optim.Adam(self.model.parameters(), lr=self.base_lr)
        x_train, y_train = self.data.get_train()
        x_test, y_test = self.data.get_test()

        for _ in range(self.pretrain_epochs):
            self.model.train()
            idx = torch.randperm(len(x_train))[:256]
            out = self.model(x_train[idx])
            loss = self.criterion(out, y_train[idx])
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        acc, per_class = evaluate(self.model, x_test, y_test)
        return acc, per_class

    def run(self):
        baseline_acc, baseline_per_class = self.pretrain()
        x_train, y_train = self.data.get_train()
        x_test, y_test = self.data.get_test()
        total_cycles = self.config['total_cycles']

        # ---- Baseline record ----
        baseline_record = {
            'cycle': -1,
            'month': 0.0,
            'phase': 'baseline',
            'condition': self.config['name'],
            'onset_type': self.config.get('onset_type', 'mixed'),
            'accuracy': baseline_acc,
            'train_loss': 0.0,
            'total_sparsity': 0.0,
            'dose_given': False,
            'treatment_active': False,
            'treatment_decay': 0.0,
            'treatment_type': 'none',
            'flare_active': False,
            'flare_layers': '[]'
        }
        for ln in MotorCircuitNetwork.LAYER_NAMES:
            baseline_record[f'sparsity_{ln}'] = 0.0
            baseline_record[f'aggregate_{ln}'] = 0.0
            baseline_record[f'noise_{ln}'] = 0.0
            baseline_record[f'weight_instability_{ln}'] = 0.0
        for c_id, c_acc in baseline_per_class.items():
            baseline_record[f'class_{c_id}_acc'] = c_acc

        baseline_symptoms = self.symptom_mapper.compute_symptoms(
            {ln: 0.0 for ln in MotorCircuitNetwork.LAYER_NAMES},
            baseline_acc,
            {ln: 0.0 for ln in MotorCircuitNetwork.LAYER_NAMES},
            0.0,
            {ln: 0.0 for ln in MotorCircuitNetwork.LAYER_NAMES
             if ln != 'final_output'},
            self.pruning_engine.get_alive_counts(),
            self.pruning_engine.get_total_counts(),
            False,
            weight_instabilities={ln: 0.0
                                  for ln in MotorCircuitNetwork.LAYER_NAMES}
        )
        for sname, sval in baseline_symptoms.items():
            baseline_record[f'symptom_{sname}'] = sval
        self.results.append(baseline_record)

        # ---- Main simulation loop ----
        for cycle in range(total_cycles):
            record = {
                'cycle': cycle,
                'month': self._get_month(cycle),
                'phase': self._get_phase(cycle),
                'condition': self.config['name'],
                'onset_type': self.config.get('onset_type', 'mixed'),
            }

            # 1. Treatment
            null_effects = {
                'regrowth_boost': 0.0, 'noise_reduction': 0.0,
                'pruning_reduction': 0.0, 'plasticity_lr_multiplier': 1.0,
                'selective_regrow_targets': None,
                'active': False, 'decay_factor': 0.0,
                'treatment_type': 'none'
            }
            dose_given = False
            treatment_effects = null_effects

            if self.treatment_schedule is not None:
                dose_given = self.treatment_schedule.check_and_administer(cycle)
                treatment_effects = self.treatment_schedule.get_effects()

            record['dose_given'] = dose_given
            record['treatment_active'] = treatment_effects['active']
            record['treatment_decay'] = treatment_effects['decay_factor']
            record['treatment_type'] = treatment_effects.get(
                'treatment_type', 'none'
            )

            # 2. Stochastic flare check (v2.0)
            flare_active = {}
            flare_thresh = self.config.get('flare_threshold', 0.3)
            flare_prob = self.config.get('flare_probability', 0.10)
            for layer_name in self.config['pruning_rates']:
                sp = self.pruning_engine.get_layer_sparsity(layer_name)
                if sp > flare_thresh and np.random.random() < flare_prob:
                    flare_active[layer_name] = True
                else:
                    flare_active[layer_name] = False

            record['flare_active'] = any(flare_active.values())
            record['flare_layers'] = json.dumps(
                [ln for ln, a in flare_active.items() if a]
            )

            # 3. Pruning (with toxicity spread indices)
            newly_pruned_per_layer = {}
            for layer_name, base_rate in self.config['pruning_rates'].items():
                # Heterogeneity modifier (v2.0)
                het_mod = self.vulnerability_modifiers.get(layer_name, 1.0)
                acceleration = self.config['pruning_acceleration'] * cycle
                effective_rate = base_rate * het_mod + acceleration
                effective_rate *= max(
                    0.0, 1.0 - treatment_effects['pruning_reduction']
                )

                # Flare burst (v2.0)
                if flare_active.get(layer_name, False):
                    effective_rate *= self.config.get('flare_magnitude', 2.0)

                tox_mult = self.aggregates.get_toxicity_multiplier(layer_name)
                newly_pruned_frac, pruned_indices = (
                    self.pruning_engine.prune_layer(
                        layer_name, effective_rate, tox_mult
                    )
                )
                self.aggregates.accumulate(layer_name, newly_pruned_frac)
                newly_pruned_per_layer[layer_name] = pruned_indices

            # 4. Toxic aggregate weight corruption (v2.0)
            self.aggregates.corrupt_surviving_weights(
                self.model, self.pruning_engine.masks,
                newly_pruned_per_layer
            )

            # 5. Autophagy clearance
            self.aggregates.attempt_clearance(
                self.config['autophagy_clearance']
            )

            # 6. Regrowth (gradient-guided if treatment active)
            for layer_name in self.config['pruning_rates']:
                regrow_frac = (self.config['base_regrowth']
                               + treatment_effects['regrowth_boost'])
                regrowth_penalty = self.aggregates.get_regrowth_penalty(
                    layer_name
                )

                if (treatment_effects.get('selective_regrow_targets')
                        == 'high_gradient_pruned'):
                    self.pruning_engine.gradient_guided_regrow(
                        layer_name, regrow_frac, regrowth_penalty
                    )
                else:
                    self.pruning_engine.regrow_layer(
                        layer_name, regrow_frac, regrowth_penalty
                    )

            # 7. Apply masks
            self.pruning_engine.apply_all_masks()

            # 8. Training with sparsity-derived excitotoxicity
            lr = self.base_lr * treatment_effects['plasticity_lr_multiplier']
            optimizer = optim.SGD(self.model.parameters(), lr=lr)

            current_sparsities = self.pruning_engine.get_all_sparsities()
            noise_profile = self.excitotoxicity.get_noise_profile(
                cycle,
                current_sparsities=current_sparsities,
                treatment_effects=treatment_effects
            )

            train_loss = train_epoch(
                self.model, x_train, y_train, optimizer, self.criterion,
                self.pruning_engine, noise_profile=noise_profile, n_steps=3
            )

            # 9. Accumulate gradients for future guided regrowth (v2.0)
            self.pruning_engine.accumulate_gradients()

            # 10. Evaluate
            acc, per_class = evaluate(
                self.model, x_test, y_test, noise_profile=noise_profile
            )

            # 11. Record architecture state and compute symptoms
            sparsities = self.pruning_engine.get_all_sparsities()
            total_sparsity = self.pruning_engine.get_total_sparsity()
            self.pruning_engine.record_sparsities()
            agg_state = self.aggregates.get_state()
            self.aggregates.record()
            alive_counts = self.pruning_engine.get_alive_counts()
            total_counts = self.pruning_engine.get_total_counts()
            weight_instabilities = self.pruning_engine.get_weight_instability()

            symptoms = self.symptom_mapper.compute_symptoms(
                sparsities, acc, agg_state, total_sparsity,
                noise_profile, alive_counts, total_counts,
                treatment_effects['active'],
                weight_instabilities=weight_instabilities
            )

            # 12. Pack record
            record['accuracy'] = acc
            record['train_loss'] = train_loss
            record['total_sparsity'] = total_sparsity

            for ln, sp in sparsities.items():
                record[f'sparsity_{ln}'] = sp
            for ln, ag in agg_state.items():
                record[f'aggregate_{ln}'] = ag
            for sname, sval in symptoms.items():
                record[f'symptom_{sname}'] = sval
            for c_id, c_acc in per_class.items():
                record[f'class_{c_id}_acc'] = c_acc
            for ln, nv in noise_profile.items():
                record[f'noise_{ln}'] = nv
            for ln, wi in weight_instabilities.items():
                record[f'weight_instability_{ln}'] = wi

            self.results.append(record)

            if self.treatment_schedule is not None:
                self.treatment_schedule.tick()

        return self.results

    def export_csv(self, filepath):
        if not self.results:
            return
        keys = list(self.results[0].keys())
        for r in self.results[1:]:
            for k in r.keys():
                if k not in keys:
                    keys.append(k)
        with open(filepath, 'w', newline='') as f:
            writer = csv.DictWriter(f, fieldnames=keys, extrasaction='ignore')
            writer.writeheader()
            writer.writerows(self.results)

    def export_json(self, filepath):
        with open(filepath, 'w') as f:
            json.dump(self.results, f, indent=2, default=str)


# ============================================================
# SECTION 11: SCENARIO RUNNER
# ============================================================

def build_treatment_schedule(arm_name, dose_cycles, params=None):
    if dose_cycles is None:
        return None
    defaults = {
        'dose': 1.0, 'half_life': 3,
        'treatment_type': 'ketamine',
        'regrowth_strength': 0.6, 'noise_reduction': 0.5,
        'pruning_reduction': 0.4, 'plasticity_boost': 2.0
    }
    if params is not None:
        defaults.update(params)
    return TreatmentSchedule(dose_cycles=dose_cycles, **defaults)


def run_all_scenarios(seeds=None):
    if seeds is None:
        seeds = [42]

    profiles = {
        'ALS': DiseaseProfile.als_profile(),
        'ALS_Bulbar': DiseaseProfile.als_bulbar_onset_profile(),
        'ALS_Limb': DiseaseProfile.als_limb_onset_profile(),
        'MDD': DiseaseProfile.mdd_profile(),
        'ALS_MDD': DiseaseProfile.als_mdd_comorbid_profile(),
        'Control': DiseaseProfile.control_profile()
    }

    # (dose_cycles, override_params)
    treatment_arms = {
        'no_treatment': (None, None),
        'ketamine_q5': ([5, 10, 15, 20], None),
        'ketamine_early': ([2, 4, 6], None),
        'ketamine_late': ([15, 18, 21, 24], None),
        'ketamine_maintenance': (list(range(3, 28, 3)), None),
        'riluzole_continuous': (list(range(0, 30)), {
            'treatment_type': 'riluzole', 'half_life': 2,
            'noise_reduction': 0.3, 'pruning_reduction': 0.15,
            'regrowth_strength': 0.0, 'plasticity_boost': 1.0
        }),
        'riluzole_early': (list(range(0, 15)), {
            'treatment_type': 'riluzole', 'half_life': 2,
            'noise_reduction': 0.3, 'pruning_reduction': 0.15,
            'regrowth_strength': 0.0, 'plasticity_boost': 1.0
        }),
    }

    all_results = {}

    for cond_name, profile in profiles.items():
        for arm_name, (dose_cycles, params) in treatment_arms.items():
            for seed in seeds:
                scenario_key = f"{cond_name}__{arm_name}__seed{seed}"
                schedule = build_treatment_schedule(
                    arm_name, dose_cycles, params
                )

                pipeline = ProgressionPipeline(
                    disease_profile=profile,
                    treatment_schedule=schedule,
                    pretrain_epochs=50,
                    seed=seed
                )

                results = pipeline.run()
                all_results[scenario_key] = results

    return all_results


# ============================================================
# SECTION 12: CONSOLE OUTPUT
# ============================================================

def print_scenario_table(scenario_name, results):
    print(f"\n{'=' * 140}")
    print(f"  {scenario_name}")
    print(f"{'=' * 140}")
    header = (
        f"{'Cyc':>3} {'Mo':>5} {'Phs':>5} | {'Acc%':>6} | {'Sprs%':>6} | "
        f"{'PL':>5} {'UM':>5} {'LM':>5} {'NMJ':>5} "
        f"{'Bul':>5} {'Rsp':>5} {'FM':>5} | "
        f"{'Dep':>5} {'Weak':>5} {'BulS':>5} {'Resp':>5} | "
        f"{'ALS':>5} {'NFL':>6} {'Cog':>5} | {'Tx':>4} {'Flr':>3}"
    )
    print(header)
    print('-' * 140)

    for r in results:
        tx_str = ''
        if r.get('dose_given', False):
            tx_str = 'DOS'
        elif r.get('treatment_active', False):
            tx_str = f"{r.get('treatment_decay', 0.0):.2f}"

        flr_str = '*' if r.get('flare_active', False) else ''

        cyc = r.get('cycle', -1)
        mo = r.get('month', 0.0)
        phs = r.get('phase', '?')[:5]
        acc = r.get('accuracy', 0.0)
        sp = r.get('total_sparsity', 0.0) * 100.0
        pl = r.get('sparsity_prefrontal_limbic', 0.0) * 100.0
        um = r.get('sparsity_upper_motor', 0.0) * 100.0
        lm = r.get('sparsity_lower_motor', 0.0) * 100.0
        nmj = r.get('sparsity_neuromuscular_junction', 0.0) * 100.0
        bul = r.get('sparsity_bulbar', 0.0) * 100.0
        rsp = r.get('sparsity_respiratory', 0.0) * 100.0
        fm = r.get('sparsity_fine_motor', 0.0) * 100.0
        dep = r.get('symptom_depression_score', 0.0)
        wk = r.get('symptom_weakness_score', 0.0)
        buls = r.get('symptom_bulbar_score', 0.0)
        resp = r.get('symptom_respiratory_risk_pct', 0.0)
        als = r.get('symptom_alsfrs_r_proxy', 0.0)
        nfl = r.get('symptom_biomarker_nfl_proxy', 0.0)
        cog = r.get('symptom_cognitive_score', 0.0)

        print(
            f"{cyc:3d} {mo:5.1f} {phs:>5} | {acc:6.1f} | {sp:5.1f}% | "
            f"{pl:4.1f}% {um:4.1f}% {lm:4.1f}% {nmj:4.1f}% "
            f"{bul:4.1f}% {rsp:4.1f}% {fm:4.1f}% | "
            f"{dep:5.1f} {wk:5.2f} {buls:5.2f} {resp:5.1f} | "
            f"{als:5.1f} {nfl:6.1f} {cog:5.1f} | {tx_str:>4} {flr_str:>3}"
        )


def print_aggregate_comparison(all_results, cycle_index=-1):
    """Print final-cycle comparison across all scenarios."""
    print(f"\n{'=' * 155}")
    print("  CROSS-SCENARIO COMPARISON (final cycle)")
    print(f"{'=' * 155}")
    header = (
        f"{'Scenario':<50} | {'Acc%':>6} | {'Sprs%':>6} | "
        f"{'Dep':>5} | {'Weak':>5} | {'BulS':>5} | {'FMS':>5} | "
        f"{'ALSFRS':>6} | {'NFL':>6} | "
        f"{'Resp%':>5} | {'MotInt%':>7} | {'Cog':>5}"
    )
    print(header)
    print('-' * 155)

    for scenario_name in sorted(all_results.keys()):
        results = all_results[scenario_name]
        r = results[cycle_index]
        acc = r.get('accuracy', 0.0)
        sp = r.get('total_sparsity', 0.0) * 100.0
        dep = r.get('symptom_depression_score', 0.0)
        wk = r.get('symptom_weakness_score', 0.0)
        buls = r.get('symptom_bulbar_score', 0.0)
        fms = r.get('symptom_fine_motor_score', 0.0)
        als = r.get('symptom_alsfrs_r_proxy', 0.0)
        nfl = r.get('symptom_biomarker_nfl_proxy', 0.0)
        resp = r.get('symptom_respiratory_risk_pct', 0.0)
        mi = r.get('symptom_motor_integrity_pct', 0.0)
        cog = r.get('symptom_cognitive_score', 0.0)

        print(
            f"{scenario_name:<50} | {acc:6.1f} | {sp:5.1f}% | "
            f"{dep:5.1f} | {wk:5.2f} | {buls:5.2f} | {fms:5.2f} | "
            f"{als:6.1f} | {nfl:6.1f} | "
            f"{resp:5.1f} | {mi:6.1f}% | {cog:5.1f}"
        )


# ============================================================
# SECTION 13: ENTRY POINT
# ============================================================

if __name__ == '__main__':
    print("=" * 60)
    print("  ALS PROGRESSION MODEL PIPELINE v2.0")
    print("  Microglial Pruning Continuum Framework")
    print("  All symptoms derived from network architecture")
    print("")
    print("  Amendments:")
    print("    - Sub-layer granularity (bulbar/respiratory/fine_motor)")
    print("    - Aggregate toxicity as weight corruption")
    print("    - Gradient-guided selective regrowth")
    print("    - Sparsity-driven excitotoxicity")
    print("    - Stochastic flares & onset subtypes")
    print("    - Riluzole comparator arm")
    print("    - Weight instability NFL proxy")
    print("    - Patient heterogeneity via seed modifiers")
    print("=" * 60)

    all_results = run_all_scenarios(seeds=[42])

    display_scenarios = [
        'ALS__no_treatment__seed42',
        'ALS__ketamine_q5__seed42',
        'ALS__ketamine_maintenance__seed42',
        'ALS__riluzole_continuous__seed42',
        'ALS_Bulbar__no_treatment__seed42',
        'ALS_Bulbar__ketamine_q5__seed42',
        'ALS_Limb__no_treatment__seed42',
        'ALS_Limb__ketamine_q5__seed42',
        'ALS_MDD__no_treatment__seed42',
        'ALS_MDD__ketamine_q5__seed42',
        'MDD__no_treatment__seed42',
        'MDD__ketamine_q5__seed42',
        'Control__no_treatment__seed42',
    ]

    for scenario in display_scenarios:
        if scenario in all_results:
            print_scenario_table(scenario, all_results[scenario])

    print_aggregate_comparison(all_results)

    print(f"\nTotal scenarios executed: {len(all_results)}")
    print("Pipeline complete.")

  ALS PROGRESSION MODEL PIPELINE v2.0
  Microglial Pruning Continuum Framework
  All symptoms derived from network architecture

  Amendments:
    - Sub-layer granularity (bulbar/respiratory/fine_motor)
    - Aggregate toxicity as weight corruption
    - Gradient-guided selective regrowth
    - Sparsity-driven excitotoxicity
    - Stochastic flares & onset subtypes
    - Riluzole comparator arm
    - Weight instability NFL proxy
    - Patient heterogeneity via seed modifiers

  ALS__no_treatment__seed42
Cyc    Mo   Phs |   Acc% |  Sprs% |    PL    UM    LM   NMJ   Bul   Rsp    FM |   Dep  Weak  BulS  Resp |   ALS    NFL   Cog |   Tx Flr
--------------------------------------------------------------------------------------------------------------------------------------------
 -1   0.0 basel |   95.7 |   0.0% |  0.0%  0.0%  0.0%  0.0%  0.0%  0.0%  0.0% |   0.0  0.00  0.00   0.0 |  46.0    0.0  30.0 |         
  0   0.0 precl |   95.2 |   9.7% |  3.7%  8.8% 12.9%  7.1%  5.6%  5.6%  7.9% 

# The End
