![](https://github.com/lattice-ai/DeepLabV3-Plus/raw/master/assets/deeplabv3_plus_diagram.png)

In [1]:
import os

import tensorflow as tf
from tensorflow import keras

In [2]:
def ASPP(inputs):
    """shape"""
    shape = inputs.shape
    
    """ASPP(Atrous Spatial Pyramid Pooling)"""
    # Image Pooling
    y_pool = tf.keras.layers.AveragePooling2D(pool_size=(shape[1], shape[2]))(inputs)
    y_pool = tf.keras.layers.Conv2D(filters=256, kernel_size=1, padding="same", use_bias=False)(y_pool)
    y_pool = tf.keras.layers.BatchNormalization()(y_pool)
    y_pool = tf.keras.layers.Activation("relu")(y_pool)
    y_pool = tf.keras.layers.UpSampling2D( (shape[1], shape[2]), interpolation="bilinear")(y_pool)
    # 1X1 Conv
    y_1 = tf.keras.layers.Conv2D(filters=256, kernel_size=1, padding="same", use_bias=False)(inputs)
    y_1 = tf.keras.layers.BatchNormalization()(y_1)
    y_1 = tf.keras.layers.Activation("relu")(y_1)
    # 3X3 Conv dilation_rate:6
    y_6 = tf.keras.layers.Conv2D(filters=256, kernel_size=3, dilation_rate=6, padding="same", use_bias=False)(inputs)
    y_6 = tf.keras.layers.BatchNormalization()(y_6)
    y_6 = tf.keras.layers.Activation("relu")(y_6)
    # 3X3 Conv dilation_rate:12
    y_12 = tf.keras.layers.Conv2D(filters=256, kernel_size=3, dilation_rate=12, padding="same", use_bias=False)(inputs)
    y_12 = tf.keras.layers.BatchNormalization()(y_12)
    y_12 = tf.keras.layers.Activation("relu")(y_12)
    # 3X3 Conv dilation_rate:18
    y_18 = tf.keras.layers.Conv2D(filters=256, kernel_size=3, dilation_rate=18, padding="same", use_bias=False)(inputs)
    y_18 = tf.keras.layers.BatchNormalization()(y_18)
    y_18 = tf.keras.layers.Activation("relu")(y_18)
    
    # Concat
    y = tf.keras.layers.Concatenate()([y_pool, y_1, y_6, y_12, y_18])
    
    # 1X1 Conv
    y = tf.keras.layers.Conv2D(filters=256, kernel_size=1, dilation_rate=1, padding="same", use_bias=False)(y)
    y = tf.keras.layers.BatchNormalization()(y)
    y = tf.keras.layers.Activation("relu")(y)
    
    return y
    
def Deeplabv3plus(shape):
    """ inputs"""
    inputs = keras.layers.Input(shape)
    
    """ Pre-trained ResNet50 """
    base_model = tf.keras.applications.ResNet50(weights='imagenet', include_top=False, input_tensor=inputs)
    
    """ Pre-trained ResNet50 Output """
    image_features = base_model.get_layer("conv4_block6_out").output
    x_a = ASPP(image_features)
    x_a = tf.keras.layers.UpSampling2D( (4,4), interpolation="bilinear")(x_a)
    
    """ Get low_level feature """
    x_b = base_model.get_layer("conv2_block2_out").output
    x_b = tf.keras.layers.Conv2D(filters=48, kernel_size=1, padding="same", use_bias=False)(x_b)
    x_b = tf.keras.layers.BatchNormalization()(x_b)
    x_b = tf.keras.layers.Activation("relu")(x_b)
    
    # Concatenate
    x = tf.keras.layers.Concatenate()([x_a, x_b])
    
    # 3X3 Conv
    x = tf.keras.layers.Conv2D(filters=256, kernel_size=3, padding="same", use_bias=False)(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.Activation("relu")(x)
    
    # 3X3 Conv
    x = tf.keras.layers.Conv2D(filters=256, kernel_size=3, padding="same", use_bias=False)(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.Activation("relu")(x)
    
    # Upsampleing
    x = tf.keras.layers.UpSampling2D( (4,4), interpolation="bilinear")(x)
    
    """ Outputs """
    x = tf.keras.layers.Conv2D(1, (1,1), name="output_layer")(x)
    x = tf.keras.layers.Activation("sigmoid")(x)
    
    """ Model """
    model = tf.keras.models.Model(inputs=inputs, outputs=x)
    return model

In [3]:
input_shape = (224,224,3)
model = Deeplabv3plus(input_shape)
model.summary()

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_1 (InputLayer)           [(None, 224, 224, 3  0           []                               
                                )]                                                                
                                                                                                  
 conv1_pad (ZeroPadding2D)      (None, 230, 230, 3)  0           ['input_1[0][0]']                
                                                                                                  
 conv1_conv (Conv2D)            (None, 112, 112, 64  9472        ['conv1_pad[0][0]']              
                                )                                                                 
                                                                                              

                                                                                                  
 conv2_block3_1_relu (Activatio  (None, 56, 56, 64)  0           ['conv2_block3_1_bn[0][0]']      
 n)                                                                                               
                                                                                                  
 conv2_block3_2_conv (Conv2D)   (None, 56, 56, 64)   36928       ['conv2_block3_1_relu[0][0]']    
                                                                                                  
 conv2_block3_2_bn (BatchNormal  (None, 56, 56, 64)  256         ['conv2_block3_2_conv[0][0]']    
 ization)                                                                                         
                                                                                                  
 conv2_block3_2_relu (Activatio  (None, 56, 56, 64)  0           ['conv2_block3_2_bn[0][0]']      
 n)       

                                                                                                  
 conv3_block3_1_relu (Activatio  (None, 28, 28, 128)  0          ['conv3_block3_1_bn[0][0]']      
 n)                                                                                               
                                                                                                  
 conv3_block3_2_conv (Conv2D)   (None, 28, 28, 128)  147584      ['conv3_block3_1_relu[0][0]']    
                                                                                                  
 conv3_block3_2_bn (BatchNormal  (None, 28, 28, 128)  512        ['conv3_block3_2_conv[0][0]']    
 ization)                                                                                         
                                                                                                  
 conv3_block3_2_relu (Activatio  (None, 28, 28, 128)  0          ['conv3_block3_2_bn[0][0]']      
 n)       

                                                                                                  
 conv4_block2_1_bn (BatchNormal  (None, 14, 14, 256)  1024       ['conv4_block2_1_conv[0][0]']    
 ization)                                                                                         
                                                                                                  
 conv4_block2_1_relu (Activatio  (None, 14, 14, 256)  0          ['conv4_block2_1_bn[0][0]']      
 n)                                                                                               
                                                                                                  
 conv4_block2_2_conv (Conv2D)   (None, 14, 14, 256)  590080      ['conv4_block2_1_relu[0][0]']    
                                                                                                  
 conv4_block2_2_bn (BatchNormal  (None, 14, 14, 256)  1024       ['conv4_block2_2_conv[0][0]']    
 ization) 

 conv4_block5_1_conv (Conv2D)   (None, 14, 14, 256)  262400      ['conv4_block4_out[0][0]']       
                                                                                                  
 conv4_block5_1_bn (BatchNormal  (None, 14, 14, 256)  1024       ['conv4_block5_1_conv[0][0]']    
 ization)                                                                                         
                                                                                                  
 conv4_block5_1_relu (Activatio  (None, 14, 14, 256)  0          ['conv4_block5_1_bn[0][0]']      
 n)                                                                                               
                                                                                                  
 conv4_block5_2_conv (Conv2D)   (None, 14, 14, 256)  590080      ['conv4_block5_1_relu[0][0]']    
                                                                                                  
 conv4_blo

 batch_normalization_4 (BatchNo  (None, 14, 14, 256)  1024       ['conv2d_4[0][0]']               
 rmalization)                                                                                     
                                                                                                  
 up_sampling2d (UpSampling2D)   (None, 14, 14, 256)  0           ['activation[0][0]']             
                                                                                                  
 activation_1 (Activation)      (None, 14, 14, 256)  0           ['batch_normalization_1[0][0]']  
                                                                                                  
 activation_2 (Activation)      (None, 14, 14, 256)  0           ['batch_normalization_2[0][0]']  
                                                                                                  
 activation_3 (Activation)      (None, 14, 14, 256)  0           ['batch_normalization_3[0][0]']  
          