Skip to content

Fix layernorm and softmax axis after upstream#17255

Merged
pengwa merged 3 commits into
mainfrom
pengwa/fixbug
Aug 25, 2023
Merged

Fix layernorm and softmax axis after upstream#17255
pengwa merged 3 commits into
mainfrom
pengwa/fixbug

Conversation

@pengwa
Copy link
Copy Markdown
Contributor

@pengwa pengwa commented Aug 22, 2023

Fix layernorm and softmax axis after upstream

For Gather (the slicing is a scalar), the output rank is small than its inputs.

When we upstream this kind of Gather before softmax or layernorm, we should also update the axis attribute.
Otherwise, the axis might be out-of-date and incorrect for the updated rank.

  File "/opt/conda/envs/ptca/lib/python3.8/site-packages/onnxruntime/training/ortmodule/_fallback.py", line 157, in handle_exception
    raise exception
  File "/opt/conda/envs/ptca/lib/python3.8/site-packages/onnxruntime/training/ortmodule/_training_manager.py", line 280, in forward
    self._build_graph(graph_transformer_config)
  File "/opt/conda/envs/ptca/lib/python3.8/site-packages/onnxruntime/training/ortmodule/_logger.py", line 158, in wrapper
    result = func(graph_execution_manager, *args, **kwargs)
  File "/opt/conda/envs/ptca/lib/python3.8/site-packages/onnxruntime/training/ortmodule/_logger.py", line 273, in wrapper
    result = func(graph_execution_manager, *args, **kwargs)
  File "/opt/conda/envs/ptca/lib/python3.8/site-packages/onnxruntime/training/ortmodule/_training_manager.py", line 361, in _build_graph
    super()._build_graph(graph_transformer_config)
  File "/opt/conda/envs/ptca/lib/python3.8/site-packages/onnxruntime/training/ortmodule/_graph_execution_manager.py", line 184, in _build_graph
    self._graph_builder.build(config)
RuntimeError: /onnxruntime/orttraining/orttraining/python/orttraining_pybind_state.cc:823 onnxruntime::python::addObjectMethodsForTraining(pybind11::module&, onnxruntime::python::ExecutionProviderRegistrationFn)::<lambda(onnxruntime::training::OrtModuleGraphBuilder*, const onnxruntime::training::TrainingGraphTransformerConfiguration&)> [ONNXRuntimeError] : 1 : FAIL : Node (Softmax_2904) Op (Softmax) [ShapeInferenceError] 'axis' must be in [-3 , 2]. Its actual value is: 3

@pengwa pengwa added the training issues related to ONNX Runtime training; typically submitted using template label Aug 22, 2023
@pengwa pengwa requested review from askhade and baijumeswani August 22, 2023 08:57
@faxu faxu added the triage:approved Approved for cherrypicks for release label Aug 24, 2023
Copy link
Copy Markdown
Contributor

@askhade askhade left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@pengwa pengwa merged commit 7c98f45 into main Aug 25, 2023
@pengwa pengwa deleted the pengwa/fixbug branch August 25, 2023 04:26
Lafi7e pushed a commit that referenced this pull request Aug 28, 2023
### Fix layernorm and softmax axis after upstream

For Gather (the slicing is a scalar), the output rank is small than its
inputs.

When we upstream this kind of Gather before softmax or layernorm, we
should also update the axis attribute.
Otherwise, the axis might be out-of-date and incorrect for the updated
rank.

```
  File "/opt/conda/envs/ptca/lib/python3.8/site-packages/onnxruntime/training/ortmodule/_fallback.py", line 157, in handle_exception
    raise exception
  File "/opt/conda/envs/ptca/lib/python3.8/site-packages/onnxruntime/training/ortmodule/_training_manager.py", line 280, in forward
    self._build_graph(graph_transformer_config)
  File "/opt/conda/envs/ptca/lib/python3.8/site-packages/onnxruntime/training/ortmodule/_logger.py", line 158, in wrapper
    result = func(graph_execution_manager, *args, **kwargs)
  File "/opt/conda/envs/ptca/lib/python3.8/site-packages/onnxruntime/training/ortmodule/_logger.py", line 273, in wrapper
    result = func(graph_execution_manager, *args, **kwargs)
  File "/opt/conda/envs/ptca/lib/python3.8/site-packages/onnxruntime/training/ortmodule/_training_manager.py", line 361, in _build_graph
    super()._build_graph(graph_transformer_config)
  File "/opt/conda/envs/ptca/lib/python3.8/site-packages/onnxruntime/training/ortmodule/_graph_execution_manager.py", line 184, in _build_graph
    self._graph_builder.build(config)
RuntimeError: /onnxruntime/orttraining/orttraining/python/orttraining_pybind_state.cc:823 onnxruntime::python::addObjectMethodsForTraining(pybind11::module&, onnxruntime::python::ExecutionProviderRegistrationFn)::<lambda(onnxruntime::training::OrtModuleGraphBuilder*, const onnxruntime::training::TrainingGraphTransformerConfiguration&)> [ONNXRuntimeError] : 1 : FAIL : Node (Softmax_2904) Op (Softmax) [ShapeInferenceError] 'axis' must be in [-3 , 2]. Its actual value is: 3
```
snnn pushed a commit that referenced this pull request Aug 28, 2023
guyang3532 added a commit that referenced this pull request Nov 16, 2023
Similar to #17255
update axis for Layernormalization when Reshape upstream it.
kleiti pushed a commit to kleiti/onnxruntime that referenced this pull request Mar 22, 2024
### Fix layernorm and softmax axis after upstream

For Gather (the slicing is a scalar), the output rank is small than its
inputs.

When we upstream this kind of Gather before softmax or layernorm, we
should also update the axis attribute.
Otherwise, the axis might be out-of-date and incorrect for the updated
rank.

```
  File "/opt/conda/envs/ptca/lib/python3.8/site-packages/onnxruntime/training/ortmodule/_fallback.py", line 157, in handle_exception
    raise exception
  File "/opt/conda/envs/ptca/lib/python3.8/site-packages/onnxruntime/training/ortmodule/_training_manager.py", line 280, in forward
    self._build_graph(graph_transformer_config)
  File "/opt/conda/envs/ptca/lib/python3.8/site-packages/onnxruntime/training/ortmodule/_logger.py", line 158, in wrapper
    result = func(graph_execution_manager, *args, **kwargs)
  File "/opt/conda/envs/ptca/lib/python3.8/site-packages/onnxruntime/training/ortmodule/_logger.py", line 273, in wrapper
    result = func(graph_execution_manager, *args, **kwargs)
  File "/opt/conda/envs/ptca/lib/python3.8/site-packages/onnxruntime/training/ortmodule/_training_manager.py", line 361, in _build_graph
    super()._build_graph(graph_transformer_config)
  File "/opt/conda/envs/ptca/lib/python3.8/site-packages/onnxruntime/training/ortmodule/_graph_execution_manager.py", line 184, in _build_graph
    self._graph_builder.build(config)
RuntimeError: /onnxruntime/orttraining/orttraining/python/orttraining_pybind_state.cc:823 onnxruntime::python::addObjectMethodsForTraining(pybind11::module&, onnxruntime::python::ExecutionProviderRegistrationFn)::<lambda(onnxruntime::training::OrtModuleGraphBuilder*, const onnxruntime::training::TrainingGraphTransformerConfiguration&)> [ONNXRuntimeError] : 1 : FAIL : Node (Softmax_2904) Op (Softmax) [ShapeInferenceError] 'axis' must be in [-3 , 2]. Its actual value is: 3
```
kleiti pushed a commit to kleiti/onnxruntime that referenced this pull request Mar 22, 2024
Similar to microsoft#17255
update axis for Layernormalization when Reshape upstream it.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

training issues related to ONNX Runtime training; typically submitted using template triage:approved Approved for cherrypicks for release

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants