In [1]:
from keras.models import Model
from keras.layers import ReLU, Input, Conv3D, MaxPooling3D, concatenate, Conv3DTranspose, BatchNormalization, Dropout, Lambda
from tensorflow.keras.optimizers import Adam
from keras.metrics import MeanIoU

kernel_initializer =  'he_uniform' #Try others if you want

def conv_block(num_filter, in_layer, drop_out):
    
    x = Conv3D(num_filter, (3, 3, 3), kernel_initializer=kernel_initializer, padding='same')(in_layer)
    x = BatchNormalization(axis = -1)(x)
    x = ReLU()(x)
    x = Dropout(drop_out)(x)
    x = Conv3D(num_filter, (3, 3, 3), kernel_initializer=kernel_initializer, padding='same')(x)
    x = BatchNormalization(axis = -1)(x)
    x = ReLU()(x)
    return x

def upConv_block(num_filter, in_layer, concate_list, drop_out):
    
    x = Conv3DTranspose(num_filter, (2, 2, 2), strides=(2, 2, 2), padding='same')(in_layer)
    x = concatenate([x]+concate_list)
    x = ReLU()(x)    
    x = Conv3D(num_filter, (3, 3, 3), kernel_initializer=kernel_initializer, padding='same')(x)
    x = BatchNormalization(axis = -1)(x)
    x = ReLU()(x)    
    x = Dropout(drop_out)(x)
    x = Conv3D(num_filter, (3, 3, 3), kernel_initializer=kernel_initializer, padding='same')(x)
    x = BatchNormalization(axis = -1)(x)
    x = ReLU()(x)
    return x
    
################################################################
def unet_pp(IMG_HEIGHT, IMG_WIDTH, IMG_DEPTH, IMG_CHANNELS, num_classes):
#Build the model
    num_filter = 16
    inputs = Input((IMG_HEIGHT, IMG_WIDTH, IMG_DEPTH, IMG_CHANNELS))
    #s = Lambda(lambda x: x / 255)(inputs)   #No need for this if we normalize our inputs beforehand
    s = inputs

    #Contraction path
    x00 = conv_block(num_filter, s, 0.1)
    p1 = MaxPooling3D((2, 2, 2))(x00)
    
    x10 = conv_block(2*num_filter, p1, 0.1)   
    p2 = MaxPooling3D((2, 2, 2))(x10)

    x01 = upConv_block(num_filter, x10, [x00], 0.2)

    x20 = conv_block(4*num_filter, p2, 0.2)   
    p3 = MaxPooling3D((2, 2, 2))(x20)
    
    x11 = upConv_block(2*num_filter, x20, [x10], 0.2)

    x02 = upConv_block(num_filter, x11, [x01, x00], 0.2)

    x30 = conv_block(4*num_filter, p3, 0.2)   
    p4 = MaxPooling3D(pool_size=(2, 2, 2))(x30)
    
    x21 = upConv_block(4*num_filter, x30, [x20], 0.2)
        
    x12 = upConv_block(4*num_filter, x21, [x11, x10], 0.2)
    
    x03 = upConv_block(num_filter, x12, [x02, x01,  x00], 0.2)
    
    x40 = conv_block(16*num_filter, p4, 0.3)   

    #Expansive path 
    x31 = upConv_block(8*num_filter, x40, [x30], 0.2)
 
    x22 = upConv_block(4*num_filter, x31, [x21, x20], 0.2)
    
    x13 = upConv_block(2*num_filter, x22, [x12, x11, x10], 0.1)
     
    x04 = upConv_block(num_filter, x13, [x03, x02, x01, x00], 0.1)

    outputs = Conv3D(num_classes, (1, 1, 1), activation='sigmoid')(x04)
     
    model = Model(inputs=[inputs], outputs=[outputs])
    #compile model outside of this function to make it flexible. 
    
    return model

#Test if everything is working ok. 
model = unet_pp(IMG_HEIGHT=128, 
                          IMG_WIDTH=128, 
                          IMG_DEPTH=128, 
                          IMG_CHANNELS=4, 
                          num_classes=1)


print(model.input_shape)
print(model.output_shape)
model.summary()

2022-10-30 11:21:31.716938: I tensorflow/core/util/util.cc:169] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2022-10-30 11:21:33.051151: E tensorflow/stream_executor/cuda/cuda_driver.cc:271] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
2022-10-30 11:21:33.051237: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:156] kernel driver does not appear to be running on this host (mig): /proc/driver/nvidia/version does not exist
2022-10-30 11:21:33.052258: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F AVX512_VNNI FMA
To enable them in other operations, rebuild TensorFlow with the appropriate comp

(None, 128, 128, 128, 4)
(None, 128, 128, 128, 1)
Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_1 (InputLayer)           [(None, 128, 128, 1  0           []                               
                                28, 4)]                                                           
                                                                                                  
 conv3d (Conv3D)                (None, 128, 128, 12  1744        ['input_1[0][0]']                
                                8, 16)                                                            
                                                                                                  
 batch_normalization (BatchNorm  (None, 128, 128, 12  64         ['conv3d[0][0]']                 
 alization)                     8, 16)      

                                                                                                  
 re_lu_15 (ReLU)                (None, 16, 16, 16,   0           ['batch_normalization_12[0][0]'] 
                                64)                                                               
                                                                                                  
 dropout_6 (Dropout)            (None, 16, 16, 16,   0           ['re_lu_15[0][0]']               
                                64)                                                               
                                                                                                  
 conv3d_13 (Conv3D)             (None, 16, 16, 16,   110656      ['dropout_6[0][0]']              
                                64)                                                               
                                                                                                  
 batch_nor

                                                                                                  
 batch_normalization_14 (BatchN  (None, 32, 32, 32,   256        ['conv3d_14[0][0]']              
 ormalization)                  64)                                                               
                                                                                                  
 conv3d_8 (Conv3D)              (None, 64, 64, 64,   55328       ['re_lu_9[0][0]']                
                                32)                                                               
                                                                                                  
 re_lu_4 (ReLU)                 (None, 128, 128, 12  0           ['concatenate[0][0]']            
                                8, 32)                                                            
                                                                                                  
 dropout_1

 re_lu_31 (ReLU)                (None, 32, 32, 32,   0           ['concatenate_7[0][0]']          
                                192)                                                              
                                                                                                  
 concatenate_4 (Concatenate)    (None, 64, 64, 64,   0           ['conv3d_transpose_4[0][0]',     
                                128)                              're_lu_11[0][0]',               
                                                                  're_lu_3[0][0]']                
                                                                                                  
 conv3d_transpose_2 (Conv3DTran  (None, 128, 128, 12  4112       ['re_lu_11[0][0]']               
 spose)                         8, 16)                                                            
                                                                                                  
 re_lu_6 (

 batch_normalization_11 (BatchN  (None, 128, 128, 12  64         ['conv3d_11[0][0]']              
 ormalization)                  8, 16)                                                            
                                                                                                  
 concatenate_8 (Concatenate)    (None, 64, 64, 64,   0           ['conv3d_transpose_8[0][0]',     
                                160)                              're_lu_22[0][0]',               
                                                                  're_lu_11[0][0]',               
                                                                  're_lu_3[0][0]']                
                                                                                                  
 conv3d_transpose_5 (Conv3DTran  (None, 128, 128, 12  8208       ['re_lu_22[0][0]']               
 spose)                         8, 16)                                                            
          

 ormalization)                  8, 16)                                                            
                                                                                                  
 re_lu_38 (ReLU)                (None, 128, 128, 12  0           ['batch_normalization_28[0][0]'] 
                                8, 16)                                                            
                                                                                                  
 dropout_14 (Dropout)           (None, 128, 128, 12  0           ['re_lu_38[0][0]']               
                                8, 16)                                                            
                                                                                                  
 conv3d_29 (Conv3D)             (None, 128, 128, 12  6928        ['dropout_14[0][0]']             
                                8, 16)                                                            
          