In [1]:
import torch
from torch import nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
from torchinfo import summary
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# 设置随机数种子，便于复现
torch.manual_seed(0)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
device

device(type='cuda')

In [4]:
# 使用cuDNN加速卷积运算
torch.backends.cudnn.benchmark = True

In [5]:
# 载入训练集
train_dataset = torchvision.datasets.MNIST(
    root="../mnist/",
    train=True,
    transform=transforms.ToTensor(),
    download=False
)

# 载入测试集
test_dataset = torchvision.datasets.MNIST(
    root="../mnist/",
    train=False,
    transform=transforms.ToTensor(),
    download=False
)

# 生成dataloader
train_loader = DataLoader(dataset=train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=32, shuffle=False)

In [6]:
class TeacherModel(nn.Module):
    def __init__(self, in_channels=1, num_classes=10) -> None:
        super(TeacherModel, self).__init__()
        self.relu = nn.ReLU()
        self.fc1 = nn.Linear(784, 1200)
        self.fc2 = nn.Linear(1200, 1200)
        self.fc3 = nn.Linear(1200, num_classes)
        self.dropout = nn.Dropout(p=0.5)

    def forward(self, x):
        x = x.view(-1, 784)

        x = self.fc1(x)
        x = self.dropout(x)
        x = self.relu(x)

        x = self.fc2(x)
        x = self.dropout(x)
        x = self.relu(x)

        x = self.fc3(x)

        return x

In [7]:
model = TeacherModel()
model = model.to(device)

In [8]:
summary(model)

Layer (type:depth-idx)                   Param #
TeacherModel                             --
├─ReLU: 1-1                              --
├─Linear: 1-2                            942,000
├─Linear: 1-3                            1,441,200
├─Linear: 1-4                            12,010
├─Dropout: 1-5                           --
Total params: 2,395,210
Trainable params: 2,395,210
Non-trainable params: 0

In [9]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

In [10]:
epochs = 6
for epoch in range(epochs):
    model.train()

    # 训练集上训练模型权重
    for data, targets in tqdm(train_loader):
        data = data.to(device)
        targets = targets.to(device)

        # 前向预测
        preds = model(data)
        loss = criterion(preds, targets)

        # 反向传播，优化权重
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    # 测试集上评估模型
    model.eval()
    num_correct = 0
    num_samples = 0

    with torch.no_grad():
        for x, y in test_loader:
            x = x.to(device)
            y = y.to(device)

            preds = model(x)
            predictions = preds.max(1).indices
            num_correct += (predictions == y).sum()
            num_samples += predictions.size(0)
        acc = (num_correct/num_samples).item()

    model.train()
    print('Epoch:{}\t Accuracy:{:4f}'.format(epoch+1, acc))

100%|██████████| 1875/1875 [00:07<00:00, 260.25it/s]


Epoch:1	 Accuracy:0.940800


100%|██████████| 1875/1875 [00:05<00:00, 355.91it/s]


Epoch:2	 Accuracy:0.959900


100%|██████████| 1875/1875 [00:05<00:00, 347.34it/s]


Epoch:3	 Accuracy:0.969300


100%|██████████| 1875/1875 [00:05<00:00, 337.21it/s]


Epoch:4	 Accuracy:0.974800


100%|██████████| 1875/1875 [00:05<00:00, 313.56it/s]


Epoch:5	 Accuracy:0.977100


100%|██████████| 1875/1875 [00:05<00:00, 341.03it/s]


Epoch:6	 Accuracy:0.976600


In [11]:
teacher_model = model

In [12]:
class StudentModel(nn.Module):
    def __init__(self, in_channels=1, num_classes=10) -> None:
        super(StudentModel, self).__init__()
        self.relu = nn.ReLU()
        self.fc1 = nn.Linear(784, 20)
        self.fc2 = nn.Linear(20, 20)
        self.fc3 = nn.Linear(20, num_classes)

    def forward(self, x):
        x = x.view(-1, 784)
        x = self.fc1(x)
        x = self.relu(x)

        x = self.fc2(x)
        x = self.relu(x)

        x = self.fc3(x)

        return x

In [13]:
# 从头训练学生模型
model = StudentModel()
model = model.to(device)

In [14]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

In [15]:
epochs = 3
for epoch in range(epochs):
    model.train()

    # 训练集上训练模型权重
    for data, targets in tqdm(train_loader):
        data =  data.to(device)
        targets = targets.to(device)

        # 前向预测
        preds = model(data)
        loss = criterion(preds, targets)

        # 反向传播
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    # 测试集上评估模型
    model.eval()
    num_correct = 0
    num_samples = 0

    with torch.no_grad():
        for x, y in test_loader:
            x = x.to(device)
            y = y.to(device)

            preds = model(x)
            predictions = preds.max(1).indices
            num_correct += (predictions == y).sum()
            num_samples += predictions.size(0)
        acc = (num_correct/num_samples).item()

    model.train()
    print('Epoch:{}\t Accuracy:{:4f}'.format(epoch+1, acc))

100%|██████████| 1875/1875 [00:05<00:00, 314.92it/s]


Epoch:1	 Accuracy:0.835100


100%|██████████| 1875/1875 [00:05<00:00, 330.29it/s]


Epoch:2	 Accuracy:0.878600


100%|██████████| 1875/1875 [00:05<00:00, 342.10it/s]


Epoch:3	 Accuracy:0.894900


In [16]:
student_model_scratch = model

In [17]:
# 知识蒸馏训练学生模型
# 准备预训练好的教师模型
teacher_model.eval()

# 准备新的学生模型
model = StudentModel()
model = model.to(device)
model.train()

# 蒸馏温度
temp = 7

In [18]:
# hard_loss
hard_loss = nn.CrossEntropyLoss()
# hard_loss权重
alpha = 0.3

# soft_loss
soft_loss = nn.KLDivLoss(reduction="batchmean")
# soft_loss = nn.CrossEntropyLoss()

optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

In [19]:
epochs = 3
for epoch in range(epochs):
    model.train()

    # 训练集上训练模型权重
    for data, targets in tqdm(train_loader):
        data =  data.to(device)
        targets = targets.to(device)

        # 教师模型预测
        with torch.no_grad():
            teacher_preds = teacher_model(data)

        # 前向预测
        student_preds = model(data)
        student_loss = hard_loss(student_preds, targets)

        # 综合hard_loss, soft_loss
        ditillation_loss = soft_loss(
            F.softmax(student_preds/temp, dim=1).log(),
            F.softmax(teacher_preds/temp, dim=1)
        )

        loss = alpha * student_loss + (1-alpha) * ditillation_loss

        # 反向传播
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    # 测试集上评估模型
    model.eval()
    num_correct = 0
    num_samples = 0

    with torch.no_grad():
        for x, y in test_loader:
            x = x.to(device)
            y = y.to(device)

            preds = model(x)
            predictions = preds.max(1).indices
            num_correct += (predictions == y).sum()
            num_samples += predictions.size(0)
        acc = (num_correct/num_samples).item()

    model.train()
    print('Epoch:{}\t Accuracy:{:4f}'.format(epoch+1, acc))

100%|██████████| 1875/1875 [00:07<00:00, 256.02it/s]


Epoch:1	 Accuracy:0.845300


100%|██████████| 1875/1875 [00:07<00:00, 244.48it/s]


Epoch:2	 Accuracy:0.884100


100%|██████████| 1875/1875 [00:06<00:00, 270.75it/s]


Epoch:3	 Accuracy:0.897500


In [None]:
x = torch.tensor([[0.88, 0.12], [0.92, 0.08]])
y = torch.tensor([[0.7, 0.3], [0.6, 0.4]])
loss = soft_loss(x, y)
loss

In [None]:
loss = y * (y.log() - x)
loss = loss.sum()/2
loss