# plaq-u-net: multi-patch consensus U-Net for automated detection and segmentation of the carotid arteries on black blood MRI sequences

E. Lavrova, 2022

This is a code supporting the corresponding paper.

Packages import:

In [1]:
import os
import numpy as np

from xml.etree import ElementTree
from numpy import zeros
from numpy import asarray
from numpy import expand_dims
from numpy import mean
import pydicom
import random
import matplotlib.pyplot as plt
import glob
import matplotlib.patches as patches

import cv2
import SimpleITK as sitk

from skimage import exposure
from skimage import img_as_float
from skimage.io import imread, imshow, concatenate_images
from skimage.transform import resize
from skimage.morphology import label
from skimage import morphology
from skimage.filters import threshold_otsu, threshold_local

from keras.models import Model, load_model
from keras.layers import Input, BatchNormalization, Activation, Dense, Dropout
from keras.layers.core import Lambda, RepeatVector, Reshape
from keras.layers.convolutional import Conv2D, Conv2DTranspose
from keras.layers.pooling import MaxPooling2D, GlobalMaxPool2D
from keras.layers.merge import concatenate, add
from keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau
from keras.optimizers import Adam
from keras.preprocessing.image import ImageDataGenerator, array_to_img, img_to_array, load_img

import pandas as pd

plt.style.use("ggplot")

from tqdm import tqdm_notebook, tnrange
from itertools import chain
from scipy.ndimage import zoom

Using TensorFlow backend.


In [2]:
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = '4,7'                        
import tensorflow as tf
import keras.backend as K
K.tensorflow_backend._get_available_gpus()

W0722 13:25:15.134429 38892 module_wrapper.py:139] From C:\ProgramData\Anaconda3\envs\segway\lib\site-packages\keras\backend\tensorflow_backend.py:174: The name tf.get_default_session is deprecated. Please use tf.compat.v1.get_default_session instead.

W0722 13:25:15.137429 38892 module_wrapper.py:139] From C:\ProgramData\Anaconda3\envs\segway\lib\site-packages\keras\backend\tensorflow_backend.py:181: The name tf.ConfigProto is deprecated. Please use tf.compat.v1.ConfigProto instead.

W0722 13:25:15.139430 38892 module_wrapper.py:139] From C:\ProgramData\Anaconda3\envs\segway\lib\site-packages\keras\backend\tensorflow_backend.py:186: The name tf.Session is deprecated. Please use tf.compat.v1.Session instead.

W0722 13:25:20.696840 38892 module_wrapper.py:139] From C:\ProgramData\Anaconda3\envs\segway\lib\site-packages\keras\backend\tensorflow_backend.py:190: The name tf.global_variables is deprecated. Please use tf.compat.v1.global_variables instead.



['/job:localhost/replica:0/task:0/device:GPU:0',
 '/job:localhost/replica:0/task:0/device:GPU:1']

## 1. Loading the models

Defining model blocks:

In [3]:
def conv2d_block(input_tensor, n_filters, kernel_size = 3, batchnorm = True):
    """Function to add 2 convolutional layers with the parameters passed to it"""
    # first layer
    x = Conv2D(filters = n_filters, kernel_size = (kernel_size, kernel_size),\
              kernel_initializer = 'he_normal', padding = 'same')(input_tensor)
    if batchnorm:
        x = BatchNormalization()(x)
    x = Activation('relu')(x)

    # second layer
    x = Conv2D(filters = n_filters, kernel_size = (kernel_size, kernel_size),\
              kernel_initializer = 'he_normal', padding = 'same')(input_tensor)
    if batchnorm:
        x = BatchNormalization()(x)
    x = Activation('relu')(x)

    return x

def get_unet(input_img, n_filters = 8, dropout = 0.2, batchnorm = True):
    """Function to define the UNET Model"""
    # Contracting Path
    c1 = conv2d_block(input_img, n_filters * 1, kernel_size = 3, batchnorm = batchnorm)
    p1 = MaxPooling2D((2, 2))(c1)
    p1 = Dropout(dropout)(p1)

    c2 = conv2d_block(p1, n_filters * 2, kernel_size = 3, batchnorm = batchnorm)
    p2 = MaxPooling2D((2, 2))(c2)
    p2 = Dropout(dropout)(p2)

    c3 = conv2d_block(p2, n_filters * 4, kernel_size = 3, batchnorm = batchnorm)
    p3 = MaxPooling2D((2, 2))(c3)
    p3 = Dropout(dropout)(p3)

    c4 = conv2d_block(p3, n_filters * 8, kernel_size = 3, batchnorm = batchnorm)
    p4 = MaxPooling2D((2, 2))(c4)
    p4 = Dropout(dropout)(p4)

    c5 = conv2d_block(p4, n_filters = n_filters * 16, kernel_size = 3, batchnorm = batchnorm)

    # Expansive Path
    u6 = Conv2DTranspose(n_filters * 8, (3, 3), strides = (2, 2), padding = 'same')(c5)
    u6 = concatenate([u6, c4])
    u6 = Dropout(dropout)(u6)
    c6 = conv2d_block(u6, n_filters * 8, kernel_size = 3, batchnorm = batchnorm)

    u7 = Conv2DTranspose(n_filters * 4, (3, 3), strides = (2, 2), padding = 'same')(c6)
    u7 = concatenate([u7, c3])
    u7 = Dropout(dropout)(u7)
    c7 = conv2d_block(u7, n_filters * 4, kernel_size = 3, batchnorm = batchnorm)

    u8 = Conv2DTranspose(n_filters * 2, (3, 3), strides = (2, 2), padding = 'same')(c7)
    u8 = concatenate([u8, c2])
    u8 = Dropout(dropout)(u8)
    c8 = conv2d_block(u8, n_filters * 2, kernel_size = 3, batchnorm = batchnorm)

    u9 = Conv2DTranspose(n_filters * 1, (3, 3), strides = (2, 2), padding = 'same')(c8)
    u9 = concatenate([u9, c1])
    u9 = Dropout(dropout)(u9)
    c9 = conv2d_block(u9, n_filters * 1, kernel_size = 3, batchnorm = batchnorm)

    outputs = Conv2D(1, (1, 1), activation='sigmoid')(c9)
    model = Model(inputs=[input_img], outputs=[outputs])
    return model

Defining metrics:

In [4]:
def dice_coef(y_true, y_pred):
    from keras import backend as K
    smooth=1
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)

def dice_loss(y_true, y_pred):
    return 1 - dice_coef(y_true, y_pred)

def custom_loss(y_true, y_pred):
    from keras.losses import binary_crossentropy
    return 0.5*keras.losses.binary_crossentropy(y_true,y_pred)+0.5*dice_loss(y_true,y_pred)

Models compilation + loading weights:

In [5]:
im_height = 64
im_width = 64

input_img = Input((im_height, im_width, 1), name='img')

model_simple = get_unet(input_img, n_filters=16, dropout=0.05, batchnorm=True)
model_simple.compile(optimizer=Adam(), loss=dice_loss, metrics=['accuracy', dice_coef])
model_simple.load_weights('../res/plaq-u-net_simple.h5')

model_aug = get_unet(input_img, n_filters=16, dropout=0.05, batchnorm=True)
model_aug.compile(optimizer=Adam(), loss=dice_loss, metrics=['accuracy', dice_coef])
model_aug.load_weights('../res/plaq-u-net_aug.h5')

W0722 13:25:28.973867 38892 module_wrapper.py:139] From C:\ProgramData\Anaconda3\envs\segway\lib\site-packages\keras\backend\tensorflow_backend.py:517: The name tf.placeholder is deprecated. Please use tf.compat.v1.placeholder instead.

W0722 13:25:28.980867 38892 module_wrapper.py:139] From C:\ProgramData\Anaconda3\envs\segway\lib\site-packages\keras\backend\tensorflow_backend.py:74: The name tf.get_default_graph is deprecated. Please use tf.compat.v1.get_default_graph instead.

W0722 13:25:28.983871 38892 module_wrapper.py:139] From C:\ProgramData\Anaconda3\envs\segway\lib\site-packages\keras\backend\tensorflow_backend.py:4185: The name tf.truncated_normal is deprecated. Please use tf.random.truncated_normal instead.

W0722 13:25:29.016867 38892 module_wrapper.py:139] From C:\ProgramData\Anaconda3\envs\segway\lib\site-packages\keras\backend\tensorflow_backend.py:1834: The name tf.nn.fused_batch_norm is deprecated. Please use tf.compat.v1.nn.fused_batch_norm instead.

W0722 13:25:29.0

Patching and consensus map calculation:

In [6]:
def detect4multipatches(img, model):
    
    steps = int((img.shape[0]-64)/4)+1
    
    M = np.empty((img.shape[0], img.shape[1], steps*steps))
    M[:] = np.NaN
    img_patch = np.zeros((steps*steps, 64, 64, 1))
    c = 0

    for i in range (0, steps):
        for j in range (0, steps):
            
            img_crop = img[4*i:4*i+64, 4*j:4*j+64]
            img_patch[c, ..., 0] = img_crop.copy()
            c += 1
            
    img_patch_pred = model.predict(img_patch, verbose=0)
    
    c = 0
    for i in range (0, steps):
        for j in range (0, steps):
            M[4*i:4*i+64, 4*j:4*j+64, c] = img_patch_pred[c, ..., 0]
            c += 1
            
    M_concord = np.nanmean(M, axis = 2)
    
    del M
            
    return M_concord

## 2. CA probability maps calculation

Some data loading and pre-processing functions:

In [7]:
# loading DICOM to array from the file path
def path2array(dcm_path):
    arr_dcm = pydicom.read_file(dcm_path, force = True)
    arr_dcm.file_meta.TransferSyntaxUID = pydicom.uid.ImplicitVRLittleEndian
    arr = arr_dcm.pixel_array
    return arr

# N4 bias field correction
def correctBiasField(img_input):
    
    corrected = False
    img_output = np.zeros(img_input.shape)

    while not corrected:

        try:
            corrector = sitk.N4BiasFieldCorrectionImageFilter()
            inputImage = sitk.GetImageFromArray(img_input)
            inputImage = sitk.Cast(inputImage, sitk.sitkFloat32)
            output = corrector.Execute(inputImage)
            img_output = sitk.GetArrayFromImage(output)
            corrected = True
        except:
            print ('BFC failed')

    return img_output

# zooming images to defined voxel size and array shape (with cropping/padding)
# from: https://stackoverflow.com/questions/37119071/scipy-rotate-and-zoom-an-image-without-changing-its-dimensions
def clipped_zoom(img, zoom_factor, **kwargs):

    h, w = img.shape[:2]

    # For multichannel images we don't want to apply the zoom factor to the RGB
    # dimension, so instead we create a tuple of zoom factors, one per array
    # dimension, with 1's for any trailing dimensions after the width and height.
    zoom_tuple = (zoom_factor,) * 2 + (1,) * (img.ndim - 2)

    # Zooming out
    if zoom_factor < 1:

        # Bounding box of the zoomed-out image within the output array
        zh = int(np.round(h * zoom_factor))
        zw = int(np.round(w * zoom_factor))
        top = (h - zh) // 2
        left = (w - zw) // 2

        # Zero-padding
        out = np.zeros_like(img)
        out[top:top+zh, left:left+zw] = zoom(img, zoom_tuple, **kwargs)

    # Zooming in
    elif zoom_factor > 1:

        # Bounding box of the zoomed-in region within the input array
        zh = int(np.round(h / zoom_factor))
        zw = int(np.round(w / zoom_factor))
        top = (h - zh) // 2
        left = (w - zw) // 2

        out = zoom(img[top:top+zh, left:left+zw], zoom_tuple, **kwargs)

        # `out` might still be slightly larger than `img` due to rounding, so
        # trim off any extra pixels at the edges
        trim_top = ((out.shape[0] - h) // 2)
        trim_left = ((out.shape[1] - w) // 2)
        out = out[trim_top:trim_top+h, trim_left:trim_left+w]

    # If zoom_factor == 1, just return the input array
    else:
        out = img
    return out

### 2.1. Test set

Getting patient names from the test set (from training script):

In [8]:
sub_names_test = ['AMC012', 'AMC006', 'MUMC094', 'MUMC027', 'MUMC079', 'MUMC052', 'MUMC127', 'MUMC071', 'MUMC038',
                  'MUMC093', 'MUMC107', 'MUMC022', 'MUMC114', 'MUMC115', 'MUMC069', 'MUMC130', 'MUMC036', 'MUMC007', 
                  'MUMC059', 'MUMC080', 'UMCU036', 'UMCU025', 'UMCU008', 'UMCU034']

In [None]:
ds_dir = '../data/'
results_dir_simple = '../res/maps/T1w/plaq-u-net_simple/'
results_dir_aug = '../res/maps/T1w/plaq-u-net_aug/'

for sub_name in sub_names_test:
    
    sub_img_names = glob.glob(ds_dir+sub_name+'*/T1W_*.dcm')
    
    for sub_img_name in sub_img_names:
        
        img = path2array(sub_img_name)
        img_test = img[8:-8,8:-8].copy()
        img_min = np.min(img_test)
        img_max = np.max(img_test)
        img_norm = np.copy((img_test - img_min)/(img_max - img_min)*255).astype(np.uint8)

        vessels_pred_multi_simple = detect4multipatches(img_norm, model_simple)
        vessels_pred_multi_aug = detect4multipatches(img_norm, model_aug)
        
        np.save(results_dir_simple + sub_name + '/' + sub_img_name.split(os.sep)[2][-10:-4] + '.npy', 
                vessels_pred_multi_simple)
        np.save(results_dir_aug + sub_name + '/' + sub_img_name.split(os.sep)[2][-10:-4] + '.npy', 
                vessels_pred_multi_aug)

In [12]:
ds_dir = '../data/'
results_dir_simple = '../res/maps/T1wCE/plaq-u-net_simple/'
results_dir_aug = '../res/maps/T1wCE/plaq-u-net_aug/'

for sub_name in sub_names_test:
    
    sub_img_names = glob.glob(ds_dir+sub_name+'*/T1W-contrast_*.dcm')
    
    for sub_img_name in sub_img_names:
        
        img = path2array(sub_img_name)
        img_test = img[8:-8,8:-8].copy()
        img_min = np.min(img_test)
        img_max = np.max(img_test)
        img_norm = np.copy((img_test - img_min)/(img_max - img_min)*255).astype(np.uint8)

        vessels_pred_multi_simple = detect4multipatches(img_norm, model_simple)
        vessels_pred_multi_aug = detect4multipatches(img_norm, model_aug)
        
        np.save(results_dir_simple + sub_name + '/' + sub_img_name.split(os.sep)[2][-10:-4] + '.npy', 
                vessels_pred_multi_simple)
        np.save(results_dir_aug + sub_name + '/' + sub_img_name.split(os.sep)[2][-10:-4] + '.npy', 
                vessels_pred_multi_aug)

  from ipykernel import kernelapp as app


In [13]:
ds_dir = '../data/'
results_dir_simple = '../res/maps/T2w/plaq-u-net_simple/'
results_dir_aug = '../res/maps/T2w/plaq-u-net_aug/'

for sub_name in sub_names_test:
    
    sub_img_names = glob.glob(ds_dir+sub_name+'*/T2W_*.dcm')
    
    for sub_img_name in sub_img_names:
        
        img = path2array(sub_img_name)
        img_test = img[8:-8,8:-8].copy()
        img_min = np.min(img_test)
        img_max = np.max(img_test)
        img_norm = np.copy((img_test - img_min)/(img_max - img_min)*255).astype(np.uint8)

        vessels_pred_multi_simple = detect4multipatches(img_norm, model_simple)
        vessels_pred_multi_aug = detect4multipatches(img_norm, model_aug)
        
        np.save(results_dir_simple + sub_name + '/' + sub_img_name.split(os.sep)[2][-10:-4] + '.npy', 
                vessels_pred_multi_simple)
        np.save(results_dir_aug + sub_name + '/' + sub_img_name.split(os.sep)[2][-10:-4] + '.npy', 
                vessels_pred_multi_aug)

  from ipykernel import kernelapp as app


### 2.2 EMC

Patient names from EMC:

In [13]:
sub_names_emc = ['EMC003', 'EMC004', 'EMC005', 'EMC007', 'EMC008', 'EMC009', 'EMC011', 
                 'EMC015', 'EMC018', 'EMC020', 'EMC024', 'EMC027', 'EMC029', 'EMC031', 
                 'EMC032', 'EMC034', 'EMC035', 'EMC036', 'EMC038', 'EMC041', 'EMC042', 
                 'EMC043', 'EMC045', 'EMC046', 'EMC047', 'EMC048', 'EMC049', 'EMC050', 
                 'EMC051', 'EMC052', 'EMC054', 'EMC055', 'EMC056', 'EMC057']

Calculating CA probability maps and saving to the results folder:

In [14]:
ds_dir = '../data/'
results_dir_simple = '../res/maps/plaq-u-net_simple/'
results_dir_aug = '../res/maps/plaq-u-net_aug/'

for sub_name in sub_names_emc:
    
    sub_img_names = glob.glob(ds_dir+sub_name+'*/T1W_*.dcm')
    
    for sub_img_name in sub_img_names:
        
        img = path2array(sub_img_name)
        #img_test = correctBiasField(img)
        img_test = img.copy()
        img_min = np.min(img_test)
        img_max = np.max(img_test)
        img_norm = np.copy((img_test - img_min)/(img_max - img_min)*255).astype(np.uint8)
        
        img_res = cv2.resize(img_norm.copy(), dsize=(512, 512), interpolation=cv2.INTER_CUBIC)
        
        vessels_pred_multi_simple = detect4multipatches(img_res, model_simple)
        vessels_pred_multi_aug = detect4multipatches(img_res, model_aug)
        
        vessels_pred_multi_simple_res = cv2.resize(vessels_pred_multi_simple.copy(), 
                                                   dsize=(256, 256), interpolation=cv2.INTER_CUBIC)
        vessels_pred_multi_aug_res = cv2.resize(vessels_pred_multi_aug.copy(), 
                                                dsize=(256, 256), interpolation=cv2.INTER_CUBIC)
        
        np.save(results_dir_simple + sub_name + '/' + sub_img_name.split(os.sep)[2][-17:-11] + '.npy', 
                vessels_pred_multi_simple_res)
        np.save(results_dir_aug + sub_name + '/' + sub_img_name.split(os.sep)[2][-17:-11] + '.npy', 
                vessels_pred_multi_aug_res)

BFC failed
BFC failed
BFC failed
BFC failed
BFC failed
BFC failed
BFC failed
BFC failed
BFC failed
BFC failed
