# 自蒸馏 (Self-Distillation)

**SOTA 教育标准** | 包含 Born-Again Networks、深度监督、辅助分类器

---

## 1. 自蒸馏理论

**传统蒸馏**: 大教师 → 小学生 | **自蒸馏**: 模型自己教自己

**Born-Again Networks**: 用训练好的模型作为教师，训练相同架构的学生。

$$\text{Acc}(M_1) < \text{Acc}(M_2) < \text{Acc}(M_3)$$

In [None]:
from __future__ import annotations
import copy
from dataclasses import dataclass
from typing import Callable, Dict, List, Optional, Tuple
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")

---

## 2. Born-Again 配置与训练器

In [None]:
@dataclass
class BornAgainConfig:
    """Born-Again Networks 配置。"""
    num_generations: int = 3
    temperature: float = 4.0
    alpha: float = 0.5
    epochs_per_gen: int = 10


class BornAgainTrainer:
    """Born-Again Networks 训练器。
    
    Core Idea: 迭代自蒸馏，每代模型都比上一代更好。
    """

    def __init__(self, model_fn: Callable[[], nn.Module], config: BornAgainConfig = BornAgainConfig()):
        self.model_fn = model_fn
        self.config = config
        self.generations: List[nn.Module] = []
        self.history: Dict[str, List] = {"generation": [], "test_acc": []}

    def get_teacher(self) -> Optional[nn.Module]:
        """获取上一代模型作为教师。"""
        if not self.generations:
            return None
        teacher = self.generations[-1]
        teacher.eval()
        for p in teacher.parameters():
            p.requires_grad = False
        return teacher


# 测试
trainer = BornAgainTrainer(lambda: nn.Linear(64, 10))
print(f"配置: {trainer.config}")

---

## 3. 深度监督网络

In [None]:
class DeepSupervisionNet(nn.Module):
    """深度监督自蒸馏网络。
    
    Core Idea: 让浅层学习深层的知识。
    """

    def __init__(self, num_classes: int = 10):
        super().__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(3, 64, 3, padding=1), nn.BatchNorm2d(64), nn.ReLU(), nn.MaxPool2d(2))
        self.layer2 = nn.Sequential(
            nn.Conv2d(64, 128, 3, padding=1), nn.BatchNorm2d(128), nn.ReLU(), nn.MaxPool2d(2))
        self.layer3 = nn.Sequential(
            nn.Conv2d(128, 256, 3, padding=1), nn.BatchNorm2d(256), nn.ReLU(), nn.MaxPool2d(2))
        
        self.aux1 = nn.Sequential(nn.AdaptiveAvgPool2d(1), nn.Flatten(), nn.Linear(64, num_classes))
        self.aux2 = nn.Sequential(nn.AdaptiveAvgPool2d(1), nn.Flatten(), nn.Linear(128, num_classes))
        self.main = nn.Sequential(nn.AdaptiveAvgPool2d(1), nn.Flatten(), nn.Linear(256, num_classes))

    def forward(self, x: Tensor) -> Tuple[Tensor, List[Tensor]]:
        f1 = self.layer1(x)
        f2 = self.layer2(f1)
        f3 = self.layer3(f2)
        return self.main(f3), [self.aux1(f1), self.aux2(f2)]


# 测试
model = DeepSupervisionNet(10)
x = torch.randn(2, 3, 32, 32)
main_out, aux_outs = model(x)
print(f"主输出: {main_out.shape}, 辅助输出: {[a.shape for a in aux_outs]}")

---

## 4. 深度监督损失

In [None]:
class DeepSupervisedLoss(nn.Module):
    """深度监督损失：深层监督浅层。"""

    def __init__(self, temperature: float = 4.0, aux_weights: List[float] = [0.3, 0.3]):
        super().__init__()
        self.T = temperature
        self.aux_weights = aux_weights

    def forward(self, main_logits: Tensor, aux_logits: List[Tensor], targets: Tensor) -> Tensor:
        loss_main = F.cross_entropy(main_logits, targets)
        main_soft = F.softmax(main_logits.detach() / self.T, dim=1)
        
        loss_aux = 0.0
        for aux, w in zip(aux_logits, self.aux_weights):
            aux_log = F.log_softmax(aux / self.T, dim=1)
            loss_aux += w * F.kl_div(aux_log, main_soft, reduction='batchmean') * (self.T ** 2)
        
        return loss_main + loss_aux


# 测试
loss_fn = DeepSupervisedLoss()
targets = torch.randint(0, 10, (2,))
loss = loss_fn(main_out, aux_outs, targets)
print(f"深度监督损失: {loss.item():.4f}")

---

## 5. 可视化

In [None]:
def visualize_born_again_effect() -> None:
    """可视化 Born-Again 效果。"""
    generations = [1, 2, 3, 4, 5]
    test_acc = [85.0, 87.2, 88.5, 89.1, 89.4]
    gains = [0] + [test_acc[i] - test_acc[i-1] for i in range(1, len(test_acc))]

    fig, axes = plt.subplots(1, 2, figsize=(12, 4))

    axes[0].plot(generations, test_acc, 'o-', linewidth=2, markersize=8)
    axes[0].set_xlabel('Generation')
    axes[0].set_ylabel('Test Accuracy (%)')
    axes[0].set_title('Born-Again: Accuracy vs Generation')
    axes[0].grid(True, alpha=0.3)

    colors = ['gray'] + ['green' if g > 0 else 'red' for g in gains[1:]]
    axes[1].bar(generations, gains, color=colors, alpha=0.7)
    axes[1].axhline(0, color='black', linestyle='-')
    axes[1].set_xlabel('Generation')
    axes[1].set_ylabel('Accuracy Gain (%)')
    axes[1].set_title('Improvement per Generation')
    axes[1].grid(True, alpha=0.3)

    plt.tight_layout()
    plt.show()
    print(f"总提升: {test_acc[-1] - test_acc[0]:+.1f}%")


visualize_born_again_effect()

---

## 6. 总结

| 方法 | 特点 | 适用场景 |
|:-----|:-----|:---------|
| **Born-Again** | 迭代蒸馏 | 有足够训练时间 |
| **深度监督** | 深层→浅层 | 深层网络 |
| **辅助分类器** | 多出口 | 推理加速 |