In [11]:
import glob
import logging
import os
import sys

import h5py
import argparse
import numpy as np
import scipy.ndimage
from scipy.optimize import linear_sum_assignment
from skimage.segmentation import relabel_sequential


class Metrics:
    def __init__(self):
        self.metricsDict = {}
        self.metricsArray = []


    def addTable(self, name, dct=None):
        levels = name.split(".")
        if dct is None:
            dct = self.metricsDict
        if levels[0] not in dct:
            dct[levels[0]] = {}
        if len(levels) > 1:
            name = ".".join(levels[1:])
            self.addTable(name, dct[levels[0]])

    def getTable(self, name, dct=None):
        levels = name.split(".")
        if dct is None:
            dct = self.metricsDict
        if len(levels) == 1:
            return dct[levels[0]]
        else:
            name = ".".join(levels[1:])
            return self.getTable(name, dct=dct[levels[0]])

    def addMetric(self, table, name, value):
        self.metricsArray.append(value)
        tbl = self.getTable(table)
        tbl[name] = value


def maybe_crop(pred_labels, gt_labels, overlapping_inst=False):
    if overlapping_inst:
        if gt_labels.shape[1:] == pred_labels.shape[1:]:
            return pred_labels, gt_labels
        else:
            # todo: add other cases
            raise NotImplementedError("Sorry, cropping for overlapping "
                                      "instances not implemented yet!")
    else:
        if gt_labels.shape == pred_labels.shape:
            return pred_labels, gt_labels
        if gt_labels.shape[0] > pred_labels.shape[0]:
            bigger_arr = gt_labels
            smaller_arr = pred_labels
            swapped = False
        else:
            bigger_arr = pred_labels
            smaller_arr = gt_labels
            swapped = True
        begin = (np.array(bigger_arr.shape) -
                 np.array(smaller_arr.shape)) // 2
        end = np.array(bigger_arr.shape) - begin
        if len(bigger_arr.shape) == 2:
            bigger_arr = bigger_arr[begin[0]:end[0],
                                    begin[1]:end[1]]
        else:
            if (np.array(bigger_arr.shape) -
                np.array(smaller_arr.shape))[2] % 2 == 1:
                end[2] -= 1
            bigger_arr = bigger_arr[begin[0]:end[0],
                                    begin[1]:end[1],
                                    begin[2]:end[2]]
        if not swapped:
            gt_labels = bigger_arr
            pred_labels = smaller_arr
        else:
            pred_labels = bigger_arr
            gt_labels = smaller_arr
        print("gt shape cropped %s", gt_labels.shape)
        print("pred shape cropped %s", pred_labels.shape)

        return pred_labels, gt_labels

def evaluate_file(pred_labels, gt_labels, background=0,
                  foreground_only=False, use_linear_sum_assignment=True):
   
    
    print("prediction min %f, max %f, shape %s", np.min(pred_labels),
                 np.max(pred_labels), pred_labels.shape)
    pred_labels = np.squeeze(pred_labels)
    print("prediction shape %s", pred_labels.shape)

    print("gt min %f, max %f, shape %s", np.min(gt_labels),
                 np.max(gt_labels), gt_labels.shape)
    if gt_labels.shape[0] == 1:
        gt_labels.shape = gt_labels.shape[1:]
    gt_labels = np.squeeze(gt_labels)
    if gt_labels.ndim > pred_labels.ndim:
        gt_labels = np.max(gt_labels, axis=0)
    # check if pred.dim < gt.dim
    # if pred_labels.ndim < gt_labels.ndim:
    #     print("WARNING: changing prediction to one instance per channel")
    # lbls = np.unique(pred_labels)
    # pred_1instpch = np.zeros((np.sum(lbls > 0),
    #                                           ) + pred_labels.shape,
    #                                          dtype=pred_labels.dtype)
    # i = 0
    # for lbl in lbls:
    #     if lbl == 0:
    #         continue
    #     pred_1instpch[i][pred_labels == lbl] = i + 1
    #     print("check: ", i, lbl, np.sum(pred_1instpch == i + 1),
    #           np.sum(pred_labels == lbl))
    #     i += 1

    # pred_labels = pred_1instpch
    print("gt shape %s", gt_labels.shape)

    # heads up: should not crop channel dimensions, assuming channels first
    overlapping_inst = False
    pred_labels, gt_labels = maybe_crop(pred_labels, gt_labels,
                                        overlapping_inst)

    # if pred_labels.shape[0] == 536:
    #     print(pred_labels.shape, gt_labels.shape)
    #     pred_labels = pred_labels[12:-12, 12:-12]
    #     gt_labels = gt_labels[12:-12, 12:-12]
    #     print(pred_labels.shape, gt_labels.shape)

    if foreground_only:
        pred_labels[gt_labels==0] = 0

    print("processing")

    # relabel gt labels in case of binary mask per channel
    if overlapping_inst and np.max(gt_labels) == 1:
        for i in range(gt_labels.shape[0]):
            gt_labels[i] = gt_labels[i] * (i + 1)

    if use_linear_sum_assignment:
        return evaluate_linear_sum_assignment(gt_labels, pred_labels,
                                             overlapping_inst=False, filterSz=None,
                                             visualize=False)

    # get gt cell ids and the size of the corresponding cell
    gt_labels_list, gt_counts = np.unique(gt_labels, return_counts=True)
    gt_labels_count_dict = {}
    print("%s %s", gt_labels_list, gt_counts)
    for (l, c) in zip(gt_labels_list, gt_counts):
        gt_labels_count_dict[l] = c

    # get pred cell ids
    pred_labels_list, pred_counts = np.unique(pred_labels,
                                              return_counts=True)
    print("%s %s", pred_labels_list, pred_counts)
    pred_labels_count_dict = {}
    for (l, c) in zip(pred_labels_list, pred_counts):
        pred_labels_count_dict[l] = c

    # get overlapping labels
    if overlapping_inst:
        pred_tile = [1,] * pred_labels.ndim
        pred_tile[0] = gt_labels.shape[0]
        gt_tile = [1,] * gt_labels.ndim
        gt_tile[1] = pred_labels.shape[0]
        pred_tiled = np.tile(pred_labels, pred_tile).flatten()
        gt_tiled = np.tile(gt_labels, gt_tile).flatten()
        mask = np.logical_or(pred_tiled > 0, gt_tiled > 0)
        overlay = np.array([
            pred_tiled[mask],
            gt_tiled[mask]
        ])
        overlay_labels, overlay_labels_counts = np.unique(
            overlay, return_counts=True, axis=1)
        overlay_labels = np.transpose(overlay_labels)
    else:
        overlay = np.array([pred_labels.flatten(),
                            gt_labels.flatten()])
        print("overlay shape %s", overlay.shape)
        # get overlaying cells and the size of the overlap
        overlay_labels, overlay_labels_counts = np.unique(overlay,
                                             return_counts=True, axis=1)
        overlay_labels = np.transpose(overlay_labels)

    # identify overlaying cells where more than 50% of gt cell is covered
    matchesSEG = np.asarray([c > 0.5 * float(gt_counts[gt_labels_list == v])
        for (u,v), c in zip(overlay_labels, overlay_labels_counts)],
                            dtype=np.bool)

    # get their ids
    matches_labels = overlay_labels[matchesSEG]

    # remove background
    if background is not None:
        pred_labels_list = pred_labels_list[pred_labels_list != background]
        gt_labels_list = gt_labels_list[gt_labels_list != background]

    matches_mat = np.zeros((len(pred_labels_list), len(gt_labels_list)))
    for (u, v) in matches_labels:
        if u > 0 and v > 0:
            matches_mat[np.where(pred_labels_list == u),
                        np.where(gt_labels_list == v)] = 1

    diceGT = {}
    iouGT = {}
    segGT = {}
    diceP = {}
    iouP = {}
    segP = {}
    segPrev = {}
    for (u,v), c in zip(overlay_labels, overlay_labels_counts):
        dice = 2.0 * c / (gt_labels_count_dict[v] + pred_labels_count_dict[u])
        iou = c / (gt_labels_count_dict[v] + pred_labels_count_dict[u] - c)

        if c > 0.5 * gt_labels_count_dict[v]:
            seg = iou
        else:
            seg = 0
        if c > 0.5 * pred_labels_count_dict[u]:
            seg2 = iou
        else:
            seg2 = 0

        if v not in diceGT:
            diceGT[v] = []
            iouGT[v] = []
            segGT[v] = []
        if u not in diceP:
            diceP[u] = []
            iouP[u] = []
            segP[u] = []
            segPrev[u] = []
        diceGT[v].append(dice)
        iouGT[v].append(iou)
        segGT[v].append(seg)
        diceP[u].append(dice)
        iouP[u].append(iou)
        segP[u].append(seg)
        segPrev[u].append(seg2)

    if background is not None:
        iouP.pop(background)
        iouGT.pop(background)
        diceP.pop(background)
        diceGT.pop(background)
        segP.pop(background)
        segPrev.pop(background)
        segGT.pop(background)

    dice = 0
    cnt = 0
    for (k, vs) in diceGT.items():
        vs = sorted(vs, reverse=True)
        dice += vs[0]
        cnt += 1
    diceGT = dice/max(1, cnt)

    dice = 0
    cnt = 0
    for (k, vs) in diceP.items():
        vs = sorted(vs, reverse=True)
        dice += vs[0]
        cnt += 1
    diceP = dice/max(1, cnt)

    iou = []
    instances = gt_labels.copy().astype(np.float32)
    for (k, vs) in iouGT.items():
        vs = sorted(vs, reverse=True)
        iou.append(vs[0])
        instances[instances==k] = vs[0]
    iouGT = np.array(iou)
    iouGTMn = np.mean(iouGT)

    iou = []
    iouIDs = []
    instances = pred_labels.copy().astype(np.float32)
    for (k, vs) in iouP.items():
        vs = sorted(vs, reverse=True)
        iou.append(vs[0])
        iouP[k] = vs
        iouIDs.append(k)
        instances[instances==k] = vs[0]

    iouP_2 = np.array(iou)
    iouIDs = np.array(iouIDs)
    iouPMn = np.mean(iouP_2)

    seg = 0
    cnt = 0
    for (k, vs) in segGT.items():
        vs = sorted(vs, reverse=True)
        seg += vs[0]
        cnt += 1
    segGT = seg/max(1, cnt)

    seg = 0
    cnt = 0
    for (k, vs) in segP.items():
        vs = sorted(vs, reverse=True)
        seg += vs[0]
        cnt += 1
    segP = seg/max(1, cnt)

    seg = 0
    cnt = 0
    for (k, vs) in segPrev.items():
        vs = sorted(vs, reverse=True)
        seg += vs[0]
        cnt += 1
    segPrev = seg/max(1, cnt)

    # non-split vertices num non-empty cols - num non-empty rows
    # (more than one entry in col: predicted cell with more than one
    # ground truth cell assigned)
    # (other way around not possible due to 50% rule)
    ns = np.sum(np.count_nonzero(matches_mat, axis=0)) \
            - np.sum(np.count_nonzero(matches_mat, axis=1) > 0)
    ns = int(ns)

    # false negative: empty cols
    # (no predicted cell for ground truth cell)
    fn = np.sum(np.sum(matches_mat, axis=0) == 0)
    # tmp = np.sum(matches_mat, axis=0)==0
    # for i in range(len(tmp)):
    #     print(i, tmp[i], gt_labels_list[i])
    fn = int(fn)

    # false positive: empty rows
    # (predicted cell for non existing ground truth cell)
    fp = np.sum(np.sum(matches_mat, axis=1) == 0)
    # tmp = np.sum(matches_mat, axis=1)==0
    # for i in range(len(tmp)):
    #     print(i, tmp[i], pred_labels_list[i])
    # print(np.sum(matches_mat, axis=1)==0)
    fp = int(fp)

    # true positive: row with single entry (can be 0, 1, or more)
    tpP = np.sum(np.sum(matches_mat, axis=1) == 1)
    tpP = int(tpP)

    # true positive: non-empty col (can only be 0 or 1)
    tpGT = np.sum(np.sum(matches_mat, axis=0) > 0)
    tpGT = int(tpGT)


    metrics = Metrics()
    tblNameGen = "general"
    metrics.addTable(tblNameGen)
    metrics.addMetric(tblNameGen, "Num GT", len(gt_labels_list))
    metrics.addMetric(tblNameGen, "Num Pred", len(pred_labels_list))
    metrics.addMetric(tblNameGen, "GT/Ref -> Pred mean dice", diceGT)
    metrics.addMetric(tblNameGen, "Pred -> GT/Ref mean dice", diceP)
    metrics.addMetric(tblNameGen, "GT/Ref -> Pred mean iou", iouGTMn)
    metrics.addMetric(tblNameGen, "Pred -> GT/Ref mean iou", iouPMn)
    metrics.addMetric(tblNameGen, "GT/Ref -> Pred mean seg", segGT)
    metrics.addMetric(tblNameGen, "Pred -> GT/Ref mean seg", segP)
    metrics.addMetric(tblNameGen, "Pred -> GT/Ref mean seg rev", segPrev)
    metrics.addMetric(tblNameGen, "Pred -> GT/Ref NS", ns)
    metrics.addMetric(tblNameGen, "Pred -> GT/Ref FP", fp)
    metrics.addMetric(tblNameGen, "Pred -> GT/Ref TP", tpP)
    metrics.addMetric(tblNameGen, "GT/Ref -> Pred FN", fn)
    metrics.addMetric(tblNameGen, "GT/Ref -> Pred TP", tpGT)

    ths = [0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95]
    aps = []
    metrics.addTable("confusion_matrix")
    for th in ths:
        tblname = "confusion_matrix.th_"+str(th).replace(".","_")
        metrics.addTable(tblname)
        apTP = 0
        for pID in np.nonzero(iouP_2 > th)[0]:
            if len(iouP[iouIDs[pID]]) == 0:
                pass
            elif len(iouP[iouIDs[pID]]) == 1:
                apTP += 1
            elif len(iouP[iouIDs[pID]]) > 1 and iouP[iouIDs[pID]][1] < th:
                apTP += 1
        metrics.addMetric(tblname, "AP_TP", apTP)
        apTP = np.count_nonzero(iouP_2[iouP_2>th])
        apFP = np.count_nonzero(iouP_2[iouP_2<=th])
        apFN = np.count_nonzero(iouGT[iouGT<=th])
        metrics.addMetric(tblname, "AP_TP", apTP)
        metrics.addMetric(tblname, "AP_FP", apFP)
        metrics.addMetric(tblname, "AP_FN", apFN)
        ap = 1.*(apTP) / max(1, apTP + apFN + apFP)
        aps.append(ap)
        metrics.addMetric(tblname, "AP", ap)
        precision = 1.*(apTP) / max(1, len(pred_labels_list))
        metrics.addMetric(tblname, "precision", precision)
        recall = 1.*(apTP) / max(1, len(gt_labels_list))
        metrics.addMetric(tblname, "recall", recall)
        if (precision + recall) > 0:
            fscore = (2. * precision * recall) / max(1, precision + recall)
        else:
            fscore = 0.0
        metrics.addMetric(tblname, 'fscore', fscore)

    avAP = np.mean(aps)
    metrics.addMetric("confusion_matrix", "avAP", avAP)

    return metrics.metricsDict


def evaluate_linear_sum_assignment(gt_labels, pred_labels,
                                   overlapping_inst=False, filterSz=None,
                                   visualize=False):
    if filterSz is not None:
        ls, cs = np.unique(pred_labels, return_counts=True)
        pred_labels2 = np.copy(pred_labels)
        print(sorted(zip(cs, ls)))
        for l, c in zip(ls, cs):
            if c < filterSz:
                pred_labels[pred_labels==l] = 0
            # else:
            #     pred_labels2[pred_labels==l] = 0
        # with h5py.File(outFn + ".hdf", 'w') as f:
        #     f.create_dataset(
        #         'volumes/small_inst',
        #         data=pred_labels2,
        #         compression='gzip')
    pred_labels_rel, _, _ = relabel_sequential(pred_labels)
    gt_labels_rel, _, _ = relabel_sequential(gt_labels)

    if overlapping_inst:
        pred_tile = [1, ] * pred_labels_rel.ndim
        pred_tile[0] = gt_labels_rel.shape[0]
        gt_tile = [1, ] * gt_labels_rel.ndim
        gt_tile[1] = pred_labels_rel.shape[0]
        pred_tiled = np.tile(pred_labels_rel, pred_tile).flatten()
        gt_tiled = np.tile(gt_labels_rel, gt_tile).flatten()
        mask = np.logical_or(pred_tiled > 0, gt_tiled > 0)
        overlay = np.array([
            pred_tiled[mask],
            gt_tiled[mask]
        ])
        overlay_labels, overlay_labels_counts = np.unique(
            overlay, return_counts=True, axis=1)
        overlay_labels = np.transpose(overlay_labels)
    else:
        overlay = np.array([pred_labels_rel.flatten(),
                            gt_labels_rel.flatten()])
        print("overlay shape relabeled %s", overlay.shape)
        # get overlaying cells and the size of the overlap
        overlay_labels, overlay_labels_counts = np.unique(
            overlay, return_counts=True, axis=1)
        overlay_labels = np.transpose(overlay_labels)

    # get gt cell ids and the size of the corresponding cell
    gt_labels_list, gt_counts = np.unique(gt_labels_rel, return_counts=True)
    gt_labels_count_dict = {}
    print("%s %s", gt_labels_list, gt_counts)
    for (l,c) in zip(gt_labels_list, gt_counts):
        gt_labels_count_dict[l] = c

    # get pred cell ids
    pred_labels_list, pred_counts = np.unique(pred_labels_rel,
                                              return_counts=True)
    print("%s %s", pred_labels_list, pred_counts)

    pred_labels_count_dict = {}
    for (l,c) in zip(pred_labels_list, pred_counts):
        pred_labels_count_dict[l] = c

    num_pred_labels = int(np.max(pred_labels_rel))
    num_gt_labels = int(np.max(gt_labels_rel))
    num_matches = min(num_gt_labels, num_pred_labels)
    iouMat = np.zeros((num_gt_labels+1, num_pred_labels+1),
                      dtype=np.float32)
    recallMat = np.zeros((num_gt_labels+1, num_pred_labels+1),
                         dtype=np.float32)
    precMat = np.zeros((num_gt_labels+1, num_pred_labels+1),
                       dtype=np.float32)
    fscoreMat = np.zeros((num_gt_labels+1, num_pred_labels+1),
                         dtype=np.float32)

    for (u,v), c in zip(overlay_labels, overlay_labels_counts):
        iou = c / (gt_labels_count_dict[v] + pred_labels_count_dict[u] - c)

        iouMat[v, u] = iou
        recallMat[v, u] = c / gt_labels_count_dict[v]
        precMat[v, u] = c / pred_labels_count_dict[u]
        fscoreMat[v, u] = 2 * (precMat[v, u] * recallMat[v, u]) / \
                              (precMat[v, u] + recallMat[v, u])
    iouMat = iouMat[1:, 1:]
    recallMat = recallMat[1:, 1:]
    precMat = precMat[1:, 1:]
    fscoreMat = fscoreMat[1:, 1:]

    metrics = Metrics()
    tblNameGen = "general"
    metrics.addTable(tblNameGen)
    metrics.addMetric(tblNameGen, "Num GT", num_gt_labels)
    metrics.addMetric(tblNameGen, "Num Pred", num_pred_labels)

    ths = [0.1, 0.2, 0.3, 0.4, 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95]
    aps = []
    metrics.addTable("confusion_matrix")
    for th in ths:
        tblname = "confusion_matrix.th_"+str(th).replace(".", "_")
        metrics.addTable(tblname)
        fscore = 0
        if num_matches > 0 and np.max(iouMat) > th:
            costs = -(iouMat >= th).astype(float) - iouMat / (2*num_matches)
            print(f'start computing lin sum assign for th {th}')
            gt_ind, pred_ind = linear_sum_assignment(costs)
            assert num_matches == len(gt_ind) == len(pred_ind)
            match_ok = iouMat[gt_ind, pred_ind] >= th
            tp = np.count_nonzero(match_ok)
            fscore_cnt = 0
            for idx, match in enumerate(match_ok):
                if match:
                    fscore = fscoreMat[gt_ind[idx], pred_ind[idx]]
                    if fscore >= 0.8:
                        fscore_cnt += 1
        else:
            tp = 0
            fscore_cnt = 0
        if visualize and tp > 0 and th == 0.5:
            vis_tp = np.zeros_like(gt_labels_rel, dtype=np.float32)
            vis_fp = np.zeros_like(gt_labels_rel, dtype=np.float32)
            vis_fn = np.zeros_like(gt_labels_rel, dtype=np.float32)
            vis_tp_seg = np.zeros_like(gt_labels_rel, dtype=np.float32)
            vis_tp_seg2 = np.zeros_like(gt_labels_rel, dtype=np.float32)
            vis_fp_seg = np.zeros_like(gt_labels_rel, dtype=np.float32)
            vis_fn_seg = np.zeros_like(gt_labels_rel, dtype=np.float32)
            if len(gt_labels_rel.shape) == 3:
                vis_fp_seg_bnd = np.zeros_like(gt_labels_rel, dtype=np.float32)
                vis_fn_seg_bnd = np.zeros_like(gt_labels_rel, dtype=np.float32)

            cntrs_gt = scipy.ndimage.measurements.center_of_mass(
                gt_labels_rel > 0,
                gt_labels_rel, sorted(list(np.unique(gt_labels_rel)))[1:])
            cntrs_pred = scipy.ndimage.measurements.center_of_mass(
                pred_labels_rel > 0,
                pred_labels_rel, sorted(list(np.unique(pred_labels_rel)))[1:])
            sz = 1
            for gti, pi, in zip(gt_ind, pred_ind):
                if iouMat[gti, pi] < th:
                    vis_fn_seg[gt_labels_rel == gti+1] = 1
                    if len(gt_labels_rel.shape) == 3:
                        set_boundary(gt_labels_rel, gti+1,
                                     vis_fn_seg_bnd)
                    vis_fp_seg[pred_labels_rel == pi+1] = 1
                    if len(gt_labels_rel.shape) == 3:
                        set_boundary(pred_labels_rel, pi+1,
                                     vis_fp_seg_bnd)
                    cntr = cntrs_gt[gti]
                    if len(gt_labels_rel.shape) == 3:
                        vis_fn[int(cntr[0]), int(cntr[1]), int(cntr[2])] = 1
                    else:
                        vis_fn[int(cntr[0]), int(cntr[1])] = 1
                    cntr = cntrs_pred[pi]
                    if len(gt_labels_rel.shape) == 3:
                        vis_fp[int(cntr[0]), int(cntr[1]), int(cntr[2])] = 1
                    else:
                        vis_fp[int(cntr[0]), int(cntr[1])] = 1
                else:
                    vis_tp_seg[gt_labels_rel == gti+1] = 1
                    cntr = cntrs_gt[gti]
                    if len(gt_labels_rel.shape) == 3:
                        vis_tp[int(cntr[0]), int(cntr[1]), int(cntr[2])] = 1
                    else:
                        vis_tp[int(cntr[0]), int(cntr[1])] = 1
                    vis_tp_seg2[pred_labels_rel == pi+1] = 1
            vis_tp = scipy.ndimage.gaussian_filter(vis_tp, sz, truncate=sz)
            for gti in range(num_gt_labels):
                if gti in gt_ind:
                    continue
                vis_fn_seg[gt_labels_rel == gti+1] = 1
                if len(gt_labels_rel.shape) == 3:
                    set_boundary(gt_labels_rel, gti+1,
                                 vis_fn_seg_bnd)
                cntr = cntrs_gt[gti]
                if len(gt_labels_rel.shape) == 3:
                    vis_fn[int(cntr[0]), int(cntr[1]), int(cntr[2])] = 1
                else:
                    vis_fn[int(cntr[0]), int(cntr[1])] = 1
            vis_fn = scipy.ndimage.gaussian_filter(vis_fn, sz, truncate=sz)
            for pi in range(num_pred_labels):
                if pi in pred_ind:
                    continue
                vis_fp_seg[pred_labels_rel == pi+1] = 1
                if len(gt_labels_rel.shape) == 3:
                    set_boundary(pred_labels_rel, pi+1,
                                 vis_fp_seg_bnd)
                cntr = cntrs_pred[pi]
                if len(gt_labels_rel.shape) == 3:
                    vis_fp[int(cntr[0]), int(cntr[1]), int(cntr[2])] = 1
                else:
                    vis_fp[int(cntr[0]), int(cntr[1])] = 1
            vis_fp = scipy.ndimage.gaussian_filter(vis_fp, sz, truncate=sz)
            vis_tp = vis_tp/np.max(vis_tp)
            vis_fp = vis_fp/np.max(vis_fp)
            vis_fn = vis_fn/np.max(vis_fn)
           
        metrics.addMetric(tblname, "Fscore_cnt", fscore_cnt)
        fp = num_pred_labels - tp
        fn = num_gt_labels - tp
        metrics.addMetric(tblname, "AP_TP", tp)
        metrics.addMetric(tblname, "AP_FP", fp)
        metrics.addMetric(tblname, "AP_FN", fn)
        ap = tp / max(1, tp + fn + fp)
        aps.append(ap)
        metrics.addMetric(tblname, "AP", ap)
        precision = tp / max(1, tp + fp)
        metrics.addMetric(tblname, "precision", precision)
        recall = tp / max(1, tp + fn)
        metrics.addMetric(tblname, "recall", recall)
        if (precision + recall) > 0:
            fscore = (2. * precision * recall) / max(1, precision + recall)
        else:
            fscore = 0.0
        metrics.addMetric(tblname, 'fscore', fscore)

    avAP19 = np.mean(aps)
    avAP59 = np.mean(aps[4:])
    metrics.addMetric("confusion_matrix", "avAP", avAP19)
    metrics.addMetric("confusion_matrix", "avAP59", avAP59)
    metrics.addMetric("confusion_matrix", "avAP19", avAP19)

    return metrics.metricsDict


def set_boundary(labels_rel, label, target):
    coords_z, coords_y, coords_x = np.nonzero(labels_rel == label)
    coords = {}
    for z,y,x in zip(coords_z, coords_y, coords_x):
        coords.setdefault(z, []).append((z, y, x))
    max_z = -1
    max_z_len = -1
    for z, v in coords.items():
        if len(v) > max_z_len:
            max_z_len = len(v)
            max_z = z
    tmp = np.zeros_like(labels_rel[max_z], dtype=np.float32)
    tmp = labels_rel[max_z]==label
    struct = scipy.ndimage.generate_binary_structure(2, 2)
    eroded_tmp = scipy.ndimage.binary_erosion(
        tmp,
        iterations=1,
        structure=struct,
        border_value=1)
    bnd = np.logical_xor(tmp, eroded_tmp)
    target[max_z][bnd] = 1
    
def get_patch(img, patch_shape, x_start, y_start, z_start):
    '''random tile strategy'''
    x_end = x_start + patch_shape[0]
    y_end = y_start + patch_shape[1]
    z_end = z_start + patch_shape[2]
    return img[x_start:x_end, y_start:y_end, z_start:z_end, :]

def patch(img, patch_shape):
    patches = []
    
    num_x = img.shape[0] // patch_shape[0]
    num_y = img.shape[1] // patch_shape[1]
    num_z = img.shape[2] // patch_shape[2]
    
    for x in range(num_x):
        for y in range(num_y):
            for z in range(num_z):
            
                x_start = patch_shape[0] * x
                y_start = patch_shape[1] * y
                z_start = patch_shape[2] * z

                patch = get_patch(img, patch_shape, x_start, y_start, z_start)
                patches.append(patch)
    
    return np.array(patches)

def re_patch(patches, img_shape):
    patch_shape = patches.shape[1:]
    img = np.zeros(img_shape)
    num_x = img_shape[0] // patch_shape[0]
    num_y = img_shape[1] // patch_shape[1]
    num_z = img_shape[2] // patch_shape[2]
    
    current_patch = 0
    for x in range(num_x):
        for y in range(num_y):
            for z in range(num_z):
                x_start = patch_shape[0] * x
                y_start = patch_shape[1] * y
                z_start = patch_shape[2] * z
                x_end = x_start + patch_shape[0]
                y_end = y_start + patch_shape[1]
                z_end = z_start + patch_shape[2]
                img[x_start:x_end, y_start:y_end, z_start:z_end, :] = patches[current_patch]
                
                current_patch = current_patch+1
    return img

In [13]:
from skimage import io
instance_mask = io.imread('data/masks/leaf3.tiff').astype(np.uint16)
# Alex's mask has several instanevs for the background (value below 200), here we merge them to value 0
instance_mask[instance_mask <= 200] = 0

#pred_labels = np.load(res_file)
#gt_labels  = np.load(gt_file)
pred_labels = instance_mask +100
gt_labels = instance_mask

metr = evaluate_file(pred_labels, gt_labels, background=0,
                  foreground_only=False, use_linear_sum_assignment=True)

print(metr)

prediction min %f, max %f, shape %s 100 2496 (256, 256, 256)
prediction shape %s (256, 256, 256)
gt min %f, max %f, shape %s 0 2396 (256, 256, 256)
gt shape %s (256, 256, 256)
processing
overlay shape relabeled %s (2, 16777216)
%s %s [   0    1    2 ... 1138 1139 1140] [7203224   10730    4733 ...    1973    1849   30085]
%s %s [   1    2    3 ... 1139 1140 1141] [7203224   10730    4733 ...    1973    1849   30085]
start computing lin sum assign for th 0.1
start computing lin sum assign for th 0.2
start computing lin sum assign for th 0.3
start computing lin sum assign for th 0.4
start computing lin sum assign for th 0.5
start computing lin sum assign for th 0.55
start computing lin sum assign for th 0.6
start computing lin sum assign for th 0.65
start computing lin sum assign for th 0.7
start computing lin sum assign for th 0.75
start computing lin sum assign for th 0.8
start computing lin sum assign for th 0.85
start computing lin sum assign for th 0.9
start computing lin sum assign