In [1]:


from tensorflow.keras.layers import Input, Conv3D, MaxPooling3D, BatchNormalization,concatenate,Conv3DTranspose,Dropout
from tensorflow.keras.models import Model

input_shape = (128,128,128,1)

# IC = 32 # INPUT CHANNEL OR FILTERS

def UNET_3d_v2 (input_shape=input_shape,IC = 8,last_activation='sigmoid'):
    inputs=Input(shape=input_shape)

    conv1 = Conv3D(IC,(3,3,3), activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(inputs)
    d1=Dropout(0.1)(conv1)
    conv2 = Conv3D(IC,(3,3,3), activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(d1)
    b=BatchNormalization()(conv2)

    pool1 = MaxPooling3D(pool_size=(2, 2,2))(b)
    conv3 = Conv3D(IC*2,(3,3,3), activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool1)
    d2=Dropout(0.2)(conv3)
    conv4 = Conv3D(IC*2,(3,3,3), activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(d2)
    b1=BatchNormalization()(conv4)

    pool2 = MaxPooling3D(pool_size=(2, 2,2))(b1)
    conv5 = Conv3D(IC*4,(3,3,3), activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool2)
    d3=Dropout(0.3)(conv5)
    conv6 = Conv3D(IC*4,(3,3,3), activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(d3)
    b2=BatchNormalization()(conv6)

    pool3 = MaxPooling3D(pool_size=(2, 2,2))(b2)
    conv7 = Conv3D(IC*8,(3,3, 3), activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool3)
    d4=Dropout(0.4)(conv7)
    conv8 = Conv3D(IC*8,(3,3, 3), activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(d4)
    b3=BatchNormalization()(conv8)

    pool4 = MaxPooling3D(pool_size=(2, 2, 2))(b3)
    conv9 = Conv3D(IC*16,(3,3, 3),activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool4)
    d5=Dropout(0.5)(conv9)
    conv10 = Conv3D(IC*16,(3,3, 3), activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(d5)
    b4=BatchNormalization()(conv10)


    conv11 = Conv3DTranspose(IC*16,(4,4, 4), activation = 'relu', padding = 'same', strides=(2,2, 2),kernel_initializer = 'he_normal')(b4)
    x= concatenate([conv11,conv8])
    conv12 = Conv3D(IC*8,(3,3,3), activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(x)
    d6=Dropout(0.4)(conv12)
    conv13 = Conv3D(IC*8,(3,3, 3), activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(d6)
    b5=BatchNormalization()(conv13)


    conv14 = Conv3DTranspose(IC*8,(4,4, 4), activation = 'relu', padding = 'same', strides=(2,2, 2),kernel_initializer = 'he_normal')(b5)
    x1=concatenate([conv14,conv6])
    conv15 = Conv3D(IC*4,3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(x1)
    d7=Dropout(0.3)(conv15)
    conv16 = Conv3D(IC*4,3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(d7)
    b6=BatchNormalization()(conv16)

    conv17 = Conv3DTranspose(IC*4,(4,4,4), activation = 'relu', padding = 'same',strides=(2,2, 2), kernel_initializer = 'he_normal')(b6)
    x2=concatenate([conv17,conv4])
    conv18 = Conv3D(IC*2,(3,3, 3), activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(x2)
    d8=Dropout(0.2)(conv18)
    conv19 = Conv3D(IC*2,(3,3, 3) ,activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(d8)
    b7=BatchNormalization()(conv19)

    conv20 = Conv3DTranspose(IC*2,(4,4, 4), activation = 'relu', padding = 'same',strides=(2,2, 2), kernel_initializer = 'he_normal')(b7)
    x3=concatenate([conv20,conv2])
    conv21 = Conv3D(IC,(3,3 , 3) ,activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(x3)
    d9=Dropout(0.1)(conv21)
    conv22 = Conv3D(IC,(3,3, 3), activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(d9)

    outputs = Conv3D(1,(1,1, 1), activation = last_activation, padding = 'same', kernel_initializer = 'he_normal')(conv22)
    model = Model( inputs = inputs, outputs = outputs)

    return model

In [None]:
import matplotlib.pyplot as plt
import tensorflow as tf
import numpy as np

lab_dir = './data_CBCT/labels/'
fd_img = './data_CBCT/images/'

import os
import nibabel as nib
import numpy as np
import skimage.transform as skTrans

def preprocess_data_cbct(folder_directory, desired_shape=(128,128,128)):
    # Get the list of all files in the directory
    files = os.listdir(folder_directory)

    # Create an empty list to store the preprocessed data arrays
    preprocessed_data = []

    # Load each file and preprocess the data
    for file_name in files:
        # Construct the full file path
        file_path = os.path.join(folder_directory, file_name)

        # Load the file
        nii_data = nib.load(file_path)

        # Access the data array
        data_array = nii_data.get_fdata()
        # skTrans.resize will resize the image to (400,400,280). But if we use np.resize the the output 
        # images for some file will be blank if there sizes are not (400,400,280).
#         data_array = skTrans.resize(data_array, (400,400,280), order=1, preserve_range=True)

#         # Check if the data array shape matches the desired shape
        if data_array.shape != desired_shape:
        
#             # Reshape the data array to the desired shape
#              data_array = skTrans.resize(data_array, desired_shape, order=1, preserve_range=True)
             data_array = skTrans.resize(data_array, desired_shape,  mode = 'constant', order=0, preserve_range=False, anti_aliasing=False)

        # Normalize the data array
        data_array = (data_array - np.min(data_array)) / (np.max(data_array) - np.min(data_array))
        data_array = np.asarray(data_array)
        # Convert the data array to float32
        data_array = data_array.astype(np.float32)
        

        # Add a new axis to the data array to represent the channel dimension
        data_array = np.expand_dims(data_array, axis=-1)

        # Append the preprocessed data array to the list
        preprocessed_data.append(data_array)

    # Convert the preprocessed data list to a numpy array
    preprocessed_data = np.array(preprocessed_data)

    return preprocessed_data

X = preprocess_data_cbct(fd_img,desired_shape=(128,128,128))
Y = preprocess_data_cbct(lab_dir,desired_shape=(128,128,128))

x_train = X[:100,:,:,:,:]
y_train = Y[:100,:,:,:,:]
x_test = X[100:,:,:,:,:]
y_test = Y[100:,:,:,:,:]

mirrored_strategy = tf.distribute.MirroredStrategy(devices=["/gpu:0", "/gpu:1", "/gpu:2", "/gpu:3","/gpu:4","/gpu:5","/gpu:6","/gpu:7"])


with mirrored_strategy.scope():

    model =UNET_3d_v2(input_shape=(128,128,128,1),IC = 8,last_activation='sigmoid')

    # model.summary()
    model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
    history = model.fit(x_train, y_train, batch_size = 1, epochs=100,validation_split=0.2, verbose=1)