diff --git a/rax/__init__.py b/rax/__init__.py index 87437e1..1b675ab 100644 --- a/rax/__init__.py +++ b/rax/__init__.py @@ -43,7 +43,6 @@ __version__ = "0.2.0" -# pyformat: disable __all__ = [ "dcg2_lambdaweight", "dcg_lambdaweight", @@ -70,7 +69,6 @@ "gumbel_t12n", "segment_t12n", ] -# pyformat: enable # copybara: stripped(1) diff --git a/rax/_src/lambdaweights_test.py b/rax/_src/lambdaweights_test.py index 317bf09..71b55b0 100644 --- a/rax/_src/lambdaweights_test.py +++ b/rax/_src/lambdaweights_test.py @@ -34,36 +34,40 @@ class LambdaweightsTest(parameterized.TestCase): - @parameterized.parameters([{ - "lambdaweight_fn": lambdaweights.labeldiff_lambdaweight, - "expected": [0.0, 1.0, 0.3, 1.0, 0.0, 0.7, 0.3, 0.7, 0.0] - }, { - "lambdaweight_fn": - lambdaweights.dcg_lambdaweight, - "expected": [ - 0.0, - 3.0 * abs(discount(2) - discount(3)) * abs(gain(0.0) - gain(1.0)), - 3.0 * abs(discount(1) - discount(3)) * abs(gain(0.0) - gain(0.3)), - 3.0 * abs(discount(3) - discount(2)) * abs(gain(1.0) - gain(0.0)), - 0.0, - 3.0 * abs(discount(1) - discount(2)) * abs(gain(1.0) - gain(0.3)), - 3.0 * abs(discount(3) - discount(1)) * abs(gain(0.3) - gain(0.0)), - 3.0 * abs(discount(2) - discount(1)) * abs(gain(0.3) - gain(1.0)), 0.0 - ] - }, { - "lambdaweight_fn": - lambdaweights.dcg2_lambdaweight, - "expected": [ - 0.0, - 3.0 * abs(discount(1) - discount(2)) * abs(gain(0.0) - gain(1.0)), - 3.0 * abs(discount(2) - discount(3)) * abs(gain(0.0) - gain(0.3)), - 3.0 * abs(discount(1) - discount(2)) * abs(gain(1.0) - gain(0.0)), - 0.0, - 3.0 * abs(discount(1) - discount(2)) * abs(gain(1.0) - gain(0.3)), - 3.0 * abs(discount(2) - discount(3)) * abs(gain(0.3) - gain(0.0)), - 3.0 * abs(discount(1) - discount(2)) * abs(gain(0.3) - gain(1.0)), 0.0 - ] - }]) + @parameterized.parameters([ + { + "lambdaweight_fn": lambdaweights.labeldiff_lambdaweight, + "expected": [0.0, 1.0, 0.3, 1.0, 0.0, 0.7, 0.3, 0.7, 0.0], + }, + { + "lambdaweight_fn": lambdaweights.dcg_lambdaweight, + "expected": [ + 0.0, + 3.0 * abs(discount(2) - discount(3)) * abs(gain(0.0) - gain(1.0)), + 3.0 * abs(discount(1) - discount(3)) * abs(gain(0.0) - gain(0.3)), + 3.0 * abs(discount(3) - discount(2)) * abs(gain(1.0) - gain(0.0)), + 0.0, + 3.0 * abs(discount(1) - discount(2)) * abs(gain(1.0) - gain(0.3)), + 3.0 * abs(discount(3) - discount(1)) * abs(gain(0.3) - gain(0.0)), + 3.0 * abs(discount(2) - discount(1)) * abs(gain(0.3) - gain(1.0)), + 0.0, + ], + }, + { + "lambdaweight_fn": lambdaweights.dcg2_lambdaweight, + "expected": [ + 0.0, + 3.0 * abs(discount(1) - discount(2)) * abs(gain(0.0) - gain(1.0)), + 3.0 * abs(discount(2) - discount(3)) * abs(gain(0.0) - gain(0.3)), + 3.0 * abs(discount(1) - discount(2)) * abs(gain(1.0) - gain(0.0)), + 0.0, + 3.0 * abs(discount(1) - discount(2)) * abs(gain(1.0) - gain(0.3)), + 3.0 * abs(discount(2) - discount(3)) * abs(gain(0.3) - gain(0.0)), + 3.0 * abs(discount(1) - discount(2)) * abs(gain(0.3) - gain(1.0)), + 0.0, + ], + }, + ]) def test_computes_lambdaweights(self, lambdaweight_fn, expected): scores = jnp.array([0.0, 1.0, 2.0]) labels = jnp.array([0.0, 1.0, 0.3]) @@ -94,7 +98,7 @@ def test_computes_lambdaweights(self, lambdaweight_fn, expected): ], [ gain(2.0) * discount(1) + gain(1.0) * discount(2) + gain(0.0) * discount(3) - ] + ], ] }]) # pyformat: disable def test_computes_normalized_lambdaweights(self, lambdaweight_fn, normalizer): @@ -105,67 +109,113 @@ def test_computes_normalized_lambdaweights(self, lambdaweight_fn, normalizer): result_unnormalized = lambdaweight_fn(scores, labels) np.testing.assert_allclose( - result, result_unnormalized / jnp.array(normalizer), rtol=1e-5) + result, result_unnormalized / jnp.array(normalizer), rtol=1e-5 + ) - @parameterized.parameters([{ - "lambdaweight_fn": - lambdaweights.labeldiff_lambdaweight, - "expected": [[0.0, 2.0, 1.0, 2.0, 0.0, 1.0, 1.0, 1.0, 0.0], - [0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0]] - }, { - "lambdaweight_fn": - lambdaweights.dcg_lambdaweight, - "expected": [ - [ - 0.0, - 3.0 * abs(discount(2) - discount(3)) * abs(gain(0.0) - gain(2.0)), - 3.0 * abs(discount(1) - discount(3)) * abs(gain(0.0) - gain(1.0)), - 3.0 * abs(discount(3) - discount(2)) * abs(gain(2.0) - gain(0.0)), - 0.0, - 3.0 * abs(discount(1) - discount(2)) * abs(gain(2.0) - gain(1.0)), - 3.0 * abs(discount(3) - discount(1)) * abs(gain(1.0) - gain(0.0)), - 3.0 * abs(discount(2) - discount(1)) * abs(gain(1.0) - gain(2.0)), - 0.0 + @parameterized.parameters([ + { + "lambdaweight_fn": lambdaweights.labeldiff_lambdaweight, + "expected": [ + [0.0, 2.0, 1.0, 2.0, 0.0, 1.0, 1.0, 1.0, 0.0], + [0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0], ], - [ - 0.0, 0.0, - 3.0 * abs(discount(2) - discount(3)) * abs(gain(1.0) - gain(0.0)), - 0.0, 0.0, - 3.0 * abs(discount(2) - discount(1)) * abs(gain(0.0) - gain(1.0)), - 3.0 * abs(discount(3) - discount(2)) * abs(gain(1.0) - gain(0.0)), - 3.0 * abs(discount(1) - discount(2)) * abs(gain(1.0) - gain(0.0)), - 0.0 + }, + { + "lambdaweight_fn": lambdaweights.dcg_lambdaweight, + "expected": [ + [ + 0.0, + 3.0 + * abs(discount(2) - discount(3)) + * abs(gain(0.0) - gain(2.0)), + 3.0 + * abs(discount(1) - discount(3)) + * abs(gain(0.0) - gain(1.0)), + 3.0 + * abs(discount(3) - discount(2)) + * abs(gain(2.0) - gain(0.0)), + 0.0, + 3.0 + * abs(discount(1) - discount(2)) + * abs(gain(2.0) - gain(1.0)), + 3.0 + * abs(discount(3) - discount(1)) + * abs(gain(1.0) - gain(0.0)), + 3.0 + * abs(discount(2) - discount(1)) + * abs(gain(1.0) - gain(2.0)), + 0.0, + ], + [ + 0.0, + 0.0, + 3.0 + * abs(discount(2) - discount(3)) + * abs(gain(1.0) - gain(0.0)), + 0.0, + 0.0, + 3.0 + * abs(discount(2) - discount(1)) + * abs(gain(0.0) - gain(1.0)), + 3.0 + * abs(discount(3) - discount(2)) + * abs(gain(1.0) - gain(0.0)), + 3.0 + * abs(discount(1) - discount(2)) + * abs(gain(1.0) - gain(0.0)), + 0.0, + ], ], - ] - }, { - "lambdaweight_fn": - lambdaweights.dcg2_lambdaweight, - "expected": [ - [ - 0.0, - 3.0 * abs(discount(1) - discount(2)) * abs(gain(0.0) - gain(2.0)), - 3.0 * abs(discount(2) - discount(3)) * abs(gain(0.0) - gain(1.0)), - 3.0 * abs(discount(1) - discount(2)) * abs(gain(2.0) - gain(0.0)), - 0.0, - 3.0 * abs(discount(1) - discount(2)) * abs(gain(2.0) - gain(1.0)), - 3.0 * abs(discount(2) - discount(3)) * abs(gain(1.0) - gain(0.0)), - 3.0 * abs(discount(1) - discount(2)) * abs(gain(1.0) - gain(2.0)), - 0.0 + }, + { + "lambdaweight_fn": lambdaweights.dcg2_lambdaweight, + "expected": [ + [ + 0.0, + 3.0 + * abs(discount(1) - discount(2)) + * abs(gain(0.0) - gain(2.0)), + 3.0 + * abs(discount(2) - discount(3)) + * abs(gain(0.0) - gain(1.0)), + 3.0 + * abs(discount(1) - discount(2)) + * abs(gain(2.0) - gain(0.0)), + 0.0, + 3.0 + * abs(discount(1) - discount(2)) + * abs(gain(2.0) - gain(1.0)), + 3.0 + * abs(discount(2) - discount(3)) + * abs(gain(1.0) - gain(0.0)), + 3.0 + * abs(discount(1) - discount(2)) + * abs(gain(1.0) - gain(2.0)), + 0.0, + ], + [ + 0.0, + 0.0, + 3.0 + * abs(discount(1) - discount(2)) + * abs(gain(1.0) - gain(0.0)), + 0.0, + 0.0, + 3.0 + * abs(discount(1) - discount(2)) + * abs(gain(0.0) - gain(1.0)), + 3.0 + * abs(discount(1) - discount(2)) + * abs(gain(1.0) - gain(0.0)), + 3.0 + * abs(discount(1) - discount(2)) + * abs(gain(1.0) - gain(0.0)), + 0.0, + ], ], - [ - 0.0, 0.0, - 3.0 * abs(discount(1) - discount(2)) * abs(gain(1.0) - gain(0.0)), - 0.0, 0.0, - 3.0 * abs(discount(1) - discount(2)) * abs(gain(0.0) - gain(1.0)), - 3.0 * abs(discount(1) - discount(2)) * abs(gain(1.0) - gain(0.0)), - 3.0 * abs(discount(1) - discount(2)) * abs(gain(1.0) - gain(0.0)), - 0.0 - ], - ] - }]) - def test_lambdaweights_with_batchdim( - self, lambdaweight_fn, expected - ): + }, + ]) + def test_lambdaweights_with_batchdim(self, lambdaweight_fn, expected): scores = jnp.array([[0.0, 1.0, 2.0], [0.5, 1.5, 1.0]]) labels = jnp.array([[0.0, 2.0, 1.0], [0.0, 0.0, 1.0]]) @@ -214,30 +264,40 @@ def test_lambdaweights_with_empty_list(self, lambdaweight_fn): np.testing.assert_allclose(result, expected) - @parameterized.parameters([{ - "lambdaweight_fn": lambdaweights.labeldiff_lambdaweight, - "expected": [0.0, 0.0, 0.3, 0.0, 0.0, 0.0, 0.3, 0.0, 0.0] - }, { - "lambdaweight_fn": - lambdaweights.dcg_lambdaweight, - "expected": [ - 0.0, 0.0, - 3.0 * abs(discount(1) - discount(2)) * abs(gain(0.0) - gain(0.3)), - 0.0, 0.0, 0.0, - 3.0 * abs(discount(2) - discount(1)) * abs(gain(0.3) - gain(0.0)), - 0.0, 0.0 - ] - }, { - "lambdaweight_fn": - lambdaweights.dcg2_lambdaweight, - "expected": [ - 0.0, 0.0, - 3.0 * abs(discount(1) - discount(2)) * abs(gain(0.0) - gain(0.3)), - 0.0, 0.0, 0.0, - 3.0 * abs(discount(1) - discount(2)) * abs(gain(0.3) - gain(0.0)), - 0.0, 0.0 - ] - }]) + @parameterized.parameters([ + { + "lambdaweight_fn": lambdaweights.labeldiff_lambdaweight, + "expected": [0.0, 0.0, 0.3, 0.0, 0.0, 0.0, 0.3, 0.0, 0.0], + }, + { + "lambdaweight_fn": lambdaweights.dcg_lambdaweight, + "expected": [ + 0.0, + 0.0, + 3.0 * abs(discount(1) - discount(2)) * abs(gain(0.0) - gain(0.3)), + 0.0, + 0.0, + 0.0, + 3.0 * abs(discount(2) - discount(1)) * abs(gain(0.3) - gain(0.0)), + 0.0, + 0.0, + ], + }, + { + "lambdaweight_fn": lambdaweights.dcg2_lambdaweight, + "expected": [ + 0.0, + 0.0, + 3.0 * abs(discount(1) - discount(2)) * abs(gain(0.0) - gain(0.3)), + 0.0, + 0.0, + 0.0, + 3.0 * abs(discount(1) - discount(2)) * abs(gain(0.3) - gain(0.0)), + 0.0, + 0.0, + ], + }, + ]) def test_lambdaweights_with_where_mask(self, lambdaweight_fn, expected): scores = jnp.array([0.0, 1.0, 2.0]) labels = jnp.array([0.0, 1.0, 0.3]) @@ -247,39 +307,60 @@ def test_lambdaweights_with_where_mask(self, lambdaweight_fn, expected): np.testing.assert_allclose(result, expected, rtol=1e-5) - @parameterized.parameters([{ - "lambdaweight_fn": lambdaweights.labeldiff_lambdaweight, - "expected": [0.0, 1.0, 0.3, 1.0, 0.0, 0.7, 0.3, 0.7, 0.0] - }, { - "lambdaweight_fn": - lambdaweights.dcg_lambdaweight, - "expected": [ - 0.0, - 3.0 * abs(discount(2) - discount(3)) * abs(0.0 - 0.5 * gain(1.0)), - 3.0 * abs(discount(1) - discount(3)) * - abs(1.5 * gain(0.0) - gain(0.3)), - 3.0 * abs(discount(3) - discount(2)) * abs(0.5 * gain(1.0) - 0.0), - 0.0, 3.0 * abs(discount(1) - discount(2)) * - abs(0.5 * gain(1.0) - gain(0.3)), 3.0 * - abs(discount(3) - discount(1)) * abs(gain(0.3) - 1.5 * gain(0.0)), - 3.0 * abs(discount(2) - discount(1)) * - abs(gain(0.3) - 0.5 * gain(1.0)), 0.0 - ] - }, { - "lambdaweight_fn": - lambdaweights.dcg2_lambdaweight, - "expected": [ - 0.0, 3.0 * abs(discount(1) - discount(2)) * - abs(gain(0.0) - 0.5 * gain(1.0)), 3.0 * - abs(discount(2) - discount(3)) * abs(1.5 * gain(0.0) - gain(0.3)), - 3.0 * abs(discount(1) - discount(2)) * - abs(0.5 * gain(1.0) - gain(0.0)), 0.0, 3.0 * - abs(discount(1) - discount(2)) * abs(0.5 * gain(1.0) - gain(0.3)), - 3.0 * abs(discount(2) - discount(3)) * - abs(gain(0.3) - 1.5 * gain(0.0)), 3.0 * - abs(discount(2) - discount(1)) * abs(gain(0.3) - 0.5 * gain(1.0)), 0.0 - ] - }]) + @parameterized.parameters([ + { + "lambdaweight_fn": lambdaweights.labeldiff_lambdaweight, + "expected": [0.0, 1.0, 0.3, 1.0, 0.0, 0.7, 0.3, 0.7, 0.0], + }, + { + "lambdaweight_fn": lambdaweights.dcg_lambdaweight, + "expected": [ + 0.0, + 3.0 * abs(discount(2) - discount(3)) * abs(0.0 - 0.5 * gain(1.0)), + 3.0 + * abs(discount(1) - discount(3)) + * abs(1.5 * gain(0.0) - gain(0.3)), + 3.0 * abs(discount(3) - discount(2)) * abs(0.5 * gain(1.0) - 0.0), + 0.0, + 3.0 + * abs(discount(1) - discount(2)) + * abs(0.5 * gain(1.0) - gain(0.3)), + 3.0 + * abs(discount(3) - discount(1)) + * abs(gain(0.3) - 1.5 * gain(0.0)), + 3.0 + * abs(discount(2) - discount(1)) + * abs(gain(0.3) - 0.5 * gain(1.0)), + 0.0, + ], + }, + { + "lambdaweight_fn": lambdaweights.dcg2_lambdaweight, + "expected": [ + 0.0, + 3.0 + * abs(discount(1) - discount(2)) + * abs(gain(0.0) - 0.5 * gain(1.0)), + 3.0 + * abs(discount(2) - discount(3)) + * abs(1.5 * gain(0.0) - gain(0.3)), + 3.0 + * abs(discount(1) - discount(2)) + * abs(0.5 * gain(1.0) - gain(0.0)), + 0.0, + 3.0 + * abs(discount(1) - discount(2)) + * abs(0.5 * gain(1.0) - gain(0.3)), + 3.0 + * abs(discount(2) - discount(3)) + * abs(gain(0.3) - 1.5 * gain(0.0)), + 3.0 + * abs(discount(2) - discount(1)) + * abs(gain(0.3) - 0.5 * gain(1.0)), + 0.0, + ], + }, + ]) def test_lambdaweights_with_weights(self, lambdaweight_fn, expected): scores = jnp.array([0.0, 1.0, 2.0]) labels = jnp.array([0.0, 1.0, 0.3]) @@ -289,31 +370,54 @@ def test_lambdaweights_with_weights(self, lambdaweight_fn, expected): np.testing.assert_allclose(result, expected, rtol=1e-5) - @parameterized.parameters([{ - "lambdaweight_fn": - lambdaweights.dcg_lambdaweight, - "expected": [ - 0.0, 0.0, 3.0 * abs(discount(1) - 0.0) * abs(gain(0.0) - gain(0.3)), - 0.0, 0.0, 3.0 * abs(discount(1) - 0.0) * abs(gain(1.0) - gain(0.3)), - 3.0 * abs(0.0 - discount(1)) * abs(gain(0.3) - gain(0.0)), - 3.0 * abs(0.0 - discount(1)) * abs(gain(0.3) - gain(1.0)), 0.0 - ] - }, { - "lambdaweight_fn": - lambdaweights.dcg2_lambdaweight, - "expected": [ - 0.0, 3.0 * (1.0 / (1.0 - discount(3))) * - abs(discount(1) - discount(2)) * abs(gain(0.0) - gain(1.0)), - 3.0 * (1.0 / (1.0 - discount(3))) * abs(discount(2) - discount(3)) * - abs(gain(0.0) - gain(0.3)), 3.0 * (1.0 / (1.0 - discount(3))) * - abs(discount(1) - discount(2)) * abs(gain(1.0) - gain(0.0)), 0.0, - 3.0 * (1.0 / (1.0 - discount(2))) * abs(discount(1) - discount(2)) * - abs(gain(1.0) - gain(0.3)), 3.0 * (1.0 / (1.0 - discount(3))) * - abs(discount(2) - discount(3)) * abs(gain(0.3) - gain(0.0)), - 3.0 * (1.0 / (1.0 - discount(2))) * abs(discount(2) - discount(1)) * - abs(gain(0.3) - gain(1.0)), 0.0 - ] - }]) + @parameterized.parameters([ + { + "lambdaweight_fn": lambdaweights.dcg_lambdaweight, + "expected": [ + 0.0, + 0.0, + 3.0 * abs(discount(1) - 0.0) * abs(gain(0.0) - gain(0.3)), + 0.0, + 0.0, + 3.0 * abs(discount(1) - 0.0) * abs(gain(1.0) - gain(0.3)), + 3.0 * abs(0.0 - discount(1)) * abs(gain(0.3) - gain(0.0)), + 3.0 * abs(0.0 - discount(1)) * abs(gain(0.3) - gain(1.0)), + 0.0, + ], + }, + { + "lambdaweight_fn": lambdaweights.dcg2_lambdaweight, + "expected": [ + 0.0, + 3.0 + * (1.0 / (1.0 - discount(3))) + * abs(discount(1) - discount(2)) + * abs(gain(0.0) - gain(1.0)), + 3.0 + * (1.0 / (1.0 - discount(3))) + * abs(discount(2) - discount(3)) + * abs(gain(0.0) - gain(0.3)), + 3.0 + * (1.0 / (1.0 - discount(3))) + * abs(discount(1) - discount(2)) + * abs(gain(1.0) - gain(0.0)), + 0.0, + 3.0 + * (1.0 / (1.0 - discount(2))) + * abs(discount(1) - discount(2)) + * abs(gain(1.0) - gain(0.3)), + 3.0 + * (1.0 / (1.0 - discount(3))) + * abs(discount(2) - discount(3)) + * abs(gain(0.3) - gain(0.0)), + 3.0 + * (1.0 / (1.0 - discount(2))) + * abs(discount(2) - discount(1)) + * abs(gain(0.3) - gain(1.0)), + 0.0, + ], + }, + ]) def test_lambdaweights_with_topn(self, lambdaweight_fn, expected): scores = jnp.array([0.0, 1.0, 2.0]) labels = jnp.array([0.0, 1.0, 0.3]) @@ -323,43 +427,53 @@ def test_lambdaweights_with_topn(self, lambdaweight_fn, expected): np.testing.assert_allclose(result, expected, rtol=1e-5) - @parameterized.parameters([{ - "loss_fn": losses.pairwise_hinge_loss, - "lambdaweight_fn": lambdaweights.labeldiff_lambdaweight, - "expected": 0.63333327 - }, { - "loss_fn": losses.pairwise_logistic_loss, - "lambdaweight_fn": lambdaweights.labeldiff_lambdaweight, - "expected": 0.6564648 - }, { - "loss_fn": losses.pairwise_hinge_loss, - "lambdaweight_fn": lambdaweights.dcg_lambdaweight, - "expected": 0.43137252 * 4.0 - }, { - "loss_fn": losses.pairwise_logistic_loss, - "lambdaweight_fn": lambdaweights.dcg_lambdaweight, - "expected": 0.34675273 * 4.0 - }, { - "loss_fn": losses.pairwise_hinge_loss, - "lambdaweight_fn": lambdaweights.dcg2_lambdaweight, - "expected": 0.45518658 * 4.0 - }, { - "loss_fn": losses.pairwise_logistic_loss, - "lambdaweight_fn": lambdaweights.dcg2_lambdaweight, - "expected": 0.4053712 * 4.0 - }, { - "loss_fn": losses.pairwise_mse_loss, - "lambdaweight_fn": lambdaweights.dcg_lambdaweight, - "expected": 0.61966689551 * 4.0 - }]) - def test_computes_with_pairwise_loss(self, loss_fn, lambdaweight_fn, - expected): + @parameterized.parameters([ + { + "loss_fn": losses.pairwise_hinge_loss, + "lambdaweight_fn": lambdaweights.labeldiff_lambdaweight, + "expected": 0.63333327, + }, + { + "loss_fn": losses.pairwise_logistic_loss, + "lambdaweight_fn": lambdaweights.labeldiff_lambdaweight, + "expected": 0.6564648, + }, + { + "loss_fn": losses.pairwise_hinge_loss, + "lambdaweight_fn": lambdaweights.dcg_lambdaweight, + "expected": 0.43137252 * 4.0, + }, + { + "loss_fn": losses.pairwise_logistic_loss, + "lambdaweight_fn": lambdaweights.dcg_lambdaweight, + "expected": 0.34675273 * 4.0, + }, + { + "loss_fn": losses.pairwise_hinge_loss, + "lambdaweight_fn": lambdaweights.dcg2_lambdaweight, + "expected": 0.45518658 * 4.0, + }, + { + "loss_fn": losses.pairwise_logistic_loss, + "lambdaweight_fn": lambdaweights.dcg2_lambdaweight, + "expected": 0.4053712 * 4.0, + }, + { + "loss_fn": losses.pairwise_mse_loss, + "lambdaweight_fn": lambdaweights.dcg_lambdaweight, + "expected": 0.61966689551 * 4.0, + }, + ]) + def test_computes_with_pairwise_loss( + self, loss_fn, lambdaweight_fn, expected + ): scores = jnp.array([0.3, 1.9, 1.5, 1.2]) labels = jnp.array([0.0, 1.0, 1.0, 2.0]) where = jnp.array([1, 1, 0, 1], dtype=jnp.bool_) result = loss_fn( - scores, labels, where=where, lambdaweight_fn=lambdaweight_fn) + scores, labels, where=where, lambdaweight_fn=lambdaweight_fn + ) np.testing.assert_allclose(result, expected, rtol=1e-5) @@ -368,11 +482,9 @@ def load_tests(loader, tests, ignore): del loader, ignore # Unused. tests.addTests( doctest.DocTestSuite( - lambdaweights, globs={ - "jax": jax, - "jnp": jnp, - "rax": rax - })) + lambdaweights, globs={"jax": jax, "jnp": jnp, "rax": rax} + ) + ) return tests diff --git a/rax/_src/losses.py b/rax/_src/losses.py index b8af830..562fdea 100644 --- a/rax/_src/losses.py +++ b/rax/_src/losses.py @@ -225,7 +225,7 @@ def poly1_softmax_loss( # For lists where all items are masked, this sets pt to 1 so that the term # (1 - pt) is set to 0 for the loss computation. if where is not None: - pt = jnp.where(jnp.all(jnp.logical_not(where), axis=-1), 1., pt) + pt = jnp.where(jnp.all(jnp.logical_not(where), axis=-1), 1.0, pt) # In the segmented case, values retain their list dimension. This constructs # a mask so that only the first item per segment is used in reduce_fn. @@ -289,7 +289,8 @@ def unique_softmax_loss( # indicate, for each item, which other items have a smaller label. labels_lt = jnp.expand_dims(labels, -2) < jnp.expand_dims(labels, -1) scores_repeated = jnp.repeat( - jnp.expand_dims(scores, -2), scores.shape[-1], axis=-2) + jnp.expand_dims(scores, -2), scores.shape[-1], axis=-2 + ) if segments is not None: labels_lt = labels_lt & segment_utils.same_segment_mask(segments) @@ -309,7 +310,8 @@ def unique_softmax_loss( scores_repeated, axis=-1, where=identity_mask | labels_lt, - initial=jnp.min(scores)) + initial=jnp.min(scores), + ) log_softmax = jnp.diagonal(log_softmax, axis1=-2, axis2=-1) # Apply per-item weights. @@ -420,15 +422,17 @@ def listmle_loss( return utils.safe_reduce(loss, where=where, reduce_fn=reduce_fn) -def pairwise_loss(scores: Array, - labels: Array, - *, - pair_loss_fn: Callable[[Array, Array], Tuple[Array, Array]], - lambdaweight_fn: Optional[LambdaweightFn] = None, - where: Optional[Array] = None, - segments: Optional[Array] = None, - weights: Optional[Array] = None, - reduce_fn: Optional[ReduceFn] = jnp.mean) -> Array: +def pairwise_loss( + scores: Array, + labels: Array, + *, + pair_loss_fn: Callable[[Array, Array], Tuple[Array, Array]], + lambdaweight_fn: Optional[LambdaweightFn] = None, + where: Optional[Array] = None, + segments: Optional[Array] = None, + weights: Optional[Array] = None, + reduce_fn: Optional[ReduceFn] = jnp.mean, +) -> Array: r"""Generic pairwise loss. The ``pair_loss_fn`` takes ``(scores_diff, labels_diff)`` and returns the loss @@ -485,14 +489,16 @@ def pairwise_loss(scores: Array, return utils.safe_reduce(pair_losses, where=valid_pairs, reduce_fn=reduce_fn) -def pairwise_hinge_loss(scores: Array, - labels: Array, - *, - where: Optional[Array] = None, - segments: Optional[Array] = None, - weights: Optional[Array] = None, - lambdaweight_fn: Optional[LambdaweightFn] = None, - reduce_fn: Optional[ReduceFn] = jnp.mean) -> Array: +def pairwise_hinge_loss( + scores: Array, + labels: Array, + *, + where: Optional[Array] = None, + segments: Optional[Array] = None, + weights: Optional[Array] = None, + lambdaweight_fn: Optional[LambdaweightFn] = None, + reduce_fn: Optional[ReduceFn] = jnp.mean, +) -> Array: r"""Pairwise hinge loss. Definition: @@ -522,9 +528,10 @@ def pairwise_hinge_loss(scores: Array, The pairwise hinge loss. """ - def _hinge_loss(scores_diff: Array, - labels_diff: Array) -> Tuple[Array, Array]: - return jax.nn.relu(1. - scores_diff), labels_diff > 0 + def _hinge_loss( + scores_diff: Array, labels_diff: Array + ) -> Tuple[Array, Array]: + return jax.nn.relu(1.0 - scores_diff), labels_diff > 0 return pairwise_loss( scores, @@ -534,17 +541,20 @@ def _hinge_loss(scores_diff: Array, where=where, weights=weights, lambdaweight_fn=lambdaweight_fn, - reduce_fn=reduce_fn) + reduce_fn=reduce_fn, + ) -def pairwise_logistic_loss(scores: Array, - labels: Array, - *, - where: Optional[Array] = None, - segments: Optional[Array] = None, - weights: Optional[Array] = None, - lambdaweight_fn: Optional[LambdaweightFn] = None, - reduce_fn: Optional[ReduceFn] = jnp.mean) -> Array: +def pairwise_logistic_loss( + scores: Array, + labels: Array, + *, + where: Optional[Array] = None, + segments: Optional[Array] = None, + weights: Optional[Array] = None, + lambdaweight_fn: Optional[LambdaweightFn] = None, + reduce_fn: Optional[ReduceFn] = jnp.mean, +) -> Array: r"""Pairwise logistic loss. Definition :cite:p:`burges2005learning`: @@ -574,10 +584,13 @@ def pairwise_logistic_loss(scores: Array, The pairwise logistic loss. """ - def _logistic_loss(scores_diff: Array, - labels_diff: Array) -> Tuple[Array, Array]: - return (jax.nn.relu(-scores_diff) + - jnp.log1p(jnp.exp(-jnp.abs(scores_diff))), labels_diff > 0) + def _logistic_loss( + scores_diff: Array, labels_diff: Array + ) -> Tuple[Array, Array]: + return ( + jax.nn.relu(-scores_diff) + jnp.log1p(jnp.exp(-jnp.abs(scores_diff))), + labels_diff > 0, + ) return pairwise_loss( scores, @@ -587,7 +600,8 @@ def _logistic_loss(scores_diff: Array, segments=segments, weights=weights, lambdaweight_fn=lambdaweight_fn, - reduce_fn=reduce_fn) + reduce_fn=reduce_fn, + ) def pairwise_soft_zero_one_loss( @@ -653,13 +667,15 @@ def _soft_zero_one_loss( ) -def pointwise_sigmoid_loss(scores: Array, - labels: Array, - *, - where: Optional[Array] = None, - segments: Optional[Array] = None, - weights: Optional[Array] = None, - reduce_fn: Optional[ReduceFn] = jnp.mean) -> Array: +def pointwise_sigmoid_loss( + scores: Array, + labels: Array, + *, + where: Optional[Array] = None, + segments: Optional[Array] = None, + weights: Optional[Array] = None, + reduce_fn: Optional[ReduceFn] = jnp.mean, +) -> Array: r"""Sigmoid cross entropy loss. .. note:: @@ -702,8 +718,10 @@ def pointwise_sigmoid_loss(scores: Array, # A numerically stable version of sigmoid cross entropy. loss = ( - jax.nn.relu(scores) - scores * labels + - jnp.log(1. + jnp.exp(-jnp.abs(scores)))) + jax.nn.relu(scores) + - scores * labels + + jnp.log(1.0 + jnp.exp(-jnp.abs(scores))) + ) if weights is not None: loss *= weights @@ -711,13 +729,15 @@ def pointwise_sigmoid_loss(scores: Array, return utils.safe_reduce(loss, where=where, reduce_fn=reduce_fn) -def pointwise_mse_loss(scores: Array, - labels: Array, - *, - where: Optional[Array] = None, - segments: Optional[Array] = None, - weights: Optional[Array] = None, - reduce_fn: Optional[ReduceFn] = jnp.mean) -> Array: +def pointwise_mse_loss( + scores: Array, + labels: Array, + *, + where: Optional[Array] = None, + segments: Optional[Array] = None, + weights: Optional[Array] = None, + reduce_fn: Optional[ReduceFn] = jnp.mean, +) -> Array: r"""Mean squared error loss. Definition: @@ -758,14 +778,16 @@ def pointwise_mse_loss(scores: Array, return utils.safe_reduce(loss, where=where, reduce_fn=reduce_fn) -def pairwise_mse_loss(scores: Array, - labels: Array, - *, - where: Optional[Array] = None, - segments: Optional[Array] = None, - weights: Optional[Array] = None, - lambdaweight_fn: Optional[LambdaweightFn] = None, - reduce_fn: Optional[ReduceFn] = jnp.mean) -> Array: +def pairwise_mse_loss( + scores: Array, + labels: Array, + *, + where: Optional[Array] = None, + segments: Optional[Array] = None, + weights: Optional[Array] = None, + lambdaweight_fn: Optional[LambdaweightFn] = None, + reduce_fn: Optional[ReduceFn] = jnp.mean, +) -> Array: r"""Pairwise mean squared error loss. Definition: @@ -796,8 +818,10 @@ def pairwise_mse_loss(scores: Array, """ def _mse_loss(scores_diff: Array, labels_diff: Array) -> Tuple[Array, Array]: - return (jnp.square(scores_diff - labels_diff), - jnp.ones_like(labels_diff > 0)) + return ( + jnp.square(scores_diff - labels_diff), + jnp.ones_like(labels_diff > 0), + ) return pairwise_loss( scores, @@ -807,19 +831,22 @@ def _mse_loss(scores_diff: Array, labels_diff: Array) -> Tuple[Array, Array]: segments=segments, weights=weights, lambdaweight_fn=lambdaweight_fn, - reduce_fn=reduce_fn) - - -def pairwise_qr_loss(scores: Array, - labels: Array, - *, - where: Optional[Array] = None, - segments: Optional[Array] = None, - weights: Optional[Array] = None, - tau: float = 0.5, - squared: bool = False, - lambdaweight_fn: Optional[LambdaweightFn] = None, - reduce_fn: Optional[ReduceFn] = jnp.mean) -> Array: + reduce_fn=reduce_fn, + ) + + +def pairwise_qr_loss( + scores: Array, + labels: Array, + *, + where: Optional[Array] = None, + segments: Optional[Array] = None, + weights: Optional[Array] = None, + tau: float = 0.5, + squared: bool = False, + lambdaweight_fn: Optional[LambdaweightFn] = None, + reduce_fn: Optional[ReduceFn] = jnp.mean, +) -> Array: r"""Pairwise quantile regression loss. Definition: @@ -866,9 +893,10 @@ def _qr_loss(scores_diff: Array, labels_diff: Array) -> Tuple[Array, Array]: loss_1, loss_2 = jnp.square(loss_1), jnp.square(loss_2) return tau * loss_1 + (1 - tau) * loss_2, labels_diff > 0 - if not (tau > 0. and tau <= 1.): + if not (tau > 0.0 and tau <= 1.0): raise ValueError( - f'tau should be in the range of (0.0, 1.0], but {tau} is given.') + f'tau should be in the range of (0.0, 1.0], but {tau} is given.' + ) return pairwise_loss( scores, @@ -878,4 +906,5 @@ def _qr_loss(scores_diff: Array, labels_diff: Array) -> Tuple[Array, Array]: segments=segments, weights=weights, lambdaweight_fn=lambdaweight_fn, - reduce_fn=reduce_fn) + reduce_fn=reduce_fn, + ) diff --git a/rax/_src/losses_test.py b/rax/_src/losses_test.py index 9e4fa0c..f69e4e8 100644 --- a/rax/_src/losses_test.py +++ b/rax/_src/losses_test.py @@ -30,292 +30,397 @@ # Export symbols from math for conciser test value definitions. exp = math.exp log = math.log -logloss = lambda x: log(1. + exp(-x)) -sigmoid = lambda x: 1. / (1. + exp(-x)) +logloss = lambda x: log(1.0 + exp(-x)) +sigmoid = lambda x: 1.0 / (1.0 + exp(-x)) class LossesTest(parameterized.TestCase): - @parameterized.parameters([{ - "loss_fn": - losses.softmax_loss, - "expected_value": - -(log(exp(2.) / (exp(0.) + exp(3.) + exp(1.) + exp(2.))) + - log(exp(1.) / (exp(0.) + exp(3.) + exp(1.) + exp(2.)))) - }, { - "loss_fn": - losses.listmle_loss, - "expected_value": - -sum([ - log(exp(1.) / (exp(1.) + exp(2.) + exp(0.) + exp(3.))), - log(exp(2.) / (exp(2.) + exp(0.) + exp(3.))), - log(exp(0.) / (exp(0.) + exp(3.))), - log(exp(3.) / (exp(3.))), - ]) - }, { - "loss_fn": - losses.poly1_softmax_loss, - "expected_value": - -(log(exp(2.) / (exp(0.) + exp(3.) + exp(1.) + exp(2.))) + - log(exp(1.) / (exp(0.) + exp(3.) + exp(1.) + exp(2.)))) + - (1. - (0.5 * exp(2.) / (exp(0.) + exp(3.) + exp(1.) + exp(2.)) + 0.5 * - (exp(1.) / (exp(0.) + exp(3.) + exp(1.) + exp(2.))))) - }, { - "loss_fn": - losses.unique_softmax_loss, - "expected_value": - -(log(exp(2.) / (exp(0.) + exp(3.) + exp(2.))) + - log(exp(1.) / (exp(0.) + exp(3.) + exp(1.)))) - }, { - "loss_fn": - functools.partial(losses.poly1_softmax_loss, epsilon=0.1), - "expected_value": - -(log(exp(2.) / (exp(0.) + exp(3.) + exp(1.) + exp(2.))) + - log(exp(1.) / (exp(0.) + exp(3.) + exp(1.) + exp(2.)))) + 0.1 * - (1. - (0.5 * exp(2.) / (exp(0.) + exp(3.) + exp(1.) + exp(2.)) + 0.5 * - (exp(1.) / (exp(0.) + exp(3.) + exp(1.) + exp(2.))))) - }, { - "loss_fn": losses.pairwise_hinge_loss, - "expected_value": (3. - 1. + 1.) + (3. - 2. + 1.) - }, { - "loss_fn": - losses.pairwise_logistic_loss, - "expected_value": - logloss(1. - 0.) + logloss(1. - 3.) + logloss(2. - 3.) + - logloss(2. - 0.) - }, { - "loss_fn": - losses.pairwise_soft_zero_one_loss, - "expected_value": - sigmoid(-(1. - 0.)) + sigmoid(-(1. - 3.)) + sigmoid(-(2. - 3.)) + - sigmoid(-(2. - 0.)) - }, { - "loss_fn": - losses.pointwise_sigmoid_loss, - "expected_value": - -log(1. - sigmoid(0.)) - log(1. - sigmoid(3.)) - log(sigmoid(1.)) - - log(sigmoid(2.)) - }, { - "loss_fn": - losses.pointwise_mse_loss, - "expected_value": - (0. - 0.)**2 + (3. - 0.)**2 + (1. - 1.)**2 + (2. - 1.)**2 - }, { - "loss_fn": - losses.pairwise_mse_loss, - "expected_value": - ((0. - 3.) - (0. - 0.))**2 + ((0. - 1.) - (0. - 1.))**2 + - ((0. - 2.) - (0. - 1.))**2 + ((3. - 0.) - (0. - 0.))**2 + - ((3. - 1.) - (0. - 1.))**2 + ((3. - 2.) - (0. - 1.))**2 + - ((1. - 0.) - (1. - 0.))**2 + ((1. - 3.) - (1. - 0.))**2 + - ((1. - 2.) - (1. - 1.))**2 + ((2. - 0.) - (1. - 0.))**2 + - ((2. - 3.) - (1. - 0.))**2 + ((2. - 1.) - (1. - 1.))**2 - }, { - "loss_fn": - losses.pairwise_qr_loss, - "expected_value": - 0.5 * (((1. - 0.) - (1. - 0.)) + ((1. - 0.) - (1. - 3.)) + - ((1. - 0.) - (2. - 3.))) + 0.5 * (((2. - 0.) - (1. - 0.))) - }, { - "loss_fn": - functools.partial(losses.pairwise_qr_loss, tau=1.0), - "expected_value": - 1. * (((1. - 0.) - (1. - 0.)) + ((1. - 0.) - (1. - 3.)) + - ((1. - 0.) - (2. - 3.))) + 0. * (((2. - 0.) - (1. - 0.))) - }, { - "loss_fn": - functools.partial(losses.pairwise_qr_loss, squared=True), - "expected_value": - 0.5 * (((1. - 0.) - (1. - 0.))**2 + ((1. - 0.) - (1. - 3.))**2 + - ((1. - 0.) - (2. - 3.))**2) + 0.5 * (((2. - 0.) - - (1. - 0.))**2) - }]) + @parameterized.parameters([ + { + "loss_fn": losses.softmax_loss, + "expected_value": -( + log(exp(2.0) / (exp(0.0) + exp(3.0) + exp(1.0) + exp(2.0))) + + log(exp(1.0) / (exp(0.0) + exp(3.0) + exp(1.0) + exp(2.0))) + ), + }, + { + "loss_fn": losses.listmle_loss, + "expected_value": -sum([ + log(exp(1.0) / (exp(1.0) + exp(2.0) + exp(0.0) + exp(3.0))), + log(exp(2.0) / (exp(2.0) + exp(0.0) + exp(3.0))), + log(exp(0.0) / (exp(0.0) + exp(3.0))), + log(exp(3.0) / (exp(3.0))), + ]), + }, + { + "loss_fn": losses.poly1_softmax_loss, + "expected_value": -( + log(exp(2.0) / (exp(0.0) + exp(3.0) + exp(1.0) + exp(2.0))) + + log(exp(1.0) / (exp(0.0) + exp(3.0) + exp(1.0) + exp(2.0))) + ) + ( + 1.0 + - ( + 0.5 * exp(2.0) / (exp(0.0) + exp(3.0) + exp(1.0) + exp(2.0)) + + 0.5 + * (exp(1.0) / (exp(0.0) + exp(3.0) + exp(1.0) + exp(2.0))) + ) + ), + }, + { + "loss_fn": losses.unique_softmax_loss, + "expected_value": -( + log(exp(2.0) / (exp(0.0) + exp(3.0) + exp(2.0))) + + log(exp(1.0) / (exp(0.0) + exp(3.0) + exp(1.0))) + ), + }, + { + "loss_fn": functools.partial(losses.poly1_softmax_loss, epsilon=0.1), + "expected_value": -( + log(exp(2.0) / (exp(0.0) + exp(3.0) + exp(1.0) + exp(2.0))) + + log(exp(1.0) / (exp(0.0) + exp(3.0) + exp(1.0) + exp(2.0))) + ) + 0.1 * ( + 1.0 + - ( + 0.5 * exp(2.0) / (exp(0.0) + exp(3.0) + exp(1.0) + exp(2.0)) + + 0.5 + * (exp(1.0) / (exp(0.0) + exp(3.0) + exp(1.0) + exp(2.0))) + ) + ), + }, + { + "loss_fn": losses.pairwise_hinge_loss, + "expected_value": (3.0 - 1.0 + 1.0) + (3.0 - 2.0 + 1.0), + }, + { + "loss_fn": losses.pairwise_logistic_loss, + "expected_value": ( + logloss(1.0 - 0.0) + + logloss(1.0 - 3.0) + + logloss(2.0 - 3.0) + + logloss(2.0 - 0.0) + ), + }, + { + "loss_fn": losses.pairwise_soft_zero_one_loss, + "expected_value": ( + sigmoid(-(1.0 - 0.0)) + + sigmoid(-(1.0 - 3.0)) + + sigmoid(-(2.0 - 3.0)) + + sigmoid(-(2.0 - 0.0)) + ), + }, + { + "loss_fn": losses.pointwise_sigmoid_loss, + "expected_value": ( + -log(1.0 - sigmoid(0.0)) + - log(1.0 - sigmoid(3.0)) + - log(sigmoid(1.0)) + - log(sigmoid(2.0)) + ), + }, + { + "loss_fn": losses.pointwise_mse_loss, + "expected_value": ( + (0.0 - 0.0) ** 2 + + (3.0 - 0.0) ** 2 + + (1.0 - 1.0) ** 2 + + (2.0 - 1.0) ** 2 + ), + }, + { + "loss_fn": losses.pairwise_mse_loss, + "expected_value": ( + ((0.0 - 3.0) - (0.0 - 0.0)) ** 2 + + ((0.0 - 1.0) - (0.0 - 1.0)) ** 2 + + ((0.0 - 2.0) - (0.0 - 1.0)) ** 2 + + ((3.0 - 0.0) - (0.0 - 0.0)) ** 2 + + ((3.0 - 1.0) - (0.0 - 1.0)) ** 2 + + ((3.0 - 2.0) - (0.0 - 1.0)) ** 2 + + ((1.0 - 0.0) - (1.0 - 0.0)) ** 2 + + ((1.0 - 3.0) - (1.0 - 0.0)) ** 2 + + ((1.0 - 2.0) - (1.0 - 1.0)) ** 2 + + ((2.0 - 0.0) - (1.0 - 0.0)) ** 2 + + ((2.0 - 3.0) - (1.0 - 0.0)) ** 2 + + ((2.0 - 1.0) - (1.0 - 1.0)) ** 2 + ), + }, + { + "loss_fn": losses.pairwise_qr_loss, + "expected_value": 0.5 * ( + ((1.0 - 0.0) - (1.0 - 0.0)) + + ((1.0 - 0.0) - (1.0 - 3.0)) + + ((1.0 - 0.0) - (2.0 - 3.0)) + ) + 0.5 * (((2.0 - 0.0) - (1.0 - 0.0))), + }, + { + "loss_fn": functools.partial(losses.pairwise_qr_loss, tau=1.0), + "expected_value": 1.0 * ( + ((1.0 - 0.0) - (1.0 - 0.0)) + + ((1.0 - 0.0) - (1.0 - 3.0)) + + ((1.0 - 0.0) - (2.0 - 3.0)) + ) + 0.0 * (((2.0 - 0.0) - (1.0 - 0.0))), + }, + { + "loss_fn": functools.partial(losses.pairwise_qr_loss, squared=True), + "expected_value": 0.5 * ( + ((1.0 - 0.0) - (1.0 - 0.0)) ** 2 + + ((1.0 - 0.0) - (1.0 - 3.0)) ** 2 + + ((1.0 - 0.0) - (2.0 - 3.0)) ** 2 + ) + 0.5 * (((2.0 - 0.0) - (1.0 - 0.0)) ** 2), + }, + ]) def test_computes_loss_value(self, loss_fn, expected_value): - scores = jnp.asarray([0., 3., 1., 2.]) - labels = jnp.asarray([0., 0., 1., 1.]) + scores = jnp.asarray([0.0, 3.0, 1.0, 2.0]) + labels = jnp.asarray([0.0, 0.0, 1.0, 1.0]) loss = loss_fn(scores, labels, reduce_fn=jnp.sum) np.testing.assert_allclose(jnp.asarray(expected_value), loss) - @parameterized.parameters([{ - "loss_fn": - losses.softmax_loss, - "expected_value": - -((-2.1e26 - (0. + -2.1e26 + 3.4e37 + 42.)) + - (3.4e37 - (0. + -2.1e26 + 3.4e37 + 42.))) - }, { - "loss_fn": losses.listmle_loss, - "expected_value": 3.4e37 - }, { - "loss_fn": - losses.poly1_softmax_loss, - "expected_value": - -((-2.1e26 - (0. + -2.1e26 + 3.4e37 + 42.)) + - (3.4e37 - (0. + -2.1e26 + 3.4e37 + 42.))) - }, { - "loss_fn": - losses.unique_softmax_loss, - "expected_value": - -((-2.1e26 - (0. + 42.)) + (3.4e37 - (0. + 3.4e37 + 42.))) - }, { - "loss_fn": losses.pairwise_hinge_loss, - "expected_value": (1. - (-2.1e26 - 0.)) + (1. - (-2.1e26 - 42.0)) - }, { - "loss_fn": losses.pairwise_logistic_loss, - "expected_value": 2.1e26 + 2.1e26 - }, { - "loss_fn": losses.pairwise_soft_zero_one_loss, - "expected_value": 1.0 + 1.0 - }, { - "loss_fn": losses.pointwise_sigmoid_loss, - "expected_value": 2.1e26 - log(1. - sigmoid(0.)) + 42.0 - }, { - "loss_fn": - losses.pointwise_mse_loss, - "expected_value": - (0. - 0.)**2 + (-2.1e26 - 1.)**2 + (3.4e37 - 1.)**2 + (42. - 0.)**2 - }, { - "loss_fn": - losses.pairwise_mse_loss, - "expected_value": - (2.1e26 - -1.)**2 + (-3.4e37 - -1.)**2 + (-42. - 0.)**2 + - (-2.1e26 - 1.)**2 + ((-2.1e26 - 3.4e37) - 0.)**2 + - ((-2.1e26 - 42.) - 1.)**2 + (3.4e37 - 1.)**2 + - ((3.4e37 - -2.1e26) - 0.)**2 + ((3.4e37 - 42.) - 1.)**2 + - (42. - 0.)**2 + ((42. - -2.1e26) - -1.)**2 + ((42. - 3.4e37) - -1.)**2 - }, { - "loss_fn": - losses.pairwise_qr_loss, - "expected_value": - 0.5 * (((1. - 0.) - (-2.1e26 - 0.)) + ((1. - 0.) - (-2.1e26 - 42.))) + - 0.5 * (((3.4e37 - 0.) - (1. - 0.)) + ((3.4e37 - 42.) - (1. - 0.))) - }]) + @parameterized.parameters([ + { + "loss_fn": losses.softmax_loss, + "expected_value": -( + (-2.1e26 - (0.0 + -2.1e26 + 3.4e37 + 42.0)) + + (3.4e37 - (0.0 + -2.1e26 + 3.4e37 + 42.0)) + ), + }, + {"loss_fn": losses.listmle_loss, "expected_value": 3.4e37}, + { + "loss_fn": losses.poly1_softmax_loss, + "expected_value": -( + (-2.1e26 - (0.0 + -2.1e26 + 3.4e37 + 42.0)) + + (3.4e37 - (0.0 + -2.1e26 + 3.4e37 + 42.0)) + ), + }, + { + "loss_fn": losses.unique_softmax_loss, + "expected_value": -( + (-2.1e26 - (0.0 + 42.0)) + (3.4e37 - (0.0 + 3.4e37 + 42.0)) + ), + }, + { + "loss_fn": losses.pairwise_hinge_loss, + "expected_value": (1.0 - (-2.1e26 - 0.0)) + (1.0 - (-2.1e26 - 42.0)), + }, + { + "loss_fn": losses.pairwise_logistic_loss, + "expected_value": 2.1e26 + 2.1e26, + }, + { + "loss_fn": losses.pairwise_soft_zero_one_loss, + "expected_value": 1.0 + 1.0, + }, + { + "loss_fn": losses.pointwise_sigmoid_loss, + "expected_value": 2.1e26 - log(1.0 - sigmoid(0.0)) + 42.0, + }, + { + "loss_fn": losses.pointwise_mse_loss, + "expected_value": ( + (0.0 - 0.0) ** 2 + + (-2.1e26 - 1.0) ** 2 + + (3.4e37 - 1.0) ** 2 + + (42.0 - 0.0) ** 2 + ), + }, + { + "loss_fn": losses.pairwise_mse_loss, + "expected_value": ( + (2.1e26 - -1.0) ** 2 + + (-3.4e37 - -1.0) ** 2 + + (-42.0 - 0.0) ** 2 + + (-2.1e26 - 1.0) ** 2 + + ((-2.1e26 - 3.4e37) - 0.0) ** 2 + + ((-2.1e26 - 42.0) - 1.0) ** 2 + + (3.4e37 - 1.0) ** 2 + + ((3.4e37 - -2.1e26) - 0.0) ** 2 + + ((3.4e37 - 42.0) - 1.0) ** 2 + + (42.0 - 0.0) ** 2 + + ((42.0 - -2.1e26) - -1.0) ** 2 + + ((42.0 - 3.4e37) - -1.0) ** 2 + ), + }, + { + "loss_fn": losses.pairwise_qr_loss, + "expected_value": 0.5 * ( + ((1.0 - 0.0) - (-2.1e26 - 0.0)) + ((1.0 - 0.0) - (-2.1e26 - 42.0)) + ) + 0.5 * ( + ((3.4e37 - 0.0) - (1.0 - 0.0)) + ((3.4e37 - 42.0) - (1.0 - 0.0)) + ), + }, + ]) def test_computes_loss_with_extreme_inputs(self, loss_fn, expected_value): - scores = jnp.asarray([0., -2.1e26, 3.4e37, 42.0]) - labels = jnp.asarray([0., 1., 1., 0.]) + scores = jnp.asarray([0.0, -2.1e26, 3.4e37, 42.0]) + labels = jnp.asarray([0.0, 1.0, 1.0, 0.0]) loss = loss_fn(scores, labels, reduce_fn=jnp.sum) np.testing.assert_allclose(jnp.asarray(expected_value), loss) - @parameterized.parameters([{ - "loss_fn": losses.softmax_loss, - "expected_value": 0. - }, { - "loss_fn": - losses.listmle_loss, - "expected_value": - -sum([ - log(exp(0.) / (exp(0.) + exp(3.) + exp(1.) + exp(2.))), - log(exp(3.) / (exp(3.) + exp(1.) + exp(2.))), - log(exp(1.) / (exp(1.) + exp(2.))), - log(exp(2.) / (exp(2.))), - ]) - }, { - "loss_fn": - losses.poly1_softmax_loss, - "expected_value": - 1. - sum([ - 0.25 * (exp(0.) / (exp(0.) + exp(3.) + exp(1.) + exp(2.))), - 0.25 * (exp(3.) / (exp(0.) + exp(3.) + exp(1.) + exp(2.))), - 0.25 * (exp(1.) / (exp(0.) + exp(3.) + exp(1.) + exp(2.))), - 0.25 * (exp(2.) / (exp(0.) + exp(3.) + exp(1.) + exp(2.))), - ]) - }, { - "loss_fn": losses.unique_softmax_loss, - "expected_value": 0. - }, { - "loss_fn": losses.pairwise_hinge_loss, - "expected_value": 0. - }, { - "loss_fn": losses.pairwise_logistic_loss, - "expected_value": 0. - }, { - "loss_fn": losses.pairwise_soft_zero_one_loss, - "expected_value": 0. - }, { - "loss_fn": - losses.pointwise_sigmoid_loss, - "expected_value": - -log(1. - sigmoid(0.)) - log(1. - sigmoid(3.)) - - log(1. - sigmoid(1.)) - log(1. - sigmoid(2.)) - }, { - "loss_fn": - losses.pointwise_mse_loss, - "expected_value": - (0. - 0.)**2 + (3. - 0.)**2 + (1. - 0.)**2 + (2. - 0.)**2 - }, { - "loss_fn": - losses.pairwise_mse_loss, - "expected_value": (-3.)**2 + (-1.)**2 + (-2.)**2 + 3.**2 + 2.**2 + 1.**2 + - 1.**2 + (-2.)**2 + (-1.)**2 + 2.**2 + (-1.)**2 + 1.**2 - }, { - "loss_fn": losses.pairwise_qr_loss, - "expected_value": 0. - }]) + @parameterized.parameters([ + {"loss_fn": losses.softmax_loss, "expected_value": 0.0}, + { + "loss_fn": losses.listmle_loss, + "expected_value": -sum([ + log(exp(0.0) / (exp(0.0) + exp(3.0) + exp(1.0) + exp(2.0))), + log(exp(3.0) / (exp(3.0) + exp(1.0) + exp(2.0))), + log(exp(1.0) / (exp(1.0) + exp(2.0))), + log(exp(2.0) / (exp(2.0))), + ]), + }, + { + "loss_fn": losses.poly1_softmax_loss, + "expected_value": 1.0 - sum([ + 0.25 * (exp(0.0) / (exp(0.0) + exp(3.0) + exp(1.0) + exp(2.0))), + 0.25 * (exp(3.0) / (exp(0.0) + exp(3.0) + exp(1.0) + exp(2.0))), + 0.25 * (exp(1.0) / (exp(0.0) + exp(3.0) + exp(1.0) + exp(2.0))), + 0.25 * (exp(2.0) / (exp(0.0) + exp(3.0) + exp(1.0) + exp(2.0))), + ]), + }, + {"loss_fn": losses.unique_softmax_loss, "expected_value": 0.0}, + {"loss_fn": losses.pairwise_hinge_loss, "expected_value": 0.0}, + {"loss_fn": losses.pairwise_logistic_loss, "expected_value": 0.0}, + { + "loss_fn": losses.pairwise_soft_zero_one_loss, + "expected_value": 0.0, + }, + { + "loss_fn": losses.pointwise_sigmoid_loss, + "expected_value": ( + -log(1.0 - sigmoid(0.0)) + - log(1.0 - sigmoid(3.0)) + - log(1.0 - sigmoid(1.0)) + - log(1.0 - sigmoid(2.0)) + ), + }, + { + "loss_fn": losses.pointwise_mse_loss, + "expected_value": ( + (0.0 - 0.0) ** 2 + + (3.0 - 0.0) ** 2 + + (1.0 - 0.0) ** 2 + + (2.0 - 0.0) ** 2 + ), + }, + { + "loss_fn": losses.pairwise_mse_loss, + "expected_value": ( + (-3.0) ** 2 + + (-1.0) ** 2 + + (-2.0) ** 2 + + 3.0**2 + + 2.0**2 + + 1.0**2 + + 1.0**2 + + (-2.0) ** 2 + + (-1.0) ** 2 + + 2.0**2 + + (-1.0) ** 2 + + 1.0**2 + ), + }, + {"loss_fn": losses.pairwise_qr_loss, "expected_value": 0.0}, + ]) def test_computes_loss_for_zero_labels(self, loss_fn, expected_value): - scores = jnp.asarray([0., 3., 1., 2.]) - labels = jnp.asarray([0., 0., 0., 0.]) + scores = jnp.asarray([0.0, 3.0, 1.0, 2.0]) + labels = jnp.asarray([0.0, 0.0, 0.0, 0.0]) loss = loss_fn(scores, labels, reduce_fn=jnp.sum) np.testing.assert_allclose(jnp.asarray(expected_value), loss) - @parameterized.parameters([{ - "loss_fn": - losses.softmax_loss, - "expected_value": - -(2. * log(exp(2.) / (exp(0.) + exp(3.) + exp(1.) + exp(2.))) + - log(exp(1.) / (exp(0.) + exp(3.) + exp(1.) + exp(2.)))) - }, { - "loss_fn": - losses.poly1_softmax_loss, - "expected_value": - -(2. * log(exp(2.) / (exp(0.) + exp(3.) + exp(1.) + exp(2.))) + - log(exp(1.) / (exp(0.) + exp(3.) + exp(1.) + exp(2.)))) + - (1. - (2. / 3. * exp(2.) / - (exp(0.) + exp(3.) + exp(1.) + exp(2.)) + 1. / 3. * - (exp(1.) / (exp(0.) + exp(3.) + exp(1.) + exp(2.))))) - }, { - "loss_fn": - losses.unique_softmax_loss, - "expected_value": - -(2. * log(exp(2.) / (exp(0.) + exp(3.) + exp(2.))) + - log(exp(1.) / (exp(0.) + exp(3.) + exp(1.)))) - }, { - "loss_fn": losses.pairwise_hinge_loss, - "expected_value": 7. - }, { - "loss_fn": losses.pairwise_logistic_loss, - "expected_value": 5.320569 - }, { - "loss_fn": losses.pairwise_soft_zero_one_loss, - "expected_value": 2.850261 - }, { - "loss_fn": - losses.pointwise_sigmoid_loss, - "expected_value": - -log(1. - sigmoid(0.)) - log(1. - sigmoid(3.)) - - 2. * log(sigmoid(2.)) - log(sigmoid(1.)) - }, { - "loss_fn": - losses.pointwise_mse_loss, - "expected_value": - (0. - 0.)**2 + (3. - 0.)**2 + 2. * (2. - 1.)**2 + (1. - 1.)**2 - }, { - "loss_fn": - losses.pairwise_mse_loss, - "expected_value": - (1. * ((-3. - 0.)**2 + (-2. - -1.)**2 + (-1. - -1.)**2)) + - (1. * ((3. - 0.)**2 + (1. - -1.)**2 + (2. - -1.)**2)) + - (2. * ((2. - 1.)**2 + (-1. - 1.)**2 + (1. - 0.)**2)) + - (1. * ((1. - 1.)**2 + (-2. - 1.)**2 + (-1. - 0.)**2)) - }]) + @parameterized.parameters([ + { + "loss_fn": losses.softmax_loss, + "expected_value": -( + 2.0 * log(exp(2.0) / (exp(0.0) + exp(3.0) + exp(1.0) + exp(2.0))) + + log(exp(1.0) / (exp(0.0) + exp(3.0) + exp(1.0) + exp(2.0))) + ), + }, + { + "loss_fn": losses.poly1_softmax_loss, + "expected_value": -( + 2.0 * log(exp(2.0) / (exp(0.0) + exp(3.0) + exp(1.0) + exp(2.0))) + + log(exp(1.0) / (exp(0.0) + exp(3.0) + exp(1.0) + exp(2.0))) + ) + ( + 1.0 + - ( + 2.0 + / 3.0 + * exp(2.0) + / (exp(0.0) + exp(3.0) + exp(1.0) + exp(2.0)) + + 1.0 + / 3.0 + * (exp(1.0) / (exp(0.0) + exp(3.0) + exp(1.0) + exp(2.0))) + ) + ), + }, + { + "loss_fn": losses.unique_softmax_loss, + "expected_value": -( + 2.0 * log(exp(2.0) / (exp(0.0) + exp(3.0) + exp(2.0))) + + log(exp(1.0) / (exp(0.0) + exp(3.0) + exp(1.0))) + ), + }, + {"loss_fn": losses.pairwise_hinge_loss, "expected_value": 7.0}, + { + "loss_fn": losses.pairwise_logistic_loss, + "expected_value": 5.320569, + }, + { + "loss_fn": losses.pairwise_soft_zero_one_loss, + "expected_value": 2.850261, + }, + { + "loss_fn": losses.pointwise_sigmoid_loss, + "expected_value": ( + -log(1.0 - sigmoid(0.0)) + - log(1.0 - sigmoid(3.0)) + - 2.0 * log(sigmoid(2.0)) + - log(sigmoid(1.0)) + ), + }, + { + "loss_fn": losses.pointwise_mse_loss, + "expected_value": ( + (0.0 - 0.0) ** 2 + + (3.0 - 0.0) ** 2 + + 2.0 * (2.0 - 1.0) ** 2 + + (1.0 - 1.0) ** 2 + ), + }, + { + "loss_fn": losses.pairwise_mse_loss, + "expected_value": ( + ( + 1.0 + * ( + (-3.0 - 0.0) ** 2 + + (-2.0 - -1.0) ** 2 + + (-1.0 - -1.0) ** 2 + ) + ) + + ( + 1.0 + * ((3.0 - 0.0) ** 2 + (1.0 - -1.0) ** 2 + (2.0 - -1.0) ** 2) + ) + + ( + 2.0 + * ((2.0 - 1.0) ** 2 + (-1.0 - 1.0) ** 2 + (1.0 - 0.0) ** 2) + ) + + ( + 1.0 + * ((1.0 - 1.0) ** 2 + (-2.0 - 1.0) ** 2 + (-1.0 - 0.0) ** 2) + ) + ), + }, + ]) def test_computes_weighted_loss_value(self, loss_fn, expected_value): - scores = jnp.asarray([0., 3., 2., 1.]) - labels = jnp.asarray([0., 0., 1., 1.]) - weights = jnp.asarray([1., 1., 2., 1.]) + scores = jnp.asarray([0.0, 3.0, 2.0, 1.0]) + labels = jnp.asarray([0.0, 0.0, 1.0, 1.0]) + weights = jnp.asarray([1.0, 1.0, 2.0, 1.0]) loss = loss_fn(scores, labels, weights=weights, reduce_fn=jnp.sum) @@ -411,8 +516,8 @@ def test_computes_weighted_loss_value(self, loss_fn, expected_value): ] }]) # pyformat: disable def test_computes_loss_value_with_vmap(self, loss_fn, expected_value): - scores = jnp.asarray([[0., 3., 1., 2.], [3., 1., 4., 2.]]) - labels = jnp.asarray([[0., 0., 1., 1.], [2., 0., 1., 0.]]) + scores = jnp.asarray([[0.0, 3.0, 1.0, 2.0], [3.0, 1.0, 4.0, 2.0]]) + labels = jnp.asarray([[0.0, 0.0, 1.0, 1.0], [2.0, 0.0, 1.0, 0.0]]) loss_fn = functools.partial(loss_fn, reduce_fn=jnp.sum) vmap_loss_fn = jax.vmap(loss_fn, in_axes=(0, 0), out_axes=0) @@ -420,137 +525,149 @@ def test_computes_loss_value_with_vmap(self, loss_fn, expected_value): np.testing.assert_allclose(jnp.asarray(expected_value), loss) - @parameterized.parameters([{ - "loss_fn": losses.softmax_loss, - "expected_value": [ - -log(exp(2.) / (exp(2.) + exp(1.) + exp(3.))), - -log(exp(1.5) / (exp(1.) + exp(0.5) + exp(1.5))) - ], - "normalizer": 2. - }, { - "loss_fn": losses.listmle_loss, - "expected_value": [ - -sum([ - log(exp(2.) / (exp(2.) + exp(1.) + exp(3.))), - log(exp(1.) / (exp(1.) + exp(3.))), - log(exp(3.) / (exp(3.))), - ]), -sum([ - log(exp(1.5) / (exp(1.5) + exp(1.) + exp(0.5))), - log(exp(1.) / (exp(1.) + exp(0.5))), - log(exp(0.5) / (exp(0.5))), - ]) - ], - "normalizer": 2. - }, { - "loss_fn": losses.poly1_softmax_loss, - "expected_value": [ - -log(exp(2.) / (exp(2.) + exp(1.) + exp(3.))) + - (1. - (exp(2.) / (exp(2.) + exp(1.) + exp(3.)))), - -log(exp(1.5) / (exp(1.) + exp(0.5) + exp(1.5))) + - (1. - (exp(1.5) / (exp(1.) + exp(0.5) + exp(1.5)))) - ], - "normalizer": 2. - }, { - "loss_fn": losses.unique_softmax_loss, - "expected_value": [ - -log(exp(2.) / (exp(2.) + exp(1.) + exp(3.))), - -log(exp(1.5) / (exp(1.) + exp(0.5) + exp(1.5))) - ], - "normalizer": 2. - }, { - "loss_fn": losses.pairwise_hinge_loss, - "expected_value": [2., .5], - "normalizer": 4. - }, { - "loss_fn": losses.pairwise_logistic_loss, - "expected_value": [ - logloss(2. - 1.) + logloss(2. - 3.), - logloss(1.5 - 1.) + logloss(1.5 - 0.5) - ], - "normalizer": 4. - }, { - "loss_fn": losses.pairwise_soft_zero_one_loss, - "expected_value": [ - sigmoid(-(2. - 1.)) + sigmoid(-(2. - 3.)), - sigmoid(-(1.5 - 1.)) + sigmoid(-(1.5 - 0.5)) - ], - "normalizer": 4. - }, { - "loss_fn": losses.pointwise_sigmoid_loss, - "expected_value": [ - -log(sigmoid(2.)) - log(1. - sigmoid(1.)) - log(1. - sigmoid(3.)), - -log(sigmoid(1.5)) - log(1. - sigmoid(1.)) - log(1. - sigmoid(0.5)) - ], - "normalizer": 6. - }, { - "loss_fn": losses.pointwise_mse_loss, - "expected_value": [(2. - 1.)**2 + (1. - 0.)**2 + (3. - 0.)**2, - (1. - 0.)**2 + (0.5 - 0.)**2 + (1.5 - 1.)**2], - "normalizer": 6. - }, { - "loss_fn": losses.pairwise_mse_loss, - "expected_value": [(1. - 1.)**2 + (-1. - 1.)**2 + (-1. - -1.)**2 + - (-2. - 0.)**2 + (1. - -1.)**2 + (2. - 0.)**2, - (0.5 - 0.)**2 + (-0.5 - -1.)**2 + (-0.5 - 0.)**2 + - (-1. - -1.)**2 + (0.5 - 1.)**2 + (1. - 1.)**2], - "normalizer": 9. + 9. - }, { - "loss_fn": losses.pairwise_qr_loss, - "expected_value": [ - 0.5 * (((1. - 0.) - (2. - 1.)) + ((1. - 0.) - (2. - 3.))), - 0.5 * (((1. - 0.) - (1.5 - 1.)) + ((1. - 0.) - (1.5 - 0.5))) - ], - "normalizer": 4. - }]) + @parameterized.parameters([ + { + "loss_fn": losses.softmax_loss, + "expected_value": [ + -log(exp(2.0) / (exp(2.0) + exp(1.0) + exp(3.0))), + -log(exp(1.5) / (exp(1.0) + exp(0.5) + exp(1.5))), + ], + "normalizer": 2.0, + }, + { + "loss_fn": losses.listmle_loss, + "expected_value": [ + -sum([ + log(exp(2.0) / (exp(2.0) + exp(1.0) + exp(3.0))), + log(exp(1.0) / (exp(1.0) + exp(3.0))), + log(exp(3.0) / (exp(3.0))), + ]), + -sum([ + log(exp(1.5) / (exp(1.5) + exp(1.0) + exp(0.5))), + log(exp(1.0) / (exp(1.0) + exp(0.5))), + log(exp(0.5) / (exp(0.5))), + ]), + ], + "normalizer": 2.0, + }, + { + "loss_fn": losses.poly1_softmax_loss, + "expected_value": [ + -log(exp(2.0) / (exp(2.0) + exp(1.0) + exp(3.0))) + + (1.0 - (exp(2.0) / (exp(2.0) + exp(1.0) + exp(3.0)))), + -log(exp(1.5) / (exp(1.0) + exp(0.5) + exp(1.5))) + + (1.0 - (exp(1.5) / (exp(1.0) + exp(0.5) + exp(1.5)))), + ], + "normalizer": 2.0, + }, + { + "loss_fn": losses.unique_softmax_loss, + "expected_value": [ + -log(exp(2.0) / (exp(2.0) + exp(1.0) + exp(3.0))), + -log(exp(1.5) / (exp(1.0) + exp(0.5) + exp(1.5))), + ], + "normalizer": 2.0, + }, + { + "loss_fn": losses.pairwise_hinge_loss, + "expected_value": [2.0, 0.5], + "normalizer": 4.0, + }, + { + "loss_fn": losses.pairwise_logistic_loss, + "expected_value": [ + logloss(2.0 - 1.0) + logloss(2.0 - 3.0), + logloss(1.5 - 1.0) + logloss(1.5 - 0.5), + ], + "normalizer": 4.0, + }, + { + "loss_fn": losses.pairwise_soft_zero_one_loss, + "expected_value": [ + sigmoid(-(2.0 - 1.0)) + sigmoid(-(2.0 - 3.0)), + sigmoid(-(1.5 - 1.0)) + sigmoid(-(1.5 - 0.5)), + ], + "normalizer": 4.0, + }, + { + "loss_fn": losses.pointwise_sigmoid_loss, + "expected_value": [ + -log(sigmoid(2.0)) + - log(1.0 - sigmoid(1.0)) + - log(1.0 - sigmoid(3.0)), + -log(sigmoid(1.5)) + - log(1.0 - sigmoid(1.0)) + - log(1.0 - sigmoid(0.5)), + ], + "normalizer": 6.0, + }, + { + "loss_fn": losses.pointwise_mse_loss, + "expected_value": [ + (2.0 - 1.0) ** 2 + (1.0 - 0.0) ** 2 + (3.0 - 0.0) ** 2, + (1.0 - 0.0) ** 2 + (0.5 - 0.0) ** 2 + (1.5 - 1.0) ** 2, + ], + "normalizer": 6.0, + }, + { + "loss_fn": losses.pairwise_mse_loss, + "expected_value": [ + (1.0 - 1.0) ** 2 + + (-1.0 - 1.0) ** 2 + + (-1.0 - -1.0) ** 2 + + (-2.0 - 0.0) ** 2 + + (1.0 - -1.0) ** 2 + + (2.0 - 0.0) ** 2, + (0.5 - 0.0) ** 2 + + (-0.5 - -1.0) ** 2 + + (-0.5 - 0.0) ** 2 + + (-1.0 - -1.0) ** 2 + + (0.5 - 1.0) ** 2 + + (1.0 - 1.0) ** 2, + ], + "normalizer": 9.0 + 9.0, + }, + { + "loss_fn": losses.pairwise_qr_loss, + "expected_value": [ + 0.5 * (((1.0 - 0.0) - (2.0 - 1.0)) + ((1.0 - 0.0) - (2.0 - 3.0))), + 0.5 * (((1.0 - 0.0) - (1.5 - 1.0)) + ((1.0 - 0.0) - (1.5 - 0.5))), + ], + "normalizer": 4.0, + }, + ]) def test_computes_reduced_loss(self, loss_fn, expected_value, normalizer): - scores = jnp.array([[2., 1., 3.], [1., 0.5, 1.5]]) - labels = jnp.array([[1., 0., 0.], [0., 0., 1.]]) + scores = jnp.array([[2.0, 1.0, 3.0], [1.0, 0.5, 1.5]]) + labels = jnp.array([[1.0, 0.0, 0.0], [0.0, 0.0, 1.0]]) expected_value = jnp.asarray(expected_value) mean_loss = loss_fn(scores, labels, reduce_fn=jnp.mean) sum_loss = loss_fn(scores, labels, reduce_fn=jnp.sum) np.testing.assert_allclose( - mean_loss, jnp.sum(expected_value) / normalizer, rtol=1E-5) + mean_loss, jnp.sum(expected_value) / normalizer, rtol=1e-5 + ) np.testing.assert_allclose(sum_loss, jnp.sum(expected_value)) - @parameterized.parameters([{ - "loss_fn": losses.softmax_loss, - "expected_shape": (2,) - }, { - "loss_fn": losses.listmle_loss, - "expected_shape": (2,) - }, { - "loss_fn": losses.poly1_softmax_loss, - "expected_shape": (2,) - }, { - "loss_fn": losses.unique_softmax_loss, - "expected_shape": (2,) - }, { - "loss_fn": losses.pairwise_hinge_loss, - "expected_shape": (2, 9) - }, { - "loss_fn": losses.pairwise_logistic_loss, - "expected_shape": (2, 9) - }, { - "loss_fn": losses.pairwise_soft_zero_one_loss, - "expected_shape": (2, 9) - }, { - "loss_fn": losses.pairwise_mse_loss, - "expected_shape": (2, 9) - }, { - "loss_fn": losses.pairwise_qr_loss, - "expected_shape": (2, 9) - }, { - "loss_fn": losses.pointwise_sigmoid_loss, - "expected_shape": (2, 3) - }, { - "loss_fn": losses.pointwise_mse_loss, - "expected_shape": (2, 3) - }]) + @parameterized.parameters([ + {"loss_fn": losses.softmax_loss, "expected_shape": (2,)}, + {"loss_fn": losses.listmle_loss, "expected_shape": (2,)}, + {"loss_fn": losses.poly1_softmax_loss, "expected_shape": (2,)}, + {"loss_fn": losses.unique_softmax_loss, "expected_shape": (2,)}, + {"loss_fn": losses.pairwise_hinge_loss, "expected_shape": (2, 9)}, + {"loss_fn": losses.pairwise_logistic_loss, "expected_shape": (2, 9)}, + { + "loss_fn": losses.pairwise_soft_zero_one_loss, + "expected_shape": (2, 9), + }, + {"loss_fn": losses.pairwise_mse_loss, "expected_shape": (2, 9)}, + {"loss_fn": losses.pairwise_qr_loss, "expected_shape": (2, 9)}, + {"loss_fn": losses.pointwise_sigmoid_loss, "expected_shape": (2, 3)}, + {"loss_fn": losses.pointwise_mse_loss, "expected_shape": (2, 3)}, + ]) def test_computes_unreduced_loss(self, loss_fn, expected_shape): - scores = jnp.array([[2., 1., 3.], [1., 0.5, 1.5]]) - labels = jnp.array([[1., 0., 0.], [0., 0., 1.]]) + scores = jnp.array([[2.0, 1.0, 3.0], [1.0, 0.5, 1.5]]) + labels = jnp.array([[1.0, 0.0, 0.0], [0.0, 0.0, 1.0]]) none_loss = loss_fn(scores, labels, reduce_fn=None) sum_loss = loss_fn(scores, labels, reduce_fn=jnp.sum) @@ -587,11 +704,9 @@ def test_computes_loss_value_with_segments(self, loss_fn): ) loss = loss_fn(scores, labels, segments=segments) - expected_loss = loss_fn( - list_scores, list_labels, where=list_mask - ) + expected_loss = loss_fn(list_scores, list_labels, where=list_mask) - np.testing.assert_allclose(expected_loss, loss, rtol=1E-5) + np.testing.assert_allclose(expected_loss, loss, rtol=1e-5) @parameterized.parameters([ losses.pointwise_mse_loss, @@ -623,11 +738,9 @@ def test_computes_loss_value_with_segments_and_mask(self, loss_fn): ) loss = loss_fn(scores, labels, segments=segments, where=where) - expected_loss = loss_fn( - list_scores, list_labels, where=list_mask - ) + expected_loss = loss_fn(list_scores, list_labels, where=list_mask) - np.testing.assert_allclose(expected_loss, loss, rtol=1E-5) + np.testing.assert_allclose(expected_loss, loss, rtol=1e-5) @parameterized.parameters([ losses.softmax_loss, @@ -643,11 +756,11 @@ def test_computes_loss_value_with_segments_and_mask(self, loss_fn): losses.unique_softmax_loss, ]) def test_computes_loss_value_with_where(self, loss_fn): - scores = jnp.asarray([0., 3., 1., 2.]) - labels = jnp.asarray([0., 0., 2., 1.]) + scores = jnp.asarray([0.0, 3.0, 1.0, 2.0]) + labels = jnp.asarray([0.0, 0.0, 2.0, 1.0]) where = jnp.asarray([True, True, True, False]) - expected_scores = jnp.asarray([0., 3., 1.]) - expected_labels = jnp.asarray([0., 0., 2.]) + expected_scores = jnp.asarray([0.0, 3.0, 1.0]) + expected_labels = jnp.asarray([0.0, 0.0, 2.0]) loss = loss_fn(scores, labels, where=where) expected_loss = loss_fn(expected_scores, expected_labels) @@ -668,13 +781,13 @@ def test_computes_loss_value_with_where(self, loss_fn): losses.unique_softmax_loss, ]) def test_computes_loss_value_with_all_masked(self, loss_fn): - scores = jnp.asarray([0., 3., 1., 2.]) - labels = jnp.asarray([0., 0., 1., 1.]) + scores = jnp.asarray([0.0, 3.0, 1.0, 2.0]) + labels = jnp.asarray([0.0, 0.0, 1.0, 1.0]) where = jnp.asarray([False, False, False, False]) loss = loss_fn(scores, labels, where=where) - np.testing.assert_allclose(jnp.asarray(0.), loss, atol=1E-7) + np.testing.assert_allclose(jnp.asarray(0.0), loss, atol=1e-7) @parameterized.parameters([ losses.softmax_loss, @@ -690,13 +803,13 @@ def test_computes_loss_value_with_all_masked(self, loss_fn): losses.unique_softmax_loss, ]) def test_computes_loss_with_arbitrary_batch_dimensions(self, loss_fn): - scores = jnp.asarray([2., 3., 1.]) - labels = jnp.asarray([0., 0., 1.]) + scores = jnp.asarray([2.0, 3.0, 1.0]) + labels = jnp.asarray([0.0, 0.0, 1.0]) where = jnp.asarray([False, True, True]) original_loss = loss_fn(scores, labels, where=where) - scores = jnp.asarray([[[[2., 3., 1.]]]]) - labels = jnp.asarray([[[[0., 0., 1.]]]]) + scores = jnp.asarray([[[[2.0, 3.0, 1.0]]]]) + labels = jnp.asarray([[[[0.0, 0.0, 1.0]]]]) where = jnp.asarray([[[[False, True, True]]]]) batched_loss = loss_fn(scores, labels, where=where) @@ -716,13 +829,14 @@ def test_computes_loss_with_arbitrary_batch_dimensions(self, loss_fn): losses.unique_softmax_loss, ]) def test_grad_does_not_return_nan_for_zero_labels(self, loss_fn): - scores = jnp.asarray([0., 3., 1., 2.]) - labels = jnp.asarray([0., 0., 0., 0.]) + scores = jnp.asarray([0.0, 3.0, 1.0, 2.0]) + labels = jnp.asarray([0.0, 0.0, 0.0, 0.0]) grads = jax.grad(loss_fn)(scores, labels, reduce_fn=jnp.mean) np.testing.assert_array_equal( - jnp.isnan(grads), jnp.zeros_like(jnp.isnan(grads))) + jnp.isnan(grads), jnp.zeros_like(jnp.isnan(grads)) + ) @parameterized.parameters([ losses.softmax_loss, @@ -738,14 +852,15 @@ def test_grad_does_not_return_nan_for_zero_labels(self, loss_fn): losses.unique_softmax_loss, ]) def test_grad_does_not_return_nan_with_all_masked(self, loss_fn): - scores = jnp.asarray([0., 3., 1., 2.]) - labels = jnp.asarray([0., 0., 1., 1.]) + scores = jnp.asarray([0.0, 3.0, 1.0, 2.0]) + labels = jnp.asarray([0.0, 0.0, 1.0, 1.0]) where = jnp.asarray([False, False, False, False]) grads = jax.grad(loss_fn)(scores, labels, where=where, reduce_fn=jnp.mean) np.testing.assert_array_equal( - jnp.isnan(grads), jnp.zeros_like(jnp.isnan(grads))) + jnp.isnan(grads), jnp.zeros_like(jnp.isnan(grads)) + ) @parameterized.parameters([ losses.softmax_loss, @@ -774,11 +889,8 @@ def test_ignores_lists_containing_only_invalid_items(self, loss_fn): def load_tests(loader, tests, ignore): del loader, ignore # Unused. tests.addTests( - doctest.DocTestSuite(losses, globs={ - "jax": jax, - "jnp": jnp, - "rax": rax - })) + doctest.DocTestSuite(losses, globs={"jax": jax, "jnp": jnp, "rax": rax}) + ) return tests diff --git a/rax/_src/metrics.py b/rax/_src/metrics.py index 28b9e1e..ebfa3a5 100644 --- a/rax/_src/metrics.py +++ b/rax/_src/metrics.py @@ -122,7 +122,7 @@ def default_gain_fn(label: Array) -> Array: Returns: The gain value for given label. """ - return jnp.power(2., label) - 1. + return jnp.power(2.0, label) - 1.0 def default_discount_fn(rank: Array) -> Array: @@ -139,7 +139,7 @@ def default_discount_fn(rank: Array) -> Array: Returns: The discount value for given rank. """ - return 1. / jnp.log2(rank + 1) + return 1.0 / jnp.log2(rank + 1) def mrr_metric( @@ -196,8 +196,9 @@ def mrr_metric( The MRR metric. """ # Get the relevant items. - relevant_items = jnp.where(labels >= 1, jnp.ones_like(labels), - jnp.zeros_like(labels)) + relevant_items = jnp.where( + labels >= 1, jnp.ones_like(labels), jnp.zeros_like(labels) + ) # Get the retrieved items. ranks = rank_fn(scores, where=where, segments=segments, key=key) @@ -211,7 +212,7 @@ def mrr_metric( ) # Compute reciprocal ranks. - reciprocal_ranks = jnp.reciprocal(jnp.where(ranks == 0., jnp.inf, ranks)) + reciprocal_ranks = jnp.reciprocal(jnp.where(ranks == 0.0, jnp.inf, ranks)) # Get the maximum reciprocal rank. if segments is not None: @@ -219,7 +220,7 @@ def mrr_metric( relevant_items * retrieved_items * reciprocal_ranks, segments, where=where, - initial=0.0 + initial=0.0, ) else: values = jnp.max( @@ -296,8 +297,9 @@ def recall_metric( The recall metric. """ # Get the relevant items. - relevant_items = jnp.where(labels >= 1, jnp.ones_like(labels), - jnp.zeros_like(labels)) + relevant_items = jnp.where( + labels >= 1, jnp.ones_like(labels), jnp.zeros_like(labels) + ) # Get the retrieved items. ranks = rank_fn(scores, where=where, segments=segments, key=key) @@ -325,7 +327,7 @@ def recall_metric( n_relevant = jnp.sum(relevant_items, where=where, axis=-1) # Compute recall but prevent division by zero. - n_relevant = jnp.where(n_relevant == 0, 1., n_relevant) + n_relevant = jnp.where(n_relevant == 0, 1.0, n_relevant) values = n_retrieved_relevant / n_relevant # In the segmented case, values retain their list dimension. This constructs @@ -395,8 +397,9 @@ def precision_metric( The precision metric. """ # Get the relevant items. - relevant_items = jnp.where(labels >= 1, jnp.ones_like(labels), - jnp.zeros_like(labels)) + relevant_items = jnp.where( + labels >= 1, jnp.ones_like(labels), jnp.zeros_like(labels) + ) # Get the retrieved items. ranks = rank_fn(scores, where=where, segments=segments, key=key) @@ -424,7 +427,7 @@ def precision_metric( n_retrieved = jnp.sum(retrieved_items, where=where, axis=-1) # Compute precision but prevent division by zero. - n_retrieved = jnp.where(n_retrieved == 0, 1., n_retrieved) + n_retrieved = jnp.where(n_retrieved == 0, 1.0, n_retrieved) values = n_retrieved_relevant / n_retrieved # In the segmented case, values retain their list dimension. This constructs @@ -494,8 +497,9 @@ def ap_metric( The average precision metric. """ # Get the relevant items. - relevant_items = jnp.where(labels >= 1, jnp.ones_like(labels), - jnp.zeros_like(labels)) + relevant_items = jnp.where( + labels >= 1, jnp.ones_like(labels), jnp.zeros_like(labels) + ) # Get the retrieved items. ranks = rank_fn(scores, where=where, segments=segments, key=key) @@ -534,7 +538,7 @@ def ap_metric( n_relevant = jnp.sum(relevant_items, where=where, axis=-1) # Compute average precision but prevent division by zero. - n_relevant = jnp.where(n_relevant == 0, 1., n_relevant) + n_relevant = jnp.where(n_relevant == 0, 1.0, n_relevant) values = sum_prec_at_k / n_relevant # In the segmented case, values retain their list dimension. This constructs @@ -711,7 +715,8 @@ def ndcg_metric( discount_fn=discount_fn, rank_fn=rank_fn, cutoff_fn=cutoff_fn, - reduce_fn=None) + reduce_fn=None, + ) # The ideal dcg is computed by ordering items by their (weighted) gains. ideal_scores = gain_fn(labels) @@ -729,10 +734,11 @@ def ndcg_metric( discount_fn=discount_fn, rank_fn=utils.ranks, cutoff_fn=utils.cutoff, - reduce_fn=None) + reduce_fn=None, + ) # Compute the result as `dcg / ideal_dcg` while preventing division by zero. - ideal_dcg = jnp.where(ideal_dcg == 0., 1., ideal_dcg) + ideal_dcg = jnp.where(ideal_dcg == 0.0, 1.0, ideal_dcg) values = regular_dcg / ideal_dcg # In the segmented case, values retain their list dimension. This constructs diff --git a/rax/_src/metrics_test.py b/rax/_src/metrics_test.py index 0487be0..11d70d1 100644 --- a/rax/_src/metrics_test.py +++ b/rax/_src/metrics_test.py @@ -33,298 +33,309 @@ class MetricsTest(parameterized.TestCase): - @parameterized.parameters([{ - "metric_fn": metrics.mrr_metric, - "expected_value": 1 / 2 - }, { - "metric_fn": functools.partial(metrics.recall_metric, topn=None), - "expected_value": 1. - }, { - "metric_fn": functools.partial(metrics.recall_metric, topn=3), - "expected_value": 1. - }, { - "metric_fn": functools.partial(metrics.precision_metric, topn=None), - "expected_value": 0.5 - }, { - "metric_fn": functools.partial(metrics.precision_metric, topn=3), - "expected_value": 2. / 3. - }, { - "metric_fn": metrics.ap_metric, - "expected_value": (0.5 + 2. / 3.) / 2. - }, { - "metric_fn": metrics.dcg_metric, - "expected_value": 1 / log2(1 + 2) + 1 / log2(1 + 3) - }, { - "metric_fn": - metrics.ndcg_metric, - "expected_value": (1 / log2(1 + 2) + 1 / log2(1 + 3)) / - (1 / log2(1 + 1) + 1 / log2(1 + 2)) - }]) + @parameterized.parameters([ + {"metric_fn": metrics.mrr_metric, "expected_value": 1 / 2}, + { + "metric_fn": functools.partial(metrics.recall_metric, topn=None), + "expected_value": 1.0, + }, + { + "metric_fn": functools.partial(metrics.recall_metric, topn=3), + "expected_value": 1.0, + }, + { + "metric_fn": functools.partial(metrics.precision_metric, topn=None), + "expected_value": 0.5, + }, + { + "metric_fn": functools.partial(metrics.precision_metric, topn=3), + "expected_value": 2.0 / 3.0, + }, + { + "metric_fn": metrics.ap_metric, + "expected_value": (0.5 + 2.0 / 3.0) / 2.0, + }, + { + "metric_fn": metrics.dcg_metric, + "expected_value": 1 / log2(1 + 2) + 1 / log2(1 + 3), + }, + { + "metric_fn": metrics.ndcg_metric, + "expected_value": (1 / log2(1 + 2) + 1 / log2(1 + 3)) / ( + 1 / log2(1 + 1) + 1 / log2(1 + 2) + ), + }, + ]) def test_computes_metric_value(self, metric_fn, expected_value): - scores = jnp.asarray([0., 3., 1., 2.]) - labels = jnp.asarray([0., 0., 1., 1.]) + scores = jnp.asarray([0.0, 3.0, 1.0, 2.0]) + labels = jnp.asarray([0.0, 0.0, 1.0, 1.0]) metric = metric_fn(scores, labels) - np.testing.assert_allclose(jnp.asarray(expected_value), metric, rtol=1E-5) - - @parameterized.parameters([{ - "metric_fn": metrics.mrr_metric, - "expected_value": [1 / 2, 1 / 3] - }, { - "metric_fn": functools.partial(metrics.recall_metric, topn=None), - "expected_value": [1., 1.] - }, { - "metric_fn": functools.partial(metrics.recall_metric, topn=3), - "expected_value": [1., 1. / 2.] - }, { - "metric_fn": functools.partial(metrics.precision_metric, topn=None), - "expected_value": [0.5, 0.5] - }, { - "metric_fn": functools.partial(metrics.precision_metric, topn=3), - "expected_value": [2. / 3., 1. / 3.] - }, { - "metric_fn": metrics.ap_metric, - "expected_value": [(0.5 + 2. / 3.) / 2., (1. / 3. + 0.5) / 2.] - }, { - "metric_fn": - metrics.dcg_metric, - "expected_value": [ - 1 / log2(1 + 2) + 1 / log2(1 + 3), - 1 / log2(1 + 3) + (2**2 - 1) / log2(1 + 4) - ] - }, { - "metric_fn": - metrics.ndcg_metric, - "expected_value": [(1 / log2(1 + 2) + 1 / log2(1 + 3)) / - (1 / log2(1 + 1) + 1 / log2(1 + 2)), - (1 / log2(1 + 3) + (2**2 - 1) / log2(1 + 4)) / - ((2**2 - 1) / log2(1 + 1) + 1 / log2(1 + 2))] - }]) - def test_computes_metric_value_on_batch_with_vmap(self, metric_fn, - expected_value): - scores = jnp.asarray([[0., 3., 1., 2.], [1., 4., 3., 2.]]) - labels = jnp.asarray([[0., 0., 1., 1.], [2., 0., 0., 1.]]) + np.testing.assert_allclose(jnp.asarray(expected_value), metric, rtol=1e-5) + + @parameterized.parameters([ + {"metric_fn": metrics.mrr_metric, "expected_value": [1 / 2, 1 / 3]}, + { + "metric_fn": functools.partial(metrics.recall_metric, topn=None), + "expected_value": [1.0, 1.0], + }, + { + "metric_fn": functools.partial(metrics.recall_metric, topn=3), + "expected_value": [1.0, 1.0 / 2.0], + }, + { + "metric_fn": functools.partial(metrics.precision_metric, topn=None), + "expected_value": [0.5, 0.5], + }, + { + "metric_fn": functools.partial(metrics.precision_metric, topn=3), + "expected_value": [2.0 / 3.0, 1.0 / 3.0], + }, + { + "metric_fn": metrics.ap_metric, + "expected_value": [ + (0.5 + 2.0 / 3.0) / 2.0, + (1.0 / 3.0 + 0.5) / 2.0, + ], + }, + { + "metric_fn": metrics.dcg_metric, + "expected_value": [ + 1 / log2(1 + 2) + 1 / log2(1 + 3), + 1 / log2(1 + 3) + (2**2 - 1) / log2(1 + 4), + ], + }, + { + "metric_fn": metrics.ndcg_metric, + "expected_value": [ + (1 / log2(1 + 2) + 1 / log2(1 + 3)) + / (1 / log2(1 + 1) + 1 / log2(1 + 2)), + (1 / log2(1 + 3) + (2**2 - 1) / log2(1 + 4)) + / ((2**2 - 1) / log2(1 + 1) + 1 / log2(1 + 2)), + ], + }, + ]) + def test_computes_metric_value_on_batch_with_vmap( + self, metric_fn, expected_value + ): + scores = jnp.asarray([[0.0, 3.0, 1.0, 2.0], [1.0, 4.0, 3.0, 2.0]]) + labels = jnp.asarray([[0.0, 0.0, 1.0, 1.0], [2.0, 0.0, 0.0, 1.0]]) vmap_metric_fn = jax.vmap(metric_fn, in_axes=(0, 0), out_axes=0) metric = vmap_metric_fn(scores, labels) - np.testing.assert_allclose(jnp.asarray(expected_value), metric, rtol=1E-5) - - @parameterized.parameters([{ - "metric_fn": metrics.mrr_metric, - "expected_value": [1. / 2., 1.], - "normalizer": 2. - }, { - "metric_fn": functools.partial(metrics.precision_metric, topn=2), - "expected_value": [1. / 2., 1. / 2.], - "normalizer": 2. - }, { - "metric_fn": functools.partial(metrics.recall_metric, topn=2), - "expected_value": [1. / 1., 1. / 2.], - "normalizer": 2. - }, { - "metric_fn": metrics.ap_metric, - "expected_value": [1. / 2., (1. + 2. / 3.) / 2.], - "normalizer": 2. - }, { - "metric_fn": metrics.dcg_metric, - "expected_value": [1. / log2(1 + 2), 3. / log2(1 + 1) + 1. / log2(1 + 3)], - "normalizer": 2. - }, { - "metric_fn": metrics.ndcg_metric, - "expected_value": [ - (1. / log2(1 + 2)), - (3. / log2(1 + 1) + 1. / log2(1 + 3)) / (3. + 1. / log2(1 + 2)) - ], - "normalizer": 2. - }]) + np.testing.assert_allclose(jnp.asarray(expected_value), metric, rtol=1e-5) + + @parameterized.parameters([ + { + "metric_fn": metrics.mrr_metric, + "expected_value": [1.0 / 2.0, 1.0], + "normalizer": 2.0, + }, + { + "metric_fn": functools.partial(metrics.precision_metric, topn=2), + "expected_value": [1.0 / 2.0, 1.0 / 2.0], + "normalizer": 2.0, + }, + { + "metric_fn": functools.partial(metrics.recall_metric, topn=2), + "expected_value": [1.0 / 1.0, 1.0 / 2.0], + "normalizer": 2.0, + }, + { + "metric_fn": metrics.ap_metric, + "expected_value": [1.0 / 2.0, (1.0 + 2.0 / 3.0) / 2.0], + "normalizer": 2.0, + }, + { + "metric_fn": metrics.dcg_metric, + "expected_value": [ + 1.0 / log2(1 + 2), + 3.0 / log2(1 + 1) + 1.0 / log2(1 + 3), + ], + "normalizer": 2.0, + }, + { + "metric_fn": metrics.ndcg_metric, + "expected_value": [ + (1.0 / log2(1 + 2)), + (3.0 / log2(1 + 1) + 1.0 / log2(1 + 3)) + / (3.0 + 1.0 / log2(1 + 2)), + ], + "normalizer": 2.0, + }, + ]) def test_computes_reduced_metric(self, metric_fn, expected_value, normalizer): - scores = jnp.array([[2., 1., 3.], [1., 0.5, 1.5]]) - labels = jnp.array([[1., 0., 0.], [0., 1., 2.]]) + scores = jnp.array([[2.0, 1.0, 3.0], [1.0, 0.5, 1.5]]) + labels = jnp.array([[1.0, 0.0, 0.0], [0.0, 1.0, 2.0]]) expected_value = jnp.asarray(expected_value) mean_metric = metric_fn(scores, labels, reduce_fn=jnp.mean) sum_metric = metric_fn(scores, labels, reduce_fn=jnp.sum) - np.testing.assert_allclose(mean_metric, - jnp.sum(expected_value) / normalizer) + np.testing.assert_allclose( + mean_metric, jnp.sum(expected_value) / normalizer + ) np.testing.assert_allclose(sum_metric, jnp.sum(expected_value)) @parameterized.parameters([(metrics.mrr_metric, (2,))]) def test_computes_unreduced_metric(self, metric_fn, expected_shape): - scores = jnp.array([[2., 1., 3.], [1., 0.5, 1.5]]) - labels = jnp.array([[1., 0., 0.], [0., 1., 2.]]) + scores = jnp.array([[2.0, 1.0, 3.0], [1.0, 0.5, 1.5]]) + labels = jnp.array([[1.0, 0.0, 0.0], [0.0, 1.0, 2.0]]) result = metric_fn(scores, labels, reduce_fn=None) self.assertEqual(result.shape, expected_shape) - @parameterized.parameters([{ - "metric_fn": metrics.dcg_metric, - "expected_value": 2 / log2(1 + 2) + 1 / log2(1 + 3) - }, { - "metric_fn": - metrics.ndcg_metric, - "expected_value": (2 / log2(1 + 2) + 1 / log2(1 + 3)) / - (2 / log2(1 + 1) + 1 / log2(1 + 2)) - }]) + @parameterized.parameters([ + { + "metric_fn": metrics.dcg_metric, + "expected_value": 2 / log2(1 + 2) + 1 / log2(1 + 3), + }, + { + "metric_fn": metrics.ndcg_metric, + "expected_value": (2 / log2(1 + 2) + 1 / log2(1 + 3)) / ( + 2 / log2(1 + 1) + 1 / log2(1 + 2) + ), + }, + ]) def test_computes_weighted_metric(self, metric_fn, expected_value): - scores = jnp.asarray([0., 3., 1., 2.]) - labels = jnp.asarray([0., 0., 1., 1.]) - weights = jnp.asarray([1., 2., 1., 2.]) + scores = jnp.asarray([0.0, 3.0, 1.0, 2.0]) + labels = jnp.asarray([0.0, 0.0, 1.0, 1.0]) + weights = jnp.asarray([1.0, 2.0, 1.0, 2.0]) metric = metric_fn(scores, labels, weights=weights) np.testing.assert_allclose(jnp.asarray(expected_value), metric) - @parameterized.parameters([{ - "metric_fn": metrics.mrr_metric, - "expected_value": 0.0 - }, { - "metric_fn": metrics.recall_metric, - "expected_value": 0.0 - }, { - "metric_fn": metrics.precision_metric, - "expected_value": 0.0 - }, { - "metric_fn": metrics.ap_metric, - "expected_value": 0.0 - }, { - "metric_fn": metrics.dcg_metric, - "expected_value": 0.0 - }, { - "metric_fn": metrics.ndcg_metric, - "expected_value": 0.0 - }]) + @parameterized.parameters([ + {"metric_fn": metrics.mrr_metric, "expected_value": 0.0}, + {"metric_fn": metrics.recall_metric, "expected_value": 0.0}, + {"metric_fn": metrics.precision_metric, "expected_value": 0.0}, + {"metric_fn": metrics.ap_metric, "expected_value": 0.0}, + {"metric_fn": metrics.dcg_metric, "expected_value": 0.0}, + {"metric_fn": metrics.ndcg_metric, "expected_value": 0.0}, + ]) def test_computes_metric_with_topn_1(self, metric_fn, expected_value): - scores = jnp.asarray([0., 3., 1., 2.]) - labels = jnp.asarray([1., 0., 1., 1.]) + scores = jnp.asarray([0.0, 3.0, 1.0, 2.0]) + labels = jnp.asarray([1.0, 0.0, 1.0, 1.0]) metric = metric_fn(scores, labels, topn=1) np.testing.assert_allclose(jnp.asarray(expected_value), metric) - @parameterized.parameters([{ - "metric_fn": metrics.mrr_metric, - "expected_value": 1 / 2 - }, { - "metric_fn": metrics.recall_metric, - "expected_value": 2 / 3 - }, { - "metric_fn": metrics.precision_metric, - "expected_value": 2 / 3 - }, { - "metric_fn": metrics.ap_metric, - "expected_value": (1 / 2 + 2 / 3) / 3 - }, { - "metric_fn": metrics.dcg_metric, - "expected_value": 1 / log2(1 + 2) + 1 / log2(1 + 3) - }, { - "metric_fn": - metrics.ndcg_metric, - "expected_value": (1 / log2(1 + 2) + 1 / log2(1 + 3)) / - (1 / log2(1 + 1) + 1 / log2(1 + 2) + 1 / log2(1 + 3)) - }]) + @parameterized.parameters([ + {"metric_fn": metrics.mrr_metric, "expected_value": 1 / 2}, + {"metric_fn": metrics.recall_metric, "expected_value": 2 / 3}, + {"metric_fn": metrics.precision_metric, "expected_value": 2 / 3}, + { + "metric_fn": metrics.ap_metric, + "expected_value": (1 / 2 + 2 / 3) / 3, + }, + { + "metric_fn": metrics.dcg_metric, + "expected_value": 1 / log2(1 + 2) + 1 / log2(1 + 3), + }, + { + "metric_fn": metrics.ndcg_metric, + "expected_value": (1 / log2(1 + 2) + 1 / log2(1 + 3)) / ( + 1 / log2(1 + 1) + 1 / log2(1 + 2) + 1 / log2(1 + 3) + ), + }, + ]) def test_computes_metric_with_topn_3(self, metric_fn, expected_value): - scores = jnp.asarray([0., 3., 1., 2.]) - labels = jnp.asarray([1., 0., 1., 1.]) + scores = jnp.asarray([0.0, 3.0, 1.0, 2.0]) + labels = jnp.asarray([1.0, 0.0, 1.0, 1.0]) metric = metric_fn(scores, labels, topn=3) np.testing.assert_allclose(jnp.asarray(expected_value), metric) - @parameterized.parameters([{ - "metric_fn": metrics.mrr_metric, - "expected_value": 0. - }, { - "metric_fn": functools.partial(metrics.recall_metric, topn=None), - "expected_value": 0. - }, { - "metric_fn": functools.partial(metrics.recall_metric, topn=3), - "expected_value": 0. - }, { - "metric_fn": functools.partial(metrics.precision_metric, topn=None), - "expected_value": 0. - }, { - "metric_fn": functools.partial(metrics.precision_metric, topn=3), - "expected_value": 0. - }, { - "metric_fn": metrics.ap_metric, - "expected_value": 0. - }, { - "metric_fn": metrics.dcg_metric, - "expected_value": 0. - }, { - "metric_fn": metrics.ndcg_metric, - "expected_value": 0. - }]) - def test_computes_metric_value_with_all_masked(self, metric_fn, - expected_value): - scores = jnp.asarray([0., 3., 1., 2.]) - labels = jnp.asarray([0., 0., 1., 1.]) + @parameterized.parameters([ + {"metric_fn": metrics.mrr_metric, "expected_value": 0.0}, + { + "metric_fn": functools.partial(metrics.recall_metric, topn=None), + "expected_value": 0.0, + }, + { + "metric_fn": functools.partial(metrics.recall_metric, topn=3), + "expected_value": 0.0, + }, + { + "metric_fn": functools.partial(metrics.precision_metric, topn=None), + "expected_value": 0.0, + }, + { + "metric_fn": functools.partial(metrics.precision_metric, topn=3), + "expected_value": 0.0, + }, + {"metric_fn": metrics.ap_metric, "expected_value": 0.0}, + {"metric_fn": metrics.dcg_metric, "expected_value": 0.0}, + {"metric_fn": metrics.ndcg_metric, "expected_value": 0.0}, + ]) + def test_computes_metric_value_with_all_masked( + self, metric_fn, expected_value + ): + scores = jnp.asarray([0.0, 3.0, 1.0, 2.0]) + labels = jnp.asarray([0.0, 0.0, 1.0, 1.0]) where = jnp.asarray([False, False, False, False]) metric = metric_fn(scores, labels, where=where) np.testing.assert_allclose(jnp.asarray(expected_value), metric) - @parameterized.parameters([{ - "metric_fn": metrics.mrr_metric, - "expected_value": 0. - }, { - "metric_fn": functools.partial(metrics.recall_metric, topn=None), - "expected_value": 0. - }, { - "metric_fn": functools.partial(metrics.recall_metric, topn=3), - "expected_value": 0. - }, { - "metric_fn": functools.partial(metrics.precision_metric, topn=None), - "expected_value": 0. - }, { - "metric_fn": functools.partial(metrics.precision_metric, topn=3), - "expected_value": 0. - }, { - "metric_fn": metrics.ap_metric, - "expected_value": 0. - }, { - "metric_fn": metrics.dcg_metric, - "expected_value": 0. - }, { - "metric_fn": metrics.ndcg_metric, - "expected_value": 0. - }]) + @parameterized.parameters([ + {"metric_fn": metrics.mrr_metric, "expected_value": 0.0}, + { + "metric_fn": functools.partial(metrics.recall_metric, topn=None), + "expected_value": 0.0, + }, + { + "metric_fn": functools.partial(metrics.recall_metric, topn=3), + "expected_value": 0.0, + }, + { + "metric_fn": functools.partial(metrics.precision_metric, topn=None), + "expected_value": 0.0, + }, + { + "metric_fn": functools.partial(metrics.precision_metric, topn=3), + "expected_value": 0.0, + }, + {"metric_fn": metrics.ap_metric, "expected_value": 0.0}, + {"metric_fn": metrics.dcg_metric, "expected_value": 0.0}, + {"metric_fn": metrics.ndcg_metric, "expected_value": 0.0}, + ]) def test_computes_metric_value_with_no_relevant_labels( - self, metric_fn, expected_value): - scores = jnp.asarray([0., 3., 1., 2.]) - labels = jnp.asarray([0., 0., 0., 0.]) + self, metric_fn, expected_value + ): + scores = jnp.asarray([0.0, 3.0, 1.0, 2.0]) + labels = jnp.asarray([0.0, 0.0, 0.0, 0.0]) metric = metric_fn(scores, labels) np.testing.assert_allclose(jnp.asarray(expected_value), metric) - @parameterized.parameters([{ - "metric_fn": metrics.mrr_metric, - "expected_value": 1. / 2. - }, { - "metric_fn": metrics.precision_metric, - "expected_value": 1. / 3. - }, { - "metric_fn": metrics.recall_metric, - "expected_value": 1. / 2. - }, { - "metric_fn": metrics.ap_metric, - "expected_value": (1. / 2.) / 2. - }, { - "metric_fn": metrics.dcg_metric, - "expected_value": 1. / log2(1 + 2) - }, { - "metric_fn": metrics.ndcg_metric, - "expected_value": 1. / log2(1 + 2) / (1. / log2(1 + 1) + 1. / log2(1 + 2)) - }]) + @parameterized.parameters([ + {"metric_fn": metrics.mrr_metric, "expected_value": 1.0 / 2.0}, + {"metric_fn": metrics.precision_metric, "expected_value": 1.0 / 3.0}, + {"metric_fn": metrics.recall_metric, "expected_value": 1.0 / 2.0}, + {"metric_fn": metrics.ap_metric, "expected_value": (1.0 / 2.0) / 2.0}, + { + "metric_fn": metrics.dcg_metric, + "expected_value": 1.0 / log2(1 + 2), + }, + { + "metric_fn": metrics.ndcg_metric, + "expected_value": ( + 1.0 / log2(1 + 2) / (1.0 / log2(1 + 1) + 1.0 / log2(1 + 2)) + ), + }, + ]) def test_treats_neginf_as_unranked_items(self, metric_fn, expected_value): scores = jnp.array([-jnp.inf, 5, 2, -jnp.inf, 1]) - labels = jnp.array([1., 0, 1, 0, 0]) + labels = jnp.array([1.0, 0, 1, 0, 0]) metric = metric_fn(scores, labels) @@ -457,29 +468,34 @@ def test_computes_metric_with_segments_and_cutoff(self, metric_fn): class RetrievedItemsTest(parameterized.TestCase): def test_does_not_retrieve_items_with_neginf_scores(self): - scores = jnp.array([-2., -jnp.inf, 4., 3.]) - ranks = jnp.array([3., 4., 1., 2.]) + scores = jnp.array([-2.0, -jnp.inf, 4.0, 3.0]) + ranks = jnp.array([3.0, 4.0, 1.0, 2.0]) retrieved_items = metrics._retrieved_items(scores, ranks) - np.testing.assert_array_equal(jnp.array([1., 0, 1, 1]), retrieved_items) + np.testing.assert_array_equal(jnp.array([1.0, 0, 1, 1]), retrieved_items) def test_does_not_retrieve_masked_items(self): - scores = jnp.array([-2., 1., 4., 3.]) - ranks = jnp.array([4., 3., 1., 2.]) + scores = jnp.array([-2.0, 1.0, 4.0, 3.0]) + ranks = jnp.array([4.0, 3.0, 1.0, 2.0]) where = jnp.array([True, False, True, True]) retrieved_items = metrics._retrieved_items(scores, ranks, where=where) - np.testing.assert_array_equal(jnp.array([1., 0, 1, 1]), retrieved_items) + np.testing.assert_array_equal(jnp.array([1.0, 0, 1, 1]), retrieved_items) - @parameterized.parameters([(0, [0., 0, 0, 0]), (1, [0., 0, 1, 0]), - (2, [0., 0, 1, 1]), (3, [0., 1, 1, 1]), - (4, [1., 1, 1, 1]), (10, [1., 1, 1, 1]), - (None, [1., 1, 1, 1])]) + @parameterized.parameters([ + (0, [0.0, 0, 0, 0]), + (1, [0.0, 0, 1, 0]), + (2, [0.0, 0, 1, 1]), + (3, [0.0, 1, 1, 1]), + (4, [1.0, 1, 1, 1]), + (10, [1.0, 1, 1, 1]), + (None, [1.0, 1, 1, 1]), + ]) def test_does_not_retrieve_items_beyond_topn(self, topn, expected): - scores = jnp.array([-2., 1., 4., 3.]) - ranks = jnp.array([4., 3., 1., 2.]) + scores = jnp.array([-2.0, 1.0, 4.0, 3.0]) + ranks = jnp.array([4.0, 3.0, 1.0, 2.0]) retrieved_items = metrics._retrieved_items(scores, ranks, topn=topn) @@ -491,12 +507,9 @@ def load_tests(loader, tests, ignore): tests.addTests( doctest.DocTestSuite( metrics, - globs={ - "functools": functools, - "jax": jax, - "jnp": jnp, - "rax": rax - })) + globs={"functools": functools, "jax": jax, "jnp": jnp, "rax": rax}, + ) + ) return tests diff --git a/rax/_src/segment_utils.py b/rax/_src/segment_utils.py index 75716fe..9083f15 100644 --- a/rax/_src/segment_utils.py +++ b/rax/_src/segment_utils.py @@ -55,14 +55,12 @@ def segment_max( jnp.broadcast_to(jnp.expand_dims(a, -2), mask.shape), axis=-1, where=mask, - initial=initial + initial=initial, ) def segment_log_softmax( - a: Array, - segments: Array, - where: Optional[Array] = None + a: Array, segments: Array, where: Optional[Array] = None ) -> Array: """Returns segment log-softmax.""" a_max = segment_max(a, segments, where=where, initial=jnp.min(a)) @@ -74,9 +72,7 @@ def segment_log_softmax( def segment_softmax( - a: Array, - segments: Array, - where: Optional[Array] = None + a: Array, segments: Array, where: Optional[Array] = None ) -> Array: """Returns segment softmax.""" a_max = segment_max(a, segments, where=where, initial=jnp.min(a)) diff --git a/rax/_src/t12n.py b/rax/_src/t12n.py index 235e64d..8b19652 100644 --- a/rax/_src/t12n.py +++ b/rax/_src/t12n.py @@ -26,7 +26,6 @@ >>> approx_ndcg_loss_fn = rax.approx_t12n(rax.ndcg_metric) >>> print(approx_ndcg_loss_fn(scores, labels)) -0.71789175 - """ import functools @@ -92,10 +91,12 @@ def approx_t12n(metric_fn: MetricFn, temperature: float = 1.0) -> LossFn: approx_kwargs = {} if "rank_fn" in parameters: approx_kwargs["rank_fn"] = functools.partial( - utils.approx_ranks, step_fn=step_fn) + utils.approx_ranks, step_fn=step_fn + ) if "cutoff_fn" in parameters: approx_kwargs["cutoff_fn"] = functools.partial( - utils.approx_cutoff, step_fn=step_fn) + utils.approx_cutoff, step_fn=step_fn + ) @jax.util.wraps(metric_fn, namestr="approx_{fun}", docstr="Approx {doc}") def approx_metric_loss(scores, labels, **kwargs): @@ -137,18 +138,20 @@ def bound_t12n(metric_fn: MetricFn): A loss function that computes the lower-bound version of ``metric_fn``. """ # Define lower and upper bound step_fn. - upper_bound_step_fn = lambda x: jax.nn.relu(x + 1.) - lower_bound_step_fn = lambda x: 1. - jax.nn.relu(1. - x) + upper_bound_step_fn = lambda x: jax.nn.relu(x + 1.0) + lower_bound_step_fn = lambda x: 1.0 - jax.nn.relu(1.0 - x) # Construct kwargs for rank and cutoff functions. parameters = inspect.signature(metric_fn).parameters approx_kwargs = {} if "rank_fn" in parameters: approx_kwargs["rank_fn"] = functools.partial( - utils.approx_ranks, step_fn=upper_bound_step_fn) + utils.approx_ranks, step_fn=upper_bound_step_fn + ) if "cutoff_fn" in parameters: approx_kwargs["cutoff_fn"] = functools.partial( - utils.approx_cutoff, step_fn=lower_bound_step_fn) + utils.approx_cutoff, step_fn=lower_bound_step_fn + ) @jax.util.wraps(metric_fn, namestr="bounded_{fun}", docstr="Bounded {doc}") def bounded_metric_loss(scores, labels, **kwargs): @@ -160,11 +163,13 @@ def bounded_metric_loss(scores, labels, **kwargs): return bounded_metric_loss -def gumbel_t12n(loss_or_metric_fn: LossOrMetricFn, - *, - samples: int = 8, - beta: float = 1.0, - smoothing_factor: Optional[float] = None) -> LossOrMetricFn: +def gumbel_t12n( + loss_or_metric_fn: LossOrMetricFn, + *, + samples: int = 8, + beta: float = 1.0, + smoothing_factor: Optional[float] = None +) -> LossOrMetricFn: """Transforms ``loss_or_metric_fn`` to operate on Gumbel-sampled scores. This transformation changes given ``loss_or_metric_fn`` so that it samples @@ -205,8 +210,9 @@ def expand_and_repeat_dim(a: Array, axis: int = 0): loss_or_metric_fn, namestr="gumbel_{fun}", docstr="Gumbel {doc}" ) @utils.update_signature(loss_or_metric_fn, "key") - def _loss_or_metric_fn_with_gumbel_scores(scores: Array, labels: Array, *, - key: Array, **kwargs): + def _loss_or_metric_fn_with_gumbel_scores( + scores: Array, labels: Array, *, key: Array, **kwargs + ): # Repeat scores and labels `n` times by adding a new batch dim. scores = expand_and_repeat_dim(scores) labels = expand_and_repeat_dim(labels) diff --git a/rax/_src/t12n_test.py b/rax/_src/t12n_test.py index de213bc..f9e5bff 100644 --- a/rax/_src/t12n_test.py +++ b/rax/_src/t12n_test.py @@ -42,16 +42,18 @@ class ApproxT12nTest(parameterized.TestCase): metrics.ndcg_metric, ]) def test_approx_t12n_metric_has_nonzero_nonnan_loss(self, metric_fn): - scores = jnp.asarray([-2., 1., 3., 9.]) - labels = jnp.asarray([1., 0., 1., 0.]) + scores = jnp.asarray([-2.0, 1.0, 3.0, 9.0]) + labels = jnp.asarray([1.0, 0.0, 1.0, 0.0]) loss_fn = t12n.approx_t12n(metric_fn) loss = loss_fn(scores, labels) np.testing.assert_array_equal( - jnp.isnan(loss), jnp.zeros_like(jnp.isnan(loss))) - np.testing.assert_array_equal(loss != 0., - jnp.ones_like(loss, dtype=jnp.bool_)) + jnp.isnan(loss), jnp.zeros_like(jnp.isnan(loss)) + ) + np.testing.assert_array_equal( + loss != 0.0, jnp.ones_like(loss, dtype=jnp.bool_) + ) @parameterized.parameters([ metrics.mrr_metric, @@ -62,16 +64,18 @@ def test_approx_t12n_metric_has_nonzero_nonnan_loss(self, metric_fn): metrics.ndcg_metric, ]) def test_approx_t12n_metric_has_nonzero_nonnan_grads(self, metric_fn): - scores = jnp.asarray([-2., 1., 3., 9.]) - labels = jnp.asarray([1., 0., 1., 0.]) + scores = jnp.asarray([-2.0, 1.0, 3.0, 9.0]) + labels = jnp.asarray([1.0, 0.0, 1.0, 0.0]) loss_fn = t12n.approx_t12n(metric_fn) grads = jax.grad(loss_fn)(scores, labels) np.testing.assert_array_equal( - jnp.isnan(grads), jnp.zeros_like(jnp.isnan(grads))) - np.testing.assert_array_equal(grads != 0., - jnp.ones_like(grads, dtype=jnp.bool_)) + jnp.isnan(grads), jnp.zeros_like(jnp.isnan(grads)) + ) + np.testing.assert_array_equal( + grads != 0.0, jnp.ones_like(grads, dtype=jnp.bool_) + ) @parameterized.parameters([ metrics.mrr_metric, @@ -82,15 +86,16 @@ def test_approx_t12n_metric_has_nonzero_nonnan_grads(self, metric_fn): metrics.ndcg_metric, ]) def test_approx_t12n_metric_has_nonnan_grads_with_all_where(self, metric_fn): - scores = jnp.asarray([-2., 1., 3., 9.]) - labels = jnp.asarray([1., 0., 1., 0.]) + scores = jnp.asarray([-2.0, 1.0, 3.0, 9.0]) + labels = jnp.asarray([1.0, 0.0, 1.0, 0.0]) where = jnp.asarray([False, False, False, False]) loss_fn = t12n.approx_t12n(metric_fn) grads = jax.grad(loss_fn)(scores, labels, where=where) np.testing.assert_array_equal( - jnp.isnan(grads), jnp.zeros_like(jnp.isnan(grads))) + jnp.isnan(grads), jnp.zeros_like(jnp.isnan(grads)) + ) @parameterized.parameters([ metrics.mrr_metric, @@ -101,21 +106,23 @@ def test_approx_t12n_metric_has_nonnan_grads_with_all_where(self, metric_fn): metrics.ndcg_metric, ]) def test_approx_t12n_metric_has_nonnan_grads_with_zero_labels( - self, metric_fn): - scores = jnp.asarray([-2., 1., 3., 9.]) - labels = jnp.asarray([0., 0., 0., 0.]) + self, metric_fn + ): + scores = jnp.asarray([-2.0, 1.0, 3.0, 9.0]) + labels = jnp.asarray([0.0, 0.0, 0.0, 0.0]) loss_fn = t12n.approx_t12n(metric_fn) grads = jax.grad(loss_fn)(scores, labels) np.testing.assert_array_equal( - jnp.isnan(grads), jnp.zeros_like(jnp.isnan(grads))) + jnp.isnan(grads), jnp.zeros_like(jnp.isnan(grads)) + ) class BoundT12nTest(parameterized.TestCase): def test_computes_upper_bound_on_ranks(self): - scores = jnp.array([2., -1.5, 0.9]) + scores = jnp.array([2.0, -1.5, 0.9]) labels = jnp.ones_like(scores) def fn(scores, labels, *, rank_fn): @@ -125,11 +132,13 @@ def fn(scores, labels, *, rank_fn): bound_fn = t12n.bound_t12n(fn) ranks = bound_fn(scores, labels) - expected = jnp.array([(1. + 0. + 0.), (1. + 4.5 + 3.4), (1. + 2.1 + 0.)]) + expected = jnp.array( + [(1.0 + 0.0 + 0.0), (1.0 + 4.5 + 3.4), (1.0 + 2.1 + 0.0)] + ) np.testing.assert_allclose(ranks, expected) def test_computes_lower_bound_on_cutoffs(self): - scores = jnp.array([2., -1.5, 0.9]) + scores = jnp.array([2.0, -1.5, 0.9]) labels = jnp.ones_like(scores) def fn(scores, labels, *, cutoff_fn): @@ -139,7 +148,7 @@ def fn(scores, labels, *, cutoff_fn): bound_fn = t12n.bound_t12n(fn) ranks = bound_fn(scores, labels) - expected = jnp.array([1., -1.5 - (-1.5 + 0.9) / 2., 1.]) + expected = jnp.array([1.0, -1.5 - (-1.5 + 0.9) / 2.0, 1.0]) np.testing.assert_allclose(ranks, expected) @parameterized.parameters([ @@ -151,16 +160,18 @@ def fn(scores, labels, *, cutoff_fn): metrics.ndcg_metric, ]) def test_bound_t12n_metric_has_nonzero_nonnan_loss(self, metric_fn): - scores = jnp.asarray([-2., 1., 3., 9.]) - labels = jnp.asarray([1., 0., 1., 0.]) + scores = jnp.asarray([-2.0, 1.0, 3.0, 9.0]) + labels = jnp.asarray([1.0, 0.0, 1.0, 0.0]) loss_fn = t12n.bound_t12n(metric_fn) loss = loss_fn(scores, labels) np.testing.assert_array_equal( - jnp.isnan(loss), jnp.zeros_like(jnp.isnan(loss))) - np.testing.assert_array_equal(loss != 0., - jnp.ones_like(loss, dtype=jnp.bool_)) + jnp.isnan(loss), jnp.zeros_like(jnp.isnan(loss)) + ) + np.testing.assert_array_equal( + loss != 0.0, jnp.ones_like(loss, dtype=jnp.bool_) + ) @parameterized.parameters([ metrics.mrr_metric, @@ -171,15 +182,16 @@ def test_bound_t12n_metric_has_nonzero_nonnan_loss(self, metric_fn): metrics.ndcg_metric, ]) def test_bound_t12n_metric_has_nonzero_nonnan_grads(self, metric_fn): - scores = jnp.asarray([-2., 1., 3., 9.]) - labels = jnp.asarray([1., 0., 1., 0.]) + scores = jnp.asarray([-2.0, 1.0, 3.0, 9.0]) + labels = jnp.asarray([1.0, 0.0, 1.0, 0.0]) loss_fn = t12n.bound_t12n(metric_fn) grads = jax.grad(loss_fn)(scores, labels) np.testing.assert_array_equal( - jnp.isnan(grads), jnp.zeros_like(jnp.isnan(grads))) - self.assertGreater(jnp.sum(jnp.abs(grads)), 0.) + jnp.isnan(grads), jnp.zeros_like(jnp.isnan(grads)) + ) + self.assertGreater(jnp.sum(jnp.abs(grads)), 0.0) @parameterized.parameters([ metrics.mrr_metric, @@ -190,15 +202,16 @@ def test_bound_t12n_metric_has_nonzero_nonnan_grads(self, metric_fn): metrics.ndcg_metric, ]) def test_bound_t12n_metric_has_nonnan_grads_with_all_where(self, metric_fn): - scores = jnp.asarray([-2., 1., 3., 9.]) - labels = jnp.asarray([1., 0., 1., 0.]) + scores = jnp.asarray([-2.0, 1.0, 3.0, 9.0]) + labels = jnp.asarray([1.0, 0.0, 1.0, 0.0]) where = jnp.asarray([False, False, False, False]) loss_fn = t12n.bound_t12n(metric_fn) grads = jax.grad(loss_fn)(scores, labels, where=where) np.testing.assert_array_equal( - jnp.isnan(grads), jnp.zeros_like(jnp.isnan(grads))) + jnp.isnan(grads), jnp.zeros_like(jnp.isnan(grads)) + ) @parameterized.parameters([ metrics.mrr_metric, @@ -209,32 +222,34 @@ def test_bound_t12n_metric_has_nonnan_grads_with_all_where(self, metric_fn): metrics.ndcg_metric, ]) def test_bound_t12n_metric_has_nonnan_grads_with_zero_labels(self, metric_fn): - scores = jnp.asarray([-2., 1., 3., 9.]) - labels = jnp.asarray([0., 0., 0., 0.]) + scores = jnp.asarray([-2.0, 1.0, 3.0, 9.0]) + labels = jnp.asarray([0.0, 0.0, 0.0, 0.0]) loss_fn = t12n.bound_t12n(metric_fn) grads = jax.grad(loss_fn)(scores, labels) np.testing.assert_array_equal( - jnp.isnan(grads), jnp.zeros_like(jnp.isnan(grads))) + jnp.isnan(grads), jnp.zeros_like(jnp.isnan(grads)) + ) class GumbelT12nTest(parameterized.TestCase): def test_samples_scores_using_key(self): - scores = jnp.asarray([0., 1., 2.]) - labels = jnp.asarray([0., 1., 0.]) + scores = jnp.asarray([0.0, 1.0, 2.0]) + labels = jnp.asarray([0.0, 1.0, 0.0]) mock_loss_fn = lambda scores, labels: scores new_loss_fn = t12n.gumbel_t12n(mock_loss_fn, samples=1) loss = new_loss_fn(scores, labels, key=jax.random.PRNGKey(42)) np.testing.assert_allclose( - loss, jnp.asarray([[0.589013, 0.166654, 0.962401]]), rtol=1E-5) + loss, jnp.asarray([[0.589013, 0.166654, 0.962401]]), rtol=1e-5 + ) def test_repeats_inputs_n_times(self): - scores = jnp.asarray([0., 1., 2.]) - labels = jnp.asarray([0., 1., 0.]) + scores = jnp.asarray([0.0, 1.0, 2.0]) + labels = jnp.asarray([0.0, 1.0, 0.0]) where = jnp.asarray([True, True, False]) n = 32 mock_loss_fn = lambda scores, labels, where: (scores, labels, where) @@ -242,14 +257,15 @@ def test_repeats_inputs_n_times(self): new_loss_fn = t12n.gumbel_t12n(mock_loss_fn, samples=n) new_scores, new_labels, new_where = new_loss_fn( - scores, labels, where=where, key=jax.random.PRNGKey(42)) + scores, labels, where=where, key=jax.random.PRNGKey(42) + ) self.assertEqual(new_scores.shape, (n, 3)) self.assertEqual(new_labels.shape, (n, 3)) self.assertEqual(new_where.shape, (n, 3)) def test_samples_scores_using_gumbel_beta_shape(self): - scores = jnp.asarray([0., 1., 2.]) - labels = jnp.asarray([0., 1., 0.]) + scores = jnp.asarray([0.0, 1.0, 2.0]) + labels = jnp.asarray([0.0, 1.0, 0.0]) mock_loss_fn = lambda scores, labels: scores new_loss_fn = t12n.gumbel_t12n(mock_loss_fn, samples=1, beta=0.00001) @@ -258,19 +274,20 @@ def test_samples_scores_using_gumbel_beta_shape(self): np.testing.assert_allclose(loss, jnp.expand_dims(scores, 0), atol=1e-3) def test_handles_extreme_scores(self): - scores = jnp.asarray([-3e18, 1., 2e22]) - labels = jnp.asarray([0., 1., 0.]) + scores = jnp.asarray([-3e18, 1.0, 2e22]) + labels = jnp.asarray([0.0, 1.0, 0.0]) mock_loss_fn = lambda scores, labels: scores new_loss_fn = t12n.gumbel_t12n(mock_loss_fn, samples=1) loss = new_loss_fn(scores, labels, key=jax.random.PRNGKey(42)) np.testing.assert_allclose( - loss, jnp.asarray([[-3e18, 1.666543e-01, 2e22]]), rtol=1E-5) + loss, jnp.asarray([[-3e18, 1.666543e-01, 2e22]]), rtol=1e-5 + ) def test_raises_an_error_if_no_key_is_provided(self): - scores = jnp.asarray([-3e18, 1., 2e22]) - labels = jnp.asarray([0., 1., 0.]) + scores = jnp.asarray([-3e18, 1.0, 2e22]) + labels = jnp.asarray([0.0, 1.0, 0.0]) mock_loss_fn = lambda scores, labels: scores new_loss_fn = t12n.gumbel_t12n(mock_loss_fn) @@ -279,8 +296,8 @@ def test_raises_an_error_if_no_key_is_provided(self): new_loss_fn(scores, labels) def test_applies_log_softmax_transformation(self): - scores = jnp.asarray([3., -2., 5.5, 1.]) - labels = jnp.asarray([0., 1., 2., 0.]) + scores = jnp.asarray([3.0, -2.0, 5.5, 1.0]) + labels = jnp.asarray([0.0, 1.0, 2.0, 0.0]) mock_loss_fn = lambda scores, labels: scores gumbel_loss_fn = t12n.gumbel_t12n(mock_loss_fn) @@ -288,10 +305,12 @@ def test_applies_log_softmax_transformation(self): output_scores = gumbel_loss_fn(scores, labels, key=jax.random.PRNGKey(42)) logsoftmax_scores = logsoftmax_loss_fn( - scores, labels, key=jax.random.PRNGKey(42)) + scores, labels, key=jax.random.PRNGKey(42) + ) np.testing.assert_allclose( - jnp.log(jax.nn.softmax(output_scores) + 1e-20), logsoftmax_scores) + jnp.log(jax.nn.softmax(output_scores) + 1e-20), logsoftmax_scores + ) def test_smoothing_factor_should_handle_extreme_values(self): scores = jnp.asarray([-1e34, 1e34]) @@ -304,11 +323,13 @@ def mock_loss_fn(scores, labels, where=None): gumbel_loss_fn = rax.gumbel_t12n(mock_loss_fn, smoothing_factor=1e-20) grads = jax.grad(gumbel_loss_fn)( - scores, labels, where=where, key=jax.random.PRNGKey(42)) + scores, labels, where=where, key=jax.random.PRNGKey(42) + ) # Grads should not be NaN. np.testing.assert_array_equal( - jnp.isnan(grads), jnp.zeros_like(jnp.isnan(grads))) + jnp.isnan(grads), jnp.zeros_like(jnp.isnan(grads)) + ) def test_returns_function_with_key_in_signature(self): def loss_fn(scores, labels, *, where=None): @@ -511,7 +532,6 @@ def test_computes_fn_with_random_key(self, fn): ) def test_uses_segmented_implementation_when_available(self): - # Construct a mocked loss fn that accepts `segments` as a kwarg. def loss_fn_that_supports_segments(scores, labels, *, segments=None): del scores, labels, segments # Unused by mocked function. @@ -549,11 +569,8 @@ def loss_fn(scores, labels, *, where=None): def load_tests(loader, tests, ignore): del loader, ignore # Unused. tests.addTests( - doctest.DocTestSuite(t12n, globs={ - "jax": jax, - "jnp": jnp, - "rax": rax - })) + doctest.DocTestSuite(t12n, globs={"jax": jax, "jnp": jnp, "rax": rax}) + ) return tests diff --git a/rax/_src/types.py b/rax/_src/types.py index 7825edc..957251c 100644 --- a/rax/_src/types.py +++ b/rax/_src/types.py @@ -85,8 +85,12 @@ def __call__( class ReduceFn(Protocol): """:class:`typing.Protocol` for reduce functions.""" - def __call__(self, a: Array, where: Optional[Array], - axis: Optional[Union[int, Tuple[int, ...]]]) -> Array: + def __call__( + self, + a: Array, + where: Optional[Array], + axis: Optional[Union[int, Tuple[int, ...]]], + ) -> Array: """Reduces an array across one or more dimensions. Args: @@ -106,8 +110,9 @@ def __call__(self, a: Array, where: Optional[Array], class LossFn(Protocol): """:class:`typing.Protocol` for loss functions.""" - def __call__(self, scores: Array, labels: Array, *, where: Optional[Array], - **kwargs) -> Array: + def __call__( + self, scores: Array, labels: Array, *, where: Optional[Array], **kwargs + ) -> Array: """Computes a loss. Args: @@ -127,8 +132,9 @@ def __call__(self, scores: Array, labels: Array, *, where: Optional[Array], class MetricFn(Protocol): """:class:`typing.Protocol` for metric functions.""" - def __call__(self, scores: Array, labels: Array, *, where: Optional[Array], - **kwargs) -> Array: + def __call__( + self, scores: Array, labels: Array, *, where: Optional[Array], **kwargs + ) -> Array: """Computes a metric. Args: @@ -148,8 +154,15 @@ def __call__(self, scores: Array, labels: Array, *, where: Optional[Array], class LambdaweightFn(Protocol): """:class:`typing.Protocol` for lambdaweight functions.""" - def __call__(self, scores: Array, labels: Array, *, where: Optional[Array], - weights: Optional[Array], **kwargs) -> Array: + def __call__( + self, + scores: Array, + labels: Array, + *, + where: Optional[Array], + weights: Optional[Array], + **kwargs + ) -> Array: """Computes lambdaweights. Args: diff --git a/rax/_src/utils.py b/rax/_src/utils.py index 78d9160..8e32982 100644 --- a/rax/_src/utils.py +++ b/rax/_src/utils.py @@ -28,9 +28,11 @@ T = TypeVar("T") -def safe_reduce(a: Array, - where: Optional[Array] = None, - reduce_fn: Optional[Callable[..., Array]] = None) -> Array: +def safe_reduce( + a: Array, + where: Optional[Array] = None, + reduce_fn: Optional[Callable[..., Array]] = None, +) -> Array: """Reduces the values of given array while preventing NaN in the output. For :func:`jax.numpy.mean` reduction, this additionally prevents ``NaN`` in @@ -64,7 +66,7 @@ def safe_reduce(a: Array, # valid pairs. Instead, we prefer that the loss returns 0 in these cases. # Note that this only hides those NaN values if the input did not contain # any NaN values. Otherwise it just returns the output as-is. - output = jnp.where(jnp.isnan(output) & is_input_valid, 0., output) + output = jnp.where(jnp.isnan(output) & is_input_valid, 0.0, output) if reduce_fn is None and where is not None: # When there is no reduce_fn (i.e. we are returning an unreduced @@ -72,7 +74,7 @@ def safe_reduce(a: Array, # This makes sure that manual sum reduction on an unreduced loss works as # expected: # `jnp.sum(loss_fn(reduce_fn=None)) == loss_fn(reduce_fn=jnp.sum)` - output = jnp.where(where, output, 0.) + output = jnp.where(where, output, 0.0) return output @@ -152,11 +154,13 @@ def normalize_probabilities( return output -def logcumsumexp(x: Array, - *, - axis: int = -1, - where: Optional[Array] = None, - reverse: bool = False): +def logcumsumexp( + x: Array, + *, + axis: int = -1, + where: Optional[Array] = None, + reverse: bool = False, +): """Computes the cumulative logsumexp. This is a numerically safe and efficient implementation of a cumulative @@ -165,8 +169,8 @@ def logcumsumexp(x: Array, Args: x: The :class:`~jax.Array` to compute the cumulative logsumexp for. axis: The axis over which the cumulative sum should take place. - where: An optional :class:`~jax.Array` of the same shape as ``x`` - indicating which items are valid for computing the cumulative logsumexp. + where: An optional :class:`~jax.Array` of the same shape as ``x`` indicating + which items are valid for computing the cumulative logsumexp. reverse: Whether to compute the cumulative sum in reverse. Returns: @@ -194,14 +198,14 @@ def logcumsumexp(x: Array, # Compute `exp(x_i - m_i)` for each i. x_shifted = jnp.exp(x - m) - x_shifted = jnp.where(where, x_shifted, 0.) + x_shifted = jnp.where(where, x_shifted, 0.0) # Compute `exp(m_{i-1} - m_i)` for each i. This is used to perform an # efficient version of the internal cumulative sumation (see below). # Note that `m_{i-1} <= m_i` for all i because m_i is a cumulative maximum, so # this is numerically safe. - m_diffs = jnp.exp(jnp.minimum(0., jnp.roll(m, 1, axis=0) - m)) - m_diffs = jnp.where(where, m_diffs, 1.) + m_diffs = jnp.exp(jnp.minimum(0.0, jnp.roll(m, 1, axis=0) - m)) + m_diffs = jnp.where(where, m_diffs, 1.0) # We wish to compute the following output values (for each i): # @@ -304,7 +308,7 @@ def ranks( where: Optional[Array] = None, segments: Optional[Array] = None, axis: int = -1, - key: Optional[Array] = None + key: Optional[Array] = None, ) -> Array: """Computes the ranks for given scores. @@ -368,7 +372,7 @@ def approx_ranks( where: Optional[Array] = None, segments: Optional[Array] = None, key: Optional[Array] = None, - step_fn: Callable[[Array], Array] = jax.nn.sigmoid + step_fn: Callable[[Array], Array] = jax.nn.sigmoid, ) -> Array: """Computes approximate ranks. @@ -457,7 +461,7 @@ def approx_cutoff( *, where: Optional[Array] = None, segments: Optional[Array] = None, - step_fn: Callable[[Array], Array] = jax.nn.sigmoid + step_fn: Callable[[Array], Array] = jax.nn.sigmoid, ) -> Array: """Approximately select the largest ``n`` values of ``a``. diff --git a/rax/_src/utils_test.py b/rax/_src/utils_test.py index c683b45..e86da4a 100644 --- a/rax/_src/utils_test.py +++ b/rax/_src/utils_test.py @@ -29,34 +29,43 @@ class NormalizeProbabilitiesTest(absltest.TestCase): def test_sums_to_one_for_given_axis(self): - arr = jnp.asarray([[0., 1., 2.], [3., 4., 5.]]) + arr = jnp.asarray([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]) result1 = utils.normalize_probabilities(arr, axis=0) result2 = utils.normalize_probabilities(arr, axis=1) np.testing.assert_array_equal( - result1, jnp.asarray([[0., 1. / 5., 2. / 7.], [1., 4. / 5., 5. / 7.]])) + result1, + jnp.asarray([[0.0, 1.0 / 5.0, 2.0 / 7.0], [1.0, 4.0 / 5.0, 5.0 / 7.0]]), + ) np.testing.assert_array_equal( result2, - jnp.asarray([[0., 1. / 3., 2. / 3.], [3. / 12., 4. / 12., 5. / 12.]])) + jnp.asarray( + [[0.0, 1.0 / 3.0, 2.0 / 3.0], [3.0 / 12.0, 4.0 / 12.0, 5.0 / 12.0]] + ), + ) def test_sums_to_one_for_default_axis(self): - arr = jnp.asarray([[0., 1., 2.], [3., 4., 5.]]) + arr = jnp.asarray([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]) result = utils.normalize_probabilities(arr) np.testing.assert_array_equal( result, - jnp.asarray([[0., 1. / 3., 2. / 3.], [3. / 12., 4. / 12., 5. / 12.]])) + jnp.asarray( + [[0.0, 1.0 / 3.0, 2.0 / 3.0], [3.0 / 12.0, 4.0 / 12.0, 5.0 / 12.0]] + ), + ) def test_handles_where(self): - arr = jnp.asarray([[0., 1., 2.], [3., 4., 5.]]) + arr = jnp.asarray([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]) where = jnp.asarray([[True, False, True], [True, True, True]]) result = utils.normalize_probabilities(arr, where=where, axis=1) np.testing.assert_array_equal( - jnp.sum(result, axis=1, where=where), jnp.asarray([1., 1.])) + jnp.sum(result, axis=1, where=where), jnp.asarray([1.0, 1.0]) + ) def test_handles_segments(self): arr = jnp.asarray([0.0, 1.0, 2.0, 5.0, 7.0, 9.0]) @@ -87,59 +96,64 @@ def test_handles_where_and_segments_with_any_axis(self): # Assert non-masked values sum to the number of segments in each axis. np.testing.assert_array_equal( jnp.sum(result1, where=where, axis=0, keepdims=True), - jnp.array([[[1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 0.0]]]) + jnp.array([[[1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 0.0]]]), ) np.testing.assert_array_equal( jnp.sum(result2, where=where, axis=1, keepdims=True), - jnp.array([[[1.0, 1.0, 2.0, 1.0]]]) + jnp.array([[[1.0, 1.0, 2.0, 1.0]]]), ) np.testing.assert_array_equal( jnp.sum(result3, where=where, axis=2, keepdims=True), - jnp.array([[[2.0], [2.0]]]) + jnp.array([[[2.0], [2.0]]]), ) def test_correctly_sets_all_zeros(self): - arr = jnp.asarray([[0., 0., 0.], [0., 0., 0.]]) + arr = jnp.asarray([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]) result1 = utils.normalize_probabilities(arr, axis=0) result2 = utils.normalize_probabilities(arr, axis=1) np.testing.assert_array_equal( - jnp.sum(result1, axis=0), jnp.asarray([1., 1., 1.])) + jnp.sum(result1, axis=0), jnp.asarray([1.0, 1.0, 1.0]) + ) np.testing.assert_array_equal( - jnp.sum(result2, axis=1), jnp.asarray([1., 1.])) + jnp.sum(result2, axis=1), jnp.asarray([1.0, 1.0]) + ) def test_correctly_handles_all_masked(self): - arr = jnp.asarray([[2., 1., 3.], [1., 1., 1.]]) + arr = jnp.asarray([[2.0, 1.0, 3.0], [1.0, 1.0, 1.0]]) where = jnp.asarray([[False, False, False], [False, False, False]]) result1 = utils.normalize_probabilities(arr, where=where, axis=0) result2 = utils.normalize_probabilities(arr, where=where, axis=1) np.testing.assert_array_equal( - jnp.sum(result1, axis=0), jnp.asarray([1., 1., 1.])) + jnp.sum(result1, axis=0), jnp.asarray([1.0, 1.0, 1.0]) + ) np.testing.assert_array_equal( - jnp.sum(result2, axis=1), jnp.asarray([1., 1.])) + jnp.sum(result2, axis=1), jnp.asarray([1.0, 1.0]) + ) class LogCumsumExp(absltest.TestCase): def test_computes_logcumsumexp(self): - x = jnp.asarray([-4., 5., 2.3, 0.]) + x = jnp.asarray([-4.0, 5.0, 2.3, 0.0]) result = utils.logcumsumexp(x) np.testing.assert_array_equal( result, jnp.asarray([ - jnp.log(jnp.exp(-4.)), - jnp.log(jnp.exp(-4.) + jnp.exp(5.)), - jnp.log(jnp.exp(-4.) + jnp.exp(5.) + jnp.exp(2.3)), - jnp.log(jnp.exp(-4.) + jnp.exp(5.) + jnp.exp(2.3) + jnp.exp(0.)) - ])) + jnp.log(jnp.exp(-4.0)), + jnp.log(jnp.exp(-4.0) + jnp.exp(5.0)), + jnp.log(jnp.exp(-4.0) + jnp.exp(5.0) + jnp.exp(2.3)), + jnp.log(jnp.exp(-4.0) + jnp.exp(5.0) + jnp.exp(2.3) + jnp.exp(0.0)), + ]), + ) def test_computes_over_specified_axis(self): - x = jnp.asarray([[-4., 2.3, 0.], [2.2, -1.2, 1.1]]) + x = jnp.asarray([[-4.0, 2.3, 0.0], [2.2, -1.2, 1.1]]) result = utils.logcumsumexp(x, axis=-1) np.testing.assert_array_equal(result[0, :], utils.logcumsumexp(x[0, :])) @@ -151,8 +165,8 @@ def test_computes_over_specified_axis(self): np.testing.assert_array_equal(result[:, 2], utils.logcumsumexp(x[:, 2])) def test_computes_reversed(self): - x = jnp.asarray([-4., 5., 2.3, 0.]) - x_flipped = jnp.asarray([0., 2.3, 5., -4.]) + x = jnp.asarray([-4.0, 5.0, 2.3, 0.0]) + x_flipped = jnp.asarray([0.0, 2.3, 5.0, -4.0]) result_reverse = utils.logcumsumexp(x, reverse=True) result_flipped = jnp.flip(utils.logcumsumexp(x_flipped)) @@ -160,9 +174,9 @@ def test_computes_reversed(self): np.testing.assert_array_equal(result_reverse, result_flipped) def test_computes_with_where_mask(self): - x = jnp.asarray([-4., 5., 2.3, 0.]) + x = jnp.asarray([-4.0, 5.0, 2.3, 0.0]) where = jnp.asarray([True, False, True, True]) - x_masked = jnp.asarray([-4., 2.3, 0.]) + x_masked = jnp.asarray([-4.0, 2.3, 0.0]) result_where = utils.logcumsumexp(x, where=where) result_masked = utils.logcumsumexp(x_masked) @@ -172,70 +186,83 @@ def test_computes_with_where_mask(self): np.testing.assert_array_equal(result_where[3], result_masked[2]) def test_handles_extreme_values(self): - x = jnp.asarray([-4., -2.1e26, 5., 3.4e38, 10., -2.99e26]) + x = jnp.asarray([-4.0, -2.1e26, 5.0, 3.4e38, 10.0, -2.99e26]) result = utils.logcumsumexp(x) np.testing.assert_array_equal( - result, jnp.asarray([-4., -4., 5.0001235, 3.4e38, 3.4e38, 3.4e38])) + result, jnp.asarray([-4.0, -4.0, 5.0001235, 3.4e38, 3.4e38, 3.4e38]) + ) class SortByTest(absltest.TestCase): def test_sorts_by_scores(self): - scores = jnp.asarray([0., 3., 1., 2.]) - tensors_to_sort = [jnp.asarray([10., 13., 11., 12.])] + scores = jnp.asarray([0.0, 3.0, 1.0, 2.0]) + tensors_to_sort = [jnp.asarray([10.0, 13.0, 11.0, 12.0])] result = utils.sort_by(scores, tensors_to_sort)[0] - np.testing.assert_array_equal(result, jnp.asarray([13., 12., 11., 10.])) + np.testing.assert_array_equal(result, jnp.asarray([13.0, 12.0, 11.0, 10.0])) def test_sorts_by_given_axis(self): - scores = jnp.asarray([[3., 1., 2.], [1., 5., 3.]]) - tensors_to_sort = [jnp.asarray([[0., 1., 2.], [3., 4., 5.]])] + scores = jnp.asarray([[3.0, 1.0, 2.0], [1.0, 5.0, 3.0]]) + tensors_to_sort = [jnp.asarray([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]])] result_0 = utils.sort_by(scores, tensors_to_sort, axis=0)[0] result_1 = utils.sort_by(scores, tensors_to_sort, axis=1)[0] - np.testing.assert_array_equal(result_0, - jnp.asarray([[0., 4., 5.], [3., 1., 2.]])) - np.testing.assert_array_equal(result_1, - jnp.asarray([[0., 2., 1.], [4., 5., 3.]])) + np.testing.assert_array_equal( + result_0, jnp.asarray([[0.0, 4.0, 5.0], [3.0, 1.0, 2.0]]) + ) + np.testing.assert_array_equal( + result_1, jnp.asarray([[0.0, 2.0, 1.0], [4.0, 5.0, 3.0]]) + ) def test_sorts_multiple_tensors(self): - scores = jnp.asarray([0., 3., 1., 2.]) + scores = jnp.asarray([0.0, 3.0, 1.0, 2.0]) tensors_to_sort = [ - jnp.asarray([10., 13., 11., 12.]), - jnp.asarray([50., 56., 52., 54.]), - jnp.asarray([75., 78., 76., 77.]) + jnp.asarray([10.0, 13.0, 11.0, 12.0]), + jnp.asarray([50.0, 56.0, 52.0, 54.0]), + jnp.asarray([75.0, 78.0, 76.0, 77.0]), ] result = utils.sort_by(scores, tensors_to_sort) - np.testing.assert_array_equal(result[0], jnp.asarray([13., 12., 11., 10.])) - np.testing.assert_array_equal(result[1], jnp.asarray([56., 54., 52., 50.])) - np.testing.assert_array_equal(result[2], jnp.asarray([78., 77., 76., 75.])) + np.testing.assert_array_equal( + result[0], jnp.asarray([13.0, 12.0, 11.0, 10.0]) + ) + np.testing.assert_array_equal( + result[1], jnp.asarray([56.0, 54.0, 52.0, 50.0]) + ) + np.testing.assert_array_equal( + result[2], jnp.asarray([78.0, 77.0, 76.0, 75.0]) + ) def test_places_masked_values_last(self): - scores = jnp.asarray([0., 3., 1., 2.]) - tensors_to_sort = [jnp.asarray([10., 13., 11., 12.])] + scores = jnp.asarray([0.0, 3.0, 1.0, 2.0]) + tensors_to_sort = [jnp.asarray([10.0, 13.0, 11.0, 12.0])] where = jnp.asarray([True, True, False, False]) result = utils.sort_by(scores, tensors_to_sort, where=where)[0] - np.testing.assert_array_equal(result, jnp.asarray([13., 10., 12., 11.])) + np.testing.assert_array_equal(result, jnp.asarray([13.0, 10.0, 12.0, 11.0])) def test_breaks_ties_randomly_when_key_is_provided(self): - scores = jnp.asarray([0., 1., 1., 2.]) - tensors_to_sort = [jnp.asarray([10., 11.1, 11.2, 12.])] + scores = jnp.asarray([0.0, 1.0, 1.0, 2.0]) + tensors_to_sort = [jnp.asarray([10.0, 11.1, 11.2, 12.0])] key = jax.random.PRNGKey(4242) key1, key2 = jax.random.split(key) result1 = utils.sort_by(scores, tensors_to_sort, key=key1)[0] result2 = utils.sort_by(scores, tensors_to_sort, key=key2)[0] - np.testing.assert_array_equal(result1, jnp.asarray([12., 11.2, 11.1, 10.])) - np.testing.assert_array_equal(result2, jnp.asarray([12., 11.1, 11.2, 10.])) + np.testing.assert_array_equal( + result1, jnp.asarray([12.0, 11.2, 11.1, 10.0]) + ) + np.testing.assert_array_equal( + result2, jnp.asarray([12.0, 11.1, 11.2, 10.0]) + ) def test_sorts_within_segments(self): scores = jnp.asarray([[0.0, 3.0, 1.0, 2.0, 3.5, -2.0, 5.0]]) @@ -367,21 +394,21 @@ def test_approx_cutoff_n_is_zero(self): class RanksTest(absltest.TestCase): def test_ranks_by_sorting_scores(self): - scores = jnp.asarray([[0., 1., 2.], [2., 1., 3.]]) + scores = jnp.asarray([[0.0, 1.0, 2.0], [2.0, 1.0, 3.0]]) ranks = utils.ranks(scores) np.testing.assert_array_equal(ranks, jnp.asarray([[3, 2, 1], [2, 3, 1]])) def test_ranks_along_given_axis(self): - scores = jnp.asarray([[0., 1., 2.], [1., 2., 0.]]) + scores = jnp.asarray([[0.0, 1.0, 2.0], [1.0, 2.0, 0.0]]) ranks = utils.ranks(scores, axis=0) np.testing.assert_array_equal(ranks, jnp.asarray([[2, 2, 1], [1, 1, 2]])) def test_ranks_with_ties_broken_randomly(self): - scores = jnp.asarray([2., 1., 1.]) + scores = jnp.asarray([2.0, 1.0, 1.0]) key = jax.random.PRNGKey(1) key1, key2 = jax.random.split(key) @@ -410,7 +437,7 @@ def test_ranks_with_segments_and_where(self): class ApproxRanksTest(absltest.TestCase): def test_computes_approx_ranks(self): - scores = jnp.asarray([-3., 1., 2.]) + scores = jnp.asarray([-3.0, 1.0, 2.0]) ranks = utils.approx_ranks(scores) @@ -418,13 +445,14 @@ def test_computes_approx_ranks(self): np.testing.assert_array_equal( ranks, jnp.asarray([ - sigmoid(3. + 1.) + sigmoid(3. + 2.) + 1.0, - sigmoid(-1. - 3.) + sigmoid(-1. + 2.) + 1.0, - sigmoid(-2. - 3.) + sigmoid(-2. + 1.) + 1.0 - ])) + sigmoid(3.0 + 1.0) + sigmoid(3.0 + 2.0) + 1.0, + sigmoid(-1.0 - 3.0) + sigmoid(-1.0 + 2.0) + 1.0, + sigmoid(-2.0 - 3.0) + sigmoid(-2.0 + 1.0) + 1.0, + ]), + ) def test_maintains_order(self): - scores = jnp.asarray([-4., 1., -3., 2.]) + scores = jnp.asarray([-4.0, 1.0, -3.0, 2.0]) ranks = utils.approx_ranks(scores) true_ranks = utils.ranks(scores) @@ -440,7 +468,8 @@ def test_computes_approx_ranks_with_where(self): ranks_with_where = utils.approx_ranks(scores, where=where) np.testing.assert_array_equal( - ranks, jnp.asarray([ranks_with_where[0], ranks_with_where[2]])) + ranks, jnp.asarray([ranks_with_where[0], ranks_with_where[2]]) + ) def test_computes_approx_ranks_with_segments(self): scores_segment_0 = jnp.asarray([3.33, 1.125]) @@ -458,7 +487,7 @@ def test_computes_approx_ranks_with_segments(self): class SafeReduceTest(absltest.TestCase): def test_reduces_values_according_to_fn(self): - a = jnp.array([[3., 2.], [4.5, 1.2]]) + a = jnp.array([[3.0, 2.0], [4.5, 1.2]]) res_mean = utils.safe_reduce(a, reduce_fn=jnp.mean) res_sum = utils.safe_reduce(a, reduce_fn=jnp.sum) @@ -469,7 +498,7 @@ def test_reduces_values_according_to_fn(self): np.testing.assert_allclose(res_none, a) def test_reduces_values_with_mask(self): - a = jnp.array([[3., 2., 0.01], [4.5, 1.2, 0.9]]) + a = jnp.array([[3.0, 2.0, 0.01], [4.5, 1.2, 0.9]]) where = jnp.array([[True, False, True], [True, True, False]]) res_mean = utils.safe_reduce(a, where=where, reduce_fn=jnp.mean) @@ -478,21 +507,21 @@ def test_reduces_values_with_mask(self): np.testing.assert_allclose(res_mean, jnp.mean(a, where=where)) np.testing.assert_allclose(res_sum, jnp.sum(a, where=where)) - np.testing.assert_allclose(res_none, jnp.where(where, a, 0.)) + np.testing.assert_allclose(res_none, jnp.where(where, a, 0.0)) def test_reduces_mean_with_all_masked(self): - a = jnp.array([[3., 2., 0.01], [4.5, 1.2, 0.9]]) + a = jnp.array([[3.0, 2.0, 0.01], [4.5, 1.2, 0.9]]) where = jnp.array([[False, False, False], [False, False, False]]) res_mean = utils.safe_reduce(a, where=where, reduce_fn=jnp.mean) - np.testing.assert_allclose(res_mean, jnp.array(0.)) + np.testing.assert_allclose(res_mean, jnp.array(0.0)) class ComputePairsTest(absltest.TestCase): def test_computes_all_pairs(self): - a = jnp.array([1., 2., 3.]) + a = jnp.array([1.0, 2.0, 3.0]) expected = jnp.array([11.0, 21.0, 31.0, 12.0, 22.0, 32.0, 13.0, 23.0, 33.0]) result = utils.compute_pairs(a, lambda a, b: a + b * 10.0) @@ -508,7 +537,7 @@ def test_computes_all_pairs_on_empty_array(self): np.testing.assert_allclose(result, expected) def test_computes_all_pairs_with_batch_dimension(self): - a = jnp.array([[1., 2.], [3., 4.]]) + a = jnp.array([[1.0, 2.0], [3.0, 4.0]]) expected = jnp.array([[1.0, 2.0, 2.0, 4.0], [9.0, 12.0, 12.0, 16.0]]) result = utils.compute_pairs(a, lambda a, b: a * b) @@ -599,12 +628,15 @@ def load_tests(loader, tests, ignore): del loader, ignore # Unused. tests.addTests( doctest.DocTestSuite( - utils, extraglobs={ + utils, + extraglobs={ "jax": jax, "jnp": jnp, "rax": rax, "utils": utils, - })) + }, + ) + ) return tests diff --git a/rax/types.py b/rax/types.py index a1fc82b..75a324b 100644 --- a/rax/types.py +++ b/rax/types.py @@ -21,7 +21,6 @@ from rax._src.types import RankFn from rax._src.types import ReduceFn -# pyformat: disable __all__ = [ "CutoffFn", "LambdaweightFn", @@ -30,5 +29,3 @@ "RankFn", "ReduceFn", ] -# pyformat: enable - diff --git a/rax/utils.py b/rax/utils.py index 7e12794..f537692 100644 --- a/rax/utils.py +++ b/rax/utils.py @@ -19,12 +19,9 @@ from rax._src.utils import cutoff from rax._src.utils import ranks -# pyformat: disable __all__ = [ "approx_cutoff", "approx_ranks", "cutoff", "ranks", ] -# pyformat: enable -