diff --git a/flax/nnx/graphlib.py b/flax/nnx/graphlib.py index 4cc499390..fe4357a9f 100644 --- a/flax/nnx/graphlib.py +++ b/flax/nnx/graphlib.py @@ -1602,7 +1602,7 @@ def _cached_partial( >>> model = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) >>> optimizer = nnx.Optimizer(model, optax.adamw(1e-3), wrt=nnx.Param) ... - >>> @nnx.jit + >>> @nnx.jit(graph_updates=True, graph=True) ... def train_step(model, optimizer, x, y): ... def loss_fn(model): ... return jnp.mean((model(x) - y) ** 2) @@ -1611,8 +1611,9 @@ def _cached_partial( ... optimizer.update(model, grads) ... return loss ... - >>> cached_train_step = nnx.cached_partial(train_step, model, optimizer) - ... + >>> with nnx.set_graph_mode(True): + ... with nnx.set_graph_updates(True): + ... cached_train_step = nnx.cached_partial(train_step, model, optimizer) >>> for step in range(total_steps:=2): ... x, y = jnp.ones((10, 2)), jnp.ones((10, 3)) ... # loss = train_step(model, optimizer, x, y) @@ -3156,7 +3157,7 @@ def iter_graph( >>> module = Linear(3, 4, rngs=nnx.Rngs(0)) >>> graph = [module, module] ... - >>> for path, value in nnx.iter_graph(graph): + >>> for path, value in nnx.iter_graph(graph, graph=True): ... print(path, type(value).__name__) ... (0, '_pytree__nodes') HashableMapping