Skip to content

Commit

Permalink
Add test_mask and test_axis methods.
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosgmartin committed Jun 13, 2024
1 parent 674f016 commit 512310d
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 8 deletions.
107 changes: 100 additions & 7 deletions optax/losses/_classification_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,22 @@ def test_mask(self, size):
mask = np.random.randint(2, size=size, dtype=bool)
x = _classification.softmax_cross_entropy(preds[mask], targets[mask])
y = _classification.softmax_cross_entropy(preds, targets, where=mask)
np.testing.assert_allclose(x, y)
np.testing.assert_allclose(x, y, atol=1e-4)

@parameterized.parameters(
dict(axis=0, shape=[4, 5, 6]),
dict(axis=1, shape=[4, 5, 6]),
dict(axis=2, shape=[4, 5, 6]),
)
def test_axis(self, shape, axis):
preds = np.random.normal(size=shape)
targets = np.random.dirichlet(np.ones(shape[-1]), size=shape[:-1])
x = _classification.softmax_cross_entropy(preds, targets, axis=axis)
y = _classification.softmax_cross_entropy(
np.moveaxis(preds, axis, -1),
np.moveaxis(targets, axis, -1),
)
np.testing.assert_allclose(x, y, atol=1e-4)


class SafeSoftmaxCrossEntropyTest(parameterized.TestCase):
Expand Down Expand Up @@ -215,6 +230,22 @@ def test_gradient(self):
order=1,
)

@parameterized.parameters(
dict(axis=0, shape=[4, 5, 6]),
dict(axis=1, shape=[4, 5, 6]),
dict(axis=2, shape=[4, 5, 6]),
)
def test_axis(self, shape, axis):
preds = np.random.normal(size=shape)
targets = np.random.randint(shape[axis], size=shape[:axis] + shape[axis+1:])
f = _classification.softmax_cross_entropy_with_integer_labels
x = f(preds, targets, axis=axis)
y = f(
np.moveaxis(preds, axis, -1),
targets,
)
np.testing.assert_allclose(x, y, atol=1e-4)


class SigmoidCrossEntropyTest(parameterized.TestCase):

Expand Down Expand Up @@ -369,7 +400,22 @@ def test_mask(self, size):
mask = np.random.randint(2, size=size, dtype=bool)
x = _classification.poly_loss_cross_entropy(preds[mask], targets[mask])
y = _classification.poly_loss_cross_entropy(preds, targets, where=mask)
np.testing.assert_allclose(x, y)
np.testing.assert_allclose(x, y, atol=1e-4)

@parameterized.parameters(
dict(axis=0, shape=[4, 5, 6]),
dict(axis=1, shape=[4, 5, 6]),
dict(axis=2, shape=[4, 5, 6]),
)
def test_axis(self, shape, axis):
preds = np.random.normal(size=shape)
targets = np.random.dirichlet(np.ones(shape[-1]), size=shape[:-1])
x = _classification.poly_loss_cross_entropy(preds, targets, axis=axis)
y = _classification.poly_loss_cross_entropy(
np.moveaxis(preds, axis, -1),
np.moveaxis(targets, axis, -1),
)
np.testing.assert_allclose(x, y, atol=1e-4)


class HingeTest(parameterized.TestCase):
Expand Down Expand Up @@ -533,7 +579,22 @@ def test_mask(self, size):
mask = np.random.randint(2, size=size, dtype=bool)
x = _classification.convex_kl_divergence(preds[mask], targets[mask])
y = _classification.convex_kl_divergence(preds, targets, where=mask)
np.testing.assert_allclose(x, y)
np.testing.assert_allclose(x, y, atol=1e-4)

@parameterized.parameters(
dict(axis=0, shape=[4, 5, 6]),
dict(axis=1, shape=[4, 5, 6]),
dict(axis=2, shape=[4, 5, 6]),
)
def test_axis(self, shape, axis):
preds = np.random.normal(size=shape)
targets = np.random.dirichlet(np.ones(shape[-1]), size=shape[:-1])
x = _classification.convex_kl_divergence(preds, targets, axis=axis)
y = _classification.convex_kl_divergence(
np.moveaxis(preds, axis, -1),
np.moveaxis(targets, axis, -1),
)
np.testing.assert_allclose(x, y, atol=1e-4)


class PerceptronTest(parameterized.TestCase):
Expand Down Expand Up @@ -624,7 +685,22 @@ def test_mask(self, size):
mask = np.random.randint(2, size=size, dtype=bool)
x = _classification.kl_divergence(preds[mask], targets[mask])
y = _classification.kl_divergence(preds, targets, where=mask)
np.testing.assert_allclose(x, y)
np.testing.assert_allclose(x, y, atol=1e-4)

@parameterized.parameters(
dict(axis=0, shape=[4, 5, 6]),
dict(axis=1, shape=[4, 5, 6]),
dict(axis=2, shape=[4, 5, 6]),
)
def test_axis(self, shape, axis):
preds = np.random.normal(size=shape)
targets = np.random.dirichlet(np.ones(shape[-1]), size=shape[:-1])
x = _classification.kl_divergence(preds, targets, axis=axis)
y = _classification.kl_divergence(
np.moveaxis(preds, axis, -1),
np.moveaxis(targets, axis, -1),
)
np.testing.assert_allclose(x, y, atol=1e-4)


class KLDivergenceWithLogTargetsTest(parameterized.TestCase):
Expand Down Expand Up @@ -667,9 +743,26 @@ def test_mask(self, size):
preds = np.random.normal(size=size)
targets = np.log(np.random.dirichlet(np.ones(size)))
mask = np.random.randint(2, size=size, dtype=bool)
x = _classification.kl_divergence_with_log_targets(preds[mask], targets[mask])
y = _classification.kl_divergence_with_log_targets(preds, targets, where=mask)
np.testing.assert_allclose(x, y)
f = _classification.kl_divergence_with_log_targets
x = f(preds[mask], targets[mask])
y = f(preds, targets, where=mask)
np.testing.assert_allclose(x, y, atol=1e-4)

@parameterized.parameters(
dict(axis=0, shape=[4, 5, 6]),
dict(axis=1, shape=[4, 5, 6]),
dict(axis=2, shape=[4, 5, 6]),
)
def test_axis(self, shape, axis):
preds = np.random.normal(size=shape)
targets = np.log(np.random.dirichlet(np.ones(shape[-1]), size=shape[:-1]))
f = _classification.kl_divergence_with_log_targets
x = f(preds, targets, axis=axis)
y = f(
np.moveaxis(preds, axis, -1),
np.moveaxis(targets, axis, -1),
)
np.testing.assert_allclose(x, y, atol=1e-4)


def _lengths_to_paddings(lengths: chex.Array, maxlength: int) -> chex.Array:
Expand Down
17 changes: 16 additions & 1 deletion optax/losses/_regression_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def test_mask_distance(self, size):
mask = np.random.randint(2, size=size, dtype=bool)
x = _regression.cosine_distance(preds[mask], targets[mask])
y = _regression.cosine_distance(preds, targets, where=mask)
np.testing.assert_allclose(x, y)
np.testing.assert_allclose(x, y, atol=1e-4)

@parameterized.parameters(dict(size=5), dict(size=10))
def test_mask_similarity(self, size):
Expand All @@ -191,6 +191,21 @@ def test_mask_similarity(self, size):
y = _regression.cosine_similarity(preds, targets, where=mask)
np.testing.assert_allclose(x, y, atol=1e-4)

@parameterized.parameters(
dict(axis=0, shape=[4, 5, 6]),
dict(axis=1, shape=[4, 5, 6]),
dict(axis=2, shape=[4, 5, 6]),
)
def test_axis(self, shape, axis):
preds = np.random.normal(size=shape)
targets = np.random.normal(size=shape)
x = _regression.cosine_similarity(preds, targets, axis=axis)
y = _regression.cosine_similarity(
np.moveaxis(preds, axis, -1),
np.moveaxis(targets, axis, -1),
)
np.testing.assert_allclose(x, y, atol=1e-4)


if __name__ == '__main__':
absltest.main()

0 comments on commit 512310d

Please sign in to comment.