# 使用NNI对模型进行剪枝、量化、蒸馏压缩
本节使用NNI框架对Resnet18模型进行融合压缩（剪枝、量化、蒸馏），数据集使用Cifar10

>参考链接：https://github.com/microsoft/nni/blob/master/examples/compression/fusion/pqd_fuse.py

In [17]:
import pickle

import torch

from models import (
    build_resnet18,
    prepare_dataloader,
    prepare_optimizer,
    train,
    training_step,
    evaluate,
    device
)

In [18]:
from nni.compression import TorchEvaluator
from nni.compression.base.compressor import Quantizer
from nni.compression.distillation import DynamicLayerwiseDistiller
from nni.compression.pruning import TaylorPruner, AGPPruner
from nni.compression.quantization import QATQuantizer
from nni.compression.utils import auto_set_denpendency_group_ids
from nni.compression.speedup import ModelSpeedup

## 使用Resnet18模型，数据集使用Cifar10

In [19]:
#  Resnet18 on Cifar10
model = build_resnet18()
_, test_loader = prepare_dataloader()
print('Original model paramater number: ', sum([param.numel() for param in model.parameters()]))
print('Original model acc: ', evaluate(model, test_loader), '%')


Files already downloaded and verified
Original model paramater number:  11181642
Original model acc:  12.24 %


In [None]:
print(model)

## 微调Resnet18

In [20]:
optimizer = prepare_optimizer(model)
train(model, optimizer, training_step, lr_scheduler=None, max_steps=None, max_epochs=30)
print('Original model paramater number: ', sum([param.numel() for param in model.parameters()]))
print('Original model after 30 epochs finetuning acc: ', evaluate(model, test_loader), '%')

Files already downloaded and verified
[Training Epoch 0 / Step 391] Final Acc: 71.12%
[Training Epoch 1 / Step 782] Final Acc: 52.1%
[Training Epoch 2 / Step 1173] Final Acc: 70.93%
[Training Epoch 3 / Step 1564] Final Acc: 73.81%
[Training Epoch 4 / Step 1955] Final Acc: 77.49%
[Training Epoch 5 / Step 2346] Final Acc: 78.85%
[Training Epoch 6 / Step 2737] Final Acc: 79.18%
[Training Epoch 7 / Step 3128] Final Acc: 79.7%
[Training Epoch 8 / Step 3519] Final Acc: 81.5%
[Training Epoch 9 / Step 3910] Final Acc: 80.98%
[Training Epoch 10 / Step 4301] Final Acc: 81.38%
[Training Epoch 11 / Step 4692] Final Acc: 82.42%
[Training Epoch 12 / Step 5083] Final Acc: 81.52%
[Training Epoch 13 / Step 5474] Final Acc: 81.56%
[Training Epoch 14 / Step 5865] Final Acc: 81.85%
[Training Epoch 15 / Step 6256] Final Acc: 82.57%
[Training Epoch 16 / Step 6647] Final Acc: 83.29%
[Training Epoch 17 / Step 7038] Final Acc: 83.21%
[Training Epoch 18 / Step 7429] Final Acc: 83.47%
[Training Epoch 19 / Step 7

创建一个与学生模型相同的教师模型，并复制微调后的学生模型的权重。

In [21]:
# build a teacher model
teacher_model = build_resnet18()
teacher_model.load_state_dict(pickle.loads(pickle.dumps(model.state_dict())))

<All keys matched successfully>

## 剪枝

为模型中所有 Conv2d 层设置剪枝，指定稀疏率为 0.5，并为 BatchNorm2d 层设置相应的目标对齐策略。

In [22]:
# create pruner
bn_list = [module_name for module_name, module in model.named_modules() if isinstance(module, torch.nn.BatchNorm2d)]
p_config_list = [{
    'op_types': ['Conv2d'],
    'sparse_ratio': 0.5
}, *[{
    'op_names': [name],
    'target_names': ['_output_'],
    'target_settings': {
        '_output_': {
            'align': {
                'module_name': name.replace('bn', 'conv') if 'bn' in name else name.replace('downsample.1', 'downsample.0'),
                'target_name': 'weight',
                'dims': [0],
            },
            'granularity': 'per_channel'
        }
    }
} for name in bn_list]]
dummy_input = torch.rand(8, 3, 224, 224).to(device)
p_config_list = auto_set_denpendency_group_ids(model, p_config_list, dummy_input)

初始化优化器和评估器，并使用 TaylorPruner 和 AGPPruner 进行剪枝。TaylorPruner 使用泰勒展开法计算权重的重要性，AGPPruner 是一种渐进式剪枝方法。

In [23]:
optimizer = prepare_optimizer(model)
evaluator = TorchEvaluator(train, optimizer, training_step)
sub_pruner = TaylorPruner(model, p_config_list, evaluator, training_steps=100)
scheduled_pruner = AGPPruner(sub_pruner, interval_steps=100, total_times=30)

## 量化

使用 QATQuantizer 创建量化器，对 Conv2d 和 BatchNorm2d 层进行量化，采用 int8 精度。

In [24]:
# create quantizer
q_config_list = [{
    'op_types': ['Conv2d'],
    'quant_dtype': 'int8',
    'target_names': ['_input_'],
    'granularity': 'per_channel'
}, {
    'op_types': ['Conv2d'],
    'quant_dtype': 'int8',
    'target_names': ['weight'],
    'granularity': 'out_channel'
}, {
    'op_types': ['BatchNorm2d'],
    'quant_dtype': 'int8',
    'target_names': ['_output_'],
    'granularity': 'per_channel'
}]

quantizer = QATQuantizer.from_compressor(scheduled_pruner, q_config_list, quant_start_step=100)

## 蒸馏
定义 teacher_predict 函数用于从教师模型生成预测，并创建 DynamicLayerwiseDistiller 进行逐层蒸馏。

In [25]:
# create distiller
def teacher_predict(batch, teacher_model):
    return teacher_model(batch[0])

d_config_list = [{
    'op_types': ['Conv2d'],
    'lambda': 0.1,
    'apply_method': 'mse',
}]
distiller = DynamicLayerwiseDistiller.from_compressor(quantizer, d_config_list, teacher_model, teacher_predict, 0.1)

# max_steps contains (30 iterations 100 steps agp taylor pruning, and 3000 steps finetuning)
distiller.compress(max_steps=100 * 60, max_epochs=None)
distiller.unwrap_model()
distiller.unwrap_teacher_model()

Files already downloaded and verified
[Training Epoch 0 / Step 391] Final Acc: 80.23%
[Training Epoch 1 / Step 782] Final Acc: 79.04%
[Training Epoch 2 / Step 1173] Final Acc: 78.22%
[Training Epoch 3 / Step 1564] Final Acc: 78.4%
[Training Epoch 4 / Step 1955] Final Acc: 78.97%
[Training Epoch 5 / Step 2346] Final Acc: 79.02%
[Training Epoch 6 / Step 2737] Final Acc: 78.42%
[Training Epoch 7 / Step 3128] Final Acc: 79.56%
[Training Epoch 8 / Step 3519] Final Acc: 78.64%
[Training Epoch 9 / Step 3910] Final Acc: 79.56%
[Training Epoch 10 / Step 4301] Final Acc: 79.06%
[Training Epoch 11 / Step 4692] Final Acc: 79.5%
[Training Epoch 12 / Step 5083] Final Acc: 79.39%
[Training Epoch 13 / Step 5474] Final Acc: 80.31%
[Training Epoch 14 / Step 5865] Final Acc: 79.6%
[Training Epoch 15 / Step 6000] Final Acc: 80.44%


## 加速模型

In [26]:
# speed up model
masks = scheduled_pruner.get_masks()
speedup = ModelSpeedup(model, dummy_input, masks)
model = speedup.speedup_model()

print('Compressed model paramater number: ', sum([param.numel() for param in model.parameters()]))
print('Compressed model without finetuning & qsim acc: ', evaluate(model, test_loader), '%')

[2024-09-23 17:38:35] [32mStart to speedup the model...[0m
[2024-09-23 17:38:35] [32mResolve the mask conflict before mask propagate...[0m
[2024-09-23 17:38:35] [32mdim0 sparsity: 0.500000[0m
[2024-09-23 17:38:35] [32mdim1 sparsity: 0.000000[0m
0 Filter
[2024-09-23 17:38:35] [32mdim0 sparsity: 0.500000[0m
[2024-09-23 17:38:35] [32mdim1 sparsity: 0.000000[0m
[2024-09-23 17:38:35] [32mInfer module masks...[0m
[2024-09-23 17:38:35] [32mPropagate original variables[0m
[2024-09-23 17:38:35] [32mPropagate variables for placeholder: x, output mask:  0.0000 [0m
[2024-09-23 17:38:35] [32mPropagate variables for call_module: conv1, weight:  0.5000 , output mask:  0.0000 [0m
[2024-09-23 17:38:35] [32mPropagate variables for call_module: bn1, _output_0:  0.5000 , output mask:  0.0000 [0m
[2024-09-23 17:38:35] [32mPropagate variables for call_module: relu, , output mask:  0.0000 [0m
[2024-09-23 17:38:35] [32mPropagate variables for call_module: maxpool, , output mask:  0.00

获取量化的校准配置，并使用 trans 函数调整配置，仿真量化过程。

In [27]:
# simulate quantization
calibration_config = quantizer.get_calibration_config()

In [28]:
def trans(calibration_config, speedup: ModelSpeedup):
    for node, node_info in speedup.node_infos.items():
        if node.op == 'call_module' and node.target in calibration_config:
            # assume the module only has one input and one output
            input_mask = speedup.node_infos[node.args[0]].output_masks
            param_mask = node_info.param_masks
            output_mask = node_info.output_masks

            module_cali_config = calibration_config[node.target]
            if '_input_0' in module_cali_config:
                reduce_dims = list(range(len(input_mask.shape)))
                reduce_dims.remove(1)
                idxs = torch.nonzero(input_mask.sum(reduce_dims), as_tuple=True)[0].cpu()
                module_cali_config['_input_0']['scale'] = module_cali_config['_input_0']['scale'].index_select(1, idxs)
                module_cali_config['_input_0']['zero_point'] = module_cali_config['_input_0']['zero_point'].index_select(1, idxs)
            if '_output_0' in module_cali_config:
                reduce_dims = list(range(len(output_mask.shape)))
                reduce_dims.remove(1)
                idxs = torch.nonzero(output_mask.sum(reduce_dims), as_tuple=True)[0].cpu()
                module_cali_config['_output_0']['scale'] = module_cali_config['_output_0']['scale'].index_select(1, idxs)
                module_cali_config['_output_0']['zero_point'] = module_cali_config['_output_0']['zero_point'].index_select(1, idxs)
            if 'weight' in module_cali_config:
                reduce_dims = list(range(len(param_mask['weight'].shape)))
                reduce_dims.remove(0)
                idxs = torch.nonzero(param_mask['weight'].sum(reduce_dims), as_tuple=True)[0].cpu()
                module_cali_config['weight']['scale'] = module_cali_config['weight']['scale'].index_select(0, idxs)
                module_cali_config['weight']['zero_point'] = module_cali_config['weight']['zero_point'].index_select(0, idxs)
            if 'bias' in module_cali_config:
                idxs = torch.nonzero(param_mask['bias'], as_tuple=True)[0].cpu()
                module_cali_config['bias']['scale'] = module_cali_config['bias']['scale'].index_select(0, idxs)
                module_cali_config['bias']['zero_point'] = module_cali_config['bias']['zero_point'].index_select(0, idxs)
    return calibration_config

In [29]:
calibration_config = trans(calibration_config, speedup)

In [30]:
sim_quantizer = Quantizer(model, q_config_list)
sim_quantizer.update_calibration_config(calibration_config)

In [31]:
print('Compressed model paramater number: ', sum([param.numel() for param in model.parameters()]))
print('Compressed model without finetuning acc: ', evaluate(model, test_loader), '%')

Compressed model paramater number:  2801450
Compressed model without finetuning acc:  80.44 %
