In [None]:
import numpy as np
import pydicom
import nibabel as nib
import os
import cv2
import matplotlib.pyplot as plt
from scipy.ndimage import label
import tensorflow as tf
from tensorflow import keras

In [None]:
def preprocess_TotalSeg(mask_path):
    mask=nib.load(mask_path)
    mask=mask.get_fdata()
    mask1=np.zeros((220,220,380))
    for i in range(380):
        mask_i=cv2.resize(mask[70:442,70:442,379-i],(220,220))
        for j in range(220):
            mask1[:,j,i]=mask_i[:,j]
    mask1=np.round(np.flip(mask1,axis=1))
    return(mask1)

def mask_into_binary(mask,index_list):
    #create another image matrix where the values within the greatest connected component of the target are 1 and the others are 0
    structure=np.ones((3,3,3),dtype=int)
    mask1=np.zeros(mask.shape)
    for index in index_list:
        mask_i=np.array(mask==index,dtype=int)
        if np.max(mask_i)==1:
            labeled, _ = label(mask_i,structure)
            array=labeled.flatten()
            array1=array[array>0]
            counts = np.bincount(array1)
            mask1+=np.array(labeled==np.argmax(counts),dtype=int)
    return(mask1)

def find_bounding_box(mask):
    k=np.max(np.max(mask,axis=2),axis=1)
    k_1s=np.array(range(len(k)))[k==1]
    k1=np.min(k_1s)
    k2=np.max(k_1s)
    k=np.max(np.max(mask,axis=2),axis=0)
    k_1s=np.array(range(len(k)))[k==1]
    k3=np.min(k_1s)
    k4=np.max(k_1s)
    k=np.max(np.max(mask,axis=1),axis=0)
    k_1s=np.array(range(len(k)))[k==1]
    k5=np.min(k_1s)
    k6=np.max(k_1s)
    return(k1,k2,k3,k4,k5,k6)

def read_pet_image(path_to_dcm_folder):
    #the input is the path to the folder containing all the DICOM slices
    #first, create a list of the names of all the files in this folders
    filenames=os.listdir(path_to_dcm_folder)
    #then create a list of filenames which are DICOM (i.e. have .dcm ending)
    dcm_filenames=[]
    for i in range(len(filenames)):
        if filenames[i][-4:]=='.dcm':
            dcm_filenames.append(filenames[i])
    #initialise lists for slices, their locations, and times
    slice_list=[]
    locations=[]
    times=[]
    #fill the lists by going through all the DICOM slices
    for i in range(len(dcm_filenames)):
        file='{}/{}'.format(path_to_dcm_folder,dcm_filenames[i])
        ds=pydicom.dcmread(file)
        locations.append(ds.ImagePositionPatient[2])
        times.append(ds.AcquisitionTime)
        if 'RescaleSlope' in ds and 'RescaleIntercept' in ds:
            slice_list.append(ds.pixel_array*ds.RescaleSlope+ ds.RescaleIntercept)
        else:
            slice_list.append(ds.pixel_array)
    #create new lists of the unique values of the acquisation times
    unique_times=np.sort(np.unique(np.array(times)))
    #initialize the image
    img=np.zeros((220,220,380))
    #find the 3D time-frames for all unique time values
    for i in range(len(unique_times)):
        locations_1=list(np.array(locations)[np.array(times)==unique_times[i]])
        slice_list_1=list(np.array(slice_list)[np.array(times)==unique_times[i]])
        merged=dict(zip(locations_1,slice_list_1))
        ordered_slice_list=[merged[loc] for loc in sorted(locations_1)]
        img_3d=np.moveaxis(np.array(ordered_slice_list),0,2)
        img_3d=np.moveaxis(img_3d,0,1)
        img_3d=np.flip(img_3d,axis=2)
        if i>0:
            k=(float(unique_times[i])-float(unique_times[i-1]))/(float(unique_times[-1])-float(unique_times[0]))
        else:
            k=(float(unique_times[i+1])-float(unique_times[i]))/(float(unique_times[-1])-float(unique_times[0]))
        img+=k*img_3d
    return(img)

def preprocess_slice(img,img_height):
    img=cv2.resize(img,(img_height,img_height))
    return img.astype('float16')

#function for training the encoderCNN for classification
#input: x_train, y_train, the number of epochs, and the organ and the view for naming the trained model
def train_encoder(x_train,y_train,numberOfEpochs,organ,view):
    
    img_height=x_train[0].shape[0]

    model = keras.Sequential([keras.layers.Conv2D(16, 3, activation='relu', input_shape=(img_height,img_height,1)),
                        keras.layers.Conv2D(16, 3, activation='relu'),
                        keras.layers.MaxPooling2D(pool_size=(2, 2), strides=(2, 2)),
                        keras.layers.Conv2D(32, 3, activation='relu'),
                        keras.layers.Conv2D(32, 3, activation='relu'),
                        keras.layers.MaxPooling2D(pool_size=(2, 2), strides=(2, 2)),
                        keras.layers.Conv2D(64, 3, activation='relu'),
                        keras.layers.Conv2D(64, 3, activation='relu'),
                        keras.layers.MaxPooling2D(pool_size=(2, 2), strides=(2, 2)),
                        keras.layers.Conv2D(128, 3, activation='relu'),
                        keras.layers.Conv2D(128, 3, activation='relu'),
                        keras.layers.MaxPooling2D(strides=(2, 2)),
                        keras.layers.Flatten(),
                        keras.layers.Dense(128, activation='relu'),
                        keras.layers.Dense(64, activation='relu'),
                        keras.layers.Dense(32, activation='relu'),
                        keras.layers.Dense(1, activation='sigmoid')
    ]
    )

    model.compile(
            optimizer=tf.keras.optimizers.Adam(1e-3),
            loss=tf.keras.losses.BinaryCrossentropy(),
            metrics=[tf.keras.metrics.BinaryAccuracy()]
    )

    history=model.fit(x=x_train,y=y_train,epochs=numberOfEpochs,shuffle=True)
    model.save('classifier_{}_{}.keras'.format(organ,view),include_optimizer=True)
    plt.plot(range(len(history.history['loss'])),history.history['loss'],color='blue')

#function for post-processing predictions by taking means of five consecutive predictions
def postProcess_new(predictions,patientIndexes_test):

    predictions_1=[]
    for i in range(len(predictions)):
        current=predictions[i]
        if i>0:
            if patientIndexes_test[i-1]==patientIndexes_test[i]:
                prev=predictions[i-1]
            else:
                prev=0
            if i>1:
                if patientIndexes_test[i-2]==patientIndexes_test[i]:
                    prev1=predictions[i-2]
                else:
                    prev1=0
            else:
                    prev1=0
        else:
            prev=0
            prev1=0
        if i<len(predictions)-1:
            if patientIndexes_test[i+1]==patientIndexes_test[i]:
                next=predictions[i+1]
            else:
                next=0
            if i<len(predictions)-2:
                if patientIndexes_test[i+2]==patientIndexes_test[i]:
                    next1=predictions[i+2]
                else:
                    next1=0
            else:
                next1=0
        else:
            next=0
            next1=0
        predictions_1.append(np.mean([prev1,prev,current,next,next1]))
    predictions_1=np.array(predictions_1)
    return(predictions_1)

def chooseIndices(i_1):
    i_1s=np.array(range(len(i_1)))[i_1==1]
    if i_1s[-1]-i_1s[0]<64:
        return([max(0,round((i_1s[-1]+i_1s[0])/2-32))])
    elif i_1s[-1]-i_1s[0]<128:
        return([i_1s[0],i_1s[-1]-64])
    elif i_1s[-1]-i_1s[0]<192:
        return([i_1s[0],round((i_1s[-1]+i_1s[0])/2),i_1s[-1]-64])
    else:
        if len(i_1)==220:
            return([0,55,101,156])
        if len(i_1)==380:
            return([0,63,126,189,252,315])

#the segmentation model (a standard U-Net model with some layers removed so that it suits for 64*64 data)
#input: x_train, y_train, the number of epochs, and the organ and the view for naming the trained model
def train_unet(x_train,y_train,numberOfEpochs,organ,view):

    img_height=x_train[0].shape[0]
    if x_train[0].shape==(img_height,img_height):
        input_depth=1
    else:
        input_depth=x_train[0].shape[2]
    #U-Net model 
    inputs = tf.keras.layers.Input(shape=(img_height,img_height,input_depth))
    
    #Contraction path
    c1 = tf.keras.layers.Conv2D(16, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(inputs)
    c1 = tf.keras.layers.Dropout(0.1)(c1)
    c1 = tf.keras.layers.Conv2D(16, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c1)
    p1 = tf.keras.layers.MaxPooling2D((2, 2))(c1)

    c2 = tf.keras.layers.Conv2D(32, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(p1)
    c2 = tf.keras.layers.Dropout(0.1)(c2)
    c2 = tf.keras.layers.Conv2D(32, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c2)
    p2 = tf.keras.layers.MaxPooling2D((2, 2))(c2)

    c3 = tf.keras.layers.Conv2D(64, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(p2)
    c3 = tf.keras.layers.Dropout(0.2)(c3)
    c3 = tf.keras.layers.Conv2D(64, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c3)

    #Expansive path
    u8 = tf.keras.layers.Conv2DTranspose(32, (2, 2), strides=(2, 2), padding='same')(c3)
    u8 = tf.keras.layers.concatenate([u8, c2])
    c8 = tf.keras.layers.Conv2D(32, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(u8)
    c8 = tf.keras.layers.Dropout(0.1)(c8)
    c8 = tf.keras.layers.Conv2D(32, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c8)

    u9 = tf.keras.layers.Conv2DTranspose(16, (2, 2), strides=(2, 2), padding='same')(c8)
    u9 = tf.keras.layers.concatenate([u9, c1], axis=3)
    c9 = tf.keras.layers.Conv2D(16, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(u9)
    c9 = tf.keras.layers.Dropout(0.1)(c9)
    c9 = tf.keras.layers.Conv2D(16, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c9)

    outputs = tf.keras.layers.Conv2D(1, (1,1), activation='sigmoid')(c9)

    model = tf.keras.Model(inputs=inputs, outputs=outputs)
    
    model.compile(optimizer='adam', loss='Dice')

    history=model.fit(x=x_train,y=y_train,epochs=numberOfEpochs,shuffle=True)
    model.save('segmentator_{}_{}.keras'.format(organ,view),include_optimizer=True)
    plt.plot(range(len(history.history['loss'])),history.history['loss'],color='blue')

In [54]:
index=1

organ=['-','brain','aorta','heart','myocardium','left_ventricle','kidneys','liver','lungs'][index]
folder_path='C:/Users/oonar/Documents/kov/files'
filenames=os.listdir(folder_path)
study_ids=[]
for i in range(len(filenames)):
    if 'img' in filenames[i]:
        number=int(filenames[i][3:-4])
        if index==1:
            if number not in [61,67,81,119,121]:
                study_ids.append(number)
        if index==2:
            study_ids.append(number)
        if index==3:
            if number not in [16,17,49,60,62,81,119,121,131,133]:
                study_ids.append(number)
        if index==4:
            if number not in [16,17,35,49,60,62,81,119,121,131,133]:
                study_ids.append(number)
        if index==5:
            if number not in [16,17,49,60,62,81,119,121,131,133]:
                study_ids.append(number)
        if index==6:
            if number not in [1,16,27,29,35,42,49,56,61,67,68,70,81,83,88,109,119,121,130,131,144,149,157]:
                study_ids.append(number)
        if index==7:
            if number not in [1,3,25,32,33,34,37,61,81,83,84,109,119,121,131,133,149]:
                study_ids.append(number)
        if index==8:
            if number not in [24,42,64,65]:
                study_ids.append(number)

In [None]:
#train classifiers
img_height=128
for view in ['sagittal','coronal','transaxial']:
    patientIndexes=[]
    slices=[]
    labels=[]
    for i in range(len(study_ids)):
        img_path='C:/Users/oonar/Documents/kov/files/img{}.nii'.format(study_ids[i])
        mask_path='C:/Users/oonar/Documents/kov/files/mask{}.nii'.format(study_ids[i])
        img=nib.load(img_path)
        img=img.get_fdata()
        mask=nib.load(mask_path)
        mask=mask.get_fdata()
        if index==3:
            mask[mask==4]=3
            mask[mask==5]=3
        img[img>50000]=50000
        img=img/50000
        if view=='sagittal':
            for j in range(img.shape[0]):
                patientIndexes.append(i)
                slices.append(preprocess_slice(img[j,:,:],img_height))
                labels.append(np.max(mask[j,:,:]==index))
        if view=='coronal':
            for j in range(img.shape[1]):
                patientIndexes.append(i)
                slices.append(preprocess_slice(img[:,j,:],img_height))
                labels.append(np.max(mask[:,j,:]==index))
        if view=='transaxial':
            for j in range(img.shape[2]):
                patientIndexes.append(i)
                slices.append(preprocess_slice(img[:,:,j],img_height))
                labels.append(np.max(mask[:,:,j]==index))

    x_train=np.array(slices).astype('float16')
    y_train=np.array(labels)
    print(x_train.shape)
    numberOfEpochs=5
    train_encoder(x_train,y_train,numberOfEpochs,organ,view)       

In [57]:
#load and compile classifiers
classifier_sag=keras.models.load_model('classifier_{}_{}.keras'.format(organ,'sagittal'))
classifier_sag.compile(optimizer='adam',loss='BinaryCrossentropy')
classifier_cor=keras.models.load_model('classifier_{}_{}.keras'.format(organ,'coronal'))
classifier_cor.compile(optimizer='adam',loss='BinaryCrossentropy')
classifier_trans=keras.models.load_model('classifier_{}_{}.keras'.format(organ,'transaxial'))
classifier_trans.compile(optimizer='adam',loss='BinaryCrossentropy')

In [None]:
#predict classification labels
img_height=128
for view in ['sagittal','coronal','transaxial']:
    patientIndexes=[]
    slices=[]
    labels=[]
    for i in range(len(study_ids)):
        img_path='C:/Users/oonar/Documents/kov/files/img{}.nii'.format(study_ids[i])
        mask_path='C:/Users/oonar/Documents/kov/files/mask{}.nii'.format(study_ids[i])
        img=nib.load(img_path)
        img=img.get_fdata()
        mask=nib.load(mask_path)
        mask=mask.get_fdata()
        if index==3:
            mask[mask==4]=3
            mask[mask==5]=3
        img[img>50000]=50000
        img=img/50000
        if view=='sagittal':
            for j in range(img.shape[0]):
                patientIndexes.append(i)
                slices.append(preprocess_slice(img[j,:,:],img_height))
                labels.append(np.max(mask[j,:,:]==index))
        if view=='coronal':
            for j in range(img.shape[1]):
                patientIndexes.append(i)
                slices.append(preprocess_slice(img[:,j,:],img_height))
                labels.append(np.max(mask[:,j,:]==index))
        if view=='transaxial':
            for j in range(img.shape[2]):
                patientIndexes.append(i)
                slices.append(preprocess_slice(img[:,:,j],img_height))
                labels.append(np.max(mask[:,:,j]==index))

    x_train=np.array(slices).astype('float16')
    if view=='sagittal':
        predicted_labels_sag=classifier_sag.predict(x_train)[:,0]
        predicted_labels_sag=postProcess_new(predicted_labels_sag,patientIndexes)   
    if view=='coronal':
        predicted_labels_cor=classifier_cor.predict(x_train)[:,0]
        predicted_labels_cor=postProcess_new(predicted_labels_cor,patientIndexes)  
    if view=='transaxial':
        predicted_labels_trans=classifier_trans.predict(x_train)[:,0]
        predicted_labels_trans=postProcess_new(predicted_labels_trans,patientIndexes)       

In [None]:
#train segmentators
for view in ['sagittal','coronal','transaxial']:
    patientIndexes=[]
    slices=[]
    masks=[]
    for i in range(len(study_ids)):
        img_path='C:/Users/oonar/Documents/kov/files/img{}.nii'.format(study_ids[i])
        mask_path='C:/Users/oonar/Documents/kov/files/mask{}.nii'.format(study_ids[i])
        img=nib.load(img_path)
        img=img.get_fdata()
        mask=nib.load(mask_path)
        mask=mask.get_fdata()
        if index==3:
            mask[mask==4]=3
            mask[mask==5]=3
        img[img>50000]=50000
        img=img/50000
        mask_1=np.zeros(img.shape)
        for j in range(img.shape[0]):
            if predicted_labels_sag[img.shape[0]*i+j]>=0.5:
                mask_1[j,:,:]+=1
        for j in range(img.shape[1]):
            if predicted_labels_cor[img.shape[1]*i+j]>=0.5:
                mask_1[:,j,:]+=1
        for j in range(img.shape[2]):
            if predicted_labels_trans[img.shape[2]*i+j]>=0.5:
                mask_1[:,:,j]+=1
        mask_1=np.array(mask_1>2.5,dtype=int)
        if view=='sagittal':
            for j in range(mask_1.shape[0]):
                if np.max(mask_1[j,:,:])==1:
                    for i1 in chooseIndices(np.max(mask_1[j,:,:],axis=1)):
                        for j1 in chooseIndices(np.max(mask_1[j,:,:],axis=0)):
                            if np.max(mask_1[j,i1:(i1+64),j1:(j1+64)])==1:
                                patientIndexes.append(i)
                                slices.append(img[j,i1:(i1+64),j1:(j1+64)].astype('float16'))
                                masks.append(np.array(mask[j,i1:(i1+64),j1:(j1+64)]==index,dtype=int))
        if view=='coronal':
            for j in range(mask_1.shape[1]):
                if np.max(mask_1[:,j,:])==1:
                    for i1 in chooseIndices(np.max(mask_1[:,j,:],axis=1)):
                        for j1 in chooseIndices(np.max(mask_1[:,j,:],axis=0)):
                            if np.max(mask_1[i1:(i1+64),j,j1:(j1+64)])==1:
                                patientIndexes.append(i)
                                slices.append(img[i1:(i1+64),j,j1:(j1+64)].astype('float16'))
                                masks.append(np.array(mask[i1:(i1+64),j,j1:(j1+64)]==index,dtype=int))
        if view=='transaxial':
            for j in range(mask_1.shape[2]):
                if np.max(mask_1[:,:,j])==1:
                    for i1 in chooseIndices(np.max(mask_1[:,:,j],axis=1)):
                        for j1 in chooseIndices(np.max(mask_1[:,:,j],axis=0)):
                            if np.max(mask_1[i1:(i1+64),j1:(j1+64),j])==1:
                                patientIndexes.append(i)
                                slices.append(img[i1:(i1+64),j1:(j1+64),j].astype('float16'))
                                masks.append(np.array(mask[i1:(i1+64),j1:(j1+64),j]==index,dtype=int))
    x_train=np.array(slices).astype('float16')
    y_train=np.array(masks)
    print(x_train.shape)
    numberOfEpochs=5
    train_unet(x_train,y_train,numberOfEpochs,organ,view)