In [226]:
import os
import cv2
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from skimage.measure import regionprops


from skimage.filters import sobel
from skimage.measure import label
from skimage.util import img_as_float
from skimage.morphology import remove_small_objects, footprints
from skimage.segmentation import (morphological_geodesic_active_contour,
                                  inverse_gaussian_gradient,
                                  mark_boundaries)
from timeit import default_timer as timer
from scipy import ndimage as ndi

from skimage.morphology import disk
from skimage.segmentation import watershed
from skimage import data
from skimage.filters import rank
from skimage.util import img_as_ubyte

In [227]:
def accuracy(groundtruth_mask, pred_mask):
    intersect = np.sum(pred_mask*groundtruth_mask)
    union = np.sum(pred_mask) + np.sum(groundtruth_mask) - intersect
    xor = np.sum(groundtruth_mask==pred_mask)
    acc = np.mean(xor/(union + xor - intersect))
    return round(acc, 3)

def precision_score_(groundtruth_mask, pred_mask):
    intersect = np.sum(pred_mask*groundtruth_mask)
    total_pixel_pred = np.sum(pred_mask)
    precision = np.mean(intersect/total_pixel_pred)
    return round(precision, 3)

def recall_score_(groundtruth_mask, pred_mask):
    intersect = np.sum(pred_mask*groundtruth_mask)
    total_pixel_truth = np.sum(groundtruth_mask)
    recall = np.mean(intersect/total_pixel_truth)
    return round(recall, 3)

def iou(groundtruth_mask, pred_mask):
    intersect = np.sum(pred_mask*groundtruth_mask)
    union = np.sum(pred_mask) + np.sum(groundtruth_mask) - intersect
    iou = np.mean(intersect/union)
    return round(iou, 3)

def dice_coef(groundtruth_mask, pred_mask):
    intersect = np.sum(pred_mask*groundtruth_mask)
    total_sum = np.sum(pred_mask) + np.sum(groundtruth_mask)
    dice = np.mean(2*intersect/total_sum)
    return round(dice, 3)

def metrics_table(gt_masks, pred_masks):
    metrics = {'Precision':[],'Recall':[],'Accuracy':[],'Dice':[],'IoU':[]}
    for i, (mask, pred) in enumerate(zip(gt_masks, pred_masks)):
        metrics['Precision'].append(precision_score_(mask, pred))
        metrics['Recall'].append(recall_score_(mask, pred))
        metrics['Accuracy'].append(accuracy(mask, pred))
        metrics['Dice'].append(dice_coef(mask, pred))
        metrics['IoU'].append(iou(mask, pred))
    df = pd.DataFrame.from_dict(metrics)
    df.columns = ['Precision', 'Recall', 'Accuracy', 'Dice', 'IoU']

    avg_precision = df['Precision'].mean()
    avg_recall = df['Recall'].mean()
    avg_accuracy = df['Accuracy'].mean()
    avg_dice = df['Dice'].mean()
    avg_iou = df['IoU'].mean()
    
    return df, avg_precision, avg_recall, avg_accuracy, avg_dice, avg_iou

In [228]:
def img_skull_strip(img):

    img_blur = cv2.medianBlur(img, 1)

    img_gaus_blur = cv2.GaussianBlur(img, (17, 17), 0)
    img_sharp = np.float32(img)
    img_detail = img_sharp - img_gaus_blur
    img_sharp = img_sharp + img_detail

    img_sharp = np.clip(img_sharp, 0, 255)
    img_sharp = img_sharp.astype('uint8')

    def getLargestCC(segmentation):
        labels = label(segmentation)
        assert(labels.max() != 0 ) # assume at least 1 CC
        largestCC = labels == np.argmax(np.bincount(labels.flat)[1:])+1
        return largestCC

    high_pixel = 0.01
    hist = cv2.calcHist([img_sharp],[0],None,[256],[0,256]).flatten()
    total_count = img_sharp.shape[0] * img_sharp.shape[1]  # height * width
    target_count = high_pixel * total_count # bright pixels we look for
    summed = 0
    for i in range(255, 0, -1):
        summed += int(hist[i])
        if target_count <= summed:
            hi_thresh = i
            break
        else:
            hi_thresh = 0
    filtered = cv2.threshold(img_sharp, hi_thresh, 0, cv2.THRESH_TOZERO)[1]
    kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (9,9))
    filtered = cv2.morphologyEx(filtered, cv2.MORPH_CLOSE, kernel)

    thres, img_thres = cv2.threshold(img_blur, 0, 255, cv2.THRESH_BINARY+cv2.THRESH_OTSU) 

    kernel = footprints.octagon(6,6)
    img_morph = cv2.morphologyEx(img_thres, cv2.MORPH_OPEN, kernel)

    strip = img_thres - filtered

    img_cc = getLargestCC(img_morph)
    img_cc = img_cc.astype('uint8') * 255

    kernel = footprints.octagon(15,15)
    img_cc = cv2.morphologyEx(img_cc, cv2.MORPH_CLOSE, kernel)

    img_strip = cv2.bitwise_and(img, img, mask = img_cc)

    return img_strip

def watershed_seg(img):
    markers = rank.gradient(img, disk(7)) < 40
    markers = ndi.label(markers)[0]

    # local gradient (disk(2) is used to keep edges thin)
    gradient = rank.gradient(img, disk(2))

    # process the watershed
    labels = watershed(gradient, markers)

    return labels


def segment(label1, img):
    # binary_tumor_mask = tumor_mask.astype(np.uint8) 
    props = regionprops(label1,intensity_image=img)
    max_area = 9000 
    # Extract tumor regions
    tumor_regions = []
    max_intensity = 0 
    for prop in props:
        # Check criteria for identifying tumors
        if prop.mean_intensity > max_intensity:
            max_intensity = prop.mean_intensity
        if prop.area < max_area:
            tumor_regions.append(prop)

    # Calculate the dynamic minimum intensity
    min_intensity = max_intensity-100 #allow for different ranges
    min_area = 100
    tumor_regions_filtered = []
    for prop in tumor_regions:
        if min_area < prop.area < max_area and prop.mean_intensity > min_intensity:
            tumor_regions_filtered.append(prop)

    # Create a mask for tumor regions
    tumor_mask = np.zeros_like(label1)
    for region in tumor_regions_filtered:
        tumor_mask[label1 == region.label] = 1
    tumor_mask = np.uint8(tumor_mask)

    kernel = footprints.octagon(13,13)
    imgOut = cv2.morphologyEx(tumor_mask, cv2.MORPH_CLOSE, kernel)

    return imgOut



In [230]:
pathImage = 'brain_tumor_dataset/images/'
pathMask = 'brain_tumor_dataset/masks/'
output_dir = 'brain_tumor_dataset/segment/'

images = []
imgNames = []
for img in os.listdir(pathImage):
    if (img.endswith(".png")):
        imgNames.append(img)
        img = cv2.imread(pathImage+img, cv2.COLOR_BGR2GRAY)
        images.append(img)

In [231]:
ws_masks = []
start = timer()
for imgName, img in zip(imgNames, images):
    img_ss = img_skull_strip(img)
    img_ws = watershed_seg(img_ss)
    img_segment = segment(img_ws,img_ss)
    img_segment = (img_segment).astype('uint8')
    ws_masks.append(img_segment)
    plt.imsave(output_dir + imgName,255*img_segment, cmap=cm.gray)
end = timer()
timeTaken = end - start
print('Computation time: ', timeTaken)

Computation time:  7.640610999999808


In [232]:
np.unique(ws_masks)

array([0, 1], dtype=uint8)

In [233]:
gt_masks = []
for img in os.listdir(pathMask):
    if (img.endswith(".png")):
        img = cv2.imread(pathMask+img, cv2.COLOR_BGR2GRAY)
        img = (img / 255).astype('uint8')
        gt_masks.append(img)

In [234]:
df, prec, rec, acc, dice, iou = metrics_table(gt_masks, ws_masks)
df

Unnamed: 0,Precision,Recall,Accuracy,Dice,IoU
0,0.997,0.797,0.997,0.886,0.795
1,0.314,0.926,0.974,0.47,0.307
2,0.651,0.862,0.992,0.742,0.59
3,0.993,0.92,0.998,0.955,0.914
4,0.896,0.875,0.994,0.886,0.795
5,0.715,0.938,0.986,0.811,0.683
6,0.69,0.937,0.985,0.794,0.659
7,0.991,0.94,0.998,0.965,0.932
8,0.928,0.91,0.998,0.918,0.849
9,0.99,0.826,0.997,0.901,0.819


In [235]:
print('Average Precision: ', prec)
print('Average Recall: ', rec)
print('Average Accuracy: ', acc)
print('Average Dice: ', dice)
print('Average IoU: ', iou)

Average Precision:  0.34173469387755107
Average Recall:  0.6013061224489796
Average Accuracy:  0.9427142857142858
Average Dice:  0.38122448979591844
Average IoU:  0.3006122448979592


In [236]:
#fig, axs = plt.subplots(1, 2, figsize=(9, 6), constrained_layout=True)
#ax = axs.flatten()

#ax[0].imshow(mark_boundaries(img, im_true))
#ax[0].set_title('Groundtruth')
#ax[0].set_axis_off()

#ax[1].imshow(mark_boundaries(img, im_test3))
#ax[1].imshow(mark_boundaries(img, gac))
#ax[1].set_title('Morphological GAC')
#ax[1].set_axis_off()

#plt.show()