# PyTorch 从零实现知识蒸馏

使用 Qwen2.5-0.5B 作为教师模型，从零实现 ~123M 参数的 Decoder-Only Transformer 学生模型，
在 Wikipedia 中文子集上进行在线知识蒸馏训练。

**运行环境**: Google Colab T4 GPU (15GB VRAM)

**模块结构**:
- `src/config.py` — 配置数据类
- `src/model.py` — 模型架构（RMSNorm, RoPE, GQA, SwiGLU）
- `src/data.py` — 数据加载与预处理
- `src/trainer.py` — 蒸馏训练循环
- `src/generate.py` — 文本生成推理

## Cell 1: 环境检查 & 依赖安装

In [1]:
# 安装依赖（Colab 环境）
# !pip install torch transformers datasets matplotlib

import sys
import torch

print(f"Python: {sys.version}")
print(f"PyTorch: {torch.__version__}")
print(f"CUDA 可用: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"显存: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"\n使用设备: {DEVICE}")
print("begin11")

Python: 3.12.12 (main, Oct 10 2025, 08:52:57) [GCC 11.4.0]
PyTorch: 2.9.0+cu126
CUDA 可用: True
GPU: Tesla T4
显存: 14.7 GB

使用设备: cuda
begin


## Cell 2: 挂载 Google Drive（可选）

In [2]:
# 挂载 Google Drive 用于持久化检查点（可选）
# 如果在本地运行，可跳过此 Cell

USE_DRIVE = True  # 设置为 True 启用 Google Drive

if USE_DRIVE:
    from google.colab import drive
    drive.mount('/content/drive')
    CHECKPOINT_DIR = '/content/drive/MyDrive/distillation_checkpoints/'
else:
    CHECKPOINT_DIR = 'checkpoints/'

import os
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
print(f"检查点目录: {CHECKPOINT_DIR}")

Mounted at /content/drive
检查点目录: /content/drive/MyDrive/distillation_checkpoints/


## Cell 3: 导入模块 & 配置

In [16]:
# 如果在 Colab 中，需要先克隆项目或上传 src/ 目录
# !git clone <repo-url> && cd LLM

import sys
# !git clone https://github.com/linshengli/LLM.git
%cd /content/LLM
!git checkout stage1
os.listdir('.')
sys.path.insert(0, os.getcwd())  # 确保可以导入 src 模块
print(os.getcwd())
print(os.listdir('.'))

from src.config import ModelConfig, TrainingConfig
from src.model import StudentModel
from src.data import load_tokenizer, create_dataloaders
from src.trainer import DistillationTrainer, load_teacher_model
from src.generate import TextGenerator, load_trained_model

# 模型配置
model_config = ModelConfig()
print(f"模型配置: {model_config}")
print(f"预期参数量: ~{151665 * 512 + 3802112 * 12 + 512:,}")

# 训练配置
training_config = TrainingConfig(checkpoint_dir=CHECKPOINT_DIR)
print(f"\n训练配置: {training_config}")

/content/LLM
Already on 'stage1'
Your branch is up to date with 'origin/stage1'.
/content/LLM
['.claude', 'models', '.specify', 'train_qwenold_colab.ipynb', 'tests', 'CLAUDE.md', '.qoder', 'notebooks', 'src', 'train_qwenold_colab.py', '__pycache__', 'requirements.txt', '.git', 'data', 'specs']
模型配置: ModelConfig(hidden_size=512, num_layers=12, num_heads=8, num_kv_heads=2, intermediate_size=2048, vocab_size=151665, max_seq_len=512, rope_theta=1000000.0, norm_eps=1e-06, dropout=0.0)
预期参数量: ~123,278,336

训练配置: TrainingConfig(batch_size=8, learning_rate=0.0003, weight_decay=0.01, warmup_steps=500, num_epochs=3, gradient_clip=1.0, alpha=0.5, temperature=2.0, checkpoint_dir='/content/drive/MyDrive/distillation_checkpoints/', log_interval=50, eval_interval=500, save_interval=1000)


## Cell 4: 加载 Tokenizer & 构建数据集

In [17]:
# 加载 Tokenizer
tokenizer = load_tokenizer()
print(f"Tokenizer vocab_size: {tokenizer.vocab_size}")
print(f"pad_token: {tokenizer.pad_token}")

# 构建 DataLoader（取 5000 条文章，约 50MB 文本）
train_loader, val_loader = create_dataloaders(
    tokenizer=tokenizer,
    config=training_config,
    model_config=model_config,
    max_samples=5000,
)

print(f"\n训练集 batch 数: {len(train_loader)}")
print(f"验证集 batch 数: {len(val_loader)}")

# 验证一个 batch
sample_batch = next(iter(train_loader))
print(f"\n样本 batch:")
print(f"  input_ids shape: {sample_batch['input_ids'].shape}")
print(f"  labels shape: {sample_batch['labels'].shape}")

Error while fetching `HF_TOKEN` secret value from your vault: 'Requesting secret HF_TOKEN timed out. Secrets can only be fetched when running from the Colab UI.'.
You are not authenticated with the Hugging Face Hub in this notebook.
If the error persists, please let us know by opening an issue on GitHub (https://github.com/huggingface/huggingface_hub/issues/new).


config.json:   0%|          | 0.00/681 [00:00<?, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]



Tokenizer vocab_size: 151643
pad_token: <|endoftext|>


README.md: 0.00B [00:00, ?B/s]

20231101.zh/train-00000-of-00006.parquet:   0%|          | 0.00/587M [00:00<?, ?B/s]

20231101.zh/train-00001-of-00006.parquet:   0%|          | 0.00/264M [00:00<?, ?B/s]

20231101.zh/train-00002-of-00006.parquet:   0%|          | 0.00/127M [00:00<?, ?B/s]

20231101.zh/train-00003-of-00006.parquet:   0%|          | 0.00/262M [00:00<?, ?B/s]

20231101.zh/train-00004-of-00006.parquet:   0%|          | 0.00/257M [00:00<?, ?B/s]

20231101.zh/train-00005-of-00006.parquet:   0%|          | 0.00/225M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/1384748 [00:00<?, ? examples/s]


训练集 batch 数: 3305
验证集 batch 数: 351

样本 batch:
  input_ids shape: torch.Size([8, 512])
  labels shape: torch.Size([8, 512])


## Cell 5: 构建学生模型

In [18]:
# 构建学生模型
student_model = StudentModel(model_config)
param_count = student_model.count_parameters()
print(f"学生模型参数量: {param_count:,} ({param_count / 1e6:.1f}M)")
print(f"\n模型结构:")
print(student_model)

# 验证前向传播
with torch.no_grad():
    test_input = torch.randint(0, model_config.vocab_size, (2, 32))
    test_output = student_model(test_input)
    print(f"\n前向传播验证:")
    print(f"  输入: {test_input.shape}")
    print(f"  输出: {test_output.shape}")
    assert test_output.shape == (2, 32, model_config.vocab_size)
    print("  ✓ 输出形状正确")

# 显存状态
if torch.cuda.is_available():
    student_model = student_model.to(DEVICE)
    print(f"\n显存占用:")
    print(f"  已分配: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
    print(f"  已缓存: {torch.cuda.memory_reserved() / 1024**3:.2f} GB")

学生模型参数量: 123,278,336 (123.3M)

模型结构:
StudentModel(
  (embedding): Embedding(151665, 512)
  (layers): ModuleList(
    (0-11): 12 x TransformerBlock(
      (attention_norm): RMSNorm()
      (attention): GQAAttention(
        (q_proj): Linear(in_features=512, out_features=512, bias=False)
        (k_proj): Linear(in_features=512, out_features=128, bias=False)
        (v_proj): Linear(in_features=512, out_features=128, bias=False)
        (o_proj): Linear(in_features=512, out_features=512, bias=False)
        (rotary_emb): RotaryEmbedding()
      )
      (ffn_norm): RMSNorm()
      (ffn): SwiGLUFFN(
        (gate_proj): Linear(in_features=512, out_features=2048, bias=False)
        (up_proj): Linear(in_features=512, out_features=2048, bias=False)
        (down_proj): Linear(in_features=2048, out_features=512, bias=False)
      )
    )
  )
  (norm): RMSNorm()
  (lm_head): Linear(in_features=512, out_features=151665, bias=False)
)

前向传播验证:
  输入: torch.Size([2, 32])
  输出: torch.Size([2, 32, 1

## Cell 6: 加载教师模型

In [20]:
# 加载教师模型（Qwen2.5-0.5B, FP16）
teacher_model = load_teacher_model(device=DEVICE)
teacher_params = sum(p.numel() for p in teacher_model.parameters())
print(f"教师模型参数量: {teacher_params:,} ({teacher_params / 1e6:.1f}M)")

# 显存状态
if torch.cuda.is_available():
    print(f"\n显存占用（学生 + 教师）:")
    print(f"  已分配: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
    print(f"  已缓存: {torch.cuda.memory_reserved() / 1024**3:.2f} GB")
    free_mem = torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_allocated()
    print(f"  剩余显存: {free_mem / 1024**3:.2f} GB")

Loading weights:   0%|          | 0/290 [00:00<?, ?it/s]

教师模型参数量: 494,032,768 (494.0M)

显存占用（学生 + 教师）:
  已分配: 1.39 GB
  已缓存: 2.41 GB
  剩余显存: 13.35 GB


## Cell 7: 执行蒸馏训练

In [21]:
# 创建蒸馏训练器
trainer = DistillationTrainer(
    student_model=student_model,
    teacher_model=teacher_model,
    train_loader=train_loader,
    val_loader=val_loader,
    config=training_config,
    device=DEVICE,
)

# 执行训练
print("开始蒸馏训练...")
history = trainer.train()
print(f"\n训练完成！共 {len(history['train_loss'])} 步")

开始蒸馏训练...


OutOfMemoryError: CUDA out of memory. Tried to allocate 2.31 GiB. GPU 0 has a total capacity of 14.74 GiB of which 1.58 GiB is free. Process 15601 has 13.16 GiB memory in use. Of the allocated memory 12.79 GiB is allocated by PyTorch, and 253.04 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

## Cell 8: 绘制 Loss 曲线

In [None]:
import matplotlib.pyplot as plt

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# 训练 Loss 曲线（使用滑动平均平滑）
train_loss = history["train_loss"]
window = min(50, len(train_loss) // 5 + 1)
if len(train_loss) > window:
    smoothed = [sum(train_loss[max(0,i-window):i+1]) / len(train_loss[max(0,i-window):i+1])
                for i in range(len(train_loss))]
else:
    smoothed = train_loss

axes[0].plot(train_loss, alpha=0.3, label="原始")
axes[0].plot(smoothed, label=f"滑动平均 (window={window})")
axes[0].set_xlabel("训练步数")
axes[0].set_ylabel("Loss")
axes[0].set_title("训练 Loss 曲线")
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# 验证 Loss & PPL
if history["val_loss"]:
    ax2 = axes[1]
    ax2.plot(history["val_loss"], "o-", color="orange", label="验证 Loss")
    ax2.set_xlabel("评估次数")
    ax2.set_ylabel("Loss")
    ax2.set_title("验证集指标")
    ax2_twin = ax2.twinx()
    ax2_twin.plot(history["val_ppl"], "s--", color="green", label="PPL")
    ax2_twin.set_ylabel("Perplexity")
    ax2.legend(loc="upper left")
    ax2_twin.legend(loc="upper right")
    ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig("training_curves.png", dpi=150)
plt.show()
print("训练曲线已保存至 training_curves.png")

## Cell 9: 文本生成演示

In [None]:
import os

# 加载最佳模型
best_ckpt = os.path.join(CHECKPOINT_DIR, "best.pt")
if not os.path.exists(best_ckpt):
    best_ckpt = os.path.join(CHECKPOINT_DIR, "final.pt")

gen_model = load_trained_model(best_ckpt, model_config, DEVICE)
generator = TextGenerator(model=gen_model, tokenizer=tokenizer, device=DEVICE)

# 测试提示词
prompts = [
    "中国的首都是",
    "人工智能的发展",
    "数学是研究",
]

strategies = ["greedy", "top_k", "top_p"]

for prompt in prompts:
    print(f"\n{'='*60}")
    print(f"提示词: {prompt}")
    print(f"{'='*60}")
    for strategy in strategies:
        result = generator.generate(
            prompt,
            max_new_tokens=100,
            strategy=strategy,
            temperature=0.8,
            top_k=50,
            top_p=0.9,
        )
        print(f"\n[{strategy}]:")
        print(result)

## Cell 10: 保存最终模型

In [None]:
import os

# 保存最终模型和配置
final_save_dir = os.path.join(CHECKPOINT_DIR, "final_model")
os.makedirs(final_save_dir, exist_ok=True)

# 保存模型权重
torch.save(
    {
        "model_state_dict": gen_model.state_dict(),
        "model_config": model_config,
        "training_config": training_config,
    },
    os.path.join(final_save_dir, "model.pt"),
)

print(f"模型已保存至: {final_save_dir}")
model_size = os.path.getsize(os.path.join(final_save_dir, "model.pt")) / 1024**2
print(f"模型文件大小: {model_size:.1f} MB")
print(f"\n训练摘要:")
print(f"  学生模型参数量: {param_count:,}")
print(f"  训练步数: {len(history['train_loss'])}")
if history['val_loss']:
    print(f"  最佳验证 Loss: {min(history['val_loss']):.4f}")
    print(f"  最佳验证 PPL: {min(history['val_ppl']):.2f}")
print("\n蒸馏训练完成！")