In [1]:
import tensorflow as tf

In [8]:
# Depthwise seperable convolution layer
def depthwise_seperable(x, e_1x1, padding="VALID", stride=1):
  x = tf.keras.layers.DepthwiseConv2D(3, strides=stride, padding=padding, activation="relu")(x)
  x = tf.keras.layers.Conv2D(e_1x1, 1, padding="SAME", activation="relu")(x)

  return x

<img src="https://production-media.paperswithcode.com/methods/Screen_Shot_2020-06-22_at_4.26.15_PM_ko4FqXD.png" />

In [33]:
resize_layer = tf.keras.layers.Lambda( 
    lambda image: tf.image.resize( 
        image, 
        (224, 224), 
        method = tf.image.ResizeMethod.BICUBIC,
        preserve_aspect_ratio = True
    )
)

inp = tf.keras.layers.Input((32, 32, 3))
x = resize_layer(inp)
x = tf.keras.layers.ZeroPadding2D(padding=((0, 1), (0, 1)), name='conv1_pad')(x)
x = tf.keras.layers.Conv2D(32, 3, strides=2)(x)

x = depthwise_seperable(x, 64, padding="SAME")

x = depthwise_seperable(x, 128, stride=2, padding="SAME")
x = depthwise_seperable(x, 128, padding="SAME")

x = depthwise_seperable(x, 256, stride=2, padding="SAME")
x = depthwise_seperable(x, 256, padding="SAME")

x = depthwise_seperable(x, 512, stride=2, padding="SAME")

for i in range(5):
  x = depthwise_seperable(x, 512, padding="SAME")
  x = depthwise_seperable(x, 512, padding="SAME")

x = depthwise_seperable(x, 1024, stride=2, padding="SAME")
x = depthwise_seperable(x, 1024, padding="SAME")

x = tf.keras.layers.AveragePooling2D(7)(x)

x = tf.keras.layers.Dense(1000)(x)

out = tf.keras.layers.Dense(10, activation="softmax")(x)

model = tf.keras.Model(inputs=inp, outputs=out)

In [34]:
ts_inp = tf.random.normal((1, 32, 32, 3))
ts_out = model(ts_inp)
print(model.summary())

Model: "model_10"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_19 (InputLayer)       [(None, 32, 32, 3)]       0         
                                                                 
 lambda_18 (Lambda)          (None, 224, 224, 3)       0         
                                                                 
 conv1_pad (ZeroPadding2D)   (None, 225, 225, 3)       0         
                                                                 
 conv2d_256 (Conv2D)         (None, 112, 112, 32)      896       
                                                                 
 depthwise_conv2d_244 (Depth  (None, 112, 112, 32)     320       
 wiseConv2D)                                                     
                                                                 
 conv2d_257 (Conv2D)         (None, 112, 112, 64)      2112      
                                                          