## Transfer Learning and Knowledge Distillation
 In this notebook, the goal is to apply knowledge distillation to a student other than the master but with comparable complexity. To speed up the teacher's training process we will use a pre-trained model.

In [1]:
#All the imports
import tensorflow as tf
from BANEnsemble import BANEnsemble
from knowledge_distillation import distill_knowledge, ban
from WideResNet import WideResidualNetwork

In [2]:
# Connection with the dataset in drive
from google.colab import drive
drive.mount("/content/drive")

Mounted at /content/drive


In [3]:
! unzip /content/drive/My\ Drive/datasets/bird_species.zip 

[1;30;43mOutput streaming troncato alle ultime 5000 righe.[0m
  inflating: train/SUPERB STARLING/069.jpg  
  inflating: train/SUPERB STARLING/070.jpg  
  inflating: train/SUPERB STARLING/071.jpg  
  inflating: train/SUPERB STARLING/072.jpg  
  inflating: train/SUPERB STARLING/073.jpg  
  inflating: train/SUPERB STARLING/074.jpg  
  inflating: train/SUPERB STARLING/075.jpg  
  inflating: train/SUPERB STARLING/076.jpg  
  inflating: train/SUPERB STARLING/077.jpg  
  inflating: train/SUPERB STARLING/078.jpg  
  inflating: train/SUPERB STARLING/079.jpg  
  inflating: train/SUPERB STARLING/080.jpg  
  inflating: train/SUPERB STARLING/081.jpg  
  inflating: train/SUPERB STARLING/082.jpg  
  inflating: train/SUPERB STARLING/083.jpg  
  inflating: train/SUPERB STARLING/084.jpg  
  inflating: train/SUPERB STARLING/085.jpg  
  inflating: train/SUPERB STARLING/086.jpg  
  inflating: train/SUPERB STARLING/087.jpg  
  inflating: train/SUPERB STARLING/088.jpg  
  inflating: train/SUPERB STARLING/0

## Preprocessing and Augmentation
In order to maintain consistency with previous tests, the preprocessing and image augmentation will be the same

In [4]:
BATCH_SIZE = 32
IMAGE_SIZE = (96, 96)
data_path = '/content'

train_datagen = tf.keras.preprocessing.image.ImageDataGenerator(
        rescale=1./255,
        horizontal_flip=True,
        width_shift_range=0.1,
        height_shift_range=0.1,
        rotation_range=0.05        )  
valid_datagen = tf.keras.preprocessing.image.ImageDataGenerator(
    rescale=1./255,
    )

test_datagen = tf.keras.preprocessing.image.ImageDataGenerator(
    rescale=1./255
)
train_generator = train_datagen.flow_from_directory(
        data_path+'/train',
        target_size=IMAGE_SIZE,
        batch_size=BATCH_SIZE,
        class_mode='categorical',
        )
validation_generator = valid_datagen.flow_from_directory(
        data_path+'/valid',
        target_size=IMAGE_SIZE,
        batch_size=BATCH_SIZE,
        class_mode='categorical',
        )
test_generator = test_datagen.flow_from_directory(
        data_path+'/test',
        target_size=IMAGE_SIZE,
        batch_size=BATCH_SIZE,
        class_mode='categorical')

Found 31316 images belonging to 225 classes.
Found 1125 images belonging to 225 classes.
Found 1125 images belonging to 225 classes.


In [None]:
base_model = tf.keras.applications.MobileNetV2(include_top=False, input_shape=(96, 96, 3), classes=225)


Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/mobilenet_v2/mobilenet_v2_weights_tf_dim_ordering_tf_kernels_1.0_96_no_top.h5


In [6]:
teacher_model = tf.keras.models.load_model('mobile_net.h5')

In [None]:
base_model.trainable=False


In [None]:
inputs = tf.keras.layers.Input(shape=(96, 96, 3))
x = base_model(inputs, training=False)
x = tf.keras.layers.GlobalAveragePooling2D()(x)
x = tf.keras.layers.Dense(225)(x)
x = tf.keras.layers.Activation('softmax')(x)

teacher_model = tf.keras.models.Model(inputs, x)

teacher_model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
teacher_model.summary()

Model: "functional_11"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_21 (InputLayer)        [(None, 96, 96, 3)]       0         
_________________________________________________________________
mobilenetv2_1.00_96 (Functio (None, 3, 3, 1280)        2257984   
_________________________________________________________________
global_average_pooling2d_10  (None, 1280)              0         
_________________________________________________________________
dense_20 (Dense)             (None, 225)               288225    
_________________________________________________________________
activation_12 (Activation)   (None, 225)               0         
Total params: 2,546,209
Trainable params: 288,225
Non-trainable params: 2,257,984
_________________________________________________________________


In [None]:

history = teacher_model.fit(train_generator, 
                            batch_size=BATCH_SIZE,
                            validation_data=validation_generator,
                            steps_per_epoch=train_generator.samples//BATCH_SIZE,
                            validation_steps=validation_generator.samples//BATCH_SIZE,
                            epochs=15)

Epoch 1/15
 22/978 [..............................] - ETA: 1:46 - loss: 0.1592 - accuracy: 0.9517

KeyboardInterrupt: ignored

In [7]:
#teacher_model.save('mobile_net.h5')
teacher_model.evaluate(test_generator)



[0.2923920154571533, 0.9182222485542297]

In [None]:
base_model.trainable=True

In [None]:
print(f"The base_model has: {len(base_model.layers)} layers")

The base_model has: 155 layers


In [None]:
fine_tune_at = 100

for layer in base_model.layers[:fine_tune_at]:
  layer.trainable=False

In [None]:
teacher_model.summary()

Model: "functional_11"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_21 (InputLayer)        [(None, 96, 96, 3)]       0         
_________________________________________________________________
mobilenetv2_1.00_96 (Functio (None, 3, 3, 1280)        2257984   
_________________________________________________________________
global_average_pooling2d_10  (None, 1280)              0         
_________________________________________________________________
dense_20 (Dense)             (None, 225)               288225    
_________________________________________________________________
activation_12 (Activation)   (None, 225)               0         
Total params: 2,546,209
Trainable params: 288,225
Non-trainable params: 2,257,984
_________________________________________________________________


In [None]:
calls = [
             tf.keras.callbacks.EarlyStopping(patience=5, restore_best_weights=True),
             tf.keras.callbacks.ModelCheckpoint('densenet_teacher.h5', save_best_only=True, save_weights_only=True)
]
history = teacher_model.fit(train_generator, 
                            batch_size=BATCH_SIZE,
                            validation_data=validation_generator,
                            steps_per_epoch=train_generator.samples//BATCH_SIZE,
                            validation_steps=validation_generator.samples//BATCH_SIZE,
                            epochs=100,
                            callbacks=calls)

Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100


In [None]:
teacher_model.evaluate(test_generator)



[0.34798598289489746, 0.9111111164093018]

In [None]:
teacher_model.layers[-4].trainable= True
teacher_model.summary()
for layer in teacher_model.layers[-4].layers[:fine_tune_at]:
  layer.trainable=False
teacher_model.summary()

Model: "functional_11"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_21 (InputLayer)        [(None, 96, 96, 3)]       0         
_________________________________________________________________
mobilenetv2_1.00_96 (Functio (None, 3, 3, 1280)        2257984   
_________________________________________________________________
global_average_pooling2d_10  (None, 1280)              0         
_________________________________________________________________
dense_20 (Dense)             (None, 225)               288225    
_________________________________________________________________
activation_12 (Activation)   (None, 225)               0         
Total params: 2,546,209
Trainable params: 2,512,097
Non-trainable params: 34,112
_________________________________________________________________
Model: "functional_11"
_________________________________________________________________
Layer (type)   

In [20]:
student_network = tf.keras.models.Sequential(
    [
     WideResidualNetwork(225, 16, 4, includeActivation=False),
     tf.keras.layers.Activation('softmax')
    ]
)
student_network.build(input_shape=(None, 96, 96, 3))
student_network.summary()

Model: "sequential_12"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
wide_residual_network_12 (Wi (None, 225)               2836845   
_________________________________________________________________
activation_25 (Activation)   (None, 225)               0         
Total params: 2,836,845
Trainable params: 2,833,639
Non-trainable params: 3,206
_________________________________________________________________


In [24]:
student_network.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])

In [27]:
BATCH_SIZE=32
callbacks = [
             tf.keras.callbacks.EarlyStopping(patience=5, restore_best_weights=True, monitor='loss'),
             tf.keras.callbacks.ModelCheckpoint('student.h5', save_weights_only=True, save_best_only=True)
]
fit_args = dict(
    epochs=100,
    validation_data=validation_generator,
    validation_steps=validation_generator.samples//BATCH_SIZE,
    steps_per_epoch=train_generator.samples//BATCH_SIZE,
    callbacks = callbacks

)
distill_knowledge(teacher_model, student_network, train_generator, fit_args=fit_args)

Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Epoch 21/100
Epoch 22/100
Epoch 23/100
Epoch 24/100
Epoch 25/100
Epoch 26/100
Epoch 27/100
Epoch 28/100
Epoch 29/100
Epoch 30/100
Epoch 31/100
Epoch 32/100
Epoch 33/100
Epoch 34/100
Epoch 35/100
Epoch 36/100
Epoch 37/100
Epoch 38/100
Epoch 39/100
Epoch 40/100
Epoch 41/100
Epoch 42/100
Epoch 43/100
Epoch 44/100
Epoch 45/100
Epoch 46/100
Epoch 47/100
Epoch 48/100
Epoch 49/100
Epoch 50/100
Epoch 51/100
Epoch 52/100
Epoch 53/100
Epoch 54/100
Epoch 55/100
Epoch 56/100
Epoch 57/100
Epoch 58/100
Epoch 59/100
Epoch 60/100
Epoch 61/100
Epoch 62/100
Epoch 63/100
Epoch 64/100
Epoch 65/100
Epoch 66/100
Epoch 67/100
Epoch 68/100
Epoch 69/100
Epoch 70/100
Epoch 71/100
Epoch 72/100
Epoch 73/100
Epoch 74/100
Epoch 75/100
Epoch 76/100
Epoch 77/100
Epoch 78

<tensorflow.python.keras.callbacks.History at 0x7f0b694da160>

In [31]:
teacher_model.evaluate(test_generator)
student_network.load_weights('student.h5')
student_network.evaluate(test_generator)




[0.23009125888347626, 0.9333333373069763]