# 模型剪枝实践
Pytorch在1.4.0版本开始，加入了剪枝操作，在torch.nn.utils.prune模块中，本项目按照剪枝范围划分，将其分以下几种剪枝方式:
- 局部剪枝（Local Pruning）
  - 结构化剪枝
    - 随机结构化剪枝（random_structured）
    - 范数结构化剪枝（ln_structured）
  - 非结构化剪枝
    - 随机非结构化剪枝（random_unstructured）
    - 范数非结构化剪枝（l1_unstructured）
- 全局剪枝（Global Pruning）
  - 非结构化剪枝（global_unstructured）
- 自定义剪枝（Custom  Pruning）
  
**注：** 全局剪枝只有非结构化剪枝方式。

## 一、局部剪枝
首先介绍局部剪枝（Local Pruning）方式，指的是对网络的单个层或局部范围内进行剪枝。

### 1.1 结构化剪枝
按照剪枝方式划分，可以分为结构化剪枝和非结构化剪枝方式。非结构化剪枝会随机地将一些权重参数变为0，结构化剪枝则将某个维度某些通道变成0。

#### 1.1.1 随机结构化剪枝（random_structured）

In [80]:
import torch
from torch import nn
import torch.nn.utils.prune as prune
import torch.nn.functional as F
from torchsummary import summary

创建一个经典的LeNet网络

In [81]:
# 定义一个LeNet网络
class LeNet(nn.Module):
    def __init__(self, num_classes=10):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5)
        self.conv2 = nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5)
        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.fc1 = nn.Linear(in_features=16 * 4 * 4, out_features=120)
        self.fc2 = nn.Linear(in_features=120, out_features=84)
        self.fc3 = nn.Linear(in_features=84, out_features=num_classes)

    def forward(self, x):
        x = self.maxpool(F.relu(self.conv1(x)))
        x = self.maxpool(F.relu(self.conv2(x)))

        x = x.view(x.size()[0], -1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)

        return x
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = LeNet().to(device=device)

In [82]:
# 打印模型结构
summary(model, input_size=(1, 28, 28))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1            [-1, 6, 24, 24]             156
         MaxPool2d-2            [-1, 6, 12, 12]               0
            Conv2d-3             [-1, 16, 8, 8]           2,416
         MaxPool2d-4             [-1, 16, 4, 4]               0
            Linear-5                  [-1, 120]          30,840
            Linear-6                   [-1, 84]          10,164
            Linear-7                   [-1, 10]             850
Total params: 44,426
Trainable params: 44,426
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.04
Params size (MB): 0.17
Estimated Total Size (MB): 0.22
----------------------------------------------------------------


In [83]:
# 打印第一个卷积层的参数
module1 = model.conv1
print(list(module1.named_parameters()))

[('weight', Parameter containing:
tensor([[[[ 0.0199, -0.1327, -0.1888,  0.0135, -0.1062],
          [-0.1044, -0.1658,  0.1823, -0.0305, -0.1942],
          [-0.1320,  0.0258, -0.0003, -0.0305, -0.1435],
          [-0.0992, -0.0019,  0.1422, -0.1876, -0.1113],
          [ 0.0927, -0.0115,  0.1514,  0.0101,  0.1839]]],


        [[[ 0.1230, -0.1329,  0.0667,  0.0866, -0.1326],
          [ 0.1693,  0.0348,  0.0219,  0.0047, -0.0020],
          [ 0.1351, -0.0226,  0.1412, -0.0835,  0.1451],
          [-0.1025, -0.0730, -0.0170, -0.0361, -0.0915],
          [ 0.0103,  0.0657, -0.0003, -0.1348,  0.1867]]],


        [[[-0.0462, -0.1972,  0.1027,  0.0306,  0.0201],
          [ 0.1154,  0.1172, -0.0053,  0.1465,  0.1623],
          [-0.0238, -0.0572,  0.1679,  0.0517, -0.0271],
          [-0.0193, -0.0926, -0.1101,  0.1252,  0.1987],
          [-0.1187,  0.0380, -0.1373,  0.1509, -0.1453]]],


        [[[-0.1933,  0.1175,  0.1616, -0.0805,  0.0404],
          [-0.1942, -0.1721,  0.1894,  0.0

In [84]:
# 打印module1中的属性张量named_buffers，初始时为空列表
print(list(module1.named_buffers()))

[]


In [85]:
# 打印模型的状态字典，状态字典里包含了所有的参数
print(model.state_dict().keys())

odict_keys(['conv1.weight', 'conv1.bias', 'conv2.weight', 'conv2.bias', 'fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias', 'fc3.weight', 'fc3.bias'])


In [86]:
# 第一个参数: module1, 代表要进行剪枝的特定模块, 这里指的是module=model.conv1,
#             说明这里要对第一个卷积层执行剪枝.
# 第二个参数: name, 代表要对选中的模块中的哪些参数执行剪枝.
#             这里设定为name="weight", 说明是对网络中的weight剪枝, 而不对bias剪枝.
# 第三个参数: amount, 代表要对模型中特定比例或绝对数量的参数执行剪枝.
#             amount是一个介于0.0-1.0的float数值,代表比例, 或者一个正整数，代表指定剪裁掉多少个参数.
# 第四个参数: dim, 代表要进行剪枝通道(channel)的维度索引.
#            

prune.random_structured(module1, name="weight", amount=0.2, dim=0)

Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))

In [87]:
# 再次打印模型的状态字典，观察conv1层
print(model.state_dict().keys())

odict_keys(['conv1.bias', 'conv1.weight_orig', 'conv1.weight_mask', 'conv2.weight', 'conv2.bias', 'fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias', 'fc3.weight', 'fc3.bias'])


In [88]:
# 再次打印module1中的属性张量named_buffers
print(list(module1.named_parameters()))

[('bias', Parameter containing:
tensor([ 0.1717, -0.1737, -0.0202,  0.1787,  0.0203, -0.1731],
       requires_grad=True)), ('weight_orig', Parameter containing:
tensor([[[[ 0.0199, -0.1327, -0.1888,  0.0135, -0.1062],
          [-0.1044, -0.1658,  0.1823, -0.0305, -0.1942],
          [-0.1320,  0.0258, -0.0003, -0.0305, -0.1435],
          [-0.0992, -0.0019,  0.1422, -0.1876, -0.1113],
          [ 0.0927, -0.0115,  0.1514,  0.0101,  0.1839]]],


        [[[ 0.1230, -0.1329,  0.0667,  0.0866, -0.1326],
          [ 0.1693,  0.0348,  0.0219,  0.0047, -0.0020],
          [ 0.1351, -0.0226,  0.1412, -0.0835,  0.1451],
          [-0.1025, -0.0730, -0.0170, -0.0361, -0.0915],
          [ 0.0103,  0.0657, -0.0003, -0.1348,  0.1867]]],


        [[[-0.0462, -0.1972,  0.1027,  0.0306,  0.0201],
          [ 0.1154,  0.1172, -0.0053,  0.1465,  0.1623],
          [-0.0238, -0.0572,  0.1679,  0.0517, -0.0271],
          [-0.0193, -0.0926, -0.1101,  0.1252,  0.1987],
          [-0.1187,  0.0380, -0.

In [89]:
# 再次打印module1中的属性张量named_buffers
print(list(module1.named_buffers()))

[('weight_mask', tensor([[[[1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1.]]],


        [[[1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1.]]],


        [[[1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1.]]],


        [[[0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.]]],


        [[[1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1.]]],


        [[[1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1.]]]]))

结论: 经过剪枝操作后, 原始的权重矩阵weight变成了weight_orig. 并且剪枝前打印为空列表的module.named_buffers(), 现在多了weight_mask参数.

In [90]:
# 打印module1.weight, 看看发现了什么？
print(module1.weight)

tensor([[[[ 0.0199, -0.1327, -0.1888,  0.0135, -0.1062],
          [-0.1044, -0.1658,  0.1823, -0.0305, -0.1942],
          [-0.1320,  0.0258, -0.0003, -0.0305, -0.1435],
          [-0.0992, -0.0019,  0.1422, -0.1876, -0.1113],
          [ 0.0927, -0.0115,  0.1514,  0.0101,  0.1839]]],


        [[[ 0.1230, -0.1329,  0.0667,  0.0866, -0.1326],
          [ 0.1693,  0.0348,  0.0219,  0.0047, -0.0020],
          [ 0.1351, -0.0226,  0.1412, -0.0835,  0.1451],
          [-0.1025, -0.0730, -0.0170, -0.0361, -0.0915],
          [ 0.0103,  0.0657, -0.0003, -0.1348,  0.1867]]],


        [[[-0.0462, -0.1972,  0.1027,  0.0306,  0.0201],
          [ 0.1154,  0.1172, -0.0053,  0.1465,  0.1623],
          [-0.0238, -0.0572,  0.1679,  0.0517, -0.0271],
          [-0.0193, -0.0926, -0.1101,  0.1252,  0.1987],
          [-0.1187,  0.0380, -0.1373,  0.1509, -0.1453]]],


        [[[-0.0000,  0.0000,  0.0000, -0.0000,  0.0000],
          [-0.0000, -0.0000,  0.0000,  0.0000,  0.0000],
          [ 0.0000,

结论: 经过剪枝操作后， 原始的weight变成了weight_orig，并存放在named_parameters中, 对应的剪枝矩阵存放在weight_mask中, 将weight_mask视作掩码张量, 再和weight_orig相乘的结果就存放在了weight中.

**注意:** 剪枝操作后的weight已经不再是module的参数(parameter), 而只是module的一个属性(attribute).

对于每一次剪枝操作, 模型都会对应一个具体的_forward_pre_hooks函数用于剪枝，该函数存放执行过的剪枝操作.

In [91]:
# 打印_forward_pre_hooks
print(module1._forward_pre_hooks)

OrderedDict([(56, <torch.nn.utils.prune.RandomStructured object at 0x0000026F817EA310>)])


#### 1.1.2 范数结构化剪枝（ln_structured）

In [92]:
# 对conv2进行范数结构化剪枝
module2 = model.conv2
# 再次打印模型参数
print(model.state_dict().keys())
print('*'*50)
print(list(module2.named_parameters()))
print('*'*50)
print(list(module2.named_buffers()))
print('*'*50)
print(module2.bias)
print('*'*50)
print(module2._forward_pre_hooks)

odict_keys(['conv1.bias', 'conv1.weight_orig', 'conv1.weight_mask', 'conv2.weight', 'conv2.bias', 'fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias', 'fc3.weight', 'fc3.bias'])
**************************************************
[('weight', Parameter containing:
tensor([[[[ 0.0210, -0.0655,  0.0230,  0.0613, -0.0271],
          [-0.0699,  0.0637, -0.0151, -0.0677,  0.0258],
          [ 0.0524,  0.0122,  0.0715,  0.0532,  0.0003],
          [-0.0535,  0.0701,  0.0092, -0.0380, -0.0733],
          [-0.0700,  0.0787, -0.0032,  0.0699,  0.0131]],

         [[-0.0752, -0.0405, -0.0301, -0.0688,  0.0344],
          [ 0.0741, -0.0329, -0.0034, -0.0423,  0.0569],
          [-0.0753,  0.0278,  0.0411,  0.0265,  0.0094],
          [-0.0207,  0.0377, -0.0015,  0.0758, -0.0226],
          [-0.0267, -0.0606, -0.0489,  0.0108,  0.0403]],

         [[ 0.0536,  0.0718,  0.0239,  0.0475,  0.0764],
          [ 0.0763, -0.0240, -0.0424, -0.0742, -0.0398],
          [-0.0459, -0.0795,  0.0269,  0.0177, -0.

In [93]:
# 第一个参数: module2, 代表要进行剪枝的特定模块, 这里指的是module2=model.conv2,
#             说明这里要对第一个卷积层执行剪枝.
# 第二个参数: name, 代表要对选中的模块中的哪些参数执行剪枝.
#             这里设定为name="weight", 说明是对网络中的weight剪枝, 而不对bias剪枝.
# 第三个参数: amount, 代表要对模型中特定比例或绝对数量的参数执行剪枝.
#             amount是一个介于0.0-1.0的float数值,代表比例, 或者一个正整数，代表指定剪裁掉多少个参数.
# 第四个参数: n, 代表范数类型，这里n=2代表是L2范数.
# 第五个参数: dim, 代表要进行剪枝通道(channel)的维度索引.

prune.ln_structured(module2, name="weight", amount=3, n=2, dim=0)

# 再次打印模型参数
print(model.state_dict().keys())
print('*'*50)
print(list(module2.named_parameters()))
print('*'*50)
print(list(module2.named_buffers()))
print('*'*50)
print(module2.weight)
print('*'*50)
print(module2._forward_pre_hooks)

odict_keys(['conv1.bias', 'conv1.weight_orig', 'conv1.weight_mask', 'conv2.bias', 'conv2.weight_orig', 'conv2.weight_mask', 'fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias', 'fc3.weight', 'fc3.bias'])
**************************************************
[('bias', Parameter containing:
tensor([-0.0397, -0.0521, -0.0037, -0.0105,  0.0705, -0.0071,  0.0337,  0.0399,
         0.0788, -0.0610,  0.0180,  0.0376, -0.0044,  0.0465,  0.0462, -0.0157],
       requires_grad=True)), ('weight_orig', Parameter containing:
tensor([[[[ 0.0210, -0.0655,  0.0230,  0.0613, -0.0271],
          [-0.0699,  0.0637, -0.0151, -0.0677,  0.0258],
          [ 0.0524,  0.0122,  0.0715,  0.0532,  0.0003],
          [-0.0535,  0.0701,  0.0092, -0.0380, -0.0733],
          [-0.0700,  0.0787, -0.0032,  0.0699,  0.0131]],

         [[-0.0752, -0.0405, -0.0301, -0.0688,  0.0344],
          [ 0.0741, -0.0329, -0.0034, -0.0423,  0.0569],
          [-0.0753,  0.0278,  0.0411,  0.0265,  0.0094],
          [-0.0207,  0.0377,

结论: 在module的不同参数集合上应用不同的剪枝策略, 我们发现模型参数中不仅仅有了weight_orig, 也有了bias_orig. 在起到掩码张量作用的named_buffers中, 也同时出现了weight_mask 

In [94]:
# 对module1执行剪枝永久化操作remove
prune.remove(module1, 'weight')
print('*'*50)

# 将剪枝后的模型的状态字典打印出来
print(model.state_dict().keys())

# 再次打印模型参数
print(list(module1.named_parameters()))
print('*'*50)

# r再次打印模型mask buffers参数
print(list(module1.named_buffers()))
print('*'*50)

# 再次打印模型的_forward_pre_hooks
print(module1._forward_pre_hooks)

**************************************************
odict_keys(['conv1.bias', 'conv1.weight', 'conv2.bias', 'conv2.weight_orig', 'conv2.weight_mask', 'fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias', 'fc3.weight', 'fc3.bias'])
[('bias', Parameter containing:
tensor([ 0.1717, -0.1737, -0.0202,  0.1787,  0.0203, -0.1731],
       requires_grad=True)), ('weight', Parameter containing:
tensor([[[[ 0.0199, -0.1327, -0.1888,  0.0135, -0.1062],
          [-0.1044, -0.1658,  0.1823, -0.0305, -0.1942],
          [-0.1320,  0.0258, -0.0003, -0.0305, -0.1435],
          [-0.0992, -0.0019,  0.1422, -0.1876, -0.1113],
          [ 0.0927, -0.0115,  0.1514,  0.0101,  0.1839]]],


        [[[ 0.1230, -0.1329,  0.0667,  0.0866, -0.1326],
          [ 0.1693,  0.0348,  0.0219,  0.0047, -0.0020],
          [ 0.1351, -0.0226,  0.1412, -0.0835,  0.1451],
          [-0.1025, -0.0730, -0.0170, -0.0361, -0.0915],
          [ 0.0103,  0.0657, -0.0003, -0.1348,  0.1867]]],


        [[[-0.0462, -0.1972,  0.1027,

In [95]:
# 对module2执行剪枝永久化操作remove
prune.remove(module2, 'weight')
print('*'*50)

# 将剪枝后的模型的状态字典打印出来
print(model.state_dict().keys())

# 再次打印模型参数
print(list(module2.named_parameters()))
print('*'*50)

# r再次打印模型mask buffers参数
print(list(module2.named_buffers()))
print('*'*50)

# 再次打印模型的_forward_pre_hooks
print(module2._forward_pre_hooks)

**************************************************
odict_keys(['conv1.bias', 'conv1.weight', 'conv2.bias', 'conv2.weight', 'fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias', 'fc3.weight', 'fc3.bias'])
[('bias', Parameter containing:
tensor([-0.0397, -0.0521, -0.0037, -0.0105,  0.0705, -0.0071,  0.0337,  0.0399,
         0.0788, -0.0610,  0.0180,  0.0376, -0.0044,  0.0465,  0.0462, -0.0157],
       requires_grad=True)), ('weight', Parameter containing:
tensor([[[[ 0.0210, -0.0655,  0.0230,  0.0613, -0.0271],
          [-0.0699,  0.0637, -0.0151, -0.0677,  0.0258],
          [ 0.0524,  0.0122,  0.0715,  0.0532,  0.0003],
          [-0.0535,  0.0701,  0.0092, -0.0380, -0.0733],
          [-0.0700,  0.0787, -0.0032,  0.0699,  0.0131]],

         [[-0.0752, -0.0405, -0.0301, -0.0688,  0.0344],
          [ 0.0741, -0.0329, -0.0034, -0.0423,  0.0569],
          [-0.0753,  0.0278,  0.0411,  0.0265,  0.0094],
          [-0.0207,  0.0377, -0.0015,  0.0758, -0.0226],
          [-0.0267, -0.0606,

结论: 对模型的weight执行remove操作后, 模型参数集合中只剩下bias_orig了, weight_orig消失, 变成了weight, 说明针对weight的剪枝已经永久化生效. 对于named_buffers张量打印可以看出, 只剩下bias_mask了, 因为针对weight做掩码的weight_mask已经生效完毕, 不再需要保留了. 同理, 在_forward_pre_hooks中也只剩下针对bias做剪枝的函数了.

## 2.多参数模块的剪枝(Pruning multiple parameters).¶

In [96]:
model = LeNet().to(device=device)

# 打印初始模型的所有状态字典
print(model.state_dict().keys())
print('*'*50)

# 打印初始模型的mask buffers张量字典名称
print(dict(model.named_buffers()).keys())
print('*'*50)

# 对于模型进行分模块参数的剪枝
for name, module in model.named_modules():
    # 对模型中所有的卷积层执行l1_unstructured剪枝操作, 选取20%的参数剪枝
    if isinstance(module, torch.nn.Conv2d):
        prune.l1_unstructured(module, name="weight", amount=0.2)
    # 对模型中所有全连接层执行ln_structured剪枝操作, 选取40%的参数剪枝
    elif isinstance(module, torch.nn.Linear):
        prune.ln_structured(module, name="weight", amount=0.4, n=2, dim=0)

# 打印多参数模块剪枝后的mask buffers张量字典名称
print(dict(model.named_buffers()).keys())
print('*'*50)

# 打印多参数模块剪枝后模型的所有状态字典名称
print(model.state_dict().keys())

odict_keys(['conv1.weight', 'conv1.bias', 'conv2.weight', 'conv2.bias', 'fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias', 'fc3.weight', 'fc3.bias'])
**************************************************
dict_keys([])
**************************************************
dict_keys(['conv1.weight_mask', 'conv2.weight_mask', 'fc1.weight_mask', 'fc2.weight_mask', 'fc3.weight_mask'])
**************************************************
odict_keys(['conv1.bias', 'conv1.weight_orig', 'conv1.weight_mask', 'conv2.bias', 'conv2.weight_orig', 'conv2.weight_mask', 'fc1.bias', 'fc1.weight_orig', 'fc1.weight_mask', 'fc2.bias', 'fc2.weight_orig', 'fc2.weight_mask', 'fc3.bias', 'fc3.weight_orig', 'fc3.weight_mask'])


## 3.全局剪枝(GLobal pruning)

更普遍也更通用的剪枝策略是采用全局剪枝(global pruning), 比如在整体网络的视角下剪枝掉20%的权重参数, 而不是在每一层上都剪枝掉20%的权重参数. 采用全局剪枝后, 不同的层被剪掉的百分比不同.

In [97]:
model = LeNet().to(device=device)

# 首先打印初始化模型的状态字典
print(model.state_dict().keys())
print('*'*50)

# 构建参数集合, 决定哪些层, 哪些参数集合参与剪枝
parameters_to_prune = (
            (model.conv1, 'weight'),
            (model.conv2, 'weight'),
            (model.fc1, 'weight'),
            (model.fc2, 'weight'))

# 调用prune中的全局剪枝函数global_unstructured执行剪枝操作, 此处针对整体模型中的20%参数量进行剪枝
prune.global_unstructured(parameters_to_prune, pruning_method=prune.L1Unstructured, amount=0.2)

# 最后打印剪枝后的模型的状态字典
print(model.state_dict().keys())

odict_keys(['conv1.weight', 'conv1.bias', 'conv2.weight', 'conv2.bias', 'fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias', 'fc3.weight', 'fc3.bias'])
**************************************************
odict_keys(['conv1.bias', 'conv1.weight_orig', 'conv1.weight_mask', 'conv2.bias', 'conv2.weight_orig', 'conv2.weight_mask', 'fc1.bias', 'fc1.weight_orig', 'fc1.weight_mask', 'fc2.bias', 'fc2.weight_orig', 'fc2.weight_mask', 'fc3.weight', 'fc3.bias'])


针对模型剪枝后, 不同的层会有不同比例的权重参数被剪掉, 利用代码打印出来看看:

In [98]:
model = LeNet().to(device=device)

parameters_to_prune = (
            (model.conv1, 'weight'),
            (model.conv2, 'weight'),
            (model.fc1, 'weight'),
            (model.fc2, 'weight'))

prune.global_unstructured(parameters_to_prune, pruning_method=prune.L1Unstructured, amount=0.2)

print(
    "Sparsity in conv1.weight: {:.2f}%".format(
    100. * float(torch.sum(model.conv1.weight == 0))
    / float(model.conv1.weight.nelement())
    ))

print(
    "Sparsity in conv2.weight: {:.2f}%".format(
    100. * float(torch.sum(model.conv2.weight == 0))
    / float(model.conv2.weight.nelement())
    ))

print(
    "Sparsity in fc1.weight: {:.2f}%".format(
    100. * float(torch.sum(model.fc1.weight == 0))
    / float(model.fc1.weight.nelement())
    ))

print(
    "Sparsity in fc2.weight: {:.2f}%".format(
    100. * float(torch.sum(model.fc2.weight == 0))
    / float(model.fc2.weight.nelement())
    ))


print(
    "Global sparsity: {:.2f}%".format(
    100. * float(torch.sum(model.conv1.weight == 0)
               + torch.sum(model.conv2.weight == 0)
               + torch.sum(model.fc1.weight == 0)
               + torch.sum(model.fc2.weight == 0))
         / float(model.conv1.weight.nelement()
               + model.conv2.weight.nelement()
               + model.fc1.weight.nelement()
               + model.fc2.weight.nelement())
    ))

Sparsity in conv1.weight: 8.67%
Sparsity in conv2.weight: 16.67%
Sparsity in fc1.weight: 21.78%
Sparsity in fc2.weight: 15.55%
Global sparsity: 20.00%


结论: 当采用全局剪枝策略的时候(假定20%比例参数参与剪枝), 仅保证模型总体参数量的20%被剪枝掉, 具体到每一层的情况则由模型的具体参数分布情况来定.


## 4.用户自定义剪枝(Custom pruning).

剪枝模型通过继承class BasePruningMethod()来执行剪枝, 内部有若干方法: call, apply_mask, apply, prune, remove等等. 一般来说, 用户只需要实现__init__, 和compute_mask两个函数即可完成自定义的剪枝规则设定.

In [101]:
# 自定义剪枝方法的类, 一定要继承prune.BasePruningMethod
class custom_prune(prune.BasePruningMethod):
    PRUNING_TYPE = "unstructured"

    # 内部实现compute_mask函数, 完成程序员自己定义的剪枝规则, 本质上就是如何去mask掉权重参数
    def compute_mask(self, t, default_mask):
        mask = default_mask.clone()
        # 此处定义的规则是每隔一个参数就遮掩掉一个, 最终参与剪枝的参数量的50%被mask掉
        mask.view(-1)[::2] = 0
        return mask

# 自定义剪枝方法的函数, 内部直接调用剪枝类的方法apply
def custome_unstructured_pruning(module, name):
    custom_prune.apply(module, name)
    return module

In [102]:
import time
# 实例化模型类
model = LeNet().to(device=device)

start = time.time()
# 调用自定义剪枝方法的函数, 对model中的第2个全连接层fc2中的偏置bias执行自定义剪枝
custome_unstructured_pruning(model.fc2, name="bias")

# 剪枝成功的最大标志, 就是拥有了bias_mask参数
print(model.fc2.bias_mask)

# 打印一下自定义剪枝的耗时
duration = time.time() - start
print(duration * 1000, 'ms')

tensor([0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1.,
        0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1.,
        0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1.,
        0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1.,
        0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1.])
1.9998550415039062 ms


结论: 打印出来的bias_mask张量, 完全是按照预定义的方式每隔一位遮掩掉一位, 0和1交替出现, 后续执行remove操作的时候, 原始的bias_orig中的权重就会同样的被每隔一位剪枝掉一位. 在GPU机器上执行自定义剪枝速度特别快, 仅需1.7ms.