In [1]:
import torch
from torch import nn
import torch.nn.utils.prune as prune
import torch.nn.functional as F

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

'''搭建类LeNet网络'''
class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        # 单通道图像输入，5×5核尺寸
        self.conv1 = nn.Conv2d(1, 3, 5)
        self.conv2 = nn.Conv2d(3, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)  
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(-1, int(x.nelement() / x.shape[0]))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [3]:
model = LeNet().to(device=device)
module = model.conv1
print(list(module.named_parameters()))      # 6×5×5的weight + 6×1的bias 的参数量

[('weight', Parameter containing:
tensor([[[[-0.0904,  0.0711, -0.0759,  0.0595, -0.1563],
          [-0.1085,  0.0188,  0.1486,  0.1716, -0.0291],
          [ 0.1657,  0.0028, -0.0296, -0.1090, -0.1923],
          [ 0.1616,  0.0829, -0.1888,  0.0814,  0.0476],
          [ 0.1709, -0.0672,  0.0499, -0.1713, -0.1772]]],


        [[[-0.1454, -0.1347, -0.0237,  0.1693, -0.1326],
          [-0.1398,  0.0666,  0.0255,  0.1590,  0.1502],
          [ 0.1559,  0.0760, -0.0553,  0.0052, -0.0605],
          [ 0.0317, -0.0124,  0.1549, -0.1674,  0.0629],
          [-0.1922,  0.0164, -0.1942,  0.1078,  0.1314]]],


        [[[-0.0781,  0.0147, -0.1763,  0.0891, -0.0867],
          [ 0.0728,  0.1424,  0.1553,  0.1235,  0.0102],
          [ 0.1219, -0.0986, -0.0100,  0.0934, -0.1461],
          [ 0.1884,  0.1153,  0.0886, -0.1132,  0.1022],
          [ 0.1443,  0.1996,  0.0964, -0.1735, -0.1491]]]], device='cuda:0',
       requires_grad=True)), ('bias', Parameter containing:
tensor([ 0.1852, -0.058

In [4]:
prune.ln_structured(module, name="weight", amount=0.33, n=2, dim=0)
print(module.weight)

tensor([[[[-0.0904,  0.0711, -0.0759,  0.0595, -0.1563],
          [-0.1085,  0.0188,  0.1486,  0.1716, -0.0291],
          [ 0.1657,  0.0028, -0.0296, -0.1090, -0.1923],
          [ 0.1616,  0.0829, -0.1888,  0.0814,  0.0476],
          [ 0.1709, -0.0672,  0.0499, -0.1713, -0.1772]]],


        [[[-0.0000, -0.0000, -0.0000,  0.0000, -0.0000],
          [-0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000, -0.0000,  0.0000, -0.0000],
          [ 0.0000, -0.0000,  0.0000, -0.0000,  0.0000],
          [-0.0000,  0.0000, -0.0000,  0.0000,  0.0000]]],


        [[[-0.0781,  0.0147, -0.1763,  0.0891, -0.0867],
          [ 0.0728,  0.1424,  0.1553,  0.1235,  0.0102],
          [ 0.1219, -0.0986, -0.0100,  0.0934, -0.1461],
          [ 0.1884,  0.1153,  0.0886, -0.1132,  0.1022],
          [ 0.1443,  0.1996,  0.0964, -0.1735, -0.1491]]]], device='cuda:0',
       grad_fn=<MulBackward0>)


In [5]:
print(list(module.named_parameters()))
print(list(module.named_buffers()))
print(model.state_dict().keys())

[('bias', Parameter containing:
tensor([ 0.1852, -0.0585,  0.1808], device='cuda:0', requires_grad=True)), ('weight_orig', Parameter containing:
tensor([[[[-0.0904,  0.0711, -0.0759,  0.0595, -0.1563],
          [-0.1085,  0.0188,  0.1486,  0.1716, -0.0291],
          [ 0.1657,  0.0028, -0.0296, -0.1090, -0.1923],
          [ 0.1616,  0.0829, -0.1888,  0.0814,  0.0476],
          [ 0.1709, -0.0672,  0.0499, -0.1713, -0.1772]]],


        [[[-0.1454, -0.1347, -0.0237,  0.1693, -0.1326],
          [-0.1398,  0.0666,  0.0255,  0.1590,  0.1502],
          [ 0.1559,  0.0760, -0.0553,  0.0052, -0.0605],
          [ 0.0317, -0.0124,  0.1549, -0.1674,  0.0629],
          [-0.1922,  0.0164, -0.1942,  0.1078,  0.1314]]],


        [[[-0.0781,  0.0147, -0.1763,  0.0891, -0.0867],
          [ 0.0728,  0.1424,  0.1553,  0.1235,  0.0102],
          [ 0.1219, -0.0986, -0.0100,  0.0934, -0.1461],
          [ 0.1884,  0.1153,  0.0886, -0.1132,  0.1022],
          [ 0.1443,  0.1996,  0.0964, -0.1735, -0

In [6]:
prune.remove(module, 'weight')
print(list(module.named_parameters()))
print(list(module.named_buffers()))
print(model.state_dict().keys())

[('bias', Parameter containing:
tensor([ 0.1852, -0.0585,  0.1808], device='cuda:0', requires_grad=True)), ('weight', Parameter containing:
tensor([[[[-0.0904,  0.0711, -0.0759,  0.0595, -0.1563],
          [-0.1085,  0.0188,  0.1486,  0.1716, -0.0291],
          [ 0.1657,  0.0028, -0.0296, -0.1090, -0.1923],
          [ 0.1616,  0.0829, -0.1888,  0.0814,  0.0476],
          [ 0.1709, -0.0672,  0.0499, -0.1713, -0.1772]]],


        [[[-0.0000, -0.0000, -0.0000,  0.0000, -0.0000],
          [-0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000, -0.0000,  0.0000, -0.0000],
          [ 0.0000, -0.0000,  0.0000, -0.0000,  0.0000],
          [-0.0000,  0.0000, -0.0000,  0.0000,  0.0000]]],


        [[[-0.0781,  0.0147, -0.1763,  0.0891, -0.0867],
          [ 0.0728,  0.1424,  0.1553,  0.1235,  0.0102],
          [ 0.1219, -0.0986, -0.0100,  0.0934, -0.1461],
          [ 0.1884,  0.1153,  0.0886, -0.1132,  0.1022],
          [ 0.1443,  0.1996,  0.0964, -0.1735, -0.1491

In [8]:
list(model.named_parameters())

[('conv1.bias',
  Parameter containing:
  tensor([ 0.1852, -0.0585,  0.1808], device='cuda:0', requires_grad=True)),
 ('conv1.weight',
  Parameter containing:
  tensor([[[[-0.0904,  0.0711, -0.0759,  0.0595, -0.1563],
            [-0.1085,  0.0188,  0.1486,  0.1716, -0.0291],
            [ 0.1657,  0.0028, -0.0296, -0.1090, -0.1923],
            [ 0.1616,  0.0829, -0.1888,  0.0814,  0.0476],
            [ 0.1709, -0.0672,  0.0499, -0.1713, -0.1772]]],
  
  
          [[[-0.0000, -0.0000, -0.0000,  0.0000, -0.0000],
            [-0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
            [ 0.0000,  0.0000, -0.0000,  0.0000, -0.0000],
            [ 0.0000, -0.0000,  0.0000, -0.0000,  0.0000],
            [-0.0000,  0.0000, -0.0000,  0.0000,  0.0000]]],
  
  
          [[[-0.0781,  0.0147, -0.1763,  0.0891, -0.0867],
            [ 0.0728,  0.1424,  0.1553,  0.1235,  0.0102],
            [ 0.1219, -0.0986, -0.0100,  0.0934, -0.1461],
            [ 0.1884,  0.1153,  0.0886, -0.1132,  0.1022],