# NNI剪枝实践

模型剪枝是一种通过减少模型权重大小或中间状态大小来减少模型大小和计算量的技术。修剪 DNN 模型有以下三种常见做法：
- 预训练模型 -> 修剪模型 -> 微调修剪后的模型
- 在训练期间修剪模型（即修剪感知训练）-> 微调修剪后的模型
- 修剪模型 -> 从头开始​​训练修剪后的模型

NNI支持上述所有的方式，本节以第一种方法为例来展示NNI的用法

> 参考文档：https://nni.readthedocs.io/zh/stable/tutorials/pruning_quick_start.html

In [51]:
import random
import numpy as np
import torch
from matplotlib import pyplot as plt
from torch import nn
from torch.optim import *
from torch.optim.lr_scheduler import *
from torch.utils.data import DataLoader
from torchvision.transforms import *
from tqdm.auto import tqdm
import torch.nn.functional as F
from torchvision import datasets

In [52]:
random.seed(0)
np.random.seed(0)
torch.manual_seed(0)

<torch._C.Generator at 0x7fcd038a0670>

## 加载ch02中训练好的model

In [53]:
# 设置归一化
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

# 获取数据集
train_dataset = datasets.MNIST(root='../../ch02/data/mnist', train=True, download=True, transform=transform)  
test_dataset = datasets.MNIST(root='../../ch02/data/mnist', train=False, download=True, transform=transform)  # train=True训练集，=False测试集

# 设置DataLoader
batch_size = 64
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [54]:
# 定义一个LeNet网络
class LeNet(nn.Module):
    def __init__(self, num_classes=10):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5)
        self.conv2 = nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5)
        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.fc1 = nn.Linear(in_features=16 * 4 * 4, out_features=120)
        self.fc2 = nn.Linear(in_features=120, out_features=84)
        self.fc3 = nn.Linear(in_features=84, out_features=num_classes)

    def forward(self, x):
        x = self.maxpool(F.relu(self.conv1(x)))
        x = self.maxpool(F.relu(self.conv2(x)))

        x = x.view(x.size()[0], -1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)

        return x
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("cpu")
model = LeNet().to(device=device)

In [55]:
def train(
  model: nn.Module,
  dataloader: DataLoader,
  criterion: nn.Module,
  optimizer: Optimizer,
  callbacks = None
) -> None:
  model.train()

  for inputs, targets in tqdm(dataloader, desc='train', leave=False):
    inputs = inputs.to(device)
    targets = targets.to(device)
    # print(inputs.shape)
    # Reset the gradients (from the last iteration)
    optimizer.zero_grad()

    # Forward inference
    outputs = model(inputs).cpu()
    loss = criterion(outputs, targets)

    # Backward propagation
    loss.backward()

    # Update optimizer 
    optimizer.step()

    if callbacks is not None:
        for callback in callbacks:
            callback()

In [56]:
@torch.inference_mode()
def evaluate(
  model: nn.Module,
  dataloader: DataLoader,
  verbose=True,
) -> float:
  model.eval()

  num_samples = 0
  num_correct = 0

  for inputs, targets in tqdm(dataloader, desc="eval", leave=False,
                              disable=not verbose):
    inputs = inputs.to(device)
    targets = targets.to(device)
  
    # Inference
    outputs = model(inputs).cpu()

    # Convert logits to class indices
    outputs = outputs.argmax(dim=1)

    # Update metrics
    num_samples += targets.size(0)
    num_correct += (outputs == targets).sum()

  return (num_correct / num_samples * 100).item()

In [57]:
# 加载参数信息
checkpoint = torch.load('../../ch02/model.pt')
# 加载状态字典到模型
model.load_state_dict(checkpoint)
# 查看原始模型结构
print(model)

LeNet(
  (conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (maxpool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (fc1): Linear(in_features=256, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=10, bias=True)
)


In [58]:
# 查看原始模型参数量
print('Original model paramater number: ', sum([param.numel() for param in model.parameters()]))

Original model paramater number:  44426


In [59]:
# 查看原始模型准确率
model_accuracy = evaluate(model, test_loader)
print(f"Model has accuracy={model_accuracy:.2f}%")

eval:   0%|          | 0/157 [00:00<?, ?it/s]

Model has accuracy=97.99%


## 模型剪枝
使用L1NormPruner减掉80%的参数

In [60]:
config_list = [{
    'op_types': ['Linear', 'Conv2d'],
    'exclude_op_names': ['fc3'],
    'sparse_ratio': 0.8
}]

In [61]:
from nni.compression.pruning import L1NormPruner
pruner = L1NormPruner(model, config_list)

# show the wrapped model structure, `PrunerModuleWrapper` have wrapped the layers that configured in the config_list.
print(model)

LeNet(
  (conv1): Conv2d(
    1, 6, kernel_size=(5, 5), stride=(1, 1)
    (_nni_wrapper): ModuleWrapper(module=Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1)), module_name=conv1)
  )
  (conv2): Conv2d(
    6, 16, kernel_size=(5, 5), stride=(1, 1)
    (_nni_wrapper): ModuleWrapper(module=Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1)), module_name=conv2)
  )
  (maxpool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (fc1): Linear(
    in_features=256, out_features=120, bias=True
    (_nni_wrapper): ModuleWrapper(module=Linear(in_features=256, out_features=120, bias=True), module_name=fc1)
  )
  (fc2): Linear(
    in_features=120, out_features=84, bias=True
    (_nni_wrapper): ModuleWrapper(module=Linear(in_features=120, out_features=84, bias=True), module_name=fc2)
  )
  (fc3): Linear(in_features=84, out_features=10, bias=True)
)


In [62]:
# compress the model and generate the masks
_, masks = pruner.compress()
# show the masks sparsity
for name, mask in masks.items():
    print(name, ' sparsity : ', '{:.2}'.format(mask['weight'].sum() / mask['weight'].numel()))

conv2  sparsity :  0.25
conv1  sparsity :  0.33
fc1  sparsity :  0.2
fc2  sparsity :  0.2


## 模型加速
剪枝算法通常使用权重掩模来模拟真实的剪枝。掩码可用于检查特定修剪（或稀疏性）的模型性能，但没有真正的加速。NNI提供了实际加速功能。
为了加速模型，应该替换修剪后的层，或者用较小的层替换粗粒度掩模，或者用稀疏内核替换细粒度掩模。粗粒度掩模通常会改变权重或输入/输出张量的形状，因此，NNI利用形状推断来检查是否有其他未修剪的层也应该由于形状变化而被替换。因此，在设计中，主要有两个步骤：首先，进行形状推断，找出所有应该替换的模块；其次，更换模块。第一步需要模型的拓扑（即连接），NNI使用基于torch.fx的跟踪器来获取 PyTorch 的模型图。模块的新形状是由NNI自动推理的，前向输出和后向输入中未改变的部分准备减少。

In [63]:
# need to unwrap the model, if the model is wrapped before speedup
pruner.unwrap_model()

# speedup the model, for more information about speedup, please refer :doc:`pruning_speedup`.
from nni.compression.speedup import ModelSpeedup

ModelSpeedup(model, torch.rand(3, 1, 28, 28).to(device), masks).speedup_model()

[2024-09-23 21:25:34] [32mStart to speedup the model...[0m
[2024-09-23 21:25:34] [32mResolve the mask conflict before mask propagate...[0m
[2024-09-23 21:25:34] [32mdim0 sparsity: 0.727273[0m
[2024-09-23 21:25:34] [32mdim1 sparsity: 0.000000[0m
0 Filter
[2024-09-23 21:25:34] [32mdim0 sparsity: 0.727273[0m
[2024-09-23 21:25:34] [32mdim1 sparsity: 0.000000[0m
[2024-09-23 21:25:34] [32mInfer module masks...[0m
[2024-09-23 21:25:34] [32mPropagate original variables[0m
[2024-09-23 21:25:34] [32mPropagate variables for placeholder: x, output mask:  0.0000 [0m
[2024-09-23 21:25:34] [32mPropagate variables for call_module: conv1, weight:  0.6667 bias:  0.6667 , output mask:  0.0000 [0m
[2024-09-23 21:25:34] [32mPropagate variables for call_function: relu, output mask:  0.0000 [0m
[2024-09-23 21:25:34] [32mPropagate variables for call_module: maxpool, , output mask:  0.0000 [0m
[2024-09-23 21:25:34] [32mPropagate variables for call_module: conv2, weight:  0.7500 bias:  

LeNet(
  (conv1): Conv2d(1, 2, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(2, 4, kernel_size=(5, 5), stride=(1, 1))
  (maxpool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (fc1): Linear(in_features=64, out_features=24, bias=True)
  (fc2): Linear(in_features=24, out_features=17, bias=True)
  (fc3): Linear(in_features=17, out_features=10, bias=True)
)

In [64]:
# 查看裁剪后的模型结构，对比裁剪前后的不同
print(model)

LeNet(
  (conv1): Conv2d(1, 2, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(2, 4, kernel_size=(5, 5), stride=(1, 1))
  (maxpool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (fc1): Linear(in_features=64, out_features=24, bias=True)
  (fc2): Linear(in_features=24, out_features=17, bias=True)
  (fc3): Linear(in_features=17, out_features=10, bias=True)
)


In [65]:
# 查看裁剪后的模型参数
print('Pruned model paramater number: ', sum([param.numel() for param in model.parameters()]))

Pruned model paramater number:  2421


In [66]:
# 查看裁剪后的模型准确率
model_accuracy = evaluate(model, test_loader)
print(f"Model has accuracy={model_accuracy:.2f}%")

eval:   0%|          | 0/157 [00:00<?, ?it/s]

Model has accuracy=30.50%


## 微调模型

In [67]:
lr = 0.01
momentum = 0.5
num_epoch = 5

optimizer = torch.optim.SGD(model.parameters(),  lr=lr, momentum=momentum)  # lr学习率，momentum冲量
criterion = nn.CrossEntropyLoss()  # 交叉熵损失


best_accuracy = 0

for epoch in range(num_epoch):
    train(model, train_loader, criterion, optimizer)
    accuracy = evaluate(model, test_loader)
    is_best = accuracy > best_accuracy
    if is_best:
        best_accuracy = accuracy       
    print(f'Epoch{epoch+1:>2d} Accuracy {accuracy:.2f}% / Best Accuracy: {best_accuracy:.2f}%')

train:   0%|          | 0/938 [00:00<?, ?it/s]

eval:   0%|          | 0/157 [00:00<?, ?it/s]

Epoch 1 Accuracy 94.66% / Best Accuracy: 94.66%


train:   0%|          | 0/938 [00:00<?, ?it/s]

eval:   0%|          | 0/157 [00:00<?, ?it/s]

Epoch 2 Accuracy 96.36% / Best Accuracy: 96.36%


train:   0%|          | 0/938 [00:00<?, ?it/s]

eval:   0%|          | 0/157 [00:00<?, ?it/s]

Epoch 3 Accuracy 97.05% / Best Accuracy: 97.05%


train:   0%|          | 0/938 [00:00<?, ?it/s]

eval:   0%|          | 0/157 [00:00<?, ?it/s]

Epoch 4 Accuracy 96.16% / Best Accuracy: 97.05%


train:   0%|          | 0/938 [00:00<?, ?it/s]

eval:   0%|          | 0/157 [00:00<?, ?it/s]

Epoch 5 Accuracy 97.25% / Best Accuracy: 97.25%


In [68]:
# 查看微调后模型的结构
print(model)

LeNet(
  (conv1): Conv2d(1, 2, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(2, 4, kernel_size=(5, 5), stride=(1, 1))
  (maxpool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (fc1): Linear(in_features=64, out_features=24, bias=True)
  (fc2): Linear(in_features=24, out_features=17, bias=True)
  (fc3): Linear(in_features=17, out_features=10, bias=True)
)


In [69]:
# 查看微调后的模型参数
print('Fine-tuning  model paramater number: ', sum([param.numel() for param in model.parameters()]))

Fine-tuning  model paramater number:  2421
