In [1]:
import numpy as np
from numba import jit
from scipy.ndimage import label
from tifffile import imread, imsave
from glob import glob
from scipy import ndimage
from skimage.measure import label

In [2]:
@jit
def pixel_sharing_bipartite(lab1, lab2):
    assert lab1.shape == lab2.shape
    psg = np.zeros((lab1.max()+1, lab2.max()+1), dtype=np.int)
    for i in range(lab1.size):
        psg[lab1.flat[i], lab2.flat[i]] += 1
    return psg

In [3]:
def matching_iou(psg, fraction=0.5):
    iou = intersection_over_union(psg)
    matching = iou > fraction
    matching[:,0] = False
    matching[0,:] = False
    return matching

In [4]:
def matching_max(psg):
    """
    matching based on mutual first preference
    """
    rowmax = np.argmax(psg, axis=0)
    colmax = np.argmax(psg, axis=1)
    starting_index = np.arange(psg.shape[1])
    equal_matches = colmax[rowmax[starting_index]]==starting_index
    rm, cm = rowmax[equal_matches], colmax[rowmax[equal_matches]]
    matching = np.zeros_like(psg)
    matching[rm, cm] = 1
    return matching

In [5]:
def intersection_over_union(psg):
    rsum = np.sum(psg, 0, keepdims=True)
    csum = np.sum(psg, 1, keepdims=True)
    return psg / (rsum + csum - psg)

In [6]:
def matching_overlap(psg, fractions=(0.5,0.5)):
    """
    create a matching given pixel_sharing_bipartite of two label images based on mutually overlapping regions of sufficient size.
    NOTE: a true matching is only gauranteed for fractions > 0.5. Otherwise some cells might have deg=2 or more.
    NOTE: doesnt break when the fraction of pixels matching is a ratio only slightly great than 0.5? (but rounds to 0.5 with float64?)
    """
    afrac, bfrac = fractions
    set0_object_sizes = np.sum(psg, axis=1, keepdims=True)
    m0  = np.where(set0_object_sizes==0,0,psg / set0_object_sizes)
    set1_object_sizes = np.sum(psg, axis=0, keepdims=True)
    m1 = np.where(set1_object_sizes==0,0,psg / set1_object_sizes)
    m0 = m0 > afrac
    m1 = m1 > bfrac
    matching = m0 * m1
    matching = matching.astype('bool')
    return matching

In [7]:
def precision(lab_gt, lab, iou=0.5, partial_dataset=False):
    '''
    precision = TP / (TP + FP + FN) i.e. "intersection over union" for a graph matching
    '''
    
    psg = pixel_sharing_bipartite(lab_gt, lab)
    matching = matching_iou(psg, fraction=iou)
    assert matching.sum(0).max() < 2
    assert matching.sum(1).max() < 2
    n_gt  = len(set(np.unique(lab_gt)) - {0})
    n_hyp = len(set(np.unique(lab)) - {0})
    n_matched = matching.sum()
    if partial_dataset:
        return n_matched , (n_gt + n_hyp - n_matched)
    else:
        return n_matched / (n_gt + n_hyp - n_matched)

In [8]:
## full scores
def seg(lab_gt, lab, partial_dataset=False):
    """
    calculate seg from pixel_sharing_bipartite
    seg is the average conditional-iou across ground truth cells
    conditional-iou gives zero if not in matching
    ----
    calculate conditional intersection over union (CIoU) from matching & pixel_sharing_bipartite
    for a fraction > 0.5 matching. Any CIoU between matching pairs will be > 1/3. But there may be some
    IoU as low as 1/2 that don't match, and thus have CIoU = 0.
    """
    psg = pixel_sharing_bipartite(lab_gt, lab)
    iou = intersection_over_union(psg)
    matching = matching_overlap(psg, fractions=(0.5, 0))
    matching[0,:] = False
    matching[:,0] = False
    n_gt = len(set(np.unique(lab_gt)) - {0})
    n_matched = iou[matching].sum()
    if partial_dataset:
        return n_matched , n_gt
    else:
        return n_matched / n_gt

In [20]:
average_precision_score_list = []
thresholds = np.arange(0.5,0.95,0.05)
for thresh in thresholds:
    
    gt_segs_path = glob("/Users/prakash/Desktop/metasegData/drosophila/SEG/*.tif")
    label_segs_path = glob("/Users/prakash/Desktop/MetaSeg_Data_Version1/drosophila/SIMPLE/*.tif")
    sum_seg = 0
    scores_per_time = []
    for i in range(len(gt_segs_path)):
        score = precision(imread(label_segs_path[i]).astype(np.uint16), imread(gt_segs_path[i]), thresh)
        scores_per_time.append(score)
        if(0<=score<=1):
            sum_seg = sum_seg + score
            
    average_precision_score = sum_seg/len(gt_segs_path)
    average_precision_score_list.append(average_precision_score)
    print("AP score for threshold "+str(thresh)+" :", average_precision_score )

print("Mean AP score is: ", np.mean(np.array(average_precision_score_list)))

AP score for threshold 0.5 : 0.195159181789
AP score for threshold 0.55 : 0.188435277739
AP score for threshold 0.6 : 0.161478589935
AP score for threshold 0.65 : 0.106172115275
AP score for threshold 0.7 : 0.0605930386525
AP score for threshold 0.75 : 0.0419643715145
AP score for threshold 0.8 : 0.0322283418199
AP score for threshold 0.85 : 0.0242938213457
AP score for threshold 0.9 : 0.0159752897147
Mean AP score is:  0.0918111141983


In [10]:
scores_per_time

[0.020899053627760251,
 0.017327557480839719,
 0.016174183514774496,
 0.016129032258064516,
 0.015985244389794036,
 0.016103568045468898,
 0.016528925619834711,
 0.016938110749185668,
 0.019658753709198812,
 0.020359281437125749,
 0.020310633213859019,
 0.020985010706638114,
 0.02202845341899954,
 0.021830004644681839,
 0.021217712177121772,
 0.021016617790811338,
 0.021158690176322419,
 0.017830609212481426,
 0.016058394160583942,
 0.013768485466598673,
 0.0095142714071106659,
 0.006422924901185771,
 0.0071174377224199285,
 0.0060851926977687626,
 0.010047593865679535,
 0.0120415982484948,
 0.014061654948620876,
 0.017213555675094135,
 0.020273972602739727,
 0.02001000500250125,
 0.02133194588969823,
 0.022267206477732792,
 0.024594453165881738,
 0.023845007451564829,
 0.023868722028841372,
 0.023928215353938187,
 0.023715415019762844,
 0.022515907978463045,
 0.021749408983451537,
 0.024365987071108902,
 0.025276461295418641,
 0.030494216614090432,
 0.039144050104384133,
 0.0420212765