In [1]:
import torch
from torch import nn
import torch.nn.utils.prune as prune
import torch.nn.functional as F

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        
        self.conv1 = nn.Conv2d(1, 6, 3)
        self.conv2 = nn.Conv2d(6,16,3)
        self.fc1 = nn.Linear(16*5*5,120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3= nn.Linear(84,10)
        
    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), (2,2))
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(-1, int(x.nelement() / x.shape[0]))
        x = F.relu(self.fc1(x))
        x = F.relu(sekf.fc2(x))
        x = self.fc3(x)
        return x

In [3]:
model = LeNet().to(device)

In [4]:
module = model.conv1

print(list(module.named_parameters()))

[('weight', Parameter containing:
tensor([[[[-0.1344, -0.2570, -0.0766],
          [ 0.1902,  0.2730, -0.0570],
          [ 0.3233,  0.0250, -0.1096]]],


        [[[ 0.3006, -0.1488,  0.1619],
          [ 0.1015,  0.1041, -0.0207],
          [-0.0058, -0.2466,  0.1211]]],


        [[[-0.0090, -0.1213,  0.1847],
          [ 0.2169,  0.3054,  0.1787],
          [-0.2001, -0.1974, -0.3088]]],


        [[[-0.2924,  0.1331,  0.2393],
          [ 0.1412, -0.2841, -0.1181],
          [ 0.3100, -0.3310,  0.2429]]],


        [[[-0.0296, -0.1413,  0.3116],
          [ 0.3214,  0.0974, -0.1454],
          [-0.2844, -0.1496,  0.0366]]],


        [[[ 0.1690,  0.1320, -0.1280],
          [-0.1170,  0.0237,  0.2908],
          [ 0.0753, -0.1256,  0.2720]]]], device='cuda:0', requires_grad=True)), ('bias', Parameter containing:
tensor([-0.1460, -0.0215, -0.1873,  0.1574,  0.1332,  0.3207], device='cuda:0',
       requires_grad=True))]


## Pruning a Module

In [5]:
prune.random_unstructured(module, name='weight', amount=0.3)

Conv2d(1, 6, kernel_size=(3, 3), stride=(1, 1))

In [6]:
print(list(module.named_parameters()))

[('bias', Parameter containing:
tensor([-0.1460, -0.0215, -0.1873,  0.1574,  0.1332,  0.3207], device='cuda:0',
       requires_grad=True)), ('weight_orig', Parameter containing:
tensor([[[[-0.1344, -0.2570, -0.0766],
          [ 0.1902,  0.2730, -0.0570],
          [ 0.3233,  0.0250, -0.1096]]],


        [[[ 0.3006, -0.1488,  0.1619],
          [ 0.1015,  0.1041, -0.0207],
          [-0.0058, -0.2466,  0.1211]]],


        [[[-0.0090, -0.1213,  0.1847],
          [ 0.2169,  0.3054,  0.1787],
          [-0.2001, -0.1974, -0.3088]]],


        [[[-0.2924,  0.1331,  0.2393],
          [ 0.1412, -0.2841, -0.1181],
          [ 0.3100, -0.3310,  0.2429]]],


        [[[-0.0296, -0.1413,  0.3116],
          [ 0.3214,  0.0974, -0.1454],
          [-0.2844, -0.1496,  0.0366]]],


        [[[ 0.1690,  0.1320, -0.1280],
          [-0.1170,  0.0237,  0.2908],
          [ 0.0753, -0.1256,  0.2720]]]], device='cuda:0', requires_grad=True))]


In [7]:
print(list(module.named_buffers()))

[('weight_mask', tensor([[[[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 0.]]],


        [[[1., 1., 1.],
          [0., 1., 1.],
          [1., 1., 1.]]],


        [[[1., 0., 1.],
          [0., 1., 1.],
          [0., 1., 1.]]],


        [[[0., 0., 1.],
          [0., 1., 0.],
          [1., 1., 1.]]],


        [[[1., 1., 1.],
          [0., 1., 1.],
          [0., 0., 0.]]],


        [[[1., 0., 1.],
          [1., 1., 0.],
          [1., 1., 0.]]]], device='cuda:0'))]


In [8]:
# Prunning is applied to each forward pass with forward_pre_hooks.
print(module._forward_pre_hooks)

OrderedDict([(0, <torch.nn.utils.prune.RandomUnstructured object at 0x000002016B395088>)])


In [9]:
# 3 smallest entries from bias
prune.l1_unstructured(module, name='bias', amount=3)

Conv2d(1, 6, kernel_size=(3, 3), stride=(1, 1))

In [10]:
print(module._forward_pre_hooks)

OrderedDict([(0, <torch.nn.utils.prune.RandomUnstructured object at 0x000002016B395088>), (1, <torch.nn.utils.prune.L1Unstructured object at 0x000002016B3BA8C8>)])


### Iterative Prunning

In [12]:
#
prune.ln_structured(module, name='weight', amount=0.5,n=2, dim=0)
print(module.weight)

tensor([[[[-0.1344, -0.2570, -0.0766],
          [ 0.1902,  0.2730, -0.0570],
          [ 0.3233,  0.0250, -0.0000]]],


        [[[ 0.0000, -0.0000,  0.0000],
          [ 0.0000,  0.0000, -0.0000],
          [-0.0000, -0.0000,  0.0000]]],


        [[[-0.0090, -0.0000,  0.1847],
          [ 0.0000,  0.3054,  0.1787],
          [-0.0000, -0.1974, -0.3088]]],


        [[[-0.0000,  0.0000,  0.2393],
          [ 0.0000, -0.2841, -0.0000],
          [ 0.3100, -0.3310,  0.2429]]],


        [[[-0.0000, -0.0000,  0.0000],
          [ 0.0000,  0.0000, -0.0000],
          [-0.0000, -0.0000,  0.0000]]],


        [[[ 0.0000,  0.0000, -0.0000],
          [-0.0000,  0.0000,  0.0000],
          [ 0.0000, -0.0000,  0.0000]]]], device='cuda:0',
       grad_fn=<MulBackward0>)


In [14]:
for hook in module._forward_pre_hooks.values():
    if hook._tensor_name == 'weight':
        break
        
print(list(hook))

[<torch.nn.utils.prune.RandomUnstructured object at 0x000002016B395088>, <torch.nn.utils.prune.LnStructured object at 0x0000020133ADE688>]


In [15]:
model.state_dict().keys()

odict_keys(['conv1.weight_orig', 'conv1.bias_orig', 'conv1.weight_mask', 'conv1.bias_mask', 'conv2.weight', 'conv2.bias', 'fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias', 'fc3.weight', 'fc3.bias'])

In [16]:
prune.remove(module, 'weight')
print(list(module.named_parameters()))

[('bias_orig', Parameter containing:
tensor([-0.1460, -0.0215, -0.1873,  0.1574,  0.1332,  0.3207], device='cuda:0',
       requires_grad=True)), ('weight', Parameter containing:
tensor([[[[-0.1344, -0.2570, -0.0766],
          [ 0.1902,  0.2730, -0.0570],
          [ 0.3233,  0.0250, -0.0000]]],


        [[[ 0.0000, -0.0000,  0.0000],
          [ 0.0000,  0.0000, -0.0000],
          [-0.0000, -0.0000,  0.0000]]],


        [[[-0.0090, -0.0000,  0.1847],
          [ 0.0000,  0.3054,  0.1787],
          [-0.0000, -0.1974, -0.3088]]],


        [[[-0.0000,  0.0000,  0.2393],
          [ 0.0000, -0.2841, -0.0000],
          [ 0.3100, -0.3310,  0.2429]]],


        [[[-0.0000, -0.0000,  0.0000],
          [ 0.0000,  0.0000, -0.0000],
          [-0.0000, -0.0000,  0.0000]]],


        [[[ 0.0000,  0.0000, -0.0000],
          [-0.0000,  0.0000,  0.0000],
          [ 0.0000, -0.0000,  0.0000]]]], device='cuda:0', requires_grad=True))]


In [17]:
# prunning multiple
new_model = LeNet()
for name, module in new_model.named_modules():
    if isinstance(module, torch.nn.Conv2d):
        prune.l1_unstructured(module, name='weight', amount=0.2)
    elif isinstance(module, torch.nn.Linear):
        prune.l1_unstructured(module, name='weight', amount=0.4)
print(dict(new_model.named_buffers()).keys())

dict_keys(['conv1.weight_mask', 'conv2.weight_mask', 'fc1.weight_mask', 'fc2.weight_mask', 'fc3.weight_mask'])


## Global Prunning