In [None]:
#| default_exp sparse.sparsify_callback

In [None]:
#| include: false
from nbdev.showdoc import *

In [None]:
#| export
from fastai.vision.all import *
from fastai.callback.all import *
from fasterai.sparse.sparsifier import *
from fasterai.sparse.criteria import *
from fasterai.sparse.schedule import *

import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
path = untar_data(URLs.PETS)
files = get_image_files(path/"images")

def label_func(f): return f[0].isupper()

In [None]:
dls = ImageDataLoaders.from_name_func(path, files, label_func, item_tfms=Resize(64))

In [None]:
#| export
class SparsifyCallback(Callback):
    def __init__(self, sparsity, granularity, context, criteria, schedule, lth=False, rewind_epoch=0, reset_end=False, save_tickets=False, model=None, round_to=None, layer_type=nn.Conv2d):
        store_attr()
        self.sparsity = listify(self.sparsity)

    def before_fit(self):
        print(f'Pruning of {self.granularity} until a sparsity of {self.sparsity}%')
        assert self.schedule.start_pct*self.n_epoch>=self.rewind_epoch, 'You must rewind to an epoch before the start of the pruning process'
        model = self.model if self.model else self.learn.model
        self.sparsifier = Sparsifier(model, self.granularity, self.context, self.criteria, self.layer_type)

    def before_epoch(self):
        if self.epoch == self.rewind_epoch:
            print(f'Saving Weights at epoch {self.epoch}')
            self.sparsifier._save_weights()

    def before_batch(self):
        self.current_sparsity = self.schedule(self.sparsity, round(self.pct_train,3))
        if self.schedule.pruned and self.training:
            if self.lth and self.save_tickets:
                print('Saving Intermediate Ticket')
                self.sparsifier.save_model(f'winning_ticket_{self.previous_sparsity[0]:.2f}.pth', self.learn.model)
            self.sparsifier.prune_model(self.current_sparsity, self.round_to)

    def after_step(self):
        if self.lth and self.schedule.pruned:
            print(f'Resetting Weights to their epoch {self.rewind_epoch} values')
            self.sparsifier._reset_weights(self.learn.model)
        self.schedule.after_pruned()
        self.sparsifier._apply_masks()

    def after_epoch(self):
        sparsity_str = [float(f"%0.2f"%sp) for sp in self.current_sparsity]
        print(f'Sparsity at the end of epoch {self.epoch}: {sparsity_str}%')

    def after_fit(self):
        if self.save_tickets:
            print('Saving Final Ticket')
            self.sparsifier.save_model(f'winning_ticket_{self.previous_sparsity[0]:.2f}.pth', self.learn.model)
        print(f'Final Sparsity: {self.schedule.current_sparsity:}%')
        if self.reset_end: self.sparsifier._reset_weights()
        self.sparsifier._clean_buffers()
        self.schedule.reset()
        self.sparsifier.print_sparsity()

The most important part of our `Callback` happens in `before_batch`. There, we first compute the sparsity of our network according to our schedule and then we remove the parameters accordingly.

In [None]:
learn = cnn_learner(dls, resnet18, metrics=accuracy)
learn.unfreeze()

  warn("`cnn_learner` has been renamed to `vision_learner` -- please update your code")


In [None]:
learn.fit_one_cycle(5)

epoch,train_loss,valid_loss,accuracy,time
0,0.71417,0.534177,0.802436,00:08
1,0.405863,0.46695,0.861976,00:07
2,0.229647,0.234999,0.902571,00:07
3,0.141966,0.198904,0.924222,00:07
4,0.073327,0.191152,0.930988,00:07


Let's now try adding some sparsity in our model

In [None]:
learn = cnn_learner(dls, resnet18, metrics=accuracy)
learn.unfreeze()

  warn("`cnn_learner` has been renamed to `vision_learner` -- please update your code")


The `SparsifyCallback` requires a new argument compared to the `Sparsifier`. Indeed, we need to know the pruning schedule that we should follow during training in order to prune the parameters accordingly.

You can use any scheduling function already [available](https://docs.fast.ai/callback.schedule.html#Annealing) in fastai or come up with your own ! For more information about the pruning schedules, take a look at the [Schedules section](https://nathanhubens.github.io/fasterai/schedules.html).

In [None]:
learn = Learner(dls, resnet18(num_classes=2), metrics=accuracy)

In [None]:
sp_cb = SparsifyCallback(sparsity=50, granularity='weight', context='local', criteria=large_final, schedule=cos)

In [None]:
learn.fit(10, cbs=sp_cb)

Pruning of weight until a sparsity of [50]%
Saving Weights at epoch 0


epoch,train_loss,valid_loss,accuracy,time
0,0.597135,0.568558,0.694181,00:09
1,0.543739,0.527585,0.730717,00:07
2,0.508932,0.507831,0.748309,00:07
3,0.451922,0.454692,0.799053,00:07
4,0.427453,0.434664,0.801759,00:07
5,0.377218,0.402817,0.82341,00:07
6,0.340924,0.410856,0.820027,00:07
7,0.319503,0.363846,0.837618,00:07
8,0.271233,0.377996,0.85318,00:07
9,0.228336,0.334722,0.865359,00:07


Sparsity at the end of epoch 0: [1.22]%
Sparsity at the end of epoch 1: [4.77]%
Sparsity at the end of epoch 2: [10.31]%
Sparsity at the end of epoch 3: [17.27]%
Sparsity at the end of epoch 4: [25.0]%
Sparsity at the end of epoch 5: [32.73]%
Sparsity at the end of epoch 6: [39.69]%
Sparsity at the end of epoch 7: [45.23]%
Sparsity at the end of epoch 8: [48.78]%
Sparsity at the end of epoch 9: [50.0]%
Final Sparsity: [50.0]%
Sparsity in Conv2d 1: 50.00%
Sparsity in Conv2d 7: 50.00%
Sparsity in Conv2d 10: 50.00%
Sparsity in Conv2d 13: 50.00%
Sparsity in Conv2d 16: 50.00%
Sparsity in Conv2d 20: 50.00%
Sparsity in Conv2d 23: 50.00%
Sparsity in Conv2d 26: 50.00%
Sparsity in Conv2d 29: 50.00%
Sparsity in Conv2d 32: 50.00%
Sparsity in Conv2d 36: 50.00%
Sparsity in Conv2d 39: 50.00%
Sparsity in Conv2d 42: 50.00%
Sparsity in Conv2d 45: 50.00%
Sparsity in Conv2d 48: 50.00%
Sparsity in Conv2d 52: 50.00%
Sparsity in Conv2d 55: 50.00%
Sparsity in Conv2d 58: 50.00%
Sparsity in Conv2d 61: 50.00%
Sp

Surprisingly, our network that is composed of $50 \%$ of zeroes performs reasonnably well when compared to our plain and dense network.

The `SparsifyCallback` also accepts a list of sparsities, corresponding to each layer of `layer_type` to be pruned. Below, we show how to prune only the intermediate layers of ResNet-18.

In [None]:
learn = cnn_learner(dls, resnet18, metrics=accuracy)
learn.unfreeze()

In [None]:
sparsities = [0, 0, 0, 0, 0, 0, 50, 50, 50, 50, 50, 50, 50, 50, 0, 0, 0, 0, 0, 0]

In [None]:
sp_cb = SparsifyCallback(sparsity=sparsities, granularity='weight', context='local', criteria=large_final, schedule=cos)

In [None]:
learn.fit_one_cycle(5, cbs=sp_cb)

Pruning of weight until a sparsity of [0, 0, 0, 0, 0, 0, 50, 50, 50, 50, 50, 50, 50, 50, 0, 0, 0, 0, 0, 0]%
Saving Weights at epoch 0


epoch,train_loss,valid_loss,accuracy,time
0,0.748184,0.876642,0.826116,00:08
1,0.422033,0.255813,0.889039,00:08
2,0.262884,0.2341,0.904601,00:08
3,0.132767,0.228366,0.921516,00:08
4,0.07511,0.210104,0.930311,00:08


Sparsity at the end of epoch 0: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 4.77, 4.77, 4.77, 4.77, 4.77, 4.77, 4.77, 4.77, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]%
Sparsity at the end of epoch 1: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 17.27, 17.27, 17.27, 17.27, 17.27, 17.27, 17.27, 17.27, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]%
Sparsity at the end of epoch 2: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 32.73, 32.73, 32.73, 32.73, 32.73, 32.73, 32.73, 32.73, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]%
Sparsity at the end of epoch 3: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 45.23, 45.23, 45.23, 45.23, 45.23, 45.23, 45.23, 45.23, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]%
Sparsity at the end of epoch 4: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 50.0, 50.0, 50.0, 50.0, 50.0, 50.0, 50.0, 50.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]%
Final Sparsity: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 50.0, 50.0, 50.0, 50.0, 50.0, 50.0, 50.0, 50.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]%
Sparsity in Conv2d 2: 0.00%
Sparsity in Conv2d 8: 0.00%
Sparsity in Conv2d 11: 0.00%
Sparsity in Conv2d 14: 0.00%
Sparsity in Conv2d 17: 0.0

On top of that, the `SparsifyCallback`can also take many optionnal arguments: 

- `start_sparsity`: the sparsity that the schedule will use as a starting point (default to 0)
- `start_epoch`: the epoch at which the schedule will start pruning (default to 0)
- `end_epoch`: the epoch at which the schedule will stop pruning (default to the training epochs passed in `fit`)
- `lth`: whether training using the Lottery Ticket Hypothesis, i.e. reset the weights to their original value at each pruning step (more information in the Lottery Ticket Hypothesis section)
- `rewind_epoch`: the epoch used as a reference for the Lottery Ticket Hypothesis with Rewinding (default to 0)
- `reset_end`: whether you want to reset the weights to their original values after training (pruning masks are still applied)
- `save_tickets`: whether to save intermediate winning tickets.
- `model`: pass a model or a part of the model if you don't want to apply pruning on the whole model trained.
- `round_to`: if specified, the weights will be pruned to the closest multiple value of `round_to`.
- `layer_type`: specify the type of layer that you want to apply pruning to (default to nn.Conv2d)`

For example, we correctly pruned the convolution layers of our model, but we could imagine pruning the Linear Layers of even only the BatchNorm ones !