Skip to content

Commit

Permalink
[export] allow register dataclass as pytree node (pytorch#106160)
Browse files Browse the repository at this point in the history
In this pr, we allow users to register a customized flatten/unflatten/serialization/deserialization for a dataclass. We provide some default implementation for flatten/unflatten. We could implement a decorator based on it when needed.

## Motivation:
HuggingFace and many internal models return dataclass output and torch.export wants to maintain the invariant that export result (i.e. exported_program) has the same calling convention and result as the original callable.

This is not supported in export yet: we cannot recover the original dataclass from flattened output produced by the underlying graph module (produced by dynamo and processed further by aot_export). We need to have a place to store the metadata of the dataclass so that we can re-construct it. To avoid adding hacky code in export and allow princinpled extensibility, we think extending pytree may be a good option.

## Implementation:
@zou3519 mentioned https://github.com/pytorch/pytorch/pull/93214/files and [jax-2371](google/jax#2371 (comment)), which suggests that it's not a good idea to make dataclass a default pytree node but it could be good to provide a default implementation for dataclass. Since currently, this seems to be an export-only feature, we added this extension point in export.

We also add "return_none_fields" flag to control whether none fields are returned after flattening, which is expected to be False in produce_matching of dynamo.export.

Also added some tests.

Pull Request resolved: pytorch#106160
Approved by: https://github.com/zhxchen17
  • Loading branch information
ydwu4 authored and bobby-palmer committed Jul 29, 2023
1 parent d4a8dbe commit 4aa1590
Show file tree
Hide file tree
Showing 2 changed files with 143 additions and 0 deletions.
89 changes: 89 additions & 0 deletions test/export/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,13 @@
import torch
import torch._dynamo as torchdynamo
from torch._export import export, dynamic_dim
from torch._export.utils import register_dataclass_as_pytree_node
from torch._export.constraints import constrain_as_size, constrain_as_value
from torch.fx.experimental.proxy_tensor import make_fx
from torch.testing._internal.common_utils import run_tests, TestCase
from torch.utils._pytree import tree_flatten, tree_unflatten, LeafSpec, TreeSpec
from functorch.experimental.control_flow import map
from dataclasses import dataclass


@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo isn't support")
Expand Down Expand Up @@ -330,6 +333,92 @@ def fn_ddo(x):
):
_ = export(fn_ddo, (torch.tensor([2, 3, 5]),), constraints=None)

def test_pytree_regster_data_class(self):

@dataclass
class MyDataClass:
x: int
y: int
z: int = None

dt = MyDataClass(x=3, y=4)
flat, spec = tree_flatten(dt)
self.assertTrue(spec, LeafSpec())
self.assertTrue(len(flat) == 1)

register_dataclass_as_pytree_node(MyDataClass)

flat, spec = tree_flatten(dt)
self.assertEqual(
spec,
TreeSpec(
MyDataClass,
(
MyDataClass,
['x', 'y'],
['z']
),
[LeafSpec(), LeafSpec()]
)
)
self.assertEqual(flat, [3, 4])

orig_dt = tree_unflatten(flat, spec)
self.assertTrue(isinstance(orig_dt, MyDataClass))
self.assertEqual(orig_dt.x, 3)
self.assertEqual(orig_dt.y, 4)
self.assertEqual(orig_dt.z, None)

# Override the registration with keep none fields
register_dataclass_as_pytree_node(MyDataClass, return_none_fields=True)

flat, spec = tree_flatten(dt)
self.assertEqual(
spec,
TreeSpec(
MyDataClass,
(
MyDataClass,
['x', 'y', 'z'],
[],
),
[LeafSpec(), LeafSpec(), LeafSpec()]
)
)
self.assertEqual(flat, [3, 4, None])

orig_dt = tree_unflatten(flat, spec)
self.assertTrue(isinstance(orig_dt, MyDataClass))
self.assertEqual(orig_dt.x, 3)
self.assertEqual(orig_dt.y, 4)
self.assertEqual(orig_dt.z, None)

def test_pytree_regster_nested_data_class(self):

@dataclass
class Inner:
x: int
y: int

@dataclass
class Outer:
xy: Inner
ab: Inner

xy = Inner(1, 2)
ab = Inner(3, 4)
dt = Outer(xy, ab)
inp = {"dt1": (dt, ({},)), "dt2": ((torch.ones(1),), dt)}

register_dataclass_as_pytree_node(Inner)
register_dataclass_as_pytree_node(Outer)

flat, spec = tree_flatten(inp)
self.assertEqual(flat, [1, 2, 3, 4, torch.ones(1), 1, 2, 3, 4])

unflat = tree_unflatten(flat, spec)
self.assertEqual(unflat, inp)


if __name__ == '__main__':
run_tests()
54 changes: 54 additions & 0 deletions torch/_export/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import dataclasses

from typing import Any, List, Optional, Tuple

from torch.utils._pytree import (
_register_pytree_node,
Context,
FlattenFunc,
MaybeFromStrFunc,
ToStrFunc,
UnflattenFunc,
)


def register_dataclass_as_pytree_node(
typ: Any,
flatten_fn: Optional[FlattenFunc] = None,
unflatten_fn: Optional[UnflattenFunc] = None,
to_str_fn: Optional[ToStrFunc] = None,
maybe_from_str_fn: Optional[MaybeFromStrFunc] = None,
*,
return_none_fields: bool = False,
) -> None:
assert dataclasses.is_dataclass(
typ
), f"Only dataclasses can be registered with this function: {typ}"

def default_flatten_fn(obj: Any) -> Tuple[List[Any], Context]:
flattened = []
flat_names = []
none_names = []
for f in dataclasses.fields(obj):
name, val = f.name, getattr(obj, f.name)
if val is not None or return_none_fields:
flattened.append(val)
flat_names.append(name)
else:
none_names.append(name)
return flattened, (typ, flat_names, none_names)

def default_unflatten_fn(values: List[Any], context: Context) -> Any:
typ, flat_names, none_names = context
return typ(**dict(zip(flat_names, values)), **{k: None for k in none_names})

flatten_fn = flatten_fn if flatten_fn is not None else default_flatten_fn
unflatten_fn = unflatten_fn if unflatten_fn is not None else default_unflatten_fn

_register_pytree_node(
typ,
flatten_fn,
unflatten_fn,
None,
None,
)

0 comments on commit 4aa1590

Please sign in to comment.