In [1]:
import torch
from torch import nn
import torch.nn.utils.prune as prune
import torch.nn.functional as F

创建一个经典的CNN网络

In [2]:
# 定义一个CNN网络
class CNN(nn.Module):
    def __init__(self, num_classes=10):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1)
        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.fc1 = nn.Linear(in_features=64 * 8 * 8, out_features=512)
        self.fc2 = nn.Linear(in_features=512, out_features=num_classes)

    def forward(self, x):
        x = self.maxpool(F.relu(self.conv1(x)))
        x = self.maxpool(F.relu(self.conv2(x)))

        x = x.view(-1, 64 * 8 * 8)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)

        return x
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = CNN().to(device=device)

In [3]:
module = model.conv1
print(list(module.named_parameters()))

[('weight', Parameter containing:
tensor([[[[ 1.5348e-01, -1.4900e-01, -1.1842e-01],
          [ 2.4871e-02,  1.5275e-01,  1.1659e-01],
          [ 1.5592e-01, -1.3514e-01, -7.5043e-03]],

         [[ 1.2052e-02,  1.0355e-01, -7.6704e-02],
          [ 4.9685e-02,  1.3783e-01,  1.2183e-01],
          [-1.2113e-01, -1.0311e-01,  8.8395e-02]],

         [[-9.9745e-02, -5.7398e-02,  1.7571e-01],
          [ 2.3906e-02,  4.4312e-02,  9.1364e-02],
          [-1.2461e-01,  1.4030e-01, -7.7255e-02]]],


        [[[-3.0464e-02, -9.5161e-02, -6.0011e-02],
          [ 1.3134e-01, -1.7915e-01,  3.2132e-02],
          [ 7.6910e-02,  2.6616e-02,  3.2158e-02]],

         [[-1.7325e-01,  5.8404e-03, -1.8473e-02],
          [ 1.4755e-01,  1.9016e-01, -1.6464e-01],
          [ 1.4077e-01,  8.5270e-02, -1.8393e-01]],

         [[-4.3645e-02,  7.0406e-02, -1.7177e-01],
          [ 1.0164e-01,  5.5904e-02,  1.1319e-01],
          [-9.6777e-02, -1.6566e-01,  1.0832e-02]]],


        [[[-8.5839e-02, -1.1341e

In [4]:
print(list(module.named_buffers()))

[]


In [5]:
# 第一个参数: module, 代表要进行剪枝的特定模块, 这里指的是module=model.conv1,
#             说明这里要对第一个卷积层执行剪枝.
# 第二个参数: name, 指定要对选中的模块中的哪些参数执行剪枝.
#             这里设定为name="weight", 意味着对连接网络中的weight剪枝, 而不对bias剪枝.
# 第三个参数: amount, 指定要对模型中多大比例的参数执行剪枝.
#             amount是一个介于0.0-1.0的float数值, 或者一个正整数指定剪裁掉多少条连接边.

prune.random_unstructured(module, name="weight", amount=0.3)

Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))

In [6]:
print(list(module.named_parameters()))

[('bias', Parameter containing:
tensor([ 0.1137, -0.1016,  0.1808, -0.1531, -0.1572, -0.1555,  0.1419,  0.1511,
        -0.1444, -0.0748, -0.1287, -0.0134, -0.1542,  0.0316, -0.1716, -0.0265,
         0.1560, -0.0430, -0.1248,  0.0514,  0.0445, -0.1000,  0.0385, -0.1035,
        -0.0260,  0.0857, -0.0459, -0.0254, -0.0110, -0.1802,  0.0544, -0.0310],
       requires_grad=True)), ('weight_orig', Parameter containing:
tensor([[[[ 1.5348e-01, -1.4900e-01, -1.1842e-01],
          [ 2.4871e-02,  1.5275e-01,  1.1659e-01],
          [ 1.5592e-01, -1.3514e-01, -7.5043e-03]],

         [[ 1.2052e-02,  1.0355e-01, -7.6704e-02],
          [ 4.9685e-02,  1.3783e-01,  1.2183e-01],
          [-1.2113e-01, -1.0311e-01,  8.8395e-02]],

         [[-9.9745e-02, -5.7398e-02,  1.7571e-01],
          [ 2.3906e-02,  4.4312e-02,  9.1364e-02],
          [-1.2461e-01,  1.4030e-01, -7.7255e-02]]],


        [[[-3.0464e-02, -9.5161e-02, -6.0011e-02],
          [ 1.3134e-01, -1.7915e-01,  3.2132e-02],
          [

In [7]:
print(list(module.named_buffers()))

[('weight_mask', tensor([[[[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]],

         [[1., 1., 1.],
          [1., 1., 1.],
          [0., 1., 0.]],

         [[1., 1., 1.],
          [1., 1., 1.],
          [1., 0., 1.]]],


        [[[1., 0., 1.],
          [1., 0., 1.],
          [0., 1., 0.]],

         [[0., 1., 1.],
          [0., 1., 1.],
          [1., 1., 0.]],

         [[1., 0., 0.],
          [0., 1., 0.],
          [1., 1., 1.]]],


        [[[1., 0., 0.],
          [1., 1., 1.],
          [1., 1., 0.]],

         [[1., 1., 1.],
          [1., 1., 0.],
          [1., 1., 1.]],

         [[1., 0., 0.],
          [0., 1., 0.],
          [1., 0., 1.]]],


        [[[0., 1., 1.],
          [1., 1., 1.],
          [0., 1., 1.]],

         [[0., 0., 1.],
          [1., 1., 0.],
          [1., 1., 1.]],

         [[1., 1., 1.],
          [1., 1., 0.],
          [0., 0., 1.]]],


        [[[0., 0., 1.],
          [1., 1., 1.],
          [1., 1., 1.]],

         [[1.

结论: 模型经历剪枝操作后, 原始的权重矩阵weight参数不见了, 变成了weight_orig. 并且刚刚打印为空列表的module.named_buffers(), 此时拥有了一个weight_mask参数.

这时打印module.weight属性值, 看看有什么启发?

In [8]:
print(module.weight)

tensor([[[[ 0.1535, -0.1490, -0.1184],
          [ 0.0249,  0.1527,  0.1166],
          [ 0.1559, -0.1351, -0.0075]],

         [[ 0.0121,  0.1035, -0.0767],
          [ 0.0497,  0.1378,  0.1218],
          [-0.0000, -0.1031,  0.0000]],

         [[-0.0997, -0.0574,  0.1757],
          [ 0.0239,  0.0443,  0.0914],
          [-0.1246,  0.0000, -0.0773]]],


        [[[-0.0305, -0.0000, -0.0600],
          [ 0.1313, -0.0000,  0.0321],
          [ 0.0000,  0.0266,  0.0000]],

         [[-0.0000,  0.0058, -0.0185],
          [ 0.0000,  0.1902, -0.1646],
          [ 0.1408,  0.0853, -0.0000]],

         [[-0.0436,  0.0000, -0.0000],
          [ 0.0000,  0.0559,  0.0000],
          [-0.0968, -0.1657,  0.0108]]],


        [[[-0.0858, -0.0000,  0.0000],
          [ 0.0087, -0.0030, -0.1767],
          [ 0.1128, -0.0203, -0.0000]],

         [[ 0.1185, -0.0912, -0.0550],
          [ 0.1865,  0.0960,  0.0000],
          [-0.0425,  0.1350,  0.0423]],

         [[-0.0585, -0.0000,  0.0000],
     

结论: 经过剪枝操作后的模型, 原始的参数存放在了weight_orig中, 对应的剪枝矩阵存放在weight_mask中, 而将weight_mask视作掩码张量, 再和weight_orig相乘的结果就存放在了weight中.

注意: 剪枝操作后的weight已经不再是module的参数(parameter), 而只是module的一个属性(attribute).

对于每一次剪枝操作, 模型都会对应一个具体的_forward_pre_hooks函数用于剪枝.

In [9]:
print(module._forward_pre_hooks)

OrderedDict([(0, <torch.nn.utils.prune.RandomUnstructured object at 0x000001FC64B87B80>)])


In [10]:
# 第一个参数: module, 代表剪枝的对象, 此处代表LeNet中的conv1
# 第二个参数: name, 代表剪枝对象中的具体参数, 此处代表偏置量
# 第三个参数: amount, 代表剪枝的数量, 可以设置为0.0-1.0之间表示比例, 也可以用正整数表示剪枝的参数绝对数量
prune.l1_unstructured(module, name="bias", amount=3)

# 再次打印模型参数
print(list(module.named_parameters()))
print('*'*50)
print(list(module.named_buffers()))
print('*'*50)
print(module.bias)
print('*'*50)
print(module._forward_pre_hooks)

[('weight_orig', Parameter containing:
tensor([[[[ 1.5348e-01, -1.4900e-01, -1.1842e-01],
          [ 2.4871e-02,  1.5275e-01,  1.1659e-01],
          [ 1.5592e-01, -1.3514e-01, -7.5043e-03]],

         [[ 1.2052e-02,  1.0355e-01, -7.6704e-02],
          [ 4.9685e-02,  1.3783e-01,  1.2183e-01],
          [-1.2113e-01, -1.0311e-01,  8.8395e-02]],

         [[-9.9745e-02, -5.7398e-02,  1.7571e-01],
          [ 2.3906e-02,  4.4312e-02,  9.1364e-02],
          [-1.2461e-01,  1.4030e-01, -7.7255e-02]]],


        [[[-3.0464e-02, -9.5161e-02, -6.0011e-02],
          [ 1.3134e-01, -1.7915e-01,  3.2132e-02],
          [ 7.6910e-02,  2.6616e-02,  3.2158e-02]],

         [[-1.7325e-01,  5.8404e-03, -1.8473e-02],
          [ 1.4755e-01,  1.9016e-01, -1.6464e-01],
          [ 1.4077e-01,  8.5270e-02, -1.8393e-01]],

         [[-4.3645e-02,  7.0406e-02, -1.7177e-01],
          [ 1.0164e-01,  5.5904e-02,  1.1319e-01],
          [-9.6777e-02, -1.6566e-01,  1.0832e-02]]],


        [[[-8.5839e-02, -1.

结论: 在module的不同参数集合上应用不同的剪枝策略, 我们发现模型参数中不仅仅有了weight_orig, 也有了bias_orig. 在起到掩码张量作用的named_buffers中, 也同时出现了weight_mask和bias_mask. 最后, 因为我们在两类参数上应用了两种不同的剪枝函数, 因此_forward_pre_hooks中也打印出了2个不同的函数结果.

### 序列化一个剪枝模型

In [11]:
# 对于一个模型来说, 不管是它原始的参数, 拥有的属性值, 还是剪枝的mask buffers参数
# 全部都存储在模型的状态字典中, 即state_dict()中.
# 将模型初始的状态字典打印出来
print(model.state_dict().keys())


odict_keys(['conv1.weight_orig', 'conv1.bias_orig', 'conv1.weight_mask', 'conv1.bias_mask', 'conv2.weight', 'conv2.bias', 'fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias'])


In [12]:
# 对模型进行剪枝操作, 分别在weight和bias上剪枝
module = model.conv1
prune.random_unstructured(module, name="weight", amount=0.3)
prune.l1_unstructured(module, name="bias", amount=3)

# 再将剪枝后的模型的状态字典打印出来
print(model.state_dict().keys())

odict_keys(['conv1.weight_orig', 'conv1.bias_orig', 'conv1.weight_mask', 'conv1.bias_mask', 'conv2.weight', 'conv2.bias', 'fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias'])


In [13]:
# 打印剪枝后的模型参数
print(list(module.named_parameters()))
print('*'*50)

# 打印剪枝后的模型mask buffers参数
print(list(module.named_buffers()))
print('*'*50)

# 打印剪枝后的模型weight属性值
print(module.weight)
print('*'*50)

# 打印模型的_forward_pre_hooks
print(module._forward_pre_hooks)
print('*'*50)


[('weight_orig', Parameter containing:
tensor([[[[ 1.5348e-01, -1.4900e-01, -1.1842e-01],
          [ 2.4871e-02,  1.5275e-01,  1.1659e-01],
          [ 1.5592e-01, -1.3514e-01, -7.5043e-03]],

         [[ 1.2052e-02,  1.0355e-01, -7.6704e-02],
          [ 4.9685e-02,  1.3783e-01,  1.2183e-01],
          [-1.2113e-01, -1.0311e-01,  8.8395e-02]],

         [[-9.9745e-02, -5.7398e-02,  1.7571e-01],
          [ 2.3906e-02,  4.4312e-02,  9.1364e-02],
          [-1.2461e-01,  1.4030e-01, -7.7255e-02]]],


        [[[-3.0464e-02, -9.5161e-02, -6.0011e-02],
          [ 1.3134e-01, -1.7915e-01,  3.2132e-02],
          [ 7.6910e-02,  2.6616e-02,  3.2158e-02]],

         [[-1.7325e-01,  5.8404e-03, -1.8473e-02],
          [ 1.4755e-01,  1.9016e-01, -1.6464e-01],
          [ 1.4077e-01,  8.5270e-02, -1.8393e-01]],

         [[-4.3645e-02,  7.0406e-02, -1.7177e-01],
          [ 1.0164e-01,  5.5904e-02,  1.1319e-01],
          [-9.6777e-02, -1.6566e-01,  1.0832e-02]]],


        [[[-8.5839e-02, -1.

In [14]:
# 执行剪枝永久化操作remove
prune.remove(module, 'weight')
print('*'*50)

# 将剪枝后的模型的状态字典打印出来
print(model.state_dict().keys())

# remove后再次打印模型参数
print(list(module.named_parameters()))
print('*'*50)

# remove后再次打印模型mask buffers参数
print(list(module.named_buffers()))
print('*'*50)

# remove后再次打印模型的_forward_pre_hooks
print(module._forward_pre_hooks)

**************************************************
odict_keys(['conv1.bias_orig', 'conv1.weight', 'conv1.bias_mask', 'conv2.weight', 'conv2.bias', 'fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias'])
[('bias_orig', Parameter containing:
tensor([ 0.1137, -0.1016,  0.1808, -0.1531, -0.1572, -0.1555,  0.1419,  0.1511,
        -0.1444, -0.0748, -0.1287, -0.0134, -0.1542,  0.0316, -0.1716, -0.0265,
         0.1560, -0.0430, -0.1248,  0.0514,  0.0445, -0.1000,  0.0385, -0.1035,
        -0.0260,  0.0857, -0.0459, -0.0254, -0.0110, -0.1802,  0.0544, -0.0310],
       requires_grad=True)), ('weight', Parameter containing:
tensor([[[[ 0.1535, -0.1490, -0.1184],
          [ 0.0249,  0.1527,  0.0000],
          [ 0.1559, -0.0000, -0.0075]],

         [[ 0.0121,  0.1035, -0.0767],
          [ 0.0497,  0.1378,  0.1218],
          [-0.0000, -0.1031,  0.0000]],

         [[-0.0997, -0.0574,  0.1757],
          [ 0.0000,  0.0443,  0.0914],
          [-0.1246,  0.0000, -0.0773]]],


        [[[-0.0305, -

结论: 对模型的weight执行remove操作后, 模型参数集合中只剩下bias_orig了, weight_orig消失, 变成了weight, 说明针对weight的剪枝已经永久化生效. 对于named_buffers张量打印可以看出, 只剩下bias_mask了, 因为针对weight做掩码的weight_mask已经生效完毕, 不再需要保留了. 同理, 在_forward_pre_hooks中也只剩下针对bias做剪枝的函数了.

## 2.多参数模块的剪枝(Pruning multiple parameters).¶

In [15]:
model = CNN().to(device=device)

# 打印初始模型的所有状态字典
print(model.state_dict().keys())
print('*'*50)

# 打印初始模型的mask buffers张量字典名称
print(dict(model.named_buffers()).keys())
print('*'*50)

# 对于模型进行分模块参数的剪枝
for name, module in model.named_modules():
    # 对模型中所有的卷积层执行l1_unstructured剪枝操作, 选取20%的参数剪枝
    if isinstance(module, torch.nn.Conv2d):
        prune.l1_unstructured(module, name="weight", amount=0.2)
    # 对模型中所有全连接层执行ln_structured剪枝操作, 选取40%的参数剪枝
    elif isinstance(module, torch.nn.Linear):
        prune.ln_structured(module, name="weight", amount=0.4, n=2, dim=0)

# 打印多参数模块剪枝后的mask buffers张量字典名称
print(dict(model.named_buffers()).keys())
print('*'*50)

# 打印多参数模块剪枝后模型的所有状态字典名称
print(model.state_dict().keys())

odict_keys(['conv1.weight', 'conv1.bias', 'conv2.weight', 'conv2.bias', 'fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias'])
**************************************************
dict_keys([])
**************************************************
dict_keys(['conv1.weight_mask', 'conv2.weight_mask', 'fc1.weight_mask', 'fc2.weight_mask'])
**************************************************
odict_keys(['conv1.bias', 'conv1.weight_orig', 'conv1.weight_mask', 'conv2.bias', 'conv2.weight_orig', 'conv2.weight_mask', 'fc1.bias', 'fc1.weight_orig', 'fc1.weight_mask', 'fc2.bias', 'fc2.weight_orig', 'fc2.weight_mask'])


## 3.全局剪枝(GLobal pruning)

更普遍也更通用的剪枝策略是采用全局剪枝(global pruning), 比如在整体网络的视角下剪枝掉20%的权重参数, 而不是在每一层上都剪枝掉20%的权重参数. 采用全局剪枝后, 不同的层被剪掉的百分比不同.

In [16]:
model = CNN().to(device=device)

# 首先打印初始化模型的状态字典
print(model.state_dict().keys())
print('*'*50)

# 构建参数集合, 决定哪些层, 哪些参数集合参与剪枝
parameters_to_prune = (
            (model.conv1, 'weight'),
            (model.conv2, 'weight'),
            (model.fc1, 'weight'),
            (model.fc2, 'weight'))

# 调用prune中的全局剪枝函数global_unstructured执行剪枝操作, 此处针对整体模型中的20%参数量进行剪枝
prune.global_unstructured(parameters_to_prune, pruning_method=prune.L1Unstructured, amount=0.2)

# 最后打印剪枝后的模型的状态字典
print(model.state_dict().keys())

odict_keys(['conv1.weight', 'conv1.bias', 'conv2.weight', 'conv2.bias', 'fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias'])
**************************************************


odict_keys(['conv1.bias', 'conv1.weight_orig', 'conv1.weight_mask', 'conv2.bias', 'conv2.weight_orig', 'conv2.weight_mask', 'fc1.bias', 'fc1.weight_orig', 'fc1.weight_mask', 'fc2.bias', 'fc2.weight_orig', 'fc2.weight_mask'])


针对模型剪枝后, 不同的层会有不同比例的权重参数被剪掉, 利用代码打印出来看看:

In [17]:
model = LeNet().to(device=device)

parameters_to_prune = (
            (model.conv1, 'weight'),
            (model.conv2, 'weight'),
            (model.fc1, 'weight'),
            (model.fc2, 'weight'))

prune.global_unstructured(parameters_to_prune, pruning_method=prune.L1Unstructured, amount=0.2)

print(
    "Sparsity in conv1.weight: {:.2f}%".format(
    100. * float(torch.sum(model.conv1.weight == 0))
    / float(model.conv1.weight.nelement())
    ))

print(
    "Sparsity in conv2.weight: {:.2f}%".format(
    100. * float(torch.sum(model.conv2.weight == 0))
    / float(model.conv2.weight.nelement())
    ))

print(
    "Sparsity in fc1.weight: {:.2f}%".format(
    100. * float(torch.sum(model.fc1.weight == 0))
    / float(model.fc1.weight.nelement())
    ))

print(
    "Sparsity in fc2.weight: {:.2f}%".format(
    100. * float(torch.sum(model.fc2.weight == 0))
    / float(model.fc2.weight.nelement())
    ))


print(
    "Global sparsity: {:.2f}%".format(
    100. * float(torch.sum(model.conv1.weight == 0)
               + torch.sum(model.conv2.weight == 0)
               + torch.sum(model.fc1.weight == 0)
               + torch.sum(model.fc2.weight == 0))
         / float(model.conv1.weight.nelement()
               + model.conv2.weight.nelement()
               + model.fc1.weight.nelement()
               + model.fc2.weight.nelement())
    ))

NameError: name 'LeNet' is not defined

结论: 当采用全局剪枝策略的时候(假定20%比例参数参与剪枝), 仅保证模型总体参数量的20%被剪枝掉, 具体到每一层的情况则由模型的具体参数分布情况来定.


## 4.用户自定义剪枝(Custom pruning).

剪枝模型通过继承class BasePruningMethod()来执行剪枝, 内部有若干方法: call, apply_mask, apply, prune, remove等等. 一般来说, 用户只需要实现__init__, 和compute_mask两个函数即可完成自定义的剪枝规则设定.

In [None]:
# 自定义剪枝方法的类, 一定要继承prune.BasePruningMethod
class custom_prune(prune.BasePruningMethod):
    PRUNING_TYPE = "unstructured"

    # 内部实现compute_mask函数, 完成程序员自己定义的剪枝规则, 本质上就是如何去mask掉权重参数
    def compute_mask(self, t, default_mask):
        mask = default_mask.clone()
        # 此处定义的规则是每隔一个参数就遮掩掉一个, 最终参与剪枝的参数量的50%被mask掉
        mask.view(-1)[::2] = 0
        return mask

# 自定义剪枝方法的函数, 内部直接调用剪枝类的方法apply
def custome_unstructured_pruning(module, name):
    custom_prune.apply(module, name)
    return module

In [None]:
import time
# 实例化模型类
model = CNN().to(device=device)

start = time.time()
# 调用自定义剪枝方法的函数, 对model中的第2个全连接层fc2中的偏置bias执行自定义剪枝
custome_unstructured_pruning(model.fc2, name="bias")

# 剪枝成功的最大标志, 就是拥有了bias_mask参数
print(model.fc2.bias_mask)

# 打印一下自定义剪枝的耗时
duration = time.time() - start
print(duration * 1000, 'ms')

tensor([0., 1., 0., 1., 0., 1., 0., 1., 0., 1.])
1.9989013671875 ms


结论: 打印出来的bias_mask张量, 完全是按照预定义的方式每隔一位遮掩掉一位, 0和1交替出现, 后续执行remove操作的时候, 原始的bias_orig中的权重就会同样的被每隔一位剪枝掉一位. 在GPU机器上执行自定义剪枝速度特别快, 仅需1.7ms.