diff --git a/flax/nnx/transforms/iteration.py b/flax/nnx/transforms/iteration.py index 36c351f34..c169a91fa 100644 --- a/flax/nnx/transforms/iteration.py +++ b/flax/nnx/transforms/iteration.py @@ -107,17 +107,25 @@ def _update_variable_sharding_metadata( ): def _update_axes_fn(tree_node): if isinstance(tree_node, extract.TreeNode) and isinstance( - tree_node.metatata, StateAxes + tree_node.metatata, (StateAxes, int) ): - graphdef_states_out: list[extract.GraphDefState] = [] - for graphdef_state, axis in zip( + if isinstance(tree_node.metatata, int): + graph_def_state = tree_node.graphdef_states[0] + assert isinstance(graph_def_state, extract.GraphDefState) + graphdef_state = axis_fn( + graph_def_state, tree_node.metatata, transform_metadata + ) + return tree_node.replace(graphdef_states=(graphdef_state,)) + else: + graphdef_states_out: list[extract.GraphDefState] = [] + for graphdef_state, axis in zip( tree_node.graphdef_states, tree_node.metatata.axes - ): - assert isinstance(graphdef_state, extract.GraphDefState) - if isinstance(axis, int): - graphdef_state = axis_fn(graphdef_state, axis, transform_metadata) - graphdef_states_out.append(graphdef_state) - return tree_node.replace(graphdef_states=tuple(graphdef_states_out)) + ): + assert isinstance(graphdef_state, extract.GraphDefState) + if isinstance(axis, int): + graphdef_state = axis_fn(graphdef_state, axis, transform_metadata) + graphdef_states_out.append(graphdef_state) + return tree_node.replace(graphdef_states=tuple(graphdef_states_out)) return tree_node return jax.tree.map( @@ -130,7 +138,7 @@ def _vmap_split_fn(ctx: graph.SplitContext, path, prefix, x): return extract.TreeNode.from_split( *ctx.split(x, *prefix.filters), metadata=prefix ) - return extract.TreeNode.from_split(*ctx.split(x)) + return extract.TreeNode.from_split(*ctx.split(x), metadata=prefix) @dataclasses.dataclass(eq=False) diff --git a/tests/nnx/transforms_test.py b/tests/nnx/transforms_test.py index be487628f..b693be2d1 100644 --- a/tests/nnx/transforms_test.py +++ b/tests/nnx/transforms_test.py @@ -2235,6 +2235,27 @@ def forward(model, x): self.assertEqual(y.shape, (5, 4, 3)) + def test_metadata(self): + @nnx.vmap( + in_axes=(None,), + out_axes=0, + axis_size=5, + transform_metadata={nnx.spmd.PARTITION_NAME: 'c'}, + ) + def create_block(rngs: nnx.Rngs): + return nnx.Linear( + 16, + 32, + rngs=rngs, + kernel_init=nnx.with_partitioning( + nnx.initializers.lecun_normal(), ('a', 'b') + ), + ) + + m = create_block(nnx.Rngs(0)) + self.assertEqual(m.kernel.value.shape, (5, 16, 32)) + self.assertEqual(m.kernel.sharding, ('c', 'a', 'b')) + class TestPmap(absltest.TestCase):