From f3213b40884904d35b934d497ecc3d562035df8c Mon Sep 17 00:00:00 2001 From: Matteo Hessel Date: Wed, 27 Mar 2024 03:23:05 -0700 Subject: [PATCH] Upstream multiclass_perceptron_loss jaxopt loss to optax. PiperOrigin-RevId: 619473957 --- optax/losses/_classification.py | 27 +++++++++++++++++- optax/losses/_classification_test.py | 41 ++++++++++++++++++++++++++++ 2 files changed, 67 insertions(+), 1 deletion(-) diff --git a/optax/losses/_classification.py b/optax/losses/_classification.py index a7f878ba..e4a025cb 100644 --- a/optax/losses/_classification.py +++ b/optax/losses/_classification.py @@ -78,7 +78,7 @@ def hinge_loss( targets: Target values. Target values should be strictly in the set {-1, 1}. Returns: - Binary Hinge Loss. + loss value. """ return jnp.maximum(0, 1 - predictor_outputs * targets) @@ -99,6 +99,7 @@ def perceptron_loss( Returns: loss value. """ + chex.assert_equal_shape([predictor_outputs, targets]) return jnp.maximum(0, - predictor_outputs * targets) @@ -172,6 +173,30 @@ def multiclass_logistic_loss(logits, labels): return softmax_cross_entropy_with_integer_labels(logits, labels) +def multiclass_perceptron_loss( + scores: chex.Array, + label: chex.Array, +) -> chex.Array: + """Binary perceptron loss. + + References: + Michael Collins. Discriminative training methods for Hidden Markov Models: + Theory and experiments with perceptron algorithms. EMNLP 2002 + + Args: + scores: score produced by the model. + label: ground-truth integer label. + + Returns: + loss value. + + .. versionadded:: 0.2.2 + """ + one_hot_label = jax.nn.one_hot(label, scores.shape[-1]) + dot_last_dim = jnp.vectorize(jnp.dot, signature='(n),(n)->()') + return jnp.max(scores, axis=-1) - dot_last_dim(scores, one_hot_label) + + @functools.partial(chex.warn_only_n_pos_args_in_future, n=2) def poly_loss_cross_entropy( logits: chex.Array, diff --git a/optax/losses/_classification_test.py b/optax/losses/_classification_test.py index db1af2d3..ab8229c0 100644 --- a/optax/losses/_classification_test.py +++ b/optax/losses/_classification_test.py @@ -259,6 +259,47 @@ def test_batched(self): ) +class PerceptronTest(parameterized.TestCase): + + def test_binary(self): + label = jnp.array(1) + signed_label = jnp.array(2.0 * label - 1.0) + score = jnp.array(10.) + def reference_impl(label, logit) -> float: + return jax.nn.relu(- logit * (2.0 * label - 1.0)) + expected = reference_impl(label, score) + result = _classification.perceptron_loss(score, signed_label) + np.testing.assert_allclose(result, expected, atol=1e-4) + + def test_batched_binary(self): + labels = jnp.array([1, 0]) + signed_labels = jnp.array(2.0 * labels - 1.0) + scores = jnp.array([10., 20.]) + def reference_impl(label, logit) -> float: + return jax.nn.relu(- logit * (2.0 * label - 1.0)) + expected = jax.vmap(reference_impl)(labels, scores) + result = _classification.perceptron_loss(scores, signed_labels) + np.testing.assert_allclose(result, expected, atol=1e-4) + + def test_multi_class(self): + label = jnp.array(1) + scores = jnp.array([10., 3.]) + def reference_impl(label, scores): + return jnp.max(scores) - scores[label] + expected = reference_impl(label, scores) + result = _classification.multiclass_perceptron_loss(scores, label) + np.testing.assert_allclose(result, expected, atol=1e-4) + + def test_batched_multi_class(self): + label = jnp.array([1, 0]) + scores = jnp.array([[10., 3.], [11., -2.]]) + def reference_impl(label, scores): + return jnp.max(scores) - scores[label] + expected = jax.vmap(reference_impl)(label, scores) + result = _classification.multiclass_perceptron_loss(scores, label) + np.testing.assert_allclose(result, expected, atol=1e-4) + + class KLDivergenceTest(parameterized.TestCase): def setUp(self):