Skip to content
This repository has been archived by the owner on Apr 19, 2023. It is now read-only.

Commit

Permalink
Enable use of SorensenDiceLoss as metric
Browse files Browse the repository at this point in the history
  • Loading branch information
constantinpape committed Jan 7, 2019
1 parent f4b8fe8 commit 94bcd01
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions inferno/extensions/criteria/set_similarity_measures.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class SorensenDiceLoss(nn.Module):
between the input and the target. For both inputs and targets it must be the case that
`input_or_target.size(1) = num_channels`.
"""
def __init__(self, weight=None, channelwise=True, eps=1e-6):
def __init__(self, weight=None, channelwise=True, eps=1e-6, use_as_metric=False):
"""
Parameters
----------
Expand All @@ -20,11 +20,15 @@ def __init__(self, weight=None, channelwise=True, eps=1e-6):
channelwise : bool
Whether to apply the loss channelwise and sum the results (True)
or to apply it on all channels jointly (False).
use_as_metric: bool
Whether to return the sorensen dice scoere as metric (range 0 to 1)
instead of as loss (range -1 to 0). Default is false.
"""
super(SorensenDiceLoss, self).__init__()
self.register_buffer('weight', weight)
self.channelwise = channelwise
self.eps = eps
self.factor = 2. if use_as_metric else -2.

def forward(self, input, target):
"""
Expand All @@ -37,9 +41,8 @@ def forward(self, input, target):
if not self.channelwise:
numerator = (input * target).sum()
denominator = (input * input).sum() + (target * target).sum()
loss = -2. * (numerator / denominator.clamp(min=self.eps))
loss = self.factor * (numerator / denominator.clamp(min=self.eps))
else:
# TODO This should be compatible with Pytorch 0.2, but check
# Flatten input and target to have the shape (C, N),
# where N is the number of samples
input = flatten_samples(input)
Expand All @@ -48,7 +51,7 @@ def forward(self, input, target):
# leaving the channels intact)
numerator = (input * target).sum(-1)
denominator = (input * input).sum(-1) + (target * target).sum(-1)
channelwise_loss = -2 * (numerator / denominator.clamp(min=self.eps))
channelwise_loss = self.factor * (numerator / denominator.clamp(min=self.eps))
if self.weight is not None:
# With pytorch < 0.2, channelwise_loss.size = (C, 1).
if channelwise_loss.dim() == 2:
Expand Down

0 comments on commit 94bcd01

Please sign in to comment.