diff --git a/flax/nnx/nn/stochastic.py b/flax/nnx/nn/stochastic.py index baed1526a..a1eea770d 100644 --- a/flax/nnx/nn/stochastic.py +++ b/flax/nnx/nn/stochastic.py @@ -150,7 +150,9 @@ def __call__( broadcast_shape = list(inputs.shape) for dim in self.broadcast_dims: broadcast_shape[dim] = 1 - mask = random.bernoulli(key, p=keep_prob, shape=broadcast_shape) + mask = random.bernoulli( + key, p=keep_prob, shape=broadcast_shape, out_sharding=jax.typeof(inputs).sharding + ) mask = jnp.broadcast_to(mask, inputs.shape) return lax.select(mask, inputs / keep_prob, jnp.zeros_like(inputs)) diff --git a/tests/nnx/spmd_test.py b/tests/nnx/spmd_test.py index 9f2ec333e..4b8d8ba3e 100644 --- a/tests/nnx/spmd_test.py +++ b/tests/nnx/spmd_test.py @@ -243,6 +243,24 @@ def test_out_sharding_embed_attend(self): assert 'float32[2@X,10]' in str(jax.typeof(layer.attend(sharded_array))) assert 'float32[2@X,10@Y]' in str(jax.typeof(layer.attend(sharded_array, out_sharding=P("X", "Y")))) + def test_out_sharding_dropout(self): + mesh = jax.make_mesh((2, 2), ("X", "Y"), axis_types=(AxisType.Explicit, AxisType.Explicit)) + with jax.set_mesh(mesh): + replicated_array = jnp.arange(8).reshape(2, 4).astype(jnp.float32) + sharded_array = reshard(replicated_array, P("X", None)) + layers = [ + nnx.Dropout(rate=0.5, rngs=nnx.Rngs(0)), + nnx.Dropout(rate=0.5, broadcast_dims=(1,), rngs=nnx.Rngs(0)), + ] + for layer in layers: + assert 'float32[2@X,4]' in str(jax.typeof(layer(sharded_array))) + + @jax.jit + def func(x, rngs): + return layer(x, rngs=rngs) + + assert 'float32[2@X,4]' in str(jax.typeof(func(sharded_array, nnx.Rngs(0)))) + @parameterized.product(use_hijax=[True, False]) def test_logical_rules(self, use_hijax): self.enter_context(nnx.var_defaults(hijax=use_hijax))