Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Autograd Function Fallback bug fix - moe support #8105

Merged
merged 11 commits into from
Jul 7, 2021
Merged

Conversation

pengwa
Copy link
Contributor

@pengwa pengwa commented Jun 20, 2021

Description: Autograd Function Fallback bug fix - moe support

  • Support forward inputs orders like "Non_tensor/Tensor/Non_tensor". Correspondingly, support "None/Tensor_Grad/None" for backward outputs.
  • Report RuntimeError when PythonOp detected but _enable_custom_autograd_function is NOT enabled.
  • Simplify the attributed used by PythonOpGrad. Renaming some attribute to reflect whether it is used for tensor inputs or all (tensor + non-tensor) inputs.

Attached the PythonOP/PythonOpGrad schemas (and how those attributes used) after the change:

image

Motivation and Context

  • Why is this change required? What problem does it solve?
  • If it fixes an open issue, please link to the issue here.

…rrespondingly, support "None/Tensor_Grad/None" fpr backward outputs.
…raining\python\training\ortmodule\__init__.py (1 issue)"
Copy link
Contributor

@tlh20 tlh20 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few minor comments on places we could extend assertion checks. Adjacent to some of the changes in this PR there is a use of "static" that looks suspicious -- could you check if that is correct?

onnxruntime/core/language_interop_ops/torch/torch_proxy.cc Outdated Show resolved Hide resolved
onnxruntime/core/language_interop_ops/torch/torch_proxy.cc Outdated Show resolved Hide resolved
orttraining/orttraining/core/graph/gradient_builder.cc Outdated Show resolved Hide resolved
orttraining/orttraining/core/graph/training_op_defs.cc Outdated Show resolved Hide resolved
orttraining/orttraining/core/graph/training_op_defs.cc Outdated Show resolved Hide resolved
orttraining/orttraining/core/graph/training_op_defs.cc Outdated Show resolved Hide resolved
orttraining/orttraining/core/graph/training_op_defs.cc Outdated Show resolved Hide resolved
Co-authored-by: Tim Harris <tiharr@microsoft.com>
@SherlockNoMad SherlockNoMad added the training issues related to ONNX Runtime training; typically submitted using template label Jul 2, 2021
SherlockNoMad
SherlockNoMad previously approved these changes Jul 2, 2021
pengwa and others added 3 commits July 3, 2021 10:45
Refine the schema description

Co-authored-by: Tim Harris <tiharr@microsoft.com>
@pengwa
Copy link
Contributor Author

pengwa commented Jul 3, 2021

A few minor comments on places we could extend assertion checks. Adjacent to some of the changes in this PR there is a use of "static" that looks suspicious -- could you check if that is correct?

Thanks @tlh20 for your time reviewing , I have addressed all of comments . :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
training issues related to ONNX Runtime training; typically submitted using template
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants