From e352712ce9266b52dca641cc9cff4fde3eaae388 Mon Sep 17 00:00:00 2001 From: Brian Hirsh Date: Thu, 28 Aug 2025 07:10:02 -0700 Subject: [PATCH 1/2] fix getitem handling in existing SAC tag pass, turn on attention SAC in example --- autoparallel/activation_checkpointing.py | 26 ++++++++++++++++++++++-- 1 file changed, 24 insertions(+), 2 deletions(-) diff --git a/autoparallel/activation_checkpointing.py b/autoparallel/activation_checkpointing.py index 71fe731e..6217228c 100644 --- a/autoparallel/activation_checkpointing.py +++ b/autoparallel/activation_checkpointing.py @@ -3,6 +3,7 @@ # This source code is licensed under the BSD license found in the # LICENSE file in the root directory of this source tree. import logging +import operator from collections import defaultdict from dataclasses import dataclass from typing import Optional, Union @@ -207,6 +208,26 @@ def mark_nodes_as_must_save_to_stage_recomputation( if node.meta.get("recompute", None) is not None: # do not mess with allgather nodes that have already been marked recompute! continue + if node.target is operator.getitem: + # we need to be a bit careful: we are trying to manually emulate setting "precompute" tags + # in the same way that compiel does when it encounters userland SAC. + # + # torch.compile does this by using TorchDispatchModes to intercept ops as they are traced, + # and setting their "recompute" tag. + # + # However, TorchDispatchModes *only* intercept OpOverloads (and HOPs) + # getitem is neither, and so in vanilla torch.compile usage, + # getitem nodes recieve no tags. + # + # What happens if we blindly set all nodes to PREFER_RECOMPUTE? Example bad outcome: + # - user is using attention, so we see this series of ops in the joint graph: + # attention_fw -> getitem -> attention_bw (the getitem is an output used for the bw) + # - user runs SAC, and marks attention_fw as MUST_SAVE + # - if we mark getitem as PREFER_RECOMPUTE, and attention_fw as MUST_SAVE, + # the partitioner ends up generating an invalid graph. + # Today the partitioner relies on the fact that getitem's recompute behavior + # is implicitly determined by the recompute behavior of the multi-output op preceding it. + continue node.meta["recompute"] = CheckpointPolicy.PREFER_RECOMPUTE # add an arbitrarily large graph id. I'm assuming 100000 here, which should be fine # and is the same we add for the all-gather nodes @@ -314,6 +335,7 @@ def _apply_ac_policy(joint_graph: torch.fx.Graph, save_list: set[torch.ops.OpOve counter += 1 continue must_save_nodes.append(node) + must_save_nodes.append(node) _mark_nodes_as_must_save(must_save_nodes) @@ -327,7 +349,7 @@ def ac_joint_pass(graph: torch.fx.Graph, ac_stage_size_in_GiB: float = 2.0): # policy, but this is not working yet save_list = { torch.ops.aten.mm.default, - # torch.ops.aten._scaled_dot_product_efficient_attention.default, - # torch.ops.aten._scaled_dot_product_flash_attention.default, + torch.ops.aten._scaled_dot_product_efficient_attention.default, + torch.ops.aten._scaled_dot_product_flash_attention.default, } _apply_ac_policy(graph, save_list=save_list) From 2fca0916a07eb0ad8dc17961420f0157eb19eb29 Mon Sep 17 00:00:00 2001 From: Brian Hirsh Date: Thu, 28 Aug 2025 07:18:24 -0700 Subject: [PATCH 2/2] cleanup --- autoparallel/activation_checkpointing.py | 1 - 1 file changed, 1 deletion(-) diff --git a/autoparallel/activation_checkpointing.py b/autoparallel/activation_checkpointing.py index 6217228c..0b19587f 100644 --- a/autoparallel/activation_checkpointing.py +++ b/autoparallel/activation_checkpointing.py @@ -335,7 +335,6 @@ def _apply_ac_policy(joint_graph: torch.fx.Graph, save_list: set[torch.ops.OpOve counter += 1 continue must_save_nodes.append(node) - must_save_nodes.append(node) _mark_nodes_as_must_save(must_save_nodes)