Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

onnxruntime.InferenceSession hangs #20354

Closed
doloresgarcia opened this issue Apr 17, 2024 · 21 comments
Closed

onnxruntime.InferenceSession hangs #20354

doloresgarcia opened this issue Apr 17, 2024 · 21 comments
Labels
core runtime issues related to core runtime

Comments

@doloresgarcia
Copy link

doloresgarcia commented Apr 17, 2024

Describe the issue

Hi! I have exported an onnx model from pytorch and I am trying to use it for inference, but the onnxruntime.InferenceSession hangs without any error
I have tried this on linux and mac. When using strace, it does not show anything after the last message

The model can be found here:
model

To reproduce

import onnxruntime as ort
ort.set_default_logger_severity(0)
so = ort.SessionOptions()
print(so.inter_op_num_threads)
print(so.intra_op_num_threads)
print("starting to load")
ort_session = ort.InferenceSession(
    "model_1.onnx",
    providers=["CPUExecutionProvider"],
)
print("finished loading")

The log gives this:

0
0
starting to load
2024-04-17 14:39:53.167254537 [I:onnxruntime:, inference_session.cc:330 operator()] Flush-to-zero and denormal-as-zero are off
2024-04-17 14:39:53.244762912 [I:onnxruntime:, inference_session.cc:338 ConstructorCommon] Creating and using per session threadpools since use_per_session_threads_ is true
2024-04-17 14:39:53.244791816 [I:onnxruntime:, inference_session.cc:356 ConstructorCommon] Dynamic block base set to 0
2024-04-17 14:39:53.245374968 [V:onnxruntime:Default, env.cc:248 ThreadMain] pthread_setaffinity_np succeed for thread: 1325925, index: 0, mask: {1, }
2024-04-17 14:39:53.245475139 [V:onnxruntime:Default, env.cc:248 ThreadMain] pthread_setaffinity_np succeed for thread: 1325927, index: 2, mask: {3, }
2024-04-17 14:39:53.245640304 [V:onnxruntime:Default, env.cc:248 ThreadMain] pthread_setaffinity_np succeed for thread: 1325930, index: 5, mask: {6, }
2024-04-17 14:39:53.257577164 [V:onnxruntime:Default, env.cc:248 ThreadMain] pthread_setaffinity_np succeed for thread: 1325926, index: 1, mask: {2, }
2024-04-17 14:39:53.257664768 [V:onnxruntime:Default, env.cc:248 ThreadMain] pthread_setaffinity_np succeed for thread: 1325928, index: 3, mask: {4, }
2024-04-17 14:39:53.257809630 [V:onnxruntime:Default, env.cc:248 ThreadMain] pthread_setaffinity_np succeed for thread: 1325929, index: 4, mask: {5, }
2024-04-17 14:39:53.260755085 [V:onnxruntime:Default, env.cc:248 ThreadMain] pthread_setaffinity_np succeed for thread: 1325932, index: 7, mask: {8, }
2024-04-17 14:39:53.260733432 [V:onnxruntime:Default, env.cc:248 ThreadMain] pthread_setaffinity_np succeed for thread: 1325931, index: 6, mask: {7, }
2024-04-17 14:39:53.264674323 [V:onnxruntime:Default, env.cc:248 ThreadMain] pthread_setaffinity_np succeed for thread: 1325933, index: 8, mask: {9, }
2024-04-17 14:40:02.464322875 [I:onnxruntime:, inference_session.cc:1402 Initialize] Initializing session.
2024-04-17 14:40:02.464374810 [I:onnxruntime:Default, bfc_arena.cc:29 BFCArena] Creating BFCArena for Cpu with following configs: initial_chunk_size_bytes: 1048576 max_dead_bytes_per_chunk: 134217728 initial_growth_chunk_size_bytes: 2097152 max_power_of_two_extend_bytes: 1073741824 memory limit: 18446744073709551615 arena_extend_strategy: 0
2024-04-17 14:40:02.495584676 [V:onnxruntime:Default, bfc_arena.cc:66 BFCArena] Creating 21 bins of max chunk size 256 to 268435456
2024-04-17 14:40:08.501566408 [I:onnxruntime:, constant_sharing.cc:256 ApplyImpl] Total shared scalar initializer count: 7799
2024-04-17 14:40:23.650234892 [I:onnxruntime:, graph.cc:3556 CleanUnusedInitializersAndNodeArgs] Removing initializer '_val_149'. It is no longer used by any node.
2024-04-17 14:40:23.650378373 [I:onnxruntime:, graph.cc:3556 CleanUnusedInitializersAndNodeArgs] Removing initializer '_val_142'. It is no longer used by any node.
2024-04-17 14:40:23.650385274 [I:onnxruntime:, graph.cc:3556 CleanUnusedInitializersAndNodeArgs] Removing initializer '_val_139'. It is no longer used by any node.
2024-04-17 14:40:23.650391115 [I:onnxruntime:, graph.cc:3556 CleanUnusedInitializersAndNodeArgs] Removing initializer '_val_137'. It is no longer used by any node.
2024-04-17 14:40:23.650403024 [I:onnxruntime:, graph.cc:3556 CleanUnusedInitializersAndNodeArgs] Removing initializer '_val_144'. It is no longer used by any node.
2024-04-17 14:40:23.650410618 [I:onnxruntime:, graph.cc:3556 CleanUnusedInitializersAndNodeArgs] Removing initializer '_val_146'. It is no longer used by any node.
2024-04-17 14:40:23.650428728 [I:onnxruntime:, graph.cc:3556 CleanUnusedInitializersAndNodeArgs] Removing initializer 'ortshared_7_0_1_7'. It is no longer used by any node.
2024-04-17 14:40:23.650434231 [I:onnxruntime:, graph.cc:3556 CleanUnusedInitializersAndNodeArgs] Removing initializer 'ortshared_7_0_1_6'. It is no longer used by any node.
2024-04-17 14:40:23.650441269 [I:onnxruntime:, graph.cc:3556 CleanUnusedInitializersAndNodeArgs] Removing initializer 'ortshared_7_0_1_4'. It is no longer used by any node.
2024-04-17 14:40:23.650446846 [I:onnxruntime:, graph.cc:3556 CleanUnusedInitializersAndNodeArgs] Removing initializer 'ortshared_7_0_1_3'. It is no longer used by any node.
2024-04-17 14:40:23.650456481 [I:onnxruntime:, graph.cc:3556 CleanUnusedInitializersAndNodeArgs] Removing initializer 'ortshared_7_1_1_1'. It is no longer used by any node.
2024-04-17 14:40:23.650463897 [I:onnxruntime:, graph.cc:3556 CleanUnusedInitializersAndNodeArgs] Removing initializer 'ortshared_7_0_1_1'. It is no longer used by any node.
2024-04-17 14:40:23.650473228 [I:onnxruntime:, graph.cc:3556 CleanUnusedInitializersAndNodeArgs] Removing initializer 'ortshared_7_0_1_0'. It is no longer used by any node.
2024-04-17 14:40:23.652591674 [I:onnxruntime:, graph.cc:3556 CleanUnusedInitializersAndNodeArgs] Removing initializer '_val_148'. It is no longer used by any node.
2024-04-17 14:40:23.654060774 [I:onnxruntime:, graph.cc:3556 CleanUnusedInitializersAndNodeArgs] Removing initializer 'ortshared_7_0_1_5'. It is no longer used by any node.
2024-04-17 14:40:23.654923282 [I:onnxruntime:, graph.cc:3556 CleanUnusedInitializersAndNodeArgs] Removing initializer 'ortshared_7_0_1_2'. It is no longer used by any node.
2024-04-17 14:40:27.459510017 [I:onnxruntime:, constant_sharing.cc:256 ApplyImpl] Total shared scalar initializer count: 2

Urgency

This is a key step to use our model in production

Platform

Linux

OS Version

Linux 9.3

ONNX Runtime Installation

Released Package

ONNX Runtime Version or Commit ID

1.16.3

ONNX Runtime API

Python

Architecture

X86_64

Execution Provider

Default CPU

Execution Provider Library Version

No response

@yuslepukhin yuslepukhin added the core runtime issues related to core runtime label Apr 17, 2024
@doloresgarcia
Copy link
Author

Hi @justinchuby, I am wondering if this could be related to the intro of new transformations 124160. Do you think it could be the case? (sorry to bother you again)

@justinchuby
Copy link
Contributor

Hi @justinchuby, I am wondering if this could be related to the intro of new transformations 124160. Do you think it could be the case? (sorry to bother you again)

I suspect there maybe another cause. Could you test with the latest ONNX Runtime release to see if it is still an issue?

@doloresgarcia
Copy link
Author

doloresgarcia commented Apr 18, 2024

Thanks for checking @justinchuby! I tested now with 1.17.3 and it is still the case :/

@justinchuby
Copy link
Contributor

Is the model open source? Could you share source code to it?

@justinchuby
Copy link
Contributor

justinchuby commented Apr 18, 2024

Please try the following:

set the env var TORCHLIB_EXPERIMENTAL_PREFER_TRACING=1 before running the pytorch export script to get the model, then inline the model with

model_proto = onnx.load("model.onnx")
inlined = onnx.inliner.inline_local_functions(model_proto)
onnx.save(inlined, "model_inlined.onnx")

not guaranteed to succeed but curious if that would help.

@yuslepukhin
Copy link
Member

Try different optimization levels and see if this affects the outcome.

@justinchuby
Copy link
Contributor

Some observations: the model has ~350k nodes

@doloresgarcia
Copy link
Author

Is the model open source? Could you share source code to it?
The model is an adaptation of the gatr (just removing the ._VF einsums so that it is onnx exportable)
https://github.com/Qualcomm-AI-research/geometric-algebra-transformer/blob/main/gatr/nets/gatr.py

@doloresgarcia
Copy link
Author

doloresgarcia commented Apr 18, 2024

Please try the following:

set the env var TORCHLIB_EXPERIMENTAL_PREFER_TRACING=1 before running the pytorch export script to get the model, then inline the model with

model_proto = onnx.load("model.onnx")
inlined = onnx.inliner.inline_local_functions(model_proto)
onnx.save(inlined, "model_inlined.onnx")

not guaranteed to succeed but curious if that would help.

This code runs, and returns the inlined model. The InferenceSession log now shows an error:

2024-04-18 23:22:23.255135644 [W:onnxruntime:, constant_folding.cc:212 ApplyImpl] Could not find a CPU kernel and hence can't constant fold CastLike node 'n1__11634_2008'
2024-04-18 23:22:23.255240965 [W:onnxruntime:, constant_folding.cc:212 ApplyImpl] Could not find a CPU kernel and hence can't constant fold CastLike node 'n1__11602_1985'
sess.initialize_session(providers, provider_options, disabled_optimizers)
onnxruntime.capi.onnxruntime_pybind11_state.Fail: [ONNXRuntimeError] : 1 : FAIL : Node (n0__11868) Op (Mul) [ShapeInferenceError] Incompatible dimensions

@doloresgarcia
Copy link
Author

Some observations: the model has ~350k nodes

Would just this make the inference session take too long to start or not start at all?

@doloresgarcia
Copy link
Author

doloresgarcia commented Apr 19, 2024

Try different optimization levels and see if this affects the outcome.

Thanks for the reply @yuslepukhin
with ort.GraphOptimizationLevel.ORT_DISABLE_ALL it initializes the session (after 3h)
Then there is also a bug on shapes
Status Message: updates tensor should have shape equal to indices.shape[:-1] + data.shape[indices.shape[-1]:]. updates shape: {}, indices shape: {3,1}, data shape: {4,4}

What is the correct way to debug this? I have no information about where to look for this operation in the original code. I am assuming this is a conversion error.

@yuslepukhin
Copy link
Member

yuslepukhin commented Apr 19, 2024

Some observations: the model has ~350k nodes

Would just this make the inference session take too long to start or not start at all?

The model inlining takes a lot of time. Stand by.
How exactly the conversion was performed?

@doloresgarcia
Copy link
Author

I have optimized the model and now I can start the inference session and run it. Thank you @yuslepukhin and @justinchuby :)

@justinchuby
Copy link
Contributor

I have optimized the model and now I can start the inference session and run it. Thank you @yuslepukhin and @justinchuby :)

Awesome! Curious what was done?

@yuslepukhin
Copy link
Member

The initial model fails the check from ONNX:

This is from the ORT Optimized model (inlining only)

Graph must be in single static assignment (SSA) form, however '_inlfunc_IsScalar_tmp' has been used as output names multiple times.

==> Context: Bad node spec for node. Name: _inlfunc_aten_mean_dim_n1 OpType: If

@doloresgarcia
Copy link
Author

doloresgarcia commented Apr 23, 2024

I have optimized the model and now I can start the inference session and run it. Thank you @yuslepukhin and @justinchuby :)

Awesome! Curious what was done?

The graph had many constant that were created by the model inside functions, I initialized those with the model instead. Also there were some conversion errors like:
x[...,index_list] is not converted well and has to be modified to use torch.index_select.
However, operations like einsum do not seem to be dynamic with input shape (this is for a GNN like architecture) so that is problematic.

@justinchuby
Copy link
Contributor

initialized those with the model instead

Do you mean turning Constant operators into graph initializers?

einsum do not seem to be dynamic with input shape

Could you share a concrete example?

@phierhager
Copy link

Hello @doloresgarcia ,
I try to convert, save, load and run a custom PyTorch model via ONNX runtime. However, as in your case, the run gets stuck and I get no clear error messages besides UnsqueezeElimination cannot remove node _inlfunc_aten_mean_dim_n1 and UnsqueezeElimination cannot remove node _inlfunc_aten_mean_dim_token_14647_n1. If I turn off the optimization, I get no error message and the process gets killed after a while.
Can you give some guidance on what you did exactly besides the torch.index_select to optimize the model for onnxruntime to work? That would be of great help!
Thank you.

@doloresgarcia
Copy link
Author

initialized those with the model instead

Do you mean turning Constant operators into graph initializers?

einsum do not seem to be dynamic with input shape

Could you share a concrete example?

I mean just matrices that were created inside functions that were used in many layers. So a solution was to add those as arguments of the main model class and pass them to those layers. This reduced the time to start inference and now it works quickly.

@doloresgarcia
Copy link
Author

Hello @doloresgarcia , I try to convert, save, load and run a custom PyTorch model via ONNX runtime. However, as in your case, the run gets stuck and I get no clear error messages besides UnsqueezeElimination cannot remove node _inlfunc_aten_mean_dim_n1 and UnsqueezeElimination cannot remove node _inlfunc_aten_mean_dim_token_14647_n1. If I turn off the optimization, I get no error message and the process gets killed after a while. Can you give some guidance on what you did exactly besides the torch.index_select to optimize the model for onnxruntime to work? That would be of great help! Thank you.

I am using the torch.onnx.dynamo_export which seems to support more complex models than the torch.onnx.export
Also, disabling the graph optimization as you say
so.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL
but my error messages appeared even after disabling this so I guess it must be something different.

@phierhager
Copy link

Hello @doloresgarcia , I try to convert, save, load and run a custom PyTorch model via ONNX runtime. However, as in your case, the run gets stuck and I get no clear error messages besides UnsqueezeElimination cannot remove node _inlfunc_aten_mean_dim_n1 and UnsqueezeElimination cannot remove node _inlfunc_aten_mean_dim_token_14647_n1. If I turn off the optimization, I get no error message and the process gets killed after a while. Can you give some guidance on what you did exactly besides the torch.index_select to optimize the model for onnxruntime to work? That would be of great help! Thank you.

I am using the torch.onnx.dynamo_export which seems to support more complex models than the torch.onnx.export Also, disabling the graph optimization as you say so.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL but my error messages appeared even after disabling this so I guess it must be something different.

Okay, thank you for the quick reply.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
core runtime issues related to core runtime
Projects
None yet
Development

No branches or pull requests

4 participants