# 知识蒸馏

在trainStudent中，我们直接使用简单的网络进行训练，当epoch为5时，模型准确率为58.66%，在本节中，我们使用知识蒸馏训练简单网络来对比使用知识蒸馏前后的性能。

In [1]:
import os
import random
import numpy as np
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision import datasets

# 知识蒸馏 KD 的损失函数
from loss.kd import loss

In [2]:
# 设置随机数种子, 从而可以复现
def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True

setup_seed(42)

In [3]:
# device = (
#     "cuda"
#     if torch.cuda.is_available()
#     else "mps"
#     if torch.backends.mps.is_available()
#     else "cpu"
# )

device = "cuda" # 'cpu'

## 定义超参数

In [4]:
T = 4               # temperature : 知识蒸馏中的温度
ALPHA = 0.1         # alpha : hard_loss(硬损失交叉熵)的loss weight 
BETA = 0.9          # beta : soft_loss(软损失KL散度)的loss weight

## 加载教师模型, 以及定义学生网络

In [6]:
class LeNet(nn.Module):
    def __init__(self, num_classes=10):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5)
        self.conv2 = nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5)
        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.fc1 = nn.Linear(in_features=16 * 4 * 4, out_features=120)
        self.fc2 = nn.Linear(in_features=120, out_features=84)
        self.fc3 = nn.Linear(in_features=84, out_features=num_classes)

    def forward(self, x):
        x = self.maxpool(F.relu(self.conv1(x)))
        x = self.maxpool(F.relu(self.conv2(x)))

        x = x.view(x.size()[0], -1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)

        return x
    
class LeNetHalfChannel(nn.Module):
    def __init__(self, num_classes=10):
        super(LeNetHalfChannel, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=3, kernel_size=5)   
        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.fc1 = nn.Linear(in_features=3 * 12 * 12, out_features=num_classes)   

    def forward(self, x):
        x = self.maxpool(F.relu(self.conv1(x)))

        x = x.view(x.size()[0], -1)
        x = F.relu(self.fc1(x))
        
        return x
    

teacher_net = LeNet().to(device=device)
student_net = LeNetHalfChannel().to(device=device)

data_dir = "../../01prune/notebook/0.minist_classify"
teacher_net.load_state_dict(torch.load(data_dir+'/model.pt'))

  teacher_net.load_state_dict(torch.load(data_dir+'/model.pt'))


<All keys matched successfully>

## 加载数据集
第一次使用会先进行下载, 如果下载的很慢, 可以手动下载数据集然后拖入到 data 文件夹下

In [7]:
# 设置归一化
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

# 获取数据集
# 这里直接读取 ch02 中下载好的数据
train_dataset = datasets.MNIST(root=data_dir+'/data/mnist/', train=True, download=False, transform=transform)  
test_dataset = datasets.MNIST(root=data_dir+'/data/mnist/', train=False, download=False, transform=transform)  # train=True训练集，=False测试集

# 设置DataLoader
batch_size = 64
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

## 优化器

In [8]:
lr = 0.01
momentum = 0.5
num_epoch = 5
optimizer = torch.optim.SGD(student_net.parameters(),  lr=lr, momentum=momentum)  # lr学习率，momentum冲量

## Train函数和Test函数

In [9]:
# 分别定义训练集和测试集上的最佳Acc, 使用 global 修饰为全局变量, 然后再训练期间更新
best_train_acc = 0
best_test_acc = 0

In [10]:

def train(epoch):
    global best_train_acc

    # 设置学生模型为训练模式
    student_net.train()

    print('\nEpoch: %d' % epoch)

    train_loss = 0
    correct = 0
    total = 0

    # 使用 tqdm 包装 trainloader 以显示进度条
    with tqdm(train_loader, desc=f"Training Epoch {epoch}", total=len(train_loader)) as pbar:
        for batch_idx, (inputs, targets) in enumerate(pbar):
            inputs, targets = inputs.to(device), targets.to(device)
            optimizer.zero_grad()

            logits_student = student_net(inputs)
            with torch.no_grad():
                logits_teacher = teacher_net(inputs)

            # 硬损失
            ce_loss = nn.CrossEntropyLoss()(logits_student, targets)
            # 软损失
            kd_loss = loss(logits_student, logits_teacher, temperature=T)
            total_loss = ALPHA * ce_loss + BETA * kd_loss

            total_loss.backward()
            optimizer.step()

            train_loss += total_loss.item()
            _, predicted = logits_student.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

            # 使用 set_postfix 更新进度条的后缀
            pbar.set_postfix(loss=train_loss / (batch_idx + 1), acc=f"{100. * correct / total:.1f}%")

    # 如果当前训练集上的准确率高于 best_test_acc，则更新 best_test_acc
    acc = 100 * correct / total
    if acc > best_train_acc:
        best_train_acc = acc


In [11]:
def test(net, epoch):
    global best_test_acc
    net.eval()

    test_loss = 0
    correct = 0
    total = 0

    with torch.no_grad():
        # 使用 tqdm 包装 testloader 以显示进度条
        with tqdm(test_loader, desc=f"Testing Epoch {epoch}", total=len(test_loader)) as pbar:
            for batch_idx, (inputs, targets) in enumerate(pbar):

                inputs, targets = inputs.to(device), targets.to(device)
                logits_student = net(inputs)

                loss = nn.CrossEntropyLoss()(logits_student, targets)

                test_loss += loss.item()
                _, predicted = logits_student.max(1)
                total += targets.size(0)
                correct += predicted.eq(targets).sum().item()

                # 在 tqdm 进度条的后缀中显示当前损失和准确率
                pbar.set_postfix(loss=test_loss / (batch_idx + 1), acc=f"{100. * correct / total:.1f}%")

        # 计算当前测试集上的准确率
        acc = 100. * correct / total

        # 如果当前测试集上的准确率高于 best_test_acc，则更新 best_test_acc
        # 并且将学生模型保存下来
        if acc > best_test_acc:
            print('Saving..')
            torch.save(student_net, 'checkpoints/distillation_kd.pt')
            best_test_acc = acc

## 使用训练

In [12]:
for epoch in range(1, num_epoch + 1) :
    train(epoch)
    test(student_net, epoch)


Epoch: 1


Training Epoch 1: 100%|█████████████████████████████████████████| 938/938 [00:18<00:00, 50.97it/s, acc=75.1%, loss=5.23]
Testing Epoch 1: 100%|█████████████████████████████████████████| 157/157 [00:02<00:00, 67.40it/s, acc=79.8%, loss=0.789]


Saving..

Epoch: 2


Training Epoch 2: 100%|█████████████████████████████████████████| 938/938 [00:18<00:00, 51.74it/s, acc=80.7%, loss=4.32]
Testing Epoch 2: 100%|█████████████████████████████████████████| 157/157 [00:02<00:00, 67.98it/s, acc=81.5%, loss=0.685]


Saving..

Epoch: 3


Training Epoch 3: 100%|█████████████████████████████████████████| 938/938 [00:18<00:00, 51.65it/s, acc=82.4%, loss=4.05]
Testing Epoch 3: 100%|█████████████████████████████████████████| 157/157 [00:02<00:00, 60.45it/s, acc=86.1%, loss=0.416]


Saving..

Epoch: 4


Training Epoch 4: 100%|█████████████████████████████████████████| 938/938 [00:18<00:00, 51.54it/s, acc=86.4%, loss=2.23]
Testing Epoch 4: 100%|█████████████████████████████████████████| 157/157 [00:02<00:00, 64.38it/s, acc=86.8%, loss=0.376]


Saving..

Epoch: 5


Training Epoch 5: 100%|█████████████████████████████████████████| 938/938 [00:18<00:00, 51.32it/s, acc=86.7%, loss=2.14]
Testing Epoch 5: 100%|█████████████████████████████████████████| 157/157 [00:02<00:00, 62.24it/s, acc=87.0%, loss=0.357]

Saving..





In [13]:
print('best_Train_Acc = ', best_train_acc)
print('best_Test_Acc = ', best_test_acc)

best_Train_Acc =  86.695
best_Test_Acc =  87.03


直接训练学生模型准确率为58.66%；使用知识蒸馏的方式训练学生模型，测试集上的准确率为96.85%。