Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions flax/nnx/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,7 +517,14 @@ def updates_and_snapshot(args: A) -> tuple[A, A]:
for leaf in leaves:
if isinstance(leaf, variablelib.Variable):
updates_leaves.append(leaf)
snapshot_leaves.append(leaf.copy())
# don't snapshot hijax or ref Variables as their updates are automatically
# masked out in mask_variable_updates. However, the leaf is kept in the
# updates to check for aliasing. This avoids a copy operation which has
# significance for ref Variables.
if leaf.hijax or leaf.ref:
snapshot_leaves.append(Mask())
else:
snapshot_leaves.append(leaf.copy())
else:
updates_leaves.append(Mask())
snapshot_leaves.append(Mask())
Expand Down Expand Up @@ -597,6 +604,10 @@ def mask_variable_updates(
keep_fn = lambda _, _pfx, cur, snap: variable_changed(cur, snap)

def _mask_updates(path, prefix_leaf, current, snapshot):
if current is None:
# None leaves should remain None, they only appear here because
# is_leaf catches None values for the prefix
return None
if isinstance(current, variablelib.Variable):
if current.hijax or current.ref:
return Mask()
Expand All @@ -610,7 +621,7 @@ def _mask_updates(path, prefix_leaf, current, snapshot):
current_tree, snapshot_tree, is_leaf=is_leaf,
)
return broadcast_prefix_map(
_mask_updates, prefix, current_tree, snapshot_tree, is_leaf=is_leaf,
_mask_updates, prefix, current_tree, snapshot_tree, is_leaf=is_leaf
)


Expand Down
16 changes: 7 additions & 9 deletions flax/nnx/transforms/iteration.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,21 +69,19 @@ def _apply_axis_fn(
axis_fn: tp.Callable[..., tp.Any],
) -> None:
is_leaf = lambda x: x is None or isinstance(x, variablelib.Variable)
_, per_leaf_axes = extract.broadcast_prefix2(axes, tree, is_leaf=is_leaf)
leaves = jax.tree_util.tree_leaves(tree, is_leaf=is_leaf)
for leaf, axis in zip(leaves, per_leaf_axes):
if (axis is None or isinstance(axis, int)) and isinstance(
leaf, variablelib.Variable
):
def apply_fn(path, axis, leaf):
if isinstance(axis, int) and isinstance(leaf, variablelib.Variable):
axis_fn(leaf, axis, metadata)

extract.broadcast_prefix_map(apply_fn, axes, tree, is_leaf=is_leaf)


@tp.overload
def transform_metadata(
*,
in_axes: tp.Any = 0,
out_axes: tp.Any = 0,
partition: str,
partition: str | None,
graph: bool | None = None,
) -> tp.Callable[[F], F]:
...
Expand All @@ -96,7 +94,7 @@ def transform_metadata(
in_axes: tp.Any = 0,
out_axes: tp.Any = 0,
graph: bool | None = None,
partition: str,
partition: str | None,
) -> F:
...

Expand All @@ -106,8 +104,8 @@ def transform_metadata(
*,
in_axes: tp.Any = 0,
out_axes: tp.Any = 0,
partition: str | None,
graph: bool | None = None,
partition: str,
) -> F | tp.Callable[[F], F]:
if f is Missing:
return functools.partial(
Expand Down
24 changes: 24 additions & 0 deletions tests/nnx/spmd_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,30 @@ def f(v):
self.assertEqual(v2[...], 10)


def test_transform_metadata_decorator_none_partition(self):
v = nnx.Param(
jnp.array(1),
out_sharding=(None, 'dout'),
eager_sharding=False,
)

@nnx.transform_metadata(in_axes=0, out_axes=1, partition=None)
def f(v):
v[...] += 1
self.assertEqual(v.out_sharding, ('dout',))
v2 = nnx.Param(
jnp.array(10),
out_sharding=('dmid', 'dout'),
eager_sharding=False,
)
return v2

v2 = f(v)
self.assertEqual(v.out_sharding, (None, 'dout'))
self.assertEqual(v[...], 2)
self.assertEqual(v2.out_sharding, ('dmid', None, 'dout'))
self.assertEqual(v2[...], 10)

@parameterized.product(use_eager_sharding=[True, False])
def test_eager_sharding_context(self, use_eager_sharding):
rngs = nnx.Rngs(0)
Expand Down
61 changes: 56 additions & 5 deletions tests/nnx/transforms_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5537,12 +5537,9 @@ def dict_scan(carry, x):

def test_no_carry_all_scanned(self):
def double(x):
return (x * 2,)
return x * 2

(ys,) = pure_jax_fancy_scan(
double, jnp.arange(5.0),
in_axes=(0,), out_axes=(0,),
)
ys = pure_jax_fancy_scan(double, jnp.arange(5.0), in_axes=0, out_axes=0)
np.testing.assert_allclose(ys, jnp.arange(5.0) * 2)

def test_reverse(self):
Expand Down Expand Up @@ -5651,6 +5648,60 @@ def f(x):
)
np.testing.assert_allclose(ys, jnp.arange(5.0) * 2)

def test_scan_axis_1(self):
def cumsum(carry, x):
carry = carry + x
return carry, carry

x = jnp.arange(10.0).reshape((2, 5))
final_carry, ys = pure_jax_fancy_scan(
cumsum, jnp.zeros(2), x,
in_axes=(nnx.Carry, 1), out_axes=(nnx.Carry, 1),
)
np.testing.assert_allclose(final_carry, jnp.array([10.0, 35.0]))
expected_ys = jnp.array([
[0., 1., 3., 6., 10.],
[5., 11., 18., 26., 35.]
])
np.testing.assert_allclose(ys, expected_ys)

def test_scan_axis_negative_1(self):
def cumsum(carry, x):
carry = carry + x
return carry, carry

x = jnp.arange(10.0).reshape((2, 5))
final_carry, ys = pure_jax_fancy_scan(
cumsum, jnp.zeros(2), x,
in_axes=(nnx.Carry, -1), out_axes=(nnx.Carry, -1),
)
np.testing.assert_allclose(final_carry, jnp.array([10.0, 35.0]))
expected_ys = jnp.array([
[0., 1., 3., 6., 10.],
[5., 11., 18., 26., 35.]
])
np.testing.assert_allclose(ys, expected_ys)

def test_scan_different_in_out_axes(self):
def cumsum(carry, x):
carry = carry + x
return carry, carry

x = jnp.arange(10.0).reshape((2, 5))
final_carry, ys = pure_jax_fancy_scan(
cumsum, jnp.zeros(2), x,
in_axes=(nnx.Carry, 1), out_axes=(nnx.Carry, 0),
)
np.testing.assert_allclose(final_carry, jnp.array([10.0, 35.0]))
expected_ys = jnp.array([
[0., 5.],
[1., 11.],
[3., 18.],
[6., 26.],
[10., 35.]
])
np.testing.assert_allclose(ys, expected_ys)


if __name__ == '__main__':
absltest.main()
Loading