# Pruning Schedules

> Make your neural network sparse with fastai

In [None]:
#hide
from nbdev.showdoc import *

In [None]:
#all_slow

The simplest way to perform pruning is called One-Shot Pruning. It consists of the following three steps:

![alt text](imgs/one_shot.pdf "Title")

1. You first need to train a network
2. You then need to remove some weights (depending on your criteria, needs,...)
3. You fine-tune the remaining weights to recover from the loss of parameters.

With fasterai, this is really easy to do. Let's illustrate it by an example:

In [None]:
#hide
from fastai.vision.all import *
from fastai.callback.all import *

from fasterai.sparsifier import *
from fasterai.criteria import *
from fasterai.sparsify_callback import *
from fasterai.schedule import *

import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
path = untar_data(URLs.PETS)

files = get_image_files(path/"images")

def label_func(f): return f[0].isupper()

device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

dls = ImageDataLoaders.from_name_func(path, files, label_func, item_tfms=Resize(64), device=device)

We will first train a network without any pruning, which will serve as a baseline.

In [None]:
learn = cnn_learner(dls, resnet18, metrics=accuracy)
learn.unfreeze()

learn.fit_one_cycle(3)

epoch,train_loss,valid_loss,accuracy,time
0,0.640964,0.633472,0.846414,00:12
1,0.33134,0.272124,0.895129,00:12
2,0.18117,0.211167,0.91272,00:12


## One-Shot Pruning

There are two main ways that you can perform One-Shot Pruning with fasterai. 

1. You already possess a trained network and want to prune it
2. You don't possess such a network and have to train it from scratch

### 1. You possess a trained network

In this case, the step 1) of the One-Shot Pruning process is already done. But you still need to prune the network and then fine-tune it.

Let's say we want to remove $80 \%$ of the weights of our network. This can be done as:

In [None]:
sp = Sparsifier(learn.model, 'weight', 'global', l1_norm)
sp.prune(80)

In [None]:
_, acc = learn.validate(); acc

0.8179972767829895

Obviously, as we removed a good part of trained weights, the perfomance of the network is degraded. This can be solved by retraining our pruned network, making sure that the pruned weights keep their 0 value.

We don't want to update the sparsity level anymore so we have to create a schedule that returns a constant value. Such a schedule exists in fasterai and is called `one_shot` and is defined as:

In [None]:
def one_shot(start, end, pos): return end

We can pass the same arguments to our callback than those used by the Sparsifier.

In [None]:
sp_cb=SparsifyCallback(sparsity=80, granularity='weight', method='global', criteria=l1_norm, sched_func=one_shot)

learn.fit_one_cycle(3, cbs=sp_cb)

Pruning of weight until a sparsity of 80%
Saving Weights at epoch 0


epoch,train_loss,valid_loss,accuracy,time
0,0.199193,0.281678,0.903248,01:36
1,0.167149,0.207365,0.916103,01:35
2,0.08382,0.196808,0.928281,01:34


Sparsity at the end of epoch 0: 80.00%
Sparsity at the end of epoch 1: 80.00%
Sparsity at the end of epoch 2: 80.00%
Final Sparsity: 80.00


We can also check where the pruned weights are in the network

In [None]:
for k,m in enumerate(learn.model.modules()):
    if isinstance(m, nn.Conv2d):
        print(f"Sparsity in {m.__class__.__name__} {k}: {100. * float(torch.sum(m.weight == 0))/ float(m.weight.nelement()):.2f}%")

Sparsity in Conv2d 2: 38.85%
Sparsity in Conv2d 8: 55.18%
Sparsity in Conv2d 11: 50.22%
Sparsity in Conv2d 14: 48.10%
Sparsity in Conv2d 17: 49.92%
Sparsity in Conv2d 21: 53.43%
Sparsity in Conv2d 24: 60.85%
Sparsity in Conv2d 27: 42.36%
Sparsity in Conv2d 30: 60.02%
Sparsity in Conv2d 33: 62.29%
Sparsity in Conv2d 37: 65.71%
Sparsity in Conv2d 40: 70.78%
Sparsity in Conv2d 43: 61.05%
Sparsity in Conv2d 46: 74.60%
Sparsity in Conv2d 49: 77.16%
Sparsity in Conv2d 53: 77.67%
Sparsity in Conv2d 56: 82.73%
Sparsity in Conv2d 59: 59.82%
Sparsity in Conv2d 62: 80.74%
Sparsity in Conv2d 65: 91.73%


> Note: Using Sparsifier to prune the network is not necessary as it will also be called in the Callback. This was used here to better illustrate all the steps.



### 2. You don't possess a trained network

In [None]:
learn = cnn_learner(dls, resnet18, metrics=accuracy)
learn.unfreeze()

In this case, your network needs to be trained before pruning.

You only need to create the Callback with the `one_shot` schedule and set the `start_epoch` argument, i.e. how many epochs you want to train your network before pruning it.



In [None]:
sp_cb=SparsifyCallback(sparsity=80, granularity='weight', method='global', criteria=l1_norm, sched_func=one_shot, start_epoch=3)

Let's start pruningn after 3 epochs and train our model for 6 epochs to have the same total amount of training as before

In [None]:
learn.fit_one_cycle(6, cbs=sp_cb)

Pruning of weight until a sparsity of 80%
Saving Weights at epoch 0


epoch,train_loss,valid_loss,accuracy,time
0,0.702169,0.456472,0.870095,00:11
1,0.410011,0.288117,0.881597,00:11
2,0.250258,0.252269,0.889716,00:12
3,0.145373,0.176909,0.933694,01:37
4,0.083379,0.201312,0.929635,01:34
5,0.054683,0.208249,0.933694,01:36


Sparsity at the end of epoch 0: 0.00%
Sparsity at the end of epoch 1: 0.00%
Sparsity at the end of epoch 2: 0.00%
Sparsity at the end of epoch 3: 80.00%
Sparsity at the end of epoch 4: 80.00%
Sparsity at the end of epoch 5: 80.00%
Final Sparsity: 80.00


Actually, doing the training and pruning in a single cycle works even better !

---

## Iterative Pruning

Researchers have come up with a better way to do pruning than pruning all the weigths in once (as in One-Shot Pruning). The idea is to perform several iterations of pruning and fine-tuning and is thus called Iterative Pruning.

![alt text](imgs/iterative.pdf "Title")

1. You first need to train a network
2. You then need to remove a part of the weights weights (depending on your criteria, needs,...)
3. You fine-tune the remaining weights to recover from the loss of parameters.
4. Back to step 2.

In [None]:
learn = cnn_learner(dls, resnet18, metrics=accuracy)
learn.unfreeze()

In this case, your network needs to be trained before pruning.

You only need to create the Callback with the `iterative` schedule and set the `start_epoch` argument, i.e. how many epochs you want to train your network before pruning it.

In [None]:
def iterative(start, end, pos, n_steps=3):
    "Perform iterative pruning, and pruning in `n_steps` steps"
    return start + ((end-start)/n_steps)*(np.ceil((pos)*n_steps))

The `iterative` schedules has a `n_steps`parameter, i.e. how many iterations of pruning/fine-tuning you want to perform. To modify its value, we can use the `partial` function like this:

```
iterative = partial(iterative, n_steps=5)
```

In [None]:
sp_cb=SparsifyCallback(sparsity=80, granularity='weight', method='global', criteria=l1_norm, sched_func=iterative, start_epoch=3)

Let's start pruningn after 3 epochs and train our model for 6 epochs to have the same total amount of training as before

In [None]:
learn.fit_one_cycle(6, cbs=sp_cb)

Pruning of weight until a sparsity of 80%
Saving Weights at epoch 0


epoch,train_loss,valid_loss,accuracy,time
0,0.678416,0.811682,0.843031,00:12
1,0.44831,0.305697,0.878214,00:12
2,0.24335,0.22305,0.905954,00:11
3,0.140957,0.207141,0.929635,01:37
4,0.082162,0.19937,0.927605,01:34
5,0.068106,0.171238,0.930988,01:36


Sparsity at the end of epoch 0: 0.00%
Sparsity at the end of epoch 1: 0.00%
Sparsity at the end of epoch 2: 0.00%
Sparsity at the end of epoch 3: 26.67%
Sparsity at the end of epoch 4: 53.33%
Sparsity at the end of epoch 5: 80.00%
Final Sparsity: 80.00


---

## Gradual Pruning

![alt text](imgs/gradual.pdf "Title")

In [None]:
learn = cnn_learner(dls, resnet18, metrics=accuracy)
learn.unfreeze()

In [None]:
sp_cb=SparsifyCallback(sparsity=80, granularity='weight', method='global', criteria=l1_norm, sched_func=sched_agp)

Let's start pruningn after 3 epochs and train our model for 6 epochs to have the same total amount of training as before

In [None]:
learn.fit_one_cycle(6, cbs=sp_cb)

Pruning of weight until a sparsity of 80%
Saving Weights at epoch 0


epoch,train_loss,valid_loss,accuracy,time
0,0.620968,0.493531,0.855886,01:34
1,0.395395,0.336614,0.877537,01:36
2,0.255663,0.19909,0.921516,01:33
3,0.157263,0.181541,0.924222,01:35
4,0.099781,0.169471,0.933694,01:33
5,0.062961,0.17536,0.937077,01:36


Sparsity at the end of epoch 0: 33.70%
Sparsity at the end of epoch 1: 56.30%
Sparsity at the end of epoch 2: 70.00%
Sparsity at the end of epoch 3: 77.04%
Sparsity at the end of epoch 4: 79.63%
Sparsity at the end of epoch 5: 80.00%
Final Sparsity: 80.00
