# **MobileNetV2: Inverted Residuals and Linear Bottlenecks**

Sandler, M., Howard, A., Zhu, M., Zhmoginov, A., & Chen, L. C. (2018). Mobilenetv2: Inverted residuals and linear bottlenecks. In Proceedings of the IEEE conference on computer vision and pattern recognition (pp. 4510-4520).

In [1]:
import tensorflow as tf
tf.__version__

'2.4.1'

In [41]:
def ConvBNReLU(
    x, 
    layer_type, 
    output_channels = None,
    kernel_size = 3,
    strides = 1, 
    activation_fn = tf.nn.relu6, 
    expansion_factor = 6, 
):
    assert layer_type.lower() in ["expansion", "depthwise", "pointwise", "naive"]

    if layer_type.lower() == "expansion":
        ## Conv 1x1
        x = tf.keras.layers.Conv2D(x.shape[-1] * expansion_factor, 1, padding = "same")(x)
        x = tf.keras.layers.BatchNormalization()(x)
        x = tf.keras.layers.Activation(activation_fn)(x)

    elif layer_type.lower() == "depthwise":
        ## Dwise 3x3
        x = tf.keras.layers.DepthwiseConv2D(3, strides = strides, padding = "same")(x)
        x = tf.keras.layers.BatchNormalization()(x)
        x = tf.keras.layers.Activation(activation_fn)(x)
    
    elif layer_type.lower() == "pointwise":
        ## Conv 1x1
        assert output_channels != None
        x = tf.keras.layers.Conv2D(output_channels, 1, padding = "same")(x) ## no activation, i.e. use linear.
        x = tf.keras.layers.BatchNormalization()(x)

    else:
        assert output_channels != None
        x = tf.keras.layers.Conv2D(output_channels, kernel_size, strides = strides, padding = "same")(x)
        x = tf.keras.layers.BatchNormalization()(x)
        x = tf.keras.layers.Activation(activation_fn)(x)

    return x


def InvertResidualBlock(
    x, 
    output_channels, 
    strides = 1,
    expansion_factor = 6,
):
    assert strides in [1, 2], f"Argument 'strides' must be 1 or 2, not {strides}."
    residual = x

    x = ConvBNReLU(x, "expansion", expansion_factor = expansion_factor)
    x = ConvBNReLU(x, "depthwise", expansion_factor = expansion_factor, strides = strides)
    x = ConvBNReLU(x, "pointwise", expansion_factor = expansion_factor, output_channels = output_channels)

    if strides == 1 and x.shape[-1] == residual.shape[-1]:
        x = tf.keras.layers.Add()([x, residual])

    return x

In [42]:
IMAGE_SIZE = [224, 224]

def MobileNetV2_224(model_name = "MobileNetV2_224", embedding_dims = 1_000, apply_classifier = True):
    x = model_input = tf.keras.layers.Input(shape = (*IMAGE_SIZE, 3))

    ## Entry flow (stem).
    x = ConvBNReLU(x, "naive", kernel_size = 3, strides = 2, output_channels = 32)
    x = InvertResidualBlock(x, 16, expansion_factor = 1)

    ## Middle flow.
    ## It means (output channels, repeated times, stride)
    args = [
        (24, 2, 2),
        (32, 3, 2),
        (64, 4, 2),
        (96, 3, 1),
        (160, 3, 2),
        (320, 1, 1)]
    
    for (output_channels, repeated_times, strides) in args:
        ## The first layer of each sequence has a stride s and all others use stride 1.
        x = InvertResidualBlock(x, output_channels, strides)
        for _ in range(1, repeated_times):
            x = InvertResidualBlock(x, output_channels, strides = 1)

    ## Exit flow.
    x = ConvBNReLU(x, "naive", kernel_size = 1, output_channels = 1_280)
    x = tf.keras.layers.GlobalAveragePooling2D()(x)

    model_output = x = tf.keras.layers.Dense(embedding_dims)(x)
    if apply_classifier:
        model_output = tf.keras.layers.Softmax()(model_output)
        
    return tf.keras.Model(
        inputs = model_input,
        outputs = model_output,
        name = model_name)

In [43]:
tmp = MobileNetV2_224()
tmp.summary()

Model: "MobileNetV2_224"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_15 (InputLayer)           [(None, 224, 224, 3) 0                                            
__________________________________________________________________________________________________
conv2d_166 (Conv2D)             (None, 112, 112, 32) 896         input_15[0][0]                   
__________________________________________________________________________________________________
batch_normalization_232 (BatchN (None, 112, 112, 32) 128         conv2d_166[0][0]                 
__________________________________________________________________________________________________
activation_158 (Activation)     (None, 112, 112, 32) 0           batch_normalization_232[0][0]    
____________________________________________________________________________________