forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[export] allow register dataclass as pytree node (pytorch#106160)
In this pr, we allow users to register a customized flatten/unflatten/serialization/deserialization for a dataclass. We provide some default implementation for flatten/unflatten. We could implement a decorator based on it when needed. ## Motivation: HuggingFace and many internal models return dataclass output and torch.export wants to maintain the invariant that export result (i.e. exported_program) has the same calling convention and result as the original callable. This is not supported in export yet: we cannot recover the original dataclass from flattened output produced by the underlying graph module (produced by dynamo and processed further by aot_export). We need to have a place to store the metadata of the dataclass so that we can re-construct it. To avoid adding hacky code in export and allow princinpled extensibility, we think extending pytree may be a good option. ## Implementation: @zou3519 mentioned https://github.com/pytorch/pytorch/pull/93214/files and [jax-2371](google/jax#2371 (comment)), which suggests that it's not a good idea to make dataclass a default pytree node but it could be good to provide a default implementation for dataclass. Since currently, this seems to be an export-only feature, we added this extension point in export. We also add "return_none_fields" flag to control whether none fields are returned after flattening, which is expected to be False in produce_matching of dynamo.export. Also added some tests. Pull Request resolved: pytorch#106160 Approved by: https://github.com/zhxchen17
- Loading branch information
1 parent
d4a8dbe
commit 4aa1590
Showing
2 changed files
with
143 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
import dataclasses | ||
|
||
from typing import Any, List, Optional, Tuple | ||
|
||
from torch.utils._pytree import ( | ||
_register_pytree_node, | ||
Context, | ||
FlattenFunc, | ||
MaybeFromStrFunc, | ||
ToStrFunc, | ||
UnflattenFunc, | ||
) | ||
|
||
|
||
def register_dataclass_as_pytree_node( | ||
typ: Any, | ||
flatten_fn: Optional[FlattenFunc] = None, | ||
unflatten_fn: Optional[UnflattenFunc] = None, | ||
to_str_fn: Optional[ToStrFunc] = None, | ||
maybe_from_str_fn: Optional[MaybeFromStrFunc] = None, | ||
*, | ||
return_none_fields: bool = False, | ||
) -> None: | ||
assert dataclasses.is_dataclass( | ||
typ | ||
), f"Only dataclasses can be registered with this function: {typ}" | ||
|
||
def default_flatten_fn(obj: Any) -> Tuple[List[Any], Context]: | ||
flattened = [] | ||
flat_names = [] | ||
none_names = [] | ||
for f in dataclasses.fields(obj): | ||
name, val = f.name, getattr(obj, f.name) | ||
if val is not None or return_none_fields: | ||
flattened.append(val) | ||
flat_names.append(name) | ||
else: | ||
none_names.append(name) | ||
return flattened, (typ, flat_names, none_names) | ||
|
||
def default_unflatten_fn(values: List[Any], context: Context) -> Any: | ||
typ, flat_names, none_names = context | ||
return typ(**dict(zip(flat_names, values)), **{k: None for k in none_names}) | ||
|
||
flatten_fn = flatten_fn if flatten_fn is not None else default_flatten_fn | ||
unflatten_fn = unflatten_fn if unflatten_fn is not None else default_unflatten_fn | ||
|
||
_register_pytree_node( | ||
typ, | ||
flatten_fn, | ||
unflatten_fn, | ||
None, | ||
None, | ||
) |