diff --git a/flax/nnx/nnx/module.py b/flax/nnx/nnx/module.py index 2dbc26d3e..13292bcff 100644 --- a/flax/nnx/nnx/module.py +++ b/flax/nnx/nnx/module.py @@ -28,7 +28,7 @@ from flax.nnx.nnx.object import Object, ObjectMeta from flax.nnx.nnx.graph import GraphState, StateLeaf from flax.nnx.nnx.state import State -from flax.typing import Path, PathParts +from flax.typing import Key, Path, PathParts A = tp.TypeVar('A') B = tp.TypeVar('B') @@ -184,7 +184,8 @@ def sow( setattr(self, name, variable_type(reduced_value)) def iter_modules(self) -> tp.Iterator[tuple[PathParts, Module]]: - """Iterates over all nested Modules of the current Module, including the current Module. + """Recursively iterates over all nested :class:`Module`'s of the current Module, including + the current Module. ``iter_modules`` creates a generator that yields the path and the Module instance, where the path is a tuple of strings or integers representing the path to the Module from the @@ -194,13 +195,18 @@ def iter_modules(self) -> tp.Iterator[tuple[PathParts, Module]]: >>> from flax import nnx ... + >>> class SubModule(nnx.Module): + ... def __init__(self, din, dout, rngs): + ... self.linear1 = nnx.Linear(din, dout, rngs=rngs) + ... self.linear2 = nnx.Linear(din, dout, rngs=rngs) + ... >>> class Block(nnx.Module): ... def __init__(self, din, dout, *, rngs: nnx.Rngs): ... self.linear = nnx.Linear(din, dout, rngs=rngs) + ... self.submodule = SubModule(din, dout, rngs=rngs) ... self.dropout = nnx.Dropout(0.5) ... self.batch_norm = nnx.BatchNorm(10, rngs=rngs) ... - ... >>> model = Block(2, 5, rngs=nnx.Rngs(0)) >>> for path, module in model.iter_modules(): ... print(path, type(module).__name__) @@ -208,12 +214,54 @@ def iter_modules(self) -> tp.Iterator[tuple[PathParts, Module]]: ('batch_norm',) BatchNorm ('dropout',) Dropout ('linear',) Linear + ('submodule', 'linear1') Linear + ('submodule', 'linear2') Linear + ('submodule',) SubModule () Block """ for path, value in graph.iter_graph(self): if isinstance(value, Module): yield path, value + def iter_children(self) -> tp.Iterator[tuple[Key, Module]]: + """Iterates over all children :class:`Module`'s of the current Module. This + method is similar to :func:`iter_modules`, except it only iterates over the + immediate children, and does not recurse further down. + + ``iter_children`` creates a generator that yields the key and the Module instance, + where the key is a string representing the attribute name of the Module to access + the corresponding child Module. + + Example:: + + >>> from flax import nnx + ... + >>> class SubModule(nnx.Module): + ... def __init__(self, din, dout, rngs): + ... self.linear1 = nnx.Linear(din, dout, rngs=rngs) + ... self.linear2 = nnx.Linear(din, dout, rngs=rngs) + ... + >>> class Block(nnx.Module): + ... def __init__(self, din, dout, *, rngs: nnx.Rngs): + ... self.linear = nnx.Linear(din, dout, rngs=rngs) + ... self.submodule = SubModule(din, dout, rngs=rngs) + ... self.dropout = nnx.Dropout(0.5) + ... self.batch_norm = nnx.BatchNorm(10, rngs=rngs) + ... + >>> model = Block(2, 5, rngs=nnx.Rngs(0)) + >>> for path, module in model.iter_children(): + ... print(path, type(module).__name__) + ... + batch_norm BatchNorm + dropout Dropout + linear Linear + submodule SubModule + """ + node_dict = graph.get_node_impl(self).node_dict(self) + for key, value in node_dict.items(): + if isinstance(value, Module): + yield key, value + def set_attributes( self, *filters: filterlib.Filter, diff --git a/flax/nnx/tests/module_test.py b/flax/nnx/tests/module_test.py index 1590d4934..f627d3233 100644 --- a/flax/nnx/tests/module_test.py +++ b/flax/nnx/tests/module_test.py @@ -616,18 +616,44 @@ def __init__(self, *, rngs: nnx.Rngs): {'a': nnx.Linear(1, 1, rngs=rngs)}, {'b': nnx.Conv(1, 1, 1, rngs=rngs)}, ] + self.linear = nnx.Linear(1, 1, rngs=rngs) + self.dropout = nnx.Dropout(0.5, rngs=rngs) module = Foo(rngs=nnx.Rngs(0)) modules = list(module.iter_modules()) - assert len(modules) == 3 - assert modules[0][0] == ('submodules', 0, 'a') - assert isinstance(modules[0][1], nnx.Linear) - assert modules[1][0] == ('submodules', 1, 'b') - assert isinstance(modules[1][1], nnx.Conv) - assert modules[2][0] == () - assert isinstance(modules[2][1], Foo) + assert len(modules) == 5 + assert modules[0][0] == ('dropout',) + assert isinstance(modules[0][1], nnx.Dropout) + assert modules[1][0] == ('linear',) + assert isinstance(modules[1][1], nnx.Linear) + assert modules[2][0] == ('submodules', 0, 'a') + assert isinstance(modules[2][1], nnx.Linear) + assert modules[3][0] == ('submodules', 1, 'b') + assert isinstance(modules[3][1], nnx.Conv) + assert modules[4][0] == () + assert isinstance(modules[4][1], Foo) + + def test_children_modules_iterator(self): + class Foo(nnx.Module): + def __init__(self, *, rngs: nnx.Rngs): + self.submodules = [ + {'a': nnx.Linear(1, 1, rngs=rngs)}, + {'b': nnx.Conv(1, 1, 1, rngs=rngs)}, + ] + self.linear = nnx.Linear(1, 1, rngs=rngs) + self.dropout = nnx.Dropout(0.5, rngs=rngs) + + module = Foo(rngs=nnx.Rngs(0)) + + modules = list(module.iter_children()) + + assert len(modules) == 2 + assert modules[0][0] == 'dropout' + assert isinstance(modules[0][1], nnx.Dropout) + assert modules[1][0] == 'linear' + assert isinstance(modules[1][1], nnx.Linear) def test_array_in_module(self): class Foo(nnx.Module):