In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
# for dirname, _, filenames in os.walk('/kaggle/input'):
#     for filename in filenames:
#         print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
pip install tensorflow==2.3

In [None]:
import tensorflow as tf
tf.__version__

In [None]:
import matplotlib.pyplot as plt
import pandas as pd
import io
import cv2
import numpy as np
from os import listdir
from os.path import isfile, join

import sklearn
from sklearn.model_selection import train_test_split
from sklearn.utils import shuffle
from sklearn.preprocessing import OneHotEncoder
from sklearn import svm
from sklearn.metrics import accuracy_score
from sklearn.model_selection import GridSearchCV

#import keras
import tensorflow as tf
from tensorflow.keras.applications.vgg16 import VGG16
from tensorflow.keras.applications.inception_v3 import InceptionV3
from tensorflow.keras.applications import ResNet50

from tensorflow.keras import layers
from tensorflow.keras.models import Model
from tensorflow.keras.models import Sequential, load_model
from tensorflow.keras.layers import Dense, Dropout, Activation, Flatten, Conv2D, MaxPooling2D
from skimage.feature import hog
from skimage import data, exposure

import tensorflow_addons as tfa
import random
from tqdm import tqdm

In [None]:
# !pip install tensorflow_addons
# !pip install keras_applications 
# !pip install keras_preprocessing 
# !pip install git+https://github.com/rcmalli/keras-vggface.git

In [None]:
from tensorflow.keras import backend as K
def preprocess_input(x, data_format=None, version=1):
    x_temp = np.copy(x)
    if data_format is None:
        data_format = K.image_data_format()
    assert data_format in {'channels_last', 'channels_first'}

    if version == 1:
        if data_format == 'channels_first':
            x_temp = x_temp[:, ::-1, ...]
            x_temp[:, 0, :, :] -= 93.5940
            x_temp[:, 1, :, :] -= 104.7624
            x_temp[:, 2, :, :] -= 129.1863
        else:
            x_temp = x_temp[..., ::-1]
            x_temp[..., 0] -= 93.5940
            x_temp[..., 1] -= 104.7624
            x_temp[..., 2] -= 129.1863

    elif version == 2:
        if data_format == 'channels_first':
            x_temp = x_temp[:, ::-1, ...]
            x_temp[:, 0, :, :] -= 91.4953
            x_temp[:, 1, :, :] -= 103.8827
            x_temp[:, 2, :, :] -= 131.0912
        else:
            x_temp = x_temp[..., ::-1]
            x_temp[..., 0] -= 91.4953
            x_temp[..., 1] -= 103.8827
            x_temp[..., 2] -= 131.0912
    else:
        raise NotImplementedError

    return x_temp

In [None]:
class SiameseNetwork(tf.keras.Model):
    def __init__(self, vgg_face):
        super(SiameseNetwork, self).__init__()
        self.vgg_face = vgg_face
        
    @tf.function
    def call(self, inputs):
        image_1, image_2, image_3 =  inputs
        with tf.name_scope("Anchor") as scope:
            feature_1 = self.vgg_face(image_1)
            feature_1 = tf.math.l2_normalize(feature_1, axis=-1)
        with tf.name_scope("Positive") as scope:
            feature_2 = self.vgg_face(image_2)
            feature_2 = tf.math.l2_normalize(feature_2, axis=-1)
        with tf.name_scope("Negative") as scope:
            feature_3 = self.vgg_face(image_3)
            feature_3 = tf.math.l2_normalize(feature_3, axis=-1)
        return [feature_1, feature_2, feature_3]
    
    @tf.function
    def get_features(self, inputs):
        return tf.math.l2_normalize(self.vgg_face(inputs, training=False), axis=-1)

In [None]:
class DataGenerator(tf.keras.utils.Sequence):
    def __init__(self, dataset_path, batch_size=5, shuffle=True):
        self.dataset = self.curate_dataset(dataset_path)
        self.dataset_path = dataset_path
        self.shuffle = shuffle
        self.batch_size =batch_size
        self.no_of_people = len(list(self.dataset.keys()))
        self.on_epoch_end()
        print(self.dataset.keys())
        
    def __getitem__(self, index):
        people = list(self.dataset.keys())[index * self.batch_size: (index + 1) * self.batch_size]
        P = []
        A = []
        N = []
        
        for person in people:
            anchor_index = random.randint(0, len(self.dataset[person])-1)
            a = self.get_image(person, anchor_index)
            
            positive_index = random.randint(0, len(self.dataset[person])-1)
            while positive_index == anchor_index and len(self.dataset[person]) != 1:
                positive_index = random.randint(0, len(self.dataset[person])-1)
            p = self.get_image(person, positive_index)
            
            negative_person_index = random.randint(0, self.no_of_people - 1)
            negative_person = list(self.dataset.keys())[negative_person_index]
            while negative_person == person:
                negative_person_index = random.randint(0, self.no_of_people - 1)
                negative_person = list(self.dataset.keys())[negative_person_index]
            
            negative_index = random.randint(0, len(self.dataset[negative_person])-1)
            n = self.get_image(negative_person, negative_index)
            P.append(p)
            A.append(a)
            N.append(n)
        A = np.asarray(A)
        N = np.asarray(N)
        P = np.asarray(P)
        return [A, P, N]
        
    def __len__(self):
        return self.no_of_people // self.batch_size
        
    def curate_dataset(self, dataset_path):
        dataset = {}
        dirs = [dir for dir in listdir(dataset_path)]
        for dir in dirs: 
            fichiers = [f for f in listdir(dataset_path+dir) if "jpeg" in f or "png" in f]
            for f in fichiers:
                if dir in dataset.keys():
                    dataset[dir].append(f)
                else:
                    dataset[dir] = [f]
        return dataset
    
    def on_epoch_end(self):
        if self.shuffle:
            keys = list(self.dataset.keys())
            random.shuffle(keys)
            dataset_ =  {}
            for key in keys:
                dataset_[key] = self.dataset[key]
            self.dataset = dataset_
            
    def get_image(self, person, index):
        img = cv2.imread(os.path.join(self.dataset_path, os.path.join(person, self.dataset[person][index])))
        img = cv2.resize(img, (224, 224))
        img = np.asarray(img, dtype=np.float64)
        img = preprocess_input(img)
        return img

In [None]:
K = tf.keras.backend
def loss_function(x, alpha = 0.2):
    # Triplet Loss function.
    anchor,positive,negative = x
    # distance between the anchor and the positive
    pos_dist = K.sum(K.square(anchor-positive),axis=1)
    # distance between the anchor and the negative
    neg_dist = K.sum(K.square(anchor-negative),axis=1)
    # compute loss
    basic_loss = pos_dist-neg_dist+alpha
    loss = K.mean(K.maximum(basic_loss,0.0))
    return loss

In [None]:
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import ZeroPadding2D, Convolution2D, MaxPooling2D, Dropout, Flatten, Activation

def vgg_face():	
    model = Sequential()
    model.add(ZeroPadding2D((1,1),input_shape=(224,224, 3)))
    model.add(Convolution2D(64, (3, 3), activation='relu'))
    model.add(ZeroPadding2D((1,1)))
    model.add(Convolution2D(64, (3, 3), activation='relu'))
    model.add(MaxPooling2D((2,2), strides=(2,2)))
    
    model.add(ZeroPadding2D((1,1)))
    model.add(Convolution2D(128, (3, 3), activation='relu'))
    model.add(ZeroPadding2D((1,1)))
    model.add(Convolution2D(128, (3, 3), activation='relu'))
    model.add(MaxPooling2D((2,2), strides=(2,2)))
    
    model.add(ZeroPadding2D((1,1)))
    model.add(Convolution2D(256, (3, 3), activation='relu'))
    model.add(ZeroPadding2D((1,1)))
    model.add(Convolution2D(256, (3, 3), activation='relu'))
    model.add(ZeroPadding2D((1,1)))
    model.add(Convolution2D(256, (3, 3), activation='relu'))
    model.add(MaxPooling2D((2,2), strides=(2,2)))
    
    model.add(ZeroPadding2D((1,1)))
    model.add(Convolution2D(512, (3, 3), activation='relu'))
    model.add(ZeroPadding2D((1,1)))
    model.add(Convolution2D(512, (3, 3), activation='relu'))
    model.add(ZeroPadding2D((1,1)))
    model.add(Convolution2D(512, (3, 3), activation='relu'))
    model.add(MaxPooling2D((2,2), strides=(2,2)))
    
    model.add(ZeroPadding2D((1,1)))
    model.add(Convolution2D(512, (3, 3), activation='relu'))
    model.add(ZeroPadding2D((1,1)))
    model.add(Convolution2D(512, (3, 3), activation='relu'))
    model.add(ZeroPadding2D((1,1)))
    model.add(Convolution2D(512, (3, 3), activation='relu'))
    model.add(MaxPooling2D((2,2), strides=(2,2)))
    
    model.add(Convolution2D(4096, (7, 7), activation='relu'))
    model.add(Dropout(0.5))
    model.add(Convolution2D(4096, (1, 1), activation='relu'))
    model.add(Dropout(0.5))
    model.add(Convolution2D(2622, (1, 1)))
    model.add(Flatten())
    model.add(Activation('softmax'))
    return model

# Knowledge distillation

In [None]:
optimizer = tf.keras.optimizers.Adam(learning_rate=0.00006)
#binary_cross_entropy = tf.keras.losses.BinaryCrossentropy()
def train(X):
    with tf.GradientTape() as tape:
        y_pred = model(X)
        loss = loss_function(y_pred)
    grad = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(grad, model.trainable_variables))
    return loss

In [None]:
data_generator = DataGenerator(dataset_path='../input/dataset4/dataset3/train/', batch_size=5)


losses = []
accuracy = []
epochs = 5
no_of_batches = data_generator.__len__()
for i in range(1, epochs+1, 1):
    loss = 0
    with tqdm(total=no_of_batches) as pbar:
        
        description = "Epoch " + str(i) + "/" + str(epochs)
        pbar.set_description_str(description)
        
        for j in range(no_of_batches):
            data = data_generator[j]
            temp = train(data)
            loss += temp
            
            pbar.update()
            print_statement = "Loss :" + str(temp.numpy())
            pbar.set_postfix_str(print_statement)
        
        loss /= no_of_batches
        losses.append(loss.numpy())
        # with file_writer.as_default():
        #     tf.summary.scalar('Loss', data=loss.numpy(), step=i)
            
        print_statement = "Loss :" + str(loss.numpy())
        
        pbar.set_postfix_str(print_statement)



In [None]:
data_generator = DataGenerator(dataset_path='../input/real-time-dataset/dataset4/train/')
train_dict = data_generator.curate_dataset('../input/real-time-dataset/dataset4/train/')
labels_train = []
features_train = []
images_train = []

for k, v in train_dict.items():
    images = []
    for e in v:
        image_path = '../input/real-time-dataset/dataset4/train/' + str(k) + '/' + str(e)
        image = cv2.imread(image_path)
        image = np.asarray(image, dtype=np.float64)
        image = preprocess_input(image)
        images_train.append(image)
#         img_features = model.get_features(np.expand_dims(image, axis=0))
#         features_train.append(img_features[0].numpy())
        labels_train.append(k)
    
images_train = np.asarray(images_train)
#features_train = np.asarray(features_train)

In [None]:
data_generator = DataGenerator(dataset_path='../input/real-time-dataset/dataset4/test/')
test_dict = data_generator.curate_dataset('../input/real-time-dataset/dataset4/test/')
labels_test = []
features_test = []
images_test = []


for k, v in test_dict.items():
    if k in train_dict.keys():
        images = []
        for e in v:
            image_path = '../input/real-time-dataset/dataset4/test/' + str(k) + '/' + str(e)
            image = cv2.imread(image_path)
            image = np.asarray(image, dtype=np.float64)

            image = preprocess_input(image)
            images_test.append(image)
#             img_features = model.get_features(np.expand_dims(image, axis=0))
#             features_test.append(img_features[0].numpy())
            labels_test.append(k)


images_test = np.asarray(images_test)
#features_test = np.asarray(features_test)

In [None]:
from sklearn import preprocessing
le = preprocessing.LabelEncoder()
le.fit(labels_train)
labels_train = le.transform(labels_train)
labels_test = le.transform(labels_test)

In [None]:
from sklearn.utils import shuffle
features_train, labels_train = shuffle(features_train, labels_train)
features_test, labels_test = shuffle(features_test, labels_test)

In [None]:
n_classes = len(set(labels_train))

In [None]:
from sklearn.model_selection import GridSearchCV
from sklearn.svm import SVC
clf = SVC(C=10, gamma=1, kernel='rbf',  probability=True)
clf.fit(features_train, labels_train)

In [None]:
from sklearn.metrics import accuracy_score
preds = clf.predict(features_test)
accuracy_score(labels_test, preds)

In [None]:
model.summary()

In [None]:
def model_from_scratch():	
    model = Sequential()
    model.add(ZeroPadding2D((1,1),input_shape=(224,224, 3)))
    model.add(Convolution2D(64, (3, 3), activation='relu'))
    model.add(ZeroPadding2D((1,1)))
    model.add(Convolution2D(64, (3, 3), activation='relu'))
    model.add(MaxPooling2D((2,2), strides=(2,2)))
    
    model.add(ZeroPadding2D((1,1)))
    model.add(Convolution2D(128, (3, 3), activation='relu'))
    model.add(ZeroPadding2D((1,1)))
    model.add(Convolution2D(128, (3, 3), activation='relu'))
    model.add(MaxPooling2D((2,2), strides=(2,2)))
    
    model.add(ZeroPadding2D((1,1)))
    model.add(Convolution2D(256, (3, 3), activation='relu'))
    model.add(ZeroPadding2D((1,1)))
    model.add(Convolution2D(256, (3, 3), activation='relu'))
#     model.add(ZeroPadding2D((1,1)))
#     model.add(Convolution2D(256, (3, 3), activation='relu'))
    model.add(MaxPooling2D((2,2), strides=(2,2)))
    
    model.add(ZeroPadding2D((1,1)))
    model.add(Convolution2D(512, (3, 3), activation='relu'))
    model.add(ZeroPadding2D((1,1)))
    model.add(Convolution2D(512, (3, 3), activation='relu'))
#     model.add(ZeroPadding2D((1,1)))
#     model.add(Convolution2D(512, (3, 3), activation='relu'))
    model.add(MaxPooling2D((2,2), strides=(2,2)))
    
    model.add(ZeroPadding2D((1,1)))
    model.add(Convolution2D(512, (3, 3), activation='relu'))
#     model.add(ZeroPadding2D((1,1)))
#     model.add(Convolution2D(512, (3, 3), activation='relu'))
#     model.add(ZeroPadding2D((1,1)))
#     model.add(Convolution2D(512, (3, 3), activation='relu'))
    model.add(MaxPooling2D((2,2), strides=(2,2)))
    
#     model.add(Convolution2D(4096, (7, 7), activation='relu'))
#     model.add(Dropout(0.5))
#     model.add(Convolution2D(4096, (1, 1), activation='relu'))
#     model.add(Dropout(0.5))
#     model.add(Convolution2D(2622, (1, 1)))
    model.add(Flatten())
    model.add(layers.Dense(2*1024, activation='relu'))
    model.add(layers.Dropout(0.5)) 
    model.add(layers.Dense(1024, activation='relu'))
    model.add(layers.Dropout(0.5))
    model.add(layers.Dense(1024//2, activation='relu'))
    model.add(layers.Dropout(0.5)) 
    
    # Number of classes !!!
    model.add(Dense(n_classes))
    
    
    return model

In [None]:
student = model_from_scratch()
student.summary()

In [None]:
# class Distiller(keras.Model):
#     def __init__(self, student, teacher_part1, teacher_part2):
#         super(Distiller, self).__init__()
#         self.teacher_part1 = teacher_part1
#         self.teacher_part2 = teacher_part2
#         self.student = student

#     def compile(
#         self,
#         optimizer,
#         metrics,
#         student_loss_fn,
#         distillation_loss_fn,
#         alpha=0.1,
#         temperature=3,
#     ):
#         """ Configure the distiller.

#         Args:
#             optimizer: Keras optimizer for the student weights
#             metrics: Keras metrics for evaluation
#             student_loss_fn: Loss function of difference between student
#                 predictions and ground-truth
#             distillation_loss_fn: Loss function of difference between soft
#                 student predictions and soft teacher predictions
#             alpha: weight to student_loss_fn and 1-alpha to distillation_loss_fn
#             temperature: Temperature for softening probability distributions.
#                 Larger temperature gives softer distributions.
#         """
#         super(Distiller, self).compile(optimizer=optimizer, metrics=metrics)
#         self.student_loss_fn = student_loss_fn
#         self.distillation_loss_fn = distillation_loss_fn
#         self.alpha = alpha
#         self.temperature = temperature
    
#     def call(self, x):
#         return self.student(x)

#     def train_step(self, data):
#         # Unpack data
#         x, y = data
#         print(x)
#         # Forward pass of teacher
#         features = self.teacher_part1.get_features(x)
#         teacher_predictions = self.teacher_part2(features, training=False)
#         print(features.shape)
#         print(teacher_predictions.shape)
#         #teacher_predictions = tf.convert_to_tensor(teacher_predictions)
        
#         with tf.GradientTape() as tape:
#             # Forward pass of student
#             student_predictions = self.student(x, training=True)
#             # Compute losses
#             student_loss = self.student_loss_fn(y, student_predictions)
#             distillation_loss = self.distillation_loss_fn(
#                 tf.nn.softmax(teacher_predictions / self.temperature, axis=1),
#                 tf.nn.softmax(student_predictions / self.temperature, axis=1),
#             )
#             loss = self.alpha * student_loss + (1 - self.alpha) * distillation_loss

#         # Compute gradients
#         trainable_vars = self.student.trainable_variables
#         gradients = tape.gradient(loss, trainable_vars)

#         # Update weights
#         self.optimizer.apply_gradients(zip(gradients, trainable_vars))

#         # Update the metrics configured in `compile()`.
#         self.compiled_metrics.update_state(y, student_predictions)

#         # Return a dict of performance
#         results = {m.name: m.result() for m in self.metrics}
#         results.update(
#             {"student_loss": student_loss, "distillation_loss": distillation_loss}
#         )
#         return results

#     def test_step(self, data):
#         # Unpack the data
#         x, y = data

#         # Compute predictions
#         y_prediction = self.student(x, training=False)

#         # Calculate the loss
#         student_loss = self.student_loss_fn(y, y_prediction)

#         # Update the metrics.
#         self.compiled_metrics.update_state(y, y_prediction)

#         # Return a dict of performance
#         results = {m.name: m.result() for m in self.metrics}
#         results.update({"student_loss": student_loss})
#         return results

In [None]:
# distiller = Distiller(student=student, teacher_part1=model, teacher_part2=clf)

In [None]:
def classifier():
    clf = keras.Sequential()
    
    clf.add(layers.Dense(1024//2, activation='relu', input_dim=128))
    clf.add(layers.Dropout(0.5))
    clf.add(layers.Dense(1024//4, activation='relu'))
    clf.add(layers.Dropout(0.5)) 
    # Number of classes !!!
    clf.add(Dense(n_classes))
    return clf
clf = classifier()
clf.summary()

In [None]:
clf.compile(optimizer=tf.keras.optimizers.Adam(lr=0.00006), loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), metrics = [tf.keras.metrics.SparseCategoricalAccuracy()])
clf_hist = clf.fit(features_train, labels_train, validation_split=0.2,epochs = 10)

In [None]:
from sklearn.metrics import accuracy_score
clf.evaluate(features_test, labels_test)

# DataGenerator for fit

In [None]:
import numpy as np
import keras

class DataGenerator(tf.keras.utils.Sequence):
    'Generates data for Keras'
    def __init__(self, list_IDs, labels, batch_size=15, dim=(224,224), n_channels=3,
                 n_classes=10, shuffle=True):
        'Initialization'
        self.dim = dim
        self.batch_size = batch_size
        self.labels = labels
        self.list_IDs = list_IDs
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.shuffle = shuffle
        self.on_epoch_end()

    def __len__(self):
        'Denotes the number of batches per epoch'
        return int(np.floor(len(self.list_IDs) / self.batch_size))

    def __getitem__(self, index):
        'Generate one batch of data'
        # Generate indexes of the batch
        indexes = self.indexes[index*self.batch_size:(index+1)*self.batch_size]

        # Find list of IDs
        list_IDs_temp = [self.list_IDs[k] for k in indexes]

        # Generate data
        X, y = self.__data_generation(list_IDs_temp)

        return X, y

    def on_epoch_end(self):
        'Updates indexes after each epoch'
        self.indexes = np.arange(len(self.list_IDs))
        if self.shuffle == True:
            np.random.shuffle(self.indexes)

    def __data_generation(self, list_IDs_temp):
        'Generates data containing batch_size samples' # X : (n_samples, *dim, n_channels)
        # Initialization
        X = np.empty((self.batch_size, *self.dim, self.n_channels))
        y = np.empty((self.batch_size), dtype=int)

        # Generate data
        for i, ID in enumerate(list_IDs_temp):
            # Store sample
            img = cv2.imread('../input/big-dataset/img_align_celeba/' + str(ID))
            img = cv2.resize(img, (224, 224))
            img = np.asarray(img, dtype=np.float64)
            img = preprocess_input(img)
            X[i,] = img

            # Store class
            y[i] = self.labels[ID]

        return X, y

In [None]:
with open('../input/labels-boxes/identity_CelebA.txt') as f:
    lines_id = f.readlines()

labels = {}
Ids = []
for e in lines_id:
    labels[e.split()[0]] = int(e.split()[1])
    Ids.append(e.split()[0])


In [None]:
with open('../input/labels-boxes/identity_CelebA.txt') as f:
    lines_id = f.readlines()

labels = []
images = []
for e in lines_id:
    labels.append(int(e.split()[1]))
    img = cv2.imread('../input/big-dataset/img_align_celeba/' + e.split()[0])
    img = cv2.resize(img, (224, 224))
    img = np.asarray(img, dtype=np.float64)
    img = preprocess_input(img)
    images.append(img)
# with open('../input/labels-boxes/list_bbox_celeba.txt') as f:
#     lines_b = f.readlines()
# boxes = [[int(e.split()[i]) for i in range(1,5)] for e in lines_b[2:]]

In [None]:
from sklearn.utils import shuffle
images, labels = shuffle(images, labels)

In [None]:
class Teacher(keras.Model):
    def __init__(self, teacher_part1, teacher_part2):
        super(Teacher, self).__init__()
        self.teacher_part1 = teacher_part1
        self.teacher_part2 = teacher_part2
    
    def call(self, inputs):
        x = self.teacher_part1(inputs, training=False)
        return self.teacher_part2(x, training=False)
    

In [None]:
model = vgg_face()
model.load_weights('../input/weights/vgg_face_weights.h5')
model.pop()
model.add(tf.keras.layers.Dense(128, use_bias=False))
for layer in model.layers[:-2]:
    layer.trainable = False

In [None]:
teacher = Teacher(teacher_part1 = model, teacher_part2=clf)

In [None]:

class Distiller(keras.Model):
    def __init__(self, student, teacher):
        super(Distiller, self).__init__()
        self.teacher = teacher
        self.student = student

    def compile(
        self,
        optimizer,
        metrics,
        student_loss_fn,
        distillation_loss_fn,
        alpha=0.1,
        temperature=3,
    ):
        """ Configure the distiller.

        Args:
            optimizer: Keras optimizer for the student weights
            metrics: Keras metrics for evaluation
            student_loss_fn: Loss function of difference between student
                predictions and ground-truth
            distillation_loss_fn: Loss function of difference between soft
                student predictions and soft teacher predictions
            alpha: weight to student_loss_fn and 1-alpha to distillation_loss_fn
            temperature: Temperature for softening probability distributions.
                Larger temperature gives softer distributions.
        """
        super(Distiller, self).compile(optimizer=optimizer, metrics=metrics)
        self.student_loss_fn = student_loss_fn
        self.distillation_loss_fn = distillation_loss_fn
        self.alpha = alpha
        self.temperature = temperature

#     def call(self, inputs):
#         return self.student(inputs)
    
    def train_step(self, data):
        # Unpack data
        x, y = data
        print(data)
        # Forward pass of teacher
        teacher_predictions = self.teacher(x)

        with tf.GradientTape() as tape:
            # Forward pass of student
            student_predictions = self.student(x, training=True)

            # Compute losses
            student_loss = self.student_loss_fn(y, student_predictions)
            distillation_loss = self.distillation_loss_fn(
                tf.nn.softmax(teacher_predictions / self.temperature, axis=1),
                tf.nn.softmax(student_predictions / self.temperature, axis=1),
            )
            loss = self.alpha * student_loss + (1 - self.alpha) * distillation_loss
            
        # Compute gradients
        trainable_vars = self.student.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)

        # Update weights
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))

        # Update the metrics configured in `compile()`.
        self.compiled_metrics.update_state(y, student_predictions)

        # Return a dict of performance
        results = {m.name: m.result() for m in self.metrics}
        results.update(
            {"student_loss": student_loss, "distillation_loss": distillation_loss}
        )
        return results

    def test_step(self, data):
        # Unpack the data
        x, y = data

        # Compute predictions
        y_prediction = self.student(x, training=False)

        # Calculate the loss
        student_loss = self.student_loss_fn(y, y_prediction)

        # Update the metrics.
        self.compiled_metrics.update_state(y, y_prediction)

        # Return a dict of performance
        results = {m.name: m.result() for m in self.metrics}
        results.update({"student_loss": student_loss})
        return results


In [None]:
distiller = Distiller(student=student, teacher=teacher)

In [None]:
#datagen = DataGenerator(list_IDs=Ids, labels=labels, n_classes=len(set(labels.values())))

distiller.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=0.00006),
    metrics=[keras.metrics.SparseCategoricalAccuracy()],
    student_loss_fn=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    distillation_loss_fn=keras.losses.KLDivergence(),
    alpha=0.5,
    temperature=10,
)
# Distill teacher to student
distiller.fit(images_train, labels_train, epochs=100, validation_split=0.2)

# Evaluate student on test dataset
distiller.evaluate(images_test, labels_test)
