In [1]:
import numpy as np
from numba import jit
from scipy.ndimage import label
from tifffile import imread, imsave
from glob import glob
from scipy import ndimage
from skimage.measure import label

In [2]:
@jit
def pixel_sharing_bipartite(lab1, lab2):
    assert lab1.shape == lab2.shape
    psg = np.zeros((lab1.max()+1, lab2.max()+1), dtype=np.int)
    for i in range(lab1.size):
        psg[lab1.flat[i], lab2.flat[i]] += 1
    return psg

In [3]:
def matching_iou(psg, fraction=0.5):
    iou = intersection_over_union(psg)
    matching = iou > fraction
    matching[:,0] = False
    matching[0,:] = False
    return matching

In [4]:
def matching_max(psg):
    """
    matching based on mutual first preference
    """
    rowmax = np.argmax(psg, axis=0)
    colmax = np.argmax(psg, axis=1)
    starting_index = np.arange(psg.shape[1])
    equal_matches = colmax[rowmax[starting_index]]==starting_index
    rm, cm = rowmax[equal_matches], colmax[rowmax[equal_matches]]
    matching = np.zeros_like(psg)
    matching[rm, cm] = 1
    return matching

In [5]:
def intersection_over_union(psg):
    rsum = np.sum(psg, 0, keepdims=True)
    csum = np.sum(psg, 1, keepdims=True)
    return psg / (rsum + csum - psg)

In [6]:
def matching_overlap(psg, fractions=(0.5,0.5)):
    """
    create a matching given pixel_sharing_bipartite of two label images based on mutually overlapping regions of sufficient size.
    NOTE: a true matching is only gauranteed for fractions > 0.5. Otherwise some cells might have deg=2 or more.
    NOTE: doesnt break when the fraction of pixels matching is a ratio only slightly great than 0.5? (but rounds to 0.5 with float64?)
    """
    afrac, bfrac = fractions
    set0_object_sizes = np.sum(psg, axis=1, keepdims=True)
    m0  = np.where(set0_object_sizes==0,0,psg / set0_object_sizes)
    set1_object_sizes = np.sum(psg, axis=0, keepdims=True)
    m1 = np.where(set1_object_sizes==0,0,psg / set1_object_sizes)
    m0 = m0 > afrac
    m1 = m1 > bfrac
    matching = m0 * m1
    matching = matching.astype('bool')
    return matching

In [7]:
def precision(lab_gt, lab, iou=0.5, partial_dataset=False):
    '''
    precision = TP / (TP + FP + FN) i.e. "intersection over union" for a graph matching
    '''
    
    psg = pixel_sharing_bipartite(lab_gt, lab)
    matching = matching_iou(psg, fraction=iou)
    assert matching.sum(0).max() < 2
    assert matching.sum(1).max() < 2
    n_gt  = len(set(np.unique(lab_gt)) - {0})
    n_hyp = len(set(np.unique(lab)) - {0})
    n_matched = matching.sum()
#     print("TP:", n_matched, "FP:", n_hyp-n_matched, "FN:", n_gt-n_matched)
    if partial_dataset:
        return n_matched , (n_gt + n_hyp - n_matched)
    else:
        return n_matched / (n_gt + n_hyp - n_matched)

In [8]:
## full scores
def seg(lab_gt, lab, partial_dataset=False):
    """
    calculate seg from pixel_sharing_bipartite
    seg is the average conditional-iou across ground truth cells
    conditional-iou gives zero if not in matching
    ----
    calculate conditional intersection over union (CIoU) from matching & pixel_sharing_bipartite
    for a fraction > 0.5 matching. Any CIoU between matching pairs will be > 1/3. But there may be some
    IoU as low as 1/2 that don't match, and thus have CIoU = 0.
    """
    psg = pixel_sharing_bipartite(lab_gt, lab)
    iou = intersection_over_union(psg)
    matching = matching_overlap(psg, fractions=(0.5, 0))
    matching[0,:] = False
    matching[:,0] = False
    n_gt = len(set(np.unique(lab_gt)) - {0})
    n_matched = iou[matching].sum()
    if partial_dataset:
        return n_matched , n_gt
    else:
        return n_matched / n_gt, (2* (n_matched / n_gt)/(1+(n_matched / n_gt)))

In [87]:
gt_segs_path = glob("/Users/prakash/Desktop/MetaSeg_Data_Version1/CTC_GOWT1/SEG/*.tif")
label_segs_path = glob("/Users/prakash/Desktop/MetaSeg_Data_Version1/CTC_GOWT1/BIC-MetaSeg/run1/temp/*.tif")

seg_scores_per_time = []
dice_scores_per_time = []
for i in range(len(gt_segs_path)):
    seg_score, dice_score = seg(imread(gt_segs_path[i]), imread(label_segs_path[i]).astype(np.uint16))
#     print(seg_score, dice_score)
    seg_scores_per_time.append(seg_score)
    dice_scores_per_time.append(dice_score)

print("SEG score :", np.mean(np.array(seg_scores_per_time) ))
print("DICE score :", np.mean(np.array(dice_scores_per_time) ) )

SEG score : 0.827714473899
DICE score : 0.905352369934


In [110]:
average_precision_score_list = []
thresholds = np.arange(0.5,0.95,0.05)

for thresh in thresholds:
    gt_segs_path = glob("/Users/prakash/Desktop/MetaSeg_Data_Version1/CTC_GOWT1/SEG/*.tif")[14:15]
    label_segs_path = glob("/Users/prakash/Desktop/MetaSeg_Data_Version1/CTC_GOWT1/BIC-MetaSeg/run1/67annotations/*.tif")[14:15]
    sum_seg = 0
    scores_per_time = []
    for i in range(len(gt_segs_path)):
        score = precision(imread(label_segs_path[i]).astype(np.uint16), imread(gt_segs_path[i]), thresh)
        scores_per_time.append(score)
        if(0<=score<=1):
            sum_seg = sum_seg + score
            
    average_precision_score = sum_seg/len(gt_segs_path)
    average_precision_score_list.append(average_precision_score)
    
    print("AP score for threshold "+str(thresh)+" :", average_precision_score )
    
print("Mean AP score is: ", np.mean(np.array(average_precision_score_list)))

AP score for threshold 0.5 : 0.923076923077
AP score for threshold 0.55 : 0.923076923077
AP score for threshold 0.6 : 0.851851851852
AP score for threshold 0.65 : 0.851851851852
AP score for threshold 0.7 : 0.785714285714
AP score for threshold 0.75 : 0.666666666667
AP score for threshold 0.8 : 0.515151515152
AP score for threshold 0.85 : 0.25
AP score for threshold 0.9 : 0.0416666666667
Mean AP score is:  0.645450742673


In [None]:
AP score for threshold 0.5 : 0.923076923077
AP score for threshold 0.55 : 0.923076923077
AP score for threshold 0.6 : 0.923076923077
AP score for threshold 0.65 : 0.851851851852
AP score for threshold 0.7 : 0.724137931034
AP score for threshold 0.75 : 0.724137931034
AP score for threshold 0.8 : 0.5625
AP score for threshold 0.85 : 0.190476190476
AP score for threshold 0.9 : 0.063829787234
Mean AP score is:  0.654018273429
    
AP score for threshold 0.5 : 0.807692307692
AP score for threshold 0.55 : 0.807692307692
AP score for threshold 0.6 : 0.740740740741
AP score for threshold 0.65 : 0.740740740741
AP score for threshold 0.7 : 0.740740740741
AP score for threshold 0.75 : 0.620689655172
AP score for threshold 0.8 : 0.424242424242
AP score for threshold 0.85 : 0.146341463415
AP score for threshold 0.9 : 0.0444444444444
Mean AP score is:  0.56370275832

In [88]:
average_precision_score_per_run = []
thresholds = np.arange(0.5,0.95,0.05)
runs = [1,2,3,4]
for run_idx in runs:
    average_precision_score_list = []
    for thresh in thresholds:

        gt_segs_path = glob("/Users/prakash/Desktop/MetaSeg_Data_Version1/CTC_GOWT1/SEG/*.tif")
        label_segs_path = glob("/Users/prakash/Desktop/MetaSeg_Data_Version1/CTC_GOWT1/BIC-MetaSeg/run"+str(run_idx)+"/67annotations/*.tif")
        sum_seg = 0
        scores_per_time = []
        for i in range(len(gt_segs_path)):
            score = precision(imread(label_segs_path[i]).astype(np.uint16), imread(gt_segs_path[i]), thresh)
            scores_per_time.append(score)
            if(0<=score<=1):
                sum_seg = sum_seg + score

        average_precision_score = sum_seg/len(gt_segs_path)
        average_precision_score_list.append(average_precision_score)
#         print("AP score for threshold "+str(thresh)+" :", average_precision_score )
    average_precision_score_per_run.append(average_precision_score_list)
# print("Mean AP score is: ", np.mean(np.array(average_precision_score_list)))



mean_average_precision_score_per_threshold = []
for j in range(len(thresholds)):
    s = 0
    for i in range(len(average_precision_score_per_run)):
        s=s+average_precision_score_per_run[i][j]   
    mean_average_precision_score_per_threshold.append(s/len(runs))
print("AP over multiple runs for MetaSeg is:", mean_average_precision_score_per_threshold)

AP over multiple runs for MetaSeg is: [0.95450889982757214, 0.95077826549312294, 0.94659663872276345, 0.9388713306622426, 0.9130383431717114, 0.83221739785150184, 0.65271381498766978, 0.35376628510352454, 0.096076682513008188]


In [89]:
runs = [1,2,3,4]
seg_score_per_run = []
dice_score_per_run = []
for run_idx in runs:

    gt_segs_path = glob("/Users/prakash/Desktop/MetaSeg_Data_Version1/CTC_GOWT1/SEG/*.tif")
    label_segs_path = glob("/Users/prakash/Desktop/MetaSeg_Data_Version1/CTC_GOWT1/BIC-MetaSeg/run"+str(run_idx)+"/67annotations/*.tif")

    seg_scores_per_time = []
    dice_scores_per_time = []
    for i in range(len(gt_segs_path)):
        seg_score, dice_score = seg(imread(gt_segs_path[i]), imread(label_segs_path[i]).astype(np.uint16))
    #     print(seg_score, dice_score)
        seg_scores_per_time.append(seg_score)
        dice_scores_per_time.append(dice_score)

    print("SEG score :", np.mean(np.array(seg_scores_per_time) ))
    print("DICE score :", np.mean(np.array(dice_scores_per_time) ) )
    seg_score_per_run.append(np.mean(np.array(seg_scores_per_time) ))
    dice_score_per_run.append(np.mean(np.array(dice_scores_per_time) ))


print("SEG score over multiple runs for MetaSeg:", np.mean(np.array(seg_score_per_run)))
print("DICE score over multiple runs for MetaSeg:", np.mean(np.array(dice_score_per_run)))

SEG score : 0.834807124178
DICE score : 0.909687929998
SEG score : 0.829868240698
DICE score : 0.906709900285
SEG score : 0.824392690661
DICE score : 0.90330722363
SEG score : 0.830044766094
DICE score : 0.906806216752
SEG score over multiple runs for MetaSeg: 0.829778205408
DICE score over multiple runs for MetaSeg: 0.906627817666
