# 知识蒸馏基础 (Knowledge Distillation)

**SOTA 教育标准** | 包含 Hinton 蒸馏、温度参数、软标签

---

## 1. 知识蒸馏原理

**核心思想**: 用大模型（教师）指导小模型（学生）学习。

**Hinton 蒸馏损失**:

$$\mathcal{L} = \alpha \cdot \mathcal{L}_{CE}(y, p_s) + (1-\alpha) \cdot T^2 \cdot \text{KL}(p_t^T \| p_s^T)$$

其中 $p^T = \text{softmax}(z/T)$ 是温度软化后的概率分布。

In [None]:
from __future__ import annotations
from dataclasses import dataclass
from typing import Optional
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")

---

## 2. 蒸馏损失实现

In [None]:
@dataclass
class DistillationConfig:
    """蒸馏配置。"""
    temperature: float = 4.0
    alpha: float = 0.5


class DistillationLoss(nn.Module):
    """知识蒸馏损失。
    
    Core Idea: 结合硬标签和软标签的损失。
    """

    def __init__(self, config: DistillationConfig = DistillationConfig()):
        super().__init__()
        self.temperature = config.temperature
        self.alpha = config.alpha

    def forward(self, student_logits: Tensor, teacher_logits: Tensor, labels: Tensor) -> Tensor:
        # 硬标签损失
        hard_loss = F.cross_entropy(student_logits, labels)
        
        # 软标签损失 (KL散度)
        T = self.temperature
        soft_student = F.log_softmax(student_logits / T, dim=-1)
        soft_teacher = F.softmax(teacher_logits / T, dim=-1)
        soft_loss = F.kl_div(soft_student, soft_teacher, reduction='batchmean') * (T ** 2)
        
        return self.alpha * hard_loss + (1 - self.alpha) * soft_loss


# 测试
loss_fn = DistillationLoss()
student_logits = torch.randn(8, 10)
teacher_logits = torch.randn(8, 10)
labels = torch.randint(0, 10, (8,))

loss = loss_fn(student_logits, teacher_logits, labels)
print(f"Distillation Loss: {loss.item():.4f}")

---

## 3. 温度参数分析

In [None]:
def visualize_temperature_effect():
    """可视化温度对软标签的影响。"""
    logits = torch.tensor([2.0, 1.0, 0.5, 0.1, -0.5])
    temperatures = [1.0, 2.0, 4.0, 8.0]
    
    fig, axes = plt.subplots(1, len(temperatures), figsize=(16, 4))
    
    for i, T in enumerate(temperatures):
        probs = F.softmax(logits / T, dim=0).numpy()
        axes[i].bar(range(len(probs)), probs)
        axes[i].set_title(f'T = {T}')
        axes[i].set_ylim(0, 1)
        axes[i].grid(True, alpha=0.3)
    
    plt.suptitle('Temperature Effect on Softmax')
    plt.tight_layout()
    plt.show()


visualize_temperature_effect()

---

## 4. 蒸馏训练流程

In [None]:
class Distiller:
    """知识蒸馏训练器。"""
    
    def __init__(self, teacher: nn.Module, student: nn.Module, config: DistillationConfig):
        self.teacher = teacher
        self.student = student
        self.loss_fn = DistillationLoss(config)
        self.teacher.eval()
    
    def train_step(self, x: Tensor, labels: Tensor) -> float:
        with torch.no_grad():
            teacher_logits = self.teacher(x)
        student_logits = self.student(x)
        loss = self.loss_fn(student_logits, teacher_logits, labels)
        return loss.item()


# 测试
teacher = nn.Linear(64, 10)
student = nn.Linear(64, 10)
distiller = Distiller(teacher, student, DistillationConfig())

x = torch.randn(8, 64)
labels = torch.randint(0, 10, (8,))
loss = distiller.train_step(x, labels)
print(f"Training Loss: {loss:.4f}")

---

## 5. 总结

| 概念 | 公式 | 说明 |
|:-----|:-----|:-----|
| **软标签** | $p^T = \text{softmax}(z/T)$ | 温度软化 |
| **蒸馏损失** | $\alpha \cdot L_{hard} + (1-\alpha) \cdot L_{soft}$ | 混合损失 |
| **温度** | $T > 1$ | 更平滑的分布 |