Skip to content

Commit

Permalink
Merge pull request #467 from acforvs:poly-loss
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 540278941
  • Loading branch information
OptaxDev committed Jun 14, 2023
2 parents 802161c + be72990 commit f527be8
Show file tree
Hide file tree
Showing 2 changed files with 132 additions and 4 deletions.
51 changes: 47 additions & 4 deletions optax/_src/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,13 +129,11 @@ def smooth_labels(
[Müller et al, 2019](https://arxiv.org/pdf/1906.02629.pdf)
Args:
labels: one hot labels to be smoothed.
alpha: the smoothing factor, the greedy category with be assigned
probability `(1-alpha) + alpha / num_categories`
labels: One hot labels to be smoothed.
alpha: The smoothing factor.
Returns:
a smoothed version of the one hot input labels.
"""
chex.assert_type([labels], float)
num_categories = labels.shape[-1]
Expand Down Expand Up @@ -582,3 +580,48 @@ def hinge_loss(predictor_outputs: chex.Array,
Binary Hinge Loss.
"""
return jnp.maximum(0, 1 - predictor_outputs * targets)


def poly_loss_cross_entropy(
logits: chex.Array, labels: chex.Array, epsilon: float = 2.0
) -> chex.Array:
r"""Computes PolyLoss between logits and labels.
The PolyLoss is a loss function that decomposes commonly
used classification loss functions into a series of weighted
polynomial bases. It is inspired by the Taylor expansion of
cross-entropy loss and focal loss in the bases of :math:`(1 − P_t)^j`.
.. math::
L_{Poly} = \sum_1^\infty \alpha_j \cdot (1 - P_t)^j \\
L_{Poly-N} = (\epsilon_1 + 1) \cdot (1 - P_t) + \ldots + \\
(\epsilon_N + \frac{1}{N}) \cdot (1 - P_t)^N +
\frac{1}{N + 1} \cdot (1 - P_t)^{N + 1} + \ldots = \\
- \log(P_t) + \sum_{j = 1}^N \epsilon_j \cdot (1 - P_t)^j
This function provides a simplified version of :math:`L_{Poly-N}`
with only the coefficient of the first polynomial term being changed.
References:
[Zhaoqi Leng et al, 2022](https://arxiv.org/pdf/2204.12511.pdf)
Args:
logits: Unnormalized log probabilities, with shape `[..., num_classes]`.
labels: Valid probability distributions (non-negative, sum to 1), e.g. a
one hot encoding specifying the correct class for each input;
must have a shape broadcastable to `[..., num_classes]`.
epsilon: The coefficient of the first polynomial term.
According to the paper, the following values are recommended:
- For the ImageNet 2d image classification, epsilon = 2.0.
- For the 2d Instance Segmentation and object detection, epsilon = -1.0.
- It is also recommended to adjust this value based on the task, e.g. by
using grid search.
Returns:
Poly loss between each prediction and the corresponding target
distributions, with shape `[...]`.
"""
chex.assert_type([logits, labels], float)
one_minus_pt = jnp.sum(labels * (1 - jax.nn.softmax(logits)), axis=-1)
cross_entropy = softmax_cross_entropy(logits=logits, labels=labels)
return cross_entropy + epsilon * one_minus_pt
85 changes: 85 additions & 0 deletions optax/_src/loss_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,5 +556,90 @@ def test_batched(self):
self.correct_result,
atol=1e-4)


class PolyLossTest(parameterized.TestCase):

def setUp(self):
super().setUp()
self.logits = np.array([0.14, 1.456, 2.356, -0.124, -2.47])
self.labels = np.array([0.1, 0.15, 0.2, 0.25, 0.3])

self.batched_logits = np.array([[4.0, 2.0, 1.0], [0.0, 5.0, 1.0]])
self.batched_labels = np.array([[1.0, 0.0, 0.0], [0.0, 0.8, 0.2]])
# all expected values are computed using tf version of `poly1_cross_entropy`
# see page 10 here https://arxiv.org/pdf/2204.12511.pdf for more

@chex.all_variants
@parameterized.parameters(
dict(eps=2, expected=4.5317),
dict(eps=1, expected=3.7153),
dict(eps=-1, expected=2.0827),
dict(eps=0, expected=2.8990),
dict(eps=-0.5, expected=2.4908),
dict(eps=1.15, expected=3.8378),
dict(eps=1.214, expected=3.8900),
dict(eps=5.45, expected=7.3480),
)
def test_scalar(self, eps, expected):
np.testing.assert_allclose(
self.variant(loss.poly_loss_cross_entropy)(
self.logits, self.labels, eps
),
expected,
atol=1e-4,
)

@chex.all_variants
@parameterized.parameters(
dict(eps=2, expected=np.array([0.4823, 1.2567])),
dict(eps=1, expected=np.array([0.3261, 1.0407])),
dict(eps=0, expected=np.array([0.1698, 0.8247])),
dict(eps=-0.5, expected=np.array([0.0917, 0.7168])),
dict(eps=1.15, expected=np.array([0.3495, 1.0731])),
dict(eps=1.214, expected=np.array([0.3595, 1.0870])),
dict(eps=5.45, expected=np.array([1.0211, 2.0018])),
)
def test_batched(self, eps, expected):
np.testing.assert_allclose(
self.variant(loss.poly_loss_cross_entropy)(
self.batched_logits, self.batched_labels, eps
),
expected,
atol=1e-4,
)

@chex.all_variants
@parameterized.parameters(
dict(
logits=np.array(
[[4.0, 2.0, 1.0], [0.0, 5.0, 1.0], [0.134, 1.234, 3.235]]
),
labels=np.array(
[[1.0, 0.0, 0.0], [0.0, 0.8, 0.2], [0.34, 0.33, 0.33]]
),
),
dict(
logits=np.array([[4.0, 2.0, 1.0], [0.0, 5.0, 1.0]]),
labels=np.array([[1.0, 0.0, 0.0], [0.0, 0.8, 0.2]]),
),
dict(
logits=np.array(
[[4.0, 2.0, 1.0, 0.134, 1.3515], [0.0, 5.0, 1.0, 0.5215, 5.616]]
),
labels=np.array(
[[0.5, 0.0, 0.0, 0.0, 0.5], [0.0, 0.12, 0.2, 0.56, 0.12]]
),
),
dict(logits=np.array([1.89, 2.39]), labels=np.array([0.34, 0.66])),
dict(logits=np.array([0.314]), labels=np.array([1.0])),
)
def test_equals_to_cross_entropy_when_eps0(self, logits, labels):
np.testing.assert_allclose(
self.variant(loss.poly_loss_cross_entropy)(logits, labels, epsilon=0.0),
self.variant(loss.softmax_cross_entropy)(logits, labels),
atol=1e-4,
)


if __name__ == '__main__':
absltest.main()

0 comments on commit f527be8

Please sign in to comment.