In [1]:
import tensorflow as tf

In [4]:
class MaxoutLayer(tf.keras.layers.Layer):
    def __init__(self, units, maxout_p=2, **kwargs):
        super(MaxoutLayer, self).__init__(**kwargs)
        self.units = units
        self.maxout_p = maxout_p

    def build(self, input_shape):
        self.w = self.add_weight(
            shape=(input_shape[-1], self.units * self.maxout_p),
            initializer="glorot_uniform",
            trainable=True,
            name="weights"
        )
        self.b = self.add_weight(
            shape=(self.units * self.maxout_p,),
            initializer="zeros",
            trainable=True,
            name="bias"
        )

    def call(self, inputs):
        z = tf.matmul(inputs, self.w) + self.b
        z = tf.reshape(z, [-1, self.units, self.maxout_p])
        output = tf.reduce_max(z, axis=-1)
        return output

    def get_config(self):
        config = super(MaxoutLayer, self).get_config()
        config.update({"units": self.units, "maxout_p": self.maxout_p})
        return config

In [2]:
class LRNLayer(tf.keras.layers.Layer):
    def __init__(self, depth_radius=5, bias=1.0, alpha=1e-4, beta=0.75, **kwargs):
        super(LRNLayer, self).__init__(**kwargs)
        self.depth_radius = depth_radius
        self.bias = bias
        self.alpha = alpha
        self.beta = beta

    def call(self, inputs):
        return tf.nn.local_response_normalization(inputs, depth_radius=self.depth_radius, bias=self.bias, alpha=self.alpha, beta=self.beta)

In [9]:
#
input_tensor = tf.keras.Input(shape=(220, 220, 3), name='input')

#
conv1 = tf.keras.layers.Conv2D(64, (7, 7), (2, 2), padding='same', activation='relu', name='conv1')(input_tensor)
pool1 = tf.keras.layers.MaxPooling2D(pool_size=3, strides=2, padding='same', name='pool1')(conv1)
rnorm1 = LRNLayer(name='rnorm1')(pool1)

#
conv2a = tf.keras.layers.Conv2D(filters=64, kernel_size=1, name='conv2a')(rnorm1)
conv2 = tf.keras.layers.Conv2D(filters=192, kernel_size=3, strides=1, padding='same', name='conv2')(conv2a)
rnorm2 = LRNLayer(name='rnorm2')(conv2)
pool2 = tf.keras.layers.MaxPooling2D(pool_size=3, strides=2, padding='same', name='pool2')(rnorm2)

#
conv3a = tf.keras.layers.Conv2D(192, 1, name='conv3a')(pool2)
conv3 = tf.keras.layers.Conv2D(384, 3, padding='same', name='conv3')(conv3a)
pool3 = tf.keras.layers.MaxPooling2D(pool_size=3, strides=2, padding='same', name='pool3')(conv3)

#
conv4a = tf.keras.layers.Conv2D(384, 1, name='conv4a')(pool3)
conv4 = tf.keras.layers.Conv2D(256, 3, padding='same', name='conv4')(conv4a)

#
conv5a = tf.keras.layers.Conv2D(256, 1, name='conv5a')(conv4)
conv5 = tf.keras.layers.Conv2D(256, 3, padding='same', name='conv5')(conv5a)

#
conv6a = tf.keras.layers.Conv2D(256, 1, name='conv6a')(conv5)
conv6 = tf.keras.layers.Conv2D(256, 3, padding='same', name='conv6')(conv6a)

#
pool4 = tf.keras.layers.MaxPooling2D(pool_size=3, strides=2, padding='same', name='pool4')(conv6)

#
concat = tf.keras.layers.Concatenate(name='concat')([pool4])

#
flatten = tf.keras.layers.Flatten(name="flatten")(concat)

#
fc1 = MaxoutLayer(units=4096, maxout_p=2, name='fc1')(flatten)

#
fc2 = MaxoutLayer(units=4096, maxout_p=2, name='fc2')(fc1)

#
fc7128 = tf.keras.layers.Dense(128, name='fc7128')(fc2)

#
L2 = tf.keras.layers.Lambda(lambda x: tf.math.l2_normalize(x, axis=-1), name='L2')(fc7128)

model = tf.keras.Model(inputs=input_tensor, outputs=L2, name='faceNetModel')

model.summary()

Model: "faceNetModel"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input (InputLayer)          [(None, 220, 220, 3)]     0         
                                                                 
 conv1 (Conv2D)              (None, 110, 110, 64)      9472      
                                                                 
 pool1 (MaxPooling2D)        (None, 55, 55, 64)        0         
                                                                 
 rnorm1 (LRNLayer)           (None, 55, 55, 64)        0         
                                                                 
 conv2a (Conv2D)             (None, 55, 55, 64)        4160      
                                                                 
 conv2 (Conv2D)              (None, 55, 55, 192)       110784    
                                                                 
 rnorm2 (LRNLayer)           (None, 55, 55, 192)      