In [None]:
#LOAD DEPENDENCIES
import os
import time
import math
import pickle
import logging
import numpy as np
import tensorflow as tf

from tensorflow.keras.optimizers import Adam
from tensorflow.keras.models import load_model
from tensorflow.keras import applications, Model, layers
from tensorflow.keras.metrics import CategoricalAccuracy
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications.efficientnet import preprocess_input
from tensorflow.keras.losses import KLDivergence, CategoricalCrossentropy
from tensorflow.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau
from tensorflow.keras.layers import Input, Conv2D, GlobalAveragePooling2D, Dense, Dropout
from tensorflow.keras.layers import DepthwiseConv2D, Activation, BatchNormalization, Layer

model_architecture = "proposed_model_kd"

#PREVENT ERROR UNCESSARY MESSAGES
tf.get_logger().setLevel(logging.ERROR)
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

In [None]:
#Dataset path
data_path = "dataset"

#Student and teacher name
student_name = 'proposed_model'
teacher_name = 'EfficientNetB7'

#Model, Figures and Data  paths
distilled_student_model_path = 'models/Distilled_Student_Models/'
teacher_model_path = 'models/teacher_model/'
student_model_path = 'models/proposed_model/'
kd_model_path = 'models/kd_model/'

model_kind = 'KD_model'

model_kd = model_kind

model_path = "models/" + model_kind + '/'

#Custom Functions

#Load Model Function
def load_m_teacher(directory):
    with open(directory + '/' + teacher_name + '.json', "r") as json_file:
        teacher_model = json_file.read()
        teacher_model = model_from_json(teacher_model)
        teacher_model.load_weights(directory + '/' + teacher_name + '.h5')
        return teacher_model

def load_m_student(directory):
    with open(directory + '/' + student_name + '.json', "r") as json_file:
        student_model = json_file.read()
        student_model = model_from_json(student_model)
        student_model.load_weights(directory + '/' + student_name + '.h5')
        return student_model

#Load History
def load_h(file):
    with open('models/' + model_kind + '/' + model_architecture + '/' + model_architecture + '.history', 'rb') as file_pi:
        his = pickle.load(file_pi)
    return his

#Save Model Function
def save_m(directory, model):
    if not os.path.exists(directory):
        os.makedirs(directory)
    model.save(directory + '/' + model_architecture + '.h5')
    print("model saved")

#Save History Function
def save_h(directory, his):
    if not os.path.exists(directory):
        os.makedirs(directory)
    with open(directory + '/' + model_architecture + '.history', 'wb') as file_pi:
        pickle.dump(his, file_pi)
    print("history saved")

#Save Figure Function
def save_fig(directory, fig_name):
    if not os.path.exists(directory):
        os.makedirs(directory)
    plt.savefig(directory + '/' + fig_name + '.tiff', bbox_inches='tight', dpi=600, format='tiff')
    
#Get data from generator function
def get_data(generator, nb_samples):  
    from tqdm.notebook import tqdm
    x = []
    y = []

    for i in tqdm(range(math.ceil(nb_samples/batch_size))):
        x.extend(generator[i][0])
        y.extend(generator[i][1])

    x = np.array(x)
    y = np.array(y)
    return x, y

#Create generator from data function
def get_generator(x, y, preprocessing_function=None, rescale=None, shuffle=True,):
    datagen = ImageDataGenerator(rescale=rescale, preprocessing_function=preprocessing_function)
    datagen = datagen.flow(x, y, batch_size=batch_size, shuffle=shuffle)
    return datagen

In [None]:
#LOAD THE DATA

class_names = ['0_Non_vectors', 
                '1_Aedes_albopictus', 
                '2_Aedes_vexans', 
                '3_Anopheles_sinensis', 
                '4_Culex_pipiens', 
                '5_Culex_tritaeniorhynchus']

cm_target_names = ['0', '1', '2', '3', '4', '5']

print("Class names:", class_names)
print()

train_data_dir = data_path + "/train/"
validation_data_dir = data_path + "/validation/"
test_data_dir = data_path + "/test/"

#Image specifications and handling
batch_size = 16
img_rows, img_cols = 224, 224
input_shape = (img_rows,img_cols,3)

model_input = Input(shape=input_shape)
print("Data folders found!")
print()
print("The Input size is set to ", model_input) 

In [None]:
#Select type of model
fig_path = 'figures/' + model_kind + '/'

In [None]:
#DATA GENERATORS
epochs = 30

train_datagen = ImageDataGenerator(preprocessing_function=preprocess_input)
         
val_datagen = ImageDataGenerator(preprocessing_function=preprocess_input)

test_datagen = ImageDataGenerator(preprocessing_function=preprocess_input)

train_generator = train_datagen.flow_from_directory(
        train_data_dir,
        target_size=(img_rows,img_cols),
        batch_size=batch_size,
        class_mode='categorical',
        seed=42,
        classes=class_names)

validation_generator = val_datagen.flow_from_directory(
        validation_data_dir,
        target_size=(img_rows,img_cols),
        batch_size=batch_size,
        class_mode='categorical',
        seed=42,
        shuffle=False,
        classes=class_names)

#CHECK  THE NUMBER OF SAMPLES
nb_train_samples = len(train_generator.filenames)
nb_validation_samples = len(validation_generator.filenames)

if nb_train_samples == 0:
    print("NO DATA TRAIN FOUND! Please check your train data path and folders!")
else:
    print(nb_train_samples, "Train samples found!")
    
if nb_validation_samples == 0:
    print("NO DATA VALIDATION FOUND! Please check your validation data path and folders!")
    print("Check the data folders first!")
else:
    print(nb_validation_samples, "Validation samples found!")

#check the class indices
train_generator.class_indices
validation_generator.class_indices

#true labels
Y_test=validation_generator.classes

num_classes = len(train_generator.class_indices)

print('Model set to train', num_classes, 'classes')

if nb_train_samples and nb_validation_samples > 0:
    print("Generators are set!")
    print("Check if dataset is complete and has no problems before proceeding.")

if num_classes == 2:
    loss='binary_crossentropy'
    activation_classifier = 'sigmoid'
    print("loss function is set to:", loss)
    print("activation classifier is set to:", activation_classifier)
else:
    loss='categorical_crossentropy'
    activation_classifier = 'softmax'
    print("loss function is set to:", loss)
    print("activation classifier is set to:", activation_classifier)

In [None]:
#Credits to Kenneth Borup https://keras.io/examples/vision/knowledge_distillation/

class KDistiller(Model):
    def __init__(self, student, teacher):
        super(KDistiller, self).__init__()
        self.teacher = teacher
        self.student = student

    def compile(
        self,
        optimizer,
        metrics,
        student_loss_fn,
        distillation_loss_fn,
        alpha,
        temperature,
    ):
        super(KDistiller, self).compile(optimizer=optimizer, metrics=metrics)
        self.student_loss_fn = student_loss_fn
        self.distillation_loss_fn = distillation_loss_fn
        self.alpha = alpha
        self.temperature = temperature

    def train_step(self, data):
        # Unpack data
        x, y = data

        # Forward pass of teacher
        teacher_preds = self.teacher(x, training=False) #Soft labels

        with tf.GradientTape() as tape:
            # Forward pass of student
            student_preds = self.student(x, training=True) #Soft predictions

            # Compute losses
            student_loss = self.student_loss_fn(y, student_preds) #Categorical Cross Entropy Loss
            distillation_loss = self.distillation_loss_fn( #Total loss
                tf.nn.softmax(teacher_preds / self.temperature, axis=1),
                tf.nn.softmax(student_preds / self.temperature, axis=1),
            )
            loss = self.alpha * student_loss + (1 - self.alpha) * distillation_loss

        # Compute gradients
        trainable_vars = self.student.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)

        # Update weights
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))

        # Update the metrics 
        self.compiled_metrics.update_state(y, student_preds)

        # Return a dict of performance
        results = {m.name: m.result() for m in self.metrics}
        results.update(
            {"student_loss": student_loss, "distillation_loss": distillation_loss, 'combined_loss':loss}
        )

        return results

    def test_step(self, data):
        # Unpack the data
        x, y = data

        # Compute predictions
        y_preds = self.student(x, training=False)

        # Calculate the loss
        student_loss = self.student_loss_fn(y, y_preds)

        # Update the metrics.
        self.compiled_metrics.update_state(y, y_preds)

        # Return a dict of performance
        results = {m.name: m.result() for m in self.metrics}
        results.update({"student_loss": student_loss})
        return results

    def call(self, inputs, training):
        return self.student(inputs, training=training)

In [None]:
#Load teacher model
teacher_model = tf.keras.models.load_model(teacher_model_path + teacher_name + '/' + teacher_name + '.h5')

print("Teacher model", teacher_model.name, "successfully loaded")

In [None]:
#Check teacher
teacher_model.summary()

In [None]:
#Prepare validation generator for sanity check
validation_generator = val_datagen.flow_from_directory(
        validation_data_dir,
        target_size=(img_rows,img_cols),
        batch_size=batch_size,
        class_mode='categorical',
        seed=42,
        shuffle=False,
        classes=class_names)

#Sanity check the teacher model first
teacher_model.evaluate(validation_generator, 
                    batch_size=batch_size, 
                    steps=nb_validation_samples / batch_size)

In [None]:
#Load the student model
student_model = tf.keras.models.load_model(student_model_path + student_name + '/' + student_name + '.h5')

print("Student model", student_model.name, "successfully loaded")

In [None]:
#Check the student model
student_model.summary()

In [None]:
#Get training data in x, y format for distillation
print("Loading Training Data")
x_train, y_train = get_data(train_generator, nb_train_samples)

print("Loading Validation Data")
x_val, y_val = get_data(validation_generator, nb_validation_samples)

In [None]:
#Set training constants
epochs = 30
optimizer = Adam
learning_rate = 0.001
alpha=0.3
temperature = 2

print("Batch size is set to:", batch_size)
print("Epoch is set to:", epochs)
print("Loss is set to:", loss)

rint("Temperature is set to:", temperature)
rint("Alpha is set to:", alpha)
print("Learning rate is set to:", learning_rate)

print("Optimizer is set to:", optimizer.__name__)

In [None]:
# Distill teacher to student

#Create Knowledge distiller
distiller = KDistiller(student=student_model, teacher=teacher_model)

#Compile Knowledge distiller
distiller.compile(
    optimizer = optimizer(learning_rate=learning_rate),
    metrics=[CategoricalAccuracy()],
    student_loss_fn=CategoricalCrossentropy(),
    distillation_loss_fn= KLDivergence(),
    alpha=alpha,
    temperature=temperature,
    )

reduce_lr = ReduceLROnPlateau(monitor='val_acc', factor=0.5, patience=2,
                              verbose=1, mode='min', min_lr=0.000001)

callbacks = [reduce_lr]

#Get train and val generator
train_generator = get_generator(x_train, y_train, preprocessing_function=preprocess_input)
validation_generator = get_generator(x_val, y_val, preprocessing_function=preprocess_input)

#Training
distiller_history = distiller.fit(train_generator,
                                  validation_data = validation_generator,
                                  steps_per_epoch = nb_train_samples // batch_size,
                                  validation_steps = nb_validation_samples// batch_size,
                                  callbacks=callbacks,
                                  epochs=epochs)

In [None]:
#Set the distillation of soften knowledge to the proposed student
KD_student = distiller.student

In [None]:
#Compile the KD student for a sanity check
KD_student.compile(
          optimizer = optimizer(learning_rate=learning_rate),
          loss = CategoricalCrossentropy(from_logits=True),
          metrics = [CategoricalAccuracy()]
    )

In [None]:
#Prepare the validation generator for a sanity check
validation_generator = val_datagen.flow_from_directory(
        validation_data_dir,
        target_size=(img_rows,img_cols),
        batch_size=batch_size,
        class_mode='categorical',
        seed=42,
        shuffle=False,
        classes=class_names)

In [None]:
#Evaluate the KD student
KD_student.evaluate(validation_generator, 
                    batch_size=batch_size, 
                    steps=nb_validation_samples / batch_size)

In [None]:
#Save the model
save_m(kd_model_path + model_architecture, KD_student)

In [None]:
#Load the model
model = load_model(kd_model_path + model_architecture + '/' + model_architecture + '.h5')

In [None]:
#Save the history
save_h(kd_model_path + model_architecture, distiller_history.history)

In [None]:
#Load the history
history = load_h(model_architecture)

In [None]:
#Review the KD model
model.summary()

In [None]:
#Perform another sanity check to make sure that the results did not change
model.evaluate(validation_generator, 
                    batch_size=batch_size, 
                    steps=nb_validation_samples / batch_size)