# U-Net

## Imports

In [1]:
from tensorflow.keras import layers, optimizers, backend, Model
from tensorflow.keras.applications import EfficientNetB4, EfficientNetV2S, ResNet50V2, DenseNet201

## Full model

<img src='images/unet.png' width='600'/>

In [14]:
# U-Net
def unet(input_shape=(512, 512, 3)):
    """
    Creates a neural network using the U-Net architecture.
    Args:
        input_shape: The size of the input image.
    Returns:
        A U-Net model.
    """
    # Encoder Part
    # Layer 1
    inputs = layers.Input(input_shape)
    inputs_rescaled = layers.Lambda(lambda x: x / 255)(inputs)  # Rescale input pixel values to floating point values
    c1 = layers.Conv2D(64, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(inputs_rescaled)
    c1 = layers.BatchNormalization()(c1)
    c1 = layers.Conv2D(64, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c1)
    c1 = layers.BatchNormalization()(c1)
    p1 = layers.MaxPooling2D((2, 2))(c1)

    # Layer 2
    c2 = layers.Conv2D(128, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(p1)
    c2 = layers.BatchNormalization()(c2)
    c2 = layers.Conv2D(128, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c2)
    c2 = layers.BatchNormalization()(c2)
    p2 = layers.MaxPooling2D((2, 2))(c2)

    # Layer 3
    c3 = layers.Conv2D(256, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(p2)
    c3 = layers.BatchNormalization()(c3)
    c3 = layers.Conv2D(256, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c3)
    c3 = layers.BatchNormalization()(c3)
    p3 = layers.MaxPooling2D((2, 2))(c3)

    # Layer 4
    c4 = layers.Conv2D(512, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(p3)
    c4 = layers.BatchNormalization()(c4)
    c4 = layers.Conv2D(512, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c4)
    c4 = layers.BatchNormalization()(c4)
    p4 = layers.MaxPooling2D((2, 2))(c4)

    # Layer 5
    c5 = layers.Conv2D(1024, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(p4)
    c5 = layers.BatchNormalization()(c5)
    c5 = layers.Conv2D(1024, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c5)
    c5 = layers.BatchNormalization()(c5)

    # Decoder Part
    # Layer 6
    u6 = layers.Conv2DTranspose(512, (3, 3), strides=(2, 2), padding='same')(c5)
    u6 = layers.concatenate([u6, c4])
    c6 = layers.Conv2D(512, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(u6)
    c6 = layers.BatchNormalization()(c6)
    c6 = layers.Conv2D(512, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c6)
    c6 = layers.BatchNormalization()(c6)

    # Layer 7
    u7 = layers.Conv2DTranspose(256, (3, 3), strides=(2, 2), padding='same')(c6)
    u7 = layers.concatenate([u7, c3])
    c7 = layers.Conv2D(256, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(u7)
    c7 = layers.BatchNormalization()(c7)
    c7 = layers.Conv2D(256, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c7)
    c7 = layers.BatchNormalization()(c7)

    # Layer 8
    u8 = layers.Conv2DTranspose(128, (3, 3), strides=(2, 2), padding='same')(c7)
    u8 = layers.concatenate([u8, c2])
    c8 = layers.Conv2D(128, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(u8)
    c8 = layers.BatchNormalization()(c8)
    c8 = layers.Conv2D(128, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c8)
    c8 = layers.BatchNormalization()(c8)

    # Layer 9
    u9 = layers.Conv2DTranspose(64, (3, 3), strides=(2, 2), padding='same')(c8)
    u9 = layers.concatenate([u9, c1], axis=3)
    c9 = layers.Conv2D(64, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(u9)
    c9 = layers.BatchNormalization()(c9)
    c9 = layers.Conv2D(64, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c9)
    c9 = layers.BatchNormalization()(c9)

    outputs = layers.Conv2D(1, (1, 1), activation='sigmoid')(c9)

    # Compiling model
    model = Model(inputs=[inputs], outputs=[outputs], name='UNet')
    return model

In [15]:
model = unet()

In [16]:
model.summary()

Model: "UNet"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_5 (InputLayer)           [(None, 512, 512, 3  0           []                               
                                )]                                                                
                                                                                                  
 lambda_3 (Lambda)              (None, 512, 512, 3)  0           ['input_5[0][0]']                
                                                                                                  
 conv2d_48 (Conv2D)             (None, 512, 512, 64  1792        ['lambda_3[0][0]']               
                                )                                                                 
                                                                                               

                                                                                                  
 conv2d_58 (Conv2D)             (None, 64, 64, 512)  4719104     ['concatenate_9[0][0]']          
                                                                                                  
 batch_normalization_56 (BatchN  (None, 64, 64, 512)  2048       ['conv2d_58[0][0]']              
 ormalization)                                                                                    
                                                                                                  
 conv2d_59 (Conv2D)             (None, 64, 64, 512)  2359808     ['batch_normalization_56[0][0]'] 
                                                                                                  
 batch_normalization_57 (BatchN  (None, 64, 64, 512)  2048       ['conv2d_59[0][0]']              
 ormalization)                                                                                    
          

## Block functions

In [2]:
def conv_block(input, num_filters):
    """
    Creates a block consisting of two convolutional layers.
    Args:
        input: The output of the previous up-convolution.
        num_filters: Number of filters to be used in the convolution.
    Returns:
        The result of the final calculation of the current layer.
    """
    x = layers.Conv2D(num_filters, 3, activation='relu', kernel_initializer='he_normal',  padding='same')(input)
    x = layers.BatchNormalization()(x)

    x = layers.Conv2D(num_filters, 3, activation='relu', kernel_initializer='he_normal', padding='same')(x)
    x = layers.BatchNormalization()(x)
    return x


def decoder_block(input, skip_features, num_filters):
    """
    Create a decoder block used in the decoder part of the network.
    Args:
        input: The output of the previous layer.
        skip_features: The features of the same size from the encoder part of the network.
        num_filters: Number of filters to be used in the convolution.
    Returns:
        The result of the final calculation of the current layer.
    """
    x = layers.Conv2DTranspose(num_filters, (3, 3), strides=2, activation='relu', kernel_initializer='he_normal', padding='same')(input)
    x = layers.BatchNormalization()(x)
    x = layers.concatenate([x, skip_features])
    x = conv_block(x, num_filters)
    return x


## EfficientNetB4

In [3]:
# EfficientNetB4 UNet
def EfficientNetB4_unet(input_shape=(512, 512, 3), weight='imagenet'):
    """
    Creates a neural network using the U-Net architecture and EfficientNetB4 as backbone.
    Args:
        input_shape: The size of the input image.
        weight: Pre-trained weights.
    Returns:
        A U-Net model using EfficientNetB4 as backbone.
    """
    # Input
    inputs = layers.Input(input_shape)

    # Loading pre trained model
    EffNetB4 = EfficientNetB4(include_top=False, weights=weight, input_tensor=inputs)

    # Encoder
    s1 = EffNetB4.get_layer('rescaling').output  # 512 x 512
    s2 = EffNetB4.get_layer('block2a_expand_activation').output  # 256 x 256
    s3 = EffNetB4.get_layer('block3a_expand_activation').output  # 128 x 128
    s4 = EffNetB4.get_layer('block4a_expand_activation').output  # 64 x 64

    # Bottleneck
    b1 = EffNetB4.get_layer('block6a_expand_activation').output  # 32 x 32

    # Decoder
    d1 = decoder_block(b1, s4, 512)  # 64 x 64
    d2 = decoder_block(d1, s3, 256)    # 128 x 128
    d3 = decoder_block(d2, s2, 128)   # 256 x 256
    d4 = decoder_block(d3, s1, 64)   # 512 x 512

    # Output
    outputs = layers.Conv2D(1, 1, padding='same', activation='sigmoid')(d4)

    model = Model(inputs, outputs, name='EfficientNetB4_U-Net')
    return model

In [4]:
model = EfficientNetB4_unet()

In [5]:
model.summary()

Model: "EfficientNetB4_U-Net"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_1 (InputLayer)           [(None, 512, 512, 3  0           []                               
                                )]                                                                
                                                                                                  
 rescaling (Rescaling)          (None, 512, 512, 3)  0           ['input_1[0][0]']                
                                                                                                  
 normalization (Normalization)  (None, 512, 512, 3)  7           ['rescaling[0][0]']              
                                                                                                  
 rescaling_1 (Rescaling)        (None, 512, 512, 3)  0           ['normalizatio

 block2a_expand_conv (Conv2D)   (None, 256, 256, 14  3456        ['block1b_add[0][0]']            
                                4)                                                                
                                                                                                  
 block2a_expand_bn (BatchNormal  (None, 256, 256, 14  576        ['block2a_expand_conv[0][0]']    
 ization)                       4)                                                                
                                                                                                  
 block2a_expand_activation (Act  (None, 256, 256, 14  0          ['block2a_expand_bn[0][0]']      
 ivation)                       4)                                                                
                                                                                                  
 block2a_dwconv_pad (ZeroPaddin  (None, 257, 257, 14  0          ['block2a_expand_activation[0][0]
 g2D)     

                                                                                                  
 block2c_expand_bn (BatchNormal  (None, 128, 128, 19  768        ['block2c_expand_conv[0][0]']    
 ization)                       2)                                                                
                                                                                                  
 block2c_expand_activation (Act  (None, 128, 128, 19  0          ['block2c_expand_bn[0][0]']      
 ivation)                       2)                                                                
                                                                                                  
 block2c_dwconv (DepthwiseConv2  (None, 128, 128, 19  1728       ['block2c_expand_activation[0][0]
 D)                             2)                               ']                               
                                                                                                  
 block2c_b

                                2)                                                                
                                                                                                  
 block3a_expand_bn (BatchNormal  (None, 128, 128, 19  768        ['block3a_expand_conv[0][0]']    
 ization)                       2)                                                                
                                                                                                  
 block3a_expand_activation (Act  (None, 128, 128, 19  0          ['block3a_expand_bn[0][0]']      
 ivation)                       2)                                                                
                                                                                                  
 block3a_dwconv_pad (ZeroPaddin  (None, 131, 131, 19  0          ['block3a_expand_activation[0][0]
 g2D)                           2)                               ']                               
          

                                                                                                  
 block3c_dwconv (DepthwiseConv2  (None, 64, 64, 336)  8400       ['block3c_expand_activation[0][0]
 D)                                                              ']                               
                                                                                                  
 block3c_bn (BatchNormalization  (None, 64, 64, 336)  1344       ['block3c_dwconv[0][0]']         
 )                                                                                                
                                                                                                  
 block3c_activation (Activation  (None, 64, 64, 336)  0          ['block3c_bn[0][0]']             
 )                                                                                                
                                                                                                  
 block3c_s

 D)                                                                                               
                                                                                                  
 block4a_bn (BatchNormalization  (None, 32, 32, 336)  1344       ['block4a_dwconv[0][0]']         
 )                                                                                                
                                                                                                  
 block4a_activation (Activation  (None, 32, 32, 336)  0          ['block4a_bn[0][0]']             
 )                                                                                                
                                                                                                  
 block4a_se_squeeze (GlobalAver  (None, 336)         0           ['block4a_activation[0][0]']     
 agePooling2D)                                                                                    
          

                                                                                                  
 block4c_se_reshape (Reshape)   (None, 1, 1, 672)    0           ['block4c_se_squeeze[0][0]']     
                                                                                                  
 block4c_se_reduce (Conv2D)     (None, 1, 1, 28)     18844       ['block4c_se_reshape[0][0]']     
                                                                                                  
 block4c_se_expand (Conv2D)     (None, 1, 1, 672)    19488       ['block4c_se_reduce[0][0]']      
                                                                                                  
 block4c_se_excite (Multiply)   (None, 32, 32, 672)  0           ['block4c_activation[0][0]',     
                                                                  'block4c_se_expand[0][0]']      
                                                                                                  
 block4c_p

 block4e_se_expand (Conv2D)     (None, 1, 1, 672)    19488       ['block4e_se_reduce[0][0]']      
                                                                                                  
 block4e_se_excite (Multiply)   (None, 32, 32, 672)  0           ['block4e_activation[0][0]',     
                                                                  'block4e_se_expand[0][0]']      
                                                                                                  
 block4e_project_conv (Conv2D)  (None, 32, 32, 112)  75264       ['block4e_se_excite[0][0]']      
                                                                                                  
 block4e_project_bn (BatchNorma  (None, 32, 32, 112)  448        ['block4e_project_conv[0][0]']   
 lization)                                                                                        
                                                                                                  
 block4e_d

 block5a_project_conv (Conv2D)  (None, 32, 32, 160)  107520      ['block5a_se_excite[0][0]']      
                                                                                                  
 block5a_project_bn (BatchNorma  (None, 32, 32, 160)  640        ['block5a_project_conv[0][0]']   
 lization)                                                                                        
                                                                                                  
 block5b_expand_conv (Conv2D)   (None, 32, 32, 960)  153600      ['block5a_project_bn[0][0]']     
                                                                                                  
 block5b_expand_bn (BatchNormal  (None, 32, 32, 960)  3840       ['block5b_expand_conv[0][0]']    
 ization)                                                                                         
                                                                                                  
 block5b_e

 block5d_expand_conv (Conv2D)   (None, 32, 32, 960)  153600      ['block5c_add[0][0]']            
                                                                                                  
 block5d_expand_bn (BatchNormal  (None, 32, 32, 960)  3840       ['block5d_expand_conv[0][0]']    
 ization)                                                                                         
                                                                                                  
 block5d_expand_activation (Act  (None, 32, 32, 960)  0          ['block5d_expand_bn[0][0]']      
 ivation)                                                                                         
                                                                                                  
 block5d_dwconv (DepthwiseConv2  (None, 32, 32, 960)  24000      ['block5d_expand_activation[0][0]
 D)                                                              ']                               
          

 block5f_expand_activation (Act  (None, 32, 32, 960)  0          ['block5f_expand_bn[0][0]']      
 ivation)                                                                                         
                                                                                                  
 block5f_dwconv (DepthwiseConv2  (None, 32, 32, 960)  24000      ['block5f_expand_activation[0][0]
 D)                                                              ']                               
                                                                                                  
 block5f_bn (BatchNormalization  (None, 32, 32, 960)  3840       ['block5f_dwconv[0][0]']         
 )                                                                                                
                                                                                                  
 block5f_activation (Activation  (None, 32, 32, 960)  0          ['block5f_bn[0][0]']             
 )        

                                                                                                  
 conv2d_transpose_2 (Conv2DTran  (None, 256, 256, 12  295040     ['batch_normalization_5[0][0]']  
 spose)                         8)                                                                
                                                                                                  
 batch_normalization_6 (BatchNo  (None, 256, 256, 12  512        ['conv2d_transpose_2[0][0]']     
 rmalization)                   8)                                                                
                                                                                                  
 concatenate_2 (Concatenate)    (None, 256, 256, 27  0           ['batch_normalization_6[0][0]',  
                                2)                                'block2a_expand_activation[0][0]
                                                                 ']                               
          

## EfficientNetV2S

In [3]:
def EfficientNetV2S_unet(input_shape=(512, 512, 3), weight='imagenet'):
    """
    Creates a neural network using the U-Net architecture and EfficientNetV2S as backbone.
    Args:
        input_shape: The size of the input image.
        weight: Pre-trained weights.
    Returns:
        A U-Net model using EfficientNetB4 as backbone.
    """
    # Input
    inputs = layers.Input(input_shape)

    # Loading pre trained model
    EffNetV2S = EfficientNetV2S(include_top=False, weights=weight, input_tensor=inputs)

    # Encoder
    s1 = EffNetV2S.get_layer('rescaling').output  # 512 x 512
    s2 = EffNetV2S.get_layer('block1b_add').output  # 256 x 256
    s3 = EffNetV2S.get_layer('block2d_add').output  # 128 x 128
    s4 = EffNetV2S.get_layer('block4a_expand_activation').output  # 64 x 64

    # Bottleneck
    b1 = EffNetV2S.get_layer('block6a_expand_activation').output  # 32 x 32

    # Decoder
    d1 = decoder_block(b1, s4, 512)  # 64 x 64
    d2 = decoder_block(d1, s3, 256)    # 128 x 128
    d3 = decoder_block(d2, s2, 128)   # 256 x 256
    d4 = decoder_block(d3, s1, 64)   # 512 x 512

    # Output
    outputs = layers.Conv2D(1, 1, padding='same', activation='sigmoid')(d4)

    model = Model(inputs, outputs, name='EfficientNetV2S_U-Net')
    return model

In [4]:
model = EfficientNetV2S_unet()

In [5]:
model.summary()

Model: "EfficientNetV2S_U-Net"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_1 (InputLayer)           [(None, 512, 512, 3  0           []                               
                                )]                                                                
                                                                                                  
 rescaling (Rescaling)          (None, 512, 512, 3)  0           ['input_1[0][0]']                
                                                                                                  
 stem_conv (Conv2D)             (None, 256, 256, 24  648         ['rescaling[0][0]']              
                                )                                                                 
                                                                              

                                                                                                  
 block2c_expand_bn (BatchNormal  (None, 128, 128, 19  768        ['block2c_expand_conv[0][0]']    
 ization)                       2)                                                                
                                                                                                  
 block2c_expand_activation (Act  (None, 128, 128, 19  0          ['block2c_expand_bn[0][0]']      
 ivation)                       2)                                                                
                                                                                                  
 block2c_project_conv (Conv2D)  (None, 128, 128, 48  9216        ['block2c_expand_activation[0][0]
                                )                                ']                               
                                                                                                  
 block2c_p

                                                                                                  
 block3c_project_bn (BatchNorma  (None, 64, 64, 64)  256         ['block3c_project_conv[0][0]']   
 lization)                                                                                        
                                                                                                  
 block3c_drop (Dropout)         (None, 64, 64, 64)   0           ['block3c_project_bn[0][0]']     
                                                                                                  
 block3c_add (Add)              (None, 64, 64, 64)   0           ['block3c_drop[0][0]',           
                                                                  'block3b_add[0][0]']            
                                                                                                  
 block3d_expand_conv (Conv2D)   (None, 64, 64, 256)  147456      ['block3c_add[0][0]']            
          

                                                                                                  
 block4b_se_reduce (Conv2D)     (None, 1, 1, 32)     16416       ['block4b_se_reshape[0][0]']     
                                                                                                  
 block4b_se_expand (Conv2D)     (None, 1, 1, 512)    16896       ['block4b_se_reduce[0][0]']      
                                                                                                  
 block4b_se_excite (Multiply)   (None, 32, 32, 512)  0           ['block4b_activation[0][0]',     
                                                                  'block4b_se_expand[0][0]']      
                                                                                                  
 block4b_project_conv (Conv2D)  (None, 32, 32, 128)  65536       ['block4b_se_excite[0][0]']      
                                                                                                  
 block4b_p

 block4d_se_excite (Multiply)   (None, 32, 32, 512)  0           ['block4d_activation[0][0]',     
                                                                  'block4d_se_expand[0][0]']      
                                                                                                  
 block4d_project_conv (Conv2D)  (None, 32, 32, 128)  65536       ['block4d_se_excite[0][0]']      
                                                                                                  
 block4d_project_bn (BatchNorma  (None, 32, 32, 128)  512        ['block4d_project_conv[0][0]']   
 lization)                                                                                        
                                                                                                  
 block4d_drop (Dropout)         (None, 32, 32, 128)  0           ['block4d_project_bn[0][0]']     
                                                                                                  
 block4d_a

 block4f_project_bn (BatchNorma  (None, 32, 32, 128)  512        ['block4f_project_conv[0][0]']   
 lization)                                                                                        
                                                                                                  
 block4f_drop (Dropout)         (None, 32, 32, 128)  0           ['block4f_project_bn[0][0]']     
                                                                                                  
 block4f_add (Add)              (None, 32, 32, 128)  0           ['block4f_drop[0][0]',           
                                                                  'block4e_add[0][0]']            
                                                                                                  
 block5a_expand_conv (Conv2D)   (None, 32, 32, 768)  98304       ['block4f_add[0][0]']            
                                                                                                  
 block5a_e

 block5c_expand_bn (BatchNormal  (None, 32, 32, 960)  3840       ['block5c_expand_conv[0][0]']    
 ization)                                                                                         
                                                                                                  
 block5c_expand_activation (Act  (None, 32, 32, 960)  0          ['block5c_expand_bn[0][0]']      
 ivation)                                                                                         
                                                                                                  
 block5c_dwconv2 (DepthwiseConv  (None, 32, 32, 960)  8640       ['block5c_expand_activation[0][0]
 2D)                                                             ']                               
                                                                                                  
 block5c_bn (BatchNormalization  (None, 32, 32, 960)  3840       ['block5c_dwconv2[0][0]']        
 )        

                                                                                                  
 block5e_dwconv2 (DepthwiseConv  (None, 32, 32, 960)  8640       ['block5e_expand_activation[0][0]
 2D)                                                             ']                               
                                                                                                  
 block5e_bn (BatchNormalization  (None, 32, 32, 960)  3840       ['block5e_dwconv2[0][0]']        
 )                                                                                                
                                                                                                  
 block5e_activation (Activation  (None, 32, 32, 960)  0          ['block5e_bn[0][0]']             
 )                                                                                                
                                                                                                  
 block5e_s

 )                                                                                                
                                                                                                  
 block5g_activation (Activation  (None, 32, 32, 960)  0          ['block5g_bn[0][0]']             
 )                                                                                                
                                                                                                  
 block5g_se_squeeze (GlobalAver  (None, 960)         0           ['block5g_activation[0][0]']     
 agePooling2D)                                                                                    
                                                                                                  
 block5g_se_reshape (Reshape)   (None, 1, 1, 960)    0           ['block5g_se_squeeze[0][0]']     
                                                                                                  
 block5g_s

 block5i_se_squeeze (GlobalAver  (None, 960)         0           ['block5i_activation[0][0]']     
 agePooling2D)                                                                                    
                                                                                                  
 block5i_se_reshape (Reshape)   (None, 1, 1, 960)    0           ['block5i_se_squeeze[0][0]']     
                                                                                                  
 block5i_se_reduce (Conv2D)     (None, 1, 1, 40)     38440       ['block5i_se_reshape[0][0]']     
                                                                                                  
 block5i_se_expand (Conv2D)     (None, 1, 1, 960)    39360       ['block5i_se_reduce[0][0]']      
                                                                                                  
 block5i_se_excite (Multiply)   (None, 32, 32, 960)  0           ['block5i_activation[0][0]',     
          

 batch_normalization_7 (BatchNo  (None, 256, 256, 12  512        ['conv2d_4[0][0]']               
 rmalization)                   8)                                                                
                                                                                                  
 conv2d_5 (Conv2D)              (None, 256, 256, 12  147584      ['batch_normalization_7[0][0]']  
                                8)                                                                
                                                                                                  
 batch_normalization_8 (BatchNo  (None, 256, 256, 12  512        ['conv2d_5[0][0]']               
 rmalization)                   8)                                                                
                                                                                                  
 conv2d_transpose_3 (Conv2DTran  (None, 512, 512, 64  73792      ['batch_normalization_8[0][0]']  
 spose)   

## ResNet50V2

In [3]:
def ResNet50V2_unet(input_shape=(512, 512, 3), weight='imagenet'):
    """
    Creates a neural network using the U-Net architecture and ResNet50V2 as backbone.
    Args:
        input_shape: The size of the input image.
        weight: Pre-trained weights.
    Returns:
        A U-Net model using EfficientNetB4 as backbone.
    """
    # Input
    inputs = layers.Input(input_shape)

    # Loading pre trained model
    ResNet50 = ResNet50V2(include_top=False, weights=weight, input_tensor=inputs)

    # Encoder
    s1 = ResNet50.get_layer('input_1').output  # 512 x 512
    s2 = ResNet50.get_layer('conv1_conv').output  # 256 x 256
    s3 = ResNet50.get_layer('conv2_block3_1_relu').output  # 128 x 128
    s4 = ResNet50.get_layer('conv3_block4_1_relu').output  # 64 x 64

    # Bottleneck
    b1 = ResNet50.get_layer('conv4_block6_1_relu').output  # 32 x 32

    # Decoder
    d1 = decoder_block(b1, s4, 512)  # 64 x 64
    d2 = decoder_block(d1, s3, 256)    # 128 x 128
    d3 = decoder_block(d2, s2, 128)   # 256 x 256
    d4 = decoder_block(d3, s1, 64)   # 512 x 512

    # Output
    outputs = layers.Conv2D(1, 1, padding='same', activation='sigmoid')(d4)

    model = Model(inputs, outputs, name='ResNet50V2_U-Net')
    return model

In [4]:
model = ResNet50V2_unet()

In [5]:
model.summary()

Model: "ResNet50V2_U-Net"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_1 (InputLayer)           [(None, 512, 512, 3  0           []                               
                                )]                                                                
                                                                                                  
 conv1_pad (ZeroPadding2D)      (None, 518, 518, 3)  0           ['input_1[0][0]']                
                                                                                                  
 conv1_conv (Conv2D)            (None, 256, 256, 64  9472        ['conv1_pad[0][0]']              
                                )                                                                 
                                                                                   

 conv2_block2_out (Add)         (None, 128, 128, 25  0           ['conv2_block1_out[0][0]',       
                                6)                                'conv2_block2_3_conv[0][0]']    
                                                                                                  
 conv2_block3_preact_bn (BatchN  (None, 128, 128, 25  1024       ['conv2_block2_out[0][0]']       
 ormalization)                  6)                                                                
                                                                                                  
 conv2_block3_preact_relu (Acti  (None, 128, 128, 25  0          ['conv2_block3_preact_bn[0][0]'] 
 vation)                        6)                                                                
                                                                                                  
 conv2_block3_1_conv (Conv2D)   (None, 128, 128, 64  16384       ['conv2_block3_preact_relu[0][0]'
          

 n)                                                                                               
                                                                                                  
 conv3_block2_2_pad (ZeroPaddin  (None, 66, 66, 128)  0          ['conv3_block2_1_relu[0][0]']    
 g2D)                                                                                             
                                                                                                  
 conv3_block2_2_conv (Conv2D)   (None, 64, 64, 128)  147456      ['conv3_block2_2_pad[0][0]']     
                                                                                                  
 conv3_block2_2_bn (BatchNormal  (None, 64, 64, 128)  512        ['conv3_block2_2_conv[0][0]']    
 ization)                                                                                         
                                                                                                  
 conv3_blo

 ormalization)                                                                                    
                                                                                                  
 conv4_block1_preact_relu (Acti  (None, 32, 32, 512)  0          ['conv4_block1_preact_bn[0][0]'] 
 vation)                                                                                          
                                                                                                  
 conv4_block1_1_conv (Conv2D)   (None, 32, 32, 256)  131072      ['conv4_block1_preact_relu[0][0]'
                                                                 ]                                
                                                                                                  
 conv4_block1_1_bn (BatchNormal  (None, 32, 32, 256)  1024       ['conv4_block1_1_conv[0][0]']    
 ization)                                                                                         
          

                                                                                                  
 conv4_block3_2_conv (Conv2D)   (None, 32, 32, 256)  589824      ['conv4_block3_2_pad[0][0]']     
                                                                                                  
 conv4_block3_2_bn (BatchNormal  (None, 32, 32, 256)  1024       ['conv4_block3_2_conv[0][0]']    
 ization)                                                                                         
                                                                                                  
 conv4_block3_2_relu (Activatio  (None, 32, 32, 256)  0          ['conv4_block3_2_bn[0][0]']      
 n)                                                                                               
                                                                                                  
 conv4_block3_3_conv (Conv2D)   (None, 32, 32, 1024  263168      ['conv4_block3_2_relu[0][0]']    
          

 vation)                        )                                                                 
                                                                                                  
 conv4_block6_1_conv (Conv2D)   (None, 32, 32, 256)  262144      ['conv4_block6_preact_relu[0][0]'
                                                                 ]                                
                                                                                                  
 conv4_block6_1_bn (BatchNormal  (None, 32, 32, 256)  1024       ['conv4_block6_1_conv[0][0]']    
 ization)                                                                                         
                                                                                                  
 conv4_block6_1_relu (Activatio  (None, 32, 32, 256)  0          ['conv4_block6_1_bn[0][0]']      
 n)                                                                                               
          

                                                                                                  
 batch_normalization_10 (BatchN  (None, 512, 512, 64  256        ['conv2d_6[0][0]']               
 ormalization)                  )                                                                 
                                                                                                  
 conv2d_7 (Conv2D)              (None, 512, 512, 64  36928       ['batch_normalization_10[0][0]'] 
                                )                                                                 
                                                                                                  
 batch_normalization_11 (BatchN  (None, 512, 512, 64  256        ['conv2d_7[0][0]']               
 ormalization)                  )                                                                 
                                                                                                  
 conv2d_8 

## DenseNet201

In [3]:
def DenseNet201_unet(input_shape=(512, 512, 3), weight='imagenet'):
    """
    Creates a neural network using the U-Net architecture and DenseNet201 as backbone.
    Args:
        input_shape: The size of the input image.
        weight: Pre-trained weights.
    Returns:
        A U-Net model using EfficientNetB4 as backbone.
    """
    # Input
    inputs = layers.Input(input_shape)

    # Loading pre trained model
    DenseNet = DenseNet201(include_top=False, weights=weight, input_tensor=inputs)

    # Encoder
    s1 = DenseNet.get_layer('input_1').output  # 512 x 512
    s2 = DenseNet.get_layer('conv1/relu').output  # 256 x 256
    s3 = DenseNet.get_layer('pool2_conv').output  # 128 x 128
    s4 = DenseNet.get_layer('pool3_conv').output  # 64 x 64

    # Bottleneck
    b1 = DenseNet.get_layer('pool4_conv').output  # 32 x 32

    # Decoder
    d1 = decoder_block(b1, s4, 512)  # 64 x 64
    d2 = decoder_block(d1, s3, 256)    # 128 x 128
    d3 = decoder_block(d2, s2, 128)   # 256 x 256
    d4 = decoder_block(d3, s1, 64)   # 512 x 512

    # Output
    outputs = layers.Conv2D(1, 1, padding='same', activation='sigmoid')(d4)

    model = Model(inputs, outputs, name='DenseNet201_U-Net')
    return model

In [4]:
model = DenseNet201_unet()

In [5]:
model.summary()

Model: "DenseNet201_U-Net"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_1 (InputLayer)           [(None, 512, 512, 3  0           []                               
                                )]                                                                
                                                                                                  
 zero_padding2d (ZeroPadding2D)  (None, 518, 518, 3)  0          ['input_1[0][0]']                
                                                                                                  
 conv1/conv (Conv2D)            (None, 256, 256, 64  9408        ['zero_padding2d[0][0]']         
                                )                                                                 
                                                                                  

 conv2_block3_concat (Concatena  (None, 128, 128, 16  0          ['conv2_block2_concat[0][0]',    
 te)                            0)                                'conv2_block3_2_conv[0][0]']    
                                                                                                  
 conv2_block4_0_bn (BatchNormal  (None, 128, 128, 16  640        ['conv2_block3_concat[0][0]']    
 ization)                       0)                                                                
                                                                                                  
 conv2_block4_0_relu (Activatio  (None, 128, 128, 16  0          ['conv2_block4_0_bn[0][0]']      
 n)                             0)                                                                
                                                                                                  
 conv2_block4_1_conv (Conv2D)   (None, 128, 128, 12  20480       ['conv2_block4_0_relu[0][0]']    
          

 conv3_block1_1_conv (Conv2D)   (None, 64, 64, 128)  16384       ['conv3_block1_0_relu[0][0]']    
                                                                                                  
 conv3_block1_1_bn (BatchNormal  (None, 64, 64, 128)  512        ['conv3_block1_1_conv[0][0]']    
 ization)                                                                                         
                                                                                                  
 conv3_block1_1_relu (Activatio  (None, 64, 64, 128)  0          ['conv3_block1_1_bn[0][0]']      
 n)                                                                                               
                                                                                                  
 conv3_block1_2_conv (Conv2D)   (None, 64, 64, 32)   36864       ['conv3_block1_1_relu[0][0]']    
                                                                                                  
 conv3_blo

                                                                                                  
 conv3_block5_2_conv (Conv2D)   (None, 64, 64, 32)   36864       ['conv3_block5_1_relu[0][0]']    
                                                                                                  
 conv3_block5_concat (Concatena  (None, 64, 64, 288)  0          ['conv3_block4_concat[0][0]',    
 te)                                                              'conv3_block5_2_conv[0][0]']    
                                                                                                  
 conv3_block6_0_bn (BatchNormal  (None, 64, 64, 288)  1152       ['conv3_block5_concat[0][0]']    
 ization)                                                                                         
                                                                                                  
 conv3_block6_0_relu (Activatio  (None, 64, 64, 288)  0          ['conv3_block6_0_bn[0][0]']      
 n)       

 lization)                                                                                        
                                                                                                  
 conv3_block10_0_relu (Activati  (None, 64, 64, 416)  0          ['conv3_block10_0_bn[0][0]']     
 on)                                                                                              
                                                                                                  
 conv3_block10_1_conv (Conv2D)  (None, 64, 64, 128)  53248       ['conv3_block10_0_relu[0][0]']   
                                                                                                  
 conv3_block10_1_bn (BatchNorma  (None, 64, 64, 128)  512        ['conv3_block10_1_conv[0][0]']   
 lization)                                                                                        
                                                                                                  
 conv3_blo

 conv4_block2_0_bn (BatchNormal  (None, 32, 32, 288)  1152       ['conv4_block1_concat[0][0]']    
 ization)                                                                                         
                                                                                                  
 conv4_block2_0_relu (Activatio  (None, 32, 32, 288)  0          ['conv4_block2_0_bn[0][0]']      
 n)                                                                                               
                                                                                                  
 conv4_block2_1_conv (Conv2D)   (None, 32, 32, 128)  36864       ['conv4_block2_0_relu[0][0]']    
                                                                                                  
 conv4_block2_1_bn (BatchNormal  (None, 32, 32, 128)  512        ['conv4_block2_1_conv[0][0]']    
 ization)                                                                                         
          

                                                                                                  
 conv4_block6_1_bn (BatchNormal  (None, 32, 32, 128)  512        ['conv4_block6_1_conv[0][0]']    
 ization)                                                                                         
                                                                                                  
 conv4_block6_1_relu (Activatio  (None, 32, 32, 128)  0          ['conv4_block6_1_bn[0][0]']      
 n)                                                                                               
                                                                                                  
 conv4_block6_2_conv (Conv2D)   (None, 32, 32, 32)   36864       ['conv4_block6_1_relu[0][0]']    
                                                                                                  
 conv4_block6_concat (Concatena  (None, 32, 32, 448)  0          ['conv4_block5_concat[0][0]',    
 te)      

 conv4_block10_2_conv (Conv2D)  (None, 32, 32, 32)   36864       ['conv4_block10_1_relu[0][0]']   
                                                                                                  
 conv4_block10_concat (Concaten  (None, 32, 32, 576)  0          ['conv4_block9_concat[0][0]',    
 ate)                                                             'conv4_block10_2_conv[0][0]']   
                                                                                                  
 conv4_block11_0_bn (BatchNorma  (None, 32, 32, 576)  2304       ['conv4_block10_concat[0][0]']   
 lization)                                                                                        
                                                                                                  
 conv4_block11_0_relu (Activati  (None, 32, 32, 576)  0          ['conv4_block11_0_bn[0][0]']     
 on)                                                                                              
          

                                                                                                  
 conv4_block15_0_relu (Activati  (None, 32, 32, 704)  0          ['conv4_block15_0_bn[0][0]']     
 on)                                                                                              
                                                                                                  
 conv4_block15_1_conv (Conv2D)  (None, 32, 32, 128)  90112       ['conv4_block15_0_relu[0][0]']   
                                                                                                  
 conv4_block15_1_bn (BatchNorma  (None, 32, 32, 128)  512        ['conv4_block15_1_conv[0][0]']   
 lization)                                                                                        
                                                                                                  
 conv4_block15_1_relu (Activati  (None, 32, 32, 128)  0          ['conv4_block15_1_bn[0][0]']     
 on)      

 lization)                                                                                        
                                                                                                  
 conv4_block19_1_relu (Activati  (None, 32, 32, 128)  0          ['conv4_block19_1_bn[0][0]']     
 on)                                                                                              
                                                                                                  
 conv4_block19_2_conv (Conv2D)  (None, 32, 32, 32)   36864       ['conv4_block19_1_relu[0][0]']   
                                                                                                  
 conv4_block19_concat (Concaten  (None, 32, 32, 864)  0          ['conv4_block18_concat[0][0]',   
 ate)                                                             'conv4_block19_2_conv[0][0]']   
                                                                                                  
 conv4_blo

 conv4_block23_concat (Concaten  (None, 32, 32, 992)  0          ['conv4_block22_concat[0][0]',   
 ate)                                                             'conv4_block23_2_conv[0][0]']   
                                                                                                  
 conv4_block24_0_bn (BatchNorma  (None, 32, 32, 992)  3968       ['conv4_block23_concat[0][0]']   
 lization)                                                                                        
                                                                                                  
 conv4_block24_0_relu (Activati  (None, 32, 32, 992)  0          ['conv4_block24_0_bn[0][0]']     
 on)                                                                                              
                                                                                                  
 conv4_block24_1_conv (Conv2D)  (None, 32, 32, 128)  126976      ['conv4_block24_0_relu[0][0]']   
          

 on)                            )                                                                 
                                                                                                  
 conv4_block28_1_conv (Conv2D)  (None, 32, 32, 128)  143360      ['conv4_block28_0_relu[0][0]']   
                                                                                                  
 conv4_block28_1_bn (BatchNorma  (None, 32, 32, 128)  512        ['conv4_block28_1_conv[0][0]']   
 lization)                                                                                        
                                                                                                  
 conv4_block28_1_relu (Activati  (None, 32, 32, 128)  0          ['conv4_block28_1_bn[0][0]']     
 on)                                                                                              
                                                                                                  
 conv4_blo

 conv4_block32_1_relu (Activati  (None, 32, 32, 128)  0          ['conv4_block32_1_bn[0][0]']     
 on)                                                                                              
                                                                                                  
 conv4_block32_2_conv (Conv2D)  (None, 32, 32, 32)   36864       ['conv4_block32_1_relu[0][0]']   
                                                                                                  
 conv4_block32_concat (Concaten  (None, 32, 32, 1280  0          ['conv4_block31_concat[0][0]',   
 ate)                           )                                 'conv4_block32_2_conv[0][0]']   
                                                                                                  
 conv4_block33_0_bn (BatchNorma  (None, 32, 32, 1280  5120       ['conv4_block32_concat[0][0]']   
 lization)                      )                                                                 
          

                                                                                                  
 conv4_block37_0_bn (BatchNorma  (None, 32, 32, 1408  5632       ['conv4_block36_concat[0][0]']   
 lization)                      )                                                                 
                                                                                                  
 conv4_block37_0_relu (Activati  (None, 32, 32, 1408  0          ['conv4_block37_0_bn[0][0]']     
 on)                            )                                                                 
                                                                                                  
 conv4_block37_1_conv (Conv2D)  (None, 32, 32, 128)  180224      ['conv4_block37_0_relu[0][0]']   
                                                                                                  
 conv4_block37_1_bn (BatchNorma  (None, 32, 32, 128)  512        ['conv4_block37_1_conv[0][0]']   
 lization)

 conv4_block41_1_conv (Conv2D)  (None, 32, 32, 128)  196608      ['conv4_block41_0_relu[0][0]']   
                                                                                                  
 conv4_block41_1_bn (BatchNorma  (None, 32, 32, 128)  512        ['conv4_block41_1_conv[0][0]']   
 lization)                                                                                        
                                                                                                  
 conv4_block41_1_relu (Activati  (None, 32, 32, 128)  0          ['conv4_block41_1_bn[0][0]']     
 on)                                                                                              
                                                                                                  
 conv4_block41_2_conv (Conv2D)  (None, 32, 32, 32)   36864       ['conv4_block41_1_relu[0][0]']   
                                                                                                  
 conv4_blo

                                                                                                  
 conv4_block45_2_conv (Conv2D)  (None, 32, 32, 32)   36864       ['conv4_block45_1_relu[0][0]']   
                                                                                                  
 conv4_block45_concat (Concaten  (None, 32, 32, 1696  0          ['conv4_block44_concat[0][0]',   
 ate)                           )                                 'conv4_block45_2_conv[0][0]']   
                                                                                                  
 conv4_block46_0_bn (BatchNorma  (None, 32, 32, 1696  6784       ['conv4_block45_concat[0][0]']   
 lization)                      )                                                                 
                                                                                                  
 conv4_block46_0_relu (Activati  (None, 32, 32, 1696  0          ['conv4_block46_0_bn[0][0]']     
 on)      

 rmalization)                                                                                     
                                                                                                  
 conv2d_1 (Conv2D)              (None, 64, 64, 512)  2359808     ['batch_normalization_1[0][0]']  
                                                                                                  
 batch_normalization_2 (BatchNo  (None, 64, 64, 512)  2048       ['conv2d_1[0][0]']               
 rmalization)                                                                                     
                                                                                                  
 conv2d_transpose_1 (Conv2DTran  (None, 128, 128, 25  1179904    ['batch_normalization_2[0][0]']  
 spose)                         6)                                                                
                                                                                                  
 batch_nor