<a href="https://colab.research.google.com/github/farmountain/SmartGlass-AI-Agent/blob/main/ANN_Llm2SNN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# ANN → SNN KD quickstart (Colab-friendly)

This section keeps the legacy cells below intact while adding a lightweight, top-level walkthrough for knowledge distillation (KD) from an ANN teacher to a spiking student. It is designed for Colab Pro time/VRAM budgets:

- **Setup:** ensure `torch`, `transformers`, and repo deps are installed; mount Drive if needed.
- **Teacher & data:** start with `sshleifer/tiny-gpt2` and a small prompt list; swap in your model or prompt file.
- **Student:** spiking-friendly transformer defined in `scripts/train_snn_student.py`.
- **KD loss:** logits KL + surrogate spike nonlinearity; small step counts and accumulation keep memory low.
- **Monitoring:** capture `training_log.json` in the output dir and plot quick loss curves (see below).

The original full scripts remain in the lower cells for deeper experimentation.


In [None]:
# Quick CLI overview (safe to run; no downloads triggered)
import subprocess, sys
cmd = [sys.executable, 'scripts/train_snn_student.py', '--help']
print(' '.join(cmd))
_ = subprocess.run(cmd, check=False)


In [None]:
# Optional micro run (mock by default to avoid downloads in Colab Pro)
import shlex, sys
example_cmd = [
    sys.executable, 'scripts/train_snn_student.py',
    '--teacher-model', 'sshleifer/tiny-gpt2',
    '--num-steps', '2',
    '--batch-size', '1',
    '--grad-accum-steps', '1',
    '--max-length', '32',
    '--output-dir', 'artifacts/snn_student_demo'
]
print(' '.join(shlex.quote(x) for x in example_cmd))
# To actually run in Colab, uncomment the next line (may download HF weights):
# import subprocess; subprocess.run(example_cmd, check=False)


In [None]:
# Mock loss curve for quick visualization
import numpy as np
import matplotlib.pyplot as plt
steps = np.arange(0, 40)
teacher_kd = 1.2 * np.exp(-steps / 12) + 0.05 * np.random.rand(len(steps))
spike_reg = 0.15 * np.exp(-steps / 20) + 0.02 * np.random.rand(len(steps))
total = teacher_kd + spike_reg
plt.figure(figsize=(6,4))
plt.plot(steps, teacher_kd, label='KD loss (logits)')
plt.plot(steps, spike_reg, label='Spike reg')
plt.plot(steps, total, label='Total')
plt.xlabel('Step')
plt.ylabel('Loss (mock)')
plt.title('Sample ANN → SNN KD trajectory')
plt.legend()
plt.tight_layout()
plt.show()


In [None]:
import numpy as np
import time

# 参数（2025年：Loihi 3 1pJ/尖峰，Lightmatter 1ns）
N = 1024
sparsity = 0.1
num_nodes = 10
spikes_per_op = N**2 * sparsity / num_nodes
pst_pj_per_spike = 1e-12
photonic_latency = 1e-9
h100_flops = 4e15
h100_watts = 700

# 稀疏矩阵
np.random.seed(42)
weights = np.random.rand(N, N) * (np.random.rand(N, N) < sparsity)
input_data = np.random.rand(N, 1)

# PST节点模拟
def pst_node_inference(task_size):
    spikes = task_size * sparsity * 0.1
    latency = spikes / 1e9  # 1ns/尖峰
    energy = spikes * pst_pj_per_spike
    output = weights[:int(task_size), :] @ input_data
    return output, latency, energy

# PST集群模拟
def pst_cluster_inference():
    total_latency, total_energy = 0, 0
    task_size = N / num_nodes
    for _ in range(num_nodes):
        output, latency, energy = pst_node_inference(task_size)
        total_latency = max(total_latency, latency) + photonic_latency
        total_energy += energy
    return total_latency, total_energy

# H100基准
def h100_inference():
    flops = 2 * N**2
    latency = flops / h100_flops
    energy = h100_watts * latency
    return latency, energy

# 运行（2025年11月2日 20:57 +08）
start_time = time.time()
pst_lat, pst_energy = pst_cluster_inference()
h100_lat, h100_energy = h100_inference()
print(f"PST: 延迟={pst_lat*1000:.6f}ms, 能量={pst_energy:.6f}J, 吞吐量={1/pst_lat:.0f} ops/s")
print(f"H100: 延迟={h100_lat*1000:.6f}ms, 能量={h100_energy:.6f}J, 吞吐量={1/h100_lat:.0f} ops/s")
print(f"PST优势: 能量 {h100_energy/pst_energy:.0f}x, 延迟 {h100_lat/pst_lat:.1f}x")
print(f"运行时间: {time.time() - start_time:.2f}s (2025-11-02 20:57 +08)")

PST: 延迟=0.000011ms, 能量=0.000000J, 吞吐量=90711176 ops/s
H100: 延迟=0.000001ms, 能量=0.000000J, 吞吐量=1907348633 ops/s
PST优势: 能量 35840x, 延迟 0.0x
运行时间: 0.00s (2025-11-02 20:57 +08)


In [6]:
# %% [markdown]
# LLM → SNN PoC: FAS/LAS-style Calibration + KD Distillation
# - Teacher: tiny-GPT2 for smoke test (swap to Mamba/LLaMA later)
# - Student: SNN (Embedding → N x [Linear+LIF over T] → LM head)
# - KD losses: logits KL + hidden MSE + spike-rate regularization
# - Calibration: teacher hidden stats → LIF thresholds
# - Evaluation: perplexity + very rough energy gain estimate (MAC vs events)

# %% Install deps (Colab-safe, no runtime restart required)
import sys, subprocess, pkgutil
def _pip(pkg): subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", pkg])

for pkg in ["torch", "transformers>=4.44.0", "datasets>=2.19.0", "spikingjelly"]:
    if pkgutil.find_loader(pkg.split("==")[0].split(">=")[0]) is None:
            _pip(pkg)

            # %% Imports
            import math
            import argparse
            from dataclasses import dataclass
            from typing import Tuple

            import torch
            import torch.nn as nn
            import torch.nn.functional as F
            from torch.utils.data import DataLoader

            try:
                from spikingjelly.activation_based import neuron, functional, surrogate
            except Exception as e:
                raise RuntimeError("Please install spikingjelly: pip install spikingjelly") from e

            from transformers import AutoModelForCausalLM, AutoTokenizer
            from datasets import load_dataset


            # %% Utilities
            def set_seed(seed: int = 42):
                import random, numpy as np
                random.seed(seed)
                np.random.seed(seed)
                torch.manual_seed(seed)
                torch.cuda.manual_seed_all(seed)

            @dataclass
            class KDWeights:
                logits: float = 1.0
                hidden: float = 0.5
                attention: float = 0.0  # no attn maps in student by default
                spike_rate: float = 0.3
                ttfs: float = 0.0       # keep 0 until TTFS implemented


            # %% Dataset / Collator
            class LMTextCollator:
                """
                Turn raw text → tokenized blocks for Causal LM (inputs, labels shifted).
                """
                def __init__(self, tokenizer, block_size: int = 128):
                    self.tok = tokenizer
                    self.block = block_size

                def __call__(self, batch):
                    texts = [ex.get("text", "") for ex in batch if ex.get("text") is not None]
                    joined = "\n\n".join(texts)
                    toks = self.tok(joined, return_tensors=None, truncation=True, max_length=4096, add_special_tokens=True)
                    ids = toks["input_ids"]
                    blocks = []
                    # split into blocks with next-token labels
                    i = 0
                    while i + self.block + 1 <= len(ids):
                        inp = ids[i : i + self.block]
                        lab = ids[i + 1 : i + 1 + self.block]
                        blocks.append({
                            "input_ids": torch.tensor(inp, dtype=torch.long),
                            "labels": torch.tensor(lab, dtype=torch.long),
                        })
                        i += self.block
                    if not blocks:  # fallback to at least one short sample
                        inp = torch.tensor(ids[: self.block], dtype=torch.long)
                        lab = torch.tensor(ids[1 : 1 + len(inp)], dtype=torch.long)
                        if len(lab) < len(inp):
                            pad = torch.full((len(inp) - len(lab),), -100, dtype=torch.long)
                            lab = torch.cat([lab, pad], dim=0)
                        blocks = [{"input_ids": inp, "labels": lab}]
                    input_ids = torch.stack([b["input_ids"] for b in blocks], dim=0)
                    labels = torch.stack([b["labels"] for b in blocks], dim=0)
                    return {"input_ids": input_ids, "labels": labels}


            # %% Student SNN
            class SNNLinearBlock(nn.Module):
                """
                Linear -> LIF over T steps (rate-coded). Residual stabilized by LayerNorm.
                inp/out: [B, S, H]
                """
                def __init__(self, hidden_size: int, lif_threshold: float = 1.0, timesteps: int = 6, p_drop: float = 0.0):
                    super().__init__()
                    self.proj = nn.Linear(hidden_size, hidden_size)
                    self.neuron = neuron.LIFNode(v_threshold=lif_threshold, surrogate_function=surrogate.ATan())
                    self.timesteps = timesteps
                    self.dropout = nn.Dropout(p_drop)
                    self.ln = nn.LayerNorm(hidden_size)

                def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
                    # x: [B, S, H]
                    B, S, H = x.shape
                    h_lin = self.proj(self.ln(x))  # [B, S, H]
                    h_lin = self.dropout(h_lin)

                    # Repeat over T with small noise to induce spiking variability
                    T = self.timesteps
                    h_rep = h_lin.unsqueeze(0).repeat(T, 1, 1, 1)  # [T, B, S, H]
                    if self.training:
                        h_rep = h_rep + torch.randn_like(h_rep) * 0.05

                    # Flatten batch to [T, B*S, H] for SpikingJelly
                    h_rep2 = h_rep.view(T, B * S, H)
                    functional.reset_net(self)  # reset membrane
                    spikes = functional.multi_step_forward(h_rep2, self.neuron)  # [T, B*S, H]
                    spikes = spikes.view(T, B, S, H)

                    out = spikes.mean(dim=0)               # [B, S, H]  (for TTFS: replace with first-hit aggregation)
                    spike_rate = spikes.mean()             # scalar
                    return out, spike_rate


            class SNNLM(nn.Module):
                """
                Embedding → N x SNNLinearBlock → LayerNorm → LM Head
                """
                def __init__(self, vocab_size: int, hidden_size: int = 384, num_layers: int = 4, timesteps: int = 6, p_drop: float = 0.0):
                    super().__init__()
                    self.embed = nn.Embedding(vocab_size, hidden_size)
                    self.blocks = nn.ModuleList([
                        SNNLinearBlock(hidden_size, timesteps=timesteps, p_drop=p_drop) for _ in range(num_layers)
                    ])
                    self.ln_f = nn.LayerNorm(hidden_size)
                    self.lm_head = nn.Linear(hidden_size, vocab_size)
                    self.timesteps = timesteps

                def forward(self, input_ids: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
                    x = self.embed(input_ids)  # [B, S, H]
                    total_rate = 0.0
                    for blk in self.blocks:
                        x, r = blk(x)
                        total_rate = total_rate + r
                    x = self.ln_f(x)
                    logits = self.lm_head(x)  # [B, S, V]
                    avg_rate = total_rate / max(1, len(self.blocks))
                    return logits, x, avg_rate


            # %% ANN→SNN Calibration (simple proxy)
            @torch.no_grad()
            def collect_teacher_hidden_stats(teacher, dl, device, max_batches: int = 10) -> Tuple[float, float]:
                means, stds = [], []
                n = 0
                for batch in dl:
                    input_ids = batch["input_ids"].to(device)
                    out = teacher(input_ids, output_hidden_states=True)
                    h = out.hidden_states[-1].detach().float()  # [B, S, Ht]
                    means.append(h.mean().item())
                    stds.append(h.std().item())
                    n += 1
                    if n >= max_batches: break
                mean_h = sum(means) / max(1, len(means))
                std_h = sum(stds) / max(1, len(stds))
                return mean_h, std_h

            @torch.no_grad()
            def calibrate_student_thresholds(student: SNNLM, teacher_hidden_stats: Tuple[float, float]):
                mean_h, std_h = teacher_hidden_stats
                base_thr = float(max(0.1, mean_h + 0.5 * std_h))
                for blk in student.blocks:
                    blk.neuron.v_threshold = base_thr


            # %% KD Loss
            class KDLoss(nn.Module):
                def __init__(self, kd: KDWeights, teacher_hidden_size: int, student_hidden_size: int, temp: float = 1.0):
                    super().__init__()
                    self.kd = kd
                    self.temp = temp
                    self.adapter = (
                        nn.Linear(student_hidden_size, teacher_hidden_size)
                        if student_hidden_size != teacher_hidden_size
                        else nn.Identity()
                    )

                def forward(self,
                            student_logits, teacher_logits,
                            student_hidden, teacher_hidden,
                            spike_rate: torch.Tensor,
                            target_rate: float = 0.3):
                    t = self.temp
                    # logits KD
                    s_log = F.log_softmax(student_logits / t, dim=-1)
                    t_log = F.softmax(teacher_logits / t, dim=-1)
                    loss_logits = F.kl_div(s_log, t_log, reduction="batchmean") * (t * t)
                    # hidden KD
                    sh = self.adapter(student_hidden)
                    th = teacher_hidden.detach()
                    loss_hidden = F.mse_loss(sh, th)
                    # attn KD skipped in this minimal SNN
                    loss_attn = torch.tensor(0.0, device=student_logits.device)
                    # spike-rate regularization
                    loss_rate = (spike_rate - target_rate) ** 2

                    loss = (self.kd.logits * loss_logits +
                            self.kd.hidden * loss_hidden +
                            self.kd.attention * loss_attn +
                            self.kd.spike_rate * loss_rate)
                    return loss, {"loss_logits": loss_logits.detach(),
                                   "loss_hidden": loss_hidden.detach(),
                                   "loss_rate": loss_rate.detach()}


            # %% Evaluation
            @torch.no_grad()
            def evaluate_perplexity(model: SNNLM, dl, device) -> float:
                model.eval()
                tot_loss, count = 0.0, 0
                for batch in dl:
                    input_ids = batch["input_ids"].to(device)
                    labels = batch["labels"].to(device)
                    logits, _, _ = model(input_ids)
                    loss = F.cross_entropy(logits.view(-1, logits.size(-1)),
                                           labels.view(-1),
                                           ignore_index=-100)
                    tot_loss += loss.item()
                    count += 1
                ppl = math.exp(tot_loss / max(1, count))
                return ppl

            @torch.no_grad()
            def estimate_energy_gain(student_avg_rate: float,
                                     hidden_size: int,
                                     num_layers: int,
                                     timesteps: int,
                                     teacher_mac: float) -> float:
                """
                Extremely rough: student events ~= hidden_size * num_layers * timesteps * avg_rate
                """
                student_events = hidden_size * num_layers * timesteps * max(1e-6, float(student_avg_rate))
                gain = teacher_mac / max(1e-6, student_events)
                return gain


            # %% Main (Colab-friendly defaults)
            def main():
                class Args: pass
                args = Args()
                args.teacher = "sshleifer/tiny-gpt2"   # swap to mamba small variant when ready
                args.block_size = 128
                args.hidden = 384
                args.layers = 4
                args.timesteps = 6
                args.batch_size = 2
                args.epochs = 1
                args.max_steps = 50        # quick smoke test; increase later
                args.lr = 1e-4
                args.seed = 42
                args.device = "cuda" if torch.cuda.is_available() else "cpu"

                set_seed(args.seed)

                # Teacher
                tok = AutoTokenizer.from_pretrained(args.teacher)
                if tok.pad_token_id is None:
                    tok.pad_token = tok.eos_token
                teacher = AutoModelForCausalLM.from_pretrained(args.teacher).to(args.device)
                teacher.eval()

                # Data
                raw = load_dataset("wikitext", "wikitext-2-raw-v1")
                collator = LMTextCollator(tok, block_size=args.block_size)
                train_loader = DataLoader(raw["train"], batch_size=args.batch_size, shuffle=True, collate_fn=collator)
                val_loader   = DataLoader(raw["validation"], batch_size=args.batch_size, shuffle=False, collate_fn=collator)

                # Student
                student = SNNLM(vocab_size=tok.vocab_size,
                                hidden_size=args.hidden,
                                num_layers=args.layers,
                                timesteps=args.timesteps).to(args.device)

                # Calibration (FAS/LAS-style proxy)
                mean_h, std_h = collect_teacher_hidden_stats(teacher, train_loader, args.device, max_batches=10)
                calibrate_student_thresholds(student, (mean_h, std_h))

                # KD setup
                kd_w = KDWeights(logits=1.0, hidden=0.5, attention=0.0, spike_rate=0.3, ttfs=0.0)
                t_hidden_size = getattr(teacher.config, "n_embd", args.hidden)
                kd_loss = KDLoss(kd_w,
                                 teacher_hidden_size=t_hidden_size,
                                 student_hidden_size=args.hidden).to(args.device)
                optim = torch.optim.AdamW(student.parameters(), lr=args.lr)

                # Train (max_steps cap for quick demo)
                step, best_val = 0, float("inf")
                student.train()
                while step < args.max_steps:
                    for batch in train_loader:
                        step += 1
                        if step > args.max_steps: break
                        input_ids = batch["input_ids"].to(args.device)
                        labels    = batch["labels"].to(args.device)
                        with torch.no_grad():
                            tout = teacher(input_ids, output_hidden_states=True, output_attentions=False)
                            t_logits = tout.logits
                            t_hidden = tout.hidden_states[-1]
                        s_logits, s_hidden, s_rate = student(input_ids)
                        loss_kd, parts = kd_loss(s_logits, t_logits, s_hidden, t_hidden, s_rate)
                        # tiny LM loss for stabilization
                        loss_lm = F.cross_entropy(s_logits.view(-1, s_logits.size(-1)),
                                                  labels.view(-1), ignore_index=-100)
                        loss = loss_kd + 0.1 * loss_lm
                        optim.zero_grad()
                        loss.backward()
                        optim.step()

                        if step % 10 == 0:
                            print(f"[Step {step}] loss={loss.item():.4f} "
                                  f"(kd_logits={float(parts['loss_logits']):.4f} "
                                  f"kd_hidden={float(parts['loss_hidden']):.4f} "
                                  f"rate={float(s_rate):.4f})")

                        if step % 50 == 0:
                            ppl = evaluate_perplexity(student, val_loader, args.device)
                            print(f"   -> validation ppl={ppl:.2f}")
                            best_val = min(best_val, ppl)

                # Final eval & energy estimate
                student_ppl = evaluate_perplexity(student, val_loader, args.device)
                print(f"[Final] student perplexity: {student_ppl:.2f}")

                # Very rough MAC vs events proxy
                teacher_mac = (getattr(teacher.config, "n_layer", 12) *
                               getattr(teacher.config, "n_embd", args.hidden) ** 2)
                gain = estimate_energy_gain(float(s_rate), args.hidden, args.layers, args.timesteps, float(teacher_mac))
                print(f"[Energy] rough gain estimate: {gain:.2f}x "
                      f"(avg spike rate={float(s_rate):.4f}, timesteps={args.timesteps})")

            if __name__ == "__main__":
                main()

  if pkgutil.find_loader(pkg.split("==")[0].split(">=")[0]) is None:
The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/662 [00:00<?, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/90.0 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/2.51M [00:00<?, ?B/s]

README.md: 0.00B [00:00, ?B/s]

model.safetensors:   0%|          | 0.00/2.51M [00:00<?, ?B/s]

wikitext-2-raw-v1/test-00000-of-00001.pa(…):   0%|          | 0.00/733k [00:00<?, ?B/s]

wikitext-2-raw-v1/train-00000-of-00001.p(…):   0%|          | 0.00/6.36M [00:00<?, ?B/s]

wikitext-2-raw-v1/validation-00000-of-00(…):   0%|          | 0.00/657k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/4358 [00:00<?, ? examples/s]

Generating train split:   0%|          | 0/36718 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/3760 [00:00<?, ? examples/s]

Consider using tensor.detach() first. (Triggered internally at /pytorch/torch/csrc/autograd/generated/python_variable_methods.cpp:835.)
  f"rate={float(s_rate):.4f})")


[Step 10] loss=21.7376 (kd_logits=20.2454 kd_hidden=0.7689 rate=0.1060)
[Step 20] loss=11.1632 (kd_logits=9.6588 kd_hidden=0.8009 rate=0.1081)
[Step 30] loss=4.1289 (kd_logits=2.7955 kd_hidden=0.5676 rate=0.1119)
[Step 40] loss=12.4620 (kd_logits=10.9466 kd_hidden=0.8223 rate=0.1100)
[Step 50] loss=3.6067 (kd_logits=2.1777 kd_hidden=0.7606 rate=0.1187)
   -> validation ppl=nan
[Final] student perplexity: nan
[Energy] rough gain estimate: 0.01x (avg spike rate=0.1187, timesteps=6)


In [9]:
# %% [markdown]
# LLM → SNN PoC (FAS/LAS-style Calibration + KD Distillation) — Stable Colab Script
# - Teacher: tiny-GPT2 (smoke test). Swap to Mamba/LLaMA later.
# - Student: SNN (Embedding → N x [Linear+LIF over T] → LM Head)
# - Loss: logits KD + hidden KD + spike-rate regularization (+ small LM CE)
# - Fixes: stable perplexity (token-masked), NaN guards, grad clip, safe prints

# %% Install deps
import sys, subprocess, pkgutil
def _pip(pkg): subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", pkg])
for pkg in ["torch", "transformers>=4.44.0", "datasets>=2.19.0", "spikingjelly"]:
    name = pkg.split("==")[0].split(">=")[0]
    if pkgutil.find_loader(name) is None:
        _pip(pkg)

# %% Imports
import math
from dataclasses import dataclass
from typing import Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

try:
    from spikingjelly.activation_based import neuron, functional, surrogate
except Exception as e:
    raise RuntimeError("Please install spikingjelly: pip install spikingjelly") from e

# %% Utils
def set_seed(seed: int = 42):
    import random, numpy as np
    random.seed(seed); np.random.seed(seed)
    torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)

@dataclass
class KDWeights:
    logits: float = 1.0
    hidden: float = 0.5
    attention: float = 0.0   # student无注意力图，默认0
    spike_rate: float = 0.3
    ttfs: float = 0.0        # 预留

# %% Collator（稳定的自回归分块 + 掩码标签）
class LMTextCollator:
    def __init__(self, tokenizer, block_size: int = 128):
        self.tok = tokenizer
        self.block = block_size

    def __call__(self, batch):
        texts = [ex.get("text", "") for ex in batch if ex.get("text") is not None]
        joined = "\n\n".join(texts)
        toks = self.tok(joined, return_tensors=None, truncation=True, max_length=4096, add_special_tokens=True)
        ids = toks["input_ids"]
        blocks = []
        i = 0
        while i + self.block + 1 <= len(ids):
            inp = ids[i : i + self.block]
            lab = ids[i + 1 : i + 1 + self.block]
            blocks.append({
                "input_ids": torch.tensor(inp, dtype=torch.long),
                "labels": torch.tensor(lab, dtype=torch.long),
            })
            i += self.block

        # Ensure at least one block is created or handle empty case
        if not blocks and len(ids) > 1:
             inp = torch.tensor(ids[: self.block], dtype=torch.long)
             lab = torch.tensor(ids[1 : 1 + len(inp)], dtype=torch.long)
             if len(lab) < len(inp):
                 pad = torch.full((len(inp)-len(lab),), -100, dtype=torch.long)
                 lab = torch.cat([lab, pad], dim=0)
             blocks = [{"input_ids": inp, "labels": lab}]
        elif not blocks: # Handle cases with less than 2 tokens
            return {"input_ids": torch.tensor([], dtype=torch.long), "labels": torch.tensor([], dtype=torch.long)}


        input_ids = torch.stack([b["input_ids"] for b in blocks], dim=0)
        labels    = torch.stack([b["labels"]    for b in blocks], dim=0)
        return {"input_ids": input_ids, "labels": labels}

# %% 学生 SNN
class SNNLinearBlock(nn.Module):
    def __init__(self, hidden_size: int, lif_threshold: float = 1.0, timesteps: int = 6, p_drop: float = 0.0):
        super().__init__()
        self.proj = nn.Linear(hidden_size, hidden_size)
        self.neuron = neuron.LIFNode(v_threshold=lif_threshold, surrogate_function=surrogate.ATan())
        self.timesteps = timesteps
        self.dropout = nn.Dropout(p_drop)
        self.ln = nn.LayerNorm(hidden_size)

    def forward(self, x: torch.Tensor):
        # x: [B, S, H]
        B, S, H = x.shape
        h_lin = self.proj(self.ln(x))
        h_lin = self.dropout(h_lin)

        T = self.timesteps
        h_rep = h_lin.unsqueeze(0).repeat(T, 1, 1, 1)  # [T,B,S,H]
        if self.training:
            h_rep = h_rep + torch.randn_like(h_rep) * 0.05

        h_rep2 = h_rep.view(T, B*S, H)
        functional.reset_net(self)
        spikes = functional.multi_step_forward(h_rep2, self.neuron)  # [T,B*S,H]
        spikes = spikes.view(T, B, S, H)

        out = spikes.mean(dim=0)         # [B,S,H]; TTFS时替换
        spike_rate = spikes.mean()       # scalar
        return out, spike_rate

class SNNLM(nn.Module):
    def __init__(self, vocab_size: int, hidden_size: int = 384, num_layers: int = 4, timesteps: int = 6, p_drop: float = 0.0):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, hidden_size)
        self.blocks = nn.ModuleList([SNNLinearBlock(hidden_size, timesteps=timesteps, p_drop=p_drop) for _ in range(num_layers)])
        self.ln_f = nn.LayerNorm(hidden_size)
        self.lm_head = nn.Linear(hidden_size, vocab_size)
        self.timesteps = timesteps

    def forward(self, input_ids: torch.Tensor):
        x = self.embed(input_ids)  # [B,S,H]
        total_rate = 0.0
        for blk in self.blocks:
            x, r = blk(x)
            total_rate = total_rate + r
        x = self.ln_f(x)
        logits = self.lm_head(x)         # [B,S,V]
        logits = torch.nan_to_num(logits, nan=0.0, posinf=1e4, neginf=-1e4)  # 保护
        avg_rate = total_rate / max(1, len(self.blocks))
        return logits, x, avg_rate

# %% ANN→SNN 校准（稳健版）
@torch.no_grad()
def collect_teacher_hidden_stats(teacher, dl, device, max_batches: int = 10) -> Tuple[float, float]:
    means, stds = [], []
    n = 0
    for batch in dl:
        input_ids = batch["input_ids"].to(device)
        # Skip empty batches
        if input_ids.numel() == 0:
            continue
        out = teacher(input_ids, output_hidden_states=True)
        h = out.hidden_states[-1].detach().float()
        m, s = h.mean().item(), h.std().item()
        if math.isfinite(m) and math.isfinite(s):
            means.append(m); stds.append(s)
        n += 1
        if n >= max_batches: break
    mean_h = sum(means)/max(1,len(means)) if means else 1.0
    std_h  = sum(stds)/max(1,len(stds))  if stds  else 0.5
    return mean_h, std_h

@torch.no_grad()
def calibrate_student_thresholds(student: SNNLM, teacher_hidden_stats: Tuple[float, float]):
    mean_h, std_h = teacher_hidden_stats
    # 保底阈值，避免0/爆发
    base_thr = float(max(0.1, min(5.0, mean_h + 0.5 * std_h)))
    for blk in student.blocks:
        blk.neuron.v_threshold = base_thr

# %% KD 损失
class KDLoss(nn.Module):
    def __init__(self, kd: KDWeights, teacher_hidden_size: int, student_hidden_size: int, temp: float = 1.0):
        super().__init__()
        self.kd = kd
        self.temp = temp
        self.adapter = (
            nn.Linear(student_hidden_size, teacher_hidden_size)
            if student_hidden_size != teacher_hidden_size else nn.Identity()
        )

    def forward(self, s_logits, t_logits, s_hidden, t_hidden, spike_rate, target_rate: float = 0.3):
        t = self.temp
        s_log = F.log_softmax(s_logits / t, dim=-1)
        t_log = F.softmax(t_logits / t, dim=-1)
        loss_logits = F.kl_div(s_log, t_log, reduction="batchmean") * (t * t)

        sh = self.adapter(s_hidden)
        th = t_hidden.detach()
        loss_hidden = F.mse_loss(sh, th)

        loss_attn = torch.tensor(0.0, device=s_logits.device)
        loss_rate = (spike_rate - target_rate) ** 2

        loss = (self.kd.logits * loss_logits +
                self.kd.hidden * loss_hidden +
                self.kd.attention * loss_attn +
                self.kd.spike_rate * loss_rate)
        return loss, {"loss_logits": loss_logits.detach(),
                       "loss_hidden": loss_hidden.detach(),
                       "loss_rate":  loss_rate.detach()}

# %% 稳定的 PPL 评估（逐 token 掩码平均）
@torch.no_grad()
def evaluate_perplexity(model: SNNLM, dl, device) -> float:
    model.eval()
    loss_sum, tok_count = 0.0, 0
    for batch in dl:
        input_ids = batch["input_ids"].to(device)
        labels    = batch["labels"].to(device)
        # Skip empty batches
        if input_ids.numel() == 0:
            continue
        logits, _, _ = model(input_ids)                              # [B,S,V]
        logits = torch.nan_to_num(logits, nan=0.0, posinf=1e4, neginf=-1e4)
        V = logits.size(-1)
        loss_tok = F.cross_entropy(logits.view(-1, V),
                                   labels.view(-1),
                                   ignore_index=-100,
                                   reduction='none')                 # [B*S]
        mask = (labels.view(-1) != -100)
        if mask.any():
            loss_sum += loss_tok[mask].sum().item()
            tok_count += mask.sum().item()
    if tok_count == 0:
        return float("nan")
    return math.exp(loss_sum / tok_count)

@torch.no_grad()
def estimate_energy_gain(student_avg_rate: float, hidden_size: int, num_layers: int, timesteps: int, teacher_mac: float) -> float:
    events = hidden_size * num_layers * timesteps * max(1e-6, float(student_avg_rate))
    return teacher_mac / max(1e-6, events)

# %% Main（可直接运行）
def main():
    class Args: pass
    args = Args()
    args.teacher = "sshleifer/tiny-gpt2"   # 替换为 Mamba/LLaMA 小型变体即可
    args.block_size = 128
    args.hidden = 384
    args.layers = 4
    args.timesteps = 6
    args.batch_size = 2
    args.max_steps = 50
    args.lr = 1e-4
    args.seed = 42
    args.device = "cuda" if torch.cuda.is_available() else "cpu"

    set_seed(args.seed)

    # Teacher
    tok = AutoModelForCausalLM.from_pretrained(args.teacher).config
    teacher_name = "sshleifer/tiny-gpt2"  # 为简化下方获取tokenizer
    tokenizer = AutoTokenizer.from_pretrained(teacher_name if args.teacher is None else args.teacher)
    if tokenizer.pad_token_id is None:
        tokenizer.pad_token = tokenizer.eos_token
    teacher = AutoModelForCausalLM.from_pretrained(args.teacher).to(args.device)
    teacher.eval()

    # Data
    raw = load_dataset("wikitext", "wikitext-2-raw-v1")
    collator = LMTextCollator(tokenizer, block_size=args.block_size)
    # Adjusted batch sizes to 1 for train and val loaders to simplify data handling
    train_loader = DataLoader(raw["train"], batch_size=1, shuffle=True,  collate_fn=collator)
    val_loader   = DataLoader(raw["validation"], batch_size=1, shuffle=False, collate_fn=collator)


    # Student
    student = SNNLM(vocab_size=tokenizer.vocab_size,
                         hidden_size=args.hidden,
                         num_layers=args.layers,
                         timesteps=args.timesteps).to(args.device)

    # Calibration
    mean_h, std_h = collect_teacher_hidden_stats(teacher, train_loader, args.device, max_batches=10)
    calibrate_student_thresholds(student, (mean_h, std_h))

    # KD training
    kd_w = KDWeights(logits=1.0, hidden=0.5, attention=0.0, spike_rate=0.3, ttfs=0.0)
    t_hidden_size = getattr(teacher.config, "n_embd", args.hidden)
    kd_loss = KDLoss(kd_w, teacher_hidden_size=t_hidden_size, student_hidden_size=args.hidden).to(args.device)
    optim = torch.optim.AdamW(student.parameters(), lr=args.lr)

    student.train()
    step, best_val = 0, float("inf")
    avg_rate_accum, avg_rate_steps = 0.0, 0

    while step < args.max_steps:
        for batch in train_loader:
            step += 1
            if step > args.max_steps: break
            input_ids = batch["input_ids"].to(args.device)
            labels    = batch["labels"].to(args.device)

            # Skip empty batches
            if input_ids.numel() == 0:
                print(f"Skipping empty batch at step {step}")
                continue

            with torch.no_grad():
                tout = teacher(input_ids, output_hidden_states=True, output_attentions=False)
                t_logits = torch.nan_to_num(tout.logits, nan=0.0, posinf=1e4, neginf=-1e4)
                t_hidden = tout.hidden_states[-1]

            s_logits, s_hidden, s_rate = student(input_ids)
            loss_kd, parts = kd_loss(s_logits, t_logits, s_hidden, t_hidden, s_rate)

            # small LM loss (masked)
            V = s_logits.size(-1)
            ce = F.cross_entropy(s_logits.view(-1, V), labels.view(-1), ignore_index=-100)
            loss = loss_kd + 0.1 * ce

            optim.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(student.parameters(), 1.0)  # 防爆
            optim.step()

            # 打印（detach，避免你看到的警告）
            if step % 10 == 0:
                print(f"[Step {step}] loss={loss.detach().item():.4f} "
                      f"(kd_logits={float(parts['loss_logits']):.4f} "
                      f"kd_hidden={float(parts['loss_hidden']):.4f} "
                      f"rate={float(s_rate.detach()):.4f})")

            if step % 50 == 0:
                ppl = evaluate_perplexity(student, val_loader, args.device)
                print(f"   -> validation ppl={ppl:.2f}")
                if math.isfinite(ppl):
                    best_val = min(best_val, ppl)

            avg_rate_accum += float(s_rate.detach())
            avg_rate_steps += 1

    # Final eval
    student_ppl = evaluate_perplexity(student, val_loader, args.device)
    print(f"[Final] student perplexity: {student_ppl:.2f}")

    # Energy estimate（以平均尖峰率估算事件数；teacher MAC 粗略）
    avg_s_rate = (avg_rate_accum / max(1, avg_rate_steps))
    teacher_mac = (getattr(teacher.config, "n_layer", 12) *
                   getattr(teacher.config, "n_embd", args.hidden) ** 2)
    gain = estimate_energy_gain(avg_s_rate, args.hidden, args.layers, args.timesteps, float(teacher_mac))
    print(f"[Energy] rough gain estimate: {gain:.2f}x "
          f"(avg spike rate={avg_s_rate:.4f}, timesteps={args.timesteps})")

if __name__ == "__main__":
    # 默认参数在 main() 内部
    main()

  if pkgutil.find_loader(name) is None:


Skipping empty batch at step 4
Skipping empty batch at step 7
Skipping empty batch at step 9
[Step 10] loss=16.0680 (kd_logits=14.5114 kd_hidden=0.9065 rate=0.1156)
Skipping empty batch at step 11
Skipping empty batch at step 14
Skipping empty batch at step 15
Skipping empty batch at step 17
Skipping empty batch at step 19
[Step 20] loss=21.5140 (kd_logits=20.0228 kd_hidden=0.7834 rate=0.1150)
Skipping empty batch at step 22
Skipping empty batch at step 23
Skipping empty batch at step 26
Skipping empty batch at step 28
Skipping empty batch at step 29
Skipping empty batch at step 30
Skipping empty batch at step 33
Skipping empty batch at step 34
Skipping empty batch at step 36
Skipping empty batch at step 40
Skipping empty batch at step 47
Skipping empty batch at step 50


KeyboardInterrupt: 

In [11]:
# LLM → SNN PoC (No empty batches, no deprecation warning)
# One-cell Colab script: install deps → build packed datasets → FAS/LAS-style calibration → KD training → eval

# -------- Install deps (use importlib.util.find_spec to avoid deprecation) --------
import sys, subprocess, importlib.util
def _need(pkg_root: str) -> bool:
    return importlib.util.find_spec(pkg_root) is None
def _pip(p): subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", p])

if _need("torch"): _pip("torch")
if _need("transformers"): _pip("transformers>=4.44.0")
if _need("datasets"): _pip("datasets>=2.19.0")
if _need("spikingjelly"): _pip("spikingjelly")

# -------- Imports --------
import math, random
from dataclasses import dataclass
from typing import Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader

from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset

from spikingjelly.activation_based import neuron, functional, surrogate

# -------- Repro --------
def set_seed(seed=42):
    import numpy as np
    random.seed(seed); np.random.seed(seed)
    torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)

# -------- KD Weights --------
@dataclass
class KDWeights:
    logits: float = 1.0
    hidden: float = 0.5
    attention: float = 0.0
    spike_rate: float = 0.3
    ttfs: float = 0.0

# -------- Build PACKED datasets (no empty batches) --------
def build_packed_tensors(split, tokenizer, block_size: int = 128, max_blocks: int = None):
    # Join all texts into a long stream, then slice into fixed blocks with next-token labels
    texts = [t if t is not None else "" for t in split["text"]]
    big = "\n\n".join(texts)
    # Truncate the joined text to the model's max sequence length before tokenization
    max_len = tokenizer.model_max_length if tokenizer.model_max_length > 0 and tokenizer.model_max_length < 1e9 else 4096 # Use a reasonable default if max_length is effectively infinite
    toks = tokenizer(big, return_tensors="pt", truncation=True, max_length=max_len, add_special_tokens=True)
    ids = toks["input_ids"].squeeze(0)  # [N]
    if ids.numel() < block_size + 1:
        # ensure at least one block by padding EOS
        pad = torch.full((block_size + 1 - ids.numel(),), tokenizer.eos_token_id or 0, dtype=torch.long)
        ids = torch.cat([ids, pad], dim=0)

    total = (ids.numel() - 1) // block_size
    if max_blocks is not None:
        total = min(total, max_blocks)

    xs, ys = [], []
    for i in range(total):
        s = i * block_size
        e = s + block_size
        inp = ids[s:e]                  # [B]
        lab = ids[s+1:e+1]              # shifted
        xs.append(inp)
        ys.append(lab)
    X = torch.stack(xs, dim=0)          # [B, S]
    Y = torch.stack(ys, dim=0)          # [B, S]
    return TensorDataset(X, Y)

# -------- Student SNN --------
class SNNLinearBlock(nn.Module):
    def __init__(self, hidden_size: int, timesteps: int = 6, v_thr: float = 1.0, p_drop: float = 0.0):
        super().__init__()
        self.proj = nn.Linear(hidden_size, hidden_size)
        self.neuron = neuron.LIFNode(v_threshold=v_thr, surrogate_function=surrogate.ATan())
        self.timesteps = timesteps
        self.dropout = nn.Dropout(p_drop)
        self.ln = nn.LayerNorm(hidden_size)

    def forward(self, x: torch.Tensor):
        # x: [B, S, H]
        B, S, H = x.shape
        h = self.proj(self.ln(x))
        h = self.dropout(h)
        T = self.timesteps
        hT = h.unsqueeze(0).repeat(T, 1, 1, 1)   # [T,B,S,H]
        if self.training:
            hT = hT + torch.randn_like(hT) * 0.05
        # Flatten for spikingjelly
        hT2 = hT.view(T, B*S, H)
        functional.reset_net(self)
        spikes = functional.multi_step_forward(hT2, self.neuron)     # [T,B*S,H]
        spikes = spikes.view(T, B, S, H)
        out = spikes.mean(dim=0)                                     # [B,S,H]  (TTFS可替换)
        spike_rate = spikes.mean()
        return out, spike_rate

class SNNLM(nn.Module):
    def __init__(self, vocab_size: int, hidden_size: int = 384, num_layers: int = 4, timesteps: int = 6, p_drop: float = 0.0):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, hidden_size)
        self.blocks = nn.ModuleList([SNNLinearBlock(hidden_size, timesteps, v_thr=1.0, p_drop=p_drop) for _ in range(num_layers)])
        self.ln_f = nn.LayerNorm(hidden_size)
        self.lm_head = nn.Linear(hidden_size, vocab_size)
        self.timesteps = timesteps

    def forward(self, input_ids: torch.Tensor):
        x = self.embed(input_ids)        # [B,S,H]
        total_rate = 0.0
        for blk in self.blocks:
            x, r = blk(x)
            total_rate = total_rate + r
        x = self.ln_f(x)
        logits = self.lm_head(x)         # [B,S,V]
        logits = torch.nan_to_num(logits, nan=0.0, posinf=1e4, neginf=-1e4)
        avg_rate = total_rate / max(1, len(self.blocks))
        return logits, x, avg_rate

# -------- Calibration (FAS/LAS-style proxy) --------
@torch.no_grad()
def collect_teacher_hidden_stats(teacher, data_loader, device, max_batches=10) -> Tuple[float, float]:
    means, stds, n = [], [], 0
    for X, _ in data_loader:
        X = X.to(device)
        out = teacher(X, output_hidden_states=True)
        h = out.hidden_states[-1].detach().float()
        m, s = h.mean().item(), h.std().item()
        if math.isfinite(m) and math.isfinite(s):
            means.append(m); stds.append(s); n += 1
            if n >= max_batches: break
    mean_h = sum(means)/max(1,len(means)) if means else 1.0
    std_h  = sum(stds)/max(1,len(stds))  if stds  else 0.5
    return mean_h, std_h

@torch.no_grad()
def calibrate_student_thresholds(student: SNNLM, hidden_stats: Tuple[float, float]):
    mean_h, std_h = hidden_stats
    base_thr = float(max(0.1, min(5.0, mean_h + 0.5 * std_h)))
    for blk in student.blocks:
        blk.neuron.v_threshold = base_thr

# -------- KD Loss --------
class KDLoss(nn.Module):
    def __init__(self, kd: KDWeights, teacher_hidden_size: int, student_hidden_size: int, temp: float = 1.0):
        super().__init__()
        self.kd = kd
        self.temp = temp
        self.adapter = nn.Linear(student_hidden_size, teacher_hidden_size) \
            if student_hidden_size != teacher_hidden_size else nn.Identity()

    def forward(self, s_logits, t_logits, s_hidden, t_hidden, spike_rate, target_rate=0.3):
        t = self.temp
        s_log = F.log_softmax(s_logits / t, dim=-1)
        t_log = F.softmax(t_logits / t, dim=-1)
        loss_logits = F.kl_div(s_log, t_log, reduction="batchmean") * (t*t)

        sh = self.adapter(s_hidden)
        th = t_hidden.detach()
        loss_hidden = F.mse_loss(sh, th)

        loss_attn = torch.tensor(0.0, device=s_logits.device)
        loss_rate = (spike_rate - target_rate) ** 2

        loss = (self.kd.logits * loss_logits +
                self.kd.hidden * loss_hidden +
                self.kd.attention * loss_attn +
                self.kd.spike_rate * loss_rate)
        return loss, {"loss_logits": loss_logits.detach(),
                       "loss_hidden": loss_hidden.detach(),
                       "loss_rate":  loss_rate.detach()}

# -------- Perplexity (packed dataset: no ignore_index needed) --------
@torch.no_grad()
def evaluate_perplexity(model: SNNLM, data_loader, device) -> float:
    model.eval()
    loss_sum, tok_count = 0.0, 0
    V = None
    for X, Y in data_loader:
        X, Y = X.to(device), Y.to(device)
        logits, _, _ = model(X)                  # [B,S,V]
        logits = torch.nan_to_num(logits, nan=0.0, posinf=1e4, neginf=-1e4)
        if V is None: V = logits.size(-1)
        loss_tok = F.cross_entropy(logits.view(-1, V), Y.view(-1), reduction='sum')
        loss_sum += float(loss_tok.item())
        tok_count += int(Y.numel())
    return math.exp(loss_sum / max(1, tok_count))

@torch.no_grad()
def estimate_energy_gain(avg_rate: float, hidden_size: int, num_layers: int, timesteps: int, teacher_mac: float) -> float:
    events = hidden_size * num_layers * timesteps * max(1e-6, float(avg_rate))
    return teacher_mac / max(1e-6, events)

# -------- Main --------
def main():
    class Args: pass
    args = Args()
    args.teacher = "sshleifer/tiny-gpt2"   # swap to Mamba/LLaMA small later
    args.block_size = 128
    args.hidden = 384
    args.layers = 4
    args.timesteps = 6
    args.batch_size = 8                    # bigger since we prepacked tensors
    args.max_steps = 200                   # more steps = better stability
    args.lr = 1e-4
    args.seed = 42
    args.device = "cuda" if torch.cuda.is_available() else "cpu"

    set_seed(args.seed)

    # Teacher & tokenizer
    tokenizer = AutoTokenizer.from_pretrained(args.teacher)
    if tokenizer.pad_token_id is None:
        tokenizer.pad_token = tokenizer.eos_token
    teacher = AutoModelForCausalLM.from_pretrained(args.teacher).to(args.device)
    teacher.eval()

    # Build PACKED datasets to guarantee non-empty batches
    raw = load_dataset("wikitext", "wikitext-2-raw-v1")
    train_ds = build_packed_tensors(raw["train"], tokenizer, block_size=args.block_size, max_blocks=2048)
    val_ds   = build_packed_tensors(raw["validation"], tokenizer, block_size=args.block_size, max_blocks=512)

    train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True, drop_last=True)
    val_loader   = DataLoader(val_ds,   batch_size=args.batch_size, shuffle=False, drop_last=False)

    # Student
    student = SNNLM(vocab_size=tokenizer.vocab_size,
                         hidden_size=args.hidden,
                         num_layers=args.layers,
                         timesteps=args.timesteps).to(args.device)

    # Calibration (FAS/LAS proxy)
    # Use a small subset of train_loader to estimate teacher hidden stats
    calib_subset = DataLoader(train_ds, batch_size=args.batch_size, shuffle=False, drop_last=True)
    mean_h, std_h = collect_teacher_hidden_stats(teacher, calib_subset, args.device, max_batches=10)
    calibrate_student_thresholds(student, (mean_h, std_h))

    # KD setup
    kd_w = KDWeights(logits=1.0, hidden=0.5, attention=0.0, spike_rate=0.3, ttfs=0.0)
    t_hidden_size = getattr(teacher.config, "n_embd", args.hidden)
    kd_loss = KDLoss(kd_w, teacher_hidden_size=t_hidden_size, student_hidden_size=args.hidden).to(args.device)
    optim = torch.optim.AdamW(student.parameters(), lr=args.lr)

    # Train
    step, best_val = 0, float("inf")
    avg_rate_accum, avg_rate_steps = 0.0, 0
    student.train()

    for X, Y in train_loader:
        step += 1
        if step > args.max_steps: break
        X = X.to(args.device)  # [B,S]
        Y = Y.to(args.device)

        with torch.no_grad():
            tout = teacher(X, output_hidden_states=True, output_attentions=False)
            t_logits = torch.nan_to_num(tout.logits, nan=0.0, posinf=1e4, neginf=-1e4)  # [B,S,V]
            t_hidden = tout.hidden_states[-1]                                           # [B,S,Ht]

        s_logits, s_hidden, s_rate = student(X)

        loss_kd, parts = kd_loss(s_logits, t_logits, s_hidden, t_hidden, s_rate)
        ce = F.cross_entropy(s_logits.view(-1, s_logits.size(-1)), Y.view(-1), reduction='mean')
        loss = loss_kd + 0.1 * ce

        optim.zero_grad()
        loss.backward()
        nn.utils.clip_grad_norm_(student.parameters(), 1.0)
        optim.step()

        avg_rate_accum += float(s_rate.detach())
        avg_rate_steps += 1

        if step % 10 == 0:
            print(f"[Step {step}] loss={loss.detach().item():.4f} "
                  f"(kd_logits={float(parts['loss_logits']):.4f} "
                  f"kd_hidden={float(parts['loss_hidden']):.4f} "
                  f"rate={float(s_rate.detach()):.4f})")

        if step % 50 == 0:
            ppl = evaluate_perplexity(student, val_loader, args.device)
            print(f"   -> validation ppl={ppl:.2f}")
            if math.isfinite(ppl): best_val = min(best_val, ppl)

    # Final eval & energy estimate
    student_ppl = evaluate_perplexity(student, val_loader, args.device)
    print(f"[Final] student perplexity: {student_ppl:.2f}")

    avg_s_rate = (avg_rate_accum / max(1, avg_rate_steps))
    teacher_mac = (getattr(teacher.config, "n_layer", 12) *
                   getattr(teacher.config, "n_embd", args.hidden) ** 2)
    gain = estimate_energy_gain(avg_s_rate, args.hidden, args.layers, args.timesteps, float(teacher_mac))
    print(f"[Energy] rough gain estimate: {gain:.2f}x "
          f"(avg spike rate={avg_s_rate:.4f}, timesteps={args.timesteps})")

# Run
if __name__ == "__main__":
    main()

Token indices sequence length is longer than the specified maximum sequence length for this model (2428601 > 1024). Running this sequence through the model will result in indexing errors


[Step 10] loss=21.1953 (kd_logits=19.6545 kd_hidden=0.8761 rate=0.1185)
[Step 20] loss=19.2629 (kd_logits=17.7774 kd_hidden=0.7830 rate=0.1246)
[Step 30] loss=17.5158 (kd_logits=16.0315 kd_hidden=0.7865 rate=0.1306)
[Step 40] loss=15.4339 (kd_logits=13.9472 kd_hidden=0.7933 rate=0.1403)
[Step 50] loss=12.5899 (kd_logits=11.0985 kd_hidden=0.8244 rate=0.1563)
   -> validation ppl=47548.03
[Step 60] loss=11.6649 (kd_logits=10.1619 kd_hidden=0.8433 rate=0.1614)
[Step 70] loss=10.5661 (kd_logits=9.0818 kd_hidden=0.8204 rate=0.1632)
[Step 80] loss=9.8583 (kd_logits=8.3377 kd_hidden=0.8992 rate=0.1632)
[Step 90] loss=9.3804 (kd_logits=7.9056 kd_hidden=0.8056 rate=0.1616)
[Step 100] loss=8.5295 (kd_logits=7.0588 kd_hidden=0.8116 rate=0.1653)
   -> validation ppl=40874.58
[Step 110] loss=8.1506 (kd_logits=6.6665 kd_hidden=0.8347 rate=0.1635)
[Step 120] loss=7.9348 (kd_logits=6.4465 kd_hidden=0.8456 rate=0.1637)
[Step 130] loss=7.9128 (kd_logits=6.4176 kd_hidden=0.8492 rate=0.1598)
[Step 140] lo