Skip to content

Commit

Permalink
Add more mask tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosgmartin committed Jun 13, 2024
1 parent 53c39e5 commit 674f016
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 10 deletions.
47 changes: 37 additions & 10 deletions optax/losses/_classification_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,12 +84,12 @@ def test_gradient(self):

@parameterized.parameters(dict(size=5), dict(size=10))
def test_mask(self, size):
logits = np.random.normal(size=size)
labels = np.random.dirichlet(np.ones(size))
preds = np.random.normal(size=size)
targets = np.random.dirichlet(np.ones(size))
mask = np.random.randint(2, size=size, dtype=bool)
outs_1 = _classification.softmax_cross_entropy(logits[mask], labels[mask])
outs_2 = _classification.softmax_cross_entropy(logits, labels, where=mask)
np.testing.assert_allclose(outs_1, outs_2)
x = _classification.softmax_cross_entropy(preds[mask], targets[mask])
y = _classification.softmax_cross_entropy(preds, targets, where=mask)
np.testing.assert_allclose(x, y)


class SafeSoftmaxCrossEntropyTest(parameterized.TestCase):
Expand Down Expand Up @@ -364,12 +364,12 @@ def test_equals_to_cross_entropy_when_eps0(self, logits, labels):

@parameterized.parameters(dict(size=5), dict(size=10))
def test_mask(self, size):
logits = np.random.normal(size=size)
labels = np.random.dirichlet(np.ones(size))
preds = np.random.normal(size=size)
targets = np.random.dirichlet(np.ones(size))
mask = np.random.randint(2, size=size, dtype=bool)
outs_1 = _classification.poly_loss_cross_entropy(logits[mask], labels[mask])
outs_2 = _classification.poly_loss_cross_entropy(logits, labels, where=mask)
np.testing.assert_allclose(outs_1, outs_2)
x = _classification.poly_loss_cross_entropy(preds[mask], targets[mask])
y = _classification.poly_loss_cross_entropy(preds, targets, where=mask)
np.testing.assert_allclose(x, y)


class HingeTest(parameterized.TestCase):
Expand Down Expand Up @@ -526,6 +526,15 @@ def test_batched(self):
atol=1e-4,
)

@parameterized.parameters(dict(size=5), dict(size=10))
def test_mask(self, size):
preds = np.random.normal(size=size)
targets = np.random.dirichlet(np.ones(size))
mask = np.random.randint(2, size=size, dtype=bool)
x = _classification.convex_kl_divergence(preds[mask], targets[mask])
y = _classification.convex_kl_divergence(preds, targets, where=mask)
np.testing.assert_allclose(x, y)


class PerceptronTest(parameterized.TestCase):

Expand Down Expand Up @@ -608,6 +617,15 @@ def test_batched(self):
atol=1e-4,
)

@parameterized.parameters(dict(size=5), dict(size=10))
def test_mask(self, size):
preds = np.random.normal(size=size)
targets = np.random.dirichlet(np.ones(size))
mask = np.random.randint(2, size=size, dtype=bool)
x = _classification.kl_divergence(preds[mask], targets[mask])
y = _classification.kl_divergence(preds, targets, where=mask)
np.testing.assert_allclose(x, y)


class KLDivergenceWithLogTargetsTest(parameterized.TestCase):

Expand Down Expand Up @@ -644,6 +662,15 @@ def test_batched(self):
atol=1e-4,
)

@parameterized.parameters(dict(size=5), dict(size=10))
def test_mask(self, size):
preds = np.random.normal(size=size)
targets = np.log(np.random.dirichlet(np.ones(size)))
mask = np.random.randint(2, size=size, dtype=bool)
x = _classification.kl_divergence_with_log_targets(preds[mask], targets[mask])
y = _classification.kl_divergence_with_log_targets(preds, targets, where=mask)
np.testing.assert_allclose(x, y)


def _lengths_to_paddings(lengths: chex.Array, maxlength: int) -> chex.Array:
indices = jnp.arange(maxlength).reshape((1,) * lengths.ndim + (maxlength,))
Expand Down
18 changes: 18 additions & 0 deletions optax/losses/_regression_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,24 @@ def test_batched_similarity(self):
self.variant(_regression.cosine_similarity)(self.ys, self.ts),
1. - self.exp, atol=1e-4)

@parameterized.parameters(dict(size=5), dict(size=10))
def test_mask_distance(self, size):
preds = np.random.normal(size=size)
targets = np.random.normal(size=size)
mask = np.random.randint(2, size=size, dtype=bool)
x = _regression.cosine_distance(preds[mask], targets[mask])
y = _regression.cosine_distance(preds, targets, where=mask)
np.testing.assert_allclose(x, y)

@parameterized.parameters(dict(size=5), dict(size=10))
def test_mask_similarity(self, size):
preds = np.random.normal(size=size)
targets = np.random.normal(size=size)
mask = np.random.randint(2, size=size, dtype=bool)
x = _regression.cosine_similarity(preds[mask], targets[mask])
y = _regression.cosine_similarity(preds, targets, where=mask)
np.testing.assert_allclose(x, y, atol=1e-4)


if __name__ == '__main__':
absltest.main()

0 comments on commit 674f016

Please sign in to comment.