In [1]:
from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, MaxPool2D, Conv2DTranspose, Concatenate, Input
from tensorflow.keras.models import Model
from tensorflow.keras.applications import EfficientNetB4

In [2]:
def conv_block(input, num_filters):
    x = Conv2D(num_filters, 3, padding="same")(input)
    x = BatchNormalization()(x)
    x = Activation("relu")(x)

    x = Conv2D(num_filters, 3, padding="same")(x)
    x = BatchNormalization()(x)
    x = Activation("relu")(x)

    return x

def decoder_block(input, skip_features, num_filters):
    x = Conv2DTranspose(num_filters, (2, 2), strides=2, padding="same")(input)
    x = Concatenate()([x, skip_features])
    x = conv_block(x, num_filters)
    return x

In [3]:
def EfficientNetB4_unet(input_shape):
    # Input
    inputs = Input(input_shape)

    # Loading pre trained model
    EffNetB4 = EfficientNetB4(include_top=False, weights="imagenet", input_tensor=inputs)
    
    # Encoder
    s1 = EffNetB4.get_layer('tf.math.truediv').output  # 512 x 512
    s2 = EffNetB4.get_layer('block2a_expand_activation').output  # 256 x 256
    s3 = EffNetB4.get_layer('block3a_expand_activation').output  # 128 x 128
    s4 = EffNetB4.get_layer('block4a_expand_activation').output  # 64 x 64
    
    # Bottleneck
    b1 = EffNetB4.get_layer('block6a_expand_activation').output  # 32 x 32
    
    # Decoder
    d1 = decoder_block(b1, s4, 512)  # 64 x 64
    d2 = decoder_block(d1, s3, 256)    # 128 x 128
    d3 = decoder_block(d2, s2, 128)   # 256 x 256
    d4 = decoder_block(d3, s1, 64)   # 512 x 512
    
    # Output
    outputs = Conv2D(1, 1, padding="same", activation="sigmoid")(d4)

    model = Model(inputs, outputs, name="EfficientNetB4_U-Net")
    model.summary()
    return 

In [4]:
EfficientNetB4_unet((512, 512, 3))

Model: "ResNet50_U-Net"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_1 (InputLayer)           [(None, 512, 512, 3  0           []                               
                                )]                                                                
                                                                                                  
 rescaling (Rescaling)          (None, 512, 512, 3)  0           ['input_1[0][0]']                
                                                                                                  
 normalization (Normalization)  (None, 512, 512, 3)  7           ['rescaling[0][0]']              
                                                                                                  
 tf.math.truediv (TFOpLambda)   (None, 512, 512, 3)  0           ['normalization[0][0

 block2a_expand_conv (Conv2D)   (None, 256, 256, 14  3456        ['block1b_add[0][0]']            
                                4)                                                                
                                                                                                  
 block2a_expand_bn (BatchNormal  (None, 256, 256, 14  576        ['block2a_expand_conv[0][0]']    
 ization)                       4)                                                                
                                                                                                  
 block2a_expand_activation (Act  (None, 256, 256, 14  0          ['block2a_expand_bn[0][0]']      
 ivation)                       4)                                                                
                                                                                                  
 block2a_dwconv_pad (ZeroPaddin  (None, 257, 257, 14  0          ['block2a_expand_activation[0][0]
 g2D)     

                                                                                                  
 block2c_expand_bn (BatchNormal  (None, 128, 128, 19  768        ['block2c_expand_conv[0][0]']    
 ization)                       2)                                                                
                                                                                                  
 block2c_expand_activation (Act  (None, 128, 128, 19  0          ['block2c_expand_bn[0][0]']      
 ivation)                       2)                                                                
                                                                                                  
 block2c_dwconv (DepthwiseConv2  (None, 128, 128, 19  1728       ['block2c_expand_activation[0][0]
 D)                             2)                               ']                               
                                                                                                  
 block2c_b

                                2)                                                                
                                                                                                  
 block3a_expand_bn (BatchNormal  (None, 128, 128, 19  768        ['block3a_expand_conv[0][0]']    
 ization)                       2)                                                                
                                                                                                  
 block3a_expand_activation (Act  (None, 128, 128, 19  0          ['block3a_expand_bn[0][0]']      
 ivation)                       2)                                                                
                                                                                                  
 block3a_dwconv_pad (ZeroPaddin  (None, 131, 131, 19  0          ['block3a_expand_activation[0][0]
 g2D)                           2)                               ']                               
          

                                                                                                  
 block3c_dwconv (DepthwiseConv2  (None, 64, 64, 336)  8400       ['block3c_expand_activation[0][0]
 D)                                                              ']                               
                                                                                                  
 block3c_bn (BatchNormalization  (None, 64, 64, 336)  1344       ['block3c_dwconv[0][0]']         
 )                                                                                                
                                                                                                  
 block3c_activation (Activation  (None, 64, 64, 336)  0          ['block3c_bn[0][0]']             
 )                                                                                                
                                                                                                  
 block3c_s

 D)                                                                                               
                                                                                                  
 block4a_bn (BatchNormalization  (None, 32, 32, 336)  1344       ['block4a_dwconv[0][0]']         
 )                                                                                                
                                                                                                  
 block4a_activation (Activation  (None, 32, 32, 336)  0          ['block4a_bn[0][0]']             
 )                                                                                                
                                                                                                  
 block4a_se_squeeze (GlobalAver  (None, 336)         0           ['block4a_activation[0][0]']     
 agePooling2D)                                                                                    
          

                                                                                                  
 block4c_se_reshape (Reshape)   (None, 1, 1, 672)    0           ['block4c_se_squeeze[0][0]']     
                                                                                                  
 block4c_se_reduce (Conv2D)     (None, 1, 1, 28)     18844       ['block4c_se_reshape[0][0]']     
                                                                                                  
 block4c_se_expand (Conv2D)     (None, 1, 1, 672)    19488       ['block4c_se_reduce[0][0]']      
                                                                                                  
 block4c_se_excite (Multiply)   (None, 32, 32, 672)  0           ['block4c_activation[0][0]',     
                                                                  'block4c_se_expand[0][0]']      
                                                                                                  
 block4c_p

 block4e_se_expand (Conv2D)     (None, 1, 1, 672)    19488       ['block4e_se_reduce[0][0]']      
                                                                                                  
 block4e_se_excite (Multiply)   (None, 32, 32, 672)  0           ['block4e_activation[0][0]',     
                                                                  'block4e_se_expand[0][0]']      
                                                                                                  
 block4e_project_conv (Conv2D)  (None, 32, 32, 112)  75264       ['block4e_se_excite[0][0]']      
                                                                                                  
 block4e_project_bn (BatchNorma  (None, 32, 32, 112)  448        ['block4e_project_conv[0][0]']   
 lization)                                                                                        
                                                                                                  
 block4e_d

 block5a_project_conv (Conv2D)  (None, 32, 32, 160)  107520      ['block5a_se_excite[0][0]']      
                                                                                                  
 block5a_project_bn (BatchNorma  (None, 32, 32, 160)  640        ['block5a_project_conv[0][0]']   
 lization)                                                                                        
                                                                                                  
 block5b_expand_conv (Conv2D)   (None, 32, 32, 960)  153600      ['block5a_project_bn[0][0]']     
                                                                                                  
 block5b_expand_bn (BatchNormal  (None, 32, 32, 960)  3840       ['block5b_expand_conv[0][0]']    
 ization)                                                                                         
                                                                                                  
 block5b_e

 block5d_expand_conv (Conv2D)   (None, 32, 32, 960)  153600      ['block5c_add[0][0]']            
                                                                                                  
 block5d_expand_bn (BatchNormal  (None, 32, 32, 960)  3840       ['block5d_expand_conv[0][0]']    
 ization)                                                                                         
                                                                                                  
 block5d_expand_activation (Act  (None, 32, 32, 960)  0          ['block5d_expand_bn[0][0]']      
 ivation)                                                                                         
                                                                                                  
 block5d_dwconv (DepthwiseConv2  (None, 32, 32, 960)  24000      ['block5d_expand_activation[0][0]
 D)                                                              ']                               
          

 block5f_expand_activation (Act  (None, 32, 32, 960)  0          ['block5f_expand_bn[0][0]']      
 ivation)                                                                                         
                                                                                                  
 block5f_dwconv (DepthwiseConv2  (None, 32, 32, 960)  24000      ['block5f_expand_activation[0][0]
 D)                                                              ']                               
                                                                                                  
 block5f_bn (BatchNormalization  (None, 32, 32, 960)  3840       ['block5f_dwconv[0][0]']         
 )                                                                                                
                                                                                                  
 block5f_activation (Activation  (None, 32, 32, 960)  0          ['block5f_bn[0][0]']             
 )        

 rmalization)                   6)                                                                
                                                                                                  
 activation_3 (Activation)      (None, 128, 128, 25  0           ['batch_normalization_3[0][0]']  
                                6)                                                                
                                                                                                  
 conv2d_transpose_2 (Conv2DTran  (None, 256, 256, 12  131200     ['activation_3[0][0]']           
 spose)                         8)                                                                
                                                                                                  
 concatenate_2 (Concatenate)    (None, 256, 256, 27  0           ['conv2d_transpose_2[0][0]',     
                                2)                                'block2a_expand_activation[0][0]
          