## Transfer Learning and Knowledge Distillation
 In this notebook, the goal is to apply knowledge distillation to a student other than the master but with comparable complexity. To speed up the teacher's training process we will use a pre-trained model.

In [None]:
#All the imports
import tensorflow as tf
from BANEnsemble import BANEnsemble
from knowledge_distillation import distill_knowledge, ban
from WideResNet import WideResidualNetwork

In [None]:
# Connection with the dataset in drive
from google.colab import drive
drive.mount("/content/drive")

Mounted at /content/drive


## Preprocessing and Augmentation
In order to maintain consistency with previous tests, the preprocessing and image augmentation will be the same

In [None]:
BATCH_SIZE = 32
IMAGE_SIZE = (96, 96)
data_path = '/content'

train_datagen = tf.keras.preprocessing.image.ImageDataGenerator(
        rescale=1./255,
        horizontal_flip=True,
        width_shift_range=0.1,
        height_shift_range=0.1,
        rotation_range=0.05        )  
valid_datagen = tf.keras.preprocessing.image.ImageDataGenerator(
    rescale=1./255,
    )

test_datagen = tf.keras.preprocessing.image.ImageDataGenerator(
    rescale=1./255
)
train_generator = train_datagen.flow_from_directory(
        data_path+'/train',
        target_size=IMAGE_SIZE,
        batch_size=BATCH_SIZE,
        class_mode='categorical',
        )
validation_generator = valid_datagen.flow_from_directory(
        data_path+'/valid',
        target_size=IMAGE_SIZE,
        batch_size=BATCH_SIZE,
        class_mode='categorical',
        )
test_generator = test_datagen.flow_from_directory(
        data_path+'/test',
        target_size=IMAGE_SIZE,
        batch_size=BATCH_SIZE,
        class_mode='categorical')

Found 31316 images belonging to 225 classes.
Found 1125 images belonging to 225 classes.
Found 1125 images belonging to 225 classes.


## Teacher Model Training

I will start training the teacher that this time is MobileNetV2, which is in the keras package application. It is a model suitable for the purposes of the notebook. To speed up the training we will see that it will be possible to obtain good results using transfer learning. 
By default MobileNet has the imagenet weights.

In [None]:
base_model = tf.keras.applications.MobileNetV2(include_top=False, input_shape=(96, 96, 3), classes=225)


Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/mobilenet_v2/mobilenet_v2_weights_tf_dim_ordering_tf_kernels_1.0_96_no_top.h5


In [None]:
teacher_model = tf.keras.models.load_model('mobile_net.h5')

In [None]:
base_model.trainable=False


In [None]:
# Adding the classification layer

inputs = tf.keras.layers.Input(shape=(96, 96, 3))
x = base_model(inputs, training=False)
x = tf.keras.layers.GlobalAveragePooling2D()(x)
x = tf.keras.layers.Dense(225)(x)
x = tf.keras.layers.Activation('softmax')(x)

teacher_model = tf.keras.models.Model(inputs, x)

teacher_model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
teacher_model.summary()

Model: "functional_11"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_21 (InputLayer)        [(None, 96, 96, 3)]       0         
_________________________________________________________________
mobilenetv2_1.00_96 (Functio (None, 3, 3, 1280)        2257984   
_________________________________________________________________
global_average_pooling2d_10  (None, 1280)              0         
_________________________________________________________________
dense_20 (Dense)             (None, 225)               288225    
_________________________________________________________________
activation_12 (Activation)   (None, 225)               0         
Total params: 2,546,209
Trainable params: 288,225
Non-trainable params: 2,257,984
_________________________________________________________________


In [None]:

history = teacher_model.fit(train_generator, 
                            batch_size=BATCH_SIZE,
                            validation_data=validation_generator,
                            steps_per_epoch=train_generator.samples//BATCH_SIZE,
                            validation_steps=validation_generator.samples//BATCH_SIZE,
                            epochs=15)

In [None]:
teacher_model.evaluate(test_generator)

In [None]:
#Fine tuning
base_model.trainable=True

In [None]:
print(f"The base_model has: {len(base_model.layers)} layers")

The base_model has: 155 layers


In [None]:
fine_tune_at = 100

for layer in base_model.layers[:fine_tune_at]:
  layer.trainable=False

In [None]:
teacher_model.summary()

Model: "functional_11"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_21 (InputLayer)        [(None, 96, 96, 3)]       0         
_________________________________________________________________
mobilenetv2_1.00_96 (Functio (None, 3, 3, 1280)        2257984   
_________________________________________________________________
global_average_pooling2d_10  (None, 1280)              0         
_________________________________________________________________
dense_20 (Dense)             (None, 225)               288225    
_________________________________________________________________
activation_12 (Activation)   (None, 225)               0         
Total params: 2,546,209
Trainable params: 288,225
Non-trainable params: 2,257,984
_________________________________________________________________


In [None]:
calls = [
             tf.keras.callbacks.EarlyStopping(patience=5, restore_best_weights=True),
             tf.keras.callbacks.ModelCheckpoint('densenet_teacher.h5', save_best_only=True, save_weights_only=True)
]
history = teacher_model.fit(train_generator, 
                            batch_size=BATCH_SIZE,
                            validation_data=validation_generator,
                            steps_per_epoch=train_generator.samples//BATCH_SIZE,
                            validation_steps=validation_generator.samples//BATCH_SIZE,
                            epochs=100,
                            callbacks=calls)

Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100


In [None]:
teacher_model.evaluate(test_generator)



[0.34798598289489746, 0.9111111164093018]

## Student model

In this section we will build a model of complexity comparable to the teacher and try to apply knowledge distillation as in the paper.

In [None]:
student_network = tf.keras.models.Sequential(
    [
     WideResidualNetwork(225, 16, 4, includeActivation=False),
     tf.keras.layers.Activation('softmax')
    ]
)
student_network.build(input_shape=(None, 96, 96, 3))
student_network.summary()

Model: "sequential_12"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
wide_residual_network_12 (Wi (None, 225)               2836845   
_________________________________________________________________
activation_25 (Activation)   (None, 225)               0         
Total params: 2,836,845
Trainable params: 2,833,639
Non-trainable params: 3,206
_________________________________________________________________


In [None]:
student_network.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])

In [None]:
BATCH_SIZE=32
callbacks = [
             tf.keras.callbacks.EarlyStopping(patience=5, restore_best_weights=True, monitor='loss'),
             tf.keras.callbacks.ModelCheckpoint('student.h5', save_weights_only=True, save_best_only=True)
]
fit_args = dict(
    epochs=100,
    validation_data=validation_generator,
    validation_steps=validation_generator.samples//BATCH_SIZE,
    steps_per_epoch=train_generator.samples//BATCH_SIZE,
    callbacks = callbacks

)
s_history = distill_knowledge(teacher_model, student_network, train_generator, fit_args=fit_args)

## Results

The results on test dataset are encouraging even with different models.

In [None]:
teacher_model.evaluate(test_generator)
student_network.evaluate(test_generator)




[0.23009125888347626, 0.9333333373069763]