From 674f01687f538a2832a8cc8940d096e1b61d0218 Mon Sep 17 00:00:00 2001 From: carlosgmartin Date: Wed, 12 Jun 2024 22:13:01 -0400 Subject: [PATCH] Add more mask tests. --- optax/losses/_classification_test.py | 47 ++++++++++++++++++++++------ optax/losses/_regression_test.py | 18 +++++++++++ 2 files changed, 55 insertions(+), 10 deletions(-) diff --git a/optax/losses/_classification_test.py b/optax/losses/_classification_test.py index 916d25e4..ee9b5421 100644 --- a/optax/losses/_classification_test.py +++ b/optax/losses/_classification_test.py @@ -84,12 +84,12 @@ def test_gradient(self): @parameterized.parameters(dict(size=5), dict(size=10)) def test_mask(self, size): - logits = np.random.normal(size=size) - labels = np.random.dirichlet(np.ones(size)) + preds = np.random.normal(size=size) + targets = np.random.dirichlet(np.ones(size)) mask = np.random.randint(2, size=size, dtype=bool) - outs_1 = _classification.softmax_cross_entropy(logits[mask], labels[mask]) - outs_2 = _classification.softmax_cross_entropy(logits, labels, where=mask) - np.testing.assert_allclose(outs_1, outs_2) + x = _classification.softmax_cross_entropy(preds[mask], targets[mask]) + y = _classification.softmax_cross_entropy(preds, targets, where=mask) + np.testing.assert_allclose(x, y) class SafeSoftmaxCrossEntropyTest(parameterized.TestCase): @@ -364,12 +364,12 @@ def test_equals_to_cross_entropy_when_eps0(self, logits, labels): @parameterized.parameters(dict(size=5), dict(size=10)) def test_mask(self, size): - logits = np.random.normal(size=size) - labels = np.random.dirichlet(np.ones(size)) + preds = np.random.normal(size=size) + targets = np.random.dirichlet(np.ones(size)) mask = np.random.randint(2, size=size, dtype=bool) - outs_1 = _classification.poly_loss_cross_entropy(logits[mask], labels[mask]) - outs_2 = _classification.poly_loss_cross_entropy(logits, labels, where=mask) - np.testing.assert_allclose(outs_1, outs_2) + x = _classification.poly_loss_cross_entropy(preds[mask], targets[mask]) + y = _classification.poly_loss_cross_entropy(preds, targets, where=mask) + np.testing.assert_allclose(x, y) class HingeTest(parameterized.TestCase): @@ -526,6 +526,15 @@ def test_batched(self): atol=1e-4, ) + @parameterized.parameters(dict(size=5), dict(size=10)) + def test_mask(self, size): + preds = np.random.normal(size=size) + targets = np.random.dirichlet(np.ones(size)) + mask = np.random.randint(2, size=size, dtype=bool) + x = _classification.convex_kl_divergence(preds[mask], targets[mask]) + y = _classification.convex_kl_divergence(preds, targets, where=mask) + np.testing.assert_allclose(x, y) + class PerceptronTest(parameterized.TestCase): @@ -608,6 +617,15 @@ def test_batched(self): atol=1e-4, ) + @parameterized.parameters(dict(size=5), dict(size=10)) + def test_mask(self, size): + preds = np.random.normal(size=size) + targets = np.random.dirichlet(np.ones(size)) + mask = np.random.randint(2, size=size, dtype=bool) + x = _classification.kl_divergence(preds[mask], targets[mask]) + y = _classification.kl_divergence(preds, targets, where=mask) + np.testing.assert_allclose(x, y) + class KLDivergenceWithLogTargetsTest(parameterized.TestCase): @@ -644,6 +662,15 @@ def test_batched(self): atol=1e-4, ) + @parameterized.parameters(dict(size=5), dict(size=10)) + def test_mask(self, size): + preds = np.random.normal(size=size) + targets = np.log(np.random.dirichlet(np.ones(size))) + mask = np.random.randint(2, size=size, dtype=bool) + x = _classification.kl_divergence_with_log_targets(preds[mask], targets[mask]) + y = _classification.kl_divergence_with_log_targets(preds, targets, where=mask) + np.testing.assert_allclose(x, y) + def _lengths_to_paddings(lengths: chex.Array, maxlength: int) -> chex.Array: indices = jnp.arange(maxlength).reshape((1,) * lengths.ndim + (maxlength,)) diff --git a/optax/losses/_regression_test.py b/optax/losses/_regression_test.py index 8369117e..b1c613bf 100644 --- a/optax/losses/_regression_test.py +++ b/optax/losses/_regression_test.py @@ -173,6 +173,24 @@ def test_batched_similarity(self): self.variant(_regression.cosine_similarity)(self.ys, self.ts), 1. - self.exp, atol=1e-4) + @parameterized.parameters(dict(size=5), dict(size=10)) + def test_mask_distance(self, size): + preds = np.random.normal(size=size) + targets = np.random.normal(size=size) + mask = np.random.randint(2, size=size, dtype=bool) + x = _regression.cosine_distance(preds[mask], targets[mask]) + y = _regression.cosine_distance(preds, targets, where=mask) + np.testing.assert_allclose(x, y) + + @parameterized.parameters(dict(size=5), dict(size=10)) + def test_mask_similarity(self, size): + preds = np.random.normal(size=size) + targets = np.random.normal(size=size) + mask = np.random.randint(2, size=size, dtype=bool) + x = _regression.cosine_similarity(preds[mask], targets[mask]) + y = _regression.cosine_similarity(preds, targets, where=mask) + np.testing.assert_allclose(x, y, atol=1e-4) + if __name__ == '__main__': absltest.main()