# Pruning Schedules

> Make your neural network sparse with fastai

In [None]:
#hide
from nbdev.showdoc import *

In [None]:
#all_slow

The simplest way to perform pruning is called One-Shot Pruning. It consists of the following three steps:

![alt text](imgs/one_shot.pdf "Title")

1. You first need to train a network
2. You then need to remove some weights (depending on your criteria, needs,...)
3. You fine-tune the remaining weights to recover from the loss of parameters.

With fasterai, this is really easy to do. Let's illustrate it by an example:

In [None]:
#hide
from fastai.vision.all import *
from fastai.callback.all import *

from fasterai.sparsifier import *
from fasterai.criteria import *
from fasterai.sparsify_callback import *
from fasterai.schedule import *

import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
path = untar_data(URLs.PETS)

In [None]:
files = get_image_files(path/"images")

In [None]:
def label_func(f): return f[0].isupper()

In [None]:
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

In [None]:
dls = ImageDataLoaders.from_name_func(path, files, label_func, item_tfms=Resize(64), device=device)

We will first train a network without any pruning, which will serve as a baseline.

In [None]:
learn = cnn_learner(dls, resnet18, metrics=accuracy)
learn.unfreeze()

In [None]:
learn.fit_one_cycle(3)

epoch,train_loss,valid_loss,accuracy,time
0,0.655453,0.456962,0.830176,00:21
1,0.347105,0.249341,0.904601,00:20
2,0.175726,0.205904,0.920839,00:20


## One-Shot Pruning

There are two main ways that you can perform One-Shot Pruning with fasterai. 

1. You already possess a trained network and want to prune it
2. You don't possess such a network and have to train it from scratch

### 1. You possess a trained network

In this case, the step 1) of the One-Shot Pruning process is already done. But you still need to prune the network and then fine-tune it.

Let's say we want to remove $80 \%$ of the weights of our network. This can be done as:

In [None]:
sp = Sparsifier(learn.model, 'weight', 'global', l1_norm)
sp.prune(80)

In [None]:
_, acc = learn.validate(); acc

0.8464140892028809

Obviously, as we removed a good part of trained weights, the perfomance of the network is degraded. This can be solved by retraining our pruned network, making sure that the pruned weights keep their 0 value.

We don't want to update the sparsity level anymore so we have to create a schedule that returns a constant value. Such a schedule exists in fasterai and is called `one_shot` and is defined as:

In [None]:
def one_shot(start, end, pos): return end

We can pass the same arguments to our callback than those used by the Sparsifier.

In [None]:
sp_cb=SparsifyCallback(sparsity=80, granularity='weight', method='global', criteria=l1_norm, sched_func=one_shot)

In [None]:
learn.fit_one_cycle(3, cbs=sp_cb)

Pruning of weight until a sparsity of 80%
Saving Weights at epoch 0


epoch,train_loss,valid_loss,accuracy,time
0,0.189826,0.228184,0.901218,01:41
1,0.174893,0.206082,0.923545,01:45
2,0.087448,0.189777,0.932341,01:41


Sparsity at the end of epoch 0: 80.00%
Sparsity at the end of epoch 1: 80.00%
Sparsity at the end of epoch 2: 80.00%
Final Sparsity: 80.00


We can also check where the pruned weights are in the network

In [None]:
for k,m in enumerate(learn.model.modules()):
    if isinstance(m, nn.Conv2d):
        print(f"Sparsity in {m.__class__.__name__} {k}: {100. * float(torch.sum(m.weight == 0))/ float(m.weight.nelement()):.2f}%")

Sparsity in Conv2d 2: 38.84%
Sparsity in Conv2d 8: 55.45%
Sparsity in Conv2d 11: 50.12%
Sparsity in Conv2d 14: 48.30%
Sparsity in Conv2d 17: 49.84%
Sparsity in Conv2d 21: 53.56%
Sparsity in Conv2d 24: 60.95%
Sparsity in Conv2d 27: 42.13%
Sparsity in Conv2d 30: 60.09%
Sparsity in Conv2d 33: 62.34%
Sparsity in Conv2d 37: 65.65%
Sparsity in Conv2d 40: 70.78%
Sparsity in Conv2d 43: 60.77%
Sparsity in Conv2d 46: 74.65%
Sparsity in Conv2d 49: 77.25%
Sparsity in Conv2d 53: 77.68%
Sparsity in Conv2d 56: 82.73%
Sparsity in Conv2d 59: 59.90%
Sparsity in Conv2d 62: 80.71%
Sparsity in Conv2d 65: 91.69%


> Note: Using Sparsifier to prune the network is not necessary as it will also be called in the Callback. This was used here to better illustrate all the steps.

### 2. You don't possess a trained network

In [None]:
learn = cnn_learner(dls, resnet18, metrics=accuracy)
learn.unfreeze()

In this case, your network needs to be trained before pruning.

You only need to create the Callback with the `one_shot` schedule and set the `start_epoch` argument, i.e. how many epochs you want to train your network before pruning it.

In [None]:
sp_cb=SparsifyCallback(sparsity=80, granularity='weight', method='global', criteria=l1_norm, sched_func=one_shot, start_epoch=3)

Let's start pruningn after 3 epochs and train our model for 6 epochs to have the same total amount of training as before

In [None]:
learn.fit_one_cycle(6, cbs=sp_cb)

Pruning of weight until a sparsity of 80%
Saving Weights at epoch 0


epoch,train_loss,valid_loss,accuracy,time
0,0.668028,0.732009,0.832206,00:11
1,0.40628,0.355843,0.864005,00:11
2,0.254012,0.234607,0.901894,00:12
3,0.156341,0.181618,0.933694,01:40
4,0.072821,0.194896,0.934371,01:40
5,0.048527,0.203621,0.926928,01:44


Sparsity at the end of epoch 0: 0.00%
Sparsity at the end of epoch 1: 0.00%
Sparsity at the end of epoch 2: 0.00%
Sparsity at the end of epoch 3: 80.00%
Sparsity at the end of epoch 4: 80.00%
Sparsity at the end of epoch 5: 80.00%
Final Sparsity: 80.00


Actually, doing the training and pruning in a single cycle works even better !

We can check if we get similar sparsity values across the layers:

In [None]:
for k,m in enumerate(learn.model.modules()):
    if isinstance(m, nn.Conv2d):
        print(f"Sparsity in {m.__class__.__name__} {k}: {100. * float(torch.sum(m.weight == 0))/ float(m.weight.nelement()):.2f}%")

Sparsity in Conv2d 2: 38.58%
Sparsity in Conv2d 8: 55.26%
Sparsity in Conv2d 11: 49.96%
Sparsity in Conv2d 14: 47.94%
Sparsity in Conv2d 17: 49.40%
Sparsity in Conv2d 21: 53.08%
Sparsity in Conv2d 24: 60.51%
Sparsity in Conv2d 27: 41.58%
Sparsity in Conv2d 30: 59.75%
Sparsity in Conv2d 33: 62.18%
Sparsity in Conv2d 37: 65.47%
Sparsity in Conv2d 40: 70.64%
Sparsity in Conv2d 43: 60.69%
Sparsity in Conv2d 46: 74.53%
Sparsity in Conv2d 49: 77.17%
Sparsity in Conv2d 53: 77.64%
Sparsity in Conv2d 56: 82.78%
Sparsity in Conv2d 59: 59.83%
Sparsity in Conv2d 62: 80.76%
Sparsity in Conv2d 65: 91.83%


---

## Iterative Pruning

Researchers have come up with a better way to do pruning than pruning all the weigths in once (as in One-Shot Pruning). The idea is to perform several iterations of pruning and fine-tuning and is thus called Iterative Pruning.

![alt text](imgs/iterative.pdf "Title")

1. You first need to train a network
2. You then need to remove a part of the weights weights (depending on your criteria, needs,...)
3. You fine-tune the remaining weights to recover from the loss of parameters.
4. Back to step 2.

There are two main ways that you can perform Iterative Pruning with fasterai. 

1. You already possess a trained network and want to prune it
2. You don't possess such a network and have to train it from scratch

In [None]:
learn = cnn_learner(dls, resnet18, metrics=accuracy)
learn.unfreeze()

In this case, your network needs to be trained before pruning.

You only need to create the Callback with the `iterative` schedule and set the `start_epoch` argument, i.e. how many epochs you want to train your network before pruning it.

In [None]:
def iterative(start, end, pos, n_steps=3):
    "Perform iterative pruning, and pruning in `n_steps` steps"
    return start + ((end-start)/n_steps)*(np.ceil((pos)*n_steps))

The `iterative` schedules has a `n_steps`parameter, i.e. how many iterations of pruning/fine-tuning you want to perform. To modify its value, we can use the `partial` function like this:

```
iterative = partial(iterative, n_steps=5)
```

In [None]:
sp_cb=SparsifyCallback(sparsity=80, granularity='weight', method='global', criteria=l1_norm, sched_func=iterative, start_epoch=3)

Let's start pruningn after 3 epochs and train our model for 6 epochs to have the same total amount of training as before

In [None]:
learn.fit_one_cycle(6, cbs=sp_cb)

Pruning of weight until a sparsity of 80%
Saving Weights at epoch 0


epoch,train_loss,valid_loss,accuracy,time
0,0.707019,0.753476,0.792963,00:20
1,0.421993,0.341725,0.857916,00:13
2,0.271612,0.234841,0.901218,00:11
3,0.157437,0.185066,0.933694,01:44
4,0.093744,0.164622,0.93843,01:41
5,0.074952,0.179888,0.928958,01:42


Sparsity at the end of epoch 0: 0.00%
Sparsity at the end of epoch 1: 0.00%
Sparsity at the end of epoch 2: 0.00%
Sparsity at the end of epoch 3: 26.67%
Sparsity at the end of epoch 4: 53.33%
Sparsity at the end of epoch 5: 80.00%
Final Sparsity: 80.00


As you can see, the network sparsity changes over the training.

---

## Gradual Pruning

In [None]:
learn = cnn_learner(dls, resnet18, metrics=accuracy)
learn.unfreeze()

In [None]:
sp_cb=SparsifyCallback(sparsity=80, granularity='weight', method='global', criteria=l1_norm, sched_func=sched_agp)

Let's start pruningn after 3 epochs and train our model for 6 epochs to have the same total amount of training as before

In [None]:
learn.fit_one_cycle(6, cbs=sp_cb)

Pruning of weight until a sparsity of 80%
Saving Weights at epoch 0


epoch,train_loss,valid_loss,accuracy,time
0,0.666568,0.774074,0.841678,01:40
1,0.408887,0.26985,0.888363,01:41
2,0.256823,0.250144,0.891069,01:42
3,0.162672,0.195018,0.928958,01:37
4,0.106435,0.176864,0.935724,01:40
5,0.057881,0.166701,0.94046,01:38


Sparsity at the end of epoch 0: 33.70%
Sparsity at the end of epoch 1: 56.30%
Sparsity at the end of epoch 2: 70.00%
Sparsity at the end of epoch 3: 77.04%
Sparsity at the end of epoch 4: 79.63%
Sparsity at the end of epoch 5: 80.00%
Final Sparsity: 80.00
