# 使用DKD方式进行蒸馏

现有的知识蒸馏方法主要关注于中间层的深度特征蒸馏，而对logit蒸馏的重要性认识不足。[DKD]()重新定义了传统的知识蒸馏损失函数，将其分解为目标类知识蒸馏（TCKD）和非目标类知识蒸馏（NCKD）。
- 目标类知识蒸馏（TCKD）：关注于目标类的知识传递。
- 非目标类知识蒸馏（NCKD）：关注于非目标类之间的知识传递。

![DKD](../images/dkd.png)

传统的知识蒸馏损失函数可以表示为:

$$
K D=K L\left(p_T \| p_S\right)
$$


其中， $p_T$ 和 $p_S$ 分别是教师模型和学生模型的预测概率。
在DKD中，损失函数被重构为:

$$
K D=T C K D+\left(1-p_T^t\right) \cdot N C K D
$$


这里， $p_T^t$ 是教师模型对目标类的预测概率。



在DKD中，引入了两个超参数：
- $\alpha$ ：用于TCKD的权重。
- $\boldsymbol{\beta}$ ：用于NCKD的权重。

因此，DKD的损失函数可以表示为:

$$
D K D=\alpha \cdot T C K D+\beta \cdot N C K D
$$


通过调整这两个超参数，可以灵活地平衡TCKD和NCKD的重要性。


在训练过程中，DKD的实现步骤如下:
1. 计算logits: 从教师模型和学生模型中获取输出logits。
2. 应用softmax: 将logits转换为概率分布。
3. 计算 $T C K D$ 和NCKD:
   - TCKD: 计算教师和学生在目标类上的KL散度。
   - NCKD：计算教师和学生在非目标类上的KL散度。
4. 合并损失：根据超参数 $\alpha$ 和 $\beta$ 合并TCKD和NCKD的损失，得到最终的DKD损失。

In [1]:
import os
import random
import numpy as np
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision import datasets

# 知识蒸馏 KD 的损失函数
from loss.dkd import dkd_loss

In [2]:
# 设置随机数种子, 从而可以复现
def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True

setup_seed(42)

In [3]:
# device = (
#     "cuda"
#     if torch.cuda.is_available()
#     else "mps"
#     if torch.backends.mps.is_available()
#     else "cpu"
# )
device = "cuda" # cpu

## 定义超参数

In [4]:
T = 4               # temperature : 知识蒸馏中的温度
ALPHA = 1.0         # alpha : TCKD 部分的loss weight
BETA = 2.0          # beta : NCKD 部分的loss weight
LOSS_CE = 1.0       # loss_ce : 交叉熵的loss weight

## 加载教师模型, 以及定义学生网络

In [6]:
class LeNet(nn.Module):
    def __init__(self, num_classes=10):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5)
        self.conv2 = nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5)
        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.fc1 = nn.Linear(in_features=16 * 4 * 4, out_features=120)
        self.fc2 = nn.Linear(in_features=120, out_features=84)
        self.fc3 = nn.Linear(in_features=84, out_features=num_classes)

    def forward(self, x):
        x = self.maxpool(F.relu(self.conv1(x)))
        x = self.maxpool(F.relu(self.conv2(x)))

        x = x.view(x.size()[0], -1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)

        return x
    
class LeNetHalfChannel(nn.Module):
    def __init__(self, num_classes=10):
        super(LeNetHalfChannel, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=3, kernel_size=5)   
        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.fc1 = nn.Linear(in_features=3 * 12 * 12, out_features=10)   

    def forward(self, x):
        x = self.maxpool(F.relu(self.conv1(x)))

        x = x.view(x.size()[0], -1)
        x = F.relu(self.fc1(x))
        
        return x
    

teacher_net = LeNet().to(device=device)
student_net = LeNetHalfChannel().to(device=device)

data_dir = "../../01prune/notebook/0.minist_classify"
teacher_net.load_state_dict(torch.load(data_dir+'/model.pt'))

  teacher_net.load_state_dict(torch.load(data_dir+'/model.pt'))


<All keys matched successfully>

## 加载数据集


In [7]:
# 设置归一化
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

# 获取数据集
# 这里直接读取 ch02 中下载好的数据
train_dataset = datasets.MNIST(root=data_dir+'/data/mnist/', train=True, download=False, transform=transform)  
test_dataset = datasets.MNIST(root=data_dir+'/data/mnist/', train=False, download=False, transform=transform)  # train=True训练集，=False测试集

# 设置DataLoader
batch_size = 64
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

## 优化器

In [8]:
lr = 0.01
momentum = 0.5
num_epoch = 5
optimizer = torch.optim.SGD(student_net.parameters(),  lr=lr, momentum=momentum)  # lr学习率，momentum冲量

## Train函数和Test函数

In [9]:
# 分别定义训练集和测试集上的最佳Acc, 使用 global 修饰为全局变量, 然后再训练期间更新
best_train_acc = 0
best_test_acc = 0

In [10]:
from tqdm import tqdm

def train(epoch):
    global best_train_acc

    # 设置学生模型为训练模式
    student_net.train()

    print('\nEpoch: %d' % epoch)

    train_loss = 0
    correct = 0
    total = 0

    # 使用 tqdm 包装 trainloader 以显示进度条
    with tqdm(train_loader, desc=f"Training Epoch {epoch}", total=len(train_loader)) as pbar:
        for batch_idx, (inputs, targets) in enumerate(pbar):
            inputs, targets = inputs.to(device), targets.to(device)
            optimizer.zero_grad()

            logits_student = student_net(inputs)
            with torch.no_grad():
                logits_teacher = teacher_net(inputs)

            # 硬损失
            ce_loss = nn.CrossEntropyLoss()(logits_student, targets)
            # 软损失
            kd_loss = dkd_loss(logits_student, logits_teacher, targets, ALPHA, BETA, T)
            total_loss = ALPHA * ce_loss + BETA * kd_loss

            total_loss.backward()
            optimizer.step()

            train_loss += total_loss.item()
            _, predicted = logits_student.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

            # 使用 set_postfix 更新进度条的后缀
            pbar.set_postfix(loss=train_loss / (batch_idx + 1), acc=f"{100. * correct / total:.1f}%")

    # 如果当前训练集上的准确率高于 best_test_acc，则更新 best_test_acc
    acc = 100 * correct / total
    if acc > best_train_acc:
        best_train_acc = acc


In [11]:
def test(net, epoch):
    global best_test_acc
    net.eval()

    test_loss = 0
    correct = 0
    total = 0

    with torch.no_grad():
        # 使用 tqdm 包装 testloader 以显示进度条
        with tqdm(test_loader, desc=f"Testing Epoch {epoch}", total=len(test_loader)) as pbar:
            for batch_idx, (inputs, targets) in enumerate(pbar):

                inputs, targets = inputs.to(device), targets.to(device)
                logits_student = net(inputs)

                loss = nn.CrossEntropyLoss()(logits_student, targets)

                test_loss += loss.item()
                _, predicted = logits_student.max(1)
                total += targets.size(0)
                correct += predicted.eq(targets).sum().item()

                # 在 tqdm 进度条的后缀中显示当前损失和准确率
                pbar.set_postfix(loss=test_loss / (batch_idx + 1), acc=f"{100. * correct / total:.1f}%")

        # 计算当前测试集上的准确率
        acc = 100. * correct / total

        # 如果当前测试集上的准确率高于 best_test_acc，则更新 best_test_acc
        # 并且将学生模型保存下来
        if acc > best_test_acc:
            print('Saving..')
            torch.save(student_net, 'checkpoints/distillation_dkd.pt')
            best_test_acc = acc

## 训练

In [12]:
for epoch in range(1, num_epoch + 1) :
    train(epoch)
    test(student_net, epoch)


Epoch: 1


Training Epoch 1: 100%|█████████████████████████████████████████| 938/938 [00:17<00:00, 52.51it/s, acc=47.9%, loss=34.6]
Testing Epoch 1: 100%|███████████████████████████████████████████| 157/157 [00:02<00:00, 69.17it/s, acc=50.3%, loss=1.4]


Saving..

Epoch: 2


Training Epoch 2: 100%|█████████████████████████████████████████| 938/938 [00:17<00:00, 54.71it/s, acc=55.7%, loss=30.6]
Testing Epoch 2: 100%|██████████████████████████████████████████| 157/157 [00:02<00:00, 71.65it/s, acc=59.0%, loss=1.26]


Saving..

Epoch: 3


Training Epoch 3: 100%|█████████████████████████████████████████| 938/938 [00:17<00:00, 54.13it/s, acc=58.9%, loss=29.3]
Testing Epoch 3: 100%|██████████████████████████████████████████| 157/157 [00:02<00:00, 74.74it/s, acc=59.2%, loss=1.36]


Saving..

Epoch: 4


Training Epoch 4: 100%|█████████████████████████████████████████| 938/938 [00:17<00:00, 55.02it/s, acc=58.9%, loss=29.2]
Testing Epoch 4: 100%|██████████████████████████████████████████| 157/157 [00:02<00:00, 74.88it/s, acc=59.1%, loss=1.28]



Epoch: 5


Training Epoch 5: 100%|█████████████████████████████████████████| 938/938 [00:16<00:00, 57.67it/s, acc=58.9%, loss=29.2]
Testing Epoch 5: 100%|███████████████████████████████████████████| 157/157 [00:02<00:00, 74.99it/s, acc=59.1%, loss=1.3]


In [13]:
print('best_Train_Acc = ', best_train_acc)
print('best_Test_Acc = ', best_test_acc)

best_Train_Acc =  58.915
best_Test_Acc =  59.17
