Fix layernorm and softmax axis after upstream#17255
Merged
Merged
Conversation
askhade
reviewed
Aug 25, 2023
Lafi7e
pushed a commit
that referenced
this pull request
Aug 28, 2023
### Fix layernorm and softmax axis after upstream
For Gather (the slicing is a scalar), the output rank is small than its
inputs.
When we upstream this kind of Gather before softmax or layernorm, we
should also update the axis attribute.
Otherwise, the axis might be out-of-date and incorrect for the updated
rank.
```
File "/opt/conda/envs/ptca/lib/python3.8/site-packages/onnxruntime/training/ortmodule/_fallback.py", line 157, in handle_exception
raise exception
File "/opt/conda/envs/ptca/lib/python3.8/site-packages/onnxruntime/training/ortmodule/_training_manager.py", line 280, in forward
self._build_graph(graph_transformer_config)
File "/opt/conda/envs/ptca/lib/python3.8/site-packages/onnxruntime/training/ortmodule/_logger.py", line 158, in wrapper
result = func(graph_execution_manager, *args, **kwargs)
File "/opt/conda/envs/ptca/lib/python3.8/site-packages/onnxruntime/training/ortmodule/_logger.py", line 273, in wrapper
result = func(graph_execution_manager, *args, **kwargs)
File "/opt/conda/envs/ptca/lib/python3.8/site-packages/onnxruntime/training/ortmodule/_training_manager.py", line 361, in _build_graph
super()._build_graph(graph_transformer_config)
File "/opt/conda/envs/ptca/lib/python3.8/site-packages/onnxruntime/training/ortmodule/_graph_execution_manager.py", line 184, in _build_graph
self._graph_builder.build(config)
RuntimeError: /onnxruntime/orttraining/orttraining/python/orttraining_pybind_state.cc:823 onnxruntime::python::addObjectMethodsForTraining(pybind11::module&, onnxruntime::python::ExecutionProviderRegistrationFn)::<lambda(onnxruntime::training::OrtModuleGraphBuilder*, const onnxruntime::training::TrainingGraphTransformerConfiguration&)> [ONNXRuntimeError] : 1 : FAIL : Node (Softmax_2904) Op (Softmax) [ShapeInferenceError] 'axis' must be in [-3 , 2]. Its actual value is: 3
```
snnn
pushed a commit
that referenced
this pull request
Aug 28, 2023
Cherry-pick 1st round for rel-1.16.0 from https://github.com/microsoft/onnxruntime/issues?q=label%3Arelease%3A1.16+label%3Atriage%3Aapproved+is%3Aclosed except #17201 because it caused UT failure and is not fixed yet. PR list: #16417 #16936 #17000 #17236 #17238 #17240 #17252 #17255 #17258 #17265 #17267 #17277
guyang3532
added a commit
that referenced
this pull request
Nov 16, 2023
Similar to #17255 update axis for Layernormalization when Reshape upstream it.
kleiti
pushed a commit
to kleiti/onnxruntime
that referenced
this pull request
Mar 22, 2024
### Fix layernorm and softmax axis after upstream
For Gather (the slicing is a scalar), the output rank is small than its
inputs.
When we upstream this kind of Gather before softmax or layernorm, we
should also update the axis attribute.
Otherwise, the axis might be out-of-date and incorrect for the updated
rank.
```
File "/opt/conda/envs/ptca/lib/python3.8/site-packages/onnxruntime/training/ortmodule/_fallback.py", line 157, in handle_exception
raise exception
File "/opt/conda/envs/ptca/lib/python3.8/site-packages/onnxruntime/training/ortmodule/_training_manager.py", line 280, in forward
self._build_graph(graph_transformer_config)
File "/opt/conda/envs/ptca/lib/python3.8/site-packages/onnxruntime/training/ortmodule/_logger.py", line 158, in wrapper
result = func(graph_execution_manager, *args, **kwargs)
File "/opt/conda/envs/ptca/lib/python3.8/site-packages/onnxruntime/training/ortmodule/_logger.py", line 273, in wrapper
result = func(graph_execution_manager, *args, **kwargs)
File "/opt/conda/envs/ptca/lib/python3.8/site-packages/onnxruntime/training/ortmodule/_training_manager.py", line 361, in _build_graph
super()._build_graph(graph_transformer_config)
File "/opt/conda/envs/ptca/lib/python3.8/site-packages/onnxruntime/training/ortmodule/_graph_execution_manager.py", line 184, in _build_graph
self._graph_builder.build(config)
RuntimeError: /onnxruntime/orttraining/orttraining/python/orttraining_pybind_state.cc:823 onnxruntime::python::addObjectMethodsForTraining(pybind11::module&, onnxruntime::python::ExecutionProviderRegistrationFn)::<lambda(onnxruntime::training::OrtModuleGraphBuilder*, const onnxruntime::training::TrainingGraphTransformerConfiguration&)> [ONNXRuntimeError] : 1 : FAIL : Node (Softmax_2904) Op (Softmax) [ShapeInferenceError] 'axis' must be in [-3 , 2]. Its actual value is: 3
```
kleiti
pushed a commit
to kleiti/onnxruntime
that referenced
this pull request
Mar 22, 2024
Similar to microsoft#17255 update axis for Layernormalization when Reshape upstream it.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Fix layernorm and softmax axis after upstream
For Gather (the slicing is a scalar), the output rank is small than its inputs.
When we upstream this kind of Gather before softmax or layernorm, we should also update the axis attribute.
Otherwise, the axis might be out-of-date and incorrect for the updated rank.