# Network Sparsifying

The sparsifier allows to sparsify your network using different methods:
- Local or Global: the choice of parameters removed is done per layer (e.g sparsity=50% removes half of parameters at each layer) or based on the whole network (e.g sparsity=50% removes half of the parameters of the network, whatever the layer they come from). 
- Weight, Kernel, Filter: granularity of the sparsifying.
- Scheduling Function: scheduling applied to the removal of parameters. Scheduling supported by default can be found [here](https://docs.fast.ai/callback.html#Annealing-functions)

In [1]:
from fastai.vision import *
from fastai.callbacks import *

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [3]:
import warnings
warnings.filterwarnings("ignore")

In [4]:
path = untar_data(URLs.IMAGENETTE_160)

In [5]:
data = (ImageList.from_folder(path)
                .split_by_folder(train='train', valid='val')
                .label_from_folder()
                .transform(get_transforms(), size=64)
                .databunch(bs=64)
                .normalize(imagenet_stats))

In [6]:
import sys
sys.path.append('../')

from fasterai.sparsifier_test import *

## VGG16

In [7]:
learn = Learner(data, models.vgg16_bn(), metrics=[accuracy])

In [8]:
learn.fit_one_cycle(3, 1e-3)

epoch,train_loss,valid_loss,accuracy,time
0,2.231394,5.494659,0.204841,00:15
1,1.858624,1.850971,0.338344,00:14
2,1.463806,1.214066,0.608662,00:15


You can either prune after training by using the `prune` method. But you shouldn't expect great results like that as, even less important than others, the parameters you remove still have some importance overall, and by doing so, you don't give a chance to the network to recover from the loss of some of its parameters.

In [9]:
# Example: remove 50% of the least important parameters
sparsifier = Sparsifier(learn.model, granularity='weight', method='local', criteria='l1')

In [11]:
sparsifier.prune(sparsity=50)

Generally, pruning your network like that requires to retrain you model in order to allow it to recover from the removal of the parameters

In [12]:
learn.validate()

[1.7763458, tensor(0.4104)]

A common way to work is by doing several iterations of **pruning->fine-tuning**. This process can be long and sensitive as you have to choose at each iteration, how much parameters to remove and a bad choice can lead to a completely broken network with no chance to recover.

The goal of the `Sparsifier`function is rather to include pruning **into** the training process. By doing so, the time of the process is greatly reduced.

## Local Weight Sparsifying

Let's remove 50% of the parameters of VGG16 and see how the training behaves.

In [13]:
learn = Learner(data, models.vgg16_bn(), metrics=[accuracy])

In [14]:
learn.fit_one_cycle(3, 1e-3, callbacks=[SparsifyCallback(learn, sparsity=50, granularity='weight', method='local', criteria='l1', sched_func=annealing_cos)])

Pruning of weight until a sparsity of 50%


epoch,train_loss,valid_loss,accuracy,time
0,2.241648,2.585599,0.181656,00:21
1,1.89552,2.031944,0.326369,00:22
2,1.55008,1.363338,0.551083,00:21


Saving Weights at epoch 0
Sparsity at the end of epoch 0: 12.50%
Sparsity at the end of epoch 1: 37.50%
Sparsity at the end of epoch 2: 50.00%
Final Sparsity: 50.00


So now our network has only half of its parameters that are used and still is able to achieve almost the same accuracy as when it was using 100% !

Let's double check that we correctly removed the parameters:

In [15]:
for k,m in enumerate(learn.model.modules()):
    if isinstance(m, nn.Conv2d):
        print(f"Sparsity in {m.__class__.__name__} {k}: {100. * float(torch.sum(m.weight == 0))/ float(m.weight.nelement()):.2f}%")

Sparsity in Conv2d 2: 50.00%
Sparsity in Conv2d 5: 50.00%
Sparsity in Conv2d 9: 50.00%
Sparsity in Conv2d 12: 50.00%
Sparsity in Conv2d 16: 50.00%
Sparsity in Conv2d 19: 50.00%
Sparsity in Conv2d 22: 50.00%
Sparsity in Conv2d 26: 50.00%
Sparsity in Conv2d 29: 50.00%
Sparsity in Conv2d 32: 50.00%
Sparsity in Conv2d 36: 50.00%
Sparsity in Conv2d 39: 50.00%
Sparsity in Conv2d 42: 50.00%


And if we look closer to a single Convolution filter, we expect to see half of its values to be zero:

In [16]:
print(learn.model.features[0].weight[0].data)

tensor([[[-0.0887, -0.0848, -0.0687],
         [ 0.0000, -0.0705,  0.0000],
         [ 0.0427,  0.0000, -0.0550]],

        [[-0.0509, -0.0000, -0.0000],
         [-0.0804, -0.0552, -0.0000],
         [ 0.1522,  0.0000,  0.0000]],

        [[-0.0000, -0.0000,  0.0827],
         [ 0.0540, -0.0000,  0.0697],
         [-0.0000,  0.0000,  0.0566]]], device='cuda:0')


## Global Filter Pruning

Let's now try another way to prune our network, this time we will remove 20% of the least globally important filters. And we will try with another architecture, ResNet18.

In [17]:
learn = Learner(data, models.resnet18(), metrics=[accuracy])

In [18]:
learn.fit_one_cycle(3, 1e-3, callbacks=[SparsifyCallback(learn, sparsity=20, granularity='filter', method='global', criteria='l1', sched_func=annealing_cos)])

Pruning of filter until a sparsity of 20%


epoch,train_loss,valid_loss,accuracy,time
0,1.894946,1.99504,0.424713,00:17
1,1.277624,1.340366,0.572994,00:17
2,0.976437,0.914715,0.707261,00:18


Saving Weights at epoch 0
Sparsity at the end of epoch 0: 5.00%
Sparsity at the end of epoch 1: 15.00%
Sparsity at the end of epoch 2: 20.00%
Final Sparsity: 20.00


This time, we expect to have different sparsities accross layers. This can give us a good indication of how deep in the network are important features extracted.

In [19]:
for k,m in enumerate(learn.model.modules()):
    if isinstance(m, nn.Conv2d):
        print(f"Sparsity in {m.__class__.__name__} {k}: {100. * float(torch.sum(m.weight == 0))/ float(m.weight.nelement()):.2f}%")

Sparsity in Conv2d 1: 0.00%
Sparsity in Conv2d 7: 0.00%
Sparsity in Conv2d 10: 0.00%
Sparsity in Conv2d 13: 0.00%
Sparsity in Conv2d 16: 0.00%
Sparsity in Conv2d 20: 0.00%
Sparsity in Conv2d 23: 0.00%
Sparsity in Conv2d 26: 0.00%
Sparsity in Conv2d 29: 0.00%
Sparsity in Conv2d 32: 0.00%
Sparsity in Conv2d 36: 0.00%
Sparsity in Conv2d 39: 0.00%
Sparsity in Conv2d 42: 0.00%
Sparsity in Conv2d 45: 0.00%
Sparsity in Conv2d 48: 0.00%
Sparsity in Conv2d 52: 30.08%
Sparsity in Conv2d 55: 27.73%
Sparsity in Conv2d 58: 0.00%
Sparsity in Conv2d 61: 69.34%
Sparsity in Conv2d 64: 60.35%
