In [1]:
import torch
from torch import nn
import torch.nn.utils.prune as prune
import torch.nn.functional as F

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        # 1 input image channel, 6 output channels, 5x5 square conv kernel
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)  # 5x5 image dimension
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(-1, int(x.nelement() / x.shape[0]))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

model = LeNet().to(device=device)

In [4]:
module = model.conv1
print(list(module.named_parameters()))

[('weight', Parameter containing:
tensor([[[[-0.1311, -0.0106, -0.1297,  0.0740,  0.1125],
          [ 0.0003, -0.1689,  0.1007,  0.0911,  0.1673],
          [ 0.0724,  0.1513, -0.1970,  0.1194, -0.1911],
          [ 0.0301, -0.1660,  0.0438,  0.1178,  0.0236],
          [-0.1019,  0.1640,  0.1223, -0.1617,  0.0311]]],


        [[[-0.1381, -0.1333, -0.1951, -0.0924,  0.0854],
          [-0.0095,  0.1625,  0.1237,  0.0638,  0.0023],
          [ 0.1473, -0.0118,  0.0777,  0.0388,  0.0515],
          [ 0.0467, -0.1149, -0.1511, -0.0465,  0.1774],
          [ 0.0618, -0.1462, -0.0340, -0.0032, -0.1512]]],


        [[[-0.0478, -0.0595,  0.1983,  0.1490,  0.1821],
          [-0.0525,  0.0850, -0.0960, -0.1044, -0.0358],
          [ 0.1889, -0.0765, -0.0149, -0.0304,  0.0293],
          [ 0.0116,  0.1574,  0.0045, -0.1183,  0.1506],
          [ 0.1668,  0.0018,  0.0139, -0.1282, -0.0285]]],


        [[[ 0.1932,  0.1258,  0.0306, -0.0968,  0.0240],
          [ 0.0491,  0.1337,  0.0060, -0.1

In [5]:
print(list(module.named_buffers()))

[]


In [6]:
prune.random_unstructured(module, name="weight", amount=0.3)

Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))

In [7]:
print(list(module.named_parameters()))

[('bias', Parameter containing:
tensor([-0.1185, -0.0826, -0.1663, -0.0797,  0.0712,  0.0842],
       requires_grad=True)), ('weight_orig', Parameter containing:
tensor([[[[-0.1311, -0.0106, -0.1297,  0.0740,  0.1125],
          [ 0.0003, -0.1689,  0.1007,  0.0911,  0.1673],
          [ 0.0724,  0.1513, -0.1970,  0.1194, -0.1911],
          [ 0.0301, -0.1660,  0.0438,  0.1178,  0.0236],
          [-0.1019,  0.1640,  0.1223, -0.1617,  0.0311]]],


        [[[-0.1381, -0.1333, -0.1951, -0.0924,  0.0854],
          [-0.0095,  0.1625,  0.1237,  0.0638,  0.0023],
          [ 0.1473, -0.0118,  0.0777,  0.0388,  0.0515],
          [ 0.0467, -0.1149, -0.1511, -0.0465,  0.1774],
          [ 0.0618, -0.1462, -0.0340, -0.0032, -0.1512]]],


        [[[-0.0478, -0.0595,  0.1983,  0.1490,  0.1821],
          [-0.0525,  0.0850, -0.0960, -0.1044, -0.0358],
          [ 0.1889, -0.0765, -0.0149, -0.0304,  0.0293],
          [ 0.0116,  0.1574,  0.0045, -0.1183,  0.1506],
          [ 0.1668,  0.0018,  0.

In [8]:
print(list(module.named_buffers()))

[('weight_mask', tensor([[[[0., 1., 1., 0., 0.],
          [1., 1., 1., 1., 1.],
          [1., 1., 1., 0., 1.],
          [1., 1., 1., 0., 1.],
          [0., 1., 1., 1., 1.]]],


        [[[0., 0., 1., 1., 0.],
          [1., 0., 1., 1., 0.],
          [1., 1., 0., 0., 0.],
          [0., 1., 1., 1., 1.],
          [0., 1., 1., 0., 1.]]],


        [[[1., 1., 1., 0., 0.],
          [0., 1., 1., 1., 1.],
          [0., 1., 1., 0., 1.],
          [1., 0., 1., 1., 0.],
          [1., 1., 1., 1., 0.]]],


        [[[1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1.],
          [1., 1., 1., 0., 0.],
          [0., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1.]]],


        [[[1., 1., 1., 1., 0.],
          [0., 0., 1., 1., 1.],
          [1., 0., 1., 1., 1.],
          [1., 1., 1., 1., 0.],
          [1., 0., 1., 1., 0.]]],


        [[[0., 0., 1., 0., 0.],
          [1., 1., 1., 1., 0.],
          [1., 0., 1., 1., 1.],
          [1., 1., 0., 1., 0.],
          [1., 1., 0., 0., 1.]]]]))

In [9]:
print(module.weight)

tensor([[[[-0.0000, -0.0106, -0.1297,  0.0000,  0.0000],
          [ 0.0003, -0.1689,  0.1007,  0.0911,  0.1673],
          [ 0.0724,  0.1513, -0.1970,  0.0000, -0.1911],
          [ 0.0301, -0.1660,  0.0438,  0.0000,  0.0236],
          [-0.0000,  0.1640,  0.1223, -0.1617,  0.0311]]],


        [[[-0.0000, -0.0000, -0.1951, -0.0924,  0.0000],
          [-0.0095,  0.0000,  0.1237,  0.0638,  0.0000],
          [ 0.1473, -0.0118,  0.0000,  0.0000,  0.0000],
          [ 0.0000, -0.1149, -0.1511, -0.0465,  0.1774],
          [ 0.0000, -0.1462, -0.0340, -0.0000, -0.1512]]],


        [[[-0.0478, -0.0595,  0.1983,  0.0000,  0.0000],
          [-0.0000,  0.0850, -0.0960, -0.1044, -0.0358],
          [ 0.0000, -0.0765, -0.0149, -0.0000,  0.0293],
          [ 0.0116,  0.0000,  0.0045, -0.1183,  0.0000],
          [ 0.1668,  0.0018,  0.0139, -0.1282, -0.0000]]],


        [[[ 0.1932,  0.1258,  0.0306, -0.0968,  0.0240],
          [ 0.0491,  0.1337,  0.0060, -0.1103,  0.0267],
          [ 0.1390,

In [10]:
print(module._forward_pre_hooks)

OrderedDict([(0, <torch.nn.utils.prune.RandomUnstructured object at 0x7efa41b40d10>)])


In [11]:
prune.l1_unstructured(module, name="bias", amount=3)

Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))

In [12]:
print(list(module.named_parameters()))

[('weight_orig', Parameter containing:
tensor([[[[-0.1311, -0.0106, -0.1297,  0.0740,  0.1125],
          [ 0.0003, -0.1689,  0.1007,  0.0911,  0.1673],
          [ 0.0724,  0.1513, -0.1970,  0.1194, -0.1911],
          [ 0.0301, -0.1660,  0.0438,  0.1178,  0.0236],
          [-0.1019,  0.1640,  0.1223, -0.1617,  0.0311]]],


        [[[-0.1381, -0.1333, -0.1951, -0.0924,  0.0854],
          [-0.0095,  0.1625,  0.1237,  0.0638,  0.0023],
          [ 0.1473, -0.0118,  0.0777,  0.0388,  0.0515],
          [ 0.0467, -0.1149, -0.1511, -0.0465,  0.1774],
          [ 0.0618, -0.1462, -0.0340, -0.0032, -0.1512]]],


        [[[-0.0478, -0.0595,  0.1983,  0.1490,  0.1821],
          [-0.0525,  0.0850, -0.0960, -0.1044, -0.0358],
          [ 0.1889, -0.0765, -0.0149, -0.0304,  0.0293],
          [ 0.0116,  0.1574,  0.0045, -0.1183,  0.1506],
          [ 0.1668,  0.0018,  0.0139, -0.1282, -0.0285]]],


        [[[ 0.1932,  0.1258,  0.0306, -0.0968,  0.0240],
          [ 0.0491,  0.1337,  0.0060,