Skip to content
This repository has been archived by the owner on Apr 19, 2023. It is now read-only.

Commit

Permalink
Merge pull request #134 from inferno-pytorch/retain-graph
Browse files Browse the repository at this point in the history
Enable 3d rand metrics
  • Loading branch information
constantinpape committed Aug 14, 2018
2 parents 1d2dc58 + becaa62 commit ebeea3c
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 7 deletions.
28 changes: 22 additions & 6 deletions inferno/extensions/metrics/arand.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,25 @@ class ArandScore(Metric):
----------
[1]: http://journal.frontiersin.org/article/10.3389/fnana.2015.00142/full#h3
"""
def __init__(self, average_slices=True):
self.average_slices = average_slices

def forward(self, prediction, target):
assert(len(prediction) == len(target)), "%i, %i" % (len(prediction), len(target))
prediction = prediction.cpu().numpy().squeeze()
target = target.cpu().numpy().squeeze()
return np.mean([adapted_rand(pred, targ)[0]
for pred, targ in zip(prediction, target)])
if self.average_slices:
return np.mean([adapted_rand(pred, targ)[0]
for pred, targ in zip(prediction, target)])
else:
return adapted_rand(prediction, target)[0]


class ArandError(ArandScore):
"""Arand Error = 1 - <arand score>"""
def __init__(self, **super_kwargs):
super(ArandError, self).__init__(**super_kwargs)

def forward(self, prediction, target):
return 1. - super(ArandError, self).forward(prediction, target)

Expand Down Expand Up @@ -63,9 +72,16 @@ def adapted_rand(seg, gt):
if np.any(gt == 0):
logger.debug("Zeros in ground truth, 0's will be ignored.")

if np.all(seg == 0) or np.all(gt == 0):
logger.error("Either segmentation or groundtruth are all zeros.")
return [0, 0, 0]
seg_zeros = np.all(seg == 0)
gt_zeros = np.all(gt == 0)
if seg_zeros or gt_zeros:
if seg_zeros:
logger.error("Segmentation is all zeros.")
else:
print(gt.shape)
print(np.unique(gt))
logger.error("Groundtruth is all zeros.")
return 0, 0, 0

# segA is truth, segB is query
segA = np.ravel(gt)
Expand Down Expand Up @@ -112,4 +128,4 @@ def adapted_rand(seg, gt):
precision = float(sum_p_ij) / sum_b
recall = float(sum_p_ij) / sum_a
f_score = 2.0 * precision * recall / (precision + recall)
return [f_score, precision, recall]
return f_score, precision, recall
3 changes: 2 additions & 1 deletion inferno/extensions/metrics/cremi_score.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import numpy as np
from .voi import voi
from .arand import adapted_rand

Expand All @@ -11,5 +12,5 @@ def cremi_metrics(seg, gt, no_seg_ignore=True):
seg += 1
vi_s, vi_m = voi(seg, gt)
rand = 1. - adapted_rand(seg, gt)[0]
cs = (vi_s + vi_m + rand) / 3
cs = np.sqrt((vi_s + vi_m) * rand)
return cs, vi_s, vi_m, rand

0 comments on commit ebeea3c

Please sign in to comment.