# Project B: Knowledge Distillation for Building Lightweight Deep Learning Models in Visual Classification Tasks

In [None]:
import tensorflow.compat.v2 as tf
import tensorflow_datasets as tfds
from typing import Union

tf.enable_v2_behavior()

builder = tfds.builder('mnist')
BATCH_SIZE = 256
NUM_EPOCHS = 12
NUM_CLASSES = 10  # 10 total classes.

In [None]:
from tensorflow import keras
from tensorflow.keras import layers
#from keras.layers import Dense, Dropout, Activation, Flatten
from tensorflow.keras.optimizers import Adam

# Data loading

In [None]:
# Load train and test splits.
def preprocess(x):
  image = tf.image.convert_image_dtype(x['image'], tf.float32)
  class_labels = tf.one_hot(x['label'], builder.info.features['label'].num_classes)
  return image, class_labels


mnist_train = tfds.load('mnist', split='train', shuffle_files=False).cache()
mnist_train = mnist_train.map(preprocess)
mnist_train = mnist_train.shuffle(builder.info.splits['train'].num_examples)
mnist_train = mnist_train.batch(BATCH_SIZE, drop_remainder=True)

mnist_test = tfds.load('mnist', split='test').cache()
mnist_test = mnist_test.map(preprocess).batch(BATCH_SIZE)

[1mDownloading and preparing dataset mnist/3.0.1 (download: 11.06 MiB, generated: 21.00 MiB, total: 32.06 MiB) to /root/tensorflow_datasets/mnist/3.0.1...[0m


local data directory. If you'd instead prefer to read directly from our public
GCS bucket (recommended if you're running on GCP), you can instead pass
`try_gcs=True` to `tfds.load` or set `data_dir=gs://tfds-data/datasets`.



Dl Completed...:   0%|          | 0/4 [00:00<?, ? file/s]


[1mDataset mnist downloaded and prepared to /root/tensorflow_datasets/mnist/3.0.1. Subsequent calls will reuse this data.[0m


In [None]:
mnist_train

<BatchDataset element_spec=(TensorSpec(shape=(256, 28, 28, 1), dtype=tf.float32, name=None), TensorSpec(shape=(256, 10), dtype=tf.float32, name=None))>

In [None]:
image.shape

NameError: ignored

# Model creation

In [None]:


# Build CNN teacher.
#cnn_model = tf.keras.Sequential()

# your code start from here for stpe 2
teacher = keras.Sequential(
    [
        keras.Input(shape=(28, 28, 1)),
        layers.Conv2D(32, (3, 3), strides=(1, 1), padding="same"),
        #layers.Conv2D(256, (3, 3), strides=(2, 2), padding="same"),
        layers.ReLU(),
        layers.MaxPooling2D(pool_size=(2, 2), strides=(1, 1), padding="same"),
        #layers.Conv2D(512, (3, 3), strides=(2, 2), padding="same"),
        layers.Conv2D(64, (3, 3), strides=(1, 1), padding="same"),
        layers.ReLU(),
        layers.MaxPooling2D(pool_size=(2, 2), strides=(2, 2), padding="same"),
        layers.Flatten(),
        layers.Dropout(0.5),
        layers.Dense(128, activation="relu"),
        layers.Dropout(0.5),
        #layers.Dense(10),
        #layers.Dense(10, activation="softmax")
        layers.Dense(10)
    ],
    name="teacher",
)


# Build fully connected student.
#fc_model = tf.keras.Sequential()


# your code start from here for step 2

student = keras.Sequential(
    [
        keras.Input(shape=(28, 28, 1)),
        layers.Flatten(),
        layers.Dense(784, activation="relu"),
        layers.Dense(784, activation="relu"),
        #layers.Dense(10),
        #layers.Dense(10, activation="softmax")
        layers.Dense(10)
        # model.add(Activation('relu'))
    ],
    name="student",
)




In [None]:
# teacher.compile(loss='categorical_crossentropy', optimizer=Adam(), metrics=['accuracy'])

In [None]:
# student.compile(loss='categorical_crossentropy', optimizer=Adam(), metrics=['accuracy'])

In [None]:
# teacher.fit(mnist_train, epochs=5)


In [None]:
student.summary()

Model: "student"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 flatten_7 (Flatten)         (None, 784)               0         
                                                                 
 dense_16 (Dense)            (None, 784)               615440    
                                                                 
 dense_17 (Dense)            (None, 784)               615440    
                                                                 
 dense_18 (Dense)            (None, 10)                7850      
                                                                 
Total params: 1,238,730
Trainable params: 1,238,730
Non-trainable params: 0
_________________________________________________________________


In [None]:
teacher.summary()

Model: "teacher"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 conv2d_8 (Conv2D)           (None, 28, 28, 32)        320       
                                                                 
 re_lu_8 (ReLU)              (None, 28, 28, 32)        0         
                                                                 
 max_pooling2d_8 (MaxPooling  (None, 28, 28, 32)       0         
 2D)                                                             
                                                                 
 conv2d_9 (Conv2D)           (None, 28, 28, 64)        18496     
                                                                 
 re_lu_9 (ReLU)              (None, 28, 28, 64)        0         
                                                                 
 max_pooling2d_9 (MaxPooling  (None, 14, 14, 64)       0         
 2D)                                                       

In [None]:
mnist_train

<BatchDataset element_spec=(TensorSpec(shape=(256, 28, 28, 1), dtype=tf.float32, name=None), TensorSpec(shape=(256, 10), dtype=tf.float32, name=None))>

In [None]:
# Train the teacher model
# teacher.compile(loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True), optimizer=tf.keras.optimizers.Adam(learning_rate=1e-5), metrics=["accuracy"])


In [None]:
# teacher.fit(mnist_train, epochs=5)

### **NEW TEACHER MODEL**

In [None]:
teacher2 = keras.Sequential(
    [
        keras.Input(shape=(28, 28, 1)),
        layers.Conv2D(32, (3, 3), strides=(1, 1), padding="same"),
        #layers.Conv2D(256, (3, 3), strides=(2, 2), padding="same"),
        layers.ReLU(),
        layers.MaxPooling2D(pool_size=(2, 2), strides=(1, 1), padding="same"),
        #layers.Conv2D(512, (3, 3), strides=(2, 2), padding="same"),
        layers.Conv2D(64, (3, 3), strides=(1, 1), padding="same"),
        layers.ReLU(),
        layers.MaxPooling2D(pool_size=(2, 2), strides=(2, 2), padding="same"),
        layers.Flatten(),
        layers.Dropout(0.5),
        layers.Dense(128, activation="relu"),
        layers.Dropout(0.5),
        #layers.Dense(10),
        #layers.Dense(10, activation="softmax")
        layers.Dense(10)
    ],
    name="teacher2",
)

### Unpack data

In [None]:
mnist_example, =mnist_train.take(1)
image, label = mnist_example


In [None]:
image.shape

TensorShape([256, 28, 28, 1])

In [171]:
mnist_train

<BatchDataset element_spec=(TensorSpec(shape=(256, 28, 28, 1), dtype=tf.float32, name=None), TensorSpec(shape=(256, 10), dtype=tf.float32, name=None))>

# Teacher loss function

In [173]:
@tf.function
def compute_teacher_loss(images, labels):
  """Compute class knowledge distillation teacher loss for given images
     and labels.

  Args:
    images: Tensor representing a batch of images.
    labels: Tensor representing a batch of labels.

  Returns:
    Scalar loss Tensor.
  """
  class_logits = teacher(images, training=True)

  # Compute cross-entropy loss for classes.

  # your code start from here for step 3
  cross_entropy_loss_value =tf.reduce_mean(tf.keras.losses.categorical_crossentropy(labels, tf.nn.softmax(class_logits)))


  return cross_entropy_loss_value

In [None]:
@tf.function
def compute_teacher2_loss(images, labels, model):
  """Compute class knowledge distillation teacher loss for given images
     and labels.

  Args:
    images: Tensor representing a batch of images.
    labels: Tensor representing a batch of labels.

  Returns:
    Scalar loss Tensor.
  """
  # class_logits = teacher2(images, training=True)
  class_logits2 = model(images, training=True)
  # Compute cross-entropy loss for classes.

  # your code start from here for step 3
  cross_entropy_loss_value2 =tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels, class_logits2))


  return cross_entropy_loss_value2

### Experiments

In [None]:
teacher_wo_hd_loss=compute_teacher_loss(image,label)

In [None]:
teacher_wo_hd_loss

<tf.Tensor: shape=(), dtype=float32, numpy=2.2940745>

In [None]:
logits = teacher(image, training=True)
cross_entropy_loss_value1=tf.reduce_mean(tf.keras.losses.categorical_crossentropy(label, logits, from_logits=True))

In [None]:
cross_entropy_loss_value1

<tf.Tensor: shape=(), dtype=float32, numpy=0.32171333>

In [None]:
logits = teacher(image, training=True)
cross_entropy_loss_value2 =tf.reduce_mean(tf.keras.losses.categorical_crossentropy(label, tf.nn.softmax(logits)))

In [None]:
cross_entropy_loss_value2

<tf.Tensor: shape=(), dtype=float32, numpy=0.38943234>

In [None]:
logits = teacher(image, training=True)
teacher_wo_conven_loss= tf.keras.losses.CategoricalCrossentropy(from_logits=True)


In [None]:
logits

<tf.Tensor: shape=(256, 10), dtype=float32, numpy=
array([[-0.1397345 , -0.04447453,  0.12862381, ..., -0.08678841,
         0.21981353,  0.09755316],
       [-0.02464676, -0.1277153 ,  0.22527118, ...,  0.2484115 ,
         0.01940316,  0.11200459],
       [-0.09397954,  0.01585651,  0.2389846 , ..., -0.16748519,
        -0.07179485,  0.13059744],
       ...,
       [-0.01045809, -0.11735284,  0.03315504, ..., -0.07616577,
         0.06614104, -0.0382463 ],
       [-0.01081269,  0.04974911, -0.03781566, ..., -0.05184472,
        -0.09699267, -0.05853181],
       [-0.18763806, -0.04817468,  0.12804027, ..., -0.00710503,
         0.05203816,  0.1680861 ]], dtype=float32)>

In [None]:
label

<tf.Tensor: shape=(256, 10), dtype=float32, numpy=
array([[0., 0., 0., ..., 0., 1., 0.],
       [1., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 1.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 1., 0., 0.],
       [0., 0., 0., ..., 0., 1., 0.]], dtype=float32)>

In [None]:
label.get_shape()[0]

256

In [None]:
compute_teacher_loss(image,label)

<tf.Tensor: shape=(), dtype=float32, numpy=2.2930107>

In [None]:
compute_teacher2_loss(image,label,teacher2)

<tf.Tensor: shape=(), dtype=float32, numpy=2.3127394>

In [None]:
# Define loass function and optimizer
# loss_func = tf.keras.losses.CategoricalCrossentropy(from_logits=True)
# optimizer = tf.keras.optimizers.Adam(learning_rate=1e-5) # Low since we are fine-tuning


In [None]:
# teacher.compile(loss=compute_teacher_loss(), optimizer=tf.keras.optimizers.Adam(learning_rate=1e-5), metrics=["accuracy"])

### experiment ends

# Student loss function

In [None]:


# Hyperparameters for distillation (need to be tuned).
ALPHA = 0.5 # task balance between cross-entropy and distillation loss
DISTILLATION_TEMPERATURE = 4. #temperature hyperparameter

def distillation_loss(teacher_logits: tf.Tensor, student_logits: tf.Tensor,
                      temperature: Union[float, tf.Tensor]):
  """Compute distillation loss.

  This function computes cross entropy between softened logits and softened
  targets. The resulting loss is scaled by the squared temperature so that
  the gradient magnitude remains approximately constant as the temperature is
  changed. For reference, see Hinton et al., 2014, "Distilling the knowledge in
  a neural network."

  Args:
    teacher_logits: A Tensor of logits provided by the teacher.
    student_logits: A Tensor of logits provided by the student, of the same
      shape as `teacher_logits`.
    temperature: Temperature to use for distillation.

  Returns:
    A scalar Tensor containing the distillation loss.
  """
 # your code start from here for step 3
  soft_targets = teacher_logits / temperature

  return tf.reduce_mean(
      tf.nn.softmax_cross_entropy_with_logits(
          soft_targets, student_logits / temperature)) * temperature ** 2

def compute_student_loss(images, labels):
  """Compute class knowledge distillation student loss for given images
     and labels.

  Args:
    images: Tensor representing a batch of images.
    labels: Tensor representing a batch of labels.

  Returns:
    Scalar loss Tensor.
  """
  student_class_logits = student(images, training=True)

  # Compute class distillation loss between student class logits and
  # softened teacher class targets probabilities.

  # your code start from here for step 3

  teacher_class_logits = teacher(images, training=False)
  distillation_loss_value =distillation_loss(teacher_class_logits,student_class_logits,DISTILLATION_TEMPERATURE)

  # Compute cross-entropy loss with hard targets.

  # your code start from here for step 3

  # cross_entropy_loss_value = tf.reduce_mean(tf.keras.losses.categorical_crossentropy(labels, student_class_logits, from_logits=True))
  cross_entropy_loss_value = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels, student_class_logits, from_logits=True))

  total_loss =ALPHA*cross_entropy_loss_value + (1-ALPHA)*distillation_loss_value

  return total_loss

# Train and evaluation

In [None]:
@tf.function
def compute_num_correct(model, images, labels):
  """Compute number of correctly classified images in a batch.

  Args:
    model: Instance of tf.keras.Model.
    images: Tensor representing a batch of images.
    labels: Tensor representing a batch of labels.

  Returns:
    Number of correctly classified images.
  """
  class_logits = model(images, training=False)
  return tf.reduce_sum(
      tf.cast(tf.math.equal(tf.argmax(class_logits, -1), tf.argmax(labels, -1)),
              tf.float32)), tf.argmax(class_logits, -1), tf.argmax(labels, -1)

In [None]:
def train_and_evaluate(model, compute_loss_fn):
  """Perform training and evaluation for a given model.

  Args:
    model: Instance of tf.keras.Model.
    compute_loss_fn: A function that computes the training loss given the
      images, and labels.
  """

  # your code start from here for step 4
  optimizer = tf.keras.optimizers.Adam(learning_rate=1e-5)

  for epoch in range(1, NUM_EPOCHS + 1):
    # Run training.
    print('Epoch {}: '.format(epoch), end='')
    for images, labels in mnist_train:
      with tf.GradientTape() as tape:
         # your code start from here for step 4
        logits = model(images, training=True)
        loss_value = tf.reduce_mean(compute_loss_fn(labels,logits,from_logits=True))
        # loss_value = compute_loss_fn
      grads = tape.gradient(loss_value, model.trainable_variables)
      optimizer.apply_gradients(zip(grads, model.trainable_variables))
      train_acc_metric.update_state(labels, logits)
    # Run evaluation.
    num_correct = 0
    num_total = builder.info.splits['test'].num_examples
    # for images, labels in mnist_test:
      # your code start from here for step 4
    print("Training loss : %.4f" % (float(loss_value)))
    #   num_correct += compute_num_correct(model,images,labels)
    # print("Class_accuracy: " + '{:.2f}%'.format(
    #     num_correct / num_total * 100))
    train_acc = train_acc_metric.result()
    print("Training acc over epoch: %.4f" % (float(train_acc),))

    # Reset training metrics at the end of each epoch
    train_acc_metric.reset_states()

In [None]:
train_acc_metric = keras.metrics.CategoricalAccuracy()
logits = teacher(image, training=True)

train_and_evaluate(teacher,tf.keras.losses.categorical_crossentropy)

Epoch 1: Training loss : 1.4006
Training acc over epoch: 0.4195
Epoch 2: Training loss : 0.9280
Training acc over epoch: 0.6909
Epoch 3: Training loss : 0.7246
Training acc over epoch: 0.7699
Epoch 4: Training loss : 0.6009
Training acc over epoch: 0.8103
Epoch 5: Training loss : 0.5128
Training acc over epoch: 0.8349
Epoch 6: Training loss : 0.4921
Training acc over epoch: 0.8504
Epoch 7: Training loss : 0.4141
Training acc over epoch: 0.8636
Epoch 8: Training loss : 0.3289
Training acc over epoch: 0.8747
Epoch 9: Training loss : 0.4116
Training acc over epoch: 0.8828
Epoch 10: Training loss : 0.3446
Training acc over epoch: 0.8892
Epoch 11: Training loss : 0.2574
Training acc over epoch: 0.8963
Epoch 12: Training loss : 0.3591
Training acc over epoch: 0.9036


### Experiments

In [None]:
t=compute_num_correct(teacher,image,label)[0]

In [None]:
t

<tf.Tensor: shape=(), dtype=float32, numpy=228.0>

### Experiment ends

In [None]:
def train_and_evaluate2(model, compute_loss_fn):
  """Perform training and evaluation for a given model.

  Args:
    model: Instance of tf.keras.Model.
    compute_loss_fn: A function that computes the training loss given the
      images, and labels.
  """

  # your code start from here for step 4
  optimizer = tf.keras.optimizers.Adam(learning_rate=1e-5)

  for epoch in range(1, NUM_EPOCHS + 1):
    # Run training.
    print('Epoch {}: '.format(epoch), end='')
    for images, labels in mnist_train:
      with tf.GradientTape() as tape:
         # your code start from here for step 4
        
          loss_value = compute_loss_fn(images,labels,model)

      grads = tape.gradient(loss_value, model.trainable_variables)
      optimizer.apply_gradients(zip(grads, model.trainable_variables))
      
    # Run evaluation.
    num_correct = 0
    num_total = builder.info.splits['test'].num_examples
    for images, labels in mnist_test:
    #   # your code start from here for step 4

      num_correct += compute_num_correct(model,images,labels)[0]
      
    print("Class_accuracy: " + '{:.2f}%'.format(
        num_correct / num_total * 100))




### Experiments

In [None]:
num_total = builder.info.splits['test'].num_examples

In [None]:
num_total

10000

In [None]:
for images, labels in mnist_train:
  loss_value = compute_teacher2_loss(images,labels,teacher2)


In [None]:
loss_value

<tf.Tensor: shape=(), dtype=float32, numpy=2.3341565>

In [None]:
for images, labels in mnist_train:
   logits = teacher2(images, training=True)
   loss_value2 = tf.reduce_mean(tf.keras.losses.categorical_crossentropy(labels,logits,from_logits=True))

In [None]:
loss_value2

<tf.Tensor: shape=(), dtype=float32, numpy=2.3091621>

### Training teacher2 using **compute_teacher2_loss**

In [None]:
train_and_evaluate2(teacher2,compute_teacher2_loss)

Epoch 1: Class_accuracy: 75.72%
Epoch 2: Class_accuracy: 84.07%
Epoch 3: Class_accuracy: 87.31%
Epoch 4: Class_accuracy: 89.30%
Epoch 5: Class_accuracy: 90.34%
Epoch 6: Class_accuracy: 91.11%
Epoch 7: Class_accuracy: 91.70%
Epoch 8: Class_accuracy: 92.20%
Epoch 9: Class_accuracy: 92.45%
Epoch 10: Class_accuracy: 92.84%
Epoch 11: Class_accuracy: 93.22%
Epoch 12: Class_accuracy: 93.53%


# Training models

In [None]:
# your code start from here for step 5 



# Test accuracy vs. tempreture curve

In [None]:
# your code start from here for step 6


# Train student from scratch

In [None]:
# Build fully connected student.
# fc_model_no_distillation = tf.keras.Sequential()
student_no_distil = keras.Sequential(
    [
        keras.Input(shape=(28, 28, 1)),
        layers.Flatten(),
        layers.Dense(784, activation="relu"),
        layers.Dense(784, activation="relu"),
        #layers.Dense(10),
        #layers.Dense(10, activation="softmax")
        layers.Dense(10)
        # model.add(Activation('relu'))
    ],
    name="student_no_distil",
)
# your code start from here for step 7






def compute_plain_cross_entropy_loss(images, labels, student_model):
  """Compute plain loss for given images and labels.

  For fair comparison and convenience, this function also performs a
  LogSumExp over classes, but does not perform class distillation.

  Args:
    images: Tensor representing a batch of images.
    labels: Tensor representing a batch of labels.

  Returns:
    Scalar loss Tensor.
  """
  # your code start from here for step 7

  student_class_logits = student_model(images, training=True)
  cross_entropy_loss =tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels, student_class_logits))
  
  return cross_entropy_loss


train_and_evaluate(student_no_distil, compute_plain_cross_entropy_loss)

NameError: ignored

# Comparing the teacher and student model (number of of parameters and FLOPs) 

In [None]:
# your code start from here for step 8


# XAI method to explain models

In [None]:
# your code start from here for step 9


# Implementing the state-of-the-art KD algorithm

In [None]:
# your code start from here for step 13
