# Lottery Ticket Hypothesis

> How to find winning tickets with fastai

In [None]:
#all_slow

## The Lottery Ticket Hypothesis

The [Lottery Ticket Hypothesis](https://arxiv.org/abs/1803.03635) is a really intriguing discovery made in 2019 by Frankle & Carbin. It states that:

> A randomly-initialized, dense neural network contains a subnetwork that is initialised such that — when trained in isolation — it can match the test accuracy of the original network after training for at most the same number of iterations.

Meaning that, once we find that subnetwork. Every other parameter in the network becomes useless.

The way authors propose to find those subnetwork is as follows:

1. Initialize the neural network
2. Train it to convergence
3. Prune the smallest magnitude weights by creating a mask $m$
4. Reinitialize the weights to their original value; i.e at iteration 0.
5. Repeat from step 2 until reaching the desired level of sparsity.

In [None]:
#hide
from fastai.vision.all import *

import torch
import torch.nn as nn

%config InlineBackend.figure_format = 'retina'

In [None]:
from fasterai.sparse.all import *

In [None]:
path = untar_data(URLs.PETS)
files = get_image_files(path/"images")

def label_func(f): return f[0].isupper()

device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

dls = ImageDataLoaders.from_name_func(path, files, label_func, item_tfms=Resize(64), device=device)

What we are trying to prove is that: in a neural network A, there exists a subnetwork B able to get an accuracy $a_B > a_A$, in a training time $t_B < t_A$.

Let's get the baseline for network A:

In [None]:
learn = cnn_learner(dls, resnet18, metrics=accuracy)

Let's save original weights 

In [None]:
initial_weights = learn.model.state_dict()

In [None]:
learn.fit(6)

epoch,train_loss,valid_loss,accuracy,time
0,0.668828,0.434854,0.804465,00:09
1,0.479476,0.370044,0.847091,00:10
2,0.390932,0.329315,0.854533,00:10
3,0.324513,0.288489,0.868742,00:09
4,0.286272,0.259737,0.880244,00:10
5,0.246245,0.250105,0.887686,00:10


To find the lottery ticket, we will perform iterative pruning but, at each pruning step we will re-initialize the remaining weights to their original values (i.e. before training).

We will restart from the same initialization to be sure to not get lucky.

In [None]:
learn.model.load_state_dict(initial_weights)

<All keys matched successfully>

We can pass the parameters `lth=True` to make the weights of the network reset to their original value after each pruning step, i.e. step 4) of the LTH. To empirically validate the LTH, we need to retrain the found "lottery ticket" after the pruning phase. This can be done by setting the `end_epoch` parameter, which will control the epoch at which we stop the pruning process. The fine-tuning phase thus happening during `total_epoch-end_epoch`.

In [None]:
sp_cb = SparsifyCallback(50, 'weight', 'local', large_final, iterative, start_epoch=1, end_epoch=4, lth=True)

So here, we will first pretrain our model for 1 epoch (`start_epoch=1`), perform the LTH process during 3 epochs (`end_epoch-start_epoch`), then the found lottery ticket for 6 epochs (`total_epoch-end_epoch`). If the final accuracy is higher than the baseline then we have found a "winning lottery ticket".

In [None]:
learn.fit(10, cbs=sp_cb)

Pruning of weight until a sparsity of 50%


epoch,train_loss,valid_loss,accuracy,time
0,0.627306,0.399069,0.826116,00:09
1,0.527028,0.367685,0.832882,00:11
2,0.475066,0.343404,0.847767,00:11
3,0.423313,0.332563,0.859269,00:11
4,0.342591,0.29417,0.867388,00:11
5,0.289098,0.279289,0.875507,00:11
6,0.254679,0.27023,0.883627,00:11
7,0.237374,0.244046,0.895805,00:11
8,0.219428,0.242725,0.897835,00:11
9,0.200471,0.240898,0.897158,00:11


Saving Weights at epoch 0
Sparsity at the end of epoch 0: 0.00%
Resetting Weights to their epoch 0 values
Sparsity at the end of epoch 1: 16.67%
Resetting Weights to their epoch 0 values
Sparsity at the end of epoch 2: 33.33%
Resetting Weights to their epoch 0 values
Sparsity at the end of epoch 3: 50.00%
Sparsity at the end of epoch 4: 50.00%
Sparsity at the end of epoch 5: 50.00%
Sparsity at the end of epoch 6: 50.00%
Sparsity at the end of epoch 7: 50.00%
Sparsity at the end of epoch 8: 50.00%
Sparsity at the end of epoch 9: 50.00%
Final Sparsity: 50.00


We indeed have a network B, whose accuracy $a_B > a_A$ in the same training time.

## Lottery Ticket Hypothesis with Rewinding

In some case, LTH fails for deeper networks, author then propose a [solution](https://arxiv.org/pdf/1903.01611.pdf), which is to rewind the weights to a more advanced iteration instead of the initialization value.

In [None]:
learn.model.load_state_dict(initial_weights)

<All keys matched successfully>

This can be done in fasterai by passing the `rewind_epoch` parameter, that will save the weights at that epoch, then resetting the weights accordingly.

In [None]:
sp_cb = SparsifyCallback(50, 'weight', 'local', large_final, iterative , start_epoch=1, end_epoch=4, lth=True, rewind_epoch=1)

In [None]:
learn.fit(10, cbs=sp_cb)

Pruning of weight until a sparsity of 50%


epoch,train_loss,valid_loss,accuracy,time
0,0.637804,0.389029,0.822057,00:09
1,0.414081,0.324026,0.859946,00:11
2,0.354582,0.309311,0.866035,00:11
3,0.329014,0.273554,0.876861,00:11
4,0.291184,0.25973,0.885656,00:11
5,0.259071,0.251663,0.895129,00:11
6,0.233578,0.249048,0.901894,00:11
7,0.209886,0.244723,0.899865,00:11
8,0.196974,0.238308,0.898512,00:11
9,0.19281,0.222618,0.905954,00:11


Sparsity at the end of epoch 0: 0.00%
Saving Weights at epoch 1
Resetting Weights to their epoch 1 values
Sparsity at the end of epoch 1: 16.67%
Resetting Weights to their epoch 1 values
Sparsity at the end of epoch 2: 33.33%
Resetting Weights to their epoch 1 values
Sparsity at the end of epoch 3: 50.00%
Sparsity at the end of epoch 4: 50.00%
Sparsity at the end of epoch 5: 50.00%
Sparsity at the end of epoch 6: 50.00%
Sparsity at the end of epoch 7: 50.00%
Sparsity at the end of epoch 8: 50.00%
Sparsity at the end of epoch 9: 50.00%
Final Sparsity: 50.00


## Super-Masks

Researchers from Uber AI [investigated](https://arxiv.org/pdf/1905.01067.pdf) the LTH and found the existence of what they call "Super-Masks", i.e. masks that, applied on a untrained neural network, allows to reach better-than-random results.

In [None]:
learn.model.load_state_dict(initial_weights)

<All keys matched successfully>

To find supermasks, authors perform the LTH method then apply the mask on the original, untrained network. In fasterai, you can pass the parameter `reset_end=True`, which will reset the weights to their original value at the end of the training, but keeping the pruned weights (i.e. the mask) unchanged.

In [None]:
sp_cb = SparsifyCallback(50, 'weight', 'local', large_final, iterative , start_epoch=1, lth=True, reset_end=True)

In [None]:
learn.fit(4, cbs=sp_cb)

Pruning of weight until a sparsity of 50%


epoch,train_loss,valid_loss,accuracy,time
0,0.644316,0.405834,0.815291,00:09
1,0.541251,0.364345,0.844384,00:11
2,0.493762,0.341148,0.85318,00:11
3,0.441781,0.319401,0.860622,00:11


Saving Weights at epoch 0
Sparsity at the end of epoch 0: 0.00%
Resetting Weights to their epoch 0 values
Sparsity at the end of epoch 1: 16.67%
Resetting Weights to their epoch 0 values
Sparsity at the end of epoch 2: 33.33%
Resetting Weights to their epoch 0 values
Sparsity at the end of epoch 3: 50.00%
Final Sparsity: 50.00


In [None]:
learn.validate()

(#2) [1.0403801202774048,0.6698240637779236]

So now we have an untrained model which is better than random (50%).