In [None]:
import runpy
import numpy as np
import tifffile as tiff
import matplotlib.pyplot as plt
from skimage.segmentation import mark_boundaries, find_boundaries
from pathlib import Path

In [None]:
%cd ..
from evaluate import read_roi
%cd -

In [None]:
# This notebook compares ROIs by individual annotators and merges them through
# majority voting. All the annotators should look at pairwise comparisons as well
# as merged masks, and decide, through discussion, whether the masks can be used
# as "ground truth" labels. If significant discrepancy among individual ROIs is
# observed (typically, if a pairwise overlap has an IoU of 0.4 or less), it must
# be resolved by discussion and redrawing.

# ROI directory
# ROIs by individual annotators should be placed in subdirectories.
paths = runpy.run_path('../params/paths.py')
ROI_PATH = Path(paths['HPC2_DATASETS'] + '_GT')

# Image directory
# For visual inspection, ROIs will be shown on the microscope images under
# this directory. Run summarize.py first to produce *_avg.tif.
IMG_PATH = Path(paths['OUTPUT_BASE_PATH'], 'results', 'voltage_HPC2')

In [None]:
def calc_IoU(mask1, mask2):
    I = np.logical_and(mask1, mask2)
    U = np.logical_or(mask1, mask2)
    return np.sum(I) / np.sum(U)

In [None]:
def calc_majority_vote(img_path, roi_path, gt_list, data_name):

    img = tiff.imread(Path(img_path, data_name, data_name + '_avg.tif'))

    masks_dict = {}
    contour_dict = {}
    for gt_id in gt_list:
        roi_files = Path(roi_path, gt_id).glob(data_name + '*')
        masks = read_roi(next(roi_files), img.shape)
        contour_image = np.zeros(img.shape, dtype=bool)
        for mask in masks:
            contour = find_boundaries(mask, mode='outer')
            contour_image = np.logical_or(contour_image, contour)
        masks_dict[gt_id] = masks
        contour_dict[gt_id] = contour_image

    match_dict = {}
    for ref_id in gt_list:
        matches = []
        for ref_mask in masks_dict[ref_id]:
            match = {}
            for tgt_id in gt_list:
                if(ref_id == tgt_id):
                    continue

                IoUs = []
                for tgt_mask in masks_dict[tgt_id]:
                    iou = calc_IoU(ref_mask, tgt_mask)
                    IoUs.append(iou)
                maxIoU = max(IoUs)
                maxIdx = IoUs.index(maxIoU)
                match[tgt_id] = [maxIoU, maxIdx]
            matches.append(match)
        match_dict[ref_id] = matches

    majority_list = []
    for ref_id in gt_list:
        ref_matches = match_dict[ref_id]
        for ref_index, ref_match in enumerate(ref_matches):
            majority = {}
            majority[ref_id] = ref_index
            print(ref_id + ' %d ' % ref_index, end='')
            for tgt_id, tgt_val in ref_match.items():
                tgt_iou = tgt_val[0]
                tgt_index = tgt_val[1]
                print('- ' + tgt_id, end='')
                if(tgt_iou == 0):
                    print(' _ (0%) [no match] ', end='')
                else:
                    print(' %d (%d%%) [- ' % (tgt_index, tgt_iou*100), end='')
                    ref_val = match_dict[tgt_id][tgt_index][ref_id]
                    ref_iou_by_tgt = ref_val[0]
                    ref_index_by_tgt = ref_val[1]
                    print(ref_id + ' %d' % ref_index_by_tgt, end='')
                    if(ref_index != ref_index_by_tgt):
                        print(': mismatch] ', end='')
                    elif(tgt_iou != ref_iou_by_tgt):
                        print(': inconsistent IoU %d%%]' % ref_iou_by_tgt, end='')
                    else:
                        print('] ', end='')
                        majority[tgt_id] = tgt_index
            print('')

            if(len(majority) > len(gt_list)/2): # if indeed majority
                already_in_list = False
                for elem in majority_list:
                    same_val = 0
                    for id in gt_list:
                        v1 = elem[id] if(id in elem) else -1
                        v2 = majority[id] if(id in majority) else -1
                        if(v1 == v2):
                            same_val += 1
                    if(same_val == len(gt_list)):
                        already_in_list = True
                        break
                if(not already_in_list):
                    majority_list.append(majority)

    return majority_list, masks_dict, contour_dict, img

In [None]:
def create_overlay(mask, color):
    boundary = mark_boundaries(np.zeros(mask.shape), mask, color=(1, 1, 1))
    overlay = np.zeros((mask.shape[0], mask.shape[1], 4), dtype='uint8')
    overlay[:, :, 0] = int(color[0] * 255)
    overlay[:, :, 1] = int(color[1] * 255)
    overlay[:, :, 2] = int(color[2] * 255)
    overlay[:, :, 3] = (boundary[:, :, 0] * 255).astype('uint8')
    return overlay

In [None]:
def show_majority_votes(img_path, roi_path, data_name):

    GT_LIST = sorted([x.stem for x in roi_path.iterdir() if x.is_dir()])
    majority_list, masks_dict, contour_dict, img = calc_majority_vote(img_path, roi_path, GT_LIST, data_name)

    LINE_COLOR = [255, 255, 0]
    overlay = np.zeros((img.shape[0], img.shape[1], 4), dtype='uint8')
    overlay[:, :, 0] = LINE_COLOR[0]
    overlay[:, :, 1] = LINE_COLOR[1]
    overlay[:, :, 2] = LINE_COLOR[2]

    dpi = 100
    width = 3
    height = width * img.shape[0] / img.shape[1]
    num_gt = len(GT_LIST)
    gap = 0.1

    IoUs = []
    merged_masks = np.zeros((len(majority_list), img.shape[0], img.shape[1]), dtype='uint8')
    
    for j, majority in enumerate(majority_list):

        plt.figure(figsize=(width * num_gt, 1), dpi=dpi)
        plt.axis('off')
        plt.hlines(0, 0, width * num_gt)
        plt.show()

        plt.figure(figsize=(width * num_gt, height + 1), dpi=dpi)
        plt.suptitle('All masks of data set ' + data_name)
        for i, gt_id in enumerate(GT_LIST):
            plt.subplot(1, num_gt, i+1)
            plt.axis('off')
            plt.imshow(img, interpolation='bilinear', cmap='gray')
            overlay[:, :, 3] = contour_dict[gt_id] * 255
            plt.imshow(overlay, interpolation='bilinear')
            plt.title(gt_id)

        plt.tight_layout(pad=gap)
        plt.show()

        plt.figure(figsize=(width * num_gt, height + 1), dpi=dpi)
        plt.suptitle('Majority mask No. %d' % j)
        for i, gt_id in enumerate(GT_LIST):
            plt.subplot(1, num_gt, i+1)
            plt.axis('off')
            plt.title(gt_id)
            plt.imshow(img, interpolation='bilinear', cmap='gray')
            if(gt_id in majority):
                mask_index = majority[gt_id]
                mask = masks_dict[gt_id][mask_index]
                plt.imshow(create_overlay(mask, (1, 0, 0)), interpolation='bilinear')

        plt.tight_layout(pad=gap)
        plt.show()

        plt.figure(figsize=(width * num_gt, height + 1), dpi=dpi)
        plt.suptitle('Majority mask No. %d comparison' % j)
        for i in range(num_gt):
            plt.subplot(1, num_gt, i+1)
            plt.axis('off')
            plt.imshow(img, interpolation='bilinear', cmap='gray')
            gt1 = GT_LIST[i]
            gt2 = GT_LIST[(i + 1) % num_gt]
            if(gt1 in majority and gt2 in majority):
                mask_index1 = majority[gt1]
                mask_index2 = majority[gt2]
                mask1 = masks_dict[gt1][mask_index1]
                mask2 = masks_dict[gt2][mask_index2]
                IoU = calc_IoU(mask1, mask2)
                IoUs.append(IoU)
                plt.imshow(create_overlay(mask1, (1, 0, 1)), interpolation='bilinear')
                plt.imshow(create_overlay(mask2, (0.2, 0.7, 0)), interpolation='bilinear')
                plt.title(gt1 + ' - ' + gt2 + ' (IoU %.2f)' % IoU, fontsize=10)
            else:
                plt.title(gt1 + ' - ' + gt2 + ' (no match)', fontsize=10)

        plt.tight_layout(pad=gap)
        plt.show()

        plt.figure(figsize=(width * 2, height + 1), dpi=dpi)
        plt.suptitle('Majority mask No. %d voting result' % j)
        sum_mask = np.zeros(img.shape, dtype='uint8')
        for gt_id in GT_LIST:
            if(gt_id in majority):
                mask_index = majority[gt_id]
                mask = masks_dict[gt_id][mask_index]
                sum_mask += mask
        plt.subplot(1, 2, 1)
        plt.axis('off')
        plt.imshow(sum_mask, interpolation='bilinear', vmin=0, vmax=num_gt, cmap='inferno')
        plt.title('Heat map of # votes')
        
        plt.subplot(1, 2, 2)
        plt.axis('off')
        majority_mask = sum_mask > num_gt/2
        merged_masks[j] = majority_mask
        plt.imshow(img, interpolation='bilinear', cmap='gray')
        plt.imshow(create_overlay(majority_mask, (1, 1, 0)), interpolation='bilinear')
        plt.title('Majority mask')
        
        plt.tight_layout(pad=gap)
        plt.show()

        
    plt.figure(figsize=(width * num_gt, 1), dpi=dpi)
    plt.axis('off')
    plt.hlines(0, 0, width * num_gt)
    plt.show()

    plt.figure(figsize=(width * num_gt, height + 1), dpi=dpi)
    plt.suptitle('All masks of data set ' + data_name)
    for i, gt_id in enumerate(GT_LIST):
        plt.subplot(1, num_gt, i+1)
        plt.axis('off')
        plt.imshow(img, interpolation='bilinear', cmap='gray')
        overlay[:, :, 3] = contour_dict[gt_id] * 255
        plt.imshow(overlay, interpolation='bilinear')
        plt.title(gt_id)

    plt.tight_layout(pad=gap)
    plt.show()

    plt.figure(figsize=(width, height + 1), dpi=dpi)
    plt.axis('off')
    plt.imshow(img, interpolation='bilinear', cmap='gray')
    for mask in merged_masks:
        plt.imshow(create_overlay(mask, (1, 1, 0)), interpolation='bilinear')
    plt.title('Merged majority masks')
    plt.tight_layout(pad=gap)
    plt.show()
    
    tiff.imwrite(data_name + '.tif', merged_masks * 255, photometric='minisblack')
    
    return IoUs

In [None]:
IoUs = []
for d in Path(IMG_PATH).iterdir():
    if(d.is_dir()):
        print('Data set ' + d.name)
        IoUs += show_majority_votes(IMG_PATH, ROI_PATH, d.name)

In [None]:
plt.figure()
plt.hist(IoUs, bins=40, range=(0, 1))
plt.title('Histogram of IoU values')
plt.show()