In [None]:
# default_exp sparse.sparsify_callback

# SparsifyCallback

> Use the sparsifier in fastai Callback system

In [None]:
#all_slow

In [2]:
#hide
from nbdev.showdoc import *

%config InlineBackend.figure_format = 'retina'


Bad key "text.kerning_factor" on line 4 in
/Users/nathan/opt/miniconda3/envs/deep/lib/python3.6/site-packages/matplotlib/mpl-data/stylelib/_classic_test_patch.mplstyle.
You probably need to get an updated matplotlibrc file from
http://github.com/matplotlib/matplotlib/blob/master/matplotlibrc.template
or from the matplotlib source distribution


In [3]:
#export
from fastai.vision.all import *
from fastai.callback.all import *
from fasterai.sparse.sparsifier import *
from fasterai.sparse.criteria import *

import torch
import torch.nn as nn
import torch.nn.functional as F

In [6]:
model1 = resnet18()

In [7]:
model2 = None

In [11]:
model3 = vgg16_bn()

In [12]:
model = model3 if model1 is None else model1

In [16]:
model = model2 if model2 else model3

In [None]:
path = untar_data(URLs.PETS)
files = get_image_files(path/"images")

def label_func(f): return f[0].isupper()

In [None]:
dls = ImageDataLoaders.from_name_func(path, files, label_func, item_tfms=Resize(64))

In [1]:
#export
class SparsifyCallback(Callback):
        
    def __init__(self, end_sparsity, granularity, method, criteria, sched_func, start_sparsity=0, start_epoch=0, end_epoch=None, lth=False, rewind_epoch=0, reset_end=False, model=None, round_to=None, save_tickets=False, layer_type=nn.Conv2d):
        store_attr()
        self.end_sparsity, self.current_sparsity, self.previous_sparsity = map(listify, [self.end_sparsity, self.start_sparsity, self.start_sparsity])
        
        assert self.start_epoch>=self.rewind_epoch, 'You must rewind to an epoch before the start of the pruning process'
    
    def before_fit(self):
        print(f'Pruning of {self.granularity} until a sparsity of {self.end_sparsity}%')
        self.end_epoch = self.n_epoch if self.end_epoch is None else self.end_epoch
        assert self.end_epoch <= self.n_epoch, 'Your end_epoch must be smaller than total number of epoch'
        
        model = self.learn.model if self.model is None else self.model # Pass a model if you don't want the whole model to be pruned
        self.sparsifier = Sparsifier(model, self.granularity, self.method, self.criteria, self.layer_type)
    
    def before_epoch(self):
        if self.epoch == self.rewind_epoch:
            print(f'Saving Weights at epoch {self.epoch}')
            self.sparsifier._save_weights()
        
    def before_batch(self):
        if self.epoch>=self.start_epoch and self.epoch < self.end_epoch: 
            self._set_sparsity()
            if self.current_sparsity!=self.previous_sparsity and self.training:
                if self.lth and self.save_tickets:
                    print('Saving Intermediate Ticket')
                    self.sparsifier.save_model(f'winning_ticket_{self.previous_sparsity[0]:.2f}.pth', self.learn.model)
                self.sparsifier.prune_model(self.current_sparsity, self.round_to)
            
    def after_step(self):
        if self.epoch>=self.start_epoch:
            if self.lth and self.current_sparsity!=self.previous_sparsity:
                print(f'Resetting Weights to their epoch {self.rewind_epoch} values')
                self.sparsifier._reset_weights(self.learn.model)

            self.previous_sparsity = self.current_sparsity
            self.sparsifier._apply_masks()
            
    def after_epoch(self):
        sparsity_str = [float(f"%0.2f"%sp) for sp in self.current_sparsity]
        print(f'Sparsity at the end of epoch {self.epoch}: {sparsity_str}%')
        
    def after_fit(self):
        if self.save_tickets:
            print('Saving Final Ticket')
            self.sparsifier.save_model(f'winning_ticket_{self.previous_sparsity[0]:.2f}.pth', self.learn.model)
            
        print(f'Final Sparsity: {self.current_sparsity:}%')
        if self.reset_end: self.sparsifier._reset_weights()
        self.sparsifier._clean_buffers()
        self.sparsifier.print_sparsity()
        
    def _set_sparsity(self):
        self.current_sparsity = [self.sched_func(start=self.start_sparsity, end=end_sp, pos=(round(self.pct_train,5)*self.n_epoch-self.start_epoch)/(self.end_epoch-self.start_epoch)) for end_sp in self.end_sparsity]

NameError: name 'Callback' is not defined

The most important part of our `Callback` happens in `before_batch`. There, we first compute the sparsity of our network according to our schedule and then we remove the parameters accordingly.

In [None]:
learn = cnn_learner(dls, resnet18, metrics=accuracy)
learn.unfreeze()

  warn("`cnn_learner` has been renamed to `vision_learner` -- please update your code")


In [None]:
learn.fit_one_cycle(5)

epoch,train_loss,valid_loss,accuracy,time
0,0.666736,1.887411,0.753721,00:11
1,0.43754,0.276486,0.881597,00:10
2,0.258372,0.291492,0.87889,00:10
3,0.146018,0.20028,0.924222,00:10
4,0.07515,0.212859,0.925575,00:10


Let's now try adding some sparsity in our model

In [None]:
learn = cnn_learner(dls, resnet18, metrics=accuracy)
learn.unfreeze()

The `SparsifyCallback` requires a new argument compared to the `Sparsifier`. Indeed, we need to know the pruning schedule that we should follow during training in order to prune the parameters accordingly.

You can use any scheduling function already [available](https://docs.fast.ai/callback.schedule.html#Annealing) in fastai or come up with your own ! For more information about the pruning schedules, take a look at the [Schedules section](https://nathanhubens.github.io/fasterai/schedules.html).

In [None]:
sp_cb = SparsifyCallback(end_sparsity=50, granularity='weight', method='local', criteria=large_final, sched_func=sched_cos)

In [None]:
learn.fit_one_cycle(5, cbs=sp_cb)

Pruning of weight until a sparsity of [50]%
Saving Weights at epoch 0


epoch,train_loss,valid_loss,accuracy,time
0,0.672893,0.45721,0.832882,00:10
1,0.406809,0.249998,0.899188,00:10
2,0.234348,0.46169,0.835589,00:10
3,0.132843,0.218986,0.920839,00:10
4,0.080099,0.211927,0.919486,00:10


Sparsity at the end of epoch 0: [4.77]%
Sparsity at the end of epoch 1: [17.27]%
Sparsity at the end of epoch 2: [32.73]%
Sparsity at the end of epoch 3: [45.23]%
Sparsity at the end of epoch 4: [50.0]%
Final Sparsity: [50.0]%
Sparsity in Conv2d 2: 50.00%
Sparsity in Conv2d 8: 50.00%
Sparsity in Conv2d 11: 50.00%
Sparsity in Conv2d 14: 50.00%
Sparsity in Conv2d 17: 50.00%
Sparsity in Conv2d 21: 50.00%
Sparsity in Conv2d 24: 50.00%
Sparsity in Conv2d 27: 50.00%
Sparsity in Conv2d 30: 50.00%
Sparsity in Conv2d 33: 50.00%
Sparsity in Conv2d 37: 50.00%
Sparsity in Conv2d 40: 50.00%
Sparsity in Conv2d 43: 50.00%
Sparsity in Conv2d 46: 50.00%
Sparsity in Conv2d 49: 50.00%
Sparsity in Conv2d 53: 50.00%
Sparsity in Conv2d 56: 50.00%
Sparsity in Conv2d 59: 50.00%
Sparsity in Conv2d 62: 50.00%
Sparsity in Conv2d 65: 50.00%


Surprisingly, our network that is composed of $50 \%$ of zeroes performs reasonnably well when compared to our plain and dense network.

The `SparsifyCallback` also accepts a list of sparsities, corresponding to each layer of `layer_type` to be pruned. Below, we show how to prune only the intermediate layers of ResNet-18.

In [None]:
learn = cnn_learner(dls, resnet18, metrics=accuracy)
learn.unfreeze()

In [None]:
sparsities = [0, 0, 0, 0, 0, 0, 50, 50, 50, 50, 50, 50, 50, 50, 0, 0, 0, 0, 0, 0]

In [None]:
sp_cb = SparsifyCallback(end_sparsity=sparsities, granularity='weight', method='local', criteria=large_final, sched_func=sched_cos)

In [None]:
learn.fit_one_cycle(5, cbs=sp_cb)

Pruning of weight until a sparsity of [0, 0, 0, 0, 0, 0, 50, 50, 50, 50, 50, 50, 50, 50, 0, 0, 0, 0, 0, 0]%
Saving Weights at epoch 0


epoch,train_loss,valid_loss,accuracy,time
0,0.663173,0.546195,0.790934,00:10
1,0.403149,0.366974,0.876184,00:10
2,0.2379,0.250653,0.904601,00:11
3,0.13862,0.214972,0.924899,00:11
4,0.067375,0.212548,0.920839,00:10


Sparsity at the end of epoch 0: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 4.77, 4.77, 4.77, 4.77, 4.77, 4.77, 4.77, 4.77, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]%
Sparsity at the end of epoch 1: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 17.27, 17.27, 17.27, 17.27, 17.27, 17.27, 17.27, 17.27, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]%
Sparsity at the end of epoch 2: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 32.73, 32.73, 32.73, 32.73, 32.73, 32.73, 32.73, 32.73, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]%
Sparsity at the end of epoch 3: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 45.23, 45.23, 45.23, 45.23, 45.23, 45.23, 45.23, 45.23, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]%
Sparsity at the end of epoch 4: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 50.0, 50.0, 50.0, 50.0, 50.0, 50.0, 50.0, 50.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]%
Final Sparsity: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 50.0, 50.0, 50.0, 50.0, 50.0, 50.0, 50.0, 50.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]%
Sparsity in Conv2d 2: 0.00%
Sparsity in Conv2d 8: 0.00%
Sparsity in Conv2d 11: 0.00%
Sparsity in Conv2d 14: 0.00%
Sparsity in Conv2d 17: 0.0

On top of that, the `SparsifyCallback`can also take many optionnal arguments: 

- `start_sparsity`: the sparsity that the schedule will use as a starting point (default to 0)
- `start_epoch`: the epoch at which the schedule will start pruning (default to 0)
- `end_epoch`: the epoch at which the schedule will stop pruning (default to the training epochs passed in `fit`)
- `lth`: whether training using the Lottery Ticket Hypothesis, i.e. reset the weights to their original value at each pruning step (more information in the Lottery Ticket Hypothesis section)
- `rewind_epoch`: the epoch used as a reference for the Lottery Ticket Hypothesis with Rewinding (default to 0)
- `reset_end`: whether you want to reset the weights to their original values after training (pruning masks are still applied)
- `model`: pass a model or a part of the model if you don't want to apply pruning on the whole model trained.
- `layer_type`: specify the type of layer that you want to apply pruning to (default to nn.Conv2d)

For example, we correctly pruned the convolution layers of our model, but we could imagine pruning the Linear Layers of even only the BatchNorm ones !