deepcompile: Fix backward graph recompilation due to unbalanced forward/backward visits#7980
Merged
tohtana merged 2 commits intodeepspeedai:masterfrom Apr 22, 2026
Conversation
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: ebbe957609
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
In PyTorch AOT Autograd, having tensors requiring grad in inputs doesn't guarantee backward graph compilation. If no output requires grad and no input requiring grad is mutated, aot_autograd skips backward compilation (see [1]). DeepCompile previously required backward compilation for every forward graph which required grad, but relied solely on the existence of require_grad tensors. This mismatch caused unbalanced forward/backward visits, leaving graphs unvisited in `frames_needing_bwd`. The patched FunctionMeta then remained effective during backward execution, raising KeyError when removing the (already-removed) frame IDs from the `frames_needing_bwd` set. A reproduction can be found at [2]. Simply put a guard on the set removal operation is insufficient. The backward graph is still recompiled on each iteration, severely impacting performance. Instead of duplicating how AOT Autograd determines whether to compile the backward graph, use the fact that a joint graph requires a backward pass if and only if it is partitioned into a forward and a backward module. The frame IDs of partitioned graphs are collected in the patched partition functions and then used to determine `needs_backward` in the forward compile function. `backend_fn` is not a proper place for the second step since autograd creates fw/bw compile functions before partitioning a joint graph. References [1] https://github.com/pytorch/pytorch/blob/aea31e0c306e2315bf6d84255e0dde7adf09762a/torch/_functorch/aot_autograd.py#L618 [2] https://gist.github.com/eternalNight/96d6bc60e2bf566fda1300154d0e89dc Signed-off-by: Junjie Mao <junjie.mao@linux.alibaba.com>
ebbe957 to
dd475b3
Compare
tohtana
approved these changes
Apr 22, 2026
Collaborator
tohtana
left a comment
There was a problem hiding this comment.
Thank you, @eternalNight! This is a significant improvement.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
In PyTorch AOT Autograd, having tensors requiring grad in inputs doesn't guarantee backward graph compilation. If no output requires grad and no input requiring grad is mutated, aot_autograd skips backward compilation (see [1]).
DeepCompile previously required backward compilation for every forward graph which required grad, but relied solely on the existence of require_grad tensors. This mismatch caused unbalanced forward/backward visits, leaving graphs unvisited in
frames_needing_bwd. The patched FunctionMeta then remained effective during backward execution, raising KeyError when removing the (already-removed) frame IDs from theframes_needing_bwdset. A reproduction can be found at [2].Simply put a guard on the set removal operation is insufficient. The backward graph is still recompiled on each iteration, severely impacting performance.
Instead of duplicating how AOT Autograd determines whether to compile the backward graph, use the fact that a joint graph requires a backward pass if and only if it is partitioned into a forward and a backward module. The frame IDs of partitioned graphs are collected in the patched partition functions and then used to determine
needs_backwardin theforward compile function.
backend_fnis not a proper place for the second step since autograd creates fw/bw compile functions before partitioning a joint graph.References
[1] https://github.com/pytorch/pytorch/blob/aea31e0c306e2315bf6d84255e0dde7adf09762a/torch/_functorch/aot_autograd.py#L618
[2] https://gist.github.com/eternalNight/96d6bc60e2bf566fda1300154d0e89dc