Pytorch model pruning

https://pytorch.org/tutorials/intermediate/pruning_tutorial.html

In [None]:
import torch
import torch.nn as nn
import torch.nn.utils.prune as prune
import torch.nn.functional as F

Create model

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        # 1 input image channel, 6 output channels, 3x3 square conv kernel
        self.conv1 = nn.Conv2d(1, 6, 3)
        self.conv2 = nn.Conv2d(6, 16, 3)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)  # 5x5 image dimension
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(-1, int(x.nelement() / x.shape[0]))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [None]:
model = LeNet().to(device=device)

Inspect module

In [None]:
module = model.conv1
print(list(module.named_parameters()))

```conv1``` is a $3\times3$ convolutional layer with $6$ filters. 

In [None]:
# currently no buffer is allocated
print(list(module.named_buffers()))

Pruning module's ```weight```

In [None]:
prune.random_unstructured(module, name="weight", amount=0.3)

In [None]:
print(list(module.named_buffers()))

After pruning ```weight``` of the module, we get ```weight_masked``` parameters on buffer. Based on the command, we are zero-ing $30\%$ of ```weight``` in the ```conv1``` module by random factor. 

In [None]:
print(module.weight)

Our new ```weight``` now has zero weight based on zero value in ```weight_mask``` and the others is kept as original.

In [None]:
print(module._forward_pre_hooks)

Because currently we only pruning ```weight``` of the module so far, we only have one pre-hook before forward pass are applied. In the new forward pass, we apply the mask to the original weight.

Pruning ```bias```

In [None]:
prune.l1_unstructured(module, name="bias", amount=3)

Here we are pruining three smallest bias entry by L1 norm

In [None]:
print(list(module.named_parameters()))

Now we have parameter ```weight_orig``` and ```bias_orig``` indicating our module ```weight``` and ```bias``` are modified. ```weight_orig``` and ```bias_orig``` store the original weight and bias before pruning. 

In [None]:
print(list(module.named_buffers()))

We add ```bias_mask``` to the buffers. As you can see, the ```bias_mask``` has three zero value indicating that we are pruning the three smallest bias entry.

Let's see the new bias after pruning

In [None]:
print(module.bias)

In [None]:
print(module._forward_pre_hooks)

We are now have two pre-hooks before forward pass are applied. In the new forward pass, we apply the mask to the original weight and original bias too.

We could further do iterative pruning on the module

In [None]:
prune.ln_structured(module, name="weight", amount=0.5, n=2, dim=0)

print(module.weight)

We are further prune ```conv1``` weight (```name=weight```) by using L2 norm (```n=2```) by half amount (```amount=0.5```) of the layer size (```dim=0```).

In [None]:
for hook in module._forward_pre_hooks.values():
    if hook._tensor_name == "weight":  # select out the correct hook
        break

print(list(hook))  # pruning history in the container

Our pre-hook on the ```weight``` parameter now has two consecutive pruning routines

In [None]:
print(model.state_dict().keys())

Notify that in the model ```state_dict()``` now we have additional parameter corresponding to module mask. This can easily be saved by using ```torch.save()```.

Make pruning **permanent**!

In [None]:
# removing re-parameterization on ```weight```
prune.remove(module, "weight")
print(list(module.named_parameters()))

Notify that now our ```weight``` is no longer named as ```weight_orig``` meaning that we are permanently pruning the ```weight``` of the module. However, our ```bias``` still has data in buffer because we don't make the changes permanent.

In [None]:
print(list(module.named_buffers()))

Pruning multiple parameters in model

In [None]:
new_model = LeNet()
for name, module in new_model.named_modules():
    # prune 20 percents of connections (weights) in all 2d-conv layers
    if isinstance(module, nn.Conv2d):
        prune.l1_unstructured(module, name="weight", amount=0.2)
    # prune 40 percents of connections (weights) in all linear layers
    elif isinstance(module, nn.Linear):
        prune.l1_unstructured(module, name="weight", amount=0.4)

print(dict(new_model.named_buffers()).keys()) # verify all modules have corresponding mask

Global pruning

In [None]:
model = LeNet()

parameters_to_prune = (
    (model.conv1, "weight"),
    (model.conv2, "weight"),
    (model.fc1, "weight"),
    (model.fc2, "weight"),
    (model.fc3, "weight"),
)

Sparsity original parameters

In [None]:
def check_module_sparsity(module):
    percentage = 100. * float(torch.sum(module.weight == 0)) / float(module.weight.nelement())
    print(
        f"Sparsity in {module.__module__}: {percentage:.2f}"
    )

In [None]:
check_module_sparsity(model.conv1)
check_module_sparsity(model.conv2)
check_module_sparsity(model.fc1)
check_module_sparsity(model.fc2)
check_module_sparsity(model.fc3)

Do pruning

In [None]:
prune.global_unstructured(
    parameters_to_prune,
    pruning_method=prune.L1Unstructured,
    amount=0.2,
)

We are pruning 20 percent of parameters overall

In [None]:
check_module_sparsity(model.conv1)
check_module_sparsity(model.conv2)
check_module_sparsity(model.fc1)
check_module_sparsity(model.fc2)
check_module_sparsity(model.fc3)