# Isodose - Testing Antidepressant Treatments for Pruning-Induced Fragility in LLMs

In [None]:
"""
================================================================================
MULTI-MECHANISM RECOVERY FROM PRUNING-INDUCED FRAGILITY IN LLMs
WITH ISODOSE TREATMENT COMPARISON
================================================================================

Adapts the antidepressant comparison experiment to real Large Language Models
using structured width pruning as the "disease" model.

NEW FEATURE: ISODOSE COMPARISON
- Equalizes treatments based on computational cost (training FLOPs)
- Enables fair comparison of treatment efficacy per unit of "therapeutic effort"
- Analogous to comparing drugs at equivalent doses rather than arbitrary doses

MODIFICATIONS APPLIED:
- Removed bitsandbytes/4-bit quantization (Colab compatibility)
- Uses fp16 with device_map="auto"
- Starts from pre-pruned model oopere/pruned60-llama-3.2-1B
- Implements proxy evaluations (ARC-Easy, LAMBADA)
- Includes acute relapse via unstructured pruning
- Simplified longitudinal stress cycles
- ADDED: Isodose calibration and comparison framework

Based on:
- Pere Martra's pruning course: https://github.com/peremartra/Large-Language-Model-Notebooks-Course
- Llama-pruning tools: https://github.com/MedITSolutionsKurman/llama-pruning
- Paper: "Fragile Knowledge, Robust Instruction-Following" (arXiv 2512.22671)

Treatments tested:
1. KETAMINE-LIKE: High-rank LoRA (structural regrowth) + short aggressive fine-tuning
2. SSRI-LIKE: Low-rank LoRA + prolonged gradual stabilization
3. NEUROSTEROID-LIKE: High dropout (tonic inhibition) + consolidation training
================================================================================
"""

# ============================================================================
# CELL 1: INSTALLATIONS (Run first in Colab)
# ============================================================================
# !pip install -q transformers accelerate peft datasets torch --extra-index-url https://download.pytorch.org/whl/cu121
# !pip install -q sentencepiece einops

# ============================================================================
# CELL 2: IMPORTS AND CONFIGURATION
# ============================================================================
import torch
import torch.nn as nn
from torch.nn.utils import prune
from torch.utils.data import DataLoader
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    TrainingArguments,
    Trainer,
    DataCollatorForLanguageModeling
)
from peft import (
    LoraConfig,
    get_peft_model,
    prepare_model_for_kbit_training
)
from datasets import load_dataset
import numpy as np
from typing import Dict, List, Tuple, Optional
import copy
import json
import os
from dataclasses import dataclass, field
from collections import defaultdict
import warnings
import gc
import time
import math

warnings.filterwarnings('ignore')

# Set seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# ============================================================================
# CONFIGURATION
# ============================================================================
@dataclass
class ExperimentConfig:
    """
    Configuration for the LLM pruning recovery experiment.

    DESIGN NOTES:
    - pruned_model: Using oopere/pruned60-llama-3.2-1B which has 60% MLP neurons
      removed, representing severe "fragility" state analogous to depression
    - No 4-bit quantization: Avoids bitsandbytes issues on Colab
    - eval_samples: Reduced for speed while maintaining statistical validity

    ISODOSE PARAMETERS:
    - target_flops: The computational budget all treatments must match
    - isodose_mode: Whether to run in isodose mode or default mode
    """

    # Model settings - using pre-pruned model to avoid pruning overhead
    # This model has 60% MLP pruning = strong knowledge fragility
    pruned_model: str = "oopere/pruned60-llama-3.2-1B"

    # CRITICAL: No 4-bit quantization - uses fp16 instead
    # This avoids bitsandbytes package issues on Colab
    use_4bit: bool = False

    device: str = "cuda" if torch.cuda.is_available() else "cpu"

    # Experiment settings
    n_seeds: int = 3
    output_dir: str = "./llm_treatment_results"

    # Evaluation settings - using subset for speed
    eval_samples: int = 200

    # -------------------------------------------------------------------------
    # ISODOSE CONFIGURATION
    # -------------------------------------------------------------------------
    # Enable isodose comparison mode
    isodose_mode: bool = True

    # Target computational budget (in estimated FLOPs)
    # Will be calibrated based on ketamine treatment as reference
    # Set to None for automatic calibration
    target_flops: Optional[float] = None

    # Isodose calibration reference treatment
    isodose_reference: str = "ketamine"

    # -------------------------------------------------------------------------
    # DEFAULT TREATMENT PARAMETERS (used when isodose_mode=False)
    # -------------------------------------------------------------------------
    # KETAMINE-LIKE
    ketamine_lora_rank: int = 64
    ketamine_lora_alpha: int = 128
    ketamine_epochs: int = 3
    ketamine_lr: float = 5e-5
    ketamine_target_modules: List[str] = field(
        default_factory=lambda: ["q_proj", "v_proj", "k_proj", "o_proj",
                                  "gate_proj", "up_proj", "down_proj"]
    )
    ketamine_dropout: float = 0.05

    # SSRI-LIKE
    ssri_lora_rank: int = 8
    ssri_lora_alpha: int = 16
    ssri_epochs: int = 15
    ssri_lr: float = 1e-6
    ssri_dropout: float = 0.1

    # NEUROSTEROID-LIKE
    neuro_lora_rank: int = 32
    neuro_lora_alpha: int = 64
    neuro_epochs: int = 5
    neuro_lr: float = 3e-5
    neuro_dropout: float = 0.5

    # -------------------------------------------------------------------------
    # LONGITUDINAL SIMULATION PARAMETERS
    # -------------------------------------------------------------------------
    longitudinal_cycles: int = 8
    additional_prune_per_cycle: float = 0.10

    # Maintenance training epochs per treatment
    ketamine_maintenance_epochs: int = 2
    ssri_maintenance_epochs: int = 5
    neuro_maintenance_epochs: int = 3

    # -------------------------------------------------------------------------
    # ACUTE RELAPSE SIMULATION
    # -------------------------------------------------------------------------
    acute_relapse_prune_amount: float = 0.30

    # Dataset settings
    finetune_dataset: str = "databricks/databricks-dolly-15k"
    finetune_subset_size: int = 500
    max_seq_length: int = 512


# Initialize global config
CONFIG = ExperimentConfig()


# ============================================================================
# CELL 3: ISODOSE CALCULATION FRAMEWORK
# ============================================================================
@dataclass
class TreatmentDose:
    """
    Represents a treatment's computational "dose" and parameters.

    ISODOSE CONCEPT:
    In pharmacology, isodose comparison means comparing drugs at doses that
    produce equivalent biological effects or have equivalent potency.

    For LLM treatments, we define "dose" as computational cost:
    - Training FLOPs = f(trainable_params, epochs, dataset_size, seq_length)
    - This captures both the "intensity" (params) and "duration" (epochs) of treatment

    This allows fair comparison: "Given the same computational budget,
    which treatment mechanism produces better recovery?"
    """
    treatment_name: str
    lora_rank: int
    lora_alpha: int
    epochs: int
    lr: float
    dropout: float
    target_modules: List[str]
    trainable_params: int = 0
    estimated_flops: float = 0.0

    def compute_flops(self, dataset_size: int, seq_length: int,
                      hidden_size: int = 2048, num_layers: int = 16):
        """
        Estimate training FLOPs for this treatment configuration.

        FLOP ESTIMATION MODEL:
        - Forward pass: ~2 * params * seq_length * batch_size per sample
        - Backward pass: ~4 * params * seq_length * batch_size per sample
        - Total per epoch: ~6 * trainable_params * dataset_size * seq_length
        - Total training: epochs * per_epoch_flops

        This is a simplified model but captures the key scaling factors.

        Args:
            dataset_size: Number of training samples
            seq_length: Sequence length
            hidden_size: Model hidden dimension
            num_layers: Number of transformer layers

        Returns:
            Estimated total training FLOPs
        """
        # Estimate trainable parameters based on LoRA config
        # LoRA adds rank * (in_features + out_features) params per adapted layer
        params_per_module = 2 * self.lora_rank * hidden_size
        num_adapted_modules = len(self.target_modules) * num_layers
        self.trainable_params = params_per_module * num_adapted_modules

        # FLOPs per training step (forward + backward)
        flops_per_sample = 6 * self.trainable_params * seq_length

        # Total FLOPs
        self.estimated_flops = flops_per_sample * dataset_size * self.epochs

        return self.estimated_flops


def calculate_isodose_parameters(reference_dose: TreatmentDose,
                                  target_treatment: str,
                                  config: ExperimentConfig) -> TreatmentDose:
    """
    Calculate parameters for a treatment to match the reference dose's FLOPs.

    ISODOSE CALIBRATION STRATEGY:
    Each treatment has characteristic features we want to preserve:
    - Ketamine: High rank, short duration, aggressive LR
    - SSRI: Low rank, long duration, gentle LR
    - Neurosteroid: Moderate rank, high dropout

    To achieve isodose, we adjust EPOCHS while preserving characteristic features.

    The key insight: FLOPs scale linearly with epochs, so we can solve for
    the number of epochs needed to match target FLOPs:

    target_epochs = target_flops / (flops_per_epoch)

    Args:
        reference_dose: The reference treatment dose (typically ketamine)
        target_treatment: Name of treatment to calibrate
        config: Experiment configuration

    Returns:
        TreatmentDose with calibrated parameters
    """
    target_flops = reference_dose.estimated_flops

    print(f"\n    Calibrating {target_treatment.upper()} to match {reference_dose.estimated_flops:.2e} FLOPs")

    if target_treatment == 'ketamine':
        # Ketamine is typically the reference, return as-is
        return TreatmentDose(
            treatment_name='ketamine',
            lora_rank=config.ketamine_lora_rank,
            lora_alpha=config.ketamine_lora_alpha,
            epochs=config.ketamine_epochs,
            lr=config.ketamine_lr,
            dropout=config.ketamine_dropout,
            target_modules=config.ketamine_target_modules,
            trainable_params=reference_dose.trainable_params,
            estimated_flops=reference_dose.estimated_flops
        )

    elif target_treatment == 'ssri':
        # SSRI: Preserve low rank and gentle LR, adjust epochs
        # Characteristic: Low rank (8), limited target modules
        ssri_modules = ["q_proj", "v_proj"]

        # Create temporary dose to calculate FLOPs per epoch
        temp_dose = TreatmentDose(
            treatment_name='ssri',
            lora_rank=config.ssri_lora_rank,
            lora_alpha=config.ssri_lora_alpha,
            epochs=1,  # Calculate for 1 epoch
            lr=config.ssri_lr,
            dropout=config.ssri_dropout,
            target_modules=ssri_modules
        )
        flops_per_epoch = temp_dose.compute_flops(
            config.finetune_subset_size,
            config.max_seq_length
        )

        # Calculate epochs needed to match target FLOPs
        isodose_epochs = max(1, int(round(target_flops / flops_per_epoch)))

        # Create calibrated dose
        calibrated = TreatmentDose(
            treatment_name='ssri',
            lora_rank=config.ssri_lora_rank,
            lora_alpha=config.ssri_lora_alpha,
            epochs=isodose_epochs,
            lr=config.ssri_lr,
            dropout=config.ssri_dropout,
            target_modules=ssri_modules,
            trainable_params=temp_dose.trainable_params,
            estimated_flops=flops_per_epoch * isodose_epochs
        )

        print(f"      Original epochs: {config.ssri_epochs}")
        print(f"      Isodose epochs: {isodose_epochs}")
        print(f"      FLOP ratio: {calibrated.estimated_flops / target_flops:.2f}x")

        return calibrated

    elif target_treatment == 'neurosteroid':
        # Neurosteroid: Preserve high dropout and moderate rank, adjust epochs
        neuro_modules = ["q_proj", "v_proj", "gate_proj", "up_proj"]

        temp_dose = TreatmentDose(
            treatment_name='neurosteroid',
            lora_rank=config.neuro_lora_rank,
            lora_alpha=config.neuro_lora_alpha,
            epochs=1,
            lr=config.neuro_lr,
            dropout=config.neuro_dropout,
            target_modules=neuro_modules
        )
        flops_per_epoch = temp_dose.compute_flops(
            config.finetune_subset_size,
            config.max_seq_length
        )

        isodose_epochs = max(1, int(round(target_flops / flops_per_epoch)))

        calibrated = TreatmentDose(
            treatment_name='neurosteroid',
            lora_rank=config.neuro_lora_rank,
            lora_alpha=config.neuro_lora_alpha,
            epochs=isodose_epochs,
            lr=config.neuro_lr,
            dropout=config.neuro_dropout,
            target_modules=neuro_modules,
            trainable_params=temp_dose.trainable_params,
            estimated_flops=flops_per_epoch * isodose_epochs
        )

        print(f"      Original epochs: {config.neuro_epochs}")
        print(f"      Isodose epochs: {isodose_epochs}")
        print(f"      FLOP ratio: {calibrated.estimated_flops / target_flops:.2f}x")

        return calibrated

    else:
        raise ValueError(f"Unknown treatment: {target_treatment}")


def calibrate_all_treatments(config: ExperimentConfig) -> Dict[str, TreatmentDose]:
    """
    Calibrate all treatments to isodose based on reference treatment.

    CALIBRATION PROCESS:
    1. Calculate FLOPs for reference treatment (ketamine by default)
    2. For each other treatment, adjust epochs to match reference FLOPs
    3. Return dictionary of calibrated treatment doses

    This ensures all treatments receive the same "computational dose"
    while preserving their characteristic mechanisms.

    Args:
        config: Experiment configuration

    Returns:
        Dictionary mapping treatment names to calibrated TreatmentDose objects
    """
    print("\n" + "=" * 70)
    print("ISODOSE CALIBRATION")
    print("=" * 70)
    print(f"  Reference treatment: {config.isodose_reference.upper()}")
    print(f"  Calibration strategy: Adjust epochs to match reference FLOPs")
    print(f"  Preserved characteristics:")
    print(f"    - Ketamine: High rank ({config.ketamine_lora_rank}), many modules")
    print(f"    - SSRI: Low rank ({config.ssri_lora_rank}), gentle LR ({config.ssri_lr})")
    print(f"    - Neurosteroid: High dropout ({config.neuro_dropout})")
    print("-" * 70)

    # Calculate reference dose (ketamine)
    reference = TreatmentDose(
        treatment_name='ketamine',
        lora_rank=config.ketamine_lora_rank,
        lora_alpha=config.ketamine_lora_alpha,
        epochs=config.ketamine_epochs,
        lr=config.ketamine_lr,
        dropout=config.ketamine_dropout,
        target_modules=config.ketamine_target_modules
    )
    reference.compute_flops(config.finetune_subset_size, config.max_seq_length)

    print(f"\n  REFERENCE: {config.isodose_reference.upper()}")
    print(f"    LoRA rank: {reference.lora_rank}")
    print(f"    Epochs: {reference.epochs}")
    print(f"    Trainable params: {reference.trainable_params:,}")
    print(f"    Estimated FLOPs: {reference.estimated_flops:.2e}")

    # Calibrate all treatments
    calibrated_doses = {
        'ketamine': reference
    }

    for treatment in ['ssri', 'neurosteroid']:
        calibrated_doses[treatment] = calculate_isodose_parameters(
            reference, treatment, config
        )

    # Print summary table
    print("\n" + "-" * 70)
    print("  ISODOSE CALIBRATION SUMMARY")
    print("-" * 70)
    print(f"\n  {'Treatment':<15} {'Rank':>6} {'Epochs':>8} {'Params':>12} {'FLOPs':>14} {'Ratio':>8}")
    print("  " + "-" * 65)

    for name, dose in calibrated_doses.items():
        ratio = dose.estimated_flops / reference.estimated_flops
        print(f"  {name.capitalize():<15} {dose.lora_rank:>6} {dose.epochs:>8} "
              f"{dose.trainable_params:>12,} {dose.estimated_flops:>14.2e} {ratio:>7.2f}x")

    print("=" * 70)

    return calibrated_doses


# ============================================================================
# CELL 4: MODEL LOADING (fp16, NO 4-bit quantization)
# ============================================================================
def load_pruned_model(config: ExperimentConfig = None) -> Tuple[AutoModelForCausalLM, AutoTokenizer]:
    """
    Load the pre-pruned Llama model in fp16.

    IMPLEMENTATION NOTES:
    - Uses oopere/pruned60-llama-3.2-1B: A Llama-3.2-1B model with 60% of MLP
      neurons removed via structured width pruning
    - This creates a "fragile" model where factual knowledge is impaired but
      instruction-following capabilities are preserved
    - fp16 + device_map="auto" fits comfortably on Colab T4 (16GB) or A100
    - NO bitsandbytes dependency - avoids Colab CUDA wheel issues
    """
    if config is None:
        config = CONFIG

    print("=" * 70)
    print("LOADING PRE-PRUNED MODEL")
    print("=" * 70)
    print(f"  Model: {config.pruned_model}")
    print(f"  Dtype: torch.float16 (NO 4-bit quantization)")
    print(f"  Device map: auto")
    print(f"  Expected state: 60% MLP neurons removed (fragile knowledge)")
    print("-" * 70)

    start_time = time.time()

    model = AutoModelForCausalLM.from_pretrained(
        config.pruned_model,
        torch_dtype=torch.float16,
        device_map="auto",
        trust_remote_code=True,
        low_cpu_mem_usage=True
    )

    tokenizer = AutoTokenizer.from_pretrained(
        config.pruned_model,
        trust_remote_code=True
    )

    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
        print("  [INFO] Set pad_token = eos_token")

    load_time = time.time() - start_time

    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

    print(f"  Load time: {load_time:.2f}s")
    print(f"  Total parameters: {total_params:,}")
    print(f"  Trainable parameters: {trainable_params:,}")
    print(f"  Model device: {next(model.parameters()).device}")
    print(f"  Model dtype: {next(model.parameters()).dtype}")
    print("=" * 70)
    print("  [SUCCESS] Pruned model loaded (fragile state confirmed)")
    print("=" * 70)

    return model, tokenizer


# ============================================================================
# CELL 5: EVALUATION FUNCTIONS (Proxy metrics for fragility)
# ============================================================================
def load_evaluation_datasets(config: ExperimentConfig = None) -> Dict:
    """Load evaluation datasets for measuring knowledge fragility."""
    if config is None:
        config = CONFIG

    print("-" * 70)
    print("LOADING EVALUATION DATASETS")
    print("-" * 70)

    datasets = {}

    try:
        arc = load_dataset(
            "allenai/ai2_arc",
            "ARC-Easy",
            split="test"
        ).shuffle(seed=42).select(range(config.eval_samples))
        datasets['arc'] = arc
        print(f"  ARC-Easy: {len(arc)} samples loaded")
    except Exception as e:
        print(f"  [WARNING] Failed to load ARC-Easy: {e}")
        datasets['arc'] = None

    try:
        lambada = load_dataset(
            "lambada",
            split="test"
        ).shuffle(seed=42).select(range(config.eval_samples // 2))
        datasets['lambada'] = lambada
        print(f"  LAMBADA: {len(lambada)} samples loaded")
    except Exception as e:
        print(f"  [WARNING] Failed to load LAMBADA: {e}")
        datasets['lambada'] = None

    print("-" * 70)

    return datasets


def evaluate_arc_easy(model, tokenizer, dataset, max_samples: int = None) -> Dict:
    """Evaluate model on ARC-Easy multiple choice questions."""
    if dataset is None:
        return {'accuracy': 0.0, 'error': 'Dataset not loaded'}

    model.eval()
    correct = 0
    total = 0
    results = []

    samples = dataset if max_samples is None else dataset.select(range(min(max_samples, len(dataset))))

    for example in samples:
        question = example['question']
        answer_key = example['answerKey']
        choices = example['choices']

        choice_text = "\n".join([
            f"{label}: {text}"
            for label, text in zip(choices['label'], choices['text'])
        ])

        prompt = f"Question: {question}\n\nChoices:\n{choice_text}\n\nAnswer:"

        inputs = tokenizer(
            prompt,
            return_tensors="pt",
            truncation=True,
            max_length=512
        ).to(model.device)

        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                max_new_tokens=10,
                temperature=0.0,
                do_sample=False,
                pad_token_id=tokenizer.pad_token_id
            )

        generated = tokenizer.decode(outputs[0], skip_special_tokens=True)
        response = generated[len(prompt):].strip().upper()

        is_correct = answer_key.upper() in response[:5]

        if is_correct:
            correct += 1
        total += 1

        results.append({
            'question': question[:50] + '...',
            'correct_answer': answer_key,
            'response': response[:20],
            'is_correct': is_correct
        })

    accuracy = 100.0 * correct / total if total > 0 else 0.0

    return {
        'accuracy': accuracy,
        'correct': correct,
        'total': total,
        'results': results
    }


def evaluate_lambada(model, tokenizer, dataset, max_samples: int = None) -> Dict:
    """Evaluate model on LAMBADA word prediction task."""
    if dataset is None:
        return {'accuracy': 0.0, 'error': 'Dataset not loaded'}

    model.eval()
    correct = 0
    total = 0
    results = []

    samples = dataset if max_samples is None else dataset.select(range(min(max_samples, len(dataset))))

    for example in samples:
        text = example['text']
        words = text.split()

        if len(words) < 2:
            continue

        target_word = words[-1].lower().strip('.,!?')
        context = ' '.join(words[:-1])

        inputs = tokenizer(
            context,
            return_tensors="pt",
            truncation=True,
            max_length=512
        ).to(model.device)

        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                max_new_tokens=5,
                temperature=0.0,
                do_sample=False,
                pad_token_id=tokenizer.pad_token_id
            )

        generated = tokenizer.decode(outputs[0], skip_special_tokens=True)
        prediction = generated[len(context):].strip().lower().split()[0] if generated[len(context):].strip() else ""
        prediction = prediction.strip('.,!?')

        is_correct = target_word == prediction or target_word.startswith(prediction)

        if is_correct:
            correct += 1
        total += 1

        results.append({
            'target': target_word,
            'prediction': prediction,
            'is_correct': is_correct
        })

    accuracy = 100.0 * correct / total if total > 0 else 0.0

    return {
        'accuracy': accuracy,
        'correct': correct,
        'total': total
    }


def evaluate_model_full(model, tokenizer, eval_datasets: Dict,
                        config: ExperimentConfig = None) -> Dict:
    """Run full evaluation suite on model."""
    if config is None:
        config = CONFIG

    results = {}

    print("    Running ARC-Easy evaluation...", end=" ")
    arc_results = evaluate_arc_easy(model, tokenizer, eval_datasets.get('arc'))
    results['arc_easy'] = arc_results['accuracy']
    print(f"Accuracy: {arc_results['accuracy']:.1f}%")

    print("    Running LAMBADA evaluation...", end=" ")
    lambada_results = evaluate_lambada(model, tokenizer, eval_datasets.get('lambada'))
    results['lambada'] = lambada_results['accuracy']
    print(f"Accuracy: {lambada_results['accuracy']:.1f}%")

    results['composite'] = (results['arc_easy'] + results['lambada']) / 2

    return results


# ============================================================================
# CELL 6: DATASET PREPARATION
# ============================================================================
def prepare_finetune_dataset(tokenizer, config: ExperimentConfig, seed: int):
    """Prepare fine-tuning dataset for knowledge recovery training."""
    print("-" * 70)
    print(f"PREPARING FINE-TUNING DATASET (seed={seed})")
    print("-" * 70)

    dataset = load_dataset(
        config.finetune_dataset,
        split="train"
    )

    dataset = dataset.shuffle(seed=seed).select(
        range(min(config.finetune_subset_size, len(dataset)))
    )

    print(f"  Dataset: {config.finetune_dataset}")
    print(f"  Subset size: {len(dataset)} samples")

    def format_example(example):
        instruction = example.get('instruction', example.get('context', ''))
        response = example.get('response', example.get('text', ''))

        text = f"### Instruction:\n{instruction}\n\n### Response:\n{response}"
        return {'text': text}

    dataset = dataset.map(format_example)

    def tokenize(examples):
        return tokenizer(
            examples['text'],
            truncation=True,
            max_length=config.max_seq_length,
            padding='max_length'
        )

    dataset = dataset.map(tokenize, batched=True, remove_columns=dataset.column_names)
    dataset.set_format('torch')

    print(f"  Tokenized: max_length={config.max_seq_length}")
    print("-" * 70)

    return dataset


# ============================================================================
# CELL 7: TREATMENT APPLICATION WITH ISODOSE SUPPORT
# ============================================================================
def apply_treatment_with_dose(base_model, tokenizer, train_dataset,
                               dose: TreatmentDose, config: ExperimentConfig,
                               seed: int) -> Tuple[nn.Module, Dict]:
    """
    Apply a treatment using calibrated isodose parameters.

    UNIFIED TREATMENT APPLICATION:
    This function applies any treatment using parameters from a TreatmentDose
    object, enabling both default and isodose experiments with the same code.

    Args:
        base_model: The pruned model to treat
        tokenizer: Tokenizer
        train_dataset: Fine-tuning dataset
        dose: TreatmentDose object with calibrated parameters
        config: Experiment configuration
        seed: Random seed

    Returns:
        Tuple of (treated_model, treatment_stats)
    """
    print("\n" + "=" * 70)
    print(f"APPLYING {dose.treatment_name.upper()} TREATMENT")
    if config.isodose_mode:
        print("  [ISODOSE MODE]")
    print("=" * 70)
    print(f"  LoRA rank: {dose.lora_rank}")
    print(f"  LoRA alpha: {dose.lora_alpha}")
    print(f"  Training epochs: {dose.epochs}")
    print(f"  Learning rate: {dose.lr}")
    print(f"  Dropout: {dose.dropout}")
    print(f"  Target modules: {dose.target_modules}")
    print(f"  Estimated FLOPs: {dose.estimated_flops:.2e}")
    print("-" * 70)

    torch.manual_seed(seed)

    model = copy.deepcopy(base_model)

    lora_config = LoraConfig(
        r=dose.lora_rank,
        lora_alpha=dose.lora_alpha,
        target_modules=dose.target_modules,
        lora_dropout=dose.dropout,
        bias="none",
        task_type="CAUSAL_LM"
    )

    model = get_peft_model(model, lora_config)

    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    total_params = sum(p.numel() for p in model.parameters())

    print(f"  Actual trainable parameters: {trainable_params:,} ({100*trainable_params/total_params:.2f}%)")

    training_args = TrainingArguments(
        output_dir=os.path.join(config.output_dir, f"{dose.treatment_name}_seed{seed}"),
        num_train_epochs=dose.epochs,
        learning_rate=dose.lr,
        per_device_train_batch_size=4,
        gradient_accumulation_steps=4,
        warmup_ratio=0.1,
        logging_steps=50,
        save_strategy="no",
        fp16=True,
        report_to="none",
        remove_unused_columns=False
    )

    data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        data_collator=data_collator
    )

    print("  Starting training...")
    start_time = time.time()
    trainer.train()
    train_time = time.time() - start_time

    print(f"  Training complete in {train_time:.1f}s")
    print("=" * 70)

    stats = {
        'treatment': dose.treatment_name,
        'isodose_mode': config.isodose_mode,
        'lora_rank': dose.lora_rank,
        'epochs': dose.epochs,
        'lr': dose.lr,
        'dropout': dose.dropout,
        'trainable_params': trainable_params,
        'estimated_flops': dose.estimated_flops,
        'train_time': train_time
    }

    return model, stats


def apply_ketamine_treatment(base_model, tokenizer, train_dataset,
                             config: ExperimentConfig, seed: int,
                             dose: TreatmentDose = None):
    """Apply ketamine-like treatment (wrapper for compatibility)."""
    if dose is None:
        dose = TreatmentDose(
            treatment_name='ketamine',
            lora_rank=config.ketamine_lora_rank,
            lora_alpha=config.ketamine_lora_alpha,
            epochs=config.ketamine_epochs,
            lr=config.ketamine_lr,
            dropout=config.ketamine_dropout,
            target_modules=config.ketamine_target_modules
        )
        dose.compute_flops(config.finetune_subset_size, config.max_seq_length)

    return apply_treatment_with_dose(base_model, tokenizer, train_dataset,
                                      dose, config, seed)


def apply_ssri_treatment(base_model, tokenizer, train_dataset,
                         config: ExperimentConfig, seed: int,
                         dose: TreatmentDose = None):
    """Apply SSRI-like treatment (wrapper for compatibility)."""
    if dose is None:
        dose = TreatmentDose(
            treatment_name='ssri',
            lora_rank=config.ssri_lora_rank,
            lora_alpha=config.ssri_lora_alpha,
            epochs=config.ssri_epochs,
            lr=config.ssri_lr,
            dropout=config.ssri_dropout,
            target_modules=["q_proj", "v_proj"]
        )
        dose.compute_flops(config.finetune_subset_size, config.max_seq_length)

    return apply_treatment_with_dose(base_model, tokenizer, train_dataset,
                                      dose, config, seed)


def apply_neurosteroid_treatment(base_model, tokenizer, train_dataset,
                                  config: ExperimentConfig, seed: int,
                                  dose: TreatmentDose = None):
    """Apply neurosteroid-like treatment (wrapper for compatibility)."""
    if dose is None:
        dose = TreatmentDose(
            treatment_name='neurosteroid',
            lora_rank=config.neuro_lora_rank,
            lora_alpha=config.neuro_lora_alpha,
            epochs=config.neuro_epochs,
            lr=config.neuro_lr,
            dropout=config.neuro_dropout,
            target_modules=["q_proj", "v_proj", "gate_proj", "up_proj"]
        )
        dose.compute_flops(config.finetune_subset_size, config.max_seq_length)

    return apply_treatment_with_dose(base_model, tokenizer, train_dataset,
                                      dose, config, seed)


# ============================================================================
# CELL 8: ACUTE RELAPSE SIMULATION
# ============================================================================
def apply_acute_relapse(model, prune_amount: float = 0.30):
    """Simulate acute relapse by applying unstructured pruning."""
    print(f"    Applying acute relapse: {prune_amount*100:.0f}% unstructured pruning")

    pruned_layers = 0

    for name, module in model.named_modules():
        if isinstance(module, nn.Linear):
            try:
                prune.l1_unstructured(module, name='weight', amount=prune_amount)
                prune.remove(module, 'weight')
                pruned_layers += 1
            except Exception as e:
                pass

    print(f"    Pruned {pruned_layers} linear layers")

    return model


# ============================================================================
# CELL 9: LONGITUDINAL SIMULATION
# ============================================================================
def run_longitudinal_simulation(treated_model, tokenizer, train_dataset,
                                 eval_datasets, treatment_name: str,
                                 config: ExperimentConfig):
    """Simulate longitudinal course with chronic stress and maintenance."""
    print(f"\n  --- Longitudinal Simulation for {treatment_name.upper()} ---")

    trajectory = {
        'accuracy': [],
        'cycle': []
    }

    model = treated_model

    if treatment_name == 'ketamine':
        maintenance_epochs = config.ketamine_maintenance_epochs
    elif treatment_name == 'ssri':
        maintenance_epochs = config.ssri_maintenance_epochs
    else:
        maintenance_epochs = config.neuro_maintenance_epochs

    for cycle in range(config.longitudinal_cycles + 1):
        print(f"    Cycle {cycle}:", end=" ")

        model.eval()
        with torch.no_grad():
            arc_result = evaluate_arc_easy(model, tokenizer, eval_datasets.get('arc'), max_samples=50)

        accuracy = arc_result['accuracy']
        trajectory['accuracy'].append(accuracy)
        trajectory['cycle'].append(cycle)

        print(f"ARC-Easy = {accuracy:.1f}%", end="")

        if cycle < config.longitudinal_cycles:
            for name, module in model.named_modules():
                if isinstance(module, nn.Linear):
                    try:
                        prune.random_unstructured(
                            module,
                            name='weight',
                            amount=config.additional_prune_per_cycle
                        )
                        prune.remove(module, 'weight')
                    except:
                        pass

            print(f" -> Stress applied ({config.additional_prune_per_cycle*100:.0f}% prune)", end="")

            if maintenance_epochs > 0:
                model.train()
                optimizer = torch.optim.AdamW(
                    [p for p in model.parameters() if p.requires_grad],
                    lr=1e-6
                )

                data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)

                maintenance_data = train_dataset.select(range(min(100, len(train_dataset))))
                train_loader = DataLoader(
                    maintenance_data,
                    batch_size=4,
                    collate_fn=data_collator
                )

                for epoch in range(maintenance_epochs):
                    for batch in train_loader:
                        batch = {k: v.to(model.device) for k, v in batch.items()}
                        optimizer.zero_grad()
                        outputs = model(**batch)
                        outputs.loss.backward()
                        optimizer.step()

                print(f" -> Maintenance ({maintenance_epochs} epochs)")
            else:
                print()
        else:
            print()

    return trajectory


# ============================================================================
# CELL 10: SINGLE SEED EXPERIMENT WITH ISODOSE SUPPORT
# ============================================================================
def run_single_seed(seed: int, config: ExperimentConfig = None,
                    calibrated_doses: Dict[str, TreatmentDose] = None) -> Dict:
    """
    Run complete experiment for a single random seed.

    ISODOSE INTEGRATION:
    If calibrated_doses is provided and config.isodose_mode is True,
    treatments will use the isodose-calibrated parameters.
    Otherwise, default parameters from config are used.
    """
    if config is None:
        config = CONFIG

    torch.manual_seed(seed)
    np.random.seed(seed)

    print("\n" + "#" * 70)
    print(f"#  SEED {seed} EXPERIMENT")
    if config.isodose_mode:
        print("#  [ISODOSE MODE ACTIVE]")
    print("#" * 70)

    results = {
        'seed': seed,
        'isodose_mode': config.isodose_mode,
        'treatments': {},
        'longitudinal': {}
    }

    eval_datasets = load_evaluation_datasets(config)

    # -------------------------------------------------------------------------
    # STEP 1: Load pruned model and evaluate untreated baseline
    # -------------------------------------------------------------------------
    print("\n[STEP 1] Loading pruned model (untreated baseline)")
    base_model, tokenizer = load_pruned_model(config)

    print("\n[STEP 2] Evaluating UNTREATED (pruned) baseline")
    print("-" * 70)
    untreated_results = evaluate_model_full(base_model, tokenizer, eval_datasets, config)

    results['treatments']['untreated'] = {
        'post_treatment': untreated_results,
        'post_relapse': None,
        'relapse_drop': None
    }

    print(f"  UNTREATED BASELINE:")
    print(f"    ARC-Easy: {untreated_results['arc_easy']:.1f}%")
    print(f"    LAMBADA:  {untreated_results['lambada']:.1f}%")
    print(f"    Composite: {untreated_results['composite']:.1f}%")

    train_dataset = prepare_finetune_dataset(tokenizer, config, seed)

    # -------------------------------------------------------------------------
    # STEP 3: Apply treatments and evaluate
    # -------------------------------------------------------------------------
    treatments = ['ketamine', 'ssri', 'neurosteroid']

    for treatment_name in treatments:
        print(f"\n[STEP 3.{treatments.index(treatment_name)+1}] Processing {treatment_name.upper()} treatment")

        # Get dose (isodose-calibrated or default)
        if config.isodose_mode and calibrated_doses:
            dose = calibrated_doses[treatment_name]
            print(f"  Using ISODOSE parameters (target FLOPs: {dose.estimated_flops:.2e})")
        else:
            dose = None
            print(f"  Using DEFAULT parameters")

        print("  Reloading fresh pruned model...")
        fresh_model, _ = load_pruned_model(config)

        # Apply treatment
        if treatment_name == 'ketamine':
            treated_model, treatment_stats = apply_ketamine_treatment(
                fresh_model, tokenizer, train_dataset, config, seed, dose
            )
        elif treatment_name == 'ssri':
            treated_model, treatment_stats = apply_ssri_treatment(
                fresh_model, tokenizer, train_dataset, config, seed, dose
            )
        else:
            treated_model, treatment_stats = apply_neurosteroid_treatment(
                fresh_model, tokenizer, train_dataset, config, seed, dose
            )

        # Evaluate post-treatment
        print(f"\n  Evaluating POST-{treatment_name.upper()} performance:")
        post_treatment = evaluate_model_full(treated_model, tokenizer, eval_datasets, config)

        print(f"    ARC-Easy: {post_treatment['arc_easy']:.1f}%")
        print(f"    LAMBADA:  {post_treatment['lambada']:.1f}%")
        print(f"    Composite: {post_treatment['composite']:.1f}%")

        recovery = post_treatment['composite'] - untreated_results['composite']
        print(f"    Recovery from untreated: {recovery:+.1f}%")

        # Calculate efficiency metric for isodose comparison
        if config.isodose_mode and calibrated_doses:
            efficiency = recovery / (treatment_stats['estimated_flops'] / 1e15)  # Recovery per PetaFLOP
            print(f"    Efficiency (recovery/PetaFLOP): {efficiency:.2f}")
            treatment_stats['efficiency'] = efficiency

        # Apply acute relapse
        print(f"\n  Simulating ACUTE RELAPSE:")
        apply_acute_relapse(treated_model, config.acute_relapse_prune_amount)

        # Evaluate post-relapse
        print(f"  Evaluating POST-RELAPSE performance:")
        post_relapse = evaluate_model_full(treated_model, tokenizer, eval_datasets, config)

        relapse_drop = post_treatment['composite'] - post_relapse['composite']
        print(f"    ARC-Easy: {post_relapse['arc_easy']:.1f}%")
        print(f"    LAMBADA:  {post_relapse['lambada']:.1f}%")
        print(f"    Composite: {post_relapse['composite']:.1f}%")
        print(f"    Relapse drop: {relapse_drop:.1f}%")

        results['treatments'][treatment_name] = {
            'stats': treatment_stats,
            'post_treatment': post_treatment,
            'post_relapse': post_relapse,
            'recovery': recovery,
            'relapse_drop': relapse_drop
        }

        # Run longitudinal simulation
        print(f"\n  Running LONGITUDINAL simulation for {treatment_name}:")

        fresh_model2, _ = load_pruned_model(config)

        if treatment_name == 'ketamine':
            long_model, _ = apply_ketamine_treatment(
                fresh_model2, tokenizer, train_dataset, config, seed, dose
            )
        elif treatment_name == 'ssri':
            long_model, _ = apply_ssri_treatment(
                fresh_model2, tokenizer, train_dataset, config, seed, dose
            )
        else:
            long_model, _ = apply_neurosteroid_treatment(
                fresh_model2, tokenizer, train_dataset, config, seed, dose
            )

        trajectory = run_longitudinal_simulation(
            long_model, tokenizer, train_dataset, eval_datasets,
            treatment_name, config
        )

        results['longitudinal'][treatment_name] = trajectory

        del treated_model, fresh_model, long_model
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

    print("\n" + "#" * 70)
    print(f"#  SEED {seed} COMPLETE")
    print("#" * 70)

    return results


# ============================================================================
# CELL 11: RESULTS AGGREGATION AND PRINTING
# ============================================================================
def aggregate_results(all_results: List[Dict]) -> Dict:
    """Aggregate results across all seeds."""
    aggregated = {
        'n_seeds': len(all_results),
        'isodose_mode': all_results[0].get('isodose_mode', False) if all_results else False,
        'treatments': {},
        'longitudinal': {}
    }

    treatments = ['untreated', 'ketamine', 'ssri', 'neurosteroid']

    for treatment in treatments:
        metrics = defaultdict(list)

        for seed_result in all_results:
            if treatment in seed_result.get('treatments', {}):
                treat_data = seed_result['treatments'][treatment]

                if 'post_treatment' in treat_data and treat_data['post_treatment']:
                    for key, value in treat_data['post_treatment'].items():
                        if isinstance(value, (int, float)):
                            metrics[f'post_{key}'].append(value)

                if treat_data.get('recovery') is not None:
                    metrics['recovery'].append(treat_data['recovery'])
                if treat_data.get('relapse_drop') is not None:
                    metrics['relapse_drop'].append(treat_data['relapse_drop'])

                # Collect isodose-specific metrics
                if 'stats' in treat_data and treat_data['stats']:
                    if 'efficiency' in treat_data['stats']:
                        metrics['efficiency'].append(treat_data['stats']['efficiency'])
                    if 'estimated_flops' in treat_data['stats']:
                        metrics['flops'].append(treat_data['stats']['estimated_flops'])

                if 'post_relapse' in treat_data and treat_data['post_relapse']:
                    for key, value in treat_data['post_relapse'].items():
                        if isinstance(value, (int, float)):
                            metrics[f'relapse_{key}'].append(value)

        aggregated['treatments'][treatment] = {}
        for key, values in metrics.items():
            if values:
                aggregated['treatments'][treatment][key] = {
                    'mean': np.mean(values),
                    'std': np.std(values),
                    'values': values
                }

    # Aggregate longitudinal
    for treat in ['ketamine', 'ssri', 'neurosteroid']:
        trajectories = []

        for seed_result in all_results:
            if 'longitudinal' in seed_result and treat in seed_result['longitudinal']:
                trajectories.append(seed_result['longitudinal'][treat]['accuracy'])

        if trajectories:
            traj_array = np.array(trajectories)
            aggregated['longitudinal'][treat] = {
                'accuracy_mean': np.mean(traj_array, axis=0).tolist(),
                'accuracy_std': np.std(traj_array, axis=0).tolist(),
                'initial_mean': float(np.mean(traj_array[:, 0])),
                'final_mean': float(np.mean(traj_array[:, -1])),
                'total_drop_mean': float(np.mean(traj_array[:, 0] - traj_array[:, -1])),
                'total_drop_std': float(np.std(traj_array[:, 0] - traj_array[:, -1]))
            }

    return aggregated


def print_final_results(aggregated: Dict, config: ExperimentConfig,
                        calibrated_doses: Dict[str, TreatmentDose] = None):
    """Print comprehensive final results tables."""
    print("\n")
    print("=" * 80)
    print("=" * 80)
    print("  FINAL AGGREGATED RESULTS")
    print(f"  {aggregated['n_seeds']} seeds completed")
    if aggregated.get('isodose_mode'):
        print("  [ISODOSE COMPARISON MODE]")
    print("=" * 80)
    print("=" * 80)

    # -------------------------------------------------------------------------
    # ISODOSE CONFIGURATION SUMMARY (if applicable)
    # -------------------------------------------------------------------------
    if aggregated.get('isodose_mode') and calibrated_doses:
        print("\n" + "-" * 80)
        print("  ISODOSE CONFIGURATION")
        print("  (All treatments calibrated to equivalent computational cost)")
        print("-" * 80)

        print(f"\n  {'Treatment':<15} {'Rank':>6} {'Epochs':>8} {'Dropout':>8} {'FLOPs':>14}")
        print("  " + "-" * 55)

        for name in ['ketamine', 'ssri', 'neurosteroid']:
            if name in calibrated_doses:
                dose = calibrated_doses[name]
                print(f"  {name.capitalize():<15} {dose.lora_rank:>6} {dose.epochs:>8} "
                      f"{dose.dropout:>8.2f} {dose.estimated_flops:>14.2e}")

    # -------------------------------------------------------------------------
    # TABLE 1: Post-Treatment Performance
    # -------------------------------------------------------------------------
    print("\n" + "-" * 80)
    print("  TABLE 1: POST-TREATMENT PERFORMANCE")
    print("  (Mean ± Std across seeds, percentages)")
    print("-" * 80)

    print(f"\n  {'Treatment':<20} {'ARC-Easy':>15} {'LAMBADA':>15} {'Composite':>15}")
    print("  " + "-" * 65)

    labels = {
        'untreated': 'Untreated (pruned)',
        'ketamine': 'Ketamine-like',
        'ssri': 'SSRI-like',
        'neurosteroid': 'Neurosteroid-like'
    }

    for treat in ['untreated', 'ketamine', 'ssri', 'neurosteroid']:
        if treat in aggregated['treatments']:
            data = aggregated['treatments'][treat]

            arc = data.get('post_arc_easy', {})
            lamb = data.get('post_lambada', {})
            comp = data.get('post_composite', {})

            arc_str = f"{arc.get('mean', 0):.1f}±{arc.get('std', 0):.1f}" if arc else "N/A"
            lamb_str = f"{lamb.get('mean', 0):.1f}±{lamb.get('std', 0):.1f}" if lamb else "N/A"
            comp_str = f"{comp.get('mean', 0):.1f}±{comp.get('std', 0):.1f}" if comp else "N/A"

            print(f"  {labels[treat]:<20} {arc_str:>15} {lamb_str:>15} {comp_str:>15}")

    # -------------------------------------------------------------------------
    # TABLE 2: Treatment Effects (Recovery and Relapse Resilience)
    # -------------------------------------------------------------------------
    print("\n" + "-" * 80)
    print("  TABLE 2: TREATMENT EFFECTS")
    print("  (Recovery = improvement over untreated, Relapse Drop = loss after acute stress)")
    print("-" * 80)

    print(f"\n  {'Treatment':<20} {'Recovery':>18} {'Relapse Drop':>18} {'Net Retained':>15}")
    print("  " + "-" * 70)

    for treat in ['ketamine', 'ssri', 'neurosteroid']:
        if treat in aggregated['treatments']:
            data = aggregated['treatments'][treat]

            rec = data.get('recovery', {})
            drop = data.get('relapse_drop', {})

            rec_mean = rec.get('mean', 0)
            rec_std = rec.get('std', 0)
            drop_mean = drop.get('mean', 0)
            drop_std = drop.get('std', 0)
            net_retained = rec_mean - drop_mean

            rec_str = f"{rec_mean:+.1f}±{rec_std:.1f}%"
            drop_str = f"{drop_mean:.1f}±{drop_std:.1f}%"
            net_str = f"{net_retained:+.1f}%"

            print(f"  {labels[treat]:<20} {rec_str:>18} {drop_str:>18} {net_str:>15}")

    # -------------------------------------------------------------------------
    # TABLE 3: ISODOSE EFFICIENCY COMPARISON (NEW)
    # -------------------------------------------------------------------------
    if aggregated.get('isodose_mode'):
        print("\n" + "-" * 80)
        print("  TABLE 3: ISODOSE EFFICIENCY COMPARISON")
        print("  (Recovery per unit computational cost - higher is more efficient)")
        print("-" * 80)

        print(f"\n  {'Treatment':<20} {'Recovery':>12} {'FLOPs':>14} {'Efficiency*':>15}")
        print("  " + "-" * 60)

        efficiencies = []
        for treat in ['ketamine', 'ssri', 'neurosteroid']:
            if treat in aggregated['treatments']:
                data = aggregated['treatments'][treat]

                rec = data.get('recovery', {}).get('mean', 0)
                flops = data.get('flops', {}).get('mean', 0)
                eff = data.get('efficiency', {})

                eff_mean = eff.get('mean', 0) if eff else 0
                eff_std = eff.get('std', 0) if eff else 0

                efficiencies.append((treat, eff_mean))

                rec_str = f"{rec:+.1f}%"
                flops_str = f"{flops:.2e}" if flops else "N/A"
                eff_str = f"{eff_mean:.2f}±{eff_std:.2f}" if eff else "N/A"

                print(f"  {labels[treat]:<20} {rec_str:>12} {flops_str:>14} {eff_str:>15}")

        print("\n  * Efficiency = Recovery (%) / PetaFLOPs")

        # Rank treatments by efficiency
        if efficiencies:
            efficiencies.sort(key=lambda x: x[1], reverse=True)
            print(f"\n  ISODOSE RANKING (by efficiency):")
            for i, (treat, eff) in enumerate(efficiencies):
                print(f"    {i+1}. {labels[treat]}: {eff:.2f}")

    # -------------------------------------------------------------------------
    # TABLE 4: Longitudinal Trajectory
    # -------------------------------------------------------------------------
    if aggregated['longitudinal']:
        print("\n" + "-" * 80)
        print("  TABLE 4: LONGITUDINAL TRAJECTORY (ARC-Easy % over stress cycles)")
        print("-" * 80)

        n_cycles = len(aggregated['longitudinal']['ketamine']['accuracy_mean'])

        header = f"\n  {'Cycle':<8}"
        for treat in ['ketamine', 'ssri', 'neurosteroid']:
            header += f" {treat.capitalize():>18}"
        print(header)
        print("  " + "-" * 62)

        for c in range(n_cycles):
            row = f"  {c:<8}"
            for treat in ['ketamine', 'ssri', 'neurosteroid']:
                mean = aggregated['longitudinal'][treat]['accuracy_mean'][c]
                std = aggregated['longitudinal'][treat]['accuracy_std'][c]
                row += f" {mean:>7.1f}±{std:>5.1f}%"
            print(row)

        print("\n  " + "-" * 62)
        print("  LONGITUDINAL SUMMARY:")

        for treat in ['ketamine', 'ssri', 'neurosteroid']:
            long = aggregated['longitudinal'][treat]
            print(f"    {labels[treat]:<20}")
            print(f"      Initial (cycle 0): {long['initial_mean']:.1f}%")
            print(f"      Final (cycle {n_cycles-1}):   {long['final_mean']:.1f}%")
            print(f"      Total drop:        {long['total_drop_mean']:.1f}±{long['total_drop_std']:.1f}%")

    # -------------------------------------------------------------------------
    # INTERPRETATION GUIDE
    # -------------------------------------------------------------------------
    print("\n" + "=" * 80)
    print("  INTERPRETATION GUIDE")
    print("=" * 80)

    if aggregated.get('isodose_mode'):
        print("""
    ISODOSE COMPARISON INTERPRETATION:

    The isodose framework enables FAIR comparison by equalizing computational cost.
    This answers: "Given the SAME therapeutic effort, which mechanism is superior?"

    KEY ISODOSE METRICS:
    1. Recovery (%) - Improvement over untreated baseline
    2. Efficiency - Recovery per PetaFLOP of computation
    3. Relapse Drop - Vulnerability to acute stress
    4. Longitudinal Stability - Durability under chronic stress

    WHAT ISODOSE REVEALS:
    - If a treatment with LOW structural capacity (SSRI) matches HIGH capacity (ketamine)
      at equal FLOPs, the MECHANISM itself may be more efficient
    - If HIGH capacity (ketamine) still wins at isodose, structural regrowth may be
      fundamentally more effective for this type of damage

    EXPECTED PATTERNS AT ISODOSE:

    KETAMINE-LIKE (r=64, short duration):
    - High efficiency IF structural regrowth is critical for knowledge recovery
    - May show diminishing returns if extra capacity isn't fully utilized

    SSRI-LIKE (r=8, extended duration):
    - At isodose, gets MORE epochs than default (scaled up to match FLOPs)
    - Tests whether gradual learning can compensate for low capacity
    - May show improved efficiency if default was "undertreated"

    NEUROSTEROID-LIKE (r=32, high dropout):
    - At isodose, dropout regularization applied over calibrated epochs
    - Tests whether inhibition-based stabilization is cost-effective

    COMPARING DEFAULT vs ISODOSE:
    - If rankings CHANGE between modes, the default comparison was unfair
    - If rankings PERSIST, the mechanism difference is robust to dosing
    """)
    else:
        print("""
    EXPECTED PATTERNS (based on mechanistic analogies):

    KETAMINE-LIKE (High-rank LoRA, r=64, aggressive training):
    - EXPECTED: Best acute recovery due to high new parameter capacity
    - EXPECTED: Smallest relapse drop (structural changes are robust)
    - EXPECTED: Best longitudinal stability under chronic stress
    - ANALOGY: Rapid synaptogenesis provides lasting structural benefits

    SSRI-LIKE (Low-rank LoRA, r=8, slow gradual training):
    - EXPECTED: Slowest, most variable recovery
    - EXPECTED: Moderate-to-large relapse vulnerability
    - EXPECTED: Variable longitudinal outcomes across seeds
    - ANALOGY: Gradual neuromodulation without major structural rewiring

    NEUROSTEROID-LIKE (Moderate LoRA, high dropout=0.5):
    - EXPECTED: Fast initial response due to regularization effects
    - EXPECTED: Intermediate relapse resilience
    - EXPECTED: May show state-dependence (dropout during training vs eval)
    - ANALOGY: Tonic inhibition provides rapid stabilization
    """)

    print("=" * 80)


# ============================================================================
# CELL 12: MAIN EXPERIMENT RUNNER
# ============================================================================
def run_experiment(config: ExperimentConfig = None):
    """
    Run the complete multi-seed experiment.

    ISODOSE MODE:
    If config.isodose_mode is True, treatments are first calibrated to
    equivalent computational cost before running experiments.
    """
    if config is None:
        config = CONFIG

    print("\n")
    print("#" * 80)
    print("#" * 80)
    print("#" + " " * 78 + "#")
    print("#" + " MULTI-MECHANISM RECOVERY FROM PRUNING-INDUCED FRAGILITY ".center(78) + "#")
    print("#" + " IN LARGE LANGUAGE MODELS ".center(78) + "#")
    if config.isodose_mode:
        print("#" + " [WITH ISODOSE COMPARISON] ".center(78) + "#")
    print("#" + " " * 78 + "#")
    print("#" * 80)
    print("#" * 80)

    print("\n" + "=" * 80)
    print("EXPERIMENT CONFIGURATION")
    print("=" * 80)
    print(f"  Base model (pre-pruned): {config.pruned_model}")
    print(f"  Pruning level: 60% MLP neurons removed (severe fragility)")
    print(f"  Number of seeds: {config.n_seeds}")
    print(f"  Evaluation samples: {config.eval_samples}")
    print(f"  Fine-tune dataset: {config.finetune_dataset}")
    print(f"  Fine-tune subset: {config.finetune_subset_size} samples")
    print(f"  Longitudinal cycles: {config.longitudinal_cycles}")
    print(f"  Acute relapse pruning: {config.acute_relapse_prune_amount*100:.0f}%")
    print(f"  Chronic stress pruning: {config.additional_prune_per_cycle*100:.0f}% per cycle")
    print(f"  ISODOSE MODE: {config.isodose_mode}")

    # -------------------------------------------------------------------------
    # ISODOSE CALIBRATION (if enabled)
    # -------------------------------------------------------------------------
    calibrated_doses = None

    if config.isodose_mode:
        calibrated_doses = calibrate_all_treatments(config)
    else:
        print("\n" + "-" * 80)
        print("DEFAULT TREATMENT PARAMETERS")
        print("-" * 80)
        print(f"  KETAMINE-LIKE:")
        print(f"    LoRA rank: {config.ketamine_lora_rank}, alpha: {config.ketamine_lora_alpha}")
        print(f"    Epochs: {config.ketamine_epochs}, LR: {config.ketamine_lr}")
        print(f"    Dropout: {config.ketamine_dropout}")

        print(f"  SSRI-LIKE:")
        print(f"    LoRA rank: {config.ssri_lora_rank}, alpha: {config.ssri_lora_alpha}")
        print(f"    Epochs: {config.ssri_epochs}, LR: {config.ssri_lr}")
        print(f"    Dropout: {config.ssri_dropout}")

        print(f"  NEUROSTEROID-LIKE:")
        print(f"    LoRA rank: {config.neuro_lora_rank}, alpha: {config.neuro_lora_alpha}")
        print(f"    Epochs: {config.neuro_epochs}, LR: {config.neuro_lr}")
        print(f"    Dropout: {config.neuro_dropout} (HIGH - tonic inhibition)")

    print("=" * 80)

    # Create output directory
    os.makedirs(config.output_dir, exist_ok=True)

    # Run experiments
    all_results = []

    for seed in range(config.n_seeds):
        print(f"\n{'*'*80}")
        print(f"  STARTING SEED {seed+1}/{config.n_seeds}")
        print(f"{'*'*80}")

        try:
            seed_results = run_single_seed(seed, config, calibrated_doses)
            all_results.append(seed_results)

            # Save intermediate results
            with open(os.path.join(config.output_dir, f'seed_{seed}_results.json'), 'w') as f:
                json.dump({
                    'seed': seed,
                    'isodose_mode': config.isodose_mode,
                    'treatments': {
                        k: {
                            'post_treatment': v.get('post_treatment'),
                            'recovery': v.get('recovery'),
                            'relapse_drop': v.get('relapse_drop'),
                            'stats': {
                                'epochs': v.get('stats', {}).get('epochs'),
                                'flops': v.get('stats', {}).get('estimated_flops'),
                                'efficiency': v.get('stats', {}).get('efficiency')
                            } if v.get('stats') else None
                        } for k, v in seed_results['treatments'].items()
                    }
                }, f, indent=2)

            print(f"\n  [SAVED] Seed {seed} results saved to {config.output_dir}/seed_{seed}_results.json")

        except Exception as e:
            print(f"\n  [ERROR] Seed {seed} failed: {e}")
            import traceback
            traceback.print_exc()
            continue

    if not all_results:
        print("\n[FATAL] No successful seed runs!")
        return {}

    # Aggregate results
    print("\n" + "=" * 80)
    print("AGGREGATING RESULTS ACROSS SEEDS")
    print("=" * 80)

    aggregated = aggregate_results(all_results)

    # Print final results
    print_final_results(aggregated, config, calibrated_doses)

    # Save final aggregated results
    try:
        with open(os.path.join(config.output_dir, 'aggregated_results.json'), 'w') as f:
            def convert_to_serializable(obj):
                if isinstance(obj, np.ndarray):
                    return obj.tolist()
                elif isinstance(obj, (np.float32, np.float64)):
                    return float(obj)
                elif isinstance(obj, (np.int32, np.int64)):
                    return int(obj)
                elif isinstance(obj, dict):
                    return {k: convert_to_serializable(v) for k, v in obj.items()}
                elif isinstance(obj, list):
                    return [convert_to_serializable(i) for i in obj]
                return obj

            json.dump(convert_to_serializable(aggregated), f, indent=2)

        print(f"\n[SAVED] Aggregated results saved to {config.output_dir}/aggregated_results.json")
    except Exception as e:
        print(f"\n[WARNING] Could not save aggregated results: {e}")

    # Save isodose calibration if applicable
    if config.isodose_mode and calibrated_doses:
        try:
            isodose_config = {
                name: {
                    'lora_rank': dose.lora_rank,
                    'lora_alpha': dose.lora_alpha,
                    'epochs': dose.epochs,
                    'lr': dose.lr,
                    'dropout': dose.dropout,
                    'target_modules': dose.target_modules,
                    'trainable_params': dose.trainable_params,
                    'estimated_flops': dose.estimated_flops
                }
                for name, dose in calibrated_doses.items()
            }

            with open(os.path.join(config.output_dir, 'isodose_calibration.json'), 'w') as f:
                json.dump(isodose_config, f, indent=2)

            print(f"[SAVED] Isodose calibration saved to {config.output_dir}/isodose_calibration.json")
        except Exception as e:
            print(f"[WARNING] Could not save isodose calibration: {e}")

    print("\n" + "=" * 80)
    print("EXPERIMENT COMPLETE")
    print("=" * 80)

    return {
        'all_results': all_results,
        'aggregated': aggregated,
        'calibrated_doses': calibrated_doses
    }


# ============================================================================
# CELL 13: ENTRY POINT WITH MODE SELECTION
# ============================================================================
if __name__ == "__main__":
    # Configure experiment mode
    config = ExperimentConfig()

    # Enable isodose comparison (set to False for default mode)
    config.isodose_mode = True

    print("\n" + "=" * 80)
    print("EXPERIMENT MODE SELECTION")
    print("=" * 80)
    print(f"  isodose_mode = {config.isodose_mode}")
    if config.isodose_mode:
        print("  -> Treatments will be calibrated to EQUIVALENT computational cost")
        print("  -> This enables FAIR comparison of mechanism efficacy")
    else:
        print("  -> Treatments will use DEFAULT parameters")
        print("  -> Comparisons reflect both mechanism AND dose differences")
    print("=" * 80)

    # Run the experiment
    results = run_experiment(config)


EXPERIMENT MODE SELECTION
  isodose_mode = True
  -> Treatments will be calibrated to EQUIVALENT computational cost
  -> This enables FAIR comparison of mechanism efficacy


################################################################################
################################################################################
#                                                                              #
#           MULTI-MECHANISM RECOVERY FROM PRUNING-INDUCED FRAGILITY            #
#                           IN LARGE LANGUAGE MODELS                           #
#                          [WITH ISODOSE COMPARISON]                           #
#                                                                              #
################################################################################
################################################################################

EXPERIMENT CONFIGURATION
  Base model (pre-pruned): oopere/pruned60-llama-3.2-1B
  Pruning leve

README.md: 0.00B [00:00, ?B/s]

ARC-Easy/train-00000-of-00001.parquet:   0%|          | 0.00/331k [00:00<?, ?B/s]

ARC-Easy/test-00000-of-00001.parquet:   0%|          | 0.00/346k [00:00<?, ?B/s]

ARC-Easy/validation-00000-of-00001.parqu(…):   0%|          | 0.00/86.1k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/2251 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/2376 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/570 [00:00<?, ? examples/s]

  ARC-Easy: 200 samples loaded


README.md: 0.00B [00:00, ?B/s]

plain_text/train-00000-of-00002.parquet:   0%|          | 0.00/269M [00:00<?, ?B/s]

plain_text/train-00001-of-00002.parquet:   0%|          | 0.00/281M [00:00<?, ?B/s]

plain_text/test-00000-of-00001.parquet:   0%|          | 0.00/1.14M [00:00<?, ?B/s]

plain_text/validation-00000-of-00001.par(…):   0%|          | 0.00/1.08M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/2662 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/5153 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/4869 [00:00<?, ? examples/s]

  LAMBADA: 100 samples loaded
----------------------------------------------------------------------

[STEP 1] Loading pruned model (untreated baseline)
LOADING PRE-PRUNED MODEL
  Model: oopere/pruned60-llama-3.2-1B
  Dtype: torch.float16 (NO 4-bit quantization)
  Device map: auto
  Expected state: 60% MLP neurons removed (fragile knowledge)
----------------------------------------------------------------------


config.json:   0%|          | 0.00/883 [00:00<?, ?B/s]

`torch_dtype` is deprecated! Use `dtype` instead!


model.safetensors:   0%|          | 0.00/1.51G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/180 [00:00<?, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

tokenizer.json:   0%|          | 0.00/17.2M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/335 [00:00<?, ?B/s]

The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


  Load time: 7.87s
  Total parameters: 752,650,240
  Trainable parameters: 752,650,240
  Model device: cuda:0
  Model dtype: torch.float16
  [SUCCESS] Pruned model loaded (fragile state confirmed)

[STEP 2] Evaluating UNTREATED (pruned) baseline
----------------------------------------------------------------------
    Running ARC-Easy evaluation... Accuracy: 22.5%
    Running LAMBADA evaluation... Accuracy: 20.0%
  UNTREATED BASELINE:
    ARC-Easy: 22.5%
    LAMBADA:  20.0%
    Composite: 21.2%
----------------------------------------------------------------------
PREPARING FINE-TUNING DATASET (seed=0)
----------------------------------------------------------------------


README.md: 0.00B [00:00, ?B/s]

databricks-dolly-15k.jsonl:   0%|          | 0.00/13.1M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/15011 [00:00<?, ? examples/s]

  Dataset: databricks/databricks-dolly-15k
  Subset size: 500 samples


Map:   0%|          | 0/500 [00:00<?, ? examples/s]

Map:   0%|          | 0/500 [00:00<?, ? examples/s]

  Tokenized: max_length=512
----------------------------------------------------------------------

[STEP 3.1] Processing KETAMINE treatment
  Using ISODOSE parameters (target FLOPs: 1.35e+14)
  Reloading fresh pruned model...
LOADING PRE-PRUNED MODEL
  Model: oopere/pruned60-llama-3.2-1B
  Dtype: torch.float16 (NO 4-bit quantization)
  Device map: auto
  Expected state: 60% MLP neurons removed (fragile knowledge)
----------------------------------------------------------------------
  Load time: 3.53s
  Total parameters: 752,650,240
  Trainable parameters: 752,650,240
  Model device: cuda:0
  Model dtype: torch.float16
  [SUCCESS] Pruned model loaded (fragile state confirmed)

APPLYING KETAMINE TREATMENT
  [ISODOSE MODE]
  LoRA rank: 64
  LoRA alpha: 128
  Training epochs: 3
  Learning rate: 5e-05
  Dropout: 0.05
  Target modules: ['q_proj', 'v_proj', 'k_proj', 'o_proj', 'gate_proj', 'up_proj', 'down_proj']
  Estimated FLOPs: 1.35e+14
--------------------------------------------------

The model is already on multiple devices. Skipping the move to device specified in `args`.


  Actual trainable parameters: 29,989,888 (3.83%)
  Starting training...


Step,Training Loss
50,3.8002


  Training complete in 61.6s

  Evaluating POST-KETAMINE performance:
    Running ARC-Easy evaluation... Accuracy: 34.5%
    Running LAMBADA evaluation... Accuracy: 14.0%
    ARC-Easy: 34.5%
    LAMBADA:  14.0%
    Composite: 24.2%
    Recovery from untreated: +3.0%
    Efficiency (recovery/PetaFLOP): 22.17

  Simulating ACUTE RELAPSE:
    Applying acute relapse: 30% unstructured pruning
    Pruned 337 linear layers
  Evaluating POST-RELAPSE performance:
    Running ARC-Easy evaluation... Accuracy: 27.5%
    Running LAMBADA evaluation... Accuracy: 16.0%
    ARC-Easy: 27.5%
    LAMBADA:  16.0%
    Composite: 21.8%
    Relapse drop: 2.5%

  Running LONGITUDINAL simulation for ketamine:
LOADING PRE-PRUNED MODEL
  Model: oopere/pruned60-llama-3.2-1B
  Dtype: torch.float16 (NO 4-bit quantization)
  Device map: auto
  Expected state: 60% MLP neurons removed (fragile knowledge)
----------------------------------------------------------------------
  Load time: 1.62s
  Total parameters: 752,65

The model is already on multiple devices. Skipping the move to device specified in `args`.


  Actual trainable parameters: 29,989,888 (3.83%)
  Starting training...


Step,Training Loss
50,3.8002


  Training complete in 61.9s

  --- Longitudinal Simulation for KETAMINE ---
    Cycle 0: ARC-Easy = 30.0% -> Stress applied (10% prune) -> Maintenance (2 epochs)
    Cycle 1: ARC-Easy = 20.0% -> Stress applied (10% prune) -> Maintenance (2 epochs)
    Cycle 2: ARC-Easy = 14.0% -> Stress applied (10% prune) -> Maintenance (2 epochs)
    Cycle 3: ARC-Easy = 16.0% -> Stress applied (10% prune) -> Maintenance (2 epochs)
    Cycle 4: ARC-Easy = 26.0% -> Stress applied (10% prune) -> Maintenance (2 epochs)
    Cycle 5: ARC-Easy = 6.0% -> Stress applied (10% prune) -> Maintenance (2 epochs)
    Cycle 6: ARC-Easy = 14.0% -> Stress applied (10% prune) -> Maintenance (2 epochs)
    Cycle 7: ARC-Easy = 20.0% -> Stress applied (10% prune) -> Maintenance (2 epochs)
    Cycle 8: ARC-Easy = 20.0%

[STEP 3.2] Processing SSRI treatment
  Using ISODOSE parameters (target FLOPs: 1.35e+14)
  Reloading fresh pruned model...
LOADING PRE-PRUNED MODEL
  Model: oopere/pruned60-llama-3.2-1B
  Dtype: torch.floa

The model is already on multiple devices. Skipping the move to device specified in `args`.


  Load time: 1.70s
  Total parameters: 752,650,240
  Trainable parameters: 752,650,240
  Model device: cuda:0
  Model dtype: torch.float16
  [SUCCESS] Pruned model loaded (fragile state confirmed)

APPLYING SSRI TREATMENT
  [ISODOSE MODE]
  LoRA rank: 8
  LoRA alpha: 16
  Training epochs: 84
  Learning rate: 1e-06
  Dropout: 0.1
  Target modules: ['q_proj', 'v_proj']
  Estimated FLOPs: 1.35e+14
----------------------------------------------------------------------
  Actual trainable parameters: 851,968 (0.11%)
  Starting training...


Step,Training Loss
50,5.3426
100,5.3664
150,5.3336
200,5.338
250,5.3033
300,5.2759
350,5.262
400,5.2046
450,5.1396
500,5.1


  Training complete in 1247.4s

  Evaluating POST-SSRI performance:
    Running ARC-Easy evaluation... Accuracy: 23.0%
    Running LAMBADA evaluation... Accuracy: 16.0%
    ARC-Easy: 23.0%
    LAMBADA:  16.0%
    Composite: 19.5%
    Recovery from untreated: -1.8%
    Efficiency (recovery/PetaFLOP): -12.94

  Simulating ACUTE RELAPSE:
    Applying acute relapse: 30% unstructured pruning
    Pruned 177 linear layers
  Evaluating POST-RELAPSE performance:
    Running ARC-Easy evaluation... Accuracy: 31.5%
    Running LAMBADA evaluation... Accuracy: 13.0%
    ARC-Easy: 31.5%
    LAMBADA:  13.0%
    Composite: 22.2%
    Relapse drop: -2.8%

  Running LONGITUDINAL simulation for ssri:
LOADING PRE-PRUNED MODEL
  Model: oopere/pruned60-llama-3.2-1B
  Dtype: torch.float16 (NO 4-bit quantization)
  Device map: auto
  Expected state: 60% MLP neurons removed (fragile knowledge)
----------------------------------------------------------------------


The model is already on multiple devices. Skipping the move to device specified in `args`.


  Load time: 1.80s
  Total parameters: 752,650,240
  Trainable parameters: 752,650,240
  Model device: cuda:0
  Model dtype: torch.float16
  [SUCCESS] Pruned model loaded (fragile state confirmed)

APPLYING SSRI TREATMENT
  [ISODOSE MODE]
  LoRA rank: 8
  LoRA alpha: 16
  Training epochs: 84
  Learning rate: 1e-06
  Dropout: 0.1
  Target modules: ['q_proj', 'v_proj']
  Estimated FLOPs: 1.35e+14
----------------------------------------------------------------------
  Actual trainable parameters: 851,968 (0.11%)
  Starting training...


Step,Training Loss
50,5.3426
100,5.3664
150,5.3335
200,5.338
250,5.3033
300,5.2759
350,5.262
400,5.2046
450,5.1396
500,5.1


  Training complete in 1249.6s

  --- Longitudinal Simulation for SSRI ---
    Cycle 0: ARC-Easy = 20.0% -> Stress applied (10% prune) -> Maintenance (5 epochs)
    Cycle 1: ARC-Easy = 24.0% -> Stress applied (10% prune) -> Maintenance (5 epochs)
    Cycle 2: ARC-Easy = 6.0% -> Stress applied (10% prune) -> Maintenance (5 epochs)
    Cycle 3: ARC-Easy = 12.0% -> Stress applied (10% prune) -> Maintenance (5 epochs)
    Cycle 4: ARC-Easy = 8.0% -> Stress applied (10% prune) -> Maintenance (5 epochs)
    Cycle 5: ARC-Easy = 30.0% -> Stress applied (10% prune) -> Maintenance (5 epochs)
    Cycle 6: ARC-Easy = 4.0% -> Stress applied (10% prune) -> Maintenance (5 epochs)
    Cycle 7: ARC-Easy = 28.0% -> Stress applied (10% prune) -> Maintenance (5 epochs)
    Cycle 8: ARC-Easy = 18.0%

[STEP 3.3] Processing NEUROSTEROID treatment
  Using ISODOSE parameters (target FLOPs: 1.29e+14)
  Reloading fresh pruned model...
LOADING PRE-PRUNED MODEL
  Model: oopere/pruned60-llama-3.2-1B
  Dtype: torch.

The model is already on multiple devices. Skipping the move to device specified in `args`.


  Starting training...


Step,Training Loss
50,4.6244
100,3.6123
150,3.3721
200,3.2792
250,3.17
300,3.1506


  Training complete in 167.9s

  Evaluating POST-NEUROSTEROID performance:
    Running ARC-Easy evaluation... Accuracy: 30.0%
    Running LAMBADA evaluation... Accuracy: 11.0%
    ARC-Easy: 30.0%
    LAMBADA:  11.0%
    Composite: 20.5%
    Recovery from untreated: -0.8%
    Efficiency (recovery/PetaFLOP): -5.82

  Simulating ACUTE RELAPSE:
    Applying acute relapse: 30% unstructured pruning
    Pruned 241 linear layers
  Evaluating POST-RELAPSE performance:
    Running ARC-Easy evaluation... Accuracy: 22.0%
    Running LAMBADA evaluation... Accuracy: 15.0%
    ARC-Easy: 22.0%
    LAMBADA:  15.0%
    Composite: 18.5%
    Relapse drop: 2.0%

  Running LONGITUDINAL simulation for neurosteroid:
LOADING PRE-PRUNED MODEL
  Model: oopere/pruned60-llama-3.2-1B
  Dtype: torch.float16 (NO 4-bit quantization)
  Device map: auto
  Expected state: 60% MLP neurons removed (fragile knowledge)
----------------------------------------------------------------------
  Load time: 2.49s
  Total parameter

The model is already on multiple devices. Skipping the move to device specified in `args`.


  Actual trainable parameters: 8,860,672 (1.16%)
  Starting training...


Step,Training Loss
50,4.6244
100,3.6123
150,3.372
200,3.2792
250,3.1699
300,3.1506


  Training complete in 169.4s

  --- Longitudinal Simulation for NEUROSTEROID ---
    Cycle 0: ARC-Easy = 24.0% -> Stress applied (10% prune) -> Maintenance (3 epochs)
    Cycle 1: ARC-Easy = 34.0% -> Stress applied (10% prune) -> Maintenance (3 epochs)
    Cycle 2: ARC-Easy = 8.0% -> Stress applied (10% prune) -> Maintenance (3 epochs)
    Cycle 3: ARC-Easy = 28.0% -> Stress applied (10% prune) -> Maintenance (3 epochs)
    Cycle 4: ARC-Easy = 30.0% -> Stress applied (10% prune) -> Maintenance (3 epochs)
    Cycle 5: ARC-Easy = 8.0% -> Stress applied (10% prune) -> Maintenance (3 epochs)
    Cycle 6: ARC-Easy = 18.0% -> Stress applied (10% prune) -> Maintenance (3 epochs)
    Cycle 7: ARC-Easy = 20.0% -> Stress applied (10% prune) -> Maintenance (3 epochs)
    Cycle 8: ARC-Easy = 18.0%

######################################################################
#  SEED 0 COMPLETE
######################################################################

  [SAVED] Seed 0 results saved to ./llm

Map:   0%|          | 0/500 [00:00<?, ? examples/s]

Map:   0%|          | 0/500 [00:00<?, ? examples/s]

  Tokenized: max_length=512
----------------------------------------------------------------------

[STEP 3.1] Processing KETAMINE treatment
  Using ISODOSE parameters (target FLOPs: 1.35e+14)
  Reloading fresh pruned model...
LOADING PRE-PRUNED MODEL
  Model: oopere/pruned60-llama-3.2-1B
  Dtype: torch.float16 (NO 4-bit quantization)
  Device map: auto
  Expected state: 60% MLP neurons removed (fragile knowledge)
----------------------------------------------------------------------
  Load time: 1.59s
  Total parameters: 752,650,240
  Trainable parameters: 752,650,240
  Model device: cuda:0
  Model dtype: torch.float16
  [SUCCESS] Pruned model loaded (fragile state confirmed)

APPLYING KETAMINE TREATMENT
  [ISODOSE MODE]
  LoRA rank: 64
  LoRA alpha: 128
  Training epochs: 3
  Learning rate: 5e-05
  Dropout: 0.05
  Target modules: ['q_proj', 'v_proj', 'k_proj', 'o_proj', 'gate_proj', 'up_proj', 'down_proj']
  Estimated FLOPs: 1.35e+14
--------------------------------------------------

The model is already on multiple devices. Skipping the move to device specified in `args`.


  Actual trainable parameters: 29,989,888 (3.83%)
  Starting training...


Step,Training Loss
50,3.7958


  Training complete in 61.8s

  Evaluating POST-KETAMINE performance:
    Running ARC-Easy evaluation... Accuracy: 34.0%
    Running LAMBADA evaluation... Accuracy: 16.0%
    ARC-Easy: 34.0%
    LAMBADA:  16.0%
    Composite: 25.0%
    Recovery from untreated: +3.8%
    Efficiency (recovery/PetaFLOP): 27.72

  Simulating ACUTE RELAPSE:
    Applying acute relapse: 30% unstructured pruning
    Pruned 337 linear layers
  Evaluating POST-RELAPSE performance:
    Running ARC-Easy evaluation... Accuracy: 27.5%
    Running LAMBADA evaluation... Accuracy: 20.0%
    ARC-Easy: 27.5%
    LAMBADA:  20.0%
    Composite: 23.8%
    Relapse drop: 1.2%

  Running LONGITUDINAL simulation for ketamine:
LOADING PRE-PRUNED MODEL
  Model: oopere/pruned60-llama-3.2-1B
  Dtype: torch.float16 (NO 4-bit quantization)
  Device map: auto
  Expected state: 60% MLP neurons removed (fragile knowledge)
----------------------------------------------------------------------
  Load time: 26.27s
  Total parameters: 752,6

The model is already on multiple devices. Skipping the move to device specified in `args`.


  Actual trainable parameters: 29,989,888 (3.83%)
  Starting training...


Step,Training Loss
50,3.7957


  Training complete in 61.8s

  --- Longitudinal Simulation for KETAMINE ---
    Cycle 0: ARC-Easy = 30.0% -> Stress applied (10% prune) -> Maintenance (2 epochs)
    Cycle 1: ARC-Easy = 22.0% -> Stress applied (10% prune) -> Maintenance (2 epochs)
    Cycle 2: ARC-Easy = 16.0% -> Stress applied (10% prune) -> Maintenance (2 epochs)
    Cycle 3: ARC-Easy = 24.0% -> Stress applied (10% prune) -> Maintenance (2 epochs)
    Cycle 4: ARC-Easy = 18.0% -> Stress applied (10% prune) -> Maintenance (2 epochs)
    Cycle 5: ARC-Easy = 8.0% -> Stress applied (10% prune) -> Maintenance (2 epochs)
    Cycle 6: ARC-Easy = 18.0% -> Stress applied (10% prune) -> Maintenance (2 epochs)
    Cycle 7: ARC-Easy = 20.0% -> Stress applied (10% prune) -> Maintenance (2 epochs)
    Cycle 8: ARC-Easy = 16.0%

[STEP 3.2] Processing SSRI treatment
  Using ISODOSE parameters (target FLOPs: 1.35e+14)
  Reloading fresh pruned model...
LOADING PRE-PRUNED MODEL
  Model: oopere/pruned60-llama-3.2-1B
  Dtype: torch.floa

The model is already on multiple devices. Skipping the move to device specified in `args`.


  Load time: 2.36s
  Total parameters: 752,650,240
  Trainable parameters: 752,650,240
  Model device: cuda:0
  Model dtype: torch.float16
  [SUCCESS] Pruned model loaded (fragile state confirmed)

APPLYING SSRI TREATMENT
  [ISODOSE MODE]
  LoRA rank: 8
  LoRA alpha: 16
  Training epochs: 84
  Learning rate: 1e-06
  Dropout: 0.1
  Target modules: ['q_proj', 'v_proj']
  Estimated FLOPs: 1.35e+14
----------------------------------------------------------------------
  Actual trainable parameters: 851,968 (0.11%)
  Starting training...


Step,Training Loss
50,5.3586
100,5.3473
150,5.3519
200,5.3324
250,5.3083
300,5.2772
350,5.2412
400,5.2079
450,5.1094
500,5.0759


  Training complete in 1250.0s

  Evaluating POST-SSRI performance:
    Running ARC-Easy evaluation... Accuracy: 22.5%
    Running LAMBADA evaluation... Accuracy: 18.0%
    ARC-Easy: 22.5%
    LAMBADA:  18.0%
    Composite: 20.2%
    Recovery from untreated: -1.0%
    Efficiency (recovery/PetaFLOP): -7.39

  Simulating ACUTE RELAPSE:
    Applying acute relapse: 30% unstructured pruning
    Pruned 177 linear layers
  Evaluating POST-RELAPSE performance:
    Running ARC-Easy evaluation... Accuracy: 30.5%
    Running LAMBADA evaluation... Accuracy: 13.0%
    ARC-Easy: 30.5%
    LAMBADA:  13.0%
    Composite: 21.8%
    Relapse drop: -1.5%

  Running LONGITUDINAL simulation for ssri:
LOADING PRE-PRUNED MODEL
  Model: oopere/pruned60-llama-3.2-1B
  Dtype: torch.float16 (NO 4-bit quantization)
  Device map: auto
  Expected state: 60% MLP neurons removed (fragile knowledge)
----------------------------------------------------------------------


The model is already on multiple devices. Skipping the move to device specified in `args`.


  Load time: 1.86s
  Total parameters: 752,650,240
  Trainable parameters: 752,650,240
  Model device: cuda:0
  Model dtype: torch.float16
  [SUCCESS] Pruned model loaded (fragile state confirmed)

APPLYING SSRI TREATMENT
  [ISODOSE MODE]
  LoRA rank: 8
  LoRA alpha: 16
  Training epochs: 84
  Learning rate: 1e-06
  Dropout: 0.1
  Target modules: ['q_proj', 'v_proj']
  Estimated FLOPs: 1.35e+14
----------------------------------------------------------------------
  Actual trainable parameters: 851,968 (0.11%)
  Starting training...


Step,Training Loss
50,5.3587
100,5.3473
150,5.3519
200,5.3324
250,5.3083
300,5.2771
350,5.2412
400,5.2079
450,5.1093
500,5.0759


  Training complete in 1250.1s

  --- Longitudinal Simulation for SSRI ---
    Cycle 0: ARC-Easy = 24.0% -> Stress applied (10% prune) -> Maintenance (5 epochs)
    Cycle 1: ARC-Easy = 24.0% -> Stress applied (10% prune) -> Maintenance (5 epochs)
    Cycle 2: ARC-Easy = 6.0% -> Stress applied (10% prune) -> Maintenance (5 epochs)
    Cycle 3: ARC-Easy = 12.0% -> Stress applied (10% prune) -> Maintenance (5 epochs)
    Cycle 4: ARC-Easy = 8.0% -> Stress applied (10% prune) -> Maintenance (5 epochs)
    Cycle 5: ARC-Easy = 30.0% -> Stress applied (10% prune) -> Maintenance (5 epochs)
    Cycle 6: ARC-Easy = 4.0% -> Stress applied (10% prune) -> Maintenance (5 epochs)
    Cycle 7: ARC-Easy = 20.0% -> Stress applied (10% prune) -> Maintenance (5 epochs)
    Cycle 8: ARC-Easy = 18.0%

[STEP 3.3] Processing NEUROSTEROID treatment
  Using ISODOSE parameters (target FLOPs: 1.29e+14)
  Reloading fresh pruned model...
LOADING PRE-PRUNED MODEL
  Model: oopere/pruned60-llama-3.2-1B
  Dtype: torch.

The model is already on multiple devices. Skipping the move to device specified in `args`.


  Starting training...


Step,Training Loss
50,4.6285
100,3.6199
150,3.3952
200,3.267
250,3.1754
300,3.1333


  Training complete in 168.5s

  Evaluating POST-NEUROSTEROID performance:
    Running ARC-Easy evaluation... Accuracy: 27.5%
    Running LAMBADA evaluation... Accuracy: 11.0%
    ARC-Easy: 27.5%
    LAMBADA:  11.0%
    Composite: 19.2%
    Recovery from untreated: -2.0%
    Efficiency (recovery/PetaFLOP): -15.52

  Simulating ACUTE RELAPSE:
    Applying acute relapse: 30% unstructured pruning
    Pruned 241 linear layers
  Evaluating POST-RELAPSE performance:
    Running ARC-Easy evaluation... Accuracy: 24.0%
    Running LAMBADA evaluation... Accuracy: 22.0%
    ARC-Easy: 24.0%
    LAMBADA:  22.0%
    Composite: 23.0%
    Relapse drop: -3.8%

  Running LONGITUDINAL simulation for neurosteroid:
LOADING PRE-PRUNED MODEL
  Model: oopere/pruned60-llama-3.2-1B
  Dtype: torch.float16 (NO 4-bit quantization)
  Device map: auto
  Expected state: 60% MLP neurons removed (fragile knowledge)
----------------------------------------------------------------------
  Load time: 1.65s
  Total paramet

The model is already on multiple devices. Skipping the move to device specified in `args`.


  Actual trainable parameters: 8,860,672 (1.16%)
  Starting training...


Step,Training Loss
50,4.6285
100,3.6199
150,3.3952
200,3.267
250,3.1754
300,3.1333


  Training complete in 168.6s

  --- Longitudinal Simulation for NEUROSTEROID ---
    Cycle 0: ARC-Easy = 28.0% -> Stress applied (10% prune) -> Maintenance (3 epochs)
    Cycle 1: ARC-Easy = 28.0% -> Stress applied (10% prune) -> Maintenance (3 epochs)
    Cycle 2: ARC-Easy = 4.0% -> Stress applied (10% prune) -> Maintenance (3 epochs)
    Cycle 3: ARC-Easy = 30.0% -> Stress applied (10% prune) -> Maintenance (3 epochs)
    Cycle 4: ARC-Easy = 28.0% -> Stress applied (10% prune) -> Maintenance (3 epochs)
    Cycle 5: ARC-Easy = 6.0% -> Stress applied (10% prune) -> Maintenance (3 epochs)
    Cycle 6: ARC-Easy = 18.0% -> Stress applied (10% prune) -> Maintenance (3 epochs)
    Cycle 7: ARC-Easy = 22.0% -> Stress applied (10% prune) -> Maintenance (3 epochs)
    Cycle 8: ARC-Easy = 18.0%

######################################################################
#  SEED 1 COMPLETE
######################################################################

  [SAVED] Seed 1 results saved to ./llm

Map:   0%|          | 0/500 [00:00<?, ? examples/s]

Map:   0%|          | 0/500 [00:00<?, ? examples/s]

  Tokenized: max_length=512
----------------------------------------------------------------------

[STEP 3.1] Processing KETAMINE treatment
  Using ISODOSE parameters (target FLOPs: 1.35e+14)
  Reloading fresh pruned model...
LOADING PRE-PRUNED MODEL
  Model: oopere/pruned60-llama-3.2-1B
  Dtype: torch.float16 (NO 4-bit quantization)
  Device map: auto
  Expected state: 60% MLP neurons removed (fragile knowledge)
----------------------------------------------------------------------
  Load time: 1.50s
  Total parameters: 752,650,240
  Trainable parameters: 752,650,240
  Model device: cuda:0
  Model dtype: torch.float16
  [SUCCESS] Pruned model loaded (fragile state confirmed)

APPLYING KETAMINE TREATMENT
  [ISODOSE MODE]
  LoRA rank: 64
  LoRA alpha: 128
  Training epochs: 3
  Learning rate: 5e-05
  Dropout: 0.05
  Target modules: ['q_proj', 'v_proj', 'k_proj', 'o_proj', 'gate_proj', 'up_proj', 'down_proj']
  Estimated FLOPs: 1.35e+14
--------------------------------------------------

The model is already on multiple devices. Skipping the move to device specified in `args`.


  Actual trainable parameters: 29,989,888 (3.83%)
  Starting training...


Step,Training Loss
50,3.8027


  Training complete in 61.7s

  Evaluating POST-KETAMINE performance:
    Running ARC-Easy evaluation... Accuracy: 28.5%
    Running LAMBADA evaluation... Accuracy: 13.0%
    ARC-Easy: 28.5%
    LAMBADA:  13.0%
    Composite: 20.8%
    Recovery from untreated: -0.5%
    Efficiency (recovery/PetaFLOP): -3.70

  Simulating ACUTE RELAPSE:
    Applying acute relapse: 30% unstructured pruning
    Pruned 337 linear layers
  Evaluating POST-RELAPSE performance:
    Running ARC-Easy evaluation... Accuracy: 23.0%
    Running LAMBADA evaluation... Accuracy: 17.0%
    ARC-Easy: 23.0%
    LAMBADA:  17.0%
    Composite: 20.0%
    Relapse drop: 0.8%

  Running LONGITUDINAL simulation for ketamine:
LOADING PRE-PRUNED MODEL
  Model: oopere/pruned60-llama-3.2-1B
  Dtype: torch.float16 (NO 4-bit quantization)
  Device map: auto
  Expected state: 60% MLP neurons removed (fragile knowledge)
----------------------------------------------------------------------
  Load time: 1.63s
  Total parameters: 752,65

The model is already on multiple devices. Skipping the move to device specified in `args`.


  Actual trainable parameters: 29,989,888 (3.83%)
  Starting training...


Step,Training Loss
50,3.8027


  Training complete in 61.9s

  --- Longitudinal Simulation for KETAMINE ---
    Cycle 0: ARC-Easy = 22.0% -> Stress applied (10% prune) -> Maintenance (2 epochs)
    Cycle 1: ARC-Easy = 16.0% -> Stress applied (10% prune) -> Maintenance (2 epochs)
    Cycle 2: ARC-Easy = 20.0% -> Stress applied (10% prune) -> Maintenance (2 epochs)
    Cycle 3: ARC-Easy = 18.0% -> Stress applied (10% prune) -> Maintenance (2 epochs)
    Cycle 4: ARC-Easy = 22.0% -> Stress applied (10% prune) -> Maintenance (2 epochs)
    Cycle 5: ARC-Easy = 8.0% -> Stress applied (10% prune) -> Maintenance (2 epochs)
    Cycle 6: ARC-Easy = 18.0% -> Stress applied (10% prune) -> Maintenance (2 epochs)
    Cycle 7: ARC-Easy = 20.0% -> Stress applied (10% prune) -> Maintenance (2 epochs)
    Cycle 8: ARC-Easy = 16.0%

[STEP 3.2] Processing SSRI treatment
  Using ISODOSE parameters (target FLOPs: 1.35e+14)
  Reloading fresh pruned model...
LOADING PRE-PRUNED MODEL
  Model: oopere/pruned60-llama-3.2-1B
  Dtype: torch.floa

The model is already on multiple devices. Skipping the move to device specified in `args`.


  Load time: 1.68s
  Total parameters: 752,650,240
  Trainable parameters: 752,650,240
  Model device: cuda:0
  Model dtype: torch.float16
  [SUCCESS] Pruned model loaded (fragile state confirmed)

APPLYING SSRI TREATMENT
  [ISODOSE MODE]
  LoRA rank: 8
  LoRA alpha: 16
  Training epochs: 84
  Learning rate: 1e-06
  Dropout: 0.1
  Target modules: ['q_proj', 'v_proj']
  Estimated FLOPs: 1.35e+14
----------------------------------------------------------------------
  Actual trainable parameters: 851,968 (0.11%)
  Starting training...


Step,Training Loss
50,5.3536
100,5.362
150,5.3671
200,5.3613
250,5.3337
300,5.279
350,5.264
400,5.1615
450,5.194
500,5.1116


  Training complete in 1250.2s

  Evaluating POST-SSRI performance:
    Running ARC-Easy evaluation... Accuracy: 21.5%
    Running LAMBADA evaluation... Accuracy: 17.0%
    ARC-Easy: 21.5%
    LAMBADA:  17.0%
    Composite: 19.2%
    Recovery from untreated: -2.0%
    Efficiency (recovery/PetaFLOP): -14.78

  Simulating ACUTE RELAPSE:
    Applying acute relapse: 30% unstructured pruning
    Pruned 177 linear layers
  Evaluating POST-RELAPSE performance:
    Running ARC-Easy evaluation... Accuracy: 32.0%
    Running LAMBADA evaluation... Accuracy: 13.0%
    ARC-Easy: 32.0%
    LAMBADA:  13.0%
    Composite: 22.5%
    Relapse drop: -3.2%

  Running LONGITUDINAL simulation for ssri:
LOADING PRE-PRUNED MODEL
  Model: oopere/pruned60-llama-3.2-1B
  Dtype: torch.float16 (NO 4-bit quantization)
  Device map: auto
  Expected state: 60% MLP neurons removed (fragile knowledge)
----------------------------------------------------------------------


The model is already on multiple devices. Skipping the move to device specified in `args`.


  Load time: 1.71s
  Total parameters: 752,650,240
  Trainable parameters: 752,650,240
  Model device: cuda:0
  Model dtype: torch.float16
  [SUCCESS] Pruned model loaded (fragile state confirmed)

APPLYING SSRI TREATMENT
  [ISODOSE MODE]
  LoRA rank: 8
  LoRA alpha: 16
  Training epochs: 84
  Learning rate: 1e-06
  Dropout: 0.1
  Target modules: ['q_proj', 'v_proj']
  Estimated FLOPs: 1.35e+14
----------------------------------------------------------------------
  Actual trainable parameters: 851,968 (0.11%)
  Starting training...


Step,Training Loss
50,5.3536
100,5.362
150,5.3671
200,5.3613
250,5.3337
300,5.279
350,5.264
400,5.1615
450,5.194
500,5.1117


  Training complete in 1254.4s

  --- Longitudinal Simulation for SSRI ---
    Cycle 0: ARC-Easy = 20.0% -> Stress applied (10% prune) -> Maintenance (5 epochs)
    Cycle 1: ARC-Easy = 24.0% -> Stress applied (10% prune) -> Maintenance (5 epochs)
    Cycle 2: ARC-Easy = 4.0% -> Stress applied (10% prune) -> Maintenance (5 epochs)
    Cycle 3: ARC-Easy = 12.0% -> Stress applied (10% prune) -> Maintenance (5 epochs)
    Cycle 4: ARC-Easy = 8.0% -> Stress applied (10% prune) -> Maintenance (5 epochs)
    Cycle 5: ARC-Easy = 30.0% -> Stress applied (10% prune) -> Maintenance (5 epochs)
    Cycle 6: ARC-Easy = 4.0% -> Stress applied (10% prune) -> Maintenance (5 epochs)
    Cycle 7: ARC-Easy = 26.0% -> Stress applied (10% prune) -> Maintenance (5 epochs)
    Cycle 8: ARC-Easy = 18.0%

[STEP 3.3] Processing NEUROSTEROID treatment
  Using ISODOSE parameters (target FLOPs: 1.29e+14)
  Reloading fresh pruned model...
LOADING PRE-PRUNED MODEL
  Model: oopere/pruned60-llama-3.2-1B
  Dtype: torch.

The model is already on multiple devices. Skipping the move to device specified in `args`.


  Starting training...


Step,Training Loss
50,4.6142
100,3.5828
150,3.3725
200,3.2633
250,3.1684
300,3.122


  Training complete in 168.3s

  Evaluating POST-NEUROSTEROID performance:
    Running ARC-Easy evaluation... Accuracy: 28.0%
    Running LAMBADA evaluation... Accuracy: 12.0%
    ARC-Easy: 28.0%
    LAMBADA:  12.0%
    Composite: 20.0%
    Recovery from untreated: -1.2%
    Efficiency (recovery/PetaFLOP): -9.70

  Simulating ACUTE RELAPSE:
    Applying acute relapse: 30% unstructured pruning
    Pruned 241 linear layers
  Evaluating POST-RELAPSE performance:
    Running ARC-Easy evaluation... Accuracy: 25.0%
    Running LAMBADA evaluation... Accuracy: 17.0%
    ARC-Easy: 25.0%
    LAMBADA:  17.0%
    Composite: 21.0%
    Relapse drop: -1.0%

  Running LONGITUDINAL simulation for neurosteroid:
LOADING PRE-PRUNED MODEL
  Model: oopere/pruned60-llama-3.2-1B
  Dtype: torch.float16 (NO 4-bit quantization)
  Device map: auto
  Expected state: 60% MLP neurons removed (fragile knowledge)
----------------------------------------------------------------------
  Load time: 1.61s
  Total paramete

The model is already on multiple devices. Skipping the move to device specified in `args`.


  Actual trainable parameters: 8,860,672 (1.16%)
  Starting training...


Step,Training Loss
50,4.6141
100,3.5828
150,3.3725
200,3.2634
250,3.1685
300,3.122


  Training complete in 169.2s

  --- Longitudinal Simulation for NEUROSTEROID ---
    Cycle 0: ARC-Easy = 24.0% -> Stress applied (10% prune) -> Maintenance (3 epochs)
    Cycle 1: ARC-Easy = 30.0% -> Stress applied (10% prune) -> Maintenance (3 epochs)
    Cycle 2: ARC-Easy = 2.0% -> Stress applied (10% prune) -> Maintenance (3 epochs)
    Cycle 3: ARC-Easy = 30.0% -> Stress applied (10% prune) -> Maintenance (3 epochs)
    Cycle 4: ARC-Easy = 26.0% -> Stress applied (10% prune) -> Maintenance (3 epochs)
    Cycle 5: ARC-Easy = 8.0% -> Stress applied (10% prune) -> Maintenance (3 epochs)
    Cycle 6: ARC-Easy = 20.0% -> Stress applied (10% prune) -> Maintenance (3 epochs)
    Cycle 7: ARC-Easy = 24.0% -> Stress applied (10% prune) -> Maintenance (3 epochs)
    Cycle 8: ARC-Easy = 18.0%

######################################################################
#  SEED 2 COMPLETE
######################################################################

  [SAVED] Seed 2 results saved to ./llm

# Isodose Sweep and Generalization

In [None]:
"""
================================================================================
KETAMINE ISODOSE SWEEP & GENERALIZATION TEST
================================================================================

Extension of the Multi-Mechanism Recovery experiment to:
1. Sweep Ketamine-like treatment parameters to find optimal dosing
2. Test generalization of optimal dose across different pruned LLM architectures

Based on isodose results showing:
- Ketamine: +2.1% recovery, +15.4 recovery/PetaFLOP efficiency
- Superior resilience to relapse and longitudinal stress

This module adds:
- KetamineSweepConfig: Configuration for dose parameter sweeps
- Isodose sweep: Vary rank while maintaining constant FLOPs
- Budget sweep: Vary total computational budget at optimal rank
- Generalization test: Validate optimal dose on different model architectures

References:
- llama-pruning: https://github.com/MedITSolutionsKurman/llama-pruning
- LLM Course: https://github.com/peremartra/Large-Language-Model-Notebooks-Course
================================================================================
"""

# ============================================================================
# CELL 1: INSTALLATIONS (Run first in Colab)
# ============================================================================
# !pip install -q transformers accelerate peft datasets torch --extra-index-url https://download.pytorch.org/whl/cu121
# !pip install -q sentencepiece einops

# ============================================================================
# CELL 2: IMPORTS AND CONFIGURATION
# ============================================================================
import torch
import torch.nn as nn
from torch.nn.utils import prune
from torch.utils.data import DataLoader
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    TrainingArguments,
    Trainer,
    DataCollatorForLanguageModeling
)
from peft import (
    LoraConfig,
    get_peft_model,
    prepare_model_for_kbit_training
)
from datasets import load_dataset
import numpy as np
from typing import Dict, List, Tuple, Optional, Any
import copy
import json
import os
from dataclasses import dataclass, field
from collections import defaultdict
from itertools import product
import warnings
import gc
import time
import math

warnings.filterwarnings('ignore')

# Set seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)


# ============================================================================
# CELL 3: EXTENDED CONFIGURATION WITH KETAMINE SWEEP
# ============================================================================
@dataclass
class KetamineSweepConfig:
    """
    Configuration for Ketamine dosing sweep experiments.

    SWEEP MODES:
    - "isodose": Vary rank while adjusting epochs to maintain constant FLOPs
    - "budget": Vary total computational budget at fixed rank
    - "full_grid": Sweep multiple parameters simultaneously

    DESIGN RATIONALE:
    Ketamine-like treatment showed superior efficacy in isodose comparison.
    This sweep aims to find the optimal "dose" (rank/capacity) and test
    whether this optimal configuration generalizes across model architectures.
    """
    # Sweep mode selection
    sweep_mode: str = "isodose"  # "isodose", "budget", or "full_grid"

    # Rank sweep (capacity/synaptogenesis strength)
    ranks: List[int] = field(default_factory=lambda: [16, 32, 64, 96, 128])

    # Alpha scaling strategy: alpha = rank * alpha_multiplier
    alpha_multiplier: float = 2.0

    # Learning rate options for grid search
    learning_rates: List[float] = field(default_factory=lambda: [5e-5])

    # Dropout options
    dropouts: List[float] = field(default_factory=lambda: [0.05])

    # Target module configurations to test
    target_module_configs: Dict[str, List[str]] = field(default_factory=lambda: {
        'full': ["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
        'attention_only': ["q_proj", "v_proj", "k_proj", "o_proj"],
        'mlp_only': ["gate_proj", "up_proj", "down_proj"],
        'minimal': ["q_proj", "v_proj"]
    })

    # Default target module config for rank sweep
    default_target_config: str = 'full'

    # Base FLOPs target (calibrated from reference ketamine treatment)
    # Set to None for automatic calibration from rank=64, epochs=3
    base_flops: Optional[float] = None

    # Budget sweep multipliers (relative to base_flops)
    budget_multipliers: List[float] = field(default_factory=lambda: [0.5, 1.0, 1.5, 2.0])

    # Generalization models to test
    # Format: List of (model_id, description, expected_architecture) tuples
    generalization_models: List[Tuple[str, str, str]] = field(default_factory=lambda: [
        ("oopere/pruned60-llama-3.2-1B", "Primary: Llama-3.2-1B 60% pruned", "llama"),
        # Add alternative models here - examples:
        # ("username/pruned-mistral-7b", "Mistral-7B pruned", "mistral"),
        # ("username/pruned-phi-3", "Phi-3 pruned", "phi"),
    ])

    # Number of seeds for sweep (1 for fast ranking, 3 for validation)
    sweep_seeds: int = 1
    validation_seeds: int = 3

    # Number of top configurations to validate with more seeds
    top_k_validate: int = 2

    # Evaluation settings for sweep (reduced for speed)
    sweep_eval_samples: int = 100

    # Whether to run acute relapse test during sweep
    test_relapse_in_sweep: bool = True

    # Whether to run longitudinal test during sweep (slower)
    test_longitudinal_in_sweep: bool = False

    # Output directory for sweep results
    sweep_output_dir: str = "./ketamine_sweep_results"


@dataclass
class ExperimentConfig:
    """
    Extended configuration including Ketamine sweep parameters.
    """
    # Model settings
    pruned_model: str = "oopere/pruned60-llama-3.2-1B"
    use_4bit: bool = False
    device: str = "cuda" if torch.cuda.is_available() else "cpu"

    # Experiment settings
    n_seeds: int = 3
    output_dir: str = "./llm_treatment_results"
    eval_samples: int = 200

    # Isodose settings
    isodose_mode: bool = True
    target_flops: Optional[float] = None
    isodose_reference: str = "ketamine"

    # Ketamine sweep configuration
    ketamine_sweep: KetamineSweepConfig = field(default_factory=KetamineSweepConfig)

    # Run mode flags
    run_main_experiment: bool = False  # Set to True to run original 3-treatment comparison
    run_ketamine_sweep: bool = True    # Set to True to run Ketamine optimization sweep
    run_generalization_test: bool = True  # Set to True to test on other models

    # Default treatment parameters (used when not in sweep mode)
    ketamine_lora_rank: int = 64
    ketamine_lora_alpha: int = 128
    ketamine_epochs: int = 3
    ketamine_lr: float = 5e-5
    ketamine_target_modules: List[str] = field(
        default_factory=lambda: ["q_proj", "v_proj", "k_proj", "o_proj",
                                  "gate_proj", "up_proj", "down_proj"]
    )
    ketamine_dropout: float = 0.05

    # SSRI parameters (for comparison if needed)
    ssri_lora_rank: int = 8
    ssri_lora_alpha: int = 16
    ssri_epochs: int = 15
    ssri_lr: float = 1e-6
    ssri_dropout: float = 0.1

    # Neurosteroid parameters (for comparison if needed)
    neuro_lora_rank: int = 32
    neuro_lora_alpha: int = 64
    neuro_epochs: int = 5
    neuro_lr: float = 3e-5
    neuro_dropout: float = 0.5

    # Longitudinal settings
    longitudinal_cycles: int = 8
    additional_prune_per_cycle: float = 0.10
    ketamine_maintenance_epochs: int = 2
    ssri_maintenance_epochs: int = 5
    neuro_maintenance_epochs: int = 3

    # Acute relapse settings
    acute_relapse_prune_amount: float = 0.30

    # Dataset settings
    finetune_dataset: str = "databricks/databricks-dolly-15k"
    finetune_subset_size: int = 500
    max_seq_length: int = 512


# Initialize global config
CONFIG = ExperimentConfig()


# ============================================================================
# CELL 4: TREATMENT DOSE DATACLASS
# ============================================================================
@dataclass
class TreatmentDose:
    """
    Represents a treatment's computational "dose" and parameters.
    Extended with scoring and comparison methods for sweep optimization.
    """
    treatment_name: str
    lora_rank: int
    lora_alpha: int
    epochs: int
    lr: float
    dropout: float
    target_modules: List[str]
    trainable_params: int = 0
    estimated_flops: float = 0.0

    # Results storage
    results: Dict[str, Any] = field(default_factory=dict)

    def compute_flops(self, dataset_size: int, seq_length: int,
                      hidden_size: int = 2048, num_layers: int = 16) -> float:
        """Estimate training FLOPs for this treatment configuration."""
        params_per_module = 2 * self.lora_rank * hidden_size
        num_adapted_modules = len(self.target_modules) * num_layers
        self.trainable_params = params_per_module * num_adapted_modules

        flops_per_sample = 6 * self.trainable_params * seq_length
        self.estimated_flops = flops_per_sample * dataset_size * self.epochs

        return self.estimated_flops

    def get_config_key(self) -> str:
        """Generate a unique key for this configuration."""
        modules_key = "_".join(sorted([m.split("_")[0] for m in self.target_modules]))
        return f"r{self.lora_rank}_e{self.epochs}_lr{self.lr:.0e}_d{self.dropout}_{modules_key}"

    def to_dict(self) -> Dict:
        """Convert to dictionary for serialization."""
        return {
            'treatment_name': self.treatment_name,
            'lora_rank': self.lora_rank,
            'lora_alpha': self.lora_alpha,
            'epochs': self.epochs,
            'lr': self.lr,
            'dropout': self.dropout,
            'target_modules': self.target_modules,
            'trainable_params': self.trainable_params,
            'estimated_flops': self.estimated_flops,
            'results': self.results
        }


@dataclass
class SweepResult:
    """
    Stores results from a single sweep configuration.
    Includes composite scoring for optimization.
    """
    dose: TreatmentDose
    seed: int
    model_id: str

    # Core metrics
    post_treatment_composite: float = 0.0
    post_treatment_arc: float = 0.0
    post_treatment_lambada: float = 0.0

    # Recovery metrics
    recovery_from_baseline: float = 0.0

    # Relapse metrics
    post_relapse_composite: float = 0.0
    relapse_drop: float = 0.0

    # Longitudinal metrics (optional)
    longitudinal_final: float = 0.0
    longitudinal_drop: float = 0.0

    # Efficiency metrics
    efficiency: float = 0.0  # Recovery per PetaFLOP

    # Timing
    train_time: float = 0.0

    def calculate_composite_score(self,
                                   recovery_weight: float = 0.4,
                                   resilience_weight: float = 0.3,
                                   efficiency_weight: float = 0.2,
                                   durability_weight: float = 0.1) -> float:
        """
        Calculate composite optimization score.

        SCORING COMPONENTS:
        1. Recovery: Primary measure of treatment effectiveness
        2. Resilience: Inverse of relapse drop (resistance to acute stress)
        3. Efficiency: Recovery per unit computational cost
        4. Durability: Longitudinal stability (if measured)

        Higher score = better treatment configuration.
        """
        # Normalize components to similar scales
        recovery_score = self.recovery_from_baseline  # Already in % points

        # Resilience: lower drop is better, invert and scale
        resilience_score = max(0, 10 - self.relapse_drop)  # 0-10 scale

        # Efficiency: recovery per PetaFLOP, scale to similar range
        efficiency_score = self.efficiency * 0.5  # Scale factor

        # Durability: if not measured, use resilience as proxy
        if self.longitudinal_drop > 0:
            durability_score = max(0, 20 - self.longitudinal_drop)
        else:
            durability_score = resilience_score

        composite = (
            recovery_weight * recovery_score +
            resilience_weight * resilience_score +
            efficiency_weight * efficiency_score +
            durability_weight * durability_score
        )

        return composite

    def to_dict(self) -> Dict:
        """Convert to dictionary for serialization."""
        return {
            'dose': self.dose.to_dict(),
            'seed': self.seed,
            'model_id': self.model_id,
            'post_treatment_composite': self.post_treatment_composite,
            'post_treatment_arc': self.post_treatment_arc,
            'post_treatment_lambada': self.post_treatment_lambada,
            'recovery_from_baseline': self.recovery_from_baseline,
            'post_relapse_composite': self.post_relapse_composite,
            'relapse_drop': self.relapse_drop,
            'longitudinal_final': self.longitudinal_final,
            'longitudinal_drop': self.longitudinal_drop,
            'efficiency': self.efficiency,
            'train_time': self.train_time,
            'composite_score': self.calculate_composite_score()
        }


# ============================================================================
# CELL 5: MODEL LOADING UTILITIES
# ============================================================================
def load_model_for_sweep(model_id: str, config: ExperimentConfig) -> Tuple[AutoModelForCausalLM, AutoTokenizer]:
    """
    Load a model for sweep experiments.
    Handles different model architectures and provides architecture detection.
    """
    print(f"  Loading model: {model_id}")

    try:
        model = AutoModelForCausalLM.from_pretrained(
            model_id,
            torch_dtype=torch.float16,
            device_map="auto",
            trust_remote_code=True,
            low_cpu_mem_usage=True
        )

        tokenizer = AutoTokenizer.from_pretrained(
            model_id,
            trust_remote_code=True
        )

        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token

        # Detect architecture
        architecture = detect_model_architecture(model)

        total_params = sum(p.numel() for p in model.parameters())
        print(f"    Architecture: {architecture}")
        print(f"    Parameters: {total_params:,}")
        print(f"    Device: {next(model.parameters()).device}")

        return model, tokenizer, architecture

    except Exception as e:
        print(f"    [ERROR] Failed to load {model_id}: {e}")
        raise


def detect_model_architecture(model) -> str:
    """
    Detect the model architecture type for appropriate module targeting.

    Returns one of: 'llama', 'mistral', 'phi', 'qwen', 'gemma', 'unknown'
    """
    model_type = getattr(model.config, 'model_type', 'unknown').lower()

    architecture_map = {
        'llama': 'llama',
        'mistral': 'mistral',
        'phi': 'phi',
        'phi3': 'phi',
        'qwen': 'qwen',
        'qwen2': 'qwen',
        'gemma': 'gemma',
        'gemma2': 'gemma',
    }

    return architecture_map.get(model_type, 'unknown')


def get_target_modules_for_architecture(architecture: str, config_name: str,
                                         sweep_config: KetamineSweepConfig) -> List[str]:
    """
    Get appropriate target modules for a given architecture and configuration.

    Different architectures may have different module naming conventions.
    This function handles the mapping.
    """
    base_modules = sweep_config.target_module_configs.get(
        config_name,
        sweep_config.target_module_configs['full']
    )

    # Architecture-specific module name mappings
    # Most Llama-like models (Mistral, Qwen, etc.) use the same naming
    # Phi models may use different names

    if architecture == 'phi':
        # Phi-3 uses different naming
        name_map = {
            'q_proj': 'qkv_proj',  # Phi combines QKV
            'k_proj': 'qkv_proj',
            'v_proj': 'qkv_proj',
            'o_proj': 'o_proj',
            'gate_proj': 'gate_up_proj',  # Phi combines gate and up
            'up_proj': 'gate_up_proj',
            'down_proj': 'down_proj'
        }
        # Get unique modules after mapping
        mapped = list(set(name_map.get(m, m) for m in base_modules))
        return mapped

    # Default: return base modules (works for Llama, Mistral, Qwen, Gemma)
    return base_modules


# ============================================================================
# CELL 6: EVALUATION FUNCTIONS
# ============================================================================
def load_evaluation_datasets(config: ExperimentConfig = None,
                              max_samples: int = None) -> Dict:
    """Load evaluation datasets for measuring knowledge fragility."""
    if config is None:
        config = CONFIG

    if max_samples is None:
        max_samples = config.eval_samples

    datasets = {}

    try:
        arc = load_dataset(
            "allenai/ai2_arc",
            "ARC-Easy",
            split="test"
        ).shuffle(seed=42).select(range(min(max_samples, 2376)))
        datasets['arc'] = arc
        print(f"    ARC-Easy: {len(arc)} samples")
    except Exception as e:
        print(f"    [WARNING] ARC-Easy load failed: {e}")
        datasets['arc'] = None

    try:
        lambada = load_dataset(
            "lambada",
            split="test"
        ).shuffle(seed=42).select(range(min(max_samples // 2, 2500)))
        datasets['lambada'] = lambada
        print(f"    LAMBADA: {len(lambada)} samples")
    except Exception as e:
        print(f"    [WARNING] LAMBADA load failed: {e}")
        datasets['lambada'] = None

    return datasets


def evaluate_arc_easy(model, tokenizer, dataset, max_samples: int = None) -> Dict:
    """Evaluate model on ARC-Easy multiple choice questions."""
    if dataset is None:
        return {'accuracy': 0.0, 'error': 'Dataset not loaded'}

    model.eval()
    correct = 0
    total = 0

    samples = dataset if max_samples is None else dataset.select(
        range(min(max_samples, len(dataset)))
    )

    for example in samples:
        question = example['question']
        answer_key = example['answerKey']
        choices = example['choices']

        choice_text = "\n".join([
            f"{label}: {text}"
            for label, text in zip(choices['label'], choices['text'])
        ])

        prompt = f"Question: {question}\n\nChoices:\n{choice_text}\n\nAnswer:"

        inputs = tokenizer(
            prompt,
            return_tensors="pt",
            truncation=True,
            max_length=512
        ).to(model.device)

        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                max_new_tokens=10,
                temperature=0.0,
                do_sample=False,
                pad_token_id=tokenizer.pad_token_id
            )

        generated = tokenizer.decode(outputs[0], skip_special_tokens=True)
        response = generated[len(prompt):].strip().upper()

        if answer_key.upper() in response[:5]:
            correct += 1
        total += 1

    accuracy = 100.0 * correct / total if total > 0 else 0.0
    return {'accuracy': accuracy, 'correct': correct, 'total': total}


def evaluate_lambada(model, tokenizer, dataset, max_samples: int = None) -> Dict:
    """Evaluate model on LAMBADA word prediction task."""
    if dataset is None:
        return {'accuracy': 0.0, 'error': 'Dataset not loaded'}

    model.eval()
    correct = 0
    total = 0

    samples = dataset if max_samples is None else dataset.select(
        range(min(max_samples, len(dataset)))
    )

    for example in samples:
        text = example['text']
        words = text.split()

        if len(words) < 2:
            continue

        target_word = words[-1].lower().strip('.,!?')
        context = ' '.join(words[:-1])

        inputs = tokenizer(
            context,
            return_tensors="pt",
            truncation=True,
            max_length=512
        ).to(model.device)

        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                max_new_tokens=5,
                temperature=0.0,
                do_sample=False,
                pad_token_id=tokenizer.pad_token_id
            )

        generated = tokenizer.decode(outputs[0], skip_special_tokens=True)
        prediction = generated[len(context):].strip().lower().split()[0] if generated[len(context):].strip() else ""
        prediction = prediction.strip('.,!?')

        if target_word == prediction or target_word.startswith(prediction):
            correct += 1
        total += 1

    accuracy = 100.0 * correct / total if total > 0 else 0.0
    return {'accuracy': accuracy, 'correct': correct, 'total': total}


def evaluate_model_quick(model, tokenizer, eval_datasets: Dict,
                          max_samples: int = 50) -> Dict:
    """Quick evaluation for sweep (fewer samples)."""
    results = {}

    arc_results = evaluate_arc_easy(model, tokenizer, eval_datasets.get('arc'), max_samples)
    results['arc_easy'] = arc_results['accuracy']

    lambada_results = evaluate_lambada(model, tokenizer, eval_datasets.get('lambada'), max_samples // 2)
    results['lambada'] = lambada_results['accuracy']

    results['composite'] = (results['arc_easy'] + results['lambada']) / 2

    return results


# ============================================================================
# CELL 7: DATASET PREPARATION
# ============================================================================
def prepare_finetune_dataset(tokenizer, config: ExperimentConfig, seed: int):
    """Prepare fine-tuning dataset for knowledge recovery training."""
    dataset = load_dataset(
        config.finetune_dataset,
        split="train"
    )

    dataset = dataset.shuffle(seed=seed).select(
        range(min(config.finetune_subset_size, len(dataset)))
    )

    def format_example(example):
        instruction = example.get('instruction', example.get('context', ''))
        response = example.get('response', example.get('text', ''))
        text = f"### Instruction:\n{instruction}\n\n### Response:\n{response}"
        return {'text': text}

    dataset = dataset.map(format_example)

    def tokenize(examples):
        return tokenizer(
            examples['text'],
            truncation=True,
            max_length=config.max_seq_length,
            padding='max_length'
        )

    dataset = dataset.map(tokenize, batched=True, remove_columns=dataset.column_names)
    dataset.set_format('torch')

    return dataset


# ============================================================================
# CELL 8: TREATMENT APPLICATION
# ============================================================================
def apply_ketamine_dose(base_model, tokenizer, train_dataset,
                         dose: TreatmentDose, config: ExperimentConfig,
                         seed: int) -> Tuple[nn.Module, Dict]:
    """
    Apply a Ketamine-like treatment with specified dose parameters.
    """
    torch.manual_seed(seed)

    model = copy.deepcopy(base_model)

    lora_config = LoraConfig(
        r=dose.lora_rank,
        lora_alpha=dose.lora_alpha,
        target_modules=dose.target_modules,
        lora_dropout=dose.dropout,
        bias="none",
        task_type="CAUSAL_LM"
    )

    model = get_peft_model(model, lora_config)

    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

    training_args = TrainingArguments(
        output_dir=os.path.join(config.ketamine_sweep.sweep_output_dir,
                                f"sweep_{dose.get_config_key()}_seed{seed}"),
        num_train_epochs=dose.epochs,
        learning_rate=dose.lr,
        per_device_train_batch_size=4,
        gradient_accumulation_steps=4,
        warmup_ratio=0.1,
        logging_steps=100,
        save_strategy="no",
        fp16=True,
        report_to="none",
        remove_unused_columns=False
    )

    data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        data_collator=data_collator
    )

    start_time = time.time()
    trainer.train()
    train_time = time.time() - start_time

    stats = {
        'trainable_params': trainable_params,
        'train_time': train_time,
        'estimated_flops': dose.estimated_flops
    }

    return model, stats


def apply_acute_relapse(model, prune_amount: float = 0.30):
    """Simulate acute relapse by applying unstructured pruning."""
    pruned_layers = 0

    for name, module in model.named_modules():
        if isinstance(module, nn.Linear):
            try:
                prune.l1_unstructured(module, name='weight', amount=prune_amount)
                prune.remove(module, 'weight')
                pruned_layers += 1
            except:
                pass

    return model


# ============================================================================
# CELL 9: KETAMINE ISODOSE SWEEP
# ============================================================================
def generate_isodose_sweep_doses(config: ExperimentConfig,
                                  base_flops: float) -> List[TreatmentDose]:
    """
    Generate Ketamine doses for isodose sweep.

    ISODOSE PRINCIPLE:
    Vary the LoRA rank (capacity/structural regrowth strength) while
    adjusting epochs to maintain constant computational cost (FLOPs).

    This isolates the effect of STRUCTURAL CAPACITY from total effort.
    """
    sweep_config = config.ketamine_sweep
    doses = []

    target_modules = sweep_config.target_module_configs[sweep_config.default_target_config]

    for rank in sweep_config.ranks:
        # Create dose with this rank
        dose = TreatmentDose(
            treatment_name='ketamine',
            lora_rank=rank,
            lora_alpha=int(rank * sweep_config.alpha_multiplier),
            epochs=3,  # Placeholder, will be adjusted
            lr=sweep_config.learning_rates[0],
            dropout=sweep_config.dropouts[0],
            target_modules=target_modules
        )

        # Calculate FLOPs for 1 epoch to determine epoch scaling
        dose.compute_flops(config.finetune_subset_size, config.max_seq_length)
        flops_per_epoch = dose.estimated_flops / dose.epochs

        # Calculate epochs needed to match base_flops
        isodose_epochs = max(1, int(round(base_flops / flops_per_epoch)))

        # Update dose with isodose-calibrated epochs
        dose.epochs = isodose_epochs
        dose.estimated_flops = flops_per_epoch * isodose_epochs

        doses.append(dose)

    return doses


def generate_budget_sweep_doses(config: ExperimentConfig,
                                 optimal_rank: int,
                                 base_flops: float) -> List[TreatmentDose]:
    """
    Generate Ketamine doses for budget sweep.

    BUDGET SWEEP PRINCIPLE:
    At the optimal rank found from isodose sweep, vary the total
    computational budget to find the point of diminishing returns.
    """
    sweep_config = config.ketamine_sweep
    doses = []

    target_modules = sweep_config.target_module_configs[sweep_config.default_target_config]

    for multiplier in sweep_config.budget_multipliers:
        target_flops = base_flops * multiplier

        dose = TreatmentDose(
            treatment_name='ketamine',
            lora_rank=optimal_rank,
            lora_alpha=int(optimal_rank * sweep_config.alpha_multiplier),
            epochs=3,  # Placeholder
            lr=sweep_config.learning_rates[0],
            dropout=sweep_config.dropouts[0],
            target_modules=target_modules
        )

        dose.compute_flops(config.finetune_subset_size, config.max_seq_length)
        flops_per_epoch = dose.estimated_flops / dose.epochs

        budget_epochs = max(1, int(round(target_flops / flops_per_epoch)))
        dose.epochs = budget_epochs
        dose.estimated_flops = flops_per_epoch * budget_epochs

        doses.append(dose)

    return doses


def generate_target_module_sweep_doses(config: ExperimentConfig,
                                        optimal_rank: int,
                                        base_flops: float) -> List[TreatmentDose]:
    """
    Generate Ketamine doses varying target modules.

    MODULE SWEEP PRINCIPLE:
    Test whether the structural regrowth benefit comes from:
    - Attention layers (q, k, v, o projections)
    - MLP layers (gate, up, down projections)
    - Or both together
    """
    sweep_config = config.ketamine_sweep
    doses = []

    for config_name, modules in sweep_config.target_module_configs.items():
        dose = TreatmentDose(
            treatment_name=f'ketamine_{config_name}',
            lora_rank=optimal_rank,
            lora_alpha=int(optimal_rank * sweep_config.alpha_multiplier),
            epochs=3,
            lr=sweep_config.learning_rates[0],
            dropout=sweep_config.dropouts[0],
            target_modules=modules
        )

        dose.compute_flops(config.finetune_subset_size, config.max_seq_length)
        flops_per_epoch = dose.estimated_flops / dose.epochs

        isodose_epochs = max(1, int(round(base_flops / flops_per_epoch)))
        dose.epochs = isodose_epochs
        dose.estimated_flops = flops_per_epoch * isodose_epochs

        doses.append(dose)

    return doses


def run_single_dose_evaluation(dose: TreatmentDose,
                                model_id: str,
                                config: ExperimentConfig,
                                seed: int,
                                eval_datasets: Dict,
                                train_dataset,
                                baseline_composite: float,
                                test_relapse: bool = True,
                                test_longitudinal: bool = False) -> SweepResult:
    """
    Run evaluation for a single dose configuration.
    Returns a SweepResult with all metrics.
    """
    print(f"\n    Testing: rank={dose.lora_rank}, epochs={dose.epochs}, "
          f"FLOPs={dose.estimated_flops:.2e}")

    # Load fresh model
    model, tokenizer, architecture = load_model_for_sweep(model_id, config)

    # Adapt target modules for architecture if needed
    adapted_modules = get_target_modules_for_architecture(
        architecture,
        config.ketamine_sweep.default_target_config,
        config.ketamine_sweep
    )
    dose.target_modules = adapted_modules

    # Apply treatment
    treated_model, stats = apply_ketamine_dose(
        model, tokenizer, train_dataset, dose, config, seed
    )

    # Evaluate post-treatment
    post_treatment = evaluate_model_quick(
        treated_model, tokenizer, eval_datasets,
        max_samples=config.ketamine_sweep.sweep_eval_samples
    )

    recovery = post_treatment['composite'] - baseline_composite
    efficiency = recovery / (dose.estimated_flops / 1e15) if dose.estimated_flops > 0 else 0

    print(f"      Post-treatment: {post_treatment['composite']:.1f}% "
          f"(recovery: {recovery:+.1f}%, efficiency: {efficiency:.2f})")

    # Create result object
    result = SweepResult(
        dose=dose,
        seed=seed,
        model_id=model_id,
        post_treatment_composite=post_treatment['composite'],
        post_treatment_arc=post_treatment['arc_easy'],
        post_treatment_lambada=post_treatment['lambada'],
        recovery_from_baseline=recovery,
        efficiency=efficiency,
        train_time=stats['train_time']
    )

    # Test relapse resilience
    if test_relapse:
        apply_acute_relapse(treated_model, config.acute_relapse_prune_amount)
        post_relapse = evaluate_model_quick(
            treated_model, tokenizer, eval_datasets,
            max_samples=config.ketamine_sweep.sweep_eval_samples // 2
        )
        result.post_relapse_composite = post_relapse['composite']
        result.relapse_drop = post_treatment['composite'] - post_relapse['composite']
        print(f"      Post-relapse: {post_relapse['composite']:.1f}% "
              f"(drop: {result.relapse_drop:.1f}%)")

    # Clean up
    del treated_model, model
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

    return result


def run_ketamine_isodose_sweep(config: ExperimentConfig) -> Dict:
    """
    Run the Ketamine isodose sweep to find optimal rank.

    SWEEP PROCEDURE:
    1. Establish baseline (untreated pruned model performance)
    2. Calculate base FLOPs from reference configuration
    3. Generate isodose configurations varying rank
    4. Evaluate each configuration
    5. Rank by composite score
    6. Return top configurations for validation
    """
    print("\n" + "=" * 80)
    print("KETAMINE ISODOSE SWEEP")
    print("=" * 80)

    sweep_config = config.ketamine_sweep
    os.makedirs(sweep_config.sweep_output_dir, exist_ok=True)

    # -------------------------------------------------------------------------
    # STEP 1: Load model and establish baseline
    # -------------------------------------------------------------------------
    print("\n[STEP 1] Establishing baseline")
    print("-" * 60)

    model_id = sweep_config.generalization_models[0][0]
    print(f"  Primary model: {model_id}")

    model, tokenizer, architecture = load_model_for_sweep(model_id, config)

    print("  Loading evaluation datasets...")
    eval_datasets = load_evaluation_datasets(config, sweep_config.sweep_eval_samples)

    print("  Evaluating untreated baseline...")
    baseline = evaluate_model_quick(model, tokenizer, eval_datasets,
                                     sweep_config.sweep_eval_samples)
    baseline_composite = baseline['composite']
    print(f"  Baseline composite: {baseline_composite:.1f}%")

    # Prepare training dataset
    print("  Preparing training dataset...")
    train_dataset = prepare_finetune_dataset(tokenizer, config, seed=0)

    del model
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

    # -------------------------------------------------------------------------
    # STEP 2: Calculate base FLOPs
    # -------------------------------------------------------------------------
    print("\n[STEP 2] Calibrating base FLOPs")
    print("-" * 60)

    reference_dose = TreatmentDose(
        treatment_name='ketamine_reference',
        lora_rank=config.ketamine_lora_rank,
        lora_alpha=config.ketamine_lora_alpha,
        epochs=config.ketamine_epochs,
        lr=config.ketamine_lr,
        dropout=config.ketamine_dropout,
        target_modules=config.ketamine_target_modules
    )
    base_flops = reference_dose.compute_flops(
        config.finetune_subset_size,
        config.max_seq_length
    )

    print(f"  Reference rank: {config.ketamine_lora_rank}")
    print(f"  Reference epochs: {config.ketamine_epochs}")
    print(f"  Base FLOPs: {base_flops:.2e}")

    # -------------------------------------------------------------------------
    # STEP 3: Generate isodose sweep configurations
    # -------------------------------------------------------------------------
    print("\n[STEP 3] Generating isodose sweep configurations")
    print("-" * 60)

    doses = generate_isodose_sweep_doses(config, base_flops)

    print(f"  Configurations to test: {len(doses)}")
    print(f"\n  {'Rank':>6} {'Alpha':>8} {'Epochs':>8} {'FLOPs':>14} {'Ratio':>8}")
    print("  " + "-" * 50)

    for dose in doses:
        ratio = dose.estimated_flops / base_flops
        print(f"  {dose.lora_rank:>6} {dose.lora_alpha:>8} {dose.epochs:>8} "
              f"{dose.estimated_flops:>14.2e} {ratio:>7.2f}x")

    # -------------------------------------------------------------------------
    # STEP 4: Run sweep evaluations
    # -------------------------------------------------------------------------
    print("\n[STEP 4] Running isodose sweep")
    print("-" * 60)

    sweep_results = []

    for i, dose in enumerate(doses):
        print(f"\n  Configuration {i+1}/{len(doses)}")

        for seed in range(sweep_config.sweep_seeds):
            result = run_single_dose_evaluation(
                dose=copy.deepcopy(dose),
                model_id=model_id,
                config=config,
                seed=seed,
                eval_datasets=eval_datasets,
                train_dataset=train_dataset,
                baseline_composite=baseline_composite,
                test_relapse=sweep_config.test_relapse_in_sweep,
                test_longitudinal=sweep_config.test_longitudinal_in_sweep
            )
            sweep_results.append(result)

    # -------------------------------------------------------------------------
    # STEP 5: Rank configurations
    # -------------------------------------------------------------------------
    print("\n[STEP 5] Ranking configurations")
    print("-" * 60)

    # Calculate composite scores
    for result in sweep_results:
        result.composite_score = result.calculate_composite_score()

    # Sort by composite score
    sweep_results.sort(key=lambda x: x.composite_score, reverse=True)

    print(f"\n  ISODOSE SWEEP RANKINGS (by composite score)")
    print(f"\n  {'Rank':>4} {'LoRA r':>8} {'Epochs':>8} {'Recovery':>10} "
          f"{'Efficiency':>12} {'Relapse':>10} {'Score':>8}")
    print("  " + "-" * 70)

    for i, result in enumerate(sweep_results):
        print(f"  {i+1:>4} {result.dose.lora_rank:>8} {result.dose.epochs:>8} "
              f"{result.recovery_from_baseline:>+9.1f}% "
              f"{result.efficiency:>11.2f} "
              f"{result.relapse_drop:>9.1f}% "
              f"{result.composite_score:>8.2f}")

    # Identify optimal configuration
    optimal_result = sweep_results[0]
    optimal_rank = optimal_result.dose.lora_rank

    print(f"\n  OPTIMAL CONFIGURATION:")
    print(f"    LoRA rank: {optimal_rank}")
    print(f"    Epochs: {optimal_result.dose.epochs}")
    print(f"    Recovery: {optimal_result.recovery_from_baseline:+.1f}%")
    print(f"    Efficiency: {optimal_result.efficiency:.2f} recovery/PetaFLOP")
    print(f"    Relapse drop: {optimal_result.relapse_drop:.1f}%")
    print(f"    Composite score: {optimal_result.composite_score:.2f}")

    # Save sweep results
    sweep_data = {
        'baseline_composite': baseline_composite,
        'base_flops': base_flops,
        'optimal_rank': optimal_rank,
        'results': [r.to_dict() for r in sweep_results]
    }

    with open(os.path.join(sweep_config.sweep_output_dir, 'isodose_sweep_results.json'), 'w') as f:
        json.dump(sweep_data, f, indent=2)

    print(f"\n  [SAVED] Results saved to {sweep_config.sweep_output_dir}/isodose_sweep_results.json")

    return {
        'sweep_results': sweep_results,
        'optimal_rank': optimal_rank,
        'optimal_result': optimal_result,
        'baseline_composite': baseline_composite,
        'base_flops': base_flops,
        'eval_datasets': eval_datasets,
        'train_dataset': train_dataset
    }


# ============================================================================
# CELL 10: GENERALIZATION TEST
# ============================================================================
def run_generalization_test(optimal_dose: TreatmentDose,
                             sweep_data: Dict,
                             config: ExperimentConfig) -> Dict:
    """
    Test the optimal Ketamine dose on additional model architectures.

    GENERALIZATION HYPOTHESIS:
    If the optimal Ketamine-like treatment generalizes across architectures:
    - Structural regrowth (high-rank LoRA) is a fundamental recovery mechanism
    - The approach can be recommended as a general pruning recovery strategy

    If it doesn't generalize:
    - The optimal dose may be architecture-specific
    - Different pruning methods may require different recovery strategies
    """
    print("\n" + "=" * 80)
    print("GENERALIZATION TEST")
    print("=" * 80)

    sweep_config = config.ketamine_sweep

    if len(sweep_config.generalization_models) < 2:
        print("  [SKIP] No additional models configured for generalization test")
        print("  To enable generalization testing, add models to:")
        print("    config.ketamine_sweep.generalization_models")
        return {'skipped': True, 'reason': 'No additional models configured'}

    print(f"  Optimal dose: rank={optimal_dose.lora_rank}, epochs={optimal_dose.epochs}")
    print(f"  Models to test: {len(sweep_config.generalization_models) - 1}")

    generalization_results = {}

    # Get primary model results for comparison
    primary_model_id = sweep_config.generalization_models[0][0]
    primary_result = sweep_data.get('optimal_result')

    if primary_result:
        generalization_results[primary_model_id] = {
            'description': sweep_config.generalization_models[0][1],
            'recovery': primary_result.recovery_from_baseline,
            'efficiency': primary_result.efficiency,
            'relapse_drop': primary_result.relapse_drop,
            'is_primary': True
        }

    # Test on additional models
    for model_id, description, architecture in sweep_config.generalization_models[1:]:
        print(f"\n  Testing on: {description}")
        print(f"    Model: {model_id}")
        print("-" * 60)

        try:
            # Load model
            model, tokenizer, detected_arch = load_model_for_sweep(model_id, config)

            # Establish baseline for this model
            print("    Evaluating baseline...")
            eval_datasets = load_evaluation_datasets(config, sweep_config.sweep_eval_samples)
            baseline = evaluate_model_quick(model, tokenizer, eval_datasets,
                                            sweep_config.sweep_eval_samples)
            baseline_composite = baseline['composite']
            print(f"    Baseline: {baseline_composite:.1f}%")

            # Prepare training data
            train_dataset = prepare_finetune_dataset(tokenizer, config, seed=0)

            del model
            gc.collect()
            if torch.cuda.is_available():
                torch.cuda.empty_cache()

            # Create adapted dose for this architecture
            adapted_dose = copy.deepcopy(optimal_dose)
            adapted_dose.target_modules = get_target_modules_for_architecture(
                detected_arch,
                sweep_config.default_target_config,
                sweep_config
            )

            # Recalculate FLOPs for this architecture (may differ slightly)
            adapted_dose.compute_flops(config.finetune_subset_size, config.max_seq_length)

            # Run validation seeds
            seed_results = []
            for seed in range(sweep_config.validation_seeds):
                print(f"\n    Seed {seed + 1}/{sweep_config.validation_seeds}")

                result = run_single_dose_evaluation(
                    dose=copy.deepcopy(adapted_dose),
                    model_id=model_id,
                    config=config,
                    seed=seed,
                    eval_datasets=eval_datasets,
                    train_dataset=train_dataset,
                    baseline_composite=baseline_composite,
                    test_relapse=True,
                    test_longitudinal=False
                )
                seed_results.append(result)

            # Aggregate results
            avg_recovery = np.mean([r.recovery_from_baseline for r in seed_results])
            std_recovery = np.std([r.recovery_from_baseline for r in seed_results])
            avg_efficiency = np.mean([r.efficiency for r in seed_results])
            avg_relapse = np.mean([r.relapse_drop for r in seed_results])

            generalization_results[model_id] = {
                'description': description,
                'architecture': detected_arch,
                'baseline': baseline_composite,
                'recovery': avg_recovery,
                'recovery_std': std_recovery,
                'efficiency': avg_efficiency,
                'relapse_drop': avg_relapse,
                'is_primary': False,
                'n_seeds': len(seed_results)
            }

            print(f"\n    Results for {description}:")
            print(f"      Recovery: {avg_recovery:+.1f}% ± {std_recovery:.1f}%")
            print(f"      Efficiency: {avg_efficiency:.2f}")
            print(f"      Relapse drop: {avg_relapse:.1f}%")

        except Exception as e:
            print(f"    [ERROR] Failed to test {model_id}: {e}")
            generalization_results[model_id] = {
                'description': description,
                'error': str(e),
                'is_primary': False
            }
            continue

    # -------------------------------------------------------------------------
    # GENERALIZATION ANALYSIS
    # -------------------------------------------------------------------------
    print("\n" + "=" * 80)
    print("GENERALIZATION ANALYSIS")
    print("=" * 80)

    successful_models = {k: v for k, v in generalization_results.items()
                         if 'error' not in v and 'recovery' in v}

    if len(successful_models) >= 2:
        recoveries = [v['recovery'] for v in successful_models.values()]
        efficiencies = [v['efficiency'] for v in successful_models.values()]

        # Calculate generalization metrics
        recovery_mean = np.mean(recoveries)
        recovery_std = np.std(recoveries)
        recovery_range = max(recoveries) - min(recoveries)

        efficiency_mean = np.mean(efficiencies)
        efficiency_std = np.std(efficiencies)

        print(f"\n  CROSS-MODEL STATISTICS (n={len(successful_models)} models):")
        print(f"    Recovery: {recovery_mean:+.1f}% ± {recovery_std:.1f}%")
        print(f"    Recovery range: {recovery_range:.1f}%")
        print(f"    Efficiency: {efficiency_mean:.2f} ± {efficiency_std:.2f}")

        # Generalization score: lower std relative to mean = better generalization
        if recovery_mean != 0:
            generalization_coefficient = recovery_std / abs(recovery_mean)
            print(f"\n    Generalization coefficient (CV): {generalization_coefficient:.2f}")

            if generalization_coefficient < 0.3:
                print("    → STRONG GENERALIZATION: Optimal dose transfers well")
            elif generalization_coefficient < 0.5:
                print("    → MODERATE GENERALIZATION: Some architecture dependence")
            else:
                print("    → WEAK GENERALIZATION: Dose may be architecture-specific")

        # Compare to primary model
        if primary_model_id in successful_models:
            primary_recovery = successful_models[primary_model_id]['recovery']
            for model_id, data in successful_models.items():
                if model_id != primary_model_id:
                    delta = data['recovery'] - primary_recovery
                    print(f"\n    {data['description']}:")
                    print(f"      Delta vs primary: {delta:+.1f}%")

    # Save generalization results
    with open(os.path.join(sweep_config.sweep_output_dir,
                           'generalization_results.json'), 'w') as f:
        json.dump(generalization_results, f, indent=2, default=str)

    print(f"\n  [SAVED] Generalization results saved")

    return generalization_results


# ============================================================================
# CELL 11: MAIN SWEEP RUNNER
# ============================================================================
def run_full_ketamine_optimization(config: ExperimentConfig = None):
    """
    Run the complete Ketamine optimization pipeline:
    1. Isodose sweep to find optimal rank
    2. (Optional) Budget sweep at optimal rank
    3. (Optional) Target module sweep
    4. Generalization test on other architectures
    """
    if config is None:
        config = CONFIG

    print("\n")
    print("#" * 80)
    print("#" * 80)
    print("#" + " " * 78 + "#")
    print("#" + " KETAMINE ISODOSE OPTIMIZATION & GENERALIZATION TEST ".center(78) + "#")
    print("#" + " " * 78 + "#")
    print("#" * 80)
    print("#" * 80)

    sweep_config = config.ketamine_sweep

    print("\n" + "=" * 80)
    print("OPTIMIZATION CONFIGURATION")
    print("=" * 80)
    print(f"  Sweep mode: {sweep_config.sweep_mode}")
    print(f"  Ranks to test: {sweep_config.ranks}")
    print(f"  Sweep seeds: {sweep_config.sweep_seeds}")
    print(f"  Validation seeds: {sweep_config.validation_seeds}")
    print(f"  Top-K to validate: {sweep_config.top_k_validate}")
    print(f"  Test relapse: {sweep_config.test_relapse_in_sweep}")
    print(f"  Test longitudinal: {sweep_config.test_longitudinal_in_sweep}")
    print(f"  Generalization models: {len(sweep_config.generalization_models)}")

    for model_id, desc, arch in sweep_config.generalization_models:
        print(f"    - {desc} ({arch})")

    print("=" * 80)

    results = {}

    # -------------------------------------------------------------------------
    # PHASE 1: Isodose Sweep
    # -------------------------------------------------------------------------
    print("\n" + "=" * 80)
    print("PHASE 1: ISODOSE SWEEP")
    print("=" * 80)

    sweep_data = run_ketamine_isodose_sweep(config)
    results['isodose_sweep'] = sweep_data

    optimal_rank = sweep_data['optimal_rank']
    optimal_result = sweep_data['optimal_result']
    base_flops = sweep_data['base_flops']

    # -------------------------------------------------------------------------
    # PHASE 2: Validation with more seeds (top configurations)
    # -------------------------------------------------------------------------
    if sweep_config.validation_seeds > sweep_config.sweep_seeds:
        print("\n" + "=" * 80)
        print("PHASE 2: VALIDATION OF TOP CONFIGURATIONS")
        print("=" * 80)

        top_results = sweep_data['sweep_results'][:sweep_config.top_k_validate]

        validation_results = []

        for i, result in enumerate(top_results):
            print(f"\n  Validating configuration {i+1}/{len(top_results)}: "
                  f"rank={result.dose.lora_rank}")

            # Run additional seeds
            model_id = sweep_config.generalization_models[0][0]

            additional_seeds = []
            for seed in range(sweep_config.sweep_seeds, sweep_config.validation_seeds):
                val_result = run_single_dose_evaluation(
                    dose=copy.deepcopy(result.dose),
                    model_id=model_id,
                    config=config,
                    seed=seed,
                    eval_datasets=sweep_data['eval_datasets'],
                    train_dataset=sweep_data['train_dataset'],
                    baseline_composite=sweep_data['baseline_composite'],
                    test_relapse=True,
                    test_longitudinal=False
                )
                additional_seeds.append(val_result)

            # Combine with original seeds
            all_seed_results = [result] + additional_seeds

            avg_recovery = np.mean([r.recovery_from_baseline for r in all_seed_results])
            std_recovery = np.std([r.recovery_from_baseline for r in all_seed_results])
            avg_score = np.mean([r.calculate_composite_score() for r in all_seed_results])

            validation_results.append({
                'rank': result.dose.lora_rank,
                'recovery_mean': avg_recovery,
                'recovery_std': std_recovery,
                'score_mean': avg_score,
                'n_seeds': len(all_seed_results)
            })

            print(f"    Validated recovery: {avg_recovery:+.1f}% ± {std_recovery:.1f}%")
            print(f"    Validated score: {avg_score:.2f}")

        results['validation'] = validation_results

        # Update optimal if validation changed ranking
        validation_results.sort(key=lambda x: x['score_mean'], reverse=True)
        validated_optimal_rank = validation_results[0]['rank']

        if validated_optimal_rank != optimal_rank:
            print(f"\n  [NOTE] Validation changed optimal rank: "
                  f"{optimal_rank} -> {validated_optimal_rank}")
            optimal_rank = validated_optimal_rank

    # -------------------------------------------------------------------------
    # PHASE 3: Generalization Test
    # -------------------------------------------------------------------------
    if config.run_generalization_test:
        print("\n" + "=" * 80)
        print("PHASE 3: GENERALIZATION TEST")
        print("=" * 80)

        # Create optimal dose for generalization
        optimal_dose = TreatmentDose(
            treatment_name='ketamine_optimal',
            lora_rank=optimal_rank,
            lora_alpha=int(optimal_rank * sweep_config.alpha_multiplier),
            epochs=optimal_result.dose.epochs,
            lr=sweep_config.learning_rates[0],
            dropout=sweep_config.dropouts[0],
            target_modules=sweep_config.target_module_configs[sweep_config.default_target_config]
        )
        optimal_dose.compute_flops(config.finetune_subset_size, config.max_seq_length)

        gen_results = run_generalization_test(optimal_dose, sweep_data, config)
        results['generalization'] = gen_results

    # -------------------------------------------------------------------------
    # FINAL SUMMARY
    # -------------------------------------------------------------------------
    print("\n" + "=" * 80)
    print("=" * 80)
    print("  KETAMINE OPTIMIZATION COMPLETE")
    print("=" * 80)
    print("=" * 80)

    print(f"\n  OPTIMAL KETAMINE DOSE:")
    print(f"    LoRA rank: {optimal_rank}")
    print(f"    Alpha: {int(optimal_rank * sweep_config.alpha_multiplier)}")
    print(f"    Epochs (isodose): {optimal_result.dose.epochs}")
    print(f"    Target FLOPs: {base_flops:.2e}")

    print(f"\n  PRIMARY MODEL PERFORMANCE:")
    print(f"    Recovery: {optimal_result.recovery_from_baseline:+.1f}%")
    print(f"    Efficiency: {optimal_result.efficiency:.2f} recovery/PetaFLOP")
    print(f"    Relapse resilience: {optimal_result.relapse_drop:.1f}% drop")

    if 'generalization' in results and not results['generalization'].get('skipped'):
        gen = results['generalization']
        successful = {k: v for k, v in gen.items() if 'recovery' in v}
        if len(successful) > 1:
            recoveries = [v['recovery'] for v in successful.values()]
            print(f"\n  GENERALIZATION:")
            print(f"    Models tested: {len(successful)}")
            print(f"    Recovery range: {min(recoveries):+.1f}% to {max(recoveries):+.1f}%")
            print(f"    Mean recovery: {np.mean(recoveries):+.1f}%")

    print("\n  RECOMMENDED CONFIGURATION FOR OTHER PRUNED LLMS:")
    print(f"    lora_rank = {optimal_rank}")
    print(f"    lora_alpha = {int(optimal_rank * sweep_config.alpha_multiplier)}")
    print(f"    target_modules = {sweep_config.target_module_configs[sweep_config.default_target_config]}")
    print(f"    dropout = {sweep_config.dropouts[0]}")
    print(f"    learning_rate = {sweep_config.learning_rates[0]}")
    print(f"    epochs = (calibrate to match {base_flops:.2e} FLOPs)")

    # Save final summary
    summary = {
        'optimal_rank': optimal_rank,
        'optimal_alpha': int(optimal_rank * sweep_config.alpha_multiplier),
        'optimal_epochs': optimal_result.dose.epochs,
        'base_flops': base_flops,
        'primary_recovery': optimal_result.recovery_from_baseline,
        'primary_efficiency': optimal_result.efficiency,
        'primary_relapse_drop': optimal_result.relapse_drop,
        'target_modules': sweep_config.target_module_configs[sweep_config.default_target_config],
        'dropout': sweep_config.dropouts[0],
        'learning_rate': sweep_config.learning_rates[0]
    }

    with open(os.path.join(sweep_config.sweep_output_dir, 'optimal_dose_summary.json'), 'w') as f:
        json.dump(summary, f, indent=2)

    print(f"\n  [SAVED] All results saved to {sweep_config.sweep_output_dir}/")
    print("=" * 80)

    return results


# ============================================================================
# CELL 12: ENTRY POINT
# ============================================================================
if __name__ == "__main__":
    # Configure experiment
    config = ExperimentConfig()

    # -------------------------------------------------------------------------
    # SWEEP CONFIGURATION
    # -------------------------------------------------------------------------
    # Customize the sweep parameters here

    # Ranks to test (capacity sweep)
    config.ketamine_sweep.ranks = [16, 32, 64, 96, 128]

    # Seeds for sweep (1 for fast, 3 for reliable)
    config.ketamine_sweep.sweep_seeds = 1
    config.ketamine_sweep.validation_seeds = 3

    # Number of top configs to validate with more seeds
    config.ketamine_sweep.top_k_validate = 2

    # Generalization models
    # Add additional pruned models here for cross-architecture testing
    config.ketamine_sweep.generalization_models = [
        ("oopere/pruned60-llama-3.2-1B", "Llama-3.2-1B 60% pruned", "llama"),  # Primary (keep)
        ("oopere/pruned40-gemma-2-2b", "Gemma-2-2B 40% MLP pruned (same pruning style/author)", "gemma"),
        ("pszemraj/Mistral-7B-v0.3-prune6", "Mistral-7B layer-pruned (6 layers removed)", "mistral"),
        ("pszemraj/Phi-3-small-8k-prune6", "Phi-3-small layer-pruned (6 layers removed)", "phi"),
    ]

    # Run configuration
    config.run_main_experiment = False  # Skip original 3-treatment comparison
    config.run_ketamine_sweep = True    # Run Ketamine optimization
    config.run_generalization_test = True  # Test on other architectures

    # -------------------------------------------------------------------------
    # RUN OPTIMIZATION
    # -------------------------------------------------------------------------
    print("\n" + "=" * 80)
    print("EXPERIMENT MODE SELECTION")
    print("=" * 80)
    print(f"  run_main_experiment: {config.run_main_experiment}")
    print(f"  run_ketamine_sweep: {config.run_ketamine_sweep}")
    print(f"  run_generalization_test: {config.run_generalization_test}")
    print("=" * 80)

    if config.run_ketamine_sweep:
        results = run_full_ketamine_optimization(config)
    elif config.run_main_experiment:
        # Run original experiment if needed
        from dataclasses import replace
        original_config = replace(config, isodose_mode=True)
        # run_experiment(original_config)  # Uncomment to run
        pass
    else:
        print("\n[INFO] No experiment mode selected. Set run_ketamine_sweep=True or run_main_experiment=True")


EXPERIMENT MODE SELECTION
  run_main_experiment: False
  run_ketamine_sweep: True
  run_generalization_test: True


################################################################################
################################################################################
#                                                                              #
#             KETAMINE ISODOSE OPTIMIZATION & GENERALIZATION TEST              #
#                                                                              #
################################################################################
################################################################################

OPTIMIZATION CONFIGURATION
  Sweep mode: isodose
  Ranks to test: [16, 32, 64, 96, 128]
  Sweep seeds: 1
  Validation seeds: 3
  Top-K to validate: 2
  Test relapse: True
  Test longitudinal: False
  Generalization models: 4
    - Llama-3.2-1B 60% pruned (llama)
    - Gemma-2-2B 40% MLP pruned (same pruning style/

Map:   0%|          | 0/500 [00:00<?, ? examples/s]

Map:   0%|          | 0/500 [00:00<?, ? examples/s]


[STEP 2] Calibrating base FLOPs
------------------------------------------------------------
  Reference rank: 64
  Reference epochs: 3
  Base FLOPs: 1.35e+14

[STEP 3] Generating isodose sweep configurations
------------------------------------------------------------
  Configurations to test: 5

    Rank    Alpha   Epochs          FLOPs    Ratio
  --------------------------------------------------
      16       32       12       1.35e+14    1.00x
      32       64        6       1.35e+14    1.00x
      64      128        3       1.35e+14    1.00x
      96      192        2       1.35e+14    1.00x
     128      256        2       1.80e+14    1.33x

[STEP 4] Running isodose sweep
------------------------------------------------------------

  Configuration 1/5

    Testing: rank=16, epochs=12, FLOPs=1.35e+14
  Loading model: oopere/pruned60-llama-3.2-1B
    Architecture: llama
    Parameters: 752,650,240
    Device: cuda:0


The model is already on multiple devices. Skipping the move to device specified in `args`.


Step,Training Loss
100,3.9099
200,3.0312
300,2.7608


      Post-treatment: 20.5% (recovery: -2.0%, efficiency: -14.78)
      Post-relapse: 15.0% (drop: 5.5%)

  Configuration 2/5

    Testing: rank=32, epochs=6, FLOPs=1.35e+14
  Loading model: oopere/pruned60-llama-3.2-1B
    Architecture: llama
    Parameters: 752,650,240
    Device: cuda:0


The model is already on multiple devices. Skipping the move to device specified in `args`.


Step,Training Loss
100,3.6492


      Post-treatment: 26.0% (recovery: +3.5%, efficiency: 25.87)
      Post-relapse: 23.0% (drop: 3.0%)

  Configuration 3/5

    Testing: rank=64, epochs=3, FLOPs=1.35e+14
  Loading model: oopere/pruned60-llama-3.2-1B
    Architecture: llama
    Parameters: 752,650,240
    Device: cuda:0


The model is already on multiple devices. Skipping the move to device specified in `args`.


Step,Training Loss


      Post-treatment: 22.5% (recovery: +0.0%, efficiency: 0.00)
      Post-relapse: 22.0% (drop: 0.5%)

  Configuration 4/5

    Testing: rank=96, epochs=2, FLOPs=1.35e+14
  Loading model: oopere/pruned60-llama-3.2-1B
    Architecture: llama
    Parameters: 752,650,240
    Device: cuda:0


The model is already on multiple devices. Skipping the move to device specified in `args`.


Step,Training Loss


      Post-treatment: 20.0% (recovery: -2.5%, efficiency: -18.48)
      Post-relapse: 20.0% (drop: 0.0%)

  Configuration 5/5

    Testing: rank=128, epochs=2, FLOPs=1.80e+14
  Loading model: oopere/pruned60-llama-3.2-1B
    Architecture: llama
    Parameters: 752,650,240
    Device: cuda:0


The model is already on multiple devices. Skipping the move to device specified in `args`.


Step,Training Loss


      Post-treatment: 18.0% (recovery: -4.5%, efficiency: -24.95)
      Post-relapse: 20.0% (drop: -2.0%)

[STEP 5] Ranking configurations
------------------------------------------------------------

  ISODOSE SWEEP RANKINGS (by composite score)

  Rank   LoRA r   Epochs   Recovery   Efficiency    Relapse    Score
  ----------------------------------------------------------------------
     1       32        6      +3.5%       25.87       3.0%     6.79
     2       64        3      +0.0%        0.00       0.5%     3.80
     3       96        2      -2.5%      -18.48       0.0%     1.15
     4      128        2      -4.5%      -24.95      -2.0%     0.51
     5       16       12      -2.0%      -14.78       5.5%    -0.48

  OPTIMAL CONFIGURATION:
    LoRA rank: 32
    Epochs: 6
    Recovery: +3.5%
    Efficiency: 25.87 recovery/PetaFLOP
    Relapse drop: 3.0%
    Composite score: 6.79

  [SAVED] Results saved to ./ketamine_sweep_results/isodose_sweep_results.json

PHASE 2: VALIDATION OF

The model is already on multiple devices. Skipping the move to device specified in `args`.


Step,Training Loss
100,3.6469


      Post-treatment: 23.0% (recovery: +0.5%, efficiency: 3.70)
      Post-relapse: 21.0% (drop: 2.0%)

    Testing: rank=32, epochs=6, FLOPs=1.35e+14
  Loading model: oopere/pruned60-llama-3.2-1B
    Architecture: llama
    Parameters: 752,650,240
    Device: cuda:0


The model is already on multiple devices. Skipping the move to device specified in `args`.


Step,Training Loss
100,3.6494


      Post-treatment: 23.0% (recovery: +0.5%, efficiency: 3.70)
      Post-relapse: 23.0% (drop: 0.0%)
    Validated recovery: +1.5% ± 1.4%
    Validated score: 5.04

  Validating configuration 2/2: rank=64

    Testing: rank=64, epochs=3, FLOPs=1.35e+14
  Loading model: oopere/pruned60-llama-3.2-1B
    Architecture: llama
    Parameters: 752,650,240
    Device: cuda:0


The model is already on multiple devices. Skipping the move to device specified in `args`.


Step,Training Loss


      Post-treatment: 22.5% (recovery: +0.0%, efficiency: 0.00)
      Post-relapse: 17.0% (drop: 5.5%)

    Testing: rank=64, epochs=3, FLOPs=1.35e+14
  Loading model: oopere/pruned60-llama-3.2-1B
    Architecture: llama
    Parameters: 752,650,240
    Device: cuda:0


The model is already on multiple devices. Skipping the move to device specified in `args`.


Step,Training Loss


      Post-treatment: 23.5% (recovery: +1.0%, efficiency: 7.39)
      Post-relapse: 23.0% (drop: 0.5%)
    Validated recovery: +0.3% ± 0.5%
    Validated score: 3.51

PHASE 3: GENERALIZATION TEST

GENERALIZATION TEST
  Optimal dose: rank=32, epochs=6
  Models to test: 3

  Testing on: Gemma-2-2B 40% MLP pruned (same pruning style/author)
    Model: oopere/pruned40-gemma-2-2b
------------------------------------------------------------
  Loading model: oopere/pruned40-gemma-2-2b


config.json:   0%|          | 0.00/858 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/3.90G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/168 [00:00<?, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

tokenizer.model:   0%|          | 0.00/4.24M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/34.4M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/522 [00:00<?, ?B/s]

    Architecture: gemma
    Parameters: 1,951,923,456
    Device: cuda:0
    Evaluating baseline...
    ARC-Easy: 100 samples


The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


    LAMBADA: 50 samples
    Baseline: 32.0%


Map:   0%|          | 0/500 [00:00<?, ? examples/s]


    Seed 1/3

    Testing: rank=32, epochs=6, FLOPs=1.35e+14
  Loading model: oopere/pruned40-gemma-2-2b
    Architecture: gemma
    Parameters: 1,951,923,456
    Device: cuda:0


The model is already on multiple devices. Skipping the move to device specified in `args`.


Step,Training Loss
100,2.44


      Post-treatment: 37.5% (recovery: +5.5%, efficiency: 40.65)
      Post-relapse: 32.0% (drop: 5.5%)

    Seed 2/3

    Testing: rank=32, epochs=6, FLOPs=1.35e+14
  Loading model: oopere/pruned40-gemma-2-2b
    Architecture: gemma
    Parameters: 1,951,923,456
    Device: cuda:0


The model is already on multiple devices. Skipping the move to device specified in `args`.


Step,Training Loss
100,2.4416


      Post-treatment: 35.0% (recovery: +3.0%, efficiency: 22.17)
      Post-relapse: 31.0% (drop: 4.0%)

    Seed 3/3

    Testing: rank=32, epochs=6, FLOPs=1.35e+14
  Loading model: oopere/pruned40-gemma-2-2b
    Architecture: gemma
    Parameters: 1,951,923,456
    Device: cuda:0


The model is already on multiple devices. Skipping the move to device specified in `args`.


Step,Training Loss
100,2.4421


      Post-treatment: 38.5% (recovery: +6.5%, efficiency: 48.04)
      Post-relapse: 30.0% (drop: 8.5%)

    Results for Gemma-2-2B 40% MLP pruned (same pruning style/author):
      Recovery: +5.0% ± 1.5%
      Efficiency: 36.96
      Relapse drop: 6.0%

  Testing on: Mistral-7B layer-pruned (6 layers removed)
    Model: pszemraj/Mistral-7B-v0.3-prune6
------------------------------------------------------------
  Loading model: pszemraj/Mistral-7B-v0.3-prune6


config.json:   0%|          | 0.00/644 [00:00<?, ?B/s]

model.safetensors.index.json: 0.00B [00:00, ?B/s]

Fetching 3 files:   0%|          | 0/3 [00:00<?, ?it/s]

model-00002-of-00003.safetensors:   0%|          | 0.00/4.92G [00:00<?, ?B/s]

model-00003-of-00003.safetensors:   0%|          | 0.00/2.06G [00:00<?, ?B/s]

model-00001-of-00003.safetensors:   0%|          | 0.00/4.90G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

tokenizer.model:   0%|          | 0.00/587k [00:00<?, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/414 [00:00<?, ?B/s]

    Architecture: mistral
    Parameters: 5,939,351,552
    Device: cuda:0
    Evaluating baseline...
    ARC-Easy: 100 samples
    LAMBADA: 50 samples
    Baseline: 67.5%


Map:   0%|          | 0/500 [00:00<?, ? examples/s]


    Seed 1/3

    Testing: rank=32, epochs=6, FLOPs=1.35e+14
  Loading model: pszemraj/Mistral-7B-v0.3-prune6


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

    Architecture: mistral
    Parameters: 5,939,351,552
    Device: cuda:0


The model is already on multiple devices. Skipping the move to device specified in `args`.


Step,Training Loss
100,1.6591


      Post-treatment: 72.5% (recovery: +5.0%, efficiency: 36.96)
      Post-relapse: 70.0% (drop: 2.5%)

    Seed 2/3

    Testing: rank=32, epochs=6, FLOPs=1.35e+14
  Loading model: pszemraj/Mistral-7B-v0.3-prune6


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

    Architecture: mistral
    Parameters: 5,939,351,552
    Device: cuda:0


The model is already on multiple devices. Skipping the move to device specified in `args`.


Step,Training Loss
100,1.6588


      Post-treatment: 71.0% (recovery: +3.5%, efficiency: 25.87)
      Post-relapse: 69.0% (drop: 2.0%)

    Seed 3/3

    Testing: rank=32, epochs=6, FLOPs=1.35e+14
  Loading model: pszemraj/Mistral-7B-v0.3-prune6


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

    Architecture: mistral
    Parameters: 5,939,351,552
    Device: cuda:0


The model is already on multiple devices. Skipping the move to device specified in `args`.


Step,Training Loss
100,1.6621


      Post-treatment: 69.5% (recovery: +2.0%, efficiency: 14.78)
      Post-relapse: 70.0% (drop: -0.5%)

    Results for Mistral-7B layer-pruned (6 layers removed):
      Recovery: +3.5% ± 1.2%
      Efficiency: 25.87
      Relapse drop: 1.3%

  Testing on: Phi-3-small layer-pruned (6 layers removed)
    Model: pszemraj/Phi-3-small-8k-prune6
------------------------------------------------------------
  Loading model: pszemraj/Phi-3-small-8k-prune6


config.json: 0.00B [00:00, ?B/s]

configuration_phi3_small.py: 0.00B [00:00, ?B/s]

tokenization_phi3_small.py: 0.00B [00:00, ?B/s]

A new version of the following files was downloaded from https://huggingface.co/microsoft/Phi-3-small-8k-instruct:
- tokenization_phi3_small.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.
A new version of the following files was downloaded from https://huggingface.co/microsoft/Phi-3-small-8k-instruct:
- configuration_phi3_small.py
- tokenization_phi3_small.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.


modeling_phi3_small.py: 0.00B [00:00, ?B/s]

positional_embedding.py: 0.00B [00:00, ?B/s]

A new version of the following files was downloaded from https://huggingface.co/microsoft/Phi-3-small-8k-instruct:
- positional_embedding.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.


triton_blocksparse_attention_layer.py: 0.00B [00:00, ?B/s]

triton_flash_blocksparse_attn.py: 0.00B [00:00, ?B/s]

A new version of the following files was downloaded from https://huggingface.co/microsoft/Phi-3-small-8k-instruct:
- triton_flash_blocksparse_attn.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.
A new version of the following files was downloaded from https://huggingface.co/microsoft/Phi-3-small-8k-instruct:
- triton_blocksparse_attention_layer.py
- triton_flash_blocksparse_attn.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.
A new version of the following files was downloaded from https://huggingface.co/microsoft/Phi-3-small-8k-instruct:
- modeling_phi3_small.py
- positional_embedding.py
- triton_blocksparse_attention_layer.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.


model.safetensors.index.json: 0.00B [00:00, ?B/s]

Fetching 3 files:   0%|          | 0/3 [00:00<?, ?it/s]

model-00001-of-00003.safetensors:   0%|          | 0.00/4.87G [00:00<?, ?B/s]

model-00003-of-00003.safetensors:   0%|          | 0.00/2.50G [00:00<?, ?B/s]

model-00002-of-00003.safetensors:   0%|          | 0.00/4.80G [00:00<?, ?B/s]

    [ERROR] Failed to load pszemraj/Phi-3-small-8k-prune6: Flash Attention is not available, but is needed for dense attention
    [ERROR] Failed to test pszemraj/Phi-3-small-8k-prune6: Flash Attention is not available, but is needed for dense attention

GENERALIZATION ANALYSIS

  CROSS-MODEL STATISTICS (n=3 models):
    Recovery: +4.0% ± 0.7%
    Recovery range: 1.5%
    Efficiency: 29.57 ± 5.23

    Generalization coefficient (CV): 0.18
    → STRONG GENERALIZATION: Optimal dose transfers well

    Gemma-2-2B 40% MLP pruned (same pruning style/author):
      Delta vs primary: +1.5%

    Mistral-7B layer-pruned (6 layers removed):
      Delta vs primary: +0.0%

  [SAVED] Generalization results saved

  KETAMINE OPTIMIZATION COMPLETE

  OPTIMAL KETAMINE DOSE:
    LoRA rank: 32
    Alpha: 64
    Epochs (isodose): 6
    Target FLOPs: 1.35e+14

  PRIMARY MODEL PERFORMANCE:
    Recovery: +3.5%
    Efficiency: 25.87 recovery/PetaFLOP
    Relapse resilience: 3.0% drop

  GENERALIZATION:
    M

# The End