fix: add setup_context for torch.func compatibility#7916
fix: add setup_context for torch.func compatibility#7916roycho96 wants to merge 21 commits intodeepspeedai:masterfrom
Conversation
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: eed37042bc
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
Signed-off-by: Sung Hyun Cho <hope5487@gmail.com>
… unpack error Signed-off-by: Sung Hyun Cho <hope5487@gmail.com>
252aea1 to
39b1755
Compare
This comment was marked as resolved.
This comment was marked as resolved.
Hi @zhangj1an, I've sent you a collaborator invite to my fork. Feel free to push your fix directly to the branch. Thanks for the suggestion! |
… setup_context Co-authored-by: zhangj1an <jianmusings@gmail.com> Signed-off-by: Sung Hyun Cho <hope5487@gmail.com>
444122c to
6df37af
Compare
…afe linear Avoid asymmetric custom_bwd without custom_fwd on the setup_context forward path; mirror forward AMP in backward via torch.amp.autocast. Signed-off-by: Zhang <jianmusings@gmail.com>
PyTorch versions that expose autograd.Function.setup_context need the modern forward + setup_context shape for torch.func / functorch. Signed-off-by: Zhang <jianmusings@gmail.com>
0a66444 to
5e83d05
Compare
This comment was marked as resolved.
This comment was marked as resolved.
Signed-off-by: Zhang <jianmusings@gmail.com>
Signed-off-by: Zhang <jianmusings@gmail.com>
Signed-off-by: Zhang <jianmusings@gmail.com>
Signed-off-by: Zhang <jianmusings@gmail.com>
Signed-off-by: Zhang <jianmusings@gmail.com>
Signed-off-by: Zhang <jianmusings@gmail.com>
Signed-off-by: Zhang <jianmusings@gmail.com>
This comment was marked as resolved.
This comment was marked as resolved.
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 60d20da79f
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
|
Hi @tohtana , would you mind reviewing this PR when you're free? It addresses a useful compatibility fix for torch.func. Much appreciated! |
Signed-off-by: Sung Hyun Cho <hope5487@gmail.com>
|
I additionally fix autocast backward: always wrap with autocast(enabled=ctx._fwd_used_autocast) to match |
Signed-off-by: Sung Hyun Cho <hope5487@gmail.com>
1acca1f to
04c456f
Compare
Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com>
There was a problem hiding this comment.
Thank you @roycho96, @zhangj1an!
This PR overall looks good to me. The issue is I merged my PR #7920 and it caused a conflict as I hadn't checked the changes by this PR. Sorry for that.
I opened a new PR on your fork to resolve the conflict. Please check and merge it.
Resolve master merge conflict for deepspeedai#7916
Signed-off-by: Zhang Jian <jianmusings@gmail.com>
Signed-off-by: Zhang Jian <jianmusings@gmail.com>
Signed-off-by: Zhang Jian <jianmusings@gmail.com>
LinearFunctionForZeroStage3uses the legacyforward(ctx, ...)pattern which is incompatible withtorch.functransforms (torch.func.grad,torch.func.grad_and_value,vmap, etc.):This affects any library that uses
torch.funcinternally on a ZeRO-3 model.Fix
Fixes #7913
Note
As pointed out by @zhangj1an in #7913,
PostBackwardFunctionModuleandPreBackwardFunctionForModuleinparameter_offload.pyhave the same issue. Those will be addressed in a follow-up commit within this PR.