diff --git a/.github/actions/webgpu-validate-shader-key/parse-chromium-debug-log.js b/.github/actions/webgpu-validate-shader-key/parse-chromium-debug-log.js index 2342381f03934..d34ff95192089 100644 --- a/.github/actions/webgpu-validate-shader-key/parse-chromium-debug-log.js +++ b/.github/actions/webgpu-validate-shader-key/parse-chromium-debug-log.js @@ -16,7 +16,7 @@ async function processChromiumDebugLog() { for await (const line of rl) { const result = - /^\[.+INFO:CONSOLE\(\d+\)]\ "(?.+)",\ source:\ [^"]+?\(\d+\)$/.exec( + /^\[.+INFO:CONSOLE.+?]\ "(?.+)",\ source:\ [^"]+?\(\d+\)$/.exec( line ); if (!result) { diff --git a/.github/workflows/codeql.yml b/.github/workflows/codeql.yml index 3fadd99bd4ccc..a0188b864d849 100644 --- a/.github/workflows/codeql.yml +++ b/.github/workflows/codeql.yml @@ -8,16 +8,16 @@ name: "CodeQL" on: push: - branches: [ "main", nuget_pkg, rel-* ] + branches: ["main", nuget_pkg, rel-*] pull_request: # The branches below must be a subset of the branches above - branches: [ "main" ] + branches: ["main"] schedule: - cron: '41 13 * * 0' workflow_dispatch: concurrency: - group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} + group: ${{ github.workflow }}-${{ github.event_name == 'pull_request' && github.ref || github.sha }} cancel-in-progress: true jobs: @@ -32,44 +32,44 @@ jobs: strategy: fail-fast: false matrix: - language: [ 'java', 'javascript', 'python' ] + language: ['java', 'javascript', 'python'] # CodeQL supports [ 'cpp', 'csharp', 'go', 'java', 'javascript', 'python', 'ruby' ] # Learn more about CodeQL language support at https://aka.ms/codeql-docs/language-support steps: - - name: Checkout repository - uses: actions/checkout@v4 + - name: Checkout repository + uses: actions/checkout@v4 # Initializes the CodeQL tools for scanning. - - name: Initialize CodeQL - uses: github/codeql-action/init@v3 - with: - languages: ${{ matrix.language }} + - name: Initialize CodeQL + uses: github/codeql-action/init@v3 + with: + languages: ${{ matrix.language }} # If you wish to specify custom queries, you can do so here or in a config file. # By default, queries listed here will override any specified in a config file. # Prefix the list here with "+" to use these queries and those in the config file. # Details on CodeQL's query packs refer to: https://docs.github.com/en/code-security/code-scanning/automatically-scanning-your-code-for-vulnerabilities-and-errors/configuring-code-scanning#using-queries-in-ql-packs - queries: security-extended,security-and-quality + queries: security-extended,security-and-quality # Setup Java to use a version that is not too old for the project - - if: ${{ matrix.language == 'java' }} - name: Setup Java 11 - uses: actions/setup-java@v4 - with: - java-version: '11' - distribution: 'microsoft' + - if: ${{ matrix.language == 'java' }} + name: Setup Java 11 + uses: actions/setup-java@v4 + with: + java-version: '11' + distribution: 'microsoft' - - if: ${{ matrix.language == 'javascript' }} - uses: actions/setup-node@v4 - with: - node-version: 20 + - if: ${{ matrix.language == 'javascript' }} + uses: actions/setup-node@v4 + with: + node-version: 20 # Autobuild attempts to build any compiled languages (C/C++, C#, or Java). # If this step fails, then you should remove it and run the build manually (see below) - - if: ${{ matrix.language != 'cpp' }} - name: Autobuild - uses: github/codeql-action/autobuild@v3 + - if: ${{ matrix.language != 'cpp' }} + name: Autobuild + uses: github/codeql-action/autobuild@v3 - - name: Perform CodeQL Analysis - uses: github/codeql-action/analyze@v3 + - name: Perform CodeQL Analysis + uses: github/codeql-action/analyze@v3 diff --git a/.github/workflows/gradle-wrapper-validation.yml b/.github/workflows/gradle-wrapper-validation.yml index 88ec3b451dd2d..91e42583d361f 100644 --- a/.github/workflows/gradle-wrapper-validation.yml +++ b/.github/workflows/gradle-wrapper-validation.yml @@ -5,9 +5,9 @@ name: "Validate Gradle Wrapper" on: push: - branches: [ main, 'rel-*'] + branches: [main, 'rel-*'] pull_request: - branches: [ main, 'rel-*'] + branches: [main, 'rel-*'] workflow_dispatch: jobs: @@ -17,3 +17,6 @@ jobs: steps: - uses: actions/checkout@v4 - uses: gradle/actions/wrapper-validation@v4 +concurrency: + group: ${{ github.workflow }}-${{ github.event_name == 'pull_request' && github.ref || github.sha }} + cancel-in-progress: true diff --git a/.github/workflows/ios.yml b/.github/workflows/ios.yml index 2c51bf0ce476f..75f5d02fd3720 100644 --- a/.github/workflows/ios.yml +++ b/.github/workflows/ios.yml @@ -12,53 +12,53 @@ on: workflow_dispatch: concurrency: - group: ${{ github.workflow }}-${{ github.ref }} + group: ${{ github.workflow }}-${{ github.event_name == 'pull_request' && github.ref || github.sha }} cancel-in-progress: true jobs: iOS_CI_on_Mac: runs-on: macos-14 steps: - - name: Checkout repository - uses: actions/checkout@v4 - with: - submodules: false - - name: Use Xcode ${{ env.XCODE_VERSION }} - shell: bash - run: | - set -e -x - XCODE_DEVELOPER_DIR="/Applications/Xcode_${{ env.XCODE_VERSION }}.app/Contents/Developer" - sudo xcode-select --switch "${XCODE_DEVELOPER_DIR}" + - name: Checkout repository + uses: actions/checkout@v4 + with: + submodules: false + - name: Use Xcode ${{ env.XCODE_VERSION }} + shell: bash + run: | + set -e -x + XCODE_DEVELOPER_DIR="/Applications/Xcode_${{ env.XCODE_VERSION }}.app/Contents/Developer" + sudo xcode-select --switch "${XCODE_DEVELOPER_DIR}" - - name: Export GitHub Actions cache environment variables - uses: actions/github-script@v7 - with: - script: | - core.exportVariable('ACTIONS_CACHE_URL', process.env.ACTIONS_CACHE_URL || ''); - core.exportVariable('ACTIONS_RUNTIME_TOKEN', process.env.ACTIONS_RUNTIME_TOKEN || ''); + - name: Export GitHub Actions cache environment variables + uses: actions/github-script@v7 + with: + script: | + core.exportVariable('ACTIONS_CACHE_URL', process.env.ACTIONS_CACHE_URL || ''); + core.exportVariable('ACTIONS_RUNTIME_TOKEN', process.env.ACTIONS_RUNTIME_TOKEN || ''); - - name: (CPU, CoreML, XNNPACK EPs) Build onnxruntime for iOS x86_64 and run tests using simulator - shell: bash - run: | - python3 ${{ github.workspace }}/tools/ci_build/build.py \ - --skip_submodule_sync \ - --build_dir ${{ github.workspace }}/iOS \ - --build_shared_lib \ - --use_coreml \ - --use_xnnpack \ - --ios \ - --apple_sysroot iphonesimulator \ - --osx_arch x86_64 \ - --apple_deploy_target=15.1 \ - --use_xcode \ - --config RelWithDebInfo \ - --build_apple_framework \ - --parallel \ - --use_binskim_compliant_compile_flags - env: - ORT_GET_SIMULATOR_DEVICE_INFO_REQUESTED_RUNTIME_VERSION: ${{ env.IOS_SIMULATOR_RUNTIME_VERSION }} + - name: (CPU, CoreML, XNNPACK EPs) Build onnxruntime for iOS x86_64 and run tests using simulator + shell: bash + run: | + python3 ${{ github.workspace }}/tools/ci_build/build.py \ + --skip_submodule_sync \ + --build_dir ${{ github.workspace }}/iOS \ + --build_shared_lib \ + --use_coreml \ + --use_xnnpack \ + --ios \ + --apple_sysroot iphonesimulator \ + --osx_arch x86_64 \ + --apple_deploy_target=15.1 \ + --use_xcode \ + --config RelWithDebInfo \ + --build_apple_framework \ + --parallel \ + --use_binskim_compliant_compile_flags + env: + ORT_GET_SIMULATOR_DEVICE_INFO_REQUESTED_RUNTIME_VERSION: ${{ env.IOS_SIMULATOR_RUNTIME_VERSION }} timeout-minutes: 150 env: XCODE_VERSION: 15.3.0 - IOS_SIMULATOR_RUNTIME_VERSION: 17.4 \ No newline at end of file + IOS_SIMULATOR_RUNTIME_VERSION: 17.4 diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index c7e434de0aa33..16c9008f3675f 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -9,7 +9,7 @@ on: workflow_dispatch: concurrency: - group: ${{ github.workflow }}-${{ github.ref }}-${{ github.event_name == 'workflow_dispatch' }} + group: ${{ github.workflow }}-${{ github.event_name == 'pull_request' && github.ref || github.sha }} cancel-in-progress: true jobs: @@ -109,18 +109,7 @@ jobs: github_token: ${{ secrets.github_token }} reporter: github-pr-check level: info - flags: --linelength=120 - --exclude=java/src/main/native/*.c - --exclude=onnxruntime/core/mlas/inc/* - --exclude=onnxruntime/core/mlas/lib/* - --exclude=onnxruntime/contrib_ops/cuda/bert/flash_attention/* - --exclude=build/Debug/* - --exclude=cmake/* - --exclude=csharp/test/* - --exclude=onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/GeneratedShaders/* - --exclude=orttraining/orttraining/test/* - --exclude=onnxruntime/test/* - --exclude=winml/* + flags: --linelength=120 --exclude=java/src/main/native/*.c --exclude=onnxruntime/core/mlas/inc/* --exclude=onnxruntime/core/mlas/lib/* --exclude=onnxruntime/contrib_ops/cuda/bert/flash_attention/* --exclude=build/Debug/* --exclude=cmake/* --exclude=csharp/test/* --exclude=onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/GeneratedShaders/* --exclude=orttraining/orttraining/test/* --exclude=onnxruntime/test/* --exclude=winml/* filter: "-runtime/references" lint-js: diff --git a/.github/workflows/linux-dnnl.yml b/.github/workflows/linux-dnnl.yml index 0d2c959354e3d..f6e4fe5708140 100644 --- a/.github/workflows/linux-dnnl.yml +++ b/.github/workflows/linux-dnnl.yml @@ -8,13 +8,13 @@ name: Linux DNNL CI on: push: - branches: [ main, 'rel-*'] + branches: [main, 'rel-*'] pull_request: - branches: [ main, 'rel-*'] + branches: [main, 'rel-*'] workflow_dispatch: concurrency: - group: ${{ github.workflow }}-${{ github.ref }} + group: ${{ github.workflow }}-${{ github.event_name == 'pull_request' && github.ref || github.sha }} cancel-in-progress: true permissions: diff --git a/.github/workflows/linux_ci.yml b/.github/workflows/linux_ci.yml index 031c0cdf7d620..6f517f2656e94 100644 --- a/.github/workflows/linux_ci.yml +++ b/.github/workflows/linux_ci.yml @@ -21,13 +21,13 @@ name: Linux CI on: push: - branches: [ main, 'rel-*'] + branches: [main, 'rel-*'] pull_request: - branches: [ main, 'rel-*'] + branches: [main, 'rel-*'] workflow_dispatch: concurrency: - group: ${{ github.workflow }}-${{ github.ref }} + group: ${{ github.workflow }}-${{ github.event_name == 'pull_request' && github.ref || github.sha }} cancel-in-progress: true permissions: diff --git a/.github/workflows/linux_cuda_ci.yml b/.github/workflows/linux_cuda_ci.yml index d3a54e1506e39..0dbe63371c7b8 100644 --- a/.github/workflows/linux_cuda_ci.yml +++ b/.github/workflows/linux_cuda_ci.yml @@ -2,13 +2,13 @@ name: Linux CUDA CI on: push: - branches: [ main, 'rel-*'] + branches: [main, 'rel-*'] pull_request: - branches: [ main, 'rel-*'] + branches: [main, 'rel-*'] workflow_dispatch: concurrency: - group: ${{ github.workflow }}-${{ github.ref }} + group: ${{ github.workflow }}-${{ github.event_name == 'pull_request' && github.ref || github.sha }} cancel-in-progress: true permissions: diff --git a/.github/workflows/linux_openvino_ci.yml b/.github/workflows/linux_openvino_ci.yml index 12495b1f26c65..0a4827087309e 100644 --- a/.github/workflows/linux_openvino_ci.yml +++ b/.github/workflows/linux_openvino_ci.yml @@ -2,13 +2,13 @@ name: Linux OpenVINO CI on: push: - branches: [ main, 'rel-*' ] + branches: [main, 'rel-*'] pull_request: - branches: [ main, 'rel-*' ] + branches: [main, 'rel-*'] workflow_dispatch: concurrency: - group: ${{ github.workflow }}-${{ github.ref }} + group: ${{ github.workflow }}-${{ github.event_name == 'pull_request' && github.ref || github.sha }} cancel-in-progress: true permissions: diff --git a/.github/workflows/linux_tensorrt_ci.yml b/.github/workflows/linux_tensorrt_ci.yml index f8d4a0d4dd218..405de75e95454 100644 --- a/.github/workflows/linux_tensorrt_ci.yml +++ b/.github/workflows/linux_tensorrt_ci.yml @@ -2,13 +2,13 @@ name: Linux TensorRT CI on: push: - branches: [ main, 'rel-*'] + branches: [main, 'rel-*'] pull_request: - branches: [ main, 'rel-*'] + branches: [main, 'rel-*'] workflow_dispatch: concurrency: - group: ${{ github.workflow }}-${{ github.ref }} + group: ${{ github.workflow }}-${{ github.event_name == 'pull_request' && github.ref || github.sha }} cancel-in-progress: true permissions: diff --git a/.github/workflows/mac.yml b/.github/workflows/mac.yml index 8fcc53bbd9991..9cc1604d71e68 100644 --- a/.github/workflows/mac.yml +++ b/.github/workflows/mac.yml @@ -12,7 +12,7 @@ on: workflow_dispatch: concurrency: - group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} + group: ${{ github.workflow }}-${{ github.event_name == 'pull_request' && github.ref || github.sha }} cancel-in-progress: true env: diff --git a/.github/workflows/pr_checks.yml b/.github/workflows/pr_checks.yml index af890d88995be..bb8a0638afea2 100644 --- a/.github/workflows/pr_checks.yml +++ b/.github/workflows/pr_checks.yml @@ -13,7 +13,7 @@ permissions: # set top-level default permissions as security best practice contents: read concurrency: - group: ${{ github.workflow }}-${{ github.ref }}-${{ github.event_name == 'workflow_dispatch' }} + group: ${{ github.workflow }}-${{ github.event_name == 'pull_request' && github.ref || github.sha }} cancel-in-progress: true jobs: diff --git a/.github/workflows/web.yml b/.github/workflows/web.yml index 8490252ca657b..8f922ef26cd7e 100644 --- a/.github/workflows/web.yml +++ b/.github/workflows/web.yml @@ -12,7 +12,7 @@ on: workflow_dispatch: concurrency: - group: ${{ github.workflow }}-${{ github.ref }} + group: ${{ github.workflow }}-${{ github.event_name == 'pull_request' && github.ref || github.sha }} cancel-in-progress: true jobs: diff --git a/.github/workflows/windows-web-ci-workflow.yml b/.github/workflows/windows-web-ci-workflow.yml index 57f687d8502ff..b45663a6145e3 100644 --- a/.github/workflows/windows-web-ci-workflow.yml +++ b/.github/workflows/windows-web-ci-workflow.yml @@ -205,6 +205,14 @@ jobs: log_file_path: ${{ runner.temp }}\web\test\07\chrome_debug.log is_chromium_log: true + # this step is added to help investigate the shader validation failure which is hard to reproduce + - name: Upload WebGPU shader validation log on failure + if: ${{ failure() && inputs.run_webgpu_tests == true && inputs.build_config == 'Debug' }} + uses: actions/upload-artifact@v4 + with: + name: webgpu-shader-validation-logs + path: ${{ runner.temp }}\web\test\07\chrome_debug.log + - name: E2E package consuming test if: ${{ inputs.build_config == 'Release' }} run: npm run test:e2e -- --browser=Chrome_default diff --git a/CODEOWNERS b/CODEOWNERS index a55067ed798d8..3ce36ef57524a 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -3,12 +3,6 @@ # Mobile /onnxruntime/core/flatbuffers/schema/ort.fbs @microsoft/onnxruntime-mobile -# MLAS and related contrib ops -/onnxruntime/core/graph/contrib_ops/nhwc_schema_defs.cc @microsoft/onnxruntime-mlas -/onnxruntime/core/graph/contrib_ops/nchwc_schema_defs.cc @microsoft/onnxruntime-mlas -/onnxruntime/core/graph/contrib_ops/quantization_defs.* @microsoft/onnxruntime-mlas -/onnxruntime/core/mlas/** @microsoft/onnxruntime-mlas - # Dependencies requirements-dev.txt @microsoft/onnxruntime-admin requirements-doc.txt @microsoft/onnxruntime-admin diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index 0204ce1423bbf..121799e16ee97 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -974,7 +974,7 @@ if (onnxruntime_USE_JSEP) list(APPEND ONNXRUNTIME_PROVIDER_NAMES js) endif() if (onnxruntime_USE_QNN OR onnxruntime_USE_QNN_INTERFACE) - + if(onnxruntime_USE_QNN) list(APPEND ORT_PROVIDER_FLAGS -DUSE_QNN=1) else() @@ -1552,7 +1552,6 @@ if (Git_FOUND) string(APPEND ORT_BUILD_INFO "git-branch=${ORT_GIT_BRANCH}, git-commit-id=${ORT_GIT_COMMIT}, ") endif() string(APPEND ORT_BUILD_INFO "build type=${CMAKE_BUILD_TYPE}") -string(APPEND ORT_BUILD_INFO ", cmake cxx flags: ${CMAKE_CXX_FLAGS}") configure_file(onnxruntime_config.h.in ${CMAKE_CURRENT_BINARY_DIR}/onnxruntime_config.h) get_property(onnxruntime_GENERATOR_IS_MULTI_CONFIG GLOBAL PROPERTY GENERATOR_IS_MULTI_CONFIG) @@ -1790,12 +1789,10 @@ if (onnxruntime_USE_WINML) list(APPEND ONNXRUNTIME_CMAKE_FILES winml) endif() # if (onnxruntime_USE_WINML) -if (onnxruntime_BUILD_SHARED_LIB OR onnxruntime_BUILD_APPLE_FRAMEWORK) - if (onnxruntime_BUILD_APPLE_FRAMEWORK AND NOT ${CMAKE_SYSTEM_NAME} MATCHES "Darwin|iOS|visionOS|tvOS") - message(FATAL_ERROR "onnxruntime_BUILD_APPLE_FRAMEWORK can only be enabled for macOS or iOS or visionOS or tvOS.") - endif() - list(APPEND ONNXRUNTIME_CMAKE_FILES onnxruntime) +if (onnxruntime_BUILD_APPLE_FRAMEWORK AND NOT ${CMAKE_SYSTEM_NAME} MATCHES "Darwin|iOS|visionOS|tvOS") + message(FATAL_ERROR "onnxruntime_BUILD_APPLE_FRAMEWORK can only be enabled for macOS or iOS or visionOS or tvOS.") endif() +list(APPEND ONNXRUNTIME_CMAKE_FILES onnxruntime) if (onnxruntime_BUILD_JAVA) message(STATUS "Java Build is enabled") @@ -1910,18 +1907,50 @@ if (onnxruntime_ENABLE_EXTERNAL_CUSTOM_OP_SCHEMAS) ) endif() -if(TARGET onnxruntime) -# Install - include(GNUInstallDirs) +if(NOT onnxruntime_BUILD_SHARED_LIB AND onnxruntime_USE_WEBGPU) + message(WARNING "CMake target files will not be generated for static onnxruntime builds with webgpu support") +else() + # Install include(CMakePackageConfigHelpers) + set(PROJECT_CONFIG_CONTENT "@PACKAGE_INIT@\n") + + if (NOT onnxruntime_BUILD_SHARED_LIB) + string(APPEND PROJECT_CONFIG_CONTENT + "include(CMakeFindDependencyMacro)\n\ + find_dependency(absl)\n\ + find_dependency(date)\n\ + find_dependency(Eigen3)\n\ + find_dependency(nlohmann_json)\n\ + find_dependency(ONNX)\n\ + find_dependency(re2)\n\ + find_dependency(flatbuffers)\n\ + find_dependency(cpuinfo)\n\ + find_dependency(protobuf)\n\ + find_dependency(Boost COMPONENTS mp11)\n\ + find_dependency(Microsoft.GSL 4.0)\n\ + if(NOT WIN32 AND NOT CMAKE_SYSTEM_NAME STREQUAL \"Android\")\n\ + find_dependency(Iconv)\n\ + endif()\n\ + if(WIN32)\n\ + find_dependency(wil)\n\ + endif()\n\ + find_path(safeint_SOURCE_DIR NAMES \"SafeInt.hpp\" REQUIRED)\n\ + add_library(safeint_interface IMPORTED INTERFACE)\n\ + target_include_directories(safeint_interface INTERFACE ${safeint_SOURCE_DIR})\n\ + ") + endif() + string(APPEND PROJECT_CONFIG_CONTENT - "include(\"\${CMAKE_CURRENT_LIST_DIR}/${PROJECT_NAME}Targets.cmake\")") + "include(\"\${CMAKE_CURRENT_LIST_DIR}/${PROJECT_NAME}Targets.cmake\")\n") + + file(WRITE "${CMAKE_CURRENT_BINARY_DIR}/PROJECT_CONFIG_FILE" ${PROJECT_CONFIG_CONTENT}) install(EXPORT ${PROJECT_NAME}Targets NAMESPACE ${PROJECT_NAME}:: DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/${PROJECT_NAME}) -# Create config for find_package() + + # Create config for find_package() configure_package_config_file( "${CMAKE_CURRENT_BINARY_DIR}/PROJECT_CONFIG_FILE" ${PROJECT_NAME}Config.cmake INSTALL_DESTINATION "${CMAKE_INSTALL_LIBDIR}/cmake/${PROJECT_NAME}") @@ -1936,16 +1965,16 @@ if(TARGET onnxruntime) "${PROJECT_BINARY_DIR}/${PROJECT_NAME}Config.cmake" "${PROJECT_BINARY_DIR}/${PROJECT_NAME}ConfigVersion.cmake" DESTINATION "${CMAKE_INSTALL_LIBDIR}/cmake/${PROJECT_NAME}") -endif() -if(DEFINED BUILD_AS_ARM64X) - set(ARM64X_TARGETS onnxruntime) + if(DEFINED BUILD_AS_ARM64X) + set(ARM64X_TARGETS onnxruntime) - # Add additional ARM64X build targets - if (onnxruntime_USE_QNN AND NOT onnxruntime_BUILD_QNN_EP_STATIC_LIB) - list(APPEND ARM64X_TARGETS onnxruntime_providers_shared) - list(APPEND ARM64X_TARGETS onnxruntime_providers_qnn) - endif() + # Add additional ARM64X build targets + if (onnxruntime_USE_QNN AND NOT onnxruntime_BUILD_QNN_EP_STATIC_LIB) + list(APPEND ARM64X_TARGETS onnxruntime_providers_shared) + list(APPEND ARM64X_TARGETS onnxruntime_providers_qnn) + endif() - include("${CMAKE_CURRENT_SOURCE_DIR}/arm64x.cmake") + include("${CMAKE_CURRENT_SOURCE_DIR}/arm64x.cmake") + endif() endif() diff --git a/cmake/external/abseil-cpp.cmake b/cmake/external/abseil-cpp.cmake index cd93a8da0061f..5cfb9e78b4720 100644 --- a/cmake/external/abseil-cpp.cmake +++ b/cmake/external/abseil-cpp.cmake @@ -9,15 +9,21 @@ set(BUILD_TESTING 0) set(ABSL_BUILD_TESTING OFF) set(ABSL_BUILD_TEST_HELPERS OFF) set(ABSL_USE_EXTERNAL_GOOGLETEST ON) + +# Both abseil and xnnpack create a target called memory, which +# results in a duplicate target if ABSL_ENABLE_INSTALL is on. +if (onnxruntime_USE_XNNPACK) + set(ABSL_ENABLE_INSTALL OFF) +else() + set(ABSL_ENABLE_INSTALL ON) +endif() + if(Patch_FOUND AND WIN32) set(ABSL_PATCH_COMMAND ${Patch_EXECUTABLE} --binary --ignore-whitespace -p1 < ${PROJECT_SOURCE_DIR}/patches/abseil/absl_windows.patch) else() set(ABSL_PATCH_COMMAND "") endif() -if(WIN32 AND NOT Patch_FOUND) - #see https://github.com/google/re2/issues/425 and https://github.com/google/re2/issues/436 - set(ABSL_ENABLE_INSTALL ON) -endif() + # NB! Advancing Abseil version changes its internal namespace, # currently absl::lts_20240116 which affects abseil-cpp.natvis debugger # visualization file, that must be adjusted accordingly, unless we eliminate diff --git a/cmake/external/extensions.cmake b/cmake/external/extensions.cmake index e12d74d734e97..8c00c1c8a530b 100644 --- a/cmake/external/extensions.cmake +++ b/cmake/external/extensions.cmake @@ -52,9 +52,24 @@ else() add_subdirectory(${onnxruntime_EXTENSIONS_PATH} ${CMAKE_BINARY_DIR}/_deps/extensions-subbuild EXCLUDE_FROM_ALL) endif() +# move internal includes generated by onnxruntime-extensions as PUBLIC +# from INTERFACE_INCLUDE_DIRECTORIES to INTERFACE_DIRECTORIES for the +# targets that we will link to +get_target_property(ocos_operators_INTERFACE_INCLUDES ocos_operators INTERFACE_INCLUDE_DIRECTORIES) +target_include_directories(ocos_operators PRIVATE ${ocos_operators_INTERFACE_INCLUDES}) +set_target_properties(ocos_operators PROPERTIES INTERFACE_INCLUDE_DIRECTORIES "") + +get_target_property(noexcep_operators_INTERFACE_INCLUDES noexcep_operators INTERFACE_INCLUDE_DIRECTORIES) +target_include_directories(noexcep_operators PRIVATE ${noexcep_operators_INTERFACE_INCLUDES}) +set_target_properties(noexcep_operators PROPERTIES INTERFACE_INCLUDE_DIRECTORIES "") + +get_target_property(ortcustomops_INTERFACE_INCLUDES ortcustomops INTERFACE_INCLUDE_DIRECTORIES) +target_include_directories(ortcustomops PRIVATE ${ortcustomops_INTERFACE_INCLUDES} ${ONNXRUNTIME_INCLUDE_DIR}/core/session) +set_target_properties(ortcustomops PROPERTIES INTERFACE_INCLUDE_DIRECTORIES "") + # target library or executable are defined in CMakeLists.txt of onnxruntime-extensions target_include_directories(ocos_operators PRIVATE ${RE2_INCLUDE_DIR} ${json_SOURCE_DIR}/include) -target_include_directories(ortcustomops PUBLIC ${onnxruntime_EXTENSIONS_PATH}/includes) +target_include_directories(ortcustomops PUBLIC $) if(OCOS_ENABLE_SPM_TOKENIZER) onnxruntime_add_include_to_target(sentencepiece-static ${PROTOBUF_LIB} ${ABSEIL_LIBS}) endif() @@ -64,3 +79,10 @@ onnxruntime_add_include_to_target(noexcep_operators ${PROTOBUF_LIB} ${ABSEIL_LIB add_dependencies(ocos_operators ${onnxruntime_EXTERNAL_DEPENDENCIES}) add_dependencies(ortcustomops ${onnxruntime_EXTERNAL_DEPENDENCIES}) +if (NOT onnxruntime_BUILD_SHARED_LIB) + install(TARGETS ocos_operators ortcustomops noexcep_operators EXPORT ${PROJECT_NAME}Targets + ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} + LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} + RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} + FRAMEWORK DESTINATION ${CMAKE_INSTALL_BINDIR}) +endif() diff --git a/cmake/external/onnxruntime_external_deps.cmake b/cmake/external/onnxruntime_external_deps.cmake index 1f31acb057b37..d967e806eb5a3 100644 --- a/cmake/external/onnxruntime_external_deps.cmake +++ b/cmake/external/onnxruntime_external_deps.cmake @@ -179,7 +179,8 @@ endif() # for cross-compiling #2. if ONNX_CUSTOM_PROTOC_EXECUTABLE is not set, Compile everything(including protoc) from source code. if(Patch_FOUND) - set(ONNXRUNTIME_PROTOBUF_PATCH_COMMAND ${Patch_EXECUTABLE} --binary --ignore-whitespace -p1 < ${PROJECT_SOURCE_DIR}/patches/protobuf/protobuf_cmake.patch) + set(ONNXRUNTIME_PROTOBUF_PATCH_COMMAND ${Patch_EXECUTABLE} --binary --ignore-whitespace -p1 < ${PROJECT_SOURCE_DIR}/patches/protobuf/protobuf_cmake.patch && + ${Patch_EXECUTABLE} --binary --ignore-whitespace -p1 < ${PROJECT_SOURCE_DIR}/patches/protobuf/protobuf_android_log.patch) else() set(ONNXRUNTIME_PROTOBUF_PATCH_COMMAND "") endif() @@ -287,15 +288,16 @@ if(NOT TARGET Boost::mp11) EXCLUDE_FROM_ALL FIND_PACKAGE_ARGS NAMES Boost ) - onnxruntime_fetchcontent_makeavailable(mp11) + FetchContent_Populate(mp11) if(NOT TARGET Boost::mp11) - add_library(Boost::mp11 ALIAS Boost::headers) + add_library(Boost::mp11 IMPORTED INTERFACE) + target_include_directories(Boost::mp11 INTERFACE $) endif() endif() endif() set(JSON_BuildTests OFF CACHE INTERNAL "") -set(JSON_Install OFF CACHE INTERNAL "") +set(JSON_Install ON CACHE INTERNAL "") onnxruntime_fetchcontent_declare( nlohmann_json @@ -407,6 +409,13 @@ set(GSL_TARGET "Microsoft.GSL::GSL") set(GSL_INCLUDE_DIR "$") onnxruntime_fetchcontent_makeavailable(GSL) +if (NOT GSL_FOUND AND NOT onnxruntime_BUILD_SHARED_LIB) + install(TARGETS GSL EXPORT ${PROJECT_NAME}Targets + ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} + LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} + RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}) +endif() + find_path(safeint_SOURCE_DIR NAMES "SafeInt.hpp") if(NOT safeint_SOURCE_DIR) unset(safeint_SOURCE_DIR) @@ -420,7 +429,7 @@ if(NOT safeint_SOURCE_DIR) # use fetch content rather than makeavailable because safeint only includes unconditional test targets FetchContent_Populate(safeint) endif() -add_library(safeint_interface INTERFACE) +add_library(safeint_interface IMPORTED INTERFACE) target_include_directories(safeint_interface INTERFACE ${safeint_SOURCE_DIR}) @@ -433,7 +442,7 @@ if (CMAKE_SYSTEM_NAME STREQUAL "iOS" OR CMAKE_SYSTEM_NAME STREQUAL "Android" OR set(FLATBUFFERS_BUILD_FLATC OFF CACHE BOOL "FLATBUFFERS_BUILD_FLATC" FORCE) endif() set(FLATBUFFERS_BUILD_TESTS OFF CACHE BOOL "FLATBUFFERS_BUILD_TESTS" FORCE) -set(FLATBUFFERS_INSTALL OFF CACHE BOOL "FLATBUFFERS_INSTALL" FORCE) +set(FLATBUFFERS_INSTALL ON CACHE BOOL "FLATBUFFERS_INSTALL" FORCE) set(FLATBUFFERS_BUILD_FLATHASH OFF CACHE BOOL "FLATBUFFERS_BUILD_FLATHASH" FORCE) set(FLATBUFFERS_BUILD_FLATLIB ON CACHE BOOL "FLATBUFFERS_BUILD_FLATLIB" FORCE) if(Patch_FOUND) diff --git a/cmake/onnxruntime.cmake b/cmake/onnxruntime.cmake index 6c1d4485ebcc9..1b124e3bb3f74 100644 --- a/cmake/onnxruntime.cmake +++ b/cmake/onnxruntime.cmake @@ -53,138 +53,148 @@ endfunction() get_c_cxx_api_headers(ONNXRUNTIME_PUBLIC_HEADERS) -#If you want to verify if there is any extra line in symbols.txt, run -# nm -C -g --defined libonnxruntime.so |grep -v '\sA\s' | cut -f 3 -d ' ' | sort -# after build - -list(APPEND SYMBOL_FILES "${REPO_ROOT}/tools/ci_build/gen_def.py") -foreach(f ${ONNXRUNTIME_PROVIDER_NAMES}) - list(APPEND SYMBOL_FILES "${ONNXRUNTIME_ROOT}/core/providers/${f}/symbols.txt") -endforeach() - -if(NOT CMAKE_SYSTEM_NAME MATCHES "AIX") -add_custom_command(OUTPUT ${SYMBOL_FILE} ${CMAKE_CURRENT_BINARY_DIR}/generated_source.c - COMMAND ${Python_EXECUTABLE} "${REPO_ROOT}/tools/ci_build/gen_def.py" - --version_file "${ONNXRUNTIME_ROOT}/../VERSION_NUMBER" --src_root "${ONNXRUNTIME_ROOT}" - --config ${ONNXRUNTIME_PROVIDER_NAMES} --style=${OUTPUT_STYLE} --output ${SYMBOL_FILE} - --output_source ${CMAKE_CURRENT_BINARY_DIR}/generated_source.c - DEPENDS ${SYMBOL_FILES} - WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}) - -add_custom_target(onnxruntime_generate_def ALL DEPENDS ${SYMBOL_FILE} ${CMAKE_CURRENT_BINARY_DIR}/generated_source.c) -endif() -if(WIN32) - onnxruntime_add_shared_library(onnxruntime - ${SYMBOL_FILE} - "${ONNXRUNTIME_ROOT}/core/dll/dllmain.cc" - "${ONNXRUNTIME_ROOT}/core/dll/delay_load_hook.cc" - "${ONNXRUNTIME_ROOT}/core/dll/onnxruntime.rc" - ) -elseif(onnxruntime_BUILD_APPLE_FRAMEWORK) - # apple framework requires the header file be part of the library - onnxruntime_add_shared_library(onnxruntime - ${ONNXRUNTIME_PUBLIC_HEADERS} - "${CMAKE_CURRENT_BINARY_DIR}/generated_source.c" - ) +if(onnxruntime_BUILD_SHARED_LIB) + #If you want to verify if there is any extra line in symbols.txt, run + # nm -C -g --defined libonnxruntime.so |grep -v '\sA\s' | cut -f 3 -d ' ' | sort + # after build - # create Info.plist for the framework and podspec for CocoaPods (optional) - set(MACOSX_FRAMEWORK_NAME "onnxruntime") - set(MACOSX_FRAMEWORK_IDENTIFIER "com.microsoft.onnxruntime") + list(APPEND SYMBOL_FILES "${REPO_ROOT}/tools/ci_build/gen_def.py") + foreach(f ${ONNXRUNTIME_PROVIDER_NAMES}) + list(APPEND SYMBOL_FILES "${ONNXRUNTIME_ROOT}/core/providers/${f}/symbols.txt") + endforeach() - # Setup weak frameworks for macOS/iOS. 'weak' as the CoreML or WebGPU EPs are optionally enabled. - if(onnxruntime_USE_COREML) - list(APPEND _weak_frameworks "\\\"CoreML\\\"") - endif() + if(NOT CMAKE_SYSTEM_NAME MATCHES "AIX") + add_custom_command(OUTPUT ${SYMBOL_FILE} ${CMAKE_CURRENT_BINARY_DIR}/generated_source.c + COMMAND ${Python_EXECUTABLE} "${REPO_ROOT}/tools/ci_build/gen_def.py" + --version_file "${ONNXRUNTIME_ROOT}/../VERSION_NUMBER" --src_root "${ONNXRUNTIME_ROOT}" + --config ${ONNXRUNTIME_PROVIDER_NAMES} --style=${OUTPUT_STYLE} --output ${SYMBOL_FILE} + --output_source ${CMAKE_CURRENT_BINARY_DIR}/generated_source.c + DEPENDS ${SYMBOL_FILES} + WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}) - if(onnxruntime_USE_WEBGPU) - list(APPEND _weak_frameworks "\\\"QuartzCore\\\"") - list(APPEND _weak_frameworks "\\\"IOSurface\\\"") - list(APPEND _weak_frameworks "\\\"Metal\\\"") + add_custom_target(onnxruntime_generate_def ALL DEPENDS ${SYMBOL_FILE} ${CMAKE_CURRENT_BINARY_DIR}/generated_source.c) endif() + if(WIN32) + onnxruntime_add_shared_library(onnxruntime + ${SYMBOL_FILE} + "${ONNXRUNTIME_ROOT}/core/dll/dllmain.cc" + "${ONNXRUNTIME_ROOT}/core/dll/delay_load_hook.cc" + "${ONNXRUNTIME_ROOT}/core/dll/onnxruntime.rc" + ) + elseif(onnxruntime_BUILD_APPLE_FRAMEWORK) + # apple framework requires the header file be part of the library + onnxruntime_add_shared_library(onnxruntime + ${ONNXRUNTIME_PUBLIC_HEADERS} + "${CMAKE_CURRENT_BINARY_DIR}/generated_source.c" + ) + + # create Info.plist for the framework and podspec for CocoaPods (optional) + set(MACOSX_FRAMEWORK_NAME "onnxruntime") + set(MACOSX_FRAMEWORK_IDENTIFIER "com.microsoft.onnxruntime") + + # Setup weak frameworks for macOS/iOS. 'weak' as the CoreML or WebGPU EPs are optionally enabled. + if(onnxruntime_USE_COREML) + list(APPEND _weak_frameworks "\\\"CoreML\\\"") + endif() - if (_weak_frameworks) - string(JOIN ", " APPLE_WEAK_FRAMEWORK ${_weak_frameworks}) - endif() + if(onnxruntime_USE_WEBGPU) + list(APPEND _weak_frameworks "\\\"QuartzCore\\\"") + list(APPEND _weak_frameworks "\\\"IOSurface\\\"") + list(APPEND _weak_frameworks "\\\"Metal\\\"") + endif() - set(INFO_PLIST_PATH "${CMAKE_CURRENT_BINARY_DIR}/Info.plist") - configure_file(${REPO_ROOT}/cmake/Info.plist.in ${INFO_PLIST_PATH}) - configure_file( - ${REPO_ROOT}/tools/ci_build/github/apple/framework_info.json.template - ${CMAKE_CURRENT_BINARY_DIR}/framework_info.json) - set_target_properties(onnxruntime PROPERTIES - FRAMEWORK TRUE - FRAMEWORK_VERSION A - MACOSX_FRAMEWORK_INFO_PLIST ${INFO_PLIST_PATH} - # Note: The PUBLIC_HEADER and VERSION properties for the 'onnxruntime' target will be set later in this file. - ) -else() - if(CMAKE_SYSTEM_NAME MATCHES "AIX") - onnxruntime_add_shared_library(onnxruntime ${ONNXRUNTIME_ROOT}/core/session/onnxruntime_c_api.cc) + if (_weak_frameworks) + string(JOIN ", " APPLE_WEAK_FRAMEWORK ${_weak_frameworks}) + endif() + + set(INFO_PLIST_PATH "${CMAKE_CURRENT_BINARY_DIR}/Info.plist") + configure_file(${REPO_ROOT}/cmake/Info.plist.in ${INFO_PLIST_PATH}) + configure_file( + ${REPO_ROOT}/tools/ci_build/github/apple/framework_info.json.template + ${CMAKE_CURRENT_BINARY_DIR}/framework_info.json) + set_target_properties(onnxruntime PROPERTIES + FRAMEWORK TRUE + FRAMEWORK_VERSION A + MACOSX_FRAMEWORK_INFO_PLIST ${INFO_PLIST_PATH} + # Note: The PUBLIC_HEADER and VERSION properties for the 'onnxruntime' target will be set later in this file. + ) else() - onnxruntime_add_shared_library(onnxruntime ${CMAKE_CURRENT_BINARY_DIR}/generated_source.c ) - endif() - if(NOT APPLE) - include(CheckLinkerFlag) - check_linker_flag(CXX "LINKER:-rpath=\$ORIGIN" LINKER_SUPPORT_RPATH) - if(LINKER_SUPPORT_RPATH) - target_link_options(onnxruntime PRIVATE "LINKER:-rpath=\$ORIGIN") + if(CMAKE_SYSTEM_NAME MATCHES "AIX") + onnxruntime_add_shared_library(onnxruntime ${ONNXRUNTIME_ROOT}/core/session/onnxruntime_c_api.cc) + else() + onnxruntime_add_shared_library(onnxruntime ${CMAKE_CURRENT_BINARY_DIR}/generated_source.c ) + endif() + if(NOT APPLE) + include(CheckLinkerFlag) + check_linker_flag(CXX "LINKER:-rpath=\$ORIGIN" LINKER_SUPPORT_RPATH) + if(LINKER_SUPPORT_RPATH) + target_link_options(onnxruntime PRIVATE "LINKER:-rpath=\$ORIGIN") + endif() endif() endif() -endif() -if(CMAKE_SYSTEM_NAME MATCHES "AIX") - add_dependencies(onnxruntime ${onnxruntime_EXTERNAL_DEPENDENCIES}) -else() - add_dependencies(onnxruntime onnxruntime_generate_def ${onnxruntime_EXTERNAL_DEPENDENCIES}) -endif() -target_include_directories(onnxruntime PRIVATE ${ONNXRUNTIME_ROOT} PUBLIC "$") + if(CMAKE_SYSTEM_NAME MATCHES "AIX") + add_dependencies(onnxruntime ${onnxruntime_EXTERNAL_DEPENDENCIES}) + else() + add_dependencies(onnxruntime onnxruntime_generate_def ${onnxruntime_EXTERNAL_DEPENDENCIES}) + endif() + target_include_directories(onnxruntime PRIVATE ${ONNXRUNTIME_ROOT} PUBLIC "$") -target_compile_definitions(onnxruntime PRIVATE FILE_NAME=\"onnxruntime.dll\") + target_compile_definitions(onnxruntime PRIVATE FILE_NAME=\"onnxruntime.dll\") -if(UNIX) - if (APPLE) - target_link_options(onnxruntime PRIVATE "LINKER:-dead_strip") - elseif(NOT CMAKE_SYSTEM_NAME MATCHES "AIX") - target_link_options(onnxruntime PRIVATE "LINKER:--version-script=${SYMBOL_FILE}" "LINKER:--no-undefined" "LINKER:--gc-sections") + if(UNIX) + if (APPLE) + target_link_options(onnxruntime PRIVATE "LINKER:-dead_strip") + elseif(NOT CMAKE_SYSTEM_NAME MATCHES "AIX") + target_link_options(onnxruntime PRIVATE "LINKER:--version-script=${SYMBOL_FILE}" "LINKER:--no-undefined" "LINKER:--gc-sections") + endif() + else() + target_link_options(onnxruntime PRIVATE "-DEF:${SYMBOL_FILE}") endif() -else() - target_link_options(onnxruntime PRIVATE "-DEF:${SYMBOL_FILE}") -endif() -if (APPLE OR ${CMAKE_SYSTEM_NAME} MATCHES "^iOS") + if (APPLE) target_link_options(onnxruntime PRIVATE "LINKER:-exported_symbols_list,${SYMBOL_FILE}") - if (${CMAKE_SYSTEM_NAME} STREQUAL "iOS") - set_target_properties(onnxruntime PROPERTIES - MACOSX_RPATH TRUE - INSTALL_RPATH_USE_LINK_PATH FALSE - BUILD_WITH_INSTALL_NAME_DIR TRUE - INSTALL_NAME_DIR @rpath) - else() - set_target_properties(onnxruntime PROPERTIES INSTALL_RPATH "@loader_path") - endif() -endif() + set_target_properties(onnxruntime PROPERTIES + MACOSX_RPATH TRUE + SKIP_BUILD_RPATH TRUE + INSTALL_RPATH_USE_LINK_PATH FALSE + BUILD_WITH_INSTALL_NAME_DIR TRUE + INSTALL_NAME_DIR @rpath) + endif() -if(CMAKE_SYSTEM_NAME STREQUAL "Android" AND onnxruntime_MINIMAL_BUILD) - # target onnxruntime is a shared library, the dummy __cxa_demangle is only attach to it to avoid - # affecting downstream ort library users with the behavior of dummy __cxa_demangle. So the dummy - # __cxa_demangle must not expose to libonnxruntime_common.a. It works as when the linker is - # creating the DSO, our dummy __cxa_demangle always comes before libc++abi.a so the - # __cxa_demangle in libc++abi.a is discarded, thus, huge binary size reduction. - target_sources(onnxruntime PRIVATE "${ONNXRUNTIME_ROOT}/core/platform/android/cxa_demangle.cc") - target_compile_definitions(onnxruntime PRIVATE USE_DUMMY_EXA_DEMANGLE=1) -endif() + if(CMAKE_SYSTEM_NAME STREQUAL "Android" AND onnxruntime_MINIMAL_BUILD) + # target onnxruntime is a shared library, the dummy __cxa_demangle is only attach to it to avoid + # affecting downstream ort library users with the behavior of dummy __cxa_demangle. So the dummy + # __cxa_demangle must not expose to libonnxruntime_common.a. It works as when the linker is + # creating the DSO, our dummy __cxa_demangle always comes before libc++abi.a so the + # __cxa_demangle in libc++abi.a is discarded, thus, huge binary size reduction. + target_sources(onnxruntime PRIVATE "${ONNXRUNTIME_ROOT}/core/platform/android/cxa_demangle.cc") + target_compile_definitions(onnxruntime PRIVATE USE_DUMMY_EXA_DEMANGLE=1) + endif() -# strip binary on Android, or for a minimal build on Unix -if(CMAKE_SYSTEM_NAME STREQUAL "Android" OR (onnxruntime_MINIMAL_BUILD AND UNIX)) - if (onnxruntime_MINIMAL_BUILD AND ADD_DEBUG_INFO_TO_MINIMAL_BUILD) - # don't strip - else() - set_target_properties(onnxruntime PROPERTIES LINK_FLAGS_RELEASE -s) - set_target_properties(onnxruntime PROPERTIES LINK_FLAGS_MINSIZEREL -s) + # strip binary on Android, or for a minimal build on Unix + if(CMAKE_SYSTEM_NAME STREQUAL "Android" OR (onnxruntime_MINIMAL_BUILD AND UNIX)) + if (onnxruntime_MINIMAL_BUILD AND ADD_DEBUG_INFO_TO_MINIMAL_BUILD) + # don't strip + else() + set_target_properties(onnxruntime PROPERTIES LINK_FLAGS_RELEASE -s) + set_target_properties(onnxruntime PROPERTIES LINK_FLAGS_MINSIZEREL -s) + endif() + endif() + + # we need to copy C/C++ API headers to be packed into Android AAR package + if(CMAKE_SYSTEM_NAME STREQUAL "Android" AND onnxruntime_BUILD_JAVA) + set(ANDROID_HEADERS_DIR ${CMAKE_CURRENT_BINARY_DIR}/android/headers) + file(MAKE_DIRECTORY ${ANDROID_HEADERS_DIR}) + # copy the header files one by one + foreach(h_ ${ONNXRUNTIME_PUBLIC_HEADERS}) + get_filename_component(HEADER_NAME_ ${h_} NAME) + add_custom_command(TARGET onnxruntime POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy_if_different ${h_} ${ANDROID_HEADERS_DIR}/${HEADER_NAME_}) + endforeach() endif() endif() @@ -199,6 +209,10 @@ if(CMAKE_SYSTEM_NAME STREQUAL "Android" AND onnxruntime_BUILD_JAVA) endforeach() endif() +if (NOT onnxruntime_BUILD_SHARED_LIB) + add_library(onnxruntime INTERFACE) +endif() + set(onnxruntime_INTERNAL_PROVIDER_LIBRARIES ${PROVIDERS_ACL} ${PROVIDERS_ARMNN} @@ -251,10 +265,17 @@ endif() # If you are linking a new library, please add it to the list onnxruntime_INTERNAL_LIBRARIES or onnxruntime_EXTERNAL_LIBRARIES, # Please do not add a library directly to the target_link_libraries command -target_link_libraries(onnxruntime PRIVATE +if (onnxruntime_BUILD_SHARED_LIB) + target_link_libraries(onnxruntime PRIVATE + ${onnxruntime_INTERNAL_LIBRARIES} + ${onnxruntime_EXTERNAL_LIBRARIES} + ) +else() + target_link_libraries(onnxruntime INTERFACE ${onnxruntime_INTERNAL_LIBRARIES} ${onnxruntime_EXTERNAL_LIBRARIES} -) + ) +endif() if(WIN32) target_link_options(onnxruntime PRIVATE ${onnxruntime_DELAYLOAD_FLAGS}) diff --git a/cmake/onnxruntime_common.cmake b/cmake/onnxruntime_common.cmake index 993bb5d89efee..f9cd35fa71aa8 100644 --- a/cmake/onnxruntime_common.cmake +++ b/cmake/onnxruntime_common.cmake @@ -222,7 +222,7 @@ endif() if (NOT onnxruntime_BUILD_SHARED_LIB) install(DIRECTORY ${PROJECT_SOURCE_DIR}/../include/onnxruntime/core/common DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/onnxruntime/core) - install(TARGETS onnxruntime_common + install(TARGETS onnxruntime_common EXPORT ${PROJECT_NAME}Targets ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} diff --git a/cmake/onnxruntime_flatbuffers.cmake b/cmake/onnxruntime_flatbuffers.cmake index 3ab4c19122ba1..066fb561de57d 100644 --- a/cmake/onnxruntime_flatbuffers.cmake +++ b/cmake/onnxruntime_flatbuffers.cmake @@ -22,10 +22,9 @@ if (FLATBUFFERS_BUILD_FLATC) add_dependencies(onnxruntime_flatbuffers flatc) endif() if (NOT onnxruntime_BUILD_SHARED_LIB) - install(TARGETS onnxruntime_flatbuffers + install(TARGETS onnxruntime_flatbuffers EXPORT ${PROJECT_NAME}Targets ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} FRAMEWORK DESTINATION ${CMAKE_INSTALL_BINDIR}) endif() - diff --git a/cmake/onnxruntime_framework.cmake b/cmake/onnxruntime_framework.cmake index e96bb32a7cd21..15f3105b34ecb 100644 --- a/cmake/onnxruntime_framework.cmake +++ b/cmake/onnxruntime_framework.cmake @@ -66,9 +66,9 @@ endif() if(onnxruntime_USE_TENSORRT OR onnxruntime_USE_NCCL OR onnxruntime_USE_NV) # TODO: for now, core framework depends on CUDA. It should be moved to TensorRT EP # TODO: provider_bridge_ort.cc should not include nccl.h -target_include_directories(onnxruntime_framework PRIVATE ${ONNXRUNTIME_ROOT} PUBLIC ${CMAKE_CURRENT_BINARY_DIR} ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) +target_include_directories(onnxruntime_framework PRIVATE ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR} PUBLIC ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) else() -target_include_directories(onnxruntime_framework PRIVATE ${ONNXRUNTIME_ROOT} PUBLIC ${CMAKE_CURRENT_BINARY_DIR}) +target_include_directories(onnxruntime_framework PRIVATE ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR}) endif() # Needed for the provider interface, as it includes training headers when training is enabled if (onnxruntime_ENABLE_TRAINING_OPS) @@ -118,7 +118,7 @@ if (onnxruntime_BUILD_SHARED_LIB) install(FILES ${PROJECT_SOURCE_DIR}/../include/onnxruntime/core/framework/provider_options.h DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/onnxruntime/) else() install(DIRECTORY ${PROJECT_SOURCE_DIR}/../include/onnxruntime/core/framework DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/onnxruntime/core) - install(TARGETS onnxruntime_framework + install(TARGETS onnxruntime_framework EXPORT ${PROJECT_NAME}Targets ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} diff --git a/cmake/onnxruntime_graph.cmake b/cmake/onnxruntime_graph.cmake index 4d51325b8414e..fba1a680bb62a 100644 --- a/cmake/onnxruntime_graph.cmake +++ b/cmake/onnxruntime_graph.cmake @@ -157,7 +157,7 @@ endif() if (NOT onnxruntime_BUILD_SHARED_LIB) install(DIRECTORY ${PROJECT_SOURCE_DIR}/../include/onnxruntime/core/graph DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/onnxruntime/core) - install(TARGETS onnxruntime_graph + install(TARGETS onnxruntime_graph EXPORT ${PROJECT_NAME}Targets ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} diff --git a/cmake/onnxruntime_lora.cmake b/cmake/onnxruntime_lora.cmake index 7ba48454d997e..26ee21c645584 100644 --- a/cmake/onnxruntime_lora.cmake +++ b/cmake/onnxruntime_lora.cmake @@ -22,7 +22,7 @@ add_dependencies(onnxruntime_lora ${onnxruntime_EXTERNAL_DEPENDENCIES}) set_target_properties(onnxruntime_lora PROPERTIES FOLDER "ONNXRuntime") if (NOT onnxruntime_BUILD_SHARED_LIB) - install(TARGETS onnxruntime_lora + install(TARGETS onnxruntime_lora EXPORT ${PROJECT_NAME}Targets ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index 5d46ac9adb7c2..3279a17f8cd5e 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -790,7 +790,7 @@ if (PLATFORM_NAME STREQUAL "macabi") endif() if (NOT onnxruntime_BUILD_SHARED_LIB) - install(TARGETS onnxruntime_mlas + install(TARGETS onnxruntime_mlas EXPORT ${PROJECT_NAME}Targets ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} diff --git a/cmake/onnxruntime_optimizer.cmake b/cmake/onnxruntime_optimizer.cmake index 173c872d4cc06..e60cfbe1c0566 100644 --- a/cmake/onnxruntime_optimizer.cmake +++ b/cmake/onnxruntime_optimizer.cmake @@ -131,7 +131,7 @@ set_target_properties(onnxruntime_optimizer PROPERTIES FOLDER "ONNXRuntime") if (NOT onnxruntime_BUILD_SHARED_LIB) install(DIRECTORY ${PROJECT_SOURCE_DIR}/../include/onnxruntime/core/optimizer DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/onnxruntime/core) - install(TARGETS onnxruntime_optimizer + install(TARGETS onnxruntime_optimizer EXPORT ${PROJECT_NAME}Targets ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} diff --git a/cmake/onnxruntime_providers_acl.cmake b/cmake/onnxruntime_providers_acl.cmake index 1726151a3597a..afa41d72c1e9c 100644 --- a/cmake/onnxruntime_providers_acl.cmake +++ b/cmake/onnxruntime_providers_acl.cmake @@ -28,9 +28,9 @@ set_target_properties(onnxruntime_providers_acl PROPERTIES LINKER_LANGUAGE CXX) if (NOT onnxruntime_BUILD_SHARED_LIB) - install(TARGETS onnxruntime_providers_acl + install(TARGETS onnxruntime_providers_acl EXPORT ${PROJECT_NAME}Targets ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} FRAMEWORK DESTINATION ${CMAKE_INSTALL_BINDIR}) - endif() \ No newline at end of file + endif() diff --git a/cmake/onnxruntime_providers_armnn.cmake b/cmake/onnxruntime_providers_armnn.cmake index d6e0f3bd1b6cc..5238a1d19a197 100644 --- a/cmake/onnxruntime_providers_armnn.cmake +++ b/cmake/onnxruntime_providers_armnn.cmake @@ -25,9 +25,9 @@ set_target_properties(onnxruntime_providers_armnn PROPERTIES LINKER_LANGUAGE CXX) if (NOT onnxruntime_BUILD_SHARED_LIB) - install(TARGETS onnxruntime_providers_armnn + install(TARGETS onnxruntime_providers_armnn EXPORT ${PROJECT_NAME}Targets ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} FRAMEWORK DESTINATION ${CMAKE_INSTALL_BINDIR}) - endif() \ No newline at end of file + endif() diff --git a/cmake/onnxruntime_providers_coreml.cmake b/cmake/onnxruntime_providers_coreml.cmake index 0dd6e77c4b67a..757198ffb651d 100644 --- a/cmake/onnxruntime_providers_coreml.cmake +++ b/cmake/onnxruntime_providers_coreml.cmake @@ -210,7 +210,7 @@ target_include_directories(onnxruntime_providers_coreml PRIVATE ${ONNXRUNTIME_RO set_target_properties(onnxruntime_providers_coreml PROPERTIES LINKER_LANGUAGE CXX) if (NOT onnxruntime_BUILD_SHARED_LIB) - install(TARGETS onnxruntime_providers_coreml + install(TARGETS onnxruntime_providers_coreml EXPORT ${PROJECT_NAME}Targets ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} diff --git a/cmake/onnxruntime_providers_cpu.cmake b/cmake/onnxruntime_providers_cpu.cmake index b9c810e1250a0..5a2dfb3210988 100644 --- a/cmake/onnxruntime_providers_cpu.cmake +++ b/cmake/onnxruntime_providers_cpu.cmake @@ -267,7 +267,7 @@ if (NOT onnxruntime_MINIMAL_BUILD AND NOT onnxruntime_EXTENDED_MINIMAL_BUILD endif() if (NOT onnxruntime_BUILD_SHARED_LIB) - install(TARGETS onnxruntime_providers + install(TARGETS onnxruntime_providers EXPORT ${PROJECT_NAME}Targets ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} diff --git a/cmake/onnxruntime_providers_dml.cmake b/cmake/onnxruntime_providers_dml.cmake index e5ff19f133ee5..62136c5c568d7 100644 --- a/cmake/onnxruntime_providers_dml.cmake +++ b/cmake/onnxruntime_providers_dml.cmake @@ -87,7 +87,7 @@ set_target_properties(onnxruntime_providers_dml PROPERTIES FOLDER "ONNXRuntime") if (NOT onnxruntime_BUILD_SHARED_LIB) - install(TARGETS onnxruntime_providers_dml + install(TARGETS onnxruntime_providers_dml EXPORT ${PROJECT_NAME}Targets ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} diff --git a/cmake/onnxruntime_providers_nnapi.cmake b/cmake/onnxruntime_providers_nnapi.cmake index b718a976eb26f..06364ebd49593 100644 --- a/cmake/onnxruntime_providers_nnapi.cmake +++ b/cmake/onnxruntime_providers_nnapi.cmake @@ -74,7 +74,7 @@ endif() if (NOT onnxruntime_BUILD_SHARED_LIB) - install(TARGETS onnxruntime_providers_nnapi + install(TARGETS onnxruntime_providers_nnapi EXPORT ${PROJECT_NAME}Targets ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} diff --git a/cmake/onnxruntime_providers_nv.cmake b/cmake/onnxruntime_providers_nv.cmake index 06d44b5289518..12d824fc3360e 100644 --- a/cmake/onnxruntime_providers_nv.cmake +++ b/cmake/onnxruntime_providers_nv.cmake @@ -36,44 +36,30 @@ file(READ ${TENSORRT_INCLUDE_DIR}/NvInferVersion.h NVINFER_VER_CONTENT) - string(REGEX MATCH "define NV_TENSORRT_MAJOR * +([0-9]+)" NV_TENSORRT_MAJOR "${NVINFER_VER_CONTENT}") - string(REGEX REPLACE "define NV_TENSORRT_MAJOR * +([0-9]+)" "\\1" NV_TENSORRT_MAJOR "${NV_TENSORRT_MAJOR}") - string(REGEX MATCH "define NV_TENSORRT_MINOR * +([0-9]+)" NV_TENSORRT_MINOR "${NVINFER_VER_CONTENT}") - string(REGEX REPLACE "define NV_TENSORRT_MINOR * +([0-9]+)" "\\1" NV_TENSORRT_MINOR "${NV_TENSORRT_MINOR}") - string(REGEX MATCH "define NV_TENSORRT_PATCH * +([0-9]+)" NV_TENSORRT_PATCH "${NVINFER_VER_CONTENT}") - string(REGEX REPLACE "define NV_TENSORRT_PATCH * +([0-9]+)" "\\1" NV_TENSORRT_PATCH "${NV_TENSORRT_PATCH}") - math(EXPR NV_TENSORRT_MAJOR_INT "${NV_TENSORRT_MAJOR}") - math(EXPR NV_TENSORRT_MINOR_INT "${NV_TENSORRT_MINOR}") - math(EXPR NV_TENSORRT_PATCH_INT "${NV_TENSORRT_PATCH}") - - if (NV_TENSORRT_MAJOR) - MESSAGE(STATUS "NV_TENSORRT_MAJOR is ${NV_TENSORRT_MAJOR}") + string(REGEX MATCH "define TRT_MAJOR_RTX * +([0-9]+)" NV_TRT_MAJOR_RTX "${NVINFER_VER_CONTENT}") + string(REGEX REPLACE "define TRT_MAJOR_RTX * +([0-9]+)" "\\1" NV_TRT_MAJOR_RTX "${NV_TRT_MAJOR_RTX}") + string(REGEX MATCH "define TRT_MINOR_RTX * +([0-9]+)" NV_TRT_MINOR_RTX "${NVINFER_VER_CONTENT}") + string(REGEX REPLACE "define TRT_MINOR_RTX * +([0-9]+)" "\\1" NV_TRT_MINOR_RTX "${NV_TRT_MINOR_RTX}") + math(EXPR NV_TRT_MAJOR_RTX_INT "${NV_TRT_MAJOR_RTX}") + math(EXPR NV_TRT_MINOR_RTX_INT "${NV_TRT_MINOR_RTX}") + + if (NV_TRT_MAJOR_RTX) + MESSAGE(STATUS "NV_TRT_MAJOR_RTX is ${NV_TRT_MAJOR_RTX}") else() - MESSAGE(STATUS "Can't find NV_TENSORRT_MAJOR macro") + MESSAGE(STATUS "Can't find NV_TRT_MAJOR_RTX macro") endif() - # Check TRT version >= 10.0.1.6 - if ((NV_TENSORRT_MAJOR_INT GREATER 10) OR - (NV_TENSORRT_MAJOR_INT EQUAL 10 AND NV_TENSORRT_MINOR_INT GREATER 0) OR - (NV_TENSORRT_MAJOR_INT EQUAL 10 AND NV_TENSORRT_PATCH_INT GREATER 0)) - set(TRT_GREATER_OR_EQUAL_TRT_10_GA ON) - else() - message( FATAL_ERROR "Only TensorRT 10.x or higher is supported." ) - endif() - - # TensorRT 10 GA onwards, the TensorRT libraries will have major version appended to the end on Windows, - # for example, nvinfer_10.dll, nvonnxparser_10.dll ... - if (WIN32 AND TRT_GREATER_OR_EQUAL_TRT_10_GA) - set(NVINFER_LIB "nvinfer_${NV_TENSORRT_MAJOR}") - set(PARSER_LIB "nvonnxparser_${NV_TENSORRT_MAJOR}") + if (WIN32) + set(NVINFER_LIB "tensorrt_rtx_${NV_TRT_MAJOR_RTX}_${NV_TRT_MINOR_RTX}") + set(PARSER_LIB "tensorrt_onnxparser_rtx_${NV_TRT_MAJOR_RTX}_${NV_TRT_MINOR_RTX}") endif() if (NOT NVINFER_LIB) - set(NVINFER_LIB "nvinfer") + set(NVINFER_LIB "tensorrt_rtx") endif() if (NOT PARSER_LIB) - set(PARSER_LIB "nvonnxparser") + set(PARSER_LIB "tensorrt_onnxparser_rtx") endif() MESSAGE(STATUS "Looking for ${NVINFER_LIB}") @@ -100,9 +86,8 @@ set(TENSORRT_LIBRARY ${TENSORRT_LIBRARY_INFER} ${TENSORRT_LIBRARY_NVONNXPARSER}) MESSAGE(STATUS "Find TensorRT libs at ${TENSORRT_LIBRARY}") else() - if (TRT_GREATER_OR_EQUAL_TRT_10_GA) - set(ONNX_USE_LITE_PROTO ON) - endif() + set(ONNX_USE_LITE_PROTO ON) + onnxruntime_fetchcontent_declare( onnx_tensorrt URL ${DEP_URL_onnx_tensorrt} diff --git a/cmake/onnxruntime_providers_qnn.cmake b/cmake/onnxruntime_providers_qnn.cmake index 60b3aaf38cd85..748e3de843bab 100644 --- a/cmake/onnxruntime_providers_qnn.cmake +++ b/cmake/onnxruntime_providers_qnn.cmake @@ -161,3 +161,11 @@ ) endif() endif() + +if (NOT onnxruntime_BUILD_SHARED_LIB) + install(TARGETS onnxruntime_providers_qnn EXPORT ${PROJECT_NAME}Targets + ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} + LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} + RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} + FRAMEWORK DESTINATION ${CMAKE_INSTALL_BINDIR}) +endif() diff --git a/cmake/onnxruntime_providers_rknpu.cmake b/cmake/onnxruntime_providers_rknpu.cmake index 408bcfde06c36..831df84aa6e08 100644 --- a/cmake/onnxruntime_providers_rknpu.cmake +++ b/cmake/onnxruntime_providers_rknpu.cmake @@ -36,9 +36,9 @@ set_target_properties(onnxruntime_providers_rknpu PROPERTIES LINKER_LANGUAGE CXX) if (NOT onnxruntime_BUILD_SHARED_LIB) - install(TARGETS onnxruntime_providers_rknpu + install(TARGETS onnxruntime_providers_rknpu EXPORT ${PROJECT_NAME}Targets ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} FRAMEWORK DESTINATION ${CMAKE_INSTALL_BINDIR}) - endif() \ No newline at end of file + endif() diff --git a/cmake/onnxruntime_providers_webnn.cmake b/cmake/onnxruntime_providers_webnn.cmake index 05c63c22244db..da5720a9f7cb5 100644 --- a/cmake/onnxruntime_providers_webnn.cmake +++ b/cmake/onnxruntime_providers_webnn.cmake @@ -22,4 +22,12 @@ add_dependencies(onnxruntime_providers_webnn onnx ${onnxruntime_EXTERNAL_DEPENDENCIES}) set_target_properties(onnxruntime_providers_webnn PROPERTIES FOLDER "ONNXRuntime") - set_target_properties(onnxruntime_providers_webnn PROPERTIES LINKER_LANGUAGE CXX) \ No newline at end of file + set_target_properties(onnxruntime_providers_webnn PROPERTIES LINKER_LANGUAGE CXX) + +if (NOT onnxruntime_BUILD_SHARED_LIB) + install(TARGETS onnxruntime_providers_webnn EXPORT ${PROJECT_NAME}Targets + ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} + LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} + RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} + FRAMEWORK DESTINATION ${CMAKE_INSTALL_BINDIR}) +endif() diff --git a/cmake/onnxruntime_providers_xnnpack.cmake b/cmake/onnxruntime_providers_xnnpack.cmake index 796536ac9d12b..1fbe553fa58a0 100644 --- a/cmake/onnxruntime_providers_xnnpack.cmake +++ b/cmake/onnxruntime_providers_xnnpack.cmake @@ -28,7 +28,7 @@ set_target_properties(onnxruntime_providers_xnnpack PROPERTIES LINKER_LANGUAGE CXX) if (NOT onnxruntime_BUILD_SHARED_LIB) - install(TARGETS onnxruntime_providers_xnnpack + install(TARGETS onnxruntime_providers_xnnpack EXPORT ${PROJECT_NAME}Targets ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} diff --git a/cmake/onnxruntime_python.cmake b/cmake/onnxruntime_python.cmake index c57a2a962303d..8f7a96e052fa1 100644 --- a/cmake/onnxruntime_python.cmake +++ b/cmake/onnxruntime_python.cmake @@ -468,6 +468,9 @@ file(GLOB onnxruntime_python_quantization_fusions_src CONFIGURE_DEPENDS file(GLOB onnxruntime_python_quantization_ep_qnn_src CONFIGURE_DEPENDS "${ONNXRUNTIME_ROOT}/python/tools/quantization/execution_providers/qnn/*.py" ) +file(GLOB onnxruntime_python_quantization_neural_compressor_src CONFIGURE_DEPENDS + "${ONNXRUNTIME_ROOT}/python/tools/quantization/neural_compressor/*.py" +) file(GLOB onnxruntime_python_transformers_src CONFIGURE_DEPENDS "${ONNXRUNTIME_ROOT}/python/tools/transformers/*.py" ) @@ -581,6 +584,7 @@ add_custom_command( COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/quantization/fusions COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/quantization/execution_providers COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/quantization/execution_providers/qnn + COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/quantization/neural_compressor COMMAND ${CMAKE_COMMAND} -E make_directory $/quantization COMMAND ${CMAKE_COMMAND} -E make_directory $/transformers COMMAND ${CMAKE_COMMAND} -E make_directory $/transformers/test_data/models @@ -660,6 +664,9 @@ add_custom_command( COMMAND ${CMAKE_COMMAND} -E copy ${onnxruntime_python_quantization_ep_qnn_src} $/onnxruntime/quantization/execution_providers/qnn/ + COMMAND ${CMAKE_COMMAND} -E copy + ${onnxruntime_python_quantization_neural_compressor_src} + $/onnxruntime/quantization/neural_compressor/ COMMAND ${CMAKE_COMMAND} -E copy ${onnxruntime_python_transformers_src} $/onnxruntime/transformers/ diff --git a/cmake/onnxruntime_session.cmake b/cmake/onnxruntime_session.cmake index bd46c7476d0f0..d61512fa3cf09 100644 --- a/cmake/onnxruntime_session.cmake +++ b/cmake/onnxruntime_session.cmake @@ -69,7 +69,7 @@ endif() if (NOT onnxruntime_BUILD_SHARED_LIB) install(DIRECTORY ${PROJECT_SOURCE_DIR}/../include/onnxruntime/core/session DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/onnxruntime/core) - install(TARGETS onnxruntime_session + install(TARGETS onnxruntime_session EXPORT ${PROJECT_NAME}Targets ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} diff --git a/cmake/onnxruntime_util.cmake b/cmake/onnxruntime_util.cmake index 17e64ca2c1d65..851b68a4b61a0 100644 --- a/cmake/onnxruntime_util.cmake +++ b/cmake/onnxruntime_util.cmake @@ -21,7 +21,7 @@ if (WIN32) endif() if (NOT onnxruntime_BUILD_SHARED_LIB) - install(TARGETS onnxruntime_util + install(TARGETS onnxruntime_util EXPORT ${PROJECT_NAME}Targets ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} diff --git a/cmake/patches/protobuf/protobuf_android_log.patch b/cmake/patches/protobuf/protobuf_android_log.patch new file mode 100644 index 0000000000000..56bca9d325a75 --- /dev/null +++ b/cmake/patches/protobuf/protobuf_android_log.patch @@ -0,0 +1,26 @@ +diff --git a/cmake/libprotobuf-lite.cmake b/cmake/libprotobuf-lite.cmake +index 83e970312..96908991f 100644 +--- a/cmake/libprotobuf-lite.cmake ++++ b/cmake/libprotobuf-lite.cmake +@@ -102,7 +102,7 @@ if(protobuf_LINK_LIBATOMIC) + target_link_libraries(libprotobuf-lite PRIVATE atomic) + endif() + if(${CMAKE_SYSTEM_NAME} STREQUAL "Android") +- target_link_libraries(libprotobuf-lite PRIVATE log) ++ target_link_libraries(libprotobuf-lite PRIVATE -llog) + endif() + target_include_directories(libprotobuf-lite PUBLIC ${protobuf_SOURCE_DIR}/src) + if(protobuf_BUILD_SHARED_LIBS) +diff --git a/cmake/libprotobuf.cmake b/cmake/libprotobuf.cmake +index 07e4bcf57..0cf27caff 100644 +--- a/cmake/libprotobuf.cmake ++++ b/cmake/libprotobuf.cmake +@@ -118,7 +118,7 @@ if(protobuf_LINK_LIBATOMIC) + target_link_libraries(libprotobuf PRIVATE atomic) + endif() + if(${CMAKE_SYSTEM_NAME} STREQUAL "Android") +- target_link_libraries(libprotobuf PRIVATE log) ++ target_link_libraries(libprotobuf PRIVATE -llog) + endif() + target_include_directories(libprotobuf PUBLIC ${protobuf_SOURCE_DIR}/src) + if(protobuf_BUILD_SHARED_LIBS) diff --git a/cmake/tensorboard/compat/proto/CMakeLists.txt b/cmake/tensorboard/compat/proto/CMakeLists.txt index ad31e4062a8a4..addc3779e9521 100644 --- a/cmake/tensorboard/compat/proto/CMakeLists.txt +++ b/cmake/tensorboard/compat/proto/CMakeLists.txt @@ -23,3 +23,10 @@ add_dependencies(tensorboard ${onnxruntime_EXTERNAL_DEPENDENCIES}) if(WIN32) target_compile_options(tensorboard PRIVATE "/wd4100" "/wd4125" "/wd4127" "/wd4267" "/wd4456" "/wd4800" "/wd6011" "/wd6387" "/wd28182") endif() + +if (NOT onnxruntime_BUILD_SHARED_LIB) + install(TARGETS tensorboard EXPORT ${PROJECT_NAME}Targets + ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} + LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} + RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}) +endif() diff --git a/cmake/vcpkg-ports/dlpack/portfile.cmake b/cmake/vcpkg-ports/dlpack/portfile.cmake new file mode 100644 index 0000000000000..fdf328836d4dd --- /dev/null +++ b/cmake/vcpkg-ports/dlpack/portfile.cmake @@ -0,0 +1,25 @@ +set(VCPKG_BUILD_TYPE release) # header-only port + +vcpkg_from_github( + OUT_SOURCE_PATH SOURCE_PATH + REPO dmlc/dlpack + REF 5c210da409e7f1e51ddf445134a4376fdbd70d7d + SHA512 4bc5f5fd36b20ef2943989d5c06fe9cd34f942cdfd4b4866a4405649f7faac47fcdcf3a1fa60eb7b96b643222e5e4b036cbca7d49835dc5f8b659708620a2e8f + HEAD_REF main +) + +vcpkg_cmake_configure( + SOURCE_PATH "${SOURCE_PATH}" + OPTIONS + -DBUILD_MOCK=FALSE +) + +vcpkg_cmake_install() + +vcpkg_cmake_config_fixup(CONFIG_PATH "lib/cmake/dlpack") + +file(REMOVE_RECURSE "${CURRENT_PACKAGES_DIR}/lib") + +vcpkg_install_copyright(FILE_LIST "${SOURCE_PATH}/LICENSE") + +file(COPY "${CMAKE_CURRENT_LIST_DIR}/usage" DESTINATION "${CURRENT_PACKAGES_DIR}/share/${PORT}") diff --git a/cmake/vcpkg-ports/dlpack/usage b/cmake/vcpkg-ports/dlpack/usage new file mode 100644 index 0000000000000..771ec78517174 --- /dev/null +++ b/cmake/vcpkg-ports/dlpack/usage @@ -0,0 +1,4 @@ +dlpack provides CMake targets: + + find_package(dlpack CONFIG REQUIRED) + target_link_libraries(main PRIVATE dlpack::dlpack) diff --git a/cmake/vcpkg-ports/dlpack/vcpkg.json b/cmake/vcpkg-ports/dlpack/vcpkg.json new file mode 100644 index 0000000000000..48f2f22a0a058 --- /dev/null +++ b/cmake/vcpkg-ports/dlpack/vcpkg.json @@ -0,0 +1,17 @@ +{ + "name": "dlpack", + "version-semver": "1.1.1", + "description": "DLPack is an open in-memory tensor structure for sharing tensors among frameworks", + "homepage": "https://github.com/dmlc/dlpack", + "license": "Apache-2.0", + "dependencies": [ + { + "name": "vcpkg-cmake", + "host": true + }, + { + "name": "vcpkg-cmake-config", + "host": true + } + ] +} diff --git a/cmake/winml.cmake b/cmake/winml.cmake index ef635c0c8c794..f2651d0cbc2b2 100644 --- a/cmake/winml.cmake +++ b/cmake/winml.cmake @@ -316,7 +316,7 @@ if (onnxruntime_WINML_NAMESPACE_OVERRIDE STREQUAL "Windows") target_compile_definitions(winml_adapter PRIVATE "BUILD_INBOX=1") endif() -# wil requires C++17 +# will requires C++17 set_target_properties(winml_adapter PROPERTIES CXX_STANDARD 17) set_target_properties(winml_adapter PROPERTIES CXX_STANDARD_REQUIRED ON) diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/CompileModel.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/CompileModel.shared.cs new file mode 100644 index 0000000000000..9f42bf2247529 --- /dev/null +++ b/csharp/src/Microsoft.ML.OnnxRuntime/CompileModel.shared.cs @@ -0,0 +1,130 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +namespace Microsoft.ML.OnnxRuntime +{ + using System; + using System.Runtime.InteropServices; + + /// + /// This class is used to set options for model compilation, and to produce a compiled model using those options. + /// See https://onnxruntime.ai/docs/api/c/ for further details of various options. + /// + public class OrtModelCompilationOptions : SafeHandle + { + /// + /// Create a new OrtModelCompilationOptions object from SessionOptions. + /// + /// SessionOptions instance to read settings from. + public OrtModelCompilationOptions(SessionOptions sessionOptions) + : base(IntPtr.Zero, true) + { + NativeApiStatus.VerifySuccess( + NativeMethods.CompileApi.OrtCreateModelCompilationOptionsFromSessionOptions( + OrtEnv.Instance().Handle, sessionOptions.Handle, out handle)); + } + + /// + /// Compile the model using the options set in this object. + /// + public void CompileModel() + { + NativeApiStatus.VerifySuccess(NativeMethods.CompileApi.OrtCompileModel(OrtEnv.Instance().Handle, handle)); + } + + + /// + /// Set the input model to compile. + /// + /// Path to ONNX model to compile. + public void SetInputModelPath(string path) + { + var platformPath = NativeOnnxValueHelper.GetPlatformSerializedString(path); + NativeApiStatus.VerifySuccess( + NativeMethods.CompileApi.OrtModelCompilationOptions_SetInputModelPath(handle, platformPath)); + } + + /// + /// Set the input model to compile to be a byte array. + /// The input bytes are NOT copied and must remain valid while in use by ORT. + /// + /// Input model bytes. + public void SetInputModelFromBuffer(byte[] buffer) + { + NativeApiStatus.VerifySuccess( + NativeMethods.CompileApi.OrtModelCompilationOptions_SetInputModelFromBuffer( + handle, buffer, (UIntPtr)buffer.Length)); + } + + /// + /// Set the path to write the compiled ONNX model to. + /// + /// Path to write compiled model to. + public void SetOutputModelPath(string path) + { + var platformPath = NativeOnnxValueHelper.GetPlatformSerializedString(path); + NativeApiStatus.VerifySuccess( + NativeMethods.CompileApi.OrtModelCompilationOptions_SetOutputModelPath(handle, platformPath)); + + } + + /// + /// Set the path to a file to write initializers as external data to, + /// and the threshold that determines when to write an initializer to the external data file. + /// + /// Path to file to write external data to. + /// Size at which an initializer will be written to external data. + public void SetOutputModelExternalInitializersFile(string filePath, ulong threshold) + { + var platformPath = NativeOnnxValueHelper.GetPlatformSerializedString(filePath); + NativeApiStatus.VerifySuccess( + NativeMethods.CompileApi.OrtModelCompilationOptions_SetOutputModelExternalInitializersFile( + handle, platformPath, new UIntPtr(threshold))); + } + + // TODO: In order to use this to create an InferenceSession without copying bytes we need more infrastructure. + // - Need something that wraps the allocator, pointer and size and is SafeHandle based. + // - When it is disposed we need to use the allocator to release the native buffer. + // - Need the 4 InferenceSession ctors that take byte[] for the model to be duplicated to handle this new + // wrapper type. + // Due to that making this API internal so we can test it. We can make it public when the other infrastructure + // is in place as it will change the signature of the API. + internal void SetOutputModelBuffer(OrtAllocator allocator, + ref IntPtr outputModelBufferPtr, ref UIntPtr outputModelBufferSizePtr) + { + NativeApiStatus.VerifySuccess( + NativeMethods.CompileApi.OrtModelCompilationOptions_SetOutputModelBuffer( + handle, allocator.Pointer, ref outputModelBufferPtr, ref outputModelBufferSizePtr)); + } + + /// + /// Enables or disables the embedding of EPContext binary data into the `ep_cache_context` attribute + /// of EPContext nodes. + /// + /// Enable if true. Default is false. + public void SetEpContextEmbedMode(bool embed) + { + NativeApiStatus.VerifySuccess( + NativeMethods.CompileApi.OrtModelCompilationOptions_SetEpContextEmbedMode(handle, embed)); + } + + internal IntPtr Handle => handle; + + + /// + /// Indicates whether the native handle is invalid. + /// + public override bool IsInvalid => handle == IntPtr.Zero; + + /// + /// Release the native instance of OrtModelCompilationOptions. + /// + /// true + protected override bool ReleaseHandle() + { + NativeMethods.CompileApi.OrtReleaseModelCompilationOptions(handle); + handle = IntPtr.Zero; + return true; + } + } +} \ No newline at end of file diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/InferenceSession.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/InferenceSession.shared.cs index b62a3c50bfda6..792f0ddd0f777 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/InferenceSession.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/InferenceSession.shared.cs @@ -1046,7 +1046,7 @@ public ulong ProfilingStartTimeNs } } - private static void OrtCallback(IntPtr userData, IntPtr[] ouputs, uint numOutputs, IntPtr status) + private static void OrtCallback(IntPtr userData, IntPtr[] outputs, uint numOutputs, IntPtr status) { var hostHdl = GCHandle.FromIntPtr(userData); CallbackHost host = (CallbackHost)hostHdl.Target; @@ -1635,7 +1635,7 @@ internal TensorTypeAndShape(TensorElementType elementType, int[] dimensions, str } /// - /// Represents sequnce metdata + /// Represents sequence metadata /// public class SequenceMetadata { @@ -1648,7 +1648,7 @@ internal SequenceMetadata(NodeMetadata elementData) ElementMeta = elementData; } /// - /// Element Metatada, recursive definition with a Tensor being a base case + /// Element Metadata, recursive definition with a Tensor being a base case /// may contain maps, tensors and other sequences /// public NodeMetadata ElementMeta { get; } @@ -1669,7 +1669,7 @@ internal OptionalMetadata(NodeMetadata elementData) } /// - /// Element Metatada, recursive definition with a Tensor being a base case + /// Element Metadata, recursive definition with a Tensor being a base case /// may contain maps, tensors and sequences /// public NodeMetadata ElementMeta { get; } @@ -1876,7 +1876,7 @@ public TensorElementType ElementDataType } /// - /// Convinience method to check for string + /// Convenience method to check for string /// public bool IsString { diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/NativeCompileApiMethods.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/NativeCompileApiMethods.shared.cs new file mode 100644 index 0000000000000..3a87f87d124e9 --- /dev/null +++ b/csharp/src/Microsoft.ML.OnnxRuntime/NativeCompileApiMethods.shared.cs @@ -0,0 +1,152 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +namespace Microsoft.ML.OnnxRuntime.CompileApi; + +using System; +using System.Runtime.InteropServices; + +// NOTE: The order of the APIs in this struct should match exactly that in OrtCompileApi +// See onnxruntime/core/session/compile_api.cc. +[StructLayout(LayoutKind.Sequential)] +public struct OrtCompileApi +{ + public IntPtr ReleaseModelCompilationOptions; + public IntPtr CreateModelCompilationOptionsFromSessionOptions; + public IntPtr ModelCompilationOptions_SetInputModelPath; + public IntPtr ModelCompilationOptions_SetInputModelFromBuffer; + public IntPtr ModelCompilationOptions_SetOutputModelPath; + public IntPtr ModelCompilationOptions_SetOutputModelExternalInitializersFile; + public IntPtr ModelCompilationOptions_SetOutputModelBuffer; + public IntPtr ModelCompilationOptions_SetEpContextEmbedMode; + public IntPtr CompileModel; +} + +internal class NativeMethods +{ + private static OrtCompileApi _compileApi; + + // + // Define the delegate signatures, and a static member for each to hold the marshaled function pointer. + // + // We populate the static members in the constructor of this class. + // + // The C# code will call the C++ API through the delegate instances in the static members. + // + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate void DOrtReleaseModelCompilationOptions(IntPtr /* OrtModelCompilationOptions* */ options); + public DOrtReleaseModelCompilationOptions OrtReleaseModelCompilationOptions; + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* OrtStatus* */ DOrtCreateModelCompilationOptionsFromSessionOptions( + IntPtr /* const OrtEnv* */ env, + IntPtr /* const OrtSessionOptions* */ sessionOptions, + out IntPtr /* OrtModelCompilationOptions** */ outOptions); + public DOrtCreateModelCompilationOptionsFromSessionOptions + OrtCreateModelCompilationOptionsFromSessionOptions; + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* OrtStatus* */ DOrtModelCompilationOptions_SetInputModelPath( + IntPtr /* OrtModelCompilationOptions* */ options, + byte[] /* const ORTCHAR_T* */ inputModelPath); + public DOrtModelCompilationOptions_SetInputModelPath OrtModelCompilationOptions_SetInputModelPath; + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* OrtStatus* */ DOrtModelCompilationOptions_SetInputModelFromBuffer( + IntPtr /* OrtModelCompilationOptions* */ options, + byte[] /* const void* */ inputModelData, + UIntPtr /* size_t */ inputModelDataSize); + public DOrtModelCompilationOptions_SetInputModelFromBuffer + OrtModelCompilationOptions_SetInputModelFromBuffer; + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* OrtStatus* */ DOrtModelCompilationOptions_SetOutputModelPath( + IntPtr /* OrtModelCompilationOptions* */ options, + byte[] /* const ORTCHAR_T* */ outputModelPath); + public DOrtModelCompilationOptions_SetOutputModelPath OrtModelCompilationOptions_SetOutputModelPath; + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* OrtStatus* */ DOrtModelCompilationOptions_SetOutputModelExternalInitializersFile( + IntPtr /* OrtModelCompilationOptions* */ options, + byte[] /* const ORTCHAR_T* */ externalInitializersFilePath, + UIntPtr /* size_t */ externalInitializerSizeThreshold); + public DOrtModelCompilationOptions_SetOutputModelExternalInitializersFile + OrtModelCompilationOptions_SetOutputModelExternalInitializersFile; + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* OrtStatus* */ DOrtModelCompilationOptions_SetOutputModelBuffer( + IntPtr /* OrtModelCompilationOptions* */ options, + IntPtr /* OrtAllocator* */ allocator, + ref IntPtr /* void** */ outputModelBufferPtr, + ref UIntPtr /* size_t* */ outputModelBufferSizePtr); + public DOrtModelCompilationOptions_SetOutputModelBuffer OrtModelCompilationOptions_SetOutputModelBuffer; + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* OrtStatus* */ DOrtModelCompilationOptions_SetEpContextEmbedMode( + IntPtr /* OrtModelCompilationOptions* */ options, + bool embedEpContextInModel); + public DOrtModelCompilationOptions_SetEpContextEmbedMode OrtModelCompilationOptions_SetEpContextEmbedMode; + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* OrtStatus* */ DOrtCompileModel( + IntPtr /* const OrtEnv* */ env, + IntPtr /* const OrtModelCompilationOptions* */ modelOptions); + public DOrtCompileModel OrtCompileModel; + + internal NativeMethods(OnnxRuntime.NativeMethods.DOrtGetCompileApi getCompileApi) + { + +#if NETSTANDARD2_0 + IntPtr compileApiPtr = getCompileApi(); + _compileApi = (OrtCompileApi)Marshal.PtrToStructure(compileApiPtr, typeof(OrtCompileApi)); +#else + _compileApi = (OrtCompileApi)getCompileApi(); +#endif + + OrtReleaseModelCompilationOptions = + (DOrtReleaseModelCompilationOptions)Marshal.GetDelegateForFunctionPointer( + _compileApi.ReleaseModelCompilationOptions, + typeof(DOrtReleaseModelCompilationOptions)); + + OrtCreateModelCompilationOptionsFromSessionOptions = + (DOrtCreateModelCompilationOptionsFromSessionOptions)Marshal.GetDelegateForFunctionPointer( + _compileApi.CreateModelCompilationOptionsFromSessionOptions, + typeof(DOrtCreateModelCompilationOptionsFromSessionOptions)); + + OrtModelCompilationOptions_SetInputModelPath = + (DOrtModelCompilationOptions_SetInputModelPath)Marshal.GetDelegateForFunctionPointer( + _compileApi.ModelCompilationOptions_SetInputModelPath, + typeof(DOrtModelCompilationOptions_SetInputModelPath)); + + OrtModelCompilationOptions_SetInputModelFromBuffer = + (DOrtModelCompilationOptions_SetInputModelFromBuffer)Marshal.GetDelegateForFunctionPointer( + _compileApi.ModelCompilationOptions_SetInputModelFromBuffer, + typeof(DOrtModelCompilationOptions_SetInputModelFromBuffer)); + + OrtModelCompilationOptions_SetOutputModelPath = + (DOrtModelCompilationOptions_SetOutputModelPath)Marshal.GetDelegateForFunctionPointer( + _compileApi.ModelCompilationOptions_SetOutputModelPath, + typeof(DOrtModelCompilationOptions_SetOutputModelPath)); + + OrtModelCompilationOptions_SetOutputModelExternalInitializersFile = + (DOrtModelCompilationOptions_SetOutputModelExternalInitializersFile)Marshal.GetDelegateForFunctionPointer( + _compileApi.ModelCompilationOptions_SetOutputModelExternalInitializersFile, + typeof(DOrtModelCompilationOptions_SetOutputModelExternalInitializersFile)); + + OrtModelCompilationOptions_SetOutputModelBuffer = + (DOrtModelCompilationOptions_SetOutputModelBuffer)Marshal.GetDelegateForFunctionPointer( + _compileApi.ModelCompilationOptions_SetOutputModelBuffer, + typeof(DOrtModelCompilationOptions_SetOutputModelBuffer)); + + OrtModelCompilationOptions_SetEpContextEmbedMode = + (DOrtModelCompilationOptions_SetEpContextEmbedMode)Marshal.GetDelegateForFunctionPointer( + _compileApi.ModelCompilationOptions_SetEpContextEmbedMode, + typeof(DOrtModelCompilationOptions_SetEpContextEmbedMode)); + + OrtCompileModel = + (DOrtCompileModel)Marshal.GetDelegateForFunctionPointer( + _compileApi.CompileModel, + typeof(DOrtCompileModel)); + } +} diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs index 77c35aac65b92..8cca2b42e987a 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs @@ -336,12 +336,46 @@ public struct OrtApi public IntPtr GetModelEditorApi; public IntPtr CreateTensorWithDataAndDeleterAsOrtValue; public IntPtr SessionOptionsSetLoadCancellationFlag; + + public IntPtr GetCompileApi; + + public IntPtr CreateKeyValuePairs; + public IntPtr AddKeyValuePair; + public IntPtr GetKeyValue; + public IntPtr GetKeyValuePairs; + public IntPtr RemoveKeyValuePair; + public IntPtr ReleaseKeyValuePairs; + + public IntPtr RegisterExecutionProviderLibrary; + public IntPtr UnregisterExecutionProviderLibrary; + + public IntPtr GetEpDevices; + + public IntPtr SessionOptionsAppendExecutionProvider_V2; + public IntPtr SessionOptionsSetEpSelectionPolicy; + public IntPtr SessionOptionsSetEpSelectionPolicyDelegate; + + public IntPtr HardwareDevice_Type; + public IntPtr HardwareDevice_VendorId; + public IntPtr HardwareDevice_Vendor; + public IntPtr HardwareDevice_DeviceId; + public IntPtr HardwareDevice_Metadata; + + public IntPtr EpDevice_EpName; + public IntPtr EpDevice_EpVendor; + public IntPtr EpDevice_EpMetadata; + public IntPtr EpDevice_EpOptions; + public IntPtr EpDevice_Device; + public IntPtr GetEpApi; + public IntPtr GetTensorSizeInBytes; } internal static class NativeMethods { static OrtApi api_; + static internal CompileApi.NativeMethods CompileApi; + #if NETSTANDARD2_0 [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr DOrtGetApi(UInt32 version); @@ -375,12 +409,15 @@ static NativeMethods() api_ = (OrtApi)OrtGetApi(ORT_API_VERSION); OrtGetVersionString = (DOrtGetVersionString)Marshal.GetDelegateForFunctionPointer(OrtGetApiBase().GetVersionString, typeof(DOrtGetVersionString)); #endif + OrtCreateStatus = (DOrtCreateStatus)Marshal.GetDelegateForFunctionPointer( + api_.CreateStatus, typeof(DOrtCreateStatus)); OrtCreateEnv = (DOrtCreateEnv)Marshal.GetDelegateForFunctionPointer(api_.CreateEnv, typeof(DOrtCreateEnv)); OrtCreateEnvWithCustomLogger = (DOrtCreateEnvWithCustomLogger)Marshal.GetDelegateForFunctionPointer(api_.CreateEnvWithCustomLogger, typeof(DOrtCreateEnvWithCustomLogger)); OrtCreateEnvWithGlobalThreadPools = (DOrtCreateEnvWithGlobalThreadPools)Marshal.GetDelegateForFunctionPointer(api_.CreateEnvWithGlobalThreadPools, typeof(DOrtCreateEnvWithGlobalThreadPools)); OrtCreateEnvWithCustomLoggerAndGlobalThreadPools = (DOrtCreateEnvWithCustomLoggerAndGlobalThreadPools)Marshal.GetDelegateForFunctionPointer(api_.CreateEnvWithCustomLoggerAndGlobalThreadPools, typeof(DOrtCreateEnvWithCustomLoggerAndGlobalThreadPools)); OrtReleaseEnv = (DOrtReleaseEnv)Marshal.GetDelegateForFunctionPointer(api_.ReleaseEnv, typeof(DOrtReleaseEnv)); + OrtEnableTelemetryEvents = (DOrtEnableTelemetryEvents)Marshal.GetDelegateForFunctionPointer(api_.EnableTelemetryEvents, typeof(DOrtEnableTelemetryEvents)); OrtDisableTelemetryEvents = (DOrtDisableTelemetryEvents)Marshal.GetDelegateForFunctionPointer(api_.DisableTelemetryEvents, typeof(DOrtDisableTelemetryEvents)); @@ -504,6 +541,7 @@ static NativeMethods() OrtValueIsTensor = (DOrtValueIsTensor)Marshal.GetDelegateForFunctionPointer(api_.IsTensor, typeof(DOrtValueIsTensor)); OrtValueIsSparseTensor = (DOrtValueIsSparseTensor)Marshal.GetDelegateForFunctionPointer(api_.IsSparseTensor, typeof(DOrtValueIsSparseTensor)); OrtGetTensorMutableData = (DOrtGetTensorMutableData)Marshal.GetDelegateForFunctionPointer(api_.GetTensorMutableData, typeof(DOrtGetTensorMutableData)); + OrtGetTensorSizeInBytes = (DOrtGetTensorSizeInBytes)Marshal.GetDelegateForFunctionPointer(api_.GetTensorSizeInBytes, typeof(DOrtGetTensorSizeInBytes)); OrtFillStringTensor = (DOrtFillStringTensor)Marshal.GetDelegateForFunctionPointer(api_.FillStringTensor, typeof(DOrtFillStringTensor)); OrtGetResizedStringTensorElementBuffer = (DOrtGetResizedStringTensorElementBuffer)Marshal.GetDelegateForFunctionPointer(api_.GetResizedStringTensorElementBuffer, typeof(DOrtGetResizedStringTensorElementBuffer)); OrtGetStringTensorContent = (DOrtGetStringTensorContent)Marshal.GetDelegateForFunctionPointer(api_.GetStringTensorContent, typeof(DOrtGetStringTensorContent)); @@ -582,6 +620,90 @@ static NativeMethods() typeof(DReleaseLoraAdapter)); OrtRunOptionsAddActiveLoraAdapter = (DOrtRunOptionsAddActiveLoraAdapter)Marshal.GetDelegateForFunctionPointer( api_.RunOptionsAddActiveLoraAdapter, typeof(DOrtRunOptionsAddActiveLoraAdapter)); + + OrtGetCompileApi = (DOrtGetCompileApi)Marshal.GetDelegateForFunctionPointer( + api_.GetCompileApi, typeof(DOrtGetCompileApi)); + + // populate the CompileApi struct now that we have the delegate to get the compile API pointer. + CompileApi = new CompileApi.NativeMethods(OrtGetCompileApi); + + OrtCreateKeyValuePairs = (DOrtCreateKeyValuePairs)Marshal.GetDelegateForFunctionPointer( + api_.CreateKeyValuePairs, typeof(DOrtCreateKeyValuePairs)); + + OrtAddKeyValuePair = (DOrtAddKeyValuePair)Marshal.GetDelegateForFunctionPointer( + api_.AddKeyValuePair, typeof(DOrtAddKeyValuePair)); + + OrtGetKeyValue = (DOrtGetKeyValue)Marshal.GetDelegateForFunctionPointer( + api_.GetKeyValue, typeof(DOrtGetKeyValue)); + + OrtGetKeyValuePairs = (DOrtGetKeyValuePairs)Marshal.GetDelegateForFunctionPointer( + api_.GetKeyValuePairs, typeof(DOrtGetKeyValuePairs)); + + OrtRemoveKeyValuePair = (DOrtRemoveKeyValuePair)Marshal.GetDelegateForFunctionPointer( + api_.RemoveKeyValuePair, typeof(DOrtRemoveKeyValuePair)); + + OrtReleaseKeyValuePairs = (DOrtReleaseKeyValuePairs)Marshal.GetDelegateForFunctionPointer( + api_.ReleaseKeyValuePairs, typeof(DOrtReleaseKeyValuePairs)); + + OrtHardwareDevice_Type = (DOrtHardwareDevice_Type)Marshal.GetDelegateForFunctionPointer( + api_.HardwareDevice_Type, typeof(DOrtHardwareDevice_Type)); + + OrtHardwareDevice_VendorId = (DOrtHardwareDevice_VendorId)Marshal.GetDelegateForFunctionPointer( + api_.HardwareDevice_VendorId, typeof(DOrtHardwareDevice_VendorId)); + + OrtHardwareDevice_Vendor = (DOrtHardwareDevice_Vendor)Marshal.GetDelegateForFunctionPointer( + api_.HardwareDevice_Vendor, typeof(DOrtHardwareDevice_Vendor)); + + OrtHardwareDevice_DeviceId = (DOrtHardwareDevice_DeviceId)Marshal.GetDelegateForFunctionPointer( + api_.HardwareDevice_DeviceId, typeof(DOrtHardwareDevice_DeviceId)); + + OrtHardwareDevice_Metadata = (DOrtHardwareDevice_Metadata)Marshal.GetDelegateForFunctionPointer( + api_.HardwareDevice_Metadata, typeof(DOrtHardwareDevice_Metadata)); + + + OrtEpDevice_EpName = (DOrtEpDevice_EpName)Marshal.GetDelegateForFunctionPointer( + api_.EpDevice_EpName, typeof(DOrtEpDevice_EpName)); + + OrtEpDevice_EpVendor = (DOrtEpDevice_EpVendor)Marshal.GetDelegateForFunctionPointer( + api_.EpDevice_EpVendor, typeof(DOrtEpDevice_EpVendor)); + + OrtEpDevice_EpMetadata = (DOrtEpDevice_EpMetadata)Marshal.GetDelegateForFunctionPointer( + api_.EpDevice_EpMetadata, typeof(DOrtEpDevice_EpMetadata)); + + OrtEpDevice_EpOptions = (DOrtEpDevice_EpOptions)Marshal.GetDelegateForFunctionPointer( + api_.EpDevice_EpOptions, typeof(DOrtEpDevice_EpOptions)); + + OrtEpDevice_Device = (DOrtEpDevice_Device)Marshal.GetDelegateForFunctionPointer( + api_.EpDevice_Device, typeof(DOrtEpDevice_Device)); + + OrtRegisterExecutionProviderLibrary = + (DOrtRegisterExecutionProviderLibrary)Marshal.GetDelegateForFunctionPointer( + api_.RegisterExecutionProviderLibrary, + typeof(DOrtRegisterExecutionProviderLibrary)); + + OrtUnregisterExecutionProviderLibrary = + (DOrtUnregisterExecutionProviderLibrary)Marshal.GetDelegateForFunctionPointer( + api_.UnregisterExecutionProviderLibrary, + typeof(DOrtUnregisterExecutionProviderLibrary)); + + OrtGetEpDevices = (DOrtGetEpDevices)Marshal.GetDelegateForFunctionPointer( + api_.GetEpDevices, + typeof(DOrtGetEpDevices)); + + OrtSessionOptionsAppendExecutionProvider_V2 = + (DOrtSessionOptionsAppendExecutionProvider_V2)Marshal.GetDelegateForFunctionPointer( + api_.SessionOptionsAppendExecutionProvider_V2, + typeof(DOrtSessionOptionsAppendExecutionProvider_V2)); + + OrtSessionOptionsSetEpSelectionPolicy = + (DSessionOptionsSetEpSelectionPolicy)Marshal.GetDelegateForFunctionPointer( + api_.SessionOptionsSetEpSelectionPolicy, + typeof(DSessionOptionsSetEpSelectionPolicy)); + + OrtSessionOptionsSetEpSelectionPolicyDelegate = + (DSessionOptionsSetEpSelectionPolicyDelegate)Marshal.GetDelegateForFunctionPointer( + api_.SessionOptionsSetEpSelectionPolicyDelegate, + typeof(DSessionOptionsSetEpSelectionPolicyDelegate)); } internal class NativeLib @@ -817,13 +939,18 @@ internal class NativeLib #endregion Status API #region InferenceSession API + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* OrtStatus* */ DOrtCreateStatus( + uint /* OrtErrorCode */ code, + byte[] /* const char* */ msg); + public static DOrtCreateStatus OrtCreateStatus; [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /* OrtStatus* */ DOrtCreateSession( IntPtr /* (OrtEnv*) */ environment, //[MarshalAs(UnmanagedType.LPStr)]string modelPath byte[] modelPath, - IntPtr /* (OrtSessionOptions*) */ sessopnOptions, + IntPtr /* (OrtSessionOptions*) */ sessionOptions, out IntPtr /**/ session); public static DOrtCreateSession OrtCreateSession; @@ -1332,7 +1459,7 @@ out IntPtr lora_adapter /// bytes /// size in bytes /// optional device allocator - /// resuling LoraAdapter instance + /// resulting LoraAdapter instance /// [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /*(OrtStatus*)*/ DCreateLoraAdapterFromArray( @@ -1350,7 +1477,7 @@ out IntPtr lora_adapter #endregion - #region RunOptions API +#region RunOptions API [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /*(OrtStatus*)*/ DOrtCreateRunOptions(out IntPtr /* OrtRunOptions** */ runOptions); @@ -1971,6 +2098,12 @@ out IntPtr lora_adapter public static DOrtGetTensorMutableData OrtGetTensorMutableData; + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /*(OrtStatus*)*/ DOrtGetTensorSizeInBytes(IntPtr /* const struct OrtValue*/ ortValue, + out UIntPtr /* size_t* */ tensorSizeInBytes); + + public static DOrtGetTensorSizeInBytes OrtGetTensorSizeInBytes; + /// \param value A tensor created from OrtCreateTensor... function. /// \param len total data length, not including the trailing '\0' chars. [UnmanagedFunctionPointer(CallingConvention.Winapi)] @@ -2153,7 +2286,255 @@ out IntPtr lora_adapter #endregion -#region Misc API +#region Compile API + +#if NETSTANDARD2_0 + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr DOrtGetCompileApi(); +#else + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate ref CompileApi.OrtCompileApi DOrtGetCompileApi(); +#endif + public static DOrtGetCompileApi OrtGetCompileApi; +#endregion + +#region Auto EP API related + // + // OrtKeyValuePairs + + /// + /// Create an OrtKeyValuePairs instance. + /// + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate void DOrtCreateKeyValuePairs(out IntPtr /* OrtKeyValuePairs** */ kvps); + + /// + /// Add/replace a key-value pair in the OrtKeyValuePairs instance. + /// + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate void DOrtAddKeyValuePair(IntPtr /* OrtKeyValuePairs* */ kvps, + byte[] /* const char* */ key, + byte[] /* const char* */ value); + + /// + /// Get the value for the provided key. + /// + /// Value. Returns IntPtr.Zero if key was not found. + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* const char* */ DOrtGetKeyValue(IntPtr /* const OrtKeyValuePairs* */ kvps, + byte[] /* const char* */ key); + + /// + /// Get all the key-value pairs in the OrtKeyValuePairs instance. + /// + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate void DOrtGetKeyValuePairs(IntPtr /* const OrtKeyValuePairs* */ kvps, + out IntPtr /* const char* const** */ keys, + out IntPtr /* const char* const** */ values, + out UIntPtr /* size_t* */ numEntries); + + /// + /// Remove a key-value pair from the OrtKeyValuePairs instance. + /// Ignores keys that are not present. + /// + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate void DOrtRemoveKeyValuePair(IntPtr /* OrtKeyValuePairs* */ kvps, + byte[] /* const char* */ key); + + /// + /// Release the OrtKeyValuePairs instance. + /// + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate void DOrtReleaseKeyValuePairs(IntPtr /* OrtKeyValuePairs* */ kvps); + + + public static DOrtCreateKeyValuePairs OrtCreateKeyValuePairs; + public static DOrtAddKeyValuePair OrtAddKeyValuePair; + public static DOrtGetKeyValue OrtGetKeyValue; + public static DOrtGetKeyValuePairs OrtGetKeyValuePairs; + public static DOrtRemoveKeyValuePair OrtRemoveKeyValuePair; + public static DOrtReleaseKeyValuePairs OrtReleaseKeyValuePairs; + + + // + // OrtHardwareDevice + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate int /* OrtHardwareDeviceType */ DOrtHardwareDevice_Type( + IntPtr /* const OrtHardwareDevice* */ device); + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate uint /* uint32_t */ DOrtHardwareDevice_VendorId( + IntPtr /* const OrtHardwareDevice* */ device); + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* const char* */ DOrtHardwareDevice_Vendor( + IntPtr /* const OrtHardwareDevice* */ device); + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate uint /* uint32_t */ DOrtHardwareDevice_DeviceId( + IntPtr /* const OrtHardwareDevice* */ device); + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* const OrtKeyValuePairs* */ DOrtHardwareDevice_Metadata( + IntPtr /* const OrtHardwareDevice* */ device); + + + public static DOrtHardwareDevice_Type OrtHardwareDevice_Type; + public static DOrtHardwareDevice_VendorId OrtHardwareDevice_VendorId; + public static DOrtHardwareDevice_Vendor OrtHardwareDevice_Vendor; + public static DOrtHardwareDevice_DeviceId OrtHardwareDevice_DeviceId; + public static DOrtHardwareDevice_Metadata OrtHardwareDevice_Metadata; + + // + // OrtEpDevice + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* const char* */ DOrtEpDevice_EpName(IntPtr /* const OrtEpDevice* */ ep_device); + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* const char* */ DOrtEpDevice_EpVendor(IntPtr /* const OrtEpDevice* */ ep_device); + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* const OrtKeyValuePairs* */ DOrtEpDevice_EpMetadata( + IntPtr /* const OrtEpDevice* */ ep_device); + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* const OrtKeyValuePairs* */ DOrtEpDevice_EpOptions( + IntPtr /* const OrtEpDevice* */ ep_device); + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* const OrtHardwareDevice* */ DOrtEpDevice_Device( + IntPtr /* const OrtEpDevice* */ ep_device); + + + public static DOrtEpDevice_EpName OrtEpDevice_EpName; + public static DOrtEpDevice_EpVendor OrtEpDevice_EpVendor; + public static DOrtEpDevice_EpMetadata OrtEpDevice_EpMetadata; + public static DOrtEpDevice_EpOptions OrtEpDevice_EpOptions; + public static DOrtEpDevice_Device OrtEpDevice_Device; + + // + // Auto Selection EP registration and selection customization + + /// + /// Register an execution provider library. + /// The library must implement CreateEpFactories and ReleaseEpFactory. + /// + /// Environment to add the EP library to. + /// Name to register the library under. + /// Absolute path to the library. + /// OrtStatus* + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* OrtStatus* */ DOrtRegisterExecutionProviderLibrary( + IntPtr /* OrtEnv* */ env, + byte[] /* const char* */ registration_name, + byte[] /* const ORTCHAR_T* */ path); + + /// + /// Unregister an execution provider library. + /// + /// The environment to unregister the library from. + /// The name the library was registered under. + /// OrtStatus* + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* OrtStatus* */ DOrtUnregisterExecutionProviderLibrary( + IntPtr /* OrtEnv* */ env, + byte[] /* const char* */ registration_name); + + public static DOrtRegisterExecutionProviderLibrary OrtRegisterExecutionProviderLibrary; + public static DOrtUnregisterExecutionProviderLibrary OrtUnregisterExecutionProviderLibrary; + + /// + /// Get the OrtEpDevices that are available. + /// These are all the possible execution provider and device pairs. + /// + /// OrtStatus* + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* OrtStatus* */ DOrtGetEpDevices( + IntPtr /* const OrtEnv* */ env, + out IntPtr /* const OrtEpDevice* const** */ ep_devices, + out UIntPtr /* size_t* */ num_ep_devices); + + public static DOrtGetEpDevices OrtGetEpDevices; + + /// + /// Add execution provider devices to the session options. + /// Priority is based on the order of the OrtEpDevice instances. Highest priority first. + /// All OrtEpDevice instances in ep_devices must be for the same execution provider. + /// e.g. selecting OpenVINO for GPU and NPU would have an OrtEpDevice for GPU and NPU. + /// + /// SessionOptions to add to. + /// Environment that the OrtEpDevice instances came from by calling GetEpDevices + /// One or more OrtEpDevice instances. + /// Number of OrtEpDevice instances. + /// User overrides for execution provider options. May be IntPtr.Zero. + /// User overrides for execution provider options. May be IntPtr.Zero. + /// Number of user overrides for execution provider options. + /// + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* OrtStatus* */ DOrtSessionOptionsAppendExecutionProvider_V2( + IntPtr /* OrtSessionOptions* */ sess_options, + IntPtr /* OrtEnv* */ env, + IntPtr[] /* const OrtEpDevice* const* */ ep_devices, + UIntPtr /* size_t */ num_ep_devices, + IntPtr /* const char* const* */ ep_option_keys, // use OrtKeyValuePairs.GetKeyValuePairHandles + IntPtr /* const char* const* */ ep_option_vals, + UIntPtr /* size_t */ num_ep_options); + + public static DOrtSessionOptionsAppendExecutionProvider_V2 OrtSessionOptionsAppendExecutionProvider_V2; + + /// + /// Delegate to do custom execution provider selection. + /// + /// Available OrtEpDevices to select from. + /// Number of OrtEpDevices. + /// Metadata from the ONNX model. + /// Runtime metadata. May be IntPtr.Zero. + /// OrtEpDevices that were selected. Pre-allocated array for delegate to update. + /// Maximum number of OrtEpDevices that can be selected. + /// Number of OrtEpDevices that were selected. + /// State that was provided in when the delegate was registered. + /// OrtStatus* + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr DOrtEpSelectionDelegate( + IntPtr /* OrtEpDevice** */ epDevices, + uint numDevices, + IntPtr /* OrtKeyValuePairs* */ modelMetadata, + IntPtr /* OrtKeyValuePairs* */ runtimeMetadata, + IntPtr /* OrtEpDevice** */ selected, + uint maxSelected, + out UIntPtr numSelected, + IntPtr /* void* */ state + ); + + /// + /// Set the execution provider selection policy. + /// + /// SessionOptions to set the policy for. + /// Selection policy. + /// OrtStatus* + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* OrtStatus* */ DSessionOptionsSetEpSelectionPolicy( + IntPtr /* OrtSessionOptions* */ session_options, + int /* OrtExecutionProviderDevicePolicy */ policy); + public static DSessionOptionsSetEpSelectionPolicy OrtSessionOptionsSetEpSelectionPolicy; + + /// + /// Set the execution provider selection policy delegate. + /// + /// SessionOptions to set the policy for. + /// Selection policy delegate. + /// State that is passed through to the selection delegate. + /// OrtStatus* + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* OrtStatus* */ DSessionOptionsSetEpSelectionPolicyDelegate( + IntPtr /* OrtSessionOptions* */ session_options, + IntPtr /* DOrtEpSelectionDelegate* */ selection_delegate, + IntPtr /* void* */ state); + public static DSessionOptionsSetEpSelectionPolicyDelegate OrtSessionOptionsSetEpSelectionPolicyDelegate; + + + #endregion + #region Misc API /// /// Queries all the execution providers supported in the native onnxruntime shared library diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/OrtEnv.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/OrtEnv.shared.cs index f4b2649f8d055..5c70808b82be1 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/OrtEnv.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/OrtEnv.shared.cs @@ -2,6 +2,7 @@ // Licensed under the MIT License. using System; +using System.Collections.Generic; using System.Runtime.InteropServices; namespace Microsoft.ML.OnnxRuntime @@ -376,6 +377,68 @@ public OrtLoggingLevel EnvLogLevel } } + /// + /// Register an execution provider library with the OrtEnv instance. + /// A registered execution provider library can be used by all sessions created with the OrtEnv instance. + /// Devices the execution provider can utilize are added to the values returned by GetEpDevices() and can + /// be used in SessionOptions.AppendExecutionProvider to select an execution provider for a device. + /// + /// Coming: A selection policy can be specified and ORT will automatically select the best execution providers + /// and devices for the model. + /// + /// The name to register the library under. + /// The path to the library to register. + /// + /// + public void RegisterExecutionProviderLibrary(string registrationName, string libraryPath) + { + var registrationNameUtf8 = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(registrationName); + var pathUtf8 = NativeOnnxValueHelper.GetPlatformSerializedString(libraryPath); + + NativeApiStatus.VerifySuccess( + NativeMethods.OrtRegisterExecutionProviderLibrary(handle, registrationNameUtf8, pathUtf8)); + } + + /// + /// Unregister an execution provider library from the OrtEnv instance. + /// + /// The name the library was registered under. + public void UnregisterExecutionProviderLibrary(string registrationName) + { + var registrationNameUtf8 = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(registrationName); + + NativeApiStatus.VerifySuccess( + NativeMethods.OrtUnregisterExecutionProviderLibrary(handle, registrationNameUtf8)); + } + + /// + /// Get the list of all execution provider and device combinations that are available. + /// These can be used to select the execution provider and device for a session. + /// + /// + /// + /// + public IReadOnlyList GetEpDevices() + { + IntPtr epDevicesPtr; + UIntPtr numEpDevices; + + NativeApiStatus.VerifySuccess(NativeMethods.OrtGetEpDevices(handle, out epDevicesPtr, out numEpDevices)); + + int count = (int)numEpDevices; + var epDevices = new List(count); + + IntPtr[] epDevicePtrs = new IntPtr[count]; + Marshal.Copy(epDevicesPtr, epDevicePtrs, 0, count); + + foreach (var ptr in epDevicePtrs) + { + epDevices.Add(new OrtEpDevice(ptr)); + } + + return epDevices.AsReadOnly(); + } + #endregion #region SafeHandle overrides diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/OrtEpDevice.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/OrtEpDevice.shared.cs new file mode 100644 index 0000000000000..0318e08519128 --- /dev/null +++ b/csharp/src/Microsoft.ML.OnnxRuntime/OrtEpDevice.shared.cs @@ -0,0 +1,86 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Runtime.InteropServices; + +namespace Microsoft.ML.OnnxRuntime +{ + /// + /// Represents the combination of an execution provider and a hardware device + /// that the execution provider can utilize. + /// + public class OrtEpDevice + { + /// + /// Construct an OrtEpDevice from an existing native OrtEpDevice instance. + /// + /// Native OrtEpDevice handle. + internal OrtEpDevice(IntPtr epDeviceHandle) + { + _handle = epDeviceHandle; + } + + internal IntPtr Handle => _handle; + + /// + /// The name of the execution provider. + /// + public string EpName + { + get + { + IntPtr namePtr = NativeMethods.OrtEpDevice_EpName(_handle); + return NativeOnnxValueHelper.StringFromNativeUtf8(namePtr); + } + } + + /// + /// The vendor who owns the execution provider. + /// + public string EpVendor + { + get + { + IntPtr vendorPtr = NativeMethods.OrtEpDevice_EpVendor(_handle); + return NativeOnnxValueHelper.StringFromNativeUtf8(vendorPtr); + } + } + + /// + /// Execution provider metadata. + /// + public OrtKeyValuePairs EpMetadata + { + get + { + return new OrtKeyValuePairs(NativeMethods.OrtEpDevice_EpMetadata(_handle)); + } + } + + /// + /// Execution provider options. + /// + public OrtKeyValuePairs EpOptions + { + get + { + return new OrtKeyValuePairs(NativeMethods.OrtEpDevice_EpOptions(_handle)); + } + } + + /// + /// The hardware device that the execution provider can utilize. + /// + public OrtHardwareDevice HardwareDevice + { + get + { + IntPtr devicePtr = NativeMethods.OrtEpDevice_Device(_handle); + return new OrtHardwareDevice(devicePtr); + } + } + + private readonly IntPtr _handle; + } +} \ No newline at end of file diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/OrtHardwareDevice.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/OrtHardwareDevice.shared.cs new file mode 100644 index 0000000000000..af7115a92285e --- /dev/null +++ b/csharp/src/Microsoft.ML.OnnxRuntime/OrtHardwareDevice.shared.cs @@ -0,0 +1,104 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + + +namespace Microsoft.ML.OnnxRuntime +{ + using System; + using System.Runtime.InteropServices; + + /// + /// Represents the type of hardware device. + /// Matches OrtHardwareDeviceType in the ORT C API. + /// + public enum OrtHardwareDeviceType + { + CPU = 0, + GPU = 1, + NPU = 2, + } + + /// + /// Represents a hardware device that is available on the current system. + /// + public class OrtHardwareDevice + { + + /// + /// Construct an OrtHardwareDevice for a native OrtHardwareDevice instance. + /// + /// Native OrtHardwareDevice handle. + internal OrtHardwareDevice(IntPtr deviceHandle) + { + _handle = deviceHandle; + } + + /// + /// Get the type of hardware device. + /// + public OrtHardwareDeviceType Type + { + get + { + return (OrtHardwareDeviceType)NativeMethods.OrtHardwareDevice_Type(_handle); + } + } + + /// + /// Get the vendor ID of the hardware device if known. + /// + /// + /// For PCIe devices the vendor ID is the PCIe vendor ID. See https://pcisig.com/membership/member-companies. + /// + public uint VendorId + { + get + { + return NativeMethods.OrtHardwareDevice_VendorId(_handle); + } + } + + /// + /// The vendor (manufacturer) of the hardware device. + /// + public string Vendor + { + get + { + IntPtr vendorPtr = NativeMethods.OrtHardwareDevice_Vendor(_handle); + return NativeOnnxValueHelper.StringFromNativeUtf8(vendorPtr); + } + } + + /// + /// Get the device ID of the hardware device if known. + /// + /// + /// This is the identifier of the device model. + /// PCIe device IDs can be looked up at https://www.pcilookup.com/ when combined with the VendorId. + /// It is NOT a unique identifier for the device in the current system. + /// + public uint DeviceId + { + get + { + return NativeMethods.OrtHardwareDevice_DeviceId(_handle); + } + } + + /// + /// Get device metadata. + /// This may include information such as whether a GPU is discrete or integrated. + /// The available metadata will differ by platform and device type. + /// + public OrtKeyValuePairs Metadata + { + get + { + return new OrtKeyValuePairs(NativeMethods.OrtHardwareDevice_Metadata(_handle)); + } + } + + private readonly IntPtr _handle; + } +} \ No newline at end of file diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/OrtKeyValuePairs.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/OrtKeyValuePairs.shared.cs new file mode 100644 index 0000000000000..6a8d1037d9017 --- /dev/null +++ b/csharp/src/Microsoft.ML.OnnxRuntime/OrtKeyValuePairs.shared.cs @@ -0,0 +1,192 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + + +namespace Microsoft.ML.OnnxRuntime +{ + using System; + using System.Collections.Generic; + using System.Runtime.InteropServices; + + /// + /// Class to manage key-value pairs. + /// These are most often used for options and metadata. + /// + /// + /// + /// + public class OrtKeyValuePairs : SafeHandle + { + private readonly bool _createdHandle; + + // cache the values here for convenience. + // we could force a call to the C API every time in case something was changed in the background. + private Dictionary _keyValuePairs; + + /// + /// Create a new OrtKeyValuePairs instance. + /// + /// + /// A backing native instance is created and kept in sync with the C# content. + /// + public OrtKeyValuePairs() + : base(IntPtr.Zero, ownsHandle: true) + { + NativeMethods.OrtCreateKeyValuePairs(out handle); + _createdHandle = true; + _keyValuePairs = new Dictionary(); + } + + /// + /// Create a new OrtKeyValuePairs instance from an existing native OrtKeyValuePairs handle. + /// + /// Native OrtKeyValuePairs handle. + /// + /// The instance is read-only, so calling Add or Remove will throw an InvalidOperationError. + /// + internal OrtKeyValuePairs(IntPtr constHandle) + : base(constHandle, ownsHandle: false) + { + _createdHandle = false; + _keyValuePairs = GetLatest(); + } + + /// + /// Create a new OrtKeyValuePairs instance from a dictionary. + /// + /// Key-value pairs to add. + /// + /// A backing native instance is created and kept in sync with the C# content. + /// + public OrtKeyValuePairs(IReadOnlyDictionary keyValuePairs) + : base(IntPtr.Zero, ownsHandle: true) + { + NativeMethods.OrtCreateKeyValuePairs(out handle); + _createdHandle = true; + _keyValuePairs = new Dictionary(keyValuePairs != null ? keyValuePairs.Count : 0); + + if (keyValuePairs != null && keyValuePairs.Count > 0) + { + foreach (var kvp in keyValuePairs) + { + Add(kvp.Key, kvp.Value); + } + } + } + + /// + /// Current key-value pair entries. + /// + /// + /// Call Refresh() to update the cached values with the latest from the backing native instance. + /// In general that should not be required as it's not expected an OrtKeyValuePairs instance would be + /// updated by both native and C# code. + /// + public IReadOnlyDictionary Entries => _keyValuePairs; + + /// + /// Adds a key-value pair. Overrides any existing value for the key. + /// + /// Key to add. Must not be null or empty. + /// Value to add. May be empty. Must not be null. + public void Add(string key, string value) + { + if (!_createdHandle) + { + throw new InvalidOperationException($"{nameof(Add)} can only be called on instances you created."); + } + + var keyPtr = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(key); + var valuePtr = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(value); + NativeMethods.OrtAddKeyValuePair(handle, keyPtr, valuePtr); + _keyValuePairs[key] = value; // update the cached value + } + + /// + /// Update the cached values with the latest from the backing native instance as that is the source of truth. + /// + public void Refresh() + { + // refresh the cached values. + _keyValuePairs = GetLatest(); + } + + /// + /// Removes a key-value pair by key. Ignores keys that do not exist. + /// + /// Key to remove. + public void Remove(string key) + { + if (!_createdHandle) + { + throw new InvalidOperationException($"{nameof(Remove)} can only be called on instances you created."); + } + + var keyPtr = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(key); + NativeMethods.OrtRemoveKeyValuePair(handle, keyPtr); + + _keyValuePairs.Remove(key); // update the cached value + } + + // for internal usage to pass into the call to OrtSessionOptionsAppendExecutionProvider_V2 + // from SessionOptions::AppendExecutionProvider + internal void GetKeyValuePairHandles(out IntPtr keysHandle, out IntPtr valuesHandle, out UIntPtr numEntries) + { + if (IsInvalid) + { + throw new InvalidOperationException($"{nameof(GetKeyValuePairHandles)}: Invalid instance."); + } + + NativeMethods.OrtGetKeyValuePairs(handle, out keysHandle, out valuesHandle, out numEntries); + } + + /// + /// Fetch all the key/value pairs to make sure we are in sync with the C API. + /// + private Dictionary GetLatest() + { + var dict = new Dictionary(); + if (IsInvalid) + { + return dict; + } + + IntPtr keys, values; + UIntPtr numEntries; + NativeMethods.OrtGetKeyValuePairs(handle, out keys, out values, out numEntries); + + ulong count = numEntries.ToUInt64(); + int offset = 0; + for (ulong i = 0; i < count; i++, offset += IntPtr.Size) + { + IntPtr keyPtr = Marshal.ReadIntPtr(keys, offset); + IntPtr valuePtr = Marshal.ReadIntPtr(values, offset); + var key = NativeOnnxValueHelper.StringFromNativeUtf8(keyPtr); + var value = NativeOnnxValueHelper.StringFromNativeUtf8(valuePtr); + dict.Add(key, value); + } + + return dict; + } + + /// + /// Indicates whether the native handle is invalid. + /// + public override bool IsInvalid { get { return handle == IntPtr.Zero; } } + + /// + /// Release the native instance of OrtKeyValuePairs if we own it. + /// + /// true + protected override bool ReleaseHandle() + { + if (_createdHandle) + { + NativeMethods.OrtReleaseKeyValuePairs(handle); + handle = IntPtr.Zero; + } + + return true; + } + } +} \ No newline at end of file diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/OrtValue.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/OrtValue.shared.cs index 7a5c3aaa19eac..01ee3aa5ae753 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/OrtValue.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/OrtValue.shared.cs @@ -312,6 +312,17 @@ public SystemNumericsTensors.TensorSpan GetTensorSpanMutableRawData() w } #endif + /// + /// This API computes and returns the size of the tensor data in bytes. + /// + /// size of the tensor data in bytes + public long GetTensorSizeInBytes() + { + // The native API verifies that this is a non-string tensor + NativeApiStatus.VerifySuccess(NativeMethods.OrtGetTensorSizeInBytes(Handle, out UIntPtr size)); + return (long)size; + } + /// /// Fetch string tensor element buffer pointer at the specified index, /// convert/copy to UTF-16 char[] and return a ReadOnlyMemory{char} instance. @@ -689,8 +700,8 @@ public static OrtValue CreateTensorValueFromMemory(T[] data, long[] shape) wh /// The method will attempt to pin managed memory so no copying occurs when data is passed down /// to native code. /// - /// Tensor object - /// discovered tensor element type + /// + /// Tensor object /// And instance of OrtValue constructed on top of the object [Experimental("SYSLIB5001")] public static OrtValue CreateTensorValueFromSystemNumericsTensorObject(SystemNumericsTensors.Tensor tensor) where T : unmanaged @@ -1357,7 +1368,7 @@ public static OrtValue CreateMapWithStringValues(K[] keys, IReadOnlyCollectio /// This API helps the user to process a map OrtValue without /// having to deal with the lifespan of intermediate OrtValues. /// - /// each API value is fed to the vistor functor. + /// each API value is fed to the visitor functor. /// /// visitor function /// Allocator to use for intermediate values diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs index 9b0f183f03681..9794d2c184d5d 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs @@ -32,6 +32,21 @@ public enum ExecutionMode ORT_PARALLEL = 1, } + /// + /// Controls the execution provider selection when using automatic EP selection. + /// Execution providers must be registered with the OrtEnv to be available for selection. + /// + public enum ExecutionProviderDevicePolicy + { + DEFAULT = 0, + PREFER_CPU = 1, + PREFER_NPU, + PREFER_GPU, + MAX_PERFORMANCE, + MAX_EFFICIENCY, + MIN_OVERALL_POWER, + } + /// /// Holds the options for creating an InferenceSession /// It forces the instantiation of the OrtEnv singleton. @@ -408,6 +423,82 @@ public void AppendExecutionProvider(string providerName, Dictionary + /// Select execution providers from the list of available execution providers and devices returned by + /// GetEpDevices. + /// + /// One or more OrtEpDevice instances may be provided in epDevices, but must all be for the same + /// execution provider. + /// + /// Make multiple calls to AppendExecutionProvider if you wish to use multiple execution providers. + /// + /// e.g. + /// - if execution provider 'A' has an OrtEpDevice for NPU and one for GPU and you wish to use it for + /// both devices, pass the two OrtEpDevice instances in the epDevices list in one call. + /// - if you wish to use execution provider 'B' for GPU and execution provider 'C' for CPU, + /// make two calls to AppendExecutionProvider, with one OrtEpDevice in the epDevices list in each call. + /// + /// The priority of the execution providers is set by the order in which they are appended. + /// Highest priority is first. + /// + /// OrtEnv that provided the OrtEpDevice instances via a call to GetEpDevices. + /// One or more OrtEpDevice instances to append. + /// These must all have the save EpName value. + /// Optional options to configure the execution provider. May be null. + /// epDevices was empty. + /// + public void AppendExecutionProvider(OrtEnv env, IReadOnlyList epDevices, + IReadOnlyDictionary epOptions) + { + if (epDevices == null || epDevices.Count == 0) + { + throw new ArgumentException("No execution provider devices were specified."); + } + + // Convert EpDevices to native pointers + IntPtr[] epDevicePtrs = new IntPtr[epDevices.Count]; + for (int i = 0; i < epDevices.Count; i++) + { + epDevicePtrs[i] = epDevices[i].Handle; + } + + if (epOptions != null && epOptions.Count > 0) + { + // this creates an OrtKeyValuePairs instance with a backing native instance + using var kvps = new OrtKeyValuePairs(epOptions); + + // get the native key/value handles so we can pass those straight through to the C API + // and not have to do any special marshaling here. + IntPtr epOptionsKeys, epOptionsValues; + UIntPtr epOptionsCount; + kvps.GetKeyValuePairHandles(out epOptionsKeys, out epOptionsValues, out epOptionsCount); + + NativeApiStatus.VerifySuccess( + NativeMethods.OrtSessionOptionsAppendExecutionProvider_V2( + handle, + env.Handle, + epDevicePtrs, + (UIntPtr)epDevices.Count, + epOptionsKeys, + epOptionsValues, + epOptionsCount)); + } + else + { + NativeApiStatus.VerifySuccess( + NativeMethods.OrtSessionOptionsAppendExecutionProvider_V2( + handle, + env.Handle, + epDevicePtrs, + (UIntPtr)epDevices.Count, + IntPtr.Zero, // EP options keys + IntPtr.Zero, // EP options values + UIntPtr.Zero)); // EP options count + } + + } + #endregion //ExecutionProviderAppends #region Public Methods @@ -452,8 +543,8 @@ public void RegisterCustomOpLibraryV2(string libraryPath, out IntPtr libraryHand // End result of that is // SessionOptions.RegisterCustomOpLibrary calls NativeMethods.OrtRegisterCustomOpsLibrary_V2 // SessionOptions.RegisterCustomOpLibraryV2 calls NativeMethods.OrtRegisterCustomOpsLibrary - var utf8Path = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(libraryPath); - NativeApiStatus.VerifySuccess(NativeMethods.OrtRegisterCustomOpsLibrary(handle, utf8Path, + var platformPath = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(libraryPath); + NativeApiStatus.VerifySuccess(NativeMethods.OrtRegisterCustomOpsLibrary(handle, platformPath, out libraryHandle)); } @@ -536,6 +627,41 @@ public void AddFreeDimensionOverrideByName(string dimName, long dimValue) var utf8 = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(dimName); NativeApiStatus.VerifySuccess(NativeMethods.OrtAddFreeDimensionOverrideByName(handle, utf8, dimValue)); } + + /// + /// Set the execution provider selection policy if using automatic execution provider selection. + /// Execution providers must be registered with the OrtEnv to be available for selection. + /// + /// Policy to use. + public void SetEpSelectionPolicy(ExecutionProviderDevicePolicy policy) + { + NativeApiStatus.VerifySuccess( + NativeMethods.OrtSessionOptionsSetEpSelectionPolicy(handle, (int)policy)); + } + + /// + /// Set the execution provider selection policy if using automatic execution provider selection. + /// Execution providers must be registered with the OrtEnv to be available for selection. + /// + /// Delegate that implements the custom selection policy. + public void SetEpSelectionPolicyDelegate(EpSelectionDelegate selectionDelegate = null) + { + _epSelectionPolicyConnector = new EpSelectionPolicyConnector(selectionDelegate); + _epSelectionPolicyDelegate = new NativeMethods.DOrtEpSelectionDelegate( + EpSelectionPolicyConnector.EpSelectionPolicyWrapper); + + // make sure these stay alive. not sure if this is necessary when they're class members though + _epSelectionPolicyConnectorHandle = GCHandle.Alloc(_epSelectionPolicyConnector); + _epSelectionPolicyDelegateHandle = GCHandle.Alloc(_epSelectionPolicyDelegate); + + IntPtr funcPtr = Marshal.GetFunctionPointerForDelegate(_epSelectionPolicyDelegate); + + NativeApiStatus.VerifySuccess( + NativeMethods.OrtSessionOptionsSetEpSelectionPolicyDelegate( + handle, + funcPtr, + GCHandle.ToIntPtr(_epSelectionPolicyConnectorHandle))); + } #endregion internal IntPtr Handle @@ -811,7 +937,120 @@ public void SetLoadCancellationFlag(bool value) { NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionOptionsSetLoadCancellationFlag(handle, value)); } + #endregion + + #region Selection Policy Delegate helpers + /// + /// Delegate to select execution provider devices from a list of available devices. + /// + /// OrtEpDevices to select from. + /// Model metadata. + /// Runtime metadata. + /// Maximum number of devices that can be selected. + /// Selected devices. Ordered by priority. Highest priority first. + public delegate List EpSelectionDelegate(IReadOnlyList epDevices, + OrtKeyValuePairs modelMetadata, + OrtKeyValuePairs runtimeMetadata, + uint maxSelections); + + /// + /// Class to bridge the C# and native worlds for the EP selection policy delegate + /// + internal class EpSelectionPolicyConnector + { + private readonly EpSelectionDelegate _csharpDelegate; + + internal EpSelectionPolicyConnector(EpSelectionDelegate selectionDelegate) + { + _csharpDelegate = selectionDelegate; + } + /// + /// Delegate to convert between the C and C# worlds + /// + /// OrtEpDevices to select from. + /// Number of OrtEpDevices. + /// Model metadata. + /// Runtime metadata. + /// Pre-allocated OrtEpDevice buffer to update with selected devices. + /// Number of entries in selectedOut. + /// Number of OrtEpDevies that were selected. + /// Opaque state. + /// nullptr for OrtStatus* to indicate success. + /// Currently we don't have a way to create an OrtStatus instance from the C# bindings. + /// Can add if we need to return an explicit error message. + /// + public static IntPtr EpSelectionPolicyWrapper(IntPtr /* OrtEpDevice** */ epDevicesIn, + uint numDevices, + IntPtr /* OrtKeyValuePairs* */ modelMetadataIn, + IntPtr /* OrtKeyValuePairs* */ runtimeMetadataIn, + IntPtr /* OrtEpDevice** */ selectedOut, + uint maxSelected, + out UIntPtr numSelected, + IntPtr state) + { + numSelected = UIntPtr.Zero; + + try + { + + Span epDevicesIntPtrs; + Span selectedDevicesIntPtrs; + EpSelectionPolicyConnector connector = (EpSelectionPolicyConnector)GCHandle.FromIntPtr(state).Target; + + unsafe + { + void* ptr = epDevicesIn.ToPointer(); + epDevicesIntPtrs = new Span(ptr, checked((int)numDevices)); + } + + List epDevices = new List(); + for (int i = 0; i < numDevices; i++) + { + + epDevices.Add(new OrtEpDevice(epDevicesIntPtrs[i])); + } + + OrtKeyValuePairs modelMetadata = new OrtKeyValuePairs(modelMetadataIn); + OrtKeyValuePairs runtimeMetadata = new OrtKeyValuePairs(runtimeMetadataIn); + + var selected = connector._csharpDelegate(epDevices, modelMetadata, runtimeMetadata, maxSelected); + + if (selected.Count > maxSelected) + { + var error = $"The number of selected devices ({selected.Count}) returned by " + + $"the C# selection delegate exceeds the maximum ({maxSelected})."; + IntPtr status = NativeMethods.OrtCreateStatus((uint)ErrorCode.Fail, + NativeOnnxValueHelper.StringToZeroTerminatedUtf8(error)); + return status; + } + + numSelected = (UIntPtr)selected.Count; + + unsafe + { + void* ptr = selectedOut.ToPointer(); + selectedDevicesIntPtrs = new Span(ptr, (int)maxSelected); + } + + int idx = 0; + foreach (var epDevice in selected) + { + selectedDevicesIntPtrs[idx] = epDevice.Handle; + idx++; + } + } + catch (Exception ex) + { + var error = $"The C# selection delegate threw an exception: {ex.Message}"; + IntPtr status = NativeMethods.OrtCreateStatus((uint)ErrorCode.Fail, + NativeOnnxValueHelper.StringToZeroTerminatedUtf8(error)); + return status; + } + + return IntPtr.Zero; + } + } #endregion #region Private Methods @@ -897,8 +1136,43 @@ protected override bool ReleaseHandle() { NativeMethods.OrtReleaseSessionOptions(handle); handle = IntPtr.Zero; + + if (_epSelectionPolicyConnectorHandle.IsAllocated) + { + _epSelectionPolicyConnectorHandle.Free(); + _epSelectionPolicyConnector = null; + } + + if (_epSelectionPolicyDelegateHandle.IsAllocated) + { + _epSelectionPolicyDelegateHandle.Free(); + _epSelectionPolicyDelegate = null; + } + + return true; } #endregion + + /// + /// Helper class to connect C and C# usage of the EP selection policy delegate. + /// + EpSelectionPolicyConnector _epSelectionPolicyConnector = null; + + /// + /// Handle to the EP selection policy connector that is passed to the C API as state for the + /// EP selection policy delegate. + /// + GCHandle _epSelectionPolicyConnectorHandle = default; + + /// + /// Delegate instance that is provided to the C API. + /// + NativeMethods.DOrtEpSelectionDelegate _epSelectionPolicyDelegate = null; + + /// + /// Handle to the EP selection policy delegate that is passed to the C API. + /// + GCHandle _epSelectionPolicyDelegateHandle = default; } } diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/Tensors/ArrayUtilities.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/Tensors/ArrayUtilities.shared.cs index 025c1331ce54d..a0c1a0a150365 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/Tensors/ArrayUtilities.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/Tensors/ArrayUtilities.shared.cs @@ -179,7 +179,7 @@ public static void GetIndices(ReadOnlySpan strides, bool reverseStride, int } /// - /// Calculates the n-d indices from the 1-d index in a layout specificed by strides + /// Calculates the n-d indices from the 1-d index in a layout specified by strides /// /// /// diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/Tensors/Tensor.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/Tensors/Tensor.shared.cs index e927b8105c6c9..9de80a15942e4 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/Tensors/Tensor.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/Tensors/Tensor.shared.cs @@ -68,7 +68,7 @@ public class TensorTypeInfo /// Ctor /// /// TensorElementType value - /// size fo the type in bytes + /// size of the type in bytes public TensorTypeInfo(TensorElementType elementType, int typeSize) { ElementType = elementType; @@ -674,7 +674,7 @@ public Tensor GetDiagonal(int offset) // the diagonal will be the length of the smaller axis // if offset it positive, the length will shift along the second axis - // if the offsett is negative, the length will shift along the first axis + // if the offset is negative, the length will shift along the first axis // In that way the length of the diagonal will be // Min(offset < 0 ? axisLength0 + offset : axisLength0, offset > 0 ? axisLength1 - offset : axisLength1) // To illustrate, consider the following @@ -907,21 +907,21 @@ public virtual T this[ReadOnlySpan indices] } /// - /// Gets the value at the specied index, where index is a linearized version of n-dimension indices using strides. + /// Gets the value at the specified index, where index is a linearized version of n-dimension indices using strides. /// /// An integer index computed as a dot-product of indices. /// The value at the specified position in this Tensor. public abstract T GetValue(int index); /// - /// Sets the value at the specied index, where index is a linearized version of n-dimension indices using strides. + /// Sets the value at the specified index, where index is a linearized version of n-dimension indices using strides. /// /// An integer index computed as a dot-product of indices. /// The new value to set at the specified position in this Tensor. public abstract void SetValue(int index, T value); - #region statics + #region statistics /// /// Performs a value comparison of the content and shape of two tensors. Two tensors are equal if they have the same shape and same value at every set of indices. If not equal a tensor is greater or less than another tensor based on the first non-equal element when enumerating in linear order. /// diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/targets/netstandard/props_qnn.xml b/csharp/src/Microsoft.ML.OnnxRuntime/targets/netstandard/props_qnn.xml index fa0e957418fab..83ffb22ccf6b2 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/targets/netstandard/props_qnn.xml +++ b/csharp/src/Microsoft.ML.OnnxRuntime/targets/netstandard/props_qnn.xml @@ -12,7 +12,7 @@ - $(MSBuildThisFileDirectory)../../runtimes/win-arm64x/native/onnxruntime.lib;%(AdditionalDependencies) + $(MSBuildThisFileDirectory)../../runtimes/win-arm64/native/onnxruntime.lib;%(AdditionalDependencies) @@ -24,7 +24,7 @@ - $(MSBuildThisFileDirectory)../../runtimes/win-arm64x/native/onnxruntime.lib;%(AdditionalDependencies) + $(MSBuildThisFileDirectory)../../runtimes/win-arm64/native/onnxruntime.lib;%(AdditionalDependencies) @@ -36,7 +36,7 @@ x86 - arm64x + arm64 arm $(Platform) @@ -47,31 +47,31 @@ - + Exists('$(MSBuildThisFileDirectory)..\..\runtimes\win-arm64\native\onnxruntime.dll')"> onnxruntime.dll PreserveNewest false - + Exists('$(MSBuildThisFileDirectory)..\..\runtimes\win-arm64\native\onnxruntime_providers_shared.dll')"> onnxruntime_providers_shared.dll PreserveNewest false - onnxruntime.dll PreserveNewest false - + Exists('$(MSBuildThisFileDirectory)..\..\runtimes\win-arm64\native\onnxruntime_providers_shared.dll')"> onnxruntime_providers_shared.dll PreserveNewest false diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/CompileApiTests.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/CompileApiTests.cs new file mode 100644 index 0000000000000..72c165df56418 --- /dev/null +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/CompileApiTests.cs @@ -0,0 +1,67 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// not supported on mobile platforms +#if !(ANDROID || IOS) + +namespace Microsoft.ML.OnnxRuntime.Tests; + +using System; +using System.Globalization; +using System.Runtime.InteropServices; +using Xunit; + + +public class CompileApiTests +{ + private OrtEnv ortEnvInstance = OrtEnv.Instance(); + + + [Fact] + public void BasicUsage() + { + var so = new SessionOptions(); + using (var compileOptions = new OrtModelCompilationOptions(so)) + { + // mainly checking these don't throw which ensures all the plumbing for the binding works. + compileOptions.SetInputModelPath("model.onnx"); + compileOptions.SetOutputModelPath("compiled_model.onnx"); + + compileOptions.SetOutputModelExternalInitializersFile("external_data.bin", 512); + compileOptions.SetEpContextEmbedMode(true); + + } + + // setup a new instance as SetOutputModelExternalInitializersFile is incompatible with SetOutputModelBuffer + using (var compileOptions = new OrtModelCompilationOptions(so)) + { + var model = TestDataLoader.LoadModelFromEmbeddedResource("squeezenet.onnx"); + compileOptions.SetInputModelFromBuffer(model); + + // SetOutputModelBuffer updates the user provided IntPtr and size when it allocates data post-compile. + // Due to that we need to allocate an IntPtr and UIntPtr here. + IntPtr bytePtr = new IntPtr(); + UIntPtr bytesSize = new UIntPtr(); + var allocator = OrtAllocator.DefaultInstance; + compileOptions.SetOutputModelBuffer(allocator, ref bytePtr, ref bytesSize); + + compileOptions.CompileModel(); + + Assert.NotEqual(IntPtr.Zero, bytePtr); + Assert.NotEqual(UIntPtr.Zero, bytesSize); + + byte[] compiledBytes = new byte[bytesSize.ToUInt64()]; + Marshal.Copy(bytePtr, compiledBytes, 0, (int)bytesSize.ToUInt32()); + + // Check the compiled model is valid + using (var session = new InferenceSession(compiledBytes, so)) + { + Assert.NotNull(session); + } + + allocator.FreeMemory(bytePtr); + } + } +} + +#endif \ No newline at end of file diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/OnnxMl.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/OnnxMl.cs index 510097e2f4c29..f9e9f42f8b78e 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/OnnxMl.cs +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/OnnxMl.cs @@ -209,7 +209,7 @@ public enum Version { [pbr::OriginalName("IR_VERSION_2019_9_19")] IrVersion2019919 = 6, /// /// IR VERSION 7 published on May 8, 2020 - /// - Add support to allow function body graph to rely on multiple external opreator sets. + /// - Add support to allow function body graph to rely on multiple external operator sets. /// - Add a list to promote inference graph's initializers to global and /// mutable variables. Global variables are visible in all graphs of the /// stored models. @@ -2481,7 +2481,7 @@ public string DocString { /// /// Name of the function "FunctionProto.name" should be unique within the domain "FunctionProto.domain". /// In case of any conflicts the behavior (whether the model local functions are given higher priority, - /// or standard opserator sets are given higher priotity or this is treated as error) is defined by + /// or standard opserator sets are given higher priority or this is treated as error) is defined by /// the runtimes. /// /// The operator sets imported by FunctionProto should be compatible with the ones diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/OrtAutoEpTests.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/OrtAutoEpTests.cs new file mode 100644 index 0000000000000..9368f9d8bc298 --- /dev/null +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/OrtAutoEpTests.cs @@ -0,0 +1,286 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// not supported on mobile platforms +#if !(ANDROID || IOS) + +namespace Microsoft.ML.OnnxRuntime.Tests; + +using System; +using System.Linq; +using System.IO; +using System.Runtime.InteropServices; +using Xunit; +using System.Collections.Generic; + +/// +/// Tests for auto ep selection/registration. +/// Includes testing of OrtHardwareDevice and OrtEpDevice as those only come from auto ep related code and we only +/// get read-only access to them (i.e. we can't directly create instances of them to test). +/// +public class OrtAutoEpTests +{ + private OrtEnv ortEnvInstance = OrtEnv.Instance(); + + private void ReadHardwareDeviceValues(OrtHardwareDevice device) + { + Assert.True(device.Type == OrtHardwareDeviceType.CPU || + device.Type == OrtHardwareDeviceType.GPU || + device.Type == OrtHardwareDeviceType.NPU); + if (device.Type == OrtHardwareDeviceType.CPU) + { + Assert.NotEmpty(device.Vendor); + } + else + { + Assert.True(device.VendorId != 0); + Assert.True(device.DeviceId != 0); + } + + var metadata = device.Metadata; + Assert.NotNull(metadata); + foreach (var kvp in metadata.Entries) + { + Assert.NotEmpty(kvp.Key); + // Assert.NotEmpty(kvp.Value); this is allowed + } + } + + [Fact] + public void GetEpDevices() + { + var epDevices = ortEnvInstance.GetEpDevices(); + Assert.NotNull(epDevices); + Assert.NotEmpty(epDevices); + foreach (var ep_device in epDevices) + { + Assert.NotEmpty(ep_device.EpName); + Assert.NotEmpty(ep_device.EpVendor); + var metadata = ep_device.EpMetadata; + Assert.NotNull(metadata); + var options = ep_device.EpOptions; + Assert.NotNull(options); + ReadHardwareDeviceValues(ep_device.HardwareDevice); + } + } + + [Fact] + public void RegisterUnregisterLibrary() + { + if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) + { + string libFullPath = Path.Combine(Directory.GetCurrentDirectory(), "example_plugin_ep.dll"); + Assert.True(File.Exists(libFullPath), $"Expected lib {libFullPath} does not exist."); + + // example plugin ep uses the registration name as the ep name + const string epName = "csharp_ep"; + + // register. shouldn't throw + ortEnvInstance.RegisterExecutionProviderLibrary(epName, libFullPath); + + // check OrtEpDevice was found + var epDevices = ortEnvInstance.GetEpDevices(); + var found = epDevices.Any(d => string.Equals(epName, d.EpName, StringComparison.OrdinalIgnoreCase)); + Assert.True(found); + + // unregister + ortEnvInstance.UnregisterExecutionProviderLibrary(epName); + } + } + + [Fact] + public void AppendToSessionOptionsV2() + { + var runTest = (Func> getEpOptions) => + { + using SessionOptions sessionOptions = new SessionOptions(); + sessionOptions.LogSeverityLevel = OrtLoggingLevel.ORT_LOGGING_LEVEL_VERBOSE; + + var epDevices = ortEnvInstance.GetEpDevices(); + + // cpu ep ignores the provider options so we can use any value in epOptions and it won't break. + List selectedEpDevices = epDevices.Where(d => d.EpName == "CPUExecutionProvider").ToList(); + + Dictionary epOptions = getEpOptions(); + sessionOptions.AppendExecutionProvider(ortEnvInstance, selectedEpDevices, epOptions); + + var model = TestDataLoader.LoadModelFromEmbeddedResource("squeezenet.onnx"); + + // session should load successfully + using (var session = new InferenceSession(model)) + { + Assert.NotNull(session); + } + }; + + runTest(() => + { + // null options + return null; + }); + + runTest(() => + { + // empty options + return new Dictionary(); + }); + + runTest(() => + { + // dummy options + return new Dictionary + { + { "random_key", "value" }, + }; + }); + } + + [Fact] + public void SetEpSelectionPolicy() + { + using SessionOptions sessionOptions = new SessionOptions(); + sessionOptions.LogSeverityLevel = OrtLoggingLevel.ORT_LOGGING_LEVEL_VERBOSE; + + var epDevices = ortEnvInstance.GetEpDevices(); + Assert.NotEmpty(epDevices); + + // doesn't matter what the value is. should fallback to ORT CPU EP + sessionOptions.SetEpSelectionPolicy(ExecutionProviderDevicePolicy.PREFER_GPU); + + var model = TestDataLoader.LoadModelFromEmbeddedResource("squeezenet.onnx"); + + // session should load successfully + using (var session = new InferenceSession(model, sessionOptions)) + { + Assert.NotNull(session); + } + } + + private static List SelectionPolicyDelegate(IReadOnlyList epDevices, + OrtKeyValuePairs modelMetadata, + OrtKeyValuePairs runtimeMetadata, + uint maxSelections) + { + Assert.NotEmpty(modelMetadata.Entries); + Assert.True(epDevices.Count > 0); + + // select first device and last (if there are more than one). + var selected = new List(); + + selected.Add(epDevices[0]); + + // add ORT CPU EP which is always last. + if (maxSelections > 2 && epDevices.Count > 1) + { + selected.Add(epDevices.Last()); + } + + return selected; + } + + [Fact] + public void SetEpSelectionPolicyDelegate() + { + using SessionOptions sessionOptions = new SessionOptions(); + sessionOptions.LogSeverityLevel = OrtLoggingLevel.ORT_LOGGING_LEVEL_VERBOSE; + + var epDevices = ortEnvInstance.GetEpDevices(); + Assert.NotEmpty(epDevices); + + // doesn't matter what the value is. should fallback to ORT CPU EP + sessionOptions.SetEpSelectionPolicyDelegate(SelectionPolicyDelegate); + + var model = TestDataLoader.LoadModelFromEmbeddedResource("squeezenet.onnx"); + + // session should load successfully + using (var session = new InferenceSession(model, sessionOptions)) + { + Assert.NotNull(session); + } + } + + // select max + 1, starting with all devices + private static List SelectionPolicyDelegateTooMany(IReadOnlyList epDevices, + OrtKeyValuePairs modelMetadata, + OrtKeyValuePairs runtimeMetadata, + uint maxSelections) + { + Assert.NotEmpty(modelMetadata.Entries); + Assert.True(epDevices.Count > 0); + var selected = new List(epDevices); + + while (selected.Count < (maxSelections + 1)) + { + selected.Add(epDevices.Last()); + } + + return selected; + } + + [Fact] + public void SetEpSelectionPolicyDelegateTooMany() + { + using SessionOptions sessionOptions = new SessionOptions(); + sessionOptions.LogSeverityLevel = OrtLoggingLevel.ORT_LOGGING_LEVEL_VERBOSE; + + var epDevices = ortEnvInstance.GetEpDevices(); + Assert.NotEmpty(epDevices); + + // select too many devices + sessionOptions.SetEpSelectionPolicyDelegate(SelectionPolicyDelegateTooMany); + + var model = TestDataLoader.LoadModelFromEmbeddedResource("squeezenet.onnx"); + + // session should fail + try + { + using var session = new InferenceSession(model, sessionOptions); + Assert.Fail("Should have thrown an exception"); + } + catch (OnnxRuntimeException ex) + { + // Current C++ max is 8. We copy all devices and keep adding until we exceed that. + const int max = 8; + var numSelected = epDevices.Count > max ? epDevices.Count : (max + 1); + var expected = "[ErrorCode:Fail] EP selection delegate failed: The number of selected devices " + + $"({numSelected}) returned by the C# selection delegate exceeds the maximum ({max})"; + Assert.Contains(expected, ex.Message); + } + } + + // throw exception in user provided delegate + private static List SelectionPolicyDelegateThrows(IReadOnlyList epDevices, + OrtKeyValuePairs modelMetadata, + OrtKeyValuePairs runtimeMetadata, + uint maxSelections) + { + throw new ArgumentException("Test exception"); + } + + [Fact] + public void SetEpSelectionPolicyDelegateThrows() + { + using SessionOptions sessionOptions = new SessionOptions(); + sessionOptions.LogSeverityLevel = OrtLoggingLevel.ORT_LOGGING_LEVEL_VERBOSE; + + var epDevices = ortEnvInstance.GetEpDevices(); + Assert.NotEmpty(epDevices); + + sessionOptions.SetEpSelectionPolicyDelegate(SelectionPolicyDelegateThrows); + + var model = TestDataLoader.LoadModelFromEmbeddedResource("squeezenet.onnx"); + + try + { + using var session = new InferenceSession(model, sessionOptions); + Assert.Fail("Should have thrown an exception"); + } + catch (OnnxRuntimeException ex) + { + var expected = "[ErrorCode:Fail] EP selection delegate failed: " + + "The C# selection delegate threw an exception: Test exception"; + Assert.Contains(expected, ex.Message); + } + } +} +#endif \ No newline at end of file diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/OrtKeyValuePairsTests.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/OrtKeyValuePairsTests.cs new file mode 100644 index 0000000000000..b89b970688d5f --- /dev/null +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/OrtKeyValuePairsTests.cs @@ -0,0 +1,39 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using Xunit; + +namespace Microsoft.ML.OnnxRuntime.Tests; + +public class OrtKeyValuePairsTests +{ + private OrtEnv ortEnvInstance = OrtEnv.Instance(); + + + [Fact] + public void CRUD() + { + using var kvp = new OrtKeyValuePairs(); + kvp.Add("key1", "value1"); + kvp.Add("key2", "value2"); + kvp.Add("key3", ""); // allowed + + Assert.Equal("value1", kvp.Entries["key1"]); + Assert.Equal("value2", kvp.Entries["key2"]); + Assert.Equal("", kvp.Entries["key3"]); + + kvp.Remove("key1"); + Assert.False(kvp.Entries.ContainsKey("key1")); + + kvp.Remove("invalid_key"); // shouldn't break + + Assert.Equal(2, kvp.Entries.Count); + + // refresh from the C API to make sure everything is in sync + kvp.Refresh(); + Assert.Equal(2, kvp.Entries.Count); + Assert.Equal("value2", kvp.Entries["key2"]); + Assert.Equal("", kvp.Entries["key3"]); + } +} diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/OrtValueTests.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/OrtValueTests.cs index 69baa3f58b23a..eb2e4ad7fb5bb 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/OrtValueTests.cs +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/OrtValueTests.cs @@ -153,6 +153,8 @@ static void VerifyTensorCreateWithData(OrtValue tensor, TensorElementType dat // Verify contained data Assert.Equal(originalData.ToArray(), tensor.GetTensorDataAsSpan().ToArray()); + var byteData = tensor.GetTensorMutableRawData(); + Assert.Equal(byteData.Length, tensor.GetTensorSizeInBytes()); } [Fact(DisplayName = "CreateTensorOverManagedBuffer")] @@ -278,7 +280,8 @@ public void CreateMapFromValues() // Must return always 2 for map since we have two ort values Assert.Equal(2, map.GetValueCount()); - map.ProcessMap((keys, values) => { + map.ProcessMap((keys, values) => + { Assert.Equal(OnnxValueType.ONNX_TYPE_TENSOR, keys.OnnxType); Assert.Equal(OnnxValueType.ONNX_TYPE_TENSOR, values.OnnxType); Assert.Equal(ml_data_1, keys.GetTensorDataAsSpan().ToArray()); diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/TestDataLoader.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/TestDataLoader.cs index 9e5e2b6203790..2efbd55f9f350 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/TestDataLoader.cs +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/TestDataLoader.cs @@ -550,7 +550,7 @@ internal static OrtValue CreateOrtValueFromRawData(OrtAllocator allocator, ReadO var ortValue = OrtValue.CreateAllocatedTensorValue(allocator, elementType, shape); try { - // The endianess data in protobuf is little endian. + // The endianness data in protobuf is little endian. // We simply copy raw memory into the tensor raw data. var span = ortValue.GetTensorMutableRawData(); Assert.Equal(rawData.Length, span.Length); diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp.csproj b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp.csproj index a8abcd2b4aa1c..ee3c8c69aa2ae 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp.csproj +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp.csproj @@ -70,7 +70,8 @@ + $(NativeBuildOutputDir)\custom_op_library*.dll; + $(NativeBuildOutputDir)\example_plugin_ep.dll"> PreserveNewest false diff --git a/csharp/tools/Microsoft.ML.OnnxRuntime.PerfTool/OnnxMl.cs b/csharp/tools/Microsoft.ML.OnnxRuntime.PerfTool/OnnxMl.cs index 9ee5e44a356da..6dd94750afd0a 100644 --- a/csharp/tools/Microsoft.ML.OnnxRuntime.PerfTool/OnnxMl.cs +++ b/csharp/tools/Microsoft.ML.OnnxRuntime.PerfTool/OnnxMl.cs @@ -209,7 +209,7 @@ public enum Version { [pbr::OriginalName("IR_VERSION_2019_9_19")] IrVersion2019919 = 6, /// /// IR VERSION 7 published on May 8, 2020 - /// - Add support to allow function body graph to rely on multiple external opreator sets. + /// - Add support to allow function body graph to rely on multiple external operator sets. /// - Add a list to promote inference graph's initializers to global and /// mutable variables. Global variables are visible in all graphs of the /// stored models. @@ -2481,7 +2481,7 @@ public string DocString { /// /// Name of the function "FunctionProto.name" should be unique within the domain "FunctionProto.domain". /// In case of any conflicts the behavior (whether the model local functions are given higher priority, - /// or standard opserator sets are given higher priotity or this is treated as error) is defined by + /// or standard opserator sets are given higher priority or this is treated as error) is defined by /// the runtimes. /// /// The operator sets imported by FunctionProto should be compatible with the ones diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index 0308b5c79c508..dbe7d9b85092a 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -815,7 +815,7 @@ This version of the operator has been available since version 1 of the 'com.micr scale = 1. / (1. - ratio). ``` - This op functions in much the same was as Dropout-11 and Dropout-13 do, execpt that the mask is output as a bit-packed uint32 tensor, instead of a boolean tensor. + This op functions in much the same was as Dropout-11 and Dropout-13 do, except that the mask is output as a bit-packed uint32 tensor, instead of a boolean tensor. #### Version @@ -2231,9 +2231,9 @@ This version of the operator has been available since version 1 of the 'com.micr
dtype : int
Output Type. Same definition as attribute 'to' for operator Cast.
transA : int
-
Whether A should be transposed. Float 8 only supprted transA=0.
+
Whether A should be transposed. Float 8 only supported transA=0.
transB : int
-
Whether B should be transposed. Float 8 only supprted transB=1.
+
Whether B should be transposed. Float 8 only supported transB=1.
#### Inputs (2 - 6) @@ -2309,7 +2309,7 @@ This version of the operator has been available since version 1 of the 'com.micr
emb : U
-
embeddding - 3D tensor with shape (batch_size, seq_len, dim)
+
embedding - 3D tensor with shape (batch_size, seq_len, dim)
q : T
q state - 4D tensor with shape (batch_size, num_heads, seq_len, dim)
q_rot : T
@@ -2816,7 +2816,7 @@ This version of the operator has been available since version 1 of the 'com.micr Matrix product with right hand matrix being pre-packed and quantized int4 data blob. During quantization, the matrix is divided into blocks, where each block is a - continguous subset inside each column. Each block is quantized into a + contiguous subset inside each column. Each block is quantized into a sequence of 4b integers with a scaling factor and an optional offset. Currently 3 quantization types are supported: (0): block size 32, no offset, (1): block size 32, with offset, (2): block size 64, @@ -6076,7 +6076,7 @@ This version of the operator has been available since version 1 of the 'com.micr ### **com.microsoft.WhisperBeamSearch** - Beam Search for whisper model, especiall with cross_qk features etc. + Beam Search for whisper model, especially with cross_qk features etc. #### Version diff --git a/docs/How_To_Update_ONNX_Dev_Notes.md b/docs/How_To_Update_ONNX_Dev_Notes.md index 895b552508cf6..98d830bd20645 100644 --- a/docs/How_To_Update_ONNX_Dev_Notes.md +++ b/docs/How_To_Update_ONNX_Dev_Notes.md @@ -6,7 +6,7 @@ If you need to update the ONNX submodule to a different version, follow the step ## Update ONNX installation -Currently, ONNXRUNTIME supports two ways to install ONNX cpp dependencies, one is through cmake/deps.txt, and the other one is by vcpkg. And both of them are guarded by CI. It is recommeded to test vcpkg within Windows machines. +Currently, ONNXRUNTIME supports two ways to install ONNX cpp dependencies, one is through cmake/deps.txt, and the other one is by vcpkg. And both of them are guarded by CI. It is recommended to test vcpkg within Windows machines. ### Update the ONNX submodule (commit would be more precise than branch) diff --git a/docs/cmake_guideline.md b/docs/cmake_guideline.md index e03706476d73f..2ea0d5e0841b3 100644 --- a/docs/cmake_guideline.md +++ b/docs/cmake_guideline.md @@ -144,7 +144,7 @@ Here system means the combination of - CPU Arch: x86_32, x86_64, armv6, armv7, arvm7l, aarch64, … - OS: bare-metal, linux, Windows - Libc: gnu libc/ulibc/musl/… -- ABI: ARM has mutilple ABIs like eabi, eabihf… +- ABI: ARM has multiple ABIs like eabi, eabihf… When "host system" != "target system" (any different in the four dimensions), we call it cross-compiling. For example, when you build a Windows EXE on Linux, or build an ARM program on an x86_64 CPU, you are doing cross-compiling. Then special handling is needed. diff --git a/docs/python/examples/plot_convert_pipeline_vectorizer.py b/docs/python/examples/plot_convert_pipeline_vectorizer.py index 2215cb73ee643..62617e581b2e8 100644 --- a/docs/python/examples/plot_convert_pipeline_vectorizer.py +++ b/docs/python/examples/plot_convert_pipeline_vectorizer.py @@ -95,4 +95,4 @@ ######################### # Very similar. *ONNX Runtime* uses floats instead of doubles, -# that explains the small discrepencies. +# that explains the small discrepancies. diff --git a/docs/python/examples/plot_train_convert_predict.py b/docs/python/examples/plot_train_convert_predict.py index f0fd8694fb541..24e49d421f285 100644 --- a/docs/python/examples/plot_train_convert_predict.py +++ b/docs/python/examples/plot_train_convert_predict.py @@ -16,7 +16,7 @@ Train a logistic regression +++++++++++++++++++++++++++ -The first step consists in retrieving the iris datset. +The first step consists in retrieving the iris dataset. """ from sklearn.datasets import load_iris @@ -95,7 +95,7 @@ ############################# # And then with ONNX Runtime. -# The probabilies appear to be +# The probabilities appear to be prob_name = sess.get_outputs()[1].name prob_rt = sess.run([prob_name], {input_name: X_test.astype(numpy.float32)})[0] diff --git a/include/onnxruntime/core/framework/ortdevice.h b/include/onnxruntime/core/framework/ortdevice.h index 472575d1998f5..2a377238e0e27 100644 --- a/include/onnxruntime/core/framework/ortdevice.h +++ b/include/onnxruntime/core/framework/ortdevice.h @@ -103,7 +103,7 @@ struct OrtDevice { }; inline bool operator==(const OrtDevice& left, const OrtDevice& other) { - return left.Id() == other.Id() && left.MemType() == other.MemType() && left.Type() == other.Type(); + return left.Id() == other.Id() && left.MemType() == other.MemType() && left.Type() == other.Type() && left.GetAlignment() == other.GetAlignment(); } inline bool operator!=(const OrtDevice& left, const OrtDevice& other) { diff --git a/include/onnxruntime/core/framework/sparse_tensor.h b/include/onnxruntime/core/framework/sparse_tensor.h index da10e3bc10307..4e9d5c41fbc76 100644 --- a/include/onnxruntime/core/framework/sparse_tensor.h +++ b/include/onnxruntime/core/framework/sparse_tensor.h @@ -93,7 +93,7 @@ class SparseTensor final { /// /// The factory function creates an instance of SparseTensor on the heap - /// using appropriate constructor and initializes OrtValue instance wit it. + /// using appropriate constructor and initializes OrtValue instance with it. /// /// element data type /// dense shape of the sparse tensor @@ -110,7 +110,7 @@ class SparseTensor final { /// /// The factory function creates an instance of SparseTensor on the heap - /// using appropriate constructor and initializes OrtValue instance wit it. + /// using appropriate constructor and initializes OrtValue instance with it. /// /// element data type /// dense shape of the sparse tensor diff --git a/include/onnxruntime/core/graph/model_saving_options.h b/include/onnxruntime/core/graph/model_saving_options.h index 924799f15b247..45536a6967606 100644 --- a/include/onnxruntime/core/graph/model_saving_options.h +++ b/include/onnxruntime/core/graph/model_saving_options.h @@ -21,7 +21,7 @@ struct ModelSavingOptions { explicit ModelSavingOptions(size_t size_threshold) : initializer_size_threshold(size_threshold) {} - // Mimimal initializer size in bytes to be externalized on disk + // Minimal initializer size in bytes to be externalized on disk size_t initializer_size_threshold; // Offset will always be page aligned and allocation granularity aligned for // mmap support. This is done by padding previous tensor data with zeros diff --git a/include/onnxruntime/core/platform/EigenNonBlockingThreadPool.h b/include/onnxruntime/core/platform/EigenNonBlockingThreadPool.h index 26fc440f7bfc5..172ec3ec50686 100644 --- a/include/onnxruntime/core/platform/EigenNonBlockingThreadPool.h +++ b/include/onnxruntime/core/platform/EigenNonBlockingThreadPool.h @@ -256,7 +256,7 @@ class ThreadPoolProfiler { void LogCoreAndBlock(std::ptrdiff_t block_size); // called in main thread to log core and block size for task breakdown void LogThreadId(int thread_idx); // called in child thread to log its id void LogRun(int thread_idx); // called in child thread to log num of run - std::string DumpChildThreadStat(); // return all child statitics collected so far + std::string DumpChildThreadStat(); // return all child statistics collected so far private: static const char* GetEventName(ThreadPoolEvent); @@ -739,7 +739,7 @@ class ThreadPoolTempl : public onnxruntime::concurrency::ExtendedThreadPoolInter // Allocate a new tag to use to identify work items from a given // thread in a parallel section. Ideally, threads will have // unique tags, but re-use is not incorrect if the counter wraps - // (for intsance, if a long-running workload is calling into ORT + // (for instance, if a long-running workload is calling into ORT // from a fresh thread for each request). We must not re-use the // default tag 0 which is used to identify work items added via // Schedule as opposed to requests for help in parallel sections. @@ -912,7 +912,7 @@ class ThreadPoolTempl : public onnxruntime::concurrency::ExtendedThreadPoolInter } } - // Now we know that dispatch is finshed, we synchronize with the + // Now we know that dispatch is finished, we synchronize with the // tasks that were created (if any) for the parallel section. We // revoke tasks still in queues, and then wait for any that are // still running. @@ -1004,7 +1004,7 @@ class ThreadPoolTempl : public onnxruntime::concurrency::ExtendedThreadPoolInter // * First, suppose that a single job is running a series of loops. // Its main thread enters a parallel loop. Initially, let's assume // its preferred worker array is [_,0,1,2], writing "_" for the - // unusued element for the par_idx=0 work that the main thread will + // unused element for the par_idx=0 work that the main thread will // run. // // The main thread schedules the dispatcher task onto worker 0. diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 9a5891f9e236d..0d2da44971b3a 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -425,6 +425,34 @@ typedef enum OrtExecutionProviderDevicePolicy { OrtExecutionProviderDevicePolicy_MIN_OVERALL_POWER, } OrtExecutionProviderDevicePolicy; +/** \brief Delegate to allow providing custom OrtEpDevice selection logic + * + * This delegate is called by the EP selection code to allow the user to provide custom device selection logic. + * The user can use this to select OrtEpDevice instances from the list of available devices. + * + * \param ep_devices The list of available devices. + * \param num_devices The number of available devices. + * \param model_metadata The model metadata. + * \param runtime_metadata The runtime metadata. May be nullptr. + * \param selected Pre-allocated array to populate with selected OrtEpDevice pointers from ep_devices. + * \param max_selected The maximum number of devices that can be selected in the pre-allocated array. + Currently the maximum is 8. + * \param num_selected The number of selected devices. + * \param state Opaque pointer. Required to use the delegate from other languages like C# and python. + * + * \return OrtStatus* Selection status. Return nullptr on success. + * Use CreateStatus to provide error info. Use ORT_FAIL as the error code. + * ORT will release the OrtStatus* if not null. + */ +typedef OrtStatus*(ORT_API_CALL* EpSelectionDelegate)(_In_ const OrtEpDevice** ep_devices, + _In_ size_t num_devices, + _In_ const OrtKeyValuePairs* model_metadata, + _In_opt_ const OrtKeyValuePairs* runtime_metadata, + _Inout_ const OrtEpDevice** selected, + _In_ size_t max_selected, + _Out_ size_t* num_selected, + _In_ void* state); + /** \brief Algorithm to use for cuDNN Convolution Op */ typedef enum OrtCudnnConvAlgoSearch { @@ -546,7 +574,7 @@ typedef struct OrtROCMProviderOptions { */ int device_id; - /** \brief ROCM MIOpen Convolution algorithm exaustive search option. + /** \brief ROCM MIOpen Convolution algorithm exhaustive search option. * Defaults to 0 (false). */ int miopen_conv_exhaustive_search; @@ -3608,9 +3636,9 @@ struct OrtApi { * \param[in] op_name Operator name * \param[in] domain Operator domain * \param[in] version Operator opset version - * \param[in] type_constraint_names Name of the type contraints, such as "T" or "T1" - * \param[in] type_constraint_values Type of each contraints - * \param[in] type_constraint_count Number of contraints + * \param[in] type_constraint_names Name of the type constraints, such as "T" or "T1" + * \param[in] type_constraint_values Type of each constraints + * \param[in] type_constraint_count Number of constraints * \param[in] attr_values Attributes used to initialize the operator * \param[in] attr_count Number of the attributes * \param[in] input_count Number of inputs @@ -4298,7 +4326,7 @@ struct OrtApi { /** \brief Get the logging severity level of the ::OrtLogger. * - * Can be used in a custom operator to get the logging serverity level of the ::OrtLogger associated with + * Can be used in a custom operator to get the logging severity level of the ::OrtLogger associated with * the ::OrtKernelInfo. * * \param[in] logger The ::OrtLogger instance. @@ -4716,12 +4744,12 @@ struct OrtApi { _In_reads_(num_keys) const char* const* provider_options_values, _In_ size_t num_keys); - /** \brief Get scratch buffer from the corresponding allocator under the sepcific OrtMemoryInfo object. + /** \brief Get scratch buffer from the corresponding allocator under the specific OrtMemoryInfo object. * NOTE: callers are responsible to release this scratch buffer from the corresponding allocator * \param[in] context OrtKernelContext instance * \param[in] mem_info OrtMemoryInfo instance * \param[in] count_or_bytes How many bytes is this scratch buffer - * \param[out] out A pointer to the scrach buffer + * \param[out] out A pointer to the scratch buffer * * \snippet{doc} snippets.dox OrtStatus Return Value * @@ -5073,7 +5101,8 @@ struct OrtApi { ORT_API2_STATUS(GetEpDevices, _In_ const OrtEnv* env, _Outptr_ const OrtEpDevice* const** ep_devices, _Out_ size_t* num_ep_devices); - /** \brief Append execution provider to the session options by name. + /** \brief Append the execution provider that is responsible for the selected OrtEpDevice instances + * to the session options. * * \param[in] session_options Session options to add execution provider to. * \param[in] env Environment that execution providers were registered with. @@ -5098,6 +5127,33 @@ struct OrtApi { _In_reads_(num_op_options) const char* const* ep_option_vals, size_t num_ep_options); + /** \brief Set the execution provider selection policy for the session. + * + * Allows users to specify a device selection policy for automatic execution provider (EP) selection. + * If custom selection is required please use SessionOptionsSetEpSelectionPolicyDelegate instead. + * + * \param[in] session_options The OrtSessionOptions instance. + * \param[in] policy The device selection policy to use (see OrtExecutionProviderDevicePolicy). + * + * \since Version 1.22 + */ + ORT_API2_STATUS(SessionOptionsSetEpSelectionPolicy, _In_ OrtSessionOptions* session_options, + _In_ OrtExecutionProviderDevicePolicy policy); + + /** \brief Set the execution provider selection policy delegate for the session. + * + * Allows users to provide a custom device selection policy for automatic execution provider (EP) selection. + * + * \param[in] session_options The OrtSessionOptions instance. + * \param[in] delegate Delegate callback for custom selection. + * \param[in] delegate_state Optional state that will be passed to the delegate callback. nullptr if not required. + * + * \since Version 1.22 + */ + ORT_API2_STATUS(SessionOptionsSetEpSelectionPolicyDelegate, _In_ OrtSessionOptions* session_options, + _In_ EpSelectionDelegate delegate, + _In_opt_ void* delegate_state); + /** \brief Get the hardware device type. * * \param[in] device The OrtHardwareDevice instance to query. @@ -5195,6 +5251,21 @@ struct OrtApi { * \since Version 1.22. */ const OrtEpApi*(ORT_API_CALL* GetEpApi)(); + + /** \brief Compute total size in bytes of the tensor data contained in an OrtValue. + * + * Returns the total number of bytes used to store the tensor data. For numeric tensors, + * this is sizeof(element_type) * total_element_count. OrtValues that are not tensors or + * that are tensors that contain strings will cause an error to be returned. + * + * \param[in] ort_value OrtValue instance containing a tensor + * \param[out] size The total size of the tensor data in bytes + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.23 + */ + ORT_API2_STATUS(GetTensorSizeInBytes, _In_ const OrtValue* ort_value, _Out_ size_t* size); }; /* @@ -5257,7 +5328,7 @@ struct OrtCustomOp { // Returns the memory type of the input tensors. This API allows the custom op // to place the inputs on specific devices. By default, it returns // OrtMemTypeDefault, which means the input is placed on the default device for - // the execution provider. If the inputs need to be with different memory tyeps, + // the execution provider. If the inputs need to be with different memory types, // this function can be overridden to return the specific memory types. OrtMemType(ORT_API_CALL* GetInputMemoryType)(_In_ const struct OrtCustomOp* op, _In_ size_t index); @@ -6060,7 +6131,7 @@ struct OrtEpFactory { * \param[in] session_options The OrtSessionOptions instance that contains the configuration options for the * session. This will include ep_options from GetSupportedDevices as well as any * user provided overrides. - * Execution provider options will have been added with a prefix of 'ep..'. + * Execution provider options will have been added with a prefix of 'ep.[ep name].'. * The OrtSessionOptions instance will NOT be valid after this call and should not be * stored for later use. * \param[in] logger The OrtLogger instance for the session that the execution provider should use for logging. @@ -6068,7 +6139,7 @@ struct OrtEpFactory { * * \snippet{doc} snippets.dox OrtStatus Return Value * - * \since Version . This is a placeholder. + * \since Version [coming soon]. This is a placeholder. */ OrtStatus*(ORT_API_CALL* CreateEp)(_In_ OrtEpFactory* this_ptr, _In_reads_(num_devices) const OrtHardwareDevice* const* devices, @@ -6082,7 +6153,7 @@ struct OrtEpFactory { * \param[in] this_ptr The OrtEpFactory instance. * \param[in] ep The OrtEp instance to release. * - * \since Version . This is a placeholder. + * \since Version [coming soon]. This is a placeholder. */ void(ORT_API_CALL* ReleaseEp)(OrtEpFactory* this_ptr, struct OrtEp* ep); }; diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index 0ecc27c59dc28..bf8e57894d384 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -195,7 +195,7 @@ inline const OrtEpApi& GetEpApi() { * uint16_t buffers to/from Ort::Float16_t to feed and retrieve data. * * \code{.unparsed} - * // This example demonstrates converion from float to float16 + * // This example demonstrates conversion from float to float16 * constexpr float values[] = {1.f, 2.f, 3.f, 4.f, 5.f}; * std::vector fp16_values; * fp16_values.reserve(std::size(values)); @@ -337,7 +337,7 @@ static_assert(sizeof(Float16_t) == sizeof(uint16_t), "Sizes must match"); * uint16_t buffers to/from Ort::BFloat16_t to feed and retrieve data. * * \code{.unparsed} - * // This example demonstrates converion from float to float16 + * // This example demonstrates conversion from float to float16 * constexpr float values[] = {1.f, 2.f, 3.f, 4.f, 5.f}; * std::vector bfp16_values; * bfp16_values.reserve(std::size(values)); @@ -772,7 +772,7 @@ struct KeyValuePairsImpl : Ort::detail::Base { // Const object holder that does not own the underlying object using ConstKeyValuePairs = detail::KeyValuePairsImpl>; -/** \brief Wrapper around ::OrtKeyValuePair */ +/** \brief Wrapper around ::OrtKeyValuePairs */ struct KeyValuePairs : detail::KeyValuePairsImpl { explicit KeyValuePairs(std::nullptr_t) {} ///< No instance is created /// Take ownership of a pointer created by C API @@ -1085,19 +1085,29 @@ struct SessionOptionsImpl : ConstSessionOptionsImpl { SessionOptionsImpl& AppendExecutionProvider_TensorRT(const OrtTensorRTProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_TensorRT SessionOptionsImpl& AppendExecutionProvider_TensorRT_V2(const OrtTensorRTProviderOptionsV2& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_TensorRT SessionOptionsImpl& AppendExecutionProvider_MIGraphX(const OrtMIGraphXProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_MIGraphX - ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_CANN + /// Wraps OrtApi::SessionOptionsAppendExecutionProvider_CANN SessionOptionsImpl& AppendExecutionProvider_CANN(const OrtCANNProviderOptions& provider_options); - ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_Dnnl + /// Wraps OrtApi::SessionOptionsAppendExecutionProvider_Dnnl SessionOptionsImpl& AppendExecutionProvider_Dnnl(const OrtDnnlProviderOptions& provider_options); /// Wraps OrtApi::SessionOptionsAppendExecutionProvider. Currently supports QNN, SNPE and XNNPACK. SessionOptionsImpl& AppendExecutionProvider(const std::string& provider_name, const std::unordered_map& provider_options = {}); + /// Append EPs that have been registered previously with the OrtEnv. + /// Wraps OrtApi::SessionOptionsAppendExecutionProvider_V2 SessionOptionsImpl& AppendExecutionProvider_V2(Env& env, const std::vector& ep_devices, const KeyValuePairs& ep_options); + /// Append EPs that have been registered previously with the OrtEnv. + /// Wraps OrtApi::SessionOptionsAppendExecutionProvider_V2 SessionOptionsImpl& AppendExecutionProvider_V2(Env& env, const std::vector& ep_devices, const std::unordered_map& ep_options); + /// Wraps OrtApi::SessionOptionsSetEpSelectionPolicy + SessionOptionsImpl& SetEpSelectionPolicy(OrtExecutionProviderDevicePolicy policy); + + /// Wraps OrtApi::SessionOptionsSetEpSelectionPolicyDelegate + SessionOptionsImpl& SetEpSelectionPolicy(EpSelectionDelegate delegate, void* state = nullptr); + SessionOptionsImpl& SetCustomCreateThreadFn(OrtCustomCreateThreadFn ort_custom_create_thread_fn); ///< Wraps OrtApi::SessionOptionsSetCustomCreateThreadFn SessionOptionsImpl& SetCustomThreadCreationOptions(void* ort_custom_thread_creation_options); ///< Wraps OrtApi::SessionOptionsSetCustomThreadCreationOptions SessionOptionsImpl& SetCustomJoinThreadFn(OrtCustomJoinThreadFn ort_custom_join_thread_fn); ///< Wraps OrtApi::SessionOptionsSetCustomJoinThreadFn @@ -1728,6 +1738,18 @@ struct ConstValueImpl : Base { /// byte length for the specified string element size_t GetStringTensorElementLength(size_t element_index) const; + /// + /// Returns the total size of the tensor data in bytes. + /// + /// The total size of the tensor data in bytes + /// Throws an exception if the OrtValue does not contain a tensor or + /// if it contains a tensor that contains strings + /// + /// For numeric tensors, this is sizeof(element_type) * total_element_count. + /// + /// + size_t GetTensorSizeInBytes() const; ///< Wraps OrtApi::GetTensorSizeInBytes + #if !defined(DISABLE_SPARSE_TENSORS) /// /// The API returns the sparse data format this OrtValue holds in a sparse tensor. @@ -1809,7 +1831,7 @@ struct ValueImpl : ConstValueImpl { /// by the vector of dims. /// /// - /// [in] expressed by a vecotr of dimensions offsets + /// [in] expressed by a vector of dimensions offsets /// template R& At(const std::vector& location); @@ -2666,7 +2688,7 @@ struct CustomOpBase : OrtCustomOp { return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED; } - // Default implemention of GetInputMemoryType() that returns OrtMemTypeDefault + // Default implementation of GetInputMemoryType() that returns OrtMemTypeDefault OrtMemType GetInputMemoryType(size_t /*index*/) const { return OrtMemTypeDefault; } diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index 48b3b80cced55..0d0b3198a8736 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -1149,6 +1149,18 @@ inline SessionOptionsImpl& SessionOptionsImpl::AppendExecutionProvider_V2( return *this; } +template +inline SessionOptionsImpl& SessionOptionsImpl::SetEpSelectionPolicy(OrtExecutionProviderDevicePolicy policy) { + ThrowOnError(GetApi().SessionOptionsSetEpSelectionPolicy(this->p_, policy)); + return *this; +} + +template +inline SessionOptionsImpl& SessionOptionsImpl::SetEpSelectionPolicy(EpSelectionDelegate delegate, void* state) { + ThrowOnError(GetApi().SessionOptionsSetEpSelectionPolicyDelegate(this->p_, delegate, state)); + return *this; +} + template inline SessionOptionsImpl& SessionOptionsImpl::SetCustomCreateThreadFn(OrtCustomCreateThreadFn ort_custom_create_thread_fn) { ThrowOnError(GetApi().SessionOptionsSetCustomCreateThreadFn(this->p_, ort_custom_create_thread_fn)); @@ -1811,6 +1823,13 @@ inline size_t ConstValueImpl::GetStringTensorElementLength(size_t element_ind return out; } +template +inline size_t ConstValueImpl::GetTensorSizeInBytes() const { + size_t out; + ThrowOnError(GetApi().GetTensorSizeInBytes(this->p_, &out)); + return out; +} + template template inline const R* ConstValueImpl::GetTensorData() const { diff --git a/include/onnxruntime/core/session/onnxruntime_lite_custom_op.h b/include/onnxruntime/core/session/onnxruntime_lite_custom_op.h index ce87d8c56d3fe..5002e16ba116c 100644 --- a/include/onnxruntime/core/session/onnxruntime_lite_custom_op.h +++ b/include/onnxruntime/core/session/onnxruntime_lite_custom_op.h @@ -361,7 +361,7 @@ struct TensorArray : public ArgBase { tensor = std::make_unique>(ctx, ith_input, true); break; default: - ORT_CXX_API_THROW("unknow input type", ORT_RUNTIME_EXCEPTION); + ORT_CXX_API_THROW("unknown input type", ORT_RUNTIME_EXCEPTION); break; } tensors_.emplace_back(tensor.release()); diff --git a/java/src/main/java/ai/onnxruntime/OnnxSparseTensor.java b/java/src/main/java/ai/onnxruntime/OnnxSparseTensor.java index 8400ef53ff6d7..3901df19d0570 100644 --- a/java/src/main/java/ai/onnxruntime/OnnxSparseTensor.java +++ b/java/src/main/java/ai/onnxruntime/OnnxSparseTensor.java @@ -927,7 +927,7 @@ public BlockSparseTensor( } if (indicesShape.length < 2) { throw new IllegalArgumentException( - "Expected [numBlocks, co-ordinates] or larger, but indices shape was " + "Expected [numBlocks, coordinates] or larger, but indices shape was " + Arrays.toString(indicesShape)); } } diff --git a/java/src/test/java/ai/onnxruntime/InferenceTest.java b/java/src/test/java/ai/onnxruntime/InferenceTest.java index 41c4f5e02bd69..7246738fd4406 100644 --- a/java/src/test/java/ai/onnxruntime/InferenceTest.java +++ b/java/src/test/java/ai/onnxruntime/InferenceTest.java @@ -1145,7 +1145,7 @@ public void testPreTrainedModel(String opset, String modelName) throws IOExcepti try (OrtSession session = env.createSession(onnxModelFileName)) { String testDataDirNamePattern; if (opset.equals("opset9") && modelName.equals("LSTM_Seq_lens_unpacked")) { - testDataDirNamePattern = "seq_lens"; // discrepency in data directory + testDataDirNamePattern = "seq_lens"; // discrepancy in data directory } else { testDataDirNamePattern = "test_data"; } diff --git a/java/src/test/java/ai/onnxruntime/OnnxMl.java b/java/src/test/java/ai/onnxruntime/OnnxMl.java index d96c6ccba6e31..1347fb5e1a259 100644 --- a/java/src/test/java/ai/onnxruntime/OnnxMl.java +++ b/java/src/test/java/ai/onnxruntime/OnnxMl.java @@ -18317,7 +18317,7 @@ public interface TensorProtoOrBuilder * For float and complex64 values * Complex64 tensors are encoded as a single array of floats, * with the real components appearing in odd numbered positions, - * and the corresponding imaginary component apparing in the + * and the corresponding imaginary component appearing in the * subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i] * is encoded as [1.0, 2.0 ,3.0 ,4.0] * When this field is present, the data_type field MUST be FLOAT or COMPLEX64. @@ -18333,7 +18333,7 @@ public interface TensorProtoOrBuilder * For float and complex64 values * Complex64 tensors are encoded as a single array of floats, * with the real components appearing in odd numbered positions, - * and the corresponding imaginary component apparing in the + * and the corresponding imaginary component appearing in the * subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i] * is encoded as [1.0, 2.0 ,3.0 ,4.0] * When this field is present, the data_type field MUST be FLOAT or COMPLEX64. @@ -18349,7 +18349,7 @@ public interface TensorProtoOrBuilder * For float and complex64 values * Complex64 tensors are encoded as a single array of floats, * with the real components appearing in odd numbered positions, - * and the corresponding imaginary component apparing in the + * and the corresponding imaginary component appearing in the * subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i] * is encoded as [1.0, 2.0 ,3.0 ,4.0] * When this field is present, the data_type field MUST be FLOAT or COMPLEX64. @@ -18663,7 +18663,7 @@ public interface TensorProtoOrBuilder * For double * Complex128 tensors are encoded as a single array of doubles, * with the real components appearing in odd numbered positions, - * and the corresponding imaginary component apparing in the + * and the corresponding imaginary component appearing in the * subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i] * is encoded as [1.0, 2.0 ,3.0 ,4.0] * When this field is present, the data_type field MUST be DOUBLE or COMPLEX128 @@ -18679,7 +18679,7 @@ public interface TensorProtoOrBuilder * For double * Complex128 tensors are encoded as a single array of doubles, * with the real components appearing in odd numbered positions, - * and the corresponding imaginary component apparing in the + * and the corresponding imaginary component appearing in the * subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i] * is encoded as [1.0, 2.0 ,3.0 ,4.0] * When this field is present, the data_type field MUST be DOUBLE or COMPLEX128 @@ -18695,7 +18695,7 @@ public interface TensorProtoOrBuilder * For double * Complex128 tensors are encoded as a single array of doubles, * with the real components appearing in odd numbered positions, - * and the corresponding imaginary component apparing in the + * and the corresponding imaginary component appearing in the * subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i] * is encoded as [1.0, 2.0 ,3.0 ,4.0] * When this field is present, the data_type field MUST be DOUBLE or COMPLEX128 @@ -20167,7 +20167,7 @@ public OnnxMl.TensorProto.SegmentOrBuilder getSegmentOrBuilder() { * For float and complex64 values * Complex64 tensors are encoded as a single array of floats, * with the real components appearing in odd numbered positions, - * and the corresponding imaginary component apparing in the + * and the corresponding imaginary component appearing in the * subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i] * is encoded as [1.0, 2.0 ,3.0 ,4.0] * When this field is present, the data_type field MUST be FLOAT or COMPLEX64. @@ -20185,7 +20185,7 @@ public java.util.List getFloatDataList() { * For float and complex64 values * Complex64 tensors are encoded as a single array of floats, * with the real components appearing in odd numbered positions, - * and the corresponding imaginary component apparing in the + * and the corresponding imaginary component appearing in the * subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i] * is encoded as [1.0, 2.0 ,3.0 ,4.0] * When this field is present, the data_type field MUST be FLOAT or COMPLEX64. @@ -20203,7 +20203,7 @@ public int getFloatDataCount() { * For float and complex64 values * Complex64 tensors are encoded as a single array of floats, * with the real components appearing in odd numbered positions, - * and the corresponding imaginary component apparing in the + * and the corresponding imaginary component appearing in the * subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i] * is encoded as [1.0, 2.0 ,3.0 ,4.0] * When this field is present, the data_type field MUST be FLOAT or COMPLEX64. @@ -20621,7 +20621,7 @@ public OnnxMl.TensorProto.DataLocation getDataLocation() { * For double * Complex128 tensors are encoded as a single array of doubles, * with the real components appearing in odd numbered positions, - * and the corresponding imaginary component apparing in the + * and the corresponding imaginary component appearing in the * subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i] * is encoded as [1.0, 2.0 ,3.0 ,4.0] * When this field is present, the data_type field MUST be DOUBLE or COMPLEX128 @@ -20639,7 +20639,7 @@ public java.util.List getDoubleDataList() { * For double * Complex128 tensors are encoded as a single array of doubles, * with the real components appearing in odd numbered positions, - * and the corresponding imaginary component apparing in the + * and the corresponding imaginary component appearing in the * subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i] * is encoded as [1.0, 2.0 ,3.0 ,4.0] * When this field is present, the data_type field MUST be DOUBLE or COMPLEX128 @@ -20657,7 +20657,7 @@ public int getDoubleDataCount() { * For double * Complex128 tensors are encoded as a single array of doubles, * with the real components appearing in odd numbered positions, - * and the corresponding imaginary component apparing in the + * and the corresponding imaginary component appearing in the * subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i] * is encoded as [1.0, 2.0 ,3.0 ,4.0] * When this field is present, the data_type field MUST be DOUBLE or COMPLEX128 @@ -21732,7 +21732,7 @@ private void ensureFloatDataIsMutable() { * For float and complex64 values * Complex64 tensors are encoded as a single array of floats, * with the real components appearing in odd numbered positions, - * and the corresponding imaginary component apparing in the + * and the corresponding imaginary component appearing in the * subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i] * is encoded as [1.0, 2.0 ,3.0 ,4.0] * When this field is present, the data_type field MUST be FLOAT or COMPLEX64. @@ -21750,7 +21750,7 @@ public java.util.List getFloatDataList() { * For float and complex64 values * Complex64 tensors are encoded as a single array of floats, * with the real components appearing in odd numbered positions, - * and the corresponding imaginary component apparing in the + * and the corresponding imaginary component appearing in the * subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i] * is encoded as [1.0, 2.0 ,3.0 ,4.0] * When this field is present, the data_type field MUST be FLOAT or COMPLEX64. @@ -21768,7 +21768,7 @@ public int getFloatDataCount() { * For float and complex64 values * Complex64 tensors are encoded as a single array of floats, * with the real components appearing in odd numbered positions, - * and the corresponding imaginary component apparing in the + * and the corresponding imaginary component appearing in the * subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i] * is encoded as [1.0, 2.0 ,3.0 ,4.0] * When this field is present, the data_type field MUST be FLOAT or COMPLEX64. @@ -21786,7 +21786,7 @@ public float getFloatData(int index) { * For float and complex64 values * Complex64 tensors are encoded as a single array of floats, * with the real components appearing in odd numbered positions, - * and the corresponding imaginary component apparing in the + * and the corresponding imaginary component appearing in the * subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i] * is encoded as [1.0, 2.0 ,3.0 ,4.0] * When this field is present, the data_type field MUST be FLOAT or COMPLEX64. @@ -21807,7 +21807,7 @@ public Builder setFloatData(int index, float value) { * For float and complex64 values * Complex64 tensors are encoded as a single array of floats, * with the real components appearing in odd numbered positions, - * and the corresponding imaginary component apparing in the + * and the corresponding imaginary component appearing in the * subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i] * is encoded as [1.0, 2.0 ,3.0 ,4.0] * When this field is present, the data_type field MUST be FLOAT or COMPLEX64. @@ -21828,7 +21828,7 @@ public Builder addFloatData(float value) { * For float and complex64 values * Complex64 tensors are encoded as a single array of floats, * with the real components appearing in odd numbered positions, - * and the corresponding imaginary component apparing in the + * and the corresponding imaginary component appearing in the * subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i] * is encoded as [1.0, 2.0 ,3.0 ,4.0] * When this field is present, the data_type field MUST be FLOAT or COMPLEX64. @@ -21849,7 +21849,7 @@ public Builder addAllFloatData(java.lang.Iterable val * For float and complex64 values * Complex64 tensors are encoded as a single array of floats, * with the real components appearing in odd numbered positions, - * and the corresponding imaginary component apparing in the + * and the corresponding imaginary component appearing in the * subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i] * is encoded as [1.0, 2.0 ,3.0 ,4.0] * When this field is present, the data_type field MUST be FLOAT or COMPLEX64. @@ -23107,7 +23107,7 @@ private void ensureDoubleDataIsMutable() { * For double * Complex128 tensors are encoded as a single array of doubles, * with the real components appearing in odd numbered positions, - * and the corresponding imaginary component apparing in the + * and the corresponding imaginary component appearing in the * subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i] * is encoded as [1.0, 2.0 ,3.0 ,4.0] * When this field is present, the data_type field MUST be DOUBLE or COMPLEX128 @@ -23125,7 +23125,7 @@ public java.util.List getDoubleDataList() { * For double * Complex128 tensors are encoded as a single array of doubles, * with the real components appearing in odd numbered positions, - * and the corresponding imaginary component apparing in the + * and the corresponding imaginary component appearing in the * subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i] * is encoded as [1.0, 2.0 ,3.0 ,4.0] * When this field is present, the data_type field MUST be DOUBLE or COMPLEX128 @@ -23143,7 +23143,7 @@ public int getDoubleDataCount() { * For double * Complex128 tensors are encoded as a single array of doubles, * with the real components appearing in odd numbered positions, - * and the corresponding imaginary component apparing in the + * and the corresponding imaginary component appearing in the * subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i] * is encoded as [1.0, 2.0 ,3.0 ,4.0] * When this field is present, the data_type field MUST be DOUBLE or COMPLEX128 @@ -23161,7 +23161,7 @@ public double getDoubleData(int index) { * For double * Complex128 tensors are encoded as a single array of doubles, * with the real components appearing in odd numbered positions, - * and the corresponding imaginary component apparing in the + * and the corresponding imaginary component appearing in the * subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i] * is encoded as [1.0, 2.0 ,3.0 ,4.0] * When this field is present, the data_type field MUST be DOUBLE or COMPLEX128 @@ -23182,7 +23182,7 @@ public Builder setDoubleData(int index, double value) { * For double * Complex128 tensors are encoded as a single array of doubles, * with the real components appearing in odd numbered positions, - * and the corresponding imaginary component apparing in the + * and the corresponding imaginary component appearing in the * subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i] * is encoded as [1.0, 2.0 ,3.0 ,4.0] * When this field is present, the data_type field MUST be DOUBLE or COMPLEX128 @@ -23203,7 +23203,7 @@ public Builder addDoubleData(double value) { * For double * Complex128 tensors are encoded as a single array of doubles, * with the real components appearing in odd numbered positions, - * and the corresponding imaginary component apparing in the + * and the corresponding imaginary component appearing in the * subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i] * is encoded as [1.0, 2.0 ,3.0 ,4.0] * When this field is present, the data_type field MUST be DOUBLE or COMPLEX128 @@ -23224,7 +23224,7 @@ public Builder addAllDoubleData(java.lang.Iterable v * For double * Complex128 tensors are encoded as a single array of doubles, * with the real components appearing in odd numbered positions, - * and the corresponding imaginary component apparing in the + * and the corresponding imaginary component appearing in the * subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i] * is encoded as [1.0, 2.0 ,3.0 ,4.0] * When this field is present, the data_type field MUST be DOUBLE or COMPLEX128 diff --git a/js/README.md b/js/README.md index dbc58f3a75ebd..4fc1af240852f 100644 --- a/js/README.md +++ b/js/README.md @@ -307,7 +307,7 @@ From ORT v1.19 onwards, the ONNX Runtime Mobile packages are no longer published 3. In ``, run the below python script to build the ONNX Runtime Android archive file. On a Windows machine, this requires an admin account to build. - You can build a 'full' package that supports all operators and types, or a reduced size package that supports a limited set of operators and types based on your model/s to miminize the binary size. + You can build a 'full' package that supports all operators and types, or a reduced size package that supports a limited set of operators and types based on your model/s to minimize the binary size. See [here](https://onnxruntime.ai/docs/build/custom.html) for information about how the reduced build works, including creating the configuration file using your model/s. The instructions here show how to build a 'full' package. diff --git a/js/common/lib/inference-session-impl.ts b/js/common/lib/inference-session-impl.ts index 797dba8b94089..877c595bffd15 100644 --- a/js/common/lib/inference-session-impl.ts +++ b/js/common/lib/inference-session-impl.ts @@ -6,7 +6,7 @@ import { InferenceSessionHandler } from './backend.js'; import { InferenceSession as InferenceSessionInterface } from './inference-session.js'; import { OnnxValue } from './onnx-value.js'; import { Tensor } from './tensor.js'; -import { TRACE_FUNC_BEGIN, TRACE_FUNC_END } from './trace.js'; +import { TRACE_FUNC_BEGIN, TRACE_FUNC_END, TRACE_EVENT_BEGIN, TRACE_EVENT_END } from './trace.js'; type SessionOptions = InferenceSessionInterface.SessionOptions; type RunOptions = InferenceSessionInterface.RunOptions; @@ -22,6 +22,7 @@ export class InferenceSession implements InferenceSessionInterface { run(feeds: FeedsType, fetches: FetchesType, options?: RunOptions): Promise; async run(feeds: FeedsType, arg1?: FetchesType | RunOptions, arg2?: RunOptions): Promise { TRACE_FUNC_BEGIN(); + TRACE_EVENT_BEGIN('InferenceSession.run'); const fetches: { [name: string]: OnnxValue | null } = {}; let options: RunOptions = {}; // check inputs @@ -120,6 +121,7 @@ export class InferenceSession implements InferenceSessionInterface { } } } + TRACE_EVENT_END('InferenceSession.run'); TRACE_FUNC_END(); return returnValue; } @@ -144,6 +146,7 @@ export class InferenceSession implements InferenceSessionInterface { arg3?: SessionOptions, ): Promise { TRACE_FUNC_BEGIN(); + TRACE_EVENT_BEGIN('InferenceSession.create'); // either load from a file or buffer let filePathOrUint8Array: string | Uint8Array; let options: SessionOptions = {}; @@ -207,6 +210,7 @@ export class InferenceSession implements InferenceSessionInterface { // resolve backend, update session options with validated EPs, and create session handler const [backend, optionsWithValidatedEPs] = await resolveBackendAndExecutionProviders(options); const handler = await backend.createInferenceSessionHandler(filePathOrUint8Array, optionsWithValidatedEPs); + TRACE_EVENT_END('InferenceSession.create'); TRACE_FUNC_END(); return new InferenceSession(handler); } diff --git a/js/common/lib/trace.ts b/js/common/lib/trace.ts index 25d178f15a29d..0f20dd39935ac 100644 --- a/js/common/lib/trace.ts +++ b/js/common/lib/trace.ts @@ -51,3 +51,25 @@ export const TRACE_FUNC_END = (extraMsg?: string) => { } TRACE_FUNC('END', extraMsg); }; + +/** + * @ignore + */ +export const TRACE_EVENT_BEGIN = (extraMsg?: string) => { + if (typeof env.trace === 'undefined' ? !env.wasm.trace : !env.trace) { + return; + } + // eslint-disable-next-line no-console + console.time(`ORT::${extraMsg}`); +}; + +/** + * @ignore + */ +export const TRACE_EVENT_END = (extraMsg?: string) => { + if (typeof env.trace === 'undefined' ? !env.wasm.trace : !env.trace) { + return; + } + // eslint-disable-next-line no-console + console.timeEnd(`ORT::${extraMsg}`); +}; diff --git a/js/node/CMakeLists.txt b/js/node/CMakeLists.txt index 2bd6f22e5f901..52af5dc48a21a 100644 --- a/js/node/CMakeLists.txt +++ b/js/node/CMakeLists.txt @@ -113,12 +113,6 @@ endif() if (WIN32) file(COPY ${ONNXRUNTIME_WIN_BIN_DIR}/onnxruntime.dll DESTINATION ${dist_folder}) - if (ORT_NODEJS_DLL_DEPS) - foreach(dll ${ORT_NODEJS_DLL_DEPS}) - file(COPY ${dll} DESTINATION ${dist_folder}) - endforeach() - endif() - elseif (APPLE) file(COPY ${ONNXRUNTIME_BUILD_DIR}/libonnxruntime.dylib DESTINATION ${dist_folder} FOLLOW_SYMLINK_CHAIN) @@ -128,3 +122,9 @@ elseif (UNIX) else() message(FATAL_ERROR "Platform not supported.") endif() + +if (ORT_NODEJS_DLL_DEPS) + foreach(dll ${ORT_NODEJS_DLL_DEPS}) + file(COPY ${dll} DESTINATION ${dist_folder}) + endforeach() +endif() diff --git a/js/node/README.md b/js/node/README.md index c271d8daccc8b..b8414546c4729 100644 --- a/js/node/README.md +++ b/js/node/README.md @@ -27,13 +27,14 @@ The following table lists the supported versions of ONNX Runtime Node.js binding | EPs/Platforms | Windows x64 | Windows arm64 | Linux x64 | Linux arm64 | MacOS x64 | MacOS arm64 | | ------------- | ------------------ | ------------------ | ------------------ | ------------------ | ------------------ | ------------------ | | CPU | ✔️ | ✔️ | ✔️ | ✔️ | ✔️ | ✔️ | -| WebGPU | ✔️ \[1] | ✔️ \[1] | ✔️ \[1] | ✔️ \[1] | ✔️ \[1] | ✔️ \[1] | +| WebGPU | ✔️ \[1] | ✔️ \[1] | ❌ \[2] | ❌ \[2] | ✔️ \[1] | ✔️ \[1] | | DirectML | ✔️ | ✔️ | ❌ | ❌ | ❌ | ❌ | -| CUDA | ❌ | ❌ | ✔️\[2] | ❌ | ❌ | ❌ | +| CUDA | ❌ | ❌ | ✔️\[3] | ❌ | ❌ | ❌ | | CoreML | ❌ | ❌ | ❌ | ❌ | ✔️ | ✔️ | - \[1]: WebGPU support is currently experimental. -- \[2]: CUDA v12. See [CUDA EP Installation](#cuda-ep-installation) for details. +- \[2]: WebGPU support is not available on Linux x64 and arm64 yet in the pre-built binaries. +- \[3]: CUDA v12. See [CUDA EP Installation](#cuda-ep-installation) for details. To use on platforms without pre-built binaries, you can build Node.js binding from source and consume it by `npm install /js/node/`. See also [instructions](https://onnxruntime.ai/docs/build/inferencing.html#apis-and-language-bindings) for building ONNX Runtime Node.js binding locally. diff --git a/js/web/docs/webnn-operators.md b/js/web/docs/webnn-operators.md index 01acb257e9d0f..981a684154df1 100644 --- a/js/web/docs/webnn-operators.md +++ b/js/web/docs/webnn-operators.md @@ -64,6 +64,7 @@ platforms. Check the [WebNN status](https://webmachinelearning.github.io/webnn-s | LRN | ai.onnx(7-12, 13+) | pad, averagePool2d, transpose, add, mul, pow, div | | | LSTM | ai.onnx(7-13, 14-21, 22+) | lstm | Only supports 'layout' == 0, 'input_forget' == 0. 'clip' is not supported. The activation functions in 'activations' must be one of 'Relu', 'Tanh', 'Sigmoid'. Forward and backward activations must be the same if bidirectional. 'sequence_lens' if present should be constant with values equal to the first dimension length of input 'X' | | MatMul | ai.onnx(7-8, 9-12, 13+) | matmul | | +| MatMulInteger | ai.onnx(10+) | cast, dequantizeLinear, matmul | | | MatMulNBits | com.microsoft(1+) | add, dequantizeLinear, matmul, reshape, transpose | Inputs 'B' and 'zero_points' (if present) should be constants, input 'g_idx' is not supported, only bits=4 is supported | | Max | ai.onnx(7, 8-11, 12, 13+) | max | | | MaxPool | ai.onnx(7, 8-9, 10, 11, 12+) | maxPool2d | Only supports 4-D input, 2-D 'kernel_shape', 'storage_order' != 1, one output | @@ -91,7 +92,7 @@ platforms. Check the [WebNN status](https://webmachinelearning.github.io/webnn-s | Relu | ai.onnx(7-12, 13, 14+) | relu | | | Reshape | ai.onnx(7-12, 13, 14-18, 19-20, 21+) | reshape | Input 'shape' should be a constant, 0 dimension value in 'shape' is not supported | | Resize | ai.onnx(11-12, 13-17, 18, 19+) | resample2d | Only supports 4-D input, antialias == 0, exclude_outside == 0, keep_aspect_ratio_policy == 'stretch', 'linear' and 'nearest' modes, input 'scales' and 'sizes' if present must be a constant | -| RotaryEmbedding | com.microsoft(1+) | add, concat, gather, mul, reshape, split | | +| RotaryEmbedding | ai.onnx(23+), com.microsoft(1+) | add, concat, gather, mul, reshape, slice, split | | | ScatterElements | ai.onnx(11-12, 13-15, 16-17, 18+) | scatterElements | Only supports 'reduction' == 'none' | | ScatterND | ai.onnx(11-12, 13-15, 16-17, 18+) | scatterND | Only supports 'reduction' == 'none' | | Shape | ai.onnx(7-12, 13-14, 15-18, 19-20, 21+) | slice | | diff --git a/js/web/lib/onnxjs/backends/backend-webgl.ts b/js/web/lib/onnxjs/backends/backend-webgl.ts index a122068eb67bc..25c29f8680e6f 100644 --- a/js/web/lib/onnxjs/backends/backend-webgl.ts +++ b/js/web/lib/onnxjs/backends/backend-webgl.ts @@ -12,7 +12,7 @@ import { WebGLContext } from './webgl/webgl-context'; import { createWebGLContext } from './webgl/webgl-context-factory'; /** - * WebGLBackend is the entry point for all WebGL opeartions + * WebGLBackend is the entry point for all WebGL operations * When it starts it created the WebGLRenderingContext * and other main framework components such as Program and Texture Managers */ diff --git a/js/web/lib/onnxjs/backends/webgl/glsl-vec-lib.ts b/js/web/lib/onnxjs/backends/webgl/glsl-vec-lib.ts index 7b1ba915e7c10..d50e3a510d480 100644 --- a/js/web/lib/onnxjs/backends/webgl/glsl-vec-lib.ts +++ b/js/web/lib/onnxjs/backends/webgl/glsl-vec-lib.ts @@ -5,7 +5,7 @@ import { GlslContext, GlslLib, GlslLibRoutine } from './glsl-definitions'; /** * GLSL Library responsible for vec routines - * Vec is an varible length int array. The length is fixed at the time of + * Vec is an variable length int array. The length is fixed at the time of * generating the library functions from the dimensions of the output. */ export class VecGlslLib extends GlslLib { diff --git a/js/web/lib/onnxjs/backends/webgl/texture-manager.ts b/js/web/lib/onnxjs/backends/webgl/texture-manager.ts index 3aad95b33e3e4..05efced124d14 100644 --- a/js/web/lib/onnxjs/backends/webgl/texture-manager.ts +++ b/js/web/lib/onnxjs/backends/webgl/texture-manager.ts @@ -16,7 +16,7 @@ export interface TextureManagerConfig { /** * TextureManager is the mainly responsible for caching Textures * Textures are cached in 2 levels: - * 1. the texures which are associated with a dataId (from Tensor) + * 1. the textures which are associated with a dataId (from Tensor) * Caching these is crucial to performance. These are In-use Textures * 2. textures which are not in use by any current ProgramInfo/Tensor * These are called Free Textures diff --git a/js/web/lib/onnxjs/backends/webgl/webgl-context.ts b/js/web/lib/onnxjs/backends/webgl/webgl-context.ts index 19684dec81b3d..a559e6053863c 100644 --- a/js/web/lib/onnxjs/backends/webgl/webgl-context.ts +++ b/js/web/lib/onnxjs/backends/webgl/webgl-context.ts @@ -581,7 +581,7 @@ ${shaderSource}`); // TODO: add webgl 1 handling. throw new Error('WebGL1 profiling currently not supported'); } - // return miliseconds + // return milliseconds return timeElapsed / 1000000; } diff --git a/js/web/lib/onnxjs/instrument.ts b/js/web/lib/onnxjs/instrument.ts index df6a1777054fd..a8cb8685950e6 100644 --- a/js/web/lib/onnxjs/instrument.ts +++ b/js/web/lib/onnxjs/instrument.ts @@ -27,7 +27,7 @@ export declare namespace Logger { */ provider?: Provider; /** - * Specify the minimal logger serverity. 'warning' by default + * Specify the minimal logger severity. 'warning' by default */ minimalSeverity?: Logger.Severity; /** @@ -178,7 +178,7 @@ function createCategorizedLogger(category: string): Logger.CategorizedLogger { }; } -// NOTE: argument 'category' is put the last parameter beacause typescript +// NOTE: argument 'category' is put the last parameter because typescript // doesn't allow optional argument put in front of required argument. This // order is different from a usual logging API. function logInternal(severity: Logger.Severity, content: string, _stack: number, category?: string) { @@ -440,7 +440,7 @@ export class Profiler { this._timingEvents.length - this._flushPointer >= this._flushBatchSize || currentTime - this._flushTime >= this._flushIntervalInMilliseconds ) { - // should flush when either batch size accumlated or interval elepsed + // should flush when either batch size accumulated or interval elepsed for ( const previousPointer = this._flushPointer; diff --git a/js/web/lib/onnxjs/session.ts b/js/web/lib/onnxjs/session.ts index 26243ed9fe509..ffa616b921c5e 100644 --- a/js/web/lib/onnxjs/session.ts +++ b/js/web/lib/onnxjs/session.ts @@ -225,7 +225,7 @@ export class Session { for (let i = 0; i < expectedDims.length; ++i) { if (expectedDims[i] !== actualDims[i] && (!noneDimSupported || expectedDims[i] !== 0)) { - // data shape mis-match AND not a 'None' dimension. + // data shape mismatch AND not a 'None' dimension. return false; } } diff --git a/js/web/lib/onnxjs/util.ts b/js/web/lib/onnxjs/util.ts index 3afe84c046eb4..b25e9ef33b38a 100644 --- a/js/web/lib/onnxjs/util.ts +++ b/js/web/lib/onnxjs/util.ts @@ -100,7 +100,7 @@ export class MatMulUtil { /** * Fix the output shape computed for MatMul operation if it needs fixing - * @param outputShape The computed outputShape. Should be an array (atleast of length 2) of positive integers. + * @param outputShape The computed outputShape. Should be an array (at least of length 2) of positive integers. * This will be mutated. * @param aRank The rank of tensor A. * @param bRank The rank of tensor B. @@ -183,7 +183,7 @@ export class BroadcastUtil { /** * Given the indices of a broadcasted tensor, calculate the original indices * @param broadcastedIndices The given indices of the broadcasted tensor. - * @param originalShape The original shape of the tensor before broadcas + * @param originalShape The original shape of the tensor before broadcast * @returns The calculated indices that maps to the original tensor. */ static index(broadcastedIndices: readonly number[], originalShape: readonly number[]): number[] { @@ -243,7 +243,7 @@ export class BroadcastUtil { c.set([], op(a.get([]) as number, b.get([]) as number)); } - // atleast one input is a non-scalar + // at least one input is a non-scalar else { const outputIndices = new Array(outputShape.length); const originalIndicesA = new Array(a.dims.length); @@ -611,7 +611,7 @@ export class ShapeUtil { } /** - * normailze axis of range [-r, r) into [0, r). + * normalize axis of range [-r, r) into [0, r). */ static normalizeAxis(axis: number, tensorRank: number): number { if (axis < -tensorRank && axis >= tensorRank) { diff --git a/js/web/lib/wasm/jsep/init.ts b/js/web/lib/wasm/jsep/init.ts index 8ab6b054bf8a7..463e26d0208e5 100644 --- a/js/web/lib/wasm/jsep/init.ts +++ b/js/web/lib/wasm/jsep/init.ts @@ -291,6 +291,8 @@ export const init = async ( }, // jsepDownloadTensor async (tensorId: number, dstBuffer: ArrayBufferView | ArrayBuffer) => backend.downloadTensor(tensorId, dstBuffer), + // jsepEnableTraceEvent + !!env.trace, ]); } }; diff --git a/js/web/lib/wasm/jsep/util.ts b/js/web/lib/wasm/jsep/util.ts index 85aca96057df2..7f87a2475ac16 100644 --- a/js/web/lib/wasm/jsep/util.ts +++ b/js/web/lib/wasm/jsep/util.ts @@ -189,7 +189,7 @@ export class ShapeUtil { } /** - * normailze axis of range [-r, r) into [0, r). + * normalize axis of range [-r, r) into [0, r). */ static normalizeAxis(axis: number, tensorRank: number): number { if (axis < -tensorRank && axis >= tensorRank) { diff --git a/js/web/lib/wasm/jsep/webgpu/ops/bias-add.ts b/js/web/lib/wasm/jsep/webgpu/ops/bias-add.ts index dd59d5f03d47d..321cb136f86b1 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/bias-add.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/bias-add.ts @@ -29,7 +29,7 @@ const createBiasAddProgramInfo = (inputs: readonly TensorView[]): ProgramInfo => const outputShape = inputs[0].dims; const channels = inputs[0].dims[2]; - // since channel number can be only 320/640/1280, it's always divisable by 4 + // since channel number can be only 320/640/1280, it's always divisible by 4 const outputSize = ShapeUtil.size(outputShape) / 4; const dataType = inputs[0].dataType; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/einsum.ts b/js/web/lib/wasm/jsep/webgpu/ops/einsum.ts index 48da675193ad8..9374d50111ca1 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/einsum.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/einsum.ts @@ -22,12 +22,12 @@ export interface EinsumAttributes extends AttributeWithCacheKey { const symbolPattern = '[a-zA-Z]|\\.\\.\\.'; // The pattern each symbol in each term in the symbolic equation should match const termPattern = '(' + symbolPattern + ')+'; // The pattern each term in the symbolic equation should match -const termPatternOnly = '^' + termPattern + '$'; // The patterns only matchs a term begin to end. +const termPatternOnly = '^' + termPattern + '$'; // The patterns only matches a term begin to end. const lhsPattern = '(' + termPattern + ',)*' + termPattern; // The pattern the LHS should match -const lhsPatternOnly = '^' + lhsPattern + '$'; // The patterns only matchs a LHS begin to end. +const lhsPatternOnly = '^' + lhsPattern + '$'; // The patterns only matches a LHS begin to end. interface SymbolInfo { - count: number; // Symbol corresponding to a dimmension of an input + count: number; // Symbol corresponding to a dimension of an input inputIndices: number[]; // Number of input variables the symbol corresponds to dimValue: number; // Number of dimensions the symbol corresponds to } diff --git a/js/web/lib/wasm/jsep/webgpu/ops/group-query-attention.ts b/js/web/lib/wasm/jsep/webgpu/ops/group-query-attention.ts index 32b3c54f734dc..d218be3ce8b5f 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/group-query-attention.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/group-query-attention.ts @@ -100,7 +100,7 @@ export const validateInputs = ( } const hasPastKey = pastKey && pastKey.dims.length !== 0; const hasPastValue = pastValue && pastValue.dims.length !== 0; - // Currenly the onnxruntime GQA specification only support key/value BNSH format. + // Currently the onnxruntime GQA specification only support key/value BNSH format. const isPastkvBSNH = hasPastKey && pastKey.dims.length === 4 && diff --git a/js/web/lib/wasm/jsep/webgpu/ops/scatter-nd.ts b/js/web/lib/wasm/jsep/webgpu/ops/scatter-nd.ts index ec1d23e4887d5..286984c15feca 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/scatter-nd.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/scatter-nd.ts @@ -78,46 +78,15 @@ const atomicReductionSnippet = (reduction: string, ptr: string, v: string, type: } }; -const calcDataOffsetSnippet = (dataRank: number, parallel: boolean) => - `${ - dataRank === 1 - ? ` - let element_count_dim = uniforms.output_strides; - let dim_value = uniforms.output_shape;` - : ` - let element_count_dim = uniforms.output_strides[${parallel ? 'i - indices_start' : 'i'}]; - let dim_value = uniforms.output_shape[${parallel ? 'i - indices_start' : 'i'} + uniforms.last_index_dimension];` - } - - if (index >= 0) { - if (index >= i32(dim_value)) { - index = i32(dim_value - 1); - } - } else { - if (index < -i32(dim_value)) { - index = 0; - } else { - index += i32(dim_value); - } - } - data_offset += u32((u32(index) * element_count_dim));`; - -const updateElementsSnippet = (attributes: ScatterNDAttributes, outputTypeValue: ReductionType, parallel: boolean) => - `for (var i = 0u; i < uniforms.num_updates_elements; i++) { - let value = updates[uniforms.num_updates_elements * ${parallel ? 'global_idx' : 'idx'} + i]; - ${atomicReductionSnippet(attributes.reduction, 'output[data_offset + i]', 'value', outputTypeValue)} - }`; - const createScatterNDProgramInfo = (inputs: readonly TensorView[], attributes: ScatterNDAttributes): ProgramInfo => { const inputShape = inputs[0].dims; const indicesShape = inputs[1].dims; const outputShape = inputShape; // TODO: support bool with components 4. const components = 1; - const outputSize = Math.ceil(ShapeUtil.size(indicesShape) / components); + const outputSize = Math.ceil(ShapeUtil.sizeToDimension(indicesShape, indicesShape.length - 1) / components); const lastIndexDimension = indicesShape[indicesShape.length - 1]; const numUpdatesElements = ShapeUtil.sizeFromDimension(inputShape, lastIndexDimension); - const numIndicesElements = ShapeUtil.sizeFromDimension(indicesShape, 0) / lastIndexDimension; const programUniforms: ProgramUniform[] = [ { type: DataType.uint32, data: outputSize }, @@ -142,48 +111,45 @@ const createScatterNDProgramInfo = (inputs: readonly TensorView[], attributes: S .declareVariables(indices, updates, output)} ${shaderHelper.mainStart()} ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')} - var hasDuplicates = false; - if (${attributes.reduction === 'none'}) { - for (var i = 0; i < ${numIndicesElements}; i = i + 1) { - for (var j = i + 1; j < ${numIndicesElements}; j = j + 1) { - var index_i = i32(indices[i].x); - var index_j = i32(indices[j].x); - if (index_i == index_j) { - hasDuplicates = true; - break; - } + var data_offset = 0u; + let indices_start = uniforms.last_index_dimension * global_idx; + let indices_end = indices_start + uniforms.last_index_dimension; + for (var i = indices_start; i < indices_end; i++) { + var index = i32(indices[i].x); + ${ + inputs[0].dims.length === 1 + ? ` + let element_count_dim = uniforms.output_strides; + let dim_value = uniforms.output_shape;` + : ` + let element_count_dim = uniforms.output_strides[i - indices_start]; + let dim_value = uniforms.output_shape[i - indices_start];` + } + if (index >= 0) { + if (index >= i32(dim_value)) { + index = i32(dim_value - 1); } - if (hasDuplicates) { - break; + } else { + if (index < -i32(dim_value)) { + index = 0; + } else { + index += i32(dim_value); } } + data_offset += u32((u32(index) * element_count_dim)); } - if (${attributes.reduction === 'none'} && hasDuplicates) { - if (global_idx != 0u) { - return; - } - // Process each index-update pair individually when duplicates exist - for (var idx = 0u; idx < ${numIndicesElements}u; idx++) { - var data_offset = 0u; - for (var i = 0u; i < uniforms.last_index_dimension; i++) { - var index = i32(indices[idx * uniforms.last_index_dimension + i].x); - ${calcDataOffsetSnippet(inputShape.length, false)} - } - ${updateElementsSnippet(attributes, output.type.value as ReductionType, false)} - } - return; + for (var i = 0u; i < uniforms.num_updates_elements; i++) { + let value = updates[uniforms.num_updates_elements * global_idx + i]; + ${atomicReductionSnippet( + attributes.reduction, + 'output[data_offset + i]', + 'value', + output.type.value as ReductionType, + )} } - var data_offset = 0u; - var indices_start = uniforms.last_index_dimension * global_idx; - var indices_end = indices_start + uniforms.last_index_dimension; - for (var i = indices_start; i < indices_end; i++) { - var index = i32(indices[i].x); - ${calcDataOffsetSnippet(inputShape.length, true)} - } - ${updateElementsSnippet(attributes, output.type.value as ReductionType, true)} - }`; + }`; }; return { name: 'ScatterND', diff --git a/js/web/lib/wasm/jsep/webgpu/ops/split.ts b/js/web/lib/wasm/jsep/webgpu/ops/split.ts index 8c39505734e41..76a69e4bb8061 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/split.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/split.ts @@ -140,7 +140,7 @@ export const parseSplitAttributes = (attributes: Record): Split const splitSizes: number[] = attributes.splitSizes as number[]; const numOutputs = (attributes.numOutputs as number) < 0 ? splitSizes.length : (attributes.numOutputs as number); if (numOutputs !== splitSizes.length) { - throw new Error('numOutputs and splitSizes lengh must be equal'); + throw new Error('numOutputs and splitSizes length must be equal'); } return createAttributeWithCacheKey({ axis, numOutputs, splitSizes }); }; diff --git a/js/web/lib/wasm/proxy-wrapper.ts b/js/web/lib/wasm/proxy-wrapper.ts index 30b1f5101e5f2..a6b0ac6d5a051 100644 --- a/js/web/lib/wasm/proxy-wrapper.ts +++ b/js/web/lib/wasm/proxy-wrapper.ts @@ -124,7 +124,7 @@ export const initializeWebAssemblyAndOrtRuntime = async (): Promise => { ) { // for a build bundled the wasm JS, if either of the following conditions is met: // - the proxy worker is loaded from a blob URL - // - `import.meta.url` is a file URL, it means it is overwriten by the bundler. + // - `import.meta.url` is a file URL, it means it is overwritten by the bundler. // // in either case, the path information is lost, we need to pass the path of the .wasm file to the worker. // we need to use the bundler preferred URL format: diff --git a/js/web/lib/wasm/run-options.ts b/js/web/lib/wasm/run-options.ts index d15c8339b6824..a0ef57df31129 100644 --- a/js/web/lib/wasm/run-options.ts +++ b/js/web/lib/wasm/run-options.ts @@ -22,7 +22,7 @@ export const setRunOptions = (options: InferenceSession.RunOptions): [number, nu options.logSeverityLevel < 0 || options.logSeverityLevel > 4 ) { - throw new Error(`log serverity level is not valid: ${options.logSeverityLevel}`); + throw new Error(`log severity level is not valid: ${options.logSeverityLevel}`); } if (options?.logVerbosityLevel === undefined) { diff --git a/js/web/lib/wasm/session-options.ts b/js/web/lib/wasm/session-options.ts index 89a4484e5a1c4..cd787379220c1 100644 --- a/js/web/lib/wasm/session-options.ts +++ b/js/web/lib/wasm/session-options.ts @@ -179,7 +179,7 @@ export const setSessionOptions = async (options?: InferenceSession.SessionOption const logSeverityLevel = sessionOptions.logSeverityLevel ?? 2; // Default to 2 - warning if (!Number.isInteger(logSeverityLevel) || logSeverityLevel < 0 || logSeverityLevel > 4) { - throw new Error(`log serverity level is not valid: ${logSeverityLevel}`); + throw new Error(`log severity level is not valid: ${logSeverityLevel}`); } const logVerbosityLevel = sessionOptions.logVerbosityLevel ?? 0; // Default to 0 - verbose diff --git a/js/web/lib/wasm/wasm-core-impl.ts b/js/web/lib/wasm/wasm-core-impl.ts index f42a224ed2e85..cfdc0053b3485 100644 --- a/js/web/lib/wasm/wasm-core-impl.ts +++ b/js/web/lib/wasm/wasm-core-impl.ts @@ -6,7 +6,7 @@ // https://github.com/webmachinelearning/webnn/issues/677 /// -import { Env, InferenceSession, Tensor } from 'onnxruntime-common'; +import { Env, InferenceSession, Tensor, TRACE_EVENT_BEGIN, TRACE_EVENT_END } from 'onnxruntime-common'; import { SerializableInternalBuffer, @@ -711,6 +711,7 @@ export const run = async ( try { [runOptionsHandle, runOptionsAllocs] = setRunOptions(options); + TRACE_EVENT_BEGIN('wasm prepareInputOutputTensor'); // create input tensors for (let i = 0; i < inputCount; i++) { await prepareInputOutputTensor( @@ -736,6 +737,7 @@ export const run = async ( enableGraphCapture, ); } + TRACE_EVENT_END('wasm prepareInputOutputTensor'); for (let i = 0; i < inputCount; i++) { wasm.setValue(inputValuesOffset + i * ptrSize, inputTensorHandles[i], '*'); @@ -755,6 +757,7 @@ export const run = async ( ); } + TRACE_EVENT_BEGIN('wasm bindInputsOutputs'); // process inputs for (let i = 0; i < inputCount; i++) { const index = inputIndices[i]; @@ -788,6 +791,7 @@ export const run = async ( } } } + TRACE_EVENT_END('wasm bindInputsOutputs'); activeSessions.set(sessionId, [ sessionHandle, inputNamesUTF8Encoded, @@ -830,6 +834,7 @@ export const run = async ( const output: TensorMetadata[] = []; const outputPromises: Array> = []; + TRACE_EVENT_BEGIN('wasm ProcessOutputTensor'); for (let i = 0; i < outputCount; i++) { const tensor = Number(wasm.getValue(outputValuesOffset + i * ptrSize, '*')); if (tensor === outputTensorHandles[i]) { @@ -1028,6 +1033,7 @@ export const run = async ( for (const [index, data] of await Promise.all(outputPromises)) { output[index][2] = data; } + TRACE_EVENT_END('wasm ProcessOutputTensor'); return output; } finally { wasm.webnnOnRunEnd?.(sessionHandle); diff --git a/js/web/lib/wasm/wasm-types.ts b/js/web/lib/wasm/wasm-types.ts index f2d051927b1d5..29a4028ae46cc 100644 --- a/js/web/lib/wasm/wasm-types.ts +++ b/js/web/lib/wasm/wasm-types.ts @@ -71,6 +71,7 @@ export declare namespace JSEP { ensureTensor: EnsureTensorFunction, uploadTensor: UploadTensorFunction, downloadTensor: DownloadTensorFunction, + enableTraceEvent: boolean, ], ): void; } diff --git a/js/web/test/e2e/exports/testcases/vite-default/package-lock.json b/js/web/test/e2e/exports/testcases/vite-default/package-lock.json index c9da59b4b0021..48f0a8f3e9d5c 100644 --- a/js/web/test/e2e/exports/testcases/vite-default/package-lock.json +++ b/js/web/test/e2e/exports/testcases/vite-default/package-lock.json @@ -12,7 +12,7 @@ }, "devDependencies": { "@vitejs/plugin-vue": "^5.2.1", - "vite": "^6.2.6" + "vite": "^6.3.5" } }, "node_modules/@babel/helper-string-parser": { @@ -944,6 +944,21 @@ "integrity": "sha512-Rfkk/Mp/DL7JVje3u18FxFujQlTNR2q6QfMSMB7AvCBx91NGj/ba3kCfza0f6dVDbw7YlRf/nDrn7pQrCCyQ/w==", "license": "MIT" }, + "node_modules/fdir": { + "version": "6.4.4", + "resolved": "https://registry.npmjs.org/fdir/-/fdir-6.4.4.tgz", + "integrity": "sha512-1NZP+GK4GfuAv3PqKvxQRDMjdSRZjnkq7KfhlNrCNNlZ0ygQFpebfrnfnq/W7fpUnAv9aGWmY1zKx7FYL3gwhg==", + "dev": true, + "license": "MIT", + "peerDependencies": { + "picomatch": "^3 || ^4" + }, + "peerDependenciesMeta": { + "picomatch": { + "optional": true + } + } + }, "node_modules/fsevents": { "version": "2.3.3", "resolved": "https://registry.npmjs.org/fsevents/-/fsevents-2.3.3.tgz", @@ -992,6 +1007,19 @@ "integrity": "sha512-xceH2snhtb5M9liqDsmEw56le376mTZkEX/jEb/RxNFyegNul7eNslCXP9FDj/Lcu0X8KEyMceP2ntpaHrDEVA==", "license": "ISC" }, + "node_modules/picomatch": { + "version": "4.0.2", + "resolved": "https://registry.npmjs.org/picomatch/-/picomatch-4.0.2.tgz", + "integrity": "sha512-M7BAV6Rlcy5u+m6oPhAPFgJTzAioX/6B0DxyvDlo9l8+T3nLKbrczg2WLUyzd45L8RqfUMyGPzekbMvX2Ldkwg==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=12" + }, + "funding": { + "url": "https://github.com/sponsors/jonschlinkert" + } + }, "node_modules/postcss": { "version": "8.5.3", "resolved": "https://registry.npmjs.org/postcss/-/postcss-8.5.3.tgz", @@ -1068,16 +1096,36 @@ "node": ">=0.10.0" } }, + "node_modules/tinyglobby": { + "version": "0.2.13", + "resolved": "https://registry.npmjs.org/tinyglobby/-/tinyglobby-0.2.13.tgz", + "integrity": "sha512-mEwzpUgrLySlveBwEVDMKk5B57bhLPYovRfPAXD5gA/98Opn0rCDj3GtLwFvCvH5RK9uPCExUROW5NjDwvqkxw==", + "dev": true, + "license": "MIT", + "dependencies": { + "fdir": "^6.4.4", + "picomatch": "^4.0.2" + }, + "engines": { + "node": ">=12.0.0" + }, + "funding": { + "url": "https://github.com/sponsors/SuperchupuDev" + } + }, "node_modules/vite": { - "version": "6.2.6", - "resolved": "https://registry.npmjs.org/vite/-/vite-6.2.6.tgz", - "integrity": "sha512-9xpjNl3kR4rVDZgPNdTL0/c6ao4km69a/2ihNQbcANz8RuCOK3hQBmLSJf3bRKVQjVMda+YvizNE8AwvogcPbw==", + "version": "6.3.5", + "resolved": "https://registry.npmjs.org/vite/-/vite-6.3.5.tgz", + "integrity": "sha512-cZn6NDFE7wdTpINgs++ZJ4N49W2vRp8LCKrn3Ob1kYNtOo21vfDoaV5GzBfLU4MovSAB8uNRm4jgzVQZ+mBzPQ==", "dev": true, "license": "MIT", "dependencies": { "esbuild": "^0.25.0", + "fdir": "^6.4.4", + "picomatch": "^4.0.2", "postcss": "^8.5.3", - "rollup": "^4.30.1" + "rollup": "^4.34.9", + "tinyglobby": "^0.2.13" }, "bin": { "vite": "bin/vite.js" diff --git a/js/web/test/e2e/exports/testcases/vite-default/package.json b/js/web/test/e2e/exports/testcases/vite-default/package.json index 5169734074299..f7d5751354905 100644 --- a/js/web/test/e2e/exports/testcases/vite-default/package.json +++ b/js/web/test/e2e/exports/testcases/vite-default/package.json @@ -13,6 +13,6 @@ }, "devDependencies": { "@vitejs/plugin-vue": "^5.2.1", - "vite": "^6.2.6" + "vite": "^6.3.5" } } diff --git a/js/web/test/unittests/backends/webgl/test-pack-unpack.ts b/js/web/test/unittests/backends/webgl/test-pack-unpack.ts index 28821663ffd50..bae0a9f3e8fc9 100644 --- a/js/web/test/unittests/backends/webgl/test-pack-unpack.ts +++ b/js/web/test/unittests/backends/webgl/test-pack-unpack.ts @@ -257,7 +257,7 @@ describe('#UnitTest# - pack - Tensor pack', () => { for (let k = 0; k < testDataSet.length; ++k) { const testData = testDataSet[k]; describe('Test pack', () => {}); - it(`Test pack kernal ${textureLayout[w]} ${JSON.stringify(testData)}`, () => { + it(`Test pack kernel ${textureLayout[w]} ${JSON.stringify(testData)}`, () => { const webglInferenceHandler = inferenceHandler as WebGLInferenceHandler; const elementCount = testData.elementCount; @@ -291,7 +291,7 @@ describe('#UnitTest# - pack - Tensor pack', () => { // compile shader code const programInfo = createPackProgramInfoLoader(inferenceHandler! as WebGLInferenceHandler, inputTensor); - // run kernal and get output + // run kernel and get output const resultTextureData = webglInferenceHandler.executeProgram(programInfo, [inputTensor]); const gl = webglInferenceHandler.session.textureManager.glContext.gl; const resultDataBuffer = createArrayFromTexture( @@ -324,7 +324,7 @@ describe('#UnitTest# - unpack - Tensor unpack', () => { for (let k = 0; k < testDataSet.length; ++k) { const testData = testDataSet[k]; describe(`Test unpack ${JSON.stringify(testData)}`, () => {}); - it(`Test unpack kernal ${testData.inputShape}`, () => { + it(`Test unpack kernel ${testData.inputShape}`, () => { const webglInferenceHandler = inferenceHandler as WebGLInferenceHandler; const elementCount = testData.elementCount; @@ -365,7 +365,7 @@ describe('#UnitTest# - unpack - Tensor unpack', () => { // compile shader code const programInfo = createUnpackProgramInfoLoader(inferenceHandler! as WebGLInferenceHandler, inputTensor); - // run kernal and get output + // run kernel and get output const resultTextureData = webglInferenceHandler.executeProgram(programInfo, [inputTensor]); const result = resultTextureData.tensor.data; @@ -419,7 +419,7 @@ describe('#UnitTest# - pack-unpack round trip', () => { packResultData.tensor, ); - // run unpack kernal and get output + // run unpack kernel and get output const unpackResultData = webglInferenceHandler.executeProgram(unpackProgramInfo, [inputTensor]); const resultData = unpackResultData.tensor.data; diff --git a/js/web/test/unittests/backends/webgl/test-reshape-packed.ts b/js/web/test/unittests/backends/webgl/test-reshape-packed.ts index b90372db1250a..61abcadc5b98b 100644 --- a/js/web/test/unittests/backends/webgl/test-reshape-packed.ts +++ b/js/web/test/unittests/backends/webgl/test-reshape-packed.ts @@ -134,7 +134,7 @@ describe('#UnitTest# - reshape - packed', () => { const inputData = createAscendingArray(elementCount); const inputTensorA = new Tensor(inputTensorShape, 'float32', undefined, undefined, inputData); - // run kernal and get output + // run kernel and get output const resultTensor = webglInferenceHandler.reshapePacked(inputTensorA, outputTensorShape); const result = resultTensor.data; diff --git a/onnxruntime/__init__.py b/onnxruntime/__init__.py index f8e13f7629e6b..6ef0707f4b7c6 100644 --- a/onnxruntime/__init__.py +++ b/onnxruntime/__init__.py @@ -30,6 +30,10 @@ NodeArg, # noqa: F401 OrtAllocatorType, # noqa: F401 OrtArenaCfg, # noqa: F401 + OrtEpDevice, # noqa: F401 + OrtExecutionProviderDevicePolicy, # noqa: F401 + OrtHardwareDevice, # noqa: F401 + OrtHardwareDeviceType, # noqa: F401 OrtMemoryInfo, # noqa: F401 OrtMemType, # noqa: F401 OrtSparseFormat, # noqa: F401 @@ -44,11 +48,14 @@ get_available_providers, # noqa: F401 get_build_info, # noqa: F401 get_device, # noqa: F401 + get_ep_devices, # noqa: F401 get_version_string, # noqa: F401 has_collective_ops, # noqa: F401 + register_execution_provider_library, # noqa: F401 set_default_logger_severity, # noqa: F401 set_default_logger_verbosity, # noqa: F401 set_seed, # noqa: F401 + unregister_execution_provider_library, # noqa: F401 ) import_capi_exception = None @@ -64,6 +71,7 @@ AdapterFormat, # noqa: F401 InferenceSession, # noqa: F401 IOBinding, # noqa: F401 + ModelCompiler, # noqa: F401 OrtDevice, # noqa: F401 OrtValue, # noqa: F401 SparseTensor, # noqa: F401 diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc index eb5de4634a4d8..33f90d0a5f791 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc @@ -143,7 +143,7 @@ Status FlashAttentionProgram::GenerateShaderCode(ShaderHelper& shader) const { const max_k_step: u32 = 16u; const vec_factor: u32 = 4u; const qkv_head_size_vec: u32 = qkv_head_size / vec_factor; - const min_value : q_element_t = q_element_t(-65504.0); + const min_value = f32(-3.402823e+38f);; // Default SHM usage limit is 16KB in Dawn. // vec4 * qkv_head_size_vec * max_k_step = 8 * (128/4) * 16 = 4KB. 128 is head_size for phi4. @@ -195,32 +195,32 @@ Status FlashAttentionProgram::GenerateShaderCode(ShaderHelper& shader) const { // Move half of o_tile from private memory into workgroup memory to reduce register pressure. // Note that register spill was observed on Qualcomm if whole o_tile is on private memory. // vec4 * half_qkv_head_size_vec * workgroup_size_x = 8 * (128/4/2) * 64 = 8KB. - var o_tile_r : array, workgroup_size_x>; + var o_tile_r : array, half_qkv_head_size_vec>, workgroup_size_x>; // Private memory per lane. - var o_tile : array; + var o_tile : array, half_qkv_head_size_vec>; fn writeo(o_idx_global: u32, head_idx: u32, local_idx: u32) { // Stored as float16[batch_size,sequence_length,3072] let offset = o_idx_global * num_heads * qkv_head_size_vec + head_idx * qkv_head_size_vec; for (var idx:u32 = 0; idx < half_qkv_head_size_vec; idx ++) { - output[offset+idx] = o_tile[idx]; - output[offset+idx+half_qkv_head_size_vec] = o_tile_r[local_idx][idx]; + output[offset+idx] = q_value_t(o_tile[idx]); + output[offset+idx+half_qkv_head_size_vec] = q_value_t(o_tile_r[local_idx][idx]); } } )HELPER_FN"; } else { shader.AdditionalImplementation() << R"HELPER_FN( // Private memory per lane. - var o_tile : array; + var o_tile : array, qkv_head_size_vec>; fn writeo(o_idx_global: u32, head_idx: u32) { // Stored as float16[batch_size,sequence_length,3072] let offset = o_idx_global * num_heads * qkv_head_size_vec + head_idx * qkv_head_size_vec; for (var idx:u32 = 0; idx < qkv_head_size_vec; idx ++) { - output[offset+idx] = o_tile[idx]; + output[offset+idx] = q_value_t(o_tile[idx]); } } )HELPER_FN"; @@ -268,8 +268,8 @@ Status FlashAttentionProgram::GenerateShaderCode(ShaderHelper& shader) const { loadq(q_idx_global, head_idx); } - var previous_max : q_element_t = min_value; - var previous_denom : q_element_t = 0; + var previous_max : f32 = min_value; + var previous_denom : f32 = 0; for(var k_start = 0u; k_start < uniforms.total_sequence_length; k_start+=capped_sg_size) { @@ -279,16 +279,16 @@ Status FlashAttentionProgram::GenerateShaderCode(ShaderHelper& shader) const { workgroupBarrier(); // Compute QKt - var qk_1:vec4; - var qk_2:vec4; - var qk_3:vec4; - var qk_4:vec4; + var qk_1:vec4; + var qk_2:vec4; + var qk_3:vec4; + var qk_4:vec4; if (sg_size > 8) { for (var i:u32 = 0u; i < qkv_head_size_vec; i++) { - var k_local = k_tile[capped_sg_id][i]; - var q_own = q_tile[i]; + var k_local = vec4(k_tile[capped_sg_id][i]); + var q_own = vec4(q_tile[i]); qk_1[0] += dot(q_own, subgroupShuffle(k_local, 0)); qk_1[1] += dot(q_own, subgroupShuffle(k_local, 1)); qk_1[2] += dot(q_own, subgroupShuffle(k_local, 2)); @@ -311,8 +311,8 @@ Status FlashAttentionProgram::GenerateShaderCode(ShaderHelper& shader) const { { for (var i:u32 = 0u; i < qkv_head_size_vec; i++) { - var k_local = k_tile[capped_sg_id][i]; - var q_own = q_tile[i]; + var k_local = vec4(k_tile[capped_sg_id][i]); + var q_own = vec4(q_tile[i]); qk_1[0] += dot(q_own, subgroupShuffle(k_local, 0)); qk_1[1] += dot(q_own, subgroupShuffle(k_local, 1)); qk_1[2] += dot(q_own, subgroupShuffle(k_local, 2)); @@ -324,12 +324,12 @@ Status FlashAttentionProgram::GenerateShaderCode(ShaderHelper& shader) const { } } - qk_1 = qk_1 * q_element_t(uniforms.alpha) + loadAttentionBias(q_idx_global, k_start, head_idx); - qk_2 = qk_2 * q_element_t(uniforms.alpha) + loadAttentionBias(q_idx_global, k_start+4, head_idx); + qk_1 = qk_1 * uniforms.alpha + vec4(loadAttentionBias(q_idx_global, k_start, head_idx)); + qk_2 = qk_2 * uniforms.alpha + vec4(loadAttentionBias(q_idx_global, k_start+4, head_idx)); if (sg_size > 8) { - qk_3 = qk_3 * q_element_t(uniforms.alpha) + loadAttentionBias(q_idx_global, k_start+8, head_idx); - qk_4 = qk_4 * q_element_t(uniforms.alpha) + loadAttentionBias(q_idx_global, k_start+12, head_idx); + qk_3 = qk_3 * uniforms.alpha + vec4(loadAttentionBias(q_idx_global, k_start+8, head_idx)); + qk_4 = qk_4 * uniforms.alpha + vec4(loadAttentionBias(q_idx_global, k_start+12, head_idx)); } let seq_causal_length = select(uniforms.total_sequence_length, uniforms.past_sequence_length + q_idx_global + 1, uniforms.is_gqa > 0); @@ -386,11 +386,11 @@ Status FlashAttentionProgram::GenerateShaderCode(ShaderHelper& shader) const { } let local_max = max(max(local_max_temp.x, local_max_temp.y),max(local_max_temp.z, local_max_temp.w)); let new_max = max(previous_max, local_max); - qk_1 = q_value_t(exp(vec4(qk_1) - f32(new_max))); - qk_2 = q_value_t(exp(vec4(qk_2) - f32(new_max))); + qk_1 = exp(qk_1 - new_max); + qk_2 = exp(qk_2 - new_max); if (sg_size > 8) { - qk_3 = q_value_t(exp(vec4(qk_3) - f32(new_max))); - qk_4 = q_value_t(exp(vec4(qk_4) - f32(new_max))); + qk_3 = exp(qk_3 - new_max); + qk_4 = exp(qk_4 - new_max); } let sum_vec = qk_1 + qk_2 + qk_3 + qk_4; let sum = sum_vec.x + sum_vec.y + sum_vec.z + sum_vec.w; @@ -398,7 +398,7 @@ Status FlashAttentionProgram::GenerateShaderCode(ShaderHelper& shader) const { // Compute lhs term of update di prime and the compute di prime. let dleft = previous_denom * exp(previous_max-new_max); var d = dleft + sum; - d = select(d,q_element_t(0.0000001),d==0); + d = select(d,f32(0.0000001),d==0); qk_1 = qk_1 / d; qk_2 = qk_2 / d; if (sg_size > 8) { @@ -416,7 +416,7 @@ Status FlashAttentionProgram::GenerateShaderCode(ShaderHelper& shader) const { if (sg_size > 8) { for (var i:u32 = 0; i < half_qkv_head_size_vec; i++) { - var val = v_tile[capped_sg_id][i]; + var val = vec4(v_tile[capped_sg_id][i]); var sum = subgroupShuffle(val, 0) * qk_1[0]; sum += subgroupShuffle(val, 1) * qk_1[1]; sum += subgroupShuffle(val, 2) * qk_1[2]; @@ -435,7 +435,7 @@ Status FlashAttentionProgram::GenerateShaderCode(ShaderHelper& shader) const { sum += subgroupShuffle(val, 15) * qk_4[3]; o_tile[i] = o_tile[i] * o_ratio + sum; - val = v_tile[capped_sg_id][half_qkv_head_size_vec + i]; + val = vec4(v_tile[capped_sg_id][half_qkv_head_size_vec + i]); sum = subgroupShuffle(val, 0) * qk_1[0]; sum += subgroupShuffle(val, 1) * qk_1[1]; sum += subgroupShuffle(val, 2) * qk_1[2]; @@ -459,7 +459,7 @@ Status FlashAttentionProgram::GenerateShaderCode(ShaderHelper& shader) const { { for (var i:u32 = 0; i < half_qkv_head_size_vec; i++) { - var val = v_tile[capped_sg_id][i]; + var val = vec4(v_tile[capped_sg_id][i]); var sum = subgroupShuffle(val, 0) * qk_1[0]; sum += subgroupShuffle(val, 1) * qk_1[1]; sum += subgroupShuffle(val, 2) * qk_1[2]; @@ -470,7 +470,7 @@ Status FlashAttentionProgram::GenerateShaderCode(ShaderHelper& shader) const { sum += subgroupShuffle(val, 7) * qk_2[3]; o_tile[i] = o_tile[i] * o_ratio + sum; - val = v_tile[capped_sg_id][half_qkv_head_size_vec + i]; + val = vec4(v_tile[capped_sg_id][half_qkv_head_size_vec + i]); sum = subgroupShuffle(val, 0) * qk_1[0]; sum += subgroupShuffle(val, 1) * qk_1[1]; sum += subgroupShuffle(val, 2) * qk_1[2]; @@ -493,7 +493,7 @@ Status FlashAttentionProgram::GenerateShaderCode(ShaderHelper& shader) const { if (sg_size > 8) { for (var i:u32 = 0; i < qkv_head_size_vec; i++) { - var val = v_tile[capped_sg_id][i]; + var val = vec4(v_tile[capped_sg_id][i]); var sum = subgroupShuffle(val, 0) * qk_1[0]; sum += subgroupShuffle(val, 1) * qk_1[1]; sum += subgroupShuffle(val, 2) * qk_1[2]; @@ -517,7 +517,7 @@ Status FlashAttentionProgram::GenerateShaderCode(ShaderHelper& shader) const { { for (var i:u32 = 0; i < qkv_head_size_vec; i++) { - var val = select(vec4(0), v_tile[capped_sg_id][i], k_start + capped_sg_id < seq_causal_length); + var val = vec4(v_tile[capped_sg_id][i]); var sum = subgroupShuffle(val, 0) * qk_1[0]; sum += subgroupShuffle(val, 1) * qk_1[1]; sum += subgroupShuffle(val, 2) * qk_1[2]; @@ -547,7 +547,7 @@ Status FlashAttentionDecodeQKTProgram::GenerateShaderCode(ShaderHelper& shader) shader.AddInput("attention_bias", ShaderUsage::UseUniform); } shader.AddOutput("output", ShaderUsage::UseUniform); - shader.AddOutput("metadata", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); + shader.AddOutput("metadata", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); // Note that this shader adopts similar algorithm with dp4a generation shader. // // This algorithm works to compute dot product of keys with queries parallelly, by processing on the k (head_size) dimension at each step amongst tile_size_k_vec threads, @@ -569,8 +569,8 @@ Status FlashAttentionDecodeQKTProgram::GenerateShaderCode(ShaderHelper& shader) << "const sub_tile_count = " << WorkgroupSizeX() / tile_size_k_vec << "u;\n"; shader.AdditionalImplementation() << R"ADDNL_FN( var tile_q: array; -var inner_qk_values: array, tile_size>; -var tile_qk: array; +var inner_qk_values: array, tile_size>; +var tile_qk: array; )ADDNL_FN"; if (has_attention_bias_) { @@ -602,30 +602,34 @@ var tile_qk: array; tile_q[local_idx] = q[q_offset + k + local_idx]; } workgroupBarrier(); - let q_data = tile_q[local_col]; + let q_data = vec4(tile_q[local_col]); if (k + local_col < uniforms.head_size_vec) { for (var row_offset = 0u; row_offset < tile_size; row_offset += sub_tile_count) { if (total_seq_offset + row_offset + local_row < total_sequence_length) { - inner_qk_values[row_offset + local_row][local_col] += dot(present_key[present_offset + (total_seq_offset + row_offset + local_row) * uniforms.head_size_vec + k + local_col], q_data); + inner_qk_values[row_offset + local_row][local_col] += dot(vec4(present_key[present_offset + (total_seq_offset + row_offset + local_row) * uniforms.head_size_vec + k + local_col]), q_data); } } } workgroupBarrier(); } - if (local_idx < tile_size && total_seq_offset + local_idx < total_sequence_length) { - var sum = q_element_t(0); + if (local_idx < tile_size && total_seq_offset + local_idx < total_sequence_length && head_idx < uniforms.num_heads) { + var sum = f32(0); for (var i = 0u; i < tile_size_k_vec; i++) { sum += inner_qk_values[local_idx][i]; } let output_idx = head_idx * total_sequence_length + total_seq_offset + local_idx; - sum = sum * q_element_t(uniforms.alpha) + loadAttentionBias(output_idx); + sum = sum * uniforms.alpha + f32(loadAttentionBias(output_idx)); tile_qk[local_idx] = sum; output[output_idx] = sum; } workgroupBarrier(); + if (head_idx >= uniforms.num_heads) { + return; + } + if (local_idx == 0u) { // Calculate the max and sum in current split. var l_max = f32(-3.402823e+38f); @@ -637,7 +641,7 @@ var tile_qk: array; l_sum += exp(f32(tile_qk[i]) - l_max); } let meta_offset = head_idx * uniforms.num_total_seq_length_tile + workgroup_idx % uniforms.num_total_seq_length_tile; - metadata[meta_offset] = metadata_value_t(metadata_element_t(l_max), metadata_element_t(l_sum)); + metadata[meta_offset] = metadata_value_t(l_max, l_sum); } )MAIN_FN"; @@ -672,15 +676,16 @@ Status ComputeFlashAttentionDecodeQKT(onnxruntime::webgpu::ComputeContext& conte // present_sequence_length is used to index into the KV cache, for static kv cache it is the max sequence length. {static_cast(parameters.is_gqa_ ? parameters.seqlen_present_kv_cache_ : parameters.total_sequence_length_)}, {static_cast(parameters.n_reps)}, - {num_total_seq_length_tile}}); + {num_total_seq_length_tile}, + {static_cast(parameters.num_heads_)}}); return context.RunProgram(program); } Status FlashAttentionDecodeSplitVxProgram::GenerateShaderCode(ShaderHelper& shader) const { shader.AddInput("metadata", ShaderUsage::UseUniform); - shader.AddInput("qk", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); - shader.AddInput("present_value", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); + shader.AddInput("qk", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); + shader.AddInput("present_value", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); shader.AddOutput("out_split_vx", ShaderUsage::UseUniform); // Note that this shader adopts similar algorithm with dp4a generation shader. @@ -700,7 +705,7 @@ Status FlashAttentionDecodeSplitVxProgram::GenerateShaderCode(ShaderHelper& shad << "const tile_size_k_vec = " << tile_size_k_vec << "u;\n" << "const sub_tile_count = " << WorkgroupSizeX() / tile_size_k_vec << "u;\n"; shader.AdditionalImplementation() << R"HELPER_FN( -var tile_qk: array; +var tile_qk: array; var tile_output: array; var qkv_values: array, sub_tile_count>; @@ -718,24 +723,26 @@ var qkv_values: array, let present_offset = u32(head_idx / uniforms.n_reps) * uniforms.head_size_vec * uniforms.present_sequence_length; // Calculate the global max and sum in qk. - var g_max = f32(-3.402823e+38f); - for (var i = 0u; i < uniforms.num_total_seq_length_tile; i++) - { - let meta_offset = head_idx * uniforms.num_total_seq_length_tile + i; - g_max = max(g_max, f32(metadata[meta_offset].x)); - } - var g_sum = f32(0); - for (var i = 0u; i < uniforms.num_total_seq_length_tile; i++) + if (head_idx < uniforms.num_heads) { - let meta_offset = head_idx * uniforms.num_total_seq_length_tile + i; - let m_value = metadata[meta_offset]; - g_sum += exp(f32(m_value.x) - g_max) * f32(m_value.y); - } + var g_max = f32(-3.402823e+38f); + var g_sum = f32(0); + for (var i = 0u; i < uniforms.num_total_seq_length_tile; i++) + { + let meta_offset = head_idx * uniforms.num_total_seq_length_tile + i; + g_max = max(g_max, metadata[meta_offset].x); + } + for (var i = 0u; i < uniforms.num_total_seq_length_tile; i++) + { + let meta_offset = head_idx * uniforms.num_total_seq_length_tile + i; + let m_value = metadata[meta_offset]; + g_sum += exp(m_value.x - g_max) * m_value.y; + } - if (total_seq_offset + local_idx < total_sequence_length) { - tile_qk[local_idx] = qk_value_t(exp(f32(qk[head_idx * total_sequence_length + total_seq_offset + local_idx]) - g_max) / g_sum); + if (total_seq_offset + local_idx < total_sequence_length) { + tile_qk[local_idx] = present_value_element_t(exp(qk[head_idx * total_sequence_length + total_seq_offset + local_idx] - g_max) / g_sum); + } } - for (var k: u32 = 0u; k < uniforms.head_size_vec; k += tile_size_k_vec) { var value = present_value_value_t(0); qkv_values[local_row][local_col] = present_value_value_t(0); @@ -760,6 +767,10 @@ var qkv_values: array, workgroupBarrier(); } + if (head_idx >= uniforms.num_heads) { + return; + } + for (var i = local_idx; i < uniforms.head_size_vec; i += workgroup_size_x) { let out_offset = head_idx * uniforms.num_total_seq_length_tile * uniforms.head_size_vec + (workgroup_idx % uniforms.num_total_seq_length_tile) * uniforms.head_size_vec + i; out_split_vx[out_offset] = tile_output[i]; @@ -791,7 +802,8 @@ Status ComputeFlashAttentionDecodeSplitVxScore(onnxruntime::webgpu::ComputeConte {static_cast(head_size_vec)}, {static_cast(parameters.is_gqa_ ? parameters.seqlen_present_kv_cache_ : parameters.total_sequence_length_)}, {static_cast(parameters.n_reps)}, - num_total_seq_length_tile}); + num_total_seq_length_tile, + {static_cast(parameters.num_heads_)}}); return context.RunProgram(program); } @@ -830,6 +842,10 @@ var tile_input: array, TILE_SIZE>; tile_input[local_row][local_col] = value; workgroupBarrier(); + if (head_idx >= uniforms.num_heads) { + return; + } + if (local_idx < TILE_SIZE && head_size_offset + local_idx < uniforms.head_size_vec) { value = output_value_t(0); for (var i = 0u; i < TILE_SIZE; i++) { @@ -860,7 +876,8 @@ Status ComputeFlashAttentionDecodeVxReduce(onnxruntime::webgpu::ComputeContext& .SetWorkgroupSize(tile_size * tile_size) .AddUniformVariables({{static_cast(parameters.v_head_size_ / components)}, num_total_seq_length_tile, - {num_head_size_tile}}); + {num_head_size_tile}, + {static_cast(parameters.num_heads_)}}); return context.RunProgram(program); } @@ -903,14 +920,14 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co const TensorShapeVector qk_dims({parameters.batch_size_, parameters.num_heads_, parameters.sequence_length_, parameters.total_sequence_length_}); const TensorShape qk_shape(qk_dims); - Tensor qk = context.CreateGPUTensor(Q->DataType(), qk_shape); + Tensor qk = context.CreateGPUTensor(DataTypeImpl::GetType(), qk_shape); constexpr uint32_t tile_size = 64; const uint32_t num_total_seq_length_tile = (parameters.total_sequence_length_ + tile_size - 1) / tile_size; // The metadata is used to store the max and sum of each tile. const TensorShapeVector metadata_dims({parameters.batch_size_, parameters.num_heads_, num_total_seq_length_tile, 2}); const TensorShape metadata_shape(metadata_dims); - Tensor metadata = context.CreateGPUTensor(Q->DataType(), metadata_shape); + Tensor metadata = context.CreateGPUTensor(DataTypeImpl::GetType(), metadata_shape); ORT_RETURN_IF_ERROR(ComputeFlashAttentionDecodeQKT(context, Q, attention_bias, &qk, present_key, &metadata, parameters, num_total_seq_length_tile, tile_size)); diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h index 181e411cdc91f..bfa4e7cb6d53f 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h @@ -78,7 +78,8 @@ class FlashAttentionDecodeQKTProgram final : public Program #include #include "contrib_ops/webgpu/quantization/matmul_nbits.h" @@ -18,17 +19,42 @@ namespace webgpu { namespace { -std::string QuantizedDataType(int components) { - switch (components) { - case 1: - return "array"; - case 2: - return "mat4x2"; - case 4: - return "mat2x4"; - default: - return "array"; +std::string ReadZeroPoint(uint32_t nbits, bool has_zero_points) { + ORT_ENFORCE(nbits == 8 || nbits == 4, "Only 4/8 bits are supported for webgpu matmulnbits"); + std::stringstream ss; + if (has_zero_points) { + ss << "const elements_in_uint32 = " << (32 / nbits) << "u;\n" + << "const bits = " << nbits << "u;\n"; + ss << R"( +fn mm_read_zero(row : u32, col : u32, r_dim: u32, c_dim: u32) -> output_element_t { + if (row < r_dim && col < c_dim) { + let offset = row * c_dim + col; + + // u32 holds elements_in_uint32 packed nbits. + let array_index = offset / elements_in_uint32; + let component_index = offset % elements_in_uint32; + let packed_value = zero_points[array_index]; + + // Extract the nbits component + let shift_amount = component_index * bits; +)"; + ss << " let masked_value = (packed_value >> shift_amount) & " << (nbits == 4 ? "0xFu" : "0xFF") << ";\n"; + ss << R"( + return output_element_t(masked_value); } + return output_element_t(0); +} +)"; + } else { + ss << "const default_zero_point = " << (nbits == 4 ? 8 : 128) << ";\n"; + ss << R"( +fn mm_read_zero(row : u32, col : u32, r_dim: u32, c_dim: u32) -> output_element_t { + // The default zero point is 8. + return output_element_t(default_zero_point); +} +)"; + } + return ss.str(); } constexpr unsigned int kMinMForTileOptimization = 4; @@ -46,483 +72,6 @@ ONNX_OPERATOR_KERNEL_EX( .TypeConstraint("T4", DataTypeImpl::GetTensorType()), MatMulNBits); -Status MatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const { - const auto& a = shader.AddInput("input_a", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias); - const auto& b = shader.AddInput("input_b", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias); - const auto& scales = shader.AddInput("scales", ShaderUsage::UseUniform); - const auto& y = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias | ShaderUsage::UseIndicesTypeAlias); - - if (block_size_ == 32) { - const uint32_t workgroup_size = WorkgroupSizeX() * WorkgroupSizeY(); - const uint32_t tile_size = WorkgroupSizeX() * components_b_ * 8; // each uint32 has 8 data. - const uint32_t a_length_per_tile = tile_size / a.NumComponents(); - const uint32_t blocks_per_tile = tile_size / block_size_; - if (tile_m_ > 1 && use_subgroup_) { - ORT_ENFORCE(a.NumComponents() == 4, "input a's components must be equal to 4."); - ORT_ENFORCE(components_b_ == 4, "input b's components must be equal to 4."); - shader.AdditionalImplementation() << "fn mm_readA(batch : u32, row : u32, col : u32) -> input_a_value_t {\n" - " if (row < uniforms.input_a_shape[1] && col < uniforms.input_a_shape[2]) {\n" - << " return " << a.GetByIndices("input_a_indices_t(batch, row, col)") << ";\n" - << " } else {\n" - " return input_a_value_t(0);\n" - " }\n" - "}\n" - << "var sub_b: array, " << WorkgroupSizeY() << ">;\n" - << "var sub_scale: array, " << WorkgroupSizeY() << ">;\n" - << "var inter_results: array, " << WorkgroupSizeY() << ">," << tile_m_ << ">;\n"; - shader.MainFunctionBody() << " let col = workgroup_id.x * " << WorkgroupSizeY() << ";\n" - << " let row = workgroup_id.y * " << tile_m_ << ";\n" - << " let batch = workgroup_id.z;\n"; - shader.MainFunctionBody() << " let n_blocks_per_col = uniforms.input_b_shape[1];\n" - << " let num_tiles = (n_blocks_per_col - 1) / " << blocks_per_tile << " + 1;\n" - // Loop over shared dimension. - << " for (var tile: u32 = 0; tile < num_tiles; tile += 1) {\n" - << " // load one tile B/scale data into shared memory.\n" - // Each thread processes one block. - " let b_col = col + local_id.y;\n" - << " let block = tile * " << blocks_per_tile << " + local_id.x;\n" - << " if (b_col < uniforms.input_b_shape[0] && block < n_blocks_per_col) {\n" - << " sub_b[local_id.y][local_id.x] = " << b.GetByIndices("input_b_indices_t(b_col, block, 0)") << ";\n" - << " sub_scale[local_id.y][local_id.x] = " << scales.GetByOffset("b_col * n_blocks_per_col + block") << ";\n" - << " } else {\n" - " sub_b[local_id.y][local_id.x] = input_b_value_t(0);\n" - " sub_scale[local_id.y][local_id.x] = output_value_t(0);\n" - " }\n" - " workgroupBarrier();\n" - << " var in_y = (local_idx % 32) / 4;\n" - " var in_x = (local_idx / 32) * 4 + local_idx % 4;\n" - << " var word_offset = (local_idx % 4) * " << block_size_ / a.NumComponents() << ";\n" - << " if (sg_size == 8u) {\n" - " in_y = local_idx % 8;\n" - " in_x = local_idx / 8;\n" - << " word_offset = 0u;\n" - " } else if (sg_size == 16u) {\n" - " in_y = (local_idx % 16) / 2;\n" - " in_x = (local_idx / 16) * 2 + local_idx % 2;\n" - << " word_offset = (local_idx % 2) * " << block_size_ / a.NumComponents() << ";\n" - << " } else if (sg_size == 32u) {\n" - " in_y = (local_idx % 32) / 4;\n" - " in_x = (local_idx / 32) * 4 + local_idx % 4;\n" - << " word_offset = (local_idx % 4) * " << block_size_ / a.NumComponents() << ";\n" - << " } else if (sg_size == 64u) {\n" - " in_y = local_idx / 8;\n" - " in_x = local_idx % 8;\n" - << " word_offset = (local_idx % 8) * " << block_size_ / a.NumComponents() << ";\n" - << " }\n"; - if (has_zero_points_) { - const auto& zero_points = shader.AddInput("zero_points", ShaderUsage::UseUniform); - shader.MainFunctionBody() << " let zero_point_bytes_per_col = (n_blocks_per_col + 1) / 2;\n" - " let zero_point_byte_count = b_col * zero_point_bytes_per_col + (block >> 0x1u);\n" - " let zero_point_word_index = zero_point_byte_count >> 0x2u;\n" - " let zero_point_byte_offset = zero_point_byte_count & 0x3u;\n" - " let zero_point_nibble_offset: u32 = block & 0x1u;\n" - " let zero_point_bits_offset = (zero_point_byte_offset << 3) + (zero_point_nibble_offset << 2);\n" - << " let zero_point_word = " << zero_points.GetByOffset("zero_point_word_index") << " >> zero_point_bits_offset;\n" - << " let zero_point = output_element_t((zero_point_word) & 0xFu);\n"; - } else { - // The default zero point is 8 for unsigned 4-bit quantization. - shader.MainFunctionBody() << " let zero_point = output_element_t(8.0);\n"; - } - shader.MainFunctionBody() << " let scale = sub_scale[in_y][in_x];\n" - " let b_data = sub_b[in_y][in_x];\n"; - shader.MainFunctionBody() << " let a_col_start = tile * " << a_length_per_tile << ";\n"; - for (uint32_t i = 0; i < tile_m_; i++) { - shader.MainFunctionBody() << " let a_data" << i << " = mm_readA(batch, row + " << i << ", a_col_start + local_idx);\n"; - } - - shader.MainFunctionBody() << " if (sg_size == 8u) {\n"; - shader.MainFunctionBody() << " for (var i: u32 = 0; i < 4; i++) {\n"; - shader.MainFunctionBody() << " let b_value = b_data[i];\n" - " let b_value_lower = unpack4xU8(b_value & 0x0F0F0F0Fu);\n" - " let b_value_upper = unpack4xU8((b_value >> 4) & 0x0F0F0F0Fu);\n" - " let b_quantized_values = mat2x4(output_element_t(b_value_lower[0]), output_element_t(b_value_upper[0]), output_element_t(b_value_lower[1]), output_element_t(b_value_upper[1]), output_element_t(b_value_lower[2]), output_element_t(b_value_upper[2]), output_element_t(b_value_lower[3]), output_element_t(b_value_upper[3]));\n" - " let b_dequantized_values = (b_quantized_values - mat2x4(zero_point, zero_point, zero_point, zero_point, zero_point, zero_point, zero_point, zero_point)) * scale;\n"; - for (uint32_t i = 0; i < tile_m_; i++) { - if (i == 0) { - shader.MainFunctionBody() << " var "; - } - shader.MainFunctionBody() << " a0 = subgroupShuffle(a_data" << i << ", i * 2);\n"; - if (i == 0) { - shader.MainFunctionBody() << " var "; - } - shader.MainFunctionBody() << " a1 = subgroupShuffle(a_data" << i << ", i * 2 + 1);\n"; - shader.MainFunctionBody() << " inter_results[" << i << "][in_y][in_x] += dot(a0, b_dequantized_values[0]) + dot(a1, b_dequantized_values[1]);\n"; - } - shader.MainFunctionBody() << " }\n"; - shader.MainFunctionBody() << " } else if (sg_size == 16u) {\n"; - shader.MainFunctionBody() << " for (var i: u32 = 0; i < 4; i++) {\n"; - shader.MainFunctionBody() << " let b_value = b_data[i];\n" - " let b_value_lower = unpack4xU8(b_value & 0x0F0F0F0Fu);\n" - " let b_value_upper = unpack4xU8((b_value >> 4) & 0x0F0F0F0Fu);\n" - " let b_quantized_values = mat2x4(output_element_t(b_value_lower[0]), output_element_t(b_value_upper[0]), output_element_t(b_value_lower[1]), output_element_t(b_value_upper[1]), output_element_t(b_value_lower[2]), output_element_t(b_value_upper[2]), output_element_t(b_value_lower[3]), output_element_t(b_value_upper[3]));\n" - " let b_dequantized_values = (b_quantized_values - mat2x4(zero_point, zero_point, zero_point, zero_point, zero_point, zero_point, zero_point, zero_point)) * scale;\n"; - for (uint32_t i = 0; i < tile_m_; i++) { - if (i == 0) { - shader.MainFunctionBody() << " var "; - } - shader.MainFunctionBody() << " a0 = subgroupShuffle(a_data" << i << ", i * 2);\n"; - if (i == 0) { - shader.MainFunctionBody() << " var "; - } - shader.MainFunctionBody() << " a00 = subgroupShuffle(a_data" << i << ", i * 2 + 8);\n"; - if (i == 0) { - shader.MainFunctionBody() << " var "; - } - shader.MainFunctionBody() << " a1 = subgroupShuffle(a_data" << i << ", i * 2 + 1);\n"; - if (i == 0) { - shader.MainFunctionBody() << " var "; - } - shader.MainFunctionBody() << " a11 = subgroupShuffle(a_data" << i << ", i * 2 + 9);\n"; - shader.MainFunctionBody() << " inter_results[" << i << "][in_y][in_x] += dot(select(a00, a0, local_idx % 2 == 0), b_dequantized_values[0]) + dot(select(a11, a1, local_idx % 2 == 0), b_dequantized_values[1]);\n"; - } - shader.MainFunctionBody() << " word_offset += " << 8 / a.NumComponents() << ";\n" - << " }\n"; - shader.MainFunctionBody() << " } else {\n"; - shader.MainFunctionBody() << " for (var i: u32 = 0; i < 4; i++) {\n"; - shader.MainFunctionBody() << " let b_value = b_data[i];\n" - " let b_value_lower = unpack4xU8(b_value & 0x0F0F0F0Fu);\n" - " let b_value_upper = unpack4xU8((b_value >> 4) & 0x0F0F0F0Fu);\n" - " let b_quantized_values = mat2x4(output_element_t(b_value_lower[0]), output_element_t(b_value_upper[0]), output_element_t(b_value_lower[1]), output_element_t(b_value_upper[1]), output_element_t(b_value_lower[2]), output_element_t(b_value_upper[2]), output_element_t(b_value_lower[3]), output_element_t(b_value_upper[3]));\n" - " let b_dequantized_values = (b_quantized_values - mat2x4(zero_point, zero_point, zero_point, zero_point, zero_point, zero_point, zero_point, zero_point)) * scale;\n"; - for (uint32_t i = 0; i < tile_m_; i++) { - if (i == 0) { - shader.MainFunctionBody() << " var "; - } - shader.MainFunctionBody() << " a0 = subgroupShuffle(a_data" << i << ", word_offset);\n"; - if (i == 0) { - shader.MainFunctionBody() << " var "; - } - shader.MainFunctionBody() << " a1 = subgroupShuffle(a_data" << i << ", word_offset + 1);\n"; - shader.MainFunctionBody() << " inter_results[" << i << "][in_y][in_x] += dot(a0, b_dequantized_values[0]) + dot(a1, b_dequantized_values[1]);\n"; - } - shader.MainFunctionBody() << " word_offset += " << 8 / a.NumComponents() << ";\n"; - shader.MainFunctionBody() << " }\n"; - shader.MainFunctionBody() << " }\n"; - shader.MainFunctionBody() << " workgroupBarrier();\n"; - - shader.MainFunctionBody() << " }\n"; - shader.MainFunctionBody() << " if (local_idx < " << WorkgroupSizeY() * tile_m_ << ") {\n" - << " let inner_row = local_idx / " << WorkgroupSizeY() << ";\n" - << " let inner_col = local_idx % " << WorkgroupSizeY() << ";\n" - << " var output_value = output_value_t(0);\n" - << " for (var b = 0u; b < " << WorkgroupSizeX() << "; b++) {\n" - << " output_value += inter_results[inner_row][inner_col][b];\n" - " }\n" - " if (row + inner_row < uniforms.output_shape[1] && col + inner_col < uniforms.output_shape[2]) {\n" - << " " << y.SetByIndices("output_indices_t(batch, row + inner_row, col + inner_col)", "output_value") << ";\n" - << " }\n" - " }\n"; - } else { - if (tile_m_ == 1) { - shader.AdditionalImplementation() << "fn mm_readA(batch : u32, row : u32, col : u32) -> input_a_value_t {\n" - " if (col < uniforms.input_a_shape[2]) {\n" - << " return " << a.GetByIndices("input_a_indices_t(batch, row, col)") << ";\n" - << " } else {\n" - " return input_a_value_t(0);\n" - " }\n" - "}\n" - << "var sub_a: array;\n" - << "var inter_results: array, " << WorkgroupSizeY() << ">;\n"; - std::string offset = "workgroup_idx * " + std::to_string(WorkgroupSizeY()); - shader.MainFunctionBody() << " let output_indices = " << y.OffsetToIndices(offset) << ";\n" - << " let col = output_indices[2];\n" - " let row = output_indices[1];\n" - " let batch = output_indices[0];\n"; - } else { - ORT_ENFORCE(tile_m_ < WorkgroupSizeY(), "tile_m must be less than or equal to WorkgroupSizeY."); - ORT_ENFORCE(WorkgroupSizeX() == WorkgroupSizeY(), "WorkgroupSizeX must be equal to WorkgroupSizeY."); - - shader.AdditionalImplementation() << "fn mm_readA(batch : u32, row : u32, col : u32) -> input_a_value_t {\n" - " if (row < uniforms.input_a_shape[1] && col < uniforms.input_a_shape[2]) {\n" - << " return " << a.GetByIndices("input_a_indices_t(batch, row, col)") << ";\n" - << " } else {\n" - " return input_a_value_t(0);\n" - " }\n" - "}\n" - << "var sub_a: array," << tile_m_ << ">;\n" - << "var inter_results: array, " << WorkgroupSizeY() << ">," << tile_m_ << ">;\n"; - shader.MainFunctionBody() << " let col = workgroup_id.x * " << WorkgroupSizeY() << ";\n" - << " let row = workgroup_id.y * " << tile_m_ << ";\n" - << " let batch = workgroup_id.z;\n"; - } - shader.MainFunctionBody() << " let n_blocks_per_col = uniforms.input_b_shape[1];\n" - << " let num_tiles = (n_blocks_per_col - 1) / " << blocks_per_tile << " + 1;\n" - // Loop over shared dimension. - << " for (var tile: u32 = 0; tile < num_tiles; tile += 1) {\n" - << " let a_col_start = tile * " << a_length_per_tile << ";\n" - << " // load one tile A data into shared memory.\n" - << " for (var a_offset = local_idx; a_offset < " << a_length_per_tile << "; a_offset += " << workgroup_size << ") {\n" - << " let a_col = a_col_start + a_offset;\n"; - if (tile_m_ == 1) { - shader.MainFunctionBody() << " sub_a[a_offset] = mm_readA(batch, row, a_col);\n"; - } else { - for (uint32_t i = 0; i < tile_m_; i++) { - shader.MainFunctionBody() << " sub_a[" << i << "][a_offset] = mm_readA(batch, row + " << i << ", a_col);\n"; - } - } - shader.MainFunctionBody() << " }\n" - " workgroupBarrier();\n" - // Each thread processes one block. - " let b_row = col + local_id.y;\n" - << " let block = tile * " << blocks_per_tile << " + local_id.x;\n"; - if (has_zero_points_) { - const auto& zero_points = shader.AddInput("zero_points", ShaderUsage::UseUniform); - shader.MainFunctionBody() << " let zero_point_bytes_per_col = (n_blocks_per_col + 1) / 2;\n" - " let zero_point_byte_count = b_row * zero_point_bytes_per_col + (block >> 0x1u);\n" - " let zero_point_word_index = zero_point_byte_count >> 0x2u;\n" - " let zero_point_byte_offset = zero_point_byte_count & 0x3u;\n" - " let zero_point_nibble_offset: u32 = block & 0x1u;\n" - " let zero_point_bits_offset = (zero_point_byte_offset << 3) + (zero_point_nibble_offset << 2);\n" - << " let zero_point_word = " << zero_points.GetByOffset("zero_point_word_index") << " >> zero_point_bits_offset;\n" - << " let zero_point = output_element_t((zero_point_word) & 0xFu);\n"; - } else { - // The default zero point is 8 for unsigned 4-bit quantization. - shader.MainFunctionBody() << " let zero_point = output_element_t(8.0);\n"; - } - shader.MainFunctionBody() << " var scale = output_element_t(0);\n" - " var b_data = input_b_value_t(0);\n" - << " if (block < n_blocks_per_col) {\n" - << " scale = " << scales.GetByOffset("b_row * n_blocks_per_col + block") << ";\n" - << " b_data = " << b.GetByIndices("input_b_indices_t(b_row, block, 0)") << ";\n" - << " }\n" - << " var word_offset = local_id.x * " << block_size_ / a.NumComponents() << ";\n" - << " for (var i: u32 = 0; i < " << components_b_ << "; i++) {\n"; - shader.MainFunctionBody() << " let b_value = b_data"; - if (components_b_ > 1) { - shader.MainFunctionBody() << "[i]"; - } - shader.MainFunctionBody() << ";\n" - " let b_value_lower = unpack4xU8(b_value & 0x0F0F0F0Fu);\n" - " let b_value_upper = unpack4xU8((b_value >> 4) & 0x0F0F0F0Fu);\n" - " let b_quantized_values = mat2x4(output_element_t(b_value_lower[0]), output_element_t(b_value_upper[0]), output_element_t(b_value_lower[1]), output_element_t(b_value_upper[1]), output_element_t(b_value_lower[2]), output_element_t(b_value_upper[2]), output_element_t(b_value_lower[3]), output_element_t(b_value_upper[3]));\n" - " let b_dequantized_values = (b_quantized_values - mat2x4("; - for (int i = 0; i < 8; i++) { - shader.MainFunctionBody() << "zero_point"; - if (i < 7) { - shader.MainFunctionBody() << ", "; - } - } - shader.MainFunctionBody() << ")) * scale;\n"; - if (tile_m_ == 1) { - switch (a.NumComponents()) { - case 1: - shader.MainFunctionBody() << " inter_results[local_id.y][local_id.x] += dot(vec4(sub_a[word_offset], sub_a[word_offset + 1], sub_a[word_offset + 2], sub_a[word_offset + 3]), b_dequantized_values[0]) + dot(vec4(sub_a[word_offset + 4], sub_a[word_offset + 5], sub_a[word_offset + 6], sub_a[word_offset + 7]), b_dequantized_values[1]);\n"; - break; - case 2: - shader.MainFunctionBody() << " inter_results[local_id.y][local_id.x] += dot(vec4(sub_a[word_offset], sub_a[word_offset + 1]), b_dequantized_values[0]) + dot(vec4(sub_a[word_offset + 2], sub_a[word_offset + 3]), b_dequantized_values[1]);\n"; - break; - case 4: - shader.MainFunctionBody() << " inter_results[local_id.y][local_id.x] += dot(sub_a[word_offset], b_dequantized_values[0]) + dot(sub_a[word_offset + 1], b_dequantized_values[1]);\n"; - break; - default: - break; - } - } else { - for (uint32_t i = 0; i < tile_m_; i++) { - switch (a.NumComponents()) { - case 1: - shader.MainFunctionBody() << " inter_results[" << i << "][local_id.y][local_id.x] += dot(vec4(sub_a[" << i << "][word_offset], sub_a[" << i << "][word_offset + 1], sub_a[" << i << "][word_offset + 2], sub_a[" << i << "][word_offset + 3]), b_dequantized_values[0]) + dot(vec4(sub_a[" << i << "][word_offset + 4], sub_a[" << i << "][word_offset + 5], sub_a[" << i << "][word_offset + 6], sub_a[" << i << "][word_offset + 7]), b_dequantized_values[1]);\n"; - break; - case 2: - shader.MainFunctionBody() << " inter_results[" << i << "][local_id.y][local_id.x] += dot(vec4(sub_a[" << i << "][word_offset], sub_a[" << i << "][word_offset + 1]), b_dequantized_values[0]) + dot(vec4(sub_a[" << i << "][word_offset + 2], sub_a[" << i << "][word_offset + 3]), b_dequantized_values[1]);\n"; - break; - case 4: - shader.MainFunctionBody() << " inter_results[" << i << "][local_id.y][local_id.x] += dot(sub_a[" << i << "][word_offset], b_dequantized_values[0]) + dot(sub_a[" << i << "][word_offset + 1], b_dequantized_values[1]);\n"; - break; - default: - break; - } - } - } - shader.MainFunctionBody() << " word_offset += " << 8 / a.NumComponents() << ";\n" - << " }\n" - " workgroupBarrier();\n" - " }\n"; - if (tile_m_ == 1) { - shader.MainFunctionBody() << " if (local_idx < " << WorkgroupSizeY() << ") {\n" - << " var output_value = output_value_t(0);\n" - << " for (var b = 0u; b < " << WorkgroupSizeX() << "; b++) {\n" - << " output_value += inter_results[local_idx][b];\n" - " }\n" - " if (col + local_idx < uniforms.output_shape[2]) {\n" - << " " << y.SetByIndices("output_indices_t(batch, row, col + local_idx)", "output_value") << ";\n" - << " }\n" - " }\n"; - } else { - shader.MainFunctionBody() << " if (local_id.y < " << tile_m_ << ") {\n" - << " var output_value = output_value_t(0);\n" - << " for (var b = 0u; b < " << WorkgroupSizeX() << "; b++) {\n" - << " output_value += inter_results[local_id.y][local_id.x][b];\n" - " }\n" - " if (row + local_id.y < uniforms.output_shape[1] && col + local_id.x < uniforms.output_shape[2]) {\n" - << " " << y.SetByIndices("output_indices_t(batch, row + local_id.y, col + local_id.x)", "output_value") << ";\n" - << " }\n" - " }\n"; - } - } - } else { - const std::string quantized_data_type = QuantizedDataType(a.NumComponents()); - const int output_element_number = y.NumComponents() * onnxruntime::narrow(output_number_); - - const uint32_t shared_memory_size = output_number_ * WORKGROUP_SIZE; - std::string offset = "workgroup_idx * " + std::to_string(output_number_); - shader.AdditionalImplementation() << "var workgroup_shared : array;\n"; - shader.MainFunctionBody() << " let output_indices = " << y.OffsetToIndices(offset) << ";\n" - << " let col = output_indices[2];\n" - " let row = output_indices[1];\n" - " let batch = output_indices[0];\n" - " let n_blocks_per_col = uniforms.input_b_shape[1];\n" - " let blob_size = uniforms.input_b_shape[2];\n" - " for (var block = local_id.x; block < n_blocks_per_col; block += workgroup_size_x) {\n" - << " var word_offset = block * uniforms.block_size / " << a.NumComponents() << ";\n"; - - // prepare scale and zero point - shader.MainFunctionBody() << " var col_index = col * " << y.NumComponents() << ";\n"; - if (has_zero_points_) { - const auto& zero_points = shader.AddInput("zero_points", ShaderUsage::UseUniform); - shader.MainFunctionBody() << " let zero_point_bytes_per_col = (n_blocks_per_col + 1) / 2;\n" - " var zero_point_byte_count: u32;\n" - " var zero_point_word_index: u32;\n" - " var zero_point_byte_offset: u32;\n" - " let zero_point_nibble_offset: u32 = block & 0x1u;\n" - " var zero_point_bits_offset: u32;\n" - " var zero_point_word: u32;\n"; - for (int c = 0; c < output_element_number; c++) { - shader.MainFunctionBody() << " let scale" << c << " = " << scales.GetByOffset("col_index * n_blocks_per_col + block") << ";\n" - << " zero_point_byte_count = col_index * zero_point_bytes_per_col + (block >> 0x1u);\n" - " zero_point_word_index = zero_point_byte_count >> 0x2u;\n" - " zero_point_byte_offset = zero_point_byte_count & 0x3u;\n" - " zero_point_bits_offset = (zero_point_byte_offset << 3) + (zero_point_nibble_offset << 2);\n" - << " zero_point_word = " << zero_points.GetByOffset("zero_point_word_index") << " >> zero_point_bits_offset;\n" - << " let zero_point" << c << " = output_element_t((zero_point_word) & 0xFu);\n" - << " col_index += 1;\n"; - } - } else { - shader.MainFunctionBody() << " let zero_point = output_element_t(8.0);\n"; - for (int c = 0; c < output_element_number; c++) { - shader.MainFunctionBody() << " let scale" << c << " = " << scales.GetByOffset("col_index * n_blocks_per_col + block") << ";\n" - << " col_index += 1;\n"; - } - } - - shader.MainFunctionBody() << " for (var word: u32 = 0; word < blob_size; word += 1) {\n"; - - // prepare b data - shader.MainFunctionBody() << " col_index = col * " << y.NumComponents() << ";\n"; - for (int c = 0; c < output_element_number; c++) { - shader.MainFunctionBody() << " let b" << c << "_data = " << b.GetByIndices("input_b_indices_t(col_index, block, word)") << ";\n" - << " col_index += 1;\n"; - } - shader.MainFunctionBody() << " var b_value : u32;\n" - " let b_mask : u32 = 0x0F0F0F0Fu;\n" - " var b_value_lower : vec4;\n" - " var b_value_upper : vec4;\n" - << " var b_quantized_values : " << quantized_data_type << ";\n" - << " var b_dequantized_values : " << quantized_data_type << ";\n"; - - shader.MainFunctionBody() << " for (var i: u32 = 0; i < " << components_b_ << "; i++) {\n"; - - // process one word - shader.MainFunctionBody() << " var input_offset = " << a.IndicesToOffset("input_a_indices_t(batch, row, word_offset)") << ";\n" - << " var a_data: " << quantized_data_type << ";\n" - << " for (var j: u32 = 0; j < " << (8 / a.NumComponents()) << "; j++) {\n" - << " if (word_offset + j < uniforms.input_a_shape[2]) {\n" - << " a_data[j] = " << a.GetByOffset("input_offset") << ";\n" - << " input_offset++;\n" - " } else {\n" - " a_data[j] = input_a_value_t(0);\n" - " }\n" - " }\n"; - for (int c = 0; c < output_element_number; c++) { - shader.MainFunctionBody() << " b_value = b" << c << "_data"; - if (components_b_ > 1) { - shader.MainFunctionBody() << "[i]"; - } - shader.MainFunctionBody() << ";\n" - " b_value_lower = unpack4xU8(b_value & b_mask);\n" - " b_value_upper = unpack4xU8((b_value >> 4) & b_mask);\n" - << " b_quantized_values = " << quantized_data_type << "(output_element_t(b_value_lower[0]), output_element_t(b_value_upper[0]), output_element_t(b_value_lower[1]), output_element_t(b_value_upper[1]), output_element_t(b_value_lower[2]), output_element_t(b_value_upper[2]), output_element_t(b_value_lower[3]), output_element_t(b_value_upper[3]));\n" - << " b_dequantized_values = "; - if (a.NumComponents() == 1) { - if (has_zero_points_) { - shader.MainFunctionBody() << quantized_data_type << "((b_quantized_values[0] - zero_point" << c << ") * scale" << c << ", " - << "(b_quantized_values[1] - zero_point" << c << ") * scale" << c << ", " - << "(b_quantized_values[2] - zero_point" << c << ") * scale" << c << ", " - << "(b_quantized_values[3] - zero_point" << c << ") * scale" << c << ", " - << "(b_quantized_values[4] - zero_point" << c << ") * scale" << c << ", " - << "(b_quantized_values[5] - zero_point" << c << ") * scale" << c << ", " - << "(b_quantized_values[6] - zero_point" << c << ") * scale" << c << ", " - << "(b_quantized_values[7] - zero_point" << c << ") * scale" << c << ");\n"; - } else { - shader.MainFunctionBody() << quantized_data_type << "((b_quantized_values[0] - zero_point) * scale" << c << ", " - << "(b_quantized_values[1] - zero_point) * scale" << c << "," - << "(b_quantized_values[2] - zero_point) * scale" << c << "," - << "(b_quantized_values[3] - zero_point) * scale" << c << "," - << "(b_quantized_values[4] - zero_point) * scale" << c << "," - << "(b_quantized_values[5] - zero_point) * scale" << c << "," - << "(b_quantized_values[6] - zero_point) * scale" << c << "," - << "(b_quantized_values[7] - zero_point) * scale" << c << ");\n"; - } - } else { - shader.MainFunctionBody() << "(b_quantized_values - " << quantized_data_type << "("; - for (int i = 0; i < 8; i++) { - if (has_zero_points_) { - shader.MainFunctionBody() << "zero_point" << c; - } else { - shader.MainFunctionBody() << "zero_point"; - } - if (i < 7) { - shader.MainFunctionBody() << ", "; - } - } - shader.MainFunctionBody() << ")) * scale" << c << ";\n"; - } - - shader.MainFunctionBody() << " workgroup_shared[local_id.x * " << output_number_ << " + " << c / y.NumComponents() << "]"; - if (y.NumComponents() > 1) { - shader.MainFunctionBody() << "[" << c % y.NumComponents() << "]"; - } - shader.MainFunctionBody() << " += "; - if (a.NumComponents() == 1) { - shader.MainFunctionBody() << "a_data[0] * b_dequantized_values[0] + " - "a_data[1] * b_dequantized_values[1] + " - "a_data[2] * b_dequantized_values[2] + " - "a_data[3] * b_dequantized_values[3] + " - "a_data[4] * b_dequantized_values[4] + " - "a_data[5] * b_dequantized_values[5] + " - "a_data[6] * b_dequantized_values[6] + " - "a_data[7] * b_dequantized_values[7];\n"; - } else if (a.NumComponents() == 2) { - shader.MainFunctionBody() << "dot(a_data[0], b_dequantized_values[0]) + " - "dot(a_data[1], b_dequantized_values[1]) + " - "dot(a_data[2], b_dequantized_values[2]) + " - "dot(a_data[3], b_dequantized_values[3]);\n"; - } else if (a.NumComponents() == 4) { - shader.MainFunctionBody() << "dot(a_data[0], b_dequantized_values[0]) + " - "dot(a_data[1], b_dequantized_values[1]);\n"; - } - } - - shader.MainFunctionBody() << " word_offset += " << 8 / a.NumComponents() << ";\n" - << " }\n" - " }\n" - " }\n" - " workgroupBarrier();\n" - << " if (local_id.x < " << output_number_ << ") {\n" - << " var output_value = output_value_t(0);\n" - " var workgroup_shared_offset = local_id.x;\n" - << " let blocks_num = min(" << shared_memory_size << ", n_blocks_per_col);\n" - << " for (var b = 0u; b < blocks_num; b++) {\n" - " output_value += workgroup_shared[workgroup_shared_offset];\n" - << " workgroup_shared_offset += " << output_number_ << ";\n" - << " }\n" - << " " << y.SetByIndices("output_indices_t(batch, row, col + local_id.x)", "output_value") << "\n" - << " }\n"; - } - - return Status::OK(); -} - Status MatMulNBitsWideTileProgram::GenerateShaderCode(ShaderHelper& shader) const { const auto& a = shader.AddInput("input_a", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias); const auto& b = shader.AddInput("input_b", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias); @@ -541,65 +90,21 @@ Status MatMulNBitsWideTileProgram::GenerateShaderCode(ShaderHelper& shader) cons // memory read/write helpers shader.AdditionalImplementation() << "fn mm_read_a(batch : u32, row : u32, col : u32) -> input_a_value_t {\n" - << " if (row < uniforms.input_a_shape[1] && col < uniforms.input_a_shape[2]) {\n" + << " if (batch < uniforms.input_a_shape[0] && row < uniforms.input_a_shape[1] && col < uniforms.input_a_shape[2]) {\n" << " return " << a.GetByIndices("input_a_indices_t(batch, row, col)") << ";\n" << " }\n" << " return input_a_value_t(0);\n" << "}\n"; + if (nbits_ == 4) { + shader.AdditionalImplementation() << "\n" + << "fn mm_read_b(row : u32, col : u32) -> input_b_value_t {\n" + << " if (row < uniforms.input_b_shape[0] && col < uniforms.input_b_shape[1]) {\n" + << " return " << b.GetByIndices("input_b_indices_t(row, col, 0)") << ";\n" + << " }\n" + << " return input_b_value_t(0);\n" + << "}\n"; - shader.AdditionalImplementation() << "\n" - << "fn mm_read_b(row : u32, col : u32) -> input_b_value_t {\n" - << " if (row < uniforms.input_b_shape[0] && col < uniforms.input_b_shape[1]) {\n" - << " return " << b.GetByIndices("input_b_indices_t(row, col, 0)") << ";\n" - << " }\n" - << " return input_b_value_t(0);\n" - << "}\n"; - - shader.AdditionalImplementation() << "\n" - << "fn mm_read_scale(row : u32, col : u32) -> output_element_t {\n" - << " if (row < uniforms.input_b_shape[0] && col < uniforms.input_b_shape[1]) {\n" - << " return scales[row * uniforms.input_b_shape[1] + col];\n" - << " }\n" - << " return output_element_t(0);\n" - << "}\n"; - - if (has_zero_points_) { shader.AdditionalImplementation() << R"( -fn mm_read_zero(row : u32, col : u32) -> output_element_t { - if (row < uniforms.input_b_shape[0] && col < uniforms.input_b_shape[1]) { - let offset = row * uniforms.input_b_stride[0] + col * uniforms.input_b_stride[1]; - - // u32 holds 8 packed uint4. - let array_index = offset / 8u; - let component_index = offset % 8u; - let packed_value = zero_points[array_index]; - - // Extract the uint4 component - let shift_amount = component_index * 4u; - let masked_value = (packed_value >> shift_amount) & 0xFu; - - return output_element_t(masked_value); - } - return output_element_t(0); -} -)"; - } else { - shader.AdditionalImplementation() << R"( -fn mm_read_zero(row : u32, col : u32) -> output_element_t { - // The default zero point is 8. - return output_element_t(8); -} -)"; - } - - shader.AdditionalImplementation() << "\n" - << "fn mm_write_y(batch : u32, row : u32, col : u32, value : output_value_t) {\n" - << " if (row < uniforms.output_shape[1] && col < uniforms.output_shape[2]) {\n" - << " " << y.SetByIndices("output_indices_t(batch, row, col)", "value") << "\n" - << " }\n" - << "}\n"; - - shader.AdditionalImplementation() << R"( fn dequantize_packed8xU4(packed_value : u32, zero_point : output_element_t, scale : output_element_t) -> mat2x4 { let lower_values: vec4 = unpack4xU8(packed_value & 0x0F0F0F0Fu); let upper_values: vec4 = unpack4xU8((packed_value >> 4u) & 0x0F0F0F0Fu); @@ -620,6 +125,23 @@ fn dequantize_packed8xU4(packed_value : u32, zero_point : output_element_t, scal return dequantized_values; } )"; + } + + shader.AdditionalImplementation() << "\n" + << "fn mm_read_scale(row : u32, col : u32) -> output_element_t {\n" + << " if (row < uniforms.input_b_shape[0] && col < uniforms.input_b_shape[1]) {\n" + << " return scales[row * uniforms.input_b_shape[1] + col];\n" + << " }\n" + << " return output_element_t(0);\n" + << "}\n" + << ReadZeroPoint(nbits_, has_zero_points_); + + shader.AdditionalImplementation() << "\n" + << "fn mm_write_y(batch : u32, row : u32, col : u32, value : output_value_t) {\n" + << " if (row < uniforms.output_shape[1] && col < uniforms.output_shape[2]) {\n" + << " " << y.SetByIndices("output_indices_t(batch, row, col)", "value") << "\n" + << " }\n" + << "}\n"; // declare const variables shader.AdditionalImplementation() << "\n" @@ -635,9 +157,9 @@ fn dequantize_packed8xU4(packed_value : u32, zero_point : output_element_t, scal // main shader.MainFunctionBody() << R"MAIN_FN( - let batch = workgroup_id.z; - let row = workgroup_id.y * kTileM; - let col = workgroup_id.x * kTileN; + let batch = workgroup_idx / (uniforms.num_M_tile * uniforms.num_N_tile); + let row = ((workgroup_idx / uniforms.num_N_tile) % uniforms.num_M_tile) * kTileM; + let col = (workgroup_idx % uniforms.num_N_tile) * kTileN; let a_elements_per_col = uniforms.input_a_shape[2]; let a_blocks_per_col = (a_elements_per_col + kAComponentsForBlock32 - 1) / kAComponentsForBlock32; @@ -655,10 +177,13 @@ fn dequantize_packed8xU4(packed_value : u32, zero_point : output_element_t, scal let b_row = col + local_idx; let b_col = a_block_idx; - let b_data = mm_read_b(b_row, b_col); let scale = mm_read_scale(b_row, b_col); - let zero_point = mm_read_zero(b_row, b_col); + let zero_point = mm_read_zero(b_row, b_col, uniforms.input_b_shape[0], uniforms.zero_blocks_per_col); +)MAIN_FN"; + if (nbits_ == 4) { + shader.MainFunctionBody() << R"MAIN_FN( + let b_data = mm_read_b(b_row, b_col); // `b` component size is 4. for (var b_idx = 0u; b_idx < 4u; b_idx++) { let b_dequantized = dequantize_packed8xU4(b_data[b_idx], zero_point, scale); @@ -669,10 +194,37 @@ fn dequantize_packed8xU4(packed_value : u32, zero_point : output_element_t, scal results[m_idx] += f32(dot(a_data0, b_dequantized[0])) + f32(dot(a_data1, b_dequantized[1])); } } +)MAIN_FN"; + } else { + shader.MainFunctionBody() << " var b_data0 = vec4(0);\n" + " var b_data1 = vec4(0);\n" + " if (b_row < uniforms.input_b_shape[0] && b_col < uniforms.input_b_shape[1]) {\n" + << " b_data0 = " << b.GetByIndices("input_b_indices_t(b_row, b_col, 0)") << ";\n" + << " b_data1 = " << b.GetByIndices("input_b_indices_t(b_row, b_col, 1)") << ";\n" + " }" + << R"MAIN_FN( + for (var b_idx = 0u; b_idx < 4u; b_idx++) { + let b_dequantized0 = (vec4(unpack4xU8(b_data0[b_idx])) - vec4(zero_point)) * scale; + let b_dequantized1 = (vec4(unpack4xU8(b_data1[b_idx])) - vec4(zero_point)) * scale; + for (var m_idx = 0u; m_idx < kTileM; m_idx++) { + let a_data0 = a_data_tile[m_idx][b_idx]; + let a_data1 = a_data_tile[m_idx][b_idx + 4u]; + + results[m_idx] += f32(dot(a_data0, b_dequantized0)) + f32(dot(a_data1, b_dequantized1)); + } + } +)MAIN_FN"; + } + + shader.MainFunctionBody() << R"MAIN_FN( workgroupBarrier(); } + if (batch >= uniforms.input_a_shape[0]) { + return; + } + // Write the results. for (var m_idx = 0u; m_idx < kTileM; m_idx++) { mm_write_y(batch, row + m_idx, col + local_idx, output_value_t(results[m_idx])); @@ -682,6 +234,156 @@ fn dequantize_packed8xU4(packed_value : u32, zero_point : output_element_t, scal return Status::OK(); } +// Apply similar idea with DP4AMatMulNBitsSmallMProgram algorithm. +Status MatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const { + const auto& a = shader.AddInput("input_a", ShaderUsage::UseValueTypeAlias); + const auto& b = shader.AddInput("input_b"); + shader.AddInput("scales_b"); + if (has_zero_points_) { + shader.AddInput("zero_points", ShaderUsage::UseUniform); + } + shader.AddOutput("output", ShaderUsage::UseElementTypeAlias); + const uint32_t components_a = a.NumComponents(); + const uint32_t components_b = b.NumComponents() / 4; // b is stored as uint32 which includs 4 uint8. + constexpr uint32_t tile_size_k_vec = 16; + uint32_t elements_in_value_b = components_b * (32 / nbits_); + uint32_t tile_k_size = tile_size_k_vec * elements_in_value_b; + const uint32_t a_length_per_tile = tile_k_size / components_a; + + shader.AdditionalImplementation() << "const a_length_per_tile = " << a_length_per_tile << "u;\n" + << "const tile_size_k_vec = " << tile_size_k_vec << ";\n" + << "const tile_size_k = " << tile_k_size << "u;\n" + << "const tile_size = " << tile_size_ << "u;\n" + << "const elements_in_value_b = " << elements_in_value_b << "u;\n" + << "const sub_tile_count = " << WorkgroupSizeX() / tile_size_k_vec << "u;\n" + << "const component_a = " << components_a << "u;\n" + << "const component_b = " << components_b << "u;\n"; + shader.AdditionalImplementation() << R"ADDNL_FN( + // Shared memory + var tile_A : array; + var inter_results: array, tile_size>; + fn loadSHMA(batch: u32, a_global: u32, kidx: u32, col: u32) + { + let k_offset = kidx / component_a + col; + if (batch < uniforms.batch_count && k_offset < uniforms.K_of_a) { + tile_A[col] = input_a[batch * uniforms.M * uniforms.K_of_a + a_global * uniforms.K_of_a + k_offset]; + } else { + tile_A[col] = input_a_value_t(0); + } + } +)ADDNL_FN" + << ReadZeroPoint(nbits_, has_zero_points_); + + shader.MainFunctionBody() << R"MAIN_FN( + let batch = workgroup_idx / (uniforms.M * uniforms.num_N_tile); + let a_global = (workgroup_idx / uniforms.num_N_tile) % uniforms.M; + let b_global_base = (workgroup_idx % uniforms.num_N_tile) * tile_size; + + let idx = local_idx % tile_size_k_vec; + let idy = local_idx / tile_size_k_vec; + + for (var kidx = 0u; kidx < uniforms.K; kidx += tile_size_k) + { + for (var id = local_idx; id < a_length_per_tile; id += workgroup_size_x) + { + loadSHMA(batch, a_global, kidx, id); + } + workgroupBarrier(); + + for (var local_row_offset = 0u; local_row_offset < tile_size; local_row_offset += sub_tile_count) + { + var b_global = b_global_base + local_row_offset + idy; + var k_offset = kidx / elements_in_value_b + idx; + if (b_global < uniforms.N && k_offset < uniforms.K_of_b) + { + let block_idx = (kidx + idx * elements_in_value_b) / uniforms.block_size; + let scale_b = scales_b[b_global * uniforms.blocks_per_col + block_idx]; + let zero = mm_read_zero(b_global, block_idx, uniforms.N, uniforms.zero_blocks_per_col); + var b_value = input_b[b_global * uniforms.K_of_b + k_offset]; +)MAIN_FN"; + + if (nbits_ == 4) { + shader.MainFunctionBody() << R"MAIN_FN( + var sum = output_element_t(0); + var a_offset = idx * (8 / component_a) * component_b; + for (var i = 0u; i < component_b; i++) { + let b_value_lower = vec4(unpack4xU8(b_value[i] & 0x0F0F0F0Fu)) - vec4(zero); + let b_value_upper = vec4(unpack4xU8((b_value[i] >> 4) & 0x0F0F0F0Fu)) - vec4(zero); + let b0 = vec4(b_value_lower[0], b_value_upper[0], b_value_lower[1], b_value_upper[1]) * scale_b; + let b1 = vec4(b_value_lower[2], b_value_upper[2], b_value_lower[3], b_value_upper[3]) * scale_b; +)MAIN_FN"; + switch (components_a) { + case 1: + shader.MainFunctionBody() << " sum += dot(vec4(tile_A[a_offset], tile_A[a_offset + 1], tile_A[a_offset + 2], tile_A[a_offset + 3]), b0) +" + " dot(vec4(tile_A[a_offset + 4], tile_A[a_offset + 5], tile_A[a_offset + 6], tile_A[a_offset + 7]), b1);\n" + " a_offset += 8;\n"; + break; + case 2: + shader.MainFunctionBody() << " sum += dot(vec4(tile_A[a_offset], tile_A[a_offset + 1]), b0) +" + "dot(vec4(tile_A[a_offset + 2], tile_A[a_offset + 3]), b1);\n" + " a_offset += 4;\n"; + break; + case 4: + shader.MainFunctionBody() << " sum += dot(tile_A[a_offset], b0) + dot(tile_A[a_offset + 1], b1);\n" + " a_offset += 2;\n"; + break; + default: + break; + } + shader.MainFunctionBody() << " }\n"; + } else { + shader.MainFunctionBody() << R"MAIN_FN( + var sum = output_element_t(0); + var a_offset = idx * (4 / component_a) * component_b; + for (var i = 0u; i < component_b; i++) { + let b_value = (vec4(unpack4xU8(b_value[i])) - vec4(zero)) * scale_b; +)MAIN_FN"; + switch (components_a) { + case 1: + shader.MainFunctionBody() << " sum += dot(vec4(tile_A[a_offset], tile_A[a_offset + 1], tile_A[a_offset + 2], tile_A[a_offset + 3]), b_value);\n" + " a_offset += 4;\n"; + break; + case 2: + shader.MainFunctionBody() << " sum += dot(vec4(tile_A[a_offset], tile_A[a_offset + 1]), b_value);\n" + " a_offset += 2;\n"; + break; + case 4: + shader.MainFunctionBody() << " sum += dot(tile_A[a_offset], b_value);\n" + " a_offset += 1;\n"; + break; + default: + break; + } + shader.MainFunctionBody() << " }\n"; + } + + shader.MainFunctionBody() << R"MAIN_FN( + inter_results[local_row_offset + idy][idx] += sum; + } + } + workgroupBarrier(); + } + + if (batch >= uniforms.batch_count) { + return; + } + + if (local_idx < tile_size) { + var output_value = output_element_t(0); + for (var b = 0u; b < tile_size_k_vec; b++) { + output_value += inter_results[local_idx][b]; + } + let b_global = b_global_base + local_idx; + let output_idx = batch * uniforms.M * uniforms.N + a_global * uniforms.N + b_global; + if (b_global < uniforms.N) { + output[output_idx] = output_value; + } + } +)MAIN_FN"; + + return Status::OK(); +} + Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const { const Tensor* a = context.Input(0); const Tensor* b = context.Input(1); @@ -724,20 +426,18 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context } // On FP32 only GPUs, integer math is faster than FP32 therefore always use DP4A independent of length of M. - if ((M >= kMinMForTileOptimization || y->DataType() == DataTypeImpl::GetType() || nbits == 8 || + if ((M >= kMinMForTileOptimization || y->DataType() == DataTypeImpl::GetType() || context.AdapterInfo().vendor == std::string_view{"qualcomm"}) && CanApplyDP4AMatrixMatMulNBits(context, accuracy_level_, block_size, batch_count, N, K, components_a, has_zero_points)) { return ApplyDP4AMatrixMatMulNBits(a, b, scales, M, N, K, block_size, kMinMForTileOptimization, nbits, context, y); } - // TODO: Remvoe it once the 8bits is supported for the non-dp4 path. - ORT_ENFORCE(nbits == 4, "Only 4 bits are supported for the non-dp4 path for webgpu matmulnbits"); + // zero_points has shape[N * CeilDiv(n_blocks_per_col * bits, 8)]. So here we need to check whether n_blocks_per_col is divisible by 8/nbits. + uint32_t zero_blocks_per_col = n_blocks_per_col % (8 / nbits) == 0 ? n_blocks_per_col : n_blocks_per_col + 1; // WideTileProgram // This program is optimized for Block32 prefill using Tile16x128. - // TODO: loosen restrictions on vendor. - const bool use_wide_tile_program = block_size == 32 && components_a == 4 && components_b == 4 && M >= kMinMForTileOptimization && - context.AdapterInfo().vendor == std::string_view{"intel"}; + const bool use_wide_tile_program = block_size == 32 && components_a == 4 && components_b == 4 && M >= kMinMForTileOptimization; if (use_wide_tile_program) { // Enforce output components to 1. components = 1; @@ -745,8 +445,10 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context constexpr uint32_t workgroup_size = 128; constexpr uint32_t tile_m = workgroup_size / 8; constexpr uint32_t tile_n = workgroup_size; + uint32_t num_N_tile = (N + tile_n - 1) / tile_n; + uint32_t num_M_tile = (M + tile_m - 1) / tile_m; - MatMulNBitsWideTileProgram program{has_zero_points, tile_m, tile_n}; + MatMulNBitsWideTileProgram program{has_zero_points, tile_m, tile_n, nbits}; program.SetWorkgroupSize(workgroup_size); program.SetDispatchGroupSize((N + tile_n - 1) / tile_n, (M + tile_m - 1) / tile_m, @@ -762,7 +464,8 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context {b, ProgramTensorMetadataDependency::TypeAndRank, reshaped_b_shape, onnxruntime::narrow(components_b * 4)}, {scales, ProgramTensorMetadataDependency::None}}) .AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank, reshaped_y_shape, onnxruntime::narrow(components)}) - .AddUniformVariable({block_size}); + .AddUniformVariables({{block_size}, {zero_blocks_per_col}, {num_N_tile}, {num_M_tile}}) + .CacheHint(nbits, has_zero_points); if (has_zero_points) { program.AddInput({zero_points, ProgramTensorMetadataDependency::None, {(zero_points->Shape().Size() + 3) / 4}, 4}); } @@ -770,47 +473,21 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context return context.RunProgram(program); } - // Generic program - // TODO: Support output_number > 1. Some cases are failed when output_number > 1. - constexpr uint32_t output_number = 1; - const uint32_t tile_m = M > kMinMForTileOptimization ? 4 : 1; - const bool has_subgroup = context.HasFeature(wgpu::FeatureName::Subgroups); - const bool use_subgroup = has_subgroup && context.AdapterInfo().vendor == std::string_view{"intel"} && components_a == 4 && block_size == 32; - MatMulNBitsProgram program{output_number, block_size, tile_m, static_cast(components_b), has_zero_points, use_subgroup}; - if (M > kMinMForTileOptimization && block_size == 32) { - components = 1; - constexpr uint32_t workgroup_size = 64; - constexpr uint32_t workgroup_y = 8; - constexpr uint32_t workgroup_x = workgroup_size / workgroup_y; - program.SetWorkgroupSize(workgroup_x, workgroup_y, 1); - program.SetDispatchGroupSize((N + workgroup_y - 1) / workgroup_y, - (M + tile_m - 1) / tile_m, - batch_count); - program.CacheHint("T_M" + std::to_string(tile_m) + "Subgroup" + std::to_string(use_subgroup)); - } else if (block_size == 32) { - components = 1; - // TODO: Tune the workgroup size when `M=1`. - constexpr uint32_t workgroup_size = 128; - const uint32_t workgroup_y = N % 8 == 0 ? 8 : 1; - const uint32_t workgroup_x = workgroup_size / workgroup_y; - program.SetWorkgroupSize(workgroup_x, workgroup_y, 1); - program.SetDispatchGroupSize(data_size / components / workgroup_y); - program.CacheHint("T_M" + std::to_string(tile_m)); - } else { - program.SetDispatchGroupSize(data_size / components / output_number); - program.CacheHint("O_N" + std::to_string(output_number)); - } - - TensorShape reshaped_a_shape{batch_count, M, K / components_a}; - TensorShape reshaped_b_shape{N, n_blocks_per_col, blob_size_in_words / components_b}; - TensorShape reshaped_y_shape{batch_count, M, N / components}; - + constexpr uint32_t workgroup_size = 128; + constexpr uint32_t tile_size = 8; + constexpr uint32_t kU32Components = 4; + uint32_t components_b_with_u32 = components_b * kU32Components; + uint32_t num_N_tile = (N + tile_size - 1) / tile_size; + MatMulNBitsProgram program{tile_size, nbits, has_zero_points}; + program.SetWorkgroupSize(workgroup_size); + program.SetDispatchGroupSize((N + tile_size - 1) / tile_size, M, batch_count); program - .AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, reshaped_a_shape, static_cast(components_a)}, - {b, ProgramTensorMetadataDependency::TypeAndRank, reshaped_b_shape, static_cast(components_b * 4 /** b will be accessed as uint32 which includs 4 uint8. So here we need to multiply 4.*/)}, - {scales, ProgramTensorMetadataDependency::None}}) - .AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank, reshaped_y_shape, static_cast(components)}) - .AddUniformVariable({block_size}); + .AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, static_cast(components_a)}, + {b, ProgramTensorMetadataDependency::TypeAndRank, static_cast(components_b_with_u32)}, + {scales, ProgramTensorMetadataDependency::TypeAndRank}}) + .AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank}) + .AddUniformVariables({{M}, {N}, {K}, {K / components_a}, {n_blocks_per_col * blob_size / components_b_with_u32}, {block_size}, {n_blocks_per_col}, {zero_blocks_per_col}, {num_N_tile}, {batch_count}}) + .CacheHint(nbits, has_zero_points); if (has_zero_points) { program.AddInput({zero_points, ProgramTensorMetadataDependency::None, {(zero_points->Shape().Size() + 3) / 4}, 4}); } diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h index d5e4bc68fc33a..807576c91752b 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h @@ -12,41 +12,44 @@ namespace webgpu { using namespace onnxruntime::webgpu; -class MatMulNBitsProgram final : public Program { +class MatMulNBitsWideTileProgram final : public Program { public: - MatMulNBitsProgram(uint32_t output_number, uint32_t block_size, uint32_t tile_m, int components_b, bool has_zero_points, bool use_subgroup) : Program{"MatMulNBits"}, - output_number_{output_number}, - block_size_{block_size}, - tile_m_{tile_m}, - components_b_{components_b}, - has_zero_points_{has_zero_points}, - use_subgroup_(use_subgroup) { - } + MatMulNBitsWideTileProgram(bool has_zero_points, uint32_t tile_m, uint32_t tile_n, uint32_t nbits) + : Program{"MatMulNBitsWideTileProgram"}, has_zero_points_{has_zero_points}, tile_m_(tile_m), tile_n_(tile_n), nbits_(nbits) {} Status GenerateShaderCode(ShaderHelper& sh) const override; - WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"block_size", ProgramUniformVariableDataType::Uint32}); + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"block_size", ProgramUniformVariableDataType::Uint32}, + {"zero_blocks_per_col", ProgramUniformVariableDataType::Uint32}, + {"num_N_tile", ProgramUniformVariableDataType::Uint32}, + {"num_M_tile", ProgramUniformVariableDataType::Uint32}); private: - uint32_t output_number_; - uint32_t block_size_; - uint32_t tile_m_; - int components_b_; bool has_zero_points_; - bool use_subgroup_; + uint32_t tile_m_; + uint32_t tile_n_; + uint32_t nbits_; }; -class MatMulNBitsWideTileProgram final : public Program { +class MatMulNBitsProgram final : public Program { public: - MatMulNBitsWideTileProgram(bool has_zero_points, uint32_t tile_m, uint32_t tile_n) - : Program{"MatMulNBitsWideTileProgram"}, has_zero_points_{has_zero_points}, tile_m_(tile_m), tile_n_(tile_n) {} - + MatMulNBitsProgram(uint32_t tile_size, uint32_t nbits, bool has_zero_points) : Program{"MatMulNBits"}, tile_size_(tile_size), nbits_(nbits), has_zero_points_(has_zero_points) {} Status GenerateShaderCode(ShaderHelper& sh) const override; - WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"block_size", ProgramUniformVariableDataType::Uint32}); + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES( + {"M", ProgramUniformVariableDataType::Uint32}, + {"N", ProgramUniformVariableDataType::Uint32}, + {"K", ProgramUniformVariableDataType::Uint32}, + {"K_of_a", ProgramUniformVariableDataType::Uint32}, + {"K_of_b", ProgramUniformVariableDataType::Uint32}, + {"block_size", ProgramUniformVariableDataType::Uint32}, + {"blocks_per_col", ProgramUniformVariableDataType::Uint32}, + {"zero_blocks_per_col", ProgramUniformVariableDataType::Uint32}, + {"num_N_tile", ProgramUniformVariableDataType::Uint32}, + {"batch_count", ProgramUniformVariableDataType::Uint32}); private: + uint32_t tile_size_; + uint32_t nbits_; bool has_zero_points_; - uint32_t tile_m_; - uint32_t tile_n_; }; class MatMulNBits final : public WebGpuKernel { diff --git a/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.cc b/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.cc index 09650be9358d0..674473a173445 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.cc @@ -248,8 +248,8 @@ bool CanApplySubgroupMatrixMatMulNBits(onnxruntime::webgpu::ComputeContext& cont // some precision issues with subgroupMatrixMultiplyAccumulate. It is possible to support higher accuracy // by setting compute_precision to Fp32, but that will be slower. For 1K token prefill FP16 Phi 3.5 is around 5s, // FP322 is around 7s. - return context.AdapterInfo().backendType == wgpu::BackendType::Metal && - has_subgroup_matrix && + return has_subgroup_matrix && + context.AdapterInfo().vendor == std::string_view{"apple"} && accuracy_level == 4 && block_size == 32 && batch_count == 1 && diff --git a/onnxruntime/core/common/cpuid_info.cc b/onnxruntime/core/common/cpuid_info.cc index 97766028cfe12..91961bf22ce1e 100644 --- a/onnxruntime/core/common/cpuid_info.cc +++ b/onnxruntime/core/common/cpuid_info.cc @@ -106,6 +106,7 @@ void CPUIDInfo::X86Init() { GetCPUID(0, data); vendor_ = GetX86Vendor(data); + vendor_id_ = GetVendorId(vendor_); int num_IDs = data[0]; if (num_IDs >= 1) { @@ -151,6 +152,14 @@ std::string CPUIDInfo::GetX86Vendor(int32_t* data) { #endif // defined(CPUIDINFO_ARCH_X86) +uint32_t CPUIDInfo::GetVendorId(const std::string& vendor) { + if (vendor == "GenuineIntel") return 0x8086; + if (vendor == "GenuineAMD") return 0x1022; + if (vendor.find("Qualcomm") == 0) return 'Q' | ('C' << 8) | ('O' << 16) | ('M' << 24); + if (vendor.find("NV") == 0) return 0x10DE; + return 0; +} + #if defined(CPUIDINFO_ARCH_ARM) #if defined(__linux__) @@ -204,6 +213,7 @@ void CPUIDInfo::ArmLinuxInit() { void CPUIDInfo::ArmWindowsInit() { // Get the ARM vendor string from the registry vendor_ = GetArmWindowsVendor(); + vendor_id_ = GetVendorId(vendor_); // Read MIDR and ID_AA64ISAR1_EL1 register values from Windows registry // There should be one per CPU diff --git a/onnxruntime/core/common/cpuid_info.h b/onnxruntime/core/common/cpuid_info.h index 4d6e7e8b9105e..b820fa2ab1af7 100644 --- a/onnxruntime/core/common/cpuid_info.h +++ b/onnxruntime/core/common/cpuid_info.h @@ -19,6 +19,10 @@ class CPUIDInfo { return vendor_; } + uint32_t GetCPUVendorId() const { + return vendor_id_; + } + bool HasAMX_BF16() const { return has_amx_bf16_; } bool HasAVX() const { return has_avx_; } bool HasAVX2() const { return has_avx2_; } @@ -123,6 +127,9 @@ class CPUIDInfo { bool has_arm_neon_bf16_{false}; std::string vendor_; + uint32_t vendor_id_; + + uint32_t GetVendorId(const std::string& vendor); #if defined(CPUIDINFO_ARCH_X86) diff --git a/onnxruntime/core/framework/graph_partitioner.cc b/onnxruntime/core/framework/graph_partitioner.cc index e3e54be3f7c21..8ed5eeaa8d44f 100644 --- a/onnxruntime/core/framework/graph_partitioner.cc +++ b/onnxruntime/core/framework/graph_partitioner.cc @@ -806,7 +806,13 @@ static Status CreateEpContextModel(const ExecutionProviders& execution_providers ORT_RETURN_IF(ep_context_gen_options.error_if_no_compiled_nodes, "Compiled model does not contain any EPContext nodes. " "Check that the session EPs support compilation and can execute at least one model subgraph."); - return Status::OK(); + + LOGS(logger, WARNING) << "Compiled model does not contain any EPContext nodes. " + "Either the session EPs do not support compilation or " + "no subgraphs were able to be compiled."; + + // we continue on to generate the compiled model which may benefit from L1 optimizations even if there are not + // EPContext nodes. } auto get_ep_context_node = [&all_ep_context_nodes](const std::string& node_name) -> std::pair { diff --git a/onnxruntime/core/framework/session_options.h b/onnxruntime/core/framework/session_options.h index 94ff2bb55a055..89a43c4f71ee6 100644 --- a/onnxruntime/core/framework/session_options.h +++ b/onnxruntime/core/framework/session_options.h @@ -90,6 +90,16 @@ struct EpContextModelGenerationOptions { size_t output_external_initializer_size_threshold = 0; }; +struct EpSelectionPolicy { + // flag to detect that a policy was set by the user. + // need to preserve current behavior of defaulting to CPU EP if no EPs are explicitly registered + // and no selection policy was explicitly provided. + bool enable{false}; + OrtExecutionProviderDevicePolicy policy = OrtExecutionProviderDevicePolicy_DEFAULT; + EpSelectionDelegate delegate{}; + void* state{nullptr}; // state for the delegate +}; + /** * Configuration information for a session. */ @@ -222,6 +232,11 @@ struct SessionOptions { // copied internally and the flag needs to be accessible across all copies. std::shared_ptr load_cancellation_flag = std::make_shared(false); + // Policy to guide Execution Provider selection + EpSelectionPolicy ep_selection_policy = {false, + OrtExecutionProviderDevicePolicy::OrtExecutionProviderDevicePolicy_DEFAULT, + nullptr}; + // Options for generating compile EPContext models were previously stored in session_option.configs as // string key/value pairs. To support more advanced options, such as setting input/output buffers, we // now have to store EPContext options in a struct of type EpContextModelGenerationOptions. @@ -253,6 +268,7 @@ inline std::ostream& operator<<(std::ostream& os, const SessionOptions& session_ << " use_per_session_threads:" << session_options.use_per_session_threads << " thread_pool_allow_spinning:" << session_options.thread_pool_allow_spinning << " use_deterministic_compute:" << session_options.use_deterministic_compute + << " ep_selection_policy:" << session_options.ep_selection_policy.policy << " config_options: { " << session_options.config_options << " }" //<< " initializers_to_share_map:" << session_options.initializers_to_share_map #if !defined(ORT_MINIMAL_BUILD) && !defined(DISABLE_EXTERNAL_INITIALIZERS) diff --git a/onnxruntime/core/framework/session_state_utils.cc b/onnxruntime/core/framework/session_state_utils.cc index cef902e506075..cacd772b61d76 100644 --- a/onnxruntime/core/framework/session_state_utils.cc +++ b/onnxruntime/core/framework/session_state_utils.cc @@ -369,7 +369,7 @@ common::Status SaveInitializedTensors( if (memory_profile_func) memory_profile_func(planner); - for (auto i : planned_initializers_memory_sizes_in_byte) { + for (const auto& i : planned_initializers_memory_sizes_in_byte) { LOGS(logger, INFO) << "[Memory] SessionStateInitializer statically allocates " << i.second << " bytes for " << i.first.ToString() << std::endl; } diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc index abfd93e13ca3b..7eea7d218e278 100644 --- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc @@ -1474,7 +1474,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA( .SetDoc(GemmaRotaryEmbedding_ver1_doc) .Input(0, "emb", - "embeddding - 3D tensor with shape (batch_size, seq_len, dim)", + "embedding - 3D tensor with shape (batch_size, seq_len, dim)", "U") .Input(1, "q", diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index d87688a62040c..f9f7be60a9bd6 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -1188,7 +1188,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA(BeamSearch, 1, ONNX_MS_OPERATOR_SET_SCHEMA(WhisperBeamSearch, 1, OpSchema() - .SetDoc("Beam Search for whisper model, especiall with cross_qk features etc.") + .SetDoc("Beam Search for whisper model, especially with cross_qk features etc.") .Attr("eos_token_id", "The id of the end-of-sequence token", AttributeProto::INT) .Attr("pad_token_id", "The id of the padding token", AttributeProto::INT) .Attr("decoder_start_token_id", "The id of the token that indicates decoding starts (i.e. the start of transcription token id)", AttributeProto::INT, static_cast(-1)) @@ -2079,7 +2079,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA(MatMulFpQ4, 1, .SetDoc(R"DOC( Matrix product with right hand matrix being pre-packed and quantized int4 data blob. During quantization, the matrix is divided into blocks, where each block is a -continguous subset inside each column. Each block is quantized into a +contiguous subset inside each column. Each block is quantized into a sequence of 4b integers with a scaling factor and an optional offset. Currently 3 quantization types are supported: (0): block size 32, no offset, (1): block size 32, with offset, (2): block size 64, @@ -2691,12 +2691,12 @@ ONNX_MS_OPERATOR_SET_SCHEMA(GemmFloat8, 1, .SetDoc(R"DOC(Generic Gemm for float and float 8.)DOC") .Attr( "transA", - "Whether A should be transposed. Float 8 only supprted transA=0.", + "Whether A should be transposed. Float 8 only supported transA=0.", AttributeProto::INT, static_cast(0)) .Attr( "transB", - "Whether B should be transposed. Float 8 only supprted transB=1.", + "Whether B should be transposed. Float 8 only supported transB=1.", AttributeProto::INT, static_cast(0)) .Attr( @@ -3388,7 +3388,7 @@ where scale = 1. / (1. - ratio). ``` -This op functions in much the same was as Dropout-11 and Dropout-13 do, execpt that the mask is output as a bit-packed uint32 tensor, instead of a boolean tensor. +This op functions in much the same was as Dropout-11 and Dropout-13 do, except that the mask is output as a bit-packed uint32 tensor, instead of a boolean tensor. )DOC"; ONNX_CONTRIB_OPERATOR_SCHEMA(BitmaskDropout) diff --git a/onnxruntime/core/mlas/lib/qnbitgemm.h b/onnxruntime/core/mlas/lib/qnbitgemm.h index a740801e00514..4c133103bee04 100644 --- a/onnxruntime/core/mlas/lib/qnbitgemm.h +++ b/onnxruntime/core/mlas/lib/qnbitgemm.h @@ -54,8 +54,8 @@ struct PackedQuantBDataStruct { const size_t PackedQuantBDataSize = N * BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); size_t BlkSumSize = MlasDivRoundup(N, 16) * BlockCountK * 16 * sizeof(T); #if defined(MLAS_TARGET_AMD64_IX86) - // _mm256_load_si256 requires alignment on a 32-byte boundary - PackedQuantBData = (std::byte*)MlasAlignAddress(PackedQuantBWorkspace, 32); + // avx512 requires alignment on a 64-byte boundary + PackedQuantBData = (std::byte*)MlasAlignAddress(PackedQuantBWorkspace, 64); #else PackedQuantBData = (std::byte*)PackedQuantBWorkspace; #endif diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen32.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen32.h index 09d53f9b852db..d2d9886ab61f7 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen32.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen32.h @@ -131,8 +131,8 @@ accumulate_q8_blklen32_r1c1blk2_avx2( __m256& acc0 ) { - const __m256i bv0_32_epi8 = _mm256_loadu_si256(reinterpret_cast(QuantBDataPtr)); - const __m256i bv1_32_epi8 = _mm256_loadu_si256(reinterpret_cast(QuantBDataPtr + 32)); + const __m256i bv0_32_epi8 = _mm256_load_si256(reinterpret_cast(QuantBDataPtr)); + const __m256i bv1_32_epi8 = _mm256_load_si256(reinterpret_cast(QuantBDataPtr + 32)); __m256 scale_b_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_b)); // 01 01 01 01 __m256 scale_a0_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_a0)); __m256 scale_a0b_2_ps = _mm256_mul_ps(scale_b_2_ps, scale_a0_2_ps); @@ -194,8 +194,8 @@ accumulate_q8_blklen32_r2c1blk2_avx2( __m256& acc1 ) { - const __m256i bv0_32_epi8 = _mm256_loadu_si256(reinterpret_cast(QuantBDataPtr)); - const __m256i bv1_32_epi8 = _mm256_loadu_si256(reinterpret_cast(QuantBDataPtr + 32)); + const __m256i bv0_32_epi8 = _mm256_load_si256(reinterpret_cast(QuantBDataPtr)); + const __m256i bv1_32_epi8 = _mm256_load_si256(reinterpret_cast(QuantBDataPtr + 32)); __m256 scale_b_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_b)); // 01 01 01 01 __m256 scale_a0_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_a0)); __m256 scale_a0b_2_ps = _mm256_mul_ps(scale_b_2_ps, scale_a0_2_ps); @@ -360,7 +360,7 @@ accumulate_q8_blklen32_r1c1blk1_avx2( __m256& acc0 ) { - const __m256i bv0_32_epi8 = _mm256_loadu_si256(reinterpret_cast(QuantBDataPtr)); + const __m256i bv0_32_epi8 = _mm256_load_si256(reinterpret_cast(QuantBDataPtr)); #if !defined(__GNUC__) || (__GNUC__ > 10) if constexpr (vnni) @@ -402,7 +402,7 @@ accumulate_q8_blklen32_r2c1blk1_avx2( __m256& acc1 ) { - const __m256i bv0_32_epi8 = _mm256_loadu_si256(reinterpret_cast(QuantBDataPtr)); + const __m256i bv0_32_epi8 = _mm256_load_si256(reinterpret_cast(QuantBDataPtr)); #if !defined(__GNUC__) || (__GNUC__ > 10) if constexpr (vnni) @@ -667,10 +667,10 @@ Q8Int8GemmR2xC4BlkLen32Avx2( // process 2 blks of 64 4b weights a time for (; k_blks_remaining > 1; k_blks_remaining -= PerAccuBlk2) { // load A: - const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr)); - const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + BlkLen32)); - const __m256i av_10_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + lda)); - const __m256i av_11_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + lda + BlkLen32)); + const __m256i av_00_epi8 = _mm256_load_si256((const __m256i*)(QuantAPtr)); + const __m256i av_01_epi8 = _mm256_load_si256((const __m256i*)(QuantAPtr + BlkLen32)); + const __m256i av_10_epi8 = _mm256_load_si256((const __m256i*)(QuantAPtr + lda)); + const __m256i av_11_epi8 = _mm256_load_si256((const __m256i*)(QuantAPtr + lda + BlkLen32)); accumulate_q8_blklen32_r2c1blk2_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc[0], acc[NCols4]); accumulate_q8_blklen32_r2c1blk2_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + StrideQuantBData2, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + PerAccuBlk2, acc[1], acc[NCols4 + 1]); @@ -686,8 +686,8 @@ Q8Int8GemmR2xC4BlkLen32Avx2( if (k_blks_remaining > 0) { // load A - const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr)); - const __m256i av_10_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + lda)); + const __m256i av_00_epi8 = _mm256_load_si256((const __m256i*)(QuantAPtr)); + const __m256i av_10_epi8 = _mm256_load_si256((const __m256i*)(QuantAPtr + lda)); float scale_a00 = *QuantAScalePtr; float scale_a10 = *(QuantAScalePtr + BlockCountK); @@ -865,10 +865,10 @@ Q8Int8GemmR2xC1BlkLen32Avx2( size_t k_blks_remaining = BlockCountK; // process 2 blks of 64 4b weights a time for (; k_blks_remaining > 1; k_blks_remaining -= PerAccuBlk2) { - const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr)); - const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + BlkLen32)); - const __m256i av_10_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + lda)); - const __m256i av_11_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + lda + BlkLen32)); + const __m256i av_00_epi8 = _mm256_load_si256((const __m256i*)(QuantAPtr)); + const __m256i av_01_epi8 = _mm256_load_si256((const __m256i*)(QuantAPtr + BlkLen32)); + const __m256i av_10_epi8 = _mm256_load_si256((const __m256i*)(QuantAPtr + lda)); + const __m256i av_11_epi8 = _mm256_load_si256((const __m256i*)(QuantAPtr + lda + BlkLen32)); accumulate_q8_blklen32_r2c1blk2_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc0, acc1); @@ -880,8 +880,8 @@ Q8Int8GemmR2xC1BlkLen32Avx2( } if (k_blks_remaining > 0) { - const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantAPtr); - const __m256i av_10_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + lda)); + const __m256i av_00_epi8 = _mm256_load_si256((const __m256i*)QuantAPtr); + const __m256i av_10_epi8 = _mm256_load_si256((const __m256i*)(QuantAPtr + lda)); const float scale_a00 = *QuantAScalePtr; const float scale_a10 = *(QuantAScalePtr + BlockCountK); @@ -1082,8 +1082,8 @@ Q8Int8GemmR1xC4BlkLen32Avx2( __m256 acc[NCols4] = {_mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps()}; size_t k_blks_remaining = BlockCountK; for (; k_blks_remaining > 1; k_blks_remaining -= PerAccuBlk2) { - const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr)); - const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + BlkLen32)); + const __m256i av_00_epi8 = _mm256_load_si256((const __m256i*)(QuantAPtr)); + const __m256i av_01_epi8 = _mm256_load_si256((const __m256i*)(QuantAPtr + BlkLen32)); accumulate_q8_blklen32_r1c1blk2_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc[0]); accumulate_q8_blklen32_r1c1blk2_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + StrideQuantBData2, QuantAScalePtr, QuantBScalePtr + PerAccuBlk2, acc[1]); @@ -1099,7 +1099,7 @@ Q8Int8GemmR1xC4BlkLen32Avx2( if (k_blks_remaining > 0) { // load A - const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantAPtr); + const __m256i av_00_epi8 = _mm256_load_si256((const __m256i*)QuantAPtr); const float scale_a00 = *QuantAScalePtr; const float scale_00 = scale_a00 * (QuantBScalePtr)[0]; @@ -1257,8 +1257,8 @@ Q8Int8GemmR1xC1BlkLen32Avx2( __m256 acc0 = _mm256_setzero_ps(); size_t k_blks_remaining = BlockCountK; for (; k_blks_remaining > 1; k_blks_remaining -= PerAccuBlk2) { - const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr)); - const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + BlkLen32)); + const __m256i av_00_epi8 = _mm256_load_si256((const __m256i*)(QuantAPtr)); + const __m256i av_01_epi8 = _mm256_load_si256((const __m256i*)(QuantAPtr + BlkLen32)); accumulate_q8_blklen32_r1c1blk2_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc0); // increment block pointers @@ -1269,7 +1269,7 @@ Q8Int8GemmR1xC1BlkLen32Avx2( } if (k_blks_remaining > 0) { - const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantAPtr); + const __m256i av_00_epi8 = _mm256_load_si256((const __m256i*)QuantAPtr); const float& scale_a00 = *QuantAScalePtr; const float& scale_00 = scale_a00 * (QuantBScalePtr)[0]; accumulate_q8_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr, scale_00, acc0); diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen64.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen64.h index 2bf27df2dccce..20583740396a7 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen64.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen64.h @@ -87,13 +87,12 @@ static MLAS_FORCEINLINE void accumulate_q8_blklen64_r1c1blk1_avx2( const __m256i& av00_32_epi8, const __m256i& av01_32_epi8, - const std::byte* QuantBDataPtr, + const __m256i& bv0_32_epi8, + const __m256i& bv1_32_epi8, float scale_a0b, __m256& acc0 ) { - __m256i bv0_32_epi8 = _mm256_loadu_si256(reinterpret_cast(QuantBDataPtr)); - __m256i bv1_32_epi8 = _mm256_loadu_si256(reinterpret_cast(QuantBDataPtr + 32)); __m256 scale_8_ps = _mm256_set1_ps(scale_a0b); #if !defined(__GNUC__) || (__GNUC__ > 10) @@ -143,15 +142,14 @@ accumulate_q8_blklen64_r2c1blk1_avx2( const __m256i& av01_32_epi8, const __m256i& av10_32_epi8, const __m256i& av11_32_epi8, - const std::byte* QuantBDataPtr, + const __m256i& bv0_32_epi8, + const __m256i& bv1_32_epi8, float scale_a0b, float scale_a1b, __m256& acc0, __m256& acc1 ) { - __m256i bv0_32_epi8 = _mm256_loadu_si256(reinterpret_cast(QuantBDataPtr)); - __m256i bv1_32_epi8 = _mm256_loadu_si256(reinterpret_cast(QuantBDataPtr + 32)); __m256 scale0_8_ps = _mm256_set1_ps(scale_a0b); __m256 scale1_8_ps = _mm256_set1_ps(scale_a1b); @@ -416,21 +414,54 @@ Q8Int8GemmR2xC4BlkLen64Avx2( const float scale_a0b3 = (*(QuantBScalePtr + 3)) * scale_a0; const float scale_a1b3 = (*(QuantBScalePtr + 3)) * scale_a1; - for (size_t kk = 0; kk < PerBlkSubblkCount; kk++) { - const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantAPtr); - const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + 32)); - const __m256i av_10_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + lda)); - const __m256i av_11_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + lda + 32)); - - accumulate_q8_blklen64_r2c1blk1_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, scale_a0b0, scale_a1b0, acc[0], acc[NCols4]); - accumulate_q8_blklen64_r2c1blk1_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + SubblkDataSizeInBytes, scale_a0b1, scale_a1b1, acc[1], acc[NCols4 + 1]); - accumulate_q8_blklen64_r2c1blk1_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + 2 * SubblkDataSizeInBytes, scale_a0b2, scale_a1b2, acc[2], acc[NCols4 + 2]); - accumulate_q8_blklen64_r2c1blk1_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + 3 * SubblkDataSizeInBytes, scale_a0b3, scale_a1b3, acc[3], acc[NCols4 + 3]); + __m256i av_00_epi8 = _mm256_load_si256((const __m256i*)QuantAPtr); + __m256i av_01_epi8 = _mm256_load_si256((const __m256i*)(QuantAPtr + 32)); + __m256i av_10_epi8 = _mm256_load_si256((const __m256i*)(QuantAPtr + lda)); + __m256i av_11_epi8 = _mm256_load_si256((const __m256i*)(QuantAPtr + lda + 32)); + + __m256i bv00_32_epi8 = _mm256_load_si256(reinterpret_cast(QuantBDataPtr)); + __m256i bv01_32_epi8 = _mm256_load_si256(reinterpret_cast(QuantBDataPtr + 32)); + __m256i bv10_32_epi8 = _mm256_load_si256(reinterpret_cast(QuantBDataPtr + SubblkDataSizeInBytes)); + __m256i bv11_32_epi8 = _mm256_load_si256(reinterpret_cast(QuantBDataPtr + SubblkDataSizeInBytes + 32)); + __m256i bv20_32_epi8 = _mm256_load_si256(reinterpret_cast(QuantBDataPtr + 2 * SubblkDataSizeInBytes)); + __m256i bv21_32_epi8 = _mm256_load_si256(reinterpret_cast(QuantBDataPtr + 2 * SubblkDataSizeInBytes + 32)); + __m256i bv30_32_epi8 = _mm256_load_si256(reinterpret_cast(QuantBDataPtr + 3 * SubblkDataSizeInBytes)); + __m256i bv31_32_epi8 = _mm256_load_si256(reinterpret_cast(QuantBDataPtr + 3 * SubblkDataSizeInBytes + 32)); + + for (size_t kk = 0; kk < PerBlkSubblkCount - 1; kk++) { + accumulate_q8_blklen64_r2c1blk1_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, bv00_32_epi8, bv01_32_epi8, scale_a0b0, scale_a1b0, acc[0], acc[NCols4]); + accumulate_q8_blklen64_r2c1blk1_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, bv10_32_epi8, bv11_32_epi8, scale_a0b1, scale_a1b1, acc[1], acc[NCols4 + 1]); + accumulate_q8_blklen64_r2c1blk1_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, bv20_32_epi8, bv21_32_epi8, scale_a0b2, scale_a1b2, acc[2], acc[NCols4 + 2]); + accumulate_q8_blklen64_r2c1blk1_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, bv30_32_epi8, bv31_32_epi8, scale_a0b3, scale_a1b3, acc[3], acc[NCols4 + 3]); // increment block pointers QuantAPtr += SubblkLen; QuantBDataPtr += NCols4 * SubblkDataSizeInBytes; + + av_00_epi8 = _mm256_load_si256((const __m256i*)QuantAPtr); + av_01_epi8 = _mm256_load_si256((const __m256i*)(QuantAPtr + 32)); + av_10_epi8 = _mm256_load_si256((const __m256i*)(QuantAPtr + lda)); + av_11_epi8 = _mm256_load_si256((const __m256i*)(QuantAPtr + lda + 32)); + + bv00_32_epi8 = _mm256_load_si256(reinterpret_cast(QuantBDataPtr)); + bv01_32_epi8 = _mm256_load_si256(reinterpret_cast(QuantBDataPtr + 32)); + bv10_32_epi8 = _mm256_load_si256(reinterpret_cast(QuantBDataPtr + SubblkDataSizeInBytes)); + bv11_32_epi8 = _mm256_load_si256(reinterpret_cast(QuantBDataPtr + SubblkDataSizeInBytes + 32)); + bv20_32_epi8 = _mm256_load_si256(reinterpret_cast(QuantBDataPtr + 2 * SubblkDataSizeInBytes)); + bv21_32_epi8 = _mm256_load_si256(reinterpret_cast(QuantBDataPtr + 2 * SubblkDataSizeInBytes + 32)); + bv30_32_epi8 = _mm256_load_si256(reinterpret_cast(QuantBDataPtr + 3 * SubblkDataSizeInBytes)); + bv31_32_epi8 = _mm256_load_si256(reinterpret_cast(QuantBDataPtr + 3 * SubblkDataSizeInBytes + 32)); } + + accumulate_q8_blklen64_r2c1blk1_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, bv00_32_epi8, bv01_32_epi8, scale_a0b0, scale_a1b0, acc[0], acc[NCols4]); + accumulate_q8_blklen64_r2c1blk1_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, bv10_32_epi8, bv11_32_epi8, scale_a0b1, scale_a1b1, acc[1], acc[NCols4 + 1]); + accumulate_q8_blklen64_r2c1blk1_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, bv20_32_epi8, bv21_32_epi8, scale_a0b2, scale_a1b2, acc[2], acc[NCols4 + 2]); + accumulate_q8_blklen64_r2c1blk1_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, bv30_32_epi8, bv31_32_epi8, scale_a0b3, scale_a1b3, acc[3], acc[NCols4 + 3]); + + // increment block pointers + QuantAPtr += SubblkLen; + QuantBDataPtr += NCols4 * SubblkDataSizeInBytes; + QuantAScalePtr++; QuantBScalePtr += NCols4; } // k_blks_remaining @@ -587,18 +618,36 @@ Q8Int8GemmR2xC1BlkLen64Avx2( const float scale_a0b0 = (*QuantBScalePtr) * scale_a0; const float scale_a1b0 = (*QuantBScalePtr) * scale_a1; - for (size_t kk = 0; kk < PerBlkSubblkCount; kk++) { - const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantAPtr); - const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + 32)); - const __m256i av_10_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + lda)); - const __m256i av_11_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + lda + 32)); + __m256i av_00_epi8 = _mm256_load_si256((const __m256i*)QuantAPtr); + __m256i av_01_epi8 = _mm256_load_si256((const __m256i*)(QuantAPtr + 32)); + __m256i av_10_epi8 = _mm256_load_si256((const __m256i*)(QuantAPtr + lda)); + __m256i av_11_epi8 = _mm256_load_si256((const __m256i*)(QuantAPtr + lda + 32)); + + __m256i bv00_32_epi8 = _mm256_load_si256(reinterpret_cast(QuantBDataPtr)); + __m256i bv01_32_epi8 = _mm256_load_si256(reinterpret_cast(QuantBDataPtr + 32)); - accumulate_q8_blklen64_r2c1blk1_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, scale_a0b0, scale_a1b0, acc0, acc1); + for (size_t kk = 0; kk < PerBlkSubblkCount - 1; kk++) { + accumulate_q8_blklen64_r2c1blk1_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, bv00_32_epi8, bv01_32_epi8, scale_a0b0, scale_a1b0, acc0, acc1); // increment block pointers QuantAPtr += SubblkLen; QuantBDataPtr += SubblkDataSizeInBytes; + + av_00_epi8 = _mm256_load_si256((const __m256i*)QuantAPtr); + av_01_epi8 = _mm256_load_si256((const __m256i*)(QuantAPtr + 32)); + av_10_epi8 = _mm256_load_si256((const __m256i*)(QuantAPtr + lda)); + av_11_epi8 = _mm256_load_si256((const __m256i*)(QuantAPtr + lda + 32)); + + bv00_32_epi8 = _mm256_load_si256(reinterpret_cast(QuantBDataPtr)); + bv01_32_epi8 = _mm256_load_si256(reinterpret_cast(QuantBDataPtr + 32)); } + + accumulate_q8_blklen64_r2c1blk1_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, bv00_32_epi8, bv01_32_epi8, scale_a0b0, scale_a1b0, acc0, acc1); + + // increment block pointers + QuantAPtr += SubblkLen; + QuantBDataPtr += SubblkDataSizeInBytes; + QuantAScalePtr++; QuantBScalePtr++; } @@ -751,19 +800,47 @@ Q8Int8GemmR1xC4BlkLen64Avx2( const float scale_a0b2 = (*(QuantBScalePtr + 2)) * scale_a0; const float scale_a0b3 = (*(QuantBScalePtr + 3)) * scale_a0; - for (size_t kk = 0; kk < PerBlkSubblkCount; kk++) { - const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantAPtr); - const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + 32)); - - accumulate_q8_blklen64_r1c1blk1_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr, scale_a0b0, acc[0]); - accumulate_q8_blklen64_r1c1blk1_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + SubblkDataSizeInBytes, scale_a0b1, acc[1]); - accumulate_q8_blklen64_r1c1blk1_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + 2 * SubblkDataSizeInBytes, scale_a0b2, acc[2]); - accumulate_q8_blklen64_r1c1blk1_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + 3 * SubblkDataSizeInBytes, scale_a0b3, acc[3]); - + __m256i av_00_epi8 = _mm256_load_si256((const __m256i*)QuantAPtr); + __m256i av_01_epi8 = _mm256_load_si256((const __m256i*)(QuantAPtr + 32)); + + __m256i bv00_32_epi8 = _mm256_load_si256(reinterpret_cast(QuantBDataPtr)); + __m256i bv01_32_epi8 = _mm256_load_si256(reinterpret_cast(QuantBDataPtr + 32)); + __m256i bv10_32_epi8 = _mm256_load_si256(reinterpret_cast(QuantBDataPtr + SubblkDataSizeInBytes)); + __m256i bv11_32_epi8 = _mm256_load_si256(reinterpret_cast(QuantBDataPtr + SubblkDataSizeInBytes + 32)); + __m256i bv20_32_epi8 = _mm256_load_si256(reinterpret_cast(QuantBDataPtr + 2 * SubblkDataSizeInBytes)); + __m256i bv21_32_epi8 = _mm256_load_si256(reinterpret_cast(QuantBDataPtr + 2 * SubblkDataSizeInBytes + 32)); + __m256i bv30_32_epi8 = _mm256_load_si256(reinterpret_cast(QuantBDataPtr + 3 * SubblkDataSizeInBytes)); + __m256i bv31_32_epi8 = _mm256_load_si256(reinterpret_cast(QuantBDataPtr + 3 * SubblkDataSizeInBytes + 32)); + + for (size_t kk = 0; kk < PerBlkSubblkCount - 1; kk++) { + accumulate_q8_blklen64_r1c1blk1_avx2(av_00_epi8, av_01_epi8, bv00_32_epi8, bv01_32_epi8, scale_a0b0, acc[0]); + accumulate_q8_blklen64_r1c1blk1_avx2(av_00_epi8, av_01_epi8, bv10_32_epi8, bv11_32_epi8, scale_a0b1, acc[1]); + accumulate_q8_blklen64_r1c1blk1_avx2(av_00_epi8, av_01_epi8, bv20_32_epi8, bv21_32_epi8, scale_a0b2, acc[2]); + accumulate_q8_blklen64_r1c1blk1_avx2(av_00_epi8, av_01_epi8, bv30_32_epi8, bv31_32_epi8, scale_a0b3, acc[3]); // increment block pointers QuantAPtr += SubblkLen; QuantBDataPtr += NCols4 * SubblkDataSizeInBytes; + + av_00_epi8 = _mm256_load_si256((const __m256i*)QuantAPtr); + av_01_epi8 = _mm256_load_si256((const __m256i*)(QuantAPtr + 32)); + + bv00_32_epi8 = _mm256_load_si256(reinterpret_cast(QuantBDataPtr)); + bv01_32_epi8 = _mm256_load_si256(reinterpret_cast(QuantBDataPtr + 32)); + bv10_32_epi8 = _mm256_load_si256(reinterpret_cast(QuantBDataPtr + SubblkDataSizeInBytes)); + bv11_32_epi8 = _mm256_load_si256(reinterpret_cast(QuantBDataPtr + SubblkDataSizeInBytes + 32)); + bv20_32_epi8 = _mm256_load_si256(reinterpret_cast(QuantBDataPtr + 2 * SubblkDataSizeInBytes)); + bv21_32_epi8 = _mm256_load_si256(reinterpret_cast(QuantBDataPtr + 2 * SubblkDataSizeInBytes + 32)); + bv30_32_epi8 = _mm256_load_si256(reinterpret_cast(QuantBDataPtr + 3 * SubblkDataSizeInBytes)); + bv31_32_epi8 = _mm256_load_si256(reinterpret_cast(QuantBDataPtr + 3 * SubblkDataSizeInBytes + 32)); } + + accumulate_q8_blklen64_r1c1blk1_avx2(av_00_epi8, av_01_epi8, bv00_32_epi8, bv01_32_epi8, scale_a0b0, acc[0]); + accumulate_q8_blklen64_r1c1blk1_avx2(av_00_epi8, av_01_epi8, bv10_32_epi8, bv11_32_epi8, scale_a0b1, acc[1]); + accumulate_q8_blklen64_r1c1blk1_avx2(av_00_epi8, av_01_epi8, bv20_32_epi8, bv21_32_epi8, scale_a0b2, acc[2]); + accumulate_q8_blklen64_r1c1blk1_avx2(av_00_epi8, av_01_epi8, bv30_32_epi8, bv31_32_epi8, scale_a0b3, acc[3]); + QuantAPtr += SubblkLen; + QuantBDataPtr += NCols4 * SubblkDataSizeInBytes; + QuantAScalePtr++; QuantBScalePtr += NCols4; } @@ -909,16 +986,30 @@ Q8Int8GemmR1xC1BlkLen64Avx2( const float scale_a0 = *QuantAScalePtr; const float scale_a0b0 = (*QuantBScalePtr) * scale_a0; - for (size_t kk = 0; kk < PerBlkSubblkCount; kk++) { - const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantAPtr); - const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + 32)); + __m256i av_00_epi8 = _mm256_load_si256((const __m256i*)QuantAPtr); + __m256i av_01_epi8 = _mm256_load_si256((const __m256i*)(QuantAPtr + 32)); + + __m256i bv00_32_epi8 = _mm256_load_si256(reinterpret_cast(QuantBDataPtr)); + __m256i bv01_32_epi8 = _mm256_load_si256(reinterpret_cast(QuantBDataPtr + 32)); - accumulate_q8_blklen64_r1c1blk1_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr, scale_a0b0, acc0); + for (size_t kk = 0; kk < PerBlkSubblkCount - 1; kk++) { + accumulate_q8_blklen64_r1c1blk1_avx2(av_00_epi8, av_01_epi8, bv00_32_epi8, bv01_32_epi8, scale_a0b0, acc0); // increment block pointers QuantAPtr += SubblkLen; QuantBDataPtr += SubblkDataSizeInBytes; + + av_00_epi8 = _mm256_load_si256((const __m256i*)QuantAPtr); + av_01_epi8 = _mm256_load_si256((const __m256i*)(QuantAPtr + 32)); + + bv00_32_epi8 = _mm256_load_si256(reinterpret_cast(QuantBDataPtr)); + bv01_32_epi8 = _mm256_load_si256(reinterpret_cast(QuantBDataPtr + 32)); } + + accumulate_q8_blklen64_r1c1blk1_avx2(av_00_epi8, av_01_epi8, bv00_32_epi8, bv01_32_epi8, scale_a0b0, acc0); + QuantAPtr += SubblkLen; + QuantBDataPtr += SubblkDataSizeInBytes; + QuantAScalePtr++; QuantBScalePtr++; } diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen128.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen128.h index 7ca72debd6d25..6e8cebeac7bc5 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen128.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen128.h @@ -156,8 +156,8 @@ accumulate_q8_blklen128_r1c1blk1_avx512( __m512& acc0 ) { - __m512i bv0_64_epi8 = _mm512_loadu_si512(reinterpret_cast(QuantBDataPtr)); - __m512i bv1_64_epi8 = _mm512_loadu_si512(reinterpret_cast(QuantBDataPtr + 64)); + __m512i bv0_64_epi8 = _mm512_load_si512(reinterpret_cast(QuantBDataPtr)); + __m512i bv1_64_epi8 = _mm512_load_si512(reinterpret_cast(QuantBDataPtr + 64)); if constexpr (vnni) { dot_accumulate_1blkvnni(bv0_64_epi8, bv1_64_epi8, av00_64_epi8, av01_64_epi8, scale_a0b, acc0); @@ -204,8 +204,8 @@ accumulate_q8_blklen128_r2c1blk1_avx512( __m512& acc1 ) { - __m512i bv0_64_epi8 = _mm512_loadu_si512(reinterpret_cast(QuantBDataPtr)); - __m512i bv1_64_epi8 = _mm512_loadu_si512(reinterpret_cast(QuantBDataPtr + 64)); + __m512i bv0_64_epi8 = _mm512_load_si512(reinterpret_cast(QuantBDataPtr)); + __m512i bv1_64_epi8 = _mm512_load_si512(reinterpret_cast(QuantBDataPtr + 64)); if constexpr (vnni) { dot_accumulate_1blkvnni(bv0_64_epi8, bv1_64_epi8, av00_64_epi8, av01_64_epi8, scale_a0b, acc0); @@ -428,10 +428,10 @@ Q8Int8GemmR2xC4BlkLen128Avx512( const float scale_a1b3 = (*(QuantAScalePtr + BlockCountK)) * (*(QuantBScalePtr + 3)); for (size_t kk = 0; kk < PerBlkSubblkCount; kk++) { - const __m512i av00_64_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); - const __m512i av01_64_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + SubblkLen / 2)); - const __m512i av10_64_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda)); - const __m512i av11_64_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda + SubblkLen / 2)); + const __m512i av00_64_epi8 = _mm512_load_si512((const __m512i*)QuantAPtr); + const __m512i av01_64_epi8 = _mm512_load_si512((const __m512i*)(QuantAPtr + SubblkLen / 2)); + const __m512i av10_64_epi8 = _mm512_load_si512((const __m512i*)(QuantAPtr + lda)); + const __m512i av11_64_epi8 = _mm512_load_si512((const __m512i*)(QuantAPtr + lda + SubblkLen / 2)); accumulate_q8_blklen128_r2c1blk1_avx512(av00_64_epi8, av01_64_epi8, av10_64_epi8, av11_64_epi8, QuantBDataPtr, scale_a0b0, scale_a1b0, acc[0], acc[NCols4]); accumulate_q8_blklen128_r2c1blk1_avx512(av00_64_epi8, av01_64_epi8, av10_64_epi8, av11_64_epi8, QuantBDataPtr + SubblkDataSizeInBytes, scale_a0b1, scale_a1b1, acc[1], acc[NCols4 + 1]); @@ -606,10 +606,10 @@ Q8Int8GemmR2xC1BlkLen128Avx512( const float scale_a1b0 = (*(QuantAScalePtr + BlockCountK)) * (*QuantBScalePtr); for (size_t kk = 0; kk < PerBlkSubblkCount; kk++) { - const __m512i av00_64_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); - const __m512i av01_64_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + SubblkLen / 2)); - const __m512i av10_64_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda)); - const __m512i av11_64_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda + SubblkLen / 2)); + const __m512i av00_64_epi8 = _mm512_load_si512((const __m512i*)QuantAPtr); + const __m512i av01_64_epi8 = _mm512_load_si512((const __m512i*)(QuantAPtr + SubblkLen / 2)); + const __m512i av10_64_epi8 = _mm512_load_si512((const __m512i*)(QuantAPtr + lda)); + const __m512i av11_64_epi8 = _mm512_load_si512((const __m512i*)(QuantAPtr + lda + SubblkLen / 2)); accumulate_q8_blklen128_r2c1blk1_avx512(av00_64_epi8, av01_64_epi8, av10_64_epi8, av11_64_epi8, QuantBDataPtr, scale_a0b0, scale_a1b0, acc0, acc1); @@ -768,8 +768,8 @@ Q8Int8GemmR1xC4BlkLen128Avx512( const float scale_a0b3 = (*QuantAScalePtr) * (*(QuantBScalePtr + 3)); for (size_t kk = 0; kk < PerBlkSubblkCount; kk++) { - const __m512i av0_64_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); - const __m512i av1_64_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + SubblkLen / 2)); + const __m512i av0_64_epi8 = _mm512_load_si512((const __m512i*)QuantAPtr); + const __m512i av1_64_epi8 = _mm512_load_si512((const __m512i*)(QuantAPtr + SubblkLen / 2)); accumulate_q8_blklen128_r1c1blk1_avx512(av0_64_epi8, av1_64_epi8, QuantBDataPtr, scale_a0b0, acc[0]); accumulate_q8_blklen128_r1c1blk1_avx512(av0_64_epi8, av1_64_epi8, QuantBDataPtr + SubblkDataSizeInBytes, scale_a0b1, acc[1]); @@ -925,8 +925,8 @@ Q8Int8GemmR1xC1BlkLen128Avx512( const float scale_a0b0 = (*QuantAScalePtr) * (*QuantBScalePtr); for (size_t kk = 0; kk < PerBlkSubblkCount; kk++) { - const __m512i av0_64_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); - const __m512i av1_64_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + SubblkLen / 2)); + const __m512i av0_64_epi8 = _mm512_load_si512((const __m512i*)QuantAPtr); + const __m512i av1_64_epi8 = _mm512_load_si512((const __m512i*)(QuantAPtr + SubblkLen / 2)); accumulate_q8_blklen128_r1c1blk1_avx512(av0_64_epi8, av1_64_epi8, QuantBDataPtr, scale_a0b0, acc0); diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen64.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen64.h index 33d4fde26ae5b..68bf1dae6d0d0 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen64.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen64.h @@ -200,8 +200,8 @@ accumulate_q8_blklen64_r1c1blk2_avx512( __m512& acc0 ) { - __m512i bv0_64_epi8 = _mm512_loadu_si512(reinterpret_cast(QuantBDataPtr)); - __m512i bv1_64_epi8 = _mm512_loadu_si512(reinterpret_cast(QuantBDataPtr + 64)); + __m512i bv0_64_epi8 = _mm512_load_si512(reinterpret_cast(QuantBDataPtr)); + __m512i bv1_64_epi8 = _mm512_load_si512(reinterpret_cast(QuantBDataPtr + 64)); const __m256 scale_b_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_b)); const __m512 scale_b_16_ps = _mm512_broadcast_f32x8(scale_b_ps); @@ -259,8 +259,8 @@ accumulate_q8_blklen64_r2c1blk2_avx512( __m512& acc1 ) { - __m512i bv0_64_epi8 = _mm512_loadu_si512(reinterpret_cast(QuantBDataPtr)); - __m512i bv1_64_epi8 = _mm512_loadu_si512(reinterpret_cast(QuantBDataPtr + 64)); + __m512i bv0_64_epi8 = _mm512_load_si512(reinterpret_cast(QuantBDataPtr)); + __m512i bv1_64_epi8 = _mm512_load_si512(reinterpret_cast(QuantBDataPtr + 64)); const __m256 scale_b_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_b)); const __m512 scale_b_16_ps = _mm512_broadcast_f32x8(scale_b_ps); @@ -440,7 +440,7 @@ accumulate_q8_blklen64_r1c1blk1_avx512( __m512& acc0 ) { - __m512i bv_64_epi8 = _mm512_loadu_si512(reinterpret_cast(QuantBDataPtr)); + __m512i bv_64_epi8 = _mm512_load_si512(reinterpret_cast(QuantBDataPtr)); const __m128 scale_b_ps = _mm_broadcast_ss(scale_b); const __m512 scale_b_16_ps = _mm512_broadcast_f32x2(scale_b_ps); @@ -485,7 +485,7 @@ accumulate_q8_blklen64_r2c1blk1_avx512( __m512& acc1 ) { - __m512i bv_64_epi8 = _mm512_loadu_si512(reinterpret_cast(QuantBDataPtr)); + __m512i bv_64_epi8 = _mm512_load_si512(reinterpret_cast(QuantBDataPtr)); const __m128 scale_b_ps = _mm_broadcast_ss(scale_b); const __m512 scale_b_16_ps = _mm512_broadcast_f32x2(scale_b_ps); @@ -750,10 +750,10 @@ Q8Int8GemmR2xC4BlkLen64Avx512( size_t k_blks_remaining = BlockCountK; for (; k_blks_remaining > 1; k_blks_remaining -= PerAccuBlk2) { - const __m512i av_00_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); - const __m512i av_01_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + 64)); - const __m512i av_10_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda)); - const __m512i av_11_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda + 64)); + const __m512i av_00_epi8 = _mm512_load_si512((const __m512i*)QuantAPtr); + const __m512i av_01_epi8 = _mm512_load_si512((const __m512i*)(QuantAPtr + 64)); + const __m512i av_10_epi8 = _mm512_load_si512((const __m512i*)(QuantAPtr + lda)); + const __m512i av_11_epi8 = _mm512_load_si512((const __m512i*)(QuantAPtr + lda + 64)); accumulate_q8_blklen64_r2c1blk2_avx512(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc[0], acc[NCols4]); accumulate_q8_blklen64_r2c1blk2_avx512(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + StrideQuantBData, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + PerAccuBlk2, acc[1], acc[NCols4 + 1]); @@ -768,8 +768,8 @@ Q8Int8GemmR2xC4BlkLen64Avx512( } // k_blks_remaining for (; k_blks_remaining > 0; --k_blks_remaining) { - const __m512i av_00_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); - const __m512i av_10_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda)); + const __m512i av_00_epi8 = _mm512_load_si512((const __m512i*)QuantAPtr); + const __m512i av_10_epi8 = _mm512_load_si512((const __m512i*)(QuantAPtr + lda)); accumulate_q8_blklen64_r2c1blk1_avx512(av_00_epi8, av_10_epi8, QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc[0], acc[NCols4]); accumulate_q8_blklen64_r2c1blk1_avx512(av_00_epi8, av_10_epi8, QuantBDataPtr + BlkDataSizeInBytes, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 1, acc[1], acc[NCols4 + 1]); @@ -940,10 +940,10 @@ Q8Int8GemmR2xC1BlkLen64Avx512( size_t k_blks_remaining = BlockCountK; for (; k_blks_remaining > 1; k_blks_remaining -= PerAccuBlk2) { - const __m512i av_00_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); - const __m512i av_01_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + 64)); - const __m512i av_10_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda)); - const __m512i av_11_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda + 64)); + const __m512i av_00_epi8 = _mm512_load_si512((const __m512i*)QuantAPtr); + const __m512i av_01_epi8 = _mm512_load_si512((const __m512i*)(QuantAPtr + 64)); + const __m512i av_10_epi8 = _mm512_load_si512((const __m512i*)(QuantAPtr + lda)); + const __m512i av_11_epi8 = _mm512_load_si512((const __m512i*)(QuantAPtr + lda + 64)); accumulate_q8_blklen64_r2c1blk2_avx512(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc0, acc1); @@ -955,8 +955,8 @@ Q8Int8GemmR2xC1BlkLen64Avx512( } for (; k_blks_remaining > 0; --k_blks_remaining) { - const __m512i av_00_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); - const __m512i av_10_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda)); + const __m512i av_00_epi8 = _mm512_load_si512((const __m512i*)QuantAPtr); + const __m512i av_10_epi8 = _mm512_load_si512((const __m512i*)(QuantAPtr + lda)); accumulate_q8_blklen64_r2c1blk1_avx512(av_00_epi8, av_10_epi8, QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc0, acc1); @@ -1119,8 +1119,8 @@ Q8Int8GemmR1xC4BlkLen64Avx512( __m512 acc[NCols4] = {_mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps()}; size_t k_blks_remaining = BlockCountK; for (; k_blks_remaining >= PerAccuBlk2; k_blks_remaining -= PerAccuBlk2) { - const __m512i av0_64_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); - const __m512i av1_64_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + 64)); + const __m512i av0_64_epi8 = _mm512_load_si512((const __m512i*)QuantAPtr); + const __m512i av1_64_epi8 = _mm512_load_si512((const __m512i*)(QuantAPtr + 64)); accumulate_q8_blklen64_r1c1blk2_avx512(av0_64_epi8, av1_64_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc[0]); accumulate_q8_blklen64_r1c1blk2_avx512(av0_64_epi8, av1_64_epi8, QuantBDataPtr + PerAccuBlk2 * BlkDataSizeInBytes, QuantAScalePtr, QuantBScalePtr + PerAccuBlk2, acc[1]); accumulate_q8_blklen64_r1c1blk2_avx512(av0_64_epi8, av1_64_epi8, QuantBDataPtr + 2 * PerAccuBlk2 * BlkDataSizeInBytes, QuantAScalePtr, QuantBScalePtr + 2 * PerAccuBlk2, acc[2]); @@ -1134,7 +1134,7 @@ Q8Int8GemmR1xC4BlkLen64Avx512( } for (; k_blks_remaining > 0; --k_blks_remaining) { - const __m512i av_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); + const __m512i av_epi8 = _mm512_load_si512((const __m512i*)QuantAPtr); accumulate_q8_blklen64_r1c1blk1_avx512(av_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc[0]); accumulate_q8_blklen64_r1c1blk1_avx512(av_epi8, QuantBDataPtr + BlkDataSizeInBytes, QuantAScalePtr, QuantBScalePtr + 1, acc[1]); @@ -1293,8 +1293,8 @@ Q8Int8GemmR1xC1BlkLen64Avx512( __m512 acc0 = _mm512_setzero_ps(); size_t k_blks_remaining = BlockCountK; for (; k_blks_remaining > 1; k_blks_remaining -= PerAccuBlk2) { - const __m512i av_00_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); - const __m512i av_01_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + 64)); + const __m512i av_00_epi8 = _mm512_load_si512((const __m512i*)QuantAPtr); + const __m512i av_01_epi8 = _mm512_load_si512((const __m512i*)(QuantAPtr + 64)); accumulate_q8_blklen64_r1c1blk2_avx512(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc0); @@ -1306,7 +1306,7 @@ Q8Int8GemmR1xC1BlkLen64Avx512( } for (; k_blks_remaining > 0; --k_blks_remaining) { - const __m512i av_00_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); + const __m512i av_00_epi8 = _mm512_load_si512((const __m512i*)QuantAPtr); accumulate_q8_blklen64_r1c1blk1_avx512(av_00_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc0); diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common.h index e7df817dea34c..bb38f37fb0eb8 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common.h @@ -22,8 +22,8 @@ QNBitGemmPackQuantBDataSize( const size_t ScaleSize = N * BlockCountK * sizeof(float); size_t BlkSumSize = MlasDivRoundup(N, 16) * BlockCountK * 16 * sizeof(float); - // _mm256_load_si256 requires alignment on a 32-byte boundary - constexpr size_t PackedQuantBDataAlignment = 32; + // avx512 requires alignment on a 64-byte boundary + constexpr size_t PackedQuantBDataAlignment = 64; PackedQuantBDataSize += PackedQuantBDataAlignment - 1; constexpr size_t BlkSumAlignment = MlasQNBitQuantBBlkSumAlignment(); BlkSumSize += BlkSumAlignment - 1; diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_q8_block.h b/onnxruntime/core/mlas/lib/sqnbitgemm_q8_block.h index 80af2f46790df..9dc217c5ddb97 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_q8_block.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_q8_block.h @@ -66,5 +66,5 @@ MLAS_FORCEINLINE constexpr size_t Q8BlkAlignment() { - return alignof(float); + return 16 * alignof(float); } diff --git a/onnxruntime/core/optimizer/layout_transformation/layout_transformation_dev_notes.md b/onnxruntime/core/optimizer/layout_transformation/layout_transformation_dev_notes.md index 5b58e2b646717..0daa9d1ddeaee 100644 --- a/onnxruntime/core/optimizer/layout_transformation/layout_transformation_dev_notes.md +++ b/onnxruntime/core/optimizer/layout_transformation/layout_transformation_dev_notes.md @@ -13,7 +13,7 @@ Foreach 1. call GetCapability 2. IF EP.DesiredFormat == NHWC 2.1. Invoke Layout Transformer - 2.2 If graph is modified -> call GetCapability (layout transformer can add new nodes to the graph) + 2.2. If graph is modified -> call GetCapability (layout transformer can add new nodes to the graph) 3 Compile ``` diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc index 05627dd25857f..6515661a2ee6a 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc @@ -401,6 +401,37 @@ void ConvSelector::UpdateBuilder(NodesToOptimizeIndicesBuilder& builder) const { builder.input_nodes.resize(3, NodesToOptimizeIndices::kEmptyNodeIndex); } +bool EinsumNodeGroupSelector::Check(const GraphViewer& graph_viewer, + const Node& node, const Node* redundant_clip_node, + const std::vector& dq_nodes, + const std::vector& q_nodes) const { + if (!CheckQDQNodes(graph_viewer, node, redundant_clip_node, dq_nodes, q_nodes, /*num_dq_inputs=*/-1, + /*is_empty_q_nodes_allowed=*/true)) { + return false; + } + size_t num_dq_inputs = dq_nodes.size(); + for (size_t i = 0; i < num_dq_inputs; ++i) { + int32_t dt_input = dq_nodes[i]->InputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); + if (!allow_int8_ && dt_input == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8) { + return false; + } + if (!allow_16bit_ && Is16BitIntType(dt_input)) { + return false; + } + if (!allow_4bit_ && Is4BitIntType(dt_input)) { + return false; + } + } + if (!q_nodes.empty()) { + int32_t dt_input0 = dq_nodes[0]->InputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); + int32_t dt_output = q_nodes[0]->OutputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); + if (dt_input0 != dt_output) { + return false; + } + } + return true; +} + bool MatMulNodeGroupSelector::Check(const GraphViewer& graph_viewer, const Node& node, const Node* redundant_clip_node, const std::vector& dq_nodes, const std::vector& q_nodes) const { diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h index 36e04146040db..e4f4844fb88ad 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h @@ -182,6 +182,22 @@ class PadNodeGroupSelector : public NodeGroupSelector { const std::vector& q_nodes) const override; }; +// one ore more DQ nodes for each input -> node -> Q +class EinsumNodeGroupSelector : public NodeGroupSelector { + public: + explicit EinsumNodeGroupSelector(bool allow_int8 = true, bool allow_16bit = true, bool allow_4bit = true) + : allow_int8_(allow_int8), allow_16bit_(allow_16bit), allow_4bit_(allow_4bit) {} + + private: + bool Check(const GraphViewer& graph_viewer, + const Node& node, const Node* redundant_clip_node, + const std::vector& dq_nodes, + const std::vector& q_nodes) const override; + bool allow_int8_; + bool allow_16bit_; + bool allow_4bit_; +}; + // 2 DQ nodes for input -> node -> optional Q if QLinearMatMul, MatMulIntegerToFloat if not // The lack of a trailing Q isn't really a QDQ node group, so we default support for that to off. class MatMulNodeGroupSelector : public NodeGroupSelector { diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc index e531d19d4c643..d3957a34dcfca 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc @@ -113,6 +113,9 @@ static const OpVersionsAndSelector::OpVersionsMap GetConvOpVersionsMap() { static const OpVersionsAndSelector::OpVersionsMap GetConvTransposeOpVersionsMap() { return {{"ConvTranspose", {}}}; } +static const OpVersionsAndSelector::OpVersionsMap GetEinsumOpVersionsMap() { + return {{"Einsum", {}}}; +} static const OpVersionsAndSelector::OpVersionsMap GetMatMulOpVersionsMap() { return {{"MatMul", {}}}; } @@ -202,6 +205,13 @@ void RegisterConvTransposeSelector(Selectors& qdq_selectors) { std::move(selector)); } +void RegisterEinsumSelector(Selectors& qdq_selectors) { + /* register selector for einsum op */ + std::unique_ptr selector = std::make_unique(); + qdq_selectors.RegisterSelector(GetEinsumOpVersionsMap(), + std::move(selector)); +} + void RegisterMatMulSelector(Selectors& qdq_selectors) { /* register selector for matmul op */ std::unique_ptr selector = std::make_unique(); @@ -267,6 +277,7 @@ void SelectorManager::CreateSelectors() { RegisterSplitSelector(qdq_selectors_); RegisterConvSelector(qdq_selectors_); RegisterConvTransposeSelector(qdq_selectors_); + RegisterEinsumSelector(qdq_selectors_); RegisterMatMulSelector(qdq_selectors_); RegisterGemmSelector(qdq_selectors_); RegisterInstanceAndLayerNormalizationSelector(qdq_selectors_); diff --git a/onnxruntime/core/optimizer/qdq_transformer/weight_bias_quantization.cc b/onnxruntime/core/optimizer/qdq_transformer/weight_bias_quantization.cc index 83c5d7bc8d92a..58e90ea3c71c2 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/weight_bias_quantization.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/weight_bias_quantization.cc @@ -89,7 +89,9 @@ Status WeightBiasQuantization::ApplyImpl(Graph& graph, bool& modified, int graph } const auto& dq_attrs = dq_1->GetAttributes(); - if (dq_attrs.find("block_size") != dq_attrs.end()) { + auto attr_it = dq_attrs.find("block_size"); + // Default value of block_size=0 has no significance. Don't skip weight_bias_quantization. + if (attr_it != dq_attrs.end() && attr_it->second.i() != 0) { continue; } diff --git a/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.cc b/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.cc index 10cb6eb97bdd6..93c7efc9ca167 100644 --- a/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.cc +++ b/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.cc @@ -2487,6 +2487,7 @@ constexpr HandlerInfo reshape_handler = {&FirstInput, &HandleReshape, /*transpos static const std::unordered_map handler_map{ {"Cast", simple_node_handler}, {"Exp", simple_node_handler}, + {"Gelu", simple_node_handler}, {"Identity", simple_node_handler}, {"LeakyRelu", simple_node_handler}, {"Log", simple_node_handler}, diff --git a/onnxruntime/core/platform/telemetry.cc b/onnxruntime/core/platform/telemetry.cc index 206774c896ff5..e9c7830f6d7a4 100644 --- a/onnxruntime/core/platform/telemetry.cc +++ b/onnxruntime/core/platform/telemetry.cc @@ -89,4 +89,12 @@ void Telemetry::LogExecutionProviderEvent(LUID* adapterLuid) const { ORT_UNUSED_PARAMETER(adapterLuid); } +void Telemetry::LogDriverInfoEvent(const std::string_view device_class, + const std::wstring_view& driver_names, + const std::wstring_view& driver_versions) const { + ORT_UNUSED_PARAMETER(device_class); + ORT_UNUSED_PARAMETER(driver_names); + ORT_UNUSED_PARAMETER(driver_versions); +} + } // namespace onnxruntime diff --git a/onnxruntime/core/platform/telemetry.h b/onnxruntime/core/platform/telemetry.h index bc261fddcd56e..d9afcace2fb81 100644 --- a/onnxruntime/core/platform/telemetry.h +++ b/onnxruntime/core/platform/telemetry.h @@ -69,6 +69,10 @@ class Telemetry { virtual void LogExecutionProviderEvent(LUID* adapterLuid) const; + virtual void LogDriverInfoEvent(const std::string_view device_class, + const std::wstring_view& driver_names, + const std::wstring_view& driver_versions) const; + private: ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(Telemetry); }; diff --git a/onnxruntime/core/platform/windows/device_discovery.cc b/onnxruntime/core/platform/windows/device_discovery.cc index 88fbec37c8075..61db2bf368b09 100644 --- a/onnxruntime/core/platform/windows/device_discovery.cc +++ b/onnxruntime/core/platform/windows/device_discovery.cc @@ -12,6 +12,7 @@ #include "core/common/cpuid_info.h" #include "core/common/logging/logging.h" +#include "core/platform/env.h" #include "core/session/abi_devices.h" //// For SetupApi info @@ -56,6 +57,26 @@ struct DeviceInfo { std::unordered_map metadata; }; +struct DriverInfo { + std::wstring driver_versions; + std::wstring driver_names; + + void AddDevice(const std::wstring& driver_version, const std::wstring& driver_name) { + if (!driver_version.empty()) { + if (!driver_versions.empty()) { + driver_versions += L", "; + } + driver_versions += driver_version; + } + if (!driver_name.empty()) { + if (!driver_names.empty()) { + driver_names += L", "; + } + driver_names += driver_name; + } + } +}; + uint64_t GetDeviceKey(uint32_t vendor_id, uint32_t device_id) { return (uint64_t(vendor_id) << 32) | device_id; } @@ -68,6 +89,20 @@ uint64_t GetLuidKey(LUID luid) { return (uint64_t(luid.HighPart) << 32) | luid.LowPart; } +// Converts a wide string (up to 4 characters) representing a hardware ID component (e.g., "ABCD" from "VEN_ABCD") +// into a uint32_t. The conversion is done in a little-endian manner, meaning the first character +// of the string becomes the least significant byte of the integer, and the fourth character +// becomes the most significant byte. +uint32_t WStringToUint32Id(const std::wstring& vendor_name) { + uint32_t vendor_id = 0; + for (size_t i = 0; i < 4 && i < vendor_name.size(); ++i) { + // For little-endian, place each character at the appropriate byte position + // First character goes into lowest byte, last character into highest byte + vendor_id |= static_cast(vendor_name[i] & 0xFF) << (i * 8); + } + return vendor_id; +} + // returns info for display and processor entries. key is (vendor_id << 32 | device_id) // npus: (vendor_id << 32 | device_id) for devices we think are NPUs from DXCORE std::unordered_map GetDeviceInfoSetupApi(const std::unordered_set& npus) { @@ -75,11 +110,14 @@ std::unordered_map GetDeviceInfoSetupApi(const std::unorde const GUID local_DXCORE_ADAPTER_ATTRIBUTE_D3D12_GENERIC_ML = {0xb71b0d41, 0x1088, 0x422f, 0xa2, 0x7c, 0x2, 0x50, 0xb7, 0xd3, 0xa9, 0x88}; const GUID local_DXCORE_HARDWARE_TYPE_ATTRIBUTE_NPU = {0xd46140c4, 0xadd7, 0x451b, 0x9e, 0x56, 0x6, 0xfe, 0x8c, 0x3b, 0x58, 0xed}; + const GUID local_GUID_DEVCLASS_COMPUTEACCELERATOR = {0xf01a9d53, 0x3ff6, 0x48d2, 0x9f, 0x97, 0xc8, 0xa7, 0x00, 0x4b, 0xe1, 0x0c}; + + std::unordered_map device_version_info; std::array guids = { GUID_DEVCLASS_DISPLAY, GUID_DEVCLASS_PROCESSOR, - GUID_DEVCLASS_SYSTEM, + local_GUID_DEVCLASS_COMPUTEACCELERATOR, }; for (auto guid : guids) { @@ -103,23 +141,34 @@ std::unordered_map GetDeviceInfoSetupApi(const std::unorde //// Get hardware ID (contains VEN_xxxx&DEV_xxxx) if (SetupDiGetDeviceRegistryPropertyW(devInfo, &devData, SPDRP_HARDWAREID, ®DataType, (PBYTE)buffer, sizeof(buffer), &size)) { + uint32_t vendor_id = 0; + uint32_t device_id = 0; + // PCI\VEN_xxxx&DEV_yyyy&... // ACPI\VEN_xxxx&DEV_yyyy&... if we're lucky. // ACPI values seem to be very inconsistent, so we check fairly carefully and always require a device id. const auto get_id = [](const std::wstring& hardware_id, const std::wstring& prefix) -> uint32_t { if (auto idx = hardware_id.find(prefix); idx != std::wstring::npos) { auto id = hardware_id.substr(idx + prefix.size(), 4); - if (std::all_of(id.begin(), id.end(), iswxdigit)) { - return std::stoul(id, nullptr, 16); + if (id.size() == 4) { + // DXCore reports vendor and device IDs as 32-bit integer representations of the ASCII string. + return WStringToUint32Id(id); } } return 0; }; - uint32_t vendor_id = get_id(buffer, L"VEN_"); - uint32_t device_id = get_id(buffer, L"DEV_"); - // won't always have a vendor id from an ACPI entry. need at least a device id to identify the hardware + // Processor ID should come from CPUID mapping. + if (guid == GUID_DEVCLASS_PROCESSOR) { + vendor_id = CPUIDInfo::GetCPUIDInfo().GetCPUVendorId(); + } else { + vendor_id = get_id(buffer, L"VEN_"); + } + + device_id = get_id(buffer, L"DEV_"); + + // Won't always have a vendor id from an ACPI entry. ACPI is not defined for this purpose. if (vendor_id == 0 && device_id == 0) { continue; } @@ -138,8 +187,8 @@ std::unordered_map GetDeviceInfoSetupApi(const std::unorde entry = &device_info[key]; entry->vendor_id = vendor_id; entry->device_id = device_id; - // put the first hardware id string in the metadata. ignore the other lines. - entry->metadata.emplace(L"SPDRP_HARDWAREID", std::wstring(buffer, wcslen(buffer))); + // put the first hardware id string in the metadata. ignore the other lines. not sure if this is of value. + // entry->metadata.emplace(L"SPDRP_HARDWAREID", std::wstring(buffer, wcslen(buffer))); } else { // need valid ids continue; @@ -156,14 +205,14 @@ std::unordered_map GetDeviceInfoSetupApi(const std::unorde (PBYTE)buffer, sizeof(buffer), &size)) { std::wstring desc{buffer}; - // Should we require the NPU to be found by DXCore or do we want to allow this vague matching? - // Probably depends on whether we always attempt to run DXCore or not. - const auto possible_npu = [](const std::wstring& desc) { - return (desc.find(L"NPU") != std::wstring::npos || - desc.find(L"Neural") != std::wstring::npos || - desc.find(L"AI Engine") != std::wstring::npos || - desc.find(L"VPU") != std::wstring::npos); - }; + // For now, require dxcore to identify an NPU. + // If we want to try and infer it from the description this _may_ work but is untested. + // const auto possible_npu = [](const std::wstring& desc) { + // return (desc.find(L"NPU") != std::wstring::npos || + // desc.find(L"Neural") != std::wstring::npos || + // desc.find(L"AI Engine") != std::wstring::npos || + // desc.find(L"VPU") != std::wstring::npos); + // }; // use description if no friendly name if (entry->description.empty()) { @@ -171,15 +220,15 @@ std::unordered_map GetDeviceInfoSetupApi(const std::unorde } uint64_t npu_key = GetDeviceKey(*entry); - bool is_npu = npus.count(npu_key) > 0 || possible_npu(desc); + bool is_npu = npus.count(npu_key) > 0; // rely on dxcore to determine if something is an NPU if (guid == GUID_DEVCLASS_DISPLAY) { entry->type = OrtHardwareDeviceType_GPU; } else if (guid == GUID_DEVCLASS_PROCESSOR) { entry->type = is_npu ? OrtHardwareDeviceType_NPU : OrtHardwareDeviceType_CPU; - } else if (guid == GUID_DEVCLASS_SYSTEM) { + } else if (guid == local_GUID_DEVCLASS_COMPUTEACCELERATOR) { if (!is_npu) { - // we're only iterating system devices to look for NPUs so drop anything else + // we're only iterating compute accelerator devices to look for NPUs so drop anything else device_info.erase(key); continue; } @@ -194,28 +243,64 @@ std::unordered_map GetDeviceInfoSetupApi(const std::unorde continue; } - if (SetupDiGetDeviceRegistryPropertyW(devInfo, &devData, SPDRP_MFG, nullptr, - (PBYTE)buffer, sizeof(buffer), &size)) { - entry->vendor = std::wstring(buffer, wcslen(buffer)); + if (entry->type == OrtHardwareDeviceType_CPU) { + // get 12 byte string from CPUID. easier for a user to match this if they are explicitly picking a device. + std::string_view cpuid_vendor = CPUIDInfo::GetCPUIDInfo().GetCPUVendor(); + entry->vendor = std::wstring(cpuid_vendor.begin(), cpuid_vendor.end()); } - // Add the UI number if GPU. Helpful if user has integrated and discrete GPUs - if (entry->type == OrtHardwareDeviceType_GPU) { - DWORD ui_number = 0; - if (SetupDiGetDeviceRegistryPropertyW(devInfo, &devData, SPDRP_UI_NUMBER, nullptr, - (PBYTE)&ui_number, sizeof(ui_number), &size)) { - // use value read. - } else { - // infer it as 0 if not set. + if (entry->vendor.empty()) { + if (SetupDiGetDeviceRegistryPropertyW(devInfo, &devData, SPDRP_MFG, nullptr, + (PBYTE)buffer, sizeof(buffer), &size)) { + entry->vendor = std::wstring(buffer, wcslen(buffer)); } + } - entry->metadata.emplace(L"SPDRP_UI_NUMBER", std::to_wstring(ui_number)); + // Generate telemetry event to log the GPU and NPU driver name and version. + if (entry->type == OrtHardwareDeviceType_CPU) { + // Skip processor entries for telemetry. + continue; + } + + // Open the device's driver registry key + HKEY dev_reg_key = SetupDiOpenDevRegKey(devInfo, &devData, + DICS_FLAG_GLOBAL, + 0, + DIREG_DRV, + KEY_READ); + + if (dev_reg_key != INVALID_HANDLE_VALUE) { + // Query the "DriverVersion" string + std::wstring driver_version_str; + wchar_t driver_version[256]; + DWORD str_size = sizeof(driver_version); + DWORD type = 0; + if (RegQueryValueExW(dev_reg_key, L"DriverVersion", + nullptr, &type, + reinterpret_cast(driver_version), + &str_size) == ERROR_SUCCESS && + type == REG_SZ) { + // Ensure proper null termination of a string retrieved from the Windows Registry API. + driver_version[(str_size / sizeof(wchar_t)) - 1] = 0; + driver_version_str = driver_version; + } + RegCloseKey(dev_reg_key); + device_version_info[entry->type].AddDevice(driver_version_str, entry->description); } } SetupDiDestroyDeviceInfoList(devInfo); } + // Log driver information for GPUs and NPUs + const Env& env = Env::Default(); + for (const auto& [type, info] : device_version_info) { + if (!info.driver_versions.empty() || !info.driver_names.empty()) { + const std::string_view driver_class = (type == OrtHardwareDeviceType_GPU) ? "GPU" : "NPU"; + env.GetTelemetryProvider().LogDriverInfoEvent(driver_class, info.driver_names, info.driver_versions); + } + } + return device_info; } @@ -252,9 +337,7 @@ std::unordered_map GetDeviceInfoD3D12() { info.description = std::wstring(desc.Description); info.metadata[L"DxgiAdapterNumber"] = std::to_wstring(i); - info.metadata[L"VideoMemory"] = std::to_wstring(desc.DedicatedVideoMemory / (1024 * 1024)) + L" MB"; - info.metadata[L"SystemMemory"] = std::to_wstring(desc.DedicatedSystemMemory / (1024 * 1024)) + L" MB"; - info.metadata[L"SharedSystemMemory"] = std::to_wstring(desc.DedicatedSystemMemory / (1024 * 1024)) + L" MB"; + info.metadata[L"DxgiVideoMemory"] = std::to_wstring(desc.DedicatedVideoMemory / (1024 * 1024)) + L" MB"; } // iterate by high-performance GPU preference to add that info @@ -272,7 +355,7 @@ std::unordered_map GetDeviceInfoD3D12() { auto it = device_info.find(key); if (it != device_info.end()) { DeviceInfo& info = it->second; - info.metadata[L"HighPerformanceIndex"] = std::to_wstring(i); + info.metadata[L"DxgiHighPerformanceIndex"] = std::to_wstring(i); } } @@ -365,6 +448,20 @@ std::unordered_map GetDeviceInfoDxcore() { return device_info; } + +DeviceInfo GetDeviceInfoCPUID() { + DeviceInfo cpu_info{}; + cpu_info.type = OrtHardwareDeviceType_CPU; + + auto& cpuinfo = CPUIDInfo::GetCPUIDInfo(); + cpu_info.vendor_id = cpuinfo.GetCPUVendorId(); + + std::string_view cpuid_vendor = cpuinfo.GetCPUVendor(); + cpu_info.vendor = std::wstring(cpuid_vendor.begin(), cpuid_vendor.end()); + cpu_info.description = cpu_info.vendor; + + return cpu_info; +} } // namespace // Get devices from various sources and combine them into a single set of devices. @@ -386,6 +483,22 @@ std::unordered_set DeviceDiscovery::DiscoverDevicesForPlatfor // setupapi_info. key is vendor_id+device_id std::unordered_map setupapi_info = GetDeviceInfoSetupApi(npus); + // Ensure we have at least one CPU + bool found_cpu = false; + for (auto& [key, device] : setupapi_info) { + if (device.type == OrtHardwareDeviceType_CPU) { + found_cpu = true; + break; + } + } + + // If no CPU was found via SetupApi, add one from CPUID + if (!found_cpu) { + DeviceInfo device = GetDeviceInfoCPUID(); + uint64_t key = GetDeviceKey(device); + setupapi_info[key] = std::move(device); + } + // add dxcore info for any devices that are not in d3d12. // d3d12 info is more complete and has a good description and metadata. // dxcore has 'Discrete' in metadata so add that if found @@ -405,25 +518,40 @@ std::unordered_set DeviceDiscovery::DiscoverDevicesForPlatfor } } - std::wstring_convert > converter; // wstring to string - const auto device_to_ortdevice = [&converter]( + // our log output to std::wclog breaks with UTF8 chars that are not supported by the current code page. + // e.g. (TM) symbol. that stops ALL logging working on at least arm64. + // safest way to avoid that is to keep it to single byte chars. + // process the OrtHardwareDevice values this way so it can be safely logged. + // only the 'description' metadata is likely to be affected and that is mainly for diagnostic purposes. + const auto to_safe_string = [](const std::wstring& wstr) -> std::string { + std::string str(wstr.size(), ' '); + std::transform(wstr.begin(), wstr.end(), str.begin(), [](wchar_t wchar) { + if (wchar >= 0 && wchar <= 127) { + return static_cast(wchar); + } + return ' '; + }); + return str; + }; + + const auto device_to_ortdevice = [&to_safe_string]( DeviceInfo& device, std::unordered_map* extra_metadata = nullptr) { - OrtHardwareDevice ortdevice{device.type, device.vendor_id, device.device_id, converter.to_bytes(device.vendor)}; + OrtHardwareDevice ortdevice{device.type, device.vendor_id, device.device_id, to_safe_string(device.vendor)}; if (!device.description.empty()) { - ortdevice.metadata.Add("Description", converter.to_bytes(device.description)); + ortdevice.metadata.Add("Description", to_safe_string(device.description)); } for (auto& [key, value] : device.metadata) { - ortdevice.metadata.Add(converter.to_bytes(key), converter.to_bytes(value)); + ortdevice.metadata.Add(to_safe_string(key), to_safe_string(value)); } if (extra_metadata) { // add any extra metadata from the dxcore info for (auto& [key, value] : *extra_metadata) { if (device.metadata.find(key) == device.metadata.end()) { - ortdevice.metadata.Add(converter.to_bytes(key), converter.to_bytes(value)); + ortdevice.metadata.Add(to_safe_string(key), to_safe_string(value)); } } } @@ -431,6 +559,7 @@ std::unordered_set DeviceDiscovery::DiscoverDevicesForPlatfor std::ostringstream oss; oss << "Adding OrtHardwareDevice {vendor_id:0x" << std::hex << ortdevice.vendor_id << ", device_id:0x" << ortdevice.device_id + << ", vendor:" << ortdevice.vendor << ", type:" << std::dec << static_cast(ortdevice.type) << ", metadata: ["; for (auto& [key, value] : ortdevice.metadata.entries) { diff --git a/onnxruntime/core/platform/windows/telemetry.cc b/onnxruntime/core/platform/windows/telemetry.cc index 47789af9d5a47..0775e19c5654b 100644 --- a/onnxruntime/core/platform/windows/telemetry.cc +++ b/onnxruntime/core/platform/windows/telemetry.cc @@ -385,4 +385,21 @@ void WindowsTelemetry::LogExecutionProviderEvent(LUID* adapterLuid) const { TraceLoggingUInt32(adapterLuid->HighPart, "adapterLuidHighPart")); } +void WindowsTelemetry::LogDriverInfoEvent(const std::string_view device_class, const std::wstring_view& driver_names, const std::wstring_view& driver_versions) const { + if (global_register_count_ == 0 || enabled_ == false) + return; + + TraceLoggingWrite(telemetry_provider_handle, + "DriverInfo", + TraceLoggingBool(true, "UTCReplace_AppSessionGuid"), + TelemetryPrivacyDataTag(PDT_ProductAndServiceUsage), + TraceLoggingKeyword(MICROSOFT_KEYWORD_MEASURES), + TraceLoggingLevel(WINEVENT_LEVEL_INFO), + // Telemetry info + TraceLoggingUInt8(0, "schemaVersion"), + TraceLoggingString(device_class.data(), "deviceClass"), + TraceLoggingWideString(driver_names.data(), "driverNames"), + TraceLoggingWideString(driver_versions.data(), "driverVersions")); +} + } // namespace onnxruntime diff --git a/onnxruntime/core/platform/windows/telemetry.h b/onnxruntime/core/platform/windows/telemetry.h index b23a60a44b5f0..92b3d11d77702 100644 --- a/onnxruntime/core/platform/windows/telemetry.h +++ b/onnxruntime/core/platform/windows/telemetry.h @@ -60,6 +60,10 @@ class WindowsTelemetry : public Telemetry { void LogExecutionProviderEvent(LUID* adapterLuid) const override; + void LogDriverInfoEvent(const std::string_view device_class, + const std::wstring_view& driver_names, + const std::wstring_view& driver_versions) const override; + using EtwInternalCallback = std::function; diff --git a/onnxruntime/core/providers/cpu/text/string_normalizer.h b/onnxruntime/core/providers/cpu/text/string_normalizer.h index 750c59ec21e21..9759140847b3a 100644 --- a/onnxruntime/core/providers/cpu/text/string_normalizer.h +++ b/onnxruntime/core/providers/cpu/text/string_normalizer.h @@ -27,7 +27,7 @@ class StringNormalizer : public OpKernel { private: bool is_case_sensitive_{true}; CaseAction case_change_action_{NONE}; - // Set this to lower because some characters do not have capital case. + // Set this to lower because some characters do not have upper case. // used for case-insensitive compare CaseAction compare_caseaction_{LOWER}; std::string locale_name_; diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc index 9a694b03387ae..e70ddc481ba43 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc @@ -870,6 +870,7 @@ GetUnsupportedNodeIndices(const GraphViewer& graph_viewer, "Mod", "Mul", "Multinomial", + "MultiHeadAttention", "Neg", "NegativeLogLikelihoodLoss", "NonMaxSuppression", @@ -884,6 +885,7 @@ GetUnsupportedNodeIndices(const GraphViewer& graph_viewer, "QLinearConv", "QLinearMatMul", "QuantizeLinear", + "QuickGelu", "DynamicQuantizeLinear", "RandomNormal", "RandomNormalLike", diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_allocator.cc b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_allocator.cc new file mode 100644 index 0000000000000..4e8179d86fd73 --- /dev/null +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_allocator.cc @@ -0,0 +1,100 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "nv_allocator.h" +#include "nv_includes.h" +#include "core/providers/cuda/shared_inc/cuda_call.h" +namespace onnxruntime { + +void CUDAAllocator::CheckDevice(bool throw_when_fail) const { +#ifndef NDEBUG + // check device to match at debug build + // if it's expected to change, call cudaSetDevice instead of the check + int current_device; + auto cuda_err = cudaGetDevice(¤t_device); + if (cuda_err == cudaSuccess) { + ORT_ENFORCE(current_device == Info().id); + } else if (throw_when_fail) { + CUDA_CALL_THROW(cuda_err); + } +#else + ORT_UNUSED_PARAMETER(throw_when_fail); +#endif +} + +void CUDAAllocator::SetDevice(bool throw_when_fail) const { + int current_device; + auto cuda_err = cudaGetDevice(¤t_device); + if (cuda_err == cudaSuccess) { + int allocator_device_id = Info().id; + if (current_device != allocator_device_id) { + cuda_err = cudaSetDevice(allocator_device_id); + } + } + + if (cuda_err != cudaSuccess && throw_when_fail) { + CUDA_CALL_THROW(cuda_err); + } +} + +void* CUDAAllocator::Alloc(size_t size) { + SetDevice(true); + CheckDevice(true); + void* p = nullptr; + if (size > 0) { + // BFCArena was updated recently to handle the exception and adjust the request size + CUDA_CALL_THROW(cudaMalloc((void**)&p, size)); + } + return p; +} + +void CUDAAllocator::Free(void* p) { + SetDevice(false); + CheckDevice(false); // ignore CUDA failure when free + cudaFree(p); // do not throw error since it's OK for cudaFree to fail during shutdown +} + +void* CUDAExternalAllocator::Alloc(size_t size) { + void* p = nullptr; + if (size > 0) { + p = alloc_(size); + + // review(codemzs): ORT_ENFORCE does not seem appropriate. + ORT_ENFORCE(p != nullptr); + } + + return p; +} + +void CUDAExternalAllocator::Free(void* p) { + free_(p); + std::lock_guard lock(lock_); + auto it = reserved_.find(p); + if (it != reserved_.end()) { + reserved_.erase(it); + if (empty_cache_) empty_cache_(); + } +} + +void* CUDAExternalAllocator::Reserve(size_t size) { + void* p = Alloc(size); + if (!p) return nullptr; + std::lock_guard lock(lock_); + ORT_ENFORCE(reserved_.find(p) == reserved_.end()); + reserved_.insert(p); + return p; +} + +void* CUDAPinnedAllocator::Alloc(size_t size) { + void* p = nullptr; + if (size > 0) { + CUDA_CALL_THROW(cudaMallocHost((void**)&p, size)); + } + return p; +} + +void CUDAPinnedAllocator::Free(void* p) { + CUDA_CALL_THROW(cudaFreeHost(p)); +} + +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_allocator.h b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_allocator.h new file mode 100644 index 0000000000000..a3f05bded5de9 --- /dev/null +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_allocator.h @@ -0,0 +1,65 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/common/inlined_containers.h" +#include "core/framework/allocator.h" +#include + +namespace onnxruntime { + +class CUDAAllocator : public IAllocator { + public: + CUDAAllocator(OrtDevice::DeviceId device_id, const char* name) + : IAllocator( + OrtMemoryInfo(name, OrtAllocatorType::OrtDeviceAllocator, + OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, device_id), + device_id, OrtMemTypeDefault)) {} + void* Alloc(size_t size) override; + void Free(void* p) override; + + private: + void CheckDevice(bool throw_when_fail) const; + void SetDevice(bool throw_when_fail) const; +}; + +class CUDAExternalAllocator : public CUDAAllocator { + typedef void* (*ExternalAlloc)(size_t size); + typedef void (*ExternalFree)(void* p); + typedef void (*ExternalEmptyCache)(); + + public: + CUDAExternalAllocator(OrtDevice::DeviceId device_id, const char* name, void* alloc, void* free, void* empty_cache) + : CUDAAllocator(device_id, name) { + alloc_ = reinterpret_cast(alloc); + free_ = reinterpret_cast(free); + empty_cache_ = reinterpret_cast(empty_cache); + } + + void* Alloc(size_t size) override; + void Free(void* p) override; + void* Reserve(size_t size) override; + + private: + mutable std::mutex lock_; + ExternalAlloc alloc_; + ExternalFree free_; + ExternalEmptyCache empty_cache_; + InlinedHashSet reserved_; +}; + +// TODO: add a default constructor +class CUDAPinnedAllocator : public IAllocator { + public: + CUDAPinnedAllocator(const char* name) + : IAllocator( + OrtMemoryInfo(name, OrtAllocatorType::OrtDeviceAllocator, + OrtDevice(OrtDevice::CPU, OrtDevice::MemType::CUDA_PINNED, 0 /*CPU device always with id 0*/), + 0, OrtMemTypeCPUOutput)) {} + + void* Alloc(size_t size) override; + void Free(void* p) override; +}; +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_cuda_call.cc b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_cuda_call.cc new file mode 100644 index 0000000000000..8e9ea1257cdd2 --- /dev/null +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_cuda_call.cc @@ -0,0 +1,145 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/shared_library/provider_api.h" +#include "core/providers/cuda/shared_inc/cuda_call.h" +#include + +#ifdef _WIN32 +#else // POSIX +#include +#include +#endif + +namespace onnxruntime { + +using namespace common; + +template +const char* CudaErrString(ERRTYPE) { + ORT_NOT_IMPLEMENTED(); +} + +#define CASE_ENUM_TO_STR(x) \ + case x: \ + return #x + +template <> +const char* CudaErrString(cudaError_t x) { + cudaDeviceSynchronize(); + return cudaGetErrorString(x); +} + +#ifndef USE_CUDA_MINIMAL +template <> +const char* CudaErrString(cublasStatus_t e) { + cudaDeviceSynchronize(); + switch (e) { + CASE_ENUM_TO_STR(CUBLAS_STATUS_SUCCESS); + CASE_ENUM_TO_STR(CUBLAS_STATUS_NOT_INITIALIZED); + CASE_ENUM_TO_STR(CUBLAS_STATUS_ALLOC_FAILED); + CASE_ENUM_TO_STR(CUBLAS_STATUS_INVALID_VALUE); + CASE_ENUM_TO_STR(CUBLAS_STATUS_ARCH_MISMATCH); + CASE_ENUM_TO_STR(CUBLAS_STATUS_MAPPING_ERROR); + CASE_ENUM_TO_STR(CUBLAS_STATUS_EXECUTION_FAILED); + CASE_ENUM_TO_STR(CUBLAS_STATUS_INTERNAL_ERROR); + CASE_ENUM_TO_STR(CUBLAS_STATUS_NOT_SUPPORTED); + CASE_ENUM_TO_STR(CUBLAS_STATUS_LICENSE_ERROR); + default: + return "(look for CUBLAS_STATUS_xxx in cublas_api.h)"; + } +} + +template <> +const char* CudaErrString(curandStatus) { + cudaDeviceSynchronize(); + return "(see curand.h & look for curandStatus or CURAND_STATUS_xxx)"; +} + +template <> +const char* CudaErrString(cudnnStatus_t e) { + cudaDeviceSynchronize(); + return cudnnGetErrorString(e); +} + +template <> +const char* CudaErrString(cufftResult e) { + cudaDeviceSynchronize(); + switch (e) { + CASE_ENUM_TO_STR(CUFFT_SUCCESS); + CASE_ENUM_TO_STR(CUFFT_ALLOC_FAILED); + CASE_ENUM_TO_STR(CUFFT_INVALID_VALUE); + CASE_ENUM_TO_STR(CUFFT_INTERNAL_ERROR); + CASE_ENUM_TO_STR(CUFFT_SETUP_FAILED); + CASE_ENUM_TO_STR(CUFFT_INVALID_SIZE); + default: + return "Unknown cufft error status"; + } +} +#endif + +#ifdef ORT_USE_NCCL +template <> +const char* CudaErrString(ncclResult_t e) { + cudaDeviceSynchronize(); + return ncclGetErrorString(e); +} +#endif + +template +int GetErrorCode(ERRTYPE err) { + return static_cast(err); +} + +template +std::conditional_t CudaCall( + ERRTYPE retCode, const char* exprString, const char* libName, SUCCTYPE successCode, const char* msg, + const char* file, const int line) { + if (retCode != successCode) { + try { +#ifdef _WIN32 + std::string hostname_str = GetEnvironmentVar("COMPUTERNAME"); + if (hostname_str.empty()) { + hostname_str = "?"; + } + const char* hostname = hostname_str.c_str(); +#else + char hostname[HOST_NAME_MAX]; + if (gethostname(hostname, HOST_NAME_MAX) != 0) + strcpy(hostname, "?"); +#endif + int currentCudaDevice = -1; + cudaGetDevice(¤tCudaDevice); + cudaGetLastError(); // clear last CUDA error + static char str[1024]; + snprintf(str, 1024, "%s failure %d: %s ; GPU=%d ; hostname=%s ; file=%s ; line=%d ; expr=%s; %s", + libName, GetErrorCode(retCode), CudaErrString(retCode), currentCudaDevice, + hostname, + file, line, exprString, msg); + if constexpr (THRW) { + // throw an exception with the error info + ORT_THROW(str); + } else { + LOGS_DEFAULT(ERROR) << str; + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, str); + } + } catch (const std::exception& e) { // catch, log, and rethrow since CUDA code sometimes hangs in destruction, + // so we'd never get to see the error + if constexpr (THRW) { + ORT_THROW(e.what()); + } else { + LOGS_DEFAULT(ERROR) << e.what(); + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, e.what()); + } + } + } + if constexpr (!THRW) { + return Status::OK(); + } +} + +template Status CudaCall(cudaError retCode, const char* exprString, const char* libName, cudaError successCode, const char* msg, const char* file, const int line); +template void CudaCall(cudaError retCode, const char* exprString, const char* libName, cudaError successCode, const char* msg, const char* file, const int line); + +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_data_transfer.cc b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_data_transfer.cc new file mode 100644 index 0000000000000..4779ddd1a9556 --- /dev/null +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_data_transfer.cc @@ -0,0 +1,92 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/shared_library/provider_api.h" + +#include "nv_data_transfer.h" + +#include "core/providers/cuda/shared_inc/cuda_call.h" +#define CUDA_RETURN_IF_ERROR(expr) ORT_RETURN_IF_ERROR(CUDA_CALL(expr)) +namespace onnxruntime { +bool GPUDataTransfer::CanCopy(const OrtDevice& src_device, const OrtDevice& dst_device) const { + return src_device.Type() == OrtDevice::GPU || src_device.MemType() == OrtDevice::MemType::CUDA_PINNED || + dst_device.Type() == OrtDevice::GPU || dst_device.MemType() == OrtDevice::MemType::CUDA_PINNED; +} + +common::Status GPUDataTransfer::CopyTensor(const Tensor& src, Tensor& dst) const { + size_t bytes = src.SizeInBytes(); + const void* src_data = src.DataRaw(); + void* dst_data = dst.MutableDataRaw(); + + auto& src_device = src.Location().device; + auto& dst_device = dst.Location().device; + + // for the sync version of memcpy, launch to cuda default stream + if (dst_device.Type() == OrtDevice::GPU) { + if (src_device.Type() == OrtDevice::GPU) { + // Copy only if the two addresses are different. + if (dst_data != src_data) { + CUDA_RETURN_IF_ERROR(cudaMemcpy(dst_data, src_data, bytes, cudaMemcpyDeviceToDevice)); + // For device memory to device memory copy, no host-side synchronization is performed by cudaMemcpy. + // see https://docs.nvidia.com/cuda/cuda-runtime-api/api-sync-behavior.html + CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(nullptr)); + } + } else { + // copy from other CPU memory to GPU, this is blocking + CUDA_RETURN_IF_ERROR(cudaMemcpy(dst_data, src_data, bytes, cudaMemcpyHostToDevice)); + if (src_device.MemType() != OrtDevice::MemType::CUDA_PINNED) { + // For cudaMemcpy from pageable host memory to device memory, DMA to final destination may not have completed. + // see https://docs.nvidia.com/cuda/cuda-runtime-api/api-sync-behavior.html + CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(nullptr)); + } + } + } else if (src_device.Type() == OrtDevice::GPU) { + // copying from GPU to CPU memory, this is blocking + CUDA_RETURN_IF_ERROR(cudaMemcpy(dst_data, src_data, bytes, cudaMemcpyDeviceToHost)); + } else { + // copying between cpu memory + ORT_ENFORCE(dst_data != src_data); + memcpy(dst_data, src_data, bytes); + } + + return Status::OK(); +} + +common::Status GPUDataTransfer::CopyTensorAsync(const Tensor& src, Tensor& dst, Stream& stream) const { + size_t bytes = src.SizeInBytes(); + const void* src_data = src.DataRaw(); + void* dst_data = dst.MutableDataRaw(); + + auto& src_device = src.Location().device; + auto& dst_device = dst.Location().device; + + if (dst_device.Type() == OrtDevice::GPU) { + if (src_device.Type() == OrtDevice::CPU) { + // copy from pinned or non-pinned CPU memory to GPU + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(dst_data, src_data, bytes, cudaMemcpyHostToDevice, static_cast(stream.GetHandle()))); + } else if (src_device.Type() == OrtDevice::GPU) { + // copying between GPU, this is non-blocking + if (dst_data != src_data) { + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(dst_data, src_data, bytes, cudaMemcpyDeviceToDevice, static_cast(stream.GetHandle()))); + } + } + } else if (src_device.Type() == OrtDevice::GPU) { + if (dst_device.Type() == OrtDevice::CPU) { + // copy from GPU to pinned or non-pinned CPU memory. + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(dst_data, src_data, bytes, cudaMemcpyDeviceToHost, static_cast(stream.GetHandle()))); + } + } else { + if (src_device.MemType() == OrtDevice::MemType::CUDA_PINNED) { + // sync the stream first to make sure the data arrived + CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(static_cast(stream.GetHandle()))); + } + + ORT_ENFORCE(dst_data != src_data); + memcpy(dst_data, src_data, bytes); + } + + return Status::OK(); +} + +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_data_transfer.h b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_data_transfer.h new file mode 100644 index 0000000000000..272ea367ac7e4 --- /dev/null +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_data_transfer.h @@ -0,0 +1,25 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "nv_includes.h" +#include "core/framework/data_transfer.h" + +namespace onnxruntime { + +class GPUDataTransfer : public IDataTransfer { + public: + GPUDataTransfer() = default; + ~GPUDataTransfer() = default; + + bool CanCopy(const OrtDevice& src_device, const OrtDevice& dst_device) const override; + + // Dumpen MSVC warning about not fully overriding + using IDataTransfer::CopyTensor; + common::Status CopyTensor(const Tensor& src, Tensor& dst) const override; + common::Status CopyTensorAsync(const Tensor& src, Tensor& dst, Stream& stream) const override; +}; + +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc index 25c130a849793..6a7ff63dbc0ed 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc @@ -1,4 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. +// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // Licensed under the MIT License. #include #include @@ -12,10 +13,11 @@ #include "nv_execution_provider.h" #include "nv_execution_provider_utils.h" #include "nv_execution_provider_custom_ops.h" +#include "nv_allocator.h" +#include "nv_data_transfer.h" #include "onnx_ctx_model_helper.h" #include "core/providers/cuda/shared_inc/cuda_call.h" #include "core/providers/cuda/math/unary_elementwise_ops_impl.h" -#include "core/providers/cuda/gpu_data_transfer.h" #include "core/session/allocator_adapters.h" #include "cuda_runtime_api.h" #include @@ -113,16 +115,6 @@ void Impl_Cast( } } // namespace cuda -template <> -Status CudaCall(cudaError retCode, const char* exprString, const char* libName, cudaError successCode, const char* msg, const char* file, const int line) { - return g_host->CudaCall_false(retCode, exprString, libName, successCode, msg, file, line); -} - -template <> -void CudaCall(cudaError retCode, const char* exprString, const char* libName, cudaError successCode, const char* msg, const char* file, const int line) { - return g_host->CudaCall_true(retCode, exprString, libName, successCode, msg, file, line); -} - #if NV_TENSORRT_MAJOR >= 10 void* OutputAllocator::reallocateOutputAsync(char const* /*tensorName*/, void* /*currentMemory*/, uint64_t size, uint64_t /*alignment*/, cudaStream_t /*stream*/) noexcept { @@ -1311,13 +1303,14 @@ void NvExecutionProvider::IncrementRegularRunCountBeforeGraphCapture() { std::vector NvExecutionProvider::CreatePreferredAllocators() { AllocatorCreationInfo default_memory_info( - [](OrtDevice::DeviceId device_id) { return CreateCUDAAllocator(device_id, onnxruntime::CUDA); }, + [](OrtDevice::DeviceId device_id) { return std::make_unique(device_id, CUDA); }, narrow(device_id_)); AllocatorCreationInfo pinned_allocator_info( [](OrtDevice::DeviceId device_id) { ORT_UNUSED_PARAMETER(device_id); - return CreateCUDAPinnedAllocator(onnxruntime::CUDA_PINNED); + return std::make_unique(CUDA_PINNED); + ; }, 0); @@ -1325,7 +1318,7 @@ std::vector NvExecutionProvider::CreatePreferredAllocators() { } std::unique_ptr NvExecutionProvider::GetDataTransfer() const { - return onnxruntime::CreateGPUDataTransfer(); + return std::make_unique(); } Status NvExecutionProvider::OnRunStart(const onnxruntime::RunOptions& /*run_options*/) { @@ -2021,10 +2014,12 @@ NvExecutionProvider::GetCapability(const GraphViewer& graph, } // Remove subgraphs if its size is less than the predefined minimal size - for (auto it = supported_nodes_vector.begin(); it != supported_nodes_vector.end(); ++it) { + for (auto it = supported_nodes_vector.begin(); it != supported_nodes_vector.end();) { const size_t subgraph_size = it->first.size(); if (subgraph_size < min_subgraph_size_) { - supported_nodes_vector.erase(it--); + it = supported_nodes_vector.erase(it); + } else { + ++it; } } @@ -2586,11 +2581,7 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr if (mem_size > max_ctx_mem_size_) { max_ctx_mem_size_ = mem_size; } -#if NV_TENSORRT_MAJOR < 10 - trt_context = std::unique_ptr(trt_engine->createExecutionContextWithoutDeviceMemory()); -#else trt_context = std::unique_ptr(trt_engine->createExecutionContext(nvinfer1::ExecutionContextAllocationStrategy::kUSER_MANAGED)); -#endif } else { trt_context = std::unique_ptr(trt_engine->createExecutionContext()); } @@ -2971,11 +2962,8 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(const Gra if (mem_size > max_ctx_mem_size_) { max_ctx_mem_size_ = mem_size; } -#if NV_TENSORRT_MAJOR < 10 - trt_context = std::unique_ptr(trt_engine->createExecutionContextWithoutDeviceMemory()); -#else trt_context = std::unique_ptr(trt_engine->createExecutionContext(nvinfer1::ExecutionContextAllocationStrategy::kUSER_MANAGED)); -#endif + } else { trt_context = std::unique_ptr(trt_engine->createExecutionContext()); } diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/pool_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/pool_op_builder.cc index 795886fa255ed..ba9f7baa4c1ee 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/pool_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/pool_op_builder.cc @@ -59,8 +59,29 @@ Status PoolOpBuilder::IsOpSupported(QnnModelWrapper& qnn_model_wrapper, std::vector input_shape; ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(inputs[0].node_arg, input_shape), "Cannot get shape"); - if (input_shape.size() != 4) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "QNN Pool2D only support 2D!"); + + bool is1d = (input_shape.size() == 3); + if (!is1d && input_shape.size() != 4) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "QNN Pool only supports rank 3 or 4!"); + } + + NodeAttrHelper node_helper(node_unit); + + if (is1d) { + auto kernel_shape = node_helper.Get("kernel_shape", std::vector{}); + ORT_RETURN_IF_NOT(kernel_shape.size() == 1, "QNN Pool1D: kernel_shape must have length 1!"); + + auto pads = node_helper.Get("pads", std::vector{}); + ORT_RETURN_IF_NOT(pads.size() == 2, "QNN Pool1D: pads must have length 2!"); + + auto strides = node_helper.Get("strides", std::vector{}); + ORT_RETURN_IF_NOT(strides.empty() || strides.size() == 1, "QNN Pool1D: strides must have length 1 or omitted!"); + + auto dilations = node_helper.Get("dilations", std::vector{1}); + ORT_RETURN_IF_NOT(dilations.size() == 1, "QNN Pool1D: dilations must have length 1 or omitted!"); + } else { + auto dilations = node_helper.Get("dilations", std::vector{1, 1}); + ORT_RETURN_IF_NOT(dilations.size() == 2, "QNN Pool2D: dilations must have length 2 or omitted!"); } if (node_unit.Outputs().size() > 1) { @@ -73,15 +94,9 @@ Status PoolOpBuilder::IsOpSupported(QnnModelWrapper& qnn_model_wrapper, return Status::OK(); } - NodeAttrHelper node_helper(node_unit); - auto dilation_values = node_helper.Get("dilations", std::vector{1, 1}); - if (dilation_values != std::vector{1, 1}) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "QNN does not support Dilation attribute"); - } - if (op_type == "MaxPool" || op_type == "AveragePool") { auto auto_pad = node_helper.Get("auto_pad", std::string("NOTSET")); - ORT_RETURN_IF(auto_pad != "NOTSET" && auto_pad != "SAME_LOWER" && auto_pad != "SAME_UPPER", + ORT_RETURN_IF(auto_pad != "NOTSET" && auto_pad != "SAME_LOWER" && auto_pad != "SAME_UPPER" && auto_pad != "VALID", "QNN Pool operators do not support 'auto_pad' value: ", auto_pad.c_str()); } @@ -94,19 +109,49 @@ Status PoolOpBuilder::SetCommonPoolParams(const NodeAttrHelper& node_helper, int32_t& ceil_mode, std::vector&& input_shape, std::vector&& output_shape) const { - filter_size = node_helper.Get("kernel_shape", std::vector{1, 1}); - ORT_RETURN_IF_NOT(filter_size.size() == 2, "QNN only support kernel_shape with shape[2]."); - - strides = node_helper.Get("strides", std::vector{1, 1}); - ORT_RETURN_IF_NOT(strides.size() == 2, "QNN only support strides with shape[2]."); + { + auto raw_filter_size = node_helper.Get("kernel_shape", std::vector{1, 1}); + if (raw_filter_size.size() == 1) { + filter_size = {1, raw_filter_size[0]}; + } else { + filter_size = raw_filter_size; + } + } + ORT_RETURN_IF_NOT(filter_size.size() == 2, + "QNN only support kernel_shape with shape[2]."); + + { + auto raw_strides = node_helper.Get("strides", std::vector{1, 1}); + if (raw_strides.size() == 1) { + strides = {1, raw_strides[0]}; + } else { + strides = raw_strides; + } + } + ORT_RETURN_IF_NOT(strides.size() == 2, + "QNN only support strides with shape[2]."); + + { + auto raw_pad_amount = node_helper.Get("pads", std::vector{0, 0, 0, 0}); + if (raw_pad_amount.size() == 2) { + pad_amount = {0, raw_pad_amount[0], 0, raw_pad_amount[1]}; + } else { + pad_amount = raw_pad_amount; + } + } - pad_amount = node_helper.Get("pads", std::vector{0, 0, 0, 0}); auto auto_pad = node_helper.Get("auto_pad", std::string("NOTSET")); - ORT_RETURN_IF(auto_pad != "NOTSET" && auto_pad != "SAME_LOWER" && auto_pad != "SAME_UPPER", + ORT_RETURN_IF(auto_pad != "NOTSET" && auto_pad != "SAME_LOWER" && auto_pad != "SAME_UPPER" && auto_pad != "VALID", "QNN Pool operators do not support 'auto_pad' value: ", auto_pad.c_str()); if (auto_pad.compare("NOTSET") != 0) { - std::vector dilations = node_helper.Get("dilations", std::vector{1, 1}); + std::vector dilations; + auto raw_dilations = node_helper.Get("dilations", std::vector{1, 1}); + if (raw_dilations.size() == 1) { + dilations = {1, raw_dilations[0]}; + } else { + dilations = raw_dilations; + } auto total_pads_0 = (output_shape[1] - 1) * strides[0] + (filter_size[0] - 1) * dilations[0] + 1 - input_shape[1]; auto total_pads_1 = (output_shape[2] - 1) * strides[1] + (filter_size[1] - 1) * dilations[1] + 1 - input_shape[2]; @@ -144,6 +189,36 @@ void SetPoolParam(const NodeUnit& node_unit, qnn_model_wrapper.AddParamWrapper(std::move(qnn_param)); } +std::vector ComputePoolOutputShape( + const std::vector& input_shape, // {N, H, W, C} + const std::vector& kernel_shape, // {k_h, k_w} + const std::vector& strides, // {s_h, s_w} + const std::vector& pads) { + assert(input_shape.size() == 4 && + kernel_shape.size() == 2 && + strides.size() == 2 && + pads.size() == 4); + + const uint32_t N = input_shape[0]; + const uint32_t H = input_shape[1]; + const uint32_t W = input_shape[2]; + const uint32_t C = input_shape[3]; + + // pad the spatial dims + uint32_t padded_H = H + pads[0] + pads[2]; + uint32_t padded_W = W + pads[1] + pads[3]; + + // floor-mode on NHWC + uint32_t out_H = (padded_H < kernel_shape[0]) + ? 0 + : (padded_H - kernel_shape[0]) / strides[0] + 1; + uint32_t out_W = (padded_W < kernel_shape[1]) + ? 0 + : (padded_W - kernel_shape[1]) / strides[1] + 1; + + return {N, out_H, out_W, C}; +} + Status PoolOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit, std::vector&& input_names, @@ -154,7 +229,45 @@ Status PoolOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wra const auto& inputs = node_unit.Inputs(); std::vector input_shape; ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(inputs[0].node_arg, input_shape), "Cannot get shape"); - ORT_RETURN_IF_NOT(input_shape.size() == 4, "Input should have 4 dimension NCHW."); + + const auto& reshape_input = node_unit.Inputs()[0]; + TensorInfo reshape_input_info = {}; + ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(reshape_input, reshape_input_info)); + + bool needs_reshape = false; + const std::string reshape4d = input_names[0] + "_pre_reshape"; + if (input_shape.size() == 3) { + needs_reshape = true; + // build new_shape = {N, 1, C, L} + std::vector new_shape = { + input_shape[0], + 1, + input_shape[1], + input_shape[2]}; + + const std::string reshape_node_name = "pre_reshape"; + QnnTensorWrapper rw( + reshape4d, + QNN_TENSOR_TYPE_NATIVE, + reshape_input_info.qnn_data_type, + reshape_input_info.quant_param.Copy(), + std::move(new_shape)); + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(rw)), + "Failed to add reshape-4d tensor."); + ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode( + reshape_node_name, + QNN_OP_PACKAGE_NAME_QTI_AISW, + "Reshape", + {input_names[0]}, + {reshape4d}, + {}, + do_op_validation), + "Failed to create reshape-4d node."); + input_names[0] = reshape4d; + input_shape = {input_shape[0], 1, input_shape[1], input_shape[2]}; + } + + ORT_RETURN_IF_NOT(input_shape.size() == 4, "Input should have 4 dims NCHW or 3 dims for 1D pooling."); // Default value for GlobalAveragePool // Pool use filter & stride with shape [filter_height, filter_width] // With layout transformer, the input has shape [batch, height, width, channel], @@ -192,6 +305,22 @@ Status PoolOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wra std::move(input_shape), std::move(output_shape))); } + std::vector onnx_in_shape; + ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(inputs[0].node_arg, onnx_in_shape), "Cannot get shape"); + // Reshaped input rank-4 for MaxPool + if (onnx_in_shape.size() == 3) { + onnx_in_shape = {onnx_in_shape[0], + 1, + onnx_in_shape[1], + onnx_in_shape[2]}; + } + + // Calculate MaxPool output for rank-4 when input is rank 3 + auto pooled_shape = ComputePoolOutputShape(onnx_in_shape, + filter_size, + stride, + pad_amount); + SetPoolParam(node_unit, QNN_OP_POOL_MAX_2D_PARAM_FILTER_SIZE, std::move(filter_size_dim), std::move(filter_size), param_tensor_names, qnn_model_wrapper); SetPoolParam(node_unit, QNN_OP_POOL_MAX_2D_PARAM_PAD_AMOUNT, std::move(pad_amount_dim), std::move(pad_amount), param_tensor_names, qnn_model_wrapper); SetPoolParam(node_unit, QNN_OP_POOL_MAX_2D_PARAM_STRIDE, std::move(stride_dim), std::move(stride), param_tensor_names, qnn_model_wrapper); @@ -229,12 +358,65 @@ Status PoolOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wra qnn_model_wrapper.AddParamWrapper(std::move(count_pad_for_edges_param)); } - ORT_RETURN_IF_ERROR(ProcessOutputs(qnn_model_wrapper, node_unit, - std::move(input_names), - std::move(param_tensor_names), - logger, - do_op_validation, - GetQnnOpType(op_type))); + if (!needs_reshape) { + ORT_RETURN_IF_ERROR(ProcessOutputs(qnn_model_wrapper, node_unit, + std::move(input_names), + std::move(param_tensor_names), + logger, + do_op_validation, + GetQnnOpType(op_type))); + + return Status::OK(); + } + const auto& outputs = node_unit.Outputs(); + const std::string real_out = outputs[0].node_arg.Name(); + const std::string pool_name = "poolmax2d"; + const std::string pool_out = real_out + "_post_reshape"; + const std::string post_reshape_node_name = "post_reshape"; + const std::string qnn_op = GetQnnOpType(op_type); + TensorInfo output_info{}; + ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(node_unit.Outputs()[0], output_info)); + bool is_graph_output = qnn_model_wrapper.IsGraphOutput(real_out); + Qnn_TensorType_t tensor_type = is_graph_output ? QNN_TENSOR_TYPE_APP_READ : QNN_TENSOR_TYPE_NATIVE; + QnnTensorWrapper pool_tensor( + pool_out, + QNN_TENSOR_TYPE_NATIVE, + reshape_input_info.qnn_data_type, + output_info.quant_param.Copy(), + std::move(pooled_shape)); + + ORT_RETURN_IF_NOT( + qnn_model_wrapper.AddTensorWrapper(std::move(pool_tensor)), + "Failed to add tensor for pool_out"); + + ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode( + pool_name, + QNN_OP_PACKAGE_NAME_QTI_AISW, + qnn_op, + {reshape4d}, + {pool_out}, + std::move(param_tensor_names), + do_op_validation), + "Failed to create QNN Pool node for rank-3 input."); + + std::vector final_shape3d = output_info.shape; + QnnTensorWrapper reshape_back_tensor( + real_out, + tensor_type, + output_info.qnn_data_type, + output_info.quant_param.Copy(), + std::move(final_shape3d)); + + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(reshape_back_tensor)), "Failed to add tensor."); + ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode( + post_reshape_node_name, + QNN_OP_PACKAGE_NAME_QTI_AISW, + "Reshape", + {pool_out}, + {real_out}, + {}, + do_op_validation), + "Failed to create reshape-back node."); return Status::OK(); } diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/resize_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/resize_op_builder.cc index 347f0651069dc..85844721b1f2c 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/resize_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/resize_op_builder.cc @@ -185,18 +185,15 @@ Status ResizeOpBuilder::IsOpSupported(QnnModelWrapper& qnn_model_wrapper, "QNN EP: Resize on the NPU does not support nearest_mode ", nearest_mode.c_str()); #endif - const bool use_resize_nn_op = nearest_mode == "floor"; + // Use ResizeNearestNeighbor for rank-4 inputs. + const bool use_resize_nn_op = input_rank == 4; // If HTP uses ResizeNearestNeighbor ("floor"), then the "pytorch_half_pixel" coordinate_transformation_mode // is not supported. - ORT_RETURN_IF(use_resize_nn_op && transformation_mode == "pytorch_half_pixel", + ORT_RETURN_IF(!use_resize_nn_op && nearest_mode == "floor" && transformation_mode == "pytorch_half_pixel", "QNN EP: Resize on the NPU does not support the combination of nearest_mode == 'floor' ", " and coordinate_transformation_mode == 'pytorch_half_pixel'."); - // QNN's ResizeNearestNeighbor requires rank 4 inputs. - ORT_RETURN_IF(use_resize_nn_op && input_rank != 4, - "QNN EP: Resize on the NPU with nearest_mode == 'floor' requires an input with rank 4."); - #if QNN_API_VERSION_MAJOR >= 2 && QNN_API_VERSION_MINOR >= 14 // QNN's Resize only supports "round_prefer_ceil" if transformation_mode is "align_corners". ORT_RETURN_IF(!use_resize_nn_op && transformation_mode != "align_corners", @@ -267,11 +264,11 @@ Status ResizeOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_w const bool is_npu_backend = IsNpuBackend(qnn_model_wrapper.GetQnnBackendType()); std::string qnn_op_type = "Resize"; - if (is_npu_backend && input_rank == 4 && interp_mode == "nearest" && nearest_mode == "floor") { + if (is_npu_backend && input_rank == 4 && interp_mode == "nearest") { // Translate Resize with - // {input_rank: 4, mode: "nearest", nearest_mode: "floor", coordinate_transformation_mode: XXX} to - // QNN's ResizeNearestNeighbor operator on the HTP backend. This combination of parameters is not supported on HTP - // via QNN's Resize operator. Note that QNN's ResizeNearestNeighbor operator always uses "floor" rounding. + // {input_rank: 4, mode: "nearest", coordinate_transformation_mode: XXX} to + // QNN's ResizeNearestNeighbor operator on the HTP backend. QNN ResizeNearestNeighbor + // seems to be faster than QNN Resize. qnn_op_type = "ResizeNearestNeighbor"; // 'align_corners' diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc index edd1f3e9eb53b..0009dab837525 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc @@ -1280,7 +1280,7 @@ Status QnnBackendManager::ExtractBackendProfilingInfo() { } ORT_RETURN_IF(!tracelogging_provider_ep_enabled && profiling_file_path_.empty(), - "Need to specify a cvs file via provider option profiling_file_path if ETW not enabled."); + "Need to specify a CSV file via provider option profiling_file_path if ETW not enabled."); ORT_RETURN_IF(nullptr == profile_backend_handle_, "Backend profile handle not valid."); @@ -1311,7 +1311,7 @@ Status QnnBackendManager::ExtractBackendProfilingInfo() { } std::ofstream outfile; - if (!tracelogging_provider_ep_enabled) { + if (!profiling_file_path_.empty()) { // Write to CSV in append mode std::ifstream infile(profiling_file_path_.c_str()); bool exists = infile.good(); @@ -1334,10 +1334,11 @@ Status QnnBackendManager::ExtractBackendProfilingInfo() { tracelogging_provider_ep_enabled)); } - if (!tracelogging_provider_ep_enabled) { - outfile.close(); - LOGS(*logger_, VERBOSE) << "Wrote QNN profiling events (" << num_events << ") to qnn-profiling-data.csv"; - } else { + if (outfile) { + LOGS(*logger_, VERBOSE) << "Wrote QNN profiling events (" << num_events << ") to file (" + << profiling_file_path_ << ")"; + } + if (tracelogging_provider_ep_enabled) { LOGS(*logger_, VERBOSE) << "Wrote QNN profiling events (" << num_events << ") to ETW"; } } @@ -1399,11 +1400,7 @@ Status QnnBackendManager::ExtractProfilingEventBasic( std::string message = GetEventTypeString(event_data.type); std::string unit = GetUnitString(event_data.unit); -#ifndef _WIN32 - tracelogging_provider_ep_enabled = false; -#endif - - if (!tracelogging_provider_ep_enabled) { + if (outfile) { outfile << "UNKNOWN" << "," << message << "," @@ -1413,7 +1410,9 @@ Status QnnBackendManager::ExtractProfilingEventBasic( << "," << eventLevel << "," << (event_data.identifier ? event_data.identifier : "NULL") << "\n"; - } else { + } + + if (tracelogging_provider_ep_enabled) { #ifdef _WIN32 LogQnnProfileEventAsTraceLogging( (uint64_t)0, @@ -1443,11 +1442,7 @@ Status QnnBackendManager::ExtractProfilingEventExtended( std::string message = GetEventTypeString(event_data_extended.v1.type); std::string unit = GetUnitString(event_data_extended.v1.unit); -#ifndef _WIN32 - tracelogging_provider_ep_enabled = false; -#endif - - if (!tracelogging_provider_ep_enabled) { + if (outfile) { if (event_data_extended.version == QNN_PROFILE_DATA_VERSION_1) { outfile << event_data_extended.v1.timestamp << "," << message << "," @@ -1458,7 +1453,9 @@ Status QnnBackendManager::ExtractProfilingEventExtended( << eventLevel << "," << (event_data_extended.v1.identifier ? event_data_extended.v1.identifier : "NULL") << "\n"; } - } else { + } + + if (tracelogging_provider_ep_enabled) { #ifdef _WIN32 LogQnnProfileEventAsTraceLogging( event_data_extended.v1.timestamp, diff --git a/onnxruntime/core/providers/qnn/builder/qnn_model.cc b/onnxruntime/core/providers/qnn/builder/qnn_model.cc index 8421bd4a99196..ec84820bb7896 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_model.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_model.cc @@ -8,10 +8,10 @@ #include #include "QnnOpDef.h" -#include "core/providers/qnn/ort_api.h" #include "core/providers/qnn/builder/op_builder_factory.h" -#include "core/providers/qnn/builder/qnn_node_group.h" +#include "core/providers/qnn/builder/qnn_node_group/qnn_node_group.h" #include "core/providers/qnn/builder/qnn_utils.h" +#include "core/providers/qnn/ort_api.h" #include "core/providers/qnn/qnn_allocator.h" #include "core/providers/qnn/shared_context.h" diff --git a/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.cc b/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.cc index 39ec9dba18f07..0f0b42bf754d7 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.cc @@ -69,7 +69,7 @@ Status QnnModelWrapper::MakeTensorWrapper(const NodeUnitIODef& tensor, QnnTensor } Qnn_TensorMemType_t mem_type = QNN_TENSORMEMTYPE_RAW; - if (true == model_settings_.htp_shared_memory) { + if (true == model_settings_.htp_shared_memory && (IsGraphInput(tensor_name) || IsGraphOutput(tensor_name))) { mem_type = QNN_TENSORMEMTYPE_MEMHANDLE; } tensor_wrapper = QnnTensorWrapper(tensor_name, GetTensorType(tensor_name), tensor_info.qnn_data_type, diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/dq_q_fusion.h b/onnxruntime/core/providers/qnn/builder/qnn_node_group/dq_q_fusion.h index d3d552bc172ec..cbc052cbebe25 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_node_group/dq_q_fusion.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/dq_q_fusion.h @@ -7,8 +7,8 @@ #include #include +#include "core/providers/qnn/builder/qnn_node_group/qnn_node_group.h" #include "core/providers/qnn/ort_api.h" -#include "core/providers/qnn/builder/qnn_node_group.h" namespace onnxruntime { namespace qnn { diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/hardsigmoid_mul_fusion.h b/onnxruntime/core/providers/qnn/builder/qnn_node_group/hardsigmoid_mul_fusion.h index 0a1b16d24ffcd..51243b9ffa79b 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_node_group/hardsigmoid_mul_fusion.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/hardsigmoid_mul_fusion.h @@ -7,8 +7,8 @@ #include #include +#include "core/providers/qnn/builder/qnn_node_group/qnn_node_group.h" #include "core/providers/qnn/ort_api.h" -#include "core/providers/qnn/builder/qnn_node_group.h" namespace onnxruntime { namespace qnn { diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/qnn_node_group.cc b/onnxruntime/core/providers/qnn/builder/qnn_node_group/qnn_node_group.cc index 85969b9e2dc05..0390a305b2df9 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_node_group/qnn_node_group.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/qnn_node_group.cc @@ -1,8 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "core/providers/qnn/builder/qnn_node_group.h" - #include #include #include @@ -10,13 +8,15 @@ #include #include #include -#include "core/providers/qnn/ort_api.h" -#include "core/providers/qnn/builder/qnn_utils.h" -#include "core/providers/qnn/builder/qnn_model_wrapper.h" + #include "core/providers/qnn/builder/op_builder_factory.h" +#include "core/providers/qnn/builder/qnn_model_wrapper.h" #include "core/providers/qnn/builder/qnn_node_group/dq_q_fusion.h" #include "core/providers/qnn/builder/qnn_node_group/hardsigmoid_mul_fusion.h" +#include "core/providers/qnn/builder/qnn_node_group/qnn_node_group.h" #include "core/providers/qnn/builder/qnn_node_group/reshape_gemm_fusion.h" +#include "core/providers/qnn/builder/qnn_utils.h" +#include "core/providers/qnn/ort_api.h" namespace onnxruntime { namespace qnn { diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group.h b/onnxruntime/core/providers/qnn/builder/qnn_node_group/qnn_node_group.h similarity index 100% rename from onnxruntime/core/providers/qnn/builder/qnn_node_group.h rename to onnxruntime/core/providers/qnn/builder/qnn_node_group/qnn_node_group.h diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/reshape_gemm_fusion.h b/onnxruntime/core/providers/qnn/builder/qnn_node_group/reshape_gemm_fusion.h index 6c953e6cf72c5..7e3f4b962a15c 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_node_group/reshape_gemm_fusion.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/reshape_gemm_fusion.h @@ -9,8 +9,8 @@ #include #include +#include "core/providers/qnn/builder/qnn_node_group/qnn_node_group.h" #include "core/providers/qnn/ort_api.h" -#include "core/providers/qnn/builder/qnn_node_group.h" namespace onnxruntime { namespace qnn { diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/utils.cc b/onnxruntime/core/providers/qnn/builder/qnn_node_group/utils.cc index 93b2fca296389..bd74f3d43b325 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_node_group/utils.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/utils.cc @@ -4,8 +4,8 @@ #include #include +#include "core/providers/qnn/builder/qnn_node_group/qnn_node_group.h" #include "core/providers/qnn/ort_api.h" -#include "core/providers/qnn/builder/qnn_node_group.h" namespace onnxruntime { namespace qnn { diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/utils.h b/onnxruntime/core/providers/qnn/builder/qnn_node_group/utils.h index c4cf4e8a20a92..f0b2afb67006e 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_node_group/utils.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/utils.h @@ -7,8 +7,8 @@ #include #include +#include "core/providers/qnn/builder/qnn_node_group/qnn_node_group.h" #include "core/providers/qnn/ort_api.h" -#include "core/providers/qnn/builder/qnn_node_group.h" namespace onnxruntime { namespace qnn { diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc index 2d117927cbaf7..65ef19f0b6c0e 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc @@ -8,13 +8,13 @@ #include #include -#include "core/providers/qnn/ort_api.h" #include "core/providers/qnn/builder/onnx_ctx_model_helper.h" #include "core/providers/qnn/builder/op_builder_factory.h" #include "core/providers/qnn/builder/qnn_def.h" #include "core/providers/qnn/builder/qnn_model_wrapper.h" -#include "core/providers/qnn/builder/qnn_node_group.h" +#include "core/providers/qnn/builder/qnn_node_group/qnn_node_group.h" #include "core/providers/qnn/builder/qnn_utils.h" +#include "core/providers/qnn/ort_api.h" #include "core/providers/qnn/qnn_allocator.h" #include "core/providers/qnn/qnn_telemetry.h" #include "core/providers/qnn/rpcmem_library.h" diff --git a/onnxruntime/core/providers/qnn/qnn_provider_factory.cc b/onnxruntime/core/providers/qnn/qnn_provider_factory.cc index 7b92a23e428eb..b2f289448b013 100644 --- a/onnxruntime/core/providers/qnn/qnn_provider_factory.cc +++ b/onnxruntime/core/providers/qnn/qnn_provider_factory.cc @@ -79,6 +79,27 @@ struct QNN_Provider : Provider { return std::make_shared(*provider_options, config_options); } + Status CreateIExecutionProvider(const OrtHardwareDevice* const* /*devices*/, + const OrtKeyValuePairs* const* /*ep_metadata*/, + size_t num_devices, + ProviderOptions& provider_options, + const OrtSessionOptions& session_options, + const OrtLogger& logger, + std::unique_ptr& ep) override { + if (num_devices != 1) { + return Status(common::ONNXRUNTIME, ORT_EP_FAIL, "QNN EP only supports one device."); + } + + const ConfigOptions* config_options = &session_options.GetConfigOptions(); + + std::array configs_array = {&provider_options, config_options}; + const void* arg = reinterpret_cast(&configs_array); + auto ep_factory = CreateExecutionProviderFactory(arg); + ep = ep_factory->CreateProvider(session_options, logger); + + return Status::OK(); + } + void Initialize() override {} void Shutdown() override {} } g_provider; @@ -93,4 +114,121 @@ ORT_API(onnxruntime::Provider*, GetProvider) { return &onnxruntime::g_provider; } } + +#include "core/framework/error_code_helper.h" + +// OrtEpApi infrastructure to be able to use the QNN EP as an OrtEpFactory for auto EP selection. +struct QnnEpFactory : OrtEpFactory { + QnnEpFactory(const OrtApi& ort_api_in, + const char* ep_name, + OrtHardwareDeviceType hw_type, + const char* qnn_backend_type) + : ort_api{ort_api_in}, ep_name{ep_name}, ort_hw_device_type{hw_type}, qnn_backend_type{qnn_backend_type} { + GetName = GetNameImpl; + GetVendor = GetVendorImpl; + GetSupportedDevices = GetSupportedDevicesImpl; + CreateEp = CreateEpImpl; + ReleaseEp = ReleaseEpImpl; + } + + // Returns the name for the EP. Each unique factory configuration must have a unique name. + // Ex: a factory that supports NPU should have a different than a factory that supports GPU. + static const char* GetNameImpl(const OrtEpFactory* this_ptr) { + const auto* factory = static_cast(this_ptr); + return factory->ep_name.c_str(); + } + + static const char* GetVendorImpl(const OrtEpFactory* this_ptr) { + const auto* factory = static_cast(this_ptr); + return factory->vendor.c_str(); + } + + // Creates and returns OrtEpDevice instances for all OrtHardwareDevices that this factory supports. + // An EP created with this factory is expected to be able to execute a model with *all* supported + // hardware devices at once. A single instance of QNN EP is not currently setup to partition a model among + // multiple different QNN backends at once (e.g, npu, cpu, gpu), so this factory instance is set to only + // support one backend: npu. To support a different backend, like gpu, create a different factory instance + // that only supports GPU. + static OrtStatus* GetSupportedDevicesImpl(OrtEpFactory* this_ptr, + const OrtHardwareDevice* const* devices, + size_t num_devices, + OrtEpDevice** ep_devices, + size_t max_ep_devices, + size_t* p_num_ep_devices) { + size_t& num_ep_devices = *p_num_ep_devices; + auto* factory = static_cast(this_ptr); + + for (size_t i = 0; i < num_devices && num_ep_devices < max_ep_devices; ++i) { + const OrtHardwareDevice& device = *devices[i]; + if (factory->ort_api.HardwareDevice_Type(&device) == factory->ort_hw_device_type && + factory->ort_api.HardwareDevice_VendorId(&device) == factory->vendor_id) { + OrtKeyValuePairs* ep_options = nullptr; + factory->ort_api.CreateKeyValuePairs(&ep_options); + factory->ort_api.AddKeyValuePair(ep_options, "backend_type", factory->qnn_backend_type.c_str()); + ORT_API_RETURN_IF_ERROR( + factory->ort_api.GetEpApi()->CreateEpDevice(factory, &device, nullptr, ep_options, + &ep_devices[num_ep_devices++])); + } + } + + return nullptr; + } + + static OrtStatus* CreateEpImpl(OrtEpFactory* /*this_ptr*/, + _In_reads_(num_devices) const OrtHardwareDevice* const* /*devices*/, + _In_reads_(num_devices) const OrtKeyValuePairs* const* /*ep_metadata*/, + _In_ size_t /*num_devices*/, + _In_ const OrtSessionOptions* /*session_options*/, + _In_ const OrtLogger* /*logger*/, + _Out_ OrtEp** /*ep*/) { + return onnxruntime::CreateStatus(ORT_INVALID_ARGUMENT, "QNN EP factory does not support this method."); + } + + static void ReleaseEpImpl(OrtEpFactory* /*this_ptr*/, OrtEp* /*ep*/) { + // no-op as we never create an EP here. + } + + const OrtApi& ort_api; + const std::string ep_name; // EP name + const std::string vendor{"Microsoft"}; // EP vendor name + + // Qualcomm vendor ID. Refer to the ACPI ID registry (search Qualcomm): https://uefi.org/ACPI_ID_List + const uint32_t vendor_id{'Q' | ('C' << 8) | ('O' << 16) | ('M' << 24)}; + const OrtHardwareDeviceType ort_hw_device_type; // Supported OrtHardwareDevice + const std::string qnn_backend_type; // QNN backend type for OrtHardwareDevice +}; + +extern "C" { +// +// Public symbols +// +OrtStatus* CreateEpFactories(const char* /*registration_name*/, const OrtApiBase* ort_api_base, + OrtEpFactory** factories, size_t max_factories, size_t* num_factories) { + const OrtApi* ort_api = ort_api_base->GetApi(ORT_API_VERSION); + + // Factory could use registration_name or define its own EP name. + auto factory_npu = std::make_unique(*ort_api, + onnxruntime::kQnnExecutionProvider, + OrtHardwareDeviceType_NPU, "htp"); + + // If want to support GPU, create a new factory instance because QNN EP is not currently setup to partition a single model + // among heterogeneous devices. + // std::unique_ptr factory_gpu = std::make_unique(*ort_api, "QNNExecutionProvider_GPU", OrtHardwareDeviceType_GPU, "gpu"); + + if (max_factories < 1) { + return ort_api->CreateStatus(ORT_INVALID_ARGUMENT, + "Not enough space to return EP factory. Need at least one."); + } + + factories[0] = factory_npu.release(); + *num_factories = 1; + + return nullptr; +} + +OrtStatus* ReleaseEpFactory(OrtEpFactory* factory) { + delete static_cast(factory); + return nullptr; +} +} #endif // !BUILD_QNN_EP_STATIC_LIB diff --git a/onnxruntime/core/providers/qnn/symbols.def b/onnxruntime/core/providers/qnn/symbols.def index 4ec2f7914c208..3afed01da1966 100644 --- a/onnxruntime/core/providers/qnn/symbols.def +++ b/onnxruntime/core/providers/qnn/symbols.def @@ -1,2 +1,4 @@ EXPORTS GetProvider + CreateEpFactories + ReleaseEpFactory diff --git a/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc b/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc index eee6a05f12729..afabc1fa9b1c9 100644 --- a/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc +++ b/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc @@ -347,7 +347,7 @@ common::Status IExecutionProvider::Compile(const std::vector& return g_host->IExecutionProvider__Compile(this, fused_nodes_and_graphs, node_compute_funcs); } -#if defined(USE_TENSORRT) || defined(USE_NV) +#if defined(USE_TENSORRT) std::unique_ptr CreateCUDAAllocator(int16_t device_id, const char* name) { return g_host->CreateCUDAAllocator(device_id, name); } diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index ded135bf50ec8..72eb2579e9d42 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -2587,11 +2587,12 @@ TensorrtExecutionProvider::GetCapability(const GraphViewer& graph, supported_nodes_vector.clear(); } - // Remove subgraphs if its size is less than the predefined minimal size - for (auto it = supported_nodes_vector.begin(); it != supported_nodes_vector.end(); ++it) { + for (auto it = supported_nodes_vector.begin(); it != supported_nodes_vector.end();) { const size_t subgraph_size = it->first.size(); if (subgraph_size < min_subgraph_size_) { - supported_nodes_vector.erase(it--); + it = supported_nodes_vector.erase(it); + } else { + ++it; } } diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.h b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.h index a72de6ed75399..8d4dc19690eac 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.h +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.h @@ -78,7 +78,7 @@ struct TensorRTCustomOp : Ort::CustomOpBasemutable_graph() = *graph_proto_subgraph; auto& logger = logging::LoggingManager::DefaultLogger(); auto model = Model::Create(std::move(*model_proto), ToPathString(filename), nullptr, logger); + auto status = model->MainGraph().Resolve(); + vai_assert(status.IsOK(), "graph resolve error:" + status.ErrorMessage()); if (initializer_size_threshold == std::numeric_limits::max()) { model_proto = model->ToProto(); } else { diff --git a/onnxruntime/core/providers/vsinpu/builders/impl/gather_op_builder.h b/onnxruntime/core/providers/vsinpu/builders/impl/gather_op_builder.h index bd91bbd81e1fa..0d3374bce325b 100644 --- a/onnxruntime/core/providers/vsinpu/builders/impl/gather_op_builder.h +++ b/onnxruntime/core/providers/vsinpu/builders/impl/gather_op_builder.h @@ -59,21 +59,33 @@ class GatherOpBuilder : public BaseOpBuilder { std::vector>& outputs, const NodeUnit& node_unit) override { LOGS_DEFAULT(VERBOSE) << "Creating Gather Op."; + auto indices = node_unit.Inputs()[1]; + int8_t is_scalar_indices = 0; NodeAttrHelper helper(node_unit.GetNode()); auto axis = helper.Get("axis", 0); axis = util::ReverseAxis(axis, inputs[0]->GetShape().size()); auto op = graph_ep->GetGraph()->CreateOperation(axis, 0); + auto indices_shape_proto = indices.node_arg.Shape(); + if (indices_shape_proto != nullptr) { + if (indices_shape_proto->dim_size() == 0) { + is_scalar_indices = 1; + } + } else { + is_scalar_indices = 1; + } bool is_i64_indices = inputs[1]->GetDataType() == tim::vx::DataType::INT64; if (!is_i64_indices) { + inputs[1]->SetScalar(is_scalar_indices); (*op).BindInputs(inputs).BindOutputs(outputs); } else { std::vector origin_data(inputs[1]->GetSpec().GetElementNum()); inputs[1]->CopyDataFromTensor(origin_data.data()); std::vector transformed_data(origin_data.begin(), origin_data.end()); - tim::vx::TensorSpec ts = inputs[1]->GetSpec().SetAttribute(tim::vx::TensorAttribute::INPUT); + tim::vx::TensorSpec ts = inputs[1]->GetSpec(); ts.SetDataType(tim::vx::DataType::INT32); auto transformed_indices = graph_ep->GetGraph()->CreateTensor(ts, transformed_data.data()); + transformed_indices->SetScalar(is_scalar_indices); (*op).BindInput(inputs[0]).BindInput(transformed_indices).BindOutput(outputs[0]); } graph_ep->GetOps().push_back(std::move(op)); diff --git a/onnxruntime/core/providers/vsinpu/vsinpu_ep_graph.cc b/onnxruntime/core/providers/vsinpu/vsinpu_ep_graph.cc index db8a87d9eaf24..ac113ffc1dc64 100644 --- a/onnxruntime/core/providers/vsinpu/vsinpu_ep_graph.cc +++ b/onnxruntime/core/providers/vsinpu/vsinpu_ep_graph.cc @@ -73,20 +73,23 @@ bool GraphEP::Prepare() { } bool GraphEP::SupportedOp(const onnxruntime::GraphViewer& graph_viewer, - const NodeUnit& node_unit) { + const NodeUnit& node_unit, + const logging::Logger& logger) { const auto& supported_builtins = vsi::npu::SupportedBuiltinOps(); const auto& target_node = node_unit.GetNode(); const auto& it = supported_builtins.find(target_node.OpType()); if (supported_builtins.end() != it) { return it->second->IsSupported(graph_viewer, node_unit); } - LOGS_DEFAULT(WARNING) << "Fallback unsupported op (node_unit) " << node_unit.OpType() + LOGS(logger, WARNING) << "Fallback unsupported op (node_unit) " << node_unit.OpType() << " to cpu."; return false; } -bool GraphEP::IsNodeSupportedInGroup(const NodeUnit& node_unit, const GraphViewer& graph_viewer) { - return SupportedOp(graph_viewer, node_unit); +bool GraphEP::IsNodeSupportedInGroup(const NodeUnit& node_unit, + const GraphViewer& graph_viewer, + const logging::Logger& logger) { + return SupportedOp(graph_viewer, node_unit, logger); } const NodeUnit& GraphEP::GetNodeUnit(const Node* node) const { @@ -151,7 +154,7 @@ bool GraphEP::BindTensors(const std::shared_ptr& nodeio_info) { if (!input_names.empty()) { for (auto& name : input_names) { if (tensors_.find(name) == tensors_.end() || tensors_[name] == nullptr) { - LOGS_DEFAULT(ERROR) << "Input tensor not defined or not found!"; + LOGS(logger_, ERROR) << "Input tensor not defined or not found!"; return false; } (*op).BindInput(tensors_[name]); @@ -160,7 +163,7 @@ bool GraphEP::BindTensors(const std::shared_ptr& nodeio_info) { if (!output_names.empty()) { for (auto& name : output_names) { if (tensors_.find(name) == tensors_.end() || tensors_[name] == nullptr) { - LOGS_DEFAULT(ERROR) << "Output tensor not defined or not found!"; + LOGS(logger_, ERROR) << "Output tensor not defined or not found!"; return false; } (*op).BindOutput(tensors_[name]); diff --git a/onnxruntime/core/providers/vsinpu/vsinpu_ep_graph.h b/onnxruntime/core/providers/vsinpu/vsinpu_ep_graph.h index 5bb332fad0177..31a983810b20a 100644 --- a/onnxruntime/core/providers/vsinpu/vsinpu_ep_graph.h +++ b/onnxruntime/core/providers/vsinpu/vsinpu_ep_graph.h @@ -57,11 +57,12 @@ class GraphEP { bool Prepare(); static bool SupportedOp(const onnxruntime::GraphViewer& graph_viewer, - const NodeUnit& node_unit); + const NodeUnit& node_unit, const logging::Logger& logger); // If a node is supported by VSINPU in a partition node group // `node_outputs_in_group` is the set of the output names of the nodes added to this group so far - static bool IsNodeSupportedInGroup(const NodeUnit& node_unit, const GraphViewer& graph_viewer); + static bool IsNodeSupportedInGroup(const NodeUnit& node_unit, const GraphViewer& graph_viewer, + const logging::Logger& logger); const NodeUnit& GetNodeUnit(const Node* node) const; diff --git a/onnxruntime/core/providers/vsinpu/vsinpu_execution_provider.cc b/onnxruntime/core/providers/vsinpu/vsinpu_execution_provider.cc index 76cf9c9a797e1..3b70ab3c9241b 100644 --- a/onnxruntime/core/providers/vsinpu/vsinpu_execution_provider.cc +++ b/onnxruntime/core/providers/vsinpu/vsinpu_execution_provider.cc @@ -37,8 +37,6 @@ #include "core/optimizer/qdq_transformer/selectors_actions/shared/utils.h" #include "core/providers/partitioning_utils.h" -#include "core/providers/qnn/ort_api.h" - namespace onnxruntime { VSINPUExecutionProvider::VSINPUExecutionProvider(const VSINPUExecutionProviderInfo& info) : IExecutionProvider{onnxruntime::kVSINPUExecutionProvider, @@ -56,17 +54,17 @@ VSINPUExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_vie const GraphOptimizerRegistry& /* graph_optimizer_registry */, IResourceAccountant* /* resource_accountant */) const { std::vector> result; - + const auto& logger = *GetLogger(); if (graph_viewer.IsSubgraph()) { return result; } for (const auto& tensor : graph_viewer.GetAllInitializedTensors()) { if (tensor.second->has_data_location()) { - LOGS_DEFAULT(VERBOSE) << "location:" << tensor.second->data_location(); + LOGS(logger, VERBOSE) << "location:" << tensor.second->data_location(); if (tensor.second->data_location() == ONNX_NAMESPACE::TensorProto_DataLocation_EXTERNAL) { - LOGS_DEFAULT(WARNING) << "VSINPU: Initializers with external data location are not " + LOGS(logger, WARNING) << "VSINPU: Initializers with external data location are not " "currently supported"; return result; } @@ -93,11 +91,11 @@ VSINPUExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_vie supported = it->second; } else { // We only check the target node of the node unit - supported = vsi::npu::GraphEP::IsNodeSupportedInGroup(*node_unit, graph_viewer); + supported = vsi::npu::GraphEP::IsNodeSupportedInGroup(*node_unit, graph_viewer, logger); node_unit_supported_result[node_unit] = supported; } - LOGS_DEFAULT(VERBOSE) << "Node supported: [" << supported + LOGS(logger, VERBOSE) << "Node supported: [" << supported << "] Operator type: [" << node.OpType() << "] index: [" << node.Index() << "] name: [" << node.Name() @@ -158,9 +156,9 @@ VSINPUExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_vie // If the graph is partitioned in multiple subgraphs, and this may impact performance, // we want to give users a summary message at warning level. if (num_of_partitions > 1) { - LOGS_DEFAULT(WARNING) << summary_msg; + LOGS(logger, WARNING) << summary_msg; } else { - LOGS_DEFAULT(INFO) << summary_msg; + LOGS(logger, INFO) << summary_msg; } return result; diff --git a/onnxruntime/core/providers/vsinpu/vsinpu_execution_provider.h b/onnxruntime/core/providers/vsinpu/vsinpu_execution_provider.h index 1f96bde81b1d6..8053fef46a4f1 100644 --- a/onnxruntime/core/providers/vsinpu/vsinpu_execution_provider.h +++ b/onnxruntime/core/providers/vsinpu/vsinpu_execution_provider.h @@ -50,7 +50,7 @@ class VSINPUExecutionProvider : public IExecutionProvider { std::mutex& GetMutex() { return mutex_; } private: - int device_id_; + OrtDevice::DeviceId device_id_; std::mutex mutex_; }; diff --git a/onnxruntime/core/providers/webgpu/nn/fuse_utils.cc b/onnxruntime/core/providers/webgpu/nn/fuse_utils.cc index 38db604695a54..9e934e9eb5db7 100644 --- a/onnxruntime/core/providers/webgpu/nn/fuse_utils.cc +++ b/onnxruntime/core/providers/webgpu/nn/fuse_utils.cc @@ -63,11 +63,15 @@ std::string GetActivationSnippet(const Activation& activation, std::string value case ActivationKind::Sigmoid: return "value = " + value_type_cast(1.0) + " / (" + value_type_cast(1.0) + " + exp(-value));"; case ActivationKind::Clip: - return "value = clamp(value, " + value_type_cast(activation.activation_params_.Clip.minimum_) + ", " + value_type_cast(activation.activation_params_.Clip.maximum_) + ");"; + return "value = clamp(value, " + value_type_cast(activation.activation_params_.Clip.minimum_) + ", " + + value_type_cast(activation.activation_params_.Clip.maximum_) + ");"; case ActivationKind::HardSigmoid: - return "value = clamp(" + value_type_cast(activation.activation_params_.HardSigmoid.alpha_) + " * value + " + value_type_cast(activation.activation_params_.HardSigmoid.beta_) + ", 0.0" + ", 1.0" + ");"; + return "value = clamp(" + value_type_cast(activation.activation_params_.HardSigmoid.alpha_) + " * value + " + + value_type_cast(activation.activation_params_.HardSigmoid.beta_) + ", " + value_type_cast(0.0) + ", " + + value_type_cast(1.0) + ");"; case ActivationKind::LeakyRelu: - return "value = select(" + base_type_cast(activation.activation_params_.LeakyRelu.alpha_) + " * value, value, value >= " + value_type_cast(0.0) + ");"; + return "value = select(" + base_type_cast(activation.activation_params_.LeakyRelu.alpha_) + + " * value, value, value >= " + value_type_cast(0.0) + ");"; case ActivationKind::Tanh: return "value = tanh(value);"; default: diff --git a/onnxruntime/core/providers/webgpu/nn/instance_norm.cc b/onnxruntime/core/providers/webgpu/nn/instance_norm.cc index f3bccec4872fc..7b39980f85605 100644 --- a/onnxruntime/core/providers/webgpu/nn/instance_norm.cc +++ b/onnxruntime/core/providers/webgpu/nn/instance_norm.cc @@ -88,13 +88,13 @@ Status InstanceNormProgram::GenerateShaderCode(ShaderHelper& shader) const { const auto& channel_scale_shift = shader.AddInput("channel_scale_shift", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias); const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias); shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size") - << "let outputIndices = " << output.OffsetToIndices("global_idx") + << "let outputIndices = " << output.OffsetToIndices("global_idx") << ";\n" << "let batch = outputIndices[0];\n" << "let channel = outputIndices[1];\n" << "let channel_scale_shift_indices = channel_scale_shift_indices_t(batch, channel, 0);\n" << "let channel_scale_shift = " << channel_scale_shift.GetByIndices("channel_scale_shift_indices") << ";\n" << "let input_value = " << input.GetByOffset("global_idx") << ";\n" - << "let output_value = input_value * output_value_t(channel_scale_sift.x) + output_value_t(channel_scale_shift.y);\n" + << "let output_value = input_value * output_value_t(channel_scale_shift.x) + output_value_t(channel_scale_shift.y);\n" << output.SetByOffset("global_idx", "output_value") << ";\n"; return Status::OK(); } diff --git a/onnxruntime/core/providers/webgpu/nn/pool.cc b/onnxruntime/core/providers/webgpu/nn/pool.cc index 12c135dbbf46d..d650392b71fb5 100644 --- a/onnxruntime/core/providers/webgpu/nn/pool.cc +++ b/onnxruntime/core/providers/webgpu/nn/pool.cc @@ -95,6 +95,9 @@ Status PoolProgram::GenerateShaderCode(ShaderHelper& shader) const { var_decl_code = SS_GET(var_decl_ss); sampling_code = " value = max(value, x_val);\n"; + if (are_small_output_big_kernel_) { + downsampling_code = " sum_or_max_shared[local_idx] = value;\n"; + } } else { SS(var_decl_ss, kStringInitialSize); var_decl_ss << " var value = " << (is_float16_ ? "f16(0)" : "f32(0)") << ";\n"; @@ -113,7 +116,12 @@ Status PoolProgram::GenerateShaderCode(ShaderHelper& shader) const { sampling_code = SS_GET(sampling_ss); SS(downsampling_ss, kStringInitialSize); - downsampling_ss << " value /= " << (is_float16_ ? "f16" : "f32") << "(count);\n"; + if (are_small_output_big_kernel_) { + downsampling_ss << " sum_or_max_shared[local_idx] = value;\n" + << " count_shared[local_idx] = count;\n"; + } else { + downsampling_ss << " value /= " << (is_float16_ ? "f16" : "f32") << "(count);\n"; + } downsampling_code = SS_GET(downsampling_ss); } @@ -125,13 +133,54 @@ Status PoolProgram::GenerateShaderCode(ShaderHelper& shader) const { auto data_dim_end = input.Rank(); data_dim_end = is_nhwc_ ? data_dim_end - 1 : data_dim_end; + std::string sum_or_max_shared; + if (are_small_output_big_kernel_) { + shader.AdditionalImplementation() + << "var sum_or_max_shared : array<" << (is_float16_ ? "f16" : "f32") << ",workgroup_size_x >;\n" + << (!is_max_pool_ ? "var count_shared : array;\n" : ""); + + SS(shared_ss, 512); + std::string sum_or_max_shared_op; + std::string count_shared_op; + if (is_max_pool_) { + sum_or_max_shared_op = "sum_or_max_shared[local_idx] = max(sum_or_max_shared[local_idx], sum_or_max_shared[local_idx + reduce_size]);\n"; + } else { + sum_or_max_shared_op = "sum_or_max_shared[local_idx] += sum_or_max_shared[local_idx + reduce_size];\n"; + count_shared_op = "count_shared[local_idx] += count_shared[local_idx + reduce_size];\n"; + } + + shared_ss << " workgroupBarrier();\n" + << " var reduce_size : u32 = workgroup_size_x;\n" + << " for (var curr_size = reduce_size >> 1; curr_size > 0; curr_size = reduce_size >> 1) {\n" + << " reduce_size = curr_size + (reduce_size & 1);\n" + << " if (local_idx < curr_size) {\n" + << " " << sum_or_max_shared_op + << " " << count_shared_op + << " }\n" + << " workgroupBarrier();\n" + << " }\n"; + sum_or_max_shared = SS_GET(shared_ss); + } + std::string kernel_loop_decl_code = are_small_output_big_kernel_ ? " for (var i: u32 = local_idx; i < uniforms.kernel_size; i += workgroup_size_x) {\n" : " for (var i: u32 = 0; i < uniforms.kernel_size; i++) {\n"; + + SS(output_ss, kStringInitialSize); + if (are_small_output_big_kernel_) { + output_ss << " if (local_idx == 0) {\n" + << " value = sum_or_max_shared[0]" << (!is_max_pool_ ? (is_float16_ ? " / f16(count_shared[0])" : " / f32(count_shared[0])") : "") << ";\n" + << " " << output.SetByOffset("workgroup_idx", "value") << ";\n" + << " }\n"; + } else { + output_ss << " " << output.SetByOffset("global_idx", "value") << ";\n"; + } + std::string output_code = SS_GET(output_ss); + auto& body = shader.MainFunctionBody(); - body << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size") - << " let y_indices = " << output.OffsetToIndices("global_idx") << ";\n" + body << (are_small_output_big_kernel_ ? "" : shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size")) + << " let y_indices = " << output.OffsetToIndices((are_small_output_big_kernel_ ? "workgroup_idx" : "global_idx")) << ";\n" << " var x_indices = y_indices;\n" << " var k_indices: array;\n" << var_decl_code - << " for (var i: u32 = 0; i < uniforms.kernel_size; i++) {\n" + << kernel_loop_decl_code << " var offset = i;\n" // ---- Compute offset to indices in pooling window. << " for (var j = 0; j < " << kernel_rank << "; j++) {\n" @@ -162,7 +211,8 @@ Status PoolProgram::GenerateShaderCode(ShaderHelper& shader) const { << " }\n" << " }\n" << downsampling_code - << " " << output.SetByOffset("global_idx", "value") << ";\n"; + << sum_or_max_shared + << output_code; return Status::OK(); } @@ -225,7 +275,6 @@ Status Pool::ComputeInternal(ComputeContext& context) const { } bool is_float16 = X->GetElementType() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16; bool count_include_pad = pool_attrs_.count_include_pad; - PoolProgram program{is_max_pool, is_nhwc, kernel_shape, is_float16, count_include_pad}; // Number of elements uint32_t output_size = gsl::narrow_cast(Y->Shape().Size()); @@ -235,16 +284,25 @@ Status Pool::ComputeInternal(ComputeContext& context) const { const auto strides_u32 = NarrowToU32(strides); const auto dilations_u32 = NarrowToU32(dilations); - program.CacheHint(kernel_shape.size(), is_max_pool, is_nhwc, is_float16, count_include_pad) + bool are_small_output_big_kernel = output_size <= 128 && kernel_size >= 128; + PoolProgram program{is_max_pool, is_nhwc, kernel_shape, is_float16, count_include_pad, are_small_output_big_kernel}; + + program.CacheHint(kernel_shape.size(), is_max_pool, is_nhwc, is_float16, count_include_pad, are_small_output_big_kernel) .AddInputs({{X, ProgramTensorMetadataDependency::TypeAndRank}}) .AddOutputs({{Y}}) - .SetDispatchGroupSize((output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) .AddUniformVariables({output_size, kernel_size, gsl::span(kernel_strides.data(), kernel_strides.size()), gsl::span(pads_u32.data(), pads_u32.size()), gsl::span(strides_u32.data(), strides_u32.size()), gsl::span(dilations_u32.data(), dilations_u32.size())}); + if (are_small_output_big_kernel) { + program.SetWorkgroupSize(128) + .SetDispatchGroupSize(output_size); + } else { + program.SetDispatchGroupSize((output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE); + } + return context.RunProgram(program); } diff --git a/onnxruntime/core/providers/webgpu/nn/pool.h b/onnxruntime/core/providers/webgpu/nn/pool.h index c1716542e5549..57bdc64954acd 100644 --- a/onnxruntime/core/providers/webgpu/nn/pool.h +++ b/onnxruntime/core/providers/webgpu/nn/pool.h @@ -14,13 +14,14 @@ namespace webgpu { class PoolProgram final : public Program { public: PoolProgram(bool is_max_pool, bool is_nhwc, const TensorShapeVector& kernel_shape, bool is_float16, - bool count_include_pad) + bool count_include_pad, bool are_small_output_big_kernel) : Program{"Pool"}, is_max_pool_{is_max_pool}, is_nhwc_{is_nhwc}, kernel_shape_{kernel_shape}, is_float16_{is_float16}, - count_include_pad_{count_include_pad} {} + count_include_pad_{count_include_pad}, + are_small_output_big_kernel_{are_small_output_big_kernel} {} Status GenerateShaderCode(ShaderHelper& sh) const override; @@ -39,6 +40,7 @@ class PoolProgram final : public Program { const TensorShapeVector kernel_shape_; const bool is_float16_; const bool count_include_pad_; + const bool are_small_output_big_kernel_; }; template diff --git a/onnxruntime/core/providers/webgpu/quantization/quantize_linear.cc b/onnxruntime/core/providers/webgpu/quantization/quantize_linear.cc index e2b5d73168935..0305049e9b789 100644 --- a/onnxruntime/core/providers/webgpu/quantization/quantize_linear.cc +++ b/onnxruntime/core/providers/webgpu/quantization/quantize_linear.cc @@ -79,7 +79,7 @@ Status DequantizeLinearProgram::GenerateShaderCode(ShaderHelper& shader) const { if (packed_) { shader.MainFunctionBody() << "let zero_point_index = " << output.IndicesGet("output_indices", "uniforms.axis") << ";\n" - << "let zero_point_input = " << zero_point.GetByOffset("zero_point_index / 4") << ";\n" + << "let zero_point_input = " << zero_point.GetByOffset("u32(zero_point_index / 4)") << ";\n" << "let zero_point_vec = " << unpack << ";\n" << "let zero_point_value = zero_point_vec[zero_point_index % 4];\n"; } else { @@ -92,7 +92,7 @@ Status DequantizeLinearProgram::GenerateShaderCode(ShaderHelper& shader) const { if (packed_) { shader.MainFunctionBody() << "let zero_point_offset = " << scale.GetByIndices("scale_indices") << ";\n" - << "let zero_point_input = " << zero_point.GetByOffset("zero_point_offset / 4") << ";\n" + << "let zero_point_input = " << zero_point.GetByOffset("u32(zero_point_offset / 4)") << ";\n" << "let zero_point_vec = " << unpack << ";\n" << "let zero_point_value = zero_point_vec[zero_point_offset % 4];\n"; } else { diff --git a/onnxruntime/core/providers/webgpu/shader_helper.cc b/onnxruntime/core/providers/webgpu/shader_helper.cc index 59855e6117641..36f6b512a0a93 100644 --- a/onnxruntime/core/providers/webgpu/shader_helper.cc +++ b/onnxruntime/core/providers/webgpu/shader_helper.cc @@ -125,7 +125,8 @@ namespace { // Validate if the tensor element type matches the program variable data type Status ValidateVariableDataType(int32_t element_type, ProgramVariableDataType var_type, bool is_atomic = false) { if (is_atomic) { - ORT_RETURN_IF_NOT(var_type == ProgramVariableDataType::Int32 || var_type == ProgramVariableDataType::Uint32, + // float32 is not a valid data type for atomic. However the data may be bitcast-ed to i32 and used to simulate atomic operation using atomicCompareExchangeWeak. + ORT_RETURN_IF_NOT(var_type == ProgramVariableDataType::Int32 || var_type == ProgramVariableDataType::Uint32 || var_type == ProgramVariableDataType::Float32, "Unexpected program variable type ", int(var_type), " for atomic variable"); } @@ -422,11 +423,17 @@ Status ShaderHelper::GenerateSourceCode(std::string& code, std::vector& sha bool is_atomic = program_.Outputs()[i].is_atomic; ss << "@group(0) @binding(" << input_vars_.size() + i << ") var " << output->name_ << ": array<"; if (is_atomic) { - ss << "atomic<"; - } - ss << output->StorageType(); - if (is_atomic) { - ss << ">"; + if (output->type_ == ProgramVariableDataType::Float32) { + ss << "atomic"; + } else if (output->type_ == ProgramVariableDataType::Uint32) { + ss << "atomic"; + } else if (output->type_ == ProgramVariableDataType::Int32) { + ss << "atomic"; + } else { + ORT_RETURN_IF(true, "Unsupported atomic type: ", int(output->type_)); + } + } else { + ss << output->StorageType(); } ss << ">;\n"; } diff --git a/onnxruntime/core/providers/webgpu/tensor/scatter_nd.cc b/onnxruntime/core/providers/webgpu/tensor/scatter_nd.cc new file mode 100644 index 0000000000000..986255ea1f185 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/tensor/scatter_nd.cc @@ -0,0 +1,222 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include +#include "core/providers/webgpu/webgpu_kernel.h" +#include "core/providers/webgpu/webgpu_supported_types.h" +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/shader_variable.h" +#include "scatter_nd.h" + +namespace onnxruntime { +namespace webgpu { + +Status ScatterNDProgram::GenerateShaderCode(ShaderHelper& shader) const { + shader.AddInput("indices", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias); + shader.AddInput("updates", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias); + const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseShapeAndStride); + const auto output_rank = static_cast(output.Rank()); + auto atomic_reduction_snippet = [](ScatterNDReduction reduction, const std::string& ptr, const std::string& value, const std::string& data_type) -> std ::string { + std::ostringstream ss; + bool is_32_bit_integer = data_type == "i32" || data_type == "u32"; + bool is_unsigned_integer = data_type == "u32"; + std::ostringstream ss_float_start; + ss_float_start << " {\n" + << " var oldValue = 0" << (is_unsigned_integer ? "u" : "") << ";\n" + << " loop {\n" + << " let newValueF32 = "; + std::ostringstream ss_float_end; + ss_float_end << ";\n" + << " let newValue = bitcast<" << (is_unsigned_integer ? "u32" : "i32") << ">(newValueF32);\n" + << " let res = atomicCompareExchangeWeak(&" << ptr << ", oldValue, newValue);\n" + << " if res.exchanged {\n" + << " break;\n" + << " }\n" + << " oldValue = res.old_value;\n" + << " }\n" + << " }\n"; + switch (reduction) { + case ScatterNDReduction::None: + ss << " " << ptr << " = " << value << ";\n"; + break; + case ScatterNDReduction::Add: + if (is_32_bit_integer) { + ss << " atomicAdd(&" << ptr << ", bitcast<" << data_type << ">(" << value << "));\n"; + } else { + // atomicAdd only supports uint/int type. For float, we use + // atomicCompareExchangeWeak to simulate. + ss << ss_float_start.str() << "bitcast<" << data_type << ">(oldValue) + (" << value << ")" << ss_float_end.str() + << "\n"; + } + break; + case ScatterNDReduction::Max: + if (is_32_bit_integer) { + ss << " atomicMax(&" << ptr << ", bitcast<" << data_type << ">(" << value << "));\n"; + } else { + // atomicMax only supports uint/int type. For float, we use + // atomicCompareExchangeWeak to simulate. + ss << ss_float_start.str() << "max(bitcast<" << data_type << ">(oldValue), (" << value << "))" << ss_float_end.str(); + } + break; + case ScatterNDReduction::Min: + if (is_32_bit_integer) { + ss << " atomicMin(&" << ptr << ", bitcast<" << data_type << ">(" << value << "));\n"; + } else { + // atomicMin only supports uint/int type. For float, we use + // atomicCompareExchangeWeak to simulate. + ss << ss_float_start.str() << "min(bitcast<" << data_type << ">(oldValue), (" << value << "))" << ss_float_end.str(); + } + break; + case ScatterNDReduction::Mul: + // atomicMul is not supported, we use atomicCompareExchangeWeak to simulate. + ss << ss_float_start.str() << "(bitcast<" << data_type << ">(oldValue) * (" << value << "))" << ss_float_end.str(); + break; + default: + ORT_THROW("Unsupported reduction type: ", static_cast(reduction)); + // The controlflow should never reach here. + } + return ss.str(); + }; + + auto calc_data_offset_snippet = [](size_t output_rank) -> std::string { + std::ostringstream ss; + if (output_rank < 2) { + ss << " let element_count_dim = 1u;\n"; + } else { + ss << " let element_count_dim = select(" << GetElementAt("uniforms.output_stride", "i - indices_start", output_rank - 1) << ", 1u, i - indices_start == " << (output_rank - 1) << ");\n"; + } + ss << " let dim_value = " << GetElementAt("uniforms.output_shape", "i - indices_start", output_rank) << ";\n" + << " if (index >= 0) {\n" + << " if (index >= i32(dim_value)) {\n" + << " index = i32(dim_value - 1);\n" + << " }\n" + << " } else {\n" + << " if (index < -i32(dim_value)) {\n" + << " index = 0;\n" + << " } else {\n" + << " index += i32(dim_value);\n" + << " }\n" + << " }\n" + << " data_offset += u32((u32(index) * element_count_dim));\n"; + return ss.str(); + }; + + auto update_elements_snippet = [atomic_reduction_snippet](ScatterNDReduction reduction, const std::string& data_type) -> std::string { + std::ostringstream ss; + ss << " for (var i = 0u; i < uniforms.num_updates_elements; i++) {\n" + << " let value = updates[uniforms.num_updates_elements * global_idx + i];\n" + << atomic_reduction_snippet(reduction, "output[data_offset + i]", "value", data_type) << "\n" + << " }\n"; + return ss.str(); + }; + std::string data_type_str; + bool reducible = false; + if (data_type_ == DataTypeImpl::GetType()) { + reducible = true; + data_type_str = "i32"; + } else if (data_type_ == DataTypeImpl::GetType()) { + reducible = true; + data_type_str = "u32"; + } else if (data_type_ == DataTypeImpl::GetType()) { + reducible = true; + data_type_str = "f32"; + } else { + // Default value. + data_type_str = "output_element_t"; + } + if (reduction_ != ScatterNDReduction::None && !reducible) { + ORT_THROW("ScatterND: Reduction is not supported for data type ", data_type_str); + } + shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size") + << " var data_offset = 0u;\n" + << " var indices_start = uniforms.last_index_dimension * global_idx;\n" + << " var indices_end = indices_start + uniforms.last_index_dimension;\n" + << " for (var i = indices_start; i < indices_end; i++) {\n" + << " var index = i32(indices[i].x);\n" + << calc_data_offset_snippet(output_rank) + << " }\n" + << update_elements_snippet(reduction_, data_type_str); + return Status::OK(); +} + +Status ScatterND::ComputeInternal(ComputeContext& context) const { + const Tensor* input = context.Input(0); + const auto* indices = context.Input(1); + const auto* updates = context.Input(2); + const auto& input_shape = input->Shape(); + const auto& indices_shape = indices->Shape(); + auto indices_rank = indices_shape.NumDimensions(); + auto last_index_dimension = static_cast(indices_shape[indices_rank - 1]); + auto num_updates_elements = static_cast(input_shape.SizeFromDimension(last_index_dimension)); + // TODO: support bool with components 4. + const size_t components = 1; + auto output_size = static_cast((indices_shape.SizeToDimension(indices_rank - 1) + components - 1) / components); + auto* output = context.Output(0, input_shape); + MLDataType data_type = input->DataType(); + const void* source = input->DataRaw(); + void* target = output->MutableDataRaw(); + // If source and target pointers are not equal (non-inplace operation), we need to copy the data. + if (target != source) { + ORT_RETURN_IF_ERROR(Info().GetDataTransferManager().CopyTensor(*input, *output)); + } + ScatterNDProgram program(reduction_, data_type); + program + .CacheHint(static_cast(reduction_)) + .AddInputs({{indices, ProgramTensorMetadataDependency::TypeAndRank}, + {updates, ProgramTensorMetadataDependency::TypeAndRank}}) + .SetDispatchGroupSize((output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) + .AddUniformVariables({output_size, last_index_dimension, num_updates_elements}); + if (reduction_ != ScatterNDReduction::None && (data_type == DataTypeImpl::GetType() || data_type == DataTypeImpl::GetType() || + data_type == DataTypeImpl::GetType())) { + program.AddOutput({output, ProgramTensorMetadataDependency::TypeAndRank, ProgramOutput::Atomic}); + } else { + program.AddOutput({output, ProgramTensorMetadataDependency::TypeAndRank}); + } + return context.RunProgram(program); +} + +ONNX_OPERATOR_KERNEL_EX( + ScatterND, + kOnnxDomain, + 18, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedNumberTypes()) + .MayInplace(0, 0), + ScatterND); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + ScatterND, + kOnnxDomain, + 16, + 17, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedNumberTypes()) + .MayInplace(0, 0), + ScatterND); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + ScatterND, + kOnnxDomain, + 13, + 15, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedNumberTypes()) + .MayInplace(0, 0), + ScatterND); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + ScatterND, + kOnnxDomain, + 11, + 12, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedNumberTypes()) + .MayInplace(0, 0), + ScatterND); +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/tensor/scatter_nd.h b/onnxruntime/core/providers/webgpu/tensor/scatter_nd.h new file mode 100644 index 0000000000000..40bcbadebf65d --- /dev/null +++ b/onnxruntime/core/providers/webgpu/tensor/scatter_nd.h @@ -0,0 +1,60 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/providers/webgpu/webgpu_supported_types.h" +#include "core/providers/webgpu/webgpu_kernel.h" +#include "core/providers/webgpu/program.h" +#include "core/framework/data_transfer_manager.h" + +namespace onnxruntime { +namespace webgpu { + +enum class ScatterNDReduction : int { + None = 0, + Add = 1, + Mul = 2, + Min = 3, + Max = 4, +}; + +class ScatterNDProgram final : public Program { + public: + ScatterNDProgram(ScatterNDReduction reduction, MLDataType data_type) : Program{"ScatterND"}, reduction_(reduction), data_type_(data_type) {} + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"output_size", ProgramUniformVariableDataType::Uint32}, + {"last_index_dimension", ProgramUniformVariableDataType::Uint32}, + {"num_updates_elements", ProgramUniformVariableDataType::Uint32}); + ScatterNDReduction reduction_; + MLDataType data_type_; +}; + +class ScatterND : public WebGpuKernel { + public: + ScatterND(const OpKernelInfo& info) : WebGpuKernel(info) { + std::string reduction = info.GetAttrOrDefault("reduction", "none"); + if (reduction == "add") { + reduction_ = ScatterNDReduction::Add; + } else if (reduction == "mul") { + reduction_ = ScatterNDReduction::Mul; + } else if (reduction == "min") { + reduction_ = ScatterNDReduction::Min; + } else if (reduction == "max") { + reduction_ = ScatterNDReduction::Max; + } else if (reduction == "none") { + reduction_ = ScatterNDReduction::None; + } else { + ORT_THROW("Reduction '", reduction, "' is not supported on webgpu when opset <= 18."); + } + } + + Status ComputeInternal(ComputeContext& context) const override; + + private: + ScatterNDReduction reduction_{ScatterNDReduction::None}; +}; + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/tensor/transpose.cc b/onnxruntime/core/providers/webgpu/tensor/transpose.cc index 0df7d1ae9fa2f..5f1496ff7a40e 100644 --- a/onnxruntime/core/providers/webgpu/tensor/transpose.cc +++ b/onnxruntime/core/providers/webgpu/tensor/transpose.cc @@ -165,6 +165,11 @@ Status Transpose::ComputeInternal(ComputeContext& context) const { TensorShape output_shape(output_dims); auto* output_tensor = context.Output(0, output_shape); + int64_t output_size = output_shape.Size(); + if (output_size == 0) { + return Status::OK(); + } + return DoTranspose(context, *p_perm, *input_tensor, *output_tensor); } diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.cc b/onnxruntime/core/providers/webgpu/webgpu_context.cc index 17f4fa1bd44b3..27380645baf57 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_context.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_context.cc @@ -10,7 +10,9 @@ #endif #if !defined(__wasm__) +#if !defined(BUILD_DAWN_MONOLITHIC_LIBRARY) #include "dawn/dawn_proc.h" +#endif #if !defined(USE_EXTERNAL_DAWN) #include "dawn/native/DawnNative.h" #endif diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc index 928e48d78d7e5..9ea79e4cf28a3 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc @@ -400,6 +400,11 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxD class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 21, 22, DequantizeLinear); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 23, DequantizeLinear); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, ScatterND); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 15, ScatterND); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 16, 17, ScatterND); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 18, ScatterND); + std::unique_ptr RegisterKernels() { auto kernel_registry = std::make_unique(); @@ -732,6 +737,11 @@ std::unique_ptr RegisterKernels() { BuildKernelCreateInfo, BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, }; for (auto& function_table_entry : function_table) { diff --git a/onnxruntime/core/providers/webnn/builders/helper.h b/onnxruntime/core/providers/webnn/builders/helper.h index d77634e1f5b9b..6556c293f81bf 100644 --- a/onnxruntime/core/providers/webnn/builders/helper.h +++ b/onnxruntime/core/providers/webnn/builders/helper.h @@ -183,6 +183,10 @@ inline bool IsEmptyTensor(const GraphViewer& graph_viewer, const std::string& na return std::any_of(dims.begin(), dims.end(), [](auto d) { return d == 0; }); } +inline bool IsOnnxDomain(std::string_view domain) { + return (domain == onnxruntime::kOnnxDomain) || (domain == onnxruntime::kOnnxDomainAlias); +} + inline bool TensorExists(const ConstPointerContainer>& defs, size_t tensor_index) noexcept { return tensor_index < defs.size() && defs[tensor_index]->Exists(); } @@ -203,9 +207,10 @@ const std::map> decomposed_op_ma {"add", "cast", "concat", "constant", "cumulativeSum", "div", "expand", "lesser", "matmul", "reshape", "scatterND", "softmax", "transpose", "where"}}, {"LRN", {"add", "averagePool2d", "div", "mul", "pad", "pow", "transpose"}}, + {"MatMulInteger", {"cast", "dequantizeLinear", "matmul"}}, {"MatMulNBits", {"add", "dequantizeLinear", "matmul", "reshape", "transpose"}}, {"MultiHeadAttention", {"add", "cast", "concat", "constant", "div", "matmul", "reshape", "softmax", "transpose"}}, - {"RotaryEmbedding", {"add", "concat", "gather", "mul", "reshape", "split"}}, + {"RotaryEmbedding", {"add", "concat", "gather", "mul", "reshape", "slice", "split"}}, {"SimplifiedLayerNormalization", {"add", "div", "mul", "pow", "reduceMean", "sqrt"}}, {"SkipSimplifiedLayerNormalization", {"add", "div", "mul", "pow", "reduceMean", "sqrt"}}, }; @@ -262,7 +267,6 @@ const std::map op_map = { {"LpPool", "l2Pool2d"}, {"LSTM", "lstm"}, {"MatMul", "matmul"}, - {"MatMulInteger", "matmulInteger"}, {"Max", "max"}, {"MaxPool", "maxPool2d"}, {"Min", "min"}, diff --git a/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc index d4be0f1bee18e..02f46c85d1d06 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc @@ -26,6 +26,8 @@ class GemmOpBuilder : public BaseOpBuilder { const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; bool HasSupportedInputsImpl(const GraphViewer&, const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const override; + bool HasSupportedOutputsImpl(const Node& node, const emscripten::val& wnn_limits, + const logging::Logger& logger) const override; }; // Add operator related. @@ -35,111 +37,136 @@ Status GemmOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N const auto& input_defs = node.InputDefs(); const size_t a_idx = 0, b_idx = 1, c_idx = 2; // A*B+C + std::vector a_shape; + std::vector b_shape; + std::vector output_shape; + ORT_RETURN_IF_NOT(GetShape(*input_defs[a_idx], a_shape, logger), "Can not get shape of A"); + ORT_RETURN_IF_NOT(GetShape(*input_defs[b_idx], b_shape, logger), "Can not get shape of B"); + ORT_RETURN_IF_NOT(GetShape(*node.OutputDefs()[0], output_shape, logger), "Can not get output shape"); + emscripten::val a = model_builder.GetOperand(node.InputDefs()[a_idx]->Name()); emscripten::val b = model_builder.GetOperand(node.InputDefs()[b_idx]->Name()); emscripten::val output = emscripten::val::object(); - emscripten::val options = emscripten::val::object(); - options.set("label", node.Name()); - if (op_type == "MatMul") { - std::vector a_shape; - if (!GetShape(*input_defs[a_idx], a_shape, logger)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Can not get shape of A."); - } - std::vector b_shape; - if (!GetShape(*input_defs[b_idx], b_shape, logger)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Can not get shape of B."); - } - // If the first argument is 1-D, it is promoted to a matrix by prepending a 1 to its dimensions. - bool extended_a_shape = false; - if (a_shape.size() == 1) { - extended_a_shape = true; - a_shape.insert(a_shape.begin(), 1); - emscripten::val a_shape_arr = emscripten::val::array(GetNarrowedIntfromInt64(a_shape)); - emscripten::val reshape_a_options = emscripten::val::object(); - reshape_a_options.set("label", node.Name() + "_reshape_a"); - a = model_builder.GetBuilder().call("reshape", a, a_shape_arr, reshape_a_options); - } - // If the second argument is 1-D, it is promoted to a matrix by appending a 1 to its dimensions. - bool extended_b_shape = false; - if (b_shape.size() == 1) { - extended_b_shape = true; - b_shape.push_back(1); - emscripten::val b_shape_arr = emscripten::val::array(GetNarrowedIntfromInt64(b_shape)); - emscripten::val reshape_b_options = emscripten::val::object(); - reshape_b_options.set("label", node.Name() + "_reshape_b"); - b = model_builder.GetBuilder().call("reshape", b, b_shape_arr, reshape_b_options); - } + emscripten::val common_options = emscripten::val::object(); - output = model_builder.GetBuilder().call("matmul", a, b, options); + // MatMul and MatMulInteger in ONNX allow 1-D inputs while matmul in WebNN only supports at least 2-D inputs. + // We can support 1-D inputs by reshaping them to 2-D. We don't care Gemm here because it only provides 2-D inputs. - emscripten::val reshape_output_options = emscripten::val::object(); - reshape_output_options.set("label", node.Name() + "_reshape_output"); - // If the inputs are both 1D, reduce the output to a scalar. - if (extended_a_shape && extended_b_shape) { - output = model_builder.GetBuilder().call("reshape", - output, - emscripten::val::array(), - reshape_output_options); - } - // After matrix multiplication the prepended 1 is removed. - else if (extended_a_shape) { - std::vector new_shape; - for (size_t i = 0; i < b_shape.size() - 2; i++) { - new_shape.push_back(SafeInt(b_shape[i])); - } - new_shape.push_back(SafeInt(b_shape.back())); - output = model_builder.GetBuilder().call("reshape", - output, - emscripten::val::array(new_shape), - reshape_output_options); - } - // After matrix multiplication the appended 1 is removed. - else if (extended_b_shape) { - std::vector new_shape; - for (size_t i = 0; i < a_shape.size() - 1; i++) { - new_shape.push_back(SafeInt(a_shape[i])); - } + // If the input A is 1-D, it is promoted to a matrix by prepending a 1 to its dimensions. + if (a_shape.size() == 1) { + a_shape.insert(a_shape.begin(), 1); + emscripten::val a_shape_arr = emscripten::val::array(GetNarrowedIntfromInt64(a_shape)); + common_options.set("label", node.Name() + "_reshape_a"); + a = model_builder.GetBuilder().call("reshape", a, a_shape_arr, common_options); + } + // If the input B is 1-D, it is promoted to a matrix by appending a 1 to its dimensions. + if (b_shape.size() == 1) { + b_shape.push_back(1); + emscripten::val b_shape_arr = emscripten::val::array(GetNarrowedIntfromInt64(b_shape)); + common_options.set("label", node.Name() + "_reshape_b"); + b = model_builder.GetBuilder().call("reshape", b, b_shape_arr, common_options); + } + + if (op_type == "MatMul") { + common_options.set("label", node.Name()); + output = model_builder.GetBuilder().call("matmul", a, b, common_options); + + // If A or B input is 1-D, we need to reshape the output back to its original shape. + if (a_shape.size() == 1 || b_shape.size() == 1) { + common_options.set("label", node.Name() + "_reshape_output"); + emscripten::val output_shape_arr = emscripten::val::array(GetNarrowedIntfromInt64(output_shape)); output = model_builder.GetBuilder().call("reshape", output, - emscripten::val::array(new_shape), - reshape_output_options); + output_shape_arr, + common_options); } } else if (op_type == "MatMulInteger") { - emscripten::val a_zero_point = emscripten::val::null(); - emscripten::val b_zero_point = emscripten::val::null(); - if (input_defs.size() >= 3) { + // WebNN doesn't provide a dedicated op for MatMulInteger, it can be simply decomposed by + // DequantizeLinear A, B -> MatMul -> Cast (to int32) + int32_t a_type; + ORT_RETURN_IF_NOT(GetType(*input_defs[0], a_type, logger), "Cannot get data type of input A"); + + emscripten::val a_zero_point, b_zero_point, a_scale, b_scale; + if (TensorExists(input_defs, 2)) { a_zero_point = model_builder.GetOperand(node.InputDefs()[2]->Name()); + std::vector a_zero_point_shape; + ORT_RETURN_IF_NOT(GetShape(*input_defs[2], a_zero_point_shape, logger), "Cannot get shape of a_zero_point"); + // Scale is not used by MatMulInteger but required by DequantizeLinear. So set it to deafult value 1.0f. + // The scale input should have the same shape as the zero point input. + a_scale = model_builder.CreateOrGetConstant(ONNX_NAMESPACE::TensorProto_DataType_FLOAT, + 1.0f, + GetNarrowedIntfromInt64(a_zero_point_shape)); } else { - a_zero_point = model_builder.CreateOrGetConstant(ONNX_NAMESPACE::TensorProto_DataType_UINT8, 0); + // If a_zero_point is not provided, create default scalar for zero_point and scale inputs. + a_zero_point = model_builder.CreateOrGetConstant(a_type, 0); + a_scale = model_builder.CreateOrGetConstant(ONNX_NAMESPACE::TensorProto_DataType_FLOAT, 1.0f); } - if (input_defs.size() >= 4) { + + // Dequantize A to Float32 + common_options.set("label", node.Name() + "_dequantized_a"); + emscripten::val dequantized_a = model_builder.GetBuilder().call("dequantizeLinear", + a, + a_scale, + a_zero_point, + common_options); + if (TensorExists(input_defs, 3)) { b_zero_point = model_builder.GetOperand(node.InputDefs()[3]->Name()); + std::vector b_zero_point_shape; + ORT_RETURN_IF_NOT(GetShape(*input_defs[3], b_zero_point_shape, logger), "Cannot get shape of b_zero_point"); + b_scale = model_builder.CreateOrGetConstant(ONNX_NAMESPACE::TensorProto_DataType_FLOAT, + 1.0f, + GetNarrowedIntfromInt64(b_zero_point_shape)); } else { - b_zero_point = model_builder.CreateOrGetConstant(ONNX_NAMESPACE::TensorProto_DataType_UINT8, 0); + b_zero_point = model_builder.CreateOrGetConstant(a_type, 0); + b_scale = model_builder.CreateOrGetConstant(ONNX_NAMESPACE::TensorProto_DataType_FLOAT, 1.0f); + } + + // Dequantize B to Float32 + common_options.set("label", node.Name() + "_dequantized_b"); + emscripten::val dequantized_b = model_builder.GetBuilder().call("dequantizeLinear", + b, + b_scale, + b_zero_point, + common_options); + // MatMul dequantized A and B + common_options.set("label", node.Name() + "_matmul_dequantized_ab"); + emscripten::val matmul_dequantized_ab = model_builder.GetBuilder().call("matmul", + dequantized_a, + dequantized_b, + common_options); + // Cast matmul_dequantized_ab to int32 + common_options.set("label", node.Name() + "_cast_output"); + output = model_builder.GetBuilder().call("cast", + matmul_dequantized_ab, + emscripten::val("int32"), + common_options); + // If A or B input is 1-D, we need to reshape the output back to its original shape. + if (a_shape.size() == 1 || b_shape.size() == 1) { + common_options.set("label", node.Name() + "_reshape_output"); + emscripten::val output_shape_arr = emscripten::val::array(GetNarrowedIntfromInt64(output_shape)); + output = model_builder.GetBuilder().call("reshape", + output, + output_shape_arr, + common_options); } - output = model_builder.GetBuilder().call("matmulInteger", - a, - a_zero_point, - b, - b_zero_point, - options); } else { // Gemm NodeAttrHelper helper(node); const auto transA = helper.Get("transA", 0); - options.set("aTranspose", emscripten::val(transA == 1)); + common_options.set("aTranspose", emscripten::val(transA == 1)); const auto transB = helper.Get("transB", 0); - options.set("bTranspose", emscripten::val(transB == 1)); + common_options.set("bTranspose", emscripten::val(transB == 1)); const auto alpha = helper.Get("alpha", 1.0f); const auto beta = helper.Get("beta", 1.0f); - options.set("alpha", alpha); - options.set("beta", beta); + common_options.set("alpha", alpha); + common_options.set("beta", beta); // Add bias if present. if (input_defs.size() > 2) { - options.set("c", model_builder.GetOperand(node.InputDefs()[c_idx]->Name())); + common_options.set("c", model_builder.GetOperand(node.InputDefs()[c_idx]->Name())); } - output = model_builder.GetBuilder().call("gemm", a, b, options); + common_options.set("label", node.Name()); + output = model_builder.GetBuilder().call("gemm", a, b, common_options); } model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output)); @@ -241,7 +268,31 @@ bool GemmOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& node, return false; } - return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "a", "A", logger); + if (op_type == "MatMulInteger") { + // The first decomposed op of MatMulInteger is DequantizeLinear, and so + // we only need to ensure it supports the input0_type. + return IsDataTypeSupportedByOp("DequantizeLinear", input0_type, wnn_limits, "input", "x", logger); + } else { + return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "a", "A", logger); + } +} + +bool GemmOpBuilder::HasSupportedOutputsImpl(const Node& node, const emscripten::val& wnn_limits, + const logging::Logger& logger) const { + const auto& output = *node.OutputDefs()[0]; + const std::string_view op_type = node.OpType(); + int32_t output_type; + if (!GetType(output, output_type, logger)) { + return false; + } + + if (op_type == "MatMulInteger") { + // The last decomposed op of MatMulInteger is Cast, and so + // we only need to ensure it supports the output_type. + return IsDataTypeSupportedByOp("Cast", output_type, wnn_limits, "output", "Output", logger); + } else { + return IsDataTypeSupportedByOp(op_type, output_type, wnn_limits, "output", "Output", logger); + } } void CreateGemmOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { diff --git a/onnxruntime/core/providers/webnn/builders/impl/rotaryEmbedding_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/rotaryEmbedding_op_builder.cc index d02dd61460f60..ad22758028f2c 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/rotaryEmbedding_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/rotaryEmbedding_op_builder.cc @@ -70,20 +70,34 @@ Status RotaryEmbeddingOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_build const auto& input_defs = node.InputDefs(); int32_t input_data_type; ORT_RETURN_IF_NOT(GetType(*input_defs[0], input_data_type, logger), "Cannot get input type"); + + const bool is_onnx_domain = IsOnnxDomain(node.Domain()); + // The input indexes for the onnx domain and the microsoft domain are different. + const size_t cos_cache_idx = is_onnx_domain ? 1 : 2; + const size_t sin_cache_idx = is_onnx_domain ? 2 : 3; + const size_t position_ids_idx = is_onnx_domain ? 3 : 1; std::vector input_shape; std::vector position_ids_shape; std::vector cos_cache_shape; + // Since opset 23, the position_ids input is optional. + const bool has_position_ids = TensorExists(input_defs, position_ids_idx); ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get input shape"); - ORT_RETURN_IF_NOT(GetShape(*input_defs[1], position_ids_shape, logger), "Cannot get position_ids shape"); - ORT_RETURN_IF_NOT(GetShape(*input_defs[2], cos_cache_shape, logger), "Cannot get cos_cache shape"); + ORT_RETURN_IF_NOT(GetShape(*input_defs[cos_cache_idx], cos_cache_shape, logger), "Cannot get cos_cache shape"); + if (has_position_ids) { + ORT_RETURN_IF_NOT(GetShape(*input_defs[position_ids_idx], position_ids_shape, logger), + "Cannot get position_ids shape"); + } const bool input_is_4d = input_shape.size() == 4; // When position_ids is a 1D tensor, it represents the start offset for each sequence. - const bool position_ids_is_offset = position_ids_shape.size() == 1; + const bool position_ids_is_offset = has_position_ids && position_ids_shape.size() == 1; emscripten::val input = model_builder.GetOperand(input_defs[0]->Name()); - emscripten::val position_ids = model_builder.GetOperand(input_defs[1]->Name()); - emscripten::val cos_cache = model_builder.GetOperand(input_defs[2]->Name()); - emscripten::val sin_cache = model_builder.GetOperand(input_defs[3]->Name()); + emscripten::val position_ids; + if (has_position_ids) { + position_ids = model_builder.GetOperand(input_defs[position_ids_idx]->Name()); + } + emscripten::val cos_cache = model_builder.GetOperand(input_defs[cos_cache_idx]->Name()); + emscripten::val sin_cache = model_builder.GetOperand(input_defs[sin_cache_idx]->Name()); const auto node_name = node.Name(); emscripten::val wnn_builder = model_builder.GetBuilder(); @@ -97,19 +111,34 @@ Status RotaryEmbeddingOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_build // - 3D: [batch_size, sequence_length, hidden_size] // - 4D: [batch_size, num_heads, sequence_length, head_size] const uint32_t batch_size = static_cast(input_shape[0]); - const uint32_t sequence_length = input_is_4d ? static_cast(input_shape[2]) - : static_cast(input_shape[1]); - const uint32_t hidden_size = input_is_4d ? static_cast(input_shape[1] * input_shape[3]) - : static_cast(input_shape[2]); - const uint32_t head_size = num_heads == 0 ? static_cast(cos_cache_shape[1]) * 2 - : hidden_size / num_heads; - if (num_heads == 0) { - num_heads = hidden_size / head_size; + uint32_t sequence_length, hidden_size, head_size; + if (input_is_4d) { + sequence_length = static_cast(input_shape[2]); + hidden_size = static_cast(input_shape[1] * input_shape[3]); + num_heads = static_cast(input_shape[1]); + head_size = static_cast(input_shape[3]); + } else { + sequence_length = static_cast(input_shape[1]); + hidden_size = static_cast(input_shape[2]); + // Since opset 23, if the input is 3D, the num_heads must be provided. + if (is_onnx_domain) { + assert(num_heads != 0); + head_size = hidden_size / num_heads; + } else { + if (num_heads == 0) { + head_size = static_cast(cos_cache_shape[1]) * 2; + num_heads = hidden_size / head_size; + } else { + head_size = hidden_size / num_heads; + } + } } + if (rotary_embedding_dim == 0) { rotary_embedding_dim = head_size; } + const uint32_t half_rotary_embedding_dim = rotary_embedding_dim / 2; emscripten::val transpose_options = emscripten::val::object(); // Ensure the input is reshaped to: [batch_size, sequence_length, num_heads, head_size]. @@ -147,8 +176,8 @@ Status RotaryEmbeddingOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_build // Split the partial input0 data into 2 equal parts. // Firstly reshape the partial input0. const std::vector new_partial_input0_shape = - interleaved ? std::vector({batch_size, sequence_length, num_heads, rotary_embedding_dim / 2, 2}) - : std::vector({batch_size, sequence_length, num_heads, 2, rotary_embedding_dim / 2}); + interleaved ? std::vector({batch_size, sequence_length, num_heads, half_rotary_embedding_dim, 2}) + : std::vector({batch_size, sequence_length, num_heads, 2, half_rotary_embedding_dim}); emscripten::val reshape_partial_input0_options = emscripten::val::object(); reshape_partial_input0_options.set("label", node_name + "_reshape_partial_input0"); partial_input0 = wnn_builder.call( @@ -196,19 +225,34 @@ Status RotaryEmbeddingOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_build "add", position_ids, position_ids_range, position_ids_add_range_options); } - // Gather the cosine/sine values based on the position_ids. - emscripten::val gather_cos_sin_options = emscripten::val::object(); - gather_cos_sin_options.set("label", node_name + "_gather_cos_sin"); - gather_cos_sin_options.set("axis", 0); - emscripten::val gather_cos = wnn_builder.call( - "gather", cos_cache, position_ids, gather_cos_sin_options); - emscripten::val gather_sin = wnn_builder.call( - "gather", sin_cache, position_ids, gather_cos_sin_options); + // Gather the cosine/sine values based on the position_ids (if it presents). + emscripten::val gather_cos = cos_cache; + emscripten::val gather_sin = sin_cache; + if (has_position_ids) { + emscripten::val gather_cos_sin_options = emscripten::val::object(); + gather_cos_sin_options.set("label", node_name + "_gather_cos_sin"); + gather_cos_sin_options.set("axis", 0); + gather_cos = wnn_builder.call("gather", gather_cos, position_ids, gather_cos_sin_options); + gather_sin = wnn_builder.call("gather", gather_sin, position_ids, gather_cos_sin_options); + } + + // If it is full rotation, we need to slice the gathered cosine/sine + // to get the shape [batch_size, sequence_length, rotary_embedding_dim / 2]. + if (cos_cache_shape.back() != static_cast(half_rotary_embedding_dim)) { + emscripten::val slice_gather_cos_sin_options = emscripten::val::object(); + slice_gather_cos_sin_options.set("label", node_name + "_slice_gather_cos_sin"); + const std::vector starts{0, 0, 0}; + const std::vector sizes{batch_size, sequence_length, half_rotary_embedding_dim}; + gather_cos = wnn_builder.call("slice", gather_cos, emscripten::val::array(starts), + emscripten::val::array(sizes), slice_gather_cos_sin_options); + gather_sin = wnn_builder.call("slice", gather_sin, emscripten::val::array(starts), + emscripten::val::array(sizes), slice_gather_cos_sin_options); + } - // After gathering cosine/sine, reshape and broadcast them to match the number of heads of the input data. + // Reshape and broadcast them to match the number of heads of the input data. const std::vector reshaped_cos_sin_shape = - interleaved ? std::vector({batch_size, sequence_length, 1, rotary_embedding_dim / 2, 1}) - : std::vector({batch_size, sequence_length, 1, 1, rotary_embedding_dim / 2}); + interleaved ? std::vector({batch_size, sequence_length, 1, half_rotary_embedding_dim, 1}) + : std::vector({batch_size, sequence_length, 1, 1, half_rotary_embedding_dim}); emscripten::val reshape_gather_cos_sin_options = emscripten::val::object(); reshape_gather_cos_sin_options.set("label", node_name + "_reshape_gather_cos_sin"); gather_cos = wnn_builder.call( @@ -312,10 +356,13 @@ bool RotaryEmbeddingOpBuilder::IsOpSupportedImpl(const GraphViewer&, const Node& const WebnnDeviceType /* device_type */, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); + const bool is_onnx_domain = IsOnnxDomain(node.Domain()); + // The input indexes for the onnx domain and the microsoft domain are different. + const size_t cos_cache_idx = is_onnx_domain ? 1 : 2; std::vector input_shape; std::vector cos_cache_shape; if (!GetShape(*input_defs[0], input_shape, logger)) return false; - if (!GetShape(*input_defs[2], cos_cache_shape, logger)) return false; + if (!GetShape(*input_defs[cos_cache_idx], cos_cache_shape, logger)) return false; const auto input_size = input_shape.size(); if (input_size != 3 && input_size != 4) { LOGS(logger, VERBOSE) << "RotaryEmbedding only supports 3D or 4D input shape, input is " << input_size << "D shape"; @@ -326,11 +373,23 @@ bool RotaryEmbeddingOpBuilder::IsOpSupportedImpl(const GraphViewer&, const Node& const int is_packed_batching = helper.Get("is_packed_batching", 0); const int num_heads = helper.Get("num_heads", 0); const int rotary_embedding_dim = helper.Get("rotary_embedding_dim", 0); - const auto sequence_length = input_size == 4 ? input_shape[2] : input_shape[1]; - if (is_packed_batching == 0 && sequence_length > cos_cache_shape[0]) { - LOGS(logger, VERBOSE) << "RotaryEmbedding: updating cos_cache and sin_cache is not currently supported"; - return false; + + if (is_onnx_domain) { + if (input_size == 3 && num_heads == 0) { + LOGS(logger, VERBOSE) << "RotaryEmbedding: num_heads must be provided if input is 3D"; + return false; + } + } else { + if (is_packed_batching == 0 && sequence_length > cos_cache_shape[0]) { + LOGS(logger, VERBOSE) << "RotaryEmbedding: updating cos_cache and sin_cache is not currently supported"; + return false; + } + + if (rotary_embedding_dim > 0 && num_heads == 0) { + LOGS(logger, VERBOSE) << "RotaryEmbedding: num_heads must be provided if rotary_embedding_dim is specified"; + return false; + } } if (input_size == 4 && num_heads != 0 && num_heads != input_shape[1]) { @@ -339,11 +398,6 @@ bool RotaryEmbeddingOpBuilder::IsOpSupportedImpl(const GraphViewer&, const Node& return false; } - if (rotary_embedding_dim > 0 && num_heads == 0) { - LOGS(logger, VERBOSE) << "RotaryEmbedding: num_heads must be provided if rotary_embedding_dim is specified"; - return false; - } - return true; } @@ -353,14 +407,22 @@ bool RotaryEmbeddingOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); const std::string_view op_type = node.OpType(); + const bool is_onnx_domain = IsOnnxDomain(node.Domain()); + // The input indexes for the onnx domain and the microsoft domain are different. + const size_t cos_cache_idx = is_onnx_domain ? 1 : 2; + const size_t sin_cache_idx = is_onnx_domain ? 2 : 3; + const size_t position_ids_idx = is_onnx_domain ? 3 : 1; int32_t input_type = 0; int32_t position_ids_type = 0; int32_t cos_cache_type = 0; int32_t sin_cache_type = 0; + // Since opset 23, the position_ids is an optional input. + const bool has_position_ids = TensorExists(input_defs, position_ids_idx); + if (!GetType(*input_defs[0], input_type, logger) || - !GetType(*input_defs[1], position_ids_type, logger) || - !GetType(*input_defs[2], cos_cache_type, logger) || - !GetType(*input_defs[3], sin_cache_type, logger)) { + (has_position_ids && !GetType(*input_defs[position_ids_idx], position_ids_type, logger)) || + !GetType(*input_defs[cos_cache_idx], cos_cache_type, logger) || + !GetType(*input_defs[sin_cache_idx], sin_cache_type, logger)) { return false; } @@ -369,7 +431,7 @@ bool RotaryEmbeddingOpBuilder::HasSupportedInputsImpl(const GraphViewer&, return false; } - if (position_ids_type != ONNX_NAMESPACE::TensorProto_DataType_INT64) { + if (has_position_ids && position_ids_type != ONNX_NAMESPACE::TensorProto_DataType_INT64) { return false; } diff --git a/onnxruntime/core/providers/webnn/builders/model.cc b/onnxruntime/core/providers/webnn/builders/model.cc index 40fdfc609e6a1..ef829e82823d0 100644 --- a/onnxruntime/core/providers/webnn/builders/model.cc +++ b/onnxruntime/core/providers/webnn/builders/model.cc @@ -158,6 +158,11 @@ onnxruntime::common::Status Model::Dispatch(const InlinedHashMap& outputs) { auto webnnEnsureTensor = emscripten::val::module_property("webnnEnsureTensor"); auto promises = emscripten::val::array(); + bool trace = emscripten::val::module_property("webnnEnableTraceEvent").as(); + emscripten::val console = emscripten::val::global("console"); + if (trace) { + console.call("time", emscripten::val("ORT::Dispatch::jsepEnsureTensor")); + } for (const auto& [_, tensor] : inputs) { emscripten::val shape = emscripten::val::array(); for (const auto& dim : tensor.tensor_info.shape) { @@ -176,6 +181,9 @@ onnxruntime::common::Status Model::Dispatch(const InlinedHashMap(tensor.buffer), tensor.tensor_info.data_type, shape, false); promises.call("push", ml_tensor); } + if (trace) { + console.call("timeEnd", emscripten::val("ORT::Dispatch::jsepEnsureTensor")); + } auto ml_tensors = emscripten::val::global("Promise").call("all", promises).await(); for (const auto& [name, _] : inputs) { wnn_inputs_.set(name, ml_tensors.call("shift")); diff --git a/onnxruntime/core/providers/webnn/data_transfer.cc b/onnxruntime/core/providers/webnn/data_transfer.cc index aa85277b72453..17369e6fbc75d 100644 --- a/onnxruntime/core/providers/webnn/data_transfer.cc +++ b/onnxruntime/core/providers/webnn/data_transfer.cc @@ -20,6 +20,11 @@ common::Status DataTransfer::CopyTensor(const Tensor& src, Tensor& dst) const { // We don't need to transfer the tensor to an MLTensor, so we don't need to copy the data. return Status::OK(); } + bool trace = emscripten::val::module_property("webnnEnableTraceEvent").as(); + emscripten::val console = emscripten::val::global("console"); + if (trace) { + console.call("time", emscripten::val("ORT::DataTransfer::CopyTensor")); + } size_t bytes = src.SizeInBytes(); if (bytes > 0) { @@ -30,10 +35,16 @@ common::Status DataTransfer::CopyTensor(const Tensor& src, Tensor& dst) const { if (dst_device.Type() == OrtDevice::GPU) { EM_ASM({ Module.webnnUploadTensor($0, HEAPU8.subarray($1, $1 + $2)); }, dst_data, reinterpret_cast(src_data), bytes); + if (trace) { + console.call("timeEnd", emscripten::val("ORT::DataTransfer::webnnUploadTensor")); + } } else { auto webnnDownloadTensor = emscripten::val::module_property("webnnDownloadTensor"); auto subarray = emscripten::typed_memory_view(bytes, static_cast(dst_data)); webnnDownloadTensor(reinterpret_cast(src_data), subarray).await(); + if (trace) { + console.call("timeEnd", emscripten::val("ORT::DataTransfer::webnnDownloadTensor")); + } } } diff --git a/onnxruntime/core/session/abi_key_value_pairs.h b/onnxruntime/core/session/abi_key_value_pairs.h index 3242be817881a..150575b3a9efc 100644 --- a/onnxruntime/core/session/abi_key_value_pairs.h +++ b/onnxruntime/core/session/abi_key_value_pairs.h @@ -57,6 +57,8 @@ struct OrtKeyValuePairs { keys.erase(key_iter); values.erase(values.begin() + idx); } + + entries.erase(iter); } } diff --git a/onnxruntime/core/session/abi_session_options.cc b/onnxruntime/core/session/abi_session_options.cc index 0b116c2fa64f6..c205e05baadb9 100644 --- a/onnxruntime/core/session/abi_session_options.cc +++ b/onnxruntime/core/session/abi_session_options.cc @@ -366,6 +366,29 @@ ORT_API_STATUS_IMPL(OrtApis::SetDeterministicCompute, _Inout_ OrtSessionOptions* API_IMPL_END } +ORT_API_STATUS_IMPL(OrtApis::SessionOptionsSetEpSelectionPolicy, _In_ OrtSessionOptions* options, + _In_ OrtExecutionProviderDevicePolicy policy) { + API_IMPL_BEGIN + options->value.ep_selection_policy.enable = true; + options->value.ep_selection_policy.policy = policy; + options->value.ep_selection_policy.delegate = nullptr; + options->value.ep_selection_policy.state = nullptr; + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtApis::SessionOptionsSetEpSelectionPolicyDelegate, _In_ OrtSessionOptions* options, + _In_opt_ EpSelectionDelegate delegate, + _In_opt_ void* state) { + API_IMPL_BEGIN + options->value.ep_selection_policy.enable = true; + options->value.ep_selection_policy.policy = OrtExecutionProviderDevicePolicy_DEFAULT; + options->value.ep_selection_policy.delegate = delegate; + options->value.ep_selection_policy.state = state; + return nullptr; + API_IMPL_END +} + ORT_API_STATUS_IMPL(OrtApis::SessionOptionsSetLoadCancellationFlag, _Inout_ OrtSessionOptions* options, _In_ bool is_cancel) { API_IMPL_BEGIN diff --git a/onnxruntime/core/session/compile_api.cc b/onnxruntime/core/session/compile_api.cc index a3f6addd100ad..ad128fee6cc3d 100644 --- a/onnxruntime/core/session/compile_api.cc +++ b/onnxruntime/core/session/compile_api.cc @@ -8,11 +8,13 @@ #include #include "core/common/common.h" +#include "core/session/allocator_adapters.h" #include "core/framework/error_code_helper.h" #include "core/session/abi_session_options_impl.h" #include "core/session/inference_session.h" #include "core/session/model_compilation_options.h" #include "core/session/ort_apis.h" +#include "core/session/ort_env.h" #include "core/session/utils.h" #else #include "core/common/common.h" @@ -43,7 +45,8 @@ ORT_API_STATUS_IMPL(OrtCompileAPI::CreateModelCompilationOptionsFromSessionOptio return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "The session_options argument must be a non-null pointer"); } - auto model_compile_options = std::make_unique(*env, *session_options); + auto model_compile_options = std::make_unique(env->GetEnvironment(), + *session_options); *out = reinterpret_cast(model_compile_options.release()); return nullptr; #else @@ -150,7 +153,7 @@ ORT_API_STATUS_IMPL(OrtCompileAPI::ModelCompilationOptions_SetOutputModelExterna ORT_API_STATUS_IMPL(OrtCompileAPI::ModelCompilationOptions_SetOutputModelBuffer, _In_ OrtModelCompilationOptions* ort_model_compile_options, - _Inout_ OrtAllocator* allocator, void** output_model_data_ptr, size_t* output_model_data_size_ptr) { + _Inout_ OrtAllocator* ort_allocator, void** output_model_data_ptr, size_t* output_model_data_size_ptr) { API_IMPL_BEGIN #if !defined(ORT_MINIMAL_BUILD) auto model_compile_options = reinterpret_cast(ort_model_compile_options); @@ -163,17 +166,18 @@ ORT_API_STATUS_IMPL(OrtCompileAPI::ModelCompilationOptions_SetOutputModelBuffer, return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Invalid output model buffer: size pointer is null"); } - if (allocator == nullptr) { + if (ort_allocator == nullptr) { return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Invalid allocator for output model buffer: allocator pointer is null"); } - ORT_API_RETURN_IF_STATUS_NOT_OK(model_compile_options->SetOutputModelBuffer(allocator, + onnxruntime::AllocatorPtr allocator = std::make_shared(ort_allocator); + ORT_API_RETURN_IF_STATUS_NOT_OK(model_compile_options->SetOutputModelBuffer(std::move(allocator), output_model_data_ptr, output_model_data_size_ptr)); return nullptr; #else ORT_UNUSED_PARAMETER(ort_model_compile_options); - ORT_UNUSED_PARAMETER(allocator); + ORT_UNUSED_PARAMETER(ort_allocator); ORT_UNUSED_PARAMETER(output_model_data_ptr); ORT_UNUSED_PARAMETER(output_model_data_size_ptr); return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "Compile API is not supported in this build"); @@ -202,23 +206,7 @@ ORT_API_STATUS_IMPL(OrtCompileAPI::CompileModel, _In_ const OrtEnv* env, API_IMPL_BEGIN #if !defined(ORT_MINIMAL_BUILD) auto model_compile_options = reinterpret_cast(ort_model_compile_options); - ORT_API_RETURN_IF_STATUS_NOT_OK(model_compile_options->Check()); - - std::unique_ptr session; - const OrtSessionOptions* session_options = &model_compile_options->GetSessionOptions(); - - if (model_compile_options->InputModelComesFromFile()) { - PathString input_model_path = ToPathString(model_compile_options->GetInputModelPath()); - ORT_API_RETURN_IF_ERROR(CreateSessionAndLoadModel(session_options, env, - input_model_path.c_str(), - nullptr, 0, session)); - } else { - ORT_API_RETURN_IF_ERROR(CreateSessionAndLoadModel(session_options, env, nullptr, - model_compile_options->GetInputModelData(), - model_compile_options->GetInputModelDataSize(), session)); - } - - ORT_API_RETURN_IF_ERROR(InitializeSession(session_options, *session)); + ORT_API_RETURN_IF_STATUS_NOT_OK(onnxruntime::CompileModel(env->GetEnvironment(), *model_compile_options)); return nullptr; #else ORT_UNUSED_PARAMETER(env); diff --git a/onnxruntime/core/session/ep_library_internal.cc b/onnxruntime/core/session/ep_library_internal.cc index c515195c7e6bf..aa032f24f13c0 100644 --- a/onnxruntime/core/session/ep_library_internal.cc +++ b/onnxruntime/core/session/ep_library_internal.cc @@ -183,14 +183,14 @@ std::vector> EpLibraryInternal::CreateInterna // CPU EP internal_eps.push_back(CreateCpuEp()); -#if defined(USE_DML) - internal_eps.push_back(CreateDmlEp()); -#endif - #if defined(USE_WEBGPU) internal_eps.push_back(CreateWebGpuEp()); #endif +#if defined(USE_DML) + internal_eps.push_back(CreateDmlEp()); +#endif + return internal_eps; } diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index 8ec7312cc6354..df70856a64e99 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -3266,6 +3266,7 @@ common::Status InferenceSession::SaveModelMetadata(const onnxruntime::Model& mod // save model metadata model_metadata_.producer_name = model.ProducerName(); + model_metadata_.producer_version = model.ProducerVersion(); model_metadata_.description = model.DocString(); model_metadata_.graph_description = model.GraphDocString(); model_metadata_.domain = model.Domain(); @@ -3430,6 +3431,10 @@ const Model& InferenceSession::GetModel() const { return *model_; } +const Environment& InferenceSession::GetEnvironment() const { + return environment_; +} + SessionIOBinding::SessionIOBinding(InferenceSession* session) : sess_(session) { ORT_ENFORCE(session->NewIOBinding(&binding_).IsOK()); } diff --git a/onnxruntime/core/session/inference_session.h b/onnxruntime/core/session/inference_session.h index a21388d1e9918..51350390a0456 100644 --- a/onnxruntime/core/session/inference_session.h +++ b/onnxruntime/core/session/inference_session.h @@ -80,6 +80,7 @@ struct ModelMetadata { ModelMetadata& operator=(const ModelMetadata&) = delete; std::string producer_name; + std::string producer_version; std::string graph_name; std::string domain; std::string description; @@ -350,8 +351,8 @@ class InferenceSession { /** * Initializes a previously loaded ONNX model. Initialization includes but is not - * limited to graph transformations, construction of kernels, etc. - * This method assumes that a method has been loaded previously. + * limited to graph transformations, construction of kernels, EP policy decisions, etc. + * This method assumes that a model has been loaded previously. * This API is thread-safe. * @return OK if success */ @@ -603,6 +604,7 @@ class InferenceSession { #endif const Model& GetModel() const; + const Environment& GetEnvironment() const; protected: #if !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/core/session/model_compilation_options.cc b/onnxruntime/core/session/model_compilation_options.cc index 80ef18de5cfa3..d0cb092f78843 100644 --- a/onnxruntime/core/session/model_compilation_options.cc +++ b/onnxruntime/core/session/model_compilation_options.cc @@ -8,18 +8,20 @@ #include #include -#include "core/session/allocator_adapters.h" +#include "core/framework/allocator.h" #include "core/session/onnxruntime_session_options_config_keys.h" -#include "core/session/ort_env.h" +#include "core/session/environment.h" namespace onnxruntime { -ModelCompilationOptions::ModelCompilationOptions(const OrtEnv& env, const OrtSessionOptions& session_options) +ModelCompilationOptions::ModelCompilationOptions(const onnxruntime::Environment& env, const OrtSessionOptions& session_options) : env_(env), session_options_(session_options) { session_options_.value.has_explicit_ep_context_gen_options = true; session_options_.value.ep_context_gen_options = session_options.value.GetEpContextGenerationOptions(); session_options_.value.ep_context_gen_options.enable = true; session_options_.value.ep_context_gen_options.overwrite_existing_output_file = true; - session_options_.value.ep_context_gen_options.error_if_no_compiled_nodes = true; + // defaulting to false to support wider usage. will log WARNING if compiling model with no context nodes. + // TODO: Add ability for user to explicitly set this. + session_options_.value.ep_context_gen_options.error_if_no_compiled_nodes = false; // Shouldn't fail because the key/value strings are below the maximum string length limits in ConfigOptions. ORT_ENFORCE(session_options_.value.config_options.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1").IsOK()); @@ -84,15 +86,14 @@ void ModelCompilationOptions::SetOutputModelExternalInitializersFile(const std:: external_initializer_size_threshold; } -Status ModelCompilationOptions::SetOutputModelBuffer(OrtAllocator* allocator, +Status ModelCompilationOptions::SetOutputModelBuffer(onnxruntime::AllocatorPtr allocator, void** output_model_buffer_ptr, size_t* output_model_buffer_size_ptr) { ORT_RETURN_IF_ERROR(ResetOutputModelSettings()); session_options_.value.ep_context_gen_options.output_model_buffer_ptr = output_model_buffer_ptr; session_options_.value.ep_context_gen_options.output_model_buffer_size_ptr = output_model_buffer_size_ptr; - session_options_.value.ep_context_gen_options.output_model_buffer_allocator = - std::make_shared(allocator); + session_options_.value.ep_context_gen_options.output_model_buffer_allocator = std::move(allocator); return Status::OK(); } diff --git a/onnxruntime/core/session/model_compilation_options.h b/onnxruntime/core/session/model_compilation_options.h index 5ee64d48c3060..9238264003645 100644 --- a/onnxruntime/core/session/model_compilation_options.h +++ b/onnxruntime/core/session/model_compilation_options.h @@ -8,11 +8,13 @@ #include #include "core/common/status.h" #include "core/common/path_string.h" +#include "core/framework/allocator.h" #include "core/session/abi_session_options_impl.h" #include "core/session/onnxruntime_c_api.h" #include "core/session/onnxruntime_session_options_config_keys.h" namespace onnxruntime { +class Environment; /// /// Stores options to compile ONNX models into "EPContext" models. @@ -23,9 +25,9 @@ class ModelCompilationOptions { /// Creates an instance with the session options to use for model compilation. /// The session options are expected to have execution providers that compile. /// - /// Reference to OrtEnv + /// Reference to Environment /// Reference to session options - ModelCompilationOptions(const OrtEnv& env, const OrtSessionOptions& session_options); + ModelCompilationOptions(const onnxruntime::Environment& env, const OrtSessionOptions& session_options); /// /// Sets the file path to the input ONNX model to compile. @@ -67,7 +69,7 @@ class ModelCompilationOptions { /// Pointer to the buffer that will contain the compiled model /// Set to the size of the buffer /// Status indicating potential error - Status SetOutputModelBuffer(OrtAllocator* allocator, void** output_model_buffer_ptr, + Status SetOutputModelBuffer(onnxruntime::AllocatorPtr allocator, void** output_model_buffer_ptr, size_t* output_model_buffer_size_ptr); /// @@ -122,7 +124,7 @@ class ModelCompilationOptions { Status CheckInputModelSettings() const; Status CheckOutputModelSettings() const; - const OrtEnv& env_; + const onnxruntime::Environment& env_; OrtSessionOptions session_options_; std::string input_model_path_; const void* input_model_data_ = nullptr; diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index b5c271594055a..868fab767fa7b 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -1206,6 +1206,33 @@ ORT_API_STATUS_IMPL(OrtApis::GetStringTensorElementLength, _In_ const OrtValue* API_IMPL_END } +ORT_API_STATUS_IMPL(OrtApis::GetTensorSizeInBytes, _In_ const OrtValue* value, _Out_ size_t* size) { + API_IMPL_BEGIN + + if (value == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Input `value` argument must not be null"); + } + + if (size == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Output `size` argument must not be null"); + } + + if (!value->IsAllocated() || !value->IsTensor()) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "OrtValue is expected to contain a tensor"); + } + + const auto& tensor = value->Get(); + + // Check if this is a string tensor + if (tensor.IsDataTypeString()) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "String tensors are not supported by this API"); + } + + *size = tensor.SizeInBytes(); + return nullptr; + API_IMPL_END +} + ORT_API_STATUS_IMPL(OrtApis::GetStringTensorContent, _In_ const OrtValue* value, _Out_writes_bytes_all_(s_len) void* s, size_t s_len, _Out_writes_all_(offsets_len) size_t* offsets, size_t offsets_len) { API_IMPL_BEGIN @@ -2465,49 +2492,16 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_V2, _In_ OrtS _In_reads_(num_op_options) const char* const* ep_option_vals, size_t num_ep_options) { API_IMPL_BEGIN - if (num_ep_devices > 1) { - const auto& ep_name = ep_devices[0]->ep_name; - bool all_match = std::all_of(ep_devices + 1, ep_devices + num_ep_devices, - [&ep_name](const OrtEpDevice* ep_device) { return ep_device->ep_name == ep_name; }); - if (!all_match) { - return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, - "All OrtEpDevice values in ep_devices must have the same execution provider."); - } - } - - EpFactoryInternal* internal_factory = nullptr; - for (size_t i = 0; i < num_ep_devices; ++i) { - const OrtEpDevice* entry = ep_devices[i]; - - // we expect the internal factory to be available for internal and provider bridge EPs, which is all we support. - internal_factory = env->GetEnvironment().GetEpFactoryInternal(entry->ep_factory); - if (!internal_factory) { - return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "EP is not currently supported by this API"); - } - - // add the options to the session options with the EP prefix. - // first add the default values with prefix followed by user specified values so those win - const auto prefix = OrtSessionOptions::GetProviderOptionPrefix(entry->ep_name.c_str()); - auto& config_options = session_options->value.config_options; - for (const auto& [key, value] : entry->ep_options.entries) { - ORT_API_RETURN_IF_STATUS_NOT_OK(config_options.AddConfigEntry((prefix + key).c_str(), value.c_str())); - } - - for (size_t j = 0; j < num_ep_options; ++j) { - if (ep_option_keys[j] == nullptr) { - continue; - } + std::unique_ptr provider_factory = nullptr; - ORT_API_RETURN_IF_STATUS_NOT_OK(config_options.AddConfigEntry((prefix + ep_option_keys[j]).c_str(), - ep_option_vals[j])); - } - } - - if (internal_factory) { - session_options->provider_factories.push_back( - std::make_unique( - *internal_factory, std::vector(ep_devices, ep_devices + num_ep_devices))); - } + ORT_API_RETURN_IF_STATUS_NOT_OK(CreateIExecutionProviderFactoryForEpDevices( + env->GetEnvironment(), + session_options->value, + gsl::span(ep_devices, num_ep_devices), + gsl::span(ep_option_keys, num_ep_options), + gsl::span(ep_option_vals, num_ep_options), + /*output*/ provider_factory)); + session_options->provider_factories.push_back(std::move(provider_factory)); return nullptr; API_IMPL_END @@ -3012,6 +3006,8 @@ static constexpr OrtApi ort_api_1_to_23 = { &OrtApis::UnregisterExecutionProviderLibrary, &OrtApis::GetEpDevices, &OrtApis::SessionOptionsAppendExecutionProvider_V2, + &OrtApis::SessionOptionsSetEpSelectionPolicy, + &OrtApis::SessionOptionsSetEpSelectionPolicyDelegate, &OrtApis::HardwareDevice_Type, &OrtApis::HardwareDevice_VendorId, @@ -3027,6 +3023,7 @@ static constexpr OrtApi ort_api_1_to_23 = { &OrtApis::GetEpApi, // End of Version 22 - DO NOT MODIFY ABOVE (see above text for more information) + &OrtApis::GetTensorSizeInBytes, }; // OrtApiBase can never change as there is no way to know what version of OrtApiBase is returned by OrtGetApiBase. @@ -3061,7 +3058,7 @@ static_assert(offsetof(OrtApi, AddExternalInitializersFromFilesInMemory) / sizeo // no additions in version 19, 20, and 21 static_assert(offsetof(OrtApi, SetEpDynamicOptions) / sizeof(void*) == 284, "Size of version 20 API cannot change"); -static_assert(offsetof(OrtApi, GetEpApi) / sizeof(void*) == 315, "Size of version 22 API cannot change"); +static_assert(offsetof(OrtApi, GetEpApi) / sizeof(void*) == 317, "Size of version 22 API cannot change"); // So that nobody forgets to finish an API version, this check will serve as a reminder: static_assert(std::string_view(ORT_VERSION) == "1.23.0", diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index 0033eb0d604f2..81af6694f6273 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -562,8 +562,9 @@ ORT_API(void, GetKeyValuePairs, _In_ const OrtKeyValuePairs* kvps, ORT_API(void, RemoveKeyValuePair, _In_ OrtKeyValuePairs* kvps, _In_ const char* key); ORT_API(void, ReleaseKeyValuePairs, _Frees_ptr_opt_ OrtKeyValuePairs*); -ORT_API_STATUS_IMPL(RegisterExecutionProviderLibrary, _In_ OrtEnv* env, const char* ep_name, const ORTCHAR_T* path); -ORT_API_STATUS_IMPL(UnregisterExecutionProviderLibrary, _In_ OrtEnv* env, _In_ const char* ep_name); +ORT_API_STATUS_IMPL(RegisterExecutionProviderLibrary, _In_ OrtEnv* env, const char* registration_name, + const ORTCHAR_T* path); +ORT_API_STATUS_IMPL(UnregisterExecutionProviderLibrary, _In_ OrtEnv* env, _In_ const char* registration_name); ORT_API_STATUS_IMPL(GetEpDevices, _In_ const OrtEnv* env, _Outptr_ const OrtEpDevice* const** ep_devices, _Out_ size_t* num_ep_devices); @@ -575,6 +576,13 @@ ORT_API_STATUS_IMPL(SessionOptionsAppendExecutionProvider_V2, _In_ OrtSessionOpt _In_reads_(num_op_options) const char* const* ep_option_vals, size_t num_ep_options); +ORT_API_STATUS_IMPL(SessionOptionsSetEpSelectionPolicy, _In_ OrtSessionOptions* sess_options, + _In_ OrtExecutionProviderDevicePolicy policy); + +ORT_API_STATUS_IMPL(SessionOptionsSetEpSelectionPolicyDelegate, _In_ OrtSessionOptions* sess_options, + _In_ EpSelectionDelegate delegate, + _In_opt_ void* state); + // OrtHardwareDevice accessors. ORT_API(OrtHardwareDeviceType, HardwareDevice_Type, _In_ const OrtHardwareDevice* device); ORT_API(uint32_t, HardwareDevice_VendorId, _In_ const OrtHardwareDevice* device); @@ -590,4 +598,7 @@ ORT_API(const OrtKeyValuePairs*, EpDevice_EpOptions, _In_ const OrtEpDevice* ep_ ORT_API(const OrtHardwareDevice*, EpDevice_Device, _In_ const OrtEpDevice* ep_device); ORT_API(const OrtEpApi*, GetEpApi); + +ORT_API_STATUS_IMPL(GetTensorSizeInBytes, _In_ const OrtValue* ort_value, _Out_ size_t* size); + } // namespace OrtApis diff --git a/onnxruntime/core/session/provider_policy_context.cc b/onnxruntime/core/session/provider_policy_context.cc new file mode 100644 index 0000000000000..a4e0c16b411a1 --- /dev/null +++ b/onnxruntime/core/session/provider_policy_context.cc @@ -0,0 +1,385 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if !defined(ORT_MINIMAL_BUILD) + +#include "core/session/provider_policy_context.h" + +#include + +#include "core/framework/error_code_helper.h" +#include "core/session/abi_devices.h" +#include "core/session/ep_factory_internal.h" +#include "core/session/inference_session.h" +#include "core/session/inference_session_utils.h" +#include "core/session/onnxruntime_c_api.h" +#include "core/session/onnxruntime_session_options_config_keys.h" +#include "core/session/ort_apis.h" + +namespace onnxruntime { +namespace { +bool MatchesEpVendor(const OrtEpDevice* d) { + // TODO: Would be better to match on Id. Should the EP add that in EP metadata? + return d->device->vendor == d->ep_vendor; +} + +bool IsDiscreteDevice(const OrtEpDevice* d) { + if (d->device->type != OrtHardwareDeviceType::OrtHardwareDeviceType_GPU) { + return false; + } + + const auto& entries = d->device->metadata.entries; + if (auto it = entries.find("Discrete"); it != entries.end()) { + return it->second == "1"; + } + + return false; +} + +bool IsDefaultCpuEp(const OrtEpDevice* d) { + return d->device->type == OrtHardwareDeviceType::OrtHardwareDeviceType_CPU && + d->ep_vendor == "Microsoft"; +} + +// Sort devices. NPU -> GPU -> CPU +// Within in type, vendor owned, not. +// Default CPU EP is last +std::vector OrderDevices(const std::vector& devices) { + std::vector sorted_devices(devices.begin(), devices.end()); + std::sort(sorted_devices.begin(), sorted_devices.end(), [](const OrtEpDevice* a, const OrtEpDevice* b) { + auto aDeviceType = a->device->type; + auto bDeviceType = b->device->type; + if (aDeviceType != bDeviceType) { + // NPU -> GPU -> CPU + // std::sort is ascending order, so NPU < GPU < CPU + + // one is an NPU + if (aDeviceType == OrtHardwareDeviceType::OrtHardwareDeviceType_NPU) { + return true; + } else if (bDeviceType == OrtHardwareDeviceType::OrtHardwareDeviceType_NPU) { + return false; + } + + if (aDeviceType == OrtHardwareDeviceType::OrtHardwareDeviceType_GPU) { + return true; + } else if (bDeviceType == OrtHardwareDeviceType::OrtHardwareDeviceType_GPU) { + return false; + } + + // this shouldn't be reachable as it would imply both are CPU + ORT_THROW("Unexpected combination of devices"); + } + + // both devices are the same + + if (aDeviceType == OrtHardwareDeviceType::OrtHardwareDeviceType_GPU) { + bool aDiscrete = IsDiscreteDevice(a); + bool bDiscrete = IsDiscreteDevice(b); + if (aDiscrete != bDiscrete) { + return aDiscrete == true; // prefer discrete + } + + // both discrete or both integrated + } + + // prefer device matching platform vendor + bool aVendor = MatchesEpVendor(a); + bool bVendor = MatchesEpVendor(b); + if (aVendor != bVendor) { + return aVendor == true; // prefer the device that matches the EP vendor + } + + // default CPU EP last + bool aIsDefaultCpuEp = IsDefaultCpuEp(a); + bool bIsDefaultCpuEp = IsDefaultCpuEp(b); + if (!aIsDefaultCpuEp && !bIsDefaultCpuEp) { + // neither are default CPU EP. both do/don't match vendor. + // TODO: implement tie-breaker for this scenario. arbitrarily sort by ep name + return a->ep_name < b->ep_name; + } + + // one is the default CPU EP + return aIsDefaultCpuEp == false; // prefer the one that is not the default CPU EP + }); + + return sorted_devices; +} + +OrtKeyValuePairs GetModelMetadata(const InferenceSession& session) { + OrtKeyValuePairs metadata; + auto status_and_metadata = session.GetModelMetadata(); + + if (!status_and_metadata.first.IsOK()) { + return metadata; + } + + // use field names from onnx.proto + const auto& model_metadata = *status_and_metadata.second; + metadata.Add("producer_name", model_metadata.producer_name); + metadata.Add("producer_version", model_metadata.producer_version); + metadata.Add("domain", model_metadata.domain); + metadata.Add("model_version", std::to_string(model_metadata.version)); + metadata.Add("doc_string", model_metadata.description); + metadata.Add("graph_name", model_metadata.graph_name); // name from main GraphProto + metadata.Add("graph_description", model_metadata.graph_description); // descriptions from main GraphProto + for (const auto& entry : model_metadata.custom_metadata_map) { + metadata.Add(entry.first, entry.second); + } + + return metadata; +} +} // namespace + +// Select execution providers based on the device policy and available devices and add to session +Status ProviderPolicyContext::SelectEpsForSession(const Environment& env, const OrtSessionOptions& options, + InferenceSession& sess) { + // Get the list of devices from the environment and order them. + // Ordered by preference within each type. NPU -> GPU -> NPU + // TODO: Should environment.cc do the ordering? + std::vector execution_devices = OrderDevices(env.GetOrtEpDevices()); + + // The list of devices selected by policies + std::vector devices_selected; + + // Run the delegate if it was passed in lieu of any other policy + if (options.value.ep_selection_policy.delegate) { + auto model_metadata = GetModelMetadata(sess); + OrtKeyValuePairs runtime_metadata; // TODO: where should this come from? + + std::vector delegate_devices(execution_devices.begin(), execution_devices.end()); + std::array selected_devices{nullptr}; + size_t num_selected = 0; + + EpSelectionDelegate delegate = options.value.ep_selection_policy.delegate; + auto* status = delegate(delegate_devices.data(), delegate_devices.size(), + &model_metadata, &runtime_metadata, + selected_devices.data(), selected_devices.size(), &num_selected, + options.value.ep_selection_policy.state); + + // return or fall-through for both these cases + // going with explicit failure for now so it's obvious to user what is happening + if (status != nullptr) { + std::string delegate_error_msg = OrtApis::GetErrorMessage(status); // copy + OrtApis::ReleaseStatus(status); + + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "EP selection delegate failed: ", delegate_error_msg); + } + + if (num_selected == 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "EP selection delegate did not select anything."); + } + + if (num_selected > selected_devices.size()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "EP selection delegate selected too many EP devices (", num_selected, "). ", + "The limit is ", selected_devices.size(), " EP devices."); + } + + // Copy the selected devices to the output vector + devices_selected.reserve(num_selected); + for (size_t i = 0; i < num_selected; ++i) { + devices_selected.push_back(selected_devices[i]); + } + } else { + // Create the selector for the chosen policy + std::unique_ptr selector; + switch (options.value.ep_selection_policy.policy) { + case OrtExecutionProviderDevicePolicy_DEFAULT: + selector = std::make_unique(); + break; + case OrtExecutionProviderDevicePolicy_PREFER_CPU: + selector = std::make_unique(); + break; + case OrtExecutionProviderDevicePolicy_PREFER_NPU: + case OrtExecutionProviderDevicePolicy_MAX_EFFICIENCY: + case OrtExecutionProviderDevicePolicy_MIN_OVERALL_POWER: + selector = std::make_unique(); + break; + case OrtExecutionProviderDevicePolicy_PREFER_GPU: + case OrtExecutionProviderDevicePolicy_MAX_PERFORMANCE: + selector = std::make_unique(); + break; + } + + // Execute policy + + selector->SelectProvidersForDevices(execution_devices, devices_selected); + } + + // Fail if we did not find any device matches + if (devices_selected.empty()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "No execution providers selected. Please check the device policy and available devices."); + } + + // Configure the session options for the devices. This updates the SessionOptions in the InferenceSession with any + // EP options that have not been overridden by the user. + ORT_RETURN_IF_ERROR(AddEpDefaultOptionsToSession(sess, devices_selected)); + + // Create OrtSessionOptions for the CreateEp call. + // Once the InferenceSession is created, its SessionOptions is the source of truth and contains all the values from + // the user provided OrtSessionOptions. We do a copy for simplicity. The OrtSessionOptions instance goes away + // once we exit this function so an EP implementation should not use OrtSessionOptions after it returns from + // CreateEp. + auto& session_options = sess.GetMutableSessionOptions(); + OrtSessionOptions ort_so; + ort_so.value = session_options; + const auto& session_logger = sess.GetLogger(); + const OrtLogger& api_session_logger = *session_logger->ToExternal(); + + // Remove the ORT CPU EP if configured to do so + bool disable_ort_cpu_ep = ort_so.value.config_options.GetConfigEntry(kOrtSessionOptionsDisableCPUEPFallback) == "1"; + if (disable_ort_cpu_ep) { + RemoveOrtCpuDevice(devices_selected); + } + + // Fold the EPs into a single structure per factory + std::vector eps_selected; + FoldSelectedDevices(devices_selected, eps_selected); + + // Iterate through the selected EPs and create them + for (size_t idx = 0; idx < eps_selected.size(); ++idx) { + std::unique_ptr ep = nullptr; + ORT_RETURN_IF_ERROR(CreateExecutionProvider(env, ort_so, api_session_logger, eps_selected[idx], ep)); + if (ep != nullptr) { + ORT_RETURN_IF_ERROR(sess.RegisterExecutionProvider(std::move(ep))); + } + } + + return Status::OK(); +} + +void ProviderPolicyContext::FoldSelectedDevices(std::vector devices_selected, + std::vector& eps_selected) { + while (devices_selected.size() > 0) { + const auto ep_name = std::string(devices_selected[0]->ep_name); + SelectionInfo info; + info.ep_factory = devices_selected[0]->ep_factory; + + do { + auto iter = std::find_if(devices_selected.begin(), devices_selected.end(), [&ep_name](const OrtEpDevice* d) { + return d->ep_name == ep_name; + }); + + if (iter != devices_selected.end()) { + info.devices.push_back((*iter)->device); + info.ep_metadata.push_back(&(*iter)->ep_metadata); + devices_selected.erase(iter); + } else { + break; + } + } while (true); + + eps_selected.push_back(info); + } +} + +Status ProviderPolicyContext::CreateExecutionProvider(const Environment& env, OrtSessionOptions& options, + const OrtLogger& logger, + SelectionInfo& info, std::unique_ptr& ep) { + EpFactoryInternal* internal_factory = env.GetEpFactoryInternal(info.ep_factory); + + if (internal_factory) { + // this is a factory we created and registered internally for internal and provider bridge EPs + OrtStatus* status = internal_factory->CreateIExecutionProvider(info.devices.data(), info.ep_metadata.data(), + info.devices.size(), &options, &logger, + &ep); + if (status != nullptr) { + return ToStatus(status); + } + } else { + // in the real setup we need an IExecutionProvider wrapper implementation that uses the OrtEp internally, + // and we would add that IExecutionProvider to the InferenceSession. + // but first we need OrtEp and the OrtEpApi to be implemented. + ORT_NOT_IMPLEMENTED("IExecutionProvider that wraps OrtEp has not been implemented."); + + // OrtEp* api_ep = nullptr; + //// add the ep_options to session options but leave any existing entries (user provided overrides) untouched. + // auto status = info.ep_factory->CreateEp(info.ep_factory, info.devices.data(), info.ep_metadata.data(), + // info.devices.size(), &options, &logger, + // &api_ep); + // if (status != nullptr) { + // return ToStatus(status); + // } + } + + return Status::OK(); +} + +Status ProviderPolicyContext::AddEpDefaultOptionsToSession(InferenceSession& sess, + std::vector devices) { + auto& config_options = sess.GetMutableSessionOptions().config_options; + for (auto device : devices) { + const std::string ep_options_prefix = OrtSessionOptions::GetProviderOptionPrefix(device->ep_name.c_str()); + for (const auto& [key, value] : device->ep_options.entries) { + const std::string option_key = ep_options_prefix + key; + // preserve user-provided options as they override any defaults the EP factory specified earlier + if (config_options.configurations.find(option_key) == config_options.configurations.end()) { + // use AddConfigEntry for the error checking it does + ORT_RETURN_IF_ERROR(config_options.AddConfigEntry(option_key.c_str(), value.c_str())); + } + } + } + + return Status::OK(); +} + +void ProviderPolicyContext::RemoveOrtCpuDevice(std::vector& devices) { + // Remove the Microsoft CPU EP. always last if available. + if (IsDefaultCpuEp(devices.back())) { + devices.pop_back(); + } +} + +void DefaultEpPolicy::SelectProvidersForDevices(const std::vector& sorted_devices, + std::vector& selected) { + // Default policy is prefer CPU + PreferCpuEpPolicy().SelectProvidersForDevices(sorted_devices, selected); +} + +void PreferCpuEpPolicy::SelectProvidersForDevices(const std::vector& sorted_devices, + std::vector& selected) { + // Select the first CPU device from sorted devices + auto first_cpu = std::find_if(sorted_devices.begin(), sorted_devices.end(), + [](const OrtEpDevice* device) { + return device->device->type == OrtHardwareDeviceType::OrtHardwareDeviceType_CPU; + }); + + ORT_ENFORCE(first_cpu != sorted_devices.end(), "No CPU based execution providers were found."); + selected.push_back(*first_cpu); + + // add ORT CPU EP as the final option to ensure maximum coverage of opsets and operators + if (!IsDefaultCpuEp(*first_cpu) && IsDefaultCpuEp(sorted_devices.back())) { + selected.push_back(sorted_devices.back()); + } +} + +void PreferNpuEpPolicy::SelectProvidersForDevices(const std::vector& sorted_devices, + std::vector& selected) { + // Select the first NPU if there is one. + if (sorted_devices.front()->device->type == OrtHardwareDeviceType::OrtHardwareDeviceType_NPU) { + selected.push_back(sorted_devices.front()); + } + + // CPU fallback + PreferCpuEpPolicy().SelectProvidersForDevices(sorted_devices, selected); +} + +void PreferGpuEpPolicy::SelectProvidersForDevices(const std::vector& sorted_devices, + std::vector& selected) { + // Select the first GPU device + auto first_gpu = std::find_if(sorted_devices.begin(), sorted_devices.end(), + [](const OrtEpDevice* device) { + return device->device->type == OrtHardwareDeviceType::OrtHardwareDeviceType_GPU; + }); + + if (first_gpu != sorted_devices.end()) { + selected.push_back(*first_gpu); + } + + // Add a CPU fallback + PreferCpuEpPolicy().SelectProvidersForDevices(sorted_devices, selected); +} +} // namespace onnxruntime + +#endif // !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/core/session/provider_policy_context.h b/onnxruntime/core/session/provider_policy_context.h new file mode 100644 index 0000000000000..185f9523312ba --- /dev/null +++ b/onnxruntime/core/session/provider_policy_context.h @@ -0,0 +1,79 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#if !defined(ORT_MINIMAL_BUILD) + +#include "core/session/abi_session_options_impl.h" +#include "core/session/environment.h" +#include "core/session/onnxruntime_c_api.h" // For OrtExecutionProviderDevicePolicy + +namespace onnxruntime { + +struct SelectionInfo { + OrtEpFactory* ep_factory; + std::vector devices; + std::vector ep_metadata; +}; + +class IEpPolicySelector { + public: + /// + /// Select the OrtEpDevice instances to use. + /// Selection is in priority order. Highest priority first. + /// + /// Ordered devices. + /// Type order is NPU -> GPU -> CPU + /// Within a type: Discrete -> Integrated if GPU, EP vendor matches device vendor, vendor does not match + /// ORT CPU EP is always last if available. + /// + /// + virtual void SelectProvidersForDevices(const std::vector& sorted_devices, + std::vector& selected_devices) = 0; + + virtual ~IEpPolicySelector() = default; +}; + +class ProviderPolicyContext { + public: + ProviderPolicyContext() = default; + + Status SelectEpsForSession(const Environment& env, const OrtSessionOptions& options, InferenceSession& sess); + Status AddEpDefaultOptionsToSession(InferenceSession& sess, std::vector devices); + void RemoveOrtCpuDevice(std::vector& devices); + Status CreateExecutionProvider(const Environment& env, OrtSessionOptions& options, const OrtLogger& logger, + SelectionInfo& info, std::unique_ptr& ep); + void FoldSelectedDevices(std::vector devices_selected, // copy + std::vector& eps_selected); + + private: +}; + +class DefaultEpPolicy : public IEpPolicySelector { + public: + void SelectProvidersForDevices(const std::vector& sorted_devices, + std::vector& selected_devices) override; +}; + +class PreferCpuEpPolicy : public IEpPolicySelector { + public: + void SelectProvidersForDevices(const std::vector& sorted_devices, + std::vector& selected_devices) override; +}; + +class PreferNpuEpPolicy : public IEpPolicySelector { + public: + void SelectProvidersForDevices(const std::vector& sorted_devices, + std::vector& selected_devices) override; +}; + +class PreferGpuEpPolicy : public IEpPolicySelector { + public: + void SelectProvidersForDevices(const std::vector& sorted_devices, + std::vector& selected_devices) override; +}; + +} // namespace onnxruntime + +#endif // !ORT_MINIMAL_BUILD diff --git a/onnxruntime/core/session/utils.cc b/onnxruntime/core/session/utils.cc index adb019fdde86d..8ca4ef6af1f44 100644 --- a/onnxruntime/core/session/utils.cc +++ b/onnxruntime/core/session/utils.cc @@ -10,14 +10,18 @@ #include "core/session/environment.h" #include "core/session/inference_session.h" #include "core/session/inference_session_utils.h" -#include "core/session/ep_factory_internal.h" #include "core/session/onnxruntime_c_api.h" #include "core/session/onnxruntime_session_options_config_keys.h" -#include "core/session/ep_library_plugin.h" -#include "core/session/ep_library_provider_bridge.h" #include "core/session/ort_apis.h" #include "core/session/ort_env.h" -#include "core/session/onnxruntime_session_options_config_keys.h" + +#if !defined(ORT_MINIMAL_BUILD) +#include "core/session/ep_factory_internal.h" +#include "core/session/ep_library_plugin.h" +#include "core/session/ep_library_provider_bridge.h" +#include "core/session/model_compilation_options.h" +#include "core/session/provider_policy_context.h" +#endif // !defined(ORT_MINIMAL_BUILD) using namespace onnxruntime; #if !defined(ORT_MINIMAL_BUILD) @@ -71,6 +75,11 @@ Status TestAutoSelectEPsImpl(const Environment& env, InferenceSession& sess, con return ToStatus(status); } } else { + // in the real setup we need an IExecutionProvider wrapper implementation that uses the OrtEp internally, + // and we would add that IExecutionProvider to the InferenceSession. + ORT_NOT_IMPLEMENTED("IExecutionProvider that wraps OrtEp has not been implemented."); + + /* OrtEp* api_ep = nullptr; auto status = ep_device->ep_factory->CreateEp( ep_device->ep_factory, devices.data(), ep_metadata.data(), devices.size(), @@ -79,10 +88,7 @@ Status TestAutoSelectEPsImpl(const Environment& env, InferenceSession& sess, con if (status != nullptr) { return ToStatus(status); } - - // in the real setup we need an IExecutionProvider wrapper implementation that uses the OrtEp internally, - // and we would add that IExecutionProvider to the InferenceSession. - ORT_NOT_IMPLEMENTED("IExecutionProvider that wraps OrtEp has not been implemented."); + */ } ORT_RETURN_IF_ERROR(sess.RegisterExecutionProvider(std::move(ep))); @@ -117,13 +123,14 @@ common::Status CopyStringToOutputArg(std::string_view str, const char* err_msg, return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, err_msg); } -// provider either model_path, or modal_data + model_data_length. -OrtStatus* CreateSessionAndLoadModel(_In_ const OrtSessionOptions* options, - _In_ const OrtEnv* env, - _In_opt_z_ const ORTCHAR_T* model_path, - _In_opt_ const void* model_data, - size_t model_data_length, - std::unique_ptr& sess) { +// Internal function that creates an InferenceSession and loads the model. +// Caller should provide either model_path, or modal_data + model_data_length. +static OrtStatus* CreateSessionAndLoadModelImpl(_In_ const OrtSessionOptions* options, + const onnxruntime::Environment& env, + _In_opt_z_ const ORTCHAR_T* model_path, + _In_opt_ const void* model_data, + size_t model_data_length, + std::unique_ptr& sess) { // quick check here to decide load path. InferenceSession will provide error message for invalid values. // TODO: Could move to a helper const Env& os_env = Env::Default(); // OS environment (!= ORT environment) @@ -152,12 +159,12 @@ OrtStatus* CreateSessionAndLoadModel(_In_ const OrtSessionOptions* options, if (model_path != nullptr) { sess = std::make_unique( options == nullptr ? onnxruntime::SessionOptions() : options->value, - env->GetEnvironment(), + env, model_path); } else { sess = std::make_unique( options == nullptr ? onnxruntime::SessionOptions() : options->value, - env->GetEnvironment(), + env, model_data, static_cast(model_data_length)); } #else @@ -166,17 +173,9 @@ OrtStatus* CreateSessionAndLoadModel(_In_ const OrtSessionOptions* options, } else { sess = std::make_unique( options == nullptr ? onnxruntime::SessionOptions() : options->value, - env->GetEnvironment()); + env); } -#if !defined(ORT_MINIMAL_BUILD) - // TEMPORARY for testing. Manually specify the EP to select. - auto auto_select_ep_name = sess->GetSessionOptions().config_options.GetConfigEntry("test.ep_to_select"); - if (auto_select_ep_name) { - ORT_API_RETURN_IF_STATUS_NOT_OK(TestAutoSelectEPsImpl(env->GetEnvironment(), *sess, *auto_select_ep_name)); - } -#endif // !defined(ORT_MINIMAL_BUILD) - #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS) // Add custom domains if (options && !options->custom_op_domains_.empty()) { @@ -200,6 +199,17 @@ OrtStatus* CreateSessionAndLoadModel(_In_ const OrtSessionOptions* options, return nullptr; } +// Creates an InferenceSession and loads the model. +// Caller should provide either model_path, or modal_data + model_data_length. +OrtStatus* CreateSessionAndLoadModel(_In_ const OrtSessionOptions* options, + _In_ const OrtEnv* env, + _In_opt_z_ const ORTCHAR_T* model_path, + _In_opt_ const void* model_data, + size_t model_data_length, + std::unique_ptr& sess) { + return CreateSessionAndLoadModelImpl(options, env->GetEnvironment(), model_path, model_data, model_data_length, sess); +} + OrtStatus* InitializeSession(_In_ const OrtSessionOptions* options, _In_ onnxruntime::InferenceSession& sess, _Inout_opt_ OrtPrepackedWeightsContainer* prepacked_weights_container) { @@ -207,22 +217,38 @@ OrtStatus* InitializeSession(_In_ const OrtSessionOptions* options, ORT_ENFORCE(session_logger != nullptr, "Session logger is invalid, but should have been initialized during session construction."); - // we need to disable mem pattern if DML is one of the providers since DML doesn't have the concept of - // byte addressable memory - std::vector> provider_list; - if (options) { + const bool has_provider_factories = options != nullptr && !options->provider_factories.empty(); + + if (has_provider_factories) { + std::vector> provider_list; for (auto& factory : options->provider_factories) { auto provider = factory->CreateProvider(*options, *session_logger->ToExternal()); provider_list.push_back(std::move(provider)); } + + // register the providers + for (auto& provider : provider_list) { + if (provider) { + ORT_API_RETURN_IF_STATUS_NOT_OK(sess.RegisterExecutionProvider(std::move(provider))); + } + } } +#if !defined(ORT_MINIMAL_BUILD) + else { + // TEMPORARY for testing. Manually specify the EP to select. + auto auto_select_ep_name = sess.GetSessionOptions().config_options.GetConfigEntry("test.ep_to_select"); + if (auto_select_ep_name) { + ORT_API_RETURN_IF_STATUS_NOT_OK(TestAutoSelectEPsImpl(sess.GetEnvironment(), sess, *auto_select_ep_name)); + } - // register the providers - for (auto& provider : provider_list) { - if (provider) { - ORT_API_RETURN_IF_STATUS_NOT_OK(sess.RegisterExecutionProvider(std::move(provider))); + // if there are no providers registered, and there's an ep selection policy set, do auto ep selection. + // note: the model has already been loaded so model metadata should be available to the policy delegate callback. + if (options != nullptr && options->value.ep_selection_policy.enable) { + ProviderPolicyContext context; + ORT_API_RETURN_IF_STATUS_NOT_OK(context.SelectEpsForSession(sess.GetEnvironment(), *options, sess)); } } +#endif // !defined(ORT_MINIMAL_BUILD) if (prepacked_weights_container != nullptr) { ORT_API_RETURN_IF_STATUS_NOT_OK(sess.AddPrePackedWeightsContainer( @@ -236,6 +262,28 @@ OrtStatus* InitializeSession(_In_ const OrtSessionOptions* options, namespace onnxruntime { #if !defined(ORT_MINIMAL_BUILD) +Status CompileModel(const Environment& env, const ModelCompilationOptions& model_compile_options) { + ORT_RETURN_IF_ERROR(model_compile_options.Check()); + + std::unique_ptr session; + const OrtSessionOptions* session_options = &model_compile_options.GetSessionOptions(); + + if (model_compile_options.InputModelComesFromFile()) { + PathString input_model_path = ToPathString(model_compile_options.GetInputModelPath()); + ORT_RETURN_IF_ERROR(ToStatus(CreateSessionAndLoadModelImpl(session_options, env, + input_model_path.c_str(), + nullptr, 0, session))); + } else { + ORT_RETURN_IF_ERROR(ToStatus(CreateSessionAndLoadModelImpl(session_options, env, nullptr, + model_compile_options.GetInputModelData(), + model_compile_options.GetInputModelDataSize(), + session))); + } + + ORT_RETURN_IF_ERROR(ToStatus(InitializeSession(session_options, *session))); + return Status::OK(); +} + Status LoadPluginOrProviderBridge(const std::string& registration_name, const ORTCHAR_T* library_path, std::unique_ptr& ep_library, @@ -274,5 +322,65 @@ Status LoadPluginOrProviderBridge(const std::string& registration_name, return Status::OK(); } -#endif + +Status CreateIExecutionProviderFactoryForEpDevices(const Environment& env, + SessionOptions& session_options, + gsl::span ep_devices, + gsl::span ep_option_keys, + gsl::span ep_option_vals, + /*output*/ std::unique_ptr& out) { + if (ep_devices.empty()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Must provide one or more OrtEpDevice instances."); + } + + const size_t num_ep_options = ep_option_keys.size(); + if (ep_option_vals.size() != num_ep_options) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Must provide the same number of keys and values for EP options."); + } + + const auto& ep_name = ep_devices[0]->ep_name; + bool all_match = std::all_of(ep_devices.begin() + 1, ep_devices.end(), + [&ep_name](const OrtEpDevice* ep_device) { return ep_device->ep_name == ep_name; }); + if (!all_match) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "All OrtEpDevice values in ep_devices must have the same execution provider."); + } + + EpFactoryInternal* internal_factory = nullptr; + for (const OrtEpDevice* ep_device : ep_devices) { + // we expect the internal factory to be available for internal and provider bridge EPs, which is all we support. + internal_factory = env.GetEpFactoryInternal(ep_device->ep_factory); + if (!internal_factory) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "EP is not currently supported by this API"); + } + + // add the options to the session options with the EP prefix. + // first add the default values with prefix followed by user specified values so those win + const std::string prefix = OrtSessionOptions::GetProviderOptionPrefix(ep_device->ep_name.c_str()); + auto& config_options = session_options.config_options; + for (const auto& [key, value] : ep_device->ep_options.entries) { + ORT_RETURN_IF_ERROR(config_options.AddConfigEntry((prefix + key).c_str(), value.c_str())); + } + + for (size_t j = 0; j < num_ep_options; ++j) { + if (ep_option_keys[j] == nullptr) { + continue; + } + + ORT_RETURN_IF_ERROR(config_options.AddConfigEntry((prefix + ep_option_keys[j]).c_str(), ep_option_vals[j])); + } + } + + if (!internal_factory) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "EP is not currently supported by this API"); + } + + out = std::make_unique(*internal_factory, + std::vector(ep_devices.begin(), + ep_devices.end())); + return Status::OK(); +} +#endif // !defined(ORT_MINIMAL_BUILD) } // namespace onnxruntime diff --git a/onnxruntime/core/session/utils.h b/onnxruntime/core/session/utils.h index 535a7b041609d..5a5dcae9165ed 100644 --- a/onnxruntime/core/session/utils.h +++ b/onnxruntime/core/session/utils.h @@ -3,7 +3,10 @@ #pragma once +#include +#include #include +#include #include "core/common/common.h" #include "core/session/onnxruntime_c_api.h" @@ -14,9 +17,18 @@ struct OrtStatus; struct OrtPrepackedWeightsContainer; namespace onnxruntime { class InferenceSession; +class ModelCompilationOptions; +} // namespace onnxruntime + +#if !defined(ORT_MINIMAL_BUILD) +namespace onnxruntime { +class Environment; class EpLibrary; class EpFactoryInternal; +struct IExecutionProviderFactory; +struct SessionOptions; } // namespace onnxruntime +#endif // !defined(ORT_MINIMAL_BUILD) OrtStatus* CreateSessionAndLoadModel(_In_ const OrtSessionOptions* options, _In_ const OrtEnv* env, @@ -29,8 +41,18 @@ OrtStatus* InitializeSession(_In_ const OrtSessionOptions* options, _In_ onnxruntime::InferenceSession& sess, _Inout_opt_ OrtPrepackedWeightsContainer* prepacked_weights_container = nullptr); +#if !defined(ORT_MINIMAL_BUILD) namespace onnxruntime { +/// +/// Compiles an ONNX model into a model with EPContext nodes. Each EPContext node represents a subgraph compiled for +/// a specific execution provider. +/// +/// A reference to the Environment instance. +/// An object specifying the compilation options. +/// A Status indicating an error or success. +Status CompileModel(const Environment& env, const ModelCompilationOptions& model_compile_options); + // load a library that is added using RegisterExecutionProviderLibrary. // infer whether it's a provider bridge library or plugin library Status LoadPluginOrProviderBridge(const std::string& registration_name, @@ -38,4 +60,14 @@ Status LoadPluginOrProviderBridge(const std::string& registration_name, std::unique_ptr& ep_library, std::vector& internal_factories); +// Creates an IExecutionProviderFactory instance for a list of OrtEpDevices that all refer to the same EP. +// Adds all provider options to the OrtSessionOptions configuration. +Status CreateIExecutionProviderFactoryForEpDevices(const Environment& env, + SessionOptions& session_options, + gsl::span ep_devices, + gsl::span ep_options_keys, + gsl::span ep_options_vals, + /*output*/ std::unique_ptr& out); + } // namespace onnxruntime +#endif // !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/python/onnxruntime_inference_collection.py b/onnxruntime/python/onnxruntime_inference_collection.py index ed0298a85b8e7..15c423d7285bc 100644 --- a/onnxruntime/python/onnxruntime_inference_collection.py +++ b/onnxruntime/python/onnxruntime_inference_collection.py @@ -172,10 +172,10 @@ class Session: This is the main class used to run a model. """ - def __init__(self): + def __init__(self, enable_fallback: bool = True): # self._sess is managed by the derived class and relies on bindings from C.InferenceSession self._sess = None - self._enable_fallback = True + self._enable_fallback = enable_fallback def get_session_options(self) -> onnxruntime.SessionOptions: "Return the session options. See :class:`onnxruntime.SessionOptions`." @@ -446,7 +446,7 @@ def __init__( means execute a node using `CUDAExecutionProvider` if capable, otherwise execute using `CPUExecutionProvider`. """ - super().__init__() + super().__init__(enable_fallback=int(kwargs.get("enable_fallback", 1)) == 1) if isinstance(path_or_bytes, (str, os.PathLike)): self._model_path = os.fspath(path_or_bytes) @@ -459,7 +459,6 @@ def __init__( self._sess_options = sess_options self._sess_options_initial = sess_options - self._enable_fallback = True if "read_config_from_model" in kwargs: self._read_config_from_model = int(kwargs["read_config_from_model"]) == 1 else: @@ -542,6 +541,16 @@ def _create_inference_session(self, providers, provider_options, disabled_optimi providers, provider_options, available_providers ) + # Print a warning if user passed providers to InferenceSession() but the SessionOptions instance + # already has provider information (e.g., via add_provider_for_devices()). The providers specified + # here will take precedence. + if self._sess_options is not None and (providers or provider_options) and self._sess_options.has_providers(): + warnings.warn( + "Specified 'providers'/'provider_options' when creating InferenceSession but SessionOptions has " + "already been configured with providers. InferenceSession will only use the providers " + "passed to InferenceSession()." + ) + session_options = self._sess_options if self._sess_options else C.get_default_session_options() self._register_ep_custom_ops(session_options, providers, provider_options, available_providers) @@ -609,6 +618,115 @@ def _register_ep_custom_ops(self, session_options, providers, provider_options, C.register_nv_tensorrt_rtx_plugins_as_custom_ops(session_options, providers[i][1]) +class ModelCompiler: + """ + This class is used to compile an ONNX model. A compiled ONNX model has EPContext nodes that each + encapsulates a subgraph compiled/optimized for a specific execution provider. + + Refer to the EPContext design document for more information about EPContext models: + https://onnxruntime.ai/docs/execution-providers/EP-Context-Design.html + + :: + + sess_options = onnxruntime.SessionOptions() + sess_options.add_provider("SomeExecutionProvider", {"option1": "value1"}) + # Alternatively, allow ONNX Runtime to select the provider automatically given a policy: + # sess_options.set_provider_selection_policy(onnxrt.OrtExecutionProviderDevicePolicy.PREFER_NPU) + + model_compiler = onnxruntime.ModelCompiler(sess_options, "input_model.onnx") + model_compiler.compile_to_file("output_model.onnx") + """ + + def __init__( + self, + sess_options: onnxruntime.SessionOptions, + input_model_path_or_bytes: str | os.PathLike | bytes, + embed_compiled_data_into_model: bool = False, + external_initializers_file_path: str | os.PathLike | None = None, + external_initializers_size_threshold: int = 1024, + ): + """ + Creates a ModelCompiler instance. + + :param sess_options: Session options containing the providers for which the model will be compiled. + Refer to SessionOptions.add_provider() and SessionOptions.set_provider_selection_policy(). + :param input_model_path_or_bytes: The path to the input model file or bytes representing a serialized + ONNX model. + :param embed_compiled_data_into_model: Defaults to False. Set to True to embed compiled binary data into + EPContext nodes in the compiled model. + :param external_initializers_file_path: Defaults to None. Set to a path for a file that will store the + initializers for non-compiled nodes. + :param external_initializers_size_threshold: Defaults to 1024. Ignored if `external_initializers_file_path` + is None or empty. Initializers larger than this threshold are stored in the external initializers file. + """ + input_model_path: str | os.PathLike | None = None + input_model_bytes: bytes | None = None + if isinstance(input_model_path_or_bytes, (str, os.PathLike)): + if not input_model_path_or_bytes: + raise ValueError("Input model path is empty") + input_model_path = os.fspath(input_model_path_or_bytes) + elif isinstance(input_model_path_or_bytes, bytes): + if len(input_model_path_or_bytes) == 0: + raise ValueError("Input model bytes array is empty") + input_model_bytes = input_model_path_or_bytes + else: + raise TypeError(f"Unable to load from type '{type(input_model_path_or_bytes)}'") + + if external_initializers_file_path: + if not isinstance(external_initializers_file_path, (str, os.PathLike)): + arg_type = type(external_initializers_file_path) + raise TypeError(f"Output external initializer filepath is of unexpected type '{arg_type}'") + external_initializers_file_path = os.fspath(external_initializers_file_path) + else: + external_initializers_file_path = "" + + if input_model_path: + self._model_compiler = C.ModelCompiler( + sess_options, + input_model_path, + True, # is path + embed_compiled_data_into_model, + external_initializers_file_path, + external_initializers_size_threshold, + ) + else: + self._model_compiler = C.ModelCompiler( + sess_options, + input_model_bytes, + False, # is bytes + embed_compiled_data_into_model, + external_initializers_file_path, + external_initializers_size_threshold, + ) + + def compile_to_file(self, output_model_path: str | None = None): + """ + Compiles to an output file. If an output file path is not provided, + the output file path is generated based on the input model path by replacing + '.onnx' with '_ctx.onnx'. Ex: The generated output file is 'model_ctx.onnx' for + an input model with path 'model.onnx'. + + Raises an 'InvalidArgument' exception if the compilation options are invalid. + + :param output_model_path: Defaults to None. The path for the output/compiled model. + """ + if output_model_path: + if not isinstance(output_model_path, (str, os.PathLike)): + raise TypeError(f"Output model's filepath is of unexpected type '{type(output_model_path)}'") + output_model_path = os.fspath(output_model_path) + self._model_compiler.compile_to_file(output_model_path) + + def compile_to_bytes(self) -> bytes: + """ + Compiles to bytes representing the serialized compiled ONNX model. + + Raises an 'InvalidArgument' exception if the compilation options are invalid. + + :return: A bytes object representing the compiled ONNX model. + """ + return self._model_compiler.compile_to_bytes() + + class IOBinding: """ This class provides API to bind input/output to a specified device, e.g. GPU. @@ -888,6 +1006,13 @@ def element_type(self) -> int: """ return self._ortvalue.element_type() + def tensor_size_in_bytes(self) -> int: + """ + Returns the size of the data in the OrtValue in bytes + if the OrtValue is a tensor. + """ + return self._ortvalue.tensor_size_in_bytes() + def has_value(self) -> bool: """ Returns True if the OrtValue corresponding to an diff --git a/onnxruntime/python/onnxruntime_pybind_exceptions.cc b/onnxruntime/python/onnxruntime_pybind_exceptions.cc index 1a1aae6f48ad1..8f3b97c8c7786 100644 --- a/onnxruntime/python/onnxruntime_pybind_exceptions.cc +++ b/onnxruntime/python/onnxruntime_pybind_exceptions.cc @@ -35,6 +35,8 @@ void RegisterExceptions(pybind11::module& m) { pybind11::register_exception(m, "NotImplemented"); pybind11::register_exception(m, "InvalidGraph"); pybind11::register_exception(m, "EPFail"); + pybind11::register_exception(m, "ModelLoadCanceled"); + pybind11::register_exception(m, "ModelRequiresCompilation"); } void OrtPybindThrowIfError(onnxruntime::common::Status status) { @@ -61,6 +63,10 @@ void OrtPybindThrowIfError(onnxruntime::common::Status status) { throw InvalidGraph(std::move(msg)); case onnxruntime::common::StatusCode::EP_FAIL: throw EPFail(std::move(msg)); + case onnxruntime::common::StatusCode::MODEL_LOAD_CANCELED: + throw ModelLoadCanceled(std::move(msg)); + case onnxruntime::common::StatusCode::MODEL_REQUIRES_COMPILATION: + throw ModelRequiresCompilation(std::move(msg)); default: throw std::runtime_error(std::move(msg)); } diff --git a/onnxruntime/python/onnxruntime_pybind_exceptions.h b/onnxruntime/python/onnxruntime_pybind_exceptions.h index bc7d31ff2be2d..86bc4a5da8d46 100644 --- a/onnxruntime/python/onnxruntime_pybind_exceptions.h +++ b/onnxruntime/python/onnxruntime_pybind_exceptions.h @@ -44,6 +44,12 @@ struct InvalidGraph : std::runtime_error { struct EPFail : std::runtime_error { explicit EPFail(const std::string& what) : std::runtime_error(what) {} }; +struct ModelLoadCanceled : std::runtime_error { + explicit ModelLoadCanceled(const std::string& what) : std::runtime_error(what) {} +}; +struct ModelRequiresCompilation : std::runtime_error { + explicit ModelRequiresCompilation(const std::string& what) : std::runtime_error(what) {} +}; void RegisterExceptions(pybind11::module& m); diff --git a/onnxruntime/python/onnxruntime_pybind_model_compiler.cc b/onnxruntime/python/onnxruntime_pybind_model_compiler.cc new file mode 100644 index 0000000000000..8bb7ee2098caf --- /dev/null +++ b/onnxruntime/python/onnxruntime_pybind_model_compiler.cc @@ -0,0 +1,80 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// Licensed under the MIT License. +#include "python/onnxruntime_pybind_model_compiler.h" + +#include +#include +#include +#include "core/common/common.h" +#include "core/framework/error_code_helper.h" +#include "core/session/utils.h" + +namespace onnxruntime { +namespace python { + +onnxruntime::Status PyModelCompiler::Create(/*out*/ std::unique_ptr& out, + std::shared_ptr env, + const PySessionOptions& sess_options, + std::string&& input_model_path_or_bytes, bool input_model_is_path, + bool embed_compiled_data_into_model, + const std::string& external_initializers_file_path, + size_t external_initializers_size_threshold) { + auto model_compiler = std::make_unique(env, sess_options, PrivateConstructorTag{}); + ModelCompilationOptions& compile_options = model_compiler->model_compile_options_; + + if (input_model_is_path) { + compile_options.SetInputModelPath(input_model_path_or_bytes); + } else { + model_compiler->input_model_bytes_ = std::move(input_model_path_or_bytes); + compile_options.SetInputModelFromBuffer(reinterpret_cast(model_compiler->input_model_bytes_.data()), + model_compiler->input_model_bytes_.size()); + } + + ORT_RETURN_IF_ERROR(compile_options.SetEpContextEmbedMode(embed_compiled_data_into_model)); + + if (!external_initializers_file_path.empty()) { + compile_options.SetOutputModelExternalInitializersFile(external_initializers_file_path, + external_initializers_size_threshold); + } + + out = std::move(model_compiler); + return Status::OK(); +} + +onnxruntime::Status PyModelCompiler::CompileToFile(const std::string& output_model_path) { + ORT_RETURN_IF_ERROR(model_compile_options_.SetOutputModelPath(output_model_path)); + ORT_RETURN_IF_ERROR(onnxruntime::CompileModel(*env_, model_compile_options_)); + return Status::OK(); +} + +onnxruntime::Status PyModelCompiler::CompileToBytes(std::string& output_buffer) { + if (!output_buffer.empty()) { + // Opt to return an error if the output buffer is not empty instead of just calling output_buffer.clear() + // because the C++ standard does not explicitly require that capacity is unchanged by a call to clear(). + // Don't want to reallocate a large buffer an extra time unnecessarily. So, we'll consider this an internal + // ORT error. + // Refer to: https://en.cppreference.com/w/cpp/string/basic_string/clear + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Output buffer should be empty."); + } + + onnxruntime::AllocatorPtr allocator = std::make_shared(); + + void* buffer_data = nullptr; + size_t buffer_size = 0; + ORT_RETURN_IF_ERROR(model_compile_options_.SetOutputModelBuffer(allocator, &buffer_data, &buffer_size)); + ORT_RETURN_IF_ERROR(onnxruntime::CompileModel(*env_, model_compile_options_)); + + // Copy into output buffer. + output_buffer.reserve(buffer_size); + gsl::span src(reinterpret_cast(buffer_data), buffer_size); + std::copy(src.begin(), src.end(), std::back_inserter(output_buffer)); + return Status::OK(); +} + +PyModelCompiler::PyModelCompiler(std::shared_ptr env, const PySessionOptions& sess_options, + PrivateConstructorTag) + : env_(env), model_compile_options_(*env, sess_options) { +} +} // namespace python +} // namespace onnxruntime diff --git a/onnxruntime/python/onnxruntime_pybind_model_compiler.h b/onnxruntime/python/onnxruntime_pybind_model_compiler.h new file mode 100644 index 0000000000000..6c9f48fa00ba6 --- /dev/null +++ b/onnxruntime/python/onnxruntime_pybind_model_compiler.h @@ -0,0 +1,78 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// Licensed under the MIT License. +#pragma once + +#if !defined(ORT_MINIMAL_BUILD) +#include +#include +#include "core/common/status.h" +#include "core/session/model_compilation_options.h" +#include "python/onnxruntime_pybind_state_common.h" + +namespace onnxruntime { +class Environment; + +namespace python { +/// +/// Class exposed to Python that enables compiling ONNX models. +/// Internally wraps a onnxruntime::ModelCompilationOptions that stores and validates settings. +/// +class PyModelCompiler { + private: + // private tag to pass to constructor to ensure that constructor cannot be directly called externally + struct PrivateConstructorTag {}; + + public: + /// + /// Static class function that creates a unique_ptr with the given settings. + /// + /// Output parameter for the result + /// The Environment instance + /// The SessionOptions from which to initialize compilation options. + /// An r-value string that could be the input model's path or bytes + /// True if 'input_model_path_or_bytes' is a path, and false if its bytes. + /// True to embed compiled binary data into EPContext nodes. + /// The file into which to store initializers for non-compiled + /// nodes. + /// Ignored if 'external_initializers_file_path' is empty. + /// Initializers with a size greater than this threshold are dumped into the external file. + /// A Status indicating error or success. + static onnxruntime::Status Create(/*out*/ std::unique_ptr& out, + std::shared_ptr env, + const PySessionOptions& sess_options, + std::string&& input_model_path_or_bytes, bool input_model_is_path, + bool embed_compiled_data_into_model = false, + const std::string& external_initializers_file_path = {}, + size_t external_initializers_size_threshold = 1024); + + // Note: Creation should be done via Create(). This constructor is public so that it can be called from + // std::make_shared(). + PyModelCompiler(std::shared_ptr env, const PySessionOptions& sess_options, + PrivateConstructorTag); + + /// + /// Compiles the input model and saves the result to an output file. + /// If the 'output_model_path' is not specified, + /// it is generated based on the input model's path by replacing '.onnx' with '_ctx.onnx'. + /// + /// The path into which to save the compiled model. + /// A Status indicating error or success. + onnxruntime::Status CompileToFile(const std::string& output_model_path = {}); + + /// + /// Compiles the input model and stores the result into a buffer. + /// + /// A reference to the output buffer into which to store the + /// serialized ONNX model bytes. + /// A Status indicating error or success. + onnxruntime::Status CompileToBytes(std::string& output_buffer); + + private: + std::shared_ptr env_; + onnxruntime::ModelCompilationOptions model_compile_options_; + std::string input_model_bytes_; +}; +} // namespace python +} // namespace onnxruntime +#endif // !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/python/onnxruntime_pybind_ortvalue.cc b/onnxruntime/python/onnxruntime_pybind_ortvalue.cc index 66ceacda75e6d..382cd742c96aa 100644 --- a/onnxruntime/python/onnxruntime_pybind_ortvalue.cc +++ b/onnxruntime/python/onnxruntime_pybind_ortvalue.cc @@ -327,6 +327,9 @@ void addOrtValueMethods(pybind11::module& m) { "This integer is one type defined by ONNX TensorProto_DataType " "(such as onnx.TensorProto.FLOAT)." "Raises an exception in any other case.") + .def("tensor_size_in_bytes", [](const OrtValue* ort_value) -> size_t { + ORT_ENFORCE(ort_value->IsTensor(), "Only OrtValues that are Tensors are currently supported"); + return ort_value->Get().SizeInBytes(); }, "Returns tensor size in bytes.") .def("has_value", [](const OrtValue* ort_value) -> bool { return ort_value->IsAllocated(); }) .def("is_tensor", [](const OrtValue* ort_value) -> bool { return ort_value->IsTensor(); }) .def("is_sparse_tensor", [](const OrtValue* ort_value) -> bool { return ort_value->IsSparseTensor(); }) diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index 0f15c5fbbdba0..aa2c0cc6a0f86 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -2,10 +2,15 @@ // SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates // Licensed under the MIT License. +#include #include "python/onnxruntime_pybind_exceptions.h" #include "python/onnxruntime_pybind_mlvalue.h" #include "python/onnxruntime_pybind_state_common.h" +#if !defined(ORT_MINIMAL_BUILD) +#include "python/onnxruntime_pybind_model_compiler.h" +#endif // !defined(ORT_MINIMAL_BUILD) + #define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION #define PY_ARRAY_UNIQUE_SYMBOL onnxruntime_python_ARRAY_API #include "python/numpy_helper.h" @@ -18,6 +23,7 @@ #include "core/framework/arena_extend_strategy.h" #include "core/framework/data_transfer_utils.h" #include "core/framework/data_types_internal.h" +#include "core/framework/error_code_helper.h" #include "core/framework/provider_options_utils.h" #include "core/framework/random_seed.h" #include "core/framework/sparse_tensor.h" @@ -26,6 +32,7 @@ #include "core/graph/graph_viewer.h" #include "core/platform/env.h" #include "core/providers/get_execution_providers.h" +#include "core/providers/providers.h" #include "core/providers/tensorrt/tensorrt_provider_options.h" #include "core/session/IOBinding.h" #include "core/session/abi_session_options_impl.h" @@ -34,6 +41,13 @@ #include "core/session/lora_adapters.h" +#if !defined(ORT_MINIMAL_BUILD) +#include "core/session/abi_devices.h" +#include "core/session/ep_factory_internal.h" +#include "core/session/provider_policy_context.h" +#include "core/session/utils.h" +#endif + #ifdef ENABLE_ATEN #include "contrib_ops/cpu/aten_ops/aten_op_executor.h" #endif @@ -402,7 +416,7 @@ py::object AddTensorAsPyObj(const OrtValue& val, const DataTransferManager* data return GetPyObjFromTensor(val, data_transfer_manager, mem_cpy_to_host_functions); } -static std::unique_ptr LoadExecutionProvider( +static std::shared_ptr LoadExecutionProviderFactory( const std::string& ep_shared_lib_path, const ProviderOptions& provider_options = {}, const std::string& entry_symbol_name = "GetProvider") { @@ -417,8 +431,7 @@ static std::unique_ptr LoadExecutionProvider( OrtPybindThrowIfError(Env::Default().GetSymbolFromLibrary(handle, entry_symbol_name, (void**)&PGetProvider)); Provider* provider = PGetProvider(); - std::shared_ptr ep_factory = provider->CreateExecutionProviderFactory(&provider_options); - return ep_factory->CreateProvider(); + return provider->CreateExecutionProviderFactory(&provider_options); } #if defined(USE_CUDA) || defined(USE_CUDA_PROVIDER_INTERFACE) @@ -539,14 +552,21 @@ void RegisterNvTensorRTRtxPluginsAsCustomOps(PySessionOptions& so, const Provide } #endif -std::unique_ptr CreateExecutionProviderInstance( +/** + * Creates an IExecutionProviderFactory instance of the specified type. + * @param session_options The session options. + * @param type The execution provider type (e.g., CUDAExecutionProvider). + * @param provider_options_map A map of provider options. + * + * @return A shared_ptr with the factory instance, or null if unable to create it. + */ +static std::shared_ptr CreateExecutionProviderFactoryInstance( const SessionOptions& session_options, const std::string& type, const ProviderOptionsMap& provider_options_map) { if (type == kCpuExecutionProvider) { return onnxruntime::CPUProviderFactoryCreator::Create( - session_options.enable_cpu_mem_arena) - ->CreateProvider(); + session_options.enable_cpu_mem_arena); } else if (type == kTensorrtExecutionProvider) { #if defined(USE_TENSORRT) || defined(USE_TENSORRT_PROVIDER_INTERFACE) // If the environment variable 'ORT_TENSORRT_UNAVAILABLE' exists, then we do not load TensorRT. This is set by _ld_preload for the manylinux case @@ -869,11 +889,11 @@ std::unique_ptr CreateExecutionProviderInstance( } } if (std::shared_ptr tensorrt_provider_factory = onnxruntime::TensorrtProviderFactoryCreator::Create(¶ms)) { - return tensorrt_provider_factory->CreateProvider(); + return tensorrt_provider_factory; } } else { if (std::shared_ptr tensorrt_provider_factory = onnxruntime::TensorrtProviderFactoryCreator::Create(cuda_device_id)) { - return tensorrt_provider_factory->CreateProvider(); + return tensorrt_provider_factory; } } } @@ -892,11 +912,11 @@ std::unique_ptr CreateExecutionProviderInstance( ProviderOptions info = it->second; if (std::shared_ptr nv_tensorrt_rtx_provider_factory = onnxruntime::NvProviderFactoryCreator::Create( info, &session_options)) { - return nv_tensorrt_rtx_provider_factory->CreateProvider(); + return nv_tensorrt_rtx_provider_factory; } } else { if (std::shared_ptr nv_tensorrt_rtx_provider_factory = onnxruntime::NvProviderFactoryCreator::Create(cuda_device_id)) { - return nv_tensorrt_rtx_provider_factory->CreateProvider(); + return nv_tensorrt_rtx_provider_factory; } } } @@ -1024,12 +1044,12 @@ std::unique_ptr CreateExecutionProviderInstance( } if (std::shared_ptr migraphx_provider_factory = onnxruntime::MIGraphXProviderFactoryCreator::Create(¶ms)) { - return migraphx_provider_factory->CreateProvider(); + return migraphx_provider_factory; } } else { if (std::shared_ptr migraphx_provider_factory = onnxruntime::MIGraphXProviderFactoryCreator::Create(cuda_device_id)) { - return migraphx_provider_factory->CreateProvider(); + return migraphx_provider_factory; } } #endif @@ -1048,7 +1068,7 @@ std::unique_ptr CreateExecutionProviderInstance( // hence we must try to initialize it here if we can since FromProviderOptions might contain // external CUDA allocator. external_allocator_info = info.external_allocator_info; - return cuda_provider_info->CreateExecutionProviderFactory(info)->CreateProvider(); + return cuda_provider_info->CreateExecutionProviderFactory(info); } } #if defined(USE_CUDA) @@ -1081,7 +1101,7 @@ std::unique_ptr CreateExecutionProviderInstance( // however they still exist and are in-use. Nevertheless, it is used to return ROCMAllocator, hence we must // try to initialize it here if we can since FromProviderOptions might contain external ROCM allocator. external_allocator_info = info.external_allocator_info; - return rocm_provider_info->CreateExecutionProviderFactory(info)->CreateProvider(); + return rocm_provider_info->CreateExecutionProviderFactory(info); } else { if (!Env::Default().GetEnvironmentVar("ROCM_PATH").empty()) { ORT_THROW( @@ -1118,7 +1138,7 @@ std::unique_ptr CreateExecutionProviderInstance( #endif // !defined(DNNL_ORT_THREAD) dnnl_options.use_arena = session_options.enable_cpu_mem_arena; - return onnxruntime::DnnlProviderFactoryCreator::Create(&dnnl_options)->CreateProvider(); + return onnxruntime::DnnlProviderFactoryCreator::Create(&dnnl_options); #endif } else if (type == kOpenVINOExecutionProvider) { #if defined(USE_OPENVINO) || defined(USE_OPENVINO_PROVIDER_INTERFACE) @@ -1189,10 +1209,9 @@ std::unique_ptr CreateExecutionProviderInstance( } if (std::shared_ptr openvino_provider_factory = onnxruntime::OpenVINOProviderFactoryCreator::Create( &OV_provider_options_map, &session_options)) { - auto p = openvino_provider_factory->CreateProvider(); // Reset global variables config to avoid it being accidentally passed on to the next session openvino_device_type.clear(); - return p; + return openvino_provider_factory; } else { if (!Env::Default().GetEnvironmentVar("INTEL_OPENVINO_DIR").empty()) { ORT_THROW("INTEL_OPENVINO_DIR is set but OpenVINO library wasn't able to be loaded. Please install a supported version of OpenVINO as mentioned in the requirements page (https://onnxruntime.ai/docs/execution-providers/OpenVINO-ExecutionProvider.html#requirements), ensure dependency libraries are in the PATH and your hardware is supported."); @@ -1210,7 +1229,7 @@ std::unique_ptr CreateExecutionProviderInstance( } info["session_options"] = std::to_string((uintptr_t)(void*)&session_options); if (auto vitisai_factory = onnxruntime::VitisAIProviderFactoryCreator::Create(info); vitisai_factory) { - return vitisai_factory->CreateProvider(); + return vitisai_factory; } LOGS_DEFAULT(WARNING) << "Failed to create " << type @@ -1238,21 +1257,18 @@ std::unique_ptr CreateExecutionProviderInstance( } } } - return onnxruntime::ACLProviderFactoryCreator::Create(enable_fast_math) - ->CreateProvider(); + return onnxruntime::ACLProviderFactoryCreator::Create(enable_fast_math); #endif } else if (type == kArmNNExecutionProvider) { #ifdef USE_ARMNN return onnxruntime::ArmNNProviderFactoryCreator::Create( - session_options.enable_cpu_mem_arena) - ->CreateProvider(); + session_options.enable_cpu_mem_arena); #endif } else if (type == kDmlExecutionProvider) { #ifdef USE_DML auto cit = provider_options_map.find(type); return onnxruntime::DMLProviderFactoryCreator::CreateFromProviderOptions( - session_options.config_options, cit == provider_options_map.end() ? ProviderOptions{} : cit->second, true) - ->CreateProvider(); + session_options.config_options, cit == provider_options_map.end() ? ProviderOptions{} : cit->second, true); #endif } else if (type == kNnapiExecutionProvider) { #if defined(USE_NNAPI) @@ -1261,15 +1277,15 @@ std::unique_ptr CreateExecutionProviderInstance( #endif const auto partitioning_stop_ops_list = session_options.config_options.GetConfigEntry( kOrtSessionOptionsConfigNnapiEpPartitioningStopOps); - return onnxruntime::NnapiProviderFactoryCreator::Create(0, partitioning_stop_ops_list)->CreateProvider(); + return onnxruntime::NnapiProviderFactoryCreator::Create(0, partitioning_stop_ops_list); #endif } else if (type == kVSINPUExecutionProvider) { #ifdef USE_VSINPU - return onnxruntime::VSINPUProviderFactoryCreator::Create()->CreateProvider(); + return onnxruntime::VSINPUProviderFactoryCreator::Create(); #endif } else if (type == kRknpuExecutionProvider) { #ifdef USE_RKNPU - return onnxruntime::RknpuProviderFactoryCreator::Create()->CreateProvider(); + return onnxruntime::RknpuProviderFactoryCreator::Create(); #endif } else if (type == kCoreMLExecutionProvider) { #if defined(USE_COREML) @@ -1300,36 +1316,35 @@ std::unique_ptr CreateExecutionProviderInstance( } } else { // read from provider_options - return onnxruntime::CoreMLProviderFactoryCreator::Create(options)->CreateProvider(); + return onnxruntime::CoreMLProviderFactoryCreator::Create(options); } } - return onnxruntime::CoreMLProviderFactoryCreator::Create(coreml_flags)->CreateProvider(); + return onnxruntime::CoreMLProviderFactoryCreator::Create(coreml_flags); #endif } else if (type == kXnnpackExecutionProvider) { #if defined(USE_XNNPACK) auto cit = provider_options_map.find(type); return onnxruntime::XnnpackProviderFactoryCreator::Create( - cit == provider_options_map.end() ? ProviderOptions{} : cit->second, &session_options) - ->CreateProvider(); + cit == provider_options_map.end() ? ProviderOptions{} : cit->second, &session_options); #endif } else if (type == kWebGpuExecutionProvider) { #if defined(USE_WEBGPU) - return onnxruntime::WebGpuProviderFactoryCreator::Create(session_options.config_options)->CreateProvider(); + return onnxruntime::WebGpuProviderFactoryCreator::Create(session_options.config_options); #endif } else if (type == kCannExecutionProvider) { #ifdef USE_CANN if (auto* cann_provider_info = TryGetProviderInfo_CANN()) { const CANNExecutionProviderInfo info = GetCannExecutionProviderInfo(cann_provider_info, provider_options_map); - return cann_provider_info->CreateExecutionProviderFactory(info)->CreateProvider(); + return cann_provider_info->CreateExecutionProviderFactory(info); } else { ORT_THROW("create CANN ExecutionProvider fail"); } #endif } else if (type == kAzureExecutionProvider) { #ifdef USE_AZURE - return onnxruntime::AzureProviderFactoryCreator::Create({})->CreateProvider(); + return onnxruntime::AzureProviderFactoryCreator::Create({}); #endif } else if (type == kQnnExecutionProvider) { #if defined(USE_QNN) || defined(USE_QNN_PROVIDER_INTERFACE) @@ -1337,7 +1352,7 @@ std::unique_ptr CreateExecutionProviderInstance( auto qnn_factory = onnxruntime::QNNProviderFactoryCreator::Create( cit == provider_options_map.end() ? ProviderOptions{} : cit->second, &session_options); if (qnn_factory) { - return qnn_factory->CreateProvider(); + return qnn_factory; } LOGS_DEFAULT(WARNING) << "Failed to create " << type @@ -1362,7 +1377,7 @@ std::unique_ptr CreateExecutionProviderInstance( provider_options.insert(option); } } - return LoadExecutionProvider(shared_lib_path_it->second, provider_options, entry_symbol); + return LoadExecutionProviderFactory(shared_lib_path_it->second, provider_options, entry_symbol); } } // unknown provider @@ -1371,20 +1386,58 @@ std::unique_ptr CreateExecutionProviderInstance( return nullptr; } +/** + * Create an IExecutionProvider instance of the specified type. Note: this is called by orttraining code. + * @param session_options The session options. + * @param type The execution provider type (e.g., CUDAExecutionProvider). + * @param provider_options_map A map of provider options. + * + * @return A unique_ptr with the execution provider instance, or null if unable to create it. + */ +std::unique_ptr CreateExecutionProviderInstance(const SessionOptions& session_options, + const std::string& type, + const ProviderOptionsMap& provider_options_map) { + auto ep_factory = CreateExecutionProviderFactoryInstance(session_options, type, provider_options_map); + if (ep_factory) { + return ep_factory->CreateProvider(); + } + return nullptr; +} + /* * Register execution provider with options. */ static void RegisterExecutionProviders(InferenceSession* sess, const std::vector& provider_types, const ProviderOptionsMap& provider_options_map) { - ORT_UNUSED_PARAMETER(provider_options_map); - for (const std::string& type : provider_types) { auto ep = CreateExecutionProviderInstance(sess->GetSessionOptions(), type, provider_options_map); - if (ep) + if (ep) { OrtPybindThrowIfError(sess->RegisterExecutionProvider(std::move(ep))); + } } } +/** + * Adds an explicit execution provider factory to the session options. + * + * @param py_sess_options The session options. + * @param provider_type The type of the provider to add. + * @param provider_options The options for the execution provider as a map of string key/value pairs. + * + * @return A Status indicating an error or success. + */ +static Status AddExplicitEpFactory(PySessionOptions& py_sess_options, const std::string& provider_type, + const ProviderOptions& provider_options) { + const ProviderOptionsMap provider_options_map = {{provider_type, provider_options}}; + auto ep_factory = CreateExecutionProviderFactoryInstance(py_sess_options.value, provider_type, provider_options_map); + if (!ep_factory) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Failed to add provider of type '", + provider_type, "' to SessionOptions. Provider configuration is not supported."); + } + py_sess_options.provider_factories.push_back(std::move(ep_factory)); + return Status::OK(); +} + /** * Generate a map for mapping execution provider to excution provider options. * @@ -1426,6 +1479,83 @@ static void RegisterCustomOpDomains(PyInferenceSession* sess, const PySessionOpt } #endif +#if !defined(ORT_MINIMAL_BUILD) +/** + * Add the execution provider that is responsible for the selected OrtEpDevice instances to the session options. + * + * @param py_sess_options The session options. + * @param provider_type The type of the provider to add. + * @param provider_options The options for the execution provider as a map of string key/value pairs. + * + * @return A Status indicating an error or success. + */ +static Status AddEpFactoryFromEpDevices(PySessionOptions& py_sess_options, + const std::vector& ep_devices, + const ProviderOptions& provider_options) { + std::shared_ptr env = GetEnv(); + const size_t num_ep_options = provider_options.size(); + std::vector ep_option_keys; + std::vector ep_option_vals; + + ep_option_keys.reserve(num_ep_options); + ep_option_vals.reserve(num_ep_options); + for (const auto& [key, val] : provider_options) { + ep_option_keys.push_back(key.c_str()); + ep_option_vals.push_back(val.c_str()); + } + + std::unique_ptr provider_factory = nullptr; + ORT_RETURN_IF_ERROR(CreateIExecutionProviderFactoryForEpDevices(*env, + py_sess_options.value, + ep_devices, + ep_option_keys, + ep_option_vals, + /*output*/ provider_factory)); + py_sess_options.provider_factories.push_back(std::move(provider_factory)); + return Status::OK(); +} + +/** + * Initializes the inference session using EPs specified in the session options. + * + * @param py_sess The inference session. + * @param disabled_optimizer_names Set of optimizers to disable. + * @return A Status indicating error or success. + */ +static Status InitializeSessionEpsFromSessionOptions(PyInferenceSession& py_sess, + const std::unordered_set& disabled_optimizer_names) { + ORT_RETURN_IF(py_sess.GetSessionHandle() == nullptr, "Invalid Python InferenceSession handle"); + InferenceSession& sess = *py_sess.GetSessionHandle(); + + const logging::Logger* sess_logger = sess.GetLogger(); + ORT_RETURN_IF(sess_logger == nullptr, "Invalid InferenceSession logger handle"); + + const OrtSessionOptions& ort_session_options = py_sess.GetOrtSessionOptions(); + + // if there are no providers registered, and there's an ep selection policy set, do auto ep selection + if (ort_session_options.provider_factories.empty() && ort_session_options.value.ep_selection_policy.enable) { + ProviderPolicyContext context; + ORT_RETURN_IF_ERROR(context.SelectEpsForSession(*GetEnv(), ort_session_options, sess)); + } else { + for (const auto& provider_factory : ort_session_options.provider_factories) { + std::unique_ptr ep = provider_factory->CreateProvider(ort_session_options, + *(sess_logger->ToExternal())); + if (ep) { + ORT_RETURN_IF_ERROR(sess.RegisterExecutionProvider(std::move(ep))); + } + } + } + + if (!disabled_optimizer_names.empty()) { + ORT_RETURN_IF_ERROR(sess.FilterEnabledOptimizers({disabled_optimizer_names.cbegin(), + disabled_optimizer_names.cend()})); + } + + ORT_RETURN_IF_ERROR(sess.Initialize()); + return Status::OK(); +} +#endif // !defined(ORT_MINIMAL_BUILD) + void InitializeSession(InferenceSession* sess, ExecutionProviderRegistrationFn ep_registration_fn, const std::vector& provider_types, @@ -1528,6 +1658,43 @@ void addGlobalMethods(py::module& m) { throw std::runtime_error("Error when creating and registering allocator in create_and_register_allocator_v2: " + st.ErrorMessage()); } }); + m.def( + "register_execution_provider_library", + [](const std::string& registration_name, const PathString& library_path) -> void { +#if !defined(ORT_MINIMAL_BUILD) + std::shared_ptr env = GetEnv(); + OrtPybindThrowIfError(env->RegisterExecutionProviderLibrary(registration_name, library_path.c_str())); +#else + ORT_UNUSED_PARAMETER(registration_name); + ORT_UNUSED_PARAMETER(library_path); + ORT_THROW("Execution provider libraries are not supported in this build."); +#endif + }, + R"pbdoc(Register an execution provider library with ONNX Runtime.)pbdoc"); + m.def( + "unregister_execution_provider_library", + [](const std::string& registration_name) -> void { +#if !defined(ORT_MINIMAL_BUILD) + std::shared_ptr env = GetEnv(); + OrtPybindThrowIfError(env->UnregisterExecutionProviderLibrary(registration_name)); +#else + ORT_UNUSED_PARAMETER(registration_name); + ORT_THROW("Execution provider libraries are not supported in this build."); +#endif + }, + R"pbdoc(Unregister an execution provider library from ONNX Runtime.)pbdoc"); + m.def( + "get_ep_devices", + []() -> const std::vector& { +#if !defined(ORT_MINIMAL_BUILD) + std::shared_ptr env = GetEnv(); + return env->GetOrtEpDevices(); +#else + ORT_THROW("OrtEpDevices are not supported in this build"); +#endif + }, + R"pbdoc(Get the list of available OrtEpDevice instances.)pbdoc", + py::return_value_policy::reference); #if defined(USE_OPENVINO) || defined(USE_OPENVINO_PROVIDER_INTERFACE) m.def( @@ -1632,6 +1799,70 @@ void addGlobalMethods(py::module& m) { #endif } +#if !defined(ORT_MINIMAL_BUILD) +/** + * Calls the user's Python EP selection function and converts the results to a format that can be used + * by ORT to select OrtEpDevice instances. The user's function is set by calling + * SessionOptions.set_provider_selection_policy_delegate() on the Python side. The result of this wrapper + * function is used in core/session/provider_policy_context.cc. + * + * @param ep_devices OrtEpDevices to select from. + * @param num_devices Number of OrtEpDevices to select from. + * @param model_metadata Model's metadata. + * @param runtime_metadata Runtime metadata. + * @param selected Pre-allocated OrtEpDevice buffer to update with selected devices. + * @param max_selected Maximum number of entries in the pre-allocated 'selected' buffer. + * @param state Opaque state that holds a pointer to the user's Python function. + * + * @return nullptr OrtStatus* to indicate success. + */ +static OrtStatus* ORT_API_CALL PyEpSelectionPolicyWrapper(_In_ const OrtEpDevice** ep_devices, + _In_ size_t num_devices, + _In_ const OrtKeyValuePairs* model_metadata, + _In_opt_ const OrtKeyValuePairs* runtime_metadata, + _Inout_ const OrtEpDevice** selected, + _In_ size_t max_selected, + _Out_ size_t* num_selected, + _In_ void* state) { + PyEpSelectionDelegate* actual_delegate = reinterpret_cast(state); + std::vector py_ep_devices(ep_devices, ep_devices + num_devices); + std::unordered_map py_model_metadata = + model_metadata ? model_metadata->entries : std::unordered_map{}; + std::unordered_map py_runtime_metadata = + runtime_metadata ? runtime_metadata->entries : std::unordered_map{}; + + *num_selected = 0; + std::vector py_selected; + OrtStatus* status = nullptr; + + // Call the Python delegate function and convert any exceptions to a status. + ORT_TRY { + py_selected = (*actual_delegate)(py_ep_devices, py_model_metadata, py_runtime_metadata, max_selected); + } + ORT_CATCH(const std::exception& e) { + ORT_HANDLE_EXCEPTION([&]() { + status = ToOrtStatus(ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, e.what())); + }); + } + + if (status != nullptr) { + return status; + } + + if (py_selected.size() > max_selected) { + return ToOrtStatus(ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "selected too many EP devices (", py_selected.size(), "). ", + "The limit is ", max_selected, " EP devices.")); + } + + *num_selected = py_selected.size(); + for (size_t i = 0; i < py_selected.size(); ++i) { + selected[i] = py_selected[i]; + } + + return nullptr; +} +#endif // !defined(ORT_MINIMAL_BUILD) + void addObjectMethods(py::module& m, ExecutionProviderRegistrationFn ep_registration_fn) { py::enum_(m, "GraphOptimizationLevel") .value("ORT_DISABLE_ALL", GraphOptimizationLevel::ORT_DISABLE_ALL) @@ -1672,6 +1903,75 @@ void addObjectMethods(py::module& m, ExecutionProviderRegistrationFn ep_registra .def_static("webgpu", []() { return OrtDevice::GPU; }) .def_static("default_memory", []() { return OrtDevice::MemType::DEFAULT; }); + py::enum_(m, "OrtExecutionProviderDevicePolicy") + .value("DEFAULT", OrtExecutionProviderDevicePolicy_DEFAULT) + .value("PREFER_CPU", OrtExecutionProviderDevicePolicy_PREFER_CPU) + .value("PREFER_NPU", OrtExecutionProviderDevicePolicy_PREFER_NPU) + .value("PREFER_GPU", OrtExecutionProviderDevicePolicy_PREFER_GPU) + .value("MAX_PERFORMANCE", OrtExecutionProviderDevicePolicy_MAX_PERFORMANCE) + .value("MAX_EFFICIENCY", OrtExecutionProviderDevicePolicy_MAX_EFFICIENCY) + .value("MIN_OVERALL_POWER", OrtExecutionProviderDevicePolicy_MIN_OVERALL_POWER); + + py::enum_(m, "OrtHardwareDeviceType") + .value("CPU", OrtHardwareDeviceType_CPU) + .value("GPU", OrtHardwareDeviceType_GPU) + .value("NPU", OrtHardwareDeviceType_NPU); + + py::class_ py_hw_device(m, "OrtHardwareDevice", R"pbdoc(ONNX Runtime hardware device information.)pbdoc"); + py_hw_device.def_property_readonly( + "type", + [](OrtHardwareDevice* hw_device) -> OrtHardwareDeviceType { return hw_device->type; }, + R"pbdoc(Hardware device's type.)pbdoc") + .def_property_readonly( + "vendor_id", + [](OrtHardwareDevice* hw_device) -> uint32_t { return hw_device->vendor_id; }, + R"pbdoc(Hardware device's vendor identifier.)pbdoc") + .def_property_readonly( + "vendor", + [](OrtHardwareDevice* hw_device) -> std::string { return hw_device->vendor; }, + R"pbdoc(Hardware device's vendor name.)pbdoc") + .def_property_readonly( + "device_id", + [](OrtHardwareDevice* hw_device) -> uint32_t { return hw_device->device_id; }, + R"pbdoc(Hardware device's unique identifier.)pbdoc") + .def_property_readonly( + "metadata", + [](OrtHardwareDevice* hw_device) -> std::unordered_map { + return hw_device->metadata.entries; + }, + R"pbdoc(Hardware device's metadata as string key/value pairs.)pbdoc"); + + py::class_ py_ep_device(m, "OrtEpDevice", + R"pbdoc(Represents a hardware device that an execution provider supports +for model inference.)pbdoc"); + py_ep_device.def_property_readonly( + "ep_name", + [](OrtEpDevice* ep_device) -> std::string { return ep_device->ep_name; }, + R"pbdoc(The execution provider's name.)pbdoc") + .def_property_readonly( + "ep_vendor", + [](OrtEpDevice* ep_device) -> std::string { return ep_device->ep_vendor; }, + R"pbdoc(The execution provider's vendor name.)pbdoc") + .def_property_readonly( + "ep_metadata", + [](OrtEpDevice* ep_device) -> std::unordered_map { + return ep_device->ep_metadata.entries; + }, + R"pbdoc(The execution provider's additional metadata for the OrtHardwareDevice.)pbdoc") + .def_property_readonly( + "ep_options", + [](OrtEpDevice* ep_device) -> std::unordered_map { + return ep_device->ep_options.entries; + }, + R"pbdoc(The execution provider's options used to configure the provider to use the OrtHardwareDevice.)pbdoc") + .def_property_readonly( + "device", + [](OrtEpDevice* ep_device) -> const OrtHardwareDevice& { + return *ep_device->device; + }, + R"pbdoc(The OrtHardwareDevice instance for the OrtEpDevice.)pbdoc", + py::return_value_policy::reference_internal); + py::class_ ort_arena_cfg_binding(m, "OrtArenaCfg"); // Note: Doesn't expose initial_growth_chunk_sizes_bytes/max_power_of_two_extend_bytes option. // This constructor kept for backwards compatibility, key-value pair constructor overload exposes all options @@ -1736,6 +2036,85 @@ void addObjectMethods(py::module& m, ExecutionProviderRegistrationFn ep_registra sess(m, "SessionOptions", R"pbdoc(Configuration information for a session.)pbdoc"); sess .def(py::init()) + .def( + // Equivalent to the C API's SessionOptionsAppendExecutionProvider. + "add_provider", + [](PySessionOptions* sess_options, + const std::string& provider_name, + const ProviderOptions& provider_options = {}) { + OrtPybindThrowIfError(AddExplicitEpFactory(*sess_options, provider_name, provider_options)); + }, + R"pbdoc(Adds an explicit execution provider.)pbdoc") + .def( + // Equivalent to the C API's SessionOptionsAppendExecutionProvider_V2. + "add_provider_for_devices", + [](PySessionOptions* sess_options, + const std::vector& ep_devices, + const ProviderOptions& provider_options = {}) { +#if !defined(ORT_MINIMAL_BUILD) + OrtPybindThrowIfError(AddEpFactoryFromEpDevices(*sess_options, + ep_devices, + provider_options)); +#else + ORT_UNUSED_PARAMETER(sess_options); + ORT_UNUSED_PARAMETER(ep_devices); + ORT_UNUSED_PARAMETER(provider_options); + ORT_THROW("OrtEpDevices are not supported in this build"); +#endif + }, + R"pbdoc(Adds the execution provider that is responsible for the selected OrtEpDevice instances. All OrtEpDevice instances +must refer to the same execution provider.)pbdoc") + .def( + // Equivalent to the C API's SessionOptionsSetEpSelectionPolicy. + "set_provider_selection_policy", + [](PySessionOptions* py_sess_options, + OrtExecutionProviderDevicePolicy policy) { +#if !defined(ORT_MINIMAL_BUILD) + py_sess_options->py_ep_selection_delegate = nullptr; + + py_sess_options->value.ep_selection_policy.enable = true; + py_sess_options->value.ep_selection_policy.policy = policy; + py_sess_options->value.ep_selection_policy.delegate = nullptr; + py_sess_options->value.ep_selection_policy.state = nullptr; +#else + ORT_UNUSED_PARAMETER(py_sess_options); + ORT_UNUSED_PARAMETER(policy); + ORT_THROW("EP selection policies are not supported in this build"); +#endif + }, + R"pbdoc(Sets the execution provider selection policy for the session. Allows users to specify a +selection policy for automatic execution provider (EP) selection.)pbdoc") + .def( + // Equivalent to the C API's SessionOptionsSetEpSelectionPolicyDelegate. + "set_provider_selection_policy_delegate", + [](PySessionOptions* py_sess_options, + PyEpSelectionDelegate delegate_fn) { +#if !defined(ORT_MINIMAL_BUILD) + py_sess_options->py_ep_selection_delegate = delegate_fn; // Store python callback in PySessionOptions + + py_sess_options->value.ep_selection_policy.enable = true; + py_sess_options->value.ep_selection_policy.policy = OrtExecutionProviderDevicePolicy_DEFAULT; + py_sess_options->value.ep_selection_policy.delegate = PyEpSelectionPolicyWrapper; + py_sess_options->value.ep_selection_policy.state = + reinterpret_cast(&py_sess_options->py_ep_selection_delegate); +#else + ORT_UNUSED_PARAMETER(py_sess_options); + ORT_UNUSED_PARAMETER(delegate_fn); + ORT_THROW("EP selection policies are not supported in this build"); +#endif + }, + R"pbdoc(Sets the execution provider selection policy delegate for the session. Allows users to specify a +custom selection policy function for automatic execution provider (EP) selection. The delegate must return a list of +selected OrtEpDevice instances. The signature of the delegate is +def custom_delegate(ep_devices: Sequence[OrtEpDevice], model_metadata: dict[str, str], runtime_metadata: dict[str, str], +max_selections: int) -> Sequence[OrtEpDevice])pbdoc") + .def( + "has_providers", + [](PySessionOptions* sess_options) -> bool { + return !sess_options->provider_factories.empty() || sess_options->value.ep_selection_policy.enable; + }, + R"pbdoc(Returns true if the SessionOptions has been configured with providers, OrtEpDevices, or +policies that will run the model.)pbdoc") .def_property( "enable_cpu_mem_arena", [](const PySessionOptions* options) -> bool { return options->value.enable_cpu_mem_arena; }, @@ -2132,11 +2511,18 @@ including arg name, arg type (contains both type and shape).)pbdoc") const std::vector& provider_types = {}, const ProviderOptionsVector& provider_options = {}, const std::unordered_set& disabled_optimizer_names = {}) { - InitializeSession(sess->GetSessionHandle(), - ep_registration_fn, - provider_types, - provider_options, - disabled_optimizer_names); + // If the user did not explicitly specify providers when creating InferenceSession and the SessionOptions + // has provider information (i.e., either explicit EPs or an EP selection policy), then use the information + // in the session options to initialize the session. + if (provider_types.empty() && sess->HasProvidersInSessionOptions()) { + OrtPybindThrowIfError(InitializeSessionEpsFromSessionOptions(*sess, disabled_optimizer_names)); + } else { + InitializeSession(sess->GetSessionHandle(), + ep_registration_fn, + provider_types, + provider_options, + disabled_optimizer_names); + } }, R"pbdoc(Load a model saved in ONNX or ORT format.)pbdoc") .def("run", @@ -2395,6 +2781,58 @@ including arg name, arg type (contains both type and shape).)pbdoc") .value("kNextPowerOfTwo", onnxruntime::ArenaExtendStrategy::kNextPowerOfTwo) .value("kSameAsRequested", onnxruntime::ArenaExtendStrategy::kSameAsRequested) .export_values(); + + py::class_(m, "ModelCompiler", + R"pbdoc(This is the class used to compile an ONNX model.)pbdoc") + .def(py::init([](const PySessionOptions& sess_options, + std::string path_or_bytes, + bool is_path, + bool embed_compiled_data_into_model = false, + std::string external_initializers_file_path = {}, + size_t external_initializers_size_threshold = 1024) { +#if !defined(ORT_MINIMAL_BUILD) + std::unique_ptr result; + OrtPybindThrowIfError(PyModelCompiler::Create(result, GetEnv(), sess_options, + std::move(path_or_bytes), is_path, + embed_compiled_data_into_model, + external_initializers_file_path, + external_initializers_size_threshold)); + return result; +#else + ORT_UNUSED_PARAMETER(sess_options); + ORT_UNUSED_PARAMETER(path_or_bytes); + ORT_UNUSED_PARAMETER(is_path); + ORT_UNUSED_PARAMETER(embed_compiled_data_into_model); + ORT_UNUSED_PARAMETER(external_initializers_file_path); + ORT_UNUSED_PARAMETER(external_initializers_size_threshold); + ORT_THROW("Compile API is not supported in this build."); +#endif + })) + .def( + "compile_to_file", + [](PyModelCompiler* model_compiler, std::string output_model_path = {}) { +#if !defined(ORT_MINIMAL_BUILD) + OrtPybindThrowIfError(model_compiler->CompileToFile(output_model_path)); +#else + ORT_UNUSED_PARAMETER(model_compiler); + ORT_UNUSED_PARAMETER(output_model_path); + ORT_THROW("Compile API is not supported in this build."); +#endif + }, + R"pbdoc(Compile an ONNX model into a file.)pbdoc") + .def( + "compile_to_bytes", + [](PyModelCompiler* model_compiler) -> py::bytes { +#if !defined(ORT_MINIMAL_BUILD) + std::string output_bytes; + OrtPybindThrowIfError(model_compiler->CompileToBytes(output_bytes)); + return py::bytes(output_bytes); +#else + ORT_UNUSED_PARAMETER(model_compiler); + ORT_THROW("Compile API is not supported in this build."); +#endif + }, + R"pbdoc(Compile an ONNX model into a buffer.)pbdoc"); } bool CreateInferencePybindStateModule(py::module& m) { diff --git a/onnxruntime/python/onnxruntime_pybind_state_common.h b/onnxruntime/python/onnxruntime_pybind_state_common.h index 3ae5c0d289c21..4114bd4078799 100644 --- a/onnxruntime/python/onnxruntime_pybind_state_common.h +++ b/onnxruntime/python/onnxruntime_pybind_state_common.h @@ -4,6 +4,8 @@ #pragma once +#include + #include "core/common/logging/logging.h" #include "core/common/logging/sinks/cerr_sink.h" #include "core/common/optional.h" @@ -229,18 +231,31 @@ extern OrtDevice::DeviceId cuda_device_id; // TODO remove deprecated global config extern size_t gpu_mem_limit; -using PySessionOptions = OrtSessionOptions; +#if !defined(ORT_MINIMAL_BUILD) +using PyEpSelectionDelegate = std::function(const std::vector& ep_devices, + const std::unordered_map& model_metadata, + const std::unordered_map& runtime_metadata, + size_t max_selections)>; +#endif + +// Thin wrapper over internal C OrtSessionOptions to store additional state. +struct PySessionOptions : public OrtSessionOptions { +#if !defined(ORT_MINIMAL_BUILD) + // Callback function from Python application that allows the user to specify custom EP selection logic. + PyEpSelectionDelegate py_ep_selection_delegate; +#endif // !defined(ORT_MINIMAL_BUILD) +}; // Thin wrapper over internal C++ InferenceSession to accommodate custom op library management for the Python user struct PyInferenceSession { PyInferenceSession(std::shared_ptr env, const PySessionOptions& so) - : env_(std::move(env)) { + : env_(std::move(env)), session_options_(so) { sess_ = std::make_unique(so.value, *env_); } #if !defined(ORT_MINIMAL_BUILD) PyInferenceSession(std::shared_ptr env, const PySessionOptions& so, const std::string& arg, bool is_arg_file_name) - : env_(std::move(env)) { + : env_(std::move(env)), session_options_(so) { if (is_arg_file_name) { // Given arg is the file path. Invoke the corresponding ctor(). sess_ = std::make_unique(so.value, *env_, arg); @@ -252,6 +267,24 @@ struct PyInferenceSession { } #endif + // Returns true if the session options have provider information from either + // setting explicit providers, setting a provider that supports a OrtEpDevice(s), or + // setting a selection policy (e.g., prefer gpu). + bool HasProvidersInSessionOptions() const { + return !session_options_.provider_factories.empty() || + session_options_.value.ep_selection_policy.enable; + } + + // Returns (and updates) a reference to the OrtSessionOptions for this inference session. + OrtSessionOptions& GetOrtSessionOptions() { + if (sess_) { + // Copy internal value from InferenceSession as it is the source of truth + // and the option configurations may have changed. + session_options_.value = sess_->GetSessionOptions(); + } + return session_options_; + } + InferenceSession* GetSessionHandle() const { return sess_.get(); } virtual ~PyInferenceSession() = default; @@ -264,6 +297,7 @@ struct PyInferenceSession { private: std::shared_ptr env_; std::unique_ptr sess_; + OrtSessionOptions session_options_; }; inline const PySessionOptions& GetDefaultCPUSessionOptions() { diff --git a/onnxruntime/python/tools/quantization/matmul_nbits_quantizer.py b/onnxruntime/python/tools/quantization/matmul_nbits_quantizer.py index f70a7b545e60a..5428898b1c642 100644 --- a/onnxruntime/python/tools/quantization/matmul_nbits_quantizer.py +++ b/onnxruntime/python/tools/quantization/matmul_nbits_quantizer.py @@ -8,7 +8,6 @@ import argparse import copy -import importlib import logging import os @@ -16,11 +15,11 @@ import numpy.typing as npt import onnx from onnx.onnx_pb import GraphProto, ModelProto, NodeProto, TensorProto -from packaging import version from onnxruntime.capi._pybind_state import quantize_matmul_4bits, quantize_matmul_8bits, quantize_qdq_matmul_4bits from .calibrate import CalibrationDataReader +from .neural_compressor import gptq_quantize, rtn_quantize from .onnx_model import ONNXModel from .quant_utils import QuantFormat, attribute_to_kwarg @@ -98,6 +97,40 @@ def __init__( self.ratios = ratios +class KQuantWeightOnlyQuantConfig(WeightOnlyQuantConfig): + def __init__( + self, + ratios=None, + quant_format=QuantFormat.QOperator, + op_types_to_quantize: tuple[str, ...] | None = None, + customized_weight_config: dict | None = None, + ): + """ + This is a class for k-quant algorithm Weight Only Quant Configuration. + + Args: + ratios: + percentile of clip. Defaults to {}. + quant_format (QuantFormat{QOperator, QDQ}, optional): + QOperator format quantizes the model with quantized operators directly. + QDQ format quantize the model by inserting QuantizeLinear/DeQuantizeLinear on the tensor. + Defaults to QuantFormat.QOperator. + op_types_to_quantize (optional): + set of operator types to quantize. + """ + assert quant_format == QuantFormat.QOperator, "k-quant only supports QOperator format" + + if ratios is None: + ratios = {} + super().__init__( + algorithm="k_quant", + quant_format=quant_format, + op_types_to_quantize=op_types_to_quantize, + customized_weight_config=customized_weight_config, + ) + self.ratios = ratios + + class GPTQWeightOnlyQuantConfig(WeightOnlyQuantConfig): def __init__( self, @@ -1313,14 +1346,12 @@ def inc_dataloader(): algorithm = self.algo_config.algorithm logger.info(f"start to quantize model with {algorithm} algorithm...") - if algorithm == "RTN": - from neural_compressor.adaptor.ox_utils.weight_only import rtn_quantize - + if algorithm in ["RTN", "k_quant"]: kwargs["ratios"] = self.algo_config.ratios + kwargs["algorithm"] = algorithm """ - neural-compressor uses fp32 to represent the node that skip quantization, it does not mean this node is fp32 type though. - https://github.com/intel/neural-compressor/blob/a617115b1490bbe6163c0024fb55bd260c8914df/neural_compressor/adaptor/ox_utils/weight_only.py#L343 + We uses fp32 to represent the node that skip quantization, it does not mean this node is fp32 type though. """ for n in self.nodes_to_exclude: weight_only_node_config[n] = "fp32" @@ -1331,8 +1362,6 @@ def inc_dataloader(): **kwargs, ) elif algorithm == "GPTQ": - from neural_compressor.adaptor.ox_utils.weight_only import gptq_quantize - kwargs["percdamp"] = self.algo_config.percdamp kwargs["blocksize"] = self.algo_config.block_size kwargs["actorder"] = self.algo_config.actorder @@ -1380,21 +1409,7 @@ def process(self): self.model = ONNXModel(self.model) # Ensure the model is wrapped back into ONNXModel self.model.clean_initializers() else: - # use Intel® Neural Compressor for RTN or GPTQ weight-only quantize algorithm - try: - importlib.import_module("neural_compressor") - except Exception as e: - logging.error(f"{e}.") - raise RuntimeError( - "neural-compressor is not correctly installed. Please check your environment." - ) from e - - import neural_compressor - - assert version.parse(neural_compressor.__version__) >= version.parse("2.3.2"), ( - "Require neural-compressor >= 2.3.2 to support weight only quantization!" - ) - + # RTN or GPTQ weight-only quantize algorithm self.int4_quant_algo() @@ -1425,7 +1440,7 @@ def parse_args(): "--quant_method", default="default", type=str, - choices=["default", "hqq", "rtn", "gptq", "nvidia_awq"], + choices=["default", "hqq", "rtn", "k_quant", "gptq", "nvidia_awq"], help="the algorithm used to quantize weight, \nrtn and gptq leverage Intel® Neural Compressor", ) parser.add_argument("--bits", default=4, type=int, help="the target bits to represent weight") @@ -1555,6 +1570,8 @@ def parse_args(): ) elif args.quant_method == "rtn": quant_config = RTNWeightOnlyQuantConfig(op_types_to_quantize=op_types_to_quantize) + elif args.quant_method == "k_quant": + quant_config = KQuantWeightOnlyQuantConfig(op_types_to_quantize=op_types_to_quantize) elif args.quant_method == "gptq": quant_config = GPTQWeightOnlyQuantConfig(block_size=args.block_size, op_types_to_quantize=op_types_to_quantize) elif args.quant_method == "nvidia_awq": diff --git a/onnxruntime/python/tools/quantization/neural_compressor/__init__.py b/onnxruntime/python/tools/quantization/neural_compressor/__init__.py new file mode 100644 index 0000000000000..08b9a38624c98 --- /dev/null +++ b/onnxruntime/python/tools/quantization/neural_compressor/__init__.py @@ -0,0 +1 @@ +from .weight_only import gptq_quantize, rtn_quantize # noqa: F401 diff --git a/onnxruntime/python/tools/quantization/neural_compressor/onnx_model.py b/onnxruntime/python/tools/quantization/neural_compressor/onnx_model.py new file mode 100644 index 0000000000000..f931045c4e349 --- /dev/null +++ b/onnxruntime/python/tools/quantization/neural_compressor/onnx_model.py @@ -0,0 +1,1264 @@ +# +# The implementation of this file is based on: +# https://github.com/intel/neural-compressor/tree/master/neural_compressor +# +# Copyright (c) 2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Class for ONNX model.""" + +import logging +import os +import sys +from pathlib import Path + +import onnx + +from .util import MAXIMUM_PROTOBUF, find_by_name + +logger = logging.getLogger("neural_compressor") + +# TODO: Check https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/python/tools/quantization/onnx_model.py to see if we can integrate with it. + + +class ONNXModel: + """Build ONNX model.""" + + def __init__(self, model, **kwargs): + """Initialize an ONNX model. + + Args: + model (str or ModelProto): path to onnx model or loaded ModelProto model object. + ignore_warning (bool): ignore large model warning. Default is False. + load_external_data (bool): load external data for large model. Default is True. + """ + self._model = model if not isinstance(model, str) else onnx.load(model, load_external_data=False) + self._model_path = None if not isinstance(model, str) else model + + self.check_is_large_model() + if self._is_large_model and self._model_path is None and not kwargs.get("ignore_warning", False): + logger.warning("Model size > 2GB. Please use model path instead of onnx model object to quantize") + + if self._is_large_model and isinstance(model, str) and kwargs.get("load_external_data", True): + from onnx.external_data_helper import load_external_data_for_model + + load_external_data_for_model(self._model, os.path.dirname(self._model_path)) + + self._config = None + if isinstance(model, str) and os.path.exists(Path(model).parent.joinpath("config.json").as_posix()): + from transformers import AutoConfig + + self._config = AutoConfig.from_pretrained(Path(model).parent.as_posix()) + + self.node_name_counter = {} + self._output_name_to_node = {} + self._input_name_to_nodes = {} + self._get_input_name_to_nodes(self._model.graph.node) + self._get_output_name_to_node(self._model.graph.node) + self._graph_info = {} + self._get_graph_info() + self._q_config = None + + def check_is_large_model(self): + """Check model > 2GB.""" + init_size = 0 + for init in self._model.graph.initializer: + # if initializer has external data location, return True + if init.HasField("data_location") and init.data_location == onnx.TensorProto.EXTERNAL: + self._is_large_model = True + return + # if raise error of initializer size > 2GB, return True + try: + init_bytes = init.SerializeToString() + init_size += sys.getsizeof(init_bytes) + except Exception as e: + if "exceeds maximum protobuf size of 2GB" in str(e): + self._is_large_model = True + return + else: # pragma: no cover + raise e + if init_size > MAXIMUM_PROTOBUF: + self._is_large_model = True + return + self._is_large_model = False + + @property + def is_large_model(self): + """Check the onnx model is over 2GB.""" + return self._is_large_model + + @property + def model_path(self): + """Return model path.""" + return self._model_path + + @model_path.setter + def model_path(self, path): + """Set model path.""" + self._model_path = path + + def framework(self): + """Return framework.""" + return "onnxruntime" + + @property + def q_config(self): + """Return q_config.""" + return self._q_config + + @q_config.setter + def q_config(self, q_config): + """Set q_config.""" + self._q_config = q_config + + @property + def hf_config(self): + """Return huggingface config if model is Transformer-based.""" + return self._config + + @property + def model(self): + """Return model itself.""" + return self._model + + @model.setter + def model(self, model): + """Set model itself.""" + self._model = model + self._graph_info = {} + self._get_graph_info() + self._output_name_to_node = {} + self._input_name_to_nodes = {} + self._get_input_name_to_nodes(self._model.graph.node) + self._get_output_name_to_node(self._model.graph.node) + + def input(self): + """Return input of model.""" + return [i.name for i in self._model.graph.input] + + def output(self): + """Return output of model.""" + return [i.name for i in self._model.graph.output] + + def update(self): + """Update model info.""" + self._graph_info = {} + self._get_graph_info() + self._output_name_to_node = {} + self._input_name_to_nodes = {} + self._get_input_name_to_nodes(self._model.graph.node) + self._get_output_name_to_node(self._model.graph.node) + + @property + def graph_info(self): + """Return ORT Graph Info object holding information about backend graph.""" + return self._graph_info + + def _get_graph_info(self): + """Update graph info.""" + for node in self._model.graph.node: + self.graph_info.update({node.name: node.op_type}) + + def save(self, root): + """Save ONNX model.""" + if os.path.split(root)[0] != "" and not os.path.exists(os.path.split(root)[0]): + raise ValueError('"root" directory does not exists.') + if self.is_large_model: # pragma: no cover + from onnx.external_data_helper import load_external_data_for_model + + load_external_data_for_model(self._model, os.path.split(self._model_path)[0]) + onnx.save_model( + self._model, + root, + save_as_external_data=True, + all_tensors_to_one_file=True, + location=root.split("/")[-1] + "_data", + size_threshold=1024, + convert_attribute=False, + ) + else: + onnx.save(self._model, root) + + if self._config is not None: + model_type = "" if not hasattr(self._config, "model_type") else self._config.model_type + self._config.__class__.model_type = model_type + output_config_file = Path(root).parent.joinpath("config.json").as_posix() + self._config.to_json_file(output_config_file, use_diff=False) + + def nodes(self): + """Return model nodes.""" + return self._model.graph.node + + def initializer(self): + """Return model initializer.""" + return self._model.graph.initializer + + def graph(self): + """Return model graph.""" + return self._model.graph + + def ir_version(self): + """Return model ir_version.""" + return self._model.ir_version + + def opset_import(self): + """Return model opset_import.""" + return self._model.opset_import + + def remove_node(self, node): + """Remove a node from model.""" + if node in self._model.graph.node: + self._model.graph.node.remove(node) + + def remove_nodes(self, nodes_to_remove): + """Remove nodes from model.""" + for node in nodes_to_remove: + self.remove_node(node) + + def add_node(self, node): + """Add a node to model.""" + self._model.graph.node.extend([node]) + + def add_nodes(self, nodes_to_add): + """Add nodes to model.""" + self._model.graph.node.extend(nodes_to_add) + + def add_initializer(self, tensor): + """Add a initializer to model.""" + if find_by_name(tensor.name, self._model.graph.initializer) is None: + self._model.graph.initializer.extend([tensor]) + + def add_initializers(self, tensors): + """Add initializers to model.""" + for tensor in tensors: + self.add_initializer(tensor) + + def get_initializer(self, name): + """Get an initializer by name.""" + for tensor in self._model.graph.initializer: + if tensor.name == name: + return tensor + return None + + def get_initializer_share_num(self, name): + """Get the number of shares of initializer.""" + num = 0 + if self.get_initializer(name) is None: + return num + + for node in self.nodes(): + if name in node.input: + num += 1 + return num + + def get_node(self, name): + """Get a node by name.""" + for node in self._model.graph.node: + if node.name == name: + return node + return None + + def remove_initializer(self, tensor): + """Remove an initializer from model.""" + if tensor in self._model.graph.initializer: + self._model.graph.initializer.remove(tensor) + + def remove_initializers(self, init_to_remove): + """Remove initializers from model.""" + for initializer in init_to_remove: + self.remove_initializer(initializer) + + def set_initializer(self, tensor, array, raw=False): + """Update initializer.""" + old_tensor = self.get_initializer(tensor) + self.remove_initializer(old_tensor) + dims = old_tensor.dims + data_type = old_tensor.data_type + new_tensor = ( + onnx.helper.make_tensor(tensor, data_type, dims, array.flatten().tolist()) + if not raw + else onnx.helper.make_tensor(tensor, data_type, dims, array.tostring(), raw=raw) + ) + self.add_initializer(new_tensor) + + @property + def input_name_to_nodes(self): + """Return input names of nodes.""" + return self._input_name_to_nodes + + def _get_input_name_to_nodes(self, nodes): + """Get input names of nodes.""" + for node in nodes: + attrs = [ + attr + for attr in node.attribute + if attr.type == onnx.AttributeProto.GRAPH or attr.type == onnx.AttributeProto.GRAPHS + ] + if len(attrs) > 0: + for attr in attrs: + self._get_input_name_to_nodes(attr.g.node) + for input_name in node.input: + if len(input_name.strip()) != 0: + if input_name not in self._input_name_to_nodes: + self._input_name_to_nodes[input_name] = [node] + else: + self._input_name_to_nodes[input_name].append(node) + + @property + def output_name_to_node(self): + """Return output names of nodes.""" + return self._output_name_to_node + + def _get_output_name_to_node(self, nodes): + """Get output names of nodes.""" + for node in nodes: + attrs = [ + attr + for attr in node.attribute + if attr.type == onnx.AttributeProto.GRAPH or attr.type == onnx.AttributeProto.GRAPHS + ] + if len(attrs) > 0: + for attr in attrs: + self._get_output_name_to_node(attr.g.node) + for output_name in node.output: + if len(output_name.strip()) != 0: + self._output_name_to_node[output_name] = node + + def get_siblings(self, node): + """Get siblings nodes.""" + siblings = [] + for parent in self.get_parents(node): + for child in self.get_children(parent): + if child.name != node.name: + siblings.append(child) + return siblings + + def get_children(self, node, input_name_to_nodes=None): + """Get children nodes.""" + if input_name_to_nodes is None: + input_name_to_nodes = self._input_name_to_nodes + + children = [] + for output in node.output: + if output in input_name_to_nodes: + for child in input_name_to_nodes[output]: + children.append(child) # noqa: PERF402 + return children + + def get_parents(self, node, output_name_to_node=None): + """Get parents nodes.""" + if output_name_to_node is None: + output_name_to_node = self._output_name_to_node + + parents = [] + for input in node.input: + if input in output_name_to_node: + parents.append(output_name_to_node[input]) + return parents + + def get_parent(self, node, idx, output_name_to_node=None): + """Get parent node by idx.""" + if output_name_to_node is None: + output_name_to_node = self._output_name_to_node + + if len(node.input) <= idx: + return None + + input = node.input[idx] + if input not in output_name_to_node: + return None + + return output_name_to_node[input] + + def find_node_by_name(self, node_name, new_nodes_list, graph): + """Find out node by name.""" + graph_nodes_list = list(graph.node) # deep copy + graph_nodes_list.extend(new_nodes_list) + node = find_by_name(node_name, graph_nodes_list) + return node + + def find_nodes_by_initializer(self, graph, initializer): + """Find all nodes with given initializer as an input.""" + nodes = [] + for node in graph.node: + for node_input in node.input: + if node_input == initializer.name: + nodes.append(node) + return nodes + + def get_scale_zero(self, tensor): + """Help function to get scale and zero_point.""" + if not tensor.endswith("_quantized"): + logger.debug(f"Find {tensor} in the quantized graph is not quantized.") + return None, None + + def _searcher(tensor_name): + """Search scale and zero point tensor recursively.""" + node = self._input_name_to_nodes[tensor_name][0] + parent = self._output_name_to_node.get(tensor_name, None) + direct_int8 = ["Reshape", "Transpose", "Squeeze", "Unsqueeze", "MaxPool", "Pad", "Split"] + if parent is not None and parent.op_type in direct_int8: + fp32_tensor_name = ( + parent.input[0] + .replace("_quantized", "") + .replace("_QuantizeLinear", "") + .replace("_QuantizeInput", "") + ) + elif node.op_type in ["Gather"]: # pragma: no cover + fp32_tensor_name = ( + node.output[0] + .replace("_quantized", "") + .replace("_QuantizeLinear", "") + .replace("_QuantizeInput", "") + ) + else: + fp32_tensor_name = ( + tensor_name.replace("_quantized", "").replace("_QuantizeLinear", "").replace("_QuantizeInput", "") + ) + scale = fp32_tensor_name + "_scale" + scale_tensor = self.get_initializer(scale) + zo = fp32_tensor_name + "_zero_point" + zo_tensor = self.get_initializer(zo) + + if scale_tensor is None or zo_tensor is None: + if parent is not None: + scale_tensor, zo_tensor = _searcher(parent.input[0]) + return scale_tensor, zo_tensor + + node = self._input_name_to_nodes[tensor][0] + # TODO check if scale_tensor and zero_point is needed + # for bias of qlinearconv, scale and zero_point is not needed + if (node.op_type == "QLinearConv" and tensor == node.input[-1]) or ( + node.op_type == "QGemm" and tensor == node.input[-3] + ): + return None, None + else: + scale_tensor, zo_tensor = _searcher(tensor) + assert scale_tensor, f"missing scale for tensor {tensor}" + assert zo_tensor, f"missing zero point for tensor {tensor}" + return scale_tensor, zo_tensor + + def save_model_to_file(self, output_path, use_external_data_format=False): + """Save model to external data, which is needed for model size > 2GB.""" + from onnx.external_data_helper import convert_model_to_external_data + + if use_external_data_format: + convert_model_to_external_data( + self._model, all_tensors_to_one_file=True, location=Path(output_path).name + ".data" + ) + onnx.save_model(self._model, output_path) + + @staticmethod + def replace_node_input(node, old_input_name, new_input_name): + """Replace input of a node.""" + assert isinstance(old_input_name, str) and isinstance(new_input_name, str) + for j in range(len(node.input)): + if node.input[j] == old_input_name: + node.input[j] = new_input_name + + def replace_input_of_all_nodes(self, old_input_name, new_input_name, white_optype=None, black_optype=None): + """Replace inputs of all nodes.""" + if white_optype is None: + white_optype = [] + if black_optype is None: + black_optype = [] + if len(white_optype) > 0: + for node in self.model.graph.node: + if node.op_type in white_optype: + ONNXModel.replace_node_input(node, old_input_name, new_input_name) + else: + for node in self.model.graph.node: + if node.op_type not in black_optype: + ONNXModel.replace_node_input(node, old_input_name, new_input_name) + + @staticmethod + def replace_node_output(node, old_output_name, new_output_name): + """Replace output of a node.""" + assert isinstance(old_output_name, str) and isinstance(new_output_name, str) + for j in range(len(node.output)): + if node.output[j] == old_output_name: + node.output[j] = new_output_name + + def replace_output_of_all_nodes(self, old_output_name, new_output_name, white_optype=None, black_optype=None): + """Replace outputs of all nodes.""" + if white_optype is None: + white_optype = [] + if black_optype is None: + black_optype = [] + if len(white_optype) > 0: + for node in self.model.graph.node: + if node.op_type in white_optype: + ONNXModel.replace_node_output(node, old_output_name, new_output_name) + else: + for node in self.model.graph.node: + if node.op_type not in black_optype: + ONNXModel.replace_node_output(node, old_output_name, new_output_name) + + def remove_unused_nodes(self): + """Remove unused nodes.""" + unused_nodes = [] + nodes = self.nodes() + for node in nodes: + if ( + node.op_type == "Constant" + and node.output[0] not in self._model.graph.output + and node.output[0] not in self._input_name_to_nodes + ): + unused_nodes.append(node) + elif ( + node.op_type == "QuantizeLinear" + and len(self.get_children(node)) == 1 + and self.get_children(node)[0].op_type == "DequantizeLinear" + and node.input[0] not in self._output_name_to_node + and self.get_children(node)[0].output[0] not in self._input_name_to_nodes + ): + unused_nodes.append(node) + unused_nodes.extend(self.get_children(node)) + else: + # remove the node if it does not serve as the input or output of any other nodes + unused = True + for output in node.output: + if output in self._input_name_to_nodes or output in self.output(): + unused = False + break + for input in node.input: + if self.get_initializer(input) is not None: + continue + elif input in self._output_name_to_node or input in self.input(): + unused = False + break + if unused: + unused_nodes.append(node) + self.remove_nodes(unused_nodes) + + ununsed_weights = [] + for w in self._model.graph.initializer: + if w.name not in self._input_name_to_nodes and w.name not in self._model.graph.output: + ununsed_weights.append(w) + # Remove from graph.input + for graph_input in self.graph().input: + if graph_input.name == w.name: + self.graph().input.remove(graph_input) + + self.remove_initializers(ununsed_weights) + self.update() + + def topological_sort(self, enable_subgraph=False): + """Topological sort the model.""" + import copy + from collections import deque + + if not enable_subgraph: + input_name_to_nodes = {} + output_name_to_node = {} + for node in self.model.graph.node: + for input_name in node.input: + if len(input_name.strip()) != 0: + if input_name not in input_name_to_nodes: + input_name_to_nodes[input_name] = [node] + else: + input_name_to_nodes[input_name].append(node) + for output_name in node.output: + if len(output_name.strip()) != 0: + output_name_to_node[output_name] = node + else: # pragma: no cover + input_name_to_nodes = self._input_name_to_nodes + output_name_to_node = self._output_name_to_node + + all_nodes = {} + q = deque() + wait = deque() + for inp in self.model.graph.input: + q.extend(input_name_to_nodes[inp.name]) + for n in self.model.graph.node: + if all(i not in output_name_to_node and i not in self.input() for i in n.input): + q.append(n) + + while q: + n = q.popleft() + if not all(output_name_to_node[i].name in all_nodes for i in n.input if i in output_name_to_node): + if n not in wait: + wait.append(n) + continue + + all_nodes[n.name] = n + for out in n.output: + if out in input_name_to_nodes: + q.extend([i for i in input_name_to_nodes[out] if i.name not in all_nodes and i not in q]) + if len(q) == 0 and len(wait) != 0: + q = copy.deepcopy(wait) + wait.clear() + nodes = [i[1] for i in all_nodes.items()] + assert len(list({n.name for n in nodes})) == len(list({n.name for n in self.model.graph.node})) + self.model.graph.ClearField("node") + self.model.graph.node.extend(nodes) + + def get_nodes_chain(self, start, stop, result_chain=None): + """Get nodes chain with given start node and stop node.""" + from collections import deque + + from onnx import NodeProto + + if result_chain is None: + result_chain = [] + # process start node list + start_node = deque() + for node in start: + if isinstance(node, str): + start_node.append(node) + elif isinstance(node, NodeProto): + start_node.append(node.name) + else: + assert False, "'get_nodes_chain' function only support list[string]or list[NodeProto] params" # noqa: B011 + + # process stop node list + stop_node = [] + for node in stop: + if isinstance(node, str): + stop_node.append(node) + elif isinstance(node, NodeProto): + stop_node.append(node.name) + else: + assert False, "'get_nodes_chain' function only support list[string]or list[NodeProto] params" # noqa: B011 + + while start_node: + node_name = start_node.popleft() + if node_name in stop_node: + continue + if node_name not in result_chain: + result_chain.append(node_name) + else: + continue + + node = find_by_name(node_name, list(self.model.graph.node)) + for parent in self.get_parents(node): + start_node.append(parent.name) + + return result_chain + + def find_split_node_for_layer_wise_quantization(self): + """Find split node for layer wise quantization.""" + # find split nodes of decoder blocks + # embed -> decoder.0 -(split_node)-> ... -(split_node)-> decoder.n -(split_node)-> norm -> head + # after split: embed -> decoder.0, + # decoder.1, + # decoder.2, + # ..., + # decoder.n, + # norm -> head + start_nodes = [] + for node in self._model.graph.node: + start_node, qkv_nodes_list = None, None + if node.op_type == "SkipLayerNormalization": + start_node = node + qkv_nodes_list = [ + self.match_parent_path( + start_node, + ["MatMul", "Reshape", "Transpose", "Reshape", "MatMul"], + [None, 0, 0, 0, 0], + ), + self.match_parent_path( + start_node, + ["Add", "MatMul", "Reshape", "Transpose", "MatMul"], + [1, 1, 0, 0, 0], + ), + ] + if node.op_type == "Add": + start_node = node + qkv_nodes_list = [ + # match base attention structure + self.match_parent_path( + start_node, + ["Add", "MatMul", "Reshape", "Transpose", "MatMul"], + [0, None, 0, 0, 0], + ), + self.match_parent_path( + start_node, ["Add", "MatMul", "Reshape", "Transpose", "MatMul"], [1, None, 0, 0, 0] + ), + # match gpt attention no past structure + self.match_parent_path( + start_node, + ["Reshape", "Gemm", "Reshape", "Reshape", "Transpose", "MatMul"], + [None, 0, 0, 0, 0, 0], + output_name_to_node=self.output_name_to_node, + return_indice=[], + ), + # match bart attention structure + self.match_parent_path( + start_node, + ["Add", "MatMul", "Reshape", "Transpose", "Reshape", "MatMul"], + [0, None, 0, 0, 0, 0], + ), + self.match_parent_path( + start_node, + ["Add", "MatMul", "Reshape", "Transpose", "Reshape", "MatMul"], + [1, None, 0, 0, 0, 0], + ), + self.match_parent_path( + start_node, + ["MatMul", "Mul", "MatMul", "Mul", "Div", "Add"], + [None, 0, None, 0, None, 0], + ), + self.match_parent_path( + start_node, + ["MatMul", "Mul", "MatMul", "SimplifiedLayerNormalization", "Add"], + [None, 0, None, 0, 0], + ), + ] + if not start_node: + continue + if not any(qkv_nodes_list): + continue + start_nodes.append(start_node) + return start_nodes + + def find_qkv_in_attention(self, find_all=False): + """Find qkv MatMul in Attention. + + Args: + find_all (bool, optional): find all qkv MatMul. Defaults to False + + Returns: + qkv (list): qkv MatMul list + """ + qkv = [] + for node in self._model.graph.node: + if node.op_type == "Attention": + qkv.append([node.name]) + continue + start_node, qkv_nodes_list = None, None + if node.op_type == "SkipLayerNormalization": + start_node = node + qkv_nodes_list = [ + self.match_parent_path( + start_node, + ["MatMul", "Reshape", "Transpose", "Reshape", "MatMul"], + [None, 0, 0, 0, 0], + ), + self.match_parent_path( + start_node, + ["Add", "MatMul", "Reshape", "Transpose", "MatMul"], + [1, 1, 0, 0, 0], + ), + ] + if node.op_type == "Add": + start_node = node + qkv_nodes_list = [ + # match base attention structure + self.match_parent_path( + start_node, + ["Add", "MatMul", "Reshape", "Transpose", "MatMul"], + [0, None, 0, 0, 0], + ), + self.match_parent_path( + start_node, ["Add", "MatMul", "Reshape", "Transpose", "MatMul"], [1, None, 0, 0, 0] + ), + # match gpt attention no past structure + self.match_parent_path( + start_node, + ["Reshape", "Gemm", "Reshape", "Reshape", "Transpose", "MatMul"], + [None, 0, 0, 0, 0, 0], + output_name_to_node=self.output_name_to_node, + return_indice=[], + ), + # match bart attention structure + self.match_parent_path( + start_node, + ["Add", "MatMul", "Reshape", "Transpose", "Reshape", "MatMul"], + [0, None, 0, 0, 0, 0], + ), + self.match_parent_path( + start_node, + ["Add", "MatMul", "Reshape", "Transpose", "Reshape", "MatMul"], + [1, None, 0, 0, 0, 0], + ), + ] + if not start_node: + continue + if not any(qkv_nodes_list): + continue + qkv_nodes = [qkv for qkv in qkv_nodes_list if qkv is not None][-1] + other_inputs = [] + for input in start_node.input: + if input not in self.output_name_to_node: + continue + if input == qkv_nodes[0].output[0]: + continue + other_inputs.append(input) + if len(other_inputs) != 1: + continue + root_input = other_inputs[0] + input_name_to_nodes = self.input_name_to_nodes + children = input_name_to_nodes[root_input] + children_types = [child.op_type for child in children] + if children_types.count("MatMul") == 3: + qkv.append([child.name for child in children if child.op_type == "MatMul"]) + if not find_all: + break + return qkv + + def find_ffn_matmul(self, attention_index, attention_matmul_list, block_len): + """Find MatMul in FFN. + + Args: + attention_index (list): index of Attention + attention_matmul_list (list): list of Attention and MatMul nodes + block_len (int): block length + + Returns: + list: list of MatMul in FFN + """ + ffn_matmul = [] + for idx in range(len(attention_index)): + if idx != len(attention_index) - 1: + index = attention_index[idx + 1] + if index - 2 >= 0: + ffn_matmul.append([attention_matmul_list[index - 2], attention_matmul_list[index - 1]]) + else: + index = attention_index[idx] + if index + block_len - 1 < len(attention_matmul_list): + ffn_matmul.append( + [attention_matmul_list[index + block_len - 2], attention_matmul_list[index + block_len - 1]] + ) + return ffn_matmul + + def export(self, save_path, conf): + """Export Qlinear to QDQ model.""" + from neural_compressor.config import ONNXQlinear2QDQConfig + from neural_compressor.utils.export import onnx_qlinear_to_qdq + + if isinstance(conf, ONNXQlinear2QDQConfig): + add_nodes, remove_nodes, inits = onnx_qlinear_to_qdq(self._model, self._input_name_to_nodes) + self.add_nodes(add_nodes) + self.remove_nodes(remove_nodes) + self.add_initializers(inits) + self.update() + self.remove_unused_nodes() + self.topological_sort() + self.save(save_path) + else: + logger.warning("Unsupported config for export, only ONNXQlinear2QDQConfig is supported!") + exit(0) + + def add_tensors_to_outputs(self, tensor_names): + """Add the tensors to the model outputs to gets their values. + + Args: + tensor_names: The names of tensors to be dumped. + """ + added_outputs = [] + for tensor in tensor_names: + if tensor not in self.output(): + added_tensor = onnx.helper.ValueInfoProto() + added_tensor.name = tensor + added_outputs.append(added_tensor) + self._model.graph.output.extend(added_outputs) # pylint: disable=no-member + + def remove_tensors_from_outputs(self, tensor_names): + """Remove the tensors from the model outputs. + + Args: + tensor_names: The names of tensors to be removed. + """ + removed_outputs = [] + for tensor in tensor_names: + if tensor in self.output(): + removed_outputs.append(self._model.graph.output[self.output().index(tensor)]) + for output in removed_outputs: + self._model.graph.output.remove(output) + + def match_first_parent(self, node, parent_op_type, output_name_to_node, exclude=None): + """Find parent node based on constraints on op_type. + + Args: + node (str): current node name. + parent_op_type (str): constraint of parent node op_type. + output_name_to_node (dict): dictionary with output name as key, and node as value. + exclude (list): list of nodes that are excluded (not allowed to match as parent). + + Returns: + parent: The matched parent node. None if not found. + index: The input index of matched parent node. None if not found. + """ + if exclude is None: + exclude = [] + for i, input in enumerate(node.input): + if input in output_name_to_node: + parent = output_name_to_node[input] + if parent.op_type == parent_op_type and parent not in exclude: + return parent, i + return None, None + + def match_parent( + self, + node, + parent_op_type, + input_index=None, + output_name_to_node=None, + exclude=None, + return_indice=None, + ): + """Find parent node based on constraints on op_type and index. + + Args: + node (str): current node name. + parent_op_type (str): constraint of parent node op_type. + input_index (int or None): only check the parent given input index of current node. + output_name_to_node (dict): dictionary with output name as key, and node as value. + exclude (list): list of nodes that are excluded (not allowed to match as parent). + return_indice (list): a list to append the input index when input_index is None. + + Returns: + parent: The matched parent node. + """ + assert node is not None + assert input_index is None or input_index >= 0 + if exclude is None: + exclude = [] + if output_name_to_node is None: + output_name_to_node = self._output_name_to_node + + if input_index is None: + parent, index = self.match_first_parent(node, parent_op_type, output_name_to_node, exclude) + if return_indice is not None: + return_indice.append(index) + return parent + + if input_index >= len(node.input): + return None + + parent = self.get_parent(node, input_index, output_name_to_node) + if parent is not None and parent.op_type == parent_op_type and parent not in exclude: + return parent + + return None + + def match_parent_path( + self, + node, + parent_op_types, + parent_input_index, + output_name_to_node=None, + return_indice=None, + ): + """Find a sequence of input edges based on constraints on parent op_type and index. + + Args: + node (str): current node name. + parent_op_types (str): constraint of parent node op_type of each input edge. + parent_input_index (list): constraint of input index of each input edge. + None means no constraint. + output_name_to_node (dict): dictionary with output name as key, and node as value. + return_indice (list): a list to append the input index when there is + no constraint on input index of an edge. + + Returns: + parents: a list of matched parent node. + """ + assert len(parent_input_index) == len(parent_op_types) + + if output_name_to_node is None: + output_name_to_node = self._output_name_to_node + + current_node = node + matched_parents = [] + for i, op_type in enumerate(parent_op_types): + matched_parent = self.match_parent( + current_node, + op_type, + parent_input_index[i], + output_name_to_node, + exclude=[], + return_indice=return_indice, + ) + if matched_parent is None: + return None + + matched_parents.append(matched_parent) + current_node = matched_parent + + return matched_parents + + def is_smoothquant_model(self): + """Check the model is smooth quantized or not. + + Returns: + bool: the model is smooth quantized or not. + """ + for init in self.model.graph.initializer: # noqa: SIM110 + if "_smooth_scale" in init.name: + return True + return False + + def find_split_nodes(self): + """Find split nodes for layer-wise quantization.""" + split_nodes = self.find_split_node_for_layer_wise_quantization() + return split_nodes + + def split_model_with_node( + self, split_node_name, path_of_model_to_split, shape_infer=True, save_both_split_models=True + ): + """Split model into two parts at a given node. + + Args: + split_node_name (str): name of the node where the model is split at> + path_of_model_to_split (str): path of model to be split. + shape_infer (bool): do shape inference. Default is True. + save_both_split_models (bool): whether to save the two split models. + False means only save the first split model. + True means save both the two split models. + Default id True. + + Returns: + tuple: the first split model, the second split model + """ + # origin model : ... -> node_1 -> split_node -> node_2 -> ... + # split model 1: ... -> node_1 -> split_node + # split model 2: node_2 -> ... + + split_model_part_1 = onnx.ModelProto() + split_model_part_1.CopyFrom(self._model) + split_model_part_1.graph.ClearField("node") + + split_model_part_2 = onnx.ModelProto() + split_model_part_2.CopyFrom(self._model) + split_model_part_2.graph.ClearField("node") + + split_node_output = None + part_idx = 1 + for node in self._model.graph.node: + if part_idx == 1: + split_model_part_1.graph.node.append(node) + elif part_idx == 2: + split_model_part_2.graph.node.append(node) + + if node.name == split_node_name: + split_node_output = node.output + part_idx = 2 + + assert len(split_node_output) == 1, ( + f"Only support split at node with 1 output tensor, while current split node {split_node_name} has {len(split_node_output)} output tensors" + ) + split_tensor_name = split_node_output[0] + + # infer shape of the model to be split + if shape_infer: + try: + from neural_compressor.adaptor.ox_utils.util import infer_shapes + + self._model = infer_shapes(self._model, auto_merge=True, base_dir=os.path.dirname(self._model_path)) + except Exception as e: # pragma: no cover + logger.error( + "Shape infer fails for layer-wise quantization. " + "We would recommend checking the graph optimization level of your model " + "and setting it to 'DISABLE_ALL' or 'ENABLE_BASIC', " + "as this may help avoid this error." + ) + raise e + + split_tensor_type, split_tensor_shape = self._get_output_type_shape_by_tensor_name(split_tensor_name) + split_tensor = onnx.helper.make_tensor_value_info(split_tensor_name, split_tensor_type, split_tensor_shape) + + split_model_part_1 = ONNXModel(split_model_part_1, ignore_warning=True) + split_model_part_2 = ONNXModel(split_model_part_2, ignore_warning=True) + + # remove unused input & output + split_model_part_1._remove_unused_input_output() + split_model_part_2._remove_unused_input_output() + + split_model_part_1.model.graph.output.append(split_tensor) + split_model_part_2.model.graph.input.append(split_tensor) + + insert_output_for_model_1 = [] + insert_input_for_model_2 = [] + for output in split_model_part_1.output_name_to_node: + if output in split_model_part_2.input_name_to_nodes: + output_type, output_shape = self._get_output_type_shape_by_tensor_name(output) + output_tensor = onnx.helper.make_tensor_value_info(output, output_type, output_shape) + if output_tensor not in split_model_part_1.model.graph.output: + insert_output_for_model_1.append(output_tensor) + if output_tensor not in split_model_part_2.model.graph.input: + insert_input_for_model_2.append(output_tensor) + + # insert model 1 output + for output in insert_output_for_model_1: + split_model_part_1.model.graph.output.append(output) + + # insert model 2 input + for input in insert_input_for_model_2: + split_model_part_2.model.graph.input.append(input) + + # remove unused init + split_model_part_1.remove_unused_init() + split_model_part_2.remove_unused_init() + + split_model_part_1.update() + split_model_part_2.update() + + dir_of_model_to_split = os.path.dirname(path_of_model_to_split) + + split_model_part_1.load_model_initializer_by_tensor(dir_of_model_to_split) + split_model_part_1_path = os.path.join(dir_of_model_to_split, "split_model_part_1.onnx") + split_model_part_1.model_path = split_model_part_1_path + split_model_part_1._save_split_model(split_model_part_1_path) + split_model_part_1.check_is_large_model() + logger.debug(f"save split model part 1 to {split_model_part_1_path} for layer wise quantization") + + if save_both_split_models: + split_model_part_2.load_model_initializer_by_tensor(dir_of_model_to_split) + split_model_part_2_path = os.path.join(dir_of_model_to_split, "split_model_part_2.onnx") + split_model_part_2.model_path = split_model_part_2_path + split_model_part_2._save_split_model(split_model_part_2_path) + split_model_part_2.check_is_large_model() + logger.debug(f"save split model part 2 to {split_model_part_2_path} for layer wise quantization") + return split_model_part_1, split_model_part_2 + else: + return split_model_part_1, split_model_part_2 + + def _save_split_model(self, save_path): + """Save split model as external data for layer wise quantization. + + Args: + save_path (str): the path to save the split model + """ + if os.path.exists(save_path + "_data"): + os.remove(save_path + "_data") + onnx.save_model( + self._model, + save_path, + save_as_external_data=True, + all_tensors_to_one_file=True, + location=save_path.split("/")[-1] + "_data", + size_threshold=1024, + convert_attribute=False, + ) + + def _get_output_type_shape_by_tensor_name(self, tensor_name): + """Get output type and shape with a tensor name. + + Args: + tensor_name (str): name of a tensor + + Returns: + tuple: output type and shape + """ + elem_type = onnx.TensorProto.FLOAT + shape = None + for output in self._model.graph.value_info: + if output.name == tensor_name: + elem_type = output.type.tensor_type.elem_type + shape = [ + dim.dim_value if dim.HasField("dim_value") else -1 for dim in output.type.tensor_type.shape.dim + ] + break + return elem_type, shape + + def _remove_unused_input_output(self): + """Remove unused input & output for split model.""" + remove_outputs = [] + remove_inputs = [] + for output in self._model.graph.output: + if output.name not in self.output_name_to_node: + remove_outputs.append(output) + + for input in self._model.graph.input: + if input.name not in self.input_name_to_nodes: + remove_inputs.append(input) + + for output in remove_outputs: + self._model.graph.output.remove(output) + for input in remove_inputs: + self._model.graph.input.remove(input) + + def remove_unused_init(self): + """Remove unused init.""" + remov_inits = [] + for init in self._model.graph.initializer: + if init.name not in self.input_name_to_nodes: + remov_inits.append(init) + self.remove_initializers(remov_inits) + + def load_model_initializer_by_tensor(self, data_path=None): + """Load model initializer by tensor. + + Args: + data_path (str, optional): the directory of saved initializer. Defaults to None. + """ + from onnx.external_data_helper import load_external_data_for_tensor + + if data_path is None: + data_path = os.path.dirname(self._model_path) + for init in self._model.graph.initializer: + if init.HasField("data_location") and init.data_location == onnx.TensorProto.EXTERNAL: + load_external_data_for_tensor(init, data_path) + + def write_external_data_to_new_location(self, external_data_location="external.data", overwrite=False): + """Write external data of merged quantized model to new location to save memory. + + Args: + external_data_location (str, optional): external data location of merged quantized model. + Defaults to "external.data". + overwrite (bool, optional): if True, remove existed externa data. Defaults to False. + """ + from onnx.external_data_helper import convert_model_to_external_data, write_external_data_tensors + + if overwrite and os.path.exists(os.path.join(os.path.dirname(self._model_path), external_data_location)): + os.remove(os.path.join(os.path.dirname(self._model_path), external_data_location)) + self.load_model_initializer_by_tensor() + convert_model_to_external_data(self._model, location=external_data_location) + # TODO : if init is already saved, skip write it + write_external_data_tensors(self._model, filepath=os.path.dirname(self._model_path)) + + def merge_split_models(self, to_merge_model): + """Merge two split model into final model.""" + to_merge_model.write_external_data_to_new_location() + self.add_nodes(list(to_merge_model.nodes())) + self.add_initializers(list(to_merge_model.initializer())) + self.update() + + # add new output + for output in to_merge_model.graph().output: + if output.name not in self.output(): + self._model.graph.output.append(output) + + # remove unused output + remove_output = [] + for output in self._model.graph.output: + if output.name in to_merge_model.input(): + remove_output.append(output) + for output in remove_output: + self._model.graph.output.remove(output) + + # add new input + for input in to_merge_model.graph().input: + if ( + input.name not in self.input() + and input.name not in self.output() + and input.name not in self.output_name_to_node + ): + self._model.graph.input.append(input) + + def re_org_output(self, origin_output): + """Re-org output of merged model for layer-wise quantization.""" + outputs = {} + tmp_remove = [] + for output in self._model.graph.output: + outputs[output.name] = output + tmp_remove.append(output) + + for output in tmp_remove: + self._model.graph.output.remove(output) + + for out_name in origin_output: + self._model.graph.output.append(outputs[out_name]) diff --git a/onnxruntime/python/tools/quantization/neural_compressor/util.py b/onnxruntime/python/tools/quantization/neural_compressor/util.py new file mode 100644 index 0000000000000..aae01b4defd1f --- /dev/null +++ b/onnxruntime/python/tools/quantization/neural_compressor/util.py @@ -0,0 +1,80 @@ +# +# The implementation of this file is based on: +# https://github.com/intel/neural-compressor/tree/master/neural_compressor +# +# Copyright (c) 2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Helper classes or functions for onnxrt adaptor.""" + +import importlib +import logging + +import numpy as np + +logger = logging.getLogger("neural_compressor") + + +MAXIMUM_PROTOBUF = 2147483648 + + +def simple_progress_bar(total, i): + """Progress bar for cases where tqdm can't be used.""" + progress = i / total + bar_length = 20 + bar = "#" * int(bar_length * progress) + spaces = " " * (bar_length - len(bar)) + percentage = progress * 100 + print(f"\rProgress: [{bar}{spaces}] {percentage:.2f}%", end="") + + +def find_by_name(name, item_list): + """Helper function to find item by name in a list.""" + items = [] + for item in item_list: + assert hasattr(item, "name"), f"{item} should have a 'name' attribute defined" # pragma: no cover + if item.name == name: + items.append(item) + if len(items) > 0: + return items[0] + else: + return None + + +def to_numpy(data): + """Convert to numpy ndarrays.""" + import torch + + if not isinstance(data, np.ndarray): + if not importlib.util.find_spec("torch"): + logger.error( + "Please install torch to enable subsequent data type check and conversion, " + "or reorganize your data format to numpy array." + ) + exit(0) + if isinstance(data, torch.Tensor): + if data.dtype is torch.bfloat16: # pragma: no cover + return data.detach().cpu().to(torch.float32).numpy() + if data.dtype is torch.chalf: # pragma: no cover + return data.detach().cpu().to(torch.cfloat).numpy() + return data.detach().cpu().numpy() + else: + try: + return np.array(data) + except Exception: + assert False, ( # noqa: B011 + f"The input data for onnx model is {type(data)}, which is not supported to convert to numpy ndarrays." + ) + else: + return data diff --git a/onnxruntime/python/tools/quantization/neural_compressor/weight_only.py b/onnxruntime/python/tools/quantization/neural_compressor/weight_only.py new file mode 100644 index 0000000000000..558415f028c7b --- /dev/null +++ b/onnxruntime/python/tools/quantization/neural_compressor/weight_only.py @@ -0,0 +1,932 @@ +# +# The implementation of this file is based on: +# https://github.com/intel/neural-compressor/tree/master/neural_compressor +# +# Copyright (c) 2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Modifications: +# Add k-quant quantization method. +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""WeightOnly for onnxrt adaptor.""" + +import copy +import logging +import os +import sys + +import numpy as np +import onnx +from onnx import numpy_helper +from onnx.helper import np_dtype_to_tensor_dtype + +import onnxruntime as ort + +from .onnx_model import ONNXModel +from .util import simple_progress_bar + +logger = logging.getLogger("neural_compressor") + + +def make_matmul_weight_only_node( + node, + weight_shape, + num_bits, + group_size, + k_blocks, + q_weight, + scale, + zero_point, + accuracy_level=0, +): # pragma: no cover + """Build MatMulNBits node. + + Args: + node: original matmul node + weight_shape: original weight shape + num_bits (int): num_bits + group_size (int): how many elements share one scale/zp + k_blocks (int): block number + q_weight (array): quantized weight + scale (array): scale + zero_point (array): zero point + accuracy_level (int): accuracy level. Support 0 (unset), 1(fp32), 2(fp16), 3(bf16), or 4(int8). + + Returns: + matmul_weight_only_node: MatMulNBits node + new_inits: initializers of the new node + """ + blob_size = group_size * num_bits // 8 + packed = np.zeros((q_weight.shape[0], blob_size), dtype="uint8") + q_weight_name = node.input[1] + f"_Q{num_bits!s}G{group_size!s}" + input_names = [node.input[0], q_weight_name] + new_inits = [] + kwargs = {} + + op_type = "MatMulNBits" + + # pack quantized weight + if num_bits == 4: + q_weight_pairs = q_weight[:, ::2] | q_weight[:, 1::2] << 4 + packed[:, :] = q_weight_pairs[:, :blob_size] + elif num_bits == 8: + packed = q_weight + else: + logger.error(f"MatMulNBits does not have kernel support for num_bits = {num_bits}.") + + packed = np.reshape(packed, (-1, k_blocks, blob_size)) + + # build scale tensor + scale = np.reshape(scale, (-1, k_blocks)) + assert scale.dtype == np.float32 or scale.dtype == np.float16 + scale_tensor = onnx.helper.make_tensor( + name=node.input[1] + "_scale", + data_type=np_dtype_to_tensor_dtype(scale.dtype), + dims=scale.shape, + vals=scale.tobytes(), + raw=True, + ) + input_names.append(scale_tensor.name) + new_inits.append(scale_tensor) + + # build zero_point tensor + if zero_point is not None: + if num_bits == 8: + packed_zp = zero_point.astype("uint8") + elif num_bits == 4: + # For 4-bit case, the default zeros is 0x8. So it is 0x88 = 136 if we fill lower/higher 4 bits with 0x8. + packed_zp = np.full((zero_point.shape[0] + 1) // 2, 136, dtype="uint8") + # create an index array + idx = np.arange(zero_point.shape[0] // k_blocks * k_blocks).reshape(-1) + # separate odd and even indices + even_idx = idx[::2] + odd_idx = idx[1::2] + # vectorized operation for even and odd indices + packed_zp[even_idx // 2] = (packed_zp[even_idx // 2] & 0xF0) | zero_point[even_idx].ravel() + packed_zp[odd_idx // 2] = (packed_zp[odd_idx // 2] & 0x0F) | (zero_point[odd_idx].ravel() << 4) + else: + raise ValueError(f"MatMulNBits does not have kernel support for num_bits = {num_bits}.") + + packed_zp = np.reshape(packed_zp, (weight_shape[1], -1)) + zp_tensor = onnx.helper.make_tensor( + name=node.input[1] + "_zp", data_type=2, dims=packed_zp.shape, vals=packed_zp.tobytes(), raw=True + ) + input_names.append(zp_tensor.name) + new_inits.append(zp_tensor) + + # set kwargs + kwargs["K"] = weight_shape[0] + kwargs["N"] = weight_shape[1] + kwargs["bits"] = num_bits + kwargs["block_size"] = group_size + if accuracy_level > 0: + # require onnxruntime > 1.16.3 + kwargs["accuracy_level"] = accuracy_level + + q_weight_tensor = onnx.helper.make_tensor( + name=q_weight_name, + data_type=2, + dims=packed.shape, + vals=packed.tobytes(), + raw=True, + ) + new_inits.append(q_weight_tensor) + + matmul_weight_only_node = onnx.helper.make_node( + op_type, + inputs=input_names, + outputs=node.output, + name=node.name + "_Q" + str(num_bits) if node.name else "_Q" + str(num_bits), + domain="com.microsoft", + **kwargs, + ) + return matmul_weight_only_node, new_inits + + +def quant_tensor(data, num_bits=4, group_size=32, scheme="asym", dtype="int", ratio=1.0): + """Quantize tensor per group. + + Args: + data : input weight + num_bits (int, optional): num_bits. Defaults to 4. + group_size (int, optional): how many elements share one scale/zp. Defaults to 4. + scheme (str, optional): quantization scheme. Defaults to "asym". + dtype (str, optional): data type. Defaults to "int". + ratio (float, optional): percentile of clip. Defaults to 1.0. + + Returns: + output: quantized weight + scale: scale + zero_point: zero point + """ + data = np.reshape(data, (-1, group_size)) + if scheme == "asym" or dtype == "uint": + maxq = 2**num_bits - 1 + minq = 0 + elif scheme == "sym": + maxq = 2 ** (num_bits - 1) - 1 if num_bits != 1 else 0 + minq = -(2 ** (num_bits - 1)) if num_bits != 1 else -1 + + rmin = np.min(data, axis=1, keepdims=True) * ratio + rmax = np.max(data, axis=1, keepdims=True) * ratio + if scheme == "sym": + max_range = np.maximum(np.abs(rmin), np.abs(rmax)) + scale = np.ones(rmax.shape) + mask = max_range > 0 + scale[mask] = (max_range[mask] * 2.0).astype(np.float64) / (maxq - minq) + zero_point = ( + np.zeros(scale.shape) if dtype == "int" else np.ones(rmax.shape, dtype="uint8") * (1 << (num_bits - 1)) + ) + else: + scale = np.ones(rmax.shape) + scale[rmin != rmax] = np.array( + [float(i) / (maxq - minq) for i in (rmax - rmin)[rmin != rmax].flatten().tolist()] + ) + zero_point = ( + ((np.zeros(scale.shape) - rmin) / scale).round() + if dtype == "int" + else np.maximum(0, np.minimum(maxq, ((np.zeros(scale.shape) - rmin) / scale).round())).astype("uint8") + ) + + q_weight = np.empty_like(data, dtype=scale.dtype) + np.divide(data, scale, out=q_weight) + np.add(q_weight, zero_point, out=q_weight) + np.round(q_weight, out=q_weight) + np.clip(q_weight, minq, maxq, out=q_weight) + + return q_weight, scale, zero_point + + +def quant_tensor_k_quant_cpu(data, num_bits=4, group_size=32): + """Quantize tensor per group based on k quant. + + Ref: https://github.com/ggml-org/llama.cpp/blob/64eda5deb9859e87a020e56bab5d2f9ca956f1de/ggml/src/ggml-quants.c + + Args: + data : input weight + num_bits (int, optional): num_bits. Defaults to 4. + group_size (int, optional): how many elements share one scale/zp. Defaults to 32. + + Returns: + output: quantized weight + scale: scale + zero_point: zero point + """ + data = np.reshape(data, (-1, group_size)).astype(np.float32) # nb = data.shape[0], (nb, group_size) + maxq = 2**num_bits - 1 + minq = 0 + sum_x2 = np.sum(data**2, axis=1, keepdims=True) # (nb, 1) + av_x = np.sqrt(sum_x2 / group_size) # (nb, 1) + weights = np.add(av_x, np.abs(data)) # (nb, group_size) + rmin = np.min(data, axis=1, keepdims=True) # (nb, 1) + rmax = np.max(data, axis=1, keepdims=True) # (nb, 1) + sum_w = np.sum(weights, axis=1, keepdims=True) # (nb, 1) + sum_x = np.sum(weights * data, axis=1, keepdims=True) # (nb, group_size) + iscale = np.ones(rmax.shape, dtype=data.dtype) # (nb, 1) + mask = rmin != rmax + iscale[mask] = (maxq - minq) / (rmax[mask] - rmin[mask]) + scale = 1 / iscale + quant_data = np.clip(np.round(iscale * (data - rmin)), minq, maxq) # (nb, group_size) + diff = scale * quant_data + rmin - data # (nb, group_size) + best_mad = np.sum(weights * diff**2, axis=1, keepdims=True) # (nb, 1) + nstep = 20 + rdelta = 0.1 + # nstep * rdelta = -2 * rrmin, maxq - minq = 2**num_bits - 1 + rrmin = -1 + for is_ in range(nstep): + iscale_new = np.ones(rmax.shape, dtype=data.dtype) # (nb, 1) + factor = np.array([rrmin + rdelta * is_ + maxq - minq]).astype(data.dtype)[0] + mask = rmin != rmax + iscale_new[mask] = factor / (rmax[mask] - rmin[mask]) + quant_data_new = np.clip(np.round(iscale_new * (data - rmin)), minq, maxq) # (nb, group_size) + mul_weights_quant_data_new = weights * quant_data_new + sum_l = np.sum(mul_weights_quant_data_new, axis=1, keepdims=True) # (nb, 1) + sum_l2 = np.sum(mul_weights_quant_data_new * quant_data_new, axis=1, keepdims=True) # (nb, 1) + sum_xl = np.sum(mul_weights_quant_data_new * data, axis=1, keepdims=True) # (nb, 1) + D = np.subtract(sum_w * sum_l2, sum_l**2) # noqa: N806 + + this_scale = (sum_w * sum_xl - sum_x * sum_l) / D # (nb, 1) + this_min = (sum_l2 * sum_x - sum_l * sum_xl) / D # (nb, 1) + + diff = this_scale * quant_data_new + this_min - data # (nb, group_size) + mad = np.sum(weights * diff**2, axis=1, keepdims=True) # (nb, 1) + + mad_1 = np.array(mad) + best_mad_1 = np.array(best_mad) + idx_to_replace = np.where(mad_1 < best_mad_1)[0] + quant_data[idx_to_replace, :] = quant_data_new[idx_to_replace, :] + best_mad[idx_to_replace] = mad[idx_to_replace] + scale[idx_to_replace] = this_scale[idx_to_replace] + rmin[idx_to_replace] = this_min[idx_to_replace] + + zero_point = np.clip(((-rmin) / scale).round(), 0, maxq).astype("uint8") + scale = scale.astype(np.float64) + q_weight = np.empty_like(data, dtype=scale.dtype) + np.divide(data, scale, out=q_weight) + np.add(q_weight, zero_point, out=q_weight) + np.round(q_weight, out=q_weight) + np.clip(q_weight, minq, maxq, out=q_weight) + + return q_weight, scale, zero_point + + +def quant_tensor_k_quant_cuda(data, num_bits=4, group_size=32): + """Quantize tensor per group based on k quant. + + Ref: https://github.com/ggml-org/llama.cpp/blob/64eda5deb9859e87a020e56bab5d2f9ca956f1de/ggml/src/ggml-quants.c + + Args: + data : input weight + num_bits (int, optional): num_bits. Defaults to 4. + group_size (int, optional): how many elements share one scale/zp. Defaults to 4. + + Returns: + output: quantized weight + scale: scale + zero_point: zero point + """ + try: + import cupy as cp + import torch + + if torch.cuda.is_available(): + data = cp.asarray(data) + data = data.reshape((-1, group_size)).astype(cp.float32) # nb = data.shape[0], (nb, group_size) + maxq = 2**num_bits - 1 + minq = 0 + sum_x2 = cp.sum(data**2, axis=1, keepdims=True) # (nb, 1) + av_x = cp.sqrt(sum_x2 / group_size) # (nb, 1) + weights = cp.add(av_x, cp.abs(data)) # (nb, group_size) + rmin = cp.min(data, axis=1, keepdims=True) # (nb, 1) + rmax = cp.max(data, axis=1, keepdims=True) # (nb, 1) + sum_w = cp.sum(weights, axis=1, keepdims=True) # (nb, 1) + sum_x = cp.sum(weights * data, axis=1, keepdims=True) # (nb, group_size) + iscale = cp.ones(rmax.shape, dtype=data.dtype) # (nb, 1) + mask = rmin != rmax + iscale[mask] = (maxq - minq) / (rmax[mask] - rmin[mask]) + scale = 1 / iscale + quant_data = cp.clip(cp.round(iscale * (data - rmin)), minq, maxq) # (nb, group_size) + diff = scale * quant_data + rmin - data # (nb, group_size) + best_mad = cp.sum(weights * diff**2, axis=1, keepdims=True) # (nb, 1) + nstep = 20 + rdelta = 0.1 + rrmin = -1 + for is_ in range(nstep): + iscale_new = cp.ones(rmax.shape, dtype=data.dtype) # (nb, 1) + factor = cp.array([rrmin + rdelta * is_ + maxq - minq]).astype(data.dtype)[0] + mask = rmin != rmax + iscale_new[mask] = factor / (rmax[mask] - rmin[mask]) + quant_data_new = cp.clip(cp.round(iscale_new * (data - rmin)), minq, maxq) # (nb, group_size) + mul_weights_quant_data_new = weights * quant_data_new + sum_l = cp.sum(mul_weights_quant_data_new, axis=1, keepdims=True) # (nb, 1) + sum_l2 = cp.sum(mul_weights_quant_data_new * quant_data_new, axis=1, keepdims=True) # (nb, 1) + sum_xl = cp.sum(mul_weights_quant_data_new * data, axis=1, keepdims=True) # (nb, 1) + D = cp.subtract(sum_w * sum_l2, sum_l**2) # noqa: N806 + + this_scale = (sum_w * sum_xl - sum_x * sum_l) / D # (nb, 1) + this_min = (sum_l2 * sum_x - sum_l * sum_xl) / D # (nb, 1) + + diff = this_scale * quant_data_new + this_min - data # (nb, group_size) + mad = cp.sum(weights * diff**2, axis=1, keepdims=True) # (nb, 1) + + mad_1 = cp.array(mad) + best_mad_1 = cp.array(best_mad) + idx_to_replace = cp.where(mad_1 < best_mad_1)[0] + quant_data[idx_to_replace, :] = quant_data_new[idx_to_replace, :] + best_mad[idx_to_replace] = mad[idx_to_replace] + scale[idx_to_replace] = this_scale[idx_to_replace] + rmin[idx_to_replace] = this_min[idx_to_replace] + + zero_point = cp.clip(((-rmin) / scale).round(), 0, maxq).astype("uint8") + scale = scale.astype(cp.float64) + q_weight = cp.empty_like(data, dtype=scale.dtype) + cp.divide(data, scale, out=q_weight) + cp.add(q_weight, zero_point, out=q_weight) + cp.round(q_weight, out=q_weight) + cp.clip(q_weight, minq, maxq, out=q_weight) + + return q_weight.get(), scale.get(), zero_point.get() + else: + logger.warning( + "Try to use k-quant quantization on CUDA. However, CUDA is not available." + "Fall back to k-quant quantization on CPU." + ) + return quant_tensor_k_quant_cpu(data, num_bits, group_size) + except ImportError: + logger.info( + "Now we are using k-quant quantization on cpu, which is time consuming." + "Please consider install cupy to speed up on CUDA. See https://cupy.dev/" + "Please also install torch to check CUDA availability." + ) + return quant_tensor_k_quant_cpu(data, num_bits, group_size) + + +def qdq_tensor(data, num_bits=4, group_size=32, scheme="asym", dtype="int", ratio=1.0): + """Quant dequant tensor per group. + + Args: + data : input weight + num_bits (int, optional): num_bits. Defaults to 4. + group_size (int, optional): how many elements share one scale/zp. Defaults to 4. + scheme (str, optional): quantization scheme. Defaults to "asym". + dtype (str, optional): data type. Defaults to "int". + ratio (float, optional): percentile of clip. Defaults to 1.0. + + Returns: + output: quant-dequant weight + """ + org_shape = data.shape + weight, scale, zp = quant_tensor(data, num_bits, group_size, scheme, dtype, ratio) + return np.reshape(scale * (weight - zp), org_shape) + + +def pad_tensor(weight, group_size, k_blocks): + """Pad tensor rowi so that it can be is divisible by group_size. + + Args: + weight (array): weight + group_size (int): how many elements share one scale/zp + k_blocks (int): the number of block + + Returns: + weight: paded weight + """ + if group_size == -1: + return weight + + org_w_shape = weight.shape + padded_rows = k_blocks * group_size + pad_len = padded_rows - org_w_shape[0] + + if pad_len > 0: + weight = np.pad(weight, ((0, pad_len), (0, 0)), "constant") + + return weight + + +def rtn_quantize( + model, + weight_config={}, # noqa: B006 + num_bits=4, + group_size=32, + scheme="asym", + ratios={}, # noqa: B006 + accuracy_level=0, + providers=["CPUExecutionProvider"], # noqa: B006 + algorithm="k_quant", +): + """Quant the model with round to nearst method. + + Args: + model (ModelProto or ONNXModel): onnx model + weight_config (dict): quantization config + For example, + weight_config = { + 'fc2': + { + 'bits': 4, + 'group_size': 32, + 'scheme': 'sym', + 'algorithm': 'RTN' + } + } + num_bits (int, optional): num_bits. Default is 4. + group_size (int, optional): how many elements share one scale/zp. Default is 32. + scheme (str, optional): sym or asym. Defaults to "asym". + ratios (dict, optional): percentile of clip. Defaults to {}. + accuracy_level (int): accuracy level. Support 0 (unset),1(fp32), 2(fp16), 3(bf16), or 4(int8). + providers (list): providers to use + + Returns: + model: fake quantized ONNXModel + """ + model = ONNXModel(model) + base_dir = os.path.dirname(model.model_path) if model.model_path is not None else "" + new_nodes = [] + remove_nodes = [] + total_num = len([i for i in model.nodes() if i.op_type in ["MatMul"]]) + curr_id = 0 + for node in model.nodes(): + if node.op_type in ["MatMul"]: + curr_id += 1 + simple_progress_bar(total_num, curr_id) + if ( + node.op_type in ["MatMul"] + and model.get_initializer(node.input[1]) is not None + and weight_config.get(node.name, {}) != "fp32" + ): + weight_tensor = model.get_initializer(node.input[1]) + weight = numpy_helper.to_array(weight_tensor, base_dir=base_dir).copy() + if len(weight.shape) != 2: + continue + + dtype = weight.dtype + + if node.name in weight_config: + num_bits = weight_config[node.name]["bits"] + group_size = weight_config[node.name]["group_size"] + scheme = weight_config[node.name]["scheme"] + + org_w_shape = weight.shape # ic, oc + group_size = group_size if group_size != -1 else org_w_shape[0] + + k_blocks = (org_w_shape[0] - 1) // group_size + 1 + init_share_num = model.get_initializer_share_num(node.input[1]) + + weight = pad_tensor(weight, group_size, k_blocks) + + satisfy_MatMulNBits_condition = num_bits == 4 or num_bits == 8 # noqa: N806 + + if satisfy_MatMulNBits_condition: # pragma: no cover + if algorithm == "k_quant": + q_weight, scale, zp = quant_tensor_k_quant_cuda(weight.T, num_bits, group_size) + else: + q_weight, scale, zp = quant_tensor( + weight.T, num_bits, group_size, scheme, "uint", ratios.get(node.input[1], 1) + ) + + q_matmul_node, new_inits = make_matmul_weight_only_node( + node=node, + weight_shape=org_w_shape, + num_bits=num_bits, + group_size=group_size, + k_blocks=k_blocks, + q_weight=q_weight.astype("uint8"), + scale=scale.astype(dtype), + zero_point=zp if scheme == "asym" or algorithm == "k_quant" else None, + accuracy_level=accuracy_level, + ) + + model.add_initializers(new_inits) + remove_nodes.append(node) + new_nodes.append(q_matmul_node) + else: + q_weight = qdq_tensor(weight.T, num_bits, group_size, scheme, "int", ratios.get(node.input[1], 1)) + q_weight = np.reshape(q_weight, (org_w_shape[1], -1)) + q_weight = np.transpose(q_weight) + q_weight = q_weight[: org_w_shape[0], :].astype(dtype) + q_weight_tensor = onnx.helper.make_tensor( + name=node.input[1] + f"_Q{num_bits!s}G{group_size!s}", + data_type=np_dtype_to_tensor_dtype(dtype), + dims=weight.shape, + vals=q_weight.tobytes(), + raw=True, + ) + model.add_initializer(q_weight_tensor) + node.input[1] = q_weight_tensor.name + if init_share_num == 1: + model.remove_initializer(weight_tensor) + + model.add_nodes(new_nodes) + model.remove_nodes(remove_nodes) + model.topological_sort() + return model + + +def get_weight_scale(weight, group_size): + """Get the scale of weight.""" + org_shape = weight.shape + weight = np.reshape(weight, (-1, group_size)) if group_size != -1 else weight + scale = np.mean(np.reshape(np.abs(weight) / np.max(np.abs(weight), axis=1, keepdims=True), org_shape), axis=0) + return scale + + +def prepare_inputs(model, n_samples, dataloader, providers): + """Prepare inputs for weight only quantization. + + Args: + model (ModelProto or ONNXModel): onnx model + n_samples (int, optional): calibration sample number. -1 means all samples. + dataloader (object): dataloader for calibration. + providers (list): providers to use + + Returns: + inputs: prepared inputs. + so: session options + """ + from importlib.util import find_spec + + from .util import to_numpy + + so = ort.SessionOptions() + if sys.version_info < (3, 11) and find_spec("onnxruntime_extensions"): # pragma: no cover + from onnxruntime_extensions import get_library_path + + so.register_custom_ops_library(get_library_path()) + if model.is_large_model: + onnx.save_model( + model.model, + model.model_path + "_augment.onnx", + save_as_external_data=True, + all_tensors_to_one_file=True, + convert_attribute=False, + ) + + session = ( + ort.InferenceSession(model.model.SerializeToString(), so, providers=providers) + if not model.is_large_model + else ort.InferenceSession(model.model_path + "_augment.onnx", so, providers=providers) + ) + inputs_names = [i.name for i in session.get_inputs()] + del session + + inputs = [] + for i, data in enumerate(dataloader): + if n_samples != -1 and ((i + 1) * dataloader.batch_size) > n_samples: + break + if len(inputs_names) != 1 or isinstance(data[0], dict): + assert len(data[0]) == len(inputs_names), ( + f"Input number mismatch, require {len(inputs_names)} but get {len(data[0])}" + ) + + if isinstance(data[0], dict): + inputs.append(dict([(name, to_numpy(inp_data)) for name, inp_data in data[0].items()])) # noqa: C404 + elif isinstance(data[0], np.ndarray): # pragma: no cover + inputs.append(dict([(name, inp) for name, inp in zip(inputs_names, [data[0]], strict=False)])) # noqa: C404 + else: # pragma: no cover + inputs.append(dict([(name, to_numpy(inp)) for name, inp in zip(inputs_names, data[0], strict=False)])) # noqa: C404 + return inputs, so + + +def gptq( + W, + H, + num_bits=4, + group_size=32, + scheme="asym", + blocksize=128, + percdamp=0.01, + actorder=False, + mse=False, + perchannel=True, +): + """Quant the weight with GPTQ method. + + Args: + W (array): weight. + H (array): Hessian matrix. + num_bits (int, optional): num_bits. Default is 4. + group_size (int, optional): how many elements share one scale/zp. Default is 32. + scheme (str, optional): sym or asym. Defaults to "asym". + blocksize (int, optional): blocksize to quantize weight. + percdamp (float, optional): percent of the average Hessian diagonal to use for dampening. + actorder (bool, optional): whether rearrange Hessian matrix considering the diag's value. + mse (bool, optional): whether get scale and zero point with mse error. + perchannel (bool, optional): whether quantize weight per-channel. + + Returns: + Q: fake quantized weight + """ + maxq = 2**num_bits - 1 + grid = 100 + maxshrink = 0.8 + norm = 2.4 + + def find_params(weight): + org_shape = weight.shape + # find zp, scale + if not perchannel: + weight = np.expand_dims(weight.flatten(), axis=1) + tmp = np.zeros(weight.shape[1]) + xmin = np.minimum(np.min(weight, axis=0), tmp) + xmax = np.maximum(np.max(weight, axis=0), tmp) + if scheme == "sym": + xmax = np.maximum(np.abs(xmin), xmax) + tmp = xmin < 0 + if np.any(tmp): + xmin[tmp] = -xmax[tmp] + tmp = (xmin == 0) & (xmax == 0) + xmin[tmp] = -1 + xmax[tmp] = +1 + + scale = (xmax - xmin) / maxq + if scheme == "sym": + zero = np.ones(scale.shape) * (maxq + 1) / 2 + else: + zero = np.round(-xmin / scale) + if mse: + best = np.ones([weight.shape[1]]) * float("inf") + for i in range(int(maxshrink * grid)): + p = 1 - i / grid + xmin1 = p * xmin + xmax1 = p * xmax + scale1 = (xmax1 - xmin1) / maxq + zero1 = np.round(-xmin1 / scale1) if scheme != "sym" else zero + q = np.clip(np.round(weight / scale1) + zero1, 0, maxq) + q -= weight + q = np.power(np.abs(q), norm) + err = np.sum(q, 0) + tmp = err < best + if np.any(tmp): + best[tmp] = err[tmp] + scale[tmp] = scale1[tmp] + zero[tmp] = zero1[tmp] + if not perchannel: + tmp = org_shape[1] + scale = np.repeat(scale, tmp) + zero = np.repeat(zero, tmp) + shape = [-1] + [1] * (len(org_shape) - 1) + scale = np.reshape(scale, shape) + zero = np.reshape(zero, shape) + return scale, zero + + shape = W.shape + scale, zp = find_params(W) + dead = np.diag(H) == 0 + H[dead, dead] = 1 + W[dead, :] = 0 # such channel makes no contribution to quantization computation + + # rearrange considering the diag's value + if actorder: + perm = np.argsort(np.diag(H))[::-1] + W = W[perm, :] # noqa: N806 + H = H[perm, :][:, perm] # noqa: N806 + Losses = np.zeros_like(W) # noqa: N806 + Q = np.zeros_like(W) # noqa: N806 + damp = percdamp * np.mean(np.diag(H)) + diag = np.arange(shape[0]) + H[diag, diag] += damp # add a average value of + H = np.linalg.cholesky(np.linalg.inv(H)).T # noqa: N806 + Hinv = H # noqa: N806 + for i1 in range(0, shape[0], blocksize): + i2 = min(i1 + blocksize, shape[0]) + count = i2 - i1 + + W1 = copy.deepcopy(W[i1:i2, :]) # noqa: N806 + Q1 = np.zeros_like(W1) # noqa: N806 + Err1 = np.zeros_like(W1) # noqa: N806 + Losses1 = np.zeros_like(W1) # noqa: N806 + Hinv1 = Hinv[i1:i2, i1:i2] # noqa: N806 + + for i in range(count): # within a block, channel wise + w = W1[i, :] + d = Hinv1[i, i] + + if group_size != -1: + if (i1 + i) % group_size == 0: + scale, zp = find_params(W[(i1 + i) : (i1 + i + group_size), :]) + + q = (scale * (np.clip(np.round(w[:, np.newaxis] / scale) + zp, 0, maxq) - zp)).flatten() + Q1[i, :] = q + Losses1[i, :] = (w - q) ** 2 / d**2 + + err1 = (w - q) / d + W1[i:, :] -= np.matmul(np.expand_dims(Hinv1[i:, i], axis=1), np.expand_dims(err1, axis=0)) + Err1[i, :] = err1 + + Q[i1:i2, :] = Q1 + Losses[i1:i2, :] = Losses1 / 2 + + W[i2:, :] -= np.matmul(Hinv[i2:, i1:i2], Err1) + + if actorder: + invperm = np.argsort(perm) + Q = Q[invperm, :] # noqa: N806 + + Q = np.reshape(Q, W.shape) # noqa: N806 + del W + return Q + + +def gptq_quantize( + model, + dataloader, + weight_config={}, # noqa: B006 + num_bits=4, + group_size=32, + scheme="asym", + n_samples=128, + percdamp=0.01, + blocksize=128, + actorder=False, + mse=False, + perchannel=True, + accuracy_level=0, + providers=["CPUExecutionProvider"], # noqa: B006 +): + """Quant the model with GPTQ method. + + Args: + model (ModelProto or ONNXModel): onnx model + dataloader (object): dataloader for calibration. + weight_config (dict): quantization config + For example, + weight_config = { + 'fc2': + { + 'bits': 4, + 'group_size': 32, + 'scheme': 'sym', + 'algorithm': 'GPTQ' + } + } + num_bits (int, optional): num_bits. Default is 4. + group_size (int, optional): how many elements share one scale/zp. Default is 32. + scheme (str, optional): sym or asym. Defaults to "asym". + n_samples (int, optional): calibration sample number. + percdamp (float, optional): percent of the average Hessian diagonal to use for dampening. + blocksize (int, optional): blocksize to quantize weight. + actorder (bool, optional): whether rearrange Hessian matrix considering the diag's value. + mse (bool, optional): whether get scale and zero point with mse error. + perchannel (bool, optional): whether quantize weight per-channel. + accuracy_level (int): accuracy level. Support 0 (unset), 1(fp32), 2(fp16), 3(bf16), or 4(int8). + providers (list): providers to use + + Returns: + model: fake quantized ONNXModel + """ + model = ONNXModel(model) + base_dir = os.path.dirname(model.model_path) if model.model_path is not None else "" + + inputs, so = prepare_inputs(model, n_samples, dataloader, providers) + del dataloader + org_output = copy.deepcopy(model.model.graph.output) + model.remove_tensors_from_outputs([i.name for i in org_output]) + output_names = [] + for node in model.nodes(): + if ( + node.op_type in ["MatMul"] + and weight_config.get(node.name, {}) != "fp32" + and weight_config.get(node.name, {}).get("algorithm", "GPTQ") == "GPTQ" + ): + output_names.append(node.input[0]) + output_names = list(set(output_names)) + model.add_tensors_to_outputs(output_names) + if model.is_large_model: + onnx.save_model( + model.model, + model.model_path + "_augment.onnx", + save_as_external_data=True, + all_tensors_to_one_file=True, + convert_attribute=False, + ) + + session = ( + ort.InferenceSession(model.model.SerializeToString(), so, providers=providers) + if not model.is_large_model + else ort.InferenceSession(model.model_path + "_augment.onnx", so, providers=providers) + ) + + for idx, input_name in enumerate(output_names): + simple_progress_bar(len(output_names), idx + 1) + node_list = [] + weights = [] + + for node in model.input_name_to_nodes[input_name]: + if ( + node.op_type in ["MatMul"] + and weight_config.get(node.name, {}) != "fp32" + and weight_config.get(node.name, {}).get("algorithm", "GPTQ") == "GPTQ" + and model.get_initializer(node.input[1]) is not None + ): + weight = numpy_helper.to_array( + model.get_initializer(model.get_node(node.name).input[1]), base_dir + ).copy() + if len(weight.shape) != 2: + continue + + weights.append(weight) + node_list.append(model.get_node(node.name)) + + if len(weights) == 0: + continue + + Hs = [np.zeros((i.shape[0], i.shape[0])) for i in weights] # noqa: N806 + nsamples = 0 + for data in inputs: + inp = session.run([input_name], data)[0] + tmp = inp.shape[0] + inp = np.reshape(inp, (-1, inp.shape[-1])) + Hs = [i * (nsamples / (nsamples + tmp)) for i in Hs] # noqa: N806 + nsamples += tmp + inp = np.sqrt(2 / nsamples) * inp + Hs = [i + np.matmul(inp.T, inp) for i in Hs] # noqa: N806 + + for ( + node, + weight, + H, # noqa: N806 + ) in zip(node_list, weights, Hs, strict=False): + if node.name in weight_config: + num_bits = weight_config[node.name]["bits"] + group_size = weight_config[node.name]["group_size"] + scheme = weight_config[node.name]["scheme"] + group_size = group_size if group_size != -1 else weight.shape[0] + dtype = weight.dtype + + q_weight = gptq( + weight, + H, + num_bits=num_bits, + group_size=group_size, + scheme=scheme, + blocksize=blocksize, + percdamp=percdamp, + actorder=actorder, + mse=mse, + perchannel=perchannel, + ) + + weight_tensor = model.get_initializer(node.input[1]) + init_share_num = model.get_initializer_share_num(node.input[1]) + + satisfy_MatMulNBits_condition = num_bits == 4 # noqa: N806 + + if satisfy_MatMulNBits_condition: # pragma: no cover + org_shape = weight.shape + k_blocks = (org_shape[0] + group_size - 1) // group_size + q_weight = pad_tensor(q_weight, group_size, k_blocks) + q_weight, scale, zp = quant_tensor(q_weight.T, num_bits, group_size, scheme, "uint") + q_matmul_node, new_inits = make_matmul_weight_only_node( + node=node, + weight_shape=org_shape, + num_bits=num_bits, + group_size=group_size, + k_blocks=k_blocks, + q_weight=q_weight.astype("uint8"), + scale=scale.astype(dtype), + zero_point=zp if scheme == "asym" else None, + accuracy_level=accuracy_level, + ) + + model.add_initializers(new_inits) + model.remove_node(node) + model.add_node(q_matmul_node) + else: + q_weight_tensor = onnx.helper.make_tensor( + name=node.input[1] + f"_Q{num_bits!s}G{group_size!s}", + data_type=np_dtype_to_tensor_dtype(dtype), + dims=q_weight.shape, + vals=q_weight.astype(dtype).tobytes(), + raw=True, + ) + model.add_initializer(q_weight_tensor) + node.input[1] = q_weight_tensor.name + if init_share_num == 1: + model.remove_initializer(weight_tensor) + + model.remove_tensors_from_outputs(output_names) + model.model.graph.output.MergeFrom(org_output) + + model.topological_sort() + + # reload external data to prevent external data file path errors + if model.is_large_model: + from onnx.external_data_helper import load_external_data_for_model + + load_external_data_for_model(model.model, os.path.split(model.model_path)[0]) + + return model diff --git a/onnxruntime/test/autoep/test_autoep_selection.cc b/onnxruntime/test/autoep/test_autoep_selection.cc index 619f0a4bcda33..cea1299adc26f 100644 --- a/onnxruntime/test/autoep/test_autoep_selection.cc +++ b/onnxruntime/test/autoep/test_autoep_selection.cc @@ -64,7 +64,11 @@ static void TestInference(Ort::Env& env, const std::basic_string& mod const std::vector& expected_dims_y, const std::vector& expected_values_y, bool auto_select = true, // auto select vs SessionOptionsAppendExecutionProvider_V2 + // manual select using functor const std::function&)>& select_devices = nullptr, + // auto select using policy + std::optional policy = std::nullopt, + std::optional delegate = std::nullopt, bool test_session_creation_only = false) { Ort::SessionOptions session_options; @@ -74,16 +78,22 @@ static void TestInference(Ort::Env& env, const std::basic_string& mod } if (auto_select) { - // manually specify EP to select for now - session_options.AddConfigEntry("test.ep_to_select", ep_to_select.c_str()); - - // add the provider options to the session options with the required prefix - const std::string option_prefix = OrtSessionOptions::GetProviderOptionPrefix(ep_to_select.c_str()); - std::vector keys, values; - ep_options.GetKeyValuePairs(keys, values); - for (size_t i = 0, end = keys.size(); i < end; ++i) { - // add the default value with prefix - session_options.AddConfigEntry((option_prefix + keys[i]).c_str(), values[i]); + if (delegate) { + session_options.SetEpSelectionPolicy(*delegate, nullptr); + } else if (policy) { + session_options.SetEpSelectionPolicy(*policy); + } else { + // manually specify EP to select + session_options.AddConfigEntry("test.ep_to_select", ep_to_select.c_str()); + + // add the provider options to the session options with the required prefix + const std::string option_prefix = OrtSessionOptions::GetProviderOptionPrefix(ep_to_select.c_str()); + std::vector keys, values; + ep_options.GetKeyValuePairs(keys, values); + for (size_t i = 0, end = keys.size(); i < end; ++i) { + // add the default value with prefix + session_options.AddConfigEntry((option_prefix + keys[i]).c_str(), values[i]); + } } } else { std::vector devices; @@ -188,7 +198,7 @@ TEST(AutoEpSelection, DmlEP) { devices.push_back(ep_device); } else { // if this is available, 0 == best performance - auto* perf_index = c_api->GetKeyValue(kvps, "HighPerformanceIndex"); + auto* perf_index = c_api->GetKeyValue(kvps, "DxgiHighPerformanceIndex"); if (perf_index && strcmp(perf_index, "0") == 0) { devices[0] = ep_device; // replace as this is the higher performance device } @@ -213,20 +223,27 @@ TEST(AutoEpSelection, WebGpuEP) { TEST(AutoEpSelection, MiscApiTests) { const OrtApi* c_api = &Ort::GetApi(); - // nullptr and empty input to OrtKeyValuePairs + // nullptr and empty input to OrtKeyValuePairs. also test RemoveKeyValuePair { OrtKeyValuePairs* kvps = nullptr; c_api->CreateKeyValuePairs(&kvps); c_api->AddKeyValuePair(kvps, "key1", nullptr); // should be ignored c_api->AddKeyValuePair(kvps, nullptr, "value1"); // should be ignored c_api->RemoveKeyValuePair(kvps, nullptr); // should be ignored - - c_api->AddKeyValuePair(kvps, "", "value2"); // empty key should be ignored + c_api->AddKeyValuePair(kvps, "", "value2"); // should be ignored ASSERT_EQ(c_api->GetKeyValue(kvps, ""), nullptr); + c_api->AddKeyValuePair(kvps, "key1", "value1"); c_api->AddKeyValuePair(kvps, "key2", ""); // empty value is allowed ASSERT_EQ(c_api->GetKeyValue(kvps, "key2"), std::string("")); + c_api->RemoveKeyValuePair(kvps, "key1"); + const char* const* keys = nullptr; + const char* const* values = nullptr; + size_t num_entries = 0; + c_api->GetKeyValuePairs(kvps, &keys, &values, &num_entries); + ASSERT_EQ(num_entries, 1); + c_api->ReleaseKeyValuePairs(kvps); } @@ -259,6 +276,230 @@ TEST(AutoEpSelection, MiscApiTests) { } } +TEST(AutoEpSelection, PreferCpu) { + std::vector> inputs(1); + auto& input = inputs.back(); + input.name = "X"; + input.dims = {3, 2}; + input.values = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + + // prepare expected inputs and outputs + std::vector expected_dims_y = {3, 2}; + std::vector expected_values_y = {1.0f, 4.0f, 9.0f, 16.0f, 25.0f, 36.0f}; + + const Ort::KeyValuePairs provider_options; + + TestInference(*ort_env, ORT_TSTR("testdata/mul_1.onnx"), + "", // don't need EP name + std::nullopt, + provider_options, + inputs, + "Y", + expected_dims_y, + expected_values_y, + /* auto_select */ true, + /*select_devices*/ nullptr, + OrtExecutionProviderDevicePolicy::OrtExecutionProviderDevicePolicy_PREFER_CPU); +} + +// this should fallback to CPU if no GPU +TEST(AutoEpSelection, PreferGpu) { + std::vector> inputs(1); + auto& input = inputs.back(); + input.name = "X"; + input.dims = {3, 2}; + input.values = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + + // prepare expected inputs and outputs + std::vector expected_dims_y = {3, 2}; + std::vector expected_values_y = {1.0f, 4.0f, 9.0f, 16.0f, 25.0f, 36.0f}; + + const Ort::KeyValuePairs provider_options; + + TestInference(*ort_env, ORT_TSTR("testdata/mul_1.onnx"), + "", // don't need EP name + std::nullopt, + provider_options, + inputs, + "Y", + expected_dims_y, + expected_values_y, + /* auto_select */ true, + /*select_devices*/ nullptr, + OrtExecutionProviderDevicePolicy::OrtExecutionProviderDevicePolicy_PREFER_GPU); +} + +// this should fallback to CPU if no NPU +TEST(AutoEpSelection, PreferNpu) { + std::vector> inputs(1); + auto& input = inputs.back(); + input.name = "X"; + input.dims = {3, 2}; + input.values = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + + // prepare expected inputs and outputs + std::vector expected_dims_y = {3, 2}; + std::vector expected_values_y = {1.0f, 4.0f, 9.0f, 16.0f, 25.0f, 36.0f}; + + const Ort::KeyValuePairs provider_options; + + TestInference(*ort_env, ORT_TSTR("testdata/mul_1.onnx"), + "", // don't need EP name + std::nullopt, + provider_options, + inputs, + "Y", + expected_dims_y, + expected_values_y, + /* auto_select */ true, + /*select_devices*/ nullptr, + OrtExecutionProviderDevicePolicy::OrtExecutionProviderDevicePolicy_PREFER_NPU); +} + +static OrtStatus* ORT_API_CALL PolicyDelegate(_In_ const OrtEpDevice** ep_devices, + _In_ size_t num_devices, + _In_ const OrtKeyValuePairs* model_metadata, + _In_opt_ const OrtKeyValuePairs* /*runtime_metadata*/, + _Inout_ const OrtEpDevice** selected, + _In_ size_t max_selected, + _Out_ size_t* num_selected, + _In_ void* /*state*/) { + *num_selected = 0; + + if (max_selected <= 2) { + return Ort::GetApi().CreateStatus(ORT_INVALID_ARGUMENT, "Expected to be able to select 2 devices."); + } + + if (model_metadata->entries.empty()) { + return Ort::GetApi().CreateStatus(ORT_INVALID_ARGUMENT, "Model metadata was empty."); + } + + selected[0] = ep_devices[0]; + *num_selected = 1; + if (num_devices > 1) { + // CPU EP is always last. + selected[1] = ep_devices[num_devices - 1]; + *num_selected = 2; + } + + return nullptr; +} + +static OrtStatus* ORT_API_CALL PolicyDelegateSelectNone(_In_ const OrtEpDevice** /*ep_devices*/, + _In_ size_t /*num_devices*/, + _In_ const OrtKeyValuePairs* /*model_metadata*/, + _In_opt_ const OrtKeyValuePairs* /*runtime_metadata*/, + _Inout_ const OrtEpDevice** /*selected*/, + _In_ size_t /*max_selected*/, + _Out_ size_t* num_selected, + _In_ void* /*state*/) { + *num_selected = 0; + + return nullptr; +} + +static OrtStatus* ORT_API_CALL PolicyDelegateReturnError(_In_ const OrtEpDevice** /*ep_devices*/, + _In_ size_t /*num_devices*/, + _In_ const OrtKeyValuePairs* /*model_metadata*/, + _In_opt_ const OrtKeyValuePairs* /*runtime_metadata*/, + _Inout_ const OrtEpDevice** /*selected*/, + _In_ size_t /*max_selected*/, + _Out_ size_t* num_selected, + _In_ void* /*state*/) { + *num_selected = 0; + + return Ort::GetApi().CreateStatus(ORT_INVALID_ARGUMENT, "Selection error."); +} + +// test providing a delegate +TEST(AutoEpSelection, PolicyDelegate) { + std::vector> inputs(1); + auto& input = inputs.back(); + input.name = "X"; + input.dims = {3, 2}; + input.values = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + + // prepare expected inputs and outputs + std::vector expected_dims_y = {3, 2}; + std::vector expected_values_y = {1.0f, 4.0f, 9.0f, 16.0f, 25.0f, 36.0f}; + + const Ort::KeyValuePairs provider_options; + + TestInference(*ort_env, ORT_TSTR("testdata/mul_1.onnx"), + "", // don't need EP name + std::nullopt, + provider_options, + inputs, + "Y", + expected_dims_y, + expected_values_y, + /* auto_select */ true, + /*select_devices*/ nullptr, + std::nullopt, + PolicyDelegate); +} + +// test providing a delegate +TEST(AutoEpSelection, PolicyDelegateSelectsNothing) { + std::vector> inputs(1); + auto& input = inputs.back(); + input.name = "X"; + input.dims = {3, 2}; + input.values = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + + // prepare expected inputs and outputs + std::vector expected_dims_y = {3, 2}; + std::vector expected_values_y = {1.0f, 4.0f, 9.0f, 16.0f, 25.0f, 36.0f}; + + const Ort::KeyValuePairs provider_options; + + ASSERT_THROW( + TestInference(*ort_env, ORT_TSTR("testdata/mul_1.onnx"), + "", // don't need EP name + std::nullopt, + provider_options, + inputs, + "Y", + expected_dims_y, + expected_values_y, + /* auto_select */ true, + /*select_devices*/ nullptr, + std::nullopt, + PolicyDelegateSelectNone, + /*test_session_creation_only*/ true), + Ort::Exception); +} + +TEST(AutoEpSelection, PolicyDelegateReturnsError) { + std::vector> inputs(1); + auto& input = inputs.back(); + input.name = "X"; + input.dims = {3, 2}; + input.values = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + + // prepare expected inputs and outputs + std::vector expected_dims_y = {3, 2}; + std::vector expected_values_y = {1.0f, 4.0f, 9.0f, 16.0f, 25.0f, 36.0f}; + + const Ort::KeyValuePairs provider_options; + + ASSERT_THROW( + TestInference(*ort_env, ORT_TSTR("testdata/mul_1.onnx"), + "", // don't need EP name + std::nullopt, + provider_options, + inputs, + "Y", + expected_dims_y, + expected_values_y, + /* auto_select */ true, + /*select_devices*/ nullptr, + std::nullopt, + PolicyDelegateReturnError, + /*test_session_creation_only*/ true), + Ort::Exception); +} + namespace { struct ExamplePluginInfo { const std::filesystem::path library_path = diff --git a/onnxruntime/test/contrib_ops/matmul_8bits_test.cc b/onnxruntime/test/contrib_ops/matmul_8bits_test.cc index b29fc5181eb46..257d3b3efdf9c 100644 --- a/onnxruntime/test/contrib_ops/matmul_8bits_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_8bits_test.cc @@ -2,7 +2,7 @@ // Licensed under the MIT License. #ifndef ORT_MINIMAL_BUILD -#if (defined(MLAS_TARGET_AMD64_IX86) && !defined(USE_DML) && !defined(USE_WEBGPU) && !defined(USE_COREML)) || defined(USE_CUDA) +#if (defined(MLAS_TARGET_AMD64_IX86) && !defined(USE_DML) && !defined(USE_WEBGPU) && !defined(USE_COREML)) || defined(USE_CUDA) || defined(USE_WEBGPU) #include @@ -186,6 +186,10 @@ void RunTest8Bits(const TestOptions8Bits& opts) { std::vector> execution_providers; #ifdef USE_CUDA execution_providers.emplace_back(DefaultCudaExecutionProvider()); +#elif USE_WEBGPU + execution_providers.emplace_back(DefaultWebGpuExecutionProvider()); +#endif +#if defined(USE_CUDA) || defined(USE_WEBGPU) test.ConfigEps(std::move(execution_providers)); test.RunWithConfig(); execution_providers.clear(); @@ -226,8 +230,8 @@ void TestMatMul8BitsTyped() { RunTest8Bits(opts); } -// CUDA does not support bias for MatMulNBits -#if not defined(USE_CUDA) +// CUDA/WEBGPU does not support bias for MatMulNBits +#if !defined(USE_CUDA) && !defined(USE_WEBGPU) { TestOptions8Bits opts = base_opts; opts.has_bias = true; @@ -279,7 +283,7 @@ TEST(MatMulNBits, Float32_8b_AccuracyLevel4_Float) { TestMatMul8BitsTyped(); } -#ifdef USE_CUDA +#if defined(USE_CUDA) || defined(USE_WEBGPU) TEST(MatMulNBits, Float32_8b_AccuracyLevel4_Float16) { TestMatMul8BitsTyped(); TestMatMul8BitsTyped(); diff --git a/onnxruntime/test/contrib_ops/multihead_attention_op_test.cc b/onnxruntime/test/contrib_ops/multihead_attention_op_test.cc index 7c3dc617ffb12..2104f7a35c078 100644 --- a/onnxruntime/test/contrib_ops/multihead_attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/multihead_attention_op_test.cc @@ -671,6 +671,7 @@ TEST(MultiHeadAttentionTest, CrossAttention_Batch1_HeadSize16) { } TEST(MultiHeadAttentionTest, CrossAttention_Batch1_HeadSize8) { + ROCM_GTEST_SKIP("ROCm MHA skip - missing support for ROCm on Radeon"); AttentionTestData data; GetCrossAttentionData_HeadSize8_NoBias(data); RunMultiHeadAttentionTests(data, DISABLE_CUDA); diff --git a/onnxruntime/test/perftest/strings_helper.cc b/onnxruntime/test/perftest/strings_helper.cc index e09c8fac70887..9fd49da1d0486 100644 --- a/onnxruntime/test/perftest/strings_helper.cc +++ b/onnxruntime/test/perftest/strings_helper.cc @@ -41,7 +41,7 @@ void ParseSessionConfigs(const std::string& configs_string, available_keys_str += ", "; } ORT_THROW("[ERROR] wrong key type entered : `", key, - "`. The following runtime key options are avaible: [", available_keys_str, "]"); + "`. The following runtime key options are available: [", available_keys_str, "]"); } auto it = session_configs.find(key); diff --git a/onnxruntime/test/providers/qnn/pool_op_test.cpp b/onnxruntime/test/providers/qnn/pool_op_test.cpp index ae194bd2ef920..d777b1134d060 100644 --- a/onnxruntime/test/providers/qnn/pool_op_test.cpp +++ b/onnxruntime/test/providers/qnn/pool_op_test.cpp @@ -188,6 +188,80 @@ TEST_F(QnnHTPBackendTests, MaxPool_Large_Input_HTP_u8) { QDQTolerance(0.00417f)); } +TEST_F(QnnHTPBackendTests, MaxPool1D_ReshapeNodesPresent) { + auto build_test_case = [](ModelTestBuilder& builder) { + NodeArg* input = builder.MakeInput(std::vector{1, 3, 3}, + GetFloatDataInRange(-10.0f, 10.0f, 9)); + NodeArg* output = builder.MakeOutput(); + auto& maxpool_node = builder.AddNode("MaxPool", {input}, {output}); + maxpool_node.AddAttribute("kernel_shape", std::vector{3}); + maxpool_node.AddAttribute("strides", std::vector{3}); + maxpool_node.AddAttribute("pads", std::vector{0, 0}); + maxpool_node.AddAttribute("ceil_mode", static_cast(0)); + maxpool_node.AddAttribute("storage_order", static_cast(0)); + maxpool_node.AddAttribute("auto_pad", "NOTSET"); + }; + + // Build and serialize the model + auto& logging_manager = DefaultLoggingManager(); + onnxruntime::Model model("maxpool1d", false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(), {}, {}, + logging_manager.DefaultLogger()); + ModelTestBuilder builder(model.MainGraph()); + build_test_case(builder); + builder.SetGraphOutputs(); + ASSERT_STATUS_OK(model.MainGraph().Resolve()); + std::string model_data; + model.ToProto().SerializeToString(&model_data); + + // Setup session options and register QNN HTP EP + SessionOptions so; + ProviderOptions options; + options["backend_type"] = "htp"; + + InferenceSessionWrapper session{so, GetEnvironment()}; + auto qnn_ep = QnnExecutionProviderWithOptions(options, &so); + ASSERT_STATUS_OK(session.RegisterExecutionProvider(std::move(qnn_ep))); + ASSERT_STATUS_OK(session.Load(model_data.data(), static_cast(model_data.size()))); + ASSERT_STATUS_OK(session.Initialize()); + const Graph& graph = session.GetGraph(); + int number_of_nodes = graph.NumberOfNodes(); + + // The Reshape -> Pool -> Reshape gets fused to a single QNN node + EXPECT_EQ(number_of_nodes, 1) << "Expected 1 QNN fused node for MaxPool rank-3 input."; +} + +// 1-D MaxPool HTP test for rank-3 without ceil +TEST_F(QnnHTPBackendTests, MaxPool_Rank3_HTP_u8) { + RunQDQPoolOpTest( + "MaxPool", + TestInputDef({1, 3, 3}, false, -10.0f, 10.0f), + // A single 1-D kernel of length 3 + {utils::MakeAttribute("kernel_shape", std::vector{3}), + utils::MakeAttribute("strides", std::vector{3}), + // 1-D pad: only two values + utils::MakeAttribute("pads", std::vector{0, 0}), + utils::MakeAttribute("dilations", std::vector{1}), + utils::MakeAttribute("ceil_mode", static_cast(0)), + utils::MakeAttribute("storage_order", static_cast(0)), + utils::MakeAttribute("auto_pad", "NOTSET")}, + ExpectedEPNodeAssignment::All); +} + +// 1-D MaxPool HTP test for rank-3 with ceil_mode=1 +TEST_F(QnnHTPBackendTests, MaxPool_Rank3_Ceil_HTP_u8) { + RunQDQPoolOpTest( + "MaxPool", + TestInputDef({1, 3, 3}, false, -10.0f, 10.0f), + {utils::MakeAttribute("kernel_shape", std::vector{3}), + utils::MakeAttribute("strides", std::vector{3}), + utils::MakeAttribute("pads", std::vector{0, 0}), + utils::MakeAttribute("dilations", std::vector{1}), + utils::MakeAttribute("ceil_mode", static_cast(1)), + utils::MakeAttribute("storage_order", static_cast(0)), + utils::MakeAttribute("auto_pad", "NOTSET")}, + ExpectedEPNodeAssignment::All); +} + TEST_F(QnnHTPBackendTests, MaxPool_Ceil_HTP_u8) { RunQDQPoolOpTest("MaxPool", TestInputDef({1, 2, 3, 3}, false, -10.0f, 10.0f), // Dynamic input with range [-10, 10] @@ -219,6 +293,19 @@ TEST_F(QnnHTPBackendTests, DISABLED_MaxPool_Large_Input2_Ceil_HTP_u8) { ExpectedEPNodeAssignment::All); } +TEST_F(QnnHTPBackendTests, MaxPool_Large_Input3_AutoPadValid_HTP_u8) { + RunQDQPoolOpTest("MaxPool", + TestInputDef({1, 160, 14, 20}, false, -10.0f, 10.0f), // Dynamic input with range [-10, 10] + {utils::MakeAttribute("kernel_shape", std::vector{2, 2}), + utils::MakeAttribute("strides", std::vector{2, 2}), + utils::MakeAttribute("pads", std::vector{0, 0, 0, 0}), + utils::MakeAttribute("dilations", std::vector{1, 1}), + utils::MakeAttribute("ceil_mode", static_cast(0)), + utils::MakeAttribute("storage_order", static_cast(0)), + utils::MakeAttribute("auto_pad", "VALID")}, + ExpectedEPNodeAssignment::All); +} + // QNN v2.13: Certain large input sizes cause the QNN graph to fail to finalize with error 1002 (QNN_COMMON_ERROR_MEM_ALLOC). // Fixed in QNN v2.14.1. TEST_F(QnnHTPBackendTests, MaxPool_LargeInput_1Pads_u8) { diff --git a/onnxruntime/test/providers/qnn/qnn_basic_test.cc b/onnxruntime/test/providers/qnn/qnn_basic_test.cc index f736abcd3006d..0212dacadbced 100644 --- a/onnxruntime/test/providers/qnn/qnn_basic_test.cc +++ b/onnxruntime/test/providers/qnn/qnn_basic_test.cc @@ -16,6 +16,7 @@ #include "core/session/onnxruntime_run_options_config_keys.h" #include "test/providers/qnn/qnn_test_utils.h" +#include "test/util/include/api_asserts.h" #include "gtest/gtest.h" #include "gmock/gmock.h" @@ -37,24 +38,24 @@ namespace test { // TODO: When we need QNN in a minimal build we should add an ORT format version of the model #if !defined(ORT_MINIMAL_BUILD) +static bool SessionHasEp(Ort::Session& session, const char* ep_name) { + // Access the underlying InferenceSession. + const OrtSession* ort_session = session; + const InferenceSession* s = reinterpret_cast(ort_session); + bool has_ep = false; + + for (const auto& provider : s->GetRegisteredProviderTypes()) { + if (provider == ep_name) { + has_ep = true; + break; + } + } + return has_ep; +} + // Tests that the QNN EP is registered when added via the public C++ API. // Loads a simple ONNX model that adds floats. TEST_F(QnnHTPBackendTests, TestAddEpUsingPublicApi) { - auto session_has_qnn_ep = [](Ort::Session& session) -> bool { - // Access the underlying InferenceSession. - const OrtSession* ort_session = session; - const InferenceSession* s = reinterpret_cast(ort_session); - bool have_qnn_ep = false; - - for (const auto& provider : s->GetRegisteredProviderTypes()) { - if (provider == kQnnExecutionProvider) { - have_qnn_ep = true; - break; - } - } - return have_qnn_ep; - }; - onnxruntime::ProviderOptions options; #if defined(_WIN32) options["backend_path"] = "QnnHtp.dll"; @@ -77,8 +78,9 @@ TEST_F(QnnHTPBackendTests, TestAddEpUsingPublicApi) { so.AppendExecutionProvider("QNN", options); Ort::Session session(*ort_env, ort_model_path, so); - ASSERT_TRUE(session_has_qnn_ep(session)) << "QNN EP was not found in registered providers for session " - << "when added to session with name 'QNN'"; + ASSERT_TRUE(SessionHasEp(session, kQnnExecutionProvider)) + << "QNN EP was not found in registered providers for session " + << "providers for session when added to session with name 'QNN'"; } { @@ -92,8 +94,9 @@ TEST_F(QnnHTPBackendTests, TestAddEpUsingPublicApi) { so.AppendExecutionProvider(kQnnExecutionProvider, options); Ort::Session session(*ort_env, ort_model_path, so); - ASSERT_TRUE(session_has_qnn_ep(session)) << "QNN EP was not found in registered providers for session " - << "when added to session with name '" << kQnnExecutionProvider << "'"; + ASSERT_TRUE(SessionHasEp(session, kQnnExecutionProvider)) + << "QNN EP was not found in registered providers for session " + << "when added to session with name '" << kQnnExecutionProvider << "'"; } } @@ -1265,6 +1268,24 @@ TEST_F(QnnHTPBackendTests, LoadingAndUnloadingOfQnnLibrary_FixSegFault) { } #endif // !BUILD_QNN_EP_STATIC_LIB +#if defined(WIN32) && !BUILD_QNN_EP_STATIC_LIB +// Tests autoEP feature to automatically select an EP that supports the NPU. +// Currently only works on Windows. +TEST_F(QnnHTPBackendTests, AutoEp_PreferNpu) { + ASSERT_ORTSTATUS_OK(Ort::GetApi().RegisterExecutionProviderLibrary(*ort_env, kQnnExecutionProvider, + ORT_TSTR("onnxruntime_providers_qnn.dll"))); + + Ort::SessionOptions so; + so.SetEpSelectionPolicy(OrtExecutionProviderDevicePolicy_PREFER_NPU); + + const ORTCHAR_T* ort_model_path = ORT_MODEL_FOLDER "nhwc_resize_sizes_opset18.quant.onnx"; + Ort::Session session(*ort_env, ort_model_path, so); + EXPECT_TRUE(SessionHasEp(session, kQnnExecutionProvider)); + + ASSERT_ORTSTATUS_OK(Ort::GetApi().UnregisterExecutionProviderLibrary(*ort_env, kQnnExecutionProvider)); +} +#endif // defined(WIN32) && !BUILD_QNN_EP_STATIC_LIB + #endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) #endif // !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/test/providers/qnn/resize_test.cc b/onnxruntime/test/providers/qnn/resize_test.cc index fbd729fa998d9..702d4e6eddb1b 100644 --- a/onnxruntime/test/providers/qnn/resize_test.cc +++ b/onnxruntime/test/providers/qnn/resize_test.cc @@ -399,6 +399,15 @@ TEST_F(QnnHTPBackendTests, ResizeU8_2xNearestHalfPixelRoundPreferFloor_Unsupport ExpectedEPNodeAssignment::None); // No longer supported as of QNN SDK 2.21 } +// Test 2x QDQ Resize mode: "nearest", coordinate_transformation_mode: "half_pixel", nearest_mode: "round_prefer_Ceil" +// Maps to QNN's ResizeNearesetNeighbor operator. +TEST_F(QnnHTPBackendTests, ResizeU8_2xNearestHalfPixelRoundPreferCeil) { + std::vector input_data = GetFloatDataInRange(-10.0f, 10.0f, 48); + RunQDQResizeOpTest(TestInputDef({1, 3, 4, 4}, false, input_data), + {1, 3, 8, 8}, "nearest", "half_pixel", "round_prefer_ceil", + ExpectedEPNodeAssignment::All); +} + // Test 2x QDQ Resize mode: "nearest", coordinate_transformation_mode: "align_corners", nearest_mode: "round_prefer_ceil" // Maps to QNN's Resize operator. // UPDATE: "round_prefer_ceil" is supported as of QNN SDK 2.21 if using "align_corners". (Unsupported in QNN SDK 2.19). diff --git a/onnxruntime/test/python/autoep_helper.py b/onnxruntime/test/python/autoep_helper.py new file mode 100644 index 0000000000000..e3b214afa6e62 --- /dev/null +++ b/onnxruntime/test/python/autoep_helper.py @@ -0,0 +1,67 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +from __future__ import annotations + +import os +import tempfile +import unittest + +import onnxruntime as onnxrt +from onnxruntime.capi.onnxruntime_pybind11_state import Fail + + +class AutoEpTestCase(unittest.TestCase): + """ + Base class for TestCase classes that need to register and unregister EP libraries. + Because EP libraries are registered with the ORT environment and all unit tests share + the same environment, this class tracks which libraries have already been registered + so that they are not erroneously registered or unregistered. + + Derived classes must use 'self.register_execution_provider_library()' and + 'self.unregister_execution_provider_library()' to benefit from these utilities. + """ + + @classmethod + def setUpClass(cls): + cls._tmp_model_dir = tempfile.TemporaryDirectory(prefix="ort.autoep_") + + # Note: swap with the commented line if you want to see the models in local test dir. + cls._tmp_dir_path = cls._tmp_model_dir.name + # cls._tmp_dir_path = "." + + # Track registered EP libraries across all tests. + cls._registered_providers = set() + + @classmethod + def tearDownClass(cls): + cls._tmp_model_dir.cleanup() + + def register_execution_provider_library(self, ep_registration_name: str, ep_lib_path: os.PathLike | str): + if ep_registration_name in self._registered_providers: + return # Already registered + + try: + onnxrt.register_execution_provider_library(ep_registration_name, ep_lib_path) + except Fail as onnxruntime_error: + if "already registered" in str(onnxruntime_error): + pass # Allow register to fail if the EP library was previously registered. + else: + raise onnxruntime_error + + # Add this EP library to set of registered EP libraries. + # If the unit test itself does not unregister the library, tearDown() will try. + self._registered_providers.add(ep_registration_name) + + def unregister_execution_provider_library(self, ep_registration_name: str): + if ep_registration_name not in self._registered_providers: + return # Not registered + + try: + onnxrt.unregister_execution_provider_library(ep_registration_name) + except Fail as onnxruntime_error: + if "was not registered" in str(onnxruntime_error): + pass # Allow unregister to fail if the EP library was never registered. + else: + raise onnxruntime_error + + self._registered_providers.remove(ep_registration_name) diff --git a/onnxruntime/test/python/onnxruntime_test_python.py b/onnxruntime/test/python/onnxruntime_test_python.py index f3ebc92409f77..0558b44ae6b47 100644 --- a/onnxruntime/test/python/onnxruntime_test_python.py +++ b/onnxruntime/test/python/onnxruntime_test_python.py @@ -1360,13 +1360,20 @@ def test_register_custom_ops_library(self): ) def test_ort_value(self): + providers_to_test = onnxrt.get_available_providers() numpy_arr_input = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=np.float32) numpy_arr_output = np.array([[1.0, 4.0], [9.0, 16.0], [25.0, 36.0]], dtype=np.float32) - def test_session_with_ortvalue_input(ortvalue): - sess = onnxrt.InferenceSession(get_name("mul_1.onnx"), providers=onnxrt.get_available_providers()) + def test_session_with_ortvalue_input(ortvalue, providers): + sess = onnxrt.InferenceSession(get_name("mul_1.onnx"), providers=providers) res = sess.run(["Y"], {"X": ortvalue}) - self.assertTrue(np.array_equal(res[0], numpy_arr_output)) + + if "QNNExecutionProvider" in providers: + # QNN runs float32 with fp16 precision, so relax accuracy expectations + np.testing.assert_allclose(numpy_arr_output, res[0], rtol=1e-04, atol=1e-06) + else: + self.assertTrue(np.array_equal(res[0], numpy_arr_output)) + vect = sess._sess.run_with_ort_values({"X": ortvalue._get_c_value()}, ["Y"], RunOptions()) self.assertIsInstance(vect, OrtValueVector) @@ -1375,10 +1382,12 @@ def test_session_with_ortvalue_input(ortvalue): self.assertEqual(ortvalue1.shape(), [3, 2]) self.assertEqual(ortvalue1.data_type(), "tensor(float)") self.assertEqual(ortvalue1.is_tensor(), True) + # Assumes float32 and shape {3, 2} as above + self.assertEqual(ortvalue1.tensor_size_in_bytes(), 4 * 2 * 3) self.assertTrue(np.array_equal(ortvalue1.numpy(), numpy_arr_input)) # Pass in the constructed OrtValue to a session via Run() and check results - test_session_with_ortvalue_input(ortvalue1) + test_session_with_ortvalue_input(ortvalue1, providers_to_test) # The constructed OrtValue should still be valid after being used in a session self.assertTrue(np.array_equal(ortvalue1.numpy(), numpy_arr_input)) @@ -1392,7 +1401,7 @@ def test_session_with_ortvalue_input(ortvalue): self.assertEqual(float_tensor_data_type, ort_value_with_type.element_type()) self.assertEqual([3, 2], ort_value_with_type.shape()) - if "CUDAExecutionProvider" in onnxrt.get_available_providers(): + if "CUDAExecutionProvider" in providers_to_test: ortvalue2 = onnxrt.OrtValue.ortvalue_from_numpy(numpy_arr_input, "cuda", 0) self.assertEqual(ortvalue2.device_name(), "cuda") self.assertEqual(ortvalue2.shape(), [3, 2]) @@ -1401,7 +1410,7 @@ def test_session_with_ortvalue_input(ortvalue): self.assertTrue(np.array_equal(ortvalue2.numpy(), numpy_arr_input)) # Pass in the constructed OrtValue to a session via Run() and check results - test_session_with_ortvalue_input(ortvalue2) + test_session_with_ortvalue_input(ortvalue2, providers_to_test) # The constructed OrtValue should still be valid after being used in a session self.assertTrue(np.array_equal(ortvalue2.numpy(), numpy_arr_input)) diff --git a/onnxruntime/test/python/onnxruntime_test_python_autoep.py b/onnxruntime/test/python/onnxruntime_test_python_autoep.py new file mode 100644 index 0000000000000..417a6e27fb7b2 --- /dev/null +++ b/onnxruntime/test/python/onnxruntime_test_python_autoep.py @@ -0,0 +1,240 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +from __future__ import annotations + +import os +import platform +import sys +import unittest +from collections.abc import Sequence + +import numpy as np +from autoep_helper import AutoEpTestCase +from helper import get_name + +import onnxruntime as onnxrt +from onnxruntime.capi.onnxruntime_pybind11_state import Fail, InvalidArgument + +# handle change from python 3.8 and on where loading a dll from the current directory needs to be explicitly allowed. +if platform.system() == "Windows" and sys.version_info.major >= 3 and sys.version_info.minor >= 8: # noqa: YTT204 + os.add_dll_directory(os.getcwd()) + +available_providers = list(onnxrt.get_available_providers()) + + +class TestAutoEP(AutoEpTestCase): + def test_cuda_ep_register_and_inference(self): + """ + Test registration of CUDA EP, adding its OrtDevice to the SessionOptions, and running inference. + """ + ep_lib_path = "onnxruntime_providers_cuda.dll" + ep_name = "CUDAExecutionProvider" + + if sys.platform != "win32": + self.skipTest("Skipping test because device discovery is only supported on Windows") + + if ep_name not in available_providers: + self.skipTest("Skipping test because it needs to run on CUDA EP") + + self.register_execution_provider_library(ep_name, ep_lib_path) + + ep_devices = onnxrt.get_ep_devices() + has_cpu_ep = False + cuda_ep_device = None + for ep_device in ep_devices: + if ep_device.ep_name == "CPUExecutionProvider": + has_cpu_ep = True + if ep_device.ep_name == ep_name: + cuda_ep_device = ep_device + + self.assertTrue(has_cpu_ep) + self.assertIsNotNone(cuda_ep_device) + self.assertEqual(cuda_ep_device.ep_vendor, "Microsoft") + + hw_device = cuda_ep_device.device + self.assertEqual(hw_device.type, onnxrt.OrtHardwareDeviceType.GPU) + + # Add CUDA's OrtEpDevice to session options + sess_options = onnxrt.SessionOptions() + sess_options.add_provider_for_devices([cuda_ep_device], {"prefer_nhwc": "1"}) + self.assertTrue(sess_options.has_providers()) + + # Run sample model and check output + sess = onnxrt.InferenceSession(get_name("mul_1.onnx"), sess_options=sess_options) + + x = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=np.float32) + input_name = sess.get_inputs()[0].name + res = sess.run([], {input_name: x}) + output_expected = np.array([[1.0, 4.0], [9.0, 16.0], [25.0, 36.0]], dtype=np.float32) + np.testing.assert_allclose(output_expected, res[0], rtol=1e-05, atol=1e-08) + + del sess # Delete session before unregistering library + self.unregister_execution_provider_library(ep_name) + + def test_cuda_prefer_gpu_and_inference(self): + """ + Test selecting CUDA EP via the PREFER_GPU policy and running inference. + """ + ep_lib_path = "onnxruntime_providers_cuda.dll" + ep_name = "CUDAExecutionProvider" + + if sys.platform != "win32": + self.skipTest("Skipping test because device discovery is only supported on Windows") + + if ep_name not in available_providers: + self.skipTest("Skipping test because it needs to run on CUDA EP") + + self.register_execution_provider_library(ep_name, ep_lib_path) + + # Set a policy to prefer GPU. Cuda should be selected. + sess_options = onnxrt.SessionOptions() + sess_options.set_provider_selection_policy(onnxrt.OrtExecutionProviderDevicePolicy.PREFER_GPU) + self.assertTrue(sess_options.has_providers()) + + # Run sample model and check output + sess = onnxrt.InferenceSession(get_name("mul_1.onnx"), sess_options=sess_options) + + x = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=np.float32) + input_name = sess.get_inputs()[0].name + res = sess.run([], {input_name: x}) + output_expected = np.array([[1.0, 4.0], [9.0, 16.0], [25.0, 36.0]], dtype=np.float32) + np.testing.assert_allclose(output_expected, res[0], rtol=1e-05, atol=1e-08) + + del sess # Delete session before unregistering library + self.unregister_execution_provider_library(ep_name) + + def test_cuda_ep_selection_delegate_and_inference(self): + """ + Test selecting CUDA EP via the custom EP selection delegate function and then run inference. + """ + ep_lib_path = "onnxruntime_providers_cuda.dll" + ep_name = "CUDAExecutionProvider" + + if sys.platform != "win32": + self.skipTest("Skipping test because device discovery is only supported on Windows") + + if ep_name not in available_providers: + self.skipTest("Skipping test because it needs to run on CUDA EP") + + self.register_execution_provider_library(ep_name, ep_lib_path) + + # User's custom EP selection function. + def my_delegate( + ep_devices: Sequence[onnxrt.OrtEpDevice], + model_metadata: dict[str, str], + runtime_metadata: dict[str, str], + max_selections: int, + ) -> Sequence[onnxrt.OrtEpDevice]: + self.assertGreater(len(model_metadata), 0) + self.assertGreaterEqual(len(ep_devices), 2) + self.assertGreaterEqual(max_selections, 2) + + cuda_ep_device = next((d for d in ep_devices if d.ep_name == ep_name), None) + self.assertIsNotNone(cuda_ep_device) + + # Select the CUDA device and the ORT CPU EP device (should always be last) + return [cuda_ep_device, ep_devices[-1]] + + sess_options = onnxrt.SessionOptions() + sess_options.set_provider_selection_policy_delegate(my_delegate) + self.assertTrue(sess_options.has_providers()) + + # Run sample model and check output + sess = onnxrt.InferenceSession(get_name("mul_1.onnx"), sess_options=sess_options) + + x = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=np.float32) + input_name = sess.get_inputs()[0].name + res = sess.run([], {input_name: x}) + output_expected = np.array([[1.0, 4.0], [9.0, 16.0], [25.0, 36.0]], dtype=np.float32) + np.testing.assert_allclose(output_expected, res[0], rtol=1e-05, atol=1e-08) + + del sess # Delete session before unregistering library + self.unregister_execution_provider_library(ep_name) + + def test_custom_ep_selection_delegate_that_raises_error(self): + """ + Test a custom EP selection delegate function that raises a Python exception. ORT should re-raise as FAIL. + """ + if sys.platform != "win32": + self.skipTest("Skipping test because device discovery is only supported on Windows") + + # User's custom EP selection function. + custom_error_message = "MY ERROR" + + def my_delegate_that_fails( + ep_devices: Sequence[onnxrt.OrtEpDevice], + model_metadata: dict[str, str], + runtime_metadata: dict[str, str], + max_selections: int, + ) -> Sequence[onnxrt.OrtEpDevice]: + self.assertGreaterEqual(len(ep_devices), 1) + raise ValueError(custom_error_message) + + sess_options = onnxrt.SessionOptions() + sess_options.set_provider_selection_policy_delegate(my_delegate_that_fails) + + # Create session and expect ORT to raise a Fail exception that contains our message. + with self.assertRaises(Fail) as context: + onnxrt.InferenceSession(get_name("mul_1.onnx"), sess_options=sess_options) + self.assertIn(custom_error_message, str(context.exception)) + + def test_example_plugin_ep_devices(self): + """ + Test registration of an example EP plugin and retrieval of its OrtEpDevice. + """ + if sys.platform != "win32": + self.skipTest("Skipping test because it device discovery is only supported on Windows") + + ep_lib_path = "example_plugin_ep.dll" + try: + ep_lib_path = get_name("example_plugin_ep.dll") + except FileNotFoundError: + self.skipTest(f"Skipping test because EP library '{ep_lib_path}' cannot be found") + + ep_name = "example_ep" + self.register_execution_provider_library(ep_name, os.path.realpath(ep_lib_path)) + + ep_devices = onnxrt.get_ep_devices() + has_cpu_ep = False + test_ep_device = None + for ep_device in ep_devices: + if ep_device.ep_name == "CPUExecutionProvider": + has_cpu_ep = True + if ep_device.ep_name == ep_name: + test_ep_device = ep_device + + self.assertTrue(has_cpu_ep) + self.assertIsNotNone(test_ep_device) + + # Test the OrtEpDevice getters. Expected values are from /onnxruntime/test/autoep/library/example_plugin_ep.cc + self.assertEqual(test_ep_device.ep_vendor, "Contoso") + + ep_metadata = test_ep_device.ep_metadata + self.assertEqual(ep_metadata["version"], "0.1") + + ep_options = test_ep_device.ep_options + self.assertEqual(ep_options["run_really_fast"], "true") + + # The CPU hw device info will vary by machine so check for the common values. + hw_device = test_ep_device.device + self.assertEqual(hw_device.type, onnxrt.OrtHardwareDeviceType.CPU) + self.assertGreaterEqual(hw_device.vendor_id, 0) + self.assertGreaterEqual(hw_device.device_id, 0) + self.assertGreater(len(hw_device.vendor), 0) + + hw_metadata = hw_device.metadata + self.assertGreater(len(hw_metadata), 0) # Should have at least SPDRP_HARDWAREID on Windows + + # Test adding this EP plugin's OrtEpDevice to the SessionOptions. + sess_options = onnxrt.SessionOptions() + with self.assertRaises(InvalidArgument) as context: + # Will raise InvalidArgument because ORT currently only supports provider bridge APIs. + # Actual plugin EPs will be supported in the future. + sess_options.add_provider_for_devices([test_ep_device], {"opt1": "val1"}) + self.assertIn("EP is not currently supported", str(context.exception)) + + self.unregister_execution_provider_library(ep_name) + + +if __name__ == "__main__": + unittest.main(verbosity=1) diff --git a/onnxruntime/test/python/onnxruntime_test_python_backend.py b/onnxruntime/test/python/onnxruntime_test_python_backend.py index 1f6cd78f28334..6ed7dfe59b1f6 100644 --- a/onnxruntime/test/python/onnxruntime_test_python_backend.py +++ b/onnxruntime/test/python/onnxruntime_test_python_backend.py @@ -40,18 +40,28 @@ def test_allocation_plan_works_with_only_execute_path_to_fetches_option(self): This case is handled specifically in ExecutionFrame::AllocateAsPerAllocationPlan(). This test is to ensure that the case is covered. """ + providers = onnxrt.get_available_providers() + has_qnn_ep = "QNNExecutionProvider" in providers name = get_name("alloc_tensor_reuse.onnx") - sess = onnxrt.InferenceSession(name, providers=onnxrt.get_available_providers()) + sess = onnxrt.InferenceSession(name, providers=providers) run_options = onnxrt.RunOptions() run_options.only_execute_path_to_fetches = True inp0, inp1 = np.ones((10,), dtype=np.float32), np.ones((10,), dtype=np.float32) session_run_results = sess.run(["outp0"], {"inp0": inp0, "inp1": inp1}, run_options) - assert_allclose(session_run_results[0], -(inp0 + inp1)) + if has_qnn_ep: + # QNN EP runs fp32 with fp16 precision, so relax tolerance. + assert_allclose(session_run_results[0], -(inp0 + inp1), rtol=1e-6, atol=1e-6) + else: + assert_allclose(session_run_results[0], -(inp0 + inp1)) session_run_results = sess.run(["outp1"], {"inp0": inp0, "inp1": inp1}, run_options) - assert_allclose(session_run_results[0], -(inp0 - inp1)) + if has_qnn_ep: + # QNN EP runs fp32 with fp16 precision, so relax tolerance. + assert_allclose(session_run_results[0], -(inp0 - inp1), rtol=1e-6, atol=1e-6) + else: + assert_allclose(session_run_results[0], -(inp0 - inp1)) if __name__ == "__main__": diff --git a/onnxruntime/test/python/onnxruntime_test_python_compile_api.py b/onnxruntime/test/python/onnxruntime_test_python_compile_api.py new file mode 100644 index 0000000000000..7a410d4bbeb6a --- /dev/null +++ b/onnxruntime/test/python/onnxruntime_test_python_compile_api.py @@ -0,0 +1,226 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +from __future__ import annotations + +import os +import platform +import sys +import unittest +from collections.abc import Sequence + +import onnx +from autoep_helper import AutoEpTestCase +from helper import get_name + +import onnxruntime as onnxrt +from onnxruntime.capi.onnxruntime_pybind11_state import ModelRequiresCompilation + +# handle change from python 3.8 and on where loading a dll from the current directory needs to be explicitly allowed. +if platform.system() == "Windows" and sys.version_info.major >= 3 and sys.version_info.minor >= 8: # noqa: YTT204 + os.add_dll_directory(os.getcwd()) + +available_providers = list(onnxrt.get_available_providers()) + + +class TestCompileApi(AutoEpTestCase): + def test_compile_with_files_prefer_npu_policy(self): + """ + Tests compiling a model (to/from files) using an EP selection policy (PREFER_NPU). + """ + if "QNNExecutionProvider" not in available_providers: + self.skipTest("Skipping test because it needs to run on QNN EP") + + if sys.platform != "win32": + self.skipTest("Skipping test because provider selection policies are only supported on Windows") + + ep_lib_path = "onnxruntime_providers_qnn.dll" + ep_name = "QNNExecutionProvider" + self.register_execution_provider_library(ep_name, ep_lib_path) + + input_model_path = get_name("nhwc_resize_scales_opset18.onnx") + output_model_path = os.path.join(self._tmp_dir_path, "model.compiled0.onnx") + + session_options = onnxrt.SessionOptions() + session_options.set_provider_selection_policy(onnxrt.OrtExecutionProviderDevicePolicy.PREFER_NPU) + + model_compiler = onnxrt.ModelCompiler( + session_options, + input_model_path, + embed_compiled_data_into_model=True, + external_initializers_file_path=None, + ) + model_compiler.compile_to_file(output_model_path) + self.assertTrue(os.path.exists(output_model_path)) + self.unregister_execution_provider_library(ep_name) + + def test_compile_with_ep_selection_delegate(self): + """ + Tests compiling a model (to/from files) using an EP selection delegate callback. + """ + if sys.platform != "win32": + self.skipTest("Skipping test because provider selection policies are only supported on Windows") + + input_model_path = get_name("nhwc_resize_scales_opset18.onnx") + output_model_path = os.path.join(self._tmp_dir_path, "model.compiled.delegate.onnx") + + # User's custom EP selection function. + def my_delegate( + ep_devices: Sequence[onnxrt.OrtEpDevice], + model_metadata: dict[str, str], + runtime_metadata: dict[str, str], + max_selections: int, + ) -> Sequence[onnxrt.OrtEpDevice]: + self.assertGreater(len(ep_devices), 0) + self.assertGreater(len(model_metadata), 0) + self.assertGreater(max_selections, 0) + + # Select the first and last devices (if there are more than one) + selected_devices = [ep_devices[0]] + if max_selections > 2 and len(ep_devices) > 1: + selected_devices.append(ep_devices[-1]) # ORT CPU EP is always last + + return selected_devices + + session_options = onnxrt.SessionOptions() + session_options.set_provider_selection_policy_delegate(my_delegate) + + model_compiler = onnxrt.ModelCompiler( + session_options, + input_model_path, + embed_compiled_data_into_model=True, + external_initializers_file_path=None, + ) + model_compiler.compile_to_file(output_model_path) + self.assertTrue(os.path.exists(output_model_path)) + + def test_compile_with_input_and_output_files(self): + """ + Tests compiling a model (to/from files) using explicit EP. + """ + provider = None + provider_options = dict() + if "QNNExecutionProvider" in available_providers: + provider = "QNNExecutionProvider" + provider_options["backend_type"] = "htp" + # TODO(adrianlizarraga): Allow test to run for other compiling EPs (e.g., OpenVINO) + + input_model_path = get_name("nhwc_resize_scales_opset18.onnx") + output_model_path = os.path.join(self._tmp_dir_path, "model.compiled1.onnx") + + session_options = onnxrt.SessionOptions() + if provider: + session_options.add_provider(provider, provider_options) + + model_compiler = onnxrt.ModelCompiler( + session_options, + input_model_path, + embed_compiled_data_into_model=True, + external_initializers_file_path=None, + ) + model_compiler.compile_to_file(output_model_path) + self.assertTrue(os.path.exists(output_model_path)) + + def test_compile_to_file_with_input_model_in_buffer(self): + """ + Tests compiling an input model that is stored in a buffer. The output is saved to a file. + """ + provider = None + provider_options = dict() + if "QNNExecutionProvider" in available_providers: + provider = "QNNExecutionProvider" + provider_options["backend_type"] = "htp" + # TODO(adrianlizarraga): Allow test to run for other compiling EPs (e.g., OpenVINO) + + input_onnx_model = onnx.load(get_name("nhwc_resize_scales_opset18.onnx")) + input_model_bytes = input_onnx_model.SerializeToString() + output_model_path = os.path.join(self._tmp_dir_path, "model.compiled2.onnx") + + session_options = onnxrt.SessionOptions() + if provider: + session_options.add_provider(provider, provider_options) + + model_compiler = onnxrt.ModelCompiler( + session_options, + input_model_bytes, + embed_compiled_data_into_model=True, + external_initializers_file_path=None, + ) + model_compiler.compile_to_file(output_model_path) + self.assertTrue(os.path.exists(output_model_path)) + + def test_compile_from_buffer_to_buffer(self): + """ + Tests compiling an input model that is stored in a buffer. The output is stored in a buffer too. + """ + provider = None + provider_options = dict() + if "QNNExecutionProvider" in available_providers: + provider = "QNNExecutionProvider" + provider_options["backend_type"] = "htp" + # TODO(adrianlizarraga): Allow test to run for other compiling EPs (e.g., OpenVINO) + + input_onnx_model = onnx.load(get_name("nhwc_resize_scales_opset18.onnx")) + input_model_bytes = input_onnx_model.SerializeToString() + + session_options = onnxrt.SessionOptions() + if provider: + session_options.add_provider(provider, provider_options) + + model_compiler = onnxrt.ModelCompiler( + session_options, + input_model_bytes, + embed_compiled_data_into_model=True, + external_initializers_file_path=None, + ) + output_model_bytes = model_compiler.compile_to_bytes() + self.assertTrue(isinstance(output_model_bytes, bytes)) + self.assertGreater(len(output_model_bytes), 0) + + def test_fail_load_uncompiled_model_and_then_compile(self): + """ + Tests compiling scenario: + - Load uncompiled model into session that disables JIT compilation. + - Expect an error (ModelRequiresCompilation) + - Compile model and retry creating an inference session successfully. + """ + if "QNNExecutionProvider" not in available_providers: + self.skipTest("Skipping test because it needs to run on a compiling EP") + + input_model_path = get_name("nhwc_resize_scales_opset18.onnx") + + session_options = onnxrt.SessionOptions() + session_options.add_session_config_entry("session.disable_model_compile", "1") # Disable JIT model compilation! + session_options.add_provider("QNNExecutionProvider", {"backend_type": "htp"}) + + # Session creation should fail with error ORT_MODEL_REQUIRES_COMPILATION because the input model + # is not compiled and we disabled JIT compilation for this session. + with self.assertRaises(ModelRequiresCompilation) as context: + onnxrt.InferenceSession( + input_model_path, + sess_options=session_options, + enable_fallback=False, + ) + self.assertIn("needs to compile", str(context.exception)) + + # Try to compile the model now. + compiled_model_path = os.path.join(self._tmp_dir_path, "model.compiled3.onnx") + model_compiler = onnxrt.ModelCompiler( + session_options, + input_model_path, + embed_compiled_data_into_model=True, + external_initializers_file_path="external_weights.bin", + external_initializers_size_threshold=128, + ) + model_compiler.compile_to_file(compiled_model_path) + + self.assertTrue(os.path.exists(compiled_model_path)) + self.assertEqual(session_options.get_session_config_entry("session.disable_model_compile"), "1") + self.assertTrue(session_options.has_providers()) + + # Creating the session with the compiled model should not fail. + sess = onnxrt.InferenceSession(compiled_model_path, sess_options=session_options) + self.assertIsNotNone(sess) + + +if __name__ == "__main__": + unittest.main(verbosity=1) diff --git a/onnxruntime/test/python/quantization/test_op_matmul_4bits.py b/onnxruntime/test/python/quantization/test_op_matmul_4bits.py index 0e739055b1772..03f4791c580e6 100644 --- a/onnxruntime/test/python/quantization/test_op_matmul_4bits.py +++ b/onnxruntime/test/python/quantization/test_op_matmul_4bits.py @@ -360,6 +360,8 @@ def test_quantize_matmul_int4_offsets_qdq(self): def test_quantize_matmul_int4_using_rtn_algo(self): if not find_spec("neural_compressor"): self.skipTest("skip test_smooth_quant since neural_compressor is not installed") + if not find_spec("torch"): + self.skipTest("skip test_quantize_matmul_int4_using_rtn_algo since torch is not installed") model_fp32_path = str(Path(self._tmp_model_dir.name).joinpath("matmul_fp32_offset.onnx").absolute()) self.construct_model_matmul(model_fp32_path, symmetric=False) data_reader = self.input_feeds(1, {"input": (100, 52)}) @@ -371,6 +373,8 @@ def test_quantize_matmul_int4_using_rtn_algo(self): def test_quantize_matmul_int4_using_gptq_algo(self): if not find_spec("neural_compressor"): self.skipTest("skip test_smooth_quant since neural_compressor is not installed") + if not find_spec("torch"): + self.skipTest("skip test_quantize_matmul_int4_using_gptq_algo since torch is not installed") model_fp32_path = str(Path(self._tmp_model_dir.name).joinpath("matmul_fp32_offset.onnx").absolute()) self.construct_model_matmul(model_fp32_path, symmetric=False) data_reader = self.input_feeds(1, {"input": (100, 52)}) diff --git a/onnxruntime/test/shared_lib/test_inference.cc b/onnxruntime/test/shared_lib/test_inference.cc index 0a4ea71933724..6460e3cb3aec4 100644 --- a/onnxruntime/test/shared_lib/test_inference.cc +++ b/onnxruntime/test/shared_lib/test_inference.cc @@ -425,6 +425,39 @@ TEST_P(CApiTestWithProvider, simple) { nullptr, nullptr); } +template +void TestGetTensorSizeInBytes(Ort::ConstMemoryInfo cpu_meminfo) { + constexpr const size_t expected_size_in_bytes = sizeof(T) * element_count_to_create; + constexpr const std::array dims = {1, static_cast(element_count_to_create)}; + std::array data; + std::fill(data.begin(), data.end(), T{1}); + + auto value = Ort::Value::CreateTensor(cpu_meminfo, data.data(), + data.size(), dims.data(), dims.size()); + + auto type_info = value.GetTypeInfo(); + ASSERT_EQ(type_info.GetONNXType(), ONNX_TYPE_TENSOR); + auto tensor_type_info = type_info.GetTensorTypeAndShapeInfo(); + const auto element_count = tensor_type_info.GetElementCount(); + ASSERT_EQ(expected_size_in_bytes / sizeof(T), element_count); + ASSERT_EQ(expected_size_in_bytes, value.GetTensorSizeInBytes()); +} + +TEST(CApiTest, TestGetTensorSizeInBytes) { + Ort::MemoryInfo cpu_meminfo("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault); + TestGetTensorSizeInBytes(cpu_meminfo.GetConst()); + TestGetTensorSizeInBytes(cpu_meminfo.GetConst()); + TestGetTensorSizeInBytes(cpu_meminfo.GetConst()); + TestGetTensorSizeInBytes(cpu_meminfo.GetConst()); + TestGetTensorSizeInBytes(cpu_meminfo.GetConst()); + + TestGetTensorSizeInBytes(cpu_meminfo.GetConst()); + TestGetTensorSizeInBytes(cpu_meminfo.GetConst()); + TestGetTensorSizeInBytes(cpu_meminfo.GetConst()); + TestGetTensorSizeInBytes(cpu_meminfo.GetConst()); + TestGetTensorSizeInBytes(cpu_meminfo.GetConst()); +} + TEST(CApiTest, dim_param) { Ort::SessionOptions session_options; Ort::Session session(*ort_env, NAMED_AND_ANON_DIM_PARAM_URI, session_options); diff --git a/onnxruntime/wasm/pre-jsep.js b/onnxruntime/wasm/pre-jsep.js index df09920bddebd..8232a286d4480 100644 --- a/onnxruntime/wasm/pre-jsep.js +++ b/onnxruntime/wasm/pre-jsep.js @@ -104,6 +104,7 @@ Module["jsepInit"] = (name, params) => { Module["webnnEnsureTensor"], Module.webnnUploadTensor, Module["webnnDownloadTensor"], + Module["webnnEnableTraceEvent"], ] = params.slice(1); // This function is called from both JS and an EM_ASM block, it needs both a minifiable name and an explicit name. diff --git a/requirements-lintrunner.txt b/requirements-lintrunner.txt index 02408f6ed17e8..37df56216bde9 100644 --- a/requirements-lintrunner.txt +++ b/requirements-lintrunner.txt @@ -3,6 +3,6 @@ lintrunner==0.12.7 lintrunner-adapters==0.12.4 # RUFF -ruff==0.11.6 +ruff==0.11.9 # CLANGFORMAT clang-format==19.1.7 diff --git a/setup.py b/setup.py index 1e426ea8e060b..c45657c0c2873 100644 --- a/setup.py +++ b/setup.py @@ -517,6 +517,7 @@ def finalize_options(self): "onnxruntime.quantization.CalTableFlatBuffers", "onnxruntime.quantization.fusions", "onnxruntime.quantization.execution_providers.qnn", + "onnxruntime.quantization.neural_compressor", "onnxruntime.transformers", "onnxruntime.transformers.models.bart", "onnxruntime.transformers.models.bert", diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index 03b51790e0ef6..8dce6be731402 100644 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -1228,9 +1228,62 @@ def generate_build_tree( ] env = {} if args.use_vcpkg: - env["VCPKG_KEEP_ENV_VARS"] = "TRT_UPLOAD_AUTH_TOKEN;EMSDK;EMSDK_NODE;EMSDK_PYTHON" + vcpkg_keep_env_vars = ["TRT_UPLOAD_AUTH_TOKEN"] + if args.build_wasm: - env["EMSDK"] = emsdk_dir + emsdk_vars = ["EMSDK", "EMSDK_NODE", "EMSDK_PYTHON"] + + # If environment variables 'EMSDK' is not set, run emsdk_env to set them + if "EMSDK" not in os.environ: + if is_windows(): + # Run `cmd /s /c call .\\emsdk_env && set` to run emsdk_env and dump the environment variables + emsdk_env = run_subprocess( + ["cmd", "/s", "/c", "call .\\emsdk_env && set"], + cwd=emsdk_dir, + capture_stdout=True, + ) + else: + # Run `sh -c ". ./emsdk_env.sh && printenv"` to run emsdk_env and dump the environment variables + emsdk_env = run_subprocess( + ["sh", "-c", ". ./emsdk_env.sh && printenv"], + cwd=emsdk_dir, + capture_stdout=True, + ) + + # check for EMSDK environment variables and set them in the environment + for line in emsdk_env.stdout.decode().splitlines(): + if "=" in line: + key, value = line.rstrip().split("=", 1) + if key in emsdk_vars: + os.environ[key] = value + + for var in emsdk_vars: + if var in os.environ: + env[var] = os.environ[var] + elif var == "EMSDK": + # EMSDK must be set, but EMSDK_NODE and EMSDK_PYTHON are optional + raise BuildError( + "EMSDK environment variable is not set correctly. Please run `emsdk_env` to set them." + ) + + vcpkg_keep_env_vars += emsdk_vars + + # + # Workaround for vcpkg failed to find the correct path of Python + # + # Since vcpkg does not inherit the environment variables `PATH` from the parent process, CMake will fail to + # find the Python executable if the Python executable is not in the default location. This usually happens + # to the Python installed by Anaconda. + # + # To minimize the impact of this problem, we set the `Python3_ROOT_DIR` environment variable to the + # directory of current Python executable. + # + # see https://cmake.org/cmake/help/latest/module/FindPython3.html + # + env["Python3_ROOT_DIR"] = str(Path(os.path.dirname(sys.executable)).resolve()) + vcpkg_keep_env_vars += ["Python3_ROOT_DIR"] + + env["VCPKG_KEEP_ENV_VARS"] = ";".join(vcpkg_keep_env_vars) run_subprocess( [*temp_cmake_args, f"-DCMAKE_BUILD_TYPE={config}"], @@ -1667,6 +1720,9 @@ def run_onnxruntime_tests(args, source_dir, ctest_path, build_dir, configs): [sys.executable, "onnxruntime_test_python.py"], cwd=cwd, dll_path=dll_path, python_path=python_path ) + log.info("Testing AutoEP feature") + run_subprocess([sys.executable, "onnxruntime_test_python_autoep.py"], cwd=cwd, dll_path=dll_path) + if not args.disable_contrib_ops: run_subprocess([sys.executable, "onnxruntime_test_python_sparse_matmul.py"], cwd=cwd, dll_path=dll_path) @@ -1761,6 +1817,12 @@ def run_onnxruntime_tests(args, source_dir, ctest_path, build_dir, configs): cwd=cwd, ) + if not args.disable_contrib_ops: + log.info("Testing Python Compile API") + run_subprocess( + [sys.executable, "onnxruntime_test_python_compile_api.py"], cwd=cwd, dll_path=dll_path + ) + if not args.skip_onnx_tests: run_subprocess([os.path.join(cwd, "onnx_test_runner"), "test_models"], cwd=cwd) if config != "Debug": @@ -2189,7 +2251,7 @@ def main(): cmake_extra_defines = normalize_arg_list(args.cmake_extra_defines) - if args.use_tensorrt or args.use_nv_tensorrt_rtx: + if args.use_tensorrt: args.use_cuda = True if args.build_wheel or args.gen_doc or args.enable_training: diff --git a/tools/ci_build/github/azure-pipelines/linux-migraphx-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-migraphx-ci-pipeline.yml index de5df97d37d3d..c6ebb80f98e12 100644 --- a/tools/ci_build/github/azure-pipelines/linux-migraphx-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-migraphx-ci-pipeline.yml @@ -37,7 +37,7 @@ variables: - name: render value: 109 - name: RocmVersion - value: 6.2.3 + value: 6.4 jobs: - job: Linux_Build diff --git a/tools/ci_build/github/azure-pipelines/linux-rocm-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-rocm-ci-pipeline.yml index af5a8d1decb6e..7388ed6d5a1e9 100644 --- a/tools/ci_build/github/azure-pipelines/linux-rocm-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-rocm-ci-pipeline.yml @@ -37,7 +37,7 @@ variables: - name: render value: 109 - name: RocmVersion - value: 6.3.2 + value: 6.4 jobs: - job: Linux_Build diff --git a/tools/ci_build/github/azure-pipelines/npm-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/npm-packaging-pipeline.yml index e5d45f9f07bfb..f6d404c3bde62 100644 --- a/tools/ci_build/github/azure-pipelines/npm-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/npm-packaging-pipeline.yml @@ -114,6 +114,12 @@ extends: targetFolder: $(Build.ArtifactStagingDirectory)\validation-scripts displayName: 'Copy validation scripts' + - script: | + echo "== Source Branch ==" + echo "$(Build.SourceBranch)" + echo "$(Build.SourceBranch)" > $(Build.ArtifactStagingDirectory)\node-artifacts\_branch.txt + displayName: 'Extract Source Branch' + - task: 1ES.PublishPipelineArtifact@1 inputs: artifactName: 'validation_scripts' diff --git a/tools/ci_build/github/azure-pipelines/nuget/templates/dml-vs-2022.yml b/tools/ci_build/github/azure-pipelines/nuget/templates/dml-vs-2022.yml index 3b307abe5fcef..2a09eba776353 100644 --- a/tools/ci_build/github/azure-pipelines/nuget/templates/dml-vs-2022.yml +++ b/tools/ci_build/github/azure-pipelines/nuget/templates/dml-vs-2022.yml @@ -149,14 +149,20 @@ stages: artifactName: ${{ parameters.ArtifactName }} targetPath: '$(Build.ArtifactStagingDirectory)' - - task: PublishSymbols@2 - displayName: 'Publish Build Symbols' - condition: and (succeeded(), or(eq(variables['Build.SourceBranch'], 'refs/heads/main'), startsWith(variables['Build.SourceBranch'], 'refs/heads/rel-'))) - inputs: - SymbolsFolder: '$(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo' - SearchPattern: 'onnxruntime.pdb' - SymbolServerType: teamServices - SymbolExpirationInDays: 365 + + - ${{if or(eq(variables['Build.SourceBranch'], 'refs/heads/main'), startsWith(variables['Build.SourceBranch'], 'refs/heads/rel-'))}}: + - template: publish-symbolrequestprod-api.yml + parameters: + ${{if eq(variables['Build.SourceBranch'], 'refs/heads/main')}}: + symbolExpiryTime: 60 + includePublicSymbolServer: true + symbolsArtifactName: ${{parameters.artifactNameNoVersionString}} + symbolsVersion: $(Build.BuildId) + symbolProject: 'ONNX Runtime' + subscription: 'OnnxrunTimeCodeSign_20240611' + searchPattern: | + $(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo\onnxruntime.pdb + $(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo\onnxruntime_providers_*.pdb # Node.js Publish - ${{ if eq(parameters['DoNodejsPack'], 'true') }}: diff --git a/tools/ci_build/github/azure-pipelines/py-cuda-publishing-pipeline.yml b/tools/ci_build/github/azure-pipelines/py-cuda-publishing-pipeline.yml index 230c391c00ebd..016c09e6c01da 100644 --- a/tools/ci_build/github/azure-pipelines/py-cuda-publishing-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/py-cuda-publishing-pipeline.yml @@ -1,7 +1,7 @@ resources: pipelines: - pipeline: build - source: 'Python CUDA12 Package Test Pipeline' + source: 'Python CUDA Package Test Pipeline' trigger: branches: include: @@ -37,4 +37,4 @@ extends: stages: - template: stages/py-cuda-publishing-stage.yml parameters: - artifact_feed: $(ArtifactFeed) \ No newline at end of file + artifact_feed: $(ArtifactFeed) diff --git a/tools/ci_build/github/azure-pipelines/stages/py-cpu-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/stages/py-cpu-packaging-stage.yml index 4c18fb73cd779..9928a68b6df06 100644 --- a/tools/ci_build/github/azure-pipelines/stages/py-cpu-packaging-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/py-cpu-packaging-stage.yml @@ -123,12 +123,28 @@ stages: --skip_submodule_sync --cmake_generator "Visual Studio 17 2022" --enable_pybind - --enable_onnx_tests --use_vcpkg --use_vcpkg_ms_internal_asset_cache + --enable_onnx_tests --use_vcpkg --use_vcpkg_ms_internal_asset_cache ${{ parameters.build_py_parameters }} --parallel --use_binskim_compliant_compile_flags --update --build $(TelemetryOption) workingDirectory: '$(Build.BinariesDirectory)' + + - ${{if or(eq(variables['Build.SourceBranch'], 'refs/heads/main'), startsWith(variables['Build.SourceBranch'], 'refs/heads/rel-'))}}: + - template: ../templates/publish-symbolrequestprod-api.yml + parameters: + ${{if eq(variables['Build.SourceBranch'], 'refs/heads/main')}}: + symbolExpiryTime: 60 + includePublicSymbolServer: true + symbolsArtifactName: onnxruntime_cpu_win_x64_$(PythonVersion) + symbolsVersion: $(Build.BuildId) + symbolProject: 'ONNX Runtime' + subscription: 'OnnxrunTimeCodeSign_20240611' + searchPattern: | + $(Build.BinariesDirectory)\${{ parameters.cmake_build_type }}\${{ parameters.cmake_build_type }}\onnxruntime.pdb + $(Build.BinariesDirectory)\${{ parameters.cmake_build_type }}\${{ parameters.cmake_build_type }}\onnxruntime_providers_shared.pdb + $(Build.BinariesDirectory)\${{ parameters.cmake_build_type }}\${{ parameters.cmake_build_type }}\onnxruntime_pybind11_state.pdb + # Esrp signing - template: ../templates/win-esrp-dll.yml parameters: diff --git a/tools/ci_build/github/azure-pipelines/stages/py-win-gpu-stage.yml b/tools/ci_build/github/azure-pipelines/stages/py-win-gpu-stage.yml index fe2b85976d38b..0a88391dd4ad6 100644 --- a/tools/ci_build/github/azure-pipelines/stages/py-win-gpu-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/py-win-gpu-stage.yml @@ -133,6 +133,27 @@ stages: $(TelemetryOption) ${{ parameters.BUILD_PY_PARAMETERS }} ${{ parameters.EP_BUILD_FLAGS }} ${{ variables.trt_build_flag }} workingDirectory: '$(Build.BinariesDirectory)' + + - ${{if or(eq(variables['Build.SourceBranch'], 'refs/heads/main'), startsWith(variables['Build.SourceBranch'], 'refs/heads/rel-'))}}: + - template: ../templates/publish-symbolrequestprod-api.yml + parameters: + ${{if eq(variables['Build.SourceBranch'], 'refs/heads/main')}}: + symbolExpiryTime: 60 + includePublicSymbolServer: true + symbolsFolder: '$(Build.BinariesDirectory)\${{ parameters.cmake_build_type }}\${{ parameters.cmake_build_type }}' + symbolsArtifactName: onnxruntime_gpu_win_x64_${{ parameters.PYTHON_VERSION }} + symbolsVersion: $(Build.BuildId) + symbolProject: 'ONNX Runtime' + subscription: 'OnnxrunTimeCodeSign_20240611' + searchPattern: | + $(Build.BinariesDirectory)\${{ parameters.cmake_build_type }}\${{ parameters.cmake_build_type }}\onnxruntime.pdb + $(Build.BinariesDirectory)\${{ parameters.cmake_build_type }}\${{ parameters.cmake_build_type }}\onnxruntime_providers_shared.pdb + $(Build.BinariesDirectory)\${{ parameters.cmake_build_type }}\${{ parameters.cmake_build_type }}\onnxruntime_providers_cuda.pdb + $(Build.BinariesDirectory)\${{ parameters.cmake_build_type }}\${{ parameters.cmake_build_type }}\onnxruntime_providers_tensorrt.pdb + $(Build.BinariesDirectory)\${{ parameters.cmake_build_type }}\${{ parameters.cmake_build_type }}\onnxruntime_providers_dml.pdb + $(Build.BinariesDirectory)\${{ parameters.cmake_build_type }}\${{ parameters.cmake_build_type }}\onnxruntime_pybind11_state.pdb + + # Esrp signing - template: ../templates/win-esrp-dll.yml parameters: @@ -170,7 +191,6 @@ stages: pool: name: ${{parameters.MACHINE_POOL}} steps: - - checkout: self clean: true submodules: none diff --git a/tools/ci_build/github/azure-pipelines/templates/c-api-artifacts-package-and-publish-steps-windows.yml b/tools/ci_build/github/azure-pipelines/templates/c-api-artifacts-package-and-publish-steps-windows.yml index f15a2992e0d00..adf9c91e602a0 100644 --- a/tools/ci_build/github/azure-pipelines/templates/c-api-artifacts-package-and-publish-steps-windows.yml +++ b/tools/ci_build/github/azure-pipelines/templates/c-api-artifacts-package-and-publish-steps-windows.yml @@ -31,6 +31,21 @@ parameters: default: true steps: + - ${{if or(eq(variables['Build.SourceBranch'], 'refs/heads/main'), startsWith(variables['Build.SourceBranch'], 'refs/heads/rel-'))}}: + - template: publish-symbolrequestprod-api.yml + parameters: + ${{if eq(variables['Build.SourceBranch'], 'refs/heads/main')}}: + symbolExpiryTime: 60 + includePublicSymbolServer: true + symbolsArtifactName: ${{parameters.artifactNameNoVersionString}} + symbolsVersion: $(Build.BuildId) + symbolProject: 'ONNX Runtime' + subscription: 'OnnxrunTimeCodeSign_20240611' + searchPattern: | + $(Build.BinariesDirectory)\${{parameters.buildConfig}}\${{parameters.buildConfig}}\onnxruntime.pdb + $(Build.BinariesDirectory)\${{parameters.buildConfig}}\${{parameters.buildConfig}}\onnxruntime_providers_*.pdb + + - task: CmdLine@2 displayName: 'Copy build artifacts for zipping' inputs: diff --git a/tools/ci_build/github/azure-pipelines/templates/publish-symbolrequestprod-api.yml b/tools/ci_build/github/azure-pipelines/templates/publish-symbolrequestprod-api.yml new file mode 100644 index 0000000000000..b2a3eaca0280f --- /dev/null +++ b/tools/ci_build/github/azure-pipelines/templates/publish-symbolrequestprod-api.yml @@ -0,0 +1,87 @@ +# This file was copied from https://github.com/microsoft/devhome/blob/main/build/templates/publish-symbolrequestprod-api.yml#L71 +parameters: + - name: includePublicSymbolServer + type: boolean + default: false + - name: searchPattern + type: string + default: '**/*.pdb' + - name: jobName + type: string + default: PublishSymbols + - name: indexSources + type: boolean + default: true + - name: symbolExpiryTime + type: string + default: 36530 # This is the default from PublishSymbols@2 + - name: symbolsArtifactName + type: string + default: '' + - name: symbolsVersion + type: string + default: '' + - name: symbolProject + type: string + - name: subscription + type: string + +steps: + - powershell: |- + Get-PackageProvider -Name NuGet -ForceBootstrap + Install-Module -Verbose -AllowClobber -Force Az.Accounts, Az.Storage, Az.Network, Az.Resources, Az.Compute + displayName: Install Azure Module Dependencies + + # Transit the Azure token from the Service Connection into a secret variable for the rest of the pipeline to use. + - task: AzurePowerShell@5 + displayName: Generate an Azure Token + inputs: + azureSubscription: ${{ parameters.subscription }} + azurePowerShellVersion: LatestVersion + pwsh: true + ScriptType: InlineScript + Inline: |- + $AzToken = (Get-AzAccessToken -ResourceUrl api://30471ccf-0966-45b9-a979-065dbedb24c1).Token + Write-Host "##vso[task.setvariable variable=SymbolAccessToken;issecret=true]$AzToken" + + - task: PublishSymbols@2 + displayName: Publish Symbols (to current Azure DevOps tenant) + continueOnError: True + inputs: + SearchPattern: ${{ parameters.searchPattern }} + IndexSources: ${{ parameters.indexSources }} + DetailedLog: true + SymbolsMaximumWaitTime: 30 + SymbolServerType: 'TeamServices' + SymbolsProduct: 'onnxruntime' + SymbolsVersion: ${{ parameters.symbolsVersion }} + SymbolsArtifactName: '${{ parameters.symbolsArtifactName }}_${{ parameters.symbolsVersion }}' + SymbolExpirationInDays: ${{ parameters.symbolExpiryTime }} + env: + LIB: $(Build.SourcesDirectory) + + - pwsh: |- + # Prepare the defaults for IRM + $PSDefaultParameterValues['Invoke-RestMethod:Headers'] = @{ Authorization = "Bearer $(SymbolAccessToken)" } + $PSDefaultParameterValues['Invoke-RestMethod:ContentType'] = "application/json" + $PSDefaultParameterValues['Invoke-RestMethod:Method'] = "POST" + + $BaseUri = "https://symbolrequestprod.trafficmanager.net/projects/${{ parameters.symbolProject }}/requests" + + # Prepare the request + $expiration = (Get-Date).Add([TimeSpan]::FromDays(${{ parameters.symbolExpiryTime }})) + $createRequestBody = @{ + requestName = "${{ parameters.symbolsArtifactName }}_${{ parameters.symbolsVersion }}"; + expirationTime = $expiration.ToString(); + } + Write-Host "##[debug]Starting request $($createRequestBody.requestName) with expiration date of $($createRequestBody.expirationTime)" + Invoke-RestMethod -Uri "$BaseUri" -Body ($createRequestBody | ConvertTo-Json -Compress) -Verbose + + # Request symbol publication + $publishRequestBody = @{ + publishToInternalServer = $true; + publishToPublicServer = $${{ parameters.includePublicSymbolServer }}; + } + Write-Host "##[debug]Submitting request $($createRequestBody.requestName) ($($publishRequestBody | ConvertTo-Json -Compress))" + Invoke-RestMethod -Uri "$BaseUri/$($createRequestBody.requestName)" -Body ($publishRequestBody | ConvertTo-Json -Compress) -Verbose + displayName: Publish Symbols using internal REST API diff --git a/tools/ci_build/github/azure-pipelines/templates/py-packaging-training-cuda-stage.yml b/tools/ci_build/github/azure-pipelines/templates/py-packaging-training-cuda-stage.yml deleted file mode 100644 index a1f326ebaafa8..0000000000000 --- a/tools/ci_build/github/azure-pipelines/templates/py-packaging-training-cuda-stage.yml +++ /dev/null @@ -1,94 +0,0 @@ -parameters: -- name: build_py_parameters - displayName: > - Extra parameters to pass to build.py. Don't put newlines in here. - type: string - default: '' - -- name: torch_version - displayName: > - torch_version. - type: string - -- name: opset_version - displayName: > - opset_version. - type: string - -- name: cuda_version - displayName: > - cuda_version. - type: string - -- name: cmake_cuda_architectures - displayName: > - cmake_cuda_architectures - type: string - -- name: docker_file - displayName: > - docker_file. - type: string - -- name: agent_pool - displayName: > - agent_pool. - type: string - -- name: upload_wheel - displayName: > - upload_wheel. - type: string - default: '' - -- name: debug_build - displayName: > - debug_build. - type: boolean - default: false - -- name: SpecificArtifact - displayName: Use Specific Artifact - type: boolean - default: false - -- name: BuildId - displayName: Specific Artifact's BuildId - type: string - default: '0' - -- name: build_pool_name - displayName: > - build_pool_name. - type: string - -- name: PythonVersionList - displayName: Python Version List - type: object - default: - - name: '38' - version: '3.8' - - name: '39' - version: '3.9' - - name: '310' - version: '3.10' - - name: '311' - version: '3.11' - -stages: -- ${{ each python_version in parameters.PythonVersionList }}: - - template: py-packaging-training-cuda-stage-steps.yml - parameters: - build_py_parameters: ${{ parameters.build_py_parameters }} - torch_version: ${{ parameters.torch_version }} - opset_version: ${{ parameters.opset_version }} - cuda_version: ${{ parameters.cuda_version }} - cmake_cuda_architectures: ${{ parameters.cmake_cuda_architectures }} - docker_file: ${{ parameters.docker_file }} - upload_wheel: ${{ parameters.upload_wheel }} - debug_build: ${{ parameters.debug_build }} - stage_name: 'Linux_py_Training_Cuda_Wheels_${{ python_version.name }}' - python_version: ${{ python_version.version }} - SpecificArtifact: ${{ parameters.SpecificArtifact }} - BuildId: ${{ parameters.BuildId }} - build_pool_name: ${{ parameters.build_pool_name }} diff --git a/tools/ci_build/github/azure-pipelines/templates/py-win-arm64ec-qnn.yml b/tools/ci_build/github/azure-pipelines/templates/py-win-arm64ec-qnn.yml index 6df46bfc8e1b0..1a00d67bdbb2a 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-win-arm64ec-qnn.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-win-arm64ec-qnn.yml @@ -91,7 +91,7 @@ jobs: --use_qnn --qnn_home $(QnnSDKRootDir) --enable_pybind - --parallel --update --arm64ec + --parallel --use_vcpkg --use_vcpkg_ms_internal_asset_cache --update --arm64ec $(TelemetryOption) ${{ parameters.BUILD_PY_PARAMETERS }} workingDirectory: '$(Build.BinariesDirectory)' diff --git a/tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml b/tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml index e4888ffd62df3..d739724f8744a 100644 --- a/tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml +++ b/tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml @@ -20,7 +20,7 @@ stages: name: ${{ parameters.qnn_ep_build_pool_name }} variables: OrtPackageId: ${{ parameters.OrtNugetPackageId }} - commonBuildArgs: '--compile_no_warning_as_error --skip_submodule_sync --build_shared_lib --cmake_generator "Visual Studio 17 2022" --config ${{ parameters.build_config }} --parallel --use_binskim_compliant_compile_flags ' + commonBuildArgs: '--compile_no_warning_as_error --skip_submodule_sync --build_shared_lib --cmake_generator "Visual Studio 17 2022" --config ${{ parameters.build_config }} --parallel --use_vcpkg --use_vcpkg_ms_internal_asset_cache --use_binskim_compliant_compile_flags ' steps: - template: set-version-number-variables-step.yml @@ -98,7 +98,7 @@ stages: solution: '$(Build.SourcesDirectory)\csharp\OnnxRuntime.CSharp.proj' platform: 'Any CPU' configuration: ${{ parameters.build_config }} - msbuildArguments: '-t:CreatePackage -p:OnnxRuntimeBuildDirectory="$(Build.BinariesDirectory)" -p:OrtPackageId=$(OrtPackageId) -p:IsReleaseBuild=${{ parameters.IsReleaseBuild }} -p:TargetArchitecture=arm64x' + msbuildArguments: '-t:CreatePackage -p:OnnxRuntimeBuildDirectory="$(Build.BinariesDirectory)" -p:OrtPackageId=$(OrtPackageId) -p:IsReleaseBuild=${{ parameters.IsReleaseBuild }} -p:TargetArchitecture=arm64' workingDirectory: '$(Build.SourcesDirectory)\csharp' - task: CopyFiles@2 diff --git a/tools/ci_build/github/js/validate-npm-packages.py b/tools/ci_build/github/js/validate-npm-packages.py index b009330764973..37b5b3d9807a3 100644 --- a/tools/ci_build/github/js/validate-npm-packages.py +++ b/tools/ci_build/github/js/validate-npm-packages.py @@ -129,6 +129,16 @@ if RELEASE_WEB and RELEASE_REACT_NATIVE and ort_web_ver != ort_react_native_ver: raise Exception("version number is different for onnxruntime-web and onnxruntime-react-native") +# @dev build has to match the following pattern: +# "x.y.z-dev.*" +if tag == "dev": + if RELEASE_NODE and "-dev" not in ort_node_ver: + raise Exception(f'@dev build version should contain "-dev". ort_node_ver={ort_node_ver}') + if RELEASE_WEB and "-dev" not in ort_web_ver: + raise Exception(f'@dev build version should contain "-dev". ort_web_ver={ort_web_ver}') + if RELEASE_REACT_NATIVE and "-dev" not in ort_react_native_ver: + raise Exception(f'@dev build version should contain "-dev". ort_react_native_ver={ort_react_native_ver}') + print("====== validated versions ======") print(f"source_branch={source_branch}") print(f"tag={tag}") diff --git a/tools/ci_build/github/linux/docker/migraphx-ci-pipeline-env.Dockerfile b/tools/ci_build/github/linux/docker/migraphx-ci-pipeline-env.Dockerfile index 667ffe03d8922..7b02a5e658d31 100644 --- a/tools/ci_build/github/linux/docker/migraphx-ci-pipeline-env.Dockerfile +++ b/tools/ci_build/github/linux/docker/migraphx-ci-pipeline-env.Dockerfile @@ -1,7 +1,7 @@ # Refer to https://github.com/RadeonOpenCompute/ROCm-docker/blob/master/dev/Dockerfile-ubuntu-22.04-complete FROM ubuntu:22.04 -ARG ROCM_VERSION=6.2.3 +ARG ROCM_VERSION=6.4 ARG AMDGPU_VERSION=${ROCM_VERSION} ARG APT_PREF='Package: *\nPin: release o=repo.radeon.com\nPin-Priority: 600' diff --git a/tools/ci_build/github/linux/docker/rocm-ci-pipeline-env.Dockerfile b/tools/ci_build/github/linux/docker/rocm-ci-pipeline-env.Dockerfile index 1cd5f289dd1c9..83a4e04435b95 100644 --- a/tools/ci_build/github/linux/docker/rocm-ci-pipeline-env.Dockerfile +++ b/tools/ci_build/github/linux/docker/rocm-ci-pipeline-env.Dockerfile @@ -1,7 +1,7 @@ # Refer to https://github.com/RadeonOpenCompute/ROCm-docker/blob/master/dev/Dockerfile-ubuntu-22.04-complete FROM ubuntu:22.04 -ARG ROCM_VERSION=6.3.2 +ARG ROCM_VERSION=6.4 ARG AMDGPU_VERSION=${ROCM_VERSION} ARG APT_PREF='Package: *\nPin: release o=repo.radeon.com\nPin-Priority: 600' diff --git a/tools/perf_view/ort_perf_view.html b/tools/perf_view/ort_perf_view.html index 509fe5593f6a1..9cd65f06d9337 100644 --- a/tools/perf_view/ort_perf_view.html +++ b/tools/perf_view/ort_perf_view.html @@ -3,9 +3,9 @@ Onnxruntime Perf View - + - +