In [1]:
import tensorflow as tf

In [2]:
class CKD():
    def __init__(
        self,
        teacher: tf.keras.Model,
        student: tf.keras.Model,
        alreadySoftmax: bool = True,
        optimizer: tf.keras.optimizers = tf.keras.optimizers.Adam(),
        studentLoss: tf.keras.losses = tf.keras.losses.CategoricalCrossentropy(),
        distilLoss: tf.keras.losses = tf.keras.losses.KLDivergence(),
        metrics = [tf.keras.metrics.CategoricalAccuracy()]
    ):
        """
        Description
        ---------------
        Initialize the teacher model, the student model and their last layer index.

        Input(s)
        ---------------
        teacher: A trained Keras Sequential or Functional model (Sub-class models are not supported).
        student: An untrained Keras Sequential or Functional model (Sub-class models are not supported).
        alreadySoftmax : If the last layer is softmax it must be true, else it must be false (for teacher and student). By default true.
        optimizer: Optimizer instance. By default Adam.
        distilLoss: Loss instance. By default KLDivergence.
        metrics: List of metrics to be evaluated by the model during training and testing. By default accuracy.
        """
        self.teacher = teacher
        self.student = student
        self.alreadySoftmax = alreadySoftmax
        self.metrics = metrics
        self.optimizer = optimizer
        self.distilLoss = distilLoss
        self.studentLoss = studentLoss

    def distil(
        self,
        trainData: tf.data.Dataset,
        valData: tf.data.Dataset,
        epochs: int = 1,
        trainBatchSize: int = 32,
        valBatchSize: int = 32,
        alpha: float = 0.1,
        temperature: int = 3
    ):
        """
        Description
        ---------------
        Distil the knowledge of the teacher to the student.

        Input(s)
        ---------------
        trainData: TensorFlow Dataset with training images.
        valData: TensorFlow Dataset with validation images.
        epochs: Number of epochs to distil the model. By default 1.
        trainBatchSize: Number of samples per gradient update. By default 32.
        valBatchSize: Number of samples per validation batch. By default 32.
        alpha: Loss weighting factor. By default 0.1 (10% student's loss, 90% distillation's loss).
        temperature: Temperature for softening probability distributions. Larger temperature gives softer distributions. By default 3.

        Output(s)
        ---------------
        distilled_model: Distilled Keras Sequential or Functional student model.
        """
        # Compiling student model
        self.student.compile(
            optimizer=self.optimizer,
            loss=self.studentLoss,
            metrics=self.metrics
        )
        # Prepare the training dataset
        trainData = trainData.shuffle(1024).batch(batch_size=trainBatchSize)
        batchNbTrain = trainData.cardinality().numpy()
        # Prepare the validation dataset
        valData = valData.batch(batch_size=valBatchSize)
        batchNbVal = valData.cardinality().numpy()
        # Getting metrics
        metricsHandlerTruth = []
        metricsHandlerTeacher = []
        for metric in self.metrics:
            metricsHandlerTruth.append(tf.keras.metrics.get(metric))
            metricsHandlerTeacher.append(tf.keras.metrics.get(metric))
        # Training
        for epoch in range(epochs):
            print("Distillation Epoch {}/{}".format(epoch+1, epochs))
            pb_train = tf.keras.utils.Progbar(batchNbTrain)
            for step, (x_batch_train, y_batch_train) in enumerate(trainData):
                # Teacher's forward pass
                teacherPredsTrain = self.teacher(x_batch_train, training=False)
                with tf.GradientTape() as tape:
                    # Student's forward pass
                    studentPredsTrain = self.student(x_batch_train, training=True)
                    # Computing distillation and student losses
                    if self.alreadySoftmax == False:
                        distilLossTrain = self.distilLoss(
                            tf.keras.activations.softmax(teacherPredsTrain) / temperature,
                            tf.keras.activations.softmax(studentPredsTrain) / temperature
                        )
                        studentLossTrain = self.studentLoss(
                            y_batch_train,
                            tf.keras.activations.softmax(studentPredsTrain)
                        )
                    else:
                        distilLossTrain = self.distilLoss(
                            teacherPredsTrain / temperature,
                            studentPredsTrain / temperature
                        )
                        studentLossTrain = self.studentLoss(
                            y_batch_train,
                            tf.keras.activations.softmax(studentPredsTrain)
                        )
                    # Computing loss
                    lossTrain = (alpha * studentLossTrain) + ((1 - alpha) * distilLossTrain)
                # Computing metrics
                metricsTuplesTrain = []
                # Comparing to ground truth values
                for handler in metricsHandlerTruth:
                    handler.reset_state()
                    handler.update_state(y_batch_train.numpy(), studentPredsTrain.numpy())
                    metricsTuplesTrain.append((handler.name, handler.result().numpy()))
                # Comparing to teacher
                for handler in metricsHandlerTeacher:
                    handler.reset_state()
                    handler.update_state(teacherPredsTrain.numpy(), studentPredsTrain.numpy())
                    metricsTuplesTrain.append(("Distillation_" + handler.name, handler.result().numpy()))
                # Updating progress bar losses and metrics
                lossesTuplesTrain = [
                    ('Loss', lossTrain),
                    ('DistilLoss', distilLossTrain),
                    ('StudentLoss', studentLossTrain)
                ]
                globalTuplesTrain = lossesTuplesTrain + metricsTuplesTrain
                pb_train.add(1, values=globalTuplesTrain)
                # Computing gradient
                gradients = tape.gradient(lossTrain, self.student.trainable_variables)
                # Update weights
                self.student.optimizer.apply_gradients(zip(gradients, self.student.trainable_variables))
            # Validation
            pb_val = tf.keras.utils.Progbar(batchNbVal)
            for step, (x_batch_val, y_batch_val) in enumerate(valData):
                teacherPredsVal = self.teacher(x_batch_val, training=False)
                studentPredsVal = self.student(x_batch_val, training=False)
                # Computing losses and metrics
                # Computing losses
                if self.alreadySoftmax == False:
                    distilLossVal = self.distilLoss(
                        tf.keras.activations.softmax(teacherPredsVal) / temperature,
                        tf.keras.activations.softmax(studentPredsVal) / temperature
                    )
                    studentLossVal = self.studentLoss(
                        y_batch_val,
                        tf.keras.activations.softmax(studentPredsVal)
                    )
                else:
                    distilLossVal = self.distilLoss(
                        teacherPredsVal / temperature,
                        studentPredsVal / temperature
                    )
                    studentLossVal = self.studentLoss(
                        y_batch_val,
                        tf.keras.activations.softmax(studentPredsVal)
                    )
                lossVal = (alpha * studentLossVal) + ((1 - alpha) * distilLossVal)
                # Updating validation losses and metrics
                lossesTuplesVal = [
                    ('Val_loss', lossVal),
                    ('Val_distilLoss', distilLossVal),
                    ('Val_studentLoss', studentLossVal)
                ]
                # Computing metrics
                metricsTuplesVal = []
                # Comparing to ground truth values
                for handler in metricsHandlerTruth:
                    if step == 0:
                        handler.reset_state()
                    handler.update_state(y_batch_val.numpy(), studentPredsVal.numpy())
                    metricsTuplesVal.append(('Val_' + handler.name, handler.result().numpy()))
                # Comparing to teacher
                for handler in metricsHandlerTeacher:
                    if step == 0:
                        handler.reset_state()
                    handler.update_state(teacherPredsVal.numpy(), studentPredsVal.numpy())
                    metricsTuplesVal.append(("Val_distillation_" + handler.name, handler.result().numpy()))
                globalTuplesVal = lossesTuplesVal + metricsTuplesVal
                pb_val.add(1, values=globalTuplesVal)
        # Returning distilled student
        return self.student



In [3]:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()

In [4]:
input_shape = x_train.shape[1::]
y_train_cat = tf.keras.utils.to_categorical(y_train, num_classes=10)
y_test_cat = tf.keras.utils.to_categorical(y_test, num_classes=10)
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train = x_train / 255.0
x_test = x_test / 255.0

nb_classes = 10

In [5]:
train_sparse = tf.data.Dataset.from_tensor_slices((x_train, y_train)).batch(32)
train_cat = tf.data.Dataset.from_tensor_slices((x_train, y_train_cat))
train_cat_batch = tf.data.Dataset.from_tensor_slices((x_train, y_train_cat)).batch(32)
test_sparse = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(32)
test_cat = tf.data.Dataset.from_tensor_slices((x_test, y_test_cat))
test_cat_batch = tf.data.Dataset.from_tensor_slices((x_test, y_test_cat)).batch(32)


In [6]:
teacher = tf.keras.models.load_model('teacher_model_CKD_presentation.h5')

In [7]:
teacher.summary()

Model: "teacher"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 conv2d (Conv2D)             (None, 16, 16, 32)        896       
                                                                 
 dropout (Dropout)           (None, 16, 16, 32)        0         
                                                                 
 conv2d_1 (Conv2D)           (None, 16, 16, 32)        9248      
                                                                 
 max_pooling2d (MaxPooling2D  (None, 8, 8, 32)         0         
 )                                                               
                                                                 
 conv2d_2 (Conv2D)           (None, 8, 8, 64)          18496     
                                                                 
 dropout_1 (Dropout)         (None, 8, 8, 64)          0         
                                                           

In [8]:
teacher.evaluate(test_sparse)



[0.7514100074768066, 0.7638000249862671, 0.7638000249862671]

In [9]:
student = tf.keras.Sequential(
        [
            tf.keras.Input(shape=input_shape),
            tf.keras.layers.Conv2D(16, (3, 3), strides=(2, 2), padding="same"),
            tf.keras.layers.Dropout(0.2),
            tf.keras.layers.LeakyReLU(alpha=0.2),
            tf.keras.layers.MaxPooling2D(pool_size=(2, 2), strides=(1, 1), padding="same"),
            tf.keras.layers.Conv2D(32, (3, 3), strides=(2, 2), padding="same"),
            tf.keras.layers.Dropout(0.2),
            tf.keras.layers.Flatten(),
            tf.keras.layers.Dense(10)
        ],
        name="student",
    )


In [10]:
student.summary()

Model: "student"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 conv2d (Conv2D)             (None, 16, 16, 16)        448       
                                                                 
 dropout (Dropout)           (None, 16, 16, 16)        0         
                                                                 
 leaky_re_lu (LeakyReLU)     (None, 16, 16, 16)        0         
                                                                 
 max_pooling2d (MaxPooling2D  (None, 16, 16, 16)       0         
 )                                                               
                                                                 
 conv2d_1 (Conv2D)           (None, 8, 8, 32)          4640      
                                                                 
 dropout_1 (Dropout)         (None, 8, 8, 32)          0         
                                                           

In [11]:
dist = CKD(
    teacher,
    student,
    alreadySoftmax=False,
    optimizer=tf.keras.optimizers.Adam(),
    studentLoss=tf.keras.losses.CategoricalCrossentropy()
)


In [12]:
distilled = dist.distil(
    trainData=train_cat,
    valData=test_cat,
    epochs=10
)


Distillation Epoch 1/10
Distillation Epoch 2/10
Distillation Epoch 3/10
Distillation Epoch 4/10
Distillation Epoch 5/10
Distillation Epoch 6/10
Distillation Epoch 7/10
Distillation Epoch 8/10
Distillation Epoch 9/10
Distillation Epoch 10/10


In [13]:
distilled.evaluate(test_cat_batch)



[5.670470714569092, 0.605379045009613]

In [14]:
distilled.save('student.h5')