# 模型蒸馏教程

In [1]:
import torch
import os
import sys

import torch
from torch import optim
import torch.nn as nn
import torch.nn.functional as F
import lightning as L

## 下载数据集，构建dataloader

In [2]:
import torch
import torchvision.transforms as transforms
from torchvision.datasets.mnist import MNIST
from torch.utils.data import random_split
from torch.utils.data import DataLoader

# %%
'''MNIST DataModule'''


class MNISTDataModule(L.LightningDataModule):

    def __init__(self, dataset_dir, train_batch_size, test_batch_size, train_val_ratio, seed, num_workers):
        super().__init__()

        self.dataset_dir = dataset_dir #路径
        self.train_batch_size = train_batch_size
        self.test_batch_size = test_batch_size
        self.train_val_ratio = train_val_ratio #训练验证集比例
        self.seed = seed
        self.num_workers = num_workers

        self.transform = transforms.Compose([
            transforms.ToTensor(),
            # NOTE ((0.1307,), (0.3081,))，均值是0.1307，标准差是0.3081，由MNIST数据集提供方计算好的
            transforms.Normalize((0.1307,), (0.3081,)),
        ])
        self.dims = (1, 28, 28)
        self.num_classes = 10

    def prepare_data(self):
        pass
        # 已经有了就不用下载了
        # MNIST(self.data_dir, train=True, download=False)
        # MNIST(self.data_dir, train=False, download=False)

    def setup(self, stage=None):
        # Assign train/val datasets for use in dataloaders
        # 分为fit阶段和test阶段

        if stage == 'fit' or stage is None:
            # 载入train数据集
            mnist_train = MNIST(self.dataset_dir, train=True, download=True, transform=self.transform)
            # 划分train数据集的train和val比例
            mnist_train_length = len(mnist_train)
            train_val = [int(mnist_train_length * ratio) for ratio in self.train_val_ratio]
            # 设置seed
            generator = torch.Generator().manual_seed(self.seed)

            self.mnist_train, self.mnist_val = random_split(mnist_train, train_val, generator=generator)

        if stage == 'test' or stage is None:
            self.mnist_test = MNIST(self.dataset_dir, train=False, download=True, transform=self.transform)

    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=self.train_batch_size, num_workers=self.num_workers,persistent_workers=True)

    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=self.test_batch_size, num_workers=self.num_workers,persistent_workers=True)

    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=self.test_batch_size, num_workers=self.num_workers,persistent_workers=True)


In [3]:
class Arguments:
    pass


# NOTE:数据集参数
dataset_args = Arguments()
dataset_args.seed = 42
dataset_args.Dataset_Dir = r"./../../code/torch_data"
dataset_args.train_batch_size = 64
dataset_args.test_batch_size = 1000
dataset_args.train_val_ratio = (0.8, 0.2)
dataset_args.num_workers = 15

# %%
# 实例化mnist数据集对象
mnist = MNISTDataModule(dataset_dir=dataset_args.Dataset_Dir,
                        train_batch_size=dataset_args.train_batch_size,
                        test_batch_size=dataset_args.test_batch_size,
                        train_val_ratio=dataset_args.train_val_ratio,
                        seed=dataset_args.seed,
                        num_workers=dataset_args.num_workers)
mnist.setup()

# 实例化dataloaders
train_dataloader = mnist.train_dataloader()
val_dataloader = mnist.val_dataloader()
test_dataloader = mnist.test_dataloader()

# 训练教师网络

In [4]:

from __future__ import print_function
#import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR
from tqdm import tqdm,trange


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output


def train( model, device, train_loader, optimizer, epoch):
    model.train()
    with tqdm(total=len(train_loader.dataset)) as pbar:
        pbar.set_description('training')
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data)
            loss = F.nll_loss(output, target) #output已经log运算了，本质是交叉熵损失
            loss.backward()
            optimizer.step()
            if batch_idx % 10 == 0:
                pbar.update(len(data)*10)
                # print(loss.item())
                # print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                #    epoch, batch_idx * len(data), len(train_loader.dataset),
                #           100. * batch_idx / len(train_loader), loss.item()))
                # if args.dry_run:
                #    break


def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in tqdm(test_loader, desc='testing'):
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))


def main():
    # Training settings
    #parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
    #parser.add_argument('--batch-size', type=int, default=64, metavar='N',
                        #help='input batch size for training (default: 64)')

    #parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
     #                   help='input batch size for testing (default: 1000)')

    #parser.add_argument('--epochs', type=int, default=5, metavar='N',
     #                   help='number of epochs to train (default: 14)')

    #parser.add_argument('--lr', type=float, default=1.0, metavar='LR',
     #                   help='learning rate (default: 1.0)')

    #parser.add_argument('--gamma', type=float, default=0.7, metavar='M',
     #                   help='Learning rate step gamma (default: 0.7)')

    #parser.add_argument('--no-cuda', action='store_true', default=False,
     #                   help='disables CUDA training')

    #parser.add_argument('--dry-run', action='store_true', default=False,
     #                   help='quickly check a single pass')

    #parser.add_argument('--seed', type=int, default=42, metavar='S',
     #                   help='random seed (default: 1)')

    #parser.add_argument('--log-interval', type=int, default=10, metavar='N',
     #                   help='how many batches to wait before logging training status')

    #parser.add_argument('--save-model', action='store_true', default=True,
     #                   help='For Saving the current Model')

    #args = parser.parse_args()
    batch_size = 64
    test_batch_size = 1000
    lr = 1.0
    gamma = 0.7
    epochs = 5
    no_cuda = False
    seed = 42
    use_cuda = not no_cuda and torch.cuda.is_available()

    torch.manual_seed(seed)

    device = torch.device("cuda" if use_cuda else "cpu")

    train_kwargs = {'batch_size': batch_size}
    test_kwargs = {'batch_size': test_batch_size}
    if use_cuda:
        cuda_kwargs = {'num_workers': 1,
                       'pin_memory': True,
                       'shuffle': True}
        train_kwargs.update(cuda_kwargs)
        test_kwargs.update(cuda_kwargs)

    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])
    dataset1 = datasets.MNIST(dataset_args.Dataset_Dir, train=True, download=True,
                              transform=transform)
    dataset2 = datasets.MNIST(dataset_args.Dataset_Dir, train=False,
                              transform=transform)
    train_loader = torch.utils.data.DataLoader(dataset1, **train_kwargs)
    test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)

    model = Net().to(device)
    optimizer = optim.Adadelta(model.parameters(), lr=lr)

    scheduler = StepLR(optimizer, step_size=1, gamma=gamma)
    for epoch in range(1, epochs + 1):
        train( model, device, train_loader, optimizer, epoch)
        test(model, device, test_loader)
        scheduler.step()

    #if args.save_model:
    torch.save(model.state_dict(), "mnist_cnn.pt")

main()



training: : 60160it [00:11, 5031.29it/s]                          
testing: 100%|██████████| 10/10 [00:05<00:00,  1.73it/s]



Test set: Average loss: 0.0457, Accuracy: 9853/10000 (99%)



training: : 60160it [00:09, 6095.76it/s]                          
testing: 100%|██████████| 10/10 [00:05<00:00,  1.71it/s]



Test set: Average loss: 0.0380, Accuracy: 9872/10000 (99%)



training: : 60160it [00:09, 6066.63it/s]                          
testing: 100%|██████████| 10/10 [00:05<00:00,  1.69it/s]



Test set: Average loss: 0.0350, Accuracy: 9876/10000 (99%)



training: : 60160it [00:09, 6024.44it/s]                          
testing: 100%|██████████| 10/10 [00:05<00:00,  1.72it/s]



Test set: Average loss: 0.0302, Accuracy: 9902/10000 (99%)



training: : 60160it [00:09, 6089.50it/s]                          
testing: 100%|██████████| 10/10 [00:05<00:00,  1.69it/s]


Test set: Average loss: 0.0275, Accuracy: 9911/10000 (99%)






# 开始蒸馏

In [5]:
import torch
from torch import nn
import torch.nn.functional as F


# %%
class TeacherNet(nn.Module):
    """
    Network architecture taken from https://github.com/pytorch/examples/blob/master/mnist/main.py
    
    98.2% accuracy after 1 epoch
    """

    def __init__(self):
        super().__init__()

        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        return x


class StudentNet(nn.Module):
    """
    Naive linear model

    92.8% accuracy after 5 epochs, single FC layer
    """

    def __init__(self):
        super().__init__()

        self.fc1 = nn.Linear(28 * 28, 16)
        self.fc2 = nn.Linear(16, 10)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        return x


In [6]:
# %%
"""this file is adapted from
    https://github.com/bilunsun/knowledge_distillation  pl_distribution.py"""
# %%
import torch
from torch import optim
import torch.nn as nn
import torch.nn.functional as F

import lightning as L
import lightning.pytorch.callbacks as callbacks
import torchmetrics


class KDMoudle(L.LightningModule):
    def __init__(self, teacher, student, learning_rate, temperature, alpha):
        super().__init__()

        self.teacher = teacher
        self.teacher.requires_grad_(False) #冻结teacher模型，不更新参数
        self.student = student

        self.learning_rate = learning_rate

        self.temperature = temperature
        self.alpha = alpha

    def forward(self, x):
        student_logits = self.student(x)
        teacher_logits = self.teacher(x) #都没有softmax层

        return student_logits, teacher_logits

    def training_step(self, batch, batch_index):
        x, y = batch
        student_logits, teacher_logits = self.forward(x)

        # # NOTE:第一组：直接用hard_loss训练student网络
        # loss = F.cross_entropy(student_logits, y)
        #
        # # NOTE:第二组：用soft_loss训练student网络
        # loss = nn.KLDivLoss()(F.log_softmax(student_logits / self.temperature),
        #                       F.softmax(teacher_logits / self.temperature)) * (
        #                self.alpha * self.temperature * self.temperature)

        # NOTE:第三组：用hard_loss+soft_loss训练student网络
        soft_loss = nn.KLDivLoss()(F.log_softmax(student_logits / self.temperature,dim=1),
                                   F.softmax(teacher_logits / self.temperature,dim=1)) * (
                            self.alpha * self.temperature * self.temperature)
        hard_loss = F.cross_entropy(student_logits, y) * (1.0 - self.alpha)
        loss = hard_loss + soft_loss

        # WHY:student_logits为什么用log_softmax 而 teacher_logits直接用softmax？

        self.log("student_train_loss", loss)
        return loss

    def validation_step(self, batch, batch_index):
        x, y = batch
        student_logits, teacher_logits = self.forward(x)

        student_loss = F.cross_entropy(student_logits, y)

        student_preds = torch.argmax(student_logits, dim=1)
        student_acc = torchmetrics.functional.accuracy(student_preds, y,task='multiclass',num_classes=10)

        teacher_preds = torch.argmax(teacher_logits, dim=1)
        teacher_acc = torchmetrics.functional.accuracy(teacher_preds, y,task='multiclass',num_classes=10)

        self.log("student_val_loss", student_loss, prog_bar=True)
        self.log("student_val_acc", student_acc, prog_bar=True)
        self.log("teacher_val_acc", teacher_acc, prog_bar=True)

        return student_loss

    def test_step(self, batch, batch_index):
        x, y = batch
        student_logits, teacher_logits = self.forward(x)

        student_preds = torch.argmax(student_logits, dim=1)
        student_acc = torchmetrics.functional.accuracy(student_preds, y,task='multiclass',num_classes=10)

        teacher_preds = torch.argmax(teacher_logits, dim=1)
        teacher_acc = torchmetrics.functional.accuracy(teacher_preds, y,task='multiclass',num_classes=10)

        self.log("student_test_acc", student_acc, prog_bar=True)
        self.log("teacher_test_acc", teacher_acc, prog_bar=True)

    def configure_optimizers(self):
        optimizer = optim.Adam(self.student.parameters(), lr=self.learning_rate)
        return optimizer


# callbacks
def get_callbacks():
    # 监控student_val_loss，不再减小了就停止
    early_stopping = callbacks.EarlyStopping(monitor='student_val_loss',
                                                            min_delta=1e-4, patience=2,
                                                            verbose=False, mode='min')
    # checkpoint
    model_checkpoint = callbacks.ModelCheckpoint(save_weights_only=True)

    # 监控学习率
    lr_monitor = callbacks.LearningRateMonitor(logging_interval='step')

    return [early_stopping, model_checkpoint, lr_monitor]

In [7]:
# %%
# NOTE:训练参数
train_args = Arguments()

train_args.learning_rate = 1e-3
train_args.max_epochs = 10
train_args.temperature = 2
train_args.alpha = 0.8
# %%
# 实例化pl_moudle
teacher = TeacherNet()
# 载入权重
teacher.load_state_dict(torch.load("./mnist_cnn.pt"))
student = StudentNet()
kd_moudle = KDMoudle(teacher=teacher,
                     student=student,
                     learning_rate=train_args.learning_rate,
                     temperature=train_args.temperature,
                     alpha=train_args.alpha)

In [8]:
trainer = L.Trainer(
    # fast_dev_run=1,  # debug时开启，只跑一个batch的train、val和test
    max_epochs=train_args.max_epochs,
    callbacks=get_callbacks(),

    log_every_n_steps=1)


# %%
# training
trainer.fit(kd_moudle, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader)
# %%
# testing
trainer.test(dataloaders=test_dataloader)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 4070') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
Missing logger folder: c:\Users\admin\Desktop\coolShare\edge computing\2.knowledge distilling and low-rank approximation\lightning_logs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type       | Params
---------------------------------------
0 | teacher | TeacherNet | 1.2 M 
1 | student | StudentNet | 12.7 K
---------------------------------------
12.7 K    Trainable params
1.2 M     Non-trainable params
1.2 M     Total params
4.850     Total estimated model params siz

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]



Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=10` reached.
Restoring states from the checkpoint path at c:\Users\admin\Desktop\coolShare\edge computing\2.knowledge distilling and low-rank approximation\lightning_logs\version_0\checkpoints\epoch=9-step=7500.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from the checkpoint at c:\Users\admin\Desktop\coolShare\edge computing\2.knowledge distilling and low-rank approximation\lightning_logs\version_0\checkpoints\epoch=9-step=7500.ckpt


Testing: |          | 0/? [00:00<?, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
    student_test_acc        0.9462000131607056
    teacher_test_acc         0.991100013256073
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'student_test_acc': 0.9462000131607056,
  'teacher_test_acc': 0.991100013256073}]