In [1]:
import numpy as np
from skimage.segmentation import relabel_sequential
from scipy.optimize import linear_sum_assignment

def merge_label_slices(imgs, iou_threshold = 0.0):
    # relabel first plane -> first "corrected" plane of result
    res = [relabel_sequential(imgs[0])[0]]
    max_label = 0
    for img in imgs[1:]:
        # correct each other plane in comparison to last (corrected) plane
        next_corr, max_label = correct_next_plane(res[-1], img, iou_threshold, max_label)
        res.append(next_corr)
    # relabel everything, as we might have skipped some numbers
    return relabel_sequential(np.stack(res))[0]


def correct_next_plane(img1, img2, iou_threshold = 0.0, max_label=0):
    
    
    labs1 = np.sort(np.unique(img1))[1:]
    
    if len(labs1) < 1:
        img2_relab, _, _ = relabel_sequential(img2, max_label+1)
        return img2_relab, np.max([max_label, np.max(img2_relab)])
    
    # relabel img2 so everything has a higher label than img1
    img2_relab, _, _ = relabel_sequential(img2, np.max([max_label, labs1.max()])+1)
    labs2 = np.sort(np.unique(img2_relab))[1:]

    # maximum number of objects in one of the two images
    n_labs = np.max([len(labs1), len(labs2)])

    # init IOU matrix
    ious = np.zeros((n_labs, n_labs))

    # fill IOU matrix (rows: objects in img1, cols: objects in img2)
    for i, lab1 in enumerate(labs1):
        for j, lab2 in enumerate(labs2):

            bin1 = img1 == lab1
            bin2 = img2_relab == lab2

            intersection = bin1 & bin2
            union = bin1 | bin2
            iou = intersection.sum() / union.sum()

            ious[i, j] = iou

    # find maximum IOU matching
    row_idx, col_idx = linear_sum_assignment(ious, True)

    img2_corr = img2_relab.copy()

    for ri, ci in zip(row_idx, col_idx):
        
        # if we have a valid match
        # (object exists in img1 and img2)
        valid_match = ri < len(labs1) and ci < len(labs2)
        # and the IOU of the matched objects is bigger than threshold
        if valid_match and ious[ri,ci] > iou_threshold:
            # copy the label from img1 to result
            img2_corr[img2_relab == labs2[ci]] = labs1[ri]
        
    return img2_corr, np.max([max_label, np.max(img2_corr)])

In [3]:
from skimage.io import imread
from glob import glob

# read some test data
Groundtruth = [imread(f) for f in sorted(glob("/scratch/leonhardt_ba_ss21/gabriel/BA/3D/*corrected.tif"))]
Stardist_21_05_split = [imread(f) for f in sorted(glob("/scratch/leonhardt_ba_ss21/gabriel/BA/3D/*_model_21_05.tif"))]

# test on one pair
gt_mask = Groundtruth[0]
pred_mask = Stardist_21_05_split[0]

In [5]:
### calculate precision/recall as well as IOU for true positives
### for one gt_mask / pred_mask pair
### -> do in loop for many pairs

# use merge_label_slices to match GT with prediction
gt_mask_matched, pred_mask_matched = merge_label_slices([gt_mask, pred_mask])

# get sets of all labels in matched GT and pred
labels_gt = set(np.unique(gt_mask_matched))
labels_pred = set(np.unique(pred_mask_matched))

true_positive_labels = labels_gt & labels_pred # labels that appear in both GT and prediction
false_positive_labels = labels_pred - labels_gt # labels that appear only in prediction
false_negative_labels = labels_gt - labels_pred # labels that appear only in GT

# calculate precision/recall: https://en.wikipedia.org/wiki/Precision_and_recall
precision = len(true_positive_labels) / len(true_positive_labels | false_positive_labels)
recall = len(true_positive_labels) / len(true_positive_labels | false_negative_labels)

# calculate IOUs for all true positives
ious = []
for label in true_positive_labels:
    
    # same as old code, but we have already found matching
    # so we do not need to loop over all combinations
    gt_binary = gt_mask_matched == label
    pred_binary = pred_mask_matched == label

    intersection = gt_binary & pred_binary
    union = gt_binary | pred_binary

    iou = intersection.sum() / union.sum()
    ious.append(iou)
    
print(precision, recall, np.mean(ious))

0.6071428571428571 0.9444444444444444 0.4239890182000488
