## 导入相关库和网络

使用 `res32x4` 去蒸馏 `res8x4`

In [1]:
import os
import random
import numpy as np
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision import datasets

# 知识蒸馏 KD 的损失函数
from loss.kd import loss

  Referenced from: <2BD1B165-EC09-3F68-BCE4-8FE4E70CA7E2> /Users/terry/miniconda3/lib/python3.11/site-packages/torchvision/image.so
  warn(


In [2]:
# 设置随机数种子, 从而可以复现
def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True

setup_seed(42)

In [3]:
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)

device = "cpu"

## 定义超参数

In [4]:
T = 4               # temperature : 知识蒸馏中的温度
ALPHA = 0.1         # alpha : hard_loss(硬损失交叉熵)的loss weight 
BETA = 0.9          # beta : soft_loss(软损失KL散度)的loss weight

## 加载教师模型, 以及定义学生网络

In [5]:
class LeNet(nn.Module):
    def __init__(self, num_classes=10):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5)
        self.conv2 = nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5)
        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.fc1 = nn.Linear(in_features=16 * 4 * 4, out_features=120)
        self.fc2 = nn.Linear(in_features=120, out_features=84)
        self.fc3 = nn.Linear(in_features=84, out_features=num_classes)

    def forward(self, x):
        x = self.maxpool(F.relu(self.conv1(x)))
        x = self.maxpool(F.relu(self.conv2(x)))

        x = x.view(x.size()[0], -1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)

        return x
    
class LeNetHalfChannel(nn.Module):
    def __init__(self, num_classes=10):
        super(LeNetHalfChannel, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=3, kernel_size=5)   
        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.fc1 = nn.Linear(in_features=3 * 12 * 12, out_features=10)   

    def forward(self, x):
        x = self.maxpool(F.relu(self.conv1(x)))

        x = x.view(x.size()[0], -1)
        x = F.relu(self.fc1(x))
        
        return x
    

teacher_net = LeNet().to(device=device)
student_net = LeNetHalfChannel().to(device=device)

teacher_net.load_state_dict(torch.load('../ch02/model.pt'))

<All keys matched successfully>

## 加载数据集
第一次使用会先进行下载, 如果下载的很慢, 可以手动下载数据集然后拖入到 data 文件夹下

In [6]:
# 设置归一化
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

# 获取数据集
# 这里直接读取 ch02 中下载好的数据
train_dataset = datasets.MNIST(root='../ch02/data/mnist/', train=True, download=False, transform=transform)  
test_dataset = datasets.MNIST(root='../ch02/data/mnist/', train=False, download=False, transform=transform)  # train=True训练集，=False测试集

# 设置DataLoader
batch_size = 64
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

## 优化器

In [7]:
lr = 0.01
momentum = 0.5
num_epoch = 5
optimizer = torch.optim.SGD(student_net.parameters(),  lr=lr, momentum=momentum)  # lr学习率，momentum冲量

## Train函数和Test函数

In [8]:
# 分别定义训练集和测试集上的最佳Acc, 使用 global 修饰为全局变量, 然后再训练期间更新
best_train_acc = 0
best_test_acc = 0

In [9]:
from tqdm import tqdm

def train(epoch):
    global best_train_acc

    # 设置学生模型为训练模式
    student_net.train()

    print('\nEpoch: %d' % epoch)

    train_loss = 0
    correct = 0
    total = 0

    # 使用 tqdm 包装 trainloader 以显示进度条
    with tqdm(train_loader, desc=f"Training Epoch {epoch}", total=len(train_loader)) as pbar:
        for batch_idx, (inputs, targets) in enumerate(pbar):
            inputs, targets = inputs.to(device), targets.to(device)
            optimizer.zero_grad()

            logits_student = student_net(inputs)
            with torch.no_grad():
                logits_teacher = teacher_net(inputs)

            # 硬损失
            ce_loss = nn.CrossEntropyLoss()(logits_student, targets)
            # 软损失
            kd_loss = loss(logits_student, logits_teacher, temperature=T)
            total_loss = ALPHA * ce_loss + BETA * kd_loss

            total_loss.backward()
            optimizer.step()

            train_loss += total_loss.item()
            _, predicted = logits_student.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

            # 使用 set_postfix 更新进度条的后缀
            pbar.set_postfix(loss=train_loss / (batch_idx + 1), acc=f"{100. * correct / total:.1f}%")

    # 如果当前训练集上的准确率高于 best_test_acc，则更新 best_test_acc
    acc = 100 * correct / total
    if acc > best_train_acc:
        best_train_acc = acc


In [10]:
def test(net, epoch):
    global best_test_acc
    net.eval()

    test_loss = 0
    correct = 0
    total = 0

    with torch.no_grad():
        # 使用 tqdm 包装 testloader 以显示进度条
        with tqdm(test_loader, desc=f"Testing Epoch {epoch}", total=len(test_loader)) as pbar:
            for batch_idx, (inputs, targets) in enumerate(pbar):

                inputs, targets = inputs.to(device), targets.to(device)
                logits_student = net(inputs)

                loss = nn.CrossEntropyLoss()(logits_student, targets)

                test_loss += loss.item()
                _, predicted = logits_student.max(1)
                total += targets.size(0)
                correct += predicted.eq(targets).sum().item()

                # 在 tqdm 进度条的后缀中显示当前损失和准确率
                pbar.set_postfix(loss=test_loss / (batch_idx + 1), acc=f"{100. * correct / total:.1f}%")

        # 计算当前测试集上的准确率
        acc = 100. * correct / total

        # 如果当前测试集上的准确率高于 best_test_acc，则更新 best_test_acc
        # 并且将学生模型保存下来
        if acc > best_test_acc:
            print('Saving..')
            torch.save(student_net, 'checkpoints/distillation_kd.pt')
            best_test_acc = acc

## 训练

In [11]:
for epoch in range(1, num_epoch + 1) :
    train(epoch)
    test(student_net, epoch)


Epoch: 1


Training Epoch 1:   0%|          | 0/938 [00:00<?, ?it/s, acc=9.4%, loss=14.2]

Training Epoch 1: 100%|██████████| 938/938 [00:16<00:00, 58.15it/s, acc=76.4%, loss=3.58]
Testing Epoch 1: 100%|██████████| 157/157 [00:01<00:00, 116.71it/s, acc=95.6%, loss=0.14] 


Saving..

Epoch: 2


Training Epoch 2: 100%|██████████| 938/938 [00:19<00:00, 49.07it/s, acc=95.5%, loss=0.639]
Testing Epoch 2: 100%|██████████| 157/157 [00:01<00:00, 108.94it/s, acc=96.5%, loss=0.111]


Saving..

Epoch: 3


Training Epoch 3: 100%|██████████| 938/938 [00:12<00:00, 72.31it/s, acc=96.1%, loss=0.529]
Testing Epoch 3: 100%|██████████| 157/157 [00:00<00:00, 158.06it/s, acc=96.4%, loss=0.111]



Epoch: 4


Training Epoch 4: 100%|██████████| 938/938 [00:12<00:00, 76.12it/s, acc=96.4%, loss=0.496]
Testing Epoch 4: 100%|██████████| 157/157 [00:00<00:00, 171.95it/s, acc=96.8%, loss=0.102] 


Saving..

Epoch: 5


Training Epoch 5: 100%|██████████| 938/938 [00:11<00:00, 82.18it/s, acc=96.3%, loss=0.482]
Testing Epoch 5: 100%|██████████| 157/157 [00:00<00:00, 169.48it/s, acc=96.7%, loss=0.103]


In [12]:
print('best_Train_Acc = ', best_train_acc)
print('best_Test_Acc = ', best_test_acc)

best_Train_Acc =  96.39
best_Test_Acc =  96.85
