Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature Request] TensorRT Explicit Engine Profile Values #13851

Closed
seddonm1 opened this issue Dec 6, 2022 · 10 comments
Closed

[Feature Request] TensorRT Explicit Engine Profile Values #13851

seddonm1 opened this issue Dec 6, 2022 · 10 comments
Labels
ep:TensorRT issues related to TensorRT execution provider feature request request for unsupported feature or enhancement

Comments

@seddonm1
Copy link

seddonm1 commented Dec 6, 2022

Describe the feature request

The current behavior of the TensorRT Runtime when using Dynamic Batch is to use the shape of the incoming tensor to programatically determine the OptProfileSelector:kMIN, OptProfileSelector:kMAX and OptProfileSelector:kOPT. These values are then passed to the TensorRT engine builder to optimize the model - which can be and extremely slow process (observed to be up to ~10 minutes with some models). The problem with this is that because Onnxruntime does not necessarily have the information it needs to build a correctly shaped model it can end up building the model multiple times.

This proposal (and I am happy to do the work) is to allow a user to optionally pass in explicit values for kMIN, kMAX, kOPT in the tensorrt_provider_options.

The current algorithm basically something like.

  • if no existing profile then instantiate to first to INT_MAX and second to -INT_MAX else read the two values from the profile file
  • if first > incoming tensor batch_size then first = batch_size and indicate model needs rebuild
  • if second value < incoming tensor batch_size then second = batch_size and indicate model needs rebuild
  • build mode with kMIN = first and kMAX = second and kOPT = second

Describe scenario use case

This proposal is to bring OnnxRuntime TensorRT provider into similar behavior to NVIDIA Deepstream where they explicity set kMIN = 1 and kMAX and kOPT to a user supplied batch_size.

This proposal does not intend to change existing behavior but allow a user to override if they want.

@seddonm1 seddonm1 added the feature request request for unsupported feature or enhancement label Dec 6, 2022
@github-actions github-actions bot added the ep:TensorRT issues related to TensorRT execution provider label Dec 6, 2022
@fxmarty
Copy link
Contributor

fxmarty commented Jan 2, 2023

+1

@jywu-msft
Copy link
Member

Thanks! we're looking into implementing this feature,
as well as ability to attach multiple profiles to a single engine, so one can switch between profiles for a single engine.
We will reach out to get some feedback about the interface for passing in the profile values. if there are multiple subgraphs/engines how do we map the profile values to the engines.

@seddonm1
Copy link
Author

Great. This is really important.

I had to write some (super hacky) code to find the specific input tensors I was requiring and overriding the code.

NVIDIA trtexec allows the user to specify input shapes for named tensors like: --shapes=input:32x3x244x244. Maybe this is a good pattern?

@jywu-msft
Copy link
Member

jywu-msft commented Mar 31, 2023

Great. This is really important.

I had to write some (super hacky) code to find the specific input tensors I was requiring and overriding the code.

NVIDIA trtexec allows the user to specify input shapes for named tensors like: --shapes=input:32x3x244x244. Maybe this is a good pattern?

yes, I think we can do something like that.
In addition to overriding shape profile, it would probably make sense to support overriding engine file (e.g. one created by trtexec)
we can only support these overrides if the graph is fully supported by TensorRT (no partitions assigned to other EP's)

Our goal is to support more features of trtexec in the coming months.

@jywu-msft
Copy link
Member

@seddonm1 @fxmarty FYI, there's a PR in development for the profile override feature you requested.
Would it be possible to try this dev branch to confirm it can meet your requirements?
we are hoping to get this feature in as part of the upcoming ORT 1.15 release (targeted towards end of May)
explicit engine file override and multi-profile support should follow soon after.
thanks.

@seddonm1
Copy link
Author

seddonm1 commented Apr 21, 2023

Hi @jywu-msft
I have compiled and tried to run but received this error when passed this string images:1x3x640x640:

[TensorRT EP] The format of provider option 'trt_profile_min_shapes' is wrong, please follow the format of 'input1:dim1xdimd2...,input2:dim1xdim2...,...'

It could be my end and I can investigate further next week.

Apart from that I am concerned about (1) Whole graph can be run on TRT. I think this means that a graph with an Onnx NonMaxSupression node will not work. I will fix my issue then get back to you with proper feedback.

@chilo-ms
Copy link
Contributor

chilo-ms commented Apr 21, 2023

Hi @jywu-msft I have compiled and tried to run but received this error when passed this string images:1x3x640x640:

[TensorRT EP] The format of provider option 'trt_profile_min_shapes' is wrong, please follow the format of 'input1:dim1xdimd2...,input2:dim1xdim2...,...'

It could be my end and I can investigate further next week.

Apart from that I am concerned about (1) Whole graph can be run on TRT. I think this means that a graph with an Onnx NonMaxSupression node will not work. I will fix my issue then get back to you with proper feedback.

Hi @seddonm1

If possible, could you share the whole provider options string that you used? From the error message, it seems you have the wrong format of provider option 'trt_profile_min_shapes'. Also, you can turn on verbose logging to see the whole log including the provider options as well as the parsing of the profile string.

Following is the example I tested for these explicit profiles using onnxruntime_perf_test:
(Note: you need to set those three provider options, "trt_profile_min_shapes|imgs:1x3x384x288 trt_profile_max_shapes|imgs:32x3x384x288 trt_profile_opt_shapes|imgs:32x3x384x288")

./build/Linux/Release/./onnxruntime_perf_test -e tensorrt -r 100 -i "trt_profile_min_shapes|imgs:1x3x384x288 trt_profile_max_shapes|imgs:32x3x384x288 trt_profile_opt_shapes|imgs:32x3x384x288" ~/test.onnx
....
2023-04-21 17:59:14.249626764 [V:onnxruntime:Default, tensorrt_execution_provider_utils.h:386 MakeInputNameShapePair] [TensorRT EP] imgs:1x3x384x288
2023-04-21 17:59:14.249665939 [V:onnxruntime:Default, tensorrt_execution_provider_utils.h:445 ParseProfileShapes] [TensorRT EP] imgs
2023-04-21 17:59:14.249676018 [V:onnxruntime:Default, tensorrt_execution_provider_utils.h:451 ParseProfileShapes] [TensorRT EP] 1, 3, 384, 288,
2023-04-21 17:59:14.249685025 [V:onnxruntime:Default, tensorrt_execution_provider_utils.h:386 MakeInputNameShapePair] [TensorRT EP] imgs:32x3x384x288
2023-04-21 17:59:14.249692519 [V:onnxruntime:Default, tensorrt_execution_provider_utils.h:445 ParseProfileShapes] [TensorRT EP] imgs
2023-04-21 17:59:14.249700665 [V:onnxruntime:Default, tensorrt_execution_provider_utils.h:451 ParseProfileShapes] [TensorRT EP] 32, 3, 384, 288,
2023-04-21 17:59:14.249706606 [V:onnxruntime:Default, tensorrt_execution_provider_utils.h:386 MakeInputNameShapePair] [TensorRT EP] imgs:32x3x384x288
2023-04-21 17:59:14.249715483 [V:onnxruntime:Default, tensorrt_execution_provider_utils.h:445 ParseProfileShapes] [TensorRT EP] imgs
2023-04-21 17:59:14.249723679 [V:onnxruntime:Default, tensorrt_execution_provider_utils.h:451 ParseProfileShapes] [TensorRT EP] 32, 3, 384, 288,
2023-04-21 17:59:14.249732706 [V:onnxruntime:Default, tensorrt_execution_provider.cc:908 TensorrtExecutionProvider] [TensorRT EP] TensorRT provider options: device_id: 0, trt_max_partition_iterations: 1000, trt_min_subgraph_size: 1, trt_max_workspace_size: 1073741824, trt_fp16_enable: 0, trt_int8_enable: 0, trt_int8_calibration_cache_name: , int8_calibration_cache_available: 0, trt_int8_use_native_tensorrt_calibration_table: 0, trt_dla_enable: 0, trt_dla_core: 0, trt_dump_subgraphs: 0, trt_engine_cache_enable: 0, trt_cache_path: , trt_engine_decryption_enable: 0, trt_engine_decryption_lib_path: , trt_force_sequential_engine_build: 0, trt_context_memory_sharing_enable: 0, trt_layer_norm_fp32_fallback: 0, trt_build_heuristics_enable: 0, trt_sparsity_enable: 0, trt_builder_optimization_level: 2, trt_auxiliary_streams: -1, trt_tactic_sources: , trt_profile_min_shapes: imgs:1x3x384x288, trt_profile_max_shapes: imgs:32x3x384x288, trt_profile_opt_shapes: imgs:32x3x384x288
...

Re: Apart from that I am concerned about (1) Whole graph can be run on TRT. I think this means that a graph with an Onnx NonMaxSupression node will not work.

This PR actually supports the "graph can be partially run on TRT", could you also try it with your model?
The reason we impose this restriction of "Whole graph can be run on TRT" is to simplify the user scenario of using explicit profiles. We are open to any suggestions, so we do appreciate your feedback after testing it. Thank you!

@seddonm1
Copy link
Author

Hi @chilo-ms

I have fixed my side which was causing the issue I experienced (sorry to waste your time).

This is a good change. I have tested with and without an ONNX NonMaxSupression node and it is working correctly.

Good work 👍

@chilo-ms
Copy link
Contributor

chilo-ms commented Apr 24, 2023

Good to hear that this PR can work for your model.

There is a follow up question that we want to get your opinion on how this explicit profiles working with engine cache.
If your set those explicit profiles provider options as well as trt_engine_cache_enable=true and run with TRT EP, it will generate the engine cache for you. Once the engine cache is being used for another deployment environment, this PR asks user to add additional provider option trt_engine_cache_built_with_explicit_profiles (However, no need to add those explicit profiles provider option) in order to use the engine built with explicit profiles. Otherwise, TRT EP will fallback to original logic of determining the min/max/opt profiles for you and rebuilds the engine if needed.

Do you think this usage easy to use? Really appreciate you feedback!

[UPDATE 4/26] After internal discussion, we decided not to add trt_engine_cache_built_with_explicit_profiles. User needs to provide those three explict profiles regardless of engine cache enable.

chilo-ms added a commit that referenced this issue May 1, 2023
Previous behavior of TRT EP to set TRT optimization profiles for dynamic
shape input is based on input tensor values. Users can't explicitly
specify the profiles.

This PR makes users capable of specifying min/max/opt profiles through
newly added three provider options:

`trt_profile_min_shapes`, `trt_profile_max_shapes` and
`trt_profile_opt_shapes`
with the format of "input1:dim1xdim2...,input2:dim3xdim4...".
(Note: It's similar to --minShapes, --maxShapes and --optShapes of
trtexec command-line
[flags](https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#trtexec-flags))

For example, if you are using onnxruntime_perf_test, you can try this:

`./onnxruntime_perf_test -e tensorrt -r 1 -i
"trt_profile_min_shapes|imgs:1x3x384x288
trt_profile_max_shapes|imgs:32x3x384x288
trt_profile_opt_shapes|imgs:16x3x384x288" your_model_path`

If the engine cache is enabled, you still need to provide these three
explicit provider options in order to use this feature. ORT TRT will
compare the min/max/opt profile shape with the ones saved in .profile
file to decide whether to rebuild the engine.

Constraints to use these provider options: (1) Need to specify
min/max/opt profile shapes for all the dynamic shape input

 

This feature is also requested by other users:
#13851
@chilo-ms
Copy link
Contributor

chilo-ms commented May 1, 2023

Hi @seddonm1

Just to let you know that this feature now is merged to main.
Feel free to try it out again and let me know if there is any issues. Thanks.

@chilo-ms chilo-ms closed this as completed May 1, 2023
ShukantPal pushed a commit to ShukantPal/onnxruntime that referenced this issue May 7, 2023
Previous behavior of TRT EP to set TRT optimization profiles for dynamic
shape input is based on input tensor values. Users can't explicitly
specify the profiles.

This PR makes users capable of specifying min/max/opt profiles through
newly added three provider options:

`trt_profile_min_shapes`, `trt_profile_max_shapes` and
`trt_profile_opt_shapes`
with the format of "input1:dim1xdim2...,input2:dim3xdim4...".
(Note: It's similar to --minShapes, --maxShapes and --optShapes of
trtexec command-line
[flags](https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#trtexec-flags))

For example, if you are using onnxruntime_perf_test, you can try this:

`./onnxruntime_perf_test -e tensorrt -r 1 -i
"trt_profile_min_shapes|imgs:1x3x384x288
trt_profile_max_shapes|imgs:32x3x384x288
trt_profile_opt_shapes|imgs:16x3x384x288" your_model_path`

If the engine cache is enabled, you still need to provide these three
explicit provider options in order to use this feature. ORT TRT will
compare the min/max/opt profile shape with the ones saved in .profile
file to decide whether to rebuild the engine.

Constraints to use these provider options: (1) Need to specify
min/max/opt profile shapes for all the dynamic shape input

 

This feature is also requested by other users:
microsoft#13851
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ep:TensorRT issues related to TensorRT execution provider feature request request for unsupported feature or enhancement
Projects
None yet
Development

No branches or pull requests

4 participants