# Knowledge Distillation

In [12]:
import numpy as np
import tensorflow as tf
import tensorflow.keras as keras
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras import layers as L

In [13]:
print(tf. __version__) 

2.4.1


# Data Preparation

In [2]:
(x_train, y_train), (x_valid, y_valid) = keras.datasets.cifar10.load_data()
x_train = x_train/255.0
x_valid = x_valid/255.0
# x_train = np.expand_dims(x_train, axis=3)
# x_valid = np.expand_dims(x_valid, axis=3)
y_train = keras.utils.to_categorical(y_train)
y_valid = keras.utils.to_categorical(y_valid)

In [3]:
T_EPOCHS = 25
S_EPOCHS = 20
IMAGE_SIZE = x_train.shape[1:]
BATCH_SIZE = 512
N_CLASSES = y_train.shape[-1]
IMAGE_SIZE, N_CLASSES

((32, 32, 3), 10)

In [4]:
def nn_callbacks():
    es = keras.callbacks.EarlyStopping(
        patience=5, verbose=1, restore_best_weights=True, min_delta=1e-4
    )
    rlp = keras.callbacks.ReduceLROnPlateau(patience=2, verbose=1)
    return [es, rlp]

In [5]:
d_train = tf.data.Dataset.from_tensor_slices((x_train, y_train))
d_valid = tf.data.Dataset.from_tensor_slices((x_valid, y_valid))

del x_train, x_valid, y_train, y_valid

# Building the Models

**Teacher Model**

In [6]:
def build_teacher_model(name='teacher'):
    base_model = keras.applications.VGG19(input_shape=IMAGE_SIZE, include_top=False)
    base_model.trainable = True
    return keras.models.Sequential([
            base_model,        
            L.GlobalAvgPool2D(),        
            L.Dense(N_CLASSES)
        ], name=name
    )
        

teacher_model = build_teacher_model()
teacher_model.summary()

Model: "teacher"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
vgg19 (Functional)           (None, 1, 1, 512)         20024384  
_________________________________________________________________
global_average_pooling2d (Gl (None, 512)               0         
_________________________________________________________________
dense (Dense)                (None, 10)                5130      
Total params: 20,029,514
Trainable params: 20,029,514
Non-trainable params: 0
_________________________________________________________________


**Student Model**

In [7]:
def build_student_model(name='student'):
    return keras.models.Sequential([
        L.Conv2D(64, 3, input_shape=IMAGE_SIZE, padding='same', activation='relu'),
        L.Conv2D(64, 3, padding='same', activation='relu'),
        L.Conv2D(64, 3, padding='same', activation='relu'),
        L.MaxPool2D(pool_size=2),
        L.Conv2D(64, 3, padding='same', activation='relu'),
        L.Conv2D(64, 3, padding='same', activation='relu'),
        L.Conv2D(64, 3, padding='same', activation='relu'),
        L.MaxPool2D(pool_size=2),
        L.Conv2D(64, 3, padding='same', activation='relu'),
        L.Conv2D(64, 3, padding='same', activation='relu'),
        L.Conv2D(64, 3, padding='same', activation='relu'),
        L.MaxPool2D(pool_size=2),
        L.Conv2D(64, 3, padding='same', activation='relu'),
        L.Conv2D(64, 3, padding='same', activation='relu'),
        L.Conv2D(64, 3, padding='same', activation='relu'),
        L.MaxPool2D(pool_size=2),
        L.Conv2D(64, 3, padding='same', activation='relu'),
        L.Conv2D(64, 3, padding='same', activation='relu'),
        L.Conv2D(64, 3, padding='same', activation='relu'),
        L.MaxPool2D(pool_size=2),
        L.GlobalAvgPool2D(),
        L.Dense(N_CLASSES),
    ],name=name) 

student_model = build_student_model()
student_model.summary()

Model: "student"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv2d (Conv2D)              (None, 32, 32, 64)        1792      
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 32, 32, 64)        36928     
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 32, 32, 64)        36928     
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 16, 16, 64)        0         
_________________________________________________________________
conv2d_3 (Conv2D)            (None, 16, 16, 64)        36928     
_________________________________________________________________
conv2d_4 (Conv2D)            (None, 16, 16, 64)        36928     
_________________________________________________________________
conv2d_5 (Conv2D)            (None, 16, 16, 64)        3692

# Training Teacher

In [8]:
teacher_model.compile(
    optimizer=keras.optimizers.Adam(1e-5), 
    loss=keras.losses.CategoricalCrossentropy(from_logits=True),
    metrics=['accuracy']
)

history = teacher_model.fit(
    d_train.shuffle(1024, 19).batch(BATCH_SIZE),
    validation_data=d_valid.shuffle(1024, 19).batch(BATCH_SIZE),
    epochs=T_EPOCHS,
    callbacks=nn_callbacks(), 
    batch_size=BATCH_SIZE
)

Epoch 1/25
Epoch 2/25
Epoch 3/25
Epoch 4/25
Epoch 5/25
Epoch 6/25
Epoch 7/25
Epoch 8/25
Epoch 9/25
Epoch 10/25
Epoch 11/25
Epoch 12/25
Epoch 13/25

Epoch 00013: ReduceLROnPlateau reducing learning rate to 9.999999747378752e-07.
Epoch 14/25
Epoch 15/25
Epoch 16/25

Epoch 00016: ReduceLROnPlateau reducing learning rate to 9.999999974752428e-08.
Epoch 17/25
Epoch 18/25

Epoch 00018: ReduceLROnPlateau reducing learning rate to 1.0000000116860975e-08.
Epoch 19/25
Restoring model weights from the end of the best epoch.
Epoch 00019: early stopping


# Distillation in Action

In [9]:
class Distiller(keras.Model):
    def __init__(self, student, teacher, activation):
        super().__init__()
        self.teacher = teacher
        self.student = student
        self.activation = activation

    def compile(
        self,
        optimizer,
        metrics,
        student_loss_fn,
        distillation_loss_fn,
        alpha=0.1,
        temperature=10,
    ):
        """ Configure the distiller.

        Args:
            optimizer: Keras optimizer for the student weights
            metrics: Keras metrics for evaluation
            student_loss_fn: Loss function of difference between student
                predictions and ground-truth
            distillation_loss_fn: Loss function of difference between soft
                student predictions and soft teacher predictions
            alpha: weight to student_loss_fn and 1-alpha to distillation_loss_fn
            temperature: Temperature for softening probability distributions.
                Larger temperature gives softer distributions.
        """
        super().compile(optimizer=optimizer, metrics=metrics)
        self.student.compile(optimizer=optimizer, metrics=metrics, loss=student_loss_fn)
        self.student_loss_fn = student_loss_fn
        self.distillation_loss_fn = distillation_loss_fn
        self.alpha = alpha
        self.temperature = temperature

    def train_step(self, data):
        x, y = data
        teacher_predictions = self.teacher(x, training=False)

        with tf.GradientTape() as tape:
            student_predictions = self.student(x, training=True)
            student_loss = self.student_loss_fn(y, student_predictions)
            distillation_loss = self.distillation_loss_fn(
                self.activation(teacher_predictions / self.temperature, axis=1),
                self.activation(student_predictions / self.temperature, axis=1),
            )
            loss = self.alpha * student_loss + (1 - self.alpha) * distillation_loss

        trainable_vars = self.student.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)

        self.optimizer.apply_gradients(zip(gradients, trainable_vars))
        self.compiled_metrics.update_state(y, student_predictions)

        results = {m.name: m.result() for m in self.metrics}
        results.update(
            {"student_loss": student_loss, "distillation_loss": distillation_loss, "loss": loss}
        )
        return results

    def test_step(self, data):
        x, y = data
        teacher_predictions = self.teacher(x, training=False)
        student_predictions = self.student(x, training=False)
        
        student_loss = self.student_loss_fn(y, student_predictions)
        distillation_loss = self.distillation_loss_fn(
            self.activation(teacher_predictions / self.temperature, axis=1),
            self.activation(student_predictions / self.temperature, axis=1),
        )
        loss = self.alpha * student_loss + (1 - self.alpha) * distillation_loss
        
        self.compiled_metrics.update_state(y, student_predictions)

        results = {m.name: m.result() for m in self.metrics}
        results.update(
            {"student_loss": student_loss, "distillation_loss": distillation_loss, "loss": loss}
        )
        return results
    
    def call(self, x):
        return self.student(x)

In [10]:
distiller = Distiller(student_model, teacher_model, tf.nn.softmax)
distiller.compile(
    optimizer=keras.optimizers.Adam(),
    metrics=['accuracy'],
    student_loss_fn=keras.losses.CategoricalCrossentropy(from_logits=True),
    distillation_loss_fn=keras.losses.KLDivergence(),
    alpha=0.7,
    temperature=1,
)
history_distillation = distiller.fit(
    d_train.shuffle(1024, 19).batch(BATCH_SIZE), 
    validation_data=d_valid.shuffle(1024, 19).batch(BATCH_SIZE),
    epochs=S_EPOCHS, callbacks=nn_callbacks(), batch_size=BATCH_SIZE
)

Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20

Epoch 00016: ReduceLROnPlateau reducing learning rate to 0.00010000000474974513.
Epoch 17/20
Epoch 18/20

Epoch 00018: ReduceLROnPlateau reducing learning rate to 1.0000000474974514e-05.
Epoch 19/20
Restoring model weights from the end of the best epoch.
Epoch 00019: early stopping


# Comparison

In [11]:
import os

print('Teacher Model:')
teacher_model.save('teacher.h5')
teacher_model.evaluate(d_valid.shuffle(1024, 19).batch(BATCH_SIZE))
print("File Size is :", round(os.path.getsize('teacher.h5')/1024**2, 2), "MB")
print('Distilled Model:')
student_model.save('student.h5')
student_model.evaluate(d_valid.shuffle(1024, 19).batch(BATCH_SIZE))
print("File Size is :", round(os.path.getsize('student.h5')/1024**2, 2), "MB")

Teacher Model:
File Size is : 229.35 MB
Distilled Model:
File Size is : 6.09 MB


**Reference**

* [Distilling the Knowledge in a Neural Network](https://arxiv.org/abs/1503.02531)
* [Implementation of classical Knowledge Distillation](https://keras.io/examples/vision/knowledge_distillation/)