-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Automatically treat dataclasses as pytrees #2371
Comments
Thanks for starting this! Treating dataclasses as pytrees sgtm. (Quick terminology aside: can a value really be "differentiable"? Maybe we should say "values that don't inhabit a vector space", or "non-vspace values" for short.) The non-vspace values question is interestingly related, but could be considered separately from dataclasses. After all, if we think of dataclasses as just pytrees (i.e. isomorphic to tuples) as the title of this issue proposes, then asking asking about non-vspace values in dataclasses would be no different than asking the same question about pytrees in general. And since a leaf is a pytree, we can just ask the question in the case of a single scalar: what should We could instead choose for dataclasses to act differently from pytrees, and then maybe do something interesting with their type annotations (which other pytree types don't have). But that seems like a different proposal from "treat dataclasses as pytrees", and may require more discussion (e.g. it may require we change our convention for Should we separate out the question of what one might do with dataclass metadata, and restrict this issue to being about treating dataclasses as pytrees? AIUI that's what your code is already up to! |
This needs tests, and likely some iteration on the code from my branch (or a new implementation) to get things working. Note that the flattening/unflattening logic here needs to be written in C++. |
Hello @shoyer! Thanks for the initiative! This is a crucial feature and one of the main things that stop me from using JAX today in day to day research. I would like to finish this work and I will able to do that after NuerIPS2020 deadline. You mentioned that this PR lacks testing, is there anything else? Does it still need [un]flatten implementations? Thanks! |
For context in the various other tree libraries ( Instead of treeating all dataclasses as jaxtrees, could we instead create a drop in replacement for dataclass for users who know they want this behavior? Here's an example implementation which is basically a fork of flax.struct: from dataclasses import dataclass
from typing import Any, Type, TypeVar
import jax
import jax.numpy as jnp
T = TypeVar("T")
def jax_tree(cls: T) -> T:
is_data = lambda x: isinstance(x, jnp.ndarray) or hasattr(x, '__jax_dataclass')
def flatten_fun(obj):
meta = {}
data = {}
for k, v in obj.__dict__.items():
if isinstance(v, list): # We can add other containers here.
are_data = list(map(is_data, v))
assert all(are_data) or not any(are_data)
data[k] = v
elif is_data(v):
data[k] = v
else:
meta[k] = v
meta['__data_keys'] = list(data.keys())
data = list(data.values())
return tuple(data), tuple(meta.items())
def unflatten_fun(meta, data):
meta = dict(meta)
data = dict(zip(meta.pop('__data_keys'), data))
return cls(**meta, **data)
jax.tree_util.register_pytree_node(cls, flatten_fun, unflatten_fun)
cls.__jax_dataclass = True
return dataclass(cls)
jax.tree = jax_tree
@jax.tree
class Bar:
c: jnp.ndarray
@jax.tree
class Foo(object):
a: jnp.ndarray
b: Bar
>>> foo = Foo(jnp.ones([]), Bar(jnp.zeros([])))
>>> jax.tree_leaves(foo)
[DeviceArray(1., dtype=float32), DeviceArray(0., dtype=float32)] |
@tomhennigan Unlike flax's version, it looks like you're trying to guess which components are categorized as auxiliary parameters and which are pytree-like? I tried to do that, but unfortunately there may be components that are both hashable and pytree-like, for example, elements of type I ended up modifying for field_info in dataclasses.fields(data_clz):
if not field_info.init:
continue
if field_info.metadata.get('pytree_like', True):
tree_fields.append(name)
else:
hashed_fields.append(name) Also, I fixed a problem where they don't preserve metadata passed into their field factory: def field(pytree_like: bool = True, **kwargs: Any) -> dataclasses.Field:
return dataclasses.field(metadata={**kwargs.pop('metadata', {}),
'pytree_like': pytree_like},
**kwargs) Other minor things I did were
Hope any of this is useful :) If there's one change that I would love to see, but haven't done in my own code yet, it's to declare JAX dataclasses using a mixin rather than a decorator. The benefits are that
|
Hello all! @tomhennigan, @NeilGirdhar thanks for your input.
Can the same reasoning be applied to
Agreed and I think the same option should be available for other structures. Actually, this is another important missing bit in JAX! I raised that issue before #2588. Example: non-parameteric Gaussian process model with trainable hyper-parameters lengthscale and variance of the squared exponential kernel. Often, there is a need to experiment with the model in a way that we compute gradients w.r.t. to only variance, only lengthscale or both variance and lengthscale. I had to write different code for each case specifically, which is super annoying considering that other frameworks (TF and PyTorch) support trainability of tensors out of the box. Long story short: as I think, two features will bring more users to JAX:
|
@mattjj @shoyer I've opened a PR with an implementation, see tensorflow/tensorflow#46894 |
There has been some offline discussion; I think the consensus is that treating dataclasses as pytrees by default is probably not something we want to do, for a few reasons. There is some past experience in TF suggesting this could be problematic: TF recurses into arbitrary data structures which leads to the need for hacky workarounds (e.g. allow/deny in AutoGraph) Better would be to allow manual registration of dataclasses, perhaps by making JAX's One side-note here: it may be better to only handle frozen dataclasses in this case, because general dataclasses are mutable, which can cause subtle bugs due to side-effects. |
I haven't thought through all the corner cases, but what if we provided a function similar to import dataclasses
import jax
def register_pytree_node_dataclass(cls):
_flatten = lambda obj: jax.tree_flatten(dataclasses.asdict(obj))
_unflatten = lambda d, children: cls(**d.unflatten(children))
jax.tree_util.register_pytree_node(cls, _flatten, _unflatten)
return cls
@register_pytree_node_dataclass
@dataclasses.dataclass
class DataClass:
a: int
b: list
c: dict
d = DataClass(2, ['a', 'b', 'c'], {'x': 1, 'y': 2})
leaves, treedef = jax.tree_flatten(d)
print(treedef.unflatten(leaves))
# DataClass(a=2, b=['a', 'b', 'c'], c={'x': 1, 'y': 2}) Alternatively, we could make |
Do you have more details about this hacky workaround? Is it just for namedtuple? I can understand not wanting to recurse into arbitrary classes.
I'd like all of my dataclasses to always be registered. If you have ideas for how to do this without manually registering each one, or using jax specific decorators, I think that would be a reasonable compromise. It seems that this should be enforced for namedtuple too if this is a legitimate concern.
Agreed, too bad they are mutable by default. |
Currently the I suspect the best compromise we can hope for at the moment is explicit registration. |
One note is that we do have some more heuristic logic for identifying named tuples: |
@hawkinsp - what would you think about pytree exposing a boolean flag (False by default) that optionally registers dataclasses? |
I think a flag would be confusing. Whatever behavior we have should be consistent. |
@hawkinsp yeah, the same heuristic is used in a couple places. I lifted it out to a function https://github.com/tensorflow/tensorflow/pull/46894/files#diff-f29d52c716a11de00124e98f74e14994b626a81cdde917580fb1ec16a54f0ab2R134-R138 @jakevdp so I can just write an importlib hook that will register every dataclass, which effectively removes this limitation. so that's fine, though I know many others would like this functionality by default (hence this issue) - it's better developer ergonomics. I'd still like to know why this is acceptable for class Foo(namedtuple(field_names="test", typename="Foo")):
test: str
@dataclass(frozen=True)
class Bar:
test: str |
It's better developer ergonomics for people who want dataclasses registered; worse for people who don't
I think many people would prefer namedtuple to not be registered automatically, but that would be a fairly significant breaking change at this point so it's unlikely to happen. |
@jakevdp ok, fair enough- but what is the buggy behavior we're trying to avoid? "internal discussions" aren't helpful for those of us on GitHub. it seems odd that undesirable functionality will be halfway supported. maybe ya'll can document it here and close the related issues as |
@jheek might want to comment here. We probably could look into unregistering |
I've been following the issue for a while. I prefer the drop-in replacement idea like the one suggested by @tomhennigan. This would allow an easy, explicit way to mark fields as static or non-static. Marking fields as static cannot be automatically done based on the types of the elements. For example, an integer element might or might not be able to be static. If the integer is the result of a tracer, it needs to be non-static, but if it's used as the limit of a scan, it has to be static. There is no way to know at definition time. I ended up forking from flax too. To mark attributes as static, I went with using a modified field constructor, but I named the parameter I like @hawkinsp suggestion of deregistering |
named tuples are I think registering dataclasses by default (with a flag or always) does lead to problems. It might be worth considering allowing dataclasses to be registered post-hoc e.g.: @jax.tree_util.dataclass
class Foo:
# equivalent too
@jax.tree_util.register_dataclass
@dataclasses.dataclass(frozen=True)
class Foo: Having a field decorator for defining metadata can be really useful at times. In JAX it often avoids the need to define functions and flag like arguments in static_argnums. A non-registered dataclass can also really help. One thing that comes to mind is configs. They often contain bools, integers, and floats so they can often be raised into jax but they are intended to be constants and allow for optimisations like removing dead branches and avoiding removing multiply by 0 (for example |
In this pr, we allow users to register a customized flatten/unflatten/serialization/deserialization for a dataclass. We provide some default implementation for flatten/unflatten. We could implement a decorator based on it when needed. ## Motivation: HuggingFace and many internal models return dataclass output and torch.export wants to maintain the invariant that export result (i.e. exported_program) has the same calling convention and result as the original callable. This is not supported in export yet: we cannot recover the original dataclass from flattened output produced by the underlying graph module (produced by dynamo and processed further by aot_export). We need to have a place to store the metadata of the dataclass so that we can re-construct it. To avoid adding hacky code in export and allow princinpled extensibility, we think extending pytree may be a good option. ## Implementation: @zou3519 mentioned https://github.com/pytorch/pytorch/pull/93214/files and [jax-2371](jax-ml/jax#2371 (comment)), which suggests that it's not a good idea to make dataclass a default pytree node but it could be good to provide a default implementation for dataclass. Since currently, this seems to be an export-only feature, we added this extension point in export. We also add "return_none_fields" flag to control whether none fields are returned after flattening, which is expected to be False in produce_matching of dynamo.export. Also added some tests. Pull Request resolved: #106160 Approved by: https://github.com/zhxchen17
In this pr, we allow users to register a customized flatten/unflatten/serialization/deserialization for a dataclass. We provide some default implementation for flatten/unflatten. We could implement a decorator based on it when needed. ## Motivation: HuggingFace and many internal models return dataclass output and torch.export wants to maintain the invariant that export result (i.e. exported_program) has the same calling convention and result as the original callable. This is not supported in export yet: we cannot recover the original dataclass from flattened output produced by the underlying graph module (produced by dynamo and processed further by aot_export). We need to have a place to store the metadata of the dataclass so that we can re-construct it. To avoid adding hacky code in export and allow princinpled extensibility, we think extending pytree may be a good option. ## Implementation: @zou3519 mentioned https://github.com/pytorch/pytorch/pull/93214/files and [jax-2371](jax-ml/jax#2371 (comment)), which suggests that it's not a good idea to make dataclass a default pytree node but it could be good to provide a default implementation for dataclass. Since currently, this seems to be an export-only feature, we added this extension point in export. We also add "return_none_fields" flag to control whether none fields are returned after flattening, which is expected to be False in produce_matching of dynamo.export. Also added some tests. Pull Request resolved: pytorch#106160 Approved by: https://github.com/zhxchen17
Crashing the party to give my opinion, adding a new decorator seems the most convenient to me, I would love to have this functionality. Using a decorator that is separate from the usual |
@mishmish66 you might find the dataclasses utility functions in # Convenience decorator matching @dataclasses.dataclass:
@chex.dataclass
class MyChexDataclass:
foo: PyTree[jax.Array]
# You can also register regular dataclasses:
@dataclasses.dataclass
class MyRegularDataclass:
foo: PyTree[jax.Array]
chex.register_dataclass_type_with_jax_tree_util(MyRegularDataclass) |
This whole chex tool is great, thanks for the suggestion I'll definitely be using this! |
JAX should automatically treat dataclasses as pytrees, so they don't have to be explicitly registered.
Ideally we would also support some syntax for non-differentiable parameters. Flax does so by adding custom metadata into
dataclassess.Field.metadata
with the specialflax.struct.field
constructor, which seems like a very clean way to do this.I started working on this in a branch, but haven't tested anything so it very likely is entirely broken/non-functional! If somebody wants to finish this up it would be awesome :)
master...shoyer:dataclasses-pytree
xref #1808
The text was updated successfully, but these errors were encountered: