Skip to content

Commit

Permalink
feat: Added implementation of PolyLoss (#209)
Browse files Browse the repository at this point in the history
* feat: Added implementation of PolyLoss

* docs: Added PolyLoss to the documentaton

* docs: Updated README

* test: Updated unittests

* test: Fixed unittest

* test: Updated assertion tolerance

* test: Fixed unittest
  • Loading branch information
frgfm committed May 1, 2022
1 parent f0f6379 commit 94cda24
Show file tree
Hide file tree
Showing 6 changed files with 115 additions and 5 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ pip install -e Holocron/.

### PyTorch layers for every need
- Activation: [HardMish](https://github.com/digantamisra98/H-Mish), [NLReLU](https://arxiv.org/abs/1908.03682), [FReLU](https://arxiv.org/abs/2007.11824)
- Loss: [Focal Loss](https://arxiv.org/abs/1708.02002), MultiLabelCrossEntropy, [MixupLoss](https://arxiv.org/pdf/1710.09412.pdf), [ClassBalancedWrapper](https://arxiv.org/abs/1901.05555), [ComplementCrossEntropy](https://arxiv.org/abs/2009.02189), [MutualChannelLoss](https://arxiv.org/abs/2002.04264), [DiceLoss](https://arxiv.org/abs/1606.04797)
- Loss: [Focal Loss](https://arxiv.org/abs/1708.02002), MultiLabelCrossEntropy, [MixupLoss](https://arxiv.org/pdf/1710.09412.pdf), [ClassBalancedWrapper](https://arxiv.org/abs/1901.05555), [ComplementCrossEntropy](https://arxiv.org/abs/2009.02189), [MutualChannelLoss](https://arxiv.org/abs/2002.04264), [DiceLoss](https://arxiv.org/abs/1606.04797), [PolyLoss](https://arxiv.org/abs/2204.12511)
- Convolutions: [NormConv2d](https://arxiv.org/pdf/2005.05274v2.pdf), [Add2d](https://arxiv.org/pdf/1912.13200.pdf), [SlimConv2d](https://arxiv.org/pdf/2003.07469.pdf), [PyConv2d](https://arxiv.org/abs/2006.11538), [Involution](https://arxiv.org/abs/2103.06255)
- Regularization: [DropBlock](https://arxiv.org/abs/1810.12890)
- Pooling: [BlurPool2d](https://arxiv.org/abs/1904.11486), [SPP](https://arxiv.org/abs/1406.4729), [ZPool](https://arxiv.org/abs/2010.03045)
Expand Down
2 changes: 2 additions & 0 deletions docs/source/nn.functional.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ Loss functions

.. autofunction:: dice_loss

..autofunction:: poly_loss

Convolutions
------------

Expand Down
2 changes: 2 additions & 0 deletions docs/source/nn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ Loss functions

.. autoclass:: DiceLoss

.. autoclass:: PolyLoss


Loss wrappers
--------------
Expand Down
57 changes: 56 additions & 1 deletion holocron/nn/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

__all__ = ['hard_mish', 'nl_relu', 'focal_loss', 'multilabel_cross_entropy', 'complement_cross_entropy',
'mutual_channel_loss', 'norm_conv2d', 'add2d', 'dropblock2d', 'z_pool', 'concat_downsample2d',
'dice_loss']
'dice_loss', 'poly_loss']


def hard_mish(x: Tensor, inplace: bool = False) -> Tensor:
Expand Down Expand Up @@ -547,3 +547,58 @@ def dice_loss(
loss = 1 - (1 + 1 / gamma) * (weight * dice_coeff).sum() / weight.sum()

return loss


def poly_loss(
x: Tensor,
target: Tensor,
eps: float = 2.,
weight: Optional[Tensor] = None,
ignore_index: int = -100,
reduction: str = 'mean',
) -> Tensor:
"""Implements the Poly1 loss from `"PolyLoss: A Polynomial Expansion Perspective of Classification Loss
Functions" <https://arxiv.org/pdf/2204.12511.pdf>`_.
Args:
x (torch.Tensor[N, K, ...]): predicted probability
target (torch.Tensor[N, K, ...]): target probability
eps (float, optional): epsilon 1 from the paper
weight (torch.Tensor[K], optional): manual rescaling of each class
ignore_index (int, optional): specifies target value that is ignored and do not contribute to gradient
reduction (str, optional): reduction method
Returns:
torch.Tensor: loss reduced with `reduction` method
"""

# log(P[class]) = log_softmax(score)[class]
logpt = F.log_softmax(x, dim=1)

# Compute pt and logpt only for target classes (the remaining will have a 0 coefficient)
logpt = logpt.transpose(1, 0).flatten(1).gather(0, target.view(1, -1)).squeeze()
# Ignore index (set loss contribution to 0)
valid_idxs = torch.ones(target.view(-1).shape[0], dtype=torch.bool, device=x.device)
if ignore_index >= 0 and ignore_index < x.shape[1]:
valid_idxs[target.view(-1) == ignore_index] = False

# Get P(class)
loss = -1 * logpt + eps * (1 - logpt.exp())

# Weight
if weight is not None:
# Tensor type
if weight.type() != x.data.type():
weight = weight.type_as(x.data)
logpt = weight.gather(0, target.data.view(-1)) * logpt

# Loss reduction
if reduction == 'sum':
loss = loss[valid_idxs].sum()
elif reduction == 'mean':
loss = loss[valid_idxs].mean()
else:
# if no reduction, reshape tensor like target
loss = loss.view(*target.shape)

return loss
31 changes: 29 additions & 2 deletions holocron/nn/modules/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from .. import functional as F

__all__ = ['FocalLoss', 'MultiLabelCrossEntropy', 'ComplementCrossEntropy',
'ClassBalancedWrapper', 'MutualChannelLoss', 'DiceLoss']
'ClassBalancedWrapper', 'MutualChannelLoss', 'DiceLoss', 'PolyLoss']


class _Loss(nn.Module):
Expand All @@ -21,7 +21,7 @@ def __init__(
self,
weight: Optional[Union[float, List[float], Tensor]] = None,
ignore_index: int = -100,
reduction: str = 'mean'
reduction: str = 'mean',
) -> None:
super().__init__()
# Cast class weights if possible
Expand Down Expand Up @@ -211,3 +211,30 @@ def forward(self, x: Tensor, target: Tensor) -> Tensor:

def __repr__(self) -> str:
return f"{self.__class__.__name__}(reduction='{self.reduction}', gamma={self.gamma}, eps={self.eps})"


class PolyLoss(_Loss):
"""Implements the Poly1 loss from `"PolyLoss: A Polynomial Expansion Perspective of Classification Loss
Functions" <https://arxiv.org/pdf/2204.12511.pdf>`_.
Args:
weight (torch.Tensor[K], optional): class weight for loss computation
eps (float, optional): epsilon 1 from the paper
ignore_index: int = -100,
reduction: str = 'mean',
"""

def __init__(
self,
*args: Any,
eps: float = 2.,
**kwargs: Any,
) -> None:
super().__init__(*args, **kwargs)
self.eps = eps

def forward(self, x: Tensor, target: Tensor) -> Tensor:
return F.poly_loss(x, target, self.eps, self.weight, self.ignore_index, self.reduction)

def __repr__(self) -> str:
return f"{self.__class__.__name__}(eps={self.eps}, reduction='{self.reduction}')"
26 changes: 25 additions & 1 deletion tests/test_nn_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def _test_loss_function(loss_fn, same_loss=0., multi_label=False):
num_classes = 4
# 4 classes
x = torch.ones(num_batches, num_classes)
x[:, 0, ...] = 10
x[:, 0, ...] = 100
x.requires_grad_(True)

# Identical target
Expand Down Expand Up @@ -204,3 +204,27 @@ def test_dice_loss():
out.backward()

assert repr(nn.DiceLoss()) == "DiceLoss(reduction='mean', gamma=1.0, eps=1e-08)"


def test_poly_loss():

_test_loss_function(F.poly_loss)

num_batches = 2
num_classes = 4

x = torch.rand((num_batches, num_classes, 20, 20), requires_grad=True)
target = torch.rand(num_batches, num_classes, 20, 20)
target = (num_classes * torch.rand(num_batches, 20, 20)).to(torch.long)

# Backprop
out = F.poly_loss(x, target)
out.backward()

# Weighted loss
class_weights = torch.ones(num_classes)
class_weights[0] = 2
out = F.poly_loss(x, target, weight=class_weights)
out.backward()

assert repr(nn.PolyLoss()) == "PolyLoss(eps=2.0, reduction='mean')"

0 comments on commit 94cda24

Please sign in to comment.