# Anatomical segmentation of aorta with 3D U-net

In [1]:
import numpy as np
import os

from PIL import Image
import SimpleITK as sitk
from skimage.transform import resize
from skimage.morphology import binary_erosion
from skimage.measure import label, regionprops_table
import scipy

import keras
from keras.models import load_model
from keras.layers import Input, Dense, Conv3D, MaxPooling3D, UpSampling3D, Conv3DTranspose, concatenate
from keras.models import Model
from keras.callbacks import EarlyStopping, ModelCheckpoint
from keras.optimizers import Adam
from keras import backend as K
from keras import metrics
from keras import optimizers
from skimage import io

import tensorflow as tf

# Ignore warnings
import warnings
warnings.filterwarnings("ignore")

Using TensorFlow backend.


### Set up the model

In [3]:
def dice_coef(y_true, y_pred):
    
    y_true_1 = K.flatten(y_true[:,:,:,:,0])
    y_pred_1 = K.flatten(y_pred[:,:,:,:,0])
    d_1 = (2 * K.sum(y_true_1 * y_pred_1) +1 )/ (K.sum(y_true_1) + K.sum(y_pred_1) + 1)
    
    y_true_2 = K.flatten(y_true[:,:,:,:,1])
    y_pred_2 = K.flatten(y_pred[:,:,:,:,1])
    d_2 = (2 * K.sum(y_true_2 * y_pred_2) +1 )/ (K.sum(y_true_2) + K.sum(y_pred_2) + 1)
    
    y_true_3 = K.flatten(y_true[:,:,:,:,2])
    y_pred_3 = K.flatten(y_pred[:,:,:,:,2])
    d_3 = (2 * K.sum(y_true_3 * y_pred_3) +1 )/ (K.sum(y_true_3) + K.sum(y_pred_3) + 1)
    
    y_true_4 = K.flatten(y_true[:,:,:,:,3])
    y_pred_4 = K.flatten(y_pred[:,:,:,:,3])
    d_4 = (2 * K.sum(y_true_4 * y_pred_4) +1 )/ (K.sum(y_true_4) + K.sum(y_pred_4) + 1)
    
    y_true_5 = K.flatten(y_true[:,:,:,:,4])
    y_pred_5 = K.flatten(y_pred[:,:,:,:,4])
    d_5 = (2 * K.sum(y_true_5 * y_pred_5) +1 )/ (K.sum(y_true_5) + K.sum(y_pred_5) + 1)
    
    y_true_6 = K.flatten(y_true[:,:,:,:,5])
    y_pred_6 = K.flatten(y_pred[:,:,:,:,5])
    d_6 = (2 * K.sum(y_true_6 * y_pred_6) +1 )/ (K.sum(y_true_6) + K.sum(y_pred_6) + 1)
    
    return (d_1+d_2+d_3+d_4+d_5+d_6)/6.0
def dice_coef_loss(y_true, y_pred):
    return 1-dice_coef(y_true, y_pred)
inputs = Input((None,None,None, 1))
conv1 = Conv3D(32, (3, 3, 3), activation='relu', padding='same')(inputs)
conv1 = Conv3D(32, (3, 3, 3), activation='relu', padding='same')(conv1)
pool1 = MaxPooling3D(pool_size=(2, 2, 2))(conv1)
conv2 = Conv3D(64, (3, 3, 3), activation='relu', padding='same')(pool1)
conv2 = Conv3D(64, (3, 3, 3), activation='relu', padding='same')(conv2)
pool2 = MaxPooling3D(pool_size=(2, 2, 2))(conv2)
conv3 = Conv3D(128, (3, 3, 3), activation='relu', padding='same')(pool2)
conv3 = Conv3D(128, (3, 3, 3), activation='relu', padding='same')(conv3)
pool3 = MaxPooling3D(pool_size=(2, 2, 2))(conv3)
conv4 = Conv3D(256, (3, 3, 3), activation='relu', padding='same')(pool3)
conv4 = Conv3D(256, (3, 3, 3), activation='relu', padding='same')(conv4)
pool4 = MaxPooling3D(pool_size=(2, 2, 2))(conv4)
conv5 = Conv3D(512, (3, 3, 3), activation='relu', padding='same')(pool4)
conv5 = Conv3D(512, (3, 3, 3), activation='relu', padding='same')(conv5)

up6 = concatenate([Conv3DTranspose(256, (2, 2, 2), strides=(2, 2, 2), padding='same')(conv5), conv4], axis=4)
conv6 = Conv3D(256, (3, 3, 3), activation='relu', padding='same')(up6)
conv6 = Conv3D(256, (3, 3, 3), activation='relu', padding='same')(conv6)
up7 = concatenate([Conv3DTranspose(128, (2, 2, 2), strides=(2, 2, 2), padding='same')(conv6), conv3], axis=4)
conv7 = Conv3D(128, (3, 3, 3), activation='relu', padding='same')(up7)
conv7 = Conv3D(128, (3, 3, 3), activation='relu', padding='same')(conv7)
up8 = concatenate([Conv3DTranspose(64, (2, 2, 2), strides=(2, 2, 2), padding='same')(conv7), conv2], axis=4)
conv8 = Conv3D(64, (3, 3, 3), activation='relu', padding='same')(up8)
conv8 = Conv3D(64, (3, 3, 3), activation='relu', padding='same')(conv8)
up9 = concatenate([Conv3DTranspose(32, (2, 2, 2), strides=(2, 2, 2), padding='same')(conv8), conv1], axis=4)
conv9 = Conv3D(32, (3, 3, 3), activation='relu', padding='same')(up9)
conv9 = Conv3D(32, (3, 3, 3), activation='relu', padding='same')(conv9)
conv10 = Conv3D(6, (3, 3, 3), activation='sigmoid', padding='same')(conv9)

model = Model(inputs=[inputs], outputs=[conv10])
model.compile(optimizer=Adam(lr = 0.00001), loss=dice_coef_loss, metrics=[dice_coef])

In [4]:
def dice_c(y_true, y_pred):
    
    y_true_1 = y_true.flatten()
    y_pred_1 = y_pred.flatten()
    d_1 = (2 * np.sum(y_true_1 * y_pred_1) +1 )/ (np.sum(y_true_1) + np.sum(y_pred_1) + 1)
    
    return d_1

### Post processing of predictions:

In [5]:
def post_process(model_weights,aorta,atlas):
    # read aorta:
    im = sitk.ReadImage(aorta)
    im = sitk.GetArrayFromImage(im)

    aorta=np.copy(im)
    aorta[aorta>0.0]=1.0

    # downsize:
    im=resize(im,(64,64,64),1,preserve_range=True)

    # binarize:
    im[im>0]=1.0

    #im=np.reshape(im,(1,96,64,32,1))
    im=np.reshape(im,(1,64,64,64,1))
    
    model.load_weights(model_weights)
    pred=model.predict(im, batch_size=1)
    
    seg=np.copy(pred[0,:,:,:,:])
    seg[seg>=0.5]=1.0
    seg[seg<0.5]=0.0

    seg[:,:,:,0]=binary_erosion(seg[:,:,:,0])
    seg[:,:,:,1]=binary_erosion(seg[:,:,:,1])
    seg[:,:,:,2]=binary_erosion(seg[:,:,:,2])
    seg[:,:,:,3]=binary_erosion(seg[:,:,:,3])
    seg[:,:,:,4]=binary_erosion(seg[:,:,:,4])
    seg[:,:,:,5]=binary_erosion(seg[:,:,:,5])

    # enlarge:
    seg=resize(seg,(300, 260, 190,6),1,preserve_range=True)
    seg[seg>=0.5]=1.0
    seg[seg<0.5]=0.0

    # remove predictions outside aorta:
    aorta=aorta.reshape((aorta.shape[0],aorta.shape[1],aorta.shape[2],1))
    seg=np.multiply(seg,aorta)
    aorta=aorta.reshape((aorta.shape[0],aorta.shape[1],aorta.shape[2]))

    # keep only largest connected component (excluding the unlabelled part of the aorta):
    for i in range(0,5):
        l_im=label(seg[:,:,:,i])
        try:
            rp=regionprops_table(l_im,properties=('label','area'))
            l=rp['label'][np.argmax(rp['area'])]
            l_im[l_im!=l]=0.0
            l_im[l_im==l]=1.0

            seg[:,:,:,i]=l_im
        except:
            pass

    # set overlapping points to zero:
    tmp=np.sum(seg,axis=3)
    tmp[tmp==1.0]=100
    tmp[tmp<100]=1
    tmp[tmp==100]=0

    idx=scipy.ndimage.morphology.distance_transform_edt(tmp,return_distances=False,return_indices=True)

    # label points:
    label_img=np.zeros(aorta.shape)

    label_img += 1*seg[:,:,:,0]
    label_img += 2*seg[:,:,:,1]
    label_img += 3*seg[:,:,:,2]
    label_img += 4*seg[:,:,:,3]
    label_img += 5*seg[:,:,:,4]
    label_img += 6*seg[:,:,:,5]

    # set overlapping points to zero:
    tmp=np.sum(seg,axis=3)
    tmp[tmp>1.0]=10
    tmp[tmp<10]=1
    tmp[tmp==10]=0

    label_img = np.multiply(label_img,tmp)

    # fill in nearest neighbour values:
    for x in range(0,aorta.shape[0]):
        for y in range(0,aorta.shape[1]):
            for z in range(0,aorta.shape[2]):
                if label_img[x,y,z]==0:
                    i=idx[:,x,y,z]
                    label_img[x,y,z]=label_img[i[0],i[1],i[2]]

    # set background to zero:
    final_seg=np.multiply(label_img,aorta)
    
    # read atlas:
    a = sitk.ReadImage(atlas)
    a = sitk.GetArrayFromImage(a)

    # separate classes:
    ano=np.zeros((300,260,190,6))

    for i in range(0,6):
        tmp=np.zeros(a.shape)
        tmp[np.where(a==i+1)]=1.0
        ano[:,:,:,i]=tmp
    
    # separate classes:
    p=np.zeros((300,260,190,6))

    for i in range(0,6):
        tmp=np.zeros(final_seg.shape)
        tmp[np.where(final_seg==i+1)]=1.0
        p[:,:,:,i]=tmp
    # calculate dice:
    dice=dice_c(ano,p)
    
    return final_seg, dice

### Apply prediction and post processing to an aorta:

In [6]:
s,d=post_process(model_weights="../models/aorta_3d_model_weights.hdf5",aorta="../data/ID20_aorta.nii.gz",atlas="../data/ID20_atlas.nii.gz")

In [7]:
print("Dice coefficient: " + str(d))

Dice coefficient: 0.8568815552378631
