In [1]:
''' Reference:
    Wen et al. A Discriminative Feature Learning Approach for Deep Face Recognition. ECCV 2016'''

import tensorflow as tf
import tensorflow.keras as keras

#Allow memory growth
physical_devices = tf.config.list_physical_devices('GPU')
tf.config.experimental.set_memory_growth(physical_devices[0], True)
physical_devices

[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]

In [21]:
class Centre_loss(keras.layers.Layer):
    def __init__(self, num_classes, feature_dim):
        super(Centre_loss, self).__init__()
        self.num_classes=num_classes
        self.feature_dim=feature_dim
        self.centres = tf.Variable(tf.random.normal(shape=[self.num_classes, self.feature_dim], trainable=True))

    def call(self, inputs, y_True):
        #inputs -> shape (batch_size, feature_dim)
        #y_True -> the ground true labels of the inputs

        batch_size = inputs.shape[0]
        
        #For inputs
        #[batch_size, feature_dim] -> [batch_size, 1]
        feature_square_sum = tf.reduce_sum(tf.math.pow(inputs, 2), axis=1, keepdims=True)
        #[batch_size, 1] -> [batch_size, num_classes]
        feature_ss_broadcast = tf.broadcast_to(input=feature_square_sum, \
            shape=[batch_size, self.num_classes])

        #For centres
        #[num_classes, feature_dim] -> [num_classes, 1]
        centre_square_sum = tf.reduce_sum(tf.math.pow(self.centres, 2), axis=1, keepdims=True)
        #[num_classes, 1] -> [num_classes, batch_size] -> [batch_size, num_classes]
        centre_ss_broadcast = tf.transpose(tf.broadcast_to(input=centre_square_sum, shape=[self.num_classes, batch_size]))

        #[batch_size, num_classes]
        feature_ss_broadcast = tf.cast(feature_ss_broadcast, dtype=tf.float32)
        
        fea_cen = feature_ss_broadcast + centre_ss_broadcast
        fea_cen = fea_cen*1 + (-2)*(inputs @ tf.transpose(self.centres))

        #Class mask
        classes_broadcast = tf.broadcast_to(tf.cast(tf.range(self.num_classes), dtype=tf.float32),\
            shape=[batch_size, self.num_classes])
        y_True_broadcast = tf.broadcast_to(tf.expand_dims(y_True, axis=1), shape=[batch_size, self.num_classes])
        y_True_broadcast = tf.cast(y_True_broadcast, dtype=tf.float32)

        mask = tf.cast(tf.equal(classes_broadcast, y_True_broadcast), dtype=tf.float32)

        loss_matrix = tf.clip_by_value(t=fea_cen * mask, clip_value_min=1e-12, clip_value_max=1e+12)

        loss = tf.reduce_sum(loss_matrix)/batch_size

        return loss

In [24]:
if __name__ == "__main__":
    loss = CentreLoss(num_classes=10, feature_dim=7)
    print("Done!")

<class '__main__.CentreLoss'>
Done!
