In [2]:
import tensorflow as tf
from tensorflow.keras.layers import Conv2D, ReLU, Input, SeparableConv2D, \
        Add, Dense, BatchNormalization, MaxPool2D, GlobalAvgPool2D

In [20]:
class conv_bn(tf.keras.layers.Layer):
    def __init__(self, filters, kernel_size, strides=1):
        super(conv_bn, self).__init__()
        self.conv = Conv2D(filters, kernel_size, strides, padding='same', use_bias = False)
        self.bn  =  BatchNormalization()

    def call(self,inputs):
        x = self.conv(inputs) 
        x = self.bn(x)
        return x

In [21]:
class sep_bn(tf.keras.layers.Layer):
    def __init__(self, filters, kernel_size, strides=1):
        super(sep_bn, self).__init__()
        self.conv = SeparableConv2D(filters, kernel_size, strides, padding='same', use_bias = False)
        self.bn  =  BatchNormalization()

    def call(self,inputs):
        x = self.conv(inputs) 
        x = self.bn(x)
        return x

In [10]:
def EntryFlow(inputs):
    x = conv_bn(filters=32, kernel_size=3, strides=2)(inputs)
    x = ReLU()(x)
    x = conv_bn(filters=64, kernel_size=3)(x)
    y = ReLU()(x)

    x = sep_bn(filters=128, kernel_size=3)(y)
    x = ReLU()(x)
    x = sep_bn(filters=128, kernel_size=3)(x)
    x = ReLU()(x)
    x = MaxPool2D(pool_size=3, strides=2, padding='same')(x)

    y = conv_bn(filters=128, kernel_size=1, strides=2)(y)
    x = Add()([y,x])

    x = ReLU()(x)
    x = sep_bn(filters=256, kernel_size=3)(x)
    x = ReLU()(x)
    x = sep_bn(filters=256, kernel_size=3)(x)
    x = MaxPool2D(pool_size=3, strides=2, padding='same')(x)

    y = conv_bn(filters=256, kernel_size=1, strides=2)(y)
    x = Add()([y,x])

    x = ReLU()(x)
    x = sep_bn(filters=728, kernel_size=3)(x)
    x = ReLU()(x)
    x = sep_bn(filters=728, kernel_size=3)(x)
    x = MaxPool2D(pool_size=3, strides=2, padding='same')(x)

    y = conv_bn(filters=728, kernel_size=1, strides=2)(y)
    x = Add()([y,x])

    return x

In [11]:
def MiddleFlow(inputs):
    for _ in range(8):
        x = ReLU()(inputs)
        x = sep_bn(filters=728, kernel_size=3)(x)
        x = ReLU()(x)
        x = sep_bn(filters=728, kernel_size=3)(x)
        x = ReLU()(x)
        x = sep_bn(filters=728, kernel_size=3)(x)
        inputs = Add()([inputs, x])
    return inputs

In [26]:
def ExitFlow(inputs):
    x = ReLU()(inputs)
    x = sep_bn(filters=728, kernel_size=3)(x)
    x = ReLU()(x)
    x = sep_bn(filters=1024, kernel_size=3)(x)
    x = MaxPool2D(pool_size=3, strides=2, padding='same')(x)
    inputs = conv_bn(filters=1024, kernel_size=1, strides=2)(inputs)
    x = Add()([inputs,x])
    x = sep_bn(filters=1536, kernel_size=3)(x)
    x = ReLU()(x)
    x = sep_bn(filters=2048, kernel_size=3)(x)
    x = ReLU()(x)
    x = GlobalAvgPool2D()(x)
    x = Dense(1000, activation='softmax')(x)
    return x

In [27]:
inputs = Input(shape=(299,299,3))
x = EntryFlow(inputs)
x = MiddleFlow(x)
output = ExitFlow(x)

In [28]:
model = tf.keras.Model(inputs, output)

In [29]:
model.summary()

Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_11 (InputLayer)           [(None, 299, 299, 3) 0                                            
__________________________________________________________________________________________________
conv_bn_14 (conv_bn)            (None, 150, 150, 32) 992         input_11[0][0]                   
__________________________________________________________________________________________________
re_lu_36 (ReLU)                 (None, 150, 150, 32) 0           conv_bn_14[0][0]                 
__________________________________________________________________________________________________
conv_bn_15 (conv_bn)            (None, 150, 150, 64) 18688       re_lu_36[0][0]                   
______________________________________________________________________________________________

In [31]:
import numpy as np
import tensorflow.keras.backend as K
np.sum([K.count_params(p) for p in model.trainable_weights])

22855952