Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion flax/nnx/nn/stochastic.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,9 @@ def __call__(
broadcast_shape = list(inputs.shape)
for dim in self.broadcast_dims:
broadcast_shape[dim] = 1
mask = random.bernoulli(key, p=keep_prob, shape=broadcast_shape)
mask = random.bernoulli(
key, p=keep_prob, shape=broadcast_shape, out_sharding=jax.typeof(inputs).sharding
)
mask = jnp.broadcast_to(mask, inputs.shape)
return lax.select(mask, inputs / keep_prob, jnp.zeros_like(inputs))

Expand Down
18 changes: 18 additions & 0 deletions tests/nnx/spmd_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,24 @@ def test_out_sharding_embed_attend(self):
assert 'float32[2@X,10]' in str(jax.typeof(layer.attend(sharded_array)))
assert 'float32[2@X,10@Y]' in str(jax.typeof(layer.attend(sharded_array, out_sharding=P("X", "Y"))))

def test_out_sharding_dropout(self):
mesh = jax.make_mesh((2, 2), ("X", "Y"), axis_types=(AxisType.Explicit, AxisType.Explicit))
with jax.set_mesh(mesh):
replicated_array = jnp.arange(8).reshape(2, 4).astype(jnp.float32)
sharded_array = reshard(replicated_array, P("X", None))
layers = [
nnx.Dropout(rate=0.5, rngs=nnx.Rngs(0)),
nnx.Dropout(rate=0.5, broadcast_dims=(1,), rngs=nnx.Rngs(0)),
]
for layer in layers:
assert 'float32[2@X,4]' in str(jax.typeof(layer(sharded_array)))

@jax.jit
def func(x, rngs):
return layer(x, rngs=rngs)

assert 'float32[2@X,4]' in str(jax.typeof(func(sharded_array, nnx.Rngs(0))))

@parameterized.product(use_hijax=[True, False])
def test_logical_rules(self, use_hijax):
self.enter_context(nnx.var_defaults(hijax=use_hijax))
Expand Down
Loading