Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 23 additions & 2 deletions autoparallel/activation_checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import logging
import operator
from collections import defaultdict
from dataclasses import dataclass
from typing import Optional, Union
Expand Down Expand Up @@ -207,6 +208,26 @@ def mark_nodes_as_must_save_to_stage_recomputation(
if node.meta.get("recompute", None) is not None:
# do not mess with allgather nodes that have already been marked recompute!
continue
if node.target is operator.getitem:
# we need to be a bit careful: we are trying to manually emulate setting "precompute" tags
# in the same way that compiel does when it encounters userland SAC.
#
# torch.compile does this by using TorchDispatchModes to intercept ops as they are traced,
# and setting their "recompute" tag.
#
# However, TorchDispatchModes *only* intercept OpOverloads (and HOPs)
# getitem is neither, and so in vanilla torch.compile usage,
# getitem nodes recieve no tags.
#
# What happens if we blindly set all nodes to PREFER_RECOMPUTE? Example bad outcome:
# - user is using attention, so we see this series of ops in the joint graph:
# attention_fw -> getitem -> attention_bw (the getitem is an output used for the bw)
# - user runs SAC, and marks attention_fw as MUST_SAVE
# - if we mark getitem as PREFER_RECOMPUTE, and attention_fw as MUST_SAVE,
# the partitioner ends up generating an invalid graph.
# Today the partitioner relies on the fact that getitem's recompute behavior
# is implicitly determined by the recompute behavior of the multi-output op preceding it.
continue
node.meta["recompute"] = CheckpointPolicy.PREFER_RECOMPUTE
# add an arbitrarily large graph id. I'm assuming 100000 here, which should be fine
# and is the same we add for the all-gather nodes
Expand Down Expand Up @@ -327,7 +348,7 @@ def ac_joint_pass(graph: torch.fx.Graph, ac_stage_size_in_GiB: float = 2.0):
# policy, but this is not working yet
save_list = {
torch.ops.aten.mm.default,
# torch.ops.aten._scaled_dot_product_efficient_attention.default,
# torch.ops.aten._scaled_dot_product_flash_attention.default,
torch.ops.aten._scaled_dot_product_efficient_attention.default,
torch.ops.aten._scaled_dot_product_flash_attention.default,
}
_apply_ac_policy(graph, save_list=save_list)