# plaq-u-net: multi-patch consensus U-Net for automated detection and segmentation of the carotid arteries on black blood MRI sequences

E. Lavrova, 2022

This is a code supporting the corresponding paper.

Packages import:

In [1]:
import os
import numpy as np

from xml.etree import ElementTree
from numpy import zeros
from numpy import asarray
from numpy import expand_dims
from numpy import mean
import pydicom
import random
import matplotlib.pyplot as plt
import glob
import matplotlib.patches as patches

import cv2
import SimpleITK as sitk

from skimage import exposure
from skimage import img_as_float
from skimage.io import imread, imshow, concatenate_images
from skimage.transform import resize
from skimage.morphology import label
from skimage import morphology
from skimage.filters import threshold_otsu, threshold_local

from keras.models import Model, load_model
from keras.layers import Input, BatchNormalization, Activation, Dense, Dropout
from keras.layers.core import Lambda, RepeatVector, Reshape
from keras.layers.convolutional import Conv2D, Conv2DTranspose
from keras.layers.pooling import MaxPooling2D, GlobalMaxPool2D
from keras.layers.merge import concatenate, add
from keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau
from keras.optimizers import Adam
from keras.preprocessing.image import ImageDataGenerator, array_to_img, img_to_array, load_img

import pandas as pd

plt.style.use("ggplot")

from tqdm import tqdm_notebook, tnrange
from itertools import chain
from scipy.ndimage import zoom

Using TensorFlow backend.


In [2]:
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = '1'                        
import tensorflow as tf
import keras.backend as K
K.tensorflow_backend._get_available_gpus()

W1115 20:39:58.012665 48480 module_wrapper.py:139] From C:\ProgramData\Anaconda3\envs\segway\lib\site-packages\keras\backend\tensorflow_backend.py:174: The name tf.get_default_session is deprecated. Please use tf.compat.v1.get_default_session instead.

W1115 20:39:58.013662 48480 module_wrapper.py:139] From C:\ProgramData\Anaconda3\envs\segway\lib\site-packages\keras\backend\tensorflow_backend.py:181: The name tf.ConfigProto is deprecated. Please use tf.compat.v1.ConfigProto instead.

W1115 20:39:58.015666 48480 module_wrapper.py:139] From C:\ProgramData\Anaconda3\envs\segway\lib\site-packages\keras\backend\tensorflow_backend.py:186: The name tf.Session is deprecated. Please use tf.compat.v1.Session instead.

W1115 20:40:00.620081 48480 module_wrapper.py:139] From C:\ProgramData\Anaconda3\envs\segway\lib\site-packages\keras\backend\tensorflow_backend.py:190: The name tf.global_variables is deprecated. Please use tf.compat.v1.global_variables instead.



['/job:localhost/replica:0/task:0/device:GPU:0']

In [3]:
import nibabel as nib
import pickle

## 1. Loading the models

Defining model blocks:

In [4]:
def conv2d_block(input_tensor, n_filters, kernel_size = 3, batchnorm = True):
    """Function to add 2 convolutional layers with the parameters passed to it"""
    # first layer
    x = Conv2D(filters = n_filters, kernel_size = (kernel_size, kernel_size),\
              kernel_initializer = 'he_normal', padding = 'same')(input_tensor)
    if batchnorm:
        x = BatchNormalization()(x)
    x = Activation('elu')(x)

    # second layer
    x = Conv2D(filters = n_filters, kernel_size = (kernel_size, kernel_size),\
              kernel_initializer = 'he_normal', padding = 'same')(input_tensor)
    if batchnorm:
        x = BatchNormalization()(x)
    x = Activation('elu')(x)

    return x

def get_unet(input_img, n_filters = 8, dropout = 0.10, batchnorm = True):
    """Function to define the UNET Model"""
    # Contracting Path
    c1 = conv2d_block(input_img, n_filters * 1, kernel_size = 3, batchnorm = batchnorm)
    p1 = MaxPooling2D((2, 2))(c1)
    p1 = Dropout(dropout)(p1, training=True)

    c2 = conv2d_block(p1, n_filters * 2, kernel_size = 3, batchnorm = batchnorm)
    p2 = MaxPooling2D((2, 2))(c2)
    p2 = Dropout(dropout)(p2, training=True)

    c3 = conv2d_block(p2, n_filters * 4, kernel_size = 3, batchnorm = batchnorm)
    p3 = MaxPooling2D((2, 2))(c3)
    p3 = Dropout(dropout)(p3, training=True)

    c4 = conv2d_block(p3, n_filters * 8, kernel_size = 3, batchnorm = batchnorm)
    p4 = MaxPooling2D((2, 2))(c4)
    p4 = Dropout(dropout)(p4, training=True)

    c5 = conv2d_block(p4, n_filters = n_filters * 16, kernel_size = 3, batchnorm = batchnorm)

    # Expansive Path
    u6 = Conv2DTranspose(n_filters * 8, (3, 3), strides = (2, 2), padding = 'same')(c5)
    u6 = concatenate([u6, c4])
    u6 = Dropout(dropout)(u6, training=True)
    c6 = conv2d_block(u6, n_filters * 8, kernel_size = 3, batchnorm = batchnorm)

    u7 = Conv2DTranspose(n_filters * 4, (3, 3), strides = (2, 2), padding = 'same')(c6)
    u7 = concatenate([u7, c3])
    u7 = Dropout(dropout)(u7, training=True)
    c7 = conv2d_block(u7, n_filters * 4, kernel_size = 3, batchnorm = batchnorm)

    u8 = Conv2DTranspose(n_filters * 2, (3, 3), strides = (2, 2), padding = 'same')(c7)
    u8 = concatenate([u8, c2])
    u8 = Dropout(dropout)(u8, training=True)
    c8 = conv2d_block(u8, n_filters * 2, kernel_size = 3, batchnorm = batchnorm)

    u9 = Conv2DTranspose(n_filters * 1, (3, 3), strides = (2, 2), padding = 'same')(c8)
    u9 = concatenate([u9, c1])
    u9 = Dropout(dropout)(u9, training=True)
    c9 = conv2d_block(u9, n_filters * 1, kernel_size = 3, batchnorm = batchnorm)

    outputs = Conv2D(1, (1, 1), activation='sigmoid')(c9)
    model = Model(inputs=[input_img], outputs=[outputs])
    return model

Models compilation + loading weights:

In [5]:
im_height = 64
im_width = 64


input_img = Input((im_height, im_width, 2), name='img')

plaqunet_simple = get_unet(input_img, n_filters=16, dropout=0.05, batchnorm=True)
plaqunet_simple.load_weights('../res/plaq-u-net_simple_dce_2.h5')

plaqunet_aug = get_unet(input_img, n_filters=16, dropout=0.05, batchnorm=True)
plaqunet_aug.load_weights('../res/plaq-u-net_aug_dce_2.h5')

W1115 20:40:00.871916 48480 module_wrapper.py:139] From C:\ProgramData\Anaconda3\envs\segway\lib\site-packages\keras\backend\tensorflow_backend.py:517: The name tf.placeholder is deprecated. Please use tf.compat.v1.placeholder instead.

W1115 20:40:00.875916 48480 module_wrapper.py:139] From C:\ProgramData\Anaconda3\envs\segway\lib\site-packages\keras\backend\tensorflow_backend.py:74: The name tf.get_default_graph is deprecated. Please use tf.compat.v1.get_default_graph instead.

W1115 20:40:00.876917 48480 module_wrapper.py:139] From C:\ProgramData\Anaconda3\envs\segway\lib\site-packages\keras\backend\tensorflow_backend.py:4185: The name tf.truncated_normal is deprecated. Please use tf.random.truncated_normal instead.

W1115 20:40:00.903922 48480 module_wrapper.py:139] From C:\ProgramData\Anaconda3\envs\segway\lib\site-packages\keras\backend\tensorflow_backend.py:1834: The name tf.nn.fused_batch_norm is deprecated. Please use tf.compat.v1.nn.fused_batch_norm instead.

W1115 20:40:00.9

In [6]:
plaquncertaintynet_simple = get_unet(input_img, n_filters=16, dropout=0.05, batchnorm=True)
plaquncertaintynet_simple.load_weights('../res/plaq-u-net_simple_ul_4.h5')

plaquncertaintynet_aug = get_unet(input_img, n_filters=16, dropout=0.05, batchnorm=True)
plaquncertaintynet_aug.load_weights('../res/plaq-u-net_aug_ul_4.h5')

In [7]:
IMG_WIDTH = 256
IMG_HEIGHT = 512
IMG_SIDE = 64

In [8]:
def norm_img(img_arr):

    img_min = np.min(img_arr) 
    img_max = np.max(img_arr) 
    
    img_norm = np.copy((img_arr - img_min)/(img_max - img_min)*255).astype(np.uint8)
        
    return img_norm

In [9]:
def fill_arrays(X, y, sub_names, dirname_img, dirname_nnunet, dirname_gt):
    
    i = 0

    for sub_name in sub_names:
       

        filename_img = dirname_img + sub_name + '_0000.nii.gz'
        filename_sm = dirname_nnunet + sub_name + '.npz'
        filename_pkl = dirname_nnunet + sub_name + '.pkl'
        filename_gt = dirname_gt + sub_name + '.nii.gz'

        img = nib.load(filename_img).get_fdata().T
        gt = nib.load(filename_gt).get_fdata().T
        sm_cropped = np.load(filename_sm)['softmax']
        with open(filename_pkl, 'rb') as f:
            p = pickle.load(f)
        crop_box = p['crop_bbox']
        sm = np.zeros(img.shape, dtype=np.float16)
        sm[crop_box[0][0]:crop_box[0][1], crop_box[1][0]:crop_box[1][1], crop_box[2][0]:crop_box[2][1]] = sm_cropped[1, ...]

        for j in range (0, img.shape[0]):
            img_norm = norm_img(img[j, ...])
            X[i, ..., 0] = img_norm
            X[i, ..., 1] = 255*sm[j, ...]
            y[i, ..., 0] = gt[j, ...]
            i += 1
                
    return None
    

In [10]:
def detect4multipatches_0(img, model):
    
    img_padded = np.zeros((IMG_HEIGHT+128, IMG_WIDTH+128, 2), dtype=np.uint8)
    img_padded[64:-64, 64:-64, :] = img
    
    sm = img_padded[..., 1]
    sm_bin = (sm>0).astype(np.uint8)
    sm_bin_label = label(sm_bin)
    label_weights = []
    for l in range(1, np.max(sm_bin_label)+1):
        mask_label = (sm_bin_label==l).astype(np.uint8)
        weight_label = np.sum(mask_label*sm)
        rec_lw = {'label': l, 'weight': weight_label}
        label_weights.append(rec_lw)
    label_weights = pd.DataFrame(label_weights)
    n_labels = min(len(label_weights), 1)
    labels = []
    if len(label_weights)>0:
        label_weights.sort_values(by='weight', inplace=True, ascending = False)
        labels = np.array(label_weights['label'])[:n_labels]
    contour = np.isin(sm_bin_label, labels).astype(np.uint8)
    contour_pixels = np.where(contour>0)
    pred_padded = np.zeros((img_padded.shape[0], img_padded.shape[1]))
    if np.sum(contour)>0:
        x_center = int(np.mean(contour_pixels[0]))
        y_center = int(np.mean(contour_pixels[1]))

        img_patch = np.zeros((1, 64, 64, 2))
        img_patch[0, ...] = img_padded[x_center-32:x_center+32, y_center-32:y_center+32, :].copy()

        img_patch_pred = model.predict(img_patch, verbose=0)
        pred_padded[x_center-32:x_center+32, y_center-32:y_center+32] = img_patch_pred[..., 0]
            
    return pred_padded[64:-64, 64:-64]

## 2. CA probability maps calculation

In [11]:
def fill_array(sub_name, dirname_img, dirname_nnunet):
    
    i = 0
    
    filename_img = dirname_img + sub_name + '_0000.nii.gz'
    filename_sm = dirname_nnunet + sub_name + '.npz'
    filename_pkl = dirname_nnunet + sub_name + '.pkl'
    img_nii = nib.load(filename_img)
    img = img_nii.get_fdata().T
    sm_cropped = np.load(filename_sm)['softmax']
    with open(filename_pkl, 'rb') as f:
        p = pickle.load(f)
    crop_box = p['crop_bbox']
    sm = np.zeros(img.shape, dtype=np.float16)
    sm[crop_box[0][0]:crop_box[0][1], crop_box[1][0]:crop_box[1][1], crop_box[2][0]:crop_box[2][1]] = sm_cropped[1, ...]
    
    X = np.zeros((img.shape[0], img.shape[1], img.shape[2], 2), dtype = np.uint8)

    for j in range (0, img.shape[0]):
        img_norm = norm_img(img[j, ...])
        X[i, ..., 0] = img_norm
        X[i, ..., 1] = 255*sm[j, ...]
        i += 1
              
    return X, img_nii.affine

In [12]:
sub_names_test = ['AMC012', 'AMC006', 
                  'MUMC094', 'MUMC027', 'MUMC079', 'MUMC052', 'MUMC127', 'MUMC071', 'MUMC038', 'MUMC093', 'MUMC107', 
                  'MUMC022', 'MUMC114', 'MUMC115', 'MUMC069', 'MUMC130', 'MUMC036', 'MUMC007', 'MUMC059', 'MUMC080', 
                  'UMCU036', 'UMCU025', 'UMCU008', 'UMCU034']

In [13]:
sub_names_emc = ['EMC003', 'EMC004', 'EMC005', 'EMC007', 'EMC008', 'EMC009', 'EMC011', 
                 'EMC015', 'EMC018', 'EMC020', 'EMC024', 'EMC027', 'EMC029', 'EMC031', 
                 'EMC032', 'EMC034', 'EMC035', 'EMC036', 'EMC038', 'EMC041', 'EMC042', 
                 'EMC043', 'EMC045', 'EMC046', 'EMC047', 'EMC048', 'EMC049', 'EMC050', 
                 'EMC051', 'EMC052', 'EMC054', 'EMC055', 'EMC056', 'EMC057']

Probabilities

In [18]:
dirname_imgdata_test = 'D:/Lisa/nnUNet_raw_data_base/nnUNet_raw_data/Task001_CA/imagesTs1/'
dirname_gtdata_test = 'D:/Lisa/nnUNet_raw_data_base/nnUNet_raw_data/Task001_CA/labelsTs1/'
dirname_nnunetdata_test = 'D:/Lisa/nnUNet_raw_data_base/nnUNet_results/Ts1/'
dirname_results_test = '../res/nifti_compare/test_plaqunet_epochs/'
dirname_results_test_p = '../res/nifti_compare/test_plaqunet_epochs_p/'

for sub_name in sub_names_test:
    X, affine_nii = fill_array(sub_name, dirname_imgdata_test, dirname_nnunetdata_test)
    pred = np.zeros((X.shape[0], X.shape[1], X.shape[2]), dtype=np.float32)
    for i in range(0, X.shape[0]):
        pred_slice = detect4multipatches_0(X[i, ...], plaqunet_aug)
        pred[i, ...] = (pred_slice).astype(np.float32)
    nifti_pred = nib.Nifti1Image((pred>0.5).astype(np.uint8).T, affine=affine_nii)
    nifti_pred_p = nib.Nifti1Image(pred.T, affine=affine_nii)
    nib.save(nifti_pred, dirname_results_test + sub_name + '.nii.gz')
    nib.save(nifti_pred_p, dirname_results_test_p + sub_name + '.nii.gz')

In [19]:
dirname_imgdata_t2w = 'D:/Lisa/nnUNet_raw_data_base/nnUNet_raw_data/Task001_CA/imagesTs2/'
dirname_gtdata_t2w = 'D:/Lisa/nnUNet_raw_data_base/nnUNet_raw_data/Task001_CA/labelsTs2/'
dirname_nnunetdata_t2w = 'D:/Lisa/nnUNet_raw_data_base/nnUNet_results/Ts2/'
dirname_results_t2w = '../res/nifti_compare/t2w_plaqunet_epochs/'
dirname_results_t2w_p = '../res/nifti_compare/t2w_plaqunet_epochs_p/'

for sub_name in sub_names_test:
    X, affine_nii = fill_array(sub_name, dirname_imgdata_t2w, dirname_nnunetdata_t2w)
    pred = np.zeros((X.shape[0], X.shape[1], X.shape[2]), dtype=np.float32)
    for i in range(0, X.shape[0]):
        pred_slice = detect4multipatches_0(X[i, ...], plaqunet_aug)
        pred[i, ...] = (pred_slice).astype(np.float32)
    nifti_pred = nib.Nifti1Image((pred>0.5).astype(np.uint8).T, affine=affine_nii)
    nifti_pred_p = nib.Nifti1Image(pred.T, affine=affine_nii)
    nib.save(nifti_pred, dirname_results_t2w + sub_name + '.nii.gz')
    nib.save(nifti_pred_p, dirname_results_t2w_p + sub_name + '.nii.gz')

  


In [20]:
dirname_imgdata_t1wce = 'D:/Lisa/nnUNet_raw_data_base/nnUNet_raw_data/Task001_CA/imagesTs3/'
dirname_gtdata_t1wce = 'D:/Lisa/nnUNet_raw_data_base/nnUNet_raw_data/Task001_CA/labelsTs3/'
dirname_nnunetdata_t1wce = 'D:/Lisa/nnUNet_raw_data_base/nnUNet_results/Ts3/'
dirname_results_t1wce = '../res/nifti_compare/t1wce_plaqunet_epochs/'
dirname_results_t1wce_p = '../res/nifti_compare/t1wce_plaqunet_epochs_p/'

for sub_name in sub_names_test:
    X, affine_nii = fill_array(sub_name, dirname_imgdata_t1wce, dirname_nnunetdata_t1wce)
    pred = np.zeros((X.shape[0], X.shape[1], X.shape[2]), dtype=np.float32)
    for i in range(0, X.shape[0]):
        pred_slice = detect4multipatches_0(X[i, ...], plaqunet_aug)
        pred[i, ...] = (pred_slice).astype(np.float32)
    nifti_pred = nib.Nifti1Image((pred>0.5).astype(np.uint8).T, affine=affine_nii)
    nifti_pred_p = nib.Nifti1Image(pred.T, affine=affine_nii)
    nib.save(nifti_pred, dirname_results_t1wce + sub_name + '.nii.gz')
    nib.save(nifti_pred_p, dirname_results_t1wce_p + sub_name + '.nii.gz')

  


In [21]:
dirname_imgdata_emc = 'D:/Lisa/nnUNet_raw_data_base/nnUNet_raw_data/Task001_CA/imagesTs4/'
dirname_gtdata_emc = 'D:/Lisa/nnUNet_raw_data_base/nnUNet_raw_data/Task001_CA/labelsTs4/'
dirname_nnunetdata_emc = 'D:/Lisa/nnUNet_raw_data_base/nnUNet_results/Ts4/'
dirname_results_emc = '../res/nifti_compare/emc_plaqunet_epochs/'
dirname_results_emc_p = '../res/nifti_compare/emc_plaqunet_epochs_p/'

for sub_name in sub_names_emc:
    X, affine_nii = fill_array(sub_name, dirname_imgdata_emc, dirname_nnunetdata_emc)
    pred = np.zeros((X.shape[0], X.shape[1], X.shape[2]), dtype=np.float32)
    for i in range(0, X.shape[0]):
        pred_slice = detect4multipatches_0(X[i, ...], plaqunet_aug)
        pred[i, ...] = (pred_slice).astype(np.float32)
    nifti_pred = nib.Nifti1Image((pred>0.5).astype(np.uint8).T, affine=affine_nii)
    nifti_pred_p = nib.Nifti1Image(pred.T, affine=affine_nii)
    nib.save(nifti_pred, dirname_results_emc + sub_name + '.nii.gz')
    nib.save(nifti_pred_p, dirname_results_emc_p + sub_name + '.nii.gz')

Dropout

In [34]:
tf.compat.v1.set_random_seed(0)

In [35]:
dirname_imgdata_test = 'D:/Lisa/nnUNet_raw_data_base/nnUNet_raw_data/Task001_CA/imagesTs1/'
dirname_gtdata_test = 'D:/Lisa/nnUNet_raw_data_base/nnUNet_raw_data/Task001_CA/labelsTs1/'
dirname_nnunetdata_test = 'D:/Lisa/nnUNet_raw_data_base/nnUNet_results/Ts1/'
dirname_results_test = '../res/nifti_compare/test_plaqumap_dropout/'

for sub_name in sub_names_test:
    X, affine_nii = fill_array(sub_name, dirname_imgdata_test, dirname_nnunetdata_test)
    for j in range (0, 10):
        pred = np.zeros((X.shape[0], X.shape[1], X.shape[2]), dtype=np.float32)
        for i in range(0, X.shape[0]):
            pred_slice = detect4multipatches_0(X[i, ...], plaquncertaintynet_aug)
            pred[i, ...] = (pred_slice).astype(np.float32)
        nifti_pred = nib.Nifti1Image(pred.T, affine=affine_nii)
        nib.save(nifti_pred, dirname_results_test + sub_name + '_' + str(j) + '.nii.gz')

In [40]:
dirname_imgdata_test = 'D:/Lisa/nnUNet_raw_data_base/nnUNet_raw_data/Task001_CA/imagesTs2/'
dirname_gtdata_test = 'D:/Lisa/nnUNet_raw_data_base/nnUNet_raw_data/Task001_CA/labelsTs2/'
dirname_nnunetdata_test = 'D:/Lisa/nnUNet_raw_data_base/nnUNet_results/Ts2/'
dirname_results_test = '../res/nifti_compare/t2w_plaqumap_dropout/'

for sub_name in sub_names_test:
    X, affine_nii = fill_array(sub_name, dirname_imgdata_test, dirname_nnunetdata_test)
    for j in range (0, 10):
        pred = np.zeros((X.shape[0], X.shape[1], X.shape[2]), dtype=np.float32)
        for i in range(0, X.shape[0]):
            pred_slice = detect4multipatches_0(X[i, ...], plaquncertaintynet_aug)
            pred[i, ...] = (pred_slice).astype(np.float32)
        nifti_pred = nib.Nifti1Image(pred.T, affine=affine_nii)
        nib.save(nifti_pred, dirname_results_test + sub_name + '_' + str(j) + '.nii.gz')

  


In [41]:
dirname_imgdata_test = 'D:/Lisa/nnUNet_raw_data_base/nnUNet_raw_data/Task001_CA/imagesTs3/'
dirname_gtdata_test = 'D:/Lisa/nnUNet_raw_data_base/nnUNet_raw_data/Task001_CA/labelsTs3/'
dirname_nnunetdata_test = 'D:/Lisa/nnUNet_raw_data_base/nnUNet_results/Ts3/'
dirname_results_test = '../res/nifti_compare/t1wce_plaqumap_dropout/'

for sub_name in sub_names_test:
    X, affine_nii = fill_array(sub_name, dirname_imgdata_test, dirname_nnunetdata_test)
    for j in range (0, 10):
        pred = np.zeros((X.shape[0], X.shape[1], X.shape[2]), dtype=np.float32)
        for i in range(0, X.shape[0]):
            pred_slice = detect4multipatches_0(X[i, ...], plaquncertaintynet_aug)
            pred[i, ...] = (pred_slice).astype(np.float32)
        nifti_pred = nib.Nifti1Image(pred.T, affine=affine_nii)
        nib.save(nifti_pred, dirname_results_test + sub_name + '_' + str(j) + '.nii.gz')

  


In [43]:
dirname_imgdata_test = 'D:/Lisa/nnUNet_raw_data_base/nnUNet_raw_data/Task001_CA/imagesTs4/'
dirname_gtdata_test = 'D:/Lisa/nnUNet_raw_data_base/nnUNet_raw_data/Task001_CA/labelsTs4/'
dirname_nnunetdata_test = 'D:/Lisa/nnUNet_raw_data_base/nnUNet_results/Ts4/'
dirname_results_test = '../res/nifti_compare/emc_plaqumap_dropout/'

for sub_name in sub_names_emc:
    X, affine_nii = fill_array(sub_name, dirname_imgdata_test, dirname_nnunetdata_test)
    for j in range (0, 10):
        pred = np.zeros((X.shape[0], X.shape[1], X.shape[2]), dtype=np.float32)
        for i in range(0, X.shape[0]):
            pred_slice = detect4multipatches_0(X[i, ...], plaquncertaintynet_aug)
            pred[i, ...] = (pred_slice).astype(np.float32)
        nifti_pred = nib.Nifti1Image(pred.T, affine=affine_nii)
        nib.save(nifti_pred, dirname_results_test + sub_name + '_' + str(j) + '.nii.gz')