diff --git a/autoparallel/api.py b/autoparallel/api.py index b24cb611..6928e1e4 100644 --- a/autoparallel/api.py +++ b/autoparallel/api.py @@ -250,9 +250,20 @@ def build_model_graph(self): inputs = (inputs,) with set_dtype_cast(True): + # TODO: what is going on, why do I need to trace twice and + # under the preserve_node_meta mode? + # if I trace only once under preserve_node_meta, I get + # that the nn_module_stack is not set + with torch.fx.traceback.preserve_node_meta(): + ep_with_ac = torch.export.export(self.model, inputs) ep = torch.export.export(self.model, inputs) + for n, n0 in zip(ep.graph.nodes, ep_with_ac.graph.nodes): + if "nn_module_stack" in n.meta: + n0.meta["nn_module_stack"] = n.meta["nn_module_stack"] + if "fwd_nn_module_stack" in n.meta: + n0.meta["fwd_nn_module_stack"] = n.meta["fwd_nn_module_stack"] self.joint_with_descriptors = aot_export_joint_with_descriptors( - self.stack, ep.module(), inputs, decompositions=decomp_table + self.stack, ep_with_ac.module(), inputs, decompositions=decomp_table ) gm = self.joint_with_descriptors.graph_module diff --git a/examples/example_autoparallel.py b/examples/example_autoparallel.py index ff9eff2a..4c85b2fe 100644 --- a/examples/example_autoparallel.py +++ b/examples/example_autoparallel.py @@ -4,15 +4,27 @@ # LICENSE file in the root directory of this source tree. +import functools + import torch from torch import nn from torch.distributed.fsdp import MixedPrecisionPolicy from torch.distributed.tensor.placement_types import Replicate, Shard from torch.testing._internal.distributed.fake_pg import FakeStore +from torch.utils.checkpoint import create_selective_checkpoint_contexts from autoparallel.api import AutoParallel +def policy_fn(ctx, op, *args, **kwargs): + if op == torch.ops.aten._scaled_dot_product_flash_attention.default: + return torch.utils.checkpoint.CheckpointPolicy.PREFER_SAVE + return torch.utils.checkpoint.CheckpointPolicy.PREFER_RECOMPUTE + + +context_fn = functools.partial(create_selective_checkpoint_contexts, policy_fn) + + class Block(nn.Module): def __init__(self, nheads, dim1, dim2): super().__init__() @@ -48,7 +60,7 @@ def _compute_attention(self, x): def forward(self, x): o = torch.utils.checkpoint.checkpoint( - self._compute_attention, x, use_reentrant=False + self._compute_attention, x, use_reentrant=False, context_fn=context_fn ) o0 = o + x