Skip to content

Commit

Permalink
Merge pull request #21268 from jakevdp:register-dataclass
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 634624518
  • Loading branch information
jax authors committed May 17, 2024
2 parents 0e92433 + defb53f commit 1829a66
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 12 deletions.
1 change: 1 addition & 0 deletions docs/jax.tree_util.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ List of Functions
Partial
all_leaves
build_tree
register_dataclass
register_pytree_node
register_pytree_node_class
register_pytree_with_keys
Expand Down
88 changes: 76 additions & 12 deletions jax/_src/tree_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from functools import partial
import operator as op
import textwrap
from typing import Any, Callable, NamedTuple, TypeVar, Union, overload
from typing import Any, Callable, NamedTuple, Sequence, TypeVar, Union, overload

from jax._src import traceback_util
from jax._src.lib import pytree
Expand All @@ -32,7 +32,7 @@
traceback_util.register_exclusion(__file__)

T = TypeVar("T")
U = TypeVar("U", bound=type[Any])
Typ = TypeVar("Typ", bound=type[Any])
H = TypeVar("H", bound=Hashable)

Leaf = Any
Expand Down Expand Up @@ -254,7 +254,7 @@ def register_pytree_node(nodetype: type[T],
_registry[nodetype] = _RegistryEntry(flatten_func, unflatten_func)


def register_pytree_node_class(cls: U) -> U:
def register_pytree_node_class(cls: Typ) -> Typ:
"""Extends the set of types that are considered internal nodes in pytrees.
This function is a thin wrapper around ``register_pytree_node``, and provides
Expand Down Expand Up @@ -807,7 +807,7 @@ def flatten_func_impl(tree):
)


def register_pytree_with_keys_class(cls: U) -> U:
def register_pytree_with_keys_class(cls: Typ) -> Typ:
"""Extends the set of types that are considered internal nodes in pytrees.
This function is similar to ``register_pytree_node_class``, but requires a
Expand Down Expand Up @@ -838,19 +838,82 @@ def tree_unflatten(cls, aux_data, children):


def register_dataclass(
nodetype: type, data_fields: list[str], meta_fields: list[str]
):
nodetype: Typ, data_fields: Sequence[str], meta_fields: Sequence[str]
) -> Typ:
"""Extends the set of types that are considered internal nodes in pytrees.
This differs from ``register_pytree_with_keys_class`` in that the C++
registries use the optimized C++ dataclass builtin instead of the argument
functions.
See :ref:`extending-pytrees` for more information about registering pytrees.
Args:
nodetype: a Python type to treat as an internal pytree node.
meta_fields: auxiliary data field names.
data_fields: data field names.
nodetype: a Python type to treat as an internal pytree node. This is assumed
to have the semantics of a :obj:`~dataclasses.dataclass`: namely, class
attributes represent the whole of the object state, and can be passed
as keywords to the class constructor to create a copy of the object.
All defined attributes should be listed among ``meta_fields`` or ``data_fields``.
meta_fields: auxiliary data field names. These fields *must* contain static,
hashable, immutable objects, as these objects are used to generate JIT cache
keys. In particular, ``meta_fields`` cannot contain :class:`jax.Array` or
:class:`numpy.ndarray` objects.
data_fields: data field names. These fields *must* be JAX-compatible objects
such as arrays (:class:`jax.Array` or :class:`numpy.ndarray`), scalars, or
pytrees whose leaves are arrays or scalars. Note that ``data_fields`` may be
``None``, as this is recognized by JAX as an empty pytree.
Returns:
The input class ``nodetype`` is returned unchanged after being added to JAX's
pytree registry. This return value allows ``register_dataclass`` to be partially
evaluated and used as a decorator as in the example below.
Example:
>>> from dataclasses import dataclass
>>> from functools import partial
>>>
>>> @partial(jax.tree_util.register_dataclass,
... data_fields=['x', 'y'],
... meta_fields=['op'])
... @dataclass
... class MyStruct:
... x: jax.Array
... y: jax.Array
... op: str
...
>>> m = MyStruct(x=jnp.ones(3), y=jnp.arange(3), op='add')
>>> m
MyStruct(x=Array([1., 1., 1.], dtype=float32), y=Array([0, 1, 2], dtype=int32), op='add')
Now that this class is registered, it can be used with functions in :mod:`jax.tree_util`:
>>> leaves, treedef = jax.tree.flatten(m)
>>> leaves
[Array([1., 1., 1.], dtype=float32), Array([0, 1, 2], dtype=int32)]
>>> treedef
PyTreeDef(CustomNode(MyStruct[('add',)], [*, *]))
>>> jax.tree.unflatten(treedef, leaves)
MyStruct(x=Array([1., 1., 1.], dtype=float32), y=Array([0, 1, 2], dtype=int32), op='add')
In particular, this registration allows ``m`` to be passed seamlessly through code
wrapped in :func:`jax.jit` and other JAX transformations:
>>> @jax.jit
... def compiled_func(m):
... if m.op == 'add':
... return m.x + m.y
... else:
... raise ValueError(f"{m.op=}")
...
>>> compiled_func(m)
Array([1., 2., 3.], dtype=float32)
"""
# Store inputs as immutable tuples in this scope, because we close over them
# for later evaluation. This prevents potentially confusing behavior if the
# caller were to pass in lists that are later mutated.
meta_fields = tuple(meta_fields)
data_fields = tuple(data_fields)

def flatten_with_keys(x):
meta = tuple(getattr(x, name) for name in meta_fields)
data = tuple((GetAttrKey(name), getattr(x, name)) for name in data_fields)
Expand All @@ -867,13 +930,14 @@ def flatten_func(x):
data = tuple(getattr(x, name) for name in data_fields)
return data, meta

default_registry.register_dataclass_node(nodetype, data_fields, meta_fields)
none_leaf_registry.register_dataclass_node(nodetype, data_fields, meta_fields)
dispatch_registry.register_dataclass_node(nodetype, data_fields, meta_fields)
default_registry.register_dataclass_node(nodetype, list(data_fields), list(meta_fields))
none_leaf_registry.register_dataclass_node(nodetype, list(data_fields), list(meta_fields))
dispatch_registry.register_dataclass_node(nodetype, list(data_fields), list(meta_fields))
_registry[nodetype] = _RegistryEntry(flatten_func, unflatten_func)
_registry_with_keypaths[nodetype] = _RegistryWithKeypathsEntry(
flatten_with_keys, unflatten_func
)
return nodetype


def register_static(cls: type[H]) -> type[H]:
Expand Down

0 comments on commit 1829a66

Please sign in to comment.