diff --git a/autoparallel/activation_checkpointing.py b/autoparallel/activation_checkpointing.py index 71fe731e..0b19587f 100644 --- a/autoparallel/activation_checkpointing.py +++ b/autoparallel/activation_checkpointing.py @@ -3,6 +3,7 @@ # This source code is licensed under the BSD license found in the # LICENSE file in the root directory of this source tree. import logging +import operator from collections import defaultdict from dataclasses import dataclass from typing import Optional, Union @@ -207,6 +208,26 @@ def mark_nodes_as_must_save_to_stage_recomputation( if node.meta.get("recompute", None) is not None: # do not mess with allgather nodes that have already been marked recompute! continue + if node.target is operator.getitem: + # we need to be a bit careful: we are trying to manually emulate setting "precompute" tags + # in the same way that compiel does when it encounters userland SAC. + # + # torch.compile does this by using TorchDispatchModes to intercept ops as they are traced, + # and setting their "recompute" tag. + # + # However, TorchDispatchModes *only* intercept OpOverloads (and HOPs) + # getitem is neither, and so in vanilla torch.compile usage, + # getitem nodes recieve no tags. + # + # What happens if we blindly set all nodes to PREFER_RECOMPUTE? Example bad outcome: + # - user is using attention, so we see this series of ops in the joint graph: + # attention_fw -> getitem -> attention_bw (the getitem is an output used for the bw) + # - user runs SAC, and marks attention_fw as MUST_SAVE + # - if we mark getitem as PREFER_RECOMPUTE, and attention_fw as MUST_SAVE, + # the partitioner ends up generating an invalid graph. + # Today the partitioner relies on the fact that getitem's recompute behavior + # is implicitly determined by the recompute behavior of the multi-output op preceding it. + continue node.meta["recompute"] = CheckpointPolicy.PREFER_RECOMPUTE # add an arbitrarily large graph id. I'm assuming 100000 here, which should be fine # and is the same we add for the all-gather nodes @@ -327,7 +348,7 @@ def ac_joint_pass(graph: torch.fx.Graph, ac_stage_size_in_GiB: float = 2.0): # policy, but this is not working yet save_list = { torch.ops.aten.mm.default, - # torch.ops.aten._scaled_dot_product_efficient_attention.default, - # torch.ops.aten._scaled_dot_product_flash_attention.default, + torch.ops.aten._scaled_dot_product_efficient_attention.default, + torch.ops.aten._scaled_dot_product_flash_attention.default, } _apply_ac_policy(graph, save_list=save_list)