From 674afbe1cc324883d1e388102430154dc42ec12e Mon Sep 17 00:00:00 2001 From: Chi Lo <54722500+chilo-ms@users.noreply.github.com> Date: Mon, 29 Jan 2024 17:36:27 -0800 Subject: [PATCH] [TensorRT EP] Fix InferenceSession::Run() not thread-safe issue (#19301) Given that InferenceSession::Run() is guaranteed to be thread-safe meaning multiple threads can call this function concurrently, TRT EP needs to carefully take care of concurrency here, if not, following concurrent issue might happen: - It's suggested that to perform inference concurrently in multiple streams, use one trt execution context per stream. In the design of TRT EP (Not apply per-thread context implementation) and if multiple threads are calling InferenceSession::Run() concurrently, the trt execution context instance is shared by all the threads and each thread aquires different stream from ORT. So TRT EP will end up having one trt execution context using multiple streams which is not suggested. But, since the whole compute_func() is protected by the lock and if cudaStreamSynchronize() is enforced here, one trt execution context per stream is guaranteed. Therefore, TRT EP needs to call cudaStreamSynchronize() at compute_func() which means to wait until stream has completed all operations to prevent the concurrent github isse: https://github.com/microsoft/onnxruntime/issues/19275 --- .../tensorrt/tensorrt_execution_provider.cc | 48 ++++++++++++++++--- .../tensorrt/tensorrt_execution_provider.h | 6 +-- 2 files changed, 43 insertions(+), 11 deletions(-) diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index aa02d8384afa..795c85a478b4 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -1657,6 +1657,16 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv } } + // cuda graph: + // cudaStreamSynchronize() is not allowed in cuda graph capture. + // + // external stream: + // If user provides "external" cuda stream, only this cuda stream will be used even if multiple threads are running InferenceSession.Run() concurrently. + // So, no need to synchronize different streams after enqueueV3. + if (cuda_graph_enable_ || external_stream_) { + sync_stream_after_enqueue_ = false; + } + { auto lock = GetApiLock(); runtime_ = std::unique_ptr(nvinfer1::createInferRuntime(GetTensorrtLogger())); @@ -2491,7 +2501,6 @@ TensorrtExecutionProvider::GetCapability(const GraphViewer& graph, } else if (number_of_trt_nodes == number_of_ort_nodes) { LOGS_DEFAULT(INFO) << "[TensorRT EP] Whole graph will run on TensorRT execution provider"; } else { - sync_stream_after_enqueue_ = true; LOGS_DEFAULT(INFO) << "[TensorRT EP] Graph is partitioned and number of subgraphs running on TensorRT execution provider is " << number_of_subgraphs; } @@ -3078,7 +3087,7 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView *p = {context->allocate_func, context->release_func, context->allocator_handle, context->node_name, builder_.get(), &parsers_[context->node_name], &engines_[context->node_name], &contexts_[context->node_name], &networks_[context->node_name], input_info_[context->node_name], output_info_[context->node_name], - input_shape_ranges_[context->node_name], sync_stream_after_enqueue_, &tensorrt_mu_, fp16_enable_, int8_enable_, int8_calibration_cache_available_, + input_shape_ranges_[context->node_name], &tensorrt_mu_, fp16_enable_, int8_enable_, int8_calibration_cache_available_, dla_enable_, dla_core_, &max_workspace_size_, trt_node_name_with_precision, engine_cache_enable_, cache_path_, runtime_.get(), profiles_[context->node_name], context_memory_sharing_enable_, &max_ctx_mem_size_, dynamic_range_map, engine_decryption_enable_, engine_decryption_, engine_encryption_, timing_cache_enable_, @@ -3106,7 +3115,6 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView const std::unordered_map& input_indexes = (trt_state->input_info)[0]; const std::unordered_map& output_indexes = (trt_state->output_info)[0]; const std::unordered_map& output_types = (trt_state->output_info)[1]; - bool sync_stream_after_enqueue = trt_state->sync_stream_after_enqueue; auto fused_node_name = trt_state->fused_node_name; auto& shape_ranges = trt_state->input_shape_ranges; auto& dds_output_allocator_map = this->dds_output_allocator_maps_[fused_node_name]; @@ -3499,7 +3507,21 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "TensorRT EP execution context enqueue failed."); } - if (sync_stream_after_enqueue || dds_output_set.size() > 0) { + /* + * Given that InferenceSession::Run() is guaranteed to be thread-safe meaning multiple threads can call this function concurrently, + * TRT EP needs to carefully take care of concurrency here, if not, following concurrent issue might happen: + * + * It's suggested that to perform inference concurrently in multiple streams, use one trt execution context per stream. + * In the design of TRT EP (Not apply per-thread context implementation) and if multiple threads are calling InferenceSession::Run() concurrently, + * the trt execution context instance is shared by all the threads and each thread aquires different stream from ORT. + * So TRT EP will end up having one trt execution context using multiple streams which is not suggested. + * But, since the whole compute_func() is protected by the lock and if cudaStreamSynchronize() is enforced here, one trt execution context per stream + * is guaranteed. + * + * Therefore, TRT EP needs to call cudaStreamSynchronize() which means to wait until stream has completed all operations to prevent the concurrent issue mentioned above. + * However, if cuda graph is enabled, TRT EP won't call cudaStreamSynchronize() since it's not allowed during graph capture. + */ + if (sync_stream_after_enqueue_) { CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream)); } @@ -3643,7 +3665,6 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(con &contexts_[context->node_name], input_info_[context->node_name], output_info_[context->node_name], - sync_stream_after_enqueue_, context_memory_sharing_enable_, &max_ctx_mem_size_, &tensorrt_mu_}; @@ -3670,7 +3691,6 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(con const std::unordered_map& output_indexes = (trt_state->output_info)[0]; const std::unordered_map& output_types = (trt_state->output_info)[1]; auto fused_node_name = trt_state->fused_node_name; - bool sync_stream_after_enqueue = trt_state->sync_stream_after_enqueue; auto& dds_output_allocator_map = this->dds_output_allocator_maps_[fused_node_name]; auto trt_engine = trt_state->engine->get(); auto trt_context = trt_state->context->get(); @@ -3780,7 +3800,21 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(con return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "TensorRT EP execution context enqueue failed."); } - if (sync_stream_after_enqueue || dds_output_set.size() > 0) { + /* + * Given that InferenceSession::Run() is guaranteed to be thread-safe meaning multiple threads can call this function concurrently, + * TRT EP needs to carefully take care of concurrency here, if not, following concurrent issue might happen: + * + * It's suggested that to perform inference concurrently in multiple streams, use one trt execution context per stream. + * In the design of TRT EP (Not apply per-thread context implementation) and if multiple threads are calling InferenceSession::Run() concurrently, + * the trt execution context instance is shared by all the threads and each thread aquires different stream from ORT. + * So TRT EP will end up having one trt execution context using multiple streams which is not suggested. + * But, since the whole compute_func() is protected by the lock and if cudaStreamSynchronize() is enforced here, one trt execution context per stream + * is guaranteed. + * + * Therefore, TRT EP needs to call cudaStreamSynchronize() which means to wait until stream has completed all operations to prevent the concurrent issue mentioned above. + * However, if cuda graph is enabled, TRT EP won't call cudaStreamSynchronize() since it's not allowed during graph capture. + */ + if (sync_stream_after_enqueue_) { CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream)); } diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h index 401a8da119ac..d21b6bf02804 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h @@ -149,7 +149,6 @@ struct TensorrtFuncState { std::vector> input_info; std::vector> output_info; std::unordered_map>>> input_shape_ranges; - bool sync_stream_after_enqueue = false; OrtMutex* tensorrt_mu_ptr = nullptr; bool fp16_enable = false; bool int8_enable = false; @@ -193,7 +192,6 @@ struct TensorrtShortFuncState { std::unique_ptr* context = nullptr; std::vector> input_info; std::vector> output_info; - bool sync_stream_after_enqueue = false; bool context_memory_sharing_enable = false; size_t* max_context_mem_size_ptr = nullptr; OrtMutex* tensorrt_mu_ptr = nullptr; @@ -332,8 +330,8 @@ class TensorrtExecutionProvider : public IExecutionProvider { cudnnHandle_t external_cudnn_handle_ = nullptr; cublasHandle_t external_cublas_handle_ = nullptr; - // Call cudaStreamSynchronize() after TRT enqueueV2()/enqueueV3() - mutable bool sync_stream_after_enqueue_ = false; + // Call cudaStreamSynchronize() after TRT enqueueV3() + mutable bool sync_stream_after_enqueue_ = true; CUDAGraph cuda_graph_; bool is_graph_captured_ = false;