In [2]:
from keras.models import Model
from keras.layers import ReLU, Input, Conv2D, MaxPooling2D, concatenate, Conv2DTranspose, BatchNormalization, Dropout, Lambda
from tensorflow.keras.optimizers import Adam
from keras.metrics import MeanIoU

kernel_initializer =  'he_uniform' #Try others if you want

def conv_block(num_filter, in_layer, drop_out):
    
    x = Conv2D(num_filter,  (3, 3), kernel_initializer=kernel_initializer, padding='same')(in_layer)
    x = BatchNormalization(axis = -1)(x)
    x = ReLU()(x)
    x = Dropout(drop_out)(x)
    x = Conv2D(num_filter,  (3, 3), kernel_initializer=kernel_initializer, padding='same')(x)
    x = BatchNormalization(axis = -1)(x)
    x = ReLU()(x)
    return x

def upConv_block(num_filter, in_layer, concate_list, drop_out):
    
    x = Conv2DTranspose(num_filter, (2, 2), strides=(2, 2), padding='same')(in_layer)
    x = concatenate([x]+concate_list)
    x = ReLU()(x)    
    x = Conv2D(num_filter,  (3, 3), kernel_initializer=kernel_initializer, padding='same')(x)
    x = BatchNormalization(axis = -1)(x)
    x = ReLU()(x)    
    x = Dropout(drop_out)(x)
    x = Conv2D(num_filter,  (3, 3), kernel_initializer=kernel_initializer, padding='same')(x)
    x = BatchNormalization(axis = -1)(x)
    x = ReLU()(x)
    return x
    
################################################################
def unet_pp(IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS, num_classes):
#Build the model
    num_filter = 16
    inputs = Input((IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS))
    #s = Lambda(lambda x: x / 255)(inputs)   #No need for this if we normalize our inputs beforehand
    s = inputs

    #Contraction path
    x00 = conv_block(num_filter, s, 0.1)
    p1 = MaxPooling2D((2, 2))(x00)
    
    x10 = conv_block(2*num_filter, p1, 0.1)   
    p2 = MaxPooling2D((2, 2))(x10)

    x01 = upConv_block(num_filter, x10, [x00], 0.2)

    x20 = conv_block(4*num_filter, p2, 0.2)   
    p3 = MaxPooling2D((2, 2))(x20)
    
    x11 = upConv_block(2*num_filter, x20, [x10], 0.2)

    x02 = upConv_block(num_filter, x11, [x01, x00], 0.2)

    x30 = conv_block(4*num_filter, p3, 0.2)   
    p4 = MaxPooling2D(pool_size=(2, 2))(x30)
    
    x21 = upConv_block(4*num_filter, x30, [x20], 0.2)
        
    x12 = upConv_block(4*num_filter, x21, [x11, x10], 0.2)
    
    x03 = upConv_block(num_filter, x12, [x02, x01,  x00], 0.2)
    
    x40 = conv_block(16*num_filter, p4, 0.3)   

    #Expansive path 
    x31 = upConv_block(8*num_filter, x40, [x30], 0.2)
 
    x22 = upConv_block(4*num_filter, x31, [x21, x20], 0.2)
    
    x13 = upConv_block(2*num_filter, x22, [x12, x11, x10], 0.1)
     
    x04 = upConv_block(num_filter, x13, [x03, x02, x01, x00], 0.1)

    outputs = Conv2D(num_classes, (1, 1), activation='sigmoid')(x04)
     
    model = Model(inputs=[inputs], outputs=[outputs])
    #compile model outside of this function to make it flexible. 
    
    return model

def main():
    #Test if everything is working ok. 
    model = unet_pp(IMG_HEIGHT=448, 
                    IMG_WIDTH=336, 
                    IMG_CHANNELS=1, 
                    num_classes=1)


    print(model.input_shape)
    print(model.output_shape)
    model.summary()
    
if __name__=="__main__":
    main()

2023-01-16 22:33:20.772684: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F AVX512_VNNI FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-01-16 22:33:21.846162: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1613] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 822 MB memory:  -> device: 0, name: NVIDIA RTX A6000, pci bus id: 0000:17:00.0, compute capability: 8.6
2023-01-16 22:33:21.846705: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1613] Created device /job:localhost/replica:0/task:0/device:GPU:1 with 39704 MB memory:  -> device: 1, name: NVIDIA RTX A6000, pci bus id: 0000:65:00.0, compute capability: 8.6


(None, 448, 336, 1)
(None, 448, 336, 1)
Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_1 (InputLayer)           [(None, 448, 336, 1  0           []                               
                                )]                                                                
                                                                                                  
 conv2d (Conv2D)                (None, 448, 336, 16  160         ['input_1[0][0]']                
                                )                                                                 
                                                                                                  
 batch_normalization (BatchNorm  (None, 448, 336, 16  64         ['conv2d[0][0]']                 
 alization)                     )                     

 ormalization)                                                                                    
                                                                                                  
 re_lu_16 (ReLU)                (None, 56, 42, 64)   0           ['batch_normalization_13[0][0]'] 
                                                                                                  
 max_pooling2d_3 (MaxPooling2D)  (None, 28, 21, 64)  0           ['re_lu_16[0][0]']               
                                                                                                  
 conv2d_20 (Conv2D)             (None, 28, 21, 256)  147712      ['max_pooling2d_3[0][0]']        
                                                                                                  
 batch_normalization_20 (BatchN  (None, 28, 21, 256)  1024       ['conv2d_20[0][0]']              
 ormalization)                                                                                    
          

                                                                                                  
 dropout_7 (Dropout)            (None, 112, 84, 64)  0           ['re_lu_18[0][0]']               
                                                                                                  
 re_lu_10 (ReLU)                (None, 224, 168, 32  0           ['batch_normalization_8[0][0]']  
                                )                                                                 
                                                                                                  
 batch_normalization_4 (BatchNo  (None, 448, 336, 16  64         ['conv2d_4[0][0]']               
 rmalization)                   )                                                                 
                                                                                                  
 batch_normalization_23 (BatchN  (None, 56, 42, 128)  512        ['conv2d_23[0][0]']              
 ormalizat

 re_lu_12 (ReLU)                (None, 448, 336, 48  0           ['concatenate_2[0][0]']          
                                )                                                                 
                                                                                                  
 re_lu_32 (ReLU)                (None, 112, 84, 64)  0           ['batch_normalization_24[0][0]'] 
                                                                                                  
 batch_normalization_16 (BatchN  (None, 224, 168, 64  256        ['conv2d_16[0][0]']              
 ormalization)                  )                                                                 
                                                                                                  
 conv2d_10 (Conv2D)             (None, 448, 336, 16  6928        ['re_lu_12[0][0]']               
                                )                                                                 
          

                                                                                                  
 re_lu_35 (ReLU)                (None, 224, 168, 32  0           ['batch_normalization_26[0][0]'] 
                                )                                                                 
                                                                                                  
 batch_normalization_18 (BatchN  (None, 448, 336, 16  64         ['conv2d_18[0][0]']              
 ormalization)                  )                                                                 
                                                                                                  
 dropout_13 (Dropout)           (None, 224, 168, 32  0           ['re_lu_35[0][0]']               
                                )                                                                 
                                                                                                  
 re_lu_24 