Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

a perfermance issue when use onnx runtime-tensorrt #19934

Open
hy846130226 opened this issue Mar 15, 2024 · 18 comments
Open

a perfermance issue when use onnx runtime-tensorrt #19934

hy846130226 opened this issue Mar 15, 2024 · 18 comments
Assignees
Labels
ep:TensorRT issues related to TensorRT execution provider performance issues related to performance regressions platform:windows issues related to the Windows platform

Comments

@hy846130226
Copy link

Describe the issue

I'm using the onnx runtime-tensorrt
I found every time when I load the onnx model, it will cost some time, it may be a little short or long.
So I print the log.
image
I'd like to know what the red standard areas spend their time doing.

To reproduce

just use the onnx runtime-tensorrt to run a onnx model.

Urgency

No response

Platform

Windows

OS Version

WIN10

ONNX Runtime Installation

Released Package

ONNX Runtime Version or Commit ID

1.6.3

ONNX Runtime API

C++

Architecture

X64

Execution Provider

TensorRT

Execution Provider Library Version

CUDA 11.6 Tensorrt 8.6

@github-actions github-actions bot added ep:CUDA issues related to the CUDA execution provider ep:TensorRT issues related to TensorRT execution provider platform:windows issues related to the Windows platform labels Mar 15, 2024
@jywu-msft
Copy link
Member

there are various strategies for reducing the session initialization time. we're in the process of putting together a doc to provide guidance.
+@chilo-ms

@hy846130226
Copy link
Author

Hi @jywu-msft
Thanks! It is very helpful if we have such a document.

@hy846130226
Copy link
Author

I have read the source code and found this operation cost much time.
Could someone tell me why? Is the onnx do something optimze in the model?

image

@hy846130226
Copy link
Author

Oh I found the main place to spend time.
It's here:
image
it seems the onnx is loading the tensorrt ep.

How did they do it?
By reflexing the dll? or something?
Why it cost so much time?

@jywu-msft
Copy link
Member

jywu-msft commented Mar 20, 2024

there are 2 areas which cost the most time during tensorrt EP initialization.

  1. TensorRT builder instantiation. here it loads a DLL with tensorrt kernels.
  2. TensorRT engine build. (this can take the most time because it is doing kernel auto-tuning, where it measures timings for different kernels/tactics.
    For 2), there is an option to enable serializing a built engine to disk so that you don't need to rebuild it next time you initialize a session. the option is trt_engine_cache_enable , can you try it?
    to avoid 1) is a little more complicated. if 2) is enough, then you can try that first.
    @chilo-ms to add more comments.

@hy846130226
Copy link
Author

Hi @jywu-msft

I see 2)
And I found even I use the trt_engine_cache_enable, it still cost time, but it indeed cost shorter.
Because it generate the trt IBuilder, it cost some time, but as my knowledage, If I had an off-the-shelf trt model,I just need IRuntime

Like this:
image

So why the onnx-trt not check if enable the trt_engine_cache_enable, if it does, do not load the IBuilder?

@hy846130226
Copy link
Author

And about 1)

I think it is indeed not easy.
Can you roughly describe the process for me? I'm having a bit of trouble understanding the code, so if you could that would be greatly appreciated!

@chilo-ms
Copy link
Contributor

chilo-ms commented Mar 21, 2024

So why the onnx-trt not check if enable the trt_engine_cache_enable, if it does, do not load the IBuilder?

ORT TRT has this similar feature (starts from 1.17.0) which skips TRT builder instantiation and simply deserializes engine cache to run inference.

However, we still need an "ONNX" model to start with. So, ORT TRT helps user create the "embed engine" model which is basically an ONNX model contains only one node that wraps the engine cache.
Run this embed engine model to skip those lengthy processes such as TRT builder instantiation.

Please see below the highlighted part to know how to use ORT TRT provider options to generate/run embed engine model.
image

BTW, we are working on documenting the usage of embed engine model.
Also note that there are constraints using it, such as

  • whole model should be TRT eligible.
  • It supports dynamic shape input only when user explicit specifies the shape range meaning engine won't be rebuilt for all the inference runs.

@sophies927 sophies927 removed the ep:CUDA issues related to the CUDA execution provider label Mar 21, 2024
@hy846130226
Copy link
Author

Hi @chilo-ms,

I try to use the trt_dumo_ep_context_model like following:

image

But I got error:
[ONNXRuntimeError] : 1 : FAIL : provider_options_utils.h:148 onnxruntime::ProviderOptionsParser::Parse Unknown provider option: "trt_dump_ep_context_model".

@hy846130226
Copy link
Author

And I try to modify the source code simiply, I comments the filds about IBuilder, INetworkDefinition, IParser.

image
image

I found it could still work.

This is a simply version, I know.

I will continue to debug if this way will cause something errors, also, I want to know if I have a tensorrt model in trt_engine_cache_path, and enable the trt_engine_cache_enable, I do not initialize IBuilder, is this way correct?

@hy846130226
Copy link
Author

And I try to modify the source code simiply, I comments the filds about IBuilder, INetworkDefinition, IParser.

image image

I found it could still work.

This is a simply version, I know.

I will continue to debug if this way will cause something errors, also, I want to know if I have a tensorrt model in trt_engine_cache_path, and enable the trt_engine_cache_enable, I do not initialize IBuilder, is this way correct?

I think if I comment those fields about IBuilder, INetworkDefinition, IParser, so that the outside could not get the associated object, it also could prove that the outside does not use those objects, right?

@chilo-ms
Copy link
Contributor

Hi @chilo-ms,

I try to use the trt_dumo_ep_context_model like following:

image

But I got error: [ONNXRuntimeError] : 1 : FAIL : provider_options_utils.h:148 onnxruntime::ProviderOptionsParser::Parse Unknown provider option: "trt_dump_ep_context_model".

What ORT version are you using?
Please use 1.17.0 or above or main branch.

@chilo-ms
Copy link
Contributor

And I try to modify the source code simiply, I comments the filds about IBuilder, INetworkDefinition, IParser.

image image

I found it could still work.

This is a simply version, I know.

I will continue to debug if this way will cause something errors, also, I want to know if I have a tensorrt model in trt_engine_cache_path, and enable the trt_engine_cache_enable, I do not initialize IBuilder, is this way correct?

Your idea is basically right.
Please see the ORT TRT code (here and here) in main branch.

In additions to the code path (in EP Compile) you found that it involves builder instantiation, there is also builder instantization in the EP GetCapability. So that's why we need the "Embed Engine" model to skip builder instantization.

@hy846130226
Copy link
Author

Hi @chilo-ms

Thanks for your reply very much!

I will try to remove the process of generating the IBuilder if it already genearta model.

And about the EP GetCapabnility, I also have a question, and here is the link:
#20029

"So that's why we need the "Embed Engine" model to skip builder instantization."I do not know why the EP GetCapability method need to genearte IBuilder Object, as my knowledage, the IBuilder is used to generate some trt objects, such as the INetworkDefinition.

And if I already have a trt model from onnx, could I skip this step in process?

@hy846130226
Copy link
Author

Hi @chilo-ms,
I try to use the trt_dumo_ep_context_model like following:
image
But I got error: [ONNXRuntimeError] : 1 : FAIL : provider_options_utils.h:148 onnxruntime::ProviderOptionsParser::Parse Unknown provider option: "trt_dump_ep_context_model".

What ORT version are you using? Please use 1.17.0 or above or main branch.

Yes my version is 1.16.3.

Because at first, I download your 1.17.0 or 1.17.3 packages, there is no dll in it.
So I use the 1.16.3.

Why the newest packages in nuget don't have dll?

Also I will use the newest code to build the dll.

@jywu-msft
Copy link
Member

jywu-msft commented Mar 23, 2024

use the 1.17.1 nuget package.
there are multiple packages.
i.e. Microsoft.ML.Onnxruntime.Gpu depends on Microsoft.ML.OnnxRuntime.Gpu.Windows
and in that package are the onnxruntime .dll's

@hy846130226
Copy link
Author

Hi @jywu-msft

I try to use the 1.17.1 Microsoft.ML.Onnxruntime.Gpu depends on Microsoft.ML.OnnxRuntime.Gpu.Windows.

But I got the error:
image

I check the structure of 1.17.1 package, I found that the directory was "buildTransitive" not "build", it cause that the vs could not load the props,targets files.

image

I feel confused, am I missing something?

@chilo-ms
Copy link
Contributor

chilo-ms commented Mar 26, 2024

"So that's why we need the "Embed Engine" model to skip builder instantization."I do not know why the EP GetCapability method need to genearte IBuilder Object, as my knowledage, the IBuilder is used to generate some trt objects, such as the INetworkDefinition.

And if I already have a trt model from onnx, could I skip this step in process?

Because TRT parser needs TRT networks which depends on TRT builder.
https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc#L2082

If you have TRT engine cache, you still need the embed engine model to skip the process for now.
Please see the embed engine model (EPContext node model) to skip the whole GetCapability.
Here are two PRs which introduces embed engine model feature.
#18217
#19154
But, we are working on another PR that can skip GetCapability without using the embed engine model but simply with engine cahce. (This is the exact feature that you want)

Also, I'm working on the document for users to better understand this feature.

@sophies927 sophies927 added the performance issues related to performance regressions label Mar 28, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ep:TensorRT issues related to TensorRT execution provider performance issues related to performance regressions platform:windows issues related to the Windows platform
Projects
None yet
Development

No branches or pull requests

4 participants