# Weight transfer experiments

In this notebook, we conduct weight transfer experiments on MNIST.

## Change working directory to project root

In [None]:
import os
ROOT_DIRECTORIES = {'dogwood', 'tests'}
if set(os.listdir('.')).intersection(ROOT_DIRECTORIES) != ROOT_DIRECTORIES:
    os.chdir('../..')

## Experiments

In [None]:
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Flatten, Dense
from dogwood.weight_transfer import are_symmetric_dense_neurons, \
    expand_dense_layer, expand_dense_layers, clone_layer

MAX_PIXEL_VALUE = 255
MNIST_IMAGE_SHAPE = (28, 28)

In [None]:
# TODO should we make MNIST a versioned dataset?
(X_train, y_train), (X_test, y_test) = mnist.load_data()
X_train = tf.cast(X_train, tf.float32) / MAX_PIXEL_VALUE
X_test = tf.cast(X_test, tf.float32) / MAX_PIXEL_VALUE

In [None]:
def get_small_model(num_hidden=1):
    model = Sequential([
        Flatten(input_shape=MNIST_IMAGE_SHAPE, name='flatten'),
        Dense(num_hidden, activation='relu', name='dense_1'),
        Dense(10, activation='softmax', name='dense_2')
    ])
    model.compile(optimizer='adam',
                  loss='sparse_categorical_crossentropy',
                  metrics=['sparse_categorical_accuracy'])
    return model

**One hidden layer weight expansion experiment on MNIST**

In this experiment, we compare the training progress of models trained from scratch against those with partially pretrained weights. We also compare their performance on the test data after training.

In [None]:
batch_size = 32
epochs = 20
# TODO add evaluation
from_scratch_histories = {}
from_scratch_eval = {}
expanded_histories = {}
expanded_eval = {}
expanded_model = None
for num_hidden in range(1, 6):
    from_scratch_model = get_small_model(num_hidden)
    history = from_scratch_model.fit(X_train, y_train, batch_size=batch_size, epochs=epochs)
    from_scratch_histories[num_hidden] = history
    # The first iteration, the expanded model is created from scratch.
    # Every other iteration, expand.
    if not expanded_model:
        expanded_model = get_small_model(num_hidden)
    else:
        expanded_model = expand_dense_layer(expanded_model, 'dense_1', 1)
        expanded_model.compile(
            optimizer='adam',
            loss='sparse_categorical_crossentropy',
            metrics=['sparse_categorical_accuracy'])
    history = expanded_model.fit(X_train, y_train, batch_size=batch_size, epochs=epochs)
    expanded_histories[num_hidden] = history

In [None]:
plt.plot(from_scratch_histories[1].history['sparse_categorical_accuracy'])
plt.plot(expanded_histories[1].history['sparse_categorical_accuracy'])
plt.show()

In [None]:
fig, ax = plt.subplots(figsize=(10, 10))
for num_hidden in sorted(from_scratch_histories.keys()):    
    plt.plot(from_scratch_histories[num_hidden].history['sparse_categorical_accuracy'], '-', label=f'S{num_hidden}')
    plt.plot(expanded_histories[num_hidden].history['sparse_categorical_accuracy'], 'o', label=f'E{num_hidden}')
plt.legend()
plt.xlim([0, epochs])
plt.ylim([0, 1])
plt.xticks(list(range(epochs)))
plt.show()