Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upstream sparsemax jaxopt loss to optax. #899

Merged
merged 1 commit into from
Apr 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 34 additions & 3 deletions optax/losses/_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,37 @@ def perceptron_loss(
return jnp.maximum(0, - predictor_outputs * targets)


def sparsemax_loss(
logits: chex.Array,
labels: chex.Array,
) -> chex.Array:
"""Binary sparsemax loss.

This loss is zero if and only if `jax.nn.sparse_sigmoid(logits) == labels`.

References:
Learning with Fenchel-Young Losses. Mathieu Blondel, André F. T. Martins,
Vlad Niculae. JMLR 2020. (Sec. 4.4)

Args:
logits: score produced by the model (float).
labels: ground-truth integer label (0 or 1).

Returns:
loss value

.. versionadded:: 0.2.3
"""
return jax.nn.sparse_plus(jnp.where(labels, -logits, logits))


@functools.partial(
chex.warn_deprecated_function,
replacement='sparsemax_loss')
def binary_sparsemax_loss(logits, labels):
return sparsemax_loss(logits, labels)


def softmax_cross_entropy(
logits: chex.Array,
labels: chex.Array,
Expand Down Expand Up @@ -183,16 +214,16 @@ def multiclass_hinge_loss(
) -> chex.Array:
"""Multiclass hinge loss.

References:
https://en.wikipedia.org/wiki/Hinge_loss

Args:
scores: scores produced by the model (floats).
labels: ground-truth integer label.

Returns:
loss value

References:
https://en.wikipedia.org/wiki/Hinge_loss

.. versionadded:: 0.2.3
"""
one_hot_labels = jax.nn.one_hot(labels, scores.shape[-1])
Expand Down
37 changes: 37 additions & 0 deletions optax/losses/_classification_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,43 @@ def reference_impl(label, scores):
np.testing.assert_allclose(result, expected, atol=1e-4)


class SparsemaxTest(parameterized.TestCase):

def test_binary(self):
label = 1
score = 10.
def reference_impl(label, logit):
scores = -(2*label-1)*logit
if scores <= -1.0:
return 0.0
elif scores >= 1.0:
return scores
else:
return (scores + 1.0) ** 2 / 4
expected = reference_impl(label, score)
result = _classification.sparsemax_loss(
jnp.asarray(score), jnp.asarray(label))
np.testing.assert_allclose(result, expected, atol=1e-4)

def test_batched_binary(self):
labels = jnp.array([1, 0])
scores = jnp.array([10., 20.])
def reference_impl(label, logit):
scores = -(2*label-1)*logit
if scores <= -1.0:
return 0.0
elif scores >= 1.0:
return scores
else:
return (scores + 1.0) ** 2 / 4
expected = jnp.asarray([
reference_impl(labels[0], scores[0]),
reference_impl(labels[1], scores[1])])
# in the optax loss the leading dimensions are automatically handled.
result = _classification.sparsemax_loss(scores, labels)
np.testing.assert_allclose(result, expected, atol=1e-4)


class ConvexKLDivergenceTest(parameterized.TestCase):

def setUp(self):
Expand Down
Loading