Distilling the Knowledge in a Neural Network: https://arxiv.org/pdf/1503.02531.pdf

In [1]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

# define the model architecture for the student model
def create_student_model(input_shape=(32, 32, 3), num_classes=10):
  inputs = keras.Input(shape=input_shape)
  x = layers.Rescaling(1./255)(inputs)
  x = layers.Conv2D(filters=64, kernel_size=3)(x)
  x = layers.BatchNormalization()(x)
  x = layers.Activation("relu")(x)
  x = layers.Flatten(name='flatten')(x)
  x = layers.Dense(units=16, activation = "relu")(x)
  outputs = layers.Dense(num_classes, activation="softmax")(x)
  student_model = keras.Model(inputs=inputs, outputs=outputs)
  return student_model

# define the model architecture for the teacher model
def create_teacher_model(input_shape=(32, 32, 3), num_classes=10):
  inputs = keras.Input(shape=input_shape)
  x = layers.Rescaling(1./255)(inputs)
  x = layers.Conv2D(filters=64, kernel_size=3)(x)
  x = layers.BatchNormalization()(x)
  x = layers.Activation("relu")(x)
  x = layers.Conv2D(filters=32, kernel_size=3)(x)
  x = layers.BatchNormalization()(x)
  x = layers.Activation("relu")(x)
  x = layers.Conv2D(filters=16, kernel_size=3)(x)
  x = layers.BatchNormalization()(x)
  x = layers.Activation("relu")(x)
  x = layers.Flatten(name='flatten')(x)
  x = layers.Dense(units=16, activation = "relu")(x)
  outputs = layers.Dense(num_classes, activation="softmax")(x)
  teacher_model = keras.Model(inputs=inputs, outputs=outputs)
  return teacher_model

In [2]:
#This code was taken from: https://keras.io/examples/vision/knowledge_distillation/
class Distiller(keras.Model):
    def __init__(self, student, teacher):
        super().__init__()
        self.teacher = teacher
        self.student = student

    def compile(
        self,
        optimizer,
        metrics,
        student_loss_fn,
        distillation_loss_fn,
        alpha=0.1,
        temperature=3,
    ):
        """ Configure the distiller.

        Args:
            optimizer: Keras optimizer for the student weights
            metrics: Keras metrics for evaluation
            student_loss_fn: Loss function of difference between student
                predictions and ground-truth
            distillation_loss_fn: Loss function of difference between soft
                student predictions and soft teacher predictions
            alpha: weight to student_loss_fn and 1-alpha to distillation_loss_fn
            temperature: Temperature for softening probability distributions.
                Larger temperature gives softer distributions.
        """
        super().compile(optimizer=optimizer, metrics=metrics)
        self.student_loss_fn = student_loss_fn
        self.distillation_loss_fn = distillation_loss_fn
        self.alpha = alpha
        self.temperature = temperature

    @tf.function
    def train_step(self, data):
        # Unpack data
        x, y = data

        # Forward pass of teacher
        teacher_predictions = self.teacher(x, training=False)

        with tf.GradientTape() as tape:
            # Forward pass of student
            student_predictions = self.student(x, training=True)

            # Compute losses
            student_loss = self.student_loss_fn(y, student_predictions)

            # Compute scaled distillation loss from https://arxiv.org/abs/1503.02531
            # The magnitudes of the gradients produced by the soft targets scale
            # as 1/T^2, multiply them by T^2 when using both hard and soft targets.
            distillation_loss = (
                self.distillation_loss_fn(
                    tf.nn.softmax(teacher_predictions / self.temperature, axis=1),
                    tf.nn.softmax(student_predictions / self.temperature, axis=1),
                )
                * self.temperature**2
            )

            loss = self.alpha * student_loss + (1 - self.alpha) * distillation_loss

        # Compute gradients
        trainable_vars = self.student.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)

        # Update weights
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))

        # Update the metrics configured in `compile()`.
        self.compiled_metrics.update_state(y, student_predictions)

        # Return a dict of performance
        results = {m.name: m.result() for m in self.metrics}
        results.update(
            {"student_loss": student_loss, "distillation_loss": distillation_loss}
        )
        return results

    @tf.function
    def test_step(self, data):
        # Unpack the data
        x, y = data

        # Compute predictions
        y_prediction = self.student(x, training=False)

        # Calculate the loss
        student_loss = self.student_loss_fn(y, y_prediction)

        # Update the metrics.
        self.compiled_metrics.update_state(y, y_prediction)

        # Return a dict of performance
        results = {m.name: m.result() for m in self.metrics}
        results.update({"student_loss": student_loss})
        return results

In [3]:
student_model = create_student_model()
teacher_model = create_teacher_model()

In [4]:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()

In [5]:
teacher_model.compile(
    optimizer=keras.optimizers.Adam(),
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[keras.metrics.SparseCategoricalAccuracy()],
)

In [6]:
teacher_model.fit(x_train, y_train, epochs=40)

Epoch 1/40


  return dispatch_target(*args, **kwargs)


Epoch 2/40
Epoch 3/40
Epoch 4/40
Epoch 5/40
Epoch 6/40
Epoch 7/40
Epoch 8/40
Epoch 9/40
Epoch 10/40
Epoch 11/40
Epoch 12/40
Epoch 13/40
Epoch 14/40
Epoch 15/40
Epoch 16/40
Epoch 17/40
Epoch 18/40
Epoch 19/40
Epoch 20/40
Epoch 21/40
Epoch 22/40
Epoch 23/40
Epoch 24/40
Epoch 25/40
Epoch 26/40
Epoch 27/40
Epoch 28/40
Epoch 29/40
Epoch 30/40
Epoch 31/40
Epoch 32/40
Epoch 33/40
Epoch 34/40
Epoch 35/40
Epoch 36/40
Epoch 37/40
Epoch 38/40
Epoch 39/40
Epoch 40/40


<keras.callbacks.History at 0x7f60d665dc10>

In [7]:
teacher_model.evaluate(x_test, y_test)



[2.1458559036254883, 0.5145999789237976]

In [8]:
distiller = Distiller(student=student_model, teacher=teacher_model)

In [9]:
import keras.backend as K
from keras.losses import Loss

class KLDivergence(Loss):
    def __init__(self, *args, **kwargs):
        super(KLDivergence, self).__init__(*args, **kwargs)
        
    def call(self, y_true, y_pred):
        y_true = K.clip(y_true, K.epsilon(), 1)
        y_pred = K.clip(y_pred, K.epsilon(), 1)
        return K.sum(y_true * K.log(y_true / y_pred), axis=-1)

In [10]:
distiller.compile(
    optimizer=keras.optimizers.Adam(),
    metrics=[keras.metrics.SparseCategoricalAccuracy()],
    student_loss_fn=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    distillation_loss_fn=KLDivergence(),
    alpha=0.1,
    temperature=10,
)

In [11]:
distiller.fit(x_train, y_train, epochs=5)

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


<keras.callbacks.History at 0x7f60d63a5eb0>

In [12]:
distiller.evaluate(x_test, y_test)



[0.10000000149011612, 2.316351890563965]