From a138d9f3d5e1767ca50603744c55151758846732 Mon Sep 17 00:00:00 2001 From: Cristian Garcia Date: Sat, 21 Mar 2026 07:45:41 -0700 Subject: [PATCH] fix transform_metadata PiperOrigin-RevId: 887307013 --- flax/nnx/extract.py | 15 ++++++-- flax/nnx/transforms/iteration.py | 16 ++++----- tests/nnx/spmd_test.py | 24 +++++++++++++ tests/nnx/transforms_test.py | 61 +++++++++++++++++++++++++++++--- 4 files changed, 100 insertions(+), 16 deletions(-) diff --git a/flax/nnx/extract.py b/flax/nnx/extract.py index 96226a208..2ade95e4e 100644 --- a/flax/nnx/extract.py +++ b/flax/nnx/extract.py @@ -517,7 +517,14 @@ def updates_and_snapshot(args: A) -> tuple[A, A]: for leaf in leaves: if isinstance(leaf, variablelib.Variable): updates_leaves.append(leaf) - snapshot_leaves.append(leaf.copy()) + # don't snapshot hijax or ref Variables as their updates are automatically + # masked out in mask_variable_updates. However, the leaf is kept in the + # updates to check for aliasing. This avoids a copy operation which has + # significance for ref Variables. + if leaf.hijax or leaf.ref: + snapshot_leaves.append(Mask()) + else: + snapshot_leaves.append(leaf.copy()) else: updates_leaves.append(Mask()) snapshot_leaves.append(Mask()) @@ -597,6 +604,10 @@ def mask_variable_updates( keep_fn = lambda _, _pfx, cur, snap: variable_changed(cur, snap) def _mask_updates(path, prefix_leaf, current, snapshot): + if current is None: + # None leaves should remain None, they only appear here because + # is_leaf catches None values for the prefix + return None if isinstance(current, variablelib.Variable): if current.hijax or current.ref: return Mask() @@ -610,7 +621,7 @@ def _mask_updates(path, prefix_leaf, current, snapshot): current_tree, snapshot_tree, is_leaf=is_leaf, ) return broadcast_prefix_map( - _mask_updates, prefix, current_tree, snapshot_tree, is_leaf=is_leaf, + _mask_updates, prefix, current_tree, snapshot_tree, is_leaf=is_leaf ) diff --git a/flax/nnx/transforms/iteration.py b/flax/nnx/transforms/iteration.py index 3d2592ee4..67ec87740 100644 --- a/flax/nnx/transforms/iteration.py +++ b/flax/nnx/transforms/iteration.py @@ -69,21 +69,19 @@ def _apply_axis_fn( axis_fn: tp.Callable[..., tp.Any], ) -> None: is_leaf = lambda x: x is None or isinstance(x, variablelib.Variable) - _, per_leaf_axes = extract.broadcast_prefix2(axes, tree, is_leaf=is_leaf) - leaves = jax.tree_util.tree_leaves(tree, is_leaf=is_leaf) - for leaf, axis in zip(leaves, per_leaf_axes): - if (axis is None or isinstance(axis, int)) and isinstance( - leaf, variablelib.Variable - ): + def apply_fn(path, axis, leaf): + if isinstance(axis, int) and isinstance(leaf, variablelib.Variable): axis_fn(leaf, axis, metadata) + extract.broadcast_prefix_map(apply_fn, axes, tree, is_leaf=is_leaf) + @tp.overload def transform_metadata( *, in_axes: tp.Any = 0, out_axes: tp.Any = 0, - partition: str, + partition: str | None, graph: bool | None = None, ) -> tp.Callable[[F], F]: ... @@ -96,7 +94,7 @@ def transform_metadata( in_axes: tp.Any = 0, out_axes: tp.Any = 0, graph: bool | None = None, - partition: str, + partition: str | None, ) -> F: ... @@ -106,8 +104,8 @@ def transform_metadata( *, in_axes: tp.Any = 0, out_axes: tp.Any = 0, + partition: str | None, graph: bool | None = None, - partition: str, ) -> F | tp.Callable[[F], F]: if f is Missing: return functools.partial( diff --git a/tests/nnx/spmd_test.py b/tests/nnx/spmd_test.py index d05a26b16..9b81e680e 100644 --- a/tests/nnx/spmd_test.py +++ b/tests/nnx/spmd_test.py @@ -219,6 +219,30 @@ def f(v): self.assertEqual(v2[...], 10) + def test_transform_metadata_decorator_none_partition(self): + v = nnx.Param( + jnp.array(1), + out_sharding=(None, 'dout'), + eager_sharding=False, + ) + + @nnx.transform_metadata(in_axes=0, out_axes=1, partition=None) + def f(v): + v[...] += 1 + self.assertEqual(v.out_sharding, ('dout',)) + v2 = nnx.Param( + jnp.array(10), + out_sharding=('dmid', 'dout'), + eager_sharding=False, + ) + return v2 + + v2 = f(v) + self.assertEqual(v.out_sharding, (None, 'dout')) + self.assertEqual(v[...], 2) + self.assertEqual(v2.out_sharding, ('dmid', None, 'dout')) + self.assertEqual(v2[...], 10) + @parameterized.product(use_eager_sharding=[True, False]) def test_eager_sharding_context(self, use_eager_sharding): rngs = nnx.Rngs(0) diff --git a/tests/nnx/transforms_test.py b/tests/nnx/transforms_test.py index db179a9af..1e9bf06e8 100644 --- a/tests/nnx/transforms_test.py +++ b/tests/nnx/transforms_test.py @@ -5537,12 +5537,9 @@ def dict_scan(carry, x): def test_no_carry_all_scanned(self): def double(x): - return (x * 2,) + return x * 2 - (ys,) = pure_jax_fancy_scan( - double, jnp.arange(5.0), - in_axes=(0,), out_axes=(0,), - ) + ys = pure_jax_fancy_scan(double, jnp.arange(5.0), in_axes=0, out_axes=0) np.testing.assert_allclose(ys, jnp.arange(5.0) * 2) def test_reverse(self): @@ -5651,6 +5648,60 @@ def f(x): ) np.testing.assert_allclose(ys, jnp.arange(5.0) * 2) + def test_scan_axis_1(self): + def cumsum(carry, x): + carry = carry + x + return carry, carry + + x = jnp.arange(10.0).reshape((2, 5)) + final_carry, ys = pure_jax_fancy_scan( + cumsum, jnp.zeros(2), x, + in_axes=(nnx.Carry, 1), out_axes=(nnx.Carry, 1), + ) + np.testing.assert_allclose(final_carry, jnp.array([10.0, 35.0])) + expected_ys = jnp.array([ + [0., 1., 3., 6., 10.], + [5., 11., 18., 26., 35.] + ]) + np.testing.assert_allclose(ys, expected_ys) + + def test_scan_axis_negative_1(self): + def cumsum(carry, x): + carry = carry + x + return carry, carry + + x = jnp.arange(10.0).reshape((2, 5)) + final_carry, ys = pure_jax_fancy_scan( + cumsum, jnp.zeros(2), x, + in_axes=(nnx.Carry, -1), out_axes=(nnx.Carry, -1), + ) + np.testing.assert_allclose(final_carry, jnp.array([10.0, 35.0])) + expected_ys = jnp.array([ + [0., 1., 3., 6., 10.], + [5., 11., 18., 26., 35.] + ]) + np.testing.assert_allclose(ys, expected_ys) + + def test_scan_different_in_out_axes(self): + def cumsum(carry, x): + carry = carry + x + return carry, carry + + x = jnp.arange(10.0).reshape((2, 5)) + final_carry, ys = pure_jax_fancy_scan( + cumsum, jnp.zeros(2), x, + in_axes=(nnx.Carry, 1), out_axes=(nnx.Carry, 0), + ) + np.testing.assert_allclose(final_carry, jnp.array([10.0, 35.0])) + expected_ys = jnp.array([ + [0., 5.], + [1., 11.], + [3., 18.], + [6., 26.], + [10., 35.] + ]) + np.testing.assert_allclose(ys, expected_ys) + if __name__ == '__main__': absltest.main()