-
-
Notifications
You must be signed in to change notification settings - Fork 949
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #59 from arraiyopensource/feat/dice_loss
implement multiclass dice loss
- Loading branch information
Showing
5 changed files
with
208 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
torchgeometry/image/gaussian.py | ||
torchgeometry/losses/ssim.py | ||
torchgeometry/losses/dice.py | ||
torchgeometry/losses/depth_smooth.py | ||
torchgeometry/homography_warper.py |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,3 @@ | ||
from .ssim import SSIM, ssim | ||
from .dice import DiceLoss, dice_loss, one_hot | ||
from .depth_smooth import DepthSmoothnessLoss, depth_smoothness_loss |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,143 @@ | ||
from typing import Optional | ||
|
||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
|
||
|
||
def one_hot(labels: torch.Tensor, | ||
num_classes: int, | ||
device: Optional[torch.device] = None, | ||
dtype: Optional[torch.dtype] = None) -> torch.Tensor: | ||
r"""Converts an integer label 2D tensor to a one-hot 3D tensor. | ||
Args: | ||
labels (torch.Tensor) : tensor with labels of shape :math:`(N, H, W)`, | ||
where N is batch siz. Each value is an integer | ||
representing correct classification. | ||
num_classes (int): number of classes in labels. | ||
device (Optional[torch.device]): the desired device of returned tensor. | ||
Default: if None, uses the current device for the default tensor type | ||
(see torch.set_default_tensor_type()). device will be the CPU for CPU | ||
tensor types and the current CUDA device for CUDA tensor types. | ||
dtype (Optional[torch.dtype]): the desired data type of returned | ||
tensor. Default: if None, infers data type from values. | ||
Returns: | ||
torch.Tensor: the labels in one hot tensor. | ||
Examples:: | ||
>>> labels = torch.LongTensor([[[0, 1], [2, 0]]]) | ||
>>> tgm.losses.one_hot(labels, num_classes=3) | ||
tensor([[[[1., 0.], | ||
[0., 1.]], | ||
[[0., 1.], | ||
[0., 0.]], | ||
[[0., 0.], | ||
[1., 0.]]]] | ||
""" | ||
if not torch.is_tensor(labels): | ||
raise TypeError("Input labels type is not a torch.Tensor. Got {}" | ||
.format(type(labels))) | ||
if not len(labels.shape) == 3: | ||
raise ValueError("Invalid depth shape, we expect BxHxW. Got: {}" | ||
.format(labels.shape)) | ||
if not labels.dtype == torch.int64: | ||
raise ValueError( | ||
"labels must be of the same dtype torch.int64. Got: {}" .format( | ||
labels.dtype)) | ||
if num_classes < 1: | ||
raise ValueError("The number of classes must be bigger than one." | ||
" Got: {}".format(num_classes)) | ||
batch_size, height, width = labels.shape | ||
one_hot = torch.zeros(batch_size, num_classes, height, width, | ||
device=device, dtype=dtype) | ||
return one_hot.scatter_(1, labels.unsqueeze(1), 1.0) | ||
|
||
|
||
# based on: | ||
# https://github.com/kevinzakka/pytorch-goodies/blob/master/losses.py | ||
|
||
class DiceLoss(nn.Module): | ||
r"""Criterion that computes Sørensen-Dice Coefficient loss. | ||
According to [1], we compute the Sørensen-Dice Coefficient as follows: | ||
.. math:: | ||
\text{Dice}(x, class) = \frac{2 |X| \cap |Y|}{|X| + |Y|} | ||
where: | ||
- :math:`X` expects to be the scores of each class. | ||
- :math:`Y` expects to be the one-hot tensor with the class labels. | ||
the loss, is finally computed as: | ||
.. math:: | ||
\text{loss}(x, class) = 1 - \text{Dice}(x, class) | ||
[1] https://en.wikipedia.org/wiki/S%C3%B8rensen%E2%80%93Dice_coefficient | ||
Shape: | ||
- Input: :math:`(N, C, H, W)` where C = number of classes. | ||
- Target: :math:`(N, H, W)` where each value is | ||
:math:`0 ≤ targets[i] ≤ C−1`. | ||
Examples: | ||
>>> N = 5 # num_classes | ||
>>> loss = tgm.losses.DiceLoss() | ||
>>> input = torch.randn(1, N, 3, 5, requires_grad=True) | ||
>>> target = torch.empty(1, 3, 5, dtype=torch.long).random_(N) | ||
>>> output = loss(input, target) | ||
>>> output.backward() | ||
""" | ||
def __init__(self) -> None: | ||
super(DiceLoss, self).__init__() | ||
self.eps: float = 1e-6 | ||
|
||
def forward( | ||
self, | ||
input: torch.Tensor, | ||
target: torch.Tensor) -> torch.Tensor: | ||
if not torch.is_tensor(input): | ||
raise TypeError("Input type is not a torch.Tensor. Got {}" | ||
.format(type(input))) | ||
if not len(input.shape) == 4: | ||
raise ValueError("Invalid input shape, we expect BxNxHxW. Got: {}" | ||
.format(input.shape)) | ||
if not input.shape[-2:] == target.shape[-2:]: | ||
raise ValueError("input and target shapes must be the same. Got: {}" | ||
.format(input.shape, input.shape)) | ||
if not input.device == target.device: | ||
raise ValueError( | ||
"input and target must be in the same device. Got: {}" .format( | ||
input.device, target.device)) | ||
# compute softmax over the classes axis | ||
input_soft = F.softmax(input, dim=1) | ||
|
||
# create the labels one hot tensor | ||
target_one_hot = one_hot(target, num_classes=input.shape[1], | ||
device=input.device, dtype=input.dtype) | ||
|
||
# compute the actual dice score | ||
intersection = torch.sum(input_soft * target_one_hot, dim=(1, 2, 3)) | ||
cardinality = torch.sum(input_soft + target_one_hot, dim=(1, 2, 3)) | ||
|
||
dice_score = 2. * intersection / (cardinality + self.eps) | ||
return torch.mean(1. - dice_score) | ||
|
||
|
||
###################### | ||
# functional interface | ||
###################### | ||
|
||
|
||
def dice_loss( | ||
input: torch.Tensor, | ||
target: torch.Tensor) -> torch.Tensor: | ||
r"""Function that computes Sørensen-Dice Coefficient loss. | ||
See :class:`~torchgeometry.losses.DiceLoss` for details. | ||
""" | ||
return DiceLoss()(input, target) |