In [1]:
# import tensorflow as tf

In [2]:
# config = tf.compat.v1.ConfigProto()
# config.gpu_options.allow_growth = True
# sess= tf.compat.v1.Session(config=config)

In [3]:
from keras.models import Model
from keras.layers import ReLU, Input, Conv3D, MaxPooling3D, concatenate, Conv3DTranspose, BatchNormalization, Dropout, Lambda
from tensorflow.keras.optimizers import Adam
from keras.metrics import MeanIoU

kernel_initializer =  'he_uniform' #Try others if you want

def conv_block(num_filter, in_layer):
    ext = Conv3D(num_filter, (3, 3, 3), kernel_initializer=kernel_initializer, padding='same')(in_layer)
    ext = BatchNormalization(axis = -1)(ext)
    ext = ReLU()(ext)
    return ext

def Bottleneck(num_filter, in_layer):
    ext = Conv3D(num_filter/4, (1, 1, 1), kernel_initializer=kernel_initializer, padding='same')(in_layer)
    ext = BatchNormalization(axis = -1)(ext)
    ext = ReLU()(ext)
    ext = Conv3D(num_filter/4, (3, 3, 3), kernel_initializer=kernel_initializer, padding='same')(ext)
    ext = BatchNormalization(axis = -1)(ext)
    ext = ReLU()(ext)
    ext = Conv3D(num_filter, (1, 1, 1), kernel_initializer=kernel_initializer, padding='same')(ext)
    ext = BatchNormalization(axis = -1)(ext)
    ext = ReLU()(ext)
    out = in_layer + ext
    out = ReLU()(out)
    return out
    
def Anam_net(IMG_HEIGHT, IMG_WIDTH, IMG_DEPTH, IMG_CHANNELS, num_classes):
    #Build the model
    num_filter = 16
    inputs = Input((IMG_HEIGHT, IMG_WIDTH, IMG_DEPTH, IMG_CHANNELS))
    #s = Lambda(lambda x: x / 255)(inputs)   #No need for this if we normalize our inputs beforehand
    s = inputs

    #Contraction path
    c1 = conv_block(num_filter, s)
    p1 = MaxPooling3D((2, 2, 2))(c1)
    ad1 = Bottleneck(16,p1)
    
    c2 = conv_block(num_filter*2, ad1)
    p2 = MaxPooling3D((2, 2, 2))(c2)
    ad2 = Bottleneck(num_filter*2,p2)

    c3 = conv_block(num_filter*4, ad2)
    p3 = MaxPooling3D((2, 2, 2))(c3)
    ad3 = Bottleneck(num_filter*4,p3)
 
    c4 = conv_block(num_filter*4, ad3)
    p4 = MaxPooling3D((2, 2, 2))(c4)
    
    #Expansive path 
    u4 = Conv3DTranspose(num_filter*4, (2, 2, 2), strides=(2, 2, 2), padding='same')(p4)
    ad4 = Bottleneck(num_filter*4,u4)
    u4 = concatenate([ad4, c4])
    u4 = conv_block(num_filter*4, u4)
    
    u3 = Conv3DTranspose(num_filter*4, (2, 2, 2), strides=(2, 2, 2), padding='same')(u4)
    ad5 = Bottleneck(num_filter*4,u3)
    u3 = concatenate([ad5, c3])
    u3 = conv_block(num_filter*4, u3)
    
    u2 = Conv3DTranspose(num_filter*2, (2, 2, 2), strides=(2, 2, 2), padding='same')(u3)
    ad6 = Bottleneck(num_filter*2,u2)
    u2 = concatenate([ad6, c2])
    u2 = conv_block(num_filter*2, u2)
    
    u1 = Conv3DTranspose(num_filter, (2, 2, 2), strides=(2, 2, 2), padding='same')(u2)
    u1 = concatenate([u1, c1])
    u1 = conv_block(num_filter, u1)
    
    outputs = Conv3D(num_classes, (1, 1, 1), activation='sigmoid')(u1)
     
    model = Model(inputs=[inputs], outputs=[outputs])
    #compile model outside of this function to make it flexible. 
    
    return model


2022-10-29 23:04:21.410959: I tensorflow/core/util/util.cc:169] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.


In [4]:
model = Anam_net(IMG_HEIGHT=128, 
                          IMG_WIDTH=128, 
                          IMG_DEPTH=128, 
                          IMG_CHANNELS=4, 
                          num_classes=1)

model.summary()
print(model.input_shape)
print(model.output_shape)

2022-10-29 23:04:22.680310: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F AVX512_VNNI FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2022-10-29 23:04:23.414620: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1532] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 34711 MB memory:  -> device: 0, name: NVIDIA RTX A6000, pci bus id: 0000:17:00.0, compute capability: 8.6
2022-10-29 23:04:23.415167: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1532] Created device /job:localhost/replica:0/task:0/device:GPU:1 with 41875 MB memory:  -> device: 1, name: NVIDIA RTX A6000, pci bus id: 0000:65:00.0, compute capability: 8.6


Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_1 (InputLayer)           [(None, 128, 128, 1  0           []                               
                                28, 4)]                                                           
                                                                                                  
 conv3d (Conv3D)                (None, 128, 128, 12  1744        ['input_1[0][0]']                
                                8, 16)                                                            
                                                                                                  
 batch_normalization (BatchNorm  (None, 128, 128, 12  64         ['conv3d[0][0]']                 
 alization)                     8, 16)                                                        

                                                                                                  
 batch_normalization_7 (BatchNo  (None, 32, 32, 32,   128        ['conv3d_7[0][0]']               
 rmalization)                   32)                                                               
                                                                                                  
 re_lu_8 (ReLU)                 (None, 32, 32, 32,   0           ['batch_normalization_7[0][0]']  
                                32)                                                               
                                                                                                  
 tf.__operators__.add_1 (TFOpLa  (None, 32, 32, 32,   0          ['max_pooling3d_1[0][0]',        
 mbda)                          32)                               're_lu_8[0][0]']                
                                                                                                  
 re_lu_9 (

                                                                                                  
 batch_normalization_14 (BatchN  (None, 16, 16, 16,   64         ['conv3d_14[0][0]']              
 ormalization)                  16)                                                               
                                                                                                  
 re_lu_17 (ReLU)                (None, 16, 16, 16,   0           ['batch_normalization_14[0][0]'] 
                                16)                                                               
                                                                                                  
 conv3d_15 (Conv3D)             (None, 16, 16, 16,   1088        ['re_lu_17[0][0]']               
                                64)                                                               
                                                                                                  
 batch_nor

 spose)                         32)                                                               
                                                                                                  
 conv3d_21 (Conv3D)             (None, 64, 64, 64,   264         ['conv3d_transpose_2[0][0]']     
                                8)                                                                
                                                                                                  
 batch_normalization_21 (BatchN  (None, 64, 64, 64,   32         ['conv3d_21[0][0]']              
 ormalization)                  8)                                                                
                                                                                                  
 re_lu_26 (ReLU)                (None, 64, 64, 64,   0           ['batch_normalization_21[0][0]'] 
                                8)                                                                
          