Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 18 additions & 10 deletions flax/nnx/transforms/iteration.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,17 +107,25 @@ def _update_variable_sharding_metadata(
):
def _update_axes_fn(tree_node):
if isinstance(tree_node, extract.TreeNode) and isinstance(
tree_node.metatata, StateAxes
tree_node.metatata, (StateAxes, int)
):
graphdef_states_out: list[extract.GraphDefState] = []
for graphdef_state, axis in zip(
if isinstance(tree_node.metatata, int):
graph_def_state = tree_node.graphdef_states[0]
assert isinstance(graph_def_state, extract.GraphDefState)
graphdef_state = axis_fn(
graph_def_state, tree_node.metatata, transform_metadata
)
return tree_node.replace(graphdef_states=(graphdef_state,))
else:
graphdef_states_out: list[extract.GraphDefState] = []
for graphdef_state, axis in zip(
tree_node.graphdef_states, tree_node.metatata.axes
):
assert isinstance(graphdef_state, extract.GraphDefState)
if isinstance(axis, int):
graphdef_state = axis_fn(graphdef_state, axis, transform_metadata)
graphdef_states_out.append(graphdef_state)
return tree_node.replace(graphdef_states=tuple(graphdef_states_out))
):
assert isinstance(graphdef_state, extract.GraphDefState)
if isinstance(axis, int):
graphdef_state = axis_fn(graphdef_state, axis, transform_metadata)
graphdef_states_out.append(graphdef_state)
return tree_node.replace(graphdef_states=tuple(graphdef_states_out))
return tree_node

return jax.tree.map(
Expand All @@ -130,7 +138,7 @@ def _vmap_split_fn(ctx: graph.SplitContext, path, prefix, x):
return extract.TreeNode.from_split(
*ctx.split(x, *prefix.filters), metadata=prefix
)
return extract.TreeNode.from_split(*ctx.split(x))
return extract.TreeNode.from_split(*ctx.split(x), metadata=prefix)


@dataclasses.dataclass(eq=False)
Expand Down
21 changes: 21 additions & 0 deletions tests/nnx/transforms_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2235,6 +2235,27 @@ def forward(model, x):

self.assertEqual(y.shape, (5, 4, 3))

def test_metadata(self):
@nnx.vmap(
in_axes=(None,),
out_axes=0,
axis_size=5,
transform_metadata={nnx.spmd.PARTITION_NAME: 'c'},
)
def create_block(rngs: nnx.Rngs):
return nnx.Linear(
16,
32,
rngs=rngs,
kernel_init=nnx.with_partitioning(
nnx.initializers.lecun_normal(), ('a', 'b')
),
)

m = create_block(nnx.Rngs(0))
self.assertEqual(m.kernel.value.shape, (5, 16, 32))
self.assertEqual(m.kernel.sharding, ('c', 'a', 'b'))


class TestPmap(absltest.TestCase):

Expand Down
Loading