# Project B: Knowledge Distillation for Building Lightweight Deep Learning Models in Visual Classification Tasks

In [1]:
import tensorflow.compat.v2 as tf
import tensorflow_datasets as tfds
from typing import Union

tf.enable_v2_behavior()

# Not originally included
import numpy as np
import os
import csv
from keras.preprocessing.image import ImageDataGenerator
from keras_flops import get_flops
import sklearn.metrics as metrics


BATCH_SIZE = 2
NUM_INIT_EPOCHS = 10
NUM_FINE_EPOCHS = 25

ENTROPY_ZERO_FILLER = 1e-15
NUM_CLASSES = 2 #SSA or HP
TRAIN_BATCHES = int(np.floor(2176/BATCH_SIZE))
TEST_BATCHES = int(np.floor(976/BATCH_SIZE))
NUM_TEST_IMAGE = 976

# For quick testing purposes ONLY
"""
TRAIN_BATCHES = 4
TEST_BATCHES = 4
"""

'\nTRAIN_BATCHES = 4\nTEST_BATCHES = 4\n'

# Data Loading/Augmenting

In [2]:
# Adapted from Project A's HMT Dataset loading code to use the ImageDataGenerator to augment the data and preload labels/image batches

img_dir = 'mhist_dataset/images'
train_dir = 'mhist_dataset/images/train'
test_dir = 'mhist_dataset/images/test'
anno_csv = 'mhist_dataset/annotations.csv'


if not os.path.isdir(train_dir):
    os.mkdir(train_dir)
if not os.path.isdir(test_dir):
    os.mkdir(test_dir)
    
if not os.path.isdir(os.path.join(train_dir, '01_HP')):
    os.mkdir(os.path.join(train_dir, '01_HP'))
if not os.path.isdir(os.path.join(train_dir, '02_SSA')):
    os.mkdir(os.path.join(train_dir, '02_SSA'))
if not os.path.isdir(os.path.join(test_dir, '01_HP')):
    os.mkdir(os.path.join(test_dir, '01_HP'))
if not os.path.isdir(os.path.join(test_dir, '02_SSA')):
    os.mkdir(os.path.join(test_dir, '02_SSA'))
    
# load csv
# label struct: [HP, SSA]
# labels as a list are sorted in alphabetical order as per the csv

train_labels = []
test_labels = []
train_img = []
test_img = []

with open(anno_csv, 'r') as csvfile:
    first_row = True
    for row in csv.reader(csvfile):
        if first_row:
            first_row = False
            continue
        if row[3] == 'train':
            train_img.append(row[0])
            if row[1] == 'HP':
                train_labels.append([1, 0])
            elif row[1] == 'SSA':
                train_labels.append([0, 1])
        elif row[3] == 'test':
            test_img.append(row[0])
            if row[1] == 'HP':
                test_labels.append([1, 0])
            elif row[1] == 'SSA':
                test_labels.append([0, 1])
        if row[0] in os.listdir(img_dir):
            if row[1] == 'HP' and row[3] == 'train':
                os.rename(os.path.join(img_dir, row[0]), os.path.join(train_dir, '01_HP', row[0]))
            elif row[1] == 'SSA' and row[3] == 'train':
                os.rename(os.path.join(img_dir, row[0]), os.path.join(train_dir, '02_SSA', row[0]))
            elif row[1] == 'HP' and row[3] == 'test':
                os.rename(os.path.join(img_dir, row[0]), os.path.join(test_dir, '01_HP', row[0]))
            elif row[1] == 'SSA' and row[3] == 'test':
                os.rename(os.path.join(img_dir, row[0]), os.path.join(test_dir, '02_SSA', row[0]))
    
# Data Augmentation using ImageDataGenerator
train_datagen = ImageDataGenerator(rescale=1/255.,
shear_range=0.1,
rotation_range=15,
horizontal_flip=True,
vertical_flip=True)

test_datagen = ImageDataGenerator(rescale=1/255.)

train_generator = train_datagen.flow_from_directory(train_dir,
class_mode='categorical',
interpolation='bilinear',
target_size=(224, 224),
batch_size=BATCH_SIZE,
shuffle=True) # 68 batches of 32

eval_generator = test_datagen.flow_from_directory(test_dir,
class_mode='categorical',
interpolation='bilinear',
target_size=(224, 224),
batch_size=244,
shuffle=True) # single batch half of the test images to get F1 scores over full test dataset

test_generator = test_datagen.flow_from_directory(test_dir,
class_mode='categorical',
interpolation='bilinear',
target_size=(224, 224),
batch_size=BATCH_SIZE,
shuffle=False) # 30.5 batches of 32

Found 2176 images belonging to 2 classes.
Found 976 images belonging to 2 classes.
Found 976 images belonging to 2 classes.


# Model Creation

In [14]:
# citing https://www.tensorflow.org/api_docs/python/tf/keras/applications/resnet_v2/ResNet50V2, https://www.kaggle.com/code/suniliitb96/tutorial-keras-transfer-learning-with-resnet50/notebook,
# https://www.tensorflow.org/api_docs/python/tf/keras/applications/mobilenet_v2/MobileNetV2 and https://github.com/Abhi-T/MNIST-CLASSIFIER-From-Scratch/blob/main/MNIST__handwritten_digit_Model.ipynb
teacher_model = tf.keras.Sequential()
teacher_model.add(tf.keras.applications.resnet_v2.ResNet50V2(include_top=False, pooling='avg', classifier_activation = None, input_shape = (224, 224, 3)))
teacher_model.add(tf.keras.layers.Dense(NUM_CLASSES)) 
print(teacher_model.summary())

# Build fully connected students
student_kd_model = tf.keras.Sequential()
student_kd_model.add(tf.keras.applications.mobilenet_v2.MobileNetV2(include_top=False, pooling='avg',classifier_activation = None, input_shape = (224, 224, 3)))
student_kd_model.add(tf.keras.layers.Dense(NUM_CLASSES)) 
print(student_kd_model.summary())

student_scratch_model = tf.keras.Sequential()
student_scratch_model.add(tf.keras.applications.mobilenet_v2.MobileNetV2(include_top=False, pooling='avg',classifier_activation = None, input_shape = (224, 224, 3)))
student_scratch_model.add(tf.keras.layers.Dense(NUM_CLASSES)) 
print(student_scratch_model.summary())

Model: "sequential_10"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 resnet50v2 (Functional)     (None, 2048)              23564800  
                                                                 
 dense_10 (Dense)            (None, 2)                 4098      
                                                                 
Total params: 23,568,898
Trainable params: 23,523,458
Non-trainable params: 45,440
_________________________________________________________________
None
Model: "sequential_11"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 mobilenetv2_1.00_224 (Funct  (None, 1280)             2257984   
 ional)                                                          
                                                                 
 dense_11 (Dense)            (None, 2)                 2562      
        

# Loss Function

## Teacher Loss Function

In [8]:
def compute_teacher_loss(model, images, labels, **kwargs):
    """Compute class knowledge distillation teacher loss for given images
     and labels.

    Args:
    images: Tensor representing a batch of images.
    labels: Tensor representing a batch of labels.
    kwargs: n/a

    Returns:
    Scalar loss Tensor.
    """
    
    class_logits = model(images, training=True)

    # Compute cross-entropy loss for classes.
    
    cross_entropy_loss_value = tf.nn.softmax_cross_entropy_with_logits(labels, class_logits)

    return cross_entropy_loss_value

## Student (KD) Loss Function

In [9]:
# adapted from https://keras.io/examples/vision/knowledge_distillation/

def distillation_loss(teacher_logits: tf.Tensor, student_logits: tf.Tensor,
                      temperature: Union[float, tf.Tensor]):
    """Compute distillation loss.

    This function computes cross entropy between softened logits and softened
    targets. The resulting loss is scaled by the squared temperature so that
    the gradient magnitude remains approximately constant as the temperature is
    changed. For reference, see Hinton et al., 2014, "Distilling the knowledge in
    a neural network."

    Args:
    teacher_logits: A Tensor of logits provided by the teacher.
    student_logits: A Tensor of logits provided by the student, of the same
      shape as `teacher_logits`.
    temperature: Temperature to use for distillation.

    Returns:
    A scalar Tensor containing the distillation loss.
    """
    # your code start from here for step 3
    soft_targets = teacher_logits / temperature

    return tf.reduce_mean(
      tf.nn.softmax_cross_entropy_with_logits(
          soft_targets, student_logits / temperature)) * temperature ** 2

def compute_student_loss(student_model, images, labels, **kwargs):
    """Compute class knowledge distillation student loss for given images
     and labels.

    Args:
    images: Tensor representing a batch of images.
    labels: Tensor representing a batch of labels.
    kwargs:
        teacher_model: Teacher model
        temperature: Temperature hyperparameter
        alpha: Alpha hyperparameter

    Returns:
    Scalar loss Tensor.
    """
    
    teacher_model = kwargs['teacher_model']
    temperature = kwargs['temperature']
    alpha = kwargs['alpha']
    
    student_class_logits = student_model(images, training=True)

    # Compute class distillation loss between student class logits and
    # softened teacher class targets probabilities.

    teacher_class_logits = teacher_model(images, training=False)
    distillation_loss_value = distillation_loss(teacher_class_logits, student_class_logits, temperature)

    # Compute cross-entropy loss with hard targets.
    
    cross_entropy_loss_value = tf.nn.softmax_cross_entropy_with_logits(labels, student_class_logits)

    total_loss = alpha * cross_entropy_loss_value + (1 - alpha) * distillation_loss_value

    return total_loss

## Student (Scratch) Loss Function

In [10]:
def compute_student_scratch_loss(model, images, labels, **kwargs):
    """Compute class student (scratch) loss for given images
     and labels.

    Args:
    images: Tensor representing a batch of images.
    labels: Tensor representing a batch of labels.
    kwargs:
        temperature: Temperature hyperparameter

    Returns:
    Scalar loss Tensor.
    """
    temperature = kwargs['temperature']
    
    class_logits = model(images, training=True)

    # Compute cross-entropy loss for classes.
    
    cross_entropy_loss_value = tf.nn.softmax_cross_entropy_with_logits(labels, class_logits/temperature) * temperature ** 2

    return cross_entropy_loss_value

# Train and Evaluation

In [11]:
@tf.function
def compute_num_correct(model, images, labels):
    """Compute number of correctly classified images in a batch.

    Args:
    model: Instance of tf.keras.Model.
    images: Tensor representing a batch of images.
    labels: Tensor representing a batch of labels.

    Returns:
    Number of correctly classified images.
    """
    class_logits = model(images, training=False)
    return tf.reduce_sum(
        tf.cast(tf.math.equal(tf.argmax(class_logits, -1), tf.argmax(labels, -1)),
              tf.float32)), tf.argmax(class_logits, -1), tf.argmax(labels, -1)


def train_and_evaluate(model, compute_loss_fn, num_epochs, learning_rate, **kwargs):
    """Perform training and evaluation for a given model.

    Args:
    model: Main Instance of tf.keras.Model.
    compute_loss_fn: A function that computes the training loss given the
        images, and labels.
    num_epochs: Number of epochs to train for
    learning_rate: Optimizer learning rate
    kwargs: Passed through to loss fn
    """

    # your code start from here for step 4
    optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)

    train_generator.reset()
    
    accuracy = 0
    
    for epoch in range(1, num_epochs + 1):
        # Run training.
        print('Epoch {}: '.format(epoch), end='')

        #for images, labels in mhist_train:
        for batch in range(TRAIN_BATCHES):
            
            images, labels = train_generator.next()

            with tf.GradientTape() as tape:

                loss_value = compute_loss_fn(model, images, labels, **kwargs)

            grads = tape.gradient(loss_value, model.trainable_variables)
            optimizer.apply_gradients(zip(grads, model.trainable_variables))

        # Run evaluation.
        num_correct = 0
        num_total = BATCH_SIZE*TEST_BATCHES
        
        test_generator.reset()
        
        for batch in range(TEST_BATCHES):
            images, labels = test_generator.next()
            num_correct += compute_num_correct(model,images,labels)[0]
        accuracy = num_correct / num_total * 100
        print("Class_accuracy: " + '{:.2f}%'.format(accuracy))
    
    return accuracy


# Training models

In [19]:
# your code start from here for step 5 

print("Teacher Run: Initial")
train_and_evaluate(teacher_model, compute_teacher_loss, NUM_INIT_EPOCHS, 1e-4)
print("\nTeacher Run: Fine")
teach_acc = train_and_evaluate(teacher_model, compute_teacher_loss, NUM_FINE_EPOCHS, 1e-5)

print("\n\nStudent (KD) Run: Initial")
train_and_evaluate(student_kd_model, compute_student_loss, NUM_INIT_EPOCHS, 1e-3, teacher_model=teacher_model, temperature=4, alpha=0.5)
print("\nStudent (KD) Run: Fine")
skd_acc = train_and_evaluate(student_kd_model, compute_student_loss, NUM_FINE_EPOCHS, 1e-4, teacher_model=teacher_model, temperature=4, alpha=0.5)

print("\n\nStudent (Scratch) Run: Initial")
train_and_evaluate(student_scratch_model, compute_student_scratch_loss, NUM_INIT_EPOCHS, 1e-3, temperature=4)
print("\nStudent (Scratch) Run: Fine")
ss_acc = train_and_evaluate(student_scratch_model, compute_student_scratch_loss, NUM_FINE_EPOCHS, 1e-4, temperature=4)

Teacher Run: Initial
Epoch 1: Class_accuracy: 69.88%
Epoch 2: Class_accuracy: 74.80%
Epoch 3: Class_accuracy: 70.90%
Epoch 4: Class_accuracy: 76.54%
Epoch 5: Class_accuracy: 75.82%
Epoch 6: Class_accuracy: 77.15%
Epoch 7: Class_accuracy: 76.23%
Epoch 8: Class_accuracy: 71.52%
Epoch 9: Class_accuracy: 76.54%
Epoch 10: Class_accuracy: 74.59%

Teacher Run: Fine
Epoch 1: Class_accuracy: 85.76%
Epoch 2: Class_accuracy: 84.73%
Epoch 3: Class_accuracy: 85.25%
Epoch 4: Class_accuracy: 85.66%
Epoch 5: Class_accuracy: 86.27%
Epoch 6: Class_accuracy: 85.45%
Epoch 7: Class_accuracy: 85.14%
Epoch 8: Class_accuracy: 85.25%
Epoch 9: Class_accuracy: 85.55%
Epoch 10: Class_accuracy: 85.25%
Epoch 11: Class_accuracy: 84.84%
Epoch 12: Class_accuracy: 86.27%
Epoch 13: Class_accuracy: 84.63%
Epoch 14: Class_accuracy: 86.27%
Epoch 15: Class_accuracy: 87.40%
Epoch 16: Class_accuracy: 86.99%
Epoch 17: Class_accuracy: 85.14%
Epoch 18: Class_accuracy: 84.32%
Epoch 19: Class_accuracy: 86.37%
Epoch 20: Class_accur

# Test Accuracy vs. Temperature Curve

In [20]:
# Original student uses temperature 4
student_kd_model1 = tf.keras.Sequential()
student_kd_model1.add(tf.keras.applications.mobilenet_v2.MobileNetV2(include_top=False, pooling='avg',classifier_activation = None, input_shape = (224, 224, 3)))
student_kd_model1.add(tf.keras.layers.Dense(NUM_CLASSES)) 

student_kd_model2 = tf.keras.Sequential()
student_kd_model2.add(tf.keras.applications.mobilenet_v2.MobileNetV2(include_top=False, pooling='avg',classifier_activation = None, input_shape = (224, 224, 3)))
student_kd_model2.add(tf.keras.layers.Dense(NUM_CLASSES)) 

student_kd_model16 = tf.keras.Sequential()
student_kd_model16.add(tf.keras.applications.mobilenet_v2.MobileNetV2(include_top=False, pooling='avg',classifier_activation = None, input_shape = (224, 224, 3)))
student_kd_model16.add(tf.keras.layers.Dense(NUM_CLASSES)) 

student_kd_model32 = tf.keras.Sequential()
student_kd_model32.add(tf.keras.applications.mobilenet_v2.MobileNetV2(include_top=False, pooling='avg',classifier_activation = None, input_shape = (224, 224, 3)))
student_kd_model32.add(tf.keras.layers.Dense(NUM_CLASSES)) 

student_kd_model64 = tf.keras.Sequential()
student_kd_model64.add(tf.keras.applications.mobilenet_v2.MobileNetV2(include_top=False, pooling='avg',classifier_activation = None, input_shape = (224, 224, 3)))
student_kd_model64.add(tf.keras.layers.Dense(NUM_CLASSES)) 

student_dict = {1: student_kd_model1, 2: student_kd_model2, 4: student_kd_model, 16: student_kd_model16, 32: student_kd_model32, 64: student_kd_model64}
acc_dict = {4: 63.32}

for temp in [1, 2, 16, 32, 64]:
    print(f'Student (KD, Temperature = {temp}) Run: Initial')
    train_and_evaluate(student_dict[temp], compute_student_loss, NUM_INIT_EPOCHS, 1e-3, teacher_model=teacher_model, temperature=temp, alpha=0.5)
    print(f'\nStudent (KD, Temperature = {temp}) Run: Fine')
    acc_dict[temp] = train_and_evaluate(student_dict[temp], compute_student_loss, NUM_FINE_EPOCHS, 1e-4, teacher_model=teacher_model, temperature=temp, alpha=0.5)
    print('\n')
    



Student (KD, Temperature = 1) Run: Initial
Epoch 1: Class_accuracy: 63.22%
Epoch 2: Class_accuracy: 36.78%
Epoch 3: Class_accuracy: 42.32%
Epoch 4: Class_accuracy: 45.39%
Epoch 5: Class_accuracy: 63.22%
Epoch 6: Class_accuracy: 45.80%
Epoch 7: Class_accuracy: 63.11%
Epoch 8: Class_accuracy: 59.12%
Epoch 9: Class_accuracy: 47.03%
Epoch 10: Class_accuracy: 63.22%

Student (KD, Temperature = 1) Run: Fine
Epoch 1: Class_accuracy: 63.11%
Epoch 2: Class_accuracy: 59.32%
Epoch 3: Class_accuracy: 61.99%
Epoch 4: Class_accuracy: 63.63%
Epoch 5: Class_accuracy: 61.07%
Epoch 6: Class_accuracy: 63.22%
Epoch 7: Class_accuracy: 62.81%
Epoch 8: Class_accuracy: 63.11%
Epoch 9: Class_accuracy: 63.22%
Epoch 10: Class_accuracy: 64.24%
Epoch 11: Class_accuracy: 63.01%
Epoch 12: Class_accuracy: 62.70%
Epoch 13: Class_accuracy: 63.42%
Epoch 14: Class_accuracy: 63.11%
Epoch 15: Class_accuracy: 62.19%
Epoch 16: Class_accuracy: 63.01%
Epoch 17: Class_accuracy: 62.70%
Epoch 18: Class_accuracy: 62.50%
Epoch 19: 

# State of the Art Models

In [10]:
takd_teacher_model = tf.keras.Sequential()
takd_teacher_model.add(tf.keras.applications.resnet_v2.ResNet101V2(include_top=False, pooling='avg',
                                                                   classifier_activation = None, 
                                                                   input_shape = (224, 224, 3)))
takd_teacher_model.add(tf.keras.layers.Dense(NUM_CLASSES)) 
print(takd_teacher_model.summary())

takd_ta_model = tf.keras.Sequential()
takd_ta_model.add(tf.keras.applications.resnet_v2.ResNet50V2(include_top=False, pooling='avg', 
                                                             classifier_activation = None, 
                                                             input_shape = (224, 224, 3)))
takd_ta_model.add(tf.keras.layers.Dense(NUM_CLASSES)) 
print(takd_ta_model.summary())

takd_student_takd_model = tf.keras.Sequential()
takd_student_takd_model.add(tf.keras.applications.mobilenet_v2.MobileNetV2(include_top=False, pooling='avg',classifier_activation = None, input_shape = (224, 224, 3)))
takd_student_takd_model.add(tf.keras.layers.Dense(NUM_CLASSES)) 
print(takd_student_takd_model.summary())

takd_student_teacherkd_model = tf.keras.Sequential()
takd_student_teacherkd_model.add(tf.keras.applications.mobilenet_v2.MobileNetV2(include_top=False, pooling='avg',classifier_activation = None, input_shape = (224, 224, 3)))
takd_student_teacherkd_model.add(tf.keras.layers.Dense(NUM_CLASSES)) 
print(takd_student_teacherkd_model.summary())


Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 resnet50v2 (Functional)     (None, 2048)              23564800  
                                                                 
 dense (Dense)               (None, 2)                 4098      
                                                                 
Total params: 23,568,898
Trainable params: 23,523,458
Non-trainable params: 45,440
_________________________________________________________________
None
Model: "sequential_1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 mobilenetv2_1.00_224 (Funct  (None, 1280)             2257984   
 ional)                                                          
                                                                 
 dense_1 (Dense)             (None, 2)                 2562      
            

In [13]:
print("TAKD Teacher Run: Initial")
train_and_evaluate(takd_teacher_model, compute_teacher_loss, NUM_INIT_EPOCHS, 1e-4)
print("\nTAKD Teacher Run: Fine")
takd_teach_acc = train_and_evaluate(takd_teacher_model, compute_teacher_loss, NUM_FINE_EPOCHS, 1e-5)

print("\nTAKD TA Run: Initial")
train_and_evaluate(takd_ta_model, compute_student_loss, NUM_INIT_EPOCHS, 1e-3, teacher_model=takd_teacher_model, temperature=4, alpha=0.5)
print("\nTAKD TA Run: Fine")
takd_ta_acc = train_and_evaluate(takd_ta_model, compute_student_loss, NUM_FINE_EPOCHS, 1e-4, teacher_model=takd_teacher_model, temperature=4, alpha=0.5)

print("\n\nTAKD Student (TA Distilled) Run: Initial")
train_and_evaluate(takd_student_takd_model, compute_student_loss, NUM_INIT_EPOCHS, 1e-3, teacher_model=takd_ta_model, temperature=4, alpha=0.5)
print("\nTAKD Student (TA Distilled) Run: Fine")
takd_stakd_acc = train_and_evaluate(takd_student_takd_model, compute_student_loss, NUM_FINE_EPOCHS, 1e-4, teacher_model=takd_ta_model, temperature=4, alpha=0.5)

print("\n\nTAKD Student (Teacher Distilled) Run: Initial")
train_and_evaluate(takd_student_teacherkd_model, compute_student_loss, NUM_INIT_EPOCHS, 1e-3, teacher_model=takd_teacher_model, temperature=4, alpha=0.5)
print("\nTAKD Student (Teacher Distilled) Run: Fine")
takd_steachkd_acc = train_and_evaluate(takd_student_teacherkd_model, compute_student_loss, NUM_FINE_EPOCHS, 1e-4, teacher_model=takd_teacher_model, temperature=4, alpha=0.5)


TAKD Teacher Run: Initial

TAKD Teacher Run: Fine

TAKD TA Run: Initial
Epoch 1: Class_accuracy: 63.22%
Epoch 2: Class_accuracy: 63.22%
Epoch 3: Class_accuracy: 63.22%
Epoch 4: Class_accuracy: 63.22%
Epoch 5: Class_accuracy: 63.22%
Epoch 6: Class_accuracy: 36.78%
Epoch 7: Class_accuracy: 63.22%
Epoch 8: Class_accuracy: 36.78%
Epoch 9: Class_accuracy: 63.22%
Epoch 10: Class_accuracy: 63.22%

TAKD TA Run: Fine
Epoch 1: Class_accuracy: 63.22%
Epoch 2: Class_accuracy: 63.22%
Epoch 3: Class_accuracy: 63.22%
Epoch 4: Class_accuracy: 59.94%
Epoch 5: Class_accuracy: 63.22%
Epoch 6: Class_accuracy: 63.22%
Epoch 7: Class_accuracy: 63.22%
Epoch 8: Class_accuracy: 63.22%
Epoch 9: Class_accuracy: 63.22%
Epoch 10: Class_accuracy: 59.43%
Epoch 11: Class_accuracy: 63.22%
Epoch 12: Class_accuracy: 63.22%
Epoch 13: Class_accuracy: 63.22%
Epoch 14: Class_accuracy: 63.22%
Epoch 15: Class_accuracy: 63.22%
Epoch 16: Class_accuracy: 49.69%
Epoch 17: Class_accuracy: 63.22%
Epoch 18: Class_accuracy: 63.22%
Epo

## Model without Transfer Learning

In [13]:
student_notf_model = tf.keras.Sequential()
student_notf_model.add(tf.keras.applications.mobilenet_v2.MobileNetV2(include_top=False, pooling='avg', weights=None, classifier_activation = None, input_shape = (224, 224, 3)))
student_notf_model.add(tf.keras.layers.Dense(NUM_CLASSES)) 
print(student_kd_model.summary())

Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 mobilenetv2_1.00_224 (Funct  (None, 1280)             2257984   
 ional)                                                          
                                                                 
 dense (Dense)               (None, 2)                 2562      
                                                                 
Total params: 2,260,546
Trainable params: 2,226,434
Non-trainable params: 34,112
_________________________________________________________________
None


In [15]:
print("Student (no TF/KD) Run: Initial")
train_and_evaluate(student_notf_model, compute_teacher_loss, NUM_INIT_EPOCHS, 1)
print("\nStudent (no TF/KD) Run: Fine")
takd_teach_acc = train_and_evaluate(student_notf_model, compute_teacher_loss, NUM_FINE_EPOCHS, 1e-1)

Student (no TF/KD) Run: Initial
Epoch 1: Class_accuracy: 36.78%
Epoch 2: Class_accuracy: 63.22%
Epoch 3: Class_accuracy: 63.22%
Epoch 4: Class_accuracy: 63.22%
Epoch 5: Class_accuracy: 63.22%
Epoch 6: Class_accuracy: 63.22%
Epoch 7: Class_accuracy: 36.78%
Epoch 8: Class_accuracy: 63.22%
Epoch 9: Class_accuracy: 63.22%
Epoch 10: Class_accuracy: 36.78%

Student (no TF/KD) Run: Fine
Epoch 1: Class_accuracy: 63.22%
Epoch 2: Class_accuracy: 36.78%
Epoch 3: Class_accuracy: 63.22%
Epoch 4: Class_accuracy: 63.22%
Epoch 5: Class_accuracy: 36.78%
Epoch 6: Class_accuracy: 63.22%
Epoch 7: Class_accuracy: 63.22%
Epoch 8: Class_accuracy: 36.78%
Epoch 9: Class_accuracy: 36.78%
Epoch 10: Class_accuracy: 63.22%
Epoch 11: Class_accuracy: 63.22%
Epoch 12: Class_accuracy: 36.78%
Epoch 13: Class_accuracy: 36.78%
Epoch 14: Class_accuracy: 36.78%
Epoch 15: Class_accuracy: 63.22%
Epoch 16: Class_accuracy: 36.78%
Epoch 17: Class_accuracy: 63.22%
Epoch 18: Class_accuracy: 36.78%
Epoch 19: Class_accuracy: 63.22%

# Save/Load Models

## Save Models

In [16]:
if not os.path.exists('MHIST_TRAINED_MODEL'):
    os.mkdir('MHIST_TRAINED_MODEL')
teacher_model.save('MHIST_TRAINED_MODEL/teacher.h5')
student_kd_model.save('MHIST_TRAINED_MODEL/student_kd4.h5')
student_scratch_model.save('MHIST_TRAINED_MODEL/student_scratch.h5')
student_kd_model1.save('MHIST_TRAINED_MODEL/student_kd1.h5')
student_kd_model2.save('MHIST_TRAINED_MODEL/student_kd2.h5')
student_kd_model16.save('MHIST_TRAINED_MODEL/student_kd16.h5')
student_kd_model32.save('MHIST_TRAINED_MODEL/student_kd32.h5')
student_kd_model64.save('MHIST_TRAINED_MODEL/student_kd64.h5')
takd_teacher_model.save('MHIST_TRAINED_MODEL/takd_teacher.h5')
takd_ta_model.save('MHIST_TRAINED_MODEL/takd_ta.h5')
takd_student_takd_model.save('MHIST_TRAINED_MODEL/takd_student_takd.h5')
takd_student_teacherkd_model.save('MHIST_TRAINED_MODEL/takd_student_teacherkd.h5')
student_notf_model.save('MHIST_TRAINED_MODEL/student_notf.h5')



## Load Models

In [17]:
# Load the Models

# Initial Teacher/Student Models
teacher_model = tf.keras.models.load_model('MHIST_TRAINED_MODEL/teacher.h5')
student_kd_model = tf.keras.models.load_model('MHIST_TRAINED_MODEL/student_kd4.h5')

# Student from Scratch
student_scratch_model = tf.keras.models.load_model('MHIST_TRAINED_MODEL/student_scratch.h5')

# Student Temperature Sweep Models
student_kd_model1 = tf.keras.models.load_model('MHIST_TRAINED_MODEL/student_kd1.h5')
student_kd_model2 = tf.keras.models.load_model('MHIST_TRAINED_MODEL/student_kd2.h5')
student_kd_model16 = tf.keras.models.load_model('MHIST_TRAINED_MODEL/student_kd16.h5')
student_kd_model32 = tf.keras.models.load_model('MHIST_TRAINED_MODEL/student_kd32.h5')
student_kd_model64 = tf.keras.models.load_model('MHIST_TRAINED_MODEL/student_kd64.h5')

# TAKD Models
takd_teacher_model = tf.keras.models.load_model('MHIST_TRAINED_MODEL/takd_teacher.h5')
takd_ta_model = tf.keras.models.load_model('MHIST_TRAINED_MODEL/takd_ta.h5')
takd_student_takd_model = tf.keras.models.load_model('MHIST_TRAINED_MODEL/takd_student_takd.h5')
takd_student_teacherkd_model = tf.keras.models.load_model('MHIST_TRAINED_MODEL/takd_student_teacherkd.h5')

# No TF
student_notf_model = tf.keras.models.load_model('MHIST_TRAINED_MODEL/student_notf.h5')



# Evaluation

## FLOP/Parameter Generation

In [4]:
for model in [teacher_model, student_kd_model, student_scratch_model, takd_teacher_model, takd_ta_model]:
    print(model.summary());
    flops = get_flops(model, batch_size=1)
    print(f"FLOPS: {flops / 10 ** 9:.03} G")
    print('\n')

Model: "sequential_10"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 resnet50v2 (Functional)     (None, 2048)              23564800  
                                                                 
 dense_10 (Dense)            (None, 2)                 4098      
                                                                 
Total params: 23,568,898
Trainable params: 23,523,458
Non-trainable params: 45,440
_________________________________________________________________
None
Instructions for updating:
Use `tf.compat.v1.graph_util.tensor_shape_from_node_def_name`
FLOPS: 6.99 G


Model: "sequential_11"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 mobilenetv2_1.00_224 (Funct  (None, 1280)             2257984   
 ional)                                                          
                                   

# UAC Score Generation

In [18]:
eval_generator.reset()
image_batch,label_batch=eval_generator.next()

model_list = [teacher_model, student_kd_model, student_scratch_model, 
              student_kd_model1, student_kd_model2, student_kd_model16,
              student_kd_model32, student_kd_model64, 
              takd_teacher_model, takd_ta_model, takd_student_takd_model,
              takd_student_teacherkd_model, student_notf_model]
model_name_list = ['teacher_model', 'student_kd_model', 'student_scratch_model', 
                   'student_kd_model1', 'student_kd_model2', 'student_kd_model16',
                   'student_kd_model32', 'student_kd_model64', 
                   'takd_teacher_model', 'takd_ta_model', 'takd_student_takd_model',
                   'takd_student_teacherkd_model', 'student_notf_model']

model_auc_dict = dict()

for model_idx, model in enumerate(model_list):
    
    y_score = model.predict(image_batch)
    y_pred = np.argmax(y_score, axis=1)
    
    roc_auc = [];
    for class_idx in range(NUM_CLASSES):   
        fpr, tpr, _ = metrics.roc_curve(label_batch[:, class_idx], y_score[:, class_idx])
        roc_auc.append(metrics.auc(fpr, tpr))
    model_auc_dict[model_name_list[model_idx]] = roc_auc
    print(f'{model_name_list[model_idx]} AUC = {model_auc_dict[model_name_list[model_idx]]}')

teacher_model AUC = [0.9278386008400469, 0.9274943193555051]
student_kd_model AUC = [0.4782414101769607, 0.5207946016663223]
student_scratch_model AUC = [0.7522550437237485, 0.7540453074433656]
student_kd_model1 AUC = [0.7220615575294361, 0.6962060180403498]
student_kd_model2 AUC = [0.4647800041313779, 0.6077945328100255]
student_kd_model16 AUC = [0.4926323762308063, 0.5073676237691936]
student_kd_model32 AUC = [0.4135509192315637, 0.5872065000344282]
student_kd_model64 AUC = [0.48402533911726225, 0.5166287957033671]
takd_teacher_model AUC = [0.9132410658954762, 0.9151690422089102]
takd_ta_model AUC = [0.3626661158162914, 0.6373338841837086]
takd_student_takd_model AUC = [0.6028713075810783, 0.39461543758176687]
takd_student_teacherkd_model AUC = [0.5074709082145562, 0.4926323762308063]
student_notf_model AUC = [0.5, 0.5]
