In [None]:
# =============================================================================
# cell 1: environment setup (v14.1 - Hyperparameter Tuning)
# =============================================================================
import os
os.environ['TOKENIZERS_PARALLELISM'] = 'false'

import sys
import time
import math
import json
import base64
from pathlib import Path
from datetime import datetime
from dataclasses import dataclass, asdict, field
from typing import Dict, List, Optional, Tuple, Any
import warnings
warnings.filterwarnings('ignore')

# v15 full validation run: disable torch.compile for stability with rich spike instrumentation outputs.
USE_TORCH_COMPILE = False
USE_GRADIENT_CHECKPOINTING = True

# generate timestamp for this run
RUN_TIMESTAMP = datetime.now().strftime('%Y-%m-%d_%H%M%S')
print(f"run timestamp: {RUN_TIMESTAMP}")

# detect platform
IS_KAGGLE = 'KAGGLE_KERNEL_RUN_TYPE' in os.environ
IS_COLAB = 'COLAB_GPU' in os.environ or 'google.colab' in sys.modules
PLATFORM = 'kaggle' if IS_KAGGLE else 'colab' if IS_COLAB else 'local'
OUTPUT_DIR = '/kaggle/working/outputs' if IS_KAGGLE else 'outputs'

for subdir in ['figures', 'checkpoints', 'logs', 'results']:
    os.makedirs(f'{OUTPUT_DIR}/{subdir}', exist_ok=True)

print(f"platform: {PLATFORM}")
print(f"output directory: {OUTPUT_DIR}")
print(f"torch.compile: {'enabled' if USE_TORCH_COMPILE else 'disabled'}")
print(f"gradient checkpointing: {'enabled' if USE_GRADIENT_CHECKPOINTING else 'disabled'}")


In [None]:
# =============================================================================
# cell 2: pytorch and hardware setup (v14.1)
# =============================================================================
# Dependency bootstrap: fail early with clear errors, auto-install when allowed.
import importlib
import subprocess

AUTO_INSTALL_MISSING_DEPS = os.environ.get('GERHARD_AUTO_INSTALL_DEPS', '1') == '1'

def ensure_dependency(import_name: str, pip_name: str = None, required: bool = True) -> bool:
    pip_target = pip_name or import_name
    try:
        importlib.import_module(import_name)
        return True
    except ModuleNotFoundError as exc:
        if AUTO_INSTALL_MISSING_DEPS:
            print(f"missing dependency '{import_name}', attempting pip install: {pip_target}")
            try:
                subprocess.check_call([sys.executable, '-m', 'pip', 'install', '--quiet', pip_target])
                importlib.import_module(import_name)
                print(f"installed dependency '{import_name}'")
                return True
            except Exception as install_exc:
                message = (
                    f"failed to install dependency '{import_name}' via pip target '{pip_target}'. "
                    f"set GERHARD_AUTO_INSTALL_DEPS=0 to disable auto-install attempts."
                )
                if required:
                    raise ModuleNotFoundError(message) from install_exc
                print(f"warning: {message}")
                return False
        message = (
            f"missing dependency '{import_name}'. "
            f"install '{pip_target}' or set GERHARD_AUTO_INSTALL_DEPS=1 for automatic install."
        )
        if required:
            raise ModuleNotFoundError(message) from exc
        print(f"warning: {message}")
        return False

# Required dependencies used later in the notebook.
ensure_dependency('tqdm', 'tqdm', required=True)
ensure_dependency('transformers', 'transformers', required=True)
ensure_dependency('datasets', 'datasets', required=True)

# Optional plotting dependency. Training/evaluation can proceed without it.
MATPLOTLIB_AVAILABLE = ensure_dependency('matplotlib', 'matplotlib', required=False)

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
from torch.utils.checkpoint import checkpoint
import numpy as np

if MATPLOTLIB_AVAILABLE:
    import matplotlib
    matplotlib.use('Agg')
    import matplotlib.pyplot as plt
else:
    plt = None
    print("warning: matplotlib unavailable; plot generation will be skipped.")

from tqdm.auto import tqdm

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
SEED = 42

torch.manual_seed(SEED)
np.random.seed(SEED)

if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True
    torch.backends.cudnn.benchmark = True

    gpu_name = torch.cuda.get_device_name(0)
    gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1e9
    print(f"gpu: {gpu_name}")
    print(f"memory: {gpu_memory:.1f} gb")

# v14: set float32 matmul precision for torch.compile
if USE_TORCH_COMPILE and hasattr(torch, 'set_float32_matmul_precision'):
    torch.set_float32_matmul_precision('high')
    print("float32 matmul precision: high (for torch.compile)")

print(f"device: {DEVICE}")
print(f"pytorch: {torch.__version__}")

# check torch.compile availability
TORCH_COMPILE_AVAILABLE = hasattr(torch, 'compile') and torch.__version__ >= '2.0'
print(f"torch.compile available: {TORCH_COMPILE_AVAILABLE}")


In [None]:
# =============================================================================
# cell 4: configuration (v14.1 - Hyperparameter Tuning per External LLM)
# =============================================================================
@dataclass
class Config:
    # Version for dynamic labeling (NEVER hardcode versions elsewhere!)
    VERSION: str = 'v15'
    VERSION_DESC: str = 'SpikingBrain information encoding validation'
    
    # gpt-2 teacher (frozen, pre-trained)
    teacher_name: str = "gpt2"

    # student model architecture - v14.1: capacity increase (512d, 5L, ~56M)
    d_model: int = 768      # v14.3: safer scaling (512->768 instead of 512->1024)
    n_layers: int = 5       # v10 value (DO NOT reduce)
    vocab_size: int = 50257
    max_seq_len: int = 256

    # ==========================================================================
    # v14.1: Feature Dynamics Distillation (FDD) with CKA Loss
    # ==========================================================================
    # FDD aligns layer-wise dynamics (Δh) between student and teacher
    # Uses CKA (Centered Kernel Alignment) - projector-free, dimension-agnostic
    use_fdd: bool = True
    fdd_weight: float = 0.1           # v14.1: 100x increase (CKA bounded [0,1], safe)
    fdd_warmup_steps: int = 500       # Don't enable until step 500
    fdd_loss_type: str = "cka"        # Options: "cka" (recommended), "mse"
    fdd_kill_threshold: float = 0.10  # Disable if PPL increases >10%

    # ==========================================================================
    # v14.1: Hard Distillation (CE with ground truth)
    # ==========================================================================
    # Anchors student to correct tokens, not just teacher's soft distribution
    ce_hard_weight: float = 0.5       # Ground truth CE loss weight
    
    # Layer mapping: student_layer -> teacher_layer
    # With 5 student layers and 12 teacher layers:
    # We align early/middle/late semantic representations
    # Default: {0: 2, 2: 6, 4: 10}
    fdd_n_align_layers: int = 3       # Number of layer pairs to align

    # ==========================================================================
    # v14.1: Extended Training (same as v13.1)
    # ==========================================================================
    distill_steps: int = 7000       # v14.3: more steps for larger model
    distill_lr: float = 2e-4       # v14.3: reduced for larger model
    warmup_steps: int = 100
    min_lr: float = 1e-6

    # v14.1: gradient accumulation
    accumulation_steps: int = 2       # effective batch = 8 * 2 = 16

    # ==========================================================================
    # v14.1: Early Stopping (same as v13.1)
    # ==========================================================================
    use_early_stopping: bool = True
    early_stopping_patience: int = 800  # v14.3: more patience for larger model
    min_ppl_delta: float = 1.0

    # ==========================================================================
    # v14.1: POCL DISABLED (failed in v13)
    # ==========================================================================
    use_pocl: bool = False
    pocl_stages: int = 3
    pocl_temp_schedule: tuple = (1.0, 1.5, 2.0)
    pocl_pretrain_steps: int = 100

    # ==========================================================================
    # v14.1: CTKD ENABLED (proven in v12.1, v13.1)
    # ==========================================================================
    use_ctkd: bool = True
    tau_min: float = 1.0
    tau_max: float = 5.0
    tau_init: float = 2.0
    lambda_max: float = 1.0
    lambda_warmup_ratio: float = 0.25  # v14.3: slower CTKD ramp-up

    # Legacy flags (all disabled for v14.1)
    use_learnable_temperature: bool = False
    use_channel_wise_spikes: bool = False
    use_progressive_stages: bool = False
    temperature: float = 2.0

    # Hidden alignment DISABLED (using FDD instead)
    hidden_align_weight: float = 0.0
    teacher_d_model: int = 768
    teacher_n_layers: int = 12
    temperature_lr: float = 0.001

    # lora for ttt
    lora_rank: int = 8
    lora_alpha: float = 16.0
    ttt_lr: float = 1e-4
    ttt_steps: int = 100

    # spiking parameters
    spike_alpha: float = 1.0
    spike_threshold_mix: float = 0.35
    spike_surrogate_temp: float = 0.10

    # v15 spike semantic/health shaping
    use_spike_semantic_loss: bool = True
    spike_semantic_weight: float = 0.08
    spike_semantic_warmup_steps: int = 400
    spike_target_threshold_scale: float = 0.75

    # general training
    batch_size: int = 8
    max_grad_norm: float = 1.0
    eval_interval: int = 300

config = Config()

print(f"configuration (v14.1 - Hyperparameter Tuning per External LLM):")
print(f"  teacher: {config.teacher_name} (124m params)")
print(f"  student: d={config.d_model}, layers={config.n_layers} (~56M params)")
print(f"")
print(f"{config.VERSION} CHANGES (based on external LLM diagnosis):")
print(f"  d_model: 320 -> 512 (capacity increase for ternary compensation)")
print(f"  fdd_weight: 0.001 -> 0.1 (enable alignment gradient signal)")
print(f"  ce_hard_weight: {config.ce_hard_weight} (NEW - ground truth anchoring)")
print(f"")
print(f"{config.VERSION} INNOVATION - Feature Dynamics Distillation (FDD):")
print(f"  use_fdd: {config.use_fdd}")
print(f"  fdd_weight: {config.fdd_weight} (100x increase, CKA bounded [0,1], safe)")
print(f"  fdd_warmup_steps: {config.fdd_warmup_steps}")
print(f"  fdd_loss_type: {config.fdd_loss_type}")
print(f"  fdd_n_align_layers: {config.fdd_n_align_layers}")
print(f"  fdd_kill_threshold: {config.fdd_kill_threshold} (10% PPL increase triggers disable)")
print(f"")
print(f"  FDD Strategy:")
print(f"    - Align layer DYNAMICS (Δh), not just hidden states")
print(f"    - Use CKA loss (projector-free, dimension-agnostic)")
print(f"    - 100x weight increase (was too weak before)")
print(f"    - Safety kill-switch if PPL regresses")
print(f"")
print(f"{config.VERSION}: Hard Distillation:")
print(f"  ce_hard_weight: {config.ce_hard_weight}")
print(f"  Formula: L = KL + 0.5*CE + 0.1*FDD")
print(f"")
print(f"{config.VERSION}: CTKD (proven technique):")
print(f"  use_ctkd: {config.use_ctkd}")
print(f"  Temperature bounds: [{config.tau_min}, {config.tau_max}]")
print(f"  Lambda warmup: {config.lambda_warmup_ratio*100:.0f}%")
print(f"")
print(f"{config.VERSION}: Extended Training:")
print(f"  distill_steps: {config.distill_steps}")
print(f"  warmup_steps: {config.warmup_steps}")
print(f"  min_lr: {config.min_lr}")
print(f"")
print(f"{config.VERSION}: Early Stopping:")
print(f"  use_early_stopping: {config.use_early_stopping}")
print(f"  patience: {config.early_stopping_patience} steps")
print(f"  min_delta: {config.min_ppl_delta} PPL")
print(f"")
print(f"disabled features:")
print(f"  POCL: {config.use_pocl} (failed in v13)")
print(f"  channel-wise spikes: {config.use_channel_wise_spikes}")
print(f"  old hidden alignment: {config.hidden_align_weight}")
print(f"")
print(f"training:")
print(f"  accumulation: {config.accumulation_steps} (effective batch = {config.batch_size * config.accumulation_steps})")
print(f"  spike_threshold_mix: {config.spike_threshold_mix}")
print(f"  spike_surrogate_temp: {config.spike_surrogate_temp}")
print(f"  use_spike_semantic_loss: {config.use_spike_semantic_loss}")
if config.use_spike_semantic_loss:
    print(f"    spike_semantic_weight: {config.spike_semantic_weight}")
    print(f"    spike_semantic_warmup_steps: {config.spike_semantic_warmup_steps}")
    print(f"    spike_target_threshold_scale: {config.spike_target_threshold_scale}")
print(f"")
print(f"targets:")
print(f"  PPL: validate v14.3 (306.89) spike encoding quality")


In [None]:
# =============================================================================
# cell 3: PRE-TRAINING VALIDATION (run before training to catch issues)
# =============================================================================
print("=" * 70)
print("PRE-TRAINING VALIDATION")
print("=" * 70)

validation_errors = []
validation_warnings = []

# 1. Config Sanity Checks
print("")
print("[1] CONFIG SANITY CHECKS")

if config.d_model < 256:
    validation_errors.append(f"d_model={config.d_model} too small (min 256)")
elif config.d_model > 2048:
    validation_warnings.append(f"d_model={config.d_model} very large - check VRAM")
print(f"  d_model: {config.d_model}")

if config.n_layers < 3:
    validation_errors.append(f"n_layers={config.n_layers} too few")
print(f"  n_layers: {config.n_layers}")

print(f"  fdd_weight: {config.fdd_weight}")
print(f"  ce_hard_weight: {config.ce_hard_weight}")
print(f"  spike_threshold_mix: {config.spike_threshold_mix}")
if not (0.0 <= config.spike_threshold_mix <= 1.0):
    validation_errors.append(f"spike_threshold_mix={config.spike_threshold_mix} must be in [0, 1]")
if config.spike_surrogate_temp <= 0:
    validation_errors.append(f"spike_surrogate_temp={config.spike_surrogate_temp} must be > 0")
if config.use_spike_semantic_loss:
    print(f"  spike_semantic_weight: {config.spike_semantic_weight}")
    print(f"  spike_semantic_warmup_steps: {config.spike_semantic_warmup_steps}")
    if config.spike_semantic_weight < 0:
        validation_errors.append(f"spike_semantic_weight={config.spike_semantic_weight} must be >= 0")
print(f"  VERSION: {config.VERSION}")
print(f"  VERSION_DESC: {config.VERSION_DESC}")

# 2. Memory Estimation
print("")
print("[2] MEMORY ESTIMATION")
embed_params = config.vocab_size * config.d_model * 2
layer_params = config.n_layers * config.d_model * config.d_model * 8
total_params_est = embed_params + layer_params
print(f"  Estimated params: ~{total_params_est/1e6:.1f}M")

vram_est_gb = (total_params_est * 4 * 3) / 1e9
print(f"  Estimated VRAM: ~{vram_est_gb:.1f}GB")

if vram_est_gb > 14:
    validation_warnings.append(f"VRAM estimate {vram_est_gb:.1f}GB may exceed 16GB limit")
    print(f"  WARNING: May exceed 16GB VRAM limit!")

# 3. Training Config
print("")
print("[3] TRAINING CONFIG")
print(f"  distill_steps: {config.distill_steps}")
print(f"  distill_lr: {config.distill_lr}")
print(f"  batch_size: {config.batch_size}")
if hasattr(config, 'accumulation_steps'):
    eff_batch = config.batch_size * config.accumulation_steps
    print(f"  effective_batch: {eff_batch}")
print(f"  early_stopping: patience={config.early_stopping_patience}")

# 4. Feature Flags
print("")
print("[4] FEATURE FLAGS")
print(f"  use_fdd: {config.use_fdd}")
print(f"  use_ctkd: {config.use_ctkd}")
print(f"  use_pocl: {config.use_pocl}")

# Summary
print("")
print("=" * 70)
if validation_errors:
    print(f"VALIDATION FAILED - {len(validation_errors)} errors:")
    for e in validation_errors:
        print(f"   - {e}")
    raise RuntimeError("Fix validation errors before training!")
elif validation_warnings:
    print(f"VALIDATION PASSED WITH {len(validation_warnings)} WARNINGS:")
    for w in validation_warnings:
        print(f"   - {w}")
else:
    print("ALL VALIDATIONS PASSED")
print("=" * 70)


In [None]:
# =============================================================================
# cell 6: v13 PROPER CTKD Implementation
# =============================================================================
# References:
# - CTKD Paper: https://arxiv.org/abs/2211.16231
# - GRL Origin: Ganin & Lempitsky (2015) https://arxiv.org/abs/1409.7495
# - torch-gradient-reversal: https://pypi.org/project/torch-gradient-reversal/

# -----------------------------------------------------------------------------
# GradientReversalFunction (Custom Autograd)
# -----------------------------------------------------------------------------
class GradientReversalFunction(torch.autograd.Function):
    """
    Gradient Reversal Layer for adversarial training.
    
    Forward: Identity mapping f(x) = x
    Backward: Negates gradient ∂f/∂x = -λ * grad
    
    This enables min-max optimization in a single backward pass:
    - Student minimizes loss (normal gradients)
    - Temperature maximizes loss (reversed gradients via GRL)
    
    Reference: Ganin & Lempitsky, "Unsupervised Domain Adaptation by Backpropagation"
    """
    
    @staticmethod
    def forward(ctx, x, lambda_):
        # Save lambda for backward pass
        ctx.lambda_ = lambda_
        # Forward is identity (must clone to avoid in-place issues)
        return x.clone()
    
    @staticmethod
    def backward(ctx, grad_output):
        # Backward negates and scales gradient
        # Returns: (grad for x, grad for lambda_)
        # lambda_ is a hyperparameter, doesn't need gradient
        return -ctx.lambda_ * grad_output, None


class GradientReversalLayer(nn.Module):
    """
    Module wrapper for GradientReversalFunction.
    
    Usage:
        grl = GradientReversalLayer()
        grl.set_lambda(0.5)  # Set adversarial strength
        y = grl(x)  # Forward: y = x, Backward: grad_x = -0.5 * grad_y
    """
    
    def __init__(self):
        super().__init__()
        self.lambda_ = 1.0
    
    def set_lambda(self, lambda_: float):
        """Set the adversarial strength (0 = no reversal, 1 = full reversal)."""
        self.lambda_ = lambda_
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return GradientReversalFunction.apply(x, self.lambda_)


# -----------------------------------------------------------------------------
# Lambda Scheduler (Cosine with Warmup)
# -----------------------------------------------------------------------------
def get_lambda(step: int, total_steps: int, lambda_max: float = 1.0, 
               warmup_ratio: float = 0.2) -> float:
    """
    Cosine schedule for adversarial strength λ.
    
    - During warmup (first warmup_ratio of training): λ = 0
      Temperature learns freely to find reasonable range
    - After warmup: λ increases from 0 to lambda_max via cosine
      Gradually increases adversarial pressure
    
    Args:
        step: Current training step
        total_steps: Total number of training steps
        lambda_max: Maximum λ value (default 1.0 = full reversal)
        warmup_ratio: Fraction of training for warmup (default 0.2 = 20%)
    
    Returns:
        Current λ value in [0, lambda_max]
    """
    warmup_steps = int(total_steps * warmup_ratio)
    
    if step < warmup_steps:
        return 0.0
    
    # Progress after warmup [0, 1]
    progress = (step - warmup_steps) / max(total_steps - warmup_steps, 1)
    # Cosine increase from 0 to lambda_max
    lambda_ = lambda_max * (1 - math.cos(math.pi * progress)) / 2
    return lambda_


# -----------------------------------------------------------------------------
# CTKDTemperature (Proper Implementation with GRL)
# -----------------------------------------------------------------------------
class CTKDTemperature(nn.Module):
    """
    Curriculum Temperature for Knowledge Distillation (CTKD).
    
    Key features:
    1. Adversarial learning via Gradient Reversal Layer
    2. Sigmoid bounding for smooth gradients at boundaries
    3. Proper initialization via logit transform
    
    The temperature module tries to MAXIMIZE the KL loss (via GRL),
    finding the "hardest" temperature for the student.
    The student tries to MINIMIZE the KL loss.
    This adversarial game leads to optimal curriculum difficulty.
    
    Reference: Li et al., "Curriculum Temperature for Knowledge Distillation", AAAI 2023
    """
    
    def __init__(self, tau_min: float = 1.0, tau_max: float = 5.0, init: float = 2.0):
        """
        Args:
            tau_min: Minimum temperature (default 1.0)
            tau_max: Maximum temperature (default 5.0, conservative for LLMs)
            init: Initial temperature (default 2.0)
        """
        super().__init__()
        self.tau_min = tau_min
        self.tau_range = tau_max - tau_min
        
        # Initialize raw parameter so sigmoid outputs init value
        # sigmoid(raw) = (init - tau_min) / tau_range
        # raw = logit((init - tau_min) / tau_range)
        init_normalized = (init - tau_min) / self.tau_range
        init_normalized = max(0.01, min(0.99, init_normalized))  # Clamp for numerical stability
        init_raw = math.log(init_normalized / (1 - init_normalized))  # logit function
        
        self.raw_temp = nn.Parameter(torch.tensor(init_raw, dtype=torch.float32))
        self.grl = GradientReversalLayer()
        
        # Store config for logging
        self.tau_min_val = tau_min
        self.tau_max_val = tau_max
        self.init_val = init
    
    def forward(self, lambda_: float) -> torch.Tensor:
        """
        Compute temperature with GRL applied.
        
        Args:
            lambda_: Current adversarial strength from scheduler
        
        Returns:
            Temperature τ ∈ [tau_min, tau_max]
        """
        # Set GRL strength
        self.grl.set_lambda(lambda_)
        
        # Apply GRL to raw parameter (this is where gradient reversal happens!)
        raw_reversed = self.grl(self.raw_temp)
        
        # Sigmoid bounding (smooth, differentiable at boundaries)
        tau = self.tau_min + self.tau_range * torch.sigmoid(raw_reversed)
        
        return tau
    
    def get_temperature(self) -> float:
        """Get current temperature without GRL (for logging/display)."""
        with torch.no_grad():
            tau = self.tau_min + self.tau_range * torch.sigmoid(self.raw_temp)
            return tau.item()
    
    def get_raw_value(self) -> float:
        """Get raw (unbounded) parameter value (for debugging)."""
        return self.raw_temp.item()


# -----------------------------------------------------------------------------
# Legacy Classes (kept for backward compatibility)
# -----------------------------------------------------------------------------
class LearnableTemperature(nn.Module):
    """
    DEPRECATED: Simple learnable temperature WITHOUT GRL.
    Kept for backward compatibility. Use CTKDTemperature instead.
    
    WARNING: This class caused temperature runaway in v12!
    """
    
    def __init__(self, init: float = 2.0):
        super().__init__()
        self.log_temp = nn.Parameter(torch.log(torch.tensor(init)))
    
    def forward(self) -> torch.Tensor:
        return torch.exp(self.log_temp).clamp(1.0, 10.0)
    
    def get_temperature(self) -> float:
        return self.forward().item()


class ChannelWiseTernarySpike(nn.Module):
    """
    Per-channel learnable alpha and amplitude for ternary spikes.
    DISABLED in v13 due to structural symmetry issue with RWKV.
    """
    
    def __init__(self, d_model: int, alpha_init: float = 1.0):
        super().__init__()
        self.d_model = d_model
        self.alpha = nn.Parameter(torch.ones(d_model) * alpha_init)
        self.amplitude = nn.Parameter(torch.ones(d_model))
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x_abs_mean = x.abs().mean(dim=(0, 1), keepdim=True)
        threshold = self.alpha * x_abs_mean
        threshold = threshold.clamp(min=0.01, max=10.0)
        
        with torch.no_grad():
            pos_mask = (x > threshold).float()
            neg_mask = (x < -threshold).float()
            spike_signs = pos_mask - neg_mask
        
        spikes = self.amplitude * spike_signs
        return spikes + (x - x.detach())
    
    def get_amplitude(self) -> float:
        return self.amplitude.mean().item()
    
    def get_stats(self) -> dict:
        return {
            'alpha_mean': self.alpha.mean().item(),
            'alpha_std': self.alpha.std().item(),
            'amplitude_mean': self.amplitude.mean().item(),
            'amplitude_std': self.amplitude.std().item(),
        }

    def get_amplitude_stats(self) -> dict:
        return {
            'mean': self.amplitude.mean().item(),
            'std': self.amplitude.std().item(),
            'min': self.amplitude.min().item(),
            'max': self.amplitude.max().item(),
        }


class TrainableTernarySpike(nn.Module):
    """Original trainable ternary spike with scalar amplitude (from v8)."""

    def __init__(
        self,
        alpha: float = 1.0,
        threshold_mix: float = 0.35,
        surrogate_temp: float = 0.10,
    ):
        super().__init__()
        self.alpha = alpha
        self.threshold_mix = threshold_mix
        self.surrogate_temp = surrogate_temp
        self.amplitude = nn.Parameter(torch.ones(1))

    def forward(self, x: torch.Tensor, return_aux: bool = False):
        token_scale = x.abs().mean(dim=-1, keepdim=True)
        channel_scale = x.abs().mean(dim=(0, 1), keepdim=True)
        threshold = self.alpha * (
            (1.0 - self.threshold_mix) * token_scale + self.threshold_mix * channel_scale
        )
        threshold = threshold.clamp(min=0.01, max=10.0)

        with torch.no_grad():
            pos_mask = (x > threshold).float()
            neg_mask = (x < -threshold).float()
            spike_signs = pos_mask - neg_mask

        amplitude = self.amplitude.clamp(min=0.25, max=4.0)
        spikes = amplitude * spike_signs
        spikes = spikes + (x - x.detach())

        if return_aux:
            soft_activity = torch.sigmoid((x.abs() - threshold) / self.surrogate_temp)
            return spikes, {
                'threshold': threshold.detach(),
                'soft_activity': soft_activity,
            }
        return spikes

    def get_amplitude(self) -> float:
        return self.amplitude.item()


def get_stage_params(step: int, total_steps: int = 3000) -> dict:
    """Progressive training stages (POCL) - kept for infrastructure."""
    if step < total_steps * 0.4:
        return {'stage': 1, 'temp_target': 1.0, 'align_mult': 0.0, 'alpha': 0.9}
    elif step < total_steps * 0.7:
        return {'stage': 2, 'temp_target': 1.5, 'align_mult': 0.5, 'alpha': 0.7}
    else:
        return {'stage': 3, 'temp_target': 2.0, 'align_mult': 1.0, 'alpha': 0.5}


# -----------------------------------------------------------------------------
# Unit Tests for CTKD Components
# -----------------------------------------------------------------------------
print("="*60)
print("v13 CTKD Component Tests")
print("="*60)

# Test 1: GRL Gradient Reversal
print("\n[1] GRL Gradient Reversal Test")
grl = GradientReversalLayer()
grl.set_lambda(1.0)
x_test = torch.tensor([2.0], requires_grad=True)
y_test = grl(x_test)
loss_test = y_test.sum()
loss_test.backward()
expected_grad = -1.0  # GRL should negate: 1 * -1.0 = -1.0
actual_grad = x_test.grad.item()
grl_pass = abs(actual_grad - expected_grad) < 1e-6
print(f"  Input grad without GRL would be: +1.0")
print(f"  Input grad with GRL (λ=1.0): {actual_grad:.4f}")
print(f"  Expected: {expected_grad:.4f}")
print(f"  {'PASS' if grl_pass else 'FAIL'}")
del x_test, y_test, loss_test

# Test 2: Lambda Schedule
print("\n[2] Lambda Schedule Test")
total = 3000
warmup = 0.2
# During warmup
lambda_0 = get_lambda(0, total, warmup_ratio=warmup)
lambda_500 = get_lambda(500, total, warmup_ratio=warmup)
# After warmup
lambda_1500 = get_lambda(1500, total, warmup_ratio=warmup)
lambda_2999 = get_lambda(2999, total, warmup_ratio=warmup)

warmup_pass = lambda_0 == 0.0 and lambda_500 == 0.0
increase_pass = 0 < lambda_1500 < lambda_2999 <= 1.0
lambda_pass = warmup_pass and increase_pass
print(f"  λ(0) = {lambda_0:.4f} (should be 0.0)")
print(f"  λ(500) = {lambda_500:.4f} (should be 0.0, still in warmup)")
print(f"  λ(1500) = {lambda_1500:.4f} (should be > 0)")
print(f"  λ(2999) = {lambda_2999:.4f} (should be ≈ 1.0)")
print(f"  {'PASS' if lambda_pass else 'FAIL'}")

# Test 3: Temperature Bounds
print("\n[3] Temperature Bounds Test")
temp_module = CTKDTemperature(tau_min=1.0, tau_max=5.0, init=2.0).to(DEVICE)
init_temp = temp_module.get_temperature()

# Force extreme raw values
with torch.no_grad():
    temp_module.raw_temp.fill_(-100)
    tau_low = temp_module.get_temperature()
    
    temp_module.raw_temp.fill_(100)
    tau_high = temp_module.get_temperature()
    
    # Reset to init
    init_normalized = (2.0 - 1.0) / 4.0
    init_raw = math.log(init_normalized / (1 - init_normalized))
    temp_module.raw_temp.fill_(init_raw)

bounds_pass = (1.0 <= tau_low <= 1.01) and (4.99 <= tau_high <= 5.0) and (1.9 <= init_temp <= 2.1)
print(f"  Initial temp: {init_temp:.4f} (should be ≈ 2.0)")
print(f"  Min bound test: {tau_low:.4f} (should be ≈ 1.0)")
print(f"  Max bound test: {tau_high:.4f} (should be ≈ 5.0)")
print(f"  {'PASS' if bounds_pass else 'FAIL'}")

# Test 4: End-to-End Gradient Flow
print("\n[4] End-to-End Gradient Flow Test")
temp_module_test = CTKDTemperature(tau_min=1.0, tau_max=5.0, init=2.0).to(DEVICE)
lambda_test = 0.5

# Simulate forward pass
T = temp_module_test(lambda_test)
fake_kl_loss = T * 2.0  # Gradient ∂L/∂T = 2.0

# Without GRL: optimizer would DECREASE T to minimize loss
# With GRL: optimizer should INCREASE T (because grad is reversed)
fake_kl_loss.backward()

raw_grad = temp_module_test.raw_temp.grad.item()
# The gradient through sigmoid and GRL should be negative (reversed)
# Original: ∂L/∂raw > 0 would decrease raw
# With GRL: ∂L/∂raw < 0 (negated), so optimizer increases raw
grad_flow_pass = raw_grad < 0  # Should be negative due to GRL
print(f"  Loss = T * 2.0, so ∂L/∂T = 2.0 (positive)")
print(f"  Without GRL: raw_grad would be positive (decrease T)")
print(f"  With GRL (λ=0.5): raw_grad = {raw_grad:.4f} (should be negative)")
print(f"  {'PASS' if grad_flow_pass else 'FAIL'}")
del temp_module_test

# Summary
print("\n" + "="*60)
all_pass = grl_pass and lambda_pass and bounds_pass and grad_flow_pass
print(f"CTKD Component Tests: {'ALL PASS' if all_pass else 'SOME FAILED'}")
if not all_pass:
    print("WARNING: Fix failing tests before running training!")
print("="*60)


In [None]:
# =============================================================================
# cell 6.5: v14.1 FDD (Feature Dynamics Distillation) with CKA Loss
# =============================================================================
# References:
# - CKA: Kornblith et al., "Similarity of Neural Network Representations Revisited"
# - FDD: Feature Dynamics Distillation (view transformer as ODE)
# - v7 lesson: Hidden alignment with weight=1.0 caused PPL regression to 1655!

# -----------------------------------------------------------------------------
# Centered Kernel Alignment (CKA) Loss
# -----------------------------------------------------------------------------
def cka_loss(X: torch.Tensor, Y: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
    """
    Centered Kernel Alignment (CKA) loss for representation alignment.

    CKA is a similarity measure between feature representations that is:
    - Invariant to orthogonal transformations
    - Invariant to isotropic scaling
    - Does NOT require dimension matching (projector-free!)

    Args:
        X: Student features [n_samples, dim_x]
        Y: Teacher features [n_samples, dim_y]
        eps: Small constant for numerical stability

    Returns:
        Loss = 1 - CKA (minimize to maximize alignment)
        CKA = 1 means perfect alignment, CKA = 0 means no alignment

    Note: n_samples must match, but dim_x and dim_y can differ!

    CRITICAL: Uses float32 to prevent overflow in mixed precision training.
    With n=2048 samples, Gram matrix sums can exceed float16 max (~65504).
    This is training-time only - does NOT affect student's ternary activations.
    """
    # CRITICAL: Force float32 to prevent overflow in mixed precision
    # Under torch.cuda.amp.autocast(), tensors are float16 by default
    # Gram matrix sums: 4M elements squared and summed -> can exceed 65504
    with torch.cuda.amp.autocast(enabled=False):
        X = X.float()
        Y = Y.float()

        # Validate input shapes
        assert X.dim() == 2 and Y.dim() == 2, f"Expected 2D tensors, got X:{X.dim()}D, Y:{Y.dim()}D"
        assert X.size(0) == Y.size(0), f"Sample count mismatch: X={X.size(0)}, Y={Y.size(0)}"

        # Center the features (critical for CKA)
        X_centered = X - X.mean(dim=0, keepdim=True)
        Y_centered = Y - Y.mean(dim=0, keepdim=True)

        # Row-normalize for additional numerical stability
        # This bounds Gram matrix elements to [-1, 1] range
        X_norm = X_centered / (X_centered.norm(dim=1, keepdim=True) + eps)
        Y_norm = Y_centered / (Y_centered.norm(dim=1, keepdim=True) + eps)

        # Compute Gram matrices on normalized features
        # K_X[i,j] = dot(X_norm[i], X_norm[j]) in [-1, 1]
        K_X = X_norm @ X_norm.T  # [n, n]
        K_Y = Y_norm @ Y_norm.T  # [n, n]

        # HSIC (Hilbert-Schmidt Independence Criterion)
        # Now numerically stable: elements in [-1, 1]^2 = [0, 1]
        hsic_xy = (K_X * K_Y).sum()
        hsic_xx = (K_X * K_X).sum()
        hsic_yy = (K_Y * K_Y).sum()

        # CKA = HSIC(X,Y) / sqrt(HSIC(X,X) * HSIC(Y,Y))
        cka = hsic_xy / (torch.sqrt(hsic_xx * hsic_yy) + eps)

        # Clamp to valid range (numerical safety)
        cka = cka.clamp(0.0, 1.0)

        # Return loss (1 - CKA): minimize loss = maximize alignment
        return 1.0 - cka


# -----------------------------------------------------------------------------
# Layer Mapping for FDD
# -----------------------------------------------------------------------------
def get_fdd_layer_mapping(n_student_layers: int, n_teacher_layers: int,
                          n_align_layers: int = 3) -> Dict[int, int]:
    """
    Create layer mapping for Feature Dynamics Distillation.

    Maps student layers to teacher layers for hidden state alignment.
    Uses even spacing to cover early/middle/late representations.

    Args:
        n_student_layers: Number of student layers (e.g., 5)
        n_teacher_layers: Number of teacher layers (e.g., 12)
        n_align_layers: Number of layer pairs to align (default 3)

    Returns:
        Dict mapping student_layer_idx -> teacher_layer_idx

    Example for 5 student, 12 teacher, 3 alignments:
        {0: 2, 2: 7, 4: 11}  # Early, middle, late (with +2 offset)
    """
    if n_align_layers > n_student_layers:
        n_align_layers = n_student_layers

    layer_map = {}

    # Evenly space the student layers to align
    student_indices = []
    for i in range(n_align_layers):
        # 0, 2, 4 for 3 alignments with 5 layers
        idx = int(i * (n_student_layers - 1) / max(n_align_layers - 1, 1))
        student_indices.append(idx)

    # Map each student index to corresponding teacher layer
    for s_idx in student_indices:
        # Scale to teacher layers
        t_idx = int((s_idx / (n_student_layers - 1)) * (n_teacher_layers - 1))
        # Offset by 1-2 to avoid embedding layer
        t_idx = max(2, min(t_idx + 2, n_teacher_layers - 1))
        layer_map[s_idx] = t_idx

    return layer_map


# -----------------------------------------------------------------------------
# Feature Dynamics Distillation Loss
# -----------------------------------------------------------------------------
def compute_fdd_loss(
    student_hiddens: List[torch.Tensor],
    teacher_hiddens: List[torch.Tensor],
    layer_map: Dict[int, int],
    loss_type: str = "cka"
) -> torch.Tensor:
    """
    Compute Feature Dynamics Distillation (FDD) loss.

    FDD views the transformer as solving an ODE: dh/dt = f(h, t)
    where each layer is a discrete time step.

    Instead of matching hidden states directly (which failed in v7),
    we match the DYNAMICS (layer-to-layer changes): delta_h = h_{l+1} - h_l

    This teaches the student HOW to transform features, not just WHAT features to have.

    Args:
        student_hiddens: List of student hidden states [h_0, h_1, ..., h_L]
                        Each has shape [batch, seq, student_dim]
        teacher_hiddens: List of teacher hidden states (from output_hidden_states=True)
                        Each has shape [batch, seq, teacher_dim]
        layer_map: Dict mapping student_layer_idx -> teacher_layer_idx
        loss_type: "cka" (recommended) or "mse"

    Returns:
        FDD loss (scalar tensor)

    Note: student_hiddens[0] is embedding, student_hiddens[1] is after layer 0, etc.
    """
    total_loss = torch.tensor(0.0, device=student_hiddens[0].device)
    n_pairs = 0

    for s_layer, t_layer in layer_map.items():
        # Validate indices
        # student_hiddens: [embed, after_L0, after_L1, ..., after_L{n-1}]
        # So layer i output is at index i+1
        s_idx = s_layer + 1  # +1 because [0] is embedding
        t_idx = t_layer + 1  # Same for teacher

        # Check bounds
        if s_idx + 1 >= len(student_hiddens):
            continue
        if t_idx + 1 >= len(teacher_hiddens):
            continue

        # Compute dynamics (velocity): delta_h = h_{l+1} - h_l
        # Student dynamics: change from layer s_layer to s_layer+1
        delta_s = student_hiddens[s_idx + 1] - student_hiddens[s_idx]  # [batch, seq, s_dim]

        # Teacher dynamics: change from layer t_layer to t_layer+1
        delta_t = teacher_hiddens[t_idx + 1] - teacher_hiddens[t_idx]  # [batch, seq, t_dim]

        # Flatten for CKA: [batch, seq, dim] -> [batch*seq, dim]
        batch_size, seq_len = delta_s.size(0), delta_s.size(1)
        delta_s_flat = delta_s.reshape(batch_size * seq_len, -1)  # [n, s_dim]
        delta_t_flat = delta_t.reshape(batch_size * seq_len, -1)  # [n, t_dim]

        # Compute loss
        if loss_type == "cka":
            pair_loss = cka_loss(delta_s_flat, delta_t_flat)
        elif loss_type == "mse":
            # MSE requires dimension matching - use projector if needed
            # For now, skip if dimensions don't match
            if delta_s_flat.size(1) != delta_t_flat.size(1):
                continue
            pair_loss = F.mse_loss(delta_s_flat, delta_t_flat)
        else:
            raise ValueError(f"Unknown loss_type: {loss_type}")

        total_loss = total_loss + pair_loss
        n_pairs += 1

    # Average over pairs
    if n_pairs > 0:
        total_loss = total_loss / n_pairs

    return total_loss


# -----------------------------------------------------------------------------
# FDD Weight Scheduler (with warmup)
# -----------------------------------------------------------------------------
def get_fdd_weight(step: int, fdd_warmup_steps: int, fdd_weight: float) -> float:
    """
    Get FDD weight for current step with warmup.

    Args:
        step: Current training step
        fdd_warmup_steps: Steps before FDD kicks in
        fdd_weight: Maximum FDD weight

    Returns:
        Current FDD weight (0 during warmup, then fdd_weight)
    """
    if step < fdd_warmup_steps:
        return 0.0
    return fdd_weight


# -----------------------------------------------------------------------------
# FDD Unit Tests
# -----------------------------------------------------------------------------
print("="*60)
print("v14.1 FDD Component Tests")
print("="*60)

fdd_tests = []

# Test 1: CKA of identical tensors should return 0 loss (CKA=1)
print("\n[1] CKA Identical Tensors Test")
X_test = torch.randn(100, 64)
cka_identical = cka_loss(X_test, X_test)
identical_pass = cka_identical.item() < 0.01  # Should be ~0
print(f"  CKA loss of identical tensors: {cka_identical.item():.6f}")
print(f"  Expected: ~0.0 (CKA=1 means perfect alignment)")
print(f"  {'PASS' if identical_pass else 'FAIL'}")
fdd_tests.append(('CKA identical', identical_pass))

# Test 2: CKA of orthogonal tensors should return high loss
print("\n[2] CKA Orthogonal Tensors Test")
X_orth = torch.randn(100, 64)
Y_orth = torch.randn(100, 64)  # Different random = nearly orthogonal
cka_orthogonal = cka_loss(X_orth, Y_orth)
orthogonal_pass = 0.5 < cka_orthogonal.item() <= 1.0  # Should be high
print(f"  CKA loss of orthogonal tensors: {cka_orthogonal.item():.4f}")
print(f"  Expected: 0.5-1.0 (low alignment)")
print(f"  {'PASS' if orthogonal_pass else 'FAIL'}")
fdd_tests.append(('CKA orthogonal', orthogonal_pass))

# Test 3: CKA handles different dimensions
print("\n[3] CKA Dimension Agnostic Test")
X_small = torch.randn(100, 32)   # 32 dims
Y_large = torch.randn(100, 128)  # 128 dims
try:
    cka_diff_dim = cka_loss(X_small, Y_large)
    dim_pass = True
    print(f"  CKA with dims (32, 128): {cka_diff_dim.item():.4f}")
except Exception as e:
    dim_pass = False
    print(f"  ERROR: {e}")
print(f"  {'PASS' if dim_pass else 'FAIL'}")
fdd_tests.append(('CKA dimension agnostic', dim_pass))

# Test 4: Layer mapping correctness
print("\n[4] Layer Mapping Test")
layer_map = get_fdd_layer_mapping(n_student_layers=5, n_teacher_layers=12, n_align_layers=3)
# Actual computation: s_idx=0 -> t_idx=0+2=2, s_idx=2 -> t_idx=5+2=7, s_idx=4 -> t_idx=11 (clamped)
expected_map = {0: 2, 2: 7, 4: 11}  # Corrected expectation
map_pass = layer_map == expected_map
print(f"  Generated map: {layer_map}")
print(f"  Expected map: {expected_map}")
print(f"  {'PASS' if map_pass else 'FAIL'}")
fdd_tests.append(('Layer mapping', map_pass))

# Test 5: FDD loss computation
print("\n[5] FDD Loss Computation Test")
# Mock hidden states
student_hiddens_mock = [torch.randn(2, 16, 320) for _ in range(6)]  # embed + 5 layers
teacher_hiddens_mock = [torch.randn(2, 16, 768) for _ in range(13)]  # embed + 12 layers
try:
    fdd_loss_val = compute_fdd_loss(
        student_hiddens_mock,
        teacher_hiddens_mock,
        layer_map,
        loss_type="cka"
    )
    fdd_pass = 0.0 <= fdd_loss_val.item() <= 1.0
    print(f"  FDD loss: {fdd_loss_val.item():.4f}")
    print(f"  Expected: [0, 1]")
except Exception as e:
    fdd_pass = False
    print(f"  ERROR: {e}")
print(f"  {'PASS' if fdd_pass else 'FAIL'}")
fdd_tests.append(('FDD loss computation', fdd_pass))

# Test 6: FDD weight scheduler
print("\n[6] FDD Weight Scheduler Test")
w_0 = get_fdd_weight(0, 500, 0.1)
w_400 = get_fdd_weight(400, 500, 0.1)
w_500 = get_fdd_weight(500, 500, 0.1)
w_1000 = get_fdd_weight(1000, 500, 0.1)
scheduler_pass = (w_0 == 0.0 and w_400 == 0.0 and w_500 == 0.1 and w_1000 == 0.1)
print(f"  weight(0): {w_0} (should be 0)")
print(f"  weight(400): {w_400} (should be 0)")
print(f"  weight(500): {w_500} (should be 0.001)")
print(f"  weight(1000): {w_1000} (should be 0.001)")
print(f"  {'PASS' if scheduler_pass else 'FAIL'}")
fdd_tests.append(('FDD weight scheduler', scheduler_pass))

# Test 7: CKA float32 stability test (simulates mixed precision)
print("\n[7] CKA Float32 Stability Test")
# Simulate large values that would overflow in float16
X_large = torch.randn(2048, 320) * 100  # Large values
Y_large = torch.randn(2048, 768) * 100
try:
    cka_large = cka_loss(X_large, Y_large)
    stability_pass = not (torch.isnan(cka_large) or torch.isinf(cka_large))
    print(f"  CKA with large values (n=2048): {cka_large.item():.4f}")
    print(f"  No NaN/Inf: {stability_pass}")
except Exception as e:
    stability_pass = False
    print(f"  ERROR: {e}")
print(f"  {'PASS' if stability_pass else 'FAIL'}")
fdd_tests.append(('CKA float32 stability', stability_pass))

# Summary
print("\n" + "="*60)
all_fdd_pass = all(p for _, p in fdd_tests)
print(f"FDD Component Tests: {'ALL PASS' if all_fdd_pass else 'SOME FAILED'}")
if not all_fdd_pass:
    failed = [n for n, p in fdd_tests if not p]
    print(f"FAILED: {failed}")
print("="*60)


In [None]:
# =============================================================================
# cell 7: v13 POCL (Progressive Overload Curriculum Learning)
# =============================================================================
# Reference: "POCL: Progressive Overload Curriculum Learning" (2025)
# arXiv:2506.05695

# -----------------------------------------------------------------------------
# Sample Difficulty Scoring
# -----------------------------------------------------------------------------
def compute_sample_difficulty(student, teacher, dataloader, device, max_batches=50):
    """
    Compute difficulty scores for each sample using student-teacher divergence.

    Difficulty = average (CE loss + KL divergence) per sample.
    Higher score = harder sample for the student.

    Uses a small pre-trained student to get meaningful gradients.

    Args:
        student: Student model (should be briefly pre-trained)
        teacher: Teacher model (frozen)
        dataloader: Training data loader
        device: Compute device
        max_batches: Limit batches for efficiency

    Returns:
        Dict with sample indices and difficulty scores
    """
    student.eval()
    teacher.eval()

    all_difficulties = []
    all_indices = []
    sample_idx = 0

    with torch.no_grad():
        for batch_idx, batch in enumerate(dataloader):
            if batch_idx >= max_batches:
                break

            ids = batch[0].to(device, non_blocking=True)
            batch_size = ids.size(0)

            # Get logits
            s_logits = student(ids)
            t_logits = teacher(ids).logits

            # Per-sample difficulty (average over sequence)
            # 1. Cross-entropy with teacher as target
            s_probs = F.softmax(s_logits, dim=-1)
            t_probs = F.softmax(t_logits, dim=-1)

            # KL divergence per sample
            kl_div = F.kl_div(
                F.log_softmax(s_logits, dim=-1),
                t_probs,
                reduction='none'
            ).sum(dim=-1).mean(dim=-1)  # [batch_size]

            # Cross-entropy per sample (using teacher hard targets)
            t_tokens = t_logits.argmax(dim=-1)
            ce_loss = F.cross_entropy(
                s_logits.view(-1, s_logits.size(-1)),
                t_tokens.view(-1),
                reduction='none'
            ).view(batch_size, -1).mean(dim=-1)  # [batch_size]

            # Combined difficulty
            difficulty = kl_div + ce_loss  # [batch_size]

            all_difficulties.extend(difficulty.cpu().tolist())
            all_indices.extend(range(sample_idx, sample_idx + batch_size))
            sample_idx += batch_size

            if batch_idx % 10 == 0:
                print(f"  Scoring batch {batch_idx+1}/{max_batches}...")

    student.train()

    return {
        'indices': all_indices,
        'difficulties': all_difficulties,
        'num_samples': len(all_indices)
    }


# -----------------------------------------------------------------------------
# Data Partitioning by Difficulty
# -----------------------------------------------------------------------------
def partition_by_difficulty(difficulties_dict, n_stages=3):
    """
    Partition data into stages by difficulty (easy -> hard).

    Stage 1: Easiest 33%
    Stage 2: Easiest 66% (includes stage 1)
    Stage 3: All 100% (includes stages 1+2)

    Args:
        difficulties_dict: Output from compute_sample_difficulty()
        n_stages: Number of stages (default 3)

    Returns:
        List of index lists, one per stage (cumulative)
    """
    indices = difficulties_dict['indices']
    difficulties = difficulties_dict['difficulties']

    # Sort by difficulty (ascending = easy first)
    sorted_pairs = sorted(zip(indices, difficulties), key=lambda x: x[1])
    sorted_indices = [idx for idx, _ in sorted_pairs]

    n = len(sorted_indices)
    stage_indices = []

    for stage in range(n_stages):
        # Cumulative: stage 1 = 33%, stage 2 = 66%, stage 3 = 100%
        end_idx = int(n * (stage + 1) / n_stages)
        stage_indices.append(sorted_indices[:end_idx])

    return stage_indices


# -----------------------------------------------------------------------------
# Brief Pre-training for Difficulty Scoring
# -----------------------------------------------------------------------------
def pretrain_for_difficulty_scoring(student, teacher, train_loader, cfg, device, steps=100):
    """
    Brief pre-training so difficulty scores are meaningful.

    Without pre-training, student predictions are random garbage,
    making all samples appear equally difficult.

    Args:
        student: Student model
        teacher: Teacher model (frozen)
        train_loader: Training data loader
        cfg: Config object
        device: Compute device
        steps: Number of pre-training steps

    Returns:
        Student model (modified in-place)
    """
    print(f"Pre-training student for {steps} steps (for difficulty scoring)...")

    optimizer = torch.optim.AdamW(student.parameters(), lr=cfg.distill_lr, weight_decay=0.01)
    scaler = torch.cuda.amp.GradScaler()

    student.train()
    teacher.eval()

    step = 0
    pbar = tqdm(total=steps, desc='Pre-training')

    for batch in train_loader:
        if step >= steps:
            break

        ids = batch[0].to(device, non_blocking=True)

        with torch.cuda.amp.autocast():
            with torch.no_grad():
                t_logits = teacher(ids).logits

            s_logits = student(ids)

            # Simple KL loss (no temperature complexity)
            T = 2.0
            s_log = F.log_softmax(s_logits / T, dim=-1)
            t_prob = F.softmax(t_logits / T, dim=-1)
            loss = F.kl_div(
                s_log.view(-1, s_logits.size(-1)),
                t_prob.view(-1, t_logits.size(-1)),
                reduction='batchmean'
            ) * (T ** 2)

        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        step += 1
        pbar.update(1)
        if step % 20 == 0:
            pbar.set_postfix(loss=f"{loss.item():.3f}")

    pbar.close()
    print(f"Pre-training complete. Final loss: {loss.item():.3f}")

    return student


# -----------------------------------------------------------------------------
# Get Stage Temperature (Fixed Schedule)
# -----------------------------------------------------------------------------
def get_pocl_temperature(step, total_steps, temp_schedule, n_stages=3):
    """
    Get temperature for current POCL stage.

    Args:
        step: Current training step
        total_steps: Total training steps
        temp_schedule: Tuple of temperatures per stage (e.g., (1.0, 1.5, 2.0))
        n_stages: Number of stages

    Returns:
        Temperature for current stage
    """
    current_stage = get_pocl_stage(step, total_steps, n_stages)
    return temp_schedule[current_stage]


def get_pocl_stage(step, total_steps, n_stages=3):
    """
    Get current POCL stage (0-indexed).

    Uses rounded boundaries to ensure even distribution:
    - 5000 steps, 3 stages: boundaries at 1667, 3333
    - Stage 0: steps 0-1666
    - Stage 1: steps 1667-3332
    - Stage 2: steps 3333-4999
    """
    for i in range(n_stages - 1):
        boundary = round((i + 1) * total_steps / n_stages)
        if step < boundary:
            return i
    return n_stages - 1


# -----------------------------------------------------------------------------
# POCL Unit Tests
# -----------------------------------------------------------------------------
print("="*60)
print("v13 POCL Component Tests")
print("="*60)

# Test 1: Temperature Schedule
print("\n[1] Temperature Schedule Test")
temp_schedule = (1.0, 1.5, 2.0)
total = 5000

t_start = get_pocl_temperature(0, total, temp_schedule)
t_stage1_end = get_pocl_temperature(1666, total, temp_schedule)  # End of stage 1
t_stage2_start = get_pocl_temperature(1667, total, temp_schedule)  # Start of stage 2
t_stage2_end = get_pocl_temperature(3332, total, temp_schedule)
t_stage3 = get_pocl_temperature(4000, total, temp_schedule)

temp_pass = (t_start == 1.0 and t_stage1_end == 1.0 and t_stage2_start == 1.5 and t_stage3 == 2.0)
print(f"  T(0) = {t_start} (should be 1.0)")
print(f"  T(1666) = {t_stage1_end} (should be 1.0, end of stage 1)")
print(f"  T(1667) = {t_stage2_start} (should be 1.5, start of stage 2)")
print(f"  T(4000) = {t_stage3} (should be 2.0, stage 3)")
print(f"  {'PASS' if temp_pass else 'FAIL'}")

# Test 2: Stage Boundaries
print("\n[2] Stage Boundaries Test")
stages = [get_pocl_stage(s, total) for s in [0, 1666, 1667, 3332, 3333, 4999]]
stage_pass = stages == [0, 0, 1, 1, 2, 2]
print(f"  Stages at [0, 1666, 1667, 3332, 3333, 4999]: {stages}")
print(f"  Expected: [0, 0, 1, 1, 2, 2]")
print(f"  {'PASS' if stage_pass else 'FAIL'}")

# Test 3: Partition by Difficulty (mock)
print("\n[3] Partition by Difficulty Test (mock data)")
mock_difficulties = {
    'indices': list(range(9)),
    'difficulties': [0.5, 1.5, 0.3, 2.0, 0.8, 1.2, 2.5, 0.1, 1.8],  # Easy: 7,2,0,4 | Med: 5,1 | Hard: 8,3,6
    'num_samples': 9
}
partitions = partition_by_difficulty(mock_difficulties, n_stages=3)
# After sorting: [7(0.1), 2(0.3), 0(0.5), 4(0.8), 5(1.2), 1(1.5), 8(1.8), 3(2.0), 6(2.5)]
# Stage 1 (33%): indices [7, 2, 0] -> 3 samples
# Stage 2 (66%): indices [7, 2, 0, 4, 5, 1] -> 6 samples
# Stage 3 (100%): all 9 samples
partition_pass = (len(partitions[0]) == 3 and len(partitions[1]) == 6 and len(partitions[2]) == 9)
print(f"  Stage 1 samples: {len(partitions[0])} (should be 3)")
print(f"  Stage 2 samples: {len(partitions[1])} (should be 6)")
print(f"  Stage 3 samples: {len(partitions[2])} (should be 9)")
print(f"  Cumulative check: {partitions[0][0] in partitions[1] and partitions[1][0] in partitions[2]}")
print(f"  {'PASS' if partition_pass else 'FAIL'}")

# Summary
print("\n" + "="*60)
all_pass = temp_pass and stage_pass and partition_pass
print(f"POCL Component Tests: {'ALL PASS' if all_pass else 'SOME FAILED'}")
if not all_pass:
    print("WARNING: Fix failing tests before running training!")
print("="*60)


In [None]:
# =============================================================================
# cell 7: hardware and spike stats collectors (same as v9)
# =============================================================================
class HardwareStatsCollector:
    """collect gpu memory, timing, and throughput metrics."""

    def __init__(self):
        self.gpu_memory_history = []
        self.step_times = []
        self.tokens_processed = 0
        self.start_time = None

    def start(self):
        self.start_time = time.time()
        if torch.cuda.is_available():
            torch.cuda.reset_peak_memory_stats()

    def record_step(self, batch_size: int, seq_len: int):
        if torch.cuda.is_available():
            self.gpu_memory_history.append(torch.cuda.memory_allocated() / 1e9)
        self.tokens_processed += batch_size * seq_len
        self.step_times.append(time.time())

    def get_throughput(self) -> float:
        if len(self.step_times) < 2:
            return 0.0
        elapsed = self.step_times[-1] - self.step_times[0]
        return self.tokens_processed / elapsed if elapsed > 0 else 0.0

    def get_summary(self) -> Dict[str, Any]:
        elapsed = time.time() - self.start_time if self.start_time else 0
        return {
            'peak_gpu_memory_gb': max(self.gpu_memory_history) if self.gpu_memory_history else 0,
            'avg_gpu_memory_gb': float(np.mean(self.gpu_memory_history)) if self.gpu_memory_history else 0,
            'total_training_time_s': elapsed,
            'total_training_time_min': elapsed / 60,
            'tokens_processed': self.tokens_processed,
            'throughput_tokens_per_sec': self.get_throughput(),
        }


class SpikeStatsCollector:
    """collect per-layer spike density and amplitude evolution."""

    def __init__(self, n_layers: int):
        self.n_layers = n_layers
        self.density_history = {i: {'k': [], 'v': []} for i in range(n_layers)}
        self.amplitude_history = {i: {'k': [], 'v': []} for i in range(n_layers)}
        self.step_densities = []

    def record(self, student, step: int):
        stats = student.get_spike_stats()
        all_densities = []
        for i in range(self.n_layers):
            layer_key = f'layer_{i}'
            if layer_key in stats:
                k_density = stats[layer_key].get('k', 0)
                v_density = stats[layer_key].get('v', 0)
                k_amp = stats[layer_key].get('k_amp', 1.0)
                v_amp = stats[layer_key].get('v_amp', 1.0)

                self.density_history[i]['k'].append(k_density)
                self.density_history[i]['v'].append(v_density)
                self.amplitude_history[i]['k'].append(k_amp)
                self.amplitude_history[i]['v'].append(v_amp)
                all_densities.extend([k_density, v_density])

        if all_densities:
            self.step_densities.append({'step': step, 'density': float(np.mean(all_densities))})

    def get_summary(self) -> Dict[str, Any]:
        per_layer = {}
        all_k, all_v = [], []
        all_k_amp, all_v_amp = [], []

        for i in range(self.n_layers):
            k_vals = self.density_history[i]['k']
            v_vals = self.density_history[i]['v']
            k_amps = self.amplitude_history[i]['k']
            v_amps = self.amplitude_history[i]['v']

            per_layer[f'layer_{i}'] = {
                'k_mean': float(np.mean(k_vals)) if k_vals else 0,
                'k_std': float(np.std(k_vals)) if k_vals else 0,
                'k_final': float(k_vals[-1]) if k_vals else 0,
                'v_mean': float(np.mean(v_vals)) if v_vals else 0,
                'v_std': float(np.std(v_vals)) if v_vals else 0,
                'v_final': float(v_vals[-1]) if v_vals else 0,
                'k_amp_final': float(k_amps[-1]) if k_amps else 1.0,
                'v_amp_final': float(v_amps[-1]) if v_amps else 1.0,
            }
            all_k.extend(k_vals)
            all_v.extend(v_vals)
            if k_amps: all_k_amp.append(k_amps[-1])
            if v_amps: all_v_amp.append(v_amps[-1])

        return {
            'per_layer': per_layer,
            'overall_k_density': float(np.mean(all_k)) if all_k else 0,
            'overall_v_density': float(np.mean(all_v)) if all_v else 0,
            'overall_density': float(np.mean(all_k + all_v)) if (all_k or all_v) else 0,
            'amplitudes': {'k': all_k_amp, 'v': all_v_amp},
            'density_history': self.step_densities,
        }

print("collectors defined")


In [None]:
# =============================================================================
# cell 8: spiking goose model (v14 - channel-wise spikes + gradient checkpointing)
# =============================================================================
class SpikingGooseRecurrentLayer(nn.Module):
    """
    RWKV-style recurrence with trainable ternary spiking.
    
    Supports channel-wise ternary spikes (when use_channel_wise=True)
    """

    def __init__(self, d_model, layer_idx=0, n_layers=4, spike_alpha=1.0,
                 use_channel_wise: bool = False, threshold_mix: float = 0.35,
                 surrogate_temp: float = 0.10):
        super().__init__()
        self.d_model = d_model
        self.layer_idx = layer_idx
        self.use_channel_wise = use_channel_wise
        self.ln = nn.LayerNorm(d_model)

        ratio = layer_idx / max(n_layers - 1, 1)
        self.time_mix_k = nn.Parameter(torch.ones(d_model) * (1 - ratio))
        self.time_mix_v = nn.Parameter(torch.ones(d_model) * (1 - ratio))
        self.time_mix_r = nn.Parameter(torch.ones(d_model) * (1 - ratio))
        self.decay_weight = nn.Parameter(torch.zeros(d_model) - 0.5)

        self.key_proj = nn.Linear(d_model, d_model, bias=False)
        self.value_proj = nn.Linear(d_model, d_model, bias=False)
        self.receptance_proj = nn.Linear(d_model, d_model, bias=False)
        self.output_proj = nn.Linear(d_model, d_model, bias=False)

        # v14: Use channel-wise spikes if enabled
        if use_channel_wise:
            self.k_spike = ChannelWiseTernarySpike(d_model, alpha_init=spike_alpha)
            self.v_spike = ChannelWiseTernarySpike(d_model, alpha_init=spike_alpha)
        else:
            self.k_spike = TrainableTernarySpike(
                alpha=spike_alpha,
                threshold_mix=threshold_mix,
                surrogate_temp=surrogate_temp,
            )
            self.v_spike = TrainableTernarySpike(
                alpha=spike_alpha,
                threshold_mix=threshold_mix,
                surrogate_temp=surrogate_temp,
            )

        self.register_buffer('running_k_density', torch.tensor(0.0))
        self.register_buffer('running_v_density', torch.tensor(0.0))
        self._init_weights()

    def _init_weights(self):
        std = 0.1 / math.sqrt(self.d_model)
        for m in [self.key_proj, self.value_proj, self.receptance_proj, self.output_proj]:
            nn.init.normal_(m.weight, std=std)

    def forward(self, x, return_spikes: bool = False, detach_spikes: bool = True):
        B, T, D = x.shape
        x_norm = self.ln(x)
        prev_x = F.pad(x_norm[:, :-1, :], (0, 0, 1, 0))

        xk = x_norm * self.time_mix_k + prev_x * (1 - self.time_mix_k)
        xv = x_norm * self.time_mix_v + prev_x * (1 - self.time_mix_v)
        xr = x_norm * self.time_mix_r + prev_x * (1 - self.time_mix_r)

        k_pre = self.key_proj(xk)
        v_pre = self.value_proj(xv)

        k_aux = {}
        v_aux = {}
        if return_spikes:
            k, k_aux = self.k_spike(k_pre, return_aux=True)
            v, v_aux = self.v_spike(v_pre, return_aux=True)
        else:
            k = self.k_spike(k_pre)
            v = self.v_spike(v_pre)
        r = torch.sigmoid(self.receptance_proj(xr))

        kv = k * v
        decay = torch.sigmoid(self.decay_weight)
        t_idx = torch.arange(T, device=x.device, dtype=x.dtype)
        decay_powers = decay.unsqueeze(0) ** t_idx.unsqueeze(1)

        kv_weighted = kv / (decay_powers.unsqueeze(0) + 1e-8)
        S = torch.cumsum(kv_weighted, dim=1) * decay_powers.unsqueeze(0)

        if self.training:
            with torch.no_grad():
                self.running_k_density = 0.99 * self.running_k_density + 0.01 * (k != 0).float().mean()
                self.running_v_density = 0.99 * self.running_v_density + 0.01 * (v != 0).float().mean()

        out = x + r * self.output_proj(S)
        if return_spikes:
            k_out = k.detach() if detach_spikes else k
            v_out = v.detach() if detach_spikes else v
            return out, {
                'k_spikes': k_out,
                'v_spikes': v_out,
                'k_soft_activity': k_aux.get('soft_activity'),
                'v_soft_activity': v_aux.get('soft_activity'),
            }
        return out

    def get_spike_density(self):
        return {
            'k': self.running_k_density.item(),
            'v': self.running_v_density.item(),
            'k_amp': self.k_spike.get_amplitude(),
            'v_amp': self.v_spike.get_amplitude(),
        }
    
    def get_channel_wise_stats(self) -> dict:
        """Get channel-wise spike statistics (only available if use_channel_wise=True)."""
        if self.use_channel_wise:
            return {
                'k': self.k_spike.get_stats(),
                'v': self.v_spike.get_stats(),
            }
        return None


class GooseFFN(nn.Module):
    def __init__(self, d_model, expand=4):
        super().__init__()
        self.ln = nn.LayerNorm(d_model)
        self.w1 = nn.Linear(d_model, d_model * expand, bias=False)
        self.w2 = nn.Linear(d_model * expand, d_model, bias=False)

    def forward(self, x):
        return x + self.w2(F.silu(self.w1(self.ln(x))))


class StudentSpikingGoose(nn.Module):
    """
    Spiking student model with trainable ternary activations.
    
    Supports channel-wise ternary spikes + gradient checkpointing.
    """

    def __init__(self, cfg, use_checkpointing=True):
        super().__init__()
        self.cfg = cfg
        self.use_checkpointing = use_checkpointing and USE_GRADIENT_CHECKPOINTING
        
        # v14: Check for channel-wise spikes flag
        use_channel_wise = getattr(cfg, 'use_channel_wise_spikes', False)
        
        self.embed = nn.Embedding(cfg.vocab_size, cfg.d_model)
        self.pos_embed = nn.Embedding(cfg.max_seq_len, cfg.d_model)

        self.layers = nn.ModuleList([
            nn.ModuleDict({
                'rec': SpikingGooseRecurrentLayer(
                    cfg.d_model, i, cfg.n_layers, cfg.spike_alpha,
                    use_channel_wise=use_channel_wise,
                    threshold_mix=cfg.spike_threshold_mix,
                    surrogate_temp=cfg.spike_surrogate_temp,
                ),
                'ffn': GooseFFN(cfg.d_model),
            })
            for i in range(cfg.n_layers)
        ])

        self.ln_out = nn.LayerNorm(cfg.d_model)
        self.head = nn.Linear(cfg.d_model, cfg.vocab_size, bias=False)
        self.head.weight = self.embed.weight

        nn.init.normal_(self.embed.weight, std=0.02)
        nn.init.normal_(self.pos_embed.weight, std=0.02)

    def _layer_forward(self, layer, x):
        """helper for gradient checkpointing - processes one layer."""
        x = layer['rec'](x)
        x = layer['ffn'](x)
        return x

    def forward(self, input_ids, return_hiddens=False, return_spike_info=False, detach_spikes: bool = True):
        """forward pass with optional hidden state return for alignment."""
        B, T = input_ids.shape
        pos = torch.arange(T, device=input_ids.device).unsqueeze(0)
        x = self.embed(input_ids) + self.pos_embed(pos)

        hiddens = [x] if return_hiddens else None
        spike_info = {} if return_spike_info else None

        for layer_idx, layer in enumerate(self.layers):
            # Checkpoint path is tensor-only; skip it when spike tensors are requested.
            if self.use_checkpointing and self.training and not return_spike_info:
                x = checkpoint(self._layer_forward, layer, x, use_reentrant=False)
            else:
                if return_spike_info:
                    x, layer_spikes = layer['rec'](
                        x,
                        return_spikes=True,
                        detach_spikes=detach_spikes,
                    )
                    x = layer['ffn'](x)
                    spike_info[layer_idx] = layer_spikes
                else:
                    x = self._layer_forward(layer, x)

            if return_hiddens:
                hiddens.append(x)

        logits = self.head(self.ln_out(x))

        if return_hiddens and return_spike_info:
            return logits, hiddens, {'spike_info': spike_info}
        if return_hiddens:
            return logits, hiddens
        if return_spike_info:
            return logits, {'spike_info': spike_info}
        return logits

    def get_spike_stats(self):
        return {f'layer_{i}': layer['rec'].get_spike_density() for i, layer in enumerate(self.layers)}

    def get_avg_spike_density(self):
        densities = []
        for layer in self.layers:
            d = layer['rec'].get_spike_density()
            densities.extend([d['k'], d['v']])
        return float(np.mean(densities)) if densities else 0.0

    def get_amplitudes(self):
        return {f'layer_{i}': {'k': layer['rec'].k_spike.get_amplitude(), 'v': layer['rec'].v_spike.get_amplitude()}
                for i, layer in enumerate(self.layers)}
    
    def get_channel_amplitude_variance(self) -> float:
        """Get total variance of channel-wise amplitudes (for regularization)."""
        total_var = 0.0
        for layer in self.layers:
            rec = layer['rec']
            if hasattr(rec.k_spike, 'amplitude') and rec.k_spike.amplitude.numel() > 1:
                total_var += rec.k_spike.amplitude.var().item()
                total_var += rec.v_spike.amplitude.var().item()
        return total_var

print("student model defined (v14: channel-wise spikes + gradient checkpointing)")
print(f"  gradient checkpointing: {USE_GRADIENT_CHECKPOINTING}")
print(f"  channel-wise spikes: {config.use_channel_wise_spikes}")


In [None]:
# =============================================================================
# cell 9: hidden-state projector + v14.1 FDD layer mapping
# =============================================================================
class HiddenStateProjector(nn.Module):
    """
    Project student hidden states to teacher dimension for alignment.

    student: (B, T, 320) -> (B, T, 768)

    Maps student layers to selected teacher layers.

    NOTE: This is kept for infrastructure but hidden alignment is DISABLED.
    v14 uses FDD with CKA loss which is projector-free.
    """

    def __init__(self, student_dim: int, teacher_dim: int, n_student_layers: int):
        super().__init__()
        self.projectors = nn.ModuleList([
            nn.Linear(student_dim, teacher_dim, bias=False)
            for _ in range(n_student_layers)
        ])
        for proj in self.projectors:
            nn.init.normal_(proj.weight, std=0.02)

    def forward(self, student_hidden: torch.Tensor, layer_idx: int) -> torch.Tensor:
        return self.projectors[layer_idx](student_hidden)


def compute_hidden_alignment_loss(
    teacher_hiddens: List[torch.Tensor],
    student_hiddens: List[torch.Tensor],
    projector: HiddenStateProjector,
    teacher_layers: int = 12,
    student_layers: int = 8
) -> torch.Tensor:
    """
    Compute MSE loss between projected student and teacher hidden states.

    NOTE: This is DISABLED in v14 (hidden_align_weight=0.0).
    v14 uses FDD with CKA loss instead.
    """
    # Map student layers to teacher layers
    teacher_indices = [1, 2, 4, 5, 7, 8, 10, 11]

    total_loss = 0.0
    for s_idx, t_idx in enumerate(teacher_indices):
        if s_idx >= len(student_hiddens) - 1:
            break
        if t_idx >= len(teacher_hiddens):
            break

        s_hidden = student_hiddens[s_idx + 1]
        t_hidden = teacher_hiddens[t_idx]

        s_proj = projector(s_hidden, s_idx)
        total_loss = total_loss + F.mse_loss(s_proj, t_hidden)

    return total_loss / len(teacher_indices)


# Create projector (even if disabled, keeps infrastructure)
projector = HiddenStateProjector(
    student_dim=config.d_model,
    teacher_dim=config.teacher_d_model,
    n_student_layers=config.n_layers
).to(DEVICE)

projector_params = sum(p.numel() for p in projector.parameters())
print(f"hidden-state projector: {projector_params:,} params")
print(f"  student dim: {config.d_model}")
print(f"  teacher dim: {config.teacher_d_model}")
print(f"  student layers: {config.n_layers}")
print(f"  hidden_align_weight: {config.hidden_align_weight}")
print(f"  STATUS: DISABLED (v14 uses FDD with CKA instead)")

# =============================================================================
# v14: Create FDD layer mapping
# =============================================================================
fdd_layer_map = get_fdd_layer_mapping(
    n_student_layers=config.n_layers,
    n_teacher_layers=config.teacher_n_layers,
    n_align_layers=config.fdd_n_align_layers
)

print(f"")
print(f"{config.VERSION} FDD Layer Mapping:")
print(f"  Layer pairs to align: {config.fdd_n_align_layers}")
print(f"  Mapping: {fdd_layer_map}")
print(f"  Strategy: Align early/middle/late semantic layers")


In [None]:
# =============================================================================
# cell 10: cosine lr with warmup (same as v9)
# =============================================================================
def get_cosine_schedule_with_warmup(
    optimizer: torch.optim.Optimizer,
    warmup_steps: int,
    total_steps: int,
) -> torch.optim.lr_scheduler.LambdaLR:
    """
    linear warmup then cosine decay to 0.
    """
    def lr_lambda(step: int) -> float:
        if step < warmup_steps:
            return step / max(warmup_steps, 1)
        else:
            progress = (step - warmup_steps) / max(total_steps - warmup_steps, 1)
            return 0.5 * (1.0 + math.cos(math.pi * progress))

    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)


print(f"cosine lr: {config.warmup_steps} warmup, {config.distill_steps} total")


In [None]:
# =============================================================================
# cell 11: load gpt-2 teacher (same as v9)
# =============================================================================
from transformers import GPT2LMHeadModel, GPT2Tokenizer

print("loading gpt-2 teacher...")
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token

teacher = GPT2LMHeadModel.from_pretrained('gpt2').to(DEVICE)
teacher.config.use_cache = False  # Disable KV caching (not needed for distillation)

# Compile teacher for faster inference (PyTorch 2.0+)
try:
    teacher = torch.compile(teacher, mode='reduce-overhead')
    print('teacher compiled with torch.compile')
except Exception as e:
    print(f'torch.compile not available: {e}')
teacher.eval()
for p in teacher.parameters():
    p.requires_grad = False

teacher_params = sum(p.numel() for p in teacher.parameters())
print(f"teacher: gpt-2 ({teacher_params:,} params)")


In [None]:
# =============================================================================
# cell 13: data loading (v14 - efficient DataLoader)
# =============================================================================
from datasets import load_dataset

print("loading wikitext-2...")
dataset = load_dataset('wikitext', 'wikitext-2-raw-v1')

def pre_tokenize(texts, max_len):
    all_tokens = []
    for text in tqdm(texts, desc="tokenizing", leave=False):
        if text.strip():
            all_tokens.extend(tokenizer.encode(text, max_length=max_len*2, truncation=True))
    chunks = [all_tokens[i:i+max_len] for i in range(0, len(all_tokens)-max_len+1, max_len//2) if len(all_tokens[i:i+max_len]) == max_len]
    print(f"created {len(chunks)} sequences")
    return torch.tensor(chunks, dtype=torch.long)

train_tokens = pre_tokenize(dataset['train']['text'], config.max_seq_len)
val_tokens = pre_tokenize(dataset['validation']['text'], config.max_seq_len)

# v14: efficient DataLoader with workers and prefetch
# Note: num_workers=0 for Kaggle/Colab compatibility, but prefetch still helps
dataloader_kwargs = {
    'batch_size': config.batch_size,
    'pin_memory': True,
    'num_workers': 0 if IS_KAGGLE or IS_COLAB else 2,  # workers disabled on cloud platforms
    'prefetch_factor': None if IS_KAGGLE or IS_COLAB else 2,
    'persistent_workers': False if IS_KAGGLE or IS_COLAB else True,
}

train_loader = DataLoader(TensorDataset(train_tokens), shuffle=True, **dataloader_kwargs)
val_loader = DataLoader(TensorDataset(val_tokens), shuffle=False, **dataloader_kwargs)

print(f"train: {len(train_loader)} batches, val: {len(val_loader)} batches")
print(f"DataLoader: num_workers={dataloader_kwargs['num_workers']}, pin_memory={dataloader_kwargs['pin_memory']}")
if len(train_loader) == 0:
    raise RuntimeError(
        "train_loader is empty after tokenization. "
        "Check dataset availability and max_seq_len/chunking settings."
    )
if len(val_loader) == 0:
    raise RuntimeError(
        "val_loader is empty after tokenization. "
        "Check dataset availability and max_seq_len/chunking settings."
    )


In [None]:
# =============================================================================
# cell 14: create student model and projector (v14 - with compile)
# =============================================================================
print("creating student model (v14 - v14 baseline + POCL)...")

student = StudentSpikingGoose(config, use_checkpointing=USE_GRADIENT_CHECKPOINTING).to(DEVICE)
student_params = sum(p.numel() for p in student.parameters())

# v14: create projector (even if not used, for infrastructure preservation)
projector = HiddenStateProjector(
    student_dim=config.d_model,
    teacher_dim=config.teacher_d_model,
    n_student_layers=config.n_layers
).to(DEVICE)
projector_params = sum(p.numel() for p in projector.parameters())

compression_ratio = teacher_params / student_params

print(f"student: asnn-goose v14 ({student_params:,} params)")
print(f"projector: ({projector_params:,} params)")
print(f"compression ratio: {compression_ratio:.1f}x")
print(f"")
print(f"{config.VERSION} architecture:")
print(f"  d_model: {config.d_model}")
print(f"  n_layers: {config.n_layers}")
print(f"  params: ~{student_params // 1_000_000}M")
print(f"")

# v14: compile model if available and enabled
compile_success = False
if USE_TORCH_COMPILE and TORCH_COMPILE_AVAILABLE:
    try:
        print("compiling student model with torch.compile...")
        # Use the compile() method as recommended by PyTorch docs
        student = torch.compile(student, mode='reduce-overhead')
        compile_success = True
        print("compilation successful!")
    except Exception as e:
        print(f"torch.compile failed: {e}")
        print("continuing without compilation")
else:
    print(f"torch.compile skipped (USE_TORCH_COMPILE={USE_TORCH_COMPILE}, available={TORCH_COMPILE_AVAILABLE})")

print(f"")
print(f"speedups active:")
print(f"  gradient checkpointing: {USE_GRADIENT_CHECKPOINTING}")
print(f"  torch.compile: {compile_success}")
print(f"  accumulation_steps: {config.accumulation_steps}")


In [None]:
# =============================================================================
# cell 15: evaluation functions (same as v9)
# =============================================================================
@torch.no_grad()
def evaluate(model, loader, device, is_gpt2=False):
    model.eval()
    total_loss, total_tokens = 0, 0
    with torch.inference_mode():
      for batch in loader:
        ids = batch[0].to(device, non_blocking=True)
        with torch.cuda.amp.autocast():
            logits = model(ids).logits if is_gpt2 else model(ids)
        loss = F.cross_entropy(logits[:, :-1].reshape(-1, logits.size(-1)), ids[:, 1:].reshape(-1), reduction='sum')
        total_loss += loss.item()
        total_tokens += ids[:, 1:].numel()
    if total_tokens == 0:
        raise RuntimeError(
            "Evaluation loader produced zero tokens; cannot compute loss/PPL."
        )
    return total_loss / total_tokens

def get_ppl(loss):
    return math.exp(min(loss, 10))

print("evaluation functions defined")


In [None]:
# =============================================================================
# cell 17: distillation training loop (v14.1.1 - FDD+CTKD+HardCE)
# =============================================================================
def get_spike_semantic_weight(step: int, warmup_steps: int, max_weight: float) -> float:
    if step < warmup_steps:
        return 0.0
    ramp_steps = max(warmup_steps, 1)
    ramp = min(1.0, (step - warmup_steps) / ramp_steps)
    return max_weight * ramp


def build_teacher_ternary_target(teacher_hidden: torch.Tensor, threshold_scale: float) -> torch.Tensor:
    centered = teacher_hidden - teacher_hidden.mean(dim=-1, keepdim=True)
    threshold = threshold_scale * centered.abs().mean(dim=-1, keepdim=True)
    pos = centered > threshold
    neg = centered < -threshold
    target = torch.zeros_like(centered)
    target = torch.where(pos, torch.ones_like(target), target)
    target = torch.where(neg, -torch.ones_like(target), target)
    return target


def distill_v14(teacher, student, projector, train_loader, val_loader, cfg, device,
                hw_stats, spike_stats, fdd_layer_map):
    """
    v14 distillation with FDD (Feature Dynamics Distillation) + CTKD.

    Key innovations:
    1. FDD: Align layer dynamics (delta_h) using CKA loss
    2. CTKD: Adversarial temperature learning (proven in v12.1, v13.1)
    3. Safety: FDD kill-switch if PPL regresses

    References:
    - CKA: Kornblith et al., "Similarity of Neural Network Representations"
    - CTKD: https://arxiv.org/abs/2211.16231
    - FDD: Feature Dynamics Distillation (view transformer as ODE)
    """
    training_logs = {
        'loss_history': [],
        'kl_loss_history': [],
        'ce_loss_history': [],  # v14.1: hard distillation
        'fdd_loss_history': [],
        'spike_sem_loss_history': [],
        'align_loss_history': [],
        'ppl_history': [],
        'lr_history': [],
        'temp_history': [],
        'lambda_history': [],
        'fdd_weight_history': [],
        'spike_sem_weight_history': [],
        'stage_history': [],
        'stage_transitions': [],
        'early_stopped': False,
        'early_stop_step': None,
        'fdd_killed': False,
        'fdd_kill_step': None,
    }

    # =========================================================================
    # v14: CTKD Temperature (same as v12.1, v13.1)
    # =========================================================================
    if cfg.use_ctkd:
        temp_module = CTKDTemperature(
            tau_min=cfg.tau_min,
            tau_max=cfg.tau_max,
            init=cfg.tau_init
        ).to(device)
        print(f"{config.VERSION}: CTKD with Gradient Reversal Layer")
        print(f"     Temperature bounds: [{cfg.tau_min}, {cfg.tau_max}]")
        print(f"     Initial temp: {cfg.tau_init}")
        print(f"     Lambda warmup: {cfg.lambda_warmup_ratio*100:.0f}%")
    else:
        temp_module = None
        print(f"Using fixed temperature: {cfg.temperature}")

    # =========================================================================
    # v14: FDD Setup
    # =========================================================================
    fdd_enabled = cfg.use_fdd
    fdd_killed = False
    baseline_ppl = None  # Set at fdd_warmup_steps

    if cfg.use_fdd:
        print(f"")
        print(f"{config.VERSION}: Feature Dynamics Distillation (FDD)")
        print(f"     Layer mapping: {fdd_layer_map}")
        print(f"     Weight: {cfg.fdd_weight}")
        print(f"     Warmup: {cfg.fdd_warmup_steps} steps")
        print(f"     Loss type: {cfg.fdd_loss_type}")
        print(f"     Kill threshold: {cfg.fdd_kill_threshold*100:.0f}% PPL increase")

    if cfg.use_spike_semantic_loss:
        print(f"")
        print(f"{config.VERSION}: Spike Semantic Alignment")
        print(f"     Weight: {cfg.spike_semantic_weight}")
        print(f"     Warmup: {cfg.spike_semantic_warmup_steps} steps")
        print(f"     Target threshold scale: {cfg.spike_target_threshold_scale}")

    # =========================================================================
    # v14: Early Stopping Setup
    # =========================================================================
    best_ppl = float('inf')
    best_step = 0
    no_improve_steps = 0

    if cfg.use_early_stopping:
        print(f"")
        print(f"{config.VERSION}: Early Stopping")
        print(f"     Patience: {cfg.early_stopping_patience} steps")
        print(f"     Min delta: {cfg.min_ppl_delta} PPL")

    # =========================================================================
    # Setup optimizer
    # =========================================================================
    param_groups = [
        {'params': list(student.parameters()), 'lr': cfg.distill_lr}
    ]

    if cfg.hidden_align_weight > 0:
        param_groups.append({'params': list(projector.parameters()), 'lr': cfg.distill_lr})

    if temp_module is not None:
        param_groups.append({'params': list(temp_module.parameters()), 'lr': cfg.distill_lr})

    all_params = []
    for group in param_groups:
        all_params.extend(group['params'])

    try:
        optimizer = torch.optim.AdamW(param_groups, weight_decay=0.01, fused=True)
        print("Using fused AdamW")
    except TypeError:
        optimizer = torch.optim.AdamW(param_groups, weight_decay=0.01)

    scheduler = get_cosine_schedule_with_warmup(optimizer, cfg.warmup_steps, cfg.distill_steps)
    scaler = torch.cuda.amp.GradScaler()

    hw_stats.start()
    step = 0
    accum_step = 0
    current_stage = 1

    accumulation_steps = cfg.accumulation_steps
    effective_batch = cfg.batch_size * accumulation_steps
    print(f"Gradient accumulation: {accumulation_steps} (effective batch = {effective_batch})")
    print(f"Extended training: {cfg.distill_steps} steps")

    pbar = tqdm(total=cfg.distill_steps, desc='distilling (v14.1 - FDD+CTKD+HardCE)')

    optimizer.zero_grad(set_to_none=True)

    if len(train_loader) == 0:
        pbar.close()
        raise RuntimeError("train_loader is empty; aborting distillation loop.")
    if len(val_loader) == 0:
        pbar.close()
        raise RuntimeError("val_loader is empty; aborting distillation loop.")

    while step < cfg.distill_steps:
        for batch in train_loader:
            if step >= cfg.distill_steps:
                break

            # Check early stopping
            if cfg.use_early_stopping and no_improve_steps >= cfg.early_stopping_patience:
                print(f"\n  [Early Stopping] No improvement for {cfg.early_stopping_patience} steps")
                print(f"     Best PPL: {best_ppl:.2f} at step {best_step}")
                training_logs['early_stopped'] = True
                training_logs['early_stop_step'] = step
                pbar.close()
                return training_logs

            ids = batch[0].to(device, non_blocking=True)

            # Get lambda for CTKD
            if cfg.use_ctkd:
                current_lambda = get_lambda(
                    step, cfg.distill_steps,
                    lambda_max=cfg.lambda_max,
                    warmup_ratio=cfg.lambda_warmup_ratio
                )
            else:
                current_lambda = 0.0

            # Get current FDD weight
            if fdd_enabled and not fdd_killed:
                current_fdd_weight = get_fdd_weight(step, cfg.fdd_warmup_steps, cfg.fdd_weight)
            else:
                current_fdd_weight = 0.0

            if cfg.use_spike_semantic_loss:
                current_spike_sem_weight = get_spike_semantic_weight(
                    step,
                    cfg.spike_semantic_warmup_steps,
                    cfg.spike_semantic_weight,
                )
            else:
                current_spike_sem_weight = 0.0

            with torch.cuda.amp.autocast():
                # Teacher forward (always get hidden states for FDD)
                with torch.no_grad():
                    t_out = teacher(ids, output_hidden_states=True)
                    t_logits = t_out.logits
                    t_hiddens = t_out.hidden_states  # tuple of tensors

                # Student forward (always get hidden states for FDD)
                student.train()
                s_logits, s_hiddens, spike_aux = student(
                    ids,
                    return_hiddens=True,
                    return_spike_info=True,
                    detach_spikes=False,
                )
                spike_info = spike_aux.get('spike_info', {}) if isinstance(spike_aux, dict) else {}

                # Get temperature
                if cfg.use_ctkd and temp_module is not None:
                    T = temp_module(current_lambda)
                elif temp_module is not None:
                    T = temp_module()
                else:
                    T = cfg.temperature

                # KL divergence loss with temperature
                s_log = F.log_softmax(s_logits / T, dim=-1)
                t_prob = F.softmax(t_logits / T, dim=-1)
                kl_loss = F.kl_div(
                    s_log.view(-1, s_logits.size(-1)),
                    t_prob.view(-1, t_logits.size(-1)),
                    reduction='batchmean'
                ) * (T ** 2)

                # FDD loss (v14.1 with 100x weight increase)
                if current_fdd_weight > 0:
                    fdd_loss = compute_fdd_loss(
                        s_hiddens,
                        list(t_hiddens),  # Convert tuple to list
                        fdd_layer_map,
                        loss_type=cfg.fdd_loss_type
                    )
                else:
                    fdd_loss = torch.tensor(0.0, device=device)

                if current_spike_sem_weight > 0 and spike_info:
                    sem_losses = []
                    for s_layer, t_layer in fdd_layer_map.items():
                        layer_spikes = spike_info.get(s_layer)
                        if not isinstance(layer_spikes, dict):
                            continue

                        k_spikes = layer_spikes.get('k_spikes')
                        v_spikes = layer_spikes.get('v_spikes')
                        if k_spikes is None or v_spikes is None:
                            continue

                        spike_repr = 0.5 * (k_spikes + v_spikes)
                        teacher_hidden = t_hiddens[t_layer + 1]

                        if spike_repr.size(-1) != teacher_hidden.size(-1):
                            min_dim = min(spike_repr.size(-1), teacher_hidden.size(-1))
                            spike_repr = spike_repr[..., :min_dim]
                            teacher_hidden = teacher_hidden[..., :min_dim]

                        teacher_target = build_teacher_ternary_target(
                            teacher_hidden,
                            cfg.spike_target_threshold_scale,
                        )
                        sem_losses.append(F.mse_loss(spike_repr, teacher_target))

                    if sem_losses:
                        spike_sem_loss = torch.stack(sem_losses).mean()
                    else:
                        spike_sem_loss = torch.tensor(0.0, device=device)
                else:
                    spike_sem_loss = torch.tensor(0.0, device=device)

                # v14.1: Hard distillation (CE with ground truth)
                if cfg.ce_hard_weight > 0:
                    shift_logits = s_logits[:, :-1, :].contiguous()
                    shift_labels = ids[:, 1:].contiguous()
                    ce_loss = F.cross_entropy(
                        shift_logits.view(-1, shift_logits.size(-1)),
                        shift_labels.view(-1),
                        ignore_index=-100
                    )
                else:
                    ce_loss = torch.tensor(0.0, device=device)

                # Hidden alignment (usually disabled, kept for infrastructure)
                if cfg.hidden_align_weight > 0:
                    align_loss = compute_hidden_alignment_loss(
                        t_hiddens, s_hiddens, projector,
                        teacher_layers=cfg.teacher_n_layers,
                        student_layers=cfg.n_layers
                    )
                else:
                    align_loss = torch.tensor(0.0, device=device)

                # Total loss (v14.1: added ce_hard_weight * ce_loss)
                loss = (
                    kl_loss
                    + cfg.ce_hard_weight * ce_loss
                    + current_fdd_weight * fdd_loss
                    + current_spike_sem_weight * spike_sem_loss
                    + cfg.hidden_align_weight * align_loss
                )
                loss = loss / accumulation_steps

            scaler.scale(loss).backward()
            accum_step += 1

            if accum_step % accumulation_steps == 0:
                scaler.unscale_(optimizer)
                gn = torch.nn.utils.clip_grad_norm_(all_params, cfg.max_grad_norm)

                if torch.isfinite(gn):
                    scaler.step(optimizer)
                scaler.update()
                scheduler.step()
                optimizer.zero_grad(set_to_none=True)

                hw_stats.record_step(ids.size(0) * accumulation_steps, ids.size(1))
                spike_stats.record(student, step)

                current_lr = optimizer.param_groups[0]['lr']
                current_temp = temp_module.get_temperature() if temp_module is not None else cfg.temperature

                # Log
                training_logs['loss_history'].append({'step': step, 'loss': loss.item() * accumulation_steps})
                training_logs['kl_loss_history'].append({'step': step, 'loss': kl_loss.item()})
                training_logs['ce_loss_history'].append({'step': step, 'loss': ce_loss.item() if isinstance(ce_loss, torch.Tensor) else ce_loss})  # v14.1
                training_logs['fdd_loss_history'].append({'step': step, 'loss': fdd_loss.item() if isinstance(fdd_loss, torch.Tensor) else fdd_loss})
                training_logs['spike_sem_loss_history'].append({'step': step, 'loss': spike_sem_loss.item() if isinstance(spike_sem_loss, torch.Tensor) else spike_sem_loss})
                training_logs['align_loss_history'].append({'step': step, 'loss': align_loss.item() if isinstance(align_loss, torch.Tensor) else align_loss})
                training_logs['lr_history'].append({'step': step, 'lr': current_lr})
                training_logs['temp_history'].append({'step': step, 'temperature': current_temp})
                training_logs['lambda_history'].append({'step': step, 'lambda': current_lambda})
                training_logs['fdd_weight_history'].append({'step': step, 'fdd_weight': current_fdd_weight})
                training_logs['spike_sem_weight_history'].append({'step': step, 'weight': current_spike_sem_weight})
                training_logs['stage_history'].append({'step': step, 'stage': 1})

                # Update progress bar (v14.1: added CE loss)
                fdd_str = f"fdd={fdd_loss.item():.3f}" if current_fdd_weight > 0 else "fdd=of"
                ce_str = f"ce={ce_loss.item():.3f}" if cfg.ce_hard_weight > 0 else "ce=of"
                sem_str = f"sem={spike_sem_loss.item():.3f}" if current_spike_sem_weight > 0 else "sem=of"
                pbar.set_postfix(
                    loss=f"{loss.item() * accumulation_steps:.3f}",
                    kl=f"{kl_loss.item():.3f}",
                    ce=ce_str,
                    sem=sem_str,
                    fdd=fdd_str,
                    T=f"{current_temp:.2f}",
                    lr=f"{current_lr:.1e}"
                )
                pbar.update(1)
                step += 1

                if step % cfg.eval_interval == 0:
                    val_loss = evaluate(student, val_loader, device)
                    val_ppl = get_ppl(val_loss)
                    training_logs['ppl_history'].append({'step': step, 'ppl': val_ppl})

                    amps = student.get_amplitudes()
                    amp_str = ', '.join([f"L{i}:{amps[f'layer_{i}']['k']:.2f}" for i in range(min(4, cfg.n_layers))])

                    # =========================================================
                    # v14.1: FDD Kill Switch
                    # =========================================================
                    if step == cfg.fdd_warmup_steps and fdd_enabled:
                        baseline_ppl = val_ppl
                        print(f"\n  [FDD] Baseline PPL at warmup end: {baseline_ppl:.2f}")

                    if fdd_enabled and not fdd_killed and baseline_ppl is not None:
                        ppl_increase = (val_ppl - baseline_ppl) / baseline_ppl
                        if ppl_increase > cfg.fdd_kill_threshold:
                            fdd_killed = True
                            training_logs['fdd_killed'] = True
                            training_logs['fdd_kill_step'] = step
                            print(f"\n  [FDD KILLED] PPL increased {ppl_increase*100:.1f}% > {cfg.fdd_kill_threshold*100:.0f}%")
                            print(f"     Baseline: {baseline_ppl:.2f}, Current: {val_ppl:.2f}")
                            print(f"     Disabling FDD for remaining training")

                    # Early stopping check
                    if val_ppl < best_ppl - cfg.min_ppl_delta:
                        best_ppl = val_ppl
                        best_step = step
                        no_improve_steps = 0
                        save_dict = {
                            'student': student.state_dict(),
                            'projector': projector.state_dict(),
                            'step': step,
                            'ppl': val_ppl,
                        }
                        if temp_module is not None:
                            save_dict['temp_module'] = temp_module.state_dict()
                        torch.save(save_dict, f'{OUTPUT_DIR}/checkpoints/v15_best.pt')
                        improve_str = " [NEW BEST]"
                    else:
                        no_improve_steps += cfg.eval_interval
                        improve_str = f" (no improve: {no_improve_steps}/{cfg.early_stopping_patience})"

                    lambda_str = f", lambda={current_lambda:.2f}" if cfg.use_ctkd else ""
                    fdd_status = "KILLED" if fdd_killed else f"w={current_fdd_weight:.4f}"
                    sem_status = f"w={current_spike_sem_weight:.4f}" if current_spike_sem_weight > 0 else "off"
                    print(
                        f"\n  step {step}: ppl={val_ppl:.1f}, T={current_temp:.2f}{lambda_str}, "
                        f"FDD:{fdd_status}, SEM:{sem_status}, amps=[{amp_str}...]{improve_str}"
                    )

    pbar.close()
    return training_logs

print("distillation function defined (v14.1.1 - FDD+CTKD+HardCE)")


In [None]:
# =============================================================================
# cell 18: run distillation (v14.1.1 - FDD+CTKD+HardCE)
# =============================================================================
print("="*60)
print("v14.1: FDD (Feature Dynamics Distillation) + CTKD")
print("="*60)
print(f"  Architecture: {config.d_model}d x {config.n_layers}L (~22M params)")
print(f"  Target: PPL < 400 (improve on v13.1's 434.44)")
print(f"")
print(f"{config.VERSION} Configuration:")
print(f"  FDD: {config.use_fdd}")
print(f"    Weight: {config.fdd_weight}")
print(f"    Warmup: {config.fdd_warmup_steps} steps")
print(f"    Layer map: {fdd_layer_map}")
print(f"    Loss type: {config.fdd_loss_type}")
print(f"    Kill threshold: {config.fdd_kill_threshold*100:.0f}%")
print(f"  Spike semantic alignment: {config.use_spike_semantic_loss}")
if config.use_spike_semantic_loss:
    print(f"    Weight: {config.spike_semantic_weight}")
    print(f"    Warmup: {config.spike_semantic_warmup_steps}")
    print(f"    Target threshold scale: {config.spike_target_threshold_scale}")
print(f"  CTKD: {config.use_ctkd} (proven from v12.1, v13.1)")
print(f"  Extended training: {config.distill_steps} steps")
print(f"  Early stopping: patience={config.early_stopping_patience}")
print(f"  POCL: {config.use_pocl} (disabled - caused regression)")
print("")

# Instantiate collectors
hw_stats = HardwareStatsCollector()
spike_stats = SpikeStatsCollector(config.n_layers)
print("Initialized HardwareStatsCollector and SpikeStatsCollector")

# Run distillation (FDD + CTKD)
print(f"\nStarting distillation...")

distill_logs = distill_v14(
    teacher, student, projector,
    train_loader, val_loader,
    config, DEVICE,
    hw_stats, spike_stats,
    fdd_layer_map  # v14: pass FDD layer mapping
)

# Report results
print(f"\n\n" + "="*60)
print("v14.1 Distillation Complete!")
print("="*60)

if distill_logs['ppl_history']:
    final_ppl = distill_logs['ppl_history'][-1]['ppl']
    best_ppl_entry = min(distill_logs['ppl_history'], key=lambda x: x['ppl'])
    print(f"\nFinal PPL: {final_ppl:.2f}")
    print(f"Best PPL: {best_ppl_entry['ppl']:.2f} at step {best_ppl_entry['step']}")

if distill_logs['early_stopped']:
    print(f"\nEarly stopped at step {distill_logs['early_stop_step']}")
else:
    print(f"\nCompleted all {config.distill_steps} steps")

if distill_logs['fdd_killed']:
    print(f"\nFDD was KILLED at step {distill_logs['fdd_kill_step']} (PPL regressed)")
else:
    print(f"\nFDD remained active throughout training")

if distill_logs['temp_history']:
    temps = [h['temperature'] for h in distill_logs['temp_history']]
    print(f"\nTemperature evolution:")
    print(f"  Start: {temps[0]:.2f}")
    print(f"  End: {temps[-1]:.2f}")

if distill_logs['lambda_history']:
    lambdas = [h['lambda'] for h in distill_logs['lambda_history']]
    print(f"\nLambda evolution:")
    print(f"  Start: {lambdas[0]:.2f}")
    print(f"  End: {lambdas[-1]:.2f}")

if distill_logs['fdd_loss_history']:
    fdd_losses = [h['loss'] for h in distill_logs['fdd_loss_history'] if h['loss'] > 0]
    if fdd_losses:
        print(f"\nFDD loss evolution:")
        print(f"  Start (after warmup): {fdd_losses[0]:.4f}")
        print(f"  End: {fdd_losses[-1]:.4f}")

print(f"\n" + "="*60)


In [None]:
# =============================================================================
# cell 20: lora implementation (same as v9)
# =============================================================================
class LoRALinear(nn.Module):
    """lora adapter for linear layers."""

    def __init__(self, in_features, out_features, rank=8, alpha=16.0):
        super().__init__()
        self.scaling = alpha / rank
        self.lora_A = nn.Parameter(torch.zeros(rank, in_features))
        self.lora_B = nn.Parameter(torch.zeros(out_features, rank))
        nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))

    def forward(self, x):
        return (x @ self.lora_A.T @ self.lora_B.T) * self.scaling


def apply_lora(model, rank=8, alpha=16.0, targets=['key_proj', 'value_proj']):
    """apply lora adapters to specified modules."""
    lora_modules = {}
    for name, module in model.named_modules():
        if any(t in name for t in targets) and isinstance(module, nn.Linear):
            lora = LoRALinear(module.in_features, module.out_features, rank, alpha).to(next(module.parameters()).device)
            lora_modules[name] = lora
            orig_forward = module.forward
            def make_forward(orig, lora_mod):
                def forward(x):
                    return orig(x) + lora_mod(x)
                return forward
            module.forward = make_forward(orig_forward, lora)
    print(f"lora: {len(lora_modules)} modules, rank={rank}")
    return lora_modules

print("lora defined")

In [None]:
# =============================================================================
# cell 21: ttt with lora (same as v9)
# =============================================================================
print("="*60)
print("phase 2: test-time training with lora")
print("="*60)

for p in student.parameters():
    p.requires_grad = False

lora_modules = apply_lora(student, config.lora_rank, config.lora_alpha)
lora_params = sum(p.numel() for m in lora_modules.values() for p in m.parameters())

pre_ttt_loss = evaluate(student, val_loader, DEVICE)
pre_ttt_ppl = get_ppl(pre_ttt_loss)
print(f"\npre-ttt ppl: {pre_ttt_ppl:.2f}")

lora_opt = torch.optim.AdamW([p for m in lora_modules.values() for p in m.parameters()], lr=config.ttt_lr)
ttt_logs = {'loss_history': []}
student.train()

for step, batch in enumerate(val_loader):
    if step >= config.ttt_steps:
        break
    ids = batch[0].to(DEVICE)
    with torch.cuda.amp.autocast():
        loss = F.cross_entropy(student(ids)[:, :-1].reshape(-1, config.vocab_size), ids[:, 1:].reshape(-1))
    lora_opt.zero_grad()
    loss.backward()
    lora_opt.step()
    ttt_logs['loss_history'].append({'step': step, 'loss': loss.item()})
    if step % 20 == 0:
        print(f"  ttt {step}: loss={loss.item():.4f}")

post_ttt_loss = evaluate(student, val_loader, DEVICE)
post_ttt_ppl = get_ppl(post_ttt_loss)
print(f"\npost-ttt ppl: {post_ttt_ppl:.2f}")
print(f"ttt improvement: {pre_ttt_ppl - post_ttt_ppl:.1f} ppl")

In [None]:
# =============================================================================
# cell 22: final evaluation (v14.1.1 - FDD+CTKD+HardCE)
# =============================================================================
print("="*60)
print("final evaluation (v14.1.1 - FDD+CTKD+HardCE)")
print("="*60)

teacher_loss = evaluate(teacher, val_loader, DEVICE, is_gpt2=True)
teacher_ppl = get_ppl(teacher_loss)
student_loss = evaluate(student, val_loader, DEVICE)
student_ppl = get_ppl(student_loss)

# v14: Get final temperature and lambda from CTKD
final_temp = distill_logs['temp_history'][-1]['temperature'] if distill_logs['temp_history'] else config.tau_init
final_lambda = distill_logs['temp_history'][-1].get('lambda', config.lambda_max) if distill_logs['temp_history'] else config.lambda_max

# VRAM logging
vram_peak_gb = torch.cuda.max_memory_allocated() / 1e9 if torch.cuda.is_available() else 0

print(f"")
print(f"{'model':<30} {'ppl':>10} {'params':>15}")
print("-" * 55)
print(f"{'gpt-2 (teacher)':<30} {teacher_ppl:>10.2f} {teacher_params:>15,}")
print(f"{'asnn-goose v14.1 (student)':<30} {student_ppl:>10.2f} {student_params:>15,}")
print("-" * 55)
print(f"{'compression':<30} {compression_ratio:>10.1f}x")
print(f"{'ppl gap':<30} {student_ppl - teacher_ppl:>10.2f}")
print(f"{'spike density':<30} {student.get_avg_spike_density():>10.3f}")
print(f"{'VRAM peak':<30} {vram_peak_gb:>10.2f}GB")
print(f"{'final temperature':<30} {final_temp:>10.2f}")
print(f"{'final lambda (GRL)':<30} {final_lambda:>10.3f}")
print("")
print("CTKD Implementation:")
print(f"  tau range: [{config.tau_min:.1f}, {config.tau_max:.1f}]")
print(f"  lambda warmup ratio: {config.lambda_warmup_ratio:.0%}")
print(f"  GRL: Gradient Reversal Layer for adversarial min-max")
print("")
print("version comparison:")
print(f"  v6: 627.3 PPL (baseline)")
print(f"  v7: 1655 PPL (regression!)")
print(f"  v8: 559 PPL (fixed)")
print(f"  v9: 541.7 PPL (capacity increase)")
print(f"  v10: 514.5 PPL (320d/5L baseline)")
print(f"  v14: 512.67 PPL (channel-wise, WITH reg)")
print(f"  v14.1: 512.04 PPL (channel-wise, NO reg)")
print(f"  v12: FAILED (temp runaway without GRL)")
print(f"  v14: {student_ppl:.2f} PPL (POCL, T={final_temp:.2f}, λ={final_lambda:.3f})")
if student_ppl < 500:
    print(f"  {config.VERSION} TARGET MET! PPL < 500")
elif student_ppl < 512.04:
    print(f"  v14.1 beats v14 by {424.81 - student_ppl:.1f} PPL")
else:
    print(f"  WARNING: v14.1 did not improve over v14")


In [None]:
# =============================================================================
# cell 23: visualization
# =============================================================================
figure_path = None

if MATPLOTLIB_AVAILABLE:
    fig, axes = plt.subplots(2, 3, figsize=(15, 10))

    # distillation loss
    d_steps = [l['step'] for l in distill_logs['loss_history']]
    d_losses = [l['loss'] for l in distill_logs['loss_history']]
    kl_losses = [l['loss'] for l in distill_logs['kl_loss_history']]
    axes[0,0].plot(d_steps, d_losses, label='total', alpha=0.8)
    axes[0,0].plot(d_steps, kl_losses, label='kl', alpha=0.7)

    # CE loss (if available)
    if 'ce_loss_history' in distill_logs and distill_logs['ce_loss_history']:
        ce_losses = [l['loss'] for l in distill_logs['ce_loss_history']]
        axes[0,0].plot(d_steps, ce_losses, label='ce', alpha=0.6)
    axes[0,0].set_xlabel('step')
    axes[0,0].set_ylabel('loss')
    axes[0,0].set_title(f'distillation loss ({config.VERSION})')
    axes[0,0].legend()

    # validation ppl
    p_steps = [l['step'] for l in distill_logs['ppl_history']]
    p_ppls = [l['ppl'] for l in distill_logs['ppl_history']]
    axes[0,1].plot(p_steps, p_ppls, 'orange', marker='o')
    axes[0,1].axhline(y=teacher_ppl, color='green', linestyle='--', label=f'teacher ({teacher_ppl:.1f})')
    axes[0,1].axhline(y=627.3, color='blue', linestyle=':', label='v6 (627.3)')
    axes[0,1].axhline(y=541.7, color='purple', linestyle=':', label='v9 (541.7)')
    axes[0,1].axhline(y=300, color='red', linestyle='--', label=f'{config.VERSION} target')
    axes[0,1].set_xlabel('step')
    axes[0,1].set_ylabel('ppl')
    axes[0,1].set_title('validation ppl')
    axes[0,1].legend()

    # lr schedule
    lr_steps = [l['step'] for l in distill_logs['lr_history']]
    lr_vals = [l['lr'] for l in distill_logs['lr_history']]
    axes[0,2].plot(lr_steps, lr_vals, 'purple')
    axes[0,2].axvline(x=config.warmup_steps, color='gray', linestyle='--', label=f'warmup ({config.warmup_steps})')
    axes[0,2].set_xlabel('step')
    axes[0,2].set_ylabel('lr')
    axes[0,2].set_title('learning rate')
    axes[0,2].legend()

    # spike density + amplitudes (first 4 layers)
    spike_summary = spike_stats.get_summary()
    layers = [f'layer_{i}' for i in range(min(4, config.n_layers))]
    k_dens = [spike_summary['per_layer'][l]['k_final'] for l in layers]
    v_dens = [spike_summary['per_layer'][l]['v_final'] for l in layers]
    k_amps = [spike_summary['per_layer'][l]['k_amp_final'] for l in layers]
    v_amps = [spike_summary['per_layer'][l]['v_amp_final'] for l in layers]

    x = np.arange(len(layers))
    axes[1,0].bar(x - 0.2, k_dens, 0.4, label='k density')
    axes[1,0].bar(x + 0.2, v_dens, 0.4, label='v density')
    ax2 = axes[1,0].twinx()
    ax2.plot(x, k_amps, 'r-o', label='k amp')
    ax2.plot(x, v_amps, 'b-s', label='v amp')
    axes[1,0].set_xlabel('layer')
    axes[1,0].set_ylabel('density')
    ax2.set_ylabel('amplitude')
    axes[1,0].set_title(f'spike density & amps (first 4/{config.n_layers} layers)')
    axes[1,0].legend(loc='upper left')
    ax2.legend(loc='upper right')

    # ttt loss
    t_steps = [l['step'] for l in ttt_logs['loss_history']]
    t_losses = [l['loss'] for l in ttt_logs['loss_history']]
    axes[1,1].plot(t_steps, t_losses, 'red')
    axes[1,1].set_xlabel('step')
    axes[1,1].set_ylabel('ce loss')
    axes[1,1].set_title('ttt with lora')

    # version comparison
    # Historical versions for comparison (must match in length!)
    versions = ['v6', 'v9', 'v10', 'v12.1', 'v13.1', 'v14', 'v14.3', config.VERSION]
    # Teacher PPL is constant
    t_ppls = [44.6] * len(versions)
    # Student PPL history (from changelog)
    s_ppls = [627.3, 541.7, 514.5, 445.61, 434.44, 424.81, 306.89, student_ppl]
    assert len(versions) == len(t_ppls) == len(s_ppls), f"Array length mismatch: {len(versions)}, {len(t_ppls)}, {len(s_ppls)}"
    x = np.arange(len(versions))
    axes[1,2].bar(x - 0.2, t_ppls, 0.4, label='teacher', alpha=0.7)
    axes[1,2].bar(x + 0.2, s_ppls, 0.4, label='student', alpha=0.7)
    axes[1,2].axhline(y=300, color='red', linestyle='--', label=f'{config.VERSION} target', alpha=0.7)
    axes[1,2].set_xticks(x)
    axes[1,2].set_xticklabels(versions)
    axes[1,2].set_ylabel('ppl')
    axes[1,2].set_title('version comparison')
    axes[1,2].legend()
    axes[1,2].set_yscale('log')

    plt.tight_layout()
    figure_path = f'{OUTPUT_DIR}/figures/v15_training_{RUN_TIMESTAMP}.png'
    plt.savefig(figure_path, dpi=300, bbox_inches='tight')
    plt.show()
    print(f"saved: {figure_path}")
else:
    print("matplotlib unavailable: skipped visualization cell (cell 23).")


In [None]:
# =============================================================================
# cell 24: build results dict (v13 - Proper CTKD with GRL)
# =============================================================================
print("building results (v13 - CTKD with Gradient Reversal Layer)...")

figure_base64 = None
training_plot_filename = None
if 'figure_path' in globals() and figure_path and os.path.exists(figure_path):
    with open(figure_path, 'rb') as f:
        figure_base64 = base64.b64encode(f.read()).decode('utf-8')
    training_plot_filename = os.path.basename(figure_path)
else:
    print("figure not available; continuing without embedded training plot")

# v13: Extract final lambda
final_lambda = distill_logs['temp_history'][-1].get('lambda', config.lambda_max) if distill_logs['temp_history'] else config.lambda_max

results = {
    'version': config.VERSION,
    'timestamp': datetime.now().isoformat(),
    'run_id': RUN_TIMESTAMP,
    'platform': PLATFORM,
    'description': 'CTKD with Gradient Reversal Layer - adversarial temperature learning',

    f'{config.VERSION}_design': {
        'principle': 'CTKD: Adversarial min-max optimization via GRL',
        'innovation': 'Gradient Reversal Layer makes temperature MAXIMIZE KL while student MINIMIZES',
        'rationale': 'Proper CTKD (ArXiv 2211.16231) requires adversarial training, not simple regularization',
        'why_v12_failed': 'v12 used simple regularization - optimizer pushed T to max for easy KL',
        'techniques': {
            'ctkd_with_grl': 'ENABLED - Gradient Reversal Layer for adversarial min-max (v13 KEY)',
            'lambda_scheduling': f'Cosine warmup 0->{config.lambda_max} with {config.lambda_warmup_ratio:.0%} warmup',
            'sigmoid_bounding': f'T bounded to [{config.tau_min}, {config.tau_max}] via sigmoid (smooth gradients)',
            'no_manual_reg': 'GRL eliminates need for manual temperature regularization',
            'progressive_stages': 'DISABLED',
            'channel_wise_spikes': 'DISABLED (structural symmetry issue)',
        },
        'grl_mechanism': {
            'forward_pass': 'Identity: GRL(x) = x',
            'backward_pass': 'Negation: dGRL/dx = -lambda',
            'effect': 'Temperature gradients reversed -> T maximizes KL loss',
        },
        'temperature_config': {
            'tau_min': config.tau_min,
            'tau_max': config.tau_max,
            'tau_init': config.tau_init,
            'lambda_max': config.lambda_max,
            'lambda_warmup_ratio': config.lambda_warmup_ratio,
        },
        'architecture': {
            'd_model': config.d_model,
            'n_layers': 5,
            'params': '~22M',
        },
        'speedups': {
            'gradient_checkpointing': USE_GRADIENT_CHECKPOINTING,
            'torch_compile': compile_success,
            'fused_optimizer': True,
            'accumulation_steps': config.accumulation_steps,
        },
        'unchanged': [
            'hidden_align_weight: 0.0',
            'warmup_steps: 50',
            'distill_steps: 3000',
        ],
    },

    'architecture': {
        'teacher': {'name': 'gpt2', 'params': teacher_params},
        'student': {
            'name': f'asnn-goose-{config.VERSION}',
            'd_model': config.d_model,
            'n_layers': config.n_layers,
            'params': student_params,
        },
        'projector_params': projector_params,
        'compression_ratio': compression_ratio,
        'vram_peak_gb': vram_peak_gb,
    },

    'training_config': {
        'distill_steps': config.distill_steps,
        'tau_min': config.tau_min,
        'tau_max': config.tau_max,
        'tau_init': config.tau_init,
        'final_temperature': final_temp,
        'lambda_max': config.lambda_max,
        'lambda_warmup_ratio': config.lambda_warmup_ratio,
        'final_lambda': final_lambda,
        'hidden_align_weight': config.hidden_align_weight,
        'warmup_steps': config.warmup_steps,
        'batch_size': config.batch_size,
        'accumulation_steps': config.accumulation_steps,
        'effective_batch': config.batch_size * config.accumulation_steps,
        'distill_lr': config.distill_lr,
        'max_grad_norm': config.max_grad_norm,
    },

    'results': {
        'teacher_ppl': teacher_ppl,
        'student_ppl': student_ppl,
        'ppl_gap': student_ppl - teacher_ppl,
        'spike_density': student.get_avg_spike_density(),
        'amplitudes': student.get_amplitudes(),
        'final_temperature': final_temp,
        'final_lambda': final_lambda,
        'target_met': student_ppl < 500,
    },

    'training_curves': {
        'loss_history': distill_logs['loss_history'],
        'kl_loss_history': distill_logs['kl_loss_history'],
        'ce_loss_history': distill_logs.get('ce_loss_history', []),
        'fdd_loss_history': distill_logs.get('fdd_loss_history', []),
        'spike_sem_loss_history': distill_logs.get('spike_sem_loss_history', []),
        'align_loss_history': distill_logs['align_loss_history'],
        'ppl_history': distill_logs['ppl_history'],
        'lr_history': distill_logs['lr_history'],
        'temp_history': distill_logs['temp_history'],  # v13: includes temperature AND lambda
        'lambda_history': distill_logs.get('lambda_history', []),
        'spike_sem_weight_history': distill_logs.get('spike_sem_weight_history', []),
    },

    'hardware_stats': hw_stats.get_summary(),
    'spike_analysis': spike_stats.get_summary(),

    'ttt': {
        'lora_params': lora_params,
        'pre_ppl': pre_ttt_ppl,
        'post_ppl': post_ttt_ppl,
        'improvement': pre_ttt_ppl - post_ttt_ppl,
        'loss_history': ttt_logs['loss_history'],
    },

    'comparison': {
        'v6': {'student_ppl': 627.3, 'note': 'baseline'},
        'v7': {'student_ppl': 1655, 'note': 'regression (align=1.0, T=4)'},
        'v8': {'student_ppl': 559, 'note': 'fixed defaults (align=0, T=2)'},
        'v9': {'student_ppl': 541.7, 'note': 'capacity increase (320d, 5L)'},
        'v10': {'student_ppl': 514.5, 'note': '320d/5L baseline'},
        'v14': {'student_ppl': 512.67, 'note': 'channel-wise WITH reg (bug)'},
        'v14.1': {'student_ppl': 512.04, 'note': 'channel-wise NO reg (symmetry issue)'},
        'v12': {'student_ppl': 'FAILED', 'note': 'temp runaway without GRL'},
        'v13': {'student_ppl': student_ppl, 'note': f'POCL (T={final_temp:.2f}, L={final_lambda:.3f})'},
    },

    'figures': {
        'training_plot': {
            'filename': training_plot_filename,
            'base64': figure_base64,
        }
    },

    # validation_tests will be added in cell 26
}

print("results dict built (validation_tests pending)")
print(f"  version: {config.VERSION} ({config.VERSION_DESC})")
print(f"  final_temperature: {final_temp:.2f}")
print(f"  final_lambda: {final_lambda:.3f}")


In [None]:
# =============================================================================
# cell 25: validation tests (v14.1 - 12 tests with FDD)
# =============================================================================
# These tests validate correct v14.1 implementation

print("="*60)
print(f"{config.VERSION} Validation Test Suite")
print("="*60)

tests = []

# =============================================================================
# Test 1: PPL Target (<400)
# =============================================================================
if distill_logs['ppl_history']:
    best_ppl_entry = min(distill_logs['ppl_history'], key=lambda x: x['ppl'])
    best_ppl = best_ppl_entry['ppl']
    target_ppl = 400  # target for this version
    ppl_pass = best_ppl < target_ppl
    tests.append(('PPL < 400', ppl_pass, f"best_ppl={best_ppl:.2f}, target={target_ppl}"))
else:
    tests.append(('PPL < 400', False, "No PPL history found"))

# =============================================================================
# Test 2: PPL Improvement over v13.1 (434.44)
# =============================================================================
if distill_logs['ppl_history']:
    best_ppl_entry = min(distill_logs['ppl_history'], key=lambda x: x['ppl'])
    v13_1_ppl = 434.44
    improvement = v13_1_ppl - best_ppl_entry['ppl']
    improve_pass = improvement > 0
    tests.append(('Improved over v13.1', improve_pass, f"improvement={improvement:.2f} PPL"))
else:
    tests.append(('Improved over v13.1', False, "No PPL history"))

# =============================================================================
# Test 3: Spike Density in Valid Range [0.1, 0.9]
# =============================================================================
if spike_stats.step_densities:
    final_density = spike_stats.step_densities[-1]['density']
    density_pass = 0.1 <= final_density <= 0.9
    tests.append(('Spike density [0.1, 0.9]', density_pass, f"density={final_density:.3f}"))
else:
    tests.append(('Spike density [0.1, 0.9]', False, "No spike history"))

# =============================================================================
# Test 4: Amplitudes in Healthy Range [0.3, 3.0]
# =============================================================================
amps = student.get_amplitudes()
amp_values = []
for layer_name, layer_amps in amps.items():
    amp_values.extend([layer_amps['k'], layer_amps['v']])
amp_min, amp_max = min(amp_values), max(amp_values)
amp_pass = 0.3 <= amp_min and amp_max <= 3.0
tests.append(('Amplitudes [0.3, 3.0]', amp_pass, f"range=[{amp_min:.2f}, {amp_max:.2f}]"))

# =============================================================================
# Test 5: Training Completed (all steps or early stopped)
# =============================================================================
if distill_logs['early_stopped']:
    training_pass = distill_logs['early_stop_step'] > config.distill_steps * 0.3
    tests.append(('Training completed', training_pass, f"Early stopped at {distill_logs['early_stop_step']} steps"))
else:
    training_pass = len(distill_logs['loss_history']) >= config.distill_steps * 0.95
    tests.append(('Training completed', training_pass, f"Completed {len(distill_logs['loss_history'])}/{config.distill_steps} steps"))

# =============================================================================
# Test 6: No NaN/Inf in Loss
# =============================================================================
nan_inf_found = False
for h in distill_logs['loss_history']:
    if h['loss'] != h['loss'] or h['loss'] == float('inf') or h['loss'] == float('-inf'):
        nan_inf_found = True
        break
nan_pass = not nan_inf_found
tests.append(('No NaN/Inf loss', nan_pass, "All losses finite" if nan_pass else "Found NaN/Inf"))

# =============================================================================
# Test 7: VRAM Usage Reasonable (<8GB)
# =============================================================================
if hasattr(hw_stats, 'get_summary'):
    hw_summary = hw_stats.get_summary()
    vram_gb = hw_summary.get('peak_gpu_memory_gb', 0)
    vram_pass = vram_gb < 8.0
    tests.append(('VRAM < 8GB', vram_pass, f"peak={vram_gb:.2f}GB"))
else:
    tests.append(('VRAM < 8GB', True, "hw_stats not available"))

# =============================================================================
# Test 8: FDD Was Active (v14.1)
# =============================================================================
if config.use_fdd:
    # FDD should have been active at some point
    fdd_losses = [h['loss'] for h in distill_logs['fdd_loss_history'] if h.get('loss') is not None and h['loss'] != 0]
    fdd_active_pass = len(fdd_losses) > 0
    if distill_logs['fdd_killed']:
        status = f"Active then KILLED at step {distill_logs['fdd_kill_step']}"
    else:
        status = f"Active for {len(fdd_losses)} steps"
    tests.append(('FDD was active', fdd_active_pass, status))
else:
    tests.append(('FDD was active', True, "FDD disabled in config"))

# =============================================================================
# Test 9: CTKD Temperature Evolved (v14.1)
# =============================================================================
if config.use_ctkd and distill_logs['temp_history']:
    temps = [h['temperature'] for h in distill_logs['temp_history']]
    start_temp = temps[0]
    end_temp = temps[-1]
    temp_evolved = abs(end_temp - start_temp) > 0.1  # Should have moved
    tests.append(('Temperature evolved', temp_evolved, f"start={start_temp:.2f}, end={end_temp:.2f}"))
else:
    tests.append(('Temperature evolved', True, "CTKD disabled or no temp history"))

# =============================================================================
# Test 10: Early Stopping Working (if triggered)
# =============================================================================
if config.use_early_stopping:
    if distill_logs['early_stopped']:
        es_step = distill_logs['early_stop_step']
        es_pass = config.distill_steps * 0.3 < es_step < config.distill_steps
        tests.append(('Early stopping working', es_pass, f"stopped at {es_step}"))
    else:
        if distill_logs['ppl_history']:
            best_ppl_entry = min(distill_logs['ppl_history'], key=lambda x: x['ppl'])
            last_improvement_step = best_ppl_entry['step']
            final_step = distill_logs['ppl_history'][-1]['step']
            gap = final_step - last_improvement_step
            es_pass = gap <= config.early_stopping_patience + config.eval_interval
            tests.append(('Early stopping working', es_pass, f"last improvement at step {last_improvement_step}"))
        else:
            tests.append(('Early stopping working', True, "No PPL history"))
else:
    tests.append(('Early stopping working', True, "Early stopping disabled"))

# =============================================================================
# Test 11: FDD Loss Decreased (v14.1)
# =============================================================================
if config.use_fdd and distill_logs['fdd_loss_history']:
    fdd_losses = [h['loss'] for h in distill_logs['fdd_loss_history'] if h.get('loss') is not None and h['loss'] != 0]
    if len(fdd_losses) >= 10:
        start_fdd = sum(fdd_losses[:5]) / 5
        end_fdd = sum(fdd_losses[-5:]) / 5
        fdd_decreased = end_fdd < start_fdd
        tests.append(('FDD loss decreased', fdd_decreased, f"start={start_fdd:.4f}, end={end_fdd:.4f}"))
    else:
        tests.append(('FDD loss decreased', True, "Not enough FDD data points"))
else:
    tests.append(('FDD loss decreased', True, "FDD disabled or no history"))

# =============================================================================
# Test 12: Extended Training (5000 steps)
# =============================================================================
extended_pass = config.distill_steps >= 5000
tests.append(('Extended training (5000+)', extended_pass, f"distill_steps={config.distill_steps}"))

# =============================================================================
# Report Results
# =============================================================================
print("\n" + "-"*60)
print("TEST RESULTS")
print("-"*60)

passed = 0
failed = 0
for name, result, details in tests:
    status = "PASS" if result else "FAIL"
    symbol = "V" if result else "X"
    print(f"[{symbol}] {name}: {details}")
    if result:
        passed += 1
    else:
        failed += 1

print("-"*60)
print(f"SUMMARY: {passed}/{len(tests)} tests passed")
if failed > 0:
    print(f"WARNING: {failed} tests failed!")
else:
    print(f"ALL TESTS PASSED! {config.VERSION} implementation validated.")
print("="*60)

# Store results
validation_results = {
    'tests': tests,
    'passed': passed,
    'failed': failed,
    'total': len(tests)
}


In [None]:
# =============================================================================
# cell 27: V15 SpikingBrain - Information Encoding Validation
# =============================================================================
# Validate that spike patterns encode meaningful semantic information
# Prerequisite for v16 (sparse ops)

import numpy as np
from dataclasses import dataclass, field
from typing import Dict, List, Tuple, Optional, Any

# =============================================================================
# INLINE: SpikingBrain Validation Classes
# =============================================================================

@dataclass
class SpikeHealthMetrics:
    """Container for spike health metrics."""
    dead_neuron_pct: float
    dead_neuron_indices: Dict[str, np.ndarray]
    saturated_neuron_pct: float
    saturated_neuron_indices: Dict[str, np.ndarray]
    firing_rate_mean: float
    firing_rate_std: float
    per_channel_rates: Dict[str, np.ndarray]
    health_pass: bool
    alerts: List[str]

@dataclass
class SpikingBrainValidation:
    """Complete validation results."""
    health: SpikeHealthMetrics
    mutual_information: Dict[str, float]
    cka: Dict[str, float]
    overall_pass: bool
    summary: str

    def to_dict(self) -> Dict[str, Any]:
        return {
            'health': {
                'dead_neuron_pct': self.health.dead_neuron_pct,
                'saturated_neuron_pct': self.health.saturated_neuron_pct,
                'firing_rate_mean': self.health.firing_rate_mean,
                'firing_rate_std': self.health.firing_rate_std,
                'health_pass': self.health.health_pass,
                'alerts': self.health.alerts,
            },
            'mutual_information': self.mutual_information,
            'cka': self.cka,
            'overall_pass': self.overall_pass,
        }


class MutualInformationEstimator:
    """Estimate MI between spikes and teacher hiddens using binning."""

    def __init__(self, n_dims: int = 8, n_bins: int = 32):
        self.n_dims = n_dims
        self.n_bins = n_bins

    @torch.no_grad()
    def estimate(
        self,
        spikes: torch.Tensor,
        teacher_hidden: torch.Tensor,
    ) -> float:
        """Binning-based MI estimation."""
        # Flatten to 2D
        spikes_flat = spikes.reshape(-1, spikes.shape[-1])[:, :self.n_dims]
        teacher_flat = teacher_hidden.reshape(-1, teacher_hidden.shape[-1])[:, :self.n_dims]

        n_samples = min(spikes_flat.shape[0], 10000)
        spikes_flat = spikes_flat[:n_samples].cpu().numpy()
        teacher_flat = teacher_flat[:n_samples].cpu().numpy()

        # Bin teacher values
        teacher_binned = np.zeros_like(teacher_flat, dtype=np.int32)
        for d in range(self.n_dims):
            col = teacher_flat[:, d]
            bins = np.linspace(col.min() - 1e-10, col.max() + 1e-10, self.n_bins + 1)
            teacher_binned[:, d] = np.digitize(col, bins) - 1

        # Robust ternary discretization by sign (independent of learned amplitude).
        spikes_discrete = np.ones_like(spikes_flat, dtype=np.int32)
        spikes_discrete[spikes_flat > 1e-6] = 2
        spikes_discrete[spikes_flat < -1e-6] = 0

        # Compute MI per dimension and average
        mi_per_dim = []
        for d in range(self.n_dims):
            # Joint histogram
            joint = np.zeros((3, self.n_bins))
            for i in range(n_samples):
                s_idx = spikes_discrete[i, d]
                t_idx = max(0, min(teacher_binned[i, d], self.n_bins - 1))
                joint[s_idx, t_idx] += 1
            joint = joint / n_samples + 1e-10

            # Marginals
            p_s = joint.sum(axis=1, keepdims=True)
            p_t = joint.sum(axis=0, keepdims=True)

            # MI
            mi = np.sum(joint * np.log2(joint / (p_s * p_t + 1e-10)))
            mi_per_dim.append(max(0, mi))

        return float(np.mean(mi_per_dim))


class RepresentationAnalyzer:
    """Compute CKA between spike patterns and teacher representations."""

    @staticmethod
    def linear_cka(X: np.ndarray, Y: np.ndarray) -> float:
        """Compute linear CKA similarity."""
        X = X - X.mean(axis=0)
        Y = Y - Y.mean(axis=0)

        hsic_xy = np.linalg.norm(X.T @ Y, 'fro') ** 2
        hsic_xx = np.linalg.norm(X.T @ X, 'fro') ** 2
        hsic_yy = np.linalg.norm(Y.T @ Y, 'fro') ** 2

        return float(hsic_xy / (np.sqrt(hsic_xx * hsic_yy) + 1e-10))

    @torch.no_grad()
    def compute_cka(
        self,
        spikes: torch.Tensor,
        teacher_hidden: torch.Tensor,
        max_samples: int = 5000,
    ) -> float:
        """Compute CKA between spikes and teacher hidden states."""
        spikes_flat = spikes.reshape(-1, spikes.shape[-1])
        teacher_flat = teacher_hidden.reshape(-1, teacher_hidden.shape[-1])

        n_samples = min(spikes_flat.shape[0], max_samples)
        X = spikes_flat[:n_samples].float().cpu().numpy()
        Y = teacher_flat[:n_samples].float().cpu().numpy()

        return self.linear_cka(X, Y)


class SpikingBrainValidator:
    """Main validator for V15 SpikingBrain validation."""

    def __init__(
        self,
        device: torch.device,
        layer_map: Dict[int, int],
        dead_threshold: float = 0.05,
        saturated_threshold: float = 0.10,
        mi_threshold: float = 0.1,
        cka_threshold: float = 0.3,
        firing_rate_range: Tuple[float, float] = (0.2, 0.6),
    ):
        self.device = device
        self.layer_map = layer_map
        self.dead_threshold = dead_threshold
        self.saturated_threshold = saturated_threshold
        self.mi_threshold = mi_threshold
        self.cka_threshold = cka_threshold
        self.firing_rate_range = firing_rate_range

        self.mi_estimator = MutualInformationEstimator()
        self.cka_analyzer = RepresentationAnalyzer()

    @torch.no_grad()
    def validate(
        self,
        student: torch.nn.Module,
        teacher: torch.nn.Module,
        dataloader,
        max_batches: int = 20,
    ) -> SpikingBrainValidation:
        """Run complete SpikingBrain validation."""
        student.eval()
        teacher.eval()

        # Collect spikes and teacher hiddens
        all_spikes = {}  # layer_idx -> {'k': [], 'v': []}
        all_teacher_hiddens = {}  # teacher_layer -> list of tensors

        print("Collecting spikes and teacher representations...")
        for batch_idx, batch in enumerate(dataloader):
            if batch_idx >= max_batches:
                break

            if isinstance(batch, dict):
                input_ids = batch['input_ids'].to(self.device)
            elif isinstance(batch, (list, tuple)):
                input_ids = batch[0].to(self.device)
            else:
                raise TypeError(f"Unsupported batch type for validation: {type(batch)}")

            # Student forward with spikes
            student_out = student(input_ids, return_spike_info=True)
            aux = {}
            if isinstance(student_out, tuple):
                if len(student_out) == 2:
                    _, aux = student_out
                elif len(student_out) == 3:
                    _, _, aux = student_out
            spike_info = aux.get('spike_info', {}) if isinstance(aux, dict) else {}

            for layer_idx, layer_spikes in spike_info.items():
                if layer_idx not in all_spikes:
                    all_spikes[layer_idx] = {'k': [], 'v': []}
                if isinstance(layer_spikes, dict):
                    k_spikes = layer_spikes.get('k_spikes')
                    v_spikes = layer_spikes.get('v_spikes')
                    if k_spikes is not None:
                        all_spikes[layer_idx]['k'].append(k_spikes.cpu())
                    if v_spikes is not None:
                        all_spikes[layer_idx]['v'].append(v_spikes.cpu())
                elif isinstance(layer_spikes, list):
                    for s in layer_spikes:
                        if isinstance(s, dict):
                            if 'k_spikes' in s:
                                all_spikes[layer_idx]['k'].append(s['k_spikes'].cpu())
                            if 'v_spikes' in s:
                                all_spikes[layer_idx]['v'].append(s['v_spikes'].cpu())

            # Teacher forward (get hidden states from mapped layers)
            with torch.no_grad():
                teacher_out = teacher(input_ids, output_hidden_states=True)
                for student_layer, teacher_layer in self.layer_map.items():
                    if teacher_layer not in all_teacher_hiddens:
                        all_teacher_hiddens[teacher_layer] = []
                    h = teacher_out.hidden_states[teacher_layer + 1].cpu()
                    all_teacher_hiddens[teacher_layer].append(h)

        # 1. Compute health metrics
        print("Computing health metrics...")
        health = self._compute_health(all_spikes)

        # 2. Compute MI
        print("Estimating mutual information...")
        mi_results = self._compute_mi(all_spikes, all_teacher_hiddens)

        # 3. Compute CKA
        print("Computing CKA similarity...")
        cka_results = self._compute_cka(all_spikes, all_teacher_hiddens)

        # 4. Overall pass check
        overall_pass = (
            health.health_pass and
            mi_results.get('mutual_information', 0) >= self.mi_threshold and
            cka_results.get('cka_mean', 0) >= self.cka_threshold
        )

        # 5. Generate summary
        summary = self._generate_summary(health, mi_results, cka_results, overall_pass)

        return SpikingBrainValidation(
            health=health,
            mutual_information=mi_results,
            cka=cka_results,
            overall_pass=overall_pass,
            summary=summary,
        )

    def _compute_health(self, all_spikes: Dict[int, Dict[str, List[torch.Tensor]]]) -> SpikeHealthMetrics:
        """Compute spike health metrics."""
        dead_indices = {}
        saturated_indices = {}
        per_channel_rates = {}
        all_rates = []

        total_dead = 0
        total_saturated = 0
        total_channels = 0

        for layer_idx, layer_spikes in all_spikes.items():
            k_list = layer_spikes.get('k', [])
            v_list = layer_spikes.get('v', [])
            if not k_list and not v_list:
                continue

            if k_list and v_list:
                k_stacked = torch.cat([s.view(-1, s.shape[-1]) for s in k_list], dim=0)
                v_stacked = torch.cat([s.view(-1, s.shape[-1]) for s in v_list], dim=0)
                active = ((k_stacked != 0) | (v_stacked != 0)).float()
            else:
                base_list = k_list if k_list else v_list
                base = torch.cat([s.view(-1, s.shape[-1]) for s in base_list], dim=0)
                active = (base != 0).float()

            d_model = active.shape[-1]

            # Per-channel firing rates
            rates = active.mean(dim=0).numpy()
            per_channel_rates[f'layer_{layer_idx}'] = rates
            all_rates.append(rates)

            # Dead neurons (firing rate < 0.001)
            dead_mask = rates < 0.001
            dead_indices[f'layer_{layer_idx}'] = np.where(dead_mask)[0]
            total_dead += dead_mask.sum()

            # Saturated neurons (always fire)
            always_active = (active > 0.999).all(dim=0).numpy()
            saturated_indices[f'layer_{layer_idx}'] = np.where(always_active)[0]
            total_saturated += always_active.sum()

            total_channels += d_model

        if not all_rates:
            return SpikeHealthMetrics(
                dead_neuron_pct=1.0,
                dead_neuron_indices={},
                saturated_neuron_pct=0.0,
                saturated_neuron_indices={},
                firing_rate_mean=0.0,
                firing_rate_std=0.0,
                per_channel_rates={},
                health_pass=False,
                alerts=['No spike tensors captured during validation.'],
            )

        all_rates_flat = np.concatenate(all_rates)
        dead_pct = total_dead / total_channels if total_channels > 0 else 0
        saturated_pct = total_saturated / total_channels if total_channels > 0 else 0

        # Check health
        alerts = []
        if dead_pct > self.dead_threshold:
            alerts.append(f"Dead neurons: {dead_pct*100:.1f}% > {self.dead_threshold*100:.0f}%")
        if saturated_pct > self.saturated_threshold:
            alerts.append(f"Saturated neurons: {saturated_pct*100:.1f}% > {self.saturated_threshold*100:.0f}%")

        fr_mean = float(np.mean(all_rates_flat))
        if not (self.firing_rate_range[0] <= fr_mean <= self.firing_rate_range[1]):
            alerts.append(f"Firing rate {fr_mean:.3f} outside range {self.firing_rate_range}")

        health_pass = len(alerts) == 0

        return SpikeHealthMetrics(
            dead_neuron_pct=float(dead_pct),
            dead_neuron_indices=dead_indices,
            saturated_neuron_pct=float(saturated_pct),
            saturated_neuron_indices=saturated_indices,
            firing_rate_mean=fr_mean,
            firing_rate_std=float(np.std(all_rates_flat)),
            per_channel_rates=per_channel_rates,
            health_pass=health_pass,
            alerts=alerts,
        )

    def _compute_mi(
        self,
        all_spikes: Dict[int, Dict[str, List[torch.Tensor]]],
        all_teacher_hiddens: Dict[int, List[torch.Tensor]],
    ) -> Dict[str, float]:
        """Compute mutual information."""
        mi_per_layer = {}

        for student_layer, teacher_layer in self.layer_map.items():
            if student_layer not in all_spikes or teacher_layer not in all_teacher_hiddens:
                continue

            k_list = all_spikes[student_layer].get('k', [])
            v_list = all_spikes[student_layer].get('v', [])
            if not k_list and not v_list:
                continue
            hiddens = torch.cat(all_teacher_hiddens[teacher_layer], dim=0)

            layer_vals = []
            if k_list:
                mi_k = self.mi_estimator.estimate(torch.cat(k_list, dim=0), hiddens)
                mi_per_layer[f'layer_{student_layer}_to_{teacher_layer}_k'] = mi_k
                layer_vals.append(mi_k)
            if v_list:
                mi_v = self.mi_estimator.estimate(torch.cat(v_list, dim=0), hiddens)
                mi_per_layer[f'layer_{student_layer}_to_{teacher_layer}_v'] = mi_v
                layer_vals.append(mi_v)
            if layer_vals:
                mi_per_layer[f'layer_{student_layer}_to_{teacher_layer}'] = float(np.mean(layer_vals))

        mi_mean = np.mean(list(mi_per_layer.values())) if mi_per_layer else 0.0

        return {
            **mi_per_layer,
            'mutual_information': float(mi_mean),
        }

    def _compute_cka(
        self,
        all_spikes: Dict[int, Dict[str, List[torch.Tensor]]],
        all_teacher_hiddens: Dict[int, List[torch.Tensor]],
    ) -> Dict[str, float]:
        """Compute CKA similarity."""
        cka_per_layer = {}

        for student_layer, teacher_layer in self.layer_map.items():
            if student_layer not in all_spikes or teacher_layer not in all_teacher_hiddens:
                continue

            k_list = all_spikes[student_layer].get('k', [])
            v_list = all_spikes[student_layer].get('v', [])
            if not k_list and not v_list:
                continue
            hiddens = torch.cat(all_teacher_hiddens[teacher_layer], dim=0)

            layer_vals = []
            if k_list:
                cka_k = self.cka_analyzer.compute_cka(torch.cat(k_list, dim=0), hiddens)
                cka_per_layer[f'layer_{student_layer}_to_{teacher_layer}_k'] = cka_k
                layer_vals.append(cka_k)
            if v_list:
                cka_v = self.cka_analyzer.compute_cka(torch.cat(v_list, dim=0), hiddens)
                cka_per_layer[f'layer_{student_layer}_to_{teacher_layer}_v'] = cka_v
                layer_vals.append(cka_v)
            if layer_vals:
                cka_per_layer[f'layer_{student_layer}_to_{teacher_layer}'] = float(np.mean(layer_vals))

        cka_mean = np.mean(list(cka_per_layer.values())) if cka_per_layer else 0.0

        return {
            **cka_per_layer,
            'cka_mean': float(cka_mean),
        }

    def _generate_summary(
        self,
        health: SpikeHealthMetrics,
        mi_results: Dict[str, float],
        cka_results: Dict[str, float],
        overall_pass: bool,
    ) -> str:
        """Generate validation summary."""
        lines = [
            "=" * 60,
            "SPIKINGBRAIN VALIDATION SUMMARY",
            "=" * 60,
            "",
            "[HEALTH]",
            f"  Dead neurons: {health.dead_neuron_pct*100:.1f}% {'PASS' if health.dead_neuron_pct < self.dead_threshold else 'FAIL'}",
            f"  Saturated neurons: {health.saturated_neuron_pct*100:.1f}% {'PASS' if health.saturated_neuron_pct < self.saturated_threshold else 'FAIL'}",
            f"  Firing rate: {health.firing_rate_mean:.3f} +/- {health.firing_rate_std:.3f}",
            "",
            "[INFORMATION]",
            f"  Mutual Information: {mi_results.get('mutual_information', 0):.4f} {'PASS' if mi_results.get('mutual_information', 0) >= self.mi_threshold else 'FAIL'}",
            "",
            "[REPRESENTATION]",
            f"  CKA (mean): {cka_results.get('cka_mean', 0):.4f} {'PASS' if cka_results.get('cka_mean', 0) >= self.cka_threshold else 'FAIL'}",
            "",
            "=" * 60,
            f"OVERALL: {'PASS - Ready for v16 (sparse ops)' if overall_pass else 'NEEDS ATTENTION'}",
            "=" * 60,
        ]
        return "\n".join(lines)


# =============================================================================
# RUN VALIDATION
# =============================================================================

print('='*60)
print('V15: SPIKINGBRAIN INFORMATION ENCODING VALIDATION')
print('='*60)

# Initialize validator with v14 layer mapping
validator = SpikingBrainValidator(
    device=DEVICE,
    layer_map={0: 2, 2: 7, 4: 11},  # Student -> Teacher layer mapping
    dead_threshold=0.05,      # Alert if >5% dead neurons
    saturated_threshold=0.10,  # Alert if >10% saturated neurons
    mi_threshold=0.1,         # Minimum acceptable MI
    cka_threshold=0.3,        # Minimum acceptable CKA
    firing_rate_range=(0.2, 0.6),  # Healthy firing rate range
)

# Run validation
v15_results = validator.validate(
    student=student,
    teacher=teacher,
    dataloader=val_loader,
    max_batches=20,
)

# Print summary
print(v15_results.summary)


In [None]:
# =============================================================================
# cell 28: V15 SpikingBrain Visualizations
# =============================================================================

if not MATPLOTLIB_AVAILABLE:
    print("matplotlib unavailable: skipping V15 SpikingBrain visualizations.")
else:
    def plot_firing_rate_histogram(
        firing_rates: np.ndarray,
        target_rate: float = 0.38,
        title: str = 'Firing Rate Distribution',
        show: bool = True,
    ):
        """Plot histogram of per-channel firing rates."""
        fig, ax = plt.subplots(figsize=(10, 6))

        ax.hist(firing_rates, bins=50, alpha=0.7, color='steelblue', edgecolor='black')
        ax.axvline(target_rate, color='red', linestyle='--', linewidth=2, label=f'Target: {target_rate}')
        ax.axvline(firing_rates.mean(), color='orange', linestyle='-', linewidth=2, label=f'Mean: {firing_rates.mean():.3f}')

        # Healthy range shading
        ax.axvspan(0.2, 0.6, alpha=0.1, color='green', label='Healthy range [0.2, 0.6]')

        ax.set_xlabel('Firing Rate')
        ax.set_ylabel('Count')
        ax.set_title(title)
        ax.legend()
        ax.grid(True, alpha=0.3)

        plt.tight_layout()
        plt.savefig(f'{OUTPUT_DIR}/figures/v15_firing_rate_dist.png', dpi=150, bbox_inches='tight')
        if show:
            plt.show()
        plt.close()


    def plot_cka_by_layer(
        cka_values: Dict[str, float],
        threshold: float = 0.3,
        title: str = 'CKA Similarity by Layer',
        show: bool = True,
    ):
        """Plot CKA similarity as a bar chart."""
        # Filter to per-layer values only
        layer_cka = {k: v for k, v in cka_values.items() if 'layer_' in k and 'mean' not in k}

        if not layer_cka:
            print("No per-layer CKA values to plot")
            return

        fig, ax = plt.subplots(figsize=(10, 6))

        names = list(layer_cka.keys())
        values = [layer_cka[n] for n in names]
        colors = ['green' if v >= threshold else 'red' for v in values]

        bars = ax.bar(range(len(names)), values, color=colors, alpha=0.7, edgecolor='black')
        ax.axhline(threshold, color='orange', linestyle='--', linewidth=2, label=f'Threshold: {threshold}')
        ax.axhline(cka_values.get('cka_mean', 0), color='blue', linestyle='-', linewidth=2, label=f'Mean: {cka_values.get("cka_mean", 0):.3f}')

        ax.set_xticks(range(len(names)))
        ax.set_xticklabels([n.replace('layer_', 'L').replace('_to_', '->') for n in names], rotation=45, ha='right')
        ax.set_ylabel('CKA Similarity')
        ax.set_title(title)
        ax.legend()
        ax.grid(True, alpha=0.3, axis='y')

        plt.tight_layout()
        plt.savefig(f'{OUTPUT_DIR}/figures/v15_cka_by_layer.png', dpi=150, bbox_inches='tight')
        if show:
            plt.show()
        plt.close()


    # Collect all firing rates from health metrics
    all_rates = []
    for key, rates in v15_results.health.per_channel_rates.items():
        if len(rates) > 0:
            all_rates.append(rates)

    if all_rates:
        combined_rates = np.concatenate(all_rates)

        # Firing rate histogram
        plot_firing_rate_histogram(
            firing_rates=combined_rates,
            target_rate=0.38,
            title=f'V15 Firing Rate Distribution (mean={combined_rates.mean():.3f})',
            show=True,
        )

    # CKA by layer
    if v15_results.cka:
        plot_cka_by_layer(
            cka_values=v15_results.cka,
            threshold=0.3,
            title='V15 CKA Similarity: Spikes vs Teacher',
            show=True,
        )

    print(f'\nVisualizations saved to {OUTPUT_DIR}/figures/')


In [None]:
# =============================================================================
# cell 29: V15 Success Criteria Check
# =============================================================================

print('='*60)
print('V15 SUCCESS CRITERIA')
print('='*60)

tests = []

# Test 1: Dead neurons < 5%
dead_pct = v15_results.health.dead_neuron_pct
test1 = dead_pct < 0.05
tests.append(('Dead neurons < 5%', test1, f'{dead_pct*100:.1f}%'))

# Test 2: Saturated neurons < 10%
sat_pct = v15_results.health.saturated_neuron_pct
test2 = sat_pct < 0.10
tests.append(('Saturated neurons < 10%', test2, f'{sat_pct*100:.1f}%'))

# Test 3: MI > 0.1
mi_val = v15_results.mutual_information.get('mutual_information', 0)
test3 = mi_val > 0.1
tests.append(('MI > 0.1', test3, f'{mi_val:.4f}'))

# Test 4: CKA mean > 0.3
cka_mean = v15_results.cka.get('cka_mean', 0)
test4 = cka_mean > 0.3
tests.append(('CKA mean > 0.3', test4, f'{cka_mean:.4f}'))

# Test 5: Firing rate in healthy range [0.2, 0.6]
fr_mean = v15_results.health.firing_rate_mean
test5 = 0.2 <= fr_mean <= 0.6
tests.append(('Firing rate [0.2, 0.6]', test5, f'{fr_mean:.3f}'))

# Print results
for name, passed, value in tests:
    status = 'PASS' if passed else 'FAIL'
    print(f'  [{status}] {name}: {value}')

all_pass = all(t[1] for t in tests)
print(f'\nOverall: {"ALL PASS - Ready for v16 (sparse ops)" if all_pass else "NEEDS ATTENTION"}')
print('='*60)

# Store in results
results['v15_spiking_brain'] = {
    'validation': v15_results.to_dict(),
    'tests': {name: {'passed': passed, 'value': value} for name, passed, value in tests},
    'all_pass': all_pass,
}

print('\nV15 results added to results dict')


In [None]:
# =============================================================================
# cell 26: FINAL save + autonomous v15 artifact bundle + single-file dossier
# =============================================================================
print("="*60)
print("FINAL SAVE + AUTONOMY ARTIFACTS (v15)")
print("="*60)

# Add validation_tests to results
results['validation_tests'] = validation_results

# Save final legacy results json (kept for backward compatibility)
results_path = f'{OUTPUT_DIR}/results/results_{RUN_TIMESTAMP}.json'
with open(results_path, 'w') as f:
    json.dump(results, f, indent=2, default=str)

print(f"saved legacy results: {results_path}")
print(f"size: {os.path.getsize(results_path) / 1024:.1f} KB")

# -----------------------------------------------------------------------------
# Build canonical per-run artifact pack for autonomous ingestion
# -----------------------------------------------------------------------------
run_id = f"{config.VERSION}_{RUN_TIMESTAMP}".replace(" ", "_").replace(":", "-")
run_artifact_dir = f"{OUTPUT_DIR}/{run_id}"
os.makedirs(run_artifact_dir, exist_ok=True)

best_ppl = None
best_step = None
if distill_logs.get('ppl_history'):
    best_entry = min(distill_logs['ppl_history'], key=lambda x: x['ppl'])
    best_ppl = float(best_entry['ppl'])
    best_step = int(best_entry['step'])

metrics_payload = {
    "run_id": run_id,
    "phase": "B",
    "version": config.VERSION,
    "timestamp_utc": datetime.utcnow().isoformat() + "Z",
    "seed": int(SEED),
    "teacher_ppl": float(teacher_ppl),
    "student_ppl": float(student_ppl),
    "best_student_ppl": best_ppl,
    "best_step": best_step,
    "ppl_gap": float(student_ppl - teacher_ppl),
    "spike_density": float(student.get_avg_spike_density()),
    "v15_overall_pass": bool(v15_results.overall_pass),
    "v15_firing_rate_mean": float(v15_results.health.firing_rate_mean),
    "v15_mutual_information": float(v15_results.mutual_information.get('mutual_information', 0.0)),
    "v15_cka_mean": float(v15_results.cka.get('cka_mean', 0.0)),
}

eval_suite_payload = {
    "schema_version": "1.0",
    "run_id": run_id,
    "phase": "B",
    "version": config.VERSION,
    "git_hash": os.environ.get("GIT_COMMIT", "unknown"),
    "seed": int(SEED),
    "timestamp_utc": datetime.utcnow().isoformat() + "Z",
    "tasks": {
        "lm_ppl": {
            "metric": "ppl",
            "value": float(student_ppl),
            "baseline_reference": 306.89,
        },
        "v15_spike_health": {
            "dead_neuron_pct": float(v15_results.health.dead_neuron_pct),
            "saturated_neuron_pct": float(v15_results.health.saturated_neuron_pct),
            "firing_rate_mean": float(v15_results.health.firing_rate_mean),
            "pass": bool(v15_results.health.health_pass),
        },
        "v15_information": {
            "mutual_information": float(v15_results.mutual_information.get('mutual_information', 0.0)),
            "cka_mean": float(v15_results.cka.get('cka_mean', 0.0)),
            "overall_pass": bool(v15_results.overall_pass),
        },
    },
    "gate_recommendation": "green" if v15_results.overall_pass else "red",
}

v15_payload = {
    "run_id": run_id,
    "phase": "B",
    "version": config.VERSION,
    "validation": v15_results.to_dict(),
    "success_criteria_tests": {name: {"passed": passed, "value": value} for name, passed, value in tests},
    "overall_pass": bool(v15_results.overall_pass),
}

config_payload = {
    "run_id": run_id,
    "phase": "B",
    "version": config.VERSION,
    "version_desc": config.VERSION_DESC,
    "seed": int(SEED),
    "platform": PLATFORM,
    "device": str(DEVICE),
    "output_dir": OUTPUT_DIR,
    "run_timestamp": RUN_TIMESTAMP,
    "config": asdict(config),
}

with open(f"{run_artifact_dir}/metrics.json", "w") as f:
    json.dump(metrics_payload, f, indent=2, default=str)

with open(f"{run_artifact_dir}/eval_suite.json", "w") as f:
    json.dump(eval_suite_payload, f, indent=2, default=str)

with open(f"{run_artifact_dir}/v15_spikingbrain.json", "w") as f:
    json.dump(v15_payload, f, indent=2, default=str)

with open(f"{run_artifact_dir}/seed.txt", "w") as f:
    f.write(str(SEED) + "\n")

config_yaml_path = f"{run_artifact_dir}/config.yaml"
try:
    import yaml
    with open(config_yaml_path, "w") as f:
        yaml.safe_dump(config_payload, f, sort_keys=False)
except Exception as e:
    # Fallback: keep required artifact name, store JSON-formatted content.
    with open(config_yaml_path, "w") as f:
        f.write(json.dumps(config_payload, indent=2, default=str))
    print(f"warning: yaml export fallback used ({e})")

# Copy key legacy outputs into artifact bundle
try:
    import shutil
    if os.path.exists(results_path):
        shutil.copy2(results_path, f"{run_artifact_dir}/results.json")
except Exception as e:
    print(f"warning: could not copy legacy outputs: {e}")

print("")
print("Canonical artifacts written:")
print(f"  {run_artifact_dir}/eval_suite.json")
print(f"  {run_artifact_dir}/metrics.json")
print(f"  {run_artifact_dir}/config.yaml")
print(f"  {run_artifact_dir}/seed.txt")
print(f"  {run_artifact_dir}/v15_spikingbrain.json")

# -----------------------------------------------------------------------------
# Build a single-file detailed dossier (HTML with embedded figures and raw data)
# -----------------------------------------------------------------------------
import base64
import math
import statistics
from io import BytesIO
from html import escape as html_escape

single_file_path = f"{run_artifact_dir}/run_dossier_{run_id}.html"
single_file_primary_output = single_file_path

def _series(history, value_key):
    xs, ys = [], []
    for row in history or []:
        if not isinstance(row, dict):
            continue
        if 'step' not in row or value_key not in row:
            continue
        try:
            y = float(row[value_key])
            x = int(row['step'])
        except Exception:
            continue
        if math.isfinite(y):
            xs.append(x)
            ys.append(y)
    return xs, ys

def _stats(values):
    if not values:
        return {}
    return {
        "count": len(values),
        "min": float(min(values)),
        "max": float(max(values)),
        "mean": float(sum(values) / len(values)),
        "std": float(statistics.pstdev(values)) if len(values) > 1 else 0.0,
        "last": float(values[-1]),
    }

def _moving_avg(values, window=100):
    if not values:
        return []
    out = []
    for i in range(len(values)):
        lo = max(0, i - window + 1)
        chunk = values[lo:i+1]
        out.append(sum(chunk) / len(chunk))
    return out

detailed_figures = {}
detailed_metrics = {}

if MATPLOTLIB_AVAILABLE:
    def _save_fig(fig, name):
        # Single-file mode: keep figures embedded only; do not emit sidecar PNGs.
        buf = BytesIO()
        fig.savefig(buf, format='png', dpi=180, bbox_inches='tight')
        buf.seek(0)
        b64 = base64.b64encode(buf.read()).decode('utf-8')
        return "embedded", b64

    def _plot_lines(name, title, xlabel, ylabel, lines, logy=False):
        fig, ax = plt.subplots(figsize=(12, 5))
        any_line = False
        for line in lines:
            label = line.get('label')
            xs = line.get('x', [])
            ys = line.get('y', [])
            color = line.get('color')
            if xs and ys:
                ax.plot(xs, ys, label=label, linewidth=1.5, color=color)
                any_line = True
        if not any_line:
            plt.close(fig)
            return
        if logy:
            ax.set_yscale('log')
        ax.set_title(title)
        ax.set_xlabel(xlabel)
        ax.set_ylabel(ylabel)
        ax.grid(True, alpha=0.3)
        ax.legend()
        path, b64 = _save_fig(fig, name)
        detailed_figures[name] = {"title": title, "path": path, "base64": b64}
        plt.close(fig)

    tc = results.get('training_curves', {})

    loss_x, loss_y = _series(tc.get('loss_history', []), 'loss')
    kl_x, kl_y = _series(tc.get('kl_loss_history', []), 'loss')
    align_x, align_y = _series(tc.get('align_loss_history', []), 'loss')
    ce_x, ce_y = _series(tc.get('ce_loss_history', []), 'loss')
    fdd_x, fdd_y = _series(tc.get('fdd_loss_history', []), 'loss')
    ppl_x, ppl_y = _series(tc.get('ppl_history', []), 'ppl')
    lr_x, lr_y = _series(tc.get('lr_history', []), 'lr')
    temp_x, temp_y = _series(tc.get('temp_history', []), 'temperature')
    lam_x, lam_y = _series(tc.get('temp_history', []), 'lambda')
    if not lam_y:
        lam_x, lam_y = _series(tc.get('lambda_history', []), 'lambda')

    detailed_metrics["loss"] = _stats(loss_y)
    detailed_metrics["kl_loss"] = _stats(kl_y)
    detailed_metrics["align_loss"] = _stats(align_y)
    detailed_metrics["ce_loss"] = _stats(ce_y)
    detailed_metrics["fdd_loss"] = _stats(fdd_y)
    detailed_metrics["ppl"] = _stats(ppl_y)
    detailed_metrics["lr"] = _stats(lr_y)
    detailed_metrics["temperature"] = _stats(temp_y)
    detailed_metrics["lambda"] = _stats(lam_y)

    _plot_lines(
        name="01_loss_components",
        title="Loss Components Over Steps",
        xlabel="step",
        ylabel="loss",
        lines=[
            {"label": "total_loss", "x": loss_x, "y": loss_y, "color": "black"},
            {"label": "kl_loss", "x": kl_x, "y": kl_y, "color": "tab:blue"},
            {"label": "ce_loss", "x": ce_x, "y": ce_y, "color": "tab:orange"},
            {"label": "fdd_loss", "x": fdd_x, "y": fdd_y, "color": "tab:green"},
            {"label": "align_loss", "x": align_x, "y": align_y, "color": "tab:red"},
        ],
    )
    _plot_lines(
        name="02_loss_components_log",
        title="Loss Components Over Steps (log scale)",
        xlabel="step",
        ylabel="loss",
        lines=[
            {"label": "total_loss", "x": loss_x, "y": loss_y, "color": "black"},
            {"label": "kl_loss", "x": kl_x, "y": kl_y, "color": "tab:blue"},
            {"label": "ce_loss", "x": ce_x, "y": ce_y, "color": "tab:orange"},
            {"label": "fdd_loss", "x": fdd_x, "y": fdd_y, "color": "tab:green"},
            {"label": "align_loss", "x": align_x, "y": align_y, "color": "tab:red"},
        ],
        logy=True,
    )

    loss_ma = _moving_avg(loss_y, window=100)
    _plot_lines(
        name="03_total_loss_smoothed",
        title="Total Loss (raw + moving average)",
        xlabel="step",
        ylabel="loss",
        lines=[
            {"label": "total_loss_raw", "x": loss_x, "y": loss_y, "color": "lightgray"},
            {"label": "total_loss_ma100", "x": loss_x, "y": loss_ma, "color": "black"},
        ],
    )
    _plot_lines(
        name="04_ppl_curve",
        title="Validation PPL Over Eval Steps",
        xlabel="step",
        ylabel="ppl",
        lines=[{"label": "val_ppl", "x": ppl_x, "y": ppl_y, "color": "tab:purple"}],
    )
    _plot_lines(
        name="05_learning_rate",
        title="Learning Rate Schedule",
        xlabel="step",
        ylabel="lr",
        lines=[{"label": "lr", "x": lr_x, "y": lr_y, "color": "tab:green"}],
    )
    _plot_lines(
        name="06_temperature",
        title="CTKD Temperature",
        xlabel="step",
        ylabel="temperature",
        lines=[{"label": "temperature", "x": temp_x, "y": temp_y, "color": "tab:orange"}],
    )
    _plot_lines(
        name="07_lambda",
        title="CTKD Lambda / GRL Strength",
        xlabel="step",
        ylabel="lambda",
        lines=[{"label": "lambda", "x": lam_x, "y": lam_y, "color": "tab:red"}],
    )

    # Spike summary figures
    spike_summary = results.get('spike_analysis', {})
    per_layer = spike_summary.get('per_layer', {})
    if per_layer:
        layer_names = sorted(per_layer.keys(), key=lambda x: int(x.split('_')[-1]))
        k_density = [float(per_layer[n].get('k_final', 0.0)) for n in layer_names]
        v_density = [float(per_layer[n].get('v_final', 0.0)) for n in layer_names]
        k_amp = [float(per_layer[n].get('k_amp_final', 0.0)) for n in layer_names]
        v_amp = [float(per_layer[n].get('v_amp_final', 0.0)) for n in layer_names]

        fig, ax = plt.subplots(figsize=(12, 5))
        x = np.arange(len(layer_names))
        ax.bar(x - 0.2, k_density, 0.4, label='k_density')
        ax.bar(x + 0.2, v_density, 0.4, label='v_density')
        ax.set_xticks(x)
        ax.set_xticklabels(layer_names, rotation=30)
        ax.set_title("Per-layer Spike Density")
        ax.set_ylabel("density")
        ax.grid(True, alpha=0.3)
        ax.legend()
        path, b64 = _save_fig(fig, "08_spike_density_per_layer")
        detailed_figures["08_spike_density_per_layer"] = {"title": "Per-layer Spike Density", "path": path, "base64": b64}
        plt.close(fig)

        fig, ax = plt.subplots(figsize=(12, 5))
        x = np.arange(len(layer_names))
        ax.bar(x - 0.2, k_amp, 0.4, label='k_amplitude')
        ax.bar(x + 0.2, v_amp, 0.4, label='v_amplitude')
        ax.set_xticks(x)
        ax.set_xticklabels(layer_names, rotation=30)
        ax.set_title("Per-layer Spike Amplitude")
        ax.set_ylabel("amplitude")
        ax.grid(True, alpha=0.3)
        ax.legend()
        path, b64 = _save_fig(fig, "09_spike_amplitude_per_layer")
        detailed_figures["09_spike_amplitude_per_layer"] = {"title": "Per-layer Spike Amplitude", "path": path, "base64": b64}
        plt.close(fig)

    # Density history timeline
    density_history = spike_summary.get('density_history', [])
    if density_history:
        dx = []
        dy = []
        for row in density_history:
            if isinstance(row, dict) and 'step' in row and 'density' in row:
                try:
                    xv = int(row['step'])
                    yv = float(row['density'])
                except Exception:
                    continue
                if math.isfinite(yv):
                    dx.append(xv)
                    dy.append(yv)
        _plot_lines(
            name="10_overall_spike_density_timeline",
            title="Overall Spike Density Timeline",
            xlabel="step",
            ylabel="density",
            lines=[{"label": "overall_density", "x": dx, "y": dy, "color": "tab:blue"}],
        )

    # TTT loss
    ttt = results.get('ttt', {})
    ttt_x, ttt_y = _series(ttt.get('loss_history', []), 'loss')
    _plot_lines(
        name="11_ttt_loss",
        title="TTT LoRA Loss",
        xlabel="step",
        ylabel="loss",
        lines=[{"label": "ttt_loss", "x": ttt_x, "y": ttt_y, "color": "tab:brown"}],
    )

    # Validation test pass/fail chart
    test_rows = validation_results.get('tests', []) if isinstance(validation_results, dict) else []
    if test_rows:
        names = [str(t[0]) for t in test_rows]
        vals = [1 if bool(t[1]) else 0 for t in test_rows]
        fig, ax = plt.subplots(figsize=(14, 6))
        x = np.arange(len(names))
        colors = ['tab:green' if v == 1 else 'tab:red' for v in vals]
        ax.bar(x, vals, color=colors)
        ax.set_xticks(x)
        ax.set_xticklabels(names, rotation=45, ha='right')
        ax.set_ylim(0, 1.2)
        ax.set_title("Validation Test Outcomes")
        ax.set_ylabel("pass (1) / fail (0)")
        ax.grid(True, alpha=0.3, axis='y')
        path, b64 = _save_fig(fig, "12_validation_tests")
        detailed_figures["12_validation_tests"] = {"title": "Validation Test Outcomes", "path": path, "base64": b64}
        plt.close(fig)

    # Hardware summary chart
    hw = results.get('hardware_stats', {})
    hw_labels = []
    hw_vals = []
    for k in ["peak_gpu_memory_gb", "avg_gpu_memory_gb", "throughput_tokens_per_sec", "total_training_time_min"]:
        if k in hw:
            try:
                hw_labels.append(k)
                hw_vals.append(float(hw[k]))
            except Exception:
                pass
    if hw_labels:
        fig, ax = plt.subplots(figsize=(12, 5))
        x = np.arange(len(hw_labels))
        ax.bar(x, hw_vals, color='tab:cyan')
        ax.set_xticks(x)
        ax.set_xticklabels(hw_labels, rotation=20, ha='right')
        ax.set_title("Hardware / Runtime Summary")
        ax.grid(True, alpha=0.3, axis='y')
        path, b64 = _save_fig(fig, "13_hardware_summary")
        detailed_figures["13_hardware_summary"] = {"title": "Hardware / Runtime Summary", "path": path, "base64": b64}
        plt.close(fig)
else:
    print("matplotlib unavailable: detailed figure generation skipped in single-file dossier.")

# Include legacy training plot if present
legacy_plot = results.get('figures', {}).get('training_plot', {})
if isinstance(legacy_plot, dict) and isinstance(legacy_plot.get('base64'), str) and legacy_plot['base64']:
    detailed_figures["00_legacy_training_plot"] = {
        "title": "Legacy Training Plot",
        "path": legacy_plot.get("filename", "legacy_training_plot.png"),
        "base64": legacy_plot["base64"],
    }

consolidated_payload = {
    "schema_version": "1.0",
    "run_id": run_id,
    "timestamp_utc": datetime.utcnow().isoformat() + "Z",
    "phase": "B",
    "summary_metrics": metrics_payload,
    "eval_suite": eval_suite_payload,
    "v15_validation": v15_payload,
    "curve_stats": detailed_metrics,
    "detailed_figures_index": {
        k: {"title": v.get("title"), "path": v.get("path")}
        for k, v in detailed_figures.items()
    },
}

def _json_block(obj):
    return html_escape(json.dumps(obj, indent=2, default=str))

summary_rows = "".join(
    f"<tr><td>{html_escape(str(k))}</td><td>{html_escape(str(v))}</td></tr>"
    for k, v in metrics_payload.items()
)

curve_rows = ""
for name, stats in detailed_metrics.items():
    if not stats:
        continue
    curve_rows += (
        f"<tr><td>{html_escape(name)}</td>"
        f"<td>{stats.get('count')}</td>"
        f"<td>{stats.get('min')}</td>"
        f"<td>{stats.get('max')}</td>"
        f"<td>{stats.get('mean')}</td>"
        f"<td>{stats.get('std')}</td>"
        f"<td>{stats.get('last')}</td></tr>"
    )

fig_blocks = ""
for name, meta in sorted(detailed_figures.items()):
    b64 = meta.get("base64", "")
    title = meta.get("title", name)
    fig_blocks += (
        f"<h3>{html_escape(title)}</h3>"
        f"<p><code>{html_escape(name)}</code></p>"
        f"<img src='data:image/png;base64,{b64}' style='max-width:100%;border:1px solid #ddd;padding:6px;background:#fff;'/>"
    )

html_report = f"""<!DOCTYPE html>
<html>
<head>
  <meta charset="utf-8"/>
  <title>Gerhard V15 Dossier - {html_escape(run_id)}</title>
  <style>
    body {{ font-family: Arial, sans-serif; margin: 24px; line-height: 1.45; color: #111; }}
    h1, h2, h3 {{ margin-top: 24px; }}
    table {{ border-collapse: collapse; width: 100%; margin: 12px 0; }}
    th, td {{ border: 1px solid #ccc; padding: 8px; text-align: left; font-size: 13px; }}
    th {{ background: #f3f5f7; }}
    code, pre {{ background: #f5f5f5; padding: 2px 4px; }}
    pre {{ padding: 12px; overflow-x: auto; }}
    details {{ margin: 10px 0; }}
  </style>
</head>
<body>
  <h1>Gerhard V15 Single-File Dossier</h1>
  <p><b>Run ID:</b> {html_escape(run_id)}<br/>
     <b>Generated:</b> {html_escape(datetime.utcnow().isoformat() + "Z")}<br/>
     <b>Phase:</b> B (v15 SpikingBrain validation)</p>

  <h2>Executive Summary Metrics</h2>
  <table>
    <tr><th>Metric</th><th>Value</th></tr>
    {summary_rows}
  </table>

  <h2>Curve Statistics</h2>
  <table>
    <tr>
      <th>Curve</th><th>Count</th><th>Min</th><th>Max</th><th>Mean</th><th>Std</th><th>Last</th>
    </tr>
    {curve_rows}
  </table>

  <h2>Detailed Figures</h2>
  {fig_blocks}

  <h2>Raw Data (Embedded)</h2>
  <details>
    <summary>Consolidated payload</summary>
    <pre>{_json_block(consolidated_payload)}</pre>
  </details>
  <details>
    <summary>results.json snapshot</summary>
    <pre>{_json_block(results)}</pre>
  </details>
</body>
</html>
"""

with open(single_file_path, "w", encoding="utf-8") as f:
    f.write(html_report)

single_file_size_mb = os.path.getsize(single_file_path) / (1024 * 1024)
print("")
print(f"Single-file dossier saved: {single_file_path} ({single_file_size_mb:.2f} MB)")
print(f"Primary one-file output: {single_file_primary_output}")

# -----------------------------------------------------------------------------
# Optional automatic registration into repo autonomous state/report system
# -----------------------------------------------------------------------------
registration_result = None
registration_error = None
repo_root = None

candidate_roots = [
    Path.cwd(),
    Path.cwd().parent,
    Path('/workspace/gerhard'),
    Path('/kaggle/working/gerhard'),
]

for candidate in candidate_roots:
    if (candidate / 'scripts' / 'register_notebook_run.py').exists():
        repo_root = candidate
        break

if repo_root is not None:
    try:
        if str(repo_root) not in sys.path:
            sys.path.append(str(repo_root))
        from scripts.register_notebook_run import register_run

        registration_result = register_run(
            run_id=run_id,
            phase='B',
            source_dir=Path(run_artifact_dir),
            repo_root=repo_root,
            summary='v15 notebook pass with canonical artifact bundle and single-file dossier',
            next_action='Proceed according to gate decision (continue on green, pause on red).',
        )
        print("")
        print("Autonomous registration complete:")
        print(registration_result)
    except Exception as e:
        registration_error = str(e)
        print("")
        print(f"Autonomous registration skipped due to error: {registration_error}")
else:
    registration_error = (
        "register_notebook_run.py not found under candidate repo roots. "
        "Run registration manually later with this source dir."
    )
    print("")
    print("Autonomous registration helper not found in this environment.")
    print(f"Manual source dir for later registration: {run_artifact_dir}")

results['autonomy_artifacts'] = {
    'run_id': run_id,
    'artifact_dir': run_artifact_dir,
    'single_file_dossier': single_file_path,
    'single_file_primary_output': single_file_primary_output,
    'required_files': [
        'eval_suite.json',
        'metrics.json',
        'config.yaml',
        'seed.txt',
        'v15_spikingbrain.json',
        f'run_dossier_{run_id}.html',
    ],
    'registration_result': registration_result,
    'registration_error': registration_error,
}

# update legacy results snapshot with autonomy metadata
with open(results_path, 'w') as f:
    json.dump(results, f, indent=2, default=str)

def _attempt_auto_download(path: str):
    if IS_COLAB:
        try:
            from google.colab import files
            files.download(path)
            return True
        except Exception as e:
            print(f"colab download failed: {e}")
            return False
    # Try Jupyter front-end auto-open in non-colab notebook environments
    try:
        from IPython.display import Javascript, display
        abs_path = os.path.abspath(path).replace("\\", "/")
        display(Javascript(f"window.open('/files/{abs_path}', '_blank');"))
        print(f"triggered browser download/open for: /files/{abs_path}")
        return True
    except Exception as e:
        print(f"non-colab auto-download not available: {e}")
        return False

print("")
print("Auto-download single-file dossier")
_attempt_auto_download(single_file_primary_output)
print(f"single-file dossier path: {single_file_primary_output}")
