# ASD Models

In [None]:
"""
================================================================================
EXTENDED DEVELOPMENTAL PRUNING SIMULATION FOR AUTISM SPECTRUM DISORDER
================================================================================

VERSION 9.2: DENSITY SWEEP WITH DIFFERENTIAL PRUNING

EXPERIMENT 1: DENSITY SWEEP
---------------------------
EARLY STAGE:
  - Sweep densities from 10% to 100% (10% intervals)
  - Compare learning dynamics across density levels
  - This models different degrees of early developmental pruning

LATE STAGE (for each early density):
  - Normal: 20% additional pruning from early state
  - ASD: 50% additional pruning from early state
  - Compare Normal vs ASD outcomes at each density level

This design allows us to see:
1. How initial density affects learning and resilience
2. How the same starting point diverges with different late pruning
3. The interaction between early density and late pruning severity

================================================================================
"""

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.utils.data import DataLoader, TensorDataset
from typing import Dict, Tuple, List
import warnings

warnings.filterwarnings('ignore', category=UserWarning)


# ============================================================================
# SECTION 1: CONFIGURATION
# ============================================================================

SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)

DEVICE = torch.device('cpu')

CONFIG = {
    # =======================================================================
    # DATA GENERATION PARAMETERS
    # =======================================================================
    'n_train': 12000,
    'n_test': 4000,
    'n_clean_test': 2000,
    'batch_size': 128,

    # Visual task: Clear, well-separated clusters
    'visual_noise': 0.8,
    'visual_centers': [[-3, -3], [3, 3], [-3, 3], [3, -3]],

    # Voice task: SAME centers as visual, PERMUTED labels (true conflict)
    'voice_noise': 0.8,
    'voice_label_permutation': [2, 3, 0, 1],

    # =======================================================================
    # NETWORK ARCHITECTURE
    # =======================================================================
    'hidden_dims': [256, 256, 128],
    'input_dim': 2,
    'task_id_dim': 2,
    'output_dim': 4,

    # =======================================================================
    # TRAINING HYPERPARAMETERS
    # =======================================================================
    'baseline_epochs': 30,
    'baseline_lr': 0.001,
    'finetune_lr': 0.0005,
    'multitask_epochs': 60,

    # =======================================================================
    # VERSION 9.2: DENSITY SWEEP PARAMETERS
    # =======================================================================
    'early_density_levels': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0],
    'early_training_epochs': 15,

    # Late stage pruning (fraction to REMOVE from early state)
    'late_pruning_normal': 0.20,  # Remove 20% of remaining weights
    'late_pruning_asd': 0.50,     # Remove 50% of remaining weights
    'late_finetune_epochs': 10,

    # =======================================================================
    # STRESS LEVELS
    # =======================================================================
    'stress_levels': {
        'none': 0.0,
        'mild': 0.5,
        'moderate': 1.0,
        'high': 2.0,
        'severe': 3.0
    },

    # =======================================================================
    # TASK CONTEXT AMBIGUITY LEVELS
    # =======================================================================
    'ambiguity_levels': {
        'pure_visual': 1.0,
        'slight': 0.85,
        'moderate': 0.7,
        'high': 0.6,
        'severe': 0.55,
        'ambiguous': 0.5
    },
}


# ============================================================================
# SECTION 2: DATA GENERATION
# ============================================================================

def generate_visual_blobs(
    n_samples: int = 10000,
    noise: float = None,
    seed: int = None
) -> Tuple[torch.Tensor, torch.Tensor]:
    if noise is None:
        noise = CONFIG['visual_noise']

    rng = np.random.RandomState(seed) if seed else np.random.RandomState()
    centers = np.array(CONFIG['visual_centers'])

    labels = rng.randint(0, 4, n_samples)
    data = centers[labels] + rng.randn(n_samples, 2) * noise

    return (
        torch.tensor(data, dtype=torch.float32),
        torch.tensor(labels, dtype=torch.long)
    )


def generate_voice_blobs_conflicting(
    n_samples: int = 10000,
    noise: float = None,
    seed: int = None
) -> Tuple[torch.Tensor, torch.Tensor]:
    if noise is None:
        noise = CONFIG['voice_noise']

    rng = np.random.RandomState(seed) if seed else np.random.RandomState()
    centers = np.array(CONFIG['visual_centers'])

    original_labels = rng.randint(0, 4, n_samples)
    data = centers[original_labels] + rng.randn(n_samples, 2) * noise

    permutation = CONFIG['voice_label_permutation']
    conflicting_labels = np.array([permutation[l] for l in original_labels])

    return (
        torch.tensor(data, dtype=torch.float32),
        torch.tensor(conflicting_labels, dtype=torch.long)
    )


def create_multitask_data_loaders() -> Dict[str, DataLoader]:
    print("\n" + "="*70)
    print(" CREATING TASK-GATED MULTI-TASK DATA LOADERS")
    print("="*70)

    print("\n Generating Visual Task Data...")
    print(f"   - Training samples: {CONFIG['n_train']:,}")
    print(f"   - Test samples: {CONFIG['n_test']:,}")
    print(f"   - Clean test samples: {CONFIG['n_clean_test']:,}")
    print(f"   - Noise level: {CONFIG['visual_noise']}")
    print(f"   - Cluster centers: {CONFIG['visual_centers']}")

    visual_train_data, visual_train_labels = generate_visual_blobs(CONFIG['n_train'], seed=100)
    visual_test_data, visual_test_labels = generate_visual_blobs(CONFIG['n_test'], seed=200)
    visual_clean_data, visual_clean_labels = generate_visual_blobs(CONFIG['n_clean_test'], noise=0.0, seed=300)

    print("\n Generating Voice Task Data (CONFLICTING labels)...")
    print(f"   - Training samples: {CONFIG['n_train']:,}")
    print(f"   - Test samples: {CONFIG['n_test']:,}")
    print(f"   - Noise level: {CONFIG['voice_noise']}")
    print(f"   - Label permutation: {CONFIG['voice_label_permutation']}")
    print("   - NOTE: Same spatial distribution as visual, but labels are permuted")

    voice_train_data, voice_train_labels = generate_voice_blobs_conflicting(CONFIG['n_train'], seed=400)
    voice_test_data, voice_test_labels = generate_voice_blobs_conflicting(CONFIG['n_test'], seed=500)

    loaders = {
        'visual_train': DataLoader(TensorDataset(visual_train_data, visual_train_labels),
                                   batch_size=CONFIG['batch_size'], shuffle=True),
        'visual_test': DataLoader(TensorDataset(visual_test_data, visual_test_labels),
                                  batch_size=1000, shuffle=False),
        'visual_clean': DataLoader(TensorDataset(visual_clean_data, visual_clean_labels),
                                   batch_size=1000, shuffle=False),
        'voice_train': DataLoader(TensorDataset(voice_train_data, voice_train_labels),
                                  batch_size=CONFIG['batch_size'], shuffle=True),
        'voice_test': DataLoader(TensorDataset(voice_test_data, voice_test_labels),
                                 batch_size=1000, shuffle=False),
        'visual_test_data': visual_test_data,
        'visual_test_labels': visual_test_labels,
        'voice_test_data': voice_test_data,
        'voice_test_labels': voice_test_labels
    }

    print("\n Data Loaders Created Successfully:")
    print(f"   - Visual train loader: {len(loaders['visual_train'])} batches")
    print(f"   - Visual test loader: {len(loaders['visual_test'])} batches")
    print(f"   - Voice train loader: {len(loaders['voice_train'])} batches")
    print(f"   - Voice test loader: {len(loaders['voice_test'])} batches")
    print(f"   - Batch size: {CONFIG['batch_size']}")

    return loaders


# ============================================================================
# SECTION 3: NETWORK
# ============================================================================

class TaskGatedNetwork(nn.Module):
    def __init__(self, hidden_dims: List[int] = None):
        super().__init__()

        if hidden_dims is None:
            hidden_dims = CONFIG['hidden_dims']

        input_dim = CONFIG['input_dim'] + CONFIG['task_id_dim']

        self.fc1 = nn.Linear(input_dim, hidden_dims[0])
        self.fc2 = nn.Linear(hidden_dims[0], hidden_dims[1])
        self.fc3 = nn.Linear(hidden_dims[1], hidden_dims[2])
        self.head = nn.Linear(hidden_dims[2], CONFIG['output_dim'])

        self.relu = nn.ReLU()
        self.stress_level = 0.0

        self.weight_layers = ['fc1', 'fc2', 'fc3', 'head']
        self.backbone_layers = ['fc1', 'fc2', 'fc3']

    def set_stress(self, level: float):
        self.stress_level = level

    def forward(self, x: torch.Tensor, task_id: torch.Tensor) -> torch.Tensor:
        combined = torch.cat([x, task_id], dim=1)

        h = self.fc1(combined)
        h = self.relu(h)
        if self.stress_level > 0:
            h = h + torch.randn_like(h) * self.stress_level

        h = self.fc2(h)
        h = self.relu(h)
        if self.stress_level > 0:
            h = h + torch.randn_like(h) * self.stress_level

        h = self.fc3(h)
        h = self.relu(h)
        if self.stress_level > 0:
            h = h + torch.randn_like(h) * self.stress_level

        return self.head(h)

    def count_parameters(self) -> Tuple[int, int]:
        total = sum(p.numel() for p in self.parameters())
        nonzero = sum((p != 0).sum().item() for p in self.parameters())
        return total, nonzero

    def print_architecture(self):
        print("\n" + "-"*50)
        print(" NETWORK ARCHITECTURE")
        print("-"*50)
        print(f"   Input dimension: {CONFIG['input_dim']} + {CONFIG['task_id_dim']} (task ID) = {CONFIG['input_dim'] + CONFIG['task_id_dim']}")
        print(f"   Hidden layers: {CONFIG['hidden_dims']}")
        print(f"   Output dimension: {CONFIG['output_dim']}")
        print("\n   Layer Details:")
        for name, param in self.named_parameters():
            print(f"     {name}: {list(param.shape)}")
        total, nonzero = self.count_parameters()
        print(f"\n   Total parameters: {total:,}")
        print(f"   Non-zero parameters: {nonzero:,}")
        print("-"*50)


# ============================================================================
# SECTION 4: PRUNING MANAGER
# ============================================================================

class TaskGatedPruningManager:
    def __init__(self, model: TaskGatedNetwork, prune_head: bool = True):
        self.model = model
        self.prune_head = prune_head
        self.masks = {}
        self.history = []
        self.gradient_buffer = {}

        for name, param in model.named_parameters():
            if 'weight' in name and param.dim() >= 2:
                layer_name = name.replace('.weight', '')
                if prune_head or layer_name in model.backbone_layers:
                    self.masks[name] = torch.ones_like(param, dtype=torch.float32)
                    self.gradient_buffer[name] = torch.zeros_like(param)

    def prune_by_magnitude(self, sparsity: float, per_layer: bool = True) -> Dict[str, Dict]:
        stats = {}

        if per_layer:
            for name, param in self.model.named_parameters():
                if name in self.masks:
                    weights = param.data.abs()
                    threshold = torch.quantile(weights.flatten(), sparsity)
                    self.masks[name] = (weights >= threshold).float()
                    param.data *= self.masks[name]

                    kept = self.masks[name].sum().item()
                    total = self.masks[name].numel()
                    stats[name] = {'kept': int(kept), 'total': total, 'actual_sparsity': 1 - kept/total}
        else:
            all_weights = torch.cat([
                self.model.get_parameter(name).data.abs().flatten()
                for name in self.masks
            ])
            threshold = torch.quantile(all_weights, sparsity)

            for name, param in self.model.named_parameters():
                if name in self.masks:
                    self.masks[name] = (param.data.abs() >= threshold).float()
                    param.data *= self.masks[name]

                    kept = self.masks[name].sum().item()
                    total = self.masks[name].numel()
                    stats[name] = {'kept': int(kept), 'total': total, 'actual_sparsity': 1 - kept/total}

        self.history.append(('prune', sparsity, stats))
        return stats

    def prune_to_density(self, target_density: float, per_layer: bool = True) -> Dict[str, Dict]:
        target_sparsity = 1.0 - target_density
        return self.prune_by_magnitude(target_sparsity, per_layer)

    def prune_fraction_of_remaining(self, fraction_to_remove: float, per_layer: bool = True) -> Dict[str, Dict]:
        """Remove a fraction of currently active weights."""
        stats = {}

        for name, param in self.model.named_parameters():
            if name not in self.masks:
                continue

            mask = self.masks[name]
            active_positions = (mask == 1)
            active_weights = param.data.abs() * active_positions.float()

            num_active = active_positions.sum().item()
            if num_active == 0:
                stats[name] = {'removed': 0, 'remaining': 0, 'fraction': 0}
                continue

            num_to_remove = int(fraction_to_remove * num_active)
            if num_to_remove == 0:
                stats[name] = {'removed': 0, 'remaining': int(num_active), 'fraction': 0}
                continue

            # Get threshold for bottom fraction_to_remove of active weights
            active_values = param.data.abs()[active_positions]
            threshold = torch.quantile(active_values, fraction_to_remove)

            # Create new mask: keep weights above threshold AND currently active
            new_mask = ((param.data.abs() >= threshold) & active_positions).float()

            self.masks[name] = new_mask
            param.data *= new_mask

            new_active = new_mask.sum().item()
            stats[name] = {
                'removed': int(num_active - new_active),
                'remaining': int(new_active),
                'fraction': (num_active - new_active) / num_active if num_active > 0 else 0
            }

        self.history.append(('prune_fraction', fraction_to_remove, stats))
        return stats

    def apply_masks(self):
        with torch.no_grad():
            for name, param in self.model.named_parameters():
                if name in self.masks:
                    param.data *= self.masks[name]

    def get_sparsity(self) -> float:
        total_params = sum(m.numel() for m in self.masks.values())
        zero_params = sum((m == 0).sum().item() for m in self.masks.values())
        return zero_params / total_params if total_params > 0 else 0.0

    def get_density(self) -> float:
        return 1.0 - self.get_sparsity()

    def print_mask_stats(self, label: str = ""):
        print(f"\n   Mask Statistics {label}:")
        total_params = 0
        total_active = 0
        for name, mask in self.masks.items():
            active = mask.sum().item()
            total = mask.numel()
            total_params += total
            total_active += active
            print(f"     {name}: {int(active):,}/{total:,} active ({100*active/total:.1f}%)")
        print(f"     TOTAL: {int(total_active):,}/{total_params:,} active ({100*total_active/total_params:.1f}%)")


# ============================================================================
# SECTION 5: TRAINING
# ============================================================================

def train_taskgated_with_curves(
    model: TaskGatedNetwork,
    loaders: Dict[str, DataLoader],
    epochs: int = None,
    lr: float = None,
    pruning_manager: TaskGatedPruningManager = None,
    verbose: bool = True,
    track_test_accuracy: bool = True
) -> Dict[str, List[float]]:

    if epochs is None:
        epochs = CONFIG['multitask_epochs']
    if lr is None:
        lr = CONFIG['baseline_lr']

    optimizer = optim.Adam(model.parameters(), lr=lr)
    loss_fn = nn.CrossEntropyLoss()

    history = {
        'visual_loss': [], 'voice_loss': [], 'combined_loss': [],
        'visual_acc': [], 'voice_acc': []
    }

    visual_task_id = torch.tensor([[1.0, 0.0]]).to(DEVICE)
    voice_task_id = torch.tensor([[0.0, 1.0]]).to(DEVICE)

    model.set_stress(0.0)

    if verbose:
        print(f"\n   Training Configuration:")
        print(f"     - Epochs: {epochs}")
        print(f"     - Learning rate: {lr}")
        print(f"     - Optimizer: Adam")
        print(f"     - Loss function: CrossEntropyLoss")
        if pruning_manager:
            print(f"     - Pruning manager: Active (density={pruning_manager.get_density()*100:.1f}%)")
        print()

    for epoch in range(epochs):
        model.train()
        visual_loss = 0.0
        voice_loss = 0.0
        visual_batches = 0
        voice_batches = 0

        visual_iter = iter(loaders['visual_train'])
        voice_iter = iter(loaders['voice_train'])

        done = False
        batch_idx = 0

        while not done:
            try:
                x_v, y_v = next(visual_iter)
                x_v, y_v = x_v.to(DEVICE), y_v.to(DEVICE)
                batch_size = x_v.size(0)
                task_id = visual_task_id.repeat(batch_size, 1)

                optimizer.zero_grad()
                output = model(x_v, task_id)
                loss = loss_fn(output, y_v)
                loss.backward()
                optimizer.step()

                if pruning_manager:
                    pruning_manager.apply_masks()

                visual_loss += loss.item()
                visual_batches += 1
            except StopIteration:
                visual_iter = iter(loaders['visual_train'])

            try:
                x_a, y_a = next(voice_iter)
                x_a, y_a = x_a.to(DEVICE), y_a.to(DEVICE)
                batch_size = x_a.size(0)
                task_id = voice_task_id.repeat(batch_size, 1)

                optimizer.zero_grad()
                output = model(x_a, task_id)
                loss = loss_fn(output, y_a)
                loss.backward()
                optimizer.step()

                if pruning_manager:
                    pruning_manager.apply_masks()

                voice_loss += loss.item()
                voice_batches += 1
            except StopIteration:
                voice_iter = iter(loaders['voice_train'])

            batch_idx += 1
            if batch_idx >= len(loaders['visual_train']):
                done = True

        avg_visual_loss = visual_loss / max(visual_batches, 1)
        avg_voice_loss = voice_loss / max(voice_batches, 1)
        avg_combined_loss = (avg_visual_loss + avg_voice_loss) / 2

        history['visual_loss'].append(avg_visual_loss)
        history['voice_loss'].append(avg_voice_loss)
        history['combined_loss'].append(avg_combined_loss)

        if track_test_accuracy:
            visual_acc = evaluate_task_quick(model, loaders, 'visual')
            voice_acc = evaluate_task_quick(model, loaders, 'voice')
            history['visual_acc'].append(visual_acc)
            history['voice_acc'].append(voice_acc)

        if verbose and (epoch + 1) % 5 == 0:
            if track_test_accuracy:
                print(f"    Epoch {epoch+1:>3}/{epochs}: Loss V={avg_visual_loss:.4f} A={avg_voice_loss:.4f} "
                      f"Combined={avg_combined_loss:.4f} | Acc V={history['visual_acc'][-1]:.1f}% A={history['voice_acc'][-1]:.1f}%")
            else:
                print(f"    Epoch {epoch+1:>3}/{epochs}: Loss V={avg_visual_loss:.4f} A={avg_voice_loss:.4f} "
                      f"Combined={avg_combined_loss:.4f}")

    if verbose:
        print(f"\n   Training Complete:")
        print(f"     - Final combined loss: {history['combined_loss'][-1]:.4f}")
        if track_test_accuracy:
            print(f"     - Final visual accuracy: {history['visual_acc'][-1]:.1f}%")
            print(f"     - Final voice accuracy: {history['voice_acc'][-1]:.1f}%")

    return history


def evaluate_task_quick(model: TaskGatedNetwork, loaders: Dict[str, DataLoader], task: str = 'visual') -> float:
    model.eval()
    model.set_stress(0.0)

    if task == 'visual':
        loader = loaders['visual_test']
        task_id_base = torch.tensor([[1.0, 0.0]]).to(DEVICE)
    else:
        loader = loaders['voice_test']
        task_id_base = torch.tensor([[0.0, 1.0]]).to(DEVICE)

    correct = 0
    total = 0

    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(DEVICE), y.to(DEVICE)
            batch_size = x.size(0)
            task_id = task_id_base.repeat(batch_size, 1)
            predictions = model(x, task_id).argmax(dim=1)
            correct += (predictions == y).sum().item()
            total += y.size(0)

    model.train()
    return 100.0 * correct / total


# ============================================================================
# SECTION 6: EVALUATION
# ============================================================================

def evaluate_task(model: TaskGatedNetwork, loaders: Dict[str, DataLoader],
                  task: str = 'visual', internal_stress: float = 0.0) -> float:
    model.eval()
    model.set_stress(internal_stress)

    if task == 'visual':
        loader = loaders['visual_test']
        task_id_base = torch.tensor([[1.0, 0.0]]).to(DEVICE)
    else:
        loader = loaders['voice_test']
        task_id_base = torch.tensor([[0.0, 1.0]]).to(DEVICE)

    correct = 0
    total = 0

    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(DEVICE), y.to(DEVICE)
            batch_size = x.size(0)
            task_id = task_id_base.repeat(batch_size, 1)
            predictions = model(x, task_id).argmax(dim=1)
            correct += (predictions == y).sum().item()
            total += y.size(0)

    model.set_stress(0.0)
    return 100.0 * correct / total


def evaluate_with_ambiguous_context(model: TaskGatedNetwork, loaders: Dict[str, DataLoader],
                                    visual_weight: float = 0.5, internal_stress: float = 0.0) -> float:
    model.eval()
    model.set_stress(internal_stress)

    visual_data = loaders['visual_test_data']
    visual_labels = loaders['visual_test_labels']
    task_id_base = torch.tensor([[visual_weight, 1 - visual_weight]]).to(DEVICE)

    correct = 0
    total = 0

    with torch.no_grad():
        batch_size = 1000
        n_samples = len(visual_data)

        for start in range(0, n_samples, batch_size):
            end = min(start + batch_size, n_samples)
            x = visual_data[start:end].to(DEVICE)
            y = visual_labels[start:end].to(DEVICE)
            task_id = task_id_base.repeat(x.size(0), 1)
            predictions = model(x, task_id).argmax(dim=1)
            correct += (predictions == y).sum().item()
            total += y.size(0)

    model.set_stress(0.0)
    return 100.0 * correct / total


def evaluate_with_blended_input_and_ambiguous_context(
    model: TaskGatedNetwork, loaders: Dict[str, DataLoader],
    input_blend: float = 0.5, visual_weight: float = 0.5, internal_stress: float = 0.0) -> float:

    model.eval()
    model.set_stress(internal_stress)

    visual_data = loaders['visual_test_data']
    visual_labels = loaders['visual_test_labels']
    voice_data = loaders['voice_test_data']
    task_id_base = torch.tensor([[visual_weight, 1 - visual_weight]]).to(DEVICE)

    correct = 0
    total = 0

    with torch.no_grad():
        batch_size = 1000
        n_samples = len(visual_data)

        for start in range(0, n_samples, batch_size):
            end = min(start + batch_size, n_samples)
            x_visual = visual_data[start:end].to(DEVICE)
            y_visual = visual_labels[start:end].to(DEVICE)
            indices = torch.randint(0, len(voice_data), (end - start,))
            x_voice = voice_data[indices].to(DEVICE)
            x_blended = (1 - input_blend) * x_visual + input_blend * x_voice
            task_id = task_id_base.repeat(x_blended.size(0), 1)
            predictions = model(x_blended, task_id).argmax(dim=1)
            correct += (predictions == y_visual).sum().item()
            total += y_visual.size(0)

    model.set_stress(0.0)
    return 100.0 * correct / total


def comprehensive_evaluation(model: TaskGatedNetwork, loaders: Dict[str, DataLoader],
                             label: str, print_results: bool = True) -> Dict[str, float]:
    results = {}
    model.set_stress(0.0)

    results['visual_clear'] = evaluate_task(model, loaders, 'visual')
    results['voice_clear'] = evaluate_task(model, loaders, 'voice')

    for stress_name, stress_level in CONFIG['stress_levels'].items():
        if stress_level > 0:
            results[f'visual_stress_{stress_name}'] = evaluate_task(
                model, loaders, 'visual', internal_stress=stress_level)

    for amb_name, visual_weight in CONFIG['ambiguity_levels'].items():
        results[f'visual_ambig_{amb_name}'] = evaluate_with_ambiguous_context(
            model, loaders, visual_weight=visual_weight)

    baseline = results['visual_clear']
    for amb_name, visual_weight in CONFIG['ambiguity_levels'].items():
        if visual_weight < 1.0:
            ambig_acc = results[f'visual_ambig_{amb_name}']
            drop = baseline - ambig_acc
            relative_drop = (drop / baseline * 100) if baseline > 0 else 0
            results[f'ambig_drop_{amb_name}'] = drop
            results[f'ambig_drop_pct_{amb_name}'] = relative_drop

    results['visual_blend50_ambig'] = evaluate_with_blended_input_and_ambiguous_context(
        model, loaders, input_blend=0.5, visual_weight=0.5)
    results['visual_blend30_clear'] = evaluate_with_blended_input_and_ambiguous_context(
        model, loaders, input_blend=0.3, visual_weight=1.0)
    results['visual_blend50_clear'] = evaluate_with_blended_input_and_ambiguous_context(
        model, loaders, input_blend=0.5, visual_weight=1.0)

    total, nonzero = model.count_parameters()
    results['sparsity'] = 100 * (1 - nonzero / total)
    results['density'] = 100 * (nonzero / total)
    results['total_params'] = total
    results['nonzero_params'] = nonzero

    if print_results:
        print(f"\n{'='*70}")
        print(f" EVALUATION: {label}")
        print(f"{'='*70}")
        print(f"\n Network Statistics:")
        print(f"   Total parameters: {total:,}")
        print(f"   Non-zero parameters: {nonzero:,}")
        print(f"   Density: {results['density']:.1f}%")
        print(f"   Sparsity: {results['sparsity']:.1f}%")

        print(f"\n Basic Task Performance:")
        print(f"   Visual task accuracy: {results['visual_clear']:.1f}%")
        print(f"   Voice task accuracy: {results['voice_clear']:.1f}%")

        print(f"\n Stress Tolerance (Visual Task):")
        for stress_name, stress_level in CONFIG['stress_levels'].items():
            if stress_level > 0:
                acc = results[f'visual_stress_{stress_name}']
                drop = results['visual_clear'] - acc
                print(f"   {stress_name.capitalize():>10} (σ={stress_level:.1f}): {acc:.1f}% (drop: {drop:+.1f}%)")

        print(f"\n Ambiguity Tolerance (Visual Task):")
        for amb_name, visual_weight in CONFIG['ambiguity_levels'].items():
            acc = results[f'visual_ambig_{amb_name}']
            if visual_weight < 1.0:
                drop = results[f'ambig_drop_{amb_name}']
                print(f"   {amb_name.capitalize():>12} (w={visual_weight:.2f}): {acc:.1f}% (drop: {drop:+.1f}%)")
            else:
                print(f"   {amb_name.capitalize():>12} (w={visual_weight:.2f}): {acc:.1f}%")

        print(f"\n Blended Input Tests:")
        print(f"   50% blend + ambiguous context: {results['visual_blend50_ambig']:.1f}%")
        print(f"   30% blend + clear context: {results['visual_blend30_clear']:.1f}%")
        print(f"   50% blend + clear context: {results['visual_blend50_clear']:.1f}%")

    return results


# ============================================================================
# SECTION 7: DENSITY SWEEP EXPERIMENT (EXPERIMENT 1)
# ============================================================================

def run_density_sweep_experiment() -> Dict[str, Dict]:
    """
    EXPERIMENT 1: Density sweep with differential late pruning

    EARLY STAGE: For each density level (10% to 100%):
      - Prune to target density
      - Train for fixed epochs
      - Evaluate and compare across densities

    LATE STAGE: For each early density:
      - Normal arm: Remove 20% of remaining weights
      - ASD arm: Remove 50% of remaining weights
      - Compare Normal vs ASD at each density
    """
    print("\n" + "#"*80)
    print("#" + " EXPERIMENT 1: DENSITY SWEEP WITH DIFFERENTIAL LATE PRUNING ".center(78) + "#")
    print("#" + " Version 9.2: Comprehensive Early/Late Stage Comparison ".center(78) + "#")
    print("#"*80)

    print(f"""
 EXPERIMENTAL DESIGN:
 ====================

 EARLY STAGE (density sweep):
   Densities tested: {[f'{d*100:.0f}%' for d in CONFIG['early_density_levels']]}
   Training epochs per density: {CONFIG['early_training_epochs']}
   Learning rate: {CONFIG['finetune_lr']}

 LATE STAGE (differential pruning from each early state):
   Normal condition: Remove {CONFIG['late_pruning_normal']*100:.0f}% of remaining weights
   ASD condition: Remove {CONFIG['late_pruning_asd']*100:.0f}% of remaining weights
   Fine-tuning epochs: {CONFIG['late_finetune_epochs']}

 RATIONALE:
   This design tests how initial network density (early pruning) affects:
   1. Baseline learning capacity
   2. Resilience to further pruning (late stage)
   3. The differential impact of normal vs. excessive late pruning
    """)

    loaders = create_multitask_data_loaders()

    all_results = {
        'early_stage': {},
        'late_stage_normal': {},
        'late_stage_asd': {},
        'learning_curves': {},
        'comparisons': {}
    }

    # ========================================================================
    # STAGE 0: Train full density baseline
    # ========================================================================
    print("\n" + "="*80)
    print(" STAGE 0: TRAINING FULL DENSITY BASELINE")
    print("="*80)

    print("\n Creating full density network...")
    full_model = TaskGatedNetwork().to(DEVICE)
    full_model.print_architecture()

    print(f"\n Training baseline model for {CONFIG['baseline_epochs']} epochs...")
    full_history = train_taskgated_with_curves(
        full_model, loaders,
        epochs=CONFIG['baseline_epochs'],
        verbose=True,
        track_test_accuracy=True
    )

    all_results['baseline'] = comprehensive_evaluation(
        full_model, loaders, "FULL DENSITY BASELINE (100%)", print_results=True
    )
    all_results['learning_curves']['baseline'] = full_history

    full_state = {k: v.clone() for k, v in full_model.state_dict().items()}
    print("\n Baseline state saved for density sweep experiments.")

    # ========================================================================
    # EARLY STAGE: Density Sweep
    # ========================================================================
    print("\n" + "="*80)
    print(" EARLY STAGE: DENSITY SWEEP (10% to 100%)")
    print(" Comparing learning dynamics across density levels")
    print("="*80)

    early_states = {}  # Store states for late stage
    early_masks = {}   # Store masks for late stage

    print(f"\n Testing {len(CONFIG['early_density_levels'])} density levels...")
    print(f" Each model starts from baseline, is pruned to target density, then trained for {CONFIG['early_training_epochs']} epochs.\n")

    for i, target_density in enumerate(CONFIG['early_density_levels']):
        density_pct = int(target_density * 100)

        print("\n" + "-"*70)
        print(f" EARLY STAGE - Density Level {i+1}/{len(CONFIG['early_density_levels'])}: {density_pct}%")
        print("-"*70)

        # Skip very low density (essentially no parameters)
        if target_density < 0.05:
            print(f"   SKIPPED: Density {density_pct}% is too sparse for meaningful learning.")
            continue

        # Create model from full state
        print(f"\n   Loading baseline state...")
        model = TaskGatedNetwork().to(DEVICE)
        model.load_state_dict(full_state)
        mgr = TaskGatedPruningManager(model)

        # Prune to target density
        print(f"   Pruning to {density_pct}% density...")
        prune_stats = mgr.prune_to_density(target_density)
        actual_density = mgr.get_density() * 100

        print(f"   Pruning complete:")
        for layer_name, stats in prune_stats.items():
            print(f"     {layer_name}: kept {stats['kept']:,}/{stats['total']:,} weights ({(1-stats['actual_sparsity'])*100:.1f}% density)")

        mgr.print_mask_stats(f"(Target: {density_pct}%)")

        # Train at this density
        print(f"\n   Training for {CONFIG['early_training_epochs']} epochs at {actual_density:.1f}% density...")
        history = train_taskgated_with_curves(
            model, loaders,
            epochs=CONFIG['early_training_epochs'],
            lr=CONFIG['finetune_lr'],
            pruning_manager=mgr,
            verbose=True,
            track_test_accuracy=True
        )

        # Evaluate
        key = f"early_{density_pct}"
        results = comprehensive_evaluation(
            model, loaders,
            f"EARLY STAGE - {density_pct}% Density",
            print_results=True
        )
        results['target_density'] = target_density * 100
        results['actual_density'] = actual_density

        all_results['early_stage'][key] = results
        all_results['learning_curves'][key] = history

        # Save state for late stage
        early_states[density_pct] = {k: v.clone() for k, v in model.state_dict().items()}
        early_masks[density_pct] = {k: v.clone() for k, v in mgr.masks.items()}

        print(f"\n   State saved for late stage experiments.")

    # Early Stage Summary Table
    print("\n" + "="*70)
    print(" EARLY STAGE SUMMARY TABLE")
    print("="*70)
    print(f"\n {'Density':>10} {'Visual':>10} {'Voice':>10} {'Stress Hi':>12} {'Ambiguous':>12} {'Blend+Ambig':>12}")
    print(" " + "-"*75)

    for density_pct in sorted([int(d*100) for d in CONFIG['early_density_levels'] if d >= 0.1]):
        key = f"early_{density_pct}"
        if key in all_results['early_stage']:
            r = all_results['early_stage'][key]
            print(f" {r['actual_density']:>9.1f}% {r['visual_clear']:>9.1f}% {r['voice_clear']:>9.1f}% "
                  f"{r.get('visual_stress_high', 0):>11.1f}% {r.get('visual_ambig_ambiguous', 0):>11.1f}% "
                  f"{r.get('visual_blend50_ambig', 0):>11.1f}%")

    # ========================================================================
    # LATE STAGE: Differential Pruning
    # ========================================================================
    print("\n" + "="*80)
    print(" LATE STAGE: DIFFERENTIAL PRUNING (Normal vs ASD)")
    print("="*80)
    print(f"""
 For each early density state, we now apply additional late-stage pruning:
   - Normal condition: Remove {CONFIG['late_pruning_normal']*100:.0f}% of remaining weights
   - ASD condition: Remove {CONFIG['late_pruning_asd']*100:.0f}% of remaining weights

 This models developmental scenarios where:
   - Normal development has moderate synaptic pruning during adolescence
   - ASD may have insufficient pruning (this experiment tests EXCESSIVE pruning as comparison)
    """)

    for density_pct in sorted(early_states.keys()):
        print("\n" + "="*70)
        print(f" LATE STAGE: Starting from EARLY {density_pct}% Density")
        print("="*70)

        early_density = density_pct
        early_results = all_results['early_stage'][f'early_{density_pct}']

        print(f"\n   Early stage baseline performance:")
        print(f"     Visual accuracy: {early_results['visual_clear']:.1f}%")
        print(f"     Voice accuracy: {early_results['voice_clear']:.1f}%")
        print(f"     Stress tolerance (high): {early_results.get('visual_stress_high', 0):.1f}%")
        print(f"     Ambiguity tolerance: {early_results.get('visual_ambig_ambiguous', 0):.1f}%")

        # ----------------------------------------------------------------
        # NORMAL ARM: 20% additional pruning
        # ----------------------------------------------------------------
        print("\n" + "-"*60)
        print(f"   NORMAL CONDITION: Remove {CONFIG['late_pruning_normal']*100:.0f}% of remaining weights")
        print("-"*60)

        print(f"\n   Loading early state ({density_pct}% density)...")
        normal_model = TaskGatedNetwork().to(DEVICE)
        normal_model.load_state_dict(early_states[density_pct])
        normal_mgr = TaskGatedPruningManager(normal_model)
        normal_mgr.masks = {k: v.clone() for k, v in early_masks[density_pct].items()}

        pre_prune_density = normal_mgr.get_density() * 100
        print(f"   Pre-pruning density: {pre_prune_density:.1f}%")

        # Apply additional pruning
        print(f"   Applying {CONFIG['late_pruning_normal']*100:.0f}% additional pruning...")
        prune_stats = normal_mgr.prune_fraction_of_remaining(CONFIG['late_pruning_normal'])
        normal_density = normal_mgr.get_density() * 100

        print(f"   Pruning details:")
        for layer_name, stats in prune_stats.items():
            print(f"     {layer_name}: removed {stats['removed']:,} weights, {stats['remaining']:,} remaining ({stats['fraction']*100:.1f}% removed)")

        print(f"   Post-pruning density: {normal_density:.1f}%")

        # Fine-tune
        print(f"\n   Fine-tuning for {CONFIG['late_finetune_epochs']} epochs...")
        train_taskgated_with_curves(
            normal_model, loaders,
            epochs=CONFIG['late_finetune_epochs'],
            lr=CONFIG['finetune_lr'],
            pruning_manager=normal_mgr,
            verbose=True,
            track_test_accuracy=True
        )

        key_normal = f"late_normal_{density_pct}"
        results_normal = comprehensive_evaluation(
            normal_model, loaders,
            f"LATE NORMAL (from {density_pct}% → {normal_density:.1f}%)",
            print_results=True
        )
        results_normal['early_density'] = early_density
        results_normal['final_density'] = normal_density
        results_normal['pruning_fraction'] = CONFIG['late_pruning_normal']
        all_results['late_stage_normal'][key_normal] = results_normal

        # ----------------------------------------------------------------
        # ASD ARM: 50% additional pruning
        # ----------------------------------------------------------------
        print("\n" + "-"*60)
        print(f"   ASD CONDITION: Remove {CONFIG['late_pruning_asd']*100:.0f}% of remaining weights")
        print("-"*60)

        print(f"\n   Loading early state ({density_pct}% density)...")
        asd_model = TaskGatedNetwork().to(DEVICE)
        asd_model.load_state_dict(early_states[density_pct])
        asd_mgr = TaskGatedPruningManager(asd_model)
        asd_mgr.masks = {k: v.clone() for k, v in early_masks[density_pct].items()}

        pre_prune_density = asd_mgr.get_density() * 100
        print(f"   Pre-pruning density: {pre_prune_density:.1f}%")

        # Apply additional pruning
        print(f"   Applying {CONFIG['late_pruning_asd']*100:.0f}% additional pruning...")
        prune_stats = asd_mgr.prune_fraction_of_remaining(CONFIG['late_pruning_asd'])
        asd_density = asd_mgr.get_density() * 100

        print(f"   Pruning details:")
        for layer_name, stats in prune_stats.items():
            print(f"     {layer_name}: removed {stats['removed']:,} weights, {stats['remaining']:,} remaining ({stats['fraction']*100:.1f}% removed)")

        print(f"   Post-pruning density: {asd_density:.1f}%")

        # Fine-tune
        print(f"\n   Fine-tuning for {CONFIG['late_finetune_epochs']} epochs...")
        train_taskgated_with_curves(
            asd_model, loaders,
            epochs=CONFIG['late_finetune_epochs'],
            lr=CONFIG['finetune_lr'],
            pruning_manager=asd_mgr,
            verbose=True,
            track_test_accuracy=True
        )

        key_asd = f"late_asd_{density_pct}"
        results_asd = comprehensive_evaluation(
            asd_model, loaders,
            f"LATE ASD (from {density_pct}% → {asd_density:.1f}%)",
            print_results=True
        )
        results_asd['early_density'] = early_density
        results_asd['final_density'] = asd_density
        results_asd['pruning_fraction'] = CONFIG['late_pruning_asd']
        all_results['late_stage_asd'][key_asd] = results_asd

        # ----------------------------------------------------------------
        # COMPARISON: Normal vs ASD
        # ----------------------------------------------------------------
        print("\n" + "-"*60)
        print(f"   COMPARISON: Normal vs ASD (from {density_pct}% early density)")
        print("-"*60)

        comparison = {
            'early_density': early_density,
            'normal_final_density': normal_density,
            'asd_final_density': asd_density,
            'visual_diff': results_normal['visual_clear'] - results_asd['visual_clear'],
            'voice_diff': results_normal['voice_clear'] - results_asd['voice_clear'],
            'stress_diff': results_normal.get('visual_stress_high', 0) - results_asd.get('visual_stress_high', 0),
            'ambig_diff': results_normal.get('visual_ambig_ambiguous', 0) - results_asd.get('visual_ambig_ambiguous', 0),
            'blend_ambig_diff': results_normal.get('visual_blend50_ambig', 0) - results_asd.get('visual_blend50_ambig', 0)
        }
        all_results['comparisons'][density_pct] = comparison

        print(f"\n   {'Metric':<25} {'Normal':>12} {'ASD':>12} {'Diff (N-A)':>12}")
        print(f"   {'-'*65}")
        print(f"   {'Final Density':<25} {normal_density:>11.1f}% {asd_density:>11.1f}% {'---':>12}")
        print(f"   {'Visual Accuracy':<25} {results_normal['visual_clear']:>11.1f}% {results_asd['visual_clear']:>11.1f}% {comparison['visual_diff']:>+11.1f}%")
        print(f"   {'Voice Accuracy':<25} {results_normal['voice_clear']:>11.1f}% {results_asd['voice_clear']:>11.1f}% {comparison['voice_diff']:>+11.1f}%")
        print(f"   {'Stress Tolerance (high)':<25} {results_normal.get('visual_stress_high', 0):>11.1f}% {results_asd.get('visual_stress_high', 0):>11.1f}% {comparison['stress_diff']:>+11.1f}%")
        print(f"   {'Ambiguity Tolerance':<25} {results_normal.get('visual_ambig_ambiguous', 0):>11.1f}% {results_asd.get('visual_ambig_ambiguous', 0):>11.1f}% {comparison['ambig_diff']:>+11.1f}%")
        print(f"   {'Blend + Ambiguity':<25} {results_normal.get('visual_blend50_ambig', 0):>11.1f}% {results_asd.get('visual_blend50_ambig', 0):>11.1f}% {comparison['blend_ambig_diff']:>+11.1f}%")

        # Interpretation
        print(f"\n   Interpretation:")
        if comparison['visual_diff'] > 0:
            print(f"     → Normal outperforms ASD by {comparison['visual_diff']:.1f}% on visual task")
        elif comparison['visual_diff'] < 0:
            print(f"     → ASD outperforms Normal by {-comparison['visual_diff']:.1f}% on visual task")
        else:
            print(f"     → Normal and ASD perform equally on visual task")

        if comparison['stress_diff'] > 0:
            print(f"     → Normal has better stress tolerance by {comparison['stress_diff']:.1f}%")
        elif comparison['stress_diff'] < 0:
            print(f"     → ASD has better stress tolerance by {-comparison['stress_diff']:.1f}%")

    # ========================================================================
    # COMPREHENSIVE SUMMARY TABLES
    # ========================================================================
    print("\n" + "="*80)
    print(" COMPREHENSIVE SUMMARY TABLES")
    print("="*80)

    # Early Stage Summary
    print("\n" + "-"*70)
    print(" TABLE 1: EARLY STAGE RESULTS (Density Sweep)")
    print("-"*70)
    print(f"\n {'Density':>10} {'Visual':>10} {'Voice':>10} {'Stress Hi':>12} {'Ambiguous':>12} {'Blend+Ambig':>12}")
    print(" " + "-"*75)

    for density_pct in sorted([int(d*100) for d in CONFIG['early_density_levels'] if d >= 0.1]):
        key = f"early_{density_pct}"
        if key in all_results['early_stage']:
            r = all_results['early_stage'][key]
            print(f" {r['actual_density']:>9.1f}% {r['visual_clear']:>9.1f}% {r['voice_clear']:>9.1f}% "
                  f"{r.get('visual_stress_high', 0):>11.1f}% {r.get('visual_ambig_ambiguous', 0):>11.1f}% "
                  f"{r.get('visual_blend50_ambig', 0):>11.1f}%")

    # Late Stage Comparison Table
    print("\n" + "-"*70)
    print(" TABLE 2: LATE STAGE COMPARISON (Normal vs ASD)")
    print("-"*70)
    print(f"\n Normal: {CONFIG['late_pruning_normal']*100:.0f}% pruned | ASD: {CONFIG['late_pruning_asd']*100:.0f}% pruned\n")
    print(f" {'Early%':>8} {'N-Dens':>8} {'A-Dens':>8} {'N-Vis':>8} {'A-Vis':>8} {'Diff':>8} "
          f"{'N-Str':>8} {'A-Str':>8} {'Diff':>8}")
    print(" " + "-"*85)

    for density_pct in sorted(all_results['comparisons'].keys()):
        comp = all_results['comparisons'][density_pct]
        key_n = f"late_normal_{density_pct}"
        key_a = f"late_asd_{density_pct}"

        r_n = all_results['late_stage_normal'][key_n]
        r_a = all_results['late_stage_asd'][key_a]

        print(f" {density_pct:>7}% {r_n['final_density']:>7.1f}% {r_a['final_density']:>7.1f}% "
              f"{r_n['visual_clear']:>7.1f}% {r_a['visual_clear']:>7.1f}% {comp['visual_diff']:>+7.1f}% "
              f"{r_n.get('visual_stress_high', 0):>7.1f}% {r_a.get('visual_stress_high', 0):>7.1f}% "
              f"{comp['stress_diff']:>+7.1f}%")

    # Progression Table
    print("\n" + "-"*70)
    print(" TABLE 3: EARLY → LATE PROGRESSION")
    print("-"*70)
    print(f"\n {'Early%':>8} {'Early-Vis':>10} {'N-Vis':>10} {'A-Vis':>10} {'E→N Δ':>10} {'E→A Δ':>10}")
    print(" " + "-"*65)

    for density_pct in sorted(all_results['comparisons'].keys()):
        early_key = f"early_{density_pct}"
        key_n = f"late_normal_{density_pct}"
        key_a = f"late_asd_{density_pct}"

        r_e = all_results['early_stage'][early_key]
        r_n = all_results['late_stage_normal'][key_n]
        r_a = all_results['late_stage_asd'][key_a]

        e_vis = r_e['visual_clear']
        n_vis = r_n['visual_clear']
        a_vis = r_a['visual_clear']
        e_to_n = n_vis - e_vis
        e_to_a = a_vis - e_vis

        print(f" {density_pct:>7}% {e_vis:>9.1f}% {n_vis:>9.1f}% {a_vis:>9.1f}% {e_to_n:>+9.1f}% {e_to_a:>+9.1f}%")

    # ========================================================================
    # ANALYSIS
    # ========================================================================
    print("\n" + "="*80)
    print(" DETAILED ANALYSIS")
    print("="*80)

    # Find optimal early density
    early_visuals = [(d, all_results['early_stage'][f'early_{d}']['visual_clear'])
                     for d in sorted([int(x*100) for x in CONFIG['early_density_levels'] if x >= 0.1])
                     if f'early_{d}' in all_results['early_stage']]
    best_early_density, best_early_acc = max(early_visuals, key=lambda x: x[1])

    # Average differences
    avg_visual_diff = np.mean([c['visual_diff'] for c in all_results['comparisons'].values()])
    avg_stress_diff = np.mean([c['stress_diff'] for c in all_results['comparisons'].values()])
    avg_ambig_diff = np.mean([c['ambig_diff'] for c in all_results['comparisons'].values()])

    # Standard deviations
    std_visual_diff = np.std([c['visual_diff'] for c in all_results['comparisons'].values()])
    std_stress_diff = np.std([c['stress_diff'] for c in all_results['comparisons'].values()])

    print(f"""
 EARLY STAGE FINDINGS:
 =====================

 Optimal Early Density:
   - Best visual accuracy: {best_early_acc:.1f}% at {best_early_density}% density
   - This suggests that moderate pruning may help, but excessive pruning hurts

 Density-Performance Relationship:
   - Lower densities (10-30%): Likely degraded performance due to insufficient capacity
   - Moderate densities (40-70%): May achieve good balance
   - High densities (80-100%): Full capacity, but potentially over-parameterized

 Learning Dynamics:
   - Sparse networks may learn more slowly but generalize differently
   - Dense networks learn faster but may overfit to training distribution
    """)

    print(f"""
 LATE STAGE FINDINGS (Normal vs ASD):
 ====================================

 Overall Differences (Normal - ASD, averaged across all early densities):
   - Visual accuracy: {avg_visual_diff:+.2f}% ± {std_visual_diff:.2f}%
   - Stress tolerance: {avg_stress_diff:+.2f}% ± {std_stress_diff:.2f}%
   - Ambiguity tolerance: {avg_ambig_diff:+.2f}%

 Statistical Interpretation:
   - Positive difference = Normal outperforms ASD
   - Negative difference = ASD outperforms Normal

 Conclusion: {"Normal condition generally performs BETTER" if avg_visual_diff > 0 else "ASD condition performs comparably or BETTER"}
   → Heavier late pruning (ASD) {"HURTS" if avg_visual_diff > 0 else "does NOT necessarily hurt"} performance
    """)

    # Density-dependent effect analysis
    print("\n DENSITY-DEPENDENT EFFECTS:")
    print(" " + "-"*60)
    print(f"\n How does early density affect the Normal-ASD gap?\n")

    for density_pct in sorted(all_results['comparisons'].keys()):
        comp = all_results['comparisons'][density_pct]
        effect = "Normal BETTER" if comp['visual_diff'] > 0 else "ASD BETTER" if comp['visual_diff'] < 0 else "EQUAL"
        magnitude = abs(comp['visual_diff'])

        # Interpretation
        if density_pct <= 30:
            context = "(low starting density - already resource-constrained)"
        elif density_pct <= 60:
            context = "(moderate starting density)"
        else:
            context = "(high starting density - ample resources)"

        print(f"   {density_pct}% early density:")
        print(f"     → {effect} by {magnitude:.1f}% {context}")
        print(f"     → Final densities: Normal={comp['normal_final_density']:.1f}%, ASD={comp['asd_final_density']:.1f}%")

    # Key insights
    print(f"""

 KEY INSIGHTS:
 =============

 1. Interaction Effect:
    The impact of late pruning depends critically on early density.
    - High early density provides "buffer" against late pruning damage
    - Low early density leaves network vulnerable to any further pruning

 2. Resilience Patterns:
    Networks with more initial connections can better absorb pruning
    without catastrophic performance loss.

 3. Implications for ASD Modeling:
    If ASD involves altered pruning dynamics, both timing AND magnitude matter:
    - Early hyper-connectivity + insufficient late pruning = one pattern
    - Normal early + excessive late pruning = different pattern (tested here)

 4. Stress and Ambiguity:
    These secondary metrics often show larger Normal-ASD gaps than basic accuracy,
    suggesting pruning affects robustness more than raw performance.
    """)

    return all_results


# ============================================================================
# SECTION 8: MAIN
# ============================================================================

if __name__ == "__main__":
    print("\n" + "#"*80)
    print("#" + " DEVELOPMENTAL PRUNING SIMULATION FOR ASD ".center(78) + "#")
    print("#" + " VERSION 9.2: DENSITY SWEEP WITH DIFFERENTIAL PRUNING ".center(78) + "#")
    print("#" + " EXPERIMENT 1 ONLY ".center(78) + "#")
    print("#"*80)

    print("""
 ============================================================================
 EXPERIMENT 1: DENSITY SWEEP WITH DIFFERENTIAL LATE PRUNING
 ============================================================================

 OBJECTIVE:
 ----------
 Understand how initial network density (early pruning) affects:
   1. Learning capacity and task performance
   2. Resilience to subsequent pruning
   3. Differential outcomes between Normal and ASD pruning regimes

 EARLY STAGE:
 ------------
   • Test densities: 10%, 20%, 30%, 40%, 50%, 60%, 70%, 80%, 90%, 100%
   • Each model starts from trained baseline
   • Prune to target density, then train for 15 epochs
   • Evaluate on: accuracy, stress tolerance, ambiguity handling

 LATE STAGE:
 -----------
   • For each early density, create two conditions:
     - Normal: Remove 20% of remaining weights
     - ASD: Remove 50% of remaining weights
   • Fine-tune each for 10 epochs
   • Compare Normal vs ASD outcomes

 METRICS:
 --------
   • Visual task accuracy (primary)
   • Voice task accuracy (secondary, conflicting task)
   • Stress tolerance (performance under noise injection)
   • Ambiguity tolerance (performance with mixed task signals)
   • Blended input handling (combining sensory streams)

 ============================================================================
    """)

    # Run experiment
    print("\n Starting Experiment 1...")
    sweep_results = run_density_sweep_experiment()

    # Final summary
    print("\n" + "="*80)
    print(" EXPERIMENT 1 COMPLETE")
    print("="*80)

    print("""
 RESULTS STRUCTURE:
 ==================

 sweep_results dictionary contains:

   'baseline':
     Results for full density (100%) baseline network

   'early_stage':
     Results for each density level (keys: 'early_10', 'early_20', ..., 'early_100')
     Each contains: visual_clear, voice_clear, stress metrics, ambiguity metrics, etc.

   'late_stage_normal':
     Results after Normal pruning (20% removed) from each early state
     Keys: 'late_normal_10', 'late_normal_20', ..., 'late_normal_100'

   'late_stage_asd':
     Results after ASD pruning (50% removed) from each early state
     Keys: 'late_asd_10', 'late_asd_20', ..., 'late_asd_100'

   'comparisons':
     Direct Normal vs ASD comparisons at each early density
     Keys: 10, 20, 30, ..., 100 (integer density values)
     Each contains: visual_diff, stress_diff, ambig_diff, etc.

   'learning_curves':
     Training history (loss, accuracy per epoch) for baseline and each early density
    """)

    print("="*80)
    print(" END OF EXPERIMENT 1")
    print("="*80 + "\n")


################################################################################
#                   DEVELOPMENTAL PRUNING SIMULATION FOR ASD                   #
#             VERSION 9.2: DENSITY SWEEP WITH DIFFERENTIAL PRUNING             #
#                              EXPERIMENT 1 ONLY                               #
################################################################################

 EXPERIMENT 1: DENSITY SWEEP WITH DIFFERENTIAL LATE PRUNING

 OBJECTIVE:
 ----------
 Understand how initial network density (early pruning) affects:
   1. Learning capacity and task performance
   2. Resilience to subsequent pruning
   3. Differential outcomes between Normal and ASD pruning regimes

 EARLY STAGE:
 ------------
   • Test densities: 10%, 20%, 30%, 40%, 50%, 60%, 70%, 80%, 90%, 100%
   • Each model starts from trained baseline
   • Prune to target density, then train for 15 epochs
   • Evaluate on: accuracy, stress tolerance, ambiguity handling

 LATE STAGE:
 -----------


# The End