diff --git a/.github/workflows/linux-wasm-ci-build-and-test-workflow.yml b/.github/workflows/linux-wasm-ci-build-and-test-workflow.yml index 1f57b4c6d2ba2..51d008a34a964 100644 --- a/.github/workflows/linux-wasm-ci-build-and-test-workflow.yml +++ b/.github/workflows/linux-wasm-ci-build-and-test-workflow.yml @@ -92,7 +92,6 @@ jobs: ${{ env.common_build_args }} \ --build_dir ${{ github.workspace }}/build/wasm_inferencing_webgpu \ --use_webgpu \ - --use_jsep \ --use_webnn \ --target onnxruntime_webassembly \ --skip_tests @@ -113,8 +112,8 @@ jobs: if: ${{ inputs.skip_publish != true && inputs.build_webgpu == true }} run: | mkdir -p ${{ github.workspace }}/artifacts/wasm_webgpu/ - cp ${{ github.workspace }}/build/wasm_inferencing_webgpu/${{ inputs.build_config }}/ort-wasm-simd-threaded.jsep.wasm ${{ github.workspace }}/artifacts/wasm_webgpu/ - cp ${{ github.workspace }}/build/wasm_inferencing_webgpu/${{ inputs.build_config }}/ort-wasm-simd-threaded.jsep.mjs ${{ github.workspace }}/artifacts/wasm_webgpu/ + cp ${{ github.workspace }}/build/wasm_inferencing_webgpu/${{ inputs.build_config }}/ort-wasm-simd-threaded.asyncify.wasm ${{ github.workspace }}/artifacts/wasm_webgpu/ + cp ${{ github.workspace }}/build/wasm_inferencing_webgpu/${{ inputs.build_config }}/ort-wasm-simd-threaded.asyncify.mjs ${{ github.workspace }}/artifacts/wasm_webgpu/ - name: Upload WASM artifacts if: ${{ inputs.skip_publish != true }} diff --git a/.github/workflows/windows-web-ci-workflow.yml b/.github/workflows/windows-web-ci-workflow.yml index b45663a6145e3..a2a0e8d36def8 100644 --- a/.github/workflows/windows-web-ci-workflow.yml +++ b/.github/workflows/windows-web-ci-workflow.yml @@ -16,9 +16,6 @@ on: package_name: type: string default: "NPM_packages" - run_webgpu_tests: - type: boolean - default: true jobs: build_onnxruntime_web: @@ -86,6 +83,22 @@ jobs: run: | copy ${{ github.workspace }}\artifacts_wasm\ort-*.mjs ${{ github.workspace }}\js\web\dist\ + - name: Download WebAssembly WebGPU artifacts + uses: actions/download-artifact@v4 + with: + name: ${{ inputs.build_config }}_wasm_webgpu + path: ${{ github.workspace }}/artifacts_wasm_webgpu + + - name: Binplace dist files (.wasm) for WebGPU + shell: cmd + run: | + copy ${{ github.workspace }}\artifacts_wasm_webgpu\ort-*.wasm ${{ github.workspace }}\js\web\dist\ + + - name: Binplace dist files (.mjs) for WebGPU + shell: cmd + run: | + copy ${{ github.workspace }}\artifacts_wasm_webgpu\ort-*.mjs ${{ github.workspace }}\js\web\dist\ + - name: npm ci for /js/ run: npm ci working-directory: ${{ github.workspace }}/js @@ -115,17 +128,7 @@ jobs: run: | Get-WmiObject Win32_Process -Filter "name = 'chrome.exe'" | Format-List CommandLine - - name: Run ort-web tests (wasm,webgl backend) - if: ${{ inputs.run_webgpu_tests != true }} - shell: cmd - run: | - mkdir ${{ runner.temp }}\web\test\01 - dir ${{ runner.temp }}\web\test\01 - npm test -- -e=chrome -b=webgl,wasm --user-data-dir=${{ runner.temp }}\web\test\01 --chromium-flags=--enable-logging --chromium-flags=--v=1 - working-directory: ${{ github.workspace }}\js\web - - name: Run ort-web tests (ALL backends) - if: ${{ inputs.run_webgpu_tests == true }} shell: cmd run: | mkdir ${{ runner.temp }}\web\test\02 @@ -134,7 +137,6 @@ jobs: working-directory: ${{ github.workspace }}\js\web - name: Run ort-web tests (Suite1, webgpu, IO-binding=gpu-tensor) - if: ${{ inputs.run_webgpu_tests == true }} shell: cmd run: | mkdir ${{ runner.temp }}\web\test\03 @@ -143,7 +145,6 @@ jobs: working-directory: ${{ github.workspace }}\js\web - name: Run ort-web tests (Suite1, webgpu, IO-binding=gpu-location) - if: ${{ inputs.run_webgpu_tests == true }} shell: cmd run: | mkdir ${{ runner.temp }}\web\test\04 @@ -169,27 +170,7 @@ jobs: working-directory: ${{ github.workspace }}\js\web # WebGPU EP tests - - name: Download WebAssembly WebGPU artifacts - if: ${{ inputs.run_webgpu_tests == true }} - uses: actions/download-artifact@v4 - with: - name: ${{ inputs.build_config }}_wasm_webgpu - path: ${{ github.workspace }}/artifacts_wasm_webgpu - - - name: Binplace dist files (.wasm) for WebGPU - if: ${{ inputs.run_webgpu_tests == true }} - shell: cmd - run: | - copy /Y ${{ github.workspace }}\artifacts_wasm_webgpu\ort-*.wasm ${{ github.workspace }}\js\web\dist\ - - - name: Binplace dist files (.mjs) for WebGPU - if: ${{ inputs.run_webgpu_tests == true }} - shell: cmd - run: | - copy /Y ${{ github.workspace }}\artifacts_wasm_webgpu\ort-*.mjs ${{ github.workspace }}\js\web\dist\ - - name: Run ort-web tests - WebGPU EP - if: ${{ inputs.run_webgpu_tests == true }} continue-on-error: true shell: cmd run: | @@ -199,7 +180,7 @@ jobs: working-directory: ${{ github.workspace }}\js\web - name: Validate shader keys - WebGPU EP - if: ${{ inputs.run_webgpu_tests == true && inputs.build_config == 'Debug' }} + if: ${{ inputs.build_config == 'Debug' }} uses: ./.github/actions/webgpu-validate-shader-key with: log_file_path: ${{ runner.temp }}\web\test\07\chrome_debug.log @@ -207,7 +188,7 @@ jobs: # this step is added to help investigate the shader validation failure which is hard to reproduce - name: Upload WebGPU shader validation log on failure - if: ${{ failure() && inputs.run_webgpu_tests == true && inputs.build_config == 'Debug' }} + if: ${{ failure() && inputs.build_config == 'Debug' }} uses: actions/upload-artifact@v4 with: name: webgpu-shader-validation-logs diff --git a/.github/workflows/windows_x86.yml b/.github/workflows/windows_x86.yml index 507eacf21cc5a..4652757c1d292 100644 --- a/.github/workflows/windows_x86.yml +++ b/.github/workflows/windows_x86.yml @@ -87,7 +87,7 @@ jobs: - name: Build and Test shell: pwsh run: | - python.exe "${{ github.workspace }}\tools\ci_build\build.py" --config RelWithDebInfo --build_dir "${{ github.workspace }}\build" --skip_submodule_sync --build_csharp --parallel --use_binskim_compliant_compile_flags --cmake_generator "Visual Studio 17 2022" --build_shared_lib --enable_onnx_tests --build_wheel --msbuild_extra_options "IncludeMobileTargets=false" --build_nuget --use_vcpkg --use_vcpkg_ms_internal_asset_cache + python.exe "${{ github.workspace }}\tools\ci_build\build.py" --config RelWithDebInfo --build_dir "${{ github.workspace }}\build" --skip_submodule_sync --build_csharp --parallel --use_binskim_compliant_compile_flags --cmake_generator "Visual Studio 17 2022" --build_shared_lib --enable_onnx_tests --build_wheel --msbuild_extra_options "IncludeMobileTargets=false" --build_nuget --compile_no_warning_as_error --use_vcpkg --use_vcpkg_ms_internal_asset_cache if ($LASTEXITCODE -ne 0) { exit $LASTEXITCODE } diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index 47bfa3f312eec..c6e16d4a3920f 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -372,7 +372,7 @@ if (onnxruntime_USE_ROCM) if (HIPIFY_PERL_PATH-NOTFOUND) MESSAGE(FATAL_ERROR "hipify-perl not found") endif() - MESSAGE("HIPIFY PATH:"${HIPIFY_PERL_PATH}/hipify-perl) + MESSAGE("HIPIFY PATH: ${HIPIFY_PERL_PATH}/hipify-perl") set(onnxruntime_HIPIFY_PERL ${HIPIFY_PERL_PATH}/hipify-perl) endif() @@ -1336,7 +1336,7 @@ function(onnxruntime_configure_target target_name) if(WIN32 AND onnxruntime_ENABLE_STATIC_ANALYSIS AND onnxruntime_USE_CUSTOM_STATIC_ANALYSIS_RULES) set_target_properties(${target_name} PROPERTIES VS_USER_PROPS ${PROJECT_SOURCE_DIR}/EnableVisualStudioCodeAnalysis.props) endif() - target_include_directories(${target_name} PRIVATE ${CMAKE_CURRENT_BINARY_DIR} ${ONNXRUNTIME_ROOT} ${abseil_cpp_SOURCE_DIR}) + target_include_directories(${target_name} PRIVATE ${CMAKE_CURRENT_BINARY_DIR} ${ONNXRUNTIME_ROOT}) if (onnxruntime_ENABLE_TRAINING_OPS) target_include_directories(${target_name} PRIVATE ${ORTTRAINING_ROOT}) endif() @@ -1669,6 +1669,10 @@ if (onnxruntime_ENABLE_DLPACK) add_compile_definitions(ENABLE_DLPACK) endif() +if (onnxruntime_CALLER_FRAMEWORK) + add_definitions(-DORT_CALLER_FRAMEWORK="${onnxruntime_CALLER_FRAMEWORK}") +endif() + if (UNIX OR onnxruntime_USE_NCCL) # Find NCCL if (onnxruntime_USE_NCCL) diff --git a/cmake/deps.txt b/cmake/deps.txt index 2df433b0353c6..0f2f02305a992 100644 --- a/cmake/deps.txt +++ b/cmake/deps.txt @@ -9,7 +9,7 @@ #since the file contains a version string: "lts_20230802". However, the file is for debugging purposes only and would #not affect built binaries. # -abseil_cpp;https://github.com/abseil/abseil-cpp/archive/refs/tags/20240722.0.zip;36ee53eb1466fb6e593fc5c286680de31f8a494a +abseil_cpp;https://github.com/abseil/abseil-cpp/archive/refs/tags/20250512.0.zip;3d6ff7e7ce144d9a53a53bef1f1bf79e1da4b8e1 coremltools;https://github.com/apple/coremltools/archive/refs/tags/7.1.zip;f1bab0f30966f2e217d8e01207d518f230a1641a cxxopts;https://github.com/jarro2783/cxxopts/archive/3c73d91c0b04e2b59462f0a741be8c07024c1bc0.zip;6c6ca7f8480b26c8d00476e0e24b7184717fe4f0 date;https://github.com/HowardHinnant/date/archive/refs/tags/v3.0.1.zip;2dac0c81dc54ebdd8f8d073a75c053b04b56e159 @@ -56,5 +56,5 @@ extensions;https://github.com/microsoft/onnxruntime-extensions/archive/c24b7bab0 composable_kernel;https://github.com/ROCmSoftwarePlatform/composable_kernel/archive/204da9c522cebec5220bba52cd3542ebcaf99e7a.zip;1827348efd47831c13074245274d41b7cae8a557 directx_headers;https://github.com/microsoft/DirectX-Headers/archive/refs/tags/v1.613.1.zip;47653509a3371eabb156360f42faf582f314bf2e cudnn_frontend;https://github.com/NVIDIA/cudnn-frontend/archive/refs/tags/v1.12.0.zip;7e733cfdc410d777b76122d64232499205589a96 -dawn;https://github.com/google/dawn/archive/4cb1f9be152a4fa6bb695c08cd707ab078a1e2fb.zip;de39336b7715f53c14eec61072293b85cc73b691 +dawn;https://github.com/google/dawn/archive/9733be39e18186961d503e064874afe3e9ceb8d1.zip;2a4017c32892b90d072a9102eba90ae691fae36d kleidiai;https://github.com/ARM-software/kleidiai/archive/refs/tags/v1.4.0.tar.gz;22d3b57b54a61c194ab256ff11b0353a3b220244 diff --git a/cmake/external/abseil-cpp.cmake b/cmake/external/abseil-cpp.cmake index 488df5a4e0de8..eede60a4a977a 100644 --- a/cmake/external/abseil-cpp.cmake +++ b/cmake/external/abseil-cpp.cmake @@ -27,7 +27,7 @@ else() endif() # NB! Advancing Abseil version changes its internal namespace, -# currently absl::lts_20240116 which affects abseil-cpp.natvis debugger +# currently absl::lts_20250512 which affects abseil-cpp.natvis debugger # visualization file, that must be adjusted accordingly, unless we eliminate # that namespace at build time. onnxruntime_fetchcontent_declare( @@ -36,7 +36,7 @@ onnxruntime_fetchcontent_declare( URL_HASH SHA1=${DEP_SHA1_abseil_cpp} EXCLUDE_FROM_ALL PATCH_COMMAND ${ABSL_PATCH_COMMAND} - FIND_PACKAGE_ARGS 20240722 NAMES absl + FIND_PACKAGE_ARGS 20250512 NAMES absl ) onnxruntime_fetchcontent_makeavailable(abseil_cpp) diff --git a/cmake/external/abseil-cpp.natvis b/cmake/external/abseil-cpp.natvis index e995e215432a2..75374e0fa9fba 100644 --- a/cmake/external/abseil-cpp.natvis +++ b/cmake/external/abseil-cpp.natvis @@ -1,6 +1,6 @@ - + @@ -24,7 +24,7 @@ - + @@ -51,7 +51,7 @@ - + *($T1 *){value} (*($T1 *){value}) @@ -60,7 +60,7 @@ - + *($T1 *)this (*($T1 *)this) @@ -68,7 +68,7 @@ - + {value.first}, {value.second} ({value.first}, {value.second}) diff --git a/cmake/external/onnxruntime_external_deps.cmake b/cmake/external/onnxruntime_external_deps.cmake index 4f6bcc8c90419..304a8c83959d8 100644 --- a/cmake/external/onnxruntime_external_deps.cmake +++ b/cmake/external/onnxruntime_external_deps.cmake @@ -723,36 +723,22 @@ if (onnxruntime_USE_WEBGPU) ) else() set(ONNXRUNTIME_Dawn_PATCH_COMMAND - # The dawn.patch contains the following changes: + # The dawn_destroy_buffer_on_destructor.patch contains the following changes: # # - (private) Allow WGPUBufferImpl class to destroy the buffer in the destructor # In native implementation, wgpuBufferRelease will trigger the buffer destroy (if refcount decreased to 0). But # in emwgpu implementation, the buffer destroy won't happen. This change adds a destructor to the buffer class # to destroy the buffer when the refcount is 0 for non-external buffers. # - # - (private) Remove hard-coded CMAKE_OSX_DEPLOYMENT_TARGET in Dawn's CMake files - # https://github.com/microsoft/onnxruntime/pull/23729 - # - # - (private) Reduce unsafe buffer usage warning in aligned_storage.h - # https://github.com/microsoft/onnxruntime/pull/24308 - # The patch disables the UNSAFE_BUFFER_USAGE warning around the AlignedStorage struct in aligned_storage.h. This is done - # by using TINT_BEGIN_DISABLE_WARNING and TINT_END_DISABLE_WARNING macros, which helps in warnings related to unsafe buffer usage - # usage when compiling the code, making the build process cleaner and faster. - # - ${Patch_EXECUTABLE} --binary --ignore-whitespace -p1 < ${PROJECT_SOURCE_DIR}/patches/dawn/dawn.patch && + ${Patch_EXECUTABLE} --binary --ignore-whitespace -p1 < ${PROJECT_SOURCE_DIR}/patches/dawn/dawn_destroy_buffer_on_destructor.patch && # The dawn_force_enable_f16_nvidia_vulkan.patch contains the following changes: # # - (private) Force enable f16 support for NVIDIA Vulkan # Dawn disabled f16 support for NVIDIA Vulkan by default because of crashes in f16 CTS tests (crbug.com/tint/2164). # Since the crashes are limited to specific GPU models, we patched Dawn to remove the restriction. - ${Patch_EXECUTABLE} --binary --ignore-whitespace -p1 < ${PROJECT_SOURCE_DIR}/patches/dawn/dawn_force_enable_f16_nvidia_vulkan.patch && - - # The dawn_fix_copy_dxil_dll.patch contains the following changes: # - # - (private) Fix copy of dxil.dll in Dawn - # The patch ensures the copy of dxil.dll to be done after the build step of `dxcompiler` target. - ${Patch_EXECUTABLE} --binary --ignore-whitespace -p1 < ${PROJECT_SOURCE_DIR}/patches/dawn/dawn_fix_copy_dxil_dll.patch) + ${Patch_EXECUTABLE} --binary --ignore-whitespace -p1 < ${PROJECT_SOURCE_DIR}/patches/dawn/dawn_force_enable_f16_nvidia_vulkan.patch) onnxruntime_fetchcontent_declare( dawn diff --git a/cmake/onnxruntime.cmake b/cmake/onnxruntime.cmake index f6130f8c518a6..c4a8641e02444 100644 --- a/cmake/onnxruntime.cmake +++ b/cmake/onnxruntime.cmake @@ -22,6 +22,7 @@ endif() function(get_c_cxx_api_headers HEADERS_VAR) set(_headers "${REPO_ROOT}/include/onnxruntime/core/session/onnxruntime_c_api.h" + "${REPO_ROOT}/include/onnxruntime/core/session/onnxruntime_ep_c_api.h" "${REPO_ROOT}/include/onnxruntime/core/session/onnxruntime_cxx_api.h" "${REPO_ROOT}/include/onnxruntime/core/session/onnxruntime_cxx_inline.h" "${REPO_ROOT}/include/onnxruntime/core/session/onnxruntime_float16.h" diff --git a/cmake/onnxruntime_common.cmake b/cmake/onnxruntime_common.cmake index 1e26eede8a66f..5dcc2b2628bf4 100644 --- a/cmake/onnxruntime_common.cmake +++ b/cmake/onnxruntime_common.cmake @@ -120,16 +120,14 @@ if (onnxruntime_USE_MIMALLOC) target_link_libraries(onnxruntime_common PRIVATE onnxruntime_mimalloc_shim) endif() -if(NOT onnxruntime_DISABLE_ABSEIL) - target_include_directories(onnxruntime_common PRIVATE ${ABSEIL_SOURCE_DIR}) - if (MSVC) - set(ABSEIL_NATVIS_FILE "abseil-cpp.natvis") - target_sources( - onnxruntime_common - INTERFACE $) - endif() +if (MSVC) + set(ABSEIL_NATVIS_FILE "abseil-cpp.natvis") + target_sources( + onnxruntime_common + INTERFACE $) endif() + if (MSVC) set(EIGEN_NATVIS_FILE ${eigen_SOURCE_DIR}/debug/msvc/eigen.natvis) if (EXISTS ${EIGEN_NATVIS_FILE}) diff --git a/cmake/onnxruntime_lora.cmake b/cmake/onnxruntime_lora.cmake index 26ee21c645584..e0579f571e1e0 100644 --- a/cmake/onnxruntime_lora.cmake +++ b/cmake/onnxruntime_lora.cmake @@ -10,8 +10,7 @@ file(GLOB onnxruntime_lora_srcs CONFIGURE_DEPENDS source_group(TREE ${REPO_ROOT} FILES ${onnxruntime_lora_srcs}) onnxruntime_add_static_library(onnxruntime_lora ${onnxruntime_lora_srcs}) -onnxruntime_add_include_to_target(onnxruntime_lora onnx flatbuffers::flatbuffers Boost::mp11 ${GSL_TARGET}) -target_link_libraries(onnxruntime_lora onnxruntime_framework) +onnxruntime_add_include_to_target(onnxruntime_lora onnxruntime_framework onnxruntime_common onnx flatbuffers::flatbuffers Boost::mp11 ${GSL_TARGET}) if(onnxruntime_ENABLE_INSTRUMENT) target_compile_definitions(onnxruntime_lora PUBLIC ONNXRUNTIME_ENABLE_INSTRUMENT) diff --git a/cmake/onnxruntime_providers_cuda.cmake b/cmake/onnxruntime_providers_cuda.cmake index da46f29dacf5f..2e3589a1506d1 100644 --- a/cmake/onnxruntime_providers_cuda.cmake +++ b/cmake/onnxruntime_providers_cuda.cmake @@ -264,6 +264,11 @@ if("90" IN_LIST CMAKE_CUDA_ARCHITECTURES_ORIG) target_compile_options(${target} PRIVATE $<$:-Xptxas=-w>) target_compile_definitions(${target} PRIVATE COMPILE_HOPPER_TMA_GEMMS) + if (MSVC) + target_compile_options(${target} PRIVATE "$<$:SHELL:-Xcompiler /bigobj>") + target_compile_options(${target} PRIVATE "$<$:--diag-suppress=177>") + target_compile_options(${target} PRIVATE "$<$:SHELL:-Xcompiler /wd4172>") + endif() endif() if (onnxruntime_ENABLE_CUDA_PROFILING) # configure cupti for cuda profiling diff --git a/cmake/onnxruntime_python.cmake b/cmake/onnxruntime_python.cmake index 5639b295f0787..b177074a1bc02 100644 --- a/cmake/onnxruntime_python.cmake +++ b/cmake/onnxruntime_python.cmake @@ -537,6 +537,7 @@ set(onnxruntime_mobile_util_srcs ${REPO_ROOT}/tools/python/util/pytorch_export_helpers.py ${REPO_ROOT}/tools/python/util/reduced_build_config_parser.py ${REPO_ROOT}/tools/python/util/update_onnx_opset.py + ${REPO_ROOT}/tools/python/remove_initializer_from_input.py ) file(GLOB onnxruntime_ort_format_model_srcs CONFIGURE_DEPENDS ${REPO_ROOT}/tools/python/util/ort_format_model/*.py diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index 26ef7970fa2b6..9b33313b6147c 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -495,6 +495,7 @@ set (ONNXRUNTIME_GLOBAL_THREAD_POOLS_TEST_SRC_DIR "${TEST_SRC_DIR}/global_thread set (ONNXRUNTIME_CUSTOM_OP_REGISTRATION_TEST_SRC_DIR "${TEST_SRC_DIR}/custom_op_registration") set (ONNXRUNTIME_LOGGING_APIS_TEST_SRC_DIR "${TEST_SRC_DIR}/logging_apis") set (ONNXRUNTIME_AUTOEP_TEST_SRC_DIR "${TEST_SRC_DIR}/autoep") +set (ONNXRUNTIME_EP_GRAPH_TEST_SRC_DIR "${TEST_SRC_DIR}/ep_graph") set (onnxruntime_shared_lib_test_SRC ${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/test_fixture.h @@ -843,9 +844,15 @@ set(all_tests ${onnxruntime_test_common_src} ${onnxruntime_test_ir_src} ${onnxru ${onnxruntime_test_flatbuffers_src} ${onnxruntime_test_lora_src}) if (onnxruntime_ENABLE_CUDA_EP_INTERNAL_TESTS) + if (NOT onnxruntime_MINIMAL_BUILD AND NOT onnxruntime_REDUCED_OPS_BUILD AND NOT onnxruntime_DISABLE_CONTRIB_OPS) + set(onnxruntime_test_cuda_kernels_src_patterns "${TEST_SRC_DIR}/contrib_ops/cuda_kernels/*.cc") + endif() + file(GLOB onnxruntime_test_providers_cuda_ut_src CONFIGURE_DEPENDS "${TEST_SRC_DIR}/providers/cuda/test_cases/*" + ${onnxruntime_test_cuda_kernels_src_patterns} ) + # onnxruntime_providers_cuda_ut is only for unittests. onnxruntime_add_shared_library_module(onnxruntime_providers_cuda_ut ${onnxruntime_test_providers_cuda_ut_src} $) config_cuda_provider_shared_module(onnxruntime_providers_cuda_ut) @@ -1335,12 +1342,15 @@ endif() # shared lib if (onnxruntime_BUILD_SHARED_LIB) if(WIN32) - AddTest(DYN - TARGET onnxruntime_shared_lib_dlopen_test - SOURCES ${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/dlopen_main.cc - LIBS onnxruntime - DEPENDS ${all_dependencies} - ) + onnxruntime_add_executable(onnxruntime_shared_lib_dlopen_test ${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/dlopen_main.cc) + add_dependencies(onnxruntime_shared_lib_dlopen_test ${all_dependencies} onnxruntime) + add_test(NAME onnxruntime_shared_lib_dlopen_test COMMAND onnxruntime_shared_lib_dlopen_test WORKING_DIRECTORY $) + set_target_properties(onnxruntime_shared_lib_dlopen_test PROPERTIES FOLDER "ONNXRuntimeTest") + + if (MSVC) + # set VS debugger working directory to the test program's directory + set_target_properties(onnxruntime_shared_lib_dlopen_test PROPERTIES VS_DEBUGGER_WORKING_DIRECTORY $) + endif() endif() onnxruntime_add_static_library(onnxruntime_mocked_allocator ${TEST_SRC_DIR}/util/test_allocator.cc) target_include_directories(onnxruntime_mocked_allocator PUBLIC ${TEST_SRC_DIR}/util/include) @@ -1831,6 +1841,8 @@ if (WIN32 AND onnxruntime_BUILD_SHARED_LIB AND NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten" AND NOT onnxruntime_MINIMAL_BUILD) onnxruntime_add_shared_library_module(example_plugin_ep + ${TEST_SRC_DIR}/autoep/library/example_plugin_ep_utils.h + ${TEST_SRC_DIR}/autoep/library/example_plugin_ep_utils.cc ${TEST_SRC_DIR}/autoep/library/example_plugin_ep.cc) target_include_directories(example_plugin_ep PRIVATE ${REPO_ROOT}/include/onnxruntime/core/session) target_link_libraries(example_plugin_ep PRIVATE onnxruntime) @@ -1852,8 +1864,8 @@ if (WIN32 AND onnxruntime_BUILD_SHARED_LIB AND ${ONNXRUNTIME_AUTOEP_LIB_LINK_FLAG}) # test library - file(GLOB_RECURSE onnxruntime_autoep_test_SRC "${ONNXRUNTIME_AUTOEP_TEST_SRC_DIR}/*.h" - "${ONNXRUNTIME_AUTOEP_TEST_SRC_DIR}/*.cc") + file(GLOB onnxruntime_autoep_test_SRC "${ONNXRUNTIME_AUTOEP_TEST_SRC_DIR}/*.h" + "${ONNXRUNTIME_AUTOEP_TEST_SRC_DIR}/*.cc") set(onnxruntime_autoep_test_LIBS onnxruntime_mocked_allocator ${ONNXRUNTIME_TEST_LIBS} onnxruntime_test_utils onnx_proto onnx ${onnxruntime_EXTERNAL_LIBRARIES}) @@ -1992,4 +2004,34 @@ if (onnxruntime_USE_WEBGPU AND WIN32 AND onnxruntime_BUILD_SHARED_LIB AND NOT CM ) endif() +# onnxruntime_ep_graph_test tests the implementation of the public OrtGraph APIs for use in plugin EPs (OrtEp). +if (onnxruntime_BUILD_SHARED_LIB AND NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten" AND NOT onnxruntime_MINIMAL_BUILD) + file(GLOB_RECURSE onnxruntime_ep_graph_test_SRC "${ONNXRUNTIME_EP_GRAPH_TEST_SRC_DIR}/*.h" + "${ONNXRUNTIME_EP_GRAPH_TEST_SRC_DIR}/*.cc") + + set(onnxruntime_ep_graph_test_LIBS ${ONNXRUNTIME_TEST_LIBS} onnxruntime_test_utils ${onnxruntime_EXTERNAL_LIBRARIES}) + if (CMAKE_SYSTEM_NAME MATCHES "AIX") + list(APPEND onnxruntime_ep_graph_test_LIBS onnxruntime_session onnxruntime_util onnxruntime_lora onnxruntime_framework + onnxruntime_common onnxruntime_graph onnxruntime_providers onnxruntime_mlas + onnxruntime_optimizer onnxruntime_flatbuffers iconv re2 + ${PROTOBUF_LIB} onnx onnx_proto) + endif() + + if(NOT WIN32) + list(APPEND onnxruntime_ep_graph_test_LIBS ${CMAKE_DL_LIBS}) + endif() + + if (onnxruntime_USE_TENSORRT OR onnxruntime_USE_NV) + # Need this because unittest_main_src defines a global nvinfer1::IBuilder variable. + list(APPEND onnxruntime_ep_graph_test_LIBS ${TENSORRT_LIBRARY_INFER}) + endif() + + AddTest(DYN + TARGET onnxruntime_ep_graph_test + SOURCES ${onnxruntime_ep_graph_test_SRC} ${onnxruntime_unittest_main_src} + LIBS ${onnxruntime_ep_graph_test_LIBS} + DEPENDS ${all_dependencies} + ) +endif() + include(onnxruntime_fuzz_test.cmake) diff --git a/cmake/onnxruntime_webassembly.cmake b/cmake/onnxruntime_webassembly.cmake index c0b6efb0eb75d..486439a68b7ff 100644 --- a/cmake/onnxruntime_webassembly.cmake +++ b/cmake/onnxruntime_webassembly.cmake @@ -84,6 +84,10 @@ function(bundle_static_library bundled_target_name) add_dependencies(${bundled_target_name} bundling_target) endfunction() +if (onnxruntime_USE_JSEP AND onnxruntime_USE_WEBGPU) + message(FATAL_ERROR "onnxruntime_USE_JSEP and onnxruntime_USE_WEBGPU cannot be enabled at the same time.") +endif() + if (NOT onnxruntime_ENABLE_WEBASSEMBLY_THREADS) add_compile_definitions( BUILD_MLAS_NO_ONNXRUNTIME @@ -406,6 +410,16 @@ jsepDownload:_pp_") list(APPEND onnxruntime_webassembly_script_deps "${ONNXRUNTIME_ROOT}/wasm/post-webgpu.js") endif() + if (onnxruntime_USE_WEBNN) + target_compile_definitions(onnxruntime_webassembly PRIVATE USE_WEBNN=1) + if (NOT onnxruntime_USE_JSEP) + target_link_options(onnxruntime_webassembly PRIVATE + "SHELL:--post-js \"${ONNXRUNTIME_ROOT}/wasm/post-webnn.js\"" + ) + list(APPEND onnxruntime_webassembly_script_deps "${ONNXRUNTIME_ROOT}/wasm/post-webnn.js") + endif() + endif() + if (onnxruntime_USE_JSEP OR onnxruntime_USE_WEBGPU OR onnxruntime_USE_WEBNN) # if any of the above is enabled, we need to use the asyncify library target_link_options(onnxruntime_webassembly PRIVATE @@ -499,6 +513,9 @@ jsepDownload:_pp_") if (onnxruntime_USE_JSEP) string(APPEND target_name ".jsep") + elseif (onnxruntime_USE_WEBGPU OR onnxruntime_USE_WEBNN) + string(APPEND target_name ".asyncify") + # TODO: support JSPI and add ".jspi" once JSPI build is supported endif() set_target_properties(onnxruntime_webassembly PROPERTIES OUTPUT_NAME ${target_name} SUFFIX ".mjs") diff --git a/cmake/patches/abseil/absl_windows.patch b/cmake/patches/abseil/absl_windows.patch index c50e147aa4a7d..036ead7b88e46 100644 --- a/cmake/patches/abseil/absl_windows.patch +++ b/cmake/patches/abseil/absl_windows.patch @@ -1,163 +1,44 @@ -diff --git a/absl/base/attributes.h b/absl/base/attributes.h -index 5ea5ee3e..f4949898 100644 ---- a/absl/base/attributes.h -+++ b/absl/base/attributes.h -@@ -559,7 +559,7 @@ - #undef ABSL_ATTRIBUTE_UNUSED - #define ABSL_ATTRIBUTE_UNUSED __attribute__((__unused__)) - #else --#define ABSL_ATTRIBUTE_UNUSED -+#define ABSL_ATTRIBUTE_UNUSED [[maybe_unused]] - #endif - - // ABSL_ATTRIBUTE_INITIAL_EXEC diff --git a/absl/container/internal/raw_hash_set.h b/absl/container/internal/raw_hash_set.h -index d4fe8f5c..27418d13 100644 +index 3effc441..c339e269 100644 --- a/absl/container/internal/raw_hash_set.h +++ b/absl/container/internal/raw_hash_set.h -@@ -1924,7 +1924,7 @@ HashtablezInfoHandle SampleHashtablezInfo(size_t sizeof_slot, size_t sizeof_key, - // In SOO, we sample on the first insertion so if this is an empty SOO case - // (e.g. when reserve is called), then we still need to sample. - if (kSooEnabled && was_soo && c.size() == 0) { -- return Sample(sizeof_slot, sizeof_key, sizeof_value, SooCapacity()); -+ return Sample(sizeof_slot, sizeof_key, sizeof_value, (int16_t)SooCapacity()); - } - // For non-SOO cases, we sample whenever the capacity is increasing from zero - // to non-zero. -@@ -3525,7 +3525,7 @@ class raw_hash_set { - assert(is_soo()); - if (!ShouldSampleHashtablezInfo()) return HashtablezInfoHandle{}; - return Sample(sizeof(slot_type), sizeof(key_type), sizeof(value_type), -- SooCapacity()); -+ (int16_t)SooCapacity()); +@@ -1121,11 +1121,12 @@ class CommonFields : public CommonFieldsGenerationInfo { + #ifdef NDEBUG + f(); + return; +-#endif ++#else + const size_t cap = capacity(); + set_capacity(InvalidCapacity::kReentrance); + f(); + set_capacity(cap); ++#endif } - inline void destroy_slots() { -diff --git a/absl/copts/GENERATED_AbseilCopts.cmake b/absl/copts/GENERATED_AbseilCopts.cmake -index da2282fe..4c7fc26f 100644 ---- a/absl/copts/GENERATED_AbseilCopts.cmake -+++ b/absl/copts/GENERATED_AbseilCopts.cmake -@@ -181,8 +181,6 @@ list(APPEND ABSL_MSVC_FLAGS - "/wd4005" - "/wd4068" - "/wd4180" -- "/wd4244" -- "/wd4267" - "/wd4503" - "/wd4800" - "/DNOMINMAX" -diff --git a/absl/copts/GENERATED_copts.bzl b/absl/copts/GENERATED_copts.bzl -index b9e0071e..dd8410ec 100644 ---- a/absl/copts/GENERATED_copts.bzl -+++ b/absl/copts/GENERATED_copts.bzl -@@ -182,8 +182,6 @@ ABSL_MSVC_FLAGS = [ - "/wd4005", - "/wd4068", - "/wd4180", -- "/wd4244", -- "/wd4267", - "/wd4503", - "/wd4800", - "/DNOMINMAX", -diff --git a/absl/copts/copts.py b/absl/copts/copts.py -index 2d85ac74..4875d668 100644 ---- a/absl/copts/copts.py -+++ b/absl/copts/copts.py -@@ -118,10 +118,6 @@ MSVC_WARNING_FLAGS = [ - "/wd4068", # unknown pragma - # qualifier applied to function type has no meaning; ignored - "/wd4180", -- # conversion from 'type1' to 'type2', possible loss of data -- "/wd4244", -- # conversion from 'size_t' to 'type', possible loss of data -- "/wd4267", - # The decorated name was longer than the compiler limit - "/wd4503", - # forcing value to bool 'true' or 'false' (performance warning) -diff --git a/absl/debugging/symbolize.cc b/absl/debugging/symbolize.cc -index 638d3954..6b817075 100644 ---- a/absl/debugging/symbolize.cc -+++ b/absl/debugging/symbolize.cc -@@ -14,7 +14,7 @@ - - #include "absl/debugging/symbolize.h" - --#ifdef _WIN32 -+#if defined(_WIN32) && !defined(NDEBUG) - #include - #if !(WINAPI_FAMILY_PARTITION(WINAPI_PARTITION_APP)) || \ - WINAPI_FAMILY_PARTITION(WINAPI_PARTITION_DESKTOP) -diff --git a/absl/debugging/symbolize_win32.inc b/absl/debugging/symbolize_win32.inc -index 53a099a1..34d210d6 100644 ---- a/absl/debugging/symbolize_win32.inc -+++ b/absl/debugging/symbolize_win32.inc -@@ -35,15 +35,15 @@ ABSL_NAMESPACE_BEGIN - - static HANDLE process = NULL; - --void InitializeSymbolizer(const char*) { -- if (process != nullptr) { -- return; -- } -+namespace { -+void InitializeSymbolizerImpl() { -+ - process = GetCurrentProcess(); - - // Symbols are not loaded until a reference is made requiring the - // symbols be loaded. This is the fastest, most efficient way to use - // the symbol handler. -+ - SymSetOptions(SYMOPT_DEFERRED_LOADS | SYMOPT_UNDNAME); - if (!SymInitialize(process, nullptr, true)) { - // GetLastError() returns a Win32 DWORD, but we assign to -@@ -54,6 +54,36 @@ void InitializeSymbolizer(const char*) { - } - } - -+bool LookupAndInitialize(const void* pc, SYMBOL_INFO* symbol) { -+ auto hProcess = (process != NULL) ? process : GetCurrentProcess(); -+ if (SymFromAddr(hProcess, reinterpret_cast(pc), nullptr, symbol) != TRUE) { -+ if (GetLastError() == ERROR_INVALID_HANDLE && process == NULL) { -+ InitializeSymbolizerImpl(); -+ if (SymFromAddr(process, reinterpret_cast(pc), nullptr, symbol) != TRUE) { -+ return false; -+ } -+ } else { -+ return false; -+ } -+ return false; + private: +@@ -3344,11 +3345,14 @@ class raw_hash_set { + + // Asserts that hash and equal functors provided by the user are consistent, + // meaning that `eq(k1, k2)` implies `hash(k1)==hash(k2)`. +- template +- void AssertHashEqConsistent(const K& key) { + #ifdef NDEBUG ++ template ++ void AssertHashEqConsistent(const K&) { + return; +-#endif + } -+ return true; -+} -+} -+ -+void InitializeSymbolizer(const char*) { -+ if (process != nullptr) { -+ return; -+ } -+ -+ alignas(SYMBOL_INFO) char buf[sizeof(SYMBOL_INFO) + MAX_SYM_NAME]; -+ SYMBOL_INFO* symbol = reinterpret_cast(buf); -+ symbol->SizeOfStruct = sizeof(SYMBOL_INFO); -+ symbol->MaxNameLen = MAX_SYM_NAME; -+ -+ static_cast(LookupAndInitialize(reinterpret_cast(&InitializeSymbolizer), symbol)); -+} -+ - bool Symbolize(const void* pc, char* out, int out_size) { - if (out_size <= 0) { - return false; -@@ -62,9 +92,11 @@ bool Symbolize(const void* pc, char* out, int out_size) { - SYMBOL_INFO* symbol = reinterpret_cast(buf); - symbol->SizeOfStruct = sizeof(SYMBOL_INFO); - symbol->MaxNameLen = MAX_SYM_NAME; -- if (!SymFromAddr(process, reinterpret_cast(pc), nullptr, symbol)) { -+ -+ if(!LookupAndInitialize(pc, symbol)) { - return false; ++#else ++ template ++ void AssertHashEqConsistent(const K& key) { + // If the hash/eq functors are known to be consistent, then skip validation. + if (std::is_same::value && + std::is_same::value) { +@@ -3386,6 +3390,7 @@ class raw_hash_set { + if (capacity() > 16) return; + IterateOverFullSlots(common(), sizeof(slot_type), assert_consistent); } -+ - const size_t out_size_t = static_cast(out_size); - strncpy(out, symbol->Name, out_size_t); - if (out[out_size_t - 1] != '\0') { ++#endif + + // Attempts to find `key` in the table; if it isn't found, returns an iterator + // where the value can be inserted into, with the control byte already set to diff --git a/cmake/patches/dawn/dawn.patch b/cmake/patches/dawn/dawn.patch deleted file mode 100644 index 1fe66d2cf917d..0000000000000 --- a/cmake/patches/dawn/dawn.patch +++ /dev/null @@ -1,59 +0,0 @@ -diff --git a/src/cmake/DawnCompilerPlatformFlags.cmake b/src/cmake/DawnCompilerPlatformFlags.cmake -index 50638e2456..efa42711e6 100644 ---- a/src/cmake/DawnCompilerPlatformFlags.cmake -+++ b/src/cmake/DawnCompilerPlatformFlags.cmake -@@ -63,7 +63,3 @@ endif () - if (MSVC AND NOT COMPILER_IS_CLANG_CL) - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /MP") - endif () -- --if (TARGET_MACOS) -- set(CMAKE_OSX_DEPLOYMENT_TARGET "11.0" CACHE STRING "Minimum macOS version" FORCE) --endif () -\ No newline at end of file -diff --git a/third_party/emdawnwebgpu/webgpu.cpp b/third_party/emdawnwebgpu/webgpu.cpp -index 5bfac41dcc..71a153daaa 100644 ---- a/third_party/emdawnwebgpu/webgpu.cpp -+++ b/third_party/emdawnwebgpu/webgpu.cpp -@@ -692,6 +692,7 @@ struct WGPUBufferImpl final : public EventSource, - WGPUBufferImpl(const EventSource* source, bool mappedAtCreation); - // Injection constructor used when we already have a backing Buffer. - WGPUBufferImpl(const EventSource* source, WGPUBufferMapState mapState); -+ ~WGPUBufferImpl(); - - void Destroy(); - const void* GetConstMappedRange(size_t offset, size_t size); -@@ -1361,6 +1362,12 @@ WGPUBufferImpl::WGPUBufferImpl(const EventSource* source, - RefCountedWithExternalCount(kImportedFromJS), - mMapState(mapState) {} - -+WGPUBufferImpl::~WGPUBufferImpl() { -+ if (!IsImported()) { -+ Destroy(); -+ } -+} -+ - void WGPUBufferImpl::Destroy() { - emwgpuBufferDestroy(this); - AbortPendingMap("Buffer was destroyed before mapping was resolved."); -diff --git a/src/tint/utils/memory/aligned_storage.h b/src/tint/utils/memory/aligned_storage.h -index c532c4fc38..19c950af4c 100644 ---- a/src/tint/utils/memory/aligned_storage.h -+++ b/src/tint/utils/memory/aligned_storage.h -@@ -31,6 +31,9 @@ - #include - - #include "src/tint/utils/memory/bitcast.h" -+#include "src/tint/utils/macros/compiler.h" -+ -+TINT_BEGIN_DISABLE_WARNING(UNSAFE_BUFFER_USAGE); - - namespace tint { - -@@ -50,4 +53,6 @@ struct alignas(alignof(T)) AlignedStorage { - - } // namespace tint - -+TINT_END_DISABLE_WARNING(UNSAFE_BUFFER_USAGE); -+ - #endif // SRC_TINT_UTILS_MEMORY_ALIGNED_STORAGE_H_ diff --git a/cmake/patches/dawn/dawn_destroy_buffer_on_destructor.patch b/cmake/patches/dawn/dawn_destroy_buffer_on_destructor.patch new file mode 100644 index 0000000000000..6bfbe6427ab83 --- /dev/null +++ b/cmake/patches/dawn/dawn_destroy_buffer_on_destructor.patch @@ -0,0 +1,25 @@ +diff --git a/third_party/emdawnwebgpu/pkg/webgpu/src/webgpu.cpp b/third_party/emdawnwebgpu/pkg/webgpu/src/webgpu.cpp +index be0bb93781..4fe1d34b29 100644 +--- a/third_party/emdawnwebgpu/pkg/webgpu/src/webgpu.cpp ++++ b/third_party/emdawnwebgpu/pkg/webgpu/src/webgpu.cpp +@@ -734,6 +734,7 @@ struct WGPUBufferImpl final : public EventSource, + WGPUBufferImpl(const EventSource* source, bool mappedAtCreation); + // Injection constructor used when we already have a backing Buffer. + WGPUBufferImpl(const EventSource* source, WGPUBufferMapState mapState); ++ ~WGPUBufferImpl(); + + void Destroy(); + const void* GetConstMappedRange(size_t offset, size_t size); +@@ -1416,6 +1417,12 @@ WGPUBufferImpl::WGPUBufferImpl(const EventSource* source, + RefCountedWithExternalCount(kImportedFromJS), + mMapState(mapState) {} + ++WGPUBufferImpl::~WGPUBufferImpl() { ++ if (!IsImported()) { ++ Destroy(); ++ } ++} ++ + void WGPUBufferImpl::Destroy() { + emwgpuBufferDestroy(this); + AbortPendingMap("Buffer was destroyed before mapping was resolved."); diff --git a/cmake/patches/dawn/dawn_fix_copy_dxil_dll.patch b/cmake/patches/dawn/dawn_fix_copy_dxil_dll.patch deleted file mode 100644 index cd4d53b4cbdb7..0000000000000 --- a/cmake/patches/dawn/dawn_fix_copy_dxil_dll.patch +++ /dev/null @@ -1,13 +0,0 @@ -diff --git a/third_party/CMakeLists.txt b/third_party/CMakeLists.txt -index cdfde38819..fc5ff76421 100644 ---- a/third_party/CMakeLists.txt -+++ b/third_party/CMakeLists.txt -@@ -352,6 +352,8 @@ function(AddSubdirectoryDXC) - TARGET copy_dxil_dll - COMMAND ${CMAKE_COMMAND} -E copy_if_different ${DXIL_DLL_PATH} $ - COMMENT "Copying ${DXIL_DLL_PATH} to $") -+ # Ensure folder "$" exists when copying the dll -+ add_dependencies(copy_dxil_dll dxcompiler) - # Make dxc target depend on copy_dxil_dll - add_dependencies(dxc copy_dxil_dll) - endif() diff --git a/cmake/patches/onnx/onnx.patch b/cmake/patches/onnx/onnx.patch index c782db4b6d64d..30d5a44a1d1cc 100644 --- a/cmake/patches/onnx/onnx.patch +++ b/cmake/patches/onnx/onnx.patch @@ -1,5 +1,5 @@ diff --git a/CMakeLists.txt b/CMakeLists.txt -index 6fe5c96e..087a7780 100644 +index 8b5af303..7fe05a5a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -40,6 +40,7 @@ option(ONNX_USE_LITE_PROTO "Use lite protobuf instead of full." OFF) @@ -55,10 +55,26 @@ index 6fe5c96e..087a7780 100644 set(ONNX_API_DEFINE "-DONNX_API=__attribute__\(\(__visibility__\(\"default\"\)\)\)") set_target_properties(onnx_proto PROPERTIES CXX_VISIBILITY_PRESET hidden) set_target_properties(onnx_proto PROPERTIES VISIBILITY_INLINES_HIDDEN 1) -@@ -620,21 +636,11 @@ if(MSVC) +@@ -595,13 +611,6 @@ if(ONNX_BUILD_PYTHON) + target_link_libraries(onnx_cpp2py_export PRIVATE ${Python3_LIBRARIES}) + target_compile_options(onnx_cpp2py_export + PRIVATE /MP +- /wd4146 # unary minus operator applied to unsigned type, +- # result still unsigned +- /wd4244 # 'argument': conversion from 'google:: +- # protobuf::uint64' to 'int', possible +- # loss of data +- /wd4267 # Conversion from 'size_t' to 'int', +- # possible loss of data + ${EXTRA_FLAGS}) + add_msvc_runtime_flag(onnx_cpp2py_export) + add_onnx_global_defines(onnx_cpp2py_export) +@@ -618,23 +627,9 @@ endif() + if(MSVC) + target_compile_options(onnx_proto PRIVATE /MP - /wd4146 # unary minus operator applied to unsigned type, - # result still unsigned +- /wd4146 # unary minus operator applied to unsigned type, +- # result still unsigned - /wd4244 #'argument': conversion from 'google:: - #protobuf::uint64' to 'int', possible - # loss of data @@ -67,8 +83,8 @@ index 6fe5c96e..087a7780 100644 ${EXTRA_FLAGS}) target_compile_options(onnx PRIVATE /MP - /wd4146 # unary minus operator applied to unsigned type, - # result still unsigned +- /wd4146 # unary minus operator applied to unsigned type, +- # result still unsigned - /wd4244 # 'argument': conversion from 'google:: - # protobuf::uint64' to 'int', possible - # loss of data @@ -134,7 +150,7 @@ index c0ed3a39..6c8e2909 100644 auto direction = getAttribute(ctx, "direction", "forward"); diff --git a/onnx/defs/schema.h b/onnx/defs/schema.h -index 42318d82..a33cf342 100644 +index acf3aac7..5bef6e72 100644 --- a/onnx/defs/schema.h +++ b/onnx/defs/schema.h @@ -980,10 +980,7 @@ class OpSchemaRegistry final : public ISchemaRegistry { diff --git a/cmake/vcpkg-ports/abseil/absl_windows.patch b/cmake/vcpkg-ports/abseil/absl_windows.patch index c50e147aa4a7d..036ead7b88e46 100644 --- a/cmake/vcpkg-ports/abseil/absl_windows.patch +++ b/cmake/vcpkg-ports/abseil/absl_windows.patch @@ -1,163 +1,44 @@ -diff --git a/absl/base/attributes.h b/absl/base/attributes.h -index 5ea5ee3e..f4949898 100644 ---- a/absl/base/attributes.h -+++ b/absl/base/attributes.h -@@ -559,7 +559,7 @@ - #undef ABSL_ATTRIBUTE_UNUSED - #define ABSL_ATTRIBUTE_UNUSED __attribute__((__unused__)) - #else --#define ABSL_ATTRIBUTE_UNUSED -+#define ABSL_ATTRIBUTE_UNUSED [[maybe_unused]] - #endif - - // ABSL_ATTRIBUTE_INITIAL_EXEC diff --git a/absl/container/internal/raw_hash_set.h b/absl/container/internal/raw_hash_set.h -index d4fe8f5c..27418d13 100644 +index 3effc441..c339e269 100644 --- a/absl/container/internal/raw_hash_set.h +++ b/absl/container/internal/raw_hash_set.h -@@ -1924,7 +1924,7 @@ HashtablezInfoHandle SampleHashtablezInfo(size_t sizeof_slot, size_t sizeof_key, - // In SOO, we sample on the first insertion so if this is an empty SOO case - // (e.g. when reserve is called), then we still need to sample. - if (kSooEnabled && was_soo && c.size() == 0) { -- return Sample(sizeof_slot, sizeof_key, sizeof_value, SooCapacity()); -+ return Sample(sizeof_slot, sizeof_key, sizeof_value, (int16_t)SooCapacity()); - } - // For non-SOO cases, we sample whenever the capacity is increasing from zero - // to non-zero. -@@ -3525,7 +3525,7 @@ class raw_hash_set { - assert(is_soo()); - if (!ShouldSampleHashtablezInfo()) return HashtablezInfoHandle{}; - return Sample(sizeof(slot_type), sizeof(key_type), sizeof(value_type), -- SooCapacity()); -+ (int16_t)SooCapacity()); +@@ -1121,11 +1121,12 @@ class CommonFields : public CommonFieldsGenerationInfo { + #ifdef NDEBUG + f(); + return; +-#endif ++#else + const size_t cap = capacity(); + set_capacity(InvalidCapacity::kReentrance); + f(); + set_capacity(cap); ++#endif } - inline void destroy_slots() { -diff --git a/absl/copts/GENERATED_AbseilCopts.cmake b/absl/copts/GENERATED_AbseilCopts.cmake -index da2282fe..4c7fc26f 100644 ---- a/absl/copts/GENERATED_AbseilCopts.cmake -+++ b/absl/copts/GENERATED_AbseilCopts.cmake -@@ -181,8 +181,6 @@ list(APPEND ABSL_MSVC_FLAGS - "/wd4005" - "/wd4068" - "/wd4180" -- "/wd4244" -- "/wd4267" - "/wd4503" - "/wd4800" - "/DNOMINMAX" -diff --git a/absl/copts/GENERATED_copts.bzl b/absl/copts/GENERATED_copts.bzl -index b9e0071e..dd8410ec 100644 ---- a/absl/copts/GENERATED_copts.bzl -+++ b/absl/copts/GENERATED_copts.bzl -@@ -182,8 +182,6 @@ ABSL_MSVC_FLAGS = [ - "/wd4005", - "/wd4068", - "/wd4180", -- "/wd4244", -- "/wd4267", - "/wd4503", - "/wd4800", - "/DNOMINMAX", -diff --git a/absl/copts/copts.py b/absl/copts/copts.py -index 2d85ac74..4875d668 100644 ---- a/absl/copts/copts.py -+++ b/absl/copts/copts.py -@@ -118,10 +118,6 @@ MSVC_WARNING_FLAGS = [ - "/wd4068", # unknown pragma - # qualifier applied to function type has no meaning; ignored - "/wd4180", -- # conversion from 'type1' to 'type2', possible loss of data -- "/wd4244", -- # conversion from 'size_t' to 'type', possible loss of data -- "/wd4267", - # The decorated name was longer than the compiler limit - "/wd4503", - # forcing value to bool 'true' or 'false' (performance warning) -diff --git a/absl/debugging/symbolize.cc b/absl/debugging/symbolize.cc -index 638d3954..6b817075 100644 ---- a/absl/debugging/symbolize.cc -+++ b/absl/debugging/symbolize.cc -@@ -14,7 +14,7 @@ - - #include "absl/debugging/symbolize.h" - --#ifdef _WIN32 -+#if defined(_WIN32) && !defined(NDEBUG) - #include - #if !(WINAPI_FAMILY_PARTITION(WINAPI_PARTITION_APP)) || \ - WINAPI_FAMILY_PARTITION(WINAPI_PARTITION_DESKTOP) -diff --git a/absl/debugging/symbolize_win32.inc b/absl/debugging/symbolize_win32.inc -index 53a099a1..34d210d6 100644 ---- a/absl/debugging/symbolize_win32.inc -+++ b/absl/debugging/symbolize_win32.inc -@@ -35,15 +35,15 @@ ABSL_NAMESPACE_BEGIN - - static HANDLE process = NULL; - --void InitializeSymbolizer(const char*) { -- if (process != nullptr) { -- return; -- } -+namespace { -+void InitializeSymbolizerImpl() { -+ - process = GetCurrentProcess(); - - // Symbols are not loaded until a reference is made requiring the - // symbols be loaded. This is the fastest, most efficient way to use - // the symbol handler. -+ - SymSetOptions(SYMOPT_DEFERRED_LOADS | SYMOPT_UNDNAME); - if (!SymInitialize(process, nullptr, true)) { - // GetLastError() returns a Win32 DWORD, but we assign to -@@ -54,6 +54,36 @@ void InitializeSymbolizer(const char*) { - } - } - -+bool LookupAndInitialize(const void* pc, SYMBOL_INFO* symbol) { -+ auto hProcess = (process != NULL) ? process : GetCurrentProcess(); -+ if (SymFromAddr(hProcess, reinterpret_cast(pc), nullptr, symbol) != TRUE) { -+ if (GetLastError() == ERROR_INVALID_HANDLE && process == NULL) { -+ InitializeSymbolizerImpl(); -+ if (SymFromAddr(process, reinterpret_cast(pc), nullptr, symbol) != TRUE) { -+ return false; -+ } -+ } else { -+ return false; -+ } -+ return false; + private: +@@ -3344,11 +3345,14 @@ class raw_hash_set { + + // Asserts that hash and equal functors provided by the user are consistent, + // meaning that `eq(k1, k2)` implies `hash(k1)==hash(k2)`. +- template +- void AssertHashEqConsistent(const K& key) { + #ifdef NDEBUG ++ template ++ void AssertHashEqConsistent(const K&) { + return; +-#endif + } -+ return true; -+} -+} -+ -+void InitializeSymbolizer(const char*) { -+ if (process != nullptr) { -+ return; -+ } -+ -+ alignas(SYMBOL_INFO) char buf[sizeof(SYMBOL_INFO) + MAX_SYM_NAME]; -+ SYMBOL_INFO* symbol = reinterpret_cast(buf); -+ symbol->SizeOfStruct = sizeof(SYMBOL_INFO); -+ symbol->MaxNameLen = MAX_SYM_NAME; -+ -+ static_cast(LookupAndInitialize(reinterpret_cast(&InitializeSymbolizer), symbol)); -+} -+ - bool Symbolize(const void* pc, char* out, int out_size) { - if (out_size <= 0) { - return false; -@@ -62,9 +92,11 @@ bool Symbolize(const void* pc, char* out, int out_size) { - SYMBOL_INFO* symbol = reinterpret_cast(buf); - symbol->SizeOfStruct = sizeof(SYMBOL_INFO); - symbol->MaxNameLen = MAX_SYM_NAME; -- if (!SymFromAddr(process, reinterpret_cast(pc), nullptr, symbol)) { -+ -+ if(!LookupAndInitialize(pc, symbol)) { - return false; ++#else ++ template ++ void AssertHashEqConsistent(const K& key) { + // If the hash/eq functors are known to be consistent, then skip validation. + if (std::is_same::value && + std::is_same::value) { +@@ -3386,6 +3390,7 @@ class raw_hash_set { + if (capacity() > 16) return; + IterateOverFullSlots(common(), sizeof(slot_type), assert_consistent); } -+ - const size_t out_size_t = static_cast(out_size); - strncpy(out, symbol->Name, out_size_t); - if (out[out_size_t - 1] != '\0') { ++#endif + + // Attempts to find `key` in the table; if it isn't found, returns an iterator + // where the value can be inserted into, with the control byte already set to diff --git a/cmake/vcpkg-ports/abseil/portfile.cmake b/cmake/vcpkg-ports/abseil/portfile.cmake index 16a9bd86b06a5..0017b8ef74b40 100644 --- a/cmake/vcpkg-ports/abseil/portfile.cmake +++ b/cmake/vcpkg-ports/abseil/portfile.cmake @@ -6,19 +6,11 @@ vcpkg_from_github( OUT_SOURCE_PATH SOURCE_PATH REPO abseil/abseil-cpp REF "${VERSION}" - SHA512 bd2cca8f007f2eee66f51c95a979371622b850ceb2ce3608d00ba826f7c494a1da0fba3c1427728f2c173fe50d59b701da35c2c9fdad2752a5a49746b1c8ef31 + SHA512 92542db666e0c628cf56bf8ad09412af9c8b622e4f26e72d1e1b092ceec430a5c105f6561e2d9983af565f55da07f67e770cafe373b20cc4cb29a893a6a236fc HEAD_REF master PATCHES absl_windows.patch ) -# With ABSL_PROPAGATE_CXX_STD=ON abseil automatically detect if it is being -# compiled with C++14 or C++17, and modifies the installed `absl/base/options.h` -# header accordingly. This works even if CMAKE_CXX_STANDARD is not set. Abseil -# uses the compiler default behavior to update `absl/base/options.h` as needed. -set(ABSL_USE_CXX17_OPTION "") -if("cxx17" IN_LIST FEATURES) - set(ABSL_USE_CXX17_OPTION "-DCMAKE_CXX_STANDARD=17") -endif() set(ABSL_STATIC_RUNTIME_OPTION "") if(VCPKG_TARGET_IS_WINDOWS AND VCPKG_CRT_LINKAGE STREQUAL "static") diff --git a/cmake/vcpkg-ports/abseil/vcpkg.json b/cmake/vcpkg-ports/abseil/vcpkg.json index 1b8bccfbae03b..ca184edf2cdb7 100644 --- a/cmake/vcpkg-ports/abseil/vcpkg.json +++ b/cmake/vcpkg-ports/abseil/vcpkg.json @@ -1,6 +1,6 @@ { "name": "abseil", - "version": "20240722.0", + "version": "20250512.0", "description": [ "Abseil is an open-source collection of C++ library code designed to augment the C++ standard library. The Abseil library code is collected from Google's own C++ code base, has been extensively tested and used in production, and is the same code we depend on in our daily coding lives.", "In some cases, Abseil provides pieces missing from the C++ standard; in others, Abseil provides alternatives to the standard for special needs we've found through usage in the Google code base. We denote those cases clearly within the library code we provide you.", @@ -17,10 +17,5 @@ "name": "vcpkg-cmake-config", "host": true } - ], - "features": { - "cxx17": { - "description": "Enable compiler C++17." - } - } + ] } diff --git a/cmake/vcpkg-ports/onnx/binskim.patch b/cmake/vcpkg-ports/onnx/binskim.patch index c782db4b6d64d..30d5a44a1d1cc 100644 --- a/cmake/vcpkg-ports/onnx/binskim.patch +++ b/cmake/vcpkg-ports/onnx/binskim.patch @@ -1,5 +1,5 @@ diff --git a/CMakeLists.txt b/CMakeLists.txt -index 6fe5c96e..087a7780 100644 +index 8b5af303..7fe05a5a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -40,6 +40,7 @@ option(ONNX_USE_LITE_PROTO "Use lite protobuf instead of full." OFF) @@ -55,10 +55,26 @@ index 6fe5c96e..087a7780 100644 set(ONNX_API_DEFINE "-DONNX_API=__attribute__\(\(__visibility__\(\"default\"\)\)\)") set_target_properties(onnx_proto PROPERTIES CXX_VISIBILITY_PRESET hidden) set_target_properties(onnx_proto PROPERTIES VISIBILITY_INLINES_HIDDEN 1) -@@ -620,21 +636,11 @@ if(MSVC) +@@ -595,13 +611,6 @@ if(ONNX_BUILD_PYTHON) + target_link_libraries(onnx_cpp2py_export PRIVATE ${Python3_LIBRARIES}) + target_compile_options(onnx_cpp2py_export + PRIVATE /MP +- /wd4146 # unary minus operator applied to unsigned type, +- # result still unsigned +- /wd4244 # 'argument': conversion from 'google:: +- # protobuf::uint64' to 'int', possible +- # loss of data +- /wd4267 # Conversion from 'size_t' to 'int', +- # possible loss of data + ${EXTRA_FLAGS}) + add_msvc_runtime_flag(onnx_cpp2py_export) + add_onnx_global_defines(onnx_cpp2py_export) +@@ -618,23 +627,9 @@ endif() + if(MSVC) + target_compile_options(onnx_proto PRIVATE /MP - /wd4146 # unary minus operator applied to unsigned type, - # result still unsigned +- /wd4146 # unary minus operator applied to unsigned type, +- # result still unsigned - /wd4244 #'argument': conversion from 'google:: - #protobuf::uint64' to 'int', possible - # loss of data @@ -67,8 +83,8 @@ index 6fe5c96e..087a7780 100644 ${EXTRA_FLAGS}) target_compile_options(onnx PRIVATE /MP - /wd4146 # unary minus operator applied to unsigned type, - # result still unsigned +- /wd4146 # unary minus operator applied to unsigned type, +- # result still unsigned - /wd4244 # 'argument': conversion from 'google:: - # protobuf::uint64' to 'int', possible - # loss of data @@ -134,7 +150,7 @@ index c0ed3a39..6c8e2909 100644 auto direction = getAttribute(ctx, "direction", "forward"); diff --git a/onnx/defs/schema.h b/onnx/defs/schema.h -index 42318d82..a33cf342 100644 +index acf3aac7..5bef6e72 100644 --- a/onnx/defs/schema.h +++ b/onnx/defs/schema.h @@ -980,10 +980,7 @@ class OpSchemaRegistry final : public ISchemaRegistry { diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index 7ba2f820e9bdb..f6cc816b45ed2 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -67,6 +67,7 @@ Do not modify directly.* * com.microsoft.PackedAttention * com.microsoft.PackedMultiHeadAttention * com.microsoft.Pad + * com.microsoft.PagedAttention * com.microsoft.QAttention * com.microsoft.QGemm * com.microsoft.QLinearAdd @@ -3683,6 +3684,100 @@ This version of the operator has been available since version 1 of the 'com.micr +### **com.microsoft.PagedAttention** + + Paged Attention. + + This op leverages a block-based KV cache to enable continuous batching for LLMs. Currently, it is designed to work with + the CUDA Execution Provider only. + + In other attention ops, batch entries typically aren't of the same length, so they are padded. + Below is a batch with 3 sequences where * denotes a padding token. + Sequence_0: 0, 1*, 2*, 3* + Sequence_1: 4, 5, 6*, 7* + Sequence_2: 8, 9, 10, 11 + + PagedAttention is designed to take in packed input, i.e., only the real tokens without padding. + For example, the input shown above will be packed into 3 tensors like below: + - query ([q0, q4, q5, q8, q9, q10, q11]) + - key ([k0, k4, k5, k8, k9, k10, k11]) + - value ([v0, v4, v5, v8, v9, v10, v11]) + - cumulative_sequence_length: 0, 1, 1+2, 1+2+4 + This packing omits padding tokens. + + The query, key and value tensors contain result of hidden embedding of real tokens after input projections. + cumulative_sequence_length records cumulated length of each sequence length. + + +#### Version + +This version of the operator has been available since version 1 of the 'com.microsoft' operator set. + +#### Attributes + +
+
do_rotary : int
+
Whether to use rotary position embedding. Default value is 0.
+
kv_num_heads : int (required)
+
Number of attention heads for k and v
+
local_window_size : int
+
left_window_size for local attention (like Mistral). Default value is -1 meaning unused.
+
num_heads : int (required)
+
Number of attention heads for q
+
rotary_interleaved : int
+
Rotate using interleaved pattern. Default value is 0 (False).
+
scale : float
+
Custom scale will be used if specified. Default value is 1/sqrt(head_size)
+
softcap : float
+
Softcap value for attention weights. Default value is 0.
+
+ +#### Inputs (8 - 10) + +
+
query : T
+
Query with shape (num_tokens, hidden_size), or packed QKV with shape (num_tokens, d) where d is (num_heads * head_size + 2 * kv_num_heads * head_size).
+
key (optional) : T
+
Key with shape (num_tokens, kv_hidden_size)
+
value (optional) : T
+
Value with shape (num_tokens, kv_hidden_size)
+
key_cache : T
+
Block-based key cache with shape (num_blocks, block_size, kv_num_heads, head_size). This is updated in place within the op.
+
value_cache : T
+
Block-based value cache with shape (num_blocks, block_size, kv_num_heads, head_size). This is updated in place within the op. This should be the same shape as key_cache.
+
cumulative_sequence_length : S
+
A tensor with shape (batch_size + 1). It specifies the cumulative sequence lengths between the packed entries in Q/K/V.
+
past_seqlens : S
+
A tensor with shape (batch_size). It specifies the past lengths of cached sequence in the KV cache.
+
block_table : S
+
2D tensor with shape (batch_size, max_blocks_per_sequence) that maps each sequence in the batch to itscorresponding blocks in the KV cache.
+
cos_cache (optional) : T
+
2D tensor with shape (max total seqlen, head_size / 2).
+
sin_cache (optional) : T
+
2D tensor with shape (max total seqlen, head_size / 2).
+
+ +#### Outputs (1 - 3) + +
+
output : T
+
3D output tensor with shape (num_tokens, hidden_size)
+
key_cache_out (optional) : T
+
Block-based key cache with shape (num_blocks, block_size, kv_num_heads, head_size). This is always the same tensor as key_cache.
+
value_cache_out (optional) : T
+
Block-based value cache with shape (num_blocks, block_size, kv_num_heads, head_size). This is always the same tensor as value_cache.
+
+ +#### Type Constraints + +
+
T : tensor(float16), tensor(bfloat16)
+
Constrain input and output to float tensors.
+
S : tensor(int32)
+
Constrain Positional inputs to int tensor.
+
+ + ### **com.microsoft.QAttention** Quantization of Multi-Head Self Attention. @@ -6345,3 +6440,5 @@ No versioning maintained for experimental ops.
T : tensor(float)
Constrain input and output types to float32 tensors.
+ + diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index b657c828fbde1..4544b0daf93cd 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -371,6 +371,7 @@ Do not modify directly.* |ReverseSequence|*in* input:**T**
*in* sequence_lens:**tensor(int64)**
*out* Y:**T**|10+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |RoiAlign|*in* X:**T1**
*in* rois:**T1**
*in* batch_indices:**T2**
*out* Y:**T1**|16+|**T1** = tensor(double), tensor(float)
**T2** = tensor(int64)| |||[10, 15]|**T1** = tensor(double), tensor(float)
**T2** = tensor(int64)| +|RotaryEmbedding|*in* X:**T**
*in* cos_cache:**T**
*in* sin_cache:**T**
*in* position_ids:**M**
*out* Y:**T**|23+|**M** = tensor(int64)
**T** = tensor(float), tensor(float16)| |Round|*in* X:**T**
*out* Y:**T**|22+|**T** = tensor(double), tensor(float), tensor(float16)| |||[11, 21]|**T** = tensor(double), tensor(float), tensor(float16)| |STFT|*in* signal:**T1**
*in* frame_step:**T2**
*in* window:**T1**
*in* frame_length:**T2**
*out* output:**T1**|17+|**T1** = tensor(double), tensor(float)
**T2** = tensor(int32), tensor(int64)| @@ -952,6 +953,7 @@ Do not modify directly.* |NhwcConv|*in* X:**T**
*in* W:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| |PackedAttention|*in* input:**T**
*in* weights:**T**
*in* bias:**T**
*in* token_offset:**M**
*in* cumulative_sequence_length:**M**
*in* attention_bias:**T**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| |PackedMultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* bias:**T**
*in* token_offset:**M**
*in* cumulative_sequence_length:**M**
*in* attention_bias:**T**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| +|PagedAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* key_cache:**T**
*in* value_cache:**T**
*in* cumulative_sequence_length:**S**
*in* past_seqlens:**S**
*in* block_table:**S**
*in* cos_cache:**T**
*in* sin_cache:**T**
*out* output:**T**
*out* key_cache_out:**T**
*out* value_cache_out:**T**|1+|**S** = tensor(int32)
**T** = tensor(bfloat16), tensor(float16)| |QAttention|*in* input:**T1**
*in* weight:**T2**
*in* bias:**T3**
*in* input_scale:**T3**
*in* weight_scale:**T3**
*in* mask_index:**T4**
*in* input_zero_point:**T1**
*in* weight_zero_point:**T2**
*in* past:**T3**
*out* output:**T3**
*out* present:**T3**|1+|**T1** = tensor(int8)
**T2** = tensor(int8)
**T3** = tensor(float), tensor(float16)
**T4** = tensor(int32)| |QMoE|*in* input:**T**
*in* router_probs:**T**
*in* fc1_experts_weights:**T1**
*in* fc1_scales:**T**
*in* fc1_experts_bias:**T**
*in* fc2_experts_weights:**T1**
*in* fc2_scales:**T**
*in* fc2_experts_bias:**T**
*in* fc3_experts_weights:**T1**
*in* fc3_scales:**T**
*in* fc3_experts_bias:**T**
*out* output:**T**|1+|**T** = tensor(float16)
**T1** = tensor(uint8)| |QOrderedAttention|*in* input:**Q**
*in* scale_input:**S**
*in* scale_Q_gemm:**S**
*in* scale_K_gemm:**S**
*in* scale_V_gemm:**S**
*in* Q_weight:**Q**
*in* K_weight:**Q**
*in* V_weight:**Q**
*in* scale_Q_weight:**S**
*in* scale_K_weight:**S**
*in* scale_V_weight:**S**
*in* Q_bias:**S**
*in* K_bias:**S**
*in* V_bias:**S**
*in* scale_QKT_gemm:**S**
*in* scale_QKT_softmax:**S**
*in* scale_values_gemm:**S**
*in* mask_index:**G**
*in* past:**Q**
*in* attention_bias:**S**
*out* output:**Q**|1+|**G** = tensor(int32)
**Q** = tensor(int8)
**S** = tensor(float)| diff --git a/include/onnxruntime/core/common/const_pointer_container.h b/include/onnxruntime/core/common/const_pointer_container.h index 1d821ba609205..80343b4fea4d6 100644 --- a/include/onnxruntime/core/common/const_pointer_container.h +++ b/include/onnxruntime/core/common/const_pointer_container.h @@ -79,6 +79,10 @@ class ConstPointerContainer { return data_[index]; } + const T* const* data() const { + return data_.data(); + } + private: const Container& data_; }; diff --git a/include/onnxruntime/core/framework/allocator.h b/include/onnxruntime/core/framework/allocator.h index c84d34cfd3cbe..0660cc874ffb7 100644 --- a/include/onnxruntime/core/framework/allocator.h +++ b/include/onnxruntime/core/framework/allocator.h @@ -265,19 +265,23 @@ bool IAllocator::CalcMemSizeForArrayWithAlignment(size_t nmemb, size_t size, siz return CalcMemSizeForArrayWithAlignment(nmemb, size, alignment, out); } +using AllocatorPtr = std::shared_ptr; +using AllocatorMap = std::map; + class CPUAllocator : public IAllocator { public: explicit CPUAllocator(const OrtMemoryInfo& memory_info) : IAllocator(memory_info) {} + // Creates a function local static and returns a shared pointer to it. + // Re-use in all places where we need a standalone CPUAllocator instance + static AllocatorPtr DefaultInstance(); + CPUAllocator() : IAllocator(OrtMemoryInfo(CPU, OrtAllocatorType::OrtDeviceAllocator)) {} void* Alloc(size_t size) override; void Free(void* p) override; }; -using AllocatorPtr = std::shared_ptr; -using AllocatorMap = std::map; - void* AllocatorDefaultAlloc(size_t size); void AllocatorDefaultFree(void* p); void* AllocatorDefaultAllocAligned(size_t size, size_t alignment); diff --git a/include/onnxruntime/core/framework/execution_provider.h b/include/onnxruntime/core/framework/execution_provider.h index 2245ff5791feb..30a5735c4e493 100644 --- a/include/onnxruntime/core/framework/execution_provider.h +++ b/include/onnxruntime/core/framework/execution_provider.h @@ -170,7 +170,7 @@ class IExecutionProvider { /** Get the device id of current execution provider */ - virtual int GetDeviceId() const { return 0; }; + virtual int GetDeviceId() const { return default_device_.Id(); }; /** Get execution provider's configuration options. diff --git a/include/onnxruntime/core/framework/ort_value.h b/include/onnxruntime/core/framework/ort_value.h index a071f3182faad..0ed427dfb7695 100644 --- a/include/onnxruntime/core/framework/ort_value.h +++ b/include/onnxruntime/core/framework/ort_value.h @@ -18,7 +18,7 @@ class SparseTensor; class TensorSeq; } // namespace onnxruntime -#endif +#endif // SHARED_PROVIDER /** Represents both tensors and non-tensors. @@ -37,8 +37,8 @@ struct OrtValue { type_ = type; } - void Init(void* pData, onnxruntime::MLDataType type, const std::function& deleter) { - data_.reset(pData, deleter); + void Init(void* pData, onnxruntime::MLDataType type, std::function deleter) { + data_.reset(pData, std::move(deleter)); type_ = type; } diff --git a/include/onnxruntime/core/framework/ortdevice.h b/include/onnxruntime/core/framework/ortdevice.h index 2a377238e0e27..ffc0da918c9df 100644 --- a/include/onnxruntime/core/framework/ortdevice.h +++ b/include/onnxruntime/core/framework/ortdevice.h @@ -4,41 +4,82 @@ #pragma once #include +#include "core/common/common.h" #include "core/common/hash_combine.h" +// fix clash with INTEL that is defined in +// MacOSX14.2.sdk/System/Library/Frameworks/Security.framework/Headers/oidsbase.h +#if defined(__APPLE__) +#undef INTEL +#endif + // Struct to represent a physical device. struct OrtDevice { using DeviceType = int8_t; using MemoryType = int8_t; using DeviceId = int16_t; + using VendorId = uint32_t; using Alignment = size_t; // Pre-defined device types. static const DeviceType CPU = 0; - static const DeviceType GPU = 1; // Nvidia or AMD + static const DeviceType GPU = 1; static const DeviceType FPGA = 2; - static const DeviceType NPU = 3; // Ascend + static const DeviceType NPU = 3; + // this is used in the python API so we need to keep it for backward compatibility + // it is only used in the OrtDevice ctor, and is mapped to GPU + VendorIds::MICROSOFT static const DeviceType DML = 4; struct MemType { - // Pre-defined memory types. static const MemoryType DEFAULT = 0; - static const MemoryType CUDA_PINNED = 1; - static const MemoryType HIP_PINNED = 2; - static const MemoryType CANN_PINNED = 3; - static const MemoryType QNN_HTP_SHARED = 4; + + // deprecated values. MemType + VendorId is used to identify the memory type. + enum Deprecated : MemoryType { + CUDA_PINNED = 1, + HIP_PINNED = 2, + CANN_PINNED = 3, + QNN_HTP_SHARED = 4, + }; + + static const MemoryType HOST_ACCESSIBLE = 5; // Device memory that is accessible from host and device. + }; + + // PCI vendor ids + enum VendorIds : VendorId { + // No vendor ID. Valid for DeviceType::CPU + MemType::DEFAULT or for generic allocators like WebGPU. + NONE = 0x0000, + AMD = 0x1002, // ROCm, MIGraphX EPs + NVIDIA = 0x10DE, // CUDA/TensorRT + ARM = 0x13B5, // ARM GPU EP + MICROSOFT = 0x1414, // DML EP + HUAWEI = 0x19E5, // CANN EP + QUALCOMM = 0x5143, // QNN DP + INTEL = 0x8086, // OpenVINO }; - constexpr OrtDevice(DeviceType device_type_, MemoryType memory_type_, DeviceId device_id_, Alignment alignment) noexcept + constexpr OrtDevice(DeviceType device_type_, MemoryType memory_type_, VendorId vendor_id_, DeviceId device_id_, + Alignment alignment) /*noexcept*/ : device_type(device_type_), memory_type(memory_type_), device_id(device_id_), - alignment(alignment) {} + vendor_id(vendor_id_), + alignment(alignment) { + // temporary to make sure we haven't missed any places where the deprecated values were used + // ctor can go back to noexcept once everything is validated and this is removed`1 + ORT_ENFORCE(memory_type == MemType::DEFAULT || memory_type == MemType::HOST_ACCESSIBLE, + "Invalid memory type: ", static_cast(memory_type)); + + if (device_type == DML) { + device_type = GPU; + vendor_id = VendorIds::MICROSOFT; + } + } - constexpr OrtDevice(DeviceType device_type_, MemoryType memory_type_, DeviceId device_id_) noexcept - : OrtDevice(device_type_, memory_type_, device_id_, 0) {} + constexpr OrtDevice(DeviceType device_type_, MemoryType memory_type_, VendorId vendor_id_, + DeviceId device_id_) noexcept + : OrtDevice(device_type_, memory_type_, vendor_id_, device_id_, /*alignment*/ 0) {} - constexpr OrtDevice() noexcept : OrtDevice(CPU, MemType::DEFAULT, 0) {} + constexpr OrtDevice() noexcept : OrtDevice(CPU, MemType::DEFAULT, VendorIds::NONE, 0) {} DeviceType Type() const noexcept { return device_type; @@ -48,6 +89,10 @@ struct OrtDevice { return memory_type; } + VendorId Vendor() const noexcept { + return vendor_id; + } + DeviceId Id() const noexcept { return device_id; } @@ -61,6 +106,7 @@ struct OrtDevice { ostr << "Device:[" << "DeviceType:" << static_cast(device_type) << " MemoryType:" << static_cast(memory_type) + << " VendorId:" << vendor_id << " DeviceId:" << device_id << " Alignment:" << alignment << "]"; @@ -71,6 +117,7 @@ struct OrtDevice { size_t Hash() const { auto h = std::hash()(device_type); onnxruntime::HashCombine(memory_type, h); + onnxruntime::HashCombine(vendor_id, h); onnxruntime::HashCombine(device_id, h); onnxruntime::HashCombine(alignment, h); return h; @@ -82,6 +129,8 @@ struct OrtDevice { return device_type < other.device_type; if (memory_type != other.memory_type) return memory_type < other.memory_type; + if (vendor_id != other.vendor_id) + return vendor_id < other.vendor_id; if (device_id != other.device_id) return device_id < other.device_id; @@ -98,12 +147,18 @@ struct OrtDevice { // Device index. int32_t device_id : 16; + uint32_t vendor_id; + // Required alignment Alignment alignment; }; inline bool operator==(const OrtDevice& left, const OrtDevice& other) { - return left.Id() == other.Id() && left.MemType() == other.MemType() && left.Type() == other.Type() && left.GetAlignment() == other.GetAlignment(); + return left.Type() == other.Type() && + left.MemType() == other.MemType() && + left.Vendor() == other.Vendor() && + left.Id() == other.Id() && + left.GetAlignment() == other.GetAlignment(); } inline bool operator!=(const OrtDevice& left, const OrtDevice& other) { diff --git a/include/onnxruntime/core/framework/ortmemoryinfo.h b/include/onnxruntime/core/framework/ortmemoryinfo.h index 82f581e994904..d930b2289170d 100644 --- a/include/onnxruntime/core/framework/ortmemoryinfo.h +++ b/include/onnxruntime/core/framework/ortmemoryinfo.h @@ -14,19 +14,17 @@ struct OrtMemoryInfo { // use string for name, so we could have customized allocator in execution provider. const char* name = nullptr; - int id = -1; OrtMemType mem_type = OrtMemTypeDefault; OrtAllocatorType alloc_type = OrtInvalidAllocator; OrtDevice device; - constexpr OrtMemoryInfo(const char* name_, OrtAllocatorType type_, OrtDevice device_ = OrtDevice(), int id_ = 0, + constexpr OrtMemoryInfo(const char* name_, OrtAllocatorType type_, OrtDevice device_ = OrtDevice(), OrtMemType mem_type_ = OrtMemTypeDefault) #if ((defined(__GNUC__) && __GNUC__ > 4) || defined(__clang__)) // this causes a spurious error in CentOS gcc 4.8 build so disable if GCC version < 5 __attribute__((nonnull)) #endif : name(name_), - id(id_), mem_type(mem_type_), alloc_type(type_), device(device_) { @@ -38,18 +36,17 @@ struct OrtMemoryInfo { return alloc_type < other.alloc_type; if (mem_type != other.mem_type) return mem_type < other.mem_type; - if (id != other.id) - return id < other.id; + if (device != other.device) + return device < other.device; return strcmp(name, other.name) < 0; } // This is to make OrtMemoryInfo a valid key in hash tables - // we ignore device id size_t Hash() const { auto h = std::hash()(alloc_type); onnxruntime::HashCombine(mem_type, h); - onnxruntime::HashCombine(id, h); + onnxruntime::HashCombine(device.Hash(), h); onnxruntime::HashCombine(name, h); return h; } @@ -58,7 +55,6 @@ struct OrtMemoryInfo { std::ostringstream ostr; ostr << "OrtMemoryInfo:[" << "name:" << name - << " id:" << id << " OrtMemType:" << mem_type << " OrtAllocatorType:" << alloc_type << " " << device.ToString() @@ -71,7 +67,7 @@ struct OrtMemoryInfo { inline bool operator==(const OrtMemoryInfo& left, const OrtMemoryInfo& other) { return left.mem_type == other.mem_type && left.alloc_type == other.alloc_type && - left.id == other.id && + left.device == other.device && strcmp(left.name, other.name) == 0; } diff --git a/include/onnxruntime/core/graph/graph.h b/include/onnxruntime/core/graph/graph.h index 35b568e3f8e28..6883d3ef644d8 100644 --- a/include/onnxruntime/core/graph/graph.h +++ b/include/onnxruntime/core/graph/graph.h @@ -567,6 +567,13 @@ class Node { friend class Graph; Node(NodeIndex index, Graph& graph) : index_(index), graph_(&graph), can_be_saved_(true) {} + protected: +#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) + // internal only method to allow selected classes to directly alter the input/output definitions and arg counts + // made protected to facilitate testing + Definitions& MutableDefinitions() noexcept; +#endif + private: ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(Node); @@ -588,9 +595,6 @@ class Node { #endif #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) - // internal only method to allow selected classes to directly alter the input/output definitions and arg counts - Definitions& MutableDefinitions() noexcept; - // internal only method to allow selected classes to directly alter the links between nodes. Relationships& MutableRelationships() noexcept; @@ -721,11 +725,12 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi /** Replaces the initializer tensor with the same name as the given initializer tensor. The replacement initializer tensor must have the same type and shape as the existing initializer tensor. + The new_initializer is expected to be either small or have external data reference stored in OrtValue. Note: This currently has linear time complexity. There is room for improvement but it would likely require changes to how initializer tensors are stored and tracked. */ - common::Status ReplaceInitializedTensor(ONNX_NAMESPACE::TensorProto new_initializer); + common::Status ReplaceInitializedTensor(const ONNX_NAMESPACE::TensorProto& new_initializer, const OrtValue& ort_value); #if !defined(DISABLE_EXTERNAL_INITIALIZERS) /** This function takes externally provided data for initializers with external data @@ -745,6 +750,18 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) /** Add an initializer tensor to the Graph. */ void AddInitializedTensor(const ONNX_NAMESPACE::TensorProto& tensor_proto); + + /// + /// Add initializer to the Graph. This method takes a tensor proto that contains + /// a data pointer to ort_value. For small tensors (LT utils::kSmallTensorExternalDataThreshold), + /// the data would still be contained within tensor_proto, and + /// OrtValue would be unallocated in this case, and not added to ortvalue_initializers_. + /// + /// tensor proto with external data pointing to OrtValue. + /// value that contains the initializer tensor. This may + /// be unallocated for small tensors. + Status AddInitializedOrtValue(const ONNX_NAMESPACE::TensorProto& tensor_proto, + const OrtValue& ort_value_initializer); #endif /** Remove the initializer tensor with the provided name from the Graph. */ @@ -769,7 +786,7 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi /** Populate `value` if an externally allocated OrtValue exists for an initializer with the given name. */ - bool GetOrtValueInitializer(const std::string& name, OrtValue& value) const; + bool GetOrtValueInitializer(const std::string& name, OrtValue& value, bool check_outer_scope = false) const; /** Gets all the initializer tensors in this Graph. */ const InitializedTensorSet& GetAllInitializedTensors() const noexcept { return name_to_initial_tensor_; } @@ -780,6 +797,9 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi /** Returns true if an initializer value can be overridden by a graph input with the same name. */ bool CanOverrideInitializer() const noexcept { return ir_version_ >= 4; } + /** Returns the ONNX IR version for the model. */ + Version GetOnnxIRVersion() const noexcept { return ir_version_; } + /** returns the initializer's TensorProto if 'name' is an initializer, is constant and cannot be overridden at runtime. If the initializer is not found or is not constant, a nullptr is returned. @param check_outer_scope If true and the graph is a subgraph, @@ -795,6 +815,18 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi */ const ONNX_NAMESPACE::TensorProto* GetInitializer(const std::string& name, bool check_outer_scope) const; + /// + /// Returns the initializer's TensorProto if 'name' is an initializer (either constant or overridable). + /// If the initializer is not found, a nullptr is returned. An output parameter is set to true if the initializer + /// is constant. + /// + /// The initializer's name. + /// Checks outer scope if set to true and the graph is a subgraph. + /// Output parameter set to true if the initializer is a constant. + /// The initializer's TensorProto or nullptr. + const ONNX_NAMESPACE::TensorProto* GetInitializer(const std::string& name, bool check_outer_scope, + bool& is_constant) const; + /** Gets the Graph inputs excluding initializers. These are the required inputs to the Graph as the initializers can be optionally overridden via graph inputs. @remarks Contains no nullptr values. */ @@ -1507,6 +1539,8 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(Graph); + int32_t weight_data_type_freq_[ONNX_NAMESPACE::TensorProto_DataType_DataType_ARRAYSIZE] = {0}; + private: void InitializeStateFromModelFileGraphProto(); @@ -1645,8 +1679,16 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi // so they can be used to resolve outer scope dependencies when running BuildConnections for the subgraphs. common::Status SetOuterScopeNodeArgs(const std::unordered_set& outer_scope_node_args); - // Implementation for initializer replacement - Status ReplaceInitializedTensorImpl(ONNX_NAMESPACE::TensorProto new_initializer, bool is_external); + /// + /// Replace initializer with new_initializer. + /// + /// + /// ort_value with data, may be empty + /// This is true when we replace the initializer with external data + /// with OrtValue from the customer, in which case we enforce that the original initializer must have external data + /// + Status ReplaceInitializedTensorImpl(ONNX_NAMESPACE::TensorProto new_initializer, + OrtValue ort_value, bool must_replace_external); template // range-initializer returning std::string std::vector CreateNodeArgs(const StringRange& names, diff --git a/include/onnxruntime/core/graph/graph_viewer.h b/include/onnxruntime/core/graph/graph_viewer.h index 6a664d8be9c05..73ec34fd45f02 100644 --- a/include/onnxruntime/core/graph/graph_viewer.h +++ b/include/onnxruntime/core/graph/graph_viewer.h @@ -57,6 +57,11 @@ class GraphViewer { /** Returns true if an initializer value can be overridden by a graph input with the same name. */ bool CanOverrideInitializer() const noexcept; + /** Returns the ONNX IR version for the model. */ + Version GetOnnxIRVersion() const noexcept { + return graph_->GetOnnxIRVersion(); + } + /** Gets the Graph inputs, excluding initializers. @returns Collection of NodeArg pointers for the graph inputs, excluding inputs that have matching initializers. diff --git a/include/onnxruntime/core/optimizer/graph_transformer_utils.h b/include/onnxruntime/core/optimizer/graph_transformer_utils.h index 6f07ead935f4a..7535b704cd4f0 100644 --- a/include/onnxruntime/core/optimizer/graph_transformer_utils.h +++ b/include/onnxruntime/core/optimizer/graph_transformer_utils.h @@ -57,8 +57,7 @@ InlinedVector> GenerateTransformers( const IExecutionProvider& execution_provider /*required by constant folding*/, const logging::Logger& logger, const InlinedHashSet& rules_and_transformers_to_disable = {}, - concurrency::ThreadPool* intra_op_thread_pool = nullptr, - std::unordered_map>* p_buffered_tensors = nullptr); + concurrency::ThreadPool* intra_op_thread_pool = nullptr); #endif // !defined(ORT_MINIMAL_BUILD) @@ -89,8 +88,7 @@ InlinedVector> GenerateTransformersForMinimalB const IExecutionProvider& cpu_execution_provider, const logging::Logger& logger, const InlinedHashSet& rules_and_transformers_to_disable = {}, - concurrency::ThreadPool* intra_op_thread_pool = nullptr, - std::unordered_map>* p_buffered_tensors = nullptr); + concurrency::ThreadPool* intra_op_thread_pool = nullptr); #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) diff --git a/include/onnxruntime/core/session/environment.h b/include/onnxruntime/core/session/environment.h index a0053ffd3e3e3..1d5b2d5513044 100644 --- a/include/onnxruntime/core/session/environment.h +++ b/include/onnxruntime/core/session/environment.h @@ -51,6 +51,16 @@ class Environment { const OrtThreadingOptions* tp_options = nullptr, bool create_global_thread_pools = false); + /** + * Set the global threading options for the environment, if no global thread pools have been created yet. + * + * This function is not safe to call simultaneously from multiple threads, and will return a FAIL status on all calls + * after the first. + * @param tp_options set of parameters controlling the number of intra and inter op threads for the global + threadpools. + */ + Status SetGlobalThreadingOptions(const OrtThreadingOptions& tp_options); + logging::LoggingManager* GetLoggingManager() const { return logging_manager_.get(); } diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 0892accec40b0..c860e0794abed 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -142,6 +142,9 @@ extern "C" { // __VA_ARGS__ on Windows and Linux are different #define ORT_API(RETURN_TYPE, NAME, ...) RETURN_TYPE ORT_API_CALL NAME(__VA_ARGS__) NO_EXCEPTION +#define ORT_API_T(RETURN_TYPE, NAME, ...) \ + RETURN_TYPE(ORT_API_CALL* NAME)(__VA_ARGS__) NO_EXCEPTION + #define ORT_API_STATUS(NAME, ...) \ _Success_(return == 0) _Check_return_ _Ret_maybenull_ OrtStatusPtr ORT_API_CALL NAME(__VA_ARGS__) \ NO_EXCEPTION ORT_MUST_USE_RESULT @@ -316,6 +319,7 @@ ORT_RUNTIME_CLASS(ModelCompilationOptions); ORT_RUNTIME_CLASS(HardwareDevice); ORT_RUNTIME_CLASS(EpDevice); ORT_RUNTIME_CLASS(KeyValuePairs); +ORT_RUNTIME_CLASS(ArrayOfConstObjects); #ifdef _MSC_VER typedef _Return_type_success_(return == 0) OrtStatus* OrtStatusPtr; @@ -330,10 +334,16 @@ typedef OrtStatus* OrtStatusPtr; * When an allocator is passed to any function, be sure that the allocator object is not destroyed until the last allocated object using it is freed. */ typedef struct OrtAllocator { - uint32_t version; ///< Must be initialized to ORT_API_VERSION - void*(ORT_API_CALL* Alloc)(struct OrtAllocator* this_, size_t size); ///< Returns a pointer to an allocated block of `size` bytes - void(ORT_API_CALL* Free)(struct OrtAllocator* this_, void* p); ///< Free a block of memory previously allocated with OrtAllocator::Alloc - const struct OrtMemoryInfo*(ORT_API_CALL* Info)(const struct OrtAllocator* this_); ///< Return a pointer to an ::OrtMemoryInfo that describes this allocator + uint32_t version; ///< Must be initialized to ORT_API_VERSION + + /// Returns a pointer to an allocated block of `size` bytes + void*(ORT_API_CALL* Alloc)(struct OrtAllocator* this_, size_t size); + + /// Free a block of memory previously allocated with OrtAllocator::Alloc + void(ORT_API_CALL* Free)(struct OrtAllocator* this_, void* p); + + /// Return a pointer to an ::OrtMemoryInfo that describes this allocator + const struct OrtMemoryInfo*(ORT_API_CALL* Info)(const struct OrtAllocator* this_); /** * @brief Optional allocation function to use for memory allocations made during session initialization. * Use this function if you want to separate allocations made by ORT during Run() calls from @@ -420,12 +430,19 @@ typedef enum OrtMemType { OrtMemTypeDefault = 0, ///< The default allocator for execution provider } OrtMemType; +/** \brief This matches OrtDevice::MemoryType values */ +typedef enum OrtDeviceMemoryType { + OrtDeviceMemoryType_DEFAULT = 0, ///< Device memory + OrtDeviceMemoryType_HOST_ACCESSIBLE = 5, ///< Shared/pinned memory for transferring between CPU and the device +} OrtDeviceMemoryType; + /** \brief This mimics OrtDevice type constants so they can be returned in the API */ typedef enum OrtMemoryInfoDeviceType { OrtMemoryInfoDeviceType_CPU = 0, OrtMemoryInfoDeviceType_GPU = 1, - OrtMemoryInfoDeviceType_FPGA = 2 + OrtMemoryInfoDeviceType_FPGA = 2, + OrtMemoryInfoDeviceType_NPU = 3, } OrtMemoryInfoDeviceType; typedef enum OrtHardwareDeviceType { @@ -474,6 +491,16 @@ typedef OrtStatus*(ORT_API_CALL* EpSelectionDelegate)(_In_ const OrtEpDevice** e _Out_ size_t* num_selected, _In_ void* state); +/** \brief Enum tags for ORT runtime types used to identify the type of elements in containers, + * like OrtArrayOfConstObjects. + */ +typedef enum OrtTypeTag { + ORT_TYPE_TAG_Void, + ORT_TYPE_TAG_OrtValueInfo, + ORT_TYPE_TAG_OrtNode, + ORT_TYPE_TAG_OrtGraph, +} OrtTypeTag; + /** \brief Algorithm to use for cuDNN Convolution Op */ typedef enum OrtCudnnConvAlgoSearch { @@ -768,6 +795,9 @@ typedef struct OrtCompileApi OrtCompileApi; struct OrtEpApi; typedef struct OrtEpApi OrtEpApi; +struct OrtNodeComputeInfo; +typedef struct OrtNodeComputeInfo OrtNodeComputeInfo; + /** \brief The helper interface to get the right version of OrtApi * * Get a pointer to this structure through ::OrtGetApiBase @@ -2401,6 +2431,8 @@ struct OrtApi { /// @{ /** \brief Create an allocator for an ::OrtSession following an ::OrtMemoryInfo + * + * The allocator wraps the internal allocator from the OrtSession and becomes invalid when the session does. * * \param[in] session * \param[in] mem_info valid ::OrtMemoryInfo instance @@ -2912,7 +2944,7 @@ struct OrtApi { * crossing which the current chunk is chunked into 2. * "initial_growth_chunk_size_bytes": (Possible) Size of the second allocation in the arena. * Only relevant if arena strategy is `kNextPowerOfTwo`. Use -1 to allow ORT to choose the default. - * "max_power_of_two_extend_bytes": The maximum enxtend size if arena strategy is `kNextPowerOfTwo`. + * "max_power_of_two_extend_bytes": The maximum extend size if arena strategy is `kNextPowerOfTwo`. * It is not an allocation limit, it is only a limit for extension when requested byte is less than the limit. * When requested bytes is more than the limit, allocator will still return as requested. * Use -1 to allow ORT to choose the default 1GB for max_power_of_two_extend_bytes. @@ -3646,7 +3678,7 @@ struct OrtApi { * * \param[in] name Name of the attribute * \param[in] data Data content of the attribute - * \param[in] len Number of bytes stored in data + * \param[in] len Number of elements if data represents an array (e.g., ORT_OP_ATTR_INTS). Otherwise, set to 1. * \param[in] type Data type * \param[out] op_attr Attribute that has been created, which must be released by OrtApi::ReleaseOpAttr * @@ -3832,6 +3864,11 @@ struct OrtApi { * assigned to QNN EP is dumped to a separate file. * "json_qnn_graph_dir": Directory in which to dump QNN JSON graphs. If not specified, QNN graphs are dumped in the * program's current working directory. Ignored if "dump_json_qnn_graph" is not set. + * "op_packages": QNN UDO op_package for QNN EP, allowed format: + *   ::[:],::[:], + *   where op_type is the name of the operation, op_package_path is the path to the op package shared library, + * interface is the symbol name to register the op life cycle functions, and target is the backend type. For more + * details, refer to: https://docs.qualcomm.com/bundle/publicresource/topics/80-63442-50/op_packages.html * * XNNPACK supported keys: * "intra_op_num_threads": number of thread-pool size to use for XNNPACK execution provider. @@ -5328,6 +5365,548 @@ struct OrtApi { * \since Version 1.23. */ ORT_API2_STATUS(AllocatorGetStats, _In_ const OrtAllocator* ort_allocator, _Outptr_ OrtKeyValuePairs** out); + + /** \brief Create an ::OrtMemoryInfo + * + * \param[in] name Arbitrary name. + * \param[in] device_type Device type. + * \param[in] vendor_id PCI Vendor ID. Use 0 for a generic allocator (e.g. WebGPU). + * \param[in] device_id Device ID if there are multiple devices of the same type. e.g. 2 GPU devices. + * \param[in] mem_type Memory type. Use OrtDeviceMemoryType_DEFAULT for device memory, and + * OrtDeviceMemoryType_HOST_ACCESSIBLE (if applicable) for memory used to transfer + * between the device and the CPU. + * \param[in] alignment Alignment of the memory if required. Pass 0 for default alignment. + * \param[in] allocator_type Allocator type. If OrtAllocatorType::OrtArenaAllocator, the ORT arena will be used. + * Caveat: Support for OrtArenaAllocator is currently limited to usage of internal ORT + * allocators via CreateAllocator/CreateAndRegisterAllocator/CreateAndRegisterAllocatorV2. + * \param[out] out Newly created ::OrtMemoryInfo. Must be freed with OrtAPi::ReleaseMemoryInfo + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.23 + */ + ORT_API2_STATUS(CreateMemoryInfo_V2, _In_ const char* name, _In_ enum OrtMemoryInfoDeviceType device_type, + _In_ uint32_t vendor_id, _In_ int16_t device_id, _In_ enum OrtDeviceMemoryType mem_type, + _In_ size_t alignment, enum OrtAllocatorType allocator_type, + _Outptr_ OrtMemoryInfo** out); + + // + // OrtArrayOfConstObjects + // + + /** \brief Create an OrtArrayOfConstObjects instance, which represents an array of + * pointers to constant opaque objects (i.e., each element is a 'const void*'). + * + * The OrtArrayOfConstObjects instance does not own the underlying objects, only the pointers + * to them. + * + * An OrtArrayOfConstObjects instance stores elements of type 'const void*'. Users + * must check the object's type via ArrayOfConstObjects_GetObjectType before casting objects + * to their actual type. + * + * \param[in] object_type The object's type as indicated by the OrtTypeTag enum. + * \param[in] initial_size The backing array's initial size. Can be set to 0. + * \param[in] initial_value Each element's initial value. Can be set to NULL. + * \param[out] out A pointer to a newly created OrtArrayOfConstObjects instance. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \note Must be released by calling ReleaseArrayOfConstObjects. + * + * \since Version 1.23. + */ + ORT_API2_STATUS(CreateArrayOfConstObjects, _In_ OrtTypeTag object_type, _In_ size_t initial_size, + _In_ const void* initial_value, _Outptr_ OrtArrayOfConstObjects** out); + + ORT_CLASS_RELEASE(ArrayOfConstObjects); + + /** \brief Get a tag that represents the type of the opaque objects stored in a OrtArrayOfConstObjects instance. + * + * Refer to OrtTypeTag for valid values. + * + * \param[in] array The OrtArrayOfConstObjects instance. + * \param[out] type_tag Output parameter set to the type tag that corresponds to the object type. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.23. + */ + ORT_API2_STATUS(ArrayOfConstObjects_GetObjectType, _In_ const OrtArrayOfConstObjects* array, + _Out_ OrtTypeTag* type_tag); + + /** \brief Get a pointer to a data buffer of contiguous elements, where each element is a constant pointer to a + * constant opaque object (i.e., each element is a 'const void* const'). + * + * Caller must cast the objects to the appropriate type indicated by ArrayOfConstObjects_GetObjectType. + * + * \param[in] array The OrtArrayOfConstObjects instance. + * \param[out] data Output parameter set to the contiguous data buffer that stores all elements. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.23. + */ + ORT_API2_STATUS(ArrayOfConstObjects_GetData, _In_ const OrtArrayOfConstObjects* array, + _Outptr_ const void* const** data); + + /** \brief Get a pointer to a data buffer of contiguous elements, where each element is a pointer to a + * constant opaque object (i.e., each element is a 'const void*'). + * + * Caller must cast the objects to the appropriate type indicated by ArrayOfConstObjects_GetObjectType. + * + * \param[in] array The OrtArrayOfConstObjects instance. + * \param[out] data Output parameter set to the contiguous data buffer that stores all elements. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.23. + */ + ORT_API2_STATUS(ArrayOfConstObjects_GetMutableData, _In_ OrtArrayOfConstObjects* array, _Outptr_ const void*** data); + + /** \brief Get the number of elements contained by the given OrtArrayOfConstObjects instance. + * + * \param[in] array The OrtArrayOfConstObjects instance. + * \param[out] size Output parameter set to the number of elements in the array. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.23. + */ + ORT_API2_STATUS(ArrayOfConstObjects_GetSize, _In_ const OrtArrayOfConstObjects* array, _Out_ size_t* size); + + /** \brief Get the element at the given index. Returns an error status if the index is outside the array bounds. + * + * Caller must cast the object to the appropriate type indicated by ArrayOfConstObjects_GetObjectType. + * Example: + * // Assume OrtTypeTag is ORT_TYPE_TAG_OrtNode and there is at least one node in the array. + * const OrtNode* node = nullptr; + * OrtStatus status = ort_api.ArrayOfConstObjects_GetElementAt(nodes, 0, reinterpret_cast(&node))); + * + * \param[in] array The OrtArrayOfConstObjects instance. + * \param[in] index The index of the element. + * \param[out] out Output parameter set to the element at the given index. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.23. + */ + ORT_API2_STATUS(ArrayOfConstObjects_GetElementAt, _In_ const OrtArrayOfConstObjects* array, _In_ size_t index, + _Outptr_ const void** out); + + /** \brief Set the element at the given index. Returns an error status if the index is outside the array bounds. + * + * \param[in] array The OrtArrayOfConstObjects instance. + * \param[in] index The index of the element. + * \param[in] element The element to set. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.23. + */ + ORT_API2_STATUS(ArrayOfConstObjects_SetElementAt, _In_ OrtArrayOfConstObjects* array, _In_ size_t index, + _In_ const void* element); + + /** \brief Appends an element to the end of the array, which increases the size of the array by one. + * + * \param[in] array The OrtArrayOfConstObjects instance. + * \param[in] element The element to append. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.23. + */ + ORT_API2_STATUS(ArrayOfConstObjects_AppendElement, _In_ OrtArrayOfConstObjects* array, _In_ const void* element); + + // + // OrtValueInfo + // + + /** \brief Get the OrtNode that produces the value represented by the given OrtValueInfo. + * Optionally returns the associated output index. + * + * \param[in] value_info The OrtValueInfo instance. + * \param[out] producer_node Output parameter set to the OrtNode that produces the OrtValueInfo. + * \param[out] producer_output_index Optional output parameter set to the OrtNode instance's output index + * that produces the value. Ignored if set to NULL. + * \snippet{doc} snippets.dox OrtStatus Return Value + * \since Version 1.23. + */ + ORT_API2_STATUS(ValueInfo_GetValueProducer, _In_ const OrtValueInfo* value_info, + _Outptr_ const OrtNode** producer_node, _Out_opt_ size_t* producer_output_index); + + /** \brief Get the number of consumers of a value as a node input. + * + * Only nodes are considered "consumers" by this function. To check if an OrtValueInfo is a graph output, + * call ValueInfo_IsGraphOutput(). + * + * A single OrtNode may use a single value for more than one input (e.g., Mul(x, x)), so the returned + * `num_consumers` may be larger than the number of unique OrtNode instances that consume the value. + * + * \param[in] value_info The OrtValueInfo instance. + * \param[out] num_consumers Output parameter set to the number of consumers of the value. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.23. + */ + ORT_API2_STATUS(ValueInfo_GetValueNumConsumers, _In_ const OrtValueInfo* value_info, _Out_ size_t* num_consumers); + + /** \brief Returns information (OrtNode and input index) for all consumer nodes that use the value as an input. + * + * Only nodes are considered "consumers" by this function. + * + * Caller provides 2 pre-allocated arrays that will be filled with the OrtNode and input index values. + * Use ValueInfo_GetValueNumConsumers() to get the number of consumers of the value. + * + * An OrtNode instance may appear multiple times if it uses the given value more than once. + * Example: For a node MulNode(x, x) that consumes the value 'x' twice, the following is returned: + * - nodes: [MulNode, MulNode] + * - input_indices: [0, 1] + * + * \param[in] value_info The OrtValueInfo instance. + * \param[out] nodes Pre-allocated array of size `max_num_consumers` that will be filled with OrtNode instances. + * \param[out] input_indices Pre-allocated array of `max_num_consumers` elements that will be filled + * with input indices. Index is set to -1 for an "implicit" input to a consumer node + * that contains a subgraph (e.g., If, Loop) with nodes that use the value internally. + * \param[in] max_num_consumers The maximum size of the `consumer_nodes` and `consumer_input_indices` arrays. + * Typical usage sets this to the value of ValueInfo_GetValueNumConsumers(). + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.23. + */ + ORT_API2_STATUS(ValueInfo_GetValueConsumers, _In_ const OrtValueInfo* value_info, + _Out_writes_all_(max_num_consumers) const OrtNode** nodes, + _Out_writes_all_(max_num_consumers) int64_t* input_indices, + _In_ size_t max_num_consumers); + + /** \brief Get the underlying initializer value, as an OrtValue, from the given OrtValueInfo. + * + * Sets the output parameter to NULL if the given OrtValueInfo does not represent an initializer. + * Does not return an error status in this case. + * + * Supports initializers defined in an outer scope (i.e., a parent graph). + * + * \param[in] value_info The OrtValueInfo instance. + * \param[out] initializer_value Output parameter set to the initializer value or NULL. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * \since Version 1.23. + */ + ORT_API2_STATUS(ValueInfo_GetInitializerValue, _In_ const OrtValueInfo* value_info, + _Outptr_ const OrtValue** initializer_value); + + /** \brief Returns a boolean indicating if the given value is a required graph input. + * + * For ONNX IR version < 4, all graph inputs without a matching initializer are required. + * + * For ONNX IR version >=4, a graph input with a matching initializer is an optional graph input + * with the initializer serving as the default value. + * + * \param[in] value_info The OrtValueInfo instance representing the graph value. + * \param[out] is_required_graph_input Output parameter set to true if the graph value is a required graph input. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.23. + */ + ORT_API2_STATUS(ValueInfo_IsRequiredGraphInput, _In_ const OrtValueInfo* value_info, + _Out_ bool* is_required_graph_input); + + /** \brief Returns a boolean indicating if the given value is an optional graph input. + * + * Optional graph inputs were introduced in ONNX IR version 4. For ONNX IR version >=4, a graph input with a + * matching initializer is an optional graph input with the initializer serving as the default value. + * The matching initializer is also known as a non-constant initializer. + * + * \param[in] value_info The OrtValueInfo instance representing the graph value. + * \param[out] is_optional_graph_input Output parameter set to true if the graph value is an optional graph input. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.23. + */ + ORT_API2_STATUS(ValueInfo_IsOptionalGraphInput, _In_ const OrtValueInfo* value_info, + _Out_ bool* is_optional_graph_input); + + /** \brief Returns a boolean indicating if the given value is a graph output. + * + * \param[in] value_info The OrtValueInfo instance representing the graph value. + * \param[out] is_graph_output Output parameter set to true if the graph value is a graph output. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.23. + */ + ORT_API2_STATUS(ValueInfo_IsGraphOutput, _In_ const OrtValueInfo* value_info, _Out_ bool* is_graph_output); + + /** \brief Returns a boolean indicating if the given value is a constant initializer. + * + * For ONNX IR version < 4, all initializers are constant. + * + * For ONNX IR version >=4, an initializer that serves as the default value for a matching graph input is not a + * constant initializer. + * + * \param[in] value_info The OrtValueInfo instance representing the graph value. + * \param[out] is_constant_initializer Output parameter set to true if the graph value is a constant initializer. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.23. + */ + ORT_API2_STATUS(ValueInfo_IsConstantInitializer, _In_ const OrtValueInfo* value_info, + _Out_ bool* is_constant_initializer); + + /** \brief Returns a boolean indicating if the given value is defined in an outer scope. + * + * Certain operator types (e.g., If and Loop) contain nested subgraphs. This function enables + * determining whether a value is defined in a parent node's graph. + * + * \param[in] value_info The OrtValueInfo instance representing the graph value. + * \param[out] is_from_outer_scope Output parameter set to true if the value is defined in an outer + * scope (i.e., a parent graph). + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.23. + */ + ORT_API2_STATUS(ValueInfo_IsFromOuterScope, _In_ const OrtValueInfo* value_info, + _Out_ bool* is_from_outer_scope); + + // + // OrtGraph + // + + /** \brief Returns a graph's name. + * + * \param[in] graph The OrtGraph instance. + * \param[out] graph_name Output parameter set to the graph's name. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.23. + */ + ORT_API2_STATUS(Graph_GetName, _In_ const OrtGraph* graph, _Outptr_ const char** graph_name); + + /** \brief Returns the ONNX IR version. + * + * \param[in] graph The OrtGraph instance. + * \param[out] onnx_ir_version Output parameter set to the ONNX IR version. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.23. + */ + ORT_API2_STATUS(Graph_GetOnnxIRVersion, _In_ const OrtGraph* graph, _Out_ int64_t* onnx_ir_version); + + /** \brief Returns the graph's inputs as OrtValueInfo instances. + * + * Includes initializers that are included in the list of graph inputs. + * + * \param[in] graph The OrtGraph instance. + * \param[out] inputs Output parameter set to a new OrtArrayOfConstObjects instance containing the graph inputs + * as OrtValueInfo instances. Must be released by calling ReleaseArrayOfConstObjects. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.23. + */ + ORT_API2_STATUS(Graph_GetInputs, _In_ const OrtGraph* graph, _Outptr_ OrtArrayOfConstObjects** inputs); + + /** \brief Returns the graph's outputs as OrtValueInfo instances. + * + * \param[in] graph The OrtGraph instance. + * \param[out] outputs Output parameter set to a new OrtArrayOfConstObjects instance containing the graph outputs + * as OrtValueInfo instances. Must be released by calling ReleaseArrayOfConstObjects. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.23. + */ + ORT_API2_STATUS(Graph_GetOutputs, _In_ const OrtGraph* graph, _Outptr_ OrtArrayOfConstObjects** outputs); + + /** \brief Returns the graph's initializers as OrtValueInfo instances. Includes constant and non-constant + * initializers. + * + * For ONNX IR version < 4, all initializers are constant. + * + * For ONNX IR version >= 4, an initializer with a name that matches a graph input is considered a + * non-constant initializer. + * + * Call ValueInfo_GetInitializerValue to get the initializer's data. + * + * \param[in] graph The OrtGraph instance. + * \param[out] initializers Output parameter set to a new OrtArrayOfConstObjects instance containing the graph's + * initializers as OrtValueInfo instances. + * Must be released by calling ReleaseArrayOfConstObjects. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.23. + */ + ORT_API2_STATUS(Graph_GetInitializers, _In_ const OrtGraph* graph, _Outptr_ OrtArrayOfConstObjects** initializers); + + /** \brief Returns the graph's nodes as OrtNode instances. + * + * The nodes are sorted using a stable topological ordering. Callers are responsible for maintaining their + * own node ordering if a different order is required. + * + * \param[in] graph The OrtGraph instance. + * \param[out] nodes Output parameter set to a new OrtArrayOfConstObjects instance containing the graph's nodes as + * OrtNode instances. Must be released by calling ReleaseArrayOfConstObjects. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.23. + */ + ORT_API2_STATUS(Graph_GetNodes, const OrtGraph* graph, _Outptr_ OrtArrayOfConstObjects** nodes); + + /** \brief Get the parent node for the given graph, if any exists. + * + * Certain operator types (e.g., If and Loop) contain nested subgraphs. This function enables + * access to the parent node (e.g., the If and Loop node) from a nested subgraph. + * + * \param[in] graph The OrtGraph instance. + * \param[out] node Output parameter that is set to the graph's parent node. + * Set to NULL if a parent node does not exist (e.g., for a top-level graph). + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.23. + */ + ORT_API2_STATUS(Graph_GetParentNode, _In_ const OrtGraph* graph, _Outptr_result_maybenull_ const OrtNode** node); + + // + // OrtNode + // + + /** \brief Returns a node's identifier. + * + * The node's identifier is only unique in the node's parent graph. Different nested subgraphs + * (e.g., subgraphs contained by If and Loop nodes) may reuse identifiers. + * + * \param[in] node The OrtNode instance. + * \param[out] node_id Output parameter set to the node's identifier. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.23. + */ + ORT_API2_STATUS(Node_GetId, _In_ const OrtNode* node, _Out_ size_t* node_id); + + /** \brief Returns a node's name. Can be an empty string. + * + * \param[in] node The OrtNode instance. + * \param[out] node_name Output parameter set to the node's name. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.23. + */ + ORT_API2_STATUS(Node_GetName, _In_ const OrtNode* node, _Outptr_ const char** node_name); + + /** \brief Returns a node's operator type (e.g., "Conv"). + * + * \param[in] node The OrtNode instance. + * \param[out] operator_type Output parameter set to the name of the node's operator type. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.23. + */ + ORT_API2_STATUS(Node_GetOperatorType, _In_ const OrtNode* node, _Outptr_ const char** operator_type); + + /** \brief Returns a node's domain name. + * + * \param[in] node The OrtNode instance. + * \param[out] domain_name Output parameter set to the node's domain name. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.23. + */ + ORT_API2_STATUS(Node_GetDomain, _In_ const OrtNode* node, _Outptr_ const char** domain_name); + + /** \brief Get the opset version in which the given node's operator type was first defined. + * + * \param[in] node The OrtNode instance. + * \param[out] since_version The opset version in which the node's operator type was first defined. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.23. + */ + ORT_API2_STATUS(Node_GetSinceVersion, _In_ const OrtNode* node, _Out_ int* since_version); + + /** \brief Returns a node's inputs as OrtValueInfo instances. + * + * \param[in] node The OrtNode instance. + * \param[out] inputs Output parameter set to the OrtArrayOfConstObjects instance containing the node's inputs + * as OrtValueInfo instances. Must be released by calling ReleaseArrayOfConstObjects. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.23. + */ + ORT_API2_STATUS(Node_GetInputs, _In_ const OrtNode* node, _Outptr_ OrtArrayOfConstObjects** inputs); + + /** \brief Returns a node's outputs as OrtValueInfo instances. + * + * \param[in] node The OrtNode instance. + * \param[out] outputs Output parameter set to a new OrtArrayOfConstObjects instance containing the node's outputs + * as OrtValueInfo instances. Must be released by calling ReleaseArrayOfConstObjects. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.23. + */ + ORT_API2_STATUS(Node_GetOutputs, _In_ const OrtNode* node, _Outptr_ OrtArrayOfConstObjects** outputs); + + /** \brief Get the implicit inputs, as OrtValueInfo instances, that are used within the given node's subgraphs. + * + * Certain operator types (e.g., If and Loop) contain nested subgraphs. The internal nodes within the nested subgraphs + * may use values from the outer scope. Those "outer scope" values are considered implicit inputs to the node that + * contains the subgraphs (e.g., the If or Loop node). + * + * \param[in] node The OrtNode instance. + * \param[out] implicit_inputs Output parameter set to a new OrtArrayOfConstObjects instance containing the node's + * implicit inputs as OrtValueInfo instances. + * Must be released by calling ReleaseArrayOfConstObjects. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.23. + */ + ORT_API2_STATUS(Node_GetImplicitInputs, _In_ const OrtNode* node, _Outptr_ OrtArrayOfConstObjects** implicit_inputs); + + /** \brief Get the subgraphs, as OrtGraph instances, contained by the given node. + * + * Certain operator types (e.g., If and Loop) contain nested subgraphs. + * + * \param[in] node The OrtNode instance. + * \param[out] subgraphs Output parameter set to a new OrtArrayOfConstObjects instance containing the node's + * subgraphs as OrtGraph instances. Must be released by calling ReleaseArrayOfConstObjects. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.23. + */ + ORT_API2_STATUS(Node_GetSubgraphs, _In_ const OrtNode* node, _Outptr_ OrtArrayOfConstObjects** subgraphs); + + /** \brief Get the node's parent OrtGraph instance. + * + * Can return NULL if the OrtNode was created without an owning graph. + * + * \param[in] node The OrtNode instance. + * \param[out] parent_graph Output parameter set to the node's parent OrtGraph. Can be set to NULL + * if the node is not currently contained by a graph. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.23. + */ + ORT_API2_STATUS(Node_GetParentGraph, _In_ const OrtNode* node, + _Outptr_result_maybenull_ const OrtGraph** parent_graph); }; /* @@ -6055,198 +6634,6 @@ struct OrtCompileApi { size_t flags); }; -ORT_RUNTIME_CLASS(Ep); -ORT_RUNTIME_CLASS(EpFactory); - -struct OrtEpApi { - /** \brief Create an OrtEpDevice for the EP and an OrtHardwareDevice. - * \param[in] ep_factory Execution provider factory that is creating the instance. - * \param[in] hardware_device Hardware device that the EP can utilize. - * \param[in] ep_metadata Optional OrtKeyValuePairs instance for execution provider metadata that may be used - * during execution provider selection and passed to CreateEp. - * ep_device will copy this instance and the user should call ReleaseKeyValuePairs. - * \param[in] ep_options Optional OrtKeyValuePairs instance for execution provider options that will be added - * to the Session configuration options if the execution provider is selected. - * ep_device will copy this instance and the user should call ReleaseKeyValuePairs. - * \param ep_device OrtExecutionDevice that is created. - * - * \since Version 1.22. - */ - ORT_API2_STATUS(CreateEpDevice, _In_ OrtEpFactory* ep_factory, - _In_ const OrtHardwareDevice* hardware_device, - _In_opt_ const OrtKeyValuePairs* ep_metadata, - _In_opt_ const OrtKeyValuePairs* ep_options, - _Out_ OrtEpDevice** ep_device); - - ORT_CLASS_RELEASE(EpDevice); -}; - -/** - * \brief The OrtEp struct provides functions to implement for an execution provider. - * \since Version 1.22. - */ -struct OrtEp { - /** \brief The ONNX Runtime version the execution provider was compiled with. - * - * Implementation should set to ORT_API_VERSION. - * ORT will use this to ensure it does not call functions that were not available when the library was compiled. - * - * \since Version 1.22. - */ - uint32_t ort_version_supported; - - /** \brief Get the execution provider name. - * - * \param[in] this_ptr The OrtEp instance. - * \return The execution provider name. - * - * \note Returned string is owned by ORT and valid until UnregisterExecutionProviderLibrary is called. - * - * \since Version 1.22. - */ - const char*(ORT_API_CALL* GetName)(const OrtEp* this_ptr); - - // OrtStatus* GetCapability(OrtEp* ep, const OrtGraph* graph, - // size_t* num_supported_subgraphs, - // OrtIndexedSubgraph** supported_subgraphs, OrtAllocator* allocator); - - // OrtStatus* Compile(OrtEp* ep, const OrtGraph** graphs, OrtNode** fused_graph_nodes, - // size_t count, OrtNodeComputeInfo* node_compute_infos); - - // TODO: Implement OrtEpApi and the complete OrtEp interface as the next step. -}; - -/** \brief The function signature that ORT will call to create OrtEpFactory instances. - * - * This must be available in a function called 'CreateEpFactories' in the execution provider library. - * - * \param[in] registered_name The name the execution library is registered with by RegisterExecutionProviderLibrary - * \param[in] ort_api_base The OrtApiBase instance that is used by the factory to get the OrtApi instance for the - * version of ORT that the library was compiled against. - * \param[in,out] factories The implementation should create and add OrtEpFactory instances to this - * pre-allocated array. - * i.e. usage is `factories[0] = new MyEpFactory();` - * \param[in] max_factories The maximum number of OrtEpFactory instances that can be added to `factories`. - * Current default is to allow 4 factories. This can be increased in the future if needed. - * \param[out] num_factories The number of OrtEpFactory instances created by the factory and added to `factories`. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.22. - */ -typedef OrtStatus* (*CreateEpApiFactoriesFn)(_In_ const char* registered_name, _In_ const OrtApiBase* ort_api_base, - _Inout_ OrtEpFactory** factories, _In_ size_t max_factories, - _Out_ size_t* num_factories); - -/** \brief The function signature that ORT will call to release an OrtEpFactory instance. - * - * This must be available in a function called 'ReleaseEpFactory' in the execution provider library. - * - * \param[in] factory The OrtEpFactory instance to release. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.22. - */ -typedef OrtStatus* (*ReleaseEpApiFactoryFn)(_In_ OrtEpFactory* factory); - -/** - * \brief The OrtEpFactory provides functions to create and manage execution providers. - * \since Version 1.22. - */ -struct OrtEpFactory { - /** \brief The ONNX Runtime version the execution provider was compiled with. - * - * Implementation should set to ORT_API_VERSION. - * ORT will use this to ensure it does not call functions that were not available when the library was compiled. - * - * \since Version 1.22. - */ - uint32_t ort_version_supported; - - /** \brief Get the name the of the execution provider that the factory creates. - * - * \param[in] this_ptr The OrtEpFactory instance. - * \return The name of the execution provider the factory creates. - * - * \since Version 1.22. - */ - const char*(ORT_API_CALL* GetName)(const OrtEpFactory* this_ptr); - - /** \brief Get the name of vendor who owns the execution provider that the factory creates. - * - * \param[in] this_ptr The OrtEpFactory instance. - * \return vendor The vendor name of the execution provider the factory creates. - * - * \since Version 1.22. - */ - const char*(ORT_API_CALL* GetVendor)(const OrtEpFactory* this_ptr); // return EP vendor - - /** \brief Get information from the execution provider if it supports the OrtHardwareDevice. - * - * \param[in] this_ptr The OrtEpFactory instance. - * Non-const as the factory is passed through to the CreateEp call via the OrtEpDevice. - * \param[in] devices The OrtHardwareDevice instances that are available. - * \param[in] num_devices The number of OrtHardwareDevice instances. - * \param[out] ep_devices OrtEpDevice instances for each OrtHardwareDevice that the EP can use. - * The implementation should call OrtEpApi::CreateEpDevice to create, and add the OrtEpDevice - * instances to this pre-allocated array. ORT will take ownership of the values returned. - * i.e. usage is `ep_devices[0] = ;` - * \param[in] max_ep_devices The maximum number of OrtEpDevices that can be added to ep_devices. - * Current default is 8. This can be increased if needed. - * \param[out] num_ep_devices The number of EP devices added to ep_devices. - * \return true if the factory can create an execution provider that uses `device`. - * - * \note ORT will take ownership or ep_metadata and/or ep_options if they are not null. - * - * \since Version 1.22. - */ - OrtStatus*(ORT_API_CALL* GetSupportedDevices)(_In_ OrtEpFactory* this_ptr, - _In_reads_(num_devices) const OrtHardwareDevice* const* devices, - _In_ size_t num_devices, - _Inout_ OrtEpDevice** ep_devices, - _In_ size_t max_ep_devices, - _Out_ size_t* num_ep_devices); - - /** \brief Function to create an OrtEp instance for use in a Session. - * - * ORT will call ReleaseEp to release the instance when it is no longer needed. - * - * \param[in] this_ptr The OrtEpFactory instance. - * \param[in] devices The OrtHardwareDevice instances that the execution provider was selected to use. - * \param[in] ep_metadata_pairs Execution provider metadata that was provided to OrtEpApi::CreateEpDevice, for each - * device. - * \param[in] num_devices The number of devices the execution provider was selected for. - * \param[in] session_options The OrtSessionOptions instance that contains the configuration options for the - * session. This will include ep_options from GetSupportedDevices as well as any - * user provided overrides. - * Execution provider options will have been added with a prefix of 'ep.[ep name].'. - * The OrtSessionOptions instance will NOT be valid after this call and should not be - * stored for later use. - * \param[in] logger The OrtLogger instance for the session that the execution provider should use for logging. - * \param[out] ep The OrtEp instance created by the factory. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version [coming soon]. This is a placeholder. - */ - OrtStatus*(ORT_API_CALL* CreateEp)(_In_ OrtEpFactory* this_ptr, - _In_reads_(num_devices) const OrtHardwareDevice* const* devices, - _In_reads_(num_devices) const OrtKeyValuePairs* const* ep_metadata_pairs, - _In_ size_t num_devices, - _In_ const OrtSessionOptions* session_options, - _In_ const OrtLogger* logger, _Outptr_ OrtEp** ep); - - /** \brief Release the OrtEp instance. - * - * \param[in] this_ptr The OrtEpFactory instance. - * \param[in] ep The OrtEp instance to release. - * - * \since Version [coming soon]. This is a placeholder. - */ - void(ORT_API_CALL* ReleaseEp)(OrtEpFactory* this_ptr, struct OrtEp* ep); -}; - /* * This is the old way to add the CUDA provider to the session, please use SessionOptionsAppendExecutionProvider_CUDA above to access the latest functionality * This function always exists, but will only succeed if Onnxruntime was built with CUDA support and the CUDA provider shared library exists @@ -6297,3 +6684,5 @@ ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_Tensorrt, _In_ OrtSessio } #endif /// @} + +#include "onnxruntime_ep_c_api.h" diff --git a/include/onnxruntime/core/session/onnxruntime_ep_c_api.h b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h new file mode 100644 index 0000000000000..68b6992177b0d --- /dev/null +++ b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h @@ -0,0 +1,370 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// Do not include this file directly. Please include "onnxruntime_c_api.h" instead. + +#ifdef __cplusplus +extern "C" { +#endif + +ORT_RUNTIME_CLASS(Ep); +ORT_RUNTIME_CLASS(EpFactory); +ORT_RUNTIME_CLASS(EpGraphSupportInfo); +ORT_RUNTIME_CLASS(NodeComputeContext); + +/** + * \brief The OrtNodeComputeInfo struct provides functions that an OrtEp implements to specify the compute + * function for a compiled OrtGraph instance. + * \since Version 1.23. + */ +struct OrtNodeComputeInfo { + /** \brief The ONNX Runtime version the OrtNodeComputeInfo was compiled with. + * + * Implementation should set to ORT_API_VERSION. + * ORT will use this to ensure it does not call functions that were not available when the library was compiled. + * + * \since Version 1.23. + */ + uint32_t ort_version_supported; + + /** \brief Creates an opaque compute state object that is then passed to the Compute() function during inference. + * \param[in] this_ptr The OrtNodeComputeInfo instance. + * \param[in] compute_context OrtNodeComputeContext instance that contains compiled/fused node's name and host + * memory allocation functions. Can optionally be used to build the compute state. + * \param[out] compute_state Output parameter that is assigned the opaque computation state. ONNX Runtime calls + * ReleaseState() (after calling Compute()) to allow the implementer to release the + * compute state. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.23. + */ + OrtStatus*(ORT_API_CALL* CreateState)(_In_ OrtNodeComputeInfo* this_ptr, + _In_ OrtNodeComputeContext* compute_context, + _Outptr_ void** compute_state); + + /** \brief Computation function called to execute the fused node compiled by an OrtEp instance. + * \param[in] this_ptr The OrtNodeComputeInfo instance. + * \param[in] compute_state The opaque computation state returned by CreateState(). + * \param[in] kernel_context The OrtKernelContext instance used to access inputs/outputs. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.23. + */ + OrtStatus*(ORT_API_CALL* Compute)(_In_ OrtNodeComputeInfo* this_ptr, _In_ void* compute_state, + _In_ OrtKernelContext* kernel_context); + + /** \brief Releases the compute state returned by CreateState(). + * \param[in] this_ptr The OrtNodeComputeInfo instance. + * \param[inout] compute_state The opaque compute state returned by CreateState(). + * + * \since Version 1.23. + */ + void(ORT_API_CALL* ReleaseState)(_In_ OrtNodeComputeInfo* this_ptr, _Frees_ptr_opt_ void* compute_state); +}; + +struct OrtEpApi { + /** \brief Create an OrtEpDevice for the EP and an OrtHardwareDevice. + * \param[in] ep_factory Execution provider factory that is creating the instance. + * \param[in] hardware_device Hardware device that the EP can utilize. + * \param[in] ep_metadata Optional OrtKeyValuePairs instance for execution provider metadata that may be used + * during execution provider selection and passed to CreateEp. + * ep_device will copy this instance and the user should call ReleaseKeyValuePairs. + * \param[in] ep_options Optional OrtKeyValuePairs instance for execution provider options that will be added + * to the Session configuration options if the execution provider is selected. + * ep_device will copy this instance and the user should call ReleaseKeyValuePairs. + * \param ep_device OrtExecutionDevice that is created. + * + * \since Version 1.22. + */ + ORT_API2_STATUS(CreateEpDevice, _In_ OrtEpFactory* ep_factory, + _In_ const OrtHardwareDevice* hardware_device, + _In_opt_ const OrtKeyValuePairs* ep_metadata, + _In_opt_ const OrtKeyValuePairs* ep_options, + _Out_ OrtEpDevice** ep_device); + + ORT_CLASS_RELEASE(EpDevice); + + /** \brief Specify nodes that are supported by an OrtEp and should be fused into one node. + * + * IMPORTANT: This is not the final version of this API function. This is currently experimental but will + * be stabilized by the ONNX Runtime 1.23 release. + * + * Because the nodes will be fused into one "fused node", there must not exist an unsupported node in + * a path between two of the provided nodes. Otherwise, the graph will become invalid. + * + * This function can be called multiple times. A subsequent call to this function will force the next set of + * nodes to be fused into a different node. + * + * \param[in] graph_support_info OrtEpGraphSupportInfo instance to which to add the supported nodes. + * \param[in] nodes Array of nodes supported by the EP that should be fused/compiled. + * \param[in] num_nodes The number of supported nodes. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.23. + */ + ORT_API2_STATUS(EpGraphSupportInfo_AddNodesToFuse, _In_ OrtEpGraphSupportInfo* graph_support_info, + _In_reads_(num_nodes) const OrtNode* const* nodes, _In_ size_t num_nodes + /*, OrtFusedNodeSchema* optional_fused_node_schema, OrtNodesToOptimizeInfo* nodes_to_opt*/); + + /** \brief Specify a node that is supported by an OrtEp and should be run with a registered EP kernel. + * + * \param[in] graph_support_info OrtEpGraphSupportInfo instance to which to add the supported node. + * \param[in] node The supported OrtNode instance. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.23. + */ + ORT_API2_STATUS(EpGraphSupportInfo_AddSingleNode, _In_ OrtEpGraphSupportInfo* graph_support_info, + _In_ const OrtNode* node); + + /** \brief Query a OrtNodeComputeContext for the name of the node that encapsulates the compiled/fused node. + * + * Used in OrtNodeComputeInfo::CreateComputeState(). + * + * \param[in] context The OrtNodeComputeContext instance to query. + * \return The node's name. + * + * \note Returned string is owned by ORT and valid only while OrtNodeComputeInfo::CreateComputeState() is called. + * + * \since Version 1.23. + */ + ORT_API_T(const char*, NodeComputeContext_NodeName, _In_ const OrtNodeComputeContext* context); +}; + +/** + * \brief The OrtEp struct provides functions to implement for an execution provider. + * \since Version 1.22. + */ +struct OrtEp { + /** \brief The ONNX Runtime version the execution provider was compiled with. + * + * Implementation should set to ORT_API_VERSION. + * ORT will use this to ensure it does not call functions that were not available when the library was compiled. + * + * \since Version 1.22. + */ + uint32_t ort_version_supported; + + /** \brief Get the execution provider name. + * + * \param[in] this_ptr The OrtEp instance. + * \return The execution provider name. + * + * \note Returned string is owned by ORT and valid until UnregisterExecutionProviderLibrary is called. + * + * \since Version 1.22. + */ + const char*(ORT_API_CALL* GetName)(_In_ const OrtEp* this_ptr); + + /** \brief Get information about the nodes supported by the OrtEp instance. + * + * IMPORTANT: This is not the final version of this API function. This is currently experimental but will + * be stabilized by the ONNX Runtime 1.23 release. + * + * \param[in] this_ptr The OrtEp instance. + * \param[in] graph The OrtGraph instance for which to populate node support. The OrtGraph could be a nested subgraph + * contained by a node (e.g., an If or Loop node). ONNX Runtime calls this function separately + * for each nested subgraph. + * \param[inout] graph_support_info OrtEpGraphSupportInfo instance that the implementer must fill out in order to + * specify the supported nodes. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.23. + */ + OrtStatus*(ORT_API_CALL* GetCapability)(_In_ OrtEp* this_ptr, _In_ const OrtGraph* graph, + _Inout_ OrtEpGraphSupportInfo* graph_support_info); + + /** \brief Compile OrtGraph instances assigned to the OrtEp. Implementer must set a OrtNodeComputeInfo instance + * for each OrtGraph in order to define its computation function. + * + * If the session is configured to generate a pre-compiled model, the execution provider must return EPContext nodes, + * as OrtNode instances, that ONNX Runtime uses to create a pre-compiled model, known as an "EPContext model". + * An EPContext model contains EPContext nodes. Each EPContext node encapsulates the pre-compiled binary data for a + * OrtGraph compiled for a specific execution provider. For more details about the EPContext design, refer to: + * \htmlonly + * EPContext design document. + * \endhtmlonly + * + * \param[in] this_ptr The OrtEp instance. + * \param[in] graphs Array of `count` OrtGraph instances to compile. Each graph contains only the nodes for + * which the execution provider indicated support. Nested subgraphs contained by a + * node, such as an If or Loop, have separate OrtGraph instances. + * \param[in] fused_nodes Array of `count` fused nodes that will replace the compiled graphs. + * Each fused node is an OrtNode initialized with the intended fused node name and + * input/output information. + * \param[in] count The number of OrtGraph instances to compile. + * \param[out] node_compute_infos Array of `count` OrtNodeComputeInfo instances that define each OrtGraph instance's + * computation function. The implementer allocates the OrtNodeComputeInfo instances. + * ORT calls ReleaseNodeComputeInfos() to release multiple instances in a batch. + * \param[out] ep_context_nodes Output array of `count` OrtNode instances, each representing an EPContext + * node for a compiled OrtGraph. The execution provider must use + * OrtModelEditorApi::CreateNode to create the OrtNode instances. ONNX Runtime takes + * ownership of the OrtNode instances, so the execution provider must NOT call + * OrtApi::ReleaseNode. Should be ignored if the session is not configured to generate an + * EPContext model. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \note Do NOT cache the provided OrtGraph instances in any of the OrtNodeComputeInfo functions because the + * graphs are only valid for the duration of the call to Compile. Any graph/node/input/output + * names that are needed by the OrtNodeComputeInfo functions must be copied and stored by the OrtEp. + * + * \since Version 1.23. + */ + OrtStatus*(ORT_API_CALL* Compile)(_In_ OrtEp* this_ptr, _In_ const OrtGraph** graphs, + _In_ const OrtNode** fused_nodes, _In_ size_t count, + _Out_writes_all_(count) OrtNodeComputeInfo** node_compute_infos, + _Out_writes_(count) OrtNode** ep_context_nodes); + + /** \brief Release OrtNodeComputeInfo instances. + * + * \param[in] this_ptr The OrtEp instance. + * \param[inout] node_compute_infos The OrtNodeComputeInfo instances to release. + * \param[in] num_node_compute_infos The number of OrtNodeComputeInfo instances. + * + * \since Version 1.23. + */ + void(ORT_API_CALL* ReleaseNodeComputeInfos)(_In_ OrtEp* this_ptr, + OrtNodeComputeInfo** node_compute_infos, + _In_ size_t num_node_compute_infos); +}; + +/** \brief The function signature that ORT will call to create OrtEpFactory instances. + * + * This must be available in a function called 'CreateEpFactories' in the execution provider library. + * + * \param[in] registered_name The name the execution library is registered with by RegisterExecutionProviderLibrary + * \param[in] ort_api_base The OrtApiBase instance that is used by the factory to get the OrtApi instance for the + * version of ORT that the library was compiled against. + * \param[in,out] factories The implementation should create and add OrtEpFactory instances to this + * pre-allocated array. + * i.e. usage is `factories[0] = new MyEpFactory();` + * \param[in] max_factories The maximum number of OrtEpFactory instances that can be added to `factories`. + * Current default is to allow 4 factories. This can be increased in the future if needed. + * \param[out] num_factories The number of OrtEpFactory instances created by the factory and added to `factories`. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.22. + */ +typedef OrtStatus* (*CreateEpApiFactoriesFn)(_In_ const char* registered_name, _In_ const OrtApiBase* ort_api_base, + _Inout_ OrtEpFactory** factories, _In_ size_t max_factories, + _Out_ size_t* num_factories); + +/** \brief The function signature that ORT will call to release an OrtEpFactory instance. + * + * This must be available in a function called 'ReleaseEpFactory' in the execution provider library. + * + * \param[in] factory The OrtEpFactory instance to release. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.22. + */ +typedef OrtStatus* (*ReleaseEpApiFactoryFn)(_In_ OrtEpFactory* factory); + +/** + * \brief The OrtEpFactory provides functions to create and manage execution providers. + * \since Version 1.22. + */ +struct OrtEpFactory { + /** \brief The ONNX Runtime version the execution provider was compiled with. + * + * Implementation should set to ORT_API_VERSION. + * ORT will use this to ensure it does not call functions that were not available when the library was compiled. + * + * \since Version 1.22. + */ + uint32_t ort_version_supported; + + /** \brief Get the name of the execution provider that the factory creates. + * + * \param[in] this_ptr The OrtEpFactory instance. + * \return The name of the execution provider the factory creates. + * + * \since Version 1.22. + */ + const char*(ORT_API_CALL* GetName)(const OrtEpFactory* this_ptr); + + /** \brief Get the name of vendor who owns the execution provider that the factory creates. + * + * \param[in] this_ptr The OrtEpFactory instance. + * \return vendor The vendor name of the execution provider the factory creates. + * + * \since Version 1.22. + */ + const char*(ORT_API_CALL* GetVendor)(const OrtEpFactory* this_ptr); // return EP vendor + + /** \brief Get information from the execution provider about OrtHardwareDevice support. + * + * \param[in] this_ptr The OrtEpFactory instance. + * Non-const as the factory is passed through to the CreateEp call via the OrtEpDevice. + * \param[in] devices The OrtHardwareDevice instances that are available. + * \param[in] num_devices The number of OrtHardwareDevice instances. + * \param[out] ep_devices OrtEpDevice instances for each OrtHardwareDevice that the EP can use. + * The implementation should call OrtEpApi::CreateEpDevice to create, and add the OrtEpDevice + * instances to this pre-allocated array. ORT will take ownership of the values returned. + * i.e. usage is `ep_devices[0] = ;` + * \param[in] max_ep_devices The maximum number of OrtEpDevices that can be added to ep_devices. + * Current default is 8. This can be increased if needed. + * \param[out] num_ep_devices The number of EP devices added to ep_devices. + * \return true if the factory can create an execution provider that uses `device`. + * + * \since Version 1.22. + */ + OrtStatus*(ORT_API_CALL* GetSupportedDevices)(_In_ OrtEpFactory* this_ptr, + _In_reads_(num_devices) const OrtHardwareDevice* const* devices, + _In_ size_t num_devices, + _Inout_ OrtEpDevice** ep_devices, + _In_ size_t max_ep_devices, + _Out_ size_t* num_ep_devices); + + /** \brief Function to create an OrtEp instance for use in a Session. + * + * ORT will call ReleaseEp to release the instance when it is no longer needed. + * + * \param[in] this_ptr The OrtEpFactory instance. + * \param[in] devices The OrtHardwareDevice instances that the execution provider was selected to use. + * May be a subset of the OrtHardwareDevice instances that the execution provider's factory + * set as supported in the call to OrtEpFactory::GetSupportedDevices. + * \param[in] ep_metadata_pairs Execution provider metadata that was provided to OrtEpApi::CreateEpDevice, for each + * device. + * \param[in] num_devices The number of devices the execution provider was selected for. + * \param[in] session_options The OrtSessionOptions instance that contains the configuration options for the + * session. This will include ep_options from GetSupportedDevices as well as any + * user provided overrides. + * Execution provider options will have been added with a prefix of 'ep.[ep name].'. + * The OrtSessionOptions instance will NOT be valid after this call and should not be + * stored for later use. + * \param[in] logger The OrtLogger instance for the session that the execution provider should use for logging. + * \param[out] ep The OrtEp instance created by the factory. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.22. + */ + OrtStatus*(ORT_API_CALL* CreateEp)(_In_ OrtEpFactory* this_ptr, + _In_reads_(num_devices) const OrtHardwareDevice* const* devices, + _In_reads_(num_devices) const OrtKeyValuePairs* const* ep_metadata_pairs, + _In_ size_t num_devices, + _In_ const OrtSessionOptions* session_options, + _In_ const OrtLogger* logger, _Outptr_ OrtEp** ep); + + /** \brief Release the OrtEp instance. + * + * \param[in] this_ptr The OrtEpFactory instance. + * \param[in] ep The OrtEp instance to release. + * + * \since Version 1.22. + */ + void(ORT_API_CALL* ReleaseEp)(OrtEpFactory* this_ptr, struct OrtEp* ep); +}; + +#ifdef __cplusplus +} +#endif diff --git a/js/build_webgpu.bat b/js/build_webgpu.bat index 95413509e701d..47478be74654b 100644 --- a/js/build_webgpu.bat +++ b/js/build_webgpu.bat @@ -69,11 +69,11 @@ popd set PATH=C:\Program Files\Git\usr\bin;%PATH% call %ROOT%build.bat --config %CONFIG% %CONFIG_EXTRA_FLAG% --skip_submodule_sync --build_wasm --target onnxruntime_webassembly --skip_tests^ - --enable_wasm_simd --enable_wasm_threads --use_jsep --use_webnn --use_webgpu --build_dir %BUILD_DIR% + --enable_wasm_simd --enable_wasm_threads --use_webnn --use_webgpu --build_dir %BUILD_DIR% IF NOT "%ERRORLEVEL%" == "0" ( exit /b %ERRORLEVEL% ) -copy /Y %BUILD_DIR%\%CONFIG%\ort-wasm-simd-threaded.jsep.wasm %ROOT%js\web\dist\ -copy /Y %BUILD_DIR%\%CONFIG%\ort-wasm-simd-threaded.jsep.mjs %ROOT%js\web\dist\ +copy /Y %BUILD_DIR%\%CONFIG%\ort-wasm-simd-threaded.asyncify.wasm %ROOT%js\web\dist\ +copy /Y %BUILD_DIR%\%CONFIG%\ort-wasm-simd-threaded.asyncify.mjs %ROOT%js\web\dist\ diff --git a/js/build_webgpu.sh b/js/build_webgpu.sh new file mode 100755 index 0000000000000..5fbcee7885e39 --- /dev/null +++ b/js/build_webgpu.sh @@ -0,0 +1,117 @@ +#!/bin/bash +# Exit immediately if a command exits with a non-zero status. +set -e + +# build_webgpu.sh --- build onnxruntime-web with WebGPU EP +# +# Usage: +# build_webgpu.sh config [clean] +# +# Options: +# config Build configuration, "d" (Debug) or "r" (Release) +# clean Perform a clean build (optional) + +# Determine the root directory of the project (one level up from the script's directory) +ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" +BUILD_DIR="$ROOT_DIR/build_webgpu" + +CONFIG="" +CONFIG_EXTRA_FLAG="" # This will be empty by default + +# Parse config argument +if [ "$1" = "d" ]; then + CONFIG="Debug" + CONFIG_EXTRA_FLAG="--enable_wasm_profiling --wasm_run_tests_in_browser --cmake_extra_defines onnxruntime_ENABLE_WEBASSEMBLY_OUTPUT_OPTIMIZED_MODEL=1 --enable_wasm_debug_info" +elif [ "$1" = "r" ]; then + CONFIG="Release" + CONFIG_EXTRA_FLAG="--enable_wasm_api_exception_catching --disable_rtti" +else + echo "Error: Invalid configuration \"$1\"." + echo "Configuration must be 'd' (Debug) or 'r' (Release)." + echo "Usage: $0 [d|r] [clean]" + exit 1 +fi + +CLEAN_BUILD_REQUESTED=false +if [ "$2" = "clean" ]; then + CLEAN_BUILD_REQUESTED=true +fi + +# Perform clean if requested +if [ "$CLEAN_BUILD_REQUESTED" = true ]; then + echo "--- Performing clean build ---" + if [ -d "$BUILD_DIR" ]; then + echo "Removing build directory: $BUILD_DIR" + rm -rf "$BUILD_DIR" + fi + + echo "Synchronizing and updating submodules..." + pushd "$ROOT_DIR" > /dev/null + git submodule sync --recursive + git submodule update --init --recursive + popd > /dev/null +fi + +# Determine if npm ci needs to be run +# It needs to run if: +# 1. A clean build was requested (which implies js/web/dist will be missing or stale) +# 2. The js/web/dist directory does not exist (e.g., first build or manually removed) +PERFORM_NPM_CI=false +if [ "$CLEAN_BUILD_REQUESTED" = true ]; then + PERFORM_NPM_CI=true +elif [ ! -d "$ROOT_DIR/js/web/dist" ]; then + echo "Directory $ROOT_DIR/js/web/dist not found." + PERFORM_NPM_CI=true +fi + +if [ "$PERFORM_NPM_CI" = true ]; then + echo "--- Running npm ci and pulling WASM artifacts ---" + echo "Running npm ci in $ROOT_DIR/js" + pushd "$ROOT_DIR/js" > /dev/null + npm ci + popd > /dev/null + + echo "Running npm ci in $ROOT_DIR/js/common" + pushd "$ROOT_DIR/js/common" > /dev/null + npm ci + popd > /dev/null + + echo "Running npm ci and pull:wasm in $ROOT_DIR/js/web" + pushd "$ROOT_DIR/js/web" > /dev/null + npm ci + npm run pull:wasm + popd > /dev/null +fi + +echo "--- Building WebAssembly modules ---" + +echo "Calling $ROOT_DIR/build.sh to build WebAssembly..." +# Note: If $CONFIG_EXTRA_FLAG is empty, it will be omitted from the command due to shell expansion. +"$ROOT_DIR/build.sh" \ + --config "$CONFIG" \ + --parallel \ + ${CONFIG_EXTRA_FLAG} \ + --skip_submodule_sync \ + --build_wasm \ + --target onnxruntime_webassembly \ + --skip_tests \ + --enable_wasm_simd \ + --enable_wasm_threads \ + --use_webnn \ + --use_webgpu \ + --build_dir "$BUILD_DIR" + +# The 'set -e' command at the beginning of the script ensures that the script will exit +# immediately if the build.sh command (or any other command) fails. + +echo "--- Copying build artifacts ---" +# Ensure the dist directory exists before copying files +mkdir -p "$ROOT_DIR/js/web/dist" + +echo "Copying ort-wasm-simd-threaded.asyncify.wasm to $ROOT_DIR/js/web/dist/" +cp -f "$BUILD_DIR/$CONFIG/ort-wasm-simd-threaded.asyncify.wasm" "$ROOT_DIR/js/web/dist/" + +echo "Copying ort-wasm-simd-threaded.asyncify.mjs to $ROOT_DIR/js/web/dist/" +cp -f "$BUILD_DIR/$CONFIG/ort-wasm-simd-threaded.asyncify.mjs" "$ROOT_DIR/js/web/dist/" + +echo "--- WebGPU build process completed successfully ---" diff --git a/js/common/lib/env.ts b/js/common/lib/env.ts index 112c8a1f78851..98b74a6474331 100644 --- a/js/common/lib/env.ts +++ b/js/common/lib/env.ts @@ -15,6 +15,7 @@ export declare namespace Env { * If not modified, the filename of the .wasm file is: * - `ort-wasm-simd-threaded.wasm` for default build * - `ort-wasm-simd-threaded.jsep.wasm` for JSEP build (with WebGPU and WebNN) + * - `ort-wasm-simd-threaded.asyncify.wasm` for WebGPU build with Asyncify (with WebNN) */ wasm?: URL | string; /** @@ -25,6 +26,7 @@ export declare namespace Env { * If not modified, the filename of the .mjs file is: * - `ort-wasm-simd-threaded.mjs` for default build * - `ort-wasm-simd-threaded.jsep.mjs` for JSEP build (with WebGPU and WebNN) + * - `ort-wasm-simd-threaded.asyncify.mjs` for WebGPU build with Asyncify (with WebNN) */ mjs?: URL | string; } diff --git a/js/scripts/prepare-onnx-node-tests.ts b/js/scripts/prepare-onnx-node-tests.ts index 02c33892d57d5..c70297c6e55df 100644 --- a/js/scripts/prepare-onnx-node-tests.ts +++ b/js/scripts/prepare-onnx-node-tests.ts @@ -32,6 +32,9 @@ const JS_TEST_ROOT = path.join(JS_ROOT, 'test'); const JS_TEST_DATA_ROOT = path.join(JS_TEST_ROOT, 'data'); const JS_TEST_DATA_NODE_ROOT = path.join(JS_TEST_DATA_ROOT, 'node'); +// Configuration for download retries +const MAX_DOWNLOAD_RETRY_TIMES = 3; + const main = async () => { log.info('PrepareTestData', 'Preparing node tests ...'); @@ -49,7 +52,7 @@ const main = async () => { const folderPrefix = `onnx-rel-${onnxVersion}/onnx/backend/test/data/node`; - const buffer = await downloadZip(resourceUri); + const buffer = await downloadZip(resourceUri, MAX_DOWNLOAD_RETRY_TIMES); const zip = await jszip.loadAsync(buffer); const entries = zip.filter((relativePath) => relativePath.startsWith(folderPrefix)); diff --git a/js/scripts/utils.ts b/js/scripts/utils.ts index 5d032dc01957c..84a3cbb67468a 100644 --- a/js/scripts/utils.ts +++ b/js/scripts/utils.ts @@ -11,37 +11,61 @@ import { JSZipObject } from 'jszip'; // See https://github.com/gajus/global-agent/blob/v3.0.0/README.md#environment-variables for details. globalAgentBootstrap(); -export const downloadZip = async (url: string): Promise => - new Promise((resolve, reject) => { - https.get(url, (res) => { - const { statusCode } = res; - const contentType = res.headers['content-type']; - - if (statusCode === 301 || statusCode === 302) { - downloadZip(res.headers.location!).then( - (buffer) => resolve(buffer), - (reason) => reject(reason), - ); - return; - } else if (statusCode !== 200) { - throw new Error(`Failed to download build list. HTTP status code = ${statusCode}`); - } - if (!contentType || !/^application\/zip/.test(contentType)) { - throw new Error(`unexpected content type: ${contentType}`); - } +export const downloadZip = async (url: string, maxRetryTimes = 3): Promise => { + // eslint-disable-next-line @typescript-eslint/no-explicit-any + let lastError: any; - const chunks: Buffer[] = []; - res.on('data', (chunk) => { - chunks.push(chunk); - }); - res.on('end', () => { - resolve(Buffer.concat(chunks)); - }); - res.on('error', (err) => { - reject(`${err}`); + for (let attempt = 0; attempt <= maxRetryTimes; attempt++) { + try { + return await new Promise((resolve, reject) => { + https + .get(url, (res) => { + const { statusCode } = res; + const contentType = res.headers['content-type']; + + if (statusCode === 301 || statusCode === 302) { + downloadZip(res.headers.location!, maxRetryTimes).then( + (buffer) => resolve(buffer), + (reason) => reject(reason), + ); + return; + } else if (statusCode !== 200) { + reject(new Error(`Failed to download build list. HTTP status code = ${statusCode}`)); + return; + } + if (!contentType || !/^application\/zip/.test(contentType)) { + reject(new Error(`unexpected content type: ${contentType}`)); + return; + } + + const chunks: Buffer[] = []; + res.on('data', (chunk) => { + chunks.push(chunk); + }); + res.on('end', () => { + resolve(Buffer.concat(chunks)); + }); + res.on('error', (err) => { + reject(err); + }); + }) + .on('error', (err) => { + reject(err); + }); }); - }); - }); + } catch (error) { + lastError = error; + if (attempt < maxRetryTimes) { + // Wait before retrying (exponential backoff) + const delay = Math.pow(2, attempt) * 10000; // 10s, 20s, 40s, etc. + await new Promise((resolve) => setTimeout(resolve, delay)); + console.warn(`Download attempt ${attempt + 1} failed, retrying in ${delay}ms...`); + } + } + } + + throw lastError || new Error(`Failed to download after ${maxRetryTimes + 1} attempts`); +}; export const extractFile = async (entry: JSZipObject, ostream: WriteStream): Promise => new Promise((resolve, reject) => { diff --git a/js/web/lib/build-def.d.ts b/js/web/lib/build-def.d.ts index 83a52ebaefe05..89a2b4a6ff1be 100644 --- a/js/web/lib/build-def.d.ts +++ b/js/web/lib/build-def.d.ts @@ -17,11 +17,21 @@ interface BuildDefinitions { */ readonly DISABLE_WEBGL: boolean; /** - * defines whether to disable the whole WebGpu/WebNN backend in the build. + * defines whether to disable the JSEP support in the build. */ readonly DISABLE_JSEP: boolean; + /** + * defines whether to disable the WebGPU EP support in the build. + */ + readonly DISABLE_WEBGPU: boolean; + /** + * defines whether to disable the WebNN EP support in the build. + */ + readonly DISABLE_WEBNN: boolean; /** * defines whether to disable the whole WebAssembly backend in the build. + * + * When this build flag is set to `true`, only WebGL backend will be available. */ readonly DISABLE_WASM: boolean; /** @@ -35,18 +45,12 @@ interface BuildDefinitions { * It is usually one of the following files: * - `ort-wasm-simd-threaded.mjs` * - `ort-wasm-simd-threaded.jsep.mjs` + * - `ort-wasm-simd-threaded.asyncify.mjs` * * The value is valid only when it's an ESM build. */ readonly ENABLE_BUNDLE_WASM_JS: boolean; - /** - * defines whether to use WebGPU EP instead of JSEP for WebGPU backend. - * - * This flag requires the corresponding WebAssembly artifact to be built with `--use_webgpu` flag. - */ - readonly USE_WEBGPU_EP: boolean; - // #endregion // #region Build definitions for ESM diff --git a/js/web/lib/index.ts b/js/web/lib/index.ts index 776c0d026bc97..59a5151960a1f 100644 --- a/js/web/lib/index.ts +++ b/js/web/lib/index.ts @@ -19,10 +19,26 @@ if (!BUILD_DEFS.DISABLE_WEBGL) { registerBackend('webgl', onnxjsBackend, -10); } +if (!BUILD_DEFS.DISABLE_JSEP && !BUILD_DEFS.DISABLE_WEBGPU) { + throw new Error( + 'The current build is specified to enable both JSEP and WebGPU EP. This is not a valid configuration. ' + + 'JSEP and WebGPU EPs cannot be enabled at the same time.', + ); +} + +if (!BUILD_DEFS.DISABLE_WEBNN && BUILD_DEFS.DISABLE_JSEP && BUILD_DEFS.DISABLE_WEBGPU) { + throw new Error( + 'The current build is specified to enable WebNN EP without JSEP or WebGPU EP. This is not a valid configuration. ' + + 'WebNN EP requires either JSEP or WebGPU EP to be enabled.', + ); +} + if (!BUILD_DEFS.DISABLE_WASM) { const wasmBackend = require('./backend-wasm').wasmBackend; - if (!BUILD_DEFS.DISABLE_JSEP) { + if (!BUILD_DEFS.DISABLE_JSEP || !BUILD_DEFS.DISABLE_WEBGPU) { registerBackend('webgpu', wasmBackend, 5); + } + if (!BUILD_DEFS.DISABLE_WEBNN) { registerBackend('webnn', wasmBackend, 5); } registerBackend('cpu', wasmBackend, 10); diff --git a/js/web/lib/wasm/jsep/backend-webgpu.ts b/js/web/lib/wasm/jsep/backend-webgpu.ts index 413e89111740e..e486e4b0e043d 100644 --- a/js/web/lib/wasm/jsep/backend-webgpu.ts +++ b/js/web/lib/wasm/jsep/backend-webgpu.ts @@ -428,7 +428,7 @@ export class WebGpuBackend { console.log( `[profiling] kernel "${kernelId}|${kernelType}|${kernelName}|${programName}" ${inputShapes}${ outputShapes - }execution time: ${endTime - startTime} ns`, + }start time: ${startTime} ns, execution time: ${endTime - startTime} ns`, ); } TRACE('GPU', `${programName}::${startTimeU64}::${endTimeU64}`); diff --git a/js/web/lib/wasm/jsep/init.ts b/js/web/lib/wasm/jsep/init.ts index 463e26d0208e5..50fb26fef1d41 100644 --- a/js/web/lib/wasm/jsep/init.ts +++ b/js/web/lib/wasm/jsep/init.ts @@ -197,101 +197,106 @@ export const init = async ( } if (name === 'webgpu') { - if (!BUILD_DEFS.USE_WEBGPU_EP) { - // eslint-disable-next-line @typescript-eslint/no-require-imports, @typescript-eslint/no-var-requires - const webGpuBackendImpl = require('./backend-webgpu').WebGpuBackend; - const backend = new webGpuBackendImpl(); - await backend.initialize(env, gpuAdapter!); + // eslint-disable-next-line @typescript-eslint/no-require-imports, @typescript-eslint/no-var-requires + const webGpuBackendImpl = require('./backend-webgpu').WebGpuBackend; + const backend = new webGpuBackendImpl(); + await backend.initialize(env, gpuAdapter!); - jsepInit('webgpu', [ - // backend - backend, - - // jsepAlloc() - (size: number) => backend.alloc(Number(size)), + jsepInit('webgpu', [ + // backend + backend, - // jsepFree() - (ptr: number) => backend.free(ptr), + // jsepAlloc() + (size: number) => backend.alloc(Number(size)), - // jsepCopy(src, dst, size, isSourceGpu) - (src: number, dst: number, size: number, isSourceGpu = false) => { - if (isSourceGpu) { - LOG_DEBUG( - 'verbose', - () => `[WebGPU] jsepCopyGpuToGpu: src=${Number(src)}, dst=${Number(dst)}, size=${Number(size)}`, - ); - backend.memcpy(Number(src), Number(dst)); - } else { - LOG_DEBUG( - 'verbose', - () => - `[WebGPU] jsepCopyCpuToGpu: dataOffset=${Number(src)}, gpuDataId=${Number(dst)}, size=${Number(size)}`, - ); - const data = module.HEAPU8.subarray(Number(src >>> 0), Number(src >>> 0) + Number(size)); - backend.upload(Number(dst), data); - } - }, + // jsepFree() + (ptr: number) => backend.free(ptr), - // jsepCopyAsync(src, dst, size) - async (gpuDataId: number, dataOffset: number, size: number): Promise => { + // jsepCopy(src, dst, size, isSourceGpu) + (src: number, dst: number, size: number, isSourceGpu = false) => { + if (isSourceGpu) { LOG_DEBUG( 'verbose', - () => `[WebGPU] jsepCopyGpuToCpu: gpuDataId=${gpuDataId}, dataOffset=${dataOffset}, size=${size}`, + () => `[WebGPU] jsepCopyGpuToGpu: src=${Number(src)}, dst=${Number(dst)}, size=${Number(size)}`, ); - - await backend.download(Number(gpuDataId), () => - module.HEAPU8.subarray(Number(dataOffset) >>> 0, Number(dataOffset + size) >>> 0), - ); - }, - - // jsepCreateKernel - (kernelType: string, kernelId: number, attribute: unknown) => - backend.createKernel( - kernelType, - Number(kernelId), - attribute, - module.UTF8ToString(module._JsepGetNodeName!(Number(kernelId))), - ), - - // jsepReleaseKernel - (kernel: number) => backend.releaseKernel(kernel), - - // jsepRun - (kernel: number, contextDataOffset: number, sessionHandle: number, errors: Array>) => { + backend.memcpy(Number(src), Number(dst)); + } else { LOG_DEBUG( 'verbose', () => - `[WebGPU] jsepRun: sessionHandle=${sessionHandle}, kernel=${kernel}, contextDataOffset=${contextDataOffset}`, + `[WebGPU] jsepCopyCpuToGpu: dataOffset=${Number(src)}, gpuDataId=${Number(dst)}, size=${Number(size)}`, ); - const context = new ComputeContextImpl(module, backend, Number(contextDataOffset)); - return backend.computeKernel(Number(kernel), context, errors); - }, - // jsepCaptureBegin - () => backend.captureBegin(), - // jsepCaptureEnd - () => backend.captureEnd(), - // jsepReplay - () => backend.replay(), - ]); - } + const data = module.HEAPU8.subarray(Number(src >>> 0), Number(src >>> 0) + Number(size)); + backend.upload(Number(dst), data); + } + }, + + // jsepCopyAsync(src, dst, size) + async (gpuDataId: number, dataOffset: number, size: number): Promise => { + LOG_DEBUG( + 'verbose', + () => `[WebGPU] jsepCopyGpuToCpu: gpuDataId=${gpuDataId}, dataOffset=${dataOffset}, size=${size}`, + ); + + await backend.download(Number(gpuDataId), () => + module.HEAPU8.subarray(Number(dataOffset) >>> 0, Number(dataOffset + size) >>> 0), + ); + }, + + // jsepCreateKernel + (kernelType: string, kernelId: number, attribute: unknown) => + backend.createKernel( + kernelType, + Number(kernelId), + attribute, + module.UTF8ToString(module._JsepGetNodeName!(Number(kernelId))), + ), + + // jsepReleaseKernel + (kernel: number) => backend.releaseKernel(kernel), + + // jsepRun + (kernel: number, contextDataOffset: number, sessionHandle: number, errors: Array>) => { + LOG_DEBUG( + 'verbose', + () => + `[WebGPU] jsepRun: sessionHandle=${sessionHandle}, kernel=${kernel}, contextDataOffset=${contextDataOffset}`, + ); + const context = new ComputeContextImpl(module, backend, Number(contextDataOffset)); + return backend.computeKernel(Number(kernel), context, errors); + }, + // jsepCaptureBegin + () => backend.captureBegin(), + // jsepCaptureEnd + () => backend.captureEnd(), + // jsepReplay + () => backend.replay(), + ]); } else { const backend = new WebNNBackend(env); jsepInit('webnn', [ backend, - // jsepReserveTensorId + // webnnReserveTensorId () => backend.reserveTensorId(), - // jsepReleaseTensorId, + // webnnReleaseTensorId (tensorId: number) => backend.releaseTensorId(tensorId), - // jsepEnsureTensor - async (sessionId: number | undefined, tensorId: number, onnxDataType: number, shape: number[], copyOld) => - backend.ensureTensor(sessionId, tensorId, onnxDataType, shape, copyOld), - // jsepUploadTensor + // webnnEnsureTensor + async ( + sessionId: number | undefined, + tensorId: number, + onnxDataType: number, + shape: number[], + copyOld: boolean, + ) => backend.ensureTensor(sessionId, tensorId, onnxDataType, shape, copyOld), + // webnnUploadTensor (tensorId: number, data: Uint8Array) => { backend.uploadTensor(tensorId, data); }, - // jsepDownloadTensor + // webnnDownloadTensor async (tensorId: number, dstBuffer: ArrayBufferView | ArrayBuffer) => backend.downloadTensor(tensorId, dstBuffer), - // jsepEnableTraceEvent + // webnnRegisterMLContext + (sessionId: number, mlContext: MLContext) => backend.registerMLContext(sessionId, mlContext), + // webnnEnableTraceEvent !!env.trace, ]); } diff --git a/js/web/lib/wasm/proxy-wrapper.ts b/js/web/lib/wasm/proxy-wrapper.ts index a6b0ac6d5a051..b725f5c8e80b7 100644 --- a/js/web/lib/wasm/proxy-wrapper.ts +++ b/js/web/lib/wasm/proxy-wrapper.ts @@ -133,7 +133,9 @@ export const initializeWebAssemblyAndOrtRuntime = async (): Promise => { message.in!.wasm.wasmPaths = { wasm: !BUILD_DEFS.DISABLE_JSEP ? new URL('ort-wasm-simd-threaded.jsep.wasm', BUILD_DEFS.ESM_IMPORT_META_URL).href - : new URL('ort-wasm-simd-threaded.wasm', BUILD_DEFS.ESM_IMPORT_META_URL).href, + : !BUILD_DEFS.DISABLE_WEBGPU + ? new URL('ort-wasm-simd-threaded.asyncify.wasm', BUILD_DEFS.ESM_IMPORT_META_URL).href + : new URL('ort-wasm-simd-threaded.wasm', BUILD_DEFS.ESM_IMPORT_META_URL).href, }; } proxyWorker.postMessage(message); diff --git a/js/web/lib/wasm/session-options.ts b/js/web/lib/wasm/session-options.ts index 26d07b4347131..52d40bb403c77 100644 --- a/js/web/lib/wasm/session-options.ts +++ b/js/web/lib/wasm/session-options.ts @@ -93,7 +93,7 @@ const setExecutionProviders = async ( } break; case 'webgpu': - if (BUILD_DEFS.USE_WEBGPU_EP) { + if (!BUILD_DEFS.DISABLE_WEBGPU) { epName = 'WebGPU'; let customDevice: GPUDevice | undefined; diff --git a/js/web/lib/wasm/wasm-core-impl.ts b/js/web/lib/wasm/wasm-core-impl.ts index cfdc0053b3485..ae0a06c9e749b 100644 --- a/js/web/lib/wasm/wasm-core-impl.ts +++ b/js/web/lib/wasm/wasm-core-impl.ts @@ -105,65 +105,89 @@ export const initEp = async (env: Env, epName: string): Promise => { // initialize ASYNCIFY support getInstance().asyncInit?.(); - if (epName === 'webgpu' && BUILD_DEFS.USE_WEBGPU_EP) { - getInstance().webgpuInit!((device) => { - env.webgpu.device = device; - }); + // perform WebGPU availability check ( either JSEP or WebGPU EP ) + let webgpuAdapter = env.webgpu.adapter as GPUAdapter | null; + if (epName === 'webgpu') { + if (typeof navigator === 'undefined' || !navigator.gpu) { + throw new Error('WebGPU is not supported in current environment'); + } + if (!webgpuAdapter) { + // if adapter is not set, request a new adapter. + const powerPreference = env.webgpu.powerPreference; + if (powerPreference !== undefined && powerPreference !== 'low-power' && powerPreference !== 'high-performance') { + throw new Error(`Invalid powerPreference setting: "${powerPreference}"`); + } + const forceFallbackAdapter = env.webgpu.forceFallbackAdapter; + if (forceFallbackAdapter !== undefined && typeof forceFallbackAdapter !== 'boolean') { + throw new Error(`Invalid forceFallbackAdapter setting: "${forceFallbackAdapter}"`); + } + webgpuAdapter = await navigator.gpu.requestAdapter({ powerPreference, forceFallbackAdapter }); + if (!webgpuAdapter) { + throw new Error( + 'Failed to get GPU adapter. ' + + 'You may need to enable flag "--enable-unsafe-webgpu" if you are using Chrome.', + ); + } + } else { + // if adapter is set, validate it. + if ( + typeof webgpuAdapter.limits !== 'object' || + typeof webgpuAdapter.features !== 'object' || + typeof webgpuAdapter.requestDevice !== 'function' + ) { + throw new Error('Invalid GPU adapter set in `env.webgpu.adapter`. It must be a GPUAdapter object.'); + } + } + } + + // perform WebNN availability check ( either JSEP or WebNN EP ) + if (epName === 'webnn') { + if (typeof navigator === 'undefined' || !(navigator as unknown as { ml: unknown }).ml) { + throw new Error('WebNN is not supported in current environment'); + } } if (!BUILD_DEFS.DISABLE_JSEP) { // eslint-disable-next-line @typescript-eslint/no-require-imports, @typescript-eslint/no-var-requires const initJsep = require('./jsep/init').init; - if (epName === 'webgpu' && !BUILD_DEFS.USE_WEBGPU_EP) { - // perform WebGPU availability check - if (typeof navigator === 'undefined' || !navigator.gpu) { - throw new Error('WebGPU is not supported in current environment'); - } - - let adapter = env.webgpu.adapter as GPUAdapter | null; - if (!adapter) { - // if adapter is not set, request a new adapter. - const powerPreference = env.webgpu.powerPreference; - if ( - powerPreference !== undefined && - powerPreference !== 'low-power' && - powerPreference !== 'high-performance' - ) { - throw new Error(`Invalid powerPreference setting: "${powerPreference}"`); - } - const forceFallbackAdapter = env.webgpu.forceFallbackAdapter; - if (forceFallbackAdapter !== undefined && typeof forceFallbackAdapter !== 'boolean') { - throw new Error(`Invalid forceFallbackAdapter setting: "${forceFallbackAdapter}"`); - } - adapter = await navigator.gpu.requestAdapter({ powerPreference, forceFallbackAdapter }); - if (!adapter) { - throw new Error( - 'Failed to get GPU adapter. ' + - 'You may need to enable flag "--enable-unsafe-webgpu" if you are using Chrome.', - ); - } - } else { - // if adapter is set, validate it. - if ( - typeof adapter.limits !== 'object' || - typeof adapter.features !== 'object' || - typeof adapter.requestDevice !== 'function' - ) { - throw new Error('Invalid GPU adapter set in `env.webgpu.adapter`. It must be a GPUAdapter object.'); - } - } - - await initJsep('webgpu', getInstance(), env, adapter); + if (epName === 'webgpu') { + await initJsep('webgpu', getInstance(), env, webgpuAdapter); } if (epName === 'webnn') { - // perform WebNN availability check - if (typeof navigator === 'undefined' || !(navigator as unknown as { ml: unknown }).ml) { - throw new Error('WebNN is not supported in current environment'); - } - await initJsep('webnn', getInstance(), env); } + } else { + if (!BUILD_DEFS.DISABLE_WEBGPU && epName === 'webgpu') { + getInstance().webgpuInit!((device) => { + env.webgpu.device = device; + }); + } + if (!BUILD_DEFS.DISABLE_WEBNN && epName === 'webnn') { + // eslint-disable-next-line @typescript-eslint/no-require-imports, @typescript-eslint/no-var-requires + const backend = new (require('./jsep/backend-webnn').WebNNBackend)(env); + getInstance().webnnInit!([ + backend, + // webnnReserveTensorId + () => backend.reserveTensorId(), + // webnnReleaseTensorId, + (tensorId: number) => backend.releaseTensorId(tensorId), + // webnnEnsureTensor + async (sessionId: number | undefined, tensorId: number, onnxDataType: number, shape: number[], copyOld) => + backend.ensureTensor(sessionId, tensorId, onnxDataType, shape, copyOld), + // webnnUploadTensor + (tensorId: number, data: Uint8Array) => { + backend.uploadTensor(tensorId, data); + }, + // webnnDownloadTensor + async (tensorId: number, dstBuffer: ArrayBufferView | ArrayBuffer) => + backend.downloadTensor(tensorId, dstBuffer), + // webnnRegisterMLContext + (sessionId: number, mlContext: MLContext) => backend.registerMLContext(sessionId, mlContext), + // webnnEnableTraceEvent + !!env.trace, + ]); + } } }; @@ -450,7 +474,7 @@ export const createSession = async ( // use IO binding only when at least one output is preferred to be on GPU. let bindingState: IOBindingState | null = null; if ( - !BUILD_DEFS.DISABLE_JSEP && + (!BUILD_DEFS.DISABLE_JSEP || !BUILD_DEFS.DISABLE_WEBGPU) && outputPreferredLocations.some((l) => l === 'gpu-buffer' || l === 'ml-tensor' || l === 'ml-tensor-cpu-output') ) { ioBindingHandle = wasm._OrtCreateBinding(sessionHandle); @@ -577,7 +601,7 @@ export const prepareInputOutputTensor = async ( const gpuBuffer = tensor[2].gpuBuffer; dataByteLength = calculateTensorSizeInBytes(tensorDataTypeStringToEnum(dataType), dims)!; - if (BUILD_DEFS.USE_WEBGPU_EP) { + if (!BUILD_DEFS.DISABLE_WEBGPU) { const registerBuffer = wasm.webgpuRegisterBuffer; if (!registerBuffer) { throw new Error('Tensor location "gpu-buffer" is not supported without using WebGPU.'); @@ -748,7 +772,7 @@ export const run = async ( wasm.setValue(outputNamesOffset + i * ptrSize, outputNamesUTF8Encoded[outputIndices[i]], '*'); } - if (!BUILD_DEFS.DISABLE_JSEP && ioBindingState && !inputOutputBound) { + if ((!BUILD_DEFS.DISABLE_JSEP || !BUILD_DEFS.DISABLE_WEBGPU) && ioBindingState && !inputOutputBound) { const { handle, outputPreferredLocations, outputPreferredLocationsEncoded } = ioBindingState; if (inputNamesUTF8Encoded.length !== inputCount) { @@ -806,7 +830,7 @@ export const run = async ( wasm.webnnOnRunStart?.(sessionHandle); let errorCode: number; - if (!BUILD_DEFS.DISABLE_JSEP && ioBindingState) { + if ((!BUILD_DEFS.DISABLE_JSEP || !BUILD_DEFS.DISABLE_WEBGPU) && ioBindingState) { errorCode = await wasm._OrtRunWithBinding( sessionHandle, ioBindingState.handle, @@ -895,7 +919,7 @@ export const run = async ( // If a certain output's preferred location is GPU but the tensor is empty, we still need to create a CPU // tensor for it. There is no mapping GPU buffer for an empty tensor. if (preferredLocation === 'gpu-buffer' && size > 0) { - const getBuffer = BUILD_DEFS.USE_WEBGPU_EP ? wasm.webgpuGetBuffer : wasm.jsepGetBuffer; + const getBuffer = !BUILD_DEFS.DISABLE_WEBGPU ? wasm.webgpuGetBuffer : wasm.jsepGetBuffer; if (!getBuffer) { throw new Error('preferredLocation "gpu-buffer" is not supported without using WebGPU.'); } @@ -908,7 +932,7 @@ export const run = async ( // do not release the tensor right now. it will be released when user calls tensor.dispose(). keepOutputTensor = true; - if (BUILD_DEFS.USE_WEBGPU_EP) { + if (!BUILD_DEFS.DISABLE_WEBGPU) { wasm.webgpuRegisterBuffer!(gpuBuffer, sessionId, dataOffset); const downloadDataFunction = wasm.webgpuCreateDownloader!(gpuBuffer, bufferSize, sessionId); output.push([ @@ -1040,7 +1064,7 @@ export const run = async ( wasm.stackRestore(beforeRunStack); - if (BUILD_DEFS.USE_WEBGPU_EP) { + if (!BUILD_DEFS.DISABLE_WEBGPU) { inputTensors.forEach((t) => { if (t && t[3] === 'gpu-buffer') { wasm.webgpuUnregisterBuffer!(t[2].gpuBuffer); diff --git a/js/web/lib/wasm/wasm-types.ts b/js/web/lib/wasm/wasm-types.ts index 29a4028ae46cc..842bcd0c24a07 100644 --- a/js/web/lib/wasm/wasm-types.ts +++ b/js/web/lib/wasm/wasm-types.ts @@ -28,19 +28,8 @@ export declare namespace JSEP { type CaptureBeginFunction = () => void; type CaptureEndFunction = () => void; type ReplayFunction = () => void; - type ReserveTensorIdFunction = () => number; - type ReleaseTensorIdFunction = (tensorId: number) => void; - type EnsureTensorFunction = ( - sessionId: number | undefined, - tensorId: number, - dataType: DataType, - shape: readonly number[], - copyOld: boolean, - ) => Promise; - type UploadTensorFunction = (tensorId: number, data: Uint8Array) => void; - type DownloadTensorFunction = (tensorId: number, dstBuffer: ArrayBufferView | ArrayBuffer) => Promise; - export interface Module extends WebGpuModule, WebNnModule { + export interface Module extends WebGpuModule, WebNN.Module { /** * This is the entry of JSEP initialization. This function is called once when initializing ONNX Runtime per * backend. This function initializes Asyncify support. If name is 'webgpu', also initializes WebGPU backend and @@ -65,12 +54,13 @@ export declare namespace JSEP { jsepInit( name: 'webnn', initParams: [ - backend: BackendType, - reserveTensorId: ReserveTensorIdFunction, - releaseTensorId: ReleaseTensorIdFunction, - ensureTensor: EnsureTensorFunction, - uploadTensor: UploadTensorFunction, - downloadTensor: DownloadTensorFunction, + backend: WebNN.BackendType, + reserveTensorId: WebNN.ReserveTensorIdFunction, + releaseTensorId: WebNN.ReleaseTensorIdFunction, + ensureTensor: WebNN.EnsureTensorFunction, + uploadTensor: WebNN.UploadTensorFunction, + downloadTensor: WebNN.DownloadTensorFunction, + registerMLTensor: WebNN.RegisterMLTensorFunction, enableTraceEvent: boolean, ], ): void; @@ -145,8 +135,54 @@ export declare namespace JSEP { */ jsepOnReleaseSession: (sessionId: number) => void; } +} + +export declare namespace WebGpu { + export interface Module { + webgpuInit(setDefaultDevice: (device: GPUDevice) => void): void; + webgpuRegisterDevice( + device?: GPUDevice, + ): undefined | [deviceId: number, instanceHandle: number, deviceHandle: number]; + webgpuOnCreateSession(sessionHandle: number): void; + webgpuOnReleaseSession(sessionHandle: number): void; + webgpuRegisterBuffer(buffer: GPUBuffer, sessionHandle: number, bufferHandle?: number): number; + webgpuUnregisterBuffer(buffer: GPUBuffer): void; + webgpuGetBuffer(bufferHandle: number): GPUBuffer; + webgpuCreateDownloader(gpuBuffer: GPUBuffer, size: number, sessionHandle: number): () => Promise; + } +} + +export declare namespace WebNN { + type BackendType = unknown; + type ReserveTensorIdFunction = () => number; + type ReleaseTensorIdFunction = (tensorId: number) => void; + type EnsureTensorFunction = ( + sessionId: number | undefined, + tensorId: number, + dataType: DataType, + shape: readonly number[], + copyOld: boolean, + ) => Promise; + type UploadTensorFunction = (tensorId: number, data: Uint8Array) => void; + type DownloadTensorFunction = (tensorId: number, dstBuffer: ArrayBufferView | ArrayBuffer) => Promise; + type RegisterMLTensorFunction = (sessionId: number, mlContext: MLContext) => void; - export interface WebNnModule { + export interface Module { + /** + * The entry of WebNN initialization when used without JSEP. + */ + webnnInit( + initParams: [ + backend: BackendType, + reserveTensorId: ReserveTensorIdFunction, + releaseTensorId: ReleaseTensorIdFunction, + ensureTensor: EnsureTensorFunction, + uploadTensor: UploadTensorFunction, + downloadTensor: DownloadTensorFunction, + registerMLTensor: RegisterMLTensorFunction, + enableTraceEvent: boolean, + ], + ): void; /** * Active MLContext used to create WebNN EP. */ @@ -321,21 +357,6 @@ export declare namespace JSEP { } } -export declare namespace WebGpu { - export interface Module { - webgpuInit(setDefaultDevice: (device: GPUDevice) => void): void; - webgpuRegisterDevice( - device?: GPUDevice, - ): undefined | [deviceId: number, instanceHandle: number, deviceHandle: number]; - webgpuOnCreateSession(sessionHandle: number): void; - webgpuOnReleaseSession(sessionHandle: number): void; - webgpuRegisterBuffer(buffer: GPUBuffer, sessionHandle: number, bufferHandle?: number): number; - webgpuUnregisterBuffer(buffer: GPUBuffer): void; - webgpuGetBuffer(bufferHandle: number): GPUBuffer; - webgpuCreateDownloader(gpuBuffer: GPUBuffer, size: number, sessionHandle: number): () => Promise; - } -} - export interface OrtInferenceAPIs { _OrtInit(numThreads: number, loggingLevel: number): number; @@ -429,7 +450,8 @@ export interface OrtWasmModule extends EmscriptenModule, OrtInferenceAPIs, Partial, - Partial { + Partial, + Partial { // #region emscripten functions stackSave(): number; stackRestore(stack: number): void; diff --git a/js/web/lib/wasm/wasm-utils-import.ts b/js/web/lib/wasm/wasm-utils-import.ts index d9180e220c80c..fa7efa9910c59 100644 --- a/js/web/lib/wasm/wasm-utils-import.ts +++ b/js/web/lib/wasm/wasm-utils-import.ts @@ -214,7 +214,9 @@ const embeddedWasmModule: EmscriptenModuleFactory | undefined = require( !BUILD_DEFS.DISABLE_JSEP ? '../../dist/ort-wasm-simd-threaded.jsep.mjs' - : '../../dist/ort-wasm-simd-threaded.mjs', + : !BUILD_DEFS.DISABLE_WEBGPU + ? '../../dist/ort-wasm-simd-threaded.asyncify.mjs' + : '../../dist/ort-wasm-simd-threaded.mjs', ).default : undefined; @@ -276,7 +278,9 @@ export const importWasmModule = async ( } else { const wasmModuleFilename = !BUILD_DEFS.DISABLE_JSEP ? 'ort-wasm-simd-threaded.jsep.mjs' - : 'ort-wasm-simd-threaded.mjs'; + : !BUILD_DEFS.DISABLE_WEBGPU + ? 'ort-wasm-simd-threaded.asyncify.mjs' + : 'ort-wasm-simd-threaded.mjs'; const wasmModuleUrl = urlOverride ?? normalizeUrl(wasmModuleFilename, prefixOverride); // need to preload if all of the following conditions are met: // 1. not in Node.js. diff --git a/js/web/script/build.ts b/js/web/script/build.ts index 2ea883f739c52..22f10b0b90a8f 100644 --- a/js/web/script/build.ts +++ b/js/web/script/build.ts @@ -47,13 +47,10 @@ const DEBUG = process.env.npm_config_debug || args.debug; // boolean|'verbose'|' /** * --webgpu-ep * --no-webgpu-ep (default) - * --webgpu-ep=runtime * * Enable or disable the use of WebGPU EP. If enabled, the WebGPU EP will be used. If disabled, the WebGPU backend will * be used with JSEP. * - * If set to "runtime", it will be determined at runtime based on the value of `globalThis.WEBGPU_EP`. - * * (temporary) This flag is used to test the WebGPU EP integration. It will be removed in the future. */ const USE_WEBGPU_EP = process.env.npm_config_webgpu_ep ?? args['webgpu-ep'] ?? false; @@ -68,11 +65,12 @@ const SOURCE_ROOT_FOLDER = path.join(__dirname, '../..'); */ const DEFAULT_DEFINE = { 'BUILD_DEFS.DISABLE_WEBGL': 'false', - 'BUILD_DEFS.DISABLE_JSEP': 'false', + 'BUILD_DEFS.DISABLE_JSEP': JSON.stringify(!!USE_WEBGPU_EP), 'BUILD_DEFS.DISABLE_WASM': 'false', 'BUILD_DEFS.DISABLE_WASM_PROXY': 'false', 'BUILD_DEFS.ENABLE_BUNDLE_WASM_JS': 'false', - 'BUILD_DEFS.USE_WEBGPU_EP': USE_WEBGPU_EP === 'runtime' ? 'globalThis.WEBGPU_EP' : JSON.stringify(!!USE_WEBGPU_EP), + 'BUILD_DEFS.DISABLE_WEBGPU': JSON.stringify(!USE_WEBGPU_EP), + 'BUILD_DEFS.DISABLE_WEBNN': 'false', 'BUILD_DEFS.IS_ESM': 'false', 'BUILD_DEFS.ESM_IMPORT_META_URL': 'undefined', @@ -269,7 +267,7 @@ async function buildBundle(options: esbuild.BuildOptions) { * * The distribution code is split into multiple files: * - [output-name][.min].[m]js - * - ort-wasm-simd-threaded[.jsep].mjs + * - ort-wasm-simd-threaded[.jsep|.asyncify].mjs */ async function buildOrt({ isProduction = false, @@ -564,6 +562,7 @@ async function main() { define: { ...DEFAULT_DEFINE, 'BUILD_DEFS.DISABLE_JSEP': 'true', + 'BUILD_DEFS.DISABLE_WEBNN': 'true', 'BUILD_DEFS.DISABLE_WEBGL': 'true', 'BUILD_DEFS.DISABLE_WASM_PROXY': 'true', }, @@ -577,6 +576,7 @@ async function main() { define: { ...DEFAULT_DEFINE, 'BUILD_DEFS.DISABLE_JSEP': 'true', + 'BUILD_DEFS.DISABLE_WEBNN': 'true', 'BUILD_DEFS.DISABLE_WEBGL': 'true', 'BUILD_DEFS.DISABLE_WASM_PROXY': 'true', }, @@ -624,20 +624,36 @@ async function main() { // ort.webgpu[.min].[m]js await addAllWebBuildTasks({ outputName: 'ort.webgpu', - define: { ...DEFAULT_DEFINE, 'BUILD_DEFS.DISABLE_WEBGL': 'true' }, + define: { + ...DEFAULT_DEFINE, + 'BUILD_DEFS.DISABLE_WEBGPU': 'false', + 'BUILD_DEFS.DISABLE_JSEP': 'true', + 'BUILD_DEFS.DISABLE_WEBGL': 'true', + }, }); // ort.webgpu.bundle.min.mjs await buildOrt({ isProduction: true, outputName: 'ort.webgpu.bundle', format: 'esm', - define: { ...DEFAULT_DEFINE, 'BUILD_DEFS.DISABLE_WEBGL': 'true', 'BUILD_DEFS.ENABLE_BUNDLE_WASM_JS': 'true' }, + define: { + ...DEFAULT_DEFINE, + 'BUILD_DEFS.DISABLE_WEBGPU': 'false', + 'BUILD_DEFS.DISABLE_JSEP': 'true', + 'BUILD_DEFS.DISABLE_WEBGL': 'true', + 'BUILD_DEFS.ENABLE_BUNDLE_WASM_JS': 'true', + }, }); // ort.wasm[.min].[m]js await addAllWebBuildTasks({ outputName: 'ort.wasm', - define: { ...DEFAULT_DEFINE, 'BUILD_DEFS.DISABLE_JSEP': 'true', 'BUILD_DEFS.DISABLE_WEBGL': 'true' }, + define: { + ...DEFAULT_DEFINE, + 'BUILD_DEFS.DISABLE_JSEP': 'true', + 'BUILD_DEFS.DISABLE_WEBNN': 'true', + 'BUILD_DEFS.DISABLE_WEBGL': 'true', + }, }); // ort.wasm.bundle.min.mjs await buildOrt({ @@ -647,6 +663,7 @@ async function main() { define: { ...DEFAULT_DEFINE, 'BUILD_DEFS.DISABLE_JSEP': 'true', + 'BUILD_DEFS.DISABLE_WEBNN': 'true', 'BUILD_DEFS.DISABLE_WEBGL': 'true', 'BUILD_DEFS.ENABLE_BUNDLE_WASM_JS': 'true', }, @@ -657,6 +674,7 @@ async function main() { define: { ...DEFAULT_DEFINE, 'BUILD_DEFS.DISABLE_JSEP': 'true', + 'BUILD_DEFS.DISABLE_WEBNN': 'true', 'BUILD_DEFS.DISABLE_WASM': 'true', 'BUILD_DEFS.DISABLE_WASM_PROXY': 'true', }, diff --git a/js/web/script/pull-prebuilt-wasm-artifacts.ts b/js/web/script/pull-prebuilt-wasm-artifacts.ts index 89c57c191de0e..c3300f7272bb9 100644 --- a/js/web/script/pull-prebuilt-wasm-artifacts.ts +++ b/js/web/script/pull-prebuilt-wasm-artifacts.ts @@ -151,6 +151,7 @@ async function downloadArtifactsForRun(run: any): Promise { if (!fs.existsSync(WASM_FOLDER)) { fs.mkdirSync(WASM_FOLDER); } else { + // TODO: revise artifacts download const filesToDelete = ['ort-wasm-simd-threaded.jsep.mjs', 'ort-wasm-simd-threaded.jsep.wasm']; if (!folderName.endsWith('_webgpu')) { filesToDelete.push('ort-wasm-simd-threaded.mjs', 'ort-wasm-simd-threaded.wasm'); diff --git a/js/web/test/e2e/run.js b/js/web/test/e2e/run.js index da02c26b87a4e..4486dcb636a58 100644 --- a/js/web/test/e2e/run.js +++ b/js/web/test/e2e/run.js @@ -168,6 +168,7 @@ function prepareWasmPathOverrideFiles() { fs.copyFileSync(`${sourceFile}.wasm`, path.join(folder, 'ort-wasm-simd-threaded.wasm')); fs.copyFileSync(`${sourceFile}.mjs`, path.join(folder, 'renamed.mjs')); fs.copyFileSync(`${sourceFile}.wasm`, path.join(folder, 'renamed.wasm')); + // TODO: add .asyncify/.jspi fs.copyFileSync(`${sourceFile}.jsep.mjs`, path.join(folder, 'ort-wasm-simd-threaded.jsep.mjs')); fs.copyFileSync(`${sourceFile}.jsep.wasm`, path.join(folder, 'ort-wasm-simd-threaded.jsep.wasm')); fs.copyFileSync(`${sourceFile}.jsep.mjs`, path.join(folder, 'jsep-renamed.mjs')); diff --git a/onnxruntime/__init__.py b/onnxruntime/__init__.py index 56268369bf98a..9ef6f4afbdae5 100644 --- a/onnxruntime/__init__.py +++ b/onnxruntime/__init__.py @@ -55,6 +55,7 @@ register_execution_provider_library, # noqa: F401 set_default_logger_severity, # noqa: F401 set_default_logger_verbosity, # noqa: F401 + set_global_thread_pool_sizes, # noqa: F401 set_seed, # noqa: F401 unregister_execution_provider_library, # noqa: F401 ) diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_base.cc b/onnxruntime/contrib_ops/cpu/bert/attention_base.cc index 52dcb990ab67f..651f270230a75 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_base.cc +++ b/onnxruntime/contrib_ops/cpu/bert/attention_base.cc @@ -227,7 +227,7 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape, output_parameters->is_unidirectional = is_unidirectional_; output_parameters->past_present_share_buffer = (past_present_share_buffer_ != 0 && past != nullptr); output_parameters->do_rotary = do_rotary_; - output_parameters->rotary_embedding = rotary_embedding_ == 0 ? (int)(output_parameters->head_size) : rotary_embedding_; + output_parameters->rotary_dim = rotary_embedding_ == 0 ? (int)(output_parameters->head_size) : rotary_embedding_; output_parameters->mask_filter_value = mask_filter_value_; output_parameters->scale = scale_; output_parameters->mask_type = mask_type; diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_parameters.h b/onnxruntime/contrib_ops/cpu/bert/attention_parameters.h index c3d5128948c6f..77d3089de5d09 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_parameters.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_parameters.h @@ -22,11 +22,12 @@ struct AttentionParameters { int v_hidden_size; // hidden size of V int v_head_size; // hidden size per head of V int num_heads; - int num_splits; - int rotary_embedding; + int num_splits; // number of splits for splitkv + int rotary_dim = 0; // rotary embedding dimension int beam_width; bool is_unidirectional; bool past_present_share_buffer; + bool is_packed_qkv = false; // whether qkv is packed bool do_rotary; bool broadcast_attn_bias_dim_0; bool broadcast_attn_bias_dim_1; @@ -46,13 +47,11 @@ struct DecoderMaskedMultiHeadAttentionParameters : AttentionParameters { int beam_width = 1; // Only NeoX style rotary embedding is supported - int rotary_embedding_dim = 0; int t_step = 0; // Weather to use multihead attention(excludes matmul and bias) bool is_mha = false; bool is_cross_attention = false; - bool is_packed_qkv = false; // Useful to better use global memory bandwidth on certain CUDA architectures. // Turned off by default for now until we fully understand performance implications @@ -83,15 +82,12 @@ struct DecoderMaskedMultiHeadAttentionParameters : AttentionParameters { // Parameters deduced from node attributes and inputs/outputs. struct GroupQueryAttentionParameters : AttentionParameters { + int kv_num_heads; // number of heads of key or value + int kv_hidden_size; // hidden size of key or value int seqlen_past_kv_cache; // sequence length of past kv tensor int seqlen_present_kv_cache; // sequence length of present kv tensor - int kv_hidden_size; - int kv_num_heads; - int num_splits; // number of splits for splitkv - int rotary_dim; // rotary embedding dimension - int local_window_size; // The window size excludes current token. It only includes tokens on the left side. + int local_window_size; // The window size excludes current token. It only includes tokens on the left side. bool kv_share_buffer; - bool is_packed_qkv; bool is_subsequent_prompt; // indicates whether we have past context and seqlen > 1 bool is_first_prompt; // indicates whether this is first decoding step bool rotary_interleaved; @@ -102,18 +98,29 @@ struct GroupQueryAttentionParameters : AttentionParameters { int* zero_ptr; }; +// Parameters deduced from node attributes and inputs/outputs. +struct PagedAttentionParameters : AttentionParameters { + int kv_num_heads; // number of heads of key or value + int kv_hidden_size; // hidden size of key or value + int token_count; // number of tokens in packed query + int block_size; // block size for kv cache + int max_num_blocks_per_seq; // max number of blocks per sequence for kv cache + int num_blocks; // number of blocks in kv cache + int local_window_size; // The window size excludes current token. It only includes tokens on the left side. + bool rotary_interleaved; + float softcap; +}; + // Parameters for sparse attention. struct SparseAttentionParameters : AttentionParameters { int kv_hidden_size; // hidden size of key or value int kv_num_heads; // number of heads of key or value bool do_rotary; // whether to use rotary embedding bool rotary_interleaved; // whether to use interleaved rotary embedding - int rotary_dim; // rotary embedding dimension int sparse_block_size; // block size for sparse attention int num_sparse_layout; // number of sparse layout int stride_col_indices; // shape of block_col_indices is [num_sparse_layout, stride_col_indices] int stride_row_indices; // shape of block_row_indices is [num_sparse_layout, stride_row_indices] - bool is_packed_qkv; // whether qkv is packed int max_rotary_sequence_length; // max sequence length for rotary cos/sin cache int max_cache_sequence_length; // max sequence length for kv cache buffer }; diff --git a/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h b/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h index fa0d33e891f46..338c34acb3cfb 100644 --- a/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h +++ b/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h @@ -12,6 +12,183 @@ namespace onnxruntime { namespace contrib { namespace group_query_attention_helper { +template +Status Check_Q_K_V(const T* query, const T* key, const T* value, const int num_heads, const int kv_num_heads, + int& batch_size, int& sequence_length, int& q_hidden_size, int& kv_hidden_size, int& head_size) { + const auto& query_dims = query->Shape().GetDims(); + if (query_dims.size() != 3) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'query' is expected to have 3 dimensions, got ", + query_dims.size()); + } + batch_size = static_cast(query_dims[0]); + sequence_length = static_cast(query_dims[1]); + q_hidden_size = static_cast(query_dims[2]); + head_size = static_cast(q_hidden_size) / num_heads; + if (head_size % 8 != 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "head_size must be a multiple of 8. Got head_size % 8 == ", + head_size % 8); + } + if (value == nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'key' and 'value' shall be both present, or both absent in the case of packed qkv."); + } + const auto& key_dims = key->Shape().GetDims(); + if (key_dims.size() != 3) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'key' is expected to have 3 dimensions, got ", + key_dims.size()); + } else if (query_dims[0] != key_dims[0]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'query' and 'key' shall have same dim 0 (batch size)"); + } else if (query_dims[1] != key_dims[1]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'query' and 'key' shall have same dim 1 (sequence length)"); + } + kv_hidden_size = static_cast(key_dims[2]); + if (kv_hidden_size % kv_num_heads != 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "kv_hidden_size must be a multiple of kv_num_heads. Got kv_hidden_size % kv_num_heads == ", + kv_hidden_size % kv_num_heads); + } else if (kv_hidden_size / kv_num_heads != head_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "kv_hidden_size / kv_num_heads must be equal to head_size. Got kv_hidden_size / kv_num_heads == ", + kv_hidden_size / kv_num_heads); + } + const auto& value_dims = value->Shape().GetDims(); + if (value_dims.size() != 3) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have 3 dimensions, got ", + value_dims.size()); + } else if (query_dims[0] != value_dims[0]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'query' and 'value' shall have same dim 0 (batch size)"); + } else if (query_dims[1] != value_dims[1]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'query' and 'value' shall have same dim 1 (sequence length)"); + } else if (value_dims[2] != kv_hidden_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have same hidden size as key."); + } + return Status::OK(); +} + +template +Status Check_QKV(const T* packed_qkv, const T* value, const int num_heads, const int kv_num_heads, int& batch_size, + int& sequence_length, int& q_hidden_size, int& kv_hidden_size, int& head_size) { + const auto& packed_dims = packed_qkv->Shape().GetDims(); + if (packed_dims.size() != 3) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'query' is expected to have 3 dimensions, got ", + packed_dims.size()); + } + batch_size = static_cast(packed_dims[0]); + sequence_length = static_cast(packed_dims[1]); + head_size = static_cast(static_cast(packed_dims[2])) / (num_heads + 2 * kv_num_heads); + // Check packed qkv + if (head_size % 8 != 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "head_size must be a multiple of 8. Got head_size % 8 == ", + head_size % 8); + } + if (value != nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'key' and 'value' shall be both present, or both absent in the case of packed qkv."); + } + q_hidden_size = head_size * num_heads; + kv_hidden_size = head_size * kv_num_heads; + return Status::OK(); +} + +template +Status CheckPast(const T* past_key, const T* past_value, int batch_size, int kv_num_heads, int head_size, + int& past_sequence_length) { + const auto& past_key_dims = past_key->Shape().GetDims(); + const auto& past_value_dims = past_value->Shape().GetDims(); + + if (past_key_dims.size() != 4) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'past_key' is expected to have 4 dimensions, got ", + past_key_dims.size()); + } + if (past_value_dims.size() != 4) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'past_value' is expected to have 4 dimensions, got ", + past_value_dims.size()); + } + + if (past_key_dims[0] != batch_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'past_key' dimension 0 should be batch_size, got ", + past_key_dims[0]); + } + if (past_value_dims[0] != batch_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'past_value' dimension 0 should be batch_size, got ", + past_value_dims[0]); + } + + if (past_key_dims[2] != past_value_dims[2]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "BNSH Input 'past_key' and 'past_value' should have same dimension 2 (max sequence" + "length or past sequence length), got ", + past_key_dims[1]); + } + if (past_key_dims[1] != kv_num_heads) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'past_key' shall have kv_num_heads"); + } + if (past_value_dims[1] != kv_num_heads) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'past_value' shall have kv_num_heads"); + } + // We assume all sequence in past kv are right-padded to max or past sequence length + past_sequence_length = static_cast(past_key_dims[2]); + + if (past_key_dims[3] != head_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'past_key' dimension 3 should be same as head_size, got ", + past_key_dims[3]); + } + if (past_value_dims[3] != head_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'past_value' dimension 3 should be same as head_size, got ", + past_value_dims[3]); + } + return Status::OK(); +} + +template +Status CheckRotaryCaches(const T* cos_cache, const T* sin_cache, int head_size, int total_sequence_length, + int& rotary_dim) { + const auto& cos_dims = cos_cache->Shape().GetDims(); + const auto& sin_dims = sin_cache->Shape().GetDims(); + + if (head_size % 16 != 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "head_size shall be a multiple of 16. Got head_size % 16 == ", + head_size % 16); + } + if (cos_dims[0] < total_sequence_length) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "cos_cache dimension 0 shall not be less than total_sequence_length."); + } + if (sin_dims[0] < total_sequence_length) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "sin_cache dimension 0 shall not be less than total_sequence_length."); + } + if (cos_dims[1] > (head_size / 16) * 8 || cos_dims[1] % 8 != 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "cos_cache dimension 1 must be <= head_size / 2 and a multiple of 8."); + } + if (sin_dims[1] > (head_size / 16) * 8 || sin_dims[1] % 8 != 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "sin_cache dimension 1 must be <= head_size / 2 and a multiple of 8."); + } + if (cos_dims[1] != sin_dims[1]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "cos_cache and sin_cache dimension 1 must be the same."); + } + rotary_dim = static_cast(cos_dims[1] * 2); + return Status::OK(); +} + template Status CheckInputs(const T* query, const T* key, @@ -37,18 +214,6 @@ Status CheckInputs(const T* query, AttentionQkvFormat qkv_format = Q_K_V_BSNH; AttentionQkvFormat past_kv_format = Q_K_V_BNSH; - const bool is_packed_qkv = key == nullptr; - - const auto& query_dims = query->Shape().GetDims(); - if (query_dims.size() != 3) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'query' is expected to have 3 dimensions, got ", - query_dims.size()); - } - - int batch_size = static_cast(query_dims[0]); - int sequence_length = static_cast(query_dims[1]); - int q_hidden_size = static_cast(query_dims[2]); - int head_size = 0; if (num_heads % kv_num_heads != 0) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, @@ -56,115 +221,25 @@ Status CheckInputs(const T* query, num_heads % kv_num_heads); } + int batch_size = 0; + int sequence_length = 0; + int q_hidden_size = 0; int kv_hidden_size = 0; - // Check key and value when not packed + int head_size = 0; + const bool is_packed_qkv = key == nullptr; if (!is_packed_qkv) { - head_size = static_cast(q_hidden_size) / num_heads; - if (head_size % 8 != 0) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "head_size must be a multiple of 8. Got head_size % 8 == ", - head_size % 8); - } - if (value == nullptr) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'key' and 'value' shall be both present, or both absent in the case of packed qkv."); - } - const auto& key_dims = key->Shape().GetDims(); - if (key_dims.size() != 3) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'key' is expected to have 3 dimensions, got ", - key_dims.size()); - } else if (query_dims[0] != key_dims[0]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'query' and 'key' shall have same dim 0 (batch size)"); - } else if (query_dims[1] != key_dims[1]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'query' and 'key' shall have same dim 1 (sequence length)"); - } - kv_hidden_size = static_cast(key_dims[2]); - const auto& value_dims = value->Shape().GetDims(); - if (value_dims.size() != 3) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have 3 dimensions, got ", - value_dims.size()); - } else if (query_dims[0] != value_dims[0]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'query' and 'value' shall have same dim 0 (batch size)"); - } else if (query_dims[1] != value_dims[1]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'query' and 'value' shall have same dim 1 (sequence length)"); - } else if (value_dims[2] != kv_hidden_size) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have same hidden size as key."); - } + ORT_RETURN_IF_ERROR(Check_Q_K_V(query, key, value, num_heads, kv_num_heads, batch_size, sequence_length, + q_hidden_size, kv_hidden_size, head_size)); } else { - // Check packed qkv - head_size = static_cast(q_hidden_size) / (num_heads + 2 * kv_num_heads); - if (head_size % 8 != 0) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "head_size must be a multiple of 8. Got head_size % 8 == ", - head_size % 8); - } - if (value != nullptr) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'key' and 'value' shall be both present, or both absent in the case of packed qkv."); - } - q_hidden_size = head_size * num_heads; - kv_hidden_size = head_size * kv_num_heads; + qkv_format = QKV_BS3NH; + ORT_RETURN_IF_ERROR(Check_QKV(query, value, num_heads, kv_num_heads, batch_size, sequence_length, q_hidden_size, + kv_hidden_size, head_size)); } // Check past-present KV int32_t past_sequence_length = 0; if (past_key != nullptr && past_value != nullptr) { - const auto& past_key_dims = past_key->Shape().GetDims(); - const auto& past_value_dims = past_value->Shape().GetDims(); - - if (past_key_dims.size() != 4) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'past_key' is expected to have 4 dimensions, got ", - past_key_dims.size()); - } - if (past_value_dims.size() != 4) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'past_value' is expected to have 4 dimensions, got ", - past_value_dims.size()); - } - - if (past_key_dims[0] != batch_size) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'past_key' dimension 0 should be batch_size, got ", - past_key_dims[0]); - } - if (past_value_dims[0] != batch_size) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'past_value' dimension 0 should be batch_size, got ", - past_value_dims[0]); - } - - if (past_key_dims[2] != past_value_dims[2]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "BNSH Input 'past_key' and 'past_value' should have same dimension 2 (max sequence" - "length or past sequence length), got ", - past_key_dims[1]); - } - if (past_key_dims[1] != kv_num_heads) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'past_key' shall have kv_num_heads"); - } - if (past_value_dims[1] != kv_num_heads) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'past_value' shall have kv_num_heads"); - } - // We assume all sequence in past kv are right-padded to max or past sequence length - past_sequence_length = static_cast(past_key_dims[2]); - - if (past_key_dims[3] != head_size) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'past_key' dimension 3 should be same as head_size, got ", - past_key_dims[3]); - } - if (past_value_dims[3] != head_size) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'past_value' dimension 3 should be same as head_size, got ", - past_value_dims[3]); - } + ORT_RETURN_IF_ERROR(CheckPast(past_key, past_value, batch_size, kv_num_heads, head_size, past_sequence_length)); } else if (past_key != nullptr || past_value != nullptr) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'past_key' and 'past_value' shall be both present or both absent."); @@ -186,35 +261,7 @@ Status CheckInputs(const T* query, int rotary_dim = 0; if (cos_cache != nullptr && sin_cache != nullptr) { - const auto& cos_dims = cos_cache->Shape().GetDims(); - const auto& sin_dims = sin_cache->Shape().GetDims(); - - if (head_size % 16 != 0) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "head_size shall be a multiple of 16. Got head_size % 16 == ", - head_size % 16); - } - if (cos_dims[0] < total_sequence_length) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "cos_cache dimension 0 shall not be less than total_sequence_length."); - } - if (sin_dims[0] < total_sequence_length) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "sin_cache dimension 0 shall not be less than total_sequence_length."); - } - if (cos_dims[1] > (head_size / 16) * 8 || cos_dims[1] % 8 != 0) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "cos_cache dimension 1 must be <= head_size / 2 and a multiple of 8."); - } - if (sin_dims[1] > (head_size / 16) * 8 || sin_dims[1] % 8 != 0) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "sin_cache dimension 1 must be <= head_size / 2 and a multiple of 8."); - } - if (cos_dims[1] != sin_dims[1]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "cos_cache and sin_cache dimension 1 must be the same."); - } - rotary_dim = static_cast(cos_dims[1] * 2); + ORT_RETURN_IF_ERROR(CheckRotaryCaches(cos_cache, sin_cache, head_size, total_sequence_length, rotary_dim)); } else if (cos_cache != nullptr || sin_cache != nullptr) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'cos_cache' and 'sin_cache' shall be both present or both absent."); diff --git a/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc index 9a6c2af022c91..6ac8562e41e0b 100644 --- a/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc +++ b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc @@ -57,15 +57,20 @@ Status RunRotaryEmbedding(concurrency::ThreadPool* tp, RotaryParameters paramete const int position_ids_format = parameters.position_ids_format; const int rotary_emb_dim = parameters.rotary_embedding_dim; const int half_rotary_emb_dim = rotary_emb_dim / 2; - + // Parallel to calculate based on head_size const int loop_len = batch_size * sequence_length * n_heads; - const double cost = static_cast(rotary_emb_dim); + // The cost is calculated as: + // - head_size * sizeof(T) for reading input + // - head_size * sizeof(T) for writing output + // - rotary_emb_dim * 32 for the rotary embedding operations (32 is an approximation of the number of CPU cycles) + const double cost = static_cast(head_size * sizeof(T) * 2 + rotary_emb_dim * 32); ThreadPool::TryParallelFor(tp, loop_len, cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) { for (std::ptrdiff_t ptr = begin; ptr != end; ++ptr) { const int b = static_cast((ptr / n_heads) / sequence_length); const int s = static_cast((ptr / n_heads) % sequence_length); const int n = static_cast(ptr % n_heads); - + // Identify the index of batch, sequence, and head (specific range) in the input/output tensor + // for read/write const int block_offset = b * batch_stride + s * seq_stride + n * head_stride; const T* input_data = input + block_offset; diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc index 65e8808190da3..1a4a38282fcc1 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc @@ -440,8 +440,6 @@ Status MatMulNBits::ComputeBUnpacked(const Tensor* a, AllocatorPtr& allocator, concurrency::ThreadPool* thread_pool, const MatMulComputeHelper& helper) const { - ORT_ENFORCE(nbits_ == 4, "Only 4b quantization is supported for unpacked compute."); - const auto* a_data = a->Data(); const uint8_t* b_data = b->Data(); const auto* scales_data = scales->Data(); @@ -460,19 +458,37 @@ Status MatMulNBits::ComputeBUnpacked(const Tensor* a, auto tmp_b_data_ptr = IAllocator::MakeUniquePtr(allocator, SafeInt(K_) * N_, true); if ((reorder_idx_data == nullptr) && (!zero_points || !zero_points->IsDataType())) { - // dequantize b, only 4b quantization is supported for now - MlasDequantizeBlockwise( - tmp_b_data_ptr.get(), // dequantized output - b_data, // quantized input - scales_data, // quantization scales - static_cast(zero_points_data), // quantization zero points - static_cast(block_size_), // quantization block size - column_wise_quant_, // columnwise quantization or row-wise - static_cast(K_), // number of rows in quantized input - static_cast(N_), // number of columns in quantized input - thread_pool); + if (nbits_ == 4) { + MlasDequantizeBlockwise( + tmp_b_data_ptr.get(), // dequantized output + b_data, // quantized input + scales_data, // quantization scales + static_cast(zero_points_data), // quantization zero points + static_cast(block_size_), // quantization block size + column_wise_quant_, // columnwise quantization or row-wise + static_cast(K_), // number of rows in quantized input + static_cast(N_), // number of columns in quantized input + thread_pool); + } else { // If it isn't 4bit, it has to be 8-bit quantization + ORT_ENFORCE(nbits_ == 8); + MlasDequantizeBlockwise( + tmp_b_data_ptr.get(), // dequantized output + b_data, // quantized input + scales_data, // quantization scales + static_cast(zero_points_data), // quantization zero points + static_cast(block_size_), // quantization block size + column_wise_quant_, // columnwise quantization or row-wise + static_cast(K_), // number of rows in quantized input + static_cast(N_), // number of columns in quantized input + thread_pool); + } } else { + // Hitting any of the below is very rare ORT_ENFORCE(column_wise_quant_, "Row-wise quantization is not supported for now"); + ORT_ENFORCE(nbits_ == 4, + "Only 4b quantization is supported for unpacked compute using " + "non-MLAS de-quantization for now"); + // !!!!!!!!!!!!!! naive implementation, need to be optimized !!!!!!!!!!!!!! if (zero_points && zero_points->IsDataType()) { DequantizeBlockwise( @@ -550,7 +566,6 @@ Status MatMulNBits::ComputeBUnpacked(const Tensor* a, AllocatorPtr& allocator, concurrency::ThreadPool* thread_pool, const MatMulComputeHelper& helper) const { - ORT_ENFORCE(nbits_ == 4, "Only 4b quantization is supported for unpacked compute."); const auto* a_data = a->Data(); const uint8_t* b_data = b->Data(); const auto* scales_data = scales->Data(); @@ -580,19 +595,37 @@ Status MatMulNBits::ComputeBUnpacked(const Tensor* a, auto tmp_b_data_ptr = IAllocator::MakeUniquePtr(allocator, SafeInt(K_) * N_, true); if ((reorder_idx_data == nullptr) && (!zero_points || !zero_points->IsDataType())) { - // dequantize b, only 4b quantization is supported for now - MlasDequantizeBlockwise( - tmp_b_data_ptr.get(), // dequantized output - b_data, // quantized input - scales_ptr, // quantization scales - static_cast(zero_points_data), // quantization zero points - static_cast(block_size_), // quantization block size - column_wise_quant_, // columnwise quantization or row-wise - static_cast(K_), // number of rows in quantized input - static_cast(N_), // number of columns in quantized input - thread_pool); + if (nbits_ == 4) { + MlasDequantizeBlockwise( + tmp_b_data_ptr.get(), // dequantized output + b_data, // quantized input + scales_ptr, // quantization scales + static_cast(zero_points_data), // quantization zero points + static_cast(block_size_), // quantization block size + column_wise_quant_, // columnwise quantization or row-wise + static_cast(K_), // number of rows in quantized input + static_cast(N_), // number of columns in quantized input + thread_pool); + } else { // If it isn't 4bit, it has to be 8-bit quantization + ORT_ENFORCE(nbits_ == 8); + MlasDequantizeBlockwise( + tmp_b_data_ptr.get(), // dequantized output + b_data, // quantized input + scales_ptr, // quantization scales + static_cast(zero_points_data), // quantization zero points + static_cast(block_size_), // quantization block size + column_wise_quant_, // columnwise quantization or row-wise + static_cast(K_), // number of rows in quantized input + static_cast(N_), // number of columns in quantized input + thread_pool); + } } else { + // Hitting any of the below is very rare ORT_ENFORCE(column_wise_quant_, "Row-wise quantization is not supported for now"); + ORT_ENFORCE(nbits_ == 4, + "Only 4b quantization is supported for unpacked compute using " + "non-MLAS de-quantization for now"); + // !!!!!!!!!!!!!! naive implementation, need to be optimized !!!!!!!!!!!!!! if (zero_points && zero_points->IsDataType()) { DequantizeBlockwise( @@ -720,6 +753,13 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { } } + // TODO(hasesh): Should this logging level be warning ? + LOGS(ctx->Logger(), INFO) << "Falling back to using unpacked compute mode for the Matmul operation " + "(i.e.) the weights will be de-quantized to fp32 before invoking " + "the fp32 Matmul kernel." + "This is because MLAS doesn't have an optimized quantized kernel " + "for the requested compute configuration."; + return ComputeBUnpacked(a, b, scales, zero_points, reorder_idx, bias, y, allocator, thread_pool, helper); } diff --git a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc index 09bce9828aa33..5dc0a2efcbe93 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc @@ -215,7 +215,7 @@ Status T5DecoderSubgraph::CreateInitialFeeds( } else { size_t total_size = static_cast(sequence_length) * static_cast(batch_beam_size); size_t total_size_bytes = total_size * sizeof(int); - AllocatorPtr buffer_allocator = std::make_shared(); + AllocatorPtr buffer_allocator = CPUAllocator::DefaultInstance(); // TODO: not need extra buffer. Copy directly to input_ids_data instead like the user_cuda above. auto seq_copy = IAllocator::MakeUniquePtr(buffer_allocator, total_size_bytes, false, stream); int* seq_copy_ptr = seq_copy.get(); diff --git a/onnxruntime/contrib_ops/cpu/transformers/subgraph_whisper_decoder.cc b/onnxruntime/contrib_ops/cpu/transformers/subgraph_whisper_decoder.cc index bf866d67ffc0d..ad778fb7ef907 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/subgraph_whisper_decoder.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/subgraph_whisper_decoder.cc @@ -167,7 +167,7 @@ Status WhisperDecoderSubgraph::CreateInitialFeeds( Tensor::InitOrtValue(DataTypeImpl::GetType(), input_ids_shape, allocator, input_ids); int32_t* input_ids_data = input_ids.GetMutable()->MutableData(); - AllocatorPtr buffer_allocator = std::make_shared(); + AllocatorPtr buffer_allocator = CPUAllocator::DefaultInstance(); size_t total_size = static_cast(static_cast(cur_len) * batch_beam_size * sizeof(int)); auto seq_copy = IAllocator::MakeUniquePtr(buffer_allocator, total_size, false, stream); int* seq_copy_ptr = seq_copy.get(); diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_data.h b/onnxruntime/contrib_ops/cuda/bert/attention_data.h index c7b06d50858b4..691391ccef0d0 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_data.h +++ b/onnxruntime/contrib_ops/cuda/bert/attention_data.h @@ -180,6 +180,35 @@ struct GroupQueryAttentionData { bool use_memory_efficient_attention = false; }; +template +struct PagedAttentionData { + // Input Tensors + const T* query = nullptr; + const T* key = nullptr; + const T* value = nullptr; + T* key_cache = nullptr; + T* value_cache = nullptr; + const int* cumulative_seqlens_q = nullptr; + const int* past_seqlens = nullptr; + const int* block_table = nullptr; + const int* slot_mappings = nullptr; + const T* cos_cache = nullptr; + const T* sin_cache = nullptr; + + // Flash buffers + T* softmax_lse = nullptr; + int* cumulative_seqlens_kv = nullptr; // Flash api takes cumulative sequence length for kv-cache + + // Fused op buffers + T* workspace_buffer = nullptr; + + // Output Tensors + T* output = nullptr; + + // Kernel Flags + bool use_flash_attention = false; +}; + } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu b/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu index 122e94d9558e3..a7989df3439ae 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu @@ -150,7 +150,7 @@ Status PrepareQkv_Attention(contrib::AttentionParameters& parameters, LaunchAddBiasTranspose(stream, matrix_to_transpose, format, max_threads_per_block, batch_size, sequence_length, num_heads, qk_head_size, data.gemm_buffer, data.bias, qkv, true, v_head_size, qkv_add_bias, - 3, parameters.do_rotary, parameters.rotary_embedding, + 3, parameters.do_rotary, parameters.rotary_dim, parameters.past_sequence_length); return Status::OK(); } diff --git a/onnxruntime/contrib_ops/cuda/bert/decoder_masked_self_attention.cc b/onnxruntime/contrib_ops/cuda/bert/decoder_masked_self_attention.cc index a15b59d0c018a..b4643da58eba5 100644 --- a/onnxruntime/contrib_ops/cuda/bert/decoder_masked_self_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/decoder_masked_self_attention.cc @@ -195,7 +195,7 @@ Status DecoderMaskedSelfAttention::ComputeInternal(OpKernelContext* cont if (do_rotary_) { ORT_ENFORCE(parameters.head_size == 64 || parameters.head_size == 128, "Current implementation of rotary embedding only supports head size of 64 or 128"); - parameters.rotary_embedding_dim = parameters.head_size; + parameters.rotary_dim = parameters.head_size; parameters.t_step = parameters.past_sequence_length; } diff --git a/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu index 75ea7454791b6..6ba5ce66eaa60 100644 --- a/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu @@ -212,13 +212,13 @@ __global__ void masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentio } } - if (params.rotary_embedding_dim > 0) { - const bool do_rotary = !is_masked && QK_VEC_SIZE * tidx < params.rotary_embedding_dim; + if (params.rotary_dim > 0) { + const bool do_rotary = !is_masked && QK_VEC_SIZE * tidx < params.rotary_dim; T* q_smem = reinterpret_cast(smem_); - T* k_smem = q_smem + params.rotary_embedding_dim; + T* k_smem = q_smem + params.rotary_dim; - const int half_rotary_dim = params.rotary_embedding_dim / 2; + const int half_rotary_dim = params.rotary_dim / 2; const int half_idx = (tidx * QK_VEC_SIZE) / half_rotary_dim; const int intra_half_idx = (tidx * QK_VEC_SIZE) % half_rotary_dim; const int smem_pitch = half_rotary_dim; @@ -240,7 +240,7 @@ __global__ void masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentio vec_from_smem_transpose(k, k_smem, transpose_idx, smem_pitch); apply_rotary_embedding( - q, k, transpose_idx / tidx_factor, params.rotary_embedding_dim, params.t_step); + q, k, transpose_idx / tidx_factor, params.rotary_dim, params.t_step); write_smem_transpose(k, k_smem, transpose_idx, smem_pitch); write_smem_transpose(q, q_smem, transpose_idx, smem_pitch); diff --git a/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl_utils.h b/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl_utils.h index 08e4293528d5a..586732834f0ad 100644 --- a/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl_utils.h +++ b/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl_utils.h @@ -165,8 +165,8 @@ inline size_t CalcDynamicBlockMemory(const DecoderMaskedMultiHeadAttentionParame size_t red_sz = rows_per_red * params.head_size * sizeof(T) / 2; size_t transpose_rotary_size = 0; - if (params.rotary_embedding_dim > 0) { - transpose_rotary_size = 2 * params.rotary_embedding_dim * sizeof(T); + if (params.rotary_dim > 0) { + transpose_rotary_size = 2 * params.rotary_dim * sizeof(T); } // The max. diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash.h index 4aa633ca45e2b..c24bf88fa729b 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash.h +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash.h @@ -70,6 +70,7 @@ struct Flash_fwd_params : public Qkv_params { int seqlen_k_rounded = 0; int d_rounded = 0; int rotary_dim = 0; + int total_q = 0; // The scaling factors for the kernel. float scale_softmax = 0.0; @@ -129,6 +130,7 @@ struct Flash_fwd_params : public Qkv_params { void* __restrict__ alibi_slopes_ptr = nullptr; index_t alibi_slopes_batch_stride = 0; + bool unpadded_lse = false; const cudaDeviceProp* dprops = nullptr; }; diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc index 453dffaa2e6e6..b0241c26aafc6 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc @@ -41,7 +41,8 @@ void set_params_fprop(Flash_fwd_params& params, bool use_smooth_softmax, bool kv_bsnh = true, int window_size_left = -1, - int window_size_right = -1) { + int window_size_right = -1, + const bool unpadded_lse = false) { // Set the pointers and strides. params.q_ptr = q; params.k_ptr = k; @@ -142,6 +143,7 @@ void set_params_fprop(Flash_fwd_params& params, params.window_size_right = window_size_right; params.is_seqlens_k_cumulative = true; + params.unpadded_lse = unpadded_lse; } size_t get_softmax_lse_size(size_t seqlen, size_t batch_size, size_t num_heads) { @@ -149,6 +151,11 @@ size_t get_softmax_lse_size(size_t seqlen, size_t batch_size, size_t num_heads) return bytes; } +size_t get_softmax_lse_size(size_t token_count, size_t num_heads) { + size_t bytes = sizeof(float) * token_count * num_heads; + return bytes; +} + size_t get_softmax_lse_accum_size(size_t num_splits, size_t batch_size, size_t num_heads, size_t seqlen_q) { size_t bytes = sizeof(float) * num_splits * batch_size * seqlen_q * num_heads; return bytes; @@ -336,6 +343,8 @@ Status mha_fwd(const cudaDeviceProp& dprops, return Status::OK(); } +// TODO(aciddelgado): Baiju wants this https://github.com/Dao-AILab/flash-attention/pull/824 + Status mha_varlen_fwd(const cudaDeviceProp& dprops, cudaStream_t stream, void* q, // half (total_q, num_heads, head_size) @@ -353,10 +362,12 @@ Status mha_varlen_fwd(const cudaDeviceProp& dprops, int head_size, int max_seqlen_q, int max_seqlen_k, + int total_q, float softmax_scale, const float softcap, bool is_causal, bool is_bf16, + int local_window_size, int max_num_blocks_per_seq, int page_block_size) { auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; @@ -384,8 +395,11 @@ Status mha_varlen_fwd(const cudaDeviceProp& dprops, is_bf16, false, true, - -1, - is_causal ? 0 : -1); + local_window_size, + is_causal ? 0 : -1, + /*unpadded_lse*/ true); + + params.total_q = total_q; params.dprops = &dprops; params.num_splits = 0; params.softmax_lseaccum_ptr = nullptr; @@ -394,7 +408,7 @@ Status mha_varlen_fwd(const cudaDeviceProp& dprops, params.vnew_ptr = nullptr; params.alibi_slopes_ptr = nullptr; if (paged_KV) { - params.block_table = block_table; // TODO(aciddelgado): cast to int pointer + params.block_table = block_table; params.block_table_batch_stride = max_num_blocks_per_seq; // params.num_blocks = num_blocks; params.page_block_size = page_block_size; @@ -406,7 +420,8 @@ Status mha_varlen_fwd(const cudaDeviceProp& dprops, // params.num_blocks = 0; params.page_block_size = 1; } - run_mha_fwd(params, stream); + + run_mha_fwd(params, stream, paged_KV); return Status::OK(); } @@ -536,7 +551,7 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops, params.alibi_slopes_ptr = nullptr; if (paged_KV) { - params.block_table = block_table; // TODO(aciddelgado): cast to int pointer + params.block_table = block_table; params.block_table_batch_stride = max_num_blocks_per_seq; // params.num_blocks = num_blocks; params.page_block_size = page_block_size; diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h index 57752e8237d6e..e28e38ea3ed93 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h @@ -77,10 +77,12 @@ Status mha_varlen_fwd(const cudaDeviceProp& dprops, int head_size, int max_seqlen_q, int max_seqlen_k, + int total_q, float softmax_scale, const float softcap, bool is_causal, bool is_bf16, + int local_window_size = -1, int max_num_blocks_per_seq = 0, int page_block_size = 1); @@ -121,6 +123,7 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops, int page_block_size = 1); size_t get_softmax_lse_size(size_t max_seqlen_q, size_t batch_size, size_t num_heads); +size_t get_softmax_lse_size(size_t token_count, size_t num_heads); std::tuple get_num_splits_and_buffer_sizes(size_t batch_size, size_t seqlen_q, size_t seqlen_k, size_t num_heads, size_t head_size, size_t num_SMs); diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_kernel.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_kernel.h index d46d9597a758f..4110e715c4391 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_kernel.h +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_kernel.h @@ -34,6 +34,26 @@ using namespace cute; //////////////////////////////////////////////////////////////////////////////////////////////////// +template +__forceinline__ __device__ auto get_lse_tile(const Params& params, const int bidb, const int bidh, const int m_block, const BlockInfo& binfo) { + // When params.unpadded_lse is false, LSE is written as (b, h, seqlen_q) - this is non-variable seqlen path. + // Otherwise, when params.seqlenq_ngroups_swapped is true, it is written as (h, seqlen_q, b) to account for seqlen_q <-> h swapping trick. + // Otherwise, it's written as (h, b, seqlen_q). + const bool varlen_q = params.unpadded_lse; + auto lse_offset = varlen_q ? binfo.q_offset(params.seqlen_q, 1, bidb) : 0; + auto gmem_ptr_lse = make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr) + lse_offset); + + auto lse_shape = varlen_q ? make_shape(1, params.h, params.total_q) : make_shape(params.b, params.h, params.seqlen_q); + auto lse_stride = params.unpadded_lse + ? make_stride(params.h * params.total_q, params.total_q, 1) + : make_stride(params.h * params.seqlen_q, params.seqlen_q, 1); + + auto lse_layout = make_layout(lse_shape, lse_stride); + Tensor mLSE = make_tensor(gmem_ptr_lse, lse_layout); + auto mLSE_slice = varlen_q ? mLSE(0, bidh, _) : mLSE(bidb, bidh, _); + return local_tile(mLSE_slice, Shape>{}, make_coord(m_block)); +} + template inline __device__ void compute_attn_1rowblock(const Params& params, const int bidb, const int bidh, const int m_block) { using Element = typename Kernel_traits::Element; @@ -70,10 +90,8 @@ inline __device__ void compute_attn_1rowblock(const Params& params, const int bi make_stride(params.o_row_stride, params.o_head_stride, _1{})); Tensor gO = local_tile(mO(_, bidh, _), Shape, Int>{}, make_coord(m_block, 0)); // (kBlockM, kHeadDim) - Tensor mLSE = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr)), - make_shape(params.b, params.h, params.seqlen_q), - make_stride(params.h * params.seqlen_q, params.seqlen_q, _1{})); - Tensor gLSE = local_tile(mLSE(bidb, bidh, _), Shape>{}, make_coord(m_block)); + + Tensor gLSE = get_lse_tile(params, bidb, bidh, m_block, binfo); typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O; auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx); @@ -375,10 +393,7 @@ inline __device__ void compute_attn_1rowblock(const Params& params, const int bi make_stride(params.o_row_stride, params.o_head_stride, _1{})); Tensor gO = local_tile(mO(_, bidh, _), Shape, Int>{}, make_coord(m_block, 0)); // (kBlockM, kHeadDim) - Tensor mLSE = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr)), - make_shape(params.b, params.h, params.seqlen_q), - make_stride(params.h * params.seqlen_q, params.seqlen_q, _1{})); - Tensor gLSE = local_tile(mLSE(bidb, bidh, _), Shape>{}, make_coord(m_block)); + Tensor gLSE = get_lse_tile(params, bidb, bidh, m_block, binfo); typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O; auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx); @@ -938,8 +953,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, cons const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; const index_t row_offset_oaccum = (((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q + m_block * kBlockM) * params.d_rounded; - const index_t row_offset_lseaccum = ((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q + m_block * kBlockM; - + const index_t row_offset_lseaccum = (Split || !params.unpadded_lse ? ((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q : bidh * params.total_q + binfo.q_offset(params.seqlen_q, 1, bidb)) + m_block * kBlockM; Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast(Split ? params.oaccum_ptr : params.o_ptr) + (Split ? row_offset_oaccum : row_offset_o)), Shape, Int>{}, make_stride(Split ? kHeadDim : params.o_row_stride, _1{})); @@ -1047,12 +1061,24 @@ inline __device__ void combine_attn_seqk_parallel(const Params& params) { const int tidx = threadIdx.x; const int bidx = blockIdx.x; + const index_t lse_size = params.b * params.h * params.seqlen_q; + const index_t row_offset_lse = bidx * kBlockM; Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lseaccum_ptr) + row_offset_lse), Shape, Int>{}, - make_stride(params.b * params.h * params.seqlen_q, _1{})); + make_stride(lse_size, _1{})); + Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr) + row_offset_lse), Shape>{}, Stride<_1>{}); + // This layout maps row_offset_lse to {bidh, q_offset, bidb} or {bidh, bidb, q_offset}. + Layout flat_layout = make_layout(lse_size); + Layout orig_layout = make_layout(make_shape(params.seqlen_q, params.h, params.b)); + auto transposed_stride = make_stride(1, params.seqlen_q * params.b, params.seqlen_q); + Layout remapped_layout = make_layout(make_shape(params.seqlen_q, params.h, params.b), transposed_stride); + Layout final_layout = cute::composition(remapped_layout, cute::composition(orig_layout, flat_layout)); + + Tensor gLSE_unpadded = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr)), final_layout); + constexpr int kNLsePerThread = (kMaxSplits * kBlockM + kNThreads - 1) / kNThreads; // Read the LSE values from gmem and store them in shared memory, then tranpose them. @@ -1107,7 +1133,14 @@ inline __device__ void combine_attn_seqk_parallel(const Params& params) { ElementAccum lse_logsum = (lse_sum == 0.f || lse_sum != lse_sum) ? std::numeric_limits::infinity() : logf(lse_sum) + lse_max; // if (bidx == 0 && tidx < 32) { printf("tidx = %d, lse = %f, lse_max = %f, lse_logsum = %f\n", tidx, lse_accum(0), lse_max, lse_logsum); } if (tidx % kRowsPerLoadTranspose == 0 && tidx / kRowsPerLoadTranspose < kBlockM) { - gLSE(tidx / kRowsPerLoadTranspose) = lse_logsum; + if (params.unpadded_lse) { + const index_t lse_offset = row_offset_lse + tidx / kRowsPerLoadTranspose; + if (lse_offset < lse_size) { + gLSE_unpadded(lse_offset) = lse_logsum; + } + } else { + gLSE(tidx / kRowsPerLoadTranspose) = lse_logsum; + } } // Store the scales exp(lse - lse_logsum) in shared memory. #pragma unroll diff --git a/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu index 846d2be7bf2e1..07bca3f7fff99 100644 --- a/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu @@ -589,6 +589,7 @@ Status FlashAttention( PackedMultiHeadAttentionData& data) { const int batch_size = parameters.batch_size; const int sequence_length = parameters.sequence_length; + const int token_count = parameters.token_count; const int num_heads = parameters.num_heads; const int qk_head_size = parameters.head_size; const int v_head_size = parameters.v_head_size; @@ -638,6 +639,7 @@ Status FlashAttention( qk_head_size, sequence_length, sequence_length, + token_count, scale, 0.0, false, // is causal diff --git a/onnxruntime/contrib_ops/cuda/bert/paged_attention.cc b/onnxruntime/contrib_ops/cuda/bert/paged_attention.cc new file mode 100644 index 0000000000000..4189965ab9137 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/paged_attention.cc @@ -0,0 +1,218 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/cuda/cuda_common.h" +#include "core/platform/env_var_utils.h" +#include "contrib_ops/cpu/utils/dump_tensor.h" +#include "contrib_ops/cuda/utils/dump_cuda_tensor.h" +#include "contrib_ops/cuda/bert/paged_attention_impl.h" +#include "contrib_ops/cuda/bert/paged_attention.h" +#include "contrib_ops/cuda/bert/paged_attention_helper.h" +#include "contrib_ops/cuda/bert/flash_attention/flash_api.h" + +using namespace onnxruntime::cuda; +using namespace ::onnxruntime::common; +using namespace ONNX_NAMESPACE; + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +#define REGISTER_KERNEL_TYPED(T) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + PagedAttention, \ + kMSDomain, \ + 1, \ + T, \ + kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("S", DataTypeImpl::GetTensorType()), \ + PagedAttention); + +REGISTER_KERNEL_TYPED(MLFloat16) +REGISTER_KERNEL_TYPED(BFloat16) + +template +PagedAttention::PagedAttention(const OpKernelInfo& info) + : CudaKernel(info) { + int64_t num_heads = 0; + int64_t kv_num_heads = 0; + ORT_ENFORCE(info.GetAttr("num_heads", &num_heads).IsOK() && num_heads > 0); + ORT_ENFORCE(info.GetAttr("kv_num_heads", &kv_num_heads).IsOK() && kv_num_heads > 0 && num_heads % kv_num_heads == 0); + num_heads_ = static_cast(num_heads); + kv_num_heads_ = static_cast(kv_num_heads); + local_window_size_ = static_cast(info.GetAttrOrDefault("local_window_size", -1)); + do_rotary_ = info.GetAttrOrDefault("do_rotary", 0) == 1; + rotary_interleaved_ = info.GetAttrOrDefault("rotary_interleaved", 0) == 1; + scale_ = info.GetAttrOrDefault("scale", 0.0f); + softcap_ = info.GetAttrOrDefault("softcap", 0.0f); + + kernel_options_ = this->GetAttentionKernelOptions(); + disable_flash_attention_ = sizeof(T) != 2 || !kernel_options_->UseFlashAttention(); +} + +template +Status PagedAttention::ComputeInternal(OpKernelContext* context) const { + const Tensor* query = context->Input(0); + const Tensor* key = context->Input(1); + const Tensor* value = context->Input(2); + const Tensor* key_cache = context->Input(3); + const Tensor* value_cache = context->Input(4); + const Tensor* cumulative_seqlens_q = context->Input(5); + const Tensor* past_seqlens = context->Input(6); + const Tensor* block_table = context->Input(7); + const Tensor* cos_cache = context->Input(8); + const Tensor* sin_cache = context->Input(9); + + auto& device_prop = GetDeviceProp(); + PagedAttentionParameters parameters; + typedef typename ToCudaType::MappedType CudaT; + PagedAttentionData data; + + // Check shapes of inputs to op and set parameters + ORT_RETURN_IF_ERROR(paged_attention_helper::CheckInputs(query, + key, + value, + key_cache, + value_cache, + cumulative_seqlens_q, + past_seqlens, + block_table, + cos_cache, + sin_cache, + ¶meters, + num_heads_, + kv_num_heads_, + scale_, + softcap_, + device_prop.maxThreadsPerBlock)); + parameters.local_window_size = local_window_size_; + parameters.do_rotary = do_rotary_; + parameters.rotary_interleaved = rotary_interleaved_; + + DUMP_STRING_INIT(); + DUMP_STRING("Batch size = ", parameters.batch_size); + DUMP_STRING("Token count = ", parameters.token_count); + DUMP_STRING("Q hidden size = ", parameters.hidden_size); + DUMP_STRING("KV hidden size = ", parameters.kv_hidden_size); + DUMP_STRING("Q num heads = ", parameters.num_heads); + DUMP_STRING("KV num heads = ", parameters.kv_num_heads); + DUMP_STRING("Head size = ", parameters.head_size); + DUMP_STRING("Num blocks = ", parameters.num_blocks); + DUMP_STRING("Block size = ", parameters.block_size); + DUMP_STRING("Max num blocks per sequence = ", parameters.max_num_blocks_per_seq); + DUMP_STRING("Rotary dimension = ", parameters.rotary_dim); + DUMP_STRING("Is packed QKV = ", parameters.is_packed_qkv); + + // Check rotary + if (do_rotary_ && (cos_cache == nullptr || sin_cache == nullptr)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "cos_cache and sin_cache must be passed to PagedAttention when do_rotary = 1"); + } + + // Set output tensor shapes + TensorShapeVector output_shape(2); + output_shape[0] = static_cast(parameters.token_count); + output_shape[1] = static_cast(parameters.hidden_size); + Tensor* output = context->Output(0, output_shape); + + TensorShapeVector key_cache_out_shape(4); + key_cache_out_shape[0] = static_cast(parameters.num_blocks); + key_cache_out_shape[1] = static_cast(parameters.block_size); + key_cache_out_shape[2] = static_cast(parameters.kv_num_heads); + key_cache_out_shape[3] = static_cast(parameters.head_size); + Tensor* key_cache_out = context->Output(1, key_cache_out_shape); + + TensorShapeVector value_cache_out_shape(4); + value_cache_out_shape[0] = static_cast(parameters.num_blocks); + value_cache_out_shape[1] = static_cast(parameters.block_size); + value_cache_out_shape[2] = static_cast(parameters.kv_num_heads); + value_cache_out_shape[3] = static_cast(parameters.head_size); + Tensor* value_cache_out = context->Output(2, value_cache_out_shape); + + if (key_cache_out != nullptr && key_cache->Data() != key_cache_out->MutableData()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "key_cache and key_cache_out must be the same buffer"); + } else if (value_cache_out != nullptr && value_cache->Data() != value_cache_out->MutableData()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "value_cache and value_cache_out must be the same buffer"); + } + + // Check flash kernel availability and allocate buffers +#if USE_FLASH_ATTENTION + bool use_flash_attention = !disable_flash_attention_ && + onnxruntime::flash::is_supported(device_prop, + parameters.head_size, + parameters.num_heads, + parameters.kv_num_heads); + size_t softmax_lse_bytes = 0; + if (use_flash_attention) { + softmax_lse_bytes = onnxruntime::flash::get_softmax_lse_size(parameters.token_count, + parameters.num_heads); + } + auto softmax_lse_buffer = GetScratchBuffer(softmax_lse_bytes, context->GetComputeStream()); +#else + constexpr bool use_flash_attention = false; + auto softmax_lse_buffer = GetScratchBuffer(0, context->GetComputeStream()); // nullptr +#endif + + if (!use_flash_attention) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Currently PagedAttention is only supported through the FlashAttention kernel."); + } + + size_t cumulative_seqlens_kv_bytes = sizeof(int) * (parameters.batch_size + 1); + auto cumulative_seqlens_kv_buffer = GetScratchBuffer(cumulative_seqlens_kv_bytes, context->GetComputeStream()); + + size_t workspace_buffer_bytes = 0; + if (do_rotary_) { + workspace_buffer_bytes = sizeof(T) * parameters.token_count * (parameters.hidden_size + parameters.kv_hidden_size); + } else if (parameters.is_packed_qkv) { + workspace_buffer_bytes = sizeof(T) * parameters.token_count * parameters.hidden_size; + } + auto workspace_buffer = GetScratchBuffer(workspace_buffer_bytes, context->GetComputeStream()); + + // Print debug info + if (kernel_options_->AllowDebugInfo()) { + AttentionKernelDebugInfo debug_info; + debug_info.use_flash_attention = use_flash_attention; + + debug_info.Print("PagedAttention", + this->Node().Name(), + std::is_same::value, + std::is_same::value); + } + + // Set up data struct for kernel launch + data.query = reinterpret_cast(query->Data()); + data.key = key == nullptr ? nullptr : reinterpret_cast(key->Data()); + data.value = value == nullptr ? nullptr : reinterpret_cast(value->Data()); + data.key_cache = reinterpret_cast(const_cast(key_cache->Data())); + data.value_cache = reinterpret_cast(const_cast(value_cache->Data())); + data.cumulative_seqlens_q = reinterpret_cast(cumulative_seqlens_q->Data()); + data.past_seqlens = reinterpret_cast(past_seqlens->Data()); + data.cumulative_seqlens_kv = reinterpret_cast(cumulative_seqlens_kv_buffer.get()); + data.block_table = reinterpret_cast(block_table->Data()); + data.output = reinterpret_cast(output->MutableData()); + data.use_flash_attention = use_flash_attention; + if (softmax_lse_buffer != nullptr) { + data.softmax_lse = reinterpret_cast(softmax_lse_buffer.get()); + } + if (workspace_buffer != nullptr) { + data.workspace_buffer = reinterpret_cast(workspace_buffer.get()); + } + if (parameters.do_rotary) { + data.cos_cache = reinterpret_cast(cos_cache->Data()); + data.sin_cache = reinterpret_cast(sin_cache->Data()); + } + + cublasHandle_t cublas = GetCublasHandle(context); + + return QkvToContext( + device_prop, cublas, context->GetComputeStream(), parameters, data); +} + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/paged_attention.h b/onnxruntime/contrib_ops/cuda/bert/paged_attention.h new file mode 100644 index 0000000000000..a3df144745f61 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/paged_attention.h @@ -0,0 +1,37 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include "core/providers/cuda/cuda_kernel.h" +#include "contrib_ops/cuda/bert/paged_attention_impl.h" +#include "contrib_ops/cuda/bert/attention_kernel_options.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +using namespace onnxruntime::cuda; + +template +class PagedAttention final : public CudaKernel { + public: + PagedAttention(const OpKernelInfo& info); + Status ComputeInternal(OpKernelContext* context) const override; + + protected: + int num_heads_; // number of attention heads + int kv_num_heads_; // different for k and v for group query attention + int local_window_size_; + bool do_rotary_; + bool rotary_interleaved_; + float scale_; + float softcap_; + bool disable_flash_attention_; + const AttentionKernelOptions* kernel_options_; +}; + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/paged_attention_helper.h b/onnxruntime/contrib_ops/cuda/bert/paged_attention_helper.h new file mode 100644 index 0000000000000..6fb8969aa9d0a --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/paged_attention_helper.h @@ -0,0 +1,278 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/common/common.h" +#include "core/providers/common.h" +#include "contrib_ops/cpu/bert/attention_common.h" +#include "contrib_ops/cpu/bert/attention_parameters.h" +#include "contrib_ops/cpu/bert/group_query_attention_helper.h" + +namespace onnxruntime { +namespace contrib { +namespace paged_attention_helper { + +template +Status Check_Q_K_V(const T* query, const T* key, const T* value, const int num_heads, const int kv_num_heads, + int& token_count, int& q_hidden_size, int& kv_hidden_size, int& head_size) { + const auto& query_dims = query->Shape().GetDims(); + if (query_dims.size() != 2) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'query' is expected to have 2 dimensions, got ", + query_dims.size()); + } + token_count = static_cast(query_dims[0]); + q_hidden_size = static_cast(query_dims[1]); + head_size = static_cast(q_hidden_size) / num_heads; + if (head_size % 8 != 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "head_size must be a multiple of 8. Got head_size % 8 == ", + head_size % 8); + } + if (value == nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'key' and 'value' shall be both present, or both absent in the case of packed qkv."); + } + const auto& key_dims = key->Shape().GetDims(); + if (key_dims.size() != 2) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'key' is expected to have 2 dimensions, got ", + key_dims.size()); + } else if (token_count != key_dims[0]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'query' and 'key' shall have same dim 0 (token count)"); + } + kv_hidden_size = static_cast(key_dims[1]); + if (kv_hidden_size % kv_num_heads != 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "kv_hidden_size must be a multiple of kv_num_heads. Got kv_hidden_size % kv_num_heads == ", + kv_hidden_size % kv_num_heads); + } else if (kv_hidden_size / kv_num_heads != head_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "kv_hidden_size / kv_num_heads must be equal to head_size. Got kv_hidden_size / kv_num_heads == ", + kv_hidden_size / kv_num_heads); + } + const auto& value_dims = value->Shape().GetDims(); + if (value_dims.size() != 2) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have 2 dimensions, got ", + value_dims.size()); + } else if (token_count != value_dims[0]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'query' and 'value' shall have same dim 0 (token count)"); + } else if (value_dims[1] != kv_hidden_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have same hidden size as key."); + } + return Status::OK(); +} + +template +Status Check_QKV(const T* packed_qkv, const T* value, const int num_heads, const int kv_num_heads, int& token_count, + int& q_hidden_size, int& kv_hidden_size, int& head_size) { + const auto& packed_dims = packed_qkv->Shape().GetDims(); + if (packed_dims.size() != 2) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'query' is expected to have 2 dimensions, got ", + packed_dims.size()); + } + token_count = static_cast(packed_dims[0]); + head_size = static_cast(static_cast(packed_dims[1])) / (num_heads + 2 * kv_num_heads); + if (head_size % 8 != 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "head_size must be a multiple of 8. Got head_size % 8 == ", + head_size % 8); + } + if (value != nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'key' and 'value' shall be both present, or both absent in the case of packed qkv."); + } + q_hidden_size = head_size * num_heads; + kv_hidden_size = head_size * kv_num_heads; + return Status::OK(); +} + +template +Status CheckKVCache(const T* key_cache, const T* value_cache, const int kv_num_heads, const int head_size, + int& num_blocks, int& block_size) { + const auto& key_cache_dims = key_cache->Shape().GetDims(); + const auto& value_cache_dims = value_cache->Shape().GetDims(); + if (key_cache_dims.size() != 4) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'key_cache' is expected to have 4 dimensions, got ", + key_cache_dims.size()); + } + if (value_cache_dims.size() != 4) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'value_cache' is expected to have 4 dimensions, got ", + value_cache_dims.size()); + } + + num_blocks = static_cast(key_cache_dims[0]); + block_size = static_cast(key_cache_dims[1]); + // TODO(aciddelgado): block size multiple of 8 + if (block_size % 256 != 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "block_size must be a multiple of 256. Got block_size % 256 == ", + block_size % 256); + } + if (value_cache_dims[0] != num_blocks) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'value_cache' dimension 0 should be num_blocks, got ", + value_cache_dims[0]); + } else if (value_cache_dims[1] != block_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'value_cache' dimension 1 should be block_size, got ", + value_cache_dims[0]); + } + + if (key_cache_dims[2] != value_cache_dims[2]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'key_cache' and 'value_cache' dimension 2 (kv num heads) should be the same, got ", + key_cache_dims[2], " and ", value_cache_dims[2]); + } + if (key_cache_dims[2] != kv_num_heads) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'key_cache' shall have kv_num_heads, got ", + key_cache_dims[2]); + } + if (value_cache_dims[2] != kv_num_heads) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'value_cache' shall have kv_num_heads, got ", + value_cache_dims[2]); + } + + if (key_cache_dims[3] != value_cache_dims[3]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'key_cache' and 'value_cache' dimension 3 (head size) should be the same, got ", + key_cache_dims[3], " and ", value_cache_dims[3]); + } + if (key_cache_dims[3] != head_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'key_cache' dimension 3 should be same as head_size, got ", + key_cache_dims[3]); + } + if (value_cache_dims[3] != head_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'past_value' dimension 3 should be same as head_size, got ", + value_cache_dims[3]); + } + return Status::OK(); +} + +template +Status CheckSequenceLengthTensors(const T* cumulative_sequence_length, const T* seqlens, int& batch_size) { + const auto& cumulative_seqlen_dim = cumulative_sequence_length->Shape().GetDims(); + if (cumulative_seqlen_dim.size() != 1 || cumulative_seqlen_dim[0] < 2) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "cumulative_sequence_length must be shape (batch_size + 1)."); + } + batch_size = static_cast(cumulative_seqlen_dim[0]) - 1; + + const auto& seqlens_dim = seqlens->Shape().GetDims(); + if (seqlens_dim.size() != 1 && seqlens_dim[0] != batch_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "seqlens must be shape (batch_size)."); + } + return Status::OK(); +} + +template +Status CheckBlockTable(const T* block_table, const int batch_size, int& max_num_blocks_per_seq) { + const auto& block_table_dims = block_table->Shape().GetDims(); + if (block_table_dims.size() != 2) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "block_table must be 2D."); + } else if (block_table_dims[0] != batch_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "block_table dimension 0 should be batch_size, got ", + block_table_dims[0]); + } + max_num_blocks_per_seq = static_cast(block_table_dims[1]); + return Status::OK(); +} + +template +Status CheckInputs(const T* query, + const T* key, + const T* value, + const T* key_cache, + const T* value_cache, + const T* cumulative_sequence_length, + const T* seqlens, + const T* block_table, + const T* cos_cache, + const T* sin_cache, + void* parameters, + int num_heads, + int kv_num_heads, + float scale, + float softcap, + int max_threads_per_block) { + if (max_threads_per_block > 0 && num_heads > max_threads_per_block) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "num_heads should be no larger than ", max_threads_per_block); + } + if (num_heads % kv_num_heads != 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "num_heads must be a multiple of kv_num_heads. Got num_heads % kv_num_heads == ", + num_heads % kv_num_heads); + } + + // Check query, key, and value + int token_count = 0; + int q_hidden_size = 0; + int kv_hidden_size = 0; + int head_size = 0; + const bool is_packed_qkv = key == nullptr; + if (!is_packed_qkv) { + ORT_RETURN_IF_ERROR(Check_Q_K_V(query, key, value, num_heads, kv_num_heads, token_count, q_hidden_size, + kv_hidden_size, head_size)); + } else { + ORT_RETURN_IF_ERROR(Check_QKV(query, value, num_heads, kv_num_heads, token_count, q_hidden_size, kv_hidden_size, + head_size)); + } + + // Check KV-Cache + int num_blocks = 0; + int block_size = 0; + ORT_RETURN_IF_ERROR(CheckKVCache(key_cache, value_cache, kv_num_heads, head_size, num_blocks, block_size)); + + // Check sequence length tensors + int batch_size = 0; + ORT_RETURN_IF_ERROR(CheckSequenceLengthTensors(cumulative_sequence_length, seqlens, batch_size)); + + // Check block table and slot mappings + int max_num_blocks_per_seq = 0; + ORT_RETURN_IF_ERROR(CheckBlockTable(block_table, batch_size, max_num_blocks_per_seq)); + + // Check rotary cache + int rotary_dim = 0; + if (cos_cache != nullptr && sin_cache != nullptr) { + // 0 to bypass checking rotary cache size + ORT_RETURN_IF_ERROR(group_query_attention_helper::CheckRotaryCaches(cos_cache, sin_cache, head_size, + 0, rotary_dim)); + } else if (cos_cache != nullptr || sin_cache != nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'cos_cache' and 'sin_cache' shall be both present or both absent."); + } + + if (parameters != nullptr) { + PagedAttentionParameters* output_parameters = reinterpret_cast(parameters); + output_parameters->batch_size = batch_size; + output_parameters->token_count = token_count; + output_parameters->hidden_size = q_hidden_size; + output_parameters->kv_hidden_size = kv_hidden_size; + output_parameters->num_heads = num_heads; + output_parameters->kv_num_heads = kv_num_heads; + output_parameters->head_size = head_size; + output_parameters->block_size = block_size; + output_parameters->max_num_blocks_per_seq = max_num_blocks_per_seq; + output_parameters->num_blocks = num_blocks; + output_parameters->rotary_dim = rotary_dim; + output_parameters->is_packed_qkv = is_packed_qkv; + output_parameters->scale = scale; + output_parameters->softcap = softcap; + } + + return Status::OK(); +} + +} // namespace paged_attention_helper +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/paged_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/paged_attention_impl.cu new file mode 100644 index 0000000000000..7ecdf51bdde11 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/paged_attention_impl.cu @@ -0,0 +1,378 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include +#include +#include "core/providers/cuda/cu_inc/common.cuh" +#include "core/providers/cuda/cuda_common.h" +#include "core/providers/cuda/shared_inc/fpgeneric.h" +#include "contrib_ops/cuda/bert/attention_softmax.h" +#include "contrib_ops/cuda/utils/dump_cuda_tensor.h" +#include "contrib_ops/cuda/bert/flash_attention/flash_api.h" +#include "contrib_ops/cuda/bert/paged_attention_impl.h" +#include "core/providers/cuda/shared_inc/cuda_call.h" +#include "contrib_ops/cuda/bert/rotary_embedding_impl.h" +#include + +using namespace onnxruntime::cuda; + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +////////// Auxiliary Kernels + +template +__global__ void UnpackQKVCumulative(const T* packed_qkv, T* unpacked_qkv, const int token_count, const int num_heads, + const int kv_num_heads, const int head_size) { + const int tid = threadIdx.x + blockIdx.x * blockDim.x; + if (tid >= token_count * (num_heads + 2 * kv_num_heads) * head_size) { + return; + } + const int q_hidden_size = num_heads * head_size; + const int kv_hidden_size = kv_num_heads * head_size; + const int in_seq_stride = q_hidden_size + 2 * kv_hidden_size; + + int packed_i; + if (tid < token_count * q_hidden_size) { + const int token_id = tid / q_hidden_size; + const int offset = tid % q_hidden_size; + packed_i = token_id * in_seq_stride + offset; + } else if (tid < token_count * (q_hidden_size + kv_hidden_size)) { + const int id = tid - token_count * q_hidden_size; + const int token_id = id / kv_hidden_size; + const int offset = id % kv_hidden_size; + packed_i = token_id * in_seq_stride + q_hidden_size + offset; + } else if (tid < token_count * (q_hidden_size + 2 * kv_hidden_size)) { + const int id = tid - token_count * (q_hidden_size + kv_hidden_size); + const int token_id = id / kv_hidden_size; + const int offset = id % kv_hidden_size; + packed_i = token_id * in_seq_stride + q_hidden_size + kv_hidden_size + offset; + } + unpacked_qkv[tid] = packed_qkv[packed_i]; +} + +// Since QKV is unpacked into a single workspace buffer, this is similar to a transpose +template +Status LaunchUnpackQKVCumulative(const T* packed_qkv, T* unpacked_qkv, const int token_count, const int num_heads, + const int kv_num_heads, const int head_size, cudaStream_t stream, + const int max_threads_per_block) { + const int threads = max_threads_per_block; + const int blocks = (token_count * (num_heads + 2 * kv_num_heads) * head_size + threads - 1) / threads; + UnpackQKVCumulative<<>>(packed_qkv, unpacked_qkv, token_count, num_heads, kv_num_heads, + head_size); + return CUDA_CALL(cudaGetLastError()); +} + +template +__global__ void UnpackV(const T* input, T* output, const int token_count, const int hidden_size, + const int packed_seq_stride) { + const int tid = threadIdx.x + blockIdx.x * blockDim.x; + if (tid < token_count * hidden_size) { + int offset = tid % hidden_size; + int token_id = tid / hidden_size; + int packed_i = token_id * packed_seq_stride + offset; + output[tid] = input[packed_i]; + } +} + +template +Status LaunchUnpackCumulative(const T* input, T* output, const int token_count, const int hidden_size, + const int packed_seq_stride, cudaStream_t stream, const int max_threads_per_block) { + const int threads = std::min(max_threads_per_block, token_count * hidden_size); + const int blocks = (token_count * hidden_size + threads - 1) / threads; + UnpackV<<>>(input, output, token_count, hidden_size, packed_seq_stride); + return CUDA_CALL(cudaGetLastError()); +} + +template +__global__ void RotaryEmbeddingTNH(T* output, // TxNxH + const T* input, // TxNxH + const T* cos_cache, // Mx(H/2) + const T* sin_cache, // Mx(H/2) + const int32_t* past_seqlens, // B + const int32_t* cumulative_seqlens_q, // B+1 + const int head_size, + const int rotary_embedding_dim, + const bool interleaved, + const int3 in_strides, // TxNxH + const int3 out_strides) { // TxNxH + // Use .x in innermost loop to access global memory efficiently + + const int b = blockIdx.y; + const int s = blockIdx.x; + const int n = blockIdx.z; + const int h = threadIdx.x; + + const int sequence_length = cumulative_seqlens_q[b + 1] - cumulative_seqlens_q[b]; + if (h >= head_size || s >= sequence_length) { + return; + } + + const int t = cumulative_seqlens_q[b] + s; // t is the index of the token in the unpadded input/output + const T* input_data = input + t * in_strides.x + n * in_strides.y; + T* output_data = output + t * out_strides.x + n * out_strides.y; + + if (h >= rotary_embedding_dim) { + output_data[h] = input_data[h]; + return; + } + + // Cache is (M, H/2) + const int half_rotary_embedding_dim = rotary_embedding_dim / 2; + const int position_id = past_seqlens[b] + s; + const int cache_offset = position_id * half_rotary_embedding_dim; + const T* cos_data = cos_cache + cache_offset; + const T* sin_data = sin_cache + cache_offset; + + int cache_idx = 0; + T sign = 0; + int j = 0; + if (interleaved) { + cache_idx = (h / 2) % half_rotary_embedding_dim; + sign = (h % 2 == 0) ? -1 : 1; + j = (h % 2 == 0) ? h + 1 : h - 1; // i - sign + } else { + cache_idx = h % half_rotary_embedding_dim; + sign = (h < half_rotary_embedding_dim) ? -1 : 1; + j = (h + half_rotary_embedding_dim) % rotary_embedding_dim; + } + output_data[h] = input_data[h] * cos_data[cache_idx] + sign * input_data[j] * sin_data[cache_idx]; +} + +template +Status LaunchRotaryEmbeddingKernel(cudaStream_t stream, T* output, const T* input, const int32_t* past_seqlens, + const int32_t* cumulative_seqlens_q, const T* cos_cache, const T* sin_cache, + const int batch_size, const int max_seqlen_q, const int num_heads, + const int head_size, const int rotary_embedding_dim, const bool interleaved, + const int in_seq_stride, const int max_threads_per_block) { + ORT_ENFORCE(head_size <= max_threads_per_block, "Rotary embedding dim must be <= max_threads_per_block"); + int3 in_strides = {in_seq_stride <= 0 ? num_heads * head_size : in_seq_stride, head_size, 1}; + int3 out_strides = {num_heads * head_size, head_size, 1}; + int tpb = (head_size + 31) / 32 * 32; + + const dim3 grid(max_seqlen_q, batch_size, num_heads); + const dim3 block(tpb); + RotaryEmbeddingTNH<<>>( + output, input, cos_cache, sin_cache, past_seqlens, cumulative_seqlens_q, head_size, rotary_embedding_dim, + interleaved, in_strides, out_strides); + return CUDA_CALL(cudaGetLastError()); +} + +template +__global__ void GetCumulativeSeqlensKV(int32_t* cumulative_seqlens_kv, const int32_t* cumulative_seqlens_q, + const int32_t* past_seqlens, const int batch_size) { + int id = blockDim.x * blockIdx.x + threadIdx.x; + + if (id == 0) { + cumulative_seqlens_kv[0] = 0; + } + + typedef cub::BlockScan BlockScan; + __shared__ typename BlockScan::TempStorage temp_storage; + + // Sum past_seqlens to new sequence length (which we get by subtracting cumulative_seqlens_q). + // Then do an inclusive sum across present sequence lengths to get the cumulative sequence length + if (id < batch_size) { + cumulative_seqlens_kv[id + 1] = past_seqlens[id] + cumulative_seqlens_q[id + 1] - cumulative_seqlens_q[id]; + int length = cumulative_seqlens_kv[id + 1]; + BlockScan(temp_storage).InclusiveSum(length, length); + cumulative_seqlens_kv[id + 1] = length; + } +} + +Status LaunchGetCumulativeSeqlensKV(int32_t* cumulative_seqlens_kv, const int32_t* cumulative_seqlens_q, + const int32_t* past_seqlens, const int batch_size, cudaStream_t stream) { + const int threads = 256; + const int blocks = (batch_size + threads - 1) / threads; + GetCumulativeSeqlensKV<256><<>>(cumulative_seqlens_kv, cumulative_seqlens_q, past_seqlens, + batch_size); + return CUDA_CALL(cudaGetLastError()); +} + +template +__global__ void ReshapeAndCache(const T* __restrict__ key, const T* __restrict__ value, T* __restrict__ key_cache, + T* __restrict__ value_cache, const int* __restrict__ block_table, + const int* __restrict__ past_seqlens, const int* __restrict__ cumulative_seqlens_q, + const int batch_size, const int max_num_blocks_per_seq, const int token_count, + const int kv_hidden_size, const int block_size, const int key_stride, + const int value_stride) { + const int tid = threadIdx.x + blockIdx.x * blockDim.x; + if (tid >= token_count * kv_hidden_size) { + return; + } + const int token_id = tid / kv_hidden_size; + const int hidden_offset = tid % kv_hidden_size; + int batch_id = 0; + for (int i = 0; i < batch_size; ++i) { + if (token_id < cumulative_seqlens_q[i + 1]) { + batch_id = i; + break; + } + } + const int token_offset = token_id - cumulative_seqlens_q[batch_id]; + const int past_length = past_seqlens[batch_id]; + const int block_id = block_table[batch_id * max_num_blocks_per_seq + (past_length + token_offset) / block_size]; + const int block_offset = (past_length + token_offset) % block_size; + + const int key_id = token_id * key_stride + hidden_offset; + const int value_id = token_id * value_stride + hidden_offset; + const int dst_id = block_id * block_size * kv_hidden_size + block_offset * kv_hidden_size + hidden_offset; + key_cache[dst_id] = key[key_id]; + value_cache[dst_id] = value[value_id]; +} + +template +Status LaunchReshapeAndCache(const T* key, const T* value, T* key_cache, T* value_cache, const int* block_table, + const int* past_seqlens, const int* cumulative_seqlens_q, const int batch_size, + const int max_num_blocks_per_seq, const int token_count, const int kv_hidden_size, + const int block_size, const int key_stride, const int value_stride, cudaStream_t stream, + const int max_threads_per_block) { + const int total_size = token_count * kv_hidden_size; + const int threads(std::min(total_size, max_threads_per_block)); + const int blocks((total_size + threads - 1) / threads); + ReshapeAndCache<<>>(key, value, key_cache, value_cache, block_table, past_seqlens, + cumulative_seqlens_q, batch_size, max_num_blocks_per_seq, + token_count, kv_hidden_size, block_size, key_stride, value_stride); + return CUDA_CALL(cudaGetLastError()); +} + +////////// Launch Kernels + +#if USE_FLASH_ATTENTION +template +Status FlashAttention( + const cudaDeviceProp& device_prop, + cudaStream_t stream, + contrib::PagedAttentionParameters& parameters, + PagedAttentionData& data, + float scale) { + // Get parameters + const int max_threads_per_block = device_prop.maxThreadsPerBlock; + const int batch_size = parameters.batch_size; + const int token_count = parameters.token_count; + const int q_hidden_size = parameters.hidden_size; + const int kv_hidden_size = parameters.kv_hidden_size; + const int num_heads = parameters.num_heads; + const int kv_num_heads = parameters.kv_num_heads; + const int head_size = parameters.head_size; + const float softcap = parameters.softcap; + bool is_bf16 = std::is_same::value; + const int local_window_size = parameters.local_window_size; + const int max_num_blocks_per_seq = parameters.max_num_blocks_per_seq; + const int block_size = parameters.block_size; + // The following are passed to flash api but not used by the kernel, so they can be determined heuristically + const int max_query_len = token_count - batch_size + 1; + const int max_seq_len = parameters.max_num_blocks_per_seq * parameters.block_size; + + T* query = const_cast(data.query); + T* key; + T* value; + if (!parameters.is_packed_qkv) { + key = const_cast(data.key); + value = const_cast(data.value); + } else { + key = reinterpret_cast(query) + static_cast(num_heads * head_size); + value = reinterpret_cast(key) + static_cast(kv_num_heads * head_size); + } + + // Calculate cumulative present sequence length in cumulative_seqlens_kv + int* cumulative_seqlens_q = const_cast(data.cumulative_seqlens_q); + int* past_seqlens = const_cast(data.past_seqlens); + int* cumulative_seqlens_kv = data.cumulative_seqlens_kv; + ORT_RETURN_IF_ERROR(LaunchGetCumulativeSeqlensKV(cumulative_seqlens_kv, cumulative_seqlens_q, past_seqlens, + batch_size, stream)); + + if (parameters.do_rotary) { + // Will unpack Q and K in case of packed_qkv + auto q_buffer = data.workspace_buffer; + auto k_buffer = data.workspace_buffer + token_count * num_heads * head_size; + const int packed_seq_stride = parameters.is_packed_qkv ? (num_heads + 2 * kv_num_heads) * head_size : -1; + ORT_RETURN_IF_ERROR(LaunchRotaryEmbeddingKernel( + stream, q_buffer, query, past_seqlens, cumulative_seqlens_q, data.cos_cache, data.sin_cache, batch_size, + max_query_len, num_heads, head_size, parameters.rotary_dim, parameters.rotary_interleaved, packed_seq_stride, + max_threads_per_block)); + ORT_RETURN_IF_ERROR(LaunchRotaryEmbeddingKernel( + stream, k_buffer, key, past_seqlens, cumulative_seqlens_q, data.cos_cache, data.sin_cache, batch_size, + max_query_len, kv_num_heads, head_size, parameters.rotary_dim, parameters.rotary_interleaved, packed_seq_stride, + max_threads_per_block)); + query = q_buffer; + key = k_buffer; + } else if (parameters.is_packed_qkv) { + // Only unpack Q. K and V are unpacked by ReshapeAndCache. + auto q_buffer = data.workspace_buffer; + const int packed_seq_stride = q_hidden_size + 2 * kv_hidden_size; + ORT_RETURN_IF_ERROR(LaunchUnpackCumulative( + query, q_buffer, token_count, q_hidden_size, packed_seq_stride, stream, max_threads_per_block)); + query = q_buffer; + } + + // Insert key and value into block-based KV cache + int* block_table = const_cast(data.block_table); + const int key_stride = parameters.is_packed_qkv && !parameters.do_rotary ? q_hidden_size + 2 * kv_hidden_size : kv_hidden_size; + const int value_stride = parameters.is_packed_qkv ? q_hidden_size + 2 * kv_hidden_size : kv_hidden_size; + ORT_RETURN_IF_ERROR(LaunchReshapeAndCache(key, value, data.key_cache, data.value_cache, block_table, past_seqlens, + cumulative_seqlens_q, batch_size, max_num_blocks_per_seq, token_count, + kv_hidden_size, block_size, key_stride, value_stride, stream, + max_threads_per_block)); + + // Launch kernel + void* q = reinterpret_cast(query); + void* key_cache = reinterpret_cast(data.key_cache); + void* value_cache = reinterpret_cast(data.value_cache); + void* output = reinterpret_cast(data.output); + void* softmax_lse = reinterpret_cast(data.softmax_lse); + ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_varlen_fwd( + device_prop, stream, q, key_cache, value_cache, output, cumulative_seqlens_q, cumulative_seqlens_kv, + /*seqused_k*/ nullptr, block_table, softmax_lse, batch_size, num_heads, kv_num_heads, head_size, max_query_len, + max_seq_len, token_count, scale, softcap, /*is_causal*/ true, is_bf16, local_window_size, max_num_blocks_per_seq, + block_size)); + + DUMP_TENSOR_INIT(); + DUMP_TENSOR("flash attention output", data.output, token_count, num_heads, head_size); + + return Status::OK(); +} +#endif + +////////// API Functions + +template +Status QkvToContext( + const cudaDeviceProp& device_prop, + cublasHandle_t& /*cublas*/, + Stream* ort_stream, + contrib::PagedAttentionParameters& parameters, + PagedAttentionData& data) { + auto stream = static_cast(ort_stream->GetHandle()); + const float scale = parameters.scale == 0.0f ? 1.f / sqrt(static_cast(parameters.head_size)) : parameters.scale; + +#if USE_FLASH_ATTENTION + if (data.use_flash_attention) { + return FlashAttention(device_prop, stream, parameters, data, scale); + } +#endif + + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Unfused Paged Attention not implemented."); +} + +template struct PagedAttentionData; +template Status QkvToContext( + const cudaDeviceProp& device_prop, + cublasHandle_t& cublas, + Stream* ort_stream, + contrib::PagedAttentionParameters& parameters, + PagedAttentionData& data); + +template struct PagedAttentionData; +template Status QkvToContext( + const cudaDeviceProp& device_prop, + cublasHandle_t& cublas, + Stream* ort_stream, + contrib::PagedAttentionParameters& parameters, + PagedAttentionData& data); + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/paged_attention_impl.h b/onnxruntime/contrib_ops/cuda/bert/paged_attention_impl.h new file mode 100644 index 0000000000000..7e27556a5c63f --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/paged_attention_impl.h @@ -0,0 +1,32 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/providers/cuda/shared_inc/cuda_utils.h" +#include +#include +#include "contrib_ops/cpu/bert/attention_common.h" +#include "contrib_ops/cpu/bert/attention_parameters.h" +#include "contrib_ops/cuda/bert/attention_data.h" +#include "core/framework/allocator.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +template +Status QkvToContext( + const cudaDeviceProp& device_prop, + cublasHandle_t& cublas, + Stream* stream, + contrib::PagedAttentionParameters& parameters, + PagedAttentionData& data); + +template +Status LaunchUnpackQKVCumulative(const T* packed_qkv, T* unpacked_q, T* unpacked_k, T* unpacked_v, const int num_heads, + const int kv_num_heads, const int head_size, const int token_count, cudaStream_t stream, + const int max_threads_per_block); + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/cudaDriverWrapper.cc b/onnxruntime/contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/cudaDriverWrapper.cc index 6aaeab5dc2447..84e7eec350216 100644 --- a/onnxruntime/contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/cudaDriverWrapper.cc +++ b/onnxruntime/contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/cudaDriverWrapper.cc @@ -136,10 +136,8 @@ CUresult CUDADriverWrapper::cuLaunchKernel( f, gridDimX, gridDimY, gridDimZ, blockDimX, blockDimY, blockDimZ, sharedMemBytes, hStream, kernelParams, extra); } -// Initialize the singleton instance -CUDADriverWrapper CUDADriverWrapper::instance; - const CUDADriverWrapper* CUDADriverWrapper::GetInstance() { + static CUDADriverWrapper instance; return &instance; } diff --git a/onnxruntime/contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/cudaDriverWrapper.h b/onnxruntime/contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/cudaDriverWrapper.h index 10d0677d1173e..ee79905818ab1 100644 --- a/onnxruntime/contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/cudaDriverWrapper.h +++ b/onnxruntime/contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/cudaDriverWrapper.h @@ -100,8 +100,6 @@ class CUDADriverWrapper { CUfunction f, uint32_t gridDimX, uint32_t gridDimY, uint32_t gridDimZ, uint32_t blockDimX, uint32_t blockDimY, uint32_t blockDimZ, uint32_t sharedMemBytes, CUstream hStream, void** kernelParams, void** extra); - - static CUDADriverWrapper instance; }; inline void cuErrCheck_(CUresult stat, const CUDADriverWrapper& wrap, const char* file, int line) { diff --git a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc index 17f3433aed38a..d016d50d6c445 100644 --- a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc @@ -99,6 +99,8 @@ class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16_float, MultiHeadAttention); class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16_MLFloat16, MultiHeadAttention); class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, GroupQueryAttention); class CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, GroupQueryAttention); +class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, PagedAttention); +class CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, PagedAttention); class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, DecoderAttention); class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, DecoderAttention); class CUDA_ONNX_OP_TYPED_CLASS_NAME(1, int32_t, DynamicSlice); @@ -311,6 +313,8 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/contrib_ops/cuda/llm/common/logger.h b/onnxruntime/contrib_ops/cuda/llm/common/logger.h index a3992e751926d..45c8d0e546455 100644 --- a/onnxruntime/contrib_ops/cuda/llm/common/logger.h +++ b/onnxruntime/contrib_ops/cuda/llm/common/logger.h @@ -5,7 +5,22 @@ #include "core/providers/shared_library/provider_api.h" -#ifndef NDEBUG +#ifdef _MSC_VER +#define PRETTY_FUNCTION __FUNCSIG__ +#else +#define PRETTY_FUNCTION __PRETTY_FUNCTION__ +#endif + +#define ORT_LLM_VERBOSE 0 // Set to 1 for verbose, 2 for max verbosity + +#if ORT_LLM_VERBOSE +#include +#define ORT_LLM_LOG_ENTRY() std::cout << "Enter " << PRETTY_FUNCTION << std::endl; +#else +#define ORT_LLM_LOG_ENTRY() +#endif + +#if ORT_LLM_VERBOSE #define ORT_LLM_LOG_TRACE(msg) LOGS_DEFAULT(VERBOSE) << msg #define ORT_LLM_LOG_DEBUG(msg) LOGS_DEFAULT(VERBOSE) << msg #else diff --git a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/fpA_intB_gemm_template.h b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/fpA_intB_gemm_template.h index 715397270331b..edb763733f9ce 100644 --- a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/fpA_intB_gemm_template.h +++ b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/fpA_intB_gemm_template.h @@ -19,6 +19,7 @@ #pragma GCC diagnostic ignored "-Wstrict-aliasing" #endif // __GNUC__ +#include "cutlass/float8.h" #include "cutlass/gemm/kernel/default_gemm.h" #include "contrib_ops/cuda/llm/cutlass_extensions/compute_occupancy.h" #include "contrib_ops/cuda/llm/cutlass_extensions/gemm/device/gemm_universal_base_compat.h" @@ -39,7 +40,9 @@ #include "contrib_ops/cuda/llm/cutlass_heuristic.h" #include "contrib_ops/cuda/llm/cutlass_type_conversion.h" #include "contrib_ops/cuda/llm/fpA_intB_gemm/fpA_intB_gemm.h" +#ifndef EXCLUDE_SM_90 #include "contrib_ops/cuda/llm/fpA_intB_gemm/fpA_intB_gemm_template_sm90.h" +#endif #include "core/providers/cuda/shared_inc/cuda_call.h" namespace tk = onnxruntime::llm::common; @@ -56,7 +59,7 @@ void generic_mixed_gemm_kernelLauncher(ActivationType const* A, WeightType const ScaleZeroType const* weight_zero_points, BiasType const* biases, float const alpha, OutputType* C, int m, int n, int k, int const group_size, tkc::CutlassGemmConfig gemm_config, char* workspace, size_t workspace_bytes, cudaStream_t stream, int* occupancy = nullptr) { - ORT_LLM_LOG_DEBUG(__PRETTY_FUNCTION__); + ORT_LLM_LOG_ENTRY(); static_assert( #ifdef ENABLE_FP8 @@ -116,11 +119,12 @@ void generic_mixed_gemm_kernelLauncher(ActivationType const* A, WeightType const } if constexpr (cutlass::isFinegrained(QuantOp)) { - if constexpr (cutlass::platform::is_same::value) { + if constexpr (cutlass::platform::is_same::value) { if (group_size != 128) { ORT_THROW("Only group size 128 supported for fine grained W4A(fp)8 kernels."); } } + if (group_size != 64 && group_size != 128) { ORT_THROW("Only group size 64 and 128 supported for fine grained kernels."); } @@ -200,7 +204,7 @@ void filter_and_run_mixed_gemm(ActivationType const* A, WeightType const* B, Sca ScaleZeroType const* weight_zero_points, BiasType const* biases, float const alpha, OutputType* C, int m, int n, int k, int const group_size, tkc::CutlassGemmConfig gemm_config, char* workspace, size_t workspace_bytes, cudaStream_t stream, int* occupancy = nullptr) { - ORT_LLM_LOG_DEBUG(__PRETTY_FUNCTION__); + ORT_LLM_LOG_ENTRY(); if constexpr (Stages > 2 && arch::kMinComputeCapability < 80) { // Multistage only supported on Ampere std::string err_msg = "Cutlass fpA_intB gemm not supported for arch " + std::to_string(arch::kMinComputeCapability) + " with stages set to " + std::to_string(Stages); @@ -227,7 +231,7 @@ void dispatch_gemm_config(ActivationType const* A, WeightType const* B, ScaleZer ScaleZeroType const* weight_zero_points, BiasType const* biases, float const alpha, OutputType* C, int m, int n, int k, int const group_size, tkc::CutlassGemmConfig gemm_config, char* workspace, size_t workspace_bytes, cudaStream_t stream, int* occupancy = nullptr) { - ORT_LLM_LOG_DEBUG(__PRETTY_FUNCTION__); + ORT_LLM_LOG_ENTRY(); switch (gemm_config.stages) { case 2: filter_and_run_mixed_gemm() || is_fp8() || is_fp8() || is_fp8() || is_fp8(); @@ -331,7 +335,7 @@ template CutlassFpAIntBGemmRunner::CutlassFpAIntBGemmRunner() { - ORT_LLM_LOG_DEBUG(__PRETTY_FUNCTION__); + ORT_LLM_LOG_ENTRY(); sm_ = ::onnxruntime::llm::common::getSMVersion(); multi_processor_count_ = ::onnxruntime::llm::common::getMultiProcessorCount(); } @@ -340,7 +344,7 @@ template CutlassFpAIntBGemmRunner::~CutlassFpAIntBGemmRunner() { - ORT_LLM_LOG_DEBUG(__PRETTY_FUNCTION__); + ORT_LLM_LOG_ENTRY(); } template = 75); - if (sm_ >= 75 && sm_ < 80) { + if (sm_ < 80) { dispatch_gemm_to_cutlass(A, B, weight_scales, weight_zero_points, biases, alpha, C, m, n, k, group_size, workspace_ptr, workspace_bytes, gemm_config, stream, occupancy); - } else if ((sm_ >= 80 && sm_ < 89) || sm_ >= 100) { - dispatch_gemm_to_cutlass(A, B, weight_scales, weight_zero_points, biases, alpha, C, m, n, k, group_size, - workspace_ptr, workspace_bytes, gemm_config, stream, occupancy); } else if (sm_ == 89) { #if ENABLE_FP8 && ((__CUDACC_VER_MAJOR__ < 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 4)) if constexpr (cutlass::platform::is_same::value) { @@ -374,14 +375,18 @@ void CutlassFpAIntBGemmRunner(A, B, weight_scales, weight_zero_points, biases, alpha, C, m, n, k, group_size, workspace_ptr, workspace_bytes, gemm_config, stream, occupancy); +#ifndef EXCLUDE_SM_90 } else if (sm_ == 90) { static_assert(!cutlass::platform::is_same::value || cutlass::platform::is_same::value, "ScaleZeroType must be half for activation=fp8"); sm90_dispatch_gemm_to_cutlass(A, B, weight_scales, weight_zero_points, biases, alpha, C, m, n, k, group_size, workspace_ptr, workspace_bytes, gemm_config, stream, occupancy); +#endif } else { - ORT_THROW("[fpA_intB_gemm] Error:Arch unsupported for CUTLASS mixed type GEMM"); + dispatch_gemm_to_cutlass(A, B, weight_scales, weight_zero_points, biases, alpha, C, m, n, k, group_size, + workspace_ptr, workspace_bytes, gemm_config, stream, occupancy); } } @@ -391,7 +396,7 @@ void CutlassFpAIntBGemmRunner((ActivationType const*)A, (WeightType const*)B, (ScaleZeroType const*)weight_scales, (ScaleZeroType const*)weight_zero_points, (BiasType const*)biases, @@ -407,7 +412,7 @@ void CutlassFpAIntBGemmRunner::gemm( void const* A, void const* B, void const* weight_scales, float const alpha, void* C, int m, int n, int k, tkc::CutlassGemmConfig gemmConfig, char* workspace_ptr, const size_t workspace_bytes, cudaStream_t stream) { - ORT_LLM_LOG_DEBUG(__PRETTY_FUNCTION__); + ORT_LLM_LOG_ENTRY(); if constexpr (QuantOp == cutlass::WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY) { dispatch_to_arch((ActivationType const*)A, (WeightType const*)B, @@ -433,7 +438,7 @@ template ::gemm( void const* A, void const* B, void const* weight_scales, void* C, int m, int n, int k, tkc::CutlassGemmConfig gemmConfig, char* workspace_ptr, const size_t workspace_bytes, cudaStream_t stream) { - ORT_LLM_LOG_DEBUG(__PRETTY_FUNCTION__); + ORT_LLM_LOG_ENTRY(); gemm(A, B, weight_scales, 1.f, C, m, n, k, gemmConfig, workspace_ptr, workspace_bytes, stream); } @@ -441,6 +446,7 @@ template std::vector CutlassFpAIntBGemmRunner::getConfigs() const { + ORT_LLM_LOG_ENTRY(); static constexpr bool is_weight_only = !std::is_same::value; tkc::CutlassGemmConfig::CandidateConfigTypeParam config_type_param = tkc::CutlassGemmConfig::CandidateConfigTypeParam::HOPPER; if (is_weight_only) { @@ -456,8 +462,9 @@ template ::getWorkspaceSize( int const m, int const n, int const /*k*/) { - ORT_LLM_LOG_DEBUG(__PRETTY_FUNCTION__); - // For Hopper, we have to allocate large memory size in case for stream-K + ORT_LLM_LOG_ENTRY(); +// For Hopper, we have to allocate large memory size in case for stream-K +#ifndef EXCLUDE_SM_90 if (sm_ == 90) { // https://github.com/NVIDIA/cutlass/blob/19b4c5e065e7e5bbc8082dfc7dbd792bdac850fc/include/cutlass/gemm/kernel/tile_scheduler_params.h#L878-L892 // The above lines says sk_tiles = output_tiles - (static_cast(output_tiles / ctas_per_wave) - 1) * @@ -477,6 +484,7 @@ CutlassFpAIntBGemmRunner( max_sk_tiles_with_separate_reduction * MAX_M_TILE_SM90 * MAX_N_TILE_SM90 * sizeof(float)); } +#endif // These are the min tile sizes for each config, which would launch the maximum number of blocks int const max_grid_m = cutlass::ceil_div(m, MIN_M_TILE); int const max_grid_n = cutlass::ceil_div(n, MIN_N_TILE); diff --git a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/fpA_intB_gemm_template_sm90.h b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/fpA_intB_gemm_template_sm90.h index 432adb20079b6..e87a04b9c3445 100644 --- a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/fpA_intB_gemm_template_sm90.h +++ b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/fpA_intB_gemm_template_sm90.h @@ -44,7 +44,7 @@ void sm90_dispatch_epilogue_schedules(ActivationType const* A, WeightType const* ScaleZeroType const* weight_zero_points, BiasType const* biases, float const alpha, OutputType* C, int m, int n, int k, int const group_size, tkc::CutlassGemmConfig gemm_config, char* workspace, size_t workspace_bytes, cudaStream_t stream, int* occupancy = nullptr) { - ORT_LLM_LOG_DEBUG(__PRETTY_FUNCTION__); + ORT_LLM_LOG_ENTRY(); switch (gemm_config.epilogue_schedule) { case tkc::EpilogueScheduleType::AUTO: using EpilogueScheduleType = cute::conditional_t(CTAShape{}) == Int<64>{}, @@ -100,7 +100,7 @@ void sm90_dispatch_mainloop_schedules(ActivationType const* A, WeightType const* ScaleZeroType const* weight_zero_points, BiasType const* biases, float const alpha, OutputType* C, int m, int n, int k, int const group_size, tkc::CutlassGemmConfig gemm_config, char* workspace, size_t workspace_bytes, cudaStream_t stream, int* occupancy = nullptr) { - ORT_LLM_LOG_DEBUG(__PRETTY_FUNCTION__); + ORT_LLM_LOG_ENTRY(); constexpr bool tile_shapes_supported = are_tile_shapes_supported(); @@ -133,7 +133,7 @@ void sm90_dispatch_gemm_config(ActivationType const* A, WeightType const* B, Sca ScaleZeroType const* weight_zero_points, BiasType const* biases, float const alpha, OutputType* C, int m, int n, int k, int const group_size, tkc::CutlassGemmConfig gemm_config, char* workspace, size_t workspace_bytes, cudaStream_t stream, int* occupancy = nullptr) { - ORT_LLM_LOG_DEBUG(__PRETTY_FUNCTION__); + ORT_LLM_LOG_ENTRY(); switch (gemm_config.cluster_shape) { case tkc::ClusterShape::ClusterShape_1x1x1: sm90_dispatch_mainloop_schedules= 75, "Unsupported CUDA architecture: ", arch); - if (arch >= 75 && arch < 80) { + if (arch < 80) { return getLayoutDetailsForArch(quant_type); +#ifndef EXCLUDE_SM_90 } else if (arch >= 90 && arch < 100) { return getLayoutDetailsForArch(quant_type); - } else /*if (arch >= 80 && arch < 90 || arch >= 100)*/ { +#endif + } else { return getLayoutDetailsForArch(quant_type); } } diff --git a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm_profiler.cc b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm_profiler.cc index 8112562623791..925a6913a2890 100644 --- a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm_profiler.cc +++ b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm_profiler.cc @@ -28,7 +28,7 @@ void WeightOnlyGroupwiseQuantGemmPluginProfiler::runTactic( int const originalN = mQuantBits == 8 ? n * FP16_INT8_RATIO : n * FP16_INT4_RATIO; half* actPtr = reinterpret_cast(workspace); void* weightPtr = nextWorkspacePtr(reinterpret_cast(actPtr), m * k * sizeof(half)); - half* inputScalesPtr = reinterpret_cast(nextWorkspacePtr(reinterpret_cast(weightPtr), n * k * sizeof(float))); + half* inputScalesPtr = reinterpret_cast(nextWorkspacePtr(reinterpret_cast(weightPtr), n * k * sizeof(half))); half* zerosPtr = reinterpret_cast( nextWorkspacePtr(reinterpret_cast(inputScalesPtr), k * originalN * sizeof(half) / mGroupSize)); half* biasesPtr = reinterpret_cast( @@ -68,20 +68,19 @@ void WeightOnlyGroupwiseQuantGemmPluginProfiler::runTactic( } } -void WeightOnlyGroupwiseQuantGemmPluginProfiler::computeTmpSize(size_t maxM, size_t n, size_t k) { +size_t WeightOnlyGroupwiseQuantGemmPluginProfiler::computeTmpSize(size_t maxM, size_t n, size_t k) { // Quantized weights are packed in FP16 format (INT4*4 -> FP16, INT8*2 -> FP16) int const originalN = mQuantBits == 8 ? n * FP16_INT8_RATIO : n * FP16_INT4_RATIO; std::vector workspaces = { maxM * k * sizeof(half), // A - k * n * sizeof(float), // B + k * n * sizeof(half), // B k * originalN * sizeof(half) / mGroupSize, // scales k * originalN * sizeof(half) / mGroupSize, // zeros originalN * sizeof(half), // biases maxM * originalN * sizeof(half), // C mRunner->getWorkspaceSize(maxM, originalN, k) // workspace }; - size_t bytes = calculateTotalWorkspaceSize(workspaces.data(), workspaces.size()); - setTmpWorkspaceSizeInBytes(bytes); + return calculateTotalWorkspaceSize(workspaces.data(), workspaces.size()); } std::vector WeightOnlyGroupwiseQuantGemmPluginProfiler::getTactics( diff --git a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm_profiler.h b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm_profiler.h index 7be77fa43d85d..b1336f45cab27 100644 --- a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm_profiler.h +++ b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm_profiler.h @@ -68,7 +68,7 @@ class WeightOnlyGroupwiseQuantGemmPluginProfiler void runTactic(int m, int n, int k, Config const& tactic, char* workspace, cudaStream_t const& stream) override; - void computeTmpSize(size_t maxM, size_t n, size_t k) override; + size_t computeTmpSize(size_t maxM, size_t n, size_t k) override; std::vector getTactics(int m, int n, int k) const override; diff --git a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/fpA_intB_gemv.cu b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/fpA_intB_gemv.cu index 32cd607d36480..54ed44c0d68d5 100644 --- a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/fpA_intB_gemv.cu +++ b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/fpA_intB_gemv.cu @@ -43,21 +43,12 @@ void kernel_launcher(int arch, Params& params, cudaStream_t s) { return; \ } - if (arch >= 75 && arch < 80) { - EXEC(KernelType::FP16Int8Groupwise, FP16DetailsA, Int8DetailsW, ColumnMajorInterleaved, true); - EXEC(KernelType::FP16Int4Groupwise, FP16DetailsA, Int4DetailsW, ColumnMajorInterleaved, true); - } else if (arch >= 80 && arch < 90 || arch >= 100) { - // if (arch == 89 || arch >= 120) - // { - // EXEC_W4A8(KernelType::FP16Int4Groupwise, FP16DetailsA, Int4DetailsW, ColumnMajorInterleaved, true); - // EXEC_W4A8(KernelType::BF16Int4Groupwise, BF16DetailsA, Int4DetailsW, ColumnMajorInterleaved, true); - // } + ORT_ENFORCE(arch >= 75, "Unsupported CUDA architecture: ", arch); + if (arch < 80) { EXEC(KernelType::FP16Int8Groupwise, FP16DetailsA, Int8DetailsW, ColumnMajorInterleaved, true); EXEC(KernelType::FP16Int4Groupwise, FP16DetailsA, Int4DetailsW, ColumnMajorInterleaved, true); - - EXEC(KernelType::BF16Int8Groupwise, BF16DetailsA, Int8DetailsW, ColumnMajorInterleaved, true); - EXEC(KernelType::BF16Int4Groupwise, BF16DetailsA, Int4DetailsW, ColumnMajorInterleaved, true); - } else if (arch >= 90) { +#ifndef EXCLUDE_SM_90 + } else if (arch >= 90 && arch < 100) { // Dispatchers for W4A8 groupwise // EXEC_W4A8(KernelType::FP16Int4Groupwise, FP16DetailsA, Int4DetailsW, ColumnMajorInterleavedForHopper, true); // EXEC_W4A8(KernelType::BF16Int4Groupwise, BF16DetailsA, Int4DetailsW, ColumnMajorInterleavedForHopper, true); @@ -67,6 +58,18 @@ void kernel_launcher(int arch, Params& params, cudaStream_t s) { EXEC(KernelType::BF16Int8Groupwise, BF16DetailsA, Int8DetailsW, ColumnMajorInterleavedForHopper, true); EXEC(KernelType::BF16Int4Groupwise, BF16DetailsA, Int4DetailsW, ColumnMajorInterleavedForHopper, true); +#endif + } else { + // if (arch >= 89) + // { + // EXEC_W4A8(KernelType::FP16Int4Groupwise, FP16DetailsA, Int4DetailsW, ColumnMajorInterleaved, true); + // EXEC_W4A8(KernelType::BF16Int4Groupwise, BF16DetailsA, Int4DetailsW, ColumnMajorInterleaved, true); + // } + EXEC(KernelType::FP16Int8Groupwise, FP16DetailsA, Int8DetailsW, ColumnMajorInterleaved, true); + EXEC(KernelType::FP16Int4Groupwise, FP16DetailsA, Int4DetailsW, ColumnMajorInterleaved, true); + + EXEC(KernelType::BF16Int8Groupwise, BF16DetailsA, Int8DetailsW, ColumnMajorInterleaved, true); + EXEC(KernelType::BF16Int4Groupwise, BF16DetailsA, Int4DetailsW, ColumnMajorInterleaved, true); } #undef EXEC_W4A8 #undef EXEC @@ -81,6 +84,11 @@ bool is_supported(int arch, KernelType kernel_type) { SUPPORT(KernelType::FP16Int8Groupwise); SUPPORT(KernelType::FP16Int4Groupwise); } else if (arch >= 80) { +#ifdef EXCLUDE_SM_90 + if (arch >= 90 && arch < 100) { + return false; + } +#endif SUPPORT(KernelType::FP16Int8Groupwise); SUPPORT(KernelType::FP16Int4Groupwise); diff --git a/onnxruntime/contrib_ops/cuda/llm/gemm_profiler.cc b/onnxruntime/contrib_ops/cuda/llm/gemm_profiler.cc index 893ff27c068f8..73902a0636fcb 100644 --- a/onnxruntime/contrib_ops/cuda/llm/gemm_profiler.cc +++ b/onnxruntime/contrib_ops/cuda/llm/gemm_profiler.cc @@ -27,15 +27,6 @@ namespace onnxruntime::llm::kernels::weight_only { template GemmPluginProfiler::GemmPluginProfiler() { mMNKProfileMap = std::make_shared(); - - // set SKIP_GEMM_PLUGIN_PROFILINGS=1 to avoid tactics profilings - auto const skipEnv = std::getenv("SKIP_GEMM_PLUGIN_PROFILINGS"); - mSkip = (skipEnv != NULL && std::stoi(skipEnv)); - if (mSkip) { - ORT_LLM_LOG_DEBUG( - "SKIP_GEMM_PLUGIN_PROFILINGS is set. Skipping GEMM plugin profilings. It could result in runtime error " - "if default tactic is not defined."); - } } // template @@ -106,6 +97,7 @@ template ::profileTactics( RunnerPtr const& runner, nvinfer::DataType const& type, GemmDims const& dims, GemmIdType const& gemmId, bool hasWeightOnlyCudaKernel) { + ORT_LLM_LOG_ENTRY(); writer_lock lock(mMNKProfileMap->mutex); if (!dims.isInitialized()) { @@ -116,7 +108,8 @@ void GemmPluginProfiler::profileT mType = type; int const maxM = std::min(nextPowerOfTwo(dims.maxM), getMaxProfileM()); - computeTmpSize(maxM, dims.n, dims.k); + + size_t workspace_bytes = computeTmpSize(maxM, dims.n, dims.k); if (!mMNKProfileMap->existsMProfileMap(gemmId)) { // Create map for GEMM ID @@ -130,16 +123,22 @@ void GemmPluginProfiler::profileT auto mProfileMap = mMNKProfileMap->getMProfileMap(gemmId); bool isAllocated{false}; - auto profileTactics = [&mProfileMap, &isAllocated, this](int m, int n, int k) { + auto profileTactics = [&](int m, int n, int k) { if (mProfileMap->count(m) == 0) { if (!isAllocated) { - // Allocate tmp data to run GEMMs - allocateTmpData(); + this->mWorkspaceTmp = onnxruntime::IAllocator::MakeUniquePtr(mAllocator, workspace_bytes, true); +#if ORT_LLM_VERBOSE + AllocatorStats stats; + this->mAllocator->GetStats(&stats); + std::cout << "Allocator state after " << workspace_bytes << " bytes gemm profiler workspace:" << std::endl + << stats.DebugString() << std::endl; +#endif isAllocated = true; } - initTmpData(m, n, k, mWorkspaceTmp, mTmpWorkspaceSizeInBytes, mStream); - auto tactics = this->getTactics(m, n, k); + initTmpData(m, n, k, this->mWorkspaceTmp.get(), workspace_bytes, this->mStream); + + auto tactics = this->getTactics(m, n, k); // Profile different tactics for particular m and insert best config to the map mProfileMap->insert({m, this->profileTacticsForProblem(m, n, k, tactics)}); } @@ -171,7 +170,7 @@ void GemmPluginProfiler::profileT if (isAllocated) { // Free tmp data - freeTmpData(); + mWorkspaceTmp.reset(); } CUDA_CALL_THROW(cudaStreamDestroy(mStream)); } @@ -179,6 +178,7 @@ void GemmPluginProfiler::profileT template std::optional GemmPluginProfiler::getBestConfig( int m, GemmIdType const& gemmId) const { + ORT_LLM_LOG_ENTRY(); reader_lock lock(mMNKProfileMap->mutex); if (mSkip) { @@ -201,28 +201,19 @@ std::optional GemmPluginProfiler -void GemmPluginProfiler::allocateTmpData() { - ORT_ENFORCE(mTmpWorkspaceSizeInBytes > 0, "tmpWorkspaceSizeInBytes must be larger than 0"); - auto const status = cudaMalloc(&mWorkspaceTmp, mTmpWorkspaceSizeInBytes); - ORT_ENFORCE(status == cudaSuccess, "Can't allocate tmp workspace for GEMM tactics profiling."); -} - -template -void GemmPluginProfiler::freeTmpData() { - auto const status = cudaFree(mWorkspaceTmp); - ORT_ENFORCE(status == cudaSuccess, "Can't free tmp workspace for GEMM tactics profiling."); -} - template std::optional GemmPluginProfiler::profileTacticsForProblem( int m, int n, int k, std::vector const& tactics) { - ORT_LLM_LOG_DEBUG(__PRETTY_FUNCTION__); + ORT_LLM_LOG_ENTRY(); float bestTime = std::numeric_limits::max(); Config bestConfig; bool foundOne = false; +#if ORT_LLM_VERBOSE > 1 + std::cout << "Total configs to profile:" << tactics.size() << std::endl; +#endif + // Iterate over all tactics for given M, N and K for (size_t ii = 0; ii < tactics.size(); ++ii) { Config const& candidateConfig = tactics[ii]; @@ -233,6 +224,13 @@ std::optional GemmPluginProfiler 1 + if constexpr (std::is_same_v) { + std::cout << "Time=" << time << " for config: " << candidateConfig.toString() << std::endl; + } +#endif + foundOne = true; } catch (std::exception const& e) { std::ostringstream msg; @@ -263,6 +261,10 @@ std::optional GemmPluginProfiler 1 + std::cout << "Best config:" << bestConfig.toString() << std::endl; +#endif + return {bestConfig}; } @@ -276,7 +278,7 @@ float GemmPluginProfiler::profile // Warmup the execution for (int i = 0; i < warmup; ++i) { - runTactic(m, n, k, tactic, mWorkspaceTmp, stream); + runTactic(m, n, k, tactic, mWorkspaceTmp.get(), stream); } cudaEvent_t start; @@ -288,7 +290,7 @@ float GemmPluginProfiler::profile // Profile GEMM for (int i = 0; i < runs; ++i) { - runTactic(m, n, k, tactic, mWorkspaceTmp, stream); + runTactic(m, n, k, tactic, mWorkspaceTmp.get(), stream); } CUDA_CALL_THROW(cudaEventRecord(stop, stream)); diff --git a/onnxruntime/contrib_ops/cuda/llm/gemm_profiler.h b/onnxruntime/contrib_ops/cuda/llm/gemm_profiler.h index 0ab9b91e7f43c..44604dc6477a0 100644 --- a/onnxruntime/contrib_ops/cuda/llm/gemm_profiler.h +++ b/onnxruntime/contrib_ops/cuda/llm/gemm_profiler.h @@ -30,6 +30,7 @@ #include "contrib_ops/cuda/llm/nv_infer_datatype.h" #include "core/common/common.h" +#include "core/framework/allocator.h" namespace onnxruntime::llm::kernels::weight_only { @@ -190,14 +191,14 @@ class GemmPluginProfiler { mMNKProfileMap = map; } - void setTmpWorkspaceSizeInBytes(size_t bytes) { - mTmpWorkspaceSizeInBytes = bytes; - } - void setSkip(bool skip) { mSkip = mSkip || skip; } + void setAllocator(onnxruntime::AllocatorPtr allocator) { + mAllocator = std::move(allocator); + } + std::optional getBestConfig(int m, GemmIdType const& gemmId) const; virtual int getMaxProfileM() const; @@ -205,7 +206,7 @@ class GemmPluginProfiler { protected: virtual void runTactic(int m, int n, int k, Config const& tactic, char* workspace, cudaStream_t const& stream) = 0; - virtual void computeTmpSize(size_t maxM, size_t n, size_t k) = 0; + virtual size_t computeTmpSize(size_t maxM, size_t n, size_t k) = 0; virtual bool checkTactic(int /*m*/, int /*n*/, int /*k*/, Config const& /*tactic*/) const { return true; @@ -216,10 +217,6 @@ class GemmPluginProfiler { virtual void initTmpData(int m, int n, int k, char* workspace, size_t size, cudaStream_t stream); private: - void allocateTmpData(); - - void freeTmpData(); - std::optional profileTacticsForProblem(int m, int n, int k, std::vector const& tactics); float profileTacticForProblem(int m, int n, int k, Config const& tactic); @@ -242,15 +239,15 @@ class GemmPluginProfiler { private: MNKProfileMapPtr mMNKProfileMap{}; - size_t mTmpWorkspaceSizeInBytes{0}; - - char* mWorkspaceTmp{nullptr}; + onnxruntime::IAllocatorUniquePtr mWorkspaceTmp{nullptr}; cudaStream_t mStream; GemmDims mDims{}; bool mSkip{false}; + + onnxruntime::AllocatorPtr mAllocator; }; template diff --git a/onnxruntime/contrib_ops/cuda/llm/generate_kernels.py b/onnxruntime/contrib_ops/cuda/llm/generate_kernels.py index acaecf71bf1f0..50d2fe07d4a38 100644 --- a/onnxruntime/contrib_ops/cuda/llm/generate_kernels.py +++ b/onnxruntime/contrib_ops/cuda/llm/generate_kernels.py @@ -223,19 +223,20 @@ def get_file_content(launcher_inl_files, operations): instantiate_operation(insts_list, op) instantiations = "\n".join(insts_list) - file_content = f"""{includes} -namespace onnxruntime::llm -{{ -namespace kernels -{{ -namespace cutlass_kernels -{{ + file_content = f""" +#ifndef EXCLUDE_SM_90 +{includes} + +namespace onnxruntime::llm {{ +namespace kernels {{ +namespace cutlass_kernels {{ {instantiations} -}} // namespace cutlass_kernels -}} // namespace kernels -}} // namespace onnxruntime::llm +}} // namespace cutlass_kernels +}} // namespace kernels +}} // namespace onnxruntime::llm +#endif // EXCLUDE_SM_90 """ return file_content diff --git a/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cuh b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cuh index f6f380c8211f6..28dce3937dd23 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cuh +++ b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cuh @@ -12,9 +12,9 @@ namespace cuda { // A more general block-wise dequantization implementation that supports // different block sizes and block orientations (row-wise/column-wise). template < - int Row_, ///< rows of a matrix - int Column_ ///< columns of a matrix - > + int Row_, ///< rows of a matrix + int Column_ ///< columns of a matrix + > struct Shape2D { static int const kRow = Row_; ///< rows of a matrix static int const kColumn = Column_; ///< columns of a matrix @@ -30,17 +30,17 @@ struct Shape2D { * false: elements in a block come from one single row */ template < - typename ElementT, - int32_t block_size, - int32_t qbits, - bool Columnwise> + typename ElementT, + int32_t block_size, + int32_t qbits, + bool Columnwise> struct BlkQuantTraits { // number of qbit elements to pack into whole bytes static constexpr int kPackSize = (qbits == 8) ? 1 : (qbits == 4) ? 2 : (qbits == 2) ? 4 : 0; static_assert(kPackSize != 0, "Packing to whole bytes not supported for this qbits!"); using QuantBlk = std::conditional_t, Shape2D<1, block_size>>; - + using ThreadBlk = Shape2D; }; @@ -68,6 +68,26 @@ Status Dequantize8Bits( int block_size, cudaStream_t stream); +template +Status DequantizeNBits( + int bits, + T* output, + const uint8_t* quant_data, + const T* scales_data, + const ZeroT* zero_points, + const int32_t* reorder_idx, + int k, + int n, + int block_size, + cudaStream_t stream) { + if (bits == 4) { + return Dequantize4Bits(output, quant_data, scales_data, zero_points, reorder_idx, k, n, block_size, stream); + } else { + ORT_ENFORCE(bits == 8); + return Dequantize8Bits(output, quant_data, scales_data, zero_points, reorder_idx, k, n, block_size, stream); + } +} + /** * @brief Dequantize a block-wise quantized matrix, and store the result in a * column major matrix for use in subsequent GEMM. This implementation supports diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc index 59da5b57dc715..71a84b877b8d1 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc @@ -15,6 +15,7 @@ #include "contrib_ops/cuda/llm/fpA_intB_gemm/fpA_intB_gemm.h" #include "contrib_ops/cuda/llm/fpA_intB_gemm_adaptor.h" #include "contrib_ops/cuda/llm/fpA_intB_gemm_preprocessors.h" +#include "contrib_ops/cuda/llm/common/logger.h" #include "contrib_ops/cpu/quantization/matmul_nbits_helper.h" constexpr int MatMulNBits_Input_B = 1; @@ -72,6 +73,9 @@ void MatMulNBits::InitGemmProfiler(int sm) { gemmProfiler_->setCudaKernelType(cuda_kernel_type, sm); gemmProfiler_->setQuant(nbits_, has_bias_, has_zero_points_); gemmProfiler_->setGroupSize(block_size_); + + auto allocator = this->Info().GetAllocator(OrtMemType::OrtMemTypeDefault); + gemmProfiler_->setAllocator(allocator); } template @@ -158,6 +162,7 @@ Status MatMulNBits::PrePack_B([[maybe_unused]] const Tensor& tensor, {static_cast(k), static_cast(n)}, quant_type); + CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream)); DUMP_TENSOR_INIT(); DUMP_TENSOR_D("packed transposed_weight in GPU", packed_transposed_weight, k, n * nbits_ / 8); DUMP_TENSOR_D("preprocessed_weight", reinterpret_cast(preprocessed_weight), k, n * nbits_ / 8); @@ -289,7 +294,10 @@ Status MatMulNBits::ComputeInternal(OpKernelContext* ctx) const { if (has_fpA_intB_gemm_) { auto const& bestTactic = gemmProfiler_->getBestConfig(m, gemmId_); - DUMP_STRING("Best tactic: m=", m, " n=", n, " k=", k, " group_size=", block_size_, bestTactic->toString()); +#if ORT_LLM_VERBOSE > 1 + std::cout << "Best tactic for m=" << m << ", n=" << n << ", k=" << k << "group_size=" << block_size_ + << " is: " << bestTactic->toString() << std::endl; +#endif if (bestTactic->enableCudaKernel) { using onnxruntime::llm::kernels::fpA_intB_gemv::KernelType; @@ -330,31 +338,19 @@ Status MatMulNBits::ComputeInternal(OpKernelContext* ctx) const { } if ((reorder_idx_data == nullptr) && (!zero_points || !zero_points->IsDataType())) { - bool done = (nbits_ == 8) ? TryMatMul8Bits( - reinterpret_cast(Y->MutableData()), - reinterpret_cast(a_data), - blob_data, - reinterpret_cast(scales_data), - static_cast(zero_points_data), - m, - n, - k, - SafeInt(block_size_), - GetDeviceProp().sharedMemPerBlock, - stream) - : TryMatMul4Bits( - reinterpret_cast(Y->MutableData()), - reinterpret_cast(a_data), - blob_data, - reinterpret_cast(scales_data), - static_cast(zero_points_data), - m, - n, - k, - SafeInt(block_size_), - GetDeviceProp().sharedMemPerBlock, - stream); - if (done) { + if (TryMatMulNBits( + nbits_, + reinterpret_cast(Y->MutableData()), + reinterpret_cast(a_data), + blob_data, + reinterpret_cast(scales_data), + static_cast(zero_points_data), + m, + n, + k, + SafeInt(block_size_), + GetDeviceProp().sharedMemPerBlock, + stream)) { return Status::OK(); } } diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cuh b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cuh index fe7098b92cba8..e97990d884ae5 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cuh +++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cuh @@ -36,6 +36,33 @@ bool TryMatMul8Bits( size_t shared_mem_per_block, cudaStream_t stream); +template +bool TryMatMulNBits( + int bits, + T* output, + const T* a_data, + const uint8_t* b_data_quant, + const T* scales_data, + const uint8_t* zero_points, + int m, + int n, + int k, + int block_size, + size_t shared_mem_per_block, + cudaStream_t stream) { + if (bits == 8) { + return TryMatMul8Bits(output, a_data, b_data_quant, scales_data, zero_points, + m, n, k, block_size, shared_mem_per_block, stream); + } + + if (bits == 4) { + return TryMatMul4Bits(output, a_data, b_data_quant, scales_data, zero_points, + m, n, k, block_size, shared_mem_per_block, stream); + } + + return false; +} + } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc index 55fe2b05c7386..fabfbfbdd142e 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc @@ -624,10 +624,9 @@ var tile_qk: array; sum += inner_qk_values[local_idx][i]; } - let output_idx = head_idx * total_sequence_length + total_seq_offset + local_idx; - sum = sum + loadAttentionBias(output_idx); + sum = sum + loadAttentionBias(head_idx * total_sequence_length + total_seq_offset + local_idx); tile_qk[local_idx] = sum; - output[output_idx] = sum; + output[head_idx * uniforms.present_sequence_length + total_seq_offset + local_idx] = sum; } workgroupBarrier(); @@ -645,7 +644,7 @@ var tile_qk: array; for (var i = 0u; i < tile_size && (total_seq_offset + i) < total_sequence_length; i++) { l_sum += exp(f32(tile_qk[i]) - l_max); } - let meta_offset = head_idx * uniforms.num_total_seq_length_tile + workgroup_idx % uniforms.num_total_seq_length_tile; + let meta_offset = head_idx * uniforms.num_present_sequence_length_tile + workgroup_idx % uniforms.num_total_seq_length_tile; metadata[meta_offset] = metadata_value_t(l_max, l_sum); } )MAIN_FN"; @@ -655,7 +654,7 @@ var tile_qk: array; Status ComputeFlashAttentionDecodeQKT(onnxruntime::webgpu::ComputeContext& context, const Tensor* Q, const Tensor* attention_bias, Tensor* output, Tensor* present_key, Tensor* metadata, - const WebgpuAttentionParameters& parameters, uint32_t num_total_seq_length_tile, uint32_t tile_size) { + const WebgpuAttentionParameters& parameters, uint32_t num_total_seq_length_tile, uint32_t num_present_sequence_length_tile, uint32_t tile_size) { const float alpha = parameters.scale_ == 0.0f ? 1.f / sqrt(static_cast(parameters.head_size_)) : parameters.scale_; @@ -682,6 +681,7 @@ Status ComputeFlashAttentionDecodeQKT(onnxruntime::webgpu::ComputeContext& conte {static_cast(parameters.is_gqa_ ? parameters.seqlen_present_kv_cache_ : parameters.total_sequence_length_)}, {static_cast(parameters.n_reps)}, {num_total_seq_length_tile}, + {num_present_sequence_length_tile}, {static_cast(parameters.num_heads_)}}); return context.RunProgram(program); @@ -734,18 +734,18 @@ var qkv_values: array, var g_sum = f32(0); for (var i = 0u; i < uniforms.num_total_seq_length_tile; i++) { - let meta_offset = head_idx * uniforms.num_total_seq_length_tile + i; + let meta_offset = head_idx * uniforms.num_present_sequence_length_tile + i; g_max = max(g_max, metadata[meta_offset].x); } for (var i = 0u; i < uniforms.num_total_seq_length_tile; i++) { - let meta_offset = head_idx * uniforms.num_total_seq_length_tile + i; + let meta_offset = head_idx * uniforms.num_present_sequence_length_tile + i; let m_value = metadata[meta_offset]; g_sum += exp(m_value.x - g_max) * m_value.y; } if (total_seq_offset + local_idx < total_sequence_length) { - tile_qk[local_idx] = present_value_element_t(exp(f32(qk[head_idx * total_sequence_length + total_seq_offset + local_idx]) - g_max) / g_sum); + tile_qk[local_idx] = present_value_element_t(exp(f32(qk[head_idx * uniforms.present_sequence_length + total_seq_offset + local_idx]) - g_max) / g_sum); } } for (var k: u32 = 0u; k < uniforms.head_size_vec; k += tile_size_k_vec) { @@ -777,7 +777,7 @@ var qkv_values: array, } for (var i = local_idx; i < uniforms.head_size_vec; i += workgroup_size_x) { - let out_offset = head_idx * uniforms.num_total_seq_length_tile * uniforms.head_size_vec + (workgroup_idx % uniforms.num_total_seq_length_tile) * uniforms.head_size_vec + i; + let out_offset = head_idx * uniforms.num_present_sequence_length_tile * uniforms.head_size_vec + (workgroup_idx % uniforms.num_total_seq_length_tile) * uniforms.head_size_vec + i; out_split_vx[out_offset] = tile_output[i]; } )MAIN_FN"; @@ -792,6 +792,7 @@ Status ComputeFlashAttentionDecodeSplitVxScore(onnxruntime::webgpu::ComputeConte Tensor* present_value, const WebgpuAttentionParameters& parameters, uint32_t num_total_seq_length_tile, + uint32_t num_present_sequence_length_tile, uint32_t tile_size) { const int components = 4; int head_size_vec = parameters.v_head_size_ / components; @@ -808,6 +809,7 @@ Status ComputeFlashAttentionDecodeSplitVxScore(onnxruntime::webgpu::ComputeConte {static_cast(parameters.is_gqa_ ? parameters.seqlen_present_kv_cache_ : parameters.total_sequence_length_)}, {static_cast(parameters.n_reps)}, num_total_seq_length_tile, + num_present_sequence_length_tile, {static_cast(parameters.num_heads_)}}); return context.RunProgram(program); @@ -831,7 +833,7 @@ var tile_input: array, TILE_SIZE>; shader.MainFunctionBody() << R"MAIN_FN( let head_size_offset = (workgroup_idx % uniforms.num_head_size_tile) * TILE_SIZE; let head_idx = u32(workgroup_idx / uniforms.num_head_size_tile); - let in_offset = head_idx * uniforms.num_total_seq_length_tile * uniforms.head_size_vec; + let in_offset = head_idx * uniforms.num_present_sequence_length_tile * uniforms.head_size_vec; var value = output_value_t(0); let local_row = u32(local_idx / TILE_SIZE); let local_col = local_idx % TILE_SIZE; @@ -868,7 +870,8 @@ Status ComputeFlashAttentionDecodeVxReduce(onnxruntime::webgpu::ComputeContext& const Tensor* out_split_vx, Tensor* output, const WebgpuAttentionParameters& parameters, - uint32_t num_total_seq_length_tile) { + uint32_t num_total_seq_length_tile, + uint32_t num_present_sequence_length_tile) { const int components = 4; constexpr int tile_size = 8; int tile_head_size = tile_size * components; @@ -881,6 +884,7 @@ Status ComputeFlashAttentionDecodeVxReduce(onnxruntime::webgpu::ComputeContext& .SetWorkgroupSize(tile_size * tile_size) .AddUniformVariables({{static_cast(parameters.v_head_size_ / components)}, num_total_seq_length_tile, + num_present_sequence_length_tile, {num_head_size_tile}, {static_cast(parameters.num_heads_)}}); @@ -891,7 +895,7 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co Tensor* output, const Tensor* past_key, Tensor* present_key, const Tensor* past_value, Tensor* present_value, const WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context) { ORT_RETURN_IF_ERROR(CopyKVCache(context, parameters, K, past_key, present_key, V, past_value, present_value)); - + const int present_sequence_length = parameters.is_gqa_ ? parameters.seqlen_present_kv_cache_ : parameters.total_sequence_length_; if (parameters.sequence_length_ > 1) { const uint32_t tile_size = 64; bool has_attention_bias = attention_bias != nullptr; @@ -913,7 +917,7 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co .CacheHint(has_attention_bias, parameters.head_size_, parameters.num_heads_, is_qualcomm) .AddUniformVariables({{static_cast(parameters.sequence_length_)}, {static_cast(parameters.total_sequence_length_)}, - {static_cast(parameters.past_present_share_buffer_ ? parameters.past_sequence_length_ : parameters.total_sequence_length_)}, + {static_cast(present_sequence_length)}, {static_cast(parameters.total_sequence_length_ - parameters.kv_sequence_length_)}, {static_cast(parameters.is_gqa_ ? 1 : 0)}, {static_cast(parameters.n_reps)}, @@ -923,25 +927,27 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co return context.RunProgram(program); } + // Use present_sequence_length instead of total_sequence_length to make sure the |qk| buffer is static when static qv cache is enabled. const TensorShapeVector qk_dims({parameters.batch_size_, parameters.num_heads_, - parameters.sequence_length_, parameters.total_sequence_length_}); + parameters.sequence_length_, present_sequence_length}); const TensorShape qk_shape(qk_dims); Tensor qk = context.CreateGPUTensor(Q->DataType(), qk_shape); constexpr uint32_t tile_size = 64; const uint32_t num_total_seq_length_tile = (parameters.total_sequence_length_ + tile_size - 1) / tile_size; + const uint32_t num_present_sequence_length_tile = (present_sequence_length + tile_size - 1) / tile_size; // The metadata is used to store the max and sum of each tile. const TensorShapeVector metadata_dims({parameters.batch_size_, parameters.num_heads_, - num_total_seq_length_tile, 2}); + num_present_sequence_length_tile, 2}); const TensorShape metadata_shape(metadata_dims); Tensor metadata = context.CreateGPUTensor(DataTypeImpl::GetType(), metadata_shape); ORT_RETURN_IF_ERROR(ComputeFlashAttentionDecodeQKT(context, Q, attention_bias, &qk, present_key, &metadata, - parameters, num_total_seq_length_tile, tile_size)); + parameters, num_total_seq_length_tile, num_present_sequence_length_tile, tile_size)); - const TensorShapeVector out_split_vx_dims({parameters.batch_size_, parameters.num_heads_, num_total_seq_length_tile, parameters.head_size_}); + const TensorShapeVector out_split_vx_dims({parameters.batch_size_, parameters.num_heads_, num_present_sequence_length_tile, parameters.head_size_}); const TensorShape out_split_vx_shape(out_split_vx_dims); Tensor out_split_vx = context.CreateGPUTensor(Q->DataType(), out_split_vx_shape); - ORT_RETURN_IF_ERROR(ComputeFlashAttentionDecodeSplitVxScore(context, &metadata, &qk, &out_split_vx, present_value, parameters, num_total_seq_length_tile, tile_size)); - ORT_RETURN_IF_ERROR(ComputeFlashAttentionDecodeVxReduce(context, &out_split_vx, output, parameters, num_total_seq_length_tile)); + ORT_RETURN_IF_ERROR(ComputeFlashAttentionDecodeSplitVxScore(context, &metadata, &qk, &out_split_vx, present_value, parameters, num_total_seq_length_tile, num_present_sequence_length_tile, tile_size)); + ORT_RETURN_IF_ERROR(ComputeFlashAttentionDecodeVxReduce(context, &out_split_vx, output, parameters, num_total_seq_length_tile, num_present_sequence_length_tile)); return Status::OK(); } diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h index 374ede6f4db7c..3f79b80fb73bc 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h @@ -82,6 +82,7 @@ class FlashAttentionDecodeQKTProgram final : public Program) -> vec4 + ss << R"ADDNL_FN( + fn DequantizedFrom4BitsTo8Bits(in: vec2, zero: i32) -> vec4 { var out = vec4(0); - var value_lower = vec4(unpack4xU8(in[0] & 0x0F0F0F0Fu)) - vec4(8); - var value_upper = vec4(unpack4xU8((in[0] >> 4) & 0x0F0F0F0Fu)) - vec4(8); + var value_lower = vec4(unpack4xU8(in[0] & 0x0F0F0F0Fu)) - vec4(zero); + var value_upper = vec4(unpack4xU8((in[0] >> 4) & 0x0F0F0F0Fu)) - vec4(zero); out[0] = pack4xI8(vec4(value_lower[0], value_upper[0], value_lower[1], value_upper[1])); out[1] = pack4xI8(vec4(value_lower[2], value_upper[2], value_lower[3], value_upper[3])); - value_lower = vec4(unpack4xU8(in[1] & 0x0F0F0F0Fu)) - vec4(8); - value_upper = vec4(unpack4xU8((in[1] >> 4) & 0x0F0F0F0Fu)) - vec4(8); + value_lower = vec4(unpack4xU8(in[1] & 0x0F0F0F0Fu)) - vec4(zero); + value_upper = vec4(unpack4xU8((in[1] >> 4) & 0x0F0F0F0Fu)) - vec4(zero); out[2] = pack4xI8(vec4(value_lower[0], value_upper[0], value_lower[1], value_upper[1])); out[3] = pack4xI8(vec4(value_lower[2], value_upper[2], value_lower[3], value_upper[3])); return out; @@ -41,10 +45,7 @@ std::string CommonFunctions(uint32_t nbits) { } )ADDNL_FN"; } else { - ORT_ENFORCE(nbits == 8, "Only 4/8 bits are supported for webgpu matmulnbits"); - // For 8bits, in case data overflow when converting from int32 (output of dot4I8Packed) to f16, we force it convert to f32. - // Then do the scale. Finally, convert to output element type. - return R"ADDNL_FN( + ss << R"ADDNL_FN( fn AlignWithZeroPoint(in: vec4) -> vec4 { var out = vec4(0); @@ -54,7 +55,40 @@ std::string CommonFunctions(uint32_t nbits) { out[3] = pack4xI8(vec4(unpack4xU8(in[3])) - vec4(128)); return out; } - + )ADDNL_FN"; + // For 8bits, in case data overflow when converting from int32 (output of dot4I8Packed) to f16, we force it convert to f32. + // Then do the scale. Finally, convert to output element type. + if (has_zero_points) { + // If has_zero_points is true, vec4(unpack4xU8(b_data)) - vec4(zero) may be out of the range [-128, 127] since zero can be any value between [0, 255]. + // To avoid the data overflow when use pack4xI8, we still use |pack4xI8(vec4(unpack4xU8(xxx)) - vec4(128))| to process the b data. In SDP8AI, we use the + // dp4a's result of a and b to subtract dot(vec4(unpack4xI8(a)), vec4(zero - 128)) to get the correct result. + ss << R"ADDNL_FN( + // Scaled dot product of 8 packed unsigned integers. + fn SDP8AI(a1:vec4, b1:vec4, a2:vec4, b2:vec4, scale:output_element_t, zero: i32) -> output_element_t + { + let bias_zero = zero - 128; + var local_sum = dot4I8Packed(a1[0], b1[0]); + var dequantized_a_sum = vec4(unpack4xI8(a1[0])); + local_sum += dot4I8Packed(a1[1], b1[1]); + dequantized_a_sum += vec4(unpack4xI8(a1[1])); + local_sum += dot4I8Packed(a1[2], b1[2]); + dequantized_a_sum += vec4(unpack4xI8(a1[2])); + local_sum += dot4I8Packed(a1[3], b1[3]); + dequantized_a_sum += vec4(unpack4xI8(a1[3])); + local_sum += dot4I8Packed(a2[0], b2[0]); + dequantized_a_sum += vec4(unpack4xI8(a2[0])); + local_sum += dot4I8Packed(a2[1], b2[1]); + dequantized_a_sum += vec4(unpack4xI8(a2[1])); + local_sum += dot4I8Packed(a2[2], b2[2]); + dequantized_a_sum += vec4(unpack4xI8(a2[2])); + local_sum += dot4I8Packed(a2[3], b2[3]); + dequantized_a_sum += vec4(unpack4xI8(a2[3])); + local_sum -= dot(dequantized_a_sum, vec4(bias_zero)); + return output_element_t(f32(local_sum) * f32(scale)); + } + )ADDNL_FN"; + } else { + ss << R"ADDNL_FN( // Scaled dot product of 8 packed unsigned integers. fn SDP8AI(a1:vec4, b1:vec4, a2:vec4, b2:vec4, scale:output_element_t) -> output_element_t { @@ -69,7 +103,9 @@ std::string CommonFunctions(uint32_t nbits) { return output_element_t(f32(local_sum) * f32(scale)); } )ADDNL_FN"; + } } + return ss.str(); } } // namespace @@ -104,6 +140,9 @@ Status DP4AMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const { shader.AddInput("scales_a", ShaderUsage::UseUniform); shader.AddInput("input_b", ShaderUsage::UseUniform); shader.AddInput("scales_b", ShaderUsage::UseUniform); + if (has_zero_points_) { + shader.AddInput("zero_points", ShaderUsage::UseUniform); + } shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseElementTypeAlias); // This shader implements co-operative matrix multiply. The key idea here is to @@ -131,7 +170,7 @@ Status DP4AMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const { // this shader require A to be int8 quantized with block size 64. B is regular // matmulnbits input with block size 32. - shader.AdditionalImplementation() << CommonFunctions(nbits_) + shader.AdditionalImplementation() << CommonFunctions(nbits_, has_zero_points_) << " const block_size = " << block_size_ << ";"; shader.AdditionalImplementation() << R"ADDNL_FN( @@ -147,7 +186,11 @@ Status DP4AMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const { var scale_A : array; // 64 x 1 var tile_B : array, tile_size>, tile_size_k_vec>; // 64 x 32 var scale_B : array; // 64 x 1 - + )ADDNL_FN"; + if (nbits_ == 8 && has_zero_points_) { + shader.AdditionalImplementation() << " var zeroes : array;"; + } + shader.AdditionalImplementation() << R"ADDNL_FN( fn loadSHMA(a_global_base:u32, kidx_v:u32, row: u32, col: u32) { let a_global = a_global_base + row; @@ -174,11 +217,13 @@ Status DP4AMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const { } let b_value = input_b[b_global*uniforms.K16+kidx_v+col]; - tile_B[col][row] = DequantizedFrom4BitsTo8Bits(b_value); + let block_idx = kidx_v/(block_size/16); + let zero = mm_read_zero(b_global, block_idx, uniforms.N, uniforms.zero_blocks_per_col); + tile_B[col][row] = DequantizedFrom4BitsTo8Bits(b_value, zero); if (col == 0) { // kidx_v - each kidx_v covers 16 values of k - scale_B[row] = scales_b[b_global*(uniforms.K/block_size) + kidx_v/(block_size/16)]; + scale_B[row] = scales_b[b_global*(uniforms.K/block_size) + block_idx]; } } )ADDNL_FN"; @@ -198,7 +243,13 @@ Status DP4AMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const { if (col == 0) { // kidx_v - each kidx_v covers 16 values of k - scale_B[row] = scales_b[b_global*(uniforms.K/block_size) + kidx_v/(block_size/16)]; + let block_idx = kidx_v/(block_size/16); + scale_B[row] = scales_b[b_global*(uniforms.K/block_size) + block_idx]; + )ADDNL_FN"; + if (has_zero_points_) { + shader.AdditionalImplementation() << " zeroes[row] = mm_read_zero(b_global, block_idx, uniforms.N, uniforms.zero_blocks_per_col);\n"; + } + shader.AdditionalImplementation() << R"ADDNL_FN( } } )ADDNL_FN"; @@ -248,6 +299,64 @@ Status DP4AMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const { var own_a0: vec4 = tile_A[0][base_A + a_idx]; var own_a1: vec4 = tile_A[1][base_A + a_idx]; var own_scale_a: output_element_t = scale_A[base_A + a_idx]; + )MAIN_FN"; + if (nbits_ == 8 && has_zero_points_) { + shader.MainFunctionBody() << R"MAIN_FN( + if (sg_size == 16) + { + var own_b0: vec4 = tile_B[0][base_B + sg_id]; + var own_b1: vec4 = tile_B[1][base_B + sg_id]; + var own_scale_b: output_element_t = scale_B[base_B + sg_id]; + var zero = zeroes[base_B + sg_id]; + // Step 2: Access registers across the subgroup using subgroupShuffle and perform the matmul. + lane_output1[0] += SDP8AI(own_a0, subgroupShuffle(own_b0, 0), own_a1, subgroupShuffle(own_b1, 0), subgroupShuffle(own_scale_b, 0) * own_scale_a, subgroupShuffle(zero, 0)); + lane_output1[1] += SDP8AI(own_a0, subgroupShuffle(own_b0, 1), own_a1, subgroupShuffle(own_b1, 1), subgroupShuffle(own_scale_b, 1) * own_scale_a, subgroupShuffle(zero, 1)); + lane_output1[2] += SDP8AI(own_a0, subgroupShuffle(own_b0, 2), own_a1, subgroupShuffle(own_b1, 2), subgroupShuffle(own_scale_b, 2) * own_scale_a, subgroupShuffle(zero, 2)); + lane_output1[3] += SDP8AI(own_a0, subgroupShuffle(own_b0, 3), own_a1, subgroupShuffle(own_b1, 3), subgroupShuffle(own_scale_b, 3) * own_scale_a, subgroupShuffle(zero, 3)); + + lane_output2[0] += SDP8AI(own_a0, subgroupShuffle(own_b0, 4), own_a1, subgroupShuffle(own_b1, 4), subgroupShuffle(own_scale_b, 4) * own_scale_a, subgroupShuffle(zero, 4)); + lane_output2[1] += SDP8AI(own_a0, subgroupShuffle(own_b0, 5), own_a1, subgroupShuffle(own_b1, 5), subgroupShuffle(own_scale_b, 5) * own_scale_a, subgroupShuffle(zero, 5)); + lane_output2[2] += SDP8AI(own_a0, subgroupShuffle(own_b0, 6), own_a1, subgroupShuffle(own_b1, 6), subgroupShuffle(own_scale_b, 6) * own_scale_a, subgroupShuffle(zero, 6)); + lane_output2[3] += SDP8AI(own_a0, subgroupShuffle(own_b0, 7), own_a1, subgroupShuffle(own_b1, 7), subgroupShuffle(own_scale_b, 7) * own_scale_a, subgroupShuffle(zero, 7)); + + lane_output3[0] += SDP8AI(own_a0, subgroupShuffle(own_b0, 8), own_a1, subgroupShuffle(own_b1, 8), subgroupShuffle(own_scale_b, 8) * own_scale_a, subgroupShuffle(zero, 8)); + lane_output3[1] += SDP8AI(own_a0, subgroupShuffle(own_b0, 9), own_a1, subgroupShuffle(own_b1, 9), subgroupShuffle(own_scale_b, 9) * own_scale_a, subgroupShuffle(zero, 9)); + lane_output3[2] += SDP8AI(own_a0, subgroupShuffle(own_b0, 10), own_a1, subgroupShuffle(own_b1, 10), subgroupShuffle(own_scale_b, 10) * own_scale_a, subgroupShuffle(zero, 10)); + lane_output3[3] += SDP8AI(own_a0, subgroupShuffle(own_b0, 11), own_a1, subgroupShuffle(own_b1, 11), subgroupShuffle(own_scale_b, 11) * own_scale_a, subgroupShuffle(zero, 11)); + + lane_output4[0] += SDP8AI(own_a0, subgroupShuffle(own_b0, 12), own_a1, subgroupShuffle(own_b1, 12), subgroupShuffle(own_scale_b, 12) * own_scale_a, subgroupShuffle(zero, 12)); + lane_output4[1] += SDP8AI(own_a0, subgroupShuffle(own_b0, 13), own_a1, subgroupShuffle(own_b1, 13), subgroupShuffle(own_scale_b, 13) * own_scale_a, subgroupShuffle(zero, 13)); + lane_output4[2] += SDP8AI(own_a0, subgroupShuffle(own_b0, 14), own_a1, subgroupShuffle(own_b1, 14), subgroupShuffle(own_scale_b, 14) * own_scale_a, subgroupShuffle(zero, 14)); + lane_output4[3] += SDP8AI(own_a0, subgroupShuffle(own_b0, 15), own_a1, subgroupShuffle(own_b1, 15), subgroupShuffle(own_scale_b, 15) * own_scale_a, subgroupShuffle(zero, 15)); + } + else + { + // Code for other subgroup sizes, simply doesnt use subgroups at all. + // Relies on reads from single location tile_B[][base_B + col] by all + // being optimized by the hardware. + lane_output1[0] += SDP8AI(own_a0, tile_B[0][base_B + 0], own_a1, tile_B[1][base_B + 0], own_scale_a * scale_B[base_B + 0], zeroes[base_B + 0]); + lane_output1[1] += SDP8AI(own_a0, tile_B[0][base_B + 1], own_a1, tile_B[1][base_B + 1], own_scale_a * scale_B[base_B + 1], zeroes[base_B + 1]); + lane_output1[2] += SDP8AI(own_a0, tile_B[0][base_B + 2], own_a1, tile_B[1][base_B + 2], own_scale_a * scale_B[base_B + 2], zeroes[base_B + 2]); + lane_output1[3] += SDP8AI(own_a0, tile_B[0][base_B + 3], own_a1, tile_B[1][base_B + 3], own_scale_a * scale_B[base_B + 3], zeroes[base_B + 3]); + + lane_output2[0] += SDP8AI(own_a0, tile_B[0][base_B + 4], own_a1, tile_B[1][base_B + 4], own_scale_a * scale_B[base_B + 4], zeroes[base_B + 4]); + lane_output2[1] += SDP8AI(own_a0, tile_B[0][base_B + 5], own_a1, tile_B[1][base_B + 5], own_scale_a * scale_B[base_B + 5], zeroes[base_B + 5]); + lane_output2[2] += SDP8AI(own_a0, tile_B[0][base_B + 6], own_a1, tile_B[1][base_B + 6], own_scale_a * scale_B[base_B + 6], zeroes[base_B + 6]); + lane_output2[3] += SDP8AI(own_a0, tile_B[0][base_B + 7], own_a1, tile_B[1][base_B + 7], own_scale_a * scale_B[base_B + 7], zeroes[base_B + 7]); + + lane_output3[0] += SDP8AI(own_a0, tile_B[0][base_B + 8], own_a1, tile_B[1][base_B + 8], own_scale_a * scale_B[base_B + 8], zeroes[base_B + 8]); + lane_output3[1] += SDP8AI(own_a0, tile_B[0][base_B + 9], own_a1, tile_B[1][base_B + 9], own_scale_a * scale_B[base_B + 9], zeroes[base_B + 9]); + lane_output3[2] += SDP8AI(own_a0, tile_B[0][base_B + 10], own_a1, tile_B[1][base_B + 10], own_scale_a * scale_B[base_B + 10], zeroes[base_B + 10]); + lane_output3[3] += SDP8AI(own_a0, tile_B[0][base_B + 11], own_a1, tile_B[1][base_B + 11], own_scale_a * scale_B[base_B + 11], zeroes[base_B + 11]); + + lane_output4[0] += SDP8AI(own_a0, tile_B[0][base_B + 12], own_a1, tile_B[1][base_B + 12], own_scale_a * scale_B[base_B + 12], zeroes[base_B + 12]); + lane_output4[1] += SDP8AI(own_a0, tile_B[0][base_B + 13], own_a1, tile_B[1][base_B + 13], own_scale_a * scale_B[base_B + 13], zeroes[base_B + 13]); + lane_output4[2] += SDP8AI(own_a0, tile_B[0][base_B + 14], own_a1, tile_B[1][base_B + 14], own_scale_a * scale_B[base_B + 14], zeroes[base_B + 14]); + lane_output4[3] += SDP8AI(own_a0, tile_B[0][base_B + 15], own_a1, tile_B[1][base_B + 15], own_scale_a * scale_B[base_B + 15], zeroes[base_B + 15]); + } + )MAIN_FN"; + } else { + shader.MainFunctionBody() << R"MAIN_FN( if (sg_size == 16) { var own_b0: vec4 = tile_B[0][base_B + sg_id]; @@ -299,6 +408,9 @@ Status DP4AMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const { lane_output4[2] += SDP8AI(own_a0, tile_B[0][base_B + 14], own_a1, tile_B[1][base_B + 14], own_scale_a * scale_B[base_B + 14]); lane_output4[3] += SDP8AI(own_a0, tile_B[0][base_B + 15], own_a1, tile_B[1][base_B + 15], own_scale_a * scale_B[base_B + 15]); } + )MAIN_FN"; + } + shader.MainFunctionBody() << R"MAIN_FN( workgroupBarrier(); } @@ -324,16 +436,23 @@ Status DP4AMatMulNBitsSmallMProgram::GenerateShaderCode(ShaderHelper& shader) co shader.AddInput("scales_a", ShaderUsage::UseUniform); shader.AddInput("input_b", ShaderUsage::UseUniform); shader.AddInput("scales_b", ShaderUsage::UseUniform); + if (has_zero_points_) { + shader.AddInput("zero_points", ShaderUsage::UseUniform); + } shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseElementTypeAlias); + + ORT_ENFORCE(WorkgroupSizeX() % tile_size_k_vec_ == 0 && tile_size_k_vec_ % 4 == 0, "tile_size_k_vec_ must evenly divide workgroup size X and be divisible by 4"); + const uint32_t sub_tile_count = WorkgroupSizeX() / tile_size_k_vec_; + ORT_ENFORCE(tile_size_ % sub_tile_count == 0, "tile_size_ must be divisible by sub_tile_count"); + // This algorithm works to compute dot product of k parallelly, by processing k at each step amongst tile_size_k_vec threads, // and utilizing the remaining threads in the workgroup to process additional rows of b in parallel (such that the values in shared memory for A can be reused). // For each load of k, the tile_size_k_vec threads also reload B tile_size/num_concurrent_b_rows times to compute partial dot products of other B rows // in order to complete all tile_size b rows in this workgroup and also reusing the loaded in register values of a. - // 1. Each workgroup handles tile_size_k_vec (16) * k_vectorization_in_b (32) columns (total 512) and num_concurrent_b_rows of matrix B at a time, + // 1. Each workgroup handles tile_size_k_vec * k_vectorization_in_b (32) columns and num_concurrent_b_rows of matrix B at a time, // iterating over the columns to compute a partial dot product. // 2. Uses vec4 vectorization where each K represents 32 elements of matrix B - constexpr uint32_t tile_size_k_vec = 16; // 1. Workgroup Responsibility: // - Processes one row of matrix A @@ -346,18 +465,19 @@ Status DP4AMatMulNBitsSmallMProgram::GenerateShaderCode(ShaderHelper& shader) co // - Iterates through columns accumulating results in inter_results // - Performs final reduction sum in inter_results for output shader.AdditionalImplementation() << " const tile_size = " << tile_size_ << "u;\n" - << " const tile_size_k_vec = " << tile_size_k_vec << "u;\n" - << " const double_tile_size_k_vec = " << 2 * tile_size_k_vec << "u;\n" + << " const tile_size_k_vec = " << tile_size_k_vec_ << "u;\n" + << " const double_tile_size_k_vec = " << 2 * tile_size_k_vec_ << "u;\n" // sub_tile_count is the number of concurrent b rows processed by the workgroup. - << " const sub_tile_count = " << WorkgroupSizeX() / tile_size_k_vec << "u;\n" - << " var inter_results: array, tile_size>;\n"; + << " const sub_tile_count = " << sub_tile_count << "u;\n"; - shader.AdditionalImplementation() << CommonFunctions(nbits_) + shader.AdditionalImplementation() << CommonFunctions(nbits_, has_zero_points_) << R"ADDNL_FN( - // Need 2 * tile_size_k_vec (32) to store a tile_A since b is quantized as 4 bits and a is quantized as 8 bits. - var tile_A : array, 32>; - // Need 4 scales value since each tile_A includes 512 (4x4x32) scalars and the block_size is 128. - var scale_A : array; + var inter_results: array, tile_size>; + // Need 2 * tile_size_k_vec to store a tile_A since b is quantized as 4 bits and a is quantized as 8 bits. + var tile_A : array, double_tile_size_k_vec>; + // double_tile_size_k_vec * 16 / 128 + const scale_a_size_in_tile_a = double_tile_size_k_vec / 8; + var scale_A : array; fn loadSHMA(a_global: u32, kidx_v: u32, col: u32) { let k_offset = kidx_v + col; @@ -366,7 +486,7 @@ Status DP4AMatMulNBitsSmallMProgram::GenerateShaderCode(ShaderHelper& shader) co } tile_A[col] = input_a[a_global*uniforms.K16+k_offset]; - if (col < 4) + if (col < scale_a_size_in_tile_a) { // kidx_v - covers 16 values of k in input_a scale_A[col] = scales_a[a_global*(uniforms.K/128) + kidx_v/8 + col]; @@ -391,32 +511,37 @@ Status DP4AMatMulNBitsSmallMProgram::GenerateShaderCode(ShaderHelper& shader) co var own_a: vec4 = tile_A[local_col * 2]; var own_a1: vec4 = tile_A[local_col * 2 + 1]; var own_scale_a = scale_A[local_col / 4]; - var own_b = vec4(0); - var own_b1 = vec4(0); let k_offset = kidx_v + local_col; + // k_offset - covers 32 values of k in input_b + let block_idx = k_offset * 32 / uniforms.block_size; // calculate intermediate results into inter_results. for (var row_offset = 0u; row_offset < tile_size; row_offset += sub_tile_count) { let b_global = b_global_base + row_offset + local_row; if (b_global < uniforms.N && k_offset < uniforms.K32) { let b_offset = b_global * uniforms.K32 + k_offset; + let zero = mm_read_zero(b_global, block_idx, uniforms.N, uniforms.zero_blocks_per_col); + let own_scale_b = scales_b[b_global * uniforms.K / uniforms.block_size + block_idx]; )MAIN_FN"; if (nbits_ == 4) { shader.MainFunctionBody() << R"MAIN_FN( let b_value = input_b[b_offset]; - own_b = DequantizedFrom4BitsTo8Bits(b_value.xy); - own_b1 = DequantizedFrom4BitsTo8Bits(b_value.zw); + let own_b = DequantizedFrom4BitsTo8Bits(b_value.xy, zero); + let own_b1 = DequantizedFrom4BitsTo8Bits(b_value.zw, zero); + inter_results[row_offset + local_row][local_col] += SDP8AI(own_a, own_b, own_a1, own_b1, own_scale_a * own_scale_b); )MAIN_FN"; } else { shader.MainFunctionBody() << R"MAIN_FN( - own_b = AlignWithZeroPoint(input_b[b_offset * 2]); - own_b1 = AlignWithZeroPoint(input_b[b_offset * 2 + 1]); + let own_b = AlignWithZeroPoint(input_b[b_offset * 2]); + let own_b1 = AlignWithZeroPoint(input_b[b_offset * 2 + 1]); )MAIN_FN"; + if (has_zero_points_) { + shader.MainFunctionBody() << " inter_results[row_offset + local_row][local_col] += SDP8AI(own_a, own_b, own_a1, own_b1, own_scale_a * own_scale_b, zero);\n"; + } else { + shader.MainFunctionBody() << " inter_results[row_offset + local_row][local_col] += SDP8AI(own_a, own_b, own_a1, own_b1, own_scale_a * own_scale_b);\n"; + } } shader.MainFunctionBody() << R"MAIN_FN( - // k_offset - covers 32 values of k in input_b - let own_scale_b = scales_b[b_global * uniforms.K / uniforms.block_size + k_offset * 32 / uniforms.block_size]; - inter_results[row_offset + local_row][local_col] += SDP8AI(own_a, own_b, own_a1, own_b1, own_scale_a * own_scale_b); } } workgroupBarrier(); @@ -440,10 +565,12 @@ Status DP4AMatMulNBitsSmallMProgram::GenerateShaderCode(ShaderHelper& shader) co } Status ApplyDP4AMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Tensor* scales, + const Tensor* zero_points, uint32_t M, uint32_t N, uint32_t K, uint32_t block_size, + uint32_t zero_blocks_per_col, uint32_t min_M_for_tile_optimization, uint32_t nbits, onnxruntime::webgpu::ComputeContext& context, @@ -464,20 +591,25 @@ Status ApplyDP4AMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Tensor .AddOutputs({{&a_quant, ProgramTensorMetadataDependency::Rank, a_quant.Shape(), 1}, {&a_scale, ProgramTensorMetadataDependency::Rank, a_scale.Shape(), 1}}); ORT_RETURN_IF_ERROR(context.RunProgram(quantize_program)); - + const bool has_zero_points = zero_points != nullptr; if (M < min_M_for_tile_optimization) { - constexpr uint32_t kTileSize = 32; - DP4AMatMulNBitsSmallMProgram mul_program{kTileSize, nbits}; - uint32_t num_N_tile = (N + kTileSize - 1) / kTileSize; + uint32_t tile_size_k_vec = 16; + uint32_t tile_size = 32; + + DP4AMatMulNBitsSmallMProgram mul_program{tile_size_k_vec, tile_size, nbits, has_zero_points}; + uint32_t num_N_tile = (N + tile_size - 1) / tile_size; mul_program.SetWorkgroupSize(128); mul_program.SetDispatchGroupSize(M * num_N_tile); mul_program.AddInputs({{&a_quant, ProgramTensorMetadataDependency::TypeAndRank, static_cast(kVec4Components)}, {&a_scale, ProgramTensorMetadataDependency::TypeAndRank, 1}, {b, ProgramTensorMetadataDependency::TypeAndRank, static_cast(kVec4Components * kU32Components)}, {scales, ProgramTensorMetadataDependency::TypeAndRank, 1}}) - .AddUniformVariables({M, N, K, K / 16, K / 32, block_size, num_N_tile}) + .AddUniformVariables({M, N, K, K / 16, K / 32, block_size, num_N_tile, zero_blocks_per_col}) .AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank, 1}) - .CacheHint(nbits); + .CacheHint(nbits, tile_size_k_vec, tile_size, has_zero_points); + if (has_zero_points) { + mul_program.AddInput({zero_points, ProgramTensorMetadataDependency::None, {(zero_points->Shape().Size() + 3) / 4}, 4}); + } return context.RunProgram(mul_program); } @@ -485,7 +617,7 @@ Status ApplyDP4AMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Tensor TensorShape reshaped_y_shape{1, M, N / kVec4Components}; uint32_t num_M_tile = (M + kTileSize - 1) / kTileSize; uint32_t num_N_tile = (N + kTileSize - 1) / kTileSize; - DP4AMatMulNBitsProgram mul_program{block_size, nbits}; + DP4AMatMulNBitsProgram mul_program{block_size, nbits, has_zero_points}; mul_program.SetWorkgroupSize(256); mul_program.SetDispatchGroupSize(num_M_tile * num_N_tile); mul_program.AddInputs({{&a_quant, ProgramTensorMetadataDependency::TypeAndRank, static_cast(kVec4Components)}, @@ -497,9 +629,13 @@ Status ApplyDP4AMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Tensor {static_cast(K)}, {static_cast(K / 8)}, {static_cast(K / 16)}, - {num_N_tile}}) + {num_N_tile}, + {zero_blocks_per_col}}) .AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank, reshaped_y_shape, static_cast(kVec4Components)}) - .CacheHint("Block" + std::to_string(block_size), nbits); + .CacheHint("Block" + std::to_string(block_size), nbits, has_zero_points); + if (has_zero_points) { + mul_program.AddInput({zero_points, ProgramTensorMetadataDependency::None, {(zero_points->Shape().Size() + 3) / 4}, 4}); + } return context.RunProgram(mul_program); } @@ -509,8 +645,7 @@ bool CanApplyDP4AMatrixMatMulNBits(onnxruntime::webgpu::ComputeContext& context, uint32_t batch_count, uint32_t N, uint32_t K, - uint32_t components_k, - bool has_zero_points) { + uint32_t components_k) { // macOS - Avoid using dp4a on Metal, as it does not appear to have native dp4a support. // https://github.com/gpuweb/gpuweb/issues/2677#issuecomment-1713292226 // Use 'vendor' to check for metal; 'backend' is always WEBGPU when running under wasm. @@ -518,7 +653,7 @@ bool CanApplyDP4AMatrixMatMulNBits(onnxruntime::webgpu::ComputeContext& context, context.AdapterInfo().vendor != std::string_view{"apple"}; return (accuracy_level == 4 && block_size % 32 == 0 && batch_count == 1 && components_k == 4 && K % 128 == 0 && N % 16 == 0 && - !has_zero_points && use_dp4a); + use_dp4a); } } // namespace webgpu diff --git a/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.h b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.h index 647e200ce93b7..81ddc411b385b 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.h +++ b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.h @@ -20,7 +20,10 @@ class DP4AMatMulQuantizeProgram final : public Program { public: - DP4AMatMulNBitsProgram(uint32_t block_size, uint32_t nbits) : Program{"DP4AMatMulNBits"}, block_size_(block_size), nbits_(nbits) {} + DP4AMatMulNBitsProgram(uint32_t block_size, uint32_t nbits, bool has_zero_points) : Program{"DP4AMatMulNBits"}, + block_size_(block_size), + nbits_(nbits), + has_zero_points_(has_zero_points) {} Status GenerateShaderCode(ShaderHelper& sh) const override; WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES( {"M", ProgramUniformVariableDataType::Uint32}, @@ -28,16 +31,22 @@ class DP4AMatMulNBitsProgram final : public Program { {"K", ProgramUniformVariableDataType::Uint32}, {"K8", ProgramUniformVariableDataType::Uint32}, {"K16", ProgramUniformVariableDataType::Uint32}, - {"num_N_tile", ProgramUniformVariableDataType::Uint32}); + {"num_N_tile", ProgramUniformVariableDataType::Uint32}, + {"zero_blocks_per_col", ProgramUniformVariableDataType::Uint32}); private: uint32_t block_size_; uint32_t nbits_; + bool has_zero_points_; }; class DP4AMatMulNBitsSmallMProgram final : public Program { public: - DP4AMatMulNBitsSmallMProgram(uint32_t tile_size, uint32_t nbits) : Program{"DP4AMatMulNBitsSmallMProgram"}, tile_size_(tile_size), nbits_(nbits) {} + DP4AMatMulNBitsSmallMProgram(uint32_t tile_size_k_vec, uint32_t tile_size, uint32_t nbits, bool has_zero_points) : Program{"DP4AMatMulNBitsSmallMProgram"}, + tile_size_k_vec_(tile_size_k_vec), + tile_size_(tile_size), + nbits_(nbits), + has_zero_points_(has_zero_points) {} Status GenerateShaderCode(ShaderHelper& sh) const override; WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES( {"M", ProgramUniformVariableDataType::Uint32}, @@ -46,18 +55,23 @@ class DP4AMatMulNBitsSmallMProgram final : public Program #include "contrib_ops/webgpu/quantization/matmul_nbits.h" +#include "contrib_ops/webgpu/quantization/matmul_nbits_common.h" #include "contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.h" #include "contrib_ops/webgpu/quantization/dp4a_matmul_nbits.h" #include "contrib_ops/webgpu/webgpu_contrib_kernels.h" @@ -19,44 +20,8 @@ namespace webgpu { namespace { -std::string ReadZeroPoint(uint32_t nbits, bool has_zero_points) { - ORT_ENFORCE(nbits == 8 || nbits == 4, "Only 4/8 bits are supported for webgpu matmulnbits"); - std::stringstream ss; - if (has_zero_points) { - ss << "const elements_in_uint32 = " << (32 / nbits) << "u;\n" - << "const bits = " << nbits << "u;\n"; - ss << R"( -fn mm_read_zero(row : u32, col : u32, r_dim: u32, c_dim: u32) -> output_element_t { - if (row < r_dim && col < c_dim) { - let offset = row * c_dim + col; - - // u32 holds elements_in_uint32 packed nbits. - let array_index = offset / elements_in_uint32; - let component_index = offset % elements_in_uint32; - let packed_value = zero_points[array_index]; - - // Extract the nbits component - let shift_amount = component_index * bits; -)"; - ss << " let masked_value = (packed_value >> shift_amount) & " << (nbits == 4 ? "0xFu" : "0xFF") << ";\n"; - ss << R"( - return output_element_t(masked_value); - } - return output_element_t(0); -} -)"; - } else { - ss << "const default_zero_point = " << (nbits == 4 ? 8 : 128) << ";\n"; - ss << R"( -fn mm_read_zero(row : u32, col : u32, r_dim: u32, c_dim: u32) -> output_element_t { - return output_element_t(default_zero_point); -} -)"; - } - return ss.str(); -} - constexpr unsigned int kMinMForTileOptimization = 4; + } // namespace ONNX_OPERATOR_KERNEL_EX( @@ -133,7 +98,7 @@ fn dequantize_packed8xU4(packed_value : u32, zero_point : output_element_t, scal << " }\n" << " return output_element_t(0);\n" << "}\n" - << ReadZeroPoint(nbits_, has_zero_points_); + << GenerateZeroPointReadingCode(nbits_, has_zero_points_); shader.AdditionalImplementation() << "\n" << "fn mm_write_y(batch : u32, row : u32, col : u32, value : output_value_t) {\n" @@ -271,7 +236,7 @@ Status MatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const { } } )ADDNL_FN" - << ReadZeroPoint(nbits_, has_zero_points_); + << GenerateZeroPointReadingCode(nbits_, has_zero_points_); shader.MainFunctionBody() << R"MAIN_FN( let batch = workgroup_idx / (uniforms.M * uniforms.num_N_tile); @@ -394,6 +359,11 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context ORT_ENFORCE(g_idx == nullptr, "group_idx as input is not supported yet."); ORT_ENFORCE(bias == nullptr, "bias as input is not supported yet."); + const bool has_zero_points = zero_points != nullptr; + if (has_zero_points) { + ORT_ENFORCE(zero_points->DataType() == DataTypeImpl::GetType(), "Currently, only uint8 is supported for zero points, but got ", zero_points->DataType()); + } + MatMulComputeHelper helper; TensorShape b_shape({N_, K_}); ORT_RETURN_IF_ERROR(helper.Compute(a->Shape(), b_shape, false, true)); @@ -416,25 +386,26 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context const uint32_t components_a = GetMaxComponents(K); const uint32_t components_b = GetMaxComponents(blob_size_in_words); uint32_t components = GetMaxComponents(N); + // zero_points has shape[N * CeilDiv(n_blocks_per_col * bits, 8)]. So here we need to check whether n_blocks_per_col is divisible by 8/nbits. + // For bits==4, this is counted by elements of uint4. Need add 1 if not divisible by 2. + uint32_t zero_blocks_per_col = n_blocks_per_col % (8 / nbits) == 0 ? n_blocks_per_col : n_blocks_per_col + 1; - const bool has_zero_points = zero_points != nullptr; - // macOS - Experimental dawn support for subgroup matrix matmul on Metal. - if (M >= kMinMForTileOptimization && - CanApplySubgroupMatrixMatMulNBits(context, accuracy_level_, block_size, batch_count, N, K, has_zero_points)) { - return ApplySubgroupMatrixMatMulNBits(a, b, scales, M, N, K, nbits, context, y); +#if !defined(__wasm__) + int32_t subgroup_matrix_config_index = -1; + // apple|intel - Experimental dawn support for subgroup matrix matmul. + if (M >= kMinMForTileOptimization && (context.AdapterInfo().vendor == std::string_view{"apple"} || context.AdapterInfo().vendor == std::string_view{"intel"}) && + CanApplySubgroupMatrixMatMulNBits(context, accuracy_level_, block_size, batch_count, N, K, subgroup_matrix_config_index)) { + return ApplySubgroupMatrixMatMulNBits(a, b, scales, zero_points, M, N, K, nbits, zero_blocks_per_col, subgroup_matrix_config_index, context, y); } +#endif // On FP32 only GPUs, integer math is faster than FP32 therefore always use DP4A independent of length of M. if ((M >= kMinMForTileOptimization || y->DataType() == DataTypeImpl::GetType() || context.AdapterInfo().vendor == std::string_view{"qualcomm"}) && - CanApplyDP4AMatrixMatMulNBits(context, accuracy_level_, block_size, batch_count, N, K, components_a, has_zero_points)) { - return ApplyDP4AMatrixMatMulNBits(a, b, scales, M, N, K, block_size, kMinMForTileOptimization, nbits, context, y); + CanApplyDP4AMatrixMatMulNBits(context, accuracy_level_, block_size, batch_count, N, K, components_a)) { + return ApplyDP4AMatrixMatMulNBits(a, b, scales, zero_points, M, N, K, block_size, zero_blocks_per_col, kMinMForTileOptimization, nbits, context, y); } - // zero_points has shape[N * CeilDiv(n_blocks_per_col * bits, 8)]. So here we need to check whether n_blocks_per_col is divisible by 8/nbits. - // For bits==4, this is counted by elements of uint4. Need add 1 if not divisible by 2. - uint32_t zero_blocks_per_col = n_blocks_per_col % (8 / nbits) == 0 ? n_blocks_per_col : n_blocks_per_col + 1; - // WideTileProgram // This program is optimized for Block32 prefill using Tile16x128. const bool use_wide_tile_program = block_size == 32 && components_a == 4 && components_b == 4 && M >= kMinMForTileOptimization; diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_common.cc b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_common.cc new file mode 100644 index 0000000000000..58b0f3ada9341 --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_common.cc @@ -0,0 +1,59 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "contrib_ops/webgpu/quantization/matmul_nbits_common.h" +#include +#include "core/common/common.h" + +namespace onnxruntime { +namespace contrib { +namespace webgpu { + +std::string GenerateZeroPointReadingCode(uint32_t nbits, bool has_zero_points, + const std::string& output_type) { + ORT_ENFORCE(nbits == 8 || nbits == 4, "Only 4/8 bits are supported for webgpu matmulnbits"); + std::stringstream ss; + + if (has_zero_points) { + ss << "const elements_in_uint32 = " << (32 / nbits) << "u;\n" + << "const bits = " << nbits << "u;\n"; + ss << R"( +fn mm_read_zero(row : u32, col : u32, r_dim: u32, c_dim: u32) -> )" + << output_type << R"( { + if (row < r_dim && col < c_dim) { + let offset = row * c_dim + col; + + // u32 holds elements_in_uint32 packed nbits. + let array_index = offset / elements_in_uint32; + let component_index = offset % elements_in_uint32; + let packed_value = zero_points[array_index]; + + // Extract the nbits component + let shift_amount = component_index * bits; +)"; + ss << " let masked_value = (packed_value >> shift_amount) & " << (nbits == 4 ? "0xFu" : "0xFF") << ";\n"; + ss << R"( + return )" + << output_type << R"((masked_value); + } + return )" + << output_type << R"((0); +} +)"; + } else { + ss << "const default_zero_point = " << (nbits == 4 ? 8 : 128) << ";\n"; + ss << R"( +fn mm_read_zero(row : u32, col : u32, r_dim: u32, c_dim: u32) -> )" + << output_type << R"( { + return )" + << output_type << R"((default_zero_point); +} +)"; + } + + return ss.str(); +} + +} // namespace webgpu +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_common.h b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_common.h new file mode 100644 index 0000000000000..dde18e8a78fd2 --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_common.h @@ -0,0 +1,26 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include + +namespace onnxruntime { +namespace contrib { +namespace webgpu { + +/** + * Generates WebGPU shader code for reading zero points in quantized matrix multiplication + * + * @param nbits Number of bits for quantization (4 or 8) + * @param has_zero_points Whether zero points are provided as an input + * @param output_type Type name to use for zero point values in the generated code (default: "output_element_t") + * @return String containing the generated WebGPU shader code + */ +std::string GenerateZeroPointReadingCode(uint32_t nbits, bool has_zero_points, + const std::string& output_type = "output_element_t"); + +} // namespace webgpu +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.cc b/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.cc index 674473a173445..519a6e47438e3 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.cc @@ -1,18 +1,222 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#if !defined(__wasm__) +#include + #include "contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.h" +#include "contrib_ops/webgpu/quantization/matmul_nbits_common.h" namespace onnxruntime { namespace contrib { namespace webgpu { -Status SubgroupMatrixMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const { - shader.AddInput("input_a", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias); - shader.AddInput("input_b", ShaderUsage::UseUniform); - shader.AddInput("scales_b", ShaderUsage::UseUniform); - shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseElementTypeAlias); +constexpr std::string_view ComponentTypeName[] = {"unknown", "f32", "f16", "u32", "i32"}; +template +constexpr bool ValidateComponentTypeName(const std::array& component_type) { + bool matched = true; + for (auto type : component_type) { + switch (type) { + case wgpu::SubgroupMatrixComponentType::F32: + matched = ComponentTypeName[static_cast(wgpu::SubgroupMatrixComponentType::F32)] == "f32"; + break; + case wgpu::SubgroupMatrixComponentType::F16: + matched = ComponentTypeName[static_cast(wgpu::SubgroupMatrixComponentType::F16)] == "f16"; + break; + case wgpu::SubgroupMatrixComponentType::U32: + matched = ComponentTypeName[static_cast(wgpu::SubgroupMatrixComponentType::U32)] == "u32"; + break; + case wgpu::SubgroupMatrixComponentType::I32: + matched = ComponentTypeName[static_cast(wgpu::SubgroupMatrixComponentType::I32)] == "i32"; + break; + default: + return false; + } + if (!matched) { + return matched; + } + } + + return matched; +} +static_assert(ValidateComponentTypeName<4>({wgpu::SubgroupMatrixComponentType::F32, + wgpu::SubgroupMatrixComponentType::F16, wgpu::SubgroupMatrixComponentType::U32, + wgpu::SubgroupMatrixComponentType::I32}), + "The elements' sequence of ComponentTypeName array do not match wgpu::SubgroupMatrixComponentType"); + +// std::tuple +static const std::tuple + intel_supported_subgroup_matrix_configs[] = { + {"xe-2lpg", wgpu::BackendType::Vulkan, wgpu::SubgroupMatrixComponentType::F16, wgpu::SubgroupMatrixComponentType::F16, 8, 16, 16, 16, 32}, + {"xe-2lpg", wgpu::BackendType::Vulkan, wgpu::SubgroupMatrixComponentType::F16, wgpu::SubgroupMatrixComponentType::F32, 8, 16, 16, 16, 32}}; + +bool IsSubgroupMatrixConfigSupportedOnIntel(onnxruntime::webgpu::ComputeContext& context, int32_t& config_index) { + const wgpu::AdapterInfo& adapter_info = context.AdapterInfo(); + const wgpu::AdapterPropertiesSubgroupMatrixConfigs& subgroup_matrix_configs = context.SubgroupMatrixConfigs(); + int32_t index = 0; + for (auto& supported_config : intel_supported_subgroup_matrix_configs) { + for (size_t i = 0; i < subgroup_matrix_configs.configCount; i++) { + auto& subgroup_matrix_config = subgroup_matrix_configs.configs[i]; + auto&& config = std::make_tuple(adapter_info.architecture, adapter_info.backendType, + subgroup_matrix_config.componentType, subgroup_matrix_config.resultComponentType, + subgroup_matrix_config.M, subgroup_matrix_config.N, subgroup_matrix_config.K, + adapter_info.subgroupMinSize, adapter_info.subgroupMaxSize); + if (config == supported_config) { + config_index = index; + return true; + } + } + index++; + } + return false; +} + +Status GenerateShaderCodeOnIntel(ShaderHelper& shader, uint32_t nbits, int32_t config_index, bool has_zero_points) { + auto& config = intel_supported_subgroup_matrix_configs[config_index]; + shader.AdditionalImplementation() << "alias component_type = " << ComponentTypeName[static_cast(std::get<2>(config))] << ";\n" + << "alias result_component_type = " << ComponentTypeName[static_cast(std::get<3>(config))] << ";\n" + << "const m_dim = " << std::get<4>(config) << ";\n" + << "const n_dim = " << std::get<5>(config) << ";\n" + << "const k_dim = " << std::get<6>(config) << ";\n"; + + shader.AdditionalImplementation() << R"ADDNL_FN( + const tile_cols = 64; + const tile_rows = 64; + const tile_k = 32; + const subtile_rows = 8; + const quantization_block_size = 32; + + var tile_A: array; // 64 x 32 - RxC + var tile_B: array; // 64 x 32 - RxC + + fn loadSHMA(tile_base: u32, k_idx: u32, row: u32, c_idx:u32) { + let a_global = tile_base + row; + if (a_global >= uniforms.M) { + return; + } + // Each call loads 8 columns, starting at col. + let col = c_idx * 8; + // 256 threads need to load 64 x 32. 4 threads per row or 8 col per thread. + for (var col_offset:u32 = 0; col_offset < 8; col_offset++) + { + tile_A[row * tile_k + col + col_offset] = component_type(input_a[a_global*uniforms.K + k_idx + col + col_offset]); + } + } + )ADDNL_FN" << GenerateZeroPointReadingCode(nbits, has_zero_points, "component_type"); + if (nbits == 4) { + shader.AdditionalImplementation() << R"ADDNL_FN( + fn loadSHMB(tile_base: u32, k_idx: u32, row: u32, c_idx: u32) { + let b_global = tile_base + row; + if (b_global >= uniforms.N) { + return; + } + // Each call loads 8 columns, starting at col. + let col = c_idx * 8; + // 256 threads need to load 64 x 32. 4 threads per row or 8 col per thread. + // Stored in column major fashion. + let b_idx = u32((b_global * uniforms.K + k_idx + col) / 8); + let scale = component_type(scales_b[(b_global * uniforms.K + k_idx + col) / quantization_block_size]); + let zero = mm_read_zero(b_global, (k_idx + col) / quantization_block_size, uniforms.N, uniforms.zero_blocks_per_col); + let b_value = input_b[b_idx]; + let b_value_lower = (vec4(unpack4xU8(b_value & 0x0F0F0F0Fu)) - vec4(zero)) * scale; + let b_value_upper = (vec4(unpack4xU8((b_value >> 4) & 0x0F0F0F0Fu)) - vec4(zero)) * scale; + let tile_b_base = row * tile_k + col; + tile_B[tile_b_base] = b_value_lower[0]; + tile_B[tile_b_base + 1] = b_value_upper[0]; + tile_B[tile_b_base + 2] = b_value_lower[1]; + tile_B[tile_b_base + 3] = b_value_upper[1]; + tile_B[tile_b_base + 4] = b_value_lower[2]; + tile_B[tile_b_base + 5] = b_value_upper[2]; + tile_B[tile_b_base + 6] = b_value_lower[3]; + tile_B[tile_b_base + 7] = b_value_upper[3]; + } + )ADDNL_FN"; + } else { + ORT_ENFORCE(nbits == 8, "Only 4/8 bits are supported for webgpu matmulnbits"); + shader.AdditionalImplementation() << R"ADDNL_FN( + fn loadSHMB(tile_base: u32, k_idx: u32, row: u32, c_idx: u32) { + let b_global = tile_base + row; + if (b_global >= uniforms.N) { + return; + } + // Each call loads 8 columns, starting at col. + let col = c_idx * 8; + // 256 threads need to load 64 x 32. 4 threads per row or 8 col per thread. + // Stored in column major fashion. + let b_idx = u32((b_global * uniforms.K + k_idx + col) / 8); + let scale = component_type(scales_b[(b_global * uniforms.K + k_idx + col) / quantization_block_size]); + let zero = mm_read_zero(b_global, (k_idx + col) / quantization_block_size, uniforms.N, uniforms.zero_blocks_per_col); + let b_value = input_b[b_idx]; + let b_value0 = (vec4(unpack4xU8(b_value[0])) - vec4(zero)) * scale; + let b_value1 = (vec4(unpack4xU8(b_value[1])) - vec4(zero)) * scale; + let tile_b_base = row * tile_k + col; + tile_B[tile_b_base] = b_value0[0]; + tile_B[tile_b_base + 1] = b_value0[1]; + tile_B[tile_b_base + 2] = b_value0[2]; + tile_B[tile_b_base + 3] = b_value0[3]; + tile_B[tile_b_base + 4] = b_value1[0]; + tile_B[tile_b_base + 5] = b_value1[1]; + tile_B[tile_b_base + 6] = b_value1[2]; + tile_B[tile_b_base + 7] = b_value1[3]; + } + )ADDNL_FN"; + } + + shader.MainFunctionBody() << R"MAIN_FN( + let a_global_base = workgroup_id.y * tile_rows; + let b_global_base = workgroup_id.x * tile_cols; + + let subtile_id = u32(local_idx / sg_size); + + var matC00: subgroup_matrix_result; + var matC01: subgroup_matrix_result; + var matC02: subgroup_matrix_result; + var matC03: subgroup_matrix_result; + for (var kidx: u32 = 0; kidx < uniforms.K; kidx += tile_k) { + // Load Phase + loadSHMA(a_global_base, kidx, local_idx / 4, local_idx % 4); + loadSHMB(b_global_base, kidx, local_idx / 4, local_idx % 4); + workgroupBarrier(); + + for (var step: u32 = 0; step < tile_k; step += k_dim) + { + // Load to local memory phase + let matrix_a_offset = subtile_id * subtile_rows * tile_k + step; + // Syntax: subgroupMatrixLoad src_ptr, src_offset, is_col_major, src_stride + var matA0: subgroup_matrix_left = subgroupMatrixLoad>(&tile_A, matrix_a_offset, false, tile_k); + + // tile_B is stored as column major. + // [col0-0:32][col1-0:32][col2-0:32]..[col63-0:32] + var matrix_b_offset = step; + var matB0: subgroup_matrix_right = subgroupMatrixLoad>(&tile_B, matrix_b_offset, true, tile_k); + var matB1: subgroup_matrix_right = subgroupMatrixLoad>(&tile_B, matrix_b_offset + n_dim * tile_k, true, tile_k); + var matB2: subgroup_matrix_right = subgroupMatrixLoad>(&tile_B, matrix_b_offset + 2 * n_dim * tile_k, true, tile_k); + var matB3: subgroup_matrix_right = subgroupMatrixLoad>(&tile_B, matrix_b_offset + 3 * n_dim * tile_k, true, tile_k); + + // Compute Phase + // Syntax: subgroupMatrixMultiplyAccumulate left, right, accumulate -> accumulate + matC00 = subgroupMatrixMultiplyAccumulate(matA0, matB0, matC00); + matC01 = subgroupMatrixMultiplyAccumulate(matA0, matB1, matC01); + matC02 = subgroupMatrixMultiplyAccumulate(matA0, matB2, matC02); + matC03 = subgroupMatrixMultiplyAccumulate(matA0, matB3, matC03); + } + workgroupBarrier(); + } + + // Write out + let matrix_c_offset = (a_global_base) * uniforms.N + b_global_base; + subgroupMatrixStore(&output, matrix_c_offset + subtile_id * m_dim * uniforms.N, matC00, false, uniforms.N); + subgroupMatrixStore(&output, matrix_c_offset + subtile_id * m_dim * uniforms.N + n_dim, matC01, false, uniforms.N); + subgroupMatrixStore(&output, matrix_c_offset + subtile_id * m_dim * uniforms.N + 2 * n_dim, matC02, false, uniforms.N); + subgroupMatrixStore(&output, matrix_c_offset + subtile_id * m_dim * uniforms.N + 3 * n_dim, matC03, false, uniforms.N); + )MAIN_FN"; + + return Status::OK(); +} + +Status GenerateShaderCodeOnApple(ShaderHelper& shader, uint32_t nbits, bool has_zero_points) { // tile/subtile sizes and work distribution are inspired from metal shaders in llama.cpp (kernel_mul_mm) // https://github.com/ggml-org/llama.cpp/blob/d04e7163c85a847bc61d58c22f2c503596db7aa8/ggml/src/ggml-metal/ggml-metal.metal#L6066 shader.AdditionalImplementation() << R"ADDNL_FN( @@ -41,8 +245,9 @@ Status SubgroupMatrixMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader tile_A[row * tile_k + col + col_offset] = compute_precision(input_a[a_global*uniforms.K + k_idx + col + col_offset]); } } - )ADDNL_FN"; - if (nbits_ == 4) { + )ADDNL_FN" + << GenerateZeroPointReadingCode(nbits, has_zero_points, "compute_precision"); + if (nbits == 4) { shader.AdditionalImplementation() << R"ADDNL_FN( fn loadSHMB(tile_base: u32, k_idx: u32, row: u32, c_idx: u32) { let b_global = tile_base + row; @@ -54,12 +259,13 @@ Status SubgroupMatrixMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader // 128 threads need to load 64 x 32. 2 threads per row or 16 col per thread. // Stored in column major fashion. let b_idx = u32((b_global*uniforms.K + k_idx + col)/8); - let scale = compute_precision(scales_b[(b_global*uniforms.K + k_idx + col)/quantization_block_size]); + let scale = compute_precision(scales_b[(b_global*uniforms.K + k_idx + col)/quantization_block_size]); + let zero = mm_read_zero(b_global, (k_idx + col) / quantization_block_size, uniforms.N, uniforms.zero_blocks_per_col); for (var step:u32 = 0; step < 2; step++) { var b_value = input_b[b_idx+step]; - var b_value_lower = (vec4(unpack4xU8(b_value & 0x0F0F0F0Fu)) - vec4(8)) * scale; - var b_value_upper = (vec4(unpack4xU8((b_value >> 4) & 0x0F0F0F0Fu)) - vec4(8)) * scale; + var b_value_lower = (vec4(unpack4xU8(b_value & 0x0F0F0F0Fu)) - vec4(zero)) * scale; + var b_value_upper = (vec4(unpack4xU8((b_value >> 4) & 0x0F0F0F0Fu)) - vec4(zero)) * scale; let tile_b_base = row * tile_k + col + step * 8; tile_B[tile_b_base] = b_value_lower[0]; tile_B[tile_b_base + 1] = b_value_upper[0]; @@ -73,7 +279,7 @@ Status SubgroupMatrixMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader } )ADDNL_FN"; } else { - ORT_ENFORCE(nbits_ == 8, "Only 4/8 bits are supported for webgpu matmulnbits"); + ORT_ENFORCE(nbits == 8, "Only 4/8 bits are supported for webgpu matmulnbits"); shader.AdditionalImplementation() << R"ADDNL_FN( fn loadSHMB(tile_base: u32, k_idx: u32, row: u32, c_idx: u32) { let b_global = tile_base + row; @@ -85,12 +291,13 @@ Status SubgroupMatrixMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader // 128 threads need to load 64 x 32. 2 threads per row or 16 col per thread. // Stored in column major fashion. let b_idx = u32((b_global*uniforms.K + k_idx + col)/8); - let scale = compute_precision(scales_b[(b_global*uniforms.K + k_idx + col)/quantization_block_size]); + let scale = compute_precision(scales_b[(b_global*uniforms.K + k_idx + col)/quantization_block_size]); + let zero = mm_read_zero(b_global, (k_idx + col) / quantization_block_size, uniforms.N, uniforms.zero_blocks_per_col); for (var step:u32 = 0; step < 2; step++) { var b_value = input_b[b_idx+step]; - var b_value0 = (vec4(unpack4xU8(b_value[0])) - vec4(128)) * scale; - var b_value1 = (vec4(unpack4xU8(b_value[1])) - vec4(128)) * scale; + var b_value0 = (vec4(unpack4xU8(b_value[0])) - vec4(zero)) * scale; + var b_value1 = (vec4(unpack4xU8(b_value[1])) - vec4(zero)) * scale; let tile_b_base = row * tile_k + col + step * 8; tile_B[tile_b_base] = b_value0[0]; tile_B[tile_b_base + 1] = b_value0[1]; @@ -205,30 +412,62 @@ Status SubgroupMatrixMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader return Status::OK(); } +Status SubgroupMatrixMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const { + shader.AddInput("input_a", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias); + shader.AddInput("input_b", ShaderUsage::UseUniform); + shader.AddInput("scales_b", ShaderUsage::UseUniform); + if (has_zero_points_) { + shader.AddInput("zero_points", ShaderUsage::UseUniform); + } + shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseElementTypeAlias); + + if (!vendor_.compare("apple")) { + return GenerateShaderCodeOnApple(shader, nbits_, has_zero_points_); + } else if (!vendor_.compare("intel")) { + return GenerateShaderCodeOnIntel(shader, nbits_, config_index_, has_zero_points_); + } else { + return Status(onnxruntime::common::ONNXRUNTIME, onnxruntime::common::NOT_IMPLEMENTED, + "onnxruntime does not support subgroup matrix on this verdor."); + } +} + Status ApplySubgroupMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Tensor* scales, + const Tensor* zero_points, uint32_t M, uint32_t N, uint32_t K, uint32_t nbits, + uint32_t zero_blocks_per_col, + int32_t config_index, onnxruntime::webgpu::ComputeContext& context, Tensor* y) { - constexpr uint32_t kTileSizeA = 32; + uint32_t tile_size_a = 32; + uint32_t work_group_size = 128; constexpr uint32_t kTileSizeB = 64; constexpr uint32_t kU32Components = 4; TensorShape y_shape{1, M, N}; - SubgroupMatrixMatMulNBitsProgram mul_program{nbits}; - mul_program.SetWorkgroupSize(128); + const bool has_zero_points = zero_points != nullptr; + SubgroupMatrixMatMulNBitsProgram mul_program{nbits, config_index, context.AdapterInfo().vendor, has_zero_points}; + if (context.AdapterInfo().vendor == std::string_view{"intel"}) { + tile_size_a = 64; + work_group_size = 256; + } + mul_program.SetWorkgroupSize(work_group_size); mul_program.SetDispatchGroupSize( (N + kTileSizeB - 1) / kTileSizeB, - (M + kTileSizeA - 1) / kTileSizeA, 1); + (M + tile_size_a - 1) / tile_size_a, 1); mul_program.AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, 1}, {b, ProgramTensorMetadataDependency::TypeAndRank, static_cast(nbits == 4 ? kU32Components : 2 * kU32Components)}, {scales, ProgramTensorMetadataDependency::TypeAndRank, 1}}) .AddUniformVariables({{static_cast(M)}, {static_cast(N)}, - {static_cast(K)}}) + {static_cast(K)}, + {zero_blocks_per_col}}) .AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank, y_shape, 1}) - .CacheHint(nbits); + .CacheHint(nbits, has_zero_points); + if (has_zero_points) { + mul_program.AddInput({zero_points, ProgramTensorMetadataDependency::None, {(zero_points->Shape().Size() + 3) / 4}, 4}); + } return context.RunProgram(mul_program); } @@ -238,25 +477,28 @@ bool CanApplySubgroupMatrixMatMulNBits(onnxruntime::webgpu::ComputeContext& cont uint32_t batch_count, uint32_t N, uint32_t K, - bool has_zero_points) { -#if !defined(__wasm__) - const bool has_subgroup_matrix = context.HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix); -#else - const bool has_subgroup_matrix = false; -#endif - // For now SubgroupMatrixMatMulNBits is only supported for accuracy level 4, because with Fp16 there are - // some precision issues with subgroupMatrixMultiplyAccumulate. It is possible to support higher accuracy - // by setting compute_precision to Fp32, but that will be slower. For 1K token prefill FP16 Phi 3.5 is around 5s, - // FP322 is around 7s. + int32_t& config_index) { + bool has_subgroup_matrix = context.HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix); + if (has_subgroup_matrix) { + if (context.AdapterInfo().vendor == std::string_view{"apple"}) { + // For now SubgroupMatrixMatMulNBits is only supported for accuracy level 4, because with Fp16 there are + // some precision issues with subgroupMatrixMultiplyAccumulate. It is possible to support higher accuracy + // by setting compute_precision to Fp32, but that will be slower. For 1K token prefill FP16 Phi 3.5 is around 5s, + // FP32 is around 7s. + has_subgroup_matrix = accuracy_level == 4; + } else if (context.AdapterInfo().vendor == std::string_view{"intel"}) { + has_subgroup_matrix = IsSubgroupMatrixConfigSupportedOnIntel(context, config_index); + } + } + return has_subgroup_matrix && - context.AdapterInfo().vendor == std::string_view{"apple"} && - accuracy_level == 4 && block_size == 32 && batch_count == 1 && K % 32 == 0 && - N % 64 == 0 && - !has_zero_points; + N % 64 == 0; } } // namespace webgpu } // namespace contrib } // namespace onnxruntime + +#endif diff --git a/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.h b/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.h index a233e6a54f4fc..cf1fb9a6f7f15 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.h +++ b/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.h @@ -3,6 +3,10 @@ #pragma once +#if !defined(__wasm__) + +#include + #include "core/providers/webgpu/program.h" #include "core/providers/webgpu/compute_context.h" #include "core/providers/webgpu/program.h" @@ -17,22 +21,33 @@ using namespace onnxruntime::webgpu; class SubgroupMatrixMatMulNBitsProgram final : public Program { public: - SubgroupMatrixMatMulNBitsProgram(uint32_t nbits) : Program{"SubgroupMatrixMatMulNBits"}, nbits_(nbits) {} + SubgroupMatrixMatMulNBitsProgram(uint32_t nbits, int32_t config_index, const wgpu::StringView& vendor, bool has_zero_points) : Program{"SubgroupMatrixMatMulNBits"}, + nbits_(nbits), + config_index_(config_index), + vendor_(vendor), + has_zero_points_(has_zero_points) {} Status GenerateShaderCode(ShaderHelper& sh) const override; WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES( {"M", ProgramUniformVariableDataType::Uint32}, {"N", ProgramUniformVariableDataType::Uint32}, - {"K", ProgramUniformVariableDataType::Uint32}); + {"K", ProgramUniformVariableDataType::Uint32}, + {"zero_blocks_per_col", ProgramUniformVariableDataType::Uint32}); private: uint32_t nbits_; + int32_t config_index_; + std::string vendor_; + bool has_zero_points_; }; Status ApplySubgroupMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Tensor* scales, + const Tensor* zero_points, uint32_t M, uint32_t N, uint32_t K, uint32_t nbits, + uint32_t zero_blocks_per_col, + int32_t config_index, onnxruntime::webgpu::ComputeContext& context, Tensor* y); @@ -42,8 +57,10 @@ bool CanApplySubgroupMatrixMatMulNBits(onnxruntime::webgpu::ComputeContext& cont uint32_t batch_count, uint32_t N, uint32_t K, - bool has_zero_points); + int32_t& config_index); } // namespace webgpu } // namespace contrib } // namespace onnxruntime + +#endif diff --git a/onnxruntime/core/dlpack/dlpack_converter.cc b/onnxruntime/core/dlpack/dlpack_converter.cc index 5307bb32de7d0..652414b8d693a 100644 --- a/onnxruntime/core/dlpack/dlpack_converter.cc +++ b/onnxruntime/core/dlpack/dlpack_converter.cc @@ -84,8 +84,11 @@ OrtDevice GetOrtDevice(const DLDevice& device) { case DLDeviceType::kDLCPU: return OrtDevice(); case DLDeviceType::kDLCUDA: + return OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::NVIDIA, + static_cast(device.device_id)); case DLDeviceType::kDLROCM: - return OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, static_cast(device.device_id)); + return OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::AMD, + static_cast(device.device_id)); default: ORT_THROW("Unsupported device type"); } @@ -243,7 +246,7 @@ OrtValue DlpackToOrtValue(DLManagedTensor* dlpack, bool is_bool_tensor) { ORT_ENFORCE(IsContiguousTensor(dlpack->dl_tensor), "ORT only supports contiguous tensor for now."); OrtDevice device = GetOrtDevice(dlpack->dl_tensor.device); MLDataType data_type = GetOrtValueDataType(dlpack->dl_tensor.dtype, is_bool_tensor); - OrtMemoryInfo info(GetOrtDeviceName(device), OrtDeviceAllocator, device, device.Id()); + OrtMemoryInfo info(GetOrtDeviceName(device), OrtDeviceAllocator, device); std::unique_ptr p_tensor = std::make_unique( data_type, TensorShape(dlpack->dl_tensor.shape, static_cast(dlpack->dl_tensor.ndim)), dlpack->dl_tensor.data, info); @@ -257,7 +260,7 @@ OrtValue DlpackToOrtValue(DLManagedTensor* dlpack, bool is_bool_tensor) { deleter(p); }; - ort_value.Init(p_tensor.release(), DataTypeImpl::GetType(), deleter); + ort_value.Init(p_tensor.release(), DataTypeImpl::GetType(), std::move(deleter)); return ort_value; } diff --git a/onnxruntime/core/framework/abi_pointer_array.h b/onnxruntime/core/framework/abi_pointer_array.h new file mode 100644 index 0000000000000..91af1f7f9c6c0 --- /dev/null +++ b/onnxruntime/core/framework/abi_pointer_array.h @@ -0,0 +1,19 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include + +#include "core/session/onnxruntime_c_api.h" + +struct OrtArrayOfConstObjects { + OrtArrayOfConstObjects() = default; + explicit OrtArrayOfConstObjects(OrtTypeTag object_type) : object_type(object_type) {} + OrtArrayOfConstObjects(OrtTypeTag object_type, size_t size, const void* initial_val = nullptr) + : object_type(object_type), storage(size, initial_val) {} + + OrtTypeTag object_type = OrtTypeTag::ORT_TYPE_TAG_Void; + std::vector storage; +}; diff --git a/onnxruntime/core/framework/allocator.cc b/onnxruntime/core/framework/allocator.cc index a58f5ee27b754..5140d3ffaefff 100644 --- a/onnxruntime/core/framework/allocator.cc +++ b/onnxruntime/core/framework/allocator.cc @@ -99,6 +99,11 @@ void* AllocatorDefaultAlloc(size_t size) { return AllocatorDefaultAllocAligned(size, alignment); } +AllocatorPtr CPUAllocator::DefaultInstance() { + static AllocatorPtr instance = std::make_shared(); + return instance; +} + void* CPUAllocator::Alloc(size_t size) { const auto alignment = std::max(Info().device.GetAlignment(), MlasGetPreferredBufferAlignment()); return AllocatorDefaultAllocAligned(size, alignment); @@ -133,50 +138,98 @@ std::ostream& operator<<(std::ostream& out, const OrtMemoryInfo& info) { return #endif ORT_API_STATUS_IMPL(OrtApis::CreateMemoryInfo, _In_ const char* name1, enum OrtAllocatorType type, int id1, enum OrtMemType mem_type1, _Outptr_ OrtMemoryInfo** out) { + auto device_id = static_cast(id1); if (strcmp(name1, onnxruntime::CPU) == 0) { - *out = new OrtMemoryInfo(onnxruntime::CPU, type, OrtDevice(), id1, mem_type1); - } else if (strcmp(name1, onnxruntime::CUDA) == 0 || - strcmp(name1, onnxruntime::OpenVINO_GPU) == 0 || - strcmp(name1, onnxruntime::HIP) == 0 || - strcmp(name1, onnxruntime::WEBGPU_BUFFER) == 0 || + *out = new OrtMemoryInfo(onnxruntime::CPU, type, OrtDevice(), mem_type1); + } else if (strcmp(name1, onnxruntime::CUDA) == 0) { + *out = new OrtMemoryInfo( + name1, type, + OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::NVIDIA, device_id), + mem_type1); + } else if (strcmp(name1, onnxruntime::OpenVINO_GPU) == 0) { + *out = new OrtMemoryInfo( + name1, type, + OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::INTEL, device_id), + mem_type1); + } else if (strcmp(name1, onnxruntime::HIP) == 0) { + *out = new OrtMemoryInfo( + name1, type, + OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::AMD, device_id), + mem_type1); + } else if (strcmp(name1, onnxruntime::WEBGPU_BUFFER) == 0 || strcmp(name1, onnxruntime::WEBNN_TENSOR) == 0) { *out = new OrtMemoryInfo( - name1, type, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, static_cast(id1)), id1, + name1, type, + OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::NONE, device_id), mem_type1); + } else if (strcmp(name1, onnxruntime::DML) == 0) { *out = new OrtMemoryInfo( - name1, type, OrtDevice(OrtDevice::DML, OrtDevice::MemType::DEFAULT, static_cast(id1)), id1, + name1, type, + OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::MICROSOFT, device_id), mem_type1); } else if (strcmp(name1, onnxruntime::OpenVINO_RT_NPU) == 0) { *out = new OrtMemoryInfo( - name1, type, OrtDevice(OrtDevice::NPU, OrtDevice::MemType::DEFAULT, static_cast(id1)), id1, + name1, type, + OrtDevice(OrtDevice::NPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::INTEL, device_id), mem_type1); } else if (strcmp(name1, onnxruntime::CUDA_PINNED) == 0) { *out = new OrtMemoryInfo( - onnxruntime::CUDA_PINNED, type, - OrtDevice(OrtDevice::CPU, OrtDevice::MemType::CUDA_PINNED, static_cast(id1)), - id1, mem_type1); + name1, type, OrtDevice(OrtDevice::CPU, OrtDevice::MemType::HOST_ACCESSIBLE, OrtDevice::VendorIds::NVIDIA, device_id), + mem_type1); } else if (strcmp(name1, onnxruntime::HIP_PINNED) == 0) { *out = new OrtMemoryInfo( - onnxruntime::HIP_PINNED, type, - OrtDevice(OrtDevice::CPU, OrtDevice::MemType::HIP_PINNED, static_cast(id1)), - id1, mem_type1); + name1, type, + OrtDevice(OrtDevice::CPU, OrtDevice::MemType::HOST_ACCESSIBLE, OrtDevice::VendorIds::AMD, device_id), + mem_type1); } else if (strcmp(name1, onnxruntime::QNN_HTP_SHARED) == 0) { *out = new OrtMemoryInfo( - onnxruntime::QNN_HTP_SHARED, type, - OrtDevice(OrtDevice::CPU, OrtDevice::MemType::QNN_HTP_SHARED, static_cast(id1)), - id1, mem_type1); + name1, type, + OrtDevice(OrtDevice::CPU, OrtDevice::MemType::HOST_ACCESSIBLE, OrtDevice::VendorIds::QUALCOMM, device_id), + mem_type1); } else if (strcmp(name1, onnxruntime::CPU_ALIGNED_4K) == 0) { *out = new OrtMemoryInfo( - onnxruntime::CPU_ALIGNED_4K, type, - OrtDevice(OrtDevice::CPU, OrtDevice::MemType::DEFAULT, static_cast(id1), onnxruntime::kAlloc4KAlignment), - id1, mem_type1); + name1, type, + OrtDevice(OrtDevice::CPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::NONE, device_id, + onnxruntime::kAlloc4KAlignment), + mem_type1); } else { - return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Specified device is not supported."); + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Specified device is not supported. Try CreateMemoryInfo_V2."); } return nullptr; } +ORT_API_STATUS_IMPL(OrtApis::CreateMemoryInfo_V2, _In_ const char* name, _In_ enum OrtMemoryInfoDeviceType device_type, + _In_ uint32_t vendor_id, _In_ int16_t device_id, _In_ enum OrtDeviceMemoryType mem_type, + _In_ size_t alignment, enum OrtAllocatorType type, + _Outptr_ OrtMemoryInfo** out) { + // map the public enum values to internal OrtDevice values + OrtDevice::MemoryType mt = mem_type == OrtDeviceMemoryType_DEFAULT ? OrtDevice::MemType::DEFAULT + : OrtDevice::MemType::HOST_ACCESSIBLE; + + OrtDevice::DeviceType dt; + switch (device_type) { + case OrtMemoryInfoDeviceType::OrtMemoryInfoDeviceType_CPU: + dt = OrtDevice::CPU; + break; + case OrtMemoryInfoDeviceType::OrtMemoryInfoDeviceType_GPU: + dt = OrtDevice::GPU; + break; + case OrtMemoryInfoDeviceType::OrtMemoryInfoDeviceType_NPU: + dt = OrtDevice::NPU; + break; + case OrtMemoryInfoDeviceType::OrtMemoryInfoDeviceType_FPGA: + dt = OrtDevice::FPGA; + break; + default: + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Invalid device type specified."); + } + + *out = new OrtMemoryInfo(name, type, OrtDevice{dt, mt, vendor_id, device_id, alignment}, + mem_type == OrtDeviceMemoryType_DEFAULT ? OrtMemTypeDefault : OrtMemTypeCPU); + return nullptr; +} + ORT_API(void, OrtApis::ReleaseMemoryInfo, _Frees_ptr_opt_ OrtMemoryInfo* p) { delete p; } #if defined(_MSC_VER) && !defined(__clang__) #pragma warning(pop) @@ -187,7 +240,7 @@ ORT_API_STATUS_IMPL(OrtApis::MemoryInfoGetName, _In_ const OrtMemoryInfo* ptr, _ } ORT_API_STATUS_IMPL(OrtApis::MemoryInfoGetId, _In_ const OrtMemoryInfo* ptr, _Out_ int* out) { - *out = ptr->id; + *out = ptr->device.Id(); return nullptr; } diff --git a/onnxruntime/core/framework/bfc_arena.cc b/onnxruntime/core/framework/bfc_arena.cc index 6788b4af3b982..6513c4b95a818 100644 --- a/onnxruntime/core/framework/bfc_arena.cc +++ b/onnxruntime/core/framework/bfc_arena.cc @@ -16,7 +16,6 @@ BFCArena::BFCArena(std::unique_ptr resource_allocator, : IAllocator(OrtMemoryInfo(resource_allocator->Info().name, OrtAllocatorType::OrtArenaAllocator, resource_allocator->Info().device, - resource_allocator->Info().id, resource_allocator->Info().mem_type)), arena_type_(ArenaType::BaseArena), device_allocator_(std::move(resource_allocator)), diff --git a/onnxruntime/core/framework/endian_utils.cc b/onnxruntime/core/framework/endian_utils.cc index 8b61aad769ae9..236dda6b2e9e4 100644 --- a/onnxruntime/core/framework/endian_utils.cc +++ b/onnxruntime/core/framework/endian_utils.cc @@ -48,6 +48,16 @@ void SwapByteOrderCopy(size_t element_size_in_bytes, } } +void SwapByteOrderInplace(size_t element_size_in_bytes, gsl::span bytes) { + ORT_ENFORCE(element_size_in_bytes > 0, "Expecting a positive element size"); + ORT_ENFORCE(bytes.size_bytes() % element_size_in_bytes == 0, "Expecting a match"); + if (element_size_in_bytes > 1) { + for (size_t offset = 0, lim = bytes.size_bytes(); offset < lim; offset += element_size_in_bytes) { + std::reverse(bytes.begin() + offset, bytes.begin() + offset + element_size_in_bytes); + } + } +} + namespace detail { Status CopyLittleEndian(size_t element_size_in_bytes, diff --git a/onnxruntime/core/framework/endian_utils.h b/onnxruntime/core/framework/endian_utils.h index 6f084d058d007..c0792302a7141 100644 --- a/onnxruntime/core/framework/endian_utils.h +++ b/onnxruntime/core/framework/endian_utils.h @@ -31,6 +31,21 @@ void SwapByteOrderCopy(size_t element_size_in_bytes, gsl::span source_bytes, gsl::span destination_bytes); +/** + * Swaps the byte order of the elements in the given byte span in place. + * + * This is a low-level function - please be sure to pass in valid arguments. + * In particular: + * - bytes should have a size that is a multiple of element_size_in_bytes. + * - element_size_in_bytes should be greater than zero. + * - bytes should not overlap with itself. + * + * @param element_size_in_bytes The size of an individual element, in bytes. + * @param source_bytes The source byte span. + */ +void SwapByteOrderInplace(size_t element_size_in_bytes, + gsl::span bytes); + namespace detail { /** diff --git a/onnxruntime/core/framework/error_code.cc b/onnxruntime/core/framework/error_code.cc index ce58808db7a60..570252999f0e5 100644 --- a/onnxruntime/core/framework/error_code.cc +++ b/onnxruntime/core/framework/error_code.cc @@ -1,12 +1,16 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include "core/framework/error_code_helper.h" + +#include +#include + #include "core/session/onnxruntime_c_api.h" #include "core/session/ort_apis.h" #include "core/common/status.h" #include "core/common/safeint.h" -#include "core/framework/error_code_helper.h" -#include + using onnxruntime::common::Status; struct OrtStatus { @@ -26,6 +30,17 @@ inline OrtStatus* NewStatus(size_t clen) { if (buf == nullptr) return nullptr; // OOM. What we can do here? abort()? return new (buf) OrtStatus; } + +inline void DeleteStatus(OrtStatus* ort_status) { +#if defined(_MSC_VER) && !defined(__clang__) +#pragma warning(push) +#pragma warning(disable : 26409) +#endif + delete[] reinterpret_cast(ort_status); +#if defined(_MSC_VER) && !defined(__clang__) +#pragma warning(pop) +#endif +} } // namespace // Even we say it may not return NULL, indeed it may. @@ -43,6 +58,21 @@ _Check_return_ _Ret_notnull_ OrtStatus* ORT_API_CALL OrtApis::CreateStatus(OrtEr } namespace onnxruntime { + +namespace { + +struct OrtStatusDeleter { + void operator()(OrtStatus* p) const noexcept { + if (p != nullptr) { + DeleteStatus(p); + } + } +}; + +using UniqueOrtStatus = std::unique_ptr; + +} // namespace + _Ret_notnull_ OrtStatus* ToOrtStatus(const Status& st) { if (st.IsOK()) return nullptr; @@ -56,17 +86,21 @@ _Ret_notnull_ OrtStatus* ToOrtStatus(const Status& st) { return p; } -Status ToStatus(const OrtStatus* ort_status, common::StatusCategory category) { +Status ToStatusAndRelease(OrtStatus* ort_status, common::StatusCategory category) { if (ort_status == nullptr) { return Status::OK(); } + auto unique_ort_status = UniqueOrtStatus{ort_status}; return Status(category, static_cast(ort_status->code), &ort_status->msg[0]); } + } // namespace onnxruntime + #ifdef _MSC_VER #pragma warning(pop) #endif + ORT_API(OrtErrorCode, OrtApis::GetErrorCode, _In_ const OrtStatus* status) { return status->code; } @@ -74,7 +108,7 @@ ORT_API(OrtErrorCode, OrtApis::GetErrorCode, _In_ const OrtStatus* status) { ORT_API(const char*, OrtApis::GetErrorMessage, _In_ const OrtStatus* status) { return status->msg; } -#if defined(_MSC_VER) && !defined(__clang__) -#pragma warning(disable : 26409) -#endif -ORT_API(void, OrtApis::ReleaseStatus, _Frees_ptr_opt_ OrtStatus* value) { delete[] reinterpret_cast(value); } + +ORT_API(void, OrtApis::ReleaseStatus, _Frees_ptr_opt_ OrtStatus* value) { + DeleteStatus(value); +} diff --git a/onnxruntime/core/framework/error_code_helper.h b/onnxruntime/core/framework/error_code_helper.h index b42c6a9ba3e10..cb0a56756d8aa 100644 --- a/onnxruntime/core/framework/error_code_helper.h +++ b/onnxruntime/core/framework/error_code_helper.h @@ -8,11 +8,13 @@ #include "core/session/onnxruntime_c_api.h" namespace onnxruntime { + // Convert onnxruntime::common::Status to OrtStatus*. _Ret_notnull_ OrtStatus* ToOrtStatus(const onnxruntime::common::Status& st); -// Convert OrtStatus* to onnxruntime::common::Status. -Status ToStatus(const OrtStatus* ort_status, common::StatusCategory category = common::StatusCategory::ONNXRUNTIME); +// Convert OrtStatus* to onnxruntime::common::Status and release the OrtStatus*. +Status ToStatusAndRelease(OrtStatus* ort_status, + common::StatusCategory category = common::StatusCategory::ONNXRUNTIME); }; // namespace onnxruntime #ifndef ORT_NO_EXCEPTIONS diff --git a/onnxruntime/core/framework/execution_provider.cc b/onnxruntime/core/framework/execution_provider.cc index f44e1280e8041..df85daa006a43 100644 --- a/onnxruntime/core/framework/execution_provider.cc +++ b/onnxruntime/core/framework/execution_provider.cc @@ -1,7 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. #include "core/framework/execution_provider.h" -#include "core/framework/execution_providers.h" #include "core/graph/graph_viewer.h" #include "core/framework/compute_capability.h" @@ -10,8 +9,6 @@ #include "core/framework/murmurhash3.h" #include "core/framework/op_kernel.h" -#include - namespace onnxruntime { std::vector> @@ -40,105 +37,4 @@ common::Status IExecutionProvider::Compile(const std::vector& } #endif - -ExecutionProviders::ExecutionProviders() { -#ifdef _WIN32 - // Register callback for ETW capture state (rundown) - etw_callback_key_ = "ExecutionProviders_rundown_"; - etw_callback_key_.append(std::to_string(reinterpret_cast(this))); - WindowsTelemetry::RegisterInternalCallback( - etw_callback_key_, - [this](LPCGUID SourceId, - ULONG IsEnabled, - UCHAR Level, - ULONGLONG MatchAnyKeyword, - ULONGLONG MatchAllKeyword, - PEVENT_FILTER_DESCRIPTOR FilterData, - PVOID CallbackContext) { this->EtwProvidersCallback(SourceId, IsEnabled, Level, - MatchAnyKeyword, MatchAllKeyword, - FilterData, CallbackContext); }); -#endif -} - -ExecutionProviders::~ExecutionProviders() { -#ifdef _WIN32 - WindowsTelemetry::UnregisterInternalCallback(etw_callback_key_); -#endif -} - -common::Status ExecutionProviders::Add(const std::string& provider_id, - const std::shared_ptr& p_exec_provider) { - // make sure there are no issues before we change any internal data structures - if (provider_idx_map_.find(provider_id) != provider_idx_map_.end()) { - auto status = ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Provider ", provider_id, " has already been registered."); - LOGS_DEFAULT(ERROR) << status.ErrorMessage(); - return status; - } - - // index that provider will have after insertion - auto new_provider_idx = exec_providers_.size(); - - ORT_IGNORE_RETURN_VALUE(provider_idx_map_.insert({provider_id, new_provider_idx})); - - // update execution provider options - auto providerOptions = p_exec_provider->GetProviderOptions(); - exec_provider_options_[provider_id] = providerOptions; - -#ifdef _WIN32 - LogProviderOptions(provider_id, providerOptions, false); -#endif - - exec_provider_ids_.push_back(provider_id); - exec_providers_.push_back(p_exec_provider); - return Status::OK(); -} - -#ifdef _WIN32 -void ExecutionProviders::EtwProvidersCallback(LPCGUID /* SourceId */, - ULONG IsEnabled, - UCHAR /* Level */, - ULONGLONG MatchAnyKeyword, - ULONGLONG /* MatchAllKeyword */, - PEVENT_FILTER_DESCRIPTOR /* FilterData */, - PVOID /* CallbackContext */) { - // Check if this callback is for capturing state - if ((IsEnabled == EVENT_CONTROL_CODE_CAPTURE_STATE) && - ((MatchAnyKeyword & static_cast(onnxruntime::logging::ORTTraceLoggingKeyword::Session)) != 0)) { - for (size_t i = 0; i < exec_providers_.size(); ++i) { - const auto& provider_id = exec_provider_ids_[i]; - - auto it = exec_provider_options_.find(provider_id); - if (it != exec_provider_options_.end()) { - const auto& options = it->second; - - LogProviderOptions(provider_id, options, true); - } - } - } -} - -void ExecutionProviders::LogProviderOptions(const std::string& provider_id, - const ProviderOptions& providerOptions, - bool captureState) { -#ifdef ONNXRUNTIME_ENABLE_INSTRUMENT - for (const auto& config_pair : providerOptions) { - TraceLoggingWrite( - telemetry_provider_handle, - "ProviderOptions", - TraceLoggingKeyword(static_cast(onnxruntime::logging::ORTTraceLoggingKeyword::Session)), - TraceLoggingLevel(WINEVENT_LEVEL_INFO), - TraceLoggingString(provider_id.c_str(), "ProviderId"), - TraceLoggingString(config_pair.first.c_str(), "Key"), - TraceLoggingString(config_pair.second.c_str(), "Value"), - TraceLoggingBool(captureState, "isCaptureState")); - } -#else - ORT_UNUSED_PARAMETER(provider_id); - ORT_UNUSED_PARAMETER(providerOptions); - ORT_UNUSED_PARAMETER(captureState); -#endif -} - -#endif - } // namespace onnxruntime diff --git a/onnxruntime/core/framework/execution_providers.h b/onnxruntime/core/framework/execution_providers.h index c226db8c37c51..29cf79ec385d8 100644 --- a/onnxruntime/core/framework/execution_providers.h +++ b/onnxruntime/core/framework/execution_providers.h @@ -26,24 +26,91 @@ Class for managing lookup of the execution providers in a session. */ class ExecutionProviders { public: - ExecutionProviders(); + ExecutionProviders() { +#ifdef _WIN32 + // Register callback for ETW capture state (rundown) + etw_callback_ = onnxruntime::WindowsTelemetry::EtwInternalCallback( + [this]( + LPCGUID SourceId, + ULONG IsEnabled, + UCHAR Level, + ULONGLONG MatchAnyKeyword, + ULONGLONG MatchAllKeyword, + PEVENT_FILTER_DESCRIPTOR FilterData, + PVOID CallbackContext) { + (void)SourceId; + (void)Level; + (void)MatchAnyKeyword; + (void)MatchAllKeyword; + (void)FilterData; + (void)CallbackContext; + + // Check if this callback is for capturing state + if ((IsEnabled == EVENT_CONTROL_CODE_CAPTURE_STATE) && + ((MatchAnyKeyword & static_cast(onnxruntime::logging::ORTTraceLoggingKeyword::Session)) != 0)) { + for (size_t i = 0; i < exec_providers_.size(); ++i) { + const auto& provider_id = exec_provider_ids_[i]; + + auto it = exec_provider_options_.find(provider_id); + if (it != exec_provider_options_.end()) { + const auto& options = it->second; + + LogProviderOptions(provider_id, options, true); + } + } + } + }); + WindowsTelemetry::RegisterInternalCallback(etw_callback_); +#endif + } + + ~ExecutionProviders() { +#ifdef _WIN32 + WindowsTelemetry ::UnregisterInternalCallback(etw_callback_); +#endif + } - ~ExecutionProviders(); + common::Status + Add(const std::string& provider_id, const std::shared_ptr& p_exec_provider) { + // make sure there are no issues before we change any internal data structures + if (provider_idx_map_.find(provider_id) != provider_idx_map_.end()) { + auto status = ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Provider ", provider_id, " has already been registered."); + LOGS_DEFAULT(ERROR) << status.ErrorMessage(); + return status; + } - common::Status Add(const std::string& provider_id, const std::shared_ptr& p_exec_provider); + // index that provider will have after insertion + auto new_provider_idx = exec_providers_.size(); + + ORT_IGNORE_RETURN_VALUE(provider_idx_map_.insert({provider_id, new_provider_idx})); + + // update execution provider options + auto providerOptions = p_exec_provider->GetProviderOptions(); + exec_provider_options_[provider_id] = providerOptions; #ifdef _WIN32 + LogProviderOptions(provider_id, providerOptions, false); +#endif - void EtwProvidersCallback(LPCGUID SourceId, - ULONG IsEnabled, - UCHAR Level, - ULONGLONG MatchAnyKeyword, - ULONGLONG MatchAllKeyword, - PEVENT_FILTER_DESCRIPTOR FilterData, - PVOID CallbackContext); + exec_provider_ids_.push_back(provider_id); + exec_providers_.push_back(p_exec_provider); + return Status::OK(); + } - void LogProviderOptions(const std::string& provider_id, const ProviderOptions& providerOptions, - bool captureState); +#ifdef _WIN32 + void LogProviderOptions(const std::string& provider_id, const ProviderOptions& providerOptions, bool captureState) { + for (const auto& config_pair : providerOptions) { + TraceLoggingWrite( + telemetry_provider_handle, + "ProviderOptions", + TraceLoggingKeyword(static_cast(onnxruntime::logging::ORTTraceLoggingKeyword::Session)), + TraceLoggingLevel(WINEVENT_LEVEL_INFO), + TraceLoggingString(provider_id.c_str(), "ProviderId"), + TraceLoggingString(config_pair.first.c_str(), "Key"), + TraceLoggingString(config_pair.second.c_str(), "Value"), + TraceLoggingBool(captureState, "isCaptureState")); + } + } #endif const IExecutionProvider* Get(const onnxruntime::Node& node) const { @@ -102,7 +169,7 @@ class ExecutionProviders { bool cpu_execution_provider_was_implicitly_added_ = false; #ifdef _WIN32 - std::string etw_callback_key_; + WindowsTelemetry::EtwInternalCallback etw_callback_; #endif }; } // namespace onnxruntime diff --git a/onnxruntime/core/framework/graph_partitioner.cc b/onnxruntime/core/framework/graph_partitioner.cc index 9a2991ab02730..2081b8c3c9344 100644 --- a/onnxruntime/core/framework/graph_partitioner.cc +++ b/onnxruntime/core/framework/graph_partitioner.cc @@ -17,6 +17,7 @@ #include "core/framework/resource_accountant.h" #include "core/graph/function.h" #include "core/graph/function_utils.h" +#include "core/graph/graph_utils.h" #include "core/graph/graph_viewer.h" #include "core/graph/model.h" #include "core/graph/model_saving_options.h" @@ -902,9 +903,9 @@ static Status CreateEpContextModel(const ExecutionProviders& execution_providers } // handle initializers - for (const auto& initialized_tensor : graph.GetAllInitializedTensors()) { - if (ep_graph.GetNodeArg(initialized_tensor.first) != nullptr) { - ep_graph.AddInitializedTensor(*initialized_tensor.second); + for (const auto& [name, _] : graph.GetAllInitializedTensors()) { + if (ep_graph.GetNodeArg(name) != nullptr) { + graph_utils::MakeInitializerCopyIfNotExist(graph, ep_graph, name); } } diff --git a/onnxruntime/core/framework/onnxruntime_typeinfo.h b/onnxruntime/core/framework/onnxruntime_typeinfo.h index 54bb946e0d36b..c44c9fdaa4191 100644 --- a/onnxruntime/core/framework/onnxruntime_typeinfo.h +++ b/onnxruntime/core/framework/onnxruntime_typeinfo.h @@ -14,8 +14,9 @@ class DataTypeImpl; } // namespace onnxruntime namespace ONNX_NAMESPACE { +class TensorProto; class TypeProto; -} +} // namespace ONNX_NAMESPACE // These types are only present in the winml adapter c api, so they are forward declared. struct OrtMapTypeInfo; diff --git a/onnxruntime/core/framework/session_state.cc b/onnxruntime/core/framework/session_state.cc index 6362a3169f3a3..7d0026cc35558 100644 --- a/onnxruntime/core/framework/session_state.cc +++ b/onnxruntime/core/framework/session_state.cc @@ -300,17 +300,13 @@ const std::vector& SessionState::GetPerValueAllocPlan() const return p_seq_exec_plan_->allocation_plan; } -Status SessionState::AddInitializedTensor(int ort_value_index, const OrtValue& ort_value, const OrtCallback* d, +Status SessionState::AddInitializedTensor(int ort_value_index, const OrtValue& ort_value, bool constant, bool sparse) { auto p = initialized_tensors_.insert({ort_value_index, ort_value}); if (!p.second) return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "duplicated ort_value index:", ort_value_index, ". Do you have duplicated calls to SessionState::AddInitializedTensor function?"); - if (d != nullptr && d->f != nullptr) { - deleter_for_initialized_tensors_.insert_or_assign(ort_value_index, *d); - } - if (constant) { constant_initialized_tensors_.insert({ort_value_index, ort_value}); } @@ -1620,16 +1616,16 @@ Status SessionState::FinalizeSessionStateImpl(const std::basic_string Status { - ORT_RETURN_IF_ERROR(AddInitializedTensor(idx, value, &d, constant, sparse)); + ORT_RETURN_IF_ERROR(AddInitializedTensor(idx, value, constant, sparse)); if (remove_initializers) { graph_.RemoveInitializedTensor(name); } return Status::OK(); }, logger_, data_transfer_mgr_, external_data_loader_mgr_, *p_seq_exec_plan_, session_options, - memory_profile_func, name_to_buffered_tensor_, graph_.GetPrepacked())); + memory_profile_func, graph_.GetPrepacked())); #if !defined(ORT_MINIMAL_BUILD) && defined(ORT_MEMORY_PROFILE) // Record Weight allocation info on device diff --git a/onnxruntime/core/framework/session_state.h b/onnxruntime/core/framework/session_state.h index 964c059e529f9..71b88cb692f6f 100644 --- a/onnxruntime/core/framework/session_state.h +++ b/onnxruntime/core/framework/session_state.h @@ -18,7 +18,6 @@ #include "core/common/logging/logging.h" #include "core/common/profiler.h" #include "core/framework/allocation_planner.h" -#include "core/framework/callback.h" #include "core/framework/data_transfer_manager.h" #include "core/framework/external_data_loader_manager.h" #include "core/framework/execution_providers.h" @@ -102,9 +101,6 @@ class SessionState { AllocatorMap* parent_allocators = nullptr); ~SessionState() { - for (auto& kvp : deleter_for_initialized_tensors_) { - kvp.second.f(kvp.second.param); - } } // Graph viewer. CreateGraphInfo must have been called previously. @@ -143,12 +139,11 @@ class SessionState { /** * Adds an initialized tensor (weight) so that it can be used by the * execution frame to setup the appropriate OrtValue vectors. - * This function will take a shallow copy of d if d is not NULL. * If 'constant' is true the tensor value cannot be overridden by an input at runtime. * If 'sparse' is true the tensor value represents a densified weight that was initially stored in the model * as sparse tensor. */ - Status AddInitializedTensor(int ort_value_index, const OrtValue& ort_value, const OrtCallback* d, bool constant, bool sparse); + Status AddInitializedTensor(int ort_value_index, const OrtValue& ort_value, bool constant, bool sparse); /** * Gets the map of ort_value_index to initialized tensors (weights) so that it can be used by the @@ -310,10 +305,6 @@ class SessionState { const InlinedHashSet* GetToBeExecutedRange(gsl::span fetch_mlvalue_idxs) const; #endif - std::unordered_map>* GetMutableBufferedTensors() { - return &name_to_buffered_tensor_; - } - Status FinalizeSessionState(const std::basic_string& graph_loc, const KernelRegistryManager& kernel_registry_manager, bool remove_initializers = true, @@ -461,19 +452,12 @@ class SessionState { bool operator()(const OrtMemoryInfo& lhs, const OrtMemoryInfo& rhs) const { // if (lhs.alloc_type != rhs.alloc_type) // return lhs.alloc_type < rhs.alloc_type; - if (lhs.mem_type != rhs.mem_type) + if (lhs.mem_type != rhs.mem_type) { return lhs.mem_type < rhs.mem_type; - - if (lhs.id != rhs.id) - return lhs.id < rhs.id; + } if (lhs.device != rhs.device) { - // id should always == device.id so ignore that - if (lhs.device.Type() != rhs.device.Type()) - return lhs.device.Type() < rhs.device.Type(); - - // this is the allocator mem type and not the kernel mem type that OrtMemoryInfo.mem_type represents - return lhs.device.MemType() < rhs.device.MemType(); + return lhs.device < rhs.device; } return false; @@ -509,7 +493,6 @@ class SessionState { // This data structure is for uninitializing string tensors and // munmap memory region and close file descriptor - InlinedHashMap deleter_for_initialized_tensors_; InlinedVector weights_buffers_; std::optional p_seq_exec_plan_; @@ -607,12 +590,6 @@ class SessionState { // flag to indicate whether current session using any EP that create device stream dynamically. bool has_device_stream_enabled_ep_ = false; #endif - - // Holds the tensors which provide memory buffer for TensorProtos - // Use case: in optimizer, transform a TensorProto to a new TensorProto whose the memory buffer is - // allocated by CPU instead by protobuf's arena. Arena style memory allocators do not fully release - // a instance's memory which may result large memory consumption, which is a tradeoff for speed. - std::unordered_map> name_to_buffered_tensor_; }; } // namespace onnxruntime diff --git a/onnxruntime/core/framework/session_state_utils.cc b/onnxruntime/core/framework/session_state_utils.cc index cacd772b61d76..8f0713fcd7cb1 100644 --- a/onnxruntime/core/framework/session_state_utils.cc +++ b/onnxruntime/core/framework/session_state_utils.cc @@ -47,181 +47,138 @@ static common::Status AllocateBufferUsingDeviceAllocatorFromShapeAndType(const T return Status::OK(); } -// deleter for external data tensors managed by an OrtValue; manages the release of -// the tensor's data buffer (which points to the external data) and the tensor itself -struct ExtDataValueDeleter { - OrtCallback ext_delete_cb; - Tensor* p_tensor; - void operator()(void*) noexcept { - if (ext_delete_cb.f) { - ext_delete_cb.f(ext_delete_cb.param); - } - - delete p_tensor; - } -}; - -// given a tensor proto with external data return an OrtValue with a tensor for -// that data; the pointers for the tensor data and the tensor itself are owned -// by the OrtValue's deleter. -// If tensor_proto's external file path is kTensorProtoMemoryAddressTag, and -// buffered_tensor is not null, buffered_tensor holds the real buffer pointed -// by tensor_proto. buffered_tensor must be the owner of the buffer and deleter -// should release the buffer when tensor_proto is released. -static common::Status ExtDataTensorProtoToTensor(const Env& env, - const std::basic_string& proto_path, - const ONNX_NAMESPACE::TensorProto& tensor_proto, - Tensor& tensor, OrtCallback& ext_data_deleter, - PrepackedWeightsForGraph& prepacked_for_graph, - Tensor* buffered_tensor = nullptr) { - ORT_ENFORCE(utils::HasExternalData(tensor_proto)); - - void* ext_data_buf = nullptr; - SafeInt ext_data_len = 0; - ORT_RETURN_IF_ERROR(utils::GetExtDataFromTensorProto(env, proto_path.c_str(), tensor_proto, - ext_data_buf, ext_data_len, ext_data_deleter, - buffered_tensor, &prepacked_for_graph)); - if constexpr (endian::native != endian::little) { - if (!proto_path.empty() && (proto_path.compare(onnxruntime::utils::kTensorProtoMemoryAddressTag) != 0)) { - utils::ConvertRawDataInTensorProto(const_cast(&tensor_proto), ext_data_buf, ext_data_len); - } - } - - // NB: creating a do-nothing allocator per tensor is wasteful; can perhaps be - // avoided if the Tensor class implements the do-nothing behavior when given a - // nullptr for the allocator argument - const DataTypeImpl* const type = DataTypeImpl::TensorTypeFromONNXEnum(tensor_proto.data_type())->GetElementType(); - TensorShape tensor_shape = utils::GetTensorShapeFromTensorProto(tensor_proto); - tensor = Tensor(type, tensor_shape, ext_data_buf, OrtMemoryInfo(CPU, OrtAllocatorType::OrtDeviceAllocator)); - - return common::Status::OK(); -} - -// If tensor_proto's external file path is kTensorProtoMemoryAddressTag, and -// buffered_tensor is not null, buffered_tensor holds the real buffer pointed -// by tensor_proto. buffered_tensor must be the owner of the buffer and deleter -// should release the buffer when tensor_proto is released. +/** + * @brief Deserializes a TensorProto into an OrtValue. + * + * This function handles the complexities of deserializing a tensor, including + * managing memory allocation, handling external data, and transferring data + * between different devices (e.g., CPU to GPU). It can use a pre-allocated + * memory buffer or an allocator to manage the tensor's memory. + * + * @param env The environment object, providing access to logging and other services. + * @param proto_path The file path of the ONNX model, used for resolving external data paths. + * @param tensor_proto The TensorProto message to deserialize. + * @param memory_buffer Optional. A raw memory buffer that is pre-allocated for the tensor. + * If provided, `alloc` must be null. + * @param alloc Optional. An allocator to use for allocating the tensor's memory. + * If provided, `memory_buffer` must be null. + * @param default_cpu_alloc The default CPU allocator, used for intermediate buffers if needed + * (e.g., when copying from CPU to another device). + * @param[out] ort_value The OrtValue to be populated with the deserialized tensor data. + * @param data_transfer_mgr The manager responsible for copying tensor data between different memory locations/devices. + * @param external_data_loader_mgr The manager for handling custom external data loaders. + * @param prepacked_for_graph Reference to an object managing prepacked weights for the graph. + * @param use_device_allocator_for_initializers A flag indicating whether to use the device-specific allocator + * directly for initializers, potentially bypassing arenas. + * @return common::Status indicating success or failure of the deserialization process. + * Returns an error status if both `memory_buffer` and `alloc` are provided or if both are null (unless external data on CPU allows mmap), + * if string tensors are attempted to be copied to non-CPU devices, or if any underlying + * data loading, allocation, or copying operation fails. + */ static common::Status DeserializeTensorProto(const Env& env, const std::basic_string& proto_path, - const ONNX_NAMESPACE::TensorProto& tensor_proto, const MemBuffer* m, + const ONNX_NAMESPACE::TensorProto& tensor_proto, const MemBuffer* memory_buffer, const AllocatorPtr& alloc, const AllocatorPtr& default_cpu_alloc, OrtValue& ort_value, const DataTransferManager& data_transfer_mgr, const ExternalDataLoaderManager& external_data_loader_mgr, PrepackedWeightsForGraph& prepacked_for_graph, - bool use_device_allocator_for_initializers = false, - Tensor* buffered_tensor = nullptr) { - if (bool(alloc) == (m != nullptr)) { + bool use_device_allocator_for_initializers = false) { + if (bool(alloc) == (memory_buffer != nullptr)) { return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "DeserializeTensorProto() takes either pre-allocated buffer or an allocator!"); } - ORT_RETURN_IF(buffered_tensor && !utils::HasExternalData(tensor_proto), - "With buffered tensor, tensor proto must use external location and point to buffered tensor"); - - // Get shape and type of the tensor, and allocate the empty tensor TensorShape tensor_shape = utils::GetTensorShapeFromTensorProto(tensor_proto); const DataTypeImpl* const type = DataTypeImpl::TensorTypeFromONNXEnum(tensor_proto.data_type())->GetElementType(); - std::unique_ptr p_tensor; + Tensor tensor; - auto& memory_info = (alloc != nullptr) ? alloc->Info() : m->GetAllocInfo(); - auto device_type = memory_info.device.Type(); + // Get shape and type of the tensor, and allocate the empty tensor + static const auto default_cpu_device = OrtDevice(); + const auto& memory_info = (alloc != nullptr) ? alloc->Info() : memory_buffer->GetAllocInfo(); + const auto device = memory_info.device; if (utils::HasExternalData(tensor_proto)) { auto external_data_loader = external_data_loader_mgr.GetExternalDataLoader(memory_info); if (external_data_loader) { - // if custom external data loader is used, always allocate memory on device - p_tensor - ORT_RETURN_IF_ERROR(AllocateTensor(m, p_tensor, type, tensor_shape, use_device_allocator_for_initializers, alloc)); - + // if custom external data loader is used, always allocate memory on device + ORT_RETURN_IF_ERROR(AllocateTensor(memory_buffer, tensor, type, tensor_shape, use_device_allocator_for_initializers, alloc)); ORT_RETURN_IF_ERROR(utils::LoadExtDataToTensorFromTensorProto(env, proto_path, tensor_proto, - *external_data_loader, *p_tensor)); + *external_data_loader, tensor)); - Tensor::InitOrtValue(std::move(*p_tensor), ort_value); + Tensor::InitOrtValue(std::move(tensor), ort_value); return common::Status::OK(); - } else if (device_type == OrtDevice::CPU) { + } else if (device == default_cpu_device) { // for external initializer on CPU we will use mmap for large initializers so don't need to allocate memory in advance - p_tensor = std::make_unique(); // NB: The file containing external data for the tensor is mmap'd. If the tensor will be used on CPU we can - // utilize the mmap'd buffer directly by calling ExtDataTensorProtoToTensor. If we called - // TensorProtoToTensor it would copy the data, causing unnecessary overhead - OrtCallback ext_data_deleter; - ORT_RETURN_IF_ERROR(ExtDataTensorProtoToTensor(env, proto_path, tensor_proto, *p_tensor, - ext_data_deleter, prepacked_for_graph, - buffered_tensor)); - - ExtDataValueDeleter deleter{ext_data_deleter, p_tensor.get()}; - MLDataType ml_tensor_type = DataTypeImpl::GetType(); - ort_value.Init(p_tensor.release(), ml_tensor_type, deleter); + // utilize the mmap'd buffer directly. + ORT_RETURN_IF_ERROR(utils::GetExtDataFromTensorProto(env, proto_path, tensor_proto, + ort_value, + &prepacked_for_graph)); return common::Status::OK(); - } else { // non-cpu tensor - if (tensor_proto.data_type() == ONNX_NAMESPACE::TensorProto_DataType_STRING) { + } else { // non-cpu tensor or tensor in a cpu accessible memory + if (utils::HasString(tensor_proto)) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "string tensor is not supported for copying between allocators"); } // deserialize to CPU first for non-CPU allocator, then copy to device // for external initializer load on non-CPU device: - // 1. allocate memory on device - p_tensor - // 2. load initializer into CPU memory - p_deserialize_tensor, + // 1. allocate memory on device - tensor + // 2. load initializer into CPU memory - deserialized_value, // we will use mmap so no need to allocate memory on CPU in advance - // 3. copy tensor from CPU to device - p_deserialize_tensor -> p_tensor - ORT_RETURN_IF_ERROR(AllocateTensor(m, p_tensor, type, tensor_shape, use_device_allocator_for_initializers, alloc)); + // 3. copy tensor from CPU to device - deserialized_value -> tensor -> ort_value + ORT_RETURN_IF_ERROR(AllocateTensor(memory_buffer, tensor, type, tensor_shape, use_device_allocator_for_initializers, alloc)); - std::unique_ptr p_deserialize_tensor = std::make_unique(type, TensorShape(), default_cpu_alloc); + OrtValue deserialized_value; + ORT_RETURN_IF_ERROR(utils::GetExtDataFromTensorProto(env, proto_path, tensor_proto, + deserialized_value, + &prepacked_for_graph)); - OrtCallback ext_data_deleter; - std::optional scoped_ort_callback_invoker; - ORT_RETURN_IF_ERROR(ExtDataTensorProtoToTensor(env, proto_path, tensor_proto, *p_deserialize_tensor, - ext_data_deleter, prepacked_for_graph, - buffered_tensor)); - scoped_ort_callback_invoker.emplace(ext_data_deleter); - // TODO!! Need a temp buffer allocator for non-escape buffers that maybe too big for stack allocation. - - return CopyTensorFromCPUToDevice(data_transfer_mgr, p_deserialize_tensor, p_tensor, ort_value); + return CopyTensorFromCPUToDevice(data_transfer_mgr, deserialized_value.Get(), + std::move(tensor), ort_value); } } else { - // for internal initializer, always allocate memory on device - p_tensor - ORT_RETURN_IF_ERROR(AllocateTensor(m, p_tensor, type, tensor_shape, use_device_allocator_for_initializers, alloc)); + // for internal initializer, always allocate memory on device - tensor + ORT_RETURN_IF_ERROR(AllocateTensor(memory_buffer, tensor, type, tensor_shape, use_device_allocator_for_initializers, alloc)); - if (device_type == OrtDevice::CPU) { + if (device == default_cpu_device) { // deserialize directly to CPU tensor - ORT_RETURN_IF_ERROR(utils::TensorProtoToTensor(env, proto_path.c_str(), tensor_proto, *p_tensor)); - auto ml_tensor = DataTypeImpl::GetType(); - ort_value.Init(p_tensor.release(), ml_tensor, ml_tensor->GetDeleteFunc()); + ORT_RETURN_IF_ERROR(utils::TensorProtoToTensor(env, proto_path.c_str(), tensor_proto, tensor)); + Tensor::InitOrtValue(std::move(tensor), ort_value); return common::Status::OK(); } else { // non-cpu tensor - if (tensor_proto.data_type() == ONNX_NAMESPACE::TensorProto_DataType_STRING) { + if (utils::HasString(tensor_proto)) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "string tensor is not supported for copying between allocators"); } // deserialize to CPU first for non-CPU allocator, then copy // for internal initializer - // 1. allocate memory on CPU - p_deserialize_tensor - // 2. deserialize tensor_probo into a preallocated tensor (p_deserialize_tensor) - // 3. copy tensor from CPU to device - p_deserialize_tensor -> p_tensor - std::unique_ptr p_deserialize_tensor; - ORT_RETURN_IF_ERROR(AllocateTensorOnDeviceOrMemory(use_device_allocator_for_initializers, tensor_shape, type, default_cpu_alloc, p_deserialize_tensor)); - - ORT_RETURN_IF_ERROR(utils::TensorProtoToTensor(env, proto_path.c_str(), tensor_proto, *p_deserialize_tensor)); - // TODO!! Need a temp buffer allocator for non-escape buffers that maybe too big for stack allocation. - - return CopyTensorFromCPUToDevice(data_transfer_mgr, p_deserialize_tensor, p_tensor, ort_value); + // 1. allocate memory on CPU - deserialized_tensor + // 2. deserialize tensor_proto into a preallocated tensor (deserialized_tensor) + // 3. copy tensor from CPU to device - deserialized_tensor -> tensor (allocated above) -> ort_value + Tensor deserialized_tensor; + ORT_RETURN_IF_ERROR(AllocateTensorOnDeviceOrMemory(use_device_allocator_for_initializers, tensor_shape, type, + default_cpu_alloc, deserialized_tensor)); + + ORT_RETURN_IF_ERROR(utils::TensorProtoToTensor(env, proto_path.c_str(), tensor_proto, deserialized_tensor)); + return CopyTensorFromCPUToDevice(data_transfer_mgr, deserialized_tensor, std::move(tensor), ort_value); } } } -common::Status AllocateTensor(const onnxruntime::MemBuffer* m, - std::unique_ptr& p_tensor, +common::Status AllocateTensor(const onnxruntime::MemBuffer* memory_buffer, + Tensor& tensor, const onnxruntime::DataTypeImpl* const& type, onnxruntime::TensorShape& tensor_shape, bool use_device_allocator_for_initializers, const onnxruntime::AllocatorPtr& alloc) { - if (m != nullptr) { - p_tensor = std::make_unique(type, tensor_shape, m->GetBuffer(), m->GetAllocInfo()); - if (m->GetLen() < p_tensor->SizeInBytes()) { + if (memory_buffer != nullptr) { + tensor = Tensor{type, tensor_shape, memory_buffer->GetBuffer(), memory_buffer->GetAllocInfo()}; + if (memory_buffer->GetLen() < tensor.SizeInBytes()) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Internal error. The preallocated buffer is too small. Requires ", - p_tensor->SizeInBytes(), ", Got ", m->GetLen()); + tensor.SizeInBytes(), ", Got ", memory_buffer->GetLen()); } } else { - return AllocateTensorOnDeviceOrMemory(use_device_allocator_for_initializers, tensor_shape, type, alloc, p_tensor); + return AllocateTensorOnDeviceOrMemory(use_device_allocator_for_initializers, tensor_shape, type, alloc, tensor); } return common::Status::OK(); } @@ -231,37 +188,36 @@ common::Status AllocateTensorOnDeviceOrMemory( onnxruntime::TensorShape& tensor_shape, const onnxruntime::DataTypeImpl* const& type, const onnxruntime::AllocatorPtr& alloc, - std::unique_ptr& p_tensor) { + Tensor& tensor) { if (use_device_allocator_for_initializers) { void* tensor_buffer = nullptr; ORT_RETURN_IF_ERROR(AllocateBufferUsingDeviceAllocatorFromShapeAndType(tensor_shape, type, alloc, tensor_buffer)); - p_tensor = std::make_unique(type, tensor_shape, tensor_buffer, alloc); + tensor = Tensor{type, tensor_shape, tensor_buffer, alloc}; } else { // If the provided allocator is an arena-based allocator, the call to Alloc() will tap into memory from the arena // (may expand it if there isn't a chunk that can be allotted to the memory request). // If the provided allocator is non-arena based, the device specific Alloc() call will be used to allocate the necessary memory. - p_tensor = std::make_unique(type, tensor_shape, alloc); + tensor = Tensor{type, tensor_shape, alloc}; } return common::Status::OK(); } common::Status CopyTensorFromCPUToDevice( const onnxruntime::DataTransferManager& data_transfer_mgr, - std::unique_ptr& p_deserialize_tensor, - std::unique_ptr& p_tensor, + const Tensor& deserialized_tensor, + Tensor&& tensor, OrtValue& ort_value) { - Status copy_status = data_transfer_mgr.CopyTensor(*p_deserialize_tensor, *p_tensor); + Status copy_status = data_transfer_mgr.CopyTensor(deserialized_tensor, tensor); if (!copy_status.IsOK()) { if (copy_status.ErrorMessage().empty()) { // The windows execution provider does not return any error message today for CopyTensor since it is // not implemented yet. That's the reason we're adding our own error message so that we can debug better. return Status(copy_status.Category(), copy_status.Code(), - "Failed to copy tensor to " + p_tensor->Location().ToString()); + "Failed to copy tensor to " + tensor.Location().ToString()); } return copy_status; } else { - auto ml_tensor = DataTypeImpl::GetType(); - ort_value.Init(p_tensor.release(), ml_tensor, ml_tensor->GetDeleteFunc()); + Tensor::InitOrtValue(std::move(tensor), ort_value); return common::Status::OK(); } } @@ -279,7 +235,6 @@ common::Status SaveInitializedTensors( const ExecutionPlanBase& exec_plan, const SessionOptions& session_options, const MemoryProfileFunction& memory_profile_func, - std::unordered_map>& buffered_tensors, PrepackedWeightsForGraph& prepacked_for_graph) { LOGS(logger, INFO) << "Saving initialized tensors."; ORT_ENFORCE(ort_value_name_idx_map.MaxIdx() > -1, "OrtValue indexes should have been populated."); @@ -298,13 +253,13 @@ common::Status SaveInitializedTensors( if (!ort_value_name_idx_map.GetIdx(name, ort_value_index).IsOK()) { retval = false; } else { - const auto& planned_mem_info = exec_plan.GetLocation(ort_value_index); + const auto& planned_mem_device = exec_plan.GetLocation(ort_value_index); const auto& user_mem_info = it->second->Get().Location(); - retval = user_mem_info.device == planned_mem_info; + retval = user_mem_info.device == planned_mem_device; if (!retval) { LOGS(logger, WARNING) << "Cannot use user supplied initializer with name: (" << name << ") because the ORT planned memory location device " - << planned_mem_info.ToString() + << planned_mem_device.ToString() << " ) is different from what is supplied (" << user_mem_info.ToString() << ")"; } } @@ -319,7 +274,7 @@ common::Status SaveInitializedTensors( InlinedHashSet user_supplied_initializer_ids; // set containing the ort value ids of all user supplied initializers id_to_initialized_tensor.reserve(initialized_tensor_set.size()); - user_supplied_initializer_ids.reserve(initialized_tensor_set.size()); + user_supplied_initializer_ids.reserve(session_options.initializers_to_share_map.size()); for (const auto& entry : initialized_tensor_set) { int ort_value_index; @@ -330,6 +285,8 @@ common::Status SaveInitializedTensors( id_to_initialized_tensor[ort_value_index] = entry.second; } + static const auto default_cpu_device = OrtDevice(); + // tensors requiring a specific allocation order are traced first, to ensure they are allocated in order // NB1: vector with init allocation order may contain a subset of all tensors (or none at all) // NB2: only skip tracing and planning memory when data is external (i.e mmap) and on CPU. @@ -339,10 +296,21 @@ common::Status SaveInitializedTensors( const auto entry = initialized_tensors_to_allocate.find(ort_value_index); ORT_ENFORCE(entry != initialized_tensors_to_allocate.end(), "OrtValue index: ", ort_value_index, " from initializer_allocation_order not found among initialized tensors"); - if (!(utils::HasExternalData(*entry->second) && exec_plan.GetLocation(ort_value_index).Type() == OrtDevice::CPU)) { - // can not trace string tensor - ORT_ENFORCE(entry->second->data_type() != ONNX_NAMESPACE::TensorProto_DataType_STRING, "Can not trace string tensor"); - ORT_RETURN_IF_ERROR(planner.Trace(entry->first, entry->second)); + const auto* tensor_proto = entry->second; + + // We trace to allocate a single buffer using the planner. This reduces fragmentation. + // We do not trace the following values because it would add to the memory consumption. + // - Values that are on OrtDevice() (default CPU). + // - Values that are external and mapped from disk. We let the OS manage the memory. + // - we do not trace values that are in memory because they may be sitting on top of the user allocated + // memory. + const bool trace_allocation = (exec_plan.GetLocation(ort_value_index) != default_cpu_device) || + !utils::HasExternalData(*tensor_proto); + + if (trace_allocation) { + // can not trace string tensor, and they exist only on CPU + ORT_ENFORCE(!utils::HasString(*tensor_proto), "Can not trace string tensor"); + ORT_RETURN_IF_ERROR(planner.Trace(ort_value_index, tensor_proto)); } initialized_tensors_to_allocate.erase(entry); } @@ -352,7 +320,7 @@ common::Status SaveInitializedTensors( if (user_supplied_initializer_ids.find(entry.first) != user_supplied_initializer_ids.end()) { continue; } - if (entry.second->data_type() == ONNX_NAMESPACE::TensorProto_DataType_STRING) { + if (utils::HasString(*entry.second)) { // do not trace string tensor continue; } @@ -374,8 +342,6 @@ common::Status SaveInitializedTensors( << i.second << " bytes for " << i.first.ToString() << std::endl; } - OrtCallback deleter{nullptr, nullptr}; - // 3. create weight tensors based on weights buffer for (const auto& entry : id_to_initialized_tensor) { // We check for cancellation for every initializer since mapping from disk can be costly @@ -397,39 +363,50 @@ common::Status SaveInitializedTensors( if (user_supplied_initializer_ids.find(entry.first) != user_supplied_initializer_ids.end()) { ort_value = *(session_options.initializers_to_share_map.at(name)); LOGS(logger, INFO) << "Using user supplied initializer with name (" << name << ")."; - - } else if (graph.GetOrtValueInitializer(name, ort_value)) { - // populated OrtValue from the Graph instance } else { const ONNX_NAMESPACE::TensorProto& tensor_proto = *(entry.second); - std::optional m; + std::optional memory_buffer; AllocatorPtr alloc; // TODO: if the tensor need be copied, does it have enough room? - ORT_RETURN_IF_ERROR(planner.GetPreallocatedBuffer(ort_value_index, name, m, alloc)); - bool use_device_allocator_for_initializers = - session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsUseDeviceAllocatorForInitializers, "0") == "1"; - - Tensor* p_tensor = nullptr; - auto buffered_tensors_iter = buffered_tensors.find(name); - if (buffered_tensors_iter != buffered_tensors.end()) { - p_tensor = buffered_tensors_iter->second.get(); - } - - Status st = DeserializeTensorProto(env, graph_loc, tensor_proto, (m.has_value()) ? &*m : nullptr, alloc, - default_cpu_alloc, ort_value, data_transfer_mgr, external_data_loader_mgr, - prepacked_for_graph, - use_device_allocator_for_initializers, p_tensor); - if (!st.IsOK()) { - std::ostringstream oss; - oss << "Deserialize tensor " << name << " failed." << st.ErrorMessage(); - return Status(st.Category(), st.Code(), oss.str()); - } - - if (p_tensor != nullptr) { - // p_tensor was wrapped in a deleter by DeserializeTensorProto so we can simply release it here. - ORT_IGNORE_RETURN_VALUE(buffered_tensors_iter->second.release()); - buffered_tensors.erase(buffered_tensors_iter); + ORT_RETURN_IF_ERROR(planner.GetPreallocatedBuffer(ort_value_index, name, memory_buffer, alloc)); + const bool use_device_allocator_for_initializers = + session_options.config_options.GetConfigOrDefault( + kOrtSessionOptionsUseDeviceAllocatorForInitializers, "0") == "1"; + + // Check if we already have an OrtValue for this initializer on CPU + if (OrtValue ort_value_from_graph; + graph.GetOrtValueInitializer(name, ort_value_from_graph)) { + const auto& memory_info = (alloc != nullptr) ? alloc->Info() : memory_buffer->GetAllocInfo(); + if (memory_info.device == default_cpu_device) { + // This is on CPU use directly from the graph + ort_value = std::move(ort_value_from_graph); + } else { + TensorShape tensor_shape = utils::GetTensorShapeFromTensorProto(tensor_proto); + const DataTypeImpl* const type = DataTypeImpl::TensorTypeFromONNXEnum( + tensor_proto.data_type()) + ->GetElementType(); + Tensor tensor; + ORT_RETURN_IF_ERROR(AllocateTensor((memory_buffer) ? &*memory_buffer : nullptr, tensor, type, + tensor_shape, use_device_allocator_for_initializers, + alloc)); + ORT_RETURN_IF_ERROR(CopyTensorFromCPUToDevice(data_transfer_mgr, + ort_value_from_graph.Get(), + std::move(tensor), ort_value)); + } + } else { + // We need to deserialize the tensor proto into an OrtValue + // using the preallocated buffer or allocator. + + Status st = DeserializeTensorProto(env, graph_loc, tensor_proto, (memory_buffer.has_value()) ? &*memory_buffer : nullptr, alloc, + default_cpu_alloc, ort_value, data_transfer_mgr, external_data_loader_mgr, + prepacked_for_graph, + use_device_allocator_for_initializers); + if (!st.IsOK()) { + std::ostringstream oss; + oss << "Deserialize tensor " << name << " failed." << st.ErrorMessage(); + return Status(st.Category(), st.Code(), oss.str()); + } } } @@ -442,9 +419,9 @@ common::Status SaveInitializedTensors( const bool constant = graph.IsConstantInitializer(name, /* check_outer_scope */ false); #if !defined(DISABLE_SPARSE_TENSORS) const bool sparse = graph.GetGraph().IsSparseInitializer(name); - ORT_RETURN_IF_ERROR(save_tensor_func(name, ort_value_index, ort_value, deleter, constant, sparse)); + ORT_RETURN_IF_ERROR(save_tensor_func(name, ort_value_index, ort_value, constant, sparse)); #else - ORT_RETURN_IF_ERROR(save_tensor_func(name, ort_value_index, ort_value, deleter, constant, false)); + ORT_RETURN_IF_ERROR(save_tensor_func(name, ort_value_index, ort_value, constant, false)); #endif } diff --git a/onnxruntime/core/framework/session_state_utils.h b/onnxruntime/core/framework/session_state_utils.h index 17400c45e5f32..3428b38b389a8 100644 --- a/onnxruntime/core/framework/session_state_utils.h +++ b/onnxruntime/core/framework/session_state_utils.h @@ -36,7 +36,7 @@ class Logger; namespace session_state_utils { using SaveTensorFunction = std::function; + bool constant, bool sparse)>; using MemoryProfileFunction = std::function; common::Status SaveInitializedTensors( @@ -51,12 +51,11 @@ common::Status SaveInitializedTensors( const ExecutionPlanBase& exec_plan, const SessionOptions& session_options, const MemoryProfileFunction& memory_profile_func, - std::unordered_map>& buffered_tensors, PrepackedWeightsForGraph& prepacked_for_graph); common::Status AllocateTensor( - const onnxruntime::MemBuffer* m, - std::unique_ptr& p_tensor, + const onnxruntime::MemBuffer* memory_buffer, + Tensor& p_tensor, const onnxruntime::DataTypeImpl* const& type, onnxruntime::TensorShape& tensor_shape, bool use_device_allocator_for_initializers, @@ -67,12 +66,12 @@ common::Status AllocateTensorOnDeviceOrMemory( onnxruntime::TensorShape& tensor_shape, const onnxruntime::DataTypeImpl* const& type, const onnxruntime::AllocatorPtr& alloc, - std::unique_ptr& p_tensor); + Tensor& p_tensor); common::Status CopyTensorFromCPUToDevice( const onnxruntime::DataTransferManager& data_transfer_mgr, - std::unique_ptr& p_deserialize_tensor, - std::unique_ptr& p_tensor, + const Tensor& deserialized_tensor, + Tensor&& tensor, OrtValue& ort_value); common::Status SaveInputOutputNamesToNodeMapping(const GraphViewer& graph, diff --git a/onnxruntime/core/framework/tensor_allocator_with_mem_pattern.h b/onnxruntime/core/framework/tensor_allocator_with_mem_pattern.h index 92c264e57279c..ad88149c89b81 100644 --- a/onnxruntime/core/framework/tensor_allocator_with_mem_pattern.h +++ b/onnxruntime/core/framework/tensor_allocator_with_mem_pattern.h @@ -47,12 +47,13 @@ class TensorAllocatorWithMemPattern : public ITensorAllocator { } else { buffer = alloc->Alloc(peak_size); } - weights_buffers_.push_back(BufferUniquePtr(buffer, BufferDeleter(alloc))); + + auto buffer_ptr = BufferUniquePtr(buffer, BufferDeleter(std::move(alloc))); auto kvp = buffers_.insert(std::make_pair(location, buffer)); if (!kvp.second) { - alloc->Free(buffer); return Status(common::ONNXRUNTIME, common::FAIL, "duplicated location"); } + weights_buffers_.push_back(std::move(buffer_ptr)); planned_memory_sizes_in_byte[location] += peak_size; } diff --git a/onnxruntime/core/framework/tensorprotoutils.cc b/onnxruntime/core/framework/tensorprotoutils.cc index 94a2a6677358e..1607d950059c3 100644 --- a/onnxruntime/core/framework/tensorprotoutils.cc +++ b/onnxruntime/core/framework/tensorprotoutils.cc @@ -22,7 +22,6 @@ #include "core/framework/tensor.h" #include "core/framework/ort_value_pattern_planner.h" #include "core/framework/allocator.h" -#include "core/framework/callback.h" #include "core/framework/data_types.h" #include "core/platform/path_lib.h" #include "core/framework/to_tensor_proto_element_type.h" @@ -172,13 +171,21 @@ DEFINE_INT4_UNPACK_TENSOR_WITH_RAW_DATA_IMPL(UInt4x2) Status ReadExternalDataForTensor(const ONNX_NAMESPACE::TensorProto& tensor_proto, const std::filesystem::path& tensor_proto_dir, std::vector& unpacked_tensor) { - std::basic_string external_file_path; + PathString external_file_path; onnxruntime::FileOffsetType file_offset; SafeInt tensor_byte_size; ORT_RETURN_IF_ERROR( GetExternalDataInfo(tensor_proto, tensor_proto_dir, external_file_path, file_offset, tensor_byte_size)); unpacked_tensor.resize(tensor_byte_size); + + if (external_file_path == kTensorProtoMemoryAddressTag) { + // The external data is in the same memory as the tensor proto. + // The offset is the address of the data. + std::memcpy(unpacked_tensor.data(), reinterpret_cast(file_offset), tensor_byte_size); + return Status::OK(); + } + ORT_RETURN_IF_ERROR(onnxruntime::Env::Default().ReadFileIntoBuffer( external_file_path.c_str(), file_offset, @@ -216,7 +223,7 @@ Status TensorProtoToOrtValueImpl(const Env& env, const std::filesystem::path& mo tensor->SizeInBytes(), ", Got ", m->GetLen()); } } else { - tensor = std::make_unique(type, tensor_shape, alloc); + tensor = std::make_unique(type, tensor_shape, std::move(alloc)); } ORT_RETURN_IF_ERROR(TensorProtoToTensor(env, model_path, tensor_proto, *tensor)); @@ -230,16 +237,55 @@ Status TensorProtoToOrtValueImpl(const Env& env, const std::filesystem::path& mo namespace utils { +bool HasExternalDataInMemory(const ONNX_NAMESPACE::TensorProto& ten_proto) { + if (HasExternalData(ten_proto)) { + // Retrieve the external data info + for (const auto& entry : ten_proto.external_data()) { + if (entry.key() == "location") { + PathString location = ToWideString(entry.value()); + return location == kTensorProtoMemoryAddressTag; + } + } + } + + return false; // No external data in memory +} + +Status TensorProtoWithExternalDataToTensorProto( + const ONNX_NAMESPACE::TensorProto& ten_proto, + const std::filesystem::path& model_path, + ONNX_NAMESPACE::TensorProto& new_tensor_proto) { + // Check if the input tensor has external data + ORT_RETURN_IF_NOT(HasExternalData(ten_proto), "Input tensor does not have external data."); + + // Copy the metadata from the source tensor to the new tensor + ONNX_NAMESPACE::TensorProto result; + result.set_name(ten_proto.name()); + result.set_data_type(ten_proto.data_type()); + result.mutable_dims()->CopyFrom(ten_proto.dims()); + + // Load the external data into memory + std::vector unpacked_data; + ORT_RETURN_IF_ERROR(ReadExternalDataForTensor(ten_proto, model_path, unpacked_data)); + + // Set the raw data in the new tensor + result.set_raw_data(unpacked_data.data(), unpacked_data.size()); + + new_tensor_proto = std::move(result); + + return Status::OK(); +} + Status GetExternalDataInfo(const ONNX_NAMESPACE::TensorProto& tensor_proto, const std::filesystem::path& tensor_proto_dir, std::basic_string& external_file_path, onnxruntime::FileOffsetType& file_offset, SafeInt& tensor_byte_size, ExternalDataInfo::PrepackedInfos* prepacked_infos) { - ORT_RETURN_IF_NOT(onnxruntime::utils::HasExternalData(tensor_proto), + ORT_RETURN_IF_NOT(HasExternalData(tensor_proto), "Tensor does not have external data to read from."); - ORT_RETURN_IF(!onnxruntime::utils::HasDataType(tensor_proto) || onnxruntime::utils::HasString(tensor_proto), + ORT_RETURN_IF(!HasDataType(tensor_proto) || HasString(tensor_proto), "External data type cannot be UNDEFINED or STRING."); std::unique_ptr external_data_info; @@ -247,10 +293,10 @@ Status GetExternalDataInfo(const ONNX_NAMESPACE::TensorProto& tensor_proto, const auto& location = external_data_info->GetRelPath(); - external_file_path = location == onnxruntime::utils::kTensorProtoMemoryAddressTag ? std::filesystem::path(location) - : (tensor_proto_dir / location); + external_file_path = location == kTensorProtoMemoryAddressTag ? std::filesystem::path(location) + : (tensor_proto_dir / location); - ORT_RETURN_IF_ERROR(onnxruntime::utils::GetSizeInBytesFromTensorProto<0>(tensor_proto, &tensor_byte_size)); + ORT_RETURN_IF_ERROR(GetSizeInBytesFromTensorProto<0>(tensor_proto, &tensor_byte_size)); const size_t external_data_length = external_data_info->GetLength(); ORT_RETURN_IF_NOT(external_data_length == 0 || external_data_length == tensor_byte_size, "TensorProto: ", tensor_proto.name(), @@ -270,38 +316,22 @@ void SetRawDataInTensorProto(ONNX_NAMESPACE::TensorProto& tensor_proto, std::str tensor_proto.set_raw_data(std::move(param)); } -void ConvertRawDataInTensorProto(TensorProto* tensor, - void* ext_data_buf, - size_t ext_data_len) { +void ConvertRawDataInTensorProto(TensorProto& tensor) { size_t element_size = 1; - char* bytes = NULL; + void* bytes = NULL; size_t num_elements = 0; - if (ext_data_buf && !ext_data_len) { - return; - } - switch (tensor->data_type()) { + + switch (tensor.data_type()) { case TensorProto_DataType_FLOAT: - bytes = reinterpret_cast(tensor->mutable_float_data()->mutable_data()); - num_elements = tensor->float_data_size(); + bytes = tensor.mutable_float_data()->mutable_data(); + num_elements = tensor.float_data_size(); element_size = sizeof(float); break; - case TensorProto_DataType_INT32: - bytes = reinterpret_cast(tensor->mutable_int32_data()->mutable_data()); - num_elements = tensor->int32_data_size(); - element_size = sizeof(int32_t); - break; - - case TensorProto_DataType_UINT32: - bytes = reinterpret_cast(tensor->mutable_int32_data()->mutable_data()); - num_elements = tensor->int32_data_size(); - element_size = sizeof(uint32_t); - break; - case TensorProto_DataType_UINT8: case TensorProto_DataType_INT8: - bytes = reinterpret_cast(tensor->mutable_int32_data()->mutable_data()); - num_elements = tensor->int32_data_size(); + bytes = tensor.mutable_int32_data()->mutable_data(); + num_elements = tensor.int32_data_size(); element_size = sizeof(uint8_t); break; @@ -309,56 +339,52 @@ void ConvertRawDataInTensorProto(TensorProto* tensor, case TensorProto_DataType_INT16: case TensorProto_DataType_FLOAT16: case TensorProto_DataType_BFLOAT16: - bytes = reinterpret_cast(tensor->mutable_int32_data()->mutable_data()); - num_elements = tensor->int32_data_size(); - element_size = sizeof(uint16_t); + case TensorProto_DataType_INT32: + bytes = tensor.mutable_int32_data()->mutable_data(); + num_elements = tensor.int32_data_size(); + // We are setting this to int32_t size because we need to swap all 4 bytes + // to represent 16 bits within 32 bits correctly on a LE/BE system. + element_size = sizeof(int32_t); break; + // uint32_t is stored in uint64_t + case TensorProto_DataType_UINT32: case TensorProto_DataType_UINT64: - bytes = reinterpret_cast(tensor->mutable_uint64_data()->mutable_data()); - num_elements = tensor->uint64_data_size(); + bytes = tensor.mutable_uint64_data()->mutable_data(); + num_elements = tensor.uint64_data_size(); element_size = sizeof(uint64_t); break; - case TensorProto_DataType_DOUBLE: - bytes = reinterpret_cast(tensor->mutable_double_data()->mutable_data()); - num_elements = tensor->double_data_size(); - element_size = sizeof(double); - break; - case TensorProto_DataType_INT64: - bytes = reinterpret_cast(tensor->mutable_int64_data()->mutable_data()); - num_elements = tensor->int64_data_size(); + bytes = tensor.mutable_int64_data()->mutable_data(); + num_elements = tensor.int64_data_size(); element_size = sizeof(int64_t); break; + case TensorProto_DataType_DOUBLE: + bytes = tensor.mutable_double_data()->mutable_data(); + num_elements = tensor.double_data_size(); + element_size = sizeof(double); + break; + case TensorProto_DataType_COMPLEX64: - bytes = reinterpret_cast(tensor->mutable_float_data()->mutable_data()); - num_elements = tensor->float_data_size(); + bytes = tensor.mutable_float_data()->mutable_data(); + num_elements = tensor.float_data_size(); element_size = sizeof(float); break; } - if (tensor->has_raw_data()) { - num_elements = (tensor->raw_data().size()) / element_size; - bytes = const_cast(tensor->mutable_raw_data()->c_str()); - } if (element_size == 1) { return; } - if (ext_data_buf) { - ORT_ENFORCE(ext_data_len % element_size == 0); - num_elements = ext_data_len / element_size; - bytes = reinterpret_cast(ext_data_buf); - } - for (size_t i = 0; i < num_elements; ++i) { - char* start_byte = bytes + i * element_size; - char* end_byte = start_byte + element_size - 1; - for (size_t count = 0; count < element_size / 2; ++count) { - std::swap(*start_byte++, *end_byte--); - } + + if (tensor.has_raw_data()) { + num_elements = tensor.raw_data().size() / element_size; + bytes = tensor.mutable_raw_data()->data(); } - return; + + gsl::span span = gsl::make_span(reinterpret_cast(bytes), num_elements * element_size); + SwapByteOrderInplace(element_size, span); } #if !defined(ORT_MINIMAL_BUILD) @@ -984,26 +1010,10 @@ ORT_API(void, OrtUninitializeBuffer, _In_opt_ void* input, size_t input_len, enu #if defined(_MSC_VER) && !defined(__clang__) #pragma warning(disable : 26409) #endif -class AutoDelete { - public: - OrtCallback d{nullptr, nullptr}; - AutoDelete() = default; - ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(AutoDelete); - ~AutoDelete() { - if (d.f != nullptr) { - d.f(d.param); - } - } -}; - -static void DeleteCharArray(void* param) noexcept { - auto arr = reinterpret_cast(param); - delete[] arr; -} #if !defined(__wasm__) static Status GetFileContent(const Env& env, const std::filesystem::path& file_path, FileOffsetType offset, - size_t length, void*& raw_buffer, OrtCallback& deleter) { + size_t length, IAllocatorUniquePtr& external_data) { // query length if it is 0 if (length == 0) { // The return type of std::filesystem::file_size is uintmax_t which could be bigger than size_t @@ -1015,8 +1025,9 @@ static Status GetFileContent(const Env& env, const std::filesystem::path& file_p Env::MappedMemoryPtr mapped_memory{}; auto status = env.MapFileIntoMemory(file_path.native().c_str(), offset, length, mapped_memory); if (status.IsOK()) { - deleter = mapped_memory.get_deleter().callback; - raw_buffer = mapped_memory.release(); + IAllocatorUniquePtr raw_buffer(mapped_memory.release(), + mapped_memory.get_deleter()); + external_data.swap(raw_buffer); return Status::OK(); } } @@ -1026,22 +1037,24 @@ static Status GetFileContent(const Env& env, const std::filesystem::path& file_p ORT_RETURN_IF_ERROR( env.ReadFileIntoBuffer(file_path.native().c_str(), offset, length, gsl::make_span(buffer.get(), length))); - deleter = OrtCallback{DeleteCharArray, buffer.get()}; - raw_buffer = buffer.release(); + IAllocatorUniquePtr raw_buffer(buffer.release(), [](void* p) { delete[] reinterpret_cast(p); }); + external_data.swap(raw_buffer); return Status::OK(); } #endif -Status GetExtDataFromTensorProto(const Env& env, const std::filesystem::path& model_path, - const ONNX_NAMESPACE::TensorProto& tensor_proto, void*& ext_data_buf, - SafeInt& ext_data_len, OrtCallback& ext_data_deleter, - Tensor* buffered_tensor, - PrepackedWeightsForGraph* prepacked_info) { - ORT_ENFORCE(utils::HasExternalData(tensor_proto)); +Status GetExtDataFromTensorProto(const Env& env, + const std::filesystem::path& model_path, + const ONNX_NAMESPACE::TensorProto& tensor_proto, + OrtValue& ort_value, PrepackedWeightsForGraph* prepacked_info) { + ORT_ENFORCE(HasExternalData(tensor_proto), "TensorProto for: ", + tensor_proto.name(), "Expected to have external data"); + std::basic_string tensor_proto_dir; if (!model_path.empty()) { ORT_RETURN_IF_ERROR(GetDirNameFromFilePath(model_path, tensor_proto_dir)); } + std::basic_string external_data_file_path; FileOffsetType file_offset; SafeInt raw_data_safe_len = 0; @@ -1049,20 +1062,24 @@ Status GetExtDataFromTensorProto(const Env& env, const std::filesystem::path& mo if (prepacked_info != nullptr) { prepacked_infos.emplace(); } + ORT_RETURN_IF_ERROR( GetExternalDataInfo(tensor_proto, tensor_proto_dir, external_data_file_path, file_offset, raw_data_safe_len, (prepacked_info != nullptr) ? &*prepacked_infos : nullptr)); + TensorShape tensor_shape = utils::GetTensorShapeFromTensorProto(tensor_proto); + const DataTypeImpl* const type = DataTypeImpl::TensorTypeFromONNXEnum(tensor_proto.data_type())->GetElementType(); + MLDataType ml_tensor_type = DataTypeImpl::GetType(); + const auto& name = tensor_proto.name(); + if (external_data_file_path == onnxruntime::utils::kTensorProtoMemoryAddressTag) { // the value in location is the memory address of the data - ext_data_buf = reinterpret_cast(file_offset); - ext_data_len = raw_data_safe_len; - if (buffered_tensor) { - ext_data_deleter = OrtCallback{[](void* p) noexcept { delete reinterpret_cast(p); }, - reinterpret_cast(buffered_tensor)}; - } else { - ext_data_deleter = OrtCallback{nullptr, nullptr}; - } + void* ext_data_buf = reinterpret_cast(file_offset); + auto tensor = Tensor{type, tensor_shape, ext_data_buf, OrtMemoryInfo(CPU, OrtAllocatorType::OrtDeviceAllocator)}; + ORT_RETURN_IF(raw_data_safe_len != tensor.SizeInBytes(), "Weight: ", name, + " kTensorProtoMemoryAddressTag address points to length: ", static_cast(raw_data_safe_len), + " while shape has bytes size: ", tensor.SizeInBytes()); + Tensor::InitOrtValue(std::move(tensor), ort_value); } else { #if defined(__wasm__) ORT_RETURN_IF(file_offset < 0 || file_offset + raw_data_safe_len >= 4294967296, @@ -1071,19 +1088,27 @@ Status GetExtDataFromTensorProto(const Env& env, const std::filesystem::path& mo " are out of bounds or can not be read in full (>4GB)."); auto buffer = std::make_unique(raw_data_safe_len); - ext_data_deleter = OrtCallback{DeleteCharArray, buffer.get()}; - ext_data_buf = buffer.release(); - ext_data_len = raw_data_safe_len; - ORT_RETURN_IF_ERROR(LoadWebAssemblyExternalData(env, external_data_file_path, file_offset, - ext_data_len, + raw_data_safe_len, ExternalDataLoadType::CPU, - ext_data_buf)); + buffer.get())); + + auto p_tensor = std::make_unique(type, tensor_shape, buffer.get(), + OrtMemoryInfo(CPU, OrtAllocatorType::OrtDeviceAllocator)); + + std::function deleter = [ext_data = buffer.get()](void* t) { + delete reinterpret_cast(t); + delete[] ext_data; + }; + + ort_value.Init(p_tensor.release(), ml_tensor_type, std::move(deleter)); + buffer.release(); + #else - // The GetFileContent function doesn't report error if the requested data range is invalid. Therefore we need to - // manually check file size first. + // The GetFileContent function doesn't report error if the requested data range is invalid. Therefore we need to + // manually check file size first. std::uintmax_t file_length = std::filesystem::file_size(external_data_file_path); SafeInt end_of_read(file_offset); @@ -1092,9 +1117,35 @@ Status GetExtDataFromTensorProto(const Env& env, const std::filesystem::path& mo "External initializer: ", tensor_proto.name(), " offset: ", file_offset, " size to read: ", static_cast(raw_data_safe_len), " given file_length: ", file_length, " are out of bounds or can not be read in full."); - ORT_RETURN_IF_ERROR(GetFileContent(env, external_data_file_path.c_str(), file_offset, raw_data_safe_len, - ext_data_buf, ext_data_deleter)); - ext_data_len = raw_data_safe_len; + + IAllocatorUniquePtr ext_data_buf; + ORT_RETURN_IF_ERROR(GetFileContent(env, external_data_file_path, file_offset, raw_data_safe_len, + ext_data_buf)); + + // Data on disk is little endian + if constexpr (endian::native != endian::little) { + if (type->Size() > 1) { + gsl::span data_span{reinterpret_cast(ext_data_buf.get()), raw_data_safe_len}; + SwapByteOrderInplace(type->Size(), data_span); + } + } + + auto p_tensor = std::make_unique(type, tensor_shape, ext_data_buf.get(), + OrtMemoryInfo(CPU, OrtAllocatorType::OrtDeviceAllocator)); + ORT_RETURN_IF(raw_data_safe_len != p_tensor->SizeInBytes(), "Weight: ", name, + " External file content has length: ", static_cast(raw_data_safe_len), + " while shape has bytes size: ", p_tensor->SizeInBytes()); + + // Will destroy ext_data as a member of the functor + // can not move the unique_ptr as it is not copyable + std::function deleter = [ext_data = ext_data_buf.get(), + d = ext_data_buf.get_deleter()](void* t) { + delete reinterpret_cast(t); + d(ext_data); + }; + + ort_value.Init(p_tensor.release(), ml_tensor_type, std::move(deleter)); + ext_data_buf.release(); if (prepacked_info != nullptr && !prepacked_infos->empty()) { for (const auto& [key, blobs] : *prepacked_infos) { @@ -1109,12 +1160,11 @@ Status GetExtDataFromTensorProto(const Env& env, const std::filesystem::path& mo ORT_RETURN_IF(blob_offset < 0 || static_cast(end_of_blob) > file_length, "Pre-packed blob: ", key, " offset: ", blob_offset, " file_length: ", file_length, " is out of bounds and can not read in full"); - void* data_ptr; - OrtCallback data_deleter; - ORT_RETURN_IF_ERROR(GetFileContent(env, external_data_file_path.c_str(), blob_offset, blob_length, - data_ptr, data_deleter)); - IAllocatorUniquePtr data_ptr_unique{data_ptr, OrtCallbackInvoker(data_deleter)}; - prepacked_weights.buffers_.push_back(std::move(data_ptr_unique)); + + IAllocatorUniquePtr data_ptr; + ORT_RETURN_IF_ERROR(GetFileContent(env, external_data_file_path, blob_offset, blob_length, + data_ptr)); + prepacked_weights.buffers_.push_back(std::move(data_ptr)); prepacked_weights.buffer_sizes_.push_back(blob_length); } if (!blobs.empty()) { @@ -1132,7 +1182,7 @@ Status LoadExtDataToTensorFromTensorProto(const Env& env, const std::filesystem: const ONNX_NAMESPACE::TensorProto& tensor_proto, const IExternalDataLoader& ext_data_loader, Tensor& tensor) { - ORT_ENFORCE(utils::HasExternalData(tensor_proto)); + ORT_ENFORCE(HasExternalData(tensor_proto)); std::basic_string tensor_proto_dir; if (!model_path.empty()) { ORT_RETURN_IF_ERROR(GetDirNameFromFilePath(model_path, tensor_proto_dir)); @@ -1171,9 +1221,26 @@ Status TensorProtoToTensor(const Env& env, const std::filesystem::path& model_pa const ONNX_NAMESPACE::TensorProto& tensor_proto, Tensor& tensor) { // Validate tensor compatibility TensorShape tensor_shape = GetTensorShapeFromTensorProto(tensor_proto); + + if (std::any_of(tensor_shape.GetDims().begin(), tensor_shape.GetDims().end(), + [](int64_t dim) { + return dim < 0; + })) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "tensor can't contain negative dims"); + } + + if (HasExternalData(tensor_proto)) { + OrtValue ort_value; + ORT_RETURN_IF_ERROR(GetExtDataFromTensorProto(env, model_path, tensor_proto, ort_value)); + const auto& ext_tensor = ort_value.Get(); + MakeCpuTensorCopy(ext_tensor, tensor); + return Status::OK(); + } + if (tensor_shape != tensor.Shape()) { - return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "TensorProtoToTensor() tensor shape mismatch!"); + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "TensorProtoToTensor() tensor shape mismatch!"); } + const DataTypeImpl* const source_type = DataTypeImpl::TensorTypeFromONNXEnum(tensor_proto.data_type())->GetElementType(); if (source_type->Size() > tensor.DataType()->Size()) { @@ -1181,15 +1248,12 @@ Status TensorProtoToTensor(const Env& env, const std::filesystem::path& model_pa " can not be written into Tensor type ", DataTypeImpl::ToString(tensor.DataType())); } + // Below we handle the case where TensorProto contains data in itself + // find raw data in proto buf void* raw_data = nullptr; SafeInt raw_data_len = 0; - AutoDelete deleter_for_file_data; - OrtCallback& d = deleter_for_file_data.d; - - if (utils::HasExternalData(tensor_proto)) { - ORT_RETURN_IF_ERROR(GetExtDataFromTensorProto(env, model_path, tensor_proto, raw_data, raw_data_len, d)); - } else if (utils::HasRawData(tensor_proto)) { + if (utils::HasRawData(tensor_proto)) { raw_data = const_cast(tensor_proto.raw_data().data()); // TODO The line above has const-correctness issues. Below is a possible fix which copies the tensor_proto data // into a writeable buffer. However, it requires extra memory which may exceed the limit for certain tests. @@ -1200,25 +1264,19 @@ Status TensorProtoToTensor(const Env& env, const std::filesystem::path& model_pa raw_data_len = tensor_proto.raw_data().size(); } - if (nullptr != raw_data && utils::IsPrimitiveDataType(source_type)) { + if (nullptr != raw_data && utils::HasString(tensor_proto)) { return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "string tensor can not have raw data"); } // unpacking tensor_proto data to preallocated tensor void* preallocated = tensor.MutableDataRaw(); - int64_t tensor_size = 1; - { - for (auto i : tensor_proto.dims()) { - if (i < 0) { - return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "tensor can't contain negative dims"); - } - tensor_size *= i; - } - } + const int64_t tensor_size = tensor_shape.Size(); + // tensor_size could be zero. see test_slice_start_out_of_bounds\test_data_set_0\output_0.pb - if (static_cast(tensor_size) > SIZE_MAX) { + if (narrow(tensor_size) > SIZE_MAX) { return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "size overflow"); } + switch (tensor_proto.data_type()) { CASE_PROTO(FLOAT, float); CASE_PROTO(DOUBLE, double); @@ -1256,6 +1314,31 @@ Status TensorProtoToTensor(const Env& env, const std::filesystem::path& model_pa return Status::OK(); } +common::Status CreateTensorFromTensorProto(const Env& env, const std::filesystem::path& model_path, + const ONNX_NAMESPACE::TensorProto& tensor_proto, Tensor& tensor) { + ORT_RETURN_IF_NOT(utils::HasDataType(tensor_proto), "Initializer must have a datatype"); + auto proto_data_type = tensor_proto.data_type(); + + auto proto_shape = utils::GetTensorShapeFromTensorProto(tensor_proto); + Tensor w(DataTypeImpl::TensorTypeFromONNXEnum(proto_data_type)->GetElementType(), proto_shape, + CPUAllocator::DefaultInstance()); + ORT_RETURN_IF_ERROR(utils::TensorProtoToTensor(env, model_path, tensor_proto, w)); + + tensor = std::move(w); + return Status::OK(); +} + +Status GetTensorProtoWithDataIfInMemory( + const ONNX_NAMESPACE::TensorProto& tensor_proto, std::unique_ptr& result) { + if (HasExternalDataInMemory(tensor_proto)) { + result = std::make_unique(); + return TensorProtoWithExternalDataToTensorProto(tensor_proto, {}, *result); + } + + result.reset(); + return Status::OK(); +} + Status TensorProtoToOrtValue(const Env& env, const std::filesystem::path& model_path, const ONNX_NAMESPACE::TensorProto& tensor_proto, const MemBuffer& m, OrtValue& value) { return TensorProtoToOrtValueImpl(env, model_path, tensor_proto, &m, nullptr, value); @@ -1318,15 +1401,7 @@ ONNX_NAMESPACE::TensorProto TensorToTensorProto(const Tensor& tensor, } tensor_proto.set_data_type(tensor.GetElementType()); - if (tensor.IsDataTypeString()) { - auto* mutable_string_data = tensor_proto.mutable_string_data(); - auto f = tensor.Data(); - auto end = f + tensor.Shape().Size(); - for (; f < end; ++f) { - *mutable_string_data->Add() = *f; - } - } else if (use_tensor_buffer && tensor.SizeInBytes() > 127) { - // The logic aligns with + if (use_tensor_buffer && tensor.SizeInBytes() > kSmallTensorExternalDataThreshold) { // https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/graph/graph_flatbuffers_utils.cc#L302 const auto* raw_data = tensor.DataRaw(); ORT_ENFORCE(raw_data, "Missing raw data for tensor proto. Invalid tensor."); @@ -1341,12 +1416,34 @@ ONNX_NAMESPACE::TensorProto TensorToTensorProto(const Tensor& tensor, offset, tensor.SizeInBytes(), tensor_proto); } else { - utils::SetRawDataInTensorProto(tensor_proto, tensor.DataRaw(), tensor.SizeInBytes()); + if (tensor.IsDataTypeString()) { + auto* mutable_string_data = tensor_proto.mutable_string_data(); + auto f = tensor.Data(); + auto end = f + tensor.Shape().Size(); + for (; f < end; ++f) { + *mutable_string_data->Add() = *f; + } + } else { + SetRawDataInTensorProto(tensor_proto, tensor.DataRaw(), tensor.SizeInBytes()); + } } return tensor_proto; } +ONNX_NAMESPACE::TypeProto TypeProtoFromTensorProto(const ONNX_NAMESPACE::TensorProto& tensor_proto) { + TypeProto type_proto; + + type_proto.mutable_tensor_type()->set_elem_type(tensor_proto.data_type()); + auto shape = type_proto.mutable_tensor_type()->mutable_shape(); + + for (auto dim : tensor_proto.dims()) { + shape->add_dim()->set_dim_value(dim); + } + + return type_proto; +} + common::Status ConstantNodeProtoToTensorProto(const ONNX_NAMESPACE::NodeProto& node, const std::filesystem::path& model_path, ONNX_NAMESPACE::TensorProto& tensor, const std::string& tensor_name) { @@ -1413,6 +1510,15 @@ common::Status ConstantNodeProtoToTensorProto(const ONNX_NAMESPACE::NodeProto& n return ConstantNodeProtoToTensorProto(node, model_path, tensor, node.output(0)); } +void MakeCpuTensorCopy(const Tensor& src_tensor, Tensor& dst_tensor) { + if (src_tensor.IsDataTypeString()) { + auto src_span = src_tensor.DataAsSpan(); + std::copy(src_span.begin(), src_span.end(), dst_tensor.MutableDataAsSpan().begin()); + } else { + std::memcpy(dst_tensor.MutableDataRaw(), src_tensor.DataRaw(), src_tensor.SizeInBytes()); + } +} + #if !defined(DISABLE_SPARSE_TENSORS) static Status CopySparseData(size_t n_sparse_elements, const ONNX_NAMESPACE::TensorProto& indices, @@ -1847,7 +1953,7 @@ Status UnpackInitializerData(const onnx::TensorProto& initializer, // TODO, if std::vector does not use a custom allocator, the default std::allocator will // allocation the memory aligned to std::max_align_t, need look into allocating // forced aligned memory (align as 16 or larger)for unpacked_tensor - if (initializer.data_location() == TensorProto_DataLocation_EXTERNAL) { + if (HasExternalData(initializer)) { ORT_RETURN_IF_ERROR(ReadExternalDataForTensor( initializer, model_path.parent_path(), diff --git a/onnxruntime/core/framework/tensorprotoutils.h b/onnxruntime/core/framework/tensorprotoutils.h index 79eae48c10411..347885e90bbf5 100644 --- a/onnxruntime/core/framework/tensorprotoutils.h +++ b/onnxruntime/core/framework/tensorprotoutils.h @@ -40,19 +40,13 @@ Status GetExternalDataInfo(const ONNX_NAMESPACE::TensorProto& tensor_proto, SafeInt& tensor_byte_size, ExternalDataInfo::PrepackedInfos* prepacked_infos = nullptr); /** - * This function is used to convert the endianess of Tensor data. - * If ext_data_buf is provided, then this buffer content's endianess - * will be changed. + * This function is used to convert the endianess of TensorProto data. + * * Mostly, will be used in big endian system to support the model file * generated on little endian system. * @param tensor_proto given initializer tensor - * @param ext_data_buf optional externl data buffer - * @param ext_data_len optional externl data buffer lengeh - * @returns None */ -void ConvertRawDataInTensorProto(ONNX_NAMESPACE::TensorProto* tensor_proto, - void* ext_data_buf = NULL, - size_t ext_data_len = 0); +void ConvertRawDataInTensorProto(ONNX_NAMESPACE::TensorProto& tensor_proto); /** * Wrapper function for set_raw_data. @@ -68,7 +62,7 @@ void SetRawDataInTensorProto(ONNX_NAMESPACE::TensorProto& tensor_proto, T1* raw_ using namespace ONNX_NAMESPACE; tensor_proto.set_raw_data(raw_data, raw_data_len); if constexpr (endian::native != endian::little) { - utils::ConvertRawDataInTensorProto((ONNX_NAMESPACE::TensorProto*)&tensor_proto); + utils::ConvertRawDataInTensorProto(tensor_proto); } } @@ -92,7 +86,7 @@ bool operator==(const TensorShapeProto_Dimension& l, const TensorShapeProto_Dime bool operator!=(const TensorShapeProto_Dimension& l, const TensorShapeProto_Dimension& r); } // namespace ONNX_NAMESPACE -#endif +#endif // !defined(SHARED_PROVIDER) namespace onnxruntime { namespace utils { @@ -102,6 +96,17 @@ TensorShape GetTensorShapeFromTensorShapeProto(const ONNX_NAMESPACE::TensorShape TensorShape GetTensorShapeFromTensorProto(const ONNX_NAMESPACE::TensorProto& tensor_proto); +/// +/// This function checks if the tensor_proto has external data in memory. +/// If it does, it converts it to a result with data inline, otherwise it does nothing. +/// The function returns a unique_ptr to make it compatible with EPs code. +/// +/// source proto +/// result, can be nullptr if no data in memory, still a success +/// Status +Status GetTensorProtoWithDataIfInMemory(const ONNX_NAMESPACE::TensorProto& tensor_proto, + std::unique_ptr& result); + /** * deserialize a TensorProto into a preallocated memory buffer on CPU. * \param tensor_proto_path A local file path of where the 'input' was loaded from. @@ -137,6 +142,30 @@ common::Status TensorProtoToTensor(const Env& env, const std::filesystem::path& const ONNX_NAMESPACE::TensorProto& tensor_proto, Tensor& tensor); +/** + * @brief Pre-allocates empty tensor and deserializes a TensorProto into it + * @param env + * @param model_path + * @param tensor_proto source data + * @param tensor destination empty tensor + * @return + */ +common::Status CreateTensorFromTensorProto(const Env& env, const std::filesystem::path& model_path, + const ONNX_NAMESPACE::TensorProto& tensor_proto, + Tensor& tensor); + +/// The threshold for small tensors. If the size of the tensor is LE to this value, +/// The data will stay in the TensorProto. Otherwise, the data will be moved to a Tensor instance +/// and TensorProto will contain a kTensorProtoMemoryAddressTag reference as a result of +/// TensorToTensorProto() below. This is because shape inferencing code in onnx for +/// like Reshape parses weights data and it needs to be in the TensorProto. +/// The value of 127 was chosen empirically to be the smallest value that is required +/// for onnx shape inference to work correctly. The value also takes into account the overhead +/// imposed by having external data. The external data requires location/offset/filename so for +/// small values it is better to keep the data inline in the TensorProto, even if they are not used +/// in shape inferencing, it is cheaper to inline them. +constexpr const size_t kSmallTensorExternalDataThreshold = 127; // 127 bytes + /** * @brief Creates a TensorProto from a Tensor. * @param[in] tensor the Tensor whose data and shape will be used to create the TensorProto. @@ -157,6 +186,9 @@ ONNX_NAMESPACE::TensorProto TensorToTensorProto(const Tensor& tensor, ONNXTensorElementDataType CApiElementTypeFromProtoType(int type); ONNXTensorElementDataType GetTensorElementType(const ONNX_NAMESPACE::TensorProto& tensor_proto); +// Creates a TypeProto from a TensorProto. +ONNX_NAMESPACE::TypeProto TypeProtoFromTensorProto(const ONNX_NAMESPACE::TensorProto& tensor_proto); + // How much memory it will need for putting the content of this tensor into a plain array // complex64/complex128 tensors are not supported. // The output value could be zero or -1. @@ -173,18 +205,20 @@ address of the memory containing the data. */ constexpr const ORTCHAR_T* kTensorProtoMemoryAddressTag = ORT_TSTR("*/_ORT_MEM_ADDR_/*"); -// Given a tensor proto with external data obtain a pointer to the data and its length. -// The ext_data_deleter argument is updated with a callback that owns/releases the data. -// If tensor_proto's external file path is kTensorProtoMemoryAddressTag, and -// buffered_tensor is not null, buffered_tensor holds the real buffer pointed -// by tensor_proto. buffered_tensor must be the owner of the buffer and deleter -// should release the buffer when tensor_proto is released. +/// +/// Creates a OrtValue with a tensor on top of the external data. +/// If tensor_proto points to a memory address, the OrtValue will be created with a tensor +/// that does not own the memory since the memory is already owned by some other entity. +/// +/// +/// model path +/// tensor proto containing external data +/// output ort value +/// optional pre-packed weight data output container +/// Status common::Status GetExtDataFromTensorProto(const Env& env, const std::filesystem::path& model_path, const ONNX_NAMESPACE::TensorProto& tensor_proto, - void*& ext_data_buf, SafeInt& ext_data_len, - OrtCallback& ext_data_deleter, - Tensor* buffered_tensor = nullptr, - PrepackedWeightsForGraph* prepacked_for_graph = nullptr); + OrtValue& ort_value, PrepackedWeightsForGraph* prepacked_info = nullptr); // Given a tensor proto with external data obtain a tensor using the specified custom external data loader. common::Status LoadExtDataToTensorFromTensorProto(const Env& env, const std::filesystem::path& model_path, @@ -207,6 +241,13 @@ common::Status ConstantNodeProtoToTensorProto(const ONNX_NAMESPACE::NodeProto& n const std::filesystem::path& model_path, ONNX_NAMESPACE::TensorProto& tensor); +/// +/// Creates a new CPU based tensor and copies the data from the source tensor. +/// +/// +/// +void MakeCpuTensorCopy(const Tensor& src_tensor, Tensor& dst_tensor); + #if !defined(DISABLE_SPARSE_TENSORS) // Convert a SparseTensorProto to a dense TensorProto // If the SparseTensorProto contains external data then it loads the data and converts to dense tensor proto @@ -226,7 +267,7 @@ common::Status DenseTensorToSparseTensorProto(const ONNX_NAMESPACE::TensorProto& ONNX_NAMESPACE::SparseTensorProto& sparse); #endif // !ORT_MINIMAL_BUILD #endif // !defined(DISABLE_SPARSE_TENSORS) -#endif +#endif // !defined(SHARED_PROVIDER) inline bool HasDimValue(const ONNX_NAMESPACE::TensorShapeProto_Dimension& dim) { return dim.value_case() == ONNX_NAMESPACE::TensorShapeProto_Dimension::kDimValue; @@ -435,7 +476,7 @@ inline bool HasKeyType(const ONNX_NAMESPACE::TypeProto_Map& map_proto) { inline bool HasValueType(const ONNX_NAMESPACE::TypeProto_Map& map_proto) { return map_proto.value_type().value_case() != ONNX_NAMESPACE::TypeProto::VALUE_NOT_SET; } -#endif +#endif // !defined(SHARED_PROVIDER) inline bool HasType(const ONNX_NAMESPACE::ValueInfoProto& vi_proto) { return vi_proto.type().value_case() != ONNX_NAMESPACE::TypeProto::VALUE_NOT_SET; @@ -454,7 +495,26 @@ inline bool HasName(const ONNX_NAMESPACE::TypeProto_Opaque& op_proto) { return !op_proto.name().empty(); } -#endif +/// +/// Quick check if the this tensor proto has external data in memory. +/// +/// tensor_proto +/// true if ten_proto has external data and it is in memory +bool HasExternalDataInMemory(const ONNX_NAMESPACE::TensorProto& tensor_proto); + +/// +/// This function converts TensorProto with external data to TensorProto with inline data. +/// +/// source +/// model_path, can be empty if data is in memory +/// result +/// Status +Status TensorProtoWithExternalDataToTensorProto( + const ONNX_NAMESPACE::TensorProto& tensor_proto, + const std::filesystem::path& model_path, + ONNX_NAMESPACE::TensorProto& new_tensor_proto); + +#endif // !defined(SHARED_PROVIDER) inline bool HasType(const ONNX_NAMESPACE::AttributeProto& at_proto) { return at_proto.type() != ONNX_NAMESPACE::AttributeProto::AttributeType::AttributeProto_AttributeType_UNDEFINED; @@ -521,7 +581,7 @@ inline bool HasName(const ONNX_NAMESPACE::NodeProto& node_proto) { // XXX: Figure out proto3 style return node_proto.has_name(); } -#endif +#endif // !defined(SHARED_PROVIDER) // UnpackTensor from raw data or the type specific data field. Does not handle external data. // If the tensor does not contain raw data then raw_data should be nullptr and raw_data_len should be 0. diff --git a/onnxruntime/core/graph/abi_graph_types.h b/onnxruntime/core/graph/abi_graph_types.h new file mode 100644 index 0000000000000..e92861fc4de63 --- /dev/null +++ b/onnxruntime/core/graph/abi_graph_types.h @@ -0,0 +1,280 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include +#include "core/common/inlined_containers_fwd.h" +#include "core/framework/onnxruntime_typeinfo.h" +#include "core/graph/onnx_protobuf.h" + +#define DEFINE_ORT_GRAPH_IR_TO_EXTERNAL_INTERNAL_FUNCS(external_type, internal_type, internal_api) \ + external_type* ToExternal() { return static_cast(this); } \ + const external_type* ToExternal() const { return static_cast(this); } \ + static internal_type* ToInternal(external_type* e) { \ + return e->graph_ir_api == (internal_api) ? static_cast(e) : nullptr; \ + } \ + static const internal_type* ToInternal(const external_type* e) { \ + return e->graph_ir_api == (internal_api) ? static_cast(e) : nullptr; \ + } + +// The public ORT graph IR types (e.g., OrtGraph, OrtNode, etc.) have different implementations for the +// ModelEditor API and EP API. This enum allows a user of the base class (e.g., OrtGraph) to determine +// the API for which the derived class was created. +enum class OrtGraphIrApi { + kInvalid = 0, + kModelEditorApi, + kEpApi, +}; + +/// +/// Public type that represents an ONNX value info. +/// +struct OrtValueInfo { + explicit OrtValueInfo(OrtGraphIrApi graph_ir_api) : graph_ir_api(graph_ir_api) {} + virtual ~OrtValueInfo() = default; + + /// + /// Returns the value's name. + /// + /// The value's name. + virtual const std::string& GetName() const = 0; + + /// + /// Return's an object describing the value's type and shape. + /// + /// OrtTypeInfo with the type and shape. + virtual const OrtTypeInfo* GetTypeInfo() const = 0; + + struct ProducerInfo { + ProducerInfo() = default; + ProducerInfo(const OrtNode* node, size_t output_index) : node(node), output_index(output_index) {} + const OrtNode* node = nullptr; + size_t output_index = 0; + }; + + /// + /// Returns the node (and output index) that produced the value. + /// + /// Output parameter set to the node and the output index that produced the value. + /// A status indicating success or an error. + virtual onnxruntime::Status GetProducerInfo(ProducerInfo& producer_info) const = 0; + + struct ConsumerInfo { + ConsumerInfo() = default; + ConsumerInfo(const OrtNode* node, int64_t input_index) : node(node), input_index(input_index) {} + const OrtNode* node = nullptr; + int64_t input_index = 0; // Negative if it is an implicit input to a node that contains a subgraph (e.g., If). + }; + + /// + /// Returns information on the nodes that consume the value. Includes each consumer node's input index, + /// which could be -1 for an implicit input to the node (e.g., If or Loop node). + /// + /// Output parameter set to the array of ConsumerInfo objects that describe the + /// use of this value (consumer node and input index). + /// A status indicating success or an error. + virtual onnxruntime::Status GetConsumerInfos(std::vector& consumer_infos) const = 0; + + /// + /// Returns the number of consumers for this value. In this context, a consumer is a tuple of the node and the input + /// index that uses the value. + /// + /// Output parameter set to the number of consumers. + /// A status indicating success or an error. + virtual onnxruntime::Status GetNumConsumerInfos(size_t& num_consumers) const = 0; + + /// + /// Returns the associated initializer value if this value represents an initializer (constant or non-constant). + /// + /// Output parameter set to the initializer value or nullptr if this value is not + /// an initializer. + /// A status indicating success or an error. + virtual onnxruntime::Status GetInitializerValue(const OrtValue*& value) const = 0; + + /// + /// Determine if the value is a required graph input. + /// + /// Output parameter set to true if the value is a required graph input. + /// A status indicating success or an error. + virtual onnxruntime::Status IsRequiredGraphInput(bool& is_required_graph_input) const = 0; + + /// + /// Determine if the value is an optional graph input. + /// + /// Output parameter set to true if the value is an optional graph + /// input. + /// A status indicating success or an error. + virtual onnxruntime::Status IsOptionalGraphInput(bool& is_optional_graph_input) const = 0; + + /// + /// Determine if a the value is a graph output. + /// + /// Output parameter set to true if the value is a graph output. + /// A status indicating success or an error. + virtual onnxruntime::Status IsGraphOutput(bool& is_graph_output) const = 0; + + /// + /// Determine if the value is a constant initializer. + /// + /// Output parameter set to true if the value is a constant + /// initializer. + /// A status indicating success or an error. + virtual onnxruntime::Status IsConstantInitializer(bool& is_const_initializer) const = 0; + + /// + /// Determine if the value is defined in an outer scope (i.e., a parent graph). + /// + /// Output parameter set to true if the value is defined in an outer scope. + /// A status indicating success or an error. + virtual onnxruntime::Status IsFromOuterScope(bool& is_outer_scope) const = 0; + + OrtGraphIrApi graph_ir_api = OrtGraphIrApi::kInvalid; +}; + +/// +/// Public type that represents an ONNX attribute. Currently, an OrtOpAttr is interchangeable with AttributeProto. +/// +struct OrtOpAttr { + ONNX_NAMESPACE::AttributeProto attr_proto; +}; + +/// +/// Public type that represents an ONNX node. +/// +struct OrtNode { + explicit OrtNode(OrtGraphIrApi graph_ir_api) : graph_ir_api(graph_ir_api) {} + virtual ~OrtNode() = default; + + /// + /// Returns the node's ID, which is unique in it's graph. + /// + /// The node's ID. + virtual size_t GetId() const = 0; + + /// + /// Returns the node's name. + /// + /// The node's name + virtual const std::string& GetName() const = 0; + + /// + /// Returns the node's operator type (e.g., "Conv"). + /// + /// The node's operator type. + virtual const std::string& GetOpType() const = 0; + + /// + /// Returns the node's domain name. + /// + /// The node's domain name. + virtual const std::string& GetDomain() const = 0; + + /// + /// Gets the opset version in which the node's operator type was first defined. + /// + /// Output parameter set to the node's operator "since version". + /// A status indicating success or an error. + virtual onnxruntime::Status GetSinceVersion(int& since_version) const = 0; + + /// + /// Gets the node's inputs as an array of OrtValueInfo elements wrapped in an OrtArrayOfConstObjects. + /// + /// Output parameter set to the node's inputs. + /// A status indicating success or an error. + virtual onnxruntime::Status GetInputs(std::unique_ptr& inputs) const = 0; + + /// + /// Gets the node's outputs as an array of OrtValueInfo elements wrapped in an OrtArrayOfConstObjects. + /// + /// Output parameter set to the node's outputs. + /// A status indicating success or an error. + virtual onnxruntime::Status GetOutputs(std::unique_ptr& outputs) const = 0; + + /// + /// Gets the node's implicit inputs as an array of OrtValueInfo elements wrapped in an OrtArrayOfConstObjects. + /// Applies to a node that contains a subgraph (e.g., If or Loop). An implicit input is a value consumed by an + /// internal subgraph node that is not defined in the subgraph. + /// + /// Output parameter set to the node's implicit inputs. + /// A status indicating success or an error. + virtual onnxruntime::Status GetImplicitInputs(std::unique_ptr& implicit_inputs) const = 0; + + /// + /// Gets the node's subgraphs (e.g., subgraphs contained by an If or Loop node). + /// + /// Output parameter set to the node's subgraphs as OrtGraph instances. + /// A status indicating success or an error. + virtual onnxruntime::Status GetSubgraphs(std::unique_ptr& subgraphs) const = 0; + + /// + /// Gets the node's parent graph, which is the graph that contains this node. + /// + /// Output parameter set to the node's parent graph. + /// A status indicating success or an error. + virtual onnxruntime::Status GetParentGraph(const OrtGraph*& parent_graph) const = 0; + + OrtGraphIrApi graph_ir_api = OrtGraphIrApi::kInvalid; +}; + +/// +/// Public type that represents an ONNX graph. +/// +struct OrtGraph { + explicit OrtGraph(OrtGraphIrApi graph_ir_api) : graph_ir_api(graph_ir_api) {} + virtual ~OrtGraph() = default; + + /// + /// Returns the graph's name. + /// + /// The graph's name. + virtual const std::string& GetName() const = 0; + + /// + /// Returns the model's ONNX IR version. Important in checking for optional graph inputs + /// (aka non-constant initializers), which were introduced in ONNX IR version 4. + /// + /// The model's ONNX IR version. + virtual int64_t GetOnnxIRVersion() const = 0; + + /// + /// Gets the graph's inputs (including initializers) as OrtValueInfo instances wrapped in an OrtArrayOfConstObjects. + /// + /// Output parameter set to the graph's inputs. + /// A status indicating success or an error. + virtual onnxruntime::Status GetInputs(std::unique_ptr& inputs) const = 0; + + /// + /// Gets the graph's outputs as OrtValueInfo instances wrapped in an OrtArrayOfConstObjects. + /// + /// Output parameter set to the graph's outputs. + /// A status indicating success or an error. + virtual onnxruntime::Status GetOutputs(std::unique_ptr& outputs) const = 0; + + /// + /// Gets the graph's initializers (both constant and non-constant) as OrtValueInfo instances wrapped in an + /// OrtArrayOfConstObjects. + /// + /// Output parameter set to the graph's initializers. + /// A status indicating success or an error. + virtual onnxruntime::Status GetInitializers(std::unique_ptr& initializers) const = 0; + + /// + /// Gets the graph's nodes as OrtNode instances wrapped in an OrtArrayOfConstObjects. The nodes are sorted in + /// a default "reverse DFS" topological order. + /// + /// Output parameter set to the graph's nodes. + /// A status indicating success or an error. + virtual onnxruntime::Status GetNodes(std::unique_ptr& nodes) const = 0; + + /// + /// Gets the graph's parent node, if any. The parent_node is nullptr if this is not a nested subgraph. + /// + /// Output parameter set to the parent node. + /// A status indicating success or an error. + virtual onnxruntime::Status GetParentNode(const OrtNode*& parent_node) const = 0; + + OrtGraphIrApi graph_ir_api = OrtGraphIrApi::kInvalid; +}; diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc index 238dd8d4573de..f2757c2c96471 100644 --- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc @@ -1206,6 +1206,202 @@ ONNX_MS_OPERATOR_SET_SCHEMA( GroupQueryAttentionTypeAndShapeInference(ctx, 3); })); +constexpr const char* PagedAttention_ver1_doc = R"DOC( +Paged Attention. + +This op leverages a block-based KV cache to enable continuous batching for LLMs. Currently, it is designed to work with +the CUDA Execution Provider only. + +In other attention ops, batch entries typically aren't of the same length, so they are padded. +Below is a batch with 3 sequences where * denotes a padding token. + Sequence_0: 0, 1*, 2*, 3* + Sequence_1: 4, 5, 6*, 7* + Sequence_2: 8, 9, 10, 11 + +PagedAttention is designed to take in packed input, i.e., only the real tokens without padding. +For example, the input shown above will be packed into 3 tensors like below: + - query ([q0, q4, q5, q8, q9, q10, q11]) + - key ([k0, k4, k5, k8, k9, k10, k11]) + - value ([v0, v4, v5, v8, v9, v10, v11]) + - cumulative_sequence_length: 0, 1, 1+2, 1+2+4 +This packing omits padding tokens. + +The query, key and value tensors contain result of hidden embedding of real tokens after input projections. +cumulative_sequence_length records cumulated length of each sequence length. + +)DOC"; + +// Shape inference for PagedAttention. Here are the shapes of inputs and output: +// When Q, K and V are not packed: +// Input 'query': (token_count, hidden_size) +// Input 'key': (token_count, kv_hidden_size) +// Input 'value': (token_count, kv_hidden_size) +// When Q, K and V are packed: +// Input 'query': (token_count, (num_heads + 2 * kv_num_heads) * head_size) +// Input 'key': None +// Input 'value': None +// Input 'key_cache': (num_blocks, block_size, kv_num_heads, head_size) +// Input 'value_cache': (num_blocks, block_size, kv_num_heads, head_size) +// Input 'cumulative_sequence_length': (batch_size + 1) +// Input 'seqlens': (batch_size) +// Input 'block_table': (batch_size, max_blocks_per_sequence) +// Input 'cos_cache': (max_seq_len, head_size / 2) +// Input 'sin_cache': (max_seq_len, head_size / 2) +// Output 'output': (token_count, hidden_size) +// Output 'key_cache_out': (num_blocks, block_size, kv_num_heads, head_size) +// Output 'value_cache_out': (num_blocks, block_size, kv_num_heads, head_size) +void PagedAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& ctx) { + // Type inference + ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 0, 0); + + // Shape inference for output tensor + if (hasInputShape(ctx, 0)) { + auto& query_shape = getInputShape(ctx, 0); + auto& query_dims = query_shape.dim(); + + if (query_dims.size() != 2) { + fail_shape_inference("Input 0 (query) shall be 2 dimensions"); + } + + if (ctx.hasInput(2)) { + ONNX_NAMESPACE::TensorShapeProto output_shape; + propagateShapeFromInputToOutput(ctx, 0, 0); + } else { // packed QKV + ONNX_NAMESPACE::TensorShapeProto output_shape; + *output_shape.add_dim() = query_dims[0]; + int64_t num_heads = getAttribute(ctx, "num_heads", 0); + int64_t kv_num_heads = getAttribute(ctx, "kv_num_heads", 0); + int64_t hidden_size = query_dims[1].dim_value(); + if (hidden_size <= 0 || num_heads <= 0 || kv_num_heads < 0) { + fail_shape_inference("Invalid hidden size or number of heads. Hidden size, num_heads and kv_num_heads must be positive integers."); + } else if (hidden_size % (num_heads + 2 * kv_num_heads) != 0) { + fail_shape_inference("Hidden size must be divisible by (num_heads + 2 * kv_num_heads)."); + } + int64_t head_size = hidden_size / (num_heads + 2 * kv_num_heads); + output_shape.add_dim()->set_dim_value(head_size * num_heads); + updateOutputShape(ctx, 0, output_shape); + } + } + + // Shape inference for KV Cache output tensors + if (ctx.getNumOutputs() > 1) { // has kv cache output + if (ctx.getNumOutputs() != 3) { + fail_shape_inference("Key cache and value cache output tensors must be both present or both absent."); + } + // types + ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 0, 1); + ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 0, 2); + // shapes + auto& key_cache_shape = getInputShape(ctx, 3); + auto& key_cache_dims = key_cache_shape.dim(); + if (key_cache_dims.size() != 4) { + fail_shape_inference("The block-based KV cache inputs shall be 4 dimensions"); + } + // KV cache in and out share the same buffer, thus they have the same shape + ONNX_NAMESPACE::propagateShapeFromInputToOutput(ctx, 3, 1); + ONNX_NAMESPACE::propagateShapeFromInputToOutput(ctx, 4, 2); + ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 3, 1); + ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 4, 2); + } +} + +ONNX_MS_OPERATOR_SET_SCHEMA( + PagedAttention, 1, + OpSchema() + .SetDoc(PagedAttention_ver1_doc) + .Attr("num_heads", "Number of attention heads for q", AttributeProto::INT) + .Attr("kv_num_heads", "Number of attention heads for k and v", AttributeProto::INT) + .Attr("scale", + "Custom scale will be used if specified. Default value is 1/sqrt(head_size)", + AttributeProto::FLOAT, + OPTIONAL_VALUE) + .Attr("softcap", + "Softcap value for attention weights. Default value is 0.", + AttributeProto::FLOAT, + OPTIONAL_VALUE) + .Attr("local_window_size", + "left_window_size for local attention (like Mistral). Default value is -1 meaning unused.", + AttributeProto::INT, + static_cast(-1)) + .Attr("do_rotary", + "Whether to use rotary position embedding. Default value is 0.", + AttributeProto::INT, + OPTIONAL_VALUE) + .Attr("rotary_interleaved", + "Rotate using interleaved pattern. Default value is 0 (False).", + AttributeProto::INT, + OPTIONAL_VALUE) + .Input(0, + "query", + "Query with shape (num_tokens, hidden_size), or packed QKV with shape (num_tokens, d) " + "where d is (num_heads * head_size + 2 * kv_num_heads * head_size).", + "T") + .Input(1, + "key", + "Key with shape (num_tokens, kv_hidden_size) ", + "T", + OpSchema::Optional) + .Input(2, + "value", + "Value with shape (num_tokens, kv_hidden_size)", + "T", + OpSchema::Optional) + .Input(3, + "key_cache", + "Block-based key cache with shape (num_blocks, block_size, kv_num_heads, head_size). This is updated in " + "place within the op.", + "T") + .Input(4, + "value_cache", + "Block-based value cache with shape (num_blocks, block_size, kv_num_heads, head_size). This is updated " + "in place within the op. This should be the same shape as key_cache.", + "T") + .Input(5, + "cumulative_sequence_length", + "A tensor with shape (batch_size + 1). It specifies the cumulative sequence lengths between the packed " + "entries in Q/K/V.", + "S") + .Input(6, + "past_seqlens", + "A tensor with shape (batch_size). It specifies the past lengths of cached sequence in the KV cache.", + "S") + .Input(7, + "block_table", + "2D tensor with shape (batch_size, max_blocks_per_sequence) that maps each sequence in the batch to its" + "corresponding blocks in the KV cache.", + "S") + .Input(8, + "cos_cache", + "2D tensor with shape (max total seqlen, head_size / 2).", + "T", + OpSchema::Optional) + .Input(9, + "sin_cache", + "2D tensor with shape (max total seqlen, head_size / 2).", + "T", + OpSchema::Optional) + .Output(0, + "output", + "3D output tensor with shape (num_tokens, hidden_size)", + "T") + .Output(1, + "key_cache_out", + "Block-based key cache with shape (num_blocks, block_size, kv_num_heads, head_size). This is always " + "the same tensor as key_cache.", + "T", + OpSchema::Optional) + .Output(2, + "value_cache_out", + "Block-based value cache with shape (num_blocks, block_size, kv_num_heads, head_size). This is always " + "the same tensor as value_cache.", + "T", + OpSchema::Optional) + .TypeConstraint("T", {"tensor(float16)", "tensor(bfloat16)"}, "Constrain input and output to float tensors.") + .TypeConstraint("S", {"tensor(int32)"}, "Constrain Positional inputs to int tensor.") + .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { + PagedAttentionTypeAndShapeInference(ctx); + })); + constexpr const char* SparseAttention_ver1_doc = R"DOC( Block Sparse Attention used in Phi-3-small (https://arxiv.org/pdf/2404.14219). diff --git a/onnxruntime/core/graph/contrib_ops/ms_opset.h b/onnxruntime/core/graph/contrib_ops/ms_opset.h index a9a89f756b071..6c20aae94d132 100644 --- a/onnxruntime/core/graph/contrib_ops/ms_opset.h +++ b/onnxruntime/core/graph/contrib_ops/ms_opset.h @@ -87,6 +87,7 @@ class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, MoE); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, QMoE); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, MultiHeadAttention); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, GroupQueryAttention); +class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, PagedAttention); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, MurmurHash3); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, NGramRepeatBlock); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Pad); @@ -197,6 +198,7 @@ class OpSet_Microsoft_ver1 { fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); + fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); diff --git a/onnxruntime/core/graph/ep_api_types.cc b/onnxruntime/core/graph/ep_api_types.cc new file mode 100644 index 0000000000000..bcef5fda9c0b4 --- /dev/null +++ b/onnxruntime/core/graph/ep_api_types.cc @@ -0,0 +1,654 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/graph/ep_api_types.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "core/framework/allocator.h" +#include "core/framework/tensorprotoutils.h" +#include "core/framework/onnxruntime_typeinfo.h" +#include "core/graph/graph_viewer.h" +#include "core/graph/graph.h" + +namespace onnxruntime { + +// Create an EpValueInfo from a NodeArg. +static std::unique_ptr CreateValueInfo(const NodeArg& node_arg, const EpGraph* ep_graph, size_t flags) { + const auto* type_proto = node_arg.TypeAsProto(); + std::unique_ptr type_info = type_proto != nullptr ? OrtTypeInfo::FromTypeProto(*type_proto) + : nullptr; + return std::make_unique(ep_graph, node_arg.Name(), std::move(type_info), flags); +} + +// Convert an array of NodeArgs to an array of EpValueInfos. The value_infos array should be the same size as the +// array of NodeArgs before calling this function. +static void ConvertNodeArgsToValueInfos(const EpGraph* ep_graph, + std::unordered_map>& value_infos_map, + gsl::span node_args, gsl::span value_infos, + std::function set_value_info_flags = nullptr) { + assert(node_args.size() == value_infos.size()); + + for (size_t i = 0; i < node_args.size(); ++i) { + gsl::not_null node_arg = node_args[i]; + const std::string& value_name = node_arg->Name(); + + if (!node_arg->Exists()) { + // A missing optional input/output has a null OrtValueInfo. + value_infos[i] = nullptr; + continue; + } + + auto value_info_iter = value_infos_map.find(value_name); + + if (value_info_iter != value_infos_map.end()) { + EpValueInfo* value_info = value_info_iter->second.get(); + + if (set_value_info_flags) { + set_value_info_flags(value_info); + } + + value_infos[i] = value_info; + } else { + std::unique_ptr value_info = CreateValueInfo(*node_arg, ep_graph, EpValueInfo::Flags::kFlagNone); + + if (set_value_info_flags) { + set_value_info_flags(value_info.get()); + } + + value_infos[i] = value_info.get(); + value_infos_map.emplace(value_name, std::move(value_info)); + } + } +} + +// +// EpNode +// + +EpNode::EpNode(const EpGraph* ep_graph, const Node& node, PrivateTag) + : OrtNode(OrtGraphIrApi::kEpApi), ep_graph_(ep_graph), node_(node) {} + +Status EpNode::Create(const Node& node, const EpGraph* ep_graph, + std::unordered_map>& value_infos_map, + /*out*/ std::unique_ptr& result) { + auto ep_node = std::make_unique(ep_graph, node, PrivateTag{}); + + auto node_inputs = node.InputDefs(); + auto node_outputs = node.OutputDefs(); + InlinedVector ep_node_inputs(node_inputs.size(), nullptr); + InlinedVector ep_node_outputs(node_outputs.size(), nullptr); + + ConvertNodeArgsToValueInfos(ep_graph, value_infos_map, node_inputs, ep_node_inputs); + ConvertNodeArgsToValueInfos(ep_graph, value_infos_map, node_outputs, ep_node_outputs); + + std::vector ep_node_subgraphs; + std::vector ep_node_implicit_inputs; + + if (node.ContainsSubgraph()) { + const auto node_implicit_inputs = node.ImplicitInputDefs(); + ep_node_implicit_inputs.resize(node_implicit_inputs.size(), nullptr); + + ConvertNodeArgsToValueInfos(ep_graph, value_infos_map, node_implicit_inputs, ep_node_implicit_inputs); + + std::vector> node_subgraphs = node.GetSubgraphs(); + ep_node_subgraphs.reserve(node_subgraphs.size()); + + for (gsl::not_null subgraph : node_subgraphs) { + SubgraphState subgraph_state; + subgraph_state.subgraph_viewer = std::make_unique(*subgraph); + ORT_RETURN_IF_ERROR(EpGraph::Create(*subgraph_state.subgraph_viewer, subgraph_state.ep_subgraph)); + subgraph_state.ep_subgraph->SetParentNode(ep_node.get()); + + ep_node_subgraphs.emplace_back(std::move(subgraph_state)); + } + } + + ep_node->inputs_ = std::move(ep_node_inputs); + ep_node->outputs_ = std::move(ep_node_outputs); + ep_node->implicit_inputs_ = std::move(ep_node_implicit_inputs); + ep_node->subgraphs_ = std::move(ep_node_subgraphs); + + result = std::move(ep_node); + + return Status::OK(); +} + +size_t EpNode::GetId() const { return node_.Index(); } + +const std::string& EpNode::GetName() const { return node_.Name(); } + +const std::string& EpNode::GetOpType() const { return node_.OpType(); } + +const std::string& EpNode::GetDomain() const { return node_.Domain(); } + +Status EpNode::GetSinceVersion(int& since_version) const { + since_version = node_.SinceVersion(); + return Status::OK(); +} + +Status EpNode::GetInputs(std::unique_ptr& result) const { + result = std::make_unique(ORT_TYPE_TAG_OrtValueInfo); + result->storage.reserve(inputs_.size()); + + for (const EpValueInfo* input : inputs_) { + result->storage.push_back(input); + } + + return Status::OK(); +} + +Status EpNode::GetOutputs(std::unique_ptr& result) const { + result = std::make_unique(ORT_TYPE_TAG_OrtValueInfo); + result->storage.reserve(outputs_.size()); + + for (const EpValueInfo* output : outputs_) { + result->storage.push_back(output); + } + + return Status::OK(); +} + +Status EpNode::GetImplicitInputs(std::unique_ptr& result) const { + result = std::make_unique(ORT_TYPE_TAG_OrtValueInfo); + result->storage.reserve(implicit_inputs_.size()); + + for (const EpValueInfo* implicit_input : implicit_inputs_) { + result->storage.push_back(implicit_input); + } + + return Status::OK(); +} + +Status EpNode::GetSubgraphs(std::unique_ptr& result) const { + result = std::make_unique(ORT_TYPE_TAG_OrtGraph); + result->storage.reserve(subgraphs_.size()); + + for (const SubgraphState& subgraph : subgraphs_) { + result->storage.push_back(subgraph.ep_subgraph->ToExternal()); + } + + return Status::OK(); +} + +Status EpNode::GetParentGraph(const OrtGraph*& parent_graph) const { + parent_graph = ep_graph_->ToExternal(); + return Status::OK(); +} + +gsl::span EpNode::GetInputsSpan() const { + return inputs_; +} + +gsl::span EpNode::GetImplicitInputsSpan() const { + return implicit_inputs_; +} + +gsl::span EpNode::GetOutputsSpan() const { + return outputs_; +} + +// +// EpValueInfo +// +EpValueInfo::EpValueInfo(const EpGraph* graph, const std::string& name, std::unique_ptr&& type_info, + size_t flags) + : OrtValueInfo(OrtGraphIrApi::kEpApi), + graph_(graph), + name_(name), + type_info_(std::move(type_info)), + flags_(flags) {} + +#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) +static Status GetInputIndices(const EpNode& consumer_node, + const std::string& value_info_name, + /*out*/ std::vector& indices) { + bool found = false; + auto add_input_indices = + [&found, &value_info_name, &indices](gsl::span input_value_infos, + bool is_implicit) -> void { + for (size_t i = 0; i < input_value_infos.size(); i++) { + if (input_value_infos[i]->GetName() == value_info_name) { + indices.push_back(is_implicit ? -1 : static_cast(i)); + found = true; + } + } + }; + + add_input_indices(consumer_node.GetInputsSpan(), false); + add_input_indices(consumer_node.GetImplicitInputsSpan(), true); + + ORT_RETURN_IF_NOT(found, "Did not find OrtValueInfo with name ", value_info_name); + return Status::OK(); +} + +static Status GetOutputIndex(const EpNode& producer_node, + const std::string& value_info_name, + /*out*/ size_t& index) { + bool found = false; + gsl::span outputs = producer_node.GetOutputsSpan(); + + for (size_t i = 0; i < outputs.size(); i++) { + if (outputs[i]->GetName() == value_info_name) { + index = i; + found = true; + } + } + + ORT_RETURN_IF_NOT(found, "Did not find OrtValueInfo with name ", value_info_name); + return Status::OK(); +} +#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) + +Status EpValueInfo::GetProducerInfo(OrtValueInfo::ProducerInfo& producer_info) const { +#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) + producer_info.node = nullptr; + producer_info.output_index = 0; + + if (graph_ == nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unable to get producer node for OrtValueInfo '", name_, + "' that is not owned by a OrtGraph."); + } + + const Node* node = graph_->GetGraphViewer().GetProducerNode(name_); + if (node == nullptr) { + return Status::OK(); + } + + const EpNode* ep_node = graph_->GetNode(node->Index()); + if (ep_node == nullptr) { + return Status::OK(); // Node is not in this GraphViewer + } + + size_t output_index = 0; + ORT_RETURN_IF_ERROR(GetOutputIndex(*ep_node, name_, output_index)); + + producer_info.node = ep_node->ToExternal(); + producer_info.output_index = output_index; + + return Status::OK(); +#else + ORT_UNUSED_PARAMETER(producer_info); + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, + "Getting the producer of an OrtValueInfo is not supported in this build"); +#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) +} + +Status EpValueInfo::GetConsumerInfos(std::vector& result) const { +#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) + if (graph_ == nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unable to get uses of OrtValueInfo '", name_, + "' that is not owned by a OrtGraph."); + } + + std::vector nodes = graph_->GetGraphViewer().GetConsumerNodes(name_); + if (nodes.empty()) { + return Status::OK(); + } + + std::vector consumer_infos; + consumer_infos.reserve(nodes.size()); + + for (const Node* node : nodes) { + const EpNode* ep_node = graph_->GetNode(node->Index()); + if (ep_node == nullptr) { + continue; // Node is not in this GraphViewer + } + + std::vector input_indices; + ORT_RETURN_IF_ERROR(GetInputIndices(*ep_node, name_, input_indices)); + + for (int64_t input_index : input_indices) { + OrtValueInfo::ConsumerInfo use_info(ep_node->ToExternal(), input_index); + consumer_infos.push_back(use_info); + } + } + + result = std::move(consumer_infos); + + return Status::OK(); +#else + ORT_UNUSED_PARAMETER(result); + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, + "Getting the consumers of an OrtValueInfo is not supported in this build"); +#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) +} + +Status EpValueInfo::GetNumConsumerInfos(size_t& num_consumers) const { +#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) + num_consumers = 0; + + if (graph_ == nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unable to get number of uses of OrtValueInfo '", name_, + "' that is not owned by a OrtGraph."); + } + + std::vector nodes = graph_->GetGraphViewer().GetConsumerNodes(name_); + if (nodes.empty()) { + return Status::OK(); + } + + for (const Node* node : nodes) { + const EpNode* ep_node = graph_->GetNode(node->Index()); + if (ep_node == nullptr) { + continue; // Node is not in this GraphViewer + } + + std::vector input_indices; + ORT_RETURN_IF_ERROR(GetInputIndices(*ep_node, name_, input_indices)); + + num_consumers += input_indices.size(); // A single OrtNode can use an OrtValueInfo as an input more than once. + } + + return Status::OK(); +#else + ORT_UNUSED_PARAMETER(num_consumers); + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, + "Getting the consumers of an OrtValueInfo is not supported in this build"); +#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) +} + +Status EpValueInfo::GetInitializerValue(const OrtValue*& result) const { + if (!IsFlagSet(kIsConstantInitializer) && !IsFlagSet(kIsOptionalGraphInput)) { + // This OrtValueInfo does not represent an initializer. Set result to nullptr and return an OK status + // to allow user to use this function to check if this is an initializer. + result = nullptr; + return Status::OK(); + } + + ORT_RETURN_IF(graph_ == nullptr, "Unable to get initializer value named '", name_, "': parent graph is NULL"); + + // This gets an initializer value defined in this graph or in a parent graph (as long as the value + // is used in this graph). + result = graph_->GetInitializerValue(name_); + ORT_RETURN_IF(result == nullptr, "Unable to find initializer value named '", name_, "'."); + return Status::OK(); +} + +Status EpValueInfo::IsRequiredGraphInput(bool& is_required_graph_input) const { + is_required_graph_input = IsFlagSet(Flags::kIsRequiredGraphInput); + return Status::OK(); +} + +Status EpValueInfo::IsOptionalGraphInput(bool& is_optional_graph_input) const { + is_optional_graph_input = IsFlagSet(Flags::kIsOptionalGraphInput); + return Status::OK(); +} + +Status EpValueInfo::IsGraphOutput(bool& is_graph_output) const { + is_graph_output = IsFlagSet(Flags::kIsGraphOutput); + return Status::OK(); +} + +Status EpValueInfo::IsConstantInitializer(bool& is_const_initializer) const { + is_const_initializer = IsFlagSet(Flags::kIsConstantInitializer); + return Status::OK(); +} + +Status EpValueInfo::IsFromOuterScope(bool& is_outer_scope) const { + is_outer_scope = IsFlagSet(Flags::kIsOuterScope); + return Status::OK(); +} + +// +// EpGraph +// + +void EpGraph::IndexToEpNodeMap::Resize(NodeIndex min_node_index, NodeIndex max_node_index) { + assert(max_node_index >= min_node_index); + size_t num_elems = (max_node_index - min_node_index) + 1; + + min_node_index_ = min_node_index; + nodes_.resize(num_elems, nullptr); +} + +EpNode* EpGraph::IndexToEpNodeMap::GetEpNode(NodeIndex node_index) const { + size_t i = node_index - min_node_index_; + assert(i < nodes_.size()); + return nodes_[i]; +} + +void EpGraph::IndexToEpNodeMap::SetEpNode(NodeIndex node_index, EpNode* ep_node) { + size_t i = node_index - min_node_index_; + assert(i < nodes_.size()); + nodes_[i] = ep_node; +} + +EpGraph::EpGraph(const GraphViewer& graph_viewer, PrivateTag) + : OrtGraph(OrtGraphIrApi::kEpApi), graph_viewer_(graph_viewer) {} + +// Static class function to create a std::unique_ptr. +Status EpGraph::Create(const GraphViewer& graph_viewer, /*out*/ std::unique_ptr& result) { + auto ep_graph = std::make_unique(graph_viewer, PrivateTag{}); + + AllocatorPtr initializer_allocator = CPUAllocator::DefaultInstance(); + std::unordered_map> value_infos_map; + + // Process graph inputs. + const std::vector& graph_input_node_args = graph_viewer.GetInputsIncludingInitializers(); + InlinedVector graph_input_value_infos(graph_input_node_args.size(), nullptr); + ConvertNodeArgsToValueInfos(ep_graph.get(), value_infos_map, graph_input_node_args, + graph_input_value_infos, + [&graph_viewer](EpValueInfo* v) { + if (!graph_viewer.IsInitializedTensor(v->GetName())) { + v->SetFlag(EpValueInfo::Flags::kIsRequiredGraphInput); + } else if (graph_viewer.CanOverrideInitializer()) { + v->SetFlag(EpValueInfo::Flags::kIsOptionalGraphInput); + } + }); + + // Process graph outputs. + const std::vector& graph_output_node_args = graph_viewer.GetOutputs(); + InlinedVector graph_output_value_infos(graph_output_node_args.size(), nullptr); + ConvertNodeArgsToValueInfos(ep_graph.get(), value_infos_map, graph_output_node_args, + graph_output_value_infos, + [](EpValueInfo* v) { v->SetFlag(EpValueInfo::Flags::kIsGraphOutput); }); + + std::unordered_map> outer_scope_initializer_values; + + // Create OrtValueInfo and OrtValue instances for each initializer. + const InitializedTensorSet initializers = graph_viewer.GetAllInitializedTensors(); + std::vector initializer_value_infos; + std::unordered_map> initializer_values; + + initializer_value_infos.reserve(initializers.size()); + initializer_values.reserve(initializers.size()); + + for (const auto& [initializer_name, tensor_proto] : initializers) { + EpValueInfo* value_info = nullptr; + EpValueInfo::Flags flag = graph_viewer.IsConstantInitializer(initializer_name, /*check_outer_scope*/ false) + ? EpValueInfo::kIsConstantInitializer + : EpValueInfo::kIsOptionalGraphInput; + + auto iter = value_infos_map.find(initializer_name); + if (iter != value_infos_map.end()) { + value_info = iter->second.get(); + value_info->SetFlag(flag); + } else { + auto type_proto = utils::TypeProtoFromTensorProto(*tensor_proto); + std::unique_ptr type_info = OrtTypeInfo::FromTypeProto(type_proto); + auto unique_value_info = std::make_unique(ep_graph.get(), initializer_name, std::move(type_info), + flag); + + value_info = unique_value_info.get(); + value_infos_map.emplace(initializer_name, std::move(unique_value_info)); + } + + initializer_value_infos.push_back(value_info); + + // Temporary: Copy onnx::TensorProto into OrtValue objects owned by this EpGraph. + // TODO: Remove this logic once a separate PR that updates onnxruntime::Graph to store initializers as + // OrtValue instances is merged. + auto initializer_value = std::make_unique(); + ORT_RETURN_IF_ERROR(utils::TensorProtoToOrtValue(Env::Default(), graph_viewer.ModelPath(), *tensor_proto, + initializer_allocator, *initializer_value)); + initializer_values.emplace(value_info->GetName(), std::move(initializer_value)); + } + + // Process nodes in topological order, converting Node to EpNode. + std::vector> ep_nodes; + IndexToEpNodeMap index_to_ep_node; + NodeIndex min_node_index = std::numeric_limits::max(); + NodeIndex max_node_index = std::numeric_limits::lowest(); + + ep_nodes.reserve(graph_viewer.NumberOfNodes()); + + const std::vector& node_indices = graph_viewer.GetNodesInTopologicalOrder(ExecutionOrder::DEFAULT); + + for (NodeIndex node_index : node_indices) { + gsl::not_null node = graph_viewer.GetNode(node_index); + std::unique_ptr ep_node = nullptr; + + ORT_RETURN_IF_ERROR(EpNode::Create(*node, ep_graph.get(), value_infos_map, ep_node)); + ep_nodes.push_back(std::move(ep_node)); + + min_node_index = std::min(min_node_index, node->Index()); + max_node_index = std::max(max_node_index, node->Index()); + } + + // Iterate through nodes again and update the map of NodeIndex to EpNode* + index_to_ep_node.Resize(min_node_index, max_node_index); + for (std::unique_ptr& ep_node : ep_nodes) { + index_to_ep_node.SetEpNode(ep_node->GetInternalNode().Index(), ep_node.get()); + } + + // If this is a subgraph, add the OrtValueInfo and OrtValue objects that come from the outer scope. + // Wait until we have already processed OrtValueInfos consumed and produced by nodes so that we only add + // outer OrtValueInfo/OrtValue if they are actually used by the nodes in this GraphViewer. + if (graph_viewer.IsSubgraph()) { + gsl::not_null parent_graph = graph_viewer.GetGraph().ParentGraph(); + gsl::not_null parent_node = graph_viewer.ParentNode(); + + for (gsl::not_null implicit_node_arg : parent_node->ImplicitInputDefs()) { + const std::string& implicit_name = implicit_node_arg->Name(); + auto value_info_iter = value_infos_map.find(implicit_name); + + if (value_info_iter == value_infos_map.end()) { + continue; // Skip. This implicit value is not used by a node in this GraphViewer. + } + + EpValueInfo* outer_value_info = value_info_iter->second.get(); + bool is_constant = false; + const ONNX_NAMESPACE::TensorProto* outer_initializer = parent_graph->GetInitializer(implicit_name, + /*check_outer_scope*/ true, + is_constant); + outer_value_info->SetFlag(EpValueInfo::kIsOuterScope); + + if (outer_initializer != nullptr) { + outer_value_info->SetFlag(is_constant ? EpValueInfo::kIsConstantInitializer : EpValueInfo::kIsOptionalGraphInput); + } + + // Temporary: Copy onnx::TensorProto into OrtValue objects owned by this EpGraph. + // TODO: Remove this logic once a separate PR that updates onnxruntime::Graph to store initializers as + // OrtValue instances is merged. + if (outer_initializer != nullptr) { + auto initializer_value = std::make_unique(); + ORT_RETURN_IF_ERROR(utils::TensorProtoToOrtValue(Env::Default(), parent_graph->ModelPath(), + *outer_initializer, initializer_allocator, + *initializer_value)); + outer_scope_initializer_values.emplace(outer_value_info->GetName(), std::move(initializer_value)); + } + } + } + + ep_graph->nodes_ = std::move(ep_nodes); + ep_graph->index_to_ep_node_ = std::move(index_to_ep_node); + ep_graph->value_infos_ = std::move(value_infos_map); + ep_graph->initializer_value_infos_ = std::move(initializer_value_infos); + ep_graph->initializer_values_ = std::move(initializer_values); + ep_graph->outer_scope_initializer_values_ = std::move(outer_scope_initializer_values); + ep_graph->inputs_ = std::move(graph_input_value_infos); + ep_graph->outputs_ = std::move(graph_output_value_infos); + + result = std::move(ep_graph); + + return Status::OK(); +} + +const std::string& EpGraph::GetName() const { return graph_viewer_.Name(); } + +int64_t EpGraph::GetOnnxIRVersion() const { return graph_viewer_.GetOnnxIRVersion(); } + +Status EpGraph::GetInputs(std::unique_ptr& result) const { + result = std::make_unique(ORT_TYPE_TAG_OrtValueInfo); + result->storage.reserve(inputs_.size()); + + for (const EpValueInfo* input : inputs_) { + result->storage.push_back(input); + } + + return Status::OK(); +} + +Status EpGraph::GetOutputs(std::unique_ptr& result) const { + result = std::make_unique(ORT_TYPE_TAG_OrtValueInfo); + result->storage.reserve(outputs_.size()); + + for (const EpValueInfo* output : outputs_) { + result->storage.push_back(output); + } + + return Status::OK(); +} + +Status EpGraph::GetInitializers(std::unique_ptr& result) const { + result = std::make_unique(ORT_TYPE_TAG_OrtValueInfo); + result->storage.reserve(initializer_value_infos_.size()); + + for (const EpValueInfo* initializer_value_info : initializer_value_infos_) { + result->storage.push_back(initializer_value_info); + } + + return Status::OK(); +} + +Status EpGraph::GetNodes(std::unique_ptr& result) const { + result = std::make_unique(ORT_TYPE_TAG_OrtNode); + result->storage.reserve(nodes_.size()); + + for (const std::unique_ptr& ep_node : nodes_) { + result->storage.push_back(ep_node->ToExternal()); + } + + return Status::OK(); +} + +Status EpGraph::GetParentNode(const OrtNode*& result) const { + result = parent_node_ != nullptr ? parent_node_->ToExternal() : nullptr; + return Status::OK(); +} + +void EpGraph::SetParentNode(const EpNode* node) { parent_node_ = node; } + +const GraphViewer& EpGraph::GetGraphViewer() const { return graph_viewer_; } + +const EpNode* EpGraph::GetNode(NodeIndex node_index) const { + return index_to_ep_node_.GetEpNode(node_index); +} + +const OrtValue* EpGraph::GetInitializerValue(std::string_view name) const { + // Check for initializer value in the graph's scope. + if (auto iter = initializer_values_.find(name); + iter != initializer_values_.end()) { + return iter->second.get(); + } + + // Check for the initializer value in an outer scope. + // Only finds a value if the outer initializer value is used within this graph. + if (auto iter = outer_scope_initializer_values_.find(name); + iter != outer_scope_initializer_values_.end()) { + return iter->second.get(); + } + + return nullptr; +} +} // namespace onnxruntime diff --git a/onnxruntime/core/graph/ep_api_types.h b/onnxruntime/core/graph/ep_api_types.h new file mode 100644 index 0000000000000..358379a9b5854 --- /dev/null +++ b/onnxruntime/core/graph/ep_api_types.h @@ -0,0 +1,305 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include "core/common/inlined_containers.h" +#include "core/framework/abi_pointer_array.h" +#include "core/framework/allocator.h" +#include "core/graph/basic_types.h" +#include "core/graph/abi_graph_types.h" +#include "core/graph/graph_viewer.h" + +namespace onnxruntime { +struct EpGraph; + +/// +/// Concrete implementation of OrtValueInfo used in the OrtEpApi. +/// +struct EpValueInfo : public OrtValueInfo { + public: + enum Flags { + kFlagNone = 0, + kIsRequiredGraphInput = 1 << 0, + kIsOptionalGraphInput = 1 << 1, + kIsGraphOutput = 1 << 2, + kIsConstantInitializer = 1 << 3, + kIsOuterScope = 1 << 4, + }; + + EpValueInfo(const EpGraph* graph, const std::string& name, std::unique_ptr&& type_info, + size_t flags); + + // Defines ToExternal() and ToInternal() functions to convert between OrtValueInfo and EpValueInfo. + DEFINE_ORT_GRAPH_IR_TO_EXTERNAL_INTERNAL_FUNCS(OrtValueInfo, EpValueInfo, OrtGraphIrApi::kEpApi) + + // + // Publicly accessible overrides defined by OrtValueInfo. + // + + // Returns the value's name in the graph. + const std::string& GetName() const override { return name_; } + + // Returns the value's type information, which includes both type and shape. + const OrtTypeInfo* GetTypeInfo() const override { return type_info_.get(); } + + // Gets the information (OrtNode and output index) about the node that produces this value. + Status GetProducerInfo(OrtValueInfo::ProducerInfo& producer_info) const override; + + // Gets information (OrtNode and input index) about the consumer nodes that use this value as an input. + // An OrtNode instance may appear multiple times if it uses the value as an input more than once (e.g., Mul(x, x)). + // The input index is set to -1 if the consumer node uses the value as an "implicit input". + Status GetConsumerInfos(std::vector& consumer_infos) const override; + + // Gets the number of ConsumerInfo instances that will be returned by GetConsumerInfos. + Status GetNumConsumerInfos(size_t& num_consumers) const override; + + // Gets the initializer OrtValue associated with this OrtValueInfo. Returns nullptr if this does not + // represent an initializer (either constant or non-constant). + Status GetInitializerValue(const OrtValue*& value) const override; + + // Check if this value is a required graph input. + Status IsRequiredGraphInput(bool& is_required_graph_input) const override; + + // Check if this value is an optional graph input. + Status IsOptionalGraphInput(bool& is_optional_graph_input) const override; + + // Check if this value is a graph output. + Status IsGraphOutput(bool& is_graph_output) const override; + + // Check if this value is a constant initializer. + Status IsConstantInitializer(bool& is_const_initializer) const override; + + // Check if this value is defined in an outer scope (i.e., an outer graph). + Status IsFromOuterScope(bool& is_outer_scope) const override; + + // + // Helper functions used when working directly with an EpValueInfo. + // + + // Helper to set a flag. + void SetFlag(EpValueInfo::Flags flag) { flags_ |= flag; } + + // Helper to check if a flag is set. + bool IsFlagSet(EpValueInfo::Flags flag) const { return flags_ & flag; } + + private: + // Back pointer to parent graph. If not null, enables retrieval of consumer and producer nodes. + // Is null if the EpValueInfo was created without an owning EpGraph + // (e.g., OrtValueInfo instances created for fused nodes in OrtEp::Compile()). + const EpGraph* graph_ = nullptr; + std::string name_; + std::unique_ptr type_info_; + size_t flags_ = 0; +}; + +/// +/// Concrete implementation of OrtNode used in the OrtEpApi. +/// +struct EpNode : public OrtNode { + private: + struct PrivateTag {}; // Used to prevent use of public constructor (use static EpNode::Create()) + // Need to make the constructor public for std::make_unique(). + + struct SubgraphState { + SubgraphState() = default; + SubgraphState(SubgraphState&& other) = default; + std::unique_ptr subgraph_viewer; // The graph_viewer wrapped by EpGraph below. + std::unique_ptr ep_subgraph; + }; + + public: + EpNode(const EpGraph* ep_graph, const Node& node, PrivateTag); + + /// + /// Creates an instance of EpNode, which wraps an onnxruntime::Node. + /// + /// The actual node to wrap. + /// Optional pointer to the parent graph. Set this to a valid graph to be able to get + /// neighboring nodes from this node's input and output OrtValueInfo instances. + /// Cache of all OrtValueInfo instances in the graph. Can be set to an empty + /// std::unordered_map if creating a node without a graph. + /// The new EpNode instance. + /// A Status indicating success or an error. + static Status Create(const Node& node, const EpGraph* ep_graph, + std::unordered_map>& value_infos, + /*out*/ std::unique_ptr& result); + + // Defines ToExternal() and ToInternal() functions to convert between OrtNode and EpNode. + DEFINE_ORT_GRAPH_IR_TO_EXTERNAL_INTERNAL_FUNCS(OrtNode, EpNode, OrtGraphIrApi::kEpApi) + + // + // Publicly accessible overrides defined by OrtNode. + // + + // Returns the node's ID (i.e., NodeIndex). + size_t GetId() const override; + + // Returns the node's name. + const std::string& GetName() const override; + + // Returns the node's operator type (e.g., "Conv"). + const std::string& GetOpType() const override; + + // Returns the node's domain name. + const std::string& GetDomain() const override; + + // Gets the opset version in which this node's operator was first defined. + Status GetSinceVersion(int& since_version) const override; + + // Gets the node's inputs as OrtValueInfo instances wrapped in an OrtArrayOfConstObjects. + Status GetInputs(std::unique_ptr& inputs) const override; + + // Gets the node's outputs as OrtValueInfo instances wrapped in an OrtArrayOfConstObjects. + Status GetOutputs(std::unique_ptr& outputs) const override; + + // Gets the node's implicit inputs as OrtValueInfo instances wrapped in an OrtArrayOfConstObjects. + Status GetImplicitInputs(std::unique_ptr& inputs) const override; + + // Gets the subgraphs contained by this node. + Status GetSubgraphs(std::unique_ptr& subgraphs) const override; + + // Gets this node's parent graph, which is the graph that directly contains this node. + Status GetParentGraph(const OrtGraph*& parent_graph) const override; + + // + // Helper functions used when working directly with an EpNode. + // + + // Returns the internal onnxruntime::Node& that this OrtNode wraps. + const Node& GetInternalNode() const { return node_; } + + // Helper that returns this node's inputs as a span of EpValueInfo pointers. + gsl::span GetInputsSpan() const; + + // Helper that returns this node's implicit inputs as a span of EpValueInfo pointers. + gsl::span GetImplicitInputsSpan() const; + + // Helper that returns this node's outputs as a span of EpValueInfo pointers. + gsl::span GetOutputsSpan() const; + + private: + // Back pointer to containing graph. Useful when traversing through nested subgraphs. + // Will be nullptr if the EpNode was created without an owning graph. + // (e.g., OrtNode instances created for fused nodes in OrtEp::Compile()). + const EpGraph* ep_graph_ = nullptr; + const Node& node_; + + InlinedVector inputs_; + InlinedVector outputs_; + + std::vector implicit_inputs_; + std::vector subgraphs_; +}; + +/// +/// Concrete implementation of OrtGraph used in the OrtEpApi. +/// +struct EpGraph : public OrtGraph { + private: + struct PrivateTag {}; // Used to prevent use of public constructor (use static EpGraph::Create()) + // Need to make the constructor public for std::make_unique(). + + // Class that maps a NodeIndex to an EpNode* using a std::vector. + // This is used a lot and should be more efficient than using an unordered_map. + struct IndexToEpNodeMap { + public: + IndexToEpNodeMap() = default; + IndexToEpNodeMap(IndexToEpNodeMap&& other) = default; + IndexToEpNodeMap& operator=(IndexToEpNodeMap&& other) = default; + void Resize(NodeIndex min_node_index, NodeIndex max_node_index); + EpNode* GetEpNode(NodeIndex node_index) const; + void SetEpNode(NodeIndex node_index, EpNode* ep_node); + + private: + NodeIndex min_node_index_ = 0; + std::vector nodes_; + }; + + public: + EpGraph(const GraphViewer& graph_viewer, PrivateTag); + + /// + /// Creates an instance of EpGraph, which wraps a GraphViewer. + /// + /// + /// + /// + static Status Create(const GraphViewer& graph_viewer, /*out*/ std::unique_ptr& result); + + // Defines ToExternal() and ToInternal() functions to convert between OrtGraph and EpGraph. + DEFINE_ORT_GRAPH_IR_TO_EXTERNAL_INTERNAL_FUNCS(OrtGraph, EpGraph, OrtGraphIrApi::kEpApi) + + // + // Publicly accessible overrides defined by OrtGraph. + // + + // Returns the graph's name. + const std::string& GetName() const override; + + // Returns the model's ONNX IR version. + int64_t GetOnnxIRVersion() const override; + + // Gets the graph's inputs as OrtValueInfo instances wrapped in an OrtArrayOfConstObjects. + // Includes initializers that are graph inputs. + Status GetInputs(std::unique_ptr& inputs) const override; + + // Gets the graph's outputs as OrtValueInfo instances wrapped in an OrtArrayOfConstObjects. + Status GetOutputs(std::unique_ptr& outputs) const override; + + // Gets the graph's initializers as OrtValueInfo instances wrapped in an OrtArrayOfConstObjects. + // Includes both constant initializers and non-constant initializers (aka optional graph inputs). + Status GetInitializers(std::unique_ptr& initializers) const override; + + // Gets the graph's nodes as OrtNode instances wrapped in an OrtArrayOfConstObjects. + // The nodes are sorted in a default "reverse DFS" topological order. + Status GetNodes(std::unique_ptr& nodes) const override; + + // Gets the graph's parent node or nullptr if this is not a nested subgraph. + Status GetParentNode(const OrtNode*& parent_node) const override; + + // + // Helper functions used when working directly with an EpGraph. + // + + // Sets this graph's parent node. + void SetParentNode(const EpNode* node); + + // Returns the onnxruntime::GraphViewer& wrapped by this OrtGraph. + const GraphViewer& GetGraphViewer() const; + + // Returns the EpNode with the given ID (i.e., a NodeIndex). + // Returns nullptr if this graph does not directly contain a node with the given ID. + const EpNode* GetNode(NodeIndex node_index) const; + + // Returns the OrtValue for an OrtValueInfo that represents an initializer. + // Considers both constant and non-constant initializers. + // Supports initializers defined in an outer scope as long as that initializer is used + // within this graph. + const OrtValue* GetInitializerValue(std::string_view name) const; + + private: + const GraphViewer& graph_viewer_; + const EpNode* parent_node_ = nullptr; + + std::vector> nodes_; + IndexToEpNodeMap index_to_ep_node_; + + std::unordered_map> value_infos_; // All value infos in the graph + + std::vector initializer_value_infos_; + std::unordered_map> initializer_values_; + std::unordered_map> outer_scope_initializer_values_; + + InlinedVector inputs_; + InlinedVector outputs_; +}; +} // namespace onnxruntime diff --git a/onnxruntime/core/graph/function.cc b/onnxruntime/core/graph/function.cc index 25e666ecb2c65..62e73b24cca14 100644 --- a/onnxruntime/core/graph/function.cc +++ b/onnxruntime/core/graph/function.cc @@ -4,6 +4,7 @@ #include "core/common/inlined_containers.h" #include "core/framework/tensorprotoutils.h" #include "core/graph/function_impl.h" +#include "core/graph/graph_utils.h" #include "core/graph/graph_viewer.h" #include "core/graph/model.h" @@ -72,24 +73,11 @@ FunctionImpl::FunctionImpl(onnxruntime::Graph& graph, } for (const auto& input : meta_def->inputs) { - const ONNX_NAMESPACE::TensorProto* initializer = nullptr; - if (graph.GetInitializedTensor(input, initializer)) { - // meta_def->inputs could have duplicates so make sure we only add once - const ONNX_NAMESPACE::TensorProto* subgraph_initializer = nullptr; - if (!function_body_graph_.GetInitializedTensor(input, subgraph_initializer)) { - function_body_graph_.AddInitializedTensor(*initializer); - } - } + graph_utils::MakeInitializerCopyIfNotExist(graph, function_body_graph_, input); } for (const auto& constant_initializer : meta_def->constant_initializers) { - const ONNX_NAMESPACE::TensorProto* initializer = graph.GetConstantInitializer(constant_initializer, true); - ORT_ENFORCE(initializer != nullptr, "Initializer " + constant_initializer + " is not found or is not constant initializer."); - // meta_def->constant_initializers could have duplicates so make sure we only add once - const ONNX_NAMESPACE::TensorProto* subgraph_initializer = nullptr; - if (!function_body_graph_.GetInitializedTensor(constant_initializer, subgraph_initializer)) { - function_body_graph_.AddInitializedTensor(*initializer); - } + graph_utils::MakeConstantInitializerCopyIfNotExist(graph, function_body_graph_, constant_initializer, true); } // TODO: if we reuse the nodes in parent graph, maybe we don't need to resolve it. diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index 8b41460ccce21..c9856b9964495 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -172,16 +172,6 @@ static void RemoveInvalidValues(ONNX_NAMESPACE::TypeProto& type) { } } -static TypeProto TypeProtoFromTensorProto(const TensorProto& tensor) { - TypeProto t; - t.mutable_tensor_type()->set_elem_type(tensor.data_type()); - auto shape = t.mutable_tensor_type()->mutable_shape(); - for (auto dim : tensor.dims()) - shape->add_dim()->set_dim_value(dim); - - return t; -} - static std::string GenerateSchemaKey(const IndexedSubGraph& subgraph_ptr) { return MakeString(subgraph_ptr.GetMetaDef()->domain, "_", subgraph_ptr.GetMetaDef()->name, "_", @@ -1249,14 +1239,14 @@ Graph::Graph(const Model& owning_model, if (attrib.type() == AttributeProto_AttributeType_SPARSE_TENSOR) { const TensorProto& sparse_values = node.attribute(0).sparse_tensor().values(); if ((!(sparse_values.has_raw_data())) && tensor->has_raw_data()) { - onnxruntime::utils::ConvertRawDataInTensorProto(tensor); + onnxruntime::utils::ConvertRawDataInTensorProto(*tensor); } } } ORT_ENFORCE(status.IsOK(), status.ToString()); // Ensure initializers are also graph inputs. if (ir_version_ < 4) { - TypeProto t{TypeProtoFromTensorProto(*tensor)}; + TypeProto t{utils::TypeProtoFromTensorProto(*tensor)}; const NodeArg& node_arg = GetOrCreateNodeArg(tensor->name(), &t); *(graph_proto_->add_input()) = node_arg.ToProto(); } @@ -1340,12 +1330,16 @@ Graph::Graph(const Model& owning_model, } NodeArg* matching_graph_input = GetNodeArg(tensor.name()); - TypeProto t{TypeProtoFromTensorProto(tensor)}; + TypeProto t{utils::TypeProtoFromTensorProto(tensor)}; if (!utils::HasElemType(t.tensor_type())) { ORT_THROW("This is an invalid model. Tensor does not have type information."); } + if (tensor.has_data_type() && (tensor.data_type() < TensorProto_DataType_DataType_ARRAYSIZE)) { + weight_data_type_freq_[tensor.data_type()]++; + } + if (ir_version_ < 4) { // initializers can have matching graph inputs but are treated as constant, // so we prefer the shape from the initializer @@ -1434,20 +1428,18 @@ void Graph::InitializeStateFromModelFileGraphProto() { "Graph state to be loaded into must be empty."); // Name to NodeArg mapping of all graph initializers. - std::unordered_map graph_initializers; - - // Name to NodeArg mapping of all graph inputs. - std::unordered_map graph_inputs; - - // Name to NodeArg mapping of all graph node outputs. - std::unordered_map nodes_outputs; - + InlinedHashMap graph_initializers; + graph_initializers.reserve(graph_proto_->initializer_size()); for (auto& initializer : graph_proto_->initializer()) { auto& initializer_name = initializer.name(); auto initializer_arg = GetNodeArg(initializer_name); graph_initializers.insert({initializer_name, initializer_arg}); } + // Name to NodeArg mapping of all graph inputs. + InlinedHashMap graph_inputs; + graph_inputs.reserve(graph_proto_->input_size()); + // Set graph inputs. // contains inputs exactly specified in proto. // contains inputs without default value (specified as initializer). @@ -1462,6 +1454,9 @@ void Graph::InitializeStateFromModelFileGraphProto() { } } + // Name to NodeArg mapping of all graph node outputs. + InlinedHashMap nodes_outputs; + nodes_outputs.reserve(graph_proto_->node_size() * 2); // rough estimate for (const auto& node : Nodes()) { for (const auto* output_def : node.OutputDefs()) { nodes_outputs.insert({output_def->Name(), output_def}); @@ -3411,18 +3406,28 @@ bool Graph::ResolveContext::IsOuterScopeValue(const std::string& name) const { #endif // !defined(ORT_MINIMAL_BUILD) #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) + void Graph::AddInitializedTensor(const TensorProto& tensor) { auto existing = name_to_initial_tensor_.find(tensor.name()); - if (existing != name_to_initial_tensor_.cend()) { + const bool exists = existing != name_to_initial_tensor_.cend(); + if (exists) { ORT_ENFORCE(existing->second == &tensor, "AddInitializedTensor already has tensor with name ", tensor.name(), " but different TensorProto."); return; } + // This overload is used when the tensor does not point to an OrtValue which + // would need to be updated, but it is okay if it is pointing to flatbuffers or some other place at the moment. + // However, if an ort_value present for the name, it must be replaced. + if (utils::HasExternalDataInMemory(tensor)) { + if (ortvalue_initializers_.count(tensor.name()) > 0) { + ORT_THROW("OrtValue needs to be inserted. Use the overload that takes both TensorProto and OrtValue with data"); + } + } const gsl::not_null tensor_added{graph_proto_->add_initializer()}; *(tensor_added) = tensor; name_to_initial_tensor_.emplace(tensor.name(), tensor_added); - SetGraphResolveNeeded(); + if (!is_loaded_from_model_file_ && GetNodeArg(tensor.name()) == nullptr) { // make sure there is a NodeArg for the initializer as SetGraphInputsOutputs may add it to the graph inputs. // the shape will be set to the correct value in TypeCheckInputsAndInitializers as we don't yet know whether there @@ -3431,6 +3436,45 @@ void Graph::AddInitializedTensor(const TensorProto& tensor) { t.mutable_tensor_type()->set_elem_type(tensor.data_type()); ORT_IGNORE_RETURN_VALUE(GetOrCreateNodeArg(tensor.name(), &t)); } + + SetGraphResolveNeeded(); +} + +Status Graph::AddInitializedOrtValue(const ONNX_NAMESPACE::TensorProto& tensor_proto, + const OrtValue& ortvalue_initializer) { + ORT_RETURN_IF(name_to_initial_tensor_.count(tensor_proto.name()) > 0, "Attempt to replace the existing tensor"); + + const gsl::not_null tensor_added{graph_proto_->add_initializer()}; + *(tensor_added) = tensor_proto; + name_to_initial_tensor_.emplace(tensor_proto.name(), tensor_added); + + if (ortvalue_initializer.IsAllocated()) { + ORT_RETURN_IF_NOT(utils::HasExternalDataInMemory(tensor_proto), + "TensorProto is expected to refer to the ortvalue_initializer"); + const auto element_type = static_cast(utils::GetTensorElementType(tensor_proto)); + const auto& tensor = ortvalue_initializer.Get(); + ORT_RETURN_IF_NOT(tensor.GetElementType() == element_type, + "Element type mismatch between tensor proto and ortvalue_initializer"); + const auto proto_shape = utils::GetTensorShapeFromTensorProto(tensor_proto); + ORT_RETURN_IF_NOT(proto_shape == tensor.Shape(), "Shape mismatch with ortvalue_initializer"); + + ortvalue_initializers_.insert_or_assign(tensor_proto.name(), ortvalue_initializer); + } else { + ORT_ENFORCE(ortvalue_initializers_.count(tensor_proto.name()) == 0, + "Stray leftover ort_value for a small initializer being inserted."); + } + + SetGraphResolveNeeded(); + if (!is_loaded_from_model_file_ && GetNodeArg(tensor_proto.name()) == nullptr) { + // make sure there is a NodeArg for the initializer as SetGraphInputsOutputs may add it to the graph inputs. + // the shape will be set to the correct value in TypeCheckInputsAndInitializers as we don't yet know whether there + // will be a matching graph input for this initializer (we prefer shape info from the graph input). + TypeProto t; + t.mutable_tensor_type()->set_elem_type(tensor_proto.data_type()); + ORT_IGNORE_RETURN_VALUE(GetOrCreateNodeArg(tensor_proto.name(), &t)); + } + + return Status::OK(); } void Graph::FindAllSubgraphs(std::vector& subgraphs) { @@ -3538,7 +3582,8 @@ void Graph::RemoveInitializedTensor(const std::string& tensor_name) { } #if !defined(ORT_MINIMAL_BUILD) -Status Graph::ReplaceInitializedTensorImpl(ONNX_NAMESPACE::TensorProto new_initializer, bool is_external) { +Status Graph::ReplaceInitializedTensorImpl(ONNX_NAMESPACE::TensorProto new_initializer, + OrtValue ort_value, bool must_replace_external) { // name_to_initial_tensor_ maps from name to const TensorProto*, so we first // look up the const pointer by name, then find and modify the mutable // pointed-to TensorProto in graph_proto_. @@ -3557,9 +3602,16 @@ Status Graph::ReplaceInitializedTensorImpl(ONNX_NAMESPACE::TensorProto new_initi return true; }; - ORT_RETURN_IF_NOT(!is_external || utils::HasExternalData(old_initializer), + // This check ensures that we are replacing the right initializer than the users wants to + // replace data that is on disk with a reference to data in memory. + ORT_RETURN_IF_NOT(!must_replace_external || utils::HasExternalData(old_initializer), "Trying to replace non-external initializer with external data"); + // New initializers data generally are within OrtValues + // Small initializers are still stored inside TensorProto + ORT_RETURN_IF_NOT(utils::HasExternalDataInMemory(new_initializer) || !ort_value.IsAllocated(), + "All TensorProtos are expected to point to an OrtValue"); + ORT_RETURN_IF_NOT(dims_eq(), "Replacement tensor's dimensions do not match."); ORT_RETURN_IF_NOT(old_initializer.data_type() == new_initializer.data_type(), "Replacement tensor's data type does not match."); @@ -3573,22 +3625,50 @@ Status Graph::ReplaceInitializedTensorImpl(ONNX_NAMESPACE::TensorProto new_initi ORT_ENFORCE(existing_entry != mutable_initializers.pointer_end(), "graph_proto_ is not in sync with name_to_initial_tensor_"); + if (ort_value.IsAllocated()) { + ORT_IGNORE_RETURN_VALUE(ortvalue_initializers_.insert_or_assign(initializer_name, std::move(ort_value))); + } else { + ORT_IGNORE_RETURN_VALUE(ortvalue_initializers_.erase(initializer_name)); + } + **existing_entry = std::move(new_initializer); return Status::OK(); } -Status Graph::ReplaceInitializedTensor(ONNX_NAMESPACE::TensorProto new_initializer) { - return ReplaceInitializedTensorImpl(std::move(new_initializer), false); +common::Status Graph::ReplaceInitializedTensor(const ONNX_NAMESPACE::TensorProto& new_initializer, + const OrtValue& ort_value) { + return ReplaceInitializedTensorImpl(new_initializer, ort_value, false); } #if !defined(DISABLE_EXTERNAL_INITIALIZERS) Status Graph::InjectExternalInitializedTensors(const InlinedHashMap& external_initializers) { - for (const auto& e : external_initializers) { - const auto& name = e.first; - const OrtValue& ort_value = e.second; - auto tensor_proto = utils::TensorToTensorProto(ort_value.Get(), name); - ORT_RETURN_IF_ERROR(ReplaceInitializedTensorImpl(std::move(tensor_proto), true)); + for (const auto& [name, value] : external_initializers) { + const auto& user_tensor = value.Get(); + + OrtValue ort_value; + TensorProto tensor_proto; + constexpr const bool use_tensor_buffer_true = true; + if (user_tensor.SizeInBytes() > utils::kSmallTensorExternalDataThreshold) { + if (user_tensor.OwnsBuffer()) { + // If the user tensor has its own memory, we avoid copying + tensor_proto = utils::TensorToTensorProto(user_tensor, name, use_tensor_buffer_true); + ORT_ENFORCE(utils::HasExternalDataInMemory(tensor_proto), "Expecting this tensor_proto to have a pointer"); + ort_value = value; + } else { + Tensor initializer{user_tensor.DataType(), user_tensor.Shape(), CPUAllocator::DefaultInstance()}; + utils::MakeCpuTensorCopy(user_tensor, initializer); + + tensor_proto = utils::TensorToTensorProto(initializer, name, use_tensor_buffer_true); + ORT_ENFORCE(utils::HasExternalDataInMemory(tensor_proto), "Expecting this tensor_proto to have a pointer"); + Tensor::InitOrtValue(std::move(initializer), ort_value); + } + } else { + constexpr const bool use_tensor_buffer_false = false; + tensor_proto = utils::TensorToTensorProto(user_tensor, name, use_tensor_buffer_false); + } + + ORT_RETURN_IF_ERROR(ReplaceInitializedTensorImpl(std::move(tensor_proto), std::move(ort_value), true)); LOGS(logger_, INFO) << "Replaced external initializer: " << name; } return Status::OK(); @@ -3598,14 +3678,14 @@ Status Graph::InjectExternalInitializersFromFilesInMemory( const InlinedHashMap>& external_initializer_files) { for (const auto& [tensor_name, tensor_proto] : name_to_initial_tensor_) { if (tensor_proto->data_location() == TensorProto_DataLocation_EXTERNAL) { - std::unique_ptr external_data_info; - ORT_RETURN_IF_ERROR(onnxruntime::ExternalDataInfo::Create(tensor_proto->external_data(), external_data_info)); + std::unique_ptr external_data_info; + ORT_RETURN_IF_ERROR(ExternalDataInfo::Create(tensor_proto->external_data(), external_data_info)); const auto& external_file = external_data_info->GetRelPath(); onnxruntime::FileOffsetType file_offset = external_data_info->GetOffset(); const size_t external_data_length = external_data_info->GetLength(); SafeInt tensor_byte_size; - ORT_RETURN_IF_ERROR(onnxruntime::utils::GetSizeInBytesFromTensorProto<0>(*tensor_proto, &tensor_byte_size)); + ORT_RETURN_IF_ERROR(utils::GetSizeInBytesFromTensorProto<0>(*tensor_proto, &tensor_byte_size)); ORT_RETURN_IF_NOT(external_data_length == 0 || external_data_length == tensor_byte_size, "TensorProto: ", tensor_name, " external data size mismatch. Computed size: ", *&tensor_byte_size, ", external_data.length: ", external_data_length); @@ -3641,7 +3721,17 @@ Status Graph::InjectExternalInitializersFromFilesInMemory( TensorShape tensor_shape = utils::GetTensorShapeFromTensorProto(old_initializer); auto tensor = Tensor(type, tensor_shape, tensor_buffer, OrtMemoryInfo(CPU, OrtAllocatorType::OrtDeviceAllocator)); - auto new_tensor_proto = utils::TensorToTensorProto(tensor, tensor_name); + + constexpr const bool use_tensor_buffer_true = true; + auto new_tensor_proto = utils::TensorToTensorProto(tensor, tensor_name, use_tensor_buffer_true); + // Implied that external data is in memory + const bool has_external_data_in_memory = utils::HasExternalData(new_tensor_proto); + + OrtValue ort_value; + if (has_external_data_in_memory) { + Tensor::InitOrtValue(std::move(tensor), ort_value); + } + ortvalue_initializers_.insert_or_assign(tensor_name, std::move(ort_value)); **existing_entry = std::move(new_tensor_proto); } } @@ -3662,14 +3752,24 @@ bool Graph::GetInitializedTensor(const std::string& tensor_name, const TensorPro return true; } -bool Graph::GetOrtValueInitializer(const std::string& name, OrtValue& value) const { - auto it = ortvalue_initializers_.find(name); - if (it == ortvalue_initializers_.end()) { - return false; +bool Graph::GetOrtValueInitializer(const std::string& name, OrtValue& value, bool check_outer_scope) const { + // We want to make sure that the ort_value is found on the same level as its tensor_proto + const ONNX_NAMESPACE::TensorProto* initializer = nullptr; + if (GetInitializedTensor(name, initializer)) { + auto it = ortvalue_initializers_.find(name); + if (it != ortvalue_initializers_.end()) { + value = it->second; + return true; + } } - value = it->second; - return true; + if (check_outer_scope && IsSubgraph()) { + if (IsOuterScopeValue(name)) { + // make sure there's not a local value with the same name. if there is it shadows any initializer in outer scope. + return parent_graph_->GetOrtValueInitializer(name, value, check_outer_scope); + } + } + return false; } void Graph::CleanAllInitializedTensors() noexcept { @@ -3731,6 +3831,29 @@ const ONNX_NAMESPACE::TensorProto* Graph::GetInitializer(const std::string& init return initializer; } +const ONNX_NAMESPACE::TensorProto* Graph::GetInitializer(const std::string& initializer_name, bool check_outer_scope, + bool& is_constant) const { + const ONNX_NAMESPACE::TensorProto* initializer = nullptr; + if (GetInitializedTensor(initializer_name, initializer)) { + if (CanOverrideInitializer()) { + const auto& graph_inputs = GetInputsIncludingInitializers(); + is_constant = std::none_of(graph_inputs.cbegin(), graph_inputs.cend(), + [&initializer_name](const NodeArg* input) { + return input->Name() == initializer_name; + }); + } else { + is_constant = true; + } + } else if (check_outer_scope && IsSubgraph()) { + // make sure there's not a local value with the same name. if there is it shadows any initializer in outer scope. + if (IsOuterScopeValue(initializer_name)) { + initializer = parent_graph_->GetInitializer(initializer_name, check_outer_scope, is_constant); + } + } + + return initializer; +} + #if !defined(ORT_MINIMAL_BUILD) void Graph::AddValueInfo(const NodeArg* new_value_info) { NodeArg* node_arg = GetNodeArg(new_value_info->Name()); @@ -3821,12 +3944,6 @@ SaveInputsOutputsToOrtFormat(flatbuffers::FlatBufferBuilder& builder, const std: common::Status Graph::SaveToOrtFormat(flatbuffers::FlatBufferBuilder& builder, flatbuffers::Offset& fbs_graph) const { - if constexpr (endian::native != endian::little) { - auto& tens = GetAllInitializedTensors(); - for (auto& [name, tensor_p] : tens) { - utils::ConvertRawDataInTensorProto(const_cast(tensor_p)); - } - } auto inputs = SaveInputsOutputsToOrtFormat(builder, graph_inputs_including_initializers_); auto outputs = SaveInputsOutputsToOrtFormat(builder, graph_outputs_); @@ -4111,7 +4228,7 @@ ONNX_NAMESPACE::GraphProto Graph::ToGraphProto() const { output = initializer; // copy any in-memory external data into raw data - if (utils::HasExternalData(initializer)) { + if (utils::HasExternalDataInMemory(initializer)) { const std::filesystem::path ignored; std::basic_string location; onnxruntime::FileOffsetType file_offset; @@ -4119,14 +4236,12 @@ ONNX_NAMESPACE::GraphProto Graph::ToGraphProto() const { ORT_THROW_IF_ERROR(utils::GetExternalDataInfo(initializer, ignored, location, file_offset, tensor_byte_size)); - if (location == onnxruntime::utils::kTensorProtoMemoryAddressTag) { - // file_offset is address - void* data = reinterpret_cast(file_offset); + // file_offset is address + void* data = reinterpret_cast(file_offset); - // set in raw data - output.clear_data_location(); - output.set_raw_data(data, tensor_byte_size); - } + // set in raw data + output.clear_data_location(); + output.set_raw_data(data, tensor_byte_size); } }; @@ -4934,7 +5049,7 @@ Status Graph::AddConstantProtoAsInitializer(const ONNX_NAMESPACE::NodeProto& nod ORT_ENFORCE(insert_result.second, "Constant node name: ", tensor->name(), " conflicts with graph initializer. Check that the node names have been made unique."); if (GetNodeArg(tensor->name()) == nullptr) { - TypeProto t{TypeProtoFromTensorProto(*tensor)}; + TypeProto t{utils::TypeProtoFromTensorProto(*tensor)}; ORT_IGNORE_RETURN_VALUE(GetOrCreateNodeArg(tensor->name(), &t)); } @@ -5385,7 +5500,7 @@ Status Graph::InlineFunction(Node& callnode) { ORT_ENFORCE(insert_result.second, "Initializer name: ", tensor->name(), " in inlined subgraph: ", subgraph.Name(), " conflicts with graph initializer. Check Specializing code."); if (GetNodeArg(tensor->name()) == nullptr) { - TypeProto t{TypeProtoFromTensorProto(*tensor)}; + TypeProto t{utils::TypeProtoFromTensorProto(*tensor)}; ORT_IGNORE_RETURN_VALUE(GetOrCreateNodeArg(tensor->name(), &t)); } } @@ -5747,7 +5862,7 @@ common::Status Graph::LoadFromOrtFormat(const onnxruntime::fbs::Graph& fbs_graph #if !defined(ORT_MINIMAL_BUILD) namespace { -ValueInfoProto OrtValueInfoToOnnx(const OrtValueInfo& vi) { +ValueInfoProto ModelEditorValueInfoToOnnx(const onnxruntime::ModelEditorValueInfo& vi) { // the model builder API checks that the OrtValueInfo has a complete and valid OrtTypeInfo instance and that the // name is not null/empty. ORT_ENFORCE(vi.type_info->type == ONNX_TYPE_TENSOR, @@ -5789,7 +5904,7 @@ Status Graph::LoadFromModelEditorApiModel(const OrtGraph& api_graph, bool updati // NodeArg for the value using that auto add_graph_inputs_outputs = [&, this]( - const InlinedVector>& graph_inputs_or_outputs, + const InlinedVector>& graph_inputs_or_outputs, bool is_input) { // when updating a model we don't require the inputs or outputs to be set if they're unchanged. if (updating_existing_graph && graph_inputs_or_outputs.empty()) { @@ -5799,7 +5914,7 @@ Status Graph::LoadFromModelEditorApiModel(const OrtGraph& api_graph, bool updati std::vector node_args; node_args.reserve(graph_inputs_or_outputs.size()); for (auto& ort_value_info : graph_inputs_or_outputs) { - ValueInfoProto value_info = OrtValueInfoToOnnx(*ort_value_info); + ValueInfoProto value_info = ModelEditorValueInfoToOnnx(*ort_value_info); name_to_type_map[value_info.name()] = value_info.type(); node_args.push_back(&GetOrCreateNodeArg(value_info.name(), &value_info.type())); @@ -5843,7 +5958,7 @@ Status Graph::LoadFromModelEditorApiModel(const OrtGraph& api_graph, bool updati tensor_proto.set_raw_data(t.DataRaw(), t.SizeInBytes()); } - TypeProto type_proto{TypeProtoFromTensorProto(tensor_proto)}; + TypeProto type_proto{utils::TypeProtoFromTensorProto(tensor_proto)}; ORT_IGNORE_RETURN_VALUE(GetOrCreateNodeArg(name, &type_proto)); name_to_initial_tensor_.emplace(name, &tensor_proto); @@ -5852,19 +5967,22 @@ Status Graph::LoadFromModelEditorApiModel(const OrtGraph& api_graph, bool updati // process graph inputs first as we want the type/shape from them to be preferred if a graph input // has a matching initializer - add_graph_inputs_outputs(api_graph.inputs, /*input*/ true); + const auto* editor_graph = onnxruntime::ModelEditorGraph::ToInternal(&api_graph); + ORT_RETURN_IF(editor_graph == nullptr, "Invalid OrtGraph variant for use in the model editor API."); + + add_graph_inputs_outputs(editor_graph->inputs, /*input*/ true); // add initializers - ortvalue_initializers_.reserve(api_graph.external_initializers.size()); - add_initializers(api_graph.external_initializers, /*is_external*/ true); - add_initializers(api_graph.initializers, /*is_external*/ false); + ortvalue_initializers_.reserve(editor_graph->external_initializers.size()); + add_initializers(editor_graph->external_initializers, /*is_external*/ true); + add_initializers(editor_graph->initializers, /*is_external*/ false); // add graph outputs - add_graph_inputs_outputs(api_graph.outputs, /*input*/ false); + add_graph_inputs_outputs(editor_graph->outputs, /*input*/ false); // add nodes - for (const auto& ort_node : api_graph.nodes) { - const OrtNode& node = *ort_node; + for (const auto& editor_node : editor_graph->nodes) { + const onnxruntime::ModelEditorNode& node = *editor_node; // convert Constant nodes to initializers if (node.operator_name == "Constant" && node.domain_name == kOnnxDomain) { diff --git a/onnxruntime/core/graph/graph_flatbuffers_utils.cc b/onnxruntime/core/graph/graph_flatbuffers_utils.cc index 199aa79cc1dde..6c27bacacf9c2 100644 --- a/onnxruntime/core/graph/graph_flatbuffers_utils.cc +++ b/onnxruntime/core/graph/graph_flatbuffers_utils.cc @@ -50,7 +50,16 @@ Status SaveInitializerOrtFormat(flatbuffers::FlatBufferBuilder& builder, string_data = builder.CreateVectorOfStrings(string_data_vec); } else { std::vector unpacked_tensor; - ORT_RETURN_IF_ERROR(onnxruntime::utils::UnpackInitializerData(initializer, model_path, unpacked_tensor)); + // We can not convert this in place, because the session may be used + // after the model was saved in ort format. If the session is continued to be used, then + // we continue with initializers in memory with wrong endianess + if constexpr (endian::native != endian::little) { + auto be_copy{initializer}; + onnxruntime::utils::ConvertRawDataInTensorProto(be_copy); + ORT_RETURN_IF_ERROR(onnxruntime::utils::UnpackInitializerData(be_copy, model_path, unpacked_tensor)); + } else { + ORT_RETURN_IF_ERROR(onnxruntime::utils::UnpackInitializerData(initializer, model_path, unpacked_tensor)); + } if (external_writer && unpacked_tensor.size() >= kMinimumSizeForExternalData) { // write bytes to external buffer/file and record offset for the start of the data diff --git a/onnxruntime/core/graph/graph_proto_serializer.cc b/onnxruntime/core/graph/graph_proto_serializer.cc index eb0fb22346f37..80bb3f13814d1 100644 --- a/onnxruntime/core/graph/graph_proto_serializer.cc +++ b/onnxruntime/core/graph/graph_proto_serializer.cc @@ -1,6 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include "core/common/inlined_containers.h" +#include "core/framework/tensorprotoutils.h" #include "core/graph/graph_proto_serializer.h" namespace onnxruntime { @@ -21,13 +23,13 @@ void GraphViewerToProto(const GraphViewer& graph_view, *(graph_proto.mutable_output()->Add()) = output_arg->ToProto(); } - std::unordered_set value_info_ = graph_view.GetValueInfo(); + const auto& value_infos = graph_view.GetValueInfo(); // Reserve memory for the vector to avoid reallocations - std::vector value_info_sorted; - value_info_sorted.reserve(value_info_.size()); + InlinedVector value_info_sorted; + value_info_sorted.reserve(value_infos.size()); + value_info_sorted.assign(value_infos.begin(), value_infos.end()); - value_info_sorted.assign(value_info_.begin(), value_info_.end()); auto sort_predicate = [](const NodeArg* v1, const NodeArg* v2) { return v1->Name() < v2->Name(); }; @@ -58,21 +60,39 @@ void GraphViewerToProto(const GraphViewer& graph_view, } if (include_initializer) { - std::unordered_set current_scope_initializer_set; - - auto& initializers = graph_view.GetAllInitializedTensors(); + const auto& initializers = graph_view.GetAllInitializedTensors(); // Sort initializers to maintain consistency in model proto created across inference requests - std::vector const_inits; - for (auto& it : initializers) { - const_inits.push_back(it.first); + InlinedVector const_inits; + const_inits.reserve(initializers.size()); + for (auto it = initializers.cbegin(), end = initializers.cend(); it != end; ++it) { + const_inits.push_back(it); } - std::sort(const_inits.begin(), const_inits.end()); + std::sort(const_inits.begin(), const_inits.end(), [](const auto& i1, const auto& i2) { + return i1->first < i2->first; + }); + + InlinedHashSet current_scope_initializer_set; + current_scope_initializer_set.reserve(const_inits.size()); + + auto get_initializer_with_data = [&](const ONNX_NAMESPACE::TensorProto& init, + ONNX_NAMESPACE::TensorProto& dest) -> Status { + std::unique_ptr full_init; + ORT_RETURN_IF_ERROR(utils::GetTensorProtoWithDataIfInMemory(init, full_init)); + if (full_init) { + dest = std::move(*full_init); + } else { + dest = init; + } + return Status::OK(); + }; - for (auto& it : const_inits) { + // Handle this scope initializers + for (const auto& it : const_inits) { + const auto& [name, init] = *it; + current_scope_initializer_set.insert(name); auto* p_initializer = graph_proto.add_initializer(); - *p_initializer = *(initializers.at(it)); - current_scope_initializer_set.insert(it); + ORT_THROW_IF_ERROR(get_initializer_with_data(*init, *p_initializer)); } // handle outer scope value which is a constant initializer @@ -80,13 +100,15 @@ void GraphViewerToProto(const GraphViewer& graph_view, for (auto& node_idx : graph_view.GetNodesInTopologicalOrder(order)) { const auto& node = graph_view.GetNode(node_idx); for (const auto& input : node->InputDefs()) { - if (current_scope_initializer_set.find(input->Name()) != current_scope_initializer_set.end()) { + if (current_scope_initializer_set.count(std::string_view{input->Name()}) > 0) { continue; } - if (graph_view.IsConstantInitializer(input->Name(), true)) { - auto* p_initializer = graph_proto.add_initializer(); - *p_initializer = *(graph_view.GetConstantInitializer(input->Name(), true)); + + const auto* outer_scope_init = graph_view.GetConstantInitializer(input->Name(), true); + if (outer_scope_init != nullptr) { current_scope_initializer_set.insert(input->Name()); + auto* p_initializer = graph_proto.add_initializer(); + ORT_THROW_IF_ERROR(get_initializer_with_data(*outer_scope_init, *p_initializer)); } } } diff --git a/onnxruntime/core/graph/graph_utils.cc b/onnxruntime/core/graph/graph_utils.cc index cc48df4444951..dcf627fc605f4 100644 --- a/onnxruntime/core/graph/graph_utils.cc +++ b/onnxruntime/core/graph/graph_utils.cc @@ -2,6 +2,8 @@ // Licensed under the MIT License. #include "core/graph/graph_utils.h" + +#include "core/framework/tensorprotoutils.h" #include "core/graph/graph.h" #include "core/common/logging/logging.h" @@ -249,6 +251,19 @@ const ONNX_NAMESPACE::AttributeProto* GetNodeAttribute(const Node& node, const s return iter == attrs.end() ? nullptr : &iter->second; } +static NodeArg& GetOrCreateNodeArg(Graph& graph, const ONNX_NAMESPACE::TensorProto& new_initializer) { + ONNX_NAMESPACE::TypeProto new_type; + auto* typeproto_tensor = new_type.mutable_tensor_type(); + typeproto_tensor->set_elem_type(new_initializer.data_type()); + + auto* shape = typeproto_tensor->mutable_shape(); + for (auto dim : new_initializer.dims()) { + shape->add_dim()->set_dim_value(dim); + } + + return graph.GetOrCreateNodeArg(new_initializer.name(), &new_type); +} + NodeArg& AddInitializer(Graph& graph, const ONNX_NAMESPACE::TensorProto& new_initializer) { // sanity check as AddInitializedTensor silently ignores attempts to add a duplicate initializer const ONNX_NAMESPACE::TensorProto* existing = nullptr; @@ -256,17 +271,91 @@ NodeArg& AddInitializer(Graph& graph, const ONNX_NAMESPACE::TensorProto& new_ini "Initializer with same name exists. Name:", new_initializer.name()); graph.AddInitializedTensor(new_initializer); + return GetOrCreateNodeArg(graph, new_initializer); +} - ONNX_NAMESPACE::TypeProto new_type; - auto* typeproto_tensor = new_type.mutable_tensor_type(); - typeproto_tensor->set_elem_type(new_initializer.data_type()); +NodeArg& AddInitializerWithExternalData(Graph& graph, const ONNX_NAMESPACE::TensorProto& new_initializer) { + ORT_ENFORCE(!utils::HasExternalData(new_initializer), "Expecting an initializer that contains data inline"); - auto* shape = typeproto_tensor->mutable_shape(); - for (auto dim : new_initializer.dims()) { - shape->add_dim()->set_dim_value(dim); + Tensor tensor; + ORT_THROW_IF_ERROR(utils::CreateTensorFromTensorProto(Env::Default(), graph.ModelPath(), + new_initializer, tensor)); + auto tensor_proto_with_ptr = utils::TensorToTensorProto(tensor, new_initializer.name(), true); + return AddInitializerWithExternalData(graph, tensor_proto_with_ptr, std::move(tensor)); +} + +NodeArg& AddInitializerWithExternalData(Graph& graph, const ONNX_NAMESPACE::TensorProto& new_initializer, + Tensor&& tensor) { + OrtValue ort_value; + if (utils::HasExternalDataInMemory(new_initializer)) { + Tensor::InitOrtValue(std::move(tensor), ort_value); } - return graph.GetOrCreateNodeArg(new_initializer.name(), &new_type); + ORT_THROW_IF_ERROR(graph.AddInitializedOrtValue(new_initializer, ort_value)); + return GetOrCreateNodeArg(graph, new_initializer); +} + +NodeArg& AddInitializerWithExternalData(Graph& graph, const ONNX_NAMESPACE::TensorProto& new_initializer, + OrtValue ort_value) { + ORT_THROW_IF_ERROR(graph.AddInitializedOrtValue(new_initializer, ort_value)); + return GetOrCreateNodeArg(graph, new_initializer); +} + +void MakeInitializerCopyIfNotExist(const Graph& src_graph, Graph& dst_graph, const std::string& name, + bool copy_in_memory_data) { + const ONNX_NAMESPACE::TensorProto* initializer = nullptr; + if (src_graph.GetInitializedTensor(name, initializer)) { + // check if the initializer already exists in the destination graph + const ONNX_NAMESPACE::TensorProto* existing = nullptr; + if (!dst_graph.GetInitializedTensor(name, existing)) { + const bool data_in_memory = utils::HasExternalDataInMemory(*initializer); + if (data_in_memory) { + if (copy_in_memory_data) { + ONNX_NAMESPACE::TensorProto tensor_proto; + ORT_THROW_IF_ERROR(utils::TensorProtoWithExternalDataToTensorProto(*initializer, {}, tensor_proto)); + dst_graph.AddInitializedTensor(tensor_proto); + GetOrCreateNodeArg(dst_graph, tensor_proto); + } else { + OrtValue ort_value; + if (src_graph.GetOrtValueInitializer(name, ort_value)) { + // add the initializer to the destination graph + ORT_THROW_IF_ERROR(dst_graph.AddInitializedOrtValue(*initializer, ort_value)); + } else { + // Data may be in memory, but stored in flatbuffers etc. + dst_graph.AddInitializedTensor(*initializer); + } + GetOrCreateNodeArg(dst_graph, *initializer); + } + } else { + dst_graph.AddInitializedTensor(*initializer); + GetOrCreateNodeArg(dst_graph, *initializer); + } + } + } +} + +void MakeConstantInitializerCopyIfNotExist(const Graph& src_graph, Graph& dst_graph, + const std::string& name, bool check_outer_scope) { + const auto* initializer = src_graph.GetConstantInitializer(name, check_outer_scope); + if (initializer != nullptr) { + const ONNX_NAMESPACE::TensorProto* subgraph_initializer = nullptr; + if (!dst_graph.GetInitializedTensor(name, subgraph_initializer)) { + OrtValue ort_value; + ORT_IGNORE_RETURN_VALUE(src_graph.GetOrtValueInitializer(name, ort_value, check_outer_scope)); + ORT_THROW_IF_ERROR(dst_graph.AddInitializedOrtValue(*initializer, ort_value)); + } + } +} + +Status ConvertInMemoryDataToInline(Graph& graph, const std::string& name) { + const ONNX_NAMESPACE::TensorProto* initializer = nullptr; + if (graph.GetInitializedTensor(name, initializer) && utils::HasExternalDataInMemory(*initializer)) { + ONNX_NAMESPACE::TensorProto tensor_proto; + ORT_THROW_IF_ERROR(utils::TensorProtoWithExternalDataToTensorProto(*initializer, {}, tensor_proto)); + graph.RemoveInitializedTensor(name); + graph.AddInitializedTensor(tensor_proto); + } + return Status::OK(); } int GetNodeOutputIndexFromOutputName(const Node& node, const std::string& output_name) { diff --git a/onnxruntime/core/graph/graph_utils.h b/onnxruntime/core/graph/graph_utils.h index 8710519cdc865..033488d734bd5 100644 --- a/onnxruntime/core/graph/graph_utils.h +++ b/onnxruntime/core/graph/graph_utils.h @@ -38,6 +38,71 @@ Checks that new_initializer does not already exist in 'graph' before adding it. */ NodeArg& AddInitializer(Graph& graph, const ONNX_NAMESPACE::TensorProto& new_initializer); +/// +/// Adds a new initializer to 'graph' with new_initializer that points to the OrtValue buffer +/// +/// target graph +/// TensorProto with external data contained in ort_value +/// ort_value with data +/// +NodeArg& AddInitializerWithExternalData(Graph& graph, const ONNX_NAMESPACE::TensorProto& new_initializer, + OrtValue ort_value); + +/** Add a new initializer to 'graph'. + * Checks that new_initializer does not already exist in 'graph' before adding it. + * @param new_initializer tensor proto that has external data pointing to data within the tensor. + * @param tensor with data + * @returns The NodeArg for the new initializer. + * @remarks No matching graph input is created, so the initializer will be constant. + */ +NodeArg& AddInitializerWithExternalData(Graph& graph, const ONNX_NAMESPACE::TensorProto& new_initializer, Tensor&& tensor); + +/** Add a new initializer to 'graph'. + * The function unpacks data into a tensor and converts new_initializer to a TensorProto with external data in memory. + * The initializer is then added to the graph and tensor is wrapped into OrtValue and added to + * Graph::ortvalue_initializers_; + * + * @param graph The graph to which the initializer will be added. + * @param new_initializer tensor proto that actually has data in it + * @returns The NodeArg for the new initializer. + * @remarks No matching graph input is created, so the initializer will be constant. + */ +NodeArg& AddInitializerWithExternalData(Graph& graph, const ONNX_NAMESPACE::TensorProto& new_initializer); + +/// +/// If the initializer with the given name does not exist in the destination graph, but exists in the +/// source graph, copy it to the destination graph. +/// +/// source graph s +/// destination +/// initializers name +/// if external data is in memory, copy data inline. +/// default is false. This is to accomodate EPs who load initializers on their own and do not understand +/// our /*/_ORT_MEM_ADDR_/*/ external data reference +void MakeInitializerCopyIfNotExist(const Graph& src_graph, Graph& dst_graph, const std::string& name, + bool copy_in_memory_data = false); + +/// +/// If the constant initializer with the given name does not exist in the destination graph, but exists in the +/// source graph, copy it to the destination graph along with its OrtValue if present. +/// +/// +/// +/// +/// checks outerscope if true +void MakeConstantInitializerCopyIfNotExist(const Graph& src_graph, Graph& dst_graph, + const std::string& name, bool check_outer_scope); + +/// +/// If the initializer is present with the graph and has external data in memory, +/// convert it to inline data. This is necessary for EPs that can not handle +/// external initializers that are in memory since our in-memory external data is not ONNX standard. +/// +/// Graph +/// intializer name +/// Status +Status ConvertInMemoryDataToInline(Graph& graph, const std::string& name); + /** Gets the index of an output arg with the specified output arg name. */ int GetNodeOutputIndexFromOutputName(const Node& node, const std::string& output_name); diff --git a/onnxruntime/core/graph/model_editor_api_types.h b/onnxruntime/core/graph/model_editor_api_types.h index d72bd13093b61..5860269193b94 100644 --- a/onnxruntime/core/graph/model_editor_api_types.h +++ b/onnxruntime/core/graph/model_editor_api_types.h @@ -1,24 +1,130 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#pragma once + +#include +#include +#include +#include + #include "core/common/inlined_containers_fwd.h" #include "core/framework/ort_value.h" -#include "core/framework/onnxruntime_typeinfo.h" +#include "core/graph/abi_graph_types.h" #include "core/graph/onnx_protobuf.h" -// ORT C interface types for OrtGraphApi can't be in a namespace. -// We need to define them here so onnxruntime::Model can be created from OrtModel. +namespace onnxruntime { + +/// +/// Concrete implementation of OrtValueInfo used in the ModelEditorApi. +/// +struct ModelEditorValueInfo : public OrtValueInfo { + ModelEditorValueInfo() : OrtValueInfo(OrtGraphIrApi::kModelEditorApi) {} + + // Defines ToExternal() and ToInternal() functions to convert between OrtValueInfo and ModelEditorValueInfo. + DEFINE_ORT_GRAPH_IR_TO_EXTERNAL_INTERNAL_FUNCS(OrtValueInfo, ModelEditorValueInfo, OrtGraphIrApi::kModelEditorApi) + + const std::string& GetName() const override { return name; } + + const OrtTypeInfo* GetTypeInfo() const override { return type_info.get(); } + + Status GetProducerInfo(ProducerInfo& /*producer_info*/) const override { + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, + "OrtModelEditorApi does not support getting the producer for OrtValueInfo"); + } + + Status GetConsumerInfos(std::vector& /*consumer_infos*/) const override { + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, + "OrtModelEditorApi does not support getting the consumers for a OrtValueInfo"); + } + + Status GetNumConsumerInfos(size_t& /*num_consumers*/) const override { + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, + "OrtModelEditorApi does not support getting the number of consumers for a OrtValueInfo"); + } + + Status GetInitializerValue(const OrtValue*& /*value*/) const override { + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, + "OrtModelEditorApi does not support getting the initializer value for a OrtValueInfo"); + } + + Status IsRequiredGraphInput(bool& /*is_required_graph_input*/) const override { + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, + "OrtModelEditorApi does not support querying if a graph input is required for OrtValueInfo"); + } + + Status IsOptionalGraphInput(bool& /*is_optional_graph_input*/) const override { + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, + "OrtModelEditorApi does not support querying if OrtValueInfo is an optional graph input."); + } + + Status IsGraphOutput(bool& /*is_graph_output*/) const override { + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, + "OrtModelEditorApi does not support querying if a OrtValueInfo is a graph output."); + } + + Status IsConstantInitializer(bool& /*is_const_initializer*/) const override { + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, + "OrtModelEditorApi does not support querying if a OrtValueInfo is a constant initializer."); + } + + Status IsFromOuterScope(bool& /*is_outer_scope*/) const override { + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, + "OrtModelEditorApi does not support querying if a OrtValueInfo is defined in an outer scope."); + } -struct OrtValueInfo { std::string name; std::unique_ptr type_info; }; -struct OrtOpAttr { - ONNX_NAMESPACE::AttributeProto attr_proto; -}; +/// +/// Concrete implementation of OrtNode used in the ModelEditorApi. +/// +struct ModelEditorNode : public OrtNode { + ModelEditorNode() : OrtNode(OrtGraphIrApi::kModelEditorApi) {} + + // Defines ToExternal() and ToInternal() functions to convert between OrtNode and ModelEditorNode. + DEFINE_ORT_GRAPH_IR_TO_EXTERNAL_INTERNAL_FUNCS(OrtNode, ModelEditorNode, OrtGraphIrApi::kModelEditorApi) + + size_t GetId() const override { return id; } + + const std::string& GetName() const override { return node_name; } -struct OrtNode { + const std::string& GetOpType() const override { return operator_name; } + + const std::string& GetDomain() const override { return domain_name; } + + Status GetSinceVersion(int& /*since_version*/) const override { + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, + "OrtModelEditorApi does not support getting an OrtNode's opset version"); + } + + Status GetInputs(std::unique_ptr& /*inputs*/) const override { + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, + "OrtModelEditorApi does not support getting input OrtValueInfos for OrtNode"); + } + + Status GetOutputs(std::unique_ptr& /*outputs*/) const override { + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, + "OrtModelEditorApi does not support getting output OrtValueInfos for OrtNode"); + } + + Status GetImplicitInputs(std::unique_ptr& /*implicit_inputs*/) const override { + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, + "OrtModelEditorApi does not support getting the implicit inputs for OrtNode"); + } + + Status GetSubgraphs(std::unique_ptr& /*subgraphs*/) const override { + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, + "OrtModelEditorApi does not support getting the subgraphs for OrtNode"); + } + + Status GetParentGraph(const OrtGraph*& /*parent_graph*/) const override { + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, + "OrtModelEditorApi does not support getting the parent graph for OrtNode"); + } + + size_t id = 0; std::string operator_name; std::string domain_name; std::string node_name; @@ -33,14 +139,56 @@ struct OrtNode { // std::unordered_map subgraphs; }; -struct OrtGraph { - onnxruntime::InlinedVector> inputs; - onnxruntime::InlinedVector> outputs; +/// +/// Concrete implementation of OrtGraph used in the ModelEditorApi. +/// +struct ModelEditorGraph : public OrtGraph { + ModelEditorGraph() : OrtGraph(OrtGraphIrApi::kModelEditorApi) {} + + // Defines ToExternal() and ToInternal() functions to convert between OrtGraph and ModelEditorGraph. + DEFINE_ORT_GRAPH_IR_TO_EXTERNAL_INTERNAL_FUNCS(OrtGraph, ModelEditorGraph, OrtGraphIrApi::kModelEditorApi) + + const std::string& GetName() const override { return name; } + + int64_t GetOnnxIRVersion() const override { + return ONNX_NAMESPACE::Version::IR_VERSION; + } + + Status GetInputs(std::unique_ptr& /*result*/) const override { + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, + "OrtModelEditorApi does not support getting the graph inputs."); + } + + Status GetOutputs(std::unique_ptr& /*result*/) const override { + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, + "OrtModelEditorApi does not support getting the graph outputs."); + } + + Status GetInitializers(std::unique_ptr& /*result*/) const override { + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, + "OrtModelEditorApi does not support getting the graph initializers."); + } + + Status GetNodes(std::unique_ptr& /*result*/) const override { + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, + "OrtModelEditorApi does not support getting the graph nodes."); + } + + Status GetParentNode(const OrtNode*& /*parent_node*/) const override { + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, + "OrtModelEditorApi does not support getting the parent node for OrtGraph"); + } + + onnxruntime::InlinedVector> inputs; + onnxruntime::InlinedVector> outputs; std::unordered_map> initializers; std::unordered_map> external_initializers; - std::vector> nodes; + std::vector> nodes; + std::string name = "ModelEditorGraph"; }; +} // namespace onnxruntime + struct OrtModel { std::unique_ptr graph; std::unordered_map domain_to_version; diff --git a/onnxruntime/core/mlas/lib/qgemm_kernel_wasmrelaxedsimd.cpp b/onnxruntime/core/mlas/lib/qgemm_kernel_wasmrelaxedsimd.cpp index a3a0fa758d377..72f0f5d8a4dd4 100644 --- a/onnxruntime/core/mlas/lib/qgemm_kernel_wasmrelaxedsimd.cpp +++ b/onnxruntime/core/mlas/lib/qgemm_kernel_wasmrelaxedsimd.cpp @@ -376,181 +376,215 @@ MlasGemmQuantCopyPackB( } } -MLAS_FORCEINLINE -void -MlasGemmU8X8MultiplyAccumulateRowWasmRelaxedSimd( - v128_t ABroadcast, - const uint8_t* B, - v128_t Accumulators[2] -) -{ - v128_t BElements0 = wasm_v128_load(&B[0]); - v128_t BElements1 = wasm_v128_load(&B[16]); - Accumulators[0] = wasm_i32x4_relaxed_dot_i8x16_i7x16_add(BElements0, ABroadcast, Accumulators[0]); - Accumulators[1] = wasm_i32x4_relaxed_dot_i8x16_i7x16_add(BElements1, ABroadcast, Accumulators[1]); +//-------------------------------------------------------------------------- +// Small helper that performs one (A row) × (8 B columns) FMA step. +//-------------------------------------------------------------------------- +MLAS_FORCEINLINE void DotPairAdd(v128_t ABroadcast, + v128_t BVec0, + v128_t BVec1, + v128_t Acc[2]) { + Acc[0] = wasm_i32x4_relaxed_dot_i8x16_i7x16_add(BVec0, ABroadcast, Acc[0]); + Acc[1] = wasm_i32x4_relaxed_dot_i8x16_i7x16_add(BVec1, ABroadcast, Acc[1]); } +//-------------------------------------------------------------------------- +// Generic RowCount×8 kernel implementation (RowCount is 6 or 1 at compile‑time) +//-------------------------------------------------------------------------- -template<> -size_t -MlasGemmQuantKernel( +template +static size_t GemmQuantKernelNx8Impl( const MLAS_GEMM_U8X8_KERNEL_WASMRELAXEDSIMD::PackedAType* A, const MLAS_GEMM_U8X8_KERNEL_WASMRELAXEDSIMD::PackedBType* B, int32_t* C, size_t PackedCountK, - size_t CountM, + size_t /*CountM — ignored*/, size_t CountN, size_t ldc, const int32_t* RowSumBuffer, const int32_t* ColumnSumBuffer, const int32_t* ZeroPointB, - bool ZeroMode - ) + bool ZeroMode) { - MLAS_UNREFERENCED_PARAMETER(CountM); - MLAS_UNREFERENCED_PARAMETER(ldc); + constexpr size_t ColBlock = 8; + const auto PackedK = MLAS_GEMM_U8X8_KERNEL_WASMRELAXEDSIMD::PackedK; + + // Build row‑wise pointer tables (a[r] & c[r]). + const MLAS_GEMM_U8X8_KERNEL_WASMRELAXEDSIMD::PackedAType* a[RowCount]; + int32_t* c[RowCount]; + for (size_t r = 0; r < RowCount; ++r) { + a[r] = (const MLAS_GEMM_U8X8_KERNEL_WASMRELAXEDSIMD::PackedAType*)(A + r * PackedK * PackedCountK); + c[r] = (int32_t*)(C + r * ldc); + } while (CountN > 0) { - - v128_t Accumulators[2]; - - // - // Initialize the accumulators with the row and column sums. - // - - int32_t RowSumValue = RowSumBuffer[0]; - - if (ZeroPointB != nullptr) { - - int32_t ScaledRowSumBuffer[8]; - - for (size_t i = 0; i < 8; i++) { - ScaledRowSumBuffer[i] = RowSumValue * ZeroPointB[i]; - } - + // ------------------------------------------------------------------ + // 1) Initialize accumulators with row & column sums (and zero‑points) + // ------------------------------------------------------------------ + v128_t Acc[RowCount][2]; + + if (ZeroPointB) { + v128_t zp0 = wasm_v128_load(ZeroPointB + 0); + v128_t zp1 = wasm_v128_load(ZeroPointB + 4); ZeroPointB += 8; - Accumulators[0] = wasm_v128_load(&ScaledRowSumBuffer[0]); - Accumulators[1] = wasm_v128_load(&ScaledRowSumBuffer[4]); - + for (size_t r = 0; r < RowCount; ++r) { + v128_t RowSumValues = wasm_v128_load32_splat(RowSumBuffer + r); + Acc[r][0] = wasm_i32x4_mul(RowSumValues, zp0); + Acc[r][1] = wasm_i32x4_mul(RowSumValues, zp1); + } + } else { + for (size_t r = 0; r < RowCount; ++r) { + Acc[r][0] = wasm_v128_load32_splat(RowSumBuffer + r); + Acc[r][1] = Acc[r][0]; + } } - else { - Accumulators[0] = wasm_i32x4_splat(RowSumValue); - Accumulators[1] = Accumulators[0]; + v128_t col0 = wasm_v128_load(ColumnSumBuffer + 0); // first 4 col sums + v128_t col1 = wasm_v128_load(ColumnSumBuffer + 4); // next 4 col sums + for (size_t r = 0; r < RowCount; ++r) { + Acc[r][0] = wasm_i32x4_add(Acc[r][0], col0); + Acc[r][1] = wasm_i32x4_add(Acc[r][1], col1); } - - Accumulators[0] = wasm_i32x4_add(Accumulators[0], wasm_v128_load(&ColumnSumBuffer[0])); - Accumulators[1] = wasm_i32x4_add(Accumulators[1], wasm_v128_load(&ColumnSumBuffer[4])); ColumnSumBuffer += 8; - // - // Broadcast each pair of 16-bit values from the matrix A and multiply - // with the pair of 16-bit values from matrix B, and add the 32-bit + // ------------------------------------------------------------------ + // 2) Broadcast each pair of 8-bit values from the matrix A and multiply + // with the pair of 8-bit values from matrix B, and add the 32-bit // intermediate into the accumulator registers. - // - - const uint8_t* a = A; + // ------------------------------------------------------------------ size_t k = PackedCountK; - - while (k >= 4) { - - v128_t AElements = wasm_v128_load((v128_t*)a); - v128_t ABroadcast; - - ABroadcast = wasm_i32x4_shuffle(AElements, wasm_i32x4_splat(0), 0, 0, 0, 0); - MlasGemmU8X8MultiplyAccumulateRowWasmRelaxedSimd(ABroadcast, &B[0], Accumulators); - - ABroadcast = wasm_i32x4_shuffle(AElements, wasm_i32x4_splat(0), 1, 1, 1, 1); - MlasGemmU8X8MultiplyAccumulateRowWasmRelaxedSimd(ABroadcast, &B[32], Accumulators); - - ABroadcast = wasm_i32x4_shuffle(AElements, wasm_i32x4_splat(0), 2, 2, 2, 2); - MlasGemmU8X8MultiplyAccumulateRowWasmRelaxedSimd(ABroadcast, &B[64], Accumulators); - - ABroadcast = wasm_i32x4_shuffle(AElements, wasm_i32x4_splat(0), 3, 3, 3, 3); - MlasGemmU8X8MultiplyAccumulateRowWasmRelaxedSimd(ABroadcast, &B[96], Accumulators); - - a += 4 * 4; - B += 4 * 32; - k -= 4; - } - while (k > 0) { + v128_t ABroadcast[RowCount]; + for (size_t r = 0; r < RowCount; ++r) { + ABroadcast[r] = wasm_v128_load32_splat(a[r]); // broadcast 4 × u8 + a[r] += 4; + } - v128_t ABroadcast = wasm_i32x4_splat(*((int32_t*)a)); - MlasGemmU8X8MultiplyAccumulateRowWasmRelaxedSimd(ABroadcast, &B[0], Accumulators); - - a += 4; + v128_t B0 = wasm_v128_load(&B[0]); + v128_t B1 = wasm_v128_load(&B[16]); + for (size_t r = 0; r < RowCount; ++r) { + DotPairAdd(ABroadcast[r], B0, B1, Acc[r]); + } B += 32; k -= 1; } - // - // Output the accumulator block after optionally accumulating the values + // ------------------------------------------------------------------ + // 3) Output the accumulator block after optionally accumulating the values // from matrix C. - // + // ------------------------------------------------------------------ if (CountN >= 8) { - - if (!ZeroMode) { - Accumulators[0] = wasm_i32x4_add(Accumulators[0], wasm_v128_load(&C[0])); - Accumulators[1] = wasm_i32x4_add(Accumulators[1], wasm_v128_load(&C[4])); - } - - wasm_v128_store(&C[0], Accumulators[0]); - wasm_v128_store(&C[4], Accumulators[1]); - - C += 8; - CountN -= 8; - - } - else { - - // - // Output the remaining partial output block. - // - - if ((CountN & 4) != 0) { - + // ---- Full 8‑column tile ---- + for (size_t r = 0; r < RowCount; ++r) { if (!ZeroMode) { - Accumulators[0] = wasm_i32x4_add(Accumulators[0], wasm_v128_load(&C[0])); + Acc[r][0] = wasm_i32x4_add(Acc[r][0], wasm_v128_load(c[r] + 0)); + Acc[r][1] = wasm_i32x4_add(Acc[r][1], wasm_v128_load(c[r] + 4)); } - - wasm_v128_store(&C[0], Accumulators[0]); - C += 4; - - Accumulators[0] = Accumulators[1]; + wasm_v128_store(c[r] + 0, Acc[r][0]); + wasm_v128_store(c[r] + 4, Acc[r][1]); + a[r] -= PackedCountK * 4; // Rewind a[r] for next N‑tile (PackedCountK * 4 bytes each). + c[r] += ColBlock; } - - if ((CountN & 2) != 0) { - - if (!ZeroMode) { - Accumulators[0] = wasm_i32x4_add(Accumulators[0], wasm_v128_load64_zero(&C[0])); + CountN -= ColBlock; + } else { + // ---- 4/2/1‑column tails ---- + auto Tail = [&](size_t cols, auto load_c, auto store_c) { + for (size_t r = 0; r < RowCount; ++r) { + if (!ZeroMode) Acc[r][0] = wasm_i32x4_add(Acc[r][0], load_c(c[r])); } - - wasm_v128_store64_lane(&C[0], Accumulators[0], 0); - C += 2; - - Accumulators[0] = wasm_i32x4_shuffle(Accumulators[0], wasm_i32x4_splat(0), 2, 3, 2, 3); + for (size_t r = 0; r < RowCount; ++r) store_c(c[r], Acc[r][0]); + for (size_t r = 0; r < RowCount; ++r) c[r] += cols; + }; + + if (CountN & 4) { + Tail(4, + [](int32_t* p) { return wasm_v128_load(p); }, + [](int32_t* p, v128_t v) { wasm_v128_store(p, v); }); + for (size_t r = 0; r < RowCount; ++r) Acc[r][0] = Acc[r][1]; } - - if ((CountN & 1) != 0) { - - int32_t AccumulatorValue = wasm_i32x4_extract_lane(Accumulators[0], 0); - - if (!ZeroMode) { - AccumulatorValue += C[0]; + if (CountN & 2) { + Tail(2, + [](int32_t* p) { return wasm_v128_load64_zero(p); }, + [](int32_t* p, v128_t v) { wasm_v128_store64_lane(p, v, 0); }); + for (size_t r = 0; r < RowCount; ++r) + Acc[r][0] = wasm_i32x4_shuffle(Acc[r][0], wasm_i32x4_splat(0), 2, 3, 2, 3); + } + if (CountN & 1) { + for (size_t r = 0; r < RowCount; ++r) { + int32_t v = wasm_i32x4_extract_lane(Acc[r][0], 0); + if (!ZeroMode) v += *c[r]; + *c[r] = v; } - - C[0] = AccumulatorValue; } - CountN = 0; } } + return RowCount; +} + + +size_t MlasGemmQuantKernel6x8( + const MLAS_GEMM_U8X8_KERNEL_WASMRELAXEDSIMD::PackedAType* A, + const MLAS_GEMM_U8X8_KERNEL_WASMRELAXEDSIMD::PackedBType* B, + int32_t* C, + size_t PackedCountK, + size_t CountM, + size_t CountN, + size_t ldc, + const int32_t* RowSumBuffer, + const int32_t* ColumnSumBuffer, + const int32_t* ZeroPointB, + bool ZeroMode) { + MLAS_UNREFERENCED_PARAMETER(CountM); + return GemmQuantKernelNx8Impl<6>(A, B, C, PackedCountK, 0, CountN, ldc, + RowSumBuffer, ColumnSumBuffer, ZeroPointB, ZeroMode); +} + +size_t MlasGemmQuantKernel1x8( + const MLAS_GEMM_U8X8_KERNEL_WASMRELAXEDSIMD::PackedAType* A, + const MLAS_GEMM_U8X8_KERNEL_WASMRELAXEDSIMD::PackedBType* B, + int32_t* C, + size_t PackedCountK, + size_t CountM, + size_t CountN, + size_t ldc, + const int32_t* RowSumBuffer, + const int32_t* ColumnSumBuffer, + const int32_t* ZeroPointB, + bool ZeroMode) { + MLAS_UNREFERENCED_PARAMETER(CountM); + return GemmQuantKernelNx8Impl<1>(A, B, C, PackedCountK, 0, CountN, ldc, + RowSumBuffer, ColumnSumBuffer, ZeroPointB, ZeroMode); +} - return 1; + +template <> +size_t +MlasGemmQuantKernel( + const MLAS_GEMM_U8X8_KERNEL_WASMRELAXEDSIMD::PackedAType* A, + const MLAS_GEMM_U8X8_KERNEL_WASMRELAXEDSIMD::PackedBType* B, + int32_t* C, + size_t PackedCountK, + size_t CountM, + size_t CountN, + size_t ldc, + const int32_t* RowSumBuffer, + const int32_t* ColumnSumBuffer, + const int32_t* ZeroPointB, + bool ZeroMode +) +{ + size_t RowsHandled = 0; + if (CountM >= 6) { + RowsHandled = MlasGemmQuantKernel6x8(A, B, C, PackedCountK, CountM, CountN, ldc, + RowSumBuffer, ColumnSumBuffer, ZeroPointB, ZeroMode); + } else { + RowsHandled = MlasGemmQuantKernel1x8(A, B, C, PackedCountK, CountM, CountN, ldc, + RowSumBuffer, ColumnSumBuffer, ZeroPointB, ZeroMode); + } + return RowsHandled; } const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8X8DispatchWasmRelaxedSimd = { @@ -559,5 +593,5 @@ const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8X8DispatchWasmRelaxedSimd = { nullptr, MLAS_GEMM_U8X8_KERNEL_WASMRELAXEDSIMD::PackedK, 0, - 4 // multiple of kernel stride M + 6 // multiple of kernel stride M }; diff --git a/onnxruntime/core/mlas/lib/qgemm_kernel_wasmsimd.cpp b/onnxruntime/core/mlas/lib/qgemm_kernel_wasmsimd.cpp index 1f33d77adf4b9..84f6c6bb92f4b 100644 --- a/onnxruntime/core/mlas/lib/qgemm_kernel_wasmsimd.cpp +++ b/onnxruntime/core/mlas/lib/qgemm_kernel_wasmsimd.cpp @@ -322,181 +322,209 @@ MlasGemmQuantCopyPackB( } } -MLAS_FORCEINLINE -void -MlasGemmU8X8MultiplyAccumulateRowWasmSimd( - v128_t ABroadcast, - const int16_t* B, - v128_t Accumulators[2] -) -{ - v128_t BElements0 = wasm_v128_load(&B[0]); - v128_t BElements1 = wasm_v128_load(&B[8]); - - Accumulators[0] = wasm_i32x4_add(Accumulators[0], wasm_i32x4_dot_i16x8(BElements0, ABroadcast)); - Accumulators[1] = wasm_i32x4_add(Accumulators[1], wasm_i32x4_dot_i16x8(BElements1, ABroadcast)); +//------------------------------------------------------------------ +// Helper – dot‑product add for i16×i16 → i32 pairs. +//------------------------------------------------------------------ +MLAS_FORCEINLINE void DotPairAddI16(v128_t ABroadcast, + v128_t BVec0, + v128_t BVec1, + v128_t Acc[2]) { + Acc[0] = wasm_i32x4_add(Acc[0], wasm_i32x4_dot_i16x8(BVec0, ABroadcast)); + Acc[1] = wasm_i32x4_add(Acc[1], wasm_i32x4_dot_i16x8(BVec1, ABroadcast)); } +//------------------------------------------------------------------ +// Generic RowCount×8 kernel (RowCount = 4 or 1) for WASM SIMD. +//------------------------------------------------------------------ -template<> -size_t -MlasGemmQuantKernel( +template +static size_t GemmQuantKernelNx8Impl( const MLAS_GEMM_U8X8_KERNEL_WASMSIMD::PackedAType* A, const MLAS_GEMM_U8X8_KERNEL_WASMSIMD::PackedBType* B, int32_t* C, size_t PackedCountK, - size_t CountM, + size_t /*CountM — ignored*/, size_t CountN, size_t ldc, const int32_t* RowSumBuffer, const int32_t* ColumnSumBuffer, const int32_t* ZeroPointB, - bool ZeroMode - ) + bool ZeroMode) { - MLAS_UNREFERENCED_PARAMETER(CountM); - MLAS_UNREFERENCED_PARAMETER(ldc); - - while (CountN > 0) { - - v128_t Accumulators[2]; - // - // Initialize the accumulators with the row and column sums. - // - - int32_t RowSumValue = RowSumBuffer[0]; - - if (ZeroPointB != nullptr) { + constexpr size_t ColBlock = 8; + const auto PackedK = MLAS_GEMM_U8X8_KERNEL_WASMSIMD::PackedK; // ==2 - int32_t ScaledRowSumBuffer[8]; - - for (size_t i = 0; i < 8; i++) { - ScaledRowSumBuffer[i] = RowSumValue * ZeroPointB[i]; - } + const MLAS_GEMM_U8X8_KERNEL_WASMSIMD::PackedAType* a[RowCount]; + int32_t* c[RowCount]; + for (size_t r = 0; r < RowCount; ++r) { + a[r] = (const MLAS_GEMM_U8X8_KERNEL_WASMSIMD::PackedAType*)(A + r * PackedK * PackedCountK); + c[r] = (int32_t*)(C + r * ldc); + } + while (CountN > 0) { + // ------------------------------------------------------------------ + // 1) Initialize accumulators with row & column sums (and zero‑points) + // ------------------------------------------------------------------ + v128_t Acc[RowCount][2]; + v128_t col0 = wasm_v128_load(ColumnSumBuffer + 0); + v128_t col1 = wasm_v128_load(ColumnSumBuffer + 4); + + if (ZeroPointB) { + v128_t zp0 = wasm_v128_load(ZeroPointB + 0); + v128_t zp1 = wasm_v128_load(ZeroPointB + 4); ZeroPointB += 8; - Accumulators[0] = wasm_v128_load(&ScaledRowSumBuffer[0]); - Accumulators[1] = wasm_v128_load(&ScaledRowSumBuffer[4]); - - } - else { - - Accumulators[0] = wasm_i32x4_splat(RowSumValue); - Accumulators[1] = Accumulators[0]; + for (size_t r = 0; r < RowCount; ++r) { + v128_t RowSumValues = wasm_v128_load32_splat(RowSumBuffer + r); + Acc[r][0] = wasm_i32x4_add(wasm_i32x4_mul(RowSumValues, zp0), col0); + Acc[r][1] = wasm_i32x4_add(wasm_i32x4_mul(RowSumValues, zp1), col1); + } + } else { + for (size_t r = 0; r < RowCount; ++r) { + v128_t RowSumValues = wasm_v128_load32_splat(RowSumBuffer + r); + Acc[r][0] = wasm_i32x4_add(RowSumValues, col0); + Acc[r][1] = wasm_i32x4_add(RowSumValues, col1); + } } - - Accumulators[0] = wasm_i32x4_add(Accumulators[0], wasm_v128_load(&ColumnSumBuffer[0])); - Accumulators[1] = wasm_i32x4_add(Accumulators[1], wasm_v128_load(&ColumnSumBuffer[4])); ColumnSumBuffer += 8; - // - // Broadcast each pair of 16-bit values from the matrix A and multiply + // ---------------------------------------------------------------------- + // 2) Broadcast each pair of 16-bit values from the matrix A and multiply // with the pair of 16-bit values from matrix B, and add the 32-bit // intermediate into the accumulator registers. - // - - const int16_t* a = A; + // ---------------------------------------------------------------------- size_t k = PackedCountK; + while (k > 0) { + v128_t ABroadcast[RowCount]; + for (size_t r = 0; r < RowCount; ++r) { + ABroadcast[r] = wasm_v128_load32_splat(a[r]); + a[r] += 2; + } - while (k >= 4) { - - v128_t AElements = wasm_v128_load((v128_t*)a); - v128_t ABroadcast; - - ABroadcast = wasm_i32x4_shuffle(AElements, wasm_i32x4_splat(0), 0, 0, 0, 0); - MlasGemmU8X8MultiplyAccumulateRowWasmSimd(ABroadcast, &B[0], Accumulators); - - ABroadcast = wasm_i32x4_shuffle(AElements, wasm_i32x4_splat(0), 1, 1, 1, 1); - MlasGemmU8X8MultiplyAccumulateRowWasmSimd(ABroadcast, &B[16], Accumulators); - - ABroadcast = wasm_i32x4_shuffle(AElements, wasm_i32x4_splat(0), 2, 2, 2, 2); - MlasGemmU8X8MultiplyAccumulateRowWasmSimd(ABroadcast, &B[32], Accumulators); - - ABroadcast = wasm_i32x4_shuffle(AElements, wasm_i32x4_splat(0), 3, 3, 3, 3); - MlasGemmU8X8MultiplyAccumulateRowWasmSimd(ABroadcast, &B[48], Accumulators); - - a += 4 * 2; - B += 4 * 16; - k -= 4; - } + v128_t B0 = wasm_v128_load(B + 0); // cols 0‑3 (8 i16) + v128_t B1 = wasm_v128_load(B + 8); // cols 4‑7 (8 i16) - while (k > 0) { - v128_t ABroadcast = wasm_i32x4_splat(*((int32_t*)a)); - MlasGemmU8X8MultiplyAccumulateRowWasmSimd(ABroadcast, &B[0], Accumulators); + for (size_t r = 0; r < RowCount; ++r) { + DotPairAddI16(ABroadcast[r], B0, B1, Acc[r]); + } - a += 2; B += 16; k -= 1; } - // - // Output the accumulator block after optionally accumulating the values + // ------------------------------------------------------------------ + // 3) Output the accumulator block after optionally accumulating the values // from matrix C. - // - + // ------------------------------------------------------------------ if (CountN >= 8) { - - if (!ZeroMode) { - Accumulators[0] = wasm_i32x4_add(Accumulators[0], wasm_v128_load(&C[0])); - Accumulators[1] = wasm_i32x4_add(Accumulators[1], wasm_v128_load(&C[4])); - } - - wasm_v128_store(&C[0], Accumulators[0]); - wasm_v128_store(&C[4], Accumulators[1]); - - C += 8; - CountN -= 8; - - } - else { - - // - // Output the remaining partial output block. - // - - if ((CountN & 4) != 0) { - + for (size_t r = 0; r < RowCount; ++r) { if (!ZeroMode) { - Accumulators[0] = wasm_i32x4_add(Accumulators[0], wasm_v128_load(&C[0])); + Acc[r][0] = wasm_i32x4_add(Acc[r][0], wasm_v128_load(c[r] + 0)); + Acc[r][1] = wasm_i32x4_add(Acc[r][1], wasm_v128_load(c[r] + 4)); } - - wasm_v128_store(&C[0], Accumulators[0]); - C += 4; - - Accumulators[0] = Accumulators[1]; + wasm_v128_store(c[r] + 0, Acc[r][0]); + wasm_v128_store(c[r] + 4, Acc[r][1]); + c[r] += ColBlock; + a[r] -= PackedCountK * 2; // Rewind a[r] for next N-tile (PackedCountK * 2 elements, 16-bit each). } - - if ((CountN & 2) != 0) { - - if (!ZeroMode) { - Accumulators[0] = wasm_i32x4_add(Accumulators[0], wasm_v128_load64_zero(&C[0])); + CountN -= 8; + } else { + // ---- 4/2/1‑column tails ---- + auto Tail = [&](size_t cols, auto load_c, auto store_c) { + for (size_t r = 0; r < RowCount; ++r) { + if (!ZeroMode) Acc[r][0] = wasm_i32x4_add(Acc[r][0], load_c(c[r])); } - - wasm_v128_store64_lane(&C[0], Accumulators[0], 0); - C += 2; - - Accumulators[0] = wasm_i32x4_shuffle(Accumulators[0], wasm_i32x4_splat(0), 2, 3, 2, 3); + for (size_t r = 0; r < RowCount; ++r) store_c(c[r], Acc[r][0]); + for (size_t r = 0; r < RowCount; ++r) c[r] += cols; + }; + + if (CountN & 4) { + Tail(4, + [](int32_t* p) { return wasm_v128_load(p); }, + [](int32_t* p, v128_t v) { wasm_v128_store(p, v); }); + for (size_t r = 0; r < RowCount; ++r) Acc[r][0] = Acc[r][1]; } - - if ((CountN & 1) != 0) { - - int32_t AccumulatorValue = wasm_i32x4_extract_lane(Accumulators[0], 0); - - if (!ZeroMode) { - AccumulatorValue += C[0]; + if (CountN & 2) { + Tail(2, + [](int32_t* p) { return wasm_v128_load64_zero(p); }, + [](int32_t* p, v128_t v) { wasm_v128_store64_lane(p, v, 0); }); + for (size_t r = 0; r < RowCount; ++r) + Acc[r][0] = wasm_i32x4_shuffle(Acc[r][0], wasm_i32x4_splat(0), 2, 3, 2, 3); + } + if (CountN & 1) { + for (size_t r = 0; r < RowCount; ++r) { + int32_t v = wasm_i32x4_extract_lane(Acc[r][0], 0); + if (!ZeroMode) v += *c[r]; + *c[r] = v; } - - C[0] = AccumulatorValue; } - CountN = 0; } } + return RowCount; +} + +size_t MlasGemmQuantKernel4x8( + const MLAS_GEMM_U8X8_KERNEL_WASMSIMD::PackedAType* A, + const MLAS_GEMM_U8X8_KERNEL_WASMSIMD::PackedBType* B, + int32_t* C, + size_t PackedCountK, + size_t CountM, + size_t CountN, + size_t ldc, + const int32_t* RowSumBuffer, + const int32_t* ColumnSumBuffer, + const int32_t* ZeroPointB, + bool ZeroMode) { + MLAS_UNREFERENCED_PARAMETER(CountM); + return GemmQuantKernelNx8Impl<4>(A, B, C, PackedCountK, 0, CountN, ldc, + RowSumBuffer, ColumnSumBuffer, ZeroPointB, ZeroMode); +} + +size_t MlasGemmQuantKernel1x8( + const MLAS_GEMM_U8X8_KERNEL_WASMSIMD::PackedAType* A, + const MLAS_GEMM_U8X8_KERNEL_WASMSIMD::PackedBType* B, + int32_t* C, + size_t PackedCountK, + size_t CountM, + size_t CountN, + size_t ldc, + const int32_t* RowSumBuffer, + const int32_t* ColumnSumBuffer, + const int32_t* ZeroPointB, + bool ZeroMode) { + MLAS_UNREFERENCED_PARAMETER(CountM); + return GemmQuantKernelNx8Impl<1>(A, B, C, PackedCountK, 0, CountN, ldc, + RowSumBuffer, ColumnSumBuffer, ZeroPointB, ZeroMode); +} - return 1; +template<> +size_t +MlasGemmQuantKernel( + const MLAS_GEMM_U8X8_KERNEL_WASMSIMD::PackedAType* A, + const MLAS_GEMM_U8X8_KERNEL_WASMSIMD::PackedBType* B, + int32_t* C, + size_t PackedCountK, + size_t CountM, + size_t CountN, + size_t ldc, + const int32_t* RowSumBuffer, + const int32_t* ColumnSumBuffer, + const int32_t* ZeroPointB, + bool ZeroMode + ) +{ + size_t RowsHandled = 0; + if (CountM >= 4) { + RowsHandled = MlasGemmQuantKernel4x8(A, B, C, PackedCountK, CountM, CountN, ldc, + RowSumBuffer, ColumnSumBuffer, ZeroPointB, ZeroMode); + } else { + RowsHandled = MlasGemmQuantKernel1x8(A, B, C, PackedCountK, CountM, CountN, ldc, + RowSumBuffer, ColumnSumBuffer, ZeroPointB, ZeroMode); + } + return RowsHandled; } const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8X8DispatchWasmSimd = { diff --git a/onnxruntime/core/mlas/lib/rotary_embedding_kernel_avx2.cpp b/onnxruntime/core/mlas/lib/rotary_embedding_kernel_avx2.cpp index 024e67d14a0d7..e5d1327be414a 100644 --- a/onnxruntime/core/mlas/lib/rotary_embedding_kernel_avx2.cpp +++ b/onnxruntime/core/mlas/lib/rotary_embedding_kernel_avx2.cpp @@ -235,8 +235,11 @@ RopeKernel_Avx2_fp32_Impl( __m256i in_mask_vec = _mm256_set_epi32(7, 6, 3, 2, 5, 4, 1, 0); float32x8_t real = _mm256_permutevar8x32_ps(real_s, in_mask_vec); float32x8_t imag = _mm256_permutevar8x32_ps(imag_s, in_mask_vec); - float32x8_t sin_val = _mm256_loadu_ps(sin_data+ i / 2); - float32x8_t cos_val = _mm256_loadu_ps(cos_data + i / 2); + // Use masked loads for sin/cos data to avoid reading beyond buffer bounds + size_t cos_sin_rem = rem / 2; + const __m256i cos_sin_mask = _mm256_loadu_si256((const __m256i*)(mask_buffer + 8 - cos_sin_rem)); + float32x8_t sin_val = _mm256_maskload_ps(sin_data + i / 2, cos_sin_mask); + float32x8_t cos_val = _mm256_maskload_ps(cos_data + i / 2, cos_sin_mask); //Compute Real and Imaginary output values float32x8_t real_out = _mm256_fmsub_ps(real, cos_val, _mm256_mul_ps(imag, sin_val)); float32x8_t imag_out = _mm256_fmadd_ps(real, sin_val, _mm256_mul_ps(imag, cos_val)); diff --git a/onnxruntime/core/optimizer/attention_fusion.cc b/onnxruntime/core/optimizer/attention_fusion.cc index f92451cf7fe6d..616bc1257676f 100644 --- a/onnxruntime/core/optimizer/attention_fusion.cc +++ b/onnxruntime/core/optimizer/attention_fusion.cc @@ -69,9 +69,9 @@ static NodeArg& MergeQkvWeights(Graph& graph, int64_t hidden_size, assert(nullptr != q_tensor); assert(nullptr != k_tensor); assert(nullptr != v_tensor); - Initializer q_initializer(*q_tensor, graph.ModelPath()); - Initializer k_initializer(*k_tensor, graph.ModelPath()); - Initializer v_initializer(*v_tensor, graph.ModelPath()); + Initializer q_initializer(graph, *q_tensor, graph.ModelPath()); + Initializer k_initializer(graph, *k_tensor, graph.ModelPath()); + Initializer v_initializer(graph, *v_tensor, graph.ModelPath()); auto data_type = q_tensor->data_type(); ONNX_NAMESPACE::TensorProto initializer; @@ -111,7 +111,7 @@ static NodeArg& MergeQkvWeights(Graph& graph, int64_t hidden_size, utils::SetRawDataInTensorProto(initializer, result.data(), gsl::narrow(element_count) * sizeof(MLFloat16)); } - return graph_utils::AddInitializer(graph, initializer); + return graph_utils::AddInitializerWithExternalData(graph, initializer); } static NodeArg* ConvertMaskToInt32(Graph& graph, NodeArg* mask_input, ProviderType provider_type, diff --git a/onnxruntime/core/optimizer/attention_fusion_helper.h b/onnxruntime/core/optimizer/attention_fusion_helper.h index aa70b347d7b67..ecbb750d0bf19 100644 --- a/onnxruntime/core/optimizer/attention_fusion_helper.h +++ b/onnxruntime/core/optimizer/attention_fusion_helper.h @@ -310,7 +310,16 @@ bool ValidateUnidirMask(const Graph& graph, const NodeArg& mask, bool& is_unidir // Check that the mask shape is 1x1xWxW auto shape = mask.Shape(); - if (shape == nullptr || static_cast(shape->dim_size()) != 4 || !utils::HasDimValue(shape->dim(0)) || static_cast(1) != shape->dim(0).dim_value() || !utils::HasDimValue(shape->dim(1)) || static_cast(1) != shape->dim(1).dim_value() || !utils::HasDimValue(shape->dim(2)) || !utils::HasDimValue(shape->dim(3)) || shape->dim(2).dim_value() != shape->dim(3).dim_value()) { + if ( + shape == nullptr || + static_cast(shape->dim_size()) != 4 || + !utils::HasDimValue(shape->dim(0)) || + static_cast(1) != shape->dim(0).dim_value() || + !utils::HasDimValue(shape->dim(1)) || + static_cast(1) != shape->dim(1).dim_value() || + !utils::HasDimValue(shape->dim(2)) || + !utils::HasDimValue(shape->dim(3)) || + shape->dim(2).dim_value() != shape->dim(3).dim_value()) { DEBUG_LOG("unidir mask shape not expected"); return false; } @@ -320,28 +329,20 @@ bool ValidateUnidirMask(const Graph& graph, const NodeArg& mask, bool& is_unidir return false; } - if (tensor_proto->data_location() == ONNX_NAMESPACE::TensorProto_DataLocation_EXTERNAL) { - DEBUG_LOG("This optimizer does not support external data for unidirectional mask right now"); - return false; - } - if (tensor_proto->data_type() == ONNX_NAMESPACE::TensorProto_DataType_UINT8) { size_t bytes; if (!utils::GetSizeInBytesFromTensorProto<0>(*tensor_proto, &bytes).IsOK()) { return false; } - auto data = std::make_unique(bytes); - uint8_t* p = data.get(); - if (!utils::UnpackTensor( - *tensor_proto, - tensor_proto->raw_data().size() ? tensor_proto->raw_data().data() : nullptr, - tensor_proto->raw_data().size(), - p, - bytes) - .IsOK()) { + + std::vector mask_data; + // This takes care of external data in case present + auto status = utils::UnpackInitializerData(*tensor_proto, graph.ModelPath(), mask_data); + if (!status.IsOK()) { + DEBUG_LOG(status.ErrorMessage()); return false; } - std::vector mask_data(p, p + bytes); + if (!ValidateUnidirMask(mask_data, shape->dim(2).dim_value(), is_unidirectional)) { DEBUG_LOG("Mask is neither unidirectional nor all ones"); return false; diff --git a/onnxruntime/core/optimizer/compute_optimizer/shared_utils.cc b/onnxruntime/core/optimizer/compute_optimizer/shared_utils.cc index 86a7a4d6afbf8..a98d0ea6f978b 100644 --- a/onnxruntime/core/optimizer/compute_optimizer/shared_utils.cc +++ b/onnxruntime/core/optimizer/compute_optimizer/shared_utils.cc @@ -189,7 +189,7 @@ NodeArg* CreateInitializerFromVector(Graph& graph, "total_count: ", total_count, " values.size(): ", values.size()); utils::SetRawDataInTensorProto(const_tensor, values.data(), values.size() * sizeof(int64_t)); - return &graph_utils::AddInitializer(graph, const_tensor); + return &graph_utils::AddInitializerWithExternalData(graph, const_tensor); } NodeArg* InsertNodesForValidIndices(Graph& graph, diff --git a/onnxruntime/core/optimizer/concat_slice_elimination.cc b/onnxruntime/core/optimizer/concat_slice_elimination.cc index f7a2b3be4466c..b49bcc186e93d 100644 --- a/onnxruntime/core/optimizer/concat_slice_elimination.cc +++ b/onnxruntime/core/optimizer/concat_slice_elimination.cc @@ -86,7 +86,7 @@ static bool GetSliceInfo(const Graph& graph, auto get_initializer_data = [&graph](const ONNX_NAMESPACE::TensorProto* initializer) -> InlinedVector { - Initializer init(*initializer, graph.ModelPath()); + Initializer init(graph, *initializer, graph.ModelPath()); if (initializer->data_type() == ONNX_NAMESPACE::TensorProto::INT32) { int32_t* init_data = init.data(); return InlinedVector(init_data, init_data + init.size()); diff --git a/onnxruntime/core/optimizer/constant_folding.cc b/onnxruntime/core/optimizer/constant_folding.cc index e36eef672c1ed..3d838d8aacfbb 100644 --- a/onnxruntime/core/optimizer/constant_folding.cc +++ b/onnxruntime/core/optimizer/constant_folding.cc @@ -95,7 +95,7 @@ static bool ConstantFoldShapeNode(Graph& graph, Node& node) { ONNX_NAMESPACE::TensorShapeProto result_shape; result_shape.add_dim()->set_dim_value(clamped_slice_length); constant_arg_out->SetShape(result_shape); - graph.AddInitializedTensor(shape_constant); + graph_utils::AddInitializerWithExternalData(graph, shape_constant); } return is_concrete_shape; // convert to constant if this is true @@ -118,7 +118,7 @@ static Status ConstantFoldIfNode(Graph& graph, Node& if_node, const logging::Log } // This is a boolean initializer with a single element. - Initializer condition{*initializer}; + Initializer condition{graph, *initializer}; ORT_RETURN_IF_NOT(condition.size() == 1, "If node condition initializer: `", condition_def->Name(), "' is expected to have a single boolean element"); @@ -317,7 +317,11 @@ Status ConstantFolding::ApplyImpl(Graph& graph, bool& modified, int graph_level, // Build the TensorProto that corresponds to the computed OrtValue and add it as initializer to the graph. auto* constant_arg_out = node->MutableOutputDefs()[fetch_idx]; const Tensor& out_tensor = ort_value.Get(); - ONNX_NAMESPACE::TensorProto out_tensorproto = utils::TensorToTensorProto(out_tensor, constant_arg_out->Name()); + constexpr const bool use_tensor_buffer_true = true; + ONNX_NAMESPACE::TensorProto out_tensorproto = utils::TensorToTensorProto( + out_tensor, + constant_arg_out->Name(), + use_tensor_buffer_true); ONNX_NAMESPACE::TensorShapeProto result_shape; for (auto& dim : out_tensor.Shape().GetDims()) { @@ -325,7 +329,12 @@ Status ConstantFolding::ApplyImpl(Graph& graph, bool& modified, int graph_level, } constant_arg_out->SetShape(result_shape); - graph.AddInitializedTensor(out_tensorproto); + // The data is too small and has been inlined. + if (!utils::HasExternalData(out_tensorproto)) { + ORT_THROW_IF_ERROR(graph.AddInitializedOrtValue(out_tensorproto, OrtValue())); + } else { + ORT_THROW_IF_ERROR(graph.AddInitializedOrtValue(out_tensorproto, ort_value)); + } } } } diff --git a/onnxruntime/core/optimizer/conv_add_fusion.cc b/onnxruntime/core/optimizer/conv_add_fusion.cc index adc1efae5ced4..c349adfccce53 100644 --- a/onnxruntime/core/optimizer/conv_add_fusion.cc +++ b/onnxruntime/core/optimizer/conv_add_fusion.cc @@ -62,8 +62,8 @@ Status ConvAddFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& modifie return Status::OK(); } - Initializer conv_B{*conv_B_tensor_proto, graph.ModelPath()}; - Initializer add_B{*add_B_tensor_proto, graph.ModelPath()}; + Initializer conv_B{graph, *conv_B_tensor_proto, graph.ModelPath()}; + Initializer add_B{graph, *add_B_tensor_proto, graph.ModelPath()}; if (conv_B.size() != add_B.size()) { return Status::OK(); @@ -79,12 +79,14 @@ Status ConvAddFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& modifie auto new_name = graph.GenerateNodeArgName("ConvAddFusion_B_" + B_input_name); new_conv_B_tensor_proto.set_name(new_name); - NodeArg& new_conv_B_node_arg = graph_utils::AddInitializer(graph, new_conv_B_tensor_proto); + NodeArg& new_conv_B_node_arg = graph_utils::AddInitializerWithExternalData(graph, new_conv_B_tensor_proto); graph_utils::ReplaceNodeInput(node, 2, new_conv_B_node_arg); } else { // Create new tensor proto and update shape - ONNX_NAMESPACE::TensorProto new_conv_B_tensor_proto(*add_B_tensor_proto); + Initializer add_B{graph, *add_B_tensor_proto, graph.ModelPath()}; + ONNX_NAMESPACE::TensorProto new_conv_B_tensor_proto; + add_B.ToProto(new_conv_B_tensor_proto); int64_t dim = conv_W_tensor_proto->dims(0); new_conv_B_tensor_proto.clear_dims(); new_conv_B_tensor_proto.add_dims(dim); @@ -92,7 +94,7 @@ Status ConvAddFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& modifie auto new_name = graph.GenerateNodeArgName("ConvAddFusion_Add_B_" + add_B_tensor_proto->name()); new_conv_B_tensor_proto.set_name(new_name); - NodeArg& new_add_B_node_arg = graph_utils::AddInitializer(graph, new_conv_B_tensor_proto); + NodeArg& new_add_B_node_arg = graph_utils::AddInitializerWithExternalData(graph, new_conv_B_tensor_proto); graph_utils::AddNodeInput(node, 2, new_add_B_node_arg); } diff --git a/onnxruntime/core/optimizer/conv_bn_fusion.cc b/onnxruntime/core/optimizer/conv_bn_fusion.cc index 392d03de037cf..8bf5420baddde 100644 --- a/onnxruntime/core/optimizer/conv_bn_fusion.cc +++ b/onnxruntime/core/optimizer/conv_bn_fusion.cc @@ -61,13 +61,13 @@ Status ConvBNFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_eff return Status::OK(); } - Initializer bn_scale{*bn_scale_tensor_proto, graph.ModelPath()}; - Initializer bn_B{*bn_B_tensor_proto, graph.ModelPath()}; - Initializer bn_mean{*bn_mean_tensor_proto, graph.ModelPath()}; - Initializer bn_var{*bn_var_tensor_proto, graph.ModelPath()}; - Initializer conv_W{*conv_W_tensor_proto, graph.ModelPath()}; + Initializer bn_scale{graph, *bn_scale_tensor_proto, graph.ModelPath()}; + Initializer bn_B{graph, *bn_B_tensor_proto, graph.ModelPath()}; + Initializer bn_mean{graph, *bn_mean_tensor_proto, graph.ModelPath()}; + Initializer bn_var{graph, *bn_var_tensor_proto, graph.ModelPath()}; + Initializer conv_W{graph, *conv_W_tensor_proto, graph.ModelPath()}; - std::unique_ptr conv_B = nullptr; + std::optional conv_B; const ONNX_NAMESPACE::TensorProto* conv_B_tensor_proto = nullptr; if (conv_inputs.size() == 3) { conv_B_tensor_proto = graph_utils::GetConstantInitializer(graph, conv_inputs[2]->Name()); @@ -79,7 +79,7 @@ Status ConvBNFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_eff conv_B_tensor_proto->data_type() != bn_B_tensor_proto->data_type()) { return Status::OK(); } - conv_B = std::make_unique(*conv_B_tensor_proto, graph.ModelPath()); + conv_B.emplace(graph, *conv_B_tensor_proto, graph.ModelPath()); } // Calculate new value of initializers of conv node @@ -98,7 +98,7 @@ Status ConvBNFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_eff } // Create new initializers of conv - ONNX_NAMESPACE::TensorProto new_conv_W_tensor_proto(*conv_W_tensor_proto); + ONNX_NAMESPACE::TensorProto new_conv_W_tensor_proto; conv_W.ToProto(new_conv_W_tensor_proto); ONNX_NAMESPACE::TensorProto new_conv_B_tensor_proto; @@ -120,10 +120,10 @@ Status ConvBNFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_eff new_conv_W_tensor_proto.set_name(new_W_name); new_conv_B_tensor_proto.set_name(new_B_name); - NodeArg& new_conv_W_node_arg = graph_utils::AddInitializer(graph, new_conv_W_tensor_proto); + NodeArg& new_conv_W_node_arg = graph_utils::AddInitializerWithExternalData(graph, new_conv_W_tensor_proto); graph_utils::ReplaceNodeInput(node, 1, new_conv_W_node_arg); - auto& new_conv_B_node_arg = graph_utils::AddInitializer(graph, new_conv_B_tensor_proto); + auto& new_conv_B_node_arg = graph_utils::AddInitializerWithExternalData(graph, new_conv_B_tensor_proto); if (conv_inputs.size() == 3) { graph_utils::ReplaceNodeInput(node, 2, new_conv_B_node_arg); diff --git a/onnxruntime/core/optimizer/conv_mul_fusion.cc b/onnxruntime/core/optimizer/conv_mul_fusion.cc index 6da6d089d5a71..dc50a150537f7 100644 --- a/onnxruntime/core/optimizer/conv_mul_fusion.cc +++ b/onnxruntime/core/optimizer/conv_mul_fusion.cc @@ -52,11 +52,11 @@ Status ConvMulFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_ef } } - Initializer conv_W{*conv_W_tensor_proto, graph.ModelPath()}; - Initializer mul_B{*mul_B_tensor_proto, graph.ModelPath()}; + Initializer conv_W{graph, *conv_W_tensor_proto, graph.ModelPath()}; + Initializer mul_B{graph, *mul_B_tensor_proto, graph.ModelPath()}; const ONNX_NAMESPACE::TensorProto* conv_B_tensor_proto = nullptr; - std::unique_ptr conv_B = nullptr; + std::optional conv_B; const bool is_3d = conv_inputs.size() == 3; if (is_3d) { conv_B_tensor_proto = graph_utils::GetConstantInitializer(graph, conv_inputs[2]->Name()); @@ -68,7 +68,7 @@ Status ConvMulFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_ef return Status::OK(); } - conv_B = std::make_unique(*conv_B_tensor_proto, graph.ModelPath()); + conv_B.emplace(graph, *conv_B_tensor_proto, graph.ModelPath()); } // Calculate new value of initializers of conv node @@ -83,24 +83,24 @@ Status ConvMulFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_ef } // Create new initializers of conv - ONNX_NAMESPACE::TensorProto new_conv_W_tensor_proto(*conv_W_tensor_proto); + ONNX_NAMESPACE::TensorProto new_conv_W_tensor_proto; conv_W.ToProto(new_conv_W_tensor_proto); auto new_W_name = graph.GenerateNodeArgName("ConvMulFusion_W_" + conv_W_tensor_proto->name()); new_conv_W_tensor_proto.set_name(new_W_name); // Replace initializers of conv node - NodeArg& new_conv_W_node_arg = graph_utils::AddInitializer(graph, new_conv_W_tensor_proto); + NodeArg& new_conv_W_node_arg = graph_utils::AddInitializerWithExternalData(graph, new_conv_W_tensor_proto); graph_utils::ReplaceNodeInput(conv_node, 1, new_conv_W_node_arg); if (is_3d) { - ONNX_NAMESPACE::TensorProto new_conv_B_tensor_proto(*conv_B_tensor_proto); + ONNX_NAMESPACE::TensorProto new_conv_B_tensor_proto; conv_B->ToProto(new_conv_B_tensor_proto); auto new_B_name = graph.GenerateNodeArgName("ConvMulFusion_Mul_B_" + mul_B_tensor_proto->name()); new_conv_B_tensor_proto.set_name(new_B_name); - NodeArg& new_conv_B_node_arg = graph_utils::AddInitializer(graph, new_conv_B_tensor_proto); + NodeArg& new_conv_B_node_arg = graph_utils::AddInitializerWithExternalData(graph, new_conv_B_tensor_proto); graph_utils::ReplaceNodeInput(conv_node, 2, new_conv_B_node_arg); } diff --git a/onnxruntime/core/optimizer/div_mul_fusion.cc b/onnxruntime/core/optimizer/div_mul_fusion.cc index 7184e931cb74e..e2cd66fe73f86 100644 --- a/onnxruntime/core/optimizer/div_mul_fusion.cc +++ b/onnxruntime/core/optimizer/div_mul_fusion.cc @@ -40,7 +40,7 @@ bool DivMulFusion::SatisfyCondition(const Graph& graph, const Node& node, const } int32_t data_type = initializer->data_type(); - Initializer div_A(*initializer, graph.ModelPath()); + Initializer div_A(graph, *initializer, graph.ModelPath()); if (div_A.size() > 1) { return false; } diff --git a/onnxruntime/core/optimizer/double_qdq_pairs_remover.cc b/onnxruntime/core/optimizer/double_qdq_pairs_remover.cc index 22b9dca39dceb..1841dfa2791e0 100644 --- a/onnxruntime/core/optimizer/double_qdq_pairs_remover.cc +++ b/onnxruntime/core/optimizer/double_qdq_pairs_remover.cc @@ -46,13 +46,13 @@ static bool GetQNodeZeroPointType(const Graph& graph, const Node& q_node, template static void ApplyNewInputValue(Graph& graph, Node& node, QDQ::InputIndex index, T value) { const auto* input_tensor = graph_utils::GetConstantInitializer(graph, node.InputDefs()[index]->Name()); - Initializer input_init{*input_tensor, graph.ModelPath()}; - ONNX_NAMESPACE::TensorProto new_input_tensor(*input_tensor); + Initializer input_init{graph, *input_tensor, graph.ModelPath()}; + ONNX_NAMESPACE::TensorProto new_input_tensor; input_init.data()[0] = value; input_init.ToProto(new_input_tensor); auto new_name = graph.GenerateNodeArgName("DoubleQDQRemoved_" + node.InputDefs()[index]->Name()); new_input_tensor.set_name(new_name); - NodeArg& new_input = graph_utils::AddInitializer(graph, new_input_tensor); + NodeArg& new_input = graph_utils::AddInitializerWithExternalData(graph, new_input_tensor); graph_utils::ReplaceNodeInput(node, index, new_input); } @@ -79,10 +79,10 @@ static bool FindNewZeroPointAndScale(const Graph& graph, const Node& node1, cons graph_utils::GetConstantInitializer(graph, node1_zp_name); const ONNX_NAMESPACE::TensorProto* node2_zp_tensor_proto = graph_utils::GetConstantInitializer(graph, node2_zp_name); - Initializer zero_point_init_1{*node1_zp_tensor_proto, graph.ModelPath()}; - Initializer zero_point_init_2{*node2_zp_tensor_proto, graph.ModelPath()}; - Initializer scale_init_1{*node1_scale_tensor_proto, graph.ModelPath()}; - Initializer scale_init_2{*node2_scale_tensor_proto, graph.ModelPath()}; + Initializer zero_point_init_1{graph, *node1_zp_tensor_proto, graph.ModelPath()}; + Initializer zero_point_init_2{graph, *node2_zp_tensor_proto, graph.ModelPath()}; + Initializer scale_init_1{graph, *node1_scale_tensor_proto, graph.ModelPath()}; + Initializer scale_init_2{graph, *node2_scale_tensor_proto, graph.ModelPath()}; if (zero_point_init_1.data_type() != zero_point_init_2.data_type() || scale_init_1.data_type() != ONNX_NAMESPACE::TensorProto_DataType_FLOAT || scale_init_2.data_type() != ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { @@ -181,7 +181,7 @@ static bool TryReduceDoubleQDQSequence(Graph& graph, NodeIndex q1_index) { } // The Q1 and DQ1 nodes must have equal zero-point and scale values (scalar/constant). - if (!QDQ::IsQDQPairSupported(*q1, *dq1, get_constant_initializer, graph.ModelPath())) { + if (!QDQ::IsQDQPairSupported(graph, *q1, *dq1, get_constant_initializer, graph.ModelPath())) { return false; } @@ -218,7 +218,7 @@ static bool TryReduceDoubleQDQSequence(Graph& graph, NodeIndex q1_index) { } // The Q2 and DQ2 nodes must have equal zero-point and scale values (scalar/constant). - if (!QDQ::IsQDQPairSupported(*q2, *dq2, get_constant_initializer, graph.ModelPath())) { + if (!QDQ::IsQDQPairSupported(graph, *q2, *dq2, get_constant_initializer, graph.ModelPath())) { return false; } diff --git a/onnxruntime/core/optimizer/dropout_elimination.cc b/onnxruntime/core/optimizer/dropout_elimination.cc index b82a944125667..d989c4dd80532 100644 --- a/onnxruntime/core/optimizer/dropout_elimination.cc +++ b/onnxruntime/core/optimizer/dropout_elimination.cc @@ -41,7 +41,7 @@ bool EliminateDropout::SatisfyCondition(const Graph& graph, const Node& node, co return false; } int32_t data_type = initializer->data_type(); - Initializer ratio(*initializer, graph.ModelPath()); + Initializer ratio(graph, *initializer, graph.ModelPath()); switch (data_type) { case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: if (*ratio.data() > 0.f) { diff --git a/onnxruntime/core/optimizer/embed_layer_norm_fusion.cc b/onnxruntime/core/optimizer/embed_layer_norm_fusion.cc index 103e72072f713..ad25f95ac1186 100644 --- a/onnxruntime/core/optimizer/embed_layer_norm_fusion.cc +++ b/onnxruntime/core/optimizer/embed_layer_norm_fusion.cc @@ -450,7 +450,7 @@ static NodeArg* ExtractEmbedding(Graph& graph, assert(sequence_length > 0); assert(hidden_size > 0); - Initializer old_initializer{*tensor, graph.ModelPath()}; + Initializer old_initializer{graph, *tensor, graph.ModelPath()}; auto data_type = tensor->data_type(); ONNX_NAMESPACE::TensorProto initializer; @@ -474,7 +474,7 @@ static NodeArg* ExtractEmbedding(Graph& graph, utils::SetRawDataInTensorProto(initializer, data, gsl::narrow(element_count) * sizeof(MLFloat16)); } - NodeArg& node_arg = graph_utils::AddInitializer(graph, initializer); + NodeArg& node_arg = graph_utils::AddInitializerWithExternalData(graph, initializer); modified = true; return &node_arg; } diff --git a/onnxruntime/core/optimizer/expand_elimination.cc b/onnxruntime/core/optimizer/expand_elimination.cc index 8aadeb5a1a273..86bf616ea05e2 100644 --- a/onnxruntime/core/optimizer/expand_elimination.cc +++ b/onnxruntime/core/optimizer/expand_elimination.cc @@ -36,12 +36,12 @@ bool ExpandElimination::SatisfyCondition(const Graph& graph, const Node& node, c return false; } - auto initializer = std::make_unique(*tensor_proto, graph.ModelPath()); - if (initializer->data_type() != ONNX_NAMESPACE::TensorProto_DataType_INT64) { + Initializer initializer(graph, *tensor_proto, graph.ModelPath()); + if (initializer.data_type() != ONNX_NAMESPACE::TensorProto_DataType_INT64) { return false; } - const int64_t* target_shapes = initializer->data(); + const int64_t* target_shapes = initializer.data(); // Check the dimensions starting at the trailing dimension. int i = input_shape->dim_size() - 1; diff --git a/onnxruntime/core/optimizer/fuse_initializers_transformer.cc b/onnxruntime/core/optimizer/fuse_initializers_transformer.cc index 1516b07fc2049..388ab14dd51fe 100644 --- a/onnxruntime/core/optimizer/fuse_initializers_transformer.cc +++ b/onnxruntime/core/optimizer/fuse_initializers_transformer.cc @@ -116,27 +116,34 @@ static void FuseInitializerWithNode(Graph& graph, } // Get the src initialized tensor at input def index 0 - auto constant_initializer_tensor = graph_utils::GetConstantInitializer(graph, node.InputDefs()[0]->Name()); - ONNX_NAMESPACE::TensorProto src_tensor(*constant_initializer_tensor); + const auto* constant_initializer_tensor = graph_utils::GetConstantInitializer(graph, node.InputDefs()[0]->Name()); Initializer src_init{*constant_initializer_tensor, graph.ModelPath()}; - src_init.ToProto(src_tensor); // Convert to dst tensor - ONNX_NAMESPACE::TensorProto dst_tensor; + std::string new_arg_name = graph.GenerateNodeArgName(NewNodeArgName( + next_node.InputDefs()[next_node_arg_index]->Name())); + + OrtValue new_data; if (next_node_arg_type == DataTypeImpl::GetTensorType()) - dst_tensor = src_init.ToFloat32(graph.GenerateNodeArgName(NewNodeArgName(next_node.InputDefs()[next_node_arg_index]->Name())), thread_pool); + new_data = src_init.ToFloat32(thread_pool); else if (next_node_arg_type == DataTypeImpl::GetTensorType()) - dst_tensor = src_init.ToFP16(graph.GenerateNodeArgName(NewNodeArgName(next_node.InputDefs()[next_node_arg_index]->Name()))); + new_data = src_init.ToFP16(); else if (next_node_arg_type == DataTypeImpl::GetTensorType()) - dst_tensor = src_init.ToBFloat16(graph.GenerateNodeArgName(NewNodeArgName(next_node.InputDefs()[next_node_arg_index]->Name()))); + new_data = src_init.ToBFloat16(); else return; // Remove the edge between the current node output def at index 0 and next node arg at relative arg index. graph.RemoveEdge(node.Index(), next_node.Index(), 0, static_cast(next_node_arg_index)); - // Add the new converted Tensor in next node as initializer - graph_utils::ReplaceNodeInput(next_node, static_cast(next_node_arg_index), graph_utils::AddInitializer(graph, dst_tensor)); + // Add the new converted Tensor in next node as initializer potentially with external data + ONNX_NAMESPACE::TensorProto dst_tensor = utils::TensorToTensorProto(new_data.Get(), new_arg_name, true); + if (!utils::HasExternalData(dst_tensor)) { + new_data = OrtValue(); // Data is inline + } + + auto& new_arg = graph_utils::AddInitializerWithExternalData(graph, dst_tensor, std::move(new_data)); + graph_utils::ReplaceNodeInput(next_node, static_cast(next_node_arg_index), new_arg); } Status FuseInitializersTransformer::ApplyImpl(Graph& graph, bool& modified, int /*graph_level*/, const logging::Logger& /*logger*/) const { diff --git a/onnxruntime/core/optimizer/gather_fusion.cc b/onnxruntime/core/optimizer/gather_fusion.cc index 9732ec2587b2a..3cd06350df95d 100644 --- a/onnxruntime/core/optimizer/gather_fusion.cc +++ b/onnxruntime/core/optimizer/gather_fusion.cc @@ -27,7 +27,7 @@ static bool GetScalarInt64Initializer(const Graph& graph, const NodeArg& node_ar if (!optimizer_utils::IsScalar(node_arg)) return false; const ONNX_NAMESPACE::TensorProto* tensor_proto = graph_utils::GetConstantInitializer(graph, node_arg.Name()); if (!tensor_proto || tensor_proto->data_type() != ONNX_NAMESPACE::TensorProto::INT64) return false; - Initializer init_const{*tensor_proto, graph.ModelPath()}; + Initializer init_const{graph, *tensor_proto, graph.ModelPath()}; value = *(init_const.data()); rank = tensor_proto->dims_size(); return true; @@ -256,7 +256,7 @@ Status GatherSliceToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int gra axes_initializer_proto.add_dims(static_cast(1)); axes_initializer_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); axes_initializer_proto.add_int64_data(axis); - NodeArg* axes_arg = &graph_utils::AddInitializer(graph, axes_initializer_proto); + NodeArg* axes_arg = &graph_utils::AddInitializerWithExternalData(graph, axes_initializer_proto); Node& squeeze_node = graph.AddNode(graph.GenerateNodeName("Squeeze"), "Squeeze", "Squeeze for Fused Gather nodes", {split_output_arg, axes_arg}, {original_output_arg}); @@ -272,7 +272,7 @@ Status GatherSliceToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int gra split_initializer_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); split_initializer_proto.add_dims(static_cast(split_values.size())); split_initializer_proto.mutable_int64_data()->Add(split_values.begin(), split_values.end()); - NodeArg* split_initializer_arg = &graph_utils::AddInitializer(graph, split_initializer_proto); + NodeArg* split_initializer_arg = &graph_utils::AddInitializerWithExternalData(graph, split_initializer_proto); const auto split_node_name = graph.GenerateNodeName(nodes_to_fuse[0].get().Name() + "/GatherSliceToSplitFusion"); Node& split_node = graph.AddNode(split_node_name, "Split", "Split for Fused Gather nodes", {graph.GetNodeArg(node_arg->Name()), split_initializer_arg}, split_outputs); @@ -359,7 +359,7 @@ Status GatherToSliceFusion::ApplyImpl(Graph& graph, bool& modified, int graph_le unsqueeze_axes_initializer_proto.add_dims(static_cast(1)); unsqueeze_axes_initializer_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); unsqueeze_axes_initializer_proto.add_int64_data(static_cast(0)); - NodeArg* unsqueeze_axes_arg = &graph_utils::AddInitializer(graph, unsqueeze_axes_initializer_proto); + NodeArg* unsqueeze_axes_arg = &graph_utils::AddInitializerWithExternalData(graph, unsqueeze_axes_initializer_proto); for (size_t i = 0; i < range_input_defs.size(); ++i) { Node& unsqueeze_node = graph.AddNode(graph.GenerateNodeName("Unsqueeze_" + std::to_string(i)), "Unsqueeze", @@ -386,7 +386,7 @@ Status GatherToSliceFusion::ApplyImpl(Graph& graph, bool& modified, int graph_le } else { slice_axes_initializer_proto.add_int32_data(static_cast(axis)); } - NodeArg* slice_axes_arg = &graph_utils::AddInitializer(graph, slice_axes_initializer_proto); + NodeArg* slice_axes_arg = &graph_utils::AddInitializerWithExternalData(graph, slice_axes_initializer_proto); Node& slice_node = graph.AddNode(graph.GenerateNodeName("Slice"), "Slice", "Slice for Fused Gather nodes", {gather_node.MutableInputDefs()[0], unsqueeze_outputs[0], unsqueeze_outputs[1], slice_axes_arg, unsqueeze_outputs[2]}, diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc index 24f4ad867d101..062cbce6387e6 100644 --- a/onnxruntime/core/optimizer/graph_transformer_utils.cc +++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc @@ -204,8 +204,7 @@ InlinedVector> GenerateTransformers( const IExecutionProvider& cpu_execution_provider, /*required by constant folding*/ const logging::Logger& logger, const InlinedHashSet& rules_and_transformers_to_disable, - [[maybe_unused]] concurrency::ThreadPool* intra_op_thread_pool, - std::unordered_map>* p_buffered_tensors) { + [[maybe_unused]] concurrency::ThreadPool* intra_op_thread_pool) { InlinedVector> transformers; const bool disable_quant_qdq = session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsDisableQuantQDQ, "0") == "1"; @@ -217,7 +216,7 @@ InlinedVector> GenerateTransformers( onnxruntime::kAclExecutionProvider}; #endif const InlinedHashSet dml_ep = {onnxruntime::kDmlExecutionProvider}; - AllocatorPtr cpu_allocator = std::make_shared(); + AllocatorPtr cpu_allocator = CPUAllocator::DefaultInstance(); switch (level) { case TransformerLevel::Level1: { @@ -348,8 +347,7 @@ InlinedVector> GenerateTransformers( transformers.emplace_back(std::make_unique(qdq_is_int8_allowed, SatApplyContextVariant{}, qdq_matmulnbits_accuracy_level, - intra_op_thread_pool, - p_buffered_tensors)); + intra_op_thread_pool)); } transformers.emplace_back(std::make_unique(cpu_ep)); @@ -471,8 +469,7 @@ InlinedVector> GenerateTransformersForMinimalB const IExecutionProvider& cpu_execution_provider, const logging::Logger& logger, const InlinedHashSet& rules_and_transformers_to_disable, - [[maybe_unused]] concurrency::ThreadPool* intra_op_thread_pool, - std::unordered_map>* p_buffered_tensors) { + [[maybe_unused]] concurrency::ThreadPool* intra_op_thread_pool) { InlinedVector> transformers; const bool saving = std::holds_alternative(apply_context); @@ -497,8 +494,7 @@ InlinedVector> GenerateTransformersForMinimalB transformers.emplace_back(std::make_unique(qdq_is_int8_allowed, apply_context, qdq_matmulnbits_accuracy_level, - intra_op_thread_pool, - p_buffered_tensors)); + intra_op_thread_pool)); } transformers.emplace_back(std::make_unique(cpu_ep, apply_context)); @@ -521,7 +517,7 @@ InlinedVector> GenerateTransformersForMinimalB // currently the only level 3 optimizer is the NhwcTransformer which is fully supported at runtime if (!saving) { #ifndef DISABLE_CONTRIB_OPS - AllocatorPtr cpu_allocator = std::make_shared(); + AllocatorPtr cpu_allocator = CPUAllocator::DefaultInstance(); auto cpu_registry = cpu_execution_provider.GetKernelRegistry(); auto nhwc_transformer = std::make_unique(std::move(cpu_allocator), std::move(cpu_registry), logger); diff --git a/onnxruntime/core/optimizer/identical_children_consolidation.cc b/onnxruntime/core/optimizer/identical_children_consolidation.cc index bbc8073268f08..4f31db922d078 100644 --- a/onnxruntime/core/optimizer/identical_children_consolidation.cc +++ b/onnxruntime/core/optimizer/identical_children_consolidation.cc @@ -69,7 +69,7 @@ std::string IdenticalChildrenConsolidation::IdentityBuilder(const Graph& graph, if (optimizer_utils::IsScalar(*input_def)) { const auto* data = graph_utils::GetConstantInitializer(graph, name); identity << constant_prefix; - Initializer value{*data, graph.ModelPath()}; + Initializer value{graph, *data, graph.ModelPath()}; switch (static_cast(data->data_type())) { case TensorProto::DataType::TensorProto_DataType_INT8: identity << *value.data(); diff --git a/onnxruntime/core/optimizer/initializer.cc b/onnxruntime/core/optimizer/initializer.cc index 81eb50286728f..6fbb4177ce90a 100644 --- a/onnxruntime/core/optimizer/initializer.cc +++ b/onnxruntime/core/optimizer/initializer.cc @@ -15,40 +15,82 @@ namespace onnxruntime { +static inline Tensor* GetTensor(OrtValue& ort_value) { + return ort_value.GetMutable(); +} + Initializer::Initializer(ONNX_NAMESPACE::TensorProto_DataType data_type, std::string_view name, - gsl::span dims) - : name_(name), - data_(DataTypeImpl::TensorTypeFromONNXEnum(data_type)->GetElementType(), dims, - std::make_shared()) { - if (!data_.IsDataTypeString()) { - memset(data_.MutableDataRaw(), 0, data_.SizeInBytes()); + gsl::span dims) : name_(name) { + auto tensor = Tensor(DataTypeImpl::TensorTypeFromONNXEnum(data_type)->GetElementType(), dims, + CPUAllocator::DefaultInstance()); + + if (!tensor.IsDataTypeString()) { + memset(tensor.MutableDataRaw(), 0, tensor.SizeInBytes()); } + + Tensor::InitOrtValue(std::move(tensor), ort_value_); + data_ = GetTensor(ort_value_); } Initializer::Initializer(const ONNX_NAMESPACE::TensorProto& tensor_proto, const std::filesystem::path& model_path) { - ORT_ENFORCE(utils::HasDataType(tensor_proto), "Initializer must have a datatype"); + ORT_ENFORCE(utils::HasName(tensor_proto), "Initializer must have a name"); + name_ = tensor_proto.name(); + #if !defined(__wasm__) // using full filepath is required by utils::TensorProtoToTensor(). One exception is WebAssembly platform, where // external data is not loaded from real file system. - if (utils::HasExternalData(tensor_proto)) { + if (utils::HasExternalData(tensor_proto) && !utils::HasExternalDataInMemory(tensor_proto)) { ORT_ENFORCE(!model_path.empty(), "model_path must not be empty. Ensure that a path is provided when the model is created or loaded."); } #endif - auto proto_data_type = tensor_proto.data_type(); - if (utils::HasName(tensor_proto)) { - name_ = tensor_proto.name(); + Tensor tensor; + // This creates copy of the data so clients can mutate + ORT_THROW_IF_ERROR(utils::CreateTensorFromTensorProto(Env::Default(), model_path, tensor_proto, tensor)); + Tensor::InitOrtValue(std::move(tensor), ort_value_); + data_ = GetTensor(ort_value_); +} + +Initializer::Initializer(const Graph& graph, const ONNX_NAMESPACE::TensorProto& tensor_proto, + const std::filesystem::path& model_path, bool check_outer_scope) { + ORT_ENFORCE(utils::HasName(tensor_proto), "Initializer must have a name"); + name_ = tensor_proto.name(); + + // Check if the data is in memory. This does not mean, though, that the data is in the ort_value + if (utils::HasExternalDataInMemory(tensor_proto)) { + OrtValue ort_value; + if (graph.GetOrtValueInitializer(name_, ort_value, check_outer_scope)) { + const auto& src_tensor = ort_value.Get(); + // We need to make a copy of the data to ensure that the original data is not mutated + // This is generally inline with TensorProtoToTensor() behavior which copies data from + // TensorProto to Tensor. + Tensor initializer{src_tensor.DataType(), src_tensor.Shape(), CPUAllocator::DefaultInstance()}; + utils::MakeCpuTensorCopy(src_tensor, initializer); + Tensor::InitOrtValue(std::move(initializer), ort_value_); + data_ = GetTensor(ort_value_); + return; + } +#if !defined(__wasm__) + ORT_ENFORCE(!model_path.empty(), + "model_path must not be empty. Ensure that a path is provided when the model is created or loaded."); +#endif } - auto proto_shape = utils::GetTensorShapeFromTensorProto(tensor_proto); + Tensor tensor; + // Creates a copy of the data from tensor_proto + ORT_THROW_IF_ERROR(utils::CreateTensorFromTensorProto(Env::Default(), model_path, tensor_proto, tensor)); + Tensor::InitOrtValue(std::move(tensor), ort_value_); + data_ = GetTensor(ort_value_); +} + +Initializer::~Initializer() = default; - // This must be pre-allocated - Tensor w(DataTypeImpl::TensorTypeFromONNXEnum(proto_data_type)->GetElementType(), proto_shape, - std::make_shared()); - ORT_THROW_IF_ERROR(utils::TensorProtoToTensor(Env::Default(), model_path, tensor_proto, w)); - data_ = std::move(w); +void Initializer::ToProtoWithOrtValue(ONNX_NAMESPACE::TensorProto& tensor_proto, OrtValue& ort_value) const { + constexpr const bool use_tensor_buffer_true = true; + tensor_proto = utils::TensorToTensorProto(*data_, name_, use_tensor_buffer_true); + ort_value = ort_value_; } #if !defined(ORT_EXTENDED_MINIMAL_BUILD) @@ -90,6 +132,26 @@ struct TensorToProtoFP16 { } }; +template +struct TensorToFP16 { + void operator()(const Tensor& data, Tensor& dst) const { + ToFp16 to_fp16; + auto span = data.DataAsSpan(); + auto* dst_data = dst.MutableData(); + for (const auto& v : span) { + *dst_data++ = MLFloat16::FromBits(to_fp16(v)); + } + } +}; + +template <> +struct TensorToFP16 { + void operator()(const Tensor& data, Tensor& dst) const { + const auto count = narrow(data.Shape().Size()); + MlasConvertFloatToHalfBuffer(data.Data(), dst.MutableData(), count); + } +}; + template struct ToBFloat16; @@ -127,6 +189,18 @@ struct TensorToProtoBFloat16 { } }; +template +struct TensorToBFloat16 { + void operator()(const Tensor& data, Tensor& dst) const { + ToBFloat16 to_bfloat16; + auto span = data.DataAsSpan(); + auto* dst_data = dst.MutableData(); + for (const auto& v : span) { + *dst_data++ = BFloat16::FromBits(to_bfloat16(v)); + } + } +}; + template struct ToFloat32; @@ -159,27 +233,24 @@ struct ToFloat32 { }; template -struct TensorToProtoFloat32 { - void operator()(const Tensor& data, ONNX_NAMESPACE::TensorProto& proto, onnxruntime::concurrency::ThreadPool* /*thread_pool*/) const { - auto span = data.DataAsSpan(); +struct TensorToFloat32 { + void operator()(const Tensor& src, Tensor& dst, onnxruntime::concurrency::ThreadPool* /*thread_pool*/) const { + auto src_span = src.DataAsSpan(); + auto* dst_data = dst.MutableData(); ToFloat32 to_float32; - for (const auto& v : span) { - proto.add_float_data(to_float32(v)); + for (const auto& v : src_span) { + *dst_data++ = to_float32(v); } } }; template <> -struct TensorToProtoFloat32 { +struct TensorToFloat32 { void operator()(const Tensor& data, - ONNX_NAMESPACE::TensorProto& proto, + Tensor& dst, onnxruntime::concurrency::ThreadPool* thread_pool) const { - auto source = reinterpret_cast(data.DataRaw()); - auto count = size_t(data.SizeInBytes() / sizeof(MLFloat16)); - auto destination_mem = std::make_unique(count); - auto destination = destination_mem.get(); - MlasConvertHalfToFloatBufferInParallel(source, destination, count, thread_pool); - utils::SetRawDataInTensorProto(proto, destination, count * sizeof(float)); + const auto count = narrow(data.Shape().Size()); + MlasConvertHalfToFloatBufferInParallel(data.Data(), dst.MutableData(), count, thread_pool); } }; @@ -199,26 +270,54 @@ inline void SetNameDims(const std::string& name, ONNX_NAMESPACE::TensorProto Initializer::ToFP16(const std::string& name) const { ONNX_NAMESPACE::TensorProto tensor_proto; - SetNameDims(name, data_.Shape().GetDims(), ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, tensor_proto); - utils::MLTypeCallDispatcher t_disp(data_.GetElementType()); - t_disp.Invoke(data_, tensor_proto); + SetNameDims(name, data_->Shape().GetDims(), ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, tensor_proto); + utils::MLTypeCallDispatcher t_disp(data_->GetElementType()); + t_disp.Invoke(*data_, tensor_proto); return tensor_proto; } ONNX_NAMESPACE::TensorProto Initializer::ToBFloat16(const std::string& name) const { ONNX_NAMESPACE::TensorProto tensor_proto; - SetNameDims(name, data_.Shape().GetDims(), ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16, tensor_proto); - utils::MLTypeCallDispatcher t_disp(data_.GetElementType()); - t_disp.Invoke(data_, tensor_proto); + SetNameDims(name, data_->Shape().GetDims(), ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16, tensor_proto); + utils::MLTypeCallDispatcher t_disp(data_->GetElementType()); + t_disp.Invoke(*data_, tensor_proto); return tensor_proto; } -ONNX_NAMESPACE::TensorProto Initializer::ToFloat32(const std::string& name, onnxruntime::concurrency::ThreadPool* thread_pool) const { - ONNX_NAMESPACE::TensorProto tensor_proto; - SetNameDims(name, data_.Shape().GetDims(), ONNX_NAMESPACE::TensorProto_DataType_FLOAT, tensor_proto); - utils::MLTypeCallDispatcher t_disp(data_.GetElementType()); - t_disp.Invoke(data_, tensor_proto, thread_pool); - return tensor_proto; +OrtValue onnxruntime::Initializer::ToFP16() const { + if (data_->IsDataType()) { + return ort_value_; + } + OrtValue result; + auto tensor = Tensor(DataTypeImpl::GetType(), data_->Shape().GetDims(), CPUAllocator::DefaultInstance()); + Tensor::InitOrtValue(std::move(tensor), result); + utils::MLTypeCallDispatcher t_disp(data_->GetElementType()); + t_disp.Invoke(*data_, *result.GetMutable()); + return result; +} + +OrtValue Initializer::ToBFloat16() const { + if (data_->IsDataType()) { + return ort_value_; + } + OrtValue result; + auto tensor = Tensor(DataTypeImpl::GetType(), data_->Shape().GetDims(), CPUAllocator::DefaultInstance()); + Tensor::InitOrtValue(std::move(tensor), result); + utils::MLTypeCallDispatcher t_disp(data_->GetElementType()); + t_disp.Invoke(*data_, *result.GetMutable()); + return result; +} + +OrtValue Initializer::ToFloat32(onnxruntime::concurrency::ThreadPool* thread_pool) const { + if (data_->IsDataType()) { + return ort_value_; + } + OrtValue result; + auto tensor = Tensor(DataTypeImpl::GetType(), data_->Shape().GetDims(), CPUAllocator::DefaultInstance()); + Tensor::InitOrtValue(std::move(tensor), result); + utils::MLTypeCallDispatcher t_disp(data_->GetElementType()); + t_disp.Invoke(*data_, *result.GetMutable(), thread_pool); + return result; } namespace { @@ -314,46 +413,46 @@ struct ElementWiseDiv : OpElementWise::typ } // namespace Initializer& Initializer::add(float value) { - utils::MLTypeCallDispatcher t_disp(data_.GetElementType()); - t_disp.Invoke(data_, value); + utils::MLTypeCallDispatcher t_disp(data_->GetElementType()); + t_disp.Invoke(*data_, value); return *this; } Initializer& Initializer::add(const Initializer& other) { ORT_ENFORCE(data_type() == other.data_type(), "Expecting the same data type"); ORT_ENFORCE(size() == other.size(), "Expecting the same size"); - utils::MLTypeCallDispatcher t_disp(data_.GetElementType()); - t_disp.Invoke(data_, other.data_); + utils::MLTypeCallDispatcher t_disp(data_->GetElementType()); + t_disp.Invoke(*data_, *other.data_); return *this; } Initializer& Initializer::sub(const Initializer& other) { ORT_ENFORCE(data_type() == other.data_type(), "Expecting the same data type"); ORT_ENFORCE(size() == other.size(), "Expecting the same size"); - utils::MLTypeCallDispatcher t_disp(data_.GetElementType()); - t_disp.Invoke(data_, other.data_); + utils::MLTypeCallDispatcher t_disp(data_->GetElementType()); + t_disp.Invoke(*data_, *other.data_); return *this; } Initializer& Initializer::mul(const Initializer& other) { ORT_ENFORCE(data_type() == other.data_type(), "Expecting the same data type"); ORT_ENFORCE(size() == other.size(), "Expecting the same size"); - utils::MLTypeCallDispatcher t_disp(data_.GetElementType()); - t_disp.Invoke(data_, other.data_); + utils::MLTypeCallDispatcher t_disp(data_->GetElementType()); + t_disp.Invoke(*data_, *other.data_); return *this; } Initializer& Initializer::div(const Initializer& other) { ORT_ENFORCE(data_type() == other.data_type(), "Expecting the same data type"); ORT_ENFORCE(size() == other.size(), "Expecting the same size"); - utils::MLTypeCallDispatcher t_disp(data_.GetElementType()); - t_disp.Invoke(data_, other.data_); + utils::MLTypeCallDispatcher t_disp(data_->GetElementType()); + t_disp.Invoke(*data_, *other.data_); return *this; } Initializer& Initializer::sqrt() { - utils::MLTypeCallDispatcher t_disp(data_.GetElementType()); - t_disp.Invoke(data_); + utils::MLTypeCallDispatcher t_disp(data_->GetElementType()); + t_disp.Invoke(*data_); return *this; } @@ -395,13 +494,13 @@ struct ScaleByAxis { void Initializer::scale_by_axis(const Initializer& scalers, int axis, bool column_major) { ORT_ENFORCE(axis >= 0, "Axis must be non-negative"); - const size_t block_size = narrow(data_.Shape().SizeFromDimension(gsl::narrow_cast(axis))); + const size_t block_size = narrow(data_->Shape().SizeFromDimension(gsl::narrow_cast(axis))); const size_t num_blocks = size() / block_size; ORT_ENFORCE(scalers.size() == 1 || (column_major ? scalers.size() == block_size : scalers.size() == num_blocks), "Invalid other(scalers) size"); - utils::MLTypeCallDispatcher t_disp(data_.GetElementType()); - t_disp.Invoke(data_, scalers.data_, block_size, num_blocks, column_major); + utils::MLTypeCallDispatcher t_disp(data_->GetElementType()); + t_disp.Invoke(*data_, *scalers.data_, block_size, num_blocks, column_major); } #endif // ORT_EXTENDED_MINIMAL_BUILD } // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/initializer.h b/onnxruntime/core/optimizer/initializer.h index 17d1ada29d778..96c2ca41f5539 100644 --- a/onnxruntime/core/optimizer/initializer.h +++ b/onnxruntime/core/optimizer/initializer.h @@ -11,10 +11,11 @@ #include "core/common/common.h" #include "core/common/narrow.h" #include "core/framework/allocator.h" -#include "core/optimizer/graph_transformer.h" +#include "core/framework/ort_value.h" #include "core/framework/tensor_shape.h" #include "core/framework/tensorprotoutils.h" #include "core/graph/onnx_protobuf.h" +#include "core/optimizer/graph_transformer.h" #include "core/util/math.h" namespace onnxruntime { @@ -29,50 +30,82 @@ class Initializer final { Initializer(const ONNX_NAMESPACE::TensorProto& tensor_proto, const std::filesystem::path& model_path = {}); - ~Initializer() = default; + Initializer(const Graph& graph, const ONNX_NAMESPACE::TensorProto& tensor_proto, + const std::filesystem::path& model_path = {}, bool check_outer_scope = false); + ~Initializer(); + + /// + /// This function creates a new tensor_proto with a complete copy of the data + /// + /// output void ToProto(ONNX_NAMESPACE::TensorProto& tensor_proto) const { - tensor_proto = utils::TensorToTensorProto(data_, name_); + tensor_proto = utils::TensorToTensorProto(*data_, name_); } + + /// + /// This function creates a pair of TensorProto and OrtValue. Unless the data + /// is short, tensor_proto will be a reference to the data in OrtValue. + /// Useful when adding a new initializer to the graph with external data in memory. + /// + /// + /// + void ToProtoWithOrtValue(ONNX_NAMESPACE::TensorProto& tensor_proto, OrtValue& ort_value) const; + #if !defined(ORT_EXTENDED_MINIMAL_BUILD) + // XXX: Below two used only in training, convert to OrtValue result ONNX_NAMESPACE::TensorProto ToFP16(const std::string& name) const; - ONNX_NAMESPACE::TensorProto ToBFloat16(const std::string& name) const; - ONNX_NAMESPACE::TensorProto ToFloat32(const std::string& name, onnxruntime::concurrency::ThreadPool* thread_pool = nullptr) const; + OrtValue ToFP16() const; + OrtValue ToBFloat16() const; + OrtValue ToFloat32(onnxruntime::concurrency::ThreadPool* thread_pool = nullptr) const; + #endif // ORT_EXTENDED_MINIMAL_BUILD int data_type() const { - return data_.GetElementType(); + return data_->GetElementType(); } - std::string_view name() const { + const std::string& name() const { return name_; } template T* data() { - return data_.MutableData(); + return data_->MutableData(); } template const T* data() const { - return data_.Data(); + return data_->Data(); + } + + const void* data_raw() const { + return data_->DataRaw(); + } + + void* mutable_data_raw() { + return data_->MutableDataRaw(); } template auto DataAsSpan() const { - return data_.DataAsSpan(); + return data_->DataAsSpan(); } gsl::span DataAsByteSpan() const { - return gsl::make_span(reinterpret_cast(data_.DataRaw()), data_.SizeInBytes()); + return gsl::make_span(reinterpret_cast(data_->DataRaw()), data_->SizeInBytes()); + } + + gsl::span MutableDataAsByteSpan() { + return gsl::make_span(reinterpret_cast(data_->MutableDataRaw()), data_->SizeInBytes()); } gsl::span dims() const { - return data_.Shape().GetDims(); + return data_->Shape().GetDims(); } - size_t size() const { return narrow(data_.Shape().Size()); } + size_t size() const { return narrow(data_->Shape().Size()); } #if !defined(ORT_EXTENDED_MINIMAL_BUILD) Initializer& add(float value); @@ -91,7 +124,8 @@ class Initializer final { #endif // ORT_EXTENDED_MINIMAL_BUILD private: std::string name_; - Tensor data_; + OrtValue ort_value_; + Tensor* data_; }; } // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/layer_norm_fusion.cc b/onnxruntime/core/optimizer/layer_norm_fusion.cc index 3f19fb46e5ade..1e88ed44b1a8a 100644 --- a/onnxruntime/core/optimizer/layer_norm_fusion.cc +++ b/onnxruntime/core/optimizer/layer_norm_fusion.cc @@ -70,7 +70,7 @@ static std::vector GetAxesFromReduceMeanNode(Node& reduce_mean_node, co const auto* axes = reduce_mean_node.InputDefs()[1]; const auto* axes_const = graph.GetConstantInitializer(axes->Name(), true); if (axes_const != nullptr) { - Initializer initializer{*axes_const, graph.ModelPath()}; + Initializer initializer{graph, *axes_const, graph.ModelPath()}; auto span_axes = initializer.DataAsSpan(); axes_values.insert(axes_values.end(), span_axes.begin(), span_axes.end()); } @@ -480,7 +480,7 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, const ONNX_NAMESPACE::TensorProto* tensor_proto = graph_utils::GetConstantInitializer(graph, add2_node.MutableInputDefs()[1]->Name()); if (tensor_proto != nullptr && tensor_proto->data_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { - Initializer initializer{*tensor_proto, graph.ModelPath()}; + Initializer initializer{graph, *tensor_proto, graph.ModelPath()}; layer_norm_node.AddAttribute("epsilon", initializer.data()[0]); } else { layer_norm_node.AddAttribute("epsilon", DEFAULT_LAYERNORM_EPSILON); @@ -727,7 +727,7 @@ Status SimplifiedLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int gr const ONNX_NAMESPACE::TensorProto* tensor_proto = graph_utils::GetConstantInitializer(graph, add_node.MutableInputDefs()[1]->Name()); if (tensor_proto != nullptr && tensor_proto->data_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { - Initializer initializer{*tensor_proto, graph.ModelPath()}; + Initializer initializer{graph, *tensor_proto, graph.ModelPath()}; layer_norm_node.AddAttribute("epsilon", initializer.data()[0]); } else { layer_norm_node.AddAttribute("epsilon", DEFAULT_LAYERNORM_EPSILON); diff --git a/onnxruntime/core/optimizer/matmul_add_fusion.cc b/onnxruntime/core/optimizer/matmul_add_fusion.cc index d02efe9890f1c..a6c422e59aeef 100644 --- a/onnxruntime/core/optimizer/matmul_add_fusion.cc +++ b/onnxruntime/core/optimizer/matmul_add_fusion.cc @@ -188,7 +188,7 @@ Status MatMulAddFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, shape_initializer_proto.add_dims(static_cast(shape.size())); shape_initializer_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); utils::SetRawDataInTensorProto(shape_initializer_proto, shape.data(), shape.size() * sizeof(int64_t)); - NodeArg* shape_arg = &graph_utils::AddInitializer(graph, shape_initializer_proto); + NodeArg* shape_arg = &graph_utils::AddInitializerWithExternalData(graph, shape_initializer_proto); ONNX_NAMESPACE::TypeProto new_arg_type; const ONNX_NAMESPACE::TensorProto_DataType element_type = static_cast( gemm_input_defs[0]->TypeAsProto()->tensor_type().elem_type()); diff --git a/onnxruntime/core/optimizer/matmul_bn_fusion.cc b/onnxruntime/core/optimizer/matmul_bn_fusion.cc index 6b76dc626fba0..725cb3fc33f04 100644 --- a/onnxruntime/core/optimizer/matmul_bn_fusion.cc +++ b/onnxruntime/core/optimizer/matmul_bn_fusion.cc @@ -193,11 +193,11 @@ Status MatmulBNFusion::Apply(Graph& graph, Node& matmul_node, RewriteRuleEffect& * temp = scale / sqrt(var + epsilon) * output = (temp * Input) - ((temp * mean) + bias) */ - Initializer scale(*scale_tensor, graph.ModelPath()); - Initializer bias(*bias_tensor, graph.ModelPath()); - Initializer mean(*mean_tensor, graph.ModelPath()); - Initializer var(*var_tensor, graph.ModelPath()); - Initializer matmul_b(*matmul_b_tensor, graph.ModelPath()); + Initializer scale(graph, *scale_tensor, graph.ModelPath()); + Initializer bias(graph, *bias_tensor, graph.ModelPath()); + Initializer mean(graph, *mean_tensor, graph.ModelPath()); + Initializer var(graph, *var_tensor, graph.ModelPath()); + Initializer matmul_b(graph, *matmul_b_tensor, graph.ModelPath()); var.add(epsilon); var.sqrt(); @@ -208,18 +208,18 @@ Status MatmulBNFusion::Apply(Graph& graph, Node& matmul_node, RewriteRuleEffect& bias.sub(mean); // create B tensorProto for new Gemm node from initializer. - ONNX_NAMESPACE::TensorProto new_gemm_b_tensor(*matmul_b_tensor); + ONNX_NAMESPACE::TensorProto new_gemm_b_tensor; matmul_b.ToProto(new_gemm_b_tensor); const std::string new_gemm_b_name = graph.GenerateNodeArgName("MatMulBnFusion_GemmB_" + matmul_b_tensor->name()); new_gemm_b_tensor.set_name(new_gemm_b_name); - NodeArg& new_gemm_b_node_arg = graph_utils::AddInitializer(graph, new_gemm_b_tensor); + NodeArg& new_gemm_b_node_arg = graph_utils::AddInitializerWithExternalData(graph, new_gemm_b_tensor); // create bias tensorProto for new Gemm node from initializer. - ONNX_NAMESPACE::TensorProto new_gemm_bias_tensor(*bias_tensor); + ONNX_NAMESPACE::TensorProto new_gemm_bias_tensor; bias.ToProto(new_gemm_bias_tensor); const std::string new_gemm_bias_name = graph.GenerateNodeArgName("MatMulBnFusion_GemmBias"); new_gemm_bias_tensor.set_name(new_gemm_bias_name); - NodeArg& new_gemm_bias_node_arg = graph_utils::AddInitializer(graph, new_gemm_bias_tensor); + NodeArg& new_gemm_bias_node_arg = graph_utils::AddInitializerWithExternalData(graph, new_gemm_bias_tensor); Node& gemm_node = graph.AddNode( graph.GenerateNodeArgName("MatMulBnFusion_Gemm"), diff --git a/onnxruntime/core/optimizer/nchwc_transformer.cc b/onnxruntime/core/optimizer/nchwc_transformer.cc index 46f306b92bed5..335209dbfadaf 100644 --- a/onnxruntime/core/optimizer/nchwc_transformer.cc +++ b/onnxruntime/core/optimizer/nchwc_transformer.cc @@ -408,7 +408,7 @@ void NchwcTransformerImpl::TransformConv(Node& node) { // Reuse the existing NodeArg. nchwc_conv_W_arg = filters_it->second; } else { - Initializer conv_W{*conv_W_tensor_proto, graph_.ModelPath()}; + Initializer conv_W{graph_, *conv_W_tensor_proto, graph_.ModelPath()}; const auto conv_W_dims = conv_W.dims(); int64_t reordered_filter_size = nchwc_output_channels * filter_input_channels; @@ -437,7 +437,7 @@ void NchwcTransformerImpl::TransformConv(Node& node) { nchwc_conv_W_tensor_proto.add_dims(conv_W_dims[i]); } - nchwc_conv_W_arg = &graph_utils::AddInitializer(graph_, nchwc_conv_W_tensor_proto); + nchwc_conv_W_arg = &graph_utils::AddInitializerWithExternalData(graph_, nchwc_conv_W_tensor_proto); filters_map->emplace(input_defs[1], nchwc_conv_W_arg); } @@ -449,7 +449,7 @@ void NchwcTransformerImpl::TransformConv(Node& node) { // Reuse the existing NodeArg. nchwc_conv_B_arg = biases_it->second; } else { - Initializer conv_B{*conv_B_tensor_proto, graph_.ModelPath()}; + Initializer conv_B{graph_, *conv_B_tensor_proto, graph_.ModelPath()}; InlinedVector aligned_bias(gsl::narrow(nchwc_output_channels)); ORT_ENFORCE(output_channels <= nchwc_output_channels, "Buffer overflow"); @@ -464,7 +464,7 @@ void NchwcTransformerImpl::TransformConv(Node& node) { nchwc_conv_B_tensor_proto.add_dims(nchwc_output_channels); - nchwc_conv_B_arg = &graph_utils::AddInitializer(graph_, nchwc_conv_B_tensor_proto); + nchwc_conv_B_arg = &graph_utils::AddInitializerWithExternalData(graph_, nchwc_conv_B_tensor_proto); aligned_biases_.emplace(input_defs[2], nchwc_conv_B_arg); } } @@ -580,7 +580,7 @@ Node& NchwcTransformerImpl::InsertReshape(NodeArg* input_arg, } shape_tensor_proto.add_dims(split_channels ? kNchwcDims + 1 : kNchwcDims); - shape_arg = &graph_utils::AddInitializer(graph_, shape_tensor_proto); + shape_arg = &graph_utils::AddInitializerWithExternalData(graph_, shape_tensor_proto); } Node& reshape_node = graph_.AddNode(graph_.GenerateNodeName("Reshape"), @@ -863,10 +863,10 @@ void NchwcTransformerImpl::TransformBatchNormalization(Node& node) { return; } - Initializer bn_scale{*bn_scale_tensor_proto, graph_.ModelPath()}; - Initializer bn_B{*bn_B_tensor_proto, graph_.ModelPath()}; - Initializer bn_mean{*bn_mean_tensor_proto, graph_.ModelPath()}; - Initializer bn_var{*bn_var_tensor_proto, graph_.ModelPath()}; + Initializer bn_scale{graph_, *bn_scale_tensor_proto, graph_.ModelPath()}; + Initializer bn_B{graph_, *bn_B_tensor_proto, graph_.ModelPath()}; + Initializer bn_mean{graph_, *bn_mean_tensor_proto, graph_.ModelPath()}; + Initializer bn_var{graph_, *bn_var_tensor_proto, graph_.ModelPath()}; // Calculate the scale and bias for the replacement convolution. bn_var.add(epsilon); @@ -892,7 +892,7 @@ void NchwcTransformerImpl::TransformBatchNormalization(Node& node) { nchwc_conv_W_tensor_proto.add_dims(1); nchwc_conv_W_tensor_proto.add_dims(1); - auto* nchwc_conv_W_arg = &graph_utils::AddInitializer(graph_, nchwc_conv_W_tensor_proto); + auto* nchwc_conv_W_arg = &graph_utils::AddInitializerWithExternalData(graph_, nchwc_conv_W_tensor_proto); std::copy_n(bn_B.data(), channels, padded_buffer.data()); @@ -903,7 +903,7 @@ void NchwcTransformerImpl::TransformBatchNormalization(Node& node) { gsl::narrow(nchwc_channels) * sizeof(float)); nchwc_conv_B_tensor_proto.add_dims(nchwc_channels); - auto* nchwc_conv_B_arg = &graph_utils::AddInitializer(graph_, nchwc_conv_B_tensor_proto); + auto* nchwc_conv_B_arg = &graph_utils::AddInitializerWithExternalData(graph_, nchwc_conv_B_tensor_proto); // Create the replacement node. std::string nchwc_node_name = graph_.GenerateNodeName(output_defs[0]->Name() + "_bn_nchwc"); @@ -1045,7 +1045,7 @@ void NchwcTransformerImpl::TransformResize(Node& node) { return; } - Initializer sizes{*sizes_tensor_proto, graph_.ModelPath()}; + Initializer sizes{graph_, *sizes_tensor_proto, graph_.ModelPath()}; auto* sizes_data = sizes.data(); // The sizes data can only be used if the input shape is static and the @@ -1075,7 +1075,7 @@ void NchwcTransformerImpl::TransformResize(Node& node) { return; } - Initializer scales{*scales_tensor_proto, graph_.ModelPath()}; + Initializer scales{graph_, *scales_tensor_proto, graph_.ModelPath()}; auto* scales_data = scales.data(); // Cast the scales to integers and verify that the scales are positive and diff --git a/onnxruntime/core/optimizer/noop_elimination.cc b/onnxruntime/core/optimizer/noop_elimination.cc index bba39b698a27a..6dafd9cd97799 100644 --- a/onnxruntime/core/optimizer/noop_elimination.cc +++ b/onnxruntime/core/optimizer/noop_elimination.cc @@ -68,7 +68,7 @@ bool NoopElimination::SatisfyCondition(const Graph& graph, const Node& node, con op_type == "Mul" || op_type == "Div") { int32_t data_type = initializer->data_type(); - Initializer add_init(*initializer, graph.ModelPath()); + Initializer add_init(graph, *initializer, graph.ModelPath()); float value = 0.0f; switch (data_type) { diff --git a/onnxruntime/core/optimizer/optimizer_execution_frame.cc b/onnxruntime/core/optimizer/optimizer_execution_frame.cc index b2e8e491c361c..8c26d7a9ce209 100644 --- a/onnxruntime/core/optimizer/optimizer_execution_frame.cc +++ b/onnxruntime/core/optimizer/optimizer_execution_frame.cc @@ -7,7 +7,6 @@ #include "core/common/logging/logging.h" #include "core/common/logging/macros.h" #include "core/common/status.h" -#include "core/framework/callback.h" #include "core/framework/data_transfer_manager.h" #include "core/framework/data_types.h" #include "core/framework/fuse_nodes_funcs.h" @@ -37,7 +36,7 @@ OptimizerExecutionFrame::Info::Info(const std::vector& nodes, : execution_provider_(execution_provider), is_sparse_initializer_func_(is_sparse_initializer_func), logger_(logger) { - allocator_ptr_ = std::make_shared(); + allocator_ptr_ = CPUAllocator::DefaultInstance(); ORT_ENFORCE(allocator_ptr_, "Failed to get allocator for optimizer"); ORT_THROW_IF_ERROR(data_transfer_mgr_.RegisterDataTransfer(std::make_unique())); @@ -86,7 +85,7 @@ OptimizerExecutionFrame::Info::Info(const std::vector& nodes, : execution_provider_(execution_provider), is_sparse_initializer_func_(is_sparse_initializer_func), logger_(logger) { - allocator_ptr_ = std::make_shared(); + allocator_ptr_ = CPUAllocator::DefaultInstance(); ORT_ENFORCE(allocator_ptr_, "Failed to get allocator for optimizer"); ORT_THROW_IF_ERROR(data_transfer_mgr_.RegisterDataTransfer(std::make_unique())); diff --git a/onnxruntime/core/optimizer/optimizer_execution_frame.h b/onnxruntime/core/optimizer/optimizer_execution_frame.h index 24a23312feba9..feb51514c8b2d 100644 --- a/onnxruntime/core/optimizer/optimizer_execution_frame.h +++ b/onnxruntime/core/optimizer/optimizer_execution_frame.h @@ -13,7 +13,6 @@ #include "core/framework/execution_frame.h" #include "core/framework/ort_value_name_idx_map.h" #include "core/framework/ort_value.h" -#include "core/framework/callback.h" namespace onnxruntime { class DataTransferManager; diff --git a/onnxruntime/core/optimizer/pad_fusion.cc b/onnxruntime/core/optimizer/pad_fusion.cc index bfd32a384335d..d0b6d42fd46c9 100644 --- a/onnxruntime/core/optimizer/pad_fusion.cc +++ b/onnxruntime/core/optimizer/pad_fusion.cc @@ -117,7 +117,7 @@ bool PadFusion::SatisfyCondition(const Graph& graph, const Node& node, const log // constant_value should be zero because Conv and MaxPool allow only 0 as padding value. if (node.InputDefs().size() > 2) { const auto* pad_constant_value_proto = graph_utils::GetConstantInitializer(graph, node.InputDefs()[2]->Name()); - Initializer pad_constant_value{*pad_constant_value_proto, graph.ModelPath()}; + Initializer pad_constant_value{graph, *pad_constant_value_proto, graph.ModelPath()}; if (std::any_of(pad_constant_value.DataAsByteSpan().begin(), pad_constant_value.DataAsByteSpan().end(), [](const uint8_t byte) { return byte != 0; })) { return false; } @@ -152,7 +152,7 @@ Status PadFusion::Apply(Graph& graph, Node& pad_node, RewriteRuleEffect& rule_ef if (pad_node.SinceVersion() >= 11) { const auto* pads_proto = graph_utils::GetConstantInitializer(graph, pad_node.InputDefs()[1]->Name()); - Initializer pads{*pads_proto, graph.ModelPath()}; + Initializer pads{graph, *pads_proto, graph.ModelPath()}; pads_values.assign(pads.DataAsSpan().begin(), pads.DataAsSpan().end()); } else { pads_values.assign(pad_node.GetAttributes().at("pads").ints().begin(), pad_node.GetAttributes().at("pads").ints().end()); diff --git a/onnxruntime/core/optimizer/qdq_transformer/avx2_weight_s8_to_u8.cc b/onnxruntime/core/optimizer/qdq_transformer/avx2_weight_s8_to_u8.cc index 5538aa54801cc..42cd31b5bd7b4 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/avx2_weight_s8_to_u8.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/avx2_weight_s8_to_u8.cc @@ -96,10 +96,10 @@ static bool TryConvertDynamicQuantizeLSTM(Node& op_node, Graph& graph, const log } bool should_convert = false; - Initializer w_temp(*weight_tensor_proto, graph.ModelPath()); + Initializer w_temp(graph, *weight_tensor_proto, graph.ModelPath()); { int8_t* p = w_temp.data(); - for (size_t i = 0; i < w_temp.size(); i++) { + for (size_t i = 0, lim = w_temp.size(); i < lim; i++) { if (*p < -64 || *p > 64) { should_convert = true; } @@ -108,10 +108,10 @@ static bool TryConvertDynamicQuantizeLSTM(Node& op_node, Graph& graph, const log } } - Initializer r_temp(*r_tensor_proto, graph.ModelPath()); + Initializer r_temp(graph, *r_tensor_proto, graph.ModelPath()); { int8_t* p = r_temp.data(); - for (size_t i = 0; i < r_temp.size(); i++) { + for (size_t i = 0, lim = r_temp.size(); i < lim; i++) { if (*p < -64 || *p > 64) { should_convert = true; } @@ -130,22 +130,22 @@ static bool TryConvertDynamicQuantizeLSTM(Node& op_node, Graph& graph, const log weights_proto_u8.set_name(weight_tensor_proto->name() + "_s8_2_u8"); weights_proto_u8.mutable_dims()->CopyFrom(weight_tensor_proto->dims()); utils::SetRawDataInTensorProto(weights_proto_u8, w_temp.data(), static_cast(w_temp.size())); - input_defs[w_idx] = &graph_utils::AddInitializer(graph, weights_proto_u8); + input_defs[w_idx] = &graph_utils::AddInitializerWithExternalData(graph, weights_proto_u8); ONNX_NAMESPACE::TensorProto weight_zp_proto_u8; QDQ::Int8TensorProto2Uint8(weight_zp_tensor_proto, weight_zp_proto_u8, graph, true); - input_defs[w_zp_idx] = &graph_utils::AddInitializer(graph, weight_zp_proto_u8); + input_defs[w_zp_idx] = &graph_utils::AddInitializerWithExternalData(graph, weight_zp_proto_u8); ONNX_NAMESPACE::TensorProto r_proto_u8; r_proto_u8.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_UINT8); r_proto_u8.set_name(r_tensor_proto->name() + "_s8_2_u8"); r_proto_u8.mutable_dims()->CopyFrom(r_tensor_proto->dims()); utils::SetRawDataInTensorProto(r_proto_u8, r_temp.data(), static_cast(r_temp.size())); - input_defs[r_idx] = &graph_utils::AddInitializer(graph, r_proto_u8); + input_defs[r_idx] = &graph_utils::AddInitializerWithExternalData(graph, r_proto_u8); ONNX_NAMESPACE::TensorProto r_zp_proto_u8; QDQ::Int8TensorProto2Uint8(r_zp_tensor_proto, r_zp_proto_u8, graph, true); - input_defs[r_zp_idx] = &graph_utils::AddInitializer(graph, r_zp_proto_u8); + input_defs[r_zp_idx] = &graph_utils::AddInitializerWithExternalData(graph, r_zp_proto_u8); return true; } diff --git a/onnxruntime/core/optimizer/qdq_transformer/clip_quantizelinear.cc b/onnxruntime/core/optimizer/qdq_transformer/clip_quantizelinear.cc index 72ca1cb74f1fd..a1859b9d7071b 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/clip_quantizelinear.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/clip_quantizelinear.cc @@ -30,7 +30,7 @@ static bool GetQConstantLowerUpper(const Graph& graph, const Node& node, float& return false; } - Initializer s_initializer(*s_tensor_proto, graph.ModelPath()); + Initializer s_initializer(graph, *s_tensor_proto, graph.ModelPath()); if (s_initializer.dims().size() != 0 || s_initializer.data_type() != ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { return false; @@ -45,7 +45,7 @@ static bool GetQConstantLowerUpper(const Graph& graph, const Node& node, float& return false; } - Initializer zp_initializer(*zp_tensor_proto, graph.ModelPath()); + Initializer zp_initializer(graph, *zp_tensor_proto, graph.ModelPath()); if (zp_initializer.dims().size() != 0) { return false; } diff --git a/onnxruntime/core/optimizer/qdq_transformer/qdq_final_cleanup.cc b/onnxruntime/core/optimizer/qdq_transformer/qdq_final_cleanup.cc index 507bc71709b2f..691cf1183eb0e 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/qdq_final_cleanup.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/qdq_final_cleanup.cc @@ -51,7 +51,8 @@ bool CleanUpNodeSequence(NodeSequence node_sequence_type, Graph& graph, NodeInde const auto output_edges_count = second_node_ptr->GetOutputEdgesCount(); if (!match_second(*second_node_ptr) || - !QDQ::IsQDQPairSupported(first_node, *second_node_ptr, get_constant_initializer, graph.ModelPath(), false) || + !QDQ::IsQDQPairSupported(graph, first_node, *second_node_ptr, get_constant_initializer, + graph.ModelPath(), false) || (produces_graph_output && output_edges_count != 0) || (!produces_graph_output && output_edges_count != 1)) { return false; diff --git a/onnxruntime/core/optimizer/qdq_transformer/qdq_s8_to_u8.cc b/onnxruntime/core/optimizer/qdq_transformer/qdq_s8_to_u8.cc index f2033dcbc1b03..98c818b0c761b 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/qdq_s8_to_u8.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/qdq_s8_to_u8.cc @@ -41,8 +41,8 @@ static bool QDQ_S8_to_U8(Graph& graph, Node& q_node, Node& dq_node) { // TODO(fuchen): need to augment this when we support per row quantization using ONNX_TENSOR_ELEM_TYPE = ONNX_NAMESPACE::TensorProto::DataType; - Initializer q_zero_point(*q_zp_tensor_proto, graph.ModelPath()); - Initializer dq_zero_point(*dq_zp_tensor_proto, graph.ModelPath()); + Initializer q_zero_point(graph, *q_zp_tensor_proto, graph.ModelPath()); + Initializer dq_zero_point(graph, *dq_zp_tensor_proto, graph.ModelPath()); if (q_zero_point.size() != 1 || dq_zero_point.size() != 1 || q_zero_point.data_type() != ONNX_TENSOR_ELEM_TYPE::TensorProto_DataType_INT8 || @@ -61,7 +61,7 @@ static bool QDQ_S8_to_U8(Graph& graph, Node& q_node, Node& dq_node) { zp_tensor_proto_u8.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_UINT8); zp_tensor_proto_u8.set_name(graph.GenerateNodeArgName("qdq_s8_to_u8_zp_conversion")); utils::SetRawDataInTensorProto(zp_tensor_proto_u8, &q_zp_value, sizeof(uint8_t)); - NodeArg* zp_u8_arg = &graph_utils::AddInitializer(graph, zp_tensor_proto_u8); + NodeArg* zp_u8_arg = &graph_utils::AddInitializerWithExternalData(graph, zp_tensor_proto_u8); auto q_output_node_arg_name = graph.GenerateNodeArgName("qdq_s8_to_u8_quant"); NodeArg* q_output_arg = &graph.GetOrCreateNodeArg(q_output_node_arg_name, nullptr); diff --git a/onnxruntime/core/optimizer/qdq_transformer/qdq_util.cc b/onnxruntime/core/optimizer/qdq_transformer/qdq_util.cc index fe5874d067b95..3ecdbf0ede6b3 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/qdq_util.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/qdq_util.cc @@ -15,6 +15,7 @@ namespace onnxruntime::QDQ { bool IsQDQPairSupported( + const Graph& graph, const Node& q_node, const Node& dq_node, const GetConstantInitializerFn& get_const_initializer, const std::filesystem::path& model_path, @@ -56,10 +57,10 @@ bool IsQDQPairSupported( } // check Q/DQ have same scale and zero point - Initializer q_zp(*q_zp_tensor_proto, model_path); - Initializer q_scale(*q_scale_tensor_proto, model_path); - Initializer dq_zp(*dq_zp_tensor_proto, model_path); - Initializer dq_scale(*dq_scale_tensor_proto, model_path); + Initializer q_zp(graph, *q_zp_tensor_proto, model_path); + Initializer q_scale(graph, *q_scale_tensor_proto, model_path); + Initializer dq_zp(graph, *dq_zp_tensor_proto, model_path); + Initializer dq_scale(graph, *dq_scale_tensor_proto, model_path); if (q_zp.data_type() != dq_zp.data_type() || q_scale.data_type() != dq_scale.data_type() || @@ -84,6 +85,7 @@ bool IsQDQPairSupported( } bool IsDQQConversion( + const Graph& graph, const Node& dq_node, const Node& q_node, const GetConstantInitializerFn& get_const_initializer, const std::filesystem::path& model_path) { @@ -118,10 +120,10 @@ bool IsDQQConversion( } // check Q/DQ have same scale type and different zero point type - Initializer q_zp(*q_zp_tensor_proto, model_path); - Initializer q_scale(*q_scale_tensor_proto, model_path); - Initializer dq_zp(*dq_zp_tensor_proto, model_path); - Initializer dq_scale(*dq_scale_tensor_proto, model_path); + Initializer q_zp(graph, *q_zp_tensor_proto, model_path); + Initializer q_scale(graph, *q_scale_tensor_proto, model_path); + Initializer dq_zp(graph, *dq_zp_tensor_proto, model_path); + Initializer dq_scale(graph, *dq_scale_tensor_proto, model_path); return (dq_zp.data_type() != q_zp.data_type()) && (dq_scale.data_type() == q_scale.data_type()); } @@ -167,6 +169,7 @@ bool QOrDQNodeHasConstantScalarScaleAndZeroPoint( } bool IsQOrDQScalePositiveConstantScalar( + const Graph& graph, const Node& q_or_dq_node, const GetConstantInitializerFn& get_const_initializer, const std::filesystem::path& model_path) { auto q_or_dq_input_defs = q_or_dq_node.InputDefs(); @@ -183,7 +186,7 @@ bool IsQOrDQScalePositiveConstantScalar( return false; } - Initializer q_or_dq_scale(*q_or_dq_scale_tensor_proto, model_path); + Initializer q_or_dq_scale(graph, *q_or_dq_scale_tensor_proto, model_path); switch (q_or_dq_scale.data_type()) { case ONNX_NAMESPACE::TensorProto::FLOAT: @@ -250,7 +253,7 @@ bool GetQScalarScaleZp(const Graph& graph, const Node& q_node, float& scale, int } // Support scalar float scale only for now. Need to extend to other float types if needed. - Initializer scale_initializer(*scale_tensor_proto, graph.ModelPath()); + Initializer scale_initializer(graph, *scale_tensor_proto, graph.ModelPath()); if (scale_initializer.dims().size() != 0 || scale_initializer.data_type() != ONNX_NAMESPACE::TensorProto::FLOAT) { return false; } @@ -275,7 +278,7 @@ bool GetQScalarScaleZp(const Graph& graph, const Node& q_node, float& scale, int return false; } - Initializer zp_initializer(*zp_tensor_proto, graph.ModelPath()); + Initializer zp_initializer(graph, *zp_tensor_proto, graph.ModelPath()); if (zp_initializer.dims().size() != 0) { return false; } diff --git a/onnxruntime/core/optimizer/qdq_transformer/qdq_util.h b/onnxruntime/core/optimizer/qdq_transformer/qdq_util.h index 25bd557b799c6..0648a3fc1f188 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/qdq_util.h +++ b/onnxruntime/core/optimizer/qdq_transformer/qdq_util.h @@ -36,6 +36,7 @@ using GetConstantInitializerFn = std::function()[0] != -128) || diff --git a/onnxruntime/core/optimizer/qdq_transformer/s8_to_u8.cc b/onnxruntime/core/optimizer/qdq_transformer/s8_to_u8.cc index f094f3c199f2a..616144c0ccde0 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/s8_to_u8.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/s8_to_u8.cc @@ -43,12 +43,12 @@ bool ConvertS8WeightToU8(Graph& graph, Node& op_node, // The weights fits into S7, overflow is not a problem, no need to convert to U8 return false; } - input_defs[weights_idx] = &graph_utils::AddInitializer(graph, weights_proto_u8); + input_defs[weights_idx] = &graph_utils::AddInitializerWithExternalData(graph, weights_proto_u8); // Convert weight zero point to uint8 ONNX_NAMESPACE::TensorProto weight_zp_proto_u8; Int8TensorProto2Uint8(weight_zp_tensor_proto, weight_zp_proto_u8, graph, true); - input_defs[weight_zp_idx] = &graph_utils::AddInitializer(graph, weight_zp_proto_u8); + input_defs[weight_zp_idx] = &graph_utils::AddInitializerWithExternalData(graph, weight_zp_proto_u8); return true; } diff --git a/onnxruntime/core/optimizer/qdq_transformer/s8_to_u8.h b/onnxruntime/core/optimizer/qdq_transformer/s8_to_u8.h index 1c1341fe5a127..a96f088c48306 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/s8_to_u8.h +++ b/onnxruntime/core/optimizer/qdq_transformer/s8_to_u8.h @@ -47,7 +47,7 @@ inline bool Int8TensorProto2Uint8( // principle. A better solution is to provide an efficient const iterator for // TensorProto. This require coordination with onnx side. - Initializer temp(*src, graph.ModelPath()); + Initializer temp(graph, *src, graph.ModelPath()); int8_t* p = temp.data(); bool should_convert = false; for (size_t i = 0; i < temp.size(); i++) { diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc index 8f99b7409d4fe..dce69e2913582 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc @@ -197,7 +197,8 @@ void SetOptionalZeroPoint::UpdateNodes(Graph& graph, const NodesToOptimize& sele const ONNX_NAMESPACE::TensorProto* dummy_zp_tensor_proto; if (!graph.GetInitializedTensor(zp_tensor_proto.name(), dummy_zp_tensor_proto)) { - graph.AddInitializedTensor(zp_tensor_proto); + // Zero points are small, no need for external data + graph_utils::AddInitializer(graph, zp_tensor_proto); } auto& node_arg = graph.GetOrCreateNodeArg(zp_tensor_proto.name(), nullptr); @@ -280,8 +281,7 @@ Status MatMulReplaceWithQLinear::Run(Graph& graph, const NodesToOptimize& select DQMatMulToMatMulNBitsAction::DQMatMulToMatMulNBitsAction( int64_t accuracy_level, - concurrency::ThreadPool* intra_op_thread_pool, - std::unordered_map>* p_buffered_tensors) + concurrency::ThreadPool* intra_op_thread_pool) : accuracy_level_{accuracy_level}, domain_{kMSDomain}, op_type_{"MatMulNBits"}, @@ -291,8 +291,7 @@ DQMatMulToMatMulNBitsAction::DQMatMulToMatMulNBitsAction( MoveAndAppend(target, ArgType::kInput, 0, ArgType::kInput), MoveAll(target, ArgType::kOutput)}; }()}, - intra_op_thread_pool_{intra_op_thread_pool}, - p_buffered_tensors_{p_buffered_tensors} { + intra_op_thread_pool_{intra_op_thread_pool} { ORT_ENFORCE(accuracy_level_ >= 0 && accuracy_level_ <= 4, "MatMulNBits accuracy level must be between 0 and 4"); } @@ -317,7 +316,6 @@ DQMatMulToMatMulNBitsAction::ExtraAttributes(const RuntimeState& runtime_state) Status DQMatMulToMatMulNBitsAction::ProcessNewNode(Graph& graph, const NodesToOptimize& selected_nodes, Node& replacement_node) const { - ORT_RETURN_IF_NOT(p_buffered_tensors_, "Buffered tensors map cannot be null"); const auto* dq_node = selected_nodes.Input(0); const auto* weight_arg = dq_node->InputDefs()[0]; const auto* scale_arg = dq_node->InputDefs()[1]; @@ -325,11 +323,16 @@ Status DQMatMulToMatMulNBitsAction::ProcessNewNode(Graph& graph, const auto& attrs = dq_node->GetAttributes(); const ONNX_NAMESPACE::TensorProto* weight_tensor_proto = nullptr; + ORT_RETURN_IF_NOT(graph.GetInitializedTensor(weight_arg->Name(), weight_tensor_proto), + "Missing required weight: ", weight_arg->Name(), " for node: ", dq_node->Name()); + const ONNX_NAMESPACE::TensorProto* scale_tensor_proto = nullptr; + ORT_RETURN_IF_NOT(graph.GetInitializedTensor(scale_arg->Name(), scale_tensor_proto), + "Missing required scale: ", scale_arg->Name(), " for node: ", dq_node->Name()); const ONNX_NAMESPACE::TensorProto* zp_tensor_proto = nullptr; - graph.GetInitializedTensor(weight_arg->Name(), weight_tensor_proto); - graph.GetInitializedTensor(scale_arg->Name(), scale_tensor_proto); if (zp_arg) { + // zero point is optional, one can have a NodeArg for a missing optional + // if the name is an empty string, and the below would not return ptr to a proto. graph.GetInitializedTensor(zp_arg->Name(), zp_tensor_proto); } @@ -343,37 +346,38 @@ Status DQMatMulToMatMulNBitsAction::ProcessNewNode(Graph& graph, // external file, a raw buffer, or a repeated field depending on the data // type. UnpackTensor() already contains some of these logic and is closest // to what we need. But it does not handle external data. - Initializer weight_src(*weight_tensor_proto, graph.ModelPath()); - Initializer scale_src(*scale_tensor_proto, graph.ModelPath()); + Initializer weight_src(graph, *weight_tensor_proto, graph.ModelPath()); + Initializer scale_src(graph, *scale_tensor_proto, graph.ModelPath()); auto uint8_type = DataTypeImpl::TensorTypeFromONNXEnum(ONNX_NAMESPACE::TensorProto_DataType_UINT8)->GetElementType(); auto scale_type = DataTypeImpl::TensorTypeFromONNXEnum(scale_src.data_type())->GetElementType(); - std::optional zp_src_ptr; - auto cpu_allocator = std::make_shared(); + + std::optional zp_src; + auto cpu_allocator = CPUAllocator::DefaultInstance(); auto weight_dst_name = graph.GenerateNodeArgName(weight_arg->Name() + "_T"); - auto weight_dst_ptr = std::make_unique(uint8_type, - TensorShape{N, quant_num, blob_bytes}, - cpu_allocator); + auto weight_dst = Tensor(uint8_type, + TensorShape{N, quant_num, blob_bytes}, + cpu_allocator); auto scale_dst_name = graph.GenerateNodeArgName(scale_arg->Name() + "_T"); auto scale_size = (TensorShape{N, quant_num}).Size(); - auto scale_dst_ptr = std::make_unique(scale_type, - TensorShape{scale_size}, - cpu_allocator); + auto scale_dst = Tensor(scale_type, + TensorShape{scale_size}, + cpu_allocator); std::string zp_dst_name; - std::unique_ptr zp_dst_ptr; + std::optional zp_dst; auto zp_size = (TensorShape{N, (quant_num + 1) / 2}).Size(); if (zp_tensor_proto) { - zp_src_ptr.emplace(*zp_tensor_proto, graph.ModelPath()); + zp_src.emplace(graph, *zp_tensor_proto, graph.ModelPath()); zp_dst_name = graph.GenerateNodeArgName(zp_arg->Name() + "_T"); - zp_dst_ptr = std::make_unique(uint8_type, - TensorShape{zp_size}, - cpu_allocator); + zp_dst = Tensor(uint8_type, + TensorShape{zp_size}, + cpu_allocator); } else if (weight_src.data_type() == ONNX_NAMESPACE::TensorProto_DataType_UINT4) { zp_dst_name = graph.GenerateNodeArgName("fused_DQ_MatMul_zero_point_T"); - zp_dst_ptr = std::make_unique(uint8_type, - TensorShape{zp_size}, - cpu_allocator); - memset(zp_dst_ptr->MutableDataRaw(), 0, zp_dst_ptr->SizeInBytes()); + zp_dst = Tensor(uint8_type, + TensorShape{zp_size}, + cpu_allocator); + memset(zp_dst->MutableDataRaw(), 0, zp_dst->SizeInBytes()); } if (scale_src.data_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { @@ -381,10 +385,10 @@ Status DQMatMulToMatMulNBitsAction::ProcessNewNode(Graph& graph, MlasQDQTransposeBlockwiseQuantized( weight_src.DataAsByteSpan().data(), scale_src.data(), - zp_src_ptr ? zp_src_ptr->DataAsByteSpan().data() : nullptr, - weight_dst_ptr->MutableData(), - scale_dst_ptr->MutableData(), - zp_dst_ptr ? zp_dst_ptr->MutableData() : nullptr, + zp_src ? zp_src->DataAsByteSpan().data() : nullptr, + weight_dst.MutableData(), + scale_dst.MutableData(), + zp_dst ? zp_dst->MutableData() : nullptr, true, static_cast(K), static_cast(N), @@ -394,10 +398,10 @@ Status DQMatMulToMatMulNBitsAction::ProcessNewNode(Graph& graph, MlasQDQTransposeBlockwiseQuantized( weight_src.DataAsByteSpan().data(), scale_src.data(), - zp_src_ptr ? zp_src_ptr->DataAsByteSpan().data() : nullptr, - weight_dst_ptr->MutableData(), - scale_dst_ptr->MutableData(), - zp_dst_ptr ? zp_dst_ptr->MutableData() : nullptr, + zp_src ? zp_src->DataAsByteSpan().data() : nullptr, + weight_dst.MutableData(), + scale_dst.MutableData(), + zp_dst ? zp_dst->MutableData() : nullptr, true, static_cast(K), static_cast(N), @@ -409,10 +413,10 @@ Status DQMatMulToMatMulNBitsAction::ProcessNewNode(Graph& graph, MlasQDQTransposeBlockwiseQuantized( weight_src.DataAsByteSpan().data(), scale_src.data(), - zp_src_ptr ? zp_src_ptr->DataAsByteSpan().data() : nullptr, - weight_dst_ptr->MutableData(), - scale_dst_ptr->MutableData(), - zp_dst_ptr ? zp_dst_ptr->MutableData() : nullptr, + zp_src ? zp_src->DataAsByteSpan().data() : nullptr, + weight_dst.MutableData(), + scale_dst.MutableData(), + zp_dst ? zp_dst->MutableData() : nullptr, true, static_cast(K), static_cast(N), @@ -423,10 +427,10 @@ Status DQMatMulToMatMulNBitsAction::ProcessNewNode(Graph& graph, MlasQDQTransposeBlockwiseQuantized( weight_src.DataAsByteSpan().data(), scale_src.data(), - zp_src_ptr ? zp_src_ptr->DataAsByteSpan().data() : nullptr, - weight_dst_ptr->MutableData(), - scale_dst_ptr->MutableData(), - zp_dst_ptr ? zp_dst_ptr->MutableData() : nullptr, + zp_src ? zp_src->DataAsByteSpan().data() : nullptr, + weight_dst.MutableData(), + scale_dst.MutableData(), + zp_dst ? zp_dst->MutableData() : nullptr, true, static_cast(K), static_cast(N), @@ -435,43 +439,24 @@ Status DQMatMulToMatMulNBitsAction::ProcessNewNode(Graph& graph, } } - auto weight_T_tp = utils::TensorToTensorProto(*weight_dst_ptr, weight_dst_name, true); - auto scale_T_tp = utils::TensorToTensorProto(*scale_dst_ptr, scale_dst_name, true); + auto weight_T_tp = utils::TensorToTensorProto(weight_dst, weight_dst_name, true); + auto scale_T_tp = utils::TensorToTensorProto(scale_dst, scale_dst_name, true); std::optional zp_T_tp; - if (zp_dst_ptr) { - zp_T_tp.emplace(utils::TensorToTensorProto(*zp_dst_ptr, zp_dst_name, true)); + if (zp_dst) { + zp_T_tp.emplace(utils::TensorToTensorProto(*zp_dst, zp_dst_name, true)); } auto& input_defs = replacement_node.MutableInputDefs(); - input_defs.push_back(&graph_utils::AddInitializer(graph, weight_T_tp)); + input_defs.push_back(&graph_utils::AddInitializerWithExternalData(graph, weight_T_tp, std::move(weight_dst))); replacement_node.MutableInputArgsCount().push_back(1); - if (weight_T_tp.data_location() == ONNX_NAMESPACE::TensorProto_DataLocation_EXTERNAL) { - // If tensor is too small, tensor proto directly copies data from tensor. The tensor allocated - // here can be directly destructed. - // Only keep the tensor in p_buffered_tensors_ when the tensor proto is using external data location - // and pointing the location to tensor's buffer. - ORT_RETURN_IF_NOT(p_buffered_tensors_->emplace(weight_dst_name, std::move(weight_dst_ptr)).second, - "Failed to add buffered tensor ", - weight_dst_name); - } - input_defs.push_back(&graph_utils::AddInitializer(graph, scale_T_tp)); + input_defs.push_back(&graph_utils::AddInitializerWithExternalData(graph, scale_T_tp, std::move(scale_dst))); replacement_node.MutableInputArgsCount().push_back(1); - if (scale_T_tp.data_location() == ONNX_NAMESPACE::TensorProto_DataLocation_EXTERNAL) { - ORT_RETURN_IF_NOT(p_buffered_tensors_->emplace(scale_dst_name, std::move(scale_dst_ptr)).second, - "Failed to add buffered tensor ", - scale_dst_name); - } if (zp_T_tp) { - input_defs.push_back(&graph_utils::AddInitializer(graph, zp_T_tp.value())); + input_defs.push_back(&graph_utils::AddInitializerWithExternalData(graph, zp_T_tp.value(), std::move(*zp_dst))); replacement_node.MutableInputArgsCount().push_back(1); - if (zp_T_tp->data_location() == ONNX_NAMESPACE::TensorProto_DataLocation_EXTERNAL) { - ORT_RETURN_IF_NOT(p_buffered_tensors_->emplace(zp_dst_name, std::move(zp_dst_ptr)).second, - "Failed to add buffered tensor ", - zp_dst_name); - } } return Status::OK(); diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.h b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.h index d25077ca4b491..02a8353707599 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.h +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.h @@ -86,8 +86,7 @@ struct MatMulReplaceWithQLinear : public Action { // used together with DQMatMulNodeGroupSelector, which does the sanity check struct DQMatMulToMatMulNBitsAction : public ReplaceWithNew { DQMatMulToMatMulNBitsAction(int64_t accuracy_level, - concurrency::ThreadPool* intra_op_thread_pool, - std::unordered_map>* p_buffered_tensors); + concurrency::ThreadPool* intra_op_thread_pool); private: std::string OpType(const RuntimeState&) const override { return op_type_; } @@ -106,7 +105,6 @@ struct DQMatMulToMatMulNBitsAction : public ReplaceWithNew { const std::string op_type_; const std::vector value_moves_; concurrency::ThreadPool* intra_op_thread_pool_; - std::unordered_map>* p_buffered_tensors_; }; struct GemmReplaceWithQuant : public Action { diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc index ae89af1f256d1..93eb33628105c 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc @@ -282,8 +282,7 @@ void MatMulQDQRules(SelectorActionRegistry& qdq_selector_action_registry, bool i void DQMatMulToMatMulNBitsRules(SelectorActionRegistry& qdq_selector_action_registry, int64_t qdq_matmulnbits_accuracy_level, - concurrency::ThreadPool* intra_op_thread_pool, - std::unordered_map>* p_buffered_tensors) { + concurrency::ThreadPool* intra_op_thread_pool) { // 2 nodes. DQ -> MatMul. DQ is the second input to MatMul. // DQ's weight is int4/uint4. DQ's scale is float/float16. // DQ is block-quantized along axis 0, with block_size >= 16 and as 2's power. @@ -291,8 +290,7 @@ void DQMatMulToMatMulNBitsRules(SelectorActionRegistry& qdq_selector_action_regi std::unique_ptr action = std::make_unique(qdq_matmulnbits_accuracy_level, - intra_op_thread_pool, - p_buffered_tensors); + intra_op_thread_pool); #if !defined(ORT_MINIMAL_BUILD) std::vector providers = {kCpuExecutionProvider, kCudaExecutionProvider, kDmlExecutionProvider}; @@ -353,8 +351,7 @@ void WhereQDQRules(SelectorActionRegistry& qdq_selector_action_registry) { SelectorActionRegistry CreateSelectorActionRegistry( bool is_int8_allowed, int64_t qdq_matmulnbits_accuracy_level, - concurrency::ThreadPool* intra_op_thread_pool, - std::unordered_map>* p_buffered_tensors) { + concurrency::ThreadPool* intra_op_thread_pool) { SelectorActionRegistry qdq_selector_action_registry; SplitQDQRules(qdq_selector_action_registry); DropQDQNodesRules(qdq_selector_action_registry); @@ -368,8 +365,7 @@ SelectorActionRegistry CreateSelectorActionRegistry( WhereQDQRules(qdq_selector_action_registry); DQMatMulToMatMulNBitsRules(qdq_selector_action_registry, qdq_matmulnbits_accuracy_level, - intra_op_thread_pool, - p_buffered_tensors); + intra_op_thread_pool); return qdq_selector_action_registry; } @@ -380,12 +376,11 @@ QDQSelectorActionTransformer::QDQSelectorActionTransformer( bool is_int8_allowed, const SatApplyContextVariant& apply_context, int64_t qdq_matmulnbits_accuracy_level, - concurrency::ThreadPool* intra_op_thread_pool, - std::unordered_map>* p_buffered_tensors) + concurrency::ThreadPool* intra_op_thread_pool) : SelectorActionTransformer{ "QDQSelectorActionTransformer", CreateSelectorActionRegistry(is_int8_allowed, qdq_matmulnbits_accuracy_level, - intra_op_thread_pool, p_buffered_tensors), + intra_op_thread_pool), apply_context, // this transformer is compatible with CPU, DML, ACL and CUDA EP. // There is further EP control on the rule level. diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h index 627ddd35b9919..dce1cd44fd3ea 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h @@ -29,8 +29,7 @@ class QDQSelectorActionTransformer : public SelectorActionTransformer { QDQSelectorActionTransformer(bool is_int8_allowed, const SatApplyContextVariant& apply_context = {}, int64_t qdq_matmulnbits_accuracy_level = 4, - concurrency::ThreadPool* intra_op_thread_pool = nullptr, - std::unordered_map>* p_buffered_tensors = nullptr); + concurrency::ThreadPool* intra_op_thread_pool = nullptr); }; } // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc index 255714054cdaa..becf33db80bdc 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc @@ -173,12 +173,13 @@ bool DropQDQNodeGroupSelector::Check(const GraphViewer& graph_viewer, const Node if (!allow_nonpositive_scale_) { // IsQDQPairSupported will check that the scale is the same between q_node and dq_node. - if (!IsQOrDQScalePositiveConstantScalar(q_node, get_const_initializer, graph_viewer.ModelPath())) { + if (!IsQOrDQScalePositiveConstantScalar(graph_viewer.GetGraph(), q_node, get_const_initializer, + graph_viewer.ModelPath())) { return false; } } - return IsQDQPairSupported(q_node, dq_node, get_const_initializer, graph_viewer.ModelPath()); + return IsQDQPairSupported(graph_viewer.GetGraph(), q_node, dq_node, get_const_initializer, graph_viewer.ModelPath()); } bool DropDQNodeGroupSelector::Check(const GraphViewer& graph_viewer, const Node& node, const Node* redundant_clip_node, @@ -345,7 +346,7 @@ bool SplitNodeGroupSelector::Check(const GraphViewer& graph_viewer, const Node& } if (req_equal_quant_params_ && - !IsQDQPairSupported(q_node, dq_node, get_const_initializer, graph_viewer.ModelPath())) { + !IsQDQPairSupported(graph_viewer.GetGraph(), q_node, dq_node, get_const_initializer, graph_viewer.ModelPath())) { return false; } } @@ -432,6 +433,37 @@ bool EinsumNodeGroupSelector::Check(const GraphViewer& graph_viewer, return true; } +bool ReciprocalNodeGroupSelector::Check(const GraphViewer& graph_viewer, + const Node& node, const Node* redundant_clip_node, + const std::vector& dq_nodes, + const std::vector& q_nodes) const { + if (!CheckQDQNodes(graph_viewer, node, redundant_clip_node, dq_nodes, q_nodes, /*num_dq_inputs=*/-1, + /*is_empty_q_nodes_allowed=*/true)) { + return false; + } + size_t num_dq_inputs = dq_nodes.size(); + for (size_t i = 0; i < num_dq_inputs; ++i) { + int32_t dt_input = dq_nodes[i]->InputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); + if (!allow_int8_ && dt_input == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8) { + return false; + } + if (!allow_16bit_ && Is16BitIntType(dt_input)) { + return false; + } + if (!allow_4bit_ && Is4BitIntType(dt_input)) { + return false; + } + } + if (!q_nodes.empty()) { + int32_t dt_input0 = dq_nodes[0]->InputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); + int32_t dt_output = q_nodes[0]->OutputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); + if (dt_input0 != dt_output) { + return false; + } + } + return true; +} + bool MatMulNodeGroupSelector::Check(const GraphViewer& graph_viewer, const Node& node, const Node* redundant_clip_node, const std::vector& dq_nodes, const std::vector& q_nodes) const { @@ -761,7 +793,7 @@ bool TopKNodeGroupSelector::Check(const GraphViewer& graph_viewer, const Node& n return graph_viewer.GetConstantInitializer(initializer_name, true); }; - return IsQDQPairSupported(q_node, dq_node, get_const_initializer, graph_viewer.ModelPath()); + return IsQDQPairSupported(graph_viewer.GetGraph(), q_node, dq_node, get_const_initializer, graph_viewer.ModelPath()); } bool CumSumNodeGroupSelector::Check(const GraphViewer& graph_viewer, const Node& node, const Node* redundant_clip_node, diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h index 6f7e153ec6ecb..3f062ebfb9ee5 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h @@ -198,6 +198,21 @@ class EinsumNodeGroupSelector : public NodeGroupSelector { bool allow_4bit_; }; +class ReciprocalNodeGroupSelector : public NodeGroupSelector { + public: + explicit ReciprocalNodeGroupSelector(bool allow_int8 = true, bool allow_16bit = true, bool allow_4bit = true) + : allow_int8_(allow_int8), allow_16bit_(allow_16bit), allow_4bit_(allow_4bit) {} + + private: + bool Check(const GraphViewer& graph_viewer, + const Node& node, const Node* redundant_clip_node, + const std::vector& dq_nodes, + const std::vector& q_nodes) const override; + bool allow_int8_; + bool allow_16bit_; + bool allow_4bit_; +}; + // 2 DQ nodes for input -> node -> optional Q if QLinearMatMul, MatMulIntegerToFloat if not // The lack of a trailing Q isn't really a QDQ node group, so we default support for that to off. class MatMulNodeGroupSelector : public NodeGroupSelector { diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc index a39a6e8cc0e93..82dc9b500cea1 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc @@ -116,6 +116,11 @@ static const OpVersionsAndSelector::OpVersionsMap GetConvTransposeOpVersionsMap( static const OpVersionsAndSelector::OpVersionsMap GetEinsumOpVersionsMap() { return {{"Einsum", {}}}; } + +static const OpVersionsAndSelector::OpVersionsMap GetReciprocalOpVersionsMap() { + return {{"Reciprocal", {}}}; +} + static const OpVersionsAndSelector::OpVersionsMap GetMatMulOpVersionsMap() { return {{"MatMul", {}}}; } @@ -215,6 +220,13 @@ void RegisterEinsumSelector(Selectors& qdq_selectors) { std::move(selector)); } +void RegisterReciprocalSelector(Selectors& qdq_selectors) { + /* register selector for Reciprocal op */ + std::unique_ptr selector = std::make_unique(); + qdq_selectors.RegisterSelector(GetReciprocalOpVersionsMap(), + std::move(selector)); +} + void RegisterMatMulSelector(Selectors& qdq_selectors) { /* register selector for matmul op */ std::unique_ptr selector = std::make_unique(); @@ -288,6 +300,7 @@ void SelectorManager::CreateSelectors() { RegisterConvSelector(qdq_selectors_); RegisterConvTransposeSelector(qdq_selectors_); RegisterEinsumSelector(qdq_selectors_); + RegisterReciprocalSelector(qdq_selectors_); RegisterMatMulSelector(qdq_selectors_); RegisterGemmSelector(qdq_selectors_); RegisterInstanceAndLayerNormalizationSelector(qdq_selectors_); diff --git a/onnxruntime/core/optimizer/qdq_transformer/weight_bias_quantization.cc b/onnxruntime/core/optimizer/qdq_transformer/weight_bias_quantization.cc index 58e90ea3c71c2..aa6f9c5409de7 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/weight_bias_quantization.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/weight_bias_quantization.cc @@ -117,28 +117,28 @@ Status WeightBiasQuantization::ApplyImpl(Graph& graph, bool& modified, int graph NodeArg* weight_scale_arg = nullptr; if (!dq_1) { - auto initializer = std::make_unique(*weight_proto, graph.ModelPath()); - const float* weight_data = initializer->data(); + Initializer initializer(graph, *weight_proto, graph.ModelPath()); + const float* weight_data = initializer.data(); // Quantize float32 weight to int8_t (per-tensor, symmetric). // int8_t quantization of input[1] works with input[0] of all types. float scale; int8_t zp; - GetQuantizationParameter(weight_data, static_cast(initializer->size()), scale, zp, nullptr); + GetQuantizationParameter(weight_data, static_cast(initializer.size()), scale, zp, nullptr); // Weight scale initializer. ONNX_NAMESPACE::TensorProto weight_scale_proto; weight_scale_proto.set_name(graph.GenerateNodeArgName(node.Name() + "_weight_scale")); weight_scale_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); weight_scale_proto.mutable_float_data()->Add(scale); - weight_scale_arg = &graph_utils::AddInitializer(graph, weight_scale_proto); + weight_scale_arg = &graph_utils::AddInitializerWithExternalData(graph, weight_scale_proto); // Weight zero point initializer. ONNX_NAMESPACE::TensorProto weight_zp_proto; weight_zp_proto.set_name(graph.GenerateNodeArgName(node.Name() + "_weight_zp")); weight_zp_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT8); weight_zp_proto.mutable_int32_data()->Add(static_cast(zp)); - NodeArg& weight_zp_arg = graph_utils::AddInitializer(graph, weight_zp_proto); + NodeArg& weight_zp_arg = graph_utils::AddInitializerWithExternalData(graph, weight_zp_proto); // Q from float32 to int8. ONNX_NAMESPACE::TypeProto weight_q_type_proto; diff --git a/onnxruntime/core/optimizer/quick_gelu_fusion.cc b/onnxruntime/core/optimizer/quick_gelu_fusion.cc index b09ef1c460b8e..54236c9a27980 100644 --- a/onnxruntime/core/optimizer/quick_gelu_fusion.cc +++ b/onnxruntime/core/optimizer/quick_gelu_fusion.cc @@ -37,7 +37,7 @@ Status QuickGeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, if (!optimizer_utils::IsScalar(input_arg)) continue; const TensorProto* tensor_proto = graph_utils::GetConstantInitializer(graph, input_arg.Name()); if (!tensor_proto) continue; - Initializer init_const{*tensor_proto, graph.ModelPath()}; + Initializer init_const{graph, *tensor_proto, graph.ModelPath()}; const auto data_type = tensor_proto->data_type(); if (data_type == TensorProto_DataType_FLOAT) { alpha = *(init_const.data()); diff --git a/onnxruntime/core/optimizer/relu_clip_fusion.cc b/onnxruntime/core/optimizer/relu_clip_fusion.cc index ae12c7bdfd4ac..efd7022ab764b 100644 --- a/onnxruntime/core/optimizer/relu_clip_fusion.cc +++ b/onnxruntime/core/optimizer/relu_clip_fusion.cc @@ -56,7 +56,7 @@ Status FuseReluClip::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_eff data_type = initializer->data_type(); // construct an initializer to gracefully handle typed or raw data in the TensorProto - Initializer i(*initializer, graph.ModelPath()); + Initializer i(graph, *initializer, graph.ModelPath()); switch (data_type) { case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: if (*i.data() < 0.f) { @@ -97,12 +97,7 @@ Status FuseReluClip::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_eff mutable_next_node->AddAttribute("min", 0.f); } else { // Add the initialized tensor to the graph - graph.AddInitializedTensor(replacement_min); - - // Create a corresponding NodeArg for the initialized tensor - ONNX_NAMESPACE::TypeProto t; - t.mutable_tensor_type()->set_elem_type(replacement_min.data_type()); - NodeArg* replacement_min_nodearg = &graph.GetOrCreateNodeArg(replacement_min.name(), &t); + auto* replacement_min_nodearg = &graph_utils::AddInitializerWithExternalData(graph, replacement_min); // Replace the input def at the appropriate index of the Clip node auto& mutable_input_defs = mutable_next_node->MutableInputDefs(); diff --git a/onnxruntime/core/optimizer/reshape_fusion.cc b/onnxruntime/core/optimizer/reshape_fusion.cc index 324905f953eec..36213609f6b61 100644 --- a/onnxruntime/core/optimizer/reshape_fusion.cc +++ b/onnxruntime/core/optimizer/reshape_fusion.cc @@ -438,7 +438,7 @@ bool ReshapeFusion::Fuse_Subgraph(Node& reshape, Graph& graph, const logging::Lo shape_initializer_proto.add_dims(static_cast(shape_value.size())); shape_initializer_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); utils::SetRawDataInTensorProto(shape_initializer_proto, shape_value.data(), shape_value.size() * sizeof(int64_t)); - auto& new_node_arg = graph_utils::AddInitializer(graph, shape_initializer_proto); + auto& new_node_arg = graph_utils::AddInitializerWithExternalData(graph, shape_initializer_proto); // Safely remove concat parent nodes which have only one output for (int i = 0; i < concat_input_count; ++i) { @@ -492,7 +492,7 @@ bool ReshapeFusion::FuseContiguousReshapes(Node& reshape, Graph& graph) { shape_initializer_proto.add_dims(static_cast(shape_value.size())); shape_initializer_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); utils::SetRawDataInTensorProto(shape_initializer_proto, shape_value.data(), shape_value.size() * sizeof(int64_t)); - NodeArg* shape_arg = &graph_utils::AddInitializer(graph, shape_initializer_proto); + NodeArg* shape_arg = &graph_utils::AddInitializerWithExternalData(graph, shape_initializer_proto); Node& reshape_node = graph.AddNode(graph.GenerateNodeName(name + "_new_reshape"), "Reshape", "Reshape for " + name, {contiguous_reshapes[0].get().MutableInputDefs()[0], shape_arg}, {contiguous_reshapes.back().get().MutableOutputDefs()[0]}); diff --git a/onnxruntime/core/optimizer/slice_elimination.cc b/onnxruntime/core/optimizer/slice_elimination.cc index d3ec2dd459fd3..c4066097e43f1 100644 --- a/onnxruntime/core/optimizer/slice_elimination.cc +++ b/onnxruntime/core/optimizer/slice_elimination.cc @@ -61,7 +61,7 @@ bool EliminateSlice::SatisfyCondition(const Graph& graph, const Node& node, cons auto get_initializer_data = [&graph](const ONNX_NAMESPACE::TensorProto* initializer) -> InlinedVector { - Initializer init(*initializer, graph.ModelPath()); + Initializer init(graph, *initializer, graph.ModelPath()); if (initializer->data_type() == ONNX_NAMESPACE::TensorProto::INT32) { int32_t* init_data = init.data(); return InlinedVector(init_data, init_data + init.size()); diff --git a/onnxruntime/core/optimizer/stft_decomposition.cc b/onnxruntime/core/optimizer/stft_decomposition.cc index 5c09e5225ab9c..74121508132dc 100644 --- a/onnxruntime/core/optimizer/stft_decomposition.cc +++ b/onnxruntime/core/optimizer/stft_decomposition.cc @@ -46,7 +46,7 @@ NodeArg* AddInitializer(Graph& graph, const char* name, const int64_t (&shape)[T proto.add_dims(shape[i]); } utils::SetRawDataInTensorProto(proto, begin, element_count * sizeof(TDataType)); - return &graph_utils::AddInitializer(graph, proto); + return &graph_utils::AddInitializerWithExternalData(graph, proto); } template diff --git a/onnxruntime/core/optimizer/transformer_memcpy.cc b/onnxruntime/core/optimizer/transformer_memcpy.cc index 2aa3cf30813b6..a320de2ee7a13 100644 --- a/onnxruntime/core/optimizer/transformer_memcpy.cc +++ b/onnxruntime/core/optimizer/transformer_memcpy.cc @@ -5,7 +5,9 @@ #include "core/common/logging/logging.h" #include "core/framework/kernel_registry_manager.h" #include "core/framework/execution_providers.h" +#include "core/framework/tensorprotoutils.h" #include "core/framework/utils.h" +#include "core/graph/graph_utils.h" using namespace ONNX_NAMESPACE; namespace onnxruntime { @@ -371,26 +373,31 @@ bool TransformerMemcpyImpl::ProcessInitializers(const KernelRegistryManager& ker const InitializedTensorSet& initializers_consumed, const logging::Logger& logger) { std::map replacements; - for (const auto& pair : initializers_consumed) { - const auto& name = pair.first; + for (const auto& [name, tensor_proto] : initializers_consumed) { const onnxruntime::NodeArg* provider_def = FindNodeArg(provider_input_defs_, name); const onnxruntime::NodeArg* non_provider_def = FindNodeArg(non_provider_input_defs_, name); if (provider_def != nullptr && non_provider_def != nullptr) { std::string new_def_name = graph_.GenerateNodeArgName(name); auto& new_def = graph_.GetOrCreateNodeArg(new_def_name, provider_def->TypeAsProto()); - // We make a copy of the initializer that is to be consumed by the provider Node so that - // session state initializer can copy it over to the provider device during its operation - // TODO: The copy being made is possibly redundant if this occurs in a subgraph - // When multiple subgraphs consume the same initializer as an implicit input, - // multiple copies of the initializer will be made into the provider device - // This should not directly affect runtime performance as the copies occur during initialization - // but overuse of the provider device's memory is definitely inefficient - // In future, we need to "statefully" make the copy only once and use it in all subgraphs referencing the initializer - const TensorProto* tensor_proto = pair.second; TensorProto new_tensor_proto = *tensor_proto; *(new_tensor_proto.mutable_name()) = new_def_name; - graph_.AddInitializedTensor(new_tensor_proto); + + // Query any OrtValue existing for the original initializer + // We are checking outer scope because GetInitializer is called with true, therefore, we potentially + // have references to parent graphs. + // We are doing this so the same OrtValue is re-used in subgraphs and no copies made for big items. + constexpr const bool check_outer_scope_true = true; + OrtValue ort_value; + // The initializer can be in memory with OrtValue or it can be a flatbuffer mapped. + if (utils::HasExternalDataInMemory(new_tensor_proto) && + graph_.GetOrtValueInitializer(name, ort_value, check_outer_scope_true)) { + // Re-use the same ort_value and proto that points to the same buffer + ORT_IGNORE_RETURN_VALUE(graph_utils::AddInitializerWithExternalData(graph_, new_tensor_proto, + std::move(ort_value))); + } else { + ORT_IGNORE_RETURN_VALUE(graph_utils::AddInitializer(graph_, new_tensor_proto)); + } replacements.insert(std::make_pair(provider_def, &new_def)); } diff --git a/onnxruntime/core/optimizer/transpose_optimization/ort_optimizer_api_impl.cc b/onnxruntime/core/optimizer/transpose_optimization/ort_optimizer_api_impl.cc index f87df746234fa..48ea54434b805 100644 --- a/onnxruntime/core/optimizer/transpose_optimization/ort_optimizer_api_impl.cc +++ b/onnxruntime/core/optimizer/transpose_optimization/ort_optimizer_api_impl.cc @@ -14,6 +14,7 @@ #include "core/framework/tensorprotoutils.h" #include "core/graph/graph_utils.h" #include "core/graph/graph_viewer.h" +#include "core/optimizer/initializer.h" #include "core/optimizer/layout_transformation/layout_transformation_potentially_added_ops.h" #include "core/optimizer/transpose_optimization/ort_optimizer_utils.h" #include "core/optimizer/transpose_optimization/ort_transpose_optimization.h" @@ -558,8 +559,8 @@ void ApiGraph::TransposeInitializer(std::string_view name, const std::vector new_tensor_shape_dims; - std::vector permutations; + TensorShapeVector new_tensor_shape_dims; + InlinedVector permutations; permutations.reserve(perm.size()); new_tensor_shape_dims.reserve(perm.size()); for (int64_t p : perm) { @@ -568,12 +569,12 @@ void ApiGraph::TransposeInitializer(std::string_view name, const std::vector& shape) { @@ -607,14 +610,19 @@ void ApiGraph::ReshapeInitializer(std::string_view name, const std::vectordata_type(); if (data_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { const float* val = init_const.data(); @@ -110,7 +110,7 @@ bool IsInitializerWithExpectedValue(const Graph& graph, const NodeArg& input_arg return false; } - Initializer init_const{*tensor_proto, graph.ModelPath()}; + Initializer init_const{graph, *tensor_proto, graph.ModelPath()}; const auto data_type = tensor_proto->data_type(); if (data_type == ONNX_NAMESPACE::TensorProto_DataType_INT64) { const int64_t* val = init_const.data(); @@ -171,7 +171,7 @@ bool AppendTensorFromInitializer(const Graph& graph, const NodeArg& input_arg, I return false; } - Initializer init_const{*tensor_proto, graph.ModelPath()}; + Initializer init_const{graph, *tensor_proto, graph.ModelPath()}; const auto data_type = tensor_proto->data_type(); if (data_type == ONNX_NAMESPACE::TensorProto_DataType_INT64) { const int64_t* val = init_const.data(); @@ -333,7 +333,7 @@ bool GetClipConstantMinMax(const Graph& graph, const Node& node, float& min, flo bool is_constant = true; const ONNX_NAMESPACE::TensorProto* initializer = graph.GetConstantInitializer(input->Name(), true); if (initializer) { - Initializer i(*initializer, graph.ModelPath()); + Initializer i(graph, *initializer, graph.ModelPath()); switch (initializer->data_type()) { case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: value = *i.data(); @@ -421,7 +421,7 @@ bool GetScalarInitializerValue(const onnxruntime::Graph& graph, const onnxruntim return false; } - Initializer init_const{*tensor_proto, graph.ModelPath()}; + Initializer init_const{graph, *tensor_proto, graph.ModelPath()}; const T* val = init_const.data(); value = *val; diff --git a/onnxruntime/core/platform/env.h b/onnxruntime/core/platform/env.h index 7dbc3fe82db47..e100d3626f76b 100644 --- a/onnxruntime/core/platform/env.h +++ b/onnxruntime/core/platform/env.h @@ -26,7 +26,6 @@ limitations under the License. #include "core/common/common.h" #include "core/common/path_string.h" -#include "core/framework/callback.h" #include "core/platform/env_time.h" #include "core/platform/telemetry.h" #include "core/session/onnxruntime_c_api.h" @@ -179,7 +178,7 @@ class Env { virtual common::Status ReadFileIntoBuffer(_In_z_ const ORTCHAR_T* file_path, FileOffsetType offset, size_t length, gsl::span buffer) const = 0; - using MappedMemoryPtr = std::unique_ptr; + using MappedMemoryPtr = std::unique_ptr>; /** * Maps the content of the file into memory. diff --git a/onnxruntime/core/platform/posix/env.cc b/onnxruntime/core/platform/posix/env.cc index 94aadf3df4d7e..0e43d054d5c5e 100644 --- a/onnxruntime/core/platform/posix/env.cc +++ b/onnxruntime/core/platform/posix/env.cc @@ -62,15 +62,8 @@ namespace { constexpr int OneMillion = 1000000; -class UnmapFileParam { - public: - void* addr; - size_t len; -}; - -static void UnmapFile(void* param) noexcept { - std::unique_ptr p(reinterpret_cast(param)); - int ret = munmap(p->addr, p->len); +static void UnmapFile(void* addr, size_t len) noexcept { + int ret = munmap(addr, len); if (ret != 0) { auto [err_no, err_msg] = GetErrnoInfo(); LOGS_DEFAULT(ERROR) << "munmap failed. error code: " << err_no << " error msg: " << err_msg; @@ -451,7 +444,9 @@ class PosixEnv : public Env { mapped_memory = MappedMemoryPtr{reinterpret_cast(mapped_base) + offset_to_page, - OrtCallbackInvoker{OrtCallback{UnmapFile, new UnmapFileParam{mapped_base, mapped_length}}}}; + [mapped_base, mapped_length](void*) { + UnmapFile(mapped_base, mapped_length); + }}; return Status::OK(); } diff --git a/onnxruntime/core/platform/telemetry.cc b/onnxruntime/core/platform/telemetry.cc index e9c7830f6d7a4..888ff1d0aa91e 100644 --- a/onnxruntime/core/platform/telemetry.cc +++ b/onnxruntime/core/platform/telemetry.cc @@ -52,7 +52,11 @@ void Telemetry::LogEvaluationStart() const { void Telemetry::LogSessionCreation(uint32_t session_id, int64_t ir_version, const std::string& model_producer_name, const std::string& model_producer_version, const std::string& model_domain, const std::unordered_map& domain_to_version_map, + const std::string& model_file_name, const std::string& model_graph_name, + const std::string& model_weight_type, + const std::string& model_graph_hash, + const std::string& model_weight_hash, const std::unordered_map& model_metadata, const std::string& loadedFrom, const std::vector& execution_provider_ids, bool use_fp16, bool captureState) const { @@ -62,7 +66,11 @@ void Telemetry::LogSessionCreation(uint32_t session_id, int64_t ir_version, cons ORT_UNUSED_PARAMETER(model_producer_version); ORT_UNUSED_PARAMETER(model_domain); ORT_UNUSED_PARAMETER(domain_to_version_map); + ORT_UNUSED_PARAMETER(model_file_name); ORT_UNUSED_PARAMETER(model_graph_name); + ORT_UNUSED_PARAMETER(model_weight_type); + ORT_UNUSED_PARAMETER(model_graph_hash); + ORT_UNUSED_PARAMETER(model_weight_hash); ORT_UNUSED_PARAMETER(model_metadata); ORT_UNUSED_PARAMETER(loadedFrom); ORT_UNUSED_PARAMETER(execution_provider_ids); @@ -79,10 +87,12 @@ void Telemetry::LogRuntimeError(uint32_t session_id, const common::Status& statu ORT_UNUSED_PARAMETER(line); } -void Telemetry::LogRuntimePerf(uint32_t session_id, uint32_t total_runs_since_last, int64_t total_run_duration_since_last) const { +void Telemetry::LogRuntimePerf(uint32_t session_id, uint32_t total_runs_since_last, int64_t total_run_duration_since_last, + std::unordered_map duration_per_batch_size) const { ORT_UNUSED_PARAMETER(session_id); ORT_UNUSED_PARAMETER(total_runs_since_last); ORT_UNUSED_PARAMETER(total_run_duration_since_last); + ORT_UNUSED_PARAMETER(duration_per_batch_size); } void Telemetry::LogExecutionProviderEvent(LUID* adapterLuid) const { diff --git a/onnxruntime/core/platform/telemetry.h b/onnxruntime/core/platform/telemetry.h index d9afcace2fb81..99199c34f0464 100644 --- a/onnxruntime/core/platform/telemetry.h +++ b/onnxruntime/core/platform/telemetry.h @@ -57,7 +57,11 @@ class Telemetry { virtual void LogSessionCreation(uint32_t session_id, int64_t ir_version, const std::string& model_producer_name, const std::string& model_producer_version, const std::string& model_domain, const std::unordered_map& domain_to_version_map, + const std::string& model_file_name, const std::string& model_graph_name, + const std::string& model_weight_type, + const std::string& model_graph_hash, + const std::string& model_weight_hash, const std::unordered_map& model_metadata, const std::string& loadedFrom, const std::vector& execution_provider_ids, bool use_fp16, bool captureState) const; @@ -65,7 +69,8 @@ class Telemetry { virtual void LogRuntimeError(uint32_t session_id, const common::Status& status, const char* file, const char* function, uint32_t line) const; - virtual void LogRuntimePerf(uint32_t session_id, uint32_t total_runs_since_last, int64_t total_run_duration_since_last) const; + virtual void LogRuntimePerf(uint32_t session_id, uint32_t total_runs_since_last, int64_t total_run_duration_since_last, + std::unordered_map duration_per_batch_size) const; virtual void LogExecutionProviderEvent(LUID* adapterLuid) const; diff --git a/onnxruntime/core/platform/windows/env.cc b/onnxruntime/core/platform/windows/env.cc index 9fdd323b365d6..36c6b54a1fce0 100644 --- a/onnxruntime/core/platform/windows/env.cc +++ b/onnxruntime/core/platform/windows/env.cc @@ -45,15 +45,8 @@ EXTERN_C IMAGE_DOS_HEADER __ImageBase; namespace onnxruntime { -class UnmapFileParam { - public: - void* addr; - size_t len; -}; - -static void UnmapFile(void* param) noexcept { - std::unique_ptr p(reinterpret_cast(param)); - bool ret = UnmapViewOfFile(p->addr); +static void UnmapFile(void* addr) noexcept { + bool ret = UnmapViewOfFile(addr); if (!ret) { const auto error_code = GetLastError(); LOGS_DEFAULT(ERROR) << "unmap view of file failed. error code: " << error_code @@ -467,9 +460,12 @@ Status WindowsEnv::MapFileIntoMemory(_In_z_ const ORTCHAR_T* file_path, static_cast(mapped_offset & 0xFFFFFFFF), mapped_length); GSL_SUPPRESS(r.11) + mapped_memory = MappedMemoryPtr{reinterpret_cast(mapped_base) + offset_to_page, - OrtCallbackInvoker{OrtCallback{UnmapFile, new UnmapFileParam{mapped_base, mapped_length}}}}; + [mapped_base](void*) { + UnmapFile(mapped_base); + }}; return Status::OK(); } diff --git a/onnxruntime/core/platform/windows/logging/etw_sink.cc b/onnxruntime/core/platform/windows/logging/etw_sink.cc index 6a7f292d83b72..489cd19b11302 100644 --- a/onnxruntime/core/platform/windows/logging/etw_sink.cc +++ b/onnxruntime/core/platform/windows/logging/etw_sink.cc @@ -106,15 +106,18 @@ HRESULT EtwRegistrationManager::Status() const { return etw_status_; } -void EtwRegistrationManager::RegisterInternalCallback(const std::string& cb_key, EtwInternalCallback callback) { +void EtwRegistrationManager::RegisterInternalCallback(const EtwInternalCallback& callback) { std::lock_guard lock(callbacks_mutex_); - [[maybe_unused]] auto result = callbacks_.emplace(cb_key, std::move(callback)); - assert(result.second); + callbacks_.push_back(&callback); } -void EtwRegistrationManager::UnregisterInternalCallback(const std::string& cb_key) { +void EtwRegistrationManager::UnregisterInternalCallback(const EtwInternalCallback& callback) { std::lock_guard lock(callbacks_mutex_); - callbacks_.erase(cb_key); + auto new_end = std::remove_if(callbacks_.begin(), callbacks_.end(), + [&callback](const EtwInternalCallback* ptr) { + return ptr == &callback; + }); + callbacks_.erase(new_end, callbacks_.end()); } void NTAPI EtwRegistrationManager::ORT_TL_EtwEnableCallback( @@ -135,12 +138,21 @@ void NTAPI EtwRegistrationManager::ORT_TL_EtwEnableCallback( manager.InvokeCallbacks(SourceId, IsEnabled, Level, MatchAnyKeyword, MatchAllKeyword, FilterData, CallbackContext); } -EtwRegistrationManager::EtwRegistrationManager() - : initialization_status_(InitializationStatus::NotInitialized), - is_enabled_(false), - level_(), - keyword_(0), - etw_status_(S_OK) { +EtwRegistrationManager::~EtwRegistrationManager() { + std::lock_guard lock(callbacks_mutex_); + callbacks_.clear(); + if (initialization_status_ == InitializationStatus::Initialized || + initialization_status_ == InitializationStatus::Initializing) { + std::lock_guard init_lock(init_mutex_); + assert(initialization_status_ != InitializationStatus::Initializing); + if (initialization_status_ == InitializationStatus::Initialized) { + ::TraceLoggingUnregister(etw_provider_handle); + initialization_status_ = InitializationStatus::NotInitialized; + } + } +} + +EtwRegistrationManager::EtwRegistrationManager() { } void EtwRegistrationManager::LazyInitialize() { @@ -161,13 +173,6 @@ void EtwRegistrationManager::LazyInitialize() { } } -EtwRegistrationManager::~EtwRegistrationManager() { - if (initialization_status_ == InitializationStatus::Initialized) { - ::TraceLoggingUnregister(etw_provider_handle); - initialization_status_ = InitializationStatus::NotInitialized; - } -} - void EtwRegistrationManager::InvokeCallbacks(LPCGUID SourceId, ULONG IsEnabled, UCHAR Level, ULONGLONG MatchAnyKeyword, ULONGLONG MatchAllKeyword, PEVENT_FILTER_DESCRIPTOR FilterData, PVOID CallbackContext) { @@ -177,9 +182,10 @@ void EtwRegistrationManager::InvokeCallbacks(LPCGUID SourceId, ULONG IsEnabled, } std::lock_guard lock(callbacks_mutex_); - for (const auto& entry : callbacks_) { - const auto& cb = entry.second; - cb(SourceId, IsEnabled, Level, MatchAnyKeyword, MatchAllKeyword, FilterData, CallbackContext); + for (const auto& callback : callbacks_) { + if (callback != nullptr) { + (*callback)(SourceId, IsEnabled, Level, MatchAnyKeyword, MatchAllKeyword, FilterData, CallbackContext); + } } } diff --git a/onnxruntime/core/platform/windows/logging/etw_sink.h b/onnxruntime/core/platform/windows/logging/etw_sink.h index 308770252f85a..62b762886ca82 100644 --- a/onnxruntime/core/platform/windows/logging/etw_sink.h +++ b/onnxruntime/core/platform/windows/logging/etw_sink.h @@ -20,7 +20,6 @@ #include #include #include -#include #include #include "core/common/logging/capture.h" @@ -78,9 +77,9 @@ class EtwRegistrationManager { // Get the ETW registration status HRESULT Status() const; - void RegisterInternalCallback(const std::string& cb_key, EtwInternalCallback callback); + void RegisterInternalCallback(const EtwInternalCallback& callback); - void UnregisterInternalCallback(const std::string& cb_key); + void UnregisterInternalCallback(const EtwInternalCallback& callback); private: EtwRegistrationManager(); @@ -101,11 +100,11 @@ class EtwRegistrationManager { _In_opt_ PEVENT_FILTER_DESCRIPTOR FilterData, _In_opt_ PVOID CallbackContext); - std::mutex init_mutex_; - std::atomic initialization_status_ = InitializationStatus::NotInitialized; - std::unordered_map callbacks_; + std::vector callbacks_; std::mutex callbacks_mutex_; mutable std::mutex provider_change_mutex_; + std::mutex init_mutex_; + InitializationStatus initialization_status_ = InitializationStatus::NotInitialized; bool is_enabled_; UCHAR level_; ULONGLONG keyword_; @@ -134,8 +133,8 @@ class EtwRegistrationManager { Severity MapLevelToSeverity() { return Severity::kFATAL; } uint64_t Keyword() const { return 0; } HRESULT Status() const { return 0; } - void RegisterInternalCallback(const std::string& cb_key, EtwInternalCallback callback) {} - void UnregisterInternalCallback(const std::string& cb_key) {} + void RegisterInternalCallback(const EtwInternalCallback& callback) {} + void UnregisterInternalCallback(const EtwInternalCallback& callback) {} private: EtwRegistrationManager() = default; diff --git a/onnxruntime/core/platform/windows/telemetry.cc b/onnxruntime/core/platform/windows/telemetry.cc index 5d2cfd216dffc..39cd805b96d6e 100644 --- a/onnxruntime/core/platform/windows/telemetry.cc +++ b/onnxruntime/core/platform/windows/telemetry.cc @@ -57,15 +57,18 @@ TRACELOGGING_DEFINE_PROVIDER(telemetry_provider_handle, "Microsoft.ML.ONNXRuntim #pragma warning(pop) #endif +#ifndef ORT_CALLER_FRAMEWORK +#define ORT_CALLER_FRAMEWORK "" +#endif + std::mutex WindowsTelemetry::mutex_; std::mutex WindowsTelemetry::provider_change_mutex_; uint32_t WindowsTelemetry::global_register_count_ = 0; -std::atomic_bool WindowsTelemetry::enabled_{true}; +bool WindowsTelemetry::enabled_ = true; uint32_t WindowsTelemetry::projection_ = 0; -std::atomic WindowsTelemetry::level_{0}; -std::atomic WindowsTelemetry::keyword_{0}; - -std::unordered_map WindowsTelemetry::callbacks_; +UCHAR WindowsTelemetry::level_ = 0; +UINT64 WindowsTelemetry::keyword_ = 0; +std::vector WindowsTelemetry::callbacks_; std::mutex WindowsTelemetry::callbacks_mutex_; WindowsTelemetry::WindowsTelemetry() { @@ -80,45 +83,49 @@ WindowsTelemetry::WindowsTelemetry() { } WindowsTelemetry::~WindowsTelemetry() { - { - std::lock_guard lock(mutex_); - if (global_register_count_ > 0) { - global_register_count_ -= 1; - if (global_register_count_ == 0) { - TraceLoggingUnregister(telemetry_provider_handle); - } + std::lock_guard lock(mutex_); + if (global_register_count_ > 0) { + global_register_count_ -= 1; + if (global_register_count_ == 0) { + TraceLoggingUnregister(telemetry_provider_handle); } } + + std::lock_guard lock_callbacks(callbacks_mutex_); + callbacks_.clear(); } bool WindowsTelemetry::IsEnabled() const { + std::lock_guard lock(provider_change_mutex_); return enabled_; } UCHAR WindowsTelemetry::Level() const { + std::lock_guard lock(provider_change_mutex_); return level_; } UINT64 WindowsTelemetry::Keyword() const { + std::lock_guard lock(provider_change_mutex_); return keyword_; } -void WindowsTelemetry::RegisterInternalCallback(const std::string& callback_key, EtwInternalCallback callback) { +// HRESULT WindowsTelemetry::Status() { +// return etw_status_; +// } + +void WindowsTelemetry::RegisterInternalCallback(const EtwInternalCallback& callback) { std::lock_guard lock_callbacks(callbacks_mutex_); - auto result = callbacks_.emplace(callback_key, std::move(callback)); - if (!result.second) { - result.first->second.IncrementRef(); - } + callbacks_.push_back(&callback); } -void WindowsTelemetry::UnregisterInternalCallback(const std::string& callback_key) { +void WindowsTelemetry::UnregisterInternalCallback(const EtwInternalCallback& callback) { std::lock_guard lock_callbacks(callbacks_mutex_); - auto hit = callbacks_.find(callback_key); - if (hit != callbacks_.end()) { - if (hit->second.DecrementRef() < 1) { - callbacks_.erase(hit); - } - } + auto new_end = std::remove_if(callbacks_.begin(), callbacks_.end(), + [&callback](const EtwInternalCallback* ptr) { + return ptr == &callback; + }); + callbacks_.erase(new_end, callbacks_.end()); } void NTAPI WindowsTelemetry::ORT_TL_EtwEnableCallback( @@ -129,12 +136,10 @@ void NTAPI WindowsTelemetry::ORT_TL_EtwEnableCallback( _In_ ULONGLONG MatchAllKeyword, _In_opt_ PEVENT_FILTER_DESCRIPTOR FilterData, _In_opt_ PVOID CallbackContext) { - { - std::lock_guard lock(provider_change_mutex_); - enabled_ = (IsEnabled != 0); - level_ = Level; - keyword_ = MatchAnyKeyword; - } + std::lock_guard lock(provider_change_mutex_); + enabled_ = (IsEnabled != 0); + level_ = Level; + keyword_ = MatchAnyKeyword; InvokeCallbacks(SourceId, IsEnabled, Level, MatchAnyKeyword, MatchAllKeyword, FilterData, CallbackContext); } @@ -143,9 +148,8 @@ void WindowsTelemetry::InvokeCallbacks(LPCGUID SourceId, ULONG IsEnabled, UCHAR ULONGLONG MatchAllKeyword, PEVENT_FILTER_DESCRIPTOR FilterData, PVOID CallbackContext) { std::lock_guard lock_callbacks(callbacks_mutex_); - for (const auto& entry : callbacks_) { - const auto& cb = entry.second.cb; - cb(SourceId, IsEnabled, Level, MatchAnyKeyword, MatchAllKeyword, FilterData, CallbackContext); + for (const auto& callback : callbacks_) { + (*callback)(SourceId, IsEnabled, Level, MatchAnyKeyword, MatchAllKeyword, FilterData, CallbackContext); } } @@ -184,7 +188,8 @@ void WindowsTelemetry::LogProcessInfo() const { TraceLoggingUInt8(0, "schemaVersion"), TraceLoggingString(ORT_VERSION, "runtimeVersion"), TraceLoggingBool(IsDebuggerPresent(), "isDebuggerAttached"), - TraceLoggingBool(isRedist, "isRedist")); + TraceLoggingBool(isRedist, "isRedist"), + TraceLoggingString(ORT_CALLER_FRAMEWORK, "frameworkName")); process_info_logged = true; } @@ -220,7 +225,11 @@ void WindowsTelemetry::LogEvaluationStart() const { void WindowsTelemetry::LogSessionCreation(uint32_t session_id, int64_t ir_version, const std::string& model_producer_name, const std::string& model_producer_version, const std::string& model_domain, const std::unordered_map& domain_to_version_map, + const std::string& model_file_name, const std::string& model_graph_name, + const std::string& model_weight_type, + const std::string& model_graph_hash, + const std::string& model_weight_hash, const std::unordered_map& model_metadata, const std::string& loaded_from, const std::vector& execution_provider_ids, bool use_fp16, bool captureState) const { @@ -285,7 +294,11 @@ void WindowsTelemetry::LogSessionCreation(uint32_t session_id, int64_t ir_versio TraceLoggingString(model_domain.c_str(), "modelDomain"), TraceLoggingBool(use_fp16, "usefp16"), TraceLoggingString(domain_to_version_string.c_str(), "domainToVersionMap"), + TraceLoggingString(model_file_name.c_str(), "modelFileName"), TraceLoggingString(model_graph_name.c_str(), "modelGraphName"), + TraceLoggingString(model_weight_type.c_str(), "modelWeightType"), + TraceLoggingString(model_graph_hash.c_str(), "modelGraphHash"), + TraceLoggingString(model_weight_hash.c_str(), "modelWeightHash"), TraceLoggingString(model_metadata_string.c_str(), "modelMetaData"), TraceLoggingString(loaded_from.c_str(), "loadedFrom"), TraceLoggingString(execution_provider_string.c_str(), "executionProviderIds")); @@ -307,7 +320,11 @@ void WindowsTelemetry::LogSessionCreation(uint32_t session_id, int64_t ir_versio TraceLoggingString(model_domain.c_str(), "modelDomain"), TraceLoggingBool(use_fp16, "usefp16"), TraceLoggingString(domain_to_version_string.c_str(), "domainToVersionMap"), + TraceLoggingString(model_file_name.c_str(), "modelFileName"), TraceLoggingString(model_graph_name.c_str(), "modelGraphName"), + TraceLoggingString(model_weight_type.c_str(), "modelWeightType"), + TraceLoggingString(model_graph_hash.c_str(), "modelGraphHash"), + TraceLoggingString(model_weight_hash.c_str(), "modelWeightHash"), TraceLoggingString(model_metadata_string.c_str(), "modelMetaData"), TraceLoggingString(loaded_from.c_str(), "loadedFrom"), TraceLoggingString(execution_provider_string.c_str(), "executionProviderIds")); @@ -356,10 +373,22 @@ void WindowsTelemetry::LogRuntimeError(uint32_t session_id, const common::Status #endif } -void WindowsTelemetry::LogRuntimePerf(uint32_t session_id, uint32_t total_runs_since_last, int64_t total_run_duration_since_last) const { +void WindowsTelemetry::LogRuntimePerf(uint32_t session_id, uint32_t total_runs_since_last, int64_t total_run_duration_since_last, + std::unordered_map duration_per_batch_size) const { if (global_register_count_ == 0 || enabled_ == false) return; + // Convert duration_per_batch_size to a formatted string + std::string total_duration_per_batch_size; + for (const auto& entry : duration_per_batch_size) { + if (!total_duration_per_batch_size.empty()) { + total_duration_per_batch_size += ", "; + } + total_duration_per_batch_size += std::to_string(entry.first); + total_duration_per_batch_size += ": "; + total_duration_per_batch_size += std::to_string(entry.second); + } + TraceLoggingWrite(telemetry_provider_handle, "RuntimePerf", TraceLoggingBool(true, "UTCReplace_AppSessionGuid"), @@ -369,7 +398,8 @@ void WindowsTelemetry::LogRuntimePerf(uint32_t session_id, uint32_t total_runs_s TraceLoggingUInt8(0, "schemaVersion"), TraceLoggingUInt32(session_id, "sessionId"), TraceLoggingUInt32(total_runs_since_last, "totalRuns"), - TraceLoggingInt64(total_run_duration_since_last, "totalRunDuration")); + TraceLoggingInt64(total_run_duration_since_last, "totalRunDuration"), + TraceLoggingString(total_duration_per_batch_size.c_str(), "totalRunDurationPerBatchSize")); } void WindowsTelemetry::LogExecutionProviderEvent(LUID* adapterLuid) const { diff --git a/onnxruntime/core/platform/windows/telemetry.h b/onnxruntime/core/platform/windows/telemetry.h index ff4afcae79ac5..1b4cc7b5408a5 100644 --- a/onnxruntime/core/platform/windows/telemetry.h +++ b/onnxruntime/core/platform/windows/telemetry.h @@ -2,9 +2,7 @@ // Licensed under the MIT License. #pragma once - #include -#include #include #include "core/platform/telemetry.h" @@ -13,6 +11,8 @@ #include #include "core/platform/windows/TraceLoggingConfig.h" +static constexpr size_t TelemetrySampleCount = 10; + namespace onnxruntime { /** @@ -36,6 +36,9 @@ class WindowsTelemetry : public Telemetry { // Get the current keyword UINT64 Keyword() const override; + // Get the ETW registration status + // static HRESULT Status(); + void LogProcessInfo() const override; void LogSessionCreationStart() const override; @@ -47,7 +50,11 @@ class WindowsTelemetry : public Telemetry { void LogSessionCreation(uint32_t session_id, int64_t ir_version, const std::string& model_producer_name, const std::string& model_producer_version, const std::string& model_domain, const std::unordered_map& domain_to_version_map, + const std::string& model_file_name, const std::string& model_graph_name, + const std::string& model_weight_type, + const std::string& model_graph_hash, + const std::string& model_weight_hash, const std::unordered_map& model_metadata, const std::string& loadedFrom, const std::vector& execution_provider_ids, bool use_fp16, bool captureState) const override; @@ -55,7 +62,8 @@ class WindowsTelemetry : public Telemetry { void LogRuntimeError(uint32_t session_id, const common::Status& status, const char* file, const char* function, uint32_t line) const override; - void LogRuntimePerf(uint32_t session_id, uint32_t total_runs_since_last, int64_t total_run_duration_since_last) const override; + void LogRuntimePerf(uint32_t session_id, uint32_t total_runs_since_last, int64_t total_run_duration_since_last, + std::unordered_map duration_per_batch_size) const override; void LogExecutionProviderEvent(LUID* adapterLuid) const override; @@ -67,30 +75,21 @@ class WindowsTelemetry : public Telemetry { ULONGLONG MatchAnyKeyword, ULONGLONG MatchAllKeyword, PEVENT_FILTER_DESCRIPTOR FilterData, PVOID CallbackContext)>; - static void RegisterInternalCallback(const std::string& callback_key, EtwInternalCallback callback); + static void RegisterInternalCallback(const EtwInternalCallback& callback); - static void UnregisterInternalCallback(const std::string& callback_key); + static void UnregisterInternalCallback(const EtwInternalCallback& callback); private: static std::mutex mutex_; static uint32_t global_register_count_; + static bool enabled_; static uint32_t projection_; - struct CallbackRecord { - EtwInternalCallback cb; - int ref = 1; - explicit CallbackRecord(EtwInternalCallback cb) : cb(std::move(cb)) {} - void IncrementRef() { ++ref; } - int DecrementRef() { return --ref; } - }; - - static std::unordered_map callbacks_; + static std::vector callbacks_; static std::mutex callbacks_mutex_; - - static std::atomic_bool enabled_; - static std::atomic level_; - static std::atomic keyword_; static std::mutex provider_change_mutex_; + static UCHAR level_; + static ULONGLONG keyword_; static void InvokeCallbacks(LPCGUID SourceId, ULONG IsEnabled, UCHAR Level, ULONGLONG MatchAnyKeyword, ULONGLONG MatchAllKeyword, PEVENT_FILTER_DESCRIPTOR FilterData, PVOID CallbackContext); diff --git a/onnxruntime/core/providers/cann/cann_allocator.h b/onnxruntime/core/providers/cann/cann_allocator.h index 1022374b51d9f..14daf46e45b16 100644 --- a/onnxruntime/core/providers/cann/cann_allocator.h +++ b/onnxruntime/core/providers/cann/cann_allocator.h @@ -15,8 +15,9 @@ class CANNAllocator : public IAllocator { CANNAllocator(OrtDevice::DeviceId device_id, const char* name) : IAllocator( OrtMemoryInfo(name, OrtAllocatorType::OrtDeviceAllocator, - OrtDevice(OrtDevice::NPU, OrtDevice::MemType::DEFAULT, device_id), - device_id, OrtMemTypeDefault)) {} + OrtDevice(OrtDevice::NPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::HUAWEI, + device_id), + OrtMemTypeDefault)) {} void* Alloc(size_t size) override; void Free(void* p) override; }; @@ -26,8 +27,9 @@ class CANNPinnedAllocator : public IAllocator { CANNPinnedAllocator(OrtDevice::DeviceId device_id, const char* name) : IAllocator( OrtMemoryInfo(name, OrtAllocatorType::OrtDeviceAllocator, - OrtDevice(OrtDevice::CPU, OrtDevice::MemType::CANN_PINNED, device_id), - device_id, OrtMemTypeCPUOutput)) {} + OrtDevice(OrtDevice::CPU, OrtDevice::MemType::HOST_ACCESSIBLE, OrtDevice::VendorIds::HUAWEI, + device_id), + OrtMemTypeCPUOutput)) {} void* Alloc(size_t size) override; void Free(void* p) override; diff --git a/onnxruntime/core/providers/cann/cann_execution_provider.cc b/onnxruntime/core/providers/cann/cann_execution_provider.cc index be09eefba791b..47ba70d4a5529 100644 --- a/onnxruntime/core/providers/cann/cann_execution_provider.cc +++ b/onnxruntime/core/providers/cann/cann_execution_provider.cc @@ -1028,7 +1028,10 @@ Status RegisterCANNKernels(KernelRegistry& kernel_registry) { } // namespace cann CANNExecutionProvider::CANNExecutionProvider(const CANNExecutionProviderInfo& info) - : IExecutionProvider{onnxruntime::kCannExecutionProvider, OrtDevice(OrtDevice::NPU, OrtDevice::MemType::DEFAULT, info.device_id)}, info_{info} { + : IExecutionProvider{onnxruntime::kCannExecutionProvider, + OrtDevice(OrtDevice::NPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::HUAWEI, + info.device_id)}, + info_{info} { InitProviderOrtApi(); CANN_CALL_THROW(aclrtSetDevice(info_.device_id)); @@ -1485,8 +1488,10 @@ void CANNExecutionProvider::RegisterStreamHandlers(IStreamCommandHandleRegistry& } OrtDevice CANNExecutionProvider::GetOrtDeviceByMemType(OrtMemType mem_type) const { - if (mem_type == OrtMemTypeCPUInput) return OrtDevice(); - if (mem_type == OrtMemTypeCPUOutput) return OrtDevice(OrtDevice::CPU, OrtDevice::MemType::CANN_PINNED, 0); + if (mem_type == OrtMemTypeCPUInput) + return OrtDevice(); + if (mem_type == OrtMemTypeCPUOutput) + return OrtDevice(OrtDevice::CPU, OrtDevice::MemType::HOST_ACCESSIBLE, OrtDevice::VendorIds::HUAWEI, 0); return default_device_; } diff --git a/onnxruntime/core/providers/cann/npu_data_transfer.cc b/onnxruntime/core/providers/cann/npu_data_transfer.cc index 2f51c550b207e..7821926a98a94 100644 --- a/onnxruntime/core/providers/cann/npu_data_transfer.cc +++ b/onnxruntime/core/providers/cann/npu_data_transfer.cc @@ -98,7 +98,7 @@ common::Status NPUDataTransfer::CopyTensorAsync(const Tensor& src, Tensor& dst, static_cast(stream.GetHandle()))); } } else { - if (src_device.MemType() == OrtDevice::MemType::CANN_PINNED) { + if (src_device.MemType() == OrtDevice::MemType::HOST_ACCESSIBLE) { // sync the stream first to make sure the data arrived CANN_RETURN_IF_ERROR(aclrtSynchronizeStream(static_cast(stream.GetHandle()))); } diff --git a/onnxruntime/core/providers/coreml/builders/impl/reduction_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/reduction_op_builder.cc index a4609eb2a0584..fb3d3c80ec372 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/reduction_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/reduction_op_builder.cc @@ -59,8 +59,8 @@ Status ReductionOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, co NodeAttrHelper helper(node); if (input_defs.size() > 1 && input_defs[1]->Exists()) { - auto& axes_tensor = *model_builder.GetConstantInitializer(input_defs[1]->Name()); - Initializer axes_initializer(axes_tensor); + const auto& axes_tensor = *model_builder.GetConstantInitializer(input_defs[1]->Name()); + Initializer axes_initializer(model_builder.GetGraphViewer().GetGraph(), axes_tensor); int64_t* data = axes_initializer.data(); int64_t size = axes_initializer.size(); diff --git a/onnxruntime/core/providers/coreml/builders/impl/reshape_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/reshape_op_builder.cc index b35d6971623ed..e3781ed7d388b 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/reshape_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/reshape_op_builder.cc @@ -44,7 +44,7 @@ Status ReshapeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const auto& data_name = input_defs[0]->Name(); const auto& new_shape_name = input_defs[1]->Name(); - Initializer unpacked_tensor(*model_builder.GetConstantInitializer(new_shape_name)); + Initializer unpacked_tensor(model_builder.GetGraphViewer().GetGraph(), *model_builder.GetConstantInitializer(new_shape_name)); TensorShapeVector new_shape = ToShapeVector(unpacked_tensor.DataAsSpan()); // ReshapeHelper applies the ONNX rules to create the concrete output shape @@ -75,7 +75,8 @@ Status ReshapeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, return Status::OK(); } -bool ReshapeOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, +bool ReshapeOpBuilder::IsOpSupportedImpl(const Node& node, + const OpBuilderInputParams& input_params, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); const auto& new_shape_name = input_defs[1]->Name(); @@ -87,7 +88,7 @@ bool ReshapeOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputP return false; } - Initializer unpacked_tensor(*new_shape_tensor); + Initializer unpacked_tensor(input_params.graph_viewer.GetGraph(), *new_shape_tensor); auto new_shape = unpacked_tensor.DataAsSpan(); if (new_shape.empty()) { LOGS(logger, VERBOSE) << "New shape of reshape cannot be empty"; diff --git a/onnxruntime/core/providers/coreml/builders/impl/resize_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/resize_op_builder.cc index 837573003e515..9b1545035104c 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/resize_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/resize_op_builder.cc @@ -77,7 +77,8 @@ bool GetValidatedResizeScales(const GraphViewer& graph_viewer, return false; } - Initializer unpacked_tensor(*scales_tensor); + const auto& graph = graph_viewer.GetGraph(); + Initializer unpacked_tensor(graph, *scales_tensor, graph.ModelPath()); auto scales_data = unpacked_tensor.DataAsSpan(); scales.assign(scales_data.begin(), scales_data.end()); @@ -108,7 +109,7 @@ bool GetValidatedResizeSizes(const GraphViewer& graph_viewer, return false; } - Initializer unpacked_tensor(*sizes_tensor); + Initializer unpacked_tensor(graph_viewer.GetGraph(), *sizes_tensor, graph_viewer.ModelPath()); auto sizes_data = unpacked_tensor.DataAsSpan(); sizes.assign(sizes_data.begin(), sizes_data.end()); diff --git a/onnxruntime/core/providers/coreml/builders/impl/slice_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/slice_op_builder.cc index bf72fbbf1ace4..1a0f4e4de2e09 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/slice_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/slice_op_builder.cc @@ -59,7 +59,7 @@ Status PrepareSliceComputeMetadata(const Node& slice_node, const auto* tensor_proto = graph_viewer.GetConstantInitializer(input_defs[input_idx]->Name()); ORT_RETURN_IF_NOT(tensor_proto, "Failed to get constant initializer."); - Initializer unpacked_tensor(*tensor_proto, graph_viewer.ModelPath()); + Initializer unpacked_tensor(graph_viewer.GetGraph(), *tensor_proto, graph_viewer.ModelPath()); const auto data_type = unpacked_tensor.data_type(); if (data_type == ONNX_NAMESPACE::TensorProto_DataType_INT64) { auto tensor_data = unpacked_tensor.DataAsSpan(); diff --git a/onnxruntime/core/providers/coreml/builders/impl/split_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/split_op_builder.cc index 717d344982473..4ee9b54cebd16 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/split_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/split_op_builder.cc @@ -63,7 +63,8 @@ Status SplitOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, if (input_defs.size() > 1) { // if "split" is explicitly provided as an input - Initializer unpacked_tensor(*model_builder.GetConstantInitializer(input_defs[1]->Name())); + const auto& const_init = *model_builder.GetConstantInitializer(input_defs[1]->Name()); + Initializer unpacked_tensor(const_init); auto split_span = unpacked_tensor.DataAsSpan(); AddOperationInput(*split_op, "split_sizes", model_builder.AddConstant(split_op->type(), "split_sizes", split_span)); @@ -102,7 +103,8 @@ Status SplitOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, if (input_defs.size() > 1) { // if "split" is explicitly provided as an input // const auto& split_tensor = *model_builder.GetInitializerTensors().at(input_defs[1]->Name()); - Initializer unpacked_tensor(*model_builder.GetConstantInitializer(input_defs[1]->Name())); + const auto& const_init = *model_builder.GetConstantInitializer(input_defs[1]->Name()); + Initializer unpacked_tensor(model_builder.GetGraphViewer().GetGraph(), const_init); auto split_span = unpacked_tensor.DataAsSpan(); for (const auto& split_size : split_span) { coreml_splitnd->add_splitsizes(split_size); @@ -164,7 +166,8 @@ bool SplitOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputPar return false; } - Initializer unpacked_tensor(*splits_tensor); + Initializer unpacked_tensor(input_params.graph_viewer.GetGraph(), *splits_tensor, + input_params.graph_viewer.ModelPath()); auto splits_span = unpacked_tensor.DataAsSpan(); int64_t sum_of_splits = std::accumulate(splits_span.begin(), splits_span.end(), int64_t{0}); if (sum_of_splits != split_dims_at_axis) { diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc index 51b3146c25a73..6a89fc6234f0f 100644 --- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc +++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc @@ -1332,6 +1332,8 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 23, Si class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 23, float, RMSNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 23, double, RMSNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 23, MLFloat16, RMSNormalization); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 23, float, RotaryEmbedding); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 23, MLFloat16, RotaryEmbedding); // !!PLEASE READ BELOW!! Following that, add new entries above this comment @@ -3318,6 +3320,10 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { RMSNormalization)>, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, }; for (auto& function_table_entry : function_table) { KernelCreateInfo info = function_table_entry(); diff --git a/onnxruntime/core/providers/cpu/cpu_provider_factory.cc b/onnxruntime/core/providers/cpu/cpu_provider_factory.cc index 22cf7024663a9..5ca1328011312 100644 --- a/onnxruntime/core/providers/cpu/cpu_provider_factory.cc +++ b/onnxruntime/core/providers/cpu/cpu_provider_factory.cc @@ -54,6 +54,6 @@ ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProvider_CPU, _In_ OrtSessio #endif ORT_API_STATUS_IMPL(OrtApis::CreateCpuMemoryInfo, enum OrtAllocatorType type, enum OrtMemType mem_type, _Outptr_ OrtMemoryInfo** out) { - *out = new OrtMemoryInfo(onnxruntime::CPU, type, OrtDevice(), 0, mem_type); + *out = new OrtMemoryInfo(onnxruntime::CPU, type, OrtDevice(), mem_type); return nullptr; } diff --git a/onnxruntime/core/providers/cpu/llm/rotary_embedding.cc b/onnxruntime/core/providers/cpu/llm/rotary_embedding.cc new file mode 100644 index 0000000000000..616374eee6ff1 --- /dev/null +++ b/onnxruntime/core/providers/cpu/llm/rotary_embedding.cc @@ -0,0 +1,139 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/cpu/llm/rotary_embedding.h" +#include "core/providers/cpu/llm/rotary_embedding_helper.h" + +#include "core/mlas/inc/mlas.h" +#include "core/platform/threadpool.h" + +using onnxruntime::concurrency::ThreadPool; +using namespace onnxruntime::rotary_embedding_helper; + +namespace onnxruntime { + +#define REGISTER_ONNX_KERNEL_TYPED(T) \ + ONNX_CPU_OPERATOR_TYPED_KERNEL( \ + RotaryEmbedding, \ + 23, \ + T, \ + KernelDefBuilder() \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("M", DataTypeImpl::GetTensorType()), \ + RotaryEmbedding); + +REGISTER_ONNX_KERNEL_TYPED(float) +REGISTER_ONNX_KERNEL_TYPED(MLFloat16) + +template +RotaryEmbedding::RotaryEmbedding(const OpKernelInfo& info) : OpKernel(info) { + num_heads = static_cast(info.GetAttrOrDefault("num_heads", 0)); + rotary_embedding_dim = static_cast(info.GetAttrOrDefault("rotary_embedding_dim", 0)); + interleaved = (info.GetAttrOrDefault("interleaved", 0) == 1); // Turn 0/1 into bool + + if (rotary_embedding_dim > 0) { + ORT_ENFORCE(num_heads > 0, "num_heads must be provided if rotary_embedding_dim is specified"); + } +} + +// TODO: rotary embedding in place +template +Status RunRotaryEmbedding(concurrency::ThreadPool* tp, RotaryParameters parameters, const T* input, + const int64_t* position_ids, const T* cos_cache, const T* sin_cache, T* output, + bool interleaved) { + const int batch_size = parameters.batch_size; + const int sequence_length = parameters.sequence_length; + const int n_heads = parameters.num_heads; + const int head_size = parameters.head_size; + const int head_stride = parameters.head_stride; + const int seq_stride = parameters.seq_stride; + const int batch_stride = parameters.batch_stride; + const int position_ids_format = parameters.position_ids_format; + const int rotary_emb_dim = parameters.rotary_embedding_dim; + const int half_rotary_emb_dim = rotary_emb_dim / 2; + // Parallel to calculate based on head_size + const int loop_len = batch_size * sequence_length * n_heads; + // The cost is calculated as: + // - head_size * sizeof(T) for reading input + // - head_size * sizeof(T) for writing output + // - rotary_emb_dim * 32 for the rotary embedding operations (32 is an approximation of the number of CPU cycles) + const double cost = static_cast(head_size * sizeof(T) * 2 + rotary_emb_dim * 32); + ThreadPool::TryParallelFor(tp, loop_len, cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) { + for (std::ptrdiff_t ptr = begin; ptr != end; ++ptr) { + const int b = static_cast((ptr / n_heads) / sequence_length); + const int s = static_cast((ptr / n_heads) % sequence_length); + const int n = static_cast(ptr % n_heads); + // Identify the index of batch, sequence, and head (specific range) in the input/output tensor + // for read/write + const int block_offset = b * batch_stride + s * seq_stride + n * head_stride; + const T* input_data = input + block_offset; + T* output_data = output + block_offset; + + const T* cos_data; + const T* sin_data; + int cache_offset; + if (position_ids_format == 0) { + cache_offset = (b * sequence_length + s) * half_rotary_emb_dim; + } else { + // Cache is (M, H/2) or (M, rotary_embedding_dim/2) + const int position_id = static_cast(position_ids[b * sequence_length + s]); + cache_offset = position_id * half_rotary_emb_dim; + } + cos_data = cos_cache + cache_offset; + sin_data = sin_cache + cache_offset; + + MlasRotaryEmbedOneRow(input_data, sin_data, cos_data, rotary_emb_dim, interleaved, output_data); + + if (rotary_emb_dim < head_size) { + std::memcpy(output_data + rotary_emb_dim, + input_data + rotary_emb_dim, + (head_size - rotary_emb_dim) * sizeof(T)); + } + } + }); + + return Status::OK(); +} + +template Status RunRotaryEmbedding(concurrency::ThreadPool* tp, RotaryParameters parameters, const float* input, + const int64_t* position_ids, const float* cos_cache, const float* sin_cache, float* output, + bool interleaved); + +template Status RunRotaryEmbedding(concurrency::ThreadPool* tp, RotaryParameters parameters, const MLFloat16* input, + const int64_t* position_ids, const MLFloat16* cos_cache, const MLFloat16* sin_cache, + MLFloat16* output, bool interleaved); + +template +Status RotaryEmbedding::Compute(OpKernelContext* context) const { + const Tensor* X = context->Input(0); + const Tensor* cos_cache = context->Input(1); + const Tensor* sin_cache = context->Input(2); + // Optional position_ids input, can be nullptr + const Tensor* position_ids = context->Input(3); + + RotaryParameters parameters = {}; + ORT_RETURN_IF_ERROR(rotary_embedding_helper::CheckInputs(X, + position_ids, + cos_cache, + sin_cache, + num_heads, + rotary_embedding_dim, + ¶meters)); + + Tensor* output = context->Output(0, X->Shape()); + + const T* x_src = X->Data(); + const int64_t* pos_ids_data = (nullptr == position_ids) ? nullptr : position_ids->Data(); + const T* cos_cache_data = cos_cache->Data(); + const T* sin_cache_data = sin_cache->Data(); + T* output_dest = output->MutableData(); + + AllocatorPtr allocator; + ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator)); + auto* tp = context->GetOperatorThreadPool(); + + return RunRotaryEmbedding(tp, parameters, x_src, pos_ids_data, cos_cache_data, sin_cache_data, output_dest, + interleaved); +} + +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/llm/rotary_embedding.h b/onnxruntime/core/providers/cpu/llm/rotary_embedding.h new file mode 100644 index 0000000000000..7f958fbed8030 --- /dev/null +++ b/onnxruntime/core/providers/cpu/llm/rotary_embedding.h @@ -0,0 +1,28 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/common/common.h" +#include "core/framework/op_kernel.h" +#include "core/providers/cpu/llm/rotary_embedding_helper.h" + +namespace onnxruntime { + +template +Status RunRotaryEmbedding(onnxruntime::concurrency::ThreadPool* tp, rotary_embedding_helper::RotaryParameters parameters, const T* input, + const int64_t* position_ids, const T* cos_cache, const T* sin_cache, T* output, + bool interleaved); + +template +class RotaryEmbedding final : public OpKernel { + public: + RotaryEmbedding(const OpKernelInfo& info); + Status Compute(OpKernelContext* context) const override; + + protected: + int num_heads; + int rotary_embedding_dim; + int interleaved; +}; + +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/llm/rotary_embedding_helper.h b/onnxruntime/core/providers/cpu/llm/rotary_embedding_helper.h new file mode 100644 index 0000000000000..d9f8e03cddcb3 --- /dev/null +++ b/onnxruntime/core/providers/cpu/llm/rotary_embedding_helper.h @@ -0,0 +1,207 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/common/common.h" +#include "core/providers/common.h" + +namespace onnxruntime { +namespace rotary_embedding_helper { + +// Parameters deduced from node attributes and inputs/outputs. +struct RotaryParameters { + int batch_size; // Batch size used by input + int sequence_length; // Sequence length used by input + int hidden_size; // Hidden size used by input + int head_size; // Head size + int rotary_embedding_dim; // Rotary embedding dimension. + int num_heads; // num_heads = hidden_size / head_size + int max_sequence_length; // Sequence length used by cos/sin cache + int head_stride; // Head stride + int seq_stride; // Sequence stride + int batch_stride; // Batch stride + int position_ids_format; // Format of position ids - 0 is (0), 1 is (batch_size, sequence_length) + bool transposed; // Whether the input tensor has been transposed into (batch, num_heads, seq_len, hidden) +}; + +template +Status CheckInputs(const T* input, + const T* position_ids, + const T* cos_cache, + const T* sin_cache, + int num_heads, + int rotary_embedding_dim, + void* parameters) { + // input : (batch_size, sequence_length, hidden_size) or (batch_size, num_heads, sequence_length, head_size) + // IF position ids : (0) + // rotary_embedding_dim == 0: + // cos_cache : (batch_size, sequence_length, head_size / 2) + // sin_cache : (batch_size, sequence_length, head_size / 2) + // rotary_embedding_dim > 0: + // cos_cache : (batch_size, sequence_length, rotary_embedding_dim / 2) + // sin_cache : (batch_size, sequence_length, rotary_embedding_dim / 2) + // ELSE position ids : (batch_size, sequence_length) + // rotary_embedding_dim == 0: + // cos_cache : (max_position_id_plus_1, head_size / 2) + // sin_cache : (max_position_id_plus_1, head_size / 2) + // rotary_embedding_dim > 0: + // cos_cache : (max_position_id_plus_1, rotary_embedding_dim / 2) + // sin_cache : (max_position_id_plus_1, rotary_embedding_dim / 2) + + // Check input is either 3d or 4d + const auto& input_dims = input->Shape().GetDims(); + + // Get attributes from inputs + int batch_size = static_cast(input_dims[0]); + int sequence_length; + int hidden_size; + int head_size; + + // If it's 4d, it is expected to have shape [batch, num_heads, seq_len, head_size]. + bool transposed = false; + if (input_dims.size() == 4) { + sequence_length = static_cast(input_dims[2]); + num_heads = static_cast(input_dims[1]); + head_size = static_cast(input_dims[3]); + hidden_size = num_heads * head_size; + transposed = true; + } else if (input_dims.size() == 3) { + // If it's 3d, it is expected to have shape [batch, seq_len, hidden_size]. + sequence_length = static_cast(input_dims[1]); + hidden_size = static_cast(input_dims[2]); + head_size = hidden_size / num_heads; + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'x' is expected to have 3 or 4 dimensions, got ", + input_dims.size()); + } + + int position_ids_format = 0; + int max_sequence_length = 0; + // if position_ids is not provided, cos_cache and sin_cache are expected to have 3 dimensions + // else they are expected to have 2 dimensions. + if (nullptr == position_ids) { + // Check cos_cache and sin_cache + const auto& cos_cache_dims = cos_cache->Shape().GetDims(); + if (cos_cache_dims.size() != 3) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'cos_cache' is expected to have 3 dimensions, got ", + cos_cache_dims.size()); + } + const auto& sin_cache_dims = sin_cache->Shape().GetDims(); + if (sin_cache_dims.size() != 3) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'sin_cache' is expected to have 3 dimensions, got ", + sin_cache_dims.size()); + } + if (cos_cache_dims[0] != sin_cache_dims[0] || cos_cache_dims[1] != sin_cache_dims[1] || cos_cache_dims[2] != sin_cache_dims[2]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Inputs 'cos_cache' and 'sin_cache' are expected to have ", + "the same shape"); + } + // Make sure cos_cache and sin_cache have the same batch size and sequence length as input x + // when position_ids is not provided. + if (cos_cache_dims[0] != batch_size || cos_cache_dims[1] != sequence_length) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Inputs 'cos_cache' and 'sin_cache' are expected to have ", + "the same shape as input 'x', got ", cos_cache_dims[0], " and ", cos_cache_dims[1]); + } + + max_sequence_length = static_cast(cos_cache_dims[1]); + + if (rotary_embedding_dim > 0 && rotary_embedding_dim > head_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "rotary_embedding_dim must be less than or equal to ", + "head_size"); + } + // Check cos_cache input shapes + if (cos_cache_dims[2] != (rotary_embedding_dim > 0 ? rotary_embedding_dim : head_size) / 2) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'cos_cache' dimension 2 should be same as ", + "head_size / 2 or rotary_embedding_dim / 2, got ", cos_cache_dims[2]); + } + } else { + // Check cos_cache and sin_cache + const auto& cos_cache_dims = cos_cache->Shape().GetDims(); + if (cos_cache_dims.size() != 2) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'cos_cache' is expected to have 2 dimensions, got ", + cos_cache_dims.size()); + } + const auto& sin_cache_dims = sin_cache->Shape().GetDims(); + if (sin_cache_dims.size() != 2) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'sin_cache' is expected to have 2 dimensions, got ", + sin_cache_dims.size()); + } + if (cos_cache_dims[0] != sin_cache_dims[0] || cos_cache_dims[1] != sin_cache_dims[1]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Inputs 'cos_cache' and 'sin_cache' are expected to have ", + "the same shape"); + } + // Check position_ids + const auto& position_ids_dims = position_ids->Shape().GetDims(); + if (position_ids_dims.size() != 2) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'position_ids' is expected to have 2 ", + "dimensions, got ", position_ids_dims.size()); + } + + max_sequence_length = static_cast(cos_cache_dims[0]); + + if (rotary_embedding_dim > 0 && rotary_embedding_dim > head_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "rotary_embedding_dim must be less than or equal to ", + "head_size"); + } + + // Check position_ids input shapes + if (batch_size != static_cast(position_ids_dims[0])) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'position_ids' dimension 0 should be of size ", + "batch_size, got ", position_ids_dims[0]); + } + if (sequence_length != static_cast(position_ids_dims[1])) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'position_ids' dimension 1 should be of size ", + "sequence_length, got ", position_ids_dims[1]); + } + position_ids_format = 1; + + // Check cos_cache input shapes + if (cos_cache_dims[1] != (rotary_embedding_dim > 0 ? rotary_embedding_dim : head_size) / 2) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'cos_cache' dimension 1 should be same as ", + "head_size / 2 or rotary_embedding_dim / 2, got ", cos_cache_dims[1]); + } + } + + if (sequence_length > max_sequence_length) { + // Launch update_cos_sin_cache kernel with scale + ORT_NOT_IMPLEMENTED("Updating cos_cache and sin_cache in RotaryEmbedding is not currently supported"); + } + + num_heads = num_heads > 0 ? num_heads : static_cast(hidden_size / head_size); + // Calculate stride values + int head_stride; + int seq_stride; + int batch_stride; + if (transposed) { + // Transposed input tensor shape is [batch, n_heads, seq_len, head_size] + seq_stride = head_size; + head_stride = sequence_length * seq_stride; + batch_stride = num_heads * head_stride; + } else { + // Default input tensor shape is [batch, seq_len, hidden_size] + head_stride = head_size; + seq_stride = num_heads * head_stride; + batch_stride = sequence_length * seq_stride; + } + + // Set rotary parameters + if (parameters != nullptr) { + RotaryParameters* output_parameters = reinterpret_cast(parameters); + output_parameters->batch_size = batch_size; + output_parameters->sequence_length = sequence_length; + output_parameters->hidden_size = hidden_size; + output_parameters->head_size = head_size; + output_parameters->num_heads = num_heads; + output_parameters->max_sequence_length = max_sequence_length; + output_parameters->head_stride = head_stride; + output_parameters->seq_stride = seq_stride; + output_parameters->batch_stride = batch_stride; + output_parameters->position_ids_format = position_ids_format; + output_parameters->transposed = transposed; + output_parameters->rotary_embedding_dim = rotary_embedding_dim > 0 ? rotary_embedding_dim : head_size; + } + + return Status::OK(); +} + +} // namespace rotary_embedding_helper +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/tensor/transpose.cc b/onnxruntime/core/providers/cpu/tensor/transpose.cc index d4fea3c5a75c7..9043593e5fc9e 100644 --- a/onnxruntime/core/providers/cpu/tensor/transpose.cc +++ b/onnxruntime/core/providers/cpu/tensor/transpose.cc @@ -395,7 +395,7 @@ static Status DoTransposeInt4(const gsl::span& permutations, const "Expected to transpose int4 tensor"); // Convert to Tensor, transpose, and then repack back to Tensor. - AllocatorPtr cpu_allocator = std::make_shared(); + AllocatorPtr cpu_allocator = CPUAllocator::DefaultInstance(); Tensor input_unpacked; Tensor output_unpacked(DataTypeImpl::GetType(), output.Shape(), cpu_allocator); diff --git a/onnxruntime/core/providers/cpu/tensor/utils.h b/onnxruntime/core/providers/cpu/tensor/utils.h index 6adcfec852690..313e9ea4b9948 100644 --- a/onnxruntime/core/providers/cpu/tensor/utils.h +++ b/onnxruntime/core/providers/cpu/tensor/utils.h @@ -441,6 +441,8 @@ struct SliceIterator : public SliceIteratorBase { }; inline void CopyCpuTensor(const Tensor* src, Tensor* tgt) { + ORT_ENFORCE(src->SizeInBytes() == tgt->SizeInBytes(), "Destination size does not match source."); + void* target = tgt->MutableDataRaw(); const void* source = src->DataRaw(); diff --git a/onnxruntime/core/providers/cuda/cuda_allocator.cc b/onnxruntime/core/providers/cuda/cuda_allocator.cc index 8c96d8f57a0ba..f371f76832582 100644 --- a/onnxruntime/core/providers/cuda/cuda_allocator.cc +++ b/onnxruntime/core/providers/cuda/cuda_allocator.cc @@ -14,7 +14,7 @@ void CUDAAllocator::CheckDevice(bool throw_when_fail) const { int current_device; auto cuda_err = cudaGetDevice(¤t_device); if (cuda_err == cudaSuccess) { - ORT_ENFORCE(current_device == Info().id); + ORT_ENFORCE(current_device == Info().device.Id()); } else if (throw_when_fail) { CUDA_CALL_THROW(cuda_err); } @@ -27,7 +27,7 @@ void CUDAAllocator::SetDevice(bool throw_when_fail) const { int current_device; auto cuda_err = cudaGetDevice(¤t_device); if (cuda_err == cudaSuccess) { - int allocator_device_id = Info().id; + int allocator_device_id = Info().device.Id(); if (current_device != allocator_device_id) { cuda_err = cudaSetDevice(allocator_device_id); } diff --git a/onnxruntime/core/providers/cuda/cuda_allocator.h b/onnxruntime/core/providers/cuda/cuda_allocator.h index 2d94e2b1cda89..004ec2e876ec0 100644 --- a/onnxruntime/core/providers/cuda/cuda_allocator.h +++ b/onnxruntime/core/providers/cuda/cuda_allocator.h @@ -14,8 +14,9 @@ class CUDAAllocator : public IAllocator { CUDAAllocator(OrtDevice::DeviceId device_id, const char* name) : IAllocator( OrtMemoryInfo(name, OrtAllocatorType::OrtDeviceAllocator, - OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, device_id), - device_id, OrtMemTypeDefault)) {} + OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::NVIDIA, + device_id), + OrtMemTypeDefault)) {} void* Alloc(size_t size) override; void Free(void* p) override; @@ -55,8 +56,9 @@ class CUDAPinnedAllocator : public IAllocator { CUDAPinnedAllocator(const char* name) : IAllocator( OrtMemoryInfo(name, OrtAllocatorType::OrtDeviceAllocator, - OrtDevice(OrtDevice::CPU, OrtDevice::MemType::CUDA_PINNED, 0 /*CPU device always with id 0*/), - 0, OrtMemTypeCPUOutput)) {} + OrtDevice(OrtDevice::CPU, OrtDevice::MemType::HOST_ACCESSIBLE, OrtDevice::VendorIds::NVIDIA, + 0 /*CPU device always with id 0*/), + OrtMemTypeCPUOutput)) {} void* Alloc(size_t size) override; void Free(void* p) override; diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index b6ef787e89e1a..e42422c1ce2b5 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -272,7 +272,9 @@ void OverrideTunableOpInfoByEnv(CUDAExecutionProviderInfo& info) { } CUDAExecutionProvider::CUDAExecutionProvider(const CUDAExecutionProviderInfo& info) - : IExecutionProvider{onnxruntime::kCudaExecutionProvider, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, info.device_id)}, + : IExecutionProvider{onnxruntime::kCudaExecutionProvider, + OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::NVIDIA, + info.device_id)}, info_{info}, tuning_context_(this, &info_.tunable_op) { #ifndef ENABLE_CUDA_NHWC_OPS @@ -2841,8 +2843,11 @@ void CUDAExecutionProvider::RegisterStreamHandlers(IStreamCommandHandleRegistry& OrtDevice CUDAExecutionProvider::GetOrtDeviceByMemType(OrtMemType mem_type) const { // TODO(leca): For CpuInput, return default OrtDevice to make it consistent with previous logic, otherwise, it will fail GradientCheckerTest.TileGrad // in Windows training scenario. However, we need to figure out why PINNED memType fails - if (mem_type == OrtMemTypeCPUInput) return OrtDevice(); - if (mem_type == OrtMemTypeCPUOutput) return OrtDevice(OrtDevice::CPU, OrtDevice::MemType::CUDA_PINNED, 0 /*CPU device id always be 0*/); + if (mem_type == OrtMemTypeCPUInput) + return OrtDevice(); + if (mem_type == OrtMemTypeCPUOutput) + return OrtDevice(OrtDevice::CPU, OrtDevice::MemType::HOST_ACCESSIBLE, OrtDevice::VendorIds::NVIDIA, + 0 /*CPU device id always be 0*/); return default_device_; } diff --git a/onnxruntime/core/providers/cuda/cuda_provider_factory.cc b/onnxruntime/core/providers/cuda/cuda_provider_factory.cc index c4520fe38cd2a..b160c14bdc359 100644 --- a/onnxruntime/core/providers/cuda/cuda_provider_factory.cc +++ b/onnxruntime/core/providers/cuda/cuda_provider_factory.cc @@ -118,9 +118,9 @@ struct ProviderInfo_CUDA_Impl final : ProviderInfo_CUDA { int device; CUDA_CALL_THROW(cudaGetDevice(&device)); - if (device != src_location.id) { + if (device != src_location.device.Id()) { // Need to switch to the allocating device. - CUDA_CALL_THROW(cudaSetDevice(src_location.id)); + CUDA_CALL_THROW(cudaSetDevice(src_location.device.Id())); // Copy from GPU to CPU. CUDA_CALL_THROW(cudaMemcpy(dst_ptr, src_ptr, size, cudaMemcpyDeviceToHost)); // Switch back to current device. diff --git a/onnxruntime/core/providers/cuda/gpu_data_transfer.cc b/onnxruntime/core/providers/cuda/gpu_data_transfer.cc index 4dafbda409cd3..8127b4697de22 100644 --- a/onnxruntime/core/providers/cuda/gpu_data_transfer.cc +++ b/onnxruntime/core/providers/cuda/gpu_data_transfer.cc @@ -8,8 +8,8 @@ namespace onnxruntime { bool GPUDataTransfer::CanCopy(const OrtDevice& src_device, const OrtDevice& dst_device) const { - return src_device.Type() == OrtDevice::GPU || src_device.MemType() == OrtDevice::MemType::CUDA_PINNED || - dst_device.Type() == OrtDevice::GPU || dst_device.MemType() == OrtDevice::MemType::CUDA_PINNED; + return src_device.Type() == OrtDevice::GPU || src_device.MemType() == OrtDevice::MemType::HOST_ACCESSIBLE || + dst_device.Type() == OrtDevice::GPU || dst_device.MemType() == OrtDevice::MemType::HOST_ACCESSIBLE; } common::Status GPUDataTransfer::CopyTensor(const Tensor& src, Tensor& dst) const { @@ -33,7 +33,7 @@ common::Status GPUDataTransfer::CopyTensor(const Tensor& src, Tensor& dst) const } else { // copy from other CPU memory to GPU, this is blocking CUDA_RETURN_IF_ERROR(cudaMemcpy(dst_data, src_data, bytes, cudaMemcpyHostToDevice)); - if (src_device.MemType() != OrtDevice::MemType::CUDA_PINNED) { + if (src_device.MemType() != OrtDevice::MemType::HOST_ACCESSIBLE) { // For cudaMemcpy from pageable host memory to device memory, DMA to final destination may not have completed. // see https://docs.nvidia.com/cuda/cuda-runtime-api/api-sync-behavior.html CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(nullptr)); @@ -62,20 +62,23 @@ common::Status GPUDataTransfer::CopyTensorAsync(const Tensor& src, Tensor& dst, if (dst_device.Type() == OrtDevice::GPU) { if (src_device.Type() == OrtDevice::CPU) { // copy from pinned or non-pinned CPU memory to GPU - CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(dst_data, src_data, bytes, cudaMemcpyHostToDevice, static_cast(stream.GetHandle()))); + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(dst_data, src_data, bytes, cudaMemcpyHostToDevice, + static_cast(stream.GetHandle()))); } else if (src_device.Type() == OrtDevice::GPU) { // copying between GPU, this is non-blocking if (dst_data != src_data) { - CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(dst_data, src_data, bytes, cudaMemcpyDeviceToDevice, static_cast(stream.GetHandle()))); + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(dst_data, src_data, bytes, cudaMemcpyDeviceToDevice, + static_cast(stream.GetHandle()))); } } } else if (src_device.Type() == OrtDevice::GPU) { if (dst_device.Type() == OrtDevice::CPU) { // copy from GPU to pinned or non-pinned CPU memory. - CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(dst_data, src_data, bytes, cudaMemcpyDeviceToHost, static_cast(stream.GetHandle()))); + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(dst_data, src_data, bytes, cudaMemcpyDeviceToHost, + static_cast(stream.GetHandle()))); } } else { - if (src_device.MemType() == OrtDevice::MemType::CUDA_PINNED) { + if (src_device.MemType() == OrtDevice::MemType::HOST_ACCESSIBLE) { // sync the stream first to make sure the data arrived CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(static_cast(stream.GetHandle()))); } diff --git a/onnxruntime/core/providers/cuda/math/clip_impl.cu b/onnxruntime/core/providers/cuda/math/clip_impl.cu index f0ff22269006b..cf4d79eec0280 100644 --- a/onnxruntime/core/providers/cuda/math/clip_impl.cu +++ b/onnxruntime/core/providers/cuda/math/clip_impl.cu @@ -8,10 +8,14 @@ namespace onnxruntime { namespace cuda { template __global__ void _Clip(const T* input, T* output, const T* min, const T* max, T min_default, T max_default, size_t N) { - auto min_val = (min) ? *min : min_default; - auto max_val = (max) ? *max : max_default; + auto min_val = (min) ? *min : min_default; + auto max_val = (max) ? *max : max_default; CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, N); - output[id] = (input[id] < min_val) ? min_val : ((input[id] > max_val) ? max_val : input[id]); + + // output = Min(max, Max(input, min)). Note that min might be larger than max, so we need to compute in two steps. + auto value = input[id]; + value = (value < min_val) ? min_val : value; + output[id] = (value > max_val) ? max_val : value; } template diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/BucketizedBufferAllocator.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/BucketizedBufferAllocator.cpp index 334a40b979bda..18b4b4593f537 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/BucketizedBufferAllocator.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/BucketizedBufferAllocator.cpp @@ -46,7 +46,7 @@ namespace Dml OrtMemoryInfo( "DML", OrtAllocatorType::OrtDeviceAllocator, - OrtDevice(OrtDevice::DML, OrtDevice::MemType::DEFAULT, 0))), + OrtDevice(OrtDevice::DML, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::MICROSOFT, 0))), m_device(device), m_heapProperties(heapProps), m_heapFlags(heapFlags), diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlExternalBufferAllocator.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlExternalBufferAllocator.h index f07b9540ff3fd..c99d686349e94 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlExternalBufferAllocator.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlExternalBufferAllocator.h @@ -24,7 +24,7 @@ namespace Dml OrtMemoryInfo( "DML", OrtAllocatorType::OrtDeviceAllocator, - OrtDevice(OrtDevice::DML, OrtDevice::MemType::DEFAULT, 0) + OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::MICROSOFT, 0) )) { m_device = onnxruntime::DMLProviderFactoryCreator::CreateD3D12Device(device_id, false); diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp index 868b2103586f9..a5066a41981e5 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp @@ -74,7 +74,9 @@ namespace Dml bool enableGraphCapture, bool enableSyncSpinning, bool disableMemoryArena) : - IExecutionProvider(onnxruntime::kDmlExecutionProvider, OrtDevice(OrtDevice::DML, OrtDevice::MemType::DEFAULT, 0)) + IExecutionProvider(onnxruntime::kDmlExecutionProvider, + OrtDevice(OrtDevice::DML, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::MICROSOFT, + 0)) { D3D12_COMMAND_LIST_TYPE queueType = executionContext->GetCommandListTypeForQueue(); if (queueType != D3D12_COMMAND_LIST_TYPE_DIRECT && queueType != D3D12_COMMAND_LIST_TYPE_COMPUTE) @@ -86,7 +88,8 @@ namespace Dml ComPtr device; GRAPHICS_THROW_IF_FAILED(dmlDevice->GetParentDevice(IID_GRAPHICS_PPV_ARGS(device.GetAddressOf()))); - m_impl = wil::MakeOrThrow(dmlDevice, device.Get(), executionContext, enableMetacommands, enableGraphCapture, enableSyncSpinning, disableMemoryArena); + m_impl = wil::MakeOrThrow(dmlDevice, device.Get(), executionContext, enableMetacommands, + enableGraphCapture, enableSyncSpinning, disableMemoryArena); } std::vector> diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h index aa3d8b0b4a409..6799b85747994 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h @@ -249,8 +249,8 @@ namespace Dml bool CanCopy(const OrtDevice& srcDevice, const OrtDevice& dstDevice) const final { - return (srcDevice.Type() == OrtDevice::DML) || - (dstDevice.Type() == OrtDevice::DML); + return ((srcDevice.Type() == OrtDevice::GPU && srcDevice.Vendor() == OrtDevice::VendorIds::MICROSOFT) || + (dstDevice.Type() == OrtDevice::GPU && dstDevice.Vendor() == OrtDevice::VendorIds::MICROSOFT)); } private: diff --git a/onnxruntime/core/providers/dml/GraphTransformers/bn_add_fusion.cc b/onnxruntime/core/providers/dml/GraphTransformers/bn_add_fusion.cc index 79ec47b8f8443..b3886234fa238 100644 --- a/onnxruntime/core/providers/dml/GraphTransformers/bn_add_fusion.cc +++ b/onnxruntime/core/providers/dml/GraphTransformers/bn_add_fusion.cc @@ -62,8 +62,8 @@ Status BatchNormalizationAddFusion::Apply(Graph& graph, Node& node, RewriteRuleE return Status::OK(); } - Initializer BatchNormalization_B{*BatchNormalization_B_tensor_proto, graph.ModelPath()}; - Initializer add_B{*add_B_tensor_proto, graph.ModelPath()}; + Initializer BatchNormalization_B{graph, *BatchNormalization_B_tensor_proto, graph.ModelPath()}; + Initializer add_B{graph, *add_B_tensor_proto, graph.ModelPath()}; if (BatchNormalization_B.size() != add_B.size()) { return Status::OK(); @@ -73,11 +73,12 @@ Status BatchNormalizationAddFusion::Apply(Graph& graph, Node& node, RewriteRuleE // Create new initializers of BatchNormalization ONNX_NAMESPACE::TensorProto new_BatchNormalization_B_tensor_proto; - BatchNormalization_B.ToProto(new_BatchNormalization_B_tensor_proto); + OrtValue ort_value; + BatchNormalization_B.ToProtoWithOrtValue(new_BatchNormalization_B_tensor_proto, ort_value); // Replace initializers of BatchNormalization node graph.RemoveInitializedTensor(BatchNormalization_inputs[2]->Name()); - graph.AddInitializedTensor(new_BatchNormalization_B_tensor_proto); + ORT_RETURN_IF_ERROR(graph.AddInitializedOrtValue(new_BatchNormalization_B_tensor_proto, ort_value)); // Remove Add node. auto* add_node_to_remove = graph.GetNode(add_node.Index()); diff --git a/onnxruntime/core/providers/dml/GraphTransformers/bn_mul_fusion.cc b/onnxruntime/core/providers/dml/GraphTransformers/bn_mul_fusion.cc index 02f16b4d3d467..21c85c7f67d30 100644 --- a/onnxruntime/core/providers/dml/GraphTransformers/bn_mul_fusion.cc +++ b/onnxruntime/core/providers/dml/GraphTransformers/bn_mul_fusion.cc @@ -13,7 +13,8 @@ using namespace ONNX_NAMESPACE; using namespace onnxruntime::common; namespace onnxruntime { -Status BatchNormalizationMulFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const onnxruntime::logging::Logger&) const { +Status BatchNormalizationMulFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, + const onnxruntime::logging::Logger&) const { auto& BatchNormalization_node = node; const auto& mul_node = *BatchNormalization_node.OutputNodesBegin(); const auto& BatchNormalization_inputs = BatchNormalization_node.InputDefs(); @@ -54,8 +55,8 @@ Status BatchNormalizationMulFusion::Apply(Graph& graph, Node& node, RewriteRuleE } } - Initializer BatchNormalization_Scale{*BatchNormalization_Scale_tensor_proto, graph.ModelPath()}; - Initializer mul_B{*mul_B_tensor_proto, graph.ModelPath()}; + Initializer BatchNormalization_Scale{graph, *BatchNormalization_Scale_tensor_proto, graph.ModelPath()}; + Initializer mul_B{graph, *mul_B_tensor_proto, graph.ModelPath()}; const ONNX_NAMESPACE::TensorProto* BatchNormalization_B_tensor_proto = nullptr; if (!graph.GetInitializedTensor(BatchNormalization_inputs[2]->Name(), BatchNormalization_B_tensor_proto)) @@ -67,7 +68,7 @@ Status BatchNormalizationMulFusion::Apply(Graph& graph, Node& node, RewriteRuleE BatchNormalization_B_tensor_proto->dims_size() != 1) { return Status::OK(); } - Initializer BatchNormalization_B{*BatchNormalization_B_tensor_proto, graph.ModelPath()}; + Initializer BatchNormalization_B{graph, *BatchNormalization_B_tensor_proto, graph.ModelPath()}; // Calculate new value of initializers of BatchNormalization node BatchNormalization_Scale.scale_by_axis(mul_B, 1); @@ -79,17 +80,20 @@ Status BatchNormalizationMulFusion::Apply(Graph& graph, Node& node, RewriteRuleE } // Create new initializers of BatchNormalization - ONNX_NAMESPACE::TensorProto new_BatchNormalization_Scale_tensor_proto(*BatchNormalization_Scale_tensor_proto); - BatchNormalization_Scale.ToProto(new_BatchNormalization_Scale_tensor_proto); + ONNX_NAMESPACE::TensorProto new_BatchNormalization_Scale_tensor_proto; + OrtValue ort_value_scale; + BatchNormalization_Scale.ToProtoWithOrtValue(new_BatchNormalization_Scale_tensor_proto, ort_value_scale); // Replace initializers of BatchNormalization node graph.RemoveInitializedTensor(BatchNormalization_inputs[1]->Name()); - graph.AddInitializedTensor(new_BatchNormalization_Scale_tensor_proto); + ORT_RETURN_IF_ERROR(graph.AddInitializedOrtValue(new_BatchNormalization_Scale_tensor_proto, ort_value_scale)); + + ONNX_NAMESPACE::TensorProto new_BatchNormalization_B_tensor_proto; + OrtValue ort_value_B_scale; + BatchNormalization_B.ToProtoWithOrtValue(new_BatchNormalization_B_tensor_proto, ort_value_B_scale); - ONNX_NAMESPACE::TensorProto new_BatchNormalization_B_tensor_proto(*BatchNormalization_B_tensor_proto); - BatchNormalization_B.ToProto(new_BatchNormalization_B_tensor_proto); graph.RemoveInitializedTensor(BatchNormalization_inputs[2]->Name()); - graph.AddInitializedTensor(new_BatchNormalization_B_tensor_proto); + ORT_RETURN_IF_ERROR(graph.AddInitializedOrtValue(new_BatchNormalization_B_tensor_proto, ort_value_B_scale)); // Remove Mul node. auto* mul_node_to_remove = graph.GetNode(mul_node.Index()); diff --git a/onnxruntime/core/providers/dnnl/dnnl_node_capability.cc b/onnxruntime/core/providers/dnnl/dnnl_node_capability.cc index 01f44e91fd49c..bb5d942ecb14a 100644 --- a/onnxruntime/core/providers/dnnl/dnnl_node_capability.cc +++ b/onnxruntime/core/providers/dnnl/dnnl_node_capability.cc @@ -672,30 +672,35 @@ bool DnnlErfNodeCapability::Supported(const Node* node, const GraphViewer& graph return true; } -bool DnnlErfNodeCapability::IsInitilizedWithExpectedValue(const GraphViewer& graph_viewer, const NodeArg* node_arg, float expected_value) const { - // TypeAsProto()->tensor_type().elem_type() - if ((ORT_DataType)node_arg->TypeAsProto()->tensor_type().elem_type() == type_float32) { - const ONNX_NAMESPACE::TensorProto* tensor_proto = nullptr; - graph_viewer.GetInitializedTensor(node_arg->Name(), tensor_proto); - const float* val = reinterpret_cast(tensor_proto->raw_data().data()); - - // Check for NaN and Inf - if (std::isnan(val[0]) || std::isinf(val[0])) { - if (std::isinf(val[0]) && std::isinf(expected_value) && (std::signbit(val[0]) == std::signbit(expected_value))) { - return true; - } - return false; - } +bool DnnlErfNodeCapability::IsInitilizedWithExpectedValue(const GraphViewer& graph_viewer, const NodeArg* node_arg, + float expected_value) const { + const ONNX_NAMESPACE::TensorProto* tensor_proto = nullptr; + if (!graph_viewer.GetInitializedTensor(node_arg->Name(), tensor_proto)) { + return false; + } - const float atol = 1e-8f; - const float rtol = 1e-5f; - float diff = std::abs(val[0] - expected_value); - if (diff > (atol + rtol * std::abs(expected_value))) { - return false; + onnxruntime::Initializer erf_weight{graph_viewer.GetGraph(), *tensor_proto, graph_viewer.ModelPath()}; + if (erf_weight.data_type() != ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { + return false; + } + + const float* val = erf_weight.data(); + + // Check for NaN and Inf + if (std::isnan(val[0]) || std::isinf(val[0])) { + if (std::isinf(val[0]) && std::isinf(expected_value) && (std::signbit(val[0]) == std::signbit(expected_value))) { + return true; } - return true; + return false; } - return false; + + const float atol = 1e-8f; + const float rtol = 1e-5f; + float diff = std::abs(val[0] - expected_value); + if (diff > (atol + rtol * std::abs(expected_value))) { + return false; + } + return true; } const Node* DnnlErfNodeCapability::FirstParentByType(const Node& node, const std::string& parent_type) const { diff --git a/onnxruntime/core/providers/dnnl/subgraph/dnnl_subgraph_transformer.cc b/onnxruntime/core/providers/dnnl/subgraph/dnnl_subgraph_transformer.cc index f6497e381d0f7..fdebe51865f4b 100644 --- a/onnxruntime/core/providers/dnnl/subgraph/dnnl_subgraph_transformer.cc +++ b/onnxruntime/core/providers/dnnl/subgraph/dnnl_subgraph_transformer.cc @@ -181,13 +181,11 @@ bool DnnlGraphTransformer::IsInitilizedWithExpectedValue(const onnxruntime::Grap return false; } - if (!tensor_proto->has_raw_data()) { - return false; - } - const auto data_type = input_arg.Type(); if (data_type == dnnl::memory::data_type::f32) { - const float* val = reinterpret_cast(tensor_proto->raw_data().data()); + onnxruntime::Initializer initializer(onnx_subgraph_viewer.GetGraph(), + *tensor_proto, onnx_subgraph_viewer.ModelPath()); + const float* val = initializer.data(); if (std::isnan(val[0]) || std::isinf(val[0])) { if (std::isinf(val[0]) && std::isinf(expected_value) && (std::signbit(val[0]) == std::signbit(expected_value))) { return true; @@ -775,9 +773,8 @@ void DnnlGraphTransformer::RemoveMatMulIntegerZP(DnnlSubgraph& subgraph, const o // check if b_zp is all zeros, assume data is s8 since only s8 weight is supported in onednn bool all_zero = true; - std::vector unpacked_tensor; - unpacked_tensor.resize(num_elements, 1); - ORT_THROW_IF_ERROR(onnxruntime::utils::UnpackTensor(*tensor_proto, tensor_proto->has_raw_data() ? tensor_proto->raw_data().data() : nullptr, tensor_proto->has_raw_data() ? tensor_proto->raw_data().size() : 0, reinterpret_cast(unpacked_tensor.data()), num_elements)); + std::vector unpacked_tensor; + ORT_THROW_IF_ERROR(onnxruntime::utils::UnpackInitializerData(*tensor_proto, unpacked_tensor)); for (const auto& val : unpacked_tensor) { if (val != 0) { all_zero = false; diff --git a/onnxruntime/core/providers/js/allocator.h b/onnxruntime/core/providers/js/allocator.h index aafb0bb22da7e..d76a26e14e120 100644 --- a/onnxruntime/core/providers/js/allocator.h +++ b/onnxruntime/core/providers/js/allocator.h @@ -14,8 +14,8 @@ class WebGpuAllocator : public IAllocator { WebGpuAllocator() : IAllocator( OrtMemoryInfo(WEBGPU_BUFFER, OrtAllocatorType::OrtDeviceAllocator, - OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, 0), - 0, OrtMemTypeDefault)) { + OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::NONE, 0), + OrtMemTypeDefault)) { } virtual void* Alloc(size_t size) override; diff --git a/onnxruntime/core/providers/js/js_execution_provider.cc b/onnxruntime/core/providers/js/js_execution_provider.cc index d8e24ff1f5053..2be286440bcf4 100644 --- a/onnxruntime/core/providers/js/js_execution_provider.cc +++ b/onnxruntime/core/providers/js/js_execution_provider.cc @@ -772,7 +772,8 @@ std::unique_ptr RegisterKernels() { using namespace js; JsExecutionProvider::JsExecutionProvider(const JsExecutionProviderInfo& info, const SessionOptions* session_options) - : IExecutionProvider{kJsExecutionProvider, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, 0)}, + : IExecutionProvider{kJsExecutionProvider, + OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::NONE, 0)}, preferred_data_layout_{info.data_layout} { if (session_options) { enable_graph_capture_ = session_options->config_options.GetConfigOrDefault("enableGraphCapture", "false") == "true"; diff --git a/onnxruntime/core/providers/js/operators/resize.h b/onnxruntime/core/providers/js/operators/resize.h index 3e8ccf40753c8..bf04bbd3825c3 100644 --- a/onnxruntime/core/providers/js/operators/resize.h +++ b/onnxruntime/core/providers/js/operators/resize.h @@ -107,7 +107,6 @@ class Resize : public JsKernel, public UpsampleBase { } virtual Status SerializeCustomData(OpKernelContext* context, AllocatorPtr alloc, void** ptr, size_t* size) const { - TensorShapeVector output_dims; std::vector roi_array; std::vector scales_array; diff --git a/onnxruntime/core/providers/migraphx/gpu_data_transfer.cc b/onnxruntime/core/providers/migraphx/gpu_data_transfer.cc index a643f6b208f94..01d9ee99f07fe 100644 --- a/onnxruntime/core/providers/migraphx/gpu_data_transfer.cc +++ b/onnxruntime/core/providers/migraphx/gpu_data_transfer.cc @@ -10,8 +10,8 @@ namespace onnxruntime { bool GPUDataTransfer::CanCopy(const OrtDevice& src_device, const OrtDevice& dst_device) const { - return src_device.Type() == OrtDevice::GPU || src_device.MemType() == OrtDevice::MemType::HIP_PINNED || - dst_device.Type() == OrtDevice::GPU || dst_device.MemType() == OrtDevice::MemType::HIP_PINNED; + return src_device.Type() == OrtDevice::GPU || src_device.MemType() == OrtDevice::MemType::HOST_ACCESSIBLE || + dst_device.Type() == OrtDevice::GPU || dst_device.MemType() == OrtDevice::MemType::HOST_ACCESSIBLE; } common::Status GPUDataTransfer::CopyTensor(const Tensor& src, Tensor& dst) const { @@ -34,7 +34,7 @@ common::Status GPUDataTransfer::CopyTensor(const Tensor& src, Tensor& dst) const } else { // copy from other CPU memory to GPU, this is blocking HIP_RETURN_IF_ERROR(hipMemcpy(dst_data, src_data, bytes, hipMemcpyHostToDevice)); - if (src_device.MemType() != OrtDevice::MemType::HIP_PINNED) { + if (src_device.MemType() != OrtDevice::MemType::HOST_ACCESSIBLE) { // Follow core/providers/cuda/gpu_data_transfer.cc to synchronize the default stream here. HIP_RETURN_IF_ERROR(hipStreamSynchronize(nullptr)); } @@ -63,20 +63,24 @@ common::Status GPUDataTransfer::CopyTensorAsync(const Tensor& src, Tensor& dst, if (src_device.Type() == OrtDevice::CPU) { // If source are not pinned, the memory copy will be performed synchronously. // For best performance, use hipHostMalloc to allocate host memory that is transferred asynchronously. - HIP_RETURN_IF_ERROR(hipMemcpyAsync(dst_data, src_data, bytes, hipMemcpyHostToDevice, static_cast(stream.GetHandle()))); + HIP_RETURN_IF_ERROR(hipMemcpyAsync(dst_data, src_data, bytes, hipMemcpyHostToDevice, + static_cast(stream.GetHandle()))); } else if (src_device.Type() == OrtDevice::GPU) { // copying between GPU, this is non-blocking - HIP_CALL_THROW(hipMemcpyAsync(dst_data, src_data, bytes, hipMemcpyDeviceToDevice, static_cast(stream.GetHandle()))); + HIP_CALL_THROW(hipMemcpyAsync(dst_data, src_data, bytes, hipMemcpyDeviceToDevice, + static_cast(stream.GetHandle()))); } else { // copy from other CPU memory to GPU, this is blocking - HIP_CALL_THROW(hipMemcpyWithStream(dst_data, src_data, bytes, hipMemcpyHostToDevice, static_cast(stream.GetHandle()))); + HIP_CALL_THROW(hipMemcpyWithStream(dst_data, src_data, bytes, hipMemcpyHostToDevice, + static_cast(stream.GetHandle()))); } } else if (src_device.Type() == OrtDevice::GPU) { // If dest are not pinned, the memory copy will be performed synchronously. // For best performance, use hipHostMalloc to allocate host memory that is transferred asynchronously. - HIP_RETURN_IF_ERROR(hipMemcpyAsync(dst_data, src_data, bytes, hipMemcpyDeviceToHost, static_cast(stream.GetHandle()))); + HIP_RETURN_IF_ERROR(hipMemcpyAsync(dst_data, src_data, bytes, hipMemcpyDeviceToHost, + static_cast(stream.GetHandle()))); } else { - if (src_device.MemType() == OrtDevice::MemType::CUDA_PINNED) { + if (src_device.MemType() == OrtDevice::MemType::HOST_ACCESSIBLE) { // sync the stream first to make sure the data arrived HIP_RETURN_IF_ERROR(hipStreamSynchronize(static_cast(stream.GetHandle()))); } diff --git a/onnxruntime/core/providers/migraphx/migraphx_allocator.h b/onnxruntime/core/providers/migraphx/migraphx_allocator.h index 2a84445897391..8dcfe63796c89 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_allocator.h +++ b/onnxruntime/core/providers/migraphx/migraphx_allocator.h @@ -14,8 +14,9 @@ class MIGraphXAllocator : public IAllocator { MIGraphXAllocator(int device_id, const char* name) : IAllocator( OrtMemoryInfo(name, OrtAllocatorType::OrtDeviceAllocator, - OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, static_cast(device_id)), - device_id, OrtMemTypeDefault)) {} + OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::AMD, + static_cast(device_id)), + OrtMemTypeDefault)) {} virtual void* Alloc(size_t size) override; virtual void Free(void* p) override; @@ -54,8 +55,9 @@ class MIGraphXPinnedAllocator final : public IAllocator { MIGraphXPinnedAllocator(const int device_id, const char* name) : IAllocator( OrtMemoryInfo(name, OrtDeviceAllocator, - OrtDevice(OrtDevice::CPU, OrtDevice::MemType::HIP_PINNED, static_cast(device_id)), - device_id, OrtMemTypeCPUOutput)) {} + OrtDevice(OrtDevice::CPU, OrtDevice::MemType::HOST_ACCESSIBLE, OrtDevice::VendorIds::AMD, + static_cast(device_id)), + OrtMemTypeCPUOutput)) {} void* Alloc(size_t size) override; void Free(void* p) override; diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc index ed373466198a6..716c2c39cd837 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc @@ -106,7 +106,10 @@ std::shared_ptr MIGraphXExecutionProvider::GetKernelRegistry() c } MIGraphXExecutionProvider::MIGraphXExecutionProvider(const MIGraphXExecutionProviderInfo& info) - : IExecutionProvider{onnxruntime::kMIGraphXExecutionProvider, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, info.device_id)}, info_(info) { + : IExecutionProvider{onnxruntime::kMIGraphXExecutionProvider, + OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::AMD, + info.device_id)}, + info_(info) { InitProviderOrtApi(); get_flags_from_session_info(info); metadef_id_generator_ = ModelMetadefIdGenerator::Create(); @@ -1614,8 +1617,11 @@ void MIGraphXExecutionProvider::RegisterStreamHandlers(IStreamCommandHandleRegis } OrtDevice MIGraphXExecutionProvider::GetOrtDeviceByMemType(OrtMemType mem_type) const { - if (mem_type == OrtMemTypeCPUInput) return OrtDevice(); - if (mem_type == OrtMemTypeCPUOutput) return OrtDevice(OrtDevice::CPU, OrtDevice::MemType::HIP_PINNED, 0 /*CPU device id always be 0*/); + if (mem_type == OrtMemTypeCPUInput) + return OrtDevice(); + if (mem_type == OrtMemTypeCPUOutput) + return OrtDevice(OrtDevice::CPU, OrtDevice::MemType::HOST_ACCESSIBLE, OrtDevice::VendorIds::AMD, + 0 /*CPU device id always be 0*/); return default_device_; } diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/helper.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/helper.cc index 5108f90fc763a..c37b068d988a4 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/helper.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/helper.cc @@ -203,7 +203,7 @@ common::Status GetQuantizationScaleAndZeroPoint(const GraphViewer& graph_viewer, return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, name, " is not a constant initializer"); }; - Initializer unpacked_tensor(*s, model_path); + Initializer unpacked_tensor(graph_viewer.GetGraph(), *s, model_path); // The scale should be one or more floats scale = unpacked_tensor.DataAsSpan()[0]; } @@ -215,7 +215,7 @@ common::Status GetQuantizationScaleAndZeroPoint(const GraphViewer& graph_viewer, return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, name, " is not a constant initializer"); }; - Initializer unpacked_tensor(*zp, model_path); + Initializer unpacked_tensor(graph_viewer.GetGraph(), *zp, model_path); // Onnx quantization uses uint8 [int8 not yet supported], need to cast to int32_t used by NNAPI zero_point = static_cast(unpacked_tensor.DataAsByteSpan()[0]); } diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/pad_op_builder.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/pad_op_builder.cc index 8127de0a0f05f..83727f7c9d960 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/pad_op_builder.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/pad_op_builder.cc @@ -80,7 +80,7 @@ Status PadOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const No const auto* pads_initializer = model_builder.GetConstantInitializer(pads); ORT_RETURN_IF_NOT(pads_initializer, "pads must be a constant"); - Initializer pads_initializer_raw_data(*pads_initializer); + Initializer pads_initializer_raw_data(model_builder.GetGraphViewer().GetGraph(), *pads_initializer); // assume pads_initializer has int64 data, per ONNX spec std::vector converted_pads_data{}; converted_pads_data.reserve(2 * data_rank); @@ -102,7 +102,7 @@ Status PadOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const No const auto& constant_value = inputs[2].node_arg.Name(); const auto* constant_value_initializer = model_builder.GetConstantInitializer(constant_value); ORT_RETURN_IF_NOT(constant_value_initializer, "constant_value must be a constant"); - Initializer pad_value_raw_data_init(*constant_value_initializer); + Initializer pad_value_raw_data_init(model_builder.GetGraphViewer().GetGraph(), *constant_value_initializer); pad_value = pad_value_raw_data_init.DataAsSpan()[0]; } @@ -158,7 +158,7 @@ bool PadOpBuilder::IsOpSupportedImpl(const GraphViewer& graph_viewer, const Node return false; } - Initializer unpacked_tensor(*pads_initializer); + Initializer unpacked_tensor(graph_viewer.GetGraph(), *pads_initializer); auto tensor_data = unpacked_tensor.DataAsSpan(); for (size_t i = 0; i < unpacked_tensor.size(); i++) { if (tensor_data[i] < 0) { diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/resize_op_builder.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/resize_op_builder.cc index af5aeba6c8236..c4f1e5f402491 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/resize_op_builder.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/resize_op_builder.cc @@ -249,7 +249,7 @@ bool ResizeOpBuilder::IsOpSupportedImpl(const GraphViewer& graph_viewer, const N return false; } - const Initializer unpacked_tensor(*scales); + const Initializer unpacked_tensor(graph_viewer.GetGraph(), *scales); auto scales_data = unpacked_tensor.DataAsSpan(); input_is_nchw = scales_data[1] == 1.0F; const float scale_n = scales_data[0]; @@ -287,7 +287,7 @@ bool ResizeOpBuilder::IsOpSupportedImpl(const GraphViewer& graph_viewer, const N return false; } - Initializer unpacked_tensor(*sizes); + Initializer unpacked_tensor(graph_viewer.GetGraph(), *sizes); auto sizes_data = unpacked_tensor.DataAsSpan(); input_is_nchw = sizes_data[1] == input_shape[1]; diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/split_op_builder.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/split_op_builder.cc index 7509fd15f1c5e..aa715068f432c 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/split_op_builder.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/split_op_builder.cc @@ -104,7 +104,7 @@ bool SplitOpBuilder::IsOpSupportedImpl(const GraphViewer& graph_viewer, const No return false; } - Initializer unpacked_tensor(*splits); + Initializer unpacked_tensor(graph_viewer.GetGraph(), *splits); auto splits_span = unpacked_tensor.DataAsSpan(); uint32_t sum_of_splits = std::accumulate(splits_span.begin(), splits_span.end(), SafeInt(0)); if (sum_of_splits != split_dims_at_axis) { diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_builder_helpers.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_builder_helpers.cc index 1c82d5e7452fd..c64a2df1ee8ce 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_builder_helpers.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_builder_helpers.cc @@ -1060,7 +1060,6 @@ Status AddReshapeOperator(ModelBuilder& model_builder, const auto& operand_types(model_builder.GetOperandTypes()); const auto& output = node_unit.Outputs()[0].node_arg.Name(); - const auto input_shape = shaper[input]; const auto output_shape = shaper[output]; // For reshape, the output type should be the same as the input type except the shape is different diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_allocator.cc b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_allocator.cc index a44ab93ccca8b..f8a722d782653 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_allocator.cc +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_allocator.cc @@ -14,7 +14,7 @@ void CUDAAllocator::CheckDevice(bool throw_when_fail) const { int current_device; auto cuda_err = cudaGetDevice(¤t_device); if (cuda_err == cudaSuccess) { - ORT_ENFORCE(current_device == Info().id); + ORT_ENFORCE(current_device == Info().device.Id()); } else if (throw_when_fail) { CUDA_CALL_THROW(cuda_err); } @@ -27,7 +27,7 @@ void CUDAAllocator::SetDevice(bool throw_when_fail) const { int current_device; auto cuda_err = cudaGetDevice(¤t_device); if (cuda_err == cudaSuccess) { - int allocator_device_id = Info().id; + int allocator_device_id = Info().device.Id(); if (current_device != allocator_device_id) { cuda_err = cudaSetDevice(allocator_device_id); } diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_allocator.h b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_allocator.h index a3f05bded5de9..b4b638ccb82f1 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_allocator.h +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_allocator.h @@ -13,10 +13,10 @@ namespace onnxruntime { class CUDAAllocator : public IAllocator { public: CUDAAllocator(OrtDevice::DeviceId device_id, const char* name) - : IAllocator( - OrtMemoryInfo(name, OrtAllocatorType::OrtDeviceAllocator, - OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, device_id), - device_id, OrtMemTypeDefault)) {} + : IAllocator(OrtMemoryInfo(name, OrtAllocatorType::OrtDeviceAllocator, + OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::NVIDIA, + device_id), + OrtMemTypeDefault)) {} void* Alloc(size_t size) override; void Free(void* p) override; @@ -56,8 +56,9 @@ class CUDAPinnedAllocator : public IAllocator { CUDAPinnedAllocator(const char* name) : IAllocator( OrtMemoryInfo(name, OrtAllocatorType::OrtDeviceAllocator, - OrtDevice(OrtDevice::CPU, OrtDevice::MemType::CUDA_PINNED, 0 /*CPU device always with id 0*/), - 0, OrtMemTypeCPUOutput)) {} + OrtDevice(OrtDevice::CPU, OrtDevice::MemType::HOST_ACCESSIBLE, OrtDevice::VendorIds::NVIDIA, + 0 /*CPU device always with id 0*/), + OrtMemTypeCPUOutput)) {} void* Alloc(size_t size) override; void Free(void* p) override; diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_data_transfer.cc b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_data_transfer.cc index 4779ddd1a9556..0dfc5b8f8f7d4 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_data_transfer.cc +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_data_transfer.cc @@ -10,8 +10,8 @@ #define CUDA_RETURN_IF_ERROR(expr) ORT_RETURN_IF_ERROR(CUDA_CALL(expr)) namespace onnxruntime { bool GPUDataTransfer::CanCopy(const OrtDevice& src_device, const OrtDevice& dst_device) const { - return src_device.Type() == OrtDevice::GPU || src_device.MemType() == OrtDevice::MemType::CUDA_PINNED || - dst_device.Type() == OrtDevice::GPU || dst_device.MemType() == OrtDevice::MemType::CUDA_PINNED; + return src_device.Type() == OrtDevice::GPU || src_device.MemType() == OrtDevice::MemType::HOST_ACCESSIBLE || + dst_device.Type() == OrtDevice::GPU || dst_device.MemType() == OrtDevice::MemType::HOST_ACCESSIBLE; } common::Status GPUDataTransfer::CopyTensor(const Tensor& src, Tensor& dst) const { @@ -35,7 +35,7 @@ common::Status GPUDataTransfer::CopyTensor(const Tensor& src, Tensor& dst) const } else { // copy from other CPU memory to GPU, this is blocking CUDA_RETURN_IF_ERROR(cudaMemcpy(dst_data, src_data, bytes, cudaMemcpyHostToDevice)); - if (src_device.MemType() != OrtDevice::MemType::CUDA_PINNED) { + if (src_device.MemType() != OrtDevice::MemType::HOST_ACCESSIBLE) { // For cudaMemcpy from pageable host memory to device memory, DMA to final destination may not have completed. // see https://docs.nvidia.com/cuda/cuda-runtime-api/api-sync-behavior.html CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(nullptr)); @@ -64,20 +64,23 @@ common::Status GPUDataTransfer::CopyTensorAsync(const Tensor& src, Tensor& dst, if (dst_device.Type() == OrtDevice::GPU) { if (src_device.Type() == OrtDevice::CPU) { // copy from pinned or non-pinned CPU memory to GPU - CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(dst_data, src_data, bytes, cudaMemcpyHostToDevice, static_cast(stream.GetHandle()))); + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(dst_data, src_data, bytes, cudaMemcpyHostToDevice, + static_cast(stream.GetHandle()))); } else if (src_device.Type() == OrtDevice::GPU) { // copying between GPU, this is non-blocking if (dst_data != src_data) { - CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(dst_data, src_data, bytes, cudaMemcpyDeviceToDevice, static_cast(stream.GetHandle()))); + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(dst_data, src_data, bytes, cudaMemcpyDeviceToDevice, + static_cast(stream.GetHandle()))); } } } else if (src_device.Type() == OrtDevice::GPU) { if (dst_device.Type() == OrtDevice::CPU) { // copy from GPU to pinned or non-pinned CPU memory. - CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(dst_data, src_data, bytes, cudaMemcpyDeviceToHost, static_cast(stream.GetHandle()))); + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(dst_data, src_data, bytes, cudaMemcpyDeviceToHost, + static_cast(stream.GetHandle()))); } } else { - if (src_device.MemType() == OrtDevice::MemType::CUDA_PINNED) { + if (src_device.MemType() == OrtDevice::MemType::HOST_ACCESSIBLE) { // sync the stream first to make sure the data arrived CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(static_cast(stream.GetHandle()))); } diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc index 6f3acd76212a7..f06ed1424eb24 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc @@ -1022,7 +1022,7 @@ NvExecutionProvider::PerThreadContext& NvExecutionProvider::GetPerThreadContext( NvExecutionProvider::NvExecutionProvider(const NvExecutionProviderInfo& info) : IExecutionProvider{onnxruntime::kNvTensorRTRTXExecutionProvider, - OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, + OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::NVIDIA, narrow(info.device_id))}, info_(info), device_id_(info.device_id) { @@ -1584,6 +1584,11 @@ SubGraphCollection_t NvExecutionProvider::GetSupportedList(SubGraphCollection_t // Add node and node args // If node output is also parent graph output, the output will be added to the // subgraph's output list + // + // Initializers that refer to a memory location in OrtValue + // can not be handled by TRT (unlike those that are on disk). + // This prevents us from sharing the data and we have to make a copy here. + constexpr const bool load_initializers_inline_true = true; std::vector subgraph_output_names; for (const auto& index : group.first) { const auto& node = graph.GetNode(node_index[index]); @@ -1591,24 +1596,15 @@ SubGraphCollection_t NvExecutionProvider::GetSupportedList(SubGraphCollection_t for (auto input : node->InputDefs()) { auto& n_input = graph_build.GetOrCreateNodeArg(input->Name(), input->TypeAsProto()); inputs.push_back(&n_input); - const ONNX_NAMESPACE::TensorProto* initializer = nullptr; - if (graph.GetInitializedTensor(input->Name(), initializer)) { - const ONNX_NAMESPACE::TensorProto* subgraph_initializer = nullptr; - if (!graph_build.GetInitializedTensor(input->Name(), subgraph_initializer)) { - graph_build.AddInitializedTensor(*(initializer)); - } - } + graph_utils::MakeInitializerCopyIfNotExist(graph.GetGraph(), graph_build, input->Name(), + load_initializers_inline_true); } for (auto input : node->ImplicitInputDefs()) { - const ONNX_NAMESPACE::TensorProto* initializer = nullptr; - if (graph.GetInitializedTensor(input->Name(), initializer)) { - const ONNX_NAMESPACE::TensorProto* subgraph_initializer = nullptr; - if (!graph_build.GetInitializedTensor(input->Name(), subgraph_initializer)) { - graph_build.AddInitializedTensor(*(initializer)); - } - } + graph_utils::MakeInitializerCopyIfNotExist(graph.GetGraph(), graph_build, input->Name(), + load_initializers_inline_true); } + for (auto output : node->OutputDefs()) { auto& n_output = graph_build.GetOrCreateNodeArg(output->Name(), output->TypeAsProto()); outputs.push_back(&n_output); @@ -2565,7 +2561,7 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr onnx, onnx_size, trt_engine.get(), - true /* serialize refitted engine to disk */, + false /* serialize refitted engine to disk */, detailed_build_log_); if (status != Status::OK()) { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, status.ErrorMessage()); @@ -2682,8 +2678,9 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr int num_outputs = static_cast(output_indexes.size()); std::unordered_set input_names; - OrtDevice device(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, narrow(device_id_)); - OrtMemoryInfo mem_info("", OrtAllocatorType::OrtDeviceAllocator, device, device_id_); + OrtDevice device(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::NVIDIA, + narrow(device_id_)); + OrtMemoryInfo mem_info("", OrtAllocatorType::OrtDeviceAllocator, device); if (alloc_ == nullptr) { Ort::ThrowOnError(api->KernelContext_GetAllocator(context, &mem_info, &alloc_)); } @@ -2748,7 +2745,7 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr onnx_model_bytestream_, onnx_model_bytestream_size_, trt_engine, - true /* serialize refitted engine to disk */, + false /* serialize refitted engine to disk */, detailed_build_log_); if (status != Status::OK()) { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, status.ErrorMessage()); @@ -3061,8 +3058,9 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(const Gra std::unordered_map> shape_tensor_values; // This map holds "shape tensor -> shape values" for the shape tensor input across this inference run std::unordered_map> shape_tensor_values_int64; // same as above but for int64 shape tensor input - OrtDevice device(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, narrow(device_id_)); - OrtMemoryInfo mem_info("", OrtAllocatorType::OrtDeviceAllocator, device, device_id_); + OrtDevice device(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::NVIDIA, + narrow(device_id_)); + OrtMemoryInfo mem_info("", OrtAllocatorType::OrtDeviceAllocator, device); if (alloc_ == nullptr) { Ort::ThrowOnError(api->KernelContext_GetAllocator(context, &mem_info, &alloc_)); } @@ -3259,8 +3257,11 @@ void NvExecutionProvider::RegisterStreamHandlers(IStreamCommandHandleRegistry& s } OrtDevice NvExecutionProvider::GetOrtDeviceByMemType(OrtMemType mem_type) const { - if (mem_type == OrtMemTypeCPUInput) return OrtDevice(); - if (mem_type == OrtMemTypeCPUOutput) return OrtDevice(OrtDevice::CPU, OrtDevice::MemType::CUDA_PINNED, 0 /*CPU device id always be 0*/); + if (mem_type == OrtMemTypeCPUInput) + return OrtDevice(); + if (mem_type == OrtMemTypeCPUOutput) + return OrtDevice(OrtDevice::CPU, OrtDevice::MemType::HOST_ACCESSIBLE, OrtDevice::VendorIds::NVIDIA, + 0 /*CPU device id always be 0*/); return default_device_; } diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_info.cc b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_info.cc index d99db5acb94ff..d9d3507a4687b 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_info.cc +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_info.cc @@ -65,17 +65,29 @@ NvExecutionProviderInfo NvExecutionProviderInfo::FromProviderOptions(const Provi info.onnx_bytestream = onnx_bytestream; // EP context settings - const auto embed_enable = session_options.GetConfigOrDefault(kOrtSessionOptionEpContextEnable, "0"); - if (embed_enable == "0") { + // when EP context is enabled, default is to embed the engine in the context model + // weight stripped engine is always enabled when EP context is enabled + + const auto ep_context_enable = session_options.GetConfigOrDefault(kOrtSessionOptionEpContextEnable, "0"); + if (ep_context_enable == "0") { info.dump_ep_context_model = false; - } else if (embed_enable == "1") { + } else if (ep_context_enable == "1") { info.dump_ep_context_model = true; + info.weight_stripped_engine_enable = true; } else { ORT_THROW("Invalid ", kOrtSessionOptionEpContextEnable, " must 0 or 1"); } info.ep_context_file_path = session_options.GetConfigOrDefault(kOrtSessionOptionEpContextFilePath, ""); - const auto embed_mode = std::stoi(session_options.GetConfigOrDefault(kOrtSessionOptionEpContextEmbedMode, "1")); + // If embed mode is not specified, default to 1 if dump_ep_context_model is true, otherwise 0 + const auto embed_mode = std::stoi(session_options.GetConfigOrDefault(kOrtSessionOptionEpContextEmbedMode, "-1")); + if (embed_mode == -1) { + if (info.dump_ep_context_model) + embed_mode = 1; + else + embed_mode = 0; + } + if (0 <= embed_mode || embed_mode < 2) { info.ep_context_embed_mode = embed_mode; } else { diff --git a/onnxruntime/core/providers/openvino/ov_allocator.cc b/onnxruntime/core/providers/openvino/ov_allocator.cc index 9e4ac6009e2e3..248c859b20dee 100644 --- a/onnxruntime/core/providers/openvino/ov_allocator.cc +++ b/onnxruntime/core/providers/openvino/ov_allocator.cc @@ -10,7 +10,13 @@ namespace onnxruntime { using namespace openvino_ep; -OVRTAllocator::OVRTAllocator(ov::Core& core, OrtDevice::DeviceType device_type, OrtDevice::DeviceId device_id, const char* name) : IAllocator(OrtMemoryInfo(name, OrtAllocatorType::OrtDeviceAllocator, OrtDevice(device_type, OrtDevice::MemType::DEFAULT, device_id), device_id, OrtMemTypeCPUInput)), core_(core) { +OVRTAllocator::OVRTAllocator(ov::Core& core, OrtDevice::DeviceType device_type, OrtDevice::DeviceId device_id, + const char* name) + : IAllocator(OrtMemoryInfo(name, OrtAllocatorType::OrtDeviceAllocator, + OrtDevice(device_type, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::INTEL, + device_id), + OrtMemTypeCPUInput)), + core_(core) { if (device_type == OrtDevice::NPU) { remote_ctx_ = core_.get_default_context("NPU").as(); } else { diff --git a/onnxruntime/core/providers/openvino/qdq_transformations/qdq_stripping.cc b/onnxruntime/core/providers/openvino/qdq_transformations/qdq_stripping.cc index 860cfb5713903..24e8892622175 100644 --- a/onnxruntime/core/providers/openvino/qdq_transformations/qdq_stripping.cc +++ b/onnxruntime/core/providers/openvino/qdq_transformations/qdq_stripping.cc @@ -14,6 +14,7 @@ #include "core/providers/shared_library/provider_api.h" #include "core/providers/openvino/qdq_transformations/qdq_stripping.h" +#include "core/common/inlined_containers.h" namespace onnxruntime { namespace openvino_ep { @@ -643,16 +644,12 @@ static void AddInitializerAsInput(onnxruntime::Graph& dst_graph, const onnxruntime::GraphViewer& src_graph, const std::string& initializer_name) { // Get the initializer from source graph - const auto& src_initializers = src_graph.GetAllInitializedTensors(); - auto init_iter = src_initializers.find(initializer_name); - - if (init_iter == src_initializers.end()) { + const ONNX_NAMESPACE::TensorProto* tensor_proto = nullptr; + if (!src_graph.GetInitializedTensor(initializer_name, tensor_proto)) { // Initializer not found return; } - const auto* tensor_proto = init_iter->second; - // Create TypeProto for the initializer auto type_proto = ONNX_NAMESPACE::TypeProto::Create(); auto* tensor_type = type_proto->mutable_tensor_type(); @@ -789,17 +786,21 @@ Status CreateModelWithStrippedQDQNodes(const GraphViewer& src_graph, } // Copy initializers to dst graph. + const auto& initializers = src_graph.GetAllInitializedTensors(); - std::unordered_set current_scope_initializer_set; - - auto& initializers = src_graph.GetAllInitializedTensors(); + InlinedHashSet current_scope_initializer_set; + current_scope_initializer_set.reserve(initializers.size()); // Sort initializers to maintain consistency in model proto created across inference requests - std::vector const_inits; - for (auto& it : initializers) { - const_inits.push_back(it.first); + + InlinedVector all_inits; + all_inits.reserve(initializers.size()); + for (auto it = initializers.cbegin(), end = initializers.cend(); it != end; ++it) { + all_inits.push_back(it); } - std::sort(const_inits.begin(), const_inits.end()); + std::sort(all_inits.begin(), all_inits.end(), [](const auto& i1, const auto& i2) { + return i1->first < i2->first; + }); // initialize map for creating metadata for initilizers with external weights auto& metadata = shared_weights.metadata; @@ -832,41 +833,53 @@ Status CreateModelWithStrippedQDQNodes(const GraphViewer& src_graph, metadata.emplace(key, std::move(value)); }; - // Handle constant initializers - for (auto& it : const_inits) { - const auto& initializer_tensor = *initializers.at(it); + // Handle initializers + for (const auto& it : all_inits) { + const auto& [name, init] = *it; + const auto& initializer_tensor = *init; + + std::unique_ptr init_with_data; + ORT_RETURN_IF_ERROR(utils::GetTensorProtoWithDataIfInMemory(initializer_tensor, init_with_data)); // Check if the initializer has external data - if (initializer_tensor.has_data_location() && - initializer_tensor.data_location() == ONNX_NAMESPACE::TensorProto_DataLocation_EXTERNAL && + if (!init_with_data && + utils::HasExternalData(initializer_tensor) && enable_ovep_weight_sharing) { insert_metadata(initializer_tensor); // Add initializer with external data as input - AddInitializerAsInput(dst_graph, accumulated_inputs, src_graph, it); - + AddInitializerAsInput(dst_graph, accumulated_inputs, src_graph, name); } else { // Add as an initialized tensor if it does not have external data - if (initializers_to_keep.count(it)) - dst_graph.AddInitializedTensor(*(initializers.at(it))); + if (initializers_to_keep.count(name) > 0) { + if (init_with_data) { + dst_graph.AddInitializedTensor(*init_with_data); + } else { + dst_graph.AddInitializedTensor(initializer_tensor); + } + } } - current_scope_initializer_set.insert(it); + current_scope_initializer_set.insert(name); } - // Handle outer-scope constant initializers + // Handle outer-scope initializers for (auto& node_idx : src_graph.GetNodesInTopologicalOrder()) { const auto& node = src_graph.GetNode(node_idx); for (const auto& input : node->InputDefs()) { - if (current_scope_initializer_set.find(input->Name()) != current_scope_initializer_set.end()) { + if (current_scope_initializer_set.count(input->Name()) > 0) { continue; } if (src_graph.IsConstantInitializer(input->Name(), true)) { const auto& initializer_tensor = *src_graph.GetConstantInitializer(input->Name(), true); + + std::unique_ptr init_with_data; + ORT_RETURN_IF_ERROR(utils::GetTensorProtoWithDataIfInMemory(initializer_tensor, init_with_data)); + // Check if the initializer has external data - if (initializer_tensor.has_data_location() && - initializer_tensor.data_location() == ONNX_NAMESPACE::TensorProto_DataLocation_EXTERNAL && + if (!init_with_data && + utils::HasExternalData(initializer_tensor) && enable_ovep_weight_sharing) { insert_metadata(initializer_tensor); @@ -876,7 +889,11 @@ Status CreateModelWithStrippedQDQNodes(const GraphViewer& src_graph, } else { // Add as an initialized tensor if it does not have external data if (initializers_to_keep.count(input->Name())) { - dst_graph.AddInitializedTensor(*(src_graph.GetConstantInitializer(input->Name(), true))); + if (init_with_data) { + dst_graph.AddInitializedTensor(*init_with_data); + } else { + dst_graph.AddInitializedTensor(initializer_tensor); + } } } diff --git a/onnxruntime/core/providers/qnn/builder/op_builder_factory.cc b/onnxruntime/core/providers/qnn/builder/op_builder_factory.cc index 53fef09aec0fa..3d36fe5e8ff31 100644 --- a/onnxruntime/core/providers/qnn/builder/op_builder_factory.cc +++ b/onnxruntime/core/providers/qnn/builder/op_builder_factory.cc @@ -10,6 +10,8 @@ namespace onnxruntime { namespace qnn { +static OpBuilderRegistrations op_registrations; + OpBuilderRegistrations::OpBuilderRegistrations() { { CreateSimpleOpBuilder("Add", *this); @@ -166,6 +168,10 @@ OpBuilderRegistrations::OpBuilderRegistrations() { CreateTransposeOpBuilder("Transpose", *this); } + { + CreateReciprocalOpBuilder("Reciprocal", *this); + } + { CreatePadOpBuilder("Pad", *this); } @@ -191,8 +197,11 @@ OpBuilderRegistrations::OpBuilderRegistrations() { } } +void RegisterUDOBuilder(const std::string& op_type, const std::string& op_package) { + CreateUDOBuilder(op_type, op_package, op_registrations); +} + const IOpBuilder* GetOpBuilder(const std::string& onnx_op_type) { - static const OpBuilderRegistrations op_registrations; return op_registrations.GetOpBuilderByOnnxOpType(onnx_op_type); } diff --git a/onnxruntime/core/providers/qnn/builder/op_builder_factory.h b/onnxruntime/core/providers/qnn/builder/op_builder_factory.h index 1cc8e12068cca..b9b3c34467855 100644 --- a/onnxruntime/core/providers/qnn/builder/op_builder_factory.h +++ b/onnxruntime/core/providers/qnn/builder/op_builder_factory.h @@ -3,9 +3,10 @@ #pragma once +#include #include +#include #include "op_builder.h" - namespace onnxruntime { namespace qnn { @@ -40,6 +41,20 @@ class OpBuilderRegistrations { } } + void RegisterUDOBuilder(const std::string& op_type, std::unique_ptr builder) { + auto builder_type = builder->GetOpBuilderType(); + auto pos_in_builder_type_map = builder_type_builder_map_.find(builder_type); + if (pos_in_builder_type_map != builder_type_builder_map_.end()) { + // already have this builder type, re-use it for this op_type + op_builder_map_[op_type] = pos_in_builder_type_map->second; + } else { + // New Op builder, add to vector and all the maps + builders_.push_back(std::move(builder)); + op_builder_map_[op_type] = builders_.back().get(); + builder_type_builder_map_[builder_type] = builders_.back().get(); + } + } + private: std::vector> builders_; // @@ -47,6 +62,8 @@ class OpBuilderRegistrations { // std::unordered_map builder_type_builder_map_; }; +void RegisterUDOBuilder(const std::string& op_type, const std::string& op_package); + const IOpBuilder* GetOpBuilder(const std::string& onnx_op_type); void CreateSimpleOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); @@ -93,6 +110,8 @@ void CreateLRNOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_r void CreateTransposeOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); +void CreateReciprocalOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); + void CreatePadOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateExpandOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); @@ -107,5 +126,6 @@ void CreateLSTMOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_ void CreateCumSumOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); +void CreateUDOBuilder(const std::string& op_type, const std::string& op_package, OpBuilderRegistrations& op_registrations); } // namespace qnn } // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.h b/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.h index a83e8e064c7d0..2ccfad206a38a 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.h +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.h @@ -124,6 +124,9 @@ class BaseOpBuilder : public IOpBuilder { } else if (std::is_same::value) { qnn_scalar.dataType = QNN_DATATYPE_INT_32; qnn_scalar.int32Value = static_cast(scalar); + } else if (std::is_same::value) { + qnn_scalar.dataType = QNN_DATATYPE_INT_64; + qnn_scalar.int64Value = static_cast(scalar); } else if (std::is_same::value) { qnn_scalar.dataType = QNN_DATATYPE_BOOL_8; qnn_scalar.bool8Value = static_cast(scalar); @@ -136,6 +139,21 @@ class BaseOpBuilder : public IOpBuilder { return Status::OK(); } + Status AddQnnScalar(QnnModelWrapper& qnn_model_wrapper, + const NodeIndex& node_index, + const std::string& node_name, + const std::string& scalar, + const std::string& qnn_scalar_param_name, + std::vector& param_names) const { + Qnn_Scalar_t qnn_scalar = QNN_SCALAR_INIT; + qnn_scalar.dataType = QNN_DATATYPE_STRING; + qnn_scalar.stringValue = scalar.c_str(); + QnnParamWrapper qnn_param_wrapper(node_index, node_name, qnn_scalar_param_name, qnn_scalar); + param_names.push_back(qnn_param_wrapper.GetParamTensorName()); + qnn_model_wrapper.AddParamWrapper(std::move(qnn_param_wrapper)); + return Status::OK(); + } + Status SetOutputQParamEqualToInputIfNearlyEqual(QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit, const logging::Logger& logger, diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/expand_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/expand_op_builder.cc index 1dd12abc2baf9..477a2445d9369 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/expand_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/expand_op_builder.cc @@ -20,6 +20,11 @@ class ExpandOpBuilder : public BaseOpBuilder { const logging::Logger& logger, std::vector& input_names, bool do_op_validation) const override ORT_MUST_USE_RESULT; + Status ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + std::vector&& input_names, + const logging::Logger& logger, + bool do_op_validation) const override ORT_MUST_USE_RESULT; Status OverrideOutputQuantParam(QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit, const logging::Logger& logger, @@ -47,7 +52,6 @@ Status ExpandOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper, ORT_UNUSED_PARAMETER(do_op_validation); const auto& inputs = node_unit.Inputs(); ORT_RETURN_IF(inputs.size() != 2, "Expand should has 2 inputs!"); - ORT_RETURN_IF_ERROR(ProcessInput(qnn_model_wrapper, inputs[0], logger, input_names)); // Process shape input @@ -124,6 +128,10 @@ Status ExpandOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper, FillShapeInputData(shape_data, shape_size, static_cast(1)); break; } + case QNN_DATATYPE_BOOL_8: { + FillShapeInputData(shape_data, shape_size, static_cast(1)); + break; + } default: return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Type not supported."); } // switch @@ -135,12 +143,29 @@ Status ExpandOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper, std::move(quantize_param), std::move(input_shape), std::move(shape_data)); ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(input_tensorwrapper)), "Failed to add tensor."); - input_names.push_back(shape_input_name); return Status::OK(); } +Status ExpandOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + std::vector&& input_names, + const logging::Logger& logger, + bool do_op_validation) const { + if (input_names.size() < 1) { + return Status::OK(); + } + const auto* input_proto = node_unit.Inputs()[0].node_arg.TypeAsProto(); + Qnn_DataType_t qnn_data_type{}; + ORT_RETURN_IF_ERROR(utils::GetQnnDataType(false, input_proto, qnn_data_type)); + // Boolean expand is implemented as an element-wise and operation, element-wise multiply otherwise. + const std::string target_op = qnn_data_type == QNN_DATATYPE_BOOL_8 ? "And" : node_unit.OpType(); + ORT_RETURN_IF_ERROR(ProcessOutputs(qnn_model_wrapper, node_unit, std::move(input_names), {}, + logger, do_op_validation, GetQnnOpType(target_op))); + return Status::OK(); +} + Status ExpandOpBuilder::OverrideOutputQuantParam(QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit, const logging::Logger& logger, diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/instance_norm_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/instance_norm_op_builder.cc index d77d9534bf1c4..e370501871e81 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/instance_norm_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/instance_norm_op_builder.cc @@ -25,6 +25,11 @@ class InstanceNormOpBuilder : public BaseOpBuilder { std::vector& input_names, bool do_op_validation) const override ORT_MUST_USE_RESULT; + Status ProcessScale(QnnModelWrapper& qnn_model_wrapper, + const NodeUnitIODef& input, + const logging::Logger& logger, + std::vector& input_names) const; + Status ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit, std::vector&& input_names, @@ -143,12 +148,49 @@ Status InstanceNormOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper, ORT_RETURN_IF_ERROR(ProcessInput(qnn_model_wrapper, inputs[0], logger, input_names)); // Input 0 } - ORT_RETURN_IF_ERROR(ProcessInput(qnn_model_wrapper, inputs[1], logger, input_names)); // Scale + ORT_RETURN_IF_ERROR(ProcessScale(qnn_model_wrapper, inputs[1], logger, input_names)); // Scale ORT_RETURN_IF_ERROR(ProcessInput(qnn_model_wrapper, inputs[2], logger, input_names)); // Bias return Status::OK(); } +Status InstanceNormOpBuilder::ProcessScale(QnnModelWrapper& qnn_model_wrapper, + const NodeUnitIODef& input, + const logging::Logger& logger, + std::vector& input_names) const { + ORT_RETURN_IF_ERROR(ProcessInput(qnn_model_wrapper, input, logger, input_names)); + + // Turn SFIXED scale of InstanceNorm into UFIXED when it is constant + const auto& input_name = input.node_arg.Name(); + bool is_const = qnn_model_wrapper.IsConstantInput(input_name); + bool is_npu = IsNpuBackend(qnn_model_wrapper.GetQnnBackendType()); + if (is_npu && is_const) { + TensorInfo tensor_info = {}; + ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(input, tensor_info)); + const Qnn_QuantizeParams_t& quant_param = tensor_info.quant_param.Get(); + if (tensor_info.qnn_data_type == QNN_DATATYPE_SFIXED_POINT_8) { + std::string convert_input_name = input_names.back(); + std::string convert_output_name = convert_input_name + "_convert_s8_to_u8"; + Status status = utils::InsertConvertOp( + qnn_model_wrapper, + convert_input_name, + convert_output_name, + QNN_DATATYPE_SFIXED_POINT_8, + QNN_DATATYPE_UFIXED_POINT_8, + quant_param.scaleOffsetEncoding.offset, + quant_param.scaleOffsetEncoding.scale, + tensor_info.shape, + false, // asymmetric + false // do_op_validation + ); + input_names.pop_back(); + input_names.push_back(convert_output_name); + } + } + + return Status::OK(); +} + Status InstanceNormOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit, std::vector&& input_names, diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/reciprocal_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/reciprocal_op_builder.cc new file mode 100644 index 0000000000000..bd55df8650b97 --- /dev/null +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/reciprocal_op_builder.cc @@ -0,0 +1,114 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/qnn/builder/opbuilder/base_op_builder.h" +#include "core/providers/qnn/builder/qnn_model_wrapper.h" +#include "core/providers/qnn/builder/op_builder_factory.h" +#include "core/providers/qnn/builder/qnn_utils.h" + +namespace onnxruntime { +namespace qnn { + +class ReciprocalOpBuilder : public BaseOpBuilder { + public: + ReciprocalOpBuilder() : BaseOpBuilder("ReciprocalOpBuilder") {} + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(ReciprocalOpBuilder); + + Status IsOpSupported(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + const logging::Logger& logger) const override final ORT_MUST_USE_RESULT; + + protected: + Status ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + std::vector&& input_names, + const logging::Logger& logger, + bool do_op_validation) const override ORT_MUST_USE_RESULT; +}; + +Status ReciprocalOpBuilder::IsOpSupported(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + const logging::Logger& logger) const { + ORT_UNUSED_PARAMETER(logger); + + const auto& inputs = node_unit.Inputs(); + ORT_RETURN_IF_NOT(inputs.size() == 1, "Reciprocal operator must have exactly 1 input."); + + const auto& outputs = node_unit.Outputs(); + ORT_RETURN_IF_NOT(outputs.size() == 1, "Reciprocal operator must have exactly 1 output."); + + // Check input type is float for CPU. + ORT_RETURN_IF_ERROR(DataTypeCheckForCpuBackend(qnn_model_wrapper, inputs[0].node_arg.Type())); + + return Status::OK(); +} + +Status ReciprocalOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + std::vector&& input_names, + const logging::Logger& logger, + bool do_op_validation) const { + ORT_UNUSED_PARAMETER(logger); + + // Create a constant tensor for the divisor (1.0) + std::string divisor_name = node_unit.Name() + "_divisor"; + std::vector divisor_shape{1}; + std::vector divisor_data; + + TensorInfo input_info{}; + ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(node_unit.Inputs()[0], input_info)); + + QnnQuantParamsWrapper divisor_quant_param = input_info.quant_param.Copy(); + Qnn_DataType_t divisor_qnn_data_type = input_info.qnn_data_type; + + if (input_info.quant_param.IsQuantized()) { + // Create a quantized divisor tensor + double divisor_value = 1.0; + int quantized_divisor_value; + ORT_RETURN_IF_ERROR(utils::Quantize(divisor_value, divisor_quant_param.Get().scaleOffsetEncoding.scale, + divisor_quant_param.Get().scaleOffsetEncoding.offset, + divisor_qnn_data_type, quantized_divisor_value)); + size_t element_size = qnn::utils::GetElementSizeByType(divisor_qnn_data_type); + divisor_data.resize(element_size); + std::memcpy(divisor_data.data(), &quantized_divisor_value, element_size); + } else { + // Create a float divisor tensor + divisor_data.resize(sizeof(float)); + float one = 1.0f; + std::memcpy(divisor_data.data(), &one, sizeof(float)); + } + + QnnTensorWrapper divisor_tensorwrapper(divisor_name, QNN_TENSOR_TYPE_STATIC, divisor_qnn_data_type, + std::move(divisor_quant_param), std::move(divisor_shape), std::move(divisor_data)); + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(divisor_tensorwrapper)), "Failed to add divisor tensor."); + + // Create the Div node + const auto& outputs = node_unit.Outputs(); + const std::string& output_name = outputs[0].node_arg.Name(); + bool is_graph_output = qnn_model_wrapper.IsGraphOutput(output_name); + Qnn_TensorType_t tensor_type = is_graph_output ? QNN_TENSOR_TYPE_APP_READ : QNN_TENSOR_TYPE_NATIVE; + TensorInfo output_info{}; + ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(outputs[0], output_info)); + QnnTensorWrapper output_tensorwrapper(output_name, tensor_type, output_info.qnn_data_type, + output_info.quant_param.Copy(), std::move(output_info.shape)); + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(output_tensorwrapper)), "Failed to add output tensor."); + + ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode( + utils::GetNodeName(node_unit), + QNN_OP_PACKAGE_NAME_QTI_AISW, + QNN_OP_ELEMENT_WISE_DIVIDE, + {divisor_name, input_names[0]}, + {output_name}, + {}, + do_op_validation), + "Failed to create Div node."); + + return Status::OK(); +} + +void CreateReciprocalOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { + op_registrations.AddOpBuilder(op_type, std::make_unique()); +} + +} // namespace qnn +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/reshape_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/reshape_op_builder.cc index 6fd67a72b64e1..574df61152a0b 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/reshape_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/reshape_op_builder.cc @@ -60,7 +60,7 @@ Status ReshapeOpBuilder::OverrideOutputQuantParam(QnnModelWrapper& qnn_model_wra } // Force Reshape output to use the same quantization parameters as the input if nearly equal. - // This helps the HTP backend emply certain optimizations. + // This helps the HTP backend employ certain optimizations. return SetOutputQParamEqualToInputIfNearlyEqual(qnn_model_wrapper, node_unit, logger, input_names, 0 /*input_index*/, output_index, qnn_data_type, quant_param); } diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/udo_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/udo_builder.cc new file mode 100644 index 0000000000000..339c521952bcf --- /dev/null +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/udo_builder.cc @@ -0,0 +1,134 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include +#include +#include + +#include "core/providers/qnn/ort_api.h" +#include "core/providers/qnn/builder/op_builder_factory.h" +#include "core/providers/qnn/builder/opbuilder/base_op_builder.h" +#include "core/providers/qnn/builder/qnn_model_wrapper.h" +#include "core/providers/qnn/builder/qnn_utils.h" + +namespace onnxruntime { +namespace qnn { + +class UDOBuilder : public BaseOpBuilder { + public: + UDOBuilder(const std::string& op_type, const std::string& op_package) : BaseOpBuilder(op_type + "_UDOBuilder"), op_type_(op_type), op_package_(op_package) {} + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(UDOBuilder); + + protected: + Status ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit, + std::vector&& input_names, + const logging::Logger& logger, + bool do_op_validation) const override ORT_MUST_USE_RESULT; + + private: + const std::string op_type_; + const std::string op_package_; +}; + +Status UDOBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + std::vector&& input_names, + const logging::Logger& logger, + bool do_op_validation) const { + ORT_UNUSED_PARAMETER(logger); + std::string node_name = utils::GetNodeName(node_unit); + const auto& outputs = node_unit.Outputs(); + std::vector output_names; + for (size_t i = 0; i < outputs.size(); ++i) { + const auto& output_name = outputs[i].node_arg.Name(); + + TensorInfo output_info = {}; + ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(outputs[i], output_info)); + bool is_graph_output = qnn_model_wrapper.IsGraphOutput(output_name); + + Qnn_TensorType_t tensor_type = is_graph_output ? QNN_TENSOR_TYPE_APP_READ : QNN_TENSOR_TYPE_NATIVE; + QnnTensorWrapper output_tensorwrapper(output_name, + tensor_type, + output_info.qnn_data_type, + std::move(output_info.quant_param), + std::move(output_info.shape)); + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(output_tensorwrapper)), "Failed to add tensor."); + output_names.emplace_back(output_name); + } + std::vector param_names; + NodeAttrHelper node_helper(node_unit); + auto& attrs = node_unit.GetNode().GetAttributes(); + for (auto& attr : attrs) { + std::string attr_name = attr.first; + auto& attr_value = attr.second; + LOGS(logger, VERBOSE) << "Parse attr name: " << attr_name << " for op " << node_name; + switch (attr_value.type()) { + case ONNX_NAMESPACE::AttributeProto::FLOAT: { + auto optional_float = node_helper.GetFloat(attr_name); + ORT_RETURN_IF_NOT(optional_float.has_value(), + "Failed to get values from attr ", attr_name, " in op ", node_name, " to qnn_model_wrapper."); + ORT_RETURN_IF_ERROR(AddQnnScalar(qnn_model_wrapper, node_unit.Index(), node_name, optional_float.value(), attr_name, param_names)); + break; + } + case ONNX_NAMESPACE::AttributeProto::FLOATS: { + auto optional_floats = node_helper.GetFloats(attr_name); + ORT_RETURN_IF_NOT(optional_floats.has_value(), + "Failed to get values from attr ", attr_name, " in op ", node_name, " to qnn_model_wrapper."); + std::vector floats_data(optional_floats.value().begin(), optional_floats.value().end()); + auto param_wrapper = createQnnParamWrapper(node_unit.Index(), node_name, attr_name, + {static_cast(floats_data.size())}, std::move(floats_data)); + param_names.push_back(param_wrapper.GetParamTensorName()); + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddParamWrapper(std::move(param_wrapper)), + "Failed to add tensor attr ", attr_name, " in op ", node_name, " to qnn_model_wrapper."); + break; + } + case ONNX_NAMESPACE::AttributeProto::INT: { + auto optional_int64 = node_helper.GetInt64(attr_name); + ORT_RETURN_IF_NOT(optional_int64.has_value(), + "Failed to get values from attr ", attr_name, " in op ", node_name, " to qnn_model_wrapper."); + ORT_RETURN_IF_ERROR(AddQnnScalar(qnn_model_wrapper, node_unit.Index(), node_name, optional_int64.value(), attr_name, param_names)); + break; + } + case ONNX_NAMESPACE::AttributeProto::INTS: { + auto optional_int64s = node_helper.GetInt64s(attr_name); + ORT_RETURN_IF_NOT(optional_int64s.has_value(), + "Failed to get values from attr ", attr_name, " in op ", node_name, " to qnn_model_wrapper."); + std::vector int64s_data(optional_int64s.value().begin(), optional_int64s.value().end()); + auto param_wrapper = createQnnParamWrapper(node_unit.Index(), node_name, attr_name, + {static_cast(int64s_data.size())}, std::move(int64s_data)); + param_names.push_back(param_wrapper.GetParamTensorName()); + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddParamWrapper(std::move(param_wrapper)), + "Failed to add tensor attr ", attr_name, " in op ", node_name, " to qnn_model_wrapper."); + break; + } + case ONNX_NAMESPACE::AttributeProto::STRING: { + auto optional_string = node_helper.GetString(attr_name); + ORT_RETURN_IF_NOT(optional_string.has_value(), + "Failed to get values from attr ", attr_name, " in op ", node_name, " to qnn_model_wrapper."); + ORT_RETURN_IF_ERROR(AddQnnScalar(qnn_model_wrapper, node_unit.Index(), node_name, optional_string.value(), attr_name, param_names)); + break; + } + default: { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to add scalar attr ", attr_name, " data_type ", attr_value.type(), " in op ", node_name, " to qnn_model_wrapper."); + } + } + } + + ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(node_name, + op_package_, + op_type_, + std::move(input_names), + std::move(output_names), + std::move(param_names), + do_op_validation), + "Failed to add node."); + return Status::OK(); +} + +void CreateUDOBuilder(const std::string& op_type, const std::string& op_package, OpBuilderRegistrations& op_registrations) { + op_registrations.RegisterUDOBuilder(op_type, std::make_unique(op_type, op_package)); +} + +} // namespace qnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc index 901569b54e049..2c7098749a985 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc @@ -24,6 +24,7 @@ #include "core/providers/qnn/ort_api.h" #include "core/providers/qnn/qnn_allocator.h" #include "core/providers/qnn/qnn_telemetry.h" +#include "core/providers/qnn/shared_context.h" #include "core/providers/qnn/builder/onnx_ctx_model_helper.h" #include "core/providers/qnn/builder/qnn_configs_helper.h" #include "core/providers/qnn/builder/qnn_utils.h" @@ -709,6 +710,135 @@ Status SetQnnContextConfig(ContextPriority context_priority, QnnContext_Config_t return Status::OK(); } +// callback required to add context handles to class list +// when using contextCreateFromBinaryListAsync() +void ContextCreateAsyncCallback(Qnn_ContextHandle_t context, + Qnn_GraphHandle_t graph, + const char* graphName, + QnnContext_createFromBinaryAsyncNotifyType_t notifyType, + void* notifyParam, + Qnn_ErrorHandle_t status) { + auto qnn_backend_manager = SharedContext::GetInstance().GetSharedQnnBackendManager(); + + if (context) { + qnn_backend_manager->ProcessContextFromBinListAsync(context, notifyParam); + } + + if (nullptr == graphName || graph || notifyType || status) { + // Avoid compilation unused var warning error + } +} + +void QnnBackendManager::ProcessContextFromBinListAsync(Qnn_ContextHandle_t context, void* notifyParam) { + std::lock_guard guard(ep_context_handle_map_mutex_); + if (!notifyParam) { + LOGS(*logger_, WARNING) << "No known node names associated with context handle: " << context; + return; + } + + std::vector* ep_node_names = reinterpret_cast*>(notifyParam); + for (const auto& node_name : *ep_node_names) { + if (!(ep_context_handle_map_.emplace(node_name, context).second)) { + LOGS(*logger_, VERBOSE) << "Unable to map " << context << " to " << node_name; + } + } + + auto s = AddQnnContextHandle(context); + if (s != Status::OK()) { + LOGS(*logger_, WARNING) << "Unable to add context " << context; + } +} + +Status QnnBackendManager::CreateContextVtcmBackupBufferSharingEnabled(std::unordered_map>>& context_bin_map) { +#if QNN_API_VERSION_MAJOR == 2 && (QNN_API_VERSION_MINOR >= 26) + QnnContext_Config_t context_config_resource_sharing = QNN_CONTEXT_CONFIG_INIT; + QnnHtpContext_CustomConfig_t resource_sharing_custom_config; + resource_sharing_custom_config.option = QNN_HTP_CONTEXT_CONFIG_OPTION_SHARE_RESOURCES; + resource_sharing_custom_config.shareResources = true; + context_config_resource_sharing.option = QNN_CONTEXT_CONFIG_OPTION_CUSTOM; + context_config_resource_sharing.customConfig = &resource_sharing_custom_config; + + QnnHtpContext_CustomConfig_t context_config_resource_sharing_opt_type; + context_config_resource_sharing_opt_type.option = QNN_HTP_CONTEXT_CONFIG_OPTION_SHARE_RESOURCES_OPTIMIZATION_TYPE; + context_config_resource_sharing_opt_type.shareResOptType = SEQUENTIAL_WITHOUT_VA_OPTIMIZATION; + QnnContext_Config_t resource_sharing_opt_type_config; + resource_sharing_opt_type_config.option = QNN_CONTEXT_CONFIG_OPTION_CUSTOM; + resource_sharing_opt_type_config.customConfig = &context_config_resource_sharing_opt_type; + + QnnContext_Config_t context_config_weight_sharing = QNN_CONTEXT_CONFIG_INIT; + QnnHtpContext_CustomConfig_t custom_config; + custom_config.option = QNN_HTP_CONTEXT_CONFIG_OPTION_WEIGHT_SHARING_ENABLED; + custom_config.weightSharingEnabled = true; + context_config_weight_sharing.option = QNN_CONTEXT_CONFIG_OPTION_CUSTOM; + context_config_weight_sharing.customConfig = &custom_config; +#else + LOGS(*logger_, WARNING) << "Called CreateContextVtcmBackupBufferSharingEnabled() but QNN API version is older than 2.26!"; +#endif + QnnContext_Config_t context_priority_config = QNN_CONTEXT_CONFIG_INIT; + ORT_RETURN_IF_ERROR(SetQnnContextConfig(context_priority_, context_priority_config)); + + const QnnContext_Config_t* configs[] = {&context_priority_config, +#if QNN_API_VERSION_MAJOR == 2 && (QNN_API_VERSION_MINOR >= 26) + &context_config_resource_sharing, + &resource_sharing_opt_type_config, + &context_config_weight_sharing, +#endif + nullptr}; + + std::vector context_params_list; + std::vector context_paramsv1_list; + std::vector context_params_ptr_list(context_bin_map.size() + 1); + std::vector> buffer_list; + + size_t idx = 0; + for (auto& it : context_bin_map) { + auto context_bin_filepath = it.first; + + std::ifstream cache_file(context_bin_filepath.c_str(), std::ifstream::binary); + ORT_RETURN_IF(!cache_file || !cache_file.good(), "Failed to retrieve context binary from: ", context_bin_filepath); + + cache_file.seekg(0, cache_file.end); + size_t buffer_size = static_cast(cache_file.tellg()); + ORT_RETURN_IF(0 == buffer_size, "Empty cache file encountered."); + + cache_file.seekg(0, cache_file.beg); + std::unique_ptr buffer = std::make_unique(buffer_size); + ORT_RETURN_IF(nullptr == buffer, "Failed to allocate memory for cache file."); + const auto& read_result = cache_file.read(buffer.get(), buffer_size); + ORT_RETURN_IF(!read_result, "Failed to read contents from cached context file."); + + cache_file.close(); + QnnContext_ParamsV1_t context_params_v1 = {nullptr, + buffer.get(), + buffer_size, + nullptr, + ContextCreateAsyncCallback, + it.second.get()}; + + QnnContext_Params_t context_params = {QnnContext_ParamsVersion_t::QNN_CONTEXT_PARAMS_VERSION_1, + context_params_v1}; + + buffer_list.push_back(std::move(buffer)); + context_params_list.push_back(std::move(context_params)); + context_paramsv1_list.push_back(std::move(context_params_v1)); + context_params_ptr_list[idx++] = &context_params_list.back(); + } + context_params_ptr_list[idx] = nullptr; + auto result = qnn_interface_.contextCreateFromBinaryListAsync(backend_handle_, + device_handle_, + context_params_ptr_list.data(), + configs, + nullptr); + + context_params_ptr_list.clear(); + context_paramsv1_list.clear(); + context_params_list.clear(); + buffer_list.clear(); + + ORT_RETURN_IF(QNN_CONTEXT_NO_ERROR != result, "Failed to create context. Error: ", QnnErrorHandleToString(result), ", Code:", result); + return Status::OK(); +} + Status QnnBackendManager::CreateContext(bool enable_htp_weight_sharing) { if (true == context_created_) { LOGS_DEFAULT(INFO) << "Context created already."; @@ -728,6 +858,7 @@ Status QnnBackendManager::CreateContext(bool enable_htp_weight_sharing) { const QnnContext_Config_t* npu_context_configs[] = {&context_priority_config, &context_config_weight_sharing, nullptr}; + const QnnContext_Config_t* empty_context_configs[] = {nullptr}; const QnnContext_Config_t** configs = nullptr; @@ -751,12 +882,14 @@ Status QnnBackendManager::CreateContext(bool enable_htp_weight_sharing) { } Qnn_ContextHandle_t context = nullptr; - Qnn_ErrorHandle_t result = qnn_interface_.contextCreate(backend_handle_, - device_handle_, - configs, - &context); + Qnn_ErrorHandle_t result = 0; - ORT_RETURN_IF(QNN_CONTEXT_NO_ERROR != result, "Failed to create context. Error: ", QnnErrorHandleToString(result)); + result = qnn_interface_.contextCreate(backend_handle_, + device_handle_, + configs, + &context); + + ORT_RETURN_IF(QNN_CONTEXT_NO_ERROR != result, "Failed to create context. Error: ", QnnErrorHandleToString(result), ", Code:", result); ORT_RETURN_IF_ERROR(AddQnnContextHandle(context)); @@ -936,43 +1069,60 @@ Status QnnBackendManager::LoadCachedQnnContextFromBuffer(char* buffer, uint64_t ORT_RETURN_IF(graph_count < 1 || graphs_info == nullptr, "Failed to get graph info from Qnn cached context."); LOGS(*logger_, VERBOSE) << "Graph count from QNN context: " << graph_count; - QnnContext_Config_t qnn_context_config = QNN_CONTEXT_CONFIG_INIT; - ORT_RETURN_IF_ERROR(SetQnnContextConfig(context_priority_, qnn_context_config)); + Qnn_ContextHandle_t context = nullptr; +#if QNN_API_VERSION_MAJOR == 2 && (QNN_API_VERSION_MINOR >= 26) + if (vtcm_backup_buffer_sharing_enabled_) { + if (ep_context_handle_map_.find(node_name) != ep_context_handle_map_.end()) { + context = ep_context_handle_map_.at(node_name); + } + ORT_RETURN_IF(nullptr == context, "Failed to retrieve context for ", node_name); + + } else { +#endif + QnnContext_Config_t qnn_context_config = QNN_CONTEXT_CONFIG_INIT; + ORT_RETURN_IF_ERROR(SetQnnContextConfig(context_priority_, qnn_context_config)); - // Register spill fill buffer for multi context - QnnContext_Config_t spill_fill_config = QNN_CONTEXT_CONFIG_INIT; + // Register spill fill buffer for multi context + QnnContext_Config_t spill_fill_config = QNN_CONTEXT_CONFIG_INIT; - // The spill fill buffer is available since 2.28, API version starts from 2.21 + // The spill fill buffer is available since 2.28, API version starts from 2.21 #if QNN_API_VERSION_MAJOR == 2 && (QNN_API_VERSION_MINOR >= 21) - QnnHtpContext_CustomConfig_t custom_config; - custom_config.option = QNN_HTP_CONTEXT_CONFIG_OPTION_REGISTER_MULTI_CONTEXTS; - QnnHtpContext_GroupRegistration_t group_info; - size_t current_contexts_size = GetQnnContextSize(); - // set to 0x0 (new group) if this is the first context, otherwise point to the first context handle - // note that we already move the context with max spill fill size to the beginning of the list - group_info.firstGroupHandle = (max_spill_fill_size > 0 && current_contexts_size > 0) ? GetQnnContext(0) : 0x0; - group_info.maxSpillFillBuffer = max_spill_fill_size; // Max spill-fill buffer across contexts. Must be >0 - custom_config.groupRegistration = group_info; - spill_fill_config.option = QNN_CONTEXT_CONFIG_OPTION_CUSTOM; - spill_fill_config.customConfig = &custom_config; + QnnHtpContext_CustomConfig_t custom_config; + custom_config.option = QNN_HTP_CONTEXT_CONFIG_OPTION_REGISTER_MULTI_CONTEXTS; + QnnHtpContext_GroupRegistration_t group_info; + size_t current_contexts_size = GetQnnContextSize(); + // set to 0x0 (new group) if this is the first context, otherwise point to the first context handle + // note that we already move the context with max spill fill size to the beginning of the list + group_info.firstGroupHandle = (max_spill_fill_size > 0 && current_contexts_size > 0) ? GetQnnContext(0) : 0x0; + group_info.maxSpillFillBuffer = max_spill_fill_size; // Max spill-fill buffer across contexts. Must be >0 + custom_config.groupRegistration = group_info; + spill_fill_config.option = QNN_CONTEXT_CONFIG_OPTION_CUSTOM; + spill_fill_config.customConfig = &custom_config; + #endif - QnnContext_Config_t* spill_fill_config_pointer = max_spill_fill_size > 0 ? &spill_fill_config : nullptr; - LOGS(*logger_, VERBOSE) << "Max spill fill buffer size:" << max_spill_fill_size; - const QnnContext_Config_t* context_configs[] = {&qnn_context_config, spill_fill_config_pointer, nullptr}; + QnnContext_Config_t* spill_fill_config_pointer = max_spill_fill_size > 0 ? &spill_fill_config : nullptr; + LOGS(*logger_, VERBOSE) << "Max spill fill buffer size:" << max_spill_fill_size; + + const QnnContext_Config_t* context_configs[] = {&qnn_context_config, spill_fill_config_pointer, nullptr}; + + ORT_RETURN_IF(nullptr == qnn_interface_.contextCreateFromBinary, + "Invalid function pointer for contextCreateFromBinary."); + + rt = qnn_interface_.contextCreateFromBinary(backend_handle_, + device_handle_, + context_configs, + static_cast(buffer), + buffer_length, + &context, + profile_backend_handle_); + ORT_RETURN_IF(QNN_SUCCESS != rt, "Failed to create context from binary. Error code: ", rt); + ORT_RETURN_IF_ERROR(AddQnnContextHandle(context)); + +#if QNN_API_VERSION_MAJOR == 2 && (QNN_API_VERSION_MINOR >= 26) + } +#endif - ORT_RETURN_IF(nullptr == qnn_interface_.contextCreateFromBinary, - "Invalid function pointer for contextCreateFromBinary."); - Qnn_ContextHandle_t context = nullptr; - rt = qnn_interface_.contextCreateFromBinary(backend_handle_, - device_handle_, - context_configs, - static_cast(buffer), - buffer_length, - &context, - profile_backend_handle_); - ORT_RETURN_IF(QNN_SUCCESS != rt, "Failed to create context from binary. Error code: ", rt); - ORT_RETURN_IF_ERROR(AddQnnContextHandle(context)); if (1 == graph_count) { // in case the EPContext node is generated from script // the graph name from the context binary may not match the EPContext node name @@ -1002,13 +1152,33 @@ Status QnnBackendManager::LoadCachedQnnContextFromBuffer(char* buffer, uint64_t Status QnnBackendManager::SetupBackend(const logging::Logger& logger, bool load_from_cached_context, bool need_load_system_lib, - bool share_ep_contexts) { + bool share_ep_contexts, + bool enable_vtcm_backup_buffer_sharing, + std::unordered_map>>& context_bin_map) { std::lock_guard lock(logger_recursive_mutex_); if (backend_setup_completed_) { LOGS(logger, VERBOSE) << "Backend setup already!"; + +#if QNN_API_VERSION_MAJOR == 2 && (QNN_API_VERSION_MINOR >= 26) + if (vtcm_backup_buffer_sharing_enabled_) { + LOGS(logger, VERBOSE) << "Mapping contexts to new EP main context nodes"; + + for (auto& it : context_bin_map) { + auto context_bin_filepath = it.first; + auto ep_node_names = *(it.second); + + auto context = ep_context_handle_map_.at(context_bin_filepath); + for (auto node_name : ep_node_names) { + ep_context_handle_map_.emplace(node_name, context); + } + } + } +#endif return Status::OK(); } + vtcm_backup_buffer_sharing_enabled_ = enable_vtcm_backup_buffer_sharing; + Status status = Status::OK(); if (!qnn_serializer_config_) { status = LoadBackend(); @@ -1057,6 +1227,11 @@ Status QnnBackendManager::SetupBackend(const logging::Logger& logger, LOGS(logger, VERBOSE) << "InitializeProfiling succeed."; } + if (status.IsOK()) { + ORT_RETURN_IF_ERROR(LoadOpPackage()); + LOGS(logger, VERBOSE) << "LoadOpPackage succeed."; + } + bool enable_htp_weight_sharing = false; if (share_ep_contexts && !load_from_cached_context) { #if defined(__aarch64__) || defined(_M_ARM64) @@ -1066,10 +1241,10 @@ Status QnnBackendManager::SetupBackend(const logging::Logger& logger, #endif } - if (!load_from_cached_context) { - if (status.IsOK()) { - status = CreateContext(enable_htp_weight_sharing); - } + if (status.IsOK() && (vtcm_backup_buffer_sharing_enabled_ || !load_from_cached_context)) { + status = vtcm_backup_buffer_sharing_enabled_ ? CreateContextVtcmBackupBufferSharingEnabled(context_bin_map) + : CreateContext(enable_htp_weight_sharing); + if (status.IsOK()) { LOGS(logger, VERBOSE) << "CreateContext succeed."; } diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h index b8e8081f77f27..371dc6dd4fc4a 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h @@ -24,6 +24,7 @@ #include "System/QnnSystemInterface.h" #include "core/providers/qnn/ort_api.h" +#include "core/providers/qnn/builder/op_builder_factory.h" #include "core/providers/qnn/builder/qnn_context_mem_handle_manager.h" #include "core/providers/qnn/builder/qnn_def.h" @@ -88,6 +89,13 @@ class QnnSerializerConfig { std::string graph_name_{"graph"}; }; +struct OpPackage { + std::string op_type; + std::string path; + std::string interface; + std::string target; +}; + // configuration values for QnnBackendManager creation struct QnnBackendManagerConfig { std::string backend_path; @@ -99,6 +107,7 @@ struct QnnBackendManagerConfig { uint32_t device_id; QnnHtpDevice_Arch_t htp_arch; uint32_t soc_model; + std::vector op_packages; }; class QnnBackendManager : public std::enable_shared_from_this { @@ -122,7 +131,8 @@ class QnnBackendManager : public std::enable_shared_from_this qnn_serializer_config_(config.qnn_serializer_config), device_id_(config.device_id), htp_arch_(config.htp_arch), - soc_model_(config.soc_model) { + soc_model_(config.soc_model), + op_packages_(config.op_packages) { } ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(QnnBackendManager); @@ -139,7 +149,9 @@ class QnnBackendManager : public std::enable_shared_from_this // Initializes handles to QNN resources (device, logger, etc.). // NOTE: This function locks the internal `logger_recursive_mutex_`. Status SetupBackend(const logging::Logger& logger, bool load_from_cached_context, - bool need_load_system_lib, bool share_ep_contexts); + bool need_load_system_lib, bool share_ep_contexts, + bool enable_vtcm_backup_buffer_sharing, + std::unordered_map>>& context_bin_map); Status CreateHtpPowerCfgId(uint32_t deviceId, uint32_t coreId, uint32_t& htp_power_config_id); @@ -199,6 +211,13 @@ class QnnBackendManager : public std::enable_shared_from_this QnnSerializerConfig* GetQnnSerializerConfig(); + // Handler to be called upon successful context creation via contextCreateFromBinaryListAsync() + // This handler is expected to be called in the callback ContextCreateAsyncCallback() in the .cc file + // Takes in the context and the notifyParam objects received by the callback function + // notifyParam is expected to be a pointer to a vector of node names associated with that context handle + // For each node name, a mapping to the context handle will be created + void ProcessContextFromBinListAsync(Qnn_ContextHandle_t handle, void* notifyParam); + private: Status LoadBackend(); @@ -216,6 +235,9 @@ class QnnBackendManager : public std::enable_shared_from_this Status CreateContext(bool enable_htp_weight_sharing); + Status CreateContextVtcmBackupBufferSharingEnabled(std::unordered_map>>& context_bin_map); + Status ReleaseContext(); // Sets the ORT logger and creates a corresponding QNN logger with the same log level. @@ -300,7 +322,7 @@ class QnnBackendManager : public std::enable_shared_from_this #endif // Adds a new QNN context. - // Transfers ownership of `context_handle` (i.e., responsibility of freeing it) to this instance. + // Transfers ownership of `context_handle` (i.e., responsibility of freeing it) to this instance Status AddQnnContextHandle(Qnn_ContextHandle_t context_handle); private: @@ -314,6 +336,70 @@ class QnnBackendManager : public std::enable_shared_from_this std::unique_ptr mem_handles; }; + Status LoadOpPackage() { + // assume op_packages passed in represented in + // op_packages|:::,::: + for (const auto& op_package : op_packages_) { + ORT_RETURN_IF(nullptr == qnn_interface_.backendRegisterOpPackage, "backendRegisterOpPackageFnHandle is nullptr."); + + Qnn_ErrorHandle_t result = qnn_interface_.backendRegisterOpPackage( + backend_handle_, + op_package.path.c_str(), + op_package.interface.c_str(), + op_package.target.c_str()); + + if (result != QNN_SUCCESS) { + switch (result) { + case QNN_BACKEND_ERROR_INVALID_ARGUMENT: + LOGS(*logger_, ERROR) << "Invalid argument, please check if op package path or interface provider is NULL."; + break; + case QNN_BACKEND_ERROR_OP_PACKAGE_NOT_FOUND: + LOGS(*logger_, ERROR) << "Could not open op package path. op_pack_path: " << op_package.path; + break; + case QNN_BACKEND_ERROR_OP_PACKAGE_IF_PROVIDER_NOT_FOUND: + LOGS(*logger_, ERROR) << "Could not find interfaceProvider symbol in op package library."; + break; + case QNN_BACKEND_ERROR_OP_PACKAGE_REGISTRATION_FAILED: + LOGS(*logger_, ERROR) << "Op package registration failed."; + break; + case QNN_BACKEND_ERROR_OP_PACKAGE_UNSUPPORTED_VERSION: + LOGS(*logger_, ERROR) << "Op package has interface version not supported by this backend."; + break; + case QNN_BACKEND_ERROR_NOT_SUPPORTED: + LOGS(*logger_, ERROR) << "Op package registration is not supported."; + break; + case QNN_BACKEND_ERROR_INVALID_HANDLE: + LOGS(*logger_, ERROR) << "backend is not a valid handle."; + break; + case QNN_BACKEND_ERROR_OP_PACKAGE_DUPLICATE: + LOGS(*logger_, ERROR) << "OpPackageName+OpName must be unique. Op package content information can be be obtained with \ + QnnOpPackage interface. Indicates that an Op with the same package name and op name was already registered."; + break; + case QNN_COMMON_ERROR_SYSTEM_COMMUNICATION: + LOGS(*logger_, ERROR) << "SSR occurrence (successful recovery)."; + break; + case QNN_COMMON_ERROR_SYSTEM_COMMUNICATION_FATAL: + LOGS(*logger_, ERROR) << "SSR occurrence (unsuccessful recovery)."; + break; + default: + LOGS(*logger_, ERROR) << "Unknown error occurred while initializing logging in the QNN backend."; + break; + } + } + ORT_RETURN_IF(QNN_SUCCESS != result, "Failed to register op package to backend. Error: ", QnnErrorHandleToString(result)); + LOGS(*logger_, VERBOSE) << "Successfully register the op package."; + std::string op_package_for_registration = std::filesystem::path(op_package.path).stem().string(); + // remove lib prefix in Linux + std::string prefix = "lib"; + if (op_package_for_registration.compare(0, prefix.size(), prefix) == 0) { + op_package_for_registration = op_package_for_registration.substr(prefix.size()); + } + qnn::RegisterUDOBuilder(op_package.op_type, op_package_for_registration); + } + + return Status::OK(); + } + private: const std::string backend_path_; std::recursive_mutex logger_recursive_mutex_; @@ -333,6 +419,10 @@ class QnnBackendManager : public std::enable_shared_from_this // HtpSharedMemoryAllocator allocation cleanup callback. std::unordered_map> context_map_; + // Map of EP Main Context Node names to Qnn_ContextHandle_t + std::mutex ep_context_handle_map_mutex_; + std::unordered_map ep_context_handle_map_; + // Vector of Qnn_ContextHandle_t. The context handles are owned by context_map_. std::vector contexts_; @@ -344,10 +434,10 @@ class QnnBackendManager : public std::enable_shared_from_this bool device_created_ = false; bool context_created_ = false; bool backend_setup_completed_ = false; + bool vtcm_backup_buffer_sharing_enabled_ = false; // NPU backend requires quantized model QnnBackendType qnn_backend_type_ = QnnBackendType::CPU; Qnn_ProfileHandle_t profile_backend_handle_ = nullptr; - std::vector op_package_paths_; ContextPriority context_priority_; std::string sdk_build_version_ = ""; #ifdef _WIN32 @@ -357,6 +447,7 @@ class QnnBackendManager : public std::enable_shared_from_this uint32_t device_id_ = 0; QnnHtpDevice_Arch_t htp_arch_ = QNN_HTP_DEVICE_ARCH_NONE; uint32_t soc_model_ = QNN_SOC_MODEL_UNKNOWN; + const std::vector op_packages_; }; } // namespace qnn diff --git a/onnxruntime/core/providers/qnn/builder/qnn_def.cc b/onnxruntime/core/providers/qnn/builder/qnn_def.cc index d3a086ea1bc9f..f3d81d7d2fdd7 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_def.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_def.cc @@ -139,26 +139,6 @@ void SetQnnTensorClientBuf(Qnn_Tensor_t& qnn_tensor, const std::vector& ORT_THROW("QNN tensor version not supported, QNN tensor version: ", qnn_tensor.version); } -void SetQnnTensorClientBuf(Qnn_Tensor_t& qnn_tensor, const std::vector& client_buf) { - if (QNN_TENSOR_VERSION_1 == qnn_tensor.version) { - auto size = client_buf.size() * sizeof(uint32_t); - qnn_tensor.v1.clientBuf.data = const_cast(static_cast(client_buf.data())); - qnn_tensor.v1.clientBuf.dataSize = static_cast(size); - return; - } - -#ifdef QNN_TENSOR_V2_INIT - if (QNN_TENSOR_VERSION_2 == qnn_tensor.version) { - auto size = client_buf.size() * sizeof(uint32_t); - qnn_tensor.v2.clientBuf.data = const_cast(static_cast(client_buf.data())); - qnn_tensor.v2.clientBuf.dataSize = static_cast(size); - return; - } -#endif // QNN_TENSOR_V2_INIT - - ORT_THROW("QNN tensor version not supported, QNN tensor version: ", qnn_tensor.version); -} - void SetQnnTensorClientBuf(Qnn_Tensor_t& qnn_tensor, void* buf_data, uint32_t buf_size) { if (QNN_TENSOR_VERSION_1 == qnn_tensor.version) { qnn_tensor.v1.clientBuf.data = buf_data; diff --git a/onnxruntime/core/providers/qnn/builder/qnn_def.h b/onnxruntime/core/providers/qnn/builder/qnn_def.h index a95628ae9cc7f..6fba6d847cb74 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_def.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_def.h @@ -107,7 +107,6 @@ void SetQnnTensorDataType(Qnn_Tensor_t& qnn_tensor, Qnn_DataType_t data_type); void SetQnnTensorDim(Qnn_Tensor_t& qnn_tensor, const std::vector& dimensions); void SetQnnTensorMemType(Qnn_Tensor_t& qnn_tensor, Qnn_TensorMemType_t mem_type); void SetQnnTensorClientBuf(Qnn_Tensor_t& qnn_tensor, const std::vector& client_buf); -void SetQnnTensorClientBuf(Qnn_Tensor_t& qnn_tensor, const std::vector& client_buf); void SetQnnTensorClientBuf(Qnn_Tensor_t& qnn_tensor, void* buf_data, uint32_t buf_size); void SetQnnTensorClientBufSize(Qnn_Tensor_t& qnn_tensor, uint32_t client_buf_size); void SetQnnTensorClientBufData(Qnn_Tensor_t& qnn_tensor, void* client_buf_data); @@ -305,12 +304,34 @@ class QnnParamWrapper { qnn_param_.scalarParam = scalarParam; } + QnnParamWrapper(NodeIndex node_index, + const std::string& node_name, + const std::string& name, + Qnn_DataType_t data_type, + std::vector&& shape, + std::vector&& param_data) : name_(name), shape_(std::move(shape)), param_data_(std::move(param_data)) { + qnn_param_.paramType = QNN_PARAMTYPE_TENSOR; + qnn_param_.name = name_.c_str(); + std::stringstream ss; + ss << node_name << "_" << node_index << "_" << name; + tensor_name_ = ss.str(); + qnn_param_.tensorParam = QNN_TENSOR_INIT; + SetQnnTensorType(qnn_param_.tensorParam, QNN_TENSOR_TYPE_STATIC); + SetQnnTensorName(qnn_param_.tensorParam, tensor_name_.c_str()); + SetQnnTensorDataType(qnn_param_.tensorParam, data_type); + SetQnnTensorDim(qnn_param_.tensorParam, shape_); + SetQnnTensorMemType(qnn_param_.tensorParam, QNN_TENSORMEMTYPE_RAW); + SetQnnTensorClientBuf(qnn_param_.tensorParam, param_data_); + } + QnnParamWrapper(NodeIndex node_index, const std::string& node_name, const std::string& name, std::vector&& shape, std::vector&& param_data, - bool is_signed = false) : name_(name), shape_(std::move(shape)), param_data_(std::move(param_data)) { + bool is_signed = false) : name_(name), shape_(std::move(shape)) { + param_data_.resize(param_data.size() * sizeof(uint32_t)); + std::memcpy(param_data_.data(), const_cast(static_cast(param_data.data())), param_data_.size()); qnn_param_.paramType = QNN_PARAMTYPE_TENSOR; qnn_param_.name = name_.c_str(); std::stringstream ss; @@ -324,6 +345,7 @@ class QnnParamWrapper { SetQnnTensorMemType(qnn_param_.tensorParam, QNN_TENSORMEMTYPE_RAW); SetQnnTensorClientBuf(qnn_param_.tensorParam, param_data_); } + ORT_DISALLOW_COPY_AND_ASSIGNMENT(QnnParamWrapper); QnnParamWrapper(QnnParamWrapper&& other) noexcept { std::swap(name_, other.name_); @@ -367,10 +389,34 @@ class QnnParamWrapper { std::string name_; std::string tensor_name_; std::vector shape_; - std::vector param_data_; + std::vector param_data_; Qnn_Param_t qnn_param_ = QNN_PARAM_INIT; }; +template +QnnParamWrapper createQnnParamWrapper(NodeIndex node_index, + const std::string& node_name, + const std::string& name, + std::vector&& shape, + std::vector&& param_data) { + Qnn_DataType_t qnn_data_type = QNN_DATATYPE_UNDEFINED; + if (std::is_same::value) { + qnn_data_type = QNN_DATATYPE_FLOAT_32; + } else if (std::is_same::value) { + qnn_data_type = QNN_DATATYPE_UINT_32; + } else if (std::is_same::value) { + qnn_data_type = QNN_DATATYPE_INT_32; + } else if (std::is_same::value) { + qnn_data_type = QNN_DATATYPE_INT_64; + } else if (std::is_same::value) { + qnn_data_type = QNN_DATATYPE_BOOL_8; + } + std::vector new_param_data; + new_param_data.resize(param_data.size() * sizeof(T)); + std::memcpy(new_param_data.data(), param_data.data(), new_param_data.size()); + return QnnParamWrapper(node_index, node_name, name, qnn_data_type, std::move(shape), std::move(new_param_data)); +} + class QnnOpConfigWrapper { public: QnnOpConfigWrapper(const std::string& name, diff --git a/onnxruntime/core/providers/qnn/builder/qnn_utils.cc b/onnxruntime/core/providers/qnn/builder/qnn_utils.cc index 407fce4a4374c..c6b0b4b6668f3 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_utils.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_utils.cc @@ -133,7 +133,7 @@ std::ostream& operator<<(std::ostream& out, const Qnn_Scalar_t& scalar) { out << scalar.int32Value; break; case QNN_DATATYPE_INT_64: - out << "int64_t is not supported"; + out << "int64_t is not supported in QNN except for UDO"; break; case QNN_DATATYPE_UINT_8: out << static_cast(scalar.uint8Value); @@ -145,7 +145,7 @@ std::ostream& operator<<(std::ostream& out, const Qnn_Scalar_t& scalar) { out << scalar.uint32Value; break; case QNN_DATATYPE_UINT_64: - out << "uint64_t is not supported"; + out << "uint64_t is not supported in QNN except for UDO"; break; case QNN_DATATYPE_FLOAT_16: break; diff --git a/onnxruntime/core/providers/qnn/qnn_allocator.cc b/onnxruntime/core/providers/qnn/qnn_allocator.cc index 29e5dc0c25564..d248644c13ddb 100644 --- a/onnxruntime/core/providers/qnn/qnn_allocator.cc +++ b/onnxruntime/core/providers/qnn/qnn_allocator.cc @@ -114,8 +114,9 @@ AllocationTracker& GlobalAllocationTracker() { OrtMemoryInfo HtpSharedMemoryAllocator::AssociatedMemoryInfo() { return OrtMemoryInfo{QNN_HTP_SHARED, OrtAllocatorType::OrtDeviceAllocator, - OrtDevice{OrtDevice::CPU, OrtDevice::MemType::QNN_HTP_SHARED, /* device_id */ 0}, - /* id */ 0, OrtMemTypeDefault}; + OrtDevice{OrtDevice::CPU, OrtDevice::MemType::HOST_ACCESSIBLE, OrtDevice::VendorIds::QUALCOMM, + /*device_id*/ 0}, + OrtMemTypeDefault}; } HtpSharedMemoryAllocator::HtpSharedMemoryAllocator(std::shared_ptr rpcmem_lib, diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc index bc69d38edf482..70850dc7162c8 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc @@ -8,6 +8,7 @@ #include #include +#include "core/common/string_utils.h" #include "core/providers/qnn/builder/onnx_ctx_model_helper.h" #include "core/providers/qnn/builder/op_builder_factory.h" #include "core/providers/qnn/builder/qnn_def.h" @@ -178,6 +179,45 @@ static void ParseHtpArchitecture(const std::string& htp_arch_string, QnnHtpDevic } } +static void ParseOpPackages(const std::string& op_packages_string, std::vector& op_packages) { + for (const auto& op_package : utils::SplitString(op_packages_string, ",", true)) { + auto splitStrings = utils::SplitString(op_package, ":", true); + if (splitStrings.size() < 3 || splitStrings.size() > 4) { + LOGS_DEFAULT(WARNING) << "Invalid op_package passed, expected ::[:], got " << op_package; + LOGS_DEFAULT(WARNING) << "Skip registration."; + continue; + } + + std::string op_type = std::string(splitStrings[0]); + std::string op_package_path = std::string(splitStrings[1]); + std::string op_package_interface = std::string(splitStrings[2]); + std::string op_package_target; + + if (op_type.empty()) { + LOGS_DEFAULT(WARNING) << "Op type is empty. Skip registration"; + continue; + } + + if (op_package_path.empty()) { + LOGS_DEFAULT(WARNING) << "Op package path is empty. Skip registration"; + continue; + } + + if (op_package_interface.empty()) { + LOGS_DEFAULT(WARNING) << "Op package interface is empty. Skip registration"; + continue; + } + + LOGS_DEFAULT(VERBOSE) << "Loading op package from path: " << op_package_path << " for op " << op_type; + LOGS_DEFAULT(VERBOSE) << "Op package interface: " << op_package_interface; + if (splitStrings.size() > 3 && splitStrings[3].size()) { + op_package_target = std::string(splitStrings[3]); + LOGS_DEFAULT(VERBOSE) << "Op package target: " << op_package_target; + } + op_packages.push_back({op_type, op_package_path, op_package_interface, op_package_target}); + } +} + static bool ParseBoolOption(const std::string& key, bool default_value, const std::unordered_map& options) { bool result = default_value; @@ -406,6 +446,26 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio } } + static const std::string HTP_VTCM_BACKUP_BUFFER_SHARING = "enable_vtcm_backup_buffer_sharing"; + auto htp_vtcm_backup_buffer_sharing_pos = provider_options_map.find(HTP_VTCM_BACKUP_BUFFER_SHARING); + if (htp_vtcm_backup_buffer_sharing_pos != provider_options_map.end()) { + if ("1" == htp_vtcm_backup_buffer_sharing_pos->second) { + enable_vtcm_backup_buffer_sharing_ = true; + } else if ("0" != htp_vtcm_backup_buffer_sharing_pos->second) { + LOGS_DEFAULT(WARNING) << "Invalid value entered for " << HTP_VTCM_BACKUP_BUFFER_SHARING + << ": " << htp_vtcm_backup_buffer_sharing_pos->second + << ", only 1 or 0 are allowed. Setting to 0."; + } + + LOGS_DEFAULT(VERBOSE) << "User specified enable_vtcm_backup_buffer_sharing: " << enable_vtcm_backup_buffer_sharing_; + +#if QNN_API_VERSION_MAJOR < 2 || ((QNN_API_VERSION_MAJOR) == 2 && (QNN_API_VERSION_MINOR < 26)) + if (enable_vtcm_backup_buffer_sharing_) { + LOGS_DEFAULT(WARNING) << "User specified enable_vtcm_backup_buffer_sharing but QNN API version is older than 2.26."; + } +#endif + } + static const std::string QNN_DEVICE_ID = "device_id"; auto dev_id_pos = provider_options_map.find(QNN_DEVICE_ID); if (dev_id_pos != provider_options_map.end()) { @@ -438,6 +498,13 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio } } + static const std::string QNN_OP_PACKAGES = "op_packages"; + std::vector op_packages; + auto op_packages_pos = provider_options_map.find(QNN_OP_PACKAGES); + if (op_packages_pos != provider_options_map.end()) { + ParseOpPackages(op_packages_pos->second, op_packages); + } + static const std::string QNN_HTP_FP16_MODE = "enable_htp_fp16_precision"; auto htp_fp16_mode_pos = provider_options_map.find(QNN_HTP_FP16_MODE); if (htp_fp16_mode_pos != provider_options_map.end()) { @@ -455,6 +522,10 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio LOGS_DEFAULT(ERROR) << "[EP context generation:] Weight sharing enabled conflict with EP context embed mode. Inference will not work as expected!"; } + if (qnn_context_embed_mode_ && enable_vtcm_backup_buffer_sharing_) { + LOGS_DEFAULT(ERROR) << "[EP context generation:] VTCM backup buffer sharing enabled conflict with EP context embed mode. Inference will not work as expected!"; + } + // Add this option because this feature requires QnnSystem lib and it's no supported for Windows x86_64 platform enable_spill_fill_buffer_ = ParseBoolOption("enable_htp_spill_fill_buffer", false, provider_options_map); @@ -494,7 +565,7 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio // For context binary generation with weight sharing enabled, use the QnnBackendManager from the shared context if it exits // So that all graphs from later sessions will be compiled into the same QNN context - if (context_cache_enabled_ && share_ep_contexts_ && SharedContext::GetInstance().GetSharedQnnBackendManager()) { + if (((context_cache_enabled_ && share_ep_contexts_) || enable_vtcm_backup_buffer_sharing_) && SharedContext::GetInstance().GetSharedQnnBackendManager()) { qnn_backend_manager_ = SharedContext::GetInstance().GetSharedQnnBackendManager(); // Clear the QnnBackendManager from singleton to stop the resource share if (stop_share_ep_contexts_) { @@ -510,14 +581,18 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio std::move(qnn_serializer_config), device_id_, htp_arch, - soc_model}); + soc_model, + op_packages}); + if (enable_vtcm_backup_buffer_sharing_) { + SharedContext::GetInstance().SetSharedQnnBackendManager(qnn_backend_manager_); + } } #if defined(_WIN32) if (onnxruntime::logging::EtwRegistrationManager::SupportsETW()) { auto& etwRegistrationManager = logging::EtwRegistrationManager::Instance(); // Register callback for ETW capture state (rundown) - auto etw_callback = + callback_ETWSink_provider_ = onnxruntime::logging::EtwRegistrationManager::EtwInternalCallback( [&etwRegistrationManager, this]( LPCGUID SourceId, ULONG IsEnabled, @@ -557,10 +632,8 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio // (void)qnn_backend_manager_->SetProfilingLevelETW(qnn::ProfilingLevel::INVALID); (void)qnn_backend_manager_->ResetQnnLogLevel(std::nullopt); } - }; - callback_ETWSink_key_ = "QnnExecutionProvider_"; - callback_ETWSink_key_.append(std::to_string(reinterpret_cast(this))); - etwRegistrationManager.RegisterInternalCallback(callback_ETWSink_key_, etw_callback); + }); + etwRegistrationManager.RegisterInternalCallback(callback_ETWSink_provider_); } #endif } @@ -576,7 +649,9 @@ QNNExecutionProvider::~QNNExecutionProvider() { // Unregister the ETW callback #if defined(_WIN32) - logging::EtwRegistrationManager::Instance().UnregisterInternalCallback(callback_ETWSink_key_); + if (callback_ETWSink_provider_ != nullptr) { + logging::EtwRegistrationManager::Instance().UnregisterInternalCallback(callback_ETWSink_provider_); + } #endif } @@ -720,6 +795,27 @@ static bool EpSharedContextsHasAllGraphs(const std::vector& ep_context_nodes, + const logging::Logger& logger) { + for (const auto& node : graph_viewer.Nodes()) { + NodeAttrHelper node_helper(node); + bool is_main_context = node_helper.Get(qnn::MAIN_CONTEXT, static_cast(0)); + std::string cache_source = node_helper.Get(qnn::SOURCE, ""); + + std::transform(cache_source.begin(), + cache_source.end(), + cache_source.begin(), + [](unsigned char c) { return static_cast(std::tolower(c)); }); + + if (is_main_context && qnn::EPCONTEXT_OP == node.OpType() && (cache_source == "qnnexecutionprovider" || cache_source == "qnn")) { + LOGS(logger, VERBOSE) << "EPContext Node found: [1] index: [" << node.Index() + << "] name: [" << node.Name(); + ep_context_nodes.insert(&node); + } + } +} + // For model with EPContext, filter in EPContext nodes only, and make sure each partition only has one single EPContext node static void PartitionCtxModel(const onnxruntime::GraphViewer& graph_viewer, const size_t num_nodes_in_graph, @@ -769,6 +865,18 @@ static void PartitionCtxModel(const onnxruntime::GraphViewer& graph_viewer, return; } +// Figure out the context cache Onnx file path to decide the folder location +static void GetContextOnnxModelFilePath(const std::string& user_context_cache_path, + const onnxruntime::PathString& model_path_string, + onnxruntime::PathString& context_model_path) { + // always try the path set by user first, it's the only way to set it if load model from memory + if (!user_context_cache_path.empty()) { + context_model_path = ToPathString(user_context_cache_path); + } else if (!model_path_string.empty()) { // model loaded from file + context_model_path = model_path_string; + } +} + std::vector> QNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer, const IKernelLookup& /*kernel_lookup*/, @@ -801,12 +909,41 @@ QNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer } } + std::unordered_map>> context_bin_map; + if (enable_vtcm_backup_buffer_sharing_) { + std::unordered_set ep_ctx_nodes; + GetMainEPCtxNodes(graph_viewer, ep_ctx_nodes, logger); + + onnxruntime::PathString context_model_path; + GetContextOnnxModelFilePath(context_cache_path_cfg_, graph_viewer.ModelPath().native(), context_model_path); + + std::filesystem::path parent_path = std::filesystem::path(context_model_path).parent_path(); + + for (auto& ep_ctx_node : ep_ctx_nodes) { + NodeAttrHelper node_helper(*ep_ctx_node); + std::string context_bin_filepath(parent_path.string()); + context_bin_filepath.append("/").append(node_helper.Get(qnn::EP_CACHE_CONTEXT, "")); + + if (context_bin_map.find(context_bin_filepath) == context_bin_map.end()) { + context_bin_map.emplace(context_bin_filepath, std::make_unique>()); + // Push context bin filepath for lookup between sessions + context_bin_map.at(context_bin_filepath)->push_back(context_bin_filepath); + } + context_bin_map.at(context_bin_filepath)->push_back(ep_ctx_node->Name()); + } + } + // It will load the QnnSystem lib if is_qnn_ctx_model=true, and // delay the Qnn context creation to Compile() using the cached context binary // or generate context cache enable, need to use use QnnSystem lib to parse the binary to get the max spill fill buffer size auto rt = qnn_backend_manager_->SetupBackend(logger, is_qnn_ctx_model, context_cache_enabled_ && enable_spill_fill_buffer_, - share_ep_contexts_); + share_ep_contexts_, + enable_vtcm_backup_buffer_sharing_, + context_bin_map); + + context_bin_map.clear(); + if (Status::OK() != rt) { LOGS(logger, ERROR) << "QNN SetupBackend failed " << rt.ErrorMessage(); return result; @@ -1054,18 +1191,6 @@ Status QNNExecutionProvider::CompileFromOrtGraph(const std::vector& fused_nodes_and_graphs, std::vector& node_compute_funcs) { const auto& logger = *GetLogger(); @@ -1185,6 +1310,7 @@ Status QNNExecutionProvider::Compile(const std::vector& fused qnn_models_, context_model_path, qnn_context_embed_mode_, + max_spill_fill_buffer_size, logger, share_ep_contexts_, diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.h b/onnxruntime/core/providers/qnn/qnn_execution_provider.h index bf022ae0e0018..7115708d42d8c 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.h +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.h @@ -86,6 +86,7 @@ class QNNExecutionProvider : public IExecutionProvider { bool disable_cpu_ep_fallback_ = false; // True if CPU EP fallback has been disabled for this session. bool qnn_context_embed_mode_ = true; int32_t vtcm_size_in_mb_ = 0; + bool enable_vtcm_backup_buffer_sharing_ = false; std::unique_ptr qnn_ep_context_model_; std::unique_ptr metadef_id_generator_; uint32_t device_id_ = 0; @@ -96,7 +97,7 @@ class QNNExecutionProvider : public IExecutionProvider { bool stop_share_ep_contexts_ = false; bool enable_spill_fill_buffer_ = false; #if defined(_WIN32) - std::string callback_ETWSink_key_; + onnxruntime::logging::EtwRegistrationManager::EtwInternalCallback callback_ETWSink_provider_ = nullptr; #endif qnn::ModelSettings model_settings_ = {}; bool dump_json_qnn_graph_ = false; diff --git a/onnxruntime/core/providers/qnn/qnn_telemetry.cc b/onnxruntime/core/providers/qnn/qnn_telemetry.cc index 9acc85c5feebe..b2c8350bfe8ca 100644 --- a/onnxruntime/core/providers/qnn/qnn_telemetry.cc +++ b/onnxruntime/core/providers/qnn/qnn_telemetry.cc @@ -55,7 +55,6 @@ TRACELOGGING_DEFINE_PROVIDER(telemetry_provider_handle, "Microsoft.ML.ONNXRuntim #endif // !BUILD_QNN_EP_STATIC_LIB #include "core/providers/qnn/ort_api.h" -#include namespace onnxruntime { namespace qnn { @@ -67,7 +66,7 @@ uint32_t QnnTelemetry::global_register_count_ = 0; bool QnnTelemetry::enabled_ = true; UCHAR QnnTelemetry::level_ = 0; UINT64 QnnTelemetry::keyword_ = 0; -std::unordered_map QnnTelemetry::callbacks_; +std::vector QnnTelemetry::callbacks_; std::mutex QnnTelemetry::callbacks_mutex_; #endif // !BUILD_QNN_EP_STATIC_LIB @@ -158,21 +157,25 @@ void QnnTelemetry::LogQnnProfileEvent(uint64_t timestamp, TraceLoggingString(eventIdentifier, "Event Identifier")); } -void QnnTelemetry::RegisterInternalCallback(const std::string& cb_key, EtwInternalCallback callback) { +void QnnTelemetry::RegisterInternalCallback(const EtwInternalCallback& callback) { #if BUILD_QNN_EP_STATIC_LIB - WindowsTelemetry::RegisterInternalCallback(cb_key, std::move(callback)); + WindowsTelemetry::RegisterInternalCallback(callback); #else std::lock_guard lock_callbacks(callbacks_mutex_); - callbacks_.insert_or_assign(cb_key, std::move(callback)); + callbacks_.push_back(&callback); #endif } -void QnnTelemetry::UnregisterInternalCallback(const std::string& cb_key) { +void QnnTelemetry::UnregisterInternalCallback(const EtwInternalCallback& callback) { #if BUILD_QNN_EP_STATIC_LIB - WindowsTelemetry::UnregisterInternalCallback(cb_key); + WindowsTelemetry::UnregisterInternalCallback(callback); #else std::lock_guard lock_callbacks(callbacks_mutex_); - callbacks_.erase(cb_key); + auto new_end = std::remove_if(callbacks_.begin(), callbacks_.end(), + [&callback](const EtwInternalCallback* ptr) { + return ptr == &callback; + }); + callbacks_.erase(new_end, callbacks_.end()); #endif } @@ -185,12 +188,10 @@ void NTAPI QnnTelemetry::ORT_TL_EtwEnableCallback( _In_ ULONGLONG MatchAllKeyword, _In_opt_ PEVENT_FILTER_DESCRIPTOR FilterData, _In_opt_ PVOID CallbackContext) { - { - std::lock_guard lock(provider_change_mutex_); - enabled_ = (IsEnabled != 0); - level_ = Level; - keyword_ = MatchAnyKeyword; - } + std::lock_guard lock(provider_change_mutex_); + enabled_ = (IsEnabled != 0); + level_ = Level; + keyword_ = MatchAnyKeyword; InvokeCallbacks(SourceId, IsEnabled, Level, MatchAnyKeyword, MatchAllKeyword, FilterData, CallbackContext); } @@ -199,9 +200,8 @@ void QnnTelemetry::InvokeCallbacks(LPCGUID SourceId, ULONG IsEnabled, UCHAR Leve ULONGLONG MatchAllKeyword, PEVENT_FILTER_DESCRIPTOR FilterData, PVOID CallbackContext) { std::lock_guard lock_callbacks(callbacks_mutex_); - for (const auto& entry : callbacks_) { - const auto& cb = entry.second; - cb(SourceId, IsEnabled, Level, MatchAnyKeyword, MatchAllKeyword, FilterData, CallbackContext); + for (const auto& callback : callbacks_) { + (*callback)(SourceId, IsEnabled, Level, MatchAnyKeyword, MatchAllKeyword, FilterData, CallbackContext); } } #endif // !BUILD_QNN_EP_STATIC_LIB diff --git a/onnxruntime/core/providers/qnn/qnn_telemetry.h b/onnxruntime/core/providers/qnn/qnn_telemetry.h index 4d68f14969e9e..a2d42c518c1ac 100644 --- a/onnxruntime/core/providers/qnn/qnn_telemetry.h +++ b/onnxruntime/core/providers/qnn/qnn_telemetry.h @@ -12,7 +12,6 @@ #include #include #include -#include #include #include "core/providers/qnn/ort_api.h" @@ -59,9 +58,9 @@ class QnnTelemetry { ULONGLONG MatchAnyKeyword, ULONGLONG MatchAllKeyword, PEVENT_FILTER_DESCRIPTOR FilterData, PVOID CallbackContext)>; - static void RegisterInternalCallback(const std::string& cb_key, EtwInternalCallback callback); + static void RegisterInternalCallback(const EtwInternalCallback& callback); - static void UnregisterInternalCallback(const std::string& cb_key); + static void UnregisterInternalCallback(const EtwInternalCallback& callback); private: QnnTelemetry(); @@ -73,7 +72,7 @@ class QnnTelemetry { static uint32_t global_register_count_; static bool enabled_; - static std::unordered_map callbacks_; + static std::vector callbacks_; static std::mutex callbacks_mutex_; static std::mutex provider_change_mutex_; static UCHAR level_; diff --git a/onnxruntime/core/providers/rknpu/rknpu_execution_provider.cc b/onnxruntime/core/providers/rknpu/rknpu_execution_provider.cc index e9343e2b2e06a..312733cb2ba0f 100644 --- a/onnxruntime/core/providers/rknpu/rknpu_execution_provider.cc +++ b/onnxruntime/core/providers/rknpu/rknpu_execution_provider.cc @@ -102,8 +102,8 @@ RknpuExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_view const auto graph_outputs = graph_viewer.GetOutputs(); // Add initializer to graph_viewer const auto& init_tensors = graph_viewer.GetAllInitializedTensors(); - for (const auto& tensor : init_tensors) { - graph_build.AddInitializedTensor(*(tensor.second)); + for (const auto& [name, _] : init_tensors) { + graph_utils::MakeInitializerCopyIfNotExist(graph_viewer.GetGraph(), graph_build, name); } ORT_ENFORCE(graph_build.Resolve().IsOK()); diff --git a/onnxruntime/core/providers/rocm/gpu_data_transfer.cc b/onnxruntime/core/providers/rocm/gpu_data_transfer.cc index 281a6f35a2808..2c593a4adc41b 100644 --- a/onnxruntime/core/providers/rocm/gpu_data_transfer.cc +++ b/onnxruntime/core/providers/rocm/gpu_data_transfer.cc @@ -10,8 +10,8 @@ namespace onnxruntime { bool GPUDataTransfer::CanCopy(const OrtDevice& src_device, const OrtDevice& dst_device) const { - return src_device.Type() == OrtDevice::GPU || src_device.MemType() == OrtDevice::MemType::HIP_PINNED || - dst_device.Type() == OrtDevice::GPU || dst_device.MemType() == OrtDevice::MemType::HIP_PINNED; + return src_device.Type() == OrtDevice::GPU || src_device.MemType() == OrtDevice::MemType::HOST_ACCESSIBLE || + dst_device.Type() == OrtDevice::GPU || dst_device.MemType() == OrtDevice::MemType::HOST_ACCESSIBLE; } common::Status GPUDataTransfer::CopyTensor(const Tensor& src, Tensor& dst) const { @@ -34,7 +34,7 @@ common::Status GPUDataTransfer::CopyTensor(const Tensor& src, Tensor& dst) const } else { // copy from other CPU memory to GPU, this is blocking HIP_RETURN_IF_ERROR(hipMemcpy(dst_data, src_data, bytes, hipMemcpyHostToDevice)); - if (src_device.MemType() != OrtDevice::MemType::HIP_PINNED) { + if (src_device.MemType() != OrtDevice::MemType::HOST_ACCESSIBLE) { // Follow core/providers/cuda/gpu_data_transfer.cc to synchronize the default stream here. HIP_RETURN_IF_ERROR(hipStreamSynchronize(nullptr)); } @@ -63,19 +63,22 @@ common::Status GPUDataTransfer::CopyTensorAsync(const Tensor& src, Tensor& dst, if (src_device.Type() == OrtDevice::CPU) { // If source are not pinned, the memory copy will be performed synchronously. // For best performance, use hipHostMalloc to allocate host memory that is transferred asynchronously. - HIP_RETURN_IF_ERROR(hipMemcpyAsync(dst_data, src_data, bytes, hipMemcpyHostToDevice, static_cast(stream.GetHandle()))); + HIP_RETURN_IF_ERROR(hipMemcpyAsync(dst_data, src_data, bytes, hipMemcpyHostToDevice, + static_cast(stream.GetHandle()))); } else if (src_device.Type() == OrtDevice::GPU) { // copying between GPU, this is non-blocking if (dst_data != src_data) { - HIP_RETURN_IF_ERROR(hipMemcpyAsync(dst_data, src_data, bytes, hipMemcpyDeviceToDevice, static_cast(stream.GetHandle()))); + HIP_RETURN_IF_ERROR(hipMemcpyAsync(dst_data, src_data, bytes, hipMemcpyDeviceToDevice, + static_cast(stream.GetHandle()))); } } } else if (src_device.Type() == OrtDevice::GPU) { // If dest are not pinned, the memory copy will be performed synchronously. // For best performance, use hipHostMalloc to allocate host memory that is transferred asynchronously. - HIP_RETURN_IF_ERROR(hipMemcpyAsync(dst_data, src_data, bytes, hipMemcpyDeviceToHost, static_cast(stream.GetHandle()))); + HIP_RETURN_IF_ERROR(hipMemcpyAsync(dst_data, src_data, bytes, hipMemcpyDeviceToHost, + static_cast(stream.GetHandle()))); } else { - if (src_device.MemType() == OrtDevice::MemType::CUDA_PINNED) { + if (src_device.MemType() == OrtDevice::MemType::HOST_ACCESSIBLE) { // sync the stream first to make sure the data arrived HIP_RETURN_IF_ERROR(hipStreamSynchronize(static_cast(stream.GetHandle()))); } diff --git a/onnxruntime/core/providers/rocm/rocm_allocator.h b/onnxruntime/core/providers/rocm/rocm_allocator.h index ef13fc2e25cda..ae7982ae6c618 100644 --- a/onnxruntime/core/providers/rocm/rocm_allocator.h +++ b/onnxruntime/core/providers/rocm/rocm_allocator.h @@ -14,8 +14,8 @@ class ROCMAllocator : public IAllocator { ROCMAllocator(OrtDevice::DeviceId device_id, const char* name) : IAllocator( OrtMemoryInfo(name, OrtAllocatorType::OrtDeviceAllocator, - OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, device_id), - device_id, OrtMemTypeDefault)) {} + OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::AMD, device_id), + OrtMemTypeDefault)) {} void* Alloc(size_t size) override; void Free(void* p) override; @@ -55,8 +55,9 @@ class ROCMPinnedAllocator : public IAllocator { ROCMPinnedAllocator(const char* name) : IAllocator( OrtMemoryInfo(name, OrtAllocatorType::OrtDeviceAllocator, - OrtDevice(OrtDevice::CPU, OrtDevice::MemType::HIP_PINNED, 0 /*CPU device always with id 0*/), - 0, OrtMemTypeCPUOutput)) {} + OrtDevice(OrtDevice::CPU, OrtDevice::MemType::HOST_ACCESSIBLE, OrtDevice::VendorIds::AMD, + 0 /*CPU device always with id 0*/), + OrtMemTypeCPUOutput)) {} void* Alloc(size_t size) override; void Free(void* p) override; diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc index 49771488efc44..6fcf23e346b6a 100644 --- a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc +++ b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc @@ -257,7 +257,9 @@ void OverrideTunableOpInfoByEnv(ROCMExecutionProviderInfo& info) { } ROCMExecutionProvider::ROCMExecutionProvider(const ROCMExecutionProviderInfo& info) - : IExecutionProvider{onnxruntime::kRocmExecutionProvider, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, info.device_id)}, + : IExecutionProvider{onnxruntime::kRocmExecutionProvider, + OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::AMD, + info.device_id)}, info_{info}, tuning_context_(this, &info_.tunable_op) { HIP_CALL_THROW(hipSetDevice(info_.device_id)); @@ -2524,8 +2526,11 @@ void ROCMExecutionProvider::RegisterStreamHandlers(IStreamCommandHandleRegistry& } OrtDevice ROCMExecutionProvider::GetOrtDeviceByMemType(OrtMemType mem_type) const { - if (mem_type == OrtMemTypeCPUInput) return OrtDevice(); - if (mem_type == OrtMemTypeCPUOutput) return OrtDevice(OrtDevice::CPU, OrtDevice::MemType::HIP_PINNED, 0 /*CPU device id always be 0*/); + if (mem_type == OrtMemTypeCPUInput) + return OrtDevice(); + if (mem_type == OrtMemTypeCPUOutput) + return OrtDevice(OrtDevice::CPU, OrtDevice::MemType::HOST_ACCESSIBLE, OrtDevice::VendorIds::AMD, + 0 /*CPU device id always be 0*/); return default_device_; } diff --git a/onnxruntime/core/providers/shared/utils/utils.cc b/onnxruntime/core/providers/shared/utils/utils.cc index 4281b5e53c5fd..0e0f559d2e0f1 100644 --- a/onnxruntime/core/providers/shared/utils/utils.cc +++ b/onnxruntime/core/providers/shared/utils/utils.cc @@ -27,7 +27,7 @@ bool GetType(const NodeArg& node_arg, int32_t& type, const logging::Logger& logg namespace { bool GetClipMinMaxImpl(std::function get_const_initializer, - const Node& node, float& min, float& max, const logging::Logger& logger) { + const Graph& graph, const Node& node, float& min, float& max, const logging::Logger& logger) { const auto& node_name = node.Name(); int32_t input_type; if (!GetType(*node.InputDefs()[0], input_type, logger)) { @@ -50,7 +50,7 @@ bool GetClipMinMaxImpl(std::function()[0]; @@ -97,7 +97,7 @@ bool GetClipMinMax(const GraphViewer& graph_viewer, const Node& node, float& min [&graph_viewer](const std::string& name) -> const ONNX_NAMESPACE::TensorProto* { return graph_viewer.GetConstantInitializer(name); }, - node, min, max, logger); + graph_viewer.GetGraph(), node, min, max, logger); } NodeAttrHelper::NodeAttrHelper(const onnxruntime::Node& node) diff --git a/onnxruntime/core/providers/shared_library/provider_api.h b/onnxruntime/core/providers/shared_library/provider_api.h index 58d4461c7c32a..4d3ae4f4a7e07 100644 --- a/onnxruntime/core/providers/shared_library/provider_api.h +++ b/onnxruntime/core/providers/shared_library/provider_api.h @@ -418,8 +418,32 @@ inline std::unique_ptr MakeComputeCapability(const GraphViewe return g_host->Utils__MakeComputeCapability(graph_viewer, group, generate_metadef_name, execution_provider_name, drop_constant_initializers); } + +inline Status GetTensorProtoWithDataIfInMemory( + const ONNX_NAMESPACE::TensorProto& tensor_proto, std::unique_ptr& result) { + return g_host->Utils__GetTensorProtoWithDataIfInMemory(tensor_proto, result); +} + +inline bool HasExternalDataInMemory(const ONNX_NAMESPACE::TensorProto& ten_proto) { + return g_host->Utils__HasExternalDataInMemory(ten_proto); +} + } // namespace utils +namespace graph_utils { +inline NodeArg& AddInitializerWithExternalData(Graph& graph, const ONNX_NAMESPACE::TensorProto& new_initializer) { + return g_host->GraphUtils__AddInitializerWithExternalData(graph, new_initializer); +} +inline void MakeInitializerCopyIfNotExist(const Graph& src_graph, Graph& dst_graph, const std::string& name, + bool load_inline = false) { + g_host->GraphUtils__MakeInitializerCopyIfNotExist(src_graph, dst_graph, name, load_inline); +} + +inline Status ConvertInMemoryDataToInline(Graph& graph, const std::string& name) { + return g_host->GraphUtils__ConvertInMemoryDataToInline(graph, name); +} +} // namespace graph_utils + namespace QDQ { inline std::pair>, std::unordered_map> GetAllNodeUnits(const GraphViewer* graph_viewer, const logging::Logger& logger) { @@ -436,6 +460,18 @@ inline Env& GetDefaultEnv() { return g_host->Env__Default(); } +template +inline const T* Initializer::data() const { + constexpr const int data_type = static_cast(utils::GetONNXTensorElementDataType()); + return reinterpret_cast(g_host->Initializer__data(*this_ptr_, data_type)); +} + +template +inline T* Initializer::data() { + constexpr const int data_type = static_cast(utils::GetONNXTensorElementDataType()); + return reinterpret_cast(g_host->Initializer__mutable_data(*this_ptr_, data_type)); +} + } // namespace onnxruntime #define CREATE_MESSAGE(logger, severity, category, datatype) \ diff --git a/onnxruntime/core/providers/shared_library/provider_interfaces.h b/onnxruntime/core/providers/shared_library/provider_interfaces.h index f20760fcc86fd..f843b86375e78 100644 --- a/onnxruntime/core/providers/shared_library/provider_interfaces.h +++ b/onnxruntime/core/providers/shared_library/provider_interfaces.h @@ -37,9 +37,11 @@ namespace onnxruntime { struct ProviderHost; struct ProviderHostCPU; +class ExternalDataInfo; class PhiloxGenerator; using ProviderType = const std::string&; class RandomGenerator; +class Initializer; class IOnnxRuntimeOpSchemaCollection; struct ModelSavingOptions; @@ -314,11 +316,10 @@ struct ProviderHost { virtual logging::Severity logging__EtwRegistrationManager__MapLevelToSeverity(logging::EtwRegistrationManager* p) = 0; virtual void logging__EtwRegistrationManager__RegisterInternalCallback( logging::EtwRegistrationManager* p, - const std::string& cb_key, - logging::EtwRegistrationManager_EtwInternalCallback callback) = 0; + const logging::EtwRegistrationManager_EtwInternalCallback& callback) = 0; virtual void logging__EtwRegistrationManager__UnregisterInternalCallback( logging::EtwRegistrationManager* p, - const std::string& cb_key) = 0; + const logging::EtwRegistrationManager_EtwInternalCallback& callback) = 0; #endif // defined(_WIN32) // Env @@ -977,6 +978,12 @@ struct ProviderHost { const std::function& generate_metadef_name, const std::string& execution_provider_name, bool drop_constant_initializers) = 0; + + virtual Status Utils__GetTensorProtoWithDataIfInMemory( + const ONNX_NAMESPACE::TensorProto& tensor_proto, std::unique_ptr& result) = 0; + + virtual bool Utils__HasExternalDataInMemory(const ONNX_NAMESPACE::TensorProto& ten_proto) = 0; + // Model virtual std::unique_ptr Model__construct(ONNX_NAMESPACE::ModelProto&& model_proto, const PathString& model_path, const IOnnxRuntimeOpSchemaRegistryList* local_registries, @@ -1004,6 +1011,8 @@ struct ProviderHost { virtual const std::unordered_map& Graph__DomainToVersionMap(const Graph* p) const noexcept = 0; virtual Status Graph__Resolve(Graph* p) = 0; virtual void Graph__AddInitializedTensor(Graph* p, const ONNX_NAMESPACE::TensorProto& tensor) = 0; + // We pass OrtValue by reference here (as opposed to the original Graph function) to avoid header inclusion + virtual Status Graph__AddInitializedOrtValue(Graph* p, const ONNX_NAMESPACE::TensorProto& tensor, const OrtValue& value) = 0; virtual Node& Graph__AddNode(Graph* p, const std::string& name, const std::string& op_type, const std::string& description, const gsl::span& input_args, const gsl::span& output_args, const NodeAttributes* attributes, const std::string& domain) = 0; virtual Node& Graph__AddNode(Graph* p, const std::string& name, const std::string& op_type, const std::string& description, const gsl::span& input_args, const gsl::span& output_args, NodeAttributes&& attributes, const std::string& domain) = 0; virtual Node& Graph__AddNode(Graph* p, const Node& other) = 0; @@ -1099,6 +1108,37 @@ struct ProviderHost { virtual std::unique_ptr ConstGraphNodes__cend(const ConstGraphNodes* p) = 0; virtual bool ConstGraphNodes__empty(const ConstGraphNodes* p) noexcept = 0; + // graph_util + virtual NodeArg& GraphUtils__AddInitializerWithExternalData(Graph& graph, + const ONNX_NAMESPACE::TensorProto& new_initializer) = 0; + virtual void GraphUtils__MakeInitializerCopyIfNotExist(const Graph& src_graph, Graph& dst_graph, + const std::string& name, bool load_inline) = 0; + + virtual Status GraphUtils__ConvertInMemoryDataToInline(Graph& graph, const std::string& name) = 0; + + // Initializer + virtual Initializer* Initializer__constructor(ONNX_NAMESPACE::TensorProto_DataType data_type, + std::string_view name, + gsl::span dims) = 0; + virtual Initializer* Initializer__constructor(const Graph& graph, const ONNX_NAMESPACE::TensorProto& tensor_proto, + const std::filesystem::path& model_path = {}, + bool check_outer_scope = false) = 0; + + virtual void Initializer__destructor(Initializer*) = 0; + virtual void Initializer__ToProto(const Initializer&, + ONNX_NAMESPACE::TensorProto& tensor_proto) = 0; + virtual void Initializer__ToProtoWithOrtValue(const Initializer&, + ONNX_NAMESPACE::TensorProto& tensor_proto, OrtValue& ort_value) = 0; + virtual int Initializer__data_type(const Initializer&) = 0; + virtual const std::string& Initializer__name(const Initializer&) = 0; + virtual gsl::span Initializer__dims(const Initializer&) = 0; + virtual size_t Initializer__size(const Initializer&) = 0; + // data() template helper + virtual void* Initializer__mutable_data(Initializer&, int data_type) = 0; + virtual const void* Initializer__data(const Initializer&, int data_type) = 0; + virtual void* Initializer__mutable_data_raw(Initializer&) = 0; + virtual const void* Initializer__data_raw(const Initializer&) = 0; + // OpKernel virtual const Node& OpKernel__Node(const OpKernel* p) = 0; diff --git a/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h b/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h index 5fadd0b0966e8..80b5e26db8680 100644 --- a/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h +++ b/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h @@ -58,11 +58,11 @@ struct EtwRegistrationManager final { static EtwRegistrationManager& Instance() { return g_host->logging__EtwRegistrationManager__Instance(); } static bool SupportsETW() { return g_host->logging__EtwRegistrationManager__SupportsETW(); } Severity MapLevelToSeverity() { return g_host->logging__EtwRegistrationManager__MapLevelToSeverity(this); } - void RegisterInternalCallback(const std::string& cb_key, EtwInternalCallback callback) { - g_host->logging__EtwRegistrationManager__RegisterInternalCallback(this, cb_key, std::move(callback)); + void RegisterInternalCallback(const EtwInternalCallback& callback) { + g_host->logging__EtwRegistrationManager__RegisterInternalCallback(this, callback); } - void UnregisterInternalCallback(const std::string& cb_key) { - g_host->logging__EtwRegistrationManager__UnregisterInternalCallback(this, cb_key); + void UnregisterInternalCallback(const EtwInternalCallback& callback) { + g_host->logging__EtwRegistrationManager__UnregisterInternalCallback(this, callback); } }; #endif // defined(_WIN32) @@ -1038,6 +1038,9 @@ struct Graph final { const std::unordered_map& DomainToVersionMap() const noexcept { return g_host->Graph__DomainToVersionMap(this); } Status Resolve() { return g_host->Graph__Resolve(this); } void AddInitializedTensor(const ONNX_NAMESPACE::TensorProto& tensor) { return g_host->Graph__AddInitializedTensor(this, tensor); } + Status AddInitializedOrtValue(const ONNX_NAMESPACE::TensorProto& tensor, const OrtValue& ort_value) { + return g_host->Graph__AddInitializedOrtValue(this, tensor, ort_value); + } Node& AddNode(const std::string& name, const std::string& op_type, const std::string& description, gsl::span input_args, gsl::span output_args, const NodeAttributes* attributes, const std::string& domain) { return g_host->Graph__AddNode(this, name, op_type, description, input_args, output_args, attributes, domain); } Node& AddNode(const std::string& name, const std::string& op_type, const std::string& description, gsl::span input_args, gsl::span output_args, NodeAttributes&& attributes, const std::string& domain) { return g_host->Graph__AddNode(this, name, op_type, description, input_args, output_args, std::move(attributes), domain); } Node& AddNode(const Node& other) { return g_host->Graph__AddNode(this, other); } @@ -1177,6 +1180,69 @@ struct ConstGraphNodes final { PROVIDER_DISALLOW_ALL(ConstGraphNodes) }; +class Initializer { + public: + Initializer(ONNX_NAMESPACE::TensorProto_DataType data_type, + std::string_view name, + gsl::span dims) { + this_ptr_ = g_host->Initializer__constructor(data_type, name, dims); + } + + Initializer(const Graph& graph, const ONNX_NAMESPACE::TensorProto& tensor_proto, + const std::filesystem::path& model_path = {}, + bool check_outer_scope = false) { + this_ptr_ = g_host->Initializer__constructor(graph, tensor_proto, model_path, check_outer_scope); + } + + ~Initializer() { + g_host->Initializer__destructor(this_ptr_); + } + + PROVIDER_DISALLOW_ALL(Initializer); + + void ToProto(ONNX_NAMESPACE::TensorProto& tensor_proto) const { + g_host->Initializer__ToProto(*this_ptr_, tensor_proto); + } + + void ToProtoWithOrtValue(ONNX_NAMESPACE::TensorProto& tensor_proto, OrtValue& ort_value) const { + g_host->Initializer__ToProtoWithOrtValue(*this_ptr_, tensor_proto, ort_value); + } + + int data_type() const { + return g_host->Initializer__data_type(*this_ptr_); + } + + const std::string& name() const { + return g_host->Initializer__name(*this_ptr_); + } + + gsl::span dims() const { + return g_host->Initializer__dims(*this_ptr_); + } + + size_t size() const { + return g_host->Initializer__size(*this_ptr_); + } + + // See definition for the below templates in provider_api.h + template + const T* data() const; + + template + T* data(); + + const void* data_raw() const { + return g_host->Initializer__data_raw(*this_ptr_); + } + + void* mutable_data_raw() { + return g_host->Initializer__mutable_data_raw(*this_ptr_); + } + + private: + Initializer* this_ptr_; +}; + struct OpKernelContext final { template const T& RequiredInput(int index) const; diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index fc8281ce51a1b..32d93c305273d 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -46,7 +46,8 @@ using namespace ONNX_NAMESPACE; using namespace ::onnxruntime::logging; namespace { // Check if cycle exists in the graph after partitioning -bool FindCycleHelper(size_t i, const std::list* adjacency_map, bool visited[], bool* st, std::vector& cycles) { +bool FindCycleHelper(size_t i, gsl::span> adjacency_map, gsl::span visited, gsl::span st, + InlinedVector& cycles) { if (!visited[i]) { visited[i] = true; st[i] = true; @@ -263,7 +264,6 @@ struct ShutdownProtobuf { } g_protobuf; namespace onnxruntime { - namespace cuda { template <> void Impl_Cast( @@ -1332,7 +1332,7 @@ std::vector ParseTrtPreviewFeatures(const std::string& TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProviderInfo& info) : IExecutionProvider{onnxruntime::kTensorrtExecutionProvider, - OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, + OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::NVIDIA, narrow(info.device_id))}, info_(info), device_id_(info.device_id) { @@ -2204,28 +2204,22 @@ SubGraphCollection_t TensorrtExecutionProvider::GetSupportedList(SubGraphCollect // subgraph's output list std::vector subgraph_output_names; for (const auto& index : group.first) { + // Initializers that refer to a memory location in OrtValue + // can not be handled by TRT (unlike those that are on disk). + // This prevents us from sharing the data and we have to make a copy here. + constexpr const bool load_initializers_inline_true = true; const auto& node = graph.GetNode(node_index[index]); std::vector inputs, outputs; for (auto input : node->InputDefs()) { auto& n_input = graph_build.GetOrCreateNodeArg(input->Name(), input->TypeAsProto()); inputs.push_back(&n_input); - const ONNX_NAMESPACE::TensorProto* initializer = nullptr; - if (graph.GetInitializedTensor(input->Name(), initializer)) { - const ONNX_NAMESPACE::TensorProto* subgraph_initializer = nullptr; - if (!graph_build.GetInitializedTensor(input->Name(), subgraph_initializer)) { - graph_build.AddInitializedTensor(*(initializer)); - } - } + graph_utils::MakeInitializerCopyIfNotExist(graph.GetGraph(), graph_build, input->Name(), + load_initializers_inline_true); } for (auto input : node->ImplicitInputDefs()) { - const ONNX_NAMESPACE::TensorProto* initializer = nullptr; - if (graph.GetInitializedTensor(input->Name(), initializer)) { - const ONNX_NAMESPACE::TensorProto* subgraph_initializer = nullptr; - if (!graph_build.GetInitializedTensor(input->Name(), subgraph_initializer)) { - graph_build.AddInitializedTensor(*(initializer)); - } - } + graph_utils::MakeInitializerCopyIfNotExist(graph.GetGraph(), graph_build, input->Name(), + load_initializers_inline_true); } for (auto output : node->OutputDefs()) { auto& n_output = graph_build.GetOrCreateNodeArg(output->Name(), output->TypeAsProto()); @@ -2471,7 +2465,7 @@ bool TensorrtExecutionProvider::DetectTensorRTGraphCycles(SubGraphCollection_t& // Create adjacency list size_t graph_size = node_to_index_map.size(); - std::list* adjacency_map = new std::list[graph_size]; + std::vector> adjacency_map(graph_size); for (const auto& node : node_to_outputs_map) { for (auto iter = node.second.begin(); iter != node.second.end(); ++iter) { const auto& loc = input_to_nodes_map.find(*iter); @@ -2486,14 +2480,14 @@ bool TensorrtExecutionProvider::DetectTensorRTGraphCycles(SubGraphCollection_t& } // Check cycle in the graph - bool* visited = new bool[graph_size]; - bool* st = new bool[graph_size]; + InlinedVector visited(graph_size); + InlinedVector st(graph_size); for (size_t i = 0; i < graph_size; ++i) { visited[i] = false; st[i] = false; } - std::vector cycles; + InlinedVector cycles; bool has_cycle = false; for (size_t i = 0; i < graph_size; ++i) { if (FindCycleHelper(i, adjacency_map, visited, st, cycles)) { @@ -2514,10 +2508,6 @@ bool TensorrtExecutionProvider::DetectTensorRTGraphCycles(SubGraphCollection_t& } } } - - delete[] adjacency_map; - delete[] visited; - delete[] st; } return cycle_detected; } @@ -3638,8 +3628,9 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView bool context_update = false; std::unordered_set input_names; - OrtDevice device(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, narrow(device_id_)); - OrtMemoryInfo mem_info("", OrtAllocatorType::OrtDeviceAllocator, device, device_id_); + OrtDevice device(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::NVIDIA, + narrow(device_id_)); + OrtMemoryInfo mem_info("", OrtAllocatorType::OrtDeviceAllocator, device); if (alloc_ == nullptr) { Ort::ThrowOnError(api->KernelContext_GetAllocator(context, &mem_info, &alloc_)); } @@ -4315,8 +4306,9 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(con std::unordered_map> shape_tensor_values; // This map holds "shape tensor -> shape values" for the shape tensor input across this inference run std::unordered_map> shape_tensor_values_int64; // same as above but for int64 shape tensor input - OrtDevice device(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, narrow(device_id_)); - OrtMemoryInfo mem_info("", OrtAllocatorType::OrtDeviceAllocator, device, device_id_); + OrtDevice device(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::NVIDIA, + narrow(device_id_)); + OrtMemoryInfo mem_info("", OrtAllocatorType::OrtDeviceAllocator, device); if (alloc_ == nullptr) { Ort::ThrowOnError(api->KernelContext_GetAllocator(context, &mem_info, &alloc_)); } @@ -4522,8 +4514,11 @@ void TensorrtExecutionProvider::RegisterStreamHandlers(IStreamCommandHandleRegis } OrtDevice TensorrtExecutionProvider::GetOrtDeviceByMemType(OrtMemType mem_type) const { - if (mem_type == OrtMemTypeCPUInput) return OrtDevice(); - if (mem_type == OrtMemTypeCPUOutput) return OrtDevice(OrtDevice::CPU, OrtDevice::MemType::CUDA_PINNED, 0 /*CPU device id always be 0*/); + if (mem_type == OrtMemTypeCPUInput) + return OrtDevice(); + if (mem_type == OrtMemTypeCPUOutput) + return OrtDevice(OrtDevice::CPU, OrtDevice::MemType::HOST_ACCESSIBLE, OrtDevice::VendorIds::NVIDIA, + 0 /*CPU device id always be 0*/); return default_device_; } diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h index b00c800999f3b..7e02cf7590f66 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h @@ -540,7 +540,7 @@ class TensorrtExecutionProvider : public IExecutionProvider { * and save those information in subgraph context data structure. It's useful for building a valid graph and * make Graph::Resolve() happy especially when dealing with nested control-flow op graph. */ - void BuildSubGraphContext(const Graph& build_graph) const; + void BuildSubGraphContext(Graph& build_graph) const; /** * Set outer scope values for subgraphs and add thoes values as top-level graph's inputs if needed. diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_helper.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_helper.cc index b99cb4f52ed59..c123a7d8d4590 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_helper.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_helper.cc @@ -71,7 +71,7 @@ bool TensorrtExecutionProvider::IsLocalValue(const Graph& graph, * and save those information in subgraph context data structure. It's useful for building a valid graph and * make Graph::Resolve() happy especially when dealing with nested control-flow op graph. */ -void TensorrtExecutionProvider::BuildSubGraphContext(const Graph& graph) const { +void TensorrtExecutionProvider::BuildSubGraphContext(Graph& graph) const { // Iterate all the nodes and recurse into inner most subgraph first for (int i = 0; i < graph.MaxNodeIndex(); ++i) { auto node = graph.GetNode(i); @@ -79,9 +79,9 @@ void TensorrtExecutionProvider::BuildSubGraphContext(const Graph& graph) const { continue; } - auto subgraph_map = node->GetAttributeNameToSubgraphMap(); + auto& subgraph_map = node->GetAttributeNameToMutableSubgraphMap(); for (auto& entry : subgraph_map) { - const Graph* subgraph = entry.second; + Graph* subgraph = entry.second; BuildSubGraphContext(*subgraph); } } @@ -121,6 +121,7 @@ void TensorrtExecutionProvider::BuildSubGraphContext(const Graph& graph) const { } // This input arg is not the output of another node so must come from either a graph input or an initializer. context->inputs_and_initializers[input->Name()] = input; + ORT_THROW_IF_ERROR(graph_utils::ConvertInMemoryDataToInline(graph, input->Name())); } } } diff --git a/onnxruntime/core/providers/vitisai/imp/graph.cc b/onnxruntime/core/providers/vitisai/imp/graph.cc index 88bdbaed40c73..20ae1cfbfa2c1 100644 --- a/onnxruntime/core/providers/vitisai/imp/graph.cc +++ b/onnxruntime/core/providers/vitisai/imp/graph.cc @@ -177,7 +177,9 @@ void graph_save(const Graph& graph, const std::string& filename, const std::stri auto graph_proto_subgraph = graph.ToGraphProto(); *model_proto->mutable_graph() = *graph_proto_subgraph; auto& logger = logging::LoggingManager::DefaultLogger(); - auto model = Model::Create(std::move(*model_proto), ToPathString(filename), nullptr, logger); + // Reading initializer data from an external data file below will access the data file based on the directory of the model_path + // parameter. Thus, the path to the original model must be used here to make reading initializer data from an external file work. + auto model = Model::Create(std::move(*model_proto), graph.ModelPath(), nullptr, logger); auto status = model->MainGraph().Resolve(); vai_assert(status.IsOK(), "graph resolve error:" + status.ErrorMessage()); if (initializer_size_threshold == std::numeric_limits::max()) { diff --git a/onnxruntime/core/providers/vitisai/vitisai_execution_provider.cc b/onnxruntime/core/providers/vitisai/vitisai_execution_provider.cc index 1d812779da265..b672aa8bb35be 100644 --- a/onnxruntime/core/providers/vitisai/vitisai_execution_provider.cc +++ b/onnxruntime/core/providers/vitisai/vitisai_execution_provider.cc @@ -27,7 +27,7 @@ constexpr const char* VITISAI = "VITISAI"; VitisAIExecutionProvider::VitisAIExecutionProvider( const ProviderOptions& info) : IExecutionProvider{onnxruntime::kVitisAIExecutionProvider, - OrtDevice(OrtDevice::CPU, OrtDevice::MemType::DEFAULT, + OrtDevice(OrtDevice::CPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::NONE, DEFAULT_CPU_ALLOCATOR_DEVICE_ID, kAlloc4KAlignment)}, info_(info) { @@ -155,7 +155,7 @@ std::vector VitisAIExecutionProvider::CreatePreferredAllocators() return std::make_unique( OrtMemoryInfo( onnxruntime::CPU_ALIGNED_4K, OrtAllocatorType::OrtDeviceAllocator, - OrtDevice(OrtDevice::CPU, OrtDevice::MemType::DEFAULT, + OrtDevice(OrtDevice::CPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::NONE, device_id, kAlloc4KAlignment))); }, diff --git a/onnxruntime/core/providers/vsinpu/builders/impl/pad_op_builder.h b/onnxruntime/core/providers/vsinpu/builders/impl/pad_op_builder.h index 19cbe4e7f3e48..f4bf2be17ee56 100644 --- a/onnxruntime/core/providers/vsinpu/builders/impl/pad_op_builder.h +++ b/onnxruntime/core/providers/vsinpu/builders/impl/pad_op_builder.h @@ -72,7 +72,7 @@ class PadOpBuilder : public BaseOpBuilder { return false; } - Initializer unpacked_tensor(*pads_initializer); + Initializer unpacked_tensor(graph_viewer.GetGraph(), *pads_initializer); auto tensor_data = unpacked_tensor.DataAsSpan(); for (size_t i = 0; i < unpacked_tensor.size(); i++) { if (tensor_data[i] < 0) { diff --git a/onnxruntime/core/providers/vsinpu/builders/impl/split_op_builder.h b/onnxruntime/core/providers/vsinpu/builders/impl/split_op_builder.h index e08416bda70d4..b58d272d011b1 100644 --- a/onnxruntime/core/providers/vsinpu/builders/impl/split_op_builder.h +++ b/onnxruntime/core/providers/vsinpu/builders/impl/split_op_builder.h @@ -63,7 +63,7 @@ class SplitOpBuilder : public BaseOpBuilder { LOGS_DEFAULT(WARNING) << "Optional input 'split' must be a constant initializer if provided."; return false; } - Initializer unpacked_tensor(*splits); + Initializer unpacked_tensor(graph_viewer.GetGraph(), *splits); auto split_sizes_ = unpacked_tensor.DataAsSpan(); splits_list.assign(split_sizes_.begin(), split_sizes_.end()); split_provided = true; diff --git a/onnxruntime/core/providers/vsinpu/vsinpu_execution_provider.cc b/onnxruntime/core/providers/vsinpu/vsinpu_execution_provider.cc index 3b70ab3c9241b..e21e479bc29b2 100644 --- a/onnxruntime/core/providers/vsinpu/vsinpu_execution_provider.cc +++ b/onnxruntime/core/providers/vsinpu/vsinpu_execution_provider.cc @@ -41,7 +41,7 @@ namespace onnxruntime { VSINPUExecutionProvider::VSINPUExecutionProvider(const VSINPUExecutionProviderInfo& info) : IExecutionProvider{onnxruntime::kVSINPUExecutionProvider, OrtDevice(OrtDevice::CPU, OrtDevice::MemType::DEFAULT, - DEFAULT_CPU_ALLOCATOR_DEVICE_ID, + DEFAULT_CPU_ALLOCATOR_DEVICE_ID, OrtDevice::VendorIds::NONE, kAlloc4KAlignment)}, device_id_(info.device_id) { } @@ -279,9 +279,8 @@ std::vector VSINPUExecutionProvider::CreatePreferredAllocators() { return std::make_unique( OrtMemoryInfo( onnxruntime::CPU_ALIGNED_4K, OrtAllocatorType::OrtDeviceAllocator, - OrtDevice(OrtDevice::CPU, OrtDevice::MemType::DEFAULT, - device_id, - kAlloc4KAlignment))); + OrtDevice(OrtDevice::CPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::NONE, + device_id, kAlloc4KAlignment))); }, device_id_, use_arena_false}; diff --git a/onnxruntime/core/providers/vsinpu/vsinpu_util.cc b/onnxruntime/core/providers/vsinpu/vsinpu_util.cc index 5d2f701ceac20..5034cb5a3525c 100644 --- a/onnxruntime/core/providers/vsinpu/vsinpu_util.cc +++ b/onnxruntime/core/providers/vsinpu/vsinpu_util.cc @@ -425,7 +425,7 @@ void GetQuantizationScaleAndZeroPoint( if (!s) { LOGS_DEFAULT(ERROR) << name + " is not a constant initializer"; }; - Initializer unpacked_tensor(*s, model_path); + Initializer unpacked_tensor(graph_viewer.GetGraph(), *s, model_path); scale = unpacked_tensor.DataAsSpan()[0]; // per channel quantized handling @@ -442,7 +442,7 @@ void GetQuantizationScaleAndZeroPoint( if (!s) { LOGS_DEFAULT(ERROR) << name + " is not a constant initializer"; }; - Initializer unpacked_tensor(*s, model_path); + Initializer unpacked_tensor(graph_viewer.GetGraph(), *s, model_path); bool is_i8_zp = unpacked_tensor.data_type() == onnx::TensorProto_DataType_INT8; // some qdq conv bias is int32 quantized bool is_int32_zp = unpacked_tensor.data_type() == onnx::TensorProto_DataType_INT32; diff --git a/onnxruntime/core/providers/webgpu/allocator.h b/onnxruntime/core/providers/webgpu/allocator.h index d98661884c659..0b27f713777bc 100644 --- a/onnxruntime/core/providers/webgpu/allocator.h +++ b/onnxruntime/core/providers/webgpu/allocator.h @@ -16,8 +16,8 @@ class GpuBufferAllocator : public IAllocator { GpuBufferAllocator(const WebGpuContext& context) : IAllocator( OrtMemoryInfo(WEBGPU_BUFFER, OrtAllocatorType::OrtDeviceAllocator, - OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, 0), - 0, OrtMemTypeDefault)), + OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::NONE, 0), + OrtMemTypeDefault)), context_{context} { } diff --git a/onnxruntime/core/providers/webgpu/compute_context.h b/onnxruntime/core/providers/webgpu/compute_context.h index f2569fce6b5eb..7a9cf1ecf85ba 100644 --- a/onnxruntime/core/providers/webgpu/compute_context.h +++ b/onnxruntime/core/providers/webgpu/compute_context.h @@ -40,6 +40,11 @@ class ComputeContext { inline bool HasFeature(wgpu::FeatureName feature) const { return webgpu_context_.DeviceHasFeature(feature); } +#if !defined(__wasm__) + inline const wgpu::AdapterPropertiesSubgroupMatrixConfigs& SubgroupMatrixConfigs() const { + return webgpu_context_.SubgroupMatrixConfigs(); + } +#endif // // Get the kernel context. diff --git a/onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.cc b/onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.cc index 3c4a802e4bcde..380d11fd4ab85 100644 --- a/onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.cc +++ b/onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.cc @@ -289,6 +289,14 @@ std::string GetPowImpl(int lhs_element_type, int /* rhs_element_type */) { if (lhs_element_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32) { round_str = "round"; } + std::string use_sqrt_for_pow; + if (lhs_element_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT || lhs_element_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16) { + // use sqrt instead of pow when base (a) is a positive float and exponent (b) is 0.5 + use_sqrt_for_pow = + " else if (a >= input_a_element_t(0.0) && b == 0.5) {\n" + " return sqrt(a);\n" + " }\n"; + } s << "fn pow_custom(a : input_a_element_t, b : f32) -> input_a_element_t {\n" " if (b == 0.0) {\n" @@ -296,6 +304,7 @@ std::string GetPowImpl(int lhs_element_type, int /* rhs_element_type */) { " } else if (a < input_a_element_t(0.0) && b != floor(b)) {\n" " return input_a_element_t(pow(f32(a), b)); // NaN\n" " }\n" + << use_sqrt_for_pow << " return select(sign(a), input_a_element_t(1.0), round(abs(b) % 2.0) != 1.0) * input_a_element_t(" << round_str << "(pow(f32(abs(a)), b)));\n" << "}\n" "fn pow_v(a : vec4, b : vec4) -> vec4 {\n" diff --git a/onnxruntime/core/providers/webgpu/nn/instance_norm.cc b/onnxruntime/core/providers/webgpu/nn/instance_norm.cc index 7b39980f85605..bc3fe538346ca 100644 --- a/onnxruntime/core/providers/webgpu/nn/instance_norm.cc +++ b/onnxruntime/core/providers/webgpu/nn/instance_norm.cc @@ -25,11 +25,11 @@ Status ComputeChannelScaleShiftProgram::GenerateShaderCode(ShaderHelper& shader) shader.MainFunctionBody() << " let batch = workgroup_idx / uniforms.x_shape[1];\n" << " let channel = workgroup_idx % uniforms.x_shape[1];\n" - << " let hight = uniforms.x_shape[2];\n" - << " // initialize workgroup memory<< \n" + << " let height = uniforms.x_shape[2];\n" + << " // initialize workgroup memory\n" << " var sum = f32_val_t(0);\n" << " var squared_sum = f32_val_t(0);\n" - << " for (var h = local_idx; h < hight; h += workgroup_size) {\n" + << " for (var h = local_idx; h < height; h += workgroup_size) {\n" << " let indices = x_indices_t(batch, channel, h);\n" << " let value = f32_val_t(" << input.GetByIndices("indices") << ");\n" << " sum += value;\n" @@ -46,8 +46,8 @@ Status ComputeChannelScaleShiftProgram::GenerateShaderCode(ShaderHelper& shader) << " workgroupBarrier();\n" << " }\n" << " if (local_idx == 0) {\n" - << " let sum_final = " << SumVector("workgroup_shared_sum[0]", components_) << " / f32(hight * " << components_ << ");\n" - << " let squared_sum_final = " << SumVector("workgroup_shared_squared_sum[0]", components_) << " / f32(hight * " << components_ << ");\n" + << " let sum_final = " << SumVector("workgroup_shared_sum[0]", components_) << " / f32(height * " << components_ << ");\n" + << " let squared_sum_final = " << SumVector("workgroup_shared_squared_sum[0]", components_) << " / f32(height * " << components_ << ");\n" << " let inv_std_dev = inverseSqrt(squared_sum_final - sum_final * sum_final + f32(" << std::to_string(epsilon_) << "));\n" << " let channel_scale = inv_std_dev * f32(" << scale.GetByOffset("channel") << ");\n" << " let channel_shift = f32(" << bias.GetByOffset("channel") << ") - sum_final * channel_scale;\n" @@ -194,17 +194,19 @@ Status InstanceNorm::ComputeInternal(ComputeContext& context) const { const auto spatial_size = input->Shape().SizeFromDimension(2); Tensor channel_scale_shift; ORT_RETURN_IF_ERROR(ComputeChannelScaleAndShift(context, input, scale, bias, epsilon_, &channel_scale_shift)); - const auto output_shape(input_shape_vector); + TensorShape output_shape(input_shape_vector); Tensor* output = context.Output(0, output_shape); const auto components = GetMaxComponents(spatial_size); TensorShapeVector modified_input_shape_vector = {batch_size, channels, spatial_size / components}; TensorShape modified_input_shape(modified_input_shape_vector); TensorShape modified_output_shape(modified_input_shape_vector); - auto output_size = (modified_output_shape.Size() + components - 1) / components; + auto output_size = modified_output_shape.Size(); + TensorShapeVector channel_scale_shift_shape_vector = {batch_size, channels, 1}; + TensorShape reduced_channel_scale_shift_shape(channel_scale_shift_shape_vector); InstanceNormProgram program; program .AddInputs({{input, ProgramTensorMetadataDependency::TypeAndRank, modified_input_shape, components}, - {&channel_scale_shift, ProgramTensorMetadataDependency::TypeAndRank, channel_scale_shift.Shape(), 2}}) + {&channel_scale_shift, ProgramTensorMetadataDependency::TypeAndRank, reduced_channel_scale_shift_shape, 2}}) .AddOutput({output, ProgramTensorMetadataDependency::TypeAndRank, modified_output_shape, components}) .SetDispatchGroupSize((output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) .AddUniformVariables({static_cast(output_size)}); diff --git a/onnxruntime/core/providers/webgpu/shader_helper.cc b/onnxruntime/core/providers/webgpu/shader_helper.cc index 36f6b512a0a93..a22d21d8d798b 100644 --- a/onnxruntime/core/providers/webgpu/shader_helper.cc +++ b/onnxruntime/core/providers/webgpu/shader_helper.cc @@ -379,6 +379,10 @@ Status ShaderHelper::GenerateSourceCode(std::string& code, std::vector& sha #if !defined(__wasm__) if (device_.HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix)) { ss << "enable chromium_experimental_subgroup_matrix;\n"; + + // Dawn enforces the subgroup matrix builtin arguments to be uniform in change https://dawn-review.googlesource.com/c/dawn/+/236054 + // Since we use `subgroup_id` as the subgroup matrix builtin argument, we have to turn off this restriction + ss << "diagnostic (off, chromium.subgroup_matrix_uniformity);\n"; } #endif diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.cc b/onnxruntime/core/providers/webgpu/webgpu_context.cc index 27380645baf57..4bb41c2eb0ba6 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_context.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_context.cc @@ -140,8 +140,6 @@ void WebGpuContext::Initialize(const WebGpuBufferCacheConfig& buffer_cache_confi // cache device queue device_queue_ = device_.GetQueue(); - // cache adapter info - ORT_ENFORCE(Device().GetAdapterInfo(&adapter_info_)); // cache device limits ORT_ENFORCE(Device().GetLimits(&device_limits_)); // cache device features @@ -150,6 +148,13 @@ void WebGpuContext::Initialize(const WebGpuBufferCacheConfig& buffer_cache_confi for (size_t i = 0; i < supported_features.featureCount; i++) { device_features_.insert(supported_features.features[i]); } + // cache adapter info +#if !defined(__wasm__) + if (DeviceHasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix)) { + adapter_info_.nextInChain = &subgroup_matrix_configs_; + } +#endif + ORT_ENFORCE(Device().GetAdapterInfo(&adapter_info_)); // create buffer manager buffer_mgr_ = BufferManagerFactory::Create(*this, @@ -452,6 +457,9 @@ std::vector WebGpuContext::GetEnabledAdapterToggles() const { constexpr const char* toggles[] = { "use_dxc", "allow_unsafe_apis", +#if defined(DAWN_ENABLE_VULKAN) + "use_vulkan_memory_model", +#endif }; return std::vector(std::begin(toggles), std::end(toggles)); } diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.h b/onnxruntime/core/providers/webgpu/webgpu_context.h index 68005d8afec16..4111f809b1627 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_context.h +++ b/onnxruntime/core/providers/webgpu/webgpu_context.h @@ -81,6 +81,9 @@ class WebGpuContext final { const wgpu::AdapterInfo& AdapterInfo() const { return adapter_info_; } const wgpu::Limits& DeviceLimits() const { return device_limits_; } bool DeviceHasFeature(wgpu::FeatureName feature) const { return device_features_.find(feature) != device_features_.end(); } +#if !defined(__wasm__) + const wgpu::AdapterPropertiesSubgroupMatrixConfigs& SubgroupMatrixConfigs() const { return subgroup_matrix_configs_; } +#endif const wgpu::CommandEncoder& GetCommandEncoder() { if (!current_command_encoder_) { @@ -214,6 +217,9 @@ class WebGpuContext final { wgpu::AdapterInfo adapter_info_; wgpu::Limits device_limits_; std::unordered_set device_features_; +#if !defined(__wasm__) + wgpu::AdapterPropertiesSubgroupMatrixConfigs subgroup_matrix_configs_; +#endif wgpu::CommandEncoder current_command_encoder_; wgpu::ComputePassEncoder current_compute_pass_encoder_; diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc index 9ea79e4cf28a3..46e0347f0c0fd 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc @@ -766,7 +766,8 @@ using namespace webgpu; WebGpuExecutionProvider::WebGpuExecutionProvider(int context_id, WebGpuContext& context, WebGpuExecutionProviderConfig&& config) - : IExecutionProvider{kWebGpuExecutionProvider, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, 0)}, + : IExecutionProvider{kWebGpuExecutionProvider, + OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::NONE, 0)}, context_id_{context_id}, context_{context}, preferred_data_layout_{config.data_layout}, diff --git a/onnxruntime/core/providers/webnn/allocator.h b/onnxruntime/core/providers/webnn/allocator.h index c06da909801cc..484c30dc42e50 100644 --- a/onnxruntime/core/providers/webnn/allocator.h +++ b/onnxruntime/core/providers/webnn/allocator.h @@ -15,7 +15,10 @@ namespace webnn { class WebNNTensorAllocator : public IAllocator { public: - WebNNTensorAllocator() : IAllocator(OrtMemoryInfo(WEBNN_TENSOR, OrtAllocatorType::OrtDeviceAllocator, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, 0), 0, OrtMemTypeDefault)) {} + WebNNTensorAllocator() + : IAllocator(OrtMemoryInfo(WEBNN_TENSOR, OrtAllocatorType::OrtDeviceAllocator, + OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::NONE, 0), + OrtMemTypeDefault)) {} void* Alloc(size_t size) override; diff --git a/onnxruntime/core/providers/webnn/builders/helper.h b/onnxruntime/core/providers/webnn/builders/helper.h index cec218ea94e58..d59788600f997 100644 --- a/onnxruntime/core/providers/webnn/builders/helper.h +++ b/onnxruntime/core/providers/webnn/builders/helper.h @@ -15,6 +15,7 @@ #include #include +using onnxruntime::common::Status; namespace onnxruntime { class GraphViewer; @@ -92,14 +93,33 @@ inline std::vector GetNarrowedIntfromInt64(gsl::span int64_vec return vec; } +bool inline UnpackInitializerData(const ONNX_NAMESPACE::TensorProto& initializer, + std::vector& unpacked_tensor, + const GraphViewer& graph_viewer, + const logging::Logger& logger) { + Status status = Status::OK(); + if (utils::HasExternalData(initializer)) { + status = onnxruntime::utils::UnpackInitializerData(initializer, graph_viewer.ModelPath(), unpacked_tensor); + } else { + status = onnxruntime::utils::UnpackInitializerData(initializer, unpacked_tensor); + } + + if (!status.IsOK()) { + LOGS(logger, ERROR) << "Error while unpacking initializer data: " << status.ErrorMessage(); + return false; + } + + return true; +} + template -bool ReadIntArrayFrom1DTensor(const onnx::TensorProto& tensor, std::vector& array, const logging::Logger& logger) { +bool ReadIntArrayFrom1DTensor(const onnx::TensorProto& tensor, std::vector& array, + const GraphViewer& graph_viewer, const logging::Logger& logger) { std::vector unpacked_tensor; - auto status = onnxruntime::utils::UnpackInitializerData(tensor, unpacked_tensor); - if (!status.IsOK()) { - LOGS(logger, ERROR) << "Error while unpacking shape: " << status.ErrorMessage(); + if (!UnpackInitializerData(tensor, unpacked_tensor, graph_viewer, logger)) { return false; } + const auto& dims = tensor.dims(); if (dims.size() != 1) { LOGS(logger, VERBOSE) << "The tensor must be 1D."; @@ -130,13 +150,13 @@ bool ReadIntArrayFrom1DTensor(const onnx::TensorProto& tensor, std::vector& a return true; } -inline bool ReadScalarTensorData(const onnx::TensorProto& tensor, emscripten::val& scalar, const logging::Logger& logger) { +inline bool ReadScalarTensorData(const onnx::TensorProto& tensor, emscripten::val& scalar, + const GraphViewer& graph_viewer, const logging::Logger& logger) { std::vector unpacked_tensor; - auto status = onnxruntime::utils::UnpackInitializerData(tensor, unpacked_tensor); - if (!status.IsOK()) { - LOGS(logger, ERROR) << "Error while unpacking tensor: " << status.ErrorMessage(); + if (!UnpackInitializerData(tensor, unpacked_tensor, graph_viewer, logger)) { return false; } + switch (tensor.data_type()) { case ONNX_NAMESPACE::TensorProto_DataType_BOOL: case ONNX_NAMESPACE::TensorProto_DataType_UINT8: @@ -203,11 +223,11 @@ std::unordered_set GetSupportedNodes(const GraphViewer& graph_viewe const emscripten::val& wnn_limits, const logging::Logger& logger); -// Retrieve the first input name of a WebNN op used for validating supported input data types. +// Retrieve the first input name of an ONNX op's corresponding WebNN op used for validating supported input data types. // WebNN ops have various first input names such as 'a', 'input', 'inputs', etc. // All WebNN op inputs are recorded in op_inputs_map. -inline std::string_view GetWebNNOpFirstInputName(const std::string_view webnn_op_type) { - auto it = op_inputs_map.find(webnn_op_type); +inline std::string_view GetWebNNOpFirstInputName(const std::string_view onnx_op_type) { + auto it = op_inputs_map.find(onnx_op_type); if (it != op_inputs_map.end()) { for (const auto& input : it->second.inputs) { if (input.index == 0) { @@ -218,9 +238,9 @@ inline std::string_view GetWebNNOpFirstInputName(const std::string_view webnn_op return "input"; } -inline std::string_view GetWebNNOpType(const std::string_view op_type) { - auto it = op_inputs_map.find(op_type); - // Return an empty string if the op_type is not listed in the op_inputs_map. +inline std::string_view GetWebNNOpType(const std::string_view onnx_op_type) { + auto it = op_inputs_map.find(onnx_op_type); + // Return an empty string if the onnx_op_type is not listed in the op_inputs_map. return (it != op_inputs_map.end()) ? it->second.opType : ""; } diff --git a/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.cc index d5683454c89b7..b0ec006db6986 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.cc @@ -66,7 +66,7 @@ bool BaseOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& node, if (webnn_op_type.empty()) return false; - const std::string_view webnn_input_name = GetWebNNOpFirstInputName(webnn_op_type); + const std::string_view webnn_input_name = GetWebNNOpFirstInputName(op_type); return IsDataTypeSupportedByWebNNOp(op_type, webnn_op_type, input_type, wnn_limits, webnn_input_name, "input", logger); } diff --git a/onnxruntime/core/providers/webnn/builders/impl/cumsum_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/cumsum_op_builder.cc index 14324415b3659..7528d9ad2ff51 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/cumsum_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/cumsum_op_builder.cc @@ -50,7 +50,8 @@ Status CumSumOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const const std::string axis_name = GetTensorName(input_defs, 1); const auto axis_tensor = *initializers.at(axis_name); emscripten::val axis = emscripten::val::undefined(); - ORT_RETURN_IF_NOT(ReadScalarTensorData(axis_tensor, axis, logger), "Cannot get axis value"); + ORT_RETURN_IF_NOT(ReadScalarTensorData(axis_tensor, axis, model_builder.GetGraphViewer(), logger), + "Cannot get axis value"); int64_t webnn_axis = HandleNegativeAxis(axis.as(), input_rank); NodeAttrHelper helper(node); diff --git a/onnxruntime/core/providers/webnn/builders/impl/expand_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/expand_op_builder.cc index 2c28786b788f9..e9c03420de445 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/expand_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/expand_op_builder.cc @@ -44,7 +44,8 @@ Status ExpandOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const auto& initializers(model_builder.GetInitializerTensors()); const auto& shape_tensor = *initializers.at(input_defs[1]->Name()); std::vector new_shape; - ORT_RETURN_IF_NOT(ReadIntArrayFrom1DTensor(shape_tensor, new_shape, logger), "Cannot get shape."); + ORT_RETURN_IF_NOT(ReadIntArrayFrom1DTensor(shape_tensor, new_shape, model_builder.GetGraphViewer(), logger), + "Cannot get shape."); emscripten::val input = model_builder.GetOperand(input_defs[0]->Name()); std::vector input_shape; ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get input's shape."); @@ -84,8 +85,7 @@ bool ExpandOpBuilder::IsOpSupportedImpl(const GraphViewer& graph_viewer, } std::vector new_shape; - if (!ReadIntArrayFrom1DTensor(shape_tensor, new_shape, logger)) { - LOGS(logger, VERBOSE) << "Cannot get shape."; + if (!ReadIntArrayFrom1DTensor(shape_tensor, new_shape, graph_viewer, logger)) { return false; } if (std::any_of(new_shape.begin(), new_shape.end(), [](int64_t dimension) { return dimension == 0; })) { diff --git a/onnxruntime/core/providers/webnn/builders/impl/gqa_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/gqa_op_builder.cc index 3d21f3c0d42a5..75ce80462544e 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/gqa_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/gqa_op_builder.cc @@ -392,7 +392,7 @@ bool GroupQueryAttentionOpBuilder::IsOpSupportedImpl(const GraphViewer& graph_vi const auto total_sequence_length_tensor = *total_sequence_length_initializer; emscripten::val total_sequence_length = emscripten::val::undefined(); - if (!ReadScalarTensorData(total_sequence_length_tensor, total_sequence_length, logger)) { + if (!ReadScalarTensorData(total_sequence_length_tensor, total_sequence_length, graph_viewer, logger)) { return false; } diff --git a/onnxruntime/core/providers/webnn/builders/impl/gru_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/gru_op_builder.cc index 68abc8ce834f9..dfe80dd419092 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/gru_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/gru_op_builder.cc @@ -143,8 +143,7 @@ bool GruOpBuilder::IsOpSupportedImpl(const GraphViewer& graph_viewer, const Node const auto& sequence_lens_tensor = *seq_initializer; std::vector sequence_lens; - if (!ReadIntArrayFrom1DTensor(sequence_lens_tensor, sequence_lens, logger)) { - LOGS(logger, ERROR) << "Cannot read sequence lens tensor"; + if (!ReadIntArrayFrom1DTensor(sequence_lens_tensor, sequence_lens, graph_viewer, logger)) { return false; } if (!std::all_of(sequence_lens.begin(), sequence_lens.end(), diff --git a/onnxruntime/core/providers/webnn/builders/impl/lrn_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/lrn_op_builder.cc index 2e5d3d6b5228a..8936bda875aef 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/lrn_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/lrn_op_builder.cc @@ -156,9 +156,10 @@ bool LRNOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& node, } // Check if the input data type is supported by each decomposed WebNN op. - // Decomposed ops include: "add", "averagePool2d", "div", "mul", "pad", "pow" and "transpose". - for (const std::string_view webnn_op_type : decomposed_op_map.at(op_type)) { - const std::string_view webnn_input_name = GetWebNNOpFirstInputName(webnn_op_type); + // Decomposed ops include: "Add", "AveragePool", "Div", "Mul", "Pad", "Pow" and "Transpose". + for (const std::string_view decomposed_op_type : decomposed_op_map.at(op_type)) { + const std::string_view webnn_op_type = GetWebNNOpType(decomposed_op_type); + const std::string_view webnn_input_name = GetWebNNOpFirstInputName(decomposed_op_type); if (!IsDataTypeSupportedByWebNNOp(op_type, webnn_op_type, input_type, wnn_limits, webnn_input_name, "X", logger)) { return false; } @@ -178,7 +179,8 @@ bool LRNOpBuilder::HasSupportedOutputsImpl(const Node& node, } // Check if the output data type is supported by every decomposed WebNN op. - for (const std::string_view webnn_op_type : decomposed_op_map.at(op_type)) { + for (const std::string_view decomposed_op_type : decomposed_op_map.at(op_type)) { + const std::string_view webnn_op_type = GetWebNNOpType(decomposed_op_type); if (!IsDataTypeSupportedByWebNNOp(op_type, webnn_op_type, output_type, wnn_limits, "output", "Y", logger)) { return false; } diff --git a/onnxruntime/core/providers/webnn/builders/impl/lstm_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/lstm_op_builder.cc index 7b5757ecc0faa..09e584bc66f8a 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/lstm_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/lstm_op_builder.cc @@ -149,8 +149,7 @@ bool LstmOpBuilder::IsOpSupportedImpl(const GraphViewer& graph_viewer, const Nod const auto& sequence_lens_tensor = *sequence_lens_init; std::vector sequence_lens; - if (!ReadIntArrayFrom1DTensor(sequence_lens_tensor, sequence_lens, logger)) { - LOGS(logger, ERROR) << "Cannot read sequence lens tensor"; + if (!ReadIntArrayFrom1DTensor(sequence_lens_tensor, sequence_lens, graph_viewer, logger)) { return false; } if (std::any_of(sequence_lens.begin(), sequence_lens.end(), diff --git a/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc index ea382150db315..148eacac98e4a 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc @@ -350,9 +350,10 @@ bool NormalizationOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const No if (op_type == "SimplifiedLayerNormalization" || op_type == "SkipSimplifiedLayerNormalization") { // SkipSimplifiedLayerNormalization and SimplifiedLayerNormalization are supported by decomposed WebNN ops. // Check if the input data type is supported by each decomposed WebNN op. - // Decomposed ops include: "add", "div", "mul", "pow", "reduceMean" and "sqrt". - for (const std::string_view webnn_op_type : decomposed_op_map.at(op_type)) { - const std::string_view webnn_input_name = GetWebNNOpFirstInputName(webnn_op_type); + // Decomposed ops include: "Add", "Div", "Mul", "Pow", "ReduceMean" and "Sqrt". + for (const std::string_view decomposed_op_type : decomposed_op_map.at(op_type)) { + const std::string_view webnn_op_type = GetWebNNOpType(decomposed_op_type); + const std::string_view webnn_input_name = GetWebNNOpFirstInputName(decomposed_op_type); if (!IsDataTypeSupportedByWebNNOp( op_type, webnn_op_type, input0_type, wnn_limits, webnn_input_name, "input", logger)) { return false; @@ -376,7 +377,8 @@ bool NormalizationOpBuilder::HasSupportedOutputsImpl(const Node& node, if (op_type == "SimplifiedLayerNormalization" || op_type == "SkipSimplifiedLayerNormalization") { // Check if the output data type is supported by every decomposed WebNN op. - for (const std::string_view webnn_op_type : decomposed_op_map.at(op_type)) { + for (const std::string_view decomposed_op_type : decomposed_op_map.at(op_type)) { + const std::string_view webnn_op_type = GetWebNNOpType(decomposed_op_type); if (!IsDataTypeSupportedByWebNNOp(op_type, webnn_op_type, output_type, wnn_limits, "output", "output", logger)) { return false; } diff --git a/onnxruntime/core/providers/webnn/builders/impl/pad_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/pad_op_builder.cc index f17d87d41f9ae..0a0a0d8dc93c2 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/pad_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/pad_op_builder.cc @@ -83,16 +83,18 @@ Status PadOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const auto opset = node.SinceVersion(); // From opset 11, pads, constant value and axes are inputs. if (opset >= 11) { + const auto& graph_viewer = model_builder.GetGraphViewer(); ORT_RETURN_IF(input_defs.size() < 2, "Pads is required at opset ", opset); std::vector pads; const auto& pads_tensor = *initializers.at(input_defs[1]->Name()); - ORT_RETURN_IF_NOT(ReadIntArrayFrom1DTensor(pads_tensor, pads, logger), "Error while read pads tensor"); + ORT_RETURN_IF_NOT(ReadIntArrayFrom1DTensor(pads_tensor, pads, graph_viewer, logger), + "Error while reading pads tensor"); // Constant value and axes are optional. Make sure they are not empty. if (!GetTensorName(input_defs, 2).empty()) { const auto value_tensor = *initializers.at(input_defs[2]->Name()); emscripten::val value = emscripten::val::object(); - ORT_RETURN_IF_NOT(ReadScalarTensorData(value_tensor, value, logger), "Cannot read constant value"); + ORT_RETURN_IF_NOT(ReadScalarTensorData(value_tensor, value, graph_viewer, logger), "Cannot read constant value"); options.set("value", value); } @@ -100,7 +102,8 @@ Status PadOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const auto input_rank = input_shape.size(); std::vector axes; const auto& axes_tensor = *initializers.at(input_defs[3]->Name()); - ORT_RETURN_IF_NOT(ReadIntArrayFrom1DTensor(axes_tensor, axes, logger), "Error while read axes tensor"); + ORT_RETURN_IF_NOT(ReadIntArrayFrom1DTensor(axes_tensor, axes, graph_viewer, logger), + "Error while reading axes tensor"); std::vector axes_index; std::transform( axes.begin(), axes.end(), std::back_inserter(axes_index), diff --git a/onnxruntime/core/providers/webnn/builders/impl/reshape_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/reshape_op_builder.cc index da5e034c38c8e..8cbb381e0f53e 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/reshape_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/reshape_op_builder.cc @@ -93,9 +93,7 @@ bool ReshapeOpBuilder::IsOpSupportedImpl(const GraphViewer& graph_viewer, const auto& perm_tensor = *perm_init; std::vector unpacked_tensor; - auto status = onnxruntime::utils::UnpackInitializerData(perm_tensor, unpacked_tensor); - if (!status.IsOK()) { - LOGS(logger, ERROR) << "Error while unpacking perm_tensor: " << status.ErrorMessage(); + if (!UnpackInitializerData(perm_tensor, unpacked_tensor, graph_viewer, logger)) { return false; } diff --git a/onnxruntime/core/providers/webnn/builders/impl/resize_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/resize_op_builder.cc index f71ec2f98d112..ca5fb5150aa5b 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/resize_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/resize_op_builder.cc @@ -69,11 +69,9 @@ bool GetResizeScalesAndAxes(const GraphViewer& graph_viewer, } std::vector unpacked_tensor; - auto status = onnxruntime::utils::UnpackInitializerData(scales_tensor, unpacked_tensor); - if (!status.IsOK()) { - LOGS(logger, ERROR) << "Error while unpacking scales_tensor: " << status.ErrorMessage(); + if (!UnpackInitializerData(scales_tensor, unpacked_tensor, graph_viewer, logger)) { return false; - } + }; const float* scales_data = reinterpret_cast(unpacked_tensor.data()); if (has_axes) { @@ -137,9 +135,7 @@ bool GetResizeSizesAndAxes(const GraphViewer& graph_viewer, } std::vector unpacked_tensor; - auto status = onnxruntime::utils::UnpackInitializerData(sizes_tensor, unpacked_tensor); - if (!status.IsOK()) { - LOGS(logger, ERROR) << "Error while unpacking sizes_tensor: " << status.ErrorMessage(); + if (!UnpackInitializerData(sizes_tensor, unpacked_tensor, graph_viewer, logger)) { return false; } const int64_t* sizes_data = reinterpret_cast(unpacked_tensor.data()); diff --git a/onnxruntime/core/providers/webnn/builders/impl/rotaryEmbedding_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/rotaryEmbedding_op_builder.cc index ad22758028f2c..893ca9d2419c7 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/rotaryEmbedding_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/rotaryEmbedding_op_builder.cc @@ -436,9 +436,10 @@ bool RotaryEmbeddingOpBuilder::HasSupportedInputsImpl(const GraphViewer&, } // Check if the input data type is supported by each decomposed WebNN op. - // Decomposed ops include: "add", "concat", "gather", "mul", "reshape" and "split". - for (const std::string_view webnn_op_type : decomposed_op_map.at(op_type)) { - const std::string_view webnn_input_name = GetWebNNOpFirstInputName(webnn_op_type); + // Decomposed ops include: "Add", "Concat", "Gather", "Mul", "Reshape" and "Split". + for (const std::string_view decomposed_op_type : decomposed_op_map.at(op_type)) { + const std::string_view webnn_op_type = GetWebNNOpType(decomposed_op_type); + const std::string_view webnn_input_name = GetWebNNOpFirstInputName(decomposed_op_type); if (!IsDataTypeSupportedByWebNNOp( op_type, webnn_op_type, input_type, wnn_limits, webnn_input_name, "input", logger)) { return false; @@ -459,7 +460,8 @@ bool RotaryEmbeddingOpBuilder::HasSupportedOutputsImpl(const Node& node, } // Check if the output data type is supported by every decomposed WebNN op. - for (const std::string_view webnn_op_type : decomposed_op_map.at(op_type)) { + for (const std::string_view decomposed_op_type : decomposed_op_map.at(op_type)) { + const std::string_view webnn_op_type = GetWebNNOpType(decomposed_op_type); const std::string_view webnn_output_name = webnn_op_type == "split" ? "outputs" : "output"; if (!IsDataTypeSupportedByWebNNOp( op_type, webnn_op_type, output_type, wnn_limits, webnn_output_name, "output", logger)) { diff --git a/onnxruntime/core/providers/webnn/builders/impl/slice_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/slice_op_builder.cc index 6206ac23e4bd4..8853891ff8ed6 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/slice_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/slice_op_builder.cc @@ -72,9 +72,8 @@ Status SliceOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const input_name = input_defs[input_idx]->Name(); const auto& initializers(model_builder.GetInitializerTensors()); const auto& tensor = *initializers.at(input_name); - if (!ReadIntArrayFrom1DTensor(tensor, data, logger)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Data type for starts and ends inputs is not supported in this build."); - } + ORT_RETURN_IF_NOT(ReadIntArrayFrom1DTensor(tensor, data, model_builder.GetGraphViewer(), logger), + "Data type for starts or ends inputs is not supported in this build."); return Status::OK(); }; @@ -176,7 +175,7 @@ bool SliceOpBuilder::HasSupportedInputsImpl(const GraphViewer& graph_viewer, con if (TensorExists(input_defs, 4)) { std::vector steps; const auto* init = graph_viewer.GetConstantInitializer(input_defs[4]->Name()); - if (!init || !ReadIntArrayFrom1DTensor(*init, steps, logger)) + if (!init || !ReadIntArrayFrom1DTensor(*init, steps, graph_viewer, logger)) return false; if (std::any_of(steps.begin(), steps.end(), [](int64_t step) { return step < 0; })) { if (!IsDataTypeSupportedByWebNNOp(op_type, "reverse", input_type, wnn_limits, "input", "data", logger)) { diff --git a/onnxruntime/core/providers/webnn/builders/impl/split_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/split_op_builder.cc index 8094d3024a321..d55f953b214d4 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/split_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/split_op_builder.cc @@ -66,7 +66,8 @@ Status SplitOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, } else if (GetTensorName(input_defs, 1).size()) { const auto& initializers(model_builder.GetInitializerTensors()); const auto& split_tensor = *initializers.at(input_defs[1]->Name()); - ORT_RETURN_IF_NOT(ReadIntArrayFrom1DTensor(split_tensor, splits, logger), "Cannot get input for split."); + ORT_RETURN_IF_NOT(ReadIntArrayFrom1DTensor(split_tensor, splits, model_builder.GetGraphViewer(), logger), + "Cannot get input for split."); } else if (!helper.HasAttr("split")) { split_count = node.OutputDefs().size(); } @@ -125,8 +126,7 @@ bool SplitOpBuilder::IsOpSupportedImpl(const GraphViewer& graph_viewer, LOGS(logger, VERBOSE) << "The type of tensor's element data must be INT64."; return false; } - if (!ReadIntArrayFrom1DTensor(split_tensor, split, logger)) { - LOGS(logger, VERBOSE) << "Cannot get split."; + if (!ReadIntArrayFrom1DTensor(split_tensor, split, graph_viewer, logger)) { return false; } } else { diff --git a/onnxruntime/core/providers/webnn/builders/impl/triangular_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/triangular_op_builder.cc index ca98d8264fdcd..5a267557b9454 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/triangular_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/triangular_op_builder.cc @@ -57,7 +57,9 @@ Status TriangularOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const auto diagonal_tensor = *initializers.at(diagonal_name); std::vector unpacked_tensor; - ORT_RETURN_IF_ERROR(onnxruntime::utils::UnpackInitializerData(diagonal_tensor, unpacked_tensor)); + ORT_RETURN_IF_NOT(UnpackInitializerData(diagonal_tensor, unpacked_tensor, + model_builder.GetGraphViewer(), logger), + "Failed to unpack diagonal tensor data"); const auto diagonal = *reinterpret_cast(unpacked_tensor.data()); options.set("diagonal", SafeInt(diagonal).Ref()); } diff --git a/onnxruntime/core/providers/webnn/builders/map_info.h b/onnxruntime/core/providers/webnn/builders/map_info.h index 59408ba244842..5e860eea7cac9 100644 --- a/onnxruntime/core/providers/webnn/builders/map_info.h +++ b/onnxruntime/core/providers/webnn/builders/map_info.h @@ -43,18 +43,20 @@ constexpr std::array supported_fallback }; // Some ONNX ops are supported by decomposed WebNN ops. +// This map defines the relationship between ONNX ops and their corresponding decomposed ONNX ops. +// Use ONNX-to-ONNX op mapping to improve the search complexity for WebNN ops in the op_inputs_map. const std::map> decomposed_op_map = { - {"ConvInteger", {"cast", "conv2d", "dequantizeLinear"}}, + {"ConvInteger", {"Cast", "Conv", "DequantizeLinear"}}, {"GroupQueryAttention", - {"add", "cast", "concat", "constant", "cumulativeSum", "div", "expand", "lesser", "matmul", "reshape", "scatterND", - "softmax", "transpose", "where"}}, - {"LRN", {"add", "averagePool2d", "div", "mul", "pad", "pow", "transpose"}}, - {"MatMulInteger", {"cast", "dequantizeLinear", "matmul"}}, - {"MatMulNBits", {"add", "dequantizeLinear", "matmul", "reshape", "transpose"}}, - {"MultiHeadAttention", {"add", "cast", "concat", "constant", "div", "matmul", "reshape", "softmax", "transpose"}}, - {"RotaryEmbedding", {"add", "concat", "gather", "mul", "reshape", "slice", "split"}}, - {"SimplifiedLayerNormalization", {"add", "div", "mul", "pow", "reduceMean", "sqrt"}}, - {"SkipSimplifiedLayerNormalization", {"add", "div", "mul", "pow", "reduceMean", "sqrt"}}, + {"Add", "Cast", "Concat", "CumSum", "Div", "Expand", "Less", "MatMul", "Reshape", "ScatterND", + "Softmax", "Transpose", "Where"}}, + {"LRN", {"Add", "AveragePool", "Div", "Mul", "Pad", "Pow", "Transpose"}}, + {"MatMulInteger", {"Cast", "DequantizeLinear", "MatMul"}}, + {"MatMulNBits", {"Add", "DequantizeLinear", "MatMul", "Reshape", "Transpose"}}, + {"MultiHeadAttention", {"Add", "Cast", "Concat", "Div", "MatMul", "Reshape", "Softmax", "Transpose"}}, + {"RotaryEmbedding", {"Add", "Concat", "Gather", "Mul", "Reshape", "Slice", "Split"}}, + {"SimplifiedLayerNormalization", {"Add", "Div", "Mul", "Pow", "ReduceMean", "Sqrt"}}, + {"SkipSimplifiedLayerNormalization", {"Add", "Div", "Mul", "Pow", "ReduceMean", "Sqrt"}}, }; /** diff --git a/onnxruntime/core/providers/webnn/builders/model.cc b/onnxruntime/core/providers/webnn/builders/model.cc index ef829e82823d0..00e78c12ccd45 100644 --- a/onnxruntime/core/providers/webnn/builders/model.cc +++ b/onnxruntime/core/providers/webnn/builders/model.cc @@ -161,7 +161,7 @@ onnxruntime::common::Status Model::Dispatch(const InlinedHashMap(); emscripten::val console = emscripten::val::global("console"); if (trace) { - console.call("time", emscripten::val("ORT::Dispatch::jsepEnsureTensor")); + console.call("time", emscripten::val("ORT::Dispatch::webnnEnsureTensor")); } for (const auto& [_, tensor] : inputs) { emscripten::val shape = emscripten::val::array(); @@ -182,7 +182,7 @@ onnxruntime::common::Status Model::Dispatch(const InlinedHashMap("push", ml_tensor); } if (trace) { - console.call("timeEnd", emscripten::val("ORT::Dispatch::jsepEnsureTensor")); + console.call("timeEnd", emscripten::val("ORT::Dispatch::webnnEnsureTensor")); } auto ml_tensors = emscripten::val::global("Promise").call("all", promises).await(); for (const auto& [name, _] : inputs) { diff --git a/onnxruntime/core/providers/webnn/builders/model_builder.cc b/onnxruntime/core/providers/webnn/builders/model_builder.cc index c89e95704d0dc..372f9b2fd273a 100644 --- a/onnxruntime/core/providers/webnn/builders/model_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/model_builder.cc @@ -104,7 +104,7 @@ Status ModelBuilder::RegisterConstant(const onnx::TensorProto& tensor, emscripte const bool should_convert_int64_to_int32 = !IsInt64Supported() && data_type == ONNX_NAMESPACE::TensorProto_DataType_INT64; - if (utils::HasExternalData(tensor)) { + if (utils::HasExternalData(tensor) && !utils::HasExternalDataInMemory(tensor)) { // Create WebNN Constant from external data. std::basic_string external_file_path; onnxruntime::FileOffsetType data_offset; @@ -127,7 +127,8 @@ Status ModelBuilder::RegisterConstant(const onnx::TensorProto& tensor, emscripte if (tensor.has_raw_data()) { tensor_ptr = reinterpret_cast(const_cast(tensor.raw_data().c_str())); } else { - ORT_RETURN_IF_ERROR(onnxruntime::utils::UnpackInitializerData(tensor, unpacked_tensor)); + ORT_RETURN_IF_NOT(UnpackInitializerData(tensor, unpacked_tensor, graph_viewer_, logger), + "Failed to unpack initializer data for tensor: " + tensor.name()); tensor_ptr = reinterpret_cast(unpacked_tensor.data()); } diff --git a/onnxruntime/core/providers/webnn/webnn_execution_provider.cc b/onnxruntime/core/providers/webnn/webnn_execution_provider.cc index 2da7c6499933a..036de5e17efba 100644 --- a/onnxruntime/core/providers/webnn/webnn_execution_provider.cc +++ b/onnxruntime/core/providers/webnn/webnn_execution_provider.cc @@ -31,6 +31,7 @@ WebNNExecutionProvider::WebNNExecutionProvider(const std::string& webnn_device_f OrtDevice( webnn::IsMLTensorSupported() ? OrtDevice::GPU : OrtDevice::CPU, OrtDevice::MemType::DEFAULT, + OrtDevice::VendorIds::NONE, 0)}, wnn_device_type_(webnn::DeviceTypeFromString(webnn_device_flags)) { wnn_context_ = emscripten::val::module_property("currentContext"); diff --git a/onnxruntime/core/providers/xnnpack/detail/utils.cc b/onnxruntime/core/providers/xnnpack/detail/utils.cc index 2adf8339b4b66..4f9243e592009 100644 --- a/onnxruntime/core/providers/xnnpack/detail/utils.cc +++ b/onnxruntime/core/providers/xnnpack/detail/utils.cc @@ -362,7 +362,7 @@ TensorQuantType GetTensorQuantType(const NodeUnit& node_unit, int32_t io_index, } else if (scales_dim == tensor_shape[0]) { // default 0 for zero-point if zero_dim == 0 if (zero_tensor != nullptr) { - Initializer zp_val(*zero_tensor, node_unit.ModelPath()); + Initializer zp_val(graph_viewer.GetGraph(), *zero_tensor, node_unit.ModelPath()); auto zero_points = zp_val.DataAsSpan(); for (size_t i = 0; i < zp_val.size(); i++) { if (zero_points[i] != 0) { diff --git a/onnxruntime/core/providers/xnnpack/math/softmax.cc b/onnxruntime/core/providers/xnnpack/math/softmax.cc index 6786c29e1f056..c0246c2f0da34 100644 --- a/onnxruntime/core/providers/xnnpack/math/softmax.cc +++ b/onnxruntime/core/providers/xnnpack/math/softmax.cc @@ -31,13 +31,13 @@ bool IsQuantSoftmaxSupported(const NodeUnit& node_unit, const GraphViewer& graph // idealy, QlinearSoftmax or QDQSoftmax will keep this output scale and zp, but we have to handle some // qdq models converted from other framework auto [scale_tensor, zero_tensor] = GetQuantizationZeroPointAndScale(graph, node_unit.Outputs()[0]); - Initializer q_scale(*scale_tensor, node_unit.ModelPath()); + Initializer q_scale(graph.GetGraph(), *scale_tensor, node_unit.ModelPath()); if (fabs(q_scale.DataAsSpan()[0] - 1.0f / 256.0f) > 0.0001f) { break; } if (zero_tensor) { - Initializer q_zp(*zero_tensor, node_unit.ModelPath()); + Initializer q_zp(graph.GetGraph(), *zero_tensor, node_unit.ModelPath()); if (q_zp.DataAsSpan()[0] != 0) { break; } diff --git a/onnxruntime/core/providers/xnnpack/tensor/resize.cc b/onnxruntime/core/providers/xnnpack/tensor/resize.cc index 82cf1fc9bb87d..0bb1194643743 100644 --- a/onnxruntime/core/providers/xnnpack/tensor/resize.cc +++ b/onnxruntime/core/providers/xnnpack/tensor/resize.cc @@ -65,7 +65,7 @@ bool Resize::IsOnnxNodeSupported(const NodeUnit& node_unit, // check the scale for the second dim is 1 or the size of the second dim matches the input shape. // if not, it is not the C dim as a Resize will not change the number of channels. if (scale_tensor) { - const Initializer scale_val(*scale_tensor, node_unit.ModelPath()); + const Initializer scale_val(graph_viewer.GetGraph(), *scale_tensor, node_unit.ModelPath()); const auto scales = scale_val.DataAsSpan(); if (scales[1] != 1.0F) { break; @@ -90,7 +90,7 @@ bool Resize::IsOnnxNodeSupported(const NodeUnit& node_unit, } if (size_tensor) { - const Initializer size_val(*size_tensor, node_unit.ModelPath()); + const Initializer size_val(graph_viewer.GetGraph(), *size_tensor, node_unit.ModelPath()); if (size_val.DataAsSpan()[1] != x_shape->dim(1).dim_value()) { break; } diff --git a/onnxruntime/core/session/abi_ep_types.cc b/onnxruntime/core/session/abi_ep_types.cc new file mode 100644 index 0000000000000..719f55b4e6b38 --- /dev/null +++ b/onnxruntime/core/session/abi_ep_types.cc @@ -0,0 +1,33 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/session/abi_ep_types.h" + +#include +#include +#include "core/common/common.h" +#include "core/common/inlined_containers.h" +#include "core/graph/ep_api_types.h" +#include "core/session/abi_devices.h" + +onnxruntime::Status OrtEpGraphSupportInfo::AddNodesToFuse(gsl::span nodes) { + std::vector ep_nodes; + ep_nodes.reserve(nodes.size()); + + for (const OrtNode* node : nodes) { + const auto* ep_node = onnxruntime::EpNode::ToInternal(node); + ORT_RETURN_IF(ep_node == nullptr, "Invalid OrtNode variant for use in OrtEpApi."); + ep_nodes.push_back(ep_node); + } + + node_groupings.emplace_back(NodeGroupingKind::kFusedNode, std::move(ep_nodes)); + return onnxruntime::Status::OK(); +} + +onnxruntime::Status OrtEpGraphSupportInfo::AddSingleNode(const OrtNode* node) { + std::vector ep_nodes; + ep_nodes.push_back(onnxruntime::EpNode::ToInternal(node)); + node_groupings.emplace_back(NodeGroupingKind::kSingleAssignedNode, std::move(ep_nodes)); + + return onnxruntime::Status::OK(); +} diff --git a/onnxruntime/core/session/abi_ep_types.h b/onnxruntime/core/session/abi_ep_types.h new file mode 100644 index 0000000000000..b19a03a57a78a --- /dev/null +++ b/onnxruntime/core/session/abi_ep_types.h @@ -0,0 +1,47 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include +#include + +#include "core/common/inlined_containers_fwd.h" +#include "core/common/status.h" +#include "core/session/onnxruntime_c_api.h" + +namespace onnxruntime { +struct EpGraph; +struct EpNode; +} // namespace onnxruntime + +/// +/// Class used specify the nodes an EP supports. An instance of this class is passed to OrtEp's +/// GetCapability() function. An OrtEp adds groups of supported nodes to the OrtEpGraphSupportInfo instance. +/// +struct OrtEpGraphSupportInfo { + enum class NodeGroupingKind { + kInvalidGrouping = 0, + kSingleAssignedNode, + kFusedNode, + }; + + // A grouping of supported nodes that should be handled in a single ComputeCapability. + struct NodeGrouping { + NodeGrouping(NodeGroupingKind kind, std::vector&& nodes) + : kind(kind), nodes(std::move(nodes)) {} + + NodeGroupingKind kind = NodeGroupingKind::kInvalidGrouping; + std::vector nodes; + }; + + explicit OrtEpGraphSupportInfo(const onnxruntime::EpGraph& graph) : ort_graph(graph) {} + + onnxruntime::Status AddNodesToFuse(gsl::span nodes); + onnxruntime::Status AddSingleNode(const OrtNode* node); + + const onnxruntime::EpGraph& ort_graph; + std::vector node_groupings; +}; diff --git a/onnxruntime/core/session/abi_session_options.cc b/onnxruntime/core/session/abi_session_options.cc index 695819457bc79..7a17423112144 100644 --- a/onnxruntime/core/session/abi_session_options.cc +++ b/onnxruntime/core/session/abi_session_options.cc @@ -63,18 +63,17 @@ onnxruntime::Status OrtSessionOptions::RegisterCustomOpsLibrary(onnxruntime::Pat ORT_RETURN_IF_ERROR(platform_env.GetSymbolFromLibrary(library_handle, "RegisterCustomOps", (void**)&RegisterCustomOps)); - // Call the exported RegisterCustomOps function and store the return value in a unique_ptr. - const std::unique_ptr status(RegisterCustomOps(this, OrtGetApiBase()), - OrtApis::ReleaseStatus); + // Call the exported RegisterCustomOps function. + auto status = onnxruntime::ToStatusAndRelease(RegisterCustomOps(this, OrtGetApiBase())); - if (status) { // A non-nullptr status indicates an error registering custom ops. + if (!status.IsOK()) { auto unload_status = platform_env.UnloadDynamicLibrary(library_handle); if (!unload_status.IsOK()) { LOGS_DEFAULT(WARNING) << "Failed to unload handle for dynamic library " << onnxruntime::PathToUTF8String(library_name) << ": " << unload_status; } - return onnxruntime::ToStatus(status.get()); + return status; } // The internal onnxruntime::SessionOptions will manage the lifetime of library handles. diff --git a/onnxruntime/core/session/allocator_adapters.cc b/onnxruntime/core/session/allocator_adapters.cc index 1ff3487113358..5d1f84ba96cf2 100644 --- a/onnxruntime/core/session/allocator_adapters.cc +++ b/onnxruntime/core/session/allocator_adapters.cc @@ -11,7 +11,14 @@ namespace onnxruntime { namespace { +// The ORT API maintains ABI and backward compatibility, allowing applications to be built with an older version +// and run with a newer one. Users may call `RegisterAllocator` with a custom allocator. However, any new +// function pointers introduced in the newer version may contain invalid values, as the older application +// is unaware of them. +// Therefore, it's necessary to check the version value in `OrtAllocatorImplWrappingIAllocator` and +// `IAllocatorImplWrappingOrtAllocator` to ensure compatibility. constexpr uint32_t kOrtAllocatorReserveMinVersion = 18; +constexpr uint32_t kOrtAllocatorStatsMinVersion = 23; } // namespace OrtAllocatorImplWrappingIAllocator::OrtAllocatorImplWrappingIAllocator(onnxruntime::AllocatorPtr&& i_allocator) @@ -27,16 +34,18 @@ OrtAllocatorImplWrappingIAllocator::OrtAllocatorImplWrappingIAllocator(onnxrunti OrtAllocator::Reserve = [](OrtAllocator* this_, size_t size) { return static_cast(this_)->Reserve(size); }; } - OrtAllocator::GetStats = - [](const OrtAllocator* this_, OrtKeyValuePairs** stats) noexcept -> OrtStatusPtr { - API_IMPL_BEGIN - auto kvp = std::make_unique(); - auto stats_map = static_cast(this_)->Stats(); - kvp->Copy(stats_map); - *stats = reinterpret_cast(kvp.release()); - return nullptr; - API_IMPL_END - }; + if (OrtAllocator::version >= kOrtAllocatorStatsMinVersion) { + OrtAllocator::GetStats = + [](const OrtAllocator* this_, OrtKeyValuePairs** stats) noexcept -> OrtStatusPtr { + API_IMPL_BEGIN + auto kvp = std::make_unique(); + auto stats_map = static_cast(this_)->Stats(); + kvp->Copy(stats_map); + *stats = reinterpret_cast(kvp.release()); + return nullptr; + API_IMPL_END + }; + } } void* OrtAllocatorImplWrappingIAllocator::Alloc(size_t size) { @@ -98,6 +107,43 @@ void IAllocatorImplWrappingOrtAllocator::Free(void* p) { return ort_allocator_->Free(ort_allocator_, p); } +void IAllocatorImplWrappingOrtAllocator::GetStats(AllocatorStats* stats) { + *stats = {}; + + if (ort_allocator_->version >= kOrtAllocatorStatsMinVersion && ort_allocator_->GetStats) { + OrtKeyValuePairs* kvps = nullptr; + Ort::ThrowOnError(ort_allocator_->GetStats(ort_allocator_, &kvps)); + + auto release_fn = [](OrtKeyValuePairs** kvp) { + OrtApis::ReleaseKeyValuePairs(*kvp); + }; + + std::unique_ptr kvp_guard(&kvps, release_fn); + + for (size_t i = 0; i < kvps->keys.size(); ++i) { + if (strcmp(kvps->keys[i], "Limit") == 0) { + stats->bytes_limit = std::stoll(kvps->values[i]); + } else if (strcmp(kvps->keys[i], "InUse") == 0) { + stats->bytes_in_use = std::stoll(kvps->values[i]); + } else if (strcmp(kvps->keys[i], "TotalAllocated") == 0) { + stats->total_allocated_bytes = std::stoll(kvps->values[i]); + } else if (strcmp(kvps->keys[i], "MaxInUse") == 0) { + stats->max_bytes_in_use = std::stoll(kvps->values[i]); + } else if (strcmp(kvps->keys[i], "NumAllocs") == 0) { + stats->num_allocs = std::stoll(kvps->values[i]); + } else if (strcmp(kvps->keys[i], "NumReserves") == 0) { + stats->num_reserves = std::stoll(kvps->values[i]); + } else if (strcmp(kvps->keys[i], "NumArenaExtensions") == 0) { + stats->num_arena_extensions = std::stoll(kvps->values[i]); + } else if (strcmp(kvps->keys[i], "NumArenaShrinkages") == 0) { + stats->num_arena_shrinkages = std::stoll(kvps->values[i]); + } else if (strcmp(kvps->keys[i], "MaxAllocSize") == 0) { + stats->max_alloc_size = std::stoll(kvps->values[i]); + } + } + } +} + } // namespace onnxruntime #if defined(_MSC_VER) && !defined(__clang__) #pragma warning(disable : 26409) diff --git a/onnxruntime/core/session/allocator_adapters.h b/onnxruntime/core/session/allocator_adapters.h index eb2ce20244da9..8a180db75ec67 100644 --- a/onnxruntime/core/session/allocator_adapters.h +++ b/onnxruntime/core/session/allocator_adapters.h @@ -52,6 +52,8 @@ class IAllocatorImplWrappingOrtAllocator final : public IAllocator { void Free(void* p) override; + void GetStats(AllocatorStats* stats) override; + ORT_DISALLOW_COPY_AND_ASSIGNMENT(IAllocatorImplWrappingOrtAllocator); private: diff --git a/onnxruntime/core/session/custom_ops.cc b/onnxruntime/core/session/custom_ops.cc index 7dfed4cbe787d..cefcee8f408d7 100644 --- a/onnxruntime/core/session/custom_ops.cc +++ b/onnxruntime/core/session/custom_ops.cc @@ -583,21 +583,19 @@ ORT_API_STATUS_IMPL(OrtApis::KernelInfoGetAttribute_tensor, _In_ const OrtKernel onnxruntime::TensorShape tensor_shape = onnxruntime::utils::GetTensorShapeFromTensorProto(tensor_proto); const auto* type = onnxruntime::DataTypeImpl::TensorTypeFromONNXEnum(tensor_proto.data_type())->GetElementType(); onnxruntime::AllocatorPtr alloc_ptr = std::make_shared(allocator); - auto tensorp = std::make_unique(type, tensor_shape, std::move(alloc_ptr)); + auto tensor = onnxruntime::Tensor{type, tensor_shape, std::move(alloc_ptr)}; // Deserialize TensorProto into pre-allocated, empty Tensor. // TODO: here the TensorProto loses model path information, so it cannot be an external tensor. status = onnxruntime::utils::TensorProtoToTensor(onnxruntime::Env::Default(), std::filesystem::path(), - tensor_proto, *tensorp); + tensor_proto, tensor); if (!status.IsOK()) { return onnxruntime::ToOrtStatus(status); } // Initialize OrtValue from Tensor. - auto ml_tensor = onnxruntime::DataTypeImpl::GetType(); auto value = std::make_unique(); - value->Init(tensorp.release(), ml_tensor, ml_tensor->GetDeleteFunc()); - + onnxruntime::Tensor::InitOrtValue(std::move(tensor), *value); *out = value.release(); return nullptr; }); @@ -797,8 +795,7 @@ struct CustomOpKernel : OpKernel { Status Compute(OpKernelContext* ctx) const override { if (op_.version >= min_ort_version_with_compute_v2_support && op_.KernelComputeV2) { - auto status_ptr = op_.KernelComputeV2(op_kernel_, reinterpret_cast(ctx)); - return ToStatus(status_ptr); + return ToStatusAndRelease(op_.KernelComputeV2(op_kernel_, reinterpret_cast(ctx))); } else { op_.KernelCompute(op_kernel_, reinterpret_cast(ctx)); return Status::OK(); diff --git a/onnxruntime/core/session/environment.cc b/onnxruntime/core/session/environment.cc index 1ded826638250..20b7410045333 100644 --- a/onnxruntime/core/session/environment.cc +++ b/onnxruntime/core/session/environment.cc @@ -90,7 +90,6 @@ static bool AreOrtMemoryInfosEquivalent( return left == right; } else { return left.mem_type == right.mem_type && - left.id == right.id && left.device == right.device && strcmp(left.name, right.name) == 0; } @@ -209,17 +208,7 @@ Status Environment::Initialize(std::unique_ptr logging_ // create thread pools if (create_global_thread_pools) { - create_global_thread_pools_ = true; - OrtThreadPoolParams to = tp_options->intra_op_thread_pool_params; - if (to.name == nullptr) { - to.name = ORT_TSTR("intra-op"); - } - intra_op_thread_pool_ = concurrency::CreateThreadPool(&Env::Default(), to, concurrency::ThreadPoolType::INTRA_OP); - to = tp_options->inter_op_thread_pool_params; - if (to.name == nullptr) { - to.name = ORT_TSTR("inter-op"); - } - inter_op_thread_pool_ = concurrency::CreateThreadPool(&Env::Default(), to, concurrency::ThreadPoolType::INTER_OP); + ORT_RETURN_IF_ERROR(SetGlobalThreadingOptions(*tp_options)); } ORT_TRY { @@ -346,6 +335,24 @@ Internal copy node return status; } +Status Environment::SetGlobalThreadingOptions(const OrtThreadingOptions& tp_options) { + if (create_global_thread_pools_) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Global thread pools have already been created, cannot replace them"); + } + create_global_thread_pools_ = true; + OrtThreadPoolParams to = tp_options.intra_op_thread_pool_params; + if (to.name == nullptr) { + to.name = ORT_TSTR("intra-op"); + } + intra_op_thread_pool_ = concurrency::CreateThreadPool(&Env::Default(), to, concurrency::ThreadPoolType::INTRA_OP); + to = tp_options.inter_op_thread_pool_params; + if (to.name == nullptr) { + to.name = ORT_TSTR("inter-op"); + } + inter_op_thread_pool_ = concurrency::CreateThreadPool(&Env::Default(), to, concurrency::ThreadPoolType::INTER_OP); + return Status::OK(); +} + Status Environment::CreateAndRegisterAllocatorV2(const std::string& provider_type, const OrtMemoryInfo& mem_info, const std::unordered_map& options, const OrtArenaCfg* arena_cfg) { @@ -520,7 +527,7 @@ Status Environment::EpInfo::Create(std::unique_ptr library_in, std::u std::array ep_devices{nullptr}; size_t num_ep_devices = 0; - ORT_RETURN_IF_ERROR(ToStatus( + ORT_RETURN_IF_ERROR(ToStatusAndRelease( factory.GetSupportedDevices(&factory, sorted_devices.data(), sorted_devices.size(), ep_devices.data(), ep_devices.size(), &num_ep_devices))); diff --git a/onnxruntime/core/session/ep_api.cc b/onnxruntime/core/session/ep_api.cc index 0cac00326392c..ffb5a286730ba 100644 --- a/onnxruntime/core/session/ep_api.cc +++ b/onnxruntime/core/session/ep_api.cc @@ -3,8 +3,13 @@ #include "core/session/ep_api.h" +#include +#include #include "core/framework/error_code_helper.h" +#include "core/framework/func_api.h" +#include "core/graph/ep_api_types.h" #include "core/session/abi_devices.h" +#include "core/session/abi_ep_types.h" #include "core/session/ort_apis.h" using namespace onnxruntime; @@ -38,12 +43,59 @@ ORT_API(void, ReleaseEpDevice, _Frees_ptr_opt_ OrtEpDevice* device) { delete device; } +ORT_API_STATUS_IMPL(EpGraphSupportInfo_AddNodesToFuse, _In_ OrtEpGraphSupportInfo* ort_graph_support_info, + _In_reads_(num_nodes) const OrtNode* const* nodes, size_t num_nodes) { + API_IMPL_BEGIN + if (ort_graph_support_info == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Must specify a valid OrtGraph instance"); + } + + if (num_nodes == 0 || nodes == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Must specify a valid array of 1 or more supported nodes"); + } + + gsl::span nodes_span(nodes, nodes + num_nodes); + ORT_API_RETURN_IF_STATUS_NOT_OK(ort_graph_support_info->AddNodesToFuse(nodes_span)); + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(EpGraphSupportInfo_AddSingleNode, _In_ OrtEpGraphSupportInfo* ort_graph_support_info, + _In_ const OrtNode* node) { + API_IMPL_BEGIN + if (ort_graph_support_info == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Must specify a valid OrtGraph instance"); + } + + if (node == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Must specify a non-null OrtNode"); + } + + ORT_API_RETURN_IF_STATUS_NOT_OK(ort_graph_support_info->AddSingleNode(node)); + return nullptr; + API_IMPL_END +} + +// +// OrtCompiledNodeComputeContext +// + +ORT_API(const char*, NodeComputeContext_NodeName, _In_ const OrtNodeComputeContext* context) { + const auto* compute_context = reinterpret_cast(context); + return compute_context->node_name; +} + static constexpr OrtEpApi ort_ep_api = { // NOTE: ABI compatibility depends on the order within this struct so all additions must be at the end, // and no functions can be removed (the implementation needs to change to return an error). &OrtExecutionProviderApi::CreateEpDevice, &OrtExecutionProviderApi::ReleaseEpDevice, + // End of Version 22 - DO NOT MODIFY ABOVE + + &OrtExecutionProviderApi::EpGraphSupportInfo_AddNodesToFuse, + &OrtExecutionProviderApi::EpGraphSupportInfo_AddSingleNode, + &OrtExecutionProviderApi::NodeComputeContext_NodeName, }; // checks that we don't violate the rule that the functions must remain in the slots they were originally assigned diff --git a/onnxruntime/core/session/ep_api.h b/onnxruntime/core/session/ep_api.h index 23cd31cbdd861..84c8781a70adb 100644 --- a/onnxruntime/core/session/ep_api.h +++ b/onnxruntime/core/session/ep_api.h @@ -16,4 +16,11 @@ ORT_API_STATUS_IMPL(CreateEpDevice, _In_ OrtEpFactory* ep_factory, _Out_ OrtEpDevice** ep_device); ORT_API(void, ReleaseEpDevice, _Frees_ptr_opt_ OrtEpDevice* device); + +ORT_API_STATUS_IMPL(EpGraphSupportInfo_AddNodesToFuse, _In_ OrtEpGraphSupportInfo* graph_support_info, + _In_reads_(num_nodes) const OrtNode* const* nodes, size_t num_nodes); +ORT_API_STATUS_IMPL(EpGraphSupportInfo_AddSingleNode, _In_ OrtEpGraphSupportInfo* graph_support_info, + _In_ const OrtNode* node); +ORT_API(const char*, NodeComputeContext_NodeName, _In_ const OrtNodeComputeContext* context); + } // namespace OrtExecutionProviderApi diff --git a/onnxruntime/core/session/ep_factory_internal.cc b/onnxruntime/core/session/ep_factory_internal.cc index fd907302b6b8d..354e609a6301c 100644 --- a/onnxruntime/core/session/ep_factory_internal.cc +++ b/onnxruntime/core/session/ep_factory_internal.cc @@ -68,7 +68,7 @@ void EpFactoryInternal::ReleaseEp(OrtEp* /*ep*/) { } InternalExecutionProviderFactory::InternalExecutionProviderFactory(EpFactoryInternal& ep_factory, - const std::vector& ep_devices) + gsl::span ep_devices) : ep_factory_{ep_factory} { devices_.reserve(ep_devices.size()); ep_metadata_.reserve(ep_devices.size()); @@ -83,10 +83,11 @@ std::unique_ptr InternalExecutionProviderFactory::CreateProvider(const OrtSessionOptions& session_options, const OrtLogger& session_logger) { std::unique_ptr ep; - OrtStatus* status = ep_factory_.CreateIExecutionProvider(devices_.data(), ep_metadata_.data(), devices_.size(), - &session_options, &session_logger, &ep); - if (status != nullptr) { - ORT_THROW("Error creating execution provider: ", ToStatus(status).ToString()); + auto status = ToStatusAndRelease(ep_factory_.CreateIExecutionProvider(devices_.data(), ep_metadata_.data(), + devices_.size(), &session_options, + &session_logger, &ep)); + if (!status.IsOK()) { + ORT_THROW("Error creating execution provider: ", status); } return ep; diff --git a/onnxruntime/core/session/ep_factory_internal.h b/onnxruntime/core/session/ep_factory_internal.h index 2dcc769ec635e..3853949e94375 100644 --- a/onnxruntime/core/session/ep_factory_internal.h +++ b/onnxruntime/core/session/ep_factory_internal.h @@ -3,6 +3,7 @@ #pragma once +#include #include #include "core/common/common.h" @@ -74,7 +75,7 @@ class EpFactoryInternal : public OrtEpFactory { // IExecutionProviderFactory for EpFactoryInternal that is required for SessionOptionsAppendExecutionProvider_V2 struct InternalExecutionProviderFactory : public IExecutionProviderFactory { public: - InternalExecutionProviderFactory(EpFactoryInternal& ep_factory, const std::vector& ep_devices); + InternalExecutionProviderFactory(EpFactoryInternal& ep_factory, gsl::span ep_devices); std::unique_ptr CreateProvider(const OrtSessionOptions& session_options, const OrtLogger& session_logger) override; diff --git a/onnxruntime/core/session/ep_library_plugin.cc b/onnxruntime/core/session/ep_library_plugin.cc index 3c873ec4a9aeb..f7220c20f9ac9 100644 --- a/onnxruntime/core/session/ep_library_plugin.cc +++ b/onnxruntime/core/session/ep_library_plugin.cc @@ -23,11 +23,8 @@ Status EpLibraryPlugin::Load() { std::vector factories{4, nullptr}; size_t num_factories = 0; - OrtStatus* ort_status = create_fn_(registration_name_.c_str(), OrtGetApiBase(), - factories.data(), factories.size(), &num_factories); - if (ort_status != nullptr) { - return ToStatus(ort_status); - } + ORT_RETURN_IF_ERROR(ToStatusAndRelease(create_fn_(registration_name_.c_str(), OrtGetApiBase(), + factories.data(), factories.size(), &num_factories))); for (size_t i = 0; i < num_factories; ++i) { factories_.push_back(factories[i]); @@ -64,11 +61,11 @@ Status EpLibraryPlugin::Unload() { continue; } - OrtStatus* status = release_fn_(factory); - if (status != nullptr) { + auto status = ToStatusAndRelease(release_fn_(factory)); + if (!status.IsOK()) { // log it and treat it as released LOGS_DEFAULT(ERROR) << "ReleaseEpFactory failed for: " << library_path_ << " with error: " - << ToStatus(status).ErrorMessage(); + << status.ErrorMessage(); } factories_[idx] = nullptr; // clear the pointer in case there's a failure before all are released diff --git a/onnxruntime/core/session/ep_plugin_provider_interfaces.cc b/onnxruntime/core/session/ep_plugin_provider_interfaces.cc new file mode 100644 index 0000000000000..ebd74dd51774c --- /dev/null +++ b/onnxruntime/core/session/ep_plugin_provider_interfaces.cc @@ -0,0 +1,393 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/session/ep_plugin_provider_interfaces.h" + +#include +#include +#include +#include +#include +#include "core/framework/abi_pointer_array.h" +#include "core/framework/compute_capability.h" +#include "core/framework/error_code_helper.h" +#include "core/framework/model_metadef_id_generator.h" +#include "core/graph/ep_api_types.h" +#include "core/graph/model_editor_api_types.h" +#include "core/session/abi_devices.h" +#include "core/session/abi_ep_types.h" +#include "core/session/abi_logger.h" +#include "core/session/abi_session_options_impl.h" +#include "core/session/allocator_adapters.h" +#include "core/session/ort_apis.h" +#include "core/providers/partitioning_utils.h" + +namespace onnxruntime { + +// +// PluginExecutionProviderFactory +// + +PluginExecutionProviderFactory::PluginExecutionProviderFactory(OrtEpFactory& ep_factory, + gsl::span ep_devices) + : ep_factory_{ep_factory} { + devices_.reserve(ep_devices.size()); + ep_metadata_.reserve(ep_devices.size()); + + for (const auto* ep_device : ep_devices) { + devices_.push_back(ep_device->device); + ep_metadata_.push_back(&ep_device->ep_metadata); + } +} + +std::unique_ptr +PluginExecutionProviderFactory::CreateProvider(const OrtSessionOptions& session_options, + const OrtLogger& session_logger) { + OrtEp* ort_ep = nullptr; + Status status = ToStatusAndRelease(ep_factory_.CreateEp(&ep_factory_, devices_.data(), ep_metadata_.data(), + devices_.size(), &session_options, &session_logger, &ort_ep)); + + if (!status.IsOK()) { + ORT_THROW("Error creating execution provider: ", status.ToString()); + } + + auto ep_wrapper = std::make_unique(UniqueOrtEp(ort_ep, OrtEpDeleter(ep_factory_)), + session_options); + ep_wrapper->SetLogger(session_logger.ToInternal()); + + return ep_wrapper; +} + +/// +/// Functor used to generate a Metadef name for a subgraph supported by a plugin EP. +/// The generated name is a concatenation of a prefix (i.e., the EP name) with +/// the model's hash and a unique ID. +/// +struct PluginEpMetaDefNameFunctor { + explicit PluginEpMetaDefNameFunctor(const ModelMetadefIdGenerator& generator, + const GraphViewer& graph_viewer, + const std::string& prefix) + : generator_(generator), graph_viewer_(graph_viewer), prefix_(prefix) {} + + std::string operator()() { + uint64_t model_hash = 0; + int id = generator_.GenerateId(graph_viewer_, model_hash); + return MakeString(prefix_, "_", model_hash, "_", id); + } + + const ModelMetadefIdGenerator& generator_; + const GraphViewer& graph_viewer_; + const std::string& prefix_; +}; + +// +// PluginExecutionProvider +// + +PluginExecutionProvider::PluginExecutionProvider(UniqueOrtEp ep, const OrtSessionOptions& session_options) + : IExecutionProvider(ep->GetName(ep.get()), OrtDevice()), // TODO: What to do about OrtDevice for plugins? + ort_ep_(std::move(ep)) { + generate_ep_ctx_model_ = session_options.value.GetEpContextGenerationOptions().enable; +} + +PluginExecutionProvider::~PluginExecutionProvider() { + if (ort_ep_ && !api_node_compute_infos_.empty()) { + ort_ep_->ReleaseNodeComputeInfos(ort_ep_.get(), api_node_compute_infos_.data(), + api_node_compute_infos_.size()); + } +} + +std::vector> +PluginExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer, + const IKernelLookup& kernel_lookup, + const GraphOptimizerRegistry& graph_optimizer_registry, + IResourceAccountant* resource_accountant) const { + ORT_UNUSED_PARAMETER(graph_optimizer_registry); // TODO: Add support + ORT_UNUSED_PARAMETER(resource_accountant); // TODO: Add support? Not used by prioritized EPs + ORT_UNUSED_PARAMETER(kernel_lookup); // TODO: Add support? Not used by prioritized EPs, so probably not needed? + + std::unique_ptr ep_graph = nullptr; + if (Status status = EpGraph::Create(graph_viewer, ep_graph); !status.IsOK()) { + LOGS_DEFAULT(ERROR) << "Failed to create OrtGraph: " << status.ToString(); + return {}; + } + + OrtEpGraphSupportInfo api_graph_support_info(*ep_graph); + Status status = ToStatusAndRelease(ort_ep_->GetCapability(ort_ep_.get(), ep_graph->ToExternal(), &api_graph_support_info)); + + if (!status.IsOK()) { + LOGS_DEFAULT(ERROR) << "OrtEp::GetCapability() failed with error: " << status.ToString(); + return {}; + } + + std::vector> result; + result.reserve(api_graph_support_info.node_groupings.size()); + if (api_graph_support_info.node_groupings.empty()) { + return {}; + } + + ModelMetadefIdGenerator generator; + + // Create ComputeCapability instances from OrtEpGraphSupportInfo::NodeGrouping instances. + for (const OrtEpGraphSupportInfo::NodeGrouping& node_grouping : api_graph_support_info.node_groupings) { + if (node_grouping.kind == OrtEpGraphSupportInfo::NodeGroupingKind::kSingleAssignedNode) { + auto indexed_sub_graph = std::make_unique(); + + indexed_sub_graph->nodes.push_back(node_grouping.nodes[0]->GetInternalNode().Index()); + result.push_back(std::make_unique(std::move(indexed_sub_graph))); + } else if (node_grouping.kind == OrtEpGraphSupportInfo::NodeGroupingKind::kFusedNode) { + std::unordered_set node_set; + node_set.reserve(node_grouping.nodes.size()); + for (const EpNode* ep_node : node_grouping.nodes) { + node_set.insert(&ep_node->GetInternalNode()); + } + + // We now require the OrtEp to only provide individual groups of supported nodes that each maps to exactly + // one ComputeCapability. Calling utils::CreateSupportedPartitions() may create multiple ComputeCapability + // instances, and if so, log an error and return. + // + // TODO(adrianlizarraga): Do not use the heavy-weight CreateSupportedPartitions just to check if the user + // provided a single partition. Use utils::MakeCapability() and create a new helper to check that there are no + // unsupported nodes in any path between supported nodes. + std::vector> capabilities = utils::CreateSupportedPartitions( + graph_viewer, node_set, /*stop_ops*/ {}, PluginEpMetaDefNameFunctor(generator, graph_viewer, this->Type()), + this->Type(), this->Type(), /*node_unit_map*/ nullptr); + + if (capabilities.size() > 1) { + LOGS_DEFAULT(ERROR) << "OrtEp::GetCapability() set nodes that cannot be fused together. " + << "Please ensure that the nodes provided to EpGraphSupportInfo_AddFusedNodes() do not " + << "have an unsupported node in any path between two of the supported nodes."; + return {}; + } + + // Enforce that the nodes in node_set match the nodes in capabilities[0] + // TODO(adrianlizarraga): This check can be removed when we stop using utils::CreateSupportedPartitions() above. + std::vector& capability_node_indices = capabilities[0]->sub_graph->nodes; + std::unordered_set capability_node_indices_set(capability_node_indices.begin(), + capability_node_indices.end()); + + ORT_ENFORCE(node_set.size() == capability_node_indices_set.size()); + ORT_ENFORCE(std::all_of(node_set.begin(), node_set.end(), [&capability_node_indices_set](const Node* node) { + return capability_node_indices_set.count(node->Index()) != 0; + })); + + result.push_back(std::move(capabilities[0])); + } else { + LOGS_DEFAULT(ERROR) << "PluginExecutionProvider::GetCapability() has invalid NodeGroupingKind: " + << static_cast(node_grouping.kind); + return {}; + } + } + + return result; +} + +Status PluginExecutionProvider::FusedNodeState::AddFusedNode(const Node& fused_node, /*out*/ EpNode*& added_ep_node) { + std::unique_ptr unique_ep_fused_node = nullptr; + ORT_RETURN_IF_ERROR(EpNode::Create(fused_node, /*parent graph*/ nullptr, this->value_infos, unique_ep_fused_node)); + this->nodes.push_back(std::move(unique_ep_fused_node)); + added_ep_node = this->nodes.back().get(); + return Status::OK(); +} + +/// +/// Converts the EPContext nodes provided by the plugin EP (OrtNode instances) to onnxruntime::Node instances. +/// Note that the EP plugin uses the model editor API to create the OrtNode instances. +/// +/// Name of the plugin EP. +/// EPContext nodes provided by the plugin EP. +/// Output parameter set to the resulting array of EPContext nodes. +/// Output parameter that stores the NodeArgs used by the EPContext nodes. +/// A status indicating success or an error. +static Status ConvertEpContextNodes(const std::string& ep_name, const std::vector plugin_ep_context_nodes, + /*out*/ std::vector>& result_nodes, + /*out*/ std::vector>& result_node_args) { +#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS) + if (plugin_ep_context_nodes.empty()) { + return Status::OK(); // No EPContext nodes. + } + + std::vector> ep_context_nodes_holder; + std::vector> ep_context_node_args_holder; + + ep_context_nodes_holder.reserve(plugin_ep_context_nodes.size()); + + for (const OrtNode* ort_node : plugin_ep_context_nodes) { + ORT_RETURN_IF_NOT(ort_node != nullptr, ep_name, ": OrtEp::Compile() returned a NULL EPContext node."); + + const ModelEditorNode* editor_node = ModelEditorNode::ToInternal(ort_node); + ORT_RETURN_IF_NOT(editor_node != nullptr, ep_name, ": OrtEp::Compile() returned OrtNode objects ", + "that were not created with OrtModelEditorApi."); + + // Create NodeArg for each input/output. + std::vector input_node_args; + std::vector output_node_args; + + input_node_args.reserve(editor_node->input_names.size()); + output_node_args.reserve(editor_node->output_names.size()); + + for (const std::string& input_name : editor_node->input_names) { + auto node_arg = std::make_unique(input_name, /*p_arg_type*/ nullptr); // Graph.Resolve() sets type. + input_node_args.push_back(node_arg.get()); + ep_context_node_args_holder.push_back(std::move(node_arg)); + } + + for (const std::string& output_name : editor_node->output_names) { + auto node_arg = std::make_unique(output_name, /*p_arg_type*/ nullptr); // Graph.Resolve() sets type. + output_node_args.push_back(node_arg.get()); + ep_context_node_args_holder.push_back(std::move(node_arg)); + } + + // Create a name -> attribute map. + NodeAttributes attributes; + attributes.reserve(editor_node->attributes.size()); + + for (const ONNX_NAMESPACE::AttributeProto& attr : editor_node->attributes) { + attributes.emplace(attr.name(), attr); + } + + // Create Node + auto internal_node = std::make_unique(editor_node->node_name, + editor_node->operator_name, + "EPContext node for " + ep_name, + input_node_args, + output_node_args, + &attributes, + editor_node->domain_name); + + ep_context_nodes_holder.push_back(std::move(internal_node)); + } + + result_nodes = std::move(ep_context_nodes_holder); + result_node_args = std::move(ep_context_node_args_holder); + + return Status::OK(); +#else + ORT_UNUSED_PARAMETER(ep_name); + ORT_UNUSED_PARAMETER(plugin_ep_context_nodes); + ORT_UNUSED_PARAMETER(result_nodes); + ORT_UNUSED_PARAMETER(result_node_args); + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "Creating EPContext models is not supported in this build"); +#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS) +} + +common::Status PluginExecutionProvider::Compile(const std::vector& fused_nodes_and_graphs, + std::vector& node_compute_infos) { + const logging::Logger* logger = GetLogger(); + const size_t num_graphs = fused_nodes_and_graphs.size(); + std::vector> api_graphs_holder; + std::vector api_graphs; + std::vector api_node_compute_infos(num_graphs, nullptr); + std::vector api_fused_nodes; + + // Push a new FusedNodeState to store the EpNode instances that we'll create to wrap the original fused nodes. + // Fused nodes must be valid throughout model inference because they may be cached in NodeComputeInfo instances. + fused_node_states_.push_back(FusedNodeState()); + FusedNodeState& fused_node_state = fused_node_states_.back(); + + fused_node_state.nodes.reserve(num_graphs); + api_graphs_holder.reserve(num_graphs); + api_graphs.reserve(num_graphs); + api_fused_nodes.reserve(num_graphs); + api_node_compute_infos_.reserve(api_node_compute_infos_.size() + num_graphs); + + // Wrap GraphViewers into OrtGraphs and fused Nodes into OrtNodes. + for (const FusedNodeAndGraph& node_and_graph : fused_nodes_and_graphs) { + const GraphViewer& graph_viewer = node_and_graph.filtered_graph; + const Node& fused_node = node_and_graph.fused_node; + + std::unique_ptr ep_graph = nullptr; + ORT_RETURN_IF_ERROR(EpGraph::Create(graph_viewer, ep_graph)); + api_graphs.push_back(ep_graph->ToExternal()); + api_graphs_holder.push_back(std::move(ep_graph)); + + EpNode* ep_fused_node = nullptr; + ORT_RETURN_IF_ERROR(fused_node_state.AddFusedNode(fused_node, ep_fused_node)); + api_fused_nodes.push_back(ep_fused_node->ToExternal()); + } + + // Provide an output buffer for the plugin EP to store EPContext nodes if it needs to (i.e., enabled in session options). + std::vector> plugin_ep_context_nodes_holder; + std::vector plugin_ep_context_nodes; + plugin_ep_context_nodes_holder.reserve(num_graphs); + plugin_ep_context_nodes.resize(num_graphs, nullptr); + + Status compile_status = ToStatusAndRelease(ort_ep_->Compile(ort_ep_.get(), api_graphs.data(), api_fused_nodes.data(), + num_graphs, api_node_compute_infos.data(), + plugin_ep_context_nodes.data())); + + // Store any EPContext nodes provided by the plugin EP in std::unique_ptr so that they are always properly released. + for (OrtNode* ort_node : plugin_ep_context_nodes) { + auto unique_ort_node = std::unique_ptr(ort_node, OrtApis::ReleaseNode); + plugin_ep_context_nodes_holder.push_back(std::move(unique_ort_node)); + } + + // Save OrtNodeComputeInfo created by OrtEp instance. They're freed when this IExecutionProvider + // is destroyed. + for (size_t i = 0; i < num_graphs; i++) { + if (api_node_compute_infos[i] != nullptr) { + api_node_compute_infos_.push_back(api_node_compute_infos[i]); + } + } + + ORT_RETURN_IF_ERROR(compile_status); + + // Initialize node_compute_infos as wrappers to api_node_compute_infos. + for (size_t i = 0; i < num_graphs; i++) { + OrtNodeComputeInfo* api_node_compute_info = api_node_compute_infos[i]; + ORT_RETURN_IF(api_node_compute_info == nullptr, "OrtEp::Compile() did not set a valid OrtNodeComputeInfo ", + "instance for graph at index ", i); + + NodeComputeInfo compute_info; + compute_info.create_state_func = [api_node_compute_info, logger](ComputeContext* context, + FunctionState* compute_state) -> int { + Status status = ToStatusAndRelease( + api_node_compute_info->CreateState(api_node_compute_info, + reinterpret_cast(context), + compute_state)); + const bool success = status.IsOK(); + if (!success) { + LOGS(*logger, ERROR) << "OrtNodeComputeInfo::CreateComputeState() failed with error: " + << status.ErrorMessage(); + } + + return success ? 0 : 1; + }; + + compute_info.release_state_func = [api_node_compute_info](FunctionState compute_state) -> void { + api_node_compute_info->ReleaseState(api_node_compute_info, compute_state); + }; + + compute_info.compute_func = [api_node_compute_info](FunctionState compute_state, + const OrtApi* /*c_api*/, + OrtKernelContext* kernel_context) -> Status { + ORT_RETURN_IF_ERROR(ToStatusAndRelease((api_node_compute_info->Compute(api_node_compute_info, compute_state, + kernel_context)))); + return Status::OK(); + }; + + node_compute_infos.push_back(std::move(compute_info)); + } + + // Convert the EPContext nodes provided by the plugin EP into onnxruntime::Node instances. + // We store the converted Node and NodeArg instances as members to ensure they can be returned to the ORT graph + // partitioner via a call to IExecutionProvider::GetEpContextNodes(). + if (generate_ep_ctx_model_) { + ORT_RETURN_IF_ERROR(ConvertEpContextNodes(Type(), plugin_ep_context_nodes, + /*out*/ ep_context_nodes_, /*out*/ ep_context_node_args_)); + } + + return Status::OK(); +} + +const InlinedVector PluginExecutionProvider::GetEpContextNodes() const { + InlinedVector result; + + for (const std::unique_ptr& node : ep_context_nodes_) { + result.push_back(node.get()); + } + + return result; +} + +} // namespace onnxruntime diff --git a/onnxruntime/core/session/ep_plugin_provider_interfaces.h b/onnxruntime/core/session/ep_plugin_provider_interfaces.h new file mode 100644 index 0000000000000..2b88c7f5d494f --- /dev/null +++ b/onnxruntime/core/session/ep_plugin_provider_interfaces.h @@ -0,0 +1,104 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include +#include +#include + +#include "core/common/common.h" +#include "core/framework/execution_provider.h" +#include "core/providers/providers.h" +#include "core/session/onnxruntime_c_api.h" + +namespace onnxruntime { +struct EpNode; +struct EpValueInfo; +class NodeArg; + +/// +/// IExecutionProviderFactory that wraps a OrtEpFactory. Required for SessionOptionsAppendExecutionProvider_V2. +/// +struct PluginExecutionProviderFactory : public IExecutionProviderFactory { + public: + PluginExecutionProviderFactory(OrtEpFactory& ep_factory, gsl::span ep_devices); + + std::unique_ptr CreateProvider(const OrtSessionOptions& session_options, + const OrtLogger& session_logger) override; + + std::unique_ptr CreateProvider() override { + ORT_NOT_IMPLEMENTED("CreateProvider without parameters is not supported."); + } + + private: + OrtEpFactory& ep_factory_; + std::vector devices_; + std::vector ep_metadata_; +}; + +/// +/// Functor that deletes an instance of OrtEp. Used to create an std::unique_ptr. +/// +struct OrtEpDeleter { + explicit OrtEpDeleter(OrtEpFactory& ort_ep_factory) : ort_ep_factory_(ort_ep_factory) {} + void operator()(OrtEp* ort_ep) { + ort_ep_factory_.ReleaseEp(&ort_ep_factory_, ort_ep); + } + OrtEpFactory& ort_ep_factory_; +}; + +/// +/// Type that represents a std::unique_ptr for an instance of OrtEp. +/// +using UniqueOrtEp = std::unique_ptr; + +/// +/// IExecutionProvider that wraps an instance of OrtEp. +/// +class PluginExecutionProvider : public IExecutionProvider { + public: + explicit PluginExecutionProvider(UniqueOrtEp ep, const OrtSessionOptions& session_options); + ~PluginExecutionProvider(); + + std::vector> + GetCapability(const onnxruntime::GraphViewer& graph_viewer, + const IKernelLookup& kernel_lookup, + const GraphOptimizerRegistry& graph_optimizer_registry, + IResourceAccountant* resource_accountant = nullptr) const override; + + common::Status Compile(const std::vector& fused_nodes_and_graphs, + std::vector& node_compute_funcs) override; + + const InlinedVector GetEpContextNodes() const override; + + private: + struct FusedNodeState { + FusedNodeState() = default; + FusedNodeState(FusedNodeState&& other) = default; + FusedNodeState(const FusedNodeState& other) = delete; + Status AddFusedNode(const Node& fused_node, /*out*/ EpNode*& added_ep_node); + + std::vector> nodes; + std::unordered_map> value_infos; + }; + + UniqueOrtEp ort_ep_; + bool generate_ep_ctx_model_ = false; + std::vector api_node_compute_infos_; + + // Fused nodes have to be valid throughout model inference because they may be cached in NodeComputeInfo instances. + // For each fused node, the Compile() function creates EpNode and EpValueInfo instances on the heap, + // which are then passed to the underlying OrtEp instance. This class stores this "fused node state" + // so that it is not destroyed until the EP itself is destroyed. + std::vector fused_node_states_; + + // Stores the EPContext Nodes created from the OrtNode instances returned by the underlying plugin EP. + // Need to store both the Node and NodeArg instances so that they are available when the GraphPartitioner + // calls IExecutionProvider::GetEpContextNodes(). + std::vector> ep_context_nodes_; + std::vector> ep_context_node_args_; +}; +} // namespace onnxruntime diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index c34769f43ae1d..468639a9f25bb 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -251,10 +251,10 @@ Status GetMinimalBuildOptimizationHandling( } // namespace std::atomic InferenceSession::global_session_id_{1}; +std::map InferenceSession::active_sessions_; #ifdef _WIN32 std::mutex InferenceSession::active_sessions_mutex_; // Protects access to active_sessions_ -std::map InferenceSession::active_sessions_; -const std::string InferenceSession::callback_etw_provider_key_{"InferenceSessionML_ORT_provider"}; +onnxruntime::WindowsTelemetry::EtwInternalCallback InferenceSession::callback_ML_ORT_provider_; #endif static Status FinalizeSessionOptions(const SessionOptions& user_provided_session_options, @@ -375,6 +375,7 @@ void InferenceSession::ConstructorCommon(const SessionOptions& session_options, session_id_ = global_session_id_.fetch_add(1); SetLoggingManager(session_options, session_env); + // The call to InitLogger depends on the final state of session_options_. Hence it should be invoked // after the invocation of FinalizeSessionOptions. InitLogger(logging_manager_); // this sets session_logger_ so that it can be used for logging after this point. @@ -507,43 +508,87 @@ void InferenceSession::ConstructorCommon(const SessionOptions& session_options, telemetry_ = {}; #ifdef _WIN32 - - { - std::lock_guard lock(active_sessions_mutex_); - auto result = active_sessions_.insert_or_assign(session_id_, this); - ORT_ENFORCE(result.second, "active_sessions has not been cleaned up for session_id", session_id_); - } - - WindowsTelemetry::RegisterInternalCallback(callback_etw_provider_key_, EtwProviderCallbackLogAllSessions); - - // If the __ctor does not finish for some reason, make sure that we still unregister - // whatever has been registered. - auto_etw_unregistrar_.emplace([this]() { UnregisterEtwCallbacks(); }); + std::lock_guard lock(active_sessions_mutex_); + active_sessions_[session_id_] = this; + + // Register callback for ETW capture state (rundown) for Microsoft.ML.ONNXRuntime provider + callback_ML_ORT_provider_ = onnxruntime::WindowsTelemetry::EtwInternalCallback( + [](LPCGUID SourceId, + ULONG IsEnabled, + UCHAR Level, + ULONGLONG MatchAnyKeyword, + ULONGLONG MatchAllKeyword, + PEVENT_FILTER_DESCRIPTOR FilterData, + PVOID CallbackContext) { + (void)SourceId; + (void)Level; + (void)MatchAnyKeyword; + (void)MatchAllKeyword; + (void)FilterData; + (void)CallbackContext; + ORT_UNUSED_PARAMETER(SourceId); + ORT_UNUSED_PARAMETER(Level); + ORT_UNUSED_PARAMETER(MatchAnyKeyword); + ORT_UNUSED_PARAMETER(MatchAllKeyword); + ORT_UNUSED_PARAMETER(FilterData); + ORT_UNUSED_PARAMETER(CallbackContext); + + // Check if this callback is for capturing state + if ((IsEnabled == EVENT_CONTROL_CODE_CAPTURE_STATE) && + ((MatchAnyKeyword & static_cast(onnxruntime::logging::ORTTraceLoggingKeyword::Session)) != 0)) { + InferenceSession::LogAllSessions(); + } + }); + WindowsTelemetry::RegisterInternalCallback(callback_ML_ORT_provider_); // Register callback for ETW start / stop so that LOGS tracing can be adjusted dynamically after session start - callback_etw_sink_key_ = "InferenceSession_Start_Stop_"; - callback_etw_sink_key_.append(std::to_string(session_id_)); auto& etwRegistrationManager = logging::EtwRegistrationManager::Instance(); + callback_ETWSink_provider_ = onnxruntime::logging::EtwRegistrationManager::EtwInternalCallback( + [&etwRegistrationManager, this](LPCGUID SourceId, + ULONG IsEnabled, + UCHAR Level, + ULONGLONG MatchAnyKeyword, + ULONGLONG MatchAllKeyword, + PEVENT_FILTER_DESCRIPTOR FilterData, + PVOID CallbackContext) { + ORT_UNUSED_PARAMETER(SourceId); + ORT_UNUSED_PARAMETER(Level); + ORT_UNUSED_PARAMETER(MatchAnyKeyword); + ORT_UNUSED_PARAMETER(MatchAllKeyword); + ORT_UNUSED_PARAMETER(FilterData); + ORT_UNUSED_PARAMETER(CallbackContext); + + if (logging_manager_ != nullptr) { + auto ortETWSeverity = etwRegistrationManager.MapLevelToSeverity(); + + if ((MatchAnyKeyword & static_cast(onnxruntime::logging::ORTTraceLoggingKeyword::Logs)) != 0 && + IsEnabled == EVENT_CONTROL_CODE_ENABLE_PROVIDER) { + LOGS(*session_logger_, VERBOSE) << "Adding ETW Sink to logger with severity level: " << (ULONG)ortETWSeverity; + logging_manager_->AddSinkOfType( + onnxruntime::logging::SinkType::EtwSink, + []() -> std::unique_ptr { return std::make_unique(); }, + ortETWSeverity); + onnxruntime::logging::LoggingManager::GetDefaultInstance()->AddSinkOfType( + onnxruntime::logging::SinkType::EtwSink, + []() -> std::unique_ptr { return std::make_unique(); }, + ortETWSeverity); + LOGS(*session_logger_, INFO) << "Done Adding ETW Sink to logger with severity level: " << (ULONG)ortETWSeverity; + } + if (IsEnabled == EVENT_CONTROL_CODE_DISABLE_PROVIDER) { + LOGS(*session_logger_, INFO) << "Removing ETW Sink from logger"; + logging_manager_->RemoveSink(onnxruntime::logging::SinkType::EtwSink); + LOGS(*session_logger_, VERBOSE) << "Done Removing ETW Sink from logger"; + } + } + }); // Register callback for ETW capture state (rundown) - etwRegistrationManager.RegisterInternalCallback(callback_etw_sink_key_, - [&etwRegistrationManager, this]( - LPCGUID SourceId, - ULONG IsEnabled, - UCHAR Level, - ULONGLONG MatchAnyKeyword, - ULONGLONG MatchAllKeyword, - PEVENT_FILTER_DESCRIPTOR FilterData, - PVOID CallbackContext) { - EtwProviderSinkControlCallback(etwRegistrationManager, SourceId, IsEnabled, Level, - MatchAnyKeyword, MatchAllKeyword, FilterData, - CallbackContext); - }); -#endif -} - -void InferenceSession::TraceSessionOptions(const SessionOptions& session_options, bool captureState, - const logging::Logger& logger) { + etwRegistrationManager.RegisterInternalCallback(callback_ETWSink_provider_); + +#endif +} + +void InferenceSession::TraceSessionOptions(const SessionOptions& session_options, bool captureState, const logging::Logger& logger) { ORT_UNUSED_PARAMETER(captureState); // Otherwise Linux build error LOGS(logger, INFO) << session_options; @@ -692,6 +737,18 @@ InferenceSession::~InferenceSession() { } } + // Unregister the session and ETW callbacks +#ifdef _WIN32 + std::lock_guard lock(active_sessions_mutex_); + if (callback_ML_ORT_provider_ != nullptr) { + WindowsTelemetry::UnregisterInternalCallback(callback_ML_ORT_provider_); + } + if (callback_ETWSink_provider_ != nullptr) { + logging::EtwRegistrationManager::Instance().UnregisterInternalCallback(callback_ETWSink_provider_); + } +#endif + active_sessions_.erase(session_id_); + #ifdef ONNXRUNTIME_ENABLE_INSTRUMENT if (session_activity_started_) TraceLoggingWriteStop(session_activity, "OrtInferenceSessionActivity"); @@ -1300,7 +1357,7 @@ common::Status InferenceSession::TransformGraph(onnxruntime::Graph& graph, bool transform_layout_fn = [this](Graph& graph_to_transform, bool& modified, const IExecutionProvider& execution_provider, const layout_transformation::DebugGraphFn& debug_graph_fn) -> Status { - AllocatorPtr cpu_allocator = std::make_shared(); + AllocatorPtr cpu_allocator = CPUAllocator::DefaultInstance(); ORT_RETURN_IF_ERROR_SESSIONID_( layout_transformation::TransformLayoutForEP(graph_to_transform, modified, execution_provider, std::move(cpu_allocator), debug_graph_fn)); @@ -1673,6 +1730,137 @@ static bool ModelHasFP16Inputs(const Graph& graph) { return false; } +#if !defined(ORT_MINIMAL_BUILD) +[[maybe_unused]] static std::string ModelWeightDataType(const Graph& graph) { + std::string data_type_list; + + for (int i = 0; i < ONNX_NAMESPACE::TensorProto_DataType_DataType_ARRAYSIZE; ++i) { + if (graph.weight_data_type_freq_[i] > 0) { + if (!data_type_list.empty()) { + data_type_list += ", "; + } + data_type_list += TensorProto_DataType_Name(i); + data_type_list += ": "; + data_type_list += std::to_string(graph.weight_data_type_freq_[i]); + } + } + + return data_type_list; +} +#endif + +#ifdef _WIN32 +[[maybe_unused]] static std::size_t GetStringHash(const std::string& string, std::size_t prev_hash) { + std::size_t hash = 0; + std::hash hasher; + const uint64_t golden_ratio = 0x9e3779b9; + + /* + Combine the current string's hash into the final hash using a mixing function. + The mixing function ensures that the order of the string affects the final hash + and reduces the likelihood of hash collisions. + Here's the breakdown: + - hasher(string): The hash of the current string being processed. + - 0x9e3779b9: A constant derived from the golden ratio, often used in hash functions + to improve distribution and reduce collisions. + - (prev_hash << 6) + (prev_hash >> 2): A bitwise operation that shifts the bits to + introduce additional entropy. + */ + + hash = hasher(string) + golden_ratio + (prev_hash << 6) + (prev_hash >> 2); + return hash; +} +#endif + +#ifdef _WIN32 +[[maybe_unused]] static std::string ComputeModelGraphHash(const Graph& graph) { + // Skip hashing if the graph contains an EPContext node. + const auto& nodes = graph.Nodes(); + for (const auto& node : nodes) { + if (node.OpType() == "EPContext") { + return "0"; + } + } + + // Graph Hash + std::size_t final_hash = 0; + const std::size_t node_hash_count = TelemetrySampleCount; + const std::size_t total_nodes = graph.NumberOfNodes(); + const std::size_t node_step = (total_nodes > node_hash_count) ? (total_nodes / node_hash_count) : 1; + + size_t index = 0; + for (const auto& node : nodes) { + if (index % node_step != 0) { + ++index; + continue; + } + + // Combine the hash of each node component using GetStringHash + final_hash = GetStringHash(node.Name(), final_hash); + final_hash = GetStringHash(node.OpType(), final_hash); + final_hash = GetStringHash(node.Domain(), final_hash); + + // Hash the input definitions + for (const auto& input : node.InputDefs()) { + if (input->Exists()) { + final_hash = GetStringHash(input->Name(), final_hash); + } + } + + // Hash the output definitions + for (const auto& output : node.OutputDefs()) { + if (output->Exists()) { + final_hash = GetStringHash(output->Name(), final_hash); + } + } + + ++index; + } + + // Convert the final hash to a string + std::ostringstream hash_stream; + hash_stream << std::hex << final_hash; + return hash_stream.str(); +} +#endif + +#ifdef _WIN32 +[[maybe_unused]] static std::string ComputeModelWeightHash(const InitializedTensorSet& initializers) { + std::size_t final_hash = 0; + const std::size_t node_hash_count = TelemetrySampleCount; + + // Weight Hash + const size_t total_initializers = initializers.size(); + const size_t initializer_step = (total_initializers > node_hash_count) ? (total_initializers / node_hash_count) : 1; + + size_t index = 0; + for (const auto& [tensor_name, tensor] : initializers) { + if (index % initializer_step != 0) { + ++index; + continue; + } + + // Combine the hash of each tensor component using GetStringHash + final_hash = GetStringHash(tensor_name, final_hash); + + if (tensor->has_data_type()) { + final_hash = GetStringHash(std::to_string(tensor->data_type()), final_hash); + } + + if (tensor->has_raw_data()) { + final_hash = GetStringHash(tensor->raw_data(), final_hash); + } + + ++index; + } + + // Convert the final hash to a string + std::ostringstream hash_stream; + hash_stream << std::hex << final_hash; + return hash_stream.str(); +} +#endif + common::Status InferenceSession::AddPrePackedWeightsContainer(PrepackedWeightsContainer* prepacked_weights_container) { if (prepacked_weights_container == nullptr) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, @@ -1716,7 +1904,7 @@ Status PartitionOrtFormatModel(onnxruntime::Graph& graph, [](Graph& graph_to_transform, bool& modified, const IExecutionProvider& execution_provider, const layout_transformation::DebugGraphFn& debug_graph_fn) -> Status { - AllocatorPtr cpu_allocator = std::make_shared(); + AllocatorPtr cpu_allocator = CPUAllocator::DefaultInstance(); return layout_transformation::TransformLayoutForEP(graph_to_transform, modified, execution_provider, std::move(cpu_allocator), debug_graph_fn); }; @@ -1749,8 +1937,7 @@ Status PartitionOrtFormatModel(onnxruntime::Graph& graph, Status ApplyOrtFormatModelRuntimeOptimizations( onnxruntime::Graph& graph, const logging::Logger& logger, const SessionOptions& session_options, const InlinedHashSet& optimizers_to_disable, const IExecutionProvider& cpu_ep, - concurrency::ThreadPool* intra_op_thread_pool, - std::unordered_map>* p_buffered_tensors) { + concurrency::ThreadPool* intra_op_thread_pool) { bool modified = false; for (int level = static_cast(TransformerLevel::Level2); @@ -1758,7 +1945,7 @@ Status ApplyOrtFormatModelRuntimeOptimizations( ++level) { const auto transformers = optimizer_utils::GenerateTransformersForMinimalBuild( static_cast(level), session_options, SatRuntimeOptimizationLoadContext{}, cpu_ep, logger, - optimizers_to_disable, intra_op_thread_pool, p_buffered_tensors); + optimizers_to_disable, intra_op_thread_pool); for (const auto& transformer : transformers) { ORT_RETURN_IF_ERROR(transformer->Apply(graph, modified, logger)); @@ -1838,10 +2025,11 @@ common::Status InferenceSession::Initialize() { // Verify that there are no external initializers in the graph if external data is disabled. onnxruntime::Graph& graph = model_->MainGraph(); + #ifdef DISABLE_EXTERNAL_INITIALIZERS const InitializedTensorSet& initializers = graph.GetAllInitializedTensors(); for (const auto& it : initializers) { - if (utils::HasExternalData(*it.second)) { + if (utils::HasExternalData(*it.second) && !utils::HasExternalDataInMemory(*it.second)) { return common::Status(common::ONNXRUNTIME, common::FAIL, "Initializer tensors with external data is not allowed."); } @@ -1884,6 +2072,35 @@ common::Status InferenceSession::Initialize() { TraceLoggingWriteStart(session_activity, "OrtInferenceSessionActivity"); session_activity_started_ = true; #endif + // Generate and cache telemetry data for the model when caller framework is WinAI + std::string model_weight_type, model_graph_hash, model_weight_hash; +#ifdef ORT_CALLER_FRAMEWORK + if (std::string_view(ORT_CALLER_FRAMEWORK) == "WinAI") { + InitializedTensorSet initializers = graph.GetAllInitializedTensors(); +#if !defined(ORT_MINIMAL_BUILD) + model_weight_type = ModelWeightDataType(graph); + SetWeightDataType(model_weight_type); +#endif +#ifdef _WIN32 + // Check if model metadata contains a "model_hash" field + const auto& metadata = model_->MetaData(); + auto model_hash_it = metadata.find("model_hash"); + + if (model_hash_it != metadata.end()) { + // Use the model_hash from metadata + model_graph_hash = model_hash_it->second; + model_weight_hash = model_hash_it->second; + } else { + // Compute hashes + model_graph_hash = ComputeModelGraphHash(graph); + model_weight_hash = (model_graph_hash == "0") ? "0" : ComputeModelWeightHash(initializers); + } + + SetGraphHash(model_graph_hash); + SetWeightHash(model_weight_hash); +#endif + } +#endif // now that we have all the execution providers, create the session state session_state_ = std::make_unique( @@ -2193,8 +2410,7 @@ common::Status InferenceSession::Initialize() { const auto& cpu_ep = *execution_providers_.Get(onnxruntime::kCpuExecutionProvider); ORT_RETURN_IF_ERROR_SESSIONID_( ApplyOrtFormatModelRuntimeOptimizations(graph, *session_logger_, session_options_, optimizers_to_disable_, - cpu_ep, GetIntraOpThreadPoolToUse(), - session_state_->GetMutableBufferedTensors())); + cpu_ep, GetIntraOpThreadPoolToUse())); #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) } @@ -2267,14 +2483,17 @@ common::Status InferenceSession::Initialize() { session_state_->PruneRemovableAttributes(); // and log telemetry + std::filesystem::path model_path = graph.ModelPath(); + std::string model_file_name = model_path.filename().string(); bool model_has_fp16_inputs = ModelHasFP16Inputs(graph); env.GetTelemetryProvider().LogSessionCreation( session_id_, model_->IrVersion(), model_->ProducerName(), model_->ProducerVersion(), model_->Domain(), - graph.DomainToVersionMap(), graph.Name(), model_->MetaData(), - telemetry_.event_name_, execution_providers_.GetIds(), model_has_fp16_inputs, false); + graph.DomainToVersionMap(), model_file_name, graph.Name(), model_weight_type, model_graph_hash, model_weight_hash, + model_->MetaData(), telemetry_.event_name_, execution_providers_.GetIds(), model_has_fp16_inputs, false); LOGS(*session_logger_, INFO) << "Session successfully initialized."; } + ORT_CATCH(const NotImplementedException& ex) { ORT_HANDLE_EXCEPTION([&]() { status = ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "Exception during initialization: ", ex.what()); @@ -2460,6 +2679,7 @@ common::Status InferenceSession::ValidateInputsOutputs(gsl::span(); if (expected_type->IsSparseTensorType()) { auto expected_element_type = expected_type->AsSparseTensorType()->GetElementType(); @@ -2670,7 +2890,7 @@ Status InferenceSession::Run(const RunOptions& run_options, gsl::span feed_names, gsl::span feeds, gsl::span output_names, std::vector* p_fetches, const std::vector* p_fetches_device_info) { - TimePoint tp; + TimePoint tp = std::chrono::high_resolution_clock::now(); if (session_profiler_.IsEnabled()) { tp = session_profiler_.Start(); } @@ -2838,18 +3058,40 @@ Status InferenceSession::Run(const RunOptions& run_options, } // keep track of telemetry - ++telemetry_.total_runs_since_last_; - telemetry_.total_run_duration_since_last_ += TimeDiffMicroSeconds(tp); + int64_t batch_size = 1; + for (const auto& feed : feeds) { + if (!feed.IsTensor()) { + continue; + } + + const Tensor& tensor = feed.Get(); + const TensorShape& shape = tensor.Shape(); + if (shape.NumDimensions() > 0) { + batch_size = shape[0]; // Extract batch size + } + // Exit the loop after finding the first tensor since subsequent feeds will have the same batch size. + break; + } // time to send telemetry? - if (TimeDiffMicroSeconds(telemetry_.time_sent_last_) > Telemetry::kDurationBetweenSending) { - // send the telemetry - env.GetTelemetryProvider().LogRuntimePerf(session_id_, telemetry_.total_runs_since_last_, - telemetry_.total_run_duration_since_last_); - // reset counters - telemetry_.time_sent_last_ = std::chrono::high_resolution_clock::now(); - telemetry_.total_runs_since_last_ = 0; - telemetry_.total_run_duration_since_last_ = 0; + { + // Adding lock_guard here to ensure that telemetry updates are thread-safe. + std::lock_guard telemetry_lock(telemetry_mutex_); + ++telemetry_.total_runs_since_last_; + telemetry_.total_run_duration_since_last_ += TimeDiffMicroSeconds(tp); + telemetry_.duration_per_batch_size_[batch_size] += TimeDiffMicroSeconds(tp); + + if (TimeDiffMicroSeconds(telemetry_.time_sent_last_) > Telemetry::kDurationBetweenSending) { + // send the telemetry + env.GetTelemetryProvider().LogRuntimePerf(session_id_, telemetry_.total_runs_since_last_, + telemetry_.total_run_duration_since_last_, + telemetry_.duration_per_batch_size_); + // reset counters + telemetry_.time_sent_last_ = std::chrono::high_resolution_clock::now(); + telemetry_.total_runs_since_last_ = 0; + telemetry_.total_run_duration_since_last_ = 0; + telemetry_.duration_per_batch_size_.clear(); + } } // log evaluation stop to trace logging provider @@ -3228,19 +3470,22 @@ common::Status InferenceSession::ValidateAndParseShrinkArenaString(const std::st } // Shrink if it is an arena based allocator - auto alloc = session_state_->GetAllocator(OrtDevice(device_type, memory_type, device_id)); - - if (alloc == nullptr) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Did not find an arena based allocator registered for device-id ", - " combination in the memory arena shrink list: ", device_id_pair); + // Iterate through the registered allocators as we could have multiple allocators for the device+type + // if they differ by vendor_id. + for (const auto& [device, allocator_ptr] : session_state_->GetAllocators()) { + if (device.Type() == device_type && device.MemType() == memory_type && device.Id() == device_id) { + if (allocator_ptr->Info().alloc_type == OrtAllocatorType::OrtArenaAllocator) { + arenas_to_shrink.push_back(allocator_ptr); + break; + } + } } - if (alloc->Info().alloc_type != OrtAllocatorType::OrtArenaAllocator) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "The registered allocator for device-id ", - " combination is not an arena based allocator: ", device_id_pair); + if (arenas_to_shrink.empty()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Did not find an arena based allocator registered for device-id ", + "combination in the memory arena shrink list: ", device_id_pair); } - - arenas_to_shrink.push_back(std::move(alloc)); } return Status::OK(); @@ -3397,8 +3642,7 @@ common::Status InferenceSession::AddPredefinedTransformers( if (use_full_build_optimizations) { return optimizer_utils::GenerateTransformers(level, session_options_, cpu_ep, logger, optimizers_to_disable_, - GetIntraOpThreadPoolToUse(), - session_state_->GetMutableBufferedTensors()); + GetIntraOpThreadPoolToUse()); } else { const auto sat_context = minimal_build_optimization_handling == @@ -3409,8 +3653,7 @@ common::Status InferenceSession::AddPredefinedTransformers( return optimizer_utils::GenerateTransformersForMinimalBuild(level, session_options_, sat_context, cpu_ep, logger, optimizers_to_disable_, - GetIntraOpThreadPoolToUse(), - session_state_->GetMutableBufferedTensors()); + GetIntraOpThreadPoolToUse()); } }(); @@ -3462,32 +3705,6 @@ IOBinding* SessionIOBinding::Get() { } #ifdef _WIN32 - -void InferenceSession::UnregisterEtwCallbacks() { - if (!callback_etw_sink_key_.empty()) { - logging::EtwRegistrationManager::Instance().UnregisterInternalCallback(callback_etw_sink_key_); - } - WindowsTelemetry::UnregisterInternalCallback(callback_etw_provider_key_); - { - std::lock_guard lock(active_sessions_mutex_); - active_sessions_.erase(session_id_); - } -} - -void InferenceSession::EtwProviderCallbackLogAllSessions(LPCGUID /* SourceId */, - ULONG IsEnabled, - UCHAR /* Level */, - ULONGLONG MatchAnyKeyword, - ULONGLONG /* MatchAllKeyword */, - PEVENT_FILTER_DESCRIPTOR /* FilterData */, - PVOID /* CallbackContext */) { - // Check if this callback is for capturing state - if ((IsEnabled == EVENT_CONTROL_CODE_CAPTURE_STATE) && - ((MatchAnyKeyword & static_cast(onnxruntime::logging::ORTTraceLoggingKeyword::Session)) != 0)) { - InferenceSession::LogAllSessions(); - } -} - void InferenceSession::LogAllSessions() { const Env& env = Env::Default(); @@ -3501,45 +3718,22 @@ void InferenceSession::LogAllSessions() { auto model = session->model_; if (nullptr != model) { - const onnxruntime::Graph& graph = model->MainGraph(); - const bool model_has_fp16_inputs = ModelHasFP16Inputs(graph); + onnxruntime::Graph& graph = model->MainGraph(); + std::filesystem::path model_path = graph.ModelPath(); + std::string model_file_name = model_path.filename().string(); + bool model_has_fp16_inputs = ModelHasFP16Inputs(graph); + std::string model_weight_type = session->GetWeightDataType(); + std::string model_graph_hash = session->GetGraphHash(); + std::string model_weight_hash = session->GetWeightHash(); env.GetTelemetryProvider().LogSessionCreation( session->session_id_, model->IrVersion(), model->ProducerName(), model->ProducerVersion(), model->Domain(), - graph.DomainToVersionMap(), graph.Name(), model->MetaData(), - session->telemetry_.event_name_, session->execution_providers_.GetIds(), model_has_fp16_inputs, true); + graph.DomainToVersionMap(), model_file_name, graph.Name(), model_weight_type, model_graph_hash, model_weight_hash, + model->MetaData(), session->telemetry_.event_name_, session->execution_providers_.GetIds(), model_has_fp16_inputs, true); } InferenceSession::TraceSessionOptions(session->session_options_, true, *session->session_logger_); } } - -void InferenceSession::EtwProviderSinkControlCallback(logging::EtwRegistrationManager& etwRegistrationManager, - LPCGUID, ULONG IsEnabled, UCHAR, ULONGLONG MatchAnyKeyword, - ULONGLONG, PEVENT_FILTER_DESCRIPTOR, PVOID) { - if (logging_manager_ != nullptr) { - auto ortETWSeverity = etwRegistrationManager.MapLevelToSeverity(); - - if ((MatchAnyKeyword & static_cast(onnxruntime::logging::ORTTraceLoggingKeyword::Logs)) != 0 && - IsEnabled == EVENT_CONTROL_CODE_ENABLE_PROVIDER) { - LOGS(*session_logger_, VERBOSE) << "Adding ETW Sink to logger with severity level: " << (ULONG)ortETWSeverity; - logging_manager_->AddSinkOfType( - onnxruntime::logging::SinkType::EtwSink, - []() -> std::unique_ptr { return std::make_unique(); }, - ortETWSeverity); - onnxruntime::logging::LoggingManager::GetDefaultInstance()->AddSinkOfType( - onnxruntime::logging::SinkType::EtwSink, - []() -> std::unique_ptr { return std::make_unique(); }, - ortETWSeverity); - LOGS(*session_logger_, INFO) << "Done Adding ETW Sink to logger with severity level: " << (ULONG)ortETWSeverity; - } - if (IsEnabled == EVENT_CONTROL_CODE_DISABLE_PROVIDER) { - LOGS(*session_logger_, INFO) << "Removing ETW Sink from logger"; - logging_manager_->RemoveSink(onnxruntime::logging::SinkType::EtwSink); - LOGS(*session_logger_, VERBOSE) << "Done Removing ETW Sink from logger"; - } - } -} - #endif } // namespace onnxruntime diff --git a/onnxruntime/core/session/inference_session.h b/onnxruntime/core/session/inference_session.h index 244fbac1bd9a8..7670197455a9d 100644 --- a/onnxruntime/core/session/inference_session.h +++ b/onnxruntime/core/session/inference_session.h @@ -134,6 +134,12 @@ class InferenceSession { }; using InputOutputDefMetaMap = InlinedHashMap; + static std::map active_sessions_; +#ifdef _WIN32 + static std::mutex active_sessions_mutex_; // Protects access to active_sessions_ + static onnxruntime::WindowsTelemetry::EtwInternalCallback callback_ML_ORT_provider_; + onnxruntime::logging::EtwRegistrationManager::EtwInternalCallback callback_ETWSink_provider_; +#endif public: #if !defined(ORT_MINIMAL_BUILD) @@ -600,6 +606,30 @@ class InferenceSession { const Model& GetModel() const; const Environment& GetEnvironment() const; + void SetWeightDataType(const std::string& type) { + weight_data_type_ = type; + } + + const std::string& GetWeightDataType() const { + return weight_data_type_; + } + + void SetGraphHash(const std::string& hash) { + graph_hash_ = hash; + } + + const std::string& GetGraphHash() const { + return graph_hash_; + } + + void SetWeightHash(const std::string& hash) { + weight_hash_ = hash; + } + + const std::string& GetWeightHash() const { + return weight_hash_; + } + protected: #if !defined(ORT_MINIMAL_BUILD) @@ -672,6 +702,17 @@ class InferenceSession { // The file path of where the model was loaded. e.g. /tmp/test_squeezenet/model.onnx PathString model_location_; + // Input, Output and Weight tensor data types + std::string input_data_type_; + std::string output_data_type_; + std::string weight_data_type_; + + // Graph hash of the model + std::string graph_hash_; + + // Weight hash of the model + std::string weight_hash_; + private: ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(InferenceSession); void SetLoggingManager(const SessionOptions& session_options, @@ -758,6 +799,10 @@ class InferenceSession { */ void ShrinkMemoryArenas(gsl::span arenas_to_shrink); +#ifdef _WIN32 + static void LogAllSessions(); +#endif + #if !defined(ORT_MINIMAL_BUILD) virtual common::Status AddPredefinedTransformers( GraphTransformerManager& transformer_manager, @@ -873,15 +918,18 @@ class InferenceSession { struct Telemetry { Telemetry() : time_sent_last_() {} - uint32_t total_runs_since_last_ = 0; // the total number of Run() calls since the last report - long long total_run_duration_since_last_ = 0; // the total duration (us) of Run() calls since the last report - std::string event_name_; // where the model is loaded from: ["model_loading_uri", "model_loading_proto", "model_loading_istream"] + uint32_t total_runs_since_last_ = 0; // the total number of Run() calls since the last report + long long total_run_duration_since_last_ = 0; // the total duration (us) of Run() calls since the last report + std::string event_name_; // where the model is loaded from: ["model_loading_uri", "model_loading_proto", "model_loading_istream"] + std::unordered_map duration_per_batch_size_; // the duration (us) of Run() calls per batch size since the last report TimePoint time_sent_last_; // the TimePoint of the last report // Event Rate per provider < 20 peak events per second constexpr static long long kDurationBetweenSending = 1000 * 1000 * 60 * 10; // duration in (us). send a report every 10 mins } telemetry_; + mutable std::mutex telemetry_mutex_; // to ensure thread-safe access to telemetry data + #ifdef ONNXRUNTIME_ENABLE_INSTRUMENT bool session_activity_started_ = false; TraceLoggingActivity session_activity; @@ -974,57 +1022,6 @@ class InferenceSession { // Enable nodestats collection std::optional node_stats_recorder_; #endif - -#ifdef _WIN32 - static std::mutex active_sessions_mutex_; // Protects access to active_sessions_ - static std::map active_sessions_; - // Single callback for all sessions. Registers when the first session comes up - // and unregister when the last session goes away. - static const std::string callback_etw_provider_key_; - std::string callback_etw_sink_key_; // Session Start Stop - - void UnregisterEtwCallbacks(); - - struct AutoEtwUnregistrar { - std::function unregister_callback; - explicit AutoEtwUnregistrar(std::function func) - : unregister_callback(std::move(func)) {} - ~AutoEtwUnregistrar() { - if (unregister_callback) { - unregister_callback(); - } - } - }; - - // Automatically cleans up all outstanding registrations - // in case session loading fails and ETW callbacks are already registered. - // We want callbacks to stop before any other members of the object are - // destroyed. - std::optional auto_etw_unregistrar_; - - // This callback is registered globally for all sessions - // It is unregistered when the last session goes away. - static void EtwProviderCallbackLogAllSessions(LPCGUID SourceId, - ULONG IsEnabled, - UCHAR Level, - ULONGLONG MatchAnyKeyword, - ULONGLONG MatchAllKeyword, - PEVENT_FILTER_DESCRIPTOR FilterData, - PVOID CallbackContext); - - static void LogAllSessions(); - - // This callback is registered per session - void EtwProviderSinkControlCallback(logging::EtwRegistrationManager& etwRegistrationManager, - LPCGUID /*SourceId */, - ULONG IsEnabled, - UCHAR /* Level */, - ULONGLONG MatchAnyKeyword, - ULONGLONG /* MatchAllKeyword */, - PEVENT_FILTER_DESCRIPTOR /* FilterData */, - PVOID /* CallbackContext */); - -#endif }; struct SessionIOBinding { diff --git a/onnxruntime/core/session/model_editor_c_api.cc b/onnxruntime/core/session/model_editor_c_api.cc index 7d5d45b7b531b..4f74e258cc943 100644 --- a/onnxruntime/core/session/model_editor_c_api.cc +++ b/onnxruntime/core/session/model_editor_c_api.cc @@ -41,11 +41,11 @@ ORT_API_STATUS_IMPL(OrtModelEditorAPI::CreateValueInfo, _In_ const char* name, _ return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "tensor_type_info cannot be null"); } - auto vi = std::make_unique(); + auto vi = std::make_unique(); vi->name = name; vi->type_info = type_info->Clone(); - *value_info = vi.release(); + *value_info = vi.release()->ToExternal(); return nullptr; API_IMPL_END @@ -58,7 +58,7 @@ ORT_API_STATUS_IMPL(OrtModelEditorAPI::CreateNode, const char* operator_name, co _In_reads_(attribs_len) _Inout_opt_ OrtOpAttr** attributes, _In_opt_ size_t attribs_len, _Outptr_ OrtNode** node) { API_IMPL_BEGIN - auto n = std::make_unique(); + auto n = std::make_unique(); n->operator_name = operator_name; n->domain_name = domain_name == kOnnxDomainAlias ? kOnnxDomain : domain_name; n->node_name = node_name; @@ -83,35 +83,48 @@ ORT_API_STATUS_IMPL(OrtModelEditorAPI::CreateNode, const char* operator_name, co } } - *node = n.release(); + *node = n.release()->ToExternal(); return nullptr; API_IMPL_END } ORT_API_STATUS_IMPL(OrtModelEditorAPI::CreateGraph, _Outptr_ OrtGraph** graph) { API_IMPL_BEGIN - auto g = std::make_unique(); + auto g = std::make_unique(); // do some reserves to reduce reallocation. if we had a hint about sizes upfront that would be optimal g->initializers.reserve(32); g->external_initializers.reserve(32); g->nodes.reserve(64); - *graph = g.release(); + *graph = g.release()->ToExternal(); return nullptr; API_IMPL_END } -ORT_API_STATUS_IMPL(OrtModelEditorAPI::SetGraphInputs, _In_ OrtGraph* graph, +ORT_API_STATUS_IMPL(OrtModelEditorAPI::SetGraphInputs, _In_ OrtGraph* ort_graph, _In_reads_(inputs_len) _In_ OrtValueInfo** inputs, _In_ size_t inputs_len) { API_IMPL_BEGIN + onnxruntime::ModelEditorGraph* graph = onnxruntime::ModelEditorGraph::ToInternal(ort_graph); + + if (graph == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, + "Invalid OrtGraph variant for use in the OrtModelEditorApi"); + } + graph->inputs.clear(); for (size_t i = 0; i < inputs_len; ++i) { if (inputs[i] == nullptr) { return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "inputs cannot contain null entries"); } - graph->inputs.push_back(std::unique_ptr(inputs[i])); // take ownership + onnxruntime::ModelEditorValueInfo* input = onnxruntime::ModelEditorValueInfo::ToInternal(inputs[i]); + if (input == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, + "Invalid OrtValueInfo variant for use in the OrtModelEditorApi"); + } + + graph->inputs.push_back(std::unique_ptr(input)); // take ownership inputs[i] = nullptr; } @@ -119,16 +132,29 @@ ORT_API_STATUS_IMPL(OrtModelEditorAPI::SetGraphInputs, _In_ OrtGraph* graph, API_IMPL_END } -ORT_API_STATUS_IMPL(OrtModelEditorAPI::SetGraphOutputs, _In_ OrtGraph* graph, +ORT_API_STATUS_IMPL(OrtModelEditorAPI::SetGraphOutputs, _In_ OrtGraph* ort_graph, _In_reads_(outputs_len) _In_ OrtValueInfo** outputs, _In_ size_t outputs_len) { API_IMPL_BEGIN + onnxruntime::ModelEditorGraph* graph = onnxruntime::ModelEditorGraph::ToInternal(ort_graph); + + if (graph == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, + "Invalid OrtGraph variant for use in the OrtModelEditorApi"); + } + graph->outputs.clear(); for (size_t i = 0; i < outputs_len; ++i) { if (outputs[i] == nullptr) { return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "outputs cannot contain null entries"); } - graph->outputs.push_back(std::unique_ptr(outputs[i])); // take ownership + onnxruntime::ModelEditorValueInfo* output = onnxruntime::ModelEditorValueInfo::ToInternal(outputs[i]); + if (output == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, + "Invalid OrtValueInfo variant for use in the OrtModelEditorApi"); + } + + graph->outputs.push_back(std::unique_ptr(output)); // take ownership outputs[i] = nullptr; } @@ -136,9 +162,16 @@ ORT_API_STATUS_IMPL(OrtModelEditorAPI::SetGraphOutputs, _In_ OrtGraph* graph, API_IMPL_END } -ORT_API_STATUS_IMPL(OrtModelEditorAPI::AddInitializerToGraph, _In_ OrtGraph* graph, _In_ const char* name, +ORT_API_STATUS_IMPL(OrtModelEditorAPI::AddInitializerToGraph, _In_ OrtGraph* ort_graph, _In_ const char* name, _Inout_ OrtValue* tensor, bool data_is_external) { API_IMPL_BEGIN + onnxruntime::ModelEditorGraph* graph = onnxruntime::ModelEditorGraph::ToInternal(ort_graph); + + if (graph == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, + "Invalid OrtGraph variant for use in the OrtModelEditorApi"); + } + if (!tensor->IsTensor()) { return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Only Tensor is currently supported."); } @@ -172,9 +205,24 @@ ORT_API_STATUS_IMPL(OrtModelEditorAPI::AddInitializerToGraph, _In_ OrtGraph* gra API_IMPL_END } -ORT_API_STATUS_IMPL(OrtModelEditorAPI::AddNodeToGraph, _In_ OrtGraph* graph, _Inout_ OrtNode* node) { +ORT_API_STATUS_IMPL(OrtModelEditorAPI::AddNodeToGraph, _In_ OrtGraph* ort_graph, _Inout_ OrtNode* ort_node) { API_IMPL_BEGIN - graph->nodes.push_back(std::unique_ptr(node)); // take ownership + onnxruntime::ModelEditorGraph* graph = onnxruntime::ModelEditorGraph::ToInternal(ort_graph); + + if (graph == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, + "Invalid OrtGraph variant for use in the OrtModelEditorApi"); + } + + onnxruntime::ModelEditorNode* node = onnxruntime::ModelEditorNode::ToInternal(ort_node); + + if (node == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, + "Invalid OrtNode variant for use in the OrtModelEditorApi"); + } + + node->id = graph->nodes.size(); + graph->nodes.push_back(std::unique_ptr(node)); // take ownership return nullptr; API_IMPL_END } diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 9e2705e3cfb57..2a6dc31344f6b 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -1,10 +1,12 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include #include #include #include #include +#include #include #include "core/common/common.h" @@ -13,6 +15,7 @@ #include "core/common/safeint.h" #include "core/common/status.h" #include "core/common/string_helper.h" +#include "core/framework/abi_pointer_array.h" #include "core/framework/allocator.h" #include "core/framework/callback.h" #include "core/framework/data_types.h" @@ -2383,19 +2386,485 @@ ORT_API(void, OrtApis::ReleaseModel, _Frees_ptr_opt_ OrtModel* model) { delete model; } +ORT_API_STATUS_IMPL(OrtApis::CreateArrayOfConstObjects, _In_ OrtTypeTag elem_type, _In_ size_t initial_size, + _In_ const void* initial_value, _Outptr_ OrtArrayOfConstObjects** out) { + API_IMPL_BEGIN + auto array = std::make_unique(elem_type, initial_size, initial_value); + *out = array.release(); + return nullptr; + API_IMPL_END +} + +ORT_API(void, OrtApis::ReleaseArrayOfConstObjects, _Frees_ptr_opt_ OrtArrayOfConstObjects* array) { + delete array; +} + +ORT_API_STATUS_IMPL(OrtApis::ArrayOfConstObjects_GetObjectType, _In_ const OrtArrayOfConstObjects* array, + _Out_ OrtTypeTag* type_tag) { + API_IMPL_BEGIN + if (type_tag == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "'type_tag' argument is NULL"); + } + + *type_tag = array->object_type; + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtApis::ArrayOfConstObjects_GetData, _In_ const OrtArrayOfConstObjects* array, + _Outptr_ const void* const** data) { + API_IMPL_BEGIN + if (data == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "'data' argument is NULL"); + } + + *data = array->storage.data(); + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtApis::ArrayOfConstObjects_GetMutableData, _In_ OrtArrayOfConstObjects* array, + _Outptr_ const void*** data) { + API_IMPL_BEGIN + if (data == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "'data' argument is NULL"); + } + + *data = array->storage.data(); + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtApis::ArrayOfConstObjects_GetSize, _In_ const OrtArrayOfConstObjects* array, + _Out_ size_t* size) { + API_IMPL_BEGIN + if (size == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "'size' argument is NULL"); + } + + *size = array->storage.size(); + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtApis::ArrayOfConstObjects_GetElementAt, _In_ const OrtArrayOfConstObjects* array, + _In_ size_t index, _Outptr_ const void** out) { + API_IMPL_BEGIN + if (out == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "'out' argument is NULL"); + } + + if (index >= array->storage.size()) { + std::ostringstream oss; + oss << "'index' value (" << index << ") is out of bounds for array of size " << array->storage.size(); + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, oss.str().c_str()); + } + + *out = array->storage[index]; + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtApis::ArrayOfConstObjects_SetElementAt, _In_ OrtArrayOfConstObjects* array, _In_ size_t index, + _In_ const void* element) { + API_IMPL_BEGIN + if (index >= array->storage.size()) { + std::ostringstream oss; + oss << "'index' value (" << index << ") is out of bounds for array of size " << array->storage.size(); + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, oss.str().c_str()); + } + + array->storage[index] = element; + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtApis::ArrayOfConstObjects_AppendElement, _In_ OrtArrayOfConstObjects* array, + _In_ const void* element) { + API_IMPL_BEGIN + array->storage.push_back(element); + return nullptr; + API_IMPL_END +} + ORT_API_STATUS_IMPL(OrtApis::GetValueInfoName, _In_ const OrtValueInfo* value_info, _Out_ const char** name) { API_IMPL_BEGIN - *name = value_info->name.c_str(); + *name = value_info->GetName().c_str(); return nullptr; API_IMPL_END } ORT_API_STATUS_IMPL(OrtApis::GetValueInfoTypeInfo, _In_ const OrtValueInfo* value_info, _Outptr_ const OrtTypeInfo** type_info) { API_IMPL_BEGIN + *type_info = nullptr; + + const OrtTypeInfo* type_info_internal = value_info->GetTypeInfo(); + if (type_info_internal == nullptr) { + std::ostringstream oss; + oss << "OrtValueInfo '" << value_info->GetName() << "' does not have valid type information"; + return OrtApis::CreateStatus(ORT_FAIL, oss.str().c_str()); + } + + *type_info = type_info_internal; + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtApis::ValueInfo_GetValueProducer, _In_ const OrtValueInfo* value_info, + _Outptr_ const OrtNode** producer_node, _Out_opt_ size_t* producer_output_index) { + API_IMPL_BEGIN +#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) + OrtValueInfo::ProducerInfo producer_info; + ORT_API_RETURN_IF_STATUS_NOT_OK(value_info->GetProducerInfo(producer_info)); + + *producer_node = producer_info.node; + if (producer_output_index != nullptr) { + *producer_output_index = producer_info.output_index; + } + + return nullptr; +#else + ORT_UNUSED_PARAMETER(value_info); + ORT_UNUSED_PARAMETER(producer_node); + ORT_UNUSED_PARAMETER(producer_output_index); + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "ValueInfo_GetValueProducer() is not supported in this build."); +#endif + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtApis::ValueInfo_GetValueNumConsumers, _In_ const OrtValueInfo* value_info, + _Out_ size_t* num_consumers) { + API_IMPL_BEGIN +#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) + ORT_API_RETURN_IF_STATUS_NOT_OK(value_info->GetNumConsumerInfos(*num_consumers)); + return nullptr; +#else + ORT_UNUSED_PARAMETER(value_info); + ORT_UNUSED_PARAMETER(num_consumers); + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "ValueInfo_GetValueNumConsumers() is not supported in this build."); +#endif + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtApis::ValueInfo_GetValueConsumers, _In_ const OrtValueInfo* value_info, + _Out_writes_all_(max_num_consumers) const OrtNode** nodes, + _Out_writes_all_(max_num_consumers) int64_t* input_indices, + _In_ size_t max_num_consumers) { + API_IMPL_BEGIN +#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) + std::vector consumer_infos; + ORT_API_RETURN_IF_STATUS_NOT_OK(value_info->GetConsumerInfos(consumer_infos)); + size_t num_uses = std::min(max_num_consumers, consumer_infos.size()); + + for (size_t i = 0; i < num_uses; ++i) { + nodes[i] = consumer_infos[i].node; + input_indices[i] = consumer_infos[i].input_index; + } + + return nullptr; +#else + ORT_UNUSED_PARAMETER(value_info); + ORT_UNUSED_PARAMETER(nodes); + ORT_UNUSED_PARAMETER(input_indices); + ORT_UNUSED_PARAMETER(max_num_consumers); + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "ValueInfo_GetValueConsumers() is not supported in this build."); +#endif + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtApis::ValueInfo_GetInitializerValue, _In_ const OrtValueInfo* value_info, + _Outptr_ const OrtValue** initializer_value) { + API_IMPL_BEGIN + ORT_API_RETURN_IF_STATUS_NOT_OK(value_info->GetInitializerValue(*initializer_value)); + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtApis::ValueInfo_IsRequiredGraphInput, _In_ const OrtValueInfo* value_info, + _Out_ bool* is_required_graph_input) { + API_IMPL_BEGIN + if (is_required_graph_input == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "'is_required_graph_input' argument is NULL"); + } + + ORT_API_RETURN_IF_STATUS_NOT_OK(value_info->IsRequiredGraphInput(*is_required_graph_input)); + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtApis::ValueInfo_IsOptionalGraphInput, _In_ const OrtValueInfo* value_info, + _Out_ bool* is_optional_graph_input) { + API_IMPL_BEGIN + if (is_optional_graph_input == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "'is_optional_graph_input' argument is NULL"); + } + + ORT_API_RETURN_IF_STATUS_NOT_OK(value_info->IsOptionalGraphInput(*is_optional_graph_input)); + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtApis::ValueInfo_IsGraphOutput, _In_ const OrtValueInfo* value_info, + _Out_ bool* is_graph_output) { + API_IMPL_BEGIN + if (is_graph_output == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "'is_graph_output' argument is NULL"); + } + + ORT_API_RETURN_IF_STATUS_NOT_OK(value_info->IsGraphOutput(*is_graph_output)); + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtApis::ValueInfo_IsConstantInitializer, _In_ const OrtValueInfo* value_info, + _Out_ bool* is_constant_initializer) { + API_IMPL_BEGIN + if (is_constant_initializer == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "'is_constant_initializer' argument is NULL"); + } + + ORT_API_RETURN_IF_STATUS_NOT_OK(value_info->IsConstantInitializer(*is_constant_initializer)); + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtApis::ValueInfo_IsFromOuterScope, _In_ const OrtValueInfo* value_info, + _Out_ bool* is_outer_scope) { + API_IMPL_BEGIN + if (is_outer_scope == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "'is_outer_scope' argument is NULL"); + } + + ORT_API_RETURN_IF_STATUS_NOT_OK(value_info->IsFromOuterScope(*is_outer_scope)); + return nullptr; + API_IMPL_END +} + +// +// OrtGraph +// + +ORT_API_STATUS_IMPL(OrtApis::Graph_GetName, _In_ const OrtGraph* graph, _Outptr_ const char** graph_name) { + API_IMPL_BEGIN + if (graph_name == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "'graph_name' argument is NULL"); + } + + *graph_name = graph->GetName().c_str(); + + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtApis::Graph_GetOnnxIRVersion, _In_ const OrtGraph* graph, _Out_ int64_t* ir_version) { + API_IMPL_BEGIN + if (ir_version == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "'ir_version' argument is NULL"); + } + + *ir_version = graph->GetOnnxIRVersion(); + + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtApis::Graph_GetInputs, _In_ const OrtGraph* graph, _Outptr_ OrtArrayOfConstObjects** inputs) { + API_IMPL_BEGIN + if (inputs == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "'inputs' argument is NULL"); + } - *type_info = value_info->type_info.get(); + std::unique_ptr array; + ORT_API_RETURN_IF_STATUS_NOT_OK(graph->GetInputs(array)); + *inputs = array.release(); + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtApis::Graph_GetOutputs, _In_ const OrtGraph* graph, _Outptr_ OrtArrayOfConstObjects** outputs) { + API_IMPL_BEGIN + if (outputs == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "'outputs' argument is NULL"); + } + + std::unique_ptr array; + ORT_API_RETURN_IF_STATUS_NOT_OK(graph->GetOutputs(array)); + + *outputs = array.release(); + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtApis::Graph_GetInitializers, _In_ const OrtGraph* graph, + _Outptr_ OrtArrayOfConstObjects** initializers) { + API_IMPL_BEGIN + if (initializers == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "'initializers' argument is NULL"); + } + + std::unique_ptr array; + ORT_API_RETURN_IF_STATUS_NOT_OK(graph->GetInitializers(array)); + + *initializers = array.release(); + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtApis::Graph_GetNodes, const OrtGraph* graph, _Outptr_ OrtArrayOfConstObjects** nodes) { + API_IMPL_BEGIN + if (nodes == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "'nodes' argument is NULL"); + } + + std::unique_ptr array; + ORT_API_RETURN_IF_STATUS_NOT_OK(graph->GetNodes(array)); + + *nodes = array.release(); + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtApis::Graph_GetParentNode, _In_ const OrtGraph* graph, _Outptr_ const OrtNode** node) { + API_IMPL_BEGIN + if (node == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "'node' argument is NULL"); + } + + ORT_API_RETURN_IF_STATUS_NOT_OK(graph->GetParentNode(*node)); + return nullptr; + API_IMPL_END +} + +// +// OrtNode +// + +ORT_API_STATUS_IMPL(OrtApis::Node_GetId, _In_ const OrtNode* node, _Out_ size_t* node_id) { + API_IMPL_BEGIN + if (node_id == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "'node_id' argument is NULL"); + } + + *node_id = node->GetId(); + + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtApis::Node_GetName, _In_ const OrtNode* node, _Outptr_ const char** node_name) { + API_IMPL_BEGIN + if (node_name == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "'node_name' argument is NULL"); + } + + *node_name = node->GetName().c_str(); + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtApis::Node_GetOperatorType, _In_ const OrtNode* node, _Outptr_ const char** operator_type) { + API_IMPL_BEGIN + if (operator_type == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "'operator_type' argument is NULL"); + } + + *operator_type = node->GetOpType().c_str(); + + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtApis::Node_GetDomain, _In_ const OrtNode* node, _Outptr_ const char** domain_name) { + API_IMPL_BEGIN + if (domain_name == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "'operator_type' argument is NULL"); + } + + *domain_name = node->GetDomain().c_str(); + + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtApis::Node_GetSinceVersion, _In_ const OrtNode* node, _Out_ int* since_version) { + API_IMPL_BEGIN + if (since_version == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "'since_version' argument is NULL"); + } + + ORT_API_RETURN_IF_STATUS_NOT_OK(node->GetSinceVersion(*since_version)); + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtApis::Node_GetInputs, _In_ const OrtNode* node, _Outptr_ OrtArrayOfConstObjects** inputs) { + API_IMPL_BEGIN + if (inputs == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "'inputs' argument is NULL"); + } + + std::unique_ptr array; + ORT_API_RETURN_IF_STATUS_NOT_OK(node->GetInputs(array)); + + *inputs = array.release(); + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtApis::Node_GetOutputs, _In_ const OrtNode* node, _Outptr_ OrtArrayOfConstObjects** outputs) { + API_IMPL_BEGIN + if (outputs == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "'outputs' argument is NULL"); + } + + std::unique_ptr array; + ORT_API_RETURN_IF_STATUS_NOT_OK(node->GetOutputs(array)); + + *outputs = array.release(); + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtApis::Node_GetImplicitInputs, _In_ const OrtNode* node, + _Outptr_ OrtArrayOfConstObjects** implicit_inputs) { + API_IMPL_BEGIN + if (implicit_inputs == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "'implicit_inputs' argument is NULL"); + } + + std::unique_ptr array; + ORT_API_RETURN_IF_STATUS_NOT_OK(node->GetImplicitInputs(array)); + + *implicit_inputs = array.release(); + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtApis::Node_GetSubgraphs, _In_ const OrtNode* node, _Outptr_ OrtArrayOfConstObjects** subgraphs) { + API_IMPL_BEGIN + if (subgraphs == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "'subgraphs' argument is NULL"); + } + + std::unique_ptr array; + ORT_API_RETURN_IF_STATUS_NOT_OK(node->GetSubgraphs(array)); + + *subgraphs = array.release(); + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtApis::Node_GetParentGraph, _In_ const OrtNode* node, + _Outptr_result_maybenull_ const OrtGraph** parent_graph) { + API_IMPL_BEGIN + if (parent_graph == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "'parent_graph' argument is NULL"); + } + + *parent_graph = nullptr; + ORT_API_RETURN_IF_STATUS_NOT_OK(node->GetParentGraph(*parent_graph)); return nullptr; API_IMPL_END } @@ -3029,8 +3498,48 @@ static constexpr OrtApi ort_api_1_to_23 = { &OrtApis::GetEpApi, // End of Version 22 - DO NOT MODIFY ABOVE (see above text for more information) + &OrtApis::GetTensorSizeInBytes, &OrtApis::AllocatorGetStats, + &OrtApis::CreateMemoryInfo_V2, + + &OrtApis::CreateArrayOfConstObjects, + &OrtApis::ReleaseArrayOfConstObjects, + &OrtApis::ArrayOfConstObjects_GetObjectType, + &OrtApis::ArrayOfConstObjects_GetData, + &OrtApis::ArrayOfConstObjects_GetMutableData, + &OrtApis::ArrayOfConstObjects_GetSize, + &OrtApis::ArrayOfConstObjects_GetElementAt, + &OrtApis::ArrayOfConstObjects_SetElementAt, + &OrtApis::ArrayOfConstObjects_AppendElement, + + &OrtApis::ValueInfo_GetValueProducer, + &OrtApis::ValueInfo_GetValueNumConsumers, + &OrtApis::ValueInfo_GetValueConsumers, + &OrtApis::ValueInfo_GetInitializerValue, + &OrtApis::ValueInfo_IsRequiredGraphInput, + &OrtApis::ValueInfo_IsOptionalGraphInput, + &OrtApis::ValueInfo_IsGraphOutput, + &OrtApis::ValueInfo_IsConstantInitializer, + &OrtApis::ValueInfo_IsFromOuterScope, + &OrtApis::Graph_GetName, + &OrtApis::Graph_GetOnnxIRVersion, + &OrtApis::Graph_GetInputs, + &OrtApis::Graph_GetOutputs, + &OrtApis::Graph_GetInitializers, + &OrtApis::Graph_GetNodes, + &OrtApis::Graph_GetParentNode, + &OrtApis::Node_GetId, + &OrtApis::Node_GetName, + &OrtApis::Node_GetOperatorType, + &OrtApis::Node_GetDomain, + &OrtApis::Node_GetSinceVersion, + &OrtApis::Node_GetInputs, + &OrtApis::Node_GetOutputs, + &OrtApis::Node_GetImplicitInputs, + &OrtApis::Node_GetSubgraphs, + &OrtApis::Node_GetParentGraph, + }; // OrtApiBase can never change as there is no way to know what version of OrtApiBase is returned by OrtGetApiBase. diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index dcd1b3069bcac..32319152d7e01 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -602,4 +602,68 @@ ORT_API(const OrtEpApi*, GetEpApi); ORT_API_STATUS_IMPL(GetTensorSizeInBytes, _In_ const OrtValue* ort_value, _Out_ size_t* size); ORT_API_STATUS_IMPL(AllocatorGetStats, _In_ const OrtAllocator* ptr, _Outptr_ OrtKeyValuePairs** out); + +ORT_API_STATUS_IMPL(CreateMemoryInfo_V2, _In_ const char* name, _In_ enum OrtMemoryInfoDeviceType device_type, + _In_ uint32_t vendor_id, _In_ int16_t device_id, _In_ enum OrtDeviceMemoryType mem_type, + _In_ size_t alignment, enum OrtAllocatorType allocator_type, + _Outptr_ OrtMemoryInfo** out); + +// OrtArrayOfConstObjects +ORT_API_STATUS_IMPL(CreateArrayOfConstObjects, _In_ OrtTypeTag object_type, _In_ size_t initial_size, + _In_ const void* initial_value, _Outptr_ OrtArrayOfConstObjects** out); +ORT_API(void, ReleaseArrayOfConstObjects, _Frees_ptr_opt_ OrtArrayOfConstObjects* array); +ORT_API_STATUS_IMPL(ArrayOfConstObjects_GetObjectType, _In_ const OrtArrayOfConstObjects* array, + _Out_ OrtTypeTag* type_tag); +ORT_API_STATUS_IMPL(ArrayOfConstObjects_GetData, _In_ const OrtArrayOfConstObjects* array, + _Outptr_ const void* const** data); +ORT_API_STATUS_IMPL(ArrayOfConstObjects_GetMutableData, _In_ OrtArrayOfConstObjects* array, _Outptr_ const void*** data); +ORT_API_STATUS_IMPL(ArrayOfConstObjects_GetSize, _In_ const OrtArrayOfConstObjects* array, _Out_ size_t* size); +ORT_API_STATUS_IMPL(ArrayOfConstObjects_GetElementAt, _In_ const OrtArrayOfConstObjects* array, _In_ size_t index, + _Outptr_ const void** out); +ORT_API_STATUS_IMPL(ArrayOfConstObjects_SetElementAt, _In_ OrtArrayOfConstObjects* array, _In_ size_t index, + _In_ const void* element); +ORT_API_STATUS_IMPL(ArrayOfConstObjects_AppendElement, _In_ OrtArrayOfConstObjects* array, _In_ const void* element); + +// OrtValueInfo +ORT_API_STATUS_IMPL(ValueInfo_GetValueProducer, _In_ const OrtValueInfo* value_info, + _Outptr_ const OrtNode** producer_node, _Out_opt_ size_t* producer_output_index); +ORT_API_STATUS_IMPL(ValueInfo_GetValueNumConsumers, _In_ const OrtValueInfo* value_info, _Out_ size_t* num_consumers); +ORT_API_STATUS_IMPL(ValueInfo_GetValueConsumers, _In_ const OrtValueInfo* value_info, + _Out_writes_all_(max_num_consumers) const OrtNode** nodes, + _Out_writes_all_(max_num_consumers) int64_t* input_indices, + _In_ size_t max_num_consumers); +ORT_API_STATUS_IMPL(ValueInfo_GetInitializerValue, _In_ const OrtValueInfo* value_info, + _Outptr_ const OrtValue** initializer_value); +ORT_API_STATUS_IMPL(ValueInfo_IsRequiredGraphInput, _In_ const OrtValueInfo* value_info, + _Out_ bool* is_required_graph_input); +ORT_API_STATUS_IMPL(ValueInfo_IsOptionalGraphInput, _In_ const OrtValueInfo* value_info, + _Out_ bool* is_optional_graph_input); +ORT_API_STATUS_IMPL(ValueInfo_IsGraphOutput, _In_ const OrtValueInfo* value_info, _Out_ bool* is_graph_output); +ORT_API_STATUS_IMPL(ValueInfo_IsConstantInitializer, _In_ const OrtValueInfo* value_info, + _Out_ bool* is_constant_initializer); +ORT_API_STATUS_IMPL(ValueInfo_IsFromOuterScope, _In_ const OrtValueInfo* value_info, + _Out_ bool* is_from_outer_scope); + +// OrtGraph +ORT_API_STATUS_IMPL(Graph_GetName, _In_ const OrtGraph* graph, _Outptr_ const char** graph_name); +ORT_API_STATUS_IMPL(Graph_GetOnnxIRVersion, _In_ const OrtGraph* graph, _Out_ int64_t* onnx_ir_version); +ORT_API_STATUS_IMPL(Graph_GetInputs, _In_ const OrtGraph* graph, _Outptr_ OrtArrayOfConstObjects** inputs); +ORT_API_STATUS_IMPL(Graph_GetOutputs, _In_ const OrtGraph* graph, _Outptr_ OrtArrayOfConstObjects** outputs); +ORT_API_STATUS_IMPL(Graph_GetInitializers, _In_ const OrtGraph* graph, _Outptr_ OrtArrayOfConstObjects** initializers); +ORT_API_STATUS_IMPL(Graph_GetNodes, const OrtGraph* graph, _Outptr_ OrtArrayOfConstObjects** nodes); +ORT_API_STATUS_IMPL(Graph_GetParentNode, _In_ const OrtGraph* graph, _Outptr_result_maybenull_ const OrtNode** node); + +// OrtNode +ORT_API_STATUS_IMPL(Node_GetId, _In_ const OrtNode* node, _Out_ size_t* node_id); +ORT_API_STATUS_IMPL(Node_GetName, _In_ const OrtNode* node, _Outptr_ const char** node_name); +ORT_API_STATUS_IMPL(Node_GetOperatorType, _In_ const OrtNode* node, _Outptr_ const char** operator_type); +ORT_API_STATUS_IMPL(Node_GetDomain, _In_ const OrtNode* node, _Outptr_ const char** domain_name); +ORT_API_STATUS_IMPL(Node_GetSinceVersion, _In_ const OrtNode* node, _Out_ int* since_version); +ORT_API_STATUS_IMPL(Node_GetInputs, _In_ const OrtNode* node, _Outptr_ OrtArrayOfConstObjects** inputs); +ORT_API_STATUS_IMPL(Node_GetOutputs, _In_ const OrtNode* node, _Outptr_ OrtArrayOfConstObjects** outputs); +ORT_API_STATUS_IMPL(Node_GetImplicitInputs, _In_ const OrtNode* node, _Outptr_ OrtArrayOfConstObjects** implicit_inputs); +ORT_API_STATUS_IMPL(Node_GetSubgraphs, _In_ const OrtNode* node, _Outptr_ OrtArrayOfConstObjects** subgraphs); +ORT_API_STATUS_IMPL(Node_GetParentGraph, _In_ const OrtNode* node, + _Outptr_result_maybenull_ const OrtGraph** parent_graph); + } // namespace OrtApis diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index 197fb4320e6bf..422668ef1a27f 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -47,6 +47,7 @@ #include "core/session/provider_bridge_ort.h" #include "core/util/math.h" #include "onnx/shape_inference/implementation.h" +#include "core/optimizer/initializer.h" #ifdef ENABLE_TRAINING #ifdef ENABLE_TRAINING_TORCH_INTEROP @@ -462,14 +463,13 @@ struct ProviderHostImpl : ProviderHost { } void logging__EtwRegistrationManager__RegisterInternalCallback( logging::EtwRegistrationManager* p, - const std::string& cb_key, - logging::EtwRegistrationManager_EtwInternalCallback callback) override { - p->RegisterInternalCallback(cb_key, std::move(callback)); + const logging::EtwRegistrationManager_EtwInternalCallback& callback) override { + p->RegisterInternalCallback(callback); } void logging__EtwRegistrationManager__UnregisterInternalCallback( logging::EtwRegistrationManager* p, - const std::string& cb_key) override { - p->UnregisterInternalCallback(cb_key); + const logging::EtwRegistrationManager_EtwInternalCallback& callback) override { + p->UnregisterInternalCallback(callback); } #endif // defined(_WIN32) @@ -1238,6 +1238,15 @@ struct ProviderHostImpl : ProviderHost { execution_provider_name, drop_constant_initializers); } + Status Utils__GetTensorProtoWithDataIfInMemory( + const ONNX_NAMESPACE::TensorProto& tensor_proto, std::unique_ptr& result) override { + return onnxruntime::utils::GetTensorProtoWithDataIfInMemory(tensor_proto, result); + } + + bool Utils__HasExternalDataInMemory(const ONNX_NAMESPACE::TensorProto& ten_proto) override { + return onnxruntime::utils::HasExternalDataInMemory(ten_proto); + } + // Model (wrapped) std::unique_ptr Model__construct(ONNX_NAMESPACE::ModelProto&& model_proto, const PathString& model_path, const IOnnxRuntimeOpSchemaRegistryList* local_registries, @@ -1273,6 +1282,8 @@ struct ProviderHostImpl : ProviderHost { Status Graph__Resolve(Graph* p) override { return p->Resolve(); } void Graph__AddInitializedTensor(Graph* p, const ONNX_NAMESPACE::TensorProto& tensor) override { p->AddInitializedTensor(tensor); } + Status Graph__AddInitializedOrtValue(Graph* p, const ONNX_NAMESPACE::TensorProto& tensor, + const OrtValue& value) override { return p->AddInitializedOrtValue(tensor, value); } Node& Graph__AddNode(Graph* p, const std::string& name, const std::string& op_type, const std::string& description, const gsl::span& input_args, const gsl::span& output_args, const NodeAttributes* attributes, const std::string& domain) override { return p->AddNode(name, op_type, description, input_args, output_args, attributes, domain); } @@ -1423,6 +1434,75 @@ struct ProviderHostImpl : ProviderHost { } bool ConstGraphNodes__empty(const ConstGraphNodes* p) noexcept override { return p->empty(); } + NodeArg& GraphUtils__AddInitializerWithExternalData(Graph& graph, + const ONNX_NAMESPACE::TensorProto& new_initializer) override { + return graph_utils::AddInitializerWithExternalData(graph, new_initializer); + } + + void GraphUtils__MakeInitializerCopyIfNotExist(const Graph& src_graph, Graph& dst_graph, + const std::string& name, bool load_in_memory) override { + graph_utils::MakeInitializerCopyIfNotExist(src_graph, dst_graph, name, load_in_memory); + } + + // Initializer (wrapped) + Initializer* Initializer__constructor(ONNX_NAMESPACE::TensorProto_DataType data_type, + std::string_view name, + gsl::span dims) override { + return new Initializer(data_type, name, dims); + } + + Initializer* Initializer__constructor(const Graph& graph, const ONNX_NAMESPACE::TensorProto& tensor_proto, + const std::filesystem::path& model_path, + bool check_outer_scope) override { + return new Initializer(graph, tensor_proto, model_path, check_outer_scope); + } + void Initializer__destructor(Initializer* p) override { delete p; } + void Initializer__ToProto(const Initializer& initializer, + ONNX_NAMESPACE::TensorProto& tensor_proto) override { + initializer.ToProto(tensor_proto); + } + void Initializer__ToProtoWithOrtValue(const Initializer& initializer, + ONNX_NAMESPACE::TensorProto& tensor_proto, OrtValue& ort_value) override { + initializer.ToProtoWithOrtValue(tensor_proto, ort_value); + } + int Initializer__data_type(const Initializer& initializer) override { + return initializer.data_type(); + } + const std::string& Initializer__name(const Initializer& initializer) override { + return initializer.name(); + } + gsl::span Initializer__dims(const Initializer& initializer) override { + return initializer.dims(); + } + size_t Initializer__size(const Initializer& initializer) override { + return initializer.size(); + } + + void* Initializer__mutable_data(Initializer& initializer, int data_type) override { + if (data_type != initializer.data_type()) { + throw std::invalid_argument("Initializer mutable data type mismatch"); + } + return initializer.mutable_data_raw(); + } + + const void* Initializer__data(const Initializer& initializer, int data_type) override { + if (data_type != initializer.data_type()) { + throw std::invalid_argument("Initializer data type mismatch"); + } + return initializer.data_raw(); + } + + void* Initializer__mutable_data_raw(Initializer& initializer) override { + return initializer.mutable_data_raw(); + } + const void* Initializer__data_raw(const Initializer& initializer) override { + return initializer.data_raw(); + } + + Status GraphUtils__ConvertInMemoryDataToInline(Graph& graph, const std::string& name) override { + return graph_utils::ConvertInMemoryDataToInline(graph, name); + } + // OpKernel (direct) const Node& OpKernel__Node(const OpKernel* p) override { return p->OpKernel::Node(); } diff --git a/onnxruntime/core/session/provider_policy_context.cc b/onnxruntime/core/session/provider_policy_context.cc index a4e0c16b411a1..edd937c870260 100644 --- a/onnxruntime/core/session/provider_policy_context.cc +++ b/onnxruntime/core/session/provider_policy_context.cc @@ -6,10 +6,12 @@ #include "core/session/provider_policy_context.h" #include +#include #include "core/framework/error_code_helper.h" #include "core/session/abi_devices.h" #include "core/session/ep_factory_internal.h" +#include "core/session/ep_plugin_provider_interfaces.h" #include "core/session/inference_session.h" #include "core/session/inference_session_utils.h" #include "core/session/onnxruntime_c_api.h" @@ -281,26 +283,16 @@ Status ProviderPolicyContext::CreateExecutionProvider(const Environment& env, Or if (internal_factory) { // this is a factory we created and registered internally for internal and provider bridge EPs - OrtStatus* status = internal_factory->CreateIExecutionProvider(info.devices.data(), info.ep_metadata.data(), - info.devices.size(), &options, &logger, - &ep); - if (status != nullptr) { - return ToStatus(status); - } + ORT_RETURN_IF_ERROR(ToStatusAndRelease( + internal_factory->CreateIExecutionProvider(info.devices.data(), info.ep_metadata.data(), + info.devices.size(), &options, &logger, + &ep))); } else { - // in the real setup we need an IExecutionProvider wrapper implementation that uses the OrtEp internally, - // and we would add that IExecutionProvider to the InferenceSession. - // but first we need OrtEp and the OrtEpApi to be implemented. - ORT_NOT_IMPLEMENTED("IExecutionProvider that wraps OrtEp has not been implemented."); - - // OrtEp* api_ep = nullptr; - //// add the ep_options to session options but leave any existing entries (user provided overrides) untouched. - // auto status = info.ep_factory->CreateEp(info.ep_factory, info.devices.data(), info.ep_metadata.data(), - // info.devices.size(), &options, &logger, - // &api_ep); - // if (status != nullptr) { - // return ToStatus(status); - // } + OrtEp* api_ep = nullptr; + ORT_RETURN_IF_ERROR(ToStatusAndRelease(info.ep_factory->CreateEp(info.ep_factory, info.devices.data(), + info.ep_metadata.data(), info.devices.size(), + &options, &logger, &api_ep))); + ep = std::make_unique(UniqueOrtEp(api_ep, OrtEpDeleter(*info.ep_factory)), options); } return Status::OK(); diff --git a/onnxruntime/core/session/utils.cc b/onnxruntime/core/session/utils.cc index 8ca4ef6af1f44..d0f2e862d61d9 100644 --- a/onnxruntime/core/session/utils.cc +++ b/onnxruntime/core/session/utils.cc @@ -3,6 +3,9 @@ #include "core/session/utils.h" +#include +#include + #include "core/framework/error_code_helper.h" #include "core/framework/execution_provider.h" #include "core/framework/provider_options.h" @@ -17,6 +20,7 @@ #if !defined(ORT_MINIMAL_BUILD) #include "core/session/ep_factory_internal.h" +#include "core/session/ep_plugin_provider_interfaces.h" #include "core/session/ep_library_plugin.h" #include "core/session/ep_library_provider_bridge.h" #include "core/session/model_compilation_options.h" @@ -68,12 +72,8 @@ Status TestAutoSelectEPsImpl(const Environment& env, InferenceSession& sess, con if (internal_factory) { // this is a factory we created and registered. internal or provider bridge EP. - OrtStatus* status = internal_factory->CreateIExecutionProvider( - devices.data(), ep_metadata.data(), devices.size(), &ort_so, &api_session_logger, &ep); - - if (status != nullptr) { - return ToStatus(status); - } + ORT_RETURN_IF_ERROR(ToStatusAndRelease(internal_factory->CreateIExecutionProvider( + devices.data(), ep_metadata.data(), devices.size(), &ort_so, &api_session_logger, &ep))); } else { // in the real setup we need an IExecutionProvider wrapper implementation that uses the OrtEp internally, // and we would add that IExecutionProvider to the InferenceSession. @@ -81,13 +81,9 @@ Status TestAutoSelectEPsImpl(const Environment& env, InferenceSession& sess, con /* OrtEp* api_ep = nullptr; - auto status = ep_device->ep_factory->CreateEp( + ORT_RETURN_IF_ERROR(ToStatusAndRelease(ep_device->ep_factory->CreateEp( ep_device->ep_factory, devices.data(), ep_metadata.data(), devices.size(), - &ort_so, &api_session_logger, &api_ep); - - if (status != nullptr) { - return ToStatus(status); - } + &ort_so, &api_session_logger, &api_ep))); */ } @@ -270,17 +266,18 @@ Status CompileModel(const Environment& env, const ModelCompilationOptions& model if (model_compile_options.InputModelComesFromFile()) { PathString input_model_path = ToPathString(model_compile_options.GetInputModelPath()); - ORT_RETURN_IF_ERROR(ToStatus(CreateSessionAndLoadModelImpl(session_options, env, - input_model_path.c_str(), - nullptr, 0, session))); + ORT_RETURN_IF_ERROR(ToStatusAndRelease(CreateSessionAndLoadModelImpl(session_options, env, + input_model_path.c_str(), + nullptr, 0, session))); } else { - ORT_RETURN_IF_ERROR(ToStatus(CreateSessionAndLoadModelImpl(session_options, env, nullptr, - model_compile_options.GetInputModelData(), - model_compile_options.GetInputModelDataSize(), - session))); + ORT_RETURN_IF_ERROR( + ToStatusAndRelease(CreateSessionAndLoadModelImpl(session_options, env, nullptr, + model_compile_options.GetInputModelData(), + model_compile_options.GetInputModelDataSize(), + session))); } - ORT_RETURN_IF_ERROR(ToStatus(InitializeSession(session_options, *session))); + ORT_RETURN_IF_ERROR(ToStatusAndRelease(InitializeSession(session_options, *session))); return Status::OK(); } @@ -341,21 +338,17 @@ Status CreateIExecutionProviderFactoryForEpDevices(const Environment& env, } const auto& ep_name = ep_devices[0]->ep_name; + OrtEpFactory* ep_factory = ep_devices[0]->ep_factory; bool all_match = std::all_of(ep_devices.begin() + 1, ep_devices.end(), - [&ep_name](const OrtEpDevice* ep_device) { return ep_device->ep_name == ep_name; }); + [&ep_name, &ep_factory](const OrtEpDevice* ep_device) { + return (ep_device->ep_name == ep_name) && (ep_device->ep_factory == ep_factory); + }); if (!all_match) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "All OrtEpDevice values in ep_devices must have the same execution provider."); } - EpFactoryInternal* internal_factory = nullptr; for (const OrtEpDevice* ep_device : ep_devices) { - // we expect the internal factory to be available for internal and provider bridge EPs, which is all we support. - internal_factory = env.GetEpFactoryInternal(ep_device->ep_factory); - if (!internal_factory) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "EP is not currently supported by this API"); - } - // add the options to the session options with the EP prefix. // first add the default values with prefix followed by user specified values so those win const std::string prefix = OrtSessionOptions::GetProviderOptionPrefix(ep_device->ep_name.c_str()); @@ -373,13 +366,14 @@ Status CreateIExecutionProviderFactoryForEpDevices(const Environment& env, } } - if (!internal_factory) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "EP is not currently supported by this API"); + EpFactoryInternal* internal_factory = env.GetEpFactoryInternal(ep_factory); + + if (internal_factory) { + out = std::make_unique(*internal_factory, ep_devices); + } else { + out = std::make_unique(*ep_factory, ep_devices); } - out = std::make_unique(*internal_factory, - std::vector(ep_devices.begin(), - ep_devices.end())); return Status::OK(); } #endif // !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/lora/adapter_format_utils.cc b/onnxruntime/lora/adapter_format_utils.cc index 7986082da06f7..ad8c9c13bc9ca 100644 --- a/onnxruntime/lora/adapter_format_utils.cc +++ b/onnxruntime/lora/adapter_format_utils.cc @@ -127,7 +127,7 @@ struct ReadDataForBigEndian { // If BE, we a allocate memory within the tensor and copy there swapping bytes [[maybe_unused]] static Status CreateOrtValueForBePlatforms(const Parameter& param, const MLDataType elem_type, gsl::span shape, OrtValue& result) { - static const AllocatorPtr cpu_allocator = std::make_shared(); + static const AllocatorPtr cpu_allocator = CPUAllocator::DefaultInstance(); auto src_span = ReinterpretAsSpan( gsl::make_span(param.raw_data()->data(), param.raw_data()->size())); diff --git a/onnxruntime/python/onnxruntime_pybind_iobinding.cc b/onnxruntime/python/onnxruntime_pybind_iobinding.cc index f26e188187412..b82dd1474bdf6 100644 --- a/onnxruntime/python/onnxruntime_pybind_iobinding.cc +++ b/onnxruntime/python/onnxruntime_pybind_iobinding.cc @@ -43,7 +43,7 @@ void BindOutput(SessionIOBinding* io_binding, const std::string& name, const Ort } OrtValue ml_value; - OrtMemoryInfo info(GetDeviceName(device), OrtDeviceAllocator, device, device.Id()); + OrtMemoryInfo info(GetDeviceName(device), OrtDeviceAllocator, device); Tensor::InitOrtValue(element_type, gsl::make_span(shape), reinterpret_cast(data_ptr), info, ml_value); auto status = io_binding->Get()->BindOutput(name, ml_value); @@ -94,7 +94,7 @@ void addIoBindingMethods(pybind11::module& m) { .def("bind_input", [](SessionIOBinding* io_binding, const std::string& name, const OrtDevice& device, int32_t element_type, const std::vector& shape, int64_t data_ptr) -> void { auto ml_type = OnnxTypeToOnnxRuntimeTensorType(element_type); OrtValue ml_value; - OrtMemoryInfo info(GetDeviceName(device), OrtDeviceAllocator, device, device.Id()); + OrtMemoryInfo info(GetDeviceName(device), OrtDeviceAllocator, device); Tensor::InitOrtValue(ml_type, gsl::make_span(shape), reinterpret_cast(data_ptr), info, ml_value); auto status = io_binding->Get()->BindInput(name, ml_value); @@ -111,7 +111,7 @@ void addIoBindingMethods(pybind11::module& m) { int type_num = dtype->type_num; Py_DECREF(dtype); - OrtMemoryInfo info(GetDeviceName(device), OrtDeviceAllocator, device, device.Id()); + OrtMemoryInfo info(GetDeviceName(device), OrtDeviceAllocator, device); auto ml_type = NumpyTypeToOnnxRuntimeTensorType(type_num); OrtValue ml_value; Tensor::InitOrtValue(ml_type, gsl::make_span(shape), reinterpret_cast(data_ptr), info, ml_value); diff --git a/onnxruntime/python/onnxruntime_pybind_mlvalue.cc b/onnxruntime/python/onnxruntime_pybind_mlvalue.cc index 0428b19357d51..958c9fc46bcd8 100644 --- a/onnxruntime/python/onnxruntime_pybind_mlvalue.cc +++ b/onnxruntime/python/onnxruntime_pybind_mlvalue.cc @@ -336,7 +336,7 @@ void DmlToCpuMemCpy(void* dst, const void* src, size_t num_bytes) { const std::unordered_map* GetDmlToHostMemCpyFunction() { static std::unordered_map map{ - {OrtDevice::DML, DmlToCpuMemCpy}}; + {OrtDevice::GPU, DmlToCpuMemCpy}}; return ↦ } diff --git a/onnxruntime/python/onnxruntime_pybind_module.cc b/onnxruntime/python/onnxruntime_pybind_module.cc index 43e5c30082be6..a067f8d548799 100644 --- a/onnxruntime/python/onnxruntime_pybind_module.cc +++ b/onnxruntime/python/onnxruntime_pybind_module.cc @@ -7,6 +7,7 @@ #include "core/providers/get_execution_providers.h" #include "onnxruntime_config.h" #include "core/common/common.h" +#include "core/session/environment.h" #include "core/session/ort_env.h" #include "core/session/inference_session.h" #include "core/session/provider_bridge_ort.h" @@ -21,6 +22,8 @@ std::unique_ptr CreateExecutionProviderInstance( const ProviderOptionsMap& provider_options_map); bool InitArray(); static OrtEnv* ort_env = nullptr; +static OrtThreadingOptions global_tp_options; +static bool use_global_tp = false; onnxruntime::Environment& GetEnv() { return ort_env->GetEnvironment(); } @@ -32,7 +35,7 @@ static Status CreateOrtEnv() { Env::Default().GetTelemetryProvider().SetLanguageProjection(OrtLanguageProjection::ORT_PROJECTION_PYTHON); OrtEnv::LoggingManagerConstructionInfo lm_info{nullptr, nullptr, ORT_LOGGING_LEVEL_WARNING, "Default"}; Status status; - ort_env = OrtEnv::GetInstance(lm_info, status); + ort_env = OrtEnv::GetInstance(lm_info, status, use_global_tp ? &global_tp_options : nullptr); if (!status.IsOK()) return status; // Keep the ort_env alive, don't free it. It's ok to leak the memory. #if !defined(__APPLE__) && !defined(ORT_MINIMAL_BUILD) @@ -44,6 +47,17 @@ static Status CreateOrtEnv() { return Status::OK(); } +void SetGlobalThreadingOptions(const OrtThreadingOptions&& tp_options) { + if (ort_env != nullptr) { + OrtPybindThrowIfError(GetEnv().SetGlobalThreadingOptions(tp_options)); + } + global_tp_options = tp_options; + use_global_tp = true; +} +bool CheckIfUsingGlobalThreadPool() { + return use_global_tp; +} + namespace py = pybind11; /* diff --git a/onnxruntime/python/onnxruntime_pybind_ortvalue.cc b/onnxruntime/python/onnxruntime_pybind_ortvalue.cc index 23617f4fce76f..d1d4d6f3cdad5 100644 --- a/onnxruntime/python/onnxruntime_pybind_ortvalue.cc +++ b/onnxruntime/python/onnxruntime_pybind_ortvalue.cc @@ -31,7 +31,13 @@ std::unique_ptr OrtValueFromShapeAndType(const std::vector& s throw std::runtime_error("The provided device id doesn't match any available GPUs on the machine."); } allocator = GetCudaAllocator(device.Id()); -#elif USE_ROCM +#else + throw std::runtime_error( + "Can't allocate memory on the CUDA device using this package of OnnxRuntime. " + "Please use the CUDA package of OnnxRuntime to use this feature."); +#endif + } else if (strcmp(GetDeviceName(device), HIP) == 0) { +#if USE_ROCM if (!IsRocmDeviceIdValid(logging::LoggingManager::DefaultLogger(), device.Id())) { throw std::runtime_error("The provided device id doesn't match any available GPUs on the machine."); } @@ -40,8 +46,8 @@ std::unique_ptr OrtValueFromShapeAndType(const std::vector& s allocator = GetMIGraphXAllocator(device.Id()); #else throw std::runtime_error( - "Can't allocate memory on the CUDA device using this package of OnnxRuntime. " - "Please use the CUDA package of OnnxRuntime to use this feature."); + "Can't allocate memory on the AMD device using this package of OnnxRuntime. " + "Please use the ROCm package of OnnxRuntime to use this feature."); #endif } else if (strcmp(GetDeviceName(device), DML) == 0) { #if USE_DML @@ -80,55 +86,53 @@ void addOrtValueMethods(pybind11::module& m) { CreateGenericMLValue(nullptr, GetAllocator(), "", array_on_cpu, ml_value.get(), true); } else if (device.Type() == OrtDevice::GPU) { - // The tensor's memory is allocated on CUDA - +#if USE_DML + if (device.Vendor() == OrtDevice::VendorIds::MICROSOFT) { + // InputDeflist is null because OrtValue creation is not tied to a specific model + // Likewise, there is no need to specify the name (as the name was previously used to lookup the def list) + // TODO: Add check to ensure that string arrays are not passed - we currently don't support string tensors in DML + CreateGenericMLValue( + nullptr, GetDmlAllocator(device.Id()), "", array_on_cpu, ml_value.get(), true, false, CpuToDmlMemCpy); + } else +#endif #ifdef USE_CUDA - if (!IsCudaDeviceIdValid(logging::LoggingManager::DefaultLogger(), device.Id())) { - throw std::runtime_error("The provided device id doesn't match any available GPUs on the machine."); - } + if (device.Vendor() == OrtDevice::VendorIds::NVIDIA) { + if (!IsCudaDeviceIdValid(logging::LoggingManager::DefaultLogger(), device.Id())) { + throw std::runtime_error("The provided device id doesn't match any available GPUs on the machine."); + } - // InputDeflist is null because OrtValue creation is not tied to a specific model - // Likewise, there is no need to specify the name (as the name was previously used to lookup the def list) - // TODO: Add check to ensure that string arrays are not passed - we currently don't support string tensors in CUDA - CreateGenericMLValue(nullptr, GetCudaAllocator(device.Id()), "", array_on_cpu, ml_value.get(), true, false, CpuToCudaMemCpy); -#elif USE_ROCM - if (!IsRocmDeviceIdValid(logging::LoggingManager::DefaultLogger(), device.Id())) { - throw std::runtime_error("The provided device id doesn't match any available GPUs on the machine."); - } - - // InputDeflist is null because OrtValue creation is not tied to a specific model - // Likewise, there is no need to specify the name (as the name was previously used to lookup the def list) - // TODO: Add check to ensure that string arrays are not passed - we currently don't support string tensors in ROCm - CreateGenericMLValue(nullptr, GetRocmAllocator(device.Id()), "", array_on_cpu, ml_value.get(), true, false, CpuToRocmMemCpy); -#elif USE_MIGRAPHX - // InputDeflist is null because OrtValue creation is not tied to a specific model - // Likewise, there is no need to specify the name (as the name was previously used to lookup the def list) - // TODO: Add check to ensure that string arrays are not passed - we currently don't support string tensors in MIGraphX - CreateGenericMLValue(nullptr, GetMIGraphXAllocator(device.Id()), "", array_on_cpu, ml_value.get(), true, false, CpuToMIGraphXMemCpy); -#elif USE_DML - // InputDeflist is null because OrtValue creation is not tied to a specific model - // Likewise, there is no need to specify the name (as the name was previously used to lookup the def list) - // TODO: Add check to ensure that string arrays are not passed - we currently don't support string tensors in DML - CreateGenericMLValue( - nullptr, GetDmlAllocator(device.Id()), "", array_on_cpu, ml_value.get(), true, false, CpuToDmlMemCpy); -#else - throw std::runtime_error( - "Can't allocate memory on the CUDA device using this package of OnnxRuntime. " - "Please use the CUDA package of OnnxRuntime to use this feature."); + // InputDeflist is null because OrtValue creation is not tied to a specific model + // Likewise, there is no need to specify the name (as the name was previously used to lookup the def list) + // TODO: Add check to ensure that string arrays are not passed - we currently don't support string tensors in CUDA + CreateGenericMLValue(nullptr, GetCudaAllocator(device.Id()), "", array_on_cpu, ml_value.get(), true, false, CpuToCudaMemCpy); + } else #endif - } else if (device.Type() == OrtDevice::DML) { -#if USE_DML - // InputDeflist is null because OrtValue creation is not tied to a specific model - // Likewise, there is no need to specify the name (as the name was previously used to lookup the def list) - // TODO: Add check to ensure that string arrays are not passed - we currently don't support string tensors in DML - CreateGenericMLValue( - nullptr, GetDmlAllocator(device.Id()), "", array_on_cpu, ml_value.get(), true, false, CpuToDmlMemCpy); -#else - throw std::runtime_error( - "Can't allocate memory on the CUDA device using this package of OnnxRuntime. " - "Please use the CUDA package of OnnxRuntime to use this feature."); +#ifdef USE_ROCM + if (device.Vendor() == OrtDevice::VendorIds::AMD) { + if (!IsRocmDeviceIdValid(logging::LoggingManager::DefaultLogger(), device.Id())) { + throw std::runtime_error("The provided device id doesn't match any available GPUs on the machine."); + } + + // InputDeflist is null because OrtValue creation is not tied to a specific model + // Likewise, there is no need to specify the name (as the name was previously used to lookup the def list) + // TODO: Add check to ensure that string arrays are not passed - we currently don't support string tensors in CUDA + CreateGenericMLValue(nullptr, GetRocmAllocator(device.Id()), "", array_on_cpu, ml_value.get(), true, false, CpuToRocmMemCpy); + } else #endif - } else if (device.Type() == OrtDevice::NPU) { +#if USE_MIGRAPHX + if (device.Vendor() == OrtDevice::VendorIds::AMD) { + // InputDeflist is null because OrtValue creation is not tied to a specific model + // Likewise, there is no need to specify the name (as the name was previously used to lookup the def list) + // TODO: Add check to ensure that string arrays are not passed - we currently don't support string tensors in MIGraphX + CreateGenericMLValue(nullptr, GetMIGraphXAllocator(device.Id()), "", array_on_cpu, ml_value.get(), true, false, CpuToMIGraphXMemCpy); + } else +#endif + { + throw std::runtime_error( + "Can't allocate memory on the CUDA device using this package of OnnxRuntime. " + "Please use the CUDA package of OnnxRuntime to use this feature."); + } + } else if (device.Type() == OrtDevice::NPU && device.Vendor() == OrtDevice::VendorIds::HUAWEI) { #ifdef USE_CANN if (!IsCannDeviceIdValid(logging::LoggingManager::DefaultLogger(), device.Id())) { throw std::runtime_error("The provided device id doesn't match any available NPUs on the machine."); @@ -166,41 +170,53 @@ void addOrtValueMethods(pybind11::module& m) { CpuToCpuMemCpy); } else if (device.Type() == OrtDevice::GPU) { #ifdef USE_CUDA - if (!IsCudaDeviceIdValid(logging::LoggingManager::DefaultLogger(), device.Id())) { - throw std::runtime_error("The provided device id doesn't match any available GPUs on the machine."); - } + if (device.Vendor() == OrtDevice::VendorIds::NVIDIA) { + if (!IsCudaDeviceIdValid(logging::LoggingManager::DefaultLogger(), device.Id())) { + throw std::runtime_error("The provided device id doesn't match any available GPUs on the machine."); + } - onnxruntime::python::CopyDataToTensor( - py_values, - values_type, - *(ml_value->GetMutable()), - CpuToCudaMemCpy); -#elif USE_ROCM - if (!IsRocmDeviceIdValid(logging::LoggingManager::DefaultLogger(), device.Id())) { - throw std::runtime_error("The provided device id doesn't match any available GPUs on the machine."); - } + onnxruntime::python::CopyDataToTensor( + py_values, + values_type, + *(ml_value->GetMutable()), + CpuToCudaMemCpy); + } else +#endif +#if USE_ROCM + if (device.Vendor() == OrtDevice::VendorIds::AMD) { + if (!IsRocmDeviceIdValid(logging::LoggingManager::DefaultLogger(), device.Id())) { + throw std::runtime_error("The provided device id doesn't match any available GPUs on the machine."); + } - onnxruntime::python::CopyDataToTensor( - py_values, - values_type, - *(ml_value->GetMutable()), - CpuToRocmMemCpy); -#elif USE_MIGRAPHX - onnxruntime::python::CopyDataToTensor( - py_values, - values_type, - *(ml_value->GetMutable()), - CpuToMIGraphXMemCpy); -#elif USE_DML - onnxruntime::python::CopyDataToTensor( - py_values, - values_type, - *(ml_value->GetMutable()), - CpuToDmlMemCpy); -#else - throw std::runtime_error( - "Unsupported GPU device: Cannot find the supported GPU device."); + onnxruntime::python::CopyDataToTensor( + py_values, + values_type, + *(ml_value->GetMutable()), + CpuToRocmMemCpy); + } else +#endif +#if USE_MIGRAPHX + if (device.Vendor() == OrtDevice::VendorIds::AMD) { + onnxruntime::python::CopyDataToTensor( + py_values, + values_type, + *(ml_value->GetMutable()), + CpuToMIGraphXMemCpy); + } else #endif +#if USE_DML + if (device.Vendor() == OrtDevice::VendorIds::MICROSOFT) { + onnxruntime::python::CopyDataToTensor( + py_values, + values_type, + *(ml_value->GetMutable()), + CpuToDmlMemCpy); + } else +#endif + { + throw std::runtime_error( + "Unsupported GPU device: Cannot find the supported GPU device."); + } } else if (device.Type() == OrtDevice::DML) { #if USE_DML onnxruntime::python::CopyDataToTensor( @@ -424,7 +440,7 @@ void addOrtValueMethods(pybind11::module& m) { auto ml_type = NumpyTypeToOnnxRuntimeTensorType(type_num); auto device = devices.at(i); - OrtMemoryInfo info(GetDeviceName(device), OrtDeviceAllocator, device, device.Id()); + OrtMemoryInfo info(GetDeviceName(device), OrtDeviceAllocator, device); OrtValue ml_value; Tensor::InitOrtValue(ml_type, gsl::make_span(shape), reinterpret_cast(data_ptr), info, ml_value); v->push_back(ml_value); diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index abdb58f4f1801..2b0849f02a143 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -3,6 +3,7 @@ // Licensed under the MIT License. #include +#include #include "python/onnxruntime_pybind_exceptions.h" #include "python/onnxruntime_pybind_mlvalue.h" #include "python/onnxruntime_pybind_state_common.h" @@ -39,6 +40,7 @@ #include "core/session/abi_session_options_impl.h" #include "core/session/onnxruntime_session_options_config_keys.h" #include "core/session/provider_bridge_ort.h" +#include "core/session/onnxruntime_cxx_api.h" #include "core/session/lora_adapters.h" @@ -76,6 +78,7 @@ const OrtDevice::DeviceType OrtDevice::GPU; #include #include +#include namespace onnxruntime { namespace python { @@ -303,14 +306,27 @@ const char* GetDeviceName(const OrtDevice& device) { case OrtDevice::CPU: return CPU; case OrtDevice::GPU: - return CUDA; + switch (device.Vendor()) { + case OrtDevice::VendorIds::NVIDIA: + return CUDA; + case OrtDevice::VendorIds::AMD: + return HIP; + case OrtDevice::VendorIds::MICROSOFT: + return DML; + } + + return CUDA; // default to CUDA for backwards compatibility + case OrtDevice::DML: return DML; case OrtDevice::FPGA: return "FPGA"; case OrtDevice::NPU: #ifdef USE_CANN - return CANN; + if (device.Vendor() == OrtDevice::VendorIds::HUAWEI) { + return CANN; + } + return "NPU"; #else return "NPU"; #endif @@ -1565,6 +1581,16 @@ static void LogDeprecationWarning( #endif void addGlobalMethods(py::module& m) { + m.def("set_global_thread_pool_sizes", [](int intra_op_num_threads, int inter_op_num_threads) { + static std::mutex global_thread_pool_mutex; + OrtThreadingOptions to; + to.intra_op_thread_pool_params.thread_pool_size = intra_op_num_threads; + to.inter_op_thread_pool_params.thread_pool_size = inter_op_num_threads; + std::lock_guard lock(global_thread_pool_mutex); + SetGlobalThreadingOptions(std::move(to)); }, + py::arg("intra_op_num_threads") = 0, // Default value for intra_op_num_threads + py::arg("inter_op_num_threads") = 0, // Default value for inter_op_num_threads + "Set the number of threads used by the global thread pools for intra and inter op parallelism."); m.def("get_default_session_options", &GetDefaultCPUSessionOptions, "Return a default session_options instance."); m.def("get_session_initializer", &SessionObjectInitializer::Get, "Return a default session object initializer."); m.def( @@ -1843,10 +1869,31 @@ void addObjectMethods(py::module& m, ExecutionProviderRegistrationFn ep_registra .value("CPU", OrtMemTypeCPU) .value("DEFAULT", OrtMemTypeDefault); - py::class_ device(m, "OrtDevice", R"pbdoc(ONNXRuntime device informaion.)pbdoc"); - device.def(py::init()) + py::class_ device(m, "OrtDevice", R"pbdoc(ONNXRuntime device information.)pbdoc"); + device.def(py::init()) + .def(py::init([](OrtDevice::DeviceType type, + OrtDevice::MemoryType mem_type, + OrtDevice::DeviceId device_id) { + // backwards compatibility. generally there's only one GPU EP in the python package, with the exception + // of a build with CUDA and DML. + OrtDevice::VendorId vendor = OrtDevice::VendorIds::NONE; + if (type == OrtDevice::DML) { + type = OrtDevice::GPU; + vendor = OrtDevice::VendorIds::MICROSOFT; + } else if (type == OrtDevice::GPU) { +#if USE_CUDA + vendor = OrtDevice::VendorIds::NVIDIA; +#elif USE_ROCM || USE_MIGRAPHX + vendor = OrtDevice::VendorIds::AMD; +#endif + } + + return OrtDevice(type, mem_type, vendor, device_id); + }), + R"pbdoc(Constructor with vendor_id defaulted to 0 for backward compatibility.)pbdoc") .def("device_id", &OrtDevice::Id, R"pbdoc(Device Id.)pbdoc") .def("device_type", &OrtDevice::Type, R"pbdoc(Device Type.)pbdoc") + .def("vendor_id", &OrtDevice::Vendor, R"pbdoc(Vendor Id.)pbdoc") .def_static("cpu", []() { return OrtDevice::CPU; }) .def_static("cuda", []() { return OrtDevice::GPU; }) .def_static("cann", []() { return OrtDevice::NPU; }) @@ -1971,15 +2018,19 @@ for model inference.)pbdoc"); py::class_ ort_memory_info_binding(m, "OrtMemoryInfo"); ort_memory_info_binding.def(py::init([](const char* name, OrtAllocatorType type, int id, OrtMemType mem_type) { if (strcmp(name, onnxruntime::CPU) == 0) { - return std::make_unique(onnxruntime::CPU, type, OrtDevice(), id, mem_type); + return std::make_unique(onnxruntime::CPU, type, OrtDevice(), mem_type); } else if (strcmp(name, onnxruntime::CUDA) == 0) { return std::make_unique( - onnxruntime::CUDA, type, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, static_cast(id)), id, + onnxruntime::CUDA, type, + OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::NVIDIA, + static_cast(id)), mem_type); } else if (strcmp(name, onnxruntime::CUDA_PINNED) == 0) { return std::make_unique( - onnxruntime::CUDA_PINNED, type, OrtDevice(OrtDevice::CPU, OrtDevice::MemType::CUDA_PINNED, static_cast(id)), - id, mem_type); + onnxruntime::CUDA_PINNED, type, + OrtDevice(OrtDevice::CPU, OrtDevice::MemType::HOST_ACCESSIBLE, OrtDevice::VendorIds::NVIDIA, + static_cast(id)), + mem_type); } else { throw std::runtime_error("Specified device is not supported."); } @@ -2154,6 +2205,13 @@ Serialized model format will default to ONNX unless: }, R"pbdoc(VLOG level if DEBUG build and session_log_severity_level is 0. Applies to session load, initialization, etc. Default is 0.)pbdoc") + .def_property( + "use_per_session_threads", + [](const PySessionOptions* options) -> bool { return options->value.use_per_session_threads; }, + [](PySessionOptions* options, bool use_per_session_threads) -> void { + options->value.use_per_session_threads = use_per_session_threads; + }, + R"pbdoc(Whether to use per-session thread pool. Default is True.)pbdoc") .def_property( "intra_op_num_threads", [](const PySessionOptions* options) -> int { return options->value.intra_op_param.thread_pool_size; }, @@ -2436,6 +2494,14 @@ including arg name, arg type (contains both type and shape).)pbdoc") bool load_config_from_model = false) { std::unique_ptr sess; + if (CheckIfUsingGlobalThreadPool() && so.value.use_per_session_threads) { + ORT_THROW("use_per_session_threads must be false when using a global thread pool"); + } + + if (CheckIfUsingGlobalThreadPool() && (so.value.intra_op_param.thread_pool_size != 0 || so.value.inter_op_param.thread_pool_size != 0)) { + LOGS_DEFAULT(WARNING) << "session options intra_op_param.thread_pool_size and inter_op_param.thread_pool_size are ignored when using a global thread pool"; + } + // separate creation of the session from model loading unless we have to read the config from the model. // in a minimal build we only support load via Load(...) and not at session creation time if (load_config_from_model) { diff --git a/onnxruntime/python/onnxruntime_pybind_state_common.h b/onnxruntime/python/onnxruntime_pybind_state_common.h index 1515879f61419..edb10bc28a871 100644 --- a/onnxruntime/python/onnxruntime_pybind_state_common.h +++ b/onnxruntime/python/onnxruntime_pybind_state_common.h @@ -321,7 +321,7 @@ inline const PySessionOptions& GetDefaultCPUSessionOptions() { } inline AllocatorPtr& GetAllocator() { - static AllocatorPtr alloc = std::make_shared(); + static AllocatorPtr alloc = CPUAllocator::DefaultInstance(); return alloc; } @@ -436,6 +436,9 @@ class SessionObjectInitializer { #if defined(_MSC_VER) && !defined(__clang__) #pragma warning(pop) #endif + +void SetGlobalThreadingOptions(const OrtThreadingOptions&& tp_options); +bool CheckIfUsingGlobalThreadPool(); Environment& GetEnv(); OrtEnv* GetOrtEnv(); diff --git a/onnxruntime/python/tools/qnn/gen_qnn_ctx_onnx_model.py b/onnxruntime/python/tools/qnn/gen_qnn_ctx_onnx_model.py index 7a3e364a08cfd..9943eda54f8b0 100644 --- a/onnxruntime/python/tools/qnn/gen_qnn_ctx_onnx_model.py +++ b/onnxruntime/python/tools/qnn/gen_qnn_ctx_onnx_model.py @@ -11,13 +11,16 @@ class QnnTensorStruct: - def __init__(self): - self.name = "" - self.onnx_data_type = TensorProto.FLOAT - self.is_quantized = False - self.scale = 0.0 - self.offset = 0 - self.dim = [] + def __init__( + self, name="", onnx_data_type=TensorProto.FLOAT, is_quantized=False, scale=0.0, offset=0, dim=None, id=None + ): + self.name = name + self.onnx_data_type = onnx_data_type + self.is_quantized = is_quantized + self.scale = scale + self.offset = offset + self.dim = [] if dim is None else dim + self.id = id def is_quantized_data_type(qnn_data_type, is_converter_json): @@ -113,41 +116,36 @@ def parse_qnn_converter_json_file(qnn_convert_json, qnn_input_tensor_dic, qnn_ou for qnn_tensor_name, qnn_tensor_attribute in qnn_convert_json["graph"]["tensors"].items(): # type:0 - QNN input tensor, type:1 - QNN output tensor assert ( - "type" in qnn_tensor_attribute and "data_type" in qnn_tensor_attribute and "dims" in qnn_tensor_attribute + "type" in qnn_tensor_attribute + and "data_type" in qnn_tensor_attribute + and "dims" in qnn_tensor_attribute + and "id" in qnn_tensor_attribute + and "quant_params" in qnn_tensor_attribute ), "QNN converted json file not valid. Can't find some keys from tensors" - # Get all graph inputs + # If tensor is not IO, ignore it + if qnn_tensor_attribute["type"] not in [0, 1]: + continue + + # Get all graph inputs & output + qnn_tensor = QnnTensorStruct( + name=qnn_tensor_name, + onnx_data_type=qnn_data_type_to_onnx_data_type(qnn_tensor_attribute["data_type"], is_qnn_converter_json), + is_quantized=is_quantized_data_type(qnn_tensor_attribute["data_type"], is_qnn_converter_json), + dim=qnn_tensor_attribute["dims"], + id=qnn_tensor_attribute["id"], + ) + + if ( + qnn_tensor_attribute["quant_params"]["definition"] == 1 + and qnn_tensor_attribute["quant_params"]["encoding"] == 0 + ): + qnn_tensor.scale = qnn_tensor_attribute["quant_params"]["scale_offset"]["scale"] + qnn_tensor.offset = -qnn_tensor_attribute["quant_params"]["scale_offset"]["offset"] + if qnn_tensor_attribute["type"] == 0: - qnn_tensor = QnnTensorStruct() - qnn_tensor.name = qnn_tensor_name - qnn_tensor.onnx_data_type = qnn_data_type_to_onnx_data_type( - qnn_tensor_attribute["data_type"], is_qnn_converter_json - ) - qnn_tensor.is_quantized = is_quantized_data_type(qnn_tensor_attribute["data_type"], is_qnn_converter_json) - qnn_tensor.dim = qnn_tensor_attribute["dims"] - if ( - qnn_tensor_attribute["quant_params"]["definition"] == 1 - and qnn_tensor_attribute["quant_params"]["encoding"] == 0 - ): - qnn_tensor.scale = qnn_tensor_attribute["quant_params"]["scale_offset"]["scale"] - qnn_tensor.offset = 0 - qnn_tensor_attribute["quant_params"]["scale_offset"]["offset"] qnn_input_tensor_dic[qnn_tensor_name] = qnn_tensor - - # Get all graph outputs - if qnn_tensor_attribute["type"] == 1: - qnn_tensor = QnnTensorStruct() - qnn_tensor.name = qnn_tensor_name - qnn_tensor.onnx_data_type = qnn_data_type_to_onnx_data_type( - qnn_tensor_attribute["data_type"], is_qnn_converter_json - ) - qnn_tensor.is_quantized = is_quantized_data_type(qnn_tensor_attribute["data_type"], is_qnn_converter_json) - qnn_tensor.dim = qnn_tensor_attribute["dims"] - if ( - qnn_tensor_attribute["quant_params"]["definition"] == 1 - and qnn_tensor_attribute["quant_params"]["encoding"] == 0 - ): - qnn_tensor.scale = qnn_tensor_attribute["quant_params"]["scale_offset"]["scale"] - qnn_tensor.offset = 0 - qnn_tensor_attribute["quant_params"]["scale_offset"]["offset"] + else: qnn_output_tensor_dic[qnn_tensor_name] = qnn_tensor assert len(qnn_input_tensor_dic) >= 1 and len(qnn_output_tensor_dic) >= 1, ( @@ -170,7 +168,7 @@ def generate_wrapper_onnx_file( value_infos = [] model_inputs = [] - for qnn_input in qnn_input_tensor_dic.values(): + for qnn_input in sorted(qnn_input_tensor_dic.values(), key=lambda inp: inp.id): if qnn_input.is_quantized and not quantized_IO: q_scale_input_name = qnn_input.name + "_scale" q_offset_input_name = qnn_input.name + "_zp" @@ -215,7 +213,7 @@ def generate_wrapper_onnx_file( graph_nodes.append(qnn_ep_context_node) model_outputs = [] - for qnn_output in qnn_output_tensor_dic.values(): + for qnn_output in sorted(qnn_output_tensor_dic.values(), key=lambda out: out.id): if qnn_output.is_quantized and not quantized_IO: dq_scale_input_name = qnn_output.name + "_scale" dq_offset_input_name = qnn_output.name + "_zp" diff --git a/onnxruntime/python/tools/qnn/preprocess.py b/onnxruntime/python/tools/qnn/preprocess.py index b7ddf1de9dc34..14c62d2cd26e0 100644 --- a/onnxruntime/python/tools/qnn/preprocess.py +++ b/onnxruntime/python/tools/qnn/preprocess.py @@ -69,6 +69,23 @@ def _parse_arguments(): help="List of graph output names to be transposed into channel-last.", ) + # Fix dynamic input shapes. + parser.add_argument( + "--dynamic_input_shapes", + nargs=2, + action="append", + type=str, + default=None, + help="Model input name and desired static shape in comma seprated format, for example: 'input' 1,3,256,256", + ) + + # Exclude initializer from input + parser.add_argument( + "--exclude_initializer_from_input", + action="store_true", + help="Whether to exclude initializer from input if model.ir_version >= 4", + ) + return parser.parse_args() @@ -83,6 +100,8 @@ def qnn_preprocess_model( external_data_convert_attribute: bool = False, inputs_to_make_channel_last: list[str] | None = None, outputs_to_make_channel_last: list[str] | None = None, + dynamic_input_shapes: list[tuple[str, str]] | None = None, + exclude_initializer_from_input: bool = False, ) -> bool: """Preprocess ONNX model for QNN. @@ -105,6 +124,9 @@ def qnn_preprocess_model( Defaults to None. outputs_to_make_channel_last: A list of strs specifying graph output names to be transposed into channel-last. Defaults to None. + dynamic_input_shapes: A list of tuples specifying model input name to and its static shape in comma seprated + format, for example: [('input', '1,3,256,256')]. Defaults to None. + exclude_initializer_from_input: A bool specifying whether to exclude initializer from input. Defaults to False. Returns: A bool indicating whether the model is modified. @@ -120,6 +142,8 @@ def qnn_preprocess_model( external_data_convert_attribute=external_data_convert_attribute, inputs_to_make_channel_last=inputs_to_make_channel_last, outputs_to_make_channel_last=outputs_to_make_channel_last, + dynamic_input_shapes=dynamic_input_shapes, + exclude_initializer_from_input=exclude_initializer_from_input, ) @@ -136,4 +160,6 @@ def qnn_preprocess_model( external_data_convert_attribute=args.external_data_convert_attribute, inputs_to_make_channel_last=args.inputs_to_make_channel_last, outputs_to_make_channel_last=args.outputs_to_make_channel_last, + dynamic_input_shapes=args.dynamic_input_shapes, + exclude_initializer_from_input=args.exclude_initializer_from_input, ) diff --git a/onnxruntime/python/tools/quantization/base_quantizer.py b/onnxruntime/python/tools/quantization/base_quantizer.py index 7bf8b2846d73b..9a297e451213a 100644 --- a/onnxruntime/python/tools/quantization/base_quantizer.py +++ b/onnxruntime/python/tools/quantization/base_quantizer.py @@ -24,6 +24,7 @@ QUANT_OP_NAME, TENSOR_NAME_QUANT_SUFFIX, find_by_name, + get_opset_version, model_has_infer_metadata, normalize_axis, pack_bytes_to_4bit, @@ -86,6 +87,7 @@ def __init__( self.value_infos.update({it.name: it for it in model.graph.input}) self.model = ONNXModel(model) + self.opset_version = get_opset_version(model) self.per_channel = per_channel # weight-pack per channel self.reduce_range = reduce_range @@ -127,8 +129,6 @@ def __init__( self.nodes_to_exclude = nodes_to_exclude # specific nodes to exclude self.op_types_to_quantize = op_types_to_quantize - self.opset_version = self.check_opset_version() - # Get tensor-level quantization overrides and ensure they are valid. self.tensor_quant_overrides = TensorQuantOverridesHelper(self.extra_options.get("TensorQuantOverrides", {})) @@ -188,41 +188,6 @@ def should_quantize_node(self, node): return True - def check_opset_version(self): - ai_onnx_domain = [ - opset for opset in self.model.model.opset_import if not opset.domain or opset.domain == "ai.onnx" - ] - if len(ai_onnx_domain) != 1: - raise ValueError("Failed to find proper ai.onnx domain") - opset_version = ai_onnx_domain[0].version - - if opset_version == 10: - logging.warning( - f"The original model opset version is {opset_version}, which does not support node fusions. Please update the model to opset >= 11 for better performance." - ) - return 10 - - if opset_version < 10: - logging.warning( - f"The original model opset version is {opset_version}, which does not support quantization. Please update the model to opset >= 11. Updating the model automatically to opset 11. Please verify the quantized model." - ) - self.model.model.opset_import.remove(ai_onnx_domain[0]) - self.model.model.opset_import.extend([onnx.helper.make_opsetid("", 11)]) - opset_version = 11 - - if opset_version < 19 and self.weight_qType == onnx.TensorProto.FLOAT8E4M3FN: - logging.warning( - f"The original model opset version is {opset_version}, which does not support quantization to float 8. " - "Please update the model to opset >= 19. Updating the model automatically to opset 19. " - "Please verify the quantized model." - ) - self.model.model.opset_import.remove(ai_onnx_domain[0]) - self.model.model.opset_import.extend([onnx.helper.make_opsetid("", 19)]) - self.model.model.ir_version = 9 - opset_version = 19 - - return opset_version - def quantize_bias_static_impl(self, bias_name, input_scale, weight_scale, beta=1.0): """ Quantized the bias. Zero Point == 0 and Scale == Input_Scale * Weight_Scale diff --git a/onnxruntime/python/tools/quantization/execution_providers/qnn/preprocess.py b/onnxruntime/python/tools/quantization/execution_providers/qnn/preprocess.py index 44ff7e4aba10b..191edc4c6390d 100644 --- a/onnxruntime/python/tools/quantization/execution_providers/qnn/preprocess.py +++ b/onnxruntime/python/tools/quantization/execution_providers/qnn/preprocess.py @@ -10,6 +10,8 @@ import onnx +from ....tools.onnx_model_utils import fix_output_shapes, make_input_shape_fixed +from ....tools.remove_initializer_from_input import remove_initializer_from_input from ...fusions import FusionGelu, FusionLayerNormalization from ...onnx_model import ONNXModel from ...quant_utils import save_and_reload_model_with_shape_infer @@ -20,6 +22,7 @@ def qnn_preprocess_model( model_input: str | Path | onnx.ModelProto, model_output: str | Path, + exclude_initializer_from_input: bool = False, fuse_layernorm: bool = False, save_as_external_data: bool = False, all_tensors_to_one_file: bool = False, @@ -28,6 +31,7 @@ def qnn_preprocess_model( external_data_convert_attribute: bool = False, inputs_to_make_channel_last: list[str] | None = None, outputs_to_make_channel_last: list[str] | None = None, + dynamic_input_shapes: list[tuple[str, str]] | None = None, ) -> bool: """ If necessary, this method creates a new "pre-processed" model in preparation for @@ -41,6 +45,8 @@ def qnn_preprocess_model( Args: model_input: Path to the input model file or ModelProto. model_output: Path the output model file, which is only created if this method returns True. + exclude_initializer_from_input: A bool specifying whether to exclude initializer from input. + Defaults to False. fuse_layernorm: True if ReduceMean sequences should be fused into LayerNormalization nodes. Defaults to False. save_as_external_data: True if output model should be saved with external data. Defaults to false. @@ -82,12 +88,26 @@ def qnn_preprocess_model( This can potentially improve inference latency for QDQ models running on QNN EP because the additional transpose node may allow other transpose nodes inserted during ORT layout transformation to cancel out. + dynamic_input_shapes: A list of tuples specifying model input name to and its static shape in comma seprated + format, for example: [('input', '1,3,256,256')]. Defaults to None. """ modified = False model = model_input if isinstance(model_input, onnx.ModelProto) else onnx.load_model(model_input) model = save_and_reload_model_with_shape_infer(model) onnx_model = ONNXModel(model) + # Optionally, fix the dynamic input shapes. + if dynamic_input_shapes: + for input_name, input_shape_str in dynamic_input_shapes: + input_shape = [int(i) for i in input_shape_str.split(",")] + make_input_shape_fixed(onnx_model.graph(), input_name, input_shape) + fix_output_shapes(onnx_model.model) + modified = True + + # Exclude initializer from input if model.ir_version >= 4 + if exclude_initializer_from_input: + modified |= remove_initializer_from_input(onnx_model.model) + # Fuse Erf sequence into a single Gelu fusion_gelu = FusionGelu(onnx_model) if fusion_gelu.apply(): diff --git a/onnxruntime/python/tools/quantization/quant_utils.py b/onnxruntime/python/tools/quantization/quant_utils.py index 48cd1c52be2e2..28536f52c1e56 100644 --- a/onnxruntime/python/tools/quantization/quant_utils.py +++ b/onnxruntime/python/tools/quantization/quant_utils.py @@ -971,6 +971,51 @@ def model_has_infer_metadata(model: ModelProto) -> bool: return False +def get_opset_version(model: ModelProto) -> int: + ai_onnx_domain = [opset for opset in model.opset_import if not opset.domain or opset.domain == "ai.onnx"] + if len(ai_onnx_domain) != 1: + raise ValueError("Failed to find proper ai.onnx domain") + opset_version = ai_onnx_domain[0].version + + return opset_version + + +def update_opset_version(model: ModelProto, weight_type: QuantType) -> ModelProto: + opset_version = get_opset_version(model) + target_opset_version = opset_version + weight_quant_type = getattr(weight_type, "tensor_type", weight_type) + + if opset_version < 19 and weight_quant_type == onnx.TensorProto.FLOAT8E4M3FN: + logging.warning( + f"The original model opset version is {opset_version}, which does not support quantization to float 8. " + "Please update the model to opset >= 19. Automatically update the model to opset 19. " + "Please verify the quantized model." + ) + target_opset_version = 19 + + elif opset_version == 10: + logging.warning( + f"The original model opset version is {opset_version}, which does not support node fusions. " + "Please update the model to opset >= 11 for better performance." + ) + + elif opset_version < 10: + logging.warning( + f"The original model opset version is {opset_version}, which does not support quantization. " + "Please update the model to opset >= 11. Automatically update the model to opset 11. " + "Please verify the quantized model." + ) + target_opset_version = 11 + + if target_opset_version != opset_version: + model = onnx.version_converter.convert_version(model, target_opset_version) + # Additional nodes may be added to the model during the opset version conversion. Run shape inference + # to ensure all nodes are included in model.graph.value_info. + model = save_and_reload_model_with_shape_infer(model) + + return model + + def load_model_with_shape_infer(model_path: Path) -> ModelProto: inferred_model_path = generate_identified_filename(model_path, "-inferred") onnx.shape_inference.infer_shapes_path(str(model_path), str(inferred_model_path)) diff --git a/onnxruntime/python/tools/quantization/quantize.py b/onnxruntime/python/tools/quantization/quantize.py index 96cd3d4fb6792..d0f752eea486d 100644 --- a/onnxruntime/python/tools/quantization/quantize.py +++ b/onnxruntime/python/tools/quantization/quantize.py @@ -25,6 +25,7 @@ load_model_with_shape_infer, model_has_pre_process_metadata, save_and_reload_model_with_shape_infer, + update_opset_version, ) from .registry import IntegerOpsRegistry, QDQRegistry, QLinearOpsRegistry from .tensor_quant_overrides import TensorQuantOverridesHelper @@ -680,8 +681,6 @@ def quantize_static( logging.error(f"{e}.") raise RuntimeError("neural-compressor is not correctly installed. Please check your environment.") from e - import copy - from neural_compressor.adaptor.ox_utils.smooth_quant import ORTSmoothQuant def inc_dataloader(): @@ -700,9 +699,18 @@ def inc_dataloader(): nodes_to_exclude.extend([i.name for i in model.model.graph.node if i.name not in orig_nodes]) model = load_model_with_shape_infer(Path(model_input)) # use smooth quant model for calibration + updated_model = update_opset_version(model, weight_type) + is_model_updated = updated_model is not model + if is_model_updated: + model = updated_model + with tempfile.TemporaryDirectory(prefix="ort.quant.") as quant_tmp_dir: + if is_model_updated: + # Update model_input and avoid to use the original one + model_input = copy.deepcopy(model) + if isinstance(model_input, onnx.ModelProto): - output_path = str(Path(quant_tmp_dir) / "model_input.onnx") + output_path = Path(quant_tmp_dir).joinpath("model_input.onnx").as_posix() onnx.save_model( model_input, output_path, @@ -864,6 +872,8 @@ def quantize_dynamic( if "MatMulConstBOnly" not in extra_options: extra_options["MatMulConstBOnly"] = True + model = update_opset_version(model, weight_type) + quantizer = ONNXQuantizer( model, per_channel, diff --git a/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py index 2cc002928b16e..b6b62dc3bb3a1 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py +++ b/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py @@ -535,6 +535,19 @@ def main(argv=None): if path in output_paths: output_paths.remove(path) + else: + # Create ancillary JSON files for ONNX Runtime GenAI and/or Hugging Face's Optimum + WhisperHelper.save_processing( + args.model_name_or_path, + args.provider, + args.separate_encoder_and_decoder_init, + args.output_cross_qk, + next(iter(filter(lambda path: "encoder" in path, output_paths))), + next(iter(filter(lambda path: "decoder" in path, output_paths))), + output_dir, + cache_dir, + ) + logger.info(f"Done! Outputs: {output_paths}") return max_diff diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py index 3cb6c23848f13..08118ccb551eb 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py @@ -3,16 +3,17 @@ # Licensed under the MIT License. See License.txt in the project root for # license information. # -------------------------------------------------------------------------- - +import json import logging import os +import textwrap from pathlib import Path import numpy as np import torch from convert_generation import add_cache_indirection_to_mha, add_output_qk_to_mha, fix_past_sequence_length from optimizer import optimize_model -from transformers import WhisperConfig, WhisperForConditionalGeneration, WhisperProcessor +from transformers import AutoTokenizer, WhisperConfig, WhisperForConditionalGeneration, WhisperProcessor from whisper_decoder import WhisperDecoder from whisper_encoder import WhisperEncoder from whisper_encoder_decoder_init import WhisperEncoderDecoderInit @@ -67,6 +68,569 @@ def get_onnx_path( directory = os.path.join(output_dir, model_name) if new_folder else output_dir return os.path.join(directory, model_name + ".onnx") + @staticmethod + def save_processing( + model_name_or_path: str, + provider: str, + separate_encoder_and_decoder_init: bool, + output_qk: bool, + encoder_path: str, + decoder_path: str, + output_dir: str, + cache_dir: str, + ) -> None: + config = WhisperConfig.from_pretrained(model_name_or_path, cache_dir=cache_dir) + config.save_pretrained(output_dir) + + tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, cache_dir=cache_dir) + tokenizer.save_pretrained(output_dir) + + processor = WhisperProcessor.from_pretrained(model_name_or_path, cache_dir=cache_dir) + processor.save_pretrained(output_dir) + + # Return early since the next files are for ONNX Runtime GenAI + if separate_encoder_and_decoder_init: + return + + audio_processor_json = textwrap.dedent("""\ + { + "feature_extraction": { + "sequence": [ + { + "operation": { + "name": "audio_decoder", + "type": "AudioDecoder" + } + }, + { + "operation": { + "name": "STFT", + "type": "STFTNorm", + "attrs": { + "n_fft": 400, + "frame_length": 400, + "hop_length": 160, + "_comment": [ + 0.0, + 0.0000616908073425293, + 0.0002467334270477295, + 0.0005550682544708252, + 0.000986635684967041, + 0.0015413463115692139, + 0.0022190213203430176, + 0.0030195116996765137, + 0.003942638635635376, + 0.004988163709640503, + 0.006155818700790405, + 0.007445335388183594, + 0.008856385946273804, + 0.010388582944869995, + 0.012041628360748291, + 0.013815045356750488, + 0.01570841670036316, + 0.01772129535675049, + 0.019853144884109497, + 0.022103488445281982, + 0.02447172999382019, + 0.026957333087921143, + 0.029559612274169922, + 0.03227800130844116, + 0.03511175513267517, + 0.03806024789810181, + 0.0411226749420166, + 0.044298380613327026, + 0.04758647084236145, + 0.05098623037338257, + 0.05449673533439636, + 0.058117181062698364, + 0.06184667348861694, + 0.0656842589378357, + 0.06962898373603821, + 0.07367992401123047, + 0.0778360664844513, + 0.08209633827209473, + 0.08645972609519958, + 0.09092515707015991, + 0.09549149870872498, + 0.10015767812728882, + 0.10492250323295593, + 0.1097848117351532, + 0.11474338173866272, + 0.11979702115058899, + 0.12494447827339172, + 0.13018447160720825, + 0.1355157196521759, + 0.14093685150146484, + 0.1464466154575348, + 0.15204361081123352, + 0.1577264666557312, + 0.16349375247955322, + 0.16934409737586975, + 0.1752760112285614, + 0.18128803372383118, + 0.18737870454788208, + 0.19354650378227234, + 0.1997898817062378, + 0.20610737800598145, + 0.21249738335609436, + 0.21895831823349, + 0.2254886031150818, + 0.23208662867546082, + 0.23875075578689575, + 0.24547931551933289, + 0.2522706985473633, + 0.25912320613861084, + 0.26603513956069946, + 0.27300477027893066, + 0.2800304591655731, + 0.2871103882789612, + 0.29424285888671875, + 0.30142611265182495, + 0.30865830183029175, + 0.31593772768974304, + 0.3232625722885132, + 0.3306310474872589, + 0.3380413055419922, + 0.34549152851104736, + 0.352979838848114, + 0.3605044484138489, + 0.3680635094642639, + 0.37565508484840393, + 0.38327735662460327, + 0.3909284174442291, + 0.39860638976097107, + 0.4063093662261963, + 0.41403549909591675, + 0.42178282141685486, + 0.4295494258403778, + 0.43733343482017517, + 0.44513291120529175, + 0.45294591784477234, + 0.46077051758766174, + 0.46860480308532715, + 0.4764467775821686, + 0.4842946231365204, + 0.492146372795105, + 0.5, + 0.5078536868095398, + 0.515705406665802, + 0.5235532522201538, + 0.5313953161239624, + 0.5392295718193054, + 0.5470541715621948, + 0.5548672080039978, + 0.562666654586792, + 0.5704506635665894, + 0.5782172679901123, + 0.5859646201133728, + 0.5936906933784485, + 0.6013936996459961, + 0.609071671962738, + 0.6167227625846863, + 0.6243450045585632, + 0.6319366097450256, + 0.6394955515861511, + 0.6470202207565308, + 0.6545085310935974, + 0.6619587540626526, + 0.6693689823150635, + 0.6767374277114868, + 0.6840623021125793, + 0.691341757774353, + 0.6985740065574646, + 0.7057572603225708, + 0.7128896713256836, + 0.719969630241394, + 0.7269952893257141, + 0.7339649796485901, + 0.7408769130706787, + 0.7477294206619263, + 0.7545207738876343, + 0.761249303817749, + 0.7679134607315063, + 0.774511456489563, + 0.7810417413711548, + 0.7875027060508728, + 0.7938927412033081, + 0.800210177898407, + 0.8064535856246948, + 0.8126214146614075, + 0.8187121152877808, + 0.8247240781784058, + 0.8306560516357422, + 0.8365063667297363, + 0.8422735929489136, + 0.8479564785957336, + 0.8535534143447876, + 0.8590631484985352, + 0.8644843101501465, + 0.8698155879974365, + 0.8750555515289307, + 0.8802030086517334, + 0.8852566480636597, + 0.8902152180671692, + 0.8950775265693665, + 0.899842381477356, + 0.9045084714889526, + 0.9090749025344849, + 0.9135403037071228, + 0.9179036617279053, + 0.9221639633178711, + 0.9263200759887695, + 0.9303710460662842, + 0.9343158006668091, + 0.9381533861160278, + 0.941882848739624, + 0.945503294467926, + 0.9490138292312622, + 0.9524135589599609, + 0.9557017087936401, + 0.9588773250579834, + 0.961939811706543, + 0.9648882746696472, + 0.9677220582962036, + 0.9704403877258301, + 0.9730427265167236, + 0.9755282998085022, + 0.9778965711593628, + 0.9801468849182129, + 0.9822787046432495, + 0.9842916131019592, + 0.9861849546432495, + 0.9879584312438965, + 0.9896113872528076, + 0.9911436438560486, + 0.9925546646118164, + 0.9938441514968872, + 0.9950118064880371, + 0.996057391166687, + 0.9969804883003235, + 0.997780978679657, + 0.9984586238861084, + 0.999013364315033, + 0.9994449615478516, + 0.9997532367706299, + 0.9999383091926575, + 1, + 0.9999383091926575, + 0.9997532367706299, + 0.9994449615478516, + 0.999013364315033, + 0.9984586238861084, + 0.997780978679657, + 0.9969804286956787, + 0.9960573315620422, + 0.9950118064880371, + 0.9938441514968872, + 0.9925546646118164, + 0.9911435842514038, + 0.9896113872528076, + 0.9879583716392517, + 0.9861849546432495, + 0.9842915534973145, + 0.9822787046432495, + 0.9801468253135681, + 0.9778964519500732, + 0.9755282402038574, + 0.9730426073074341, + 0.9704403877258301, + 0.9677219390869141, + 0.9648882150650024, + 0.9619396924972534, + 0.9588772654533386, + 0.9557015895843506, + 0.9524134397506714, + 0.9490137100219727, + 0.9455032348632812, + 0.9418827295303345, + 0.9381532669067383, + 0.9343156814575195, + 0.9303709268569946, + 0.9263200759887695, + 0.9221639633178711, + 0.9179036617279053, + 0.913540244102478, + 0.9090747833251953, + 0.9045084714889526, + 0.8998422622680664, + 0.8950774669647217, + 0.8902151584625244, + 0.8852565884590149, + 0.8802029490470886, + 0.8750554919242859, + 0.869815468788147, + 0.8644842505455017, + 0.8590630888938904, + 0.853553295135498, + 0.8479562997817993, + 0.842273473739624, + 0.836506187915802, + 0.8306558728218079, + 0.8247239589691162, + 0.8187118768692017, + 0.8126212358474731, + 0.8064534664154053, + 0.8002099990844727, + 0.793892502784729, + 0.7875025272369385, + 0.7810416221618652, + 0.7745113372802734, + 0.767913281917572, + 0.7612491846084595, + 0.7545205950737, + 0.7477291822433472, + 0.7408767342567444, + 0.7339648008346558, + 0.7269951105117798, + 0.7199694514274597, + 0.7128894925117493, + 0.7057570219039917, + 0.6985738277435303, + 0.6913415789604187, + 0.684062123298645, + 0.6767372488975525, + 0.6693688035011292, + 0.6619585752487183, + 0.6545083522796631, + 0.6470199823379517, + 0.6394953727722168, + 0.6319363117218018, + 0.6243447661399841, + 0.6167224645614624, + 0.6090714335441589, + 0.601393461227417, + 0.5936904549598694, + 0.5859643220901489, + 0.5782170295715332, + 0.5704504251480103, + 0.5626664161682129, + 0.5548669099807739, + 0.5470539331436157, + 0.5392293334007263, + 0.5313950181007385, + 0.5235530138015747, + 0.5157051682472229, + 0.507853627204895, + 0.5, + 0.4921463429927826, + 0.484294593334198, + 0.4764467477798462, + 0.46860471367836, + 0.4607704281806946, + 0.4529458284378052, + 0.4451328217983246, + 0.437333345413208, + 0.42954933643341064, + 0.4217827320098877, + 0.4140354096889496, + 0.4063093066215515, + 0.3986063003540039, + 0.39092832803726196, + 0.3832772672176361, + 0.37565499544143677, + 0.36806342005729675, + 0.3605043888092041, + 0.35297977924346924, + 0.3454914391040802, + 0.338041216135025, + 0.33063095808029175, + 0.3232625126838684, + 0.3159376382827759, + 0.3086581826210022, + 0.3014259934425354, + 0.2942427396774292, + 0.28711026906967163, + 0.2800303101539612, + 0.2730046510696411, + 0.2660350203514099, + 0.2591230869293213, + 0.25227057933807373, + 0.24547919631004333, + 0.2387506067752838, + 0.23208650946617126, + 0.22548848390579224, + 0.21895819902420044, + 0.2124972641468048, + 0.2061072587966919, + 0.19978976249694824, + 0.1935463547706604, + 0.18737855553627014, + 0.18128788471221924, + 0.17527586221694946, + 0.1693439483642578, + 0.16349363327026367, + 0.15772631764411926, + 0.15204349160194397, + 0.14644649624824524, + 0.1409367322921753, + 0.13551557064056396, + 0.1301843225955963, + 0.12494435906410217, + 0.11979690194129944, + 0.11474326252937317, + 0.10978469252586365, + 0.10492238402366638, + 0.10015755891799927, + 0.09549137949943542, + 0.09092503786087036, + 0.08645960688591003, + 0.08209621906280518, + 0.07783591747283936, + 0.07367980480194092, + 0.06962886452674866, + 0.06568413972854614, + 0.06184655427932739, + 0.0581170916557312, + 0.0544966459274292, + 0.05098611116409302, + 0.04758638143539429, + 0.044298261404037476, + 0.04112258553504944, + 0.038060128688812256, + 0.03511166572570801, + 0.03227788209915161, + 0.02955952286720276, + 0.02695724368095398, + 0.024471670389175415, + 0.02210339903831482, + 0.01985308527946472, + 0.017721205949783325, + 0.015708357095718384, + 0.0138150155544281, + 0.012041598558425903, + 0.010388582944869995, + 0.008856356143951416, + 0.007445335388183594, + 0.006155818700790405, + 0.004988163709640503, + 0.003942638635635376, + 0.0030195116996765137, + 0.0022190213203430176, + 0.0015413165092468262, + 0.000986635684967041, + 0.0005550682544708252, + 0.0002467334270477295, + 0.0000616908073425293 + ] + } + } + }, + { + "operation": { + "name": "log_mel_spectrogram", + "type": "LogMelSpectrum", + "attrs": { + "chunk_size": 30, + "hop_length": 160, + "n_fft": 400, + "n_mel": 80 + } + } + } + ] + } + } + """) + with open(os.path.join(output_dir, "audio_processor_config.json"), "w") as f: + f.write(audio_processor_json) + + provider_options = [] if "cpu" in provider else [{f"{provider}": {}}] + genai_config = { + "model": { + "bos_token_id": config.bos_token_id, + "context_length": config.max_length, + "decoder": { + "session_options": { + "log_id": "onnxruntime-genai", + "provider_options": provider_options, + }, + "filename": os.path.basename(decoder_path), + "head_size": config.d_model // config.decoder_attention_heads, + "hidden_size": config.d_model, + "inputs": { + "input_ids": "input_ids", + "past_key_names": "past_key_self_%d", + "past_value_names": "past_value_self_%d", + "cross_past_key_names": "past_key_cross_%d", + "cross_past_value_names": "past_value_cross_%d", + }, + "outputs": { + "logits": "logits", + "present_key_names": "present_key_self_%d", + "present_value_names": "present_value_self_%d", + }, + "num_attention_heads": config.decoder_attention_heads, + "num_hidden_layers": config.decoder_layers, + "num_key_value_heads": config.decoder_attention_heads, + }, + "encoder": { + "session_options": { + "log_id": "onnxruntime-genai", + "provider_options": provider_options, + }, + "filename": os.path.basename(encoder_path), + "head_size": config.d_model // config.encoder_attention_heads, + "hidden_size": config.d_model, + "inputs": {"audio_features": "audio_features"}, + "outputs": { + "encoder_hidden_states": "encoder_hidden_states", + "cross_present_key_names": "present_key_cross_%d", + "cross_present_value_names": "present_value_cross_%d", + }, + "num_attention_heads": config.encoder_attention_heads, + "num_hidden_layers": config.encoder_layers, + "num_key_value_heads": config.encoder_attention_heads, + }, + "eos_token_id": config.eos_token_id, + "pad_token_id": config.pad_token_id, + "type": "whisper", + "vocab_size": config.vocab_size, + }, + "search": { + "diversity_penalty": 0.0, + "do_sample": False, + "early_stopping": True, + "length_penalty": 1.0, + "max_length": config.max_length, + "min_length": 0, + "no_repeat_ngram_size": 0, + "num_beams": 1, + "num_return_sequences": 1, + "past_present_share_buffer": provider == "cuda", + "repetition_penalty": 1.0, + "temperature": 1.0, + "top_k": 1, + "top_p": 1.0, + }, + } + + # Requirements for the DMMHA kernel (which is currently + # enabled for CUDA only): + # - Buffer sharing = true + # - New input: past_sequence_length + # - New input: cache_indirection + # Otherwise, buffer sharing should be false and the new inputs + # should not be added for beam search to work in ORT GenAI. + + if provider == "cuda": + # Add inputs for DMMHA kernel + genai_config["model"]["decoder"]["inputs"].update( + { + "past_sequence_length": "past_sequence_length", + "cache_indirection": "cache_indirection", + } + ) + + if output_qk: + genai_config["model"]["decoder"]["outputs"].update( + { + "output_cross_qk_names": "output_cross_qk_%d", + } + ) + + with open(os.path.join(output_dir, "genai_config.json"), "w") as f: + json.dump(genai_config, f, indent=4) + @staticmethod def load_model( model_name_or_path: str, diff --git a/onnxruntime/test/autoep/library/example_plugin_ep.cc b/onnxruntime/test/autoep/library/example_plugin_ep.cc index 2c82b9ace3c61..9978189267a40 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep.cc +++ b/onnxruntime/test/autoep/library/example_plugin_ep.cc @@ -1,45 +1,399 @@ -#include "onnxruntime_cxx_api.h" - +#include #include +#include #include #include +#include #include -#define RETURN_IF_ERROR(fn) \ - do { \ - OrtStatus* status = (fn); \ - if (status != nullptr) { \ - return status; \ - } \ - } while (0) +#include "example_plugin_ep_utils.h" + +#define ORT_API_MANUAL_INIT +#include "onnxruntime_cxx_api.h" +#undef ORT_API_MANUAL_INIT + +struct ExampleEp; + +/// +/// Example implementation of ONNX Mul. Does not handle many things like broadcasting. +/// +struct MulKernel { + MulKernel(const OrtApi& ort_api, const OrtLogger& logger) : ort_api(ort_api), logger(logger) {} + + OrtStatus* Compute(OrtKernelContext* kernel_context) { + RETURN_IF_ERROR(ort_api.Logger_LogMessage(&logger, + OrtLoggingLevel::ORT_LOGGING_LEVEL_INFO, + "MulKernel::Compute", ORT_FILE, __LINE__, __FUNCTION__)); + size_t num_inputs = 0; + RETURN_IF_ERROR(ort_api.KernelContext_GetInputCount(kernel_context, &num_inputs)); + RETURN_IF(num_inputs != 2, ort_api, "Expected 2 inputs for MulKernel"); + + size_t num_outputs = 0; + RETURN_IF_ERROR(ort_api.KernelContext_GetOutputCount(kernel_context, &num_outputs)); + RETURN_IF(num_outputs != 1, ort_api, "Expected 1 output for MulKernel"); + + const OrtValue* input0 = nullptr; + const OrtValue* input1 = nullptr; + RETURN_IF_ERROR(ort_api.KernelContext_GetInput(kernel_context, 0, &input0)); + RETURN_IF_ERROR(ort_api.KernelContext_GetInput(kernel_context, 1, &input1)); + + OrtTensorTypeAndShapeInfo* type_shape0 = nullptr; + OrtTensorTypeAndShapeInfo* type_shape1 = nullptr; + DeferOrtRelease release_type0(&type_shape0, ort_api.ReleaseTensorTypeAndShapeInfo); + DeferOrtRelease release_type1(&type_shape1, ort_api.ReleaseTensorTypeAndShapeInfo); + + RETURN_IF_ERROR(ort_api.GetTensorTypeAndShape(input0, &type_shape0)); + RETURN_IF_ERROR(ort_api.GetTensorTypeAndShape(input1, &type_shape1)); + + ONNXTensorElementDataType elem_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; + RETURN_IF_ERROR(ort_api.GetTensorElementType(type_shape0, &elem_type)); + RETURN_IF(elem_type != ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, ort_api, "Expected float32 inputs"); + + size_t num_dims0 = 0; + size_t num_dims1 = 0; + RETURN_IF_ERROR(ort_api.GetDimensionsCount(type_shape0, &num_dims0)); + RETURN_IF_ERROR(ort_api.GetDimensionsCount(type_shape1, &num_dims1)); + RETURN_IF((num_dims0 == 0) || (num_dims1 == 0), ort_api, "Input has 0 dimensions"); + RETURN_IF(num_dims0 != num_dims1, ort_api, "Expected same dimensions for both inputs"); // No broadcasting + + std::vector dims0(num_dims0, 0); + std::vector dims1(num_dims1, 0); + RETURN_IF_ERROR(ort_api.GetDimensions(type_shape0, dims0.data(), dims0.size())); + RETURN_IF_ERROR(ort_api.GetDimensions(type_shape1, dims1.data(), dims1.size())); + RETURN_IF(dims0 != dims1, ort_api, "Expected same dimensions for both inputs"); // No broadcasting. + + const float* input_data0 = nullptr; + const float* input_data1 = nullptr; + RETURN_IF_ERROR(ort_api.GetTensorMutableData(const_cast(input0), (void**)&input_data0)); // No const-correct API? + RETURN_IF_ERROR(ort_api.GetTensorMutableData(const_cast(input1), (void**)&input_data1)); + + OrtValue* output = nullptr; + RETURN_IF_ERROR(ort_api.KernelContext_GetOutput(kernel_context, 0, dims0.data(), dims0.size(), &output)); + + float* output_data = nullptr; + RETURN_IF_ERROR(ort_api.GetTensorMutableData(output, reinterpret_cast(&output_data))); + + int64_t num_elems = 1; + for (int64_t dim : dims0) { + RETURN_IF(dim < 0, ort_api, "Invalid dimension: negative value detected"); + num_elems *= dim; + } + + for (size_t i = 0; i < static_cast(num_elems); ++i) { + output_data[i] = input_data0[i] * input_data1[i]; + } + + return nullptr; + } + + const OrtApi& ort_api; + const OrtLogger& logger; +}; + +/// +/// Example OrtNodeComputeInfo that represents the computation function for a compiled OrtGraph. +/// +struct ExampleNodeComputeInfo : OrtNodeComputeInfo { + explicit ExampleNodeComputeInfo(ExampleEp& ep); + + static OrtStatus* ORT_API_CALL CreateStateImpl(OrtNodeComputeInfo* this_ptr, + OrtNodeComputeContext* compute_context, + void** compute_state); + static OrtStatus* ORT_API_CALL ComputeImpl(OrtNodeComputeInfo* this_ptr, void* compute_state, + OrtKernelContext* kernel_context); + static void ORT_API_CALL ReleaseStateImpl(OrtNodeComputeInfo* this_ptr, void* compute_state); + + ExampleEp& ep; +}; struct ApiPtrs { const OrtApi& ort_api; - // const OrtEpApi& ep_api; TODO: Add this when we flesh out the EP API. + const OrtEpApi& ep_api; + const OrtModelEditorApi& model_editor_api; }; +/// +/// Example EP that can compile a single Mul operator. +/// struct ExampleEp : OrtEp, ApiPtrs { - ExampleEp(ApiPtrs apis, const std::string& name, const OrtSessionOptions& session_options, const OrtLogger& logger) - : ApiPtrs(apis), name_{name}, session_options_{session_options}, logger_{logger} { - // Initialize the execution provider. + struct Config { + bool enable_ep_context = false; + // Other EP configs (typically extracted from OrtSessionOptions or OrtHardwareDevice(s)) + }; + ExampleEp(ApiPtrs apis, const std::string& name, const Config& config, const OrtLogger& logger) + : ApiPtrs(apis), name_{name}, config_{config}, logger_{logger} { + // Initialize the execution provider. auto status = ort_api.Logger_LogMessage(&logger_, OrtLoggingLevel::ORT_LOGGING_LEVEL_INFO, ("ExampleEp has been created with name " + name_).c_str(), ORT_FILE, __LINE__, __FUNCTION__); // ignore status for now (void)status; + + ort_version_supported = ORT_API_VERSION; // set to the ORT version we were compiled with. + GetName = GetNameImpl; + GetCapability = GetCapabilityImpl; + Compile = CompileImpl; + ReleaseNodeComputeInfos = ReleaseNodeComputeInfosImpl; } ~ExampleEp() { // Clean up the execution provider } + static const char* ORT_API_CALL GetNameImpl(const OrtEp* this_ptr) { + const auto* ep = static_cast(this_ptr); + return ep->name_.c_str(); + } + + static OrtStatus* ORT_API_CALL GetCapabilityImpl(OrtEp* this_ptr, const OrtGraph* graph, + OrtEpGraphSupportInfo* graph_support_info) { + ExampleEp* ep = static_cast(this_ptr); + + OrtArrayOfConstObjects* nodes_array = nullptr; + DeferOrtRelease release_nodes_array(&nodes_array, ep->ort_api.ReleaseArrayOfConstObjects); + + size_t num_nodes = 0; + + RETURN_IF_ERROR(ep->ort_api.Graph_GetNodes(graph, &nodes_array)); + RETURN_IF_ERROR(ep->ort_api.ArrayOfConstObjects_GetSize(nodes_array, &num_nodes)); + + if (num_nodes == 0) { + return nullptr; // No nodes to process + } + + const void* const* nodes_data = nullptr; + RETURN_IF_ERROR(ep->ort_api.ArrayOfConstObjects_GetData(nodes_array, &nodes_data)); + auto nodes_span = gsl::span(reinterpret_cast(nodes_data), num_nodes); + + std::vector supported_nodes; + + for (const OrtNode* node : nodes_span) { + const char* op_type = nullptr; + RETURN_IF_ERROR(ep->ort_api.Node_GetOperatorType(node, &op_type)); + + if (std::strncmp(op_type, "Mul", 4) == 0) { + // Check that Mul has inputs/output of type float + OrtArrayOfConstObjects* inputs_array = nullptr; + OrtArrayOfConstObjects* outputs_array = nullptr; + DeferOrtRelease release_inputs(&inputs_array, ep->ort_api.ReleaseArrayOfConstObjects); + DeferOrtRelease release_outputs(&outputs_array, ep->ort_api.ReleaseArrayOfConstObjects); + + RETURN_IF_ERROR(ep->ort_api.Node_GetInputs(node, &inputs_array)); + RETURN_IF_ERROR(ep->ort_api.Node_GetOutputs(node, &outputs_array)); + + size_t num_inputs = 0; + size_t num_outputs = 0; + RETURN_IF_ERROR(ep->ort_api.ArrayOfConstObjects_GetSize(inputs_array, &num_inputs)); + RETURN_IF_ERROR(ep->ort_api.ArrayOfConstObjects_GetSize(outputs_array, &num_outputs)); + RETURN_IF(num_inputs != 2 || num_outputs != 1, ep->ort_api, "Mul should have 2 inputs and 1 output"); + + const void* const* inputs_data = nullptr; + const void* const* outputs_data = nullptr; + RETURN_IF_ERROR(ep->ort_api.ArrayOfConstObjects_GetData(inputs_array, &inputs_data)); + RETURN_IF_ERROR(ep->ort_api.ArrayOfConstObjects_GetData(outputs_array, &outputs_data)); + + std::array is_float = {false, false, false}; + RETURN_IF_ERROR(IsFloatTensor(ep->ort_api, static_cast(inputs_data[0]), is_float[0])); + RETURN_IF_ERROR(IsFloatTensor(ep->ort_api, static_cast(inputs_data[1]), is_float[1])); + RETURN_IF_ERROR(IsFloatTensor(ep->ort_api, static_cast(outputs_data[0]), is_float[2])); + if (!is_float[0] || !is_float[1] || !is_float[2]) { + continue; // Input or output is not of type float + } + + supported_nodes.push_back(node); // Only support a single Mul for now. + break; + } + } + RETURN_IF_ERROR(ep->ep_api.EpGraphSupportInfo_AddNodesToFuse(graph_support_info, supported_nodes.data(), + supported_nodes.size())); + return nullptr; + } + + static OrtStatus* ORT_API_CALL CompileImpl(_In_ OrtEp* this_ptr, _In_ const OrtGraph** graphs, + _In_ const OrtNode** fused_nodes, _In_ size_t count, + _Out_writes_all_(count) OrtNodeComputeInfo** node_compute_infos, + _Out_writes_(count) OrtNode** ep_context_nodes) { + ExampleEp* ep = static_cast(this_ptr); + + if (count != 1) { + return ep->ort_api.CreateStatus(ORT_EP_FAIL, "Expected to compile a single graph"); + } + + OrtArrayOfConstObjects* nodes_array = nullptr; + DeferOrtRelease release_nodes(&nodes_array, ep->ort_api.ReleaseArrayOfConstObjects); + size_t num_nodes = 0; + + RETURN_IF_ERROR(ep->ort_api.Graph_GetNodes(graphs[0], &nodes_array)); + RETURN_IF_ERROR(ep->ort_api.ArrayOfConstObjects_GetSize(nodes_array, &num_nodes)); + + if (num_nodes != 1) { + return ep->ort_api.CreateStatus(ORT_EP_FAIL, "Expected to compile a single Mul node"); + } + + const OrtNode* node_to_compile = nullptr; + RETURN_IF_ERROR(ep->ort_api.ArrayOfConstObjects_GetElementAt(nodes_array, 0, + reinterpret_cast(&node_to_compile))); + + const char* node_op_type = nullptr; + RETURN_IF_ERROR(ep->ort_api.Node_GetOperatorType(node_to_compile, &node_op_type)); + + if (std::strncmp(node_op_type, "Mul", 4) != 0) { + return ep->ort_api.CreateStatus(ORT_EP_FAIL, "Expected to compile a single Mul node"); + } + + // Now we know we're compiling a single Mul node. + // Associate the name of the fused node with our MulKernel. + const char* fused_node_name = nullptr; + RETURN_IF_ERROR(ep->ort_api.Node_GetName(fused_nodes[0], &fused_node_name)); + + ep->kernels.emplace(std::string(fused_node_name), std::make_unique(ep->ort_api, ep->logger_)); + + // Update the OrtNodeComputeInfo associated with the graph. + auto node_compute_info = std::make_unique(*ep); + node_compute_infos[0] = node_compute_info.release(); + + // Create EpContext nodes for the fused nodes we compiled. + if (ep->config_.enable_ep_context) { + assert(ep_context_nodes != nullptr); + RETURN_IF_ERROR(ep->CreateEpContextNodes(gsl::span(fused_nodes, count), + gsl::span(ep_context_nodes, count))); + } + + return nullptr; + } + + static void ORT_API_CALL ReleaseNodeComputeInfosImpl(OrtEp* this_ptr, + OrtNodeComputeInfo** node_compute_infos, + size_t num_node_compute_infos) { + (void)this_ptr; + for (size_t i = 0; i < num_node_compute_infos; i++) { + delete node_compute_infos[i]; + } + } + + // Creates EPContext nodes from the given fused nodes. + // This is an example implementation that can be used to generate an EPContext model. However, this example EP + // cannot currently run the EPContext model. + OrtStatus* CreateEpContextNodes(gsl::span fused_nodes, + /*out*/ gsl::span ep_context_nodes) { + assert(fused_nodes.size() == ep_context_nodes.size()); + + // Helper to collect input or output names from an array of OrtValueInfo instances. + auto collect_input_output_names = [&](const OrtArrayOfConstObjects& value_infos, + std::vector& result) -> OrtStatus* { + size_t num_values = 0; + RETURN_IF_ERROR(ort_api.ArrayOfConstObjects_GetSize(&value_infos, &num_values)); + + std::vector value_names(num_values, nullptr); + + for (size_t i = 0; i < num_values; i++) { + const void* value_info = nullptr; // Is a const OrtValueInfo* + RETURN_IF_ERROR(ort_api.ArrayOfConstObjects_GetElementAt(&value_infos, i, &value_info)); + RETURN_IF_ERROR(ort_api.GetValueInfoName(static_cast(value_info), &value_names[i])); + } + + result = std::move(value_names); + return nullptr; + }; + + // Create an "EPContext" node for every fused node. + for (size_t i = 0; i < fused_nodes.size(); ++i) { + const OrtNode* fused_node = fused_nodes[i]; + const char* fused_node_name = nullptr; + + RETURN_IF_ERROR(ort_api.Node_GetName(fused_node, &fused_node_name)); + + OrtArrayOfConstObjects* fused_node_inputs = nullptr; + OrtArrayOfConstObjects* fused_node_outputs = nullptr; + DeferOrtRelease defer_release0(&fused_node_inputs, ort_api.ReleaseArrayOfConstObjects); + DeferOrtRelease defer_release1(&fused_node_outputs, ort_api.ReleaseArrayOfConstObjects); + + RETURN_IF_ERROR(ort_api.Node_GetInputs(fused_node, &fused_node_inputs)); + RETURN_IF_ERROR(ort_api.Node_GetOutputs(fused_node, &fused_node_outputs)); + + std::vector input_names; + std::vector output_names; + + RETURN_IF_ERROR(collect_input_output_names(*fused_node_inputs, /*out*/ input_names)); + RETURN_IF_ERROR(collect_input_output_names(*fused_node_outputs, /*out*/ output_names)); + + int64_t is_main_context = (i == 0); + int64_t embed_mode = 1; + + // Create node attributes. The CreateNode() function copies the attributes, so we have to release them. + std::array attributes = {}; + DeferOrtRelease defer_release_attrs(attributes.data(), attributes.size(), ort_api.ReleaseOpAttr); + + RETURN_IF_ERROR(ort_api.CreateOpAttr("ep_cache_context", "binary_data", 1, ORT_OP_ATTR_STRING, &attributes[0])); + RETURN_IF_ERROR(ort_api.CreateOpAttr("main_context", &is_main_context, 1, ORT_OP_ATTR_INT, &attributes[1])); + RETURN_IF_ERROR(ort_api.CreateOpAttr("embed_mode", &embed_mode, 1, ORT_OP_ATTR_INT, &attributes[2])); + RETURN_IF_ERROR(ort_api.CreateOpAttr("ep_sdk_version", "1", 1, ORT_OP_ATTR_STRING, &attributes[3])); + RETURN_IF_ERROR(ort_api.CreateOpAttr("partition_name", fused_node_name, 1, ORT_OP_ATTR_STRING, &attributes[4])); + RETURN_IF_ERROR(ort_api.CreateOpAttr("source", this->name_.c_str(), 1, ORT_OP_ATTR_STRING, &attributes[5])); + + RETURN_IF_ERROR(model_editor_api.CreateNode("EPContext", "com.microsoft", fused_node_name, + input_names.data(), input_names.size(), + output_names.data(), output_names.size(), + attributes.data(), attributes.size(), + &ep_context_nodes[i])); + } + + return nullptr; + } + std::string name_; - const OrtSessionOptions& session_options_; + Config config_{}; const OrtLogger& logger_; + std::unordered_map> kernels; }; +// +// Implementation of ExampleNodeComuteInfo +// +ExampleNodeComputeInfo::ExampleNodeComputeInfo(ExampleEp& ep) : ep(ep) { + ort_version_supported = ORT_API_VERSION; + CreateState = CreateStateImpl; + Compute = ComputeImpl; + ReleaseState = ReleaseStateImpl; +} + +OrtStatus* ExampleNodeComputeInfo::CreateStateImpl(OrtNodeComputeInfo* this_ptr, + OrtNodeComputeContext* compute_context, + void** compute_state) { + auto* node_compute_info = static_cast(this_ptr); + ExampleEp& ep = node_compute_info->ep; + + std::string fused_node_name = ep.ep_api.NodeComputeContext_NodeName(compute_context); + auto kernel_it = ep.kernels.find(fused_node_name); + if (kernel_it == ep.kernels.end()) { + std::string message = "Unable to get kernel for fused node with name " + fused_node_name; + return ep.ort_api.CreateStatus(ORT_EP_FAIL, message.c_str()); + } + + MulKernel& kernel = *kernel_it->second; + *compute_state = &kernel; + return nullptr; +} + +OrtStatus* ExampleNodeComputeInfo::ComputeImpl(OrtNodeComputeInfo* this_ptr, void* compute_state, + OrtKernelContext* kernel_context) { + (void)this_ptr; + MulKernel& kernel = *reinterpret_cast(compute_state); + return kernel.Compute(kernel_context); +} + +void ExampleNodeComputeInfo::ReleaseStateImpl(OrtNodeComputeInfo* this_ptr, void* compute_state) { + (void)this_ptr; + MulKernel& kernel = *reinterpret_cast(compute_state); + (void)kernel; + // Do nothing for this example. +} + +/// +/// Example EP factory that can create an OrtEp and return information about the supported hardware devices. +/// struct ExampleEpFactory : OrtEpFactory, ApiPtrs { ExampleEpFactory(const char* ep_name, ApiPtrs apis) : ApiPtrs(apis), ep_name_{ep_name} { ort_version_supported = ORT_API_VERSION; // set to the ORT version we were compiled with. @@ -139,7 +493,16 @@ struct ExampleEpFactory : OrtEpFactory, ApiPtrs { // const OrtHardwareDevice* device = devices[0]; // const OrtKeyValuePairs* ep_metadata = ep_metadata[0]; - auto dummy_ep = std::make_unique(*factory, factory->ep_name_, *session_options, *logger); + // Create EP configuration from session options, if needed. + // Note: should not store a direct reference to the session options object as its lifespan is not guaranteed. + std::string ep_context_enable; + RETURN_IF_ERROR(GetSessionConfigEntryOrDefault(factory->ort_api, *session_options, + "ep.context_enable", "0", ep_context_enable)); + + ExampleEp::Config config = {}; + config.enable_ep_context = ep_context_enable == "1"; + + auto dummy_ep = std::make_unique(*factory, factory->ep_name_, config, *logger); *ep = dummy_ep.release(); return nullptr; @@ -168,11 +531,13 @@ extern "C" { EXPORT_SYMBOL OrtStatus* CreateEpFactories(const char* registration_name, const OrtApiBase* ort_api_base, OrtEpFactory** factories, size_t max_factories, size_t* num_factories) { const OrtApi* ort_api = ort_api_base->GetApi(ORT_API_VERSION); - // const OrtEpApi* ep_api = ort_api->GetEpApi(); + const OrtEpApi* ort_ep_api = ort_api->GetEpApi(); + const OrtModelEditorApi* ort_model_editor_api = ort_api->GetModelEditorApi(); // Factory could use registration_name or define its own EP name. std::unique_ptr factory = std::make_unique(registration_name, - ApiPtrs{*ort_api}); + ApiPtrs{*ort_api, *ort_ep_api, + *ort_model_editor_api}); if (max_factories < 1) { return ort_api->CreateStatus(ORT_INVALID_ARGUMENT, diff --git a/onnxruntime/test/autoep/library/example_plugin_ep_utils.cc b/onnxruntime/test/autoep/library/example_plugin_ep_utils.cc new file mode 100644 index 0000000000000..549551931c647 --- /dev/null +++ b/onnxruntime/test/autoep/library/example_plugin_ep_utils.cc @@ -0,0 +1,52 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "example_plugin_ep_utils.h" + +#include + +OrtStatus* GetSessionConfigEntryOrDefault(const OrtApi& ort_api, const OrtSessionOptions& session_options, + const char* config_key, const std::string& default_val, + /*out*/ std::string& config_val) { + int has_config = 0; + RETURN_IF_ERROR(ort_api.HasSessionConfigEntry(&session_options, config_key, &has_config)); + + if (has_config != 1) { + config_val = default_val; + return nullptr; + } + + size_t size = 0; + RETURN_IF_ERROR(ort_api.GetSessionConfigEntry(&session_options, config_key, nullptr, &size)); + + config_val.resize(size); + RETURN_IF_ERROR(ort_api.GetSessionConfigEntry(&session_options, config_key, config_val.data(), &size)); + config_val.resize(size - 1); // remove the terminating '\0' + + return nullptr; +} + +OrtStatus* IsFloatTensor(const OrtApi& ort_api, const OrtValueInfo* value_info, bool& result) { + result = false; + + const OrtTypeInfo* type_info = nullptr; + RETURN_IF_ERROR(ort_api.GetValueInfoTypeInfo(value_info, &type_info)); + + ONNXType onnx_type = ONNX_TYPE_UNKNOWN; + RETURN_IF_ERROR(ort_api.GetOnnxTypeFromTypeInfo(type_info, &onnx_type)); + if (onnx_type != ONNX_TYPE_TENSOR) { + return nullptr; + } + + const OrtTensorTypeAndShapeInfo* type_shape = nullptr; + RETURN_IF_ERROR(ort_api.CastTypeInfoToTensorInfo(type_info, &type_shape)); + + ONNXTensorElementDataType elem_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; + RETURN_IF_ERROR(ort_api.GetTensorElementType(type_shape, &elem_type)); + if (elem_type != ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) { + return nullptr; + } + + result = true; + return nullptr; +} diff --git a/onnxruntime/test/autoep/library/example_plugin_ep_utils.h b/onnxruntime/test/autoep/library/example_plugin_ep_utils.h new file mode 100644 index 0000000000000..ae0a86bbb7222 --- /dev/null +++ b/onnxruntime/test/autoep/library/example_plugin_ep_utils.h @@ -0,0 +1,58 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include + +#define ORT_API_MANUAL_INIT +#include "onnxruntime_cxx_api.h" +#undef ORT_API_MANUAL_INIT + +#define RETURN_IF_ERROR(fn) \ + do { \ + OrtStatus* status = (fn); \ + if (status != nullptr) { \ + return status; \ + } \ + } while (0) + +#define RETURN_IF(cond, ort_api, msg) \ + do { \ + if ((cond)) { \ + return (ort_api).CreateStatus(ORT_EP_FAIL, (msg)); \ + } \ + } while (0) + +// Helper to release Ort one or more objects obtained from the public C API at the end of their scope. +template +struct DeferOrtRelease { + DeferOrtRelease(T** object_ptr, std::function release_func) + : objects_(object_ptr), count_(1), release_func_(release_func) {} + + DeferOrtRelease(T** objects, size_t count, std::function release_func) + : objects_(objects), count_(count), release_func_(release_func) {} + + ~DeferOrtRelease() { + if (objects_ != nullptr && count_ > 0) { + for (size_t i = 0; i < count_; ++i) { + if (objects_[i] != nullptr) { + release_func_(objects_[i]); + objects_[i] = nullptr; + } + } + } + } + T** objects_ = nullptr; + size_t count_ = 0; + std::function release_func_ = nullptr; +}; + +// Returns an entry in the session option configurations, or a default value if not present. +OrtStatus* GetSessionConfigEntryOrDefault(const OrtApi& ort_api, const OrtSessionOptions& session_options, + const char* config_key, const std::string& default_val, + /*out*/ std::string& config_val); + +// Returns true (via output parameter) if the given OrtValueInfo represents a float tensor. +OrtStatus* IsFloatTensor(const OrtApi& ort_api, const OrtValueInfo* value_info, bool& result); diff --git a/onnxruntime/test/autoep/test_autoep_selection.cc b/onnxruntime/test/autoep/test_autoep_selection.cc index cea1299adc26f..eed2068f92f53 100644 --- a/onnxruntime/test/autoep/test_autoep_selection.cc +++ b/onnxruntime/test/autoep/test_autoep_selection.cc @@ -583,6 +583,128 @@ TEST(OrtEpLibrary, LoadUnloadPluginLibraryCxxApi) { ort_env->UnregisterExecutionProviderLibrary(registration_name.c_str()); } +static void RunModelWithPluginEp(Ort::SessionOptions& session_options) { + Ort::Session session(*ort_env, ORT_TSTR("testdata/mul_1.onnx"), session_options); + + // Create input + Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); + std::vector shape = {3, 2}; + std::vector input0_data(6, 2.0f); + std::vector ort_inputs; + std::vector ort_input_names; + + ort_inputs.emplace_back(Ort::Value::CreateTensor( + memory_info, input0_data.data(), input0_data.size(), shape.data(), shape.size())); + ort_input_names.push_back("X"); + + // Run session and get outputs + std::array output_names{"Y"}; + std::vector ort_outputs = session.Run(Ort::RunOptions{nullptr}, ort_input_names.data(), ort_inputs.data(), + ort_inputs.size(), output_names.data(), output_names.size()); + + // Check expected output values + Ort::Value& ort_output = ort_outputs[0]; + const float* output_data = ort_output.GetTensorData(); + gsl::span output_span(output_data, 6); + EXPECT_THAT(output_span, ::testing::ElementsAre(2, 4, 6, 8, 10, 12)); +} + +// Creates a session with the example plugin EP and runs a model with a single Mul node. +// Uses AppendExecutionProvider_V2 to append the example plugin EP to the session. +TEST(OrtEpLibrary, PluginEp_AppendV2_MulInference) { + const std::filesystem::path& library_path = example_plugin_info.library_path; + const std::string& registration_name = example_plugin_info.registration_name; + + ort_env->RegisterExecutionProviderLibrary(registration_name.c_str(), library_path.c_str()); + + { + std::vector ep_devices = ort_env->GetEpDevices(); + + // Find the OrtEpDevice associated with our example plugin EP. + Ort::ConstEpDevice plugin_ep_device; + for (Ort::ConstEpDevice& device : ep_devices) { + if (std::string(device.EpName()) == registration_name) { + plugin_ep_device = device; + break; + } + } + ASSERT_NE(plugin_ep_device, nullptr); + + // Create session with example plugin EP + Ort::SessionOptions session_options; + std::unordered_map ep_options; + session_options.AppendExecutionProvider_V2(*ort_env, std::vector{plugin_ep_device}, ep_options); + + RunModelWithPluginEp(session_options); + } + + ort_env->UnregisterExecutionProviderLibrary(registration_name.c_str()); +} + +// Creates a session with the example plugin EP and runs a model with a single Mul node. +// Uses the PREFER_CPU policy to append the example plugin EP to the session. +TEST(OrtEpLibrary, PluginEp_PreferCpu_MulInference) { + const std::filesystem::path& library_path = example_plugin_info.library_path; + const std::string& registration_name = example_plugin_info.registration_name; + + ort_env->RegisterExecutionProviderLibrary(registration_name.c_str(), library_path.c_str()); + + { + // PREFER_CPU pick our example EP over ORT CPU EP. TODO: Actually assert this. + Ort::SessionOptions session_options; + session_options.SetEpSelectionPolicy(OrtExecutionProviderDevicePolicy_PREFER_CPU); + RunModelWithPluginEp(session_options); + } + + ort_env->UnregisterExecutionProviderLibrary(registration_name.c_str()); +} + +// Generate an EPContext model with a plugin EP. +// This test uses the OrtCompileApi but could also be done by setting the appropriate session option configs. +TEST(OrtEpLibrary, PluginEp_GenEpContextModel) { + const std::filesystem::path& library_path = example_plugin_info.library_path; + const std::string& registration_name = example_plugin_info.registration_name; + + ort_env->RegisterExecutionProviderLibrary(registration_name.c_str(), library_path.c_str()); + + { + const ORTCHAR_T* input_model_file = ORT_TSTR("testdata/mul_1.onnx"); + const ORTCHAR_T* output_model_file = ORT_TSTR("plugin_ep_mul_1_ctx.onnx"); + std::filesystem::remove(output_model_file); + + std::vector ep_devices = ort_env->GetEpDevices(); + + // Find the OrtEpDevice associated with our example plugin EP. + Ort::ConstEpDevice plugin_ep_device; + for (Ort::ConstEpDevice& device : ep_devices) { + if (std::string(device.EpName()) == registration_name) { + plugin_ep_device = device; + break; + } + } + ASSERT_NE(plugin_ep_device, nullptr); + + // Create session with example plugin EP + Ort::SessionOptions session_options; + std::unordered_map ep_options; + + session_options.AppendExecutionProvider_V2(*ort_env, std::vector{plugin_ep_device}, ep_options); + + // Create model compilation options from the session options. + Ort::ModelCompilationOptions compile_options(*ort_env, session_options); + compile_options.SetInputModelPath(input_model_file); + compile_options.SetOutputModelPath(output_model_file); + + // Compile the model. + Ort::Status status = Ort::CompileModel(*ort_env, compile_options); + ASSERT_TRUE(status.IsOK()) << status.GetErrorMessage(); + + // Make sure the compiled model was generated. + ASSERT_TRUE(std::filesystem::exists(output_model_file)); + } + + ort_env->UnregisterExecutionProviderLibrary(registration_name.c_str()); +} } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/contrib_ops/cuda_kernels/fpA_intB_gemm_kernel_test.cc b/onnxruntime/test/contrib_ops/cuda_kernels/fpA_intB_gemm_kernel_test.cc new file mode 100644 index 0000000000000..cac6d46226ef8 --- /dev/null +++ b/onnxruntime/test/contrib_ops/cuda_kernels/fpA_intB_gemm_kernel_test.cc @@ -0,0 +1,622 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// Test can be run like the following: +// ./onnxruntime_test_all --gtest_filter=CUDA_EP_Unittest.* + +#include +#include +#include + +#include "cutlass/numeric_types.h" +#include "contrib_ops/cuda/llm/common/cuda_runtime_utils.h" +#include "contrib_ops/cuda/llm/fpA_intB_gemm/fpA_intB_gemm.h" +#include "contrib_ops/cuda/llm/fpA_intB_gemv/fpA_intB_gemv.h" +#include "contrib_ops/cuda/quantization/matmul_nbits.cuh" +#include "contrib_ops/cuda/quantization/dequantize_blockwise.cuh" +#include "core/providers/cuda/shared_inc/fpgeneric.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace wo = onnxruntime::llm::kernels::fpA_intB_gemv; +using onnxruntime::llm::cutlass_extensions::CutlassGemmConfig; + +namespace { +constexpr bool kPipelineMode = true; // CI pipeline? + +std::vector get_m_list() { + if (kPipelineMode) { + return {1, 14}; + } else { + return {1, 4, 8, 14, 256, 512, 1024, 2048}; + } +} + +std::vector> get_n_k_list() { + if (kPipelineMode) { + return {{5120, 3072}}; + } else { + // N and K of phi4 mini. + return {{5120, 3072}, {8192, 3072}, {3072, 8192}, {200064, 3072}}; + } +} + +struct CudaBuffer { + void* _data; + size_t _bytes; + + CudaBuffer(size_t size_in_bytes) : _bytes(size_in_bytes) { + cudaMalloc(&_data, _bytes); + } + + template + T* data() { + return reinterpret_cast(_data); + } + + void to_cpu(void* dst) { + cudaMemcpy(dst, _data, _bytes, cudaMemcpyDeviceToHost); + } + + void from_cpu(void* src) { + cudaMemcpy(_data, src, _bytes, cudaMemcpyHostToDevice); + } + + ~CudaBuffer() { + cudaFree(_data); + } +}; + +template +float compare(void* a, void* b, size_t size, float scale) { + auto pa = reinterpret_cast(a); + auto pb = reinterpret_cast(b); + float max_diff = 0.f; + float total_diff = 0.f; + float max_val = 0.f; + int diff_count = 0; + float threshold = 1e-7; + for (size_t n = 0; n < size; ++n) { + float va = static_cast(pa[n]); + float vb = static_cast(pb[n]); + max_val = std::max(max_val, vb); + float diff = std::abs(va - vb); + if (diff > threshold) { + max_diff = std::max(max_diff, diff); + total_diff += diff; + ++diff_count; + } + } + + float diff_threshold = max_val * scale; + if constexpr (std::is_same_v) { + // fp16 precision is about 3.3 decimal digits, and bf16 is about 2.0–2.3 decimal digits, so we use 10x threshold. + diff_threshold *= 15.f; + } else { + diff_threshold *= 1.5f; + } + + bool passed = max_diff <= diff_threshold; + if (!passed) { + printf("max diff %f (threshold %f), avg diff %f, diff count %d/%zu\n", + max_diff, diff_threshold, total_diff / diff_count, diff_count, size); + } + + return max_diff <= diff_threshold; +} + +template +void random_fill(std::vector& vec, T2 min_value, T2 max_value) { + std::mt19937 gen(rand()); + std::uniform_real_distribution dis(static_cast(min_value), static_cast(max_value)); + for (auto& v : vec) { + v = static_cast(dis(gen)); + } +} + +std::vector filter_gemm_configs(const std::vector& configs, int k) { + std::vector rets; + for (auto config : configs) { + if (config.stages >= 5) { + continue; + } + + if (config.split_k_style != onnxruntime::llm::cutlass_extensions::SplitKStyle::NO_SPLIT_K) { + int k_size = (k + config.split_k_factor - 1) / config.split_k_factor; + if (k_size % 64) { + continue; + } + } + rets.push_back(config); + } + return rets; +} + +template +struct cutlassTypeMapper { +}; + +#define CUTLASS_TYPE_MAPPER_REGISTRY( \ + CudaKernelType, CudaAType, CutlassWType, WElemBits, CutlassQuantOp) \ + template <> \ + struct cutlassTypeMapper { \ + using AType = CudaAType; \ + using WType = CutlassWType; \ + static constexpr cutlass::WeightOnlyQuantOp QuantOp = CutlassQuantOp; \ + static constexpr int WSizeInBits = WElemBits; \ + static std::string ATypeStr() { return std::is_same_v ? "Fp16" : "BF16"; } \ + static std::string WTypeStr() { return WSizeInBits == 4 ? "Int4" : "Int8"; } \ + }; + +CUTLASS_TYPE_MAPPER_REGISTRY(wo::KernelType::FP16Int8Groupwise, half, uint8_t, 8, + cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS); +CUTLASS_TYPE_MAPPER_REGISTRY(wo::KernelType::BF16Int8Groupwise, __nv_bfloat16, uint8_t, 8, + cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS); +CUTLASS_TYPE_MAPPER_REGISTRY(wo::KernelType::FP16Int4Groupwise, half, cutlass::uint4b_t, 4, + cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS); +CUTLASS_TYPE_MAPPER_REGISTRY(wo::KernelType::BF16Int4Groupwise, __nv_bfloat16, cutlass::uint4b_t, 4, + cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS); + +template +float measure_kernel_time(Func kernel_launcher, int warmup, int repeats, cudaStream_t s) { + cudaEvent_t begin, end; + cudaEventCreate(&begin); + cudaEventCreate(&end); + + for (int i = 0; i < warmup; ++i) { + kernel_launcher(); + } + cudaEventRecord(begin, s); + for (int i = 0; i < repeats; ++i) { + kernel_launcher(); + } + cudaEventRecord(end, s); + cudaEventSynchronize(end); + float time; + cudaEventElapsedTime(&time, begin, end); + cudaEventDestroy(begin); + cudaEventDestroy(end); + return time / repeats; +} + +template +void run_cutlass_kernel([[maybe_unused]] void* scaled_act, Runner& runner, wo::Params& params, Config& config, + char* ws, size_t ws_size, cudaStream_t stream) { + static constexpr cutlass::WeightOnlyQuantOp QuantOp = cutlassTypeMapper::QuantOp; + void* act = params.act; + if (params.act_scale) { + ORT_THROW("act_scale is not supported in this test fixture."); + } + if (QuantOp == cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS) { + runner.gemm(act, params.weight, params.scales, params.zeros, params.bias, params.out, params.m, params.n, + params.k, params.groupsize, config, ws, ws_size, stream); + } +} + +struct BenchmarkResult { + std::string a_type; + std::string b_type; + int m; + int n; + int k; + int block_size; + float cuda_time_us; + float cutlass_time_us; + float nbits_time_us; + float naive_time_us; + float speedup_cuda_vs_cutlass; + float speedup_cuda_vs_nbits; + float speedup_best_vs_naive; + + float best_time_us() const { + float best = cutlass_time_us; + if (cuda_time_us > 0.f) { + best = std::min(best, cuda_time_us); + } + if (nbits_time_us > 0.f) { + best = std::min(best, nbits_time_us); + } + return best; + } +}; + +void PrintBenchmarkSummary(std::vector& benchmark_results) { + std::cout << "\nBenchmark of FpA_IntB_GEMV, FpA_IntB_GEMM, MatMulNBits, Naive (DQ + GEMM) kernels (latency in microseconds):\n"; + constexpr size_t kLength = 139; + std::cout << std::string(kLength, '-') << std::endl; + std::cout << std::left << std::setw(6) << "A" + << std::setw(6) << "W" + << std::setw(6) << "m" + << std::setw(8) << "n" + << std::setw(7) << "k" + << std::setw(12) << "block_size" + << std::setw(12) << "gemv (us)" + << std::setw(12) << "gemm (us)" + << std::setw(12) << "nbits (us)" + << std::setw(12) << "best (us)" + << std::setw(12) << "naive (us)" + << std::setw(12) << "gemm/gemv" + << std::setw(12) << "nbits/gemv" + << std::setw(12) << "best/naive" + << std::endl; + std::cout << std::string(kLength, '-') << std::endl; + + std::cout << std::fixed << std::setprecision(3); + + for (const auto& result : benchmark_results) { + std::cout << std::left << std::setw(6) << result.a_type + << std::setw(6) << result.b_type + << std::setw(6) << result.m + << std::setw(8) << result.n + << std::setw(7) << result.k + << std::setw(12) << result.block_size + << std::setw(12) << result.cuda_time_us + << std::setw(12) << result.cutlass_time_us + << std::setw(12) << result.nbits_time_us + << std::setw(12) << result.best_time_us() + << std::setw(12) << result.naive_time_us + << std::setw(12) << result.speedup_cuda_vs_cutlass + << std::setw(12) << result.speedup_cuda_vs_nbits + << std::setw(12) << result.speedup_best_vs_naive + << std::endl; + } + std::cout << std::string(kLength, '-') << std::endl; +} + +template +class KernelTestFixture : public ::testing::Test { + protected: + int m_, n_, k_, block_size_; + int warmup_ = 10; + int repeats_ = 30; + cudaDeviceProp device_prop_; + std::shared_ptr d_act_, d_act_scale_, d_weight_, d_scales_, d_zeros_, d_bias_, d_out_; + std::vector::AType> h_act_, h_act_scale_, h_scales_, h_zeros_, h_bias_, h_out1_, h_out2_; + std::vector h_weight_; + std::vector benchmark_results_; + cudaStream_t s_; + cublasHandle_t cublas_handle_; + + static constexpr int WSizeInBits = cutlassTypeMapper::WSizeInBits; + + void SetUp() override { + int device; + CUDA_CALL_THROW(cudaGetDevice(&device)); + CUDA_CALL_THROW(cudaGetDeviceProperties(&device_prop_, device)); + std::srand(20240123); + cudaStreamCreate(&s_); + CUBLAS_CALL_THROW(cublasCreate(&cublas_handle_)); + CUBLAS_CALL_THROW(cublasSetStream(cublas_handle_, s_)); + } + + void TearDown() override { + PrintBenchmarkSummary(benchmark_results_); + cudaStreamDestroy(s_); + cublasDestroy(cublas_handle_); + } + + void InitBuffers(int m, int n, int k, int block_size) { + m_ = m; + n_ = n; + k_ = k; + block_size_ = block_size; + + if (cutlassTypeMapper::QuantOp == cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS) { + ORT_ENFORCE(block_size_ == 64 || block_size_ == 128); + ORT_ENFORCE(k_ % block_size_ == 0); + } + + using AType = typename cutlassTypeMapper::AType; + + constexpr size_t ATypeBytes = sizeof(AType); + const size_t m_x_k = static_cast(m_) * static_cast(k_); + const size_t n_x_k = static_cast(n_) * static_cast(k_); + const size_t m_x_n = static_cast(m_) * static_cast(n_); + d_act_ = std::make_shared(m_x_k * ATypeBytes); + d_act_scale_ = std::make_shared(static_cast(k_) * ATypeBytes); + d_weight_ = std::make_shared(n_x_k * WSizeInBits / static_cast(8)); + d_scales_ = std::make_shared(n_x_k / static_cast(block_size_) * ATypeBytes); + d_zeros_ = std::make_shared(n_x_k / static_cast(block_size_) * ATypeBytes); + d_bias_ = std::make_shared(static_cast(n_) * ATypeBytes); + d_out_ = std::make_shared(m_x_n * ATypeBytes); + + h_act_.resize(m_x_k); + h_act_scale_.resize(static_cast(k_)); + h_weight_.resize(n_x_k); + h_scales_.resize(n_x_k / static_cast(block_size_)); + h_zeros_.resize(n_x_k / static_cast(block_size_)); + h_bias_.resize(static_cast(n_)); + h_out1_.resize(m_x_n); + h_out2_.resize(m_x_n); + + random_fill(h_act_, -1.f, 1.f); + random_fill(h_act_scale_, -1.f, 1.f); + random_fill(h_scales_, -1.f, 1.f); + random_fill(h_zeros_, -1.f, 1.f); + random_fill(h_bias_, -1.f, 1.f); + + for (uint8_t& v : h_weight_) { + v = rand() % 256; + } + + d_act_->from_cpu(h_act_.data()); + d_act_scale_->from_cpu(h_act_scale_.data()); + d_weight_->from_cpu(h_weight_.data()); + d_scales_->from_cpu(h_scales_.data()); + d_zeros_->from_cpu(h_zeros_.data()); + d_bias_->from_cpu(h_bias_.data()); + } + + bool BenchmarkAndVerifyKernel() { + std::cout << "m=" << m_ << ", n=" << n_ << ", k=" << k_ << ", block_size=" << block_size_ << std::endl; + + void* p_act_scale = nullptr; + void* p_zeros = nullptr; + void* p_bias = nullptr; + + if (block_size_ != 0) { + p_zeros = d_zeros_->data(); + if constexpr (has_bias) { + p_bias = d_bias_->data(); + } + if constexpr (has_act_scale) { + p_act_scale = d_act_scale_->data(); + } + } + + wo::Params params(d_act_->data(), p_act_scale, d_weight_->data(), d_scales_->data(), p_zeros, p_bias, + d_out_->data(), 1.f, m_, n_, k_, block_size_, KT); + + //------------------------ + // Run FpA_IntB_Gemv CUDA kernel + float cuda_time_ms = 0.f; + if (m_ < 16) { + cuda_time_ms = measure_kernel_time( + [&]() { + int arch = onnxruntime::llm::common::getSMVersion(); + ORT_ENFORCE(wo::is_supported(arch, params.type)); + wo::kernel_launcher(arch, params, s_); + }, + warmup_, repeats_, s_); + d_out_->to_cpu(h_out1_.data()); + } + + // ------------------------ + // Run FpA_IntB_Gemm CUTLASS kernel + using AType = typename cutlassTypeMapper::AType; + using WType = typename cutlassTypeMapper::WType; + using onnxruntime::llm::kernels::cutlass_kernels::CutlassFpAIntBGemmRunner; + auto runner = std::make_shared::QuantOp>>(); + auto& gemm_runner = *runner; + int ws_bytes = gemm_runner.getWorkspaceSize(m_, n_, k_); + CudaBuffer ws_buffer(ws_bytes); + char* ws_ptr = reinterpret_cast(ws_buffer.data()); + + auto configs = gemm_runner.getConfigs(); + + if constexpr (filter_configs) { + configs = filter_gemm_configs(configs, k_); + } + + float fast_time_ms = std::numeric_limits::max(); + CutlassGemmConfig best_config = configs[0]; + + for (auto& config : configs) { + float time = std::numeric_limits::max(); + try { + time = measure_kernel_time( + [&]() { + run_cutlass_kernel(d_act_->data(), gemm_runner, params, config, ws_ptr, ws_bytes, s_); + }, + 2, 5, s_); + } catch (std::exception const& e) { + std::ostringstream msg; + msg << "Failed to profile m=" << params.m << ", n=" << params.n << ", k=" << params.k << "for configuration:\n"; + msg << config.toString(); + msg << "\nException:" << e.what() << "\n"; + std::cout << msg.str(); + cudaGetLastError(); // Reset the last cudaError to cudaSuccess. + continue; + } + if (time < fast_time_ms) { + fast_time_ms = time; + best_config = config; + } + } + + float cutlass_time_ms = measure_kernel_time( + [&]() { + run_cutlass_kernel(d_act_->data(), gemm_runner, params, best_config, ws_ptr, ws_bytes, s_); + }, + warmup_, repeats_, s_); + d_out_->to_cpu(h_out2_.data()); + + // ------------------------ + // Compare FpA_IntB_Gemv and FpA_IntB_Gemm outputs. + bool pass = true; + if (m_ < 16) { + float quant_scale = 1.f / (1 << (WSizeInBits - 1)); + const size_t m_x_n = static_cast(m_) * static_cast(n_); + pass = compare(h_out1_.data(), h_out2_.data(), m_x_n, quant_scale); + } + + // ------------------------ + // Run MatMulNBits kernel. + // Note that it runs on random data, so the output is not compared. + float nbits_time_ms = 0.f; + float naive_time_ms = 0.f; + if constexpr (KT == wo::KernelType::FP16Int8Groupwise || KT == wo::KernelType::FP16Int4Groupwise) { + const size_t n_x_k = static_cast(n_) * static_cast(k_); + std::vector h_uint8_zeros(n_x_k / static_cast(block_size_)); + for (uint8_t& v : h_uint8_zeros) { + v = rand() % 256; + } + + ORT_ENFORCE(k_ / block_size_ * WSizeInBits % 8 == 0); + CudaBuffer d_uint8_zeros(n_x_k / static_cast(block_size_) * WSizeInBits / static_cast(8)); + d_uint8_zeros.from_cpu(h_uint8_zeros.data()); + + if (m_ == 1) { + nbits_time_ms = measure_kernel_time( + [&]() { + onnxruntime::contrib::cuda::TryMatMulNBits(WSizeInBits, + reinterpret_cast(d_out_->data()), + reinterpret_cast(d_act_->data()), + reinterpret_cast(d_weight_->data()), + reinterpret_cast(d_scales_->data()), + static_cast(d_uint8_zeros.data()), + m_, n_, k_, block_size_, device_prop_.sharedMemPerBlock, s_); + }, + warmup_, repeats_, s_); + } + + CudaBuffer d_dequantized_weight(n_x_k * sizeof(AType)); + + naive_time_ms = measure_kernel_time( + [&]() { + auto status = onnxruntime::contrib::cuda::DequantizeNBits( + WSizeInBits, + reinterpret_cast(d_dequantized_weight.data()), + reinterpret_cast(d_weight_->data()), + reinterpret_cast(d_scales_->data()), + reinterpret_cast(d_uint8_zeros.data()), + nullptr, + k_, + n_, + block_size_, + s_); + + ORT_THROW_IF_ERROR(status); + + const AType alpha = AType(1.f); + const AType zero = AType(0.f); + constexpr bool use_tf32 = false; + CUBLAS_CALL_THROW(cublasGemmHelper( + cublas_handle_, + CUBLAS_OP_T, + CUBLAS_OP_N, + n_, + m_, + k_, + &alpha, + reinterpret_cast(d_dequantized_weight.data()), + k_, + reinterpret_cast(d_act_->data()), + k_, + &zero, + reinterpret_cast(d_out_->data()), + n_, + device_prop_, + use_tf32)); + }, + warmup_, repeats_, s_); + } + + // Store benchmark results + BenchmarkResult result; + result.a_type = cutlassTypeMapper::ATypeStr(); + result.b_type = cutlassTypeMapper::WTypeStr(); + result.m = m_; + result.n = n_; + result.k = k_; + result.block_size = block_size_; + result.cuda_time_us = cuda_time_ms * 1000.0f; + result.cutlass_time_us = cutlass_time_ms * 1000.0f; + result.nbits_time_us = nbits_time_ms * 1000.0f; + result.naive_time_us = naive_time_ms * 1000.0f; + result.speedup_cuda_vs_cutlass = cuda_time_ms > 0.f ? cutlass_time_ms / cuda_time_ms : 0.f; + result.speedup_cuda_vs_nbits = cuda_time_ms > 0.f ? nbits_time_ms / cuda_time_ms : 0.f; + result.speedup_best_vs_naive = result.naive_time_us / result.best_time_us(); + benchmark_results_.push_back(result); + + return pass; + } +}; + +} // namespace + +using Fp16Int8GroupwiseTest = KernelTestFixture; +using Fp16Int4GroupwiseTest = KernelTestFixture; +using Bf16Int8GroupwiseTest = KernelTestFixture; +using Bf16Int4GroupwiseTest = KernelTestFixture; + +TEST_F(Fp16Int8GroupwiseTest, Fp16_Int8_Gemm_CudaKernel) { + int const arch = onnxruntime::llm::common::getSMVersion(); + if (arch < 75) { + std::cout << "Skip fp16 int8 groupwise GEMM kernel for SM < 75" << std::endl; + return; + } + + for (auto m : get_m_list()) { + for (const auto& [n, k] : get_n_k_list()) { + InitBuffers(m, n, k, 64); + EXPECT_TRUE(BenchmarkAndVerifyKernel()); + InitBuffers(m, n, k, 128); + EXPECT_TRUE(BenchmarkAndVerifyKernel()); + } + } +} + +TEST_F(Fp16Int4GroupwiseTest, Fp16_Int4_Gemm_CudaKernel) { + int const arch = onnxruntime::llm::common::getSMVersion(); + if (arch < 75) { + std::cout << "Skip fp16 int4 groupwise GEMM kernel for SM < 75" << std::endl; + return; + } + + for (auto m : get_m_list()) { + for (const auto& [n, k] : get_n_k_list()) { + InitBuffers(m, n, k, 64); + EXPECT_TRUE(BenchmarkAndVerifyKernel()); + InitBuffers(m, n, k, 128); + EXPECT_TRUE(BenchmarkAndVerifyKernel()); + } + } +} + +TEST_F(Bf16Int8GroupwiseTest, BF16_Int8_Gemm_CudaKernel) { + int const arch = onnxruntime::llm::common::getSMVersion(); + if (arch < 80) { + std::cout << "Skip bf16 int8 groupwise GEMM kernel test for SM < 80" << std::endl; + return; + } + + for (auto m : get_m_list()) { + for (const auto& [n, k] : get_n_k_list()) { + InitBuffers(m, n, k, 64); + EXPECT_TRUE(BenchmarkAndVerifyKernel()); + InitBuffers(m, n, k, 128); + EXPECT_TRUE(BenchmarkAndVerifyKernel()); + } + } +} + +TEST_F(Bf16Int4GroupwiseTest, BF16_Int4_Gemm_CudaKernel) { + int const arch = onnxruntime::llm::common::getSMVersion(); + if (arch < 80) { + std::cout << "Skip bf16 int4 groupwise GEMM kernel test for SM < 80" << std::endl; + return; + } + + for (auto m : get_m_list()) { + for (const auto& [n, k] : get_n_k_list()) { + InitBuffers(m, n, k, 64); + EXPECT_TRUE(BenchmarkAndVerifyKernel()); + InitBuffers(m, n, k, 128); + EXPECT_TRUE(BenchmarkAndVerifyKernel()); + } + } +} diff --git a/onnxruntime/test/contrib_ops/matmul_8bits_test.cc b/onnxruntime/test/contrib_ops/matmul_8bits_test.cc index 9bcdb5389386f..fc802054036d8 100644 --- a/onnxruntime/test/contrib_ops/matmul_8bits_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_8bits_test.cc @@ -159,6 +159,7 @@ void RunTest8Bits(const TestOptions8Bits& opts) { test.AddOptionalInputEdge(); } + // Account for deprecated "g_idx" input test.AddOptionalInputEdge(); if (bias.has_value()) { @@ -284,6 +285,45 @@ TEST(MatMulNBits, Float32_8b_AccuracyLevel4) { TestMatMul8BitsTyped(); } +TEST(MatMulNBits, Float32_8b_AccuracyLevel1) { + // At the time of writing these tests, Fp32 activations + 8 bit weights + Accuracy level 1 + // do not have MLAS optimized kernels on any platform and hence this will use the "unpacked" + // compute mode (i.e.) de-quantize the 8 bit weights to fp32 and invoke vanilla fp32 Gemm + // in MLAS. This test helps keep that path tested. + TestMatMul8BitsTyped(); + TestMatMul8BitsTyped(); + TestMatMul8BitsTyped(); + TestMatMul8BitsTyped(); + TestMatMul8BitsTyped(); + TestMatMul8BitsTyped(); + TestMatMul8BitsTyped(); + TestMatMul8BitsTyped(); + TestMatMul8BitsTyped(); + TestMatMul8BitsTyped(); + TestMatMul8BitsTyped(); + TestMatMul8BitsTyped(); + TestMatMul8BitsTyped(); + TestMatMul8BitsTyped(); + TestMatMul8BitsTyped(); + TestMatMul8BitsTyped(); + TestMatMul8BitsTyped(); + TestMatMul8BitsTyped(); + TestMatMul8BitsTyped(); + TestMatMul8BitsTyped(); + TestMatMul8BitsTyped(); + TestMatMul8BitsTyped(); + TestMatMul8BitsTyped(); + TestMatMul8BitsTyped(); + TestMatMul8BitsTyped(); + TestMatMul8BitsTyped(); + TestMatMul8BitsTyped(); + TestMatMul8BitsTyped(); + TestMatMul8BitsTyped(); + TestMatMul8BitsTyped(); + TestMatMul8BitsTyped(); + TestMatMul8BitsTyped(); +} + #if defined(USE_CUDA) || defined(USE_WEBGPU) TEST(MatMulNBits, Float16_8b_AccuracyLevel4) { constexpr float abs_error = 0.055f; diff --git a/onnxruntime/test/contrib_ops/rotary_embedding_op_test.cc b/onnxruntime/test/contrib_ops/rotary_embedding_op_test.cc index 625c5ec35e20d..4aa6b5a98c22d 100644 --- a/onnxruntime/test/contrib_ops/rotary_embedding_op_test.cc +++ b/onnxruntime/test/contrib_ops/rotary_embedding_op_test.cc @@ -186,7 +186,7 @@ static void RunTests(const std::vector& input_data, } // Interleaved = true, pos ids shape = (1) -TEST(RotaryEmbeddingTest, RotaryEmbedding_Interleaved_SmallData_LlamaMSFT) { +TEST(ContribOpRotaryEmbeddingTest, RotaryEmbedding_Interleaved_SmallData_LlamaMSFT) { int batch_size = 1; int sequence_length = 3; int num_heads = 2; @@ -230,7 +230,7 @@ TEST(RotaryEmbeddingTest, RotaryEmbedding_Interleaved_SmallData_LlamaMSFT) { } // Interleaved = true, pos ids shape = (1) -TEST(RotaryEmbeddingTest, RotaryEmbedding_Interleaved_LargeData_LlamaMSFT) { +TEST(ContribOpRotaryEmbeddingTest, RotaryEmbedding_Interleaved_LargeData_LlamaMSFT) { int batch_size = 2; int sequence_length = 8; int num_heads = 4; @@ -430,7 +430,7 @@ TEST(RotaryEmbeddingTest, RotaryEmbedding_Interleaved_LargeData_LlamaMSFT) { } // Interleaved = false, pos ids shape = (1) -TEST(RotaryEmbeddingTest, RotaryEmbedding_NotInterleaved_LargeData_LlamaMSFT) { +TEST(ContribOpRotaryEmbeddingTest, RotaryEmbedding_NotInterleaved_LargeData_LlamaMSFT) { int batch_size = 2; int sequence_length = 8; int num_heads = 4; @@ -630,7 +630,7 @@ TEST(RotaryEmbeddingTest, RotaryEmbedding_NotInterleaved_LargeData_LlamaMSFT) { } // Interleaved = false, pos ids shape = (batch_size, sequence_length) -TEST(RotaryEmbeddingTest, RotaryEmbedding_NotInterleaved_SmallData_LlamaMSFT) { +TEST(ContribOpRotaryEmbeddingTest, RotaryEmbedding_NotInterleaved_SmallData_LlamaMSFT) { int batch_size = 1; int sequence_length = 2; int num_heads = 3; @@ -677,7 +677,7 @@ TEST(RotaryEmbeddingTest, RotaryEmbedding_NotInterleaved_SmallData_LlamaMSFT) { interleaved); } -TEST(RotaryEmbeddingTest, RotaryEmbedding_CustomRotaryDim_SmallData_Phi) { +TEST(ContribOpRotaryEmbeddingTest, RotaryEmbedding_CustomRotaryDim_SmallData_Phi) { int batch_size = 1; int sequence_length = 2; int num_heads = 1; @@ -718,7 +718,7 @@ TEST(RotaryEmbeddingTest, RotaryEmbedding_CustomRotaryDim_SmallData_Phi) { true /*use_fp16*/); } -TEST(RotaryEmbeddingTest, RotaryEmbedding_CustomRotaryDim_SmallData_Phi_Packed_Batching) { +TEST(ContribOpRotaryEmbeddingTest, RotaryEmbedding_CustomRotaryDim_SmallData_Phi_Packed_Batching) { int batch_size = 1; int sequence_length = 3; int num_heads = 1; diff --git a/onnxruntime/test/ep_graph/test_ep_graph.cc b/onnxruntime/test/ep_graph/test_ep_graph.cc new file mode 100644 index 0000000000000..64935929db070 --- /dev/null +++ b/onnxruntime/test/ep_graph/test_ep_graph.cc @@ -0,0 +1,491 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include +#include +#include + +#include "core/common/common.h" +#include "core/framework/tensorprotoutils.h" +#include "core/framework/tensor_type_and_shape.h" +#include "core/framework/onnxruntime_typeinfo.h" +#include "core/session/onnxruntime_cxx_api.h" + +#include "test/ep_graph/test_ep_graph_utils.h" +#include "test/util/include/api_asserts.h" +#include "test/util/include/asserts.h" +#include "test/util/include/test_environment.h" + +// defined in unittest_main/test_main.cc +extern std::unique_ptr ort_env; + +namespace onnxruntime { +namespace test { + +// forward-declaration for utility that uses public C APIs to check that an OrtGraph is equivalent +// to a graph represented by the internal ORT GraphViewer class. +static void CheckGraphCApi(const GraphViewer& graph_viewer, const OrtGraph& api_graph); + +// +// Tests +// + +// Checks that an OrtGraph is initialized correctly and tests basic usage of the C API +// by traversing the OrtGraph and checking validity of nodes and value infos. +TEST(EpGraphTest, BasicCApiUse) { + auto test_graph = TestGraph::Load(ORT_TSTR("testdata/mnist.onnx")); + ASSERT_NE(test_graph, nullptr) << "Failed to load test model"; + + CheckGraphCApi(test_graph->GetGraphViewer(), test_graph->GetOrtGraph()); +} + +// Use public C APIs to check that the OrtGraph for a model with subgraphs is correct. +// Traverse OrtGraph with Scan nodes, which tests handling of subgraphs, implicit inputs, and variadic I/O. +TEST(EpGraphTest, CheckModelWithSubgraphs) { + auto test_graph = TestGraph::Load(ORT_TSTR("testdata/scan_1.onnx")); + ASSERT_NE(test_graph, nullptr) << "Failed to load test model"; + + CheckGraphCApi(test_graph->GetGraphViewer(), test_graph->GetOrtGraph()); +} + +// Use public C APIs to check that the OrtGraph for bart_tiny.onnx is correct. +// This model is used in an example topological sort implementation. +TEST(EpGraphTest, CheckModelBartTiny) { + auto test_graph = TestGraph::Load(ORT_TSTR("testdata/bart_tiny.onnx")); + ASSERT_NE(test_graph, nullptr) << "Failed to load test model"; + + CheckGraphCApi(test_graph->GetGraphViewer(), test_graph->GetOrtGraph()); +} + +TEST(EpGraphTest, Check3LayerNestedSubgraph) { + // The main graph contains a 'If' node: 'graph_0__if_0' + // Inside the then-branch of 'graph_0__if_0', there is a nested 'If' node: 'graph_0__if_0__else__if_0' + // This 3-layer nested graph consumes the same initializer in different subgraphs. + auto test_graph = TestGraph::Load(ORT_TSTR("testdata/three_layer_nested_subgraph.onnx")); + ASSERT_NE(test_graph, nullptr) << "Failed to load test model"; + + CheckGraphCApi(test_graph->GetGraphViewer(), test_graph->GetOrtGraph()); +} + +// +// Utils for traversing an OrtGraph and checking against GraphViewer. +// + +// Convert an OrtArrayOfConstObjects into a span of Ort___ pointers. +template +static void GetSpanFromArrayOfConstObjects(const OrtArrayOfConstObjects* ort_array, + /*out*/ gsl::span& span) { + const OrtApi& ort_api = Ort::GetApi(); + + size_t size = 0; + ASSERT_ORTSTATUS_OK(ort_api.ArrayOfConstObjects_GetSize(ort_array, &size)); + + const void* const* raw_data = nullptr; + ASSERT_ORTSTATUS_OK(ort_api.ArrayOfConstObjects_GetData(ort_array, &raw_data)); + + auto data = reinterpret_cast(raw_data); + span = gsl::span(data, size); +} + +static void CheckArrayObjectType(const OrtArrayOfConstObjects* ort_array, OrtTypeTag expected_object_type) { + const OrtApi& ort_api = Ort::GetApi(); + + OrtTypeTag api_object_type = ORT_TYPE_TAG_Void; + ASSERT_ORTSTATUS_OK(ort_api.ArrayOfConstObjects_GetObjectType(ort_array, &api_object_type)); + ASSERT_EQ(api_object_type, expected_object_type); +} + +// Checks that the OrtTypeInfo obtained from the public C API matches another OrtTypeInfo +// obtained from the internal ORT graph IR. +static void CheckTypeInfo(const OrtTypeInfo* api_type_info, const OrtTypeInfo* type_info) { + const OrtApi& ort_api = Ort::GetApi(); + + ASSERT_NE(api_type_info, nullptr); + ASSERT_NE(type_info, nullptr); + + ONNXType api_onnx_type = ONNX_TYPE_UNKNOWN; + ASSERT_ORTSTATUS_OK(ort_api.GetOnnxTypeFromTypeInfo(api_type_info, &api_onnx_type)); + ASSERT_EQ(api_onnx_type, type_info->type); + + if (api_onnx_type == ONNX_TYPE_TENSOR) { + // Only validating Tensors (not checking Map, Sequence, etc.) values because these C APIs for getting + // type/shape information existed long before the new ORT graph IR APIs and are tested elsewhere. + const OrtTensorTypeAndShapeInfo* api_type_shape = nullptr; + ASSERT_ORTSTATUS_OK(ort_api.CastTypeInfoToTensorInfo(api_type_info, &api_type_shape)); + ASSERT_NE(api_type_shape, nullptr); + + ONNXTensorElementDataType api_elem_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; + ASSERT_ORTSTATUS_OK(ort_api.GetTensorElementType(api_type_shape, &api_elem_type)); + ASSERT_EQ(api_elem_type, type_info->tensor_type_info->type); + + size_t api_num_dims = 0; + ASSERT_ORTSTATUS_OK(ort_api.GetDimensionsCount(api_type_shape, &api_num_dims)); + ASSERT_EQ(api_num_dims, type_info->tensor_type_info->shape.NumDimensions()); + + std::vector api_dims(api_num_dims, 0); + ASSERT_ORTSTATUS_OK(ort_api.GetDimensions(api_type_shape, api_dims.data(), api_dims.size())); + ASSERT_EQ(gsl::span(api_dims), type_info->tensor_type_info->shape.GetDims()); + + std::vector api_dim_syms(api_num_dims, nullptr); + ASSERT_ORTSTATUS_OK(ort_api.GetSymbolicDimensions(api_type_shape, api_dim_syms.data(), api_dim_syms.size())); + const std::vector& dim_syms = type_info->tensor_type_info->dim_params; + for (size_t dim_idx = 0; dim_idx < api_num_dims; dim_idx++) { + ASSERT_EQ(std::string(api_dim_syms[dim_idx]), dim_syms[dim_idx]); + } + } +} + +// Checks that the given OrtNode matches the onnxruntime::Node. +static void CheckNode(const Node* node, const OrtNode* api_node) { + const OrtApi& ort_api = Ort::GetApi(); + + size_t api_node_id = 0; + const char* api_node_name = nullptr; + const char* api_op_type = nullptr; + const char* api_domain = nullptr; + + ASSERT_ORTSTATUS_OK(ort_api.Node_GetId(api_node, &api_node_id)); + ASSERT_ORTSTATUS_OK(ort_api.Node_GetName(api_node, &api_node_name)); + ASSERT_ORTSTATUS_OK(ort_api.Node_GetOperatorType(api_node, &api_op_type)); + ASSERT_ORTSTATUS_OK(ort_api.Node_GetDomain(api_node, &api_domain)); + + ASSERT_EQ(api_node_id, node->Index()); + ASSERT_EQ(std::string(api_node_name), node->Name()); + ASSERT_EQ(std::string(api_op_type), node->OpType()); + ASSERT_EQ(std::string(api_domain), node->Domain()); +} + +// Checks that the producer of a OrtValueInfo obtained from the public C API is valid. +static void CheckValueInfoProducer(const GraphViewer& graph_viewer, const OrtValueInfo* value_info, + const NodeArg* node_arg) { + const OrtApi& ort_api = Ort::GetApi(); + + if (!node_arg->Exists()) { + return; + } + + const OrtNode* api_producer_node = nullptr; + size_t api_producer_output_index = 0; + ASSERT_ORTSTATUS_OK(ort_api.ValueInfo_GetValueProducer(value_info, &api_producer_node, &api_producer_output_index)); + + const Node* producer_node = graph_viewer.GetProducerNode(node_arg->Name()); + if (producer_node == nullptr) { + ASSERT_EQ(api_producer_node, nullptr); + } else { + bool within_graph_viewer = graph_viewer.GetNode(producer_node->Index()) != nullptr; + if (!within_graph_viewer) { + ASSERT_EQ(api_producer_node, nullptr); // Producer is outside the graph viewer, so C API should return null + } else { + CheckNode(producer_node, api_producer_node); + + size_t output_index = 0; + ASSERT_STATUS_OK(GetOutputIndex(*producer_node, node_arg->Name(), output_index)); + ASSERT_EQ(api_producer_output_index, output_index); + } + } +} + +// Checks that consumers of a OrtValueInfo obtained from the public C API are valid by comparing to the original graph. +static void CheckValueInfoConsumers(const GraphViewer& graph_viewer, const OrtValueInfo* value_info, + const NodeArg* node_arg) { + const OrtApi& ort_api = Ort::GetApi(); + + if (!node_arg->Exists()) { + return; + } + + std::vector node_arg_consumers; + ASSERT_STATUS_OK(GetNodeArgConsumers(graph_viewer, *node_arg, node_arg_consumers)); + + size_t api_num_consumers = 0; + ASSERT_ORTSTATUS_OK(ort_api.ValueInfo_GetValueNumConsumers(value_info, &api_num_consumers)); + ASSERT_EQ(api_num_consumers, node_arg_consumers.size()); + + std::vector api_node_consumers(api_num_consumers, nullptr); + std::vector api_input_indices(api_num_consumers, 0); + ASSERT_ORTSTATUS_OK(ort_api.ValueInfo_GetValueConsumers(value_info, api_node_consumers.data(), + api_input_indices.data(), api_num_consumers)); + + for (size_t i = 0; i < api_num_consumers; i++) { + CheckNode(node_arg_consumers[i].node, api_node_consumers[i]); + ASSERT_EQ(api_input_indices[i], static_cast(node_arg_consumers[i].input_index)); + } +} + +static void CheckInitializerValueInfo(const OrtValueInfo* api_value_info, + const ONNX_NAMESPACE::TensorProto* tensor_proto) { + const OrtApi& ort_api = Ort::GetApi(); + + const OrtValue* api_initializer_value = nullptr; + ASSERT_ORTSTATUS_OK(ort_api.ValueInfo_GetInitializerValue(api_value_info, &api_initializer_value)); + ASSERT_NE(api_initializer_value, nullptr); + + const char* api_initializer_name = nullptr; + ASSERT_ORTSTATUS_OK(ort_api.GetValueInfoName(api_value_info, &api_initializer_name)); + ASSERT_NE(api_initializer_name, nullptr); + + // Check initializer type. + const ONNX_NAMESPACE::TypeProto type_proto = utils::TypeProtoFromTensorProto(*tensor_proto); + auto type_info = OrtTypeInfo::FromTypeProto(type_proto); + + const OrtTypeInfo* api_type_info = nullptr; + ASSERT_ORTSTATUS_OK(ort_api.GetValueInfoTypeInfo(api_value_info, &api_type_info)); + CheckTypeInfo(api_type_info, type_info.get()); +} + +static void CheckInitializerValueInfosCApi(gsl::span initializer_value_infos, + const InitializedTensorSet& initializer_tensor_protos) { + const OrtApi& ort_api = Ort::GetApi(); + + for (size_t i = 0; i < initializer_value_infos.size(); i++) { + const OrtValueInfo* api_value_info = initializer_value_infos[i]; + + const char* api_initializer_name = nullptr; + ASSERT_ORTSTATUS_OK(ort_api.GetValueInfoName(api_value_info, &api_initializer_name)); + ASSERT_NE(api_initializer_name, nullptr); + + auto tensor_proto_iter = initializer_tensor_protos.find(api_initializer_name); + ASSERT_NE(tensor_proto_iter, initializer_tensor_protos.end()); + + const ONNX_NAMESPACE::TensorProto* tensor_proto = tensor_proto_iter->second; + ASSERT_NE(tensor_proto, nullptr); + + CheckInitializerValueInfo(api_value_info, tensor_proto); + } +} + +// Checks that the OrtValueInfos obtained from the public C API are "equivalent" to the NodeArgs +// in the original graph. +static void CheckValueInfosCApi(const GraphViewer& graph_viewer, gsl::span value_infos, + gsl::span node_args) { + ASSERT_EQ(value_infos.size(), node_args.size()); + const OrtApi& ort_api = Ort::GetApi(); + const auto& graph_viewer_inputs = graph_viewer.GetInputsIncludingInitializers(); + const auto& graph_viewer_outputs = graph_viewer.GetOutputs(); + + for (size_t i = 0; i < value_infos.size(); i++) { + const NodeArg* node_arg = node_args[i]; + const OrtValueInfo* value_info = value_infos[i]; + + if (node_arg->Exists()) { + const auto& value_name = node_arg->Name(); + + ASSERT_NE(value_info, nullptr); + + const char* api_name = nullptr; + ASSERT_ORTSTATUS_OK(ort_api.GetValueInfoName(value_info, &api_name)); + ASSERT_EQ(std::string(api_name), value_name); + + bool is_graph_input = std::any_of(graph_viewer_inputs.begin(), graph_viewer_inputs.end(), + [&node_arg](const NodeArg* graph_input) { + return node_arg->Name() == graph_input->Name(); + }); + + bool is_graph_output = std::any_of(graph_viewer_outputs.begin(), graph_viewer_outputs.end(), + [&node_arg](const NodeArg* graph_output) { + return node_arg->Name() == graph_output->Name(); + }); + bool is_const_initializer = false; + const ONNX_NAMESPACE::TensorProto* initializer = graph_viewer.GetGraph().GetInitializer(value_name, true, + is_const_initializer); + bool can_override_initializer = graph_viewer.CanOverrideInitializer(); + + bool api_is_req_graph_input = false; + bool api_is_opt_graph_input = false; + ASSERT_ORTSTATUS_OK(ort_api.ValueInfo_IsRequiredGraphInput(value_info, &api_is_req_graph_input)); + ASSERT_ORTSTATUS_OK(ort_api.ValueInfo_IsOptionalGraphInput(value_info, &api_is_opt_graph_input)); + ASSERT_EQ(api_is_req_graph_input, is_graph_input && (initializer == nullptr)); + ASSERT_EQ(api_is_opt_graph_input, can_override_initializer && (initializer != nullptr) && !is_const_initializer); + + bool api_is_graph_output = false; + ASSERT_ORTSTATUS_OK(ort_api.ValueInfo_IsGraphOutput(value_info, &api_is_graph_output)); + ASSERT_EQ(api_is_graph_output, is_graph_output); + + bool is_outer_scope = graph_viewer.GetGraph().IsOuterScopeValue(node_arg->Name()); + bool api_is_outer_scope = false; + ASSERT_ORTSTATUS_OK(ort_api.ValueInfo_IsFromOuterScope(value_info, &api_is_outer_scope)); + ASSERT_EQ(api_is_outer_scope, is_outer_scope); + + bool api_is_const_initializer = false; + ASSERT_ORTSTATUS_OK(ort_api.ValueInfo_IsConstantInitializer(value_info, &api_is_const_initializer)); + ASSERT_EQ(api_is_const_initializer, is_const_initializer); + + if (is_const_initializer || api_is_opt_graph_input) { + CheckInitializerValueInfo(value_info, initializer); + } else { + auto node_arg_type_info = OrtTypeInfo::FromTypeProto(*node_arg->TypeAsProto()); + const OrtTypeInfo* api_type_info = nullptr; + ASSERT_ORTSTATUS_OK(ort_api.GetValueInfoTypeInfo(value_info, &api_type_info)); + CheckTypeInfo(api_type_info, node_arg_type_info.get()); + } + + CheckValueInfoProducer(graph_viewer, value_info, node_arg); + CheckValueInfoConsumers(graph_viewer, value_info, node_arg); + } else { + ASSERT_EQ(value_info, nullptr); // A missing optional input has a null OrtValueInfo. + } + } +} + +// Checks that the contents of the original GraphViewer matches the contents of the OrtGraph. +// Uses the public C APIs to traverse the OrtGraph. +static void CheckGraphCApi(const GraphViewer& graph_viewer, const OrtGraph& api_graph) { + const OrtApi& ort_api = Ort::GetApi(); + + // Check graph inputs. + const auto& graph_input_node_args = graph_viewer.GetInputsIncludingInitializers(); + + OrtArrayOfConstObjects* api_graph_inputs_container = nullptr; + DeferOrtRelease release_graph_inputs(&api_graph_inputs_container, + ort_api.ReleaseArrayOfConstObjects); + gsl::span api_graph_inputs{}; + + ASSERT_ORTSTATUS_OK(ort_api.Graph_GetInputs(&api_graph, &api_graph_inputs_container)); + + CheckArrayObjectType(api_graph_inputs_container, ORT_TYPE_TAG_OrtValueInfo); + GetSpanFromArrayOfConstObjects(api_graph_inputs_container, api_graph_inputs); + + ASSERT_EQ(api_graph_inputs.size(), graph_input_node_args.size()); + CheckValueInfosCApi(graph_viewer, api_graph_inputs, graph_input_node_args); + + // Check graph outputs. + const auto& graph_output_node_args = graph_viewer.GetOutputs(); + + OrtArrayOfConstObjects* api_graph_outputs_container = nullptr; + DeferOrtRelease release_graph_outputs(&api_graph_outputs_container, + ort_api.ReleaseArrayOfConstObjects); + gsl::span api_graph_outputs{}; + + ASSERT_ORTSTATUS_OK(ort_api.Graph_GetOutputs(&api_graph, &api_graph_outputs_container)); + + CheckArrayObjectType(api_graph_outputs_container, ORT_TYPE_TAG_OrtValueInfo); + GetSpanFromArrayOfConstObjects(api_graph_outputs_container, api_graph_outputs); + + ASSERT_EQ(api_graph_outputs.size(), graph_output_node_args.size()); + CheckValueInfosCApi(graph_viewer, api_graph_outputs, graph_output_node_args); + + // Check graph initializers + const auto& graph_initializers = graph_viewer.GetAllInitializedTensors(); + + OrtArrayOfConstObjects* api_initializers_container = nullptr; + DeferOrtRelease release_initializers(&api_initializers_container, + ort_api.ReleaseArrayOfConstObjects); + gsl::span api_initializers{}; + + ASSERT_ORTSTATUS_OK(ort_api.Graph_GetInitializers(&api_graph, &api_initializers_container)); + + CheckArrayObjectType(api_initializers_container, ORT_TYPE_TAG_OrtValueInfo); + GetSpanFromArrayOfConstObjects(api_initializers_container, api_initializers); + + ASSERT_EQ(api_initializers.size(), graph_initializers.size()); + CheckInitializerValueInfosCApi(api_initializers, graph_initializers); + + // Check if it has a parent node. + const Node* parent_node = graph_viewer.ParentNode(); + const bool has_parent_node = parent_node != nullptr; + const OrtNode* api_parent_node = nullptr; + + ASSERT_ORTSTATUS_OK(ort_api.Graph_GetParentNode(&api_graph, &api_parent_node)); + const bool api_has_parent_node = api_parent_node != nullptr; + ASSERT_EQ(api_has_parent_node, has_parent_node); + + if (has_parent_node) { + CheckNode(parent_node, api_parent_node); + } + + // Check all nodes. + OrtArrayOfConstObjects* api_nodes_container = nullptr; + DeferOrtRelease release_nodes(&api_nodes_container, + ort_api.ReleaseArrayOfConstObjects); + ASSERT_ORTSTATUS_OK(ort_api.Graph_GetNodes(&api_graph, &api_nodes_container)); + CheckArrayObjectType(api_nodes_container, ORT_TYPE_TAG_OrtNode); + + size_t api_num_nodes = 0; + ASSERT_ORTSTATUS_OK(ort_api.ArrayOfConstObjects_GetSize(api_nodes_container, &api_num_nodes)); + ASSERT_EQ(api_num_nodes, graph_viewer.NumberOfNodes()); + + std::vector node_indices = graph_viewer.GetNodesInTopologicalOrder(ExecutionOrder::DEFAULT); + for (size_t node_idx = 0; node_idx < api_num_nodes; node_idx++) { + // Check basic node properties. + const Node* node = graph_viewer.GetNode(node_indices[node_idx]); + const OrtNode* api_node = nullptr; + ASSERT_ORTSTATUS_OK(ort_api.ArrayOfConstObjects_GetElementAt(api_nodes_container, node_idx, + reinterpret_cast(&api_node))); + CheckNode(node, api_node); + + int api_since_version = 0; + ASSERT_ORTSTATUS_OK(ort_api.Node_GetSinceVersion(api_node, &api_since_version)); + ASSERT_EQ(api_since_version, node->SinceVersion()); + + // Check node inputs + const auto input_node_args = node->InputDefs(); + + OrtArrayOfConstObjects* api_node_inputs_container = nullptr; + DeferOrtRelease release_node_inputs(&api_node_inputs_container, + ort_api.ReleaseArrayOfConstObjects); + gsl::span api_node_inputs{}; + + ASSERT_ORTSTATUS_OK(ort_api.Node_GetInputs(api_node, &api_node_inputs_container)); + + CheckArrayObjectType(api_node_inputs_container, ORT_TYPE_TAG_OrtValueInfo); + GetSpanFromArrayOfConstObjects(api_node_inputs_container, api_node_inputs); + ASSERT_EQ(api_node_inputs.size(), input_node_args.size()); + + CheckValueInfosCApi(graph_viewer, api_node_inputs, input_node_args); + + // Check node outputs + const auto output_node_args = node->OutputDefs(); + OrtArrayOfConstObjects* api_node_outputs_container = nullptr; + DeferOrtRelease release_node_outputs(&api_node_outputs_container, + ort_api.ReleaseArrayOfConstObjects); + gsl::span api_node_outputs{}; + + ASSERT_ORTSTATUS_OK(ort_api.Node_GetOutputs(api_node, &api_node_outputs_container)); + + CheckArrayObjectType(api_node_outputs_container, ORT_TYPE_TAG_OrtValueInfo); + GetSpanFromArrayOfConstObjects(api_node_outputs_container, api_node_outputs); + ASSERT_EQ(api_node_outputs.size(), output_node_args.size()); + + CheckValueInfosCApi(graph_viewer, api_node_outputs, output_node_args); + + // Check node subgraphs + std::vector> node_subgraphs = node->GetSubgraphs(); + + if (!node_subgraphs.empty()) { + // Check node's implicit inputs to its subgraph nodes. + const auto implicit_input_node_args = node->ImplicitInputDefs(); + OrtArrayOfConstObjects* api_node_implicit_inputs_container = nullptr; + DeferOrtRelease release_node_implicit(&api_node_implicit_inputs_container, + ort_api.ReleaseArrayOfConstObjects); + gsl::span api_node_implicit_inputs{}; + + ASSERT_ORTSTATUS_OK(ort_api.Node_GetImplicitInputs(api_node, &api_node_implicit_inputs_container)); + + CheckArrayObjectType(api_node_implicit_inputs_container, ORT_TYPE_TAG_OrtValueInfo); + GetSpanFromArrayOfConstObjects(api_node_implicit_inputs_container, api_node_implicit_inputs); + ASSERT_EQ(api_node_implicit_inputs.size(), implicit_input_node_args.size()); + + CheckValueInfosCApi(graph_viewer, api_node_implicit_inputs, implicit_input_node_args); + + // Recursively check subgraphs. + OrtArrayOfConstObjects* api_node_subgraphs = nullptr; + DeferOrtRelease release_node_subgraphs(&api_node_subgraphs, + ort_api.ReleaseArrayOfConstObjects); + ASSERT_ORTSTATUS_OK(ort_api.Node_GetSubgraphs(api_node, &api_node_subgraphs)); + CheckArrayObjectType(api_node_subgraphs, ORT_TYPE_TAG_OrtGraph); + + for (size_t subgraph_idx = 0; subgraph_idx < node_subgraphs.size(); subgraph_idx++) { + auto subgraph_viewer = std::make_unique(*node_subgraphs[subgraph_idx]); + + const OrtGraph* api_subgraph = nullptr; + ASSERT_ORTSTATUS_OK(ort_api.ArrayOfConstObjects_GetElementAt(api_node_subgraphs, subgraph_idx, + reinterpret_cast(&api_subgraph))); + CheckGraphCApi(*subgraph_viewer, *api_subgraph); + } + } + } +} + +} // namespace test +} // namespace onnxruntime diff --git a/onnxruntime/test/ep_graph/test_ep_graph_topo_sort.cc b/onnxruntime/test/ep_graph/test_ep_graph_topo_sort.cc new file mode 100644 index 0000000000000..5816037c2845d --- /dev/null +++ b/onnxruntime/test/ep_graph/test_ep_graph_topo_sort.cc @@ -0,0 +1,266 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include +#include +#include +#include + +#include "core/common/common.h" +#include "core/session/onnxruntime_cxx_api.h" + +#include "test/ep_graph/test_ep_graph_utils.h" + +// +// Test implementation of Kahn's Topological sort using public C graph APIs and C++ STL. +// + +#define RETURN_IF_API_ERROR(fn) \ + do { \ + Ort::Status status(fn); \ + if (!status.IsOK()) { \ + return status; \ + } \ + } while (0) + +namespace onnxruntime { +namespace test { +template +struct VisitorPriorityQueue { + using ComparatorType = std::function; + std::list list_; + const ComparatorType comparator_ = nullptr; + VisitorPriorityQueue(const ComparatorType& comp) : comparator_(comp) {} + + void push(T node) { + list_.insert( + std::upper_bound(list_.begin(), list_.end(), node, comparator_), + node); + } + bool empty() { return list_.empty(); } + T top() { return list_.back(); } + void pop() { list_.pop_back(); } +}; + +// Get the number of input edges that come from another node upstream. +static Ort::Status GetNodeInputEdgeCount(const OrtNode* node, size_t& num_input_edges) { + const OrtApi& ort_api = Ort::GetApi(); + + OrtArrayOfConstObjects* inputs = nullptr; + DeferOrtRelease release_inputs(&inputs, ort_api.ReleaseArrayOfConstObjects); + RETURN_IF_API_ERROR(ort_api.Node_GetInputs(node, &inputs)); + + size_t num_inputs = 0; + RETURN_IF_API_ERROR(ort_api.ArrayOfConstObjects_GetSize(inputs, &num_inputs)); + + // Sum the number of inputs with a producer node. + num_input_edges = 0; + + for (size_t i = 0; i < num_inputs; ++i) { + const OrtValueInfo* input = nullptr; + RETURN_IF_API_ERROR(ort_api.ArrayOfConstObjects_GetElementAt(inputs, i, reinterpret_cast(&input))); + if (input == nullptr) continue; // Skip missing optional input + + const OrtNode* producer_node = nullptr; + RETURN_IF_API_ERROR(ort_api.ValueInfo_GetValueProducer(input, &producer_node, /*output_index*/ nullptr)); + num_input_edges += static_cast(producer_node != nullptr); + } + + return Ort::Status{nullptr}; +} + +// Get all output nodes that consume an output from the given node. +static Ort::Status GetOutputNodes(const OrtNode* node, std::vector& result) { + const OrtApi& ort_api = Ort::GetApi(); + + OrtArrayOfConstObjects* outputs = nullptr; + DeferOrtRelease release_outputs(&outputs, ort_api.ReleaseArrayOfConstObjects); + RETURN_IF_API_ERROR(ort_api.Node_GetOutputs(node, &outputs)); + + size_t num_outputs = 0; + RETURN_IF_API_ERROR(ort_api.ArrayOfConstObjects_GetSize(outputs, &num_outputs)); + + std::vector output_nodes; + output_nodes.reserve(num_outputs); // May have more than `num_outputs` + + // Gather the OrtNode consumers of every output. + for (size_t i = 0; i < num_outputs; ++i) { + const OrtValueInfo* output = nullptr; + RETURN_IF_API_ERROR(ort_api.ArrayOfConstObjects_GetElementAt(outputs, i, reinterpret_cast(&output))); + if (output == nullptr) continue; // Skip missing optional output + + size_t num_consumers = 0; + RETURN_IF_API_ERROR(ort_api.ValueInfo_GetValueNumConsumers(output, &num_consumers)); + + std::vector node_consumers(num_consumers, nullptr); + std::vector input_indices(num_consumers, 0); + RETURN_IF_API_ERROR(ort_api.ValueInfo_GetValueConsumers(output, node_consumers.data(), + input_indices.data(), num_consumers)); + + for (const OrtNode* consumer : node_consumers) { + output_nodes.push_back(consumer); + } + } + + result = std::move(output_nodes); + return Ort::Status{nullptr}; +} + +// Kahn's topological sort. +// Adapted from onnxruntime/core/graph/graph.cc to use public C API graph types. +static Ort::Status KahnsTopologicalSort(const OrtGraph& graph, + const std::function& enter, + const std::function& comp) { + const OrtApi& ort_api = Ort::GetApi(); + + // Get all nodes + OrtArrayOfConstObjects* nodes_array = nullptr; + DeferOrtRelease release_nodes(&nodes_array, ort_api.ReleaseArrayOfConstObjects); + + size_t num_nodes = 0; + const void* const* nodes_raw_data = nullptr; + + RETURN_IF_API_ERROR(ort_api.Graph_GetNodes(&graph, &nodes_array)); + RETURN_IF_API_ERROR(ort_api.ArrayOfConstObjects_GetSize(nodes_array, &num_nodes)); + RETURN_IF_API_ERROR(ort_api.ArrayOfConstObjects_GetData(nodes_array, &nodes_raw_data)); + + auto nodes_span = gsl::span(reinterpret_cast(nodes_raw_data), num_nodes); + + // Get the maximum node ID. Not really required if we chose to represent the `in_degree` as a map instead of vector. + size_t max_node_id = 0; + for (const OrtNode* node : nodes_span) { + size_t node_id = 0; + RETURN_IF_API_ERROR(ort_api.Node_GetId(node, &node_id)); + max_node_id = std::max(max_node_id, node_id); + } + + std::vector in_degree(max_node_id + 1, 0); + std::vector topo_order; + VisitorPriorityQueue to_visit(comp); + + topo_order.reserve(num_nodes); + + // Initialize in_degree and initial nodes to visit first. + for (const OrtNode* node : nodes_span) { + size_t input_edge_count = 0; + RETURN_IF_API_ERROR(GetNodeInputEdgeCount(node, input_edge_count)); + + size_t node_id = 0; + RETURN_IF_API_ERROR(ort_api.Node_GetId(node, &node_id)); + + in_degree[node_id] = input_edge_count; + if (input_edge_count == 0) { + to_visit.push(node); + } + } + + while (!to_visit.empty()) { + const OrtNode* current_node = to_visit.top(); + to_visit.pop(); + + if (!current_node) continue; + + if (enter) { + enter(current_node); + } + + std::vector output_nodes; + GetOutputNodes(current_node, output_nodes); + + for (const OrtNode* output_node : output_nodes) { + size_t output_node_id = 0; + RETURN_IF_API_ERROR(ort_api.Node_GetId(output_node, &output_node_id)); + + auto& node_in_degree = in_degree[output_node_id]; + node_in_degree--; + + if (node_in_degree == 0) { + to_visit.push(output_node); + } + } + + size_t current_node_id = 0; + RETURN_IF_API_ERROR(ort_api.Node_GetId(current_node, ¤t_node_id)); + topo_order.push_back(current_node_id); + } + + if (num_nodes != topo_order.size()) { + return Ort::Status("Some nodes are not included in the topological sort: graph has a cycle", ORT_FAIL); + } + + return Ort::Status{nullptr}; +} + +// Node comparison functor copied from onnxruntime/core/graph/graph.cc +struct PriorityNodeCompare { + inline bool IsHighPri(const OrtNode* n) const { + // local statics so we can compare std::strings in the checks + static constexpr std::string_view shape_op("Shape"); + static constexpr std::string_view size_op("Size"); + + const char* op_type = nullptr; + Ort::Status status(Ort::GetApi().Node_GetOperatorType(n, &op_type)); + ORT_ENFORCE(status.IsOK()); + + return shape_op == op_type || size_op == op_type; + } + + // Used for std::priority_queue + // If return false, n1 will be output first + // If return true, n2 will be output first + bool operator()(const OrtNode* n1, const OrtNode* n2) const { + // nodes in global high priority list will be output first + const bool isN1HighPri = IsHighPri(n1); + const bool isN2HighPri = IsHighPri(n2); + if (isN1HighPri != isN2HighPri) { + return isN2HighPri; + } + + // nodes with lower priority value will be output first + const auto n1_priority = 0; // n1->Priority(); // Looks to always be 0 inside ORT? + const auto n2_priority = 0; // n2->Priority(); // Looks to always be 0 inside ORT? + if (n1_priority != n2_priority) { + return n1_priority > n2_priority; + } + + // otherwise, nodes with lower index will be output first + size_t n1_id = 0; + Ort::Status status1(Ort::GetApi().Node_GetId(n1, &n1_id)); + ORT_ENFORCE(status1.IsOK()); + + size_t n2_id = 0; + Ort::Status status2(Ort::GetApi().Node_GetId(n2, &n2_id)); + ORT_ENFORCE(status2.IsOK()); + + return n1_id > n2_id; + } +}; + +TEST(EpGraphTest, BasicKahnTopoSort) { + auto test_graph = TestGraph::Load(ORT_TSTR("testdata/bart_tiny.onnx")); + ASSERT_NE(test_graph, nullptr) << "Failed to load test model"; + + // Sort OrtGraph with a custom Kahn's topological sorting algorithm. + std::vector api_nodes_topo_sort_with_priority; + Ort::Status status(KahnsTopologicalSort( + test_graph->GetOrtGraph(), + [&](const OrtNode* node) { + size_t node_id = 0; + Ort::Status status(Ort::GetApi().Node_GetId(node, &node_id)); + ORT_ENFORCE(status.IsOK()); + + api_nodes_topo_sort_with_priority.push_back(node_id); + }, + PriorityNodeCompare())); + ASSERT_TRUE(status.IsOK()) << status.GetErrorMessage(); + + // Use ORT's built in sorting with priority. + std::vector ort_topo_sort_with_priority = test_graph->GetGraphViewer() + .GetNodesInTopologicalOrder(ExecutionOrder::PRIORITY_BASED); + + // Check that they are equal. + ASSERT_EQ(api_nodes_topo_sort_with_priority, ort_topo_sort_with_priority); +} +} // namespace test +} // namespace onnxruntime diff --git a/onnxruntime/test/ep_graph/test_ep_graph_utils.cc b/onnxruntime/test/ep_graph/test_ep_graph_utils.cc new file mode 100644 index 0000000000000..b7743e65061de --- /dev/null +++ b/onnxruntime/test/ep_graph/test_ep_graph_utils.cc @@ -0,0 +1,93 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "test/ep_graph/test_ep_graph_utils.h" + +#include "core/graph/ep_api_types.h" +#include "core/graph/model.h" + +namespace onnxruntime { +namespace test { + +TestGraph::TestGraph(std::shared_ptr model) + : model(model), graph_viewer(model->MainGraph()) { + std::unique_ptr ep_graph = nullptr; + ORT_ENFORCE(EpGraph::Create(graph_viewer, ep_graph).IsOK()); + api_graph = std::move(ep_graph); +} + +TestGraph::~TestGraph() {} + +std::unique_ptr TestGraph::Load(const ORTCHAR_T* model_path) { + std::shared_ptr model; + auto status = Model::Load(model_path, model, nullptr, DefaultLoggingManager().DefaultLogger()); + if (!status.IsOK()) { + return nullptr; + } + + return std::make_unique(model); +} + +const OrtGraph& TestGraph::GetOrtGraph() const { return *api_graph; } +const GraphViewer& TestGraph::GetGraphViewer() const { return graph_viewer; } + +static Status GetInputIndices(const Node& consumer_node, const std::string& name, + /*out*/ std::vector& indices) { + bool found = false; + auto add_input_indices = + [&found, &name, &indices](ConstPointerContainer> input_defs, + bool is_implicit) -> void { + for (size_t i = 0; i < input_defs.size(); i++) { + if (input_defs[i]->Name() == name) { + indices.push_back(is_implicit ? -1 : static_cast(i)); + found = true; + } + } + }; + + add_input_indices(consumer_node.InputDefs(), false); + add_input_indices(consumer_node.ImplicitInputDefs(), true); + + ORT_RETURN_IF(!found, "Did not find input indices for NodeArg ", name); + return Status::OK(); +} + +Status GetOutputIndex(const Node& producer_node, const std::string& name, /*out*/ size_t& index) { + const auto outputs = producer_node.OutputDefs(); + + bool found = false; + for (size_t i = 0; i < outputs.size(); i++) { + if (outputs[i]->Name() == name) { + index = i; + found = true; + } + } + ORT_RETURN_IF(!found, "Did not find output index of NodeArg ", name); + return Status::OK(); +} + +Status GetNodeArgConsumers(const GraphViewer& graph_viewer, const NodeArg& node_arg, + /*out*/ std::vector& consumers) { + std::vector nodes = graph_viewer.GetConsumerNodes(node_arg.Name()); + if (nodes.empty()) { + return Status::OK(); + } + + consumers.reserve(nodes.size()); + for (const Node* node : nodes) { + bool within_graph_viewer = node != nullptr && graph_viewer.GetNode(node->Index()) != nullptr; + if (!within_graph_viewer) { + continue; // Node is not in this GraphViewer + } + + std::vector input_indices; + ORT_RETURN_IF_ERROR(GetInputIndices(*node, node_arg.Name(), input_indices)); + + for (int64_t input_index : input_indices) { + consumers.emplace_back(node, input_index); + } + } + return Status::OK(); +} +} // namespace test +} // namespace onnxruntime diff --git a/onnxruntime/test/ep_graph/test_ep_graph_utils.h b/onnxruntime/test/ep_graph/test_ep_graph_utils.h new file mode 100644 index 0000000000000..9c04a72a42248 --- /dev/null +++ b/onnxruntime/test/ep_graph/test_ep_graph_utils.h @@ -0,0 +1,72 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include + +#include "core/common/common.h" +#include "core/graph/model.h" +#include "core/session/onnxruntime_cxx_api.h" + +#include "test/util/include/test_environment.h" + +struct OrtGraph; +namespace onnxruntime { +namespace test { + +/// +/// Utility that loads a model from file and provides a OrtGraph view of the model for testing the public graph APIs. +/// +class TestGraph { + public: + explicit TestGraph(std::shared_ptr model); + ~TestGraph(); + + static std::unique_ptr Load(const ORTCHAR_T* model_path); + const OrtGraph& GetOrtGraph() const; + const GraphViewer& GetGraphViewer() const; + + private: + std::shared_ptr model; + GraphViewer graph_viewer; + std::unique_ptr api_graph; +}; + +// Helper to release a C API Ort object at the end of its scope. +// Useful when not using the public C++ API. +// Example: +// { +// OrtTensorTypeAndShapeInfo* info = nullptr; +// DeferOrtRelease defer_release(&info, c_api.ReleaseTensorTypeAndShapeInfo); +// ... +// } /* Release is called at end of scope*/ +template +struct DeferOrtRelease { + DeferOrtRelease(T** obj_ptr, std::function release_func) : obj_ptr_(obj_ptr), release_func_(release_func) {} + ~DeferOrtRelease() { + if (obj_ptr_ != nullptr && *obj_ptr_ != nullptr) { + release_func_(*obj_ptr_); + *obj_ptr_ = nullptr; + } + } + T** obj_ptr_ = nullptr; + std::function release_func_ = nullptr; +}; + +struct NodeArgConsumer { + NodeArgConsumer(const Node* node, int64_t index) : node(node), input_index(index) {} + const Node* node = nullptr; + int64_t input_index = -1; +}; + +// Returns consumers (i.e., consumer node + input index) of a NodeArg from the original graph. +Status GetNodeArgConsumers(const GraphViewer& graph_viewer, const NodeArg& node_arg, + /*out*/ std::vector& consumers); + +// Get output index for the given NodeArg name. Returns error if the node does not produce that node arg as an output. +Status GetOutputIndex(const Node& producer_node, const std::string& name, /*out*/ size_t& index); +} // namespace test +} // namespace onnxruntime diff --git a/onnxruntime/test/flatbuffers/flatbuffer_utils_test.cc b/onnxruntime/test/flatbuffers/flatbuffer_utils_test.cc index 467c5e773589a..2ca04235329ef 100644 --- a/onnxruntime/test/flatbuffers/flatbuffer_utils_test.cc +++ b/onnxruntime/test/flatbuffers/flatbuffer_utils_test.cc @@ -116,7 +116,7 @@ ONNX_NAMESPACE::TensorProto CreateInitializer(const std::string& name, } if constexpr (endian::native != endian::little) { - utils::ConvertRawDataInTensorProto(&tp); + utils::ConvertRawDataInTensorProto(tp); } return tp; @@ -262,7 +262,7 @@ TEST(FlatbufferUtilsTest, ExternalWriteReadWithLoadInitializers) { ONNX_NAMESPACE::TensorProto initializer; ASSERT_STATUS_OK(LoadInitializerOrtFormat(*fbs_tensor, initializer, options, reader)); if constexpr (endian::native != endian::little) { - utils::ConvertRawDataInTensorProto(&initializer); + utils::ConvertRawDataInTensorProto(initializer); } loaded_initializers.emplace_back(std::move(initializer)); // also check that the loaded flatbuffer tensors have accurately written to the external_data_offset field diff --git a/onnxruntime/test/framework/TestAllocatorManager.cc b/onnxruntime/test/framework/TestAllocatorManager.cc index 6431faf9ca4c1..30f2686cd62f5 100644 --- a/onnxruntime/test/framework/TestAllocatorManager.cc +++ b/onnxruntime/test/framework/TestAllocatorManager.cc @@ -13,7 +13,6 @@ class DummyArena : public IAllocator { : IAllocator(OrtMemoryInfo(resource_allocator->Info().name, OrtAllocatorType::OrtDeviceAllocator, resource_allocator->Info().device, - resource_allocator->Info().id, resource_allocator->Info().mem_type)), allocator_(std::move(resource_allocator)) { } @@ -50,7 +49,7 @@ static Status RegisterAllocator(std::unordered_map& m std::unique_ptr allocator, size_t /*memory_limit*/, bool use_arena) { auto& info = allocator->Info(); - auto allocator_id = GetAllocatorId(info.name, info.id, use_arena); + auto allocator_id = GetAllocatorId(info.name, info.device.Id(), use_arena); auto status = Status::OK(); if (map.find(allocator_id) != map.end()) diff --git a/onnxruntime/test/framework/allocation_planner_test.cc b/onnxruntime/test/framework/allocation_planner_test.cc index eaebac177ca91..c957f54e51a9c 100644 --- a/onnxruntime/test/framework/allocation_planner_test.cc +++ b/onnxruntime/test/framework/allocation_planner_test.cc @@ -19,6 +19,7 @@ using json = nlohmann::json; #include "core/framework/allocation_planner.h" #include "core/session/inference_session.h" #include "core/graph/model.h" +#include "core/graph/graph_utils.h" #include "core/providers/cpu/cpu_execution_provider.h" #include "core/util/thread_utils.h" @@ -1022,7 +1023,7 @@ TEST_F(PlannerTest, LocationPlanningForInitializersOnlyUsedInANestedSubgraph) { tensor.add_float_data(1.0f); tensor.set_data_type(TensorProto_DataType_FLOAT); tensor.set_name("init_data"); - main_graph.AddInitializedTensor(tensor); + graph_utils::AddInitializerWithExternalData(main_graph, tensor); // Main graph's inputs/outputs main_graph.SetInputs({&abs_data_in, &if_in}); @@ -1129,7 +1130,7 @@ TEST_F(PlannerTest, LocationPlanningForInitializersUsedOnDifferentDevicesInMainG tensor.add_int64_data(1); tensor.set_data_type(TensorProto_DataType_INT64); tensor.set_name("init_data"); - main_graph.AddInitializedTensor(tensor); + graph_utils::AddInitializerWithExternalData(main_graph, tensor); // Main graph's inputs/outputs main_graph.SetInputs({&abs_data_in, &if_in}); @@ -1554,7 +1555,7 @@ TEST_F(PlannerTest, ParaPlanCreation) { for (int i = 0; i < 64 * 3 * 7 * 7; ++i) conv_0_weight_tensor.add_float_data(0.234f); conv_0_weight_tensor.set_data_type(TensorProto_DataType_FLOAT); conv_0_weight_tensor.set_name("conv_0_weight"); - main_graph.AddInitializedTensor(conv_0_weight_tensor); + graph_utils::AddInitializerWithExternalData(main_graph, conv_0_weight_tensor); ONNX_NAMESPACE::TensorProto conv_1_weight_tensor; conv_1_weight_tensor.add_dims(64L); @@ -1564,7 +1565,7 @@ TEST_F(PlannerTest, ParaPlanCreation) { conv_1_weight_tensor.set_data_type(TensorProto_DataType_FLOAT); for (int i = 0; i < 64 * 64; ++i) conv_1_weight_tensor.add_float_data(1.017f); conv_1_weight_tensor.set_name("conv_1_weight"); - main_graph.AddInitializedTensor(conv_1_weight_tensor); + graph_utils::AddInitializerWithExternalData(main_graph, conv_1_weight_tensor); ONNX_NAMESPACE::TensorProto conv_2_weight_tensor; conv_2_weight_tensor.add_dims(64L); @@ -1574,7 +1575,7 @@ TEST_F(PlannerTest, ParaPlanCreation) { for (int i = 0; i < 64 * 64 * 3 * 3; ++i) conv_2_weight_tensor.add_float_data(2.317f); conv_2_weight_tensor.set_data_type(TensorProto_DataType_FLOAT); conv_2_weight_tensor.set_name("conv_2_weight"); - main_graph.AddInitializedTensor(conv_2_weight_tensor); + graph_utils::AddInitializerWithExternalData(main_graph, conv_2_weight_tensor); ONNX_NAMESPACE::TensorProto conv_3_weight_tensor; conv_3_weight_tensor.add_dims(256L); @@ -1584,7 +1585,7 @@ TEST_F(PlannerTest, ParaPlanCreation) { for (int i = 0; i < 256 * 64; ++i) conv_3_weight_tensor.add_float_data(1.256f); conv_3_weight_tensor.set_data_type(TensorProto_DataType_FLOAT); conv_3_weight_tensor.set_name("conv_3_weight"); - main_graph.AddInitializedTensor(conv_3_weight_tensor); + graph_utils::AddInitializerWithExternalData(main_graph, conv_3_weight_tensor); ONNX_NAMESPACE::TensorProto conv_4_weight_tensor; conv_4_weight_tensor.add_dims(256L); @@ -1594,7 +1595,7 @@ TEST_F(PlannerTest, ParaPlanCreation) { for (int i = 0; i < 256 * 64; ++i) conv_4_weight_tensor.add_float_data(1.913f); conv_4_weight_tensor.set_data_type(TensorProto_DataType_FLOAT); conv_4_weight_tensor.set_name("conv_4_weight"); - main_graph.AddInitializedTensor(conv_4_weight_tensor); + graph_utils::AddInitializerWithExternalData(main_graph, conv_4_weight_tensor); auto& conv_0_weight = main_graph.GetOrCreateNodeArg("conv_0_weight", &conv_0_weight_type); auto& conv_1_weight = main_graph.GetOrCreateNodeArg("conv_1_weight", &conv_1_weight_type); @@ -1607,35 +1608,35 @@ TEST_F(PlannerTest, ParaPlanCreation) { conv_0_bias_tensor.set_data_type(TensorProto_DataType_FLOAT); conv_0_bias_tensor.set_name("conv_0_bias"); for (int i = 0; i < 64; ++i) conv_0_bias_tensor.add_float_data(1.123f); - main_graph.AddInitializedTensor(conv_0_bias_tensor); + graph_utils::AddInitializerWithExternalData(main_graph, conv_0_bias_tensor); ONNX_NAMESPACE::TensorProto conv_1_bias_tensor; conv_1_bias_tensor.add_dims(64L); for (int i = 0; i < 64; ++i) conv_1_bias_tensor.add_float_data(2.234f); conv_1_bias_tensor.set_data_type(TensorProto_DataType_FLOAT); conv_1_bias_tensor.set_name("conv_1_bias"); - main_graph.AddInitializedTensor(conv_1_bias_tensor); + graph_utils::AddInitializerWithExternalData(main_graph, conv_1_bias_tensor); ONNX_NAMESPACE::TensorProto conv_2_bias_tensor; conv_2_bias_tensor.add_dims(64L); for (int i = 0; i < 64; ++i) conv_2_bias_tensor.add_float_data(0.121f); conv_2_bias_tensor.set_data_type(TensorProto_DataType_FLOAT); conv_2_bias_tensor.set_name("conv_2_bias"); - main_graph.AddInitializedTensor(conv_2_bias_tensor); + graph_utils::AddInitializerWithExternalData(main_graph, conv_2_bias_tensor); ONNX_NAMESPACE::TensorProto conv_3_bias_tensor; conv_3_bias_tensor.add_dims(256L); for (int i = 0; i < 256; ++i) conv_3_bias_tensor.add_float_data(1.201f); conv_3_bias_tensor.set_data_type(TensorProto_DataType_FLOAT); conv_3_bias_tensor.set_name("conv_3_bias"); - main_graph.AddInitializedTensor(conv_3_bias_tensor); + graph_utils::AddInitializerWithExternalData(main_graph, conv_3_bias_tensor); ONNX_NAMESPACE::TensorProto conv_4_bias_tensor; conv_4_bias_tensor.add_dims(256L); for (int i = 0; i < 256; ++i) conv_4_bias_tensor.add_float_data(0.897f); conv_4_bias_tensor.set_data_type(TensorProto_DataType_FLOAT); conv_4_bias_tensor.set_name("conv_4_bias"); - main_graph.AddInitializedTensor(conv_4_bias_tensor); + graph_utils::AddInitializerWithExternalData(main_graph, conv_4_bias_tensor); auto& conv_0_bias = main_graph.GetOrCreateNodeArg("conv_0_bias", &conv_0_bias_type); auto& conv_1_bias = main_graph.GetOrCreateNodeArg("conv_1_bias", &conv_1_bias_type); diff --git a/onnxruntime/test/framework/allocator_test.cc b/onnxruntime/test/framework/allocator_test.cc index fa6c4966d6953..3efba6f1b6e52 100644 --- a/onnxruntime/test/framework/allocator_test.cc +++ b/onnxruntime/test/framework/allocator_test.cc @@ -14,7 +14,7 @@ TEST(AllocatorTest, CPUAllocatorTest) { auto cpu_arena = TestCPUExecutionProvider()->CreatePreferredAllocators()[0]; ASSERT_STREQ(cpu_arena->Info().name, CPU); - EXPECT_EQ(cpu_arena->Info().id, 0); + EXPECT_EQ(cpu_arena->Info().device.Id(), 0); const auto expected_allocator_type = DoesCpuAllocatorSupportArenaUsage() ? OrtAllocatorType::OrtArenaAllocator diff --git a/onnxruntime/test/framework/cuda/fence_cuda_test.cc b/onnxruntime/test/framework/cuda/fence_cuda_test.cc index e28327941dda4..b86f3efeefafd 100644 --- a/onnxruntime/test/framework/cuda/fence_cuda_test.cc +++ b/onnxruntime/test/framework/cuda/fence_cuda_test.cc @@ -16,6 +16,7 @@ #include "core/framework/session_state.h" #include "core/framework/tensorprotoutils.h" #include "core/graph/graph_viewer.h" +#include "core/graph/graph_utils.h" #include "core/graph/model.h" #include "core/graph/op.h" #include "core/providers/cpu/math/element_wise_ops.h" @@ -66,10 +67,7 @@ static common::Status LoadInferenceSessionFromModel(FenceCudaTestInferenceSessio tensor_proto.set_data_type(PROTO_DATATYPE); \ for (auto v : value) tensor_proto.PROTO_ADD_DATA(v); \ tensor_proto.set_name(name); \ - graph.AddInitializedTensor(tensor_proto); \ - TypeProto type_proto; \ - type_proto.mutable_tensor_type()->set_elem_type(PROTO_DATATYPE); \ - return graph.GetOrCreateNodeArg(name, &type_proto); \ + return graph_utils::AddInitializerWithExternalData(graph, tensor_proto); \ } CREATE_INITIALIZER_FUNC(float, TensorProto_DataType_FLOAT, add_float_data) diff --git a/onnxruntime/test/framework/endian_test.cc b/onnxruntime/test/framework/endian_test.cc index 7b8f56bd97073..694967c70d136 100644 --- a/onnxruntime/test/framework/endian_test.cc +++ b/onnxruntime/test/framework/endian_test.cc @@ -1,10 +1,10 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - #include "core/framework/endian.h" #include "core/framework/endian_utils.h" +#include "core/graph/onnx_protobuf.h" // For TensorProto +#include "core/framework/tensorprotoutils.h" // For ConvertRawDataInTensorProto #include +#include // For std::byte #include "gtest/gtest.h" @@ -47,6 +47,327 @@ TEST(EndianTest, SwapByteOrderCopy) { } } +// Test fixture for SwapByteOrderInplace tests +class SwapByteOrderInplaceTest : public ::testing::Test {}; + +TEST_F(SwapByteOrderInplaceTest, ElementSize1) { + std::vector data = { + std::byte{0x01}, std::byte{0x02}, std::byte{0x03}, std::byte{0x04}}; + std::vector expected_data = { + std::byte{0x01}, std::byte{0x02}, std::byte{0x03}, std::byte{0x04}}; + gsl::span data_span = gsl::make_span(data); + + utils::SwapByteOrderInplace(1, data_span); + EXPECT_EQ(data, expected_data); +} + +TEST_F(SwapByteOrderInplaceTest, ElementSize2_SingleElement) { + std::vector data = {std::byte{0x01}, std::byte{0x02}}; + std::vector expected_data = {std::byte{0x02}, std::byte{0x01}}; + gsl::span data_span = gsl::make_span(data); + + utils::SwapByteOrderInplace(2, data_span); + EXPECT_EQ(data, expected_data); +} + +TEST_F(SwapByteOrderInplaceTest, ElementSize2_MultipleElements) { + std::vector data = { + std::byte{0x01}, std::byte{0x02}, std::byte{0x03}, std::byte{0x04}, std::byte{0x05}, std::byte{0x06}}; + std::vector expected_data = { + std::byte{0x02}, std::byte{0x01}, std::byte{0x04}, std::byte{0x03}, std::byte{0x06}, std::byte{0x05}}; + gsl::span data_span = gsl::make_span(data); + + utils::SwapByteOrderInplace(2, data_span); + EXPECT_EQ(data, expected_data); +} + +TEST_F(SwapByteOrderInplaceTest, ElementSize4_SingleElement) { + std::vector data = { + std::byte{0x01}, std::byte{0x02}, std::byte{0x03}, std::byte{0x04}}; + std::vector expected_data = { + std::byte{0x04}, std::byte{0x03}, std::byte{0x02}, std::byte{0x01}}; + gsl::span data_span = gsl::make_span(data); + + utils::SwapByteOrderInplace(4, data_span); + EXPECT_EQ(data, expected_data); +} + +TEST_F(SwapByteOrderInplaceTest, ElementSize4_MultipleElements) { + std::vector data = { + std::byte{0x01}, std::byte{0x02}, std::byte{0x03}, std::byte{0x04}, + std::byte{0x05}, std::byte{0x06}, std::byte{0x07}, std::byte{0x08}}; + std::vector expected_data = { + std::byte{0x04}, std::byte{0x03}, std::byte{0x02}, std::byte{0x01}, + std::byte{0x08}, std::byte{0x07}, std::byte{0x06}, std::byte{0x05}}; + gsl::span data_span = gsl::make_span(data); + + utils::SwapByteOrderInplace(4, data_span); + EXPECT_EQ(data, expected_data); +} + +TEST_F(SwapByteOrderInplaceTest, ElementSize8_SingleElement) { + std::vector data = { + std::byte{0x01}, std::byte{0x02}, std::byte{0x03}, std::byte{0x04}, + std::byte{0x05}, std::byte{0x06}, std::byte{0x07}, std::byte{0x08}}; + std::vector expected_data = { + std::byte{0x08}, std::byte{0x07}, std::byte{0x06}, std::byte{0x05}, + std::byte{0x04}, std::byte{0x03}, std::byte{0x02}, std::byte{0x01}}; + gsl::span data_span = gsl::make_span(data); + + utils::SwapByteOrderInplace(8, data_span); + EXPECT_EQ(data, expected_data); +} + +TEST_F(SwapByteOrderInplaceTest, ElementSize8_MultipleElements) { + std::vector data = { + std::byte{0x01}, std::byte{0x02}, std::byte{0x03}, std::byte{0x04}, + std::byte{0x05}, std::byte{0x06}, std::byte{0x07}, std::byte{0x08}, + std::byte{0x11}, std::byte{0x12}, std::byte{0x13}, std::byte{0x14}, + std::byte{0x15}, std::byte{0x16}, std::byte{0x17}, std::byte{0x18}}; + std::vector expected_data = { + std::byte{0x08}, std::byte{0x07}, std::byte{0x06}, std::byte{0x05}, + std::byte{0x04}, std::byte{0x03}, std::byte{0x02}, std::byte{0x01}, + std::byte{0x18}, std::byte{0x17}, std::byte{0x16}, std::byte{0x15}, + std::byte{0x14}, std::byte{0x13}, std::byte{0x12}, std::byte{0x11}}; + gsl::span data_span = gsl::make_span(data); + + utils::SwapByteOrderInplace(8, data_span); + EXPECT_EQ(data, expected_data); +} + +TEST_F(SwapByteOrderInplaceTest, EmptyBuffer) { + std::vector data = {}; + std::vector expected_data = {}; + gsl::span data_span = gsl::make_span(data); + + // Should not crash or throw for valid element sizes, e.g., 2 or 4 + // The ORT_ENFORCE checks will pass as 0 % element_size == 0 + // The loop for swapping will not execute. + utils::SwapByteOrderInplace(2, data_span); + EXPECT_EQ(data, expected_data); + + utils::SwapByteOrderInplace(4, data_span); + EXPECT_EQ(data, expected_data); +} + +TEST_F(SwapByteOrderInplaceTest, ElementSize3_OddElementSize) { + std::vector data = { + std::byte{0x01}, std::byte{0x02}, std::byte{0x03}, + std::byte{0x04}, std::byte{0x05}, std::byte{0x06}}; + std::vector expected_data = { + std::byte{0x03}, std::byte{0x02}, std::byte{0x01}, + std::byte{0x06}, std::byte{0x05}, std::byte{0x04}}; + gsl::span data_span = gsl::make_span(data); + + utils::SwapByteOrderInplace(3, data_span); + EXPECT_EQ(data, expected_data); +} + +// Test fixture for ConvertRawDataInTensorProto tests +class ConvertRawDataInTensorProtoTest : public ::testing::Test { + protected: + // Helper function to set up a TensorProto with float data + void SetupFloatTensor(ONNX_NAMESPACE::TensorProto& tensor, const std::vector& values) { + tensor.Clear(); + tensor.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + for (float value : values) { + tensor.add_float_data(value); + } + } + + // Helper function to set up a TensorProto with int32 data + void SetupInt32Tensor(ONNX_NAMESPACE::TensorProto& tensor, const std::vector& values) { + tensor.Clear(); + tensor.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT32); + for (int32_t value : values) { + tensor.add_int32_data(value); + } + } + + // Helper function to set up a TensorProto with int16 data (stored in int32 container) + void SetupInt16Tensor(ONNX_NAMESPACE::TensorProto& tensor, const std::vector& values) { + tensor.Clear(); + tensor.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT16); + for (int16_t value : values) { + tensor.add_int32_data(value); + } + } + + // Helper function to set up a TensorProto with raw data + template + void SetupRawDataTensor(ONNX_NAMESPACE::TensorProto& tensor, ONNX_NAMESPACE::TensorProto_DataType data_type, + const std::vector& values) { + tensor.Clear(); + tensor.set_data_type(data_type); + tensor.set_raw_data(values.data(), values.size() * sizeof(T)); + } + + // Helper to compare float data before and after conversion + void CompareFloatData(const ONNX_NAMESPACE::TensorProto& tensor, const std::vector& expected_values) { + ASSERT_EQ(tensor.float_data_size(), static_cast(expected_values.size())); + for (int i = 0; i < tensor.float_data_size(); i++) { + // We swap bytes so the actual value might change if we're converting endianness + // But a double swap should restore the original value + if constexpr (endian::native == endian::little) { + EXPECT_EQ(tensor.float_data(i), expected_values[i]); + } else { + // Just verify the value is different after one swap on big-endian + // We can't predict the exact value without manual byte swapping + if (expected_values[i] != 0) { // Skip zero values as they're invariant to byte swapping + EXPECT_NE(tensor.float_data(i), expected_values[i]); + } + } + } + } + + // Helper to compare int32 data before and after conversion + void CompareInt32Data(const ONNX_NAMESPACE::TensorProto& tensor, const std::vector& expected_values) { + ASSERT_EQ(tensor.int32_data_size(), static_cast(expected_values.size())); + for (int i = 0; i < tensor.int32_data_size(); i++) { + // Same logic as float comparison + if constexpr (endian::native == endian::little) { + EXPECT_EQ(tensor.int32_data(i), expected_values[i]); + } else { + if (expected_values[i] != 0) { + EXPECT_NE(tensor.int32_data(i), expected_values[i]); + } + } + } + } +}; + +TEST_F(ConvertRawDataInTensorProtoTest, FloatData) { + ONNX_NAMESPACE::TensorProto tensor; + std::vector values = {1.0f, 2.0f, 3.0f, 4.0f}; + SetupFloatTensor(tensor, values); + + // Save original values + std::vector original_values; + for (int i = 0; i < tensor.float_data_size(); i++) { + original_values.push_back(tensor.float_data(i)); + } + + // Convert once + onnxruntime::utils::ConvertRawDataInTensorProto(tensor); // Pass by reference, not pointer + + // Convert back - should restore original values + onnxruntime::utils::ConvertRawDataInTensorProto(tensor); // Pass by reference, not pointer + CompareFloatData(tensor, original_values); +} + +TEST_F(ConvertRawDataInTensorProtoTest, Int32Data) { + ONNX_NAMESPACE::TensorProto tensor; + std::vector values = {1, 2, 3, 4}; + SetupInt32Tensor(tensor, values); + + // Save original values + std::vector original_values; + for (int i = 0; i < tensor.int32_data_size(); i++) { + original_values.push_back(tensor.int32_data(i)); + } + + // Convert once + onnxruntime::utils::ConvertRawDataInTensorProto(tensor); // Pass by reference, not pointer + + // Convert back - should restore original values + onnxruntime::utils::ConvertRawDataInTensorProto(tensor); // Pass by reference, not pointer + CompareInt32Data(tensor, original_values); +} + +TEST_F(ConvertRawDataInTensorProtoTest, Int16Data) { + ONNX_NAMESPACE::TensorProto tensor; + std::vector values = {1, 2, 3, 4}; + SetupInt16Tensor(tensor, values); + + // Save original values + std::vector original_values; + for (int i = 0; i < tensor.int32_data_size(); i++) { + original_values.push_back(tensor.int32_data(i)); + } + + // Convert once + onnxruntime::utils::ConvertRawDataInTensorProto(tensor); // Pass by reference, not pointer + + // Convert back - should restore original values + onnxruntime::utils::ConvertRawDataInTensorProto(tensor); // Pass by reference, not pointer + + // When we swap bytes on int16 values stored in int32 containers, the test should pass + // on both little-endian and big-endian systems + ASSERT_EQ(tensor.int32_data_size(), static_cast(original_values.size())); + for (int i = 0; i < tensor.int32_data_size(); i++) { + EXPECT_EQ(tensor.int32_data(i), original_values[i]); + } +} + +TEST_F(ConvertRawDataInTensorProtoTest, RawFloatData) { + ONNX_NAMESPACE::TensorProto tensor; + std::vector values = {1.0f, 2.0f, 3.0f, 4.0f}; + SetupRawDataTensor(tensor, ONNX_NAMESPACE::TensorProto_DataType_FLOAT, values); + + // Save original raw data + std::string original_raw_data = tensor.raw_data(); + + // Convert once + onnxruntime::utils::ConvertRawDataInTensorProto(tensor); // Pass by reference, not pointer + + // Convert back - should restore original bytes + onnxruntime::utils::ConvertRawDataInTensorProto(tensor); // Pass by reference, not pointer + + EXPECT_EQ(tensor.raw_data(), original_raw_data); +} + +TEST_F(ConvertRawDataInTensorProtoTest, UInt8NoConversion) { + ONNX_NAMESPACE::TensorProto tensor; + tensor.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_UINT8); + std::vector values = {1, 2, 3, 4}; + for (auto val : values) { + tensor.add_int32_data(val); + } + + // Save original data + std::vector original_values; + for (int i = 0; i < tensor.int32_data_size(); i++) { + original_values.push_back(tensor.int32_data(i)); + } + + // Convert - for 1-byte elements, no conversion should happen + onnxruntime::utils::ConvertRawDataInTensorProto(tensor); // Pass by reference, not pointer + + // Verify no change occurred + ASSERT_EQ(tensor.int32_data_size(), static_cast(original_values.size())); + for (int i = 0; i < tensor.int32_data_size(); i++) { + EXPECT_EQ(tensor.int32_data(i), original_values[i]); + } +} + +TEST_F(ConvertRawDataInTensorProtoTest, DoubleConversionAndRestore) { + // Test with double values + ONNX_NAMESPACE::TensorProto tensor; + tensor.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_DOUBLE); + std::vector values = {1.1, 2.2, 3.3, 4.4}; + for (auto val : values) { + tensor.add_double_data(val); + } + + // Save original data + std::vector original_values; + for (int i = 0; i < tensor.double_data_size(); i++) { + original_values.push_back(tensor.double_data(i)); + } + + // Convert once + onnxruntime::utils::ConvertRawDataInTensorProto(tensor); // Pass by reference, not pointer + + // Convert again - this should restore original values + onnxruntime::utils::ConvertRawDataInTensorProto(tensor); // Pass by reference, not pointer + + // Verify restored values + ASSERT_EQ(tensor.double_data_size(), static_cast(original_values.size())); + for (int i = 0; i < tensor.double_data_size(); i++) { + EXPECT_EQ(tensor.double_data(i), original_values[i]); + } +} + } // namespace test } // namespace utils } // namespace onnxruntime diff --git a/onnxruntime/test/framework/inference_session_test.cc b/onnxruntime/test/framework/inference_session_test.cc index a7dc2ad8fc3ca..1e6167f862ea1 100644 --- a/onnxruntime/test/framework/inference_session_test.cc +++ b/onnxruntime/test/framework/inference_session_test.cc @@ -1224,7 +1224,7 @@ TEST(InferenceSessionTests, TestBindCudaPreallocateOutputOnCpu2) { } TEST(InferenceSessionTests, TestBindCudaSpecifyOutputDeviceOnCuda) { - OrtDevice device(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, 0); + OrtDevice device(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::NVIDIA, 0); TestBindHelper("TestBindCudaPreallocateOutputOnCuda", kGpuExecutionProvider, @@ -2217,7 +2217,8 @@ TEST(InferenceSessionTests, TestArenaShrinkageAfterRun) { ASSERT_STATUS_OK(session_object.Initialize()); // Fetch the CUDA allocator to analyze its stats - OrtMemoryInfo mem_info(CUDA, OrtArenaAllocator, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, 0)); + OrtMemoryInfo mem_info(CUDA, OrtArenaAllocator, + OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::NVIDIA, 0)); auto cuda_alloc = session_object.GetAllocator(mem_info); AllocatorStats alloc_stats; diff --git a/onnxruntime/test/framework/session_state_test.cc b/onnxruntime/test/framework/session_state_test.cc index 76399743c97f8..6ad21fa9f5cf5 100644 --- a/onnxruntime/test/framework/session_state_test.cc +++ b/onnxruntime/test/framework/session_state_test.cc @@ -275,7 +275,7 @@ TEST_P(SessionStateTestP, TestInitializerProcessing) { graph, session_state.GetMutableFuncMgr(), [](Graph& graph, bool& modified, const IExecutionProvider& execution_provider, const layout_transformation::DebugGraphFn& debug_graph_fn) -> Status { - AllocatorPtr cpu_allocator = std::make_shared(); + AllocatorPtr cpu_allocator = CPUAllocator::DefaultInstance(); return layout_transformation::TransformLayoutForEP( graph, modified, execution_provider, std::move(cpu_allocator), debug_graph_fn); }, @@ -319,7 +319,7 @@ TEST(SessionStateTest, TestInitializerMemoryAllocatedUsingNonArenaMemory) { GTEST_SKIP() << "CPU allocator does not support arena usage."; } - AllocatorPtr cpu_allocator = std::make_shared(); + AllocatorPtr cpu_allocator = CPUAllocator::DefaultInstance(); // Part 1: Feature turned ON (i.e.) allocate from non-arena memory { std::basic_ostringstream oss; diff --git a/onnxruntime/test/framework/sparse_kernels_test.cc b/onnxruntime/test/framework/sparse_kernels_test.cc index 7bd6b47f52b7d..43de3a945526c 100644 --- a/onnxruntime/test/framework/sparse_kernels_test.cc +++ b/onnxruntime/test/framework/sparse_kernels_test.cc @@ -706,7 +706,7 @@ struct InsertIndices { std::vector indices(indices_data.cbegin(), indices_data.cend()); indices_tp.mutable_raw_data()->assign(reinterpret_cast(indices.data()), indices.size() * sizeof(T)); if constexpr (endian::native != endian::little) { - utils::ConvertRawDataInTensorProto((ONNX_NAMESPACE::TensorProto*)&indices_tp); + utils::ConvertRawDataInTensorProto(indices_tp); } } } diff --git a/onnxruntime/test/framework/tensor_test.cc b/onnxruntime/test/framework/tensor_test.cc index fba099f9c55b3..2ac1a93013932 100644 --- a/onnxruntime/test/framework/tensor_test.cc +++ b/onnxruntime/test/framework/tensor_test.cc @@ -30,7 +30,7 @@ void CPUTensorTest(std::vector dims, const int offset_elements = 0) { EXPECT_EQ(t.DataType(), DataTypeImpl::GetType()); auto& location = t.Location(); EXPECT_STREQ(location.name, CPU); - EXPECT_EQ(location.id, 0); + EXPECT_EQ(location.device.Id(), 0); const T* t_data = t.Data(); EXPECT_EQ(first_element, t_data); @@ -48,7 +48,7 @@ void CPUTensorTest(std::vector dims, const int offset_elements = 0) { EXPECT_EQ(new_t.DataType(), DataTypeImpl::GetType()); auto& new_location = new_t.Location(); ASSERT_STREQ(new_location.name, CPU); - EXPECT_EQ(new_location.id, 0); + EXPECT_EQ(new_location.device.Id(), 0); } } @@ -136,7 +136,7 @@ TEST(TensorTest, EmptyTensorTest) { auto& location = t.Location(); ASSERT_STREQ(location.name, CPU); - EXPECT_EQ(location.id, 0); + EXPECT_EQ(location.device.Id(), 0); const auto expected_allocator_type = DoesCpuAllocatorSupportArenaUsage() ? OrtAllocatorType::OrtArenaAllocator @@ -161,7 +161,7 @@ TEST(TensorTest, StringTensorTest) { EXPECT_EQ(t.DataType(), DataTypeImpl::GetType()); auto& location = t.Location(); ASSERT_STREQ(location.name, CPU); - EXPECT_EQ(location.id, 0); + EXPECT_EQ(location.device.Id(), 0); std::string* new_data = t.MutableData(); EXPECT_TRUE(new_data); diff --git a/onnxruntime/test/framework/test_tensor_loader.cc b/onnxruntime/test/framework/test_tensor_loader.cc index 73bf351b6c556..1abb0ad14660d 100644 --- a/onnxruntime/test/framework/test_tensor_loader.cc +++ b/onnxruntime/test/framework/test_tensor_loader.cc @@ -31,7 +31,7 @@ TEST(CApiTensorTest, load_simple_float_tensor_not_enough_space) { // deserialize it std::vector output(1); OrtValue value; - OrtMemoryInfo cpu_memory_info(onnxruntime::CPU, OrtDeviceAllocator, OrtDevice(), 0, OrtMemTypeDefault); + OrtMemoryInfo cpu_memory_info(onnxruntime::CPU, OrtDeviceAllocator, OrtDevice(), OrtMemTypeDefault); ASSERT_STATUS_NOT_OK( utils::TensorProtoToOrtValue(Env::Default(), std::filesystem::path(), p, @@ -53,7 +53,7 @@ TEST(CApiTensorTest, load_simple_float_tensor_membuffer) { // deserialize it std::vector output(3); OrtValue value; - OrtMemoryInfo cpu_memory_info(onnxruntime::CPU, OrtDeviceAllocator, OrtDevice(), 0, OrtMemTypeDefault); + OrtMemoryInfo cpu_memory_info(onnxruntime::CPU, OrtDeviceAllocator, OrtDevice(), OrtMemTypeDefault); ASSERT_STATUS_OK( utils::TensorProtoToOrtValue(Env::Default(), std::filesystem::path(), p, MemBuffer(output.data(), output.size() * sizeof(float), cpu_memory_info), @@ -80,7 +80,7 @@ TEST(CApiTensorTest, load_simple_float_tensor_allocator) { // save it to a buffer ASSERT_TRUE(p.SerializeToString(&s)); // deserialize it - AllocatorPtr tmp_allocator = std::make_shared(); + AllocatorPtr tmp_allocator = CPUAllocator::DefaultInstance(); OrtValue value; ASSERT_STATUS_OK(utils::TensorProtoToOrtValue(Env::Default(), std::filesystem::path(), p, tmp_allocator, value)); @@ -153,7 +153,7 @@ static void run_external_data_test() { #endif } OrtValue value; - OrtMemoryInfo cpu_memory_info(onnxruntime::CPU, OrtDeviceAllocator, OrtDevice(), 0, OrtMemTypeDefault); + OrtMemoryInfo cpu_memory_info(onnxruntime::CPU, OrtDeviceAllocator, OrtDevice(), OrtMemTypeDefault); ASSERT_STATUS_OK(utils::TensorProtoToOrtValue( Env::Default(), std::filesystem::path(), p, MemBuffer(output.data(), output.size() * sizeof(float), cpu_memory_info), value)); @@ -204,7 +204,7 @@ TEST(CApiTensorTest, load_huge_tensor_with_external_data) { // deserialize it std::vector output(total_ele_count); OrtValue value; - OrtMemoryInfo cpu_memory_info(onnxruntime::CPU, OrtDeviceAllocator, OrtDevice(), 0, OrtMemTypeDefault); + OrtMemoryInfo cpu_memory_info(onnxruntime::CPU, OrtDeviceAllocator, OrtDevice(), OrtMemTypeDefault); ASSERT_STATUS_OK( utils::TensorProtoToOrtValue(Env::Default(), std::filesystem::path(), p, MemBuffer(output.data(), output.size() * sizeof(int), cpu_memory_info), value)); diff --git a/onnxruntime/test/ir/graph_test.cc b/onnxruntime/test/ir/graph_test.cc index f6b7bdb1a001c..e2b54950e7b24 100644 --- a/onnxruntime/test/ir/graph_test.cc +++ b/onnxruntime/test/ir/graph_test.cc @@ -1698,7 +1698,7 @@ TEST_F(GraphTest, ReplaceInitializedTensor) { ONNX_NAMESPACE::TensorProto bad_name = original; bad_name.set_name("invalid"); - status = graph.ReplaceInitializedTensor(std::move(bad_name)); + status = graph.ReplaceInitializedTensor(bad_name, OrtValue()); ASSERT_FALSE(status.IsOK()); } @@ -1706,7 +1706,7 @@ TEST_F(GraphTest, ReplaceInitializedTensor) { ONNX_NAMESPACE::TensorProto bad_type = original; bad_type.set_data_type(TensorProto_DataType_FLOAT16); - status = graph.ReplaceInitializedTensor(std::move(bad_type)); + status = graph.ReplaceInitializedTensor(bad_type, OrtValue()); ASSERT_FALSE(status.IsOK()); } @@ -1716,7 +1716,7 @@ TEST_F(GraphTest, ReplaceInitializedTensor) { bad_dims.add_dims(2); bad_dims.add_dims(1); - status = graph.ReplaceInitializedTensor(std::move(bad_dims)); + status = graph.ReplaceInitializedTensor(bad_dims, OrtValue()); ASSERT_FALSE(status.IsOK()); } @@ -1726,26 +1726,39 @@ TEST_F(GraphTest, ReplaceInitializedTensor) { valid_replacement.add_int32_data(3); valid_replacement.add_int32_data(4); - status = graph.ReplaceInitializedTensor(valid_replacement); + status = graph.ReplaceInitializedTensor(valid_replacement, OrtValue()); ASSERT_TRUE(status.IsOK()) << status.ErrorMessage(); - auto tensor_data_matches = [](const ONNX_NAMESPACE::TensorProto& a, const ONNX_NAMESPACE::TensorProto& b) { - if (a.int32_data_size() != b.int32_data_size()) return false; - for (int i = 0; i < a.int32_data_size(); ++i) { - if (a.int32_data(i) != b.int32_data(i)) return false; + auto tensor_data_matches = [](const Graph& graph, const ONNX_NAMESPACE::TensorProto& a, + const ONNX_NAMESPACE::TensorProto& b) -> bool { + // For simplicity. We do not want to deal with external and raw data combinations. + Tensor tensor_a; + EXPECT_TRUE(utils::CreateTensorFromTensorProto(Env::Default(), graph.ModelPath(), a, tensor_a).IsOK()); + Tensor tensor_b; + EXPECT_TRUE(utils::CreateTensorFromTensorProto(Env::Default(), graph.ModelPath(), b, tensor_b).IsOK()); + + EXPECT_EQ(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, tensor_a.GetElementType()); + + if (tensor_a.GetElementType() != tensor_b.GetElementType()) { + return false; + } + if (tensor_a.Shape() != tensor_b.Shape()) { + return false; } - return true; + const auto span_a = tensor_a.DataAsSpan(); + const auto span_b = tensor_b.DataAsSpan(); + return std::equal(span_a.begin(), span_a.end(), span_b.begin()); }; // check retrieved tensor const ONNX_NAMESPACE::TensorProto* result; ASSERT_TRUE(graph.GetInitializedTensor(initializer_name, result)); - ASSERT_TRUE(tensor_data_matches(*result, valid_replacement)); + ASSERT_TRUE(tensor_data_matches(graph, *result, valid_replacement)); // check GraphProto content const ONNX_NAMESPACE::GraphProto graph_proto = graph.ToGraphProto(); ASSERT_EQ(graph_proto.initializer_size(), 1); - ASSERT_TRUE(tensor_data_matches(graph_proto.initializer(0), valid_replacement)); + ASSERT_TRUE(tensor_data_matches(graph, graph_proto.initializer(0), valid_replacement)); } } @@ -1822,13 +1835,13 @@ TEST_F(GraphTest, InjectExternalInitializedTensors) { const TensorProto* with_data = nullptr; ASSERT_TRUE(graph.GetInitializedTensor(initializer_name, with_data)); - // No longer has external data if (with_data) { - ASSERT_FALSE(utils::HasExternalData(*with_data)); + // This proto still has external data, but now it points to the OrtValue. + ASSERT_TRUE(utils::HasExternalData(*with_data)); const auto& original_tensor = ort_value.Get(); - Tensor replaced_tensor(original_tensor.DataType(), data_shape, std::make_shared()); - ASSERT_STATUS_OK(utils::TensorProtoToTensor(Env::Default(), tensor_data_dir_path, *with_data, - replaced_tensor)); + Tensor replaced_tensor; + ASSERT_STATUS_OK(utils::CreateTensorFromTensorProto(Env::Default(), tensor_data_dir_path, *with_data, + replaced_tensor)); ASSERT_EQ(original_tensor.GetElementType(), replaced_tensor.GetElementType()); const auto original_span = original_tensor.DataAsSpan(); const auto replaced_span = replaced_tensor.DataAsSpan(); @@ -2124,6 +2137,187 @@ TEST_F(GraphTest, SubgraphOutputIsOuterScopeValue) { ::testing::ContainsRegex("Subgraph output \\(.*\\) is an outer scope value being returned directly.")); } +static void CreateIntializerWithDataInMemory(const std::string& name, const AllocatorPtr& allocator, int64_t size, + TensorProto& tensor_proto, OrtValue& ort_value) { + TensorShape shape({size}); + Tensor::InitOrtValue(DataTypeImpl::GetType(), shape, allocator, ort_value); + float v = 0; + auto* data = ort_value.GetMutable()->MutableData(); + for (int64_t i = 0; i < size; ++i) { + *data++ = v++; + } + + tensor_proto = utils::TensorToTensorProto(ort_value.Get(), name, true); +} + +TEST(GraphGetOrtValueInitializerTest, ReturnsOrtValueForExistingInitializer) { + Model model("TestModel", false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(), {}, {}, + DefaultLoggingManager().DefaultLogger()); + Graph& graph = model.MainGraph(); + + // Create a simple TensorProto initializer + const std::string name = "init1"; + auto allocator = CPUAllocator::DefaultInstance(); + constexpr const int64_t kTensorSize = 256; + TensorProto tensor_proto; + OrtValue ort_value; + CreateIntializerWithDataInMemory(name, allocator, kTensorSize, tensor_proto, ort_value); + + ASSERT_STATUS_OK(graph.AddInitializedOrtValue(tensor_proto, ort_value)); + + // Test retrieval + OrtValue retrieved; + EXPECT_TRUE(graph.GetOrtValueInitializer(name, retrieved, false)); + const Tensor& t = retrieved.Get(); + EXPECT_EQ(t.Shape().Size(), kTensorSize); +} + +TEST(GraphGetOrtValueInitializerTest, ReturnsFalseForNonExistentInitializer) { + Model model("TestModel", false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(), {}, {}, + DefaultLoggingManager().DefaultLogger()); + Graph& graph = model.MainGraph(); + + OrtValue retrieved; + EXPECT_FALSE(graph.GetOrtValueInitializer("does_not_exist", retrieved, false)); +} + +namespace { +// Casing only, do not add members +class NodeWrapper : public Node { + public: + Node::Definitions& MutableDefinitions() { + return Node::MutableDefinitions(); + } +}; +} // namespace + +TEST(GraphGetOrtValueInitializerTest, ReturnsOrtValueFromOuterScope) { + // Create parent graph with initializer + Model parent_model("ParentModel", false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(), {}, {}, + DefaultLoggingManager().DefaultLogger()); + Graph& parent_graph = parent_model.MainGraph(); + + const std::string outer_init_name = "outer_init"; + auto allocator = CPUAllocator::DefaultInstance(); + constexpr const int64_t kTensorSize = 256; + TensorProto tensor_proto; + OrtValue ort_value; + CreateIntializerWithDataInMemory(outer_init_name, allocator, kTensorSize, tensor_proto, ort_value); + + ASSERT_STATUS_OK(parent_graph.AddInitializedOrtValue(tensor_proto, ort_value)); + + // Create a node in parent graph that will be the parent node for the subgraph + TypeProto tensor_type; + tensor_type.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT); + auto& input_arg = parent_graph.GetOrCreateNodeArg("node_input", &tensor_type); + auto& output_arg = parent_graph.GetOrCreateNodeArg("node_output", &tensor_type); + NodeArg* inputs[] = {&input_arg}; + NodeArg* outputs[] = {&output_arg}; + + // Create parent node with a subgraph attribute + auto& parent_node = parent_graph.AddNode("parent_node", "If", "parent node with subgraph", inputs, outputs); + // Add the initializer name to the parent node's implicit input defs + NodeArg* outer_init_nodearg = parent_graph.GetNodeArg(outer_init_name); + ASSERT_NE(outer_init_nodearg, nullptr); + { + // Test hack to tweak an internal structure. + auto& node_wrapper = static_cast(parent_node); + node_wrapper.MutableDefinitions().implicit_input_defs.push_back(outer_init_nodearg); + } + + // Create subgraph + GraphProto subgraph_proto; + subgraph_proto.set_name("Subgraph"); + Graph subgraph(parent_model, &subgraph_proto, parent_graph.DomainToVersionMap(), parent_model.IrVersion(), + nullptr, &parent_graph, &parent_node, DefaultLoggingManager().DefaultLogger(), false); + + // Test retrieval from outer scope + OrtValue retrieved; + EXPECT_TRUE(subgraph.GetOrtValueInitializer("outer_init", retrieved, true)); + const Tensor& t = retrieved.Get(); + EXPECT_EQ(t.Shape().Size(), kTensorSize); +} + +TEST_F(GraphTest, AddInitializedOrtValueWithExternalData) { + Model model("TestAddInitializedOrtValue", false, *logger_); + Graph& graph = model.MainGraph(); + + // Create a TensorProto with external data reference + const std::string external_data_init = "external_data_init"; + auto allocator = CPUAllocator::DefaultInstance(); + constexpr const int64_t kTensorSize = 256; + TensorProto tensor_proto; + OrtValue ort_value; + CreateIntializerWithDataInMemory(external_data_init, allocator, kTensorSize, tensor_proto, ort_value); + + // Test adding the initialized OrtValue with external data reference + ASSERT_STATUS_OK(graph.AddInitializedOrtValue(tensor_proto, ort_value)); + + // Verify the initializer was added correctly + OrtValue retrieved_value; + ASSERT_TRUE(graph.GetOrtValueInitializer(external_data_init, retrieved_value, false)); + + // Verify the tensor data + const Tensor& retrieved_tensor = retrieved_value.Get(); + ASSERT_EQ(retrieved_tensor.Shape().Size(), kTensorSize); + ASSERT_EQ(retrieved_tensor.DataType(), DataTypeImpl::GetType()); + + // Verify the TensorProto was also added and has external data location + const TensorProto* retrieved_proto = nullptr; + ASSERT_TRUE(graph.GetInitializedTensor(external_data_init, retrieved_proto)); + ASSERT_NE(retrieved_proto, nullptr); + ASSERT_EQ(retrieved_proto->name(), external_data_init); + ASSERT_TRUE(utils::HasExternalDataInMemory(tensor_proto)); +} + +TEST_F(GraphTest, AddInitializedOrtValueMismatch) { + Model model("TestAddInitializedOrtValue_Mismatch", false, *logger_); + Graph& graph = model.MainGraph(); + + // Create a TensorProto with external data reference + const std::string name = "init"; + constexpr const int64_t kTensorSize = 256; + auto allocator = CPUAllocator::DefaultInstance(); + TensorProto tensor_proto; + OrtValue ort_value; + TensorShape shape({kTensorSize}); + CreateIntializerWithDataInMemory(name, allocator, kTensorSize, tensor_proto, ort_value); + + OrtValue ort_value_diff; + // Now try to create a value that has a different data type + Tensor::InitOrtValue(DataTypeImpl::GetType(), shape, allocator, ort_value_diff); + Status status = graph.AddInitializedOrtValue(tensor_proto, ort_value_diff); + ASSERT_FALSE(status.IsOK()); + + // Create OrtValue with different shape [2] + TensorShape diff_shape({2}); + Tensor::InitOrtValue(DataTypeImpl::GetType(), diff_shape, allocator, ort_value_diff); + + // Fails on shape mismatch + status = graph.AddInitializedOrtValue(tensor_proto, ort_value_diff); + ASSERT_FALSE(status.IsOK()); +} + +TEST_F(GraphTest, AddInitializedOrtValueDuplicate) { + Model model("TestAddInitializedOrtValue_Duplicate", false, *logger_); + Graph& graph = model.MainGraph(); + + // Create a TensorProto with external data reference + const std::string name = "init"; + constexpr const int64_t kTensorSize = 256; + auto allocator = CPUAllocator::DefaultInstance(); + TensorProto tensor_proto; + OrtValue ort_value; + CreateIntializerWithDataInMemory(name, allocator, kTensorSize, tensor_proto, ort_value); + + // Add the first initializer successfully + ASSERT_STATUS_OK(graph.AddInitializedOrtValue(tensor_proto, ort_value)); + + // try again + Status status = graph.AddInitializedOrtValue(tensor_proto, ort_value); + ASSERT_FALSE(status.IsOK()); +} + #ifdef ENABLE_TRAINING TEST_F(GraphTest, GraphConstruction_MemoryEfficientTopologicalSort_Recompute) { diff --git a/onnxruntime/test/onnx/TestCase.cc b/onnxruntime/test/onnx/TestCase.cc index 6dca258601339..16555eafcd897 100644 --- a/onnxruntime/test/onnx/TestCase.cc +++ b/onnxruntime/test/onnx/TestCase.cc @@ -632,7 +632,7 @@ void OnnxTestCase::ConvertTestData(const ONNX_NAMESPACE::TensorProto& test_data_ void* p = len == 0 ? nullptr : b.AllocMemory(len); Ort::Value v1{nullptr}; onnxruntime::test::OrtCallback d; - OrtMemoryInfo cpu_memory_info(onnxruntime::CPU, OrtDeviceAllocator, OrtDevice(), 0, OrtMemTypeDefault); + OrtMemoryInfo cpu_memory_info(onnxruntime::CPU, OrtDeviceAllocator, OrtDevice(), OrtMemTypeDefault); status = onnxruntime::test::TensorProtoToMLValue(test_data_pb, onnxruntime::test::MemBuffer(p, len, cpu_memory_info), v1, d); if (!status.IsOK()) { @@ -683,7 +683,7 @@ void OnnxTestCase::ConvertTestData(const ONNX_NAMESPACE::SequenceProto& test_dat void* p = len == 0 ? nullptr : b.AllocMemory(len); Ort::Value v1{nullptr}; onnxruntime::test::OrtCallback d; - OrtMemoryInfo cpu_memory_info(onnxruntime::CPU, OrtDeviceAllocator, OrtDevice(), 0, OrtMemTypeDefault); + OrtMemoryInfo cpu_memory_info(onnxruntime::CPU, OrtDeviceAllocator, OrtDevice(), OrtMemTypeDefault); status = onnxruntime::test::TensorProtoToMLValue(*it, onnxruntime::test::MemBuffer(p, len, cpu_memory_info), v1, d); if (!status.IsOK()) { @@ -1448,7 +1448,11 @@ std::unique_ptr> GetBrokenTests(const std::string& provider broken_tests->insert({"gridsample_reflection_padding", "result differs"}); broken_tests->insert({"gridsample_volumetric_nearest_align_corners_0", "unknown version"}); broken_tests->insert({"gridsample_volumetric_nearest_align_corners_1", "unknown version"}); + broken_tests->insert({"rotary_embedding", "unknown version"}); + broken_tests->insert({"rotary_embedding_no_position_ids", "unknown version"}); + broken_tests->insert({"rotary_embedding_interleaved", "unknown version"}); broken_tests->insert({"rotary_embedding_no_position_ids_expanded", "unknown version"}); + broken_tests->insert({"rotary_embedding_no_position_ids_interleaved", "unknown version"}); broken_tests->insert({"rotary_embedding_no_position_ids_interleaved_expanded", "unknown version"}); broken_tests->insert({"spacetodepth", "result differs"}); broken_tests->insert({"reduce_sum_square_empty_set_expanded", "unknown version"}); diff --git a/onnxruntime/test/onnx/main.cc b/onnxruntime/test/onnx/main.cc index 893ba1d7b5cac..bef0bdd5295be 100644 --- a/onnxruntime/test/onnx/main.cc +++ b/onnxruntime/test/onnx/main.cc @@ -75,6 +75,8 @@ void usage() { "\t [QNN only] [vtcm_mb]: QNN VTCM size in MB. default to 0(not set).\n" "\t [QNN only] [htp_performance_mode]: QNN performance mode, options: 'burst', 'balanced', 'default', 'high_performance', \n" "\t 'high_power_saver', 'low_balanced', 'extreme_power_saver', 'low_power_saver', 'power_saver', 'sustained_high_performance'. Default to 'default'. \n" + "\t [QNN only] [op_packages]: QNN UDO package, allowed format: \n" + "\t op_packages|::[:],::[:]. \n" "\t [QNN only] [qnn_context_priority]: QNN context priority, options: 'low', 'normal', 'normal_high', 'high'. Default to 'normal'. \n" "\t 0 means dump the QNN context binary into separate bin file and set the path in the Onnx skeleton model.\n" "\t [QNN only] [qnn_saver_path]: QNN Saver backend path. e.g '/folderpath/libQnnSaver.so'.\n" @@ -584,6 +586,10 @@ int real_main(int argc, char* argv[], Ort::Env& env) { std::string str = str_stream.str(); ORT_THROW("Wrong value for htp_performance_mode. select from: " + str); } + } else if (key == "op_packages") { + if (value.empty()) { + ORT_THROW("Please provide the valid op_packages."); + } } else if (key == "qnn_context_priority") { std::set supported_qnn_context_priority = {"low", "normal", "normal_high", "high"}; if (supported_qnn_context_priority.find(value) == supported_qnn_context_priority.end()) { @@ -622,7 +628,7 @@ int real_main(int argc, char* argv[], Ort::Env& env) { ORT_THROW( "Wrong key type entered. Choose from options: ['backend_type', 'backend_path', " "'profiling_level', 'profiling_file_path', 'rpc_control_latency', 'vtcm_mb', 'htp_performance_mode', " - "'qnn_saver_path', 'htp_graph_finalization_optimization_mode', 'qnn_context_priority', " + "'qnn_saver_path', 'htp_graph_finalization_optimization_mode', 'op_packages', 'qnn_context_priority', " "'soc_model', 'htp_arch', 'device_id', 'enable_htp_fp16_precision', 'offload_graph_io_quantization']"); } diff --git a/onnxruntime/test/onnx/microbenchmark/activation.cc b/onnxruntime/test/onnx/microbenchmark/activation.cc index df36135bd3017..65a0b2b93b4d2 100644 --- a/onnxruntime/test/onnx/microbenchmark/activation.cc +++ b/onnxruntime/test/onnx/microbenchmark/activation.cc @@ -24,7 +24,7 @@ extern OrtEnv* env; class Allocs : public IExecutionProvider { private: - std::shared_ptr alloc = std::make_shared(); + AllocatorPtr alloc = CPUAllocator::DefaultInstance(); public: Allocs() : IExecutionProvider("fake") {}; diff --git a/onnxruntime/test/onnx/microbenchmark/batchnorm2.cc b/onnxruntime/test/onnx/microbenchmark/batchnorm2.cc index ed1dc808871ec..ec81830156381 100644 --- a/onnxruntime/test/onnx/microbenchmark/batchnorm2.cc +++ b/onnxruntime/test/onnx/microbenchmark/batchnorm2.cc @@ -12,38 +12,35 @@ using namespace onnxruntime; template void SetRandom(Tensor& input) { - int64_t size = input.Shape().Size(); std::random_device rd; std::mt19937 gen(rd()); std::uniform_real_distribution distr(0, 1); - T* data = input.MutableData(); - for (int64_t i = 0; i != size; ++i) { - data[i] = distr(gen); - } + auto span = input.MutableDataAsSpan(); + std::generate(span.begin(), span.end(), [&]() { return distr(gen); }); } static void BM_BatchNormOldEigen(benchmark::State& state) { - std::shared_ptr alloc = std::make_shared(); + AllocatorPtr alloc = CPUAllocator::DefaultInstance(); const int64_t batch_size = state.range(0); const TensorShape shape = {batch_size, 64, 75, 75}; using T = float; - Tensor* X = new Tensor(DataTypeImpl::GetType(), shape, alloc); - SetRandom(*X); - const TensorShape& x_shape = X->Shape(); - Tensor* Y = new Tensor(DataTypeImpl::GetType(), shape, alloc); - Tensor* scale = new Tensor(DataTypeImpl::GetType(), {shape[1]}, alloc); - SetRandom(*scale); - Tensor* mean = new Tensor(DataTypeImpl::GetType(), {shape[1]}, alloc); - SetRandom(*mean); + Tensor X(DataTypeImpl::GetType(), shape, alloc); + SetRandom(X); + const TensorShape& x_shape = X.Shape(); + Tensor Y(DataTypeImpl::GetType(), shape, alloc); + Tensor scale(DataTypeImpl::GetType(), {shape[1]}, alloc); + SetRandom(scale); + Tensor mean(DataTypeImpl::GetType(), {shape[1]}, alloc); + SetRandom(mean); - Tensor* B = new Tensor(DataTypeImpl::GetType(), {shape[1]}, alloc); - SetRandom(*B); + Tensor B(DataTypeImpl::GetType(), {shape[1]}, alloc); + SetRandom(B); - Tensor* var = new Tensor(DataTypeImpl::GetType(), {shape[1]}, alloc); - SetRandom(*var); + Tensor var(DataTypeImpl::GetType(), {shape[1]}, alloc); + SetRandom(var); bool is_spatial_ = true; double epsilon_ = 1e-5; @@ -60,26 +57,26 @@ static void BM_BatchNormOldEigen(benchmark::State& state) { // calculate sample_size (including all channels) size_t sample_size_incl_all_channels = sample_size * C; for (auto _ : state) { - ConstEigenVectorArrayMap scale_arr(scale->Data(), is_spatial_ ? C : sample_size_incl_all_channels); - ConstEigenVectorArrayMap bias_arr(B->Data(), is_spatial_ ? C : sample_size_incl_all_channels); + ConstEigenVectorArrayMap scale_arr(scale.Data(), is_spatial_ ? C : sample_size_incl_all_channels); + ConstEigenVectorArrayMap bias_arr(B.Data(), is_spatial_ ? C : sample_size_incl_all_channels); // Regardless of training or testing, we will apply the estimated mean // and standard deviation to the input. For testing, they are // specified directly by the input, and for training, they are computed // by the op. Eigen::Array inv_std(is_spatial_ ? C : sample_size_incl_all_channels); - ConstEigenVectorArrayMap var_arr(var->Data(), is_spatial_ ? C : sample_size_incl_all_channels); + ConstEigenVectorArrayMap var_arr(var.Data(), is_spatial_ ? C : sample_size_incl_all_channels); inv_std = (var_arr + epsilon_).sqrt().inverse(); - ConstEigenVectorArrayMap mean_arr(mean->Data(), is_spatial_ ? C : sample_size_incl_all_channels); + ConstEigenVectorArrayMap mean_arr(mean.Data(), is_spatial_ ? C : sample_size_incl_all_channels); // We can fuse the output computation as follows: // ((x - est_mean) * (inv_var) * scale + bias // to // (x * inv_var * scale) + (bias - est_mean * inv_var * scale) Eigen::Array new_scale = inv_std * scale_arr; Eigen::Array new_bias = bias_arr - mean_arr * new_scale; - EigenArrayMap Y_arr(Y->MutableData(), is_spatial_ ? sample_size : sample_size_incl_all_channels, + EigenArrayMap Y_arr(Y.MutableData(), is_spatial_ ? sample_size : sample_size_incl_all_channels, is_spatial_ ? N * C : N); - ConstEigenArrayMap X_arr(X->Data(), is_spatial_ ? sample_size : sample_size_incl_all_channels, + ConstEigenArrayMap X_arr(X.Data(), is_spatial_ ? sample_size : sample_size_incl_all_channels, is_spatial_ ? N * C : N); if (is_spatial_) { // spatial == 1 for (size_t nc = 0; nc < N * C; ++nc) { diff --git a/onnxruntime/test/onnx/microbenchmark/main.cc b/onnxruntime/test/onnx/microbenchmark/main.cc index 70faa6f11989d..24d02caa96aa1 100644 --- a/onnxruntime/test/onnx/microbenchmark/main.cc +++ b/onnxruntime/test/onnx/microbenchmark/main.cc @@ -27,7 +27,7 @@ OrtEnv* env = nullptr; using namespace onnxruntime; static void BM_CPUAllocator(benchmark::State& state) { - AllocatorPtr cpu_allocator = std::make_shared(); + AllocatorPtr cpu_allocator = CPUAllocator::DefaultInstance(); const size_t len = state.range(0); for (auto _ : state) { void* p = cpu_allocator->Alloc(len); diff --git a/onnxruntime/test/onnx/microbenchmark/resize.cc b/onnxruntime/test/onnx/microbenchmark/resize.cc index 020680c12b8f5..2ccf588dfef54 100644 --- a/onnxruntime/test/onnx/microbenchmark/resize.cc +++ b/onnxruntime/test/onnx/microbenchmark/resize.cc @@ -30,7 +30,7 @@ static void BM_NhwcUpsampleBilinear(benchmark::State& state) { const T* const XdataBase = GenerateArrayWithRandomValue(XdataBaseSize, std::numeric_limits::min(), std::numeric_limits::max()); const size_t YdataBaseSize = batch_size * num_channels * output_height * output_width; T* const YdataBase = (T*)aligned_alloc(sizeof(T) * YdataBaseSize, 64); - AllocatorPtr alloc = std::make_shared(); + AllocatorPtr alloc = CPUAllocator::DefaultInstance(); const GetOriginalCoordinateFunc& get_original_coordinate = [](float x_resized, float x_scale, float, float, float, float) { return x_resized / x_scale; diff --git a/onnxruntime/test/onnx/tensorprotoutils.cc b/onnxruntime/test/onnx/tensorprotoutils.cc index 92c4b5bc88fe7..0b4ec1bab192a 100644 --- a/onnxruntime/test/onnx/tensorprotoutils.cc +++ b/onnxruntime/test/onnx/tensorprotoutils.cc @@ -3,21 +3,22 @@ #include "tensorprotoutils.h" -#include #include #include +#include #include -#include "mem_buffer.h" +#include "callback.h" +#include "core/common/make_string.h" #include "core/common/safeint.h" #include "core/common/status.h" -#include "core/common/make_string.h" +#include "core/framework/allocator.h" #include "core/framework/data_types.h" #include "core/framework/endian.h" -#include "core/framework/allocator.h" -#include "core/session/onnxruntime_cxx_api.h" +#include "core/framework/endian_utils.h" #include "core/graph/onnx_protobuf.h" -#include "callback.h" +#include "core/session/onnxruntime_cxx_api.h" +#include "mem_buffer.h" struct OrtStatus { OrtErrorCode code; @@ -69,21 +70,13 @@ static void UnpackTensorWithRawData(const void* raw_data, size_t raw_data_length ORT_CXX_API_THROW(MakeString("UnpackTensor: the pre-allocated size does not match the raw data size, expected ", expected_size_in_bytes, ", got ", raw_data_length), OrtErrorCode::ORT_FAIL); - memcpy(p_data, raw_data, raw_data_length); - if constexpr (endian::native != endian::little) { - /* Convert Endianness */ - char* bytes = reinterpret_cast(p_data); - size_t element_size = sizeof(T); - size_t num_elements = raw_data_length / element_size; - - for (size_t i = 0; i < num_elements; ++i) { - char* start_byte = bytes + i * element_size; - char* end_byte = start_byte + element_size - 1; - /* keep swapping */ - for (size_t count = 0; count < element_size / 2; ++count) { - std::swap(*start_byte++, *end_byte--); - } - } + + /* Convert Endianness */ + if constexpr (endian::native != endian::little && sizeof(T) > 1) { + utils::SwapByteOrderCopy(sizeof(T), gsl::make_span(reinterpret_cast(raw_data), raw_data_length), + gsl::make_span(reinterpret_cast(p_data), raw_data_length)); + } else { + memcpy(p_data, raw_data, raw_data_length); } } diff --git a/onnxruntime/test/opaque_api/test_opaque_api.cc b/onnxruntime/test/opaque_api/test_opaque_api.cc index c1c98cbf0ff5b..5bccf5ab1ac0d 100644 --- a/onnxruntime/test/opaque_api/test_opaque_api.cc +++ b/onnxruntime/test/opaque_api/test_opaque_api.cc @@ -73,7 +73,7 @@ struct NonTensorTypeConverter { // Create and populate Tensor TensorShape shape({1}); - std::shared_ptr allocator = std::make_shared(); + std::shared_ptr allocator = CPUAllocator::DefaultInstance(); std::unique_ptr tp(new Tensor(DataTypeImpl::GetType(), shape, allocator)); *tp->MutableData() = input.Get().str_; diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index 35d50cbec678f..a6a5004a2e2e2 100644 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -1431,7 +1431,7 @@ TEST_F(GraphTransformationTests, FusePadWithConv) { auto& node = *graph.GetNode(node_index); if (node.OpType() == "Pad") { const auto* pads_proto = graph_utils::GetConstantInitializer(graph, node.InputDefs()[1]->Name()); - Initializer pads{*pads_proto, graph.ModelPath()}; + Initializer pads{graph, *pads_proto, graph.ModelPath()}; gsl::span pads_values = pads.DataAsSpan(); expected_pads.resize(pads_values.size() - 4); @@ -1484,7 +1484,7 @@ TEST_F(GraphTransformationTests, FusePadWithNoPadsConv) { auto& node = *graph.GetNode(node_index); if (node.OpType() == "Pad") { const auto* pads_proto = graph_utils::GetConstantInitializer(graph, node.InputDefs()[1]->Name()); - Initializer pads{*pads_proto, graph.ModelPath()}; + Initializer pads{graph, *pads_proto, graph.ModelPath()}; gsl::span pads_values = pads.DataAsSpan(); expected_pads.resize(pads_values.size() - 4); @@ -1532,7 +1532,7 @@ TEST_F(GraphTransformationTests, FusePadWithMaxPool) { auto& node = *graph.GetNode(node_index); if (node.OpType() == "Pad") { const auto* pads_proto = graph_utils::GetConstantInitializer(graph, node.InputDefs()[1]->Name()); - Initializer pads{*pads_proto, graph.ModelPath()}; + Initializer pads{graph, *pads_proto, graph.ModelPath()}; gsl::span pads_values = pads.DataAsSpan(); expected_pads.resize(pads_values.size() - 4); @@ -3804,11 +3804,11 @@ TEST_F(GraphTransformationTests, ReshapeFusionTest) { const ONNX_NAMESPACE::TensorProto* tensor_proto = graph_utils::GetConstantInitializer(graph, node.InputDefs()[1]->Name()); ASSERT_TRUE(tensor_proto != nullptr); - auto initializer = std::make_unique(*tensor_proto, graph.ModelPath()); + Initializer initializer(graph, *tensor_proto, graph.ModelPath()); EXPECT_EQ(tensor_proto->data_type(), ONNX_NAMESPACE::TensorProto_DataType_INT64); - EXPECT_EQ(initializer->size(), 4U); + EXPECT_EQ(initializer.size(), 4U); - const int64_t* val = initializer->data(); + const int64_t* val = initializer.data(); EXPECT_EQ(val[0], 0); EXPECT_EQ(val[1], 0); EXPECT_EQ(val[2], 12); @@ -3840,11 +3840,11 @@ TEST_F(GraphTransformationTests, ReshapeFusionOneConstTest) { const ONNX_NAMESPACE::TensorProto* tensor_proto = graph_utils::GetConstantInitializer(graph, node.InputDefs()[1]->Name()); ASSERT_TRUE(tensor_proto != nullptr); - auto initializer = std::make_unique(*tensor_proto, graph.ModelPath()); + Initializer initializer(graph, *tensor_proto, graph.ModelPath()); EXPECT_EQ(tensor_proto->data_type(), ONNX_NAMESPACE::TensorProto_DataType_INT64); - EXPECT_EQ(initializer->size(), 3U); + EXPECT_EQ(initializer.size(), 3U); - const int64_t* val = initializer->data(); + const int64_t* val = initializer.data(); EXPECT_EQ(val[0], 0); EXPECT_EQ(val[1], 0); EXPECT_EQ(val[2], 768); @@ -3875,11 +3875,11 @@ TEST_F(GraphTransformationTests, ReshapeFusionInternalNodeIsOutput) { const ONNX_NAMESPACE::TensorProto* tensor_proto = graph_utils::GetConstantInitializer(graph, node.InputDefs()[1]->Name()); ASSERT_TRUE(tensor_proto != nullptr); - auto initializer = std::make_unique(*tensor_proto, graph.ModelPath()); + Initializer initializer(graph, *tensor_proto, graph.ModelPath()); EXPECT_EQ(tensor_proto->data_type(), ONNX_NAMESPACE::TensorProto_DataType_INT64); - EXPECT_EQ(initializer->size(), 3U); + EXPECT_EQ(initializer.size(), 3U); - const int64_t* val = initializer->data(); + const int64_t* val = initializer.data(); EXPECT_EQ(val[0], 0); EXPECT_EQ(val[1], 0); EXPECT_EQ(val[2], -1); @@ -3911,11 +3911,11 @@ TEST_F(GraphTransformationTests, ReshapeFusionInternalReuseTest) { const ONNX_NAMESPACE::TensorProto* tensor_proto = graph_utils::GetConstantInitializer(graph, node.InputDefs()[1]->Name()); ASSERT_TRUE(tensor_proto != nullptr); - auto initializer = std::make_unique(*tensor_proto, graph.ModelPath()); + Initializer initializer(graph, *tensor_proto, graph.ModelPath()); EXPECT_EQ(tensor_proto->data_type(), ONNX_NAMESPACE::TensorProto_DataType_INT64); - EXPECT_EQ(initializer->size(), 5U); + EXPECT_EQ(initializer.size(), 5U); - const int64_t* val = initializer->data(); + const int64_t* val = initializer.data(); EXPECT_EQ(val[0], 0); EXPECT_EQ(val[1], 128); EXPECT_EQ(val[2], 0); @@ -3970,11 +3970,11 @@ TEST_F(GraphTransformationTests, ReshapeFusionMultipleValuesInInitializerSubgrap const ONNX_NAMESPACE::TensorProto* tensor_proto = graph_utils::GetConstantInitializer(graph, node.InputDefs()[1]->Name()); ASSERT_TRUE(tensor_proto != nullptr); - auto initializer = std::make_unique(*tensor_proto, graph.ModelPath()); + Initializer initializer(graph, *tensor_proto, graph.ModelPath()); EXPECT_EQ(tensor_proto->data_type(), ONNX_NAMESPACE::TensorProto_DataType_INT64); - EXPECT_EQ(initializer->size(), 3U); + EXPECT_EQ(initializer.size(), 3U); - const int64_t* val = initializer->data(); + const int64_t* val = initializer.data(); EXPECT_EQ(val[0], 1); EXPECT_EQ(val[1], 200); EXPECT_EQ(val[2], -1); @@ -4003,11 +4003,11 @@ TEST_F(GraphTransformationTests, ReshapeFusionMultipleValuesInInitializerApplies const ONNX_NAMESPACE::TensorProto* tensor_proto = graph_utils::GetConstantInitializer(graph, node.InputDefs()[1]->Name()); ASSERT_TRUE(tensor_proto != nullptr); - auto initializer = std::make_unique(*tensor_proto, graph.ModelPath()); + Initializer initializer(graph, *tensor_proto, graph.ModelPath()); EXPECT_EQ(tensor_proto->data_type(), ONNX_NAMESPACE::TensorProto_DataType_INT64); - EXPECT_EQ(initializer->size(), 3U); + EXPECT_EQ(initializer.size(), 3U); - const int64_t* val = initializer->data(); + const int64_t* val = initializer.data(); EXPECT_EQ(val[0], 1); EXPECT_EQ(val[1], 200); EXPECT_EQ(val[2], 0); @@ -4073,11 +4073,11 @@ TEST_F(GraphTransformationTests, ReshapeFusionConcatSubgraphMultipleOutputs) { const ONNX_NAMESPACE::TensorProto* tensor_proto = graph_utils::GetConstantInitializer(graph, node.InputDefs()[1]->Name()); ASSERT_TRUE(tensor_proto != nullptr); - auto initializer = std::make_unique(*tensor_proto, graph.ModelPath()); + Initializer initializer(graph, *tensor_proto, graph.ModelPath()); EXPECT_EQ(tensor_proto->data_type(), ONNX_NAMESPACE::TensorProto_DataType_INT64); - EXPECT_EQ(initializer->size(), 3U); + EXPECT_EQ(initializer.size(), 3U); - const int64_t* val = initializer->data(); + const int64_t* val = initializer.data(); EXPECT_EQ(val[0], 0); EXPECT_EQ(val[1], 0); EXPECT_EQ(val[2], -1); @@ -4107,11 +4107,11 @@ TEST_F(GraphTransformationTests, ReshapeFusionConcatSubgraph) { const ONNX_NAMESPACE::TensorProto* tensor_proto = graph_utils::GetConstantInitializer(graph, node.InputDefs()[1]->Name()); ASSERT_TRUE(tensor_proto != nullptr); - auto initializer = std::make_unique(*tensor_proto, graph.ModelPath()); + Initializer initializer(graph, *tensor_proto, graph.ModelPath()); EXPECT_EQ(tensor_proto->data_type(), ONNX_NAMESPACE::TensorProto_DataType_INT64); - EXPECT_EQ(initializer->size(), 3U); + EXPECT_EQ(initializer.size(), 3U); - const int64_t* val = initializer->data(); + const int64_t* val = initializer.data(); EXPECT_EQ(val[0], 0); EXPECT_EQ(val[1], 0); EXPECT_EQ(val[2], -1); @@ -4141,11 +4141,11 @@ TEST_F(GraphTransformationTests, ReshapeFusionWithSlice1) { const ONNX_NAMESPACE::TensorProto* tensor_proto = graph_utils::GetConstantInitializer(graph, node.InputDefs()[1]->Name()); ASSERT_TRUE(tensor_proto != nullptr); - auto initializer = std::make_unique(*tensor_proto, graph.ModelPath()); + Initializer initializer(graph, *tensor_proto, graph.ModelPath()); EXPECT_EQ(tensor_proto->data_type(), ONNX_NAMESPACE::TensorProto_DataType_INT64); - EXPECT_EQ(initializer->size(), 3U); + EXPECT_EQ(initializer.size(), 3U); - const int64_t* val = initializer->data(); + const int64_t* val = initializer.data(); EXPECT_EQ(val[0], 0); EXPECT_EQ(val[1], 0); EXPECT_EQ(val[2], -1); @@ -4211,11 +4211,11 @@ TEST_F(GraphTransformationTests, ReshapeFusionConcatSubgraphWithDiv) { const ONNX_NAMESPACE::TensorProto* tensor_proto = graph_utils::GetConstantInitializer(graph, node.InputDefs()[1]->Name()); ASSERT_TRUE(tensor_proto != nullptr); - auto initializer = std::make_unique(*tensor_proto, graph.ModelPath()); + Initializer initializer(graph, *tensor_proto, graph.ModelPath()); EXPECT_EQ(tensor_proto->data_type(), ONNX_NAMESPACE::TensorProto_DataType_INT64); - EXPECT_EQ(initializer->size(), 3U); + EXPECT_EQ(initializer.size(), 3U); - const int64_t* val = initializer->data(); + const int64_t* val = initializer.data(); EXPECT_EQ(val[0], 0); EXPECT_EQ(val[1], 0); EXPECT_EQ(val[2], -1); @@ -4247,11 +4247,11 @@ TEST_F(GraphTransformationTests, ReshapeFusionConcatSubgraphWithMul) { const ONNX_NAMESPACE::TensorProto* tensor_proto = graph_utils::GetConstantInitializer(graph, node.InputDefs()[1]->Name()); ASSERT_TRUE(tensor_proto != nullptr); - auto initializer = std::make_unique(*tensor_proto, graph.ModelPath()); + Initializer initializer(graph, *tensor_proto, graph.ModelPath()); EXPECT_EQ(tensor_proto->data_type(), ONNX_NAMESPACE::TensorProto_DataType_INT64); - EXPECT_EQ(initializer->size(), 3U); + EXPECT_EQ(initializer.size(), 3U); - const int64_t* val = initializer->data(); + const int64_t* val = initializer.data(); EXPECT_EQ(val[0], 0); EXPECT_EQ(val[1], 0); EXPECT_EQ(val[2], -1); @@ -4281,11 +4281,11 @@ TEST_F(GraphTransformationTests, ReshapeFusionDistilBertTest) { const ONNX_NAMESPACE::TensorProto* tensor_proto = graph_utils::GetConstantInitializer(graph, node.InputDefs()[1]->Name()); ASSERT_TRUE(tensor_proto != nullptr); - auto initializer = std::make_unique(*tensor_proto, graph.ModelPath()); + Initializer initializer(graph, *tensor_proto, graph.ModelPath()); EXPECT_EQ(tensor_proto->data_type(), ONNX_NAMESPACE::TensorProto_DataType_INT64); - EXPECT_EQ(initializer->size(), 4U); + EXPECT_EQ(initializer.size(), 4U); - const int64_t* val = initializer->data(); + const int64_t* val = initializer.data(); EXPECT_EQ(val[0], 0); EXPECT_EQ(val[1], -1); EXPECT_EQ(val[2], 2); @@ -4476,8 +4476,8 @@ static void ValidateAttention(Graph& graph) { ASSERT_TRUE(tensor_proto != nullptr); EXPECT_EQ(tensor_proto->data_type(), ONNX_NAMESPACE::TensorProto_DataType_FLOAT); - auto initializer = std::make_unique(*tensor_proto, graph.ModelPath()); - EXPECT_EQ(initializer->size(), 192U); + Initializer initializer(graph, *tensor_proto, graph.ModelPath()); + EXPECT_EQ(initializer.size(), 192U); // Validate two rows (2x24 items) for sanity check. std::vector expected_value = { @@ -4531,7 +4531,7 @@ static void ValidateAttention(Graph& graph) { -0.0101165771484375, -0.00490570068359375}; - const float* data = initializer->data(); + const float* data = initializer.data(); for (size_t i = 0; i < expected_value.size(); i++) { EXPECT_EQ(data[i], static_cast(expected_value[i])); } @@ -4540,8 +4540,8 @@ static void ValidateAttention(Graph& graph) { ASSERT_TRUE(tensor_proto != nullptr); EXPECT_EQ(tensor_proto->data_type(), ONNX_NAMESPACE::TensorProto_DataType_FLOAT); - auto initializer2 = std::make_unique(*tensor_proto, graph.ModelPath()); - EXPECT_EQ(initializer2->size(), 24U); + Initializer initializer2(graph, *tensor_proto, graph.ModelPath()); + EXPECT_EQ(initializer2.size(), 24U); std::vector expected_value2 = { -0.23681640625, @@ -4569,7 +4569,7 @@ static void ValidateAttention(Graph& graph) { 0.0535888671875, 0.0091094970703125}; - const float* data2 = initializer2->data(); + const float* data2 = initializer2.data(); for (size_t i = 0; i < expected_value2.size(); i++) { EXPECT_EQ(data2[i], static_cast(expected_value2[i])); } @@ -7011,7 +7011,7 @@ TEST_F(GraphTransformationTests, ConstantSharing_ShareFloatOrHalfTypedInitialize if (entry.first.compare(mul_initializer->Name()) == 0) { const ONNX_NAMESPACE::TensorProto* tensor_proto = entry.second; int32_t data_type = tensor_proto->data_type(); - onnxruntime::Initializer float_const{*tensor_proto, graph.ModelPath()}; + onnxruntime::Initializer float_const{graph, *tensor_proto, graph.ModelPath()}; TEST_RETURN_IF_NOT(float_const.size() == 1U); float float_const_value; if (data_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) { @@ -7134,7 +7134,7 @@ TEST_F(GraphTransformationTests, ConstantSharing_Share2DFloatOrHalfTypedInitiali if (entry.first.compare(mul_initializer->Name()) == 0) { const ONNX_NAMESPACE::TensorProto* tensor_proto = entry.second; int32_t data_type = tensor_proto->data_type(); - onnxruntime::Initializer float_const{*tensor_proto, graph.ModelPath()}; + onnxruntime::Initializer float_const{graph, *tensor_proto, graph.ModelPath()}; TEST_RETURN_IF_NOT(float_const.size() == 8U); for (int i = 0; i < 8; ++i) { float float_const_value; @@ -7240,7 +7240,7 @@ TEST_F(GraphTransformationTests, ConstantSharing_ShareFloatAndHalfTypedInitializ for (const auto& entry : initialized_tensor_set) { const ONNX_NAMESPACE::TensorProto* tensor_proto = entry.second; int32_t data_type = tensor_proto->data_type(); - onnxruntime::Initializer float_const{*tensor_proto, graph.ModelPath()}; + onnxruntime::Initializer float_const{graph, *tensor_proto, graph.ModelPath()}; if (entry.first.compare(mul_initializer->Name()) == 0) { TEST_RETURN_IF_NOT(float_const.size() == 1U); TEST_RETURN_IF_NOT(data_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT); @@ -7369,7 +7369,7 @@ TEST_F(GraphTransformationTests, ConstantSharing_Share2DFloatAndHalfTypedInitial for (const auto& entry : initialized_tensor_set) { const ONNX_NAMESPACE::TensorProto* tensor_proto = entry.second; int32_t data_type = tensor_proto->data_type(); - onnxruntime::Initializer float_const{*tensor_proto, graph.ModelPath()}; + onnxruntime::Initializer float_const{graph, *tensor_proto, graph.ModelPath()}; TEST_RETURN_IF_NOT(float_const.size() == 8U); if (entry.first.compare(mul_initializer->Name()) == 0) { TEST_RETURN_IF_NOT(data_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT); @@ -7507,13 +7507,13 @@ TEST_F(GraphTransformationTests, ConstantSharing_ShareIntMaxOrFloatInfinityIniti for (const auto& entry : initialized_tensor_set) { if (entry.first.compare(mul_initializer->Name()) == 0) { const ONNX_NAMESPACE::TensorProto* tensor_proto = entry.second; - onnxruntime::Initializer int64_const{*tensor_proto, graph.ModelPath()}; + onnxruntime::Initializer int64_const{graph, *tensor_proto, graph.ModelPath()}; TEST_RETURN_IF_NOT(int64_const.size() == 1U); int64_t int64_const_value = *(int64_const.data()); TEST_RETURN_IF_NOT(int64_const_value == std::numeric_limits::max()); } else if (entry.first.compare(sub_initializer->Name()) == 0) { const ONNX_NAMESPACE::TensorProto* tensor_proto = entry.second; - onnxruntime::Initializer float_const{*tensor_proto, graph.ModelPath()}; + onnxruntime::Initializer float_const{graph, *tensor_proto, graph.ModelPath()}; TEST_RETURN_IF_NOT(float_const.size() == 1U); float float_const_value = *(float_const.data()); TEST_RETURN_IF_NOT(float_const_value == std::numeric_limits::infinity()); @@ -7606,13 +7606,13 @@ TEST_F(GraphTransformationTests, ConstantSharing_ShouldNotShareForGraphOutput) { for (const auto& entry : initialized_tensor_set) { if (entry.first.compare("y_scale") == 0) { const ONNX_NAMESPACE::TensorProto* tensor_proto = entry.second; - onnxruntime::Initializer int64_const{*tensor_proto, graph.ModelPath()}; + onnxruntime::Initializer int64_const{graph, *tensor_proto, graph.ModelPath()}; ASSERT_TRUE(int64_const.size() == 1U); float float_const_value = *(int64_const.data()); ASSERT_TRUE(float_const_value == 1); } else { const ONNX_NAMESPACE::TensorProto* tensor_proto = entry.second; - onnxruntime::Initializer uint8_const{*tensor_proto, graph.ModelPath()}; + onnxruntime::Initializer uint8_const{graph, *tensor_proto, graph.ModelPath()}; ASSERT_TRUE(uint8_const.size() == 1U); uint8_t uint8_const_value = *(uint8_const.data()); ASSERT_TRUE(uint8_const_value == static_cast(1)); @@ -7688,7 +7688,7 @@ TEST_F(GraphTransformationTests, GatherSliceToSplitFusion_AllGather) { const ONNX_NAMESPACE::TensorProto* tensor_proto = graph_utils::GetConstantInitializer(graph, input_arg.Name()); TEST_RETURN_IF_NOT(tensor_proto != nullptr); - Initializer init_const{*tensor_proto, graph.ModelPath()}; + Initializer init_const{graph, *tensor_proto, graph.ModelPath()}; TEST_RETURN_IF_NOT(tensor_proto->data_type() == ONNX_NAMESPACE::TensorProto_DataType_INT64); TEST_RETURN_IF_NOT(2 == static_cast(*(init_const.data()))); } @@ -7828,7 +7828,7 @@ TEST_F(GraphTransformationTests, GatherSliceToSplitFusion_Combined) { const NodeArg& input_arg = *(node.InputDefs()[1]); const ONNX_NAMESPACE::TensorProto* tensor_proto = graph_utils::GetConstantInitializer(graph, input_arg.Name()); TEST_RETURN_IF_NOT(tensor_proto != nullptr); - Initializer init_const{*tensor_proto, graph.ModelPath()}; + Initializer init_const{graph, *tensor_proto, graph.ModelPath()}; TEST_RETURN_IF_NOT(tensor_proto->data_type() == ONNX_NAMESPACE::TensorProto_DataType_INT64); TEST_RETURN_IF_NOT(1 == static_cast(*(init_const.data()))); } @@ -7881,7 +7881,7 @@ TEST_F(GraphTransformationTests, GatherSliceToSplitFusion_Consume_Initializer) { const NodeArg& input_arg = *(node.InputDefs()[1]); const ONNX_NAMESPACE::TensorProto* tensor_proto = graph_utils::GetConstantInitializer(graph, input_arg.Name()); TEST_RETURN_IF_NOT(tensor_proto != nullptr); - Initializer init_const{*tensor_proto, graph.ModelPath()}; + Initializer init_const{graph, *tensor_proto, graph.ModelPath()}; TEST_RETURN_IF_NOT(tensor_proto->data_type() == ONNX_NAMESPACE::TensorProto_DataType_INT64); TEST_RETURN_IF_NOT(2 == static_cast(*(init_const.data()))); } @@ -8051,7 +8051,7 @@ TEST_F(GraphTransformationTests, GatherToSliceFusion) { const ONNX_NAMESPACE::TensorProto* tensor_proto = graph_utils::GetConstantInitializer(graph, input_arg.Name()); TEST_RETURN_IF_NOT(tensor_proto != nullptr); - Initializer init_const{*tensor_proto, graph.ModelPath()}; + Initializer init_const{graph, *tensor_proto, graph.ModelPath()}; TEST_RETURN_IF_NOT(tensor_proto->data_type() == ONNX_NAMESPACE::TensorProto_DataType_INT32); TEST_RETURN_IF_NOT(2 == *(init_const.data())); } @@ -8090,7 +8090,7 @@ TEST_F(GraphTransformationTests, GatherToSliceFusion) { const ONNX_NAMESPACE::TensorProto* tensor_proto = graph_utils::GetConstantInitializer(graph, input_arg.Name()); TEST_RETURN_IF_NOT(tensor_proto != nullptr); - Initializer init_const{*tensor_proto, graph.ModelPath()}; + Initializer init_const{graph, *tensor_proto, graph.ModelPath()}; TEST_RETURN_IF_NOT(tensor_proto->data_type() == ONNX_NAMESPACE::TensorProto_DataType_INT64); TEST_RETURN_IF_NOT(2 == static_cast(*(init_const.data()))); } diff --git a/onnxruntime/test/optimizer/graph_transform_test_layernorm.cc b/onnxruntime/test/optimizer/graph_transform_test_layernorm.cc index 7b700922f4306..627a68f38b585 100644 --- a/onnxruntime/test/optimizer/graph_transform_test_layernorm.cc +++ b/onnxruntime/test/optimizer/graph_transform_test_layernorm.cc @@ -882,19 +882,18 @@ static void EmbedLayerNormFusionFormat5(const std::basic_string& file EXPECT_EQ(op_to_count["com.microsoft.EmbedLayerNormalization"], 1); // Validate the position embedding input. + double expected_value[] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 8.0, 7.0, 6.0}; for (const Node& node : graph.Nodes()) { if (node.OpType() == "EmbedLayerNormalization") { const ONNX_NAMESPACE::TensorProto* tensor_proto = graph_utils::GetConstantInitializer(graph, node.InputDefs()[3]->Name()); ASSERT_TRUE(tensor_proto != nullptr); EXPECT_EQ(tensor_proto->data_type(), ONNX_NAMESPACE::TensorProto_DataType_FLOAT); - auto initializer = std::make_unique(*tensor_proto, graph.ModelPath()); - EXPECT_EQ(initializer->size(), 12U); + Initializer initializer{graph, *tensor_proto, graph.ModelPath()}; + EXPECT_EQ(initializer.size(), std::size(expected_value)); - std::vector expected_value = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 8.0, 7.0, 6.0}; - - const float* data = initializer->data(); - for (size_t i = 0; i < expected_value.size(); i++) { + const float* data = initializer.data(); + for (size_t i = 0; i < std::size(expected_value); i++) { EXPECT_EQ(data[i], static_cast(expected_value[i])); } } diff --git a/onnxruntime/test/perftest/command_args_parser.cc b/onnxruntime/test/perftest/command_args_parser.cc index 9e0db487dbfc0..843875a881f0a 100644 --- a/onnxruntime/test/perftest/command_args_parser.cc +++ b/onnxruntime/test/perftest/command_args_parser.cc @@ -91,6 +91,8 @@ namespace perftest { "\t [QNN only] [vtcm_mb]: QNN VTCM size in MB. default to 0(not set).\n" "\t [QNN only] [htp_performance_mode]: QNN performance mode, options: 'burst', 'balanced', 'default', 'high_performance', \n" "\t 'high_power_saver', 'low_balanced', 'extreme_power_saver', 'low_power_saver', 'power_saver', 'sustained_high_performance'. Default to 'default'. \n" + "\t [QNN only] [op_packages]: QNN UDO package, allowed format: \n" + "\t op_packages|::[:],::[:]. \n" "\t [QNN only] [qnn_context_priority]: QNN context priority, options: 'low', 'normal', 'normal_high', 'high'. Default to 'normal'. \n" "\t [QNN only] [qnn_saver_path]: QNN Saver backend path. e.g '/folderpath/libQnnSaver.so'.\n" "\t [QNN only] [htp_graph_finalization_optimization_mode]: QNN graph finalization optimization mode, options: \n" diff --git a/onnxruntime/test/perftest/ort_test_session.cc b/onnxruntime/test/perftest/ort_test_session.cc index d036375874c4b..7a210ca8482a4 100644 --- a/onnxruntime/test/perftest/ort_test_session.cc +++ b/onnxruntime/test/perftest/ort_test_session.cc @@ -219,7 +219,7 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device #endif ParseSessionConfigs(option_string, provider_options, {"backend_type", "backend_path", "profiling_file_path", "profiling_level", - "rpc_control_latency", "vtcm_mb", "soc_model", "device_id", "htp_performance_mode", + "rpc_control_latency", "vtcm_mb", "soc_model", "device_id", "htp_performance_mode", "op_packages", "qnn_saver_path", "htp_graph_finalization_optimization_mode", "qnn_context_priority", "htp_arch", "enable_htp_fp16_precision", "offload_graph_io_quantization", "enable_htp_spill_fill_buffer", "enable_htp_shared_memory_allocator", "dump_json_qnn_graph", @@ -250,6 +250,10 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device std::string str = str_stream.str(); ORT_THROW("Supported htp_performance_mode: " + str); } + } else if (key == "op_packages") { + if (value.empty()) { + ORT_THROW("Please provide the valid op_packages."); + } } else if (key == "qnn_saver_path") { // no validation } else if (key == "htp_graph_finalization_optimization_mode") { diff --git a/onnxruntime/test/providers/coreml/coreml_basic_test.cc b/onnxruntime/test/providers/coreml/coreml_basic_test.cc index a2c9881ab5169..2449f7c962e83 100644 --- a/onnxruntime/test/providers/coreml/coreml_basic_test.cc +++ b/onnxruntime/test/providers/coreml/coreml_basic_test.cc @@ -137,7 +137,7 @@ TEST(CoreMLExecutionProviderTest, FunctionTest) { std::vector values_mul_x = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; OrtValue ml_value_x; - AllocatorPtr allocator = std::make_shared(); + AllocatorPtr allocator = CPUAllocator::DefaultInstance(); CreateMLValue(allocator, dims_mul_x, values_mul_x, &ml_value_x); OrtValue ml_value_y; CreateMLValue(allocator, dims_mul_x, values_mul_x, &ml_value_y); @@ -169,7 +169,7 @@ TEST(CoreMLExecutionProviderTest, ArgMaxCastTest) { std::vector dims_mul_x = {3, 2, 2}; std::vector values_mul_x = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f}; OrtValue ml_value_x; - AllocatorPtr allocator = std::make_shared(); + AllocatorPtr allocator = CPUAllocator::DefaultInstance(); CreateMLValue(allocator, dims_mul_x, values_mul_x, &ml_value_x); NameMLValMap feeds; @@ -198,7 +198,7 @@ TEST(CoreMLExecutionProviderTest, ArgMaxUnsupportedCastTest) { std::vector dims_mul_x = {3, 2, 2}; std::vector values_mul_x = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f}; OrtValue ml_value_x; - AllocatorPtr allocator = std::make_shared(); + AllocatorPtr allocator = CPUAllocator::DefaultInstance(); CreateMLValue(allocator, dims_mul_x, values_mul_x, &ml_value_x); NameMLValMap feeds; @@ -338,7 +338,7 @@ TEST(CoreMLExecutionProviderTest, TestModelCache) { std::vector dims_mul_x = {3, 2, 2}; std::vector values_mul_x = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f}; OrtValue ml_value_x; - AllocatorPtr allocator = std::make_shared(); + AllocatorPtr allocator = CPUAllocator::DefaultInstance(); CreateMLValue(allocator, dims_mul_x, values_mul_x, &ml_value_x); NameMLValMap feeds; diff --git a/onnxruntime/test/providers/cpu/llm/rotary_embedding_op_test.cc b/onnxruntime/test/providers/cpu/llm/rotary_embedding_op_test.cc new file mode 100644 index 0000000000000..482e23a4b0fb3 --- /dev/null +++ b/onnxruntime/test/providers/cpu/llm/rotary_embedding_op_test.cc @@ -0,0 +1,1091 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include "gtest/gtest.h" +#include "core/session/onnxruntime_cxx_api.h" +#include "test/common/tensor_op_test_utils.h" +#include "test/common/cuda_op_test_utils.h" +#include "test/providers/provider_test_utils.h" + +namespace onnxruntime { +namespace test { + +namespace { +enum class TensorType { + kFloat, + kFloat16, + kBFloat16 +}; +} // anonymous namespace + +static void RunTest( + const std::vector& input_data, + const std::vector& position_ids, + const std::vector& cos_cache, + const std::vector& sin_cache, + const std::vector& output_data, + int batch_size, + int sequence_length, + int head_size, + int rotary_embedding_dim, + int num_heads, + int max_sequence_length, + int64_t interleaved, + TensorType tensor_type, + bool input_is_4d, + bool disable_cpu, + bool disable_cuda, + bool disable_dml) { + // input : (batch_size, sequence_length, hidden_size) or (batch_size, num_heads, sequence_length, head_size) + // position ids : (0) or (batch_size, sequence_length) + // cos cache : (max_sequence_length, head_size / 2) + // sin cache : (max_sequence_length, head_size / 2) + // interleaved : 0 = false, 1 = true + + int hidden_size = num_heads * head_size; + std::vector input_dims; + if (input_is_4d) { + input_dims = {batch_size, num_heads, sequence_length, head_size}; + } else { + input_dims = {batch_size, sequence_length, hidden_size}; + } + std::vector pos_dims; + + std::vector cache_dims; + if (position_ids.size() != 0) { + cache_dims = {max_sequence_length, rotary_embedding_dim > 0 ? rotary_embedding_dim / 2 : head_size / 2}; + } else { + cache_dims = {batch_size, sequence_length, rotary_embedding_dim > 0 ? rotary_embedding_dim / 2 : head_size / 2}; + } + + assert(hidden_size != 0 && head_size != 0 && num_heads != 0 && max_sequence_length != 0); + assert(max_sequence_length >= sequence_length); + if (position_ids.size() == 0) { + pos_dims = {}; + } else { + pos_dims = {batch_size, sequence_length}; + } + + std::string op_type = "RotaryEmbedding"; + std::vector> execution_providers; + + int min_cuda_architecture = (tensor_type == TensorType::kBFloat16) + ? 800 + : (tensor_type == TensorType::kFloat16) ? 530 + : 0; + bool enable_cuda = HasCudaEnvironment(min_cuda_architecture); + bool enable_dml = (nullptr != DefaultDmlExecutionProvider().get()) && !disable_dml; + bool enable_webgpu = nullptr != DefaultWebGpuExecutionProvider().get(); + + if (enable_cuda && !disable_cuda) { + execution_providers.push_back(DefaultCudaExecutionProvider()); + } + if (enable_dml && !disable_dml) { + execution_providers.push_back(DefaultDmlExecutionProvider()); + } + if ((tensor_type == TensorType::kFloat || tensor_type == TensorType::kFloat16) && !disable_cpu) { + execution_providers.push_back(DefaultCpuExecutionProvider()); + } + if (enable_webgpu) { + execution_providers.push_back(DefaultWebGpuExecutionProvider()); + } + if (execution_providers.size() == 0) { + // Return early if CI pipeline does not support EP (e.g. CUDA EP for CPU CI pipeline) + return; + } + + for (auto& ep : execution_providers) { + OpTester test(op_type.c_str(), 23, onnxruntime::kOnnxDomain); + test.AddAttribute("interleaved", interleaved); + test.AddAttribute("num_heads", num_heads); + + if (rotary_embedding_dim > 0) { + test.AddAttribute("rotary_embedding_dim", rotary_embedding_dim); + } + + if (tensor_type == TensorType::kFloat) { + test.AddInput("input", input_dims, input_data); + test.AddInput("cos_cache", cache_dims, cos_cache); + test.AddInput("sin_cache", cache_dims, sin_cache); + if (position_ids.size()) { + test.AddInput("position_ids", pos_dims, position_ids); + } else { + test.AddOptionalInputEdge(); + } + test.AddOutput("output", input_dims, output_data); + } else if (tensor_type == TensorType::kFloat16) { + test.AddInput("input", input_dims, ToFloat16(input_data)); + test.AddInput("cos_cache", cache_dims, ToFloat16(cos_cache)); + test.AddInput("sin_cache", cache_dims, ToFloat16(sin_cache)); + if (position_ids.size()) { + test.AddInput("position_ids", pos_dims, position_ids); + } else { + test.AddOptionalInputEdge(); + } + test.AddOutput("output", input_dims, ToFloat16(output_data)); + } else { + test.AddInput("input", input_dims, FloatsToBFloat16s(input_data)); + test.AddInput("cos_cache", cache_dims, FloatsToBFloat16s(cos_cache)); + test.AddInput("sin_cache", cache_dims, FloatsToBFloat16s(sin_cache)); + if (position_ids.size()) { + test.AddInput("position_ids", pos_dims, position_ids); + } else { + test.AddOptionalInputEdge(); + } + test.AddOutput("output", input_dims, FloatsToBFloat16s(output_data)); + } + if (tensor_type == TensorType::kBFloat16) { + test.SetOutputAbsErr("output", 0.03f); + } else { + test.SetOutputAbsErr("output", 0.002f); + } + + std::vector> test_execution_providers; + test_execution_providers.push_back(std::move(ep)); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &test_execution_providers); + } +} + +static void RunTests(const std::vector& input_data, + const std::vector& position_ids, + const std::vector& cos_cache, + const std::vector& sin_cache, + const std::vector& output_data, + int batch_size, + int sequence_length, + int head_size = 0, + int rotary_embedding_dim = 0, + int num_heads = 0, + int max_sequence_length = 0, + int64_t interleaved = 0, + bool use_float16 = true, + bool input_is_4d = false) { + // FP32 test for CPU, CUDA and DML + RunTest(input_data, + position_ids, + cos_cache, + sin_cache, + output_data, + batch_size, + sequence_length, + head_size, + rotary_embedding_dim, + num_heads, + max_sequence_length, + interleaved, + TensorType::kFloat, + input_is_4d, + false, /* disable_cpu */ + false, /* disable_cuda */ + false /* disable_dml */); + + // FP16 test for CPU, CUDA and DML + if (use_float16) { + RunTest(input_data, + position_ids, + cos_cache, + sin_cache, + output_data, + batch_size, + sequence_length, + head_size, + rotary_embedding_dim, + num_heads, + max_sequence_length, + interleaved, + TensorType::kFloat16, + input_is_4d, + false, /* disable_cpu */ + false, /* disable_cuda*/ + false /* disable_dml */); + } +} + +// Interleaved = true, pos ids shape = (batch_size, sequence_length) +TEST(RotaryEmbeddingTest, RotaryEmbedding_Interleaved_SmallData_LlamaMSFT) { + int batch_size = 1; + int sequence_length = 3; + int num_heads = 2; + int head_size = 4; + int rotary_embedding_dim = 0; + int max_sequence_length = 8; + int64_t interleaved = 1; // true + + std::vector input_data = { + -1.0408f, 0.9166f, -1.3042f, -1.1097f, -0.1320f, -0.2751f, -0.2350f, 0.0937f, + -1.2188f, 1.1676f, -1.0574f, -0.1188f, -0.7396f, -1.2425f, -0.1752f, 0.6990f, + -0.8110f, 0.6737f, -1.1233f, -0.0919f, -0.6861f, 0.7202f, 0.1963f, 0.6142f}; + + std::vector position_ids = {0, 1, 2}; + + std::vector cos_cache = { + 1.0000f, 1.0000f, 0.5403f, 0.9999f, -0.4161f, 0.9998f, -0.9900f, 0.9996f, + -0.6536f, 0.9992f, 0.2837f, 0.9988f, 0.9602f, 0.9982f, 0.7539f, 0.9976f}; + + std::vector sin_cache = { + 0.0000f, 0.0000f, 0.8415f, 0.0100f, 0.9093f, 0.0200f, 0.1411f, 0.0300f, + -0.7568f, 0.0400f, -0.9589f, 0.0500f, -0.2794f, 0.0600f, 0.6570f, 0.0699f}; + + std::vector output_data = { + -1.0408f, 0.9166f, -1.3042f, -1.1097f, -0.1320f, -0.2751f, -0.2350f, 0.0937f, + -1.6411f, -0.3948f, -1.0561f, -0.1294f, 0.6460f, -1.2937f, -0.1822f, 0.6972f, + -0.2751f, -1.0178f, -1.1212f, -0.1143f, -0.3694f, -0.9235f, 0.1840f, 0.6180f}; + + RunTests(input_data, + position_ids, + cos_cache, + sin_cache, + output_data, + batch_size, + sequence_length, + head_size, + rotary_embedding_dim, + num_heads, + max_sequence_length, + interleaved); +} + +// Interleaved = true, pos ids shape = (batch_size, sequence_length) +TEST(RotaryEmbeddingTest, RotaryEmbedding_Interleaved_SmallData_LlamaMSFT_4D_Input) { + int batch_size = 1; + int sequence_length = 3; + int num_heads = 2; + int head_size = 4; + int rotary_embedding_dim = 0; + int max_sequence_length = 8; + int64_t interleaved = 1; // true + + std::vector input_data = { + // Head 0: sequence 0, 1, 2 + -1.0408f, 0.9166f, -1.3042f, -1.1097f, // seq 0 + -1.2188f, 1.1676f, -1.0574f, -0.1188f, // seq 1 + -0.8110f, 0.6737f, -1.1233f, -0.0919f, // seq 2 + // Head 1: sequence 0, 1, 2 + -0.1320f, -0.2751f, -0.2350f, 0.0937f, // seq 0 + -0.7396f, -1.2425f, -0.1752f, 0.6990f, // seq 1 + -0.6861f, 0.7202f, 0.1963f, 0.6142f}; // seq 2 + + std::vector position_ids = {0, 1, 2}; + + std::vector cos_cache = { + 1.0000f, 1.0000f, 0.5403f, 0.9999f, -0.4161f, 0.9998f, -0.9900f, 0.9996f, + -0.6536f, 0.9992f, 0.2837f, 0.9988f, 0.9602f, 0.9982f, 0.7539f, 0.9976f}; + + std::vector sin_cache = { + 0.0000f, 0.0000f, 0.8415f, 0.0100f, 0.9093f, 0.0200f, 0.1411f, 0.0300f, + -0.7568f, 0.0400f, -0.9589f, 0.0500f, -0.2794f, 0.0600f, 0.6570f, 0.0699f}; + + // Expected output in 4D layout: [batch=1, num_heads=2, seq_len=3, head_size=4] + std::vector output_data = { + // Head 0: sequence 0, 1, 2 + -1.0408f, 0.9166f, -1.3042f, -1.1097f, // seq 0 (no change) + -1.6411f, -0.3948f, -1.0561f, -0.1294f, // seq 1 (rotated) + -0.2751f, -1.0178f, -1.1212f, -0.1143f, // seq 2 (rotated) + // Head 1: sequence 0, 1, 2 + -0.1320f, -0.2751f, -0.2350f, 0.0937f, // seq 0 (no change) + 0.6460f, -1.2937f, -0.1822f, 0.6972f, // seq 1 (rotated) + -0.3694f, -0.9235f, 0.1840f, 0.6180f}; // seq 2 (rotated) + + RunTests(input_data, + position_ids, + cos_cache, + sin_cache, + output_data, + batch_size, + sequence_length, + head_size, + rotary_embedding_dim, + num_heads, + max_sequence_length, + interleaved, + true, // use_float16 + true); // input_is_4d +} + +// Interleaved = true, position_ids shape = (batch_size, sequence_length) +TEST(RotaryEmbeddingTest, RotaryEmbedding_Interleaved_LargeData_LlamaMSFT) { + int batch_size = 2; + int sequence_length = 8; + int num_heads = 4; + int head_size = 6; + int rotary_embedding_dim = 0; + int max_sequence_length = 16; + int64_t interleaved = 1; // true + + std::vector input_data = { + -1.0408f, 0.9166f, -1.3042f, -1.1097f, -1.2188f, + 1.1676f, -1.0190f, 0.3157f, -1.6036f, 1.8493f, + 0.0447f, 1.5853f, 0.1036f, -0.3514f, 0.2421f, + 0.6463f, 0.8730f, -0.9276f, 1.0311f, -1.9557f, + -0.1482f, 1.7376f, 2.2039f, -0.6589f, -1.0574f, + -0.1188f, -0.9078f, 0.3452f, -0.5713f, -0.2351f, + -0.5912f, 1.1312f, 0.7562f, -1.2023f, -0.5833f, + -0.4407f, 0.1766f, 1.0224f, -0.4826f, -0.5421f, + -0.5342f, -0.6413f, 1.3314f, -0.4498f, 0.5493f, + 0.0539f, 0.2601f, 0.8570f, 1.0076f, -0.7529f, + -0.2250f, -0.4327f, -1.5071f, -0.4586f, -1.9791f, + 0.7787f, -0.7749f, -0.1398f, 1.1414f, -0.6354f, + 0.0352f, -0.4765f, -0.0409f, 1.1993f, 0.5374f, + -0.1930f, 2.5211f, -0.0452f, -0.3105f, -0.9407f, + -0.0034f, 1.5199f, -0.8480f, 0.5266f, 0.0299f, + -0.0498f, 1.0651f, 0.8860f, -1.4702f, -0.2134f, + -0.8707f, 1.6159f, -0.2356f, 0.9444f, 0.5937f, + 0.7203f, 0.5061f, 1.5192f, -0.4897f, 0.9231f, + 0.2654f, -0.1441f, 0.5407f, -1.5476f, 0.6455f, + -1.1382f, 0.4640f, -0.4986f, 0.1289f, 2.7631f, + 0.1405f, 1.1191f, 2.1134f, -0.9754f, 0.1757f, + -0.1319f, -0.2735f, 0.3355f, -0.6008f, -1.1164f, + 0.2577f, -0.7226f, -0.9244f, 1.8737f, 0.6052f, + 1.1904f, 1.2195f, -0.0470f, -1.0914f, 1.0223f, + 0.3152f, 1.7528f, -0.7650f, 1.8299f, -0.2784f, + -0.2719f, 0.1885f, 2.1432f, 0.8527f, 0.0965f, + -0.0625f, 0.8269f, 1.0122f, -1.4482f, -0.0644f, + 0.3215f, 0.5908f, -1.4197f, 0.2113f, 0.0306f, + 0.3604f, 0.3166f, -0.8975f, -0.6393f, -1.2944f, + -0.0243f, -0.2354f, -0.7087f, 1.1566f, 0.4296f, + 0.5599f, -0.7776f, 0.3339f, 0.1759f, 2.1108f, + 1.0702f, 0.8279f, -0.2969f, 0.7120f, -0.2068f, + -0.1548f, 0.1553f, 0.6207f, -0.1690f, -0.5816f, + 1.2632f, 0.0695f, 1.1862f, -1.1874f, -0.7468f, + -0.9320f, -0.8579f, -0.9647f, -0.0991f, 0.0195f, + 1.1213f, -1.4873f, -0.2043f, -1.0466f, -1.5772f, + -0.0489f, 0.3430f, 0.1264f, 0.1519f, -1.3639f, + -1.6593f, 1.8127f, -1.4459f, -0.2158f, -0.9792f, + -1.4392f, 0.6508f, 0.8964f, 0.5717f, -0.2390f, + 0.6983f, -1.3416f, 0.2715f, -0.2852f, 0.6051f, + 0.2167f, -0.2181f, -1.6306f, 1.4788f, 0.2754f, + -0.0261f, -0.4618f, -0.5646f, -1.0389f, 0.5819f, + 1.3697f, 0.0002f, 1.5333f, -1.0556f, -0.1254f, + 0.1527f, -0.5996f, -1.0962f, 1.6327f, 1.3951f, + 0.8784f, 0.3389f, 1.2907f, 0.3124f, 0.7299f, + 1.4220f, 0.3375f, 0.0438f, 1.8698f, -0.2635f, + -2.0799f, -0.6313f, 0.4090f, -1.1458f, 0.0784f, + -1.8848f, -1.6165f, 0.6179f, 0.9905f, -0.0729f, + 0.5054f, -0.6681f, -1.4382f, 1.7547f, -0.9605f, + -0.4558f, -1.6105f, 0.2979f, 1.1537f, -1.5604f, + 1.2779f, -1.2514f, 0.6056f, 0.5763f, -3.3558f, + 0.2836f, 0.6909f, -0.7631f, 2.4451f, -0.3500f, + 1.3289f, -0.6494f, 0.3478f, 1.0038f, -0.2937f, + 0.9238f, -1.2185f, 0.4138f, 0.5033f, 0.9174f, + 1.8131f, 1.4436f, -0.4207f, 0.0220f, -0.6807f, + -1.3306f, 1.5646f, 0.3338f, 0.7105f, 0.4683f, + -0.6179f, 0.0818f, -0.0488f, -0.9810f, -1.3632f, + 0.0929f, -1.7926f, -0.2921f, -0.4792f, 0.6756f, + -0.3413f, -0.2242f, -0.2111f, 0.6282f, 0.1667f, + -1.4055f, 1.5895f, 1.0838f, -0.9077f, -0.8060f, + 0.7967f, -2.9351f, 2.4179f, -0.4026f, 0.6451f, + 1.6845f, -0.0901f, 0.6106f, 2.3603f, 1.3908f, + -0.7917f, -0.6734f, -0.1213f, -1.1116f, -0.7401f, + -0.7879f, 0.0606f, -2.3337f, -1.2603f, -1.7245f, + -0.3533f, -0.9421f, -0.1776f, 0.3992f, -1.7142f, + -0.5319f, -0.8848f, 0.6513f, 1.0002f, -1.4699f, + -1.4254f, 0.7013f, 0.2414f, 0.2551f, -0.7457f, + 0.3133f, -1.0941f, -0.3682f, -0.0163f, -0.0645f, + -0.8101f, 0.1415f, 0.0551f, 0.5873f, -0.5887f, + -1.4733f, -0.8565f, 0.7400f, -0.5033f, 0.0553f, + 0.9265f, -0.8652f, -0.0288f, -0.2209f, 0.0610f, + 0.6776f, 0.4361f, -0.8052f, 0.3955f, 0.8988f, + 0.8238f, 0.2262f, 1.2912f, 0.6488f, 1.2114f, + 1.3569f, 0.2983f, 0.4718f, -1.1936f, 0.7928f, + -0.8665f, 0.9468f, 1.1629f, 0.0616f, -1.3136f, + -0.2764f, 0.0277f, -0.1126f, 0.2342f, -0.5866f, + -1.8219f, 1.1079f, 0.5795f, -1.4249f}; + + std::vector position_ids = {0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7}; + + std::vector cos_cache = { + 1.0000f, 1.0000f, 1.0000f, 0.5403f, 0.9989f, 1.0000f, -0.4161f, 0.9957f, + 1.0000f, -0.9900f, 0.9903f, 1.0000f, -0.6536f, 0.9828f, 1.0000f, 0.2837f, + 0.9732f, 0.9999f, 0.9602f, 0.9615f, 0.9999f, 0.7539f, 0.9477f, 0.9999f, + -0.1455f, 0.9318f, 0.9999f, -0.9111f, 0.9140f, 0.9998f, -0.8391f, 0.8942f, + 0.9998f, 0.0044f, 0.8725f, 0.9997f, 0.8439f, 0.8488f, 0.9997f, 0.9074f, + 0.8234f, 0.9996f, 0.1367f, 0.7962f, 0.9995f, -0.7597f, 0.7673f, 0.9995f}; + + std::vector sin_cache = { + 0.0000f, 0.0000f, 0.0000f, 0.8415f, 0.0464f, 0.0022f, 0.9093f, 0.0927f, + 0.0043f, 0.1411f, 0.1388f, 0.0065f, -0.7568f, 0.1846f, 0.0086f, -0.9589f, + 0.2300f, 0.0108f, -0.2794f, 0.2749f, 0.0129f, 0.6570f, 0.3192f, 0.0151f, + 0.9894f, 0.3629f, 0.0172f, 0.4121f, 0.4057f, 0.0194f, -0.5440f, 0.4477f, + 0.0215f, -1.0000f, 0.4887f, 0.0237f, -0.5366f, 0.5286f, 0.0259f, 0.4202f, + 0.5675f, 0.0280f, 0.9906f, 0.6050f, 0.0302f, 0.6503f, 0.6413f, 0.0323f}; + + std::vector output_data = { + -1.0408f, 0.9166f, -1.3042f, -1.1097f, -1.2188f, + 1.1676f, -1.0190f, 0.3157f, -1.6036f, 1.8493f, + 0.0447f, 1.5853f, 0.1036f, -0.3514f, 0.2421f, + 0.6463f, 0.8730f, -0.9276f, 1.0311f, -1.9557f, + -0.1482f, 1.7376f, 2.2039f, -0.6589f, -0.4713f, + -0.9540f, -0.9229f, 0.3027f, -0.5708f, -0.2363f, + -1.2713f, 0.1137f, 0.8112f, -1.1659f, -0.5824f, + -0.4419f, -0.7649f, 0.7011f, -0.4569f, -0.5639f, + -0.5328f, -0.6424f, 1.0979f, 0.8773f, 0.5462f, + 0.0793f, 0.2582f, 0.8576f, 0.2653f, 1.2295f, + -0.1839f, -0.4517f, -1.5052f, -0.4651f, 0.1155f, + -2.1237f, -0.7586f, -0.2110f, 1.1441f, -0.6304f, + 0.4186f, 0.2303f, -0.1519f, 1.1903f, 0.5382f, + -0.1906f, -1.0080f, 2.3112f, -0.2220f, -0.9655f, + -0.0099f, 1.5198f, 0.7652f, -0.6410f, 0.0365f, + -0.0452f, 1.0593f, 0.8929f, 1.4856f, 0.0038f, + -1.0865f, 1.4794f, -0.2417f, 0.9428f, -0.6894f, + -0.6293f, 0.2904f, 1.5747f, -0.4956f, 0.9199f, + -0.2424f, 0.1801f, 0.7503f, -1.4576f, 0.6529f, + -1.1340f, -0.6807f, -0.0252f, -0.3834f, 2.7394f, + 0.1308f, 1.1203f, -2.1196f, -0.9618f, 0.1970f, + -0.0972f, -0.2764f, 0.3332f, -0.4522f, 1.1844f, + 0.3867f, -0.6626f, -0.9405f, 1.8656f, 0.5053f, + -1.2361f, 1.2072f, 0.1789f, -1.1002f, 1.0129f, + 1.7702f, 0.1949f, -1.1653f, 1.6049f, -0.2755f, + -0.2749f, 2.1087f, 0.4272f, 0.8076f, 0.2900f, + -0.0714f, 0.8261f, -1.1016f, -1.3814f, -0.1366f, + 0.2981f, 0.6060f, -1.4132f, 0.0893f, -0.1939f, + 0.2779f, 0.3910f, -0.8906f, -0.6489f, -1.2496f, + 0.3383f, -0.0315f, -0.7461f, 1.1510f, 0.4445f, + 0.3203f, -0.9031f, 0.2727f, 0.2609f, 2.0968f, + 1.0974f, 0.7120f, -0.5164f, 0.7415f, -0.0031f, + -0.1568f, 0.1533f, 0.5487f, -0.3357f, -0.9064f, + 1.0546f, 0.0542f, 1.1870f, -0.4045f, -1.3431f, + -0.6094f, -1.1105f, -0.9631f, -0.1137f, -0.7219f, + 0.8582f, -1.3443f, -0.6684f, -1.0227f, -1.5929f, + -0.2622f, 0.2264f, 0.0713f, 0.1843f, -1.3387f, + -1.6797f, 2.3165f, 0.1009f, 0.1081f, -0.9969f, + -1.4488f, 0.6291f, 0.8964f, 0.5717f, -0.2390f, + 0.6983f, -1.3416f, 0.2715f, -0.2852f, 0.6051f, + 0.2167f, -0.2181f, -1.6306f, 1.4788f, 0.2754f, + -0.0261f, -0.4618f, -0.5646f, -1.0389f, 0.5819f, + 1.3697f, 0.0002f, 1.5333f, -1.0556f, -0.1254f, + 0.1527f, 0.5985f, -1.0968f, 1.5662f, 1.4693f, + 0.8776f, 0.3408f, 0.4345f, 1.2549f, 0.6631f, + 1.4543f, 0.3374f, 0.0445f, 1.2320f, 1.4311f, + -2.0483f, -0.7272f, 0.4114f, -1.1449f, 1.6283f, + -0.9524f, -1.6435f, 0.5422f, 0.9907f, -0.0708f, + 0.3972f, 0.7376f, -1.5947f, 1.6138f, -0.9586f, + -0.4600f, 0.3993f, -1.5884f, 1.2934f, -1.4467f, + 1.2833f, -1.2459f, -0.7760f, 0.3108f, -3.3677f, + -0.0287f, 0.6942f, -0.7601f, -0.6993f, 2.3690f, + 1.3834f, -0.5234f, 0.3435f, 1.0053f, 0.1604f, + -0.9560f, -1.2641f, 0.2406f, 0.4973f, 0.9206f, + -1.9987f, -1.1733f, -0.4197f, -0.0366f, -0.6720f, + -1.3350f, -1.5960f, -0.1097f, 0.6386f, 0.5624f, + -0.6184f, 0.0778f, 0.1867f, 0.9643f, -1.3629f, + -0.0972f, -1.7907f, -0.3037f, 0.8245f, -0.0789f, + -0.2940f, -0.2833f, -0.2165f, 0.6264f, -1.1726f, + 0.7926f, 1.3621f, 1.3586f, -0.9007f, -0.8138f, + -2.7421f, 1.3155f, 2.4507f, 0.0507f, 0.6305f, + 1.6900f, 0.5210f, -0.3309f, 2.0630f, 1.8026f, + -0.7859f, -0.6802f, -1.1003f, -0.1990f, -0.5391f, + -0.9370f, 0.0857f, -2.3330f, -2.0112f, 0.7193f, + -0.1272f, -0.9981f, -0.1818f, 0.3973f, -0.9963f, + 1.4929f, -1.0109f, 0.4304f, 1.0160f, -1.4590f, + 0.2682f, 1.5658f, 0.1762f, 0.3038f, -0.7491f, + 0.3052f, -1.1534f, -0.0478f, 0.0021f, -0.0665f, + -0.8118f, 0.1310f, 0.2171f, 0.5485f, -0.1610f, + -1.5784f, -0.8660f, 0.7289f, -0.4678f, 0.1937f, + 1.1287f, -0.5772f, -0.0259f, -0.2212f, 0.2479f, + 0.6336f, 0.6407f, -0.6543f, 0.3838f, 0.9039f, + 0.4724f, 0.7117f, 1.0165f, 1.0270f, 1.1908f, + 1.3750f, -0.0850f, 0.5517f, -1.3842f, 0.3703f, + -0.8806f, 0.9336f, 0.8362f, 0.8105f, -1.1566f, + -0.6813f, 0.0294f, -0.1122f, 0.5620f, -0.2884f, + -2.0803f, 0.4684f, 0.6009f, -1.4160f}; + + RunTests(input_data, + position_ids, + cos_cache, + sin_cache, + output_data, + batch_size, + sequence_length, + head_size, + rotary_embedding_dim, + num_heads, + max_sequence_length, + interleaved); +} + +// Interleaved = false, pos ids shape = (batch_size, sequence_length) +TEST(RotaryEmbeddingTest, RotaryEmbedding_NotInterleaved_LargeData_LlamaMSFT) { + int batch_size = 2; + int sequence_length = 8; + int num_heads = 4; + int head_size = 6; + int rotary_embedding_dim = 0; + int max_sequence_length = 16; + int64_t interleaved = 0; // false + + std::vector input_data = { + -1.0408f, 0.9166f, -1.3042f, -1.1097f, -1.2188f, + 1.1676f, -1.0190f, 0.3157f, -1.6036f, 1.8493f, + 0.0447f, 1.5853f, 0.1036f, -0.3514f, 0.2421f, + 0.6463f, 0.8730f, -0.9276f, 1.0311f, -1.9557f, + -0.1482f, 1.7376f, 2.2039f, -0.6589f, -1.0574f, + -0.1188f, -0.9078f, 0.3452f, -0.5713f, -0.2351f, + -0.5912f, 1.1312f, 0.7562f, -1.2023f, -0.5833f, + -0.4407f, 0.1766f, 1.0224f, -0.4826f, -0.5421f, + -0.5342f, -0.6413f, 1.3314f, -0.4498f, 0.5493f, + 0.0539f, 0.2601f, 0.8570f, 1.0076f, -0.7529f, + -0.2250f, -0.4327f, -1.5071f, -0.4586f, -1.9791f, + 0.7787f, -0.7749f, -0.1398f, 1.1414f, -0.6354f, + 0.0352f, -0.4765f, -0.0409f, 1.1993f, 0.5374f, + -0.1930f, 2.5211f, -0.0452f, -0.3105f, -0.9407f, + -0.0034f, 1.5199f, -0.8480f, 0.5266f, 0.0299f, + -0.0498f, 1.0651f, 0.8860f, -1.4702f, -0.2134f, + -0.8707f, 1.6159f, -0.2356f, 0.9444f, 0.5937f, + 0.7203f, 0.5061f, 1.5192f, -0.4897f, 0.9231f, + 0.2654f, -0.1441f, 0.5407f, -1.5476f, 0.6455f, + -1.1382f, 0.4640f, -0.4986f, 0.1289f, 2.7631f, + 0.1405f, 1.1191f, 2.1134f, -0.9754f, 0.1757f, + -0.1319f, -0.2735f, 0.3355f, -0.6008f, -1.1164f, + 0.2577f, -0.7226f, -0.9244f, 1.8737f, 0.6052f, + 1.1904f, 1.2195f, -0.0470f, -1.0914f, 1.0223f, + 0.3152f, 1.7528f, -0.7650f, 1.8299f, -0.2784f, + -0.2719f, 0.1885f, 2.1432f, 0.8527f, 0.0965f, + -0.0625f, 0.8269f, 1.0122f, -1.4482f, -0.0644f, + 0.3215f, 0.5908f, -1.4197f, 0.2113f, 0.0306f, + 0.3604f, 0.3166f, -0.8975f, -0.6393f, -1.2944f, + -0.0243f, -0.2354f, -0.7087f, 1.1566f, 0.4296f, + 0.5599f, -0.7776f, 0.3339f, 0.1759f, 2.1108f, + 1.0702f, 0.8279f, -0.2969f, 0.7120f, -0.2068f, + -0.1548f, 0.1553f, 0.6207f, -0.1690f, -0.5816f, + 1.2632f, 0.0695f, 1.1862f, -1.1874f, -0.7468f, + -0.9320f, -0.8579f, -0.9647f, -0.0991f, 0.0195f, + 1.1213f, -1.4873f, -0.2043f, -1.0466f, -1.5772f, + -0.0489f, 0.3430f, 0.1264f, 0.1519f, -1.3639f, + -1.6593f, 1.8127f, -1.4459f, -0.2158f, -0.9792f, + -1.4392f, 0.6508f, 0.8964f, 0.5717f, -0.2390f, + 0.6983f, -1.3416f, 0.2715f, -0.2852f, 0.6051f, + 0.2167f, -0.2181f, -1.6306f, 1.4788f, 0.2754f, + -0.0261f, -0.4618f, -0.5646f, -1.0389f, 0.5819f, + 1.3697f, 0.0002f, 1.5333f, -1.0556f, -0.1254f, + 0.1527f, -0.5996f, -1.0962f, 1.6327f, 1.3951f, + 0.8784f, 0.3389f, 1.2907f, 0.3124f, 0.7299f, + 1.4220f, 0.3375f, 0.0438f, 1.8698f, -0.2635f, + -2.0799f, -0.6313f, 0.4090f, -1.1458f, 0.0784f, + -1.8848f, -1.6165f, 0.6179f, 0.9905f, -0.0729f, + 0.5054f, -0.6681f, -1.4382f, 1.7547f, -0.9605f, + -0.4558f, -1.6105f, 0.2979f, 1.1537f, -1.5604f, + 1.2779f, -1.2514f, 0.6056f, 0.5763f, -3.3558f, + 0.2836f, 0.6909f, -0.7631f, 2.4451f, -0.3500f, + 1.3289f, -0.6494f, 0.3478f, 1.0038f, -0.2937f, + 0.9238f, -1.2185f, 0.4138f, 0.5033f, 0.9174f, + 1.8131f, 1.4436f, -0.4207f, 0.0220f, -0.6807f, + -1.3306f, 1.5646f, 0.3338f, 0.7105f, 0.4683f, + -0.6179f, 0.0818f, -0.0488f, -0.9810f, -1.3632f, + 0.0929f, -1.7926f, -0.2921f, -0.4792f, 0.6756f, + -0.3413f, -0.2242f, -0.2111f, 0.6282f, 0.1667f, + -1.4055f, 1.5895f, 1.0838f, -0.9077f, -0.8060f, + 0.7967f, -2.9351f, 2.4179f, -0.4026f, 0.6451f, + 1.6845f, -0.0901f, 0.6106f, 2.3603f, 1.3908f, + -0.7917f, -0.6734f, -0.1213f, -1.1116f, -0.7401f, + -0.7879f, 0.0606f, -2.3337f, -1.2603f, -1.7245f, + -0.3533f, -0.9421f, -0.1776f, 0.3992f, -1.7142f, + -0.5319f, -0.8848f, 0.6513f, 1.0002f, -1.4699f, + -1.4254f, 0.7013f, 0.2414f, 0.2551f, -0.7457f, + 0.3133f, -1.0941f, -0.3682f, -0.0163f, -0.0645f, + -0.8101f, 0.1415f, 0.0551f, 0.5873f, -0.5887f, + -1.4733f, -0.8565f, 0.7400f, -0.5033f, 0.0553f, + 0.9265f, -0.8652f, -0.0288f, -0.2209f, 0.0610f, + 0.6776f, 0.4361f, -0.8052f, 0.3955f, 0.8988f, + 0.8238f, 0.2262f, 1.2912f, 0.6488f, 1.2114f, + 1.3569f, 0.2983f, 0.4718f, -1.1936f, 0.7928f, + -0.8665f, 0.9468f, 1.1629f, 0.0616f, -1.3136f, + -0.2764f, 0.0277f, -0.1126f, 0.2342f, -0.5866f, + -1.8219f, 1.1079f, 0.5795f, -1.4249f}; + + std::vector position_ids = {0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7}; + + std::vector cos_cache = { + 1.0000f, 1.0000f, 1.0000f, 0.5403f, 0.9989f, 1.0000f, -0.4161f, 0.9957f, + 1.0000f, -0.9900f, 0.9903f, 1.0000f, -0.6536f, 0.9828f, 1.0000f, 0.2837f, + 0.9732f, 0.9999f, 0.9602f, 0.9615f, 0.9999f, 0.7539f, 0.9477f, 0.9999f, + -0.1455f, 0.9318f, 0.9999f, -0.9111f, 0.9140f, 0.9998f, -0.8391f, 0.8942f, + 0.9998f, 0.0044f, 0.8725f, 0.9997f, 0.8439f, 0.8488f, 0.9997f, 0.9074f, + 0.8234f, 0.9996f, 0.1367f, 0.7962f, 0.9995f, -0.7597f, 0.7673f, 0.9995f}; + + std::vector sin_cache = { + 0.0000f, 0.0000f, 0.0000f, 0.8415f, 0.0464f, 0.0022f, 0.9093f, 0.0927f, + 0.0043f, 0.1411f, 0.1388f, 0.0065f, -0.7568f, 0.1846f, 0.0086f, -0.9589f, + 0.2300f, 0.0108f, -0.2794f, 0.2749f, 0.0129f, 0.6570f, 0.3192f, 0.0151f, + 0.9894f, 0.3629f, 0.0172f, 0.4121f, 0.4057f, 0.0194f, -0.5440f, 0.4477f, + 0.0215f, -1.0000f, 0.4887f, 0.0237f, -0.5366f, 0.5286f, 0.0259f, 0.4202f, + 0.5675f, 0.0280f, 0.9906f, 0.6050f, 0.0302f, 0.6503f, 0.6413f, 0.0323f}; + + std::vector output_data = { + -1.0408f, 0.9166f, -1.3042f, -1.1097f, -1.2188f, + 1.1676f, -1.0190f, 0.3157f, -1.6036f, 1.8493f, + 0.0447f, 1.5853f, 0.1036f, -0.3514f, 0.2421f, + 0.6463f, 0.8730f, -0.9276f, 1.0311f, -1.9557f, + -0.1482f, 1.7376f, 2.2039f, -0.6589f, -0.8618f, + -0.0922f, -0.9073f, -0.7032f, -0.5762f, -0.2371f, + 0.6923f, 1.1571f, 0.7572f, -1.1471f, -0.5302f, + -0.4391f, 0.5516f, 1.0461f, -0.4812f, -0.1443f, + -0.4862f, -0.6423f, 0.6740f, -0.4614f, 0.5475f, + 1.1495f, 0.2389f, 0.8582f, -0.0259f, -0.6099f, + -0.2230f, 1.0963f, -1.5704f, -0.4595f, 0.9507f, + 0.6696f, -0.7721f, -1.7415f, 1.2087f, -0.6387f, + -1.1052f, -0.5243f, -0.0400f, -0.4671f, 0.4909f, + -0.1931f, -0.1937f, -0.0447f, -0.3171f, 2.6839f, + -0.0076f, 1.5185f, 0.8465f, 0.3737f, 0.0242f, + -0.0703f, 1.1279f, 0.8862f, 1.2275f, -0.1786f, + -0.8767f, -1.8072f, -0.2630f, 0.9387f, -0.8021f, + 0.7813f, 0.5001f, -1.4202f, -0.3850f, 0.9263f, + -0.0443f, -0.2323f, 0.5480f, 1.5696f, 0.6193f, + -1.1346f, 1.7878f, -0.5160f, 0.1192f, -2.1572f, + 0.0460f, 1.1202f, -1.4812f, -0.9082f, 0.1728f, + -1.5132f, -0.4489f, 0.3370f, -0.1541f, -0.9266f, + 0.2416f, 0.9270f, -1.1146f, 1.8758f, -0.4312f, + 1.3714f, 1.2106f, -0.4272f, -0.8529f, 1.0328f, + 1.8441f, 1.7698f, -0.7620f, 0.2168f, 0.1322f, + -0.2802f, 0.1460f, 2.1002f, 0.8437f, -0.1534f, + 0.4321f, 0.8360f, 0.5955f, -1.5452f, -0.0491f, + -0.8794f, 0.2418f, -1.4203f, 0.3635f, 0.2362f, + 0.3672f, -0.1128f, -0.8664f, -0.6354f, -1.4409f, + -0.3413f, -0.2409f, -0.3188f, 1.1054f, 0.4265f, + 0.5867f, -1.3279f, 0.3201f, 0.0125f, 1.8157f, + 1.0745f, 0.7372f, -0.2429f, 0.7100f, -0.4299f, + -0.2304f, 0.1645f, 0.9489f, -0.1816f, -0.5968f, + 1.0394f, 0.0204f, 1.1786f, -0.3315f, -0.3997f, + -0.9304f, -1.4268f, -1.1526f, -0.1132f, 0.1490f, + 1.3967f, -1.4634f, -0.1412f, -0.6339f, -1.5995f, + -0.1366f, 0.7604f, 0.1514f, 0.0824f, -1.1830f, + -1.6572f, 2.0099f, -0.9108f, -0.2256f, 0.4527f, + -1.8254f, 0.6475f, 0.8964f, 0.5717f, -0.2390f, + 0.6983f, -1.3416f, 0.2715f, -0.2852f, 0.6051f, + 0.2167f, -0.2181f, -1.6306f, 1.4788f, 0.2754f, + -0.0261f, -0.4618f, -0.5646f, -1.0389f, 0.5819f, + 1.3697f, 0.0002f, 1.5333f, -1.0556f, -0.1254f, + 0.1527f, -1.4979f, -1.1358f, 1.6320f, 0.2493f, + 0.8266f, 0.3424f, -0.4992f, 0.2964f, 0.7298f, + 1.8544f, 0.3516f, 0.0454f, 1.5415f, -0.2822f, + -2.0774f, 1.2323f, 0.3963f, -1.1503f, -0.4775f, + -1.9287f, -1.6164f, 0.3998f, 0.9020f, -0.0764f, + -1.8059f, -0.5762f, -1.4362f, -0.2706f, -1.0183f, + -0.4620f, 2.0891f, 0.1782f, 1.1591f, -0.8151f, + 1.3000f, -1.2464f, -0.5099f, 0.5098f, -3.3525f, + 0.4326f, 0.7414f, -0.7775f, -0.4271f, -0.3807f, + 1.3245f, 2.4936f, 0.3139f, 1.0095f, 0.2323f, + 0.8450f, -1.2244f, -0.4511f, 0.6266f, 0.9095f, + -1.7981f, 1.5241f, -0.4121f, 0.2341f, -0.4737f, + -1.3333f, -1.6150f, 0.4164f, 0.7100f, -0.2429f, + -0.5656f, 0.0863f, 0.0352f, -0.7227f, -1.3613f, + -0.0988f, -1.9114f, -0.3009f, 0.1435f, 0.7029f, + -0.3467f, 0.5092f, -0.0828f, 0.6253f, 0.7113f, + -1.2138f, 1.5964f, -0.8346f, -1.1515f, -0.7923f, + -0.8254f, -3.0038f, 2.4033f, -0.3398f, 0.0922f, + 1.7053f, 1.1114f, 0.7462f, 2.3660f, -0.8409f, + -0.6654f, -0.6530f, -0.7899f, -1.0957f, -0.7149f, + -0.1072f, -0.1967f, -2.3416f, -1.2609f, -1.6375f, + -0.3576f, 0.9413f, -0.5694f, 0.3954f, 0.1383f, + -0.7477f, -0.8689f, 1.8286f, 0.8510f, -1.4793f, + -0.1597f, 0.8541f, 0.2380f, 1.4392f, -0.5644f, + 0.3158f, -1.0686f, -0.1313f, -0.0181f, 0.2438f, + -0.8801f, 0.1413f, -0.3587f, 0.8002f, -0.5982f, + -1.4301f, -0.6620f, 0.7324f, -0.7250f, 0.0610f, + 0.9293f, -0.6902f, -0.0125f, -0.2089f, -0.1664f, + 0.5428f, 0.4245f, -0.7901f, 0.5665f, 0.9044f, + 0.1948f, -0.1723f, 1.2705f, 1.0303f, 1.2202f, + 1.3762f, -0.2959f, 0.7237f, -1.2077f, 0.7937f, + -0.6705f, 0.9287f, 1.0583f, 0.0496f, -1.3118f, + 0.5556f, 0.0459f, -0.1324f, -0.5513f, -0.7409f, + -1.8002f, 0.9892f, 0.3619f, -1.4522f}; + + RunTests(input_data, + position_ids, + cos_cache, + sin_cache, + output_data, + batch_size, + sequence_length, + head_size, + rotary_embedding_dim, + num_heads, + max_sequence_length, + interleaved); +} + +// Interleaved = false, pos ids shape = (batch_size, sequence_length) +TEST(RotaryEmbeddingTest, RotaryEmbedding_NotInterleaved_LargeData_LlamaMSFT_4D_Input) { + int batch_size = 2; + int sequence_length = 8; + int num_heads = 4; + int head_size = 6; + int rotary_embedding_dim = 0; + int max_sequence_length = 16; + int64_t interleaved = 0; // false + + // Input data in 4D layout: [batch_size=2, num_heads=4, sequence_length=8, head_size=6] + std::vector input_data = { + // Batch 0, Head 0: 8 sequences of 6 elements each + -1.0408f, 0.9166f, -1.3042f, -1.1097f, -1.2188f, 1.1676f, // seq 0 + -1.0190f, 0.3157f, -1.6036f, 1.8493f, 0.0447f, 1.5853f, // seq 1 + 0.1036f, -0.3514f, 0.2421f, 0.6463f, 0.8730f, -0.9276f, // seq 2 + 1.0311f, -1.9557f, -0.1482f, 1.7376f, 2.2039f, -0.6589f, // seq 3 + -1.0574f, -0.1188f, -0.9078f, 0.3452f, -0.5713f, -0.2351f, // seq 4 + -0.5912f, 1.1312f, 0.7562f, -1.2023f, -0.5833f, -0.4407f, // seq 5 + 0.1766f, 1.0224f, -0.4826f, -0.5421f, -0.5342f, -0.6413f, // seq 6 + 1.3314f, -0.4498f, 0.5493f, 0.0539f, 0.2601f, 0.8570f, // seq 7 + + // Batch 0, Head 1: 8 sequences of 6 elements each + 1.0076f, -0.7529f, -0.2250f, -0.4327f, -1.5071f, -0.4586f, // seq 0 + -1.9791f, 0.7787f, -0.7749f, -0.1398f, 1.1414f, -0.6354f, // seq 1 + 0.0352f, -0.4765f, -0.0409f, 1.1993f, 0.5374f, -0.1930f, // seq 2 + 2.5211f, -0.0452f, -0.3105f, -0.9407f, -0.0034f, 1.5199f, // seq 3 + -0.8480f, 0.5266f, 0.0299f, -0.0498f, 1.0651f, 0.8860f, // seq 4 + -1.4702f, -0.2134f, -0.8707f, 1.6159f, -0.2356f, 0.9444f, // seq 5 + 0.5937f, 0.7203f, 0.5061f, 1.5192f, -0.4897f, 0.9231f, // seq 6 + 0.2654f, -0.1441f, 0.5407f, -1.5476f, 0.6455f, -1.1382f, // seq 7 + + // Batch 0, Head 2: 8 sequences of 6 elements each + 0.4640f, -0.4986f, 0.1289f, 2.7631f, 0.1405f, 1.1191f, // seq 0 + 2.1134f, -0.9754f, 0.1757f, -0.1319f, -0.2735f, 0.3355f, // seq 1 + -0.6008f, -1.1164f, 0.2577f, -0.7226f, -0.9244f, 1.8737f, // seq 2 + 0.6052f, 1.1904f, 1.2195f, -0.0470f, -1.0914f, 1.0223f, // seq 3 + 0.3152f, 1.7528f, -0.7650f, 1.8299f, -0.2784f, -0.2719f, // seq 4 + 0.1885f, 2.1432f, 0.8527f, 0.0965f, -0.0625f, 0.8269f, // seq 5 + 1.0122f, -1.4482f, -0.0644f, 0.3215f, 0.5908f, -1.4197f, // seq 6 + 0.2113f, 0.0306f, 0.3604f, 0.3166f, -0.8975f, -0.6393f, // seq 7 + + // Batch 0, Head 3: 8 sequences of 6 elements each + -1.2944f, -0.0243f, -0.2354f, -0.7087f, 1.1566f, 0.4296f, // seq 0 + 0.5599f, -0.7776f, 0.3339f, 0.1759f, 2.1108f, 1.0702f, // seq 1 + 0.8279f, -0.2969f, 0.7120f, -0.2068f, -0.1548f, 0.1553f, // seq 2 + 0.6207f, -0.1690f, -0.5816f, 1.2632f, 0.0695f, 1.1862f, // seq 3 + -1.1874f, -0.7468f, -0.9320f, -0.8579f, -0.9647f, -0.0991f, // seq 4 + 0.0195f, 1.1213f, -1.4873f, -0.2043f, -1.0466f, -1.5772f, // seq 5 + -0.0489f, 0.3430f, 0.1264f, 0.1519f, -1.3639f, -1.6593f, // seq 6 + 1.8127f, -1.4459f, -0.2158f, -0.9792f, -1.4392f, 0.6508f, // seq 7 + + // Batch 1, Head 0: 8 sequences of 6 elements each + 0.8964f, 0.5717f, -0.2390f, 0.6983f, -1.3416f, 0.2715f, // seq 0 + -0.2852f, 0.6051f, 0.2167f, -0.2181f, -1.6306f, 1.4788f, // seq 1 + 0.2754f, -0.0261f, -0.4618f, -0.5646f, -1.0389f, 0.5819f, // seq 2 + 1.3697f, 0.0002f, 1.5333f, -1.0556f, -0.1254f, 0.1527f, // seq 3 + -0.5996f, -1.0962f, 1.6327f, 1.3951f, 0.8784f, 0.3389f, // seq 4 + 1.2907f, 0.3124f, 0.7299f, 1.4220f, 0.3375f, 0.0438f, // seq 5 + 1.8698f, -0.2635f, -2.0799f, -0.6313f, 0.4090f, -1.1458f, // seq 6 + 0.0784f, -1.8848f, -1.6165f, 0.6179f, 0.9905f, -0.0729f, // seq 7 + + // Batch 1, Head 1: 8 sequences of 6 elements each + 0.5054f, -0.6681f, -1.4382f, 1.7547f, -0.9605f, -0.4558f, // seq 0 + -1.6105f, 0.2979f, 1.1537f, -1.5604f, 1.2779f, -1.2514f, // seq 1 + 0.6056f, 0.5763f, -3.3558f, 0.2836f, 0.6909f, -0.7631f, // seq 2 + 2.4451f, -0.3500f, 1.3289f, -0.6494f, 0.3478f, 1.0038f, // seq 3 + -0.2937f, 0.9238f, -1.2185f, 0.4138f, 0.5033f, 0.9174f, // seq 4 + 1.8131f, 1.4436f, -0.4207f, 0.0220f, -0.6807f, -1.3306f, // seq 5 + 1.5646f, 0.3338f, 0.7105f, 0.4683f, -0.6179f, 0.0818f, // seq 6 + -0.0488f, -0.9810f, -1.3632f, 0.0929f, -1.7926f, -0.2921f, // seq 7 + + // Batch 1, Head 2: 8 sequences of 6 elements each + -0.4792f, 0.6756f, -0.3413f, -0.2242f, -0.2111f, 0.6282f, // seq 0 + 0.1667f, -1.4055f, 1.5895f, 1.0838f, -0.9077f, -0.8060f, // seq 1 + 0.7967f, -2.9351f, 2.4179f, -0.4026f, 0.6451f, 1.6845f, // seq 2 + -0.0901f, 0.6106f, 2.3603f, 1.3908f, -0.7917f, -0.6734f, // seq 3 + -0.1213f, -1.1116f, -0.7401f, -0.7879f, 0.0606f, -2.3337f, // seq 4 + -1.2603f, -1.7245f, -0.3533f, -0.9421f, -0.1776f, 0.3992f, // seq 5 + -1.7142f, -0.5319f, -0.8848f, 0.6513f, 1.0002f, -1.4699f, // seq 6 + -1.4254f, 0.7013f, 0.2414f, 0.2551f, -0.7457f, 0.3133f, // seq 7 + + // Batch 1, Head 3: 8 sequences of 6 elements each + -1.0941f, -0.3682f, -0.0163f, -0.0645f, -0.8101f, 0.1415f, // seq 0 + 0.0551f, 0.5873f, -0.5887f, -1.4733f, -0.8565f, 0.7400f, // seq 1 + -0.5033f, 0.0553f, 0.9265f, -0.8652f, -0.0288f, -0.2209f, // seq 2 + 0.0610f, 0.6776f, 0.4361f, -0.8052f, 0.3955f, 0.8988f, // seq 3 + 0.8238f, 0.2262f, 1.2912f, 0.6488f, 1.2114f, 1.3569f, // seq 4 + 0.2983f, 0.4718f, -1.1936f, 0.7928f, -0.8665f, 0.9468f, // seq 5 + 1.1629f, 0.0616f, -1.3136f, -0.2764f, 0.0277f, -0.1126f, // seq 6 + 0.2342f, -0.5866f, -1.8219f, 1.1079f, 0.5795f, -1.4249f}; // seq 7 + + std::vector position_ids = {0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7}; + + std::vector cos_cache = { + 1.0000f, 1.0000f, 1.0000f, 0.5403f, 0.9989f, 1.0000f, -0.4161f, 0.9957f, + 1.0000f, -0.9900f, 0.9903f, 1.0000f, -0.6536f, 0.9828f, 1.0000f, 0.2837f, + 0.9732f, 0.9999f, 0.9602f, 0.9615f, 0.9999f, 0.7539f, 0.9477f, 0.9999f, + -0.1455f, 0.9318f, 0.9999f, -0.9111f, 0.9140f, 0.9998f, -0.8391f, 0.8942f, + 0.9998f, 0.0044f, 0.8725f, 0.9997f, 0.8439f, 0.8488f, 0.9997f, 0.9074f, + 0.8234f, 0.9996f, 0.1367f, 0.7962f, 0.9995f, -0.7597f, 0.7673f, 0.9995f}; + + std::vector sin_cache = { + 0.0000f, 0.0000f, 0.0000f, 0.8415f, 0.0464f, 0.0022f, 0.9093f, 0.0927f, + 0.0043f, 0.1411f, 0.1388f, 0.0065f, -0.7568f, 0.1846f, 0.0086f, -0.9589f, + 0.2300f, 0.0108f, -0.2794f, 0.2749f, 0.0129f, 0.6570f, 0.3192f, 0.0151f, + 0.9894f, 0.3629f, 0.0172f, 0.4121f, 0.4057f, 0.0194f, -0.5440f, 0.4477f, + 0.0215f, -1.0000f, 0.4887f, 0.0237f, -0.5366f, 0.5286f, 0.0259f, 0.4202f, + 0.5675f, 0.0280f, 0.9906f, 0.6050f, 0.0302f, 0.6503f, 0.6413f, 0.0323f}; + + // Expected output in 4D layout: [batch_size=2, num_heads=4, sequence_length=8, head_size=6] + // This is derived from the 3D test case output, reordered to match 4D layout + std::vector output_data = { + // Batch 0, Head 0: 8 sequences of 6 elements each + -1.04079998e+00f, 9.16599989e-01f, -1.30420005e+00f, -1.10969996e+00f, -1.21879995e+00f, 1.16760004e+00f, // seq 0 + -2.10675168e+00f, 3.13278645e-01f, -1.60708773e+00f, 1.41688287e-01f, 5.92993088e-02f, 1.58177209e+00f, // seq 1 + -6.30788624e-01f, -4.30816084e-01f, 2.46088684e-01f, -1.74721941e-01f, 8.36671352e-01f, -9.26558971e-01f, // seq 2 + -1.26596439e+00f, -2.24263120e+00f, -1.43917158e-01f, -1.57473576e+00f, 1.91107118e+00f, -6.59863293e-01f, // seq 3 + 9.52363968e-01f, -1.12946555e-02f, -9.05778170e-01f, 5.74617565e-01f, -5.83404124e-01f, -2.42907077e-01f, // seq 4 + -1.32060885e+00f, 1.23504281e+00f, 7.60883927e-01f, 2.25809187e-01f, -3.07491541e-01f, -4.32488948e-01f, // seq 5 + 1.81085765e-02f, 1.12988913e+00f, -4.74278957e-01f, -5.69866478e-01f, -2.32575566e-01f, -6.47461414e-01f, // seq 6 + 9.68330145e-01f, -5.09299397e-01f, 5.36304355e-01f, 9.15365040e-01f, 1.02920607e-01f, 8.65208685e-01f, // seq 7 + + // Batch 0, Head 1: 8 sequences of 6 elements each + 1.00759995e+00f, -7.52900004e-01f, -2.24999994e-01f, -4.32700008e-01f, -1.50709999e+00f, -4.58600014e-01f, // seq 0 + -9.51666117e-01f, 7.24882483e-01f, -7.73502111e-01f, -1.74094665e+00f, 1.17627621e+00f, -6.37104750e-01f, // seq 1 + -1.10517037e+00f, -5.24268031e-01f, -4.00700979e-02f, -4.67021376e-01f, 4.90917653e-01f, -1.93175867e-01f, // seq 2 + -2.36315632e+00f, -4.42896411e-02f, -3.20379347e-01f, 1.28702021e+00f, -9.64078028e-03f, 1.51788175e+00f, // seq 3 + 5.16564190e-01f, 3.20925027e-01f, 2.22803988e-02f, 6.74315631e-01f, 1.14399064e+00f, 8.86257112e-01f, // seq 4 + 1.13239074e+00f, -1.53492898e-01f, -8.80812466e-01f, 1.86820555e+00f, -2.78367937e-01f, 9.34901953e-01f, // seq 5 + 9.94535208e-01f, 8.27187002e-01f, 4.94141400e-01f, 1.29285610e+00f, -2.72836059e-01f, 9.29536343e-01f, // seq 6 + 1.21685827e+00f, -3.42607170e-01f, 5.57832778e-01f, -9.92367864e-01f, 5.65743625e-01f, -1.12992167e+00f, // seq 7 + + // Batch 0, Head 2: 8 sequences of 6 elements each + 4.63999987e-01f, -4.98600006e-01f, 1.28900006e-01f, 2.76309991e+00f, 1.40499994e-01f, 1.11909997e+00f, // seq 0 + 1.25286388e+00f, -9.61636603e-01f, 1.74961895e-01f, 1.70716047e+00f, -3.18457693e-01f, 3.35886538e-01f, // seq 1 + 9.07053053e-01f, -1.02590752e+00f, 2.49643087e-01f, -2.45633602e-01f, -1.02391529e+00f, 1.87480819e+00f, // seq 2 + -5.92516303e-01f, 1.33033943e+00f, 1.21285498e+00f, 1.31923720e-01f, -9.15585876e-01f, 1.03022671e+00f, // seq 3 + 1.17885363e+00f, 1.77404451e+00f, -7.62661636e-01f, -1.43456602e+00f, 4.99553680e-02f, -2.78479010e-01f, // seq 4 + 1.46011293e-01f, 2.10013723e+00f, 8.43684196e-01f, -1.53375596e-01f, 4.32110995e-01f, 8.36026430e-01f, // seq 5 + 1.06174159e+00f, -1.55485511e+00f, -4.60794270e-02f, 2.58956552e-02f, 1.69944048e-01f, -1.42038882e+00f, // seq 6 + -4.87071276e-02f, 3.15481633e-01f, 3.70017350e-01f, 3.77508819e-01f, -8.40793192e-01f, -6.33794010e-01f, // seq 7 + + // Batch 0, Head 3: 8 sequences of 6 elements each + -1.29439998e+00f, -2.42999997e-02f, -2.35400006e-01f, -7.08700001e-01f, 1.15660000e+00f, 4.29600000e-01f, // seq 0 + 1.54494107e-01f, -8.74685705e-01f, 3.31545562e-01f, 5.66194594e-01f, 2.07239747e+00f, 1.07093453e+00f, // seq 1 + -1.56445935e-01f, -2.81273365e-01f, 7.11332202e-01f, 8.38858962e-01f, -1.81656986e-01f, 1.58361614e-01f, // seq 2 + -7.92730570e-01f, -1.77007288e-01f, -5.89310288e-01f, -1.16298723e+00f, 4.53686491e-02f, 1.18241966e+00f, // seq 3 + 1.26825869e-01f, -5.55871427e-01f, -9.31147695e-01f, 1.45934772e+00f, -1.08596635e+00f, -1.07115202e-01f, // seq 4 + -1.90371126e-01f, 1.33196723e+00f, -1.47011745e+00f, -7.66584575e-02f, -7.60652125e-01f, -1.59310520e+00f, // seq 5 + -4.51292470e-03f, 7.04730570e-01f, 1.47792324e-01f, 1.59517035e-01f, -1.21709907e+00f, -1.65750349e+00f, // seq 6 + 2.00992894e+00f, -9.10886765e-01f, -2.25605503e-01f, 4.52725053e-01f, -1.82546115e+00f, 6.47476315e-01f, // seq 7 + + // Batch 1, Head 0: 8 sequences of 6 elements each + 8.96399975e-01f, 5.71699977e-01f, -2.38999993e-01f, 6.98300004e-01f, -1.34159994e+00f, 2.71499991e-01f, // seq 0 + 2.94375867e-02f, 6.80094242e-01f, 2.13446647e-01f, -3.57835233e-01f, -1.60072970e+00f, 1.47927678e+00f, // seq 1 + 3.98796827e-01f, 7.03182518e-02f, -4.64302182e-01f, 4.85351264e-01f, -1.03685224e+00f, 5.79914272e-01f, // seq 2 + -1.20705771e+00f, 1.76035818e-02f, 1.53230751e+00f, 1.23830867e+00f, -1.24155864e-01f, 1.62666455e-01f, // seq 3 + 1.44771028e+00f, -1.23949802e+00f, 1.62978542e+00f, -4.58060056e-01f, 6.60933018e-01f, 3.52941215e-01f, // seq 4 + 1.72972739e+00f, 2.26402700e-01f, 7.29353964e-01f, -8.34230781e-01f, 4.00307000e-01f, 5.16785383e-02f, // seq 5 + 1.61899674e+00f, -3.65789354e-01f, -2.06491113e+00f, -1.12859631e+00f, 3.20817351e-01f, -1.17251611e+00f, // seq 6 + -3.46854568e-01f, -2.10239267e+00f, -1.61523759e+00f, 5.17343640e-01f, 3.37068677e-01f, -9.73018557e-02f, // seq 7 + + // Batch 1, Head 1: 8 sequences of 6 elements each + 5.05400002e-01f, -6.68099999e-01f, -1.43820000e+00f, 1.75469995e+00f, -9.60500002e-01f, -4.55799997e-01f, // seq 0 + 4.42923486e-01f, 2.38277763e-01f, 1.15645313e+00f, -2.19831991e+00f, 1.29031682e+00f, -1.24886191e+00f, // seq 1 + -5.09867668e-01f, 5.09775519e-01f, -3.35251856e+00f, 4.32666153e-01f, 7.41352141e-01f, -7.77529955e-01f, // seq 2 + -2.32901859e+00f, -3.94879639e-01f, 1.32237530e+00f, 9.87909675e-01f, 2.95846343e-01f, 1.01243794e+00f, // seq 3 + 5.05126178e-01f, 8.15001488e-01f, -1.22638965e+00f, -4.81875092e-02f, 6.65176749e-01f, 9.06920910e-01f, // seq 4 + 5.35472274e-01f, 1.56147265e+00f, -4.06287462e-01f, -1.73234010e+00f, -3.30429256e-01f, -1.33501053e+00f, // seq 5 + 1.63317192e+00f, 4.90809381e-01f, 7.09373713e-01f, 1.25124454e-02f, -5.02349257e-01f, 9.09572691e-02f, // seq 6 + -9.78256166e-02f, -3.57495785e-01f, -1.35865283e+00f, 3.79757136e-02f, -2.01198220e+00f, -3.12655121e-01f, // seq 7 + + // Batch 1, Head 2: 8 sequences of 6 elements each + -4.79200006e-01f, 6.75599992e-01f, -3.41300011e-01f, -2.24199995e-01f, -2.11099997e-01f, 6.28199995e-01f, // seq 0 + -8.21949601e-01f, -1.36183679e+00f, 1.59127319e+00f, 7.25855172e-01f, -9.71916676e-01f, -8.02503109e-01f, // seq 1 + 3.45773101e-02f, -2.98228002e+00f, 2.41065669e+00f, 8.91961157e-01f, 3.70242298e-01f, 1.69489694e+00f, // seq 2 + -1.07042886e-01f, 7.14565158e-01f, 2.36467719e+00f, -1.38960505e+00f, -6.99269295e-01f, -6.58058047e-01f, // seq 3 + -5.17001033e-01f, -1.10366726e+00f, -7.20030189e-01f, 6.06771231e-01f, -1.45643681e-01f, -2.34006476e+00f, // seq 4 + -1.26092672e+00f, -1.63743532e+00f, -3.57576013e-01f, 9.41227913e-01f, -5.69475293e-01f, 3.95344406e-01f, // seq 5 + -1.46400166e+00f, -7.86376834e-01f, -8.65749776e-01f, 1.10432577e+00f, 8.15473020e-01f, -1.48116696e+00f, // seq 6 + -1.24220979e+00f, 9.02649522e-01f, 2.36645028e-01f, -7.44167924e-01f, -4.82844949e-01f, 3.16913843e-01f, // seq 7 + + // Batch 1, Head 3: 8 sequences of 6 elements each + -1.09410000e+00f, -3.68200004e-01f, -1.63000003e-02f, -6.44999966e-02f, -8.10100019e-01f, 1.41499996e-01f, // seq 0 + 1.26955235e+00f, 6.26395524e-01f, -5.90327978e-01f, -7.49657393e-01f, -8.28307152e-01f, 7.38704860e-01f, // seq 1 + 9.96149480e-01f, 5.77319711e-02f, 9.27449882e-01f, -9.76410210e-02f, -2.35498492e-02f, -2.16916054e-01f, // seq 2 + 5.32237142e-02f, 6.16131902e-01f, 4.30257797e-01f, 8.05755079e-01f, 4.85714525e-01f, 9.01634693e-01f, // seq 3 + -4.74238396e-02f, -1.31507218e-03f, 1.27953064e+00f, -1.04750752e+00f, 1.23232043e+00f, 1.36800432e+00f, // seq 4 + 8.44843626e-01f, 6.58450782e-01f, -1.20370615e+00f, -6.11225069e-02f, -7.34763801e-01f, 9.33814406e-01f, // seq 5 + 1.03939044e+00f, 5.16136698e-02f, -1.31201601e+00f, -5.90313554e-01f, 4.35673892e-02f, -1.29534170e-01f, // seq 6 + -5.51326931e-01f, -7.40897238e-01f, -1.80020177e+00f, 9.89115238e-01f, 3.61949444e-01f, -1.45226824e+00f}; // seq 7 + + RunTests(input_data, + position_ids, + cos_cache, + sin_cache, + output_data, + batch_size, + sequence_length, + head_size, + rotary_embedding_dim, + num_heads, + max_sequence_length, + interleaved, + true, // use_float16 + true); // input_is_4d +} + +// Interleaved = false, pos ids shape = (batch_size, sequence_length) +TEST(RotaryEmbeddingTest, RotaryEmbedding_NotInterleaved_SmallData_LlamaMSFT) { + int batch_size = 1; + int sequence_length = 2; + int num_heads = 3; + int head_size = 6; + int rotary_embedding_dim = 0; + int max_sequence_length = 4; + int64_t interleaved = 0; // false + + std::vector input_data = { + -1.0408f, 0.9166f, -1.3042f, -1.1097f, -1.2188f, 1.1676f, 1.0076f, -0.7529f, + -0.2250f, -0.4327f, -1.5071f, -0.4586f, -0.8663f, -0.2656f, 0.1665f, 0.7911f, + -0.9320f, -0.8579f, -1.0574f, -0.1188f, -0.9078f, 0.3452f, -0.5713f, -0.2351f, + -0.8480f, 0.5266f, -1.2944f, -0.0243f, -0.2354f, -0.7087f, -0.9647f, -0.0991f, + -0.2994f, -0.0650f, -1.5720f, -1.3211f}; + + std::vector position_ids = {0, 1}; + + std::vector cos_cache = { + 1.0000f, 1.0000f, 1.0000f, 0.5403f, 0.9989f, 1.0000f, -0.4161f, 0.9957f, + 1.0000f, -0.9900f, 0.9903f, 1.0000f}; + + std::vector sin_cache = { + 0.0000f, 0.0000f, 0.0000f, 0.8415f, 0.0464f, 0.0022f, 0.9093f, 0.0927f, 0.0043f, + 0.1411f, 0.1388f, 0.0065f}; + + std::vector output_data = { + -1.0408f, 0.9166f, -1.3042f, -1.1097f, -1.2188f, 1.1676f, 1.0076f, -0.7529f, + -0.2250f, -0.4327f, -1.5071f, -0.4586f, -0.8663f, -0.2656f, 0.1665f, 0.7911f, + -0.9320f, -0.8579f, -0.8618f, -0.0922f, -0.9073f, -0.7032f, -0.5762f, -0.2371f, + -0.4377f, 0.5370f, -1.2929f, -0.7267f, -0.2107f, -0.7115f, -0.4666f, -0.0261f, + -0.2965f, -0.8469f, -1.5749f, -1.3217f}; + + RunTests(input_data, + position_ids, + cos_cache, + sin_cache, + output_data, + batch_size, + sequence_length, + head_size, + rotary_embedding_dim, + num_heads, + max_sequence_length, + interleaved); +} + +TEST(RotaryEmbeddingTest, RotaryEmbedding_CustomRotaryDim_SmallData_Phi) { + int batch_size = 1; + int sequence_length = 2; + int num_heads = 1; + int head_size = 6; + int rotary_embedding_dim = 4; + int max_sequence_length = 2; + int64_t interleaved = 0; // false + + std::vector input_data = { + -1.0408f, 0.9166f, -1.3042f, -1.1097f, -1.2188f, 1.1676f, 1.0076f, -0.7529f, + -0.2250f, -0.4327f, -1.5071f, -0.4586f}; + + std::vector position_ids = {0, 1}; + + std::vector cos_cache = { + 1.0000f, 1.0000f, 1.0000f, 0.5403f}; + + std::vector sin_cache = { + 0.0000f, 0.0000f, 0.0000f, 0.8415f}; + + std::vector output_data = { + -1.0408f, 0.9166f, -1.3042f, -1.1097f, -1.2188f, 1.1676f, 1.0076f, -0.0427f, + -0.2250f, -0.8673f, -1.5071f, -0.4586f}; + + RunTests(input_data, + position_ids, + cos_cache, + sin_cache, + output_data, + batch_size, + sequence_length, + head_size, + rotary_embedding_dim, + num_heads, + max_sequence_length, + interleaved, + true /*use_fp16*/); +} + +TEST(RotaryEmbeddingTest, RotaryEmbedding_NotInterleaved_NoPosIds_SmallData_LlamaMSFT) { + int batch_size = 1; + int sequence_length = 2; + int num_heads = 3; + int head_size = 6; + int rotary_embedding_dim = 0; + int max_sequence_length = 4; + int64_t interleaved = 0; // false + + std::vector input_data = { + -1.0408f, 0.9166f, -1.3042f, -1.1097f, -1.2188f, 1.1676f, 1.0076f, -0.7529f, + -0.2250f, -0.4327f, -1.5071f, -0.4586f, -0.8663f, -0.2656f, 0.1665f, 0.7911f, + -0.9320f, -0.8579f, -1.0574f, -0.1188f, -0.9078f, 0.3452f, -0.5713f, -0.2351f, + -0.8480f, 0.5266f, -1.2944f, -0.0243f, -0.2354f, -0.7087f, -0.9647f, -0.0991f, + -0.2994f, -0.0650f, -1.5720f, -1.3211f}; + + std::vector cos_cache = { + 1.0000f, 1.0000f, 1.0000f, 0.5403f, 0.9989f, 1.0000f}; + + std::vector sin_cache = { + 0.0000f, 0.0000f, 0.0000f, 0.8415f, 0.0464f, 0.0022f}; + + std::vector position_ids = {}; + + std::vector output_data = { + -1.0408f, 0.9166f, -1.3042f, -1.1097f, -1.2188f, 1.1676f, 1.0076f, -0.7529f, + -0.2250f, -0.4327f, -1.5071f, -0.4586f, -0.8663f, -0.2656f, 0.1665f, 0.7911f, + -0.9320f, -0.8579f, -0.8618f, -0.0922f, -0.9073f, -0.7032f, -0.5762f, -0.2371f, + -0.4377f, 0.5370f, -1.2929f, -0.7267f, -0.2107f, -0.7115f, -0.4666f, -0.0261f, + -0.2965f, -0.8469f, -1.5749f, -1.3217f}; + + RunTests(input_data, + position_ids, + cos_cache, + sin_cache, + output_data, + batch_size, + sequence_length, + head_size, + rotary_embedding_dim, + num_heads, + max_sequence_length, + interleaved); +} + +// Interleaved = true, pos ids = nullptr +TEST(RotaryEmbeddingTest, RotaryEmbedding_Interleaved_NoPosIds_SmallData_LlamaMSFT) { + int batch_size = 1; + int sequence_length = 3; + int num_heads = 2; + int head_size = 4; + int rotary_embedding_dim = 0; + int max_sequence_length = 8; + int64_t interleaved = 1; // true + + std::vector input_data = { + -1.0408f, 0.9166f, -1.3042f, -1.1097f, -0.1320f, -0.2751f, -0.2350f, 0.0937f, + -1.2188f, 1.1676f, -1.0574f, -0.1188f, -0.7396f, -1.2425f, -0.1752f, 0.6990f, + -0.8110f, 0.6737f, -1.1233f, -0.0919f, -0.6861f, 0.7202f, 0.1963f, 0.6142f}; + + std::vector cos_cache = { + 1.0000f, 1.0000f, 0.5403f, 0.9999f, -0.4161f, 0.9998f}; + + std::vector sin_cache = { + 0.0000f, 0.0000f, 0.8415f, 0.0100f, 0.9093f, 0.0200f}; + + std::vector output_data = { + -1.0408f, 0.9166f, -1.3042f, -1.1097f, -0.1320f, -0.2751f, -0.2350f, 0.0937f, + -1.6411f, -0.3948f, -1.0561f, -0.1294f, 0.6460f, -1.2937f, -0.1822f, 0.6972f, + -0.2751f, -1.0178f, -1.1212f, -0.1143f, -0.3694f, -0.9235f, 0.1840f, 0.6180f}; + + std::vector position_ids = {}; + + RunTests(input_data, + position_ids, + cos_cache, + sin_cache, + output_data, + batch_size, + sequence_length, + head_size, + rotary_embedding_dim, + num_heads, + max_sequence_length, + interleaved); +} + +} // namespace test +} // namespace onnxruntime diff --git a/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc b/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc index fbd9d10a56c77..0008b68d14f41 100644 --- a/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc +++ b/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc @@ -4,6 +4,9 @@ #include "gtest/gtest.h" #include "test/providers/provider_test_utils.h" #include "test/util/include/default_providers.h" +#include "test/util/include/current_test_name.h" +#include "test/util/include/test_utils.h" +#include "test/framework/test_utils.h" #include "test/common/dnnl_op_test_utils.h" #include "test/common/cuda_op_test_utils.h" #include "test/common/trt_op_test_utils.h" @@ -1446,6 +1449,37 @@ TEST(MathOpTest, Pow_float_float16) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } #endif + +#if defined(USE_WEBGPU) +// WebGPU EP currently handles a special case for supporting Pow op: +// A Pow followed by a Cast to int64 type. +TEST(MathOpTest, Pow_float_sqrt) { + const ORTCHAR_T* model_file_name = ORT_TSTR("testdata/webgpu_pow_cast_test.onnx"); + AllocatorPtr allocator = std::make_shared(); + + std::vector dims_x = {1}; + std::vector values_x = {576.}; + OrtValue ml_value_x; + CreateMLValue(allocator, dims_x, values_x, &ml_value_x); + + std::vector dims_y = {1}; + std::vector values_y = {0.5}; + OrtValue ml_value_y; + CreateMLValue(allocator, dims_y, values_y, &ml_value_y); + + NameMLValMap feeds; + feeds.insert(std::make_pair("x", ml_value_x)); + feeds.insert(std::make_pair("y", ml_value_y)); + + EPVerificationParams verification_params{}; + + RunAndVerifyOutputsWithEP(model_file_name, CurrentTestName(), + DefaultWebGpuExecutionProvider(), + feeds, + verification_params); +} +#endif + #if defined(USE_DNNL) TEST(MathOpTest, Exp_bfloat16) { #ifdef USE_DNNL diff --git a/onnxruntime/test/providers/cpu/nn/instance_norm_op_test.cc b/onnxruntime/test/providers/cpu/nn/instance_norm_op_test.cc index a5c0cda8dcf3a..9e0516fd394ce 100644 --- a/onnxruntime/test/providers/cpu/nn/instance_norm_op_test.cc +++ b/onnxruntime/test/providers/cpu/nn/instance_norm_op_test.cc @@ -4,6 +4,7 @@ #include "gtest/gtest.h" #include "test/providers/provider_test_utils.h" #include "test/common/tensor_op_test_utils.h" +#include "test/util/include/default_providers.h" using namespace std; namespace onnxruntime { @@ -290,5 +291,51 @@ TEST(InstanceNormalizationOpTest, InstanceNormNCHW) { }); } +#ifdef USE_WEBGPU +TEST(InstanceNormalizationOpTest, InstanceNormNCHW_webgpu) { + OpTester test("InstanceNormalization"); + test.AddAttribute("epsilon", 0.009999999776482582f); + + vector input = {1.0f, 2.0f, 3.0f, 2.0f, 2.0f, 2.0f}; + vector input_dims = {1, 2, 1, 3}; + test.AddInput("input", input_dims, input); + + vector scale = {1.0f, 1.0f}; + vector scale_dims = {2}; + test.AddInput("scale", scale_dims, scale); + + vector B = {0.0f, 2.0f}; + vector B_dims = {2}; + test.AddInput("B", B_dims, B); + + vector expected_output = {-1.21566f, 0.0f, 1.21566f, 2.0f, 2.0f, 2.0f}; + test.AddOutput("Y", input_dims, expected_output); + + test.ConfigEp(DefaultWebGpuExecutionProvider(false)).RunWithConfig(); +} + +TEST(InstanceNormalizationOpTest, InstanceNormNCHW_webgpu_2) { + OpTester test("InstanceNormalization"); + test.AddAttribute("epsilon", 0.009999999776482582f); + + vector input = {1.0f, 2.0f, 3.0f, 2.0f, 2.0f, 2.0f, 2.0f, 2.0f}; + vector input_dims = {1, 2, 2, 2}; + test.AddInput("input", input_dims, input); + + vector scale = {1.0f, 1.0f}; + vector scale_dims = {2}; + test.AddInput("scale", scale_dims, scale); + + vector B = {0.0f, 2.0f}; + vector B_dims = {2}; + test.AddInput("B", B_dims, B); + + vector expected_output = {-1.40028f, 0.0f, 1.40028f, 0.0f, 2.0f, 2.0f, 2.0f, 2.0f}; + test.AddOutput("Y", input_dims, expected_output); + + test.ConfigEp(DefaultWebGpuExecutionProvider(false)).RunWithConfig(); +} +#endif + } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/providers/cuda/test_cases/allocator_cuda_test.cc b/onnxruntime/test/providers/cuda/test_cases/allocator_cuda_test.cc index b413d04fe81e8..0410e51ce207d 100644 --- a/onnxruntime/test/providers/cuda/test_cases/allocator_cuda_test.cc +++ b/onnxruntime/test/providers/cuda/test_cases/allocator_cuda_test.cc @@ -25,7 +25,7 @@ TEST(AllocatorTest, CUDAAllocatorTest) { size_t size = 1024; EXPECT_STREQ(cuda_arena->Info().name, CUDA); - EXPECT_EQ(cuda_arena->Info().id, cuda_device_id); + EXPECT_EQ(cuda_arena->Info().device.Id(), cuda_device_id); EXPECT_EQ(cuda_arena->Info().mem_type, OrtMemTypeDefault); EXPECT_EQ(cuda_arena->Info().alloc_type, OrtArenaAllocator); @@ -39,7 +39,7 @@ TEST(AllocatorTest, CUDAAllocatorTest) { auto pinned_allocator = CreateAllocator(pinned_memory_info); EXPECT_STREQ(pinned_allocator->Info().name, CUDA_PINNED); - EXPECT_EQ(pinned_allocator->Info().id, 0); + EXPECT_EQ(pinned_allocator->Info().device.Id(), 0); EXPECT_EQ(pinned_allocator->Info().mem_type, OrtMemTypeCPUOutput); EXPECT_EQ(pinned_allocator->Info().alloc_type, OrtArenaAllocator); @@ -51,7 +51,7 @@ TEST(AllocatorTest, CUDAAllocatorTest) { [](int) { return std::make_unique(); }, true); const auto& cpu_arena = CreateAllocator(cpu_memory_info); EXPECT_STREQ(cpu_arena->Info().name, CPU); - EXPECT_EQ(cpu_arena->Info().id, 0); + EXPECT_EQ(cpu_arena->Info().device.Id(), 0); EXPECT_EQ(cpu_arena->Info().mem_type, OrtMemTypeDefault); EXPECT_EQ(cpu_arena->Info().alloc_type, OrtArenaAllocator); diff --git a/onnxruntime/test/providers/cuda/test_cases/cuda_execution_provider_test.cc b/onnxruntime/test/providers/cuda/test_cases/cuda_execution_provider_test.cc index 72357ec7e02d2..7f6f6ba3bb4b0 100644 --- a/onnxruntime/test/providers/cuda/test_cases/cuda_execution_provider_test.cc +++ b/onnxruntime/test/providers/cuda/test_cases/cuda_execution_provider_test.cc @@ -60,7 +60,8 @@ TEST(TestDeferredRelease, WithoutArena) { onnxruntime::RunOptions run_opts; run_opts.run_tag = "log1"; - OrtDevice pinned_device{OrtDevice::CPU, OrtDevice::MemType::CUDA_PINNED, DEFAULT_CPU_ALLOCATOR_DEVICE_ID}; + OrtDevice pinned_device{OrtDevice::CPU, OrtDevice::MemType::HOST_ACCESSIBLE, OrtDevice::VendorIds::NVIDIA, + DEFAULT_CPU_ALLOCATOR_DEVICE_ID}; // Create allocator without BFCArena AllocatorCreationInfo pinned_memory_info( [](OrtDevice::DeviceId) { diff --git a/onnxruntime/test/providers/nnapi/nnapi_basic_test.cc b/onnxruntime/test/providers/nnapi/nnapi_basic_test.cc index 1af7bdea68b67..a96c8d05ee64f 100644 --- a/onnxruntime/test/providers/nnapi/nnapi_basic_test.cc +++ b/onnxruntime/test/providers/nnapi/nnapi_basic_test.cc @@ -52,7 +52,7 @@ TEST(NnapiExecutionProviderTest, ReshapeFlattenTest) { std::vector dims_mul_y = {3, 2, 2}; std::vector values_mul_y = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f}; OrtValue ml_value_x; - AllocatorPtr cpu_allocator = std::make_shared(); + AllocatorPtr cpu_allocator = CPUAllocator::DefaultInstance(); CreateMLValue(cpu_allocator, dims_mul_x, values_mul_x, &ml_value_x); OrtValue ml_value_y; @@ -80,7 +80,7 @@ TEST(NnapiExecutionProviderTest, SigmoidSupportedInputRankTest) { std::vector values_mul_x = {1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f}; OrtValue ml_value_x; - AllocatorPtr cpu_allocator = std::make_shared(); + AllocatorPtr cpu_allocator = CPUAllocator::DefaultInstance(); CreateMLValue(std::move(cpu_allocator), dims_mul_x, values_mul_x, &ml_value_x); NameMLValMap feeds; @@ -107,7 +107,7 @@ TEST(NnapiExecutionProviderTest, DynamicGraphInputTest) { std::vector dims_mul_x = {1, 1, 4, 4}; std::vector values_mul_x = {1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f}; OrtValue ml_value_x; - AllocatorPtr cpu_allocator = std::make_shared(); + AllocatorPtr cpu_allocator = CPUAllocator::DefaultInstance(); CreateMLValue(std::move(cpu_allocator), dims_mul_x, values_mul_x, &ml_value_x); @@ -138,7 +138,7 @@ TEST(NnapiExecutionProviderTest, InternalUint8SupportTest) { std::vector dims_x = {1, 1, 1, 3}; std::vector values_x = {0.0f, 256.0f, 512.0f}; OrtValue ml_value_x; - AllocatorPtr cpu_allocator = std::make_shared(); + AllocatorPtr cpu_allocator = CPUAllocator::DefaultInstance(); CreateMLValue(std::move(cpu_allocator), dims_x, values_x, &ml_value_x); NameMLValMap feeds; @@ -195,7 +195,7 @@ TEST(NnapiExecutionProviderTest, FunctionTest) { std::vector dims_mul_x = {1, 1, 3, 2}; std::vector values_mul_x = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; OrtValue ml_value_x; - AllocatorPtr cpu_allocator = std::make_shared(); + AllocatorPtr cpu_allocator = CPUAllocator::DefaultInstance(); CreateMLValue(cpu_allocator, dims_mul_x, values_mul_x, &ml_value_x); OrtValue ml_value_y; @@ -522,7 +522,7 @@ TEST(NnapiExecutionProviderTest, SharedInitializersDoNotGetSkipped) { constexpr auto* model_file_name = ORT_TSTR("testdata/clip_div_shared_initializer.onnx"); #if defined(__ANDROID__) - AllocatorPtr cpu_allocator = std::make_shared(); + AllocatorPtr cpu_allocator = CPUAllocator::DefaultInstance(); std::vector x_dims{3, 2}; std::vector x_values(3.0f, 3 * 2); diff --git a/onnxruntime/test/providers/qnn/cast_test.cc b/onnxruntime/test/providers/qnn/cast_test.cc index fa26c764c1b7a..2a63d98ebb37e 100644 --- a/onnxruntime/test/providers/qnn/cast_test.cc +++ b/onnxruntime/test/providers/qnn/cast_test.cc @@ -127,7 +127,9 @@ TEST_F(QnnHTPBackendTests, TestCastInt32ToFloatHTP) { } // Cast uint8_t to float on HTP -TEST_F(QnnHTPBackendTests, TestCastUInt8ToFloatHTP) { +// Fails with QNN SDK 2.35.0: +// value pair (13, 1.00000012) at index #0 don't match, which is -12 from 13 +TEST_F(QnnHTPBackendTests, DISABLED_TestCastUInt8ToFloatHTP) { RunCastOpTest({3, 3}, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT, ExpectedEPNodeAssignment::All, true, false); } diff --git a/onnxruntime/test/providers/qnn/clip_op_test.cc b/onnxruntime/test/providers/qnn/clip_op_test.cc index 512403bc5a10b..83296d342e62b 100644 --- a/onnxruntime/test/providers/qnn/clip_op_test.cc +++ b/onnxruntime/test/providers/qnn/clip_op_test.cc @@ -76,7 +76,9 @@ TEST_F(QnnCPUBackendTests, Clip_5D_f32) { // // Test Clip with float32 on HTP -TEST_F(QnnHTPBackendTests, Clip_f32) { +// Fails with QNN SDK 2.35.0: +// value pair (-4.54545403, -4.54687548) at index #3 don't match, which is -0.00142145 from -4.54545 +TEST_F(QnnHTPBackendTests, DISABLED_Clip_f32) { bool on_cpu_backend = false; RunClipTest(TestInputDef({1, 1, 3, 4}, false, GetFloatDataInRange(-10.0f, 10.0f, 12)), {TestInputDef({}, true, {-5.0f}), diff --git a/onnxruntime/test/providers/qnn/conv_test.cc b/onnxruntime/test/providers/qnn/conv_test.cc index 8232742f35a31..15ac5d3cd6369 100644 --- a/onnxruntime/test/providers/qnn/conv_test.cc +++ b/onnxruntime/test/providers/qnn/conv_test.cc @@ -654,7 +654,9 @@ TEST_F(QnnCPUBackendTests, ConvTranspose1Df32_DynamicWeights_DefaultBias) { // It has to be QDQ model, because the DQ node with initializer on Conv gets processed first // and DQ node requires its node unit to be processed // So, Conv gets processed before Mul node -TEST_F(QnnHTPBackendTests, Test_QDQConvWithDynamicWeightsFromMul) { +// +// Since at least QAIRT 2.33 value pair (3.549, 3.588) at index #12709 don't match, which is 0.039 from 3.549 +TEST_F(QnnHTPBackendTests, DISABLED_Test_QDQConvWithDynamicWeightsFromMul) { ProviderOptions provider_options; provider_options["backend_type"] = "htp"; provider_options["offload_graph_io_quantization"] = "0"; @@ -2114,7 +2116,9 @@ TEST_F(QnnHTPBackendTests, ConvTranspose1DU8U8S32_AutoPadLower) { 13); } -TEST_F(QnnHTPBackendTests, ConvU8U8S32_large_input1_padding_bias_initializer) { +// Fails with QNN SDK 2.35.0: +// value pair (-4.54545403, -4.54687548) at index #3 don't match, which is -0.00142145 from -4.54545 +TEST_F(QnnHTPBackendTests, DISABLED_ConvU8U8S32_large_input1_padding_bias_initializer) { RunHTPConvOpTest("Conv", TestInputDef({1, 3, 60, 452}, false, 0.f, 10.f), // Dynamic input TestInputDef({16, 3, 3, 3}, true, -1.f, 1.f), // Static weights diff --git a/onnxruntime/test/providers/qnn/cumsum_op_htp_test.cc b/onnxruntime/test/providers/qnn/cumsum_op_htp_test.cc index f3affc18d8a9a..cfe6523639e96 100644 --- a/onnxruntime/test/providers/qnn/cumsum_op_htp_test.cc +++ b/onnxruntime/test/providers/qnn/cumsum_op_htp_test.cc @@ -37,7 +37,9 @@ static void RunCumSumOpTest(const std::string& op_type, } // Non-QDQ model, CumSum with float input and axis input as initializer with axis 0 -TEST_F(QnnHTPBackendTests, CumSum_float_int32_e0_r0_axis_0) { +// Fails with QNN SDK 2.35.0: +// Failed to finalize QNN graph. Error code: 1002 +TEST_F(QnnHTPBackendTests, DISABLED_CumSum_float_int32_e0_r0_axis_0) { RunCumSumOpTest("CumSum", TestInputDef({3, 2}, false, {1.3f, 7.2f, 0.4f, 3.4f, 5.7f, 0.8f}), TestInputDef({}, true, {0}), @@ -48,7 +50,9 @@ TEST_F(QnnHTPBackendTests, CumSum_float_int32_e0_r0_axis_0) { } // Non-QDQ model, CumSum with float input and axis input as initializer with axis -1 -TEST_F(QnnHTPBackendTests, CumSum_float_int32_e0_r0_axis_neg1) { +// Fails with QNN SDK 2.35.0: +// Failed to finalize QNN graph. Error code: 1002 +TEST_F(QnnHTPBackendTests, DISABLED_CumSum_float_int32_e0_r0_axis_neg1) { RunCumSumOpTest("CumSum", TestInputDef({3, 2}, false, {1.3f, 7.2f, 0.4f, 3.4f, 5.7f, 0.8f}), TestInputDef({}, true, {-1}), diff --git a/onnxruntime/test/providers/qnn/gather_op_htp_test.cc b/onnxruntime/test/providers/qnn/gather_op_htp_test.cc index 326354dffa8ae..22459bb4f6941 100644 --- a/onnxruntime/test/providers/qnn/gather_op_htp_test.cc +++ b/onnxruntime/test/providers/qnn/gather_op_htp_test.cc @@ -178,7 +178,9 @@ static void RunOpTest(const std::string& op_type, } // Non-QDQ model, Gather with static input and dynamic int64 indices -TEST_F(QnnHTPBackendTests, GatherOp_IndicesStaticInt64) { +// Fails with QNN SDK 2.35.0: +// Failed to finalize QNN graph. Error code: 1002 +TEST_F(QnnHTPBackendTests, DISABLED_GatherOp_IndicesStaticInt64) { RunOpTest("Gather", TestInputDef({3, 2}, true, {1.0f, 1.2f, 2.3f, 3.4f, 4.5f, 5.7f}), TestInputDef({2, 2}, false, {0, 1, 1, 2}), diff --git a/onnxruntime/test/providers/qnn/instance_norm_htp_test.cc b/onnxruntime/test/providers/qnn/instance_norm_htp_test.cc index 4823b152b0269..648fb00da611d 100644 --- a/onnxruntime/test/providers/qnn/instance_norm_htp_test.cc +++ b/onnxruntime/test/providers/qnn/instance_norm_htp_test.cc @@ -17,24 +17,24 @@ namespace test { #if defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) // Function that builds a QDQ model with an InstanceNormalization operator. -template -static GetTestQDQModelFn BuildQDQInstanceNormTestCase(const TestInputDef& input_def, - const TestInputDef& scale_def, - const TestInputDef& bias_def, - const std::vector& attrs, - bool use_contrib_qdq = false) { +template +static GetTestQDQModelFn BuildQDQInstanceNormTestCase(const TestInputDef& input_def, + const TestInputDef& scale_def, + const TestInputDef& bias_def, + const std::vector& attrs, + bool use_contrib_qdq = false) { return [input_def, scale_def, bias_def, attrs, use_contrib_qdq](ModelTestBuilder& builder, - std::vector>& output_qparams) { + std::vector>& output_qparams) { // input => Q => DQ => NodeArg* input = MakeTestInput(builder, input_def); - QuantParams input_qparams = GetTestInputQuantParams(input_def); + QuantParams input_qparams = GetTestInputQuantParams(input_def); NodeArg* input_qdq = AddQDQNodePair(builder, input, input_qparams.scale, input_qparams.zero_point, use_contrib_qdq); // scale => Q => DQ => NodeArg* scale = MakeTestInput(builder, scale_def); - QuantParams scale_qparams = GetTestInputQuantParams(scale_def); + QuantParams scale_qparams = GetTestInputQuantParams(scale_def); NodeArg* scale_qdq = AddQDQNodePair(builder, scale, scale_qparams.scale, scale_qparams.zero_point, use_contrib_qdq); @@ -51,8 +51,8 @@ static GetTestQDQModelFn BuildQDQInstanceNormTestCase(const TestInput } // Add instance_norm_output -> Q -> output_u8 - AddQDQNodePairWithOutputAsGraphOutput(builder, instance_norm_output, output_qparams[0].scale, - output_qparams[0].zero_point, use_contrib_qdq); + AddQDQNodePairWithOutputAsGraphOutput(builder, instance_norm_output, output_qparams[0].scale, + output_qparams[0].zero_point, use_contrib_qdq); }; } @@ -66,7 +66,7 @@ static GetTestQDQModelFn BuildQDQInstanceNormTestCase(const TestInput * \param attrs The node's attributes. The only valid attribute for InstanceNormalization is 'epsilon'. * \param expected_ep_assignment How many nodes are expected to be assigned to QNN (All, Some, or None). */ -template +template static void RunInstanceNormQDQTest(const TestInputDef& input_def, const TestInputDef& scale_def, const TestInputDef& bias_def, @@ -79,7 +79,7 @@ static void RunInstanceNormQDQTest(const TestInputDef& input_def, // Runs model with DQ-> InstanceNorm -> Q and compares the outputs of the CPU and QNN EPs. TestQDQModelAccuracy(BuildOpTestCase("InstanceNormalization", {input_def, scale_def, bias_def}, {}, attrs), - BuildQDQInstanceNormTestCase(input_def, scale_def, bias_def, attrs, use_contrib_qdq), + BuildQDQInstanceNormTestCase(input_def, scale_def, bias_def, attrs, use_contrib_qdq), provider_options, 18, expected_ep_assignment); @@ -87,7 +87,7 @@ static void RunInstanceNormQDQTest(const TestInputDef& input_def, // Check that QNN compiles DQ -> InstanceNormalization -> Q as a single unit. // Use an input of rank 4. -TEST_F(QnnHTPBackendTests, InstanceNormU8) { +TEST_F(QnnHTPBackendTests, InstanceNormU8U8) { // fails with QNN 2.15.1 with the following fixed input. std::vector input_data = {3.21289f, -5.9981f, -1.72799f, 6.27263f, 3.36205f, -1.93515f, -5.40113f, 3.75648f, 6.15357f, -5.25769f, 2.73637f, -0.901382f, -6.55612f, 1.99497f, -4.79228f, 2.69813f, 8.3064f, 0.0362501f}; @@ -100,22 +100,35 @@ TEST_F(QnnHTPBackendTests, InstanceNormU8) { ExpectedEPNodeAssignment::All); } -TEST_F(QnnHTPBackendTests, InstanceNormU16) { +TEST_F(QnnHTPBackendTests, InstanceNormU8S8) { + // Check SFIXED scale std::vector input_data = {3.21289f, -5.9981f, -1.72799f, 6.27263f, 3.36205f, -1.93515f, -5.40113f, 3.75648f, 6.15357f, -5.25769f, 2.73637f, -0.901382f, -6.55612f, 1.99497f, -4.79228f, 2.69813f, 8.3064f, 0.0362501f}; std::vector scale_data = {-0.148738f, -1.45158f}; std::vector bias_data = {-2.2785083772f, 2.3338717017f}; - RunInstanceNormQDQTest(TestInputDef({1, 2, 3, 3}, false, input_data).OverrideValueRange(-10.0f, 10.0f), - TestInputDef({2}, true, scale_data).OverrideValueRange(-2.0f, 2.0f), - TestInputDef({2}, true, bias_data).OverrideValueRange(-3.0f, 3.0f), - {}, - ExpectedEPNodeAssignment::All, - true); // Use contrib Q/DQ ops for 16bit support. + RunInstanceNormQDQTest(TestInputDef({1, 2, 3, 3}, false, input_data).OverrideValueRange(-10.0f, 10.0f), + TestInputDef({2}, true, scale_data).OverrideValueRange(-2.0f, 2.0f), + TestInputDef({2}, true, bias_data).OverrideValueRange(-3.0f, 3.0f), + {}, + ExpectedEPNodeAssignment::All); +} + +TEST_F(QnnHTPBackendTests, InstanceNormU16U8) { + std::vector input_data = {3.21289f, -5.9981f, -1.72799f, 6.27263f, 3.36205f, -1.93515f, -5.40113f, 3.75648f, 6.15357f, + -5.25769f, 2.73637f, -0.901382f, -6.55612f, 1.99497f, -4.79228f, 2.69813f, 8.3064f, 0.0362501f}; + std::vector scale_data = {-0.148738f, -1.45158f}; + std::vector bias_data = {-2.2785083772f, 2.3338717017f}; + RunInstanceNormQDQTest(TestInputDef({1, 2, 3, 3}, false, input_data).OverrideValueRange(-10.0f, 10.0f), + TestInputDef({2}, true, scale_data).OverrideValueRange(-2.0f, 2.0f), + TestInputDef({2}, true, bias_data).OverrideValueRange(-3.0f, 3.0f), + {}, + ExpectedEPNodeAssignment::All, + true); // Use contrib Q/DQ ops for 16bit support. } // Check that QNN compiles DQ -> InstanceNormalization -> Q as a single unit. // Use an input of rank 3. -TEST_F(QnnHTPBackendTests, InstanceNormU8Rank3) { +TEST_F(QnnHTPBackendTests, InstanceNormU8U8Rank3) { RunInstanceNormQDQTest(TestInputDef({1, 2, 3}, false, {6.0f, 4.0f, 2.0f, 6.0f, 8.0f, 2.0f}), TestInputDef({2}, true, {1.0f, 2.0f}), TestInputDef({2}, true, {1.0f, 3.0f}), @@ -125,7 +138,7 @@ TEST_F(QnnHTPBackendTests, InstanceNormU8Rank3) { // Test 8-bit QDQ InstanceNormalization with an input of rank 3 with N != 1, // which requires wrapping the QNN InstanceNorm op with reshapes. -TEST_F(QnnHTPBackendTests, InstanceNormU8Rank3_BatchSizeNot1) { +TEST_F(QnnHTPBackendTests, InstanceNormU8U8Rank3_BatchSizeNot1) { std::vector input_data = {6.0f, 4.0f, 2.0f, 6.0f, 8.0f, 2.0f, -8.0f, -6.0f, 0.0f, 1.0f, 3.0f, 6.0f}; RunInstanceNormQDQTest(TestInputDef({2, 2, 3}, false, input_data), @@ -137,21 +150,21 @@ TEST_F(QnnHTPBackendTests, InstanceNormU8Rank3_BatchSizeNot1) { // Test 16-bit QDQ InstanceNormalization with an input of rank 3 with N != 1, // which requires wrapping the QNN InstanceNorm op with reshapes. -TEST_F(QnnHTPBackendTests, InstanceNormU16Rank3_BatchSizeNot1) { +TEST_F(QnnHTPBackendTests, InstanceNormU16U8Rank3_BatchSizeNot1) { std::vector input_data = {6.0f, 4.0f, 2.0f, 6.0f, 8.0f, 2.0f, -8.0f, -6.0f, 0.0f, 1.0f, 3.0f, 6.0f}; - RunInstanceNormQDQTest(TestInputDef({2, 2, 3}, false, input_data), - TestInputDef({2}, true, {1.0f, 2.0f}), - TestInputDef({2}, true, {1.0f, 3.0f}), - {}, - ExpectedEPNodeAssignment::All, - true); // Use contrib Q/DQ ops for 16bit support. + RunInstanceNormQDQTest(TestInputDef({2, 2, 3}, false, input_data), + TestInputDef({2}, true, {1.0f, 2.0f}), + TestInputDef({2}, true, {1.0f, 3.0f}), + {}, + ExpectedEPNodeAssignment::All, + true); // Use contrib Q/DQ ops for 16bit support. } // Test 8-bit QDQ InstanceNormalization with an input of rank 3 with N != 1, // which requires wrapping the QNN InstanceNorm op with reshapes. // Input 0 is an initializer. -TEST_F(QnnHTPBackendTests, InstanceNormU8Rank3_BatchSizeNot1_Initializer) { +TEST_F(QnnHTPBackendTests, InstanceNormU8U8Rank3_BatchSizeNot1_Initializer) { std::vector input_data = {6.0f, 4.0f, 2.0f, 6.0f, 8.0f, 2.0f, -8.0f, -6.0f, 0.0f, 1.0f, 3.0f, 6.0f}; RunInstanceNormQDQTest(TestInputDef({2, 2, 3}, true, input_data), @@ -164,19 +177,19 @@ TEST_F(QnnHTPBackendTests, InstanceNormU8Rank3_BatchSizeNot1_Initializer) { // Test 16-bit QDQ InstanceNormalization with an input of rank 3 with N != 1, // which requires wrapping the QNN InstanceNorm op with reshapes. // Input 0 is an initializer. -TEST_F(QnnHTPBackendTests, InstanceNormU16Rank3_BatchSizeNot1_Initializer) { +TEST_F(QnnHTPBackendTests, InstanceNormU16U8Rank3_BatchSizeNot1_Initializer) { std::vector input_data = {6.0f, 4.0f, 2.0f, 6.0f, 8.0f, 2.0f, -8.0f, -6.0f, 0.0f, 1.0f, 3.0f, 6.0f}; - RunInstanceNormQDQTest(TestInputDef({2, 2, 3}, true, input_data), - TestInputDef({2}, true, {1.0f, 2.0f}), - TestInputDef({2}, false, {1.0f, 3.0f}), - {}, - ExpectedEPNodeAssignment::All, - true); // Use contrib Q/DQ ops for 16-bit support. + RunInstanceNormQDQTest(TestInputDef({2, 2, 3}, true, input_data), + TestInputDef({2}, true, {1.0f, 2.0f}), + TestInputDef({2}, false, {1.0f, 3.0f}), + {}, + ExpectedEPNodeAssignment::All, + true); // Use contrib Q/DQ ops for 16-bit support. } // Check that QNN InstanceNorm operator does not handle inputs with rank > 4. -TEST_F(QnnHTPBackendTests, InstanceNormU8Rank5) { +TEST_F(QnnHTPBackendTests, InstanceNormU8U8Rank5) { RunInstanceNormQDQTest(TestInputDef({1, 2, 3, 3, 3}, false, -10.0f, 10.0f), TestInputDef({2}, true, -2.0f, 2.0f), TestInputDef({2}, true, -3.0f, 3.0f), diff --git a/onnxruntime/test/providers/qnn/lstm_test.cc b/onnxruntime/test/providers/qnn/lstm_test.cc index 4b011b9bf1108..5d20806d3ea4d 100644 --- a/onnxruntime/test/providers/qnn/lstm_test.cc +++ b/onnxruntime/test/providers/qnn/lstm_test.cc @@ -316,7 +316,9 @@ static void RunCpuFP32LSTMOpTest(const TestInputDef& X_def, // TODO: Add P to unit test below once finalize issue is resolved // HTP QDQ -TEST_F(QnnHTPBackendTests, LSTM_QDQ_sanity_forward) { +// Fails with QNN SDK 2.35.0: +// Failed to finalize QNN graph. Error code: 1002 +TEST_F(QnnHTPBackendTests, DISABLED_LSTM_QDQ_sanity_forward) { std::string direction = "forward"; uint32_t num_direction = 1; uint32_t batch_size = 3; @@ -342,7 +344,9 @@ TEST_F(QnnHTPBackendTests, LSTM_QDQ_sanity_forward) { ExpectedEPNodeAssignment::All); } -TEST_F(QnnHTPBackendTests, LSTM_QDQ_sanity_reverse) { +// Fails with QNN SDK 2.35.0: +// Failed to finalize QNN graph. Error code: 1002 +TEST_F(QnnHTPBackendTests, DISABLED_LSTM_QDQ_sanity_reverse) { std::string direction = "reverse"; uint32_t num_direction = 1; uint32_t batch_size = 3; @@ -368,7 +372,9 @@ TEST_F(QnnHTPBackendTests, LSTM_QDQ_sanity_reverse) { ExpectedEPNodeAssignment::All); } -TEST_F(QnnHTPBackendTests, LSTM_QDQ_sanity_bidirectional) { +// Fails with QNN SDK 2.35.0: +// Failed to finalize QNN graph. Error code: 1002 +TEST_F(QnnHTPBackendTests, DISABLED_LSTM_QDQ_sanity_bidirectional) { std::string direction = "bidirectional"; uint32_t num_direction = 2; uint32_t batch_size = 3; @@ -394,7 +400,9 @@ TEST_F(QnnHTPBackendTests, LSTM_QDQ_sanity_bidirectional) { ExpectedEPNodeAssignment::All); } -TEST_F(QnnHTPBackendTests, LSTM_QDQ_sanity_bidirectional_wo_B) { +// Fails with QNN SDK 2.35.0: +// Failed to finalize QNN graph. Error code: 1002 +TEST_F(QnnHTPBackendTests, DISABLED_LSTM_QDQ_sanity_bidirectional_wo_B) { std::string direction = "bidirectional"; uint32_t num_direction = 2; uint32_t batch_size = 3; @@ -419,7 +427,9 @@ TEST_F(QnnHTPBackendTests, LSTM_QDQ_sanity_bidirectional_wo_B) { ExpectedEPNodeAssignment::All); } -TEST_F(QnnHTPBackendTests, LSTM_QDQ_sanity_bidirectional_wo_H) { +// Fails with QNN SDK 2.35.0: +// Failed to finalize QNN graph. Error code: 1002 +TEST_F(QnnHTPBackendTests, DISABLED_LSTM_QDQ_sanity_bidirectional_wo_H) { std::string direction = "bidirectional"; uint32_t num_direction = 2; uint32_t batch_size = 3; @@ -444,7 +454,9 @@ TEST_F(QnnHTPBackendTests, LSTM_QDQ_sanity_bidirectional_wo_H) { ExpectedEPNodeAssignment::All); } -TEST_F(QnnHTPBackendTests, LSTM_QDQ_sanity_bidirectional_wo_C) { +// Fails with QNN SDK 2.35.0: +// Failed to finalize QNN graph. Error code: 1002 +TEST_F(QnnHTPBackendTests, DISABLED_LSTM_QDQ_sanity_bidirectional_wo_C) { std::string direction = "bidirectional"; uint32_t num_direction = 2; uint32_t batch_size = 3; @@ -469,7 +481,9 @@ TEST_F(QnnHTPBackendTests, LSTM_QDQ_sanity_bidirectional_wo_C) { ExpectedEPNodeAssignment::All); } -TEST_F(QnnHTPBackendTests, LSTM_QDQ_sanity_bidirectional_all_initializer) { +// Fails with QNN SDK 2.35.0: +// Failed to finalize QNN graph. Error code: 1002 +TEST_F(QnnHTPBackendTests, DISABLED_LSTM_QDQ_sanity_bidirectional_all_initializer) { std::string direction = "bidirectional"; uint32_t num_direction = 2; uint32_t batch_size = 3; @@ -497,7 +511,9 @@ TEST_F(QnnHTPBackendTests, LSTM_QDQ_sanity_bidirectional_all_initializer) { QDQTolerance(0.008f)); } -TEST_F(QnnHTPBackendTests, LSTM_QDQ_sanity_bidirectional_Y_only) { +// Fails with QNN SDK 2.35.0: +// Failed to finalize QNN graph. Error code: 1002 +TEST_F(QnnHTPBackendTests, DISABLED_LSTM_QDQ_sanity_bidirectional_Y_only) { std::string direction = "bidirectional"; uint32_t num_direction = 2; uint32_t batch_size = 3; @@ -523,7 +539,9 @@ TEST_F(QnnHTPBackendTests, LSTM_QDQ_sanity_bidirectional_Y_only) { ExpectedEPNodeAssignment::All); } -TEST_F(QnnHTPBackendTests, LSTM_QDQ_sanity_bidirectional_Y_h_only) { +// Fails with QNN SDK 2.35.0: +// Failed to finalize QNN graph. Error code: 1002 +TEST_F(QnnHTPBackendTests, DISABLED_LSTM_QDQ_sanity_bidirectional_Y_h_only) { std::string direction = "bidirectional"; uint32_t num_direction = 2; uint32_t batch_size = 3; @@ -549,7 +567,9 @@ TEST_F(QnnHTPBackendTests, LSTM_QDQ_sanity_bidirectional_Y_h_only) { ExpectedEPNodeAssignment::All); } -TEST_F(QnnHTPBackendTests, LSTM_QDQ_sanity_bidirectional_Y_c_only) { +// Fails with QNN SDK 2.35.0: +// Failed to finalize QNN graph. Error code: 1002 +TEST_F(QnnHTPBackendTests, DISABLED_LSTM_QDQ_sanity_bidirectional_Y_c_only) { std::string direction = "bidirectional"; uint32_t num_direction = 2; uint32_t batch_size = 3; @@ -576,7 +596,9 @@ TEST_F(QnnHTPBackendTests, LSTM_QDQ_sanity_bidirectional_Y_c_only) { } // HTP Fp16 -TEST_F(QnnHTPBackendTests, LSTM_Fp16_sanity_forward) { +// Fails with QNN SDK 2.35.0: +// Failed to finalize QNN graph. Error code: 1002 +TEST_F(QnnHTPBackendTests, DISABLED_LSTM_Fp16_sanity_forward) { std::string direction = "forward"; uint32_t num_direction = 1; uint32_t batch_size = 3; @@ -602,7 +624,9 @@ TEST_F(QnnHTPBackendTests, LSTM_Fp16_sanity_forward) { ExpectedEPNodeAssignment::All); } -TEST_F(QnnHTPBackendTests, LSTM_Fp16_sanity_reverse) { +// Fails with QNN SDK 2.35.0: +// Failed to finalize QNN graph. Error code: 1002 +TEST_F(QnnHTPBackendTests, DISABLED_LSTM_Fp16_sanity_reverse) { std::string direction = "reverse"; uint32_t num_direction = 1; uint32_t batch_size = 3; @@ -628,7 +652,9 @@ TEST_F(QnnHTPBackendTests, LSTM_Fp16_sanity_reverse) { ExpectedEPNodeAssignment::All); } -TEST_F(QnnHTPBackendTests, LSTM_Fp16_sanity_bidirectional) { +// Fails with QNN SDK 2.35.0: +// Failed to finalize QNN graph. Error code: 1002 +TEST_F(QnnHTPBackendTests, DISABLED_LSTM_Fp16_sanity_bidirectional) { std::string direction = "bidirectional"; uint32_t num_direction = 2; uint32_t batch_size = 3; @@ -655,7 +681,9 @@ TEST_F(QnnHTPBackendTests, LSTM_Fp16_sanity_bidirectional) { ExpectedEPNodeAssignment::All); } -TEST_F(QnnHTPBackendTests, LSTM_Fp16_sanity_bidirectional_wo_B) { +// Fails with QNN SDK 2.35.0: +// Failed to finalize QNN graph. Error code: 1002 +TEST_F(QnnHTPBackendTests, DISABLED_LSTM_Fp16_sanity_bidirectional_wo_B) { std::string direction = "bidirectional"; uint32_t num_direction = 2; uint32_t batch_size = 3; @@ -681,7 +709,9 @@ TEST_F(QnnHTPBackendTests, LSTM_Fp16_sanity_bidirectional_wo_B) { ExpectedEPNodeAssignment::All); } -TEST_F(QnnHTPBackendTests, LSTM_Fp16_sanity_bidirectional_wo_H) { +// Fails with QNN SDK 2.35.0: +// Failed to finalize QNN graph. Error code: 1002 +TEST_F(QnnHTPBackendTests, DISABLED_LSTM_Fp16_sanity_bidirectional_wo_H) { std::string direction = "bidirectional"; uint32_t num_direction = 2; uint32_t batch_size = 3; @@ -707,7 +737,9 @@ TEST_F(QnnHTPBackendTests, LSTM_Fp16_sanity_bidirectional_wo_H) { ExpectedEPNodeAssignment::All); } -TEST_F(QnnHTPBackendTests, LSTM_Fp16_sanity_bidirectional_wo_C) { +// Fails with QNN SDK 2.35.0: +// Failed to finalize QNN graph. Error code: 1002 +TEST_F(QnnHTPBackendTests, DISABLED_LSTM_Fp16_sanity_bidirectional_wo_C) { std::string direction = "bidirectional"; uint32_t num_direction = 2; uint32_t batch_size = 3; @@ -733,7 +765,9 @@ TEST_F(QnnHTPBackendTests, LSTM_Fp16_sanity_bidirectional_wo_C) { ExpectedEPNodeAssignment::All); } -TEST_F(QnnHTPBackendTests, LSTM_Fp16_sanity_bidirectional_all_initializer) { +// Fails with QNN SDK 2.35.0: +// Failed to finalize QNN graph. Error code: 1002 +TEST_F(QnnHTPBackendTests, DISABLED_LSTM_Fp16_sanity_bidirectional_all_initializer) { std::string direction = "bidirectional"; uint32_t num_direction = 2; uint32_t batch_size = 3; @@ -760,7 +794,9 @@ TEST_F(QnnHTPBackendTests, LSTM_Fp16_sanity_bidirectional_all_initializer) { ExpectedEPNodeAssignment::All); } -TEST_F(QnnHTPBackendTests, LSTM_Fp16_sanity_bidirectional_Y_only) { +// Fails with QNN SDK 2.35.0: +// Failed to finalize QNN graph. Error code: 1002 +TEST_F(QnnHTPBackendTests, DISABLED_LSTM_Fp16_sanity_bidirectional_Y_only) { std::string direction = "bidirectional"; uint32_t num_direction = 2; uint32_t batch_size = 3; @@ -787,7 +823,9 @@ TEST_F(QnnHTPBackendTests, LSTM_Fp16_sanity_bidirectional_Y_only) { ExpectedEPNodeAssignment::All); } -TEST_F(QnnHTPBackendTests, LSTM_Fp16_sanity_bidirectional_Y_h_only) { +// Fails with QNN SDK 2.35.0: +// Failed to finalize QNN graph. Error code: 1002 +TEST_F(QnnHTPBackendTests, DISABLED_LSTM_Fp16_sanity_bidirectional_Y_h_only) { std::string direction = "bidirectional"; uint32_t num_direction = 2; uint32_t batch_size = 3; @@ -814,7 +852,9 @@ TEST_F(QnnHTPBackendTests, LSTM_Fp16_sanity_bidirectional_Y_h_only) { ExpectedEPNodeAssignment::All); } -TEST_F(QnnHTPBackendTests, LSTM_Fp16_sanity_bidirectional_Y_c_only) { +// Fails with QNN SDK 2.35.0: +// Failed to finalize QNN graph. Error code: 1002 +TEST_F(QnnHTPBackendTests, DISABLED_LSTM_Fp16_sanity_bidirectional_Y_c_only) { std::string direction = "bidirectional"; uint32_t num_direction = 2; uint32_t batch_size = 3; diff --git a/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc b/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc index 6ef831c8ecd6f..87385b7964d98 100644 --- a/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc +++ b/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc @@ -1758,6 +1758,107 @@ TEST_F(QnnHTPBackendTests, QnnContextShareAcrossSessions) { std::remove(qnn_ctx_binary_file_name1.c_str()); } +TEST_F(QnnHTPBackendTests, VTCMBackupBufferSharing) { + ProviderOptions provider_options; + provider_options["offload_graph_io_quantization"] = "0"; + provider_options["backend_type"] = "htp"; + + // Create QDQ models + std::vector onnx_model_paths{"./weight_share1.onnx", "./weight_share2.onnx"}; + // cleanup in case some failure test doesn't remove them + for (auto model_path : onnx_model_paths) { + std::remove(model_path.c_str()); + } + + std::vector ctx_model_paths; + for (auto model_path : onnx_model_paths) { + CreateQdqModel(model_path, DefaultLoggingManager().DefaultLogger()); + EXPECT_TRUE(std::filesystem::exists(model_path.c_str())); + auto pos = model_path.find_last_of("."); + if (pos != std::string::npos) { + model_path = model_path.substr(0, pos) + "_ctx.onnx"; + } else { + model_path = model_path + "_ctx.onnx"; + } + ctx_model_paths.push_back(model_path); + } + for (auto ctx_model_path : ctx_model_paths) { + std::remove(ctx_model_path.c_str()); + } + + DumpModelWithSharedCtx(provider_options, onnx_model_paths[0], onnx_model_paths[1]); + + std::string qnn_ctx_binary_file_name1; + GetContextBinaryFileName(ctx_model_paths[0], qnn_ctx_binary_file_name1, + DefaultLoggingManager().DefaultLogger()); + EXPECT_TRUE(!qnn_ctx_binary_file_name1.empty()); + + std::string qnn_ctx_binary_file_name2; + GetContextBinaryFileName(ctx_model_paths[1], qnn_ctx_binary_file_name2, + DefaultLoggingManager().DefaultLogger()); + EXPECT_TRUE(!qnn_ctx_binary_file_name2.empty()); + // 2 *_ctx.onn point to same .bin file + EXPECT_TRUE(qnn_ctx_binary_file_name1 == qnn_ctx_binary_file_name2); + auto file_size_1 = std::filesystem::file_size(qnn_ctx_binary_file_name1); + EXPECT_TRUE(file_size_1 > 0); + + provider_options["enable_vtcm_backup_buffer_sharing"] = "1"; + // only load and run the session on real device +#if defined(__aarch64__) || defined(_M_ARM64) + Ort::SessionOptions so1; + so1.SetLogId("so1"); + so1.AppendExecutionProvider("QNN", provider_options); + Ort::SessionOptions so2; + so2.SetLogId("so2"); + so2.AppendExecutionProvider("QNN", provider_options); + + EXPECT_TRUE(2 == ctx_model_paths.size()); +#ifdef _WIN32 + std::wstring ctx_model_file1(ctx_model_paths[0].begin(), ctx_model_paths[0].end()); + std::wstring ctx_model_file2(ctx_model_paths[1].begin(), ctx_model_paths[1].end()); +#else + std::string ctx_model_file1(ctx_model_paths[0].begin(), ctx_model_paths[0].end()); + std::string ctx_model_file2(ctx_model_paths[1].begin(), ctx_model_paths[1].end()); +#endif + Ort::Session session1(*ort_env, ctx_model_file1.c_str(), so1); + Ort::Session session2(*ort_env, ctx_model_file2.c_str(), so2); + + std::vector input_names; + std::vector output_names; + GetModelInputNames(ctx_model_paths[1], input_names, output_names, + DefaultLoggingManager().DefaultLogger()); + + // Run sessions + // prepare input + std::vector input_dim{2, 3}; + std::vector input_value(2 * 3, 0.0f); + Ort::MemoryInfo info("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault); + std::vector ort_inputs; + std::vector input_names_c; + for (size_t i = 0; i < input_names.size(); ++i) { + auto input_tensor = Ort::Value::CreateTensor(info, input_value.data(), input_value.size(), + input_dim.data(), input_dim.size()); + ort_inputs.push_back(std::move(input_tensor)); + input_names_c.push_back(input_names[i].c_str()); + } + std::vector output_names_c; + for (size_t i = 0; i < output_names.size(); ++i) { + output_names_c.push_back(output_names[i].c_str()); + } + + auto ort_outputs1 = session1.Run(Ort::RunOptions{}, input_names_c.data(), ort_inputs.data(), ort_inputs.size(), + output_names_c.data(), 1); +#endif + + for (auto model_path : onnx_model_paths) { + std::remove(model_path.c_str()); + } + for (auto ctx_model_path : ctx_model_paths) { + std::remove(ctx_model_path.c_str()); + } + std::remove(qnn_ctx_binary_file_name1.c_str()); +} + // For Ort sessions to generate the context binary, with session option ep.share_ep_contexts enabled // Ort sessions will share the QnnBackendManager, so that all graphs from all models compile into the same Qnn context TEST_F(QnnHTPBackendTests, QnnContextGenWeightSharingSessionAPI) { diff --git a/onnxruntime/test/providers/qnn/qnn_node_group/scale_softmax_fusion_test.cc b/onnxruntime/test/providers/qnn/qnn_node_group/scale_softmax_fusion_test.cc index aa8dc492a95c9..eda04b954f590 100644 --- a/onnxruntime/test/providers/qnn/qnn_node_group/scale_softmax_fusion_test.cc +++ b/onnxruntime/test/providers/qnn/qnn_node_group/scale_softmax_fusion_test.cc @@ -64,7 +64,7 @@ ProviderOptions GetProviderOptions() { #if defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) -TEST_F(QnnHTPBackendTests, ScaleSoftmaxFusionScalarInitializer) { +TEST_F(QnnHTPBackendTests, DISABLED_ScaleSoftmaxFusionScalarInitializer) { ProviderOptions provider_options = GetProviderOptions(); auto input_def = TestInputDef({1, 3, 5, 5}, false, -0.5f, 0.5f); @@ -75,7 +75,7 @@ TEST_F(QnnHTPBackendTests, ScaleSoftmaxFusionScalarInitializer) { /*fp32_abs_err=*/1e-2f); } -TEST_F(QnnHTPBackendTests, ScaleSoftmaxFusionScalarConstant) { +TEST_F(QnnHTPBackendTests, DISABLED_ScaleSoftmaxFusionScalarConstant) { ProviderOptions provider_options = GetProviderOptions(); auto input_def = TestInputDef({1, 3, 5, 5}, false, -0.5f, 0.5f); @@ -86,7 +86,7 @@ TEST_F(QnnHTPBackendTests, ScaleSoftmaxFusionScalarConstant) { /*fp32_abs_err=*/1e-2f); } -TEST_F(QnnHTPBackendTests, ScaleSoftmaxFusionScalarInitializerReversed) { +TEST_F(QnnHTPBackendTests, DISABLED_ScaleSoftmaxFusionScalarInitializerReversed) { ProviderOptions provider_options = GetProviderOptions(); auto input_def = TestInputDef({1, 3, 5, 5}, false, -0.5f, 0.5f); RunQnnModelTest(BuildTestCaseScalar(input_def, 0.375f, /*use_constant=*/false, /*reverse_input_order=*/true), @@ -96,7 +96,7 @@ TEST_F(QnnHTPBackendTests, ScaleSoftmaxFusionScalarInitializerReversed) { /*fp32_abs_err=*/1e-2f); } -TEST_F(QnnHTPBackendTests, ScaleSoftmaxFusionScalarConstantReversed) { +TEST_F(QnnHTPBackendTests, DISABLED_ScaleSoftmaxFusionScalarConstantReversed) { ProviderOptions provider_options = GetProviderOptions(); auto input_def = TestInputDef({1, 3, 5, 5}, false, -0.5f, 0.5f); RunQnnModelTest(BuildTestCaseScalar(input_def, 0.125f, /*use_constant=*/true, /*reverse_input _order=*/true), @@ -106,7 +106,7 @@ TEST_F(QnnHTPBackendTests, ScaleSoftmaxFusionScalarConstantReversed) { /*fp32_abs_err=*/1e-2f); } -TEST_F(QnnHTPBackendTests, ScaleSoftmaxFusionSoftmaxNegativeAxis) { +TEST_F(QnnHTPBackendTests, DISABLED_ScaleSoftmaxFusionSoftmaxNegativeAxis) { ProviderOptions provider_options = GetProviderOptions(); auto input_def = TestInputDef({1, 3, 5, 5}, false, -0.5f, 0.5f); RunQnnModelTest(BuildTestCaseScalar(input_def, 0.125f, diff --git a/onnxruntime/test/providers/qnn/reshape_expand_op_test.cc b/onnxruntime/test/providers/qnn/reshape_expand_op_test.cc index d32be64fa5229..45626d63d1970 100644 --- a/onnxruntime/test/providers/qnn/reshape_expand_op_test.cc +++ b/onnxruntime/test/providers/qnn/reshape_expand_op_test.cc @@ -289,6 +289,16 @@ TEST_F(QnnHTPBackendTests, Expand_HTP_int64) { 19); // Opset } +// Test that bool Expand runs on HTP backend. +TEST_F(QnnHTPBackendTests, Expand_HTP_bool) { + RunReshapeExpandTestOnHTP("Expand", + TestInputDef({1}, false, {true}), + TestInputDef({3}, true, {1, 2, 3}), + {}, // Attributes + ExpectedEPNodeAssignment::All, + 19); // Opset +} + // Test QDQ Expand TEST_F(QnnHTPBackendTests, Expand_4D) { RunQDQReshapeExpandTestOnHTP("Expand", diff --git a/onnxruntime/test/providers/qnn/simple_op_htp_test.cc b/onnxruntime/test/providers/qnn/simple_op_htp_test.cc index b441af4a0efe9..761bf63976bec 100644 --- a/onnxruntime/test/providers/qnn/simple_op_htp_test.cc +++ b/onnxruntime/test/providers/qnn/simple_op_htp_test.cc @@ -991,6 +991,23 @@ TEST_F(QnnHTPBackendTests, BinaryOp_And4D) { ExpectedEPNodeAssignment::All); } +// Test Reciprocal on HTP +TEST_F(QnnHTPBackendTests, Reciprocal_Basic_FLOAT) { + RunOpTest("Reciprocal", + {TestInputDef({2, 2}, false, {1.0f, 2.0f, 0.5f, 4.0f})}, + {}, // No attributes + 13, + ExpectedEPNodeAssignment::All); +} + +TEST_F(QnnHTPBackendTests, Reciprocal_QU8) { + RunQDQOpTest("Reciprocal", + {TestInputDef({2, 2}, false, GetFloatDataInRange(1.0f, 5.0f, 4))}, + {}, // No attributes + 13, + ExpectedEPNodeAssignment::All); +} + // Test ScatterND op on HTP TEST_F(QnnHTPBackendTests, ScatterND_int64_int64) { std::vector data = {0, 1, 2, 3}; diff --git a/onnxruntime/test/providers/qnn/transpose_htp_test.cc b/onnxruntime/test/providers/qnn/transpose_htp_test.cc index f206e517408bf..83ff6440c8399 100644 --- a/onnxruntime/test/providers/qnn/transpose_htp_test.cc +++ b/onnxruntime/test/providers/qnn/transpose_htp_test.cc @@ -120,7 +120,9 @@ TEST_F(QnnHTPBackendTests, TransposeInt32OnHTP) { } // Check that QNN supports Transpose with float32 data input on HTP -TEST_F(QnnHTPBackendTests, TransposeFloatOnHTP) { +// Fails with QNN SDK 2.35.0: +// value pair (0.183528364, 0.183471695) at index #0 don't match, which is -5.66691e-05 from 0.183528 +TEST_F(QnnHTPBackendTests, DISABLED_TransposeFloatOnHTP) { RunTransposeNonQDQOnHTP(TestInputDef({1, 3, 224, 128}, false, 0, 10.0f), {utils::MakeAttribute("perm", std::vector{0, 2, 3, 1})}, ExpectedEPNodeAssignment::All, false); diff --git a/onnxruntime/test/providers/rknpu/rknpu_basic_test.cc b/onnxruntime/test/providers/rknpu/rknpu_basic_test.cc index 63a672615c27b..fffd081a692c0 100644 --- a/onnxruntime/test/providers/rknpu/rknpu_basic_test.cc +++ b/onnxruntime/test/providers/rknpu/rknpu_basic_test.cc @@ -60,7 +60,7 @@ TEST(RknpuExecutionProviderTest, FunctionTest) { std::vector dims_mul_x = {1, 1, 3, 2}; std::vector values_mul_x = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; - AllocatorPtr cpu_allocator = std::make_shared(); + AllocatorPtr cpu_allocator = CPUAllocator::DefaultInstance(); OrtValue ml_value_x; CreateMLValue(cpu_allocator, dims_mul_x, values_mul_x, &ml_value_x); OrtValue ml_value_y; diff --git a/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc b/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc index ddbcfd4931835..553059932db90 100644 --- a/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc +++ b/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc @@ -662,13 +662,10 @@ TEST(TensorrtExecutionProviderTest, ExcludeOpsTest) { params.trt_engine_cache_enable = 1; params.trt_op_types_to_exclude = "MaxPool"; std::unique_ptr execution_provider = TensorrtExecutionProviderWithOptions(¶ms); - EXPECT_TRUE(session_object.RegisterExecutionProvider(std::move(execution_provider)).IsOK()); - auto status = session_object.Load(model_name); - ASSERT_TRUE(status.IsOK()); - status = session_object.Initialize(); - ASSERT_TRUE(status.IsOK()); - status = session_object.Run(run_options, feeds, output_names, &fetches); - ASSERT_TRUE(status.IsOK()); + ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(std::move(execution_provider))); + ASSERT_STATUS_OK(session_object.Load(model_name)); + ASSERT_STATUS_OK(session_object.Initialize()); + ASSERT_STATUS_OK(session_object.Run(run_options, feeds, output_names, &fetches)); std::vector engine_files; engine_files = GetCachesByType("./", ".engine"); diff --git a/onnxruntime/test/python/onnxruntime_test_python_autoep.py b/onnxruntime/test/python/onnxruntime_test_python_autoep.py index 417a6e27fb7b2..f1c924a1ade94 100644 --- a/onnxruntime/test/python/onnxruntime_test_python_autoep.py +++ b/onnxruntime/test/python/onnxruntime_test_python_autoep.py @@ -13,7 +13,7 @@ from helper import get_name import onnxruntime as onnxrt -from onnxruntime.capi.onnxruntime_pybind11_state import Fail, InvalidArgument +from onnxruntime.capi.onnxruntime_pybind11_state import Fail # handle change from python 3.8 and on where loading a dll from the current directory needs to be explicitly allowed. if platform.system() == "Windows" and sys.version_info.major >= 3 and sys.version_info.minor >= 8: # noqa: YTT204 @@ -225,14 +225,22 @@ def test_example_plugin_ep_devices(self): hw_metadata = hw_device.metadata self.assertGreater(len(hw_metadata), 0) # Should have at least SPDRP_HARDWAREID on Windows - # Test adding this EP plugin's OrtEpDevice to the SessionOptions. + # Add EP plugin's OrtEpDevice to the SessionOptions. sess_options = onnxrt.SessionOptions() - with self.assertRaises(InvalidArgument) as context: - # Will raise InvalidArgument because ORT currently only supports provider bridge APIs. - # Actual plugin EPs will be supported in the future. - sess_options.add_provider_for_devices([test_ep_device], {"opt1": "val1"}) - self.assertIn("EP is not currently supported", str(context.exception)) + sess_options.add_provider_for_devices([test_ep_device], {"opt1": "val1"}) + sess_options.log_severity_level = 1 # INFO + self.assertTrue(sess_options.has_providers()) + + # Run sample model and check output + sess = onnxrt.InferenceSession(get_name("mul_1.onnx"), sess_options=sess_options) + x = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=np.float32) + input_name = sess.get_inputs()[0].name + res = sess.run([], {input_name: x}) + output_expected = np.array([[1.0, 4.0], [9.0, 16.0], [25.0, 36.0]], dtype=np.float32) + np.testing.assert_allclose(output_expected, res[0], rtol=1e-05, atol=1e-08) + + del sess # Delete session before unregistering library self.unregister_execution_provider_library(ep_name) diff --git a/onnxruntime/test/python/onnxruntime_test_python_global_threadpool.py b/onnxruntime/test/python/onnxruntime_test_python_global_threadpool.py new file mode 100644 index 0000000000000..dc2796c47db35 --- /dev/null +++ b/onnxruntime/test/python/onnxruntime_test_python_global_threadpool.py @@ -0,0 +1,36 @@ +# pylint: disable=C0115,W0212,C0103,C0114 +import unittest + +import numpy as np +from helper import get_name + +import onnxruntime as onnxrt + + +class TestGlobalThreadPool(unittest.TestCase): + @classmethod + def setUpClass(cls): + onnxrt.set_global_thread_pool_sizes(2, 2) + + def test_global_threadpool(self): + session_opts = onnxrt.SessionOptions() + session_opts.execution_mode = onnxrt.ExecutionMode.ORT_PARALLEL + session_opts.graph_optimization_level = onnxrt.GraphOptimizationLevel.ORT_DISABLE_ALL + session_opts.use_per_session_threads = False + session = onnxrt.InferenceSession( + get_name("mnist.onnx"), session_opts, providers=onnxrt.get_available_providers() + ) + input = np.ones([1, 1, 28, 28], np.float32) + session.run(None, {"Input3": input}) + + def test_raise_error_if_use_per_session_threads(self): + session_opts = onnxrt.SessionOptions() + session_opts.execution_mode = onnxrt.ExecutionMode.ORT_PARALLEL + session_opts.graph_optimization_level = onnxrt.GraphOptimizationLevel.ORT_DISABLE_ALL + session_opts.use_per_session_threads = True + with self.assertRaises(RuntimeError): + onnxrt.InferenceSession(get_name("mnist.onnx"), session_opts, providers=onnxrt.get_available_providers()) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxruntime/test/python/transformers/test_paged_attention_cuda.py b/onnxruntime/test/python/transformers/test_paged_attention_cuda.py new file mode 100644 index 0000000000000..cc9a02a8074c0 --- /dev/null +++ b/onnxruntime/test/python/transformers/test_paged_attention_cuda.py @@ -0,0 +1,735 @@ +# -------------------------------------------------------------------------- +# Copyright 2020 The HuggingFace Inc. team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# ------------------------------------------------------------------------- +import math +import platform +import random +import unittest + +import numpy +import torch +from einops import rearrange, repeat +from onnx import TensorProto, helper +from packaging import version +from parameterized import parameterized +from test_gqa_cpu import smooth_softmax_ref + +from onnxruntime import InferenceSession, OrtValue, SessionOptions, get_available_providers + +torch.manual_seed(0) + +pipeline_mode = True # Reduces number of tests so pipeline doesn't time out + + +class Config: + batch_size = 0 + sequence_length = 0 + total_sequence_length = 0 + num_heads = 0 + kv_num_heads = 0 + head_size = 0 + paged_kv_block_size = 0 + local = False + rotary = False + rotary_interleaved = False + packed = False + softcap = 0.0 + ep = "CUDAExecutionProvider" + + def __init__( + self, + batch_size, + sequence_length, + total_sequence_length, + num_heads, + kv_num_heads, + head_size, + paged_kv_block_size, + local, + rotary, + rotary_interleaved, + packed, + softcap, + ): + self.batch_size = batch_size + self.sequence_length = sequence_length + self.total_sequence_length = total_sequence_length + self.num_heads = num_heads + self.kv_num_heads = kv_num_heads + self.head_size = head_size + self.paged_kv_block_size = paged_kv_block_size + self.local = local + self.rotary = rotary + self.rotary_interleaved = rotary_interleaved + self.packed = packed + self.softcap = softcap + + def __repr__(self): + short_ep = self.ep[: -len("ExecutionProvider")].lower() + return ( + f"Config(batch_size={self.batch_size}, sequence_length={self.sequence_length}, " + f"total_sequence_length={self.total_sequence_length}, num_heads={self.num_heads}, " + f"kv_num_heads={self.kv_num_heads}, head_size={self.head_size}, " + f"paged_kv_block_size={self.paged_kv_block_size} rotary={self.rotary}, " + f"rotary_interleaved={self.rotary_interleaved}, packed={self.packed}, softcap={self.softcap}, " + f"ep={short_ep})" + ) + + +def create_paged_attention_graph( + config, + num_tokens, + num_blocks, + max_blocks_per_sequence, + local_window_size=-1, +): + nodes = [ + helper.make_node( + "PagedAttention", + [ + "query", + "key" if not config.packed else "", + "value" if not config.packed else "", + "key_cache", + "value_cache", + "cumulative_sequence_length", + "past_seqlens", + "block_table", + "cos_cache" if config.rotary else "", + "sin_cache" if config.rotary else "", + ], + ["output", "key_cache_out", "value_cache_out"], + "PagedAttention_0", + num_heads=config.num_heads, + kv_num_heads=config.kv_num_heads, + local_window_size=local_window_size, + do_rotary=config.rotary, + rotary_interleaved=config.rotary_interleaved, + softcap=config.softcap, + domain="com.microsoft", + ), + ] + + graph_input = [ + helper.make_tensor_value_info( + "query", + TensorProto.FLOAT16, + [ + num_tokens, + (config.num_heads * config.head_size) + if not config.packed + else (config.num_heads * config.head_size + 2 * config.kv_num_heads * config.head_size), + ], + ), + helper.make_tensor_value_info( + "key_cache", + TensorProto.FLOAT16, + [ + num_blocks, + config.paged_kv_block_size, + config.kv_num_heads, + config.head_size, + ], + ), + helper.make_tensor_value_info( + "value_cache", + TensorProto.FLOAT16, + [ + num_blocks, + config.paged_kv_block_size, + config.kv_num_heads, + config.head_size, + ], + ), + helper.make_tensor_value_info( + "cumulative_sequence_length", + TensorProto.INT32, + [config.batch_size + 1], + ), + helper.make_tensor_value_info( + "past_seqlens", + TensorProto.INT32, + [config.batch_size], + ), + helper.make_tensor_value_info( + "block_table", + TensorProto.INT32, + [config.batch_size, max_blocks_per_sequence], + ), + ] + if not config.packed: + graph_input += [ + helper.make_tensor_value_info( + "key", + TensorProto.FLOAT16, + [ + num_tokens, + config.kv_num_heads * config.head_size, + ], + ), + helper.make_tensor_value_info( + "value", + TensorProto.FLOAT16, + [ + num_tokens, + config.kv_num_heads * config.head_size, + ], + ), + ] + if config.rotary: + graph_input += [ + helper.make_tensor_value_info( + "cos_cache", + TensorProto.FLOAT16, + [ + config.total_sequence_length, + (math.floor(config.head_size / 16) * 16) // 2, + ], + ), + helper.make_tensor_value_info( + "sin_cache", + TensorProto.FLOAT16, + [ + config.total_sequence_length, + (math.floor(config.head_size / 16) * 16) // 2, + ], + ), + ] + + graph_output = [ + helper.make_tensor_value_info( + "output", + TensorProto.FLOAT16, + [num_tokens, config.num_heads * config.head_size], + ), + helper.make_tensor_value_info( + "key_cache_out", + TensorProto.FLOAT16, + [ + num_blocks, + config.paged_kv_block_size, + config.kv_num_heads, + config.head_size, + ], + ), + helper.make_tensor_value_info( + "value_cache_out", + TensorProto.FLOAT16, + [ + num_blocks, + config.paged_kv_block_size, + config.kv_num_heads, + config.head_size, + ], + ), + ] + + graph = helper.make_graph( + nodes, + "PagedAttention_Graph", + graph_input, + graph_output, + ) + + model = helper.make_model(graph) + return model.SerializeToString() + + +def rotary_options_for_current_os(): + # Reference implementation of rotary uses triton, which is not available in Windows. + # So we only test rotary in Linux right now. + return [(False, False)] if platform.system() != "Linux" else [(True, False), (True, True), (False, False)] + + +def paged_attention_func( + config, + query, + key, + value, + key_cache, + value_cache, + cumulative_sequence_length, + past_seqlens, + block_table, + cos=None, + sin=None, + window_size=-1, +): + num_tokens = cumulative_sequence_length[-1].item() + num_blocks = key_cache.shape[0] + max_blocks_per_sequence = block_table.shape[1] + onnx_model_str = create_paged_attention_graph( + config, + num_tokens, + num_blocks, + max_blocks_per_sequence, + local_window_size=window_size, + ) + ort_inputs = { + "query": query.detach().cpu().numpy(), + "key_cache": OrtValue.ortvalue_from_numpy(key_cache.detach().cpu().numpy(), "cuda", 0), + "value_cache": OrtValue.ortvalue_from_numpy(value_cache.detach().cpu().numpy(), "cuda", 0), + "cumulative_sequence_length": cumulative_sequence_length.detach().cpu().numpy(), + "past_seqlens": past_seqlens.detach().cpu().numpy(), + "block_table": block_table.detach().cpu().numpy(), + } + sess_options = SessionOptions() + ort_session = InferenceSession(onnx_model_str, sess_options, providers=[config.ep]) + io_binding = ort_session.io_binding() + if key is not None and value is not None: + ort_inputs["key"] = key.detach().cpu().numpy() + ort_inputs["value"] = value.detach().cpu().numpy() + io_binding.bind_cpu_input("key", ort_inputs["key"]) + io_binding.bind_cpu_input("value", ort_inputs["value"]) + if cos is not None and sin is not None: + ort_inputs["cos_cache"] = cos.detach().cpu().numpy() + ort_inputs["sin_cache"] = sin.detach().cpu().numpy() + io_binding.bind_cpu_input("cos_cache", ort_inputs["cos_cache"]) + io_binding.bind_cpu_input("sin_cache", ort_inputs["sin_cache"]) + io_binding.bind_cpu_input("query", ort_inputs["query"]) + io_binding.bind_input( + "key_cache", "cuda", 0, numpy.float16, ort_inputs["key_cache"].shape(), ort_inputs["key_cache"].data_ptr() + ) + io_binding.bind_input( + "value_cache", "cuda", 0, numpy.float16, ort_inputs["value_cache"].shape(), ort_inputs["value_cache"].data_ptr() + ) + io_binding.bind_cpu_input("cumulative_sequence_length", ort_inputs["cumulative_sequence_length"]) + io_binding.bind_cpu_input("past_seqlens", ort_inputs["past_seqlens"]) + io_binding.bind_cpu_input("block_table", ort_inputs["block_table"]) + io_binding.bind_output("output") + io_binding.bind_ortvalue_output("key_cache_out", ort_inputs["key_cache"]) + io_binding.bind_ortvalue_output("value_cache_out", ort_inputs["value_cache"]) + ort_session.run_with_iobinding(io_binding) + output, key_cache_out, value_cache_out = io_binding.copy_outputs_to_cpu() + output = torch.tensor(numpy.array(output)) + return output, key_cache_out, value_cache_out + + +def construct_local_mask( + seqlen_q, + seqlen_k, + window_size=(-1, -1), # -1 means infinite window size + query_padding_mask=None, + key_padding_mask=None, + device=None, +): + row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1") + col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long) + sk = seqlen_k if key_padding_mask is None else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1") + sq = seqlen_q if query_padding_mask is None else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1") + if window_size[0] < 0: + return col_idx > row_idx + sk - sq + window_size[1] + else: + sk = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk + return torch.logical_or( + col_idx > torch.minimum(row_idx + sk - sq + window_size[1], sk), + col_idx < row_idx + sk - sq - window_size[0], + ) + + +def attention_ref( + q, + k, + v, + query_padding_mask=None, + key_padding_mask=None, + dropout_p=0.0, + dropout_mask=None, + causal=False, + window_size=(-1, -1), # -1 means infinite window size + softcap=0.0, + upcast=True, + reorder_ops=False, + use_smooth_softmax=False, +): + """ + Arguments: + q: (batch_size, seqlen_q, nheads, head_dim) + k: (batch_size, seqlen_k, nheads_k, head_dim) + v: (batch_size, seqlen_k, nheads_k, head_dim) + query_padding_mask: (batch_size, seqlen_q) + key_padding_mask: (batch_size, seqlen_k) + dropout_p: float + dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k) + causal: whether to apply causal masking + window_size: (int, int), left and right window size + upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast + output back to fp16/bf16. + reorder_ops: whether to change the order of operations (scaling k instead of scaling k, etc.) + without changing the math. This is to estimate the numerical error from operation + reordering. + Output: + output: (batch_size, seqlen_q, nheads, head_dim) + attention: (batch_size, nheads, seqlen_q, seqlen_k), softmax after dropout + """ + if causal: + window_size = (window_size[0], 0) + dtype_og = q.dtype + if upcast: + q, k, v = q.float(), k.float(), v.float() + seqlen_q, seqlen_k = q.shape[1], k.shape[1] + k = repeat(k, "b s h d -> b s (h g) d", g=q.shape[2] // k.shape[2]) + v = repeat(v, "b s h d -> b s (h g) d", g=q.shape[2] // v.shape[2]) + d = q.shape[-1] + if not reorder_ops: + scores = torch.einsum("bthd,bshd->bhts", q / math.sqrt(d), k) + else: + scores = torch.einsum("bthd,bshd->bhts", q, k / math.sqrt(d)) + if softcap > 0: + scores = scores / softcap + scores = scores.tanh() + scores = scores * softcap + if key_padding_mask is not None: + scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf")) + if window_size[0] >= 0 or window_size[1] >= 0: + local_mask = construct_local_mask( + seqlen_q, + seqlen_k, + window_size, + query_padding_mask, + key_padding_mask, + q.device, + ) + scores.masked_fill_(local_mask, float("-inf")) + + if use_smooth_softmax: + attention = smooth_softmax_ref(scores) + else: + attention = torch.softmax(scores, dim=-1) + + # Some rows might be completely masked out so we fill them with zero instead of NaN + if window_size[0] >= 0 or window_size[1] >= 0: + attention = attention.masked_fill(torch.all(local_mask, dim=-1, keepdim=True), 0.0) + # We want to mask here so that the attention matrix doesn't have any NaNs + # Otherwise we'll get NaN in dV + if query_padding_mask is not None: + attention = attention.masked_fill(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0) + dropout_scaling = 1.0 / (1 - dropout_p) + if dropout_mask is not None: + attention_drop = attention.masked_fill(~dropout_mask, 0.0) + else: + attention_drop = attention + output = torch.einsum("bhts,bshd->bthd", attention_drop, v * dropout_scaling) + if query_padding_mask is not None: + output.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0) + return output.to(dtype=dtype_og), attention.to(dtype=dtype_og) + + +def rotary_embedding(*args, **kwargs): + # Use local import since triton is not available in Windows. + from rotary_flash import apply_rotary_emb + + return apply_rotary_emb(*args, **kwargs) + + +def unpad_qkv(config: Config, q, k, v, cum_seqlens): + token_count = cum_seqlens[-1] + q_unpad = torch.zeros( + token_count, + config.num_heads * config.head_size, + dtype=torch.float16, + device="cuda", + ) + k_unpad = torch.zeros( + token_count, + config.kv_num_heads * config.head_size, + dtype=torch.float16, + device="cuda", + ) + v_unpad = torch.zeros( + token_count, + config.kv_num_heads * config.head_size, + dtype=torch.float16, + device="cuda", + ) + for i in range(config.batch_size): + new_seqlen = cum_seqlens[i + 1] - cum_seqlens[i] + q_unpad[cum_seqlens[i] : cum_seqlens[i + 1]] = rearrange(q[i, :new_seqlen], "s n h -> s (n h)") + k_unpad[cum_seqlens[i] : cum_seqlens[i + 1]] = rearrange(k[i, :new_seqlen], "s n h -> s (n h)") + v_unpad[cum_seqlens[i] : cum_seqlens[i + 1]] = rearrange(v[i, :new_seqlen], "s n h -> s (n h)") + return q_unpad, k_unpad, v_unpad + + +def generate_block_kvcache(config: Config, device, dtype): + num_blocks = math.ceil(config.total_sequence_length / config.paged_kv_block_size) * config.batch_size * 3 + k_cache_paged = torch.randn( + num_blocks, config.paged_kv_block_size, config.kv_num_heads, config.head_size, device=device, dtype=dtype + ) + v_cache_paged = torch.randn( + num_blocks, config.paged_kv_block_size, config.kv_num_heads, config.head_size, device=device, dtype=dtype + ) + block_table = rearrange( + torch.randperm(num_blocks, dtype=torch.int32, device=device), + "(b nblocks) -> b nblocks", + b=config.batch_size, + ) + k_cache = rearrange( + # pytorch 1.12 doesn't have indexing with int32 + k_cache_paged[block_table.to(dtype=torch.long).flatten()], + "(b nblocks) block_size ... -> b (nblocks block_size) ...", + b=config.batch_size, + )[:, : config.total_sequence_length] + v_cache = rearrange( + v_cache_paged[block_table.to(dtype=torch.long).flatten()], + "(b nblocks) block_size ... -> b (nblocks block_size) ...", + b=config.batch_size, + )[:, : config.total_sequence_length] + return k_cache, v_cache, block_table, k_cache_paged, v_cache_paged + + +def parity_check_paged_attention( + config: Config, + rtol=1e-3, + atol=1e-3, +): + # Generate padded inputs + q = torch.randn( + config.batch_size, + config.sequence_length, + config.num_heads, + config.head_size, + device="cuda", + dtype=torch.float16, + requires_grad=False, + ) + k_new = torch.randn( + config.batch_size, + config.sequence_length, + config.kv_num_heads, + config.head_size, + device="cuda", + dtype=torch.float16, + requires_grad=False, + ) + v_new = torch.randn( + config.batch_size, + config.sequence_length, + config.kv_num_heads, + config.head_size, + device="cuda", + dtype=torch.float16, + requires_grad=False, + ) + + # Generate random sequence lengths + past_seqlens = torch.randint( + 0, + config.total_sequence_length - config.sequence_length + 1, # one above highest integer to be drawn + (config.batch_size,), + dtype=torch.int32, + device="cuda", + ) + new_seqlens = torch.randint( + 1, + config.sequence_length + 1, + (config.batch_size,), + dtype=torch.int32, + device="cuda", + ) + cum_seqlens = torch.cat( + (torch.tensor([0], dtype=torch.int32, device="cuda"), torch.cumsum(new_seqlens, dim=0)) + ).type(torch.int32) + total_seqlens = past_seqlens + new_seqlens + + q_unpad, k_unpad, v_unpad = unpad_qkv(config, q, k_new, v_new, cum_seqlens) + + # Generate kv cache and associated block-based data structures + k_cache, v_cache, block_table, k_cache_paged, v_cache_paged = generate_block_kvcache(config, "cuda", torch.float16) + + # Set window size for local / causal + window_size = (-1, -1) + left_window_size = -1 + if config.local: + left_window_size = random.randint(0, config.total_sequence_length - 1) # random.randint is inclusive + window_size = (left_window_size, 0) + else: + left_window_size = -1 + window_size = (-1, 0) + + # Apply rotary embedding for reference implementation + if config.rotary: + rotary_fraction = 1.0 + rotary_dim = math.floor(int(rotary_fraction * config.head_size) / 16) * 16 + angle = torch.rand(config.total_sequence_length, rotary_dim // 2, device="cuda") * 2 * math.pi + cos = torch.cos(angle).to(dtype=torch.float16) + sin = torch.sin(angle).to(dtype=torch.float16) + q_ro = rotary_embedding(q, cos, sin, seqlen_offsets=past_seqlens, interleaved=config.rotary_interleaved) + k_ro = rotary_embedding(k_new, cos, sin, seqlen_offsets=past_seqlens, interleaved=config.rotary_interleaved) + else: + cos, sin = None, None + q_ro, k_ro = q, k_new + + # Update reference kv cache + k_cache_ref = k_cache.clone() + v_cache_ref = v_cache.clone() + total_range = rearrange(torch.arange(config.total_sequence_length, device="cuda"), "s -> 1 s") + past_seqlens_expanded = rearrange(past_seqlens, "b -> b 1") + update_mask = torch.logical_and( + past_seqlens_expanded <= total_range, total_range < past_seqlens_expanded + config.sequence_length + ) + k_cache_ref[update_mask] = rearrange(k_ro, "b s ... -> (b s) ...") + v_cache_ref[update_mask] = rearrange(v_new, "b s ... -> (b s) ...") + k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) + v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) + + # Create padding masks for reference implementation + total_seqlens_expanded = rearrange(total_seqlens, "b -> b 1") + key_padding_mask = total_range < total_seqlens_expanded + query_range = rearrange(torch.arange(config.sequence_length, device="cuda"), "s -> 1 s") + new_seqlens_expanded = rearrange(new_seqlens, "b -> b 1") + query_padding_mask = query_range < new_seqlens_expanded + + # Run reference implementation of attention + out_ref, _ = attention_ref( + q_ro, + k_cache_rep, + v_cache_rep, + query_padding_mask, + key_padding_mask, + 0.0, + None, + causal=True, + window_size=window_size, + softcap=config.softcap, + ) + out_ref = out_ref.detach().cpu().numpy() + + if config.packed: + q_unpad = torch.concatenate([q_unpad, k_unpad, v_unpad], dim=1) + k_unpad = None + v_unpad = None + out, updated_k_cache_paged, updated_v_cache_paged = paged_attention_func( + config, + q_unpad, + k_unpad, + v_unpad, + k_cache_paged, + v_cache_paged, + cum_seqlens, + past_seqlens, + block_table, + cos, + sin, + left_window_size, + ) + num_tokens = q_unpad.shape[0] + out = torch.reshape(out, (num_tokens, config.num_heads, config.head_size)) + out = out.detach().cpu().numpy() + + err_msg = f" with {config}" + # Make sure past-present buffer updating correctly + present_k = rearrange( + updated_k_cache_paged[block_table.to(dtype=torch.long).flatten().cpu()], + "(b nblocks) block_size ... -> b (nblocks block_size) ...", + b=config.batch_size, + )[:, : config.total_sequence_length] + present_v = rearrange( + updated_v_cache_paged[block_table.to(dtype=torch.long).flatten().cpu()], + "(b nblocks) block_size ... -> b (nblocks block_size) ...", + b=config.batch_size, + )[:, : config.total_sequence_length] + for i in range(config.batch_size): + numpy.testing.assert_allclose( + present_k[i, : total_seqlens[i]], + k_cache_ref[i, : total_seqlens[i]].detach().cpu().numpy(), + rtol=rtol, + atol=atol, + equal_nan=True, + err_msg=err_msg, + ) + numpy.testing.assert_allclose( + present_v[i, : total_seqlens[i]], + v_cache_ref[i, : total_seqlens[i]].detach().cpu().numpy(), + rtol=rtol, + atol=atol, + equal_nan=True, + err_msg=err_msg, + ) + new_seqlen = cum_seqlens[i + 1] - cum_seqlens[i] + out_i = out[cum_seqlens[i] : cum_seqlens[i + 1]] + out_ref_i = out_ref[i, :new_seqlen] + numpy.testing.assert_allclose(out_i, out_ref_i, rtol=rtol, atol=atol, equal_nan=True, err_msg=err_msg) + + +def has_flash_attention(): + if not torch.cuda.is_available(): + return False + if "CUDAExecutionProvider" not in get_available_providers(): + return False + major, _ = torch.cuda.get_device_capability() + return major >= 8 and ( + platform.system() == "Linux" + or (platform.system() == "Windows" and version.parse(torch.version.cuda) >= version.parse("12.0")) + ) + + +def paged_attention_test_cases(): + batches = [4] if pipeline_mode else [1, 3, 5] + seqs = ( + [(1025, 2047)] + if pipeline_mode + else [ + (3, 1024), + (1, 339), + (408, 800), + (333, 799), + (64, 2048), + (837, 4000), + (17, 49), + (257, 257), + (459, 459), + ] + ) + num_h = [(32, 8)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] + h_sizes = [256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] + block_sizes = [256] if pipeline_mode else [256, 512] + + for b in batches: + for s, s2 in seqs: + for n, n2 in num_h: + for h in h_sizes: + for block_size in block_sizes: + for local in [False, True]: + for rotary, rotary_interleaved in rotary_options_for_current_os(): + for packed in [False, True]: + for softcap in [0.0, 50.0]: + if rotary and h % 16 > 0: + continue + + config = Config( + b, + s, + s2, + n, + n2, + h, + block_size, + local, + rotary, + rotary_interleaved, + packed, + softcap, + ) + yield ( + str(config), + config, + ) + + +@unittest.skipIf(not has_flash_attention(), reason="Flash Attention is not available, skipping tests.") +class TestPagedAttention(unittest.TestCase): + @parameterized.expand(paged_attention_test_cases()) + def test_paged_attention(self, _, config): + parity_check_paged_attention(config, rtol=5e-3, atol=5e-3) + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/onnxruntime/test/shared_lib/custom_op_utils.cc b/onnxruntime/test/shared_lib/custom_op_utils.cc index a624479bcd00b..432d78927a1ab 100644 --- a/onnxruntime/test/shared_lib/custom_op_utils.cc +++ b/onnxruntime/test/shared_lib/custom_op_utils.cc @@ -35,9 +35,11 @@ void MyCustomKernel::Compute(OrtKernelContext* context) { int64_t size = output_info.GetElementCount(); #ifdef USE_CUDA - OrtMemoryInfo mem_info("", OrtAllocatorType::OrtDeviceAllocator, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, 0)); + OrtMemoryInfo mem_info("", OrtAllocatorType::OrtDeviceAllocator, + OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::NVIDIA, 0)); #else - OrtMemoryInfo mem_info("", OrtAllocatorType::OrtArenaAllocator, OrtDevice(OrtDevice::CPU, OrtDevice::MemType::DEFAULT, 0)); + OrtMemoryInfo mem_info("", OrtAllocatorType::OrtArenaAllocator, + OrtDevice(OrtDevice::CPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::NONE, 0)); #endif OrtAllocator* allocator; Ort::ThrowOnError(ort_.KernelContext_GetAllocator(context, &mem_info, &allocator)); diff --git a/onnxruntime/test/shared_lib/dlopen_main.cc b/onnxruntime/test/shared_lib/dlopen_main.cc index 4eb9f9ca2f567..09a0647543521 100644 --- a/onnxruntime/test/shared_lib/dlopen_main.cc +++ b/onnxruntime/test/shared_lib/dlopen_main.cc @@ -142,7 +142,7 @@ int main() { const OrtApi* g_ort_api_instance = nullptr; std::cout << "Attempting to test ONNX Runtime dynamic load/unload..." << std::endl; - + int retval = 0; try { #ifdef _DEBUG _CrtMemCheckpoint(&s1); @@ -290,8 +290,7 @@ int main() { } else { std::cout << "\nHEAP_DEBUG: s3_diff_final did not show a net increase in _NORMAL_BLOCKs." << std::endl; } - } else { - std::cout << "\nHEAP_DEBUG: No overall memory difference detected between s1 (before load) and s2 (after unload)." << std::endl; + retval = EXIT_FAILURE; } } #endif @@ -310,5 +309,5 @@ int main() { } std::cout << "Program finished." << std::endl; - return 0; + return retval; } \ No newline at end of file diff --git a/onnxruntime/test/shared_lib/test_allocator.cc b/onnxruntime/test/shared_lib/test_allocator.cc index f3d390f64d4b6..29f3dfad0f11d 100644 --- a/onnxruntime/test/shared_lib/test_allocator.cc +++ b/onnxruntime/test/shared_lib/test_allocator.cc @@ -3,8 +3,12 @@ #include "core/session/onnxruntime_cxx_api.h" #include "core/providers/cpu/cpu_provider_factory.h" +#include "test/shared_lib/test_fixture.h" +#include "test/util/include/test_allocator.h" #include +extern std::unique_ptr ort_env; + TEST(CApiTest, allocation_info) { auto cpu_mem_info_1 = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault); auto cpu_mem_info_2 = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault); @@ -36,3 +40,48 @@ TEST(CApiTest, DefaultAllocator) { Ort::KeyValuePairs stats = default_allocator.GetStats(); ASSERT_EQ(0, stats.GetKeyValuePairs().size()); } + +#if !defined(ORT_MINIMAL_BUILD) +TEST(CApiTest, CustomAllocator) { + constexpr PATH_TYPE model_path = TSTR("testdata/mul_1.onnx"); + + const auto& api = Ort::GetApi(); + + // Case 1: Register a custom allocator. + { + MockedOrtAllocator mocked_allocator; + ASSERT_TRUE(api.RegisterAllocator(*ort_env, &mocked_allocator) == nullptr); + + Ort::SessionOptions session_options; + session_options.AddConfigEntry("session.use_env_allocators", "1"); + Ort::Session session(*ort_env, model_path, session_options); + + Ort::Allocator allocator(session, mocked_allocator.Info()); + + auto stats = allocator.GetStats(); + ASSERT_EQ(mocked_allocator.NumAllocations(), std::stoll(stats.GetValue("NumAllocs"))); + ASSERT_EQ(mocked_allocator.NumReserveAllocations(), std::stoll(stats.GetValue("NumReserves"))); + + ASSERT_TRUE(api.UnregisterAllocator(*ort_env, mocked_allocator.Info()) == nullptr); + } + + // Case 2: Register a custom allocator with an older API version which does not support GetStats. + { + MockedOrtAllocator mocked_allocator; + mocked_allocator.version = 22; + ASSERT_TRUE(api.RegisterAllocator(*ort_env, &mocked_allocator) == nullptr); + + Ort::SessionOptions session_options; + session_options.AddConfigEntry("session.use_env_allocators", "1"); + Ort::Session session(*ort_env, model_path, session_options); + + Ort::Allocator allocator(session, mocked_allocator.Info()); + + // Custom allocator does not implement GetStats, we expect the stats to be empty. + auto stats = allocator.GetStats(); + ASSERT_EQ(0, stats.GetKeyValuePairs().size()); + + ASSERT_TRUE(api.UnregisterAllocator(*ort_env, mocked_allocator.Info()) == nullptr); + } +} +#endif diff --git a/onnxruntime/test/testdata/custom_execution_provider_library/my_allocator.cc b/onnxruntime/test/testdata/custom_execution_provider_library/my_allocator.cc index aca03e7a6b4f6..3519c61d0abb5 100644 --- a/onnxruntime/test/testdata/custom_execution_provider_library/my_allocator.cc +++ b/onnxruntime/test/testdata/custom_execution_provider_library/my_allocator.cc @@ -7,7 +7,8 @@ namespace onnxruntime { MyEPAllocator::MyEPAllocator(OrtDevice::DeviceId device_id) : IAllocator(OrtMemoryInfo(MyEP, OrtAllocatorType::OrtArenaAllocator, - OrtDevice(MyEPDevice, OrtDevice::MemType::DEFAULT, device_id))) { + OrtDevice(MyEPDevice, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::NONE, + device_id))) { } #if defined(_MSC_VER) && !defined(__clang__) #pragma warning(disable : 26400) diff --git a/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc b/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc index 341d48342dce8..02a0c5c82b255 100644 --- a/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc +++ b/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc @@ -309,7 +309,29 @@ "^test_dequantizelinear_int4", "^test_dequantizelinear_uint4", "^test_quantizelinear_int4", - "^test_quantizelinear_uint4" + "^test_quantizelinear_uint4", + // ONNX 1.18 has invalid test model for RMSNorm operator. + "^test_rms_normalization_2d_axis0_expanded", + "^test_rms_normalization_2d_axis1_expanded", + "^test_rms_normalization_2d_axis_negative_1_expanded", + "^test_rms_normalization_2d_axis_negative_2_expanded", + "^test_rms_normalization_3d_axis0_epsilon_expanded", + "^test_rms_normalization_3d_axis1_epsilon_expanded", + "^test_rms_normalization_3d_axis2_epsilon_expanded", + "^test_rms_normalization_3d_axis_negative_1_epsilon_expanded", + "^test_rms_normalization_3d_axis_negative_2_epsilon_expanded", + "^test_rms_normalization_3d_axis_negative_3_epsilon_expanded", + "^test_rms_normalization_4d_axis0_expanded", + "^test_rms_normalization_4d_axis1_expanded", + "^test_rms_normalization_4d_axis2_expanded", + "^test_rms_normalization_4d_axis3_expanded", + "^test_rms_normalization_4d_axis_negative_1_expanded", + "^test_rms_normalization_4d_axis_negative_2_expanded", + "^test_rms_normalization_4d_axis_negative_3_expanded", + "^test_rms_normalization_4d_axis_negative_4_expanded", + "^test_rms_normalization_default_axis_expanded", + // topk uint64 is not implemented in ORT yet. + "^test_top_k_uint64" ], "current_failing_tests_x86": [ "^test_vgg19", diff --git a/onnxruntime/test/testdata/three_layer_nested_subgraph.onnx b/onnxruntime/test/testdata/three_layer_nested_subgraph.onnx new file mode 100644 index 0000000000000..647e80a7f3b38 Binary files /dev/null and b/onnxruntime/test/testdata/three_layer_nested_subgraph.onnx differ diff --git a/onnxruntime/test/testdata/webgpu_pow_cast_test.onnx b/onnxruntime/test/testdata/webgpu_pow_cast_test.onnx new file mode 100644 index 0000000000000..28eb5cf26d9b6 --- /dev/null +++ b/onnxruntime/test/testdata/webgpu_pow_cast_test.onnx @@ -0,0 +1,23 @@ + +:‹ + +x +y +pow_output/Pow"Pow +) + +pow_outputout/Cast"Cast* +to  +Main_graphZ +x + + +Z +y + + +b +out + + +B \ No newline at end of file diff --git a/onnxruntime/test/testdata/webgpu_pow_cast_test.py b/onnxruntime/test/testdata/webgpu_pow_cast_test.py new file mode 100644 index 0000000000000..b3d41aefa0119 --- /dev/null +++ b/onnxruntime/test/testdata/webgpu_pow_cast_test.py @@ -0,0 +1,40 @@ +import onnx +from onnx import TensorProto, helper + +# tests fix for precision error in florence2 model by using WebGPU built-in sqrt(x) function instead of pow(x, y) when the exponent is 0.5. +# The sqrt(x) built-in is both faster and more stable than using pow(x, 0.5). +# Example: +# Cast: input = 576 (int), output = 576 (f32) +# Pow: input = 576(f32), 0.5(f32), output = 23.99999(f32) +# Cast: input = 23.9999(f32), output = 23(int) + + +graph_proto = helper.make_graph( + [ + helper.make_node( + "Pow", + inputs=["x", "y"], + outputs=["pow_output"], + name="/Pow", + ), + helper.make_node( + "Cast", + inputs=["pow_output"], + outputs=["out"], + name="/Cast", + to=TensorProto.INT64, + ), + ], + "Main_graph", + [ + helper.make_tensor_value_info("x", TensorProto.FLOAT, [1]), + helper.make_tensor_value_info("y", TensorProto.FLOAT, [1]), + ], + [ + helper.make_tensor_value_info("out", TensorProto.INT64, [1]), + ], +) + +model = helper.make_model(graph_proto) +onnx.checker.check_model(model, True) +onnx.save(model, "webgpu_pow_cast_test.onnx") diff --git a/onnxruntime/test/util/default_providers.cc b/onnxruntime/test/util/default_providers.cc index a52043ee20207..9e563b3342dae 100644 --- a/onnxruntime/test/util/default_providers.cc +++ b/onnxruntime/test/util/default_providers.cc @@ -307,15 +307,22 @@ std::unique_ptr DefaultXnnpackExecutionProvider() { #endif } -std::unique_ptr DefaultWebGpuExecutionProvider() { +std::unique_ptr DefaultWebGpuExecutionProvider(bool is_nhwc) { #ifdef USE_WEBGPU ConfigOptions config_options{}; // Disable storage buffer cache ORT_ENFORCE(config_options.AddConfigEntry(webgpu::options::kStorageBufferCacheMode, webgpu::options::kBufferCacheMode_Disabled) .IsOK()); + if (!is_nhwc) { + // Enable NCHW support + ORT_ENFORCE(config_options.AddConfigEntry(webgpu::options::kPreferredLayout, + webgpu::options::kPreferredLayout_NCHW) + .IsOK()); + } return WebGpuProviderFactoryCreator::Create(config_options)->CreateProvider(); #else + ORT_UNUSED_PARAMETER(is_nhwc); return nullptr; #endif } diff --git a/onnxruntime/test/util/include/default_providers.h b/onnxruntime/test/util/include/default_providers.h index 3595c6f71633a..ce6434991051c 100644 --- a/onnxruntime/test/util/include/default_providers.h +++ b/onnxruntime/test/util/include/default_providers.h @@ -63,7 +63,7 @@ std::unique_ptr DefaultQnnExecutionProvider(); std::unique_ptr QnnExecutionProviderWithOptions(const ProviderOptions& options, const SessionOptions* session_options = nullptr); std::unique_ptr DefaultXnnpackExecutionProvider(); -std::unique_ptr DefaultWebGpuExecutionProvider(); +std::unique_ptr DefaultWebGpuExecutionProvider(bool is_nhwc = true); std::unique_ptr DefaultCannExecutionProvider(); std::unique_ptr DefaultDmlExecutionProvider(); diff --git a/onnxruntime/test/util/include/test_allocator.h b/onnxruntime/test/util/include/test_allocator.h index c700098c87f33..1366ce37603c1 100644 --- a/onnxruntime/test/util/include/test_allocator.h +++ b/onnxruntime/test/util/include/test_allocator.h @@ -15,6 +15,7 @@ struct MockedOrtAllocator : OrtAllocator { void Free(void* p); const OrtMemoryInfo* Info() const; void* Reserve(size_t size); + OrtKeyValuePairs* Stats() const; size_t NumAllocations() const; size_t NumReserveAllocations() const; diff --git a/onnxruntime/test/util/test_allocator.cc b/onnxruntime/test/util/test_allocator.cc index 05dd454e875d5..393f6aeb7eef1 100644 --- a/onnxruntime/test/util/test_allocator.cc +++ b/onnxruntime/test/util/test_allocator.cc @@ -10,6 +10,10 @@ MockedOrtAllocator::MockedOrtAllocator() { OrtAllocator::Free = [](OrtAllocator* this_, void* p) { static_cast(this_)->Free(p); }; OrtAllocator::Info = [](const OrtAllocator* this_) { return static_cast(this_)->Info(); }; OrtAllocator::Reserve = [](OrtAllocator* this_, size_t size) { return static_cast(this_)->Reserve(size); }; + OrtAllocator::GetStats = [](const OrtAllocator* this_, OrtKeyValuePairs** stats) noexcept -> OrtStatusPtr { + *stats = static_cast(this_)->Stats(); + return nullptr; + }; Ort::ThrowOnError(Ort::GetApi().CreateCpuMemoryInfo(OrtDeviceAllocator, OrtMemTypeDefault, &cpu_memory_info)); } @@ -56,6 +60,18 @@ const OrtMemoryInfo* MockedOrtAllocator::Info() const { return cpu_memory_info; } +OrtKeyValuePairs* MockedOrtAllocator::Stats() const { + Ort::KeyValuePairs kvps; + + auto str = std::to_string(num_allocations.load()); + kvps.Add("NumAllocs", str.c_str()); + + str = std::to_string(num_reserve_allocations.load()); + kvps.Add("NumReserves", str.c_str()); + + return kvps.release(); +} + size_t MockedOrtAllocator::NumAllocations() const { return num_allocations.load(); } diff --git a/onnxruntime/wasm/post-webnn.js b/onnxruntime/wasm/post-webnn.js new file mode 100644 index 0000000000000..3120e87b13384 --- /dev/null +++ b/onnxruntime/wasm/post-webnn.js @@ -0,0 +1,82 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// +// This file contains the post-run code for the ORT WebAssembly module. The code in this file will be injected into the +// final module using Emscripten's `--post-js` option. +// +// This file will only be used in build with flag `--use_webnn`. + +/** + * This function is called only once when initializing the WebNN backend. + * + * @param params WebNN initialization parameters. + */ +Module["webnnInit"] = (params) => { + // Functions called from EM_ASM need to be assigned in a way that can be minified. + // Functions called via emscripten::val::module_property need to be assigned by name so that the minifier doesn't + // change the name. + + const backend = params[0]; + [ + Module.webnnReserveTensorId, + Module.webnnReleaseTensorId, + Module["webnnEnsureTensor"], + Module.webnnUploadTensor, + Module["webnnDownloadTensor"], + Module.webnnRegisterMLContext, + Module["webnnEnableTraceEvent"], + ] = params.slice(1); + + // This function is called from both JS and an EM_ASM block, it needs both a minifiable name and an explicit name. + Module["webnnReleaseTensorId"] = Module.webnnReleaseTensorId; + Module["webnnUploadTensor"] = Module.webnnUploadTensor; + Module["webnnRegisterMLContext"] = Module.webnnRegisterMLContext; + + // Functions called from JS also need to have explicit names. + Module["webnnOnRunStart"] = (sessionId) => { + return backend["onRunStart"](sessionId); + }; + Module["webnnOnRunEnd"] = backend["onRunEnd"].bind(backend); + Module["webnnOnReleaseSession"] = (sessionId) => { + backend["onReleaseSession"](sessionId); + }; + Module["webnnCreateMLTensorDownloader"] = (tensorId, type) => { + return backend["createMLTensorDownloader"](tensorId, type); + }; + Module["webnnRegisterMLTensor"] = (sessionId, tensor, dataType, shape) => { + return backend["registerMLTensor"](sessionId, tensor, dataType, shape); + }; + Module["webnnCreateMLContext"] = (optionsOrGpuDevice) => { + return backend["createMLContext"](optionsOrGpuDevice); + }; + Module["webnnRegisterMLConstant"] = ( + externalFilePath, + dataOffset, + dataLength, + builder, + desc, + shouldConvertInt64ToInt32 + ) => { + return backend["registerMLConstant"]( + externalFilePath, + dataOffset, + dataLength, + builder, + desc, + Module.MountedFiles, + shouldConvertInt64ToInt32 + ); + }; + Module["webnnRegisterGraphInput"] = + backend["registerGraphInput"].bind(backend); + Module["webnnIsGraphInput"] = backend["isGraphInput"].bind(backend); + Module["webnnRegisterGraphOutput"] = + backend["registerGraphOutput"].bind(backend); + Module["webnnIsGraphOutput"] = backend["isGraphOutput"].bind(backend); + + Module["webnnCreateTemporaryTensor"] = + backend["createTemporaryTensor"].bind(backend); + Module["webnnIsGraphInputOutputTypeSupported"] = + backend["isGraphInputOutputTypeSupported"].bind(backend); +}; diff --git a/onnxruntime/wasm/pre-jsep.js b/onnxruntime/wasm/pre-jsep.js index 8232a286d4480..21f32859c576b 100644 --- a/onnxruntime/wasm/pre-jsep.js +++ b/onnxruntime/wasm/pre-jsep.js @@ -104,21 +104,20 @@ Module["jsepInit"] = (name, params) => { Module["webnnEnsureTensor"], Module.webnnUploadTensor, Module["webnnDownloadTensor"], + Module.webnnRegisterMLContext, Module["webnnEnableTraceEvent"], ] = params.slice(1); // This function is called from both JS and an EM_ASM block, it needs both a minifiable name and an explicit name. Module["webnnReleaseTensorId"] = Module.webnnReleaseTensorId; Module["webnnUploadTensor"] = Module.webnnUploadTensor; + Module["webnnRegisterMLContext"] = Module.webnnRegisterMLContext; // Functions called from JS also need to have explicit names. Module["webnnOnRunStart"] = (sessionId) => { return backend["onRunStart"](sessionId); }; Module["webnnOnRunEnd"] = backend["onRunEnd"].bind(backend); - Module["webnnRegisterMLContext"] = (sessionId, mlContext) => { - backend["registerMLContext"](sessionId, mlContext); - }; Module["webnnOnReleaseSession"] = (sessionId) => { backend["onReleaseSession"](sessionId); }; diff --git a/orttraining/orttraining/core/graph/graph_augmenter.cc b/orttraining/orttraining/core/graph/graph_augmenter.cc index 19b200efcf6bb..1fde22b32451b 100644 --- a/orttraining/orttraining/core/graph/graph_augmenter.cc +++ b/orttraining/orttraining/core/graph/graph_augmenter.cc @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include "core/graph/graph_utils.h" #include "orttraining/core/graph/graph_augmenter.h" #include "core/common/logging/logging.h" diff --git a/orttraining/orttraining/core/graph/mixed_precision_transformer.cc b/orttraining/orttraining/core/graph/mixed_precision_transformer.cc index a4143e7c817fd..888664dda8806 100644 --- a/orttraining/orttraining/core/graph/mixed_precision_transformer.cc +++ b/orttraining/orttraining/core/graph/mixed_precision_transformer.cc @@ -529,7 +529,7 @@ Status TransformGraphForMixedPrecision(Graph& graph, // Add new FP16/BFloat16 initializers to the graph for (const auto& kv : mixed_precision_initializers) { const ONNX_NAMESPACE::TensorProto* tensor_proto = kv.second; - Initializer initializer(*tensor_proto, graph.ModelPath()); + Initializer initializer(graph, *tensor_proto, graph.ModelPath()); ONNX_NAMESPACE::TensorProto weight_tensor_proto = mixed_precision_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16 ? initializer.ToFP16(kv.first) : initializer.ToBFloat16(kv.first); graph.AddInitializedTensor(weight_tensor_proto); } diff --git a/orttraining/orttraining/core/graph/zero_optimizer_graph_builder.cc b/orttraining/orttraining/core/graph/zero_optimizer_graph_builder.cc index 3f4034a9db222..8c9072614a4b0 100644 --- a/orttraining/orttraining/core/graph/zero_optimizer_graph_builder.cc +++ b/orttraining/orttraining/core/graph/zero_optimizer_graph_builder.cc @@ -125,8 +125,8 @@ static std::vector AddPartitionsForParameter( ORT_ENFORCE(dtype == ONNX_NAMESPACE::TensorProto_DataType_FLOAT); // Find the initializer partition to read out. - auto initializer = std::make_unique(*tensor_proto, graph.ModelPath()); - const float* initializer_data = initializer->data(); + auto initializer = Initializer{graph, *tensor_proto, graph.ModelPath()}; + const float* initializer_data = initializer.data(); // Create new initializer tensor proto. ONNX_NAMESPACE::TensorProto initializer_partition; diff --git a/orttraining/orttraining/core/optimizer/conv1d_replacement.cc b/orttraining/orttraining/core/optimizer/conv1d_replacement.cc index ff220fcb067b8..90be9e24d3dd4 100644 --- a/orttraining/orttraining/core/optimizer/conv1d_replacement.cc +++ b/orttraining/orttraining/core/optimizer/conv1d_replacement.cc @@ -121,7 +121,7 @@ void Conv1dToMatmul(Graph& graph, Node& conv, const std::string transformer_name initializer_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); InlinedVector initializer_proto_value{weight_squeeze_axis}; initializer_proto.set_raw_data(initializer_proto_value.data(), initializer_proto_value.size() * sizeof(int64_t)); - auto& axes_input = graph_utils::AddInitializer(graph, initializer_proto); + auto& axes_input = graph_utils::AddInitializerWithExternalData(graph, initializer_proto); // Squeeze node doesn't have opschema here, so we need to set input args count manually weight_squeeze.MutableInputArgsCount().resize(2); graph_utils::AddNodeInput(weight_squeeze, 1, axes_input); diff --git a/orttraining/orttraining/core/optimizer/megatron_transformer.cc b/orttraining/orttraining/core/optimizer/megatron_transformer.cc index 25e16304789b6..55286379fd273 100644 --- a/orttraining/orttraining/core/optimizer/megatron_transformer.cc +++ b/orttraining/orttraining/core/optimizer/megatron_transformer.cc @@ -171,8 +171,8 @@ bool MegatronTransformer::PartitionWeightByColumn(const Graph& graph, const Node LOGS_DEFAULT(WARNING) << "Checkpointing is not currently supported for graphs requiring partitioning of weight with stride > 1"; } - auto initializer = std::make_unique(*tensor_proto, graph.ModelPath()); - const float* a_weight = initializer->data(); + auto initializer = Initializer{graph, *tensor_proto, graph.ModelPath()}; + const float* a_weight = initializer.data(); std::string new_initializer_name = original_name + "_column_rank_" + std::to_string(horizontal_parallel_rank_); @@ -306,8 +306,8 @@ bool MegatronTransformer::PartitionWeightByRow(const Graph& graph, const NodeArg << horizontal_parallel_size_ << ", not supported currently."; return false; } - auto initializer = std::make_unique(*tensor_proto, graph.ModelPath()); - const float* a_weight = initializer->data(); + auto initializer = Initializer{graph, *tensor_proto, graph.ModelPath()}; + const float* a_weight = initializer.data(); std::string new_initializer_name = original_name + "_row_rank_" + std::to_string(horizontal_parallel_rank_); @@ -453,15 +453,15 @@ Status MegatronTransformer::TransformGPT2MLP(Graph& graph, bool& modified, return skip_status; } - NodeArg& a_weight_partition_arg = graph_utils::AddInitializer(graph, a_weight_initializer_partition); + NodeArg& a_weight_partition_arg = graph_utils::AddInitializerWithExternalData(graph, a_weight_initializer_partition); graph_utils::ReplaceNodeInput(node, 1, a_weight_partition_arg); updated_weight_names_.insert({a_weight_arg->Name(), a_weight_partition_arg.Name()}); - NodeArg& a_bias_partition_arg = graph_utils::AddInitializer(graph, a_bias_initializer_partition); + NodeArg& a_bias_partition_arg = graph_utils::AddInitializerWithExternalData(graph, a_bias_initializer_partition); graph_utils::ReplaceNodeInput(add_node, 1, a_bias_partition_arg); updated_weight_names_.insert({b_weight_arg->Name(), a_bias_partition_arg.Name()}); - NodeArg& b_weight_partition_arg = graph_utils::AddInitializer(graph, b_weight_initializer_partition); + NodeArg& b_weight_partition_arg = graph_utils::AddInitializerWithExternalData(graph, b_weight_initializer_partition); graph_utils::ReplaceNodeInput(matmul2_node, 1, b_weight_partition_arg); updated_weight_names_.insert({a_bias_arg->Name(), b_weight_partition_arg.Name()}); @@ -600,15 +600,15 @@ Status MegatronTransformer::TransformBARTMLP(Graph& graph, bool& modified, return skip_status; } - NodeArg& dense_wi_weight_partition_arg = graph_utils::AddInitializer(graph, dense_wi_weight_initializer_partition); + NodeArg& dense_wi_weight_partition_arg = graph_utils::AddInitializerWithExternalData(graph, dense_wi_weight_initializer_partition); graph_utils::ReplaceNodeInput(*second_op, 0, dense_wi_weight_partition_arg); updated_weight_names_.insert({dense_wi_weight_arg->Name(), dense_wi_weight_partition_arg.Name()}); - NodeArg& dense_wi_bias_partition_arg = graph_utils::AddInitializer(graph, dense_wi_bias_initializer_partition); + NodeArg& dense_wi_bias_partition_arg = graph_utils::AddInitializerWithExternalData(graph, dense_wi_bias_initializer_partition); graph_utils::ReplaceNodeInput(biasgelu_node, 1, dense_wi_bias_partition_arg); updated_weight_names_.insert({dense_wi_bias_arg->Name(), dense_wi_bias_partition_arg.Name()}); - NodeArg& dense_wo_weight_partition_arg = graph_utils::AddInitializer(graph, dense_wo_weight_initializer_partition); + NodeArg& dense_wo_weight_partition_arg = graph_utils::AddInitializerWithExternalData(graph, dense_wo_weight_initializer_partition); graph_utils::ReplaceNodeInput(*transpose_op_ptr, 0, dense_wo_weight_partition_arg); updated_weight_names_.insert({dense_wo_weight_arg->Name(), dense_wo_weight_partition_arg.Name()}); @@ -787,15 +787,15 @@ Status MegatronTransformer::TransformGPT2Attention(Graph& graph, bool& modified, // The number of the values should be more than 2, and the 3rd value should be divisible by parallel size, // i.e., the attention head number should be divisible by parallel size. - auto init_const = std::make_unique(*tensor, graph.ModelPath()); - if (init_const->size() != 3 && init_const->size() != 4) { + auto init_const = Initializer{graph, *tensor, graph.ModelPath()}; + if (init_const.size() != 3 && init_const.size() != 4) { is_reshape_valid = false; break; } - const int64_t* val = init_const->data(); + const int64_t* val = init_const.data(); if (val[2] % horizontal_parallel_size_ != 0) { - LOGS_DEFAULT(WARNING) << (init_const->size() == 3 ? "Hidden size " : "Number of attention heads ") << val[2] + LOGS_DEFAULT(WARNING) << (init_const.size() == 3 ? "Hidden size " : "Number of attention heads ") << val[2] << " is not divisible by horizontal_parallel_size_ " << horizontal_parallel_size_ << ", not supported currently."; is_reshape_valid = false; @@ -814,15 +814,15 @@ Status MegatronTransformer::TransformGPT2Attention(Graph& graph, bool& modified, [](Node* node_ptr) { return node_ptr != nullptr; }); // Replace by the partition weights. - NodeArg& qkv_weight_partition_arg = graph_utils::AddInitializer(graph, qkv_weight_initializer_partition); + NodeArg& qkv_weight_partition_arg = graph_utils::AddInitializerWithExternalData(graph, qkv_weight_initializer_partition); graph_utils::ReplaceNodeInput(node, 1, qkv_weight_partition_arg); updated_weight_names_.insert({qkv_weight_arg->Name(), qkv_weight_partition_arg.Name()}); - NodeArg& qkv_bias_partition_arg = graph_utils::AddInitializer(graph, qkv_bias_initializer_partition); + NodeArg& qkv_bias_partition_arg = graph_utils::AddInitializerWithExternalData(graph, qkv_bias_initializer_partition); graph_utils::ReplaceNodeInput(add_node, 1, qkv_bias_partition_arg); updated_weight_names_.insert({qkv_bias_arg->Name(), qkv_bias_partition_arg.Name()}); - NodeArg& dense_weight_partition_arg = graph_utils::AddInitializer(graph, dense_weight_initializer_partition); + NodeArg& dense_weight_partition_arg = graph_utils::AddInitializerWithExternalData(graph, dense_weight_initializer_partition); graph_utils::ReplaceNodeInput(matmul_node, 1, dense_weight_partition_arg); updated_weight_names_.insert({dense_weight_arg->Name(), dense_weight_partition_arg.Name()}); @@ -836,9 +836,9 @@ Status MegatronTransformer::TransformGPT2Attention(Graph& graph, bool& modified, const ONNX_NAMESPACE::TensorProto* tensor; graph.GetInitializedTensor(shape_arg->Name(), tensor); auto data_type = tensor->data_type(); - auto init_const = std::make_unique(*tensor, graph.ModelPath()); - const int64_t* val = init_const->data(); - int64_t size = init_const->size(); + auto init_const = Initializer{graph, *tensor, graph.ModelPath()}; + const int64_t* val = init_const.data(); + int64_t size = init_const.size(); ONNX_NAMESPACE::TensorProto tensor_partition; tensor_partition.set_name(graph.GenerateNodeArgName("partition_" + shape_arg->Name())); tensor_partition.set_data_type(data_type); @@ -849,7 +849,7 @@ Status MegatronTransformer::TransformGPT2Attention(Graph& graph, bool& modified, val_partition.insert(val_partition.end(), val, val + size); val_partition[2] /= horizontal_parallel_size_; tensor_partition.set_raw_data(val_partition.data(), size * sizeof(int64_t)); - NodeArg& node_arg_partition = graph_utils::AddInitializer(graph, tensor_partition); + NodeArg& node_arg_partition = graph_utils::AddInitializerWithExternalData(graph, tensor_partition); graph_utils::ReplaceNodeInput(*node_ptr, 1, node_arg_partition); graph.RemoveInitializedTensor(shape_arg->Name()); } @@ -1068,12 +1068,12 @@ Status MegatronTransformer::TransformBARTAttention(Graph& graph, bool& modified, } // The number of the values should be more than idx, and the idx'th value should be divisible by parallel size, // i.e., the attention head number should be divisible by parallel size. - auto init_const = std::make_unique(*tensor, graph.ModelPath()); - if (init_const->size() <= idx) { + auto init_const = Initializer{graph, *tensor, graph.ModelPath()}; + if (init_const.size() <= idx) { is_reshape_valid = false; break; } - const int64_t* val = init_const->data(); + const int64_t* val = init_const.data(); if (val[idx] % horizontal_parallel_size_ != 0) { LOGS_DEFAULT(WARNING) << "dim[" << idx << "]: " << val[idx] << " is not divisible by horizontal_parallel_size_ " @@ -1130,7 +1130,7 @@ Status MegatronTransformer::TransformBARTAttention(Graph& graph, bool& modified, size_t i = 0; for (auto trans_ptr : weight_transpose_node_ptrs) { auto weight_name = trans_ptr->MutableInputDefs()[0]->Name(); - NodeArg& qkv_weight_partition_arg = graph_utils::AddInitializer(graph, qkv_weight_initializer_partitions[i]); + NodeArg& qkv_weight_partition_arg = graph_utils::AddInitializerWithExternalData(graph, qkv_weight_initializer_partitions[i]); graph_utils::ReplaceNodeInput(*trans_ptr, 0, qkv_weight_partition_arg); graph.RemoveInitializedTensor(weight_name); updated_weight_names_.insert({weight_name, qkv_weight_partition_arg.Name()}); @@ -1139,14 +1139,14 @@ Status MegatronTransformer::TransformBARTAttention(Graph& graph, bool& modified, i = 0; for (auto add_ptr : bias_add_node_ptrs) { auto bias_name = add_ptr->MutableInputDefs()[1]->Name(); - NodeArg& qkv_bias_partition_arg = graph_utils::AddInitializer(graph, qkv_bias_initializer_partitions[i]); + NodeArg& qkv_bias_partition_arg = graph_utils::AddInitializerWithExternalData(graph, qkv_bias_initializer_partitions[i]); graph_utils::ReplaceNodeInput(*add_ptr, 1, qkv_bias_partition_arg); graph.RemoveInitializedTensor(bias_name); updated_weight_names_.insert({bias_name, qkv_bias_partition_arg.Name()}); i++; } - NodeArg& dense_weight_partition_arg = graph_utils::AddInitializer(graph, dense_weight_initializer_partition); + NodeArg& dense_weight_partition_arg = graph_utils::AddInitializerWithExternalData(graph, dense_weight_initializer_partition); graph_utils::ReplaceNodeInput(*last_transpose, 0, dense_weight_partition_arg); graph.RemoveInitializedTensor(dense_weight_arg->Name()); updated_weight_names_.insert({dense_weight_arg->Name(), dense_weight_partition_arg.Name()}); @@ -1162,11 +1162,12 @@ Status MegatronTransformer::TransformBARTAttention(Graph& graph, bool& modified, int64_t idx = x.second; auto shape_arg = node_ptr->MutableInputDefs()[1]; const ONNX_NAMESPACE::TensorProto* tensor; - graph.GetInitializedTensor(shape_arg->Name(), tensor); + ORT_RETURN_IF_NOT(graph.GetInitializedTensor(shape_arg->Name(), tensor), + "Expecting initializer present: ", shape_arg->Name()); auto data_type = tensor->data_type(); - auto init_const = std::make_unique(*tensor, graph.ModelPath()); - const int64_t* val = init_const->data(); - int64_t size = init_const->size(); + auto init_const = Initializer{graph, *tensor, graph.ModelPath()}; + const int64_t* val = init_const.data(); + int64_t size = init_const.size(); ONNX_NAMESPACE::TensorProto tensor_partition; tensor_partition.set_name(graph.GenerateNodeArgName("partition_" + shape_arg->Name())); tensor_partition.set_data_type(data_type); @@ -1177,7 +1178,7 @@ Status MegatronTransformer::TransformBARTAttention(Graph& graph, bool& modified, val_partition.insert(val_partition.end(), val, val + size); val_partition[idx] /= horizontal_parallel_size_; tensor_partition.set_raw_data(val_partition.data(), size * sizeof(int64_t)); - NodeArg& node_arg_partition = graph_utils::AddInitializer(graph, tensor_partition); + NodeArg& node_arg_partition = graph_utils::AddInitializerWithExternalData(graph, tensor_partition); graph_utils::ReplaceNodeInput(*node_ptr, 1, node_arg_partition); graph.RemoveInitializedTensor(shape_arg->Name()); } diff --git a/orttraining/orttraining/core/optimizer/qdq_fusion.cc b/orttraining/orttraining/core/optimizer/qdq_fusion.cc index fc9a6d213f794..4a5bdc1f8fcd2 100644 --- a/orttraining/orttraining/core/optimizer/qdq_fusion.cc +++ b/orttraining/orttraining/core/optimizer/qdq_fusion.cc @@ -21,22 +21,31 @@ int ReplaceOrCreateZeroPointInitializer(Graph& graph, Node& quantize_node) { ONNX_NAMESPACE::TensorProto zero_point_tensor_float; if (quant_node_input_defs.size() >= 3) { // The quantize node has the zero point input - auto zero_point_tensor_int = graph.GetInitializer(quant_node_input_defs[2]->Name(), true); - ORT_ENFORCE(zero_point_tensor_int != nullptr, "Expected: zero point initializer with name ", + constexpr const bool check_outer_scope_true = true; + const auto* zero_point_tensor_proto = graph.GetInitializer(quant_node_input_defs[2]->Name(), check_outer_scope_true); + ORT_ENFORCE(zero_point_tensor_proto != nullptr, "Expected: zero point initializer with name ", quant_node_input_defs[2]->Name(), " to be present in the graph. Actual: not found."); - zero_point_type = zero_point_tensor_int->data_type(); - zero_point_tensor_float.set_name(graph.GenerateNodeArgName(zero_point_tensor_int->name())); + Initializer zero_point_tensor_int(graph, *zero_point_tensor_proto, graph.ModelPath(), check_outer_scope_true); + zero_point_type = zero_point_tensor_int.data_type(); + zero_point_tensor_float.set_name(graph.GenerateNodeArgName(zero_point_tensor_int.name())); zero_point_tensor_float.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); - for (const auto val : zero_point_tensor_int->int32_data()) { - zero_point_tensor_float.add_float_data(static_cast(val)); + if (zero_point_type == ONNX_NAMESPACE::TensorProto_DataType_UINT8) { + for (const auto val : zero_point_tensor_int.DataAsSpan()) { + zero_point_tensor_float.add_float_data(static_cast(val)); + } + } else { + for (const auto val : zero_point_tensor_int.DataAsSpan()) { + zero_point_tensor_float.add_float_data(static_cast(val)); + } } - for (const auto& dim : zero_point_tensor_int->dims()) { + for (const auto dim : zero_point_tensor_int.dims()) { zero_point_tensor_float.add_dims(dim); } - graph.RemoveInitializedTensor(zero_point_tensor_int->name()); + graph.RemoveInitializedTensor(zero_point_tensor_int.name()); // Since the quantize node has the zero point initializer input, replace it - graph_utils::ReplaceNodeInput(quantize_node, 2, graph_utils::AddInitializer(graph, zero_point_tensor_float)); + graph_utils::ReplaceNodeInput(quantize_node, 2, + graph_utils::AddInitializerWithExternalData(graph, zero_point_tensor_float)); } else { // The quantize node does not have the zero point optional input. // Create the zero point initializer to be 0. @@ -45,7 +54,8 @@ int ReplaceOrCreateZeroPointInitializer(Graph& graph, Node& quantize_node) { zero_point_tensor_float.add_float_data(0.0f); // Since the input did not exist, add the newly created initializer as an input - graph_utils::AddNodeInput(quantize_node, 2, graph_utils::AddInitializer(graph, zero_point_tensor_float)); + graph_utils::AddNodeInput(quantize_node, 2, + graph_utils::AddInitializerWithExternalData(graph, zero_point_tensor_float)); } return zero_point_type; diff --git a/orttraining/orttraining/core/optimizer/scaled_sum_fusion.cc b/orttraining/orttraining/core/optimizer/scaled_sum_fusion.cc index e719a21118028..e6319952dfae7 100644 --- a/orttraining/orttraining/core/optimizer/scaled_sum_fusion.cc +++ b/orttraining/orttraining/core/optimizer/scaled_sum_fusion.cc @@ -74,7 +74,7 @@ bool IsScaleOperator(Graph& graph, Node& node, return false; } - Initializer init_const{*tensor_proto, graph.ModelPath()}; + Initializer init_const{graph, *tensor_proto, graph.ModelPath()}; const auto data_type = tensor_proto->data_type(); if (data_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) { const MLFloat16* val = init_const.data(); diff --git a/orttraining/orttraining/core/optimizer/sce_loss_grad_bias_fusion.cc b/orttraining/orttraining/core/optimizer/sce_loss_grad_bias_fusion.cc index 84bf715c7c85a..8c9c12ceb4497 100644 --- a/orttraining/orttraining/core/optimizer/sce_loss_grad_bias_fusion.cc +++ b/orttraining/orttraining/core/optimizer/sce_loss_grad_bias_fusion.cc @@ -83,7 +83,7 @@ Status SceLossGradBiasFusion::ApplyImpl(Graph& graph, bool& modified, int graph_ ignore_index_initializer_proto.set_name(graph.GenerateNodeArgName("sce_grad_ignore_index")); ignore_index_initializer_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); ignore_index_initializer_proto.add_int64_data(static_cast(-1)); - new_scegrad_node_inputs.emplace_back(&graph_utils::AddInitializer(graph, ignore_index_initializer_proto)); + new_scegrad_node_inputs.emplace_back(&graph_utils::AddInitializerWithExternalData(graph, ignore_index_initializer_proto)); } new_scegrad_node_inputs.emplace_back(bias_def); if (!p_reshape) { diff --git a/orttraining/orttraining/core/optimizer/triton_fusion.cc b/orttraining/orttraining/core/optimizer/triton_fusion.cc index f2cb3c2b8c6db..026f39712ffe6 100644 --- a/orttraining/orttraining/core/optimizer/triton_fusion.cc +++ b/orttraining/orttraining/core/optimizer/triton_fusion.cc @@ -64,7 +64,7 @@ bool CheckAxes(const Graph& graph, const Node& node, bool single_axis, const std if (!axes_const) { return false; } - Initializer initializer{*axes_const, graph.ModelPath()}; + Initializer initializer{graph, *axes_const, graph.ModelPath()}; axes_values.insert(axes_values.end(), initializer.DataAsSpan().begin(), initializer.DataAsSpan().end()); } else { diff --git a/orttraining/orttraining/core/session/training_session.cc b/orttraining/orttraining/core/session/training_session.cc index b03f1b1eadb3b..650ed69578210 100644 --- a/orttraining/orttraining/core/session/training_session.cc +++ b/orttraining/orttraining/core/session/training_session.cc @@ -4,9 +4,11 @@ #include "orttraining/core/session/training_session.h" #include "core/framework/data_transfer_utils.h" +#include "core/graph/graph_utils.h" #include "core/graph/model.h" #include "core/graph/model_saving_options.h" #include "core/session/IOBinding.h" +#include "core/optimizer/initializer.h" #include "core/optimizer/rule_based_graph_transformer.h" #include "core/providers/cpu/controlflow/utils.h" #include "core/providers/cpu/cpu_execution_provider.h" @@ -977,22 +979,18 @@ static Status UpdateWeightsBeforeSaving( if (!graph.GetInitializedTensor(name_and_ml_value.first, old_tensor_proto)) { continue; } - ONNX_NAMESPACE::TensorProto new_tensor_proto = *old_tensor_proto; - if (new_tensor_proto.has_raw_data()) { - auto* const raw_data = new_tensor_proto.mutable_raw_data(); - auto dst_span = gsl::make_span(&(*raw_data)[0], raw_data->size()); - ORT_RETURN_IF_ERROR(CopyTensorDataToByteSpan( - data_transfer_manager, src_tensor, cpu_alloc_info, dst_span)); - } else { - ORT_ENFORCE(new_tensor_proto.data_type() == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT); - auto* const float_data = new_tensor_proto.mutable_float_data(); - auto dst_span = gsl::make_span(float_data->mutable_data(), float_data->size()); - ORT_RETURN_IF_ERROR(CopyTensorDataToSpan( - data_transfer_manager, src_tensor, cpu_alloc_info, dst_span)); - } + + Initializer initializer{graph, *old_tensor_proto, graph.ModelPath()}; + const auto chars_span = ReinterpretAsSpan(initializer.MutableDataAsByteSpan()); + ORT_RETURN_IF_ERROR(CopyTensorDataToByteSpan( + data_transfer_manager, src_tensor, cpu_alloc_info, chars_span)); + + TensorProto new_tensor_proto; + OrtValue ort_value; + initializer.ToProtoWithOrtValue(new_tensor_proto, ort_value); // Replace the TensorProto in the model. - ORT_RETURN_IF_ERROR(graph.ReplaceInitializedTensor(new_tensor_proto)); + ORT_RETURN_IF_ERROR(graph.ReplaceInitializedTensor(new_tensor_proto, ort_value)); } return Status::OK(); } diff --git a/orttraining/orttraining/models/runner/training_util.cc b/orttraining/orttraining/models/runner/training_util.cc index 6af3bf4410065..fd638d4b102b9 100644 --- a/orttraining/orttraining/models/runner/training_util.cc +++ b/orttraining/orttraining/models/runner/training_util.cc @@ -50,7 +50,7 @@ common::Status DataSet::AddData(const vector& featu size_t cpu_tensor_length; ORT_RETURN_IF_ERROR(utils::GetSizeInBytesFromTensorProto<0>(tensor_proto, &cpu_tensor_length)); OrtValue ort_value; - OrtMemoryInfo info("Cpu", OrtDeviceAllocator, OrtDevice{}, 0, OrtMemTypeDefault); + OrtMemoryInfo info("Cpu", OrtDeviceAllocator, OrtDevice{}, OrtMemTypeDefault); std::unique_ptr buffer = std::make_unique(cpu_tensor_length); ORT_RETURN_IF_ERROR(utils::TensorProtoToOrtValue( Env::Default(), std::filesystem::path(), tensor_proto, MemBuffer(buffer.get(), cpu_tensor_length, info), ort_value)); diff --git a/orttraining/orttraining/python/orttraining_python_module.cc b/orttraining/orttraining/python/orttraining_python_module.cc index 897ae7e6a94c9..a06310d4c7fbc 100644 --- a/orttraining/orttraining/python/orttraining_python_module.cc +++ b/orttraining/orttraining/python/orttraining_python_module.cc @@ -173,6 +173,8 @@ void ORTTrainingPythonEnv::ClearExecutionProviderInstances() { } static ORTTrainingPythonEnv* ort_training_env = nullptr; +static OrtThreadingOptions global_tp_options; +static bool use_global_tp = false; OrtEnv* GetOrtEnv() { return &ort_training_env->GetORTEnv(); @@ -185,7 +187,7 @@ static Status CreateOrtEnv() { Env::Default().GetTelemetryProvider().SetLanguageProjection(OrtLanguageProjection::ORT_PROJECTION_PYTHON); OrtEnv::LoggingManagerConstructionInfo lm_info{nullptr, nullptr, ORT_LOGGING_LEVEL_WARNING, "Default"}; Status status; - std::unique_ptr ort_env(OrtEnv::GetInstance(lm_info, status)); + std::unique_ptr ort_env(OrtEnv::GetInstance(lm_info, status, use_global_tp ? &global_tp_options : nullptr)); if (!status.IsOK()) return status; #if !defined(__APPLE__) && !defined(ORT_MINIMAL_BUILD) if (!InitProvidersSharedLibrary()) { @@ -201,6 +203,17 @@ ORTTrainingPythonEnv& GetTrainingEnv() { return *ort_training_env; } +void SetGlobalThreadingOptions(const OrtThreadingOptions&& tp_options) { + if (ort_training_env != nullptr) { + OrtPybindThrowIfError(GetEnv().SetGlobalThreadingOptions(tp_options)); + } + global_tp_options = tp_options; + use_global_tp = true; +} +bool CheckIfUsingGlobalThreadPool() { + return use_global_tp; +} + void ResolveExtraProviderOptions(const std::vector& provider_types, const ProviderOptionsMap& original_provider_options_map, ProviderOptionsMap& merged_options) { diff --git a/orttraining/orttraining/test/framework/checkpointing_test.cc b/orttraining/orttraining/test/framework/checkpointing_test.cc index a7ee776b9bc39..615a3e86a8a2f 100644 --- a/orttraining/orttraining/test/framework/checkpointing_test.cc +++ b/orttraining/orttraining/test/framework/checkpointing_test.cc @@ -52,7 +52,7 @@ void CompareOrtValuesToTensorProtoValues( ASSERT_EQ(name_to_ort_value.size(), name_to_tensor_proto.size()); NameMLValMap name_to_ort_value_from_tensor_proto{}; - AllocatorPtr tmp_allocator = std::make_shared(); + AllocatorPtr tmp_allocator = CPUAllocator::DefaultInstance(); for (const auto& name_and_tensor_proto : name_to_tensor_proto) { const auto& name = name_and_tensor_proto.first; diff --git a/orttraining/orttraining/test/optimizer/horizontal_parallel_test_utils.cc b/orttraining/orttraining/test/optimizer/horizontal_parallel_test_utils.cc index 1cb9518a06193..9461f751aecdd 100644 --- a/orttraining/orttraining/test/optimizer/horizontal_parallel_test_utils.cc +++ b/orttraining/orttraining/test/optimizer/horizontal_parallel_test_utils.cc @@ -183,10 +183,13 @@ Status GetDataAndShapeFromTensorProto(const Graph& graph, const NodeArg* input_a } const ONNX_NAMESPACE::TensorProto* tensor_proto = nullptr; - graph.GetInitializedTensor(input_arg->Name(), tensor_proto); - auto init_const = std::make_unique(*tensor_proto, graph.ModelPath()); - const float* data_float = init_const->data(); - data.insert(data.end(), data_float, data_float + element_count); + if (!graph.GetInitializedTensor(input_arg->Name(), tensor_proto)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "Failed to get tensor proto for ", input_arg->Name()); + } + auto init_const = Initializer{graph, *tensor_proto, graph.ModelPath()}; + auto data_float = init_const.DataAsSpan(); + data.insert(data.end(), data_float.begin(), data_float.end()); return Status::OK(); } diff --git a/orttraining/orttraining/training_api/module.cc b/orttraining/orttraining/training_api/module.cc index 60708b05626c5..9e12fdcd2bb53 100644 --- a/orttraining/orttraining/training_api/module.cc +++ b/orttraining/orttraining/training_api/module.cc @@ -108,6 +108,7 @@ Status TransformModelInputsForInference(Graph& inference_graph, ORT_ENFORCE(!inference_graph.IsInitializedTensor(named_parameter_it->first), "The eval graph is invalid. Expected model parameter ", named_parameter_it->first, " to be a graph input, not a graph initializer."); + inference_graph.AddInitializedTensor(utils::CopyTensorToTensorProto( named_parameter_it->second->Data().Get(), named_parameter_it->first, data_transfer_manager)); diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index 82372645d364f..4c1ddd94fda5a 100644 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -381,6 +381,9 @@ def generate_build_tree( "-Donnxruntime_USE_TELEMETRY=" + ("ON" if args.use_telemetry else "OFF"), "-Donnxruntime_ENABLE_PIX_FOR_WEBGPU_EP=" + ("ON" if args.enable_pix_capture else "OFF"), ] + + if args.caller_framework: + cmake_args.append("-Donnxruntime_CALLER_FRAMEWORK=" + args.caller_framework) if args.winml_root_namespace_override: cmake_args.append("-Donnxruntime_WINML_NAMESPACE_OVERRIDE=" + args.winml_root_namespace_override) if args.disable_memleak_checker or args.enable_address_sanitizer: @@ -1719,6 +1722,9 @@ def run_onnxruntime_tests(args, source_dir, ctest_path, build_dir, configs): [sys.executable, "onnxruntime_test_python.py"], cwd=cwd, dll_path=dll_path, python_path=python_path ) + log.info("Testing Global Thread Pool feature") + run_subprocess([sys.executable, "onnxruntime_test_python_global_threadpool.py"], cwd=cwd, dll_path=dll_path) + log.info("Testing AutoEP feature") run_subprocess([sys.executable, "onnxruntime_test_python_autoep.py"], cwd=cwd, dll_path=dll_path) diff --git a/tools/ci_build/build_args.py b/tools/ci_build/build_args.py index 807c8b327c780..561eab7f2d61d 100644 --- a/tools/ci_build/build_args.py +++ b/tools/ci_build/build_args.py @@ -397,6 +397,7 @@ def add_windows_specific_args(parser: argparse.ArgumentParser) -> None: parser.add_argument("--windows_sdk_version", help="Windows SDK version (e.g., 10.0.19041.0).") parser.add_argument("--enable_msvc_static_runtime", action="store_true", help="Statically link MSVC runtimes.") parser.add_argument("--use_telemetry", action="store_true", help="Enable telemetry (official builds only).") + parser.add_argument("--caller_framework", type=str, help="Name of the framework calling ONNX Runtime.") # Cross-compilation targets hosted on Windows parser.add_argument( diff --git a/tools/ci_build/github/azure-pipelines/android-arm64-v8a-QNN-crosscompile-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/android-arm64-v8a-QNN-crosscompile-ci-pipeline.yml index ab10bdfba0e0f..8c16b6b7caa69 100644 --- a/tools/ci_build/github/azure-pipelines/android-arm64-v8a-QNN-crosscompile-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/android-arm64-v8a-QNN-crosscompile-ci-pipeline.yml @@ -32,7 +32,7 @@ parameters: - name: QnnSdk displayName: QNN SDK version type: string - default: 2.34.0.250424 + default: 2.35.0.250530 jobs: - job: Build_QNN_EP diff --git a/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml b/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml index aab62a723d71c..8d23a2576b1ef 100644 --- a/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml +++ b/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml @@ -60,7 +60,7 @@ parameters: - name: QnnSdk displayName: QNN SDK Version type: string - default: 2.34.0.250424 + default: 2.35.0.250530 resources: repositories: diff --git a/tools/ci_build/github/azure-pipelines/custom-nuget-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/custom-nuget-packaging-pipeline.yml index 6ee64e4870fd5..b925619401c27 100644 --- a/tools/ci_build/github/azure-pipelines/custom-nuget-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/custom-nuget-packaging-pipeline.yml @@ -6,7 +6,7 @@ parameters: - name: QnnSdk displayName: QNN SDK Version type: string - default: 2.34.0.250424 + default: 2.35.0.250530 - name: IsReleaseBuild displayName: Is a release build? Set it to true if you are doing an Onnx Runtime release. diff --git a/tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml index 580f565310661..1006936403d45 100644 --- a/tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml @@ -33,7 +33,7 @@ parameters: - name: QnnSdk displayName: QNN SDK version type: string - default: 2.34.0.250424 + default: 2.35.0.250530 jobs: - job: Build_QNN_EP diff --git a/tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml index 035b4b6c17222..96ae6952f8827 100644 --- a/tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml @@ -59,7 +59,7 @@ parameters: - name: qnn_sdk_version type: string displayName: 'QNN SDK version. Only for QNN packages.' - default: 2.34.0.250424 + default: 2.35.0.250530 trigger: none diff --git a/tools/ci_build/github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml index 63fb41ab24c68..742bd68fad104 100644 --- a/tools/ci_build/github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml @@ -2,7 +2,7 @@ parameters: - name: QnnSdk displayName: QNN SDK Version type: string - default: 2.34.0.250424 + default: 2.35.0.250530 - name: build_config displayName: Build Configuration diff --git a/tools/ci_build/github/azure-pipelines/stages/nuget-win-cuda-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/stages/nuget-win-cuda-packaging-stage.yml index e5d15db5c062a..ed6c4c799c26d 100644 --- a/tools/ci_build/github/azure-pipelines/stages/nuget-win-cuda-packaging-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/nuget-win-cuda-packaging-stage.yml @@ -45,7 +45,7 @@ stages: msbuildPlatform: x64 packageName: x64-cuda CudaVersion: ${{ parameters.CudaVersion }} - buildparameter: --use_cuda --cuda_home=${{ parameters.win_cuda_home }} --enable_onnx_tests --enable_wcos --cmake_extra_defines "CMAKE_CUDA_ARCHITECTURES=52-real;61-real;75-real;86-real;89-real;90-virtual" + buildparameter: --use_cuda --cuda_home=${{ parameters.win_cuda_home }} --enable_onnx_tests --enable_wcos --cmake_extra_defines "CMAKE_CUDA_ARCHITECTURES=52-real;61-real;75-real;86-real;89-real;90a-virtual" runTests: ${{ parameters.RunOnnxRuntimeTests }} buildJava: ${{ parameters.buildJava }} java_artifact_id: onnxruntime_gpu @@ -62,7 +62,7 @@ stages: msbuildPlatform: x64 CudaVersion: ${{ parameters.CudaVersion }} packageName: x64-tensorrt - buildparameter: --use_tensorrt --tensorrt_home=${{ parameters.win_trt_home }} --cuda_home=${{ parameters.win_cuda_home }} --enable_onnx_tests --enable_wcos --cmake_extra_defines "CMAKE_CUDA_ARCHITECTURES=52-real;61-real;75-real;86-real;89-real;90-virtual" + buildparameter: --use_tensorrt --tensorrt_home=${{ parameters.win_trt_home }} --cuda_home=${{ parameters.win_cuda_home }} --enable_onnx_tests --enable_wcos --cmake_extra_defines "CMAKE_CUDA_ARCHITECTURES=52-real;61-real;75-real;86-real;89-real;90a-virtual" runTests: ${{ parameters.RunOnnxRuntimeTests }} buildJava: ${{ parameters.buildJava }} java_artifact_id: onnxruntime_gpu diff --git a/tools/ci_build/github/azure-pipelines/stages/py-cpu-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/stages/py-cpu-packaging-stage.yml index 728927f33886a..8440a2c98bb6a 100644 --- a/tools/ci_build/github/azure-pipelines/stages/py-cpu-packaging-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/py-cpu-packaging-stage.yml @@ -59,7 +59,7 @@ parameters: - name: qnn_sdk_version type: string displayName: 'QNN SDK version. Only for QNN packages.' - default: 2.34.0.250424 + default: 2.35.0.250530 stages: - ${{ if eq(parameters.enable_windows_cpu, true) }}: diff --git a/tools/ci_build/github/azure-pipelines/stages/py-gpu-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/stages/py-gpu-packaging-stage.yml index eea9b672eef3d..c865048456f3f 100644 --- a/tools/ci_build/github/azure-pipelines/stages/py-gpu-packaging-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/py-gpu-packaging-stage.yml @@ -55,7 +55,7 @@ stages: PYTHON_VERSION: ${{ python_version }} EP_NAME: gpu CudaVersion: ${{ parameters.cuda_version }} - EP_BUILD_FLAGS: --enable_lto --use_cuda --cuda_home=$(Agent.TempDirectory)\v${{ parameters.cuda_version }} --cmake_extra_defines "CMAKE_CUDA_ARCHITECTURES=52-real;61-real;75-real;86-real;89-real;90-virtual" + EP_BUILD_FLAGS: --enable_lto --use_cuda --cuda_home=$(Agent.TempDirectory)\v${{ parameters.cuda_version }} --cmake_extra_defines "CMAKE_CUDA_ARCHITECTURES=52-real;61-real;75-real;86-real;89-real;90a-virtual" use_tensorrt: True - ${{ if eq(parameters.enable_linux_cuda, true) }}: diff --git a/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar-test.yml b/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar-test.yml index 0c70a4f82c566..68e1e1b39c56c 100644 --- a/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar-test.yml +++ b/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar-test.yml @@ -19,7 +19,7 @@ parameters: - name: QnnSDKVersion displayName: QNN SDK Version type: string - default: '2.34.0.250424' + default: '2.35.0.250530' - name: enableWebGpu displayName: Enable WebGPU test diff --git a/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar.yml b/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar.yml index c94969d9e9d41..4c5801664dda9 100644 --- a/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar.yml +++ b/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar.yml @@ -53,7 +53,7 @@ parameters: - name: QnnSDKVersion displayName: QNN SDK Version type: string - default: '2.34.0.250424' + default: '2.35.0.250530' - name: is1ES displayName: Is 1ES pipeline diff --git a/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml b/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml index d4da4d56a9766..ed2d914df8d81 100644 --- a/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml +++ b/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml @@ -47,7 +47,7 @@ parameters: - name: QnnSDKVersion displayName: QNN SDK Version type: string - default: 2.34.0.250424 + default: 2.35.0.250530 - name: is1ES displayName: Is 1ES pipeline diff --git a/tools/ci_build/github/azure-pipelines/templates/jobs/download_linux_qnn_sdk.yml b/tools/ci_build/github/azure-pipelines/templates/jobs/download_linux_qnn_sdk.yml index 01dbfc5292aa9..4d343be496475 100644 --- a/tools/ci_build/github/azure-pipelines/templates/jobs/download_linux_qnn_sdk.yml +++ b/tools/ci_build/github/azure-pipelines/templates/jobs/download_linux_qnn_sdk.yml @@ -1,7 +1,7 @@ parameters: - name: QnnSDKVersion type: string - default: '2.34.0.250424' + default: '2.35.0.250530' steps: - script: | diff --git a/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_qnn_sdk.yml b/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_qnn_sdk.yml index 13cc9314caf77..062b70a2249f6 100644 --- a/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_qnn_sdk.yml +++ b/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_qnn_sdk.yml @@ -1,7 +1,7 @@ parameters: - name: QnnSDKVersion type: string - default: '2.34.0.250424' + default: '2.35.0.250530' steps: - powershell: | diff --git a/tools/ci_build/github/azure-pipelines/templates/linux-wasm-ci.yml b/tools/ci_build/github/azure-pipelines/templates/linux-wasm-ci.yml index aa434699fbe02..3eae6ec9c3fdf 100644 --- a/tools/ci_build/github/azure-pipelines/templates/linux-wasm-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/linux-wasm-ci.yml @@ -150,8 +150,8 @@ jobs: - ${{ if eq(parameters.BuildWebGPU, true) }}: - script: | mkdir -p $(Build.ArtifactStagingDirectory)/wasm_webgpu/ - cp $(Build.BinariesDirectory)/wasm_inferencing_webgpu/${{ parameters.BuildConfig }}/ort-wasm-simd-threaded.jsep.wasm $(Build.ArtifactStagingDirectory)/wasm_webgpu/ - cp $(Build.BinariesDirectory)/wasm_inferencing_webgpu/${{ parameters.BuildConfig }}/ort-wasm-simd-threaded.jsep.mjs $(Build.ArtifactStagingDirectory)/wasm_webgpu/ + cp $(Build.BinariesDirectory)/wasm_inferencing_webgpu/${{ parameters.BuildConfig }}/ort-wasm-simd-threaded.asyncify.wasm $(Build.ArtifactStagingDirectory)/wasm_webgpu/ + cp $(Build.BinariesDirectory)/wasm_inferencing_webgpu/${{ parameters.BuildConfig }}/ort-wasm-simd-threaded.asyncify.mjs $(Build.ArtifactStagingDirectory)/wasm_webgpu/ displayName: 'Create Artifacts (WebGPU EP)' - ${{ if eq(parameters.is1ES, false) }}: - task: PublishPipelineArtifact@1 diff --git a/tools/ci_build/github/azure-pipelines/templates/py-linux-qnn.yml b/tools/ci_build/github/azure-pipelines/templates/py-linux-qnn.yml index a0bfd6a46a43c..b31a8f4a2190f 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-linux-qnn.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-linux-qnn.yml @@ -26,7 +26,7 @@ parameters: - name: QnnSdk displayName: QNN SDK version type: string - default: 2.34.0.250424 + default: 2.35.0.250530 - name: is1ES displayName: 'Whether the pipeline is running in 1ES' diff --git a/tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml b/tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml index d28b3e9604c5d..3f6b554a679ca 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml @@ -7,7 +7,7 @@ parameters: - name: QNN_SDK displayName: QNN SDK Version type: string - default: 2.34.0.250424 + default: 2.35.0.250530 - name: ENV_SETUP_SCRIPT type: string diff --git a/tools/ci_build/github/azure-pipelines/templates/py-win-arm64ec-qnn.yml b/tools/ci_build/github/azure-pipelines/templates/py-win-arm64ec-qnn.yml index f300d845579bf..8d3c7a2914672 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-win-arm64ec-qnn.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-win-arm64ec-qnn.yml @@ -7,7 +7,7 @@ parameters: - name: QNN_SDK displayName: QNN SDK Version type: string - default: 2.34.0.250424 + default: 2.35.0.250530 - name: ENV_SETUP_SCRIPT type: string diff --git a/tools/ci_build/github/azure-pipelines/templates/py-win-x64-qnn.yml b/tools/ci_build/github/azure-pipelines/templates/py-win-x64-qnn.yml index ce22142e6c5bd..263992a034ffe 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-win-x64-qnn.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-win-x64-qnn.yml @@ -7,7 +7,7 @@ parameters: - name: QNN_SDK displayName: QNN SDK Version type: string - default: 2.34.0.250424 + default: 2.35.0.250530 - name: ENV_SETUP_SCRIPT type: string diff --git a/tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml b/tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml index 0b8c493ae124d..0960ae5ebda83 100644 --- a/tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml +++ b/tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml @@ -1,5 +1,5 @@ parameters: - QnnSdk: '2.34.0.250424' + QnnSdk: '2.35.0.250530' build_config: 'RelWithDebInfo' IsReleaseBuild: false DoEsrp: false diff --git a/tools/ci_build/github/azure-pipelines/templates/win-ci.yml b/tools/ci_build/github/azure-pipelines/templates/win-ci.yml index 9c80d330854b5..f52cbbb5b3c24 100644 --- a/tools/ci_build/github/azure-pipelines/templates/win-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/win-ci.yml @@ -157,7 +157,7 @@ stages: displayName: 'Generate cmake config' inputs: scriptPath: '$(Build.SourcesDirectory)\tools\ci_build\build.py' - arguments: '--config RelWithDebInfo --use_binskim_compliant_compile_flags --enable_lto --disable_rtti --build_dir $(Build.BinariesDirectory) --skip_submodule_sync --build_shared_lib --update --build --cmake_generator "$(VSGenerator)" --enable_onnx_tests $(TelemetryOption) ${{ parameters.buildparameter }} $(timeoutParameter) $(buildJavaParameter)' + arguments: '--parallel 16 --use_vcpkg --use_vcpkg_ms_internal_asset_cache --config RelWithDebInfo --use_binskim_compliant_compile_flags --enable_lto --disable_rtti --build_dir $(Build.BinariesDirectory) --skip_submodule_sync --build_shared_lib --update --build --cmake_generator "$(VSGenerator)" --enable_onnx_tests $(TelemetryOption) ${{ parameters.buildparameter }} $(timeoutParameter) $(buildJavaParameter)' workingDirectory: '$(Build.BinariesDirectory)' diff --git a/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml index 9c06edb4d03e8..e3774fc4476ec 100644 --- a/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml @@ -33,7 +33,7 @@ parameters: - name: QnnSdk displayName: QNN SDK version type: string - default: 2.34.0.250424 + default: 2.35.0.250530 jobs: - job: 'BUILD_QNN_EP' diff --git a/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml index 3b41394b97bd3..9b96d3eb8e304 100644 --- a/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml @@ -33,7 +33,7 @@ parameters: - name: QnnSdk displayName: QNN SDK version type: string - default: 2.34.0.250424 + default: 2.35.0.250530 jobs: - job: 'BUILD_QNN_EP' diff --git a/tools/ci_build/github/linux/build_cuda_c_api_package.sh b/tools/ci_build/github/linux/build_cuda_c_api_package.sh index e7bf487d07dd8..fe417db7f2559 100755 --- a/tools/ci_build/github/linux/build_cuda_c_api_package.sh +++ b/tools/ci_build/github/linux/build_cuda_c_api_package.sh @@ -2,4 +2,4 @@ set -e -x docker run -e SYSTEM_COLLECTIONURI --rm --volume \ $BUILD_SOURCESDIRECTORY:/onnxruntime_src --volume $BUILD_BINARIESDIRECTORY:/build -e NIGHTLY_BUILD onnxruntimecuda${CUDA_VERSION_MAJOR}build \ -/bin/bash -c "/usr/bin/python3 /onnxruntime_src/tools/ci_build/build.py --enable_lto --build_java --build_nodejs --build_dir /build --config Release --skip_submodule_sync --parallel --use_binskim_compliant_compile_flags --build_shared_lib --use_cuda --cuda_version=$CUDA_VERSION --cuda_home=/usr/local/cuda-$CUDA_VERSION --cudnn_home=/usr/local/cuda-$CUDA_VERSION --skip_tests --use_vcpkg --use_vcpkg_ms_internal_asset_cache --cmake_extra_defines 'CMAKE_CUDA_ARCHITECTURES=60-real;70-real;75-real;80-real;90' && cd /build/Release && make install DESTDIR=/build/installed" +/bin/bash -c "/usr/bin/python3 /onnxruntime_src/tools/ci_build/build.py --enable_lto --build_java --build_nodejs --build_dir /build --config Release --skip_submodule_sync --parallel --use_binskim_compliant_compile_flags --build_shared_lib --use_cuda --cuda_version=$CUDA_VERSION --cuda_home=/usr/local/cuda-$CUDA_VERSION --cudnn_home=/usr/local/cuda-$CUDA_VERSION --skip_tests --use_vcpkg --use_vcpkg_ms_internal_asset_cache --cmake_extra_defines 'CMAKE_CUDA_ARCHITECTURES=60-real;70-real;75-real;80-real;90a-real;90a-virtual' && cd /build/Release && make install DESTDIR=/build/installed" diff --git a/tools/ci_build/github/linux/build_linux_python_package.sh b/tools/ci_build/github/linux/build_linux_python_package.sh index 99e507df8e8f7..fd23cd9bc37f1 100755 --- a/tools/ci_build/github/linux/build_linux_python_package.sh +++ b/tools/ci_build/github/linux/build_linux_python_package.sh @@ -70,7 +70,7 @@ fi if [ "$BUILD_DEVICE" == "GPU" ]; then SHORT_CUDA_VERSION=$(echo $CUDA_VERSION | sed 's/\([[:digit:]]\+\.[[:digit:]]\+\)\.[[:digit:]]\+/\1/') #Enable CUDA and TRT EPs. - BUILD_ARGS+=("--use_cuda" "--use_tensorrt" "--cuda_version=$SHORT_CUDA_VERSION" "--tensorrt_home=/usr" "--cuda_home=/usr/local/cuda-$SHORT_CUDA_VERSION" "--cudnn_home=/usr/local/cuda-$SHORT_CUDA_VERSION" "--nvcc_threads=1" "--cmake_extra_defines" "CMAKE_CUDA_ARCHITECTURES=60-real;70-real;75-real;80-real;86-real;90") + BUILD_ARGS+=("--use_cuda" "--use_tensorrt" "--cuda_version=$SHORT_CUDA_VERSION" "--tensorrt_home=/usr" "--cuda_home=/usr/local/cuda-$SHORT_CUDA_VERSION" "--cudnn_home=/usr/local/cuda-$SHORT_CUDA_VERSION" "--nvcc_threads=1" "--cmake_extra_defines" "CMAKE_CUDA_ARCHITECTURES=60-real;70-real;75-real;80-real;86-real;90a-real;90a-virtual") fi if [ "$BUILD_DEVICE" == "NPU" ]; then diff --git a/tools/ci_build/github/linux/build_nodejs_package.sh b/tools/ci_build/github/linux/build_nodejs_package.sh index 29ee91a122e39..cc6443cc7fab6 100755 --- a/tools/ci_build/github/linux/build_nodejs_package.sh +++ b/tools/ci_build/github/linux/build_nodejs_package.sh @@ -3,4 +3,4 @@ set -e -x mkdir -p $HOME/.onnx docker run -e SYSTEM_COLLECTIONURI --rm --volume /data/onnx:/data/onnx:ro --volume $BUILD_SOURCESDIRECTORY:/onnxruntime_src --volume $BUILD_BINARIESDIRECTORY:/build \ --volume /data/models:/build/models:ro --volume $HOME/.onnx:/home/onnxruntimedev/.onnx -e NIGHTLY_BUILD onnxruntimecuda${CUDA_VERSION_MAJOR}xtrt86build \ -/bin/bash -c "/usr/bin/python3 /onnxruntime_src/tools/ci_build/build.py --build_dir /build --config Release --skip_tests --skip_submodule_sync --parallel --use_binskim_compliant_compile_flags --build_shared_lib --build_nodejs --use_webgpu --use_tensorrt --cuda_version=$CUDA_VERSION --cuda_home=/usr/local/cuda-$CUDA_VERSION --cudnn_home=/usr --tensorrt_home=/usr --cmake_extra_defines 'CMAKE_CUDA_ARCHITECTURES=60-real;70-real;75-real;80-real;90' --use_vcpkg --use_vcpkg_ms_internal_asset_cache && cd /build/Release && make install DESTDIR=/build/installed" +/bin/bash -c "/usr/bin/python3 /onnxruntime_src/tools/ci_build/build.py --build_dir /build --config Release --skip_tests --skip_submodule_sync --parallel --use_binskim_compliant_compile_flags --build_shared_lib --build_nodejs --use_webgpu --use_tensorrt --cuda_version=$CUDA_VERSION --cuda_home=/usr/local/cuda-$CUDA_VERSION --cudnn_home=/usr --tensorrt_home=/usr --cmake_extra_defines 'CMAKE_CUDA_ARCHITECTURES=60-real;70-real;75-real;80-real;90a-real;90a-virtual' --use_vcpkg --use_vcpkg_ms_internal_asset_cache && cd /build/Release && make install DESTDIR=/build/installed" diff --git a/tools/ci_build/github/linux/build_tensorrt_c_api_package.sh b/tools/ci_build/github/linux/build_tensorrt_c_api_package.sh index 42ac04e91a035..54e671a8196be 100755 --- a/tools/ci_build/github/linux/build_tensorrt_c_api_package.sh +++ b/tools/ci_build/github/linux/build_tensorrt_c_api_package.sh @@ -3,4 +3,4 @@ set -e -x mkdir -p $HOME/.onnx docker run -e SYSTEM_COLLECTIONURI --rm --volume /data/onnx:/data/onnx:ro --volume $BUILD_SOURCESDIRECTORY:/onnxruntime_src --volume $BUILD_BINARIESDIRECTORY:/build \ --volume /data/models:/build/models:ro --volume $HOME/.onnx:/home/onnxruntimedev/.onnx -e NIGHTLY_BUILD onnxruntimecuda${CUDA_VERSION_MAJOR}xtrt86build \ -/bin/bash -c "/usr/bin/python3 /onnxruntime_src/tools/ci_build/build.py --build_dir /build --config Release --skip_tests --skip_submodule_sync --parallel --use_binskim_compliant_compile_flags --build_shared_lib --build_java --build_nodejs --use_tensorrt --cuda_version=$CUDA_VERSION --cuda_home=/usr/local/cuda-$CUDA_VERSION --cudnn_home=/usr --tensorrt_home=/usr --cmake_extra_defines 'CMAKE_CUDA_ARCHITECTURES=60-real;70-real;75-real;80-real;90' --use_vcpkg --use_vcpkg_ms_internal_asset_cache && cd /build/Release && make install DESTDIR=/build/installed" +/bin/bash -c "/usr/bin/python3 /onnxruntime_src/tools/ci_build/build.py --build_dir /build --config Release --skip_tests --skip_submodule_sync --parallel --use_binskim_compliant_compile_flags --build_shared_lib --build_java --build_nodejs --use_tensorrt --cuda_version=$CUDA_VERSION --cuda_home=/usr/local/cuda-$CUDA_VERSION --cudnn_home=/usr --tensorrt_home=/usr --cmake_extra_defines 'CMAKE_CUDA_ARCHITECTURES=60-real;70-real;75-real;80-real;90a-real;90a-virtual' --use_vcpkg --use_vcpkg_ms_internal_asset_cache && cd /build/Release && make install DESTDIR=/build/installed" diff --git a/tools/ci_build/github/linux/ort_minimal/nnapi_minimal_build_minimal_ort_and_run_tests.sh b/tools/ci_build/github/linux/ort_minimal/nnapi_minimal_build_minimal_ort_and_run_tests.sh index 02f89f4e91e5c..1d9d7f43fd10d 100755 --- a/tools/ci_build/github/linux/ort_minimal/nnapi_minimal_build_minimal_ort_and_run_tests.sh +++ b/tools/ci_build/github/linux/ort_minimal/nnapi_minimal_build_minimal_ort_and_run_tests.sh @@ -34,7 +34,7 @@ python3 $ORT_ROOT/tools/ci_build/build.py \ --disable_ml_ops \ --disable_exceptions \ --include_ops_by_config $ORT_ROOT/onnxruntime/test/testdata/required_ops_and_types.config \ - --skip_tests + --skip_tests --use_vcpkg --use_vcpkg_ms_internal_asset_cache # Push onnxruntime_test_all and testdata to emulator adb push $MIN_BUILD_DIR/Debug/onnxruntime_test_all /data/local/tmp/ diff --git a/tools/python/remove_initializer_from_input.py b/tools/python/remove_initializer_from_input.py index 935d5a44c75fe..dd8db1b06074c 100644 --- a/tools/python/remove_initializer_from_input.py +++ b/tools/python/remove_initializer_from_input.py @@ -11,25 +11,27 @@ def get_args(): return args -def remove_initializer_from_input(): - args = get_args() - - model = onnx.load(args.input) +def remove_initializer_from_input(model: onnx.ModelProto) -> bool: if model.ir_version < 4: - print("Model with ir_version below 4 requires to include initilizer in graph input") - return + print("Model with ir_version below 4 requires to include initializer in graph input") + return False inputs = model.graph.input name_to_input = {} for input in inputs: name_to_input[input.name] = input + modified = False for initializer in model.graph.initializer: if initializer.name in name_to_input: + modified = True inputs.remove(name_to_input[initializer.name]) - onnx.save(model, args.output) + return modified if __name__ == "__main__": - remove_initializer_from_input() + args = get_args() + model = onnx.load(args.input) + remove_initializer_from_input(model) + onnx.save(model, args.output)