In [2]:
import torch
from torch import nn
import torch.nn.utils.prune as prune
import torch.nn.functional as F

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class LeNet(nn.Module):
    def __init__(self, *args, **kwargs) -> None:
        super(LeNet, self).__init__(*args, **kwargs)
        
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120) # conv 연산 결과 5x5 크기의 16채널 수의 이미지
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)
    
    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(-1, int(x.nelement() / x.shape[0]))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
    
model = LeNet().to(device=device)

In [5]:
module = model.conv1 
print(list(module.named_parameters()))

[('weight', Parameter containing:
tensor([[[[ 4.8688e-02,  9.8835e-02, -1.9668e-01,  6.0688e-03, -1.3360e-01],
          [-7.1013e-02,  1.8015e-01, -1.9867e-01, -3.3947e-02,  1.2073e-01],
          [-9.5582e-02,  1.9731e-02, -1.2488e-01,  5.9975e-02, -1.8136e-01],
          [-2.0051e-05, -1.3049e-01, -1.7201e-01, -7.7237e-02,  4.4175e-02],
          [-1.9649e-01,  8.8465e-02,  4.6810e-02,  8.1950e-02,  1.3690e-01]]],


        [[[-4.5177e-02,  1.7487e-01, -8.4083e-03, -8.8938e-02, -1.5656e-01],
          [ 1.8796e-01,  9.2614e-02,  8.5879e-02, -6.5901e-02,  3.8790e-02],
          [-3.8070e-02,  2.9147e-02,  1.7186e-02, -8.3590e-02,  1.0324e-01],
          [ 1.8732e-01, -1.1453e-01, -1.4025e-02, -7.1236e-02, -1.6233e-01],
          [-9.8225e-03,  7.4507e-02, -8.9893e-02,  1.0371e-01,  1.8202e-01]]],


        [[[ 2.9800e-02,  2.5867e-02,  8.2250e-02, -1.0001e-01,  6.7538e-02],
          [-3.3414e-02,  6.6873e-02, -1.3290e-02, -8.9606e-02, -9.0914e-02],
          [ 6.2345e-02, -2.9635e-0

In [6]:
print(list(module.named_buffers()))

[]


In [7]:
prune.random_unstructured(module, name="weight", amount=0.3)

Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))

In [8]:
print(list(module.named_parameters()))

[('bias', Parameter containing:
tensor([ 0.0854,  0.0388,  0.1642, -0.0698,  0.1370,  0.0374], device='cuda:0',
       requires_grad=True)), ('weight_orig', Parameter containing:
tensor([[[[ 4.8688e-02,  9.8835e-02, -1.9668e-01,  6.0688e-03, -1.3360e-01],
          [-7.1013e-02,  1.8015e-01, -1.9867e-01, -3.3947e-02,  1.2073e-01],
          [-9.5582e-02,  1.9731e-02, -1.2488e-01,  5.9975e-02, -1.8136e-01],
          [-2.0051e-05, -1.3049e-01, -1.7201e-01, -7.7237e-02,  4.4175e-02],
          [-1.9649e-01,  8.8465e-02,  4.6810e-02,  8.1950e-02,  1.3690e-01]]],


        [[[-4.5177e-02,  1.7487e-01, -8.4083e-03, -8.8938e-02, -1.5656e-01],
          [ 1.8796e-01,  9.2614e-02,  8.5879e-02, -6.5901e-02,  3.8790e-02],
          [-3.8070e-02,  2.9147e-02,  1.7186e-02, -8.3590e-02,  1.0324e-01],
          [ 1.8732e-01, -1.1453e-01, -1.4025e-02, -7.1236e-02, -1.6233e-01],
          [-9.8225e-03,  7.4507e-02, -8.9893e-02,  1.0371e-01,  1.8202e-01]]],


        [[[ 2.9800e-02,  2.5867e-02,  8.225

In [9]:
print(list(module.named_buffers()))

[('weight_mask', tensor([[[[1., 0., 1., 1., 1.],
          [1., 1., 1., 1., 0.],
          [1., 1., 1., 1., 1.],
          [1., 1., 1., 0., 1.],
          [1., 0., 1., 1., 1.]]],


        [[[1., 1., 1., 1., 1.],
          [0., 1., 1., 0., 1.],
          [1., 0., 0., 1., 0.],
          [0., 0., 1., 0., 1.],
          [1., 1., 0., 1., 1.]]],


        [[[0., 0., 1., 1., 1.],
          [1., 1., 0., 1., 0.],
          [0., 0., 0., 1., 1.],
          [1., 1., 1., 1., 0.],
          [1., 1., 1., 1., 1.]]],


        [[[1., 1., 1., 1., 1.],
          [1., 0., 1., 1., 1.],
          [1., 1., 1., 1., 1.],
          [1., 1., 0., 1., 0.],
          [1., 1., 0., 1., 0.]]],


        [[[0., 1., 0., 1., 1.],
          [0., 0., 0., 1., 1.],
          [1., 1., 1., 1., 0.],
          [0., 1., 1., 1., 1.],
          [0., 1., 1., 1., 0.]]],


        [[[1., 0., 1., 1., 1.],
          [1., 0., 0., 1., 0.],
          [1., 0., 1., 0., 1.],
          [1., 1., 1., 0., 1.],
          [0., 0., 0., 1., 1.]]]], 

In [10]:
print(module._forward_pre_hooks)

OrderedDict([(0, <torch.nn.utils.prune.RandomUnstructured object at 0x7fdd23bf4280>)])


In [11]:
prune.l1_unstructured(module, name="bias", amount=3)

Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))

In [12]:
print(list(module.named_parameters()))

[('weight_orig', Parameter containing:
tensor([[[[ 4.8688e-02,  9.8835e-02, -1.9668e-01,  6.0688e-03, -1.3360e-01],
          [-7.1013e-02,  1.8015e-01, -1.9867e-01, -3.3947e-02,  1.2073e-01],
          [-9.5582e-02,  1.9731e-02, -1.2488e-01,  5.9975e-02, -1.8136e-01],
          [-2.0051e-05, -1.3049e-01, -1.7201e-01, -7.7237e-02,  4.4175e-02],
          [-1.9649e-01,  8.8465e-02,  4.6810e-02,  8.1950e-02,  1.3690e-01]]],


        [[[-4.5177e-02,  1.7487e-01, -8.4083e-03, -8.8938e-02, -1.5656e-01],
          [ 1.8796e-01,  9.2614e-02,  8.5879e-02, -6.5901e-02,  3.8790e-02],
          [-3.8070e-02,  2.9147e-02,  1.7186e-02, -8.3590e-02,  1.0324e-01],
          [ 1.8732e-01, -1.1453e-01, -1.4025e-02, -7.1236e-02, -1.6233e-01],
          [-9.8225e-03,  7.4507e-02, -8.9893e-02,  1.0371e-01,  1.8202e-01]]],


        [[[ 2.9800e-02,  2.5867e-02,  8.2250e-02, -1.0001e-01,  6.7538e-02],
          [-3.3414e-02,  6.6873e-02, -1.3290e-02, -8.9606e-02, -9.0914e-02],
          [ 6.2345e-02, -2.96

In [13]:
print(list(module.named_buffers()))

[('weight_mask', tensor([[[[1., 0., 1., 1., 1.],
          [1., 1., 1., 1., 0.],
          [1., 1., 1., 1., 1.],
          [1., 1., 1., 0., 1.],
          [1., 0., 1., 1., 1.]]],


        [[[1., 1., 1., 1., 1.],
          [0., 1., 1., 0., 1.],
          [1., 0., 0., 1., 0.],
          [0., 0., 1., 0., 1.],
          [1., 1., 0., 1., 1.]]],


        [[[0., 0., 1., 1., 1.],
          [1., 1., 0., 1., 0.],
          [0., 0., 0., 1., 1.],
          [1., 1., 1., 1., 0.],
          [1., 1., 1., 1., 1.]]],


        [[[1., 1., 1., 1., 1.],
          [1., 0., 1., 1., 1.],
          [1., 1., 1., 1., 1.],
          [1., 1., 0., 1., 0.],
          [1., 1., 0., 1., 0.]]],


        [[[0., 1., 0., 1., 1.],
          [0., 0., 0., 1., 1.],
          [1., 1., 1., 1., 0.],
          [0., 1., 1., 1., 1.],
          [0., 1., 1., 1., 0.]]],


        [[[1., 0., 1., 1., 1.],
          [1., 0., 0., 1., 0.],
          [1., 0., 1., 0., 1.],
          [1., 1., 1., 0., 1.],
          [0., 0., 0., 1., 1.]]]], 

In [14]:
print(module.bias)

tensor([0.0854, 0.0000, 0.1642, -0.0000, 0.1370, 0.0000], device='cuda:0',
       grad_fn=<MulBackward0>)


In [15]:
print(module._forward_pre_hooks)

OrderedDict([(0, <torch.nn.utils.prune.RandomUnstructured object at 0x7fdd23bf4280>), (1, <torch.nn.utils.prune.L1Unstructured object at 0x7fdd2448b9d0>)])


In [16]:
prune.ln_structured(module, name="weight", amount=0.5, n=2, dim=0)

Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))

In [17]:
print(module.weight)

tensor([[[[ 4.8688e-02,  0.0000e+00, -1.9668e-01,  6.0688e-03, -1.3360e-01],
          [-7.1013e-02,  1.8015e-01, -1.9867e-01, -3.3947e-02,  0.0000e+00],
          [-9.5582e-02,  1.9731e-02, -1.2488e-01,  5.9975e-02, -1.8136e-01],
          [-2.0051e-05, -1.3049e-01, -1.7201e-01, -0.0000e+00,  4.4175e-02],
          [-1.9649e-01,  0.0000e+00,  4.6810e-02,  8.1950e-02,  1.3690e-01]]],


        [[[-0.0000e+00,  0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00, -0.0000e+00,  0.0000e+00],
          [-0.0000e+00,  0.0000e+00,  0.0000e+00, -0.0000e+00,  0.0000e+00],
          [ 0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00],
          [-0.0000e+00,  0.0000e+00, -0.0000e+00,  0.0000e+00,  0.0000e+00]]],


        [[[ 0.0000e+00,  0.0000e+00,  8.2250e-02, -1.0001e-01,  6.7538e-02],
          [-3.3414e-02,  6.6873e-02, -0.0000e+00, -8.9606e-02, -0.0000e+00],
          [ 0.0000e+00, -0.0000e+00, -0.0000e+00, -1.1755e-01, -1.70

In [18]:
for hook in module._forward_pre_hooks.values():
    if hook._tensor_name == "weight":
        break

print(list(hook))

[<torch.nn.utils.prune.RandomUnstructured object at 0x7fdd23bf4280>, <torch.nn.utils.prune.LnStructured object at 0x7fddec225610>]
