Skip to content

fix: add setup_context for torch.func compatibility#7916

Open
roycho96 wants to merge 21 commits intodeepspeedai:masterfrom
roycho96:fix/support-func-torch
Open

fix: add setup_context for torch.func compatibility#7916
roycho96 wants to merge 21 commits intodeepspeedai:masterfrom
roycho96:fix/support-func-torch

Conversation

@roycho96
Copy link
Copy Markdown

LinearFunctionForZeroStage3 uses the legacy forward(ctx, ...) pattern which is incompatible with torch.func transforms (torch.func.grad, torch.func.grad_and_value, vmap, etc.):

RuntimeError: In order to use an autograd.Function with functorch transforms
(vmap, grad, jvp, jacrev, ...), it must override the setup_context staticmethod.

This affects any library that uses torch.func internally on a ZeRO-3 model.

Fix

Fixes #7913

Note

As pointed out by @zhangj1an in #7913, PostBackwardFunctionModule and PreBackwardFunctionForModule in parameter_offload.py have the same issue. Those will be addressed in a follow-up commit within this PR.

Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: eed37042bc

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Signed-off-by: Sung Hyun Cho <hope5487@gmail.com>
… unpack error

Signed-off-by: Sung Hyun Cho <hope5487@gmail.com>
@roycho96 roycho96 force-pushed the fix/support-func-torch branch from 252aea1 to 39b1755 Compare March 21, 2026 10:28
zhangj1an

This comment was marked as resolved.

@zhangj1an

This comment was marked as resolved.

@roycho96
Copy link
Copy Markdown
Author

Thanks for the work! I implemented the same fix, so it looks good to me. To reduce reviewer's effort, I had 2 minor comments. This should lead to linear.py to only have 17 insertions (+) and 7 deletions (-).

Hi @zhangj1an, I've sent you a collaborator invite to my fork. Feel free to push your fix directly to the branch. Thanks for the suggestion!

… setup_context

Co-authored-by: zhangj1an <jianmusings@gmail.com>
Signed-off-by: Sung Hyun Cho <hope5487@gmail.com>
@zhangj1an zhangj1an force-pushed the fix/support-func-torch branch from 444122c to 6df37af Compare March 22, 2026 08:45
…afe linear

Avoid asymmetric custom_bwd without custom_fwd on the setup_context forward path;
mirror forward AMP in backward via torch.amp.autocast.

Signed-off-by: Zhang <jianmusings@gmail.com>
PyTorch versions that expose autograd.Function.setup_context need the
modern forward + setup_context shape for torch.func / functorch.

Signed-off-by: Zhang <jianmusings@gmail.com>
@zhangj1an zhangj1an force-pushed the fix/support-func-torch branch from 0a66444 to 5e83d05 Compare March 22, 2026 09:34
@zhangj1an

This comment was marked as resolved.

zhangj1an and others added 9 commits March 24, 2026 22:21
Signed-off-by: Zhang <jianmusings@gmail.com>
Signed-off-by: Zhang <jianmusings@gmail.com>
Signed-off-by: Zhang <jianmusings@gmail.com>
Signed-off-by: Zhang <jianmusings@gmail.com>
Signed-off-by: Zhang <jianmusings@gmail.com>
Signed-off-by: Zhang <jianmusings@gmail.com>
Signed-off-by: Zhang <jianmusings@gmail.com>
@zhangj1an

This comment was marked as resolved.

@roycho96 roycho96 marked this pull request as ready for review March 25, 2026 14:35
@roycho96 roycho96 requested a review from loadams as a code owner March 25, 2026 14:35
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 60d20da79f

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

@Flink-ddd
Copy link
Copy Markdown
Contributor

Hi @tohtana , would you mind reviewing this PR when you're free? It addresses a useful compatibility fix for torch.func. Much appreciated!

Signed-off-by: Sung Hyun Cho <hope5487@gmail.com>
@roycho96
Copy link
Copy Markdown
Author

I additionally fix autocast backward: always wrap with autocast(enabled=ctx._fwd_used_autocast) to match @custom_bwd semantics and prevent outer autocast leaking into backward.

Signed-off-by: Sung Hyun Cho <hope5487@gmail.com>
@roycho96 roycho96 force-pushed the fix/support-func-torch branch from 1acca1f to 04c456f Compare March 29, 2026 03:40
Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com>
Copy link
Copy Markdown
Collaborator

@tohtana tohtana left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @roycho96, @zhangj1an!

This PR overall looks good to me. The issue is I merged my PR #7920 and it caused a conflict as I hadn't checked the changes by this PR. Sorry for that.
I opened a new PR on your fork to resolve the conflict. Please check and merge it.

Signed-off-by: Zhang Jian <jianmusings@gmail.com>
Signed-off-by: Zhang Jian <jianmusings@gmail.com>
Signed-off-by: Zhang Jian <jianmusings@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[BUG] LinearFunctionForZeroStage3 crashes with torch.func transforms (missing setup_context)

5 participants