In [1]:
import numpy as np
import os
from keras.models import load_model
from PIL import Image
import time
from model_residual_selu_deeplysup import Tversky, Tversky_loss
import matplotlib.pyplot as plt
import ipywidgets as widgets
import copy
import matplotlib.cm as cm

Using TensorFlow backend.


In [2]:
#helper function to evaluate the segmentation accuracy
def dice(y_true, y_pred):
    """
    Dice evaluation of segmentation accuracy
    y_true -- numpy array of ground truth image
    y_pred -- numpy array of predicted image
    """
    y_true_f = y_true.flatten()
    y_pred_f = y_pred.flatten()
    intersection = np.sum(y_true_f * y_pred_f)
    smooth = 0.0001
    return (2. * intersection + smooth) / (np.sum(y_true_f) + np.sum(y_pred_f) + smooth)

### Load data and model

In [3]:
path_to_data = "../data/Sample_slices/" #path to example CT slices,  Images folder contains preprocessed CT slices, ground-truth contains the manually segmented masks of the cartilaginous nasal capsule
path_to_model = "../data/Trained_models/val0_residual_selu_deeplysup.h5" #path to the h5 file of a trained model

In [4]:
#initialize folder to store the prediction results
try:
    os.mkdir("../results/Predicted_masks")
except:
    print ("Folder already exists.")
file = open("../results/Predicted_masks/results.txt", "w")

Folder already exists.


In [5]:
#load the trained model
model = load_model(path_to_model, custom_objects={'Tversky_loss': Tversky_loss, 'Tversky':Tversky})

In [6]:
#load the paths to images
paths_to_images_test = sorted([os.path.join(os.path.join(path_to_data, "Images").replace("\\","/"), f) for f in os.listdir(os.path.join(path_to_data, "Images").replace("\\","/")) if f.endswith(".tif")])
paths_to_masks_test = sorted([os.path.join(os.path.join(path_to_data, "Ground_truth").replace("\\","/"), f) for f in os.listdir(os.path.join(path_to_data, "Ground_truth").replace("\\","/")) if f.endswith(".tif")])

In [7]:
#load the image data
x_test = np.expand_dims(np.stack([np.array(Image.open(paths_to_images_test[k])) for k in list(range(0, len(paths_to_images_test)))], axis = 0), axis = -1)
y_test = np.expand_dims(np.stack([np.array(Image.open(paths_to_masks_test[k])) for k in list(range(0, len(paths_to_images_test)))], axis = 0), axis = -1)
y_test[y_test>0] = 1

 ### Segment the cartilage and evaluate the segmentation accuracy

In [21]:
#segment the cartilage
start = time.time()
predicted = model.predict(x_test, batch_size = 8)[-1]>0.5
end = time.time()
elapsed_time = end - start
print("Prediction finished in " + '{0:.2f}'.format(elapsed_time) + " s")

Prediction finished in 4.81 s


In [22]:
#evaluate total dice of the prediction
dice_eval = dice(y_test, predicted)
print("Total dice: " + str(dice_eval))

Total dice: 0.8830282181973035


In [23]:
#save the predictions
file.write("Time for prediction [s]: " + '{0:.2f}'.format(elapsed_time) + "\n")
for i in range(0, predicted.shape[0]):
    Image.fromarray(predicted[i, :, :, 0].astype("bool")).save("../results/Predicted_masks/predicted_" + str(i).zfill(4) + ".tif")
    file.write("predicted_" + str(i).zfill(4) + ": Dice: " + str(dice(y_test[i, :, :, :], predicted[i, :, :, :])) + "\n")
file.close()

### Interactive visualisation of the segmentation - only usable with ipywidgets support

In [24]:
#set the colormaps
cmap_ground_truth = copy.copy(cm.get_cmap("inferno"))
cmap_ground_truth.set_under('k', alpha=0)
cmap_prediction = copy.copy(cm.get_cmap("cool"))
cmap_prediction.set_under('k', alpha=0)

def display_segmentation(index, show_prediction, show_ground_truth):
    cross_section = x_test[index, :, :, 0]
    ground_truth = y_test[index, :, :, 0]
    prediction = predicted[index, :, :, 0]
    title = os.path.basename(os.path.normpath(paths_to_images_test[index]))
    plt.figure(figsize = (15, 15))
    plt.imshow(cross_section, cmap = 'gray')
    if show_prediction==True:
        plt.imshow(prediction, cmap = cmap_prediction, clim = [0.1, 1], alpha = 0.5, interpolation = None)
    if show_ground_truth==True:
        plt.imshow(ground_truth, cmap = cmap_ground_truth, clim = [0.1, 1], alpha = 0.5, interpolation = None)
    plt.title(title)
    plt.show()
widgets.interact(display_segmentation, 
index = widgets.IntSlider(min = 0, max = x_test.shape[0]-1, step = 1, value = 0, continuos_update = False),
show_prediction = widgets.Checkbox(value=False, description = "Show prediction"),
show_ground_truth = widgets.Checkbox(value=False, description = "Show ground truth"))



interactive(children=(IntSlider(value=0, description='index', max=49), Checkbox(value=False, description='Show…

<function __main__.display_segmentation(index, show_prediction, show_ground_truth)>