# asnn-goose v9: model capacity increase (320d × 5L)

## abstract

v8 achieved **PPL 559** (11% improvement over v6's 627). v9 increases model capacity to reduce PPL further toward our target of **440**.

**v8 results (baseline for v9):**
- teacher ppl: 44.6
- student ppl: **559** (target was <627 ✅)
- tests: 8/8 passed
- amplitudes: 0.74-1.08 (learned)
- spike density: 0.384

**v9 strategy: increase capacity**

| attribute | v8 | v9 | change |
|-----------|-----|-----|--------|
| d_model | 256 | **320** | +25% |
| n_layers | 4 | **5** | +25% |
| params | ~16M | **~30M** | +87% |
| VRAM | ~1.5GB | **~2.5GB** | +67% |

**expected:**
- student ppl: **<520** (from 559)
- tests: 9/9 passed (new: amplitude health)
- VRAM: <8GB on T4 (16GB available)

---

**eptesicus laboratories - lumis-next initiative**

### references
- hinton et al. (2015) "distilling the knowledge in a neural network"
- radford et al. (2019) "language models are unsupervised multitask learners" (gpt-2)
- lv et al. (2023) "spikebert: a language spikformer learned from bert with knowledge distillation"
- shen et al. (2024) "spikingmamba: towards energy-efficient large language models"
- wei et al. (2023) "ternary spike: learning ternary spikes for spiking neural networks"
- hu et al. (2022) "lora: low-rank adaptation of large language models"

In [None]:
# =============================================================================
# cell 1: environment setup
# =============================================================================
import os
os.environ['TOKENIZERS_PARALLELISM'] = 'false'

import sys
import time
import math
import json
import base64
from pathlib import Path
from datetime import datetime
from dataclasses import dataclass, asdict
from typing import Dict, List, Optional, Tuple, Any
import warnings
warnings.filterwarnings('ignore')

# generate timestamp for this run
RUN_TIMESTAMP = datetime.now().strftime('%Y-%m-%d_%H%M%S')
print(f"run timestamp: {RUN_TIMESTAMP}")

# detect platform
IS_KAGGLE = 'KAGGLE_KERNEL_RUN_TYPE' in os.environ
IS_COLAB = 'COLAB_GPU' in os.environ or 'google.colab' in sys.modules
PLATFORM = 'kaggle' if IS_KAGGLE else 'colab' if IS_COLAB else 'local'
OUTPUT_DIR = '/kaggle/working/outputs' if IS_KAGGLE else 'outputs'

for subdir in ['figures', 'checkpoints', 'logs', 'results']:
    os.makedirs(f'{OUTPUT_DIR}/{subdir}', exist_ok=True)

print(f"platform: {PLATFORM}")
print(f"output directory: {OUTPUT_DIR}")

In [None]:
# =============================================================================
# cell 2: pytorch and hardware setup
# =============================================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
import numpy as np

import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt

from tqdm.auto import tqdm

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
SEED = 42

torch.manual_seed(SEED)
np.random.seed(SEED)

if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True
    torch.backends.cudnn.benchmark = True

    gpu_name = torch.cuda.get_device_name(0)
    gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1e9
    print(f"gpu: {gpu_name}")
    print(f"memory: {gpu_memory:.1f} gb")

print(f"device: {DEVICE}")
print(f"pytorch: {torch.__version__}")

## 1. v9 design: model capacity increase

### 1.1 rationale for capacity increase

v8 achieved PPL 559 with a 16M parameter model. Research shows that model capacity is often the primary bottleneck for distillation quality. We increase:

- **d_model**: 256 → 320 (+25%)
- **n_layers**: 4 → 5 (+25%)
- **params**: ~16M → ~30M (+87%)

### 1.2 memory analysis

```
30M params = ~120MB (FP32) or ~60MB (FP16)
Optimizer states (Adam) = ~240MB
Gradients = ~60MB (FP16)
Activations = ~2GB (batch 8, seq 256, 5 layers, 320 dim)
Total = ~2.5GB

T4 has 16GB → 13.5GB headroom (safe margin)
```

### 1.3 v9 changes from v8

| component | v8 | v9 | action |
|-----------|-----|-----|--------|
| d_model | 256 | 320 | **increase** |
| n_layers | 4 | 5 | **increase** |
| teacher_indices | [3,6,9,12] | [2,5,7,10,12] | **remap** |
| params | ~16M | ~30M | **result** |
| temperature | 2.0 | 2.0 | keep |
| hidden_align_weight | 0.0 | 0.0 | keep |
| warmup_steps | 50 | 50 | keep |

### 1.4 trainable ternary spike (from v8)

$$s = a \cdot \text{sign}(x, \theta)$$

where $a$ is learnable per layer. gradient flow:
- $\partial L / \partial x = \partial L / \partial s$ (ste)
- $\partial L / \partial a = \partial L / \partial s \cdot \text{sign}$

### 1.5 hidden-state alignment (optional, disabled by default)

$$\mathcal{L}_{\text{align}} = \frac{1}{L} \sum_{l=1}^{L} \| f(h^{(s)}_l) - h^{(t)}_l \|_2^2$$

kept for future experiments with reduced weight (0.001-0.01 range)

In [None]:
# =============================================================================
# cell 4: configuration (v9 - model capacity increase)
# =============================================================================
@dataclass
class Config:
    # gpt-2 teacher (frozen, pre-trained)
    teacher_name: str = "gpt2"

    # student model architecture - v9: INCREASED CAPACITY
    d_model: int = 320      # v8: 256 -> v9: 320 (+25%)
    n_layers: int = 5       # v8: 4 -> v9: 5 (+25%)
    vocab_size: int = 50257
    max_seq_len: int = 256

    # distillation training (same as v8)
    distill_steps: int = 3000
    distill_lr: float = 3e-4
    temperature: float = 2.0      # proven in v6, v8
    warmup_steps: int = 50        # minimal warmup

    # hidden-state alignment (DISABLED by default, but code kept)
    hidden_align_weight: float = 0.0  # disabled
    teacher_d_model: int = 768        # gpt-2 hidden dim
    teacher_n_layers: int = 12        # gpt-2 layers

    # lora for ttt
    lora_rank: int = 8
    lora_alpha: float = 16.0
    ttt_lr: float = 1e-4
    ttt_steps: int = 100

    # spiking parameters
    spike_alpha: float = 1.0

    # general training
    batch_size: int = 8
    max_grad_norm: float = 1.0
    eval_interval: int = 100

config = Config()

print(f"configuration (v9 - model capacity increase):")
print(f"  teacher: {config.teacher_name} (124m params)")
print(f"  student: d={config.d_model}, layers={config.n_layers}")
print(f"  v9 changes: d_model 256→320, n_layers 4→5")
print(f"  distillation: {config.distill_steps} steps, T={config.temperature}")
print(f"  warmup: {config.warmup_steps} steps")
print(f"  hidden alignment: weight={config.hidden_align_weight} (disabled)")
print(f"  lora: rank={config.lora_rank}, ttt_steps={config.ttt_steps}")
print(f"")
print(f"expected memory: ~2.5GB (T4 has 16GB - plenty of headroom)")

## 2. trainable ternary spiking (from v7)

### 2.1 motivation

fixed ternary values {-1, 0, +1} limit expressivity. the ternary spike paper (wei et al., 2023) shows that **trainable amplitude factors** per layer significantly improve accuracy:

> "we propose to learn the amplitude of ternary spikes, which allows different layers to have different spike magnitudes optimized during training."

### 2.2 implementation

each layer has a learnable `amplitude` parameter (initialized to 1.0). during forward pass:

```python
spikes = torch.zeros_like(x)
spikes[x > threshold] = +amplitude  # trainable!
spikes[x < -threshold] = -amplitude  # trainable!
```

the amplitude receives gradients via the ste + multiplication chain.

### 2.3 straight-through estimator (ste)

gradients pass through unchanged during backpropagation:

$$\frac{\partial \mathcal{L}}{\partial x} = \frac{\partial \mathcal{L}}{\partial s}$$

In [None]:
# =============================================================================
# cell 6: trainable ternary spike (from v7, fixed implementation)
# =============================================================================
class TrainableTernarySpike(nn.Module):
    """
    trainable ternary spike with learnable amplitude.
    
    uses STE trick without custom autograd.Function:
    1. compute spike pattern without gradient tracking
    2. multiply by trainable amplitude (gradient flows here)
    3. use (x - x.detach()) trick for STE gradient on x
    """

    def __init__(self, alpha: float = 1.0):
        super().__init__()
        self.alpha = alpha
        self.amplitude = nn.Parameter(torch.ones(1))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        threshold = self.alpha * x.abs().mean(dim=-1, keepdim=True)
        threshold = threshold.clamp(min=0.01, max=10.0)

        with torch.no_grad():
            pos_mask = (x > threshold).float()
            neg_mask = (x < -threshold).float()
            spike_signs = pos_mask - neg_mask

        spikes = self.amplitude * spike_signs
        return spikes + (x - x.detach())

    def get_amplitude(self) -> float:
        return self.amplitude.item()


# test
print("testing TrainableTernarySpike...")
_spike = TrainableTernarySpike().to(DEVICE)
_x = torch.randn(2, 16, 64, device=DEVICE, requires_grad=True)
_y = _spike(_x)
_y.sum().backward()
print(f"  amplitude: {_spike.get_amplitude():.4f}")
print(f"  gradient for x: {'exists' if _x.grad is not None else 'none'}")
print(f"  gradient for amplitude: {_spike.amplitude.grad.item():.4f}")
del _spike, _x, _y

In [None]:
# =============================================================================
# cell 7: hardware and spike stats collectors
# =============================================================================
class HardwareStatsCollector:
    """collect gpu memory, timing, and throughput metrics."""

    def __init__(self):
        self.gpu_memory_history = []
        self.step_times = []
        self.tokens_processed = 0
        self.start_time = None

    def start(self):
        self.start_time = time.time()
        if torch.cuda.is_available():
            torch.cuda.reset_peak_memory_stats()

    def record_step(self, batch_size: int, seq_len: int):
        if torch.cuda.is_available():
            self.gpu_memory_history.append(torch.cuda.memory_allocated() / 1e9)
        self.tokens_processed += batch_size * seq_len
        self.step_times.append(time.time())

    def get_throughput(self) -> float:
        if len(self.step_times) < 2:
            return 0.0
        elapsed = self.step_times[-1] - self.step_times[0]
        return self.tokens_processed / elapsed if elapsed > 0 else 0.0

    def get_summary(self) -> Dict[str, Any]:
        elapsed = time.time() - self.start_time if self.start_time else 0
        return {
            'peak_gpu_memory_gb': max(self.gpu_memory_history) if self.gpu_memory_history else 0,
            'avg_gpu_memory_gb': float(np.mean(self.gpu_memory_history)) if self.gpu_memory_history else 0,
            'total_training_time_s': elapsed,
            'total_training_time_min': elapsed / 60,
            'tokens_processed': self.tokens_processed,
            'throughput_tokens_per_sec': self.get_throughput(),
        }


class SpikeStatsCollector:
    """collect per-layer spike density and amplitude evolution."""

    def __init__(self, n_layers: int):
        self.n_layers = n_layers
        self.density_history = {i: {'k': [], 'v': []} for i in range(n_layers)}
        self.amplitude_history = {i: {'k': [], 'v': []} for i in range(n_layers)}
        self.step_densities = []

    def record(self, student, step: int):
        stats = student.get_spike_stats()
        all_densities = []
        for i in range(self.n_layers):
            layer_key = f'layer_{i}'
            if layer_key in stats:
                k_density = stats[layer_key].get('k', 0)
                v_density = stats[layer_key].get('v', 0)
                k_amp = stats[layer_key].get('k_amp', 1.0)
                v_amp = stats[layer_key].get('v_amp', 1.0)

                self.density_history[i]['k'].append(k_density)
                self.density_history[i]['v'].append(v_density)
                self.amplitude_history[i]['k'].append(k_amp)
                self.amplitude_history[i]['v'].append(v_amp)
                all_densities.extend([k_density, v_density])

        if all_densities:
            self.step_densities.append({'step': step, 'density': float(np.mean(all_densities))})

    def get_summary(self) -> Dict[str, Any]:
        per_layer = {}
        all_k, all_v = [], []
        all_k_amp, all_v_amp = [], []

        for i in range(self.n_layers):
            k_vals = self.density_history[i]['k']
            v_vals = self.density_history[i]['v']
            k_amps = self.amplitude_history[i]['k']
            v_amps = self.amplitude_history[i]['v']

            per_layer[f'layer_{i}'] = {
                'k_mean': float(np.mean(k_vals)) if k_vals else 0,
                'k_std': float(np.std(k_vals)) if k_vals else 0,
                'k_final': float(k_vals[-1]) if k_vals else 0,
                'v_mean': float(np.mean(v_vals)) if v_vals else 0,
                'v_std': float(np.std(v_vals)) if v_vals else 0,
                'v_final': float(v_vals[-1]) if v_vals else 0,
                'k_amp_final': float(k_amps[-1]) if k_amps else 1.0,
                'v_amp_final': float(v_amps[-1]) if v_amps else 1.0,
            }
            all_k.extend(k_vals)
            all_v.extend(v_vals)
            if k_amps: all_k_amp.append(k_amps[-1])
            if v_amps: all_v_amp.append(v_amps[-1])

        return {
            'per_layer': per_layer,
            'overall_k_density': float(np.mean(all_k)) if all_k else 0,
            'overall_v_density': float(np.mean(all_v)) if all_v else 0,
            'overall_density': float(np.mean(all_k + all_v)) if (all_k or all_v) else 0,
            'amplitudes': {'k': all_k_amp, 'v': all_v_amp},
            'density_history': self.step_densities,
        }

print("collectors defined")

In [None]:
# =============================================================================
# cell 8: spiking goose model
# =============================================================================
class SpikingGooseRecurrentLayer(nn.Module):
    """rwkv-style recurrence with trainable ternary spiking."""

    def __init__(self, d_model, layer_idx=0, n_layers=4, spike_alpha=1.0):
        super().__init__()
        self.d_model = d_model
        self.layer_idx = layer_idx
        self.ln = nn.LayerNorm(d_model)

        ratio = layer_idx / max(n_layers - 1, 1)
        self.time_mix_k = nn.Parameter(torch.ones(d_model) * (1 - ratio))
        self.time_mix_v = nn.Parameter(torch.ones(d_model) * (1 - ratio))
        self.time_mix_r = nn.Parameter(torch.ones(d_model) * (1 - ratio))
        self.decay_weight = nn.Parameter(torch.zeros(d_model) - 0.5)

        self.key_proj = nn.Linear(d_model, d_model, bias=False)
        self.value_proj = nn.Linear(d_model, d_model, bias=False)
        self.receptance_proj = nn.Linear(d_model, d_model, bias=False)
        self.output_proj = nn.Linear(d_model, d_model, bias=False)

        self.k_spike = TrainableTernarySpike(alpha=spike_alpha)
        self.v_spike = TrainableTernarySpike(alpha=spike_alpha)

        self.register_buffer('running_k_density', torch.tensor(0.0))
        self.register_buffer('running_v_density', torch.tensor(0.0))
        self._init_weights()

    def _init_weights(self):
        std = 0.1 / math.sqrt(self.d_model)
        for m in [self.key_proj, self.value_proj, self.receptance_proj, self.output_proj]:
            nn.init.normal_(m.weight, std=std)

    def forward(self, x):
        B, T, D = x.shape
        x_norm = self.ln(x)
        prev_x = F.pad(x_norm[:, :-1, :], (0, 0, 1, 0))

        xk = x_norm * self.time_mix_k + prev_x * (1 - self.time_mix_k)
        xv = x_norm * self.time_mix_v + prev_x * (1 - self.time_mix_v)
        xr = x_norm * self.time_mix_r + prev_x * (1 - self.time_mix_r)

        k_pre = self.key_proj(xk)
        v_pre = self.value_proj(xv)

        k = self.k_spike(k_pre)
        v = self.v_spike(v_pre)
        r = torch.sigmoid(self.receptance_proj(xr))

        kv = k * v
        decay = torch.sigmoid(self.decay_weight)
        t_idx = torch.arange(T, device=x.device, dtype=x.dtype)
        decay_powers = decay.unsqueeze(0) ** t_idx.unsqueeze(1)

        kv_weighted = kv / (decay_powers.unsqueeze(0) + 1e-8)
        S = torch.cumsum(kv_weighted, dim=1) * decay_powers.unsqueeze(0)

        if self.training:
            with torch.no_grad():
                self.running_k_density = 0.99 * self.running_k_density + 0.01 * (k != 0).float().mean()
                self.running_v_density = 0.99 * self.running_v_density + 0.01 * (v != 0).float().mean()

        return x + r * self.output_proj(S)

    def get_spike_density(self):
        return {
            'k': self.running_k_density.item(),
            'v': self.running_v_density.item(),
            'k_amp': self.k_spike.get_amplitude(),
            'v_amp': self.v_spike.get_amplitude(),
        }


class GooseFFN(nn.Module):
    def __init__(self, d_model, expand=4):
        super().__init__()
        self.ln = nn.LayerNorm(d_model)
        self.w1 = nn.Linear(d_model, d_model * expand, bias=False)
        self.w2 = nn.Linear(d_model * expand, d_model, bias=False)

    def forward(self, x):
        return x + self.w2(F.silu(self.w1(self.ln(x))))


class StudentSpikingGoose(nn.Module):
    """spiking student model with trainable ternary activations."""

    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.embed = nn.Embedding(cfg.vocab_size, cfg.d_model)
        self.pos_embed = nn.Embedding(cfg.max_seq_len, cfg.d_model)

        self.layers = nn.ModuleList([
            nn.ModuleDict({
                'rec': SpikingGooseRecurrentLayer(cfg.d_model, i, cfg.n_layers, cfg.spike_alpha),
                'ffn': GooseFFN(cfg.d_model),
            })
            for i in range(cfg.n_layers)
        ])

        self.ln_out = nn.LayerNorm(cfg.d_model)
        self.head = nn.Linear(cfg.d_model, cfg.vocab_size, bias=False)
        self.head.weight = self.embed.weight

        nn.init.normal_(self.embed.weight, std=0.02)
        nn.init.normal_(self.pos_embed.weight, std=0.02)

    def forward(self, input_ids, return_hiddens=False):
        """forward pass with optional hidden state return for alignment."""
        B, T = input_ids.shape
        pos = torch.arange(T, device=input_ids.device).unsqueeze(0)
        x = self.embed(input_ids) + self.pos_embed(pos)

        hiddens = [x] if return_hiddens else None

        for layer in self.layers:
            x = layer['rec'](x)
            x = layer['ffn'](x)
            if return_hiddens:
                hiddens.append(x)

        logits = self.head(self.ln_out(x))

        if return_hiddens:
            return logits, hiddens
        return logits

    def get_spike_stats(self):
        return {f'layer_{i}': layer['rec'].get_spike_density() for i, layer in enumerate(self.layers)}

    def get_avg_spike_density(self):
        densities = []
        for layer in self.layers:
            d = layer['rec'].get_spike_density()
            densities.extend([d['k'], d['v']])
        return float(np.mean(densities)) if densities else 0.0

    def get_amplitudes(self):
        return {f'layer_{i}': {'k': layer['rec'].k_spike.get_amplitude(), 'v': layer['rec'].v_spike.get_amplitude()}
                for i, layer in enumerate(self.layers)}

print("student model defined (v8: trainable amplitudes, optional alignment)")

In [None]:
# =============================================================================
# cell 9: hidden-state alignment (KEPT from v8, optional)
# =============================================================================
class HiddenStateProjector(nn.Module):
    """
    project student hidden states to teacher dimension for alignment.
    
    student: (B, T, 320) -> (B, T, 768)  # v9: 320 dim
    
    maps 5 student layers to selected teacher layers.
    
    NOTE: kept from v8 for future experiments. only used if hidden_align_weight > 0.
    """

    def __init__(self, student_dim: int, teacher_dim: int, n_student_layers: int):
        super().__init__()
        self.projectors = nn.ModuleList([
            nn.Linear(student_dim, teacher_dim, bias=False)
            for _ in range(n_student_layers)
        ])
        for proj in self.projectors:
            nn.init.normal_(proj.weight, std=0.02)

    def forward(self, student_hidden: torch.Tensor, layer_idx: int) -> torch.Tensor:
        return self.projectors[layer_idx](student_hidden)


def compute_hidden_alignment_loss(
    teacher_hiddens: List[torch.Tensor],
    student_hiddens: List[torch.Tensor],
    projector: HiddenStateProjector,
    teacher_layers: int = 12,
    student_layers: int = 5,  # v9: 5 student layers
) -> torch.Tensor:
    """
    compute mse loss between projected student hiddens and teacher hiddens.
    
    v9 maps 5 student layers to 12 teacher layers:
      student 0 -> teacher 2 (3rd layer)
      student 1 -> teacher 5 (6th layer)
      student 2 -> teacher 7 (8th layer)
      student 3 -> teacher 10 (11th layer)
      student 4 -> teacher 12 (output)
    
    NOTE: kept from v8 for future experiments. only called if hidden_align_weight > 0.
    """
    loss = 0.0
    # v9: updated mapping for 5 student layers
    teacher_indices = [2, 5, 7, 10, 12]  # v8 was [3, 6, 9, 12] for 4 layers

    for s_idx, t_idx in enumerate(teacher_indices):
        if s_idx < len(student_hiddens) - 1 and t_idx < len(teacher_hiddens):
            s_h = student_hiddens[s_idx + 1]
            t_h = teacher_hiddens[t_idx]
            s_h_proj = projector(s_h, s_idx)
            loss += F.mse_loss(s_h_proj, t_h)

    return loss / len(teacher_indices)


print("hidden-state alignment defined (v9: 5 student layers)")
print(f"  student layers: {config.n_layers} (d={config.d_model})")
print(f"  teacher layers: {config.teacher_n_layers} (d={config.teacher_d_model})")
print(f"  layer mapping: [2, 5, 7, 10, 12] (v8 was [3, 6, 9, 12])")
print(f"  current weight: {config.hidden_align_weight} (set > 0 to enable)")

In [None]:
# =============================================================================
# cell 10: cosine lr with warmup
# =============================================================================
def get_cosine_schedule_with_warmup(
    optimizer: torch.optim.Optimizer,
    warmup_steps: int,
    total_steps: int,
) -> torch.optim.lr_scheduler.LambdaLR:
    """
    linear warmup then cosine decay to 0.
    """
    def lr_lambda(step: int) -> float:
        if step < warmup_steps:
            return step / max(warmup_steps, 1)
        else:
            progress = (step - warmup_steps) / max(total_steps - warmup_steps, 1)
            return 0.5 * (1.0 + math.cos(math.pi * progress))

    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)


print(f"cosine lr: {config.warmup_steps} warmup, {config.distill_steps} total")

In [None]:
# =============================================================================
# cell 11: load gpt-2 teacher
# =============================================================================
from transformers import GPT2LMHeadModel, GPT2Tokenizer

print("loading gpt-2 teacher...")
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token

teacher = GPT2LMHeadModel.from_pretrained('gpt2').to(DEVICE)
teacher.eval()
for p in teacher.parameters():
    p.requires_grad = False

teacher_params = sum(p.numel() for p in teacher.parameters())
print(f"teacher: gpt-2 ({teacher_params:,} params)")

## 3. gpt-2 as teacher

### 3.1 model specifications

| attribute | gpt-2 (teacher) | asnn-goose v9 (student) |
|-----------|-----------------|-------------------------|
| parameters | 124m | **~30m** |
| layers | 12 | **5** |
| hidden dim | 768 | **320** |
| attention | softmax (dense) | linear + ternary spikes |
| ppl (wikitext-2) | ~30 | target: **<520** (v9) |

### 3.2 extracting hidden states

for hidden-state alignment (if enabled), we extract intermediate representations:

```python
outputs = teacher(ids, output_hidden_states=True)
teacher_hiddens = outputs.hidden_states  # [embed, layer0, ..., layer11]
```

### 3.3 distillation loss

$$\mathcal{L}_{\text{kd}} = T^2 \cdot \text{KL}(p^{(t)} \| p^{(s)})$$

with $T=2$ (proven in v6, v8)

In [None]:
# =============================================================================
# cell 13: data loading
# =============================================================================
from datasets import load_dataset

print("loading wikitext-2...")
dataset = load_dataset('wikitext', 'wikitext-2-raw-v1')

def pre_tokenize(texts, max_len):
    all_tokens = []
    for text in tqdm(texts, desc="tokenizing", leave=False):
        if text.strip():
            all_tokens.extend(tokenizer.encode(text, max_length=max_len*2, truncation=True))
    chunks = [all_tokens[i:i+max_len] for i in range(0, len(all_tokens)-max_len+1, max_len//2) if len(all_tokens[i:i+max_len]) == max_len]
    print(f"created {len(chunks)} sequences")
    return torch.tensor(chunks, dtype=torch.long)

train_tokens = pre_tokenize(dataset['train']['text'], config.max_seq_len)
val_tokens = pre_tokenize(dataset['validation']['text'], config.max_seq_len)

train_loader = DataLoader(TensorDataset(train_tokens), batch_size=config.batch_size, shuffle=True, pin_memory=True)
val_loader = DataLoader(TensorDataset(val_tokens), batch_size=config.batch_size, shuffle=False, pin_memory=True)

print(f"train: {len(train_loader)} batches, val: {len(val_loader)} batches")

In [None]:
# =============================================================================
# cell 14: create student model and projector
# =============================================================================
print("creating student model (v9 - increased capacity)...")

student = StudentSpikingGoose(config).to(DEVICE)
student_params = sum(p.numel() for p in student.parameters())

# v9: create projector (even if not used, for infrastructure preservation)
projector = HiddenStateProjector(
    student_dim=config.d_model,
    teacher_dim=config.teacher_d_model,
    n_student_layers=config.n_layers
).to(DEVICE)
projector_params = sum(p.numel() for p in projector.parameters())

compression_ratio = teacher_params / student_params

print(f"student: asnn-goose v9 ({student_params:,} params)")
print(f"projector: ({projector_params:,} params)")
print(f"compression ratio: {compression_ratio:.1f}x")
print(f"")
print(f"v9 capacity increase:")
print(f"  d_model: 256 → {config.d_model} (+25%)")
print(f"  n_layers: 4 → {config.n_layers} (+25%)")
print(f"  params: ~16M → ~{student_params // 1_000_000}M (+{(student_params / 16_000_000 - 1) * 100:.0f}%)")
print(f"")
print(f"settings (same as v8):")
print(f"  temperature: {config.temperature}")
print(f"  hidden alignment: weight={config.hidden_align_weight} (disabled)")
if config.hidden_align_weight > 0:
    print(f"  alignment ENABLED: projector will be used")
else:
    print(f"  alignment DISABLED: projector created but not trained")

In [None]:
# =============================================================================
# cell 15: evaluation functions
# =============================================================================
@torch.no_grad()
def evaluate(model, loader, device, is_gpt2=False):
    model.eval()
    total_loss, total_tokens = 0, 0
    for batch in loader:
        ids = batch[0].to(device)
        with torch.cuda.amp.autocast():
            logits = model(ids).logits if is_gpt2 else model(ids)
        loss = F.cross_entropy(logits[:, :-1].reshape(-1, logits.size(-1)), ids[:, 1:].reshape(-1), reduction='sum')
        total_loss += loss.item()
        total_tokens += ids[:, 1:].numel()
    return total_loss / total_tokens

def get_ppl(loss):
    return math.exp(min(loss, 10))

print("evaluation functions defined")

## 4. distillation training (v9)

### 4.1 loss function (conditional alignment)

v9 uses kl divergence with **optional** hidden-state alignment (same as v8):

$$\mathcal{L}_{\text{total}} = \mathcal{L}_{\text{kd}} + \lambda \cdot \mathcal{L}_{\text{align}}$$

where:
- $\mathcal{L}_{\text{kd}} = T^2 \cdot \text{KL}(p^{(t)} \| p^{(s)})$ with $T=2$
- $\mathcal{L}_{\text{align}}$ only computed if $\lambda > 0$
- **v9 default**: $\lambda = 0$ (alignment disabled)

### 4.2 training schedule

- **warmup**: 50 steps
- **decay**: 2950 steps (cosine)
- **total**: 3000 steps

### 4.3 v9 changes

| aspect | v8 | v9 |
|--------|-----|-----|
| d_model | 256 | **320** |
| n_layers | 4 | **5** |
| params | ~16M | **~30M** |
| training | same | same |

### 4.4 loss tracking

both `kl_loss_history` and `align_loss_history` tracked separately (even when align=0)

In [None]:
# =============================================================================
# cell 17: distillation training loop (v9 - same as v8, larger model)
# =============================================================================
def distill_v9(teacher, student, projector, train_loader, val_loader, cfg, device,
               hw_stats, spike_stats):
    """
    v9 distillation: same as v8, but with larger model capacity.
    
    key settings (unchanged from v8):
    - temperature = 2.0
    - warmup = 50 steps
    - hidden_align_weight = 0.0 (disabled)
    
    v9 changes:
    - d_model: 256 → 320
    - n_layers: 4 → 5
    - params: ~16M → ~30M
    """
    training_logs = {
        'loss_history': [],
        'kl_loss_history': [],
        'align_loss_history': [],
        'ppl_history': [],
        'lr_history': [],
    }

    # combine student and projector parameters (projector only trained if weight > 0)
    if cfg.hidden_align_weight > 0:
        all_params = list(student.parameters()) + list(projector.parameters())
    else:
        all_params = list(student.parameters())
    
    optimizer = torch.optim.AdamW(all_params, lr=cfg.distill_lr, weight_decay=0.01)
    scheduler = get_cosine_schedule_with_warmup(optimizer, cfg.warmup_steps, cfg.distill_steps)
    scaler = torch.cuda.amp.GradScaler()

    hw_stats.start()
    step = 0
    best_val = float('inf')

    pbar = tqdm(total=cfg.distill_steps, desc='distilling (v9)')

    while step < cfg.distill_steps:
        for batch in train_loader:
            if step >= cfg.distill_steps:
                break

            ids = batch[0].to(device, non_blocking=True)

            with torch.cuda.amp.autocast():
                # teacher forward (with hidden states only if needed)
                with torch.no_grad():
                    if cfg.hidden_align_weight > 0:
                        t_out = teacher(ids, output_hidden_states=True)
                        t_logits = t_out.logits
                        t_hiddens = t_out.hidden_states
                    else:
                        t_logits = teacher(ids).logits

                # student forward (with hidden states only if needed)
                student.train()
                if cfg.hidden_align_weight > 0:
                    s_logits, s_hiddens = student(ids, return_hiddens=True)
                else:
                    s_logits = student(ids)

                # kl divergence loss with T=2
                T = cfg.temperature
                s_log = F.log_softmax(s_logits / T, dim=-1)
                t_prob = F.softmax(t_logits / T, dim=-1)
                kl_loss = F.kl_div(
                    s_log.view(-1, s_logits.size(-1)),
                    t_prob.view(-1, t_logits.size(-1)),
                    reduction='batchmean'
                ) * (T ** 2)

                # optional hidden-state alignment
                if cfg.hidden_align_weight > 0:
                    align_loss = compute_hidden_alignment_loss(
                        t_hiddens, s_hiddens, projector,
                        teacher_layers=cfg.teacher_n_layers,
                        student_layers=cfg.n_layers
                    )
                    loss = kl_loss + cfg.hidden_align_weight * align_loss
                else:
                    align_loss = torch.tensor(0.0, device=device)
                    loss = kl_loss

            optimizer.zero_grad(set_to_none=True)
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            gn = torch.nn.utils.clip_grad_norm_(all_params, cfg.max_grad_norm)

            if torch.isfinite(gn):
                scaler.step(optimizer)
            scaler.update()
            scheduler.step()

            hw_stats.record_step(ids.size(0), ids.size(1))
            spike_stats.record(student, step)

            current_lr = optimizer.param_groups[0]['lr']

            # track both losses separately (even if align is 0)
            training_logs['loss_history'].append({'step': step, 'loss': loss.item()})
            training_logs['kl_loss_history'].append({'step': step, 'loss': kl_loss.item()})
            training_logs['align_loss_history'].append({'step': step, 'loss': align_loss.item()})
            training_logs['lr_history'].append({'step': step, 'lr': current_lr})

            pbar.set_postfix(
                loss=f"{loss.item():.3f}",
                kl=f"{kl_loss.item():.3f}",
                align=f"{align_loss.item():.3f}",
                lr=f"{current_lr:.1e}"
            )
            pbar.update(1)
            step += 1

            if step % cfg.eval_interval == 0:
                val_loss = evaluate(student, val_loader, device)
                val_ppl = get_ppl(val_loss)
                training_logs['ppl_history'].append({'step': step, 'ppl': val_ppl})

                amps = student.get_amplitudes()
                amp_str = ', '.join([f"L{i}:{amps[f'layer_{i}']['k']:.2f}" for i in range(cfg.n_layers)])
                print(f"\n  step {step}: ppl={val_ppl:.1f}, amps=[{amp_str}]")

                if val_loss < best_val:
                    best_val = val_loss
                    torch.save({
                        'student': student.state_dict(),
                        'projector': projector.state_dict(),
                    }, f'{OUTPUT_DIR}/checkpoints/v9_best.pt')

    pbar.close()
    return training_logs

print("distillation function defined (v9 - larger capacity)")

In [None]:
# =============================================================================
# cell 18: run distillation
# =============================================================================
print("="*60)
print("phase 1: distillation (v9 - model capacity increase)")
print("="*60)
print("")
print("v9 changes from v8:")
print(f"  d_model: 256 → {config.d_model} (+25%)")
print(f"  n_layers: 4 → {config.n_layers} (+25%)")
print(f"  params: ~16M → ~{student_params // 1_000_000}M")
print("")
print("settings (unchanged from v8):")
print(f"  temperature: {config.temperature}")
print(f"  warmup: {config.warmup_steps} steps")
print(f"  hidden alignment: weight={config.hidden_align_weight} (disabled)")
print("")

hw_stats = HardwareStatsCollector()
spike_stats = SpikeStatsCollector(config.n_layers)

distill_logs = distill_v9(
    teacher, student, projector, train_loader, val_loader,
    config, DEVICE, hw_stats, spike_stats
)

print("")
print(f"distillation complete!")
print(f"throughput: {hw_stats.get_summary()['throughput_tokens_per_sec']:.0f} tokens/sec")
print("")
print("final amplitudes:")
for k, v in student.get_amplitudes().items():
    print(f"  {k}: k={v['k']:.4f}, v={v['v']:.4f}")

## 5. test-time training (ttt) with lora

### 5.1 motivation

test-time training adapts the model to new data distributions at inference time. lora (hu et al., 2022) provides efficient adaptation:

$$W' = W_0 + BA$$

where:
- $W_0$: frozen base weights
- $B \in \mathbb{R}^{d_{out} \times r}$: low-rank up-projection
- $A \in \mathbb{R}^{r \times d_{in}}$: low-rank down-projection
- $r = 8$: rank

### 5.2 application to spiking models

we apply lora to key and value projections, enabling adaptation of spike patterns at test time.

In [None]:
# =============================================================================
# cell 20: lora implementation
# =============================================================================
class LoRALinear(nn.Module):
    """lora adapter for linear layers."""

    def __init__(self, in_features, out_features, rank=8, alpha=16.0):
        super().__init__()
        self.scaling = alpha / rank
        self.lora_A = nn.Parameter(torch.zeros(rank, in_features))
        self.lora_B = nn.Parameter(torch.zeros(out_features, rank))
        nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))

    def forward(self, x):
        return (x @ self.lora_A.T @ self.lora_B.T) * self.scaling


def apply_lora(model, rank=8, alpha=16.0, targets=['key_proj', 'value_proj']):
    """apply lora adapters to specified modules."""
    lora_modules = {}
    for name, module in model.named_modules():
        if any(t in name for t in targets) and isinstance(module, nn.Linear):
            lora = LoRALinear(module.in_features, module.out_features, rank, alpha).to(next(module.parameters()).device)
            lora_modules[name] = lora
            orig_forward = module.forward
            def make_forward(orig, lora_mod):
                def forward(x):
                    return orig(x) + lora_mod(x)
                return forward
            module.forward = make_forward(orig_forward, lora)
    print(f"lora: {len(lora_modules)} modules, rank={rank}")
    return lora_modules

print("lora defined")

In [None]:
# =============================================================================
# cell 21: ttt with lora
# =============================================================================
print("="*60)
print("phase 2: test-time training with lora")
print("="*60)

for p in student.parameters():
    p.requires_grad = False

lora_modules = apply_lora(student, config.lora_rank, config.lora_alpha)
lora_params = sum(p.numel() for m in lora_modules.values() for p in m.parameters())

pre_ttt_loss = evaluate(student, val_loader, DEVICE)
pre_ttt_ppl = get_ppl(pre_ttt_loss)
print(f"\npre-ttt ppl: {pre_ttt_ppl:.2f}")

lora_opt = torch.optim.AdamW([p for m in lora_modules.values() for p in m.parameters()], lr=config.ttt_lr)
ttt_logs = {'loss_history': []}
student.train()

for step, batch in enumerate(val_loader):
    if step >= config.ttt_steps:
        break
    ids = batch[0].to(DEVICE)
    with torch.cuda.amp.autocast():
        loss = F.cross_entropy(student(ids)[:, :-1].reshape(-1, config.vocab_size), ids[:, 1:].reshape(-1))
    lora_opt.zero_grad()
    loss.backward()
    lora_opt.step()
    ttt_logs['loss_history'].append({'step': step, 'loss': loss.item()})
    if step % 20 == 0:
        print(f"  ttt {step}: loss={loss.item():.4f}")

post_ttt_loss = evaluate(student, val_loader, DEVICE)
post_ttt_ppl = get_ppl(post_ttt_loss)
print(f"\npost-ttt ppl: {post_ttt_ppl:.2f}")
print(f"ttt improvement: {pre_ttt_ppl - post_ttt_ppl:.1f} ppl")

In [None]:
# =============================================================================
# cell 22: final evaluation
# =============================================================================
print("="*60)
print("final evaluation")
print("="*60)

teacher_loss = evaluate(teacher, val_loader, DEVICE, is_gpt2=True)
teacher_ppl = get_ppl(teacher_loss)
student_loss = evaluate(student, val_loader, DEVICE)
student_ppl = get_ppl(student_loss)

# v9: add VRAM logging
vram_peak_gb = torch.cuda.max_memory_allocated() / 1e9 if torch.cuda.is_available() else 0

print(f"")
print(f"{'model':<25} {'ppl':>10} {'params':>15}")
print("-" * 50)
print(f"{'gpt-2 (teacher)':<25} {teacher_ppl:>10.2f} {teacher_params:>15,}")
print(f"{'asnn-goose v9 (student)':<25} {student_ppl:>10.2f} {student_params:>15,}")
print("-" * 50)
print(f"{'compression':<25} {compression_ratio:>10.1f}x")
print(f"{'ppl gap':<25} {student_ppl - teacher_ppl:>10.2f}")
print(f"{'spike density':<25} {student.get_avg_spike_density():>10.3f}")
print(f"{'VRAM peak':<25} {vram_peak_gb:>10.2f}GB")
print("")
print("version comparison:")
print(f"  v6: 627.3 PPL (baseline)")
print(f"  v7: 1655 PPL (regression!)")
print(f"  v8: 559 PPL (fixed)")
print(f"  v9: {student_ppl:.2f} PPL (capacity increase)")
if student_ppl < 520:
    print(f"  v9 TARGET MET! PPL < 520")
elif student_ppl < 559:
    print(f"  v9 beats v8 by {559 - student_ppl:.1f} PPL")
else:
    print(f"  WARNING: v9 did not improve over v8")

In [None]:
# =============================================================================
# cell 23: visualization
# =============================================================================
fig, axes = plt.subplots(2, 3, figsize=(15, 10))

# distillation loss (combined + separate)
d_steps = [l['step'] for l in distill_logs['loss_history']]
d_losses = [l['loss'] for l in distill_logs['loss_history']]
kl_losses = [l['loss'] for l in distill_logs['kl_loss_history']]
align_losses = [l['loss'] for l in distill_logs['align_loss_history']]
axes[0,0].plot(d_steps, d_losses, label='total', alpha=0.8)
axes[0,0].plot(d_steps, kl_losses, label='kl', alpha=0.7)
if config.hidden_align_weight > 0:
    axes[0,0].plot(d_steps, align_losses, label='align', alpha=0.7)
axes[0,0].set_xlabel('step')
axes[0,0].set_ylabel('loss')
axes[0,0].set_title('distillation loss (v9)')
axes[0,0].legend()

# validation ppl
p_steps = [l['step'] for l in distill_logs['ppl_history']]
p_ppls = [l['ppl'] for l in distill_logs['ppl_history']]
axes[0,1].plot(p_steps, p_ppls, 'orange', marker='o')
axes[0,1].axhline(y=teacher_ppl, color='green', linestyle='--', label=f'teacher ({teacher_ppl:.1f})')
axes[0,1].axhline(y=627.3, color='blue', linestyle=':', label='v6 (627.3)')
axes[0,1].axhline(y=559, color='purple', linestyle=':', label='v8 (559)')
axes[0,1].axhline(y=520, color='red', linestyle='--', label='v9 target (520)')
axes[0,1].set_xlabel('step')
axes[0,1].set_ylabel('ppl')
axes[0,1].set_title('validation ppl')
axes[0,1].legend()

# lr schedule
lr_steps = [l['step'] for l in distill_logs['lr_history']]
lr_vals = [l['lr'] for l in distill_logs['lr_history']]
axes[0,2].plot(lr_steps, lr_vals, 'purple')
axes[0,2].axvline(x=config.warmup_steps, color='gray', linestyle='--', label=f'warmup ({config.warmup_steps})')
axes[0,2].set_xlabel('step')
axes[0,2].set_ylabel('lr')
axes[0,2].set_title('learning rate')
axes[0,2].legend()

# spike density + amplitudes
spike_summary = spike_stats.get_summary()
layers = list(spike_summary['per_layer'].keys())
k_dens = [spike_summary['per_layer'][l]['k_final'] for l in layers]
v_dens = [spike_summary['per_layer'][l]['v_final'] for l in layers]
k_amps = [spike_summary['per_layer'][l]['k_amp_final'] for l in layers]
v_amps = [spike_summary['per_layer'][l]['v_amp_final'] for l in layers]

x = np.arange(len(layers))
axes[1,0].bar(x - 0.2, k_dens, 0.4, label='k density')
axes[1,0].bar(x + 0.2, v_dens, 0.4, label='v density')
ax2 = axes[1,0].twinx()
ax2.plot(x, k_amps, 'r-o', label='k amp')
ax2.plot(x, v_amps, 'b-s', label='v amp')
axes[1,0].set_xlabel('layer')
axes[1,0].set_ylabel('density')
ax2.set_ylabel('amplitude')
axes[1,0].set_title('spike density & amplitudes (v9: 5 layers)')
axes[1,0].legend(loc='upper left')
ax2.legend(loc='upper right')

# ttt loss
t_steps = [l['step'] for l in ttt_logs['loss_history']]
t_losses = [l['loss'] for l in ttt_logs['loss_history']]
axes[1,1].plot(t_steps, t_losses, 'red')
axes[1,1].set_xlabel('step')
axes[1,1].set_ylabel('ce loss')
axes[1,1].set_title('ttt with lora')

# version comparison
versions = ['v6', 'v7', 'v8', 'v9']
t_ppls = [44.6, 44.6, 44.6, teacher_ppl]
s_ppls = [627.3, 1655, 559, student_ppl]
x = np.arange(len(versions))
axes[1,2].bar(x - 0.2, t_ppls, 0.4, label='teacher', alpha=0.7)
axes[1,2].bar(x + 0.2, s_ppls, 0.4, label='student', alpha=0.7)
axes[1,2].axhline(y=520, color='red', linestyle='--', label='v9 target', alpha=0.7)
axes[1,2].set_xticks(x)
axes[1,2].set_xticklabels(versions)
axes[1,2].set_ylabel('ppl')
axes[1,2].set_title('version comparison')
axes[1,2].legend()
axes[1,2].set_yscale('log')

plt.tight_layout()
figure_path = f'{OUTPUT_DIR}/figures/v9_training_{RUN_TIMESTAMP}.png'
plt.savefig(figure_path, dpi=300, bbox_inches='tight')
plt.show()
print(f"saved: {figure_path}")

In [None]:
# =============================================================================
# cell 24: results.json with base64 png (renamed from summary.json)
# =============================================================================
print("building results...")

with open(figure_path, 'rb') as f:
    figure_base64 = base64.b64encode(f.read()).decode('utf-8')

results = {
    'version': 'v9',
    'timestamp': datetime.now().isoformat(),
    'run_id': RUN_TIMESTAMP,
    'platform': PLATFORM,
    'description': 'model capacity increase (320d, 5L, ~30M params)',

    'v9_design': {
        'principle': 'increase model capacity for better distillation',
        'changes': {
            'd_model': '256 → 320 (+25%)',
            'n_layers': '4 → 5 (+25%)',
            'params': '~16M → ~30M (+87%)',
            'teacher_indices': '[3,6,9,12] → [2,5,7,10,12]',
        },
        'unchanged': [
            'temperature: 2.0',
            'hidden_align_weight: 0.0',
            'warmup_steps: 50',
            'distill_steps: 3000',
        ],
    },

    'architecture': {
        'teacher': {'name': 'gpt2', 'params': teacher_params},
        'student': {
            'name': 'asnn-goose-v9',
            'd_model': config.d_model,
            'n_layers': config.n_layers,
            'params': student_params,
        },
        'projector_params': projector_params,
        'compression_ratio': compression_ratio,
        'vram_peak_gb': vram_peak_gb,
    },

    'training_config': {
        'distill_steps': config.distill_steps,
        'temperature': config.temperature,
        'hidden_align_weight': config.hidden_align_weight,
        'warmup_steps': config.warmup_steps,
        'batch_size': config.batch_size,
        'distill_lr': config.distill_lr,
        'max_grad_norm': config.max_grad_norm,
    },

    'results': {
        'teacher_ppl': teacher_ppl,
        'student_ppl': student_ppl,
        'ppl_gap': student_ppl - teacher_ppl,
        'spike_density': student.get_avg_spike_density(),
        'amplitudes': student.get_amplitudes(),
        'target_met': student_ppl < 520,
    },

    'training_curves': {
        'loss_history': distill_logs['loss_history'],
        'kl_loss_history': distill_logs['kl_loss_history'],
        'align_loss_history': distill_logs['align_loss_history'],
        'ppl_history': distill_logs['ppl_history'],
        'lr_history': distill_logs['lr_history'],
    },

    'hardware_stats': hw_stats.get_summary(),
    'spike_analysis': spike_stats.get_summary(),

    'ttt': {
        'lora_params': lora_params,
        'pre_ppl': pre_ttt_ppl,
        'post_ppl': post_ttt_ppl,
        'improvement': pre_ttt_ppl - post_ttt_ppl,
        'loss_history': ttt_logs['loss_history'],
    },

    'comparison': {
        'v6': {'student_ppl': 627.3, 'note': 'baseline'},
        'v7': {'student_ppl': 1655, 'note': 'regression (align=1.0, T=4)'},
        'v8': {'student_ppl': 559, 'note': 'fixed defaults (align=0, T=2)'},
        'v9': {'student_ppl': student_ppl, 'note': 'capacity increase (320d, 5L)'},
    },

    'figures': {
        'training_plot': {
            'filename': f'v9_training_{RUN_TIMESTAMP}.png',
            'base64': figure_base64,
        }
    },
}

# Save as results.json (renamed from summary.json)
results_path = f'{OUTPUT_DIR}/results/results_{RUN_TIMESTAMP}.json'
with open(results_path, 'w') as f:
    json.dump(results, f, indent=2, default=str)

print(f"saved: {results_path}")
print(f"size: {os.path.getsize(results_path) / 1024:.1f} KB")

In [None]:
# =============================================================================
# cell 25: auto-download for colab
# =============================================================================
print("="*60)
print("auto-download")
print("="*60)

if IS_COLAB:
    try:
        from google.colab import files
        files.download(results_path)
        files.download(figure_path)
        print("downloads started!")
    except Exception as e:
        print(f"download failed: {e}")
elif IS_KAGGLE:
    print(f"kaggle: {results_path}")
else:
    print(f"local: {results_path}")

In [None]:
# =============================================================================
# cell 26: validation tests (9 total - added amplitude health)
# =============================================================================
print("="*60)
print("validation tests (v9)")
print("="*60)

test_results = {}

# 1. teacher pre-trained
print("\n[1] teacher pre-trained")
test_results['teacher_pretrained'] = teacher_ppl < 50
print(f"  teacher ppl: {teacher_ppl:.2f} - {'pass' if test_results['teacher_pretrained'] else 'fail'}")

# 2. student learned (v9 target: beat v8's 559, aim for <520)
print("\n[2] student learned language")
test_results['student_learned'] = student_ppl < 627
print(f"  student ppl: {student_ppl:.2f} - {'pass' if test_results['student_learned'] else 'fail'} (target: < 627)")

# 3. ternary activations
print("\n[3] ternary activations")
student.eval()
with torch.no_grad():
    test_ids = next(iter(val_loader))[0].to(DEVICE)
    layer = student.layers[0]['rec']
    x = student.embed(test_ids) + student.pos_embed(torch.arange(test_ids.size(1), device=DEVICE).unsqueeze(0))
    x_norm = layer.ln(x)
    prev_x = F.pad(x_norm[:, :-1, :], (0, 0, 1, 0))
    xk = x_norm * layer.time_mix_k + prev_x * (1 - layer.time_mix_k)
    k_spike = layer.k_spike(layer.key_proj(xk))
    unique_vals = len(set([round(v, 4) for v in k_spike.unique().cpu().tolist()]))
    test_results['ternary'] = unique_vals <= 3
    print(f"  unique values: {unique_vals} - {'pass' if test_results['ternary'] else 'fail'}")

# 4. gradient flow
print("\n[4] gradient flow (STE)")
_spike = TrainableTernarySpike().to(DEVICE)
_x = torch.randn(2, 16, 64, device=DEVICE, requires_grad=True)
_spike(_x).sum().backward()
test_results['gradient'] = _x.grad is not None and _x.grad.abs().sum() > 0
print(f"  {'pass' if test_results['gradient'] else 'fail'}")

# 5. spike density
print("\n[5] spike density in range")
density = student.get_avg_spike_density()
test_results['density'] = 0.1 < density < 0.9
print(f"  density: {density:.3f} - {'pass' if test_results['density'] else 'fail'}")

# 6. lora applied
print("\n[6] lora applied")
test_results['lora'] = len(lora_modules) > 0
print(f"  modules: {len(lora_modules)} - {'pass' if test_results['lora'] else 'fail'}")

# 7. improvement over v6 baseline
print("\n[7] beats v6 baseline")
test_results['beats_v6'] = student_ppl < 627.3
print(f"  v6: 627.3, v9: {student_ppl:.2f} - {'pass' if test_results['beats_v6'] else 'fail'}")

# 8. amplitudes learned
print("\n[8] amplitudes learned")
amps = student.get_amplitudes()
test_results['amplitudes_learned'] = any(
    abs(amps[f'layer_{i}']['k'] - 1.0) > 0.05 or abs(amps[f'layer_{i}']['v'] - 1.0) > 0.05
    for i in range(config.n_layers)
)
print(f"  amplitudes: {[f\"{k}:{v['k']:.3f}\" for k,v in amps.items()]}")
print(f"  {'pass' if test_results['amplitudes_learned'] else 'fail'} - any amplitude != 1.0 by > 0.05")

# 9. NEW: amplitude health check (v9)
print("\n[9] amplitude health (v9 new)")
all_healthy = True
for layer_idx, amp_dict in amps.items():
    k_amp, v_amp = amp_dict['k'], amp_dict['v']
    if not (0.3 < k_amp < 3.0) or not (0.3 < v_amp < 3.0):
        print(f"  WARNING: {layer_idx} unhealthy: k={k_amp:.3f}, v={v_amp:.3f}")
        all_healthy = False
test_results['amplitude_health'] = all_healthy
print(f"  {'pass' if all_healthy else 'fail'} - all amplitudes in [0.3, 3.0]")

# save results to results dict
results['validation_tests'] = test_results
with open(results_path, 'w') as f:
    json.dump(results, f, indent=2, default=str)

print("\n" + "="*60)
passed = sum(1 for v in test_results.values() if v)
print(f"results: {passed}/{len(test_results)} passed")
if student_ppl < 520:
    print(f"v9 TARGET MET: PPL {student_ppl:.2f} < 520")
elif student_ppl < 559:
    print(f"v9 improved from v8: {559 - student_ppl:.1f} PPL reduction")

## 6. summary

### 6.1 v9 design principle

**increase model capacity for better distillation**

| attribute | v8 | v9 | change |
|-----------|-----|-----|--------|
| d_model | 256 | **320** | +25% |
| n_layers | 4 | **5** | +25% |
| params | ~16M | **~30M** | +87% |
| VRAM | ~1.5GB | **~2.5GB** | +67% |
| temperature | 2.0 | 2.0 | - |
| hidden_align_weight | 0.0 | 0.0 | - |

### 6.2 version progression

| version | teacher ppl | student ppl | key change |
|---------|-------------|-------------|------------|
| v6 | 44.6 | **627** | gpt-2 distillation |
| v7 | 44.6 | 1655 | **regression** (align=1.0, T=4) |
| v8 | 44.6 | **559** | fixed defaults |
| **v9** | ~45 | **target: <520** | **capacity increase** |

### 6.3 validation tests (9 total)

1. teacher pre-trained (PPL < 50)
2. student learned (PPL < 627)
3. ternary activations verified
4. gradient flow via STE
5. spike density in [0.1, 0.9]
6. LoRA applied
7. beats v6 baseline
8. amplitudes learned
9. **amplitude health [0.3, 3.0]** (NEW in v9)

### 6.4 next steps

if v9 succeeds (PPL < 520):
1. v10: curriculum temperature (CTKD)
2. v11: progressive training
3. v12: patient training

if v9 doesn't meet target:
- try v9.1 with 384d × 6L (~45M params)

---

*asnn-goose v9 - eptesicus laboratories - lumis-next initiative*

### references

- hinton, g., vinyals, o., & dean, j. (2015). distilling the knowledge in a neural network.
- radford, a., et al. (2019). language models are unsupervised multitask learners.
- lv, c., et al. (2023). spikebert: a language spikformer learned from bert with knowledge distillation.
- shen, j., et al. (2024). spikingmamba: towards energy-efficient large language models.
- wei, j., et al. (2023). ternary spike: learning ternary spikes for spiking neural networks.
- hu, e. j., et al. (2022). lora: low-rank adaptation of large language models.