From 138aa9aa9d72569fd31ac6a59d2daca9b3983408 Mon Sep 17 00:00:00 2001 From: Kevin Chen Date: Thu, 3 Apr 2025 13:34:00 -0700 Subject: [PATCH] Remove explicit batch network flag for TRT 10+ Signed-off-by: Kevin Chen --- .../core/providers/tensorrt/tensorrt_execution_provider.cc | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index 2c443cdc33458..ded135bf50ec8 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -2296,8 +2296,9 @@ SubGraphCollection_t TensorrtExecutionProvider::GetSupportedList(SubGraphCollect auto network_flags = 0; #if NV_TENSORRT_MAJOR > 8 network_flags |= fp16_enable_ || int8_enable_ ? 0 : 1U << static_cast(nvinfer1::NetworkDefinitionCreationFlag::kSTRONGLY_TYPED); -#endif +#else network_flags |= 1U << static_cast(nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH); +#endif auto trt_network = std::unique_ptr(trt_builder->createNetworkV2(network_flags)); auto trt_parser = tensorrt_ptr::unique_pointer(nvonnxparser::createParser(*trt_network, trt_logger)); @@ -2907,8 +2908,9 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView auto network_flags = 0; #if NV_TENSORRT_MAJOR > 8 network_flags |= fp16_enable_ || int8_enable_ ? 0 : 1U << static_cast(nvinfer1::NetworkDefinitionCreationFlag::kSTRONGLY_TYPED); -#endif +#else network_flags |= 1U << static_cast(nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH); +#endif auto trt_network = std::unique_ptr(trt_builder->createNetworkV2(network_flags)); auto trt_config = std::unique_ptr(trt_builder->createBuilderConfig()); auto trt_parser = tensorrt_ptr::unique_pointer(nvonnxparser::createParser(*trt_network, trt_logger));