Skip to content

Commit

Permalink
[Feat] Add tpu support for the losses module (#834)
Browse files Browse the repository at this point in the history
* improve kornia.enhance docs and coverage

* refactor focal loss test and jit

* refactor tversky loss

* refactor dice_loss

* enable tests for tpu in losses

* few tpu fixes

* refactor inverse_depth_smoothness_loss
  • Loading branch information
edgarriba committed Jan 12, 2021
1 parent 91fa75a commit 7634391
Show file tree
Hide file tree
Showing 9 changed files with 347 additions and 137 deletions.
2 changes: 1 addition & 1 deletion docker/tpu-tests/tpu_test_cases.jsonnet
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ local tputests = base.BaseTest {
python -c "import torch; print(torch.__version__)"
python -c "import torch_xla; print(torch_xla.__version__)"
python -c "import kornia; print(kornia.__version__)"
pytest -v kornia/test/color kornia/test/enhance kornia/test/filters --device tpu --dtype float32 -k "not grad"
pytest -v kornia/test/color kornia/test/enhance kornia/test/filters kornia/test/test_losses.py --device tpu --dtype float32 -k "not grad"
test_exit_code=$?
echo "\nFinished running commands.\n"
test $test_exit_code -eq 0
Expand Down
18 changes: 18 additions & 0 deletions docs/source/references.bib
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,21 @@ @inproceedings{yun2019cutmix
booktitle = {International Conference on Computer Vision (ICCV)},
year={2019},
}

@misc{lin2018focal,
title={Focal Loss for Dense Object Detection},
author={Tsung-Yi Lin and Priya Goyal and Ross Girshick and Kaiming He and Piotr Dollár},
year={2018},
eprint={1708.02002},
archivePrefix={arXiv},
primaryClass={cs.CV}
}

@misc{salehi2017tversky,
title={Tversky loss function for image segmentation using 3D fully convolutional deep networks},
author={Seyed Sadegh Mohseni Salehi and Deniz Erdogmus and Ali Gholipour},
year={2017},
eprint={1706.05721},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
31 changes: 23 additions & 8 deletions kornia/losses/depth_smooth.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,31 @@ def _gradient_y(img: torch.Tensor) -> torch.Tensor:
def inverse_depth_smoothness_loss(
idepth: torch.Tensor,
image: torch.Tensor) -> torch.Tensor:
r"""Computes image-aware inverse depth smoothness loss.
r"""Criterion that computes image-aware inverse depth smoothness loss.
.. math::
\text{loss} = \left | \partial_x d_{ij} \right | e^{-\left \|
\partial_x I_{ij} \right \|} + \left |
\partial_y d_{ij} \right | e^{-\left \| \partial_y I_{ij} \right \|}
See :class:`~kornia.losses.InverseDepthSmoothnessLoss` for details.
Args:
idepth (torch.Tensor): tensor with the inverse depth with shape :math:`(N, 1, H, W)`.
image (torch.Tensor): tensor with the input image with shape :math:`(N, 3, H, W)`.
Return:
torch.Tensor: a scalar with the computed loss.
Examples:
>>> idepth = torch.rand(1, 1, 4, 5)
>>> image = torch.rand(1, 3, 4, 5)
>>> loss = inverse_depth_smoothness_loss(idepth, image)
"""
if not torch.is_tensor(idepth):
if not isinstance(idepth, torch.Tensor):
raise TypeError("Input idepth type is not a torch.Tensor. Got {}"
.format(type(idepth)))

if not torch.is_tensor(image):
if not isinstance(image, torch.Tensor):
raise TypeError("Input image type is not a torch.Tensor. Got {}"
.format(type(image)))

Expand Down Expand Up @@ -68,6 +84,7 @@ def inverse_depth_smoothness_loss(
# apply image weights to depth
smoothness_x: torch.Tensor = torch.abs(idepth_dx * weights_x)
smoothness_y: torch.Tensor = torch.abs(idepth_dy * weights_y)

return torch.mean(smoothness_x) + torch.mean(smoothness_y)


Expand All @@ -80,14 +97,12 @@ class InverseDepthSmoothnessLoss(nn.Module):
\partial_x I_{ij} \right \|} + \left |
\partial_y d_{ij} \right | e^{-\left \| \partial_y I_{ij} \right \|}
Shape:
- Inverse Depth: :math:`(N, 1, H, W)`
- Image: :math:`(N, 3, H, W)`
- Output: scalar
Examples::
Examples:
>>> idepth = torch.rand(1, 1, 4, 5)
>>> image = torch.rand(1, 3, 4, 5)
>>> smooth = InverseDepthSmoothnessLoss()
Expand All @@ -97,5 +112,5 @@ class InverseDepthSmoothnessLoss(nn.Module):
def __init__(self) -> None:
super(InverseDepthSmoothnessLoss, self).__init__()

def forward(self, idepth: torch.Tensor, image: torch.Tensor) -> torch.Tensor: # type:ignore
def forward(self, idepth: torch.Tensor, image: torch.Tensor) -> torch.Tensor:
return inverse_depth_smoothness_loss(idepth, image)
63 changes: 49 additions & 14 deletions kornia/losses/dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,44 @@
# https://github.com/kevinzakka/pytorch-goodies/blob/master/losses.py

def dice_loss(input: torch.Tensor, target: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
r"""Function that computes Sørensen-Dice Coefficient loss.
r"""Criterion that computes Sørensen-Dice Coefficient loss.
According to [1], we compute the Sørensen-Dice Coefficient as follows:
.. math::
\text{Dice}(x, class) = \frac{2 |X| \cap |Y|}{|X| + |Y|}
See :class:`~kornia.losses.DiceLoss` for details.
Where:
- :math:`X` expects to be the scores of each class.
- :math:`Y` expects to be the one-hot tensor with the class labels.
the loss, is finally computed as:
.. math::
\text{loss}(x, class) = 1 - \text{Dice}(x, class)
Reference:
[1] https://en.wikipedia.org/wiki/S%C3%B8rensen%E2%80%93Dice_coefficient
Args:
input (torch.Tensor): logits tensor with shape :math:`(N, C, H, W)` where C = number of classes.
labels (torch.Tensor): labels tensor with shape :math:`(N, H, W)` where each value
is :math:`0 ≤ targets[i] ≤ C−1`.
eps (float, optional): Scalar to enforce numerical stabiliy. Default: 1e-8.
Return:
torch.Tensor: the computed loss.
Example:
>>> N = 5 # num_classes
>>> input = torch.randn(1, N, 3, 5, requires_grad=True)
>>> target = torch.empty(1, 3, 5, dtype=torch.long).random_(N)
>>> output = dice_loss(input, target)
>>> output.backward()
"""
if not torch.is_tensor(input):
if not isinstance(input, torch.Tensor):
raise TypeError("Input type is not a torch.Tensor. Got {}"
.format(type(input)))

Expand Down Expand Up @@ -46,6 +79,7 @@ def dice_loss(input: torch.Tensor, target: torch.Tensor, eps: float = 1e-8) -> t
cardinality = torch.sum(input_soft + target_one_hot, dims)

dice_score = 2. * intersection / (cardinality + eps)

return torch.mean(-dice_score + 1.)


Expand All @@ -58,7 +92,7 @@ class DiceLoss(nn.Module):
\text{Dice}(x, class) = \frac{2 |X| \cap |Y|}{|X| + |Y|}
where:
Where:
- :math:`X` expects to be the scores of each class.
- :math:`Y` expects to be the one-hot tensor with the class labels.
Expand All @@ -68,28 +102,29 @@ class DiceLoss(nn.Module):
\text{loss}(x, class) = 1 - \text{Dice}(x, class)
[1] https://en.wikipedia.org/wiki/S%C3%B8rensen%E2%80%93Dice_coefficient
Reference:
[1] https://en.wikipedia.org/wiki/S%C3%B8rensen%E2%80%93Dice_coefficient
Args:
eps (float, optional): Scalar to enforce numerical stabiliy. Default: 1e-8.
Shape:
- Input: :math:`(N, C, H, W)` where C = number of classes.
- Target: :math:`(N, H, W)` where each value is
:math:`0 ≤ targets[i] ≤ C−1`.
Examples:
Example:
>>> N = 5 # num_classes
>>> loss = DiceLoss()
>>> criterion = DiceLoss()
>>> input = torch.randn(1, N, 3, 5, requires_grad=True)
>>> target = torch.empty(1, 3, 5, dtype=torch.long).random_(N)
>>> output = loss(input, target)
>>> output = criterion(input, target)
>>> output.backward()
"""

def __init__(self) -> None:
def __init__(self, eps: float = 1e-8) -> None:
super(DiceLoss, self).__init__()
self.eps: float = 1e-6
self.eps: float = eps

def forward( # type: ignore
self,
input: torch.Tensor,
target: torch.Tensor) -> torch.Tensor:
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
return dice_loss(input, target, self.eps)
9 changes: 4 additions & 5 deletions kornia/losses/divergence.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch.nn.functional as F


def _kl_div_2d(p, q):
def _kl_div_2d(p: torch.Tensor, q: torch.Tensor) -> torch.Tensor:
# D_KL(P || Q)
batch, chans, height, width = p.shape
unsummed_kl = F.kl_div(
Expand All @@ -17,19 +17,18 @@ def _kl_div_2d(p, q):
return kl_values


def _js_div_2d(p, q):
def _js_div_2d(p: torch.Tensor, q: torch.Tensor) -> torch.Tensor:
# JSD(P || Q)
m = 0.5 * (p + q)
return 0.5 * _kl_div_2d(p, m) + 0.5 * _kl_div_2d(q, m)

# TODO: add this to the main module


def _reduce_loss(losses, reduction):
def _reduce_loss(losses: torch.Tensor, reduction: str) -> torch.Tensor:
if reduction == 'none':
return losses
else:
return torch.mean(losses) if reduction == 'mean' else torch.sum(losses)
return torch.mean(losses) if reduction == 'mean' else torch.sum(losses)


def js_div_loss_2d(
Expand Down
62 changes: 42 additions & 20 deletions kornia/losses/focal.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,39 @@ def focal_loss(
gamma: float = 2.0,
reduction: str = 'none',
eps: float = 1e-8) -> torch.Tensor:
r"""Function that computes Focal loss.
r"""Criterion that computes Focal loss.
According to :cite:`lin2018focal`, the Focal loss is computed as follows:
.. math::
\text{FL}(p_t) = -\alpha_t (1 - p_t)^{\gamma} \, \text{log}(p_t)
Where:
- :math:`p_t` is the model's estimated probability for each class.
Args:
input (torch.Tensor): logits tensor with shape :math:`(N, C, *)` where C = number of classes.
target (torch.Tensor): labels tensor with shape :math:`(N, *)` where each value is :math:`0 ≤ targets[i] ≤ C−1`.
alpha (float): Weighting factor :math:`\alpha \in [0, 1]`.
gamma (float, optional): Focusing parameter :math:`\gamma >= 0`. Default 2.
reduction (str, optional): Specifies the reduction to apply to the
output: ‘none’ | ‘mean’ | ‘sum’. ‘none’: no reduction will be applied,
‘mean’: the sum of the output will be divided by the number of elements
in the output, ‘sum’: the output will be summed. Default: ‘none’.
eps (float, optional): Scalar to enforce numerical stabiliy. Default: 1e-8.
Return:
torch.Tensor: the computed loss.
See :class:`~kornia.losses.FocalLoss` for details.
Example:
>>> N = 5 # num_classes
>>> input = torch.randn(1, N, 3, 5, requires_grad=True)
>>> target = torch.empty(1, 3, 5, dtype=torch.long).random_(N)
>>> output = focal_loss(input, target, alpha=0.5, gamma=2.0, reduction='mean')
>>> output.backward()
"""
if not torch.is_tensor(input):
if not isinstance(input, torch.Tensor):
raise TypeError("Input type is not a torch.Tensor. Got {}"
.format(type(input)))

Expand Down Expand Up @@ -73,54 +101,48 @@ def focal_loss(
class FocalLoss(nn.Module):
r"""Criterion that computes Focal loss.
According to [1], the Focal loss is computed as follows:
According to :cite:`lin2018focal`, the Focal loss is computed as follows:
.. math::
\text{FL}(p_t) = -\alpha_t (1 - p_t)^{\gamma} \, \text{log}(p_t)
where:
Where:
- :math:`p_t` is the model's estimated probability for each class.
Arguments:
Args:
alpha (float): Weighting factor :math:`\alpha \in [0, 1]`.
gamma (float): Focusing parameter :math:`\gamma >= 0`.
gamma (float, optional): Focusing parameter :math:`\gamma >= 0`. Default 2.
reduction (str, optional): Specifies the reduction to apply to the
output: ‘none’ | ‘mean’ | ‘sum’. ‘none’: no reduction will be applied,
‘mean’: the sum of the output will be divided by the number of elements
in the output, ‘sum’: the output will be summed. Default: ‘none’.
eps (float, optional): Scalar to enforce numerical stabiliy. Default: 1e-8.
Shape:
- Input: :math:`(N, C, *)` where C = number of classes.
- Target: :math:`(N, *)` where each value is
:math:`0 ≤ targets[i] ≤ C−1`.
Examples:
Example:
>>> N = 5 # num_classes
>>> kwargs = {"alpha": 0.5, "gamma": 2.0, "reduction": 'mean'}
>>> loss = FocalLoss(**kwargs)
>>> criterion = FocalLoss(**kwargs)
>>> input = torch.randn(1, N, 3, 5, requires_grad=True)
>>> target = torch.empty(1, 3, 5, dtype=torch.long).random_(N)
>>> output = loss(input, target)
>>> output = criterion(input, target)
>>> output.backward()
References:
[1] https://arxiv.org/abs/1708.02002
"""

def __init__(self, alpha: float, gamma: float = 2.0,
reduction: str = 'none') -> None:
reduction: str = 'none', eps: float = 1e-8) -> None:
super(FocalLoss, self).__init__()
self.alpha: float = alpha
self.gamma: float = gamma
self.reduction: str = reduction
self.eps: float = 1e-6
self.eps: float = eps

def forward( # type: ignore
self,
input: torch.Tensor,
target: torch.Tensor) -> torch.Tensor:
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
return focal_loss(input, target, self.alpha, self.gamma, self.reduction, self.eps)


Expand Down
Loading

0 comments on commit 7634391

Please sign in to comment.