In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
from torchinfo import summary
print(torch.__version__)

1.12.0+cu113


In [None]:
# 设置随机种子的目的是为了确保实验的可重复性。在许多深度学习任务中，随机性是一个重要的因素。例如，初始化神经网络的权重或在训练时打乱数据集。通过固定随机种子，您可以确保每次运行代码时都会得到相同的结果。
# 设置随机种子
torch.manual_seed(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 使用cuda进行加速卷积运算
torch.backends.cudnn.benchmark = True
# 载入训练集

# batch_size是指用于训练神经网络时一次性输入模型的样本数量。Batch size的设置会影响模型训练的速度和准确性。较小的batch size可能会导致训练速度较慢，但可以提供更精确的梯度估计。较大的batch size可以加速训练，但可能导致梯度估计不准确。选择合适的batch size取决于任务、模型和计算资源。
train_dataset = torchvision.datasets.MNIST(root="dataset/", train=True, transform=transforms.ToTensor(), download=True)
test_dateset = torchvision.datasets.MNIST(root="dataset/", train=False, transform=transforms.ToTensor(), download=True)
train_dataloder = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_dataloder = DataLoader(test_dateset, batch_size=32, shuffle=True)


In [4]:
# teacher model build
class Teacher_model(nn.Module):
    def __init__(self,in_channels=1,num_class=10):
        super(Teacher_model,self).__init__()
        self.fc1 = nn.Linear(784,1200)
        self.fc2 = nn.Linear(1200,1200)
        self.fc3 = nn.Linear(1200,10)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.5)
        
    def forward(self,x):
        x = x.view(-1,784)
        x = self.fc1(x)
        x = self.dropout(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.dropout(x)
        x = self.relu(x)
        x = self.fc3(x)
        return x

model = Teacher_model()
model = model.to(device)
loss_function = nn.CrossEntropyLoss()
optim = torch.optim.Adam(model.parameters(), lr=0.0001)

In [5]:
# begin training teacher model
epoches=6
for epoch in range(epoches):
    model.train()
    for image,label in train_dataloder:
        # load the label to GPU
        image,label = image.to(device),label.to(device)
        optim.zero_grad()
        out = model(image)
        loss = loss_function(out,label)
        loss.backward()
        optim.step()
    
#     evalute the model
    model.eval()
    num_correct = 0
    num_samples = 0
    with torch.no_grad():
        for image,label in test_dataloder:
            image = image .to(device)
            label = label.to(device)
            out = model(image)
            pre = out.max(1).indices
            num_correct += (pre == label).sum()
            num_samples += pre.size(0)
        acc = (num_correct / num_samples).item()
    model.train()
    print("epoches:{},accurate={}".format(epoch, acc))

teacher_model = model


epoches:0,accurate=0.9407999515533447
epoches:1,accurate=0.9608999490737915
epoches:2,accurate=0.9681999683380127
epoches:3,accurate=0.9736999869346619
epoches:4,accurate=0.9777999520301819
epoches:5,accurate=0.9789999723434448


In [6]:
# student model 
class Student_model(nn.Module):
    def __init__(self,in_channels = 1,num_class = 10):
        super(Student_model, self).__init__()
        self.fc1 = nn.Linear(784, 20)
        self.fc2 = nn.Linear(20, 20)
        self.fc3 = nn.Linear(20, 10)
        self.relu = nn.ReLU()
        
    def forward(self, x):
        x = x.view(-1, 784)
        x = self.fc1(x)
        # x = self.dropout(x)
        x = self.relu(x)
        x = self.fc2(x)
        # x = self.dropout(x)
        x = self.relu(x)
        x = self.fc3(x)
        return x
    
model = Student_model()
model = model.to(device)

# 损失函数和优化器
loss_function = nn.CrossEntropyLoss()
optim = torch.optim.Adam(model.parameters(), lr=0.0001)



In [8]:
epoches = 6
for epoch in range(epoches):
    model.train()
    for image,label in train_dataloder:
        image,label = image.to(device),label.to(device)
        optim.zero_grad()
        out = model(image)
        loss = loss_function(out,label)
        loss.backward()
        optim.step()

    model.eval()
    num_correct = 0
    num_samples = 0
    with torch.no_grad():
        for image,label in test_dataloder:
            image = image .to(device)
            label = label.to(device)
            out = model(image)
            pre = out.max(1).indices
            num_correct += (pre == label).sum()
            num_samples += pre.size(0)
        acc = (num_correct / num_samples).item()
    model.train()
    print("epoches:{},accurate={}".format(epoch, acc))

epoches:0,accurate=0.8671000003814697
epoches:1,accurate=0.8970999717712402
epoches:2,accurate=0.9088000059127808
epoches:3,accurate=0.9139999747276306
epoches:4,accurate=0.9182999730110168
epoches:5,accurate=0.9218999743461609


In [9]:
# 开始KD
# set the param
teacher_model.eval()
model=Student_model()
model=model.to(device)

# distillation temperature
T = 7
hard_loss = nn.CrossEntropyLoss()
alpha = 0.3
soft_loss = nn.KLDivLoss(reduction="batchmean")
optim=torch.optim.Adam(model.parameters(),lr=0.0001)

In [11]:
# training KD
epoches=5
for epoch in range(epoches):
    model.train()
    for image,label in train_dataloder:
        image,label = image.to(device),label.to(device)
        with torch.no_grad():
            teacher_output = teacher_model(image)
        optim.zero_grad()
        out=model(image)
        loss=hard_loss(out,label)
        ditillation_loss=soft_loss(F.softmax(out/T,dim=1),F.softmax(teacher_output/T,dim=1))
        loss_all=loss*alpha+ditillation_loss*(1-alpha)
        loss.backward()
        optim.step()
    
    model.eval()
    num_correct=0
    num_samples=0
    with torch.no_grad():
        for image,label in test_dataloder:
            image=image.to(device)
            label=label.to(device)
            out=model(image)
            pre=out.max(1).indices
            num_correct+=(pre==label).sum()
            num_samples+=pre.size(0)
        acc=(num_correct/num_samples).item()
 
    model.train()
    print("epoches:{},accurate={}".format(epoch,acc))
    

epoches:0,accurate=0.8514999747276306
epoches:1,accurate=0.8851000070571899
epoches:2,accurate=0.8989999890327454
epoches:3,accurate=0.9062999486923218
epoches:4,accurate=0.9132999777793884
