# Iterative Pruning

> Make your neural network sparse with fastai

In [None]:
#hide
from nbdev.showdoc import *

In [None]:
#all_slow

Researchers have come up with a better way to do pruning than pruning all the weigths in once (as in One-Shot Pruning). The idea is to perform several iterations of pruning and fine-tuning and is thus called Iterative Pruning.

![alt text](imgs/iterative.pdf "Title")

1. You first need to train a network
2. You then need to remove a part of the weights weights (depending on your criteria, needs,...)
3. You fine-tune the remaining weights to recover from the loss of parameters.
4. Back to step 2.

With fasterai, this is really easy to do. Let's illustrate it by an example:

In [None]:
#hide
from fastai.vision.all import *
from fastai.callback.all import *

from fasterai.sparsifier import *
from fasterai.criteria import *
from fasterai.sparsify_callback import *
from fasterai.schedule import iterative

import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
path = untar_data(URLs.PETS)

In [None]:
files = get_image_files(path/"images")

In [None]:
def label_func(f): return f[0].isupper()

In [None]:
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

In [None]:
dls = ImageDataLoaders.from_name_func(path, files, label_func, item_tfms=Resize(64), device=device)

We will first train a network without any pruning, which will serve as a baseline.

In [None]:
learn = cnn_learner(dls, resnet18, metrics=accuracy)
learn.unfreeze()

In [None]:
learn.fit_one_cycle(3)

epoch,train_loss,valid_loss,accuracy,time
0,0.66117,0.411315,0.862652,00:13
1,0.344798,0.235212,0.901218,00:12
2,0.184619,0.195194,0.923545,00:12


## Iterative Pruning

There are two main ways that you can perform Iterative Pruning with fasterai. 

1. You already possess a trained network and want to prune it
2. You don't possess such a network and have to train it from scratch

In [None]:
learn = cnn_learner(dls, resnet18, metrics=accuracy)
learn.unfreeze()

In this case, your network needs to be trained before pruning.

You only need to create the Callback with the `iterative` schedule and set the `start_epoch` argument, i.e. how many epochs you want to train your network before pruning it.

In [None]:
def iterative(start, end, pos, n_steps=3):
    "Perform iterative pruning, and pruning in `n_steps` steps"
    return start + ((end-start)/n_steps)*(np.ceil((pos)*n_steps))

The `iterative` schedules has a `n_steps`parameter, i.e. how many iterations of pruning/fine-tuning you want to perform. To modify its value, we can use the `partial` function like this:

```
iterative = partial(iterative, n_steps=5)
```

In [None]:
sp_cb=SparsifyCallback(sparsity=80, granularity='weight', method='global', criteria=l1_norm, sched_func=iterative, start_epoch=3)

Let's start pruningn after 3 epochs and train our model for 6 epochs to have the same total amount of training as before

In [None]:
learn.fit_one_cycle(6, cbs=sp_cb)

Pruning of weight until a sparsity of 80%
Saving Weights at epoch 0


epoch,train_loss,valid_loss,accuracy,time
0,0.705333,0.590168,0.826793,00:12
1,0.376754,0.285279,0.892422,00:11
2,0.238135,0.217662,0.916779,00:11
3,0.144444,0.191392,0.926928,01:39
4,0.088618,0.181956,0.936401,01:43
5,0.072037,0.189369,0.935047,01:39


Sparsity at the end of epoch 0: 0.00%
Sparsity at the end of epoch 1: 0.00%
Sparsity at the end of epoch 2: 0.00%
Sparsity at the end of epoch 3: 26.67%
Sparsity at the end of epoch 4: 53.33%
Sparsity at the end of epoch 5: 80.00%
Final Sparsity: 80.00


As you can see, the network sparsity changes over the training.