From dd475b343f1313cda96a9366976e18151d0ebcf5 Mon Sep 17 00:00:00 2001 From: Junjie Mao Date: Tue, 21 Apr 2026 11:24:06 +0800 Subject: [PATCH] deepcompile: Fix KeyError due to unbalanced forward/backward visits In PyTorch AOT Autograd, having tensors requiring grad in inputs doesn't guarantee backward graph compilation. If no output requires grad and no input requiring grad is mutated, aot_autograd skips backward compilation (see [1]). DeepCompile previously required backward compilation for every forward graph which required grad, but relied solely on the existence of require_grad tensors. This mismatch caused unbalanced forward/backward visits, leaving graphs unvisited in `frames_needing_bwd`. The patched FunctionMeta then remained effective during backward execution, raising KeyError when removing the (already-removed) frame IDs from the `frames_needing_bwd` set. A reproduction can be found at [2]. Simply put a guard on the set removal operation is insufficient. The backward graph is still recompiled on each iteration, severely impacting performance. Instead of duplicating how AOT Autograd determines whether to compile the backward graph, use the fact that a joint graph requires a backward pass if and only if it is partitioned into a forward and a backward module. The frame IDs of partitioned graphs are collected in the patched partition functions and then used to determine `needs_backward` in the forward compile function. `backend_fn` is not a proper place for the second step since autograd creates fw/bw compile functions before partitioning a joint graph. References [1] https://github.com/pytorch/pytorch/blob/aea31e0c306e2315bf6d84255e0dde7adf09762a/torch/_functorch/aot_autograd.py#L618 [2] https://gist.github.com/eternalNight/96d6bc60e2bf566fda1300154d0e89dc Signed-off-by: Junjie Mao --- deepspeed/compile/backend.py | 45 +++++++++++++++++++++++--------- deepspeed/compile/inductor.py | 34 +++++++++++++++--------- deepspeed/compile/partitioner.py | 5 +++- 3 files changed, 57 insertions(+), 27 deletions(-) diff --git a/deepspeed/compile/backend.py b/deepspeed/compile/backend.py index 0df72ed1666c..3858c2f20993 100644 --- a/deepspeed/compile/backend.py +++ b/deepspeed/compile/backend.py @@ -3,7 +3,7 @@ # DeepSpeed Team -from typing import Dict, List, Callable, Tuple +from typing import Dict, List, Callable, Tuple, Set import time import gc from collections import OrderedDict, deque @@ -12,7 +12,6 @@ from torch.fx import Graph, GraphModule try: - import torch.utils._pytree as pytree import torch._dynamo from functorch.compile import make_boxed_func from torch._functorch.aot_autograd import aot_module_simplified @@ -47,11 +46,19 @@ class GraphOrder: def __init__(self): self.frames = OrderedDict() - def add_graph(self, graph_id: int, frame_id: int, needs_backward: bool): + def __len__(self): + return len(self.frames) + + def add_graph(self, graph_id: int, frame_id: int): if frame_id not in self.frames: - self.frames[frame_id] = (graph_id, needs_backward) + self.frames[frame_id] = (graph_id, None) + + def set_needs_backward(self, frame_id: int, needs_backward: bool): + if frame_id in self.frames: + self.frames[frame_id] = (self.frames[frame_id][0], needs_backward) def get_graph_order(self) -> List[Tuple[int, bool]]: + assert all(isinstance(needs_backward, bool) for _, needs_backward in self.frames.values()) return list(self.frames.values()) def clear(self): @@ -61,6 +68,7 @@ def clear(self): graph_order_with_frame_id = GraphOrder() frames_needing_bwd = set() +frames_partitioned: Set[int] = set() profiling_results: Dict[int, ProfilingResult] = {} opt_pass_times = [] opt_passes = {} @@ -96,6 +104,7 @@ def launch_compile_passes(global_steps: int): graph_order_with_frame_id.clear() profiling_results.clear() param_manager.clear() + frames_partitioned.clear() def set_time_and_tensor_size(graph_id, graph: Graph, mem, bwd, profiling_results): @@ -225,11 +234,16 @@ def make_backend(backend, compile_config, compile_kwargs={}): def backend_fn(gm: GraphModule, real_inputs): graph_id = id(gm.graph) - needs_backward = pytree.tree_any(lambda x: x.requires_grad if torch.is_tensor(x) else False, real_inputs) + # Checking the existence of input tensors requiring grad alone is insufficient to determine `need_backward`. + # AOT autograd also checks the graph data flow and skips the backward pass if no output requires grad and no + # input requiring grad is mutated. + # + # Instead of replicating AOT autograd's backward pass determination (which is too costly), we infer whether + # backward pass is needed by checking if the joint graph is partitioned (into a forward and a backward module). + # This check cannot be placed here because autograd creates the fw/bw compiler callables before graph + # partitioning. It is thus postponed to the point where the fw compiler is called. frame_id = gm.meta["dynamo_compile_id"].frame_id - graph_order_with_frame_id.add_graph(graph_id, frame_id, needs_backward) - - graph_order = graph_order_with_frame_id.get_graph_order() + graph_order_with_frame_id.add_graph(graph_id, frame_id) z3_partition = any(hasattr(v, "ds_id") for v in real_inputs) if z3_partition: @@ -258,11 +272,14 @@ def backend_fn(gm: GraphModule, real_inputs): if graph_id not in profiling_results: profiling_results[graph_id] = ProfilingResult() profiling_results[graph_id].param_indices = param_indices - profiling_results[graph_id].needs_backward = needs_backward def make_fw_graph(gm, sample_inputs): time_start = time.time() - graph_index = len(graph_order) - 1 + graph_index = len(graph_order_with_frame_id) - 1 + + needs_backward = frame_id in frames_partitioned + graph_order_with_frame_id.set_needs_backward(frame_id, needs_backward) + profiling_results[graph_id].needs_backward = needs_backward if needs_backward: if len(frames_needing_bwd) == 0: @@ -290,7 +307,7 @@ def make_fw_graph(gm, sample_inputs): opt_passes=next_passes, gm=gm, graph_id=graph_id, - graph_order=graph_order, + graph_order=graph_order_with_frame_id.get_graph_order(), profiling_results=profiling_results, create_inputs_fn=lambda: real_inputs_with_rng, mem_budget=.0, # unused @@ -308,6 +325,7 @@ def make_fw_graph(gm, sample_inputs): def make_bw_graph(gm, sample_inputs): time_start = time.time() + graph_order = graph_order_with_frame_id.get_graph_order() graph_index = get_index_by_graph_id(graph_order, graph_id) log_rank0( f"Bwd start {graph_index} graph_id={graph_id} alloc_mem={get_accelerator().memory_allocated()} graph={gm.graph}", @@ -368,7 +386,8 @@ def compiler_fn(gm, sample_inputs): return compiler_fn - partition_fn = get_wrapped_partitioner(z3_partition, param_indices, min_cut_rematerialization_partition) + partition_fn = get_wrapped_partitioner(z3_partition, param_indices, min_cut_rematerialization_partition, + frame_id, frames_partitioned) aot_mod = aot_module_simplified(gm, real_inputs, fw_compiler=make_compiler_fn(make_fw_graph), @@ -377,7 +396,7 @@ def compiler_fn(gm, sample_inputs): return torch._dynamo.optimize(**compile_kwargs)(aot_mod) elif backend == "inductor": patch_create_aot_dispatcher_function(graph_id, z3_partition, make_fw_graph, make_bw_graph, real_inputs, - param_indices, param_manager) + param_indices, param_manager, frame_id, frames_partitioned) return torch._inductor.compile(gm, real_inputs) diff --git a/deepspeed/compile/inductor.py b/deepspeed/compile/inductor.py index 49a20e71727f..0fe880260439 100644 --- a/deepspeed/compile/inductor.py +++ b/deepspeed/compile/inductor.py @@ -3,6 +3,8 @@ # DeepSpeed Team +from typing import Set + import torch try: @@ -60,23 +62,29 @@ def wrapped_compiler(gm, fake_inputs): return wrapped_compiler -def wrap_partition_fn(partition_fn, real_inputs, param_indices): +def wrap_partition_fn(z3_partition: bool, partition_fn, real_inputs, param_indices, frame_id: int, + frames_partitioned: Set[int]): def wrapped_partition_fn(*args, **kwargs): - fn = get_wrapped_partitioner(True, param_indices, partition_fn=partition_fn) + fn = get_wrapped_partitioner(z3_partition, + param_indices, + partition_fn=partition_fn, + frame_id=frame_id, + frames_partitioned=frames_partitioned) fw_module, bw_module = fn(*args, **kwargs) - # get parameter names - pm = DSGraphParamManager(fw_module.graph, real_inputs, param_indices) + if z3_partition: + # get parameter names + pm = DSGraphParamManager(fw_module.graph, real_inputs, param_indices) - def fix_placeholder_meta(graph): - for n in graph.nodes: - if n.op == "placeholder" and n.name in pm.param_names: - n.meta["val"] = torch.empty([0], dtype=n.meta["val"].dtype, device=n.meta["val"].device) + def fix_placeholder_meta(graph): + for n in graph.nodes: + if n.op == "placeholder" and n.name in pm.param_names: + n.meta["val"] = torch.empty([0], dtype=n.meta["val"].dtype, device=n.meta["val"].device) - fix_placeholder_meta(fw_module.graph) - fix_placeholder_meta(bw_module.graph) + fix_placeholder_meta(fw_module.graph) + fix_placeholder_meta(bw_module.graph) return fw_module, bw_module @@ -84,7 +92,7 @@ def fix_placeholder_meta(graph): def patch_create_aot_dispatcher_function(graph_id: int, z3_partition: bool, make_fw_graph, make_bw_graph, real_inputs, - param_indices, param_manager): + param_indices, param_manager, frame_id: int, frames_partitioned: Set[int]): from torch._dynamo.backends.common import AotAutograd import functools @@ -112,8 +120,8 @@ def patched_init(self, **kwargs): bwd=True) kwargs["inference_compiler"] = kwargs["fw_compiler"] - if z3_partition: - kwargs["partition_fn"] = wrap_partition_fn(kwargs["partition_fn"], real_inputs, param_indices) + kwargs["partition_fn"] = wrap_partition_fn(z3_partition, kwargs["partition_fn"], real_inputs, + param_indices, frame_id, frames_partitioned) original_init(self, **kwargs) diff --git a/deepspeed/compile/partitioner.py b/deepspeed/compile/partitioner.py index 28db711433e8..9f1345da875a 100644 --- a/deepspeed/compile/partitioner.py +++ b/deepspeed/compile/partitioner.py @@ -3,7 +3,7 @@ # DeepSpeed Team -from typing import Tuple, List +from typing import Tuple, List, Set import torch from torch.fx import GraphModule, Graph, Node @@ -85,10 +85,13 @@ def get_wrapped_partitioner( z3_partition: bool, param_indices: List[Tuple[int, int, torch.Size]], partition_fn, + frame_id: int, + frames_partitioned: Set[int], ): def partition_recompute_ds_params(joint_module: GraphModule, _joint_inputs, *, num_fwd_outputs, **kwargs) -> Tuple[GraphModule, GraphModule]: + frames_partitioned.add(frame_id) if z3_partition: _recompute_param_aliases(joint_module.graph, param_indices) return partition_fn(joint_module, _joint_inputs, num_fwd_outputs=num_fwd_outputs, **kwargs)