In [1]:
import tensorflow as tf
print(tf.__version__)

2.9.0


In [2]:
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np

In [3]:
class Distiller(keras.Model):
    def __init__(self,student,teacher):
        super().__init__()
        self.teacher=teacher
        self.student=student
    
    def compile(self,optimizer,metrics,student_loss_fn,distillation_loss_fn,alpha=0.1,temperature=3):
        super().compile(optimizer=optimizer,metrics=metrics)
        self.student_loss_fn=student_loss_fn
        self.distillation_loss_fn=distillation_loss_fn
        self.alpha=alpha
        self.temperature=temperature
    
    def train_step(self,data):
        x,y=data # unpacking the data

        # calculate the loss
        teacher_predictions=self.teacher(x,training=False)
        with tf.GradientTape() as tape:
            student_predictions=self.student(x,training=True)
            student_loss=self.student_loss_fn(y,student_predictions)
            distillation_loss=self.distillation_loss_fn(tf.nn.softmax(teacher_predictions/self.temperature,axis=1),tf.nn.softmax(student_predictions/self.temperature,axis=1)*self.temperature**2)
            loss=self.alpha*student_loss+(1-self.alpha)*distillation_loss
        
        # compute gradients
        trainable_vars=self.student.trainable_variables
        gradients=tape.gradient(loss,trainable_vars)

        # update weights
        self.optimizer.apply_gradients(zip(gradients,trainable_vars))

        # update metrics 
        self.compiled_metrics.update_state(y,student_predictions)

        # return performance
        results={m.name: m.result() for m in self.metrics}
        results.update({'student_loss':student_loss,'distillation_loss':distillation_loss})
        
        return results
    
    def test_step(self,data):
        # unpack the data
        x,y=data

        # compute predictions
        y_prediction=self.student(x,training=False)

        # calculate the loss
        student_loss=self.student_loss_fn(y,y_prediction)

        # update the metrics
        self.compiled_metrics.update_state(y,y_prediction)

        # return performance
        results={m.name:m.result() for m in self.metrics}
        results.update({'student_loss':student_loss})

        return results

In [4]:
# Create Teacher Model

teacher=keras.Sequential(
    [
        keras.Input(shape=(28,28,1)),
        layers.Conv2D(256,(3,3),strides=(2,2),padding='same'),
        layers.LeakyReLU(alpha=0.2),
        layers.MaxPooling2D(pool_size=(2,2),strides=(1,1),padding='same'),
        layers.Conv2D(512,(3,3),strides=(2,2),padding='same'),
        layers.Flatten(),
        layers.Dense(10),
    ],
    name='teacher'
)

In [5]:
# Create Student Model

student=keras.Sequential(
    [
        keras.Input(shape=(28,28,1)),
        layers.Conv2D(16,(3,3),strides=(2,2),padding='same'),
        layers.LeakyReLU(alpha=0.2),
        layers.MaxPooling2D(pool_size=(2,2),strides=(1,1),padding='same'),
        layers.Conv2D(32,(3,3),strides=(2,2),padding='same'),
        layers.Flatten(),
        layers.Dense(10)
    ],
    name='student'
)

In [6]:
# clone the student model for comparision
student_scratch=keras.models.clone_model(student)

In [7]:
# prepare the dataset - MNIST - Digit Dataset

batch_size=64

(x_train,y_train),(x_test,y_test)=keras.datasets.mnist.load_data()

# Normalization - Feature Scaling
x_train=x_train.astype('float32')/255.0

# Reshape the data after normalization 1D to 2D
x_train=np.reshape(x_train,(-1,28,28,1))

x_test=x_test.astype('float32')/255.0
x_test=np.reshape(x_test,(-1,28,28,1))

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz


In [8]:
# Train the teacher model as usual 

# Compile the model
teacher.compile(
    optimizer=keras.optimizers.Adam(),
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[keras.metrics.SparseCategoricalAccuracy()]
)

# Train the Model
teacher.fit(x_train,y_train,epochs=5)

# Evaluate the Model
teacher.evaluate(x_test,y_test)

Epoch 1/5


2023-02-01 11:57:36.110620: W tensorflow/core/platform/profile_utils/cpu_utils.cc:128] Failed to get CPU frequency: 0 Hz


Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


[0.11484106630086899, 0.9724000096321106]

In [9]:
# initialize the distiller 
distiller=Distiller(student=student,teacher=teacher)

# compile the distiller
distiller.compile(
    optimizer=keras.optimizers.Adam(),
    metrics=[keras.metrics.SparseCategoricalAccuracy()],
    student_loss_fn=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    distillation_loss_fn=keras.losses.KLDivergence(),
    alpha=0.1,
    temperature=10
)

# Train the Student with Distiller
distiller.fit(x_train,y_train,epochs=3)

# Evaluate student with distiller on test dataset
distiller.evaluate(x_test,y_test)

Epoch 1/3
Epoch 2/3
Epoch 3/3


[0.9822999835014343, 0.0023729107342660427]

In [10]:
# We have trained Teacher Model - Training Acc (97.24%), Testing Acc (97.24%)
# We have trained Student Model with Distillation - Training Acc (98.23%), Testing Acc (98.22%)

# We have to train Student Model without Distillation

student_scratch.compile(
    optimizer=keras.optimizers.Adam(),
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[keras.metrics.SparseCategoricalAccuracy()]
)

# Train the Model
student_scratch.fit(x_train,y_train,epochs=3)

# Evaluate the Model
student_scratch.evaluate(x_test,y_test)

Epoch 1/3
Epoch 2/3
Epoch 3/3


[0.0572688914835453, 0.9815999865531921]

# Accuracies of Models

1. Student with Distillation - High Accuracy - 98.23%
2. Student without Distillation - 98.16%
3. Teacher Model - 97.24%

Distllation models can boost the performance of Teacher Model and Student Model without Distillation.