In [9]:
import torch
from torch import nn
import torch.nn.utils.prune as prune
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
from torchvision import models
from torchsummary import summary

## Check CUDA

In [10]:
if torch.cuda.is_available():
    cudnn.benchmark = True
    device = "cuda"
    print(torch.cuda.get_device_name())
else:
    device = "cpu"
    print("Use CPU")

Quadro RTX 3000 with Max-Q Design


## Load Model

In [12]:
PATH = r"my_weights\Resnet18_e20_b5_t70_v30.pth"
model = torch.load(PATH).to(device)

  model = torch.load(PATH).to(device)


In [13]:
for name, layer in model.named_children():
    print(name, end=" / ")

conv1 / bn1 / relu / maxpool / layer1 / layer2 / layer3 / layer4 / avgpool / fc / 

## Inspect Module : conv1

In [5]:
for name, module in model.named_modules():
    print(name, "/", module)

 / ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)

In [6]:
module = model.conv1
print("## parameters: ", list(module.named_parameters()))
print("## buffers: ", list(module.named_buffers()))

## parameters:  [('weight', Parameter containing:
tensor([[[[-1.0069e-02, -1.3557e-02, -9.1355e-03,  ...,  5.3608e-02,
            1.6570e-02, -1.4646e-02],
          [ 1.0844e-02,  3.3984e-03, -1.1661e-01,  ..., -2.7360e-01,
           -1.3079e-01,  2.8700e-03],
          [-1.1874e-02,  5.4071e-02,  2.9604e-01,  ...,  5.1904e-01,
            2.5732e-01,  6.2824e-02],
          ...,
          [-3.4586e-02,  6.7463e-03,  7.2761e-02,  ..., -3.3935e-01,
           -4.2843e-01, -2.6827e-01],
          [ 2.6184e-02,  3.4247e-02,  5.9439e-02,  ...,  4.0480e-01,
            3.8420e-01,  1.5571e-01],
          [-1.6383e-02, -7.2828e-03, -2.6395e-02,  ..., -1.5636e-01,
           -8.6250e-02, -1.2880e-02]],

         [[-1.5416e-02, -4.1712e-02, -5.0123e-02,  ...,  2.0020e-02,
           -1.0096e-02, -3.7144e-02],
          [ 3.9152e-02,  1.9342e-02, -1.1925e-01,  ..., -3.2379e-01,
           -1.7146e-01, -1.1399e-02],
          [-1.1461e-02,  8.5079e-02,  3.9374e-01,  ...,  6.9820e-01,
        

## Pruning a Module

### Local Pruning module(conv1), Random Unstructured Pruning

In [6]:
prune.random_unstructured(module, name="weight", amount=0.3)

Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)

In [7]:
print(list(module.named_parameters()))

[('weight_orig', Parameter containing:
tensor([[[[-1.0069e-02, -1.3557e-02, -9.1355e-03,  ...,  5.3608e-02,
            1.6570e-02, -1.4646e-02],
          [ 1.0844e-02,  3.3984e-03, -1.1661e-01,  ..., -2.7360e-01,
           -1.3079e-01,  2.8700e-03],
          [-1.1874e-02,  5.4071e-02,  2.9604e-01,  ...,  5.1904e-01,
            2.5732e-01,  6.2824e-02],
          ...,
          [-3.4586e-02,  6.7463e-03,  7.2761e-02,  ..., -3.3935e-01,
           -4.2843e-01, -2.6827e-01],
          [ 2.6184e-02,  3.4247e-02,  5.9439e-02,  ...,  4.0480e-01,
            3.8420e-01,  1.5571e-01],
          [-1.6383e-02, -7.2828e-03, -2.6395e-02,  ..., -1.5636e-01,
           -8.6250e-02, -1.2880e-02]],

         [[-1.5416e-02, -4.1712e-02, -5.0123e-02,  ...,  2.0020e-02,
           -1.0096e-02, -3.7144e-02],
          [ 3.9152e-02,  1.9342e-02, -1.1925e-01,  ..., -3.2379e-01,
           -1.7146e-01, -1.1399e-02],
          [-1.1461e-02,  8.5079e-02,  3.9374e-01,  ...,  6.9820e-01,
            3.6105e

In [8]:
print(list(module.named_buffers()))

[('weight_mask', tensor([[[[1., 0., 1.,  ..., 0., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 0.],
          [0., 1., 0.,  ..., 0., 0., 1.],
          ...,
          [1., 1., 1.,  ..., 1., 0., 0.],
          [0., 1., 1.,  ..., 1., 0., 1.],
          [0., 0., 1.,  ..., 1., 1., 1.]],

         [[1., 0., 0.,  ..., 0., 0., 1.],
          [1., 0., 1.,  ..., 0., 1., 0.],
          [0., 0., 0.,  ..., 0., 1., 1.],
          ...,
          [1., 1., 0.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 0., 0.],
          [0., 0., 1.,  ..., 1., 0., 0.]],

         [[1., 1., 1.,  ..., 0., 1., 1.],
          [1., 1., 0.,  ..., 0., 0., 1.],
          [0., 1., 1.,  ..., 1., 0., 0.],
          ...,
          [1., 0., 1.,  ..., 1., 1., 0.],
          [1., 1., 0.,  ..., 0., 1., 1.],
          [1., 1., 0.,  ..., 1., 0., 1.]]],


        [[[1., 1., 0.,  ..., 1., 1., 0.],
          [1., 0., 1.,  ..., 0., 0., 1.],
          [0., 0., 1.,  ..., 1., 0., 0.],
          ...,
          [1., 1., 1.,  ..., 1., 

In [9]:
print(module.weight)

tensor([[[[-1.0069e-02, -0.0000e+00, -9.1355e-03,  ...,  0.0000e+00,
            1.6570e-02, -1.4646e-02],
          [ 1.0844e-02,  3.3984e-03, -1.1661e-01,  ..., -2.7360e-01,
           -1.3079e-01,  0.0000e+00],
          [-0.0000e+00,  5.4071e-02,  0.0000e+00,  ...,  0.0000e+00,
            0.0000e+00,  6.2824e-02],
          ...,
          [-3.4586e-02,  6.7463e-03,  7.2761e-02,  ..., -3.3935e-01,
           -0.0000e+00, -0.0000e+00],
          [ 0.0000e+00,  3.4247e-02,  5.9439e-02,  ...,  4.0480e-01,
            0.0000e+00,  1.5571e-01],
          [-0.0000e+00, -0.0000e+00, -2.6395e-02,  ..., -1.5636e-01,
           -8.6250e-02, -1.2880e-02]],

         [[-1.5416e-02, -0.0000e+00, -0.0000e+00,  ...,  0.0000e+00,
           -0.0000e+00, -3.7144e-02],
          [ 3.9152e-02,  0.0000e+00, -1.1925e-01,  ..., -0.0000e+00,
           -1.7146e-01, -0.0000e+00],
          [-0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
            3.6105e-01,  1.1450e-01],
          ...,
     

In [None]:
prune.remove(module, 'weight') # 將 weight_orig、weight_mask 刪除，並新增 weight

In [17]:
print(list(module.named_parameters()))

[('weight', Parameter containing:
tensor([[[[-1.0069e-02, -0.0000e+00, -9.1355e-03,  ...,  0.0000e+00,
            1.6570e-02, -1.4646e-02],
          [ 1.0844e-02,  3.3984e-03, -1.1661e-01,  ..., -2.7360e-01,
           -1.3079e-01,  0.0000e+00],
          [-0.0000e+00,  5.4071e-02,  0.0000e+00,  ...,  0.0000e+00,
            0.0000e+00,  6.2824e-02],
          ...,
          [-3.4586e-02,  6.7463e-03,  7.2761e-02,  ..., -3.3935e-01,
           -0.0000e+00, -0.0000e+00],
          [ 0.0000e+00,  3.4247e-02,  5.9439e-02,  ...,  4.0480e-01,
            0.0000e+00,  1.5571e-01],
          [-0.0000e+00, -0.0000e+00, -2.6395e-02,  ..., -1.5636e-01,
           -8.6250e-02, -1.2880e-02]],

         [[-1.5416e-02, -0.0000e+00, -0.0000e+00,  ...,  0.0000e+00,
           -0.0000e+00, -3.7144e-02],
          [ 3.9152e-02,  0.0000e+00, -1.1925e-01,  ..., -0.0000e+00,
           -1.7146e-01, -0.0000e+00],
          [-0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
            3.6105e-01, 

In [18]:
model_path = PATH.replace('.pth', '(conv1 sp=0.3).pth')
torch.save(model.state_dict(), model_path)

### Global Pruning
Resnet18: 對 conv1 / layer1 / layer2 / layer3 / layer4 / fc 做剪枝 
1. prune.RandomUnstructured
2. prune.l1_unstructured

In [None]:
PATH = r"my_weights\Resnet18_e20_b5_t70_v30.pth"
model = torch.load(PATH).to(device)
summary(model, input_size=(3, 32, 32))

  model = torch.load(PATH).to(device)


----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 61, 16, 16]           8,967
       BatchNorm2d-2           [-1, 61, 16, 16]             122
              ReLU-3           [-1, 61, 16, 16]               0
         MaxPool2d-4             [-1, 61, 8, 8]               0
            Conv2d-5             [-1, 64, 8, 8]          35,136
       BatchNorm2d-6             [-1, 64, 8, 8]             128
              ReLU-7             [-1, 64, 8, 8]               0
            Conv2d-8             [-1, 61, 8, 8]          35,136
       BatchNorm2d-9             [-1, 61, 8, 8]             122
             ReLU-10             [-1, 61, 8, 8]               0
       BasicBlock-11             [-1, 61, 8, 8]               0
           Conv2d-12             [-1, 64, 8, 8]          35,136
      BatchNorm2d-13             [-1, 64, 8, 8]             128
             ReLU-14             [-1, 6

In [22]:
parameters_to_prune = [
    (model.conv1, "weight"), # Conv2d
    (model.fc, "weight") # Linear
] 

for layer in [model.layer1, model.layer2, model.layer3, model.layer4]:
    for block in layer:
        for name, module in block.named_modules():
            if isinstance(module, torch.nn.Conv2d):
                parameters_to_prune.append((module, 'weight'))

# for module in parameters_to_prune:
#     print(module)

In [23]:
## Pruning
prune.global_unstructured(
    parameters_to_prune,
    pruning_method=prune.L1Unstructured,
    # pruning_method=prune.RandomUnstructured,
    amount=0.3,
)

In [24]:
print(
    "Sparsity in conv1.weight: {:.2f}%".format(
        100. * float(torch.sum(model.conv1.weight == 0))
        / float(model.conv1.weight.nelement())
    )
)
print(
    "Sparsity in layer1[0].conv1.weight: {:.2f}%".format(
        100. * float(torch.sum(model.layer1[0].conv1.weight == 0))
        / float(model.layer1[0].conv1.weight.nelement())
    )
)
print(
    "Sparsity in layer1[0].conv2.weight: {:.2f}%".format(
        100. * float(torch.sum(model.layer1[0].conv2.weight == 0))
        / float(model.layer1[0].conv2.weight.nelement())
    )
)
print(
    "Sparsity in layer1[1].conv1.weight: {:.2f}%".format(
        100. * float(torch.sum(model.layer1[1].conv1.weight == 0))
        / float(model.layer1[1].conv1.weight.nelement())
    )
)
print(
    "Sparsity in layer1[1].conv2.weight: {:.2f}%".format(
        100. * float(torch.sum(model.layer1[1].conv2.weight == 0))
        / float(model.layer1[1].conv2.weight.nelement())
    )
)
print("     ......      ")
print(
    "Sparsity in fc.weight: {:.2f}%".format(
        100. * float(torch.sum(model.fc.weight == 0))
        / float(model.fc.weight.nelement())
    )
)
########################################################################
count_0 = 0
count_total = 0
for module, _ in parameters_to_prune:
    count_0 += torch.sum(module.weight == 0)
    count_total += module.weight.nelement()
print(
    "Global sparsity: {:.2f}%".format(
        100. * float(count_0) / float(count_total)
    )
)

Sparsity in conv1.weight: 19.96%
Sparsity in layer1[0].conv1.weight: 26.13%
Sparsity in layer1[0].conv2.weight: 15.75%
Sparsity in layer1[1].conv1.weight: 15.26%
Sparsity in layer1[1].conv2.weight: 15.62%
     ......      
Sparsity in fc.weight: 6.59%
Global sparsity: 30.00%


In [25]:
for module, _ in parameters_to_prune:
    prune.remove(module, 'weight')

In [26]:
summary(model, input_size=(3, 32, 32)) # Unstruced 剪枝的話，Param 不會改變
# print(model.conv1.weight) # 確認有剪枝到

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 64, 16, 16]           9,408
       BatchNorm2d-2           [-1, 64, 16, 16]             128
              ReLU-3           [-1, 64, 16, 16]               0
         MaxPool2d-4             [-1, 64, 8, 8]               0
            Conv2d-5             [-1, 64, 8, 8]          36,864
       BatchNorm2d-6             [-1, 64, 8, 8]             128
              ReLU-7             [-1, 64, 8, 8]               0
            Conv2d-8             [-1, 64, 8, 8]          36,864
       BatchNorm2d-9             [-1, 64, 8, 8]             128
             ReLU-10             [-1, 64, 8, 8]               0
       BasicBlock-11             [-1, 64, 8, 8]               0
           Conv2d-12             [-1, 64, 8, 8]          36,864
      BatchNorm2d-13             [-1, 64, 8, 8]             128
             ReLU-14             [-1, 6

In [27]:
model_path = PATH.replace('.pth', '(global L1-norm sp=0.3).pth')
torch.save(model, model_path)