# Mini Project 2 - DL Skills - Distiller

## Import libraries and load data

In [None]:
import keras
import tensorflow as tf
from utils import get_prepared_data
from model_utils import (
    get_student_vgg,
    get_vgg
)

In [None]:
# Create the teacher
teacher = get_vgg()

# Create the student
student = get_student_vgg()

# Clone student for later comparison
student_scratch = keras.models.clone_model(student)

In [None]:
(x_train, x_test), (y_train, y_test) = get_prepared_data()

In [None]:
batch_size = 8

# https://www.tensorflow.org/tutorials/load_data/numpy
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_dataset = train_dataset.repeat(8)
train_dataset = train_dataset.shuffle(400000, reshuffle_each_iteration=True)
train_dataset = train_dataset.batch(batch_size, num_parallel_calls=tf.data.AUTOTUNE)

test_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test))
test_dataset = test_dataset.repeat(8)
test_dataset = test_dataset.shuffle(80000, reshuffle_each_iteration=True)
test_dataset = test_dataset.batch(batch_size, num_parallel_calls=tf.data.AUTOTUNE)

## Compile and train teacher model

In [None]:
from keras.optimizers import adam_v2

opt = adam_v2.Adam(
    learning_rate=0.001,
    beta_1=0.9,
    beta_2=0.999,
    epsilon=1e-07,
    amsgrad=True
)

teacher.compile(
    optimizer=opt,
    loss=keras.losses.categorical_crossentropy,
    metrics=['accuracy']
)

In [None]:
from keras.callbacks import (ModelCheckpoint, EarlyStopping)

teacher_hist = teacher.fit(
    x=train_dataset,
    steps_per_epoch=100,
    validation_data=test_dataset,
    validation_steps=10,
    epochs=200
)

## Compile and fit student

In [None]:
from model_utils import Distiller

# Initialize and compile distiller
distiller = Distiller(student=student, teacher=teacher)
distiller.compile(
    optimizer=opt,
    metrics=['accuracy'],
    student_loss_fn=keras.losses.categorical_crossentropy,
    distillation_loss_fn=keras.losses.KLDivergence(),
    alpha=0.1,
    temperature=30
)

# Distill teacher to student
dist_hist = distiller.fit(
    x=train_dataset,
    steps_per_epoch=100,
    validation_data=test_dataset,
    validation_steps=10,
    epochs=100
)

In [None]:
# Train student as doen usually
student_scratch.compile(
    optimizer=opt,
    loss=keras.losses.categorical_crossentropy,
    metrics=['accuracy']
)

# Train and evaluate student trained from scratch.
st_hist = student_scratch.fit(
    x=train_dataset,
    steps_per_epoch=100,
    validation_data=test_dataset,
    validation_steps=10,
    epochs=100
)

### Plot Learning Curves (Loss and Accuracy)

In [None]:
import matplotlib.pyplot as plt

fig, ((ax1, ax2, ax3), (ax4, ax5, ax6)) = plt.subplots(2, 3, figsize=(12, 12))
min_y, max_y = 0.1, 0.85

# Teacher

ax4.plot(teacher_hist.history["loss"])
ax4.plot(teacher_hist.history["val_loss"])
ax4.set_title("teacher loss")
ax4.set_ylabel("loss")
ax4.set_xlabel("epoch")
ax4.legend(["train", "test"], loc="upper left")

ax1.plot(teacher_hist.history["accuracy"])
ax1.plot(teacher_hist.history["val_accuracy"])
ax1.set_title("teacher accuracy")
ax1.set_ylabel("accuracy")
ax1.set_xlabel("epoch")
ax1.legend(["train", "test"], loc="upper left")
ax1.set_ylim([min_y, max_y])

# Distilled

ax2.plot(dist_hist.history["accuracy"])
ax2.plot(dist_hist.history["val_accuracy"])
ax2.set_title("student distilled accuracy")
ax2.set_ylabel("accuracy")
ax2.set_xlabel("epoch")
ax2.legend(["train", "test"], loc="upper left")
ax2.set_ylim([min_y, max_y])

# Student scratch

ax3.plot(st_hist.history["accuracy"])
ax3.plot(st_hist.history["val_accuracy"])
ax3.set_title("student not distilled accuracy")
ax3.set_ylabel("accuracy")
ax3.set_xlabel("epoch")
ax3.legend(["train", "test"], loc="upper left")
ax3.set_ylim([min_y, max_y])

ax5.set_visible(False)
ax6.set_visible(False)

fig.tight_layout()
plt.show()