In [2]:
from tensorflow.keras.metrics import CategoricalAccuracy
import tensorflow as tf
from data_loader import get_loader
from tensorflow import keras
import matplotlib.pyplot as plt
from tensorflow.keras.utils import to_categorical
import numpy as np
from models.distances import Euclidean_Distance
from models.stn import BilinearInterpolation,Localization
from sklearn.neighbors import KNeighborsClassifier
from time import time
import os
from models.metrics import accuracy
from tensorflow.keras.metrics import Mean
from sklearn.metrics.pairwise import euclidean_distances as dist
from models.distances import Weighted_Euclidean_Distance
from models.metrics import loss_mse

In [3]:
class Distiller(keras.Model):
    def __init__(self, student, teacher):
        super().__init__()
        self.teacher = teacher
        self.student = student

    def compile(
        self,
        optimizer,
        metrics,
        student_loss_fn,
        distillation_loss_fn,
        alpha=0.1,
        temperature=3,
    ):
        """ Configure the distiller.

        Args:
            optimizer: Keras optimizer for the student weights
            metrics: Keras metrics for evaluation
            student_loss_fn: Loss function of difference between student
                predictions and ground-truth
            distillation_loss_fn: Loss function of difference between soft
                student predictions and soft teacher predictions
            alpha: weight to student_loss_fn and 1-alpha to distillation_loss_fn
            temperature: Temperature for softening probability distributions.
                Larger temperature gives softer distributions.
        """
        super().compile(optimizer=optimizer, metrics=metrics)
        self.student_loss_fn = student_loss_fn
        self.distillation_loss_fn = distillation_loss_fn
        self.alpha = alpha
        self.temperature = temperature

    def train_step(self, data):
        # Unpack data
        x, y = data

        # Forward pass of teacher
        teacher_predictions = self.teacher(x, training=False)

        with tf.GradientTape() as tape:
            # Forward pass of student
            student_predictions = self.student(x, training=True)

            # Compute losses
            student_loss = self.student_loss_fn(y, student_predictions)

            # Compute scaled distillation loss from https://arxiv.org/abs/1503.02531
            # The magnitudes of the gradients produced by the soft targets scale
            # as 1/T^2, multiply them by T^2 when using both hard and soft targets.
            distillation_loss = (
                self.distillation_loss_fn(
                    tf.nn.softmax(teacher_predictions / self.temperature, axis=1),
                    tf.nn.softmax(student_predictions / self.temperature, axis=1),
                )
                * self.temperature**2
            )

            loss = self.alpha * student_loss + (1 - self.alpha) * distillation_loss

        # Compute gradients
        trainable_vars = self.student.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)

        # Update weights
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))

        # Update the metrics configured in `compile()`.
        self.compiled_metrics.update_state(y, student_predictions)

        # Return a dict of performance
        results = {m.name: m.result() for m in self.metrics}
        results.update(
            {"student_loss": student_loss, "distillation_loss": distillation_loss}
        )
        return results

    def test_step(self, data):
        # Unpack the data
        [xs,xq], y = data

        # Compute predictions
        y_prediction = self.student([xs,xq], training=False)

        # Calculate the loss
        student_loss = self.student_loss_fn(y, y_prediction)

        # Update the metrics.
        self.compiled_metrics.update_state(y, y_prediction)

        # Return a dict of performance
        results = {m.name: m.result() for m in self.metrics}
        results.update({"student_loss": student_loss})
        return results

In [27]:
# This module generates encoders for feature extracting
from tensorflow import keras
from tensorflow.keras.layers import Input,Flatten,Dense,MaxPooling2D,Conv2D
from tensorflow.keras.layers import BatchNormalization, Activation, Dropout
from models.blocks import conv_block,dcp
from models.distances import Weighted_Euclidean_Distance
from models.stn import stn
from tensorflow.keras.applications import DenseNet121
from models.senet import Senet

def create_densenet(input_shape = (64,64,3)):
    inp = Input(input_shape)
    
    densnet = DenseNet121(
        input_shape=input_shape,weights='imagenet',include_top=False)
    dens_encoder = keras.Model(
        inputs=densnet.inputs,
        outputs=densnet.get_layer('conv3_block2_concat').output)
    x = dens_encoder(inp)
    x = Conv2D(
        kernel_size=(1,1),filters=32,padding='same',
        kernel_initializer='he_normal')(x)
    x = conv_block(x,kernel_size=(3,3),n_filters=32,strides=(1,1))
    x = MaxPooling2D(pool_size=(2,2),strides=(2,2))(x)
    x = Flatten()(x)
    x = Dense(units = 150,kernel_initializer="he_normal")(x)
    x = BatchNormalization(axis=-1)(x)
    x = Activation('linear')(x)
    return keras.Model(inp,x,name='encoder')

def create_model(input_shape = (64,64,3)):
    support = Input(input_shape)
    query = Input(input_shape)
    encoder = create_densenet(input_shape=input_shape)
    encoder.summary()
    support_features = encoder(support)
    query_features = encoder(query)
    dist = Weighted_Euclidean_Distance()([support_features,query_features])
    out = Activation("softmax")(dist)
    return Senet(inputs = [support,query],outputs=out)


In [38]:
student = create_model()

Model: "encoder"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_19 (InputLayer)       [(None, 64, 64, 3)]       0         
                                                                 
 model_4 (Functional)        (None, 8, 8, 192)         494528    
                                                                 
 conv2d_8 (Conv2D)           (None, 8, 8, 32)          6176      
                                                                 
 conv2d_9 (Conv2D)           (None, 6, 6, 32)          9248      
                                                                 
 batch_normalization_20 (Bat  (None, 6, 6, 32)         128       
 chNormalization)                                                
                                                                 
 activation_24 (Activation)  (None, 6, 6, 32)          0         
                                                           

In [39]:
student.summary()

Model: "senet_4"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_17 (InputLayer)          [(None, 64, 64, 3)]  0           []                               
                                                                                                  
 input_18 (InputLayer)          [(None, 64, 64, 3)]  0           []                               
                                                                                                  
 encoder (Functional)           (None, 150)          554030      ['input_17[0][0]',               
                                                                  'input_18[0][0]']               
                                                                                                  
 weighted__euclidean__distance_  (None, None)        1           ['encoder[0][0]',          

In [40]:
teacher = keras.models.load_model(
    'model_files/best_models/densenet_gtsrb2tt100k_whole.h5',
    custom_objects={'Weighted_Euclidean_Distance':Weighted_Euclidean_Distance,'BilinearInterpolation':BilinearInterpolation,'Localization':Localization,'Senet':Senet},compile=False)

In [41]:
optimizer_fn = keras.optimizers.Adam(learning_rate=1e-4,epsilon=1.0e-8)
teacher.compile(optimizer=optimizer_fn,loss_fn=loss_mse,metrics=CategoricalAccuracy(name = 'accuracy'))

In [42]:
student.compile(optimizer=optimizer_fn,loss_fn=loss_mse,metrics=CategoricalAccuracy(name = 'accuracy'))

In [43]:
teacher.summary()

Model: "senet"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_1 (InputLayer)           [(None, 64, 64, 3)]  0           []                               
                                                                                                  
 input_2 (InputLayer)           [(None, 64, 64, 3)]  0           []                               
                                                                                                  
 encoder (Functional)           (None, 300)          1585680     ['input_1[0][0]',                
                                                                  'input_2[0][0]']                
                                                                                                  
 weighted__euclidean__distance   (None, None)        1           ['encoder[0][0]',            

In [44]:
loader = get_loader('gtsrb2tt100k') 
train_gen,test_gen = loader.get_generator(batch=128,dim=64)

In [45]:
distiller = Distiller(student=student, teacher=teacher)
distiller.compile(
    optimizer=optimizer_fn,
    metrics=[keras.metrics.CategoricalAccuracy()],
    student_loss_fn=loss_mse,
    distillation_loss_fn=keras.losses.KLDivergence(),
    alpha=0.1,
    temperature=3,
)

In [46]:
best_test_acc = 0
ep = 10
for epoch in range(ep):
    print(f'====epoch{epoch+1}/{ep}====')
    distiller.fit(train_gen)
    #train(
        #model=senet,generator=train_datagen,
        #optimizer=optimizer_fn)
    #te_acc = test(model=senet,generator=test_datagen)
    te_acc,_ = distiller.evaluate(test_gen, verbose=0)
    if te_acc > best_test_acc:
        best_test_acc = te_acc
        student.save('best_student.h5')
    print(f'test accuracy: {te_acc:.4f}')
    print(f'best test accuracy: {best_test_acc:.4f}')

====epoch1/10====
test accuracy: 0.7112
best test accuracy: 0.7112
====epoch2/10====
test accuracy: 0.7385
best test accuracy: 0.7385
====epoch3/10====
test accuracy: 0.7447
best test accuracy: 0.7447
====epoch4/10====
test accuracy: 0.7382
best test accuracy: 0.7447
====epoch5/10====
test accuracy: 0.7214
best test accuracy: 0.7447
====epoch6/10====
test accuracy: 0.7245
best test accuracy: 0.7447
====epoch7/10====
test accuracy: 0.7539
best test accuracy: 0.7539
====epoch8/10====
test accuracy: 0.6993
best test accuracy: 0.7539
====epoch9/10====
test accuracy: 0.7267
best test accuracy: 0.7539
====epoch10/10====
test accuracy: 0.7260
best test accuracy: 0.7539


In [30]:
teacher.evaluate(test_gen)



0.9348958134651184