# Winning the Lottery with fastai
> How to find winning tickets in your neural network

- toc: true
- badges: false
- categories: [Deep Learning]
- comments: true
- image: images/pruning.png
- hide: true

<br>

<br>

## **Lottery Ticket Hypothesis**

The Lottery Ticket Hypothesis is a fascinating characteristic of neural networks that has been found by Frankle and Carbin in 2019. The hypothesis is the following: in a neural network, there exists a subnetwork that can be trained to a comparable accuracy and in a comparable training time than the whole network. The only condition is that the subnetwork starts from the same initial condition than when it was part of the whole network. 

In practice, this subnetwork, called "winning ticket", can be found by using pruning on the network, removing useless connections.

The steps to isolate this winning ticket are: 
1. Get a freshly initialized network
2. Train it to convergence
3. Prune the smallest weights, i.e. the weights that possess the lowest $l_1$-norm
4. Reinitialize the remaining weights to their original value, i.e. their value at step 1)
5. Repeat

![Alt Text](images/LTH/test2.gif)

Using fasterai, we already know how to prune a network. The only change here is that we have to keep track of initialization since we want to start from the initial conditions each time.

In the original paper, the idea was to iteratively prune the network, resetting the remaining weights to their initial value after each pruning step.

In [1]:
#hide
from fastai.vision.all import *
from fasterai.sparse.all import *

In [2]:
class Sparsifier():

    def __init__(self, model, granularity, method, criteria, layer_type=nn.Conv2d):
        store_attr()
        self._save_weights() # Save the original weights

    def prune_layer(self, m, sparsity, round_to=None):
        weight = self.criteria(m, self.granularity)
        mask = self._compute_mask(weight, sparsity, round_to)
        m.register_buffer("_mask", mask) # Put the mask into a buffer
        self._apply(m)

    def prune_model(self, sparsity, round_to=None):
        self.threshold=None
        for m in self.model.modules():
            if isinstance(m, self.layer_type): self.prune_layer(m, sparsity, round_to)

    def _apply(self, m):
        mask = getattr(m, "_mask", None)
        if mask is not None: m.weight.data.mul_(mask)
        if self.granularity == 'filter' and m.bias is not None:
            if mask is not None: m.bias.data.mul_(mask.squeeze()) # We want to prune the bias when pruning filters

    def _mask_grad(self):
        for m in self.model.modules():
            if isinstance(m, self.layer_type) and hasattr(m, '_mask'):
                mask = getattr(m, "_mask")
                if m.weight.grad is not None: m.weight.grad.mul_(mask)
                if self.granularity == 'filter' and m.bias is not None:
                    if m.bias.grad is not None: m.bias.grad.mul_(mask.squeeze())

    def _reset_weights(self): # Reset non-pruned weights
        for m in self.model.modules():
            if hasattr(m, 'weight'):
                init_weights = getattr(m, "_init_weights", m.weight)
                init_biases = getattr(m, "_init_biases", m.bias)
                with torch.no_grad():
                    if m.weight is not None: m.weight.copy_(init_weights)
                    if m.bias is not None: m.bias.copy_(init_biases)
                self._apply(m)
            if isinstance(m, nn.modules.batchnorm._BatchNorm): m.reset_parameters()

    def _save_weights(self):
        for m in self.model.modules():
            if hasattr(m, 'weight'):
                m.register_buffer("_init_weights", m.weight.clone())
                b = getattr(m, 'bias', None)
                if b is not None: m.register_buffer("_init_biases", b.clone())

    def _clean_buffers(self):
        for m in self.model.modules():
            if hasattr(m, 'weight'):
                if hasattr(m, '_mask'): del m._buffers["_mask"]
                if hasattr(m, '_init_weights'): del m._buffers["_init_weights"]
                if hasattr(m, '_init_biases'): del m._buffers["_init_biases"]

    def _compute_threshold(self, weight, sparsity):
        if self.method == 'global':
            global_weight = torch.cat([self.criteria(m, self.granularity).view(-1) for m in self.model.modules() if isinstance(m, self.layer_type)])
            if self.threshold is None: self.threshold = torch.quantile(global_weight, sparsity/100) # Compute the threshold globally (only once per model pruning)
            return self.threshold
        elif self.method == 'local':
            return torch.quantile(weight.view(-1), sparsity/100) # Compute the threshold locally
        else: raise NameError('Invalid Method')

    def _rounded_sparsity(self, n_to_prune, round_to):
        return max(round_to*torch.ceil(n_to_prune/round_to), round_to)

    def _compute_mask(self, weight, sparsity, round_to):
        threshold = self._compute_threshold(weight, sparsity)
        if round_to:
            n_to_keep = sum(weight.ge(threshold)).squeeze()
            threshold = torch.topk(weight.squeeze(), int(self._rounded_sparsity(n_to_keep, round_to)))[0].min()
        if threshold > weight.max(): threshold = weight.max() # Make sure we don't remove every weight of a given layer
        return weight.ge(threshold).to(dtype=weight.dtype)

    def print_sparsity(self):
        for k,m in enumerate(self.model.modules()):
            if isinstance(m, self.layer_type):
                print(f"Sparsity in {m.__class__.__name__} {k}: {100. * float(torch.sum(m.weight == 0))/ float(m.weight.nelement()):.2f}%")

In [3]:
class SparsifyCallback(Callback):

    def __init__(self, end_sparsity, granularity, method, criteria, sched_func, start_sparsity=0, start_epoch=0, end_epoch=None, lth=False, rewind_epoch=0, reset_end=False, model=None, round_to=None, layer_type=nn.Conv2d):
        store_attr()
        self.current_sparsity, self.previous_sparsity = 0, 0

        assert self.start_epoch>=self.rewind_epoch, 'You must rewind to an epoch before the start of the pruning process'

    def before_fit(self):
        print(f'Pruning of {self.granularity} until a sparsity of {self.end_sparsity}%')
        self.end_epoch = self.n_epoch if self.end_epoch is None else self.end_epoch
        assert self.end_epoch <= self.n_epoch, 'Your end_epoch must be smaller than total number of epoch'

        model = self.learn.model if self.model is None else self.model # Pass a model if you don't want the whole model to be pruned
        self.sparsifier = Sparsifier(model, self.granularity, self.method, self.criteria, self.layer_type)
        self.n_batches = math.floor(len(self.learn.dls.dataset)/self.learn.dls.bs)
        self.total_iters = self.end_epoch * self.n_batches
        self.start_iter = self.start_epoch * self.n_batches

    def before_epoch(self):
        if self.epoch == self.rewind_epoch:
            print(f'Saving Weights at epoch {self.epoch}')
            self.sparsifier._save_weights()

    def before_batch(self):
        if self.epoch>=self.start_epoch:
            if self.epoch < self.end_epoch: self._set_sparsity()
            self.sparsifier.prune_model(self.current_sparsity, self.round_to)

            if self.lth and self.current_sparsity!=self.previous_sparsity: # If sparsity has changed, the network has been pruned
                    print(f'Resetting Weights to their epoch {self.rewind_epoch} values')
                    self.sparsifier._reset_weights()
                    #self.sparsifier.model = resnet18(num_classes=10)

            self.previous_sparsity = self.current_sparsity

    def before_step(self):
        if self.epoch>=self.start_epoch:
            self.sparsifier._mask_grad()

    def after_epoch(self):
        print(f'Sparsity at the end of epoch {self.epoch}: {self.current_sparsity:.2f}%')

    def after_fit(self):
        print(f'Final Sparsity: {self.current_sparsity:.2f}')
        if self.reset_end:
            self.sparsifier._reset_weights()
        self.sparsifier._clean_buffers() # Remove buffers at the end of training
        #self.sparsifier.print_sparsity()

    def _set_sparsity(self):
        self.current_sparsity = self.sched_func(start=self.start_sparsity, end=self.end_sparsity, pos=(self.train_iter-self.start_iter)/(self.total_iters-self.start_iter))

In [4]:
#hide
def get_dls(size, pct_noise, bs, device):
    assert pct_noise in [0,5,50], '`pct_noise` must be 0,5 or 50.'
    path = URLs.IMAGENETTE_320
    source = untar_data(path)
    blocks=(ImageBlock, CategoryBlock)
    tfms = [RandomResizedCrop(size, min_scale=0.35), FlipItem(0.5)]
    batch_tfms = [Normalize.from_stats(*imagenet_stats)]
    
    csv_file = 'noisy_imagenette.csv'
    inp = pd.read_csv(source/csv_file)
    dblock = DataBlock(blocks=blocks,
               splitter=ColSplitter(),
               get_x=ColReader('path', pref=source), 
               get_y=ColReader(f'noisy_labels_{pct_noise}'),
               item_tfms=tfms,
               batch_tfms=batch_tfms)
    
    return dblock.dataloaders(inp, path=source, bs=bs, device=device)

In [5]:
#hide
def count_parameters(model):
    num_params = sum(p.numel() for p in model.parameters())
    print(f'Total parameters : {num_params:,}' )

In [6]:
#hide
def print_sparsity(model):
    for k,m in enumerate(model.modules()):
        if isinstance(m, nn.Conv2d):
            print(f"Sparsity in {m.__class__.__name__} {k}: {100. * float(torch.sum(m.weight == 0))/ float(m.weight.nelement()):.2f}%")

In [7]:
#hide
device = torch.device('cuda:0' if torch.cuda.is_available() else "cpu")
torch.cuda.set_device(0)

In [8]:
#hide
dls = get_dls(128, 0, 64, device=device)

Let's first get our baseline:

In [12]:
learn = Learner(dls, resnet18(num_classes=10), metrics=accuracy)

In [13]:
initial_weights = deepcopy(learn.model.state_dict())

In [14]:
learn.fit(5)

epoch,train_loss,valid_loss,accuracy,time
0,1.536754,1.709699,0.481529,00:11
1,1.254531,1.314451,0.578089,00:11
2,1.116412,1.168404,0.634904,00:11
3,1.023481,1.156428,0.633376,00:11
4,0.946494,0.998459,0.677962,00:11


In [15]:
learn = Learner(dls, resnet18(num_classes=10), metrics=accuracy)
learn.model.load_state_dict(initial_weights)

<All keys matched successfully>

In fasterai, 

In [16]:
sp_cb = SparsifyCallback(50, 'weight', 'global', large_final, iterative, start_epoch=5, lth=True)

In [17]:
learn.fit(20, cbs=sp_cb)

Pruning of weight until a sparsity of 50%
Saving Weights at epoch 0


epoch,train_loss,valid_loss,accuracy,time
0,1.54152,1.568734,0.501911,00:11
1,1.258532,1.62822,0.50879,00:11
2,1.111838,1.29268,0.596688,00:11
3,1.024304,1.385538,0.581146,00:11
4,0.930883,1.041547,0.672102,00:11
5,1.33093,1.39527,0.52051,00:20
6,1.141437,1.135004,0.620637,00:20
7,1.040761,1.267395,0.581656,00:20
8,0.952175,1.272328,0.59465,00:20
9,0.909871,1.207141,0.629554,00:20


Sparsity at the end of epoch 0: 0.00%
Sparsity at the end of epoch 1: 0.00%
Sparsity at the end of epoch 2: 0.00%
Sparsity at the end of epoch 3: 0.00%
Sparsity at the end of epoch 4: 0.00%
Resetting Weights to their epoch 0 values
Sparsity at the end of epoch 5: 16.67%
Sparsity at the end of epoch 6: 16.67%
Sparsity at the end of epoch 7: 16.67%
Sparsity at the end of epoch 8: 16.67%
Sparsity at the end of epoch 9: 16.67%
Resetting Weights to their epoch 0 values
Sparsity at the end of epoch 10: 33.33%
Sparsity at the end of epoch 11: 33.33%
Sparsity at the end of epoch 12: 33.33%
Sparsity at the end of epoch 13: 33.33%
Sparsity at the end of epoch 14: 33.33%
Resetting Weights to their epoch 0 values
Sparsity at the end of epoch 15: 50.00%
Sparsity at the end of epoch 16: 50.00%
Sparsity at the end of epoch 17: 50.00%
Sparsity at the end of epoch 18: 50.00%
Sparsity at the end of epoch 19: 50.00%
Final Sparsity: 50.00


In [18]:
learn = Learner(dls, resnet18(num_classes=10), metrics=accuracy)
learn.model.load_state_dict(initial_weights)

<All keys matched successfully>

In [19]:
sp_cb = SparsifyCallback(50, 'weight', 'global', large_final, iterative, start_epoch=5, lth=True, rewind_epoch=1)

In [20]:
learn.fit(20, cbs=sp_cb)

Pruning of weight until a sparsity of 50%


epoch,train_loss,valid_loss,accuracy,time
0,1.529935,1.430763,0.522548,00:11
1,1.268891,1.251196,0.603822,00:11
2,1.141558,1.176961,0.626497,00:11
3,1.013069,1.312681,0.607134,00:11
4,0.933651,0.914163,0.695796,00:11
5,1.183302,1.339694,0.553121,00:20
6,1.027278,1.148169,0.634904,00:20
7,0.919856,1.031522,0.672866,00:20
8,0.890848,0.910739,0.713885,00:20
9,0.824205,0.932853,0.69758,00:20


Sparsity at the end of epoch 0: 0.00%
Saving Weights at epoch 1
Sparsity at the end of epoch 1: 0.00%
Sparsity at the end of epoch 2: 0.00%
Sparsity at the end of epoch 3: 0.00%
Sparsity at the end of epoch 4: 0.00%
Resetting Weights to their epoch 1 values
Sparsity at the end of epoch 5: 16.67%
Sparsity at the end of epoch 6: 16.67%
Sparsity at the end of epoch 7: 16.67%
Sparsity at the end of epoch 8: 16.67%
Sparsity at the end of epoch 9: 16.67%
Resetting Weights to their epoch 1 values
Sparsity at the end of epoch 10: 33.33%
Sparsity at the end of epoch 11: 33.33%
Sparsity at the end of epoch 12: 33.33%
Sparsity at the end of epoch 13: 33.33%
Sparsity at the end of epoch 14: 33.33%
Resetting Weights to their epoch 1 values
Sparsity at the end of epoch 15: 50.00%
Sparsity at the end of epoch 16: 50.00%
Sparsity at the end of epoch 17: 50.00%
Sparsity at the end of epoch 18: 50.00%
Sparsity at the end of epoch 19: 50.00%
Final Sparsity: 50.00


In [18]:
print_sparsity(learn.model)

Sparsity in Conv2d 1: 41.28%
Sparsity in Conv2d 7: 25.14%
Sparsity in Conv2d 10: 25.66%
Sparsity in Conv2d 13: 25.76%
Sparsity in Conv2d 16: 25.46%
Sparsity in Conv2d 20: 33.97%
Sparsity in Conv2d 23: 34.10%
Sparsity in Conv2d 26: 12.82%
Sparsity in Conv2d 29: 34.61%
Sparsity in Conv2d 32: 33.99%
Sparsity in Conv2d 36: 43.95%
Sparsity in Conv2d 39: 44.27%
Sparsity in Conv2d 42: 18.27%
Sparsity in Conv2d 45: 43.93%
Sparsity in Conv2d 48: 43.33%
Sparsity in Conv2d 52: 54.12%
Sparsity in Conv2d 55: 54.15%
Sparsity in Conv2d 58: 24.68%
Sparsity in Conv2d 61: 54.24%
Sparsity in Conv2d 64: 51.93%


---

<br>

<br>

**That's all! Thank you for reading, I hope that you'll like FasterAI. I do not claim that it is perfect, you'll probably find a lot of bugs. If you do, just please tell me, so I can try to solve them 😌 **

<br>

---

<br>

<p style="font-size: 15px"><i>If you notice any mistake or improvement that can be done, please contact me ! If you found that post useful, please consider citing it as:</i></p>

```
@article{hubens2020fasterai,
  title   = "Winning the Lottery with fastai",
  author  = "Hubens, Nathan",
  journal = "nathanhubens.github.io",
  year    = "2020",
  url     = "https://nathanhubens.github.io/posts/deep%20learning/2020/08/17/FasterAI.html"
}
```

## **References**

- {{'[Cristian Buciluǎ, Rich Caruana, and Alexandru Niculescu-Mizil. Model compression. In Proceedings of the 12th ACM SIGKDD international conference on Knowledge discovery and data mining, 2006](https://www.cs.cornell.edu/~caruana/compression.kdd06.pdf)' | fndetail: 1}}
- {{'[Qizhe Xie, Minh-Thang Luong, Eduard Hovy, Quoc V. Le: Self-training with Noisy Student improves ImageNet classification. Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR), 2020](https://arxiv.org/abs/1911.04252)' | fndetail: 2}}
- {{'[H. Li, "Exploring knowledge distillation of Deep neural nets for efficient hardware solutions," CS230 Report, 2018](http://cs230.stanford.edu/files_winter_2018/projects/6940224.pdf)' | fndetail: 3}}
- {{'[Zhu, M. & Gupta, S. (2017). To prune, or not to prune: exploring the efficacy of pruning for model compression. ICLR, 2018 ](https://openreview.net/pdf?id=Sy1iIDkPM)' | fndetail: 4}}