Skip to content

Commit

Permalink
Merge pull request #59 from arraiyopensource/feat/dice_loss
Browse files Browse the repository at this point in the history
implement multiclass dice loss
  • Loading branch information
edgarriba committed Feb 4, 2019
2 parents 7175b4f + d410f9c commit 9b0fddf
Show file tree
Hide file tree
Showing 5 changed files with 208 additions and 1 deletion.
5 changes: 4 additions & 1 deletion docs/source/losses.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,11 @@ torchgeometry.losses

.. currentmodule:: torchgeometry.losses

.. autofunction:: ssim
.. autofunction:: one_hot
.. autofunction:: dice_loss
.. autofunction:: ssim
.. autofunction:: depth_smoothness_loss

.. autoclass:: DiceLoss
.. autoclass:: SSIM
.. autoclass:: DepthSmoothnessLoss
1 change: 1 addition & 0 deletions mypy_files.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
torchgeometry/image/gaussian.py
torchgeometry/losses/ssim.py
torchgeometry/losses/dice.py
torchgeometry/losses/depth_smooth.py
torchgeometry/homography_warper.py
59 changes: 59 additions & 0 deletions test/test_losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,65 @@
from common import TEST_DEVICES


class TestDiceLoss:
def test_one_hot(self):
num_classes = 4
labels = torch.zeros(2, 2, 1, dtype=torch.int64)
labels[0, 0, 0] = 0
labels[0, 1, 0] = 1
labels[1, 0, 0] = 2
labels[1, 1, 0] = 3

# convert labels to one hot tensor
one_hot = tgm.losses.one_hot(labels, num_classes)

assert pytest.approx(one_hot[0, labels[0, 0, 0], 0, 0].item(), 1.0)
assert pytest.approx(one_hot[0, labels[0, 1, 0], 1, 0].item(), 1.0)
assert pytest.approx(one_hot[1, labels[1, 0, 0], 0, 0].item(), 1.0)
assert pytest.approx(one_hot[1, labels[1, 1, 0], 1, 0].item(), 1.0)

def _test_smoke(self):
num_classes = 3
logits = torch.rand(2, num_classes, 3, 2)
labels = torch.rand(2, 3, 2) * num_classes
labels = labels.long()

criterion = tgm.losses.DiceLoss()
loss = criterion(logits, labels)

# TODO: implement me
def _test_all_zeros(self):
num_classes = 3
logits = torch.zeros(2, num_classes, 1, 2)
logits[:, 0] = 10.0
logits[:, 1] = 1.0
logits[:, 2] = 1.0
labels = torch.zeros(2, 1, 2, dtype=torch.int64)

criterion = tgm.losses.DiceLoss()
loss = criterion(logits, labels)
assert pytest.approx(loss.item(), 0.0)

# TODO: implement me
def _test_jit(self):
pass

def _test_gradcheck(self):
num_classes = 3
logits = torch.rand(2, num_classes, 3, 2)
labels = torch.rand(2, 3, 2) * num_classes
labels = labels.long()

logits = utils.tensor_to_gradcheck_var(logits) # to var
assert gradcheck(tgm.losses.dice_loss,
(logits, labels,), raise_exception=True)

def test_run_all(self):
self._test_smoke()
self._test_all_zeros()
self._test_gradcheck()


class TestDepthSmoothnessLoss:
def _test_smoke(self):
image = self.image.clone()
Expand Down
1 change: 1 addition & 0 deletions torchgeometry/losses/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .ssim import SSIM, ssim
from .dice import DiceLoss, dice_loss, one_hot
from .depth_smooth import DepthSmoothnessLoss, depth_smoothness_loss
143 changes: 143 additions & 0 deletions torchgeometry/losses/dice.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
from typing import Optional

import torch
import torch.nn as nn
import torch.nn.functional as F


def one_hot(labels: torch.Tensor,
num_classes: int,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None) -> torch.Tensor:
r"""Converts an integer label 2D tensor to a one-hot 3D tensor.
Args:
labels (torch.Tensor) : tensor with labels of shape :math:`(N, H, W)`,
where N is batch siz. Each value is an integer
representing correct classification.
num_classes (int): number of classes in labels.
device (Optional[torch.device]): the desired device of returned tensor.
Default: if None, uses the current device for the default tensor type
(see torch.set_default_tensor_type()). device will be the CPU for CPU
tensor types and the current CUDA device for CUDA tensor types.
dtype (Optional[torch.dtype]): the desired data type of returned
tensor. Default: if None, infers data type from values.
Returns:
torch.Tensor: the labels in one hot tensor.
Examples::
>>> labels = torch.LongTensor([[[0, 1], [2, 0]]])
>>> tgm.losses.one_hot(labels, num_classes=3)
tensor([[[[1., 0.],
[0., 1.]],
[[0., 1.],
[0., 0.]],
[[0., 0.],
[1., 0.]]]]
"""
if not torch.is_tensor(labels):
raise TypeError("Input labels type is not a torch.Tensor. Got {}"
.format(type(labels)))
if not len(labels.shape) == 3:
raise ValueError("Invalid depth shape, we expect BxHxW. Got: {}"
.format(labels.shape))
if not labels.dtype == torch.int64:
raise ValueError(
"labels must be of the same dtype torch.int64. Got: {}" .format(
labels.dtype))
if num_classes < 1:
raise ValueError("The number of classes must be bigger than one."
" Got: {}".format(num_classes))
batch_size, height, width = labels.shape
one_hot = torch.zeros(batch_size, num_classes, height, width,
device=device, dtype=dtype)
return one_hot.scatter_(1, labels.unsqueeze(1), 1.0)


# based on:
# https://github.com/kevinzakka/pytorch-goodies/blob/master/losses.py

class DiceLoss(nn.Module):
r"""Criterion that computes Sørensen-Dice Coefficient loss.
According to [1], we compute the Sørensen-Dice Coefficient as follows:
.. math::
\text{Dice}(x, class) = \frac{2 |X| \cap |Y|}{|X| + |Y|}
where:
- :math:`X` expects to be the scores of each class.
- :math:`Y` expects to be the one-hot tensor with the class labels.
the loss, is finally computed as:
.. math::
\text{loss}(x, class) = 1 - \text{Dice}(x, class)
[1] https://en.wikipedia.org/wiki/S%C3%B8rensen%E2%80%93Dice_coefficient
Shape:
- Input: :math:`(N, C, H, W)` where C = number of classes.
- Target: :math:`(N, H, W)` where each value is
:math:`0 ≤ targets[i] ≤ C−1`.
Examples:
>>> N = 5 # num_classes
>>> loss = tgm.losses.DiceLoss()
>>> input = torch.randn(1, N, 3, 5, requires_grad=True)
>>> target = torch.empty(1, 3, 5, dtype=torch.long).random_(N)
>>> output = loss(input, target)
>>> output.backward()
"""
def __init__(self) -> None:
super(DiceLoss, self).__init__()
self.eps: float = 1e-6

def forward(
self,
input: torch.Tensor,
target: torch.Tensor) -> torch.Tensor:
if not torch.is_tensor(input):
raise TypeError("Input type is not a torch.Tensor. Got {}"
.format(type(input)))
if not len(input.shape) == 4:
raise ValueError("Invalid input shape, we expect BxNxHxW. Got: {}"
.format(input.shape))
if not input.shape[-2:] == target.shape[-2:]:
raise ValueError("input and target shapes must be the same. Got: {}"
.format(input.shape, input.shape))
if not input.device == target.device:
raise ValueError(
"input and target must be in the same device. Got: {}" .format(
input.device, target.device))
# compute softmax over the classes axis
input_soft = F.softmax(input, dim=1)

# create the labels one hot tensor
target_one_hot = one_hot(target, num_classes=input.shape[1],
device=input.device, dtype=input.dtype)

# compute the actual dice score
intersection = torch.sum(input_soft * target_one_hot, dim=(1, 2, 3))
cardinality = torch.sum(input_soft + target_one_hot, dim=(1, 2, 3))

dice_score = 2. * intersection / (cardinality + self.eps)
return torch.mean(1. - dice_score)


######################
# functional interface
######################


def dice_loss(
input: torch.Tensor,
target: torch.Tensor) -> torch.Tensor:
r"""Function that computes Sørensen-Dice Coefficient loss.
See :class:`~torchgeometry.losses.DiceLoss` for details.
"""
return DiceLoss()(input, target)

0 comments on commit 9b0fddf

Please sign in to comment.