# The Original Model

In [None]:
"""
Developmental Pruning Simulation: Demonstrating threshold effects and plasticity recovery

This script models a simplified analog of adolescent synaptic pruning in neural networks:
1. Train an overparameterized network (childhood: dense connectivity)
2. Prune aggressively (adolescent pruning: efficiency vs. fragility tradeoff)
3. Observe performance collapse, especially under noise (psychiatric vulnerability)
4. Regrow connections + fine-tune (therapeutic plasticity restoration)
5. Observe recovery (treatment response)

Key psychiatric analogs:
- Excessive pruning → reduced cognitive flexibility, noise tolerance (schizophrenia prodrome)
- Regrowth/plasticity → ketamine-like rapid antidepressant effects, cognitive remediation
"""

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.utils.data import DataLoader, TensorDataset
from collections import OrderedDict

# ============================================================================
# REPRODUCIBILITY: Essential for scientific validity of simulation
# ============================================================================
SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)

# Use CPU for reproducibility; CUDA has non-deterministic operations
DEVICE = torch.device('cpu')


# ============================================================================
# DATA GENERATION
# ============================================================================
def generate_blobs(n_samples: int = 10000, noise: float = 0.8, seed: int = None):
    """
    Generate 4-class classification data: Gaussian blobs at corners of a square.

    The noise parameter controls class overlap - higher noise = harder task.
    This mimics real-world signal/noise ratios in sensory processing.

    Args:
        n_samples: Number of data points
        noise: Standard deviation of Gaussian noise around cluster centers
        seed: Random seed for this specific generation (for separate train/test)

    Returns:
        Tuple of (features tensor [n_samples, 2], labels tensor [n_samples])
    """
    if seed is not None:
        rng = np.random.RandomState(seed)
    else:
        rng = np.random.RandomState()

    # Four well-separated cluster centers
    centers = np.array([[-3, -3], [3, 3], [-3, 3], [3, -3]])

    # Balanced class distribution
    labels = rng.randint(0, 4, n_samples)

    # Add Gaussian noise around centers
    data = centers[labels] + rng.randn(n_samples, 2) * noise

    return (
        torch.tensor(data, dtype=torch.float32),
        torch.tensor(labels, dtype=torch.long)
    )


# Generate datasets with DIFFERENT seeds to ensure true generalization test
# Using same seed for train/test would create data leakage
train_data, train_labels = generate_blobs(12000, noise=0.8, seed=100)
test_data, test_labels = generate_blobs(4000, noise=0.8, seed=200)

# Clean test set: zero noise, pure signal - tests learned decision boundaries
clean_test_data, clean_test_labels = generate_blobs(2000, noise=0.0, seed=300)

train_loader = DataLoader(
    TensorDataset(train_data, train_labels),
    batch_size=128,
    shuffle=True
)
test_loader = DataLoader(
    TensorDataset(test_data, test_labels),
    batch_size=1000
)
clean_test_loader = DataLoader(
    TensorDataset(clean_test_data, clean_test_labels),
    batch_size=1000
)


# ============================================================================
# NETWORK ARCHITECTURE
# ============================================================================
class Net(nn.Module):
    """
    Deliberately overparameterized network (≈400K parameters for a 2D→4 task).

    This excess capacity mirrors the synaptic overgrowth in early childhood.
    Pruning will test which connections are truly necessary.

    Architecture: 2 → 512 → 512 → 256 → 4
    """
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(OrderedDict([
            ('fc1', nn.Linear(2, 512)),
            ('relu1', nn.ReLU()),
            ('fc2', nn.Linear(512, 512)),
            ('relu2', nn.ReLU()),
            ('fc3', nn.Linear(512, 256)),
            ('relu3', nn.ReLU()),
            ('fc4', nn.Linear(256, 4))
        ]))

    def forward(self, x):
        return self.net(x)

    def count_parameters(self):
        """Count total and non-zero parameters."""
        total = sum(p.numel() for p in self.parameters())
        nonzero = sum((p != 0).sum().item() for p in self.parameters())
        return total, nonzero


# ============================================================================
# PRUNING INFRASTRUCTURE
# ============================================================================
class PruningManager:
    """
    Manages weight masks for structured pruning experiments.

    Key design choices:
    - Per-layer pruning: Prevents small layers from being entirely eliminated
    - Mask persistence: Ensures pruned weights stay zero during training
    - Tracks pruning history for analysis

    Psychiatric analog: This is like tracking which synapses are eliminated
    during adolescent pruning - the pattern matters, not just the total count.
    """

    def __init__(self, model: nn.Module):
        self.model = model
        self.masks = {}  # name → binary mask tensor
        self.history = []  # Track pruning/regrowth events

        # Initialize masks to all-ones (nothing pruned yet)
        for name, param in model.named_parameters():
            if 'weight' in name and param.dim() >= 2:  # Only prune weight matrices
                self.masks[name] = torch.ones_like(param, dtype=torch.float32)

    def prune_by_magnitude(self, sparsity: float, per_layer: bool = True):
        """
        Prune weights by magnitude (smallest weights → zero).

        Args:
            sparsity: Fraction of weights to prune (0.95 = keep only top 5%)
            per_layer: If True, prune each layer independently to target sparsity.
                      If False, use global threshold (can eliminate entire layers).

        Returns:
            Dict with pruning statistics per layer

        Biological note: Magnitude-based pruning approximates Hebbian "use it or lose it"
        since larger weights typically indicate more frequently co-activated pathways.
        """
        stats = {}

        if per_layer:
            # RECOMMENDED: Prune each layer to target sparsity independently
            # This prevents pathological cases where one layer is entirely zeroed
            for name, param in self.model.named_parameters():
                if name in self.masks:
                    weights = param.data.abs()
                    threshold = torch.quantile(weights.flatten(), sparsity)

                    # Update mask: 1 where weight >= threshold, 0 otherwise
                    self.masks[name] = (weights >= threshold).float()

                    # Apply mask immediately
                    param.data *= self.masks[name]

                    # Record statistics
                    kept = self.masks[name].sum().item()
                    total = self.masks[name].numel()
                    stats[name] = {
                        'kept': int(kept),
                        'total': total,
                        'actual_sparsity': 1 - kept/total
                    }
        else:
            # Global threshold: can cause layer collapse, included for comparison
            all_weights = torch.cat([
                self.model.get_parameter(name).data.abs().flatten()
                for name in self.masks
            ])
            threshold = torch.quantile(all_weights, sparsity)

            for name, param in self.model.named_parameters():
                if name in self.masks:
                    self.masks[name] = (param.data.abs() >= threshold).float()
                    param.data *= self.masks[name]

                    kept = self.masks[name].sum().item()
                    total = self.masks[name].numel()
                    stats[name] = {
                        'kept': int(kept),
                        'total': total,
                        'actual_sparsity': 1 - kept/total
                    }

        self.history.append(('prune', sparsity, stats))
        return stats

    def regrow_random(self, regrow_fraction: float, init_scale: float = 0.03):
        """
        Randomly regrow a fraction of pruned connections with small initial weights.

        Args:
            regrow_fraction: Fraction of currently-pruned weights to restore (0.5 = half)
            init_scale: Std dev for initializing regrown weights (small = cautious)

        Returns:
            Dict with regrowth statistics per layer

        Psychiatric analog: This models neuroplasticity interventions like:
        - Ketamine: Rapid synaptogenesis in prefrontal cortex
        - Environmental enrichment: Activity-dependent sprouting
        - Cognitive remediation: Strengthening underused pathways

        The small initial weights mean regrown connections must "prove themselves"
        through learning - they're not immediately functional.
        """
        stats = {}

        for name, param in self.model.named_parameters():
            if name not in self.masks:
                continue

            # Find currently pruned positions (mask == 0)
            pruned_mask = (self.masks[name] == 0)
            num_pruned = pruned_mask.sum().item()

            if num_pruned == 0:
                stats[name] = {'regrown': 0, 'still_pruned': 0}
                continue

            # Select random subset to regrow
            num_regrow = int(regrow_fraction * num_pruned)

            if num_regrow == 0:
                stats[name] = {'regrown': 0, 'still_pruned': int(num_pruned)}
                continue

            # Get flat indices of pruned positions
            flat_pruned_indices = torch.where(pruned_mask.flatten())[0]

            # Randomly select which to regrow
            perm = torch.randperm(len(flat_pruned_indices))[:num_regrow]
            regrow_indices = flat_pruned_indices[perm]

            # Update mask and weights
            flat_mask = self.masks[name].flatten()
            flat_param = param.data.flatten()

            flat_mask[regrow_indices] = 1.0
            flat_param[regrow_indices] = torch.randn(num_regrow) * init_scale

            # Reshape back (views, so changes persist)
            self.masks[name] = flat_mask.view_as(self.masks[name])

            stats[name] = {
                'regrown': num_regrow,
                'still_pruned': int(num_pruned - num_regrow)
            }

        self.history.append(('regrow', regrow_fraction, stats))
        return stats

    def apply_masks(self):
        """
        Zero out weights where mask is 0.

        CRITICAL: Must be called after each optimizer step during training
        to maintain sparsity pattern. Without this, pruned weights drift
        back to non-zero values through gradient updates.
        """
        with torch.no_grad():
            for name, param in self.model.named_parameters():
                if name in self.masks:
                    param.data *= self.masks[name]

    def get_sparsity(self):
        """Calculate overall network sparsity."""
        total_params = 0
        zero_params = 0
        for name in self.masks:
            mask = self.masks[name]
            total_params += mask.numel()
            zero_params += (mask == 0).sum().item()
        return zero_params / total_params if total_params > 0 else 0


# ============================================================================
# TRAINING AND EVALUATION
# ============================================================================
def train(
    model: nn.Module,
    epochs: int = 15,
    lr: float = 0.001,
    pruning_manager: PruningManager = None,
    verbose: bool = False
):
    """
    Train model with optional mask enforcement for sparse training.

    Args:
        model: Network to train
        epochs: Training epochs
        lr: Learning rate (lower for fine-tuning after regrowth)
        pruning_manager: If provided, enforces sparsity masks after each step
        verbose: Print loss each epoch

    The pruning_manager.apply_masks() call is crucial - without it,
    gradient updates would resurrect pruned weights, defeating the pruning.
    """
    optimizer = optim.Adam(model.parameters(), lr=lr)
    loss_fn = nn.CrossEntropyLoss()

    for epoch in range(epochs):
        model.train()
        epoch_loss = 0.0

        for x, y in train_loader:
            x, y = x.to(DEVICE), y.to(DEVICE)

            optimizer.zero_grad()
            loss = loss_fn(model(x), y)
            loss.backward()
            optimizer.step()

            # CRITICAL: Re-apply masks after each optimizer step
            # This maintains the sparsity structure
            if pruning_manager is not None:
                pruning_manager.apply_masks()

            epoch_loss += loss.item()

        if verbose:
            print(f"  Epoch {epoch+1}/{epochs}, Loss: {epoch_loss/len(train_loader):.4f}")


def evaluate(model: nn.Module, loader: DataLoader, noise_std: float = 0.0):
    """
    Evaluate model accuracy, optionally with added input noise.

    Args:
        model: Network to evaluate
        loader: DataLoader with test data
        noise_std: Standard deviation of Gaussian noise to add to inputs
                  (tests robustness to perturbations)

    Returns:
        Accuracy as percentage

    Psychiatric note: Noise tolerance is a key measure of network robustness.
    Over-pruned networks lose their "margin of safety" and misclassify
    under perturbation - analogous to cognitive instability under stress.
    """
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(DEVICE), y.to(DEVICE)

            # Add noise to test robustness
            if noise_std > 0:
                x = x + torch.randn_like(x) * noise_std

            predictions = model(x).argmax(dim=1)
            correct += (predictions == y).sum().item()
            total += y.size(0)

    return 100.0 * correct / total


def comprehensive_eval(model: nn.Module, label: str):
    """Run full evaluation suite and print results."""
    clean_acc = evaluate(model, clean_test_loader, noise_std=0.0)
    standard_acc = evaluate(model, test_loader, noise_std=0.0)
    noisy_acc = evaluate(model, test_loader, noise_std=1.0)
    very_noisy_acc = evaluate(model, test_loader, noise_std=2.0)

    total, nonzero = model.count_parameters()
    sparsity = 100 * (1 - nonzero / total)

    print(f"\n{'='*60}")
    print(f"{label}")
    print(f"{'='*60}")
    print(f"  Parameters: {nonzero:,} / {total:,} non-zero ({sparsity:.1f}% sparse)")
    print(f"  Clean test accuracy:     {clean_acc:.1f}%  (no noise in data)")
    print(f"  Standard test accuracy:  {standard_acc:.1f}%  (noise=0.8 in data)")
    print(f"  Noisy test accuracy:     {noisy_acc:.1f}%  (+1.0 input noise)")
    print(f"  Very noisy accuracy:     {very_noisy_acc:.1f}%  (+2.0 input noise)")

    return {
        'clean': clean_acc,
        'standard': standard_acc,
        'noisy': noisy_acc,
        'very_noisy': very_noisy_acc,
        'sparsity': sparsity
    }


# ============================================================================
# MAIN EXPERIMENT
# ============================================================================
def run_experiment():
    """
    Full experimental pipeline demonstrating the pruning-plasticity hypothesis.

    Stages:
    1. BASELINE: Train overparameterized network (childhood connectivity)
    2. PRUNING: Remove 95% of weights by magnitude (adolescent pruning)
    3. OBSERVATION: Note performance drop, especially under noise
    4. PLASTICITY: Regrow 50% of pruned connections + fine-tune
    5. RECOVERY: Observe performance restoration

    This models the psychiatric hypothesis that:
    - Normal pruning → efficiency gains
    - Excessive pruning → vulnerability to noise/stress (prodromal symptoms)
    - Plasticity restoration → therapeutic recovery
    """

    print("\n" + "="*70)
    print("DEVELOPMENTAL PRUNING SIMULATION")
    print("Modeling synaptic pruning, vulnerability, and plasticity recovery")
    print("="*70)

    results = {}

    # ========================================================================
    # STAGE 1: Train full network (childhood - dense connectivity)
    # ========================================================================
    print("\n[STAGE 1] Training full network (childhood connectivity)...")

    model = Net().to(DEVICE)
    train(model, epochs=20, lr=0.001)
    results['full'] = comprehensive_eval(model, "FULL NETWORK (Pre-pruning baseline)")

    # ========================================================================
    # STAGE 2: Aggressive pruning (adolescent synaptic elimination)
    # ========================================================================
    print("\n[STAGE 2] Applying aggressive pruning (95% sparsity)...")
    print("          (Modeling excessive adolescent synaptic elimination)")

    pruning_mgr = PruningManager(model)
    prune_stats = pruning_mgr.prune_by_magnitude(sparsity=0.95, per_layer=True)

    print("\n  Per-layer pruning statistics:")
    for name, stats in prune_stats.items():
        print(f"    {name}: kept {stats['kept']}/{stats['total']} "
              f"({100*stats['actual_sparsity']:.1f}% pruned)")

    results['pruned'] = comprehensive_eval(model, "AFTER AGGRESSIVE PRUNING")

    # ========================================================================
    # STAGE 3: Plasticity restoration (therapeutic intervention)
    # ========================================================================
    print("\n[STAGE 3] Restoring plasticity (regrowing 50% of pruned connections)...")
    print("          (Modeling therapeutic neuroplasticity intervention)")

    regrow_stats = pruning_mgr.regrow_random(regrow_fraction=0.5, init_scale=0.03)

    print("\n  Per-layer regrowth statistics:")
    for name, stats in regrow_stats.items():
        print(f"    {name}: regrew {stats['regrown']}, "
              f"still pruned {stats['still_pruned']}")

    # Fine-tune with lower learning rate (careful retraining of new connections)
    print("\n  Fine-tuning with regrown connections...")
    train(model, epochs=15, lr=0.0005, pruning_manager=pruning_mgr)

    results['recovered'] = comprehensive_eval(model, "AFTER PLASTICITY RESTORATION")

    # ========================================================================
    # SUMMARY
    # ========================================================================
    print("\n" + "="*70)
    print("SUMMARY: Threshold Effects and Recovery")
    print("="*70)

    print(f"""
    Metric              Full    Pruned   Recovered
    ─────────────────────────────────────────────────
    Clean accuracy      {results['full']['clean']:5.1f}%   {results['pruned']['clean']:5.1f}%   {results['recovered']['clean']:5.1f}%
    Standard accuracy   {results['full']['standard']:5.1f}%   {results['pruned']['standard']:5.1f}%   {results['recovered']['standard']:5.1f}%
    Noisy accuracy      {results['full']['noisy']:5.1f}%   {results['pruned']['noisy']:5.1f}%   {results['recovered']['noisy']:5.1f}%
    Very noisy acc      {results['full']['very_noisy']:5.1f}%   {results['pruned']['very_noisy']:5.1f}%   {results['recovered']['very_noisy']:5.1f}%
    Sparsity            {results['full']['sparsity']:5.1f}%   {results['pruned']['sparsity']:5.1f}%   {results['recovered']['sparsity']:5.1f}%
    """)

    print("KEY OBSERVATIONS:")
    print("  1. Pruning causes larger drops in noisy conditions than clean")
    print("     → Over-pruned networks lose robustness (psychiatric vulnerability)")
    print("  2. Recovery is substantial but may not reach full baseline")
    print("     → Plasticity helps but doesn't fully reverse structural loss")
    print("  3. The pattern (not just magnitude) of pruning matters")
    print("     → Per-layer pruning preserves function better than global")

    return results


# ============================================================================
# SPARSITY SWEEP: Finding the critical threshold
# ============================================================================
def sparsity_sweep():
    """
    Test multiple sparsity levels to find the critical threshold.

    This models the question: "How much pruning is too much?"

    Expect to see:
    - Low sparsity (0-70%): Minimal performance loss
    - Medium (70-90%): Gradual degradation
    - High (90-97%): Rapid collapse (threshold effect)
    - Extreme (>97%): Near-chance performance

    The threshold effect is the key finding - there's a "cliff" where
    the network suddenly loses its ability to generalize.
    """
    print("\n" + "="*70)
    print("SPARSITY SWEEP: Finding the critical pruning threshold")
    print("="*70)

    sparsity_levels = [0.0, 0.5, 0.7, 0.8, 0.9, 0.95, 0.97, 0.99]

    print(f"\n{'Sparsity':>10} {'Clean':>10} {'Standard':>10} {'Noisy':>10}")
    print("-" * 45)

    for sparsity in sparsity_levels:
        # Fresh model for each sparsity level
        model = Net().to(DEVICE)
        train(model, epochs=20, lr=0.001)

        if sparsity > 0:
            pruning_mgr = PruningManager(model)
            pruning_mgr.prune_by_magnitude(sparsity=sparsity, per_layer=True)

        clean = evaluate(model, clean_test_loader)
        standard = evaluate(model, test_loader)
        noisy = evaluate(model, test_loader, noise_std=1.0)

        print(f"{sparsity*100:>9.0f}% {clean:>9.1f}% {standard:>9.1f}% {noisy:>9.1f}%")

    print("\nNote: Look for the 'cliff' - where performance drops sharply.")
    print("This threshold varies by task complexity and network architecture.")


if __name__ == "__main__":
    # Run main experiment
    results = run_experiment()

    # Optionally run sparsity sweep (uncomment to explore thresholds)
    # sparsity_sweep()


DEVELOPMENTAL PRUNING SIMULATION
Modeling synaptic pruning, vulnerability, and plasticity recovery

[STAGE 1] Training full network (childhood connectivity)...

FULL NETWORK (Pre-pruning baseline)
  Parameters: 396,548 / 396,548 non-zero (0.0% sparse)
  Clean test accuracy:     100.0%  (no noise in data)
  Standard test accuracy:  100.0%  (noise=0.8 in data)
  Noisy test accuracy:     97.8%  (+1.0 input noise)
  Very noisy accuracy:     83.6%  (+2.0 input noise)

[STAGE 2] Applying aggressive pruning (95% sparsity)...
          (Modeling excessive adolescent synaptic elimination)

  Per-layer pruning statistics:
    net.fc1.weight: kept 52/1024 (94.9% pruned)
    net.fc2.weight: kept 13108/262144 (95.0% pruned)
    net.fc3.weight: kept 6554/131072 (95.0% pruned)
    net.fc4.weight: kept 52/1024 (94.9% pruned)

AFTER AGGRESSIVE PRUNING
  Parameters: 21,050 / 396,548 non-zero (94.7% sparse)
  Clean test accuracy:     50.8%  (no noise in data)
  Standard test accuracy:  43.1%  (noise=0.8

# The End