Skip to content

Commit

Permalink
adds weight parameter to dice and lovasz_softmax losses (#2879)
Browse files Browse the repository at this point in the history
* adds weight parameter to dice and lovasz_softmax losses, similarly how we have focal loss

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* lint

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
ducha-aiki and pre-commit-ci[bot] committed Apr 9, 2024
1 parent d1a1cc0 commit 2387a2d
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 10 deletions.
27 changes: 23 additions & 4 deletions kornia/losses/dice.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

from typing import Optional

import torch
from torch import nn

Expand All @@ -12,7 +14,9 @@
# https://github.com/Lightning-AI/metrics/blob/v0.11.3/src/torchmetrics/functional/classification/dice.py#L66-L207


def dice_loss(pred: Tensor, target: Tensor, average: str = "micro", eps: float = 1e-8) -> Tensor:
def dice_loss(
pred: Tensor, target: Tensor, average: str = "micro", eps: float = 1e-8, weight: Optional[Tensor] = None
) -> Tensor:
r"""Criterion that computes Sørensen-Dice Coefficient loss.
According to [1], we compute the Sørensen-Dice Coefficient as follows:
Expand Down Expand Up @@ -43,6 +47,7 @@ def dice_loss(pred: Tensor, target: Tensor, average: str = "micro", eps: float =
- ``'micro'`` [default]: Calculate the loss across all classes.
- ``'macro'``: Calculate the loss for each class separately and average the metrics across classes.
eps: Scalar to enforce numerical stabiliy.
weight: weights for classes with shape :math:`(num\_of\_classes,)`.
Return:
One-element tensor of the computed loss.
Expand All @@ -64,7 +69,7 @@ def dice_loss(pred: Tensor, target: Tensor, average: str = "micro", eps: float =

if not pred.device == target.device:
raise ValueError(f"pred and target must be in the same device. Got: {pred.device} and {target.device}")

num_of_classes = pred.shape[1]
possible_average = {"micro", "macro"}
KORNIA_CHECK(average in possible_average, f"The `average` has to be one of {possible_average}. Got: {average}")

Expand All @@ -80,6 +85,18 @@ def dice_loss(pred: Tensor, target: Tensor, average: str = "micro", eps: float =
dims = (1, *dims)

# compute the actual dice score
if weight is not None:
KORNIA_CHECK_IS_TENSOR(weight, "weight must be Tensor or None.")
KORNIA_CHECK(
(weight.shape[0] == num_of_classes and weight.numel() == num_of_classes),
f"weight shape must be (num_of_classes,): ({num_of_classes},), got {weight.shape}",
)
KORNIA_CHECK(
weight.device == pred.device,
f"weight and pred must be in the same device. Got: {weight.device} and {pred.device}",
)
pred_soft = pred_soft * weight
target_one_hot = target_one_hot * weight
intersection = torch.sum(pred_soft * target_one_hot, dims)
cardinality = torch.sum(pred_soft + target_one_hot, dims)

Expand Down Expand Up @@ -120,6 +137,7 @@ class DiceLoss(nn.Module):
- ``'micro'`` [default]: Calculate the loss across all classes.
- ``'macro'``: Calculate the loss for each class separately and average the metrics across classes.
eps: Scalar to enforce numerical stabiliy.
weight: weights for classes with shape :math:`(num\_of\_classes,)`.
Shape:
- Pred: :math:`(N, C, H, W)` where C = number of classes.
Expand All @@ -135,10 +153,11 @@ class DiceLoss(nn.Module):
>>> output.backward()
"""

def __init__(self, average: str = "micro", eps: float = 1e-8) -> None:
def __init__(self, average: str = "micro", eps: float = 1e-8, weight: Optional[Tensor] = None) -> None:
super().__init__()
self.average = average
self.eps = eps
self.weight = weight

def forward(self, pred: Tensor, target: Tensor) -> Tensor:
return dice_loss(pred, target, self.average, self.eps)
return dice_loss(pred, target, self.average, self.eps, self.weight)
30 changes: 24 additions & 6 deletions kornia/losses/lovasz_softmax.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
from __future__ import annotations

from typing import Optional

import torch
from torch import Tensor, nn

from kornia.core.check import KORNIA_CHECK_SHAPE
from kornia.core.check import KORNIA_CHECK, KORNIA_CHECK_IS_TENSOR, KORNIA_CHECK_SHAPE

# based on:
# https://github.com/bermanmaxim/LovaszSoftmax


def lovasz_softmax_loss(pred: Tensor, target: Tensor) -> Tensor:
def lovasz_softmax_loss(pred: Tensor, target: Tensor, weight: Optional[Tensor] = None) -> Tensor:
r"""Criterion that computes a surrogate multi-class intersection-over-union (IoU) loss.
According to [1], we compute the IoU as follows:
Expand All @@ -22,7 +24,7 @@ def lovasz_softmax_loss(pred: Tensor, target: Tensor) -> Tensor:
Where:
- :math:`X` expects to be the scores of each class.
- :math:`Y` expects to be the binary tensor with the class labels.
- :math:`Y` expects to be the long tensor with the class labels.
the loss, is finally computed as:
Expand All @@ -41,6 +43,7 @@ def lovasz_softmax_loss(pred: Tensor, target: Tensor) -> Tensor:
pred: logits tensor with shape :math:`(N, C, H, W)` where C = number of classes > 1.
labels: labels tensor with shape :math:`(N, H, W)` where each value
is :math:`0 ≤ targets[i] ≤ C-1`.
weight: weights for classes with shape :math:`(num\_of\_classes,)`.
Return:
a scalar with the computed loss.
Expand All @@ -65,6 +68,19 @@ def lovasz_softmax_loss(pred: Tensor, target: Tensor) -> Tensor:
if not pred.device == target.device:
raise ValueError(f"pred and target must be in the same device. Got: {pred.device} and {target.device}")

num_of_classes = pred.shape[1]
# compute the actual dice score
if weight is not None:
KORNIA_CHECK_IS_TENSOR(weight, "weight must be Tensor or None.")
KORNIA_CHECK(
(weight.shape[0] == num_of_classes and weight.numel() == num_of_classes),
f"weight shape must be (num_of_classes,): ({num_of_classes},), got {weight.shape}",
)
KORNIA_CHECK(
weight.device == pred.device,
f"weight and pred must be in the same device. Got: {weight.device} and {pred.device}",
)

# flatten pred [B, C, -1] and target [B, -1] and to float
pred_flatten: Tensor = pred.reshape(pred.shape[0], pred.shape[1], -1)
target_flatten: Tensor = target.reshape(target.shape[0], -1).float()
Expand All @@ -91,7 +107,7 @@ def lovasz_softmax_loss(pred: Tensor, target: Tensor) -> Tensor:
gradient: Tensor = 1.0 - intersection / union
if N > 1:
gradient[..., 1:] = gradient[..., 1:] - gradient[..., :-1]
loss: Tensor = (errors_sorted.relu() * gradient).sum(1).mean()
loss: Tensor = (errors_sorted.relu() * gradient).sum(1).mean() * (1.0 if weight is None else weight[c])
losses.append(loss)
final_loss: Tensor = torch.stack(losses, dim=0).mean()
return final_loss
Expand Down Expand Up @@ -129,6 +145,7 @@ class LovaszSoftmaxLoss(nn.Module):
pred: logits tensor with shape :math:`(N, C, H, W)` where C = number of classes > 1.
labels: labels tensor with shape :math:`(N, H, W)` where each value
is :math:`0 ≤ targets[i] ≤ C-1`.
weight: weights for classes with shape :math:`(num\_of\_classes,)`.
Return:
a scalar with the computed loss.
Expand All @@ -142,8 +159,9 @@ class LovaszSoftmaxLoss(nn.Module):
>>> output.backward()
"""

def __init__(self) -> None:
def __init__(self, weight: Optional[Tensor] = None) -> None:
super().__init__()
self.weight = weight

def forward(self, pred: Tensor, target: Tensor) -> Tensor:
return lovasz_softmax_loss(pred=pred, target=target)
return lovasz_softmax_loss(pred=pred, target=target, weight=self.weight)
11 changes: 11 additions & 0 deletions tests/losses/test_dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,17 @@ def test_averaging_micro(self, device, dtype):
loss = criterion(logits, labels)
self.assert_close(loss, expected_loss, rtol=1e-3, atol=1e-3)

def test_weight(self, device, dtype):
num_classes = 3
eps = 1e-8
logits = torch.zeros(2, num_classes, 4, 1, device=device, dtype=dtype)
labels = torch.zeros(2, 4, 1, device=device, dtype=torch.int64)
expected_loss = torch.tensor([2.0 / 3.0], device=device, dtype=dtype).squeeze()
weight = torch.tensor([0.0, 1.0, 1.0], device=device, dtype=dtype)
criterion = kornia.losses.DiceLoss(average="micro", eps=eps, weight=weight)
loss = criterion(logits, labels)
self.assert_close(loss, expected_loss, rtol=1e-3, atol=1e-3)

def test_averaging_macro(self, device, dtype):
num_classes = 2
eps = 1e-8
Expand Down
13 changes: 13 additions & 0 deletions tests/losses/test_lovaz_softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,19 @@ def test_all_ones(self, device, dtype):

self.assert_close(loss, torch.zeros_like(loss), rtol=1e-3, atol=1e-3)

def test_weight(self, device, dtype):
num_classes = 2
# make perfect prediction
# note that softmax(prediction[:, 1]) == 1. softmax(prediction[:, 0]) == 0.
prediction = torch.zeros(2, num_classes, 1, 2, device=device, dtype=dtype)
prediction[:, 0] = 100.0
labels = torch.ones(2, 1, 2, device=device, dtype=torch.int64)

criterion = kornia.losses.LovaszSoftmaxLoss(weight=torch.tensor([1.0, 0.0], device=device, dtype=dtype))
loss = criterion(prediction, labels)

self.assert_close(loss, 0.5 * torch.ones_like(loss), rtol=1e-3, atol=1e-3)

def test_gradcheck(self, device):
num_classes = 4
logits = torch.rand(2, num_classes, 3, 2, device=device, dtype=torch.float64)
Expand Down

0 comments on commit 2387a2d

Please sign in to comment.