# Knowledge distillation

This notebook explores knowledge distillation techniques on MNIST. Knowledge distillation is using a trained network to create a new network that is almost as good. The new network is traditionally much smaller than the input, making the tradeoff worthwhile.

## Change working directory to project root

In [None]:
import os
ROOT_DIRECTORIES = {'dogwood', 'tests'}
if set(os.listdir('.')).intersection(ROOT_DIRECTORIES) != ROOT_DIRECTORIES:
    os.chdir('../..')

## Exploration

In [None]:
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
import h5py
from sklearn.linear_model import LogisticRegression
import tensorflow as tf
from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Flatten
from tensorflow.keras.initializers import Constant, GlorotUniform

In [None]:
MNIST_IMAGE_SHAPE = (28, 28)
MAX_PIXEL_VALUE = 255
MODEL_SAVE_DIR = '/tmp/dogwood/mnist'

In [None]:
(X_train, y_train), (X_test, y_test) = mnist.load_data()
X_train = tf.cast(X_train, tf.float32) / MAX_PIXEL_VALUE
X_test = tf.cast(X_test, tf.float32) / MAX_PIXEL_VALUE
print(X_train.shape)
print(y_train.shape)
print(X_test.shape)
print(y_test.shape)

In [None]:
model = Sequential([
    Flatten(input_shape=(MNIST_IMAGE_SHAPE), name='flatten'),
    Dense(128, activation='relu', name='dense_1'),
    Dense(10, activation='softmax', name='dense_2')
])
model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['sparse_categorical_accuracy'])

In [None]:
model.fit(X_train, y_train, epochs=10, batch_size=32)

In [None]:
model.evaluate(X_test, y_test)

## Student trained on original features with model's labels

First let's train a student network using the model's output as labels.

In [None]:
y_student = tf.argmax(model(X_train), axis=1)
print(y_student.shape)
print(y_student[0])

In [None]:
student = Sequential([
    Flatten(input_shape=(MNIST_IMAGE_SHAPE), name='flatten'),
    Dense(128, activation='relu', name='dense_1'),
    Dense(10, activation='softmax', name='dense_2')
])
student.compile(optimizer='adam',
                loss='sparse_categorical_crossentropy',
                metrics=['sparse_categorical_accuracy'])

In [None]:
student.fit(X_train, y_student, epochs=10, batch_size=32)

In [None]:
student.evaluate(X_test, y_test)

So, we can successfully train a new network using the model's output as labels, and it performs well.

## Random labels

What if we train a student model on random data? Can that work?

### Student model with softmax labels

We now train a student model on the output of the teacher network.

In [None]:
# Here we train the student on the softmax output.
# We could do the sparse representation instead.
X_student = np.random.rand(*X_train.numpy().shape)
y_student = model(X_student)
print(X_student.shape)
print(y_student.shape)
print(y_student[0])

In [None]:
y_test_one_hot = tf.keras.utils.to_categorical(y_test)
y_test_one_hot.shape

In [None]:
student_1 = Sequential([
    Flatten(input_shape=(MNIST_IMAGE_SHAPE), name='flatten'),
    Dense(128, activation='relu', name='dense_1'),
    Dense(10, activation='softmax', name='dense_2')
])
student_1.compile(optimizer='adam',
                  loss='categorical_crossentropy',
                  metrics=['categorical_accuracy'])

In [None]:
student_1.fit(X_student, y_student, epochs=10, batch_size=32)

In [None]:
student_1.evaluate(X_test, y_test_one_hot)

### Student model with sparse labels

In [None]:
y_student_sparse = tf.argmax(y_student, axis=1)
print(y_student_sparse.shape)
print(y_student_sparse[0])

In [None]:
student_2 = Sequential([
    Flatten(input_shape=(MNIST_IMAGE_SHAPE), name='flatten'),
    Dense(128, activation='relu', name='dense_1'),
    Dense(10, activation='softmax', name='dense_2')
])
student_2.compile(optimizer='adam',
                 loss='sparse_categorical_crossentropy',
                 metrics=['sparse_categorical_accuracy'])

In [None]:
student_2.fit(X_student, y_student_sparse, epochs=10, batch_size=32)

In [None]:
student_2.evaluate(X_test, y_test)

### Lots of data

What if we use a lot more data for the student model?

In [None]:
X_student = np.random.rand(600000, 28, 28)
y_student = model(X_student)
y_student_sparse = tf.argmax(y_student, axis=1)
print(X_student.shape)
print(y_student.shape)
print(y_student[0])
print(y_student_sparse.shape)
print(y_student_sparse[0])

In [None]:
student_3 = Sequential([
    Flatten(input_shape=(MNIST_IMAGE_SHAPE), name='flatten'),
    Dense(128, activation='relu', name='dense_1'),
    Dense(10, activation='softmax', name='dense_2')
])
student_3.compile(optimizer='adam',
                  loss='categorical_crossentropy',
                  metrics=['categorical_accuracy'])

In [None]:
student_3.fit(X_student, y_student, epochs=10, batch_size=32)

In [None]:
student_3.evaluate(X_test, y_test_one_hot)

### Smaller student network

What if we force the student to distill knowledge?

In [None]:
student_4 = Sequential([
    Flatten(input_shape=(MNIST_IMAGE_SHAPE), name='flatten'),
    Dense(16, activation='relu', name='dense_1'),
    Dense(10, activation='softmax', name='dense_2')
])
student_4.compile(optimizer='adam',
                  loss='categorical_crossentropy',
                  metrics=['categorical_accuracy'])

In [None]:
student_4.fit(X_student, y_student, epochs=10, batch_size=32)

In [None]:
student_4.evaluate(X_test, y_test_one_hot)