# Knowledge Distillation


## Setup

In [None]:
pip install tensorflow-model-optimization
pip install larq

Note: you may need to restart the kernel to use updated packages.




In [None]:
import os
import tensorflow as tf
import keras
from keras import layers
from tensorflow.python.framework import ops
import numpy as np
from tensorflow.keras import regularizers
import numpy as np
from tensorflow.keras.layers import Lambda, Dropout
import tensorflow_model_optimization as tfmot
import tempfile


import larq as lq
from larq.layers import QuantConv2D, QuantDense



## Distiller


- A student loss function on the difference between student predictions and ground-truth
- A distillation loss function, along with a `temperature`, on the difference between the
soft student predictions and the soft teacher labels
- An `alpha` factor to weight the student and distillation loss



In [None]:
class Distiller(keras.Model):
    def __init__(self, student, teacher):
        super().__init__()
        self.teacher = teacher
        self.student = student

    def compile(
        self,
        optimizer,
        metrics,
        student_loss_fn,
        distillation_loss_fn,
        alpha=0.1,
        temperature=3,
    ):
        super().compile(optimizer=optimizer, metrics=metrics)
        self.student_loss_fn = student_loss_fn
        self.distillation_loss_fn = distillation_loss_fn
        self.alpha = alpha
        self.temperature = temperature

    def train_step(self, data):
        x, y = data
        with tf.GradientTape() as tape:
            y_pred = self.student(x, training=True)
            teacher_pred = self.teacher(x, training=False)
            student_loss = self.student_loss_fn(y, y_pred)

            distillation_loss = self.distillation_loss_fn(
                tf.nn.softmax(teacher_pred / self.temperature, axis=1),
                tf.nn.softmax(y_pred / self.temperature, axis=1),
            ) * (self.temperature**2)

            loss = self.alpha * student_loss + (1 - self.alpha) * distillation_loss

        trainable_vars = self.student.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))

        self.compiled_metrics.update_state(y, y_pred)
        return {m.name: m.result() for m in self.metrics}

    def call(self, x):
        return self.student(x)

## Create student and teacher models



In [None]:
# Create the teacher
teacher = keras.Sequential(
    [
        keras.Input(shape=(32, 32, 1)),
        layers.Conv2D(32, (3, 3), activation='relu', kernel_initializer='he_uniform', padding='same', input_shape=(32, 32, 1)),
        layers.BatchNormalization(),
        layers.Conv2D(32, (3, 3), activation='relu', kernel_initializer='he_uniform', padding='same'),
        layers.BatchNormalization(),
        layers.MaxPool2D((2, 2)),
        layers.Dropout(0.2),
        layers.Conv2D(64, (3, 3), activation='relu', kernel_initializer='he_uniform', padding='same'),
        layers.BatchNormalization(),
        layers.Conv2D(64, (3, 3), activation='relu', kernel_initializer='he_uniform', padding='same'),
        layers.BatchNormalization(),
        layers.MaxPool2D((2, 2)),
        layers.Dropout(0.3),
        layers.Conv2D(128, (3, 3), activation='relu', kernel_initializer='he_uniform', padding='same'),
        layers.BatchNormalization(),
        layers.Conv2D(128, (3, 3), activation='relu', kernel_initializer='he_uniform', padding='same'),
        layers.BatchNormalization(),
        layers.MaxPool2D((2, 2)),
        layers.Dropout(0.4),
        layers.Flatten(),
        layers.Dense(512, activation='relu', kernel_initializer='he_uniform'),  # Match the student model
        layers.Dense(256, activation='relu', kernel_initializer='he_uniform'),  # Match the student model
        layers.Dense(128, activation='relu', kernel_initializer='he_uniform'),
        layers.Dropout(0.5),
        layers.Dense(10, activation='softmax'),
    ],
    name="teacher",
)

In [None]:
student = keras.Sequential([
    layers.Conv2D(64, (3, 3), activation='relu', input_shape=(32, 32, 1)),
    layers.Conv2D(128, (3, 3), activation='relu'),
    layers.Conv2D(512, (3, 3), activation='relu'),
    layers.GlobalAveragePooling2D(),
    layers.Dense(10, activation="softmax")
])


student_scratch = keras.Sequential([
    layers.Conv2D(64, (3, 3), activation='relu', input_shape=(32, 32, 1)),
    layers.Conv2D(128, (3, 3), activation='relu'),
    layers.Conv2D(512, (3, 3), activation='relu'),
    layers.GlobalAveragePooling2D(),
    layers.Dense(10, activation="softmax")
])

## Prepare the dataset



In [None]:
from sklearn.model_selection import train_test_split
batch_size = 64
(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()

# Convert images to greyscale
x_train = np.dot(x_train[..., :3], [0.2989, 0.5870, 0.1140])
x_test = np.dot(x_test[..., :3], [0.2989, 0.5870, 0.1140])

# Normalize data
x_train = x_train.astype("float32") / 255.0
x_train = np.reshape(x_train, (-1, 32, 32, 1))

x_test = x_test.astype("float32") / 255.0
x_test = np.reshape(x_test, (-1, 32, 32, 1))

from sklearn.model_selection import train_test_split
x_train, x_val, y_train, y_val = train_test_split(x_train, y_train, test_size=0.2, random_state=42)

# Convert labels to one-hot encoded format
num_classes = 10  # Replace with the number of classes in your problem
y_train = tf.keras.utils.to_categorical(y_train, num_classes)
y_val = tf.keras.utils.to_categorical(y_val, num_classes)
y_test = tf.keras.utils.to_categorical(y_test, num_classes)

## Train the teacher



In [None]:
# Train teacher as usual
teacher.compile(
    optimizer=keras.optimizers.Adam(),
    loss=keras.losses.CategoricalCrossentropy(from_logits=True),
    metrics=[keras.metrics.CategoricalAccuracy()],
)

# Train and evaluate teacher on data.
teacher.fit(x_train, y_train, epochs=20, validation_data=(x_val, y_val))
teacher.evaluate(x_test, y_test)

Epoch 1/20


  output, from_logits = _get_logits(


Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20


[0.6708458065986633, 0.8177000284194946]

# Sort data from easy to hard

In [None]:
def curriculum_sort(x_train, y_train):
    # sort based on the entropy of the teacher's predictions
    teacher_preds = teacher.predict(x_train)
    entropies = -np.sum(teacher_preds * np.log(teacher_preds + 1e-10), axis=1)
    sorted_indices = np.argsort(entropies)
    return sorted_indices


In [None]:
# Sort the training data based on the curriculum
sorted_indices = curriculum_sort(x_train, y_train)
x_train = x_train[sorted_indices]
y_train = y_train[sorted_indices]



## Distill teacher to student



In [None]:
student.compile(
    optimizer=keras.optimizers.Adam(),
    loss=keras.losses.CategoricalCrossentropy(),
    metrics=[keras.metrics.CategoricalAccuracy()],

student.fit(x_train, y_train, epochs=10, validation_data=(x_val, y_val))

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


<keras.callbacks.History at 0x1ee95f20700>

In [None]:
student.evaluate(x_test, y_test)



[2.302635431289673, 0.10000000149011612]

In [None]:
initial_temp = 6
final_temp = 1
initial_alpha = 0.9
final_alpha = 0.1
decay_rate = 1
loss_scale = 2


# Initialize and compile distiller
distiller = Distiller(student=student, teacher=teacher)
distiller.compile(
    optimizer=keras.optimizers.Adam(),
    metrics=[keras.metrics.CategoricalAccuracy()],
    student_loss_fn=keras.losses.CategoricalCrossentropy(),
    distillation_loss_fn=keras.losses.KLDivergence(),
    alpha=0.9,
    temperature=6,
)


def alpha_schedule(epoch):
   # Returns exponentially decaying alpha
   return initial_alpha * np.exp(-epoch / num_epochs)

def temp_schedule(epoch):
    # Returns logarithmically increasing temperature
#     return initial_temp + (final_temp - initial_temp) * np.log1p(epoch) / np.log1p(num_epochs)
    return final_temp + (initial_temp - final_temp) * (1 - np.log1p(epoch) / np.log1p(num_epochs))

# Temperature annealing
num_epochs = 20
# temp_step = (initial_temp - final_temp) / num_epochs


from keras.preprocessing.image import ImageDataGenerator

# Data augmentation
datagen = ImageDataGenerator(
    rotation_range=15,
    width_shift_range=0.1,
    height_shift_range=0.1,
    horizontal_flip=True,
)

# def adaptive_temperature(student_loss, initial_temp, final_temp, decay_rate):
#     return final_temp + (initial_temp - final_temp) * np.exp(-decay_rate * student_loss)

# def adaptive_temperature(student_loss, initial_temp, final_temp, decay_rate, loss_scale):
#     return final_temp + (initial_temp - final_temp) * np.exp(-decay_rate * student_loss * loss_scale)

# def adaptive_alpha(student_loss, initial_alpha, final_alpha, decay_rate, loss_scale):
#     return final_alpha + (initial_alpha - final_alpha) * np.exp(-decay_rate * student_loss * loss_scale)



# def adaptive_temperature(student_loss, initial_temp, final_temp, decay_rate, loss_scale):
#     return initial_temp - (initial_temp - final_temp) * decay_rate * (student_loss / loss_scale)

# def adaptive_alpha(student_loss, initial_alpha, final_alpha, decay_rate, loss_scale):
#     return initial_alpha - (initial_alpha - final_alpha) * decay_rate * (student_loss / loss_scale)

def adaptive_temperature(student_loss, initial_temp, final_temp, decay_rate):
    return initial_temp - (initial_temp - final_temp) * decay_rate * student_loss

def adaptive_alpha(student_loss, initial_alpha, final_alpha, decay_rate):
    return initial_alpha - (initial_alpha - final_alpha) * decay_rate * student_loss


# Training loop
for epoch in range(num_epochs):
    print(f"Epoch {epoch+1}/{num_epochs}")

    # Perform data augmentation
    datagen.fit(x_train)



    alpha = alpha_schedule(epoch)
    temperature = temp_schedule(epoch)
    distiller.temperature = temperature
    distiller.alpha = alpha

    epoch_student_loss = 0

    # Iterate over the training data in batches
    for batch_start in range(0, len(x_train), batch_size):
        batch_end = min(batch_start + batch_size, len(x_train))
        x_batch = x_train[batch_start:batch_end]
        y_batch = y_train[batch_start:batch_end]

        with tf.GradientTape() as tape:
            # Forward pass
            student_predictions = distiller.student(x_batch)

            # Calculate student loss for the batch
            student_loss = distiller.student_loss_fn(y_batch, student_predictions)

        # Accumulate student loss for the epoch
        epoch_student_loss += student_loss.numpy()




    # Calculate average student loss for the epoch
    epoch_student_loss /= (len(x_train) // batch_size)

#     Calculate adaptive temperature
#     adaptive_temp_value = adaptive_temperature(epoch_student_loss, initial_temp, final_temp, decay_rate)
#     adaptive_alpha_value = adaptive_alpha(epoch_student_loss, initial_alpha, final_alpha, decay_rate)

#     Update distiller's temperature (used for adaptive)
#     distiller.temperature = adaptive_temp_value
#     distiller.alpha = adaptive_alpha_value

    # Train the model for one epoch
    model = distiller.fit(datagen.flow(x_train, y_train, batch_size=batch_size),
                  steps_per_epoch=len(x_train) // batch_size,
                  validation_data=(x_val, y_val))


    print(f"Student Loss: {epoch_student_loss:.4f}, Temperature: {temperature:.2f}, Alpha: {alpha:.2f}")

# Evaluate student on test dataset
distiller.evaluate(x_test, y_test)

Epoch 1/20
Student Loss: 2.3026, Temperature: 6.00, Alpha: 0.90
Epoch 2/20
Student Loss: 2.3026, Temperature: 4.86, Alpha: 0.86
Epoch 3/20
Student Loss: 2.3026, Temperature: 4.20, Alpha: 0.81
Epoch 4/20
Student Loss: 2.3026, Temperature: 3.72, Alpha: 0.77
Epoch 5/20
Student Loss: 2.3026, Temperature: 3.36, Alpha: 0.74
Epoch 6/20
Student Loss: 2.3026, Temperature: 3.06, Alpha: 0.70
Epoch 7/20
Student Loss: 2.3026, Temperature: 2.80, Alpha: 0.67
Epoch 8/20
Student Loss: 2.3026, Temperature: 2.58, Alpha: 0.63
Epoch 9/20
Student Loss: 2.3026, Temperature: 2.39, Alpha: 0.60
Epoch 10/20
Student Loss: 2.3026, Temperature: 2.22, Alpha: 0.57
Epoch 11/20
Student Loss: 2.3026, Temperature: 2.06, Alpha: 0.55
Epoch 12/20
Student Loss: 2.3026, Temperature: 1.92, Alpha: 0.52
Epoch 13/20
Student Loss: 2.3026, Temperature: 1.79, Alpha: 0.49
Epoch 14/20
Student Loss: 2.3026, Temperature: 1.67, Alpha: 0.47
Epoch 15/20
Student Loss: 2.3026, Temperature: 1.55, Alpha: 0.45
Epoch 16/20
Student Loss: 2.3026, 

[0.0, 0.10000000149011612]

# Pruning

In [None]:
logdir = tempfile.mkdtemp()

# Define the pruning schedule
pruning_params = {
      'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(initial_sparsity=0.50,
                                                               final_sparsity=0.80,
                                                               begin_step=0,
                                                               end_step=len(x_train) * 10)
}

prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude

pruning_model = prune_low_magnitude(distiller.student, **pruning_params)

pruning_model.compile(
    optimizer=keras.optimizers.Adam(),
    loss=keras.losses.CategoricalCrossentropy(),
    metrics=[keras.metrics.CategoricalAccuracy()],
)

pruning_callbacks = [
    tfmot.sparsity.keras.UpdatePruningStep(),
    tfmot.sparsity.keras.PruningSummaries(log_dir=logdir),
]

with tf.device('/cpu:0'):
    pruning_model.fit(x_train, y_train, epochs=1, validation_data=(x_val, y_val), callbacks=pruning_callbacks)



In [None]:
pruning_model.evaluate(x_test, y_test)



[2.302705764770508, 0.10000000149011612]

In [None]:
stripped_pruned_model = tfmot.sparsity.keras.strip_pruning(pruning_model)

In [None]:
stripped_pruned_model.summary()

Model: "sequential_10"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 depthwise_conv2d_26 (Depthw  (None, 30, 30, 1)        10        
 iseConv2D)                                                      
                                                                 
 conv2d_9 (Conv2D)           (None, 30, 30, 64)        128       
                                                                 
 depthwise_conv2d_27 (Depthw  (None, 28, 28, 64)       640       
 iseConv2D)                                                      
                                                                 
 conv2d_10 (Conv2D)          (None, 28, 28, 128)       8320      
                                                                 
 depthwise_conv2d_28 (Depthw  (None, 26, 26, 128)      1280      
 iseConv2D)                                                      
                                                     

# Quantization

In [None]:
# Apply quantization-aware training to the student model
quantize_model = tfmot.quantization.keras.quantize_model
student_quant = quantize_model(stripped_pruned_model)

# Compile the quantized student model
student_quant.compile(
    optimizer=keras.optimizers.Adam(),
    loss=keras.losses.CategoricalCrossentropy(),
    metrics=[keras.metrics.CategoricalAccuracy()],
)
# Fine-tune the quantized student model
student_quant.fit(x_train, y_train, epochs=10, validation_data=(x_val, y_val))

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


<keras.callbacks.History at 0x1ee8627ee80>

In [None]:
student_quant.evaluate(x_test, y_test)



[2.302885055541992, 0.10000000149011612]

In [None]:
student_quant.summary()

Model: "sequential_10"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 quantize_layer (QuantizeLay  (None, 32, 32, 1)        3         
 er)                                                             
                                                                 
 quant_depthwise_conv2d_26 (  (None, 30, 30, 1)        15        
 QuantizeWrapperV2)                                              
                                                                 
 quant_conv2d_9 (QuantizeWra  (None, 30, 30, 64)       259       
 pperV2)                                                         
                                                                 
 quant_depthwise_conv2d_27 (  (None, 28, 28, 64)       645       
 QuantizeWrapperV2)                                              
                                                                 
 quant_conv2d_10 (QuantizeWr  (None, 28, 28, 128)    

# Final output before softmax

In [None]:
model_final = keras.models.Sequential(student_quant.layers[0:-1])
model_final.compile(loss='categorical_crossentropy',
                optimizer='adam', metrics=['accuracy'])

input_shape = distiller.student.layers[0].input_shape

model_final.build(input_shape)

In [None]:
model_final.summary()

Model: "sequential_11"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 quantize_layer (QuantizeLay  (None, 32, 32, 1)        3         
 er)                                                             
                                                                 
 quant_depthwise_conv2d_26 (  (None, 30, 30, 1)        15        
 QuantizeWrapperV2)                                              
                                                                 
 quant_conv2d_9 (QuantizeWra  (None, 30, 30, 64)       259       
 pperV2)                                                         
                                                                 
 quant_depthwise_conv2d_27 (  (None, 28, 28, 64)       645       
 QuantizeWrapperV2)                                              
                                                                 
 quant_conv2d_10 (QuantizeWr  (None, 28, 28, 128)    

In [None]:
x_example = x_test[0]
x_example = np.expand_dims(x_example, axis=0)
intermediate_output = model_final.predict(x_example)
print(intermediate_output)

[[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 

In [None]:
import tempfile
converter = tf.lite.TFLiteConverter.from_keras_model(model_final)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
quantized_tflite_model = converter.convert()
# Create float TFLite model.
_, quant_file = tempfile.mkstemp('.tflite')
with open(quant_file, 'wb') as f:
  f.write(quantized_tflite_model)
print("Quantized model in Mb:", os.path.getsize(quant_file) / float(2**20))



INFO:tensorflow:Assets written to: C:\Users\CMP3WO~1\AppData\Local\Temp\tmpukcvtp17\assets


INFO:tensorflow:Assets written to: C:\Users\CMP3WO~1\AppData\Local\Temp\tmpukcvtp17\assets


Quantized model in Mb: 0.1021728515625


In [None]:
open("CNN_512_grey_depth.tflite", "wb").write(quantized_tflite_model)

107136

## Train student from scratch for comparison



In [None]:
# Train student as doen usually
student_scratch.compile(
    optimizer=keras.optimizers.Adam(),
    loss=keras.losses.CategoricalCrossentropy(),
    metrics=[keras.metrics.CategoricalAccuracy()],
)

# Train and evaluate student trained from scratch.
student_scratch.fit(x_train, y_train, epochs=20, validation_data=(x_val, y_val))
student_scratch.evaluate(x_test, y_test)

Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20


[0.7911658883094788, 0.7263000011444092]