From f2454222941ab4cff1af98c354d8ec59a66b1b1d Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Tue, 1 Jul 2025 23:16:53 -0700 Subject: [PATCH 01/19] Fix NPM packaging pipeline (#25244) ### Description Fixes NPM packaging pipeline. --- .../github/azure-pipelines/templates/linux-wasm-ci.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tools/ci_build/github/azure-pipelines/templates/linux-wasm-ci.yml b/tools/ci_build/github/azure-pipelines/templates/linux-wasm-ci.yml index 3eae6ec9c3fdf..fa12aab6a91b2 100644 --- a/tools/ci_build/github/azure-pipelines/templates/linux-wasm-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/linux-wasm-ci.yml @@ -129,11 +129,11 @@ jobs: parameters: Today: $(Today) ${{ if eq(parameters.BuildStaticLib, true)}}: - AdditionalKey: wasm_inferencing_webgpu_exp | ${{ parameters.BuildConfig }} | static + AdditionalKey: wasm_inferencing_webgpu | ${{ parameters.BuildConfig }} | static ${{ else }}: - AdditionalKey: wasm_inferencing_webgpu_exp | ${{ parameters.BuildConfig }} + AdditionalKey: wasm_inferencing_webgpu | ${{ parameters.BuildConfig }} CacheDir: $(ORT_CACHE_DIR)/wasm_inferencing_webgpu - Arguments: '$(CommonBuildArgs) --build_dir $(Build.BinariesDirectory)/wasm_inferencing_webgpu --use_webgpu --use_jsep --use_webnn --target onnxruntime_webassembly --skip_tests' + Arguments: '$(CommonBuildArgs) --build_dir $(Build.BinariesDirectory)/wasm_inferencing_webgpu --use_webgpu --use_webnn --target onnxruntime_webassembly --skip_tests' DisplayName: 'Build (simd + threads + WebGPU experimental)' WithCache: ${{ parameters.WithCache }} From e80cd8a3ef9b519d125c3acc221e84319911dc0c Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 2 Jul 2025 08:52:37 +0000 Subject: [PATCH 02/19] Bump electron from 28.1.4 to 28.3.2 in /js/web (#25241) --- js/web/package-lock.json | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/js/web/package-lock.json b/js/web/package-lock.json index 6f54e9dcdc944..eabb198e97177 100644 --- a/js/web/package-lock.json +++ b/js/web/package-lock.json @@ -1067,9 +1067,9 @@ "dev": true }, "node_modules/electron": { - "version": "28.1.4", - "resolved": "https://registry.npmjs.org/electron/-/electron-28.1.4.tgz", - "integrity": "sha512-WE6go611KOhtH6efRPMnVC7FE7DCKnQ3ZyHFeI1DbaCy8OU4UjZ8/CZGcuZmZgRdxSBEHoHdgaJkWRHZzF0FOg==", + "version": "28.3.2", + "resolved": "https://registry.npmjs.org/electron/-/electron-28.3.2.tgz", + "integrity": "sha512-bmrQpdncbYNTArlg4n+qsASoXy3eeCELxeRmwUS52RNgvio1gGx5FLCwf8d4R+TsxwfkDWOaWbW0taIKheivKA==", "dev": true, "hasInstallScript": true, "dependencies": { @@ -4460,9 +4460,9 @@ "dev": true }, "electron": { - "version": "28.1.4", - "resolved": "https://registry.npmjs.org/electron/-/electron-28.1.4.tgz", - "integrity": "sha512-WE6go611KOhtH6efRPMnVC7FE7DCKnQ3ZyHFeI1DbaCy8OU4UjZ8/CZGcuZmZgRdxSBEHoHdgaJkWRHZzF0FOg==", + "version": "28.3.2", + "resolved": "https://registry.npmjs.org/electron/-/electron-28.3.2.tgz", + "integrity": "sha512-bmrQpdncbYNTArlg4n+qsASoXy3eeCELxeRmwUS52RNgvio1gGx5FLCwf8d4R+TsxwfkDWOaWbW0taIKheivKA==", "dev": true, "requires": { "@electron/get": "^2.0.0", From 0ef1b3442c79b490446fbefa9b1320394bdaa1ef Mon Sep 17 00:00:00 2001 From: Jianhui Dai Date: Wed, 2 Jul 2025 21:19:32 +0800 Subject: [PATCH 03/19] [webgpu] Add 2% tolerance to `MatMulNBits.Float32_8b_AccuracyLevel4` (#25249) ### Description Add 2% more tolerance to `MatMulNBits` accuracy level int8 compared with f32/f16, to fix #25231. ### Motivation and Context See above. --- onnxruntime/test/contrib_ops/matmul_8bits_test.cc | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/onnxruntime/test/contrib_ops/matmul_8bits_test.cc b/onnxruntime/test/contrib_ops/matmul_8bits_test.cc index bd7ee13aeae31..8151f9fb3dcc7 100644 --- a/onnxruntime/test/contrib_ops/matmul_8bits_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_8bits_test.cc @@ -305,7 +305,11 @@ TEST(MatMulNBits, Float32_8b_AccuracyLevel4) { TestMatMul8BitsTyped(); TestMatMul8BitsTyped(); TestMatMul8BitsTyped(); - TestMatMul8BitsTyped(); + + // Using a 2% larger tolerance for accuracy level int8 compared to the accuracy level f32/f16. + constexpr float abs_error = 0.1f * 1.02f; + constexpr float rel_error = 0.02f * 1.02f; + TestMatMul8BitsTyped(abs_error, rel_error); } TEST(MatMulNBits, Float32_8b_AccuracyLevel1) { From f1c39d7e56fb9638191be0a7b8063eb0a8b74a0c Mon Sep 17 00:00:00 2001 From: Ashrit Shetty Date: Wed, 2 Jul 2025 10:41:42 -0700 Subject: [PATCH 04/19] Refactor LogProviderOptions (#25250) ### Description This pull request refactors the logging of provider options in the ONNX Runtime framework to improve telemetry functionality. The changes include consolidating logging logic, introducing platform-independent methods, and enhancing the telemetry interface for better extensibility. --- .../core/framework/execution_providers.h | 30 ++++++++--------- onnxruntime/core/platform/telemetry.cc | 8 +++++ onnxruntime/core/platform/telemetry.h | 4 +++ .../core/platform/windows/telemetry.cc | 32 +++++++++++++++++++ onnxruntime/core/platform/windows/telemetry.h | 4 +++ 5 files changed, 61 insertions(+), 17 deletions(-) diff --git a/onnxruntime/core/framework/execution_providers.h b/onnxruntime/core/framework/execution_providers.h index 29cf79ec385d8..ad861af38f5e4 100644 --- a/onnxruntime/core/framework/execution_providers.h +++ b/onnxruntime/core/framework/execution_providers.h @@ -54,7 +54,6 @@ class ExecutionProviders { auto it = exec_provider_options_.find(provider_id); if (it != exec_provider_options_.end()) { const auto& options = it->second; - LogProviderOptions(provider_id, options, true); } } @@ -97,22 +96,6 @@ class ExecutionProviders { return Status::OK(); } -#ifdef _WIN32 - void LogProviderOptions(const std::string& provider_id, const ProviderOptions& providerOptions, bool captureState) { - for (const auto& config_pair : providerOptions) { - TraceLoggingWrite( - telemetry_provider_handle, - "ProviderOptions", - TraceLoggingKeyword(static_cast(onnxruntime::logging::ORTTraceLoggingKeyword::Session)), - TraceLoggingLevel(WINEVENT_LEVEL_INFO), - TraceLoggingString(provider_id.c_str(), "ProviderId"), - TraceLoggingString(config_pair.first.c_str(), "Key"), - TraceLoggingString(config_pair.second.c_str(), "Value"), - TraceLoggingBool(captureState, "isCaptureState")); - } - } -#endif - const IExecutionProvider* Get(const onnxruntime::Node& node) const { return Get(node.GetExecutionProviderType()); } @@ -157,6 +140,19 @@ class ExecutionProviders { // with a container that has unique_ptr or something move-only. ORT_DISALLOW_COPY_AND_ASSIGNMENT(ExecutionProviders); + void LogProviderOptions(const std::string& provider_id, const ProviderOptions& options, bool capture_state) { + const Env& env = Env::Default(); + // Convert ProviderOptions to string for telemetry logging + std::string provider_options_str; + for (const auto& config_pair : options) { + if (!provider_options_str.empty()) { + provider_options_str += ","; + } + provider_options_str += config_pair.first + ":" + config_pair.second; + } + env.GetTelemetryProvider().LogProviderOptions(provider_id, provider_options_str, capture_state); + } + std::vector> exec_providers_; std::vector exec_provider_ids_; ProviderOptionsMap exec_provider_options_; diff --git a/onnxruntime/core/platform/telemetry.cc b/onnxruntime/core/platform/telemetry.cc index 6754e2471f52c..9cf89a04f031c 100644 --- a/onnxruntime/core/platform/telemetry.cc +++ b/onnxruntime/core/platform/telemetry.cc @@ -116,4 +116,12 @@ void Telemetry::LogAutoEpSelection(uint32_t session_id, const std::string& selec ORT_UNUSED_PARAMETER(available_execution_provider_ids); } +void Telemetry::LogProviderOptions(const std::string& provider_id, + const std::string& provider_options_string, + bool captureState) const { + ORT_UNUSED_PARAMETER(provider_id); + ORT_UNUSED_PARAMETER(provider_options_string); + ORT_UNUSED_PARAMETER(captureState); +} + } // namespace onnxruntime diff --git a/onnxruntime/core/platform/telemetry.h b/onnxruntime/core/platform/telemetry.h index 0103588f0e0d7..cb7a6176e5aec 100644 --- a/onnxruntime/core/platform/telemetry.h +++ b/onnxruntime/core/platform/telemetry.h @@ -82,6 +82,10 @@ class Telemetry { const std::vector& requested_execution_provider_ids, const std::vector& available_execution_provider_ids) const; + virtual void LogProviderOptions(const std::string& provider_id, + const std::string& provider_options_string, + bool captureState) const; + private: ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(Telemetry); }; diff --git a/onnxruntime/core/platform/windows/telemetry.cc b/onnxruntime/core/platform/windows/telemetry.cc index 47c9d0d75df16..44ef44a3f5aff 100644 --- a/onnxruntime/core/platform/windows/telemetry.cc +++ b/onnxruntime/core/platform/windows/telemetry.cc @@ -478,4 +478,36 @@ void WindowsTelemetry::LogAutoEpSelection(uint32_t session_id, const std::string TraceLoggingString(available_execution_provider_string.c_str(), "availableExecutionProviderIds")); } +void WindowsTelemetry::LogProviderOptions(const std::string& provider_id, const std::string& provider_options_string, bool captureState) const { + if (global_register_count_ == 0 || enabled_ == false) + return; + + // Difference is MeasureEvent & isCaptureState, but keep in sync otherwise + if (!captureState) { + TraceLoggingWrite(telemetry_provider_handle, + "ProviderOptions", + TraceLoggingBool(true, "UTCReplace_AppSessionGuid"), + TelemetryPrivacyDataTag(PDT_ProductAndServiceUsage), + TraceLoggingKeyword(MICROSOFT_KEYWORD_MEASURES), + TraceLoggingKeyword(static_cast(onnxruntime::logging::ORTTraceLoggingKeyword::Session)), + TraceLoggingLevel(WINEVENT_LEVEL_INFO), + // Telemetry info + TraceLoggingUInt8(0, "schemaVersion"), + TraceLoggingString(provider_id.c_str(), "providerId"), + TraceLoggingString(provider_options_string.c_str(), "providerOptions")); + } else { + TraceLoggingWrite(telemetry_provider_handle, + "ProviderOptions_CaptureState", + TraceLoggingBool(true, "UTCReplace_AppSessionGuid"), + TelemetryPrivacyDataTag(PDT_ProductAndServiceUsage), + // Not a measure event + TraceLoggingKeyword(static_cast(onnxruntime::logging::ORTTraceLoggingKeyword::Session)), + TraceLoggingLevel(WINEVENT_LEVEL_INFO), + // Telemetry info + TraceLoggingUInt8(0, "schemaVersion"), + TraceLoggingString(provider_id.c_str(), "providerId"), + TraceLoggingString(provider_options_string.c_str(), "providerOptions")); + } +} + } // namespace onnxruntime diff --git a/onnxruntime/core/platform/windows/telemetry.h b/onnxruntime/core/platform/windows/telemetry.h index 787c8ba2d5e7f..7281063d50c2e 100644 --- a/onnxruntime/core/platform/windows/telemetry.h +++ b/onnxruntime/core/platform/windows/telemetry.h @@ -75,6 +75,10 @@ class WindowsTelemetry : public Telemetry { const std::vector& requested_execution_provider_ids, const std::vector& available_execution_provider_ids) const override; + void LogProviderOptions(const std::string& provider_id, + const std::string& provider_options_string, + bool captureState) const override; + using EtwInternalCallback = std::function; From 102c3f6a368bcf65e673370105a420c6e6cb8113 Mon Sep 17 00:00:00 2001 From: Jeff Kilpatrick Date: Wed, 2 Jul 2025 10:43:02 -0700 Subject: [PATCH 05/19] [QNN EP] Add qnn_version to build_and_package_info.py (#25229) ### Description This makes the QAIRT/QNN version available in the Python client as `onnxruntime.capi.build_and_package_info.qnn_version`, similar to how it's already done for `cuda_version` and `rcom_version`. ### Motivation and Context Users in some situations need to bring their own QAIRT/QNN SDK. In these cases, it is important to know the correct version to supply to ensure compatibility. --- setup.py | 8 ++++++-- tools/ci_build/build.py | 6 ++++++ .../ci_build/github/android/build_aar_package.py | 11 ++--------- tools/python/util/__init__.py | 1 + tools/python/util/qnn_helpers.py | 15 +++++++++++++++ 5 files changed, 30 insertions(+), 11 deletions(-) create mode 100644 tools/python/util/qnn_helpers.py diff --git a/setup.py b/setup.py index c13f8160cb112..709df50fc0d52 100644 --- a/setup.py +++ b/setup.py @@ -59,6 +59,7 @@ def parse_arg_remove_string(argv, arg_name_equal): is_migraphx = False is_openvino = False is_qnn = False +qnn_version = None # The following arguments are mutually exclusive if wheel_name_suffix == "gpu": # TODO: how to support multiple CUDA versions? @@ -88,6 +89,7 @@ def parse_arg_remove_string(argv, arg_name_equal): elif parse_arg_remove_boolean(sys.argv, "--use_qnn"): is_qnn = True package_name = "onnxruntime-qnn" + qnn_version = parse_arg_remove_string(sys.argv, "--qnn_version=") elif is_migraphx: package_name = "onnxruntime-migraphx" if not nightly_build else "ort-migraphx-nightly" @@ -734,7 +736,7 @@ def reformat_run_count(count_str): install_requires = f.read().splitlines() -def save_build_and_package_info(package_name, version_number, cuda_version, rocm_version): +def save_build_and_package_info(package_name, version_number, cuda_version, rocm_version, qnn_version): sys.path.append(path.join(path.dirname(__file__), "onnxruntime", "python")) from onnxruntime_collect_build_info import find_cudart_versions @@ -763,9 +765,11 @@ def save_build_and_package_info(package_name, version_number, cuda_version, rocm ) elif rocm_version: f.write(f"rocm_version = '{rocm_version}'\n") + elif qnn_version: + f.write(f"qnn_version = '{qnn_version}'\n") -save_build_and_package_info(package_name, version_number, cuda_version, rocm_version) +save_build_and_package_info(package_name, version_number, cuda_version, rocm_version, qnn_version) extras_require = {} if package_name == "onnxruntime-gpu" and is_cuda_version_12: diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index f0782dff23345..add6be8fb2f77 100644 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -40,6 +40,7 @@ def version_to_tuple(version: str) -> tuple: is_linux, is_macOS, is_windows, + parse_qnn_version_from_sdk_yaml, run, ) @@ -1883,6 +1884,7 @@ def build_python_wheel( use_cann, use_azure, use_qnn, + qnn_home, wheel_name_suffix, enable_training, nightly_build=False, @@ -1940,6 +1942,9 @@ def build_python_wheel( args.append("--use_cann") elif use_qnn: args.append("--use_qnn") + qnn_version = parse_qnn_version_from_sdk_yaml(qnn_home) + if qnn_version: + args.append(f"--qnn_version={qnn_version}") elif use_azure: args.append("--use_azure") @@ -2560,6 +2565,7 @@ def main(): args.use_cann, args.use_azure, args.use_qnn, + args.qnn_home, args.wheel_name_suffix, args.enable_training, nightly_build=nightly_build, diff --git a/tools/ci_build/github/android/build_aar_package.py b/tools/ci_build/github/android/build_aar_package.py index 24eac06e95e5d..7e7017d4d3c87 100644 --- a/tools/ci_build/github/android/build_aar_package.py +++ b/tools/ci_build/github/android/build_aar_package.py @@ -16,7 +16,7 @@ JAVA_ROOT = os.path.join(REPO_DIR, "java") sys.path.insert(0, os.path.join(REPO_DIR, "tools", "python")) -from util import is_windows # noqa: E402 +from util import is_windows, parse_qnn_version_from_sdk_yaml # noqa: E402 # We by default will build all 4 ABIs DEFAULT_BUILD_ABIS = ["armeabi-v7a", "arm64-v8a", "x86", "x86_64"] @@ -102,14 +102,7 @@ def _build_aar(args): if qnn_android_build: qnn_home = args.qnn_path - sdk_file = os.path.join(qnn_home, "sdk.yaml") - qnn_sdk_version = None - with open(sdk_file) as f: - for line in f: - if line.strip().startswith("version:"): - # yaml file has simple key: value format with version as key - qnn_sdk_version = line.split(":", 1)[1].strip() - break + qnn_sdk_version = parse_qnn_version_from_sdk_yaml(qnn_home) # Note: The QNN package version does not follow Semantic Versioning (SemVer) format. # only use major.minor.patch version for qnn sdk version and truncate the build_id info if any diff --git a/tools/python/util/__init__.py b/tools/python/util/__init__.py index 8631218ca9e00..5790373760042 100644 --- a/tools/python/util/__init__.py +++ b/tools/python/util/__init__.py @@ -4,6 +4,7 @@ from .get_azcopy import get_azcopy # noqa: F401 from .logger import get_logger from .platform_helpers import is_linux, is_macOS, is_windows # noqa: F401 +from .qnn_helpers import parse_qnn_version_from_sdk_yaml # noqa: F401 from .run import run # noqa: F401 from .vcpkg_helpers import ( # noqa: F401 generate_android_triplets, diff --git a/tools/python/util/qnn_helpers.py b/tools/python/util/qnn_helpers.py new file mode 100644 index 0000000000000..ff08205e2518d --- /dev/null +++ b/tools/python/util/qnn_helpers.py @@ -0,0 +1,15 @@ +#!/usr/bin/env python3 +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import os + + +def parse_qnn_version_from_sdk_yaml(qnn_home): + sdk_file = os.path.join(qnn_home, "sdk.yaml") + with open(sdk_file) as f: + for line in f: + if line.strip().startswith("version:"): + # yaml file has simple key: value format with version as key + return line.split(":", 1)[1].strip() + return None From 4d3949bb86194ae06965696408113309da80f531 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maximilian=20M=C3=BCller?= <44298237+gedoensmax@users.noreply.github.com> Date: Wed, 2 Jul 2025 13:00:20 -0700 Subject: [PATCH 06/19] Separate TRT and TRT RTX directory usage (#25248) ### Description This enables to build TRT and TRT RTX in the same ORT build. --- cmake/onnxruntime_providers_nv.cmake | 44 +++++++++++----------- cmake/onnxruntime_providers_tensorrt.cmake | 8 ++-- tools/ci_build/build.py | 13 +++++-- tools/ci_build/build_args.py | 1 + 4 files changed, 36 insertions(+), 30 deletions(-) diff --git a/cmake/onnxruntime_providers_nv.cmake b/cmake/onnxruntime_providers_nv.cmake index a804f2d7ae55c..e59463b6b91f1 100644 --- a/cmake/onnxruntime_providers_nv.cmake +++ b/cmake/onnxruntime_providers_nv.cmake @@ -17,7 +17,7 @@ endif () add_definitions("-DONNX_ML=1") add_definitions("-DONNX_NAMESPACE=onnx") set(CUDA_INCLUDE_DIRS ${CUDAToolkit_INCLUDE_DIRS}) - set(TENSORRT_ROOT ${onnxruntime_TENSORRT_HOME}) + set(TENSORRT_RTX_ROOT ${onnxruntime_TENSORRT_RTX_HOME}) set(OLD_CMAKE_CXX_FLAGS ${CMAKE_CXX_FLAGS}) set(PROTOBUF_LIBRARY ${PROTOBUF_LIB}) if (WIN32) @@ -34,12 +34,12 @@ endif () endif() set(CXX_VERSION_DEFINED TRUE) - find_path(TENSORRT_INCLUDE_DIR NvInfer.h - HINTS ${TENSORRT_ROOT} + find_path(TENSORRT_RTX_INCLUDE_DIR NvInfer.h + HINTS ${TENSORRT_RTX_ROOT} PATH_SUFFIXES include) - file(READ ${TENSORRT_INCLUDE_DIR}/NvInferVersion.h NVINFER_VER_CONTENT) + file(READ ${TENSORRT_RTX_INCLUDE_DIR}/NvInferVersion.h NVINFER_VER_CONTENT) string(REGEX MATCH "define TRT_MAJOR_RTX * +([0-9]+)" NV_TRT_MAJOR_RTX "${NVINFER_VER_CONTENT}") string(REGEX REPLACE "define TRT_MAJOR_RTX * +([0-9]+)" "\\1" NV_TRT_MAJOR_RTX "${NV_TRT_MAJOR_RTX}") string(REGEX MATCH "define TRT_MINOR_RTX * +([0-9]+)" NV_TRT_MINOR_RTX "${NVINFER_VER_CONTENT}") @@ -54,37 +54,37 @@ endif () endif() if (WIN32) - set(NVINFER_LIB "tensorrt_rtx_${NV_TRT_MAJOR_RTX}_${NV_TRT_MINOR_RTX}") - set(PARSER_LIB "tensorrt_onnxparser_rtx_${NV_TRT_MAJOR_RTX}_${NV_TRT_MINOR_RTX}") + set(TRT_RTX_LIB "tensorrt_rtx_${NV_TRT_MAJOR_RTX}_${NV_TRT_MINOR_RTX}") + set(RTX_PARSER_LIB "tensorrt_onnxparser_rtx_${NV_TRT_MAJOR_RTX}_${NV_TRT_MINOR_RTX}") endif() - if (NOT NVINFER_LIB) - set(NVINFER_LIB "tensorrt_rtx") + if (NOT TRT_RTX_LIB) + set(TRT_RTX_LIB "tensorrt_rtx") endif() - if (NOT PARSER_LIB) - set(PARSER_LIB "tensorrt_onnxparser_rtx") + if (NOT RTX_PARSER_LIB) + set(RTX_PARSER_LIB "tensorrt_onnxparser_rtx") endif() - MESSAGE(STATUS "Looking for ${NVINFER_LIB}") + MESSAGE(STATUS "Looking for ${TRT_RTX_LIB}") - find_library(TENSORRT_LIBRARY_INFER ${NVINFER_LIB} - HINTS ${TENSORRT_ROOT} + find_library(TENSORRT_LIBRARY_INFER ${TRT_RTX_LIB} + HINTS ${TENSORRT_RTX_ROOT} PATH_SUFFIXES lib lib64 lib/x64) if (NOT TENSORRT_LIBRARY_INFER) - MESSAGE(STATUS "Can't find ${NVINFER_LIB}") + MESSAGE(STATUS "Can't find ${TRT_RTX_LIB}") endif() if (onnxruntime_USE_TENSORRT_BUILTIN_PARSER) - MESSAGE(STATUS "Looking for ${PARSER_LIB}") + MESSAGE(STATUS "Looking for ${RTX_PARSER_LIB}") - find_library(TENSORRT_LIBRARY_NVONNXPARSER ${PARSER_LIB} - HINTS ${TENSORRT_ROOT} + find_library(TENSORRT_LIBRARY_NVONNXPARSER ${RTX_PARSER_LIB} + HINTS ${TENSORRT_RTX_ROOT} PATH_SUFFIXES lib lib64 lib/x64) if (NOT TENSORRT_LIBRARY_NVONNXPARSER) - MESSAGE(STATUS "Can't find ${PARSER_LIB}") + MESSAGE(STATUS "Can't find ${RTX_PARSER_LIB}") endif() set(TENSORRT_LIBRARY ${TENSORRT_LIBRARY_INFER} ${TENSORRT_LIBRARY_NVONNXPARSER}) @@ -104,7 +104,6 @@ endif () # The onnx_tensorrt repo contains a test program, getSupportedAPITest, which doesn't support Windows. It uses # unistd.h. So we must exclude it from our build. onnxruntime_fetchcontent_makeavailable is for the purpose. onnxruntime_fetchcontent_makeavailable(onnx_tensorrt) - include_directories(${onnx_tensorrt_SOURCE_DIR}) set(CMAKE_CXX_FLAGS ${OLD_CMAKE_CXX_FLAGS}) if ( CMAKE_COMPILER_IS_GNUCC ) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-unused-parameter") @@ -114,9 +113,9 @@ endif () unset(PROTOBUF_LIBRARY) unset(OLD_CMAKE_CXX_FLAGS) unset(OLD_CMAKE_CUDA_FLAGS) - set_target_properties(${PARSER_LIB} PROPERTIES LINK_FLAGS "/ignore:4199") + set_target_properties(${RTX_PARSER_LIB} PROPERTIES LINK_FLAGS "/ignore:4199") target_compile_options(nvonnxparser_static PRIVATE /FIio.h /wd4100) - target_compile_options(${PARSER_LIB} PRIVATE /FIio.h /wd4100) + target_compile_options(${RTX_PARSER_LIB} PRIVATE /FIio.h /wd4100) endif() # Static libraries are just nvonnxparser_static on all platforms set(onnxparser_link_libs nvonnxparser_static) @@ -124,7 +123,6 @@ endif () MESSAGE(STATUS "Find TensorRT libs at ${TENSORRT_LIBRARY}") endif() - include_directories(${TENSORRT_INCLUDE_DIR}) # ${TENSORRT_LIBRARY} is empty if we link nvonnxparser_static. # nvonnxparser_static is linked against tensorrt libraries in onnx-tensorrt # See https://github.com/onnx/onnx-tensorrt/blob/8af13d1b106f58df1e98945a5e7c851ddb5f0791/CMakeLists.txt#L121 @@ -152,7 +150,7 @@ endif () else() target_link_libraries(onnxruntime_providers_nv_tensorrt_rtx PRIVATE ${onnxparser_link_libs} ${trt_link_libs} ${ONNXRUNTIME_PROVIDERS_SHARED} ${PROTOBUF_LIB} flatbuffers::flatbuffers ${ABSEIL_LIBS} PUBLIC CUDA::cudart) endif() - target_include_directories(onnxruntime_providers_nv_tensorrt_rtx PRIVATE ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR} + target_include_directories(onnxruntime_providers_nv_tensorrt_rtx PRIVATE ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR} ${TENSORRT_RTX_INCLUDE_DIR} ${onnx_tensorrt_SOURCE_DIR} PUBLIC ${CUDAToolkit_INCLUDE_DIRS}) # ${CMAKE_CURRENT_BINARY_DIR} is so that #include "onnxruntime_config.h" inside tensor_shape.h is found diff --git a/cmake/onnxruntime_providers_tensorrt.cmake b/cmake/onnxruntime_providers_tensorrt.cmake index 3698aaa902922..69c81a5ec7b9d 100644 --- a/cmake/onnxruntime_providers_tensorrt.cmake +++ b/cmake/onnxruntime_providers_tensorrt.cmake @@ -138,7 +138,6 @@ # The onnx_tensorrt repo contains a test program, getSupportedAPITest, which doesn't support Windows. It uses # unistd.h. So we must exclude it from our build. onnxruntime_fetchcontent_makeavailable is for the purpose. onnxruntime_fetchcontent_makeavailable(onnx_tensorrt) - include_directories(${onnx_tensorrt_SOURCE_DIR}) set(CMAKE_CXX_FLAGS ${OLD_CMAKE_CXX_FLAGS}) if ( CMAKE_COMPILER_IS_GNUCC ) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-unused-parameter") @@ -158,7 +157,6 @@ MESSAGE(STATUS "Find TensorRT libs at ${TENSORRT_LIBRARY}") endif() - include_directories(${TENSORRT_INCLUDE_DIR}) # ${TENSORRT_LIBRARY} is empty if we link nvonnxparser_static. # nvonnxparser_static is linked against tensorrt libraries in onnx-tensorrt # See https://github.com/onnx/onnx-tensorrt/blob/8af13d1b106f58df1e98945a5e7c851ddb5f0791/CMakeLists.txt#L121 @@ -197,9 +195,11 @@ else() target_link_libraries(onnxruntime_providers_tensorrt PRIVATE ${onnxparser_link_libs} ${trt_link_libs} ${ONNXRUNTIME_PROVIDERS_SHARED} ${PROTOBUF_LIB} flatbuffers::flatbuffers ${ABSEIL_LIBS} PUBLIC CUDA::cudart) endif() - target_include_directories(onnxruntime_providers_tensorrt PRIVATE ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR} + target_include_directories(onnxruntime_providers_tensorrt PRIVATE ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR} ${TENSORRT_INCLUDE_DIR} PUBLIC ${CUDAToolkit_INCLUDE_DIRS}) - + if (NOT onnxruntime_USE_TENSORRT_BUILTIN_PARSER) + target_include_directories(onnxruntime_providers_tensorrt PRIVATE ${onnx_tensorrt_SOURCE_DIR}) + endif() # ${CMAKE_CURRENT_BINARY_DIR} is so that #include "onnxruntime_config.h" inside tensor_shape.h is found set_target_properties(onnxruntime_providers_tensorrt PROPERTIES LINKER_LANGUAGE CUDA) set_target_properties(onnxruntime_providers_tensorrt PROPERTIES FOLDER "ONNXRuntime") diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index add6be8fb2f77..fab11fa4e4d22 100644 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -350,6 +350,7 @@ def generate_build_tree( rocm_home, nccl_home, tensorrt_home, + tensorrt_rtx_home, migraphx_home, acl_home, acl_libs, @@ -722,8 +723,10 @@ def generate_build_tree( cmake_args.append("-Donnxruntime_ROCM_HOME=" + rocm_home) cmake_args.append("-Donnxruntime_ROCM_VERSION=" + args.rocm_version) - if args.use_tensorrt or args.use_nv_tensorrt_rtx: + if args.use_tensorrt: cmake_args.append("-Donnxruntime_TENSORRT_HOME=" + tensorrt_home) + if args.use_nv_tensorrt_rtx: + cmake_args.append("-Donnxruntime_TENSORRT_RTX_HOME=" + tensorrt_rtx_home) if args.use_cuda: nvcc_threads = number_of_nvcc_threads(args) @@ -1383,7 +1386,7 @@ def setup_cann_vars(args): def setup_tensorrt_vars(args): tensorrt_home = "" - if args.use_tensorrt or args.use_nv_tensorrt_rtx: + if args.use_tensorrt: tensorrt_home = args.tensorrt_home if args.tensorrt_home else os.getenv("TENSORRT_HOME") tensorrt_home_valid = tensorrt_home is not None and os.path.exists(tensorrt_home) if not tensorrt_home_valid: @@ -2352,7 +2355,10 @@ def main(): # if using tensorrt, setup tensorrt paths tensorrt_home = "" - if args.use_tensorrt or args.use_nv_tensorrt_rtx: + tensorrt_rtx_home = "" + if args.use_nv_tensorrt_rtx: + tensorrt_rtx_home = args.tensorrt_rtx_home + if args.use_tensorrt: tensorrt_home = setup_tensorrt_vars(args) # if using migraphx, setup migraphx paths @@ -2495,6 +2501,7 @@ def main(): rocm_home, nccl_home, tensorrt_home, + tensorrt_rtx_home, migraphx_home, acl_home, acl_libs, diff --git a/tools/ci_build/build_args.py b/tools/ci_build/build_args.py index 22f9cc054006e..7448ebe931d1e 100644 --- a/tools/ci_build/build_args.py +++ b/tools/ci_build/build_args.py @@ -648,6 +648,7 @@ def add_execution_provider_args(parser: argparse.ArgumentParser) -> None: ) trt_group.add_argument("--use_tensorrt_oss_parser", action="store_true", help="Use TensorRT OSS ONNX parser.") trt_group.add_argument("--tensorrt_home", help="Path to TensorRT installation directory.") + trt_group.add_argument("--tensorrt_rtx_home", help="Path to TensorRT RTX installation directory.") # --- Nv --- nv_group = parser.add_argument_group("Nv Execution Provider") From d81905edfdd845d05640f409c2a5596bfafe8c68 Mon Sep 17 00:00:00 2001 From: chenweng-quic <168707118+chenweng-quic@users.noreply.github.com> Date: Thu, 3 Jul 2025 04:26:45 +0800 Subject: [PATCH 07/19] [QNN EP] Improve QNN EP UDO support for QDQ model (#25194) ### Description - Add handling for input DQ and output Q. To avoid dummy qnn Quantize/Dequantize node translated, move udo translation from builder to qnn_node_group. - Change dlopen flag or HTP is not able to locate symbol. ### Motivation and Context To improve the UDO support for QDQ model. --- .../qnn/builder/op_builder_factory.cc | 7 +- .../qnn/builder/op_builder_factory.h | 16 - .../qnn/builder/opbuilder/base_op_builder.h | 47 --- .../qnn/builder/opbuilder/udo_builder.cc | 134 -------- .../qnn/builder/qnn_backend_manager.cc | 4 +- .../qnn/builder/qnn_backend_manager.h | 13 +- .../providers/qnn/builder/qnn_model_wrapper.h | 48 +++ .../builder/qnn_node_group/qnn_node_group.cc | 46 ++- .../builder/qnn_node_group/qnn_node_group.h | 4 + .../qnn/builder/qnn_node_group/udo_fusion.cc | 317 ++++++++++++++++++ .../qnn/builder/qnn_node_group/udo_fusion.h | 69 ++++ .../core/providers/qnn/builder/qnn_utils.cc | 4 +- 12 files changed, 481 insertions(+), 228 deletions(-) delete mode 100644 onnxruntime/core/providers/qnn/builder/opbuilder/udo_builder.cc create mode 100644 onnxruntime/core/providers/qnn/builder/qnn_node_group/udo_fusion.cc create mode 100644 onnxruntime/core/providers/qnn/builder/qnn_node_group/udo_fusion.h diff --git a/onnxruntime/core/providers/qnn/builder/op_builder_factory.cc b/onnxruntime/core/providers/qnn/builder/op_builder_factory.cc index 529f6ce824033..b27bbf1ed2f13 100644 --- a/onnxruntime/core/providers/qnn/builder/op_builder_factory.cc +++ b/onnxruntime/core/providers/qnn/builder/op_builder_factory.cc @@ -10,8 +10,6 @@ namespace onnxruntime { namespace qnn { -static OpBuilderRegistrations op_registrations; - OpBuilderRegistrations::OpBuilderRegistrations() { { CreateSimpleOpBuilder("Add", *this); @@ -201,11 +199,8 @@ OpBuilderRegistrations::OpBuilderRegistrations() { } } -void RegisterUDOBuilder(const std::string& op_type, const std::string& op_package) { - CreateUDOBuilder(op_type, op_package, op_registrations); -} - const IOpBuilder* GetOpBuilder(const std::string& onnx_op_type) { + static const OpBuilderRegistrations op_registrations; return op_registrations.GetOpBuilderByOnnxOpType(onnx_op_type); } diff --git a/onnxruntime/core/providers/qnn/builder/op_builder_factory.h b/onnxruntime/core/providers/qnn/builder/op_builder_factory.h index 4fa9ec6cc0fe1..b8eb0584b342a 100644 --- a/onnxruntime/core/providers/qnn/builder/op_builder_factory.h +++ b/onnxruntime/core/providers/qnn/builder/op_builder_factory.h @@ -41,20 +41,6 @@ class OpBuilderRegistrations { } } - void RegisterUDOBuilder(const std::string& op_type, std::unique_ptr builder) { - auto builder_type = builder->GetOpBuilderType(); - auto pos_in_builder_type_map = builder_type_builder_map_.find(builder_type); - if (pos_in_builder_type_map != builder_type_builder_map_.end()) { - // already have this builder type, re-use it for this op_type - op_builder_map_[op_type] = pos_in_builder_type_map->second; - } else { - // New Op builder, add to vector and all the maps - builders_.push_back(std::move(builder)); - op_builder_map_[op_type] = builders_.back().get(); - builder_type_builder_map_[builder_type] = builders_.back().get(); - } - } - private: std::vector> builders_; // @@ -62,7 +48,6 @@ class OpBuilderRegistrations { // std::unordered_map builder_type_builder_map_; }; -void RegisterUDOBuilder(const std::string& op_type, const std::string& op_package); const IOpBuilder* GetOpBuilder(const std::string& onnx_op_type); @@ -128,6 +113,5 @@ void CreateCumSumOpBuilder(const std::string& op_type, OpBuilderRegistrations& o void CreateMeanOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); -void CreateUDOBuilder(const std::string& op_type, const std::string& op_package, OpBuilderRegistrations& op_registrations); } // namespace qnn } // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.h b/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.h index 2ccfad206a38a..e009bc558e884 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.h +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.h @@ -107,53 +107,6 @@ class BaseOpBuilder : public IOpBuilder { const logging::Logger& logger, std::vector& input_names) const ORT_MUST_USE_RESULT; - template - Status AddQnnScalar(QnnModelWrapper& qnn_model_wrapper, - const NodeIndex& node_index, - const std::string& node_name, - const T& scalar, - const std::string& qnn_scalar_param_name, - std::vector& param_names) const { - Qnn_Scalar_t qnn_scalar = QNN_SCALAR_INIT; - if (std::is_same::value) { - qnn_scalar.dataType = QNN_DATATYPE_FLOAT_32; - qnn_scalar.floatValue = static_cast(scalar); - } else if (std::is_same::value) { - qnn_scalar.dataType = QNN_DATATYPE_UINT_32; - qnn_scalar.uint32Value = static_cast(scalar); - } else if (std::is_same::value) { - qnn_scalar.dataType = QNN_DATATYPE_INT_32; - qnn_scalar.int32Value = static_cast(scalar); - } else if (std::is_same::value) { - qnn_scalar.dataType = QNN_DATATYPE_INT_64; - qnn_scalar.int64Value = static_cast(scalar); - } else if (std::is_same::value) { - qnn_scalar.dataType = QNN_DATATYPE_BOOL_8; - qnn_scalar.bool8Value = static_cast(scalar); - } else { - ORT_RETURN_IF(true, "QNN EP: Unsupported scalar dtype"); - } - QnnParamWrapper qnn_param_wrapper(node_index, node_name, qnn_scalar_param_name, qnn_scalar); - param_names.push_back(qnn_param_wrapper.GetParamTensorName()); - qnn_model_wrapper.AddParamWrapper(std::move(qnn_param_wrapper)); - return Status::OK(); - } - - Status AddQnnScalar(QnnModelWrapper& qnn_model_wrapper, - const NodeIndex& node_index, - const std::string& node_name, - const std::string& scalar, - const std::string& qnn_scalar_param_name, - std::vector& param_names) const { - Qnn_Scalar_t qnn_scalar = QNN_SCALAR_INIT; - qnn_scalar.dataType = QNN_DATATYPE_STRING; - qnn_scalar.stringValue = scalar.c_str(); - QnnParamWrapper qnn_param_wrapper(node_index, node_name, qnn_scalar_param_name, qnn_scalar); - param_names.push_back(qnn_param_wrapper.GetParamTensorName()); - qnn_model_wrapper.AddParamWrapper(std::move(qnn_param_wrapper)); - return Status::OK(); - } - Status SetOutputQParamEqualToInputIfNearlyEqual(QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit, const logging::Logger& logger, diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/udo_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/udo_builder.cc deleted file mode 100644 index 339c521952bcf..0000000000000 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/udo_builder.cc +++ /dev/null @@ -1,134 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include -#include -#include -#include - -#include "core/providers/qnn/ort_api.h" -#include "core/providers/qnn/builder/op_builder_factory.h" -#include "core/providers/qnn/builder/opbuilder/base_op_builder.h" -#include "core/providers/qnn/builder/qnn_model_wrapper.h" -#include "core/providers/qnn/builder/qnn_utils.h" - -namespace onnxruntime { -namespace qnn { - -class UDOBuilder : public BaseOpBuilder { - public: - UDOBuilder(const std::string& op_type, const std::string& op_package) : BaseOpBuilder(op_type + "_UDOBuilder"), op_type_(op_type), op_package_(op_package) {} - ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(UDOBuilder); - - protected: - Status ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit, - std::vector&& input_names, - const logging::Logger& logger, - bool do_op_validation) const override ORT_MUST_USE_RESULT; - - private: - const std::string op_type_; - const std::string op_package_; -}; - -Status UDOBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrapper, - const NodeUnit& node_unit, - std::vector&& input_names, - const logging::Logger& logger, - bool do_op_validation) const { - ORT_UNUSED_PARAMETER(logger); - std::string node_name = utils::GetNodeName(node_unit); - const auto& outputs = node_unit.Outputs(); - std::vector output_names; - for (size_t i = 0; i < outputs.size(); ++i) { - const auto& output_name = outputs[i].node_arg.Name(); - - TensorInfo output_info = {}; - ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(outputs[i], output_info)); - bool is_graph_output = qnn_model_wrapper.IsGraphOutput(output_name); - - Qnn_TensorType_t tensor_type = is_graph_output ? QNN_TENSOR_TYPE_APP_READ : QNN_TENSOR_TYPE_NATIVE; - QnnTensorWrapper output_tensorwrapper(output_name, - tensor_type, - output_info.qnn_data_type, - std::move(output_info.quant_param), - std::move(output_info.shape)); - ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(output_tensorwrapper)), "Failed to add tensor."); - output_names.emplace_back(output_name); - } - std::vector param_names; - NodeAttrHelper node_helper(node_unit); - auto& attrs = node_unit.GetNode().GetAttributes(); - for (auto& attr : attrs) { - std::string attr_name = attr.first; - auto& attr_value = attr.second; - LOGS(logger, VERBOSE) << "Parse attr name: " << attr_name << " for op " << node_name; - switch (attr_value.type()) { - case ONNX_NAMESPACE::AttributeProto::FLOAT: { - auto optional_float = node_helper.GetFloat(attr_name); - ORT_RETURN_IF_NOT(optional_float.has_value(), - "Failed to get values from attr ", attr_name, " in op ", node_name, " to qnn_model_wrapper."); - ORT_RETURN_IF_ERROR(AddQnnScalar(qnn_model_wrapper, node_unit.Index(), node_name, optional_float.value(), attr_name, param_names)); - break; - } - case ONNX_NAMESPACE::AttributeProto::FLOATS: { - auto optional_floats = node_helper.GetFloats(attr_name); - ORT_RETURN_IF_NOT(optional_floats.has_value(), - "Failed to get values from attr ", attr_name, " in op ", node_name, " to qnn_model_wrapper."); - std::vector floats_data(optional_floats.value().begin(), optional_floats.value().end()); - auto param_wrapper = createQnnParamWrapper(node_unit.Index(), node_name, attr_name, - {static_cast(floats_data.size())}, std::move(floats_data)); - param_names.push_back(param_wrapper.GetParamTensorName()); - ORT_RETURN_IF_NOT(qnn_model_wrapper.AddParamWrapper(std::move(param_wrapper)), - "Failed to add tensor attr ", attr_name, " in op ", node_name, " to qnn_model_wrapper."); - break; - } - case ONNX_NAMESPACE::AttributeProto::INT: { - auto optional_int64 = node_helper.GetInt64(attr_name); - ORT_RETURN_IF_NOT(optional_int64.has_value(), - "Failed to get values from attr ", attr_name, " in op ", node_name, " to qnn_model_wrapper."); - ORT_RETURN_IF_ERROR(AddQnnScalar(qnn_model_wrapper, node_unit.Index(), node_name, optional_int64.value(), attr_name, param_names)); - break; - } - case ONNX_NAMESPACE::AttributeProto::INTS: { - auto optional_int64s = node_helper.GetInt64s(attr_name); - ORT_RETURN_IF_NOT(optional_int64s.has_value(), - "Failed to get values from attr ", attr_name, " in op ", node_name, " to qnn_model_wrapper."); - std::vector int64s_data(optional_int64s.value().begin(), optional_int64s.value().end()); - auto param_wrapper = createQnnParamWrapper(node_unit.Index(), node_name, attr_name, - {static_cast(int64s_data.size())}, std::move(int64s_data)); - param_names.push_back(param_wrapper.GetParamTensorName()); - ORT_RETURN_IF_NOT(qnn_model_wrapper.AddParamWrapper(std::move(param_wrapper)), - "Failed to add tensor attr ", attr_name, " in op ", node_name, " to qnn_model_wrapper."); - break; - } - case ONNX_NAMESPACE::AttributeProto::STRING: { - auto optional_string = node_helper.GetString(attr_name); - ORT_RETURN_IF_NOT(optional_string.has_value(), - "Failed to get values from attr ", attr_name, " in op ", node_name, " to qnn_model_wrapper."); - ORT_RETURN_IF_ERROR(AddQnnScalar(qnn_model_wrapper, node_unit.Index(), node_name, optional_string.value(), attr_name, param_names)); - break; - } - default: { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to add scalar attr ", attr_name, " data_type ", attr_value.type(), " in op ", node_name, " to qnn_model_wrapper."); - } - } - } - - ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(node_name, - op_package_, - op_type_, - std::move(input_names), - std::move(output_names), - std::move(param_names), - do_op_validation), - "Failed to add node."); - return Status::OK(); -} - -void CreateUDOBuilder(const std::string& op_type, const std::string& op_package, OpBuilderRegistrations& op_registrations) { - op_registrations.RegisterUDOBuilder(op_type, std::make_unique(op_type, op_package)); -} - -} // namespace qnn -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc index 2c7098749a985..d22edaf33eb1c 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc @@ -215,7 +215,7 @@ Status QnnBackendManager::GetQnnInterfaceProvider(const char* lib_path, T** interface_provider) { std::string error_msg; *backend_lib_handle = LoadLib(lib_path, - static_cast(DlOpenFlag::DL_NOW) | static_cast(DlOpenFlag::DL_LOCAL), + static_cast(DlOpenFlag::DL_NOW) | static_cast(DlOpenFlag::DL_GLOBAL), error_msg); ORT_RETURN_IF(nullptr == *backend_lib_handle, "Unable to load backend, error: ", error_msg, " ", DlError()); @@ -816,7 +816,7 @@ Status QnnBackendManager::CreateContextVtcmBackupBufferSharingEnabled(std::unord it.second.get()}; QnnContext_Params_t context_params = {QnnContext_ParamsVersion_t::QNN_CONTEXT_PARAMS_VERSION_1, - context_params_v1}; + {context_params_v1}}; buffer_list.push_back(std::move(buffer)); context_params_list.push_back(std::move(context_params)); diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h index 371dc6dd4fc4a..3e68df3024565 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h @@ -27,6 +27,7 @@ #include "core/providers/qnn/builder/op_builder_factory.h" #include "core/providers/qnn/builder/qnn_context_mem_handle_manager.h" #include "core/providers/qnn/builder/qnn_def.h" +#include "core/providers/qnn/builder/qnn_node_group/qnn_node_group.h" namespace onnxruntime { namespace qnn { @@ -388,13 +389,13 @@ class QnnBackendManager : public std::enable_shared_from_this } ORT_RETURN_IF(QNN_SUCCESS != result, "Failed to register op package to backend. Error: ", QnnErrorHandleToString(result)); LOGS(*logger_, VERBOSE) << "Successfully register the op package."; - std::string op_package_for_registration = std::filesystem::path(op_package.path).stem().string(); - // remove lib prefix in Linux - std::string prefix = "lib"; - if (op_package_for_registration.compare(0, prefix.size(), prefix) == 0) { - op_package_for_registration = op_package_for_registration.substr(prefix.size()); + std::string op_package_for_registration = op_package.interface; + std::string suffix = "InterfaceProvider"; + if (op_package_for_registration.size() >= suffix.size() && + op_package_for_registration.compare(op_package_for_registration.size() - suffix.size(), suffix.size(), suffix) == 0) { + op_package_for_registration.erase(op_package_for_registration.size() - suffix.size()); } - qnn::RegisterUDOBuilder(op_package.op_type, op_package_for_registration); + registerUDO(op_package.op_type, op_package_for_registration); } return Status::OK(); diff --git a/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.h b/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.h index 745dfde7bfac8..f940351b7626e 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.h @@ -5,6 +5,7 @@ #include #include +#include #include #include "QnnInterface.h" @@ -332,5 +333,52 @@ class QnnModelWrapper { utils::QnnJSONGraph json_qnn_graph_; }; // QnnModelWrapper +template +inline Status AddQnnScalar(QnnModelWrapper& qnn_model_wrapper, + const NodeIndex& node_index, + const std::string& node_name, + const T& scalar, + const std::string& qnn_scalar_param_name, + std::vector& param_names) { + Qnn_Scalar_t qnn_scalar = QNN_SCALAR_INIT; + if (std::is_same::value) { + qnn_scalar.dataType = QNN_DATATYPE_FLOAT_32; + qnn_scalar.floatValue = static_cast(scalar); + } else if (std::is_same::value) { + qnn_scalar.dataType = QNN_DATATYPE_UINT_32; + qnn_scalar.uint32Value = static_cast(scalar); + } else if (std::is_same::value) { + qnn_scalar.dataType = QNN_DATATYPE_INT_32; + qnn_scalar.int32Value = static_cast(scalar); + } else if (std::is_same::value) { + qnn_scalar.dataType = QNN_DATATYPE_INT_64; + qnn_scalar.int64Value = static_cast(scalar); + } else if (std::is_same::value) { + qnn_scalar.dataType = QNN_DATATYPE_BOOL_8; + qnn_scalar.bool8Value = static_cast(scalar); + } else { + ORT_RETURN_IF(true, "QNN EP: Unsupported scalar dtype"); + } + QnnParamWrapper qnn_param_wrapper(node_index, node_name, qnn_scalar_param_name, qnn_scalar); + param_names.push_back(qnn_param_wrapper.GetParamTensorName()); + qnn_model_wrapper.AddParamWrapper(std::move(qnn_param_wrapper)); + return Status::OK(); +} + +inline Status AddQnnScalar(QnnModelWrapper& qnn_model_wrapper, + const NodeIndex& node_index, + const std::string& node_name, + const std::string& scalar, + const std::string& qnn_scalar_param_name, + std::vector& param_names) { + Qnn_Scalar_t qnn_scalar = QNN_SCALAR_INIT; + qnn_scalar.dataType = QNN_DATATYPE_STRING; + qnn_scalar.stringValue = scalar.c_str(); + QnnParamWrapper qnn_param_wrapper(node_index, node_name, qnn_scalar_param_name, qnn_scalar); + param_names.push_back(qnn_param_wrapper.GetParamTensorName()); + qnn_model_wrapper.AddParamWrapper(std::move(qnn_param_wrapper)); + return Status::OK(); +} + } // namespace qnn } // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/qnn_node_group.cc b/onnxruntime/core/providers/qnn/builder/qnn_node_group/qnn_node_group.cc index cc512524e4dd7..4711a7fd264b1 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_node_group/qnn_node_group.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/qnn_node_group.cc @@ -17,6 +17,7 @@ #include "core/providers/qnn/builder/qnn_node_group/reshape_gemm_fusion.h" #include "core/providers/qnn/builder/qnn_node_group/scale_softmax_fusion.h" #include "core/providers/qnn/builder/qnn_node_group/channel_shuffle_fusion.h" +#include "core/providers/qnn/builder/qnn_node_group/udo_fusion.h" #include "core/providers/qnn/builder/qnn_utils.h" #include "core/providers/qnn/ort_api.h" @@ -64,13 +65,36 @@ class QnnNodeUnitWrapper : public IQnnNodeGroup { /// /// The type of a function that tries to fuse NodeUnits into a IQnnNodeGroup. /// -using FusionFunc = std::unique_ptr (*)( - QnnModelWrapper&, - const NodeUnit&, - const std::unordered_map&, - const std::unordered_map&, - const logging::Logger&); - +using FusionFunc = std::function(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& udo_node_unit, + const std::unordered_map& node_to_node_unit, + const std::unordered_map& node_unit_to_qnn_node_group, + const logging::Logger& logger)>; + +// Maps a starting operator type to the fusion function. +static std::unordered_map fusions = { + {"DequantizeLinear", DQQFusion::TryFusion}, + {"HardSigmoid", HardSigmoidMulFusion::TryFusion}, + {"Gemm", ReshapeGemmFusion::TryFusion}, + {"Mul", ScaleSoftmaxFusion::TryFusion}, + {"Transpose", ChannelShuffleFusion::TryFusion}}; + +void registerUDO(const std::string& node_type, const std::string& op_package) { + std::function(QnnModelWrapper & qnn_model_wrapper, + const NodeUnit& udo_node_unit, + const std::unordered_map& node_to_node_unit, + const std::unordered_map& node_unit_to_qnn_node_group, + const logging::Logger& logger)> + boundFunction = std::bind(&UDOQDQFusion::TryFusion, + node_type, + op_package, + /*qnn_model_wrapper=*/std::placeholders::_1, + /*udo_node_unit=*/std::placeholders::_2, + /*node_to_node_unit=*/std::placeholders::_3, + /*node_unit_to_qnn_node_group=*/std::placeholders::_4, + /*logger=*/std::placeholders::_5); + fusions[node_type] = boundFunction; +} /// /// Given a starting NodeUnit, this function tries all possible fusions that start with that NodeUnit. /// If successful, returns a IQnnNodeGroup object that represents the fusion of various NodeUnits. @@ -88,14 +112,6 @@ static std::unique_ptr TryQnnFusions( const std::unordered_map& node_to_node_unit, const std::unordered_map& node_unit_to_qnn_node_group, const logging::Logger& logger) { - // Maps a starting operator type to the fusion function. - static std::unordered_map fusions = { - {"DequantizeLinear", DQQFusion::TryFusion}, - {"HardSigmoid", HardSigmoidMulFusion::TryFusion}, - {"Gemm", ReshapeGemmFusion::TryFusion}, - {"Mul", ScaleSoftmaxFusion::TryFusion}, - {"Transpose", ChannelShuffleFusion::TryFusion}}; - // For now, all fusions involve standalone node units (i.e., no wrapping DQ/Q nodes). if (starting_node_unit.UnitType() != NodeUnit::Type::SingleNode) { return nullptr; diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/qnn_node_group.h b/onnxruntime/core/providers/qnn/builder/qnn_node_group/qnn_node_group.h index 276fbaae3b3c9..d133287cc9910 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_node_group/qnn_node_group.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/qnn_node_group.h @@ -5,6 +5,7 @@ #include #include +#include #include #include @@ -48,6 +49,9 @@ class IQnnNodeGroup { virtual std::string_view Type() const = 0; }; +/// Function to register fusion for QDQ +void registerUDO(const std::string& node_type, const std::string& op_package); + /// /// Traverses the ONNX graph to create IQnnNodeGroup objects, each containing one or more NodeUnits. /// The returned IQnnNodeGroup objects are sorted in topological order. diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/udo_fusion.cc b/onnxruntime/core/providers/qnn/builder/qnn_node_group/udo_fusion.cc new file mode 100644 index 0000000000000..9c5b82bcfd68f --- /dev/null +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/udo_fusion.cc @@ -0,0 +1,317 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/qnn/builder/qnn_node_group/udo_fusion.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "core/providers/qnn/ort_api.h" +#include "core/providers/qnn/builder/qnn_utils.h" +#include "core/providers/qnn/builder/op_builder_factory.h" +#include "core/providers/qnn/builder/qnn_model_wrapper.h" +#include "core/providers/qnn/builder/qnn_node_group/utils.h" +#include + +namespace onnxruntime { +namespace qnn { + +namespace { + +Status GetInputNodeUnits(const GraphViewer& graph_viewer, + const NodeUnit& node_unit, + const std::unordered_map& node_unit_map, + const std::unordered_map& qnn_node_group_map, + /*out*/ std::map& input_node_units) { + const Node& node = node_unit.GetNode(); + + // input must be of a valid type. + for (auto input_edge_iter = node.InputEdgesBegin(); input_edge_iter != node.InputEdgesEnd(); ++input_edge_iter) { + auto& input_node = (*input_edge_iter).GetNode(); + if (graph_viewer.GetNode(input_node.Index()) == nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Input node not exists in graph."); + } + + const auto input_node_unit_it = node_unit_map.find(&input_node); + if (input_node_unit_it == node_unit_map.end()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Input node has no NodeUnit mapping."); + } + const NodeUnit* input_node_unit = input_node_unit_it->second; + + // Check if input quant node has already been handled. Should not be the case if the calling + // fusion function has been called in topological order, but check to be safe. + if (input_node_unit->OpType() == DEQUANTIZE_LINEAR && qnn_node_group_map.count(input_node_unit) != 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Input quant node has been added"); + } + + if (input_node_unit->UnitType() != NodeUnit::Type::SingleNode) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Input node is not in single format."); + } + input_node_units[(*input_edge_iter).GetDstArgIndex()] = input_node_unit; + } + return Status::OK(); +} + +Status GetOutputNodeUnits(const GraphViewer& graph_viewer, + const NodeUnit& node_unit, + const std::unordered_map& node_unit_map, + /*out*/ std::map& output_node_units) { + const Node& node = node_unit.GetNode(); + + // Child must be of a valid type. + for (auto output_edge_iter = node.OutputEdgesBegin(); output_edge_iter != node.OutputEdgesEnd(); ++output_edge_iter) { + auto& output_node = (*output_edge_iter).GetNode(); + if (graph_viewer.GetNode(output_node.Index()) == nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Output node not exists in graph."); + } + + const auto output_node_unit_it = node_unit_map.find(&output_node); + if (output_node_unit_it == node_unit_map.end()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Output node has no NodeUnit mapping."); + } + const NodeUnit* output_node_unit = output_node_unit_it->second; + + if (output_node_unit->UnitType() != NodeUnit::Type::SingleNode) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Output node is not in single format."); + } + output_node_units[(*output_edge_iter).GetSrcArgIndex()] = output_node_unit; + } + return Status::OK(); +} + +static Status CreateOrValidateOnQnn( + const std::string& op_type, + const std::string& op_package, + QnnModelWrapper& qnn_model_wrapper, + const std::map& input_node_units, + const NodeUnit& node_unit, + const std::map& output_node_units, + bool do_op_validation, + const logging::Logger& logger) { + ORT_UNUSED_PARAMETER(do_op_validation); + + std::string node_name = utils::GetNodeName(node_unit); + + // get qnn inputs + const auto& inputs = node_unit.Inputs(); + // std::vector input_tensor_wrappers; + std::vector input_names; + for (size_t i = 0; i < inputs.size(); i++) { + const NodeUnit* input_edge_src_node_unit = (input_node_units.find(i) != input_node_units.end()) ? input_node_units.at(i) : nullptr; + if (!inputs[i].node_arg.Exists()) { + continue; + } + + // since input could come from initialize or graph input, which are not NodeUnit, + // we have to compare the name to get the correct order + std::string input_name = inputs[i].node_arg.Name(); + const NodeUnitIODef* input_def = &inputs[i]; + if (input_edge_src_node_unit && input_edge_src_node_unit->OpType() == DEQUANTIZE_LINEAR) { + input_name = input_edge_src_node_unit->Inputs()[0].node_arg.Name(); + input_def = &(input_edge_src_node_unit->Inputs()[0]); + } + + if (!qnn_model_wrapper.IsQnnTensorWrapperExist(input_name)) { + TensorInfo tensor_info; + ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(*input_def, tensor_info)); + + QnnTensorWrapper tensor_wrapper; + // input_tensor_wrappers.emplace_back(QnnTensorWrapper()); + ORT_RETURN_IF_ERROR(qnn_model_wrapper.MakeTensorWrapper(tensor_info, input_name, tensor_wrapper)); + // input_tensor_wrappers.emplace_back(tensor_wrapper); + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(tensor_wrapper)), + "Failed to add tensor: " + input_name); + } else { + LOGS(logger, VERBOSE) << "Tensor already added, skip it: " << input_name; + } + + input_names.emplace_back(input_name); + } + + // get qnn outputs + const auto& outputs = node_unit.Outputs(); + // std::vector output_tensor_wrappers; + std::vector output_names; + for (size_t i = 0; i < outputs.size(); ++i) { + const NodeUnit* output_edge_dst_node_unit = (output_node_units.find(i) != output_node_units.end()) ? output_node_units.at(i) : nullptr; + if (!outputs[i].node_arg.Exists()) { + continue; + } + std::string output_name = outputs[i].node_arg.Name(); + const NodeUnitIODef* output_def = &outputs[i]; + if (output_edge_dst_node_unit && output_edge_dst_node_unit->OpType() == QUANTIZE_LINEAR) { + output_name = output_edge_dst_node_unit->Outputs()[0].node_arg.Name(); + output_def = &(output_edge_dst_node_unit->Outputs()[0]); + } + + TensorInfo output_info = {}; + ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(*output_def, output_info)); + bool is_graph_output = qnn_model_wrapper.IsGraphOutput(output_name); + + Qnn_TensorType_t tensor_type = is_graph_output ? QNN_TENSOR_TYPE_APP_READ : QNN_TENSOR_TYPE_NATIVE; + QnnTensorWrapper output_tensor_wrapper(output_name, + tensor_type, + output_info.qnn_data_type, + std::move(output_info.quant_param), + std::move(output_info.shape)); + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(output_tensor_wrapper)), "Failed to add tensor."); + output_names.emplace_back(output_name); + } + + // get qnn params + NodeAttrHelper node_helper(node_unit); + std::vector param_names; + for (auto& attr : node_unit.GetNode().GetAttributes()) { + std::string attr_name = attr.first; + auto& attr_value = attr.second; + LOGS(logger, VERBOSE) << "Parse attribute name: " << attr_name << " for op " << node_name; + switch (attr_value.type()) { + case ONNX_NAMESPACE::AttributeProto::FLOAT: { + auto optional_float = node_helper.GetFloat(attr_name); + ORT_RETURN_IF_NOT(optional_float.has_value(), + "Failed to get values from attr ", attr_name, " in op ", node_name, " to qnn_model_wrapper."); + ORT_RETURN_IF_ERROR(AddQnnScalar(qnn_model_wrapper, node_unit.Index(), node_name, optional_float.value(), attr_name, param_names)); + break; + } + case ONNX_NAMESPACE::AttributeProto::FLOATS: { + auto optional_floats = node_helper.GetFloats(attr_name); + ORT_RETURN_IF_NOT(optional_floats.has_value(), + "Failed to get values from attr ", attr_name, " in op ", node_name, " to qnn_model_wrapper."); + std::vector floats_data(optional_floats.value().begin(), optional_floats.value().end()); + auto param_wrapper = createQnnParamWrapper(node_unit.Index(), node_name, attr_name, + {static_cast(floats_data.size())}, std::move(floats_data)); + param_names.push_back(param_wrapper.GetParamTensorName()); + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddParamWrapper(std::move(param_wrapper)), + "Failed to add tensor attr ", attr_name, " in op ", node_name, " to qnn_model_wrapper."); + break; + } + case ONNX_NAMESPACE::AttributeProto::INT: { + auto optional_int64 = node_helper.GetInt64(attr_name); + ORT_RETURN_IF_NOT(optional_int64.has_value(), + "Failed to get values from attr ", attr_name, " in op ", node_name, " to qnn_model_wrapper."); + ORT_RETURN_IF_ERROR(AddQnnScalar(qnn_model_wrapper, node_unit.Index(), node_name, SafeInt(optional_int64.value()), attr_name, param_names)); + break; + } + case ONNX_NAMESPACE::AttributeProto::INTS: { + auto optional_int64s = node_helper.GetInt64s(attr_name); + ORT_RETURN_IF_NOT(optional_int64s.has_value(), + "Failed to get values from attr ", attr_name, " in op ", node_name, " to qnn_model_wrapper."); + std::vector int32s_data(optional_int64s.value().size(), 0); + for (size_t i = 0; i < optional_int64s.value().size(); i++) { + int32s_data[i] = SafeInt(optional_int64s.value()[i]); + } + auto param_wrapper = createQnnParamWrapper(node_unit.Index(), node_name, attr_name, + {static_cast(int32s_data.size())}, std::move(int32s_data)); + param_names.push_back(param_wrapper.GetParamTensorName()); + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddParamWrapper(std::move(param_wrapper)), + "Failed to add tensor attr ", attr_name, " in op ", node_name, " to qnn_model_wrapper."); + break; + } + case ONNX_NAMESPACE::AttributeProto::STRING: { + auto optional_string = node_helper.GetString(attr_name); + ORT_RETURN_IF_NOT(optional_string.has_value(), + "Failed to get values from attr ", attr_name, " in op ", node_name, " to qnn_model_wrapper."); + ORT_RETURN_IF_ERROR(AddQnnScalar(qnn_model_wrapper, node_unit.Index(), node_name, optional_string.value(), attr_name, param_names)); + break; + } + default: { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to add scalar attr ", attr_name, " data_type ", attr_value.type(), " in op ", node_name, " to qnn_model_wrapper."); + } + } + } + ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(node_name, + op_package, + op_type, + std::move(input_names), + std::move(output_names), + std::move(param_names), + do_op_validation), + "Failed to add node."); + + return Status::OK(); +} + +} // namespace + +std::unique_ptr UDOQDQFusion::TryFusion( + const std::string& op_type, + const std::string& op_package, + QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& udo_node_unit, + const std::unordered_map& node_to_node_unit, + const std::unordered_map& node_unit_to_qnn_node_group, + const logging::Logger& logger) { + const GraphViewer& graph_viewer = qnn_model_wrapper.GetGraphViewer(); + + // find all input DequantizeLinear nodes + std::map input_node_units; + Status status = GetInputNodeUnits(graph_viewer, udo_node_unit, node_to_node_unit, node_unit_to_qnn_node_group, input_node_units); + if (!status.IsOK()) { + return nullptr; + } + + // find all output QuantizeLinear nodes + std::map output_node_units; + status = GetOutputNodeUnits(graph_viewer, udo_node_unit, node_to_node_unit, output_node_units); + if (!status.IsOK()) { + return nullptr; + } + + // Convert UDO node + status = CreateOrValidateOnQnn(op_type, op_package, qnn_model_wrapper, input_node_units, udo_node_unit, output_node_units, true, logger); + if (!status.IsOK()) { + return nullptr; + } + + return std::make_unique(op_type, op_package, input_node_units, udo_node_unit, output_node_units); +} + +UDOQDQFusion::UDOQDQFusion( + const std::string& op_type, + const std::string& op_package, const std::map& input_node_units, const NodeUnit& node_unit, const std::map& output_node_units) + : op_type_(op_type), + op_package_(op_package), + input_node_units_(input_node_units), + node_unit_(&node_unit), + output_node_units_(output_node_units) { + // only return input dq nodes/ node unit / output q nodes since they are the same group + for (auto& input_node_unit : input_node_units_) { + if (input_node_unit.second->OpType() == DEQUANTIZE_LINEAR) { + all_nodes_.push_back(input_node_unit.second); + } + } + all_nodes_.push_back(node_unit_); + for (auto& output_node_unit : output_node_units_) { + if (output_node_unit.second->OpType() == QUANTIZE_LINEAR) { + all_nodes_.push_back(output_node_unit.second); + } + } +} +Status UDOQDQFusion::IsSupported(QnnModelWrapper& qmw, const logging::Logger& logger) const { + return CreateOrValidateOnQnn(op_type_, op_package_, qmw, input_node_units_, *node_unit_, output_node_units_, true, logger); +} + +Status UDOQDQFusion::AddToModelBuilder(QnnModelWrapper& qmw, const logging::Logger& logger) const { + return CreateOrValidateOnQnn(op_type_, op_package_, qmw, input_node_units_, *node_unit_, output_node_units_, false, logger); +} + +gsl::span UDOQDQFusion::GetNodeUnits() const { + auto res = gsl::make_span(all_nodes_.data(), all_nodes_.size()); + return res; +} + +const NodeUnit* UDOQDQFusion::GetTargetNodeUnit() const { + return node_unit_; +} + +} // namespace qnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/udo_fusion.h b/onnxruntime/core/providers/qnn/builder/qnn_node_group/udo_fusion.h new file mode 100644 index 0000000000000..668590f9acd1d --- /dev/null +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/udo_fusion.h @@ -0,0 +1,69 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include +#include +#include + +#include "core/providers/qnn/ort_api.h" +#include "core/providers/qnn/builder/qnn_node_group/qnn_node_group.h" + +namespace onnxruntime { +namespace qnn { + +class QnnModelWrapper; +/// +/// Represents a fusion of a DQ - UDO - Q +/// This is translated into a QNN custom op. One day this should be implemented in the QDQ actions. +/// The contained NodeUnits are of type SingleNode since they are not a part of a QDQ node unit. +/// +class UDOQDQFusion : public IQnnNodeGroup { + public: + UDOQDQFusion( + const std::string& op_type, + const std::string& op_package, + const std::map& input_dq_units, + const NodeUnit& node_unit, + const std::map& output_q_units); + ORT_DISALLOW_COPY_AND_ASSIGNMENT(UDOQDQFusion); + + Status IsSupported(QnnModelWrapper& qmw, const logging::Logger& logger) const override; + Status AddToModelBuilder(QnnModelWrapper& qmw, const logging::Logger& logger) const override; + gsl::span GetNodeUnits() const override; + const NodeUnit* GetTargetNodeUnit() const override; + std::string_view Type() const override { return "UDOQDQFusion"; } + + /// + /// Traverses graph to check if the given starting NodeUnit is part of a valid sequence. + /// If so, returns a IQnnNodeGroup that contains the UDO, 2x DQ and Q NodeUnits. + /// + /// Used for validation and traverse/query the graph + /// DQ node unit that could start the sequence + /// Maps a Node to a NodeUnit. + /// Maps a NodeUnit to a IQnnNodeGroup. + /// + /// A valid IQnnNodeGroup on success or an empty std::unique_ptr otherwise + static std::unique_ptr TryFusion( + const std::string& op_type, + const std::string& op_package, + QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& input_dq_unit, + const std::unordered_map& node_to_node_unit, + const std::unordered_map& node_unit_to_qnn_node_group, + const logging::Logger& logger); + + private: + const std::string op_type_; + const std::string op_package_; + const std::map input_node_units_; + const NodeUnit* node_unit_; + const std::map output_node_units_; + std::vector all_nodes_; +}; + +} // namespace qnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/qnn_utils.cc b/onnxruntime/core/providers/qnn/builder/qnn_utils.cc index c6b0b4b6668f3..407fce4a4374c 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_utils.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_utils.cc @@ -133,7 +133,7 @@ std::ostream& operator<<(std::ostream& out, const Qnn_Scalar_t& scalar) { out << scalar.int32Value; break; case QNN_DATATYPE_INT_64: - out << "int64_t is not supported in QNN except for UDO"; + out << "int64_t is not supported"; break; case QNN_DATATYPE_UINT_8: out << static_cast(scalar.uint8Value); @@ -145,7 +145,7 @@ std::ostream& operator<<(std::ostream& out, const Qnn_Scalar_t& scalar) { out << scalar.uint32Value; break; case QNN_DATATYPE_UINT_64: - out << "uint64_t is not supported in QNN except for UDO"; + out << "uint64_t is not supported"; break; case QNN_DATATYPE_FLOAT_16: break; From 6707dd4ec5a4b03d33852539ed16cb90ff6d751b Mon Sep 17 00:00:00 2001 From: Scott McKay Date: Thu, 3 Jul 2025 08:39:04 +1000 Subject: [PATCH 08/19] Use non-CPU device type and id for host accessible memory (#25043) ### Description Use the non-CPU device type and id for host accessible memory to make the link between CPU and the non-CPU device explicit. Update the data transfer implementations to check vendor id. ### Motivation and Context --- .../onnxruntime/core/framework/ortdevice.h | 17 +++++- .../core/framework/stream_handles.h | 8 ++- .../core/session/onnxruntime_c_api.h | 7 ++- .../core/framework/allocation_planner.cc | 54 +++++++++-------- onnxruntime/core/framework/allocator.cc | 5 +- onnxruntime/core/framework/data_transfer.cc | 2 +- .../debug_node_inputs_outputs_utils.cc | 4 +- onnxruntime/core/framework/execution_frame.cc | 2 +- onnxruntime/core/framework/session_state.cc | 8 ++- onnxruntime/core/framework/sparse_tensor.cc | 4 +- onnxruntime/core/framework/sparse_utils.cc | 6 +- onnxruntime/core/framework/utils.cc | 2 +- .../core/providers/cann/cann_allocator.h | 2 +- .../providers/cann/cann_execution_provider.cc | 5 +- .../core/providers/cann/npu_data_transfer.cc | 58 +++++++++++++------ .../core/providers/cuda/cuda_allocator.h | 6 +- .../providers/cuda/cuda_execution_provider.cc | 15 ++--- .../providers/cuda/cuda_provider_factory.cc | 6 +- .../providers/cuda/cuda_provider_factory.h | 2 +- .../core/providers/cuda/gpu_data_transfer.cc | 50 +++++++++++----- .../core/providers/cuda/tensor/reshape.cc | 9 +-- .../providers/migraphx/gpu_data_transfer.cc | 47 ++++++++++----- .../providers/migraphx/migraphx_allocator.h | 2 +- .../migraphx/migraphx_execution_provider.cc | 9 +-- .../providers/nv_tensorrt_rtx/nv_allocator.h | 6 +- .../nv_tensorrt_rtx/nv_data_transfer.cc | 51 +++++++++++----- .../nv_tensorrt_rtx/nv_execution_provider.cc | 10 ++-- .../core/providers/qnn/qnn_allocator.cc | 1 + .../providers/shared_library/provider_api.h | 4 +- .../provider_bridge_provider.cc | 4 +- .../shared_library/provider_interfaces.h | 4 +- .../tensorrt/tensorrt_execution_provider.cc | 9 ++- onnxruntime/core/session/onnxruntime_c_api.cc | 5 +- .../core/session/provider_bridge_ort.cc | 6 +- .../python/onnxruntime_pybind_state.cc | 2 +- .../test/framework/allocation_planner_test.cc | 3 +- .../cuda/test_cases/allocator_cuda_test.cc | 2 +- .../cuda_execution_provider_test.cc | 16 ++--- .../cuda/test_cases/cuda_test_provider.cc | 2 +- 39 files changed, 275 insertions(+), 180 deletions(-) diff --git a/include/onnxruntime/core/framework/ortdevice.h b/include/onnxruntime/core/framework/ortdevice.h index ffc0da918c9df..536d641b4eef9 100644 --- a/include/onnxruntime/core/framework/ortdevice.h +++ b/include/onnxruntime/core/framework/ortdevice.h @@ -13,7 +13,9 @@ #undef INTEL #endif -// Struct to represent a physical device. +// Struct to represent a combination of physical device and memory type. +// A memory allocation and allocator have a specific OrtDevice associated with them, and this information is used +// to determine when data transfer is required. struct OrtDevice { using DeviceType = int8_t; using MemoryType = int8_t; @@ -41,7 +43,13 @@ struct OrtDevice { QNN_HTP_SHARED = 4, }; - static const MemoryType HOST_ACCESSIBLE = 5; // Device memory that is accessible from host and device. + // HOST_ACCESSIBLE memory is treated as CPU memory. + // When creating an OrtDevice with MemType::HOST_ACCESSIBLE: + // - For memory that is only accessible by a specific device and CPU, use the specific device type and id. + // - When creating an OrtDevice for an EP allocator, you would typically use the same device type and id + // that the EP is registered with (i.e. the OrtDevice passed to the base IExecutionProvider constructor). + // - Otherwise use OrtDevice::CPU. + static const MemoryType HOST_ACCESSIBLE = 5; }; // PCI vendor ids @@ -101,6 +109,11 @@ struct OrtDevice { return alignment; } + // CPU or HOST_ACCESSIBLE memory. + bool UsesCpuMemory() const noexcept { + return device_type == CPU || memory_type == MemType::HOST_ACCESSIBLE; + } + std::string ToString() const { std::ostringstream ostr; ostr << "Device:[" diff --git a/include/onnxruntime/core/framework/stream_handles.h b/include/onnxruntime/core/framework/stream_handles.h index 402ea2da2148c..9b9bc94105005 100644 --- a/include/onnxruntime/core/framework/stream_handles.h +++ b/include/onnxruntime/core/framework/stream_handles.h @@ -26,7 +26,9 @@ class Notification; // i.e. different cuda stream on different GPU. class Stream { public: - Stream(StreamHandle h, const OrtDevice& d) : handle_(h), device_(d) {} + Stream(StreamHandle h, const OrtDevice& d) + : handle_(h), device_(d) { + } virtual ~Stream() = default; virtual std::unique_ptr CreateNotification(size_t /*num_consumers*/) { @@ -168,8 +170,8 @@ class IStreamCommandHandleRegistry { virtual ~IStreamCommandHandleRegistry() = default; // Wait is a little special as we need to consider the source stream the notification generated, and the stream we are waiting. // i.e., for an cuda event what notify the memory copy, it could be wait on a CPU stream, or on another cuda stream. - [[nodiscard]] virtual WaitNotificationFn GetWaitHandle(OrtDevice::DeviceType notification_ower_device_type, - OrtDevice::DeviceType executor_device_type) const = 0; + [[nodiscard]] virtual WaitNotificationFn GetWaitHandle(const OrtDevice& notification_owner_device, + const OrtDevice& executor_device) const = 0; // Get the stream creation function registered on the given device type. [[nodiscard]] virtual CreateStreamFn GetCreateStreamFn(OrtDevice::DeviceType execution_device_type) const = 0; // register a wait methond which will be invoked when we wait a notification (created by 'notification_device_type' device) on a stream at 'device_type' device. diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 9106cd94ad031..c9aaa38426a7b 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -5376,13 +5376,14 @@ struct OrtApi { * \param[in] vendor_id PCI Vendor ID. Use 0 for a generic allocator (e.g. WebGPU). * \param[in] device_id Device ID if there are multiple devices of the same type. e.g. 2 GPU devices. * \param[in] mem_type Memory type. Use OrtDeviceMemoryType_DEFAULT for device memory, and - * OrtDeviceMemoryType_HOST_ACCESSIBLE (if applicable) for memory used to transfer - * between the device and the CPU. + * OrtDeviceMemoryType_HOST_ACCESSIBLE (if applicable) for memory used to transfer between the + * device and the CPU. Use the device_type and device_id of the GPU/NPU that the memory is also + * accessible to. * \param[in] alignment Alignment of the memory if required. Pass 0 for default alignment. * \param[in] allocator_type Allocator type. If OrtAllocatorType::OrtArenaAllocator, the ORT arena will be used. * Caveat: Support for OrtArenaAllocator is currently limited to usage of internal ORT * allocators via CreateAllocator/CreateAndRegisterAllocator/CreateAndRegisterAllocatorV2. - * \param[out] out Newly created ::OrtMemoryInfo. Must be freed with OrtAPi::ReleaseMemoryInfo + * \param[out] out Newly created ::OrtMemoryInfo. Must be freed with OrtApi::ReleaseMemoryInfo * * \snippet{doc} snippets.dox OrtStatus Return Value * diff --git a/onnxruntime/core/framework/allocation_planner.cc b/onnxruntime/core/framework/allocation_planner.cc index 62895c0137a78..a0c00b1cd26e5 100644 --- a/onnxruntime/core/framework/allocation_planner.cc +++ b/onnxruntime/core/framework/allocation_planner.cc @@ -730,22 +730,25 @@ class PlannerImpl { ProcessDef(index, graph_viewer_.GetNodeArg(pair.first)); } - // If the suggested_device is also CPU and default mem type, then - // we check which one has higher alignment and use that one if it is so. - // If the suggested device is CPU, but not the default mem type, then - // it is a CPU accessible memory device allocator. They typically have a page aligment - // so that would satisfy the alignment requirement of any other CPU consumers. - // If one device is not on CPU, we default on the one that is CPU. + // If both devices are OrtDevice::CPU or both are HOST_ACCESSIBLE we use the one with the higher alignment. + // If one is OrtDevice::CPU and one is HOST_ACCESSIBLE memory, we use the HOST_ACCESSIBLE one as that would + // typically have a page alignment and would satisfy the alignment requirement of any other CPU consumers. + // If one device is not on CPU, we default to the one that is CPU. auto determine_device = [](const OrtDevice& output_device, const OrtDevice& suggested_device) -> OrtDevice { - if (output_device.Type() == OrtDevice::CPU && suggested_device.Type() == OrtDevice::CPU) { - if (output_device.MemType() == OrtDevice::MemType::DEFAULT && - suggested_device.MemType() == OrtDevice::MemType::DEFAULT) { + const bool output_is_cpu = output_device.UsesCpuMemory(); // CPU or HOST_ACCESSIBLE memory + const bool suggested_is_cpu = suggested_device.UsesCpuMemory(); + if (output_is_cpu && suggested_is_cpu) { + // if both are CPU or both are HOST_ACCESSIBLE pick based on alignment. + if ((output_device.Type() == OrtDevice::CPU && suggested_device.Type() == OrtDevice::CPU) || + (output_device.MemType() == OrtDevice::MemType::HOST_ACCESSIBLE && + suggested_device.MemType() == OrtDevice::MemType::HOST_ACCESSIBLE)) { return (output_device.GetAlignment() >= suggested_device.GetAlignment()) ? output_device : suggested_device; } else { - return (output_device.MemType() != OrtDevice::MemType::DEFAULT) ? output_device : suggested_device; + // prefer host accessible memory device allocator as it most likely has the higher alignment requirement + return (output_device.Type() != OrtDevice::CPU) ? output_device : suggested_device; } } else { - return (output_device.Type() == OrtDevice::CPU) ? output_device : suggested_device; + return (output_is_cpu) ? output_device : suggested_device; } }; @@ -916,19 +919,17 @@ class PlannerImpl { // We only do it for CPU based EPs. We are not likely to encounter // non CPU devices here since they are already taken care of by using MemCpy nodes earlier. // However, we still ignore them. - if (output_device.Type() == OrtDevice::CPU && - output_device.MemType() == OrtDevice::MemType::DEFAULT) { + if (output_device.Type() == OrtDevice::CPU) { const auto& output_name = node_output->Name(); const auto consumers = graph_viewer_.GetConsumerNodes(output_name); for (const auto* consumer : consumers) { if (consumer != nullptr) { const auto& ep_type = consumer->GetExecutionProviderType(); - auto suggested_device = execution_providers_.Get(ep_type)->GetOrtDeviceByMemType( - OrtMemType::OrtMemTypeCPUInput); - if (suggested_device.Type() == OrtDevice::CPU && - suggested_device.MemType() == OrtDevice::MemType::DEFAULT) { + auto suggested_device = execution_providers_.Get(ep_type) + ->GetOrtDeviceByMemType(OrtMemType::OrtMemTypeCPUInput); + if (suggested_device.Type() == OrtDevice::CPU) { output_device = determine_device(output_device, suggested_device); - } else if (suggested_device.Type() == OrtDevice::CPU) { + } else if (suggested_device.UsesCpuMemory()) { // Edge case: there are more than one downstream nodes that suggest their own CPU accessible // memory. In that case, we can not win them all, but the chosen device would still make it run // and reduce a number of copies for some. @@ -2070,10 +2071,10 @@ class PlannerImpl { for (size_t i = 0; i < num_logic_streams_; ++i) { for (auto node_index : stream_nodes_[i]) { auto* node = graph_viewer_.GetNode(node_index); - auto stream_device = execution_plan[i]->device_.Type(); + auto stream_device = execution_plan[i]->device_; // Neither trigger ActivateNotification/WaitOnEPStep for Shape op (whose output is ready for all the EPs), nor // upstream is on CPU device (As currently we never invoke RegisterWaitFn(CPU, ...) for all kinds of EP, thus no wait_handle can be retrieved for this case) - if (node->OpType() != "Shape" && stream_device != OrtDevice::CPU) { + if (node->OpType() != "Shape" && !stream_device.UsesCpuMemory()) { for (auto it = node->OutputNodesBegin(); it != node->OutputNodesEnd(); ++it) { bool output_consumed_in_subgraph = true; for (auto* output : node->OutputDefs()) { @@ -2087,9 +2088,11 @@ class PlannerImpl { // 2. the consumer is in the same stream(non-cpu device), but it consumes a CPU tensor from an non-shape op. // for example, a resize cuda kernel consumer a tensor from MemCpyToHost cuda kernel on the same stream. // in this case, the FIFO can't guarantee the cpu tensor is ready when resize kernel is launching - OrtDevice::DeviceType output_arg_device = AllocPlan(output_arg_idx).location.Type(); - WaitNotificationFn wait_handle = stream_handle_registry.GetWaitHandle(stream_device, output_arg_device); - if ((plan_.node_stream_map_[it->Index()] != i || output_arg_device == OrtDevice::CPU) && wait_handle != nullptr) { + const auto& output_arg_device = AllocPlan(output_arg_idx).location; + WaitNotificationFn wait_handle = stream_handle_registry.GetWaitHandle(stream_device, + output_arg_device); + if ((plan_.node_stream_map_[it->Index()] != i || output_arg_device.UsesCpuMemory()) && + wait_handle != nullptr) { if (node_to_notification.find(node_index) == node_to_notification.end()) { node_to_notification[node_index] = plan_.notification_owners.size(); plan_.notification_owners.push_back(i); @@ -2103,8 +2106,9 @@ class PlannerImpl { if (output_consumed_in_subgraph) { const auto downstream = plan_.node_stream_map_[it->Index()]; if (downstream != i) { - auto downstream_device = execution_plan[downstream]->device_.Type(); - WaitNotificationFn wait_handle = stream_handle_registry.GetWaitHandle(stream_device, downstream_device); + const auto& downstream_device = execution_plan[downstream]->device_; + WaitNotificationFn wait_handle = stream_handle_registry.GetWaitHandle(stream_device, + downstream_device); if (wait_handle) { if (node_to_notification.find(node_index) == node_to_notification.end()) { node_to_notification[node_index] = plan_.notification_owners.size(); diff --git a/onnxruntime/core/framework/allocator.cc b/onnxruntime/core/framework/allocator.cc index c2ff70b8e9808..30ff8342a8009 100644 --- a/onnxruntime/core/framework/allocator.cc +++ b/onnxruntime/core/framework/allocator.cc @@ -223,12 +223,13 @@ ORT_API_STATUS_IMPL(OrtApis::CreateMemoryInfo, _In_ const char* name1, enum OrtA mem_type1); } else if (strcmp(name1, onnxruntime::CUDA_PINNED) == 0) { *out = new OrtMemoryInfo( - name1, type, OrtDevice(OrtDevice::CPU, OrtDevice::MemType::HOST_ACCESSIBLE, OrtDevice::VendorIds::NVIDIA, device_id), + name1, type, + OrtDevice(OrtDevice::GPU, OrtDevice::MemType::HOST_ACCESSIBLE, OrtDevice::VendorIds::NVIDIA, device_id), mem_type1); } else if (strcmp(name1, onnxruntime::HIP_PINNED) == 0) { *out = new OrtMemoryInfo( name1, type, - OrtDevice(OrtDevice::CPU, OrtDevice::MemType::HOST_ACCESSIBLE, OrtDevice::VendorIds::AMD, device_id), + OrtDevice(OrtDevice::GPU, OrtDevice::MemType::HOST_ACCESSIBLE, OrtDevice::VendorIds::AMD, device_id), mem_type1); } else if (strcmp(name1, onnxruntime::QNN_HTP_SHARED) == 0) { *out = new OrtMemoryInfo( diff --git a/onnxruntime/core/framework/data_transfer.cc b/onnxruntime/core/framework/data_transfer.cc index f0109267d2975..ef2f62d963528 100644 --- a/onnxruntime/core/framework/data_transfer.cc +++ b/onnxruntime/core/framework/data_transfer.cc @@ -42,7 +42,7 @@ common::Status IDataTransfer::CopySparseTensors(const std::vectorsession_state_.GetStreamHandleRegistryInstance().GetWaitHandle( - current_stream->GetDevice().Type(), current_stream->GetDevice().Type()); + current_stream->GetDevice(), current_stream->GetDevice()); void* p_data = stream_aware_alloc->AllocOnStream(buffer_size, current_stream, wait_handle); Tensor::InitOrtValue(element_type, shape, p_data, std::move(alloc), ort_value); } else { diff --git a/onnxruntime/core/framework/session_state.cc b/onnxruntime/core/framework/session_state.cc index 7d0026cc35558..2cd5103b823d1 100644 --- a/onnxruntime/core/framework/session_state.cc +++ b/onnxruntime/core/framework/session_state.cc @@ -33,8 +33,12 @@ class StreamCommandHandleRegistryImpl : public IStreamCommandHandleRegistry { // Wait is a little special as we need to consider the source stream the notification generated, // and the stream we are waiting. // i.e., for an cuda event what notify the memory copy, it could be wait on a CPU stream, or on another cuda stream. - WaitNotificationFn GetWaitHandle(const OrtDevice::DeviceType notification_owner_device_type, - const OrtDevice::DeviceType executor_device_type) const override { + WaitNotificationFn GetWaitHandle(const OrtDevice& notification_owner_device, + const OrtDevice& executor_device) const override { + auto notification_owner_device_type = notification_owner_device.UsesCpuMemory() ? OrtDevice::CPU + : notification_owner_device.Type(); + auto executor_device_type = executor_device.UsesCpuMemory() ? OrtDevice::CPU : executor_device.Type(); + auto it = notification_wait_map_.find(GetWaitKey(notification_owner_device_type, executor_device_type)); return it == notification_wait_map_.end() ? nullptr : it->second; } diff --git a/onnxruntime/core/framework/sparse_tensor.cc b/onnxruntime/core/framework/sparse_tensor.cc index 4e40e3dd81ca2..ffbdd11d59846 100644 --- a/onnxruntime/core/framework/sparse_tensor.cc +++ b/onnxruntime/core/framework/sparse_tensor.cc @@ -525,8 +525,8 @@ Status SparseTensor::Copy(const IDataTransfer& data_transfer, SparseTensor& dst_ ORT_RETURN_IF_NOT(dst_tensor.Format() == SparseFormat::kUndefined, "Destination should be empty"); ORT_RETURN_IF_NOT(dst_tensor.allocator_ != nullptr, "Destination must have a CPU allocator set"); - ORT_RETURN_IF_NOT((!is_string || dst_tensor.Location().device.Type() == OrtDevice::CPU), - "X-device copy of strings not supported"); + ORT_RETURN_IF((is_string && !dst_tensor.Location().device.UsesCpuMemory()), + "X-device copy of strings not supported"); ORT_RETURN_IF_NOT(dst_tensor.DataType() == DataType(), "Src and Dst must be of the same type"); ORT_RETURN_IF_NOT(dst_tensor.dense_shape_.Size() == dense_shape_.Size(), "Must have the same shape"); diff --git a/onnxruntime/core/framework/sparse_utils.cc b/onnxruntime/core/framework/sparse_utils.cc index b186612e0240b..c42f6d190512c 100644 --- a/onnxruntime/core/framework/sparse_utils.cc +++ b/onnxruntime/core/framework/sparse_utils.cc @@ -79,9 +79,9 @@ Status DenseTensorToSparseCsr(const DataTransferManager& data_manager, const Ten const bool is_string = src.IsDataTypeString(); - if (is_string && dst_allocator->Info().device.Type() != OrtDevice::CPU) { + if (is_string && !dst_allocator->Info().device.UsesCpuMemory()) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Unable to convert strings tensor to a sparse tensor that not on CPU"); + "Unable to convert strings tensor to a sparse tensor that is not on CPU"); } const IDataTransfer* data_transfer = data_manager.GetDataTransfer(cpu_allocator->Info().device, @@ -514,4 +514,4 @@ Status DenseTensorToSparseCoo(const DataTransferManager& data_manager, const Ten } // namespace sparse_utils } // namespace onnxruntime -#endif // !defined(DISABLE_SPARSE_TENSORS) \ No newline at end of file +#endif // !defined(DISABLE_SPARSE_TENSORS) diff --git a/onnxruntime/core/framework/utils.cc b/onnxruntime/core/framework/utils.cc index b4f01bca1a097..c6bb5d931cbe6 100644 --- a/onnxruntime/core/framework/utils.cc +++ b/onnxruntime/core/framework/utils.cc @@ -556,7 +556,7 @@ common::Status CopyOneInputAcrossDevices(const SessionState& session_state, cons size_t num_streams = device_stream_collection->NumStreams(); for (size_t i = 0; i < num_streams; i++) { Stream* stream = device_stream_collection->GetStream(i); - if (stream && stream->GetDevice().Type() != OrtDevice::CPU) { + if (stream && !stream->GetDevice().UsesCpuMemory()) { device_stream = stream; break; } diff --git a/onnxruntime/core/providers/cann/cann_allocator.h b/onnxruntime/core/providers/cann/cann_allocator.h index 14daf46e45b16..9607ee1a35049 100644 --- a/onnxruntime/core/providers/cann/cann_allocator.h +++ b/onnxruntime/core/providers/cann/cann_allocator.h @@ -27,7 +27,7 @@ class CANNPinnedAllocator : public IAllocator { CANNPinnedAllocator(OrtDevice::DeviceId device_id, const char* name) : IAllocator( OrtMemoryInfo(name, OrtAllocatorType::OrtDeviceAllocator, - OrtDevice(OrtDevice::CPU, OrtDevice::MemType::HOST_ACCESSIBLE, OrtDevice::VendorIds::HUAWEI, + OrtDevice(OrtDevice::NPU, OrtDevice::MemType::HOST_ACCESSIBLE, OrtDevice::VendorIds::HUAWEI, device_id), OrtMemTypeCPUOutput)) {} diff --git a/onnxruntime/core/providers/cann/cann_execution_provider.cc b/onnxruntime/core/providers/cann/cann_execution_provider.cc index 47ba70d4a5529..7943f56d12741 100644 --- a/onnxruntime/core/providers/cann/cann_execution_provider.cc +++ b/onnxruntime/core/providers/cann/cann_execution_provider.cc @@ -1474,7 +1474,7 @@ std::vector CANNExecutionProvider::CreatePreferredAllocators() { [](OrtDevice::DeviceId device_id) { return std::make_unique(device_id, CANN_PINNED); }, - DEFAULT_CPU_ALLOCATOR_DEVICE_ID); + info_.device_id); return std::vector{ CreateCannAllocator(info_.device_id, info_.npu_mem_limit, info_.arena_extend_strategy, @@ -1491,7 +1491,8 @@ OrtDevice CANNExecutionProvider::GetOrtDeviceByMemType(OrtMemType mem_type) cons if (mem_type == OrtMemTypeCPUInput) return OrtDevice(); if (mem_type == OrtMemTypeCPUOutput) - return OrtDevice(OrtDevice::CPU, OrtDevice::MemType::HOST_ACCESSIBLE, OrtDevice::VendorIds::HUAWEI, 0); + return OrtDevice(OrtDevice::NPU, OrtDevice::MemType::HOST_ACCESSIBLE, OrtDevice::VendorIds::HUAWEI, + default_device_.Id()); return default_device_; } diff --git a/onnxruntime/core/providers/cann/npu_data_transfer.cc b/onnxruntime/core/providers/cann/npu_data_transfer.cc index 7821926a98a94..93596fbd53617 100644 --- a/onnxruntime/core/providers/cann/npu_data_transfer.cc +++ b/onnxruntime/core/providers/cann/npu_data_transfer.cc @@ -11,7 +11,19 @@ NPUDataTransfer::NPUDataTransfer() {} NPUDataTransfer::~NPUDataTransfer() {} bool NPUDataTransfer::CanCopy(const OrtDevice& src_device, const OrtDevice& dst_device) const { - return src_device.Type() == OrtDevice::NPU || dst_device.Type() == OrtDevice::NPU; + OrtDevice::DeviceType src_type = src_device.Type(); + OrtDevice::DeviceType dst_type = dst_device.Type(); + + // check that only our NPU is involved + if ((src_type == OrtDevice::NPU && src_device.Vendor() != OrtDevice::VendorIds::HUAWEI) || + (dst_type == OrtDevice::NPU && dst_device.Vendor() != OrtDevice::VendorIds::HUAWEI)) { + return false; + } + + // copy must involve an NPU, and be device to device or cpu (exclude other device types) + return (src_type == OrtDevice::NPU || dst_type == OrtDevice::NPU) && + (src_type == OrtDevice::NPU || src_type == OrtDevice::CPU) && + (dst_type == OrtDevice::NPU || dst_type == OrtDevice::CPU); } common::Status NPUDataTransfer::CopyTensor(const Tensor& src, Tensor& dst) const { @@ -22,9 +34,14 @@ common::Status NPUDataTransfer::CopyTensor(const Tensor& src, Tensor& dst) const auto& src_device = src.Location().device; auto& dst_device = dst.Location().device; + const bool dst_is_npu_default = dst_device.Type() == OrtDevice::NPU && + dst_device.MemType() == OrtDevice::MemType::DEFAULT; + const bool src_is_npu_default = src_device.Type() == OrtDevice::NPU && + src_device.MemType() == OrtDevice::MemType::DEFAULT; + // for the sync version of memcpy, launch to cann default stream - if (dst_device.Type() == OrtDevice::NPU) { - if (src_device.Type() == OrtDevice::NPU) { + if (dst_is_npu_default) { + if (src_is_npu_default) { // Copy only if the two addresses are different. if (dst_data != src_data) { CANN_RETURN_IF_ERROR(aclrtMemcpy(dst_data, @@ -43,7 +60,7 @@ common::Status NPUDataTransfer::CopyTensor(const Tensor& src, Tensor& dst) const ACL_MEMCPY_HOST_TO_DEVICE)); CANN_RETURN_IF_ERROR(aclrtSynchronizeStream(nullptr)); } - } else if (src_device.Type() == OrtDevice::NPU) { + } else if (src_is_npu_default) { // copying from NPU to CPU memory, this is blocking CANN_RETURN_IF_ERROR(aclrtMemcpy(dst_data, bytes, @@ -67,16 +84,13 @@ common::Status NPUDataTransfer::CopyTensorAsync(const Tensor& src, Tensor& dst, auto& src_device = src.Location().device; auto& dst_device = dst.Location().device; - if (dst_device.Type() == OrtDevice::NPU) { - if (src_device.Type() == OrtDevice::CPU) { - // copy from pinned memory to NPU, this is non-blocking - CANN_RETURN_IF_ERROR(aclrtMemcpyAsync(dst_data, - bytes, - src_data, - bytes, - ACL_MEMCPY_HOST_TO_DEVICE, - static_cast(stream.GetHandle()))); - } else if (src_device.Type() == OrtDevice::NPU) { + const bool dst_is_npu_default = dst_device.Type() == OrtDevice::NPU && + dst_device.MemType() == OrtDevice::MemType::DEFAULT; + const bool src_is_npu_default = src_device.Type() == OrtDevice::NPU && + src_device.MemType() == OrtDevice::MemType::DEFAULT; + + if (dst_is_npu_default) { + if (src_is_npu_default) { // copying between NPU, this is non-blocking if (dst_data != src_data) { CANN_RETURN_IF_ERROR(aclrtMemcpyAsync(dst_data, @@ -86,17 +100,23 @@ common::Status NPUDataTransfer::CopyTensorAsync(const Tensor& src, Tensor& dst, ACL_MEMCPY_DEVICE_TO_DEVICE, static_cast(stream.GetHandle()))); } - } - } else if (src_device.Type() == OrtDevice::NPU) { - if (dst_device.Type() == OrtDevice::CPU) { - // copying from NPU to pinned memory, this is non-blocking + } else { + // copy from pinned or CPU memory to NPU, this is non-blocking CANN_RETURN_IF_ERROR(aclrtMemcpyAsync(dst_data, bytes, src_data, bytes, - ACL_MEMCPY_DEVICE_TO_HOST, + ACL_MEMCPY_HOST_TO_DEVICE, static_cast(stream.GetHandle()))); } + } else if (src_is_npu_default) { + // copying from NPU to pinned or CPU memory, this is non-blocking + CANN_RETURN_IF_ERROR(aclrtMemcpyAsync(dst_data, + bytes, + src_data, + bytes, + ACL_MEMCPY_DEVICE_TO_HOST, + static_cast(stream.GetHandle()))); } else { if (src_device.MemType() == OrtDevice::MemType::HOST_ACCESSIBLE) { // sync the stream first to make sure the data arrived diff --git a/onnxruntime/core/providers/cuda/cuda_allocator.h b/onnxruntime/core/providers/cuda/cuda_allocator.h index 004ec2e876ec0..3ba1f09ed4922 100644 --- a/onnxruntime/core/providers/cuda/cuda_allocator.h +++ b/onnxruntime/core/providers/cuda/cuda_allocator.h @@ -53,11 +53,11 @@ class CUDAExternalAllocator : public CUDAAllocator { // TODO: add a default constructor class CUDAPinnedAllocator : public IAllocator { public: - CUDAPinnedAllocator(const char* name) + CUDAPinnedAllocator(OrtDevice::DeviceId device_id, const char* name) : IAllocator( OrtMemoryInfo(name, OrtAllocatorType::OrtDeviceAllocator, - OrtDevice(OrtDevice::CPU, OrtDevice::MemType::HOST_ACCESSIBLE, OrtDevice::VendorIds::NVIDIA, - 0 /*CPU device always with id 0*/), + OrtDevice(OrtDevice::GPU, OrtDevice::MemType::HOST_ACCESSIBLE, OrtDevice::VendorIds::NVIDIA, + device_id), OrtMemTypeCPUOutput)) {} void* Alloc(size_t size) override; diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index 3aaeeee1cbc20..cc7edbc15c329 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -2877,22 +2877,17 @@ OrtDevice CUDAExecutionProvider::GetOrtDeviceByMemType(OrtMemType mem_type) cons if (mem_type == OrtMemTypeCPUInput) return OrtDevice(); if (mem_type == OrtMemTypeCPUOutput) - return OrtDevice(OrtDevice::CPU, OrtDevice::MemType::HOST_ACCESSIBLE, OrtDevice::VendorIds::NVIDIA, - 0 /*CPU device id always be 0*/); + return OrtDevice(OrtDevice::GPU, OrtDevice::MemType::HOST_ACCESSIBLE, OrtDevice::VendorIds::NVIDIA, + default_device_.Id()); return default_device_; } std::vector CUDAExecutionProvider::CreatePreferredAllocators() { AllocatorCreationInfo pinned_memory_info( - [](OrtDevice::DeviceId) { - return std::make_unique(CUDA_PINNED); + [](OrtDevice::DeviceId device_id) { + return std::make_unique(device_id, CUDA_PINNED); }, - // TODO: should we use info_.device_id instead of DEFAULT_CPU_ALLOCATOR_DEVICE_ID? - // https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__DEVICE.html#group__CUDART__DEVICE_1g159587909ffa0791bbe4b40187a4c6bb - // says the pinned memory allocated by cudaMallocHost is associated with a specific device, so it may be more - // correct to use the GPU device id, unless we wanted to share the pinned memory allocator across devices, - // at the risk the lifetime isn't managed correctly if one of those devices go away. - 0); + info_.device_id); return std::vector{ CreateCudaAllocator(info_.device_id, info_.gpu_mem_limit, info_.arena_extend_strategy, info_.external_allocator_info, info_.default_memory_arena_cfg), diff --git a/onnxruntime/core/providers/cuda/cuda_provider_factory.cc b/onnxruntime/core/providers/cuda/cuda_provider_factory.cc index b1df607e8ce99..6ba2dd8176590 100644 --- a/onnxruntime/core/providers/cuda/cuda_provider_factory.cc +++ b/onnxruntime/core/providers/cuda/cuda_provider_factory.cc @@ -84,8 +84,8 @@ struct ProviderInfo_CUDA_Impl final : ProviderInfo_CUDA { return std::make_unique(device_id, name); } - std::unique_ptr CreateCUDAPinnedAllocator(const char* name) override { - return std::make_unique(name); + std::unique_ptr CreateCUDAPinnedAllocator(int16_t device_id, const char* name) override { + return std::make_unique(device_id, name); } std::unique_ptr CreateGPUDataTransfer() override { @@ -112,7 +112,7 @@ struct ProviderInfo_CUDA_Impl final : ProviderInfo_CUDA { void CudaCall_true(int retCode, const char* exprString, const char* libName, int successCode, const char* msg, const char* file, const int line) override { CudaCall(cudaError(retCode), exprString, libName, cudaError(successCode), msg, file, line); } void CopyGpuToCpu(void* dst_ptr, const void* src_ptr, const size_t size, const OrtMemoryInfo& dst_location, const OrtMemoryInfo& src_location) override { - ORT_ENFORCE(dst_location.device.Type() == OrtDevice::CPU); + ORT_ENFORCE(dst_location.device.UsesCpuMemory(), "Copy destination is not CPU memory"); // Current CUDA device. int device; diff --git a/onnxruntime/core/providers/cuda/cuda_provider_factory.h b/onnxruntime/core/providers/cuda/cuda_provider_factory.h index 4d5ef658f6be0..cf352757686bd 100644 --- a/onnxruntime/core/providers/cuda/cuda_provider_factory.h +++ b/onnxruntime/core/providers/cuda/cuda_provider_factory.h @@ -25,7 +25,7 @@ struct ProviderInfo_CUDA { virtual OrtStatus* GetCurrentGpuDeviceId(_In_ int* device_id) = 0; virtual std::unique_ptr CreateCUDAAllocator(int16_t device_id, const char* name) = 0; - virtual std::unique_ptr CreateCUDAPinnedAllocator(const char* name) = 0; + virtual std::unique_ptr CreateCUDAPinnedAllocator(int16_t device_id, const char* name) = 0; virtual std::unique_ptr CreateGPUDataTransfer() = 0; virtual void cuda__Impl_Cast(void* stream, const int64_t* input_data, int32_t* output_data, size_t count) = 0; diff --git a/onnxruntime/core/providers/cuda/gpu_data_transfer.cc b/onnxruntime/core/providers/cuda/gpu_data_transfer.cc index 8127b4697de22..68f159ea0f843 100644 --- a/onnxruntime/core/providers/cuda/gpu_data_transfer.cc +++ b/onnxruntime/core/providers/cuda/gpu_data_transfer.cc @@ -8,8 +8,18 @@ namespace onnxruntime { bool GPUDataTransfer::CanCopy(const OrtDevice& src_device, const OrtDevice& dst_device) const { - return src_device.Type() == OrtDevice::GPU || src_device.MemType() == OrtDevice::MemType::HOST_ACCESSIBLE || - dst_device.Type() == OrtDevice::GPU || dst_device.MemType() == OrtDevice::MemType::HOST_ACCESSIBLE; + OrtDevice::DeviceType src_type = src_device.Type(); + OrtDevice::DeviceType dst_type = dst_device.Type(); + + if ((src_type == OrtDevice::GPU && src_device.Vendor() != OrtDevice::VendorIds::NVIDIA) || + (dst_type == OrtDevice::GPU && dst_device.Vendor() != OrtDevice::VendorIds::NVIDIA)) { + return false; + } + + // copy must involve a GPU, and be device to device or cpu (exclude other device types) + return (src_type == OrtDevice::GPU || dst_type == OrtDevice::GPU) && + (src_type == OrtDevice::GPU || src_type == OrtDevice::CPU) && + (dst_type == OrtDevice::GPU || dst_type == OrtDevice::CPU); } common::Status GPUDataTransfer::CopyTensor(const Tensor& src, Tensor& dst) const { @@ -20,9 +30,14 @@ common::Status GPUDataTransfer::CopyTensor(const Tensor& src, Tensor& dst) const auto& src_device = src.Location().device; auto& dst_device = dst.Location().device; + const bool dst_is_gpu_default = dst_device.Type() == OrtDevice::GPU && + dst_device.MemType() == OrtDevice::MemType::DEFAULT; + const bool src_is_gpu_default = src_device.Type() == OrtDevice::GPU && + src_device.MemType() == OrtDevice::MemType::DEFAULT; + // for the sync version of memcpy, launch to cuda default stream - if (dst_device.Type() == OrtDevice::GPU) { - if (src_device.Type() == OrtDevice::GPU) { + if (dst_is_gpu_default) { + if (src_is_gpu_default) { // Copy only if the two addresses are different. if (dst_data != src_data) { CUDA_RETURN_IF_ERROR(cudaMemcpy(dst_data, src_data, bytes, cudaMemcpyDeviceToDevice)); @@ -39,7 +54,7 @@ common::Status GPUDataTransfer::CopyTensor(const Tensor& src, Tensor& dst) const CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(nullptr)); } } - } else if (src_device.Type() == OrtDevice::GPU) { + } else if (src_is_gpu_default) { // copying from GPU to CPU memory, this is blocking CUDA_RETURN_IF_ERROR(cudaMemcpy(dst_data, src_data, bytes, cudaMemcpyDeviceToHost)); } else { @@ -59,24 +74,27 @@ common::Status GPUDataTransfer::CopyTensorAsync(const Tensor& src, Tensor& dst, auto& src_device = src.Location().device; auto& dst_device = dst.Location().device; - if (dst_device.Type() == OrtDevice::GPU) { - if (src_device.Type() == OrtDevice::CPU) { - // copy from pinned or non-pinned CPU memory to GPU - CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(dst_data, src_data, bytes, cudaMemcpyHostToDevice, - static_cast(stream.GetHandle()))); - } else if (src_device.Type() == OrtDevice::GPU) { + const bool dst_is_gpu_default = dst_device.Type() == OrtDevice::GPU && + dst_device.MemType() == OrtDevice::MemType::DEFAULT; + const bool src_is_gpu_default = src_device.Type() == OrtDevice::GPU && + src_device.MemType() == OrtDevice::MemType::DEFAULT; + + if (dst_is_gpu_default) { + if (src_is_gpu_default) { // copying between GPU, this is non-blocking if (dst_data != src_data) { CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(dst_data, src_data, bytes, cudaMemcpyDeviceToDevice, static_cast(stream.GetHandle()))); } - } - } else if (src_device.Type() == OrtDevice::GPU) { - if (dst_device.Type() == OrtDevice::CPU) { - // copy from GPU to pinned or non-pinned CPU memory. - CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(dst_data, src_data, bytes, cudaMemcpyDeviceToHost, + } else { + // copy from pinned or non-pinned CPU memory to GPU + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(dst_data, src_data, bytes, cudaMemcpyHostToDevice, static_cast(stream.GetHandle()))); } + } else if (src_is_gpu_default) { + // copy from GPU to pinned or non-pinned CPU memory. + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(dst_data, src_data, bytes, cudaMemcpyDeviceToHost, + static_cast(stream.GetHandle()))); } else { if (src_device.MemType() == OrtDevice::MemType::HOST_ACCESSIBLE) { // sync the stream first to make sure the data arrived diff --git a/onnxruntime/core/providers/cuda/tensor/reshape.cc b/onnxruntime/core/providers/cuda/tensor/reshape.cc index ab364c274a32d..8ffcba9b716da 100644 --- a/onnxruntime/core/providers/cuda/tensor/reshape.cc +++ b/onnxruntime/core/providers/cuda/tensor/reshape.cc @@ -18,7 +18,7 @@ TensorShape InferReshapeOutputShape( TensorShape InferReshapeOutputShape(const Tensor* src, const Tensor* shape, bool allow_zero) { ORT_ENFORCE(shape != nullptr, "Cannot reshape to a null shape."); ORT_ENFORCE(shape->Shape().NumDimensions() == 1, "Shape must be an 1-D tensor."); - ORT_ENFORCE(shape->Location().device.Type() == OrtDevice::CPU, "Shape must be on CPU."); + ORT_ENFORCE(shape->Location().device.UsesCpuMemory(), "Shape must be on CPU."); return InferReshapeOutputShape( src->Shape(), @@ -39,7 +39,7 @@ Status FuncReshape( return ORT_MAKE_STATUS( ONNXRUNTIME, FAIL, "The shape tensor for reshaping must be a vector, but got ", shape->Shape(), "."); } - if (shape->Location().device.Type() != OrtDevice::CPU) { + if (!shape->Location().device.UsesCpuMemory()) { return Status(common::ONNXRUNTIME, common::FAIL, "Shape tensor must be on CPU."); } @@ -65,8 +65,9 @@ std::unique_ptr FuncReshape( ORT_ENFORCE(X != nullptr, "Missing data tensor to be reshaped."); ORT_ENFORCE(shape != nullptr, "Missing shape tensor for reshaping."); - ORT_ENFORCE(shape->Shape().NumDimensions() == 1, "The shape tensor for reshaping must be a vector, but got ", shape->Shape(), "."); - ORT_ENFORCE(shape->Location().device.Type() == OrtDevice::CPU, "Shape tensor must be on CPU."); + ORT_ENFORCE(shape->Shape().NumDimensions() == 1, "The shape tensor for reshaping must be a vector, but got ", + shape->Shape(), "."); + ORT_ENFORCE(shape->Location().device.UsesCpuMemory(), "Shape tensor must be on CPU."); // Calculate output's shape. auto dst_shape = InferReshapeOutputShape(X, shape, allow_zero); diff --git a/onnxruntime/core/providers/migraphx/gpu_data_transfer.cc b/onnxruntime/core/providers/migraphx/gpu_data_transfer.cc index 01d9ee99f07fe..c9cd6e21b4eba 100644 --- a/onnxruntime/core/providers/migraphx/gpu_data_transfer.cc +++ b/onnxruntime/core/providers/migraphx/gpu_data_transfer.cc @@ -10,8 +10,19 @@ namespace onnxruntime { bool GPUDataTransfer::CanCopy(const OrtDevice& src_device, const OrtDevice& dst_device) const { - return src_device.Type() == OrtDevice::GPU || src_device.MemType() == OrtDevice::MemType::HOST_ACCESSIBLE || - dst_device.Type() == OrtDevice::GPU || dst_device.MemType() == OrtDevice::MemType::HOST_ACCESSIBLE; + OrtDevice::DeviceType src_type = src_device.Type(); + OrtDevice::DeviceType dst_type = dst_device.Type(); + + // check that only our GPU is involved + if ((src_type == OrtDevice::GPU && src_device.Vendor() != OrtDevice::VendorIds::AMD) || + (dst_type == OrtDevice::GPU && dst_device.Vendor() != OrtDevice::VendorIds::AMD)) { + return false; + } + + // copy must involve a GPU, and be device to device or cpu (exclude other device types) + return (src_type == OrtDevice::GPU || dst_type == OrtDevice::GPU) && + (src_type == OrtDevice::GPU || src_type == OrtDevice::CPU) && + (dst_type == OrtDevice::GPU || dst_type == OrtDevice::CPU); } common::Status GPUDataTransfer::CopyTensor(const Tensor& src, Tensor& dst) const { @@ -22,9 +33,14 @@ common::Status GPUDataTransfer::CopyTensor(const Tensor& src, Tensor& dst) const auto& src_device = src.Location().device; auto& dst_device = dst.Location().device; + const bool dst_is_gpu_default = dst_device.Type() == OrtDevice::GPU && + dst_device.MemType() == OrtDevice::MemType::DEFAULT; + const bool src_is_gpu_default = src_device.Type() == OrtDevice::GPU && + src_device.MemType() == OrtDevice::MemType::DEFAULT; + // for the sync version of memcpy, launch to hip default stream - if (dst_device.Type() == OrtDevice::GPU) { - if (src_device.Type() == OrtDevice::GPU) { + if (dst_is_gpu_default) { + if (src_is_gpu_default) { // Copy only if the two addresses are different. if (dst_data != src_data) { HIP_RETURN_IF_ERROR(hipMemcpy(dst_data, src_data, bytes, hipMemcpyDeviceToDevice)); @@ -39,7 +55,7 @@ common::Status GPUDataTransfer::CopyTensor(const Tensor& src, Tensor& dst) const HIP_RETURN_IF_ERROR(hipStreamSynchronize(nullptr)); } } - } else if (src_device.Type() == OrtDevice::GPU) { + } else if (src_is_gpu_default) { // copying from GPU to CPU memory, this is blocking HIP_RETURN_IF_ERROR(hipMemcpy(dst_data, src_data, bytes, hipMemcpyDeviceToHost)); } else { @@ -59,22 +75,23 @@ common::Status GPUDataTransfer::CopyTensorAsync(const Tensor& src, Tensor& dst, auto& src_device = src.Location().device; auto& dst_device = dst.Location().device; - if (dst_device.Type() == OrtDevice::GPU) { - if (src_device.Type() == OrtDevice::CPU) { - // If source are not pinned, the memory copy will be performed synchronously. - // For best performance, use hipHostMalloc to allocate host memory that is transferred asynchronously. - HIP_RETURN_IF_ERROR(hipMemcpyAsync(dst_data, src_data, bytes, hipMemcpyHostToDevice, - static_cast(stream.GetHandle()))); - } else if (src_device.Type() == OrtDevice::GPU) { + const bool dst_is_gpu_default = dst_device.Type() == OrtDevice::GPU && + dst_device.MemType() == OrtDevice::MemType::DEFAULT; + const bool src_is_gpu_default = src_device.Type() == OrtDevice::GPU && + src_device.MemType() == OrtDevice::MemType::DEFAULT; + + if (dst_is_gpu_default) { + if (src_is_gpu_default) { // copying between GPU, this is non-blocking HIP_CALL_THROW(hipMemcpyAsync(dst_data, src_data, bytes, hipMemcpyDeviceToDevice, static_cast(stream.GetHandle()))); } else { - // copy from other CPU memory to GPU, this is blocking - HIP_CALL_THROW(hipMemcpyWithStream(dst_data, src_data, bytes, hipMemcpyHostToDevice, + // If source are not pinned, the memory copy will be performed synchronously. + // For best performance, use hipHostMalloc to allocate host memory that is transferred asynchronously. + HIP_RETURN_IF_ERROR(hipMemcpyAsync(dst_data, src_data, bytes, hipMemcpyHostToDevice, static_cast(stream.GetHandle()))); } - } else if (src_device.Type() == OrtDevice::GPU) { + } else if (src_is_gpu_default) { // If dest are not pinned, the memory copy will be performed synchronously. // For best performance, use hipHostMalloc to allocate host memory that is transferred asynchronously. HIP_RETURN_IF_ERROR(hipMemcpyAsync(dst_data, src_data, bytes, hipMemcpyDeviceToHost, diff --git a/onnxruntime/core/providers/migraphx/migraphx_allocator.h b/onnxruntime/core/providers/migraphx/migraphx_allocator.h index 8dcfe63796c89..f6b7788e0604c 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_allocator.h +++ b/onnxruntime/core/providers/migraphx/migraphx_allocator.h @@ -55,7 +55,7 @@ class MIGraphXPinnedAllocator final : public IAllocator { MIGraphXPinnedAllocator(const int device_id, const char* name) : IAllocator( OrtMemoryInfo(name, OrtDeviceAllocator, - OrtDevice(OrtDevice::CPU, OrtDevice::MemType::HOST_ACCESSIBLE, OrtDevice::VendorIds::AMD, + OrtDevice(OrtDevice::GPU, OrtDevice::MemType::HOST_ACCESSIBLE, OrtDevice::VendorIds::AMD, static_cast(device_id)), OrtMemTypeCPUOutput)) {} diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc index 1b5882062361c..aa8b21ea3fe52 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc @@ -334,12 +334,13 @@ AllocatorPtr MIGraphXExecutionProvider::CreateMIGraphXAllocator(OrtDevice::Devic std::vector MIGraphXExecutionProvider::CreatePreferredAllocators() { AllocatorCreationInfo default_memory_info( - [](OrtDevice::DeviceId device_id) { return std::make_unique(device_id, onnxruntime::CUDA); }, info_.device_id); + [](OrtDevice::DeviceId device_id) { return std::make_unique(device_id, onnxruntime::CUDA); }, + info_.device_id); AllocatorCreationInfo pinned_allocator_info( [](OrtDevice::DeviceId device_id) { return std::make_unique(device_id, onnxruntime::CUDA_PINNED); }, - 0); + info_.device_id); return std::vector{CreateAllocator(default_memory_info), CreateAllocator(pinned_allocator_info)}; } @@ -1620,8 +1621,8 @@ OrtDevice MIGraphXExecutionProvider::GetOrtDeviceByMemType(OrtMemType mem_type) if (mem_type == OrtMemTypeCPUInput) return OrtDevice(); if (mem_type == OrtMemTypeCPUOutput) - return OrtDevice(OrtDevice::CPU, OrtDevice::MemType::HOST_ACCESSIBLE, OrtDevice::VendorIds::AMD, - 0 /*CPU device id always be 0*/); + return OrtDevice(OrtDevice::GPU, OrtDevice::MemType::HOST_ACCESSIBLE, OrtDevice::VendorIds::AMD, + default_device_.Id()); return default_device_; } diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_allocator.h b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_allocator.h index b4b638ccb82f1..1ab5e47a08523 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_allocator.h +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_allocator.h @@ -53,11 +53,11 @@ class CUDAExternalAllocator : public CUDAAllocator { // TODO: add a default constructor class CUDAPinnedAllocator : public IAllocator { public: - CUDAPinnedAllocator(const char* name) + CUDAPinnedAllocator(const char* name, OrtDevice::DeviceId device_id) : IAllocator( OrtMemoryInfo(name, OrtAllocatorType::OrtDeviceAllocator, - OrtDevice(OrtDevice::CPU, OrtDevice::MemType::HOST_ACCESSIBLE, OrtDevice::VendorIds::NVIDIA, - 0 /*CPU device always with id 0*/), + OrtDevice(OrtDevice::GPU, OrtDevice::MemType::HOST_ACCESSIBLE, OrtDevice::VendorIds::NVIDIA, + device_id), OrtMemTypeCPUOutput)) {} void* Alloc(size_t size) override; diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_data_transfer.cc b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_data_transfer.cc index 0dfc5b8f8f7d4..d334aeaa86cdf 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_data_transfer.cc +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_data_transfer.cc @@ -10,8 +10,19 @@ #define CUDA_RETURN_IF_ERROR(expr) ORT_RETURN_IF_ERROR(CUDA_CALL(expr)) namespace onnxruntime { bool GPUDataTransfer::CanCopy(const OrtDevice& src_device, const OrtDevice& dst_device) const { - return src_device.Type() == OrtDevice::GPU || src_device.MemType() == OrtDevice::MemType::HOST_ACCESSIBLE || - dst_device.Type() == OrtDevice::GPU || dst_device.MemType() == OrtDevice::MemType::HOST_ACCESSIBLE; + OrtDevice::DeviceType src_type = src_device.Type(); + OrtDevice::DeviceType dst_type = dst_device.Type(); + + // check that only our GPU is involved + if ((src_type == OrtDevice::GPU && src_device.Vendor() != OrtDevice::VendorIds::NVIDIA) || + (dst_type == OrtDevice::GPU && dst_device.Vendor() != OrtDevice::VendorIds::NVIDIA)) { + return false; + } + + // copy must involve a GPU, and be device to device or cpu (exclude other device types) + return (src_type == OrtDevice::GPU || dst_type == OrtDevice::GPU) && + (src_type == OrtDevice::GPU || src_type == OrtDevice::CPU) && + (dst_type == OrtDevice::GPU || dst_type == OrtDevice::CPU); } common::Status GPUDataTransfer::CopyTensor(const Tensor& src, Tensor& dst) const { @@ -22,9 +33,14 @@ common::Status GPUDataTransfer::CopyTensor(const Tensor& src, Tensor& dst) const auto& src_device = src.Location().device; auto& dst_device = dst.Location().device; + const bool dst_is_gpu_default = dst_device.Type() == OrtDevice::GPU && + dst_device.MemType() == OrtDevice::MemType::DEFAULT; + const bool src_is_gpu_default = src_device.Type() == OrtDevice::GPU && + src_device.MemType() == OrtDevice::MemType::DEFAULT; + // for the sync version of memcpy, launch to cuda default stream - if (dst_device.Type() == OrtDevice::GPU) { - if (src_device.Type() == OrtDevice::GPU) { + if (dst_is_gpu_default) { + if (src_is_gpu_default) { // Copy only if the two addresses are different. if (dst_data != src_data) { CUDA_RETURN_IF_ERROR(cudaMemcpy(dst_data, src_data, bytes, cudaMemcpyDeviceToDevice)); @@ -41,7 +57,7 @@ common::Status GPUDataTransfer::CopyTensor(const Tensor& src, Tensor& dst) const CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(nullptr)); } } - } else if (src_device.Type() == OrtDevice::GPU) { + } else if (src_is_gpu_default) { // copying from GPU to CPU memory, this is blocking CUDA_RETURN_IF_ERROR(cudaMemcpy(dst_data, src_data, bytes, cudaMemcpyDeviceToHost)); } else { @@ -61,24 +77,27 @@ common::Status GPUDataTransfer::CopyTensorAsync(const Tensor& src, Tensor& dst, auto& src_device = src.Location().device; auto& dst_device = dst.Location().device; - if (dst_device.Type() == OrtDevice::GPU) { - if (src_device.Type() == OrtDevice::CPU) { - // copy from pinned or non-pinned CPU memory to GPU - CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(dst_data, src_data, bytes, cudaMemcpyHostToDevice, - static_cast(stream.GetHandle()))); - } else if (src_device.Type() == OrtDevice::GPU) { + const bool dst_is_gpu_default = dst_device.Type() == OrtDevice::GPU && + dst_device.MemType() == OrtDevice::MemType::DEFAULT; + const bool src_is_gpu_default = src_device.Type() == OrtDevice::GPU && + src_device.MemType() == OrtDevice::MemType::DEFAULT; + + if (dst_is_gpu_default) { + if (src_is_gpu_default) { // copying between GPU, this is non-blocking if (dst_data != src_data) { CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(dst_data, src_data, bytes, cudaMemcpyDeviceToDevice, static_cast(stream.GetHandle()))); } - } - } else if (src_device.Type() == OrtDevice::GPU) { - if (dst_device.Type() == OrtDevice::CPU) { - // copy from GPU to pinned or non-pinned CPU memory. - CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(dst_data, src_data, bytes, cudaMemcpyDeviceToHost, + } else { + // copy from pinned or non-pinned CPU memory to GPU + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(dst_data, src_data, bytes, cudaMemcpyHostToDevice, static_cast(stream.GetHandle()))); } + } else if (src_is_gpu_default) { + // copy from GPU to pinned or non-pinned CPU memory. + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(dst_data, src_data, bytes, cudaMemcpyDeviceToHost, + static_cast(stream.GetHandle()))); } else { if (src_device.MemType() == OrtDevice::MemType::HOST_ACCESSIBLE) { // sync the stream first to make sure the data arrived diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc index f06ed1424eb24..711d81186bad1 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc @@ -1304,11 +1304,9 @@ std::vector NvExecutionProvider::CreatePreferredAllocators() { AllocatorCreationInfo pinned_allocator_info( [](OrtDevice::DeviceId device_id) { - ORT_UNUSED_PARAMETER(device_id); - return std::make_unique(CUDA_PINNED); - ; + return std::make_unique(device_id, CUDA_PINNED); }, - 0); + narrow(device_id_)); return std::vector{CreateAllocator(default_memory_info), CreateAllocator(pinned_allocator_info)}; } @@ -3260,8 +3258,8 @@ OrtDevice NvExecutionProvider::GetOrtDeviceByMemType(OrtMemType mem_type) const if (mem_type == OrtMemTypeCPUInput) return OrtDevice(); if (mem_type == OrtMemTypeCPUOutput) - return OrtDevice(OrtDevice::CPU, OrtDevice::MemType::HOST_ACCESSIBLE, OrtDevice::VendorIds::NVIDIA, - 0 /*CPU device id always be 0*/); + return OrtDevice(OrtDevice::GPU, OrtDevice::MemType::HOST_ACCESSIBLE, OrtDevice::VendorIds::NVIDIA, + default_device_.Id()); return default_device_; } diff --git a/onnxruntime/core/providers/qnn/qnn_allocator.cc b/onnxruntime/core/providers/qnn/qnn_allocator.cc index d248644c13ddb..f989021faedaf 100644 --- a/onnxruntime/core/providers/qnn/qnn_allocator.cc +++ b/onnxruntime/core/providers/qnn/qnn_allocator.cc @@ -114,6 +114,7 @@ AllocationTracker& GlobalAllocationTracker() { OrtMemoryInfo HtpSharedMemoryAllocator::AssociatedMemoryInfo() { return OrtMemoryInfo{QNN_HTP_SHARED, OrtAllocatorType::OrtDeviceAllocator, + // QNN EP registers with OrtDevice::CPU so we use that for HOST_ACCESSIBLE as well OrtDevice{OrtDevice::CPU, OrtDevice::MemType::HOST_ACCESSIBLE, OrtDevice::VendorIds::QUALCOMM, /*device_id*/ 0}, OrtMemTypeDefault}; diff --git a/onnxruntime/core/providers/shared_library/provider_api.h b/onnxruntime/core/providers/shared_library/provider_api.h index 4d3ae4f4a7e07..71d51c4c2992d 100644 --- a/onnxruntime/core/providers/shared_library/provider_api.h +++ b/onnxruntime/core/providers/shared_library/provider_api.h @@ -310,13 +310,13 @@ inline OrtStatus* CreateStatus(OrtErrorCode code, _In_ const char* msg) noexcept std::unique_ptr CreateCPUAllocator(const OrtMemoryInfo& memory_info); std::unique_ptr CreateCUDAAllocator(int16_t device_id, const char* name); -std::unique_ptr CreateCUDAPinnedAllocator(const char* name); +std::unique_ptr CreateCUDAPinnedAllocator(int16_t device_id, const char* name); std::unique_ptr CreateMIGraphXAllocator(int16_t device_id, const char* name); std::unique_ptr CreateMIGraphXPinnedAllocator(int16_t device_id, const char* name); std::unique_ptr CreateROCMAllocator(int16_t device_id, const char* name); -std::unique_ptr CreateROCMPinnedAllocator(const char* name); +std::unique_ptr CreateROCMPinnedAllocator(int16_t device_id, const char* name); std::unique_ptr CreateGPUDataTransfer(); diff --git a/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc b/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc index afabc1fa9b1c9..c9ff0d807633f 100644 --- a/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc +++ b/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc @@ -352,8 +352,8 @@ std::unique_ptr CreateCUDAAllocator(int16_t device_id, const char* n return g_host->CreateCUDAAllocator(device_id, name); } -std::unique_ptr CreateCUDAPinnedAllocator(const char* name) { - return g_host->CreateCUDAPinnedAllocator(name); +std::unique_ptr CreateCUDAPinnedAllocator(int16_t device_id, const char* name) { + return g_host->CreateCUDAPinnedAllocator(device_id, name); } std::unique_ptr CreateGPUDataTransfer() { diff --git a/onnxruntime/core/providers/shared_library/provider_interfaces.h b/onnxruntime/core/providers/shared_library/provider_interfaces.h index f843b86375e78..dba26b3982d86 100644 --- a/onnxruntime/core/providers/shared_library/provider_interfaces.h +++ b/onnxruntime/core/providers/shared_library/provider_interfaces.h @@ -187,7 +187,7 @@ struct ProviderHost { virtual std::string demangle(const std::string& name) = 0; virtual std::unique_ptr CreateCUDAAllocator(int16_t device_id, const char* name) = 0; - virtual std::unique_ptr CreateCUDAPinnedAllocator(const char* name) = 0; + virtual std::unique_ptr CreateCUDAPinnedAllocator(int16_t device_id, const char* name) = 0; virtual std::unique_ptr CreateGPUDataTransfer() = 0; virtual void cuda__Impl_Cast(void* stream, const int64_t* input_data, int32_t* output_data, size_t count) = 0; @@ -205,7 +205,7 @@ struct ProviderHost { #ifdef USE_ROCM virtual std::unique_ptr CreateROCMAllocator(int16_t device_id, const char* name) = 0; - virtual std::unique_ptr CreateROCMPinnedAllocator(const char* name) = 0; + virtual std::unique_ptr CreateROCMPinnedAllocator(int16_t device_id, const char* name) = 0; virtual void rocm__Impl_Cast(void* stream, const int64_t* input_data, int32_t* output_data, size_t count) = 0; virtual void rocm__Impl_Cast(void* stream, const int32_t* input_data, int64_t* output_data, size_t count) = 0; diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index 14782b5a52262..1121775bf5ef7 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -1917,10 +1917,9 @@ std::vector TensorrtExecutionProvider::CreatePreferredAllocators() AllocatorCreationInfo pinned_allocator_info( [](OrtDevice::DeviceId device_id) { - ORT_UNUSED_PARAMETER(device_id); - return CreateCUDAPinnedAllocator(onnxruntime::CUDA_PINNED); + return CreateCUDAPinnedAllocator(device_id, onnxruntime::CUDA_PINNED); }, - 0); + narrow(device_id_)); return std::vector{CreateAllocator(default_memory_info), CreateAllocator(pinned_allocator_info)}; } @@ -4535,8 +4534,8 @@ OrtDevice TensorrtExecutionProvider::GetOrtDeviceByMemType(OrtMemType mem_type) if (mem_type == OrtMemTypeCPUInput) return OrtDevice(); if (mem_type == OrtMemTypeCPUOutput) - return OrtDevice(OrtDevice::CPU, OrtDevice::MemType::HOST_ACCESSIBLE, OrtDevice::VendorIds::NVIDIA, - 0 /*CPU device id always be 0*/); + return OrtDevice(OrtDevice::GPU, OrtDevice::MemType::HOST_ACCESSIBLE, OrtDevice::VendorIds::NVIDIA, + default_device_.Id()); return default_device_; } diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 15f86cf0d7002..b30d9182684c4 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -332,9 +332,10 @@ ORT_API_STATUS_IMPL(OrtApis::CreateSparseTensorAsOrtValue, _Inout_ OrtAllocator* namespace { #if !defined(DISABLE_SPARSE_TENSORS) std::unique_ptr GetDataTransfer(const OrtDevice& src_device, const OrtDevice& dst_device) { - if (src_device.Type() == OrtDevice::CPU && dst_device.Type() == OrtDevice::CPU) { + if (src_device.UsesCpuMemory() && dst_device.UsesCpuMemory()) { return std::make_unique(); } + #if defined(USE_CUDA) || defined(USE_CUDA_PROVIDER_INTERFACE) if (src_device.Type() == OrtDevice::GPU || dst_device.Type() == OrtDevice::GPU) { if (auto* provider_info = TryGetProviderInfo_CUDA()) { @@ -348,7 +349,7 @@ std::unique_ptr GetDataTransfer(const OrtDevice& src_device, cons SparseTensor& ValidateFillInputArgs(OrtValue* v, const TensorShape& values_shape, const OrtMemoryInfo* data_mem_info) { auto& sparse_tensor = SparseTensor::GetSparseTensorFromOrtValue(*v); if (sparse_tensor.IsDataTypeString()) { - if ((data_mem_info->device.Type() != OrtDevice::CPU) || sparse_tensor.Location().device.Type() != OrtDevice::CPU) { + if (!data_mem_info->device.UsesCpuMemory() || !sparse_tensor.Location().device.UsesCpuMemory()) { ORT_THROW("Strings can only reside in CPU memory"); } } diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index 2a1f7580ac3aa..8cd16fb4e7347 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -281,7 +281,7 @@ struct ProviderHostImpl : ProviderHost { void CPUAllocator__Free(CPUAllocator* p, void* allocation) override { return p->CPUAllocator::Free(allocation); } std::unique_ptr CreateCUDAAllocator(int16_t device_id, const char* name) override { return GetProviderInfo_CUDA().CreateCUDAAllocator(device_id, name); } - std::unique_ptr CreateCUDAPinnedAllocator(const char* name) override { return GetProviderInfo_CUDA().CreateCUDAPinnedAllocator(name); } + std::unique_ptr CreateCUDAPinnedAllocator(int16_t device_id, const char* name) override { return GetProviderInfo_CUDA().CreateCUDAPinnedAllocator(device_id, name); } void cuda__Impl_Cast(void* stream, const int64_t* input_data, int32_t* output_data, size_t count) override { return GetProviderInfo_CUDA().cuda__Impl_Cast(stream, input_data, output_data, count); } void cuda__Impl_Cast(void* stream, const int32_t* input_data, int64_t* output_data, size_t count) override { return GetProviderInfo_CUDA().cuda__Impl_Cast(stream, input_data, output_data, count); } @@ -1941,9 +1941,9 @@ void UnloadSharedProviders() { } // Used by test code -std::unique_ptr CreateCUDAPinnedAllocator(const char* name) { +std::unique_ptr CreateCUDAPinnedAllocator(int16_t device_id, const char* name) { if (auto* info = onnxruntime::TryGetProviderInfo_CUDA()) - return info->CreateCUDAPinnedAllocator(name); + return info->CreateCUDAPinnedAllocator(device_id, name); return nullptr; } diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index 03db6f069cd75..5624befd0ca66 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -2078,7 +2078,7 @@ for model inference.)pbdoc"); } else if (strcmp(name, onnxruntime::CUDA_PINNED) == 0) { return std::make_unique( onnxruntime::CUDA_PINNED, type, - OrtDevice(OrtDevice::CPU, OrtDevice::MemType::HOST_ACCESSIBLE, OrtDevice::VendorIds::NVIDIA, + OrtDevice(OrtDevice::GPU, OrtDevice::MemType::HOST_ACCESSIBLE, OrtDevice::VendorIds::NVIDIA, static_cast(id)), mem_type); } else { diff --git a/onnxruntime/test/framework/allocation_planner_test.cc b/onnxruntime/test/framework/allocation_planner_test.cc index c957f54e51a9c..8827696bc2fb9 100644 --- a/onnxruntime/test/framework/allocation_planner_test.cc +++ b/onnxruntime/test/framework/allocation_planner_test.cc @@ -316,7 +316,8 @@ class PlannerTest : public ::testing::Test { public: // Wait is a little special as we need to consider the source stream the notification generated, and the stream we are waiting. // i.e., for an cuda event what notify the memory copy, it could be wait on a CPU stream, or on another cuda stream. - virtual WaitNotificationFn GetWaitHandle(const OrtDevice::DeviceType /*notification_owner_ep_type*/, const OrtDevice::DeviceType /*executor_ep_type*/) const override { + virtual WaitNotificationFn GetWaitHandle(const OrtDevice& /*notification_owner_device*/, + const OrtDevice& /*executor_device*/) const override { return nullptr; } diff --git a/onnxruntime/test/providers/cuda/test_cases/allocator_cuda_test.cc b/onnxruntime/test/providers/cuda/test_cases/allocator_cuda_test.cc index 0410e51ce207d..91a4fe9a54251 100644 --- a/onnxruntime/test/providers/cuda/test_cases/allocator_cuda_test.cc +++ b/onnxruntime/test/providers/cuda/test_cases/allocator_cuda_test.cc @@ -34,7 +34,7 @@ TEST(AllocatorTest, CUDAAllocatorTest) { EXPECT_TRUE(cuda_addr); AllocatorCreationInfo pinned_memory_info( - [](int) { return std::make_unique(CUDA_PINNED); }); + [](int device_id) { return std::make_unique(device_id, CUDA_PINNED); }); auto pinned_allocator = CreateAllocator(pinned_memory_info); diff --git a/onnxruntime/test/providers/cuda/test_cases/cuda_execution_provider_test.cc b/onnxruntime/test/providers/cuda/test_cases/cuda_execution_provider_test.cc index 7f6f6ba3bb4b0..68b1483c25792 100644 --- a/onnxruntime/test/providers/cuda/test_cases/cuda_execution_provider_test.cc +++ b/onnxruntime/test/providers/cuda/test_cases/cuda_execution_provider_test.cc @@ -23,7 +23,7 @@ TEST(TestDeferredRelease, WithArena) { // Create CUDA EP. CUDAExecutionProviderInfo info; CUDAExecutionProvider ep(info); - AllocatorPtr gpu_alloctor = ep.CreatePreferredAllocators()[0]; + AllocatorPtr gpu_allocator = ep.CreatePreferredAllocators()[0]; onnxruntime::RunOptions run_opts; run_opts.run_tag = "log1"; @@ -32,7 +32,7 @@ TEST(TestDeferredRelease, WithArena) { AllocatorPtr cpu_pinned_alloc = ep.CreatePreferredAllocators()[1]; // let the CudaStream instance "own" the default stream, so we can avoid the // work to initialize cublas/cudnn/... It is ok since it is just a customized unit test. - CudaStream stream(nullptr, gpu_alloctor->Info().device, cpu_pinned_alloc, false, true, nullptr, nullptr, info); + CudaStream stream(nullptr, gpu_allocator->Info().device, cpu_pinned_alloc, false, true, nullptr, nullptr, info); // 10 MB const size_t n_bytes = 10 * 1000000; const int64_t n_allocs = 64; @@ -60,22 +60,22 @@ TEST(TestDeferredRelease, WithoutArena) { onnxruntime::RunOptions run_opts; run_opts.run_tag = "log1"; - OrtDevice pinned_device{OrtDevice::CPU, OrtDevice::MemType::HOST_ACCESSIBLE, OrtDevice::VendorIds::NVIDIA, - DEFAULT_CPU_ALLOCATOR_DEVICE_ID}; + OrtDevice pinned_device{OrtDevice::GPU, OrtDevice::MemType::HOST_ACCESSIBLE, OrtDevice::VendorIds::NVIDIA, + info.device_id}; // Create allocator without BFCArena AllocatorCreationInfo pinned_memory_info( - [](OrtDevice::DeviceId) { - return std::make_unique(CUDA_PINNED); + [](OrtDevice::DeviceId device_id) { + return std::make_unique(device_id, CUDA_PINNED); }, pinned_device.Id(), false /* no arena */); auto cuda_pinned_alloc = CreateAllocator(pinned_memory_info); - AllocatorPtr gpu_alloctor = ep.CreatePreferredAllocators()[0]; + AllocatorPtr gpu_allocator = ep.CreatePreferredAllocators()[0]; // Allocator for call cudaMallocHost and cudaFreeHost // For details, see CUDAPinnedAllocator in cuda_allocator.cc. // let the CudaStream instance "own" the default stream, so we can avoid the // work to initialize cublas/cudnn/... It is ok since it is just a customized unit test. - CudaStream stream(nullptr, gpu_alloctor->Info().device, cuda_pinned_alloc, false, true, nullptr, nullptr, info); + CudaStream stream(nullptr, gpu_allocator->Info().device, cuda_pinned_alloc, false, true, nullptr, nullptr, info); // 10 MB const size_t n_bytes = 10 * 1000000; const int64_t n_allocs = 64; diff --git a/onnxruntime/test/providers/cuda/test_cases/cuda_test_provider.cc b/onnxruntime/test/providers/cuda/test_cases/cuda_test_provider.cc index 022c6250138d4..8213a95dcff08 100644 --- a/onnxruntime/test/providers/cuda/test_cases/cuda_test_provider.cc +++ b/onnxruntime/test/providers/cuda/test_cases/cuda_test_provider.cc @@ -47,7 +47,7 @@ struct ProviderInfo_CUDA_TestImpl : ProviderInfo_CUDA { return nullptr; } - std::unique_ptr CreateCUDAPinnedAllocator(const char*) override { + std::unique_ptr CreateCUDAPinnedAllocator(int16_t, const char*) override { return nullptr; } From 265756146d3cf8b4f12269000c11e598f85b0650 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Wed, 2 Jul 2025 17:39:28 -0700 Subject: [PATCH 09/19] [build] do not use separated artifacts for wasm build (#25267) ### Description Previously we have 2 artifacts: - `${{ inputs.build_config }}_wasm` - `${{ inputs.build_config }}_wasm_webgpu` Now that we use a different file name so that we can simplify this part and make it a single artifact. --- .../linux-wasm-ci-build-and-test-workflow.yml | 18 ++++----------- .github/workflows/windows-web-ci-workflow.yml | 16 ------------- .../templates/linux-wasm-ci.yml | 22 ++++-------------- .../azure-pipelines/templates/win-web-ci.yml | 23 ------------------- 4 files changed, 8 insertions(+), 71 deletions(-) diff --git a/.github/workflows/linux-wasm-ci-build-and-test-workflow.yml b/.github/workflows/linux-wasm-ci-build-and-test-workflow.yml index bde704edc2b6b..6667f0bc8e5ae 100644 --- a/.github/workflows/linux-wasm-ci-build-and-test-workflow.yml +++ b/.github/workflows/linux-wasm-ci-build-and-test-workflow.yml @@ -107,13 +107,10 @@ jobs: cp ${{ github.workspace }}/build/wasm_inferencing_jsep/${{ inputs.build_config }}/ort-wasm-simd-threaded.jsep.wasm ${{ github.workspace }}/artifacts/wasm/ cp ${{ github.workspace }}/build/wasm_inferencing_jsep/${{ inputs.build_config }}/ort-wasm-simd-threaded.jsep.mjs ${{ github.workspace }}/artifacts/wasm/ fi - - - name: Create WebGPU Artifacts - if: ${{ inputs.skip_publish != true && inputs.build_webgpu == true }} - run: | - mkdir -p ${{ github.workspace }}/artifacts/wasm_webgpu/ - cp ${{ github.workspace }}/build/wasm_inferencing_webgpu/${{ inputs.build_config }}/ort-wasm-simd-threaded.asyncify.wasm ${{ github.workspace }}/artifacts/wasm_webgpu/ - cp ${{ github.workspace }}/build/wasm_inferencing_webgpu/${{ inputs.build_config }}/ort-wasm-simd-threaded.asyncify.mjs ${{ github.workspace }}/artifacts/wasm_webgpu/ + if [ -d ${{ github.workspace }}/build/wasm_inferencing_webgpu ]; then + cp ${{ github.workspace }}/build/wasm_inferencing_webgpu/${{ inputs.build_config }}/ort-wasm-simd-threaded.asyncify.wasm ${{ github.workspace }}/artifacts/wasm/ + cp ${{ github.workspace }}/build/wasm_inferencing_webgpu/${{ inputs.build_config }}/ort-wasm-simd-threaded.asyncify.mjs ${{ github.workspace }}/artifacts/wasm/ + fi - name: Upload WASM artifacts if: ${{ inputs.skip_publish != true }} @@ -122,13 +119,6 @@ jobs: name: ${{ inputs.build_config }}_wasm path: ${{ github.workspace }}/artifacts/wasm - - name: Upload WebGPU artifacts - if: ${{ inputs.skip_publish != true && inputs.build_webgpu == true }} - uses: actions/upload-artifact@v4 - with: - name: ${{ inputs.build_config }}_wasm_webgpu - path: ${{ github.workspace }}/artifacts/wasm_webgpu - - name: Test (Node.js) (simd + threads) # onnxruntime_test_all is currently only supported in Debug build because it requires exception, which is disabled in Release build. if: ${{ inputs.build_config == 'Debug' }} diff --git a/.github/workflows/windows-web-ci-workflow.yml b/.github/workflows/windows-web-ci-workflow.yml index a2a0e8d36def8..fcbef760d4626 100644 --- a/.github/workflows/windows-web-ci-workflow.yml +++ b/.github/workflows/windows-web-ci-workflow.yml @@ -83,22 +83,6 @@ jobs: run: | copy ${{ github.workspace }}\artifacts_wasm\ort-*.mjs ${{ github.workspace }}\js\web\dist\ - - name: Download WebAssembly WebGPU artifacts - uses: actions/download-artifact@v4 - with: - name: ${{ inputs.build_config }}_wasm_webgpu - path: ${{ github.workspace }}/artifacts_wasm_webgpu - - - name: Binplace dist files (.wasm) for WebGPU - shell: cmd - run: | - copy ${{ github.workspace }}\artifacts_wasm_webgpu\ort-*.wasm ${{ github.workspace }}\js\web\dist\ - - - name: Binplace dist files (.mjs) for WebGPU - shell: cmd - run: | - copy ${{ github.workspace }}\artifacts_wasm_webgpu\ort-*.mjs ${{ github.workspace }}\js\web\dist\ - - name: npm ci for /js/ run: npm ci working-directory: ${{ github.workspace }}/js diff --git a/tools/ci_build/github/azure-pipelines/templates/linux-wasm-ci.yml b/tools/ci_build/github/azure-pipelines/templates/linux-wasm-ci.yml index fa12aab6a91b2..9f76c150ca2a4 100644 --- a/tools/ci_build/github/azure-pipelines/templates/linux-wasm-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/linux-wasm-ci.yml @@ -146,37 +146,23 @@ jobs: cp $(Build.BinariesDirectory)/wasm_inferencing_jsep/${{ parameters.BuildConfig }}/ort-wasm-simd-threaded.jsep.wasm $(Build.ArtifactStagingDirectory)/wasm/ cp $(Build.BinariesDirectory)/wasm_inferencing_jsep/${{ parameters.BuildConfig }}/ort-wasm-simd-threaded.jsep.mjs $(Build.ArtifactStagingDirectory)/wasm/ fi + if [ -d $(Build.BinariesDirectory)/wasm_inferencing_webgpu ]; then + cp $(Build.BinariesDirectory)/wasm_inferencing_webgpu/${{ parameters.BuildConfig }}/ort-wasm-simd-threaded.asyncify.wasm $(Build.ArtifactStagingDirectory)/wasm/ + cp $(Build.BinariesDirectory)/wasm_inferencing_webgpu/${{ parameters.BuildConfig }}/ort-wasm-simd-threaded.asyncify.mjs $(Build.ArtifactStagingDirectory)/wasm/ + fi displayName: 'Create Artifacts' - - ${{ if eq(parameters.BuildWebGPU, true) }}: - - script: | - mkdir -p $(Build.ArtifactStagingDirectory)/wasm_webgpu/ - cp $(Build.BinariesDirectory)/wasm_inferencing_webgpu/${{ parameters.BuildConfig }}/ort-wasm-simd-threaded.asyncify.wasm $(Build.ArtifactStagingDirectory)/wasm_webgpu/ - cp $(Build.BinariesDirectory)/wasm_inferencing_webgpu/${{ parameters.BuildConfig }}/ort-wasm-simd-threaded.asyncify.mjs $(Build.ArtifactStagingDirectory)/wasm_webgpu/ - displayName: 'Create Artifacts (WebGPU EP)' - ${{ if eq(parameters.is1ES, false) }}: - task: PublishPipelineArtifact@1 displayName: 'Publish Pipeline Artifact' inputs: artifactName: '${{ parameters.BuildConfig }}_wasm' targetPath: '$(Build.ArtifactStagingDirectory)/wasm' - - ${{ if eq(parameters.BuildWebGPU, true) }}: - - task: PublishPipelineArtifact@1 - displayName: 'Publish Pipeline Artifact (WebGPU EP)' - inputs: - artifactName: '${{ parameters.BuildConfig }}_wasm_webgpu' - targetPath: '$(Build.ArtifactStagingDirectory)/wasm_webgpu' - ${{ if eq(parameters.is1ES, true) }}: - task: 1ES.PublishPipelineArtifact@1 displayName: 'Publish Pipeline Artifact' inputs: artifactName: '${{ parameters.BuildConfig }}_wasm' targetPath: '$(Build.ArtifactStagingDirectory)/wasm' - - ${{ if eq(parameters.BuildWebGPU, true) }}: - - task: 1ES.PublishPipelineArtifact@1 - displayName: 'Publish Pipeline Artifact (WebGPU EP)' - inputs: - artifactName: '${{ parameters.BuildConfig }}_wasm_webgpu' - targetPath: '$(Build.ArtifactStagingDirectory)/wasm_webgpu' - task: PublishTestResults@2 displayName: 'Publish unit test results' inputs: diff --git a/tools/ci_build/github/azure-pipelines/templates/win-web-ci.yml b/tools/ci_build/github/azure-pipelines/templates/win-web-ci.yml index fecbfc8657894..b32242cfea010 100644 --- a/tools/ci_build/github/azure-pipelines/templates/win-web-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/win-web-ci.yml @@ -225,29 +225,6 @@ jobs: # === Start of experimental WebGPU EP tests === - ${{ if eq(parameters.RunWebGpuTests, true) }}: - - task: DownloadPipelineArtifact@2 - inputs: - patterns: '${{ parameters.BuildConfig }}_wasm_webgpu/**/*' - path: $(Pipeline.Workspace)\artifacts_wasm_webgpu - displayName: 'Download WebAssembly artifacts' - - task: CopyFiles@2 - inputs: - sourceFolder: $(Pipeline.Workspace)\artifacts_wasm_webgpu - contents: | - **\ort-*.wasm - targetFolder: $(Build.SourcesDirectory)\js\web\dist - flattenFolders: true - overWrite: true - displayName: 'Binplace dist files (.wasm)' - - task: CopyFiles@2 - inputs: - sourceFolder: $(Pipeline.Workspace)\artifacts_wasm_webgpu - contents: | - **\ort-*.mjs - targetFolder: $(Build.SourcesDirectory)\js\web\dist - flattenFolders: true - overWrite: true - displayName: 'Binplace dist files (.mjs)' - script: | powershell "Get-WmiObject Win32_Process -Filter \"name = 'chrome.exe'\" | Format-List CommandLine" displayName: 'Check active Chrome processes (before test)' From 4fb92bf8319be5919cd7e26043eef24707ae46d3 Mon Sep 17 00:00:00 2001 From: quic-hungjuiw Date: Thu, 3 Jul 2025 13:29:45 +0800 Subject: [PATCH 10/19] [QNN EP] Add Infrastructure to check datatypes (#25257) ### Description - Add BaseOpBuilder::ProcessDataTypes - Add CheckCpuDataTypes, CheckHtpDataTypes and CheckGpuDataTypes - Check if datatypes are supported on QnnCpu and QnnHtp for BatchNorm - Add corresponding unit test for BatchNorm on QnnCpu and QnnHtp ### Motivation and Context - Due to varying datatype support for each op on various backends (QnnCpu, QnnHtp, QnnGpu), we need an infrastructure to check datatypes according to the document https://docs.qualcomm.com/bundle/publicresource/topics/80-63442-50/operations.html#backend-supplements --- .../qnn/builder/opbuilder/base_op_builder.cc | 45 +++++++ .../qnn/builder/opbuilder/base_op_builder.h | 12 ++ .../opbuilder/batch_norm_op_builder.cc | 69 +++++++++- ...ch_norm_htp_test.cc => batch_norm_test.cc} | 119 +++++++++++++++++- 4 files changed, 238 insertions(+), 7 deletions(-) rename onnxruntime/test/providers/qnn/{batch_norm_htp_test.cc => batch_norm_test.cc} (71%) diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.cc index 74518e2fcb7a2..0152ad27c0ba2 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.cc @@ -16,6 +16,8 @@ std::string BaseOpBuilder::GetOpBuilderType() const { Status BaseOpBuilder::IsOpSupported(QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit, const logging::Logger& logger) const { + // General Datatype checks on various QNN backend (HTP, CPU, GPU) + ORT_RETURN_IF_ERROR(ProcessDataTypes(qnn_model_wrapper, node_unit)); return AddToModelBuilder(qnn_model_wrapper, node_unit, logger, true); } @@ -37,6 +39,47 @@ Status BaseOpBuilder::AddToModelBuilder(QnnModelWrapper& qnn_model_wrapper, return Status::OK(); } +Status BaseOpBuilder::ProcessDataTypes(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit) const { + std::vector input_qnn_dtypes; + std::vector output_qnn_dtypes; + const auto& inputs = node_unit.Inputs(); + const auto& outputs = node_unit.Outputs(); + for (auto input : inputs) { + TensorInfo tensor_info = {}; + ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(input, tensor_info)); + Qnn_DataType_t qnn_data_type = tensor_info.qnn_data_type; + input_qnn_dtypes.push_back(qnn_data_type); + } + for (auto output : outputs) { + TensorInfo tensor_info = {}; + ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(output, tensor_info)); + Qnn_DataType_t qnn_data_type = tensor_info.qnn_data_type; + output_qnn_dtypes.push_back(qnn_data_type); + } + if (IsCpuBackend(qnn_model_wrapper.GetQnnBackendType())) { + return CheckCpuDataTypes(input_qnn_dtypes, output_qnn_dtypes); + } else if (IsNpuBackend(qnn_model_wrapper.GetQnnBackendType())) { + return CheckHtpDataTypes(input_qnn_dtypes, output_qnn_dtypes); + } else if (IsGpuBackend(qnn_model_wrapper.GetQnnBackendType())) { + return CheckGpuDataTypes(input_qnn_dtypes, output_qnn_dtypes); + } + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Only support backend: CPU, HTP and GPU"); +} + +Status BaseOpBuilder::CheckCpuDataTypes(const std::vector, + const std::vector) const { + return Status::OK(); +} +Status BaseOpBuilder::CheckHtpDataTypes(const std::vector, + const std::vector) const { + return Status::OK(); +} +Status BaseOpBuilder::CheckGpuDataTypes(const std::vector, + const std::vector) const { + return Status::OK(); +} + Status BaseOpBuilder::ProcessInput(QnnModelWrapper& qnn_model_wrapper, const NodeUnitIODef& input, const logging::Logger& logger, @@ -371,6 +414,8 @@ Status BaseOpBuilder::ProcessAxisAttribute(const QnnModelWrapper& qnn_model_wrap } Status DataTypeCheckForCpuBackend(QnnModelWrapper& qnn_model_wrapper, ONNX_NAMESPACE::DataType onnx_tensor_data_type) { + // TODO: Retire the DataTypeCheckForCpuBackend once all Ops transition to using BaseOpBuilder::ProcessDataTypes + // Due to varying datatype support for each op in Qnn CPU backend, we need to implement CheckCpuDataTypes for each op. const auto float_elem_type = ONNX_NAMESPACE::Utils::DataTypeUtils::ToType("float"); bool is_cpu_backend = (qnn_model_wrapper.GetQnnBackendType() == QnnBackendType::CPU); ORT_RETURN_IF(is_cpu_backend && onnx_tensor_data_type != float_elem_type, "QNN CPU backend only support float data type."); diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.h b/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.h index e009bc558e884..e910afcbcf6c6 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.h +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.h @@ -70,6 +70,18 @@ class BaseOpBuilder : public IOpBuilder { return Status::OK(); } + Status ProcessDataTypes(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit) const ORT_MUST_USE_RESULT; + + virtual Status CheckCpuDataTypes(const std::vector, + const std::vector) const ORT_MUST_USE_RESULT; + + virtual Status CheckHtpDataTypes(const std::vector, + const std::vector) const ORT_MUST_USE_RESULT; + + virtual Status CheckGpuDataTypes(const std::vector, + const std::vector) const ORT_MUST_USE_RESULT; + virtual Status ProcessInputs(QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit, const logging::Logger& logger, diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/batch_norm_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/batch_norm_op_builder.cc index 21036dc55aefc..51f6523559987 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/batch_norm_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/batch_norm_op_builder.cc @@ -428,6 +428,13 @@ class BatchNormOpBuilder : public BaseOpBuilder { } return Status::OK(); } + + protected: + Status CheckCpuDataTypes(const std::vector in_dtypes, + const std::vector out_dtypes) const override ORT_MUST_USE_RESULT; + + Status CheckHtpDataTypes(const std::vector in_dtypes, + const std::vector out_dtypes) const override ORT_MUST_USE_RESULT; }; // BatchNorm is sensitive with data layout, no special validation so far @@ -441,12 +448,12 @@ Status BatchNormOpBuilder::IsOpSupported(QnnModelWrapper& qnn_model_wrapper, // Still do it here so hopefully QNN Op validation API can tell us some details why it's not supported return AddToModelBuilder(qnn_model_wrapper, node_unit, logger, true); } else { + // Check input datatype. Can't use Qnn Op validation API since it's before layout transformation + ORT_RETURN_IF_ERROR(ProcessDataTypes(qnn_model_wrapper, node_unit)); + const auto& inputs = node_unit.Inputs(); ORT_RETURN_IF_NOT(inputs.size() == 5, "5 input expected per BatchNorm Onnx Spec."); - // Check input type is float for CPU. Can't use Qnn Op validation API since it's before layout transformation - ORT_RETURN_IF_ERROR(DataTypeCheckForCpuBackend(qnn_model_wrapper, inputs[0].node_arg.Type())); - std::vector input_shape; ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(inputs[0].node_arg, input_shape), "Cannot get shape of input 0."); const size_t input_rank = input_shape.size(); @@ -617,5 +624,61 @@ void CreateBatchNormOpBuilder(const std::string& op_type, OpBuilderRegistrations op_registrations.AddOpBuilder(op_type, std::make_unique()); } +Status BatchNormOpBuilder::CheckCpuDataTypes(const std::vector in_dtypes, + const std::vector out_dtypes) const { + bool is_supported_dtype = false; + // in_dtypes: [X, scale, B, input_mean, input_var] + std::vector all_dtypes(in_dtypes.begin(), in_dtypes.begin() + 3); + // out_dtypes: [Y, running_mean, running_var] + all_dtypes.insert(all_dtypes.end(), out_dtypes.begin(), out_dtypes.begin() + 1); + // FP32 + if ( + (all_dtypes == std::vector{QNN_DATATYPE_FLOAT_32, QNN_DATATYPE_FLOAT_32, QNN_DATATYPE_FLOAT_32, QNN_DATATYPE_FLOAT_32})) { + is_supported_dtype = true; + } + // INT8 + else if ( + (all_dtypes == std::vector{QNN_DATATYPE_UFIXED_POINT_8, QNN_DATATYPE_UFIXED_POINT_8, QNN_DATATYPE_UFIXED_POINT_8, QNN_DATATYPE_UFIXED_POINT_8}) || + (all_dtypes == std::vector{QNN_DATATYPE_UFIXED_POINT_8, QNN_DATATYPE_UFIXED_POINT_8, QNN_DATATYPE_SFIXED_POINT_32, QNN_DATATYPE_UFIXED_POINT_8})) { + is_supported_dtype = true; + } + ORT_RETURN_IF_NOT(is_supported_dtype, "QNN Batchnorm unsupported datatype on CPU."); + return Status::OK(); +} + +Status BatchNormOpBuilder::CheckHtpDataTypes(const std::vector in_dtypes, + const std::vector out_dtypes) const { + bool is_supported_dtype = false; + // in_dtypes: [X, scale, B, input_mean, input_var] + std::vector all_dtypes(in_dtypes.begin(), in_dtypes.begin() + 3); + // out_dtypes: [Y, running_mean, running_var] + all_dtypes.insert(all_dtypes.end(), out_dtypes.begin(), out_dtypes.begin() + 1); + // FP16 + if ( + (all_dtypes == std::vector{QNN_DATATYPE_FLOAT_16, QNN_DATATYPE_FLOAT_16, QNN_DATATYPE_FLOAT_16, QNN_DATATYPE_FLOAT_16}) || + (all_dtypes == std::vector{QNN_DATATYPE_FLOAT_32, QNN_DATATYPE_FLOAT_32, QNN_DATATYPE_FLOAT_32, QNN_DATATYPE_FLOAT_32})) { + is_supported_dtype = true; + } + // INT16 + else if ( + (all_dtypes == std::vector{QNN_DATATYPE_UFIXED_POINT_16, QNN_DATATYPE_UFIXED_POINT_8, QNN_DATATYPE_UFIXED_POINT_8, QNN_DATATYPE_UFIXED_POINT_16}) || + (all_dtypes == std::vector{QNN_DATATYPE_UFIXED_POINT_16, QNN_DATATYPE_UFIXED_POINT_8, QNN_DATATYPE_SFIXED_POINT_32, QNN_DATATYPE_UFIXED_POINT_16}) || + (all_dtypes == std::vector{QNN_DATATYPE_UFIXED_POINT_16, QNN_DATATYPE_UFIXED_POINT_16, QNN_DATATYPE_UFIXED_POINT_8, QNN_DATATYPE_UFIXED_POINT_16}) || + (all_dtypes == std::vector{QNN_DATATYPE_UFIXED_POINT_16, QNN_DATATYPE_UFIXED_POINT_16, QNN_DATATYPE_SFIXED_POINT_32, QNN_DATATYPE_UFIXED_POINT_16}) || + (all_dtypes == std::vector{QNN_DATATYPE_UFIXED_POINT_16, QNN_DATATYPE_SFIXED_POINT_16, QNN_DATATYPE_UFIXED_POINT_8, QNN_DATATYPE_UFIXED_POINT_16}) || + (all_dtypes == std::vector{QNN_DATATYPE_UFIXED_POINT_16, QNN_DATATYPE_SFIXED_POINT_16, QNN_DATATYPE_SFIXED_POINT_32, QNN_DATATYPE_UFIXED_POINT_16})) { + is_supported_dtype = true; + } + // INT8 + else if ( + (all_dtypes == std::vector{QNN_DATATYPE_UFIXED_POINT_8, QNN_DATATYPE_UFIXED_POINT_8, QNN_DATATYPE_UFIXED_POINT_8, QNN_DATATYPE_UFIXED_POINT_8}) || + (all_dtypes == std::vector{QNN_DATATYPE_UFIXED_POINT_8, QNN_DATATYPE_UFIXED_POINT_8, QNN_DATATYPE_SFIXED_POINT_32, QNN_DATATYPE_UFIXED_POINT_8}) || + (all_dtypes == std::vector{QNN_DATATYPE_SFIXED_POINT_8, QNN_DATATYPE_SFIXED_POINT_8, QNN_DATATYPE_SFIXED_POINT_8, QNN_DATATYPE_SFIXED_POINT_8})) { + is_supported_dtype = true; + } + ORT_RETURN_IF_NOT(is_supported_dtype, "QNN Batchnorm unsupported datatype on HTP."); + return Status::OK(); +}; + } // namespace qnn } // namespace onnxruntime diff --git a/onnxruntime/test/providers/qnn/batch_norm_htp_test.cc b/onnxruntime/test/providers/qnn/batch_norm_test.cc similarity index 71% rename from onnxruntime/test/providers/qnn/batch_norm_htp_test.cc rename to onnxruntime/test/providers/qnn/batch_norm_test.cc index 73bb6f2d203c0..c88a0ce6cf0b2 100644 --- a/onnxruntime/test/providers/qnn/batch_norm_htp_test.cc +++ b/onnxruntime/test/providers/qnn/batch_norm_test.cc @@ -14,8 +14,6 @@ namespace onnxruntime { namespace test { -#if defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) - // Computes the mean and variance of inputs within a channel. // Requires an input with rank >= 3 template @@ -141,6 +139,66 @@ GetTestQDQModelFn BuildQDQBatchNormTestCase(const TestInputDef +static void RunBatchNormQDQTestOnCPU(const TestInputDef& input_def, + const TestInputDef& scale_def, + const TestInputDef& bias_def, + ExpectedEPNodeAssignment expected_ep_assignment, + QDQTolerance tolerance = QDQTolerance()) { + ProviderOptions provider_options; + provider_options["backend_type"] = "cpu"; + provider_options["offload_graph_io_quantization"] = "0"; + + // Runs model with DQ-> InstanceNorm -> Q and compares the outputs of the CPU and QNN EPs. + TestQDQModelAccuracy(BuildBatchNormTestCase(input_def, scale_def, bias_def), + BuildQDQBatchNormTestCase(input_def, scale_def, bias_def), + provider_options, + 21, + expected_ep_assignment, + tolerance); +} + +TEST_F(QnnCPUBackendTests, BatchNorm2D_fp32) { + constexpr int64_t num_channels = 2; + std::vector input_data = {-8.0f, -6.0f, -4.0f, -2.0f, 0.0f, 1.1f, 3.3f, 8.0f, + -7.0f, -5.0f, -3.0f, -1.0f, 0.0f, 2.1f, 4.3f, 7.0f}; + + ProviderOptions provider_options; + provider_options["backend_type"] = "cpu"; + + RunQnnModelTest( + BuildBatchNormTestCase( + TestInputDef({2, num_channels, 2, 2}, false, input_data), // Input data + TestInputDef({num_channels}, true, {1.0f, 2.0f}), // Scale initializer + TestInputDef({num_channels}, true, {1.1f, 2.1f}) // Bias initializer + ), + provider_options, + 13, + ExpectedEPNodeAssignment::All); +} + +TEST_F(QnnCPUBackendTests, BatchNorm2D_int8) { + constexpr int64_t num_channels = 2; + std::vector input_data = {-8.0f, -6.0f, -4.0f, -2.0f, 0.0f, 1.1f, 3.3f, 8.0f, + -7.0f, -5.0f, -3.0f, -1.0f, 0.0f, 2.1f, 4.3f, 7.0f}; + + RunBatchNormQDQTestOnCPU( + TestInputDef({2, num_channels, 2, 2}, false, input_data), // Input data + TestInputDef({num_channels}, true, {1.0f, 2.0f}), // Scale initializer + TestInputDef({num_channels}, true, {1.1f, 2.1f}), // Bias initializer + ExpectedEPNodeAssignment::All, + QDQTolerance()); +} + +#if defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) + /** * Runs an BatchNormalization model on the QNN HTP backend. Checks the graph node assignment, and that inference * outputs for QNN and CPU match. @@ -228,7 +286,7 @@ TEST_F(QnnHTPBackendTests, BatchNorm1D) { // Check that QNN compiles DQ -> BatchNormalization -> Q as a single unit. // Use an input of rank 4. -TEST_F(QnnHTPBackendTests, BatchNorm2D_a8w8) { +TEST_F(QnnHTPBackendTests, BatchNorm2D_U8U8S32) { constexpr int64_t num_channels = 2; std::vector input_data = {-8.0f, -6.0f, -4.0f, -2.0f, 0.0f, 1.1f, 3.3f, 8.0f, -7.0f, -5.0f, -3.0f, -1.0f, 0.0f, 2.1f, 4.3f, 7.0f}; @@ -241,7 +299,46 @@ TEST_F(QnnHTPBackendTests, BatchNorm2D_a8w8) { // Check that QNN compiles DQ -> BatchNormalization -> Q as a single unit. // Use an input of rank 4. -TEST_F(QnnHTPBackendTests, BatchNorm2D_a16w8) { +TEST_F(QnnHTPBackendTests, BatchNorm2D_U8S8S32) { + constexpr int64_t num_channels = 2; + std::vector input_data = {-8.0f, -6.0f, -4.0f, -2.0f, 0.0f, 1.1f, 3.3f, 8.0f, + -7.0f, -5.0f, -3.0f, -1.0f, 0.0f, 2.1f, 4.3f, 7.0f}; + + RunBatchNormQDQTest(TestInputDef({2, num_channels, 2, 2}, false, input_data), // Input data + TestInputDef({num_channels}, true, {1.0f, 2.0f}), // Scale initializer + TestInputDef({num_channels}, true, {1.1f, 2.1f}), // Bias initializer + ExpectedEPNodeAssignment::None); +} + +// Check that QNN compiles DQ -> BatchNormalization -> Q as a single unit. +// Use an input of rank 4. +TEST_F(QnnHTPBackendTests, BatchNorm2D_U8U16S32) { + constexpr int64_t num_channels = 2; + std::vector input_data = {-8.0f, -6.0f, -4.0f, -2.0f, 0.0f, 1.1f, 3.3f, 8.0f, + -7.0f, -5.0f, -3.0f, -1.0f, 0.0f, 2.1f, 4.3f, 7.0f}; + + RunBatchNormQDQTest(TestInputDef({2, num_channels, 2, 2}, false, input_data), // Input data + TestInputDef({num_channels}, true, {1.0f, 2.0f}), // Scale initializer + TestInputDef({num_channels}, true, {1.1f, 2.1f}), // Bias initializer + ExpectedEPNodeAssignment::None); +} + +// Check that QNN compiles DQ -> BatchNormalization -> Q as a single unit. +// Use an input of rank 4. +TEST_F(QnnHTPBackendTests, BatchNorm2D_U8S16S32) { + constexpr int64_t num_channels = 2; + std::vector input_data = {-8.0f, -6.0f, -4.0f, -2.0f, 0.0f, 1.1f, 3.3f, 8.0f, + -7.0f, -5.0f, -3.0f, -1.0f, 0.0f, 2.1f, 4.3f, 7.0f}; + + RunBatchNormQDQTest(TestInputDef({2, num_channels, 2, 2}, false, input_data), // Input data + TestInputDef({num_channels}, true, {1.0f, 2.0f}), // Scale initializer + TestInputDef({num_channels}, true, {1.1f, 2.1f}), // Bias initializer + ExpectedEPNodeAssignment::None); +} + +// Check that QNN compiles DQ -> BatchNormalization -> Q as a single unit. +// Use an input of rank 4. +TEST_F(QnnHTPBackendTests, BatchNorm2D_U16U8S32) { constexpr int64_t num_channels = 2; std::vector input_data = {-8.0f, -6.0f, -4.0f, -2.0f, 0.0f, 1.1f, 3.3f, 8.0f, -7.0f, -5.0f, -3.0f, -1.0f, 0.0f, 2.1f, 4.3f, 7.0f}; @@ -252,6 +349,20 @@ TEST_F(QnnHTPBackendTests, BatchNorm2D_a16w8) { ExpectedEPNodeAssignment::All); } +// Check that QNN compiles DQ -> BatchNormalization -> Q as a single unit. +// Use an input of rank 4. +// Turn on the testcase when ORT QNN-EP supports unsigned symmetric dtypes +TEST_F(QnnHTPBackendTests, DISABLED_BatchNorm2D_U16U16S32) { + constexpr int64_t num_channels = 2; + std::vector input_data = {-8.0f, -6.0f, -4.0f, -2.0f, 0.0f, 1.1f, 3.3f, 8.0f, + -7.0f, -5.0f, -3.0f, -1.0f, 0.0f, 2.1f, 4.3f, 7.0f}; + + RunBatchNormQDQTest(TestInputDef({2, num_channels, 2, 2}, false, input_data), // Input data + TestInputDef({num_channels}, true, {1.0f, 2.0f}), // Scale initializer + TestInputDef({num_channels}, true, {1.1f, 2.1f}), // Bias initializer + ExpectedEPNodeAssignment::All); +} + // Test FP16 BatchNormalization on the HTP backend. TEST_F(QnnHTPBackendTests, BatchNorm_FP16) { constexpr int64_t num_channels = 2; From a09246868ff8da633dc75a1bead1c3593802b5d2 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Wed, 2 Jul 2025 23:06:09 -0700 Subject: [PATCH 11/19] [build] Fix CUDA build (#25273) ### Description ### Motivation and Context --- onnxruntime/core/session/environment.cc | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/onnxruntime/core/session/environment.cc b/onnxruntime/core/session/environment.cc index b3176b399756e..8acf1df06b46d 100644 --- a/onnxruntime/core/session/environment.cc +++ b/onnxruntime/core/session/environment.cc @@ -398,7 +398,9 @@ Status Environment::CreateAndRegisterAllocatorV2(const std::string& provider_typ #if defined(USE_CUDA) || defined(USE_CUDA_PROVIDER_INTERFACE) if (provider_type == onnxruntime::kCudaExecutionProvider) { if (mem_info.device.MemType() == OrtDevice::MemType::HOST_ACCESSIBLE) { - AllocatorPtr allocator_ptr = GetProviderInfo_CUDA().CreateCUDAPinnedAllocator(onnxruntime::CUDA_PINNED); + AllocatorPtr allocator_ptr = GetProviderInfo_CUDA().CreateCUDAPinnedAllocator( + static_cast(mem_info.device.Id()), + onnxruntime::CUDA_PINNED); return RegisterAllocatorImpl(allocator_ptr); } else { CUDAExecutionProviderInfo cuda_ep_info; From 7fc6235861d0b01fe6bba3e6b2f233f5797a92cf Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Thu, 3 Jul 2025 02:57:56 -0700 Subject: [PATCH 12/19] [build] fix build on mac (#25270) ### Description Fixes build on macOS. always include `"core/graph/onnx_protobuf.h"` before including other onnx/protobuf headers. ``` .../build/MacOS/Debug/_deps/protobuf-src/src/google/protobuf/parse_context.h:328:47: error: implicit conversion loses integer precision: 'long' to 'int' [-Werror,-Wshorten-64-to-32] 328 | int chunk_size = buffer_end_ + kSlopBytes - ptr; | ~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~^~~~~ ``` --- onnxruntime/test/framework/function_test.cc | 1 + onnxruntime/test/optimizer/graph_transform_test.cc | 4 +++- .../layout_transformation_potentially_added_ops_test.cc | 2 +- 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/onnxruntime/test/framework/function_test.cc b/onnxruntime/test/framework/function_test.cc index 180a75a64c10e..6cde2fbc71f5d 100644 --- a/onnxruntime/test/framework/function_test.cc +++ b/onnxruntime/test/framework/function_test.cc @@ -3,6 +3,7 @@ #include "gtest/gtest.h" +#include "core/graph/onnx_protobuf.h" #include "onnx/defs/parser.h" #include "core/common/span_utils.h" diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index 099b8b23dc93d..d4b54852cc1d0 100644 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -9,6 +9,9 @@ #include "gtest/gtest.h" #include "gmock/gmock.h" + +#include "core/graph/onnx_protobuf.h" + #include "onnx/defs/parser.h" #include "onnx/defs/printer.h" @@ -18,7 +21,6 @@ #include "core/graph/graph_utils.h" #include "core/graph/graph_viewer.h" #include "core/graph/model.h" -#include "core/graph/onnx_protobuf.h" #include "core/mlas/inc/mlas_q4.h" #include "core/optimizer/attention_fusion.h" #include "core/optimizer/bias_dropout_fusion.h" diff --git a/onnxruntime/test/optimizer/layout_transformation_potentially_added_ops_test.cc b/onnxruntime/test/optimizer/layout_transformation_potentially_added_ops_test.cc index 2d8a40ddd26db..c092c660633b1 100644 --- a/onnxruntime/test/optimizer/layout_transformation_potentially_added_ops_test.cc +++ b/onnxruntime/test/optimizer/layout_transformation_potentially_added_ops_test.cc @@ -5,7 +5,7 @@ #include "gtest/gtest.h" -#include "onnx/defs/schema.h" +#include "core/graph/onnx_protobuf.h" #include "core/graph/constants.h" From 05da42ca258a2395de284b0a8c1a131df368e3ac Mon Sep 17 00:00:00 2001 From: Kevin Chen <45886021+kevinch-nv@users.noreply.github.com> Date: Thu, 3 Jul 2025 11:27:11 -0700 Subject: [PATCH 13/19] Fix TRT-EP build for EP graph tests (#25202) ### Description New binary `onnxruntime_ep_graph_test` uses `test_main.cc`, which contains deprecated declarations from `NvInfer.h`. Ignore these warnings when building TRT EP. ### Motivation and Context Fixes build for TRT EP. Signed-off-by: Kevin Chen --- cmake/onnxruntime_unittests.cmake | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index c202650a134a6..e8809bd2392c8 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -1111,7 +1111,7 @@ if (NOT IOS) target_link_libraries(onnx_test_runner PRIVATE onnx_test_runner_common ${GETOPT_LIB_WIDE} ${onnx_test_libs} nlohmann_json::nlohmann_json) target_include_directories(onnx_test_runner PRIVATE ${ONNXRUNTIME_ROOT}) - + if (onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) target_link_libraries(onnx_test_runner PRIVATE Python::Python) endif() @@ -1232,7 +1232,7 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) target_include_directories(onnxruntime_perf_test PRIVATE ${onnx_test_runner_src_dir} ${ONNXRUNTIME_ROOT} ${onnxruntime_graph_header} ${onnxruntime_exec_src_dir} ${CMAKE_CURRENT_BINARY_DIR}) - + if (WIN32) target_compile_options(onnxruntime_perf_test PRIVATE ${disabled_warnings}) if (NOT DEFINED SYS_PATH_LIB) @@ -1338,7 +1338,7 @@ endif() if (onnxruntime_USE_CUDA) list(APPEND onnxruntime_shared_lib_test_LIBS) endif() - + if (onnxruntime_USE_TENSORRT) list(APPEND onnxruntime_shared_lib_test_LIBS ${TENSORRT_LIBRARY_INFER}) endif() @@ -1372,7 +1372,7 @@ endif() if (onnxruntime_USE_NV) target_include_directories(onnxruntime_shared_lib_test PRIVATE ${CUDAToolkit_INCLUDE_DIRS}) endif() - + if (CMAKE_SYSTEM_NAME STREQUAL "Android") target_sources(onnxruntime_shared_lib_test PRIVATE @@ -1429,7 +1429,7 @@ endif() DEPENDS ${all_dependencies} ) - + target_compile_definitions(onnxruntime_test_debug_node_inputs_outputs PRIVATE DEBUG_NODE_INPUTS_OUTPUTS) @@ -1983,6 +1983,11 @@ if (onnxruntime_BUILD_SHARED_LIB AND NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten" LIBS ${onnxruntime_ep_graph_test_LIBS} DEPENDS ${all_dependencies} ) + if (UNIX AND (onnxruntime_USE_TENSORRT OR onnxruntime_USE_NV)) + # The test_main.cc includes NvInfer.h where it has many deprecated declarations + # simply ignore them for TensorRT EP build + set_property(TARGET onnxruntime_ep_graph_test APPEND_STRING PROPERTY COMPILE_FLAGS "-Wno-deprecated-declarations") + endif() endif() include(onnxruntime_fuzz_test.cmake) From 517d684c8db3e0024fc4693f0b0399b8473b2461 Mon Sep 17 00:00:00 2001 From: Adrian Lizarraga Date: Thu, 3 Jul 2025 13:02:30 -0700 Subject: [PATCH 14/19] [EP ABI] Use pre-allocated input buffers for APIs that return arrays. (#25247) ### Description - Remove `OrtArrayOfConstObjects` from C API - Rework graph APIs that return an array of objects to take pre-allocated buffers as input. - Rename `Node_GetParentGraph` to `Node_GetGraph` - Fixes C/C++ API documentation generation: https://github.com/microsoft/onnxruntime/actions/runs/16029991022 ### Motivation and Context Make the graph C APIs easier to use. The `OrtArrayOfConstObjects` approach was too verbose to use and made the API harder to understand because the function signatures did not show the element data types. Example usage with `OrtArrayOfConstObjects`: ```c++ const OrtGraph* graph; // Assumed is initialized OrtArrayOfConstObjects* nodes = nullptr; RETURN_IF_ERROR(ort_api.Graph_GetNodes(graph, &nodes)); // Get array size_t num_nodes = 0; RETURN_IF_ERROR(ort_api.ArrayOfConstObjects_GetSize(nodes, &num_nodes)); // Get size // Use the nodes. for (size_t i = 0; i < num_nodes; i++) { const OrtNode* node = nullptr; RETURN_IF_ERROR(ort_api.ArrayOfConstObjects_GetElementAt(nodes, i, reinterpret_cast(&node))); // Inspect OrtNode properties ... } // Have to manually release the OrtArrayOfConstObjects // A C++ ORT wrapper class would help via RAII, but the same C api calls are made under the hood. ort_api.ReleaseArrayOfConstObjects(nodes); ``` Example usage with "pre-allocated" buffers style: ```c++ const OrtGraph* graph; // Assumed is initialized // Get number of nodes. size_t num_nodes = 0; RETURN_IF_ERROR(ort_api.Graph_GetNumNodes(graph, &num_nodes)); // Pre-allocate buffer of OrtNode* and get nodes. std::vector nodes(num_nodes); RETURN_IF_ERROR(ort_api.Graph_GetNodes(graph, nodes.data(), nodes.size())); // Use the nodes. for (size_t i = 0; i < num_nodes; i++) { const OrtNode* node = nodes[i]; // Inspect OrtNode properties. } // std::vector destructor cleans up for us. ``` --------- Co-authored-by: Edward Chen <18449977+edgchen1@users.noreply.github.com> --- .../core/session/onnxruntime_c_api.h | 407 +++++++++--------- .../core/framework/abi_pointer_array.h | 19 - onnxruntime/core/graph/abi_graph_types.h | 114 +++-- onnxruntime/core/graph/ep_api_types.cc | 145 +++++-- onnxruntime/core/graph/ep_api_types.h | 65 ++- .../core/graph/model_editor_api_types.h | 45 +- .../session/ep_plugin_provider_interfaces.cc | 1 - onnxruntime/core/session/onnxruntime_c_api.cc | 327 +++++++------- onnxruntime/core/session/ort_apis.h | 66 +-- onnxruntime/test/autoep/library/ep.cc | 122 ++---- onnxruntime/test/ep_graph/test_ep_graph.cc | 158 +++---- .../test/ep_graph/test_ep_graph_topo_sort.cc | 46 +- .../test/ep_graph/test_ep_graph_utils.h | 21 - 13 files changed, 764 insertions(+), 772 deletions(-) delete mode 100644 onnxruntime/core/framework/abi_pointer_array.h diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index c9aaa38426a7b..4fd47f9f3f997 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -321,7 +321,6 @@ ORT_RUNTIME_CLASS(ModelCompilationOptions); ORT_RUNTIME_CLASS(HardwareDevice); ORT_RUNTIME_CLASS(EpDevice); ORT_RUNTIME_CLASS(KeyValuePairs); -ORT_RUNTIME_CLASS(ArrayOfConstObjects); #ifdef _MSC_VER typedef _Return_type_success_(return == 0) OrtStatus* OrtStatusPtr; @@ -493,17 +492,6 @@ typedef OrtStatus*(ORT_API_CALL* EpSelectionDelegate)(_In_ const OrtEpDevice** e _Out_ size_t* num_selected, _In_ void* state); -/** \brief Enum tags for ORT runtime types used to identify the type of elements in containers, - * like OrtArrayOfConstObjects. - */ -typedef enum OrtTypeTag { - ORT_TYPE_TAG_Void, - ORT_TYPE_TAG_OrtValueInfo, - ORT_TYPE_TAG_OrtOpAttr, - ORT_TYPE_TAG_OrtNode, - ORT_TYPE_TAG_OrtGraph, -} OrtTypeTag; - /** \brief Algorithm to use for cuDNN Convolution Op */ typedef enum OrtCudnnConvAlgoSearch { @@ -3864,11 +3852,11 @@ struct OrtApi { * assigned to QNN EP is dumped to a separate file. * "json_qnn_graph_dir": Directory in which to dump QNN JSON graphs. If not specified, QNN graphs are dumped in the * program's current working directory. Ignored if "dump_json_qnn_graph" is not set. - * "op_packages": QNN UDO op_package for QNN EP, allowed format: - *   ::[:],::[:], - *   where op_type is the name of the operation, op_package_path is the path to the op package shared library, - * interface is the symbol name to register the op life cycle functions, and target is the backend type. For more - * details, refer to: https://docs.qualcomm.com/bundle/publicresource/topics/80-63442-50/op_packages.html + * "op_packages": QNN UDO op_package for QNN EP, allowed format: + *   "::[:],::[:]", + *   where op_type is the name of the operation, op_package_path is the path to the op package shared library, + * interface is the symbol name to register the op life cycle functions, and target is the backend type. For more + * details, refer to: https://docs.qualcomm.com/bundle/publicresource/topics/80-63442-50/op_packages.html * * XNNPACK supported keys: * "intra_op_num_threads": number of thread-pool size to use for XNNPACK execution provider. @@ -5394,136 +5382,8 @@ struct OrtApi { _In_ size_t alignment, enum OrtAllocatorType allocator_type, _Outptr_ OrtMemoryInfo** out); - // - // OrtArrayOfConstObjects - // - - /** \brief Create an OrtArrayOfConstObjects instance, which represents an array of - * pointers to constant opaque objects (i.e., each element is a 'const void*'). - * - * The OrtArrayOfConstObjects instance does not own the underlying objects, only the pointers - * to them. - * - * An OrtArrayOfConstObjects instance stores elements of type 'const void*'. Users - * must check the object's type via ArrayOfConstObjects_GetObjectType before casting objects - * to their actual type. - * - * \param[in] object_type The object's type as indicated by the OrtTypeTag enum. - * \param[in] initial_size The backing array's initial size. Can be set to 0. - * \param[in] initial_value Each element's initial value. Can be set to NULL. - * \param[out] out A pointer to a newly created OrtArrayOfConstObjects instance. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \note Must be released by calling ReleaseArrayOfConstObjects. - * - * \since Version 1.23. - */ - ORT_API2_STATUS(CreateArrayOfConstObjects, _In_ OrtTypeTag object_type, _In_ size_t initial_size, - _In_ const void* initial_value, _Outptr_ OrtArrayOfConstObjects** out); - - ORT_CLASS_RELEASE(ArrayOfConstObjects); - - /** \brief Get a tag that represents the type of the opaque objects stored in a OrtArrayOfConstObjects instance. - * - * Refer to OrtTypeTag for valid values. - * - * \param[in] array The OrtArrayOfConstObjects instance. - * \param[out] type_tag Output parameter set to the type tag that corresponds to the object type. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.23. - */ - ORT_API2_STATUS(ArrayOfConstObjects_GetObjectType, _In_ const OrtArrayOfConstObjects* array, - _Out_ OrtTypeTag* type_tag); - - /** \brief Get a pointer to a data buffer of contiguous elements, where each element is a constant pointer to a - * constant opaque object (i.e., each element is a 'const void* const'). - * - * Caller must cast the objects to the appropriate type indicated by ArrayOfConstObjects_GetObjectType. - * - * \param[in] array The OrtArrayOfConstObjects instance. - * \param[out] data Output parameter set to the contiguous data buffer that stores all elements. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.23. - */ - ORT_API2_STATUS(ArrayOfConstObjects_GetData, _In_ const OrtArrayOfConstObjects* array, - _Outptr_ const void* const** data); - - /** \brief Get a pointer to a data buffer of contiguous elements, where each element is a pointer to a - * constant opaque object (i.e., each element is a 'const void*'). - * - * Caller must cast the objects to the appropriate type indicated by ArrayOfConstObjects_GetObjectType. - * - * \param[in] array The OrtArrayOfConstObjects instance. - * \param[out] data Output parameter set to the contiguous data buffer that stores all elements. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.23. - */ - ORT_API2_STATUS(ArrayOfConstObjects_GetMutableData, _In_ OrtArrayOfConstObjects* array, _Outptr_ const void*** data); - - /** \brief Get the number of elements contained by the given OrtArrayOfConstObjects instance. - * - * \param[in] array The OrtArrayOfConstObjects instance. - * \param[out] size Output parameter set to the number of elements in the array. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.23. - */ - ORT_API2_STATUS(ArrayOfConstObjects_GetSize, _In_ const OrtArrayOfConstObjects* array, _Out_ size_t* size); - - /** \brief Get the element at the given index. Returns an error status if the index is outside the array bounds. - * - * Caller must cast the object to the appropriate type indicated by ArrayOfConstObjects_GetObjectType. - * Example: - * // Assume OrtTypeTag is ORT_TYPE_TAG_OrtNode and there is at least one node in the array. - * const OrtNode* node = nullptr; - * OrtStatus status = ort_api.ArrayOfConstObjects_GetElementAt(nodes, 0, reinterpret_cast(&node))); - * - * \param[in] array The OrtArrayOfConstObjects instance. - * \param[in] index The index of the element. - * \param[out] out Output parameter set to the element at the given index. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.23. - */ - ORT_API2_STATUS(ArrayOfConstObjects_GetElementAt, _In_ const OrtArrayOfConstObjects* array, _In_ size_t index, - _Outptr_ const void** out); - - /** \brief Set the element at the given index. Returns an error status if the index is outside the array bounds. - * - * \param[in] array The OrtArrayOfConstObjects instance. - * \param[in] index The index of the element. - * \param[in] element The element to set. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.23. - */ - ORT_API2_STATUS(ArrayOfConstObjects_SetElementAt, _In_ OrtArrayOfConstObjects* array, _In_ size_t index, - _In_ const void* element); - - /** \brief Appends an element to the end of the array, which increases the size of the array by one. - * - * \param[in] array The OrtArrayOfConstObjects instance. - * \param[in] element The element to append. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.23. - */ - ORT_API2_STATUS(ArrayOfConstObjects_AppendElement, _In_ OrtArrayOfConstObjects* array, _In_ const void* element); - - // - // OrtValueInfo - // + /// \name OrtValueInfo + /// @{ /** \brief Get the OrtNode that produces the value represented by the given OrtValueInfo. * Optionally returns the associated output index. @@ -5568,21 +5428,23 @@ struct OrtApi { * - input_indices: [0, 1] * * \param[in] value_info The OrtValueInfo instance. - * \param[out] nodes Pre-allocated array of size `max_num_consumers` that will be filled with OrtNode instances. - * \param[out] input_indices Pre-allocated array of `max_num_consumers` elements that will be filled + * \param[out] nodes Pre-allocated array of size `num_consumers` that is filled with OrtNode instances. + * \param[out] input_indices Pre-allocated array of `num_consumers` elements that is filled * with input indices. Index is set to -1 for an "implicit" input to a consumer node * that contains a subgraph (e.g., If, Loop) with nodes that use the value internally. - * \param[in] max_num_consumers The maximum size of the `consumer_nodes` and `consumer_input_indices` arrays. - * Typical usage sets this to the value of ValueInfo_GetValueNumConsumers(). + * \param[in] num_consumers The size of the `consumer_nodes` and `consumer_input_indices` arrays. + * Typical usage sets this to the value of ValueInfo_GetValueNumConsumers(). + * An error status is returned if `num_consumers` is less than the number of actual + * consumers. * * \snippet{doc} snippets.dox OrtStatus Return Value * * \since Version 1.23. */ ORT_API2_STATUS(ValueInfo_GetValueConsumers, _In_ const OrtValueInfo* value_info, - _Out_writes_all_(max_num_consumers) const OrtNode** nodes, - _Out_writes_all_(max_num_consumers) int64_t* input_indices, - _In_ size_t max_num_consumers); + _Out_writes_all_(num_consumers) const OrtNode** nodes, + _Out_writes_all_(num_consumers) int64_t* input_indices, + _In_ size_t num_consumers); /** \brief Get the underlying initializer value, as an OrtValue, from the given OrtValueInfo. * @@ -5677,9 +5539,10 @@ struct OrtApi { ORT_API2_STATUS(ValueInfo_IsFromOuterScope, _In_ const OrtValueInfo* value_info, _Out_ bool* is_from_outer_scope); - // - // OrtGraph - // + /// @} + + /// \name OrtGraph + /// @{ /** \brief Returns a graph's name. * @@ -5703,34 +5566,78 @@ struct OrtApi { */ ORT_API2_STATUS(Graph_GetOnnxIRVersion, _In_ const OrtGraph* graph, _Out_ int64_t* onnx_ir_version); + /** \brief Returns the number of graph inputs. + * + * \note The count includes initializers that are included in the list of graph inputs. + * + * \param[in] graph The OrtGraph instance. + * \param[out] num_inputs Output parameter set to the number of graph inputs. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.23. + */ + ORT_API2_STATUS(Graph_GetNumInputs, _In_ const OrtGraph* graph, _Out_ size_t* num_inputs); + /** \brief Returns the graph's inputs as OrtValueInfo instances. * - * Includes initializers that are included in the list of graph inputs. + * \note The result includes initializers that are included in the list of graph inputs. * * \param[in] graph The OrtGraph instance. - * \param[out] inputs Output parameter set to a new OrtArrayOfConstObjects instance containing the graph inputs - * as OrtValueInfo instances. Must be released by calling ReleaseArrayOfConstObjects. + * \param[out] inputs Pre-allocated array of `num_inputs` elements that is filled with the graph's inputs. + * \param[in] num_inputs The size of the `inputs` array. + * Typical usage sets this to the result of Graph_GetNumInputs(). An error status is + * returned if `num_inputs` is less than the number of graph inputs. * * \snippet{doc} snippets.dox OrtStatus Return Value * * \since Version 1.23. */ - ORT_API2_STATUS(Graph_GetInputs, _In_ const OrtGraph* graph, _Outptr_ OrtArrayOfConstObjects** inputs); + ORT_API2_STATUS(Graph_GetInputs, _In_ const OrtGraph* graph, + _Out_writes_(num_inputs) const OrtValueInfo** inputs, _In_ size_t num_inputs); + + /** \brief Returns the number of graph outputs. + * + * \param[in] graph The OrtGraph instance. + * \param[out] num_outputs Output parameter set to the number of graph outputs. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.23. + */ + ORT_API2_STATUS(Graph_GetNumOutputs, _In_ const OrtGraph* graph, _Out_ size_t* num_outputs); /** \brief Returns the graph's outputs as OrtValueInfo instances. * * \param[in] graph The OrtGraph instance. - * \param[out] outputs Output parameter set to a new OrtArrayOfConstObjects instance containing the graph outputs - * as OrtValueInfo instances. Must be released by calling ReleaseArrayOfConstObjects. + * \param[out] outputs Pre-allocated array of `num_outputs` elements that is filled with the graph's outputs. + * \param[in] num_outputs The size of the `outputs` array. + * Typical usage sets this to the result of Graph_GetNumOutputs(). An error status is + * returned if `num_outputs` is less than the number of graph outputs. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.23. + */ + ORT_API2_STATUS(Graph_GetOutputs, _In_ const OrtGraph* graph, + _Out_writes_(num_outputs) const OrtValueInfo** outputs, _In_ size_t num_outputs); + + /** \brief Returns the number of graph initializers. + * + * Counts constant and non-constant initializers. + * + * \param[in] graph The OrtGraph instance. + * \param[out] num_initializers Output parameter set to the number of graph initializers. * * \snippet{doc} snippets.dox OrtStatus Return Value * * \since Version 1.23. */ - ORT_API2_STATUS(Graph_GetOutputs, _In_ const OrtGraph* graph, _Outptr_ OrtArrayOfConstObjects** outputs); + ORT_API2_STATUS(Graph_GetNumInitializers, _In_ const OrtGraph* graph, _Out_ size_t* num_initializers); - /** \brief Returns the graph's initializers as OrtValueInfo instances. Includes constant and non-constant - * initializers. + /** \brief Returns the graph's initializers as OrtValueInfo instances. + * + * Includes constant and non-constant initializers. * * For ONNX IR version < 4, all initializers are constant. * @@ -5740,15 +5647,29 @@ struct OrtApi { * Call ValueInfo_GetInitializerValue to get the initializer's data. * * \param[in] graph The OrtGraph instance. - * \param[out] initializers Output parameter set to a new OrtArrayOfConstObjects instance containing the graph's - * initializers as OrtValueInfo instances. - * Must be released by calling ReleaseArrayOfConstObjects. + * \param[out] initializers Pre-allocated array of `num_outputs` elements that is filled with the initializers. + * \param[in] num_initializers The size of the `initializers` array. Typical usage sets this to the + * result of Graph_GetNumInitializers(). An error status is returned if + * `num_initializers` is less than the number of graph initializers. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.23. + */ + ORT_API2_STATUS(Graph_GetInitializers, _In_ const OrtGraph* graph, + _Out_writes_(num_initializers) const OrtValueInfo** initializers, + _In_ size_t num_initializers); + + /** \brief Returns the number of graph nodes. + * + * \param[in] graph The OrtGraph instance. + * \param[out] num_nodes Output parameter set to the number of graph nodes. * * \snippet{doc} snippets.dox OrtStatus Return Value * * \since Version 1.23. */ - ORT_API2_STATUS(Graph_GetInitializers, _In_ const OrtGraph* graph, _Outptr_ OrtArrayOfConstObjects** initializers); + ORT_API2_STATUS(Graph_GetNumNodes, _In_ const OrtGraph* graph, _Out_ size_t* num_nodes); /** \brief Returns the graph's nodes as OrtNode instances. * @@ -5756,14 +5677,17 @@ struct OrtApi { * own node ordering if a different order is required. * * \param[in] graph The OrtGraph instance. - * \param[out] nodes Output parameter set to a new OrtArrayOfConstObjects instance containing the graph's nodes as - * OrtNode instances. Must be released by calling ReleaseArrayOfConstObjects. + * \param[out] nodes Pre-allocated array of `num_nodes` elements that is filled with the graph's nodes. + * \param[in] num_nodes The size of the `nodes` array. Typical usage sets this to the + * result of Graph_GetNumNodes(). An error status is returned if + * `num_nodes` is less than the number of graph nodes. * * \snippet{doc} snippets.dox OrtStatus Return Value * * \since Version 1.23. */ - ORT_API2_STATUS(Graph_GetNodes, const OrtGraph* graph, _Outptr_ OrtArrayOfConstObjects** nodes); + ORT_API2_STATUS(Graph_GetNodes, const OrtGraph* graph, + _Out_writes_(num_nodes) const OrtNode** nodes, _In_ size_t num_nodes); /** \brief Get the parent node for the given graph, if any exists. * @@ -5780,9 +5704,10 @@ struct OrtApi { */ ORT_API2_STATUS(Graph_GetParentNode, _In_ const OrtGraph* graph, _Outptr_result_maybenull_ const OrtNode** node); - // - // OrtNode - // + /// @} + + /// \name OrtNode + /// @{ /** \brief Returns a node's identifier. * @@ -5842,58 +5767,119 @@ struct OrtApi { */ ORT_API2_STATUS(Node_GetSinceVersion, _In_ const OrtNode* node, _Out_ int* since_version); - /** \brief Returns a node's inputs as OrtValueInfo instances. + /** \brief Returns the number of node inputs. * * \param[in] node The OrtNode instance. - * \param[out] inputs Output parameter set to the OrtArrayOfConstObjects instance containing the node's inputs - * as OrtValueInfo instances. Must be released by calling ReleaseArrayOfConstObjects. + * \param[out] num_inputs Output parameter set to the number of node inputs. * * \snippet{doc} snippets.dox OrtStatus Return Value * * \since Version 1.23. */ - ORT_API2_STATUS(Node_GetInputs, _In_ const OrtNode* node, _Outptr_ OrtArrayOfConstObjects** inputs); + ORT_API2_STATUS(Node_GetNumInputs, _In_ const OrtNode* node, _Out_ size_t* num_inputs); - /** \brief Returns a node's outputs as OrtValueInfo instances. + /** \brief Returns the node's inputs as OrtValueInfo instances. * * \param[in] node The OrtNode instance. - * \param[out] outputs Output parameter set to a new OrtArrayOfConstObjects instance containing the node's outputs - * as OrtValueInfo instances. Must be released by calling ReleaseArrayOfConstObjects. + * \param[out] inputs Pre-allocated array of `num_inputs` elements that is filled with the node's inputs. + * \param[in] num_inputs The size of the `inputs` array. + * Typical usage sets this to the result of Node_GetNumInputs(). An error status is + * returned if `num_inputs` is less than the number of node inputs. * * \snippet{doc} snippets.dox OrtStatus Return Value * * \since Version 1.23. */ - ORT_API2_STATUS(Node_GetOutputs, _In_ const OrtNode* node, _Outptr_ OrtArrayOfConstObjects** outputs); + ORT_API2_STATUS(Node_GetInputs, _In_ const OrtNode* node, + _Out_writes_(num_inputs) const OrtValueInfo** inputs, _In_ size_t num_inputs); - /** \brief Get the implicit inputs, as OrtValueInfo instances, that are used within the given node's subgraphs. + /** \brief Returns the number of node outputs. + * + * \param[in] node The OrtNode instance. + * \param[out] num_outputs Output parameter set to the number of node outputs. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.23. + */ + ORT_API2_STATUS(Node_GetNumOutputs, _In_ const OrtNode* node, _Out_ size_t* num_outputs); + + /** \brief Returns the node's outputs as OrtValueInfo instances. + * + * \param[in] node The OrtNode instance. + * \param[out] outputs Pre-allocated array of `num_outputs` elements that is filled with the node's outputs. + * \param[in] num_outputs The size of the `outputs` array. + * Typical usage sets this to the result of Node_GetNumOutputs(). An error status is + * returned if `num_outputs` is less than the number of node outputs. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.23. + */ + ORT_API2_STATUS(Node_GetOutputs, _In_ const OrtNode* node, + _Out_writes_(num_outputs) const OrtValueInfo** outputs, _In_ size_t num_outputs); + + /** \brief Returns the number of node implicit inputs. * * Certain operator types (e.g., If and Loop) contain nested subgraphs. The internal nodes within the nested subgraphs * may use values from the outer scope. Those "outer scope" values are considered implicit inputs to the node that * contains the subgraphs (e.g., the If or Loop node). * * \param[in] node The OrtNode instance. - * \param[out] implicit_inputs Output parameter set to a new OrtArrayOfConstObjects instance containing the node's - * implicit inputs as OrtValueInfo instances. - * Must be released by calling ReleaseArrayOfConstObjects. + * \param[out] num_implicit_inputs Output parameter set to the number of node implicit inputs. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.23. + */ + ORT_API2_STATUS(Node_GetNumImplicitInputs, _In_ const OrtNode* node, _Out_ size_t* num_implicit_inputs); + + /** \brief Get the implicit inputs, as OrtValueInfo instances, that are used within the given node's subgraphs. + * + * \note Only certain operator types (e.g., If and Loop) contain nested subgraphs. + * The internal nodes within the nested subgraphs may use values from the outer scope. Those "outer scope" values + * are considered implicit inputs to the node that contains the subgraphs (e.g., the If or Loop node). + * + * \param[in] node The OrtNode instance. + * \param[out] implicit_inputs Pre-allocated array of `num_implicit_inputs` elements that is filled the node's + * implicit inputs. + * \param[in] num_implicit_inputs The size of the `implicit_inputs` array. Typical usage sets this to the result + * of Node_GetNumImplicitInputs(). An error status is returned if + * `num_implicit_inputs` is less than the number of node implicit inputs. * * \snippet{doc} snippets.dox OrtStatus Return Value * * \since Version 1.23. */ - ORT_API2_STATUS(Node_GetImplicitInputs, _In_ const OrtNode* node, _Outptr_ OrtArrayOfConstObjects** implicit_inputs); + ORT_API2_STATUS(Node_GetImplicitInputs, _In_ const OrtNode* node, + _Out_writes_(num_implicit_inputs) const OrtValueInfo** implicit_inputs, + _In_ size_t num_implicit_inputs); + + /** \brief Returns the number of node attributes. + * + * \param[in] node The OrtNode instance. + * \param[out] num_attributes Output parameter set to the number of node attributes. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.23. + */ + ORT_API2_STATUS(Node_GetNumAttributes, _In_ const OrtNode* node, _Out_ size_t* num_attributes); /** \brief Returns a node's attributes as OrtOpAttr instances. * * \param[in] node The OrtNode instance. - * \param[out] attributes Output parameter set to the OrtArrayOfConstObjects instance containing the node's attributes - * as OrtOpAttr instances. Must be released by calling ReleaseArrayOfConstObjects. + * \param[out] attributes Pre-allocated array of `num_attributes` elements that is filled with the node's attributes. + * \param[in] num_attributes The size of the `num_attributes` array. + * Typical usage sets this to the result of Node_GetNumAttributes(). An error status is + * returned if `num_attributes` is less than the number of node attributes. * * \snippet{doc} snippets.dox OrtStatus Return Value * * \since Version 1.23. */ - ORT_API2_STATUS(Node_GetAttributes, _In_ const OrtNode* node, _Outptr_ OrtArrayOfConstObjects** attributes); + ORT_API2_STATUS(Node_GetAttributes, _In_ const OrtNode* node, + _Out_writes_(num_attributes) const OrtOpAttr** attributes, _In_ size_t num_attributes); /** \brief Gets the OrtNode's attribute as OrtOpAttr by name. * @@ -5905,7 +5891,8 @@ struct OrtApi { * * \since Version 1.23. */ - ORT_API2_STATUS(Node_GetAttributeByName, _In_ const OrtNode* node, _In_ const char* attribute_name, _Outptr_ const OrtOpAttr** attribute); + ORT_API2_STATUS(Node_GetAttributeByName, _In_ const OrtNode* node, _In_ const char* attribute_name, + _Outptr_ const OrtOpAttr** attribute); /** \brief Get the attribute type as OrtOpAttrType from an OrtOpAttr. * @@ -5929,34 +5916,51 @@ struct OrtApi { */ ORT_API2_STATUS(OpAttr_GetName, _In_ const OrtOpAttr* attribute, _Outptr_ const char** name); + /** \brief Returns the number of subgraphs contained by the given node. + * + * \note Only certain operator types (e.g., If and Loop) contain nested subgraphs. + * + * \param[in] node The OrtNode instance. + * \param[out] num_subgraphs Output parameter set to the number of node subgraphs. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.23. + */ + ORT_API2_STATUS(Node_GetNumSubgraphs, _In_ const OrtNode* node, _Out_ size_t* num_subgraphs); + /** \brief Get the subgraphs, as OrtGraph instances, contained by the given node. * - * Certain operator types (e.g., If and Loop) contain nested subgraphs. + * \note Only certain operator types (e.g., If and Loop) contain nested subgraphs. * * \param[in] node The OrtNode instance. - * \param[out] subgraphs Output parameter set to a new OrtArrayOfConstObjects instance containing the node's - * subgraphs as OrtGraph instances. Must be released by calling ReleaseArrayOfConstObjects. + * \param[out] subgraphs Pre-allocated array of `num_subgraphs` elements that is filled with the node's subgraphs. + * \param[in] num_subgraphs The size of the `num_subgraphs` array. + * Typical usage sets this to the result of Node_GetNumSubgraphs(). An error status is + * returned if `num_subgraphs` is less than the number of node subgraphs. * * \snippet{doc} snippets.dox OrtStatus Return Value * * \since Version 1.23. */ - ORT_API2_STATUS(Node_GetSubgraphs, _In_ const OrtNode* node, _Outptr_ OrtArrayOfConstObjects** subgraphs); + ORT_API2_STATUS(Node_GetSubgraphs, _In_ const OrtNode* node, + _Out_writes_(num_subgraphs) const OrtGraph** subgraphs, _In_ size_t num_subgraphs); /** \brief Get the node's parent OrtGraph instance. * * Can return NULL if the OrtNode was created without an owning graph. * * \param[in] node The OrtNode instance. - * \param[out] parent_graph Output parameter set to the node's parent OrtGraph. Can be set to NULL - * if the node is not currently contained by a graph. + * \param[out] graph Output parameter set to the node's OrtGraph. Can be set to NULL + * if the node is not currently contained by a graph. * * \snippet{doc} snippets.dox OrtStatus Return Value * * \since Version 1.23. */ - ORT_API2_STATUS(Node_GetParentGraph, _In_ const OrtNode* node, - _Outptr_result_maybenull_ const OrtGraph** parent_graph); + ORT_API2_STATUS(Node_GetGraph, _In_ const OrtNode* node, _Outptr_result_maybenull_ const OrtGraph** graph); + + /// @} /// \name OrtRunOptions /// @{ @@ -6623,7 +6627,6 @@ typedef enum OrtCompileApiFlags { * \since Version 1.22. */ struct OrtCompileApi { - /// @} /// \name OrtModelCompilationOptions /// @{ ORT_CLASS_RELEASE(ModelCompilationOptions); diff --git a/onnxruntime/core/framework/abi_pointer_array.h b/onnxruntime/core/framework/abi_pointer_array.h deleted file mode 100644 index 91af1f7f9c6c0..0000000000000 --- a/onnxruntime/core/framework/abi_pointer_array.h +++ /dev/null @@ -1,19 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include -#include - -#include "core/session/onnxruntime_c_api.h" - -struct OrtArrayOfConstObjects { - OrtArrayOfConstObjects() = default; - explicit OrtArrayOfConstObjects(OrtTypeTag object_type) : object_type(object_type) {} - OrtArrayOfConstObjects(OrtTypeTag object_type, size_t size, const void* initial_val = nullptr) - : object_type(object_type), storage(size, initial_val) {} - - OrtTypeTag object_type = OrtTypeTag::ORT_TYPE_TAG_Void; - std::vector storage; -}; diff --git a/onnxruntime/core/graph/abi_graph_types.h b/onnxruntime/core/graph/abi_graph_types.h index 46ec81abecce0..c3dd9321ebb0b 100644 --- a/onnxruntime/core/graph/abi_graph_types.h +++ b/onnxruntime/core/graph/abi_graph_types.h @@ -180,48 +180,82 @@ struct OrtNode { virtual onnxruntime::Status GetSinceVersion(int& since_version) const = 0; /// - /// Gets the node's inputs as an array of OrtValueInfo elements wrapped in an OrtArrayOfConstObjects. + /// Returns the number of node inputs. /// - /// Output parameter set to the node's inputs. + /// The number of node inputs. + virtual size_t GetNumInputs() const = 0; + + /// + /// Gets the node's inputs as OrtValueInfo instances. + /// + /// Buffer into which to copy the inputs. + /// A status indicating success or an error. + virtual onnxruntime::Status GetInputs(gsl::span inputs) const = 0; + + /// + /// Returns the number of node outputs. + /// + /// The number of node outputs. + virtual size_t GetNumOutputs() const = 0; + + /// + /// Gets the node's outputs as OrtValueInfo instances. + /// + /// Buffer into which to copy the outputs. /// A status indicating success or an error. - virtual onnxruntime::Status GetInputs(std::unique_ptr& inputs) const = 0; + virtual onnxruntime::Status GetOutputs(gsl::span outputs) const = 0; /// - /// Gets the node's outputs as an array of OrtValueInfo elements wrapped in an OrtArrayOfConstObjects. + /// Returns the number of node implicit inputs. + /// Applies to a node that contains a subgraph (e.g., If or Loop). An implicit input is a value consumed by an + /// internal subgraph node that is not defined in the subgraph. /// - /// Output parameter set to the node's outputs. + /// Output parameter set to the number of implicit inputs. /// A status indicating success or an error. - virtual onnxruntime::Status GetOutputs(std::unique_ptr& outputs) const = 0; + virtual onnxruntime::Status GetNumImplicitInputs(size_t& num_implicit_inputs) const = 0; /// - /// Gets the node's implicit inputs as an array of OrtValueInfo elements wrapped in an OrtArrayOfConstObjects. + /// Gets the node's implicit inputs. /// Applies to a node that contains a subgraph (e.g., If or Loop). An implicit input is a value consumed by an /// internal subgraph node that is not defined in the subgraph. /// - /// Output parameter set to the node's implicit inputs. + /// Buffer into which to copy the implicit inputs. /// A status indicating success or an error. - virtual onnxruntime::Status GetImplicitInputs(std::unique_ptr& implicit_inputs) const = 0; + virtual onnxruntime::Status GetImplicitInputs(gsl::span implicit_inputs) const = 0; /// - /// Gets the node's attributes as an array of OrtOpAttr elements wrapped in an OrtArrayOfConstObjects. + /// Returns the number of node attributes. /// - /// Output parameter set to the node's attributes. + /// The number of node attributes. + virtual size_t GetNumAttributes() const = 0; + + /// + /// Gets the node's attributes. + /// + /// Buffer into which to copy the attributes. /// A status indicating success or an error. - virtual onnxruntime::Status GetAttributes(std::unique_ptr& attrs) const = 0; + virtual onnxruntime::Status GetAttributes(gsl::span attrs) const = 0; + + /// + /// Gets the number of node subgraphs. + /// + /// Output parameter set to the number of subgraphs. + /// A status indicating success or an error. + virtual onnxruntime::Status GetNumSubgraphs(size_t& num_subgraphs) const = 0; /// /// Gets the node's subgraphs (e.g., subgraphs contained by an If or Loop node). /// - /// Output parameter set to the node's subgraphs as OrtGraph instances. + /// Buffer into which to copy the subgraphs. /// A status indicating success or an error. - virtual onnxruntime::Status GetSubgraphs(std::unique_ptr& subgraphs) const = 0; + virtual onnxruntime::Status GetSubgraphs(gsl::span subgraphs) const = 0; /// /// Gets the node's parent graph, which is the graph that contains this node. /// /// Output parameter set to the node's parent graph. /// A status indicating success or an error. - virtual onnxruntime::Status GetParentGraph(const OrtGraph*& parent_graph) const = 0; + virtual onnxruntime::Status GetGraph(const OrtGraph*& parent_graph) const = 0; OrtGraphIrApi graph_ir_api = OrtGraphIrApi::kInvalid; }; @@ -247,34 +281,56 @@ struct OrtGraph { virtual int64_t GetOnnxIRVersion() const = 0; /// - /// Gets the graph's inputs (including initializers) as OrtValueInfo instances wrapped in an OrtArrayOfConstObjects. + /// Returns the number of graph inputs, including initializers that appear in the list of graph inputs. + /// + /// The number of graph inputs. + virtual size_t GetNumInputs() const = 0; + + /// + /// Gets the graph's inputs (including initializers) as OrtValueInfo instances. /// - /// Output parameter set to the graph's inputs. + /// Buffer into which to copy the inputs. /// A status indicating success or an error. - virtual onnxruntime::Status GetInputs(std::unique_ptr& inputs) const = 0; + virtual onnxruntime::Status GetInputs(gsl::span inputs) const = 0; /// - /// Gets the graph's outputs as OrtValueInfo instances wrapped in an OrtArrayOfConstObjects. + /// Returns the number of graph outputs. /// - /// Output parameter set to the graph's outputs. + /// The number of graph outputs. + virtual size_t GetNumOutputs() const = 0; + + /// + /// Gets the graph's outputs as OrtValueInfo instances. + /// + /// Buffer into which to copy the outputs. /// A status indicating success or an error. - virtual onnxruntime::Status GetOutputs(std::unique_ptr& outputs) const = 0; + virtual onnxruntime::Status GetOutputs(gsl::span outputs) const = 0; + + /// + /// Returns the number of graph initializers (both constant and non-constant). + /// + /// The number of graph initializers. + virtual size_t GetNumInitializers() const = 0; /// - /// Gets the graph's initializers (both constant and non-constant) as OrtValueInfo instances wrapped in an - /// OrtArrayOfConstObjects. + /// Gets the graph's initializers (both constant and non-constant) as OrtValueInfo instances. /// - /// Output parameter set to the graph's initializers. + /// The buffer into which to copy the initializers. /// A status indicating success or an error. - virtual onnxruntime::Status GetInitializers(std::unique_ptr& initializers) const = 0; + virtual onnxruntime::Status GetInitializers(gsl::span initializers) const = 0; + + /// + /// Returns the number of graph nodes. + /// + /// The number of graph nodes. + virtual size_t GetNumNodes() const = 0; /// - /// Gets the graph's nodes as OrtNode instances wrapped in an OrtArrayOfConstObjects. The nodes are sorted in - /// a default "reverse DFS" topological order. + /// Gets the graph's nodes. The nodes are sorted in a default "reverse DFS" topological order. /// - /// Output parameter set to the graph's nodes. + /// Buffer into which to copy the nodes. /// A status indicating success or an error. - virtual onnxruntime::Status GetNodes(std::unique_ptr& nodes) const = 0; + virtual onnxruntime::Status GetNodes(gsl::span nodes) const = 0; /// /// Gets the graph's parent node, if any. The parent_node is nullptr if this is not a nested subgraph. diff --git a/onnxruntime/core/graph/ep_api_types.cc b/onnxruntime/core/graph/ep_api_types.cc index d1133a445ebfa..698c7422a1e2a 100644 --- a/onnxruntime/core/graph/ep_api_types.cc +++ b/onnxruntime/core/graph/ep_api_types.cc @@ -22,6 +22,21 @@ namespace onnxruntime { +template +static Status CheckCopyDestination(std::string_view error_array_label, size_t src_size, gsl::span dst) { + if (dst.size() < src_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Not enough space for ", error_array_label, ": expected buffer with room for at least ", + src_size, " elements, but got buffer with room for only ", dst.size(), " elements."); + } + + if (dst.data() == nullptr && src_size > 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Buffer to store ", error_array_label, " is NULL."); + } + + return Status::OK(); +} + // Create an EpValueInfo from a NodeArg. static std::unique_ptr CreateValueInfo(const NodeArg& node_arg, const EpGraph* ep_graph, size_t flags) { const auto* type_proto = node_arg.TypeAsProto(); @@ -152,62 +167,84 @@ Status EpNode::GetSinceVersion(int& since_version) const { return Status::OK(); } -Status EpNode::GetInputs(std::unique_ptr& result) const { - result = std::make_unique(ORT_TYPE_TAG_OrtValueInfo); - result->storage.reserve(inputs_.size()); +size_t EpNode::GetNumInputs() const { + return inputs_.size(); +} + +Status EpNode::GetInputs(gsl::span dst) const { + const size_t num_inputs = inputs_.size(); + ORT_RETURN_IF_ERROR((CheckCopyDestination("node inputs", num_inputs, dst))); - for (const EpValueInfo* input : inputs_) { - result->storage.push_back(input); + for (size_t i = 0; i < num_inputs; ++i) { + dst[i] = inputs_[i]; } return Status::OK(); } -Status EpNode::GetOutputs(std::unique_ptr& result) const { - result = std::make_unique(ORT_TYPE_TAG_OrtValueInfo); - result->storage.reserve(outputs_.size()); +size_t EpNode::GetNumOutputs() const { + return outputs_.size(); +} + +Status EpNode::GetOutputs(gsl::span dst) const { + const size_t num_outputs = outputs_.size(); + ORT_RETURN_IF_ERROR((CheckCopyDestination("node outputs", num_outputs, dst))); - for (const EpValueInfo* output : outputs_) { - result->storage.push_back(output); + for (size_t i = 0; i < num_outputs; ++i) { + dst[i] = outputs_[i]; } return Status::OK(); } -Status EpNode::GetImplicitInputs(std::unique_ptr& result) const { - result = std::make_unique(ORT_TYPE_TAG_OrtValueInfo); - result->storage.reserve(implicit_inputs_.size()); +Status EpNode::GetNumImplicitInputs(size_t& num_implicit_inputs) const { + num_implicit_inputs = implicit_inputs_.size(); + return Status::OK(); +} + +Status EpNode::GetImplicitInputs(gsl::span dst) const { + const size_t num_implicit_inputs = implicit_inputs_.size(); + ORT_RETURN_IF_ERROR((CheckCopyDestination("node implicit inputs", num_implicit_inputs, dst))); - for (const EpValueInfo* implicit_input : implicit_inputs_) { - result->storage.push_back(implicit_input); + for (size_t i = 0; i < num_implicit_inputs; ++i) { + dst[i] = implicit_inputs_[i]; } return Status::OK(); } -Status EpNode::GetAttributes(std::unique_ptr& result) const { - result = std::make_unique(ORT_TYPE_TAG_OrtOpAttr); - result->storage.reserve(attributes_.size()); +size_t EpNode::GetNumAttributes() const { + return attributes_.size(); +} + +Status EpNode::GetAttributes(gsl::span dst) const { + const size_t num_attributes = attributes_.size(); + ORT_RETURN_IF_ERROR((CheckCopyDestination("node attributes", num_attributes, dst))); - for (const OrtOpAttr* attr : attributes_) { - result->storage.push_back(attr); + for (size_t i = 0; i < num_attributes; ++i) { + dst[i] = attributes_[i]; } return Status::OK(); } -Status EpNode::GetSubgraphs(std::unique_ptr& result) const { - result = std::make_unique(ORT_TYPE_TAG_OrtGraph); - result->storage.reserve(subgraphs_.size()); +Status EpNode::GetNumSubgraphs(size_t& num_subgraphs) const { + num_subgraphs = subgraphs_.size(); + return Status::OK(); +} - for (const SubgraphState& subgraph : subgraphs_) { - result->storage.push_back(subgraph.ep_subgraph->ToExternal()); +Status EpNode::GetSubgraphs(gsl::span dst) const { + const size_t num_subgraphs = subgraphs_.size(); + ORT_RETURN_IF_ERROR((CheckCopyDestination("node attributes", num_subgraphs, dst))); + + for (size_t i = 0; i < num_subgraphs; ++i) { + dst[i] = subgraphs_[i].ep_subgraph.get(); } return Status::OK(); } -Status EpNode::GetParentGraph(const OrtGraph*& parent_graph) const { +Status EpNode::GetGraph(const OrtGraph*& parent_graph) const { parent_graph = ep_graph_->ToExternal(); return Status::OK(); } @@ -623,45 +660,61 @@ const std::string& EpGraph::GetName() const { return graph_viewer_.Name(); } int64_t EpGraph::GetOnnxIRVersion() const { return graph_viewer_.GetOnnxIRVersion(); } -Status EpGraph::GetInputs(std::unique_ptr& result) const { - result = std::make_unique(ORT_TYPE_TAG_OrtValueInfo); - result->storage.reserve(inputs_.size()); +size_t EpGraph::GetNumInputs() const { + return inputs_.size(); +} + +Status EpGraph::GetInputs(gsl::span dst) const { + const size_t num_inputs = inputs_.size(); + ORT_RETURN_IF_ERROR((CheckCopyDestination("graph inputs", num_inputs, dst))); - for (const EpValueInfo* input : inputs_) { - result->storage.push_back(input); + for (size_t i = 0; i < num_inputs; ++i) { + dst[i] = inputs_[i]; } return Status::OK(); } -Status EpGraph::GetOutputs(std::unique_ptr& result) const { - result = std::make_unique(ORT_TYPE_TAG_OrtValueInfo); - result->storage.reserve(outputs_.size()); +size_t EpGraph::GetNumOutputs() const { + return outputs_.size(); +} + +Status EpGraph::GetOutputs(gsl::span dst) const { + const size_t num_outputs = outputs_.size(); + ORT_RETURN_IF_ERROR((CheckCopyDestination("graph outputs", num_outputs, dst))); - for (const EpValueInfo* output : outputs_) { - result->storage.push_back(output); + for (size_t i = 0; i < num_outputs; ++i) { + dst[i] = outputs_[i]; } return Status::OK(); } -Status EpGraph::GetInitializers(std::unique_ptr& result) const { - result = std::make_unique(ORT_TYPE_TAG_OrtValueInfo); - result->storage.reserve(initializer_value_infos_.size()); +size_t EpGraph::GetNumInitializers() const { + return initializer_value_infos_.size(); +} + +Status EpGraph::GetInitializers(gsl::span dst) const { + const size_t num_initializers = initializer_value_infos_.size(); + ORT_RETURN_IF_ERROR((CheckCopyDestination("graph initializers", num_initializers, dst))); - for (const EpValueInfo* initializer_value_info : initializer_value_infos_) { - result->storage.push_back(initializer_value_info); + for (size_t i = 0; i < num_initializers; ++i) { + dst[i] = initializer_value_infos_[i]; } return Status::OK(); } -Status EpGraph::GetNodes(std::unique_ptr& result) const { - result = std::make_unique(ORT_TYPE_TAG_OrtNode); - result->storage.reserve(nodes_.size()); +size_t EpGraph::GetNumNodes() const { + return nodes_.size(); +} + +Status EpGraph::GetNodes(gsl::span dst) const { + const size_t num_nodes = nodes_.size(); + ORT_RETURN_IF_ERROR((CheckCopyDestination("graph nodes", num_nodes, dst))); - for (const std::unique_ptr& ep_node : nodes_) { - result->storage.push_back(ep_node->ToExternal()); + for (size_t i = 0; i < num_nodes; ++i) { + dst[i] = nodes_[i].get(); } return Status::OK(); diff --git a/onnxruntime/core/graph/ep_api_types.h b/onnxruntime/core/graph/ep_api_types.h index ba1c4c1ee2b45..4240f5636b7ae 100644 --- a/onnxruntime/core/graph/ep_api_types.h +++ b/onnxruntime/core/graph/ep_api_types.h @@ -12,7 +12,6 @@ #include #include "core/common/inlined_containers.h" -#include "core/framework/abi_pointer_array.h" #include "core/framework/allocator.h" #include "core/graph/basic_types.h" #include "core/graph/abi_graph_types.h" @@ -155,23 +154,38 @@ struct EpNode : public OrtNode { // Gets the opset version in which this node's operator was first defined. Status GetSinceVersion(int& since_version) const override; - // Gets the node's inputs as OrtValueInfo instances wrapped in an OrtArrayOfConstObjects. - Status GetInputs(std::unique_ptr& inputs) const override; + // Get the number of node inputs. + size_t GetNumInputs() const override; - // Gets the node's outputs as OrtValueInfo instances wrapped in an OrtArrayOfConstObjects. - Status GetOutputs(std::unique_ptr& outputs) const override; + // Gets the node's inputs. + Status GetInputs(gsl::span inputs) const override; - // Gets the node's implicit inputs as OrtValueInfo instances wrapped in an OrtArrayOfConstObjects. - Status GetImplicitInputs(std::unique_ptr& inputs) const override; + // Get the number of node outputs. + size_t GetNumOutputs() const override; - // Gets the node's attributes as OrtOpAttr instances wrapped in an OrtArrayOfConstObjects. - Status GetAttributes(std::unique_ptr& attrs) const override; + // Gets the node's outputs. + Status GetOutputs(gsl::span outputs) const override; + + // Gets the number of implicit inputs. + Status GetNumImplicitInputs(size_t& num_implicit_inputs) const override; + + // Gets the node's implicit inputs. + Status GetImplicitInputs(gsl::span inputs) const override; + + // Get the number of node attributes. + size_t GetNumAttributes() const override; + + // Gets the node's attributes. + Status GetAttributes(gsl::span attrs) const override; + + // Gets the number of subgraphs contained by this node. + Status GetNumSubgraphs(size_t& num_subgraphs) const override; // Gets the subgraphs contained by this node. - Status GetSubgraphs(std::unique_ptr& subgraphs) const override; + Status GetSubgraphs(gsl::span subgraphs) const override; // Gets this node's parent graph, which is the graph that directly contains this node. - Status GetParentGraph(const OrtGraph*& parent_graph) const override; + Status GetGraph(const OrtGraph*& parent_graph) const override; // // Helper functions used when working directly with an EpNode. @@ -257,20 +271,31 @@ struct EpGraph : public OrtGraph { // Returns the model's ONNX IR version. int64_t GetOnnxIRVersion() const override; - // Gets the graph's inputs as OrtValueInfo instances wrapped in an OrtArrayOfConstObjects. + // Get the number of graph inputs, including initializers that are listed as graph inputs. + size_t GetNumInputs() const override; + + // Gets the graph's inputs as OrtValueInfo instances. // Includes initializers that are graph inputs. - Status GetInputs(std::unique_ptr& inputs) const override; + Status GetInputs(gsl::span inputs) const override; + + // Get the number of graph outputs. + size_t GetNumOutputs() const override; - // Gets the graph's outputs as OrtValueInfo instances wrapped in an OrtArrayOfConstObjects. - Status GetOutputs(std::unique_ptr& outputs) const override; + // Gets the graph's outputs as OrtValueInfo instances. + Status GetOutputs(gsl::span outputs) const override; - // Gets the graph's initializers as OrtValueInfo instances wrapped in an OrtArrayOfConstObjects. + // Get the number of graph initializers, including both constant and non-constant initializers. + size_t GetNumInitializers() const override; + + // Gets the graph's initializers as OrtValueInfo instances. // Includes both constant initializers and non-constant initializers (aka optional graph inputs). - Status GetInitializers(std::unique_ptr& initializers) const override; + Status GetInitializers(gsl::span initializers) const override; + + // Get the number of nodes in the graph. + size_t GetNumNodes() const override; - // Gets the graph's nodes as OrtNode instances wrapped in an OrtArrayOfConstObjects. - // The nodes are sorted in a default "reverse DFS" topological order. - Status GetNodes(std::unique_ptr& nodes) const override; + // Gets the graph's nodes. The nodes are sorted in a default "reverse DFS" topological order. + Status GetNodes(gsl::span nodes) const override; // Gets the graph's parent node or nullptr if this is not a nested subgraph. Status GetParentNode(const OrtNode*& parent_node) const override; diff --git a/onnxruntime/core/graph/model_editor_api_types.h b/onnxruntime/core/graph/model_editor_api_types.h index 1fb6018164977..6330a42c115db 100644 --- a/onnxruntime/core/graph/model_editor_api_types.h +++ b/onnxruntime/core/graph/model_editor_api_types.h @@ -3,6 +3,7 @@ #pragma once +#include #include #include #include @@ -99,32 +100,48 @@ struct ModelEditorNode : public OrtNode { "OrtModelEditorApi does not support getting an OrtNode's opset version"); } - Status GetInputs(std::unique_ptr& /*inputs*/) const override { + size_t GetNumInputs() const override { return input_names.size(); } + + Status GetInputs(gsl::span /*inputs*/) const override { return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "OrtModelEditorApi does not support getting input OrtValueInfos for OrtNode"); } - Status GetOutputs(std::unique_ptr& /*outputs*/) const override { + size_t GetNumOutputs() const override { return output_names.size(); } + + Status GetOutputs(gsl::span /*outputs*/) const override { return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "OrtModelEditorApi does not support getting output OrtValueInfos for OrtNode"); } - Status GetImplicitInputs(std::unique_ptr& /*implicit_inputs*/) const override { + Status GetNumImplicitInputs(size_t& /*num_implicit_inputs*/) const override { + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, + "OrtModelEditorApi does not support getting the implicit inputs for OrtNode"); + } + + Status GetImplicitInputs(gsl::span /*implicit_inputs*/) const override { return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "OrtModelEditorApi does not support getting the implicit inputs for OrtNode"); } - Status GetAttributes(std::unique_ptr& /*attrs*/) const override { + size_t GetNumAttributes() const override { return attributes.size(); } + + Status GetAttributes(gsl::span /*attrs*/) const override { return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "OrtModelEditorApi does not support getting attribute OrtOpAttr for OrtNode"); } - Status GetSubgraphs(std::unique_ptr& /*subgraphs*/) const override { + Status GetNumSubgraphs(size_t& /*num_subgraphs*/) const override { return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "OrtModelEditorApi does not support getting the subgraphs for OrtNode"); } - Status GetParentGraph(const OrtGraph*& /*parent_graph*/) const override { + Status GetSubgraphs(gsl::span /*subgraphs*/) const override { + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, + "OrtModelEditorApi does not support getting the subgraphs for OrtNode"); + } + + Status GetGraph(const OrtGraph*& /*parent_graph*/) const override { return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "OrtModelEditorApi does not support getting the parent graph for OrtNode"); } @@ -159,22 +176,30 @@ struct ModelEditorGraph : public OrtGraph { return ONNX_NAMESPACE::Version::IR_VERSION; } - Status GetInputs(std::unique_ptr& /*result*/) const override { + size_t GetNumInputs() const override { return inputs.size(); } + + Status GetInputs(gsl::span /*result*/) const override { return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "OrtModelEditorApi does not support getting the graph inputs."); } - Status GetOutputs(std::unique_ptr& /*result*/) const override { + size_t GetNumOutputs() const override { return outputs.size(); } + + Status GetOutputs(gsl::span /*result*/) const override { return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "OrtModelEditorApi does not support getting the graph outputs."); } - Status GetInitializers(std::unique_ptr& /*result*/) const override { + size_t GetNumInitializers() const override { return initializers.size() + external_initializers.size(); } + + Status GetInitializers(gsl::span /*result*/) const override { return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "OrtModelEditorApi does not support getting the graph initializers."); } - Status GetNodes(std::unique_ptr& /*result*/) const override { + size_t GetNumNodes() const override { return nodes.size(); } + + Status GetNodes(gsl::span /*result*/) const override { return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "OrtModelEditorApi does not support getting the graph nodes."); } diff --git a/onnxruntime/core/session/ep_plugin_provider_interfaces.cc b/onnxruntime/core/session/ep_plugin_provider_interfaces.cc index 98e490a490c00..cac91a4ec52d2 100644 --- a/onnxruntime/core/session/ep_plugin_provider_interfaces.cc +++ b/onnxruntime/core/session/ep_plugin_provider_interfaces.cc @@ -8,7 +8,6 @@ #include #include #include -#include "core/framework/abi_pointer_array.h" #include "core/framework/compute_capability.h" #include "core/framework/error_code_helper.h" #include "core/framework/model_metadef_id_generator.h" diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index b30d9182684c4..8983124ce039d 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -15,7 +15,6 @@ #include "core/common/safeint.h" #include "core/common/status.h" #include "core/common/string_helper.h" -#include "core/framework/abi_pointer_array.h" #include "core/framework/allocator.h" #include "core/framework/callback.h" #include "core/framework/data_types.h" @@ -2404,107 +2403,6 @@ ORT_API(void, OrtApis::ReleaseModel, _Frees_ptr_opt_ OrtModel* model) { delete model; } -ORT_API_STATUS_IMPL(OrtApis::CreateArrayOfConstObjects, _In_ OrtTypeTag elem_type, _In_ size_t initial_size, - _In_ const void* initial_value, _Outptr_ OrtArrayOfConstObjects** out) { - API_IMPL_BEGIN - auto array = std::make_unique(elem_type, initial_size, initial_value); - *out = array.release(); - return nullptr; - API_IMPL_END -} - -ORT_API(void, OrtApis::ReleaseArrayOfConstObjects, _Frees_ptr_opt_ OrtArrayOfConstObjects* array) { - delete array; -} - -ORT_API_STATUS_IMPL(OrtApis::ArrayOfConstObjects_GetObjectType, _In_ const OrtArrayOfConstObjects* array, - _Out_ OrtTypeTag* type_tag) { - API_IMPL_BEGIN - if (type_tag == nullptr) { - return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "'type_tag' argument is NULL"); - } - - *type_tag = array->object_type; - return nullptr; - API_IMPL_END -} - -ORT_API_STATUS_IMPL(OrtApis::ArrayOfConstObjects_GetData, _In_ const OrtArrayOfConstObjects* array, - _Outptr_ const void* const** data) { - API_IMPL_BEGIN - if (data == nullptr) { - return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "'data' argument is NULL"); - } - - *data = array->storage.data(); - return nullptr; - API_IMPL_END -} - -ORT_API_STATUS_IMPL(OrtApis::ArrayOfConstObjects_GetMutableData, _In_ OrtArrayOfConstObjects* array, - _Outptr_ const void*** data) { - API_IMPL_BEGIN - if (data == nullptr) { - return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "'data' argument is NULL"); - } - - *data = array->storage.data(); - return nullptr; - API_IMPL_END -} - -ORT_API_STATUS_IMPL(OrtApis::ArrayOfConstObjects_GetSize, _In_ const OrtArrayOfConstObjects* array, - _Out_ size_t* size) { - API_IMPL_BEGIN - if (size == nullptr) { - return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "'size' argument is NULL"); - } - - *size = array->storage.size(); - return nullptr; - API_IMPL_END -} - -ORT_API_STATUS_IMPL(OrtApis::ArrayOfConstObjects_GetElementAt, _In_ const OrtArrayOfConstObjects* array, - _In_ size_t index, _Outptr_ const void** out) { - API_IMPL_BEGIN - if (out == nullptr) { - return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "'out' argument is NULL"); - } - - if (index >= array->storage.size()) { - std::ostringstream oss; - oss << "'index' value (" << index << ") is out of bounds for array of size " << array->storage.size(); - return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, oss.str().c_str()); - } - - *out = array->storage[index]; - return nullptr; - API_IMPL_END -} - -ORT_API_STATUS_IMPL(OrtApis::ArrayOfConstObjects_SetElementAt, _In_ OrtArrayOfConstObjects* array, _In_ size_t index, - _In_ const void* element) { - API_IMPL_BEGIN - if (index >= array->storage.size()) { - std::ostringstream oss; - oss << "'index' value (" << index << ") is out of bounds for array of size " << array->storage.size(); - return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, oss.str().c_str()); - } - - array->storage[index] = element; - return nullptr; - API_IMPL_END -} - -ORT_API_STATUS_IMPL(OrtApis::ArrayOfConstObjects_AppendElement, _In_ OrtArrayOfConstObjects* array, - _In_ const void* element) { - API_IMPL_BEGIN - array->storage.push_back(element); - return nullptr; - API_IMPL_END -} - ORT_API_STATUS_IMPL(OrtApis::GetValueInfoName, _In_ const OrtValueInfo* value_info, _Out_ const char** name) { API_IMPL_BEGIN @@ -2566,16 +2464,22 @@ ORT_API_STATUS_IMPL(OrtApis::ValueInfo_GetValueNumConsumers, _In_ const OrtValue } ORT_API_STATUS_IMPL(OrtApis::ValueInfo_GetValueConsumers, _In_ const OrtValueInfo* value_info, - _Out_writes_all_(max_num_consumers) const OrtNode** nodes, - _Out_writes_all_(max_num_consumers) int64_t* input_indices, - _In_ size_t max_num_consumers) { + _Out_writes_all_(num_consumers) const OrtNode** nodes, + _Out_writes_all_(num_consumers) int64_t* input_indices, + _In_ size_t num_consumers) { API_IMPL_BEGIN #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) std::vector consumer_infos; ORT_API_RETURN_IF_STATUS_NOT_OK(value_info->GetConsumerInfos(consumer_infos)); - size_t num_uses = std::min(max_num_consumers, consumer_infos.size()); - for (size_t i = 0; i < num_uses; ++i) { + if (num_consumers < consumer_infos.size()) { + std::ostringstream oss; + oss << "Not enough space for value consumers: expected buffer with at least " << consumer_infos.size() + << " elements, got " << num_consumers; + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, oss.str().c_str()); + } + + for (size_t i = 0; i < consumer_infos.size(); ++i) { nodes[i] = consumer_infos[i].node; input_indices[i] = consumer_infos[i].input_index; } @@ -2585,7 +2489,7 @@ ORT_API_STATUS_IMPL(OrtApis::ValueInfo_GetValueConsumers, _In_ const OrtValueInf ORT_UNUSED_PARAMETER(value_info); ORT_UNUSED_PARAMETER(nodes); ORT_UNUSED_PARAMETER(input_indices); - ORT_UNUSED_PARAMETER(max_num_consumers); + ORT_UNUSED_PARAMETER(num_consumers); return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "ValueInfo_GetValueConsumers() is not supported in this build."); #endif API_IMPL_END @@ -2687,59 +2591,91 @@ ORT_API_STATUS_IMPL(OrtApis::Graph_GetOnnxIRVersion, _In_ const OrtGraph* graph, API_IMPL_END } -ORT_API_STATUS_IMPL(OrtApis::Graph_GetInputs, _In_ const OrtGraph* graph, _Outptr_ OrtArrayOfConstObjects** inputs) { +ORT_API_STATUS_IMPL(OrtApis::Graph_GetNumInputs, _In_ const OrtGraph* graph, _Out_ size_t* num_inputs) { API_IMPL_BEGIN - if (inputs == nullptr) { - return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "'inputs' argument is NULL"); + if (num_inputs == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "'num_inputs' argument is NULL"); } - std::unique_ptr array; - ORT_API_RETURN_IF_STATUS_NOT_OK(graph->GetInputs(array)); + *num_inputs = graph->GetNumInputs(); - *inputs = array.release(); return nullptr; API_IMPL_END } -ORT_API_STATUS_IMPL(OrtApis::Graph_GetOutputs, _In_ const OrtGraph* graph, _Outptr_ OrtArrayOfConstObjects** outputs) { +ORT_API_STATUS_IMPL(OrtApis::Graph_GetInputs, _In_ const OrtGraph* graph, + _Out_writes_(num_inputs) const OrtValueInfo** inputs, _In_ size_t num_inputs) { API_IMPL_BEGIN - if (outputs == nullptr) { - return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "'outputs' argument is NULL"); + gsl::span inputs_span(inputs, num_inputs); + ORT_API_RETURN_IF_STATUS_NOT_OK(graph->GetInputs(inputs_span)); + + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtApis::Graph_GetNumOutputs, _In_ const OrtGraph* graph, _Out_ size_t* num_outputs) { + API_IMPL_BEGIN + if (num_outputs == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "'num_outputs' argument is NULL"); } - std::unique_ptr array; - ORT_API_RETURN_IF_STATUS_NOT_OK(graph->GetOutputs(array)); + *num_outputs = graph->GetNumOutputs(); - *outputs = array.release(); return nullptr; API_IMPL_END } -ORT_API_STATUS_IMPL(OrtApis::Graph_GetInitializers, _In_ const OrtGraph* graph, - _Outptr_ OrtArrayOfConstObjects** initializers) { +ORT_API_STATUS_IMPL(OrtApis::Graph_GetOutputs, _In_ const OrtGraph* graph, + _Out_writes_(num_outputs) const OrtValueInfo** outputs, _In_ size_t num_outputs) { API_IMPL_BEGIN - if (initializers == nullptr) { - return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "'initializers' argument is NULL"); + gsl::span outputs_span(outputs, num_outputs); + ORT_API_RETURN_IF_STATUS_NOT_OK(graph->GetOutputs(outputs_span)); + + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtApis::Graph_GetNumInitializers, _In_ const OrtGraph* graph, _Out_ size_t* num_initializers) { + API_IMPL_BEGIN + if (num_initializers == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "'num_initializers' argument is NULL"); } - std::unique_ptr array; - ORT_API_RETURN_IF_STATUS_NOT_OK(graph->GetInitializers(array)); + *num_initializers = graph->GetNumInitializers(); + + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtApis::Graph_GetInitializers, _In_ const OrtGraph* graph, + _Out_writes_(num_initializers) const OrtValueInfo** initializers, + _In_ size_t num_initializers) { + API_IMPL_BEGIN + gsl::span initializers_span(initializers, num_initializers); + ORT_API_RETURN_IF_STATUS_NOT_OK(graph->GetInitializers(initializers_span)); - *initializers = array.release(); return nullptr; API_IMPL_END } -ORT_API_STATUS_IMPL(OrtApis::Graph_GetNodes, const OrtGraph* graph, _Outptr_ OrtArrayOfConstObjects** nodes) { +ORT_API_STATUS_IMPL(OrtApis::Graph_GetNumNodes, _In_ const OrtGraph* graph, _Out_ size_t* num_nodes) { API_IMPL_BEGIN - if (nodes == nullptr) { - return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "'nodes' argument is NULL"); + if (num_nodes == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "'num_nodes' argument is NULL"); } - std::unique_ptr array; - ORT_API_RETURN_IF_STATUS_NOT_OK(graph->GetNodes(array)); + *num_nodes = graph->GetNumNodes(); + + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtApis::Graph_GetNodes, _In_ const OrtGraph* graph, + _Out_writes_(num_nodes) const OrtNode** nodes, _In_ size_t num_nodes) { + API_IMPL_BEGIN + gsl::span nodes_span(nodes, num_nodes); + ORT_API_RETURN_IF_STATUS_NOT_OK(graph->GetNodes(nodes_span)); - *nodes = array.release(); return nullptr; API_IMPL_END } @@ -2817,59 +2753,83 @@ ORT_API_STATUS_IMPL(OrtApis::Node_GetSinceVersion, _In_ const OrtNode* node, _Ou API_IMPL_END } -ORT_API_STATUS_IMPL(OrtApis::Node_GetInputs, _In_ const OrtNode* node, _Outptr_ OrtArrayOfConstObjects** inputs) { +ORT_API_STATUS_IMPL(OrtApis::Node_GetNumInputs, _In_ const OrtNode* node, _Out_ size_t* num_inputs) { API_IMPL_BEGIN - if (inputs == nullptr) { - return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "'inputs' argument is NULL"); + if (num_inputs == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "'num_inputs' argument is NULL"); } - std::unique_ptr array; - ORT_API_RETURN_IF_STATUS_NOT_OK(node->GetInputs(array)); + *num_inputs = node->GetNumInputs(); + return nullptr; + API_IMPL_END +} - *inputs = array.release(); +ORT_API_STATUS_IMPL(OrtApis::Node_GetInputs, _In_ const OrtNode* node, + _Out_writes_(num_inputs) const OrtValueInfo** inputs, _In_ size_t num_inputs) { + API_IMPL_BEGIN + gsl::span inputs_span(inputs, num_inputs); + ORT_API_RETURN_IF_STATUS_NOT_OK(node->GetInputs(inputs_span)); return nullptr; API_IMPL_END } -ORT_API_STATUS_IMPL(OrtApis::Node_GetOutputs, _In_ const OrtNode* node, _Outptr_ OrtArrayOfConstObjects** outputs) { +ORT_API_STATUS_IMPL(OrtApis::Node_GetNumOutputs, _In_ const OrtNode* node, _Out_ size_t* num_outputs) { API_IMPL_BEGIN - if (outputs == nullptr) { - return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "'outputs' argument is NULL"); + if (num_outputs == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "'num_outputs' argument is NULL"); } - std::unique_ptr array; - ORT_API_RETURN_IF_STATUS_NOT_OK(node->GetOutputs(array)); + *num_outputs = node->GetNumOutputs(); + return nullptr; + API_IMPL_END +} - *outputs = array.release(); +ORT_API_STATUS_IMPL(OrtApis::Node_GetOutputs, _In_ const OrtNode* node, + _Out_writes_(num_outputs) const OrtValueInfo** outputs, _In_ size_t num_outputs) { + API_IMPL_BEGIN + gsl::span outputs_span(outputs, num_outputs); + ORT_API_RETURN_IF_STATUS_NOT_OK(node->GetOutputs(outputs_span)); return nullptr; API_IMPL_END } -ORT_API_STATUS_IMPL(OrtApis::Node_GetImplicitInputs, _In_ const OrtNode* node, - _Outptr_ OrtArrayOfConstObjects** implicit_inputs) { +ORT_API_STATUS_IMPL(OrtApis::Node_GetNumImplicitInputs, _In_ const OrtNode* node, _Out_ size_t* num_implicit_inputs) { API_IMPL_BEGIN - if (implicit_inputs == nullptr) { - return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "'implicit_inputs' argument is NULL"); + if (num_implicit_inputs == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "'num_implicit_inputs' argument is NULL"); } - std::unique_ptr array; - ORT_API_RETURN_IF_STATUS_NOT_OK(node->GetImplicitInputs(array)); + ORT_API_RETURN_IF_STATUS_NOT_OK(node->GetNumImplicitInputs(*num_implicit_inputs)); + return nullptr; + API_IMPL_END +} - *implicit_inputs = array.release(); +ORT_API_STATUS_IMPL(OrtApis::Node_GetImplicitInputs, _In_ const OrtNode* node, + _Out_writes_(num_implicit_inputs) const OrtValueInfo** implicit_inputs, + _In_ size_t num_implicit_inputs) { + API_IMPL_BEGIN + gsl::span inputs_span(implicit_inputs, num_implicit_inputs); + ORT_API_RETURN_IF_STATUS_NOT_OK(node->GetImplicitInputs(inputs_span)); return nullptr; API_IMPL_END } -ORT_API_STATUS_IMPL(OrtApis::Node_GetAttributes, _In_ const OrtNode* node, _Outptr_ OrtArrayOfConstObjects** attributes) { +ORT_API_STATUS_IMPL(OrtApis::Node_GetNumAttributes, _In_ const OrtNode* node, _Out_ size_t* num_attributes) { API_IMPL_BEGIN - if (attributes == nullptr) { - return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "'attributes' argument is NULL"); + if (num_attributes == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "'num_attributes' argument is NULL"); } - std::unique_ptr array; - ORT_API_RETURN_IF_STATUS_NOT_OK(node->GetAttributes(array)); + *num_attributes = node->GetNumAttributes(); + return nullptr; + API_IMPL_END +} - *attributes = array.release(); +ORT_API_STATUS_IMPL(OrtApis::Node_GetAttributes, _In_ const OrtNode* node, + _Out_writes_(num_attributes) const OrtOpAttr** attributes, _In_ size_t num_attributes) { + API_IMPL_BEGIN + gsl::span attrs_span(attributes, num_attributes); + ORT_API_RETURN_IF_STATUS_NOT_OK(node->GetAttributes(attrs_span)); return nullptr; API_IMPL_END } @@ -2950,29 +2910,35 @@ ORT_API_STATUS_IMPL(OrtApis::OpAttr_GetName, _In_ const OrtOpAttr* attribute, _O API_IMPL_END } -ORT_API_STATUS_IMPL(OrtApis::Node_GetSubgraphs, _In_ const OrtNode* node, _Outptr_ OrtArrayOfConstObjects** subgraphs) { +ORT_API_STATUS_IMPL(OrtApis::Node_GetNumSubgraphs, _In_ const OrtNode* node, _Out_ size_t* num_subgraphs) { API_IMPL_BEGIN - if (subgraphs == nullptr) { - return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "'subgraphs' argument is NULL"); + if (num_subgraphs == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "'num_subgraphs' argument is NULL"); } - std::unique_ptr array; - ORT_API_RETURN_IF_STATUS_NOT_OK(node->GetSubgraphs(array)); + ORT_API_RETURN_IF_STATUS_NOT_OK(node->GetNumSubgraphs(*num_subgraphs)); + return nullptr; + API_IMPL_END +} - *subgraphs = array.release(); +ORT_API_STATUS_IMPL(OrtApis::Node_GetSubgraphs, _In_ const OrtNode* node, + _Out_writes_(num_subgraphs) const OrtGraph** subgraphs, _In_ size_t num_subgraphs) { + API_IMPL_BEGIN + gsl::span graphs_span(subgraphs, num_subgraphs); + ORT_API_RETURN_IF_STATUS_NOT_OK(node->GetSubgraphs(graphs_span)); return nullptr; API_IMPL_END } -ORT_API_STATUS_IMPL(OrtApis::Node_GetParentGraph, _In_ const OrtNode* node, - _Outptr_result_maybenull_ const OrtGraph** parent_graph) { +ORT_API_STATUS_IMPL(OrtApis::Node_GetGraph, _In_ const OrtNode* node, + _Outptr_result_maybenull_ const OrtGraph** graph) { API_IMPL_BEGIN - if (parent_graph == nullptr) { - return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "'parent_graph' argument is NULL"); + if (graph == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "'graph' argument is NULL"); } - *parent_graph = nullptr; - ORT_API_RETURN_IF_STATUS_NOT_OK(node->GetParentGraph(*parent_graph)); + *graph = nullptr; + ORT_API_RETURN_IF_STATUS_NOT_OK(node->GetGraph(*graph)); return nullptr; API_IMPL_END } @@ -3616,16 +3582,6 @@ static constexpr OrtApi ort_api_1_to_23 = { &OrtApis::CreateMemoryInfo_V2, - &OrtApis::CreateArrayOfConstObjects, - &OrtApis::ReleaseArrayOfConstObjects, - &OrtApis::ArrayOfConstObjects_GetObjectType, - &OrtApis::ArrayOfConstObjects_GetData, - &OrtApis::ArrayOfConstObjects_GetMutableData, - &OrtApis::ArrayOfConstObjects_GetSize, - &OrtApis::ArrayOfConstObjects_GetElementAt, - &OrtApis::ArrayOfConstObjects_SetElementAt, - &OrtApis::ArrayOfConstObjects_AppendElement, - &OrtApis::ValueInfo_GetValueProducer, &OrtApis::ValueInfo_GetValueNumConsumers, &OrtApis::ValueInfo_GetValueConsumers, @@ -3637,9 +3593,13 @@ static constexpr OrtApi ort_api_1_to_23 = { &OrtApis::ValueInfo_IsFromOuterScope, &OrtApis::Graph_GetName, &OrtApis::Graph_GetOnnxIRVersion, + &OrtApis::Graph_GetNumInputs, &OrtApis::Graph_GetInputs, + &OrtApis::Graph_GetNumOutputs, &OrtApis::Graph_GetOutputs, + &OrtApis::Graph_GetNumInitializers, &OrtApis::Graph_GetInitializers, + &OrtApis::Graph_GetNumNodes, &OrtApis::Graph_GetNodes, &OrtApis::Graph_GetParentNode, &OrtApis::Node_GetId, @@ -3647,15 +3607,20 @@ static constexpr OrtApi ort_api_1_to_23 = { &OrtApis::Node_GetOperatorType, &OrtApis::Node_GetDomain, &OrtApis::Node_GetSinceVersion, + &OrtApis::Node_GetNumInputs, &OrtApis::Node_GetInputs, + &OrtApis::Node_GetNumOutputs, &OrtApis::Node_GetOutputs, + &OrtApis::Node_GetNumImplicitInputs, &OrtApis::Node_GetImplicitInputs, + &OrtApis::Node_GetNumAttributes, &OrtApis::Node_GetAttributes, &OrtApis::Node_GetAttributeByName, &OrtApis::OpAttr_GetType, &OrtApis::OpAttr_GetName, + &OrtApis::Node_GetNumSubgraphs, &OrtApis::Node_GetSubgraphs, - &OrtApis::Node_GetParentGraph, + &OrtApis::Node_GetGraph, &OrtApis::GetRunConfigEntry, diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index 8e6734b914be2..4c4ab07493237 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -608,30 +608,14 @@ ORT_API_STATUS_IMPL(CreateMemoryInfo_V2, _In_ const char* name, _In_ enum OrtMem _In_ size_t alignment, enum OrtAllocatorType allocator_type, _Outptr_ OrtMemoryInfo** out); -// OrtArrayOfConstObjects -ORT_API_STATUS_IMPL(CreateArrayOfConstObjects, _In_ OrtTypeTag object_type, _In_ size_t initial_size, - _In_ const void* initial_value, _Outptr_ OrtArrayOfConstObjects** out); -ORT_API(void, ReleaseArrayOfConstObjects, _Frees_ptr_opt_ OrtArrayOfConstObjects* array); -ORT_API_STATUS_IMPL(ArrayOfConstObjects_GetObjectType, _In_ const OrtArrayOfConstObjects* array, - _Out_ OrtTypeTag* type_tag); -ORT_API_STATUS_IMPL(ArrayOfConstObjects_GetData, _In_ const OrtArrayOfConstObjects* array, - _Outptr_ const void* const** data); -ORT_API_STATUS_IMPL(ArrayOfConstObjects_GetMutableData, _In_ OrtArrayOfConstObjects* array, _Outptr_ const void*** data); -ORT_API_STATUS_IMPL(ArrayOfConstObjects_GetSize, _In_ const OrtArrayOfConstObjects* array, _Out_ size_t* size); -ORT_API_STATUS_IMPL(ArrayOfConstObjects_GetElementAt, _In_ const OrtArrayOfConstObjects* array, _In_ size_t index, - _Outptr_ const void** out); -ORT_API_STATUS_IMPL(ArrayOfConstObjects_SetElementAt, _In_ OrtArrayOfConstObjects* array, _In_ size_t index, - _In_ const void* element); -ORT_API_STATUS_IMPL(ArrayOfConstObjects_AppendElement, _In_ OrtArrayOfConstObjects* array, _In_ const void* element); - // OrtValueInfo ORT_API_STATUS_IMPL(ValueInfo_GetValueProducer, _In_ const OrtValueInfo* value_info, _Outptr_ const OrtNode** producer_node, _Out_opt_ size_t* producer_output_index); ORT_API_STATUS_IMPL(ValueInfo_GetValueNumConsumers, _In_ const OrtValueInfo* value_info, _Out_ size_t* num_consumers); ORT_API_STATUS_IMPL(ValueInfo_GetValueConsumers, _In_ const OrtValueInfo* value_info, - _Out_writes_all_(max_num_consumers) const OrtNode** nodes, - _Out_writes_all_(max_num_consumers) int64_t* input_indices, - _In_ size_t max_num_consumers); + _Out_writes_all_(num_consumers) const OrtNode** nodes, + _Out_writes_all_(num_consumers) int64_t* input_indices, + _In_ size_t num_consumers); ORT_API_STATUS_IMPL(ValueInfo_GetInitializerValue, _In_ const OrtValueInfo* value_info, _Outptr_ const OrtValue** initializer_value); ORT_API_STATUS_IMPL(ValueInfo_IsRequiredGraphInput, _In_ const OrtValueInfo* value_info, @@ -647,10 +631,19 @@ ORT_API_STATUS_IMPL(ValueInfo_IsFromOuterScope, _In_ const OrtValueInfo* value_i // OrtGraph ORT_API_STATUS_IMPL(Graph_GetName, _In_ const OrtGraph* graph, _Outptr_ const char** graph_name); ORT_API_STATUS_IMPL(Graph_GetOnnxIRVersion, _In_ const OrtGraph* graph, _Out_ int64_t* onnx_ir_version); -ORT_API_STATUS_IMPL(Graph_GetInputs, _In_ const OrtGraph* graph, _Outptr_ OrtArrayOfConstObjects** inputs); -ORT_API_STATUS_IMPL(Graph_GetOutputs, _In_ const OrtGraph* graph, _Outptr_ OrtArrayOfConstObjects** outputs); -ORT_API_STATUS_IMPL(Graph_GetInitializers, _In_ const OrtGraph* graph, _Outptr_ OrtArrayOfConstObjects** initializers); -ORT_API_STATUS_IMPL(Graph_GetNodes, const OrtGraph* graph, _Outptr_ OrtArrayOfConstObjects** nodes); +ORT_API_STATUS_IMPL(Graph_GetNumInputs, _In_ const OrtGraph* graph, _Out_ size_t* num_inputs); +ORT_API_STATUS_IMPL(Graph_GetInputs, _In_ const OrtGraph* graph, + _Out_writes_(num_inputs) const OrtValueInfo** inputs, _In_ size_t num_inputs); +ORT_API_STATUS_IMPL(Graph_GetNumOutputs, _In_ const OrtGraph* graph, _Out_ size_t* num_outputs); +ORT_API_STATUS_IMPL(Graph_GetOutputs, _In_ const OrtGraph* graph, + _Out_writes_(num_outputs) const OrtValueInfo** outputs, _In_ size_t num_outputs); +ORT_API_STATUS_IMPL(Graph_GetNumInitializers, _In_ const OrtGraph* graph, _Out_ size_t* num_initializers); +ORT_API_STATUS_IMPL(Graph_GetInitializers, _In_ const OrtGraph* graph, + _Out_writes_(num_initializers) const OrtValueInfo** initializers, + _In_ size_t num_initializers); +ORT_API_STATUS_IMPL(Graph_GetNumNodes, _In_ const OrtGraph* graph, _Out_ size_t* num_nodes); +ORT_API_STATUS_IMPL(Graph_GetNodes, const OrtGraph* graph, + _Out_writes_(num_nodes) const OrtNode** nodes, _In_ size_t num_nodes); ORT_API_STATUS_IMPL(Graph_GetParentNode, _In_ const OrtGraph* graph, _Outptr_result_maybenull_ const OrtNode** node); // OrtNode @@ -659,16 +652,27 @@ ORT_API_STATUS_IMPL(Node_GetName, _In_ const OrtNode* node, _Outptr_ const char* ORT_API_STATUS_IMPL(Node_GetOperatorType, _In_ const OrtNode* node, _Outptr_ const char** operator_type); ORT_API_STATUS_IMPL(Node_GetDomain, _In_ const OrtNode* node, _Outptr_ const char** domain_name); ORT_API_STATUS_IMPL(Node_GetSinceVersion, _In_ const OrtNode* node, _Out_ int* since_version); -ORT_API_STATUS_IMPL(Node_GetInputs, _In_ const OrtNode* node, _Outptr_ OrtArrayOfConstObjects** inputs); -ORT_API_STATUS_IMPL(Node_GetOutputs, _In_ const OrtNode* node, _Outptr_ OrtArrayOfConstObjects** outputs); -ORT_API_STATUS_IMPL(Node_GetImplicitInputs, _In_ const OrtNode* node, _Outptr_ OrtArrayOfConstObjects** implicit_inputs); -ORT_API_STATUS_IMPL(Node_GetAttributes, _In_ const OrtNode* node, _Outptr_ OrtArrayOfConstObjects** attrs); -ORT_API_STATUS_IMPL(Node_GetAttributeByName, _In_ const OrtNode* node, _In_ const char* attribute_name, _Outptr_ const OrtOpAttr** attribute); +ORT_API_STATUS_IMPL(Node_GetNumInputs, _In_ const OrtNode* node, _Out_ size_t* num_inputs); +ORT_API_STATUS_IMPL(Node_GetInputs, _In_ const OrtNode* node, + _Out_writes_(num_inputs) const OrtValueInfo** inputs, _In_ size_t num_inputs); +ORT_API_STATUS_IMPL(Node_GetNumOutputs, _In_ const OrtNode* node, _Out_ size_t* num_outputs); +ORT_API_STATUS_IMPL(Node_GetOutputs, _In_ const OrtNode* node, + _Out_writes_(num_outputs) const OrtValueInfo** outputs, _In_ size_t num_outputs); +ORT_API_STATUS_IMPL(Node_GetNumImplicitInputs, _In_ const OrtNode* node, _Out_ size_t* num_implicit_inputs); +ORT_API_STATUS_IMPL(Node_GetImplicitInputs, _In_ const OrtNode* node, + _Out_writes_(num_implicit_inputs) const OrtValueInfo** implicit_inputs, + _In_ size_t num_implicit_inputs); +ORT_API_STATUS_IMPL(Node_GetNumAttributes, _In_ const OrtNode* node, _Out_ size_t* num_attributes); +ORT_API_STATUS_IMPL(Node_GetAttributes, _In_ const OrtNode* node, + _Out_writes_(num_attributes) const OrtOpAttr** attributes, _In_ size_t num_attributes); +ORT_API_STATUS_IMPL(Node_GetAttributeByName, _In_ const OrtNode* node, _In_ const char* attribute_name, + _Outptr_ const OrtOpAttr** attribute); ORT_API_STATUS_IMPL(OpAttr_GetType, _In_ const OrtOpAttr* attribute, _Out_ OrtOpAttrType* type); ORT_API_STATUS_IMPL(OpAttr_GetName, _In_ const OrtOpAttr* attribute, _Outptr_ const char** name); -ORT_API_STATUS_IMPL(Node_GetSubgraphs, _In_ const OrtNode* node, _Outptr_ OrtArrayOfConstObjects** subgraphs); -ORT_API_STATUS_IMPL(Node_GetParentGraph, _In_ const OrtNode* node, - _Outptr_result_maybenull_ const OrtGraph** parent_graph); +ORT_API_STATUS_IMPL(Node_GetNumSubgraphs, _In_ const OrtNode* node, _Out_ size_t* num_subgraphs); +ORT_API_STATUS_IMPL(Node_GetSubgraphs, _In_ const OrtNode* node, + _Out_writes_(num_subgraphs) const OrtGraph** subgraphs, _In_ size_t num_subgraphs); +ORT_API_STATUS_IMPL(Node_GetGraph, _In_ const OrtNode* node, _Outptr_result_maybenull_ const OrtGraph** graph); ORT_API_STATUS_IMPL(GetRunConfigEntry, _In_ const OrtRunOptions* options, _In_z_ const char* config_key, _Outptr_result_maybenull_z_ const char** config_value); diff --git a/onnxruntime/test/autoep/library/ep.cc b/onnxruntime/test/autoep/library/ep.cc index d9418e3e3156d..b498c40079f48 100644 --- a/onnxruntime/test/autoep/library/ep.cc +++ b/onnxruntime/test/autoep/library/ep.cc @@ -3,6 +3,7 @@ #include "ep.h" +#include #include #include #include @@ -180,18 +181,13 @@ const char* ORT_API_CALL ExampleEp ::GetNameImpl(const OrtEp* this_ptr) noexcept } OrtStatus* ExampleEp::SaveConstantInitializers(const OrtGraph* graph) { - OrtArrayOfConstObjects* initializers = nullptr; - DeferOrtRelease release_initializers(&initializers, ort_api.ReleaseArrayOfConstObjects); size_t num_initializers = 0; + RETURN_IF_ERROR(ort_api.Graph_GetNumInitializers(graph, &num_initializers)); - RETURN_IF_ERROR(ort_api.Graph_GetInitializers(graph, &initializers)); - RETURN_IF_ERROR(ort_api.ArrayOfConstObjects_GetSize(initializers, &num_initializers)); - - for (size_t i = 0; i < num_initializers; ++i) { - const OrtValueInfo* initializer = nullptr; - RETURN_IF_ERROR(ort_api.ArrayOfConstObjects_GetElementAt(initializers, i, - reinterpret_cast(&initializer))); + std::vector initializers(num_initializers); + RETURN_IF_ERROR(ort_api.Graph_GetInitializers(graph, initializers.data(), initializers.size())); + for (const OrtValueInfo* initializer : initializers) { bool is_constant = false; RETURN_IF_ERROR(ort_api.ValueInfo_IsConstantInitializer(initializer, &is_constant)); @@ -233,53 +229,39 @@ OrtStatus* ORT_API_CALL ExampleEp::GetCapabilityImpl(OrtEp* this_ptr, const OrtG OrtEpGraphSupportInfo* graph_support_info) { ExampleEp* ep = static_cast(this_ptr); - OrtArrayOfConstObjects* nodes_array = nullptr; - DeferOrtRelease release_nodes_array(&nodes_array, ep->ort_api.ReleaseArrayOfConstObjects); - size_t num_nodes = 0; - - RETURN_IF_ERROR(ep->ort_api.Graph_GetNodes(graph, &nodes_array)); - RETURN_IF_ERROR(ep->ort_api.ArrayOfConstObjects_GetSize(nodes_array, &num_nodes)); + RETURN_IF_ERROR(ep->ort_api.Graph_GetNumNodes(graph, &num_nodes)); if (num_nodes == 0) { return nullptr; // No nodes to process } - const void* const* nodes_data = nullptr; - RETURN_IF_ERROR(ep->ort_api.ArrayOfConstObjects_GetData(nodes_array, &nodes_data)); - auto nodes_span = gsl::span(reinterpret_cast(nodes_data), num_nodes); + std::vector nodes(num_nodes); + RETURN_IF_ERROR(ep->ort_api.Graph_GetNodes(graph, nodes.data(), nodes.size())); std::vector supported_nodes; - for (const OrtNode* node : nodes_span) { + for (const OrtNode* node : nodes) { const char* op_type = nullptr; RETURN_IF_ERROR(ep->ort_api.Node_GetOperatorType(node, &op_type)); if (std::strncmp(op_type, "Mul", 4) == 0) { // Check that Mul has inputs/output of type float - OrtArrayOfConstObjects* inputs_array = nullptr; - OrtArrayOfConstObjects* outputs_array = nullptr; - DeferOrtRelease release_inputs(&inputs_array, ep->ort_api.ReleaseArrayOfConstObjects); - DeferOrtRelease release_outputs(&outputs_array, ep->ort_api.ReleaseArrayOfConstObjects); - - RETURN_IF_ERROR(ep->ort_api.Node_GetInputs(node, &inputs_array)); - RETURN_IF_ERROR(ep->ort_api.Node_GetOutputs(node, &outputs_array)); - size_t num_inputs = 0; size_t num_outputs = 0; - RETURN_IF_ERROR(ep->ort_api.ArrayOfConstObjects_GetSize(inputs_array, &num_inputs)); - RETURN_IF_ERROR(ep->ort_api.ArrayOfConstObjects_GetSize(outputs_array, &num_outputs)); + RETURN_IF_ERROR(ep->ort_api.Node_GetNumInputs(node, &num_inputs)); + RETURN_IF_ERROR(ep->ort_api.Node_GetNumOutputs(node, &num_outputs)); RETURN_IF(num_inputs != 2 || num_outputs != 1, ep->ort_api, "Mul should have 2 inputs and 1 output"); - const void* const* inputs_data = nullptr; - const void* const* outputs_data = nullptr; - RETURN_IF_ERROR(ep->ort_api.ArrayOfConstObjects_GetData(inputs_array, &inputs_data)); - RETURN_IF_ERROR(ep->ort_api.ArrayOfConstObjects_GetData(outputs_array, &outputs_data)); + std::vector inputs(num_inputs); + std::vector outputs(num_outputs); + RETURN_IF_ERROR(ep->ort_api.Node_GetInputs(node, inputs.data(), inputs.size())); + RETURN_IF_ERROR(ep->ort_api.Node_GetOutputs(node, outputs.data(), outputs.size())); std::array is_float = {false, false, false}; - RETURN_IF_ERROR(IsFloatTensor(ep->ort_api, static_cast(inputs_data[0]), is_float[0])); - RETURN_IF_ERROR(IsFloatTensor(ep->ort_api, static_cast(inputs_data[1]), is_float[1])); - RETURN_IF_ERROR(IsFloatTensor(ep->ort_api, static_cast(outputs_data[0]), is_float[2])); + RETURN_IF_ERROR(IsFloatTensor(ep->ort_api, inputs[0], is_float[0])); + RETURN_IF_ERROR(IsFloatTensor(ep->ort_api, inputs[1], is_float[1])); + RETURN_IF_ERROR(IsFloatTensor(ep->ort_api, outputs[0], is_float[2])); if (!is_float[0] || !is_float[1] || !is_float[2]) { continue; // Input or output is not of type float } @@ -321,43 +303,30 @@ OrtStatus* ORT_API_CALL ExampleEp::CompileImpl(_In_ OrtEp* this_ptr, _In_ const // implementation could transfer the weights to device memory. ep->SaveConstantInitializers(graphs[0]); - OrtArrayOfConstObjects* nodes_array = nullptr; - DeferOrtRelease release_nodes(&nodes_array, ort_api.ReleaseArrayOfConstObjects); size_t num_nodes = 0; + RETURN_IF_ERROR(ep->ort_api.Graph_GetNumNodes(graphs[0], &num_nodes)); - RETURN_IF_ERROR(ort_api.Graph_GetNodes(graphs[0], &nodes_array)); - RETURN_IF_ERROR(ort_api.ArrayOfConstObjects_GetSize(nodes_array, &num_nodes)); + std::vector nodes(num_nodes); + RETURN_IF_ERROR(ep->ort_api.Graph_GetNodes(graphs[0], nodes.data(), nodes.size())); if (num_nodes != 1) { return ort_api.CreateStatus(ORT_EP_FAIL, "Expected to compile a single Mul node"); } - const OrtNode* node_to_compile = nullptr; - RETURN_IF_ERROR(ort_api.ArrayOfConstObjects_GetElementAt(nodes_array, 0, - reinterpret_cast(&node_to_compile))); - const char* node_op_type = nullptr; - RETURN_IF_ERROR(ort_api.Node_GetOperatorType(node_to_compile, &node_op_type)); + RETURN_IF_ERROR(ort_api.Node_GetOperatorType(nodes[0], &node_op_type)); if (std::strncmp(node_op_type, "Mul", 4) != 0) { return ort_api.CreateStatus(ORT_EP_FAIL, "Expected to compile a single Mul node"); } // Now we know we're compiling a single Mul node. Create a computation kernel. - OrtArrayOfConstObjects* inputs = nullptr; - DeferOrtRelease release_inputs(&inputs, ort_api.ReleaseArrayOfConstObjects); - - RETURN_IF_ERROR(ort_api.Node_GetInputs(node_to_compile, &inputs)); - const OrtValueInfo* input0 = nullptr; - const OrtValueInfo* input1 = nullptr; - - RETURN_IF_ERROR(ort_api.ArrayOfConstObjects_GetElementAt(inputs, 0, reinterpret_cast(&input0))); - RETURN_IF_ERROR(ort_api.ArrayOfConstObjects_GetElementAt(inputs, 1, reinterpret_cast(&input1))); + std::array node_inputs = {}; + std::array node_input_names = {}; - const char* input0_name = nullptr; - const char* input1_name = nullptr; - RETURN_IF_ERROR(ort_api.GetValueInfoName(input0, &input0_name)); - RETURN_IF_ERROR(ort_api.GetValueInfoName(input1, &input1_name)); + RETURN_IF_ERROR(ort_api.Node_GetInputs(nodes[0], node_inputs.data(), node_inputs.size())); + RETURN_IF_ERROR(ort_api.GetValueInfoName(node_inputs[0], &node_input_names[0])); + RETURN_IF_ERROR(ort_api.GetValueInfoName(node_inputs[1], &node_input_names[1])); // Associate the name of the fused node with our MulKernel. const char* fused_node_name = nullptr; @@ -365,7 +334,9 @@ OrtStatus* ORT_API_CALL ExampleEp::CompileImpl(_In_ OrtEp* this_ptr, _In_ const ep->kernels_.emplace(std::string(fused_node_name), std::make_unique(ep->ort_api, ep->logger_, ep->float_initializers_, - input0_name, input1_name)); + node_input_names[0], + node_input_names[1])); + // Update the OrtNodeComputeInfo associated with the graph. auto node_compute_info = std::make_unique(*ep); node_compute_infos[0] = node_compute_info.release(); @@ -398,17 +369,14 @@ OrtStatus* ExampleEp::CreateEpContextNodes(gsl::span fused_nodes assert(fused_nodes.size() == ep_context_nodes.size()); // Helper to collect input or output names from an array of OrtValueInfo instances. - auto collect_input_output_names = [&](const OrtArrayOfConstObjects& value_infos, + auto collect_input_output_names = [&](gsl::span value_infos, std::vector& result) -> OrtStatus* { - size_t num_values = 0; - RETURN_IF_ERROR(ort_api.ArrayOfConstObjects_GetSize(&value_infos, &num_values)); - - std::vector value_names(num_values, nullptr); + size_t num_values = value_infos.size(); + std::vector value_names(num_values); - for (size_t i = 0; i < num_values; i++) { - const void* value_info = nullptr; // Is a const OrtValueInfo* - RETURN_IF_ERROR(ort_api.ArrayOfConstObjects_GetElementAt(&value_infos, i, &value_info)); - RETURN_IF_ERROR(ort_api.GetValueInfoName(static_cast(value_info), &value_names[i])); + for (size_t i = 0; i < num_values; ++i) { + const OrtValueInfo* value_info = value_infos[i]; + RETURN_IF_ERROR(ort_api.GetValueInfoName(value_info, &value_names[i])); } result = std::move(value_names); @@ -422,19 +390,21 @@ OrtStatus* ExampleEp::CreateEpContextNodes(gsl::span fused_nodes RETURN_IF_ERROR(ort_api.Node_GetName(fused_node, &fused_node_name)); - OrtArrayOfConstObjects* fused_node_inputs = nullptr; - OrtArrayOfConstObjects* fused_node_outputs = nullptr; - DeferOrtRelease defer_release0(&fused_node_inputs, ort_api.ReleaseArrayOfConstObjects); - DeferOrtRelease defer_release1(&fused_node_outputs, ort_api.ReleaseArrayOfConstObjects); + size_t num_fused_node_inputs = 0; + size_t num_fused_node_outputs = 0; + RETURN_IF_ERROR(ort_api.Node_GetNumInputs(fused_node, &num_fused_node_inputs)); + RETURN_IF_ERROR(ort_api.Node_GetNumOutputs(fused_node, &num_fused_node_outputs)); - RETURN_IF_ERROR(ort_api.Node_GetInputs(fused_node, &fused_node_inputs)); - RETURN_IF_ERROR(ort_api.Node_GetOutputs(fused_node, &fused_node_outputs)); + std::vector fused_node_inputs(num_fused_node_inputs); + std::vector fused_node_outputs(num_fused_node_outputs); + RETURN_IF_ERROR(ort_api.Node_GetInputs(fused_node, fused_node_inputs.data(), fused_node_inputs.size())); + RETURN_IF_ERROR(ort_api.Node_GetOutputs(fused_node, fused_node_outputs.data(), fused_node_outputs.size())); std::vector input_names; std::vector output_names; - RETURN_IF_ERROR(collect_input_output_names(*fused_node_inputs, /*out*/ input_names)); - RETURN_IF_ERROR(collect_input_output_names(*fused_node_outputs, /*out*/ output_names)); + RETURN_IF_ERROR(collect_input_output_names(fused_node_inputs, /*out*/ input_names)); + RETURN_IF_ERROR(collect_input_output_names(fused_node_outputs, /*out*/ output_names)); int64_t is_main_context = (i == 0); int64_t embed_mode = 1; diff --git a/onnxruntime/test/ep_graph/test_ep_graph.cc b/onnxruntime/test/ep_graph/test_ep_graph.cc index d47ebce483be4..60498e6510ec2 100644 --- a/onnxruntime/test/ep_graph/test_ep_graph.cc +++ b/onnxruntime/test/ep_graph/test_ep_graph.cc @@ -72,30 +72,6 @@ TEST(EpGraphTest, Check3LayerNestedSubgraph) { // Utils for traversing an OrtGraph and checking against GraphViewer. // -// Convert an OrtArrayOfConstObjects into a span of Ort___ pointers. -template -static void GetSpanFromArrayOfConstObjects(const OrtArrayOfConstObjects* ort_array, - /*out*/ gsl::span& span) { - const OrtApi& ort_api = Ort::GetApi(); - - size_t size = 0; - ASSERT_ORTSTATUS_OK(ort_api.ArrayOfConstObjects_GetSize(ort_array, &size)); - - const void* const* raw_data = nullptr; - ASSERT_ORTSTATUS_OK(ort_api.ArrayOfConstObjects_GetData(ort_array, &raw_data)); - - auto data = reinterpret_cast(raw_data); - span = gsl::span(data, size); -} - -static void CheckArrayObjectType(const OrtArrayOfConstObjects* ort_array, OrtTypeTag expected_object_type) { - const OrtApi& ort_api = Ort::GetApi(); - - OrtTypeTag api_object_type = ORT_TYPE_TAG_Void; - ASSERT_ORTSTATUS_OK(ort_api.ArrayOfConstObjects_GetObjectType(ort_array, &api_object_type)); - ASSERT_EQ(api_object_type, expected_object_type); -} - // Checks that the OrtTypeInfo obtained from the public C API matches another OrtTypeInfo // obtained from the internal ORT graph IR. static void CheckTypeInfo(const OrtTypeInfo* api_type_info, const OrtTypeInfo* type_info) { @@ -339,49 +315,34 @@ static void CheckGraphCApi(const GraphViewer& graph_viewer, const OrtGraph& api_ // Check graph inputs. const auto& graph_input_node_args = graph_viewer.GetInputsIncludingInitializers(); - OrtArrayOfConstObjects* api_graph_inputs_container = nullptr; - DeferOrtRelease release_graph_inputs(&api_graph_inputs_container, - ort_api.ReleaseArrayOfConstObjects); - gsl::span api_graph_inputs{}; - - ASSERT_ORTSTATUS_OK(ort_api.Graph_GetInputs(&api_graph, &api_graph_inputs_container)); - - CheckArrayObjectType(api_graph_inputs_container, ORT_TYPE_TAG_OrtValueInfo); - GetSpanFromArrayOfConstObjects(api_graph_inputs_container, api_graph_inputs); + size_t api_num_graph_inputs = 0; + ASSERT_ORTSTATUS_OK(ort_api.Graph_GetNumInputs(&api_graph, &api_num_graph_inputs)); + ASSERT_EQ(api_num_graph_inputs, graph_input_node_args.size()); - ASSERT_EQ(api_graph_inputs.size(), graph_input_node_args.size()); + std::vector api_graph_inputs(api_num_graph_inputs); + ASSERT_ORTSTATUS_OK(ort_api.Graph_GetInputs(&api_graph, api_graph_inputs.data(), api_graph_inputs.size())); CheckValueInfosCApi(graph_viewer, api_graph_inputs, graph_input_node_args); // Check graph outputs. const auto& graph_output_node_args = graph_viewer.GetOutputs(); - OrtArrayOfConstObjects* api_graph_outputs_container = nullptr; - DeferOrtRelease release_graph_outputs(&api_graph_outputs_container, - ort_api.ReleaseArrayOfConstObjects); - gsl::span api_graph_outputs{}; + size_t api_num_graph_outputs = 0; + ASSERT_ORTSTATUS_OK(ort_api.Graph_GetNumOutputs(&api_graph, &api_num_graph_outputs)); + ASSERT_EQ(api_num_graph_outputs, graph_output_node_args.size()); - ASSERT_ORTSTATUS_OK(ort_api.Graph_GetOutputs(&api_graph, &api_graph_outputs_container)); - - CheckArrayObjectType(api_graph_outputs_container, ORT_TYPE_TAG_OrtValueInfo); - GetSpanFromArrayOfConstObjects(api_graph_outputs_container, api_graph_outputs); - - ASSERT_EQ(api_graph_outputs.size(), graph_output_node_args.size()); + std::vector api_graph_outputs(api_num_graph_outputs); + ASSERT_ORTSTATUS_OK(ort_api.Graph_GetOutputs(&api_graph, api_graph_outputs.data(), api_graph_outputs.size())); CheckValueInfosCApi(graph_viewer, api_graph_outputs, graph_output_node_args); // Check graph initializers const auto& graph_initializers = graph_viewer.GetAllInitializedTensors(); - OrtArrayOfConstObjects* api_initializers_container = nullptr; - DeferOrtRelease release_initializers(&api_initializers_container, - ort_api.ReleaseArrayOfConstObjects); - gsl::span api_initializers{}; - - ASSERT_ORTSTATUS_OK(ort_api.Graph_GetInitializers(&api_graph, &api_initializers_container)); + size_t api_num_initializers = 0; + ASSERT_ORTSTATUS_OK(ort_api.Graph_GetNumInitializers(&api_graph, &api_num_initializers)); + ASSERT_EQ(api_num_initializers, graph_initializers.size()); - CheckArrayObjectType(api_initializers_container, ORT_TYPE_TAG_OrtValueInfo); - GetSpanFromArrayOfConstObjects(api_initializers_container, api_initializers); - - ASSERT_EQ(api_initializers.size(), graph_initializers.size()); + std::vector api_initializers(api_num_initializers); + ASSERT_ORTSTATUS_OK(ort_api.Graph_GetInitializers(&api_graph, api_initializers.data(), api_initializers.size())); CheckInitializerValueInfosCApi(api_initializers, graph_initializers); // Check if it has a parent node. @@ -398,23 +359,18 @@ static void CheckGraphCApi(const GraphViewer& graph_viewer, const OrtGraph& api_ } // Check all nodes. - OrtArrayOfConstObjects* api_nodes_container = nullptr; - DeferOrtRelease release_nodes(&api_nodes_container, - ort_api.ReleaseArrayOfConstObjects); - ASSERT_ORTSTATUS_OK(ort_api.Graph_GetNodes(&api_graph, &api_nodes_container)); - CheckArrayObjectType(api_nodes_container, ORT_TYPE_TAG_OrtNode); - size_t api_num_nodes = 0; - ASSERT_ORTSTATUS_OK(ort_api.ArrayOfConstObjects_GetSize(api_nodes_container, &api_num_nodes)); + ASSERT_ORTSTATUS_OK(ort_api.Graph_GetNumNodes(&api_graph, &api_num_nodes)); ASSERT_EQ(api_num_nodes, graph_viewer.NumberOfNodes()); + std::vector api_nodes(api_num_nodes); + ASSERT_ORTSTATUS_OK(ort_api.Graph_GetNodes(&api_graph, api_nodes.data(), api_nodes.size())); + std::vector node_indices = graph_viewer.GetNodesInTopologicalOrder(ExecutionOrder::DEFAULT); for (size_t node_idx = 0; node_idx < api_num_nodes; node_idx++) { // Check basic node properties. const Node* node = graph_viewer.GetNode(node_indices[node_idx]); - const OrtNode* api_node = nullptr; - ASSERT_ORTSTATUS_OK(ort_api.ArrayOfConstObjects_GetElementAt(api_nodes_container, node_idx, - reinterpret_cast(&api_node))); + const OrtNode* api_node = api_nodes[node_idx]; CheckNode(node, api_node); int api_since_version = 0; @@ -424,49 +380,37 @@ static void CheckGraphCApi(const GraphViewer& graph_viewer, const OrtGraph& api_ // Check node inputs const auto input_node_args = node->InputDefs(); - OrtArrayOfConstObjects* api_node_inputs_container = nullptr; - DeferOrtRelease release_node_inputs(&api_node_inputs_container, - ort_api.ReleaseArrayOfConstObjects); - gsl::span api_node_inputs{}; - - ASSERT_ORTSTATUS_OK(ort_api.Node_GetInputs(api_node, &api_node_inputs_container)); - - CheckArrayObjectType(api_node_inputs_container, ORT_TYPE_TAG_OrtValueInfo); - GetSpanFromArrayOfConstObjects(api_node_inputs_container, api_node_inputs); - ASSERT_EQ(api_node_inputs.size(), input_node_args.size()); + size_t api_node_num_inputs = 0; + ASSERT_ORTSTATUS_OK(ort_api.Node_GetNumInputs(api_node, &api_node_num_inputs)); + ASSERT_EQ(api_node_num_inputs, input_node_args.size()); + std::vector api_node_inputs(api_node_num_inputs); + ASSERT_ORTSTATUS_OK(ort_api.Node_GetInputs(api_node, api_node_inputs.data(), api_node_inputs.size())); CheckValueInfosCApi(graph_viewer, api_node_inputs, input_node_args); // Check node outputs const auto output_node_args = node->OutputDefs(); - OrtArrayOfConstObjects* api_node_outputs_container = nullptr; - DeferOrtRelease release_node_outputs(&api_node_outputs_container, - ort_api.ReleaseArrayOfConstObjects); - gsl::span api_node_outputs{}; - - ASSERT_ORTSTATUS_OK(ort_api.Node_GetOutputs(api_node, &api_node_outputs_container)); - - CheckArrayObjectType(api_node_outputs_container, ORT_TYPE_TAG_OrtValueInfo); - GetSpanFromArrayOfConstObjects(api_node_outputs_container, api_node_outputs); - ASSERT_EQ(api_node_outputs.size(), output_node_args.size()); + size_t api_node_num_outputs = 0; + ASSERT_ORTSTATUS_OK(ort_api.Node_GetNumOutputs(api_node, &api_node_num_outputs)); + ASSERT_EQ(api_node_num_outputs, output_node_args.size()); + std::vector api_node_outputs(api_node_num_outputs); + ASSERT_ORTSTATUS_OK(ort_api.Node_GetOutputs(api_node, api_node_outputs.data(), api_node_outputs.size())); CheckValueInfosCApi(graph_viewer, api_node_outputs, output_node_args); // Check node attributes const auto& node_attrs = node->GetAttributes(); - if (node_attrs.size() > 0) { - OrtArrayOfConstObjects* api_node_attributes = nullptr; - DeferOrtRelease release_node_attributes(&api_node_attributes, - ort_api.ReleaseArrayOfConstObjects); - ASSERT_ORTSTATUS_OK(ort_api.Node_GetAttributes(api_node, &api_node_attributes)); - CheckArrayObjectType(api_node_attributes, ORT_TYPE_TAG_OrtOpAttr); + if (!node_attrs.empty()) { + size_t api_num_node_attributes = 0; + ASSERT_ORTSTATUS_OK(ort_api.Node_GetNumAttributes(api_node, &api_num_node_attributes)); + + std::vector api_node_attributes(api_num_node_attributes); + ASSERT_ORTSTATUS_OK(ort_api.Node_GetAttributes(api_node, api_node_attributes.data(), api_node_attributes.size())); size_t attr_idx = 0; for (const auto& node_attr : node_attrs) { - const OrtOpAttr* api_node_attr = nullptr; - ASSERT_ORTSTATUS_OK(ort_api.ArrayOfConstObjects_GetElementAt(api_node_attributes, attr_idx, - reinterpret_cast(&api_node_attr))); + const OrtOpAttr* api_node_attr = api_node_attributes[attr_idx]; ASSERT_NE(api_node_attr, nullptr); api_node_attr = nullptr; @@ -531,32 +475,28 @@ static void CheckGraphCApi(const GraphViewer& graph_viewer, const OrtGraph& api_ if (!node_subgraphs.empty()) { // Check node's implicit inputs to its subgraph nodes. const auto implicit_input_node_args = node->ImplicitInputDefs(); - OrtArrayOfConstObjects* api_node_implicit_inputs_container = nullptr; - DeferOrtRelease release_node_implicit(&api_node_implicit_inputs_container, - ort_api.ReleaseArrayOfConstObjects); - gsl::span api_node_implicit_inputs{}; - ASSERT_ORTSTATUS_OK(ort_api.Node_GetImplicitInputs(api_node, &api_node_implicit_inputs_container)); + size_t api_num_node_implicit_inputs = 0; + ASSERT_ORTSTATUS_OK(ort_api.Node_GetNumImplicitInputs(api_node, &api_num_node_implicit_inputs)); + ASSERT_EQ(api_num_node_implicit_inputs, implicit_input_node_args.size()); - CheckArrayObjectType(api_node_implicit_inputs_container, ORT_TYPE_TAG_OrtValueInfo); - GetSpanFromArrayOfConstObjects(api_node_implicit_inputs_container, api_node_implicit_inputs); - ASSERT_EQ(api_node_implicit_inputs.size(), implicit_input_node_args.size()); + std::vector api_node_implicit_inputs(api_num_node_implicit_inputs); + ASSERT_ORTSTATUS_OK(ort_api.Node_GetImplicitInputs(api_node, api_node_implicit_inputs.data(), + api_node_implicit_inputs.size())); CheckValueInfosCApi(graph_viewer, api_node_implicit_inputs, implicit_input_node_args); // Recursively check subgraphs. - OrtArrayOfConstObjects* api_node_subgraphs = nullptr; - DeferOrtRelease release_node_subgraphs(&api_node_subgraphs, - ort_api.ReleaseArrayOfConstObjects); - ASSERT_ORTSTATUS_OK(ort_api.Node_GetSubgraphs(api_node, &api_node_subgraphs)); - CheckArrayObjectType(api_node_subgraphs, ORT_TYPE_TAG_OrtGraph); + size_t api_num_node_subgraphs = 0; + ASSERT_ORTSTATUS_OK(ort_api.Node_GetNumSubgraphs(api_node, &api_num_node_subgraphs)); + + std::vector api_node_subgraphs(api_num_node_subgraphs); + ASSERT_ORTSTATUS_OK(ort_api.Node_GetSubgraphs(api_node, api_node_subgraphs.data(), api_node_subgraphs.size())); for (size_t subgraph_idx = 0; subgraph_idx < node_subgraphs.size(); subgraph_idx++) { auto subgraph_viewer = std::make_unique(*node_subgraphs[subgraph_idx]); + const OrtGraph* api_subgraph = api_node_subgraphs[subgraph_idx]; - const OrtGraph* api_subgraph = nullptr; - ASSERT_ORTSTATUS_OK(ort_api.ArrayOfConstObjects_GetElementAt(api_node_subgraphs, subgraph_idx, - reinterpret_cast(&api_subgraph))); CheckGraphCApi(*subgraph_viewer, *api_subgraph); } } diff --git a/onnxruntime/test/ep_graph/test_ep_graph_topo_sort.cc b/onnxruntime/test/ep_graph/test_ep_graph_topo_sort.cc index 5816037c2845d..63652d8835e77 100644 --- a/onnxruntime/test/ep_graph/test_ep_graph_topo_sort.cc +++ b/onnxruntime/test/ep_graph/test_ep_graph_topo_sort.cc @@ -47,19 +47,16 @@ struct VisitorPriorityQueue { static Ort::Status GetNodeInputEdgeCount(const OrtNode* node, size_t& num_input_edges) { const OrtApi& ort_api = Ort::GetApi(); - OrtArrayOfConstObjects* inputs = nullptr; - DeferOrtRelease release_inputs(&inputs, ort_api.ReleaseArrayOfConstObjects); - RETURN_IF_API_ERROR(ort_api.Node_GetInputs(node, &inputs)); - size_t num_inputs = 0; - RETURN_IF_API_ERROR(ort_api.ArrayOfConstObjects_GetSize(inputs, &num_inputs)); + RETURN_IF_API_ERROR(ort_api.Node_GetNumInputs(node, &num_inputs)); + + std::vector inputs(num_inputs); + RETURN_IF_API_ERROR(ort_api.Node_GetInputs(node, inputs.data(), inputs.size())); // Sum the number of inputs with a producer node. num_input_edges = 0; - for (size_t i = 0; i < num_inputs; ++i) { - const OrtValueInfo* input = nullptr; - RETURN_IF_API_ERROR(ort_api.ArrayOfConstObjects_GetElementAt(inputs, i, reinterpret_cast(&input))); + for (const OrtValueInfo* input : inputs) { if (input == nullptr) continue; // Skip missing optional input const OrtNode* producer_node = nullptr; @@ -74,20 +71,17 @@ static Ort::Status GetNodeInputEdgeCount(const OrtNode* node, size_t& num_input_ static Ort::Status GetOutputNodes(const OrtNode* node, std::vector& result) { const OrtApi& ort_api = Ort::GetApi(); - OrtArrayOfConstObjects* outputs = nullptr; - DeferOrtRelease release_outputs(&outputs, ort_api.ReleaseArrayOfConstObjects); - RETURN_IF_API_ERROR(ort_api.Node_GetOutputs(node, &outputs)); - size_t num_outputs = 0; - RETURN_IF_API_ERROR(ort_api.ArrayOfConstObjects_GetSize(outputs, &num_outputs)); + RETURN_IF_API_ERROR(ort_api.Node_GetNumOutputs(node, &num_outputs)); + + std::vector outputs(num_outputs); + RETURN_IF_API_ERROR(ort_api.Node_GetOutputs(node, outputs.data(), outputs.size())); std::vector output_nodes; output_nodes.reserve(num_outputs); // May have more than `num_outputs` // Gather the OrtNode consumers of every output. - for (size_t i = 0; i < num_outputs; ++i) { - const OrtValueInfo* output = nullptr; - RETURN_IF_API_ERROR(ort_api.ArrayOfConstObjects_GetElementAt(outputs, i, reinterpret_cast(&output))); + for (const OrtValueInfo* output : outputs) { if (output == nullptr) continue; // Skip missing optional output size_t num_consumers = 0; @@ -115,21 +109,19 @@ static Ort::Status KahnsTopologicalSort(const OrtGraph& graph, const OrtApi& ort_api = Ort::GetApi(); // Get all nodes - OrtArrayOfConstObjects* nodes_array = nullptr; - DeferOrtRelease release_nodes(&nodes_array, ort_api.ReleaseArrayOfConstObjects); - size_t num_nodes = 0; - const void* const* nodes_raw_data = nullptr; + RETURN_IF_API_ERROR(ort_api.Graph_GetNumNodes(&graph, &num_nodes)); - RETURN_IF_API_ERROR(ort_api.Graph_GetNodes(&graph, &nodes_array)); - RETURN_IF_API_ERROR(ort_api.ArrayOfConstObjects_GetSize(nodes_array, &num_nodes)); - RETURN_IF_API_ERROR(ort_api.ArrayOfConstObjects_GetData(nodes_array, &nodes_raw_data)); + if (num_nodes == 0) { + return Ort::Status{nullptr}; // Nothing to sort. + } - auto nodes_span = gsl::span(reinterpret_cast(nodes_raw_data), num_nodes); + std::vector nodes(num_nodes); + RETURN_IF_API_ERROR(ort_api.Graph_GetNodes(&graph, nodes.data(), nodes.size())); // Get the maximum node ID. Not really required if we chose to represent the `in_degree` as a map instead of vector. size_t max_node_id = 0; - for (const OrtNode* node : nodes_span) { + for (const OrtNode* node : nodes) { size_t node_id = 0; RETURN_IF_API_ERROR(ort_api.Node_GetId(node, &node_id)); max_node_id = std::max(max_node_id, node_id); @@ -142,7 +134,7 @@ static Ort::Status KahnsTopologicalSort(const OrtGraph& graph, topo_order.reserve(num_nodes); // Initialize in_degree and initial nodes to visit first. - for (const OrtNode* node : nodes_span) { + for (const OrtNode* node : nodes) { size_t input_edge_count = 0; RETURN_IF_API_ERROR(GetNodeInputEdgeCount(node, input_edge_count)); @@ -166,7 +158,7 @@ static Ort::Status KahnsTopologicalSort(const OrtGraph& graph, } std::vector output_nodes; - GetOutputNodes(current_node, output_nodes); + RETURN_IF_API_ERROR(GetOutputNodes(current_node, output_nodes)); for (const OrtNode* output_node : output_nodes) { size_t output_node_id = 0; diff --git a/onnxruntime/test/ep_graph/test_ep_graph_utils.h b/onnxruntime/test/ep_graph/test_ep_graph_utils.h index 9c04a72a42248..b0ed825f21d71 100644 --- a/onnxruntime/test/ep_graph/test_ep_graph_utils.h +++ b/onnxruntime/test/ep_graph/test_ep_graph_utils.h @@ -35,27 +35,6 @@ class TestGraph { std::unique_ptr api_graph; }; -// Helper to release a C API Ort object at the end of its scope. -// Useful when not using the public C++ API. -// Example: -// { -// OrtTensorTypeAndShapeInfo* info = nullptr; -// DeferOrtRelease defer_release(&info, c_api.ReleaseTensorTypeAndShapeInfo); -// ... -// } /* Release is called at end of scope*/ -template -struct DeferOrtRelease { - DeferOrtRelease(T** obj_ptr, std::function release_func) : obj_ptr_(obj_ptr), release_func_(release_func) {} - ~DeferOrtRelease() { - if (obj_ptr_ != nullptr && *obj_ptr_ != nullptr) { - release_func_(*obj_ptr_); - *obj_ptr_ = nullptr; - } - } - T** obj_ptr_ = nullptr; - std::function release_func_ = nullptr; -}; - struct NodeArgConsumer { NodeArgConsumer(const Node* node, int64_t index) : node(node), input_index(index) {} const Node* node = nullptr; From c122f0d8b80dc54548469d29e79b9bb3dcd58d7d Mon Sep 17 00:00:00 2001 From: Piotr Kubaj Date: Thu, 3 Jul 2025 20:27:20 +0000 Subject: [PATCH 15/19] platform.cpp: support for POWER9 and POWER10 on FreeBSD (#25186) ### Description This commit adds support for running on POWER9 and POWER10 processors on FreeBSD. The only major difference from Linux is that FreeBSD uses elf_aux_info() instead of getauxval(). ### Motivation and Context Fix issue with Stream notification function. The stream can be nullptr so using a reference was incorrect. Try and improve readability. ### Motivation and Context Fix incorrect function signature. --- .../onnxruntime/core/framework/allocator.h | 2 +- .../core/framework/stream_handles.h | 8 ++- .../core/framework/allocation_planner.cc | 49 ++++++++++++++----- onnxruntime/core/framework/bfc_arena.cc | 6 ++- onnxruntime/core/framework/execution_steps.cc | 13 +++-- .../partial_graph_execution_state.cc | 2 +- .../framework/sequential_execution_plan.h | 2 +- .../core/framework/sequential_executor.cc | 2 +- .../framework/stream_execution_context.cc | 8 ++- .../core/providers/cann/cann_stream_handle.cc | 6 +-- .../core/providers/cann/cann_stream_handle.h | 2 +- .../providers/cuda/cuda_execution_provider.cc | 3 +- .../core/providers/cuda/cuda_stream_handle.cc | 7 +-- .../core/providers/cuda/cuda_stream_handle.h | 2 +- .../migraphx/migraphx_stream_handle.cc | 6 +-- .../migraphx/migraphx_stream_handle.h | 2 +- onnxruntime/test/framework/bfc_arena_test.cc | 2 +- 17 files changed, 81 insertions(+), 41 deletions(-) diff --git a/include/onnxruntime/core/framework/allocator.h b/include/onnxruntime/core/framework/allocator.h index 6f519249b98b6..609386fd1f081 100644 --- a/include/onnxruntime/core/framework/allocator.h +++ b/include/onnxruntime/core/framework/allocator.h @@ -86,7 +86,7 @@ class Stream; namespace synchronize { class Notification; } -using WaitNotificationFn = std::function; +using WaitNotificationFn = std::function; void* AllocateBufferWithOptions(IAllocator& allocator, size_t size, bool use_reserve, Stream* stream, WaitNotificationFn wait_fn); template diff --git a/include/onnxruntime/core/framework/stream_handles.h b/include/onnxruntime/core/framework/stream_handles.h index 9b9bc94105005..441e3ebda1502 100644 --- a/include/onnxruntime/core/framework/stream_handles.h +++ b/include/onnxruntime/core/framework/stream_handles.h @@ -172,12 +172,16 @@ class IStreamCommandHandleRegistry { // i.e., for an cuda event what notify the memory copy, it could be wait on a CPU stream, or on another cuda stream. [[nodiscard]] virtual WaitNotificationFn GetWaitHandle(const OrtDevice& notification_owner_device, const OrtDevice& executor_device) const = 0; - // Get the stream creation function registered on the given device type. + + // Get the stream creation function registered for the given device type. [[nodiscard]] virtual CreateStreamFn GetCreateStreamFn(OrtDevice::DeviceType execution_device_type) const = 0; - // register a wait methond which will be invoked when we wait a notification (created by 'notification_device_type' device) on a stream at 'device_type' device. + + // register a wait method which will be invoked to await a notification that is + // created by 'notification_device_type' device on a stream at 'device_type' device. virtual void RegisterWaitFn(OrtDevice::DeviceType notification_device_type, OrtDevice::DeviceType device_type, WaitNotificationFn fn) = 0; + // register a handle about how to create stream on given device type. virtual void RegisterCreateStreamFn(OrtDevice::DeviceType device_type, CreateStreamFn f) = 0; diff --git a/onnxruntime/core/framework/allocation_planner.cc b/onnxruntime/core/framework/allocation_planner.cc index a0c00b1cd26e5..e77496b6e8196 100644 --- a/onnxruntime/core/framework/allocation_planner.cc +++ b/onnxruntime/core/framework/allocation_planner.cc @@ -2018,6 +2018,7 @@ class PlannerImpl { execution_plan.emplace_back(nullptr); } } + // 2. Determining following things: // a. which node needs to generate the notification // b. which node needs to trigger downstream @@ -2046,10 +2047,16 @@ class PlannerImpl { return producer_topoindex < yieldOp_index_in_toposort && yieldOp_index_in_toposort < consumer_topoindex; }; #endif + size_t num_trigger_points = 0; InlinedHashMap node_to_trigger_points; + + // map of node that will generate a notification to plan_.notification_owners entry InlinedHashMap node_to_notification; + + // map of waiting node index to pairs of nodes + notification fn that it is waiting on std::map> node_to_wait; + for (size_t i = 0; i < num_logic_streams_; ++i) { for (auto node_index : stream_nodes_[i]) { auto* node = graph_viewer_.GetNode(node_index); @@ -2068,17 +2075,24 @@ class PlannerImpl { } } } + for (size_t i = 0; i < num_logic_streams_; ++i) { for (auto node_index : stream_nodes_[i]) { auto* node = graph_viewer_.GetNode(node_index); auto stream_device = execution_plan[i]->device_; - // Neither trigger ActivateNotification/WaitOnEPStep for Shape op (whose output is ready for all the EPs), nor - // upstream is on CPU device (As currently we never invoke RegisterWaitFn(CPU, ...) for all kinds of EP, thus no wait_handle can be retrieved for this case) + // We don't need an ActivateNotificationStep or WaitOnEPStep for Shape as it always runs on CPU and isn't + // dependent on input from other devices. + // We also skip CPU streams as there is no wait function for CPU -> Device, so GetWaitHandle will always return + // null. EPs only register Device -> Device and Device -> CPU wait handlers currently. if (node->OpType() != "Shape" && !stream_device.UsesCpuMemory()) { + // for each node consuming one or more outputs from the current node for (auto it = node->OutputNodesBegin(); it != node->OutputNodesEnd(); ++it) { bool output_consumed_in_subgraph = true; + + // find the output/s the downstream node consumes for (auto* output : node->OutputDefs()) { if (output->Exists()) { + // TODO: is this correct or do we need to iterate ImplicitInputDefs as well? if (std::find(it->InputDefs().begin(), it->InputDefs().end(), output) != it->InputDefs().end()) { output_consumed_in_subgraph = false; // output directly consumed in current graph OrtValueIndex output_arg_idx; @@ -2086,7 +2100,7 @@ class PlannerImpl { // there are two cases we need notification: // 1. the consumer is not in the same stream // 2. the consumer is in the same stream(non-cpu device), but it consumes a CPU tensor from an non-shape op. - // for example, a resize cuda kernel consumer a tensor from MemCpyToHost cuda kernel on the same stream. + // for example, a resize cuda kernel consumes a tensor from MemCpyToHost cuda kernel on the same stream. // in this case, the FIFO can't guarantee the cpu tensor is ready when resize kernel is launching const auto& output_arg_device = AllocPlan(output_arg_idx).location; WaitNotificationFn wait_handle = stream_handle_registry.GetWaitHandle(stream_device, @@ -2094,15 +2108,17 @@ class PlannerImpl { if ((plan_.node_stream_map_[it->Index()] != i || output_arg_device.UsesCpuMemory()) && wait_handle != nullptr) { if (node_to_notification.find(node_index) == node_to_notification.end()) { - node_to_notification[node_index] = plan_.notification_owners.size(); - plan_.notification_owners.push_back(i); + node_to_notification[node_index] = plan_.notification_owner_stream.size(); + plan_.notification_owner_stream.push_back(i); } + // if node_index is already in the map, it will NOT be overwritten by insert() node_to_wait[it->Index()].insert({node_index, wait_handle}); } } } } + if (output_consumed_in_subgraph) { const auto downstream = plan_.node_stream_map_[it->Index()]; if (downstream != i) { @@ -2111,8 +2127,8 @@ class PlannerImpl { downstream_device); if (wait_handle) { if (node_to_notification.find(node_index) == node_to_notification.end()) { - node_to_notification[node_index] = plan_.notification_owners.size(); - plan_.notification_owners.push_back(i); + node_to_notification[node_index] = plan_.notification_owner_stream.size(); + plan_.notification_owner_stream.push_back(i); } node_to_wait[it->Index()].insert({node_index, wait_handle}); } @@ -2143,6 +2159,7 @@ class PlannerImpl { // add dependency for current logic stream dependence_graph_[node_index].insert(stream_nodes_[i][j - 1]); } + auto* node = graph_viewer_.GetNode(node_index); std::unordered_set visited; // TODO(leca): See the bug description in PlannerTest.MultiStreamMultiOutput. Can remove this variable once this bug is fixed for (auto it = node->InputNodesBegin(); it != node->InputNodesEnd(); ++it) { @@ -2150,7 +2167,8 @@ class PlannerImpl { continue; } visited.insert(it->Index()); - // check whether we need to add barrier + + // add barrier if input node is not in this logic stream if (std::find(stream_nodes_[i].begin(), stream_nodes_[i].end(), it->Index()) == stream_nodes_[i].end() #ifdef ENABLE_TRAINING && !AreNodesSeparatedByYield(it->Index(), node_index) @@ -2162,17 +2180,21 @@ class PlannerImpl { size_t trigger_point_index = trigger_point_it->second; // push a barrier size_t barrier_id = plan_.num_barriers++; - plan_.downstream_map[trigger_point_index].push_back({i, - static_cast(execution_plan[i]->steps_.size())}); + // we add to the downstream map which causes TriggerDownstreamStep to run which decrements the + // barrier from the downstream stream when the downstream node is ready. + plan_.downstream_map[trigger_point_index].push_back( + {i, static_cast(execution_plan[i]->steps_.size())}); execution_plan[i]->steps_.emplace_back(std::make_unique(barrier_id, node_index)); } } + // if current node has a waiter for a notification add WaitOnEPStep. auto wait_it = node_to_wait.find(node_index); if (wait_it != node_to_wait.end()) { - for (auto wait_param : wait_it->second) { - execution_plan[i]->steps_.emplace_back(std::make_unique(wait_param.second, - node_to_notification[wait_param.first], node_index)); + for (const auto& [node_producing_notification, notification_fn] : wait_it->second) { + execution_plan[i]->steps_.emplace_back( + std::make_unique(notification_fn, + node_to_notification[node_producing_notification], node_index)); } } @@ -2180,6 +2202,7 @@ class PlannerImpl { // add dependency for model graph dependence_graph_[it->Index()].insert(node_index); } + // push launch kernel command #if defined(ORT_MINIMAL_BUILD) execution_plan[i]->steps_.emplace_back(std::make_unique(node_index)); diff --git a/onnxruntime/core/framework/bfc_arena.cc b/onnxruntime/core/framework/bfc_arena.cc index 6513c4b95a818..ed64769d13fcc 100644 --- a/onnxruntime/core/framework/bfc_arena.cc +++ b/onnxruntime/core/framework/bfc_arena.cc @@ -879,8 +879,10 @@ void StreamAwareArena::SecureTheChunk(Stream* chunk_stream, Stream* target_strea if (chunk_stream && target_stream && chunk_stream != target_stream) { auto notification = chunk_stream->CreateNotification(1); notification->ActivateAndUpdate(); - if (wait_fn) - wait_fn(*target_stream, *notification); + if (wait_fn) { + wait_fn(target_stream, *notification); + } + target_stream->UpdateStreamClock(notification->GetStreamSyncTable()); // it should be ok to release the notification now, as the wait is already launch to stream. } diff --git a/onnxruntime/core/framework/execution_steps.cc b/onnxruntime/core/framework/execution_steps.cc index b647833cfd373..36f663699be4f 100644 --- a/onnxruntime/core/framework/execution_steps.cc +++ b/onnxruntime/core/framework/execution_steps.cc @@ -34,11 +34,16 @@ Status WaitOnEPStep::Execute(StreamExecutionContext& ctx, const bool& /*terminate_flag*/, bool& continue_flag) { ORT_ENFORCE(wait_handle_, "WaitOnEPStep.wait_handle is null"); - wait_handle_(*ctx.GetDeviceStream(stream_idx), *ctx.GetNotification(notification_idx_)); - // update streams clock status - if (ctx.GetDeviceStream(stream_idx)) { - ctx.GetDeviceStream(stream_idx)->UpdateStreamClock(ctx.GetNotification(notification_idx_)->GetStreamSyncTable()); + + auto* stream = ctx.GetDeviceStream(stream_idx); + auto& notification = *ctx.GetNotification(notification_idx_); + wait_handle_(stream, notification); + + // update the stream's clock status + if (stream != nullptr) { + stream->UpdateStreamClock(notification.GetStreamSyncTable()); } + LOGS(ctx.GetLogger(), VERBOSE) << "stream " << stream_idx << " wait on Notification with id: " << notification_idx_; continue_flag = true; return Status::OK(); diff --git a/onnxruntime/core/framework/partial_graph_execution_state.cc b/onnxruntime/core/framework/partial_graph_execution_state.cc index ce0572927d94a..875395ce0bfe8 100644 --- a/onnxruntime/core/framework/partial_graph_execution_state.cc +++ b/onnxruntime/core/framework/partial_graph_execution_state.cc @@ -79,7 +79,7 @@ StreamExecutionContext& PartialGraphExecutionState::GetExecutionContext(gsl::spa execution_context_ = std::make_unique( session_state, valid_streams, - execution_plan->notification_owners, + execution_plan->notification_owner_stream, execution_plan->num_barriers, device_streams, feed_mlvalue_idxs, diff --git a/onnxruntime/core/framework/sequential_execution_plan.h b/onnxruntime/core/framework/sequential_execution_plan.h index d9472e404c0e4..d2a4378c160a6 100644 --- a/onnxruntime/core/framework/sequential_execution_plan.h +++ b/onnxruntime/core/framework/sequential_execution_plan.h @@ -157,7 +157,7 @@ struct SequentialExecutionPlan : public ExecutionPlanBase { // elements in node_release_list[i] is the index in release_actions. std::vector> node_release_list; // for each notification, what is the stream-idx of the its owner. - std::vector notification_owners; + std::vector notification_owner_stream; // key: notification index. // value: {stream_idx, step_idx} // giving a notification, we used this map to figure out what is the downstream steps it need to trigger. diff --git a/onnxruntime/core/framework/sequential_executor.cc b/onnxruntime/core/framework/sequential_executor.cc index 26a57ec3ea02f..7180d976c1d3c 100644 --- a/onnxruntime/core/framework/sequential_executor.cc +++ b/onnxruntime/core/framework/sequential_executor.cc @@ -601,7 +601,7 @@ onnxruntime::Status ExecuteThePlan(const SessionState& session_state, gsl::span< #ifdef ORT_ENABLE_STREAM StreamExecutionContext ctx(session_state, valid_streams, - execution_plan->notification_owners, + execution_plan->notification_owner_stream, execution_plan->num_barriers, device_streams, feed_mlvalue_idxs, diff --git a/onnxruntime/core/framework/stream_execution_context.cc b/onnxruntime/core/framework/stream_execution_context.cc index e8beb98749028..317f49dbb21da 100644 --- a/onnxruntime/core/framework/stream_execution_context.cc +++ b/onnxruntime/core/framework/stream_execution_context.cc @@ -52,6 +52,8 @@ StreamExecutionContext::StreamExecutionContext(const SessionState& sess_state, #endif // init barriers + // one for the producer node: BarrierStep in execution_plan[i]->steps_ + // one for the downstream node: run via plan_.downstream_map for (size_t i = 0; i < num_barriers; ++i) { count_down_barriers_[i].Set(2); } @@ -64,9 +66,11 @@ StreamExecutionContext::StreamExecutionContext(const SessionState& sess_state, } } -synchronize::Notification* StreamExecutionContext ::GetNotification(size_t idx) { return notifications_[idx].get(); } +synchronize::Notification* StreamExecutionContext::GetNotification(size_t idx) { + return notifications_[idx].get(); +} -bool StreamExecutionContext ::DecCountDownBarrier(size_t barrier_id) { +bool StreamExecutionContext::DecCountDownBarrier(size_t barrier_id) { return count_down_barriers_[barrier_id].Dec(); } diff --git a/onnxruntime/core/providers/cann/cann_stream_handle.cc b/onnxruntime/core/providers/cann/cann_stream_handle.cc index bcb5a62cf6c22..041fc54a725a9 100644 --- a/onnxruntime/core/providers/cann/cann_stream_handle.cc +++ b/onnxruntime/core/providers/cann/cann_stream_handle.cc @@ -57,11 +57,11 @@ void CannStream::Flush() { } // CPU Stream command handles -void WaitCannNotificationOnDevice(Stream& stream, synchronize::Notification& notification) { - static_cast(¬ification)->wait_on_device(stream); +void WaitCannNotificationOnDevice(Stream* stream, synchronize::Notification& notification) { + static_cast(¬ification)->wait_on_device(*stream); } -void WaitCannNotificationOnHost(Stream& /*stream*/, synchronize::Notification& notification) { +void WaitCannNotificationOnHost(Stream* /*stream*/, synchronize::Notification& notification) { static_cast(¬ification)->wait_on_host(); } diff --git a/onnxruntime/core/providers/cann/cann_stream_handle.h b/onnxruntime/core/providers/cann/cann_stream_handle.h index 5d822d23f966f..f20eafb2b4b35 100644 --- a/onnxruntime/core/providers/cann/cann_stream_handle.h +++ b/onnxruntime/core/providers/cann/cann_stream_handle.h @@ -12,7 +12,7 @@ #include "core/providers/cann/cann_call.h" namespace onnxruntime { -void WaitCannNotificationOnDevice(Stream& stream, synchronize::Notification& notification); +void WaitCannNotificationOnDevice(Stream* stream, synchronize::Notification& notification); struct CannStream : Stream { CannStream(aclrtStream stream, const OrtDevice& device, bool own_flag); diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index cc7edbc15c329..d9acb9ccdc30f 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -2856,7 +2856,8 @@ CUDAExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph, return result; } -void CUDAExecutionProvider::RegisterStreamHandlers(IStreamCommandHandleRegistry& stream_handle_registry, AllocatorMap& allocators) const { +void CUDAExecutionProvider::RegisterStreamHandlers(IStreamCommandHandleRegistry& stream_handle_registry, + AllocatorMap& allocators) const { // This allocator must be the same to the allocator // used in AllocateBufferOnCPUPinned. auto allocator = allocators[GetOrtDeviceByMemType(OrtMemTypeCPU)]; diff --git a/onnxruntime/core/providers/cuda/cuda_stream_handle.cc b/onnxruntime/core/providers/cuda/cuda_stream_handle.cc index 51fd2c67b7478..b6cbffb073774 100644 --- a/onnxruntime/core/providers/cuda/cuda_stream_handle.cc +++ b/onnxruntime/core/providers/cuda/cuda_stream_handle.cc @@ -228,11 +228,12 @@ void* CudaStream::GetResource(int version, int id) const { } // CPU Stream command handles -void WaitCudaNotificationOnDevice(Stream& stream, synchronize::Notification& notification) { - static_cast(¬ification)->wait_on_device(stream); +void WaitCudaNotificationOnDevice(Stream* stream, synchronize::Notification& notification) { + assert(stream != nullptr); // should never happen + static_cast(¬ification)->wait_on_device(*stream); } -void WaitCudaNotificationOnHost(Stream& /*stream*/, synchronize::Notification& notification) { +void WaitCudaNotificationOnHost(Stream* /*stream*/, synchronize::Notification& notification) { static_cast(¬ification)->wait_on_host(); } diff --git a/onnxruntime/core/providers/cuda/cuda_stream_handle.h b/onnxruntime/core/providers/cuda/cuda_stream_handle.h index 15e7a0553c84e..c75cf15f7c2f8 100644 --- a/onnxruntime/core/providers/cuda/cuda_stream_handle.h +++ b/onnxruntime/core/providers/cuda/cuda_stream_handle.h @@ -11,7 +11,7 @@ namespace onnxruntime { struct CudaStream; -void WaitCudaNotificationOnDevice(Stream& stream, synchronize::Notification& notification); +void WaitCudaNotificationOnDevice(Stream* stream, synchronize::Notification& notification); struct DeferredCpuAllocator : public OrtAllocator { DeferredCpuAllocator(CudaStream&); diff --git a/onnxruntime/core/providers/migraphx/migraphx_stream_handle.cc b/onnxruntime/core/providers/migraphx/migraphx_stream_handle.cc index e8e349af75aba..8ed4e4a45a8c4 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_stream_handle.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_stream_handle.cc @@ -135,11 +135,11 @@ void* MIGraphXStream::GetResource(int version, int id) const { } // CPU Stream command handles -void WaitMIGraphXNotificationOnDevice(Stream& stream, synchronize::Notification& notification) { - static_cast(¬ification)->wait_on_device(stream); +void WaitMIGraphXNotificationOnDevice(Stream* stream, synchronize::Notification& notification) { + static_cast(¬ification)->wait_on_device(*stream); } -void WaitMIGraphXNotificationOnHost(Stream& /*stream*/, synchronize::Notification& notification) { +void WaitMIGraphXNotificationOnHost(Stream* /*stream*/, synchronize::Notification& notification) { static_cast(¬ification)->wait_on_host(); } diff --git a/onnxruntime/core/providers/migraphx/migraphx_stream_handle.h b/onnxruntime/core/providers/migraphx/migraphx_stream_handle.h index 85b0aff87a436..d0ef3334b38c9 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_stream_handle.h +++ b/onnxruntime/core/providers/migraphx/migraphx_stream_handle.h @@ -7,7 +7,7 @@ #include "migraphx_call.h" namespace onnxruntime { -void WaitMIGraphXNotificationOnDevice(Stream& stream, synchronize::Notification& notification); +void WaitMIGraphXNotificationOnDevice(Stream* stream, synchronize::Notification& notification); struct MIGraphXStream : Stream { MIGraphXStream(hipStream_t stream, diff --git a/onnxruntime/test/framework/bfc_arena_test.cc b/onnxruntime/test/framework/bfc_arena_test.cc index e9f734057da1c..670447f2804dc 100644 --- a/onnxruntime/test/framework/bfc_arena_test.cc +++ b/onnxruntime/test/framework/bfc_arena_test.cc @@ -406,7 +406,7 @@ TEST(StreamAwareArenaTest, TestSecureTheChunk) { bool waitFunctionInvoked = false; void* p2 = a.AllocOnStream(BFCArena::DEFAULT_INITIAL_CHUNK_SIZE_BYTES, &stream2, - [&waitFunctionInvoked](Stream&, synchronize::Notification&) { waitFunctionInvoked = true; }); + [&waitFunctionInvoked](Stream*, synchronize::Notification&) { waitFunctionInvoked = true; }); std::unordered_map syncTable; stream2.CloneCurrentStreamSyncTable(syncTable); From 763d55421b66ade3622f3fd2f7f2414f7fb6463d Mon Sep 17 00:00:00 2001 From: Edward Chen <18449977+edgchen1@users.noreply.github.com> Date: Thu, 3 Jul 2025 19:19:10 -0700 Subject: [PATCH 18/19] `OrtKeyValuePairs` updates (#25284) ### Description - Change the entries container type from `std::unordered_map` to `std::map`. This enables deterministic iteration order. - Enforce internal container state consistency. `OrtKeyValuePairs` has several internal containers that must stay consistent. Previously, the internal containers were public. This change makes them private and also fixes the copy/move behavior. - Add unit tests. ### Motivation and Context Some fixes and improvements. --- .../core/session/onnxruntime_c_api.h | 2 + onnxruntime/core/framework/allocator.cc | 13 +-- .../core/platform/windows/device_discovery.cc | 2 +- onnxruntime/core/session/abi_devices.h | 5 +- .../core/session/abi_key_value_pairs.h | 94 ++++++++++++------ .../core/session/allocator_adapters.cc | 44 +++++---- .../core/session/ep_library_internal.cc | 2 +- onnxruntime/core/session/onnxruntime_c_api.cc | 9 +- .../core/session/provider_policy_context.cc | 4 +- onnxruntime/core/session/utils.cc | 4 +- .../python/onnxruntime_pybind_state.cc | 20 ++-- .../python/onnxruntime_pybind_state_common.h | 9 +- .../test/autoep/test_autoep_selection.cc | 4 +- .../test/framework/key_value_pairs_test.cc | 96 +++++++++++++++++++ 14 files changed, 225 insertions(+), 83 deletions(-) create mode 100644 onnxruntime/test/framework/key_value_pairs_test.cc diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 4fd47f9f3f997..7cdcbb3bc76bf 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -5075,6 +5075,8 @@ struct OrtApi { void(ORT_API_CALL* CreateKeyValuePairs)(_Outptr_ OrtKeyValuePairs** out); /** \brief Add a key-value pair to the OrtKeyValuePairs instance. + * + * If a pair with the same key already exists, it is overwritten. * * \param[in] kvps OrtKeyValuePairs instance. * \param[in] key Key to be added. diff --git a/onnxruntime/core/framework/allocator.cc b/onnxruntime/core/framework/allocator.cc index 30ff8342a8009..9324fa76ded4f 100644 --- a/onnxruntime/core/framework/allocator.cc +++ b/onnxruntime/core/framework/allocator.cc @@ -31,27 +31,28 @@ Status OrtArenaCfg::FromKeyValuePairs(const OrtKeyValuePairs& kvps, OrtArenaCfg& return Status::OK(); }; - if (auto it = kvps.entries.find(ConfigKeyNames::ArenaExtendStrategy); it != kvps.entries.end()) { + const auto& kvps_entries = kvps.Entries(); + if (auto it = kvps_entries.find(ConfigKeyNames::ArenaExtendStrategy); it != kvps_entries.end()) { ORT_RETURN_IF_ERROR(from_string(it->first, it->second, cfg.arena_extend_strategy)); } - if (auto it = kvps.entries.find(ConfigKeyNames::InitialChunkSizeBytes); it != kvps.entries.end()) { + if (auto it = kvps_entries.find(ConfigKeyNames::InitialChunkSizeBytes); it != kvps_entries.end()) { ORT_RETURN_IF_ERROR(from_string(it->first, it->second, cfg.initial_chunk_size_bytes)); } - if (auto it = kvps.entries.find(ConfigKeyNames::MaxDeadBytesPerChunk); it != kvps.entries.end()) { + if (auto it = kvps_entries.find(ConfigKeyNames::MaxDeadBytesPerChunk); it != kvps_entries.end()) { ORT_RETURN_IF_ERROR(from_string(it->first, it->second, cfg.max_dead_bytes_per_chunk)); } - if (auto it = kvps.entries.find(ConfigKeyNames::InitialGrowthChunkSizeBytes); it != kvps.entries.end()) { + if (auto it = kvps_entries.find(ConfigKeyNames::InitialGrowthChunkSizeBytes); it != kvps_entries.end()) { ORT_RETURN_IF_ERROR(from_string(it->first, it->second, cfg.initial_growth_chunk_size_bytes)); } - if (auto it = kvps.entries.find(ConfigKeyNames::MaxPowerOfTwoExtendBytes); it != kvps.entries.end()) { + if (auto it = kvps_entries.find(ConfigKeyNames::MaxPowerOfTwoExtendBytes); it != kvps_entries.end()) { ORT_RETURN_IF_ERROR(from_string(it->first, it->second, cfg.max_power_of_two_extend_bytes)); } - if (auto it = kvps.entries.find(ConfigKeyNames::MaxMem); it != kvps.entries.end()) { + if (auto it = kvps_entries.find(ConfigKeyNames::MaxMem); it != kvps_entries.end()) { ORT_RETURN_IF_ERROR(from_string(it->first, it->second, cfg.max_mem)); } diff --git a/onnxruntime/core/platform/windows/device_discovery.cc b/onnxruntime/core/platform/windows/device_discovery.cc index 3908af40f962b..dcc030cb3467d 100644 --- a/onnxruntime/core/platform/windows/device_discovery.cc +++ b/onnxruntime/core/platform/windows/device_discovery.cc @@ -581,7 +581,7 @@ std::unordered_set DeviceDiscovery::DiscoverDevicesForPlatfor << ", vendor:" << ortdevice.vendor << ", type:" << std::dec << static_cast(ortdevice.type) << ", metadata: ["; - for (auto& [key, value] : ortdevice.metadata.entries) { + for (auto& [key, value] : ortdevice.metadata.Entries()) { oss << key << "=" << value << ", "; } diff --git a/onnxruntime/core/session/abi_devices.h b/onnxruntime/core/session/abi_devices.h index 8f9e8c20926fc..67253a83ab490 100644 --- a/onnxruntime/core/session/abi_devices.h +++ b/onnxruntime/core/session/abi_devices.h @@ -26,7 +26,7 @@ struct OrtHardwareDevice { onnxruntime::HashCombine(hd.vendor_id, h); onnxruntime::HashCombine(hd.vendor, h); onnxruntime::HashCombine(hd.type, h); - for (const auto& [key, value] : hd.metadata.entries) { + for (const auto& [key, value] : hd.metadata.Entries()) { onnxruntime::HashCombine(key, h); onnxruntime::HashCombine(value, h); } @@ -51,8 +51,7 @@ struct equal_to { lhs.vendor_id == rhs.vendor_id && lhs.device_id == rhs.device_id && lhs.vendor == rhs.vendor && - lhs.metadata.keys == rhs.metadata.keys && - lhs.metadata.values == rhs.metadata.values; + lhs.metadata.Entries() == rhs.metadata.Entries(); } }; } // namespace std diff --git a/onnxruntime/core/session/abi_key_value_pairs.h b/onnxruntime/core/session/abi_key_value_pairs.h index 150575b3a9efc..7d739439b7a27 100644 --- a/onnxruntime/core/session/abi_key_value_pairs.h +++ b/onnxruntime/core/session/abi_key_value_pairs.h @@ -4,20 +4,41 @@ #pragma once #include +#include #include -#include #include +#include + +#include "gsl/gsl" struct OrtKeyValuePairs { - std::unordered_map entries; - // members to make returning all key/value entries via the C API easier - std::vector keys; - std::vector values; + OrtKeyValuePairs() = default; + + OrtKeyValuePairs(const OrtKeyValuePairs& other) { + CopyFromMap(other.entries_); + } + + OrtKeyValuePairs(OrtKeyValuePairs&& other) : OrtKeyValuePairs{} { + swap(*this, other); + } + + OrtKeyValuePairs& operator=(OrtKeyValuePairs other) { // handles copy and move assignment + swap(*this, other); + return *this; + } + + friend void swap(OrtKeyValuePairs& a, OrtKeyValuePairs& b) { + using std::swap; + swap(a.entries_, b.entries_); + swap(a.keys_, b.keys_); + swap(a.values_, b.values_); + } - void Copy(const std::unordered_map& src) { - entries = src; + void CopyFromMap(std::map src) { + entries_ = std::move(src); Sync(); } + void Add(const char* key, const char* value) { // ignore if either are nullptr. if (key && value) { @@ -25,17 +46,16 @@ struct OrtKeyValuePairs { } } - void Add(const std::string& key, const std::string& value) { + void Add(std::string key, std::string value) { if (key.empty()) { // ignore empty keys return; } - auto iter_inserted = entries.insert({key, value}); - bool inserted = iter_inserted.second; + auto [it, inserted] = entries_.insert_or_assign(std::move(key), std::move(value)); if (inserted) { - const auto& entry = *iter_inserted.first; - keys.push_back(entry.first.c_str()); - values.push_back(entry.second.c_str()); + const auto& [entry_key, entry_value] = *it; + keys_.push_back(entry_key.c_str()); + values_.push_back(entry_value.c_str()); } else { // rebuild is easier and changing an entry is not expected to be a common case. Sync(); @@ -48,27 +68,47 @@ struct OrtKeyValuePairs { return; } - auto iter = entries.find(key); - if (iter != entries.end()) { - auto key_iter = std::find(keys.begin(), keys.end(), iter->first.c_str()); - // there should only ever be one matching entry, and keys and values should be in sync - if (key_iter != keys.end()) { - auto idx = std::distance(keys.begin(), key_iter); - keys.erase(key_iter); - values.erase(values.begin() + idx); + auto iter = entries_.find(key); + if (iter != entries_.end()) { + auto key_iter = std::find(keys_.begin(), keys_.end(), iter->first.c_str()); + // there should only ever be one matching entry, and keys_ and values_ should be in sync + if (key_iter != keys_.end()) { + auto idx = std::distance(keys_.begin(), key_iter); + keys_.erase(key_iter); + values_.erase(values_.begin() + idx); } - entries.erase(iter); + entries_.erase(iter); } } + const std::map& Entries() const { + return entries_; + } + + gsl::span Keys() const { + return keys_; + } + + gsl::span Values() const { + return values_; + } + private: void Sync() { - keys.clear(); - values.clear(); - for (const auto& entry : entries) { - keys.push_back(entry.first.c_str()); - values.push_back(entry.second.c_str()); + keys_.clear(); + values_.clear(); + for (const auto& entry : entries_) { + keys_.push_back(entry.first.c_str()); + values_.push_back(entry.second.c_str()); } } + + // Note: Use std::map so that we can iterate through entries in a deterministic order. + std::map entries_; + + // members to make returning all key/value entries via the C API easier + // Note: The elements point to strings owned by `entries_`. + std::vector keys_; + std::vector values_; }; diff --git a/onnxruntime/core/session/allocator_adapters.cc b/onnxruntime/core/session/allocator_adapters.cc index 9e38a0ef75ccc..c6eff29a0bd4f 100644 --- a/onnxruntime/core/session/allocator_adapters.cc +++ b/onnxruntime/core/session/allocator_adapters.cc @@ -41,8 +41,8 @@ OrtAllocatorImplWrappingIAllocator::OrtAllocatorImplWrappingIAllocator(onnxrunti [](const OrtAllocator* this_, OrtKeyValuePairs** stats) noexcept -> OrtStatusPtr { API_IMPL_BEGIN auto kvp = std::make_unique(); - auto stats_map = static_cast(this_)->Stats(); - kvp->Copy(stats_map); + const auto& stats_map = static_cast(this_)->Stats(); + kvp->CopyFromMap(std::map(stats_map.begin(), stats_map.end())); *stats = reinterpret_cast(kvp.release()); return nullptr; API_IMPL_END @@ -130,25 +130,27 @@ void IAllocatorImplWrappingOrtAllocator::GetStats(AllocatorStats* stats) { std::unique_ptr kvp_guard(&kvps, release_fn); - for (size_t i = 0; i < kvps->keys.size(); ++i) { - if (strcmp(kvps->keys[i], "Limit") == 0) { - stats->bytes_limit = std::stoll(kvps->values[i]); - } else if (strcmp(kvps->keys[i], "InUse") == 0) { - stats->bytes_in_use = std::stoll(kvps->values[i]); - } else if (strcmp(kvps->keys[i], "TotalAllocated") == 0) { - stats->total_allocated_bytes = std::stoll(kvps->values[i]); - } else if (strcmp(kvps->keys[i], "MaxInUse") == 0) { - stats->max_bytes_in_use = std::stoll(kvps->values[i]); - } else if (strcmp(kvps->keys[i], "NumAllocs") == 0) { - stats->num_allocs = std::stoll(kvps->values[i]); - } else if (strcmp(kvps->keys[i], "NumReserves") == 0) { - stats->num_reserves = std::stoll(kvps->values[i]); - } else if (strcmp(kvps->keys[i], "NumArenaExtensions") == 0) { - stats->num_arena_extensions = std::stoll(kvps->values[i]); - } else if (strcmp(kvps->keys[i], "NumArenaShrinkages") == 0) { - stats->num_arena_shrinkages = std::stoll(kvps->values[i]); - } else if (strcmp(kvps->keys[i], "MaxAllocSize") == 0) { - stats->max_alloc_size = std::stoll(kvps->values[i]); + const auto keys = kvps->Keys(), values = kvps->Values(); + + for (size_t i = 0; i < keys.size(); ++i) { + if (strcmp(keys[i], "Limit") == 0) { + stats->bytes_limit = std::stoll(values[i]); + } else if (strcmp(keys[i], "InUse") == 0) { + stats->bytes_in_use = std::stoll(values[i]); + } else if (strcmp(keys[i], "TotalAllocated") == 0) { + stats->total_allocated_bytes = std::stoll(values[i]); + } else if (strcmp(keys[i], "MaxInUse") == 0) { + stats->max_bytes_in_use = std::stoll(values[i]); + } else if (strcmp(keys[i], "NumAllocs") == 0) { + stats->num_allocs = std::stoll(values[i]); + } else if (strcmp(keys[i], "NumReserves") == 0) { + stats->num_reserves = std::stoll(values[i]); + } else if (strcmp(keys[i], "NumArenaExtensions") == 0) { + stats->num_arena_extensions = std::stoll(values[i]); + } else if (strcmp(keys[i], "NumArenaShrinkages") == 0) { + stats->num_arena_shrinkages = std::stoll(values[i]); + } else if (strcmp(keys[i], "MaxAllocSize") == 0) { + stats->max_alloc_size = std::stoll(values[i]); } } } diff --git a/onnxruntime/core/session/ep_library_internal.cc b/onnxruntime/core/session/ep_library_internal.cc index aa032f24f13c0..25f70f7549a16 100644 --- a/onnxruntime/core/session/ep_library_internal.cc +++ b/onnxruntime/core/session/ep_library_internal.cc @@ -83,7 +83,7 @@ std::unique_ptr EpLibraryInternal::CreateDmlEp() { // TODO: Should we ignore a user provided 'device_id' when they select an OrtEpDevice as that is associated with // a specific device. // How would we know what options should not allow user overrides if set in OrtEpDevice? - if (auto it = device.metadata.entries.find("DxgiAdapterNumber"); it != device.metadata.entries.end()) { + if (auto it = device.metadata.Entries().find("DxgiAdapterNumber"); it != device.metadata.Entries().end()) { ep_options = std::make_unique(); ep_options->Add("device_id", it->second.c_str()); } diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 8983124ce039d..2551fbc8b6099 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -2985,7 +2985,8 @@ ORT_API(void, OrtApis::AddKeyValuePair, _In_ OrtKeyValuePairs* kvps, ORT_API(const char*, OrtApis::GetKeyValue, _In_ const OrtKeyValuePairs* kvps, _In_ const char* key) { const char* value = nullptr; - if (auto entry = kvps->entries.find(key); entry != kvps->entries.end()) { + const auto& entries = kvps->Entries(); + if (auto entry = entries.find(key); entry != entries.end()) { value = entry->second.c_str(); } @@ -2994,9 +2995,9 @@ ORT_API(const char*, OrtApis::GetKeyValue, _In_ const OrtKeyValuePairs* kvps, _I ORT_API(void, OrtApis::GetKeyValuePairs, _In_ const OrtKeyValuePairs* kvps, _Outptr_ const char* const** keys, _Outptr_ const char* const** values, _Out_ size_t* num_entries) { - *keys = kvps->keys.data(); - *values = kvps->values.data(); - *num_entries = kvps->entries.size(); + *keys = kvps->Keys().data(); + *values = kvps->Values().data(); + *num_entries = kvps->Entries().size(); } ORT_API(void, OrtApis::RemoveKeyValuePair, _Frees_ptr_opt_ OrtKeyValuePairs* kvps, _In_ const char* key) { diff --git a/onnxruntime/core/session/provider_policy_context.cc b/onnxruntime/core/session/provider_policy_context.cc index a5258a4811bf7..6b54c33e9b10b 100644 --- a/onnxruntime/core/session/provider_policy_context.cc +++ b/onnxruntime/core/session/provider_policy_context.cc @@ -30,7 +30,7 @@ bool IsDiscreteDevice(const OrtEpDevice* d) { return false; } - const auto& entries = d->device->metadata.entries; + const auto& entries = d->device->metadata.Entries(); if (auto it = entries.find("Discrete"); it != entries.end()) { return it->second == "1"; } @@ -366,7 +366,7 @@ Status ProviderPolicyContext::AddEpDefaultOptionsToSession(InferenceSession& ses auto& config_options = sess.GetMutableSessionOptions().config_options; for (auto device : devices) { const std::string ep_options_prefix = OrtSessionOptions::GetProviderOptionPrefix(device->ep_name.c_str()); - for (const auto& [key, value] : device->ep_options.entries) { + for (const auto& [key, value] : device->ep_options.Entries()) { const std::string option_key = ep_options_prefix + key; // preserve user-provided options as they override any defaults the EP factory specified earlier if (config_options.configurations.find(option_key) == config_options.configurations.end()) { diff --git a/onnxruntime/core/session/utils.cc b/onnxruntime/core/session/utils.cc index d0f2e862d61d9..69039beb49363 100644 --- a/onnxruntime/core/session/utils.cc +++ b/onnxruntime/core/session/utils.cc @@ -60,7 +60,7 @@ Status TestAutoSelectEPsImpl(const Environment& env, InferenceSession& sess, con // add ep_options to SessionOptions with prefix. // preserve any user provided values. const std::string ep_options_prefix = OrtSessionOptions::GetProviderOptionPrefix(ep_device->ep_name.c_str()); - for (const auto& [key, value] : ep_device->ep_options.entries) { + for (const auto& [key, value] : ep_device->ep_options.Entries()) { auto prefixed_key = ep_options_prefix + key; if (session_options.config_options.configurations.count(key) == 0) { // add the default value with prefix @@ -353,7 +353,7 @@ Status CreateIExecutionProviderFactoryForEpDevices(const Environment& env, // first add the default values with prefix followed by user specified values so those win const std::string prefix = OrtSessionOptions::GetProviderOptionPrefix(ep_device->ep_name.c_str()); auto& config_options = session_options.config_options; - for (const auto& [key, value] : ep_device->ep_options.entries) { + for (const auto& [key, value] : ep_device->ep_options.Entries()) { ORT_RETURN_IF_ERROR(config_options.AddConfigEntry((prefix + key).c_str(), value.c_str())); } diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index 5624befd0ca66..3fb16d30f4970 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -1854,10 +1854,10 @@ static OrtStatus* ORT_API_CALL PyEpSelectionPolicyWrapper(_In_ const OrtEpDevice _In_ void* state) { PyEpSelectionDelegate* actual_delegate = reinterpret_cast(state); std::vector py_ep_devices(ep_devices, ep_devices + num_devices); - std::unordered_map py_model_metadata = - model_metadata ? model_metadata->entries : std::unordered_map{}; - std::unordered_map py_runtime_metadata = - runtime_metadata ? runtime_metadata->entries : std::unordered_map{}; + std::map py_model_metadata = + model_metadata ? model_metadata->Entries() : std::map{}; + std::map py_runtime_metadata = + runtime_metadata ? runtime_metadata->Entries() : std::map{}; *num_selected = 0; std::vector py_selected; @@ -1986,8 +1986,8 @@ void addObjectMethods(py::module& m, ExecutionProviderRegistrationFn ep_registra R"pbdoc(Hardware device's unique identifier.)pbdoc") .def_property_readonly( "metadata", - [](OrtHardwareDevice* hw_device) -> std::unordered_map { - return hw_device->metadata.entries; + [](OrtHardwareDevice* hw_device) -> std::map { + return hw_device->metadata.Entries(); }, R"pbdoc(Hardware device's metadata as string key/value pairs.)pbdoc"); @@ -2004,14 +2004,14 @@ for model inference.)pbdoc"); R"pbdoc(The execution provider's vendor name.)pbdoc") .def_property_readonly( "ep_metadata", - [](OrtEpDevice* ep_device) -> std::unordered_map { - return ep_device->ep_metadata.entries; + [](OrtEpDevice* ep_device) -> std::map { + return ep_device->ep_metadata.Entries(); }, R"pbdoc(The execution provider's additional metadata for the OrtHardwareDevice.)pbdoc") .def_property_readonly( "ep_options", - [](OrtEpDevice* ep_device) -> std::unordered_map { - return ep_device->ep_options.entries; + [](OrtEpDevice* ep_device) -> std::map { + return ep_device->ep_options.Entries(); }, R"pbdoc(The execution provider's options used to configure the provider to use the OrtHardwareDevice.)pbdoc") .def_property_readonly( diff --git a/onnxruntime/python/onnxruntime_pybind_state_common.h b/onnxruntime/python/onnxruntime_pybind_state_common.h index edb10bc28a871..b3251abbc427e 100644 --- a/onnxruntime/python/onnxruntime_pybind_state_common.h +++ b/onnxruntime/python/onnxruntime_pybind_state_common.h @@ -248,10 +248,11 @@ extern OrtDevice::DeviceId cuda_device_id; extern size_t gpu_mem_limit; #if !defined(ORT_MINIMAL_BUILD) -using PyEpSelectionDelegate = std::function(const std::vector& ep_devices, - const std::unordered_map& model_metadata, - const std::unordered_map& runtime_metadata, - size_t max_selections)>; +using PyEpSelectionDelegate = + std::function(const std::vector& ep_devices, + const std::map& model_metadata, + const std::map& runtime_metadata, + size_t max_selections)>; #endif // Thin wrapper over internal C OrtSessionOptions to store additional state. diff --git a/onnxruntime/test/autoep/test_autoep_selection.cc b/onnxruntime/test/autoep/test_autoep_selection.cc index eed2068f92f53..be20d2c7c5a60 100644 --- a/onnxruntime/test/autoep/test_autoep_selection.cc +++ b/onnxruntime/test/autoep/test_autoep_selection.cc @@ -107,7 +107,7 @@ static void TestInference(Ort::Env& env, const std::basic_string& mod // C API. Test the C++ API because if it works the C API must also work. // ASSERT_ORTSTATUS_OK(Ort::GetApi().SessionOptionsAppendExecutionProvider_V2( // session_options, env, devices.data(), devices.size(), - // provider_options.keys.data(), provider_options.values.data(), provider_options.entries.size())); + // provider_options.Keys().data(), provider_options.Values().data(), provider_options.Entries().size())); std::vector ep_devices; ep_devices.reserve(devices.size()); for (const auto* device : devices) { @@ -370,7 +370,7 @@ static OrtStatus* ORT_API_CALL PolicyDelegate(_In_ const OrtEpDevice** ep_device return Ort::GetApi().CreateStatus(ORT_INVALID_ARGUMENT, "Expected to be able to select 2 devices."); } - if (model_metadata->entries.empty()) { + if (model_metadata->Entries().empty()) { return Ort::GetApi().CreateStatus(ORT_INVALID_ARGUMENT, "Model metadata was empty."); } diff --git a/onnxruntime/test/framework/key_value_pairs_test.cc b/onnxruntime/test/framework/key_value_pairs_test.cc new file mode 100644 index 0000000000000..c7d34a007e40e --- /dev/null +++ b/onnxruntime/test/framework/key_value_pairs_test.cc @@ -0,0 +1,96 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/session/abi_key_value_pairs.h" + +#include +#include + +#include "gtest/gtest.h" + +namespace onnxruntime::test { + +namespace { + +// Verify that the OrtKeyValuePairs internal containers are all consistent. +void CheckConsistency(const OrtKeyValuePairs& kvps) { + ASSERT_EQ(kvps.Keys().size(), kvps.Entries().size()); + ASSERT_EQ(kvps.Values().size(), kvps.Entries().size()); + + for (const auto& [k, v] : kvps.Entries()) { + auto key_it = std::find(kvps.Keys().begin(), kvps.Keys().end(), k.c_str()); + ASSERT_NE(key_it, kvps.Keys().end()); + + const auto entry_idx = std::distance(kvps.Keys().begin(), key_it); + ASSERT_EQ(kvps.Values()[entry_idx], v.c_str()); + } +} + +} // namespace + +TEST(OrtKeyValuePairsTest, BasicUsage) { + const auto kvp_entry_map = std::map{ + {"a", "1"}, {"b", "2"}, {"c", "3"}}; + + OrtKeyValuePairs kvps{}; + kvps.CopyFromMap(kvp_entry_map); + CheckConsistency(kvps); + ASSERT_EQ(kvps.Entries(), kvp_entry_map); + + kvps.Add("d", "4"); + CheckConsistency(kvps); + ASSERT_EQ(kvps.Entries().size(), 4); + + kvps.Remove("c"); + CheckConsistency(kvps); + ASSERT_EQ(kvps.Entries().size(), 3); +} + +TEST(OrtKeyValuePairsTest, CopyAndMove) { + const auto kvp_entry_map = std::map{ + {"a", "1"}, {"b", "2"}, {"c", "3"}}; + + OrtKeyValuePairs kvps0{}; + kvps0.CopyFromMap(kvp_entry_map); + CheckConsistency(kvps0); + + OrtKeyValuePairs kvps1 = kvps0; + CheckConsistency(kvps1); + ASSERT_EQ(kvps1.Entries(), kvps0.Entries()); + + OrtKeyValuePairs kvps2 = std::move(kvps1); + CheckConsistency(kvps1); + CheckConsistency(kvps2); + ASSERT_TRUE(kvps1.Entries().empty()); + ASSERT_EQ(kvps2.Entries(), kvps0.Entries()); +} + +TEST(OrtKeyValuePairsTest, Overwrite) { + OrtKeyValuePairs kvps{}; + + kvps.Add("a", "1"); + CheckConsistency(kvps); + + kvps.Add("a", "2"); + CheckConsistency(kvps); + ASSERT_EQ(kvps.Values().size(), 1); + ASSERT_STREQ(kvps.Values()[0], "2"); +} + +TEST(OrtKeyValuePairsTest, IgnoredInput) { + OrtKeyValuePairs kvps{}; + + kvps.Add(nullptr, "1"); + CheckConsistency(kvps); + ASSERT_EQ(kvps.Entries().size(), size_t{0}); + + kvps.Add("a", nullptr); + CheckConsistency(kvps); + ASSERT_EQ(kvps.Entries().size(), size_t{0}); + + kvps.Add("", "1"); // empty key is ignored + CheckConsistency(kvps); + ASSERT_EQ(kvps.Entries().size(), size_t{0}); +} + +} // namespace onnxruntime::test From 2f878c60296de169a8a523e692d3d65893f7c133 Mon Sep 17 00:00:00 2001 From: Jeff Kilpatrick Date: Thu, 3 Jul 2025 21:55:40 -0700 Subject: [PATCH 19/19] [QNN EP] Upgrade QNN to 2.36.0 (#25283) ### Description Update Qnn default version to 2.36.0.250627 --- .../android-arm64-v8a-QNN-crosscompile-ci-pipeline.yml | 2 +- .../azure-pipelines/c-api-noopenmp-packaging-pipelines.yml | 2 +- .../github/azure-pipelines/custom-nuget-packaging-pipeline.yml | 2 +- tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml | 2 +- tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml | 2 +- .../github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml | 2 +- .../github/azure-pipelines/stages/py-cpu-packaging-stage.yml | 2 +- .../azure-pipelines/templates/android-java-api-aar-test.yml | 2 +- .../github/azure-pipelines/templates/android-java-api-aar.yml | 2 +- tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml | 2 +- .../azure-pipelines/templates/jobs/download_linux_qnn_sdk.yml | 2 +- .../azure-pipelines/templates/jobs/download_win_qnn_sdk.yml | 2 +- .../ci_build/github/azure-pipelines/templates/py-linux-qnn.yml | 2 +- .../github/azure-pipelines/templates/py-win-arm64-qnn.yml | 2 +- .../github/azure-pipelines/templates/py-win-arm64ec-qnn.yml | 2 +- .../github/azure-pipelines/templates/py-win-x64-qnn.yml | 2 +- tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml | 2 +- .../github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml | 2 +- tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml | 2 +- 19 files changed, 19 insertions(+), 19 deletions(-) diff --git a/tools/ci_build/github/azure-pipelines/android-arm64-v8a-QNN-crosscompile-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/android-arm64-v8a-QNN-crosscompile-ci-pipeline.yml index 8c16b6b7caa69..ee7f8f2fa386a 100644 --- a/tools/ci_build/github/azure-pipelines/android-arm64-v8a-QNN-crosscompile-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/android-arm64-v8a-QNN-crosscompile-ci-pipeline.yml @@ -32,7 +32,7 @@ parameters: - name: QnnSdk displayName: QNN SDK version type: string - default: 2.35.0.250530 + default: 2.36.0.250627 jobs: - job: Build_QNN_EP diff --git a/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml b/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml index 8d23a2576b1ef..aa25e3f31166a 100644 --- a/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml +++ b/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml @@ -60,7 +60,7 @@ parameters: - name: QnnSdk displayName: QNN SDK Version type: string - default: 2.35.0.250530 + default: 2.36.0.250627 resources: repositories: diff --git a/tools/ci_build/github/azure-pipelines/custom-nuget-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/custom-nuget-packaging-pipeline.yml index d3bd9c79afe08..7addb3217072a 100644 --- a/tools/ci_build/github/azure-pipelines/custom-nuget-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/custom-nuget-packaging-pipeline.yml @@ -6,7 +6,7 @@ parameters: - name: QnnSdk displayName: QNN SDK Version type: string - default: 2.35.0.250530 + default: 2.36.0.250627 - name: IsReleaseBuild displayName: Is a release build? Set it to true if you are doing an Onnx Runtime release. diff --git a/tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml index 1006936403d45..cf8bbbed70525 100644 --- a/tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml @@ -33,7 +33,7 @@ parameters: - name: QnnSdk displayName: QNN SDK version type: string - default: 2.35.0.250530 + default: 2.36.0.250627 jobs: - job: Build_QNN_EP diff --git a/tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml index 96ae6952f8827..de024f0b3456f 100644 --- a/tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml @@ -59,7 +59,7 @@ parameters: - name: qnn_sdk_version type: string displayName: 'QNN SDK version. Only for QNN packages.' - default: 2.35.0.250530 + default: 2.36.0.250627 trigger: none diff --git a/tools/ci_build/github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml index 742bd68fad104..4fa916db0de39 100644 --- a/tools/ci_build/github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml @@ -2,7 +2,7 @@ parameters: - name: QnnSdk displayName: QNN SDK Version type: string - default: 2.35.0.250530 + default: 2.36.0.250627 - name: build_config displayName: Build Configuration diff --git a/tools/ci_build/github/azure-pipelines/stages/py-cpu-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/stages/py-cpu-packaging-stage.yml index 8440a2c98bb6a..433250f05125e 100644 --- a/tools/ci_build/github/azure-pipelines/stages/py-cpu-packaging-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/py-cpu-packaging-stage.yml @@ -59,7 +59,7 @@ parameters: - name: qnn_sdk_version type: string displayName: 'QNN SDK version. Only for QNN packages.' - default: 2.35.0.250530 + default: 2.36.0.250627 stages: - ${{ if eq(parameters.enable_windows_cpu, true) }}: diff --git a/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar-test.yml b/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar-test.yml index 68e1e1b39c56c..ab779e164b36e 100644 --- a/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar-test.yml +++ b/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar-test.yml @@ -19,7 +19,7 @@ parameters: - name: QnnSDKVersion displayName: QNN SDK Version type: string - default: '2.35.0.250530' + default: '2.36.0.250627' - name: enableWebGpu displayName: Enable WebGPU test diff --git a/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar.yml b/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar.yml index 4c5801664dda9..110f83ff587c8 100644 --- a/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar.yml +++ b/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar.yml @@ -53,7 +53,7 @@ parameters: - name: QnnSDKVersion displayName: QNN SDK Version type: string - default: '2.35.0.250530' + default: '2.36.0.250627' - name: is1ES displayName: Is 1ES pipeline diff --git a/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml b/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml index ed2d914df8d81..535784933a087 100644 --- a/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml +++ b/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml @@ -47,7 +47,7 @@ parameters: - name: QnnSDKVersion displayName: QNN SDK Version type: string - default: 2.35.0.250530 + default: 2.36.0.250627 - name: is1ES displayName: Is 1ES pipeline diff --git a/tools/ci_build/github/azure-pipelines/templates/jobs/download_linux_qnn_sdk.yml b/tools/ci_build/github/azure-pipelines/templates/jobs/download_linux_qnn_sdk.yml index 4d343be496475..3e7427cc7a2e3 100644 --- a/tools/ci_build/github/azure-pipelines/templates/jobs/download_linux_qnn_sdk.yml +++ b/tools/ci_build/github/azure-pipelines/templates/jobs/download_linux_qnn_sdk.yml @@ -1,7 +1,7 @@ parameters: - name: QnnSDKVersion type: string - default: '2.35.0.250530' + default: '2.36.0.250627' steps: - script: | diff --git a/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_qnn_sdk.yml b/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_qnn_sdk.yml index 062b70a2249f6..e3f549e2d649f 100644 --- a/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_qnn_sdk.yml +++ b/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_qnn_sdk.yml @@ -1,7 +1,7 @@ parameters: - name: QnnSDKVersion type: string - default: '2.35.0.250530' + default: '2.36.0.250627' steps: - powershell: | diff --git a/tools/ci_build/github/azure-pipelines/templates/py-linux-qnn.yml b/tools/ci_build/github/azure-pipelines/templates/py-linux-qnn.yml index b31a8f4a2190f..d533fb7c83ddd 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-linux-qnn.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-linux-qnn.yml @@ -26,7 +26,7 @@ parameters: - name: QnnSdk displayName: QNN SDK version type: string - default: 2.35.0.250530 + default: 2.36.0.250627 - name: is1ES displayName: 'Whether the pipeline is running in 1ES' diff --git a/tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml b/tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml index 3f6b554a679ca..cd060d1fbf19f 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml @@ -7,7 +7,7 @@ parameters: - name: QNN_SDK displayName: QNN SDK Version type: string - default: 2.35.0.250530 + default: 2.36.0.250627 - name: ENV_SETUP_SCRIPT type: string diff --git a/tools/ci_build/github/azure-pipelines/templates/py-win-arm64ec-qnn.yml b/tools/ci_build/github/azure-pipelines/templates/py-win-arm64ec-qnn.yml index 8d3c7a2914672..2a2ac49b4e073 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-win-arm64ec-qnn.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-win-arm64ec-qnn.yml @@ -7,7 +7,7 @@ parameters: - name: QNN_SDK displayName: QNN SDK Version type: string - default: 2.35.0.250530 + default: 2.36.0.250627 - name: ENV_SETUP_SCRIPT type: string diff --git a/tools/ci_build/github/azure-pipelines/templates/py-win-x64-qnn.yml b/tools/ci_build/github/azure-pipelines/templates/py-win-x64-qnn.yml index 263992a034ffe..8528fa3907e96 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-win-x64-qnn.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-win-x64-qnn.yml @@ -7,7 +7,7 @@ parameters: - name: QNN_SDK displayName: QNN SDK Version type: string - default: 2.35.0.250530 + default: 2.36.0.250627 - name: ENV_SETUP_SCRIPT type: string diff --git a/tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml b/tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml index 0960ae5ebda83..1406ce338f13e 100644 --- a/tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml +++ b/tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml @@ -1,5 +1,5 @@ parameters: - QnnSdk: '2.35.0.250530' + QnnSdk: '2.36.0.250627' build_config: 'RelWithDebInfo' IsReleaseBuild: false DoEsrp: false diff --git a/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml index e3774fc4476ec..78fce1f9b9602 100644 --- a/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml @@ -33,7 +33,7 @@ parameters: - name: QnnSdk displayName: QNN SDK version type: string - default: 2.35.0.250530 + default: 2.36.0.250627 jobs: - job: 'BUILD_QNN_EP' diff --git a/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml index 9b96d3eb8e304..eb77c9422853d 100644 --- a/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml @@ -33,7 +33,7 @@ parameters: - name: QnnSdk displayName: QNN SDK version type: string - default: 2.35.0.250530 + default: 2.36.0.250627 jobs: - job: 'BUILD_QNN_EP'