In [None]:
#| include: false
from nbdev.showdoc import *
import warnings
warnings.filterwarnings('ignore')

In [None]:
#| include: false
from fastai.vision.all import *
from fastai.callback.all import *
from fasterai.prune.all import *
from fasterai.core.criteria import *
import torch_pruning as tp
from torch_pruning.pruner import function
import torch_pruning as tp

import torch
import torch.nn as nn
import torch.nn.functional as F

Let's try our `PruneCallback` on the `Pets` dataset

In [None]:
path = untar_data(URLs.PETS)
files = get_image_files(path/"images")

def label_func(f): return f[0].isupper()

dls = ImageDataLoaders.from_name_func(path, files, label_func, item_tfms=Resize(64))

We'll train a vanilla ResNet18 for 5 epochs to have an idea of the expected performance

In [None]:
learn = vision_learner(dls, resnet18, metrics=accuracy)
learn.unfreeze()
learn.fit_one_cycle(5)

epoch,train_loss,valid_loss,accuracy,time
0,0.674896,0.956743,0.79364,00:40
1,0.411653,0.613217,0.854533,00:11
2,0.248864,0.282094,0.888363,00:10
3,0.133082,0.209491,0.920162,00:08
4,0.070914,0.195288,0.931664,00:08


In [None]:
base_macs, base_params = tp.utils.count_ops_and_params(learn.model, torch.randn(1,3,224,224).to(default_device()))

Let's now try adding to remove some filters in our model

In [None]:
learn = vision_learner(dls, resnet18, metrics=accuracy)
learn.unfreeze()

We'll set the `sparsity` to 50 (i.e. remove 50% of filters), the `context` to global (i.e. we remove filters from anywhere in the network), the `criteria` to large_final (i.e. keep the highest value filters and the `schedule` to one_cycle (i.e. follow the One-Cycle schedule to remove filters along training).

In [None]:
pr_cb = PruneCallback(sparsity=50, context='global', criteria=large_final, schedule=one_cycle)
learn.fit_one_cycle(10, cbs=pr_cb)

Pruning until a sparsity of [50]%


epoch,train_loss,valid_loss,accuracy,time
0,0.898704,0.609782,0.742219,00:10
1,0.537376,0.386652,0.870095,00:13
2,0.353923,0.290426,0.889039,00:13
3,0.263464,0.257007,0.905277,00:15
4,0.225723,0.262253,0.893099,00:14
5,0.226365,0.260755,0.895805,00:14
6,0.220827,0.224387,0.903924,00:13
7,0.19691,0.243172,0.896482,00:13
8,0.168376,0.237551,0.902571,00:13
9,0.148228,0.227397,0.905954,00:12


Sparsity at the end of epoch 0: [0.5]%
Sparsity at the end of epoch 1: [1.96]%
Sparsity at the end of epoch 2: [7.09]%
Sparsity at the end of epoch 3: [20.07]%
Sparsity at the end of epoch 4: [36.57]%
Sparsity at the end of epoch 5: [45.86]%
Sparsity at the end of epoch 6: [48.92]%
Sparsity at the end of epoch 7: [49.74]%
Sparsity at the end of epoch 8: [49.95]%
Sparsity at the end of epoch 9: [50.0]%
Final Sparsity: [50.0]%


In [None]:
pruned_macs, pruned_params = tp.utils.count_ops_and_params(learn.model, torch.randn(1,3,224,224).to(default_device()))

We observe that our network has lost 2.5% of accuracy. But how much parameters have we removed and how much compute does that save ?

In [None]:
print(f'The pruned model has {pruned_macs/base_macs:.2f} the compute of original model')

The pruned model has 0.71 the compute of original model


In [None]:
print(f'The pruned model has {pruned_params/base_params:.2f} the parameters of original model')

The pruned model has 0.21 the parameters of original model


So at the price of a slight decrease in accuracy, we now have a model that is 5x smaller and requires 1.5x fewer compute.