From 893ba57c22571a1429681cae3f036a7e7b5f7c54 Mon Sep 17 00:00:00 2001 From: Kevin Chen <45886021+kevinch-nv@users.noreply.github.com> Date: Fri, 4 Jul 2025 08:20:02 -0700 Subject: [PATCH 1/7] Add option for GraphViewerToProto serialization to skip writing data (#25263) ### Description Adds `include_initializer_data` option to `GraphViewerToProto` to skip writing initializer raw data and external data when serializing. ### Motivation and Context For TensorRT EP, partitioned graphs must be serialized to proto in order for getCapability() to run. For cases where the weights are not strictly needed (i.e. weightless engines), serializing the graph without initializer data reduces the overall memory required. --------- Signed-off-by: Kevin Chen --- .../core/graph/graph_proto_serializer.cc | 23 +++++++++++++++++-- .../core/graph/graph_proto_serializer.h | 3 ++- .../shared_library/provider_interfaces.h | 3 ++- .../shared_library/provider_wrappedtypes.h | 5 ++-- .../core/session/provider_bridge_ort.cc | 5 ++-- 5 files changed, 31 insertions(+), 8 deletions(-) diff --git a/onnxruntime/core/graph/graph_proto_serializer.cc b/onnxruntime/core/graph/graph_proto_serializer.cc index 80bb3f13814d1..993020278eb03 100644 --- a/onnxruntime/core/graph/graph_proto_serializer.cc +++ b/onnxruntime/core/graph/graph_proto_serializer.cc @@ -11,7 +11,8 @@ void GraphViewerToProto(const GraphViewer& graph_view, ONNX_NAMESPACE::GraphProto& graph_proto, bool include_initializer, bool include_outer_scope_args, - ExecutionOrder order) { + ExecutionOrder order, + bool include_initializer_data) { graph_proto.set_name(graph_view.Name()); graph_proto.set_doc_string(graph_view.Description()); @@ -92,7 +93,25 @@ void GraphViewerToProto(const GraphViewer& graph_view, const auto& [name, init] = *it; current_scope_initializer_set.insert(name); auto* p_initializer = graph_proto.add_initializer(); - ORT_THROW_IF_ERROR(get_initializer_with_data(*init, *p_initializer)); + + // Do not save raw or external data into the graph, only the metadata + if (!include_initializer_data && (init->has_raw_data() || init->has_data_location())) { + // Set datatype + if (init->has_data_type()) { + p_initializer->set_data_type(init->data_type()); + } + // Set name + if (init->has_name()) { + p_initializer->set_name(init->name()); + } + + // Set dims + for (int i = 0; i < init->dims_size(); ++i) { + p_initializer->add_dims(init->dims()[i]); + } + } else { + ORT_THROW_IF_ERROR(get_initializer_with_data(*init, *p_initializer)); + } } // handle outer scope value which is a constant initializer diff --git a/onnxruntime/core/graph/graph_proto_serializer.h b/onnxruntime/core/graph/graph_proto_serializer.h index ce21e1b609b26..2a8180477c476 100644 --- a/onnxruntime/core/graph/graph_proto_serializer.h +++ b/onnxruntime/core/graph/graph_proto_serializer.h @@ -11,5 +11,6 @@ void GraphViewerToProto(const GraphViewer& graph_view, ONNX_NAMESPACE::GraphProto& graph_proto, bool include_initializer, bool include_outer_scope_args, - ExecutionOrder order = ExecutionOrder::DEFAULT); + ExecutionOrder order = ExecutionOrder::DEFAULT, + bool include_initializer_data = true); } // namespace onnxruntime diff --git a/onnxruntime/core/providers/shared_library/provider_interfaces.h b/onnxruntime/core/providers/shared_library/provider_interfaces.h index dba26b3982d86..44dd70211327e 100644 --- a/onnxruntime/core/providers/shared_library/provider_interfaces.h +++ b/onnxruntime/core/providers/shared_library/provider_interfaces.h @@ -1097,7 +1097,8 @@ struct ProviderHost { ONNX_NAMESPACE::GraphProto& graph_proto, bool include_initializers, bool include_outer_scope_args, - int execution_order) noexcept = 0; + int execution_order, + bool include_initializer_data) noexcept = 0; virtual const Node* GraphViewer__GetProducerNode(const GraphViewer* p, const std::string& node_arg_name) const = 0; virtual IOnnxRuntimeOpSchemaCollectionPtr GraphViewer__GetSchemaRegistry(const GraphViewer* p) const = 0; diff --git a/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h b/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h index 80b5e26db8680..23fbead1e9707 100644 --- a/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h +++ b/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h @@ -1150,8 +1150,9 @@ class GraphViewer final { void ToProto(ONNX_NAMESPACE::GraphProto& graph_proto, bool include_initializers, bool include_outer_scope_args, - int execution_order = 0) const { - g_host->GraphViewer__ToProto(this, graph_proto, include_initializers, include_outer_scope_args, execution_order); + int execution_order = 0, + bool include_initializer_data = true) const { + g_host->GraphViewer__ToProto(this, graph_proto, include_initializers, include_outer_scope_args, execution_order, include_initializer_data); } const Node* GetProducerNode(const std::string& node_arg_name) const { return g_host->GraphViewer__GetProducerNode(this, node_arg_name); } IOnnxRuntimeOpSchemaCollectionPtr GetSchemaRegistry() const { return g_host->GraphViewer__GetSchemaRegistry(this); } diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index 8cd16fb4e7347..3db35ae8769e0 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -1389,8 +1389,9 @@ struct ProviderHostImpl : ProviderHost { ONNX_NAMESPACE::GraphProto& graph_proto, bool include_initializers, bool include_outer_scope_args, - int execution_order) noexcept override { - GraphViewerToProto(*p, graph_proto, include_initializers, include_outer_scope_args, static_cast(execution_order)); + int execution_order, + bool include_initializer_data) noexcept override { + GraphViewerToProto(*p, graph_proto, include_initializers, include_outer_scope_args, static_cast(execution_order), include_initializer_data); } const Node* GraphViewer__GetProducerNode(const GraphViewer* p, const std::string& node_arg_name) const override { return p->GetProducerNode(node_arg_name); } IOnnxRuntimeOpSchemaCollectionPtr GraphViewer__GetSchemaRegistry(const GraphViewer* p) const override { return p->GetSchemaRegistry(); } From dfc27cd7c7ea327e3610e0f90ae56b54f9be614c Mon Sep 17 00:00:00 2001 From: Ankit Maheshkar Date: Sat, 5 Jul 2025 00:59:29 +0530 Subject: [PATCH 2/7] [OVEP] OpenVINO EP Features Release 1.23 (#25262) ### Description This pull request includes a wide range of feature updates, optimizations, and bug fixes aimed at improving performance, memory efficiency, dynamic shaped model support, ORT GenAI support for GenAI models viz. LLMs / SLMs , and overall stability of the OpenVINO Execution Provider (OVEP). ### Key Enhancements - Dynamic Shaped Model Support: Added support for inferencing dynamic shaped models using `reshape_input` provider option Enabled workload type handling for dynamic-shaped models - Performance Optimizations: Reduced peak memory usage by optimizing fallback logic and model proto handling. Improved CPU inference path efficiency. Removed unintended model copies during compilation. - ORT GenAI Feature Pass: [ORT GenAI](https://github.com/microsoft/onnxruntime-genai) is now supported using OpenVINO EP using `enable_causallm` provider option as `True` - EPContext OVIR Encapsulation Feature: ORT now supports EpContext Models with OVIR (i.e. model.xml & model.bin) stored into `ep_cache_context` attribute Compilation, Inference & Pre-Compiled Cached Blob Support - Quantization Enhancements: Enabled QDQ stripping path using adaptive stripping. Enabled QDQ Channel Wise Quantization for Intel NPU friendly quantization using `MatMul4BitsQuantizer/ DefaultWeightOnlyQuantConfig` using option `channel_wised_quantize` as `True` ``` from onnxruntime.quantization import matmul_nbits_quantizer # Define quantization configuration and process quant_config = matmul_nbits_quantizer.DefaultWeightOnlyQuantConfig( block_size=128, is_symmetric=True, quant_format=quant_utils.QuantFormat.QDQ, channel_wised_quantize=True) ``` - Operator & Backend Improvements: Added support for the HardSwish operator Fixed logic for unsupported op modes and improved precision accuracy - Bug Fixes: Fixed metadata naming and file path validation Addressed device selection issues and provider key verification Resolved deprecated OV element types and LUID check issues --------- Signed-off-by: Jianhui Dai Signed-off-by: dependabot[bot] Signed-off-by: bfilipek Co-authored-by: jatinwadhwa921 <110383850+jatinwadhwa921@users.noreply.github.com> Co-authored-by: n1harika Co-authored-by: sfatimar Co-authored-by: Jaskaran Singh Nagi Co-authored-by: Eric Crawford Co-authored-by: Sushanth Rajasankar <44513542+sushraja-msft@users.noreply.github.com> Co-authored-by: Scott McKay Co-authored-by: Seungtaek Kim Co-authored-by: co63oc Co-authored-by: Jambay Kinley Co-authored-by: Hector Li Co-authored-by: Jian Chen Co-authored-by: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Co-authored-by: Jiajia Qin Co-authored-by: Alessio Soldano Co-authored-by: Changming Sun Co-authored-by: Ashish Garg Co-authored-by: Ashish Garg Co-authored-by: Jie Chen Co-authored-by: wp Co-authored-by: Satya Kumar Jandhyala Co-authored-by: Prathik Rao Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Tianlei Wu Co-authored-by: Jianhui Dai Co-authored-by: xhcao Co-authored-by: Wanming Lin Co-authored-by: Mark Schofield Co-authored-by: jiangzhaoming Co-authored-by: Yi-Hong Lyu Co-authored-by: vraspar Co-authored-by: Chi Lo <54722500+chilo-ms@users.noreply.github.com> Co-authored-by: saurabh Co-authored-by: Ranjit Ranjan <165394499+ranjitshs@users.noreply.github.com> Co-authored-by: Baiju Meswani Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: jatinwadhwa921 Co-authored-by: Pallavi Gupta Co-authored-by: Nikolay Proshunin Co-authored-by: Preetha Veeramalai Co-authored-by: Javier Martinez Co-authored-by: Bartlomiej Filipek Co-authored-by: bopeng1234 Co-authored-by: MayureshV1 <47039074+MayureshV1@users.noreply.github.com> Co-authored-by: TejalKhade28 Co-authored-by: Vishnudas Thaniel S Co-authored-by: Yaru Du Co-authored-by: Ryan Metcalfe <107415876+RyanMetcalfeInt8@users.noreply.github.com> Co-authored-by: Dvoretckii, Mikhail Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- cmake/onnxruntime_providers_openvino.cmake | 5 - cmake/onnxruntime_unittests.cmake | 7 + .../providers/openvino/backend_manager.cc | 262 ++++--- .../core/providers/openvino/backend_manager.h | 8 +- .../core/providers/openvino/backend_utils.cc | 79 +- .../core/providers/openvino/backend_utils.h | 53 +- .../openvino/backends/basic_backend.cc | 689 ++++++------------ .../openvino/backends/basic_backend.h | 200 +++-- .../core/providers/openvino/contexts.h | 8 +- .../core/providers/openvino/ibackend.h | 5 +- .../openvino/onnx_ctx_model_helper.cc | 39 +- .../openvino/onnx_ctx_model_helper.h | 1 + .../openvino/openvino_execution_provider.cc | 93 ++- .../openvino/openvino_parser_utils.cc | 120 +++ .../openvino/openvino_parser_utils.h | 4 + .../openvino/openvino_provider_factory.cc | 52 +- .../core/providers/openvino/ov_allocator.cc | 18 +- .../core/providers/openvino/ov_interface.cc | 637 ++++++++++------ .../core/providers/openvino/ov_interface.h | 92 ++- .../openvino/ov_stateful_patch_utils.cc | 350 +++++++++ .../openvino/ov_stateful_patch_utils.h | 84 +++ .../openvino/ov_versions/capability.cc | 6 +- .../openvino/ov_versions/data_ops.cc | 7 +- .../python/onnxruntime_pybind_state.cc | 58 +- .../quantization/matmul_nbits_quantizer.py | 69 +- onnxruntime/test/perftest/ort_test_session.cc | 14 +- .../cpu/reduction/reduction_ops_test.cc | 6 +- .../openvino/openvino_ep_context_test.cc | 77 ++ 28 files changed, 2030 insertions(+), 1013 deletions(-) create mode 100644 onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc create mode 100644 onnxruntime/core/providers/openvino/ov_stateful_patch_utils.h create mode 100644 onnxruntime/test/providers/openvino/openvino_ep_context_test.cc diff --git a/cmake/onnxruntime_providers_openvino.cmake b/cmake/onnxruntime_providers_openvino.cmake index 03f67983c70ab..d7cb2d5ea0d0f 100644 --- a/cmake/onnxruntime_providers_openvino.cmake +++ b/cmake/onnxruntime_providers_openvino.cmake @@ -30,11 +30,6 @@ endif() list(APPEND OPENVINO_LIB_LIST openvino::frontend::onnx openvino::runtime ${PYTHON_LIBRARIES}) - if ((DEFINED ENV{OPENCL_LIBS}) AND (DEFINED ENV{OPENCL_INCS}) AND onnxruntime_USE_OPENVINO_GPU) - add_definitions(-DIO_BUFFER_ENABLED=1) - list(APPEND OPENVINO_LIB_LIST $ENV{OPENCL_LIBS}) - endif() - source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_providers_openvino_cc_srcs}) onnxruntime_add_shared_library_module(onnxruntime_providers_openvino ${onnxruntime_providers_openvino_cc_srcs} "${ONNXRUNTIME_ROOT}/core/dll/onnxruntime.rc") diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index e8809bd2392c8..d1fb06a95f4c9 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -767,6 +767,13 @@ if(onnxruntime_USE_AZURE) list(APPEND onnxruntime_test_providers_libs onnxruntime_providers_azure) endif() +if (onnxruntime_USE_OPENVINO) + list(APPEND onnxruntime_test_framework_src_patterns ${TEST_SRC_DIR}/providers/openvino/*) + list(APPEND onnxruntime_test_framework_libs onnxruntime_providers_openvino) + list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_openvino) + list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_shared) +endif() + file(GLOB onnxruntime_test_framework_src CONFIGURE_DEPENDS ${onnxruntime_test_framework_src_patterns} ) diff --git a/onnxruntime/core/providers/openvino/backend_manager.cc b/onnxruntime/core/providers/openvino/backend_manager.cc index 9ef7e4b86db5f..041d9c07e41fe 100644 --- a/onnxruntime/core/providers/openvino/backend_manager.cc +++ b/onnxruntime/core/providers/openvino/backend_manager.cc @@ -28,9 +28,10 @@ SessionContext& BackendManager::GetSessionContext() { return session_context_; } -ov::CompiledModel& BackendManager::GetOVCompiledModel() { - ov::CompiledModel& ov_ptr = concrete_backend_->GetOVCompiledModel(); - return (ov_ptr); +ov::CompiledModel BackendManager::GetOVCompiledModel() { + if (concrete_backend_) + return concrete_backend_->GetOVCompiledModel(); + return ov::CompiledModel(); } BackendManager::BackendManager(SessionContext& session_context, @@ -42,6 +43,9 @@ BackendManager::BackendManager(SessionContext& session_context, session_context_(session_context), shared_context_{shared_context} { subgraph_context_.is_ep_ctx_graph = ep_ctx_handle_.CheckForOVEPCtxNodeInGraph(subgraph); + // If the graph contains a OVIR wrapped node, we check if it has matching xml file name attribute + subgraph_context_.is_ep_ctx_ovir_encapsulated = ep_ctx_handle_.CheckEPCacheContextAttribute(subgraph, + session_context_.onnx_model_path_name.filename().replace_extension("xml").string()); subgraph_context_.model_precision = [&](const GraphViewer& graph_viewer) { // return empty if graph has no inputs or if types are not one of FP32/FP16 @@ -65,6 +69,9 @@ BackendManager::BackendManager(SessionContext& session_context, // Save the indexes of graph inputs among fused_node's inputDefs // (which also contains initializers). for (uint32_t index = 0; const auto& node : subgraph.GetInputs()) { + if (subgraph.GetGraph().GetConsumerNodes(node->Name()).size() == 0) { + continue; // Skip if the input is a dangling node + } subgraph_context_.input_names.insert({node->Name(), index++}); } @@ -77,6 +84,11 @@ BackendManager::BackendManager(SessionContext& session_context, ptr_stream_t model_stream; std::unique_ptr model_proto; if (subgraph_context_.is_ep_ctx_graph) { + if (!session_context_.reshape.empty()) { + std::string exception_str = + "[OpenVINO-EP] Bounded dynamic model execution using provider option reshape_input is not supported for OVEP EPContext model"; + ORT_THROW(exception_str); + } model_stream = ep_ctx_handle_.GetModelBlobStream(session_context_.so_context_file_path, subgraph); } else { model_proto = GetModelProtoFromFusedNode(fused_node, subgraph, logger); @@ -84,29 +96,29 @@ BackendManager::BackendManager(SessionContext& session_context, std::string device_type = session_context_.device_type; auto& sw = shared_context_.shared_weights; - if (session_context_.so_share_ep_contexts) { + if (session_context_.so_share_ep_contexts && !sw.metadata.empty()) { std::filesystem::path weight_filename = session_context_.onnx_model_path_name.parent_path(); - if (sw.external_weight_filename.empty() && !sw.metadata.empty()) { + if (sw.external_weight_filename.empty()) { // Reasonable assumption that all metadata entries have the same external file location sw.external_weight_filename = sw.metadata.begin()->second.location; } weight_filename /= sw.external_weight_filename; std::ifstream weight_file(weight_filename); - if (weight_file) { - if (!sw.mapped_weights) { - sw.mapped_weights = std::make_unique(weight_filename); - } - backend_utils::CreateOVTensors(session_context_.device_type, sw.metadata, *sw.mapped_weights); + ORT_ENFORCE(weight_file, "Initializer file not found: ", weight_filename.string()); + if (!sw.mapped_weights) { + sw.mapped_weights = std::make_unique(weight_filename); } + backend_utils::CreateOVTensors(session_context_.device_type, sw.metadata, *sw.mapped_weights); } if (ModelHasSymbolicInputDims(subgraph)) { subgraph_context_.has_dynamic_input_shape = true; LOGS_DEFAULT(INFO) << "[OpenVINO-EP] Model has symbolic input dims"; - if ((session_context_.device_type.find("CPU") != std::string::npos || - session_context_.device_type.find("GPU") != std::string::npos) && - !session_context_.disable_dynamic_shapes) { + if ((!session_context_.disable_dynamic_shapes && + (session_context_.device_type.find("CPU") != std::string::npos || + session_context_.device_type.find("GPU") != std::string::npos)) || + (subgraph_context_.is_ep_ctx_graph)) { LOGS_DEFAULT(INFO) << "[OpenVINO-EP] Starting backend initialization. " << "Creating backend Dynamic Shapes"; try { @@ -141,58 +153,35 @@ BackendManager::BackendManager(SessionContext& session_context, model_stream); } catch (const OnnxRuntimeException& ex) { std::string exception_str = ex.what(); - bool eligible_for_cpu_fallback = device_type.find("NPU") != std::string::npos && - !session_context_.so_disable_cpu_ep_fallback && - !subgraph_context_.is_ep_ctx_graph; -#if defined(OPENVINO_DISABLE_NPU_FALLBACK) - eligible_for_cpu_fallback = false; -#else - if (eligible_for_cpu_fallback) { - LOGS_DEFAULT(VERBOSE) << exception_str; - LOGS_DEFAULT(WARNING) << "Model compilation failed at OV NPU." - << "Falling back to OV CPU for execution"; - session_context_.device_type = "CPU"; - session_context_.precision = "FP32"; - try { - concrete_backend_ = BackendFactory::MakeBackend(model_proto, - session_context_, - subgraph_context_, - shared_context_, - model_stream); - } catch (std::string const& msg) { - ORT_THROW(msg); - } - } -#endif - if (!eligible_for_cpu_fallback) { - if (device_type.find("NPU") != std::string::npos && - exception_str.find("intel_npu") != std::string::npos) { - // Handle NPU device related errors + + if (session_context_.device_type.find("NPU") != std::string::npos && + exception_str.find("intel_npu") != std::string::npos) { + // Handle NPU device related errors #ifndef NDEBUG - ORT_THROW(exception_str + "\nModel needs to be recompiled\n"); + ORT_THROW(exception_str + "\nModel needs to be recompiled\n"); #else - std::string error_message = "UNKNOWN NPU ERROR"; - std::string error_code = "code 0x0"; - std::regex error_message_pattern(R"(\bZE_\w*\b)"); - std::regex error_code_pattern("code 0x[0-9a-fA-F]+"); - std::smatch matches; - if (std::regex_search(exception_str, matches, error_message_pattern)) { - error_message = matches[0]; - } - if (std::regex_search(exception_str, matches, error_code_pattern)) { - error_code = matches[0]; - } - throw std::runtime_error(error_message + ", " + error_code + "\nModel needs to be recompiled\n"); -#endif - } else { - ORT_THROW(exception_str); + std::string error_message = "UNKNOWN NPU ERROR"; + std::string error_code = "code 0x0"; + std::regex error_message_pattern(R"(\bZE_\w*\b)"); + std::regex error_code_pattern("code 0x[0-9a-fA-F]+"); + std::smatch matches; + if (std::regex_search(exception_str, matches, error_message_pattern)) { + error_message = matches[0]; + } + if (std::regex_search(exception_str, matches, error_code_pattern)) { + error_code = matches[0]; } + throw std::runtime_error(error_message + ", " + error_code + "\nModel needs to be recompiled\n"); +#endif + } else { + ORT_THROW(exception_str); } } } - if (session_context_.so_context_enable && !subgraph_context_.is_ep_ctx_graph) { + if (session_context_.so_context_enable && + (subgraph_context_.is_ep_ctx_ovir_encapsulated || !subgraph_context_.is_ep_ctx_graph)) { auto status = onnxruntime::openvino_ep::BackendManager::ExportCompiledBlobAsEPCtxNode(subgraph); - if ((!status.IsOK())) { + if (!status.IsOK()) { ORT_THROW(status); } } @@ -287,24 +276,83 @@ bool BackendManager::ModelHasBatchedInputs(const ONNX_NAMESPACE::ModelProto& mod } bool BackendManager::ModelHasSymbolicInputDims(const onnxruntime::GraphViewer& subgraph) const { - bool has_sym_dims = false; - auto graph_inputs = subgraph.GetInputs(); - for (auto input : graph_inputs) { + const auto& graph_inputs = subgraph.GetInputs(); + + // First validate shapes if provided by user + bool shapes_valid = true; + if (!session_context_.reshape.empty()) { + try { + ValidateInputShapes(session_context_.reshape, graph_inputs); + } catch (const std::exception& e) { + LOGS_DEFAULT(ERROR) << "[OpenVINO-EP] Shape validation failed: " << e.what(); + session_context_.reshape.clear(); // Clear the shape map as it's invalid + shapes_valid = false; + } + } + + // Count dynamic inputs and check if reshape covers all of them + size_t dynamic_input_count = 0; + bool all_dynamic_inputs_covered = true; + + for (const auto* input : graph_inputs) { + // Skip dangling inputs (no consumers) + if (subgraph.GetGraph().GetConsumerNodes(input->Name()).empty()) { + continue; + } + + // Check if input has dynamic dimensions + bool has_dynamic_dim = false; + + // Case 1: Completely undefined shape if (input->Shape() == nullptr) { - has_sym_dims = true; - break; + has_dynamic_dim = true; } - for (auto& dim : input->Shape()->dim()) { - if (dim.value_case() != dim.kDimValue) { - has_sym_dims = true; - break; + // Case 2: Shape defined but with symbolic dimensions + else { + for (const auto& dim : input->Shape()->dim()) { + if (dim.value_case() != dim.kDimValue) { + has_dynamic_dim = true; + break; + } } } - if (has_sym_dims) { - break; + + // If dynamic, count it and check if reshape covers it + if (has_dynamic_dim) { + dynamic_input_count++; + + // Check if this dynamic input is covered by reshape input + if (!session_context_.reshape.empty() && + session_context_.reshape.find(input->Name()) == session_context_.reshape.end()) { + all_dynamic_inputs_covered = false; + LOGS_DEFAULT(WARNING) << "[OpenVINO-EP] reshape_input is provided but doesn't cover dynamic input: " + << input->Name(); + } } } - return has_sym_dims; + + const bool has_symbolic_dims = (dynamic_input_count > 0); + + // Early return if no reshape input provided + if (session_context_.reshape.empty()) { + return has_symbolic_dims; // Return based on whether model has symbolic dims + } + + // For dynamic models with incomplete reshape coverage, clear shapes + if (has_symbolic_dims && !all_dynamic_inputs_covered) { + session_context_.reshape.clear(); + LOGS_DEFAULT(WARNING) << "reshape_input does not cover all dynamic dimensions, " + << "ignoring all provided shapes"; + return true; // Model is dynamic + } + + // If shapes are valid with complete coverage for dynamic model, treat as concrete + if (has_symbolic_dims && shapes_valid && all_dynamic_inputs_covered) { + LOGS_DEFAULT(INFO) << "All dynamic dimensions successfully covered by reshape_input"; + return false; // Model is now effectively static with concrete shapes + } + + return has_symbolic_dims; // Return dynamic status based on symbolic dimensions } // Check to see if the graph is QDQ @@ -380,8 +428,9 @@ BackendManager::GetModelProtoFromFusedNode(const onnxruntime::Node& fused_node, #endif const auto& onnx_model_path_name = subgraph.ModelPath(); - // QDQ stripping enabled only for the NPU - if (session_context_.device_type.find("NPU") != std::string::npos && + // QDQ stripping enabled only for the NPU and experimentally on the GPU + if ((session_context_.device_type.find("NPU") != std::string::npos || + session_context_.device_type.find("GPU") != std::string::npos) && (enable_ovep_qdq_optimizer || session_context_.so_share_ep_contexts)) { std::unique_ptr model; Status status = CreateModelWithStrippedQDQNodes(subgraph, logger, session_context_.so_share_ep_contexts, enable_ovep_qdq_optimizer, model, shared_context_.shared_weights); @@ -475,9 +524,44 @@ BackendManager::ReWriteBatchDimWithOne(const ONNX_NAMESPACE::ModelProto& model_p return model_copy; } +void BackendManager::ValidateInputShapes(const reshape_t& shapes, + const std::vector& graph_inputs) const { + for (const auto& [tensor_name, requested_shape] : shapes) { + // Find matching input in graph + const NodeArg* graph_input = nullptr; + for (const auto* input : graph_inputs) { + if (input->Name() == tensor_name) { + graph_input = input; + break; + } + } + + if (!graph_input) { + ORT_THROW("Input '" + tensor_name + "' specified in reshape_input does not exist in the graph"); + } + + const ONNX_NAMESPACE::TensorShapeProto* graph_shape = graph_input->Shape(); + if (!graph_shape) { + ORT_THROW("Graph input '" + tensor_name + "' has no shape information"); + } + + // Check dimensions count matches + size_t graph_dim_count = graph_shape->dim_size(); + size_t requested_dim_count = requested_shape.get_max_shape().size(); + + if (graph_dim_count != requested_dim_count) { + ORT_THROW("Dimensions mismatch for input '" + tensor_name + + "': graph expects " + std::to_string(graph_dim_count) + + " dimensions but reshape_input specifies " + + std::to_string(requested_dim_count) + " dimensions"); + } + } +} + void BackendManager::Compute(OrtKernelContext* context) { Ort::KernelContext ctx(context); std::chrono::high_resolution_clock::time_point start_compute, end_compute; + #ifdef OPENVINO_FIL_ENABLED static bool fil_enabled = true; if (fil_enabled) { @@ -485,21 +569,26 @@ void BackendManager::Compute(OrtKernelContext* context) { LOGS_DEFAULT(INFO) << "Start Compute"; } #endif - // OV NPU doesn't support dynamic shaped model inference. + // if disable_dynamic_shapes is set to true then execution of dynamic model is done // by rewriting the model to static shaped model at runtime based on input shape. - // disable_dynamic_shapes is always set to true for OV NPU plugin. - if (subgraph_context_.has_dynamic_input_shape && - !session_context_.disable_dynamic_shapes && - (session_context_.device_type.find("CPU") != std::string::npos || - session_context_.device_type.find("GPU") != std::string::npos)) { + // disable_dynamic_shapes should be set for devices that don't support dynamic shapes. + bool need_dynamic_backend = subgraph_context_.has_dynamic_input_shape && + session_context_.disable_dynamic_shapes && !subgraph_context_.is_ep_ctx_graph; + + if (!need_dynamic_backend) { concrete_backend_->Infer(context); - } else if (subgraph_context_.has_dynamic_input_shape) { + } else { std::vector> tensor_shapes = GetInputTensorShapes(ctx); auto key = MakeMapKeyString(tensor_shapes, session_context_.device_type); std::shared_ptr dynamic_backend; - auto search = backend_map_.find(key); - if (search == backend_map_.end()) { + + { + std::unique_lock lock(mutex_); + dynamic_backend = backend_map_[key]; + } + + if (!dynamic_backend) { ptr_stream_t model_stream; LOGS_DEFAULT(INFO) << "[OpenVINO-EP] " << "Creating dynamic backend for key: " << key; @@ -540,14 +629,11 @@ void BackendManager::Compute(OrtKernelContext* context) { } #endif } + std::unique_lock lock(mutex_); backend_map_.insert({key, dynamic_backend}); - } else { - dynamic_backend = search->second; } dynamic_backend->Infer(context); - } else { - concrete_backend_->Infer(context); } #ifdef OPENVINO_FIL_ENABLED if (fil_enabled) { @@ -565,5 +651,11 @@ void BackendManager::ShutdownBackendManager() { concrete_backend_.reset(); } +void BackendManager::RewindKVCache(size_t index) { + if (concrete_backend_) { + concrete_backend_->RewindKVCache(index); + } +} + } // namespace openvino_ep } // namespace onnxruntime diff --git a/onnxruntime/core/providers/openvino/backend_manager.h b/onnxruntime/core/providers/openvino/backend_manager.h index cdc27701ec2e6..f091f95fe1c16 100644 --- a/onnxruntime/core/providers/openvino/backend_manager.h +++ b/onnxruntime/core/providers/openvino/backend_manager.h @@ -29,7 +29,8 @@ class BackendManager { void ShutdownBackendManager(); SessionContext& GetSessionContext(); Status ExportCompiledBlobAsEPCtxNode(const onnxruntime::GraphViewer& subgraph); - ov::CompiledModel& GetOVCompiledModel(); + ov::CompiledModel GetOVCompiledModel(); + void RewindKVCache(size_t index); private: std::unique_ptr GetModelProtoFromFusedNode( @@ -38,7 +39,11 @@ class BackendManager { const logging::Logger& logger) const; bool ModelHasSymbolicInputDims(const onnxruntime::GraphViewer& subgraph) const; + std::unordered_set IdentifyDynamicInputs(const onnxruntime::GraphViewer& subgraph, + const std::vector& graph_inputs) const; bool ModelHasBatchedInputs(const ONNX_NAMESPACE::ModelProto& model_proto) const; + void ValidateInputShapes(const reshape_t& shapes, + const std::vector& graph_inputs) const; std::shared_ptr ReWriteBatchDimWithOne(const ONNX_NAMESPACE::ModelProto& model_proto); @@ -49,6 +54,7 @@ class BackendManager { std::unique_ptr model_proto_; std::shared_ptr concrete_backend_; + std::mutex mutex_; std::map> backend_map_; SubGraphContext subgraph_context_; EPCtxHandler& ep_ctx_handle_; diff --git a/onnxruntime/core/providers/openvino/backend_utils.cc b/onnxruntime/core/providers/openvino/backend_utils.cc index 2ee5e9ec3e3a9..73fbe9a0fa76f 100644 --- a/onnxruntime/core/providers/openvino/backend_utils.cc +++ b/onnxruntime/core/providers/openvino/backend_utils.cc @@ -121,7 +121,7 @@ std::istream& operator>>(std::istream& stream, SharedContext::SharedWeights::Met namespace backend_utils { bool IsDebugEnabled() { - const std::string env_name = onnxruntime::GetEnvironmentVar("ORT_OPENVINO_ENABLE_DEBUG"); + static std::string env_name = onnxruntime::GetEnvironmentVar("ORT_OPENVINO_ENABLE_DEBUG"); if (!env_name.empty()) { return true; } @@ -129,7 +129,7 @@ bool IsDebugEnabled() { } bool IsCILogEnabled() { - const std::string env_name = onnxruntime::GetEnvironmentVar("ORT_OPENVINO_ENABLE_CI_LOG"); + static std::string env_name = onnxruntime::GetEnvironmentVar("ORT_OPENVINO_ENABLE_CI_LOG"); if (!env_name.empty()) { return true; } @@ -146,6 +146,10 @@ CreateOVModel(std::string&& model, try { auto ov_model = OVCore::Get()->ReadModel(std::move(model), session_context.onnx_model_path_name.string()); + if (!session_context.reshape.empty()) { + LOGS_DEFAULT(INFO) << log_tag << "Reshaping the ov tensor to specified shape"; + ov_model->reshape(session_context.reshape); + } // Check for Constant Folding if ((session_context.device_type != "NPU") && !session_context.is_wholly_supported_graph) { ov::pass::ConstantFolding pass_const_obj; @@ -175,32 +179,6 @@ CreateOVModel(std::string&& model, } } -Ort::UnownedValue -GetOutputTensor(Ort::KernelContext& context, size_t batch_size, - OVInferRequestPtr infer_request, - std::string output_name, - const SubGraphContext::string_index_map_t& output_names) { - auto graph_output_blob = infer_request->GetTensor(output_name); - - auto graph_output_dims = graph_output_blob->get_shape(); - - if (batch_size > 1) { - // Add the batch size as dim 0. - graph_output_dims.insert(graph_output_dims.begin(), batch_size); - } - size_t num_dims = graph_output_dims.size(); - std::unique_ptr output_shape(new int64_t[num_dims]); - for (size_t j = 0; j < num_dims; j++) { - output_shape[j] = static_cast(graph_output_dims[j]); - } - auto it = output_names.find(output_name); - if (it == output_names.end()) { - ORT_THROW(log_tag + "Output names mismatch between OpenVINO and ONNX"); - } - int index = it->second; - return context.GetOutput(index, output_shape.get(), num_dims); -} - Ort::UnownedValue GetOutputTensor(Ort::KernelContext& context, std::string output_name, @@ -216,14 +194,9 @@ GetOutputTensor(Ort::KernelContext& context, ORT_THROW(log_tag + "Output names mismatch between OpenVINO and ONNX"); } int index = it->second; - auto shape = node->get_shape(); + auto output_shape = ParameterShape::ToOrtShape(node->get_shape()); - size_t num_dims = shape.size(); - std::unique_ptr output_shape(new int64_t[num_dims]); - for (size_t j = 0; j < num_dims; j++) { - output_shape[j] = static_cast(shape[j]); - } - return context.GetOutput(index, output_shape.get(), num_dims); + return context.GetOutput(index, output_shape); } int GetFirstAvailableDevice(SessionContext& session_context) { @@ -308,15 +281,6 @@ void FillInputBlob(OVTensorPtr inputBlob, size_t batch_slice_idx, std::memcpy(input_data, batch_memory_offset, input_data_size); } -void FillOutputBlob(OVTensorPtr outputBlob, Ort::UnownedValue& output_tensor, - size_t batch_slice_idx) { - auto output_data = outputBlob->data(); - size_t output_data_size = outputBlob->get_byte_size(); - char* tensor_data = output_tensor.GetTensorMutableData(); - char* batch_memory_offset = tensor_data + output_data_size * batch_slice_idx; - std::memcpy(batch_memory_offset, output_data, output_data_size); -} - void printPerformanceCounts(const std::vector& performanceMap, std::ostream& stream, std::string deviceName) { int64_t totalTime = 0; @@ -436,6 +400,33 @@ void DestroyOVTensors(SharedContext::SharedWeights::Metadata::Map& metadata_map) metadata_map.clear(); } +bool IsModelStreamXML(std::istream& model_stream) { + std::streampos originalPos = model_stream.tellg(); + + // first, get the total size of model_stream in bytes + model_stream.seekg(0, std::ios::end); + auto end_pos = model_stream.tellg(); + // Restore the stream position + model_stream.seekg(originalPos); + auto total_size = end_pos - originalPos; + + // Choose 32 bytes to hold content of: + // ' header_check_len); + + // read 32 bytes into header + std::string header(header_check_len, '\0'); + model_stream.read(&header[0], header_check_len); + // Clear any read errors + model_stream.clear(); + // Restore the stream position + model_stream.seekg(originalPos); + + // return true if the header starts with '; + + static ov::PartialShape ToOvPartialShape(const ort_shape_t& ort_shape) { + std::vector ov_shape(ort_shape.size()); + std::transform(ort_shape.begin(), ort_shape.end(), ov_shape.begin(), [](int64_t dim) { + return dim == -1 ? ov::Dimension::dynamic() : ov::Dimension(dim); + }); + return ov::PartialShape(ov_shape); + } + + static ort_shape_t ToOrtShape(const ov::PartialShape& ov_shape) { + ort_shape_t ort_shape(ov_shape.size()); + std::transform(ov_shape.begin(), ov_shape.end(), ort_shape.begin(), [](const auto& dim) { + return dim.is_dynamic() ? -1 : dim.get_length(); + }); + return ort_shape; + } + + static ort_shape_t ToOrtShape(const ov::Shape& ov_shape) { + ort_shape_t ort_shape(ov_shape.size()); + std::transform(ov_shape.begin(), ov_shape.end(), ort_shape.begin(), [](const auto& dim) { + return narrow(dim); + }); + return ort_shape; + } + + operator ov::Shape() const { return ov_.get_shape(); } + operator const ov::PartialShape&() const { return ov_; } + operator const ort_shape_t&() const { return ort_; } + + explicit ParameterShape(const ort_shape_t& ort_shape) : ort_(ort_shape), ov_(ToOvPartialShape(ort_shape)) {} + explicit ParameterShape(const ov::PartialShape& ov_partial_shape) : ov_(ov_partial_shape), ort_(ToOrtShape(ov_partial_shape)) {} + + private: + ort_shape_t ort_; + ov::PartialShape ov_; +}; + namespace backend_utils { -const std::string log_tag = "[OpenVINO-EP] "; bool IsDebugEnabled(); @@ -48,19 +88,10 @@ GetOutputTensor(Ort::KernelContext& context, const SubGraphContext::string_index_map_t& output_names, std::shared_ptr node); -Ort::UnownedValue -GetOutputTensor(Ort::KernelContext& context, size_t batch_size, - OVInferRequestPtr infer_request, - std::string output_name, - const SubGraphContext::string_index_map_t& output_names); - void FillInputBlob(OVTensorPtr inputBlob, size_t batch_slice_idx, std::string input_name, Ort::KernelContext& context, const SubGraphContext& subgraph_context); -void FillOutputBlob(OVTensorPtr outputBlob, Ort::UnownedValue& output_tensor, - size_t batch_slice_idx); - std::shared_ptr CreateOVModel(std::string&& model, const SessionContext& session_context, @@ -76,6 +107,8 @@ void printPerformanceCounts(const std::vector& performanceMap, void printPerformanceCounts(OVInferRequestPtr request, std::ostream& stream, std::string deviceName); +bool IsModelStreamXML(std::istream& model_stream); + } // namespace backend_utils } // namespace openvino_ep } // namespace onnxruntime diff --git a/onnxruntime/core/providers/openvino/backends/basic_backend.cc b/onnxruntime/core/providers/openvino/backends/basic_backend.cc index c814df618e3b3..df75f84a5fee0 100644 --- a/onnxruntime/core/providers/openvino/backends/basic_backend.cc +++ b/onnxruntime/core/providers/openvino/backends/basic_backend.cc @@ -9,12 +9,14 @@ #include #include #include +#include #include "core/providers/shared_library/provider_api.h" #include "core/providers/openvino/backend_utils.h" #include "core/providers/openvino/backends/basic_backend.h" #include "core/providers/openvino/onnx_ctx_model_helper.h" #include "core/providers/openvino/backend_manager.h" +#include "core/providers/openvino/ov_stateful_patch_utils.h" namespace onnxruntime { @@ -29,98 +31,112 @@ BasicBackend::BasicBackend(std::unique_ptr& model_pr ptr_stream_t& model_stream) : session_context_{session_context}, subgraph_context_{subgraph_context}, shared_context_{shared_context} { std::string& hw_target = session_context_.device_type; + bool enable_causallm = session_context_.enable_causallm; if (ValidateSubgraph(const_outputs_map_)) return; - // OV Config + // Pre-requisite is provider_option "context" must be set + auto auto_unified_compile = ((hw_target.find("AUTO") == std::string::npos) || + (session_context_.OpenVINO_Version.at(0) >= 2024 && + session_context_.OpenVINO_Version.at(1) > 2)); ov::AnyMap device_config; - PopulateConfigValue(device_config); - - // Enable caching - EnableCaching(); - - // Setting OpenCL queue throttling for GPU - EnableGPUThrottling(device_config); - - // Enable streams; default=1 unless ovverriden by user config - EnableStreams(); - - // Set the inference_num_threads property of the CPU - SetNumThreads(device_config); - - auto npuw_status = - std::any_of(device_config.begin(), device_config.end(), [&](const std::pair& pair) { - return (pair.first.find("NPU_USE_NPUW") != std::string::npos) && (pair.second.is()) && - (pair.second.as() == "YES"); - }); - - if (npuw_status) { - LOGS_DEFAULT(INFO) << log_tag << "NPUW Enabled during compilation"; - } - - try { - // IO_BUFFER is enabled on GPU HW. - // Pre-requisite is provider_option "context" must be set -#if defined(IO_BUFFER_ENABLED) - cl_context ctx = static_cast(session_context_.context); - remote_context_ = new ov::intel_gpu::ocl::ClContext(OVCore::Get()->core, ctx); - if (subgraph_context_.is_ep_ctx_graph) { - exe_network_ = OVCore::Get()->ImportModel(*model_stream, - remote_context_, - subgraph_context_.subgraph_name); - model_stream.reset(); // Delete stream after it is no longer needed - } else { - std::string model = model_proto->SerializeAsString(); - if (!subgraph_context.has_dynamic_input_shape) { - model_proto.reset() + SetOVDeviceConfiguration(device_config); + if (subgraph_context_.is_ep_ctx_graph) { + try { + if (subgraph_context_.is_ep_ctx_ovir_encapsulated) { + // model_file_path will use so_context_file_path if the onnx_model_path_name is not available, + // especially in case of CreateSessionFormArray() where user must explicitly + // specify absolute path for so_context_file_path. + auto model_file_path = [this]() { + if (!session_context_.onnx_model_path_name.empty() && + std::filesystem::exists(session_context_.onnx_model_path_name)) return session_context_.onnx_model_path_name; + + ORT_ENFORCE(!session_context_.so_context_file_path.empty() && + std::filesystem::path(session_context_.so_context_file_path).is_absolute() && + std::filesystem::exists(session_context_.so_context_file_path), + log_tag + + "Context file path must be non-empty & absolute, when using CreateSessionFormArray() API explicitly." + " Please set a valid absolute path for ep.context_file_path in session options."); + // Return absolute context file path as input to ImportEPCtxOVIREncapsulation() function. + return session_context_.so_context_file_path; + }; + // If the EPContext node with OVIR Encapsulation, then create + // an executable network from EP_CACHE_CONTEXT using read_model() & compile_model() + exe_network_ = OVCore::Get()->ImportEPCtxOVIREncapsulation(*model_stream, + hw_target, + device_config, + enable_causallm, + model_file_path()); + } else { + // If the blob is held in an EPContext node, then skip FE+Compile + // and directly move on to creating a backend with the executable blob + exe_network_ = OVCore::Get()->ImportModel(*model_stream, + hw_target, + device_config, + subgraph_context_.subgraph_name); } - auto ov_model = CreateOVModel(std::move(model), session_context_, const_outputs_map_); - LOGS_DEFAULT(INFO) << log_tag << "IO Buffering Enabled"; - exe_network_ = OVCore::Get()->CompileModel( - ov_model, remote_context_, subgraph_context_.subgraph_name); + model_stream.reset(); + } catch (const char* msg) { + ORT_THROW(msg); + } // Delete stream after it is no longer needed + } else { + std::string model = model_proto->SerializeAsString(); + if (!subgraph_context.has_dynamic_input_shape) { + model_proto.reset(); } -#else // !IO_BUFFER_ENABLED - auto auto_unified_compile = ((hw_target.find("AUTO") == std::string::npos) || - (session_context_.OpenVINO_Version.at(0) >= 2024 && - session_context_.OpenVINO_Version.at(1) > 2)); - if (subgraph_context_.is_ep_ctx_graph) { - // If the blob is held in an EPContext node, then skip FE+Compile - // and directly move on to creating a backend with the executable blob - exe_network_ = OVCore::Get()->ImportModel(*model_stream, - hw_target, - device_config, - subgraph_context_.subgraph_name); - model_stream.reset(); // Delete stream after it is no longer needed - } else if (!session_context_.has_external_weights && - !subgraph_context_.has_dynamic_input_shape && - !session_context_.so_context_enable && - auto_unified_compile) { - // Unified OV compile_model is efficient when ov model caching is enabled - // Unified OV compile_model API is supported with AUTO from version 2024.3 and above - // Inputs with static dimenstions - // Not enabled for models with external weights and when ep context is set. - const std::string model = model_proto->SerializeAsString(); - exe_network_ = OVCore::Get()->CompileModel(model, - hw_target, - device_config, - subgraph_context_.subgraph_name); - } else { // For all other types use ov::ov_core read_model() to generate OV IR - // followed by ov::ov_core compile_model() - std::string model = model_proto->SerializeAsString(); - if (!subgraph_context.has_dynamic_input_shape) { - model_proto.reset(); + try { + // SetOVDeviceConfiguration(device_config); + if (!session_context_.has_external_weights && + !subgraph_context_.has_dynamic_input_shape && + !session_context_.so_context_enable && + session_context_.reshape.empty() && + !enable_causallm && + auto_unified_compile) { + // Unified OV compile_model is efficient when ov model caching is enabled + // Unified OV compile_model API is supported with AUTO from version 2024.3 and above + // Inputs with static dimensions + // Not enabled for models with external weights and when ep context is set. + + exe_network_ = OVCore::Get()->CompileModel(model, + hw_target, + device_config, + subgraph_context_.subgraph_name); + } else { // For all other types use ov::ov_core read_model() to generate OV IR + // followed by ov::ov_core compile_model() + auto ov_model = CreateOVModel(std::move(model), session_context_, const_outputs_map_); + exe_network_ = OVCore::Get()->CompileModel( + ov_model, hw_target, device_config, enable_causallm, subgraph_context_.subgraph_name); } - auto ov_model = CreateOVModel(std::move(model), session_context_, const_outputs_map_); - exe_network_ = OVCore::Get()->CompileModel( - ov_model, hw_target, device_config, subgraph_context_.subgraph_name); - } + LOGS_DEFAULT(INFO) << log_tag << "Loaded model to the plugin"; + } catch (const OnnxRuntimeException& ex) { + std::string exception_str = ex.what(); + bool eligible_for_cpu_fallback = session_context_.device_type.find("NPU") != std::string::npos && + !session_context_.so_disable_cpu_ep_fallback && + !subgraph_context_.is_ep_ctx_graph; +#if defined(OPENVINO_DISABLE_NPU_FALLBACK) + eligible_for_cpu_fallback = false; #endif - LOGS_DEFAULT(INFO) << log_tag << "Loaded model to the plugin"; - } catch (const char* msg) { - ORT_THROW(msg); + if (eligible_for_cpu_fallback) { + LOGS_DEFAULT(WARNING) << "Model compilation failed at OV NPU." + << "Falling back to OV CPU for execution"; + session_context_.device_type = "CPU"; + session_context_.precision = "FP32"; + device_config.clear(); + SetOVDeviceConfiguration(device_config); + try { + // Recreate the model with CPU device type + auto ov_model = CreateOVModel(std::move(model), session_context_, const_outputs_map_); + exe_network_ = OVCore::Get()->CompileModel( + ov_model, hw_target, device_config, enable_causallm, subgraph_context_.subgraph_name); + } catch (std::string const& msg) { + ORT_THROW(msg); + } + } else { + ORT_THROW(ex.what()); + } + } } - int num_infer_req = (session_context_.num_of_threads > 0) ? session_context_.num_of_threads : 1; std::function initializer = [](OVInferRequestPtr) {}; auto metadata = shared_context_.shared_weights.metadata; @@ -137,7 +153,8 @@ BasicBackend::BasicBackend(std::unique_ptr& model_pr } }; } - inferRequestsQueue_ = std::unique_ptr(new InferRequestsQueue(exe_network_, num_infer_req, std::move(initializer))); + infer_req_pool_ = std::make_unique(exe_network_, num_infer_req, std::move(initializer)); + bindings_ = std::make_unique(exe_network_, subgraph_context_, session_context_); } bool BasicBackend::ValidateSubgraph(std::map>& const_outputs_map) { @@ -198,6 +215,15 @@ void BasicBackend::PopulateConfigValue(ov::AnyMap& device_config) { if (!session_context_.load_config.empty()) { const std::map& target_config = session_context_.load_config; + if ((session_context_.device_type.find("NPU") != std::string::npos) && session_context_.enable_causallm) { + if (target_config.find("NPU") != target_config.end()) { + auto npu_genai_config = target_config.at("NPU"); + CausalLMConfig().ApplyConfig(npu_genai_config, device_config); + } else { + LOGS_DEFAULT(WARNING) << "ORT GenAI CausalLMConfig Configuration not found."; + } + } + if (session_context_.device_type.find("NPU") != std::string::npos) { auto npuw_config = target_config.at("NPU"); @@ -263,7 +289,8 @@ void BasicBackend::PopulateConfigValue(ov::AnyMap& device_config) { auto set_target_properties = [&](const std::string& device, const ov::AnyMap& config_options, const std::vector& supported_properties) { for (const auto& [key, value] : config_options) { - if (key.find("NPUW") != std::string::npos) { + if ((key.find("NPUW") != std::string::npos) || + ((device_config.find(key) != device_config.end()) && session_context_.enable_causallm)) { continue; } if (is_supported_and_mutable(key, supported_properties)) { @@ -356,331 +383,59 @@ void BasicBackend::SetNumThreads(ov::AnyMap& device_config) { device_config.emplace(ov::inference_num_threads(session_context_.num_of_threads)); } -// Starts an asynchronous inference request for data in slice indexed by batch_slice_idx on -// an Infer Request indexed by infer_req_idx -void BasicBackend::StartAsyncInference(Ort::KernelContext& context, OVInferRequestPtr infer_request) { - try { - auto ov_input_info = exe_network_.Get().inputs(); - - // Loop over subgraph original input names to find the correspondent OV input name - for (const auto& [onnx_input_name, onnx_input_index] : subgraph_context_.input_names) { - std::string input_name{}; - uint32_t input_idx = 0; - for (uint32_t index = 0; const auto& ov_input : ov_input_info) { - if (ov_input.get_names().contains(onnx_input_name)) { - input_name = onnx_input_name; - input_idx = index; - break; - } - index++; - } - ORT_ENFORCE(!input_name.empty(), log_tag, - "Input names mismatch between OpenVINO and ONNX. ", onnx_input_name, - " doesn't exist in the list of OpenVINO input tensor names"); - size_t batch_slice_idx = 0; - if (subgraph_context_.has_dynamic_input_shape && - !session_context_.disable_dynamic_shapes && - (session_context_.device_type.find("CPU") != std::string::npos || - session_context_.device_type.find("GPU") != std::string::npos)) { - auto tensor = context.GetInput(subgraph_context_.input_names.at(input_name)); - auto tensor_info = tensor.GetTensorTypeAndShapeInfo(); - auto tensor_shape = tensor_info.GetShape(); - auto tensor_size = tensor_shape.size(); - const char* tensor_data = tensor.GetTensorData(); - auto tensor_iter = 0; - ov::Shape input_tensor_shape = ov::Shape(tensor_size, 0); - for (auto i = tensor_shape.begin(); i != tensor_shape.end(); ++i) { - input_tensor_shape[tensor_iter] = *i; - tensor_iter += 1; - } - const auto& input = ov_input_info.at(input_idx); - OVTensorPtr tensor_ptr; - // avoid input copies on the CPU device - if (session_context_.device_type.find("CPU") != std::string::npos) { - tensor_ptr = std::make_shared(input.get_element_type(), input_tensor_shape, - (void*)tensor_data); - } else { - tensor_ptr = std::make_shared(input.get_element_type(), input_tensor_shape); - FillInputBlob(tensor_ptr, batch_slice_idx, input_name, context, subgraph_context_); - } - - try { - infer_request->SetTensor(std::move(input_name), tensor_ptr); - } catch (const char* msg) { - ORT_THROW(msg); - } - } else { - if ((session_context_.device_type.find("CPU") != std::string::npos || - session_context_.device_type.find("GPU") != std::string::npos)) { - OVTensorPtr graph_input_blob; - try { - graph_input_blob = infer_request->GetTensor(input_name); - } catch (const char* msg) { - ORT_THROW(msg); - } - FillInputBlob(std::move(graph_input_blob), batch_slice_idx, std::move(input_name), context, subgraph_context_); - } else { - auto tensor = context.GetInput(subgraph_context_.input_names.at(input_name)); - ort_tensor_key_t ort_tensor_key{input_name}; - auto it = ort_ov_tensor_map.find(ort_tensor_key); - if ((it == ort_ov_tensor_map.end()) || - (it != ort_ov_tensor_map.end() && (it->second.ort_ptr != tensor.GetTensorRawData()))) { - ov_tensor_data_t ov_tensor_data; - const auto& input = ov_input_info.at(input_idx); - ov_tensor_data.tensor_ptr = std::make_shared(input.get_element_type(), input.get_shape(), - const_cast(tensor.GetTensorRawData())); - - ov_tensor_data.ort_ptr = tensor.GetTensorRawData(); - ort_ov_tensor_map[ort_tensor_key] = ov_tensor_data; - - try { - infer_request->SetTensor(std::move(input_name), ov_tensor_data.tensor_ptr); - } catch (const char* msg) { - ORT_THROW(msg); - } - } - } - } - } // Loop subgraph original input names +void BasicBackend::SetOVDeviceConfiguration(ov::AnyMap& device_config) { + PopulateConfigValue(device_config); - if (session_context_.device_type.find("NPU") != std::string::npos) { - // Set the output blob as remote blob - auto graph_output_info = exe_network_.Get().outputs(); - auto output_idx = 0; - for (auto output_info_iter = graph_output_info.begin(); - output_info_iter != graph_output_info.end(); ++output_info_iter) { - auto output_names = output_info_iter->get_names(); - std::string onnx_output_name; - std::string output_name; - // using the output name retrieved from ONNX original to match with the output names returned by OV tensors - for (auto it = subgraph_context_.output_names.begin(); it != subgraph_context_.output_names.end(); ++it) { - onnx_output_name = it->first; - if (output_names.find(onnx_output_name) != output_names.end()) { - // Assigning the output_name - output_name = it->first; - break; - } - } - size_t batch_size = 1; - Ort::UnownedValue tensor = GetOutputTensor(context, - batch_size, - infer_request, - output_name, - subgraph_context_.output_names); - ort_tensor_key_t ort_tensor_key{output_name}; - const auto& it = ort_ov_tensor_map.find(ort_tensor_key); - if ((it == ort_ov_tensor_map.end()) || - (it != ort_ov_tensor_map.end() && (it->second.ort_ptr != tensor.GetTensorRawData()))) { - ov_tensor_data_t ov_tensor_data; - const auto& output = graph_output_info.at(output_idx); - ov_tensor_data.ort_ptr = tensor.GetTensorRawData(); - ov_tensor_data.tensor_ptr = std::make_shared(output.get_element_type(), output.get_shape(), - const_cast(tensor.GetTensorRawData())); - ort_ov_tensor_map[ort_tensor_key] = ov_tensor_data; - - try { - infer_request->SetTensor(std::move(output_name), ov_tensor_data.tensor_ptr); - } catch (const char* msg) { - ORT_THROW(msg); - } - } - output_idx++; - } - } + // Enable caching + EnableCaching(); - // Start Async inference - infer_request->StartAsync(); - } catch (const char* msg) { - ORT_THROW(msg); - } -} + // Setting OpenCL queue throttling for GPU + EnableGPUThrottling(device_config); -#ifdef IO_BUFFER_ENABLED -// Wait for Remote Aynchronous inference completion -void BasicBackend::StartRemoteAsyncInference(Ort::KernelContext& context, OVInferRequestPtr infer_request) { - try { - auto graph_input_info = exe_network_.Get().inputs(); - int input_idx = 0; - for (auto input_info_iter = graph_input_info.begin(); - input_info_iter != graph_input_info.end(); ++input_info_iter) { - auto input_names = input_info_iter->get_names(); - std::string onnx_input_name; - std::string input_name; - // use names retrieved from original ONNX model to assign the right onnx input name for the graph - for (auto it = subgraph_context_.input_names.begin(); it != subgraph_context_.input_names.end(); ++it) { - if (it->second == input_idx) { - onnx_input_name = it->first; - break; - } - } - // using the input name retrieved from ONNX original to match with the input names returned by OV tensors - if (input_names.find(onnx_input_name) != input_names.end()) { - input_name = onnx_input_name; - } else { - ORT_THROW(log_tag + - "Input names mismatch between OpenVINO and ONNX. " + - onnx_input_name + - " doesn't exist in the list of OpenVINO input tensor names"); - } - input_idx++; - // Kernel Context Input Buffer - const auto tensor = context.GetInput(subgraph_context_.input_names.at(input_name)); - // If the ORTValue wraps a device pointer - auto mem_info = tensor.GetTensorMemoryInfo(); - if (mem_info.GetAllocatorName() == OpenVINO_GPU) { - // Get the shared buffer pointer - const void* tensor_data = tensor.GetTensorRawData(); - const cl::Buffer* shared_buffer_const = static_cast(tensor_data); - // Create an Input Remote Blob - auto input = graph_input_info.at(0); - auto remote_blob = remote_context_->create_tensor( - input.get_element_type(), input.get_shape(), *shared_buffer_const); - ov::Tensor tensor_remote = static_cast(remote_blob); - OVTensorPtr tensor_ptr = std::make_shared(tensor_remote); - infer_request->SetTensor(input_name, tensor_ptr); - } else { - OVTensorPtr graph_input_blob; - graph_input_blob = infer_request->GetTensor(input_name); - size_t batch_slice_idx = 0; - FillInputBlob(graph_input_blob, batch_slice_idx, input_name, context, subgraph_context_); - } - } + // Enable streams; default=1 unless overridden by user configuration + EnableStreams(); - // Set the output blob as remote blob - auto graph_output_info = exe_network_.Get().outputs(); - for (auto output_info_iter = graph_output_info.begin(); - output_info_iter != graph_output_info.end(); ++output_info_iter) { - auto output_names = output_info_iter->get_names(); - std::string onnx_output_name; - std::string output_name; - bool output_name_found = false; - // using the output name retrieved from ONNX original to match with the output names returned by OV tensors - for (auto it = subgraph_context_.output_names.begin(); it != subgraph_context_.output_names.end(); ++it) { - onnx_output_name = it->first; - if (output_names.find(onnx_output_name) != output_names.end()) { - // Assigning the output_name - output_name = it->first; - output_name_found = true; - break; - } - } - if (!output_name_found) { - ORT_THROW( - log_tag + - "Output names mismatch between OpenVINO and ONNX. [ONNX Output: ] " + - onnx_output_name + " doesn't exist in the list of OpenVINO output tensor names"); - } + // Set the inference_num_threads property of the CPU + SetNumThreads(device_config); - size_t batch_size = 1; - Ort::UnownedValue tensor = GetOutputTensor(context, - batch_size, - infer_request, - output_name, - subgraph_context_.output_names); - auto mem_info = tensor.GetTensorMemoryInfo(); - // Check if ORT Value wraps a device pointer - if (mem_info.GetAllocatorName() == OpenVINO_GPU) { - const void* tensor_data = tensor.GetTensorRawData(); - const cl::Buffer* shared_buffer_const = static_cast(tensor_data); - // Create a shared Blob, set the Infer Request Output Blob - auto output = graph_output_info.at(0); - auto remote_tensor = - remote_context_->create_tensor(output.get_element_type(), output.get_shape(), *shared_buffer_const); - ov::Tensor tensor_t = static_cast(remote_tensor); - OVTensorPtr tensor_ptr = std::make_shared(tensor_t); - try { - infer_request->SetTensor(output_name, tensor_ptr); - } catch (const char* msg) { - ORT_THROW(msg); - } - } - } + auto npuw_status = + std::any_of(device_config.begin(), device_config.end(), [&](const std::pair& pair) { + return (pair.first.find("NPU_USE_NPUW") != std::string::npos) && (pair.second.is()) && + (pair.second.as() == "YES"); + }); - // Start Async inference - infer_request->StartAsync(); - } catch (const char* msg) { - ORT_THROW(msg); + if (npuw_status) { + LOGS_DEFAULT(INFO) << log_tag << "NPUW Enabled during compilation"; } } -#endif - -// Wait for asynchronous inference completion on an Infer Request object indexed by infer_req_idx -// and copy the results into a slice location within the batched output buffer indexed by batch_slice_idx -void BasicBackend::CompleteAsyncInference(Ort::KernelContext& context, OVInferRequestPtr infer_request) { - // Wait for Async inference completion - try { - infer_request->WaitRequest(); - auto graph_output_info = exe_network_.Get().outputs(); - for (auto output_info_iter = graph_output_info.begin(); - output_info_iter != graph_output_info.end(); ++output_info_iter) { - OVTensorPtr graph_output_blob; - auto output_names = output_info_iter->get_names(); - std::string onnx_output_name; - std::string output_name; - bool output_name_found = false; - // using the output name retrieved from ONNX original to match with the output names returned by OV tensors - for (auto it = subgraph_context_.output_names.begin(); it != subgraph_context_.output_names.end(); ++it) { - onnx_output_name = it->first; - if (output_names.find(onnx_output_name) != output_names.end()) { - // Assigning the output_name - output_name = it->first; - output_name_found = true; - break; - } - } - if (!output_name_found) { - ORT_THROW( - log_tag + - "Output names mismatch between OpenVINO and ONNX. " - "[ONNX Output: ] " + - onnx_output_name + - " doesn't exist in the " - "list of OpenVINO output tensor names"); - } - if ((session_context_.device_type.find("CPU") != std::string::npos || - session_context_.device_type.find("GPU") != std::string::npos)) { - try { - graph_output_blob = infer_request->GetTensor(output_name); - } catch (const char* msg) { - ORT_THROW(msg); - } - size_t batch_size = 1; - Ort::UnownedValue output_tensor = - GetOutputTensor(context, batch_size, infer_request, std::move(output_name), subgraph_context_.output_names); - auto mem_info = output_tensor.GetTensorMemoryInfo(); - if (mem_info.GetAllocatorName() == OpenVINO_GPU) { - return; - } else { - size_t batch_slice = 0; - FillOutputBlob(std::move(graph_output_blob), output_tensor, batch_slice); - } - } - } - if (!const_outputs_map_.empty()) { - for (const auto& item : const_outputs_map_) { - const auto& out_name = item.first; - auto node = item.second; - Ort::UnownedValue output_tensor = GetOutputTensor(context, - out_name, - subgraph_context_.output_names, - node); - auto mem_info = output_tensor.GetTensorMemoryInfo(); - if (mem_info.GetAllocatorName() == OpenVINO_GPU) { - ORT_THROW(log_tag + "IO Buffering is not supported for constant subgraphs"); - } else { - FillOutputsWithConstantData(std::move(node), output_tensor); - } - } +void BasicBackend::ValidateOrtDimsAgainstPartialShape(const std::vector& ort_dims, + const ov::PartialShape& partial_shape) const { + // Check if the number of dimensions matches + if (static_cast(ort_dims.size()) != partial_shape.rank().get_length()) { + ORT_THROW("Mismatch in number of dimensions between ORT tensor and OpenVINO PartialShape."); + } + // Validate each dimension + for (size_t i = 0; i < ort_dims.size(); ++i) { + const auto& ov_dim = partial_shape[i]; // OpenVINO dimension at index i + int64_t ort_dim = ort_dims[i]; // ORT dimension at index i + + // Check if the ORT dimension is within the specified range + int64_t min_dim = ov_dim.get_min_length(); + int64_t max_dim = ov_dim.get_max_length(); + if (ort_dim < min_dim || ort_dim > max_dim) { + ORT_THROW(" ORT Dimension is out of range"); } - } catch (const char* msg) { - ORT_THROW(msg); } } -void BasicBackend::Infer(OrtKernelContext* ctx) { - // Preliminary Thread safety mechanism - // currently allows a maximum of 8 Infer request's to parallel execute at the same time +void BasicBackend::RewindKVCache(size_t index) { + infer_req_pool_->forEachIdleRequest([&](OVInferRequestPtr& infer_request) { + infer_request->RewindKVCache(index); + }); +} + +void BasicBackend::Infer(OrtKernelContext* ctx) const { Ort::KernelContext context(ctx); LOGS_DEFAULT(INFO) << log_tag << "Running graph " << subgraph_context_.subgraph_name; @@ -690,79 +445,107 @@ void BasicBackend::Infer(OrtKernelContext* ctx) { for (const auto& item : const_outputs_map_) { std::string out_name = item.first; std::shared_ptr node = item.second; - try { - Ort::UnownedValue output_tensor = GetOutputTensor(context, - std::move(out_name), - subgraph_context_.output_names, - node); - FillOutputsWithConstantData(std::move(node), output_tensor); - } catch (std::string const& msg) { - ORT_THROW(msg); - } + Ort::UnownedValue output_tensor = GetOutputTensor(context, + std::move(out_name), + subgraph_context_.output_names, + node); + FillOutputsWithConstantData(std::move(node), output_tensor); } - // Get Output tensors + LOGS_DEFAULT(INFO) << log_tag << "Inference successful"; - // Enable CI Logs + if (IsCILogEnabled()) { std::cout << "Inference successful" << std::endl; } + return; + } - } else { - // Requesting for an idle infer_request from a pool of infer_requests_ - OVInferRequestPtr infer_request; - infer_request = inferRequestsQueue_->getIdleRequest(); -#ifdef IO_BUFFER_ENABLED - if ((session_context_.device_type.find("GPU") != std::string::npos) && - (session_context_.context != nullptr) && session_context_.is_wholly_supported_graph) { - try { - StartRemoteAsyncInference(context, infer_request); - } catch (std::string const& msg) { - ORT_THROW(msg); - } - } else { - try { - StartAsyncInference(context, infer_request); - } catch (std::string const& msg) { - ORT_THROW(msg); + // guarded_request will be released back to the pool when it goes out of scope + auto guarded_request = infer_req_pool_->getRequest(); + auto& infer_request = guarded_request.infer_request_; + + if (bindings_->has_dynamic_io_) { + // Dynamic shape inference + + // We don't know the output shapes so we need to get the outputs from the infer request and copy them into the ort + // tensors instead of binding them to the infer request directly. + + // Bind inputs + for (const auto& input_info : bindings_->network_inputs_) { + // Set the input shape based on the input tensor from ort + auto tensor = context.GetInput(input_info.onnx_index); + auto ort_shape = tensor.GetTensorTypeAndShapeInfo().GetShape(); + if (input_info.IsBoundedDynamic()) { + ValidateOrtDimsAgainstPartialShape(ort_shape, input_info.shape); } + auto input_shape = ParameterShape(ort_shape); + + infer_request->SetTensor(input_info.name, + input_info.type, + input_shape, + const_cast(tensor.GetTensorRawData())); } -#else - try { - StartAsyncInference(context, infer_request); - } catch (const std::runtime_error& e) { - ORT_THROW(log_tag + " Exception at StartAsyncInference: " + e.what()); + + // Run Inference + infer_request->Infer(); + + // Copy outputs + for (const auto& output_info : bindings_->network_outputs_) { + auto ov_tensor = infer_request->GetTensor(output_info.name); + auto output_shape = ParameterShape::ToOrtShape(ov_tensor->get_shape()); + auto ort_tensor = context.GetOutput(output_info.onnx_index, output_shape); + + ORT_ENFORCE(ov_tensor->get_byte_size() == ort_tensor.GetTensorSizeInBytes(), + log_tag + "Output tensor size mismatch for " + output_info.name); + + std::memcpy(ort_tensor.GetTensorMutableRawData(), + ov_tensor->data(), + ov_tensor->get_byte_size()); } -#endif - try { - CompleteAsyncInference(context, infer_request); - } catch (const std::runtime_error& e) { - ORT_THROW(log_tag + " Exception at CompleteAsyncInference: " + e.what()); + } else { + // Static shape inference + + // Bind inputs + for (const auto& input_info : bindings_->network_inputs_) { + infer_request->SetTensor(input_info.name, + input_info.type, + input_info.shape, + const_cast(context.GetInput(input_info.onnx_index).GetTensorRawData())); } - // Get Output tensors - LOGS_DEFAULT(INFO) << log_tag << "Inference successful"; - // Enable CI Logs - if (IsCILogEnabled()) { - std::cout << "Inference successful" << std::endl; + // Bind outputs + for (const auto& output_info : bindings_->network_outputs_) { + infer_request->SetTensor(output_info.name, + output_info.type, + output_info.shape, + context.GetOutput(output_info.onnx_index, output_info.shape).GetTensorMutableRawData()); } - // Create a duplicate infer_request_ shared ptr on the stack in the current local scope, - // as the infer_request gets freed in the next stage the reference count for the infer_request decrements & - // thus we dont have any dangling ptr leading to seg faults in the debug mode subsequent execution call - OVInferRequestPtr infer_request_ = infer_request; + // Run Inference + infer_request->Infer(); + } + + // Fill constant outputs if needed + for (const auto& [name, node] : const_outputs_map_) { + Ort::UnownedValue output_tensor = GetOutputTensor(context, + name, + subgraph_context_.output_names, + node); + FillOutputsWithConstantData(node, output_tensor); + } + + LOGS_DEFAULT(INFO) << log_tag << "Inference successful"; + if (IsCILogEnabled()) { + std::cout << "Inference successful" << std::endl; + } - // Once the inference is completed, the infer_request becomes free and is placed back into pool of infer_requests_ - inferRequestsQueue_->putIdleRequest(std::move(infer_request)); #ifndef NDEBUG -#ifndef IO_BUFFER_ENABLED // Printing performance counts is disabled when IO_BUFFER_ENABLED - if (openvino_ep::backend_utils::IsDebugEnabled()) { - inferRequestsQueue_->printstatus(); // Printing the elements of infer_requests_ vector pool only in debug mode - std::string& hw_target = session_context_.device_type; - printPerformanceCounts(std::move(infer_request_), std::cout, hw_target); - } -#endif -#endif + // Print performance counts before releasing the infer_request for thread safety + if (openvino_ep::backend_utils::IsDebugEnabled()) { + std::string& hw_target = session_context_.device_type; + printPerformanceCounts(infer_request, std::cout, hw_target); } +#endif } } // namespace openvino_ep diff --git a/onnxruntime/core/providers/openvino/backends/basic_backend.h b/onnxruntime/core/providers/openvino/backends/basic_backend.h index 7d905f4a1e2f7..5c75a9ae183e2 100644 --- a/onnxruntime/core/providers/openvino/backends/basic_backend.h +++ b/onnxruntime/core/providers/openvino/backends/basic_backend.h @@ -13,21 +13,117 @@ #include #include #include +#include +#include #include "core/session/onnxruntime_cxx_api.h" #include "core/providers/openvino/contexts.h" #include "core/providers/openvino/ibackend.h" #include "core/providers/openvino/ov_interface.h" +#include "core/providers/openvino/backend_utils.h" namespace onnxruntime { namespace openvino_ep { -struct ov_tensor_data_t { - OVTensorPtr tensor_ptr; - const void* ort_ptr; +struct ParameterInfo { + std::string name; + uint32_t ov_index; + uint32_t onnx_index; + ov::element::Type type; + ParameterShape shape; + uint8_t dynamic_flags = 0; + + // Query methods + bool IsStatic() const { return dynamic_flags == 0; } + bool IsFullyDynamic() const { return dynamic_flags & 1; } + bool IsBoundedDynamic() const { return dynamic_flags & 2; } + bool IsMixed() const { return (dynamic_flags & 3) == 3; } + + // Setter methods + void SetFullyDynamic(bool value) { + dynamic_flags = value ? (dynamic_flags | 1) : (dynamic_flags & ~1); + } + void SetBoundedDynamic(bool value) { + dynamic_flags = value ? (dynamic_flags | 2) : (dynamic_flags & ~2); + } }; -class InferRequestsQueue; +struct OnnxToOvNetworkBindings { + std::vector network_outputs_; + std::vector network_inputs_; + bool has_dynamic_io_ = false; + + inline static const std::array special_io_names_{ + "beam_idx", + "past_key_values", + "present", + }; + + OnnxToOvNetworkBindings(OVExeNetwork& exec_network, SubGraphContext& subgraph_context, SessionContext& session_context) { + auto populate = [&](auto& input_output_map, const SubGraphContext::string_index_map_t& onnx_input_map, const auto& ov_parameters) { + for (const auto& [onnx_name, onnx_param_index] : onnx_input_map) { + auto it = std::find_if(ov_parameters.begin(), ov_parameters.end(), + [&onnx_name](const auto& ov_parameter_info) { return ov_parameter_info.get_names().contains(onnx_name); }); + bool matched_names = it != ov_parameters.end(); + + // For Stateful Model Compilation, the ONNX model includes KV cache (past/present) tensors. + // However, these tensors are internally converted to a stateful representation, which removes them. + // It's also possible that the onnx model does not contain tensors such as beam_idx, whereas our converted + // stateful representation has introduced these new tensors, creating a name mismatch (matched_names=false). + // So, if there is a name mismatch, or the name matches our special io list, we simply continue processing + // here to prevent runtime exceptions. + if (session_context.enable_causallm) { + if (!matched_names || + std::any_of(special_io_names_.begin(), special_io_names_.end(), + [&onnx_name](const std::string& name) { return onnx_name.find(name) != std::string::npos; })) { + // This case also requires dynamic shape inference, so we'll mark the bindings as dynamic. + has_dynamic_io_ = true; + continue; + } + } + + ORT_ENFORCE(matched_names, log_tag, + "Input names mismatch between OpenVINO and ONNX. ", onnx_name, + " doesn't exist in the list of OpenVINO input tensor names"); + + auto ov_param_index = std::distance(ov_parameters.begin(), it); + + auto shape = ov_parameters[ov_param_index].get_partial_shape(); + auto type = ov_parameters[ov_param_index].get_element_type(); + ParameterInfo info{onnx_name, ov_param_index, onnx_param_index, type, ParameterShape{shape}}; + + // Analyze shape dynamism and set flags + if (!shape.is_static()) { + has_dynamic_io_ = true; + // Analyze dynamic dimensions + bool has_fully_dynamic = false; + bool has_bounded_dynamic = false; + + for (const auto& dim : shape) { + if (dim.is_dynamic()) { + if (dim.get_interval().has_upper_bound()) { + has_bounded_dynamic = true; + } else { + has_fully_dynamic = true; + } + } + } + + info.SetFullyDynamic(has_fully_dynamic); + info.SetBoundedDynamic(has_bounded_dynamic); + } + + input_output_map.push_back(std::move(info)); + } + }; + + // Populate inputs and outputs + populate(network_inputs_, subgraph_context.input_names, exec_network.Get().inputs()); + populate(network_outputs_, subgraph_context.output_names, exec_network.Get().outputs()); + } +}; + +class InferRequestPool; class BasicBackend : public IBackend { public: BasicBackend(std::unique_ptr& model_proto, @@ -36,88 +132,96 @@ class BasicBackend : public IBackend { SharedContext& shared_context, ptr_stream_t& model_stream); - void Infer(OrtKernelContext* context) override; + void Infer(OrtKernelContext* context) const override; ~BasicBackend() override = default; - ov::CompiledModel& GetOVCompiledModel() override { + ov::CompiledModel GetOVCompiledModel() override { return exe_network_.Get(); } + void RewindKVCache(size_t index) override; private: - void PopulateCompiledDirectory(std::string, std::string&, std::string&, bool&); bool ValidateSubgraph(std::map>& const_outputs_map); void PopulateConfigValue(ov::AnyMap& device_config); void EnableCaching(); void EnableGPUThrottling(ov::AnyMap& device_config); void EnableStreams(); void SetNumThreads(ov::AnyMap& device_config); - void StartAsyncInference(Ort::KernelContext& context, std::shared_ptr infer_request); - -#ifdef IO_BUFFER_ENABLED - void StartRemoteAsyncInference(Ort::KernelContext& context, std::shared_ptr infer_request); -#endif - - void CompleteAsyncInference(Ort::KernelContext& context, std::shared_ptr infer_request); + void SetOVDeviceConfiguration(ov::AnyMap& device_config); + void ValidateOrtDimsAgainstPartialShape(const std::vector& ort_dims, + const ov::PartialShape& partial_shape) const; SessionContext& session_context_; SubGraphContext subgraph_context_; SharedContext& shared_context_; - mutable std::mutex compute_lock_; OVExeNetwork exe_network_; std::map> const_outputs_map_; - std::unique_ptr inferRequestsQueue_; -#if defined IO_BUFFER_ENABLED - OVRemoteContextPtr remote_context_; -#endif + std::unique_ptr infer_req_pool_; using ort_tensor_key_t = const std::string; - std::map ort_ov_tensor_map; + std::unique_ptr bindings_; }; -class InferRequestsQueue { +class InferRequestPool { public: - InferRequestsQueue(OVExeNetwork& net, size_t nireq, std::function initializer) { - OVInferRequestPtr infer_request; - for (size_t id = 0; id < nireq; id++) { - infer_request = std::make_shared(net.CreateInferRequest()); - initializer(infer_request); - infer_requests_.push_back(infer_request); + struct GuardedInferReq { + OVInferRequestPtr infer_request_; + GuardedInferReq(InferRequestPool& queue, OVInferRequestPtr&& infer_req) : queue_(queue), infer_request_(std::move(infer_req)) {} + ~GuardedInferReq() { queue_.putIdleRequest(std::move(infer_request_)); } + + // Movable but not copyable + ORT_DISALLOW_COPY_AND_ASSIGNMENT(GuardedInferReq); + GuardedInferReq(GuardedInferReq&&) = default; + GuardedInferReq& operator=(GuardedInferReq&&) = default; + + private: + InferRequestPool& queue_; + friend class InferRequestPool; + }; + + InferRequestPool(OVExeNetwork& net, size_t initial_size, std::function initializer) : exe_network_(net), initializer_(std::move(initializer)) { + for (size_t id = 0; id < initial_size; id++) { + infer_requests_.emplace_back(createInferRequest()); } } + ~InferRequestPool() = default; - ~InferRequestsQueue() { - // clearing out the infer_requests_ vector pool in the class's destructor - for (auto& pointer : infer_requests_) { - pointer = nullptr; + GuardedInferReq getRequest() { + std::unique_lock lock(_mutex); + if (infer_requests_.empty()) { + infer_requests_.emplace_back(createInferRequest()); } - infer_requests_.erase(std::remove(infer_requests_.begin(), infer_requests_.end(), nullptr), infer_requests_.end()); + auto request = std::move(infer_requests_.back()); + infer_requests_.pop_back(); + return GuardedInferReq(*this, std::move(request)); } - void printstatus() { - std::cout << "printing elements of the vector (infer_requests_): " << std::endl; - for (auto i = infer_requests_.begin(); i != infer_requests_.end(); ++i) { - i->get()->QueryStatus(); + template + void forEachIdleRequest(Func&& func) { + std::unique_lock lock(_mutex); + for (auto& infer_request : infer_requests_) { + func(infer_request); } - std::cout << '\n'; } - void putIdleRequest(OVInferRequestPtr infer_request) { - std::unique_lock lock(_mutex); - infer_requests_.push_back(infer_request); - _cv.notify_one(); + private: + void putIdleRequest(OVInferRequestPtr&& infer_request) { + if (infer_request) { + std::unique_lock lock(_mutex); + infer_requests_.emplace_back(std::move(infer_request)); + } } - OVInferRequestPtr getIdleRequest() { - std::unique_lock lock(_mutex); - _cv.wait(lock, [this] { return infer_requests_.size() > 0; }); - auto request = infer_requests_.at(0); - infer_requests_.erase(infer_requests_.begin()); - return request; + OVInferRequestPtr createInferRequest() { + auto infer_request = exe_network_.CreateInferRequest(); + initializer_(infer_request); + return infer_request; } private: std::mutex _mutex; - std::condition_variable _cv; std::vector infer_requests_; + OVExeNetwork& exe_network_; + std::function initializer_; }; } // namespace openvino_ep diff --git a/onnxruntime/core/providers/openvino/contexts.h b/onnxruntime/core/providers/openvino/contexts.h index 7560f4570bd32..6a2b375d733f9 100644 --- a/onnxruntime/core/providers/openvino/contexts.h +++ b/onnxruntime/core/providers/openvino/contexts.h @@ -12,6 +12,7 @@ #include #include "core/common/common.h" #include "core/providers/openvino/ov_interface.h" +#include "core/providers/shared_library/provider_api.h" namespace onnxruntime { namespace openvino_ep { @@ -63,10 +64,12 @@ class SharedContext : public WeakSingleton { fs::path external_weight_filename; std::unique_ptr mapped_weights; Metadata::Map metadata; + fs::path metadata_filepath; } shared_weights; }; using config_t = std::map; +using reshape_t = std::map; struct ProviderInfo { std::string device_type{""}; // [device_type]: Overrides the accelerator hardware type and @@ -84,6 +87,7 @@ struct ProviderInfo { // dump and load the blobs for the model caching/kernel caching // (GPU) feature. If blob files are already present, // it will be directly loaded. + reshape_t reshape{}; // Used for reshaping the ov input tensor shape at runtime. std::string model_priority{"DEFAULT"}; // High-level OpenVINO model priority hint // Defines what model should be provided with more performant // bounded resource first @@ -97,6 +101,7 @@ struct ProviderInfo { bool disable_dynamic_shapes{false}; // [disable_dynamic_shapes]: Rewrite dynamic shaped models to // static shape at runtime and execute. bool enable_qdq_optimizer{false}; // Enables QDQ pruning for efficient inference latency with NPU + bool enable_causallm{false}; // Enables Causal LM Compilation for ORT GenAI OVEP Pass bool so_context_enable{false}; // ORT session option bool so_disable_cpu_ep_fallback{false}; // ORT session option bool so_context_embed_mode{false}; // ORT session option @@ -105,7 +110,7 @@ struct ProviderInfo { const ConfigOptions* config_options{NULL}; const std::unordered_set valid_provider_keys = {"device_type", "device_id", "device_luid", "cache_dir", "precision", "load_config", "context", "num_of_threads", "model_priority", "num_streams", "enable_opencl_throttling", "enable_qdq_optimizer", - "enable_causallm", "disable_dynamic_shapes"}; + "enable_causallm", "disable_dynamic_shapes", "reshape_input"}; }; // Holds context applicable to the entire EP instance. @@ -133,6 +138,7 @@ struct SubGraphContext { string_index_map_t output_names; std::string model_precision; bool is_ep_ctx_graph = false; + bool is_ep_ctx_ovir_encapsulated = false; }; } // namespace openvino_ep diff --git a/onnxruntime/core/providers/openvino/ibackend.h b/onnxruntime/core/providers/openvino/ibackend.h index 04d1f52cbf834..ec38425f602eb 100644 --- a/onnxruntime/core/providers/openvino/ibackend.h +++ b/onnxruntime/core/providers/openvino/ibackend.h @@ -14,9 +14,10 @@ namespace openvino_ep { class IBackend { public: - virtual void Infer(OrtKernelContext* context) = 0; - virtual ov::CompiledModel& GetOVCompiledModel() = 0; + virtual void Infer(OrtKernelContext* context) const = 0; + virtual ov::CompiledModel GetOVCompiledModel() = 0; virtual ~IBackend() = default; + virtual void RewindKVCache(size_t index) {} }; using ptr_stream_t = std::unique_ptr; class BackendFactory { diff --git a/onnxruntime/core/providers/openvino/onnx_ctx_model_helper.cc b/onnxruntime/core/providers/openvino/onnx_ctx_model_helper.cc index 7bd4f8d96cc55..9e70756a254aa 100644 --- a/onnxruntime/core/providers/openvino/onnx_ctx_model_helper.cc +++ b/onnxruntime/core/providers/openvino/onnx_ctx_model_helper.cc @@ -7,6 +7,7 @@ #include #include "core/providers/openvino/onnx_ctx_model_helper.h" +#include "core/providers/openvino/backend_utils.h" namespace onnxruntime { namespace openvino_ep { @@ -123,6 +124,16 @@ std::unique_ptr EPCtxHandler::GetModelBlobStream(const std::filesy ORT_ENFORCE(std::filesystem::exists(blob_filepath), "Blob file not found: ", blob_filepath.string()); result.reset((std::istream*)new std::ifstream(blob_filepath, std::ios_base::binary | std::ios_base::in)); } + + bool isXML = backend_utils::IsModelStreamXML(*result); + if (!isXML) { + // If the model stream is not an XML (i.e. precompiled blob), the OpenVINO SDK version that it was + // exported with must match the version that is currently running. + ORT_ENFORCE((attrs.count(EP_SDK_VER) == 1) && (attrs.at(EP_SDK_VER).s() == openvino_sdk_version_), + "EPCtx blob was exported / is compatible with OpenVINO SDK version " + attrs.at(EP_SDK_VER).s() + + ", but OpenVINO SDK version currently in use is " + openvino_sdk_version_); + } + LOGS_DEFAULT(VERBOSE) << "[OpenVINO EP] Read blob from EPContext Node"; return result; } @@ -142,7 +153,6 @@ bool EPCtxHandler::CheckForOVEPCtxNode(const Node& node) const { if (node.OpType() == EPCONTEXT_OP) { auto& attrs = node.GetAttributes(); bool result = (attrs.count(SOURCE) == 1) && (attrs.at(SOURCE).s() == kOpenVINOExecutionProvider); - result &= (attrs.count(EP_SDK_VER) == 1) && (attrs.at(EP_SDK_VER).s() == openvino_sdk_version_); result &= attrs.count(EMBED_MODE) == 1; result &= attrs.count(EP_CACHE_CONTEXT) == 1; return result; @@ -155,5 +165,32 @@ InlinedVector EPCtxHandler::GetEPCtxNodes() const { return InlinedVector(epctx_nodes.begin(), epctx_nodes.end()); } +// Check if graph's only node is EPContext & EP_CACHE_CONTEXT attribute has target extension. +// @param graph_viewer: The graph to inspect. +// @param target_attr_extn: The string to search for in the EP_CACHE_CONTEXT attribute. +// @return true if the node exists, is of the correct type, and the attribute contains the extension; false otherwise. +bool EPCtxHandler::CheckEPCacheContextAttribute(const GraphViewer& graph_viewer, const std::string& target_attr_extn) const { + // Only check if the graph has exactly one node + if (graph_viewer.NumberOfNodes() != 1) { + return false; + } + // Get the first node in topological order + auto first_index = *graph_viewer.GetNodesInTopologicalOrder().begin(); + const Node* node = graph_viewer.GetNode(first_index); + if (!node) { + return false; + } + // Check OpType and required attributes + if (node->OpType() != EPCONTEXT_OP) { + return false; + } + const auto& attrs = node->GetAttributes(); + auto it = attrs.find(EP_CACHE_CONTEXT); + if (it != attrs.end()) { + return it->second().s().find(target_attr_extn) != std::string::npos; + } + return false; +} + } // namespace openvino_ep } // namespace onnxruntime diff --git a/onnxruntime/core/providers/openvino/onnx_ctx_model_helper.h b/onnxruntime/core/providers/openvino/onnx_ctx_model_helper.h index ff978bd6534d8..b9ddb40a7a233 100644 --- a/onnxruntime/core/providers/openvino/onnx_ctx_model_helper.h +++ b/onnxruntime/core/providers/openvino/onnx_ctx_model_helper.h @@ -33,6 +33,7 @@ class EPCtxHandler { std::string&& model_blob_str) const; std::unique_ptr GetModelBlobStream(const std::filesystem::path& so_context_file_path, const GraphViewer& graph_viewer) const; InlinedVector GetEPCtxNodes() const; + bool CheckEPCacheContextAttribute(const GraphViewer& graph_viewer, const std::string& target_attr_extn) const; private: const std::string openvino_sdk_version_; diff --git a/onnxruntime/core/providers/openvino/openvino_execution_provider.cc b/onnxruntime/core/providers/openvino/openvino_execution_provider.cc index f9d4ab13cf2ce..a0aa04293ac37 100644 --- a/onnxruntime/core/providers/openvino/openvino_execution_provider.cc +++ b/onnxruntime/core/providers/openvino/openvino_execution_provider.cc @@ -55,7 +55,7 @@ static std::vector parseDevices(const std::string& device_string, OpenVINOExecutionProvider::OpenVINOExecutionProvider(const ProviderInfo& info, std::shared_ptr shared_context) : IExecutionProvider{onnxruntime::kOpenVINOExecutionProvider}, session_context_(info), - shared_context_{shared_context}, + shared_context_{std::move(shared_context)}, ep_ctx_handle_{session_context_.openvino_sdk_version, *GetLogger()} { InitProviderOrtApi(); } @@ -102,15 +102,24 @@ common::Status OpenVINOExecutionProvider::Compile( graph_body_viewer_0.DomainToVersionMap().at(kOnnxDomain); } - // Temporary code to read metadata before it moves to the .bin - auto& metadata = shared_context_->shared_weights.metadata; - if (session_context_.so_share_ep_contexts && metadata.empty()) { - // Metadata is always read from model location, this could be a source or epctx model - fs::path metadata_filename = session_context_.onnx_model_path_name.parent_path() / "metadata.bin"; - std::ifstream file(metadata_filename, std::ios::binary); - if (file) { - file >> metadata; + // The block below is executed during EP context model inference + auto& metadata = shared_context_->shared_weights.metadata; // Metadata object in memory + if (session_context_.so_share_ep_contexts && + !session_context_.so_context_enable && + metadata.empty()) { + fs::path context_model_file_path = session_context_.so_context_file_path; + if (context_model_file_path.empty()) { + // If ep.context_file_path is not set the input model path is used + context_model_file_path = session_context_.onnx_model_path_name; } + + // Metadata is always read from model location, this could be a source or epctx model + fs::path metadata_filename = context_model_file_path.stem().string() + "_metadata.bin"; + fs::path metadata_file_path = context_model_file_path.parent_path() / metadata_filename; + std::ifstream file(metadata_file_path, std::ios::binary); + ORT_RETURN_IF_NOT(file, "Metadata file was not found: " + metadata_file_path.string()); + shared_context_->shared_weights.metadata_filepath = metadata_file_path; + file >> metadata; } struct OpenVINOEPFunctionState { @@ -173,22 +182,31 @@ common::Status OpenVINOExecutionProvider::Compile( } } - if (session_context_.so_share_ep_contexts) { - fs::path metadata_filename; - if (session_context_.so_context_file_path.empty()) { - metadata_filename = session_context_.onnx_model_path_name.parent_path() / "metadata.bin"; - } else { - metadata_filename = session_context_.so_context_file_path.parent_path() / "metadata.bin"; + // The block below is executed during EP context model generation + if (session_context_.so_context_enable && + session_context_.so_share_ep_contexts && + !metadata.empty()) { + // For models after the first the metadata name comes from the shared context + fs::path metadata_file_path = shared_context_->shared_weights.metadata_filepath; + if (metadata_file_path.empty()) { + metadata_file_path = session_context_.so_context_file_path; + std::string name_append{"_metadata.bin"}; + if (metadata_file_path.empty()) { + metadata_file_path = session_context_.onnx_model_path_name; + name_append = "_ctx" + name_append; + } + auto metadata_filename = metadata_file_path.stem().string() + name_append; + metadata_file_path.replace_filename(metadata_filename); + shared_context_->shared_weights.metadata_filepath = metadata_file_path; } // Metadata is generated only for shared contexts - // If saving metadata then save it to the provided path or ose the original model path + // If saving metadata then save it to the provided path or use the original model path // Multiple calls to Compile() will update the metadata and for the last call // the resulting file will contain the aggregated content - std::ofstream file(metadata_filename, std::ios::binary); - if (file) { - file << metadata; - } + std::ofstream file{metadata_file_path, std::ios::binary}; + ORT_RETURN_IF_NOT(file, "Metadata file could not be written: ", metadata_file_path); + file << metadata; } return status; @@ -238,10 +256,39 @@ common::Status OpenVINOExecutionProvider::SetEpDynamicOptions(gsl::span= 0) { + backend.RewindKVCache(static_cast(index)); + } else { + LOGS_DEFAULT(WARNING) << "kvcache_rewind index is < 0:\t" << index; } } } else { diff --git a/onnxruntime/core/providers/openvino/openvino_parser_utils.cc b/onnxruntime/core/providers/openvino/openvino_parser_utils.cc index 067076b1f84f2..a78bd1fe2effc 100644 --- a/onnxruntime/core/providers/openvino/openvino_parser_utils.cc +++ b/onnxruntime/core/providers/openvino/openvino_parser_utils.cc @@ -1,4 +1,5 @@ #include +#include #include "core/providers/openvino/openvino_parser_utils.h" #include "core/providers/shared_library/provider_api.h" @@ -116,5 +117,124 @@ std::string OpenVINOParserUtils::ParsePrecision(const ProviderOptions& provider_ } } +reshape_t OpenVINOParserUtils::ParseInputShape(const std::string& reshape_input_definition) { + reshape_t parsed_shape_map; + + // Return empty map for empty input + if (reshape_input_definition.empty()) { + ORT_THROW("Empty input shape definition provided in reshape_input parameter"); + } + + // Regular expressions for parsing + const std::regex tensor_pattern(R"(([^\[\],]+)\s*\[(.*?)\])"); // e.g. "input_1[1..5, 2, 3..4],data[1,2,3]" + // const std::regex dimension_pattern(R"(\s*(\d+(?:\.\.\d+)?)\s*)"); // e.g. "1..5", "2", "3..4" + const std::regex dimension_pattern(R"(\s*([^,\s]+)\s*)"); + // Find all tensor shape definitions using regex + auto tensor_begin = std::sregex_iterator( + reshape_input_definition.begin(), + reshape_input_definition.end(), + tensor_pattern); + auto tensor_end = std::sregex_iterator(); + + // If no matches found, throw error + if (tensor_begin == tensor_end) { + ORT_THROW("Invalid input shape definition format: " + reshape_input_definition); + } + + // Process each tensor definition e.g. "input_1[1..5, 2, 3..4],data[1,2,3]" + for (std::sregex_iterator i = tensor_begin; i != tensor_end; ++i) { + std::smatch tensor_match = *i; + + // Extract tensor name and trim whitespace + std::string tensor_name = tensor_match[1].str(); // Group 1: tensor name e.g. "input_1" + tensor_name = TrimWhitespace(tensor_name); + + if (tensor_name.empty()) { + ORT_THROW("Empty tensor name provided in reshape_input parameter"); + } + + // Extract dimensions string + std::string dimensions_str = tensor_match[2].str(); // Group 2: dimensions string [e.g. "1..5, 2, 3..4"] + std::vector dimensions; + + // Find all dimension e.g. "1..5", "2", "3..4" using regex + auto dim_begin = std::sregex_iterator( + dimensions_str.begin(), + dimensions_str.end(), + dimension_pattern); + auto dim_end = std::sregex_iterator(); + + // Process each dimension + for (std::sregex_iterator j = dim_begin; j != dim_end; ++j) { + std::smatch dim_match = *j; + std::string dim_value = dim_match[1].str(); + + // Check if dimension is a range + size_t range_separator_pos = dim_value.find(".."); + if (range_separator_pos != std::string::npos) { + // Parse range + dimensions.push_back(ParseDimensionRange(dim_value, tensor_name)); + } else { + // Parse single value + bool is_valid_integer = !dim_value.empty() && + std::all_of(dim_value.begin(), dim_value.end(), [](char c) { + return std::isdigit(static_cast(c)); + }); + + if (!is_valid_integer) { + ORT_THROW("Invalid dimension value: '" + dim_value + "' for tensor: " + tensor_name); + } + + dimensions.push_back(std::stoi(dim_value)); + } + } + + // Store parsed shape in result map + parsed_shape_map[tensor_name] = ov::PartialShape(dimensions); + } + + return parsed_shape_map; +} + +// Helper function to trim whitespace from a string +std::string OpenVINOParserUtils::TrimWhitespace(const std::string& str) { + const std::string whitespace = " \t\n\r\f\v"; + size_t start = str.find_first_not_of(whitespace); + + if (start == std::string::npos) { + return ""; + } + + size_t end = str.find_last_not_of(whitespace); + return str.substr(start, end - start + 1); +} + +// Helper function to parse dimension range (e.g. "1..5") +ov::Dimension OpenVINOParserUtils::ParseDimensionRange(const std::string& range_str, const std::string& tensor_name) { + size_t range_separator_pos = range_str.find(".."); + if (range_separator_pos == std::string::npos) { + ORT_THROW("Invalid dimension range format: " + range_str); + } + + std::string range_start_str = TrimWhitespace(range_str.substr(0, range_separator_pos)); + std::string range_end_str = TrimWhitespace(range_str.substr(range_separator_pos + 2)); + + // Validate range values + if (range_start_str.empty() || range_end_str.empty() || + !std::all_of(range_start_str.begin(), range_start_str.end(), ::isdigit) || + !std::all_of(range_end_str.begin(), range_end_str.end(), ::isdigit)) { + ORT_THROW("Invalid dimension range format: '" + range_str + "' for tensor: " + tensor_name); + } + + int range_start = std::stoi(range_start_str); + int range_end = std::stoi(range_end_str); + + if (range_start > range_end) { + ORT_THROW("Invalid dimension range (start > end): " + range_str + " for tensor: " + tensor_name); + } + + return ov::Dimension(range_start, range_end); +} + } // namespace openvino_ep } // namespace onnxruntime diff --git a/onnxruntime/core/providers/openvino/openvino_parser_utils.h b/onnxruntime/core/providers/openvino/openvino_parser_utils.h index 3e23c9e788463..e6aa0e0a46a3b 100644 --- a/onnxruntime/core/providers/openvino/openvino_parser_utils.h +++ b/onnxruntime/core/providers/openvino/openvino_parser_utils.h @@ -7,6 +7,7 @@ #include #include "core/framework/provider_options.h" +#include "core/providers/openvino/contexts.h" namespace onnxruntime { namespace openvino_ep { @@ -16,6 +17,9 @@ class OpenVINOParserUtils { static std::string ParsePrecision(const ProviderOptions& provider_options, std::string& device_type, const std::string& option_name); + static reshape_t ParseInputShape(const std::string& reshape_input_definition); + static std::string TrimWhitespace(const std::string& str); + static ov::Dimension ParseDimensionRange(const std::string& range_str, const std::string& tensor_name); }; } // namespace openvino_ep diff --git a/onnxruntime/core/providers/openvino/openvino_provider_factory.cc b/onnxruntime/core/providers/openvino/openvino_provider_factory.cc index f7f15dc62fd11..bad1d416eeda2 100644 --- a/onnxruntime/core/providers/openvino/openvino_provider_factory.cc +++ b/onnxruntime/core/providers/openvino/openvino_provider_factory.cc @@ -117,8 +117,6 @@ std::string ParseDeviceType(std::shared_ptr ov_core, const ProviderOptio luid_list = split(luid_str, ','); } - bool all_devices_found = true; - for (auto device : devices_to_check) { bool device_found = false; // Check deprecated device format (CPU_FP32, GPU.0_FP16, etc.) and remove the suffix in place @@ -137,8 +135,11 @@ std::string ParseDeviceType(std::shared_ptr ov_core, const ProviderOptio // Here we need to find the full device name (with .idx, but without _precision) if (std::find(std::begin(available_devices), std::end(available_devices), device) != std::end(available_devices)) device_found = true; + if (!device_found) { + ORT_THROW("[ERROR] [OpenVINO] Device ", device, " is not available"); + } if (device_prefix != "CPU" && luid_list.size() > 0) { - for (auto dev : available_devices) { + for (const auto& dev : available_devices) { ov::device::LUID ov_luid = OVCore::Get()->core.get_property(dev, ov::device::luid); std::stringstream ov_luid_str; ov_luid_str << ov_luid; @@ -149,11 +150,10 @@ std::string ParseDeviceType(std::shared_ptr ov_core, const ProviderOptio ORT_THROW(msg); } } - all_devices_found = all_devices_found && device_found; } if (luid_list.size() > 0) { std::string ov_luid_devices; - for (auto luid_str : luid_list) { + for (const auto& luid_str : luid_list) { if (ov_luid_map.contains(luid_str)) { std::string ov_dev = ov_luid_map.at(luid_str); std::string ov_dev_strip = split(ov_dev, '.')[0]; @@ -170,26 +170,19 @@ std::string ParseDeviceType(std::shared_ptr ov_core, const ProviderOptio } if (!device_mode.empty()) { selected_device = device_mode + ":" + ov_luid_devices; - for (auto dev_str : devices_to_check) { - auto default_dev = split(dev_str, '.')[0]; + for (const auto& dev_str : devices_to_check) { + const auto default_dev = split(dev_str, '.')[0]; if (ov_luid_devices.find(default_dev) == std::string::npos) selected_device = selected_device + "," + dev_str; } } else { - selected_device = ov_luid_devices; + selected_device = std::move(ov_luid_devices); } } - // If invalid device is chosen error is thrown - if (!all_devices_found) { - ORT_THROW( - "[ERROR] [OpenVINO] You have selected wrong configuration value for the key 'device_type'. " - "Select from 'CPU', 'GPU', 'NPU', 'GPU.x' where x = 0,1,2 and so on or from" - " HETERO/MULTI/AUTO/BATCH options available. \n"); - } else { - LOGS_DEFAULT(INFO) << "[OpenVINO-EP] Choosing Device: " << selected_device; - return selected_device; - } + + LOGS_DEFAULT(INFO) << "[OpenVINO-EP] Choosing Device: " << selected_device; + return selected_device; } void ParseProviderOptions([[maybe_unused]] ProviderInfo& result, [[maybe_unused]] const ProviderOptions& config_options) {} @@ -215,7 +208,7 @@ static void ParseProviderInfo(const ProviderOptions& provider_options, // Minor optimization: we'll hold an OVCore reference to ensure we don't create a new core between ParseDeviceType and // (potential) SharedContext creation. auto ov_core = OVCore::Get(); - pi.device_type = ParseDeviceType(ov_core, provider_options); + pi.device_type = ParseDeviceType(std::move(ov_core), provider_options); if (provider_options.contains("device_id")) { std::string dev_id = provider_options.at("device_id").data(); @@ -233,6 +226,10 @@ static void ParseProviderInfo(const ProviderOptions& provider_options, pi.precision = OpenVINOParserUtils::ParsePrecision(provider_options, pi.device_type, "precision"); + if (provider_options.contains("reshape_input")) { + pi.reshape = OpenVINOParserUtils::ParseInputShape(provider_options.at("reshape_input")); + } + if (provider_options.contains("load_config")) { auto parse_config = [&](const std::string& config_str) -> std::map { // If the config string is empty, return an empty map and skip processing @@ -343,19 +340,26 @@ static void ParseProviderInfo(const ProviderOptions& provider_options, pi.enable_qdq_optimizer = ParseBooleanOption(provider_options, "enable_qdq_optimizer"); + pi.enable_causallm = ParseBooleanOption(provider_options, "enable_causallm"); + pi.disable_dynamic_shapes = ParseBooleanOption(provider_options, "disable_dynamic_shapes"); } catch (std::string msg) { ORT_THROW(msg); } - // Always true for NPU plugin or when passed . - if (pi.device_type.find("NPU") != std::string::npos) { - pi.disable_dynamic_shapes = true; - } + + // Should likely account for meta devices as well, but for now keep the current behavior. + bool target_devices_support_dynamic_shapes = + pi.device_type.find("GPU") != std::string::npos || + pi.device_type.find("CPU") != std::string::npos || + (pi.device_type.find("NPU") != std::string::npos && + pi.enable_causallm); + + pi.disable_dynamic_shapes = !target_devices_support_dynamic_shapes; } struct OpenVINOProviderFactory : IExecutionProviderFactory { OpenVINOProviderFactory(ProviderInfo provider_info, std::shared_ptr shared_context) - : provider_info_(std::move(provider_info)), shared_context_(shared_context) {} + : provider_info_(std::move(provider_info)), shared_context_(std::move(shared_context)) {} ~OpenVINOProviderFactory() override {} diff --git a/onnxruntime/core/providers/openvino/ov_allocator.cc b/onnxruntime/core/providers/openvino/ov_allocator.cc index 02e83213250a9..248c859b20dee 100644 --- a/onnxruntime/core/providers/openvino/ov_allocator.cc +++ b/onnxruntime/core/providers/openvino/ov_allocator.cc @@ -28,7 +28,7 @@ void* OVRTAllocator::Alloc(size_t size) { try { ov::Tensor* tensor = new ov::Tensor(remote_ctx_.create_host_tensor(ov::element::Type_t::u8, {size})); - std::unique_lock lock(mutex_); + std::lock_guard lock(mutex_); allocated_.insert({tensor->data(), tensor}); return reinterpret_cast(tensor->data()); } catch (const ov::Exception& e) { @@ -38,12 +38,16 @@ void* OVRTAllocator::Alloc(size_t size) { void OVRTAllocator::Free(void* p) { try { - std::unique_lock lock(mutex_); - auto it = allocated_.find(p); - if (it != allocated_.end()) { - ov::Tensor* tensor = it->second; - allocated_.erase(it); - lock.unlock(); + ov::Tensor* tensor = nullptr; + { + std::lock_guard lock(mutex_); + auto it = allocated_.find(p); + if (it != allocated_.end()) { + tensor = it->second; + allocated_.erase(it); + } + } + if (tensor) { delete tensor; } } catch (const ov::Exception& e) { diff --git a/onnxruntime/core/providers/openvino/ov_interface.cc b/onnxruntime/core/providers/openvino/ov_interface.cc index a175ca863d1d1..918940b9d9917 100644 --- a/onnxruntime/core/providers/openvino/ov_interface.cc +++ b/onnxruntime/core/providers/openvino/ov_interface.cc @@ -3,17 +3,28 @@ #include "core/providers/openvino/ov_interface.h" +#include + #define ORT_API_MANUAL_INIT #include "core/session/onnxruntime_cxx_api.h" #include "core/providers/shared_library/provider_api.h" #include "core/providers/openvino/backend_utils.h" - -using Exception = ov::Exception; +#include "core/providers/openvino/backends/basic_backend.h" +#include "core/providers/openvino/ov_stateful_patch_utils.h" namespace onnxruntime { namespace openvino_ep { -static const std::string log_tag = "[OpenVINO-EP] "; +template +inline auto OvExceptionBoundary(Func&& func, std::format_string&& fmt, Args&&... args) { + try { + return func(); + } catch (const ov::Exception& e) { + ORT_THROW(log_tag + std::vformat(fmt.get(), std::make_format_args(args...)) + ": " + std::string(e.what())); + } catch (...) { + ORT_THROW(log_tag + std::vformat(fmt.get(), std::make_format_args(args...))); + } +} #ifndef NDEBUG void printDebugInfo(const ov::CompiledModel& obj) { @@ -38,274 +49,460 @@ void printDebugInfo(const ov::CompiledModel& obj) { std::cout << " " << item2.first << ": " << item2.second.as() << std::endl; } } - } else { - std::cout << " " << cfg << ": " << prop.as() << std::endl; + else { + std::cout << " " << cfg << ": " << prop.as() << std::endl; + } } } } -} #endif -// Function to check if a given OV property is enabled -std::optional queryOVProperty(const std::string& property, const std::string& device_type) { - try { - // Get the property value - auto supported_properties = OVCore::Get()->core.get_property(device_type, ov::supported_properties); - return std::find(supported_properties.begin(), supported_properties.end(), property) != supported_properties.end(); - } catch (const std::exception&) { - return std::nullopt; // Property not found or invalid + // Function to check if a given OV property is enabled + std::optional queryOVProperty(const std::string& property, const std::string& device_type) { + try { + // Get the property value + auto supported_properties = OVCore::Get()->core.get_property(device_type, ov::supported_properties); + return std::find(supported_properties.begin(), supported_properties.end(), property) != supported_properties.end(); + } catch (const std::exception&) { + return std::nullopt; // Property not found or invalid + } } -} -std::shared_ptr OVCore::ReadModel(std::string&& model, const std::string& model_path) { - try { - std::istringstream modelStringStream(std::move(model)); - std::istream& modelStream = modelStringStream; - // Try to load with FrontEndManager - ov::frontend::FrontEndManager manager; - ov::frontend::FrontEnd::Ptr FE; - ov::frontend::InputModel::Ptr inputModel; - - ov::AnyVector params{&modelStream, model_path}; - - FE = manager.load_by_model(params); - if (FE) { - inputModel = FE->load(params); - return FE->convert(inputModel); + std::shared_ptr OVCore::ReadModel(std::string && model, const std::string& model_path) { + return OvExceptionBoundary([&]() { + std::istringstream modelStringStream(std::move(model)); + std::istream& modelStream = modelStringStream; + // Try to load with FrontEndManager + ov::frontend::FrontEndManager manager; + ov::frontend::FrontEnd::Ptr FE; + ov::frontend::InputModel::Ptr inputModel; + + ov::AnyVector params{&modelStream, model_path}; + + FE = manager.load_by_model(params); + if (FE) { + inputModel = FE->load(params); + return FE->convert(inputModel); + } else { + ORT_THROW(log_tag + "Unknown exception while Reading network"); + } + }, + "Exception while Reading network"); + } + + OVExeNetwork OVCore::StatefulCompileModel(std::shared_ptr & model, + std::string & hw_target, + const ov::AnyMap& device_config) { + ov::CompiledModel compiled_model; + ov::AnyMap config = device_config; + + if (onnxruntime::openvino_ep::backend_utils::IsDebugEnabled()) { + std::cout << "Stateless OV Model Statistic:" << std::endl; + LogBasicModelInfo(model); + } + + bool model_status = IsStateful(model); + LOGS_DEFAULT(INFO) << log_tag << "Model IsStateful() Status:\t" << (model_status ? "True" : "False"); + if (!model_status) { + LOGS_DEFAULT(INFO) << log_tag << "Converting from Stateless OV Model to Stateful OV Model" << std::endl; + PatchStatefulDecoder(model); + } + + if (onnxruntime::openvino_ep::backend_utils::IsDebugEnabled()) { + std::cout << "Stateful OV Model Statistic:" << std::endl; + LogBasicModelInfo(model); + } + + auto kv_pos = GetKVAxesPos(model); + + if (hw_target.find("NPU") != std::string::npos) { + KVDesc kv_desc; + auto parse_genai_config = [&](const std::string& key, unsigned int default_value) { + return (config.count(key) && !config.at(key).empty() && config.at(key).as() != "0") ? config.at(key).as() : default_value; + }; + + kv_desc.max_prompt_len = parse_genai_config("MAX_PROMPT_LEN", CausalLMConfig().max_prompt_len); + kv_desc.min_response_len = parse_genai_config("MIN_RESPONSE_LEN", CausalLMConfig().min_response_len); + + // For compilation, MAX_PROMPT_LEN & MIN_RESPONSE_LEN should not be 0 + if (kv_desc.max_prompt_len == 0 || kv_desc.min_response_len == 0) { + ORT_THROW(log_tag + "MAX_PROMPT_LEN and MIN_RESPONSE_LEN cannot be 0 or empty"); + } + + if (onnxruntime::openvino_ep::backend_utils::IsDebugEnabled()) { + std::cout << "kv_pos.batch = " << kv_pos.batch << std::endl; + std::cout << "kv_pos.seq_len = " << kv_pos.seq_len << std::endl; + std::cout << "kv_desc.max_prompt_len:\t" << kv_desc.max_prompt_len << std::endl; + std::cout << "kv_desc.min_response_len:\t" << kv_desc.min_response_len << std::endl; + } + + UpdateNPUConfig(config, kv_pos, kv_desc); } else { - ORT_THROW(log_tag + "[OpenVINO-EP] Unknown exception while Reading network"); + // This patches the OV IR model so that it only produces the logits required for sampling. + // Actually either way that happens within NPUW::LLMCompiledModel creation for NPU device, + // while this is here mostly to align this behavior for other devices viz. (CPU, GPU). + ApplySliceBeforeMatmulTransformation(model); } - } catch (const Exception& e) { - ORT_THROW(log_tag + "[OpenVINO-EP] Exception while Reading network: " + std::string(e.what())); - } catch (...) { - ORT_THROW(log_tag + "[OpenVINO-EP] Unknown exception while Reading network"); - } -} -OVExeNetwork OVCore::CompileModel(std::shared_ptr& ie_cnn_network, - std::string& hw_target, - ov::AnyMap& device_config, - const std::string& name) { - ov::CompiledModel obj; - try { - obj = core.compile_model(ie_cnn_network, hw_target, device_config); -#ifndef NDEBUG - printDebugInfo(obj); -#endif - OVExeNetwork exe(obj); + LOGS_DEFAULT(INFO) << log_tag << "Compiling OV Model using Stateful Transformation flow"; + compiled_model = OVCore::Get()->core.compile_model(model, hw_target, config); + OVExeNetwork exe(compiled_model, hw_target, true); return exe; - } catch (const Exception& e) { - ORT_THROW(log_tag + " Exception while Loading Network for graph: " + name + e.what()); - } catch (...) { - ORT_THROW(log_tag + " Exception while Loading Network for graph " + name); } -} -OVExeNetwork OVCore::CompileModel(const std::string& onnx_model, - std::string& hw_target, - ov::AnyMap& device_config, - const std::string& name) { - ov::CompiledModel obj; - try { - obj = core.compile_model(onnx_model, ov::Tensor(), hw_target, device_config); + OVExeNetwork OVCore::CompileModel(std::shared_ptr & ie_cnn_network, + std::string & hw_target, + ov::AnyMap & device_config, + bool enable_causallm, + const std::string& name) { + return OvExceptionBoundary([&]() { + OVExeNetwork exe; + if (enable_causallm) { + auto mutable_model = ie_cnn_network->clone(); + exe = OVCore::Get()->StatefulCompileModel(mutable_model, hw_target, device_config); + } else { + auto obj = core.compile_model(ie_cnn_network, hw_target, device_config); + exe = OVExeNetwork(obj, hw_target); + } + #ifndef NDEBUG - printDebugInfo(obj); + printDebugInfo(exe.Get()); #endif - OVExeNetwork exe(obj); - return exe; - } catch (const Exception& e) { - ORT_THROW(log_tag + " Exception while Loading Network for graph: " + name + e.what()); - } catch (...) { - ORT_THROW(log_tag + " Exception while Loading Network for graph " + name); + + return exe; + }, + "Exception while Loading Network for graph {}", name); } -} -OVExeNetwork OVCore::ImportModel(std::istream& model_stream, - std::string hw_target, - const ov::AnyMap& device_config, - std::string name) { - try { - ov::CompiledModel obj; - obj = core.import_model(model_stream, hw_target, device_config); + OVExeNetwork OVCore::CompileModel(const std::string& onnx_model, + std::string& hw_target, + ov::AnyMap& device_config, + const std::string& name) { + return OvExceptionBoundary([&]() { + ov::CompiledModel obj; + + obj = core.compile_model(onnx_model, ov::Tensor(), hw_target, device_config); #ifndef NDEBUG - printDebugInfo(obj); + printDebugInfo(obj); #endif - OVExeNetwork exe(obj); - return exe; - } catch (const Exception& e) { - ORT_THROW(log_tag + " Exception while Loading Network for graph: " + name + e.what()); - } catch (...) { - ORT_THROW(log_tag + " Exception while Loading Network for graph " + name); + OVExeNetwork exe(obj, hw_target); + return exe; + }, + "Exception while Loading Network for graph {}", name); } -} - -void OVCore::SetCache(const std::string& cache_dir_path) { - core.set_property(ov::cache_dir(cache_dir_path)); -} -#ifdef IO_BUFFER_ENABLED -OVExeNetwork OVCore::CompileModel(std::shared_ptr& model, - OVRemoteContextPtr context, std::string name) { - try { - auto obj = core.compile_model(model, *context); + OVExeNetwork OVCore::ImportModel(std::istream & model_stream, + std::string hw_target, + const ov::AnyMap& device_config, + std::string name) { + return OvExceptionBoundary([&]() { + ov::CompiledModel obj; + obj = core.import_model(model_stream, hw_target, device_config); + OVExeNetwork exe(obj, hw_target); #ifndef NDEBUG - printDebugInfo(obj); + printDebugInfo(exe.Get()); #endif - return OVExeNetwork(obj); - } catch (const Exception& e) { - ORT_THROW(log_tag + " Exception while Loading Network for graph: " + name + e.what()); - } catch (...) { - ORT_THROW(log_tag + " Exception while Loading Network for graph " + name); + return exe; + }, + "Exception while Loading Network for graph {}", name); } -} -OVExeNetwork OVCore::ImportModel(std::shared_ptr model_stream, - OVRemoteContextPtr context, std::string name) { - try { - auto obj = core.import_model(*model_stream, *context); + + OVExeNetwork OVCore::ImportEPCtxOVIREncapsulation(std::istream & model_stream, + std::string & hw_target, + const ov::AnyMap& device_config, + bool enable_causallm, + std::filesystem::path model_file_path) { + return OvExceptionBoundary([&]() { + OVExeNetwork exe; + + bool isXML = backend_utils::IsModelStreamXML(model_stream); + + // Helper function to check if file exists and is readable + const auto check_file_access = [&model_file_path](const std::filesystem::path& path) { + try { + if (!std::filesystem::exists(path) || std::filesystem::is_empty(path)) { + ORT_THROW(log_tag + "Required file missing or empty: " + path.string()); + } + std::ifstream file(path); + if (!file) { + ORT_THROW(log_tag + "Required file not readable: " + path.string()); + } + } catch (const std::exception& e) { + ORT_THROW(log_tag + "Exception while checking file access for: " + path.string() + " - " + e.what()); + } + }; + + if (isXML) { + // If the model is XML, we need to load it with the XML content in read_model() + // where weights from bin file is directly consumed + auto xml_file_path = model_file_path.parent_path() / (model_file_path.stem().string() + ".xml"); + + check_file_access(xml_file_path); + + LOGS_DEFAULT(INFO) << log_tag << "Reading OVIR from XML file path: " << xml_file_path.string(); + + // Load the model explicitly with XML contents + std::shared_ptr model = core.read_model(xml_file_path.string()); + + if (enable_causallm) { + exe = OVCore::Get()->StatefulCompileModel(model, hw_target, device_config); + } else { + auto obj = core.compile_model(model, hw_target, device_config); + exe = OVExeNetwork(obj, hw_target); + } + } + #ifndef NDEBUG - printDebugInfo(obj); + printDebugInfo(exe.Get()); #endif - OVExeNetwork exe(obj); - return exe; - } catch (const Exception& e) { - ORT_THROW(log_tag + " Exception while Loading Network for graph: " + name + e.what()); - } catch (...) { - ORT_THROW(log_tag + " Exception while Loading Network for graph " + name); + return exe; + }, + "Exception while Loading Network from OVIR model file: {}", model_file_path.string()); } -} -#endif -std::vector OVCore::GetAvailableDevices() const { - std::vector available_devices = core.get_available_devices(); - return available_devices; -} + void OVCore::SetCache(const std::string& cache_dir_path) { + core.set_property(ov::cache_dir(cache_dir_path)); + } -std::vector OVCore::GetAvailableDevices(const std::string& device_type) const { - std::vector available_devices; - std::vector devicesIDs; - // Uses logic from OpenVINO to only return available devices of the specified type (e.g. CPU, NPU or GPU) - try { - devicesIDs = core.get_property(device_type, ov::available_devices); - } catch (const ov::Exception&) { - // plugin is not created by e.g. invalid env - // Empty device list will be returned - } catch (const std::runtime_error& ex) { - // plugin is not created by e.g. invalid env - // Empty device list will be returned - ORT_THROW("[ERROR] [OpenVINO] An exception occurred while trying to create the ", - device_type, - " device: ", - ex.what()); - } catch (const std::exception& ex) { - ORT_THROW("[ERROR] [OpenVINO] An exception occurred while trying to create the ", - device_type, - " device: ", - ex.what()); - } catch (...) { - ORT_THROW("[ERROR] [OpenVINO] Unknown exception occurred while trying to create the ", - device_type, - " device"); + std::vector OVCore::GetAvailableDevices() const { + std::vector available_devices = core.get_available_devices(); + return available_devices; } - if (devicesIDs.size() > 1 || - (devicesIDs.size() == 1 && devicesIDs[0] == "0")) { - for (const auto& deviceID : devicesIDs) { - available_devices.push_back(device_type + '.' + deviceID); + std::vector OVCore::GetAvailableDevices(const std::string& device_type) const { + std::vector available_devices; + std::vector devicesIDs; + // Uses logic from OpenVINO to only return available devices of the specified type (e.g. CPU, NPU or GPU) + try { + devicesIDs = core.get_property(device_type, ov::available_devices); + } catch (const ov::Exception&) { + // plugin is not created by e.g. invalid env + // Empty device list will be returned + } catch (const std::exception& ex) { + ORT_THROW(log_tag + "An exception occurred while trying to create the ", + device_type, + " device: ", + ex.what()); + } catch (...) { + ORT_THROW(log_tag + "Unknown exception occurred while trying to create the ", + device_type, + " device"); } + + if (devicesIDs.size() > 1 || + (devicesIDs.size() == 1 && devicesIDs[0] == "0")) { + for (const auto& deviceID : devicesIDs) { + available_devices.push_back(device_type + '.' + deviceID); + } + } + if (!devicesIDs.empty()) { + available_devices.push_back(device_type); + } + + return available_devices; } - if (!devicesIDs.empty()) { - available_devices.push_back(device_type); + + void OVCore::SetStreams(const std::string& device_type, int num_streams) { + core.set_property(device_type, {ov::num_streams(num_streams)}); } - return available_devices; -} + std::shared_ptr OVExeNetwork::CreateInferRequest() { + return OvExceptionBoundary([&]() { + auto infReq = compiled_model_obj.create_infer_request(); + std::shared_ptr ovInfReq; + if (is_stateful_causallm) { + ovInfReq = std::make_shared(std::move(infReq), target_device); + } else { + ovInfReq = std::make_shared(std::move(infReq)); + } + return ovInfReq; + }, -void OVCore::SetStreams(const std::string& device_type, int num_streams) { - core.set_property(device_type, {ov::num_streams(num_streams)}); -} + "Exception while creating InferRequest object"); + } -OVInferRequest OVExeNetwork::CreateInferRequest() { - try { - auto infReq = obj.create_infer_request(); - OVInferRequest inf_obj(std::move(infReq)); - return inf_obj; - } catch (const Exception& e) { - ORT_THROW(log_tag + "Exception while creating InferRequest object: " + e.what()); - } catch (...) { - ORT_THROW(log_tag + "Exception while creating InferRequest object."); + OVTensorPtr OVInferRequest::GetTensor(const std::string& input_name) { + return OvExceptionBoundary([&]() { + auto tobj = ovInfReq.get_tensor(input_name); + OVTensorPtr blob = std::make_shared(tobj); + return blob; + }, + " Cannot access IE Blob for input: {}", input_name); } -} -OVTensorPtr OVInferRequest::GetTensor(const std::string& input_name) { - try { - auto tobj = ovInfReq.get_tensor(input_name); - OVTensorPtr blob = std::make_shared(tobj); - return blob; - } catch (const Exception& e) { - ORT_THROW(log_tag + " Cannot access IE Blob for input: " + input_name + e.what()); - } catch (...) { - ORT_THROW(log_tag + " Cannot access IE Blob for input: " + input_name); + std::string OVInferRequest::GetInputTensorName(uint32_t index) { + return OvExceptionBoundary([&]() { + const auto& model = ovInfReq.get_compiled_model(); + return *model.input(index).get_names().begin(); + }, + " Cannot access IE Blob for input number: {}", index); } -} -std::string OVInferRequest::GetInputTensorName(uint32_t index) { - try { - const auto& model = ovInfReq.get_compiled_model(); - return *model.input(index).get_names().begin(); - } catch (const Exception& e) { - ORT_THROW(log_tag + " Cannot access IE Blob for input number: ", index, e.what()); - } catch (...) { - ORT_THROW(log_tag + " Cannot access IE Blob for input number: ", index); + void OVInferRequest::SetTensor(const std::string& name, OVTensorPtr& blob) { + OvExceptionBoundary([&]() { + ovInfReq.set_tensor(name, *(blob.get())); + }, + " Cannot set Remote Blob for output: {}", name); } -} -void OVInferRequest::SetTensor(const std::string& name, OVTensorPtr& blob) { - try { - ovInfReq.set_tensor(name, *(blob.get())); - } catch (const Exception& e) { - ORT_THROW(log_tag + " Cannot set Remote Blob for output: " + name + e.what()); - } catch (...) { - ORT_THROW(log_tag + " Cannot set Remote Blob for output: " + name); + uint32_t OVInferRequest::GetNumInputs() { + return static_cast(ovInfReq.get_compiled_model().inputs().size()); } -} -uint32_t OVInferRequest::GetNumInputs() { - return static_cast(ovInfReq.get_compiled_model().inputs().size()); -} + void OVInferRequest::Infer() { + OvExceptionBoundary([&]() { + ovInfReq.infer(); + }, + "In Error Couldn't start Inference"); + } -void OVInferRequest::StartAsync() { - try { - ovInfReq.start_async(); - } catch (const Exception& e) { - ORT_THROW(log_tag + " Couldn't start Inference: " + e.what()); - } catch (...) { - ORT_THROW(log_tag + " In Error Couldn't start Inference"); + StatefulOVInferRequest::StatefulOVInferRequest(ov::InferRequest infer_request, std::string device) + : OVInferRequest(std::move(infer_request)), target_device(device) { + bool gpu_or_npu = ((device.find("NPU") != std::string::npos) || (device.find("GPU") != std::string::npos)); + if (gpu_or_npu) { + prefill_use_full_chat_history = true; + } } -} -void OVInferRequest::Infer() { - try { - ovInfReq.infer(); - } catch (const Exception& e) { - ORT_THROW(log_tag + " Couldn't start Inference: " + e.what()); - } catch (...) { - ORT_THROW(log_tag + " In Error Couldn't start Inference"); + void StatefulOVInferRequest::FillTensor(const std::string& tensor_name, const ov::element::Type& type, + const std::vector& shape, int32_t fill_value) { + ov::Tensor tensor = ov::Tensor(type, shape); + std::fill_n(tensor.data(), tensor.get_size(), fill_value); + ovInfReq.set_tensor(tensor_name, tensor); } -} -void OVInferRequest::WaitRequest() { - try { - ovInfReq.wait(); - } catch (const Exception& e) { - ORT_THROW(log_tag + " Wait Model Failed: " + e.what()); - } catch (...) { - ORT_THROW(log_tag + " Wait Mode Failed"); + void StatefulOVInferRequest::CacheTensor(const std::string& tensor_name, std::vector& cache) { + auto tensor = ovInfReq.get_tensor(tensor_name); + auto* pData = tensor.data(); + for (size_t i = 0; i < tensor.get_size(); i++) { + cache.emplace_back(pData[i]); + } } -} -void OVInferRequest::QueryStatus() { - std::cout << "ovInfReq.query_state()" - << " "; -} + void StatefulOVInferRequest::SetTensorFromCache(const std::string& tensor_name, + const std::vector& cache_data) { + auto tensor = ovInfReq.get_tensor(tensor_name); + auto new_shape = tensor.get_shape(); + new_shape[1] = cache_data.size(); + + auto new_tensor = ov::Tensor(tensor.get_element_type(), new_shape); + auto* pNewData = new_tensor.data(); + std::memcpy(pNewData, cache_data.data(), cache_data.size() * sizeof(int64_t)); + + ovInfReq.set_tensor(tensor_name, new_tensor); + } + + std::optional StatefulOVInferRequest::FindTensor(const std::string& tensor_name) { + // Check if tensor exists by examining input names in the compiled model + const auto& model = ovInfReq.get_compiled_model(); + bool tensor_exists = false; + + for (const auto& input : model.inputs()) { + const auto& names = input.get_names(); + if (names.find(tensor_name) != names.end()) { + tensor_exists = true; + break; + } + } + + if (tensor_exists) { + return ovInfReq.get_tensor(tensor_name); + } + + return std::nullopt; + } + + void StatefulOVInferRequest::PreProcessInferRequest() { + // Workaround: Setting the value here as it cannot be set at the ORT GenAI layer currently. + // TODO(ankit): Address this issue and implement the fix at the appropriate layer. + FillTensor("beam_idx", ov::element::i32, {1}, 0); + + // If 'prefill use full chat history' mode is enabled, we need to cache input_ids and position_ids. + if (prefill_use_full_chat_history) { + auto input_ids_tensor = ovInfReq.get_tensor("input_ids"); + CacheTensor("input_ids", cached_input_ids); + + // "position_ids" (GQA with Rotary Embeddings doesnt have position_ids) - check if exists + auto position_ids_opt = FindTensor("position_ids"); + bool has_position_ids = position_ids_opt.has_value(); + + if (has_position_ids) { + CacheTensor("position_ids", cached_position_ids); + } + + // If we're about to run the prefill model + if (input_ids_tensor.get_size() > 1) { + // Check if the size of the current "input_ids" tensor does not match the size of the cached "input_ids". + // This indicates that we are running a subsequent prompt (not the initial prefill). + if (input_ids_tensor.get_shape()[1] != cached_input_ids.size()) { + // Clear the internal KVCache state. For NPU device, this operation is a no-op. + ovInfReq.reset_state(); + + // Set tensors using cached values + SetTensorFromCache("input_ids", cached_input_ids); + + // Only set position_ids if it exists and we have cached values + if (has_position_ids && !cached_position_ids.empty()) { + SetTensorFromCache("position_ids", cached_position_ids); + } + } + } + } + } + + void StatefulOVInferRequest::Infer() { + PreProcessInferRequest(); + OVInferRequest::Infer(); + } + + void StatefulOVInferRequest::RewindKVCache(size_t index) { + LOGS_DEFAULT(INFO) << log_tag << "RewindKVCache: Rewinding OpenVINO-internal KVCache state to index=" << index; + + if (prefill_use_full_chat_history) { + // Clear the internal KVCache state. For NPU device, this operation is a no-op. + ovInfReq.reset_state(); + + // Resize the cached "input_ids" and "position_ids" to the specified index. + if (cached_input_ids.size() > index) { + cached_input_ids.resize(index); + } + + if (cached_position_ids.size() > index) { + cached_position_ids.resize(index); + } + } else { + if (index == 0) { + // In this case, since we're resetting the entire KVCache, simply reset the state. + ovInfReq.reset_state(); + } else { + // Retrieve KVCache states and trim them to the specified index. + // The following logic is adapted from: + // https://github.com/openvinotoolkit/openvino.genai/blob/releases/2025/1/src/cpp/src/utils.cpp#L329 + auto states = ovInfReq.query_state(); + for (auto& state : states) { + ov::Tensor old_tensor = state.get_state(); + // Tensor shape: [batch_size, num_kv_heads, seq_len, head_size] + auto shape = old_tensor.get_shape(); + + if (shape[2] > index) { + // Update the sequence length dimension to the specified index. + shape[2] = index; + + ov::Coordinate new_shape_begin{0, 0, 0, 0}; + ov::Coordinate new_shape_end{shape}; + + // Create a trimmed tensor with the updated shape. + auto trimmed_tensor = ov::Tensor(old_tensor, new_shape_begin, new_shape_end); + + // Copy the trimmed tensor into a new tensor and update the state. + ov::Tensor new_tensor(old_tensor.get_element_type(), shape); + trimmed_tensor.copy_to(new_tensor); + + state.set_state(new_tensor); + } + } + } + } + } } // namespace openvino_ep } // namespace onnxruntime diff --git a/onnxruntime/core/providers/openvino/ov_interface.h b/onnxruntime/core/providers/openvino/ov_interface.h index bebe73bd702dd..fb1757199698b 100644 --- a/onnxruntime/core/providers/openvino/ov_interface.h +++ b/onnxruntime/core/providers/openvino/ov_interface.h @@ -9,15 +9,15 @@ #include #include #include +#include +#include #include "openvino/openvino.hpp" #include "openvino/runtime/intel_npu/properties.hpp" #include "openvino/pass/convert_fp32_to_fp16.hpp" #include "openvino/frontend/manager.hpp" - -#ifdef IO_BUFFER_ENABLED -#include -#endif +#include "openvino/core/dimension.hpp" +#include "openvino/core/partial_shape.hpp" #include @@ -33,11 +33,6 @@ typedef ov::Model OVNetwork; typedef std::shared_ptr OVInferRequestPtr; typedef std::shared_ptr OVTensorPtr; -#ifdef IO_BUFFER_ENABLED -typedef ov::intel_gpu::ocl::ClContext* OVRemoteContextPtr; -typedef ov::RemoteContext OVRemoteContext; -#endif - std::optional queryOVProperty(const std::string& property, const std::string& device_type); template @@ -72,10 +67,14 @@ struct OVCore : WeakSingleton { // OV Interface For Reading Model std::shared_ptr ReadModel(std::string&& model_stream, const std::string& model_path); + OVExeNetwork StatefulCompileModel(std::shared_ptr& model, + std::string& hw_target, + const ov::AnyMap& device_config); // OV Interface for Compiling OV Model Type OVExeNetwork CompileModel(std::shared_ptr& ie_cnn_network, std::string& hw_target, ov::AnyMap& device_config, + bool enable_causallm, const std::string& name); // OV Interface for Fast Compile OVExeNetwork CompileModel(const std::string& onnx_model, @@ -87,14 +86,12 @@ struct OVCore : WeakSingleton { std::string hw_target, const ov::AnyMap& device_config, std::string name); -#ifdef IO_BUFFER_ENABLED - OVExeNetwork CompileModel(std::shared_ptr& model, - OVRemoteContextPtr context, - std::string name); - OVExeNetwork ImportModel(std::shared_ptr model_stream, - OVRemoteContextPtr context, - std::string name); -#endif + OVExeNetwork ImportEPCtxOVIREncapsulation(std::istream& model_stream, + std::string& hw_target, + const ov::AnyMap& device_config, + bool enable_causallm, + std::filesystem::path model_file_path); + std::vector GetAvailableDevices() const; std::vector GetAvailableDevices(const std::string& device_type) const; void SetCache(const std::string& cache_dir_path); @@ -102,32 +99,75 @@ struct OVCore : WeakSingleton { }; class OVExeNetwork { - ov::CompiledModel obj; + ov::CompiledModel compiled_model_obj; + std::string target_device; + bool is_stateful_causallm; public: - explicit OVExeNetwork(ov::CompiledModel md) : obj(md) {} - OVExeNetwork() : obj(ov::CompiledModel()) {} - ov::CompiledModel& Get() { return obj; } - OVInferRequest CreateInferRequest(); + explicit OVExeNetwork(ov::CompiledModel compiled_model, std::string device, bool stateful_causallm = false) + : compiled_model_obj(compiled_model), target_device(device), is_stateful_causallm(stateful_causallm) {} + OVExeNetwork() : compiled_model_obj(ov::CompiledModel()) {} + ov::CompiledModel& Get() { return compiled_model_obj; } + std::shared_ptr CreateInferRequest(); }; class OVInferRequest { + struct ov_tensor_data_t { + OVTensorPtr tensor_ptr; + const void* ort_ptr; + }; + + protected: ov::InferRequest ovInfReq; + std::unordered_map bindings_cache_; public: uint32_t GetNumInputs(); OVTensorPtr GetTensor(const std::string& name); std::string GetInputTensorName(uint32_t index); + + // Set tensor described param_info and ort_ptr. Overrides shape in param_info with shape_override. Call infer req tensor if ort_ptr is last set. + void SetTensor(const std::string& name, const ov::element::Type& type, const ov::Shape& shape, void* ort_ptr) { + auto& cached_binding = bindings_cache_[name]; + if (cached_binding.ort_ptr != ort_ptr) { + auto tensor_ptr = std::make_shared(type, shape, const_cast(ort_ptr)); + SetTensor(name, tensor_ptr); + cached_binding = {tensor_ptr, ort_ptr}; + } + } + void SetTensor(const std::string& name, OVTensorPtr& blob); - void StartAsync(); - void Infer(); - void WaitRequest(); - void QueryStatus(); + virtual void Infer(); explicit OVInferRequest(ov::InferRequest obj) : ovInfReq(std::move(obj)) {} OVInferRequest() : ovInfReq(ov::InferRequest()) {} ov::InferRequest& GetNewObj() { return ovInfReq; } + virtual void RewindKVCache(size_t index) {} }; + +class StatefulOVInferRequest : public OVInferRequest { + public: + explicit StatefulOVInferRequest(ov::InferRequest infer_request, std::string device); + + void Infer() override; + void RewindKVCache(size_t index) override; + void FillTensor(const std::string& tensor_name, const ov::element::Type& type, + const std::vector& shape, int32_t fill_value); + void CacheTensor(const std::string& tensor_name, std::vector& cache); + void SetTensorFromCache(const std::string& tensor_name, const std::vector& cache_data); + std::optional FindTensor(const std::string& tensor_name); + + private: + void PreProcessInferRequest(); + std::string target_device; + + // If prefill_use_full_chat_history is true, cache the "input_ids" & "position_ids" tensors, + // and ensure that full chat history is passed for each prefill call. + bool prefill_use_full_chat_history = false; + std::vector cached_input_ids; + std::vector cached_position_ids; +}; + } // namespace openvino_ep } // namespace onnxruntime diff --git a/onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc b/onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc new file mode 100644 index 0000000000000..67ba42884e4f0 --- /dev/null +++ b/onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc @@ -0,0 +1,350 @@ +// Copyright (C) Intel Corporation +// Licensed under the MIT License + +#include "core/providers/openvino/ov_stateful_patch_utils.h" + +namespace onnxruntime { +namespace openvino_ep { + +void LogBasicModelInfo(const std::shared_ptr& model) { + std::cout << "Model Name: " << model->get_friendly_name() << std::endl; + + // Log detailed information about model inputs and outputs + auto inputs = model->inputs(); + auto outputs = model->outputs(); + + std::cout << "\tInputs: " << std::endl; + for (const ov::Output& input : inputs) { + const std::string name = input.get_any_name(); + const ov::element::Type type = input.get_element_type(); + const ov::PartialShape shape = input.get_partial_shape(); + const ov::Layout layout = ov::layout::get_layout(input); + + std::cout << "\t\t" << name << ", " << type << ", " << shape << ", " << layout.to_string() << std::endl; + } + + std::cout << "\tOutputs: " << std::endl; + for (const ov::Output& output : outputs) { + const std::string name = output.get_any_name(); + const ov::element::Type type = output.get_element_type(); + const ov::PartialShape shape = output.get_partial_shape(); + const ov::Layout layout = ov::layout::get_layout(output); + + std::cout << "\t\t" << name << ", " << type << ", " << shape << ", " << layout.to_string() << std::endl; + } + + return; +} + +bool ModelHasInputOutputNames(std::shared_ptr model, const std::string& name_to_match) { + for (const ov::Output& input : model->inputs()) { + auto& names = input.get_names(); + + for (auto& name : names) { + if (name == name_to_match) { + return true; + } + } + } + + for (const ov::Output& output : model->outputs()) { + auto& names = output.get_names(); + for (auto& name : names) { + if (name == name_to_match) { + return true; + } + } + } + + return false; +} + +void FuseCacheReorder(std::shared_ptr ov_model, + std::vector& not_kv_inputs, + const std::vector& key_value_input_names, + int gather_dim) { + if (ModelHasInputOutputNames(ov_model, "beam_idx")) { + throw std::runtime_error("Model already has fused cache"); + } + + std::string main_input_name = "inputs_embeds"; + if (ModelHasInputOutputNames(ov_model, "input_ids")) { + main_input_name = "input_ids"; + } + + auto input_batch = ov_model->input(main_input_name).get_partial_shape()[0]; + + auto beam_idx = std::make_shared(ov::element::i32, ov::PartialShape({input_batch})); + beam_idx->set_friendly_name("beam_idx"); + beam_idx->output(0).get_tensor().add_names({"beam_idx"}); + ov_model->add_parameters({beam_idx}); + not_kv_inputs.push_back(beam_idx->get_friendly_name()); + + // Go over all cache parameters and fuse _reorder_cache with indices provided by the new parameter beam_idx + for (const auto& input_name : key_value_input_names) { + auto parameter_output_port = ov_model->input(input_name); + auto consumers = parameter_output_port.get_target_inputs(); + + auto gather_op = + std::make_shared(parameter_output_port, + beam_idx, + ov::opset13::Constant::create(ov::element::i64, {}, {gather_dim})); + + // Replace the source output for all consumers of the input tensor + for (auto& consumer : consumers) { + consumer.replace_source_output(gather_op->output(0)); + } + } + + // Validate the modified model + ov_model->validate_nodes_and_infer_types(); +} + +void MakeStateful(std::shared_ptr& ov_model, + const std::vector& key_value_input_names, + const std::vector& key_value_output_names) { + std::map input_output_map; + + // Create mapping for KV-cache inputs and outputs + for (size_t i = 0; i < key_value_input_names.size(); ++i) { + input_output_map[key_value_input_names[i]] = key_value_output_names[i]; + } + + // Apply the transformation to make the model stateful + ov::pass::Manager manager; + manager.register_pass(input_output_map); + manager.run_passes(ov_model); +} + +// Converted to C++ from below reference URL: +// https://github.com/huggingface/optimum-intel/blob/main/optimum/exporters/openvino/stateful.py#L281 +void PatchStatefulDecoder(std::shared_ptr model) { + std::vector key_value_input_names; + std::vector not_kv_inputs; + for (const ov::Output& input : model->inputs()) { + auto& names = input.get_names(); + + bool found = false; + for (auto& name : names) { + if (name.find("key_values") != std::string::npos) { + key_value_input_names.push_back(name); + found = true; + break; + } + } + + if (!found) { + not_kv_inputs.push_back(input.get_any_name()); + } + } + + std::vector key_value_output_names; + for (const ov::Output& output : model->outputs()) { + auto& names = output.get_names(); + for (auto& name : names) { + if (name.find("present") != std::string::npos) { + key_value_output_names.push_back(name); + break; + } + } + } + + if (key_value_input_names.empty() || key_value_output_names.empty()) { + std::cout << "no key_value_input_names or key_value_output_names found" << std::endl; + return; + } + + // By default, batch is the 0 - th but chatglm uses 1 - st dimension as batch + // TODO(ryan): Deduce from a model via ordinal reshape(? ) and topology + // batch_dim = 1 if config.model_type == "chatglm" and not hasattr(config, "rope_ratio") else 0 + auto batch_dim = 0; + + FuseCacheReorder(model, not_kv_inputs, key_value_input_names, batch_dim); + + MakeStateful(model, key_value_input_names, key_value_output_names); +} + +// Some other utility functions copied from OpenVINO GenAI +bool HasOpWithType(const std::shared_ptr& function, const std::string& type_name) { + for (const auto& op : function->get_ops()) { + if (op->get_type_name() == type_name) { + return true; + } + } + return false; +} + +std::tuple, int64_t> FindLLMMatmul(const std::shared_ptr& model) { + auto last_node = model->output(0).get_node()->input_value(0).get_node_shared_ptr(); + std::shared_ptr matmul = ov::as_type_ptr(last_node); + + // In the case of PagedAttention, all tokens are moved to the batch dimension, + // and slicing/gathering must be performed accordingly. + const bool pa_based_model = HasOpWithType(model, "PagedAttentionExtension"); + int64_t slice_gather_dim = pa_based_model ? 0 : 1; + + // There are several patterns for MatMul we are looking for: + // MatMul -> Result + // MatMul -> Add -> Result + // MatMul -> Transpose -> Result + // MatMul -> Divide -> Tanh -> Multiply -> Result + // MatMul -> Convert -> Result + if (!matmul) { + if (auto add = ov::as_type_ptr(last_node)) { + matmul = ov::as_type_ptr(add->input_value(0).get_node_shared_ptr()); + } else if (auto transpose = ov::as_type_ptr(last_node)) { + matmul = ov::as_type_ptr(transpose->input_value(0).get_node_shared_ptr()); + auto order = ov::as_type_ptr(transpose->input_value(1).get_node_shared_ptr())->get_axis_vector_val(); + slice_gather_dim = order[slice_gather_dim]; + } else if (auto multiply = ov::as_type_ptr(last_node)) { + if (auto tanh = ov::as_type_ptr(multiply->input_value(0).get_node_shared_ptr())) { + if (auto divide = ov::as_type_ptr(tanh->input_value(0).get_node_shared_ptr())) { + matmul = ov::as_type_ptr(divide->input_value(0).get_node_shared_ptr()); + } + } + } else if (auto convert = ov::as_type_ptr(last_node)) { + matmul = ov::as_type_ptr(convert->input_value(0).get_node_shared_ptr()); + } + } + return std::make_tuple(matmul, slice_gather_dim); +} + +void ApplySliceBeforeMatmulTransformation(std::shared_ptr model) { + std::shared_ptr matmul = nullptr; + int64_t slice_gather_dim = -1; + std::tie(matmul, slice_gather_dim) = FindLLMMatmul(model); + + if (matmul && matmul->input(0).get_partial_shape().rank().get_length() == 3) { + auto start = std::make_shared(ov::element::i64, ov::Shape{1}, std::vector{-1}); + auto stop = std::make_shared(ov::element::i64, ov::Shape{1}, std::vector{-2}); + auto step = std::make_shared(ov::element::i64, ov::Shape{1}, std::vector{-1}); + auto axis = std::make_shared(ov::element::i64, ov::Shape{1}, std::vector{slice_gather_dim}); + auto slice = std::make_shared(matmul->input_value(0), start, stop, step, axis); + matmul->input(0).replace_source_output(slice); + } +} + +void UpdateConfig(ov::AnyMap& config, const std::pair& pair) { + if (config.count(pair.first) == 0) { + config.insert(pair); + } +} + +std::optional PopOption(ov::AnyMap& config, const std::string& option_name) { + if (auto it = config.find(option_name); it != config.end()) { + std::optional found = std::make_optional(it->second); + config.erase(it); + return found; + } + return std::nullopt; +} + +void RenameKey(ov::AnyMap& config, const std::string& old_key, const std::string& new_key) { + if (config.count(old_key) != 0) { + auto opt_value = PopOption(config, old_key); + config[new_key] = opt_value.value(); + } +} + +KVAxesPosition GetKVAxesPos(std::shared_ptr model) { + // Sequence length axis in key/values tensors. For most cases, the tensor shape is + // [batch_size, num_kv_heads, seq_len, head_size]. Therefore, the sequence length axis + // is usually at index 2, and the batch axis is at index 0. + KVAxesPosition kv_pos{0u, 2u}; + + // "ReadValue" node is KV cache representation in stateful model + std::string kv_node_type_name = std::string(ov::op::v6::ReadValue::get_type_info_static().name); + + for (const auto& op : model->get_ops()) { + // Check input size, as in LoRA adapters case it could be 0 + if (op->get_type_name() != kv_node_type_name || op->get_input_size() < 1) { + continue; + } + + // Shape example: [-1,4,0,64] + auto shape = op->get_input_partial_shape(0); + + for (int64_t i = 0; i < shape.rank().get_length(); i++) { + // Find axis = 0. This would be sequence length axis. + if (shape[i] == 0) { + kv_pos.seq_len = i; + } else if (shape[i].is_dynamic()) { + // Dynamic axis is a batch + kv_pos.batch = i; + } + } + break; + } + + return kv_pos; +} + +void UpdateNPUConfig(ov::AnyMap& config, const KVAxesPosition& kv_pos, const KVDesc& kv_desc) { + UpdateConfig(config, {"NPU_USE_NPUW", "YES"}); + UpdateConfig(config, {"NPUW_LLM", "YES"}); + + UpdateConfig(config, {"NPUW_LLM_BATCH_DIM", kv_pos.batch}); + UpdateConfig(config, {"NPUW_LLM_SEQ_LEN_DIM", kv_pos.seq_len}); + + UpdateConfig(config, {"NPUW_LLM_MAX_PROMPT_LEN", kv_desc.max_prompt_len}); + UpdateConfig(config, {"NPUW_LLM_MIN_RESPONSE_LEN", kv_desc.min_response_len}); + + RenameKey(config, "++PREFILL_CONFIG", "++NPUW_LLM_PREFILL_CONFIG"); + RenameKey(config, "++GENERATE_CONFIG", "++NPUW_LLM_GENERATE_CONFIG"); + RenameKey(config, "PREFILL_CONFIG", "NPUW_LLM_PREFILL_CONFIG"); + RenameKey(config, "PREFILL_HINT", "NPUW_LLM_PREFILL_HINT"); + RenameKey(config, "GENERATE_CONFIG", "NPUW_LLM_GENERATE_CONFIG"); + RenameKey(config, "GENERATE_HINT", "NPUW_LLM_GENERATE_HINT"); + + const size_t npuw_context_len_threshold = 2048; + if ((kv_desc.max_prompt_len + kv_desc.min_response_len) >= npuw_context_len_threshold) { + // This improves accuracy for generation sequences that exceed 2k tokens. + config["++NPUW_LLM_PREFILL_CONFIG"] = ov::AnyMap{{"NPUW_DEVICES", "NPU,CPU"}, {"NPUW_ONLINE_AVOID", "P:SinCos/NPU"}}; + config["++NPUW_LLM_GENERATE_CONFIG"] = ov::AnyMap{{"NPUW_DEVICES", "NPU,CPU"}, {"NPUW_ONLINE_AVOID", "P:SinCos/NPU"}}; + } +} + +std::optional PopOptionNew(ov::AnyMap& config, const std::string& option_name) { + if (auto it = config.find(option_name); it != config.end()) { + std::optional found = std::make_optional(it->second); + config.erase(it); + return found; + } + return std::nullopt; +} + +std::optional PopIntAndCast(ov::AnyMap& config, const std::string& key) { + auto anyopt = PopOptionNew(config, key); + if (anyopt.has_value()) { + const auto any = anyopt.value(); + int64_t value; + // NB: Integer value coming from python has int64_t datatype + if (any.is()) { + value = any.as(); + } else if (any.is()) { + value = any.as(); + } else { + OPENVINO_THROW("Failed to extract " + key + ". Type mismatch: expected types: int or int64_t"); + } + if (value < 0) { + OPENVINO_THROW(key + " cannot be negative!"); + } + return std::make_optional(static_cast(value)); + } + return std::nullopt; +} + +bool IsStateful(const std::shared_ptr& model) { + for (auto&& ptr : model->get_ordered_ops()) { + if (ov::is_type(ptr) || + ov::is_type(ptr) || + ov::is_type(ptr) || + ov::is_type(ptr)) { + return true; + } + } + return false; +} + +} // namespace openvino_ep +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/openvino/ov_stateful_patch_utils.h b/onnxruntime/core/providers/openvino/ov_stateful_patch_utils.h new file mode 100644 index 0000000000000..0b89c4ed02e13 --- /dev/null +++ b/onnxruntime/core/providers/openvino/ov_stateful_patch_utils.h @@ -0,0 +1,84 @@ +// Copyright (C) Intel Corporation +// Licensed under the MIT License + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include "openvino/pass/manager.hpp" +#include "openvino/pass/make_stateful.hpp" +#include "openvino/opsets/opset13.hpp" + +namespace onnxruntime { +namespace openvino_ep { + +void LogBasicModelInfo(const std::shared_ptr& model); + +bool ModelHasInputOutputNames(std::shared_ptr model, const std::string& name_to_match); + +void FuseCacheReorder(std::shared_ptr ov_model, + std::vector& not_kv_inputs, + const std::vector& key_value_input_names, + int gather_dim); + +void MakeStateful(std::shared_ptr& ov_model, + const std::vector& key_value_input_names, + const std::vector& key_value_output_names); + +void PatchStatefulDecoder(std::shared_ptr model); + +bool HasOpWithType(const std::shared_ptr& function, const std::string& type_name); + +std::tuple, int64_t> FindLLMMatmul(const std::shared_ptr& model); + +void ApplySliceBeforeMatmulTransformation(std::shared_ptr model); + +void UpdateConfig(ov::AnyMap& config, const std::pair& pair); + +std::optional PopOption(ov::AnyMap& config, const std::string& option_name); + +void RenameKey(ov::AnyMap& config, const std::string& old_key, const std::string& new_key); + +struct KVAxesPosition { + size_t batch; + size_t seq_len; +}; + +KVAxesPosition GetKVAxesPos(std::shared_ptr model); + +struct KVDesc { + uint32_t max_prompt_len; + uint32_t min_response_len; +}; + +struct CausalLMConfig { + void ApplyConfig(const ov::AnyMap& external_config, ov::AnyMap& genai_config) { + if (external_config.find("MAX_PROMPT_LEN") != external_config.end()) { + max_prompt_len = external_config.at("MAX_PROMPT_LEN").as(); + } + if (external_config.find("MIN_RESPONSE_LEN") != external_config.end()) { + min_response_len = external_config.at("MIN_RESPONSE_LEN").as(); + } + genai_config["MAX_PROMPT_LEN"] = ov::Any(max_prompt_len); + genai_config["MIN_RESPONSE_LEN"] = ov::Any(min_response_len); + } + + unsigned int max_prompt_len = 1024; + unsigned int min_response_len = 128; +}; + +void UpdateNPUConfig(ov::AnyMap& config, const KVAxesPosition& kv_pos, const KVDesc& kv_desc); + +std::optional PopOptionNew(ov::AnyMap& config, const std::string& option_name); +std::optional PopIntAndCast(ov::AnyMap& config, const std::string& key); + +bool IsStateful(const std::shared_ptr& model); + +} // namespace openvino_ep +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/openvino/ov_versions/capability.cc b/onnxruntime/core/providers/openvino/ov_versions/capability.cc index d56687f868c3d..88ddde8610c6e 100644 --- a/onnxruntime/core/providers/openvino/ov_versions/capability.cc +++ b/onnxruntime/core/providers/openvino/ov_versions/capability.cc @@ -37,7 +37,10 @@ GetCapability::GetCapability(const EPCtxHandler& ep_ctx_handler, if (device_type_.find("NPU") != std::string::npos) { device_type_ = "CPU"; if (enable_qdq_optimizer) npu_qdq_optimizer_enabled = true; + } else if (enable_qdq_optimizer && device_type_.find("GPU") != std::string::npos) { + npu_qdq_optimizer_enabled = true; // see data_ops.cc ~615 where we check for int16 types for gpu, this may change to a better approach later } + #if OPENVINO_VERSION_MAJOR == 2024 && OPENVINO_VERSION_MINOR == 5 data_ops_ = new DataOps(graph_viewer_, V_2024_5, device_type_, npu_qdq_optimizer_enabled); #elif OPENVINO_VERSION_MAJOR == 2024 && OPENVINO_VERSION_MINOR == 6 @@ -126,9 +129,6 @@ std::vector> GetCapability::Execute() { } } - // Initializers need to be part of meta_def->inputs - Iterable2String(inputs, ng_required_initializers); - // Fill outputs with names Iterable2String(outputs, graph_viewer_.GetOutputs()); diff --git a/onnxruntime/core/providers/openvino/ov_versions/data_ops.cc b/onnxruntime/core/providers/openvino/ov_versions/data_ops.cc index 4e1387d2ef4a9..84001c1161efc 100644 --- a/onnxruntime/core/providers/openvino/ov_versions/data_ops.cc +++ b/onnxruntime/core/providers/openvino/ov_versions/data_ops.cc @@ -158,6 +158,7 @@ std::vector supported_op_mode = { {"InstanceNormalization", V_2020_4, {"CPU", "GPU"}}, {"HardSigmoid", V_2020_4, {"CPU", "GPU"}}, {"HardMax", V_2022_1, {"CPU", "GPU"}}, + {"HardSwish", V_2025_0, {"CPU", "GPU"}}, {"LayerNormalization", V_2023_0, {"CPU", "GPU"}}, {"LeakyRelu", V_2020_4, {"CPU", "GPU"}}, {"Less", V_2020_4, {"CPU", "GPU"}}, @@ -419,7 +420,8 @@ void DataOps::populate_op_mode_supported() { UnsupportedOpMode obj = {{V_2023_1, V_2023_2, V_2023_3, V_2024_0, V_2024_1, V_2024_2, V_2024_3, V_2024_4, V_2024_5, V_2024_6, V_2025_0, V_2025_1}, [this](const Node* node, const InitializedTensorSet&) { - const auto& input_arg = node->InputDefs()[1]; + const auto& input_args = node->InputDefs(); + const auto& input_arg = (input_args.size() > 1) ? input_args[1] : input_args[0]; auto shape = input_arg->Shape(); // Reshape op with empty dim is Rejected for Myriad // [TODO] Is this condition required anymore with Myriad removed? @@ -612,6 +614,9 @@ bool DataOps::type_is_supported(const NodeArg* node_arg, bool is_initializer) { (var.second == dtype)) { return true; } + // experimentally for GPU and qdq stripping mode allow int16 types + if (npu_qdq_optimizer_enabled_ && (dtype == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT16 || dtype == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT16)) + return true; } #ifndef NDEBUG if (openvino_ep::backend_utils::IsDebugEnabled()) { diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index 3fb16d30f4970..bdc4f65e590d9 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -1181,63 +1181,13 @@ static std::shared_ptr CreateExecutionProviderFactory } else if (type == kOpenVINOExecutionProvider) { #if defined(USE_OPENVINO) || defined(USE_OPENVINO_PROVIDER_INTERFACE) ProviderOptions OV_provider_options_map; + const std::unordered_set valid_provider_keys = {"device_type", "device_id", "device_luid", "cache_dir", "precision", + "load_config", "context", "num_of_threads", "model_priority", "num_streams", "enable_opencl_throttling", "enable_qdq_optimizer", + "enable_causallm", "disable_dynamic_shapes", "reshape_input"}; auto it = provider_options_map.find(type); if (it != provider_options_map.end()) { for (auto option : it->second) { - if (option.first == "device_type") { - OV_provider_options_map[option.first] = option.second; - continue; - } else if (option.first == "precision") { - OV_provider_options_map[option.first] = option.second; - continue; - } else if (option.first == "enable_opencl_throttling") { - if (!(option.second == "True" || option.second == "true" || - option.second == "False" || option.second == "false")) { - ORT_THROW("Invalid value passed for enable_opencl_throttling: ", option.second); - } - OV_provider_options_map[option.first] = option.second; - } else if (option.first == "disable_dynamic_shapes") { - if (!(option.second == "True" || option.second == "true" || - option.second == "False" || option.second == "false")) { - ORT_THROW("Invalid value passed for disable_dynamic_shapes: ", option.second); - } - OV_provider_options_map[option.first] = option.second; - } else if (option.first == "enable_dynamic_shapes") { - LOGS_DEFAULT(WARNING) << " Deprecation notice - 'enable_dynamic_shapes' is Deprected. Upgrade the API to disable_dynamic_shapes parameter." - "Please refer https://onnxruntime.ai/docs/execution-providers/OpenVINO-ExecutionProvider.html#requirements to ensure all dependencies are met."; - std::string value; - if (!(option.second == "True" || option.second == "true" || - option.second == "False" || option.second == "false")) { - ORT_THROW("Invalid value passed for enable_dynamic_shapes: ", option.second); - } - if (option.second == "True" || option.second == "true") { - value = "false"; - } else { - value = "true"; - } - OV_provider_options_map["disable_dynamic_shapes"] = value; - } else if (option.first == "num_of_threads") { - OV_provider_options_map[option.first] = option.second; - continue; - } else if (option.first == "model_priority") { - OV_provider_options_map[option.first] = option.second; - continue; - } else if (option.first == "num_streams") { - OV_provider_options_map[option.first] = option.second; - continue; - } else if (option.first == "load_config") { - OV_provider_options_map[option.first] = option.second; - continue; - } else if (option.first == "cache_dir") { - OV_provider_options_map[option.first] = option.second; - continue; - } else if (option.first == "context") { - OV_provider_options_map[option.first] = option.second; - continue; - } else if (option.first == "enable_qdq_optimizer") { - OV_provider_options_map[option.first] = option.second; - continue; - } else if (option.first == "enable_causallm") { + if (valid_provider_keys.count(option.first)) { OV_provider_options_map[option.first] = option.second; continue; } else { diff --git a/onnxruntime/python/tools/quantization/matmul_nbits_quantizer.py b/onnxruntime/python/tools/quantization/matmul_nbits_quantizer.py index 9de11041f5331..174527118ce8b 100644 --- a/onnxruntime/python/tools/quantization/matmul_nbits_quantizer.py +++ b/onnxruntime/python/tools/quantization/matmul_nbits_quantizer.py @@ -235,6 +235,7 @@ def __init__( op_types_to_quantize: tuple[str, ...] | None = None, quant_axes: tuple[tuple[str, int], ...] | None = None, bits: int = 4, + channel_wised_quantize: bool = False, ): """ This is a class for weight only affine quantization configuration. @@ -269,6 +270,9 @@ def __init__( self.is_symmetric = is_symmetric self.bits = bits self.accuracy_level = accuracy_level + self.channel_wised_quantize = channel_wised_quantize + if channel_wised_quantize and quant_format == QuantFormat.QOperator: + raise NotImplementedError("QuantFormat.QOperator is not supported channel_wised_quantize yet") class NVAWQWeightOnlyQuantConfig(WeightOnlyQuantConfig): @@ -767,6 +771,26 @@ def get_initializer(name, graph_path: list[GraphProto]) -> tuple[TensorProto, Gr return None, None +# transpose int4 matrix (packed as uint8) +def transpose_packed_int4_matrix(packed, rows, cols): + # unpack to int4 matrix + total = rows * cols + high = (packed >> 4) & 0x0F + low = packed & 0x0F + int4_vals = np.empty(total, dtype=np.uint8) + int4_vals[0::2] = low + int4_vals[1::2] = high + int4_matrix = int4_vals.reshape((rows, cols)) + + # transpose int4 matrix + int4_matrix_transposed = int4_matrix.T + + # pack to uint8 + flat = int4_matrix_transposed.reshape(-1) + packed = ((flat[1::2] << 4) & 0xF0) | (flat[0::2] & 0x0F) + return packed.astype(np.uint8) + + class DefaultWeightOnlyQuantizer: def __init__(self, config: DefaultWeightOnlyQuantConfig): self.config = config @@ -803,6 +827,10 @@ def qbits_block_quant(self, fp32weight: npt.ArrayLike) -> tuple[np.ndarray, np.n packed, fp32weight, scales, zero_point, block_size, cols, rows, self.config.is_symmetric ) else: + # block size equal to rows (K) if channel wised quantize enabled + block_size = rows if self.config.channel_wised_quantize else self.config.block_size + k_blocks = (rows + block_size - 1) // block_size + assert qbits == 4, "QDQ format only support 4 bits quantization" packed = np.zeros((rows * cols + 1) // 2, dtype="uint8") zero_point = np.zeros((cols * k_blocks + 1) // 2, dtype="uint8") @@ -845,6 +873,21 @@ def quantize_matmul(self, node: NodeProto, graph_stack: list[GraphProto]) -> lis ) scales_tensor = onnx.numpy_helper.from_array(scales, b_tensor.name + "_DQ_scales") + # if QDQ, CW and SYM enabled, optimize for Intel NPU, tranpose the weight to NHWC format will increase performance + qdq_opt_for_intel_npu_enabled = ( + self.config.quant_format == QuantFormat.QDQ + and self.config.channel_wised_quantize + and self.config.is_symmetric + ) + if qdq_opt_for_intel_npu_enabled: + rows, cols = b_ndarray.shape + packed = transpose_packed_int4_matrix(packed, rows, cols) + scales = scales.reshape((cols, 1)) # (cols, 1) + b_quant = onnx.helper.make_tensor( + b_tensor.name + f"_DQ_Q{bits}", qtype, [cols, rows], packed.tobytes(), True + ) + scales_tensor = onnx.numpy_helper.from_array(scales, b_tensor.name + "_DQ_scales") + for input in b_graph.input: if input.name == input_b: b_graph.input.remove(input) @@ -884,7 +927,12 @@ def quantize_matmul(self, node: NodeProto, graph_stack: list[GraphProto]) -> lis else: dq_input_names = [b_quant.name, scales_tensor.name] dq_output_names = [b_quant.name + "_output"] - matmul_input_names = [node.input[0], dq_output_names[0]] + tp_input_names = [dq_output_names[0]] + tp_output_names = [dq_output_names[0] + "_transposed"] + matmul_input_names = [ + node.input[0], + tp_output_names[0] if qdq_opt_for_intel_npu_enabled else dq_output_names[0], + ] matmul_output_names = [node.output[0]] if not self.config.is_symmetric: zp_tensor = onnx.helper.make_tensor( @@ -892,7 +940,11 @@ def quantize_matmul(self, node: NodeProto, graph_stack: list[GraphProto]) -> lis ) dq_input_names.append(zp_tensor.name) b_graph.initializer.extend([zp_tensor]) - dq_kwargs = {"axis": 0, "block_size": self.config.block_size} + rows, cols = b_ndarray.shape + dq_kwargs = { + "axis": 1 if qdq_opt_for_intel_npu_enabled else 0, + "block_size": rows if self.config.channel_wised_quantize else self.config.block_size, + } dq_node = onnx.helper.make_node( "DequantizeLinear", inputs=dq_input_names, @@ -906,7 +958,16 @@ def quantize_matmul(self, node: NodeProto, graph_stack: list[GraphProto]) -> lis outputs=matmul_output_names, name=node.name + f"_matmul_Q{bits}" if node.name else "", ) - output_nodes.extend([dq_node, matmul_node]) + if qdq_opt_for_intel_npu_enabled: + tp_node = onnx.helper.make_node( + "Transpose", + inputs=tp_input_names, + outputs=tp_output_names, + perm=[1, 0], + ) + output_nodes.extend([dq_node, tp_node, matmul_node]) + else: + output_nodes.extend([dq_node, matmul_node]) return output_nodes @@ -1171,6 +1232,7 @@ def __init__( quant_format=QuantFormat.QOperator, op_types_to_quantize: tuple[str, ...] | None = None, quant_axes: tuple[tuple[str, int], ...] | None = None, + channel_wised_quantize: bool = False, algo_config: WeightOnlyQuantConfig | None = None, ): if nodes_to_exclude is None: @@ -1193,6 +1255,7 @@ def __init__( op_types_to_quantize=op_types_to_quantize, quant_axes=quant_axes, bits=4, # default to 4 bits + channel_wised_quantize=channel_wised_quantize, ) self.algo_config = algo_config diff --git a/onnxruntime/test/perftest/ort_test_session.cc b/onnxruntime/test/perftest/ort_test_session.cc index dc14ac67d5ea2..7a210ca8482a4 100644 --- a/onnxruntime/test/perftest/ort_test_session.cc +++ b/onnxruntime/test/perftest/ort_test_session.cc @@ -764,6 +764,15 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)"); } else { ORT_THROW("[ERROR] [OpenVINO] The value for the key 'enable_qdq_optimizer' should be a boolean i.e. true or false. Default value is false.\n"); } + } else if (key == "enable_causallm") { + if (value == "true" || value == "True" || + value == "false" || value == "False") { + ov_options[key] = value; + } else { + ORT_THROW( + "[ERROR] [OpenVINO] The value for the key 'enable_causallm' should be a boolean i.e. true or false." + " Default value is false. This provider option must be used with CausalLM Models viz. LLMs & SLMs only.\n"); + } } else if (key == "disable_dynamic_shapes") { if (value == "true" || value == "True" || value == "false" || value == "False") { @@ -817,11 +826,14 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)"); device_memory_name_ = std::move(value); } else if (key == "device_luid") { ov_options[key] = value; + } else if (key == "reshape_input") { + ov_options[key] = value; } else { ORT_THROW( "[ERROR] [OpenVINO] wrong key type entered. Choose from the following runtime key options that are available for OpenVINO." " ['device_type', 'device_id', 'num_of_threads', 'load_config', 'cache_dir', 'num_streams', " - "'enable_opencl_throttling', 'disable_dynamic_shapes', 'enable_qdq_optimizer', 'model_priority'] \n"); + "'enable_opencl_throttling', 'disable_dynamic_shapes', 'enable_qdq_optimizer'," + " 'enable_causallm', 'model_priority'] \n"); } } session_options.AppendExecutionProvider_OpenVINO_V2(ov_options); diff --git a/onnxruntime/test/providers/cpu/reduction/reduction_ops_test.cc b/onnxruntime/test/providers/cpu/reduction/reduction_ops_test.cc index 92cd82c2c9420..c56aa3fb5feac 100644 --- a/onnxruntime/test/providers/cpu/reduction/reduction_ops_test.cc +++ b/onnxruntime/test/providers/cpu/reduction/reduction_ops_test.cc @@ -4099,7 +4099,11 @@ TEST(ReductionOpTest, ReduceSum_noop_axes_input_initializer_opset_18) { 3.0f, 4.0f}); test.AddInput("axes", {0}, {}, true); test.AddOutput("reduced", {1, 2, 2}, {1.0f, 2.0f, 3.0f, 4.0f}); - test.Run(); + test.Run( + OpTester::ExpectResult::kExpectSuccess, + "", + {kOpenVINOExecutionProvider} // OpenVINO: Disabled temporarily + ); } TEST(ReductionOpTest, ReduceSum_empty_axes_input_initializer_opset_18) { diff --git a/onnxruntime/test/providers/openvino/openvino_ep_context_test.cc b/onnxruntime/test/providers/openvino/openvino_ep_context_test.cc new file mode 100644 index 0000000000000..e205b3aeb064a --- /dev/null +++ b/onnxruntime/test/providers/openvino/openvino_ep_context_test.cc @@ -0,0 +1,77 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include + +#include "core/framework/provider_options.h" +#include "core/framework/tensor_shape.h" +#include "core/framework/float16.h" + +#include "test/util/include/test_utils.h" +#include "test/util/include/test/test_environment.h" +#include "test/util/include/default_providers.h" + +#include "core/session/onnxruntime_cxx_api.h" +#include "core/session/onnxruntime_session_options_config_keys.h" +#include "core/session/inference_session.h" +#include "core/graph/model_saving_options.h" + +#include "test/optimizer/qdq_test_utils.h" + +#include "gtest/gtest.h" +#include "gmock/gmock.h" + +using namespace ONNX_NAMESPACE; +using namespace onnxruntime::logging; + +extern std::unique_ptr ort_env; + +class OVEPEPContextTests : public ::testing::Test { +}; + +namespace onnxruntime { +namespace test { + +// Test if folder path given to ep_context_file_path throws an error +TEST_F(OVEPEPContextTests, OVEPEPContextFolderPath) { + Ort::SessionOptions sessionOptions; + std::unordered_map ov_options; + + // The below line could fail the test in non NPU platforms.Commenting it out so that the device used for building OVEP will be used. + // ov_options["device_type"] = "NPU"; + + const std::unordered_map domain_to_version = {{"", 13}, {kMSDomain, 1}}; + + auto& logging_manager = DefaultLoggingManager(); + logging_manager.SetDefaultLoggerSeverity(logging::Severity::kERROR); + + onnxruntime::Model model("OVEP_Test_Model", false, ModelMetaData(), PathString(), + IOnnxRuntimeOpSchemaRegistryList(), domain_to_version, {}, + logging_manager.DefaultLogger()); + + ASSERT_STATUS_OK(model.MainGraph().Resolve()); + + // Serialize the model to a string. + std::string model_data; + model.ToProto().SerializeToString(&model_data); + + const auto model_data_span = AsByteSpan(model_data.data(), model_data.size()); + + const std::string ep_context_file_path = "./ep_context_folder_path/"; + + sessionOptions.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1"); + sessionOptions.AddConfigEntry(kOrtSessionOptionEpContextFilePath, ep_context_file_path.c_str()); + sessionOptions.AppendExecutionProvider_OpenVINO_V2(ov_options); + + try { + Ort::Session session(*ort_env, model_data_span.data(), model_data_span.size(), sessionOptions); + FAIL(); // Should not get here! + } catch (const Ort::Exception& excpt) { + ASSERT_EQ(excpt.GetOrtErrorCode(), ORT_INVALID_ARGUMENT); + ASSERT_THAT(excpt.what(), testing::HasSubstr("context_file_path should not point to a folder.")); + } +} + +} // namespace test +} // namespace onnxruntime From 97ccf3f694b7c4e0bbad3e4a54a1242114ed67b7 Mon Sep 17 00:00:00 2001 From: Edward Chen <18449977+edgchen1@users.noreply.github.com> Date: Fri, 4 Jul 2025 21:01:44 -0700 Subject: [PATCH 3/7] Add OrtEpFactory::GetVersion and store EP version in EP metadata. (#25272) ### Description Add OrtEpFactory::GetVersion and store EP version in EP metadata. ### Motivation and Context Enforce plugin EP version specification and make it accessible from EP metadata. --------- Co-authored-by: Scott McKay --- cmake/onnxruntime.cmake | 5 +- .../core/session/onnxruntime_ep_c_api.h | 22 ++++++- .../onnxruntime_ep_device_ep_metadata_keys.h | 10 ++++ onnxruntime/core/common/semver.cc | 60 +++++++++++++++++++ onnxruntime/core/common/semver.h | 32 ++++++++++ .../providers/cuda/cuda_provider_factory.cc | 6 ++ .../providers/qnn/qnn_provider_factory.cc | 6 ++ onnxruntime/core/session/ep_api.cc | 18 ++++++ onnxruntime/core/session/ep_api_utils.h | 4 ++ .../core/session/ep_factory_internal.cc | 6 ++ .../core/session/ep_factory_internal.h | 1 + onnxruntime/test/autoep/library/ep_factory.cc | 11 +++- onnxruntime/test/autoep/library/ep_factory.h | 7 ++- .../test/autoep/test_autoep_selection.cc | 4 +- onnxruntime/test/common/semver_test.cc | 37 ++++++++++++ .../python/onnxruntime_test_python_autoep.py | 3 +- 16 files changed, 222 insertions(+), 10 deletions(-) create mode 100644 include/onnxruntime/core/session/onnxruntime_ep_device_ep_metadata_keys.h create mode 100644 onnxruntime/core/common/semver.cc create mode 100644 onnxruntime/core/common/semver.h create mode 100644 onnxruntime/test/common/semver_test.cc diff --git a/cmake/onnxruntime.cmake b/cmake/onnxruntime.cmake index ae6684b061883..010696a61022c 100644 --- a/cmake/onnxruntime.cmake +++ b/cmake/onnxruntime.cmake @@ -22,13 +22,14 @@ endif() function(get_c_cxx_api_headers HEADERS_VAR) set(_headers "${REPO_ROOT}/include/onnxruntime/core/session/onnxruntime_c_api.h" - "${REPO_ROOT}/include/onnxruntime/core/session/onnxruntime_ep_c_api.h" "${REPO_ROOT}/include/onnxruntime/core/session/onnxruntime_cxx_api.h" "${REPO_ROOT}/include/onnxruntime/core/session/onnxruntime_cxx_inline.h" + "${REPO_ROOT}/include/onnxruntime/core/session/onnxruntime_ep_c_api.h" + "${REPO_ROOT}/include/onnxruntime/core/session/onnxruntime_ep_device_ep_metadata_keys.h" "${REPO_ROOT}/include/onnxruntime/core/session/onnxruntime_float16.h" + "${REPO_ROOT}/include/onnxruntime/core/session/onnxruntime_lite_custom_op.h" "${REPO_ROOT}/include/onnxruntime/core/session/onnxruntime_run_options_config_keys.h" "${REPO_ROOT}/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h" - "${REPO_ROOT}/include/onnxruntime/core/session/onnxruntime_lite_custom_op.h" ) if (onnxruntime_ENABLE_TRAINING_APIS) diff --git a/include/onnxruntime/core/session/onnxruntime_ep_c_api.h b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h index c53a2f42247d9..44c7bb6ee424a 100644 --- a/include/onnxruntime/core/session/onnxruntime_ep_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h @@ -350,12 +350,12 @@ struct OrtEp { uint32_t ort_version_supported; /** \brief Get the execution provider name. + * + * The returned string should be a null-terminated, UTF-8 encoded string. ORT will copy it. * * \param[in] this_ptr The OrtEp instance. * \return The execution provider name. * - * \note Returned string is owned by ORT and valid until UnregisterExecutionProviderLibrary is called. - * * \since Version 1.22. */ const char*(ORT_API_CALL* GetName)(_In_ const OrtEp* this_ptr); @@ -578,6 +578,8 @@ struct OrtEpFactory { uint32_t ort_version_supported; /** \brief Get the name of the execution provider that the factory creates. + * + * The returned string should be a null-terminated, UTF-8 encoded string. ORT will copy it. * * \param[in] this_ptr The OrtEpFactory instance. * \return The name of the execution provider the factory creates. @@ -587,6 +589,8 @@ struct OrtEpFactory { const char*(ORT_API_CALL* GetName)(const OrtEpFactory* this_ptr); /** \brief Get the name of vendor who owns the execution provider that the factory creates. + * + * The returned string should be a null-terminated, UTF-8 encoded string. ORT will copy it. * * \param[in] this_ptr The OrtEpFactory instance. * \return vendor The vendor name of the execution provider the factory creates. @@ -659,6 +663,20 @@ struct OrtEpFactory { */ void(ORT_API_CALL* ReleaseEp)(OrtEpFactory* this_ptr, struct OrtEp* ep); + /** \brief Get the version of the execution provider that the factory creates. + * + * The version string should adhere to the Semantic Versioning 2.0 specification + * (https://github.com/semver/semver/blob/v2.0.0/semver.md). + * + * The returned string should be a null-terminated, UTF-8 encoded string. ORT will copy it. + * + * \param[in] this_ptr The OrtEpFactory instance. + * \return The execution provider version string. + * + * \since Version 1.23. + */ + const char*(ORT_API_CALL* GetVersion)(_In_ const OrtEpFactory* this_ptr); + /** \brief Create an OrtAllocator for the given OrtMemoryInfo. * * This is used to create an allocator that an execution provider requires. The factory that creates the EP is diff --git a/include/onnxruntime/core/session/onnxruntime_ep_device_ep_metadata_keys.h b/include/onnxruntime/core/session/onnxruntime_ep_device_ep_metadata_keys.h new file mode 100644 index 0000000000000..f0992f05f31e5 --- /dev/null +++ b/include/onnxruntime/core/session/onnxruntime_ep_device_ep_metadata_keys.h @@ -0,0 +1,10 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +// This file contains well-known keys for OrtEpDevice EP metadata entries. +// It does NOT specify all available metadata keys. + +// Key for the execution provider version string. This should be available for all plugin EPs. +static const char* const kOrtEpDevice_EpMetadataKey_Version = "version"; diff --git a/onnxruntime/core/common/semver.cc b/onnxruntime/core/common/semver.cc new file mode 100644 index 0000000000000..618d9dc29ea74 --- /dev/null +++ b/onnxruntime/core/common/semver.cc @@ -0,0 +1,60 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/common/semver.h" + +#include + +#include "core/common/common.h" +#include "core/common/narrow.h" +#include "core/common/parse_string.h" + +namespace onnxruntime { + +Status ParseSemVerVersion(std::string_view version_string, SemVerVersion* semver_version_out) { + // Semantic Versioning version regex was copied from here: + // https://github.com/semver/semver/blob/d58db1686379c8c6d52e32d42d3a530a964264e5/semver.md?plain=1#L357 + static const std::regex semver_pattern{ + R"(^(0|[1-9]\d*)\.(0|[1-9]\d*)\.(0|[1-9]\d*)(?:-((?:0|[1-9]\d*|\d*[a-zA-Z-][0-9a-zA-Z-]*)(?:\.(?:0|[1-9]\d*|\d*[a-zA-Z-][0-9a-zA-Z-]*))*))?(?:\+([0-9a-zA-Z-]+(?:\.[0-9a-zA-Z-]+)*))?$)"}; + + std::cmatch match_result{}; + ORT_RETURN_IF_NOT(std::regex_match(version_string.data(), version_string.data() + version_string.size(), + match_result, semver_pattern), + "Version string is not in semantic versioning format: '", version_string, "'"); + + auto sub_match_to_string_view = [](const std::csub_match& sub_match) -> std::optional { + if (!sub_match.matched) { + return std::nullopt; + } + return std::string_view{sub_match.first, narrow(sub_match.length())}; + }; + + auto parse_version_component = + [&sub_match_to_string_view](const std::csub_match& sub_match, uint32_t& component) -> Status { + const auto component_str = sub_match_to_string_view(sub_match); + ORT_RETURN_IF_NOT(component_str.has_value(), "sub_match does not match anything."); + return ParseStringWithClassicLocale(*component_str, component); + }; + + SemVerVersion semver_version{}; + + ORT_RETURN_IF_ERROR(parse_version_component(match_result[1], semver_version.major)); + ORT_RETURN_IF_ERROR(parse_version_component(match_result[2], semver_version.minor)); + ORT_RETURN_IF_ERROR(parse_version_component(match_result[3], semver_version.patch)); + + semver_version.prerelease = sub_match_to_string_view(match_result[4]); + semver_version.build_metadata = sub_match_to_string_view(match_result[5]); + + if (semver_version_out) { + *semver_version_out = std::move(semver_version); + } + return Status::OK(); +} + +SemVerVersion ParseSemVerVersion(std::string_view version_string) { + SemVerVersion result{}; + ORT_THROW_IF_ERROR(ParseSemVerVersion(version_string, &result)); + return result; +} + +} // namespace onnxruntime diff --git a/onnxruntime/core/common/semver.h b/onnxruntime/core/common/semver.h new file mode 100644 index 0000000000000..a07c24f016886 --- /dev/null +++ b/onnxruntime/core/common/semver.h @@ -0,0 +1,32 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include + +#include "core/common/status.h" + +namespace onnxruntime { + +// Semantic Versioning version utilities. +// See https://github.com/semver/semver/blob/v2.0.0/semver.md. + +// Semantic Versioning version components. +struct SemVerVersion { + uint32_t major{}; + uint32_t minor{}; + uint32_t patch{}; + std::optional prerelease{}; + std::optional build_metadata{}; +}; + +// Parse a Semantic Versioning version from `version_string`. +// If provided, the parsed version components will be written to `semver_version`. +Status ParseSemVerVersion(std::string_view version_string, SemVerVersion* semver_version); + +// Parse a Semantic Versioning version from `version_string`. +SemVerVersion ParseSemVerVersion(std::string_view version_string); + +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/cuda_provider_factory.cc b/onnxruntime/core/providers/cuda/cuda_provider_factory.cc index 6ba2dd8176590..2de496a9168a0 100644 --- a/onnxruntime/core/providers/cuda/cuda_provider_factory.cc +++ b/onnxruntime/core/providers/cuda/cuda_provider_factory.cc @@ -308,12 +308,14 @@ CUDA_Provider* GetProvider() { } // namespace onnxruntime #include "core/framework/error_code_helper.h" +#include "onnxruntime_config.h" // for ORT_VERSION // OrtEpApi infrastructure to be able to use the CUDA EP as an OrtEpFactory for auto EP selection. struct CudaEpFactory : OrtEpFactory { CudaEpFactory(const OrtApi& ort_api_in) : ort_api{ort_api_in} { GetName = GetNameImpl; GetVendor = GetVendorImpl; + GetVersion = GetVersionImpl; GetSupportedDevices = GetSupportedDevicesImpl; CreateEp = CreateEpImpl; ReleaseEp = ReleaseEpImpl; @@ -329,6 +331,10 @@ struct CudaEpFactory : OrtEpFactory { return factory->vendor.c_str(); } + static const char* ORT_API_CALL GetVersionImpl(const OrtEpFactory* /*this_ptr*/) noexcept { + return ORT_VERSION; + } + static OrtStatus* GetSupportedDevicesImpl(OrtEpFactory* this_ptr, const OrtHardwareDevice* const* devices, size_t num_devices, diff --git a/onnxruntime/core/providers/qnn/qnn_provider_factory.cc b/onnxruntime/core/providers/qnn/qnn_provider_factory.cc index 8a5f83f636824..c679ea1adb286 100644 --- a/onnxruntime/core/providers/qnn/qnn_provider_factory.cc +++ b/onnxruntime/core/providers/qnn/qnn_provider_factory.cc @@ -116,6 +116,7 @@ ORT_API(onnxruntime::Provider*, GetProvider) { } #include "core/framework/error_code_helper.h" +#include "onnxruntime_config.h" // for ORT_VERSION // OrtEpApi infrastructure to be able to use the QNN EP as an OrtEpFactory for auto EP selection. struct QnnEpFactory : OrtEpFactory { @@ -126,6 +127,7 @@ struct QnnEpFactory : OrtEpFactory { : ort_api{ort_api_in}, ep_name{ep_name}, ort_hw_device_type{hw_type}, qnn_backend_type{qnn_backend_type} { GetName = GetNameImpl; GetVendor = GetVendorImpl; + GetVersion = GetVersionImpl; GetSupportedDevices = GetSupportedDevicesImpl; CreateEp = CreateEpImpl; ReleaseEp = ReleaseEpImpl; @@ -143,6 +145,10 @@ struct QnnEpFactory : OrtEpFactory { return factory->vendor.c_str(); } + static const char* ORT_API_CALL GetVersionImpl(const OrtEpFactory* /*this_ptr*/) noexcept { + return ORT_VERSION; + } + // Creates and returns OrtEpDevice instances for all OrtHardwareDevices that this factory supports. // An EP created with this factory is expected to be able to execute a model with *all* supported // hardware devices at once. A single instance of QNN EP is not currently setup to partition a model among diff --git a/onnxruntime/core/session/ep_api.cc b/onnxruntime/core/session/ep_api.cc index bbadfbee70656..ad965845041f7 100644 --- a/onnxruntime/core/session/ep_api.cc +++ b/onnxruntime/core/session/ep_api.cc @@ -5,6 +5,8 @@ #include #include + +#include "core/common/semver.h" #include "core/framework/error_code_helper.h" #include "core/framework/func_api.h" #include "core/framework/ort_value.h" @@ -14,6 +16,7 @@ #include "core/graph/ep_api_types.h" #include "core/session/abi_devices.h" #include "core/session/abi_ep_types.h" +#include "core/session/onnxruntime_ep_device_ep_metadata_keys.h" #include "core/session/ort_apis.h" using namespace onnxruntime; @@ -34,6 +37,21 @@ ORT_API_STATUS_IMPL(CreateEpDevice, _In_ OrtEpFactory* ep_factory, ep_device->ep_metadata = *ep_metadata; } + // Add EP version from OrtEpFactory to metadata. OrtEpFactory::GetVersion is supported since 1.23. + if (ep_factory->ort_version_supported >= uint32_t{23}) { + if (ep_device->ep_metadata.Entries().find(kOrtEpDevice_EpMetadataKey_Version) != + ep_device->ep_metadata.Entries().end()) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, + "The provided EP metadata should not explicitly specify the EP version."); + } + + { + std::string ep_version = ep_factory->GetVersion(ep_factory); + ORT_API_RETURN_IF_STATUS_NOT_OK(ParseSemVerVersion(ep_version, nullptr)); + ep_device->ep_metadata.Add(kOrtEpDevice_EpMetadataKey_Version, std::move(ep_version)); + } + } + if (ep_options) { ep_device->ep_options = *ep_options; } diff --git a/onnxruntime/core/session/ep_api_utils.h b/onnxruntime/core/session/ep_api_utils.h index 366f934fc610e..daccd24453371 100644 --- a/onnxruntime/core/session/ep_api_utils.h +++ b/onnxruntime/core/session/ep_api_utils.h @@ -16,6 +16,10 @@ struct ForwardToFactory { return static_cast(this_ptr)->GetVendor(); } + static const char* ORT_API_CALL GetVersion(const OrtEpFactory* this_ptr) noexcept { + return static_cast(this_ptr)->GetVersion(); + } + static OrtStatus* ORT_API_CALL GetSupportedDevices(OrtEpFactory* this_ptr, const OrtHardwareDevice* const* devices, size_t num_devices, diff --git a/onnxruntime/core/session/ep_factory_internal.cc b/onnxruntime/core/session/ep_factory_internal.cc index b906f25935983..b289010cc6c5b 100644 --- a/onnxruntime/core/session/ep_factory_internal.cc +++ b/onnxruntime/core/session/ep_factory_internal.cc @@ -8,6 +8,7 @@ #include "core/session/abi_session_options_impl.h" #include "core/session/ep_api_utils.h" #include "core/session/ort_apis.h" +#include "onnxruntime_config.h" // for ORT_VERSION namespace onnxruntime { @@ -24,11 +25,16 @@ EpFactoryInternal::EpFactoryInternal(const std::string& ep_name, const std::stri OrtEpFactory::GetName = Forward::GetFactoryName; OrtEpFactory::GetVendor = Forward::GetVendor; + OrtEpFactory::GetVersion = Forward::GetVersion; OrtEpFactory::GetSupportedDevices = Forward::GetSupportedDevices; OrtEpFactory::CreateEp = Forward::CreateEp; OrtEpFactory::ReleaseEp = Forward::ReleaseEp; } +const char* EpFactoryInternal::GetVersion() const noexcept { + return ORT_VERSION; +} + OrtStatus* EpFactoryInternal::GetSupportedDevices(const OrtHardwareDevice* const* devices, size_t num_devices, OrtEpDevice** ep_devices, diff --git a/onnxruntime/core/session/ep_factory_internal.h b/onnxruntime/core/session/ep_factory_internal.h index 1951b51a38bee..087c0c60f8f4e 100644 --- a/onnxruntime/core/session/ep_factory_internal.h +++ b/onnxruntime/core/session/ep_factory_internal.h @@ -39,6 +39,7 @@ class EpFactoryInternal : public OrtEpFactory { const char* GetName() const noexcept { return ep_name_.c_str(); } const char* GetVendor() const noexcept { return vendor_.c_str(); } + const char* GetVersion() const noexcept; OrtStatus* GetSupportedDevices(_In_reads_(num_devices) const OrtHardwareDevice* const* devices, _In_ size_t num_devices, diff --git a/onnxruntime/test/autoep/library/ep_factory.cc b/onnxruntime/test/autoep/library/ep_factory.cc index c2fa5ec88a0d8..d4895102b0bf1 100644 --- a/onnxruntime/test/autoep/library/ep_factory.cc +++ b/onnxruntime/test/autoep/library/ep_factory.cc @@ -14,6 +14,7 @@ ExampleEpFactory::ExampleEpFactory(const char* ep_name, ApiPtrs apis) ort_version_supported = ORT_API_VERSION; // set to the ORT version we were compiled with. GetName = GetNameImpl; GetVendor = GetVendorImpl; + GetVersion = GetVersionImpl; GetSupportedDevices = GetSupportedDevicesImpl; @@ -86,6 +87,12 @@ const char* ORT_API_CALL ExampleEpFactory::GetVendorImpl(const OrtEpFactory* thi return factory->vendor_.c_str(); } +/*static*/ +const char* ORT_API_CALL ExampleEpFactory::GetVersionImpl(const OrtEpFactory* this_ptr) noexcept { + const auto* factory = static_cast(this_ptr); + return factory->ep_version_.c_str(); +} + /*static*/ OrtStatus* ORT_API_CALL ExampleEpFactory::GetSupportedDevicesImpl(OrtEpFactory* this_ptr, const OrtHardwareDevice* const* devices, @@ -107,7 +114,7 @@ OrtStatus* ORT_API_CALL ExampleEpFactory::GetSupportedDevicesImpl(OrtEpFactory* factory->ort_api.CreateKeyValuePairs(&ep_options); // random example using made up values - factory->ort_api.AddKeyValuePair(ep_metadata, "version", "0.1"); + factory->ort_api.AddKeyValuePair(ep_metadata, "supported_devices", "CrackGriffin 7+"); factory->ort_api.AddKeyValuePair(ep_options, "run_really_fast", "true"); // OrtEpDevice copies ep_metadata and ep_options. @@ -136,7 +143,7 @@ OrtStatus* ORT_API_CALL ExampleEpFactory::GetSupportedDevicesImpl(OrtEpFactory* // if (device.Type() == OrtHardwareDeviceType::OrtHardwareDeviceType_CPU) { // Ort::KeyValuePairs ep_metadata; // Ort::KeyValuePairs ep_options; - // ep_metadata.Add("version", "0.1"); + // ep_metadata.Add("supported_devices", "CrackGriffin 7+"); // ep_options.Add("run_really_fast", "true"); // Ort::EpDevice ep_device{*this_ptr, device, ep_metadata.GetConst(), ep_options.GetConst()}; // ep_devices[num_ep_devices++] = ep_device.release(); diff --git a/onnxruntime/test/autoep/library/ep_factory.h b/onnxruntime/test/autoep/library/ep_factory.h index 8ab67fc9d8ce6..fda77f12c4814 100644 --- a/onnxruntime/test/autoep/library/ep_factory.h +++ b/onnxruntime/test/autoep/library/ep_factory.h @@ -22,6 +22,8 @@ class ExampleEpFactory : public OrtEpFactory, public ApiPtrs { static const char* ORT_API_CALL GetVendorImpl(const OrtEpFactory* this_ptr) noexcept; + static const char* ORT_API_CALL GetVersionImpl(const OrtEpFactory* this_ptr) noexcept; + static OrtStatus* ORT_API_CALL GetSupportedDevicesImpl(OrtEpFactory* this_ptr, const OrtHardwareDevice* const* devices, size_t num_devices, @@ -49,8 +51,9 @@ class ExampleEpFactory : public OrtEpFactory, public ApiPtrs { static OrtStatus* ORT_API_CALL CreateDataTransferImpl(OrtEpFactory* this_ptr, OrtDataTransferImpl** data_transfer) noexcept; - const std::string ep_name_; // EP name - const std::string vendor_{"Contoso"}; // EP vendor name + const std::string ep_name_; // EP name + const std::string vendor_{"Contoso"}; // EP vendor name + const std::string ep_version_{"0.1.0"}; // EP version // CPU allocator so we can control the arena behavior. optional as ORT always provides a CPU allocator if needed. using MemoryInfoUniquePtr = std::unique_ptr>; diff --git a/onnxruntime/test/autoep/test_autoep_selection.cc b/onnxruntime/test/autoep/test_autoep_selection.cc index be20d2c7c5a60..01dece34e50b0 100644 --- a/onnxruntime/test/autoep/test_autoep_selection.cc +++ b/onnxruntime/test/autoep/test_autoep_selection.cc @@ -14,6 +14,7 @@ #include "core/session/abi_key_value_pairs.h" #include "core/session/abi_session_options_impl.h" #include "core/session/onnxruntime_cxx_api.h" +#include "core/session/onnxruntime_ep_device_ep_metadata_keys.h" #include "test_allocator.h" #include "test/shared_lib/utils.h" @@ -564,7 +565,8 @@ TEST(OrtEpLibrary, LoadUnloadPluginLibraryCxxApi) { ASSERT_STREQ(test_ep_device->EpVendor(), "Contoso"); auto metadata = test_ep_device->EpMetadata(); - ASSERT_STREQ(metadata.GetValue("version"), "0.1"); + ASSERT_STREQ(metadata.GetValue(kOrtEpDevice_EpMetadataKey_Version), "0.1.0"); + ASSERT_STREQ(metadata.GetValue("supported_devices"), "CrackGriffin 7+"); auto options = test_ep_device->EpOptions(); ASSERT_STREQ(options.GetValue("run_really_fast"), "true"); diff --git a/onnxruntime/test/common/semver_test.cc b/onnxruntime/test/common/semver_test.cc new file mode 100644 index 0000000000000..5ec066e59b838 --- /dev/null +++ b/onnxruntime/test/common/semver_test.cc @@ -0,0 +1,37 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/common/semver.h" + +#include "gtest/gtest.h" + +#include "test/util/include/asserts.h" + +namespace onnxruntime::test { + +TEST(SemVerParsingTest, Basic) { + { + auto semver = ParseSemVerVersion("1.2.3-abcde+fghij"); + EXPECT_EQ(semver.major, 1); + EXPECT_EQ(semver.minor, 2); + EXPECT_EQ(semver.patch, 3); + EXPECT_EQ(semver.prerelease, "abcde"); + EXPECT_EQ(semver.build_metadata, "fghij"); + } + + { + auto semver = ParseSemVerVersion("1.2.3"); + EXPECT_EQ(semver.major, 1); + EXPECT_EQ(semver.minor, 2); + EXPECT_EQ(semver.patch, 3); + EXPECT_EQ(semver.prerelease, std::nullopt); + EXPECT_EQ(semver.build_metadata, std::nullopt); + } +} + +TEST(SemVerParsingTest, Invalid) { + SemVerVersion semver{}; + ASSERT_STATUS_NOT_OK(ParseSemVerVersion("version one point zero", &semver)); +} + +} // namespace onnxruntime::test diff --git a/onnxruntime/test/python/onnxruntime_test_python_autoep.py b/onnxruntime/test/python/onnxruntime_test_python_autoep.py index f1c924a1ade94..0c52740398b7a 100644 --- a/onnxruntime/test/python/onnxruntime_test_python_autoep.py +++ b/onnxruntime/test/python/onnxruntime_test_python_autoep.py @@ -210,7 +210,8 @@ def test_example_plugin_ep_devices(self): self.assertEqual(test_ep_device.ep_vendor, "Contoso") ep_metadata = test_ep_device.ep_metadata - self.assertEqual(ep_metadata["version"], "0.1") + self.assertEqual(ep_metadata["version"], "0.1.0") + self.assertEqual(ep_metadata["supported_devices"], "CrackGriffin 7+") ep_options = test_ep_device.ep_options self.assertEqual(ep_options["run_really_fast"], "true") From 5e4d8dc36191675263b896170aa9fa4dd6ac2e13 Mon Sep 17 00:00:00 2001 From: Bonoy0328 <42759143+Bonoy0328@users.noreply.github.com> Date: Sun, 6 Jul 2025 13:19:45 +0800 Subject: [PATCH 4/7] Fix INT32 bias overflow in QOperator INT8 symmetric quantization by adjusting weight scale and requantizing (#25278) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### Overview This PR introduces a critical fix for **QOperator INT8 symmetric quantization** in ONNX Runtime. It addresses a situation where the computed **bias scale** (`input_scale * weight_scale`) becomes too small, leading to **int32 overflow** or **precision clipping** during bias quantization. ### Problem In symmetric quantization (i.e., zero_point = 0), the bias tensor is quantized using a fixed-point scale: **bias_scale = input_scale * weight_scale** When this value is too small, the quantized int32 bias may exceed the range of `int32`, causing saturation or significant quantization error. This was observed to cause **>51% accuracy loss** in some models. ### Solution This PR adds two new functions to mitigate this: --- #### 🔧 `_adjust_weight_scale_for_int32_bias(...)` Located in `onnx_quantizer.py`, this function: - **Inspects the float bias range** to compute the smallest valid bias scale (based on int32 dynamic range) - **Compares** this threshold against `input_scale * weight_scale` - If too small, **scales up the weight scale** accordingly, to prevent overflow - Supports both per-tensor and per-channel weight quantization cases This logic is **only triggered when**: - The weight's zero point is exactly zero (i.e. symmetric) - The weight data type is `INT8` or `INT16` --- #### 🔄 `_requantize_weight(...)` After weight scale adjustment, this function: - **Finds the original quantized weight** (`q_weight`), scale, and zero point from the initializer list - **Removes** the outdated quantized weight and scale - **Re-quantizes** the original float weights using the new scale and the same zero point - **Re-inserts** them into the model to maintain consistency --- ### Summary of Benefits - ✅ Prevents int32 overflow or saturation during symmetric bias quantization - ✅ Ensures weight and bias quantization remain consistent - ✅ Reduced quantization error from >51.4% to ~3% in test models - ✅ Fix is limited in scope to QOperator + symmetric INT8/INT16 flow (safe for other modes) - ✅ Improves robustness of static quantization for hardware that performs integer-only inference --- ### Code Location - `onnxruntime/quantization/onnx_quantizer.py` - `def _adjust_weight_scale_for_int32_bias(...)` - `def _requantize_weight(...)` - Integrated in `quantize_bias_static(...)` --- Please let me know if you'd like additional test coverage or integration points. Thanks! --- .../tools/quantization/onnx_quantizer.py | 155 ++++++++++++++++++ .../test_qoperator_adjust_int32_bias.py | 105 ++++++++++++ 2 files changed, 260 insertions(+) create mode 100644 onnxruntime/test/python/quantization/test_qoperator_adjust_int32_bias.py diff --git a/onnxruntime/python/tools/quantization/onnx_quantizer.py b/onnxruntime/python/tools/quantization/onnx_quantizer.py index b0a78281041d0..148e4c06a8051 100644 --- a/onnxruntime/python/tools/quantization/onnx_quantizer.py +++ b/onnxruntime/python/tools/quantization/onnx_quantizer.py @@ -28,6 +28,7 @@ get_qmin_qmax_for_qType, get_qrange_for_qType, ms_domain, + quantize_onnx_initializer, save_and_reload_model_with_shape_infer, tensor_proto_to_array, ) @@ -635,6 +636,137 @@ def find_quantized_value(self, input_name): return self.parent.find_quantized_value(input_name) return None + def adjust_single_weight_scale_if_needed( + self, + bias_val, + input_scale, + weight_scale, + weight_scale_dtype, + weight_name, + bias_name, + qrange, + multiplicative_epsilon, + idx=None, + ): + """Adjust a single weight scale to ensure the int32 bias does not overflow.""" + absmax = np.abs(bias_val) + bias_smallest_valid_scale = multiplicative_epsilon * (2.0 * absmax) / qrange + + input_scale_fp64 = np.array(input_scale.item(), dtype=np.float64) + weight_scale_fp64 = np.array(weight_scale.item(), dtype=np.float64) + bias_candidate_scale = input_scale_fp64 * weight_scale_fp64 + + if (bias_candidate_scale < bias_smallest_valid_scale) and (bias_candidate_scale > 0.0): + ratio = bias_smallest_valid_scale / bias_candidate_scale + new_scale = weight_scale_fp64 * ratio + if idx is None: + logging.info( + f"Increasing scale for weight `{weight_name}` by the ratio {ratio} to " + f"ensure bias `{bias_name}` has a valid scale." + ) + return True, np.array(new_scale, dtype=weight_scale_dtype) + else: + logging.info( + f"Increased scale[{idx}] for weight `{weight_name}` by ratio {ratio} " + f"to ensure bias `{bias_name}` has a valid scale." + ) + return True, new_scale.astype(weight_scale_dtype) + return False, weight_scale + + def _adjust_weight_scale_for_int32_bias( + self, + input_scale: np.ndarray, + weight_scale: np.ndarray, + weight_name: str, + bias_tp: onnx.TensorProto, + is_per_channel: bool, + ) -> tuple[bool, np.ndarray | None]: + """Checks if the bias scale is too small and increases the weight scale if needed.""" + + if not weight_scale.size: + return False, None + + bias_float_data = tensor_proto_to_array(bias_tp) + int32_info = np.iinfo(np.int32) + multiplicative_epsilon = 1.0001 + qrange = np.array(int32_info.max, dtype=np.float64) - np.array(int32_info.min + 1, dtype=np.float64) + weight_scale_dtype = weight_scale.dtype + updated = False + + if not is_per_channel: + rmin = np.minimum(bias_float_data.min(), np.array(0, dtype=np.float64)) + rmax = np.maximum(bias_float_data.max(), np.array(0, dtype=np.float64)) + absmax = np.maximum(np.abs(rmin), np.abs(rmax)) + changed, new_scale = self.adjust_single_weight_scale_if_needed( + absmax, + input_scale, + weight_scale, + weight_scale_dtype, + weight_name, + bias_tp.name, + qrange, + multiplicative_epsilon, + ) + if changed: + weight_scale = new_scale + updated = True + elif weight_scale.shape and len(weight_scale.shape) == 1: + for i in range(weight_scale.shape[0]): + changed, new_scale = self.adjust_single_weight_scale_if_needed( + bias_float_data[i], + input_scale, + weight_scale[i], + weight_scale_dtype, + weight_name, + bias_tp.name, + qrange, + multiplicative_epsilon, + idx=i, + ) + if changed: + weight_scale[i] = new_scale + updated = True + + return updated, weight_scale + + def _requantize_weight(self, weight_name: str, new_scale: np.ndarray) -> None: + """Re-quantizes the given weight initializer using the provided scale.""" + + if weight_name not in self.quantized_value_map: + return + + qv = self.quantized_value_map[weight_name] + + weight_tp = find_by_name(weight_name, self.model.initializer()) + scale_init = find_by_name(qv.scale_name, self.model.initializer()) + zp_init = find_by_name(qv.zp_name, self.model.initializer()) + q_weight_init = find_by_name(qv.q_name, self.model.initializer()) + + if weight_tp is None or scale_init is None or zp_init is None or q_weight_init is None: + return + + self.model.remove_initializer(scale_init) + self.model.remove_initializer(q_weight_init) + + weight_zero_point = onnx.numpy_helper.to_array(zp_init) + axis = qv.axis + + # Add new scale initializer + scale_np = np.asarray(new_scale, dtype=onnx.helper.tensor_dtype_to_np_dtype(weight_tp.data_type)) + new_scale_init = onnx.numpy_helper.from_array(scale_np.reshape(scale_init.dims), qv.scale_name) + self.model.add_initializer(new_scale_init) + + # Add new quantized weight initializer + new_q_weight = quantize_onnx_initializer( + weight_tp, + self.weight_qType, + weight_zero_point, + scale_np, + axis, + quant_weight_name=qv.q_name, + ) + self.model.add_initializer(new_q_weight) + def quantize_bias_static(self, bias_name, input_name, weight_name, beta=1.0): """ Quantized the bias. Zero Point == 0 and Scale == Input_Scale * Weight_Scale @@ -660,6 +792,29 @@ def quantize_bias_static(self, bias_name, input_name, weight_name, beta=1.0): inputscale_initializer = find_by_name(input_scale_name, self.model.initializer()) input_scale = tensor_proto_to_array(inputscale_initializer) + # Adjust weight scale if quantizing to int32 may overflow due to a small scale + weight_zp_name = self.quantized_value_map[weight_name].zp_name + weight_zp_init = find_by_name(weight_zp_name, self.model.initializer()) + weight_zero_point = onnx.numpy_helper.to_array(weight_zp_init) if weight_zp_init is not None else None + is_per_channel = self.per_channel + if ( + weight_zero_point is not None + and weight_zero_point.size + and not weight_zero_point.any() + and self.weight_qType in (onnx_proto.TensorProto.INT8,) + ): + bias_initializer = find_by_name(bias_name, self.model.initializer()) + did_update, new_weight_scale = self._adjust_weight_scale_for_int32_bias( + input_scale, + weight_scale, + weight_name, + bias_initializer, + is_per_channel, + ) + if did_update: + self._requantize_weight(weight_name, new_weight_scale) + weight_scale = new_weight_scale + ( quantized_bias_name, quantized_bias_scale_name, diff --git a/onnxruntime/test/python/quantization/test_qoperator_adjust_int32_bias.py b/onnxruntime/test/python/quantization/test_qoperator_adjust_int32_bias.py new file mode 100644 index 0000000000000..e4c958996f773 --- /dev/null +++ b/onnxruntime/test/python/quantization/test_qoperator_adjust_int32_bias.py @@ -0,0 +1,105 @@ +import os +import tempfile +import unittest + +import numpy as np +import onnx +from op_test_utils import TestDataFeeds, check_model_correctness + +from onnxruntime.quantization import QuantFormat, QuantType, quantize_static + + +class TestAdjustWeightScaleForInt32BiasQOperator(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls._tmp_model_dir = tempfile.TemporaryDirectory(prefix="ort.qop.adj_int32_bias_") + cls._tmp_dir_path = cls._tmp_model_dir.name + + @classmethod + def tearDownClass(cls): + cls._tmp_model_dir.cleanup() + + def build_conv_test_model(self, input_shape, weight_shape, onnx_float_type): + np_float_type = onnx.helper.tensor_dtype_to_np_dtype(onnx_float_type) + input_0 = onnx.helper.make_tensor_value_info("input_0", onnx_float_type, input_shape) + output_0 = onnx.helper.make_tensor_value_info("output_0", onnx_float_type, None) + + tiny_value = 1e-7 if np_float_type == np.float32 else 0.007782 + + # Step 1: reshape to (C_out, -1) to ensure per-channel broadcasting + weight_data = np.full(weight_shape, tiny_value, dtype=np_float_type) + weight_data = weight_data.reshape(weight_shape[0], -1) + for i in range(weight_data.shape[0]): + for j in range(weight_data.shape[1]): + if j % 2 == 0: + weight_data[i, j] = -weight_data[i, j] + # Step 2: reshape back to original shape + weight_data = weight_data.reshape(weight_shape) + weight = onnx.numpy_helper.from_array(weight_data, "weight") + + bias_shape = [weight_shape[0]] + bias_data = np.ones(bias_shape, dtype=np_float_type) + for i in range(len(bias_data)): + bias_data[i] = 5.0 if (i % 2 == 0) else -4.5 + if np_float_type == np.float16: + bias_data[i] = 1400 if (i % 2 == 0) else -1200 + bias = onnx.numpy_helper.from_array(bias_data, "bias") + + conv_node = onnx.helper.make_node("Conv", ["input_0", "weight", "bias"], ["output_0"], name="Conv0") + graph = onnx.helper.make_graph([conv_node], "Convfloat", [input_0], [output_0], initializer=[weight, bias]) + opset_imports = [onnx.helper.make_opsetid("", 21)] + model = onnx.helper.make_model(graph, opset_imports=opset_imports) + model = onnx.shape_inference.infer_shapes(model) + onnx.checker.check_model(model, True) + return model + + def test_adjust_weight_scale_for_int32_bias_qop(self): + test_configs = [ + (onnx.TensorProto.FLOAT, True), + (onnx.TensorProto.FLOAT, False), + (onnx.TensorProto.FLOAT, True), + (onnx.TensorProto.FLOAT, False), + ] + + for float_type, per_channel in test_configs: + with self.subTest(float_type=float_type, per_channel=per_channel): + label = f"_f{float_type}_perchannel{per_channel}" + float_model_path = os.path.join(self._tmp_dir_path, f"conv{label}.float.onnx") + qop_model_path = os.path.join(self._tmp_dir_path, f"conv{label}.qop.onnx") + + input_shape = [1, 1, 128, 128] + weight_shape = [8, 1, 1, 1] + float_model = self.build_conv_test_model(input_shape, weight_shape, float_type) + onnx.save_model(float_model, float_model_path) + + np_float_type = onnx.helper.tensor_dtype_to_np_dtype(float_type) + input_rmin = 0.0 + input_scale = 0.05 if float_type == onnx.TensorProto.FLOAT else 0.01 + input_rmax = (input_scale * 255.0) + input_rmin + input_data_list = [ + {"input_0": np.full(input_shape, input_rmin, dtype=np_float_type)}, + {"input_0": np.full(input_shape, (input_rmax - input_rmin) / 2.0, dtype=np_float_type)}, + {"input_0": np.full(input_shape, input_rmax, dtype=np_float_type)}, + ] + data_reader = TestDataFeeds(input_data_list) + + quantize_static( + float_model_path, + qop_model_path, + data_reader, + activation_type=QuantType.QInt8, + weight_type=QuantType.QInt8, + per_channel=per_channel, + quant_format=QuantFormat.QOperator, + extra_options={ + "ActivationSymmetric": True, + "WeightSymmetric": True, + }, + ) + + data_reader.rewind() + check_model_correctness(self, float_model_path, qop_model_path, data_reader.get_next()) + + +if __name__ == "__main__": + unittest.main() From e63e053b14910b3c867938ca02f3107d0456e2c6 Mon Sep 17 00:00:00 2001 From: Jiajia Qin Date: Mon, 7 Jul 2025 10:31:15 +0800 Subject: [PATCH 5/7] [webgpu] Enable graph capture (#24900) This PR enables graph capture capabilities in the WebGPU provider, which is similar with jsep one #18989. All limitations are similar with JS/CUDA EP: 1. Models with control-flow ops (i.e. If, Loop and Scan ops) are not supported. 2. Usage of graph capture is limited to models where-in all ops in the model can be partitioned to the WebGPU EP or CPU EP and no memory copy between them. 3. Shapes of inputs/outputs cannot change across inference calls. 4. IOBinding is required. And all inputs/outputs are pre-allocated gpu buffers. When users use graph capture feature, we suppose they will do some pre-process and post-process for the inference's inputs and outputs in order to keep the whole pipeline on GPU to avoid some unnecessary cpu to gpu or gpu to cpu copying. The usage will be like below: ``` // Initialize Dawn { // 1. Create Dawn instance ... instance = wgpu::CreateInstance(&instanceDescriptor); // 2. Create the adapter ... instance.RequestAdapter // 3. Create device from adapter ... adapter.RequestDevice } // Create session options webgpu_options_ = std::make_unique(); std::unordered_map provider_options; provider_options["dawnProcTable"] = std::to_string(reinterpret_cast(&dawn::native::GetProcs())); provider_options["webgpuInstance"] = std::to_string(reinterpret_cast(instance_.Get())); provider_options["webgpuDevice"] = std::to_string(reinterpret_cast(device_.Get())); provider_options["deviceId"] = "1"; provider_options["enableGraphCapture"] = "1"; // add WebGPU provider webgpu_options_->AppendExecutionProvider("WebGPU", provider_options); ... // create webgpu session webgpu_session_ = std::make_unique(*env_, model_path_.c_str(), *webgpu_options_); ... Ort::MemoryInfo memory_info_gpu("WebGPU_Buffer", OrtAllocatorType::OrtDeviceAllocator, 0, OrtMemType::OrtMemTypeDefault); Ort::Allocator allocator(*webgpu_session_, memory_info_gpu); auto input_buffer = allocator.GetAllocation(input_tensor_size_ * sizeof(float)); auto output_buffer = allocator.GetAllocation(output_tensor_size_ * sizeof(float)); // Create IoBinding objects Ort::IoBinding webgpu_binding(*webgpu_session_); // Upload cpu data to input_buffer or copy gpu buffer to input_buffer ... // Create an OrtValue tensor backed by data on gpu memory Ort::Value bound_x = Ort::Value::CreateTensor(memory_info_gpu, reinterpret_cast(input_buffer.get()), input_tensor_size_, input_dims_.data(), input_dims_.size()); Ort::Value bound_y = Ort::Value::CreateTensor(memory_info_gpu, reinterpret_cast(output_buffer.get()), output_tensor_size_, output_dims_.data(), output_dims_.size()); webgpu_binding.BindInput("input", bound_x); webgpu_binding.BindOutput("output", bound_y); // Run inference webgpu_session_->Run(Ort::RunOptions{nullptr}, webgpu_binding); // normal run + capturing ... // post process output_buffer's content ... // Update input_buffer's content ... // Run again webgpu_session_->Run(Ort::RunOptions{nullptr}, webgpu_binding); // replay() ... // post process output_buffer's content ... ``` --- .../core/providers/webgpu/allocator.cc | 13 +- onnxruntime/core/providers/webgpu/allocator.h | 12 +- .../core/providers/webgpu/buffer_manager.cc | 206 ++++++++++++++++-- .../core/providers/webgpu/buffer_manager.h | 36 +-- .../core/providers/webgpu/compute_context.cc | 9 + .../core/providers/webgpu/compute_context.h | 7 +- .../core/providers/webgpu/data_transfer.cc | 10 +- .../core/providers/webgpu/data_transfer.h | 6 +- .../core/providers/webgpu/webgpu_context.cc | 103 +++++++-- .../core/providers/webgpu/webgpu_context.h | 19 +- .../webgpu/webgpu_execution_provider.cc | 43 +++- .../webgpu/webgpu_execution_provider.h | 13 +- .../webgpu/webgpu_provider_factory.cc | 6 +- .../test/framework/inference_session_test.cc | 181 +++++++++++---- onnxruntime/test/util/default_providers.cc | 9 + .../test/util/include/default_providers.h | 2 + 16 files changed, 546 insertions(+), 129 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/allocator.cc b/onnxruntime/core/providers/webgpu/allocator.cc index 315d0cd75e946..48d884858f493 100644 --- a/onnxruntime/core/providers/webgpu/allocator.cc +++ b/onnxruntime/core/providers/webgpu/allocator.cc @@ -3,7 +3,7 @@ #include "core/framework/session_state.h" #include "core/providers/webgpu/allocator.h" -#include "core/providers/webgpu/webgpu_context.h" +#include "core/providers/webgpu/buffer_manager.h" namespace onnxruntime { namespace webgpu { @@ -15,18 +15,17 @@ void* GpuBufferAllocator::Alloc(size_t size) { stats_.num_allocs++; -#if !defined(__wasm__) - if (!session_initialized_ && context_.DeviceHasFeature(wgpu::FeatureName::BufferMapExtendedUsages)) { - return context_.BufferManager().CreateUMA(size); + // Check if the buffer manager supports UMA and we're not yet in an initialized session + if (!session_initialized_ && buffer_manager_.SupportsUMA()) { + return buffer_manager_.CreateUMA(size); } -#endif // !defined(__wasm__) - return context_.BufferManager().Create(size); + return buffer_manager_.Create(size); } void GpuBufferAllocator::Free(void* p) { if (p != nullptr) { - context_.BufferManager().Release(static_cast(p)); + buffer_manager_.Release(static_cast(p)); stats_.num_allocs--; } } diff --git a/onnxruntime/core/providers/webgpu/allocator.h b/onnxruntime/core/providers/webgpu/allocator.h index 0b27f713777bc..02f78c18fc947 100644 --- a/onnxruntime/core/providers/webgpu/allocator.h +++ b/onnxruntime/core/providers/webgpu/allocator.h @@ -9,27 +9,29 @@ namespace onnxruntime { namespace webgpu { -class WebGpuContext; +class BufferManager; class GpuBufferAllocator : public IAllocator { public: - GpuBufferAllocator(const WebGpuContext& context) + GpuBufferAllocator(const BufferManager& buffer_manager) : IAllocator( OrtMemoryInfo(WEBGPU_BUFFER, OrtAllocatorType::OrtDeviceAllocator, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::NONE, 0), OrtMemTypeDefault)), - context_{context} { + buffer_manager_{buffer_manager} { } virtual void* Alloc(size_t size) override; virtual void Free(void* p) override; void GetStats(AllocatorStats* stats) override; - void OnSessionInitializationEnd(); + // Return the associated BufferManager + const BufferManager& GetBufferManager() const { return buffer_manager_; } + private: AllocatorStats stats_; - const WebGpuContext& context_; + const BufferManager& buffer_manager_; bool session_initialized_ = false; }; diff --git a/onnxruntime/core/providers/webgpu/buffer_manager.cc b/onnxruntime/core/providers/webgpu/buffer_manager.cc index 1d8c689cbd909..c02049f60e4db 100644 --- a/onnxruntime/core/providers/webgpu/buffer_manager.cc +++ b/onnxruntime/core/providers/webgpu/buffer_manager.cc @@ -37,7 +37,7 @@ class DisabledCacheManager : public IBufferCacheManager { wgpuBufferRelease(buffer); } - void OnRefresh() override { + void OnRefresh(const SessionState& /*session_status*/) override { // no-op } }; @@ -59,7 +59,7 @@ class LazyReleaseCacheManager : public IBufferCacheManager { pending_buffers_.emplace_back(buffer); } - void OnRefresh() override { + void OnRefresh(const SessionState& /*session_status*/) override { Release(); pending_buffers_.clear(); } @@ -103,7 +103,7 @@ class SimpleCacheManager : public IBufferCacheManager { pending_buffers_.emplace_back(buffer); } - void OnRefresh() override { + void OnRefresh(const SessionState& /*session_status*/) override { for (auto& buffer : pending_buffers_) { buffers_[static_cast(wgpuBufferGetSize(buffer))].emplace_back(buffer); } @@ -196,12 +196,9 @@ class BucketCacheManager : public IBufferCacheManager { pending_buffers_.emplace_back(buffer); } - void OnRefresh() override { - // TODO: consider graph capture. currently not supported - + void OnRefresh(const SessionState& /*session_status*/) override { for (auto& buffer : pending_buffers_) { auto buffer_size = static_cast(wgpuBufferGetSize(buffer)); - auto it = buckets_.find(buffer_size); if (it != buckets_.end() && it->second.size() < buckets_limit_[buffer_size]) { it->second.emplace_back(buffer); @@ -249,6 +246,155 @@ class BucketCacheManager : public IBufferCacheManager { std::vector buckets_keys_; }; +class GraphCacheManager : public IBufferCacheManager { + public: + GraphCacheManager() : buckets_limit_{BUCKET_DEFAULT_LIMIT_TABLE} { + Initialize(); + } + GraphCacheManager(std::unordered_map&& buckets_limit) : buckets_limit_{buckets_limit} { + Initialize(); + } + + size_t CalculateBufferSize(size_t request_size) override { + // binary serch size + auto it = std::lower_bound(buckets_keys_.begin(), buckets_keys_.end(), request_size); + if (it == buckets_keys_.end()) { + return NormalizeBufferSize(request_size); + } else { + return *it; + } + } + + WGPUBuffer TryAcquireCachedBuffer(size_t buffer_size) override { + auto it = buckets_.find(buffer_size); + if (it != buckets_.end() && !it->second.empty()) { + auto buffer = it->second.back(); + it->second.pop_back(); + return buffer; + } + return nullptr; + } + + void RegisterBuffer(WGPUBuffer /*buffer*/, size_t /*request_size*/) override { + // no-op + } + + void ReleaseBuffer(WGPUBuffer buffer) override { + pending_buffers_.emplace_back(buffer); + } + + void OnRefresh(const SessionState& /*session_status*/) override { + // Initialize buckets if they don't exist yet + if (buckets_.empty()) { + for (const auto& pair : buckets_limit_) { + buckets_.emplace(pair.first, std::vector()); + } + } + + for (auto& buffer : pending_buffers_) { + auto buffer_size = static_cast(wgpuBufferGetSize(buffer)); + auto it = buckets_.find(buffer_size); + if (it != buckets_.end()) { + it->second.emplace_back(buffer); + } else { + // insert a new bucket if it doesn't exist + buckets_[buffer_size] = std::vector{buffer}; + } + } + + pending_buffers_.clear(); + } + + ~GraphCacheManager() { + for (auto& buffer : pending_buffers_) { + wgpuBufferRelease(buffer); + } + for (auto& pair : buckets_) { + for (auto& buffer : pair.second) { + wgpuBufferRelease(buffer); + } + } + } + + protected: + void Initialize() { + buckets_keys_.reserve(buckets_limit_.size()); + for (const auto& pair : buckets_limit_) { + buckets_keys_.push_back(pair.first); + } + std::sort(buckets_keys_.begin(), buckets_keys_.end()); + +#ifndef NDEBUG // if debug build + ORT_ENFORCE(std::all_of(buckets_keys_.begin(), buckets_keys_.end(), [](size_t size) { return size % 16 == 0; }), + "Bucket sizes must be multiples of 16."); + + for (size_t i = 1; i < buckets_keys_.size(); ++i) { + ORT_ENFORCE(buckets_keys_[i] > buckets_keys_[i - 1], "Bucket sizes must be in increasing order."); + } +#endif + } + std::unordered_map buckets_limit_; + std::unordered_map> buckets_; + std::vector pending_buffers_; + std::vector buckets_keys_; +}; + +class GraphSimpleCacheManager : public IBufferCacheManager { + size_t CalculateBufferSize(size_t request_size) override { + return NormalizeBufferSize(request_size); + } + + WGPUBuffer TryAcquireCachedBuffer(size_t buffer_size) override { + auto it = buffers_.find(buffer_size); + if (it != buffers_.end() && !it->second.empty()) { + auto buffer = it->second.back(); + it->second.pop_back(); + return buffer; + } + + return nullptr; + } + + void RegisterBuffer(WGPUBuffer /*buffer*/, size_t /*request_size*/) override { + // no-op + } + + void ReleaseBuffer(WGPUBuffer buffer) override { + pending_buffers_.emplace_back(buffer); + } + + void OnRefresh(const SessionState& session_status) override { + for (auto& buffer : pending_buffers_) { + if (session_status == SessionState::Default) { + buffers_[static_cast(wgpuBufferGetSize(buffer))].emplace_back(buffer); + } else { + captured_buffers_.emplace_back(buffer); + } + } + pending_buffers_.clear(); + } + + public: + ~GraphSimpleCacheManager() { + for (auto& buffer : pending_buffers_) { + wgpuBufferRelease(buffer); + } + for (auto& pair : buffers_) { + for (auto& buffer : pair.second) { + wgpuBufferRelease(buffer); + } + } + for (auto& buffer : captured_buffers_) { + wgpuBufferRelease(buffer); + } + } + + protected: + std::map> buffers_; + std::vector pending_buffers_; + std::vector captured_buffers_; +}; + std::unique_ptr CreateBufferCacheManager(BufferCacheMode cache_mode) { switch (cache_mode) { case BufferCacheMode::Disabled: @@ -259,6 +405,10 @@ std::unique_ptr CreateBufferCacheManager(BufferCacheMode ca return std::make_unique(); case BufferCacheMode::Bucket: return std::make_unique(); + case BufferCacheMode::Graph: + return std::make_unique(); + case BufferCacheMode::GraphSimple: + return std::make_unique(); default: ORT_NOT_IMPLEMENTED("Unsupported buffer cache mode"); } @@ -278,6 +428,12 @@ std::ostream& operator<<(std::ostream& os, BufferCacheMode mode) { case BufferCacheMode::Bucket: os << "Bucket"; break; + case BufferCacheMode::Graph: + os << "Graph"; + break; + case BufferCacheMode::GraphSimple: + os << "GraphSimple"; + break; default: os << "Unknown(" << static_cast(mode) << ")"; } @@ -292,7 +448,7 @@ BufferManager::BufferManager(WebGpuContext& context, BufferCacheMode storage_buf default_cache_{CreateBufferCacheManager(BufferCacheMode::Disabled)} { } -void BufferManager::Upload(void* src, WGPUBuffer dst, size_t size) { +void BufferManager::Upload(void* src, WGPUBuffer dst, size_t size) const { // If the buffer is mapped, we can directly write to it. void* mapped_data = wgpuBufferGetMappedRange(dst, 0, WGPU_WHOLE_MAP_SIZE); // ensure the buffer is mapped if (mapped_data) { @@ -317,10 +473,10 @@ void BufferManager::Upload(void* src, WGPUBuffer dst, size_t size) { auto& command_encoder = context_.GetCommandEncoder(); context_.EndComputePass(); command_encoder.CopyBufferToBuffer(staging_buffer, 0, dst, 0, buffer_size); - context_.Flush(); + context_.Flush(*this); } -void BufferManager::MemCpy(WGPUBuffer src, WGPUBuffer dst, size_t size) { +void BufferManager::MemCpy(WGPUBuffer src, WGPUBuffer dst, size_t size) const { ORT_ENFORCE(src != dst, "Source and destination buffers must be different."); EnforceBufferUnmapped(context_, src); EnforceBufferUnmapped(context_, dst); @@ -337,7 +493,7 @@ void BufferManager::MemCpy(WGPUBuffer src, WGPUBuffer dst, size_t size) { command_encoder.CopyBufferToBuffer(src, 0, dst, 0, buffer_size); } -WGPUBuffer BufferManager::Create(size_t size, wgpu::BufferUsage usage) { +WGPUBuffer BufferManager::Create(size_t size, wgpu::BufferUsage usage) const { auto& cache = GetCacheManager(usage); auto buffer_size = cache.CalculateBufferSize(size); @@ -358,7 +514,7 @@ WGPUBuffer BufferManager::Create(size_t size, wgpu::BufferUsage usage) { return buffer; } -WGPUBuffer BufferManager::CreateUMA(size_t size, wgpu::BufferUsage usage) { +WGPUBuffer BufferManager::CreateUMA(size_t size, wgpu::BufferUsage usage) const { ORT_ENFORCE(usage & wgpu::BufferUsage::Storage, "UMA buffer must be a storage buffer."); auto& cache = GetCacheManager(usage); auto buffer_size = cache.CalculateBufferSize(size); @@ -378,12 +534,21 @@ WGPUBuffer BufferManager::CreateUMA(size_t size, wgpu::BufferUsage usage) { return buffer; } -void BufferManager::Release(WGPUBuffer buffer) { +bool BufferManager::SupportsUMA() const { +#if !defined(__wasm__) + // Check if the device supports the BufferMapExtendedUsages feature + return context_.DeviceHasFeature(wgpu::FeatureName::BufferMapExtendedUsages); +#else + return false; +#endif // !defined(__wasm__) +} + +void BufferManager::Release(WGPUBuffer buffer) const { EnforceBufferUnmapped(context_, buffer); GetCacheManager(buffer).ReleaseBuffer(buffer); } -void BufferManager::Download(WGPUBuffer src, void* dst, size_t size) { +void BufferManager::Download(WGPUBuffer src, void* dst, size_t size) const { EnforceBufferUnmapped(context_, src); auto buffer_size = NormalizeBufferSize(size); @@ -395,7 +560,7 @@ void BufferManager::Download(WGPUBuffer src, void* dst, size_t size) { auto& command_encoder = context_.GetCommandEncoder(); context_.EndComputePass(); command_encoder.CopyBufferToBuffer(src, 0, staging_buffer, 0, buffer_size); - context_.Flush(); + context_.Flush(*this); // TODO: revise wait in whole project @@ -405,13 +570,14 @@ void BufferManager::Download(WGPUBuffer src, void* dst, size_t size) { auto mapped_data = staging_buffer.GetConstMappedRange(); memcpy(dst, mapped_data, size); + staging_buffer.Unmap(); } -void BufferManager::RefreshPendingBuffers() { - storage_cache_->OnRefresh(); - uniform_cache_->OnRefresh(); - query_resolve_cache_->OnRefresh(); - default_cache_->OnRefresh(); +void BufferManager::RefreshPendingBuffers(const SessionState& session_status) const { + storage_cache_->OnRefresh(session_status); + uniform_cache_->OnRefresh(session_status); + query_resolve_cache_->OnRefresh(session_status); + default_cache_->OnRefresh(session_status); } IBufferCacheManager& BufferManager::GetCacheManager(wgpu::BufferUsage usage) const { diff --git a/onnxruntime/core/providers/webgpu/buffer_manager.h b/onnxruntime/core/providers/webgpu/buffer_manager.h index b9028ad5de858..8f7882d2f2fa8 100644 --- a/onnxruntime/core/providers/webgpu/buffer_manager.h +++ b/onnxruntime/core/providers/webgpu/buffer_manager.h @@ -14,11 +14,20 @@ namespace webgpu { class WebGpuContext; +// For command capture and replay +enum class SessionState { + Default, + Capturing, + Replaying +}; + enum class BufferCacheMode { Disabled, LazyRelease, Simple, - Bucket + Bucket, + Graph, + GraphSimple, }; std::ostream& operator<<(std::ostream& os, BufferCacheMode mode); @@ -26,12 +35,13 @@ std::ostream& operator<<(std::ostream& os, BufferCacheMode mode); // IBufferCacheManager is an interface for buffer cache management. // // By implementing this interface, we can have different buffer cache management strategies. -// Currently, we have 3 strategies: +// Currently, we have 5 strategies: // - Disabled: no cache. always allocate a new buffer and release it immediately after use. // - LazyRelease: no cache. the difference from Disabled is that it delays the release of buffers until the next refresh. // - Simple: a simple cache that always keeps buffers. when a buffer is requested, it tries to find a buffer in the cache. // - Bucket: a cache that keeps buffers in different buckets based on the buffer size, with a maximum number of buffers in each bucket. -// +// - Graph: used for graph capturing storage buffer cache mode. All buffers will be cached. Buffers can be reused across runs and in one run. +// - GraphSimple: used for graph capturing uniform buffer cache mode. All buffers will be cached. Buffers can be reused across runs but can't be reused in one run. class IBufferCacheManager { public: virtual ~IBufferCacheManager() = default; @@ -49,7 +59,7 @@ class IBufferCacheManager { virtual void ReleaseBuffer(WGPUBuffer buffer) = 0; // when a stream refresh is requested - virtual void OnRefresh() = 0; + virtual void OnRefresh(const SessionState& session_status) = 0; }; // @@ -58,16 +68,16 @@ class IBufferCacheManager { class BufferManager { public: BufferManager(WebGpuContext& context, BufferCacheMode storage_buffer_cache_mode, BufferCacheMode uniform_buffer_cache_mode, BufferCacheMode query_resolve_buffer_cache_mode); - - void Upload(void* src, WGPUBuffer dst, size_t size); - void MemCpy(WGPUBuffer src, WGPUBuffer dst, size_t size); - WGPUBuffer Create(size_t size, wgpu::BufferUsage usage = wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::CopyDst); + void Upload(void* src, WGPUBuffer dst, size_t size) const; + void MemCpy(WGPUBuffer src, WGPUBuffer dst, size_t size) const; + WGPUBuffer Create(size_t size, wgpu::BufferUsage usage = wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::CopyDst) const; // Create a buffer mapped for writing. - WGPUBuffer CreateUMA(size_t size, wgpu::BufferUsage usage = wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc | - wgpu::BufferUsage::CopyDst); - void Release(WGPUBuffer buffer); - void Download(WGPUBuffer src, void* dst, size_t size); - void RefreshPendingBuffers(); + WGPUBuffer CreateUMA(size_t size, wgpu::BufferUsage usage = wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::CopyDst) const; + // Check if CreateUMA is supported (i.e., the device has BufferMapExtendedUsages feature) + bool SupportsUMA() const; + void Release(WGPUBuffer buffer) const; + void Download(WGPUBuffer src, void* dst, size_t size) const; + void RefreshPendingBuffers(const SessionState& session_status) const; private: IBufferCacheManager& GetCacheManager(wgpu::BufferUsage usage) const; diff --git a/onnxruntime/core/providers/webgpu/compute_context.cc b/onnxruntime/core/providers/webgpu/compute_context.cc index 1713a9a1ad050..904a2885ff619 100644 --- a/onnxruntime/core/providers/webgpu/compute_context.cc +++ b/onnxruntime/core/providers/webgpu/compute_context.cc @@ -5,6 +5,8 @@ #include "core/providers/webgpu/compute_context.h" #include "core/providers/webgpu/webgpu_context.h" +#include "core/providers/webgpu/allocator.h" +#include "core/providers/webgpu/buffer_manager.h" namespace onnxruntime { namespace webgpu { @@ -26,5 +28,12 @@ Status ComputeContext::PopErrorScope() { return Status::OK(); } +const webgpu::BufferManager& ComputeContext::BufferManager() const { + OrtDevice gpu_device(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::NONE, 0); + AllocatorPtr allocator = kernel_context_.GetAllocator(gpu_device); + const GpuBufferAllocator* gpu_allocator = static_cast(allocator.get()); + return gpu_allocator->GetBufferManager(); +} + } // namespace webgpu } // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/compute_context.h b/onnxruntime/core/providers/webgpu/compute_context.h index 7a9cf1ecf85ba..c9bdb3a92f162 100644 --- a/onnxruntime/core/providers/webgpu/compute_context.h +++ b/onnxruntime/core/providers/webgpu/compute_context.h @@ -20,6 +20,7 @@ class Tensor; namespace webgpu { class WebGpuContext; +class BufferManager; class ComputeContext { public: @@ -115,7 +116,6 @@ class ComputeContext { ORT_THROW_IF_ERROR(kernel_context_.GetTempSpaceAllocator(&allocator)); return {data_type, std::forward(shape), allocator}; } - // // Run a compute shader program. // @@ -123,6 +123,11 @@ class ComputeContext { return webgpu_context_.Run(*this, program); } + // + // Get the buffer manager from the GPU allocator. + // + const webgpu::BufferManager& BufferManager() const; + // // Push error scope. // diff --git a/onnxruntime/core/providers/webgpu/data_transfer.cc b/onnxruntime/core/providers/webgpu/data_transfer.cc index ac376b4fce069..6d66a7308f1de 100644 --- a/onnxruntime/core/providers/webgpu/data_transfer.cc +++ b/onnxruntime/core/providers/webgpu/data_transfer.cc @@ -2,7 +2,7 @@ // Licensed under the MIT License. #include "core/providers/webgpu/data_transfer.h" -#include "core/providers/webgpu/webgpu_context.h" +#include "core/providers/webgpu/buffer_manager.h" namespace onnxruntime { namespace webgpu { @@ -25,15 +25,15 @@ common::Status DataTransfer::CopyTensor(const Tensor& src, Tensor& dst) const { if (dst_device.Type() == OrtDevice::GPU) { if (src_device.Type() == OrtDevice::GPU) { // copy from GPU to GPU - context_.BufferManager().MemCpy(static_cast(const_cast(src_data)), - static_cast(dst_data), bytes); + buffer_manager_.MemCpy(static_cast(const_cast(src_data)), + static_cast(dst_data), bytes); } else { // copy from CPU to GPU - context_.BufferManager().Upload(const_cast(src_data), static_cast(dst_data), bytes); + buffer_manager_.Upload(const_cast(src_data), static_cast(dst_data), bytes); } } else /* if (src_device.Type() == OrtDevice::GPU) */ { // copy from GPU to CPU - context_.BufferManager().Download(static_cast(const_cast(src_data)), dst_data, bytes); + buffer_manager_.Download(static_cast(const_cast(src_data)), dst_data, bytes); } } diff --git a/onnxruntime/core/providers/webgpu/data_transfer.h b/onnxruntime/core/providers/webgpu/data_transfer.h index f9949576aa60b..0adf380149acf 100644 --- a/onnxruntime/core/providers/webgpu/data_transfer.h +++ b/onnxruntime/core/providers/webgpu/data_transfer.h @@ -9,11 +9,11 @@ namespace onnxruntime { namespace webgpu { -class WebGpuContext; +class BufferManager; class DataTransfer : public IDataTransfer { public: - DataTransfer(const WebGpuContext& context) : context_{context} {}; + DataTransfer(const BufferManager& buffer_manager) : buffer_manager_{buffer_manager} {}; ~DataTransfer() {}; bool CanCopy(const OrtDevice& src_device, const OrtDevice& dst_device) const override; @@ -21,7 +21,7 @@ class DataTransfer : public IDataTransfer { common::Status CopyTensor(const Tensor& src, Tensor& dst) const override; private: - const WebGpuContext& context_; + const BufferManager& buffer_manager_; }; } // namespace webgpu diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.cc b/onnxruntime/core/providers/webgpu/webgpu_context.cc index 4bb41c2eb0ba6..09cbe8b52f2cc 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_context.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_context.cc @@ -401,6 +401,7 @@ Status WebGpuContext::Run(ComputeContext& context, const ProgramBase& program) { const size_t uniform_buffer_total_size = (current_offset + max_alignment_of_field - 1) / max_alignment_of_field * max_alignment_of_field; WGPUBuffer uniform_buffer = nullptr; + const webgpu::BufferManager& buffer_mgr = context.BufferManager(); if (uniform_buffer_total_size > 0) { std::vector uniform_data_buffer(uniform_buffer_total_size); @@ -408,7 +409,7 @@ Status WebGpuContext::Run(ComputeContext& context, const ProgramBase& program) { memcpy(uniform_data_buffer.data() + offset, uniform.data.data(), uniform.data.size()); } - uniform_buffer = buffer_mgr_->Create(uniform_buffer_total_size, wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform); + uniform_buffer = buffer_mgr.Create(uniform_buffer_total_size, wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform); device_queue_.WriteBuffer(uniform_buffer, 0, uniform_data_buffer.data(), uniform_buffer_total_size); } @@ -429,13 +430,11 @@ Status WebGpuContext::Run(ComputeContext& context, const ProgramBase& program) { } LaunchComputePipeline(compute_pass_encoder, bind_buffers, *program_artifact, x, y, z); - if (uniform_buffer) { - buffer_mgr_->Release(uniform_buffer); + buffer_mgr.Release(uniform_buffer); } WriteTimestamp(num_pending_dispatches_ * 2 + 1); - ++num_pending_dispatches_; if (num_pending_dispatches_ >= max_num_pending_dispatches_ || @@ -443,7 +442,7 @@ Status WebGpuContext::Run(ComputeContext& context, const ProgramBase& program) { EndComputePass(); } if (num_pending_dispatches_ >= max_num_pending_dispatches_) { - Flush(); + Flush(buffer_mgr); num_pending_dispatches_ = 0; } @@ -659,7 +658,7 @@ Status WebGpuContext::PopErrorScope() { return status; } -void WebGpuContext::Flush() { +void WebGpuContext::Flush(const webgpu::BufferManager& buffer_mgr) { if (!current_command_encoder_) { return; } @@ -690,10 +689,11 @@ void WebGpuContext::Flush() { pending_queries_.emplace_back(std::move(pending_kernels_), query_read_buffer); pending_kernels_.clear(); } - auto command_buffer = current_command_encoder_.Finish(); device_queue_.Submit(1, &command_buffer); - BufferManager().RefreshPendingBuffers(); + if (session_status_ != SessionState::Replaying) { + buffer_mgr.RefreshPendingBuffers(session_status_); + } current_command_encoder_ = nullptr; num_pending_dispatches_ = 0; } @@ -724,15 +724,90 @@ void WebGpuContext::LaunchComputePipeline(const wgpu::ComputePassEncoder& comput bind_group_desc.label = {program_artifact.name.data(), program_artifact.name.length()}; auto bind_group = wgpuDeviceCreateBindGroup(Device().Get(), &bind_group_desc); + if (session_status_ == SessionState::Capturing) { + external_captured_commands_->push_back({program_artifact.compute_pipeline, + bind_group, + bind_group_layout, + {x, y, z}}); + } else { + compute_pass_encoder.SetPipeline(program_artifact.compute_pipeline); + wgpuComputePassEncoderSetBindGroup(compute_pass_encoder.Get(), 0, bind_group, 0, nullptr); + compute_pass_encoder.DispatchWorkgroups(x, y, z); + + wgpuBindGroupRelease(bind_group); + wgpuBindGroupLayoutRelease(bind_group_layout); + } +} + +void WebGpuContext::CaptureBegin(std::vector* captured_commands, const webgpu::BufferManager& buffer_manager) { + LOGS_DEFAULT(VERBOSE) << "CaptureBegin with external storage"; + // Flush any pending commands before we change the status + Flush(buffer_manager); + + external_captured_commands_ = captured_commands; + + // Make sure the external vector is empty before we start capturing + if (external_captured_commands_) { + external_captured_commands_->clear(); + } + + // TODO: support profiling with graph capture. + ORT_ENFORCE(!is_profiling_, "profiling is not supported yet under graph capture mode"); + + session_status_ = SessionState::Capturing; +} + +void WebGpuContext::Replay(const std::vector& captured_commands, const webgpu::BufferManager& buffer_manager) { + LOGS_DEFAULT(VERBOSE) << "Replay with external storage"; + session_status_ = SessionState::Replaying; + // Replay all captured commands from the provided vector + const size_t command_count = captured_commands.size(); + for (size_t i = 0; i < command_count; ++i) { + auto& command = captured_commands[i]; + const auto& compute_pass_encoder = GetComputePassEncoder(); + WriteTimestamp(num_pending_dispatches_ * 2); + compute_pass_encoder.SetPipeline(command.compute_pipeline); + wgpuComputePassEncoderSetBindGroup(compute_pass_encoder.Get(), 0, command.bind_group, 0, nullptr); + compute_pass_encoder.DispatchWorkgroups(command.dispatch_group[0], command.dispatch_group[1], command.dispatch_group[2]); + WriteTimestamp(num_pending_dispatches_ * 2 + 1); + ++num_pending_dispatches_; + if (num_pending_dispatches_ >= max_num_pending_dispatches_ || + (is_profiling_ && query_type_ == TimestampQueryType::AtPasses)) { + EndComputePass(); + } + if (num_pending_dispatches_ >= max_num_pending_dispatches_) { + Flush(buffer_manager); + num_pending_dispatches_ = 0; + } + } + + // Flush any remaining commands + Flush(buffer_manager); + + session_status_ = SessionState::Default; +} + +void WebGpuContext::CaptureEnd() { + LOGS_DEFAULT(VERBOSE) << "CaptureEnd"; - // TODO support graph capture + session_status_ = SessionState::Default; + external_captured_commands_ = nullptr; +} - compute_pass_encoder.SetPipeline(program_artifact.compute_pipeline); - wgpuComputePassEncoderSetBindGroup(compute_pass_encoder.Get(), 0, bind_group, 0, nullptr); - compute_pass_encoder.DispatchWorkgroups(x, y, z); +void WebGpuContext::ReleaseGraphResources(std::vector& captured_commands) { + LOGS_DEFAULT(VERBOSE) << "ReleaseGraphResources: Releasing " << captured_commands.size() << " captured command resources"; - wgpuBindGroupRelease(bind_group); - wgpuBindGroupLayoutRelease(bind_group_layout); + for (auto& command : captured_commands) { + if (command.bind_group != nullptr) { + wgpuBindGroupRelease(command.bind_group); + command.bind_group = nullptr; + } + + if (command.bind_group_layout != nullptr) { + wgpuBindGroupLayoutRelease(command.bind_group_layout); + command.bind_group_layout = nullptr; + } + } } std::unordered_map WebGpuContextFactory::contexts_; diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.h b/onnxruntime/core/providers/webgpu/webgpu_context.h index 4111f809b1627..935c99e0a11f6 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_context.h +++ b/onnxruntime/core/providers/webgpu/webgpu_context.h @@ -10,7 +10,6 @@ #include "core/common/common.h" #include "core/framework/library_handles.h" -#include "core/providers/webgpu/webgpu_execution_provider.h" #include "core/providers/webgpu/buffer_manager.h" #include "core/providers/webgpu/program_manager.h" @@ -26,6 +25,14 @@ class WebGpuContext; class ComputeContext; class ProgramBase; +// Definition for CapturedCommandInfo in the webgpu namespace +struct CapturedCommandInfo { + wgpu::ComputePipeline compute_pipeline; + WGPUBindGroup bind_group; + WGPUBindGroupLayout bind_group_layout; + std::array dispatch_group; +}; + struct WebGpuContextConfig { int context_id; WGPUInstance instance; @@ -118,8 +125,12 @@ class WebGpuContext final { current_compute_pass_encoder_ = nullptr; } } + void CaptureBegin(std::vector* captured_commands, const webgpu::BufferManager& buffer_manager); + void CaptureEnd(); + void Replay(const std::vector& captured_commands, const webgpu::BufferManager& buffer_manager); + void ReleaseGraphResources(std::vector& captured_commands); - void Flush(); + void Flush(const webgpu::BufferManager& buffer_mgr); webgpu::BufferManager& BufferManager() const { return *buffer_mgr_; } @@ -243,6 +254,10 @@ class WebGpuContext final { uint64_t gpu_timestamp_offset_ = 0; bool is_profiling_ = false; bool preserve_device_; + SessionState session_status_{SessionState::Default}; + + // External vector to store captured commands, owned by EP + std::vector* external_captured_commands_ = nullptr; #if defined(ENABLE_PIX_FOR_WEBGPU_EP) std::unique_ptr pix_frame_generator_ = nullptr; diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc index 13c746a6b1d31..460d220ecf1b9 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc @@ -772,11 +772,21 @@ WebGpuExecutionProvider::WebGpuExecutionProvider(int context_id, context_{context}, preferred_data_layout_{config.data_layout}, force_cpu_node_names_{std::move(config.force_cpu_node_names)}, - enable_graph_capture_{config.enable_graph_capture} {} + enable_graph_capture_{config.enable_graph_capture} { + // If graph capture is enabled, create a dedicated buffer manager for graph mode + if (enable_graph_capture_) { + // Create buffer manager for graph capture mode with appropriate cache modes + graph_buffer_mgr_ = webgpu::BufferManagerFactory::Create( + context_, + webgpu::BufferCacheMode::Graph, + webgpu::BufferCacheMode::GraphSimple, + webgpu::BufferCacheMode::Disabled); + } +} std::vector WebGpuExecutionProvider::CreatePreferredAllocators() { AllocatorCreationInfo gpuBufferAllocatorCreationInfo([&](int) { - return std::make_unique(context_); + return std::make_unique(BufferManager()); }, 0, false); auto preferred_allocators = std::vector{CreateAllocator(gpuBufferAllocatorCreationInfo)}; @@ -846,7 +856,7 @@ std::shared_ptr WebGpuExecutionProvider::GetKernelRegistry() con } std::unique_ptr WebGpuExecutionProvider::GetDataTransfer() const { - return std::make_unique(context_); + return std::make_unique(BufferManager()); } #if defined(__wasm__) @@ -871,6 +881,12 @@ std::optional WebGpuExecutionProvider::ShouldConvertDataLayoutForOp(std::s } WebGpuExecutionProvider::~WebGpuExecutionProvider() { + // Release all resources associated with the captured graph + if (!captured_commands_.empty()) { + context_.ReleaseGraphResources(captured_commands_); + } + // The graph_buffer_mgr_ will be automatically cleaned up by unique_ptr + WebGpuContextFactory::ReleaseContext(context_id_); } @@ -897,23 +913,24 @@ Status WebGpuExecutionProvider::OnRunStart(const onnxruntime::RunOptions& /*run_ } if (IsGraphCaptureEnabled() && IsGraphCaptureAllowed() && !IsGraphCaptured(0)) { - ORT_NOT_IMPLEMENTED("graph capture not implemented"); + context_.CaptureBegin(&captured_commands_, *graph_buffer_mgr_); } + return Status::OK(); } Status WebGpuExecutionProvider::OnRunEnd(bool /* sync_stream */, const onnxruntime::RunOptions& /*run_options*/) { + context_.Flush(BufferManager()); + if (IsGraphCaptureEnabled() && !IsGraphCaptured(0)) { if (IsGraphCaptureAllowed()) { - ORT_NOT_IMPLEMENTED("graph capture not implemented"); - // is_graph_captured_ = true; + context_.CaptureEnd(); + is_graph_captured_ = true; } else { IncrementRegularRunCountBeforeGraphCapture(); } } - context_.Flush(); - if (profiler_->Enabled()) { context_.CollectProfilingData(profiler_->Events()); } @@ -937,10 +954,18 @@ bool WebGpuExecutionProvider::IsGraphCaptured(int) const { Status WebGpuExecutionProvider::ReplayGraph(int) { ORT_ENFORCE(IsGraphCaptured(0)); - ORT_ENFORCE(false); + context_.Replay(captured_commands_, *graph_buffer_mgr_); return Status::OK(); } +webgpu::BufferManager& WebGpuExecutionProvider::BufferManager() const { + if (graph_buffer_mgr_) { + return *graph_buffer_mgr_; + } else { + return context_.BufferManager(); + } +} + bool WebGpuExecutionProvider::IsGraphCaptureAllowed() const { return regular_run_count_before_graph_capture_ >= min_num_runs_before_cuda_graph_capture_; } diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h index 2003f9b2ebcc6..2567be2a1eb18 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h @@ -8,6 +8,7 @@ #include "core/framework/session_options.h" #include "core/graph/constants.h" #include "core/providers/providers.h" +#include "core/providers/webgpu/buffer_manager.h" struct pthreadpool; namespace onnxruntime { @@ -18,9 +19,11 @@ template KernelCreateInfo BuildKernelCreateInfo(); class WebGpuContext; -enum class BufferCacheMode; class WebGpuProfiler; class GpuBufferAllocator; + +// Forward declare CapturedCommandInfo which is now defined in webgpu_context.h +struct CapturedCommandInfo; } // namespace webgpu struct WebGpuExecutionProviderConfig { @@ -81,10 +84,12 @@ class WebGpuExecutionProvider : public IExecutionProvider { bool IsGraphCaptureEnabled() const override; bool IsGraphCaptured(int graph_annotation_id) const override; Status ReplayGraph(int graph_annotation_id) override; + webgpu::BufferManager& BufferManager() const; private: bool IsGraphCaptureAllowed() const; void IncrementRegularRunCountBeforeGraphCapture(); + int context_id_; webgpu::WebGpuContext& context_; webgpu::WebGpuProfiler* profiler_ = nullptr; @@ -95,6 +100,12 @@ class WebGpuExecutionProvider : public IExecutionProvider { int regular_run_count_before_graph_capture_ = 0; const int min_num_runs_before_cuda_graph_capture_ = 1; // required min regular runs before graph capture for the necessary memory allocations. webgpu::GpuBufferAllocator* allocator_ = nullptr; + + // Buffer manager specifically for graph capture mode + std::unique_ptr graph_buffer_mgr_ = nullptr; + + // Store captured commands directly in the EP instead of in WebGpuContext + std::vector captured_commands_; }; } // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc b/onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc index d6812b2d0704d..80b3988215c6b 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc @@ -220,10 +220,12 @@ std::shared_ptr WebGpuProviderFactoryCreator::Create( webgpu::WebGpuBufferCacheConfig buffer_cache_config; - buffer_cache_config.storage.mode = parse_buffer_cache_mode(kStorageBufferCacheMode, webgpu::BufferCacheMode::Bucket); + buffer_cache_config.storage.mode = parse_buffer_cache_mode(kStorageBufferCacheMode, + webgpu::BufferCacheMode::Bucket); LOGS_DEFAULT(VERBOSE) << "WebGPU EP storage buffer cache mode: " << buffer_cache_config.storage.mode; - buffer_cache_config.uniform.mode = parse_buffer_cache_mode(kUniformBufferCacheMode, webgpu::BufferCacheMode::Simple); + buffer_cache_config.uniform.mode = parse_buffer_cache_mode(kUniformBufferCacheMode, + webgpu::BufferCacheMode::Simple); LOGS_DEFAULT(VERBOSE) << "WebGPU EP uniform buffer cache mode: " << buffer_cache_config.uniform.mode; buffer_cache_config.query_resolve.mode = parse_buffer_cache_mode(kQueryResolveBufferCacheMode, webgpu::BufferCacheMode::Disabled); diff --git a/onnxruntime/test/framework/inference_session_test.cc b/onnxruntime/test/framework/inference_session_test.cc index 2ce3c4859394d..add9fa6a504c9 100644 --- a/onnxruntime/test/framework/inference_session_test.cc +++ b/onnxruntime/test/framework/inference_session_test.cc @@ -217,7 +217,7 @@ static void CreateMatMulModel(std::unique_ptr& p_model, Prov if (provider_type == kCpuExecutionProvider) { node.SetExecutionProviderType(provider_type); } else { -#if defined(USE_CUDA) || defined(USE_ROCM) +#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_WEBGPU) node.SetExecutionProviderType(provider_type); #endif } @@ -286,55 +286,89 @@ void RunModelWithBindingMatMul(InferenceSession& session_object, bool is_preallocate_output_vec, ProviderType allocation_provider, IExecutionProvider* gpu_provider, - OrtDevice* output_device) { + OrtDevice* output_device, + bool enable_graph_capture) { std::unique_ptr io_binding; Status st = session_object.NewIOBinding(&io_binding); ASSERT_TRUE(st.IsOK()); - auto input_allocator = io_binding->GetCPUAllocator(bind_provider_type); // bind a value to A with input that will produce invalid output in order to test replacement of a feed std::vector values_mul_x_tmp = {12.f, 11.f, 10.f, 9.f, 8.f, 7.f, 6.f, 5.f, 4.f, 3.f, 2.f, 1.f}; std::vector dims_mul_x_A_tmp = {3, 4}; - OrtValue input_tmp; - CreateMLValue(input_allocator, dims_mul_x_A_tmp, values_mul_x_tmp, &input_tmp); - ASSERT_STATUS_OK(io_binding->BindInput("A", input_tmp)); - const void* tmp_A = io_binding->GetInputs()[0].Get().DataRaw(); // location of data post binding - - // prepare inputs std::vector values_mul_x = {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f}; - - /* - 0 1 2 3 0 1 2 - 4 5 6 7 3 4 5 - 8 9 10 11 6 7 8 - 9 10 11 - */ - // bind one input to cpu allocator from bind_provider_type, and another on user provided CPU memory - // so both code pathes are covered - OrtValue input_ml_value_A; std::vector dims_mul_x_A = {3, 4}; - CreateMLValue(input_allocator, dims_mul_x_A, values_mul_x, &input_ml_value_A); - - OrtValue input_ml_value_B; std::vector dims_mul_x_B = {4, 3}; - CreateMLValue(TestCPUExecutionProvider()->CreatePreferredAllocators()[0], dims_mul_x_B, values_mul_x, - &input_ml_value_B); - - ASSERT_STATUS_OK(io_binding->BindInput("A", input_ml_value_A)); - ASSERT_STATUS_OK(io_binding->BindInput("B", input_ml_value_B)); - // check location of 'A' post-binding has changed to validate that the previous value was replaced - ASSERT_TRUE(io_binding->GetInputs()[0].Get().DataRaw() != tmp_A); + auto cpu_alloc = TestCPUExecutionProvider()->CreatePreferredAllocators()[0]; + onnxruntime::AllocatorPtr gpu_alloc = nullptr; + if (allocation_provider == kWebGpuExecutionProvider) { + // Use session_object.GetAllocator to get the OrtAllocator for WebGPU. + // Otherwise, gpu_provider->CreatePreferredAllocators() will create a new OrtAllocator which will go to the create UMA path. + // And it can't be used for copying buffer to buffer since the target buffer is still in mapped state. + OrtMemoryInfo mem_info(WEBGPU_BUFFER, OrtAllocatorType::OrtDeviceAllocator, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::NONE, 0)); + gpu_alloc = session_object.GetAllocator(mem_info); + } else if (allocation_provider == kCudaExecutionProvider || allocation_provider == kRocmExecutionProvider) { + gpu_alloc = gpu_provider->CreatePreferredAllocators()[0]; + } + if (enable_graph_capture) { + // For graph capture, all inputs/outputs should be in preallocated gpu memory. + ASSERT_TRUE(is_preallocate_output_vec); + OrtValue input_ml_value_A_cpu; + CreateMLValue(cpu_alloc, dims_mul_x_A, values_mul_x, &input_ml_value_A_cpu); + auto& cpu_tensor_a = input_ml_value_A_cpu.Get(); + Tensor gpu_tensor_a(cpu_tensor_a.DataType(), cpu_tensor_a.Shape(), gpu_alloc); + st = gpu_provider->GetDataTransfer()->CopyTensor(cpu_tensor_a, gpu_tensor_a); + ASSERT_TRUE(st.IsOK()); + OrtValue input_ml_value_A; + Tensor::InitOrtValue(std::move(gpu_tensor_a), input_ml_value_A); + + OrtValue input_ml_value_B_cpu; + CreateMLValue(cpu_alloc, dims_mul_x_B, values_mul_x, &input_ml_value_B_cpu); + auto& cpu_tensor_b = input_ml_value_B_cpu.Get(); + Tensor gpu_tensor_b(cpu_tensor_b.DataType(), cpu_tensor_b.Shape(), gpu_alloc); + st = gpu_provider->GetDataTransfer()->CopyTensor(cpu_tensor_b, gpu_tensor_b); + ASSERT_TRUE(st.IsOK()); + OrtValue input_ml_value_B; + Tensor::InitOrtValue(std::move(gpu_tensor_b), input_ml_value_B); + ASSERT_STATUS_OK(io_binding->BindInput("A", input_ml_value_A)); + ASSERT_STATUS_OK(io_binding->BindInput("B", input_ml_value_B)); + } else { + auto input_allocator = io_binding->GetCPUAllocator(bind_provider_type); + OrtValue input_tmp; + CreateMLValue(input_allocator, dims_mul_x_A_tmp, values_mul_x_tmp, &input_tmp); + ASSERT_STATUS_OK(io_binding->BindInput("A", input_tmp)); + const void* tmp_A = io_binding->GetInputs()[0].Get().DataRaw(); // location of data post binding + + // prepare inputs + /* + 0 1 2 3 0 1 2 + 4 5 6 7 3 4 5 + 8 9 10 11 6 7 8 + 9 10 11 + */ + // bind one input to cpu allocator from bind_provider_type, and another on user provided CPU memory + // so both code pathes are covered + OrtValue input_ml_value_A; + CreateMLValue(input_allocator, dims_mul_x_A, values_mul_x, &input_ml_value_A); + + OrtValue input_ml_value_B; + CreateMLValue(cpu_alloc, dims_mul_x_B, values_mul_x, &input_ml_value_B); + + ASSERT_STATUS_OK(io_binding->BindInput("A", input_ml_value_A)); + ASSERT_STATUS_OK(io_binding->BindInput("B", input_ml_value_B)); + + // check location of 'A' post-binding has changed to validate that the previous value was replaced + ASSERT_TRUE(io_binding->GetInputs()[0].Get().DataRaw() != tmp_A); + } // prepare outputs std::vector expected_output_dims = {3, 3}; OrtValue output_ml_value; if (is_preallocate_output_vec) { if (allocation_provider == kCpuExecutionProvider) { - AllocateMLValue(TestCPUExecutionProvider()->CreatePreferredAllocators()[0], expected_output_dims, - &output_ml_value); - } else if (allocation_provider == kCudaExecutionProvider || allocation_provider == kRocmExecutionProvider) { - AllocateMLValue(gpu_provider->CreatePreferredAllocators()[0], expected_output_dims, &output_ml_value); + AllocateMLValue(cpu_alloc, expected_output_dims, &output_ml_value); + } else if (allocation_provider == kCudaExecutionProvider || allocation_provider == kRocmExecutionProvider || allocation_provider == kWebGpuExecutionProvider) { + AllocateMLValue(gpu_alloc, expected_output_dims, &output_ml_value); } else { ORT_THROW("Unsupported provider"); } @@ -351,6 +385,7 @@ void RunModelWithBindingMatMul(InferenceSession& session_object, // prepare expected inputs and outputs std::vector expected_values_mul_y = {42, 48, 54, 114, 136, 158, 186, 224, 262}; + std::vector expected_values_mul_y_2 = {174, 216, 258, 102, 128, 154, 30, 40, 50}; // Now run st = session_object.Run(run_options, *io_binding.get()); @@ -358,24 +393,24 @@ void RunModelWithBindingMatMul(InferenceSession& session_object, std::cout << "Run returned status: " << st.ErrorMessage() << std::endl; ASSERT_TRUE(st.IsOK()); - if ((is_preallocate_output_vec && (allocation_provider == kCudaExecutionProvider || allocation_provider == kRocmExecutionProvider)) || + if ((is_preallocate_output_vec && (allocation_provider == kCudaExecutionProvider || allocation_provider == kRocmExecutionProvider || allocation_provider == kWebGpuExecutionProvider)) || (output_device && output_device->Type() == OrtDevice::GPU)) { -#if defined(USE_CUDA) || defined(USE_ROCM) +#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_WEBGPU) // in this case we need to copy the tensor from cuda to cpu std::vector& outputs = io_binding->GetOutputs(); ASSERT_EQ(1u, outputs.size()); auto& rtensor = outputs.front().Get(); auto element_type = rtensor.DataType(); auto& shape = rtensor.Shape(); - auto cpu_allocator = TestCPUExecutionProvider()->CreatePreferredAllocators()[0]; - std::unique_ptr cpu_tensor = std::make_unique(element_type, - shape, - cpu_allocator); + std::unique_ptr cpu_tensor = std::make_unique(element_type, shape, cpu_alloc); #ifdef USE_CUDA st = GetProviderInfo_CUDA().CreateGPUDataTransfer()->CopyTensor(rtensor, *cpu_tensor.get()); #endif #ifdef USE_ROCM st = GetProviderInfo_ROCM().CreateGPUDataTransfer()->CopyTensor(rtensor, *cpu_tensor.get()); +#endif +#ifdef USE_WEBGPU + st = gpu_provider->GetDataTransfer()->CopyTensor(rtensor, *cpu_tensor.get()); #endif ASSERT_TRUE(st.IsOK()); OrtValue ml_value; @@ -385,11 +420,40 @@ void RunModelWithBindingMatMul(InferenceSession& session_object, VerifyOutputs({ml_value}, expected_output_dims, expected_values_mul_y); #endif } else { - if (allocation_provider == kCudaExecutionProvider || allocation_provider == kRocmExecutionProvider) { + if (allocation_provider == kCudaExecutionProvider || allocation_provider == kRocmExecutionProvider || allocation_provider == kWebGpuExecutionProvider) { ASSERT_STATUS_OK(gpu_provider->Sync()); } VerifyOutputs(io_binding->GetOutputs(), expected_output_dims, expected_values_mul_y); } + + if (enable_graph_capture) { + // Update input_a's value. Run again. Replay the captured graph + OrtValue input_a2; + CreateMLValue(cpu_alloc, dims_mul_x_A_tmp, values_mul_x_tmp, &input_a2); + auto& cpu_tensor_a2 = input_a2.Get(); + st = gpu_provider->GetDataTransfer()->CopyTensor(cpu_tensor_a2, const_cast(io_binding->GetInputs()[0].Get())); + ASSERT_TRUE(st.IsOK()); + + st = session_object.Run(run_options, *io_binding.get()); + + std::cout << "Run returned status: " << st.ErrorMessage() << std::endl; + ASSERT_TRUE(st.IsOK()); + + // Copy the tensor from gpu to cpu + std::vector& outputs = io_binding->GetOutputs(); + ASSERT_EQ(1u, outputs.size()); + auto& rtensor = outputs.front().Get(); + auto element_type = rtensor.DataType(); + auto& shape = rtensor.Shape(); + std::unique_ptr cpu_tensor = std::make_unique(element_type, shape, cpu_alloc); + st = gpu_provider->GetDataTransfer()->CopyTensor(rtensor, *cpu_tensor.get()); + ASSERT_TRUE(st.IsOK()); + OrtValue ml_value; + ml_value.Init(cpu_tensor.release(), + DataTypeImpl::GetType(), + DataTypeImpl::GetType()->GetDeleteFunc()); + VerifyOutputs({ml_value}, expected_output_dims, expected_values_mul_y_2); + } } TEST(InferenceSessionTests, NoTimeout) { @@ -1059,16 +1123,16 @@ static void TestBindHelper(const std::string& log_str, ProviderType run_provider_type, bool preallocate_output, ProviderType allocation_provider = kCpuExecutionProvider, - OrtDevice* output_device = nullptr) { + OrtDevice* output_device = nullptr, + bool enable_graph_capture = false) { SessionOptions so; so.session_logid = "InferenceSessionTests." + log_str; so.session_log_verbosity_level = 1; // change to 1 for detailed logging - InferenceSession session_object{so, GetEnvironment()}; IExecutionProvider* gpu_provider{}; - if (bind_provider_type == kCudaExecutionProvider || bind_provider_type == kRocmExecutionProvider) { + if (bind_provider_type == kCudaExecutionProvider || bind_provider_type == kRocmExecutionProvider || bind_provider_type == kWebGpuExecutionProvider) { #ifdef USE_CUDA auto provider = DefaultCudaExecutionProvider(); gpu_provider = provider.get(); @@ -1078,6 +1142,15 @@ static void TestBindHelper(const std::string& log_str, auto provider = DefaultRocmExecutionProvider(); gpu_provider = provider.get(); ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(std::move(provider))); +#endif +#ifdef USE_WEBGPU + ConfigOptions config_options{}; + ORT_ENFORCE(config_options.AddConfigEntry(webgpu::options::kEnableGraphCapture, + enable_graph_capture ? webgpu::options::kEnableGraphCapture_ON : webgpu::options::kEnableGraphCapture_OFF) + .IsOK()); + auto provider = WebGpuExecutionProviderWithOptions(config_options); + gpu_provider = provider.get(); + ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(std::move(provider))); #endif } @@ -1100,7 +1173,8 @@ static void TestBindHelper(const std::string& log_str, preallocate_output, allocation_provider, gpu_provider, - output_device); + output_device, + enable_graph_capture); } TEST(InferenceSessionTests, TestBindCpu) { @@ -1187,12 +1261,15 @@ TEST(InferenceSessionTests, InvalidInputTypeOfTensorElement) { ASSERT_TRUE(!st.IsOK()); } -#if defined(USE_CUDA) || defined(USE_ROCM) +#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_WEBGPU) #if USE_CUDA constexpr const char* kGpuExecutionProvider = kCudaExecutionProvider; #elif USE_ROCM constexpr const char* kGpuExecutionProvider = kRocmExecutionProvider; +#elif USE_WEBGPU +constexpr const char* kGpuExecutionProvider = kWebGpuExecutionProvider; #endif + TEST(InferenceSessionTests, TestBindCuda) { TestBindHelper("TestBindCuda", kGpuExecutionProvider, @@ -1223,7 +1300,7 @@ TEST(InferenceSessionTests, TestBindCudaPreallocateOutputOnCpu2) { true /* preallocate output on CPU */, kCpuExecutionProvider); } - +#ifndef USE_WEBGPU TEST(InferenceSessionTests, TestBindCudaSpecifyOutputDeviceOnCuda) { OrtDevice device(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::NVIDIA, 0); @@ -1234,7 +1311,17 @@ TEST(InferenceSessionTests, TestBindCudaSpecifyOutputDeviceOnCuda) { kGpuExecutionProvider, &device /* specify output device */); } - +#else +TEST(InferenceSessionTests, TestGraphCapture) { + TestBindHelper("TestGraphCapture", + kGpuExecutionProvider, + kGpuExecutionProvider, + true /* preallocate output on GPU */, + kGpuExecutionProvider, + nullptr, + true /* enable graph capture*/); +} +#endif // !USE_WEBGPU #endif TEST(InferenceSessionTests, ModelWithoutOpset) { diff --git a/onnxruntime/test/util/default_providers.cc b/onnxruntime/test/util/default_providers.cc index 81cb56d34c925..2e4aa3923b649 100644 --- a/onnxruntime/test/util/default_providers.cc +++ b/onnxruntime/test/util/default_providers.cc @@ -313,6 +313,15 @@ std::unique_ptr DefaultWebGpuExecutionProvider(bool is_nhwc) #endif } +std::unique_ptr WebGpuExecutionProviderWithOptions(const ConfigOptions& config_options) { +#ifdef USE_WEBGPU + return WebGpuProviderFactoryCreator::Create(config_options)->CreateProvider(); +#else + ORT_UNUSED_PARAMETER(config_options); + return nullptr; +#endif +} + std::unique_ptr DefaultCannExecutionProvider() { #ifdef USE_CANN OrtCANNProviderOptions provider_options{}; diff --git a/onnxruntime/test/util/include/default_providers.h b/onnxruntime/test/util/include/default_providers.h index ce6434991051c..67d85edb4b8ef 100644 --- a/onnxruntime/test/util/include/default_providers.h +++ b/onnxruntime/test/util/include/default_providers.h @@ -5,6 +5,7 @@ #include "core/common/optional.h" #include "core/providers/providers.h" #include "core/providers/provider_factory_creators.h" +#include "core/framework/config_options.h" #include "core/framework/execution_provider.h" namespace onnxruntime { @@ -64,6 +65,7 @@ std::unique_ptr QnnExecutionProviderWithOptions(const Provid const SessionOptions* session_options = nullptr); std::unique_ptr DefaultXnnpackExecutionProvider(); std::unique_ptr DefaultWebGpuExecutionProvider(bool is_nhwc = true); +std::unique_ptr WebGpuExecutionProviderWithOptions(const ConfigOptions& config_options); std::unique_ptr DefaultCannExecutionProvider(); std::unique_ptr DefaultDmlExecutionProvider(); From fcd448ad5e5bb0cd373d5400669509b0b1eb5fd7 Mon Sep 17 00:00:00 2001 From: Wanming Lin Date: Mon, 7 Jul 2025 12:38:05 +0800 Subject: [PATCH 6/7] [WebNN] Always create a new constant for zero_points (#25286) MatMulNBits is a decomposed op in WebNN EP. Previously, we share the WebNN constant of zero_points if they have the same value and data type. However, this brings a lot of complexity for developers to fuse it back to MatMulNBits in the underlying WebNN implementation in Chromium. In this PR, we will always create a new constant for zero_points. --- .../builders/impl/matMulNBits_op_builder.cc | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/onnxruntime/core/providers/webnn/builders/impl/matMulNBits_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/matMulNBits_op_builder.cc index 521aa4a4bfc5a..111d03571e974 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/matMulNBits_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/matMulNBits_op_builder.cc @@ -100,20 +100,25 @@ Status MatMulNBitsBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, // x_zero_point has the same shape as x_scale const bool has_zero_points = TensorExists(input_defs, 3); emscripten::val x_zero_point = emscripten::val::undefined(); + emscripten::val zero_points_desc = emscripten::val::object(); + zero_points_desc.set("dataType", emscripten::val("uint4")); + zero_points_desc.set("shape", x_scale_shape_array); + zero_points_desc.set("dimensions", x_scale_shape_array); if (has_zero_points) { // zero_points is an initializer with data type 'uint8', we need to register it as 'uint4' WebNN constant const auto zero_points_tensor = *initializers.at(input_defs[3]->Name()); - emscripten::val zero_points_desc = emscripten::val::object(); - zero_points_desc.set("dataType", emscripten::val("uint4")); - zero_points_desc.set("shape", x_scale_shape_array); - zero_points_desc.set("dimensions", x_scale_shape_array); ORT_RETURN_IF_ERROR(model_builder.RegisterConstant(zero_points_tensor, x_zero_point, zero_points_desc, logger)); } else { // zero_points' default value is 8, referred from CPU EP const int8_t default_zero_point = 8; - x_zero_point = model_builder.CreateOrGetConstant(ONNX_NAMESPACE::TensorProto_DataType_UINT4, - default_zero_point, - x_scale_shape); + // Always create a new WebNN constant for zero_points to facilitate MatMulNBits fusion in Chromium + auto num_elements = (Product(x_scale_shape) + 1) / 2; + emscripten::val default_zero_point_buffer = emscripten::val::global("Uint8Array").new_(num_elements); + default_zero_point_buffer.call("fill", + emscripten::val(PackInt8ToUint8DoubledNibbles( + default_zero_point, ONNX_NAMESPACE::TensorProto_DataType_UINT4))); + x_zero_point = + model_builder.GetBuilder().call("constant", zero_points_desc, default_zero_point_buffer); } // DequantizeLinear From a3c3e2f038216d03a184e76ac32e452612a45e33 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Mon, 7 Jul 2025 02:02:43 -0700 Subject: [PATCH 7/7] [webgpu] a few optimizations to graph capture implementation (#25305) ### Description 1. rename `SessionState` to `GraphCaptureState`, since there is already one SessionState type in ORT. 2. optimize implementation of `ComputeContext::BufferManager()` --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- onnxruntime/core/providers/webgpu/allocator.h | 3 --- .../core/providers/webgpu/buffer_manager.cc | 24 +++++++++---------- .../core/providers/webgpu/buffer_manager.h | 6 ++--- .../core/providers/webgpu/compute_context.cc | 11 ++++----- .../core/providers/webgpu/compute_context.h | 4 +++- .../core/providers/webgpu/webgpu_context.cc | 14 +++++------ .../core/providers/webgpu/webgpu_context.h | 2 +- .../core/providers/webgpu/webgpu_kernel.h | 10 ++++++-- 8 files changed, 39 insertions(+), 35 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/allocator.h b/onnxruntime/core/providers/webgpu/allocator.h index 02f78c18fc947..de9b0a800ef64 100644 --- a/onnxruntime/core/providers/webgpu/allocator.h +++ b/onnxruntime/core/providers/webgpu/allocator.h @@ -26,9 +26,6 @@ class GpuBufferAllocator : public IAllocator { void GetStats(AllocatorStats* stats) override; void OnSessionInitializationEnd(); - // Return the associated BufferManager - const BufferManager& GetBufferManager() const { return buffer_manager_; } - private: AllocatorStats stats_; const BufferManager& buffer_manager_; diff --git a/onnxruntime/core/providers/webgpu/buffer_manager.cc b/onnxruntime/core/providers/webgpu/buffer_manager.cc index c02049f60e4db..e8140a4d59eab 100644 --- a/onnxruntime/core/providers/webgpu/buffer_manager.cc +++ b/onnxruntime/core/providers/webgpu/buffer_manager.cc @@ -37,7 +37,7 @@ class DisabledCacheManager : public IBufferCacheManager { wgpuBufferRelease(buffer); } - void OnRefresh(const SessionState& /*session_status*/) override { + void OnRefresh(GraphCaptureState /*graph_capture_state*/) override { // no-op } }; @@ -59,7 +59,7 @@ class LazyReleaseCacheManager : public IBufferCacheManager { pending_buffers_.emplace_back(buffer); } - void OnRefresh(const SessionState& /*session_status*/) override { + void OnRefresh(GraphCaptureState /*graph_capture_state*/) override { Release(); pending_buffers_.clear(); } @@ -103,7 +103,7 @@ class SimpleCacheManager : public IBufferCacheManager { pending_buffers_.emplace_back(buffer); } - void OnRefresh(const SessionState& /*session_status*/) override { + void OnRefresh(GraphCaptureState /*graph_capture_state*/) override { for (auto& buffer : pending_buffers_) { buffers_[static_cast(wgpuBufferGetSize(buffer))].emplace_back(buffer); } @@ -196,7 +196,7 @@ class BucketCacheManager : public IBufferCacheManager { pending_buffers_.emplace_back(buffer); } - void OnRefresh(const SessionState& /*session_status*/) override { + void OnRefresh(GraphCaptureState /*graph_capture_state*/) override { for (auto& buffer : pending_buffers_) { auto buffer_size = static_cast(wgpuBufferGetSize(buffer)); auto it = buckets_.find(buffer_size); @@ -283,7 +283,7 @@ class GraphCacheManager : public IBufferCacheManager { pending_buffers_.emplace_back(buffer); } - void OnRefresh(const SessionState& /*session_status*/) override { + void OnRefresh(GraphCaptureState /*graph_capture_state*/) override { // Initialize buckets if they don't exist yet if (buckets_.empty()) { for (const auto& pair : buckets_limit_) { @@ -363,9 +363,9 @@ class GraphSimpleCacheManager : public IBufferCacheManager { pending_buffers_.emplace_back(buffer); } - void OnRefresh(const SessionState& session_status) override { + void OnRefresh(GraphCaptureState graph_capture_state) override { for (auto& buffer : pending_buffers_) { - if (session_status == SessionState::Default) { + if (graph_capture_state == GraphCaptureState::Default) { buffers_[static_cast(wgpuBufferGetSize(buffer))].emplace_back(buffer); } else { captured_buffers_.emplace_back(buffer); @@ -573,11 +573,11 @@ void BufferManager::Download(WGPUBuffer src, void* dst, size_t size) const { staging_buffer.Unmap(); } -void BufferManager::RefreshPendingBuffers(const SessionState& session_status) const { - storage_cache_->OnRefresh(session_status); - uniform_cache_->OnRefresh(session_status); - query_resolve_cache_->OnRefresh(session_status); - default_cache_->OnRefresh(session_status); +void BufferManager::RefreshPendingBuffers(GraphCaptureState graph_capture_state) const { + storage_cache_->OnRefresh(graph_capture_state); + uniform_cache_->OnRefresh(graph_capture_state); + query_resolve_cache_->OnRefresh(graph_capture_state); + default_cache_->OnRefresh(graph_capture_state); } IBufferCacheManager& BufferManager::GetCacheManager(wgpu::BufferUsage usage) const { diff --git a/onnxruntime/core/providers/webgpu/buffer_manager.h b/onnxruntime/core/providers/webgpu/buffer_manager.h index 8f7882d2f2fa8..e854139496726 100644 --- a/onnxruntime/core/providers/webgpu/buffer_manager.h +++ b/onnxruntime/core/providers/webgpu/buffer_manager.h @@ -15,7 +15,7 @@ namespace webgpu { class WebGpuContext; // For command capture and replay -enum class SessionState { +enum class GraphCaptureState { Default, Capturing, Replaying @@ -59,7 +59,7 @@ class IBufferCacheManager { virtual void ReleaseBuffer(WGPUBuffer buffer) = 0; // when a stream refresh is requested - virtual void OnRefresh(const SessionState& session_status) = 0; + virtual void OnRefresh(GraphCaptureState graph_capture_state) = 0; }; // @@ -77,7 +77,7 @@ class BufferManager { bool SupportsUMA() const; void Release(WGPUBuffer buffer) const; void Download(WGPUBuffer src, void* dst, size_t size) const; - void RefreshPendingBuffers(const SessionState& session_status) const; + void RefreshPendingBuffers(GraphCaptureState graph_capture_state) const; private: IBufferCacheManager& GetCacheManager(wgpu::BufferUsage usage) const; diff --git a/onnxruntime/core/providers/webgpu/compute_context.cc b/onnxruntime/core/providers/webgpu/compute_context.cc index 904a2885ff619..25caa9b954fc0 100644 --- a/onnxruntime/core/providers/webgpu/compute_context.cc +++ b/onnxruntime/core/providers/webgpu/compute_context.cc @@ -7,12 +7,14 @@ #include "core/providers/webgpu/webgpu_context.h" #include "core/providers/webgpu/allocator.h" #include "core/providers/webgpu/buffer_manager.h" +#include "core/providers/webgpu/webgpu_execution_provider.h" namespace onnxruntime { namespace webgpu { -ComputeContext::ComputeContext(OpKernelContext& kernel_context) +ComputeContext::ComputeContext(OpKernelContext& kernel_context, const WebGpuExecutionProvider& ep) : webgpu_context_{WebGpuContextFactory::GetContext(kernel_context.GetDeviceId())}, - kernel_context_{kernel_context} { + kernel_context_{kernel_context}, + ep_{ep} { } void ComputeContext::PushErrorScope() { @@ -29,10 +31,7 @@ Status ComputeContext::PopErrorScope() { } const webgpu::BufferManager& ComputeContext::BufferManager() const { - OrtDevice gpu_device(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::NONE, 0); - AllocatorPtr allocator = kernel_context_.GetAllocator(gpu_device); - const GpuBufferAllocator* gpu_allocator = static_cast(allocator.get()); - return gpu_allocator->GetBufferManager(); + return ep_.BufferManager(); } } // namespace webgpu diff --git a/onnxruntime/core/providers/webgpu/compute_context.h b/onnxruntime/core/providers/webgpu/compute_context.h index c9bdb3a92f162..fe95917e4e906 100644 --- a/onnxruntime/core/providers/webgpu/compute_context.h +++ b/onnxruntime/core/providers/webgpu/compute_context.h @@ -16,6 +16,7 @@ namespace onnxruntime { class Tensor; +class WebGpuExecutionProvider; namespace webgpu { @@ -24,7 +25,7 @@ class BufferManager; class ComputeContext { public: - ComputeContext(OpKernelContext& kernel_context); + ComputeContext(OpKernelContext& kernel_context, const WebGpuExecutionProvider& ep); virtual ~ComputeContext() = default; @@ -145,6 +146,7 @@ class ComputeContext { protected: WebGpuContext& webgpu_context_; OpKernelContext& kernel_context_; + const WebGpuExecutionProvider& ep_; }; } // namespace webgpu diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.cc b/onnxruntime/core/providers/webgpu/webgpu_context.cc index 09cbe8b52f2cc..4bd79a627df22 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_context.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_context.cc @@ -691,8 +691,8 @@ void WebGpuContext::Flush(const webgpu::BufferManager& buffer_mgr) { } auto command_buffer = current_command_encoder_.Finish(); device_queue_.Submit(1, &command_buffer); - if (session_status_ != SessionState::Replaying) { - buffer_mgr.RefreshPendingBuffers(session_status_); + if (graph_capture_state_ != GraphCaptureState::Replaying) { + buffer_mgr.RefreshPendingBuffers(graph_capture_state_); } current_command_encoder_ = nullptr; num_pending_dispatches_ = 0; @@ -724,7 +724,7 @@ void WebGpuContext::LaunchComputePipeline(const wgpu::ComputePassEncoder& comput bind_group_desc.label = {program_artifact.name.data(), program_artifact.name.length()}; auto bind_group = wgpuDeviceCreateBindGroup(Device().Get(), &bind_group_desc); - if (session_status_ == SessionState::Capturing) { + if (graph_capture_state_ == GraphCaptureState::Capturing) { external_captured_commands_->push_back({program_artifact.compute_pipeline, bind_group, bind_group_layout, @@ -754,12 +754,12 @@ void WebGpuContext::CaptureBegin(std::vector* captu // TODO: support profiling with graph capture. ORT_ENFORCE(!is_profiling_, "profiling is not supported yet under graph capture mode"); - session_status_ = SessionState::Capturing; + graph_capture_state_ = GraphCaptureState::Capturing; } void WebGpuContext::Replay(const std::vector& captured_commands, const webgpu::BufferManager& buffer_manager) { LOGS_DEFAULT(VERBOSE) << "Replay with external storage"; - session_status_ = SessionState::Replaying; + graph_capture_state_ = GraphCaptureState::Replaying; // Replay all captured commands from the provided vector const size_t command_count = captured_commands.size(); for (size_t i = 0; i < command_count; ++i) { @@ -784,13 +784,13 @@ void WebGpuContext::Replay(const std::vector& captu // Flush any remaining commands Flush(buffer_manager); - session_status_ = SessionState::Default; + graph_capture_state_ = GraphCaptureState::Default; } void WebGpuContext::CaptureEnd() { LOGS_DEFAULT(VERBOSE) << "CaptureEnd"; - session_status_ = SessionState::Default; + graph_capture_state_ = GraphCaptureState::Default; external_captured_commands_ = nullptr; } diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.h b/onnxruntime/core/providers/webgpu/webgpu_context.h index 935c99e0a11f6..3084483db522d 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_context.h +++ b/onnxruntime/core/providers/webgpu/webgpu_context.h @@ -254,7 +254,7 @@ class WebGpuContext final { uint64_t gpu_timestamp_offset_ = 0; bool is_profiling_ = false; bool preserve_device_; - SessionState session_status_{SessionState::Default}; + GraphCaptureState graph_capture_state_{GraphCaptureState::Default}; // External vector to store captured commands, owned by EP std::vector* external_captured_commands_ = nullptr; diff --git a/onnxruntime/core/providers/webgpu/webgpu_kernel.h b/onnxruntime/core/providers/webgpu/webgpu_kernel.h index d7682e751d9e4..e37be2944a22b 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_kernel.h +++ b/onnxruntime/core/providers/webgpu/webgpu_kernel.h @@ -9,6 +9,8 @@ #include "core/framework/op_kernel.h" namespace onnxruntime { + +class WebGpuExecutionProvider; namespace webgpu { // ----------------------------------------------------------------------- @@ -17,11 +19,12 @@ namespace webgpu { class WebGpuKernel : public OpKernel { public: explicit WebGpuKernel(const OpKernelInfo& info) - : OpKernel(info) { + : OpKernel(info), + ep_(*static_cast(info.GetExecutionProvider())) { } Status Compute(OpKernelContext* p_op_kernel_context) const override { - ComputeContext context{*p_op_kernel_context}; + ComputeContext context{*p_op_kernel_context, ep_}; context.PushErrorScope(); Status s = ComputeInternal(context); @@ -31,6 +34,9 @@ class WebGpuKernel : public OpKernel { } virtual Status ComputeInternal(ComputeContext& context) const = 0; + + private: + const WebGpuExecutionProvider& ep_; }; } // namespace webgpu