# U-Net

## Imports

In [1]:
from tensorflow.keras import layers, optimizers, backend, Model
from tensorflow.keras.applications import EfficientNetB4, EfficientNetV2S, ResNet50V2, DenseNet201

## Block functions

In [2]:
def conv_block(input, num_filters):
    """
    Creates a block consisting of two convolutional layers.
    Args:
        input: The output of the previous up-convolution.
        num_filters: Number of filters to be used in the convolution.
    Returns:
        The result of the final calculation of the current layer.
    """
    x = layers.Conv2D(num_filters, 3, padding='same')(input)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)

    x = layers.Conv2D(num_filters, 3, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    return x


def decoder_block(input, skip_features, num_filters):
    """
    Create a decoder block used in the decoder part of the network.
    Args:
        input: The output of the previous layer.
        skip_features: The features of the same size from the encoder part of the network.
        num_filters: Number of filters to be used in the convolution.
    Returns:
        The result of the final calculation of the current layer.
    """
    x = layers.Conv2DTranspose(num_filters, (2, 2), strides=2, padding='same')(input)
    x = layers.concatenate([x, skip_features])
    x = conv_block(x, num_filters)
    return x

## EfficientNetB4

In [3]:
# EfficientNetB4 UNet
def EfficientNetB4_unet(input_shape=(512, 512, 3), weight='imagenet'):
    """
    Creates a neural network using the U-Net architecture and EfficientNetB4 as backbone.
    Args:
        input_shape: The size of the input image.
        weight: Pre-trained weights.
    Returns:
        A U-Net model using EfficientNetB4 as backbone.
    """
    # Input
    inputs = layers.Input(input_shape)

    # Loading pre trained model
    EffNetB4 = EfficientNetB4(include_top=False, weights=weight, input_tensor=inputs)

    # Encoder
    s1 = EffNetB4.get_layer('rescaling').output  # 512 x 512
    s2 = EffNetB4.get_layer('block2a_expand_activation').output  # 256 x 256
    s3 = EffNetB4.get_layer('block3a_expand_activation').output  # 128 x 128
    s4 = EffNetB4.get_layer('block4a_expand_activation').output  # 64 x 64

    # Bottleneck
    b1 = EffNetB4.get_layer('block6a_expand_activation').output  # 32 x 32

    # Decoder
    d1 = decoder_block(b1, s4, 1024)  # 64 x 64
    d2 = decoder_block(d1, s3, 512)    # 128 x 128
    d3 = decoder_block(d2, s2, 256)   # 256 x 256
    d4 = decoder_block(d3, s1, 128)   # 512 x 512

    # Output
    outputs = layers.Conv2D(1, 1, padding='same', activation='sigmoid')(d4)

    model = Model(inputs, outputs, name='EfficientNetB4_U-Net')
    return model

In [4]:
model = EfficientNetB4_unet()

In [5]:
model.summary()

Model: "EfficientNetB4_U-Net"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_1 (InputLayer)           [(None, 512, 512, 3  0           []                               
                                )]                                                                
                                                                                                  
 rescaling (Rescaling)          (None, 512, 512, 3)  0           ['input_1[0][0]']                
                                                                                                  
 normalization (Normalization)  (None, 512, 512, 3)  7           ['rescaling[0][0]']              
                                                                                                  
 rescaling_1 (Rescaling)        (None, 512, 512, 3)  0           ['normalizatio

 block2a_expand_conv (Conv2D)   (None, 256, 256, 14  3456        ['block1b_add[0][0]']            
                                4)                                                                
                                                                                                  
 block2a_expand_bn (BatchNormal  (None, 256, 256, 14  576        ['block2a_expand_conv[0][0]']    
 ization)                       4)                                                                
                                                                                                  
 block2a_expand_activation (Act  (None, 256, 256, 14  0          ['block2a_expand_bn[0][0]']      
 ivation)                       4)                                                                
                                                                                                  
 block2a_dwconv_pad (ZeroPaddin  (None, 257, 257, 14  0          ['block2a_expand_activation[0][0]
 g2D)     

                                                                                                  
 block2c_expand_bn (BatchNormal  (None, 128, 128, 19  768        ['block2c_expand_conv[0][0]']    
 ization)                       2)                                                                
                                                                                                  
 block2c_expand_activation (Act  (None, 128, 128, 19  0          ['block2c_expand_bn[0][0]']      
 ivation)                       2)                                                                
                                                                                                  
 block2c_dwconv (DepthwiseConv2  (None, 128, 128, 19  1728       ['block2c_expand_activation[0][0]
 D)                             2)                               ']                               
                                                                                                  
 block2c_b

                                2)                                                                
                                                                                                  
 block3a_expand_bn (BatchNormal  (None, 128, 128, 19  768        ['block3a_expand_conv[0][0]']    
 ization)                       2)                                                                
                                                                                                  
 block3a_expand_activation (Act  (None, 128, 128, 19  0          ['block3a_expand_bn[0][0]']      
 ivation)                       2)                                                                
                                                                                                  
 block3a_dwconv_pad (ZeroPaddin  (None, 131, 131, 19  0          ['block3a_expand_activation[0][0]
 g2D)                           2)                               ']                               
          

                                                                                                  
 block3c_dwconv (DepthwiseConv2  (None, 64, 64, 336)  8400       ['block3c_expand_activation[0][0]
 D)                                                              ']                               
                                                                                                  
 block3c_bn (BatchNormalization  (None, 64, 64, 336)  1344       ['block3c_dwconv[0][0]']         
 )                                                                                                
                                                                                                  
 block3c_activation (Activation  (None, 64, 64, 336)  0          ['block3c_bn[0][0]']             
 )                                                                                                
                                                                                                  
 block3c_s

 D)                                                                                               
                                                                                                  
 block4a_bn (BatchNormalization  (None, 32, 32, 336)  1344       ['block4a_dwconv[0][0]']         
 )                                                                                                
                                                                                                  
 block4a_activation (Activation  (None, 32, 32, 336)  0          ['block4a_bn[0][0]']             
 )                                                                                                
                                                                                                  
 block4a_se_squeeze (GlobalAver  (None, 336)         0           ['block4a_activation[0][0]']     
 agePooling2D)                                                                                    
          

                                                                                                  
 block4c_se_reshape (Reshape)   (None, 1, 1, 672)    0           ['block4c_se_squeeze[0][0]']     
                                                                                                  
 block4c_se_reduce (Conv2D)     (None, 1, 1, 28)     18844       ['block4c_se_reshape[0][0]']     
                                                                                                  
 block4c_se_expand (Conv2D)     (None, 1, 1, 672)    19488       ['block4c_se_reduce[0][0]']      
                                                                                                  
 block4c_se_excite (Multiply)   (None, 32, 32, 672)  0           ['block4c_activation[0][0]',     
                                                                  'block4c_se_expand[0][0]']      
                                                                                                  
 block4c_p

 block4e_se_expand (Conv2D)     (None, 1, 1, 672)    19488       ['block4e_se_reduce[0][0]']      
                                                                                                  
 block4e_se_excite (Multiply)   (None, 32, 32, 672)  0           ['block4e_activation[0][0]',     
                                                                  'block4e_se_expand[0][0]']      
                                                                                                  
 block4e_project_conv (Conv2D)  (None, 32, 32, 112)  75264       ['block4e_se_excite[0][0]']      
                                                                                                  
 block4e_project_bn (BatchNorma  (None, 32, 32, 112)  448        ['block4e_project_conv[0][0]']   
 lization)                                                                                        
                                                                                                  
 block4e_d

 block5a_project_conv (Conv2D)  (None, 32, 32, 160)  107520      ['block5a_se_excite[0][0]']      
                                                                                                  
 block5a_project_bn (BatchNorma  (None, 32, 32, 160)  640        ['block5a_project_conv[0][0]']   
 lization)                                                                                        
                                                                                                  
 block5b_expand_conv (Conv2D)   (None, 32, 32, 960)  153600      ['block5a_project_bn[0][0]']     
                                                                                                  
 block5b_expand_bn (BatchNormal  (None, 32, 32, 960)  3840       ['block5b_expand_conv[0][0]']    
 ization)                                                                                         
                                                                                                  
 block5b_e

 block5d_expand_conv (Conv2D)   (None, 32, 32, 960)  153600      ['block5c_add[0][0]']            
                                                                                                  
 block5d_expand_bn (BatchNormal  (None, 32, 32, 960)  3840       ['block5d_expand_conv[0][0]']    
 ization)                                                                                         
                                                                                                  
 block5d_expand_activation (Act  (None, 32, 32, 960)  0          ['block5d_expand_bn[0][0]']      
 ivation)                                                                                         
                                                                                                  
 block5d_dwconv (DepthwiseConv2  (None, 32, 32, 960)  24000      ['block5d_expand_activation[0][0]
 D)                                                              ']                               
          

 block5f_expand_activation (Act  (None, 32, 32, 960)  0          ['block5f_expand_bn[0][0]']      
 ivation)                                                                                         
                                                                                                  
 block5f_dwconv (DepthwiseConv2  (None, 32, 32, 960)  24000      ['block5f_expand_activation[0][0]
 D)                                                              ']                               
                                                                                                  
 block5f_bn (BatchNormalization  (None, 32, 32, 960)  3840       ['block5f_dwconv[0][0]']         
 )                                                                                                
                                                                                                  
 block5f_activation (Activation  (None, 32, 32, 960)  0          ['block5f_bn[0][0]']             
 )        

 conv2d_3 (Conv2D)              (None, 128, 128, 51  2359808     ['activation_2[0][0]']           
                                2)                                                                
                                                                                                  
 batch_normalization_3 (BatchNo  (None, 128, 128, 51  2048       ['conv2d_3[0][0]']               
 rmalization)                   2)                                                                
                                                                                                  
 activation_3 (Activation)      (None, 128, 128, 51  0           ['batch_normalization_3[0][0]']  
                                2)                                                                
                                                                                                  
 conv2d_transpose_2 (Conv2DTran  (None, 256, 256, 25  524544     ['activation_3[0][0]']           
 spose)   

## ResNet50V2

In [3]:
def ResNet50V2_unet(input_shape=(512, 512, 3), weight='imagenet'):
    """
    Creates a neural network using the U-Net architecture and ResNet50V2 as backbone.
    Args:
        input_shape: The size of the input image.
        weight: Pre-trained weights.
    Returns:
        A U-Net model using EfficientNetB4 as backbone.
    """
    # Input
    inputs = layers.Input(input_shape)

    # Loading pre trained model
    ResNet50 = ResNet50V2(include_top=False, weights=weight, input_tensor=inputs)

    # Encoder
    s1 = ResNet50.get_layer('input_1').output  # 512 x 512
    s2 = ResNet50.get_layer('conv1_conv').output  # 256 x 256
    s3 = ResNet50.get_layer('conv2_block3_1_relu').output  # 128 x 128
    s4 = ResNet50.get_layer('conv3_block4_1_relu').output  # 64 x 64

    # Bottleneck
    b1 = ResNet50.get_layer('conv4_block6_1_relu').output  # 32 x 32

    # Decoder
    d1 = decoder_block(b1, s4, 1024)  # 64 x 64
    d2 = decoder_block(d1, s3, 512)    # 128 x 128
    d3 = decoder_block(d2, s2, 256)   # 256 x 256
    d4 = decoder_block(d3, s1, 128)   # 512 x 512

    # Output
    outputs = layers.Conv2D(1, 1, padding='same', activation='sigmoid')(d4)

    model = Model(inputs, outputs, name='ResNet50V2_U-Net')
    return model

In [4]:
model = ResNet50V2_unet()

In [5]:
model.summary()

Model: "ResNet50V2_U-Net"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_1 (InputLayer)           [(None, 512, 512, 3  0           []                               
                                )]                                                                
                                                                                                  
 conv1_pad (ZeroPadding2D)      (None, 518, 518, 3)  0           ['input_1[0][0]']                
                                                                                                  
 conv1_conv (Conv2D)            (None, 256, 256, 64  9472        ['conv1_pad[0][0]']              
                                )                                                                 
                                                                                   

 conv2_block2_out (Add)         (None, 128, 128, 25  0           ['conv2_block1_out[0][0]',       
                                6)                                'conv2_block2_3_conv[0][0]']    
                                                                                                  
 conv2_block3_preact_bn (BatchN  (None, 128, 128, 25  1024       ['conv2_block2_out[0][0]']       
 ormalization)                  6)                                                                
                                                                                                  
 conv2_block3_preact_relu (Acti  (None, 128, 128, 25  0          ['conv2_block3_preact_bn[0][0]'] 
 vation)                        6)                                                                
                                                                                                  
 conv2_block3_1_conv (Conv2D)   (None, 128, 128, 64  16384       ['conv2_block3_preact_relu[0][0]'
          

 n)                                                                                               
                                                                                                  
 conv3_block2_2_pad (ZeroPaddin  (None, 66, 66, 128)  0          ['conv3_block2_1_relu[0][0]']    
 g2D)                                                                                             
                                                                                                  
 conv3_block2_2_conv (Conv2D)   (None, 64, 64, 128)  147456      ['conv3_block2_2_pad[0][0]']     
                                                                                                  
 conv3_block2_2_bn (BatchNormal  (None, 64, 64, 128)  512        ['conv3_block2_2_conv[0][0]']    
 ization)                                                                                         
                                                                                                  
 conv3_blo

 ormalization)                                                                                    
                                                                                                  
 conv4_block1_preact_relu (Acti  (None, 32, 32, 512)  0          ['conv4_block1_preact_bn[0][0]'] 
 vation)                                                                                          
                                                                                                  
 conv4_block1_1_conv (Conv2D)   (None, 32, 32, 256)  131072      ['conv4_block1_preact_relu[0][0]'
                                                                 ]                                
                                                                                                  
 conv4_block1_1_bn (BatchNormal  (None, 32, 32, 256)  1024       ['conv4_block1_1_conv[0][0]']    
 ization)                                                                                         
          

                                                                                                  
 conv4_block3_2_conv (Conv2D)   (None, 32, 32, 256)  589824      ['conv4_block3_2_pad[0][0]']     
                                                                                                  
 conv4_block3_2_bn (BatchNormal  (None, 32, 32, 256)  1024       ['conv4_block3_2_conv[0][0]']    
 ization)                                                                                         
                                                                                                  
 conv4_block3_2_relu (Activatio  (None, 32, 32, 256)  0          ['conv4_block3_2_bn[0][0]']      
 n)                                                                                               
                                                                                                  
 conv4_block3_3_conv (Conv2D)   (None, 32, 32, 1024  263168      ['conv4_block3_2_relu[0][0]']    
          

 vation)                        )                                                                 
                                                                                                  
 conv4_block6_1_conv (Conv2D)   (None, 32, 32, 256)  262144      ['conv4_block6_preact_relu[0][0]'
                                                                 ]                                
                                                                                                  
 conv4_block6_1_bn (BatchNormal  (None, 32, 32, 256)  1024       ['conv4_block6_1_conv[0][0]']    
 ization)                                                                                         
                                                                                                  
 conv4_block6_1_relu (Activatio  (None, 32, 32, 256)  0          ['conv4_block6_1_bn[0][0]']      
 n)                                                                                               
          

 conv2d_transpose_3 (Conv2DTran  (None, 512, 512, 12  131200     ['activation_5[0][0]']           
 spose)                         8)                                                                
                                                                                                  
 concatenate_3 (Concatenate)    (None, 512, 512, 13  0           ['conv2d_transpose_3[0][0]',     
                                1)                                'input_1[0][0]']                
                                                                                                  
 conv2d_6 (Conv2D)              (None, 512, 512, 12  151040      ['concatenate_3[0][0]']          
                                8)                                                                
                                                                                                  
 batch_normalization_6 (BatchNo  (None, 512, 512, 12  512        ['conv2d_6[0][0]']               
 rmalizati