From d930ccd363b47cf72d77d003dc55ae7f91b89fd2 Mon Sep 17 00:00:00 2001 From: Cristian Garcia Date: Tue, 17 Mar 2026 19:29:40 -0700 Subject: [PATCH] improve error messages for tree mode duplicates check PiperOrigin-RevId: 885327573 --- flax/nnx/__init__.py | 2 + flax/nnx/extract.py | 38 ++++++--- flax/nnx/graphlib.py | 129 ++++++++++++++++------------- flax/nnx/transforms/autodiff.py | 22 ++--- flax/nnx/transforms/compilation.py | 10 +-- flax/nnx/transforms/iteration.py | 14 ++-- flax/nnx/transforms/transforms.py | 14 ++-- tests/nnx/graph_utils_test.py | 60 +++++++++++++- tests/nnx/transforms_test.py | 12 +-- 9 files changed, 195 insertions(+), 106 deletions(-) diff --git a/flax/nnx/__init__.py b/flax/nnx/__init__.py index 696dfd026..aa2a6b641 100644 --- a/flax/nnx/__init__.py +++ b/flax/nnx/__init__.py @@ -78,6 +78,8 @@ from .graphlib import cached_partial as cached_partial from .graphlib import flatten as flatten from .graphlib import unflatten as unflatten +from .graphlib import set_graph_mode as set_graph_mode +from .graphlib import set_graph_updates as set_graph_updates from .nn import initializers as initializers from .nn.activations import celu as celu from .nn.activations import elu as elu diff --git a/flax/nnx/extract.py b/flax/nnx/extract.py index da546f89e..54c704f75 100644 --- a/flax/nnx/extract.py +++ b/flax/nnx/extract.py @@ -335,6 +335,17 @@ def to_tree2( prefix: tp.Any = Missing, check_aliasing: bool = True, ) -> tp.Any: + """to_tree2 has two main tasks: + + 1. Convert all graph nodes to NodeStates (a tree representation). + 2. Check all Variables are aliased consistently given the prefix tree, + e.g. vmap's in/out_axes arguments. + + Each NodeState contains the `GraphDef` and State for each object, these + are generated using `graphlib.flatten`. `extract.broadcast_prefix` is used + to calculate the prefix for each node, `check_consistent_aliasing2` traverses + the nodes subgraph and checks for Variable aliasing. + """ ref_index: graphlib.RefMap = graphlib.RefMap() def _to_node_states(leaf): @@ -343,8 +354,8 @@ def _to_node_states(leaf): graphdef, flat_state = graphlib.flatten( leaf, ref_index=ref_index, graph=True ) - states = graphlib._to_nested_state(graphdef, (flat_state,)) - return NodeStates.from_split(graphdef, *states) + (state,) = graphlib._to_nested_state(graphdef, (flat_state,)) + return NodeStates.from_split(graphdef, state) is_leaf = lambda x: ( isinstance(x, variablelib.Variable) or graphlib.is_graph_node(x) @@ -503,8 +514,8 @@ def updates_and_snapshot(args: A) -> tuple[A, A]: return updates, snapshot -def check_no_aliases(**kwargs): - Attrs = namedtuple('Attrs', kwargs.keys()) +def check_no_aliases(fn_name: str, /, **kwargs): + Attrs = namedtuple('Attrs', kwargs.keys()) # type: ignore[misc] container = Attrs(**kwargs) is_leaf = lambda x: isinstance(x, variablelib.Variable) seen: dict[int, jax.tree_util.KeyPath] = {} @@ -518,9 +529,11 @@ def check_no_aliases(**kwargs): path_str = jax.tree_util.keystr(path) seen_path_str = jax.tree_util.keystr(seen[var_id]) raise ValueError( - f'Variable at {path_str} is the same instance as at ' - f'{seen_path_str}. tree-mode transforms do not support ' - f'returning input Variables as outputs.' + f'Duplicate {leaf}\nfound at paths:\n\n' + f' - {seen_path_str}\n' + f' - {path_str}\n\n' + f'nnx.{fn_name} with graph_updates=False does not support ' + 'returning input Variables as outputs.' ) seen[var_id] = path @@ -558,7 +571,9 @@ def _mask_updates(path, current, snapshot): ) -def apply_variable_updates(args_tree: A, updates_tree: A) -> None: +def apply_variable_updates( + args_tree: A, updates_tree: A, *, fn_name: str, +) -> None: is_leaf = lambda x: isinstance(x, variablelib.Variable) or isinstance(x, Mask) args_leaves, treedef = jax.tree.flatten_with_path(args_tree, is_leaf=is_leaf) updates_leaves = treedef.flatten_up_to(updates_tree) @@ -571,8 +586,11 @@ def apply_variable_updates(args_tree: A, updates_tree: A) -> None: path_str = jax.tree_util.keystr(path) seen_path_str = jax.tree_util.keystr(seen[var_id]) raise ValueError( - f'Variable at {path_str} was already seen at {seen_path_str}. ' - 'tree-mode jit does not support shared Variable references.' + f'Duplicate {variable}\nfound at paths:\n\n' + f' - {seen_path_str}\n' + f' - {path_str}\n\n' + f'Tree mode (graph=False) does not support shared references. ' + + graphlib._tree_mode_suggestion(fn_name) ) seen[var_id] = path if isinstance(update, variablelib.Variable): diff --git a/flax/nnx/graphlib.py b/flax/nnx/graphlib.py index 292d8ac56..c7df5c71c 100644 --- a/flax/nnx/graphlib.py +++ b/flax/nnx/graphlib.py @@ -51,23 +51,34 @@ def _tree_mode_suggestion(fn_name: str) -> str: return ( - f'\n\nIf the structure is intended to be a graph, consider ' - f'using graph=True or nnx.graph.{fn_name}.' + f'Consider the following options:\n\n' + '1. Remove the duplicates and guarantee a tree structure.\n' + f'2. Enable graph mode by passing graph=True to {fn_name} e.g.\n\n' + f' nnx.{fn_name}(..., graph=True)\n\n' + f'3. Use nnx.compat.{fn_name} instead e.g.\n\n' + f' nnx.compat.{fn_name}(...)' ) -def _check_valid_pytree(node: tp.Any, fn_name: str) -> None: +def _check_valid_pytree( + node: tp.Any, fn_name: str, path: str = '', +) -> None: from flax.nnx import pytreelib if ( isinstance(node, pytreelib.Pytree) and not node._pytree__is_pytree ): - raise ValueError( + msg = ( f"Cannot use '{fn_name}' with graph=False on a " f"'{type(node).__name__}' instance that has pytree=False. " + ) + if path: + msg += f"Found at path: {path}. " + msg += ( f"Pytree subclasses with pytree=False are not registered as " f"JAX pytrees and cannot be used in tree-mode. " + _tree_mode_suggestion(fn_name) ) + raise ValueError(msg) Names = tp.Sequence[int] Node = tp.TypeVar('Node') @@ -658,13 +669,38 @@ def _tree_flatten( leaves: list[tp.Any], paths: list[PathParts] | None, ) -> None: - def _is_leaf(x): + seen_variables: dict[int, str] = {} + seen_refs: dict[int, str] = {} + def _is_leaf(path, x): if isinstance(x, Variable): + var_id = id(x) + str_path = jax.tree_util.keystr(path) + if var_id in seen_variables: + raise ValueError( + f'Duplicate {x}\nfound at paths:\n\n' + f' - {seen_variables[var_id]}\n' + f' - {str_path}\n\n' + 'Tree mode (graph=False) does not support shared references. ' + + _tree_mode_suggestion('split') + ) + seen_variables[var_id] = str_path return True - _check_valid_pytree(x, 'flatten') + if variablelib.is_array_ref(x): + ref_id = id(x) + str_path = jax.tree_util.keystr(path) + if ref_id in seen_refs: + raise ValueError( + f'Duplicate {x}\nfound at paths:\n\n' + f' - {seen_refs[ref_id]}\n' + f' - {str_path}\n\n' + 'Tree mode (graph=False) does not support shared references. ' + + _tree_mode_suggestion('split') + ) + seen_refs[ref_id] = str_path + _check_valid_pytree(x, 'flatten', jax.tree_util.keystr(path)) return False jax_leaves, treedef = jax.tree_util.tree_flatten_with_path( - node, is_leaf=_is_leaf + node, is_leaf=_is_leaf, is_leaf_takes_path=True ) nnx_paths_and_leaves: list[tuple[PathParts, tp.Any]] = [ (jax_to_nnx_path(jax_path), value) for jax_path, value in jax_leaves @@ -682,29 +718,9 @@ def _is_leaf(x): ) nodes.append(tree_nodedef) - seen_variables: set[int] = set() - seen_refs: set[int] = set() sorted_leaf_index = 0 for nnx_path, value in nnx_paths_and_leaves: if isinstance(value, Variable): - var_id = id(value) - if var_id in seen_variables: - raise ValueError( - f'Duplicate Variable found at path {nnx_path!r}. ' - 'Tree mode (graph=False) does not support shared references. ' - + _tree_mode_suggestion('split') - ) - seen_variables.add(var_id) - raw_value = value.get_raw_value() - if variablelib.is_array_ref(raw_value): - ref_id = id(raw_value) - if ref_id in seen_refs: - raise ValueError( - f'Duplicate Ref found inside Variable at path {nnx_path!r}. ' - 'Tree mode (graph=False) does not support shared references. ' - + _tree_mode_suggestion('split') - ) - seen_refs.add(ref_id) nodes.append(VariableDef( type=value.var_type, index=sorted_leaf_index, @@ -712,15 +728,6 @@ def _is_leaf(x): metadata=HashableMapping(value.get_metadata()), array_refdef=None, )) - elif variablelib.is_array_ref(value): - ref_id = id(value) - if ref_id in seen_refs: - raise ValueError( - f'Duplicate Ref found at path {nnx_path!r}. ' - 'Tree mode (graph=False) does not support shared references. ' - + _tree_mode_suggestion('split') - ) - seen_refs.add(ref_id) leaves.append(value) if paths is not None: paths.append(nnx_path) @@ -2946,40 +2953,45 @@ def _iter_graph(node: tp.Any, /) -> tp.Iterator[tuple[PathParts, tp.Any]]: def _iter_tree(node: tp.Any, /) -> tp.Iterator[tuple[PathParts, tp.Any]]: - in_progress: set[int] = set() - seen_refs: set[int] = set() + in_progress: dict[int, str] = {} + seen_refs: dict[int, str] = {} stack: list[tuple[PathParts, tp.Any, bool]] = [((), node, False)] while stack: path, current, traversed = stack.pop() if traversed: - in_progress.discard(id(current)) + in_progress.pop(id(current), None) yield path, current continue if not is_pytree_node(current, check_graph_registry=False): - _check_valid_pytree(current, 'iter_graph') + _check_valid_pytree(current, 'iter_graph', '/'.join(map(str, path))) if isinstance(current, Variable) or variablelib.is_array_ref(current): obj_id = id(current) + str_path = '/'.join(map(str, path)) if obj_id in seen_refs: raise ValueError( - f'Found duplicate Variable or Ref at path ' - f'"{"/".join(map(str, path))}". ' - 'Shared references are not supported with graph=False. ' + f'Duplicate {current}\nfound at paths:\n\n' + f' - {seen_refs[obj_id]}\n' + f' - {str_path}\n\n' + 'Tree mode (graph=False) does not support shared references. ' + _tree_mode_suggestion('iter_graph') ) - seen_refs.add(obj_id) + seen_refs[obj_id] = str_path yield path, current continue obj_id = id(current) + str_path = '/'.join(map(str, path)) if obj_id in in_progress: raise ValueError( - f'Found cycle at path "{"/".join(map(str, path))}". ' + f'Cycle detected for {type(current).__name__}\nfound at paths:\n\n' + f' - {in_progress[obj_id]}\n' + f' - {str_path}\n\n' 'Cycles are not supported with graph=False. ' + _tree_mode_suggestion('iter_graph') ) - in_progress.add(obj_id) + in_progress[obj_id] = str_path stack.append((path, current, True)) children, _ = jax.tree_util.tree_flatten_with_path( @@ -3156,32 +3168,37 @@ def _recursive_map_tree( f: tp.Callable[[PathParts, tp.Any], tp.Any], node: tp.Any, ) -> tp.Any: - in_progress: set[int] = set() - seen_refs: set[int] = set() + in_progress: dict[int, str] = {} + seen_refs: dict[int, str] = {} def _recurse(path: PathParts, current: tp.Any) -> tp.Any: if not is_pytree_node(current, check_graph_registry=False): - _check_valid_pytree(current, 'recursive_map') + _check_valid_pytree(current, 'recursive_map', '/'.join(map(str, path))) if isinstance(current, Variable) or is_array_ref(current): obj_id = id(current) + str_path = '/'.join(map(str, path)) if obj_id in seen_refs: raise ValueError( - f'Found duplicate Variable or Ref at path ' - f'"{"/".join(map(str, path))}". ' - 'Shared references are not supported with graph=False. ' + f'Duplicate {current}\nfound at paths:\n\n' + f' - {seen_refs[obj_id]}\n' + f' - {str_path}\n\n' + 'Tree mode (graph=False) does not support shared references. ' + _tree_mode_suggestion('recursive_map') ) - seen_refs.add(obj_id) + seen_refs[obj_id] = str_path return f(path, current) obj_id = id(current) + str_path = '/'.join(map(str, path)) if obj_id in in_progress: raise ValueError( - f'Found cycle at path "{"/".join(map(str, path))}". ' + f'Cycle detected for {type(current).__name__}\nfound at paths:\n\n' + f' - {in_progress[obj_id]}\n' + f' - {str_path}\n\n' 'Cycles are not supported with graph=False. ' + _tree_mode_suggestion('recursive_map') ) - in_progress.add(obj_id) + in_progress[obj_id] = str_path children_with_path, treedef = jax.tree_util.tree_flatten_with_path( current, is_leaf=lambda x: x is not current @@ -3195,7 +3212,7 @@ def _recurse(path: PathParts, current: tp.Any) -> tp.Any: new_node = treedef.unflatten(new_children) result = f(path, new_node) - in_progress.discard(obj_id) + in_progress.pop(obj_id, None) return result return _recurse((), node) diff --git a/flax/nnx/transforms/autodiff.py b/flax/nnx/transforms/autodiff.py index 15e941e19..19d421769 100644 --- a/flax/nnx/transforms/autodiff.py +++ b/flax/nnx/transforms/autodiff.py @@ -79,7 +79,7 @@ def __call__(self, *args, **kwargs): out = self.f(*args, **kwargs) if self.graph: out = extract.to_tree2(out) - extract.check_no_aliases(args=updates[0], kwargs=updates[1], out=out) + extract.check_no_aliases('grad', args=updates[0], kwargs=updates[1], out=out) updates = extract.mask_variable_updates(updates, snapshot) if self.has_aux: @@ -189,7 +189,7 @@ def tree_grad_wrapper(*args, **kwargs): if graph: grads = extract.from_tree2(grads) result = grads - extract.apply_variable_updates((args, kwargs), updates) + extract.apply_variable_updates((args, kwargs), updates, fn_name='grad') return result return tree_grad_wrapper @@ -570,7 +570,7 @@ def __call__(self, *args): out = self.f(*args) if self.graph: out = extract.to_tree2(out) - extract.check_no_aliases(args=updates, out=out) + extract.check_no_aliases('vjp', args=updates, out=out) updates = extract.mask_variable_updates(updates, snapshot) if self.has_aux: primals_out, aux = out @@ -706,7 +706,7 @@ def vjp( raw_vjp_fn = vjp_fn def vjp_fn(g): return extract.from_tree2(raw_vjp_fn(g)) - extract.apply_variable_updates(primals, updates) + extract.apply_variable_updates(primals, updates, fn_name='vjp') if has_aux: return primals_out, vjp_fn, user_aux else: @@ -735,7 +735,7 @@ def __call__(self, *args): out = self.f(*args) if self.graph: out = extract.to_tree2(out) - extract.check_no_aliases(args=updates, out=out) + extract.check_no_aliases('jvp', args=updates, out=out) updates = extract.mask_variable_updates(updates, snapshot) if self.has_aux: primals_out, aux = out @@ -881,7 +881,7 @@ def jvp( if graph: primals_out = extract.from_tree2(primals_out) tangent_out = extract.from_tree2(tangent_out) - extract.apply_variable_updates(primals, updates) + extract.apply_variable_updates(primals, updates, fn_name='jvp') if has_aux: return primals_out, tangent_out, aux else: @@ -909,7 +909,7 @@ def __call__(self, *args): out = self.f(*args) if self.graph: out = extract.to_tree2(out) - extract.check_no_aliases(args=updates, out=out) + extract.check_no_aliases('custom_vjp', args=updates, out=out) updates = extract.mask_variable_updates(updates, snapshot) return out, updates @@ -931,7 +931,7 @@ def __call__(self, *args): if self.graph: out = extract.to_tree2(out) residual = extract.to_tree2(residual) - extract.check_no_aliases(args=updates, out=out) + extract.check_no_aliases('custom_vjp', args=updates, out=out) updates = extract.mask_variable_updates(updates, snapshot) return (out, updates), residual @@ -1007,7 +1007,7 @@ def __call__( ) if self.graph: out = extract.from_tree2(out) - extract.apply_variable_updates(args, updates) + extract.apply_variable_updates(args, updates, fn_name='custom_vjp') return out def defvjp( @@ -1567,7 +1567,7 @@ def __call__(self, *args, **kwargs): out = self.f(*args, **kwargs) if self.graph: out = extract.to_tree2(out) - extract.check_no_aliases(args=updates[0], kwargs=updates[1], out=out) + extract.check_no_aliases('remat', args=updates[0], kwargs=updates[1], out=out) updates = extract.mask_variable_updates(updates, snapshot) return out, updates @@ -1664,7 +1664,7 @@ def simple_remat_wrapper(*args, **kwargs): out, updates = checkpointed_fn(*args, **kwargs) if graph: out = extract.from_tree2(out) - extract.apply_variable_updates((args, kwargs), updates) + extract.apply_variable_updates((args, kwargs), updates, fn_name='remat') return out return simple_remat_wrapper # type: ignore[return-value] diff --git a/flax/nnx/transforms/compilation.py b/flax/nnx/transforms/compilation.py index 4d88a733f..70c908659 100644 --- a/flax/nnx/transforms/compilation.py +++ b/flax/nnx/transforms/compilation.py @@ -466,7 +466,7 @@ def __call__(self, *args, **kwargs): out = self.f(*args, **kwargs) if self.graph: out = extract.to_tree2(out, prefix=self.out_shardings) - extract.check_no_aliases(args=args_updates, kwargs=kwargs_updates, out=out) + extract.check_no_aliases('jit', args=args_updates, kwargs=kwargs_updates, out=out) def donated_arg(jax_path, c, s): path = graphlib.jax_to_nnx_path(jax_path) return path[0] in self.donate_argnums or extract.variable_changed(c, s) @@ -573,7 +573,7 @@ def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R: args, kwargs = self._maybe_to_tree(args, kwargs) out, updates = self.jitted_fn(*self.partial_args, *args, **kwargs) extract.apply_variable_updates( - ((*self.partial_args, *args), kwargs), updates) + ((*self.partial_args, *args), kwargs), updates, fn_name='jit') return self._maybe_from_tree(out) def __get__(self, obj, objtype=None): @@ -1150,7 +1150,7 @@ def call(*args, **kwargs): def __call__(self, *args, **kwargs): args, kwargs = self.jit_wrapped._maybe_to_tree(args, kwargs) out, updates = self.compiled(*args, **kwargs) - extract.apply_variable_updates((args, kwargs), updates) + extract.apply_variable_updates((args, kwargs), updates, fn_name='jit') return self.jit_wrapped._maybe_from_tree(out) @property @@ -1261,7 +1261,7 @@ def __call__(self, *args): out = self.f(*args) if self.graph: out = extract.to_tree2(out, prefix=self.out_specs) - extract.check_no_aliases(args=updates, out=out) + extract.check_no_aliases('shard_map', args=updates, out=out) updates = extract.mask_variable_updates(updates, snapshot) return out, updates @@ -1546,7 +1546,7 @@ def shard_map_wrapper(*args, **kwargs): check_aliasing=in_specs is not None, ) out, updates = shard_map_fn(*args, **kwargs) - extract.apply_variable_updates(args, updates) + extract.apply_variable_updates(args, updates, fn_name='shard_map') if graph: out = extract.from_tree2(out) return out diff --git a/flax/nnx/transforms/iteration.py b/flax/nnx/transforms/iteration.py index 2a676e9c5..9652ba207 100644 --- a/flax/nnx/transforms/iteration.py +++ b/flax/nnx/transforms/iteration.py @@ -139,7 +139,7 @@ def wrapper(*in_args, **in_kwargs): _apply_axis_fn(args, in_axes, metadata, spmd.add_axis) _apply_axis_fn(out, out_axes, metadata, spmd.add_axis) updates = extract.mask_variable_updates(updates, snapshot) - extract.apply_variable_updates(in_args, updates) + extract.apply_variable_updates(in_args, updates, fn_name='transform_metadata') if graph: out = extract.from_tree2(out) return out @@ -275,7 +275,7 @@ def __call__(self, *args, **kwargs): out = self.f(*args, **kwargs) if self.graph: out = extract.to_tree2(out, prefix=self.out_axes) - extract.check_no_aliases(args=updates[0], kwargs=updates[1], out=out) + extract.check_no_aliases('vmap', args=updates[0], kwargs=updates[1], out=out) updates = extract.mask_variable_updates(updates, snapshot) return out, updates @@ -297,7 +297,7 @@ def __call__(self, *args, **kwargs): out = self.f(*args, **kwargs) if self.graph: out = extract.to_tree2(out, prefix=self.out_axes) - extract.check_no_aliases(args=updates[0], kwargs=updates[1], out=out) + extract.check_no_aliases('pmap', args=updates[0], kwargs=updates[1], out=out) updates = extract.mask_variable_updates(updates, snapshot) return out, updates @@ -525,7 +525,7 @@ def simple_vmap_wrapper(*args, **kwargs): check_aliasing=in_axes is not None, ) out, updates = vmapped_fn(*args, **kwargs) - extract.apply_variable_updates((args, kwargs), updates) + extract.apply_variable_updates((args, kwargs), updates, fn_name='vmap') if graph: out = extract.from_tree2(out) return out @@ -792,7 +792,7 @@ def simple_pmap_wrapper(*args, **kwargs): check_aliasing=in_axes is not None, ) out, updates = pmapped_fn(*args, **kwargs) - extract.apply_variable_updates((args, kwargs), updates) + extract.apply_variable_updates((args, kwargs), updates, fn_name='pmap') if graph: out = extract.from_tree2(out) return out @@ -1417,7 +1417,7 @@ def keep_fn(path, cur, snap): ) return changed - extract.check_no_aliases(args=masked_carry_updates, out=out) + extract.check_no_aliases('scan', args=masked_carry_updates, out=out) masked_carry_updates = extract.mask_variable_updates( masked_carry_updates, masked_carry_snapshot, keep_fn=keep_fn, ) @@ -1687,7 +1687,7 @@ def simple_scan_wrapper(*args): out, updates = result masked_args = extract.mask_at(args, carry_arg_index) - extract.apply_variable_updates(masked_args, updates) + extract.apply_variable_updates(masked_args, updates, fn_name='scan') if carry_arg_index is not None: carry_in = args[carry_arg_index] diff --git a/flax/nnx/transforms/transforms.py b/flax/nnx/transforms/transforms.py index eb128edca..b01fff4a6 100644 --- a/flax/nnx/transforms/transforms.py +++ b/flax/nnx/transforms/transforms.py @@ -269,7 +269,7 @@ def __call__(self, *args, **kwargs): out = self.f(*args, **kwargs) if self.graph: out = extract.to_tree2(out) - extract.check_no_aliases(args=args, kwargs=kwargs, out=out) + extract.check_no_aliases('eval_shape', args=args, kwargs=kwargs, out=out) return out @@ -314,7 +314,7 @@ def eval_shape( if not graph or not graph_updates: if graph: args, kwargs = extract.to_tree2((args, kwargs)) - extract.check_no_aliases(args=args, kwargs=kwargs) + extract.check_no_aliases('eval_shape', args=args, kwargs=kwargs) out = jax.eval_shape( SimpleEvalShapeFn(f_call, graph=graph), *args, **kwargs ) @@ -369,7 +369,7 @@ def __call__(self, *args): out = self.f(*args) if self.graph: out = extract.to_tree2(out) - extract.check_no_aliases(args=updates, out=out) + extract.check_no_aliases('checkify', args=updates, out=out) updates = extract.mask_variable_updates(updates, snapshot) return out, updates @@ -437,7 +437,7 @@ def simple_checkify_wrapper(*args): error, (out, updates) = checkify_fn(*args) if graph: out = extract.from_tree2(out) - extract.apply_variable_updates(args, updates) + extract.apply_variable_updates(args, updates, fn_name='checkify') return error, out return simple_checkify_wrapper # type: ignore @@ -481,7 +481,7 @@ def __call__(self, *args): out = self.f(*args) if self.graph: out = extract.to_tree2(out) - extract.check_no_aliases(args=updates, out=out) + extract.check_no_aliases('switch', args=updates, out=out) return out, updates @@ -527,7 +527,7 @@ def cond( ) if graph: out = extract.from_tree2(out) - extract.apply_variable_updates(operands, updates) + extract.apply_variable_updates(operands, updates, fn_name='cond') return out @general.split_inputs(ctxtag='cond') @@ -580,7 +580,7 @@ def switch( ) if graph: out = extract.from_tree2(out) - extract.apply_variable_updates(operands, updates) + extract.apply_variable_updates(operands, updates, fn_name='switch') return out @general.split_inputs(ctxtag='switch') diff --git a/tests/nnx/graph_utils_test.py b/tests/nnx/graph_utils_test.py index aec5af0fa..88bc8c026 100644 --- a/tests/nnx/graph_utils_test.py +++ b/tests/nnx/graph_utils_test.py @@ -1244,6 +1244,14 @@ def test_graphdef_hash_with_sequential(self): ) hash(nnx.graphdef(net)) + @nnx.set_graph_mode(False) + def test_split_graph_error(self): + v = nnx.Variable(jnp.array(1.0)) + with self.assertRaisesRegex( + ValueError, 'found at paths' + ): + graphdef, state = nnx.split((v, v)) + class SimpleModule(nnx.Module): pass @@ -1481,7 +1489,7 @@ def test_iter_graph_tree_mode_shared_variable_raises(self): root.b = var with self.assertRaisesRegex( - ValueError, 'Shared references are not supported with graph=False' + ValueError, 'found at paths' ): list(nnx.iter_graph(root, graph=False)) @@ -1491,7 +1499,7 @@ def test_iter_graph_tree_mode_cycle_raises(self): a.append(b) with self.assertRaisesRegex( - ValueError, 'Cycles are not supported with graph=False' + ValueError, 'found at paths' ): list(nnx.iter_graph(a, graph=False)) @@ -1586,7 +1594,7 @@ def test_recursive_map_tree_mode_shared_variable_raises(self): g = [v, v] with self.assertRaisesRegex( - ValueError, 'Shared references are not supported with graph=False' + ValueError, 'found at paths' ): nnx.recursive_map(lambda path, node: node, g, graph=False) @@ -1596,10 +1604,54 @@ def test_recursive_map_tree_mode_cycle_raises(self): a.append(b) with self.assertRaisesRegex( - ValueError, 'Cycles are not supported with graph=False' + ValueError, 'found at paths' ): nnx.recursive_map(lambda path, node: node, a, graph=False) + def test_check_valid_pytree_flatten(self): + class NotAPytree(nnx.Pytree, pytree=False): + def __init__(self): + self.x = 1 + + node = [NotAPytree()] + with self.assertRaisesRegex( + ValueError, "pytree=False.*Found at path" + ): + nnx.graphlib.flatten(node, graph=False) + + def test_check_valid_pytree_iter_graph(self): + class NotAPytree(nnx.Pytree, pytree=False): + def __init__(self): + self.x = 1 + + node = nnx.List([NotAPytree()]) + with self.assertRaisesRegex( + ValueError, "pytree=False.*Found at path" + ): + list(nnx.iter_graph(node, graph=False)) + + def test_check_valid_pytree_iter_children(self): + class NotAPytree(nnx.Pytree, pytree=False): + def __init__(self): + self.x = 1 + + node = NotAPytree() + with self.assertRaisesRegex( + ValueError, "pytree=False" + ): + list(nnx.iter_children(node, graph=False)) + + def test_check_valid_pytree_recursive_map(self): + class NotAPytree(nnx.Pytree, pytree=False): + def __init__(self): + self.x = 1 + + node = nnx.List([NotAPytree()]) + with self.assertRaisesRegex( + ValueError, "pytree=False.*Found at path" + ): + nnx.recursive_map(lambda path, node: node, node, graph=False) + if __name__ == '__main__': absltest.main() diff --git a/tests/nnx/transforms_test.py b/tests/nnx/transforms_test.py index 3dc909cc5..ee30506fd 100644 --- a/tests/nnx/transforms_test.py +++ b/tests/nnx/transforms_test.py @@ -32,7 +32,6 @@ from flax import errors - class TestJIT(parameterized.TestCase): def test_jit(self): m = nnx.Dict(a=nnx.Param(1)) @@ -617,7 +616,7 @@ def test_tree_jit_no_input_output_aliasing(self): def f(v): return v - with self.assertRaisesRegex(ValueError, 'same instance'): + with self.assertRaisesRegex(ValueError, 'does not support returning input Variables as outputs'): f(v) def test_tree_jit_no_shared_variable_refs(self): @@ -625,10 +624,11 @@ def test_tree_jit_no_shared_variable_refs(self): @nnx.jit(graph=False) def f(v1, v2): - v1[...] += 1 - return None + pass - with self.assertRaisesRegex(ValueError, 'already seen'): + with self.assertRaisesRegex( + ValueError, 'found at paths' + ): f(v, v) def test_tree_jit_new_variable_output_ok(self): @@ -3471,7 +3471,7 @@ def test_scan_input_output_aliasing(self): def f(carry): return carry - with self.assertRaisesRegex(ValueError, 'same instance'): + with self.assertRaisesRegex(ValueError, 'does not support returning input Variables as outputs'): f(v)