In [1]:
import torch
from torch import nn
import torch.nn.utils.prune as prune
import torch.nn.functional as F
import copy

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        # 1 input image channel, 6 output channels, 3x3 square conv kernel
        self.conv1 = nn.Conv2d(1, 6, 3)
        self.conv2 = nn.Conv2d(6, 16, 3)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)  # 5x5 image dimension
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(-1, int(x.nelement() / x.shape[0]))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

model = LeNet().to(device=device)
module = model.conv1
proto = copy.deepcopy(model.state_dict())

In [3]:
prune.l1_unstructured(module, name="weight", amount=0.1)
proto = copy.deepcopy(model.state_dict())
print(module.weight)

tensor([[[[-0.1056,  0.0837, -0.1221],
          [ 0.1951, -0.1545,  0.1362],
          [ 0.2606,  0.2672,  0.2394]]],


        [[[-0.1456,  0.1818, -0.1699],
          [ 0.2134,  0.2520,  0.1077],
          [-0.0732, -0.3249, -0.2332]]],


        [[[-0.2482,  0.0892, -0.1317],
          [-0.0000,  0.3186, -0.1989],
          [-0.3071,  0.2926, -0.1845]]],


        [[[ 0.0795,  0.0696,  0.2734],
          [ 0.2202, -0.2917,  0.0801],
          [ 0.1092,  0.1431,  0.0000]]],


        [[[-0.1142, -0.2588,  0.1941],
          [ 0.3320, -0.1400,  0.1245],
          [-0.2392, -0.1399, -0.1429]]],


        [[[-0.1619,  0.0000, -0.3299],
          [-0.2776,  0.1654, -0.2576],
          [-0.0000,  0.0000,  0.0913]]]], grad_fn=<MulBackward0>)


In [4]:
prune.l1_unstructured(module, name="weight", amount=1.0)
model.load_state_dict(proto)
print(module.weight)

tensor([[[[-0., 0., -0.],
          [0., -0., 0.],
          [0., 0., 0.]]],


        [[[-0., 0., -0.],
          [0., 0., 0.],
          [-0., -0., -0.]]],


        [[[-0., 0., -0.],
          [-0., 0., -0.],
          [-0., 0., -0.]]],


        [[[0., 0., 0.],
          [0., -0., 0.],
          [0., 0., 0.]]],


        [[[-0., -0., 0.],
          [0., -0., 0.],
          [-0., -0., -0.]]],


        [[[-0., 0., -0.],
          [-0., 0., -0.],
          [-0., 0., 0.]]]], grad_fn=<MulBackward0>)


In [10]:
print(proto)



OrderedDict([('conv1.weight', tensor([[[[-0.2817, -0.3254, -0.1070],
          [ 0.3194, -0.2878, -0.0804],
          [-0.0353,  0.0052, -0.2725]]],


        [[[ 0.0510,  0.3107, -0.2943],
          [-0.2766,  0.2221,  0.0837],
          [-0.0537,  0.1197, -0.1054]]],


        [[[ 0.2520, -0.2222, -0.3022],
          [ 0.2559,  0.1009, -0.0644],
          [ 0.2686,  0.0617,  0.1090]]],


        [[[ 0.1900,  0.1891,  0.1088],
          [-0.0662,  0.0357, -0.1480],
          [-0.0417,  0.1976,  0.0017]]],


        [[[ 0.2696, -0.0906, -0.2381],
          [-0.0719, -0.2938,  0.1240],
          [-0.1816,  0.0054, -0.0804]]],


        [[[-0.2979,  0.0344,  0.2138],
          [-0.0797,  0.1555, -0.3224],
          [-0.1193, -0.2894,  0.2713]]]])), ('conv1.bias', tensor([ 0.1579,  0.0987, -0.2007,  0.0988, -0.0862,  0.0190])), ('conv2.weight', tensor([[[[-0.1072, -0.1075, -0.0682],
          [-0.0137,  0.0624,  0.1208],
          [-0.0076, -0.0858, -0.0185]],

         [[-0.0430, -0.1360