In [1]:
import tensorflow as tf

from tensorflow.keras.layers import Activation, MaxPool2D, Concatenate, ConvLSTM2D, Dropout, TimeDistributed, Reshape
from tensorflow.keras.layers import Input, Conv2D, Conv2DTranspose, BatchNormalization, Add, UpSampling2D, ConvLSTM2D
from tensorflow.keras.models import Model

from tensorflow.keras.applications import MobileNetV3Large, MobileNetV2
import numpy as np

In [2]:
time_steps = 30

In [3]:
def conv_block(inputs, num_filters):
    x = Conv2D(filters=num_filters, kernel_size=3, padding='same')(inputs)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    N =  np.int32(x.shape[2])
    x = Reshape(target_shape=(1, N, N, num_filters))(x)
    x = ConvLSTM2D(filters = num_filters, kernel_size=(3, 3), padding='same', return_sequences=False, go_backwards=True, kernel_initializer='he_normal' )(x)
    return x

In [4]:
def upsample_block(inputs, skip_features, num_filters):
    x = UpSampling2D(interpolation='bilinear')(inputs)
    x = Concatenate()([x, skip_features])
    x = conv_block(x, num_filters)
    x = Dropout(0.3)(x)    
    return x

In [5]:
def LRASPP(x, out_channels: int):
    x1 = Conv2D(out_channels, 1, use_bias=False)(x)
    x1 = BatchNormalization(momentum=0.1, epsilon=1e-5)(x1)
    x1 = Activation('relu')(x)
    
    x2 = Conv2D(out_channels, 1, use_bias=False, activation='sigmoid')(tf.reduce_mean(x, axis=[1, 2], keepdims=True))
    return x1 * x2

In [6]:
def output_block(inputs):
    x = Conv2D(3, (1,1), 1, 'same', use_bias=False)(inputs)
    x = BatchNormalization(momentum=0.1, epsilon=1e-5)(x)
    x = Activation('relu')(x)
    x = Conv2D(3, (1,1), 1, 'same', use_bias=False)(inputs)
    x = BatchNormalization(momentum=0.1, epsilon=1e-5)(x)
    x = Activation('relu')(x)
    return x

In [7]:
def get_model():
    inputs = Input(shape=(560, 560, 3), name='input')
    encoder = MobileNetV3Large(input_tensor=inputs, weights="imagenet", include_top=False)
    
    for layer in encoder.layers:
        layer.trainable = False
    
    
    #inputs = Input(shape=(30, 256, 256, 3))
    
    # Encoder
    e1 = encoder.get_layer('input').output      # (None, 560, 560, 3)
    e2 = encoder.get_layer('re_lu_2').output    # (None, 280, 280, 64)
    e3 = encoder.get_layer('re_lu_6').output    # (None, 140, 140, 72)
    e4 = encoder.get_layer('re_lu_15').output   # (None, 70, 70, 240)
    
    # Bridge
    b4 = encoder.get_layer('re_lu_29').output   #(None, 35, 35, 672)
    
    # lraspp = LRASPP(e4, 672)
    
    ### Decoder
    d1 = upsample_block(b4, e4, 672)
    d2 = upsample_block(d1, e3, 256)
    d3 = upsample_block(d2, e2, 128)
    d4 = upsample_block(d3, e1, 64)
    
    # Output
    outputs = output_block(d4)
    outputs = Conv2D(1, (1,1), 1, 'same', activation='sigmoid')(outputs)
    
    return Model(inputs, outputs)

In [8]:
model = get_model()



In [9]:
model.summary()

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input (InputLayer)             [(None, 560, 560, 3  0           []                               
                                )]                                                                
                                                                                                  
 rescaling (Rescaling)          (None, 560, 560, 3)  0           ['input[0][0]']                  
                                                                                                  
 Conv (Conv2D)                  (None, 280, 280, 16  432         ['rescaling[0][0]']              
                                )                                                                 
                                                                                              

                                                                                                  
 expanded_conv_2/depthwise/Batc  (None, 140, 140, 72  288        ['expanded_conv_2/depthwise[0][0]
 hNorm (BatchNormalization)     )                                ']                               
                                                                                                  
 re_lu_5 (ReLU)                 (None, 140, 140, 72  0           ['expanded_conv_2/depthwise/Batch
                                )                                Norm[0][0]']                     
                                                                                                  
 expanded_conv_2/project (Conv2  (None, 140, 140, 24  1728       ['re_lu_5[0][0]']                
 D)                             )                                                                 
                                                                                                  
 expanded_

                                                                                                  
 re_lu_10 (ReLU)                (None, 70, 70, 120)  0           ['expanded_conv_4/depthwise/Batch
                                                                 Norm[0][0]']                     
                                                                                                  
 expanded_conv_4/squeeze_excite  (None, 1, 1, 120)   0           ['re_lu_10[0][0]']               
 /AvgPool (GlobalAveragePooling                                                                   
 2D)                                                                                              
                                                                                                  
 expanded_conv_4/squeeze_excite  (None, 1, 1, 32)    3872        ['expanded_conv_4/squeeze_excite/
 /Conv (Conv2D)                                                  AvgPool[0][0]']                  
          

                                                                                                  
 expanded_conv_5/project/BatchN  (None, 70, 70, 40)  160         ['expanded_conv_5/project[0][0]']
 orm (BatchNormalization)                                                                         
                                                                                                  
 expanded_conv_5/Add (Add)      (None, 70, 70, 40)   0           ['expanded_conv_4/Add[0][0]',    
                                                                  'expanded_conv_5/project/BatchNo
                                                                 rm[0][0]']                       
                                                                                                  
 expanded_conv_6/expand (Conv2D  (None, 70, 70, 240)  9600       ['expanded_conv_5/Add[0][0]']    
 )                                                                                                
          

 )                                                                                                
                                                                                                  
 multiply_4 (Multiply)          (None, 35, 35, 200)  0           ['expanded_conv_7/depthwise/Batch
                                                                 Norm[0][0]',                     
                                                                  'tf.math.multiply_7[0][0]']     
                                                                                                  
 expanded_conv_7/project (Conv2  (None, 35, 35, 80)  16000       ['multiply_4[0][0]']             
 D)                                                                                               
                                                                                                  
 expanded_conv_7/project/BatchN  (None, 35, 35, 80)  320         ['expanded_conv_7/project[0][0]']
 orm (Batc

                                                                                                  
 expanded_conv_9/depthwise/Batc  (None, 35, 35, 184)  736        ['expanded_conv_9/depthwise[0][0]
 hNorm (BatchNormalization)                                      ']                               
                                                                                                  
 tf.__operators__.add_11 (TFOpL  (None, 35, 35, 184)  0          ['expanded_conv_9/depthwise/Batch
 ambda)                                                          Norm[0][0]']                     
                                                                                                  
 re_lu_22 (ReLU)                (None, 35, 35, 184)  0           ['tf.__operators__.add_11[0][0]']
                                                                                                  
 tf.math.multiply_11 (TFOpLambd  (None, 35, 35, 184)  0          ['re_lu_22[0][0]']               
 a)       

 expanded_conv_10/squeeze_excit  (None, 35, 35, 480)  0          ['multiply_10[0][0]',            
 e/Mul (Multiply)                                                 'tf.math.multiply_14[0][0]']    
                                                                                                  
 expanded_conv_10/project (Conv  (None, 35, 35, 112)  53760      ['expanded_conv_10/squeeze_excite
 2D)                                                             /Mul[0][0]']                     
                                                                                                  
 expanded_conv_10/project/Batch  (None, 35, 35, 112)  448        ['expanded_conv_10/project[0][0]'
 Norm (BatchNormalization)                                       ]                                
                                                                                                  
 expanded_conv_11/expand (Conv2  (None, 35, 35, 672)  75264      ['expanded_conv_10/project/BatchN
 D)       

 expanded_conv_12/expand/BatchN  (None, 35, 35, 672)  2688       ['expanded_conv_12/expand[0][0]']
 orm (BatchNormalization)                                                                         
                                                                                                  
 tf.__operators__.add_18 (TFOpL  (None, 35, 35, 672)  0          ['expanded_conv_12/expand/BatchNo
 ambda)                                                          rm[0][0]']                       
                                                                                                  
 re_lu_29 (ReLU)                (None, 35, 35, 672)  0           ['tf.__operators__.add_18[0][0]']
                                                                                                  
 up_sampling2d (UpSampling2D)   (None, 70, 70, 672)  0           ['re_lu_29[0][0]']               
                                                                                                  
 concatena

                                                                                                  
 batch_normalization_3 (BatchNo  (None, 560, 560, 64  256        ['conv2d_3[0][0]']               
 rmalization)                   )                                                                 
                                                                                                  
 activation_3 (Activation)      (None, 560, 560, 64  0           ['batch_normalization_3[0][0]']  
                                )                                                                 
                                                                                                  
 reshape_3 (Reshape)            (None, 1, 560, 560,  0           ['activation_3[0][0]']           
                                 64)                                                              
                                                                                                  
 conv_lstm