知识蒸馏代码实现（手写数字集）

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
# from torchinfo import summary

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
torch.manual_seed(0)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.backends.cudnn.benchmark = True

train_dataset = torchvision.datasets.FashionMNIST(root='../DeepLearning/DL_review_code/Datasets/', train=True, transform=transforms.ToTensor(), download=False)
test_dataset = torchvision.datasets.FashionMNIST(root='../DeepLearning/DL_review_code/Datasets/', train=False, transform=transforms.ToTensor(), download=False)
train_iter = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_iter = DataLoader(test_dataset, batch_size=32, shuffle=True)

In [3]:
#搭建教师网络
class Teacher_model(nn.Module):
    def __init__(self, in_channels=1, num_classes=10) -> None:
        super(Teacher_model, self).__init__()
        self.fc1 = nn.Linear(784, 1200)
        self.fc2 = nn.Linear(1200, 1200)
        self.fc3 = nn.Linear(1200, 10)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.5)
    
    def forward(self, X):
        X = X.view(-1, 784)
        X = self.fc1(X)
        X = self.dropout(X)
        X = self.relu(X)
        X = self.fc2(X)
        X = self.dropout(X)
        X = self.relu(X)
        X = self.fc3(X)
        return X

teacher_model = Teacher_model()
teacher_model = teacher_model.to(device)

loss_function = nn.CrossEntropyLoss()
optim = torch.optim.Adam(teacher_model.parameters(), lr=0.0001)

In [5]:
# 教师模型训练
num_epochs = 6
for epoch in range(num_epochs):
    teacher_model.train()
    for X, y in train_iter:
        X, y = X.to(device), y.to(device)
        y_hat = teacher_model(X)
        l = loss_function(y_hat, y)
        optim.zero_grad()
        l.backward()
        optim.step()
    
    teacher_model.eval()
    num_correct, num_samples = 0.0, 0
    with torch.no_grad():
        for X, y in test_iter:
            X, y = X.to(device), y.to(device)
            y_hat = teacher_model(X)
            pre = y_hat.argmax(dim=1)
            num_correct += (pre == y).float().sum().item()
            num_samples += y.shape[0]
        acc = num_correct/num_samples
    teacher_model.train()
    print('epoch:%d, accurate=%.3f' % (epoch, acc))

epoches:0, accurate=0.846
epoches:1, accurate=0.862
epoches:2, accurate=0.870
epoches:3, accurate=0.875
epoches:4, accurate=0.872
epoches:5, accurate=0.872


In [7]:
# 搭建学生网络
class Student_model(nn.Module):
    def __init__(self, in_channels=1, num_class=10) -> None:
        super(Student_model, self).__init__()
        self.fc1 = nn.Linear(784, 20)
        self.fc2 = nn.Linear(20, 20)
        self.fc3 = nn.Linear(20, 10)
        self.relu = nn.ReLU()
    
    def forward(self, x):
        x = x.view(-1, 784)
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.fc3(x)
        return x

student_model = Student_model().to(device)
loss_function = nn.CrossEntropyLoss()
optim = torch.optim.Adam(student_model.parameters(), lr=0.0001)

In [8]:
num_epochs = 6
for epoch in range(num_epochs):
    student_model.train()
    for X, y in train_iter:
        X, y = X.to(device), y.to(device)
        y_hat = student_model(X)
        l = loss_function(y_hat, y)
        optim.zero_grad()
        l.backward()
        optim.step()
    
    student_model.eval()
    num_correct, num_samples = 0.0, 0
    with torch.no_grad():
        for X, y in test_iter:
            X, y = X.to(device), y.to(device)
            y_hat = student_model(X)
            num_correct += (y_hat.argmax(dim=1) == y).float().sum().item()
            num_samples += y.shape[0]
        acc = num_correct / num_samples
    student_model.train()
    print('epoch:%d, accurate=%.3f' % (epoch, acc))

epoch:0, accurate=0.689
epoch:1, accurate=0.762
epoch:2, accurate=0.794
epoch:3, accurate=0.814
epoch:4, accurate=0.821
epoch:5, accurate=0.828


In [9]:
# 开始知识蒸馏算法
teacher_model.eval()
T = 7 #蒸馏温度
hard_loss = nn.CrossEntropyLoss()
alpha = 0.3
soft_loss = nn.KLDivLoss(reduction='batchmean')
optim = torch.optim.Adam(student_model.parameters(), lr=0.0001)

In [10]:
epoches = 5
for epoch in range(epoches):
    student_model.train()
    for X, y in train_iter:
        X, y = X.to(device), y.to(device)
        with torch.no_grad():
            teacher_y = teacher_model(X)
        student_y = student_model(X)
        hard_l = hard_loss(student_y, y)
        soft_l = soft_loss(F.softmax(student_y/T, dim=1), F.softmax(teacher_y/T, dim=1))
        all_l = alpha*hard_l + (1-alpha)*soft_l

        optim.zero_grad()
        all_l.backward()
        optim.step()
    
    student_model.eval()
    num_correct, num_samples = 0.0, 0
    with torch.no_grad():
        for X, y in test_iter:
            X, y = X.to(device), y.to(device)
            y_hat = student_model(X)
            num_correct += (y_hat.argmax(dim=1) == y).float().sum().item()
            num_samples += y.shape[0]
        acc = num_correct / num_samples
    student_model.train()
    print('epoch:%d, accurate=%.3f' % (epoch, acc))

epoch:0, accurate=0.827
epoch:1, accurate=0.830
epoch:2, accurate=0.830
epoch:3, accurate=0.833
epoch:4, accurate=0.834
