From eb0a3a71fb8487a64a480db07988090713f2acfa Mon Sep 17 00:00:00 2001 From: Tugsbayasgalan Manlaibaatar Date: Fri, 10 Oct 2025 11:42:27 -0400 Subject: [PATCH] Revert "Move autoparallel to use leaner export API (#181)" This reverts commit f1887eb1e6e2f388120e80c762f4b192bb39b564. --- autoparallel/api.py | 26 ++++++-------------------- 1 file changed, 6 insertions(+), 20 deletions(-) diff --git a/autoparallel/api.py b/autoparallel/api.py index 5ab07997..3ae9dc63 100644 --- a/autoparallel/api.py +++ b/autoparallel/api.py @@ -8,10 +8,9 @@ import warnings from contextlib import ExitStack, contextmanager from types import MethodType -from typing import Any, Optional, Tuple, Union +from typing import Optional, Union import torch -from torch._dynamo.functional_export import _dynamo_graph_capture_for_export from torch._functorch.aot_autograd import ( aot_compile_joint_with_descriptors, aot_export_joint_with_descriptors, @@ -23,7 +22,6 @@ from torch._subclasses import FakeTensorMode from torch.distributed.fsdp import MixedPrecisionPolicy from torch.distributed.tensor import DeviceMesh -from torch.export._trace import _restore_state_dict from torch.export._unlift import _assign_attr from torch.export.unflatten import _AttrKind @@ -165,21 +163,6 @@ def enable_local_map_wrapping(): yield -def _export(model: torch.nn.Module, inputs: Tuple[Any]) -> torch.nn.Module: - """ - Thin wrapper around graph capture output that restores the - original calling convention and attribute fqn. TODO: - 1) Use bytecode for calling convention instead of pytree for more - seamless UX. - 2) Attach guards - 3) Be more careful about tensor constants names. - """ - with torch._dynamo.config.patch(install_free_tensors=True): - gm = _dynamo_graph_capture_for_export(model)(*inputs) - _restore_state_dict(model, gm) - return gm - - class AutoParallel: """ Args: @@ -296,10 +279,13 @@ def build_model_graph(self): with set_dtype_cast( True ), enable_local_map_wrapping(), torch._dynamo.utils._disable_saved_tensors_hooks_during_tracing(): - torch_ir_with_fqn = _export(self.model, inputs) + with torch._dynamo.config.patch( + install_free_tensors=True + ), monkey_patch_export_verifier(): + ep = torch.export.export(self.model, inputs, strict=True) self.joint_with_descriptors = aot_export_joint_with_descriptors( self.stack, - torch_ir_with_fqn, + ep.module(), inputs, decompositions=decomp_table, )