Skip to content

Allow defining customized PythonOp shape inferer#17093

Merged
pengwa merged 5 commits into
mainfrom
pengwa/python_op_shape_infer
Aug 14, 2023
Merged

Allow defining customized PythonOp shape inferer#17093
pengwa merged 5 commits into
mainfrom
pengwa/python_op_shape_infer

Conversation

@pengwa
Copy link
Copy Markdown
Contributor

@pengwa pengwa commented Aug 10, 2023

Allow defining customized PythonOp shape inferer

For torch.autograd.Function, we converted it to PythonOp in MSDomain, there are two places to do shape inferencing for it:

  1. in SymbolicShapeInfer, there is one.
  2. in PythonOp op definition.

For common PythonOp, since we don't know the relation ship between inputs and outputs, so we only infer the rank from output ranks, and generate symbolic dimensions for each dim. While this will introduce many meaningless symbolic dimensions, sometimes blocking our graph transformers to do op fusion.

This PR provide a way to define custom shape inferencing for torch.autograd.Function we defined, to propagate the original dimensions across the PythonOp at the best efforts.

But the 2rd one is not covered yet, we could refine that later. Fixing 1st one is enough for ORTModule training/evaluation.

Motivation and Context

@pengwa pengwa added the training issues related to ONNX Runtime training; typically submitted using template label Aug 10, 2023
@pengwa pengwa requested review from ajindal1 and askhade August 10, 2023 13:45
@pengwa pengwa force-pushed the pengwa/python_op_shape_infer branch from 8081e63 to af446cf Compare August 10, 2023 16:24
Comment thread onnxruntime/python/tools/symbolic_shape_infer.py
ajindal1
ajindal1 previously approved these changes Aug 10, 2023
Copy link
Copy Markdown
Contributor

@ajindal1 ajindal1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just added 1 comment, otherwise LGTM.

@pengwa pengwa merged commit cd7b3f5 into main Aug 14, 2023
@pengwa pengwa deleted the pengwa/python_op_shape_infer branch August 14, 2023 01:13
kleiti pushed a commit to kleiti/onnxruntime that referenced this pull request Mar 22, 2024
### Allow defining customized PythonOp shape inferer

For `torch.autograd.Function`, we converted it to PythonOp in MSDomain,
there are two places to do shape inferencing for it:

1. in SymbolicShapeInfer, there is one. 
2. in PythonOp op definition. 

For common PythonOp, since we don't know the relation ship between
inputs and outputs, so we only infer the rank from output ranks, and
generate symbolic dimensions for each dim. While this will introduce
many meaningless symbolic dimensions, sometimes blocking our graph
transformers to do op fusion.

This PR provide a way to define custom shape inferencing for
`torch.autograd.Function` we defined, to propagate the original
dimensions across the PythonOp at the best efforts.

But the 2rd one is not covered yet, we could refine that later. Fixing
1st one is enough for ORTModule training/evaluation.

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

training issues related to ONNX Runtime training; typically submitted using template

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants