From efd8617c9bc23fe8bb7ebec900a63f9f9a9a4123 Mon Sep 17 00:00:00 2001 From: Tugsbayasgalan Manlaibaatar Date: Wed, 1 Oct 2025 13:34:20 -0700 Subject: [PATCH] Move autoparallel to use leaner export API --- autoparallel/api.py | 26 ++++++++++++++++++++------ 1 file changed, 20 insertions(+), 6 deletions(-) diff --git a/autoparallel/api.py b/autoparallel/api.py index ce087738..824b8dd5 100644 --- a/autoparallel/api.py +++ b/autoparallel/api.py @@ -8,9 +8,10 @@ import warnings from contextlib import ExitStack, contextmanager from types import MethodType -from typing import Optional, Union +from typing import Any, Optional, Tuple, Union import torch +from torch._dynamo.functional_export import _dynamo_graph_capture_for_export from torch._functorch.aot_autograd import ( aot_compile_joint_with_descriptors, aot_export_joint_with_descriptors, @@ -22,6 +23,7 @@ from torch._subclasses import FakeTensorMode from torch.distributed.fsdp import MixedPrecisionPolicy from torch.distributed.tensor import DeviceMesh +from torch.export._trace import _restore_state_dict from torch.export._unlift import _assign_attr from torch.export.unflatten import _AttrKind @@ -163,6 +165,21 @@ def enable_local_map_wrapping(): yield +def _export(model: torch.nn.Module, inputs: Tuple[Any]) -> torch.nn.Module: + """ + Thin wrapper around graph capture output that restores the + original calling convention and attribute fqn. TODO: + 1) Use bytecode for calling convention instead of pytree for more + seamless UX. + 2) Attach guards + 3) Be more careful about tensor constants names. + """ + with torch._dynamo.config.patch(install_free_tensors=True): + gm = _dynamo_graph_capture_for_export(model)(*inputs) + _restore_state_dict(model, gm) + return gm + + class AutoParallel: """ Args: @@ -279,13 +296,10 @@ def build_model_graph(self): with set_dtype_cast( True ), enable_local_map_wrapping(), torch._dynamo.utils._disable_saved_tensors_hooks_during_tracing(): - with torch._dynamo.config.patch( - install_free_tensors=True - ), monkey_patch_export_verifier(): - ep = torch.export.export(self.model, inputs, strict=True) + torch_ir_with_fqn = _export(self.model, inputs) self.joint_with_descriptors = aot_export_joint_with_descriptors( self.stack, - ep.module(), + torch_ir_with_fqn, inputs, decompositions=decomp_table, fw_compiler=self.compiler_fn,