# 剪枝与量化结合

<!-- ## 前提条件

1. 建议创建虚拟环境
2. 配置基础环境依赖
3. 如果有支持CUDA的GPU, 加入 -->

## 准备
1. 采用**CIFAR-10**数据集
>CIFAR-10 数据集是机器学习和计算机视觉中广泛使用的基准测试，由 60,000 张 32x32 彩色图像组成，分为 10 个类别，每个类别有 6,000 张图像。
2. 采用**ResNet18**预训练模型
>残差神经网络（也称为残差网络或 ResNet）是一种开创性的深度学习模型，其中权重层参考层输入学习残差函数。它于 2015 年开发用于图像识别，并赢得了当年的 ImageNet 大规模视觉识别挑战赛 （ILSVRC）。
对于ResNet-18的模型结构可视化和其他信息，可以参考李沐老师的[d2l](https://d2l.ai/chapter_convolutional-modern/resnet.html)

In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn.utils.prune as prune
import torch.optim as optim


In [None]:
# 载入并归一化 CIFAR-10
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# 将数据集划分为训练集和测试集
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
                                         shuffle=False, num_workers=2)

In [9]:
# 下载ResNet-18预训练模型参数
resnet18 = torchvision.models.resnet18()
# ResNet-18模型结构
resnet18

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [11]:
def prune_model(model, pruning_rate=0.1):
    for name, module in model.named_modules():
        if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear):
            
            # 使用非结构化L1范数剪枝
            prune.l1_unstructured(module, name='weight', amount=pruning_rate)
            
            prune.remove(module, 'weight')

In [None]:
criterion = torch.nn.CrossEntropyLoss()
optimizer = optim.SGD(resnet18.parameters(), lr=0.001, momentum=0.9)

def train_model(model, epochs=10, prune_every_n_epochs=5):
    for epoch in range(epochs):
        running_loss = 0.0
        for i, data in enumerate(trainloader, 0):
            inputs, labels = data
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            
            if i % 2000 == 1999:
                print(f'[{epoch + 1}, {i + 1}] loss: {running_loss / 2000}')
                running_loss = 0.0
                
        if (epoch + 1) % prune_every_n_epochs == 0:
            print(f'Pruning after epoch {epoch + 1}')
            prune_model(model, pruning_rate=0.1)
            print('Pruning done.')

train_model(resnet18)

In [None]:
# 将模型量化为int8类型
resnet18_int8 = torch.quantization.convert(resnet18, inplace=False)

# 保存量化后的模型
# torch.save(resnet18_int8.state_dict(), 'resnet18_int8.pth')

# 加载并评估量化后的模型
resnet18_int8_loaded = torchvision.models.resnet18()
resnet18_int8_loaded.qconfig = torch.quantization.get_default_qconfig('fbgemm')
resnet18_int8_loaded = torch.quantization.prepare(resnet18_int8_loaded)
resnet18_int8_loaded = torch.quantization.convert(resnet18_int8_loaded)
resnet18_int8_loaded.load_state_dict(torch.load('resnet18_quantized.pth'))
