diff --git a/tests/linen/linen_test.py b/tests/linen/linen_test.py index 778e5c6cf..11db8a2fe 100644 --- a/tests/linen/linen_test.py +++ b/tests/linen/linen_test.py @@ -920,7 +920,7 @@ def loss_fn(params): prev_loss = loss -class StochasticTest(absltest.TestCase): +class StochasticTest(parameterized.TestCase): def test_dropout(self): rng = random.key(0) key1, key2 = random.split(rng) @@ -975,6 +975,70 @@ def test_dropout_rate_limits(self): res = jax.grad(lambda x, k: jnp.sum(fn(x, k)))(inputs, key3) self.assertFalse(np.isnan(res).any()) + @parameterized.parameters( + { + 'num_dims': 2, + 'broadcast_dims': (1,), + 'slice_fn': lambda out, i: out[i, :], + 'summed_total': 2 * 10, + }, + { + 'num_dims': 2, + 'broadcast_dims': (0,), + 'slice_fn': lambda out, i: out[:, i], + 'summed_total': 2 * 10, + }, + { + 'num_dims': 3, + 'broadcast_dims': (1, 2), + 'slice_fn': lambda out, i: out[i, :, :], + 'summed_total': 2 * 10 * 10, + }, + { + 'num_dims': 3, + 'broadcast_dims': (1,), + 'slice_fn': lambda out, i, j: out[i, :, j], + 'summed_total': 2 * 10, + }, + { + 'num_dims': 4, + 'broadcast_dims': (0, 2, 3), + 'slice_fn': lambda out, i: out[:, i, :, :], + 'summed_total': 2 * 10 * 10 * 10, + }, + { + 'num_dims': 4, + 'broadcast_dims': (0, 1), + 'slice_fn': lambda out, i, j: out[:, :, i, j], + 'summed_total': 2 * 10 * 10, + }, + { + 'num_dims': 4, + 'broadcast_dims': (3,), + 'slice_fn': lambda out, i, j, k: out[i, j, k, :], + 'summed_total': 2 * 10, + }, + ) + def test_dropout_broadcast( + self, num_dims, broadcast_dims, slice_fn, summed_total + ): + module = nn.Dropout( + rate=0.5, broadcast_dims=broadcast_dims, deterministic=False + ) + x = jnp.ones((10,) * num_dims) + out = module.apply({}, x, rngs={'dropout': random.key(0)}) + + for i in range(10): + if num_dims - len(broadcast_dims) >= 2: + for j in range(10): + if num_dims - len(broadcast_dims) >= 3: + for k in range(10): + self.assertTrue(slice_fn(out, i, j, k).sum() in (0, summed_total)) + else: + self.assertTrue(slice_fn(out, i, j).sum() in (0, summed_total)) + else: + self.assertTrue(slice_fn(out, i).sum() in (0, summed_total)) + def test_dropout_manual_rng(self): def clone(key): if hasattr(jax.random, 'clone'):