Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions flax/nnx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@
from .graphlib import cached_partial as cached_partial
from .graphlib import flatten as flatten
from .graphlib import unflatten as unflatten
from .graphlib import set_graph_mode as set_graph_mode
from .graphlib import set_graph_updates as set_graph_updates
from .nn import initializers as initializers
from .nn.activations import celu as celu
from .nn.activations import elu as elu
Expand Down
38 changes: 28 additions & 10 deletions flax/nnx/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,17 @@ def to_tree2(
prefix: tp.Any = Missing,
check_aliasing: bool = True,
) -> tp.Any:
"""to_tree2 has two main tasks:

1. Convert all graph nodes to NodeStates (a tree representation).
2. Check all Variables are aliased consistently given the prefix tree,
e.g. vmap's in/out_axes arguments.

Each NodeState contains the `GraphDef` and State for each object, these
are generated using `graphlib.flatten`. `extract.broadcast_prefix` is used
to calculate the prefix for each node, `check_consistent_aliasing2` traverses
the nodes subgraph and checks for Variable aliasing.
"""
ref_index: graphlib.RefMap = graphlib.RefMap()

def _to_node_states(leaf):
Expand All @@ -343,8 +354,8 @@ def _to_node_states(leaf):
graphdef, flat_state = graphlib.flatten(
leaf, ref_index=ref_index, graph=True
)
states = graphlib._to_nested_state(graphdef, (flat_state,))
return NodeStates.from_split(graphdef, *states)
(state,) = graphlib._to_nested_state(graphdef, (flat_state,))
return NodeStates.from_split(graphdef, state)

is_leaf = lambda x: (
isinstance(x, variablelib.Variable) or graphlib.is_graph_node(x)
Expand Down Expand Up @@ -503,8 +514,8 @@ def updates_and_snapshot(args: A) -> tuple[A, A]:
return updates, snapshot


def check_no_aliases(**kwargs):
Attrs = namedtuple('Attrs', kwargs.keys())
def check_no_aliases(fn_name: str, /, **kwargs):
Attrs = namedtuple('Attrs', kwargs.keys()) # type: ignore[misc]
container = Attrs(**kwargs)
is_leaf = lambda x: isinstance(x, variablelib.Variable)
seen: dict[int, jax.tree_util.KeyPath] = {}
Expand All @@ -518,9 +529,11 @@ def check_no_aliases(**kwargs):
path_str = jax.tree_util.keystr(path)
seen_path_str = jax.tree_util.keystr(seen[var_id])
raise ValueError(
f'Variable at {path_str} is the same instance as at '
f'{seen_path_str}. tree-mode transforms do not support '
f'returning input Variables as outputs.'
f'Duplicate {leaf}\nfound at paths:\n\n'
f' - {seen_path_str}\n'
f' - {path_str}\n\n'
f'nnx.{fn_name} with graph_updates=False does not support '
'returning input Variables as outputs.'
)
seen[var_id] = path

Expand Down Expand Up @@ -558,7 +571,9 @@ def _mask_updates(path, current, snapshot):
)


def apply_variable_updates(args_tree: A, updates_tree: A) -> None:
def apply_variable_updates(
args_tree: A, updates_tree: A, *, fn_name: str,
) -> None:
is_leaf = lambda x: isinstance(x, variablelib.Variable) or isinstance(x, Mask)
args_leaves, treedef = jax.tree.flatten_with_path(args_tree, is_leaf=is_leaf)
updates_leaves = treedef.flatten_up_to(updates_tree)
Expand All @@ -571,8 +586,11 @@ def apply_variable_updates(args_tree: A, updates_tree: A) -> None:
path_str = jax.tree_util.keystr(path)
seen_path_str = jax.tree_util.keystr(seen[var_id])
raise ValueError(
f'Variable at {path_str} was already seen at {seen_path_str}. '
'tree-mode jit does not support shared Variable references.'
f'Duplicate {variable}\nfound at paths:\n\n'
f' - {seen_path_str}\n'
f' - {path_str}\n\n'
f'Tree mode (graph=False) does not support shared references. '
+ graphlib._tree_mode_suggestion(fn_name)
)
seen[var_id] = path
if isinstance(update, variablelib.Variable):
Expand Down
129 changes: 73 additions & 56 deletions flax/nnx/graphlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,23 +51,34 @@

def _tree_mode_suggestion(fn_name: str) -> str:
return (
f'\n\nIf the structure is intended to be a graph, consider '
f'using graph=True or nnx.graph.{fn_name}.'
f'Consider the following options:\n\n'
'1. Remove the duplicates and guarantee a tree structure.\n'
f'2. Enable graph mode by passing graph=True to {fn_name} e.g.\n\n'
f' nnx.{fn_name}(..., graph=True)\n\n'
f'3. Use nnx.compat.{fn_name} instead e.g.\n\n'
f' nnx.compat.{fn_name}(...)'
)

def _check_valid_pytree(node: tp.Any, fn_name: str) -> None:
def _check_valid_pytree(
node: tp.Any, fn_name: str, path: str = '',
) -> None:
from flax.nnx import pytreelib
if (
isinstance(node, pytreelib.Pytree)
and not node._pytree__is_pytree
):
raise ValueError(
msg = (
f"Cannot use '{fn_name}' with graph=False on a "
f"'{type(node).__name__}' instance that has pytree=False. "
)
if path:
msg += f"Found at path: {path}. "
msg += (
f"Pytree subclasses with pytree=False are not registered as "
f"JAX pytrees and cannot be used in tree-mode. "
+ _tree_mode_suggestion(fn_name)
)
raise ValueError(msg)

Names = tp.Sequence[int]
Node = tp.TypeVar('Node')
Expand Down Expand Up @@ -658,13 +669,38 @@ def _tree_flatten(
leaves: list[tp.Any],
paths: list[PathParts] | None,
) -> None:
def _is_leaf(x):
seen_variables: dict[int, str] = {}
seen_refs: dict[int, str] = {}
def _is_leaf(path, x):
if isinstance(x, Variable):
var_id = id(x)
str_path = jax.tree_util.keystr(path)
if var_id in seen_variables:
raise ValueError(
f'Duplicate {x}\nfound at paths:\n\n'
f' - {seen_variables[var_id]}\n'
f' - {str_path}\n\n'
'Tree mode (graph=False) does not support shared references. '
+ _tree_mode_suggestion('split')
)
seen_variables[var_id] = str_path
return True
_check_valid_pytree(x, 'flatten')
if variablelib.is_array_ref(x):
ref_id = id(x)
str_path = jax.tree_util.keystr(path)
if ref_id in seen_refs:
raise ValueError(
f'Duplicate {x}\nfound at paths:\n\n'
f' - {seen_refs[ref_id]}\n'
f' - {str_path}\n\n'
'Tree mode (graph=False) does not support shared references. '
+ _tree_mode_suggestion('split')
)
seen_refs[ref_id] = str_path
_check_valid_pytree(x, 'flatten', jax.tree_util.keystr(path))
return False
jax_leaves, treedef = jax.tree_util.tree_flatten_with_path(
node, is_leaf=_is_leaf
node, is_leaf=_is_leaf, is_leaf_takes_path=True
)
nnx_paths_and_leaves: list[tuple[PathParts, tp.Any]] = [
(jax_to_nnx_path(jax_path), value) for jax_path, value in jax_leaves
Expand All @@ -682,45 +718,16 @@ def _is_leaf(x):
)
nodes.append(tree_nodedef)

seen_variables: set[int] = set()
seen_refs: set[int] = set()
sorted_leaf_index = 0
for nnx_path, value in nnx_paths_and_leaves:
if isinstance(value, Variable):
var_id = id(value)
if var_id in seen_variables:
raise ValueError(
f'Duplicate Variable found at path {nnx_path!r}. '
'Tree mode (graph=False) does not support shared references. '
+ _tree_mode_suggestion('split')
)
seen_variables.add(var_id)
raw_value = value.get_raw_value()
if variablelib.is_array_ref(raw_value):
ref_id = id(raw_value)
if ref_id in seen_refs:
raise ValueError(
f'Duplicate Ref found inside Variable at path {nnx_path!r}. '
'Tree mode (graph=False) does not support shared references. '
+ _tree_mode_suggestion('split')
)
seen_refs.add(ref_id)
nodes.append(VariableDef(
type=value.var_type,
index=sorted_leaf_index,
outer_index=None,
metadata=HashableMapping(value.get_metadata()),
array_refdef=None,
))
elif variablelib.is_array_ref(value):
ref_id = id(value)
if ref_id in seen_refs:
raise ValueError(
f'Duplicate Ref found at path {nnx_path!r}. '
'Tree mode (graph=False) does not support shared references. '
+ _tree_mode_suggestion('split')
)
seen_refs.add(ref_id)
leaves.append(value)
if paths is not None:
paths.append(nnx_path)
Expand Down Expand Up @@ -2946,40 +2953,45 @@ def _iter_graph(node: tp.Any, /) -> tp.Iterator[tuple[PathParts, tp.Any]]:


def _iter_tree(node: tp.Any, /) -> tp.Iterator[tuple[PathParts, tp.Any]]:
in_progress: set[int] = set()
seen_refs: set[int] = set()
in_progress: dict[int, str] = {}
seen_refs: dict[int, str] = {}
stack: list[tuple[PathParts, tp.Any, bool]] = [((), node, False)]
while stack:
path, current, traversed = stack.pop()

if traversed:
in_progress.discard(id(current))
in_progress.pop(id(current), None)
yield path, current
continue

if not is_pytree_node(current, check_graph_registry=False):
_check_valid_pytree(current, 'iter_graph')
_check_valid_pytree(current, 'iter_graph', '/'.join(map(str, path)))
if isinstance(current, Variable) or variablelib.is_array_ref(current):
obj_id = id(current)
str_path = '/'.join(map(str, path))
if obj_id in seen_refs:
raise ValueError(
f'Found duplicate Variable or Ref at path '
f'"{"/".join(map(str, path))}". '
'Shared references are not supported with graph=False. '
f'Duplicate {current}\nfound at paths:\n\n'
f' - {seen_refs[obj_id]}\n'
f' - {str_path}\n\n'
'Tree mode (graph=False) does not support shared references. '
+ _tree_mode_suggestion('iter_graph')
)
seen_refs.add(obj_id)
seen_refs[obj_id] = str_path
yield path, current
continue

obj_id = id(current)
str_path = '/'.join(map(str, path))
if obj_id in in_progress:
raise ValueError(
f'Found cycle at path "{"/".join(map(str, path))}". '
f'Cycle detected for {type(current).__name__}\nfound at paths:\n\n'
f' - {in_progress[obj_id]}\n'
f' - {str_path}\n\n'
'Cycles are not supported with graph=False. '
+ _tree_mode_suggestion('iter_graph')
)
in_progress.add(obj_id)
in_progress[obj_id] = str_path

stack.append((path, current, True))
children, _ = jax.tree_util.tree_flatten_with_path(
Expand Down Expand Up @@ -3156,32 +3168,37 @@ def _recursive_map_tree(
f: tp.Callable[[PathParts, tp.Any], tp.Any],
node: tp.Any,
) -> tp.Any:
in_progress: set[int] = set()
seen_refs: set[int] = set()
in_progress: dict[int, str] = {}
seen_refs: dict[int, str] = {}

def _recurse(path: PathParts, current: tp.Any) -> tp.Any:
if not is_pytree_node(current, check_graph_registry=False):
_check_valid_pytree(current, 'recursive_map')
_check_valid_pytree(current, 'recursive_map', '/'.join(map(str, path)))
if isinstance(current, Variable) or is_array_ref(current):
obj_id = id(current)
str_path = '/'.join(map(str, path))
if obj_id in seen_refs:
raise ValueError(
f'Found duplicate Variable or Ref at path '
f'"{"/".join(map(str, path))}". '
'Shared references are not supported with graph=False. '
f'Duplicate {current}\nfound at paths:\n\n'
f' - {seen_refs[obj_id]}\n'
f' - {str_path}\n\n'
'Tree mode (graph=False) does not support shared references. '
+ _tree_mode_suggestion('recursive_map')
)
seen_refs.add(obj_id)
seen_refs[obj_id] = str_path
return f(path, current)

obj_id = id(current)
str_path = '/'.join(map(str, path))
if obj_id in in_progress:
raise ValueError(
f'Found cycle at path "{"/".join(map(str, path))}". '
f'Cycle detected for {type(current).__name__}\nfound at paths:\n\n'
f' - {in_progress[obj_id]}\n'
f' - {str_path}\n\n'
'Cycles are not supported with graph=False. '
+ _tree_mode_suggestion('recursive_map')
)
in_progress.add(obj_id)
in_progress[obj_id] = str_path

children_with_path, treedef = jax.tree_util.tree_flatten_with_path(
current, is_leaf=lambda x: x is not current
Expand All @@ -3195,7 +3212,7 @@ def _recurse(path: PathParts, current: tp.Any) -> tp.Any:
new_node = treedef.unflatten(new_children)
result = f(path, new_node)

in_progress.discard(obj_id)
in_progress.pop(obj_id, None)
return result

return _recurse((), node)
Expand Down
Loading
Loading