Skip to content

deepcompile: Fix backward graph recompilation due to unbalanced forward/backward visits#7980

Merged
tohtana merged 2 commits intodeepspeedai:masterfrom
openanolis:eternalNight/wrap_compiled_function_during_forward
Apr 22, 2026
Merged

deepcompile: Fix backward graph recompilation due to unbalanced forward/backward visits#7980
tohtana merged 2 commits intodeepspeedai:masterfrom
openanolis:eternalNight/wrap_compiled_function_during_forward

Conversation

@eternalNight
Copy link
Copy Markdown
Contributor

@eternalNight eternalNight commented Apr 20, 2026

In PyTorch AOT Autograd, having tensors requiring grad in inputs doesn't guarantee backward graph compilation. If no output requires grad and no input requiring grad is mutated, aot_autograd skips backward compilation (see [1]).

DeepCompile previously required backward compilation for every forward graph which required grad, but relied solely on the existence of require_grad tensors. This mismatch caused unbalanced forward/backward visits, leaving graphs unvisited in frames_needing_bwd. The patched FunctionMeta then remained effective during backward execution, raising KeyError when removing the (already-removed) frame IDs from the frames_needing_bwd set. A reproduction can be found at [2].

Simply put a guard on the set removal operation is insufficient. The backward graph is still recompiled on each iteration, severely impacting performance.

Instead of duplicating how AOT Autograd determines whether to compile the backward graph, use the fact that a joint graph requires a backward pass if and only if it is partitioned into a forward and a backward module. The frame IDs of partitioned graphs are collected in the patched partition functions and then used to determine needs_backward in the
forward compile function. backend_fn is not a proper place for the second step since autograd creates fw/bw compile functions before partitioning a joint graph.

References

[1] https://github.com/pytorch/pytorch/blob/aea31e0c306e2315bf6d84255e0dde7adf09762a/torch/_functorch/aot_autograd.py#L618
[2] https://gist.github.com/eternalNight/96d6bc60e2bf566fda1300154d0e89dc

Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: ebbe957609

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment thread deepspeed/compile/backend.py Outdated
Comment thread deepspeed/compile/patch_compiled_func.py Outdated
In PyTorch AOT Autograd, having tensors requiring grad in inputs doesn't
guarantee backward graph compilation. If no output requires grad and no
input requiring grad is mutated, aot_autograd skips backward compilation
(see [1]).

DeepCompile previously required backward compilation for every forward
graph which required grad, but relied solely on the existence of
require_grad tensors. This mismatch caused unbalanced forward/backward
visits, leaving graphs unvisited in `frames_needing_bwd`. The patched
FunctionMeta then remained effective during backward execution, raising
KeyError when removing the (already-removed) frame IDs from the
`frames_needing_bwd` set. A reproduction can be found at [2].

Simply put a guard on the set removal operation is insufficient. The
backward graph is still recompiled on each iteration, severely impacting
performance.

Instead of duplicating how AOT Autograd determines whether to compile
the backward graph, use the fact that a joint graph requires a backward
pass if and only if it is partitioned into a forward and a backward
module. The frame IDs of partitioned graphs are collected in the patched
partition functions and then used to determine `needs_backward` in the
forward compile function. `backend_fn` is not a proper place for the
second step since autograd creates fw/bw compile functions before
partitioning a joint graph.

References

[1] https://github.com/pytorch/pytorch/blob/aea31e0c306e2315bf6d84255e0dde7adf09762a/torch/_functorch/aot_autograd.py#L618
[2] https://gist.github.com/eternalNight/96d6bc60e2bf566fda1300154d0e89dc

Signed-off-by: Junjie Mao <junjie.mao@linux.alibaba.com>
@eternalNight eternalNight force-pushed the eternalNight/wrap_compiled_function_during_forward branch from ebbe957 to dd475b3 Compare April 21, 2026 03:57
Copy link
Copy Markdown
Collaborator

@tohtana tohtana left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you, @eternalNight! This is a significant improvement.

@tohtana tohtana enabled auto-merge (squash) April 22, 2026 06:39
@tohtana tohtana merged commit 077bff5 into deepspeedai:master Apr 22, 2026
7 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants