# Pruning
Models are usually designed to work well on a specific (academic) dataset.
This is not different for object detection and as such, most models are designed around the Pascal VOC or COCO dataset.
However, operational use cases might contain very different data, and usually this data is also simpler or more coherent than these academic datasets.

Whilst adapting existing networks on a case-by-case basis is certainly possible, it is quite a tiresome and daunting work.
An easier technique is to use existing networks, train them on your data and call it a day!  
Nevertheless, one might wonder whether the chosen network is optimal for their situation or whether the model might be computationally more expensive than necessary...  
Meet **pruning**, which is a technique that will reduce the number of computations in a network in an automated manner, by looking at the importance of the weights in the model.  

Lightnet implements channel-wise soft and hard pruning of convolutions, which means we either set the weights of a certain channel in a convolution to zero (soft), or completely strip the channel from the convolution (hard), which results in less computations and smaller models.
In this tutorial, we will take a look at how you can use the pruning functionality in lightnet, in order to automatically reduce the number of computations of your model, without losing accuracy.

<div class="alert alert-warning">

**Warning:**

The pruning functionality in lightnet requires the [onnx](https://github.com/onnx/onnx) library.  
You can install it by running `pip install onnx`.

</div>

In [1]:
# Basic imports
import lightnet as ln
import torch
import torchvision
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import brambox as bb
import os
import warnings

# Settings
ln.logger.setConsoleLevel('ERROR')             # Only show error log messages
bb.logger.setConsoleLevel('ERROR')             # Only show error log messages

# This is only to have a cleaner documentation and should generally not be used
warnings.filterwarnings('ignore')

## Soft Pruning

We will first look at soft pruning, which means we set the weights of a certain channel of a convolution to zero.  
This can be used before hard pruning, by iteratively soft pruning a certain percentage of a network and retraining until the original accuracy is reached.
Once the network does not need retraining to reach the original accuracy, we know that this percentage of the network is not necessary and we can remove it.  
Another use case for soft-pruning is that it is a kind of regularisation technique, which can be used similar to dropout, in order to increase the accuracy of a network during training.

Before pruning a network, you start by training it, which we assume you already did.
We thus load our network and define some extra bits and bops which we will need.
Once we have our model and optimizer, we can create a [Pruner](../api/generated/lightnet.prune.Pruner.rst).
Lightnet comes with a few different pruner implementations, here we will use a basic [L2Pruner](../api/generated/lightnet.prune.L2Pruner.rst).
Once we build our pruner, we can look at the ``prunable_channels`` property, which shows how many convolutional channels can potentially be pruned.

<div class="alert alert-info">

**Note:**

The ``prunable_channels`` property returns the total amount of channels in the prunable convolutions.
Note that we never prune the last channel of a convolution and thus cannot prune all of these channels.

</div>

In [2]:
# Network
net = ln.models.YoloV2(20)
net.load('./yolov2-voc.pt')
dimensions = (1, 3, 416, 416)

# Pruner
pruner = ln.prune.L2Pruner(net, input_dimensions=dimensions, manner="soft")
pruner.prunable_channels

9248

As we can see, the puner shows there is a total amount of 9248 prunable channels in this network.  
In order to prune channels, we simply [\_\_call\_\_](../api/generated/lightnet.prune.Pruner.rst) the pruner with a percentage.

This function will return the actual number of pruned channels, which is 924 in this particular case.

In [3]:
# Prune 10% of the network
pruner(0.10)

924

That was it for soft pruning.
Simple, right!  
Now, let's take a look at hard pruning.

## Hard Pruning

Hard pruning is not so different from soft pruning with Lightnet, but you need to understand that hard pruning will effectively modify your network architecture.
This has a few consequences for the rest of your pipelines.

Any object which holds a reference to your network parameters, will need to be updated each time you prune your network.
When (re-)training a network, usually this means your optimizer.  
A second consequence is that you will not be able to simply load your model anymore, as the model does not keep track of which channels were pruned.

Let's take a look at the optimizer issue first!  
The solution is quite simple, you recreate a new optimizer each time you prune your network.
However, this is quite tedious to do and thus, you can pass your optimizer to the pruner, and the pruner will automatically adapt your optimizer for you.

In [4]:
# Network
net = ln.models.YoloV2(20)
net.load('./yolov2-voc.pt')
dimensions = (1, 3, 416, 416)
optimizer = torch.optim.SGD(net.parameters(), lr=0.0001)

# Pruner
pruner = ln.prune.L2Pruner(net, input_dimensions=dimensions, optimizer=optimizer, manner="hard")
pruner.prunable_channels

9248

We can now prune our model, and the optimizer will be adapted automatically!  
Let us quickly validate this by printing the shape of the first parameter of the network before and after pruning.

In [5]:
print(optimizer.param_groups[0]['params'][0].shape)
pruned = pruner(0.2)
print(optimizer.param_groups[0]['params'][0].shape)
pruned

torch.Size([32, 3, 3, 3])
torch.Size([23, 3, 3, 3])


1849

The second issue we will face when hard pruning, is loading your pruned weights to perform inference.  
If we look at our pruned network and compare it to a new YoloV2 instance, we will see that the number of channels in the convolutions do not match.

In [6]:
# Original
net_original = ln.models.YoloV2(20)
net_original

YoloV2(
  (layers): ModuleList(
    (0): Sequential(
      (1_convbatch): Conv2dBatchReLU(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), LeakyReLU(negative_slope=0.1, inplace=True))
      (2_max): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (3_convbatch): Conv2dBatchReLU(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), LeakyReLU(negative_slope=0.1, inplace=True))
      (4_max): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (5_convbatch): Conv2dBatchReLU(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), LeakyReLU(negative_slope=0.1, inplace=True))
      (6_convbatch): Conv2dBatchReLU(128, 64, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0), LeakyReLU(negative_slope=0.1, inplace=True))
      (7_convbatch): Conv2dBatchReLU(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), LeakyReLU(negative_slope=0.1, inplace=True))
      (8_max): MaxPool2d(kernel_size=2, stride=2, padding=0,

In [7]:
# Pruned
net

YoloV2(
  (layers): ModuleList(
    (0): Sequential(
      (1_convbatch): Conv2dBatchReLU(3, 23, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), LeakyReLU(negative_slope=0.1, inplace=True))
      (2_max): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (3_convbatch): Conv2dBatchReLU(23, 61, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), LeakyReLU(negative_slope=0.1, inplace=True))
      (4_max): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (5_convbatch): Conv2dBatchReLU(61, 124, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), LeakyReLU(negative_slope=0.1, inplace=True))
      (6_convbatch): Conv2dBatchReLU(124, 64, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0), LeakyReLU(negative_slope=0.1, inplace=True))
      (7_convbatch): Conv2dBatchReLU(64, 125, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), LeakyReLU(negative_slope=0.1, inplace=True))
      (8_max): MaxPool2d(kernel_size=2, stride=2, padding=0,

Saving the model works just the same, we simply call the [save()](../api/generated/lightnet.network.module.Lightnet.rst#lightnet.network.module.Lightnet.save) function.
When loading the network however, we need to tell the model to reduce the number of channels where necessary.
The [load_pruned()](../api/generated/lightnet.network.module.Lightnet.rst#lightnet.network.module.Lightnet.load_pruned) method will do this for us automatically!

In [8]:
# Save pruned network
net.save('yolov2-voc-pruned.pt')

# Show difference in weights file size
print(os.path.getsize('yolov2-voc.pt'))
print(os.path.getsize('yolov2-voc-pruned.pt'))

# Load pruned weights
net_original.load_pruned('yolov2-voc-pruned.pt')
net_original

202734087
144584888


YoloV2(
  (layers): ModuleList(
    (0): Sequential(
      (1_convbatch): Conv2dBatchReLU(3, 23, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), LeakyReLU(negative_slope=0.1, inplace=True))
      (2_max): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (3_convbatch): Conv2dBatchReLU(23, 61, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), LeakyReLU(negative_slope=0.1, inplace=True))
      (4_max): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (5_convbatch): Conv2dBatchReLU(61, 124, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), LeakyReLU(negative_slope=0.1, inplace=True))
      (6_convbatch): Conv2dBatchReLU(124, 64, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0), LeakyReLU(negative_slope=0.1, inplace=True))
      (7_convbatch): Conv2dBatchReLU(64, 125, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), LeakyReLU(negative_slope=0.1, inplace=True))
      (8_max): MaxPool2d(kernel_size=2, stride=2, padding=0,

That was it for our pruning tutorial!  
You can check out the [Pascal VOC](./03-A-pascal_voc.rst) guide for an example where we train and prune networks on a real dataset!

Once you trained and pruned your network, you might want to use this network on a device without Python.  
Don't worry, our [Photonnet](https://eavise.gitlab.io/photonnet) C++ library has got your back!

<div class="alert alert-warning">

**Warning:**

Please note that in a real scenario, you will probably want to re-train your network after pruning, in order to retain the same accuracy for your model.
For those situations, it is quite important to use a training, validation and test-set.
Train and prune your model using the training and validation sets, and finally, test your final model on the test-set to report the accuracy.

</div>