Skip to content

Commit

Permalink
Merge pull request #63 from arraiyopensource/feat/focal_loss
Browse files Browse the repository at this point in the history
implement focal loss
  • Loading branch information
edgarriba committed Feb 4, 2019
2 parents 89246d2 + 467dba2 commit ffe4cb1
Show file tree
Hide file tree
Showing 6 changed files with 188 additions and 2 deletions.
2 changes: 2 additions & 0 deletions docs/source/losses.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@ torchgeometry.losses
.. autofunction:: one_hot
.. autofunction:: dice_loss
.. autofunction:: tversky_loss
.. autofunction:: focal_loss
.. autofunction:: ssim
.. autofunction:: depth_smoothness_loss

.. autoclass:: DiceLoss
.. autoclass:: TverskyLoss
.. autoclass:: FocalLoss
.. autoclass:: SSIM
.. autoclass:: DepthSmoothnessLoss
1 change: 1 addition & 0 deletions mypy_files.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
torchgeometry/image/gaussian.py
torchgeometry/losses/ssim.py
torchgeometry/losses/dice.py
torchgeometry/losses/focal.py
torchgeometry/losses/tversky.py
torchgeometry/losses/depth_smooth.py
torchgeometry/homography_warper.py
65 changes: 65 additions & 0 deletions test/test_losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,71 @@
from common import TEST_DEVICES


class TestFocalLoss:
def _test_smoke_none(self):
num_classes = 3
logits = torch.rand(2, num_classes, 3, 2)
labels = torch.rand(2, 3, 2) * num_classes
labels = labels.long()

assert tgm.losses.focal_loss(
logits,
labels,
alpha=0.5,
gamma=2.0,
reduction='none').shape == (
2,
3,
2)

def _test_smoke_sum(self):
num_classes = 3
logits = torch.rand(2, num_classes, 3, 2)
labels = torch.rand(2, 3, 2) * num_classes
labels = labels.long()

assert tgm.losses.focal_loss(
logits,
labels,
alpha=0.5,
gamma=2.0,
reduction='sum').shape == ()

def _test_smoke_mean(self):
num_classes = 3
logits = torch.rand(2, num_classes, 3, 2)
labels = torch.rand(2, 3, 2) * num_classes
labels = labels.long()

assert tgm.losses.focal_loss(
logits,
labels,
alpha=0.5,
gamma=2.0,
reduction='mean').shape == ()

# TODO: implement me
def _test_jit(self):
pass

def _test_gradcheck(self):
num_classes = 3
alpha, gamma = 0.5, 2.0 # for focal loss
logits = torch.rand(2, num_classes, 3, 2)
labels = torch.rand(2, 3, 2) * num_classes
labels = labels.long()

logits = utils.tensor_to_gradcheck_var(logits) # to var
assert gradcheck(tgm.losses.focal_loss,
(logits, labels, alpha, gamma), raise_exception=True)

def test_run_all(self):
self._test_smoke_none()
self._test_smoke_sum()
self._test_smoke_mean()
self._test_gradcheck()


class TestTverskyLoss:
def _test_smoke(self):
num_classes = 3
Expand Down
1 change: 1 addition & 0 deletions torchgeometry/losses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@
from .ssim import SSIM, ssim
from .dice import DiceLoss, dice_loss
from .tversky import TverskyLoss, tversky_loss
from .focal import FocalLoss, focal_loss
from .depth_smooth import DepthSmoothnessLoss, depth_smoothness_loss
116 changes: 116 additions & 0 deletions torchgeometry/losses/focal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
from typing import Optional

import torch
import torch.nn as nn
import torch.nn.functional as F

from .one_hot import one_hot


# based on:
# https://github.com/zhezh/focalloss/blob/master/focalloss.py

class FocalLoss(nn.Module):
r"""Criterion that computes Focal loss.
According to [1], the Focal loss is computed as follows:
.. math::
\text{FL}(p_t) = -\alpha_t (1 - p_t)^{\gamma} \, \text{log}(p_t)
where:
- :math:`p_t` is the model's estimated probability for each class.
Arguments:
alpha (float): Weighting factor :math:`\alpha \in [0, 1]`.
gamma (float): Focusing parameter :math:`\gamma >= 0`.
reduction (Optional[str]): Specifies the reduction to apply to the
output: ‘none’ | ‘mean’ | ‘sum’. ‘none’: no reduction will be applied,
‘mean’: the sum of the output will be divided by the number of elements
in the output, ‘sum’: the output will be summed. Default: ‘none’.
Shape:
- Input: :math:`(N, C, H, W)` where C = number of classes.
- Target: :math:`(N, H, W)` where each value is
:math:`0 ≤ targets[i] ≤ C−1`.
Examples:
>>> N = 5 # num_classes
>>> loss = tgm.losses.FocalLoss(alpha=0.5, gamma=2.0, reduction='mean')
>>> input = torch.randn(1, N, 3, 5, requires_grad=True)
>>> target = torch.empty(1, 3, 5, dtype=torch.long).random_(N)
>>> output = loss(input, target)
>>> output.backward()
References:
[1] https://arxiv.org/abs/1708.02002
"""

def __init__(self, alpha: float, gamma: Optional[float] = 2.0,
reduction: Optional[str] = 'none') -> None:
super(FocalLoss, self).__init__()
self.alpha: float = alpha
self.gamma: Optional[float] = gamma
self.reduction: Optional[str] = reduction
self.eps: float = 1e-6

def forward(
self,
input: torch.Tensor,
target: torch.Tensor) -> torch.Tensor:
if not torch.is_tensor(input):
raise TypeError("Input type is not a torch.Tensor. Got {}"
.format(type(input)))
if not len(input.shape) == 4:
raise ValueError("Invalid input shape, we expect BxNxHxW. Got: {}"
.format(input.shape))
if not input.shape[-2:] == target.shape[-2:]:
raise ValueError("input and target shapes must be the same. Got: {}"
.format(input.shape, input.shape))
if not input.device == target.device:
raise ValueError(
"input and target must be in the same device. Got: {}" .format(
input.device, target.device))
# compute softmax over the classes axis
input_soft = F.softmax(input, dim=1)

# create the labels one hot tensor
target_one_hot = one_hot(target, num_classes=input.shape[1],
device=input.device, dtype=input.dtype)

# compute the actual focal loss
prob = input_soft * target_one_hot
focal = -torch.log(prob) * self.alpha * (1. - prob) ** self.gamma
loss_tmp = 1. - torch.sum(focal, dim=1)

loss = -1
if self.reduction == 'none':
loss = loss_tmp
elif self.reduction == 'mean':
loss = torch.mean(loss_tmp)
elif self.reduction == 'sum':
loss = torch.sum(loss_tmp)
else:
raise NotImplementedError("Invalid reduction mode: {}"
.format(self.reduction))
return loss


######################
# functional interface
######################


def focal_loss(
input: torch.Tensor,
target: torch.Tensor,
alpha: float,
gamma: Optional[float] = 2.0,
reduction: Optional[str] = 'none') -> torch.Tensor:
r"""Function that computes Focal loss.
See :class:`~torchgeometry.losses.FocalLoss` for details.
"""
return FocalLoss(alpha, gamma, reduction)(input, target)
5 changes: 3 additions & 2 deletions torchgeometry/losses/one_hot.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
def one_hot(labels: torch.Tensor,
num_classes: int,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None) -> torch.Tensor:
dtype: Optional[torch.dtype] = None,
eps: Optional[float] = 1e-6) -> torch.Tensor:
r"""Converts an integer label 2D tensor to a one-hot 3D tensor.
Args:
Expand Down Expand Up @@ -50,4 +51,4 @@ def one_hot(labels: torch.Tensor,
batch_size, height, width = labels.shape
one_hot = torch.zeros(batch_size, num_classes, height, width,
device=device, dtype=dtype)
return one_hot.scatter_(1, labels.unsqueeze(1), 1.0)
return one_hot.scatter_(1, labels.unsqueeze(1), 1.0) + eps

0 comments on commit ffe4cb1

Please sign in to comment.