Skip to content
This repository has been archived by the owner on Apr 19, 2023. It is now read-only.

Commit

Permalink
Merge pull request #96 from abailoni/generalized-dice-loss
Browse files Browse the repository at this point in the history
Added generalized dice loss
  • Loading branch information
nasimrahaman committed May 24, 2018
2 parents a96dd32 + 9500933 commit 6289c24
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 3 deletions.
89 changes: 87 additions & 2 deletions inferno/extensions/criteria/set_similarity_measures.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@
from ...utils.torch_utils import flatten_samples
from torch.autograd import Variable


__all__ = ['SorensenDiceLoss']
__all__ = ['SorensenDiceLoss', 'GeneralizedDiceLoss']


class SorensenDiceLoss(nn.Module):
Expand All @@ -28,6 +27,13 @@ def __init__(self, weight=None, channelwise=True, eps=1e-6):
self.eps = eps

def forward(self, input, target):
"""
input: torch.FloatTensor or torch.cuda.FloatTensor
target: torch.FloatTensor or torch.cuda.FloatTensor
Expected shape of the inputs: (batch_size, nb_channels, ...)
"""
assert input.size() == target.size()
if not self.channelwise:
numerator = (input * target).sum()
denominator = (input * input).sum() + (target * target).sum()
Expand All @@ -49,8 +55,87 @@ def forward(self, input, target):
channelwise_loss = channelwise_loss.squeeze(1)
# Wrap weights in a variable
weight = Variable(self.weight, requires_grad=False)
assert weight.size() == channelwise_loss.size()
# Apply weight
channelwise_loss = weight * channelwise_loss
# Sum over the channels to compute the total loss
loss = channelwise_loss.sum()
return loss


class GeneralizedDiceLoss(nn.Module):
"""
Computes the scalar Generalized Dice Loss defined in https://arxiv.org/abs/1707.03237
This version works for multiple classes and expects predictions for every class (e.g. softmax output) and
one-hot targets for every class.
"""
def __init__(self, weight=None, channelwise=False, eps=1e-6):
super(GeneralizedDiceLoss, self).__init__()
self.register_buffer('weight', weight)
self.channelwise = channelwise
self.eps = eps

def forward(self, input, target):
"""
input: torch.FloatTensor or torch.cuda.FloatTensor
target: torch.FloatTensor or torch.cuda.FloatTensor
Expected shape of the inputs:
- if not channelwise: (batch_size, nb_classes, ...)
- if channelwise: (batch_size, nb_channels, nb_classes, ...)
"""
assert input.size() == target.size()
if not self.channelwise:
# Flatten input and target to have the shape (nb_classes, N),
# where N is the number of samples
input = flatten_samples(input)
target = flatten_samples(target)

# Find classes weights:
sum_targets = target.sum(-1)
class_weigths = 1. / (sum_targets * sum_targets).clamp(min=self.eps)

# Compute generalized Dice loss:
numer = ((input * target).sum(-1) * class_weigths).sum()
denom = ((input + target).sum(-1) * class_weigths).sum()

loss = 1. - 2. * numer / denom.clamp(min=self.eps)
else:
def flatten_and_preserve_channels(tensor):
tensor_dim = tensor.dim()
assert tensor_dim >= 3
num_channels = tensor.size(1)
num_classes = tensor.size(2)
# Permute the channel axis to first
permute_axes = list(range(tensor_dim))
permute_axes[0], permute_axes[1], permute_axes[2] = permute_axes[1], permute_axes[2], permute_axes[0]
permuted = tensor.permute(*permute_axes).contiguous()
flattened = permuted.view(num_channels, num_classes, -1)
return flattened

# Flatten input and target to have the shape (nb_channels, nb_classes, N)
input = flatten_and_preserve_channels(input)
target = flatten_and_preserve_channels(target)

# Find classes weights:
sum_targets = target.sum(-1)
class_weigths = 1. / (sum_targets * sum_targets).clamp(min=self.eps)

# Compute generalized Dice loss:
numer = ((input * target).sum(-1) * class_weigths).sum(-1)
denom = ((input + target).sum(-1) * class_weigths).sum(-1)

channelwise_loss = 1. - 2. * numer / denom.clamp(min=self.eps)

if self.weight is not None:
if channelwise_loss.dim() == 2:
channelwise_loss = channelwise_loss.squeeze(1)
channel_weights = Variable(self.weight, requires_grad=False)
assert channel_weights.size() == channelwise_loss.size(), "`weight` should have shape (nb_channels, ), `target` should have shape (batch_size, nb_channels, nb_classes, ...)"
# Apply channel weights:
channelwise_loss = channel_weights * channelwise_loss

loss = channelwise_loss.sum()

return loss
25 changes: 24 additions & 1 deletion tests/extensions/criteria/set_similarity_measures.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,20 @@
from torch.autograd import Variable


class TestSorensenDice(unittest.TestCase):
class SetSimilarityTest(unittest.TestCase):
def get_dummy_variables(self):
x = Variable(torch.zeros(3, 2, 100, 100).uniform_())
y = Variable(torch.zeros(3, 2, 100, 100).uniform_())
return x, y

def get_dummy_variables_with_channels_and_classes(self):
# (batch_size, channels, classes, ...)
x = Variable(torch.zeros(3, 2, 5, 100, 100).uniform_())
y = Variable(torch.zeros(3, 2, 5, 100, 100).uniform_())
return x, y


class TestSorensenDice(SetSimilarityTest):
# noinspection PyCallingNonCallable
def test_channelwise(self):
from inferno.extensions.criteria.set_similarity_measures import SorensenDiceLoss
Expand All @@ -25,5 +33,20 @@ def test_channelwise(self):
self.assertAlmostEqual(expected_channelwise_loss.data[0], channelwise_loss.data[0])


class TestGeneralizedSorensenDice(SetSimilarityTest):
def test_channelwise(self):
from inferno.extensions.criteria.set_similarity_measures import GeneralizedDiceLoss
x, y = self.get_dummy_variables_with_channels_and_classes()
channelwise = GeneralizedDiceLoss(channelwise=True)
not_channelwise = GeneralizedDiceLoss(channelwise=False)
# Compute channelwise loss and expected one:
channelwise_loss = channelwise(x, y)
expected_channelwise_loss = \
not_channelwise(x[:, 0, ...], y[:, 0, ...]) + \
not_channelwise(x[:, 1, ...], y[:, 1, ...])
# Compare
self.assertAlmostEqual(expected_channelwise_loss.data[0], channelwise_loss.data[0])


if __name__ == '__main__':
unittest.main()

0 comments on commit 6289c24

Please sign in to comment.