Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LocalFunction] Shape mismatch attempting to re-use buffer #17061

Closed
BowenBao opened this issue Aug 9, 2023 · 18 comments
Closed

[LocalFunction] Shape mismatch attempting to re-use buffer #17061

BowenBao opened this issue Aug 9, 2023 · 18 comments
Assignees
Labels
converter:dynamo issues related supporting the PyTorch Dynamo exporter core runtime issues related to core runtime

Comments

@BowenBao
Copy link
Contributor

BowenBao commented Aug 9, 2023

Similar to #16813 (unsure if related), this issue occurs when running dynamo exporter produced model w/ local functions. The fully inlined model runs successfully.

NOTE that ORT must be built with ONNX v1.14.1 or above, otherwise segfault may occur during ONNX shape inference.

Example model and repro script can be found here: torchbench_hf_Bart.

========== hf_Bart failed: [ONNXRuntimeError] : 6 : RUNTIME_EXCEPTION : Non-zero status code returned while running If node. Name:'_inline_aten_tn5' Status Message: /bert_ort/bowbao/repos/bench/onnxruntime/onnxruntime/core/framework/op_kernel.cc:83 virtual OrtValue* onnxruntime::OpKernelContext::OutputMLValue(int, const onnxruntime::TensorShape&) status.IsOK() was false. Shape mismatch attempting to re-use buffer. {768,3072} != {768,50265}. Validate usage of dim_value (values should be > 0) and dim_param (all values with the same string should equate to the same size) in shapes in the model.

========== inlined hf_Bart passed

@BowenBao BowenBao added core runtime issues related to core runtime converter:dynamo issues related supporting the PyTorch Dynamo exporter labels Aug 9, 2023
@snnn
Copy link
Member

snnn commented Aug 9, 2023

Most likely it is an ONNX shape inference bug.

@BowenBao
Copy link
Contributor Author

BowenBao commented Aug 9, 2023

Most likely it is an ONNX shape inference bug.

cc @jcwchen

@jcwchen
Copy link
Contributor

jcwchen commented Aug 9, 2023

To clarify, will this issue be resolved by ONNX 1.14.1 (and also if ORT consumes ONNX 1.14.1 commit)? Or it is a new issue that needs to be fixed in ONNX?

BTW, ONNX 1.14.1 is coming out soon and we should be able to consume ONNX 1.14.1 before next ORT release.

@BowenBao
Copy link
Contributor Author

BowenBao commented Aug 9, 2023

This is new issue, the repro was from a local build of ORT w/ ONNX 1.14.1.

@BowenBao
Copy link
Contributor Author

BowenBao commented Aug 9, 2023

Most likely it is an ONNX shape inference bug.

Model does pass below shape inference checks without error, this is part of the repro provided.

    onnx.shape_inference.infer_shapes(
        onnx_model, check_type=True, strict_mode=True, data_prop=True
    )
    onnx.checker.check_model(onnx_model, full_check=True)

@snnn
Copy link
Member

snnn commented Aug 9, 2023

It doesn't mean the shape inference functions generate correct results. It didn't crash, but the results could be wrong. For example, if a matmul has two inputs: matrix A and matrix B. A has shape: [m,n]. B has shape [n,k], then the result should have shape [m,k]. But the shape inference functions wrongly thinks the output shape is [m,m], onnx checker would still pass. However, ONNX Runtime would allocate a wrong buffer for the output.
So, what I want to say is: ONNX should have a way to test if the shape inference functions are correct, and it shouldn't rely on ORT. Otherwise it would be like: "I must be right because I say I am right". If the correctness of the shape inference functions is based on ORT, then ORT cannot trust the function's results.

@yuslepukhin
Copy link
Member

I will take a look at it

@jcwchen
Copy link
Contributor

jcwchen commented Aug 10, 2023

Model does pass below shape inference checks without error, this is part of the repro provided.

ONNX shape inference might overlook some issues coming from subgraph or local function due to onnx/onnx#5463. Fresh PR onnx/onnx#5488 should fix that and we can run shape inference again with that fix (and enable strict_mode) to ensure whether ONNX can catch it.

@yuslepukhin yuslepukhin self-assigned this Aug 14, 2023
@yuslepukhin
Copy link
Member

@jcwchen Is this going to be in 1.14.1?

@jcwchen
Copy link
Contributor

jcwchen commented Aug 14, 2023

@jcwchen Is this going to be in 1.14.1?

No, because onnx/onnx#5488 hasn't been merged yet and ONNX 1.14.1 will probably be out this week.

To workaround, if you want onnx shape inference be able to catch local function related issues, you can try https://pypi.org/project/onnx-weekly/ instead after onnx/onnx#5488 has been merged.

@yuslepukhin
Copy link
Member

I am actually debugging in C++ with ORT built with ONNX 1.14.1

@jcwchen
Copy link
Contributor

jcwchen commented Aug 14, 2023

I am actually debugging in C++ with ORT built with ONNX 1.14.1

IIUC, this ONNX issue (not honor check_type and strict_mode for local function) should be fine for ORT, because ORT does explicitly set check_type and strict_mode when using shape inference for location function:

ONNX_NAMESPACE::ShapeInferenceOptions options{true, 1, false};

The fix onnx/onnx#5488 is only solving the issue that ONNX's shape inference API suppresses local function related shape inference error.

@yuslepukhin
Copy link
Member

The inlined version of the model runs w/o any errors. Most of the time is spent on optimizations when enabled.

@jcwchen
Copy link
Contributor

jcwchen commented Aug 14, 2023

The inlined version of the model runs w/o any errors. Most of the time is spent on optimizations when enabled.

Sounds like it should throw some errors, but it doesn't? If yes, there might be other unidentified issues/limitations for ONNX's function shape inference and onnx/onnx#5488 won't fix that I think.

@yuslepukhin
Copy link
Member

The non-inlined version does re-pro the error in the description. The inlining is done using onnx. The script also performes shape inferencing. I will try to run this without it and see what happens.

@yuslepukhin
Copy link
Member

yuslepukhin commented Aug 23, 2023

The root cause of the problem has been identified. We are going to issue a temporary fix.
Here is the known workarounds meanwhile.

Note. The debugging and testing was done against ONNX rel-1.14.1 branch. We are assuming this is what we are going to ship.

  • You can inline and pre-optimize the model and save it. No need for ONNX inlining. The built-in inlining works just fine. session_options.optimized_model_filepath = optimized_model_name.onnx. You can then load the optimized model fast and run it. (Recommended option)
  • Disable memory reuse session_options.enable_mem_reuse = False. This would increase memory consumption and would likely impact inference performance.

yuslepukhin added a commit that referenced this issue Aug 23, 2023
### Description
Temporarily disable symbol tables.

### Motivation and Context
Local symbol tables mark unrelated shapes re-use and cause inference to
error out.

#17061
er3x3 pushed a commit that referenced this issue Aug 28, 2023
### Description
Temporarily disable symbol tables.

### Motivation and Context
Local symbol tables mark unrelated shapes re-use and cause inference to
error out.

#17061
@jberclaz
Copy link

@yuslepukhin , can you comment on the fix for this issue? Has it been included in release 1.16.0? Thanks

@yuslepukhin
Copy link
Member

@yuslepukhin , can you comment on the fix for this issue? Has it been included in release 1.16.0? Thanks

I believe so. It was approved for that.

kleiti pushed a commit to kleiti/onnxruntime that referenced this issue Mar 22, 2024
…#17267)

### Description
Temporarily disable symbol tables.

### Motivation and Context
Local symbol tables mark unrelated shapes re-use and cause inference to
error out.

microsoft#17061
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
converter:dynamo issues related supporting the PyTorch Dynamo exporter core runtime issues related to core runtime
Projects
None yet
Development

No branches or pull requests

5 participants