# Project B: Knowledge Distillation for Building Lightweight Deep Learning Models in Visual Classification Tasks

In [49]:
import tensorflow.compat.v2 as tf
import tensorflow_datasets as tfds
from typing import Union

tf.enable_v2_behavior()

builder = tfds.builder('mnist')
BATCH_SIZE = 256
NUM_EPOCHS = 12
NUM_CLASSES = 10  # 10 total classes.

# Not originally included
import numpy as np
import os
import csv
from keras.preprocessing.image import ImageDataGenerator

ENTROPY_ZERO_FILLER = 1e-15
NUM_CLASSES = 2 #SSA or HP
WEIGHT_DECAY = 5e-4

## Data loading/augmenting

In [56]:
# Adapted from Project A's HMT Dataset loading code to use the ImageDataGenerator to augment the data and preload labels/image batches

img_dir = 'mhist_dataset/images'
train_dir = 'mhist_dataset/images/train'
test_dir = 'mhist_dataset/images/test'
anno_csv = 'mhist_dataset/annotations.csv'


if not os.path.isdir(train_dir):
    os.mkdir(train_dir)
if not os.path.isdir(test_dir):
    os.mkdir(test_dir)
    
if not os.path.isdir(os.path.join(train_dir, '01_HP')):
    os.mkdir(os.path.join(train_dir, '01_HP'))
if not os.path.isdir(os.path.join(train_dir, '02_SSA')):
    os.mkdir(os.path.join(train_dir, '02_SSA'))
if not os.path.isdir(os.path.join(test_dir, '01_HP')):
    os.mkdir(os.path.join(test_dir, '01_HP'))
if not os.path.isdir(os.path.join(test_dir, '02_SSA')):
    os.mkdir(os.path.join(test_dir, '02_SSA'))
    
# load csv
# label struct: [HP, SSA]
# labels as a list are sorted in alphabetical order as per the csv

train_labels = []
test_labels = []
train_img = []
test_img = []

with open(anno_csv, 'r') as csvfile:
    first_row = True
    for row in csv.reader(csvfile):
        if first_row:
            first_row = False
            continue
        if row[3] == 'train':
            train_img.append(row[0])
            if row[1] == 'HP':
                train_labels.append([1, 0])
            elif row[1] == 'SSA':
                train_labels.append([0, 1])
        elif row[3] == 'test':
            test_img.append(row[0])
            if row[1] == 'HP':
                test_labels.append([1, 0])
            elif row[1] == 'SSA':
                test_labels.append([0, 1])
        if row[0] in os.listdir(img_dir):
            if row[1] == 'HP' and row[3] == 'train':
                os.rename(os.path.join(img_dir, row[0]), os.path.join(train_dir, '01_HP', row[0]))
            elif row[1] == 'SSA' and row[3] == 'train':
                os.rename(os.path.join(img_dir, row[0]), os.path.join(train_dir, '02_SSA', row[0]))
            elif row[1] == 'HP' and row[3] == 'test':
                os.rename(os.path.join(img_dir, row[0]), os.path.join(test_dir, '01_HP', row[0]))
            elif row[1] == 'SSA' and row[3] == 'test':
                os.rename(os.path.join(img_dir, row[0]), os.path.join(test_dir, '02_SSA', row[0]))
    
# Data Augmentation using ImageDataGenerator
train_datagen = ImageDataGenerator(rescale=1/255.,
shear_range=0.1,
rotation_range=15,
horizontal_flip=True,
vertical_flip=True)

test_datagen = ImageDataGenerator(rescale=1/255.)

train_generator = train_datagen.flow_from_directory(train_dir,
class_mode='categorical',
interpolation='bilinear',
target_size=(224, 224),
batch_size=32,
shuffle=True)

# changed batch size to 62 from 32
test_generator = test_datagen.flow_from_directory(test_dir,
class_mode='categorical',
interpolation='bilinear',
target_size=(224, 224),
batch_size=62,
shuffle=False)

Found 2176 images belonging to 2 classes.
Found 976 images belonging to 2 classes.
(array([[[[0.7842671 , 0.6063359 , 0.7762172 ],
         [0.7860314 , 0.6082515 , 0.7773262 ],
         [0.8075136 , 0.6302881 , 0.7876909 ],
         ...,
         [0.61543876, 0.42457235, 0.6388687 ],
         [0.65154874, 0.44940823, 0.6786403 ],
         [0.52542275, 0.36240715, 0.6184585 ]],

        [[0.75464535, 0.55239505, 0.7419674 ],
         [0.7533851 , 0.5516389 , 0.7416145 ],
         [0.80159956, 0.6255937 , 0.7806602 ],
         ...,
         [0.618838  , 0.42702428, 0.6409862 ],
         [0.6475922 , 0.4465105 , 0.6768014 ],
         [0.5269831 , 0.36374456, 0.6191272 ]],

        [[0.83653533, 0.6506643 , 0.7909347 ],
         [0.83572876, 0.6487992 , 0.78997695],
         [0.8756891 , 0.75205815, 0.84797704],
         ...,
         [0.5549314 , 0.3809961 , 0.61192864],
         [0.5341586 , 0.3668991 , 0.60647553],
         [0.42290118, 0.28249508, 0.53915805]],

        ...,

        

# Model creation

In [14]:
#@test {"output": "ignore"}
# citing https://www.tensorflow.org/api_docs/python/tf/keras/applications/resnet_v2/ResNet50V2, https://www.kaggle.com/code/suniliitb96/tutorial-keras-transfer-learning-with-resnet50/notebook,
# https://www.tensorflow.org/api_docs/python/tf/keras/applications/mobilenet_v2/MobileNetV2 and https://github.com/Abhi-T/MNIST-CLASSIFIER-From-Scratch/blob/main/MNIST__handwritten_digit_Model.ipynb

# Build CNN teacher.

teacher_model = tf.keras.Sequential()

# your code start from here for step 2

teacher_model.add(tf.keras.applications.resnet_v2.ResNet50V2(include_top = False, classifier_activation = None))
teacher_model.add(tf.keras.layers.Dense(NUM_CLASSES)) 
print(teacher_model.summary())

# Build fully connected students
student_kd_model = tf.keras.Sequential()
student_kd_model.add(tf.keras.applications.mobilenet_v2.MobileNetV2(include_top = False, classifier_activation = None))
student_kd_model.add(tf.keras.layers.Dense(NUM_CLASSES)) 
print(student_kd_model.summary())

student_scratch_model = tf.keras.Sequential()
student_scratch_model.add(tf.keras.applications.mobilenet_v2.MobileNetV2(include_top = False, classifier_activation = None))
student_scratch_model.add(tf.keras.layers.Dense(NUM_CLASSES)) 
print(student_scratch_model.summary())


# your code start from here for step 2




Model: "sequential_11"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 resnet50v2 (Functional)     (None, None, None, 2048)  23564800  
                                                                 
 dense_3 (Dense)             (None, None, None, 2)     4098      
                                                                 
Total params: 23,568,898
Trainable params: 23,523,458
Non-trainable params: 45,440
_________________________________________________________________
None
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/mobilenet_v2/mobilenet_v2_weights_tf_dim_ordering_tf_kernels_1.0_224_no_top.h5
Model: "sequential_12"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 mobilenetv2_1.00_224 (Funct  (None, None, None, 1280)  2257984  
 ional)                                     

# Teacher loss function

In [None]:
def compute_teacher_loss(images, labels):
    """Compute class knowledge distillation teacher loss for given images
     and labels.

    Args:
    images: Tensor representing a batch of images.
    labels: Tensor representing a batch of labels.

    Returns:
    Scalar loss Tensor.
    """
    # already in probability form
    class_logits = cnn_model(images, training=True)

    # Compute cross-entropy loss for classes.

    # your code start from here for step 3
    
    cross_entropy_loss_value = tf.keras.losses.categorical_crossentropy(labels, tf.nn.softmax(class_logits))

    return cross_entropy_loss_value

# Student loss function

In [None]:
# adapted from https://keras.io/examples/vision/knowledge_distillation/

#@test {"output": "ignore"}

# Hyperparameters for distillation (need to be tuned).
ALPHA = 0.5 # task balance between cross-entropy and distillation loss
DISTILLATION_TEMPERATURE = 4. #temperature hyperparameter

def distillation_loss(teacher_logits: tf.Tensor, student_logits: tf.Tensor,
                      temperature: Union[float, tf.Tensor]):
    """Compute distillation loss.

    This function computes cross entropy between softened logits and softened
    targets. The resulting loss is scaled by the squared temperature so that
    the gradient magnitude remains approximately constant as the temperature is
    changed. For reference, see Hinton et al., 2014, "Distilling the knowledge in
    a neural network."

    Args:
    teacher_logits: A Tensor of logits provided by the teacher.
    student_logits: A Tensor of logits provided by the student, of the same
      shape as `teacher_logits`.
    temperature: Temperature to use for distillation.

    Returns:
    A scalar Tensor containing the distillation loss.
    """
    # your code start from here for step 3
    soft_targets = teacher_logits / temperature

    return tf.reduce_mean(
      tf.nn.softmax_cross_entropy_with_logits(
          soft_targets, student_logits / temperature)) * temperature ** 2

def compute_student_loss(images, labels):
    """Compute class knowledge distillation student loss for given images
     and labels.

    Args:
    images: Tensor representing a batch of images.
    labels: Tensor representing a batch of labels.

    Returns:
    Scalar loss Tensor.
    """
    student_class_logits = fc_model(images, training=True)

    # Compute class distillation loss between student class logits and
    # softened teacher class targets probabilities.

    # your code start from here for step 3

    teacher_class_logits = cnn_model(images, training=False)
    distillation_loss_value = distillation_loss(teacher_class_logits, student_class_logits, DISTILLATION_TEMPERATURE)

    # Compute cross-entropy loss with hard targets.

    # your code start from here for step 3
    
    cross_entropy_loss_value = tf.keras.losses.categorical_crossentropy(labels, tf.nn.softmax(student_class_logits))

    total_loss = ALPHA * cross_entropy_loss_value + (1 - ALPHA) * distillation_loss_value

    return total_loss

# Train and evaluation

In [None]:
@tf.function
    def compute_num_correct(model, images, labels):
    """Compute number of correctly classified images in a batch.

    Args:
    model: Instance of tf.keras.Model.
    images: Tensor representing a batch of images.
    labels: Tensor representing a batch of labels.

    Returns:
    Number of correctly classified images.
    """
    class_logits = model(images, training=False)
    return tf.reduce_sum(
        tf.cast(tf.math.equal(tf.argmax(class_logits, -1), tf.argmax(labels, -1)),
              tf.float32)), tf.argmax(class_logits, -1), tf.argmax(labels, -1)


def train_and_evaluate(model, compute_loss_fn, num_epochs, learning_rate):
    """Perform training and evaluation for a given model.

    Args:
    model: Instance of tf.keras.Model.
    compute_loss_fn: A function that computes the training loss given the
        images, and labels.
    num_epochs: Number of epochs to train for
    learning_rate: Optimizer learning rate
    """

    # your code start from here for step 4
    optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)

    for epoch in range(1, num_epochs + 1):
    # Run training.
    print('Epoch {}: '.format(epoch), end='')
    for images, labels in mhist_train:
        with tf.GradientTape() as tape:
             # your code start from here for step 4

            loss_value = compute_loss_fn(images, labels)

        grads = tape.gradient(loss_value, model.trainable_variables)
        optimizer.apply_gradients(zip(grads, model.trainable_variables))

    # Run evaluation.
    num_correct = 0
    num_total = builder.info.splits['test'].num_examples
    for images, labels in mnist_test:
        # your code start from here for step 4
        num_correct += compute_num_correct(model,images,labels)[0]
    print("Class_accuracy: " + '{:.2f}%'.format(num_correct / num_total * 100))


# Training models

In [None]:
# your code start from here for step 5 



# Test accuracy vs. tempreture curve

In [None]:
# your code start from here for step 6


# Train student from scratch

In [None]:
# Build fully connected student.
fc_model_no_distillation = tf.keras.Sequential()

# your code start from here for step 7



#@test {"output": "ignore"}

def compute_plain_cross_entropy_loss(images, labels):
  """Compute plain loss for given images and labels.

  For fair comparison and convenience, this function also performs a
  LogSumExp over classes, but does not perform class distillation.

  Args:
    images: Tensor representing a batch of images.
    labels: Tensor representing a batch of labels.

  Returns:
    Scalar loss Tensor.
  """
  # your code start from here for step 7

  student_class_logits = fc_model_no_distillation(images, training=True)
  cross_entropy_loss = 
  
  return cross_entropy_loss


train_and_evaluate(fc_model_no_distillation, compute_plain_cross_entropy_loss)

# Comparing the teacher and student model (number of of parameters and FLOPs) 

In [None]:
# your code start from here for step 8


# XAI method to explain models

In [None]:
# your code start from here for step 9


# Implementing the state-of-the-art KD algorithm

In [None]:
# your code start from here for step 13
