Skip to content
This repository has been archived by the owner on Apr 19, 2023. It is now read-only.

Commit

Permalink
- modify arand metrics for 2d and 3d inputs
Browse files Browse the repository at this point in the history
  • Loading branch information
constantinpape committed Aug 16, 2018
1 parent 57cf4dd commit f5a1d70
Showing 1 changed file with 23 additions and 7 deletions.
30 changes: 23 additions & 7 deletions inferno/extensions/metrics/arand.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,33 @@ class ArandScore(Metric):
def __init__(self, average_slices=True):
self.average_slices = average_slices

def forward(self, prediction, target):
assert(len(prediction) == len(target)), "%i, %i" % (len(prediction), len(target))
prediction = prediction.cpu().numpy().squeeze()
target = target.cpu().numpy().squeeze()
if self.average_slices:
# compute the arand score for a prediction target pair
def _arand_for_tensor(self, prediction, target):
ndim = prediction.ndim
average_slices = self.average_slices and ndim == 3

if average_slices:
# average the arand values over the 3d slices
evaluation_values = [adapted_rand(pred, targ)
for pred, targ in zip(prediction, target)]
return np.mean([eval_val[0] for eval_val in evaluation_values if eval_val is not None])
else:
return adapted_rand(prediction, target)[0]

def forward(self, prediction, target):
assert(prediction.shape == target.shape), "%s, %s" % (str(prediction.shape),
str(target.shape))
assert prediction.shape[1] == 1, "Expect singleton channel axis"
prediction = prediction.cpu().numpy()
target = target.cpu().numpy()

ndim = prediction.ndim
assert ndim in (4, 5), "Expect 2 or 3d input with additional batch and channel axis"

# return the average arand error over the batches
return np.mean([self._arand_for_tensor(pred[0], targ[0])
for pred, targ in zip(prediction, target)])


class ArandError(ArandScore):
"""Arand Error = 1 - <arand score>"""
Expand Down Expand Up @@ -77,12 +93,12 @@ def adapted_rand(seg, gt):
gt_zeros = np.all(gt == 0)
if seg_zeros or gt_zeros:
if seg_zeros:
logger.warn("Segmentation is all zeros, ignoring for eval.")
logger.warning("Segmentation is all zeros, ignoring for eval.")
return None
else:
print(gt.shape)
print(np.unique(gt))
logger.warn("Groundtruth is all zeros, ignoring for eval.")
logger.warning("Groundtruth is all zeros, ignoring for eval.")
return None

# segA is truth, segB is query
Expand Down

0 comments on commit f5a1d70

Please sign in to comment.