# One-Shot Pruning

> Make your neural network sparse with fastai

In [None]:
#hide
from nbdev.showdoc import *

In [None]:
#all_slow

The simplest way to perform pruning is called One-Shot Pruning. It consists of the following three steps:

![alt text](imgs/one_shot.pdf "Title")

1. You first need to train a network
2. You then need to remove some weights (depending on your criteria, needs,...)
3. You fine-tune the remaining weights to recover from the loss of parameters.

With fasterai, this is really easy to do. Let's illustrate it by an example:

In [None]:
#hide
from fastai.vision.all import *
from fastai.callback.all import *

from fasterai.sparsifier import *
from fasterai.criteria import *
from fasterai.sparsify_callback import *
from fasterai.schedule import one_shot

import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
path = untar_data(URLs.PETS)

In [None]:
files = get_image_files(path/"images")

In [None]:
def label_func(f): return f[0].isupper()

In [None]:
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

In [None]:
dls = ImageDataLoaders.from_name_func(path, files, label_func, item_tfms=Resize(64), device=device)

We will first train a network without any pruning, which will serve as a baseline.

In [None]:
learn = cnn_learner(dls, resnet18, metrics=accuracy)
learn.unfreeze()

In [None]:
learn.fit_one_cycle(3)

epoch,train_loss,valid_loss,accuracy,time
0,0.67931,0.442378,0.832206,00:09
1,0.370691,0.304751,0.861299,00:09
2,0.212456,0.226414,0.910014,00:09


## One-Shot Pruning

There are two main ways that you can perform One-Shot Pruning with fasterai. 

1. You already possess a trained network and want to prune it
2. You don't possess such a network and have to train it from scratch

### 1. You possess a trained network

In this case, the step 1) of the One-Shot Pruning process is already done. But you still need to prune the network and then fine-tune it.

Let's say we want to remove $80 \%$ of the weights of our network. This can be done as:

In [None]:
sp = Sparsifier(learn.model, 'weight', 'global', l1_norm)
sp.prune(80)

In [None]:
_, acc = learn.validate(); acc

0.8037889003753662

Obviously, as we removed a good part of trained weights, the perfomance of the network is degraded. This can be solved by retraining our pruned network, making sure that the pruned weights keep their 0 value.

We don't want to update the sparsity level anymore so we have to create a schedule that returns a constant value. Such a schedule exists in fasterai and is called `one_shot` and is defined as:

In [None]:
def one_shot(start, end, pos): return end

We can pass the same arguments to our callback than those used by the Sparsifier.

In [None]:
sp_cb=SparsifyCallback(sparsity=80, granularity='weight', method='global', criteria=l1_norm, sched_func=sched_end)

In [None]:
learn.fit_one_cycle(3, cbs=sp_cb)

Pruning of weight until a sparsity of 80%
Saving Weights at epoch 0


epoch,train_loss,valid_loss,accuracy,time
0,0.224453,0.354158,0.866712,00:56
1,0.183741,0.218558,0.919486,00:55
2,0.101941,0.22559,0.920839,00:55


Sparsity at the end of epoch 0: 80.00%
Sparsity at the end of epoch 1: 80.00%
Sparsity at the end of epoch 2: 80.00%
Final Sparsity: 80.00


We can also check where the pruned weights are in the network

In [None]:
for k,m in enumerate(learn.model.modules()):
    if isinstance(m, nn.Conv2d):
        print(f"Sparsity in {m.__class__.__name__} {k}: {100. * float(torch.sum(m.weight == 0))/ float(m.weight.nelement()):.2f}%")

Sparsity in Conv2d 2: 38.65%
Sparsity in Conv2d 8: 55.38%
Sparsity in Conv2d 11: 50.37%
Sparsity in Conv2d 14: 48.50%
Sparsity in Conv2d 17: 50.10%
Sparsity in Conv2d 21: 53.52%
Sparsity in Conv2d 24: 61.05%
Sparsity in Conv2d 27: 42.68%
Sparsity in Conv2d 30: 60.19%
Sparsity in Conv2d 33: 62.56%
Sparsity in Conv2d 37: 65.79%
Sparsity in Conv2d 40: 70.87%
Sparsity in Conv2d 43: 60.93%
Sparsity in Conv2d 46: 74.63%
Sparsity in Conv2d 49: 77.23%
Sparsity in Conv2d 53: 77.70%
Sparsity in Conv2d 56: 82.72%
Sparsity in Conv2d 59: 60.01%
Sparsity in Conv2d 62: 80.68%
Sparsity in Conv2d 65: 91.67%


> Note: Using Sparsifier to prune the network is not necessary as it will also be called in the Callback. This was used here to better illustrate all the steps.

### 2. You don't possess a trained network

In [None]:
learn = cnn_learner(dls, resnet18, metrics=accuracy)
learn.unfreeze()

In this case, your network needs to be trained before pruning.

You only need to create the Callback with the `one_shot` schedule and set the `start_epoch` argument, i.e. how many epochs you want to train your network before pruning it.

In [None]:
sp_cb=SparsifyCallback(sparsity=80, granularity='weight', method='global', criteria=l1_norm, sched_func=one_shot, start_epoch=3)

Let's start pruningn after 3 epochs and train our model for 6 epochs to have the same total amount of training as before

In [None]:
learn.fit_one_cycle(6, cbs=sp_cb)

Pruning of weight until a sparsity of 80%
Saving Weights at epoch 0


epoch,train_loss,valid_loss,accuracy,time
0,0.714634,0.769432,0.811908,00:09
1,0.439999,0.663998,0.786874,00:09
2,0.265372,0.231254,0.907984,00:09
3,0.148409,0.208199,0.923545,00:55
4,0.092647,0.218885,0.932341,00:55
5,0.054311,0.206575,0.934371,00:55


Sparsity at the end of epoch 0: 0.00%
Sparsity at the end of epoch 1: 0.00%
Sparsity at the end of epoch 2: 0.00%
Sparsity at the end of epoch 3: 80.00%
Sparsity at the end of epoch 4: 80.00%
Sparsity at the end of epoch 5: 80.00%
Final Sparsity: 80.00


Actually, doing the training and pruning in a single cycle works even better !

We can check if we get similar sparsity values across the layers:

In [None]:
for k,m in enumerate(learn.model.modules()):
    if isinstance(m, nn.Conv2d):
        print(f"Sparsity in {m.__class__.__name__} {k}: {100. * float(torch.sum(m.weight == 0))/ float(m.weight.nelement()):.2f}%")

Sparsity in Conv2d 2: 38.33%
Sparsity in Conv2d 8: 55.25%
Sparsity in Conv2d 11: 50.01%
Sparsity in Conv2d 14: 47.87%
Sparsity in Conv2d 17: 49.38%
Sparsity in Conv2d 21: 53.15%
Sparsity in Conv2d 24: 60.62%
Sparsity in Conv2d 27: 42.43%
Sparsity in Conv2d 30: 59.82%
Sparsity in Conv2d 33: 62.18%
Sparsity in Conv2d 37: 65.48%
Sparsity in Conv2d 40: 70.72%
Sparsity in Conv2d 43: 60.64%
Sparsity in Conv2d 46: 74.58%
Sparsity in Conv2d 49: 77.17%
Sparsity in Conv2d 53: 77.68%
Sparsity in Conv2d 56: 82.79%
Sparsity in Conv2d 59: 59.78%
Sparsity in Conv2d 62: 80.73%
Sparsity in Conv2d 65: 91.79%
