# Lottery ticket exploration

This notebook explores lottery tickets and how they could be manipulated.

## Change working directory to project root

In [None]:
import os
ROOT_DIRECTORIES = {'dogwood', 'tests'}
if set(os.listdir('.')).intersection(ROOT_DIRECTORIES) != ROOT_DIRECTORIES:
    os.chdir('../..')

## Exploration

### Train a model on MNIST

In [None]:
import numpy as np
import tensorflow as tf
from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Flatten, Layer

In [None]:
MNIST_IMAGE_SHAPE = (28, 28)
MAX_PIXEL_VALUE = 255
MODEL_SAVE_DIR = '/tmp/dogwood/mnist'

In [None]:
(X_train, y_train), (X_test, y_test) = mnist.load_data()
X_train = tf.cast(X_train, tf.float32) / MAX_PIXEL_VALUE
X_test = tf.cast(X_test, tf.float32) / MAX_PIXEL_VALUE
print(X_train.shape)
print(y_train.shape)
print(X_test.shape)
print(y_test.shape)

In [None]:
model = Sequential([
    Flatten(input_shape=(MNIST_IMAGE_SHAPE), name='flatten'),
    Dense(128, activation='relu', name='dense_1'),
    Dense(10, activation='softmax', name='dense_2')
])
model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['sparse_categorical_accuracy'])

In [None]:
dense_1_initial_weights = model.layers[1].get_weights()
dense_2_initial_weights = model.layers[2].get_weights()

In [None]:
model.fit(X_train, y_train, epochs=10, batch_size=32)

In [None]:
model.evaluate(X_test, y_test)

### Identify lottery tickets

According to Frankle and Carbin (2019), we can identify a winning ticket almost trivially by performing the following steps.

1. Randomly initialize a neural network.
1. Train the network.
1. Prune p% of the parameters in the trained weights, creating a mask. The lowest-magnitude parameters are trimmed.
1. Reset these parameters to their randomly initialized values, creating the winning ticket.

In [None]:
def get_prune_mask(weights: np.ndarray, prune_rate: float) -> np.ndarray:
    """Returns a mask representing the pruned parameters.
    
    Pruned parameters are given a value of 0 in the mask; retained parameters
    are given a value of 1. The lowest-magnitude parameters are pruned.
    
    :param weights: The weights of one layer in a neural network.
    :param prune_rate: The fraction of parameters to prune in the range [0, 1].
    :return: A binary mask representing the pruned parameters.
    """
    # TODO