diff --git a/.github/workflows/android.yml b/.github/workflows/android.yml index 9e1a491d154cf..7f7ff74959d52 100644 --- a/.github/workflows/android.yml +++ b/.github/workflows/android.yml @@ -71,8 +71,8 @@ jobs: run: | set -e -x BINARY_SIZE_THRESHOLD_ARGS="" - echo "Binary size threshold in bytes: 1722565" - BINARY_SIZE_THRESHOLD_ARGS="--threshold_size_in_bytes 1722565" + echo "Binary size threshold in bytes: 1436672" + BINARY_SIZE_THRESHOLD_ARGS="--threshold_size_in_bytes 1436672" # Ensure ANDROID_NDK_HOME is available and get its real path if [ -z "$ANDROID_NDK_HOME" ]; then diff --git a/.github/workflows/macos-ci-build-and-test-workflow.yml b/.github/workflows/macos-ci-build-and-test-workflow.yml index c7c35fb234013..329584c68d7d1 100644 --- a/.github/workflows/macos-ci-build-and-test-workflow.yml +++ b/.github/workflows/macos-ci-build-and-test-workflow.yml @@ -62,6 +62,7 @@ jobs: --build_objc --build_java --build_wheel + ${{ matrix.target == 'arm64' && '--enable_arm_neon_nchwc' || '' }} ${{ inputs.use_webgpu && '--use_webgpu' || '' }} ${{ inputs.use_xnnpack && '--use_xnnpack' || '' }} ${{ inputs.use_coreml && '--use_coreml --skip_onnx_tests' || '' }} diff --git a/.github/workflows/react_native.yml b/.github/workflows/react_native.yml index 9327bd0ecfe3c..08426f3d0ccb5 100644 --- a/.github/workflows/react_native.yml +++ b/.github/workflows/react_native.yml @@ -102,7 +102,7 @@ jobs: run: sudo apt-get update && sudo apt-get install -y ninja-build - name: Download Android AAR artifacts - uses: actions/download-artifact@v4 + uses: actions/download-artifact@v5 with: name: onnxruntime-android-full-aar path: ${{ runner.temp }}/android-full-aar @@ -221,7 +221,7 @@ jobs: uses: actions/checkout@v5 - name: Download iOS pod artifact - uses: actions/download-artifact@v4 + uses: actions/download-artifact@v5 with: name: ios_pod path: ${{ runner.temp }}/ios_pod @@ -277,7 +277,7 @@ jobs: uses: actions/checkout@v5 - name: Download iOS pod artifact - uses: actions/download-artifact@v4 + uses: actions/download-artifact@v5 with: name: ios_pod path: ${{ runner.temp }}/ios_pod diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index bcab5e9e6fa1b..793207f5b6d76 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -88,13 +88,9 @@ option(onnxruntime_USE_RKNPU "Build with RKNPU support" OFF) option(onnxruntime_USE_DNNL "Build with DNNL support" OFF) option(onnxruntime_USE_JSEP "Build with JavaScript implemented kernels support" OFF) option(onnxruntime_USE_SVE "Build with SVE support in MLAS" OFF) +option(onnxruntime_USE_ARM_NEON_NCHWC "Build with ARM Neon NCHWc kernels in MLAS" OFF) option(onnxruntime_USE_KLEIDIAI "Build with KleidiAI integration in MLAS" OFF) -# iOS simulator build explicitly builds targets with USE_KLEIDIAI=ON so attempting to force override if so -if(APPLE AND CMAKE_OSX_ARCHITECTURES MATCHES "x86_64") - message(WARNING "Disabling KleidiAI: not supported on Apple x86_64 platforms") - set(onnxruntime_USE_KLEIDIAI OFF CACHE BOOL "" FORCE) -endif() option(onnxruntime_BUILD_UNIT_TESTS "Build ONNXRuntime unit tests" ON) option(onnxruntime_BUILD_CSHARP "Build C# library" OFF) option(onnxruntime_BUILD_OBJC "Build Objective-C library" OFF) @@ -258,6 +254,8 @@ if("${CMAKE_C_COMPILER_ID}" STREQUAL "GNU" AND CMAKE_C_COMPILER_VERSION VERSION_ message(FATAL_ERROR "GCC version must be greater than or equal to 11.1") endif() +include(detect_onnxruntime_target_platform.cmake) + # ENABLE_TRAINING includes all training functionality # The following 2 entry points # 1. ORTModule @@ -434,14 +432,6 @@ set(ORTTRAINING_SOURCE_DIR ${ORTTRAINING_ROOT}/orttraining) include(adjust_global_compile_flags.cmake) -if (APPLE) - if (NOT CMAKE_OSX_ARCHITECTURES) - message("Building ONNX Runtime for ${CMAKE_HOST_SYSTEM_PROCESSOR} CPU ARCH") - endif() -elseif (NOT WIN32 AND NOT APPLE) - message("Building ONNX Runtime for ${onnxruntime_target_platform} CPU ARCH") -endif() - # We need to link with libatomic on systems that do not have built-in atomics, or # don't have built-in support for 8 byte atomics # Derived from https://github.com/protocolbuffers/protobuf/blob/master/cmake/CMakeLists.txt @@ -513,6 +503,66 @@ if (onnxruntime_BUILD_SHARED_LIB OR onnxruntime_ENABLE_PYTHON) endif() endif() +if (onnxruntime_USE_ARM_NEON_NCHWC) + message(STATUS "Building MLAS with ARM Neon NCHWc kernels") +endif() + +if(onnxruntime_USE_SVE) + if(LINUX AND onnxruntime_target_platform STREQUAL "aarch64") + check_cxx_compiler_flag("-march=armv8.2-a+sve" HAS_ARM64_SVE) + if(HAS_ARM64_SVE) + message(STATUS "Compiler supports SVE!") + else() + message(WARNING "onnxruntime_USE_SVE was set but compiler does not support SVE. It will be disabled.") + set(onnxruntime_USE_SVE OFF) + endif() + else() + message(WARNING "onnxruntime_USE_SVE was set but it is not supported on this platform. It will be disabled.") + set(onnxruntime_USE_SVE OFF) + endif() +endif() + +if(onnxruntime_USE_KLEIDIAI) + function(is_kleidiai_supported is_supported_var) + # check for supported target platforms + if(NOT (onnxruntime_target_platform STREQUAL "aarch64" OR + onnxruntime_target_platform STREQUAL "ARM64" OR + onnxruntime_target_platform STREQUAL "arm64")) + message(WARNING "KleidiAI is not supported on this platform.") + + set(${is_supported_var} FALSE PARENT_SCOPE) + return() + endif() + + # check for compiler support + if(MSVC) + # TODO detect on MSVC + else() + check_cxx_compiler_flag(-march=armv8.2-a+dotprod HAS_ARM64_DOTPROD) + check_cxx_compiler_flag(-march=armv8.2-a+i8mm HAS_ARM64_I8MM) + if(NOT HAS_ARM64_DOTPROD) + message(WARNING "The compiler doesn't support dotprod instructions.") + endif() + if(NOT HAS_ARM64_I8MM) + message(WARNING "The compiler doesn't support i8mm instructions.") + endif() + if(NOT HAS_ARM64_DOTPROD OR NOT HAS_ARM64_I8MM) + set(${is_supported_var} FALSE PARENT_SCOPE) + return() + endif() + endif() + + set(${is_supported_var} TRUE PARENT_SCOPE) + endfunction() + + is_kleidiai_supported(is_kleidiai_supported_result) + + if(NOT is_kleidiai_supported_result) + message(WARNING "onnxruntime_USE_KLEIDIAI was set but it is not supported. It will be disabled.") + set(onnxruntime_USE_KLEIDIAI OFF) + endif() +endif() + #Dependencies begin get_filename_component(ONNXRUNTIME_ROOT "${ONNXRUNTIME_ROOT}" ABSOLUTE) get_filename_component(ORTTRAINING_ROOT "${ORTTRAINING_ROOT}" ABSOLUTE) @@ -663,43 +713,6 @@ else() endif() endif() -if(onnxruntime_USE_SVE) - if(CMAKE_SYSTEM_PROCESSOR MATCHES "aarch64" AND CMAKE_SYSTEM_NAME STREQUAL "Linux") - check_cxx_compiler_flag("-march=armv8.2-a+sve" HAS_ARM64_SVE) - if(HAS_ARM64_SVE) - message(STATUS "Compiler supports SVE!") - else() - message(WARNING "onnxruntime_USE_SVE was set but compiler does not support SVE. It will be disabled.") - set(onnxruntime_USE_SVE OFF) - endif() - else() - message(WARNING "onnxruntime_USE_SVE was set but it is not supported on this platform. It will be disabled.") - set(onnxruntime_USE_SVE OFF) - endif() -endif() - -if (onnxruntime_USE_KLEIDIAI AND ( - (onnxruntime_target_platform STREQUAL "aarch64") OR - (onnxruntime_target_platform STREQUAL "ARM64") OR - (APPLE AND CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64"))) - - # TODO Add checks for MSVC Compilation - if(NOT MSVC) - check_cxx_compiler_flag(-march=armv8.2-a+dotprod HAS_ARM64_DOTPROD) - check_cxx_compiler_flag(-march=armv8.2-a+i8mm HAS_ARM64_I8MM) - if (NOT HAS_ARM64_DOTPROD) - message(FATAL_ERROR "The compiler doesn't support dotprod") - endif() - if (NOT HAS_ARM64_I8MM) - message(FATAL_ERROR "The compiler doesn't support i8mm") - endif() - else() - message(STATUS "Skipping -march= checks on MSVC (not supported), assuming dotprod/i8mm support manually.") - set(HAS_ARM64_DOTPROD TRUE) - set(HAS_ARM64_I8MM TRUE) - endif() -endif() - #names in this var must match the directory names under onnxruntime/core/providers #ONNXRUNTIME_PROVIDER_NAMES is the list of providers that needs to export additional symbols in the global namespace. #For example CUDA EP exports "OrtSessionOptionsAppendExecutionProvider_CUDA", which is a global function. diff --git a/cmake/adjust_global_compile_flags.cmake b/cmake/adjust_global_compile_flags.cmake index 6d517003fa6b6..502a60ec8d7b8 100644 --- a/cmake/adjust_global_compile_flags.cmake +++ b/cmake/adjust_global_compile_flags.cmake @@ -217,30 +217,20 @@ endmacro() #Set global compile flags for all the source code(including third_party code like protobuf) #This section must be before any add_subdirectory, otherwise build may fail because /MD,/MT mismatch if (MSVC) - if (CMAKE_VS_PLATFORM_NAME) - # Multi-platform generator - set(onnxruntime_target_platform ${CMAKE_VS_PLATFORM_NAME}) - else() - set(onnxruntime_target_platform ${CMAKE_SYSTEM_PROCESSOR}) - endif() - if (onnxruntime_target_platform STREQUAL "ARM64") - set(onnxruntime_target_platform "ARM64") - enable_language(ASM_MARMASM) - elseif (onnxruntime_target_platform STREQUAL "ARM64EC") + if (onnxruntime_target_platform STREQUAL "ARM64" OR + onnxruntime_target_platform STREQUAL "ARM64EC" OR + onnxruntime_target_platform STREQUAL "ARM") enable_language(ASM_MARMASM) - elseif (onnxruntime_target_platform STREQUAL "ARM" OR CMAKE_GENERATOR MATCHES "ARM") - set(onnxruntime_target_platform "ARM") - enable_language(ASM_MARMASM) - elseif (onnxruntime_target_platform STREQUAL "x64" OR onnxruntime_target_platform STREQUAL "x86_64" OR onnxruntime_target_platform STREQUAL "AMD64" OR CMAKE_GENERATOR MATCHES "Win64") - set(onnxruntime_target_platform "x64") - enable_language(ASM_MASM) - elseif (onnxruntime_target_platform STREQUAL "Win32" OR onnxruntime_target_platform STREQUAL "x86" OR onnxruntime_target_platform STREQUAL "i386" OR onnxruntime_target_platform STREQUAL "i686") - set(onnxruntime_target_platform "x86") + elseif (onnxruntime_target_platform STREQUAL "x64" OR + onnxruntime_target_platform STREQUAL "x86") enable_language(ASM_MASM) - message("Enabling SAFESEH for x86 build") - set(CMAKE_ASM_MASM_FLAGS "${CMAKE_ASM_MASM_FLAGS} /safeseh") + + if (onnxruntime_target_platform STREQUAL "x86") + message("Enabling SAFESEH for x86 build") + set(CMAKE_ASM_MASM_FLAGS "${CMAKE_ASM_MASM_FLAGS} /safeseh") + endif() else() - message(FATAL_ERROR "Unknown CMAKE_SYSTEM_PROCESSOR: ${CMAKE_SYSTEM_PROCESSOR}") + message(FATAL_ERROR "Unsupported onnxruntime_target_platform value: ${onnxruntime_target_platform}") endif() #Always enable exception handling, even for Windows ARM @@ -269,34 +259,6 @@ if (MSVC) set(CMAKE_CXX_FLAGS_MINSIZEREL "${CMAKE_CXX_FLAGS_MINSIZEREL} /Gw /GL") endif() else() - if (NOT APPLE) - #XXX: Sometimes the value of CMAKE_SYSTEM_PROCESSOR is set but it's wrong. For example, if you run an armv7 docker - #image on an aarch64 machine with an aarch64 Ubuntu host OS, in the docker instance cmake may still report - # CMAKE_SYSTEM_PROCESSOR as aarch64 by default. Given compiling this code may need more than 2GB memory, we do not - # support compiling for ARM32 natively(only support cross-compiling), we will ignore this issue for now. - if(NOT CMAKE_SYSTEM_PROCESSOR) - message(WARNING "CMAKE_SYSTEM_PROCESSOR is not set. Please set it in your toolchain cmake file.") - # Try to detect it - if("${CMAKE_C_COMPILER_ID}" STREQUAL "GNU" OR "${CMAKE_C_COMPILER_ID}" STREQUAL "Clang") - execute_process( - COMMAND "${CMAKE_C_COMPILER}" -dumpmachine - OUTPUT_VARIABLE GCC_DUMP_MACHINE_OUT OUTPUT_STRIP_TRAILING_WHITESPACE - ERROR_VARIABLE _err - RESULT_VARIABLE _res - ) - if(NOT _res EQUAL 0) - message(SEND_ERROR "Failed to run 'gcc -dumpmachine':\n ${_res}") - endif() - string(REPLACE "-" ";" GCC_DUMP_MACHINE_OUT_LIST "${GCC_DUMP_MACHINE_OUT}") - list(LENGTH GCC_DUMP_MACHINE_OUT_LIST GCC_TRIPLET_LEN) - if(GCC_TRIPLET_LEN EQUAL 4) - list(GET GCC_DUMP_MACHINE_OUT_LIST 0 CMAKE_SYSTEM_PROCESSOR) - message("Setting CMAKE_SYSTEM_PROCESSOR to ${CMAKE_SYSTEM_PROCESSOR}") - endif() - endif() - endif() - set(onnxruntime_target_platform ${CMAKE_SYSTEM_PROCESSOR}) - endif() if (onnxruntime_BUILD_FOR_NATIVE_MACHINE) string(APPEND CMAKE_CXX_FLAGS " -march=native -mtune=native") string(APPEND CMAKE_C_FLAGS " -march=native -mtune=native") diff --git a/cmake/deps.txt b/cmake/deps.txt index 3d419f7fd913b..7b243ff15cd80 100644 --- a/cmake/deps.txt +++ b/cmake/deps.txt @@ -51,7 +51,7 @@ pytorch_cpuinfo;https://github.com/pytorch/cpuinfo/archive/de0ce7c7251372892e53c re2;https://github.com/google/re2/archive/refs/tags/2024-07-02.zip;646e1728269cde7fcef990bf4a8e87b047882e88 safeint;https://github.com/dcleblanc/SafeInt/archive/refs/tags/3.0.28.zip;23f252040ff6cb9f1fd18575b32fa8fb5928daac tensorboard;https://github.com/tensorflow/tensorboard/archive/373eb09e4c5d2b3cc2493f0949dc4be6b6a45e81.zip;67b833913605a4f3f499894ab11528a702c2b381 -cutlass;https://github.com/NVIDIA/cutlass/archive/refs/tags/v3.9.2.zip;b7f8dc4a879765127ce31dfeabd31c556c80ec79 +cutlass;https://github.com/NVIDIA/cutlass/archive/refs/tags/v4.2.1.zip;5d2b21b10478556c5e209dd7229e298a5c9f0b02 extensions;https://github.com/microsoft/onnxruntime-extensions/archive/c24b7bab0c12f53da76d0c31b03b9f0f8ec8f3b4.zip;239063aee4946a9af147b473a4c3da78ba7413b4 directx_headers;https://github.com/microsoft/DirectX-Headers/archive/refs/tags/v1.613.1.zip;47653509a3371eabb156360f42faf582f314bf2e cudnn_frontend;https://github.com/NVIDIA/cudnn-frontend/archive/refs/tags/v1.12.0.zip;7e733cfdc410d777b76122d64232499205589a96 diff --git a/cmake/detect_onnxruntime_target_platform.cmake b/cmake/detect_onnxruntime_target_platform.cmake new file mode 100644 index 0000000000000..2f5f6ee9ca80c --- /dev/null +++ b/cmake/detect_onnxruntime_target_platform.cmake @@ -0,0 +1,80 @@ +# This file will set the onnxruntime_target_platform variable, if applicable. +# onnxruntime_target_platform identifies the platform to compile for. +block(PROPAGATE onnxruntime_target_platform) + +unset(onnxruntime_target_platform) + +if (MSVC) + if (CMAKE_VS_PLATFORM_NAME) + # Multi-platform generator + set(onnxruntime_target_platform ${CMAKE_VS_PLATFORM_NAME}) + else() + set(onnxruntime_target_platform ${CMAKE_SYSTEM_PROCESSOR}) + endif() + + if (onnxruntime_target_platform STREQUAL "ARM64" OR + onnxruntime_target_platform STREQUAL "ARM64EC") + # Do nothing. We'll just use the current value of onnxruntime_target_platform. + elseif (onnxruntime_target_platform STREQUAL "ARM" OR + CMAKE_GENERATOR MATCHES "ARM") + set(onnxruntime_target_platform "ARM") + elseif (onnxruntime_target_platform STREQUAL "x64" OR + onnxruntime_target_platform STREQUAL "x86_64" OR + onnxruntime_target_platform STREQUAL "AMD64" OR + CMAKE_GENERATOR MATCHES "Win64") + set(onnxruntime_target_platform "x64") + elseif (onnxruntime_target_platform STREQUAL "Win32" OR + onnxruntime_target_platform STREQUAL "x86" OR + onnxruntime_target_platform STREQUAL "i386" OR + onnxruntime_target_platform STREQUAL "i686") + set(onnxruntime_target_platform "x86") + else() + message(FATAL_ERROR "Unknown target platform: ${onnxruntime_target_platform}") + endif() +elseif(APPLE) + if(DEFINED CMAKE_OSX_ARCHITECTURES) + # We'll only set onnxruntime_target_platform when CMAKE_OSX_ARCHITECTURES specifies a single architecture. + list(LENGTH CMAKE_OSX_ARCHITECTURES CMAKE_OSX_ARCHITECTURES_LEN) + if(CMAKE_OSX_ARCHITECTURES_LEN EQUAL 1) + set(onnxruntime_target_platform ${CMAKE_OSX_ARCHITECTURES}) + endif() + else() + set(onnxruntime_target_platform ${CMAKE_SYSTEM_PROCESSOR}) + endif() +else() + #XXX: Sometimes the value of CMAKE_SYSTEM_PROCESSOR is set but it's wrong. For example, if you run an armv7 docker + #image on an aarch64 machine with an aarch64 Ubuntu host OS, in the docker instance cmake may still report + # CMAKE_SYSTEM_PROCESSOR as aarch64 by default. Given compiling this code may need more than 2GB memory, we do not + # support compiling for ARM32 natively(only support cross-compiling), we will ignore this issue for now. + if(NOT CMAKE_SYSTEM_PROCESSOR) + message(WARNING "CMAKE_SYSTEM_PROCESSOR is not set. Please set it in your toolchain cmake file.") + # Try to detect it + if("${CMAKE_C_COMPILER_ID}" STREQUAL "GNU" OR "${CMAKE_C_COMPILER_ID}" STREQUAL "Clang") + execute_process( + COMMAND "${CMAKE_C_COMPILER}" -dumpmachine + OUTPUT_VARIABLE GCC_DUMP_MACHINE_OUT + OUTPUT_STRIP_TRAILING_WHITESPACE + ERROR_VARIABLE _err + RESULT_VARIABLE _res + ) + if(NOT _res EQUAL 0) + message(SEND_ERROR "Failed to run 'gcc -dumpmachine':\n ${_res}") + endif() + string(REPLACE "-" ";" GCC_DUMP_MACHINE_OUT_LIST "${GCC_DUMP_MACHINE_OUT}") + list(LENGTH GCC_DUMP_MACHINE_OUT_LIST GCC_TRIPLET_LEN) + if(GCC_TRIPLET_LEN EQUAL 4) + list(GET GCC_DUMP_MACHINE_OUT_LIST 0 CMAKE_SYSTEM_PROCESSOR) + message("Setting CMAKE_SYSTEM_PROCESSOR to ${CMAKE_SYSTEM_PROCESSOR}") + endif() + endif() + endif() + set(onnxruntime_target_platform ${CMAKE_SYSTEM_PROCESSOR}) +endif() + +if(DEFINED onnxruntime_target_platform) + message(STATUS "onnxruntime_target_platform = ${onnxruntime_target_platform}") +else() + message(WARNING "onnxruntime_target_platform is not set") +endif() + +endblock() diff --git a/cmake/external/cutlass.cmake b/cmake/external/cutlass.cmake index 65b0d61270b75..44b794d9e2f78 100644 --- a/cmake/external/cutlass.cmake +++ b/cmake/external/cutlass.cmake @@ -4,6 +4,7 @@ onnxruntime_fetchcontent_declare( URL ${DEP_URL_cutlass} URL_HASH SHA1=${DEP_SHA1_cutlass} EXCLUDE_FROM_ALL +PATCH_COMMAND ${Patch_EXECUTABLE} --ignore-whitespace -p1 < ${PROJECT_SOURCE_DIR}/patches/cutlass/cutlass_4.2.1_maybe_unused.patch ) FetchContent_GetProperties(cutlass) diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index 66c654e4a29e7..3b7c6a95ba98f 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -109,8 +109,6 @@ function(setup_mlas_source_for_windows) ${MLAS_SRC_DIR}/eltwise_kernel_neon.cpp ${MLAS_SRC_DIR}/eltwise_kernel_neon_fp16.cpp ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8_i8mm.cpp - ${MLAS_SRC_DIR}/sconv_kernel_neon.cpp - ${MLAS_SRC_DIR}/spool_kernel_neon.cpp ) set(mlas_platform_preprocess_srcs @@ -134,7 +132,11 @@ function(setup_mlas_source_for_windows) ${MLAS_SRC_DIR}/arm64/SymQgemmS8KernelSDotLd64.asm ) - if (onnxruntime_USE_KLEIDIAI) + if (onnxruntime_USE_ARM_NEON_NCHWC) + setup_arm_neon_nchwc() + endif() + + if (onnxruntime_USE_KLEIDIAI) setup_kleidiai() endif() else() @@ -289,6 +291,15 @@ function(setup_kleidiai) endif() endfunction() +function (setup_arm_neon_nchwc) + target_sources(onnxruntime_mlas PRIVATE + ${MLAS_SRC_DIR}/sconv.h + ${MLAS_SRC_DIR}/sconv_kernel_neon.cpp + ${MLAS_SRC_DIR}/spool_kernel_neon.cpp + ) + target_compile_definitions(onnxruntime_mlas PRIVATE MLAS_USE_ARM_NEON_NCHWC) +endfunction () + if (CMAKE_SYSTEM_NAME STREQUAL "Emscripten") if (onnxruntime_ENABLE_WEBASSEMBLY_SIMD) file(GLOB_RECURSE mlas_platform_srcs @@ -433,8 +444,6 @@ else() ${MLAS_SRC_DIR}/eltwise_kernel_neon.h ${MLAS_SRC_DIR}/eltwise_kernel_neon.cpp ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8_i8mm.cpp - ${MLAS_SRC_DIR}/sconv_kernel_neon.cpp - ${MLAS_SRC_DIR}/spool_kernel_neon.cpp ) # Conditionally add the SVE implementation if compiler supports it @@ -445,7 +454,11 @@ else() target_compile_definitions(onnxruntime_mlas PRIVATE MLAS_USE_SVE) endif() - if (onnxruntime_USE_KLEIDIAI) + if (onnxruntime_USE_ARM_NEON_NCHWC) + setup_arm_neon_nchwc() + endif() + + if (onnxruntime_USE_KLEIDIAI) setup_kleidiai() endif() set_source_files_properties(${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8.cpp diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index 41983d63f6afe..60c82f2bd16c5 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -1325,6 +1325,9 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) ${BENCHMARK_DIR}/layer_normalization.cc) target_include_directories(onnxruntime_benchmark PRIVATE ${ONNXRUNTIME_ROOT} ${onnxruntime_graph_header} ${ONNXRUNTIME_ROOT}/core/mlas/inc) target_compile_definitions(onnxruntime_benchmark PRIVATE BENCHMARK_STATIC_DEFINE) + if (onnxruntime_USE_SVE) + target_compile_definitions(onnxruntime_benchmark PRIVATE MLAS_USE_SVE) + endif() if(WIN32) target_compile_options(onnxruntime_benchmark PRIVATE "$<$:SHELL:--compiler-options /wd4141>" "$<$>:/wd4141>") @@ -1352,6 +1355,9 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) target_include_directories(onnxruntime_mlas_benchmark PRIVATE ${ONNXRUNTIME_ROOT}/core/mlas/inc) target_link_libraries(onnxruntime_mlas_benchmark PRIVATE benchmark::benchmark onnxruntime_util ${ONNXRUNTIME_MLAS_LIBS} onnxruntime_common ${CMAKE_DL_LIBS}) target_compile_definitions(onnxruntime_mlas_benchmark PRIVATE BENCHMARK_STATIC_DEFINE) + if (onnxruntime_USE_SVE) + target_compile_definitions(onnxruntime_mlas_benchmark PRIVATE MLAS_USE_SVE) + endif() if(WIN32) target_link_libraries(onnxruntime_mlas_benchmark PRIVATE debug Dbghelp) # Avoid using new and delete. But this is a benchmark program, it's ok if it has a chance to leak. @@ -1649,6 +1655,9 @@ endif() XCODE_ATTRIBUTE_CODE_SIGNING_ALLOWED "NO" ) endif() + if (onnxruntime_USE_SVE) + target_compile_definitions(onnxruntime_mlas_test PRIVATE MLAS_USE_SVE) + endif() target_include_directories(onnxruntime_mlas_test PRIVATE ${ONNXRUNTIME_ROOT}/core/mlas/inc ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR}) target_link_libraries(onnxruntime_mlas_test PRIVATE GTest::gtest GTest::gmock ${ONNXRUNTIME_MLAS_LIBS} onnxruntime_common) diff --git a/cmake/patches/cutlass/cutlass_4.2.1_maybe_unused.patch b/cmake/patches/cutlass/cutlass_4.2.1_maybe_unused.patch new file mode 100644 index 0000000000000..03d5972823839 --- /dev/null +++ b/cmake/patches/cutlass/cutlass_4.2.1_maybe_unused.patch @@ -0,0 +1,13 @@ +diff --git a/include/cute/layout.hpp b/include/cute/layout.hpp +index cb161369..2fdff179 100644 +--- a/include/cute/layout.hpp ++++ b/include/cute/layout.hpp +@@ -1487,7 +1487,7 @@ nullspace(Layout const& layout) + [[maybe_unused]] auto flat_stride = flatten(layout.stride()); + + // Select all indices corresponding to stride-0s +- auto iseq = cute::fold(make_seq>{}, cute::tuple<>{}, ++ [[maybe_unused]] auto iseq = cute::fold(make_seq>{}, cute::tuple<>{}, + [&](auto init, auto i){ + if constexpr (is_constant_v<0, decltype(get(flat_stride))>) { return append(init, i); } + else { return init; } diff --git a/include/onnxruntime/core/providers/utils/ort_graph_to_proto.h b/include/onnxruntime/core/providers/utils/ort_graph_to_proto.h index e2b2aff2011fe..457bb23c42f21 100644 --- a/include/onnxruntime/core/providers/utils/ort_graph_to_proto.h +++ b/include/onnxruntime/core/providers/utils/ort_graph_to_proto.h @@ -225,7 +225,8 @@ static Ort::Status GetOrtValueInfoTensorTypeShape(Ort::ConstValueInfo vi, bool get_symbolic_dims, /*out*/ ONNXTensorElementDataType& elem_type, /*out*/ std::vector& dims, - /*out*/ std::vector& symbolic_dims); + /*out*/ std::vector& symbolic_dims, + /*out*/ bool& has_shape); static Ort::Status OrtValueInfoToProto(Ort::ConstValueInfo ort_value_info, onnx::ValueInfoProto& value_info_proto); static Ort::Status OrtOpAttrToProto(Ort::ConstOpAttr ort_attr, onnx::AttributeProto& attr_proto); @@ -390,9 +391,10 @@ Ort::Status OrtGraphToProto(const OrtGraph& graph, std::vector initializer_dims; std::vector initializer_sym_dims; ONNXTensorElementDataType initializer_elem_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; + bool has_shape = false; ORT_EP_UTILS_CXX_RETURN_IF_ERROR(GetOrtValueInfoTensorTypeShape(initializer_value_info, /*get_sym_dims*/ false, initializer_elem_type, initializer_dims, - initializer_sym_dims)); + initializer_sym_dims, has_shape)); onnx::TensorProto* tensor_proto = graph_proto.add_initializer(); tensor_proto->set_name(initializer_name); @@ -493,28 +495,29 @@ static Ort::Status GetOrtValueInfoTensorTypeShape(Ort::ConstValueInfo vi, bool get_symbolic_dims, /*out*/ ONNXTensorElementDataType& elem_type, /*out*/ std::vector& dims, - /*out*/ std::vector& symbolic_dims) { + /*out*/ std::vector& symbolic_dims, + /*out*/ bool& has_shape) { try { Ort::ConstTypeInfo ort_type_info = vi.TypeInfo(); ONNXType ort_onnx_type = ort_type_info.GetONNXType(); ORT_EP_UTILS_C_RETURN_IF(ort_onnx_type != ONNX_TYPE_TENSOR, "Expected OrtValueInfo to represent a Tensor"); Ort::ConstTensorTypeAndShapeInfo ort_type_shape = ort_type_info.GetTensorTypeAndShapeInfo(); - ONNXTensorElementDataType ort_elem_type = ort_type_shape.GetElementType(); + elem_type = ort_type_shape.GetElementType(); + has_shape = ort_type_shape.HasShape(); - size_t num_dims = ort_type_shape.GetDimensionsCount(); - std::vector ort_dims = ort_type_shape.GetShape(); + if (has_shape) { + const size_t num_dims = ort_type_shape.GetDimensionsCount(); + dims = ort_type_shape.GetShape(); - elem_type = ort_elem_type; - dims = std::move(ort_dims); + if (get_symbolic_dims) { + std::vector ort_dim_syms(num_dims, nullptr); + ort_type_shape.GetSymbolicDimensions(ort_dim_syms.data(), ort_dim_syms.size()); - if (get_symbolic_dims) { - std::vector ort_dim_syms(num_dims, nullptr); - ort_type_shape.GetSymbolicDimensions(ort_dim_syms.data(), ort_dim_syms.size()); - - symbolic_dims.reserve(num_dims); - for (const char* sym_dim : ort_dim_syms) { - symbolic_dims.push_back(sym_dim); + symbolic_dims.reserve(num_dims); + for (const char* sym_dim : ort_dim_syms) { + symbolic_dims.push_back(sym_dim); + } } } } catch (const Ort::Exception& ex) { @@ -533,17 +536,18 @@ static Ort::Status OrtValueInfoToProto(Ort::ConstValueInfo ort_value_info, ONNXTensorElementDataType ort_elem_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; // We currently only support ONNX tensors. Support for other types (e.g., ONNX_TYPE_SEQUENCE) can be added later. + bool has_shape = false; ORT_EP_UTILS_CXX_RETURN_IF_ERROR(GetOrtValueInfoTensorTypeShape(ort_value_info, /*get_sym_dims*/ true, - ort_elem_type, ort_dims, ort_dim_syms)); + ort_elem_type, ort_dims, ort_dim_syms, + has_shape)); value_info_proto.set_name(ort_value_info.GetName()); onnx::TypeProto_Tensor* type_proto_tensor = value_info_proto.mutable_type()->mutable_tensor_type(); type_proto_tensor->set_elem_type(ort_elem_type); - // If there are no dimensions in the shape, do not set a TensorShapeProto. Otherwise, it always looks - // like a scalar value. - if (!ort_dims.empty()) { + // If there is no shape, do not set a TensorShapeProto. + if (has_shape) { onnx::TensorShapeProto* shape_proto = type_proto_tensor->mutable_shape(); for (size_t dim_idx = 0; dim_idx < ort_dims.size(); dim_idx++) { diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 196eeae17b8d4..e6bbebdbf3ab8 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -6580,6 +6580,13 @@ struct OrtApi { _In_ size_t byte_size, _Outptr_ OrtExternalInitializerInfo** out); /// @} + /** \brief Fetch whether the tensor has shape information. + * \param[in] info The OrtTensorTypeAndShapeInfo instance. + * \return true if the tensor has shape information, false otherwise. + * + * \since Version 1.23 + */ + ORT_API_T(bool, TensorTypeAndShape_HasShape, _In_ const OrtTensorTypeAndShapeInfo* info); }; /* diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index 3045648d17cd2..d3a8856455c49 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -1782,6 +1782,7 @@ struct TensorTypeAndShapeInfoImpl : Base { void GetSymbolicDimensions(const char** values, size_t values_count) const; ///< Wraps OrtApi::GetSymbolicDimensions std::vector GetSymbolicDimensions() const; + bool HasShape() const; ///< Wraps OrtApi::TensorTypeAndShape_HasShape std::vector GetShape() const; ///< Uses GetDimensionsCount & GetDimensions to return a std::vector of the shape }; diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index cb6448ad12a81..8ee057f51eb20 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -1982,7 +1982,12 @@ template inline size_t TensorTypeAndShapeInfoImpl::GetElementCount() const { size_t out; ThrowOnError(GetApi().GetTensorShapeElementCount(this->p_, &out)); - return static_cast(out); + return out; +} + +template +inline bool TensorTypeAndShapeInfoImpl::HasShape() const { + return GetApi().TensorTypeAndShape_HasShape(this->p_); } template @@ -2004,8 +2009,12 @@ inline void TensorTypeAndShapeInfoImpl::GetSymbolicDimensions(const char** va template inline std::vector TensorTypeAndShapeInfoImpl::GetSymbolicDimensions() const { - std::vector out(GetDimensionsCount(), nullptr); - ThrowOnError(GetApi().GetSymbolicDimensions(this->p_, out.data(), out.size())); + std::vector out; + size_t dim_count = GetDimensionsCount(); + if (dim_count > 0) { + out.resize(dim_count, nullptr); + ThrowOnError(GetApi().GetSymbolicDimensions(this->p_, out.data(), out.size())); + } return out; } diff --git a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h index 7eb5f7659a365..64a434e2fe301 100644 --- a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h +++ b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h @@ -408,3 +408,10 @@ static const char* const kOrtSessionOptionsDisableModelCompile = "session.disabl // Note: UNSUPPORTED models always fail regardless of this setting. static const char* const kOrtSessionOptionsFailOnSuboptimalCompiledModel = "session.fail_on_suboptimal_compiled_model"; + +// THIS OPTION IS NOT A REGULAR SESSION OPTION SINCE IT CAN BE MODIFIED AT ANY TIME +// Meant to be used with SetEpDynamicOptions +// options for HTP performance mode: "burst", "balanced", "default", "high_performance", +// "high_power_saver", "low_balanced", "extreme_power_saver", "low_power_saver", "power_saver", +// "sustained_high_performance". Default to "default". +static const char* const kOrtEpDynamicOptionsQnnHtpPerformanceMode = "ep.dynamic.qnn_htp_performance_mode"; diff --git a/onnxruntime/core/framework/onnxruntime_typeinfo.cc b/onnxruntime/core/framework/onnxruntime_typeinfo.cc index 1c446840b7938..dd6c072ffb656 100644 --- a/onnxruntime/core/framework/onnxruntime_typeinfo.cc +++ b/onnxruntime/core/framework/onnxruntime_typeinfo.cc @@ -170,7 +170,7 @@ std::unique_ptr OrtTypeInfo::FromOrtValue(const OrtValue& value) { const Tensor& tensor = value.Get(); const auto* tensor_data_type = tensor.DataType(); if (tensor_data_type != nullptr) { - auto type_shape = OrtTensorTypeAndShapeInfo::GetTensorShapeAndType(tensor.Shape(), *tensor_data_type); + auto type_shape = OrtTensorTypeAndShapeInfo::GetTensorShapeAndType(&tensor.Shape(), *tensor_data_type); return MakePtr(ONNX_TYPE_TENSOR, std::move(type_shape)); } return MakePtr(ONNX_TYPE_TENSOR); @@ -181,7 +181,7 @@ std::unique_ptr OrtTypeInfo::FromOrtValue(const OrtValue& value) { const SparseTensor& tensor = value.Get(); const auto* tensor_data_type = tensor.DataType(); if (tensor_data_type != nullptr) { - auto type_shape = OrtTensorTypeAndShapeInfo::GetTensorShapeAndType(tensor.DenseShape(), *tensor_data_type); + auto type_shape = OrtTensorTypeAndShapeInfo::GetTensorShapeAndType(&tensor.DenseShape(), *tensor_data_type); return MakePtr(ONNX_TYPE_SPARSETENSOR, std::move(type_shape)); } return MakePtr(ONNX_TYPE_SPARSETENSOR); @@ -194,8 +194,7 @@ std::unique_ptr OrtTypeInfo::FromOrtValue(const OrtValue& value) { const auto* tensor_data_type = value.Get().DataType(); ORT_ENFORCE(tensor_data_type != nullptr, "OrtValue is TensorSequence type but has no element Tensor DataType."); - TensorShape void_shape = {}; - auto type_shape = OrtTensorTypeAndShapeInfo::GetTensorShapeAndType(void_shape, *tensor_data_type); + auto type_shape = OrtTensorTypeAndShapeInfo::GetTensorShapeAndType(nullptr, *tensor_data_type); auto type_info = MakePtr(ONNX_TYPE_TENSOR, std::move(type_shape)); auto sequence_type_info = std::make_unique(std::move(type_info)); return MakePtr(std::move(sequence_type_info)); @@ -303,9 +302,9 @@ std::unique_ptr OrtTypeInfo::FromTypeProto(const ONNX_NAMESPACE::Ty assert(false); } } - type_shape = OrtTensorTypeAndShapeInfo::GetTensorShapeAndType(std::move(shape_data), &dim_params, input); + type_shape = OrtTensorTypeAndShapeInfo::GetTensorShapeAndType(&shape_data, &dim_params, input); } else { - type_shape = OrtTensorTypeAndShapeInfo::GetTensorShapeAndType(TensorShape(), nullptr, input); + type_shape = OrtTensorTypeAndShapeInfo::GetTensorShapeAndType(nullptr, nullptr, input); } result = MakePtr(ten_type, std::move(type_shape)); diff --git a/onnxruntime/core/framework/tensor_type_and_shape.cc b/onnxruntime/core/framework/tensor_type_and_shape.cc index bef8df51f6d03..cbf1a953819d3 100644 --- a/onnxruntime/core/framework/tensor_type_and_shape.cc +++ b/onnxruntime/core/framework/tensor_type_and_shape.cc @@ -44,7 +44,7 @@ ORT_API(void, OrtApis::ReleaseTensorTypeAndShapeInfo, _Frees_ptr_opt_ OrtTensorT ORT_API_STATUS_IMPL(OrtApis::SetTensorElementType, _Inout_ OrtTensorTypeAndShapeInfo* this_ptr, enum ONNXTensorElementDataType type) { API_IMPL_BEGIN - this_ptr->type = type; + this_ptr->SetElementType(type); return nullptr; API_IMPL_END } @@ -56,10 +56,16 @@ ORT_API_STATUS_IMPL(OrtApis::SetDimensions, OrtTensorTypeAndShapeInfo* info, return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "dim_values must be -1 (symbolic dimension) or larger."); } - auto num_dims = std::max(dim_count, info->dim_params.size()); + OrtTensorTypeAndShapeInfo::ShapeInfo shape_info; + size_t num_dims = dim_count; + if (info->HasShape()) { + shape_info.dim_params = *info->GetDimParams(); + num_dims = std::max(num_dims, shape_info.dim_params.size()); + } // make shape and dim_values consistent - info->dim_params.resize(num_dims, ""); + // and preserve existing symbolic dimension names if any + shape_info.dim_params.resize(num_dims); onnxruntime::TensorShapeVector dims; dims.resize(num_dims, -1); @@ -68,7 +74,9 @@ ORT_API_STATUS_IMPL(OrtApis::SetDimensions, OrtTensorTypeAndShapeInfo* info, dims[idx] = dim_values[idx]; } - info->shape = onnxruntime::TensorShape(dims); + shape_info.shape = onnxruntime::TensorShape(dims); + + info->SetShape(std::move(shape_info)); return nullptr; API_IMPL_END @@ -76,58 +84,77 @@ ORT_API_STATUS_IMPL(OrtApis::SetDimensions, OrtTensorTypeAndShapeInfo* info, ORT_API_STATUS_IMPL(OrtApis::GetTensorElementType, _In_ const struct OrtTensorTypeAndShapeInfo* info, _Out_ ONNXTensorElementDataType* out) { - *out = info->type; + *out = info->GetElementType(); return nullptr; } ORT_API_STATUS_IMPL(OrtApis::GetDimensionsCount, _In_ const struct OrtTensorTypeAndShapeInfo* info, _Out_ size_t* out) { - *out = info->shape.NumDimensions(); + *out = (info->HasShape()) ? info->GetShape()->NumDimensions() : 0; return nullptr; } ORT_API_STATUS_IMPL(OrtApis::GetDimensions, _In_ const struct OrtTensorTypeAndShapeInfo* info, _Out_ int64_t* dim_values, size_t dim_values_length) { - info->shape.CopyDims(dim_values, dim_values_length); + if (info->HasShape()) { + info->GetShape()->CopyDims(dim_values, dim_values_length); + } + // else we should probably return an error, but for backward compatibility with the previous implementation we don't. return nullptr; } ORT_API_STATUS_IMPL(OrtApis::GetSymbolicDimensions, _In_ const struct OrtTensorTypeAndShapeInfo* info, _Out_writes_all_(dim_params_length) const char** names, size_t dim_params_length) { - for (size_t idx = 0, end = std::min(info->dim_params.size(), dim_params_length); idx < end; ++idx) { - names[idx] = info->dim_params[idx].c_str(); + if (info->HasShape()) { + size_t end = info->GetShape()->NumDimensions(); + end = std::min(end, dim_params_length); + const auto& symbolic_dims = *info->GetDimParams(); + for (size_t idx = 0; idx < end; ++idx) { + names[idx] = symbolic_dims[idx].c_str(); + } } - + // else we should probably return an error, but for backward compatibility with the previous implementation we don't. return nullptr; } +ORT_API(bool, OrtApis::TensorTypeAndShape_HasShape, _In_ const struct OrtTensorTypeAndShapeInfo* info) { + return info->HasShape(); +} + ORT_API_STATUS_IMPL(OrtApis::SetSymbolicDimensions, _In_ struct OrtTensorTypeAndShapeInfo* info, _In_ const char** names, _In_ size_t dim_params_length) { - auto num_dims = std::max(info->shape.NumDimensions(), dim_params_length); - - // make shape and dim_values consistent - if (num_dims > info->shape.NumDimensions()) { - auto dim_values = info->shape.AsShapeVector(); - dim_values.resize(num_dims, -1); - info->shape = onnxruntime::TensorShape(dim_values); + size_t num_dims = dim_params_length; + onnxruntime::TensorShapeVector shape_vec; + if (info->HasShape()) { + num_dims = std::max(num_dims, info->GetShape()->NumDimensions()); + if (num_dims > 0) { + shape_vec = info->GetShape()->AsShapeVector(); + } } - info->dim_params.clear(); - info->dim_params.resize(num_dims, ""); + OrtTensorTypeAndShapeInfo::ShapeInfo shape_info; + shape_vec.resize(num_dims, -1); + shape_info.shape = onnxruntime::TensorShape(shape_vec); + std::vector dim_params(num_dims); for (size_t idx = 0; idx < dim_params_length; ++idx) { - info->dim_params[idx] = names[idx]; + if (names[idx] != nullptr) { + dim_params[idx] = names[idx]; + } } + shape_info.dim_params = std::move(dim_params); + info->SetShape(std::move(shape_info)); + return nullptr; } ORT_API_STATUS_IMPL(OrtApis::GetTensorShapeElementCount, _In_ const OrtTensorTypeAndShapeInfo* this_ptr, _Out_ size_t* out) { API_IMPL_BEGIN - *out = SafeInt{this_ptr->shape.Size()}; + *out = SafeInt{(this_ptr->HasShape()) ? this_ptr->GetShape()->Size() : 0}; return nullptr; API_IMPL_END } @@ -220,33 +247,47 @@ ONNXTensorElementDataType MLDataTypeToOnnxRuntimeTensorElementDataType( std::unique_ptr OrtTensorTypeAndShapeInfo::GetTensorShapeAndTypeHelper( ONNXTensorElementDataType type, - onnxruntime::TensorShape shape, + const onnxruntime::TensorShape* shape, const std::vector* dim_params) { auto type_and_shape = std::make_unique(); - type_and_shape->type = type; - type_and_shape->shape = std::move(shape); + type_and_shape->SetElementType(type); + + if (shape == nullptr && dim_params == nullptr) { + return type_and_shape; + } + + ShapeInfo shape_info; + size_t num_dims = (shape != nullptr) ? shape->NumDimensions() : 0; + num_dims = std::max(num_dims, (dim_params != nullptr) ? dim_params->size() : 0); + + onnxruntime::TensorShapeVector shape_vec; + if (shape != nullptr) { + shape_vec = shape->AsShapeVector(); + } + shape_vec.resize(num_dims, -1); + shape_info.shape = onnxruntime::TensorShape(shape_vec); if (dim_params != nullptr) { - type_and_shape->dim_params = *dim_params; - } else { - type_and_shape->dim_params.resize(type_and_shape->shape.NumDimensions(), ""); + shape_info.dim_params = *dim_params; } + shape_info.dim_params.resize(num_dims); + type_and_shape->SetShape(std::move(shape_info)); return type_and_shape; } std::unique_ptr OrtTensorTypeAndShapeInfo::GetTensorShapeAndType( - onnxruntime::TensorShape shape, + const onnxruntime::TensorShape* shape, const onnxruntime::DataTypeImpl& tensor_data_type) { ONNXTensorElementDataType type = MLDataTypeToOnnxRuntimeTensorElementDataType(&tensor_data_type); if (ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED == type) { ORT_NOT_IMPLEMENTED("Tensor type is undefined"); } - return GetTensorShapeAndTypeHelper(type, std::move(shape), nullptr); + return GetTensorShapeAndTypeHelper(type, shape, nullptr); } std::unique_ptr OrtTensorTypeAndShapeInfo::GetTensorShapeAndType( - onnxruntime::TensorShape shape, + const onnxruntime::TensorShape* shape, const std::vector* dim_params, const ONNX_NAMESPACE::TypeProto& type_proto) { auto value_case = type_proto.value_case(); @@ -259,7 +300,8 @@ std::unique_ptr OrtTensorTypeAndShapeInfo::GetTensorS if (ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED == type) { ORT_NOT_IMPLEMENTED("Tensor type is undefined"); } - return GetTensorShapeAndTypeHelper(type, std::move(shape), dim_params); + + return GetTensorShapeAndTypeHelper(type, shape, dim_params); } ORT_API_STATUS_IMPL(OrtApis::GetTensorTypeAndShape, @@ -276,14 +318,14 @@ ORT_API_STATUS_IMPL(OrtApis::GetTensorTypeAndShape, const Tensor& tensor = v->Get(); shape = &tensor.Shape(); data_type = tensor.DataType(); - auto ptr = OrtTensorTypeAndShapeInfo::GetTensorShapeAndType(*shape, *data_type); + auto ptr = OrtTensorTypeAndShapeInfo::GetTensorShapeAndType(shape, *data_type); *out = ptr.release(); } else { #if !defined(DISABLE_SPARSE_TENSORS) const SparseTensor& tensor = v->Get(); shape = &tensor.DenseShape(); data_type = tensor.DataType(); - auto ptr = OrtTensorTypeAndShapeInfo::GetTensorShapeAndType(*shape, *data_type); + auto ptr = OrtTensorTypeAndShapeInfo::GetTensorShapeAndType(shape, *data_type); *out = ptr.release(); #else ORT_NOT_IMPLEMENTED("SparseTensor is not supported in this build."); @@ -302,7 +344,7 @@ ORT_API_STATUS_IMPL(OrtApis::GetSparseTensorValuesTypeAndShape, _In_ const OrtVa #if !defined(DISABLE_SPARSE_TENSORS) const auto& sparse_tensor = SparseTensor::GetSparseTensorFromOrtValue(*v); const auto& values = sparse_tensor.Values(); - auto ptr = OrtTensorTypeAndShapeInfo::GetTensorShapeAndType(values.Shape(), *values.DataType()); + auto ptr = OrtTensorTypeAndShapeInfo::GetTensorShapeAndType(&values.Shape(), *values.DataType()); *out = ptr.release(); return nullptr; #else @@ -344,7 +386,7 @@ ORT_API_STATUS_IMPL(OrtApis::GetSparseTensorIndicesTypeShape, _In_ const OrtValu API_IMPL_BEGIN #if !defined(DISABLE_SPARSE_TENSORS) const Tensor& indices_tensor = GetIndicesTensor(*v, indices_format); - auto ptr = OrtTensorTypeAndShapeInfo::GetTensorShapeAndType(indices_tensor.Shape(), *indices_tensor.DataType()); + auto ptr = OrtTensorTypeAndShapeInfo::GetTensorShapeAndType(&indices_tensor.Shape(), *indices_tensor.DataType()); *out = ptr.release(); return nullptr; #else diff --git a/onnxruntime/core/framework/tensor_type_and_shape.h b/onnxruntime/core/framework/tensor_type_and_shape.h index 4bc1f46c00132..eae5e1a8319af 100644 --- a/onnxruntime/core/framework/tensor_type_and_shape.h +++ b/onnxruntime/core/framework/tensor_type_and_shape.h @@ -3,6 +3,7 @@ #pragma once #include +#include #include #include @@ -19,11 +20,36 @@ class DataTypeImpl; struct OrtTensorTypeAndShapeInfo { public: - ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; - onnxruntime::TensorShape shape; - // dim_param values. empty string if dim_value or no dim_param was specified. - // one entry per dimension in shape. only guaranteed to be populated for graph inputs and outputs - std::vector dim_params; + void SetElementType(ONNXTensorElementDataType element_type) noexcept { + this->type = element_type; + } + + ONNXTensorElementDataType GetElementType() const noexcept { + return type; + } + + struct ShapeInfo { + onnxruntime::TensorShape shape; + // dim_param values. empty string if dim_value or no dim_param was specified. + // one entry per dimension in shape. only guaranteed to be populated for graph inputs and outputs + std::vector dim_params; + }; + + bool HasShape() const noexcept { + return shape_info.has_value(); + } + + const onnxruntime::TensorShape* GetShape() const noexcept { + return shape_info ? &shape_info->shape : nullptr; + } + + const std::vector* GetDimParams() const noexcept { + return shape_info ? &shape_info->dim_params : nullptr; + } + + void SetShape(ShapeInfo shape) { + this->shape_info = std::move(shape); + } OrtTensorTypeAndShapeInfo(); ~OrtTensorTypeAndShapeInfo(); @@ -31,15 +57,15 @@ struct OrtTensorTypeAndShapeInfo { // Utils static std::unique_ptr GetTensorShapeAndTypeHelper( ONNXTensorElementDataType type, - onnxruntime::TensorShape shape, + const onnxruntime::TensorShape* shape, const std::vector* dim_params); static std::unique_ptr GetTensorShapeAndType( - onnxruntime::TensorShape shape, + const onnxruntime::TensorShape* shape, const onnxruntime::DataTypeImpl& tensor_data_type); static std::unique_ptr GetTensorShapeAndType( - onnxruntime::TensorShape shape, + const onnxruntime::TensorShape* shape, const std::vector* dim_params, const ONNX_NAMESPACE::TypeProto&); @@ -54,6 +80,10 @@ struct OrtTensorTypeAndShapeInfo { // Copy ops are public because std::make_unique above requires them to be accessible OrtTensorTypeAndShapeInfo(const OrtTensorTypeAndShapeInfo& other); OrtTensorTypeAndShapeInfo& operator=(const OrtTensorTypeAndShapeInfo& other); + + private: + ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; + std::optional shape_info; }; constexpr ONNXTensorElementDataType TensorDataTypeToOnnxRuntimeTensorElementDataType(int32_t dtype); diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index 9a97711996343..3f6443aa73d4c 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -6190,21 +6190,25 @@ ValueInfoProto ModelEditorValueInfoToOnnx(const onnxruntime::ModelEditorValueInf value_info_proto.set_name(vi.name); auto* tensor = value_info_proto.mutable_type()->mutable_tensor_type(); - const OrtTensorTypeAndShapeInfo& tensor_info = *vi.type_info->tensor_type_info.get(); - tensor->set_elem_type(tensor_info.type); - - auto& shape = *tensor->mutable_shape(); - - size_t idx = 0; - for (auto dim : tensor_info.shape.GetDims()) { - auto& dim_proto = *shape.add_dim(); - if (dim >= 0) { - dim_proto.set_dim_value(dim); - } else { - const std::string& dim_param = tensor_info.dim_params[idx]; - // if empty leave the new dim_proto with neither dim_value nor dim_param set. this represents an 'unknown' dim - if (!dim_param.empty()) { - dim_proto.set_dim_param(dim_param); + const OrtTensorTypeAndShapeInfo& tensor_info = *vi.type_info->tensor_type_info; + tensor->set_elem_type(tensor_info.GetElementType()); + + if (tensor_info.HasShape()) { + auto& shape = *tensor->mutable_shape(); + + size_t idx = 0; + const auto dims = tensor_info.GetShape()->GetDims(); + const auto& dim_params = *tensor_info.GetDimParams(); + for (auto dim : dims) { + auto& dim_proto = *shape.add_dim(); + if (dim >= 0) { + dim_proto.set_dim_value(dim); + } else { + const std::string& dim_param = dim_params[idx]; + // if empty leave the new dim_proto with neither dim_value nor dim_param set. this represents an 'unknown' dim + if (!dim_param.empty()) { + dim_proto.set_dim_param(dim_param); + } } } } diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h index 2e93584095343..8ed6352e7baa7 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -953,7 +953,7 @@ extern "C" { MLAS_SBGEMM_FLOAT_KERNEL MlasSbgemmKernelZero; MLAS_SBGEMM_FLOAT_KERNEL MlasSbgemmKernelAdd; #endif -#if defined(MLAS_TARGET_ARM64) +#if defined(MLAS_TARGET_ARM64) && defined(MLAS_USE_ARM_NEON_NCHWC) MLAS_CONV_FLOAT_KERNEL MlasConvNchwFloatKernelNeon; MLAS_CONV_FLOAT_KERNEL MlasConvNchwcFloatKernelNeon; MLAS_CONV_DEPTHWISE_FLOAT_KERNEL MlasConvDepthwiseFloatKernelNeon; @@ -1347,12 +1347,14 @@ struct MLAS_PLATFORM { const MLAS_GEMM_QUANT_DISPATCH* GemmU8U8Dispatch; const MLAS_GEMM_QUANT_DISPATCH* GemmU8S8Dispatch; const MLAS_GEMM_QUANT_DISPATCH* GemmS8S8Dispatch; +#if defined(MLAS_USE_ARM_NEON_NCHWC) MLAS_CONV_FLOAT_KERNEL* ConvNchwFloatKernel; MLAS_CONV_FLOAT_KERNEL* ConvNchwcFloatKernel; MLAS_CONV_DEPTHWISE_FLOAT_KERNEL* ConvDepthwiseFloatKernel; MLAS_CONV_POINTWISE_FLOAT_KERNEL* ConvPointwiseFloatKernel; MLAS_POOL_FLOAT_KERNEL* PoolFloatKernel[MlasPoolingKindCount]; uint32_t NchwcBlockSize; +#endif #endif const MLAS_SYMM_QGEMM_DISPATCH* SymmQgemmDispatch{nullptr}; diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp index 81067095401e7..46fa150395d75 100644 --- a/onnxruntime/core/mlas/lib/platform.cpp +++ b/onnxruntime/core/mlas/lib/platform.cpp @@ -560,6 +560,7 @@ Return Value: this->SoftmaxDispatch = &MlasSoftmaxDispatchNeon; this->EltwiseDispatch = &MlasEltwiseDispatchNeon; +#if defined(MLAS_USE_ARM_NEON_NCHWC) this->ConvNchwFloatKernel = MlasConvNchwFloatKernelNeon; this->ConvNchwcFloatKernel = MlasConvNchwcFloatKernelNeon; this->ConvDepthwiseFloatKernel = MlasConvDepthwiseFloatKernelNeon; @@ -568,6 +569,7 @@ Return Value: this->PoolFloatKernel[MlasAveragePoolingExcludePad] = MlasPoolAverageExcludePadFloatKernelNeon; this->PoolFloatKernel[MlasAveragePoolingIncludePad] = MlasPoolAverageIncludePadFloatKernelNeon; this->NchwcBlockSize = MLAS_NEON_NCHWC_BLOCK_SIZE; +#endif // // Check if the processor supports ASIMD dot product instructions. diff --git a/onnxruntime/core/mlas/lib/sconv.h b/onnxruntime/core/mlas/lib/sconv.h index 94e657638975a..12ccff2b7ea33 100644 --- a/onnxruntime/core/mlas/lib/sconv.h +++ b/onnxruntime/core/mlas/lib/sconv.h @@ -19,7 +19,11 @@ Module Name: // Define the convolution kernel flags. // +#if defined(MLAS_USE_ARM_NEON_NCHWC) + #define MLAS_CONV_KERNEL_FLAG_ACCUMULATE_OUTPUT 0x00000001 #define MLAS_CONV_KERNEL_FLAG_BIAS_ADDITION 0x00000002 #define MLAS_CONV_KERNEL_FLAG_RELU_ACTIVATION 0x00000004 -#define MLAS_CONV_KERNEL_FLAG_OTHER_ACTIVATION 0x00000008 \ No newline at end of file +#define MLAS_CONV_KERNEL_FLAG_OTHER_ACTIVATION 0x00000008 + +#endif diff --git a/onnxruntime/core/mlas/lib/sconv_kernel_neon.cpp b/onnxruntime/core/mlas/lib/sconv_kernel_neon.cpp index 3ecad66a32886..4c5f50adb929c 100644 --- a/onnxruntime/core/mlas/lib/sconv_kernel_neon.cpp +++ b/onnxruntime/core/mlas/lib/sconv_kernel_neon.cpp @@ -14,6 +14,8 @@ Module Name: --*/ +#if defined(MLAS_USE_ARM_NEON_NCHWC) + #include "mlasi.h" #include "sconv.h" @@ -58,7 +60,7 @@ void const size_t InputWidthElements = InputWidth / sizeof(float); const size_t DilatedInputWidthElements = DilatedInputWidth / sizeof(float); - (void)InputStride; + MLAS_UNREFERENCED_PARAMETER(InputStride); const size_t TotalOutputCount = OutputCountLeftPad + OutputCount + OutputCountRightPad; @@ -100,7 +102,7 @@ void const float* input_base = Input + output_idx * StrideWidthElements + kh * DilatedInputWidthElements + kw * DilationWidthElements; - if (IsNchwcFormat) { + if constexpr (IsNchwcFormat) { for (size_t filterBlock = 0; filterBlock < BlockSize; filterBlock++) { const float* input_element = input_base + filterBlock; const float* input_row_start = InputBase + kh * DilatedInputWidthElements; @@ -343,7 +345,7 @@ void const size_t InputStrideElements = InputStride / sizeof(float); const size_t DilatedInputWidthElements = DilatedInputWidth / sizeof(float); - (void)InputStrideElements; + MLAS_UNREFERENCED_PARAMETER(InputStrideElements); const size_t InputWidthElements = InputWidth / sizeof(float); @@ -518,3 +520,5 @@ void } } } + +#endif diff --git a/onnxruntime/core/mlas/lib/snchwc.cpp b/onnxruntime/core/mlas/lib/snchwc.cpp index 2fc27d6d4ad7f..6f3423a792509 100644 --- a/onnxruntime/core/mlas/lib/snchwc.cpp +++ b/onnxruntime/core/mlas/lib/snchwc.cpp @@ -101,7 +101,7 @@ Return Value: --*/ { -#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) || defined(MLAS_TARGET_ARM64) +#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) || (defined(MLAS_TARGET_ARM64) && defined(MLAS_USE_ARM_NEON_NCHWC)) return GetMlasPlatform().NchwcBlockSize; #else return 1; @@ -674,7 +674,7 @@ struct MLAS_NCHWC_CONV_NCHWC_ALGORITHM : MLAS_NCHWC_GROUPED_CONV_ALGORITHM const size_t BlockedOutputWidth = BlockSize * OutputWidth; -#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) || defined(MLAS_TARGET_ARM64) +#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) || (defined(MLAS_TARGET_ARM64) && defined(MLAS_USE_ARM_NEON_NCHWC)) MLAS_CONV_FLOAT_KERNEL* Kernel = GetMlasPlatform().ConvNchwcFloatKernel; #else MLAS_CONV_FLOAT_KERNEL* Kernel = MlasConvNchwcFloatKernel; @@ -784,7 +784,7 @@ struct MLAS_NCHWC_CONV_NCHW_ALGORITHM : MLAS_NCHWC_GROUPED_CONV_ALGORITHM const size_t BlockedOutputWidth = BlockSize * OutputWidth; -#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) || defined(MLAS_TARGET_ARM64) +#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) || (defined(MLAS_TARGET_ARM64) && defined(MLAS_USE_ARM_NEON_NCHWC)) MLAS_CONV_FLOAT_KERNEL* Kernel = GetMlasPlatform().ConvNchwFloatKernel; #else MLAS_CONV_FLOAT_KERNEL* Kernel = MlasConvNchwFloatKernel; @@ -879,7 +879,7 @@ struct MLAS_NCHWC_CONV_POINTWISE_ALGORITHM : MLAS_NCHWC_GROUPED_CONV_ALGORITHM const size_t FilterStrideBytes = BlockSize * InputChannels * sizeof(float); const size_t OutputStrideBytes = BlockSize * OutputSize * sizeof(float); -#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) || defined(MLAS_TARGET_ARM64) +#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) || (defined(MLAS_TARGET_ARM64) && defined(MLAS_USE_ARM_NEON_NCHWC)) MLAS_CONV_POINTWISE_FLOAT_KERNEL* Kernel = GetMlasPlatform().ConvPointwiseFloatKernel; #else MLAS_CONV_POINTWISE_FLOAT_KERNEL* Kernel = MlasConvPointwiseFloatKernel; @@ -1016,7 +1016,7 @@ struct MLAS_NCHWC_CONV_DEPTHWISE_ALGORITHM : MLAS_NCHWC_CONV_ALGORITHM const size_t BlockedOutputWidth = BlockSize * OutputWidth; -#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) || defined(MLAS_TARGET_ARM64) +#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) || (defined(MLAS_TARGET_ARM64) && defined(MLAS_USE_ARM_NEON_NCHWC)) MLAS_CONV_DEPTHWISE_FLOAT_KERNEL* Kernel = GetMlasPlatform().ConvDepthwiseFloatKernel; #else MLAS_CONV_DEPTHWISE_FLOAT_KERNEL* Kernel = MlasConvDepthwiseFloatKernel; @@ -1093,7 +1093,7 @@ struct MLAS_NCHWC_CONV_DEPTHWISE_ALGORITHM : MLAS_NCHWC_CONV_ALGORITHM struct MLAS_NCHWC_POOL_ALGORITHM : MLAS_NCHWC_NN_ALGORITHM { -#if !defined(MLAS_TARGET_AMD64) && !defined(MLAS_TARGET_LARCH64) && !defined(MLAS_TARGET_ARM64) +#if !defined(MLAS_TARGET_AMD64) && !defined(MLAS_TARGET_LARCH64) && !(defined(MLAS_TARGET_ARM64) && defined(MLAS_USE_ARM_NEON_NCHWC)) static MLAS_POOL_FLOAT_KERNEL* const PoolKernels[]; #endif @@ -1131,7 +1131,7 @@ struct MLAS_NCHWC_POOL_ALGORITHM : MLAS_NCHWC_NN_ALGORITHM const size_t DilatedInputWidthBytes = BlockSize * DilationHeight * InputWidth * sizeof(float); const size_t InputStrideBytes = DilatedInputWidthBytes - KernelWidth * DilationWidthBytes; -#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) || defined(MLAS_TARGET_ARM64) +#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) || (defined(MLAS_TARGET_ARM64) && defined(MLAS_USE_ARM_NEON_NCHWC)) MLAS_POOL_FLOAT_KERNEL* Kernel = GetMlasPlatform().PoolFloatKernel[WorkBlock->PoolingKind]; #else MLAS_POOL_FLOAT_KERNEL* Kernel = PoolKernels[WorkBlock->PoolingKind]; @@ -1197,7 +1197,7 @@ struct MLAS_NCHWC_POOL_ALGORITHM : MLAS_NCHWC_NN_ALGORITHM } }; -#if !defined(MLAS_TARGET_AMD64) && !defined(MLAS_TARGET_LARCH64) && !defined(MLAS_TARGET_ARM64) +#if !defined(MLAS_TARGET_AMD64) && !defined(MLAS_TARGET_LARCH64) && !(defined(MLAS_TARGET_ARM64) && defined(MLAS_USE_ARM_NEON_NCHWC)) MLAS_POOL_FLOAT_KERNEL* const MLAS_NCHWC_POOL_ALGORITHM::PoolKernels[] = { @@ -1621,7 +1621,7 @@ Return Value: } } -#if !defined(MLAS_TARGET_AMD64) && !defined(MLAS_TARGET_LARCH64) && !defined(MLAS_TARGET_ARM64) +#if !defined(MLAS_TARGET_AMD64) && !defined(MLAS_TARGET_LARCH64) && !(defined(MLAS_TARGET_ARM64) && defined(MLAS_USE_ARM_NEON_NCHWC)) // // Convolution and pooling kernel stubs for architectures that do not yet have diff --git a/onnxruntime/core/mlas/lib/spool_kernel_neon.cpp b/onnxruntime/core/mlas/lib/spool_kernel_neon.cpp index 8cca036d54c3a..588362584791b 100644 --- a/onnxruntime/core/mlas/lib/spool_kernel_neon.cpp +++ b/onnxruntime/core/mlas/lib/spool_kernel_neon.cpp @@ -14,6 +14,8 @@ Module Name: --*/ +#if defined(MLAS_USE_ARM_NEON_NCHWC) + #include "mlasi.h" constexpr size_t BlockSize = MLAS_PLATFORM::MLAS_NEON_NCHWC_BLOCK_SIZE; @@ -287,3 +289,5 @@ void false // ExcludePad = false ); } + +#endif diff --git a/onnxruntime/core/optimizer/nchwc_transformer.cc b/onnxruntime/core/optimizer/nchwc_transformer.cc index f094a48e10c33..3d1e5ccfdc4d5 100644 --- a/onnxruntime/core/optimizer/nchwc_transformer.cc +++ b/onnxruntime/core/optimizer/nchwc_transformer.cc @@ -609,6 +609,10 @@ void NchwcTransformerImpl::TransformBinary(Node& node, bool add_node) { nchwc_inputs.push_back(nchwc_input); } + if (nchwc_inputs.empty()) { + return; + } + auto* nchwc_input_0 = nchwc_inputs[0]; const int64_t channels = nchwc_inputs[0]->channels_; diff --git a/onnxruntime/core/optimizer/qdq_transformer/weight_bias_quantization.cc b/onnxruntime/core/optimizer/qdq_transformer/weight_bias_quantization.cc index 4efaec325292a..c6898ca1da2e5 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/weight_bias_quantization.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/weight_bias_quantization.cc @@ -4,6 +4,7 @@ #include "core/optimizer/qdq_transformer/weight_bias_quantization.h" #include "core/common/common.h" +#include "core/providers/common.h" #include "core/util/qmath.h" #include "core/graph/graph_utils.h" #include "core/graph/graph_viewer.h" @@ -128,6 +129,14 @@ Status WeightBiasQuantization::ApplyImpl(Graph& graph, bool& modified, int graph int64_t axis = 1; if (auto axis_iter = dq_attrs.find("axis"); axis_iter != dq_attrs.end()) { axis = axis_iter->second.i(); + const ONNX_NAMESPACE::TensorShapeProto* weight_shape = weight_arg->Shape(); + if (!weight_shape && dq_1->InputDefs()[0]) { + weight_shape = dq_1->InputDefs()[0]->Shape(); + } + if (axis < 0 && !weight_shape) { + continue; + } + axis = HandleNegativeAxis(axis, weight_shape->dim_size()); } int64_t expected_axis = 0; diff --git a/onnxruntime/core/optimizer/transformer_memcpy.cc b/onnxruntime/core/optimizer/transformer_memcpy.cc index 9d49c16391f78..46311303639ab 100644 --- a/onnxruntime/core/optimizer/transformer_memcpy.cc +++ b/onnxruntime/core/optimizer/transformer_memcpy.cc @@ -16,7 +16,7 @@ namespace onnxruntime { static ProviderTypeToProviderMap GetProvidersByType( const InlinedVector>& providers) { ProviderTypeToProviderMap providers_by_type{}; - for (const auto provider : providers) { + for (const auto& provider : providers) { providers_by_type.emplace(provider->Type(), provider); } return providers_by_type; @@ -100,7 +100,7 @@ static const onnx::TensorProto* GetInitializer(const Graph& graph, const std::st // and mainly provides the subgraph recursion functionality Status MemcpyTransformer::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const { - for (const auto provider : providers_) { + for (const auto& provider : providers_) { const auto& provider_type = provider->Type(); if (!utils::ProviderIsCpuBased(*provider)) { TransformerMemcpyImpl copy_impl(graph, *provider, providers_by_type_); diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc index b404a8924745b..2bdbfb9c1c62e 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc @@ -1576,6 +1576,17 @@ Status QNNExecutionProvider::SetEpDynamicOptions(gsl::span ke LOGS_DEFAULT(ERROR) << "Invalid EP Workload Type: " << value; return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid EP Workload Type."); } + } else if (key == kOrtEpDynamicOptionsQnnHtpPerformanceMode) { + auto backend_type = qnn_backend_manager_->GetQnnBackendType(); + if (qnn::QnnBackendType::HTP != backend_type && qnn::QnnBackendType::DSP != backend_type) { + return Status::OK(); + } + qnn::HtpPerformanceMode htp_performance_mode = qnn::HtpPerformanceMode::kHtpDefault; + ParseHtpPerformanceMode(value, htp_performance_mode); + if (GetPerThreadContext().IsHtpPowerConfigIdValid()) { + ORT_RETURN_IF_ERROR(qnn_backend_manager_->SetHtpPowerConfig(GetPerThreadContext().GetHtpPowerConfigId(), + htp_performance_mode)); + } } else { LOGS_DEFAULT(ERROR) << "EP Dynamic Option \"" << key << "\" is not currently supported."; return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Unsupported EP Dynamic Option"); diff --git a/onnxruntime/core/providers/webnn/builders/helper.cc b/onnxruntime/core/providers/webnn/builders/helper.cc index 142d64caa64aa..9056ea6d86f65 100644 --- a/onnxruntime/core/providers/webnn/builders/helper.cc +++ b/onnxruntime/core/providers/webnn/builders/helper.cc @@ -226,13 +226,27 @@ bool AreDataTypesSame(const std::string_view op_type, return true; } -bool IsSupportedDataType(const int32_t onnx_data_type, const emscripten::val& webnn_supported_data_types) { +bool IsSupportedDataType(const int32_t onnx_data_type, + const emscripten::val& wnn_limits, + const std::string_view webnn_op_type, + const std::string_view webnn_input_output_name) { auto it = onnx_to_webnn_data_type_map.find(static_cast(onnx_data_type)); if (it == onnx_to_webnn_data_type_map.end()) return false; const std::string_view webnn_data_type = it->second; + // MLOpSupportLimits has different structure. Certain WebNN ops have input and output name, + // special cases like 'constant', 'input' and 'output' have no input or output name. + emscripten::val webnn_supported_data_types = + webnn_input_output_name.empty() + ? wnn_limits[std::string(webnn_op_type)]["dataTypes"] + : wnn_limits[std::string(webnn_op_type)][std::string(webnn_input_output_name)]["dataTypes"]; + + if (webnn_supported_data_types.isUndefined()) { + return false; + } + // Check if WebNN supports the data type. bool is_supported = webnn_supported_data_types.call("includes", emscripten::val(std::string(webnn_data_type))) @@ -240,7 +254,8 @@ bool IsSupportedDataType(const int32_t onnx_data_type, const emscripten::val& we if (webnn_data_type == "int64" && !is_supported && - webnn_supported_data_types.call("includes", emscripten::val("int32")).as()) { + webnn_supported_data_types.call("includes", emscripten::val("int32")).as() && + !wnn_limits["constant"]["dataTypes"].call("includes", emscripten::val("int64")).as()) { // Current context doesn't support int64, but int32 is supported. // We can use int32 as a workaround. is_supported = true; @@ -280,8 +295,7 @@ bool IsDataTypeSupportedByWebNNOp(const std::string_view onnx_op_type, << webnn_input_output_name << "]"; return false; } - if (!IsSupportedDataType( - onnx_data_type, wnn_limits[std::string(webnn_op_type)][std::string(webnn_input_output_name)]["dataTypes"])) { + if (!IsSupportedDataType(onnx_data_type, wnn_limits, webnn_op_type, webnn_input_output_name)) { LOGS(logger, VERBOSE) << "[" << onnx_op_type << "] " << onnx_input_output_name << "'s data type: [" << onnx_data_type << "] is not supported by WebNN op [" << webnn_op_type << "] for now"; return false; diff --git a/onnxruntime/core/providers/webnn/builders/helper.h b/onnxruntime/core/providers/webnn/builders/helper.h index da1fac6d1ad05..baedb98a34c28 100644 --- a/onnxruntime/core/providers/webnn/builders/helper.h +++ b/onnxruntime/core/providers/webnn/builders/helper.h @@ -268,7 +268,10 @@ inline bool GetWebNNOpInputs(const std::string_view onnx_op_type, bool AreDataTypesSame(const std::string_view op_type, gsl::span input_types, const logging::Logger& logger); -bool IsSupportedDataType(const int32_t onnx_data_type, const emscripten::val& webnn_supported_data_types); +bool IsSupportedDataType(const int32_t onnx_data_type, + const emscripten::val& wnn_limits, + const std::string_view webnn_op_type, + const std::string_view webnn_input_output_name); bool IsDataTypeSupportedByOp(const std::string_view onnx_op_type, const int32_t onnx_data_type, const emscripten::val& wnn_limits, diff --git a/onnxruntime/core/providers/webnn/builders/impl/gru_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/gru_op_builder.cc index 95e75a3083cc2..01b3353b5c908 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/gru_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/gru_op_builder.cc @@ -237,15 +237,15 @@ bool GruOpBuilder::HasSupportedOutputsImpl(const Node& node, bool Y_h_supported = has_Y_h && GetType(*output_defs[1], Y_h_type, logger); if (Y_supported && !Y_h_supported) { - return IsDataTypeSupportedByOp(op_type, Y_type, wnn_limits, "outputs", "Y", logger); + return IsDataTypeSupportedByOp(op_type, Y_type, wnn_limits, "output1", "Y", logger); } else if (!Y_supported && Y_h_supported) { - return IsDataTypeSupportedByOp(op_type, Y_h_type, wnn_limits, "outputs", "Y_h", logger); + return IsDataTypeSupportedByOp(op_type, Y_h_type, wnn_limits, "output0", "Y_h", logger); } else if (Y_supported && Y_h_supported) { if (Y_type != Y_h_type) { LOGS(logger, VERBOSE) << "[GRU] Output data types must be the same."; return false; } - return IsDataTypeSupportedByOp(op_type, Y_type, wnn_limits, "outputs", "Y", logger); + return IsDataTypeSupportedByOp(op_type, Y_type, wnn_limits, "output1", "Y", logger); } else { LOGS(logger, VERBOSE) << "[GRU] No output found."; return false; diff --git a/onnxruntime/core/providers/webnn/builders/impl/lstm_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/lstm_op_builder.cc index 04d59e2f30d15..97d611fe4b817 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/lstm_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/lstm_op_builder.cc @@ -259,13 +259,13 @@ bool LstmOpBuilder::HasSupportedOutputsImpl(const Node& node, bool has_Y_c = TensorExists(output_defs, 2); if (has_Y && GetType(*output_defs[0], Y_type, logger)) { - return IsDataTypeSupportedByOp(op_type, Y_type, wnn_limits, "outputs", "Y", logger); + return IsDataTypeSupportedByOp(op_type, Y_type, wnn_limits, "output2", "Y", logger); } if (has_Y_h && GetType(*output_defs[1], Y_h_type, logger)) { - return IsDataTypeSupportedByOp(op_type, Y_h_type, wnn_limits, "outputs", "Y_h", logger); + return IsDataTypeSupportedByOp(op_type, Y_h_type, wnn_limits, "output0", "Y_h", logger); } if (has_Y_c && GetType(*output_defs[2], Y_c_type, logger)) { - return IsDataTypeSupportedByOp(op_type, Y_c_type, wnn_limits, "outputs", "Y_c", logger); + return IsDataTypeSupportedByOp(op_type, Y_c_type, wnn_limits, "output1", "Y_c", logger); } return false; diff --git a/onnxruntime/core/providers/webnn/builders/model_builder.cc b/onnxruntime/core/providers/webnn/builders/model_builder.cc index 5ee4a9daa1407..d12806cbcfbb1 100644 --- a/onnxruntime/core/providers/webnn/builders/model_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/model_builder.cc @@ -229,7 +229,7 @@ Status ModelBuilder::RegisterInitializers() { desc.set("shape", emscripten::val::array(dims)); const auto data_type = tensor.data_type(); emscripten::val operand = emscripten::val::object(); - if (IsSupportedDataType(data_type, wnn_limits_["constant"]["dataTypes"])) { + if (IsSupportedDataType(data_type, wnn_limits_, "constant", "")) { ORT_RETURN_IF_NOT(SetWebnnDataType(desc, data_type), "WebNN backend does not support data type: ", data_type); ORT_RETURN_IF_ERROR(RegisterConstant(tensor, operand, desc, logger_)); } else { diff --git a/onnxruntime/core/session/custom_ops.cc b/onnxruntime/core/session/custom_ops.cc index 2a898a2b0bf9f..00f5017a55db0 100644 --- a/onnxruntime/core/session/custom_ops.cc +++ b/onnxruntime/core/session/custom_ops.cc @@ -80,7 +80,7 @@ struct OrtShapeInferContext { auto tensor_shape = ::onnxruntime::utils::GetTensorShapeFromTensorShapeProto(shape_proto); auto symbolic_dims = GetSymbolicDims(shape_proto); input_type_shapes_.emplace_back( - OrtTensorTypeAndShapeInfo::GetTensorShapeAndTypeHelper(elem_type, tensor_shape, &symbolic_dims).release()); + OrtTensorTypeAndShapeInfo::GetTensorShapeAndTypeHelper(elem_type, &tensor_shape, &symbolic_dims)); } } @@ -94,19 +94,21 @@ struct OrtShapeInferContext { onnxruntime::Status SetOutputTypeShape(size_t index, const OrtTensorTypeAndShapeInfo* info) const { ORT_RETURN_IF_NOT(info, "Invalid shape info"); ONNX_NAMESPACE::TensorShapeProto shape_proto; - const auto& symbolic_dims = info->dim_params; - const auto& integer_dims = info->shape.GetDims(); - ORT_RETURN_IF_NOT(symbolic_dims.size() == integer_dims.size(), "symbolic and integer dims mismatch!"); - for (size_t ith = 0; ith < symbolic_dims.size(); ith++) { - auto* dim_proto = shape_proto.add_dim(); - if (symbolic_dims[ith].size() > 0) { - dim_proto->set_dim_param(symbolic_dims[ith]); - } else { - dim_proto->set_dim_value(integer_dims[ith]); + if (info->HasShape()) { + const auto& symbolic_dims = *info->GetDimParams(); + const auto integer_dims = info->GetShape()->GetDims(); + ORT_RETURN_IF_NOT(symbolic_dims.size() == integer_dims.size(), "symbolic and integer dims mismatch!"); + for (size_t ith = 0, end = symbolic_dims.size(); ith < end; ith++) { + auto* dim_proto = shape_proto.add_dim(); + if (symbolic_dims[ith].size() > 0) { + dim_proto->set_dim_param(symbolic_dims[ith]); + } else { + dim_proto->set_dim_value(integer_dims[ith]); + } } } ONNX_NAMESPACE::updateOutputShape(ctx_, index, shape_proto); - ONNX_NAMESPACE::updateOutputElemType(ctx_, index, info->type); + ONNX_NAMESPACE::updateOutputElemType(ctx_, index, info->GetElementType()); return onnxruntime::Status::OK(); } diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index d0fe6291c2e03..9b258d0983570 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -4228,6 +4228,7 @@ static constexpr OrtApi ort_api_1_to_23 = { &OrtApis::Graph_GetModelMetadata, &OrtApis::GetModelCompatibilityForEpDevices, &OrtApis::CreateExternalInitializerInfo, + &OrtApis::TensorTypeAndShape_HasShape, }; // OrtApiBase can never change as there is no way to know what version of OrtApiBase is returned by OrtGetApiBase. diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index 78616c7b3973e..f016bb3215330 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -150,6 +150,7 @@ ORT_API_STATUS_IMPL(GetDimensionsCount, _In_ const OrtTensorTypeAndShapeInfo* in ORT_API_STATUS_IMPL(GetDimensions, _In_ const OrtTensorTypeAndShapeInfo* info, _Out_ int64_t* dim_values, size_t dim_values_length); ORT_API_STATUS_IMPL(GetSymbolicDimensions, _In_ const OrtTensorTypeAndShapeInfo* info, _Out_writes_all_(dim_params_length) const char* dim_params[], size_t dim_params_length); +ORT_API(bool, TensorTypeAndShape_HasShape, _In_ const OrtTensorTypeAndShapeInfo* info); ORT_API_STATUS_IMPL(GetTensorShapeElementCount, _In_ const OrtTensorTypeAndShapeInfo* info, _Out_ size_t* out); ORT_API_STATUS_IMPL(GetTensorTypeAndShape, _In_ const OrtValue* value, _Outptr_ OrtTensorTypeAndShapeInfo** out); ORT_API_STATUS_IMPL(GetTypeInfo, _In_ const OrtValue* value, _Outptr_result_maybenull_ OrtTypeInfo** out); diff --git a/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc b/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc index c8829423fbe26..55245420db37a 100644 --- a/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc +++ b/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc @@ -3,6 +3,7 @@ #include "core/session/plugin_ep/ep_plugin_provider_interfaces.h" +#include #include #include #include @@ -117,6 +118,17 @@ static OrtDevice GetOrtDeviceForPluginEp(gsl::span ep_ return device_memory_info != nullptr ? device_memory_info->device : OrtDevice(); } +static const Node* FindFirstNodeAssignedToOtherEP(const std::string& ep_type, + gsl::span ep_nodes) { + auto node_iter = std::find_if(ep_nodes.begin(), ep_nodes.end(), + [&ep_type](const EpNode* node) -> bool { + const auto& node_ep_type = node->GetInternalNode().GetExecutionProviderType(); + return !node_ep_type.empty() && node_ep_type != ep_type; + }); + + return node_iter != ep_nodes.end() ? &(*node_iter)->GetInternalNode() : nullptr; +} + PluginExecutionProvider::PluginExecutionProvider(UniqueOrtEp ep, const OrtSessionOptions& session_options, OrtEpFactory& ep_factory, gsl::span ep_devices, @@ -158,9 +170,11 @@ PluginExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_vie ORT_UNUSED_PARAMETER(resource_accountant); // TODO: Add support? Not used by prioritized EPs ORT_UNUSED_PARAMETER(kernel_lookup); // TODO: Add support? Not used by prioritized EPs, so probably not needed? + const logging::Logger& logger = GetLogger() != nullptr ? *GetLogger() : logging::LoggingManager::DefaultLogger(); + std::unique_ptr ep_graph = nullptr; if (Status status = EpGraph::Create(graph_viewer, ep_graph); !status.IsOK()) { - LOGS_DEFAULT(ERROR) << "Failed to create OrtGraph: " << status.ToString(); + LOGS(logger, ERROR) << "Failed to create OrtGraph for " << Type() << ": " << status.ToString(); return {}; } @@ -168,7 +182,7 @@ PluginExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_vie Status status = ToStatusAndRelease(ort_ep_->GetCapability(ort_ep_.get(), ep_graph->ToExternal(), &api_graph_support_info)); if (!status.IsOK()) { - LOGS_DEFAULT(ERROR) << "OrtEp::GetCapability() failed with error: " << status.ToString(); + LOGS(logger, ERROR) << "OrtEp::GetCapability() for " << Type() << " failed with error: " << status.ToString(); return {}; } @@ -182,12 +196,39 @@ PluginExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_vie // Create ComputeCapability instances from OrtEpGraphSupportInfo::NodeGrouping instances. for (const OrtEpGraphSupportInfo::NodeGrouping& node_grouping : api_graph_support_info.node_groupings) { + // Skip this node grouping if any node has already been assigned to another EP. + if (const Node* node_for_other_ep = FindFirstNodeAssignedToOtherEP(Type(), node_grouping.nodes); + node_for_other_ep != nullptr) { + LOGS(logger, WARNING) << "OrtEp::GetCapability() specified nodes that cannot be assigned to " << Type() << ". " + << "Found one or more nodes that were already assigned to a different EP named '" + << node_for_other_ep->GetExecutionProviderType() << "'. Ex: " + << node_for_other_ep->OpType() << " node with name '" + << node_for_other_ep->Name() << "'."; + continue; + } + if (node_grouping.kind == OrtEpGraphSupportInfo::NodeGroupingKind::kSingleAssignedNode) { + if (node_grouping.nodes.size() != 1) { + // The EpGraphSupportInfo_AddSingleNode() C API should already return an error if the EP tries to provide + // an invalid node. However, we check here too just in case this changes. + LOGS(logger, ERROR) << "OrtEp::GetCapability() for " << Type() << " did not specify exactly one valid node " + << "when calling EpGraphSupportInfo_AddSingleNode()."; + return {}; + } + auto indexed_sub_graph = std::make_unique(); indexed_sub_graph->nodes.push_back(node_grouping.nodes[0]->GetInternalNode().Index()); result.push_back(std::make_unique(std::move(indexed_sub_graph))); } else if (node_grouping.kind == OrtEpGraphSupportInfo::NodeGroupingKind::kFusedNode) { + if (node_grouping.nodes.empty()) { + // The EpGraphSupportInfo_AddNodesToFuse() C API should already return an error if the EP tries to provide + // an empty array of nodes from OrtEp::GetCapability(). However, we check here too just in case this changes. + LOGS(logger, ERROR) << "OrtEp::GetCapability() for " << Type() << " set an empty array of nodes " + << "when specifying supported nodes."; + return {}; + } + std::unordered_set node_set; node_set.reserve(node_grouping.nodes.size()); @@ -207,27 +248,29 @@ PluginExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_vie this->Type(), this->Type(), /*node_unit_map*/ nullptr, node_grouping.fusion_options.drop_constant_initializers); - if (capabilities.size() > 1) { - LOGS_DEFAULT(ERROR) << "OrtEp::GetCapability() set nodes that cannot be fused together. " - << "Please ensure that the nodes provided to EpGraphSupportInfo_AddFusedNodes() do not " + if (capabilities.size() != 1) { + LOGS(logger, ERROR) << "OrtEp::GetCapability() for " << Type() << " set nodes that cannot be fused together. " + << "Please ensure that the nodes provided to EpGraphSupportInfo_AddNodesToFuse() do not " << "have an unsupported node in any path between two of the supported nodes."; return {}; } - // Enforce that the nodes in node_set match the nodes in capabilities[0] + // Log an error if the nodes in node_set do not match the nodes in capabilities[0]. We expect this to always + // be true because we've already checked that the EP did not try to claim nodes already assigned to another EP. // TODO(adrianlizarraga): This check can be removed when we stop using utils::CreateSupportedPartitions() above. std::vector& capability_node_indices = capabilities[0]->sub_graph->nodes; std::unordered_set capability_node_indices_set(capability_node_indices.begin(), capability_node_indices.end()); - ORT_ENFORCE(node_set.size() == capability_node_indices_set.size()); - ORT_ENFORCE(std::all_of(node_set.begin(), node_set.end(), [&capability_node_indices_set](const Node* node) { - return capability_node_indices_set.count(node->Index()) != 0; - })); + if (node_set.size() != capability_node_indices_set.size()) { + LOGS(logger, ERROR) << "OrtEp::GetCapability() for " << Type() + << " set nodes that cannot all be fused together."; + return {}; + } result.push_back(std::move(capabilities[0])); } else { - LOGS_DEFAULT(ERROR) << "PluginExecutionProvider::GetCapability() has invalid NodeGroupingKind: " + LOGS(logger, ERROR) << "PluginExecutionProvider::GetCapability() has invalid NodeGroupingKind: " << static_cast(node_grouping.kind); return {}; } diff --git a/onnxruntime/test/ep_graph/test_ep_graph.cc b/onnxruntime/test/ep_graph/test_ep_graph.cc index 7e6d157799d86..a1eda6d6057ba 100644 --- a/onnxruntime/test/ep_graph/test_ep_graph.cc +++ b/onnxruntime/test/ep_graph/test_ep_graph.cc @@ -581,40 +581,39 @@ TEST(EpGraphTest, SerializeToProto_3LayerSubgraphs) { // Checks that the OrtTypeInfo obtained from the public C API matches another OrtTypeInfo // obtained from the internal ORT graph IR. -static void CheckTypeInfo(const OrtTypeInfo* api_type_info, const OrtTypeInfo* type_info) { - const OrtApi& ort_api = Ort::GetApi(); +static void CheckTypeInfo(const OrtTypeInfo* ort_api_type_info, const OrtTypeInfo* ort_type_info) { + ASSERT_NE(ort_api_type_info, nullptr); + ASSERT_NE(ort_type_info, nullptr); - ASSERT_NE(api_type_info, nullptr); - ASSERT_NE(type_info, nullptr); + Ort::ConstTypeInfo api_type_info{ort_api_type_info}; + Ort::ConstTypeInfo type_info{ort_type_info}; - ONNXType api_onnx_type = ONNX_TYPE_UNKNOWN; - ASSERT_ORTSTATUS_OK(ort_api.GetOnnxTypeFromTypeInfo(api_type_info, &api_onnx_type)); - ASSERT_EQ(api_onnx_type, type_info->type); + ASSERT_EQ(api_type_info.GetONNXType(), type_info.GetONNXType()); - if (api_onnx_type == ONNX_TYPE_TENSOR) { + if (api_type_info.GetONNXType() == ONNX_TYPE_TENSOR) { // Only validating Tensors (not checking Map, Sequence, etc.) values because these C APIs for getting // type/shape information existed long before the new ORT graph IR APIs and are tested elsewhere. - const OrtTensorTypeAndShapeInfo* api_type_shape = nullptr; - ASSERT_ORTSTATUS_OK(ort_api.CastTypeInfoToTensorInfo(api_type_info, &api_type_shape)); + auto api_type_shape = api_type_info.GetTensorTypeAndShapeInfo(); + auto type_info_shape = type_info.GetTensorTypeAndShapeInfo(); ASSERT_NE(api_type_shape, nullptr); - ONNXTensorElementDataType api_elem_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; - ASSERT_ORTSTATUS_OK(ort_api.GetTensorElementType(api_type_shape, &api_elem_type)); - ASSERT_EQ(api_elem_type, type_info->tensor_type_info->type); + ONNXTensorElementDataType api_elem_type = api_type_shape.GetElementType(); + ASSERT_EQ(api_elem_type, type_info_shape.GetElementType()); - size_t api_num_dims = 0; - ASSERT_ORTSTATUS_OK(ort_api.GetDimensionsCount(api_type_shape, &api_num_dims)); - ASSERT_EQ(api_num_dims, type_info->tensor_type_info->shape.NumDimensions()); + ASSERT_EQ(api_type_shape.HasShape(), type_info_shape.HasShape()); + if (api_type_shape.HasShape()) { + const size_t api_num_dims = api_type_shape.GetDimensionsCount(); + ASSERT_EQ(api_num_dims, type_info_shape.GetDimensionsCount()); - std::vector api_dims(api_num_dims, 0); - ASSERT_ORTSTATUS_OK(ort_api.GetDimensions(api_type_shape, api_dims.data(), api_dims.size())); - ASSERT_EQ(gsl::span(api_dims), type_info->tensor_type_info->shape.GetDims()); + auto api_dims = api_type_shape.GetShape(); + ASSERT_EQ(api_dims, type_info_shape.GetShape()); - std::vector api_dim_syms(api_num_dims, nullptr); - ASSERT_ORTSTATUS_OK(ort_api.GetSymbolicDimensions(api_type_shape, api_dim_syms.data(), api_dim_syms.size())); - const std::vector& dim_syms = type_info->tensor_type_info->dim_params; - for (size_t dim_idx = 0; dim_idx < api_num_dims; dim_idx++) { - ASSERT_EQ(std::string(api_dim_syms[dim_idx]), dim_syms[dim_idx]); + const std::vector api_dim_syms = api_type_shape.GetSymbolicDimensions(); + const std::vector dim_syms = type_info_shape.GetSymbolicDimensions(); + ASSERT_EQ(api_dim_syms.size(), dim_syms.size()); + for (size_t dim_idx = 0; dim_idx < api_num_dims; dim_idx++) { + ASSERT_EQ(std::string(api_dim_syms[dim_idx]), dim_syms[dim_idx]); + } } } } diff --git a/onnxruntime/test/framework/ep_plugin_provider_test.cc b/onnxruntime/test/framework/ep_plugin_provider_test.cc index 35f7d06fb0912..30595d5ce97b2 100644 --- a/onnxruntime/test/framework/ep_plugin_provider_test.cc +++ b/onnxruntime/test/framework/ep_plugin_provider_test.cc @@ -3,9 +3,14 @@ #include "core/session/plugin_ep/ep_plugin_provider_interfaces.h" +#include #include "gsl/gsl" #include "gtest/gtest.h" +#include "core/common/logging/sinks/file_sink.h" +#include "core/graph/graph_viewer.h" +#include "core/graph/model.h" +#include "core/optimizer/graph_optimizer_registry.h" #include "core/session/abi_devices.h" #include "core/session/onnxruntime_cxx_api.h" #include "test/util/include/asserts.h" @@ -23,6 +28,14 @@ struct ApiPtrs { const gsl::not_null ep_api; }; +static void CheckStringInFile(const PathString& filename, const std::string& look_for) { + std::ifstream ifs{filename}; + std::string content(std::istreambuf_iterator{ifs}, + std::istreambuf_iterator{}); + + EXPECT_NE(content.find(look_for), std::string::npos); +} + // Normally, a plugin EP would be implemented in a separate library. // The `test_plugin_ep` namespace contains a local implementation intended for unit testing. namespace test_plugin_ep { @@ -114,6 +127,10 @@ MakeTestOrtEpResult MakeTestOrtEp(std::vector ep_devices = { return result; } +class MockKernelLookup : public IExecutionProvider::IKernelLookup { + const KernelCreateInfo* LookUpKernel(const Node& /*node*/) const override { return nullptr; } +}; + } // namespace test_plugin_ep TEST(PluginExecutionProviderTest, GetPreferredLayout) { @@ -317,4 +334,218 @@ TEST(PluginExecutionProviderTest, InferOrtDeviceFromDeviceMemoryInfo) { #endif // !defined(ORT_NO_EXCEPTIONS) } +static void LoadModelAndAssignNodesToEp(const ORTCHAR_T* model_path, + const char* ep_name, + const std::unordered_set& ep_node_names, + /*out*/ std::shared_ptr& model) { + ASSERT_STATUS_OK(Model::Load(model_path, model, nullptr, + DefaultLoggingManager().DefaultLogger())); + + Graph& graph = model->MainGraph(); + + for (Node& node : graph.Nodes()) { + if (ep_node_names.count(node.Name()) > 0) { + node.SetExecutionProviderType(ep_name); + } + } +} + +static OrtStatus* ORT_API_CALL GetCapabilityTakeAllNodesOneGroup(OrtEp* this_ptr, const OrtGraph* graph, + OrtEpGraphSupportInfo* graph_support_info) noexcept { + auto* this_ep = static_cast(this_ptr); + + size_t num_nodes = 0; + if (OrtStatus* st = this_ep->ort_api->Graph_GetNumNodes(graph, &num_nodes); st != nullptr) { + return st; + } + + std::vector nodes(num_nodes); + if (OrtStatus* st = this_ep->ort_api->Graph_GetNodes(graph, nodes.data(), nodes.size()); st != nullptr) { + return st; + } + + if (OrtStatus* st = this_ep->ep_api->EpGraphSupportInfo_AddNodesToFuse(graph_support_info, + nodes.data(), nodes.size(), nullptr); + st != nullptr) { + return st; + } + + return nullptr; +} + +static OrtStatus* ORT_API_CALL GetCapabilityTakeAllNodesTwoGroups(OrtEp* this_ptr, const OrtGraph* graph, + OrtEpGraphSupportInfo* graph_support_info) noexcept { + auto* this_ep = static_cast(this_ptr); + + size_t num_nodes = 0; + if (OrtStatus* st = this_ep->ort_api->Graph_GetNumNodes(graph, &num_nodes); st != nullptr) { + return st; + } + + std::vector nodes(num_nodes); + if (OrtStatus* st = this_ep->ort_api->Graph_GetNodes(graph, nodes.data(), nodes.size()); st != nullptr) { + return st; + } + + // Expect at least 2 nodes. If not, this is really a testing/setup error. + if (num_nodes < 2) { + return this_ep->ort_api->CreateStatus(OrtErrorCode::ORT_FAIL, + "Expected at least two nodes in call to GetCapability"); + } + + std::vector node_group1; + std::vector node_group2; + + for (size_t i = 0; i < num_nodes; i++) { + if (i < num_nodes / 2) { + node_group1.push_back(nodes[i]); + } else { + node_group2.push_back(nodes[i]); + } + } + + if (OrtStatus* st = this_ep->ep_api->EpGraphSupportInfo_AddNodesToFuse(graph_support_info, + node_group1.data(), node_group1.size(), + nullptr); + st != nullptr) { + return st; + } + + if (OrtStatus* st = this_ep->ep_api->EpGraphSupportInfo_AddNodesToFuse(graph_support_info, + node_group2.data(), node_group2.size(), + nullptr); + st != nullptr) { + return st; + } + + return nullptr; +} + +static OrtStatus* ORT_API_CALL GetCapabilityTakeSingleNode(OrtEp* this_ptr, const OrtGraph* graph, + OrtEpGraphSupportInfo* graph_support_info) noexcept { + auto* this_ep = static_cast(this_ptr); + + size_t num_nodes = 0; + if (OrtStatus* st = this_ep->ort_api->Graph_GetNumNodes(graph, &num_nodes); st != nullptr) { + return st; + } + + std::vector nodes(num_nodes); + if (OrtStatus* st = this_ep->ort_api->Graph_GetNodes(graph, nodes.data(), nodes.size()); st != nullptr) { + return st; + } + + // Take only the first node using EpGraphSupportInfo_AddSingleNode(). + if (OrtStatus* st = this_ep->ep_api->EpGraphSupportInfo_AddSingleNode(graph_support_info, nodes[0]); + st != nullptr) { + return st; + } + + return nullptr; +} + +// Tests that GetCapability() doesn't crash if a plugin EP tries to claim a mix of unassigned nodes and +// nodes that are already assigned to another EP. +TEST(PluginExecutionProviderTest, GetCapability_ClaimNodesAssignedToOtherEP) { + std::filesystem::path log_file = ORT_TSTR("log_get_capability.txt"); + + // Helper function that loads a model (Add -> Mul -> Add) and assigns some or all of the nodes to another EP. + // Then, IExecutionProvider::GetCapability() is called to test the expected behavior. + auto run_test = [&log_file](IExecutionProvider& ep, + const std::unordered_set& nodes_for_other_ep, + const std::unordered_set& nodes_for_this_ep, + const char* expected_log_string) { + std::shared_ptr model; + ASSERT_NO_FATAL_FAILURE(LoadModelAndAssignNodesToEp(ORT_TSTR("testdata/add_mul_add.onnx"), + "OtherEp", nodes_for_other_ep, model)); + + std::filesystem::remove(log_file); + + // Call IExecutionProvider::GetCapability and check results + logs. + { + logging::LoggingManager log_manager{std::make_unique(log_file, false, false), + logging::Severity::kWARNING, false, + logging::LoggingManager::InstanceType::Temporal}; + auto file_logger = log_manager.CreateLogger("FileLogger"); + ep.SetLogger(file_logger.get()); // Make EP log to a file. + + GraphViewer graph_viewer(model->MainGraph()); + auto compute_capabilities = ep.GetCapability(graph_viewer, + test_plugin_ep::MockKernelLookup{}, + GraphOptimizerRegistry(nullptr, nullptr, file_logger.get()), + nullptr); + + ASSERT_EQ(compute_capabilities.size(), nodes_for_this_ep.empty() ? 0 : 1); + + if (compute_capabilities.size() == 1) { + ASSERT_EQ(compute_capabilities[0]->sub_graph->nodes.size(), nodes_for_this_ep.size()); + + for (NodeIndex node_index : compute_capabilities[0]->sub_graph->nodes) { + const Node* node = graph_viewer.GetNode(node_index); + ASSERT_NE(node, nullptr); + EXPECT_EQ(nodes_for_this_ep.count(node->Name()), 1); + } + } + } + + ASSERT_TRUE(std::filesystem::exists(log_file)); + EXPECT_NO_FATAL_FAILURE(CheckStringInFile(log_file, expected_log_string)); + }; + + constexpr std::array node_names = {"add_0", "mul_0", "add_1"}; + + auto [ep, ort_ep] = test_plugin_ep::MakeTestOrtEp(); + + // Load a model and assign all of its nodes to another EP named 'OtherEp'. + // The plugin EP tries to claim all nodes in a single group via EpGraphSupportInfo_AddNodesToFuse. + // IExecutionProvider::GetCapability() should return an empty result and log a warning. + ort_ep->GetCapability = GetCapabilityTakeAllNodesOneGroup; + std::unordered_set nodes_for_other_ep = {"add_0", "mul_0", "add_1"}; + std::unordered_set nodes_for_this_ep; + run_test(*ep, nodes_for_other_ep, nodes_for_this_ep, + "Found one or more nodes that were already assigned to a different EP named 'OtherEp'"); + + // Load a model and assign only one node to another EP named 'OtherEp'. + // The plugin EP tries to claim all nodes in a single group. + // IExecutionProvider::GetCapability() should return an empty result and log a warning. + ort_ep->GetCapability = GetCapabilityTakeAllNodesOneGroup; + for (const char* node_name : node_names) { + nodes_for_other_ep = std::unordered_set{node_name}; + nodes_for_this_ep = std::unordered_set{}; + run_test(*ep, nodes_for_other_ep, nodes_for_this_ep, + "Found one or more nodes that were already assigned to a different EP named 'OtherEp'"); + } + + // Load a model and assign only the last Add node to another EP named 'OtherEp'. + // The plugin EP tries to claim all nodes in the following 2 groups: (add_0), (mul_0, add_1). + // IExecutionProvider::GetCapability() will only return (add_0) because the second group has a node + // that was assigned to 'OtherEp'. + ort_ep->GetCapability = GetCapabilityTakeAllNodesTwoGroups; + nodes_for_other_ep = std::unordered_set{"add_1"}; + nodes_for_this_ep = std::unordered_set{"add_0"}; + run_test(*ep, nodes_for_other_ep, nodes_for_this_ep, + "Found one or more nodes that were already assigned to a different EP named 'OtherEp'"); + + // Load a model and assign only the first Add node to another EP named 'OtherEp'. + // The plugin EP tries to claim all nodes in the following 2 groups: (add_0), (mul_0, add_1). + // IExecutionProvider::GetCapability() will only return (mul_0, add_1) because the first group has a node + // that was assigned to 'OtherEp'. + ort_ep->GetCapability = GetCapabilityTakeAllNodesTwoGroups; + nodes_for_other_ep = std::unordered_set{"add_0"}; + nodes_for_this_ep = std::unordered_set{"mul_0", "add_1"}; + run_test(*ep, nodes_for_other_ep, nodes_for_this_ep, + "Found one or more nodes that were already assigned to a different EP named 'OtherEp'"); + + // Load a model and assign the first Add node to another EP named 'OtherEp'. + // The plugin EP will try to take only the first Add node with a single call to EpGraphSupportInfo_AddSingleNode. + // IExecutionProvider::GetCapability() will return an empty result and log a warning. + ort_ep->GetCapability = GetCapabilityTakeSingleNode; + nodes_for_other_ep = std::unordered_set{"add_0"}; + nodes_for_this_ep = std::unordered_set{}; + run_test(*ep, nodes_for_other_ep, nodes_for_this_ep, + "Found one or more nodes that were already assigned to a different EP named 'OtherEp'"); + + std::filesystem::remove(log_file); +} + } // namespace onnxruntime::test diff --git a/onnxruntime/test/framework/type_info_test.cc b/onnxruntime/test/framework/type_info_test.cc index d8ef668bf1c7e..5b9e86a6bc10a 100644 --- a/onnxruntime/test/framework/type_info_test.cc +++ b/onnxruntime/test/framework/type_info_test.cc @@ -23,8 +23,10 @@ TEST(TypeInfoTests, TensorProto) { auto tensor_type_info = OrtTypeInfo::FromTypeProto(tensor_type.value); ASSERT_EQ(ONNX_TYPE_TENSOR, tensor_type_info->type); ASSERT_NE(nullptr, tensor_type_info->tensor_type_info); - ASSERT_EQ(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, tensor_type_info->tensor_type_info->type); - ASSERT_TRUE(SpanEq(AsSpan({1, 2, 3, 4}), tensor_type_info->tensor_type_info->shape.GetDims())); + ASSERT_EQ(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, tensor_type_info->tensor_type_info->GetElementType()); + ASSERT_TRUE(tensor_type_info->tensor_type_info->HasShape()); + const auto* shape = tensor_type_info->tensor_type_info->GetShape(); + ASSERT_TRUE(SpanEq(AsSpan({1, 2, 3, 4}), shape->GetDims())); } TEST(TypeInfoTests, SequenceWithTensorElement) { @@ -38,8 +40,9 @@ TEST(TypeInfoTests, SequenceWithTensorElement) { ASSERT_EQ(ONNX_TYPE_TENSOR, tensor_type_info.type); ASSERT_NE(nullptr, tensor_type_info.tensor_type_info); - ASSERT_EQ(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, tensor_type_info.tensor_type_info->type); - ASSERT_TRUE(SpanEq(AsSpan({1, 2, 3, 4}), tensor_type_info.tensor_type_info->shape.GetDims())); + ASSERT_EQ(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, tensor_type_info.tensor_type_info->GetElementType()); + ASSERT_TRUE(tensor_type_info.tensor_type_info->HasShape()); + ASSERT_TRUE(SpanEq(AsSpan({1, 2, 3, 4}), tensor_type_info.tensor_type_info->GetShape()->GetDims())); } TEST(TypeInfoTests, OptionalWithTensorProto) { @@ -55,8 +58,9 @@ TEST(TypeInfoTests, OptionalWithTensorProto) { const auto& contained_type = *optional_type_info->optional_type_info->contained_type_; ASSERT_EQ(ONNX_TYPE_TENSOR, contained_type.type); ASSERT_NE(nullptr, contained_type.tensor_type_info); - ASSERT_EQ(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, contained_type.tensor_type_info->type); - ASSERT_TRUE(SpanEq(AsSpan({1, 2, 3, 4}), contained_type.tensor_type_info->shape.GetDims())); + ASSERT_EQ(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, contained_type.tensor_type_info->GetElementType()); + ASSERT_TRUE(contained_type.tensor_type_info->HasShape()); + ASSERT_TRUE(SpanEq(AsSpan({1, 2, 3, 4}), contained_type.tensor_type_info->GetShape()->GetDims())); } #if !defined(DISABLE_ML_OPS) @@ -75,8 +79,9 @@ TEST(TypeInfoTests, MapWithTensorValue) { ASSERT_EQ(ONNX_TYPE_TENSOR, tensor_type_info.type); ASSERT_NE(nullptr, tensor_type_info.tensor_type_info); - ASSERT_EQ(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, tensor_type_info.tensor_type_info->type); - ASSERT_TRUE(SpanEq(AsSpan({1, 2, 3, 4}), tensor_type_info.tensor_type_info->shape.GetDims())); + ASSERT_EQ(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, tensor_type_info.tensor_type_info->GetElementType()); + ASSERT_TRUE(tensor_type_info.tensor_type_info->HasShape()); + ASSERT_TRUE(SpanEq(AsSpan({1, 2, 3, 4}), tensor_type_info.tensor_type_info->GetShape()->GetDims())); } #endif diff --git a/onnxruntime/test/optimizer/qdq_transformer_test.cc b/onnxruntime/test/optimizer/qdq_transformer_test.cc index 79e3073b944ff..ba2b942f73320 100644 --- a/onnxruntime/test/optimizer/qdq_transformer_test.cc +++ b/onnxruntime/test/optimizer/qdq_transformer_test.cc @@ -5669,6 +5669,52 @@ TEST(QDQTransformerTests, WeightBiasQuantization_Gemm_Weight) { test_case(true); } +TEST(QDQTransformerTests, WeightBiasQuantization_Gemm_HandleNegativeDqAxis) { + auto test_case = [](bool use_contrib_qdq) { + auto build_test_case = [&](ModelTestBuilder& builder) { + NodeArg* input_arg = + builder.MakeInput({2, 16}, std::numeric_limits::min(), std::numeric_limits::max()); + NodeArg* weight_arg = builder.MakeInitializer({16, 16}, std::numeric_limits::min(), + std::numeric_limits::max()); + NodeArg* bias_arg = builder.MakeInitializer({16}, -0.1f, 0.1f); + + NodeArg* input_dq_arg = builder.MakeIntermediate(); + NodeArg* weight_dq_arg = builder.MakeIntermediate(); + NodeArg* gemm_dq_arg = builder.MakeIntermediate(); + NodeArg* output_arg = builder.MakeOutput(); + + builder.AddDequantizeLinearNode(input_arg, 0.001f, static_cast(0), input_dq_arg, use_contrib_qdq); + + // Per-channel quantized weight with negative axis as DQ attribute + std::vector scales = std::vector(16, 0.05f); + std::vector zp = std::vector(16, static_cast(0)); + auto& dq_node = builder.AddDequantizeLinearNode(weight_arg, scales, zp, weight_dq_arg, nullptr, use_contrib_qdq); + dq_node.AddAttribute("axis", static_cast(-1)); + + builder.AddNode("Gemm", {input_dq_arg, weight_dq_arg, bias_arg}, {gemm_dq_arg}); + builder.AddQuantizeLinearNode(gemm_dq_arg, 0.144f, static_cast(69), output_arg, use_contrib_qdq); + }; + + auto check_transformed_graph = [](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + EXPECT_EQ(op_to_count["QuantizeLinear"] + op_to_count["com.microsoft.QuantizeLinear"], 1); + EXPECT_EQ(op_to_count["DequantizeLinear"] + op_to_count["com.microsoft.DequantizeLinear"], 2 + 1); + }; + + TransformerTester(build_test_case, + check_transformed_graph, + TransformerLevel::Default, + TransformerLevel::Level1, + /*opset_version=*/20, + /*per_sample_tolerance=*/0.01, + /*relative_per_sample_tolerance=*/0.01, + /*transformer=*/std::make_unique()); + }; + + test_case(false); + test_case(true); +} + TEST(QDQTransformerTests, WeightBiasQuantization_Gemm_Weight_Bias) { auto test_case = [](bool use_contrib_qdq) { auto build_test_case = [&](ModelTestBuilder& builder) { diff --git a/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc b/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc index 8d16f85b9598c..8382258bf39b4 100644 --- a/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc +++ b/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc @@ -273,14 +273,13 @@ TEST(ConvFp16Test, Conv1D_Bias) { } TEST(ConvBF16Test, Conv2D_1) { -#ifdef USE_CUDA +#ifndef USE_CUDA + GTEST_SKIP() << "BFloat16 tests are only enabled on CUDA builds"; +#else if (!CudaHasBF16Support()) { LOGS_DEFAULT(WARNING) << "Hardware does NOT support BF16"; return; } -#else - return; -#endif OpTester test("Conv", 22); @@ -332,6 +331,7 @@ TEST(ConvBF16Test, Conv2D_1) { test.AddOutput("Y", Y_shape, expected_vals, /*no sort*/ false, 0.002f, 0.0f); test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider}); +#endif } TEST(ConvFp16Test, Conv2D_1) { diff --git a/onnxruntime/test/providers/cpu/nn/pool_fp16_op_test.cc b/onnxruntime/test/providers/cpu/nn/pool_fp16_op_test.cc index 4147cab70103a..b2b7f1701107a 100644 --- a/onnxruntime/test/providers/cpu/nn/pool_fp16_op_test.cc +++ b/onnxruntime/test/providers/cpu/nn/pool_fp16_op_test.cc @@ -323,15 +323,13 @@ TEST(PoolTest, MaxPool_DilationPadding_3d) { } TEST(PoolBF16Test, AveragePool) { -#ifdef USE_CUDA +#ifndef USE_CUDA + GTEST_SKIP() << "BFloat16 tests are only enabled on CUDA builds"; +#else if (!CudaHasBF16Support()) { LOGS_DEFAULT(WARNING) << "Hardware does NOT support BF16"; return; } -#else - return; -#endif - OpTester test("AveragePool", 22); test.AddAttribute("auto_pad", ""); @@ -411,6 +409,7 @@ TEST(PoolBF16Test, AveragePool) { test.AddInput("X", x_dims, x_vals); test.AddOutput("Y", expected_dims, expected_vals); test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider}); +#endif } TEST(PoolFp16Test, AveragePool) { diff --git a/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc b/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc index 1c8cc6f78fe63..a2f1b9b56538b 100644 --- a/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc +++ b/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc @@ -2076,6 +2076,21 @@ TEST_F(QnnHTPBackendTests, QnnEpDynamicOptions) { } catch (const std::exception& e) { EXPECT_STREQ("Unsupported EP Dynamic Option", e.what()); } + + const char* const htp_perf_mode_type[] = {"ep.dynamic.qnn_htp_performance_mode"}; + const char* const eps_type[] = {"extreme_power_saver"}; + const char* const shp_type[] = {"sustained_high_performance"}; + session.SetEpDynamicOptions(htp_perf_mode_type, shp_type, 1); + ort_output = session.Run(Ort::RunOptions{}, input_names_c.data(), ort_inputs.data(), ort_inputs.size(), + output_names_c.data(), 1); + + session.SetEpDynamicOptions(htp_perf_mode_type, eps_type, 1); + ort_output = session.Run(Ort::RunOptions{}, input_names_c.data(), ort_inputs.data(), ort_inputs.size(), + output_names_c.data(), 1); + + session.SetEpDynamicOptions(htp_perf_mode_type, shp_type, 1); + ort_output = session.Run(Ort::RunOptions{}, input_names_c.data(), ort_inputs.data(), ort_inputs.size(), + output_names_c.data(), 1); } // Implementation of OrtOutStreamWriteFunc that writes the compiled model to a file. diff --git a/onnxruntime/test/shared_lib/test_inference.cc b/onnxruntime/test/shared_lib/test_inference.cc index 8c2928670934a..f2a7ee71a363a 100644 --- a/onnxruntime/test/shared_lib/test_inference.cc +++ b/onnxruntime/test/shared_lib/test_inference.cc @@ -490,6 +490,29 @@ TEST(CApiTest, dim_param) { ASSERT_EQ(strcmp(dim_param, ""), 0); } +static std::pair LoadAndGetInputShapePresent(const ORTCHAR_T* const model_url) { + Ort::Session session(*ort_env, model_url, Ort::SessionOptions{}); + const auto input_num = session.GetInputCount(); + EXPECT_EQ(input_num, 1U); + const bool input_shape_present = session.GetInputTypeInfo(0).GetTensorTypeAndShapeInfo().HasShape(); + const auto output_num = session.GetOutputCount(); + EXPECT_EQ(output_num, 1U); + const bool output_shape_present = session.GetOutputTypeInfo(0).GetTensorTypeAndShapeInfo().HasShape(); + return {input_shape_present, output_shape_present}; +} + +TEST(CApiTest, OptionalShape) { + const ORTCHAR_T* const input_shape_model = TSTR("testdata/abs_0d_input.onnx"); + auto result = LoadAndGetInputShapePresent(input_shape_model); + ASSERT_TRUE(result.first); + ASSERT_TRUE(result.second); + + const ORTCHAR_T* const no_shape_model = TSTR("testdata/abs_0d_lostdim.onnx"); + result = LoadAndGetInputShapePresent(no_shape_model); + ASSERT_FALSE(result.first); + ASSERT_FALSE(result.second); +} + INSTANTIATE_TEST_SUITE_P(CApiTestWithProviders, CApiTestWithProvider, ::testing::Values(0, 1, 2, 3, 4)); diff --git a/onnxruntime/test/testdata/abs_0d_input.onnx b/onnxruntime/test/testdata/abs_0d_input.onnx new file mode 100644 index 0000000000000..b6e7ad7e93748 Binary files /dev/null and b/onnxruntime/test/testdata/abs_0d_input.onnx differ diff --git a/onnxruntime/test/testdata/abs_0d_input.py b/onnxruntime/test/testdata/abs_0d_input.py new file mode 100644 index 0000000000000..e6e41e0801f0e --- /dev/null +++ b/onnxruntime/test/testdata/abs_0d_input.py @@ -0,0 +1,58 @@ +""" +Run this script to recreate the original onnx model. +Example usage: +python abs_0d_input.py out_model_path.onnx +""" + +import sys + +from onnx import TensorProto, helper, save + + +def clear_field(proto, field): + proto.ClearField(field) + return proto + + +def order_repeated_field(repeated_proto, key_name, order): + order = list(order) + repeated_proto.sort(key=lambda x: order.index(getattr(x, key_name))) + + +def make_node(op_type, inputs, outputs, name=None, doc_string=None, domain=None, **kwargs): + node = helper.make_node(op_type, inputs, outputs, name, doc_string, domain, **kwargs) + if doc_string == "": + node.doc_string = "" + order_repeated_field(node.attribute, "name", kwargs.keys()) + return node + + +def make_graph(*args, doc_string=None, **kwargs): + graph = helper.make_graph(*args, doc_string=doc_string, **kwargs) + if doc_string == "": + graph.doc_string = "" + return graph + + +model = helper.make_model( + opset_imports=[ + clear_field(helper.make_operatorsetid("", 21), "domain"), + clear_field(helper.make_operatorsetid("", 1), "domain"), + clear_field(helper.make_operatorsetid("", 1), "domain"), + clear_field(helper.make_operatorsetid("", 21), "domain"), + ], + ir_version=11, + producer_name="ort_ep_utils::OrtGraphToProto", + producer_version="", + model_version=0, + graph=make_graph( + name="OpenVINOExecutionProvider_11295571201636618024_0", + inputs=[helper.make_tensor_value_info("absInput_1", TensorProto.FLOAT, shape=[])], + outputs=[helper.make_tensor_value_info("absOutput_0", TensorProto.FLOAT, shape=[])], + nodes=[make_node("Abs", inputs=["absInput_1"], outputs=["absOutput_0"], name="_0", domain="")], + ), +) + +if __name__ == "__main__" and len(sys.argv) == 2: + _, out_path = sys.argv + save(model, out_path) diff --git a/onnxruntime/test/testdata/abs_0d_lostdim.onnx b/onnxruntime/test/testdata/abs_0d_lostdim.onnx new file mode 100644 index 0000000000000..288f21241b947 Binary files /dev/null and b/onnxruntime/test/testdata/abs_0d_lostdim.onnx differ diff --git a/onnxruntime/test/testdata/abs_0d_lostdim.py b/onnxruntime/test/testdata/abs_0d_lostdim.py new file mode 100644 index 0000000000000..e91118a19739e --- /dev/null +++ b/onnxruntime/test/testdata/abs_0d_lostdim.py @@ -0,0 +1,54 @@ +""" +Run this script to recreate the original onnx model. +Example usage: +python abs_0d_lostdim.py out_model_path.onnx +""" + +import sys + +import onnx +from onnx import TensorProto, helper + + +def order_repeated_field(repeated_proto, key_name, order): + order = list(order) + repeated_proto.sort(key=lambda x: order.index(getattr(x, key_name))) + + +def make_node(op_type, inputs, outputs, name=None, doc_string=None, domain=None, **kwargs): + node = helper.make_node(op_type, inputs, outputs, name, doc_string, domain, **kwargs) + if doc_string == "": + node.doc_string = "" + order_repeated_field(node.attribute, "name", kwargs.keys()) + return node + + +def make_graph(*args, doc_string=None, **kwargs): + graph = helper.make_graph(*args, doc_string=doc_string, **kwargs) + if doc_string == "": + graph.doc_string = "" + return graph + + +model = helper.make_model( + opset_imports=[ + helper.make_operatorsetid("", 21), + helper.make_operatorsetid("com.microsoft", 1), + helper.make_operatorsetid("com.microsoft.nchwc", 1), + helper.make_operatorsetid("com.ms.internal.nhwc", 21), + ], + ir_version=11, + producer_name="ort_ep_utils::OrtGraphToProto", + doc_string="Serialized from OrtGraph", + graph=make_graph( + name="OpenVINOExecutionProvider_11295571201636618024_0", + inputs=[helper.make_tensor_value_info("absInput_1", TensorProto.FLOAT, shape=None)], + outputs=[helper.make_tensor_value_info("absOutput_0", TensorProto.FLOAT, shape=None)], + doc_string="Serialized from OrtGraph", + nodes=[make_node("Abs", inputs=["absInput_1"], outputs=["absOutput_0"], name="_0", domain="")], + ), +) + +if __name__ == "__main__" and len(sys.argv) == 2: + _, out_path = sys.argv + onnx.save(model, out_path) diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index 4438ddba014d0..e4cc4d70229b2 100644 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -879,22 +879,11 @@ def generate_build_tree( if args.use_snpe: cmake_args += ["-Donnxruntime_USE_SNPE=ON"] - # Set onnxruntime_USE_KLEIDIAI based on: - # * Default value above is NO. - # * Leave disabled if "no_kleidiai" argument was specified. - # * Enable if the target is Android and args.android_abi contains arm64* - # * Enable for a Windows cross compile build if compile target is an Arm one. - # * Finally enable if platform.machine contains "arm64" and not a WebAssembly build. This should cover the following cases: - # * Linux on Arm - # * MacOs (case must be ignored) - # * TODO Delegate responsibility for Onnxruntime_USE_KLEIDIAI = ON to CMake logic if not args.no_kleidiai: - if ( - (args.android and "arm64" in args.android_abi.lower()) - or (is_windows() and (args.arm64 or args.arm64ec or args.arm) and platform.architecture()[0] != "AMD64") - or ("arm64" in platform.machine().lower() and not args.build_wasm) - ): - cmake_args += ["-Donnxruntime_USE_KLEIDIAI=ON"] + cmake_args += ["-Donnxruntime_USE_KLEIDIAI=ON"] + + if args.enable_arm_neon_nchwc: + cmake_args += ["-Donnxruntime_USE_ARM_NEON_NCHWC=ON"] if not args.no_sve: cmake_args += ["-Donnxruntime_USE_SVE=ON"] diff --git a/tools/ci_build/build_args.py b/tools/ci_build/build_args.py index 996d46974716e..8c04f8dd46016 100644 --- a/tools/ci_build/build_args.py +++ b/tools/ci_build/build_args.py @@ -629,8 +629,18 @@ def add_execution_provider_args(parser: argparse.ArgumentParser) -> None: help="Enable CUDA kernel profiling (requires CUPTI in PATH).", ) + # --- CPU --- cpu_group = parser.add_argument_group("CPU Execution Provider") cpu_group.add_argument("--no_sve", action="store_true", help="Disable building with SVE support.") + # The following enables building ORT with NCHWc Neon ARM kernels. + # At the time of writing, it is turned OFF by default because its performance relative to "regular" NCHW kernels + # is not good at smaller thread counts. But its speed-up is non-negligible with higher thread counts on supporting + # ARM platforms. + # Once the gap is closed for smaller thread counts, it can be turned on by default. + # See https://github.com/microsoft/onnxruntime/pull/25580#issuecomment-3335056846 for benchmarking details. + cpu_group.add_argument( + "--enable_arm_neon_nchwc", action="store_true", help="Enables building with NCHWc ARM kernels." + ) # --- DNNL (formerly MKL-DNN / oneDNN) --- dnnl_group = parser.add_argument_group("DNNL Execution Provider")