# 特征蒸馏 (Feature Distillation)

**SOTA 教育标准** | 包含 FitNets、注意力迁移、中间层蒸馏

---

## 1. 特征蒸馏原理

**核心思想**: 不仅蒸馏输出，还蒸馏中间层特征。

**FitNets 损失**:

$$\mathcal{L}_{hint} = \|W_r \cdot F_s - F_t\|_2^2$$

其中 $W_r$ 是回归器，用于匹配维度。

In [None]:
from __future__ import annotations
from typing import Dict, List
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")

---

## 2. 特征蒸馏损失

In [None]:
class FeatureDistillationLoss(nn.Module):
    """特征蒸馏损失。
    
    Core Idea: 最小化学生和教师中间层特征的差异。
    """

    def __init__(self, student_channels: int, teacher_channels: int):
        super().__init__()
        self.regressor = nn.Conv2d(student_channels, teacher_channels, 1) if student_channels != teacher_channels else nn.Identity()

    def forward(self, student_feat: Tensor, teacher_feat: Tensor) -> Tensor:
        student_feat = self.regressor(student_feat)
        return F.mse_loss(student_feat, teacher_feat)


# 测试
loss_fn = FeatureDistillationLoss(64, 128)
student_feat = torch.randn(4, 64, 8, 8)
teacher_feat = torch.randn(4, 128, 8, 8)
loss = loss_fn(student_feat, teacher_feat)
print(f"Feature Loss: {loss.item():.4f}")

---

## 3. 注意力迁移

In [None]:
class AttentionTransferLoss(nn.Module):
    """注意力迁移损失。
    
    Core Idea: 迁移教师的注意力图到学生。
    """

    def __init__(self, p: int = 2):
        super().__init__()
        self.p = p

    def attention_map(self, feat: Tensor) -> Tensor:
        """计算注意力图: 沿通道维度求和。"""
        return feat.pow(self.p).mean(dim=1)

    def forward(self, student_feat: Tensor, teacher_feat: Tensor) -> Tensor:
        s_attn = self.attention_map(student_feat)
        t_attn = self.attention_map(teacher_feat)
        
        # 归一化
        s_attn = s_attn / s_attn.sum(dim=(1, 2), keepdim=True)
        t_attn = t_attn / t_attn.sum(dim=(1, 2), keepdim=True)
        
        return (s_attn - t_attn).pow(2).mean()


# 测试
loss_fn = AttentionTransferLoss()
student_feat = torch.randn(4, 64, 8, 8)
teacher_feat = torch.randn(4, 128, 8, 8)
loss = loss_fn(student_feat, teacher_feat)
print(f"Attention Transfer Loss: {loss.item():.4f}")

---

## 4. 可视化

In [None]:
def visualize_attention_maps():
    """可视化注意力图。"""
    teacher_feat = torch.randn(1, 64, 16, 16)
    student_feat = torch.randn(1, 32, 16, 16)
    
    t_attn = teacher_feat.pow(2).mean(dim=1)[0]
    s_attn = student_feat.pow(2).mean(dim=1)[0]
    
    fig, axes = plt.subplots(1, 2, figsize=(10, 4))
    axes[0].imshow(t_attn.numpy(), cmap='hot')
    axes[0].set_title('Teacher Attention')
    axes[1].imshow(s_attn.numpy(), cmap='hot')
    axes[1].set_title('Student Attention')
    plt.tight_layout()
    plt.show()


visualize_attention_maps()

---

## 5. 总结

| 方法 | 损失 | 特点 |
|:-----|:-----|:-----|
| **FitNets** | MSE | 直接匹配特征 |
| **AT** | 注意力差异 | 迁移空间注意力 |
| **FSP** | Gram矩阵 | 迁移特征关系 |