Skip to content

Commit

Permalink
Upstream multiclass_perceptron_loss jaxopt loss to optax.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 619473957
  • Loading branch information
mtthss authored and OptaxDev committed Mar 27, 2024
1 parent e930df7 commit f3213b4
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 1 deletion.
27 changes: 26 additions & 1 deletion optax/losses/_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def hinge_loss(
targets: Target values. Target values should be strictly in the set {-1, 1}.
Returns:
Binary Hinge Loss.
loss value.
"""
return jnp.maximum(0, 1 - predictor_outputs * targets)

Expand All @@ -99,6 +99,7 @@ def perceptron_loss(
Returns:
loss value.
"""
chex.assert_equal_shape([predictor_outputs, targets])
return jnp.maximum(0, - predictor_outputs * targets)


Expand Down Expand Up @@ -172,6 +173,30 @@ def multiclass_logistic_loss(logits, labels):
return softmax_cross_entropy_with_integer_labels(logits, labels)


def multiclass_perceptron_loss(
scores: chex.Array,
label: chex.Array,
) -> chex.Array:
"""Binary perceptron loss.
References:
Michael Collins. Discriminative training methods for Hidden Markov Models:
Theory and experiments with perceptron algorithms. EMNLP 2002
Args:
scores: score produced by the model.
label: ground-truth integer label.
Returns:
loss value.
.. versionadded:: 0.2.2
"""
one_hot_label = jax.nn.one_hot(label, scores.shape[-1])
dot_last_dim = jnp.vectorize(jnp.dot, signature='(n),(n)->()')
return jnp.max(scores, axis=-1) - dot_last_dim(scores, one_hot_label)


@functools.partial(chex.warn_only_n_pos_args_in_future, n=2)
def poly_loss_cross_entropy(
logits: chex.Array,
Expand Down
41 changes: 41 additions & 0 deletions optax/losses/_classification_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,47 @@ def test_batched(self):
)


class PerceptronTest(parameterized.TestCase):

def test_binary(self):
label = jnp.array(1)
signed_label = jnp.array(2.0 * label - 1.0)
score = jnp.array(10.)
def reference_impl(label, logit) -> float:
return jax.nn.relu(- logit * (2.0 * label - 1.0))
expected = reference_impl(label, score)
result = _classification.perceptron_loss(score, signed_label)
np.testing.assert_allclose(result, expected, atol=1e-4)

def test_batched_binary(self):
labels = jnp.array([1, 0])
signed_labels = jnp.array(2.0 * labels - 1.0)
scores = jnp.array([10., 20.])
def reference_impl(label, logit) -> float:
return jax.nn.relu(- logit * (2.0 * label - 1.0))
expected = jax.vmap(reference_impl)(labels, scores)
result = _classification.perceptron_loss(scores, signed_labels)
np.testing.assert_allclose(result, expected, atol=1e-4)

def test_multi_class(self):
label = jnp.array(1)
scores = jnp.array([10., 3.])
def reference_impl(label, scores):
return jnp.max(scores) - scores[label]
expected = reference_impl(label, scores)
result = _classification.multiclass_perceptron_loss(scores, label)
np.testing.assert_allclose(result, expected, atol=1e-4)

def test_batched_multi_class(self):
label = jnp.array([1, 0])
scores = jnp.array([[10., 3.], [11., -2.]])
def reference_impl(label, scores):
return jnp.max(scores) - scores[label]
expected = jax.vmap(reference_impl)(label, scores)
result = _classification.multiclass_perceptron_loss(scores, label)
np.testing.assert_allclose(result, expected, atol=1e-4)


class KLDivergenceTest(parameterized.TestCase):

def setUp(self):
Expand Down

0 comments on commit f3213b4

Please sign in to comment.