# Lottery Ticket Hypothesis

> How to find winning tickets with fastai

In [None]:
#all_slow

## The Lottery Ticket Hypothesis

The [Lottery Ticket Hypothesis](https://arxiv.org/abs/1803.03635) is a really intriguing discovery made in 2019 by Frankle & Carbin. It states that:

> A randomly-initialized, dense neural network contains a subnetwork that is initialised such that — when trained in isolation — it can match the test accuracy of the original network after training for at most the same number of iterations.

Meaning that, once we find that subnetwork. Every other parameter in the network becomes useless.

The way authors propose to find those subnetwork is as follows:

1. Initialize the neural network
2. Train it to convergence
3. Prune the smallest magnitude weights by creating a mask $m$
4. Reinitialize the weights to their original value; i.e at iteration 0.
5. Repeat from step 2 until reaching the desired level of sparsity.

In [None]:
#hide
from fastai.vision.all import *

import torch
import torch.nn as nn

%config InlineBackend.figure_format = 'retina'

In [None]:
from fasterai.sparse.all import *

In [None]:
path = untar_data(URLs.PETS)
files = get_image_files(path/"images")

def label_func(f): return f[0].isupper()

device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

dls = ImageDataLoaders.from_name_func(path, files, label_func, item_tfms=Resize(64), device=device)

What we are trying to prove is that: in a neural network A, there exists a subnetwork B able to get an accuracy $a_B > a_A$, in a training time $t_B < t_A$.

Let's get the baseline for network A:

In [None]:
learn = cnn_learner(dls, resnet18, metrics=accuracy)

In [None]:
initial_weights = learn.state_dict()

In [None]:
learn.fit(6)

epoch,train_loss,valid_loss,accuracy,time
0,0.686072,0.48912,0.79364,00:07
1,0.499003,0.363835,0.840325,00:06
2,0.396878,0.3339,0.845061,00:07
3,0.331666,0.311159,0.861976,00:06
4,0.282472,0.303883,0.870095,00:06
5,0.257027,0.29612,0.870095,00:07


To find the lottery ticket, we will perform iterative pruning but, at each pruning step we will re-initialize the remaining weights to their original values (i.e. before training).

We will restart from the same initialization to be sure to not get lucky.

In [None]:
learn.load_state_dict(initial_weights)

<All keys matched successfully>

In [None]:
sp_cb = SparsifyCallback(50, 'weight', 'local', l1_norm, iterative, start_epoch=1, lth=True, reset_end=True)

In [None]:
learn.fit(4, cbs=sp_cb)

Pruning of weight until a sparsity of 50%
Saving Weights at epoch 0


epoch,train_loss,valid_loss,accuracy,time
0,0.681346,0.40733,0.826116,00:06
1,0.5617,0.391802,0.835589,00:09
2,0.499509,0.366445,0.849797,00:09
3,0.456785,0.336491,0.862652,00:09


Sparsity at the end of epoch 0: 0.00%
Resetting Weights to their epoch 0 values
Sparsity at the end of epoch 1: 16.67%
Resetting Weights to their epoch 0 values
Sparsity at the end of epoch 2: 33.33%
Resetting Weights to their epoch 0 values
Sparsity at the end of epoch 3: 50.00%
Final Sparsity: 50.00


In [None]:
sp_cb = SparsifyCallback(50, 'weight', 'local', l1_norm, one_shot)

In [None]:
learn.fit(6, cbs=sp_cb)

Pruning of weight until a sparsity of 50%
Saving Weights at epoch 0


epoch,train_loss,valid_loss,accuracy,time
0,0.389218,0.339762,0.863329,00:09
1,0.313411,0.29467,0.879567,00:09
2,0.297718,0.288504,0.884303,00:09
3,0.252114,0.276948,0.881597,00:09
4,0.235741,0.255063,0.895805,00:09
5,0.208339,0.256576,0.896482,00:09


Sparsity at the end of epoch 0: 50.00%
Sparsity at the end of epoch 1: 50.00%
Sparsity at the end of epoch 2: 50.00%
Sparsity at the end of epoch 3: 50.00%
Sparsity at the end of epoch 4: 50.00%
Sparsity at the end of epoch 5: 50.00%
Final Sparsity: 50.00


We indeed have a network B, whose accuracy $a_B > a_A$ in the same training time.

## Lottery Ticket Hypothesis with Rewind

In [None]:
learn.load_state_dict(initial_weights)

<All keys matched successfully>

In [None]:
sp_cb = SparsifyCallback(80, 'weight', 'local', l1_norm, iterative , start_epoch=1, lth=True, reset_end=True, rewind_epoch=1)

In [None]:
learn.fit(4, cbs=sp_cb)

Pruning of weight until a sparsity of 80%


epoch,train_loss,valid_loss,accuracy,time
0,0.667609,0.423499,0.803112,00:06
1,0.455044,0.361866,0.833559,00:09
2,0.407168,0.343839,0.847091,00:10
3,0.465871,0.403428,0.804465,00:09


Sparsity at the end of epoch 0: 0.00%
Saving Weights at epoch 1
Resetting Weights to their epoch 1 values
Sparsity at the end of epoch 1: 26.67%
Resetting Weights to their epoch 1 values
Sparsity at the end of epoch 2: 53.33%
Resetting Weights to their epoch 1 values
Sparsity at the end of epoch 3: 80.00%
Final Sparsity: 80.00


In [None]:
sp_cb = SparsifyCallback(80, 'weight', 'local', l1_norm, one_shot)

In [None]:
learn.fit(6, cbs=sp_cb)

Pruning of weight until a sparsity of 80%
Saving Weights at epoch 0


epoch,train_loss,valid_loss,accuracy,time
0,0.419317,0.380852,0.82544,00:10
1,0.35709,0.349349,0.845061,00:09
2,0.32809,0.327959,0.856563,00:09
3,0.299707,0.3153,0.861299,00:09
4,0.275035,0.303642,0.863329,00:09
5,0.257127,0.31021,0.865359,00:10


Sparsity at the end of epoch 0: 80.00%
Sparsity at the end of epoch 1: 80.00%
Sparsity at the end of epoch 2: 80.00%
Sparsity at the end of epoch 3: 80.00%
Sparsity at the end of epoch 4: 80.00%
Sparsity at the end of epoch 5: 80.00%
Final Sparsity: 80.00
