In [None]:
from google.colab import drive
drive.mount('/content/drive')
%cd /content/drive/MyDrive/Colab\ Notebooks/

Mounted at /content/drive
/content/drive/MyDrive/Colab Notebooks


# Knowledge Distillation in MHIST Dataset

In [None]:
import os
import sys
import numpy as np
import random
import matplotlib.pyplot as plt
import pandas as pd
from keras.preprocessing.image import ImageDataGenerator
from PIL import Image
from numpy import expand_dims
from sklearn import preprocessing
from tensorflow.keras.models import Model
from sklearn.metrics import precision_score, recall_score, f1_score
from sklearn.metrics import roc_auc_score
import tensorflow.compat.v2 as tf
from typing import Union
from keras_flops import get_flops

# Load Dataset

In [None]:
# Prepare data
path = "/content/drive/MyDrive/ECE1512/images"
CSVfile = "/content/drive/MyDrive/ECE1512/annotations.csv"

annotations = pd.read_csv(CSVfile)
annotations.set_index('Image Name', inplace=True)
annotations.head(10)

Unnamed: 0_level_0,Majority Vote Label,Number of Annotators who Selected SSA (Out of 7),Partition
Image Name,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
MHIST_aaa.png,SSA,6,train
MHIST_aab.png,HP,0,train
MHIST_aac.png,SSA,5,train
MHIST_aae.png,HP,1,train
MHIST_aaf.png,SSA,5,train
MHIST_aag.png,HP,2,test
MHIST_aah.png,HP,2,test
MHIST_aai.png,HP,3,train
MHIST_aaj.png,HP,0,train
MHIST_aak.png,HP,2,train


In [None]:
images_name = os.listdir(path)

In [None]:
mhist_train_img = []
mhist_train_label = []
mhist_test_img = []
mhist_test_label = []
datagen = ImageDataGenerator(brightness_range=[0.2,1.0],
                          zoom_range=[0.5,1.0],
                          rotation_range=180)
for img in images_name:

    if annotations.loc[img]['Partition'] == 'train':
        img_train_origin = Image.open("/content/drive/MyDrive/ECE1512/images/" + img)
        img_train_data = np.asarray(img_train_origin)

        samples = expand_dims(img_train_data, 0)
        it = datagen.flow(samples, batch_size=32)
        batch = it.next()
        image = batch[0].astype('uint8')
        mhist_train_img.append(image)

        img_train_label = annotations.loc[img]['Majority Vote Label']
        mhist_train_label.append(img_train_label)
    if annotations.loc[img]['Partition'] == 'test':
        img_test_origin = Image.open("/content/drive/MyDrive/ECE1512/images/" + img)
        img_test_data = np.asarray(img_test_origin)

        samples = expand_dims(img_test_data, 0)
        it = datagen.flow(samples, batch_size=32)
        batch = it.next()
        image = batch[0].astype('uint8')
        mhist_test_img.append(image)

        img_label = annotations.loc[img]['Majority Vote Label']
        mhist_test_label.append(img_label)

In [None]:
mhist_train_img = np.array(mhist_train_img)
mhist_train_label = np.array(mhist_train_label)
mhist_test_img = np.array(mhist_test_img)
mhist_test_label = np.array(mhist_test_label)
mhist_train_img = mhist_train_img/255
mhist_test_img = mhist_test_img/255

In [None]:
le = preprocessing.LabelEncoder()
le.fit(mhist_train_label)
mhist_train_label_le = le.transform(mhist_train_label)
mhist_test_label_le = le.transform(mhist_test_label)

In [None]:
# Build  teacher.
def teacher_initial():
  Res = tf.keras.applications.resnet_v2.ResNet50V2(include_top=False, input_shape=(224, 224, 3))
  #Res.trainable = False
  for layer in Res.layers[:185]:
    layer.trainable = False
  for layer in Res.layers[-5:]:
    layer.trainable = True

  average_tea = tf.keras.layers.GlobalAveragePooling2D()(Res.output)
  dense_tea = tf.keras.layers.Dense(2, activation = 'softmax')(average_tea)
  teacher_model = Model(inputs=Res.input, outputs=dense_tea)
  return teacher_model

In [None]:
# Build student.
def student_initial():
  Mob = tf.keras.applications.mobilenet_v2.MobileNetV2(include_top=False, input_shape=(224, 224, 3))
  Mob.trainable = False
  '''for layer in Mob.layers[:149]:
    layer.trainable = False
  for layer in Mob.layers[-5:]:
    layer.trainable = True'''

  average_stu = tf.keras.layers.GlobalAveragePooling2D()(Mob.output)
  dense_stu = tf.keras.layers.Dense(2, activation = 'softmax')(average_stu)
  student_model = Model(inputs=Mob.input, outputs=dense_stu)
  return student_model

In [None]:
teacher_model = teacher_initial()
student_model = student_initial()

Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/resnet/resnet50v2_weights_tf_dim_ordering_tf_kernels_notop.h5
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/mobilenet_v2/mobilenet_v2_weights_tf_dim_ordering_tf_kernels_1.0_224_no_top.h5


In [None]:
teacher_model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(),
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
)

# Train teacher on data.
teacher_model.fit(mhist_train_img, mhist_train_label_le, epochs=25)

Epoch 1/25
Epoch 2/25
Epoch 3/25
Epoch 4/25
Epoch 5/25
Epoch 6/25
Epoch 7/25
Epoch 8/25
Epoch 9/25
Epoch 10/25
Epoch 11/25
Epoch 12/25
Epoch 13/25
Epoch 14/25
Epoch 15/25
Epoch 16/25
Epoch 17/25
Epoch 18/25
Epoch 19/25
Epoch 20/25
Epoch 21/25
Epoch 22/25
Epoch 23/25
Epoch 24/25
Epoch 25/25


<keras.src.callbacks.History at 0x7a454546fc10>

In [None]:
tea_f1_score = f1_score(mhist_test_label_le, np.argmax(teacher_model.predict(mhist_test_img), axis=1))
print('Teachers f1 score is',tea_f1_score)

In [None]:
tea_auc = roc_auc_score(mhist_test_label_le, np.argmax(teacher_model.predict(mhist_test_img), axis=1))
print('Teachers auc is',tea_auc)

In [None]:
def distillation_loss(teacher_logits: tf.Tensor, student_logits: tf.Tensor,
                      temperature: Union[float, tf.Tensor]):
  soft_targets = tf.nn.softmax(teacher_logits / temperature)
  return tf.reduce_mean(
      tf.nn.softmax_cross_entropy_with_logits(soft_targets, student_logits / temperature)) * temperature ** 2

In [None]:
class KD(tf.keras.Model):
    def __init__(self, student, teacher):
        super(KD, self).__init__()
        self.teacher = teacher
        self.student = student

    def compile(self,optimizer,metrics,student_loss_fn,alpha,temperature):
        super(KD, self).compile(optimizer=optimizer, metrics=metrics)
        self.student_loss_fn = student_loss_fn
        self.alpha = alpha
        self.temperature = temperature

    def train_step(self, data):
        x, y = data

        teacher_predictions = self.teacher(x, training=False)
        with tf.GradientTape() as tape:
            student_predictions = self.student(x, training=True)
            student_loss = self.student_loss_fn(y, student_predictions)
            dis_loss = distillation_loss(
                tf.nn.softmax(teacher_predictions / self.temperature, axis=1),
                tf.nn.softmax(student_predictions / self.temperature, axis=1),
                temperature = self.temperature
            )
            loss_value = self.alpha * student_loss + (1 - self.alpha) * dis_loss

        # Compute gradients
        gradients = tape.gradient(loss_value, self.student.trainable_variables)
        # Update weights
        self.optimizer.apply_gradients(zip(gradients, self.student.trainable_variables))
        # Update the metrics configured in `compile()`.
        self.compiled_metrics.update_state(y, student_predictions)
        # Return a dict of performance
        results = {m.name: m.result() for m in self.metrics}
        results.update(
            {"student_loss": student_loss, "distillation_loss": dis_loss}
        )
        return results

    def test_step(self, data):
        # Unpack the data
        x, y = data
        # Compute predictions
        y_prediction = self.student(x, training=False)
        # Calculate the loss
        student_loss = self.student_loss_fn(y, y_prediction)
        # Update the metrics.
        self.compiled_metrics.update_state(y, y_prediction)
        # Return a dict of performance
        results = {m.name: m.result() for m in self.metrics}
        results.update({"student_loss": student_loss})
        # return results
        return y_prediction

In [None]:
student_model = student_initial()

In [None]:
distiller = KD(student=student_model, teacher=teacher_model)
distiller.compile(
      optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
      student_loss_fn=tf.keras.losses.SparseCategoricalCrossentropy(),
      alpha=0.5,
      temperature = 1,
  )

In [None]:
distiller.fit(mhist_train_img, mhist_train_label_le, epochs=25)

In [None]:
stu_f1_score = f1_score(mhist_test_label_le, np.argmax(student_model.predict(mhist_test_img), axis=1))
print('Student f1 score is',stu_f1_score)

In [None]:
stu_auc = roc_auc_score(mhist_test_label_le, np.argmax(student_model.predict(mhist_test_img), axis=1))
print('Student auc is',stu_auc)

In [None]:
# Hyperparameter tuning

T = [1, 2, 4, 16, 32, 64]
student_f1_list = []
student_auc_list = []
for temp in T:
  student_model = student_initial()
  distiller = KD(student=student_model, teacher=teacher_model)
  distiller.compile(
      optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
      student_loss_fn=tf.keras.losses.SparseCategoricalCrossentropy(),
      alpha=0.5,
      temperature = temp,
  )
  distiller.fit(mhist_train_img, mhist_train_label_le, epochs=25)

  stu_f1_score = f1_score(mhist_test_label_le, np.argmax(student_model.predict(mhist_test_img), axis=1))
  stu_auc = roc_auc_score(mhist_test_label_le, np.argmax(student_model.predict(mhist_test_img), axis=1))

  student_f1_list.append(stu_f1_score)
  student_auc_list.append(stu_auc)

  print(temp,'Temperature done')
print('student_f1_list is',student_f1_list)
print('student_auc_list is',student_auc_list)
plt.plot(T, student_f1_list, label='f1')
plt.plot(T, student_auc_list, label='auc')
plt.xlabel("Temperature")
plt.ylabel("F1 score and AUC score")
plt.title("Distillation Performace vs. Temperature Hyperparameters")
plt.legend()
plt.show()

In [None]:
alpha_list = [0.1,0.3,0.5,0.7,0.9]
student_f1_list = []
student_auc_list = []
for alpha in alpha_list:
  student_model = student_initial()
  distiller = KD(student=student_model, teacher=teacher_model)
  distiller.compile(
      optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
      student_loss_fn=tf.keras.losses.SparseCategoricalCrossentropy(),
      alpha=alpha,
      temperature = 2,
  )
  distiller.fit(mhist_train_img, mhist_train_label_le, epochs=25)

  stu_f1_score = f1_score(mhist_test_label_le, np.argmax(student_model.predict(mhist_test_img), axis=1))
  stu_auc = roc_auc_score(mhist_test_label_le, np.argmax(student_model.predict(mhist_test_img), axis=1))

  student_f1_list.append(stu_f1_score)
  student_auc_list.append(stu_auc)

  print(alpha,'alpha done')
print('student_f1_list is',student_f1_list)
print('student_auc_list is',student_auc_list)
plt.plot(alpha_list, student_f1_list, label='f1')
plt.plot(alpha_list, student_auc_list, label='auc')
plt.xlabel("alpha")
plt.ylabel("F1 score and AUC score")
plt.title("Distillation Performace vs. Task Balance Hyperparameters")
plt.legend()
plt.show()

In [None]:
# Train student from scratch
student_model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(),
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
)

# Train student from scratch
student_model.fit(mhist_train_img, mhist_train_label_le, steps_per_epoch=68, epochs=25)

In [None]:
stu_f1_score = f1_score(mhist_test_label_le, np.argmax(student_model.predict(mhist_test_img), axis=1))
print('Student f1 score is',stu_f1_score)

In [None]:
stu_auc = roc_auc_score(mhist_test_label_le, np.argmax(student_model.predict(mhist_test_img), axis=1))
print('Student auc is',stu_auc)

In [None]:
# Flops
teacherflops = get_flops(teacher_model)
studentflops = get_flops(student_model)
print('flops for teacher ResNet model is',teacherflops)
print('flops for student MobileNet model is',studentflops)

In [None]:
# TA Model

def ta_initial():
  Res = tf.keras.applications.resnet_v2.ResNet50V2(include_top=False, input_shape=(224, 224, 3))
  Res.trainable = False
  '''for layer in Res.layers[:186]:
    layer.trainable = False
  for layer in Res.layers[-4:]:
    layer.trainable = True'''

  average_tea = tf.keras.layers.GlobalAveragePooling2D()(Res.output)
  dense_tea = tf.keras.layers.Dense(2, activation = 'softmax')(average_tea)
  ta_model = Model(inputs=Res.input, outputs=dense_tea)
  return ta_model

def ta_initial_mob():
  Mob = tf.keras.applications.mobilenet_v2.MobileNetV2(include_top=False, input_shape=(224, 224, 3))
  Mob.trainable = False
  for layer in Mob.layers[:149]:
    layer.trainable = False
  for layer in Mob.layers[-3:]:
    layer.trainable = True

  average_tea = tf.keras.layers.GlobalAveragePooling2D()(Mob.output)
  dense_tea = tf.keras.layers.Dense(2, activation = 'softmax')(average_tea)
  ta_model = Model(inputs=Mob.input, outputs=dense_tea)
  return ta_model

In [None]:
ta_model = ta_initial()
distiller = KD(student=ta_model, teacher=teacher_model)
distiller.compile(
      optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
      student_loss_fn=tf.keras.losses.SparseCategoricalCrossentropy(),
      alpha=0.5,
      temperature = 1,
  )


In [None]:
distiller.fit(mhist_train_img, mhist_train_label_le, epochs=25)

In [None]:
student_model = student_initial()
distiller = KD(student=student_model, teacher=ta_model)
distiller.compile(
      optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
      student_loss_fn=tf.keras.losses.SparseCategoricalCrossentropy(),
      alpha=0.5,
      temperature = 1,
  )

In [None]:
distiller.fit(mhist_train_img, mhist_train_label_le, epochs=25)

In [None]:
stu_f1_score = f1_score(mhist_test_label_le, np.argmax(student_model.predict(mhist_test_img), axis=1))
print('Student f1 score is',stu_f1_score)

In [None]:
stu_auc = roc_auc_score(mhist_test_label_le, np.argmax(student_model.predict(mhist_test_img), axis=1))
print('Student auc is',stu_auc)