# Pruning Demo


In [1]:
import sys 
sys.path.append('../src/')

from models import LeNetFC 

import torch
from torch import nn
import torch.nn.utils.prune as prune
import torch.nn.functional as F


In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = LeNetFC().to(device)

## Inspect a module

In [3]:
module = model.fc1
print(list(module.named_parameters()))



[('weight', Parameter containing:
tensor([[-0.0072,  0.0125, -0.0114,  ...,  0.0336, -0.0078,  0.0305],
        [ 0.0037,  0.0083, -0.0224,  ...,  0.0031,  0.0111,  0.0004],
        [ 0.0186, -0.0054,  0.0023,  ...,  0.0094, -0.0224, -0.0021],
        ...,
        [-0.0325,  0.0182,  0.0283,  ...,  0.0329,  0.0241, -0.0234],
        [ 0.0230, -0.0012, -0.0095,  ...,  0.0044, -0.0146, -0.0155],
        [ 0.0194,  0.0081, -0.0092,  ..., -0.0185, -0.0040, -0.0191]],
       requires_grad=True)), ('bias', Parameter containing:
tensor([-3.1015e-02,  2.3865e-02,  2.6890e-02, -3.1784e-02,  3.0637e-03,
         3.9912e-03,  1.5114e-02,  2.7278e-02, -9.3997e-03, -8.3075e-03,
        -2.4223e-03, -1.0882e-02, -2.8613e-02,  2.5166e-02,  1.3207e-02,
         2.3371e-02,  1.8772e-03,  2.6723e-02, -2.3492e-02,  3.3078e-02,
         1.8470e-02, -1.8554e-02, -1.8562e-02,  5.0716e-03,  3.5132e-02,
        -1.9280e-02,  1.1390e-02,  3.0036e-02, -1.6311e-02,  3.2229e-02,
        -7.9022e-03, -2.8195e-02, 

In [4]:
# prune the weights (not biases here)
p = 0.3
prune.random_unstructured(module, name='weight', amount=p)


Linear(in_features=784, out_features=300, bias=True)

In [5]:
# notice 'weight_orig' is stored, 'weight' now contains pruned params
print(list(module.named_parameters()))

[('bias', Parameter containing:
tensor([ 0.0208,  0.0041, -0.0104, -0.0334,  0.0135, -0.0234, -0.0034, -0.0069,
        -0.0122,  0.0261,  0.0019,  0.0091, -0.0218,  0.0020, -0.0077,  0.0352,
         0.0352,  0.0292, -0.0031, -0.0191, -0.0020,  0.0271,  0.0223, -0.0271,
        -0.0354,  0.0079, -0.0193, -0.0305, -0.0163,  0.0302, -0.0009, -0.0099,
        -0.0023,  0.0276,  0.0137,  0.0104, -0.0131, -0.0293,  0.0302, -0.0299,
        -0.0143,  0.0234, -0.0095,  0.0079, -0.0148,  0.0045, -0.0010,  0.0357,
        -0.0239, -0.0020,  0.0252,  0.0020, -0.0008, -0.0265,  0.0161, -0.0172,
        -0.0338,  0.0101,  0.0002, -0.0014, -0.0280,  0.0115,  0.0237, -0.0202,
        -0.0197,  0.0016,  0.0285, -0.0100, -0.0252,  0.0166, -0.0026,  0.0219,
        -0.0257, -0.0307,  0.0342,  0.0321,  0.0306, -0.0112, -0.0102, -0.0183,
        -0.0088, -0.0322,  0.0121,  0.0078,  0.0011,  0.0281,  0.0172, -0.0124,
         0.0080, -0.0259, -0.0216,  0.0236, -0.0061,  0.0152,  0.0257,  0.0233,
        

In [6]:
# see mask
print(module.weight_mask[:2])


tensor([[1., 0., 1.,  ..., 1., 1., 0.],
        [1., 0., 0.,  ..., 1., 1., 1.]])


In [7]:
# see new pruned params (compare to weight_orig printed above, see how mask is applied?)
print(module.weight[:2])

tensor([[ 0.0259, -0.0000, -0.0148,  ..., -0.0338, -0.0061, -0.0000],
        [-0.0279,  0.0000,  0.0000,  ..., -0.0110,  0.0140, -0.0108]],
       grad_fn=<SliceBackward>)


In [8]:
# remove 3 smallest bias params according to L1 norm
prune.l1_unstructured(module, name="bias", amount=3)


Linear(in_features=784, out_features=300, bias=True)

In [9]:
# note bias_orig
print(list(module.named_parameters()))



[('weight_orig', Parameter containing:
tensor([[ 0.0259, -0.0029, -0.0148,  ..., -0.0338, -0.0061, -0.0222],
        [-0.0279,  0.0280,  0.0118,  ..., -0.0110,  0.0140, -0.0108],
        [-0.0050,  0.0167, -0.0258,  ..., -0.0210, -0.0264,  0.0171],
        ...,
        [ 0.0336, -0.0113,  0.0214,  ..., -0.0060, -0.0026,  0.0189],
        [-0.0337, -0.0168,  0.0072,  ..., -0.0021,  0.0215,  0.0010],
        [-0.0235, -0.0148, -0.0253,  ...,  0.0350, -0.0245, -0.0002]],
       requires_grad=True)), ('bias_orig', Parameter containing:
tensor([ 0.0208,  0.0041, -0.0104, -0.0334,  0.0135, -0.0234, -0.0034, -0.0069,
        -0.0122,  0.0261,  0.0019,  0.0091, -0.0218,  0.0020, -0.0077,  0.0352,
         0.0352,  0.0292, -0.0031, -0.0191, -0.0020,  0.0271,  0.0223, -0.0271,
        -0.0354,  0.0079, -0.0193, -0.0305, -0.0163,  0.0302, -0.0009, -0.0099,
        -0.0023,  0.0276,  0.0137,  0.0104, -0.0131, -0.0293,  0.0302, -0.0299,
        -0.0143,  0.0234, -0.0095,  0.0079, -0.0148,  0.0045, 

## Iterative Pruning

In [10]:
# total magnitude of weights
torch.sum(abs(module.weight))

tensor(2940.2478, grad_fn=<SumBackward0>)

In [11]:
prune.ln_structured(module, name='weight', amount=p+.2, n=2, dim=0)


Linear(in_features=784, out_features=300, bias=True)

In [12]:
# can see half the weights are now zeroed out (pruned/remove)
torch.sum(abs(module.weight))



tensor(1511.8176, grad_fn=<SumBackward0>)

In [13]:
# history of pruning applied to weight param
for hook in module._forward_pre_hooks.values():
    if hook._tensor_name == 'weight':
        break

for h in hook:
    print(h)

<torch.nn.utils.prune.RandomUnstructured object at 0x7fae6835a590>
<torch.nn.utils.prune.LnStructured object at 0x7faea8835490>


In [14]:
# serialized and retrievable 
print(model.state_dict().keys())


odict_keys(['fc1.weight_orig', 'fc1.bias_orig', 'fc1.weight_mask', 'fc1.bias_mask', 'fc2.weight', 'fc2.bias', 'fc3.weight', 'fc3.bias'])


## Make Pruning Permanent
Remove pruning re-parametrization.  

In [15]:
# Before: 
print(list(module.named_parameters()))



[('weight_orig', Parameter containing:
tensor([[ 0.0259, -0.0029, -0.0148,  ..., -0.0338, -0.0061, -0.0222],
        [-0.0279,  0.0280,  0.0118,  ..., -0.0110,  0.0140, -0.0108],
        [-0.0050,  0.0167, -0.0258,  ..., -0.0210, -0.0264,  0.0171],
        ...,
        [ 0.0336, -0.0113,  0.0214,  ..., -0.0060, -0.0026,  0.0189],
        [-0.0337, -0.0168,  0.0072,  ..., -0.0021,  0.0215,  0.0010],
        [-0.0235, -0.0148, -0.0253,  ...,  0.0350, -0.0245, -0.0002]],
       requires_grad=True)), ('bias_orig', Parameter containing:
tensor([ 0.0208,  0.0041, -0.0104, -0.0334,  0.0135, -0.0234, -0.0034, -0.0069,
        -0.0122,  0.0261,  0.0019,  0.0091, -0.0218,  0.0020, -0.0077,  0.0352,
         0.0352,  0.0292, -0.0031, -0.0191, -0.0020,  0.0271,  0.0223, -0.0271,
        -0.0354,  0.0079, -0.0193, -0.0305, -0.0163,  0.0302, -0.0009, -0.0099,
        -0.0023,  0.0276,  0.0137,  0.0104, -0.0131, -0.0293,  0.0302, -0.0299,
        -0.0143,  0.0234, -0.0095,  0.0079, -0.0148,  0.0045, 

In [16]:
# Before: 
print(list(module.named_buffers()))

[('weight_mask', tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [1., 0., 0.,  ..., 1., 1., 1.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [1., 1., 0.,  ..., 1., 0., 0.],
        [1., 1., 0.,  ..., 1., 1., 0.]])), ('bias_mask', tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1

In [17]:
# Before
print(module.weight)




tensor([[ 0.0000, -0.0000, -0.0000,  ..., -0.0000, -0.0000, -0.0000],
        [-0.0279,  0.0000,  0.0000,  ..., -0.0110,  0.0140, -0.0108],
        [-0.0000,  0.0000, -0.0000,  ..., -0.0000, -0.0000,  0.0000],
        ...,
        [ 0.0000, -0.0000,  0.0000,  ..., -0.0000, -0.0000,  0.0000],
        [-0.0337, -0.0168,  0.0000,  ..., -0.0021,  0.0000,  0.0000],
        [-0.0235, -0.0148, -0.0000,  ...,  0.0350, -0.0245, -0.0000]],
       grad_fn=<MulBackward0>)


In [18]:
# note how its just 'weight' in the parameters now, not weight_orig
prune.remove(module, 'weight')
print(list(module.named_parameters()))


[('bias_orig', Parameter containing:
tensor([ 0.0208,  0.0041, -0.0104, -0.0334,  0.0135, -0.0234, -0.0034, -0.0069,
        -0.0122,  0.0261,  0.0019,  0.0091, -0.0218,  0.0020, -0.0077,  0.0352,
         0.0352,  0.0292, -0.0031, -0.0191, -0.0020,  0.0271,  0.0223, -0.0271,
        -0.0354,  0.0079, -0.0193, -0.0305, -0.0163,  0.0302, -0.0009, -0.0099,
        -0.0023,  0.0276,  0.0137,  0.0104, -0.0131, -0.0293,  0.0302, -0.0299,
        -0.0143,  0.0234, -0.0095,  0.0079, -0.0148,  0.0045, -0.0010,  0.0357,
        -0.0239, -0.0020,  0.0252,  0.0020, -0.0008, -0.0265,  0.0161, -0.0172,
        -0.0338,  0.0101,  0.0002, -0.0014, -0.0280,  0.0115,  0.0237, -0.0202,
        -0.0197,  0.0016,  0.0285, -0.0100, -0.0252,  0.0166, -0.0026,  0.0219,
        -0.0257, -0.0307,  0.0342,  0.0321,  0.0306, -0.0112, -0.0102, -0.0183,
        -0.0088, -0.0322,  0.0121,  0.0078,  0.0011,  0.0281,  0.0172, -0.0124,
         0.0080, -0.0259, -0.0216,  0.0236, -0.0061,  0.0152,  0.0257,  0.0233,
   

In [19]:
# no weight_mask anymore 
print(list(module.named_buffers()))


[('bias_mask', tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1

## Prune Layers

In [20]:
# look at layers
for idx, (name, module) in enumerate(model.named_modules()):
    if idx > 0 and idx < 4:
        print(name, module)
        

fc1 Linear(in_features=784, out_features=300, bias=True)
fc2 Linear(in_features=300, out_features=100, bias=True)
fc3 Linear(in_features=100, out_features=10, bias=True)


In [21]:
for idx, (name, module) in enumerate(model.named_modules()):
    if isinstance(module, torch.nn.Linear):
        prune.l1_unstructured(module, name='weight', amount=p)
        prune.l1_unstructured(module, name='bias', amount=len(module.bias)//2)
        


In [23]:
print(dict(model.named_buffers()).keys())  # to verify that all masks exist

dict_keys(['fc1.bias_mask', 'fc1.weight_mask', 'fc2.weight_mask', 'fc2.bias_mask', 'fc3.weight_mask', 'fc3.bias_mask'])


## TODO
* Global pruning? 
* Custom pruning functions? 