diff --git a/.github/workflows/android.yml b/.github/workflows/android.yml index f9c8a423be241..695afc661dd99 100644 --- a/.github/workflows/android.yml +++ b/.github/workflows/android.yml @@ -71,8 +71,8 @@ jobs: run: | set -e -x BINARY_SIZE_THRESHOLD_ARGS="" - echo "Binary size threshold in bytes: 1436672" - BINARY_SIZE_THRESHOLD_ARGS="--threshold_size_in_bytes 1436672" + echo "Binary size threshold in bytes: 1722565" + BINARY_SIZE_THRESHOLD_ARGS="--threshold_size_in_bytes 1722565" # Ensure ANDROID_NDK_HOME is available and get its real path if [ -z "$ANDROID_NDK_HOME" ]; then diff --git a/.github/workflows/linux-wasm-ci-build-and-test-workflow.yml b/.github/workflows/linux-wasm-ci-build-and-test-workflow.yml index 6929e0ad21b85..183b299f34455 100644 --- a/.github/workflows/linux-wasm-ci-build-and-test-workflow.yml +++ b/.github/workflows/linux-wasm-ci-build-and-test-workflow.yml @@ -120,7 +120,7 @@ jobs: path: ${{ github.workspace }}/artifacts/wasm - name: Test (Node.js) (simd + threads) - # onnxruntime_test_all is currently only supported in Debug build because it requires exception, which is disabled in Release build. + # unit tests are currently only supported in Debug builds because they require exceptions, which are disabled in Release builds. if: ${{ inputs.build_config == 'Debug' }} run: | python ./tools/ci_build/build.py \ @@ -130,14 +130,14 @@ jobs: working-directory: ${{ github.workspace }} - name: Test (browser) (simd + threads) - # onnxruntime_test_all is currently only supported in Debug build because it requires exception, which is disabled in Release build. + # unit tests are currently only supported in Debug builds because they require exceptions, which are disabled in Release builds. if: ${{ inputs.build_config == 'Debug' }} run: | python ./tools/ci_build/build.py \ ${{ env.common_build_args }} \ --build_dir ${{ github.workspace }}/build/wasm_inferencing \ --wasm_run_tests_in_browser \ - --target onnxruntime_test_all \ + --targets onnxruntime_test_all onnxruntime_provider_test \ --update --build --test working-directory: ${{ github.workspace }} diff --git a/.github/workflows/windows_qnn_x64.yml b/.github/workflows/windows_qnn_x64.yml new file mode 100644 index 0000000000000..549af2b63b97e --- /dev/null +++ b/.github/workflows/windows_qnn_x64.yml @@ -0,0 +1,82 @@ +name: Windows x64 QNN CI Pipeline + +on: + push: + branches: + - main + - rel-* + pull_request: + branches: + - main + - rel-* + workflow_dispatch: + +concurrency: + group: ${{ github.workflow }}-${{ github.event_name == 'pull_request' && github.ref || github.sha }} + cancel-in-progress: true + +jobs: + build_test_qnn_ep: + name: Windows x64 QNN CI Pipeline (${{ matrix.QnnLibKind }}) + runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-vs2022-mms"] + timeout-minutes: 120 + strategy: + matrix: + QnnLibKind: [shared_lib, static_lib] + env: + AZCOPY_AUTO_LOGIN_TYPE: MSI + AZCOPY_MSI_CLIENT_ID: 63b63039-6328-442f-954b-5a64d124e5b4 + DOTNET_SKIP_FIRST_TIME_EXPERIENCE: true + ALLOW_RELEASED_ONNX_OPSET_ONLY: '1' + + steps: + - name: Checkout repository + uses: actions/checkout@v5 + + - name: Setup Python + uses: actions/setup-python@v5 + with: + python-version: '3.12' + architecture: x64 + + - name: Locate vcvarsall and Setup Env + uses: ./.github/actions/locate-vcvarsall-and-setup-env + with: + architecture: x64 + + - name: Download QNN SDK + working-directory: ${{ runner.temp }} + run: | + azcopy.exe cp --recursive https://lotusscus.blob.core.windows.net/models/qnnsdk/qnn-v2.38.0.250901 . + dir + shell: pwsh + + - name: Set QNN_SDK_ROOT environment variable + shell: pwsh + run: | + $qnn_sdk_path = Join-Path $env:RUNNER_TEMP "qnn-v2.38.0.250901" + echo "QNN_SDK_ROOT=$qnn_sdk_path" | Out-File -FilePath $env:GITHUB_ENV -Encoding utf8 -Append + echo "QNN SDK Root: $qnn_sdk_path" + dir $qnn_sdk_path + + - name: Build and Test + shell: cmd + run: | + python ${{ github.workspace }}\tools\ci_build\build.py --config RelWithDebInfo --build_dir ${{ runner.temp }}\build --cmake_generator "Visual Studio 17 2022" --build_java --build_shared_lib --use_qnn ${{ matrix.QnnLibKind }} --qnn_home %QNN_SDK_ROOT% --use_binskim_compliant_compile_flags --update --build --test --enable_onnx_tests --parallel + + - name: Run ONNX Tests + shell: cmd + working-directory: ${{ runner.temp }}\build\RelWithDebInfo\RelWithDebInfo + run: | + .\onnx_test_runner -j 1 -e qnn -i "backend_path|%QNN_SDK_ROOT%\lib\x86_64-windows-msvc\QnnCpu.dll" ${{ github.workspace }}\cmake\external\onnx\onnx\backend\test\data\node + + - name: Run float32 model tests + shell: cmd + working-directory: ${{ runner.temp }}\build\RelWithDebInfo\RelWithDebInfo + run: | + rem This step assumes the model data exists at C:\data\float32_models on the runner + if exist C:\data\float32_models ( + .\onnx_test_runner -j 1 -e qnn -i "backend_path|%QNN_SDK_ROOT%\lib\x86_64-windows-msvc\QnnCpu.dll" C:\data\float32_models + ) else ( + echo "Skipping float32 model tests: C:\data\float32_models not found." + ) diff --git a/.github/workflows/windows_webgpu.yml b/.github/workflows/windows_webgpu.yml index a2d34b2bc6338..f849bdda0dff3 100644 --- a/.github/workflows/windows_webgpu.yml +++ b/.github/workflows/windows_webgpu.yml @@ -127,22 +127,23 @@ jobs: } Remove-Item "${{ github.workspace }}\RelWithDebInfo" -Include "*.obj" -Recurse - - name: Run tests (onnxruntime_test_all) with verbose logging + - name: Run tests (onnxruntime_test_all, onnxruntime_provider_test) with verbose logging shell: pwsh run: | $env:ORT_UNIT_TEST_MAIN_LOG_LEVEL = "0" - .\onnxruntime_test_all.exe 2>.\onnxruntime_test_all_stderr.log + .\onnxruntime_test_all.exe 2> .\onnxruntime_test_stderr.log + .\onnxruntime_provider_test.exe 2>> .\onnxruntime_test_stderr.log working-directory: ${{ github.workspace }}\RelWithDebInfo\RelWithDebInfo - name: Check log file shell: cmd run: | - dir ${{ github.workspace }}\RelWithDebInfo\RelWithDebInfo\onnxruntime_test_all_stderr.log + dir ${{ github.workspace }}\RelWithDebInfo\RelWithDebInfo\onnxruntime_test_stderr.log - name: Validate shader keys uses: ./.github/actions/webgpu-validate-shader-key with: - log_file_path: ${{ github.workspace }}\RelWithDebInfo\RelWithDebInfo\onnxruntime_test_all_stderr.log + log_file_path: ${{ github.workspace }}\RelWithDebInfo\RelWithDebInfo\onnxruntime_test_stderr.log - name: Validate C# native delegates run: python tools\ValidateNativeDelegateAttributes.py diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index 3530ab03c822a..b922a78a6929d 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -109,6 +109,8 @@ function(setup_mlas_source_for_windows) ${MLAS_SRC_DIR}/eltwise_kernel_neon.cpp ${MLAS_SRC_DIR}/eltwise_kernel_neon_fp16.cpp ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8_i8mm.cpp + ${MLAS_SRC_DIR}/sconv_kernel_neon.cpp + ${MLAS_SRC_DIR}/spool_kernel_neon.cpp ) set(mlas_platform_preprocess_srcs @@ -431,6 +433,8 @@ else() ${MLAS_SRC_DIR}/eltwise_kernel_neon.h ${MLAS_SRC_DIR}/eltwise_kernel_neon.cpp ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8_i8mm.cpp + ${MLAS_SRC_DIR}/sconv_kernel_neon.cpp + ${MLAS_SRC_DIR}/spool_kernel_neon.cpp ) if (onnxruntime_USE_KLEIDIAI) setup_kleidiai() diff --git a/cmake/onnxruntime_test_pch.cmake b/cmake/onnxruntime_test_pch.cmake index 7c58c2d787596..f989774ade35b 100644 --- a/cmake/onnxruntime_test_pch.cmake +++ b/cmake/onnxruntime_test_pch.cmake @@ -5,7 +5,10 @@ if(CMAKE_CXX_COMPILER_ID MATCHES "MSVC") target_precompile_headers(onnxruntime_test_all PRIVATE "${CMAKE_CURRENT_SOURCE_DIR}/test_pch.h" ) - endif() + target_precompile_headers(onnxruntime_provider_test PRIVATE + "${CMAKE_CURRENT_SOURCE_DIR}/test_pch.h" + ) +endif() # Exclude certain files that might conflict with PCH set(PCH_EXCLUDE_FILES diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index b31849440c426..27354adb88557 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -10,11 +10,58 @@ if (onnxruntime_ENABLE_TRAINING) list(APPEND TEST_INC_DIR ${ORTTRAINING_ROOT}) endif() +# Exclude files based on CMake options. +function(filter_test_srcs test_srcs_var) + set(excluded_path_prefixes) + + if(onnxruntime_DISABLE_CONTRIB_OPS) + list(APPEND excluded_path_prefixes ${TEST_SRC_DIR}/contrib_ops) + endif() + + if(onnxruntime_DISABLE_ML_OPS) + list(APPEND excluded_path_prefixes ${TEST_SRC_DIR}/providers/cpu/ml) + endif() + + list(LENGTH excluded_path_prefixes num_excluded_path_prefixes) + + if("${num_excluded_path_prefixes}" GREATER "0") + set(filtered_test_srcs) + + foreach(test_src ${${test_srcs_var}}) + set(is_excluded false) + + foreach(excluded_path_prefix ${excluded_path_prefixes}) + cmake_path(ABSOLUTE_PATH test_src OUTPUT_VARIABLE test_src_absolute) + + cmake_path(IS_PREFIX excluded_path_prefix ${test_src_absolute} NORMALIZE is_excluded) + + if (is_excluded) + break() + endif() + endforeach() + + if(NOT is_excluded) + list(APPEND filtered_test_srcs ${test_src}) + endif() + endforeach() + + set(${test_srcs_var} ${filtered_test_srcs} PARENT_SCOPE) + endif() +endfunction() + set(disabled_warnings) function(AddTest) cmake_parse_arguments(_UT "DYN" "TARGET" "LIBS;SOURCES;DEPENDS;TEST_ARGS" ${ARGN}) list(REMOVE_DUPLICATES _UT_SOURCES) + filter_test_srcs(_UT_SOURCES) + + message(VERBOSE "AddTest() TARGET: ${_UT_TARGET}") + message(VERBOSE "AddTest() SOURCES:") + foreach(ut_src ${_UT_SOURCES}) + message(VERBOSE " ${ut_src}") + endforeach() + if (IOS) onnxruntime_add_executable(${_UT_TARGET} ${TEST_SRC_DIR}/xctest/orttestmain.m) else() @@ -253,6 +300,50 @@ function(AddTest) endif() endfunction(AddTest) +# Given a list of test source files, in variable `all_srcs_var`, partition it into two lists: +# - a list of the provider-related test source files, stored in `provider_test_srcs_var` +# - a list of the other remaining test source files, stored in `other_srcs_var` +# +# In particular, provider-related test source files are located in these root paths: +# - onnxruntime/test/contrib_ops +# - onnxruntime/test/providers +function(partition_provider_test_srcs + all_srcs_var provider_test_srcs_var other_srcs_var) + set(provider_test_src_roots + ${TEST_SRC_DIR}/contrib_ops + ${TEST_SRC_DIR}/providers + ) + + function(is_provider_test_src src_var result_var) + cmake_path(ABSOLUTE_PATH ${src_var} OUTPUT_VARIABLE src_absolute) + + foreach(provider_test_src_root ${provider_test_src_roots}) + cmake_path(IS_PREFIX provider_test_src_root ${src_absolute} NORMALIZE src_matches_root) + if(src_matches_root) + set(${result_var} true PARENT_SCOPE) + return() + endif() + endforeach() + + set(${result_var} false PARENT_SCOPE) + endfunction() + + set(provider_test_srcs) + set(other_srcs) + + foreach(src ${${all_srcs_var}}) + is_provider_test_src(src is_provider_test_src_result) + if(is_provider_test_src_result) + list(APPEND provider_test_srcs ${src}) + else() + list(APPEND other_srcs ${src}) + endif() + endforeach() + + set(${provider_test_srcs_var} ${provider_test_srcs} PARENT_SCOPE) + set(${other_srcs_var} ${other_srcs} PARENT_SCOPE) +endfunction() + # general program entrypoint for C++ unit tests set(onnxruntime_unittest_main_src "${TEST_SRC_DIR}/unittest_main/test_main.cc") @@ -377,29 +468,24 @@ if(NOT onnxruntime_MINIMAL_BUILD AND NOT onnxruntime_REDUCED_OPS_BUILD) "${TEST_SRC_DIR}/providers/*.h" "${TEST_SRC_DIR}/providers/*.cc" "${TEST_SRC_DIR}/opaque_api/test_opaque_api.cc" - "${TEST_SRC_DIR}/framework/TestAllocatorManager.cc" - "${TEST_SRC_DIR}/framework/TestAllocatorManager.h" - "${TEST_SRC_DIR}/framework/test_utils.cc" - "${TEST_SRC_DIR}/framework/test_utils.h" + "${TEST_SRC_DIR}/contrib_ops/*.h" + "${TEST_SRC_DIR}/contrib_ops/*.cc" + "${TEST_SRC_DIR}/contrib_ops/math/*.h" + "${TEST_SRC_DIR}/contrib_ops/math/*.cc" ) - if(NOT onnxruntime_DISABLE_CONTRIB_OPS) - list(APPEND onnxruntime_test_providers_src_patterns - "${TEST_SRC_DIR}/contrib_ops/*.h" - "${TEST_SRC_DIR}/contrib_ops/*.cc" - "${TEST_SRC_DIR}/contrib_ops/math/*.h" - "${TEST_SRC_DIR}/contrib_ops/math/*.cc") - endif() - else() set(onnxruntime_test_providers_src_patterns - "${TEST_SRC_DIR}/framework/test_utils.cc" - "${TEST_SRC_DIR}/framework/test_utils.h" # TODO: Add anything that is needed for testing a minimal build ) endif() -file(GLOB onnxruntime_test_providers_src CONFIGURE_DEPENDS ${onnxruntime_test_providers_src_patterns}) +list(LENGTH onnxruntime_test_providers_src_patterns onnxruntime_test_providers_src_patterns_length) +if(onnxruntime_test_providers_src_patterns_length GREATER 0) + file(GLOB onnxruntime_test_providers_src CONFIGURE_DEPENDS ${onnxruntime_test_providers_src_patterns}) +else() + set(onnxruntime_test_providers_src) +endif() if(NOT onnxruntime_MINIMAL_BUILD AND NOT onnxruntime_REDUCED_OPS_BUILD) file(GLOB_RECURSE onnxruntime_test_providers_cpu_src CONFIGURE_DEPENDS @@ -407,10 +493,6 @@ if(NOT onnxruntime_MINIMAL_BUILD AND NOT onnxruntime_REDUCED_OPS_BUILD) ) endif() -if(onnxruntime_DISABLE_ML_OPS) - list(FILTER onnxruntime_test_providers_cpu_src EXCLUDE REGEX ".*/ml/.*") -endif() - list(APPEND onnxruntime_test_providers_src ${onnxruntime_test_providers_cpu_src}) if (onnxruntime_USE_CUDA AND NOT onnxruntime_MINIMAL_BUILD AND NOT onnxruntime_REDUCED_OPS_BUILD) @@ -481,11 +563,12 @@ if (onnxruntime_USE_RKNPU) list(APPEND onnxruntime_test_providers_src ${onnxruntime_test_providers_rknpu_src}) endif() +set(onnxruntime_test_internal_testing_ep_src) if (NOT onnxruntime_MINIMAL_BUILD OR onnxruntime_EXTENDED_MINIMAL_BUILD) file(GLOB_RECURSE onnxruntime_test_providers_internal_testing_src CONFIGURE_DEPENDS - "${TEST_SRC_DIR}/providers/internal_testing/*" + "${TEST_SRC_DIR}/internal_testing_ep/*" ) - list(APPEND onnxruntime_test_providers_src ${onnxruntime_test_providers_internal_testing_src}) + list(APPEND onnxruntime_test_internal_testing_ep_src ${onnxruntime_test_providers_internal_testing_src}) endif() set (ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR "${TEST_SRC_DIR}/shared_lib") @@ -534,42 +617,10 @@ set (onnxruntime_webgpu_delay_load_test_SRC # the order of libraries should be maintained, with higher libraries being added first in the list set(onnxruntime_test_common_libs - onnxruntime_test_utils - onnxruntime_common -) - -set(onnxruntime_test_ir_libs - onnxruntime_test_utils - onnxruntime_graph - onnxruntime_common -) - -set(onnxruntime_test_optimizer_libs - onnxruntime_test_utils - onnxruntime_framework - onnxruntime_util - onnxruntime_graph + onnxruntime_unittest_utils onnxruntime_common ) -set(onnxruntime_test_framework_libs - onnxruntime_test_utils - onnxruntime_framework - onnxruntime_util - onnxruntime_graph - ${ONNXRUNTIME_MLAS_LIBS} - onnxruntime_common - ) - -set(onnxruntime_test_server_libs - onnxruntime_test_utils - onnxruntime_test_utils_for_server -) - -if(WIN32) - list(APPEND onnxruntime_test_framework_libs Advapi32) -endif() - set (onnxruntime_test_providers_dependencies ${onnxruntime_EXTERNAL_DEPENDENCIES}) if(onnxruntime_USE_CUDA) @@ -609,7 +660,7 @@ if(onnxruntime_USE_DNNL) endif() if(onnxruntime_USE_MIGRAPHX) - list(APPEND onnxruntime_test_framework_libs onnxruntime_providers_migraphx) + list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_migraphx) endif() if(onnxruntime_USE_COREML) @@ -674,7 +725,6 @@ set(onnxruntime_test_providers_libs if(onnxruntime_USE_TENSORRT) list(APPEND onnxruntime_test_framework_src_patterns ${TEST_SRC_DIR}/providers/tensorrt/*) list(APPEND onnxruntime_test_framework_src_patterns "${ONNXRUNTIME_ROOT}/core/providers/tensorrt/tensorrt_execution_provider_utils.h") - list(APPEND onnxruntime_test_framework_libs onnxruntime_providers_tensorrt) list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_tensorrt onnxruntime_providers_shared) list(APPEND onnxruntime_test_providers_libs ${TENSORRT_LIBRARY_INFER}) endif() @@ -682,7 +732,6 @@ endif() if(onnxruntime_USE_NV) list(APPEND onnxruntime_test_framework_src_patterns ${TEST_SRC_DIR}/providers/nv_tensorrt_rtx/*) list(APPEND onnxruntime_test_framework_src_patterns "${ONNXRUNTIME_ROOT}/core/providers/nv_tensorrt_rtx/nv_execution_provider_utils.h") - list(APPEND onnxruntime_test_framework_libs onnxruntime_providers_nv_tensorrt_rtx) list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_nv_tensorrt_rtx onnxruntime_providers_shared) list(APPEND onnxruntime_test_providers_libs ${TENSORRT_LIBRARY_INFER}) endif() @@ -694,21 +743,18 @@ endif() if(onnxruntime_USE_NNAPI_BUILTIN) list(APPEND onnxruntime_test_framework_src_patterns ${TEST_SRC_DIR}/providers/nnapi/*) - list(APPEND onnxruntime_test_framework_libs onnxruntime_providers_nnapi) list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_nnapi) list(APPEND onnxruntime_test_providers_libs onnxruntime_providers_nnapi) endif() if(onnxruntime_USE_JSEP) list(APPEND onnxruntime_test_framework_src_patterns ${TEST_SRC_DIR}/providers/js/*) - list(APPEND onnxruntime_test_framework_libs onnxruntime_providers_js) list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_js) list(APPEND onnxruntime_test_providers_libs onnxruntime_providers_js) endif() if(onnxruntime_USE_WEBGPU) list(APPEND onnxruntime_test_framework_src_patterns ${TEST_SRC_DIR}/providers/webgpu/*) - list(APPEND onnxruntime_test_framework_libs onnxruntime_providers_webgpu) list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_webgpu) list(APPEND onnxruntime_test_providers_libs onnxruntime_providers_webgpu) endif() @@ -719,7 +765,6 @@ if(onnxruntime_USE_QNN AND NOT onnxruntime_MINIMAL_BUILD AND NOT onnxruntime_RED list(APPEND onnxruntime_test_framework_src_patterns ${TEST_SRC_DIR}/providers/qnn/*) list(APPEND onnxruntime_test_framework_src_patterns ${TEST_SRC_DIR}/providers/qnn/qnn_node_group/*) list(APPEND onnxruntime_test_framework_src_patterns ${TEST_SRC_DIR}/providers/qnn/optimizer/*) - list(APPEND onnxruntime_test_framework_libs onnxruntime_providers_qnn) list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_qnn) if(NOT onnxruntime_BUILD_QNN_EP_STATIC_LIB) list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_shared) @@ -728,14 +773,12 @@ endif() if(onnxruntime_USE_SNPE) list(APPEND onnxruntime_test_framework_src_patterns ${TEST_SRC_DIR}/providers/snpe/*) - list(APPEND onnxruntime_test_framework_libs onnxruntime_providers_snpe) list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_snpe) list(APPEND onnxruntime_test_providers_libs onnxruntime_providers_snpe) endif() if(onnxruntime_USE_RKNPU) list(APPEND onnxruntime_test_framework_src_patterns ${TEST_SRC_DIR}/providers/rknpu/*) - list(APPEND onnxruntime_test_framework_libs onnxruntime_providers_rknpu) list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_rknpu) list(APPEND onnxruntime_test_providers_libs onnxruntime_providers_rknpu) endif() @@ -745,28 +788,24 @@ if(onnxruntime_USE_COREML) if(APPLE) list(APPEND onnxruntime_test_framework_src_patterns ${TEST_SRC_DIR}/providers/coreml/*.mm) endif() - list(APPEND onnxruntime_test_framework_libs onnxruntime_providers_coreml coreml_proto) list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_coreml coreml_proto) list(APPEND onnxruntime_test_providers_libs onnxruntime_providers_coreml coreml_proto) endif() if(onnxruntime_USE_XNNPACK) list(APPEND onnxruntime_test_framework_src_patterns ${TEST_SRC_DIR}/providers/xnnpack/*) - list(APPEND onnxruntime_test_framework_libs onnxruntime_providers_xnnpack) list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_xnnpack) list(APPEND onnxruntime_test_providers_libs onnxruntime_providers_xnnpack) endif() if(onnxruntime_USE_AZURE) list(APPEND onnxruntime_test_framework_src_patterns ${TEST_SRC_DIR}/providers/azure/*) - list(APPEND onnxruntime_test_framework_libs onnxruntime_providers_azure) list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_azure) list(APPEND onnxruntime_test_providers_libs onnxruntime_providers_azure) endif() if (onnxruntime_USE_OPENVINO) list(APPEND onnxruntime_test_framework_src_patterns ${TEST_SRC_DIR}/providers/openvino/*) - list(APPEND onnxruntime_test_framework_libs onnxruntime_providers_openvino) list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_openvino) list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_shared) endif() @@ -803,6 +842,58 @@ target_include_directories(onnxruntime_test_utils PUBLIC "${TEST_SRC_DIR}/util/i set_target_properties(onnxruntime_test_utils PROPERTIES FOLDER "ONNXRuntimeTest") source_group(TREE ${TEST_SRC_DIR} FILES ${onnxruntime_test_utils_src}) +# onnxruntime_unittest_utils +# This is static library containing utilities that are specifically for unit tests. +# Unlike onnxruntime_test_utils, the source files here may have dependencies on internal onnxruntime code. +# Thus, onnxruntime_unittest_utils is not suitable for use in programs that don't link with internal onnxruntime +# libraries. +block() + +file(GLOB onnxruntime_unittest_utils_src CONFIGURE_DEPENDS + "${TEST_SRC_DIR}/unittest_util/*.h" + "${TEST_SRC_DIR}/unittest_util/*.cc") + +if(onnxruntime_MINIMAL_BUILD OR onnxruntime_REDUCED_OPS_BUILD) + # some exclusions from a minimal or reduced ops build + list(REMOVE_ITEM onnxruntime_unittest_utils_src + "${TEST_SRC_DIR}/unittest_util/base_tester.cc" + "${TEST_SRC_DIR}/unittest_util/base_tester.h" + "${TEST_SRC_DIR}/unittest_util/function_test_util.cc" + "${TEST_SRC_DIR}/unittest_util/function_test_util.h" + "${TEST_SRC_DIR}/unittest_util/graph_transform_test_builder.cc" + "${TEST_SRC_DIR}/unittest_util/graph_transform_test_builder.h" + "${TEST_SRC_DIR}/unittest_util/model_tester.h" + "${TEST_SRC_DIR}/unittest_util/op_tester.cc" + "${TEST_SRC_DIR}/unittest_util/op_tester.h" + "${TEST_SRC_DIR}/unittest_util/qdq_test_utils.cc" + "${TEST_SRC_DIR}/unittest_util/qdq_test_utils.h" + ) + + if (onnxruntime_MINIMAL_BUILD) + list(REMOVE_ITEM onnxruntime_unittest_utils_src + "${TEST_SRC_DIR}/unittest_util/test_dynamic_plugin_ep.cc" + "${TEST_SRC_DIR}/unittest_util/test_dynamic_plugin_ep.h" + ) + endif() +endif() + +onnxruntime_add_static_library(onnxruntime_unittest_utils ${onnxruntime_unittest_utils_src}) + +target_link_libraries(onnxruntime_unittest_utils PUBLIC + onnx + GTest::gtest + GTest::gmock + onnxruntime_test_utils + ${ONNXRUNTIME_TEST_LIBS} + ${onnxruntime_EXTERNAL_LIBRARIES} + ) + +set_target_properties(onnxruntime_unittest_utils PROPERTIES FOLDER "ONNXRuntimeTest") + +source_group(TREE ${TEST_SRC_DIR} FILES ${onnxruntime_unittest_utils_src}) + +endblock() + if(NOT IOS) set(onnx_test_runner_src_dir ${TEST_SRC_DIR}/onnx) file(GLOB onnx_test_runner_common_srcs CONFIGURE_DEPENDS @@ -832,12 +923,20 @@ if(NOT IOS) set(onnx_test_runner_common_lib onnx_test_runner_common) endif() -set(all_tests ${onnxruntime_test_common_src} ${onnxruntime_test_ir_src} ${onnxruntime_test_optimizer_src} - ${onnxruntime_test_framework_src} ${onnxruntime_test_providers_src} ${onnxruntime_test_quantization_src} - ${onnxruntime_test_flatbuffers_src} ${onnxruntime_test_lora_src}) +set(all_tests + ${onnxruntime_test_common_src} + ${onnxruntime_test_ir_src} + ${onnxruntime_test_optimizer_src} + ${onnxruntime_test_framework_src} + ${onnxruntime_test_providers_src} + ${onnxruntime_test_internal_testing_ep_src} + ${onnxruntime_test_quantization_src} + ${onnxruntime_test_flatbuffers_src} + ${onnxruntime_test_lora_src} +) if (onnxruntime_ENABLE_CUDA_EP_INTERNAL_TESTS) - if (NOT onnxruntime_MINIMAL_BUILD AND NOT onnxruntime_REDUCED_OPS_BUILD AND NOT onnxruntime_DISABLE_CONTRIB_OPS) + if (NOT onnxruntime_MINIMAL_BUILD AND NOT onnxruntime_REDUCED_OPS_BUILD) set(onnxruntime_test_cuda_kernels_src_patterns "${TEST_SRC_DIR}/contrib_ops/cuda_kernels/*.cc") endif() @@ -878,9 +977,6 @@ if (onnxruntime_USE_OPENVINO) list(APPEND all_tests ${onnxruntime_test_openvino_src}) endif() -# this is only added to onnxruntime_test_framework_libs above, but we use onnxruntime_test_providers_libs for the onnxruntime_test_all target. -# for now, add it here. better is probably to have onnxruntime_test_providers_libs use the full onnxruntime_test_framework_libs -# list given it's built on top of that library and needs all the same dependencies. if(WIN32) list(APPEND onnxruntime_test_providers_libs Advapi32) endif() @@ -897,14 +993,15 @@ if (CMAKE_SYSTEM_NAME STREQUAL "Emscripten") "${TEST_SRC_DIR}/providers/memcpy_test.cc" ) endif() - list(REMOVE_ITEM all_tests "${TEST_SRC_DIR}/providers/cpu/reduction/reduction_ops_test.cc" - "${TEST_SRC_DIR}/providers/cpu/tensor/grid_sample_test.cc") + list(REMOVE_ITEM all_tests + "${TEST_SRC_DIR}/providers/cpu/reduction/reduction_ops_test.cc" + "${TEST_SRC_DIR}/providers/cpu/tensor/grid_sample_test.cc") endif() if (CMAKE_SYSTEM_NAME STREQUAL "Emscripten" OR IOS) # Because we do not run these model tests in our web or iOS CI build pipelines, and some test code uses C++17 # filesystem functions that are not available in the iOS version we target. - message("Disable model tests in onnxruntime_test_all") + message("Disable model tests") list(REMOVE_ITEM all_tests "${TEST_SRC_DIR}/providers/cpu/model_tests.cc" ) @@ -923,9 +1020,13 @@ endif () if(NOT onnxruntime_ENABLE_CUDA_EP_INTERNAL_TESTS) list(REMOVE_ITEM all_tests ${TEST_SRC_DIR}/providers/cuda/cuda_provider_test.cc) endif() + +partition_provider_test_srcs(all_tests onnxruntime_provider_test_srcs onnxruntime_test_all_srcs) + +list(APPEND onnxruntime_test_all_srcs ${onnxruntime_unittest_main_src}) AddTest( TARGET onnxruntime_test_all - SOURCES ${all_tests} ${onnxruntime_unittest_main_src} + SOURCES ${onnxruntime_test_all_srcs} LIBS ${onnx_test_runner_common_lib} ${onnxruntime_test_providers_libs} ${onnxruntime_test_common_libs} onnx_test_data_proto @@ -995,9 +1096,9 @@ if (onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) target_link_libraries(onnxruntime_test_all PRIVATE Python::Python) endif() if (CMAKE_SYSTEM_NAME STREQUAL "Emscripten") - set_target_properties(onnxruntime_test_all PROPERTIES LINK_DEPENDS ${TEST_SRC_DIR}/wasm/onnxruntime_test_all_adapter.js) + set_target_properties(onnxruntime_test_all PROPERTIES LINK_DEPENDS ${TEST_SRC_DIR}/wasm/onnxruntime_test_adapter.js) set_target_properties(onnxruntime_test_all PROPERTIES LINK_DEPENDS ${ONNXRUNTIME_ROOT}/wasm/pre.js) - set_target_properties(onnxruntime_test_all PROPERTIES LINK_FLAGS "-s STACK_SIZE=5242880 -s INITIAL_MEMORY=536870912 -s ALLOW_MEMORY_GROWTH=1 -s MAXIMUM_MEMORY=4294967296 -s INCOMING_MODULE_JS_API=[preRun,locateFile,arguments,onExit,wasmMemory,buffer,instantiateWasm] --pre-js \"${TEST_SRC_DIR}/wasm/onnxruntime_test_all_adapter.js\" --pre-js \"${ONNXRUNTIME_ROOT}/wasm/pre.js\" -s \"EXPORTED_RUNTIME_METHODS=['FS']\" --preload-file ${CMAKE_CURRENT_BINARY_DIR}/testdata@/testdata -s EXIT_RUNTIME=1") + set_target_properties(onnxruntime_test_all PROPERTIES LINK_FLAGS "-s STACK_SIZE=5242880 -s INITIAL_MEMORY=536870912 -s ALLOW_MEMORY_GROWTH=1 -s MAXIMUM_MEMORY=4294967296 -s INCOMING_MODULE_JS_API=[preRun,locateFile,arguments,onExit,wasmMemory,buffer,instantiateWasm] --pre-js \"${TEST_SRC_DIR}/wasm/onnxruntime_test_adapter.js\" --pre-js \"${ONNXRUNTIME_ROOT}/wasm/pre.js\" -s \"EXPORTED_RUNTIME_METHODS=['FS']\" --preload-file ${CMAKE_CURRENT_BINARY_DIR}/testdata@/testdata -s EXIT_RUNTIME=1") if (onnxruntime_ENABLE_WEBASSEMBLY_THREADS) set_property(TARGET onnxruntime_test_all APPEND_STRING PROPERTY LINK_FLAGS " -s DEFAULT_PTHREAD_STACK_SIZE=131072 -s PROXY_TO_PTHREAD=1") endif() @@ -1095,6 +1196,79 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) endif() endif() +# onnxruntime_provider_test +# Execution provider-related tests. +# These also have some support for dynamically specified plugin EPs. +if (NOT onnxruntime_MINIMAL_BUILD AND NOT onnxruntime_REDUCED_OPS_BUILD) +block() + set(supporting_test_srcs + ${TEST_SRC_DIR}/common/cuda_op_test_utils.cc + ${TEST_SRC_DIR}/common/cuda_op_test_utils.h + ${TEST_SRC_DIR}/common/tensor_op_test_utils.cc + ${TEST_SRC_DIR}/common/tensor_op_test_utils.h + ) + + list(APPEND onnxruntime_provider_test_srcs + ${supporting_test_srcs} + ${onnxruntime_unittest_main_src} + ) + + set(onnxruntime_provider_test_libs + ${onnx_test_runner_common_lib} + ${onnxruntime_test_providers_libs} + ${onnxruntime_test_common_libs} + onnx_test_data_proto + ) + + set(onnxruntime_provider_test_deps ${onnxruntime_test_providers_dependencies}) + + AddTest( + TARGET onnxruntime_provider_test + SOURCES ${onnxruntime_provider_test_srcs} + LIBS ${onnxruntime_provider_test_libs} + DEPENDS ${onnxruntime_provider_test_deps} + ) + + # enable dynamic plugin EP usage + target_compile_definitions(onnxruntime_provider_test PRIVATE ORT_UNIT_TEST_ENABLE_DYNAMIC_PLUGIN_EP_USAGE) + + # TODO fix shorten-64-to-32 warnings + # there are some in builds where sizeof(size_t) != sizeof(int64_t), e.g., in 'ONNX Runtime Web CI Pipeline' + if (HAS_SHORTEN_64_TO_32 AND NOT CMAKE_SIZEOF_VOID_P EQUAL 8) + target_compile_options(onnxruntime_provider_test PRIVATE -Wno-error=shorten-64-to-32) + endif() + + # copied from onnxruntime_test_all + # TODO reuse instead of copy? + if (CMAKE_SYSTEM_NAME STREQUAL "Emscripten") + set_target_properties(onnxruntime_provider_test PROPERTIES LINK_DEPENDS ${TEST_SRC_DIR}/wasm/onnxruntime_test_adapter.js) + set_target_properties(onnxruntime_provider_test PROPERTIES LINK_DEPENDS ${ONNXRUNTIME_ROOT}/wasm/pre.js) + set_target_properties(onnxruntime_provider_test PROPERTIES LINK_FLAGS "-s STACK_SIZE=5242880 -s INITIAL_MEMORY=536870912 -s ALLOW_MEMORY_GROWTH=1 -s MAXIMUM_MEMORY=4294967296 -s INCOMING_MODULE_JS_API=[preRun,locateFile,arguments,onExit,wasmMemory,buffer,instantiateWasm] --pre-js \"${TEST_SRC_DIR}/wasm/onnxruntime_test_adapter.js\" --pre-js \"${ONNXRUNTIME_ROOT}/wasm/pre.js\" -s \"EXPORTED_RUNTIME_METHODS=['FS']\" --preload-file ${CMAKE_CURRENT_BINARY_DIR}/testdata@/testdata -s EXIT_RUNTIME=1") + if (onnxruntime_ENABLE_WEBASSEMBLY_THREADS) + set_property(TARGET onnxruntime_provider_test APPEND_STRING PROPERTY LINK_FLAGS " -s DEFAULT_PTHREAD_STACK_SIZE=131072 -s PROXY_TO_PTHREAD=1") + endif() + if (onnxruntime_USE_JSEP) + set_target_properties(onnxruntime_provider_test PROPERTIES LINK_DEPENDS ${ONNXRUNTIME_ROOT}/wasm/pre-jsep.js) + set_property(TARGET onnxruntime_provider_test APPEND_STRING PROPERTY LINK_FLAGS " --pre-js \"${ONNXRUNTIME_ROOT}/wasm/pre-jsep.js\"") + endif() + + ### + ### if you want to investigate or debug a test failure in onnxruntime_provider_test, replace the following line. + ### those flags slow down the CI test significantly, so we don't use them by default. + ### + # set_property(TARGET onnxruntime_provider_test APPEND_STRING PROPERTY LINK_FLAGS " -s ASSERTIONS=2 -s SAFE_HEAP=1 -s STACK_OVERFLOW_CHECK=2") + set_property(TARGET onnxruntime_provider_test APPEND_STRING PROPERTY LINK_FLAGS " -s ASSERTIONS=0 -s SAFE_HEAP=0 -s STACK_OVERFLOW_CHECK=1") + endif() + + if (IOS) + add_custom_command( + TARGET onnxruntime_provider_test POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy_directory + ${TEST_DATA_SRC} + $/testdata) + endif() +endblock() +endif() set(onnx_test_libs onnxruntime_test_utils @@ -1427,23 +1601,12 @@ endif() TARGET onnxruntime_test_debug_node_inputs_outputs SOURCES "${TEST_SRC_DIR}/debug_node_inputs_outputs/debug_node_inputs_outputs_utils_test.cc" - "${TEST_SRC_DIR}/framework/TestAllocatorManager.cc" - "${TEST_SRC_DIR}/framework/test_utils.cc" - "${TEST_SRC_DIR}/providers/base_tester.h" - "${TEST_SRC_DIR}/providers/base_tester.cc" - "${TEST_SRC_DIR}/providers/checkers.h" - "${TEST_SRC_DIR}/providers/checkers.cc" - "${TEST_SRC_DIR}/providers/op_tester.h" - "${TEST_SRC_DIR}/providers/op_tester.cc" "${TEST_SRC_DIR}/providers/provider_test_utils.h" - "${TEST_SRC_DIR}/providers/tester_types.h" ${onnxruntime_unittest_main_src} LIBS ${onnxruntime_test_providers_libs} ${onnxruntime_test_common_libs} DEPENDS ${all_dependencies} ) - - target_compile_definitions(onnxruntime_test_debug_node_inputs_outputs PRIVATE DEBUG_NODE_INPUTS_OUTPUTS) endif(onnxruntime_DEBUG_NODE_INPUTS_OUTPUTS) diff --git a/cmake/test_pch.h b/cmake/test_pch.h index d538976367791..b2a940d2cb781 100644 --- a/cmake/test_pch.h +++ b/cmake/test_pch.h @@ -11,7 +11,6 @@ // Core test utilities (most frequently used in tests) #include "test/providers/provider_test_utils.h" -#include "test/providers/checkers.h" // ONNX and Protocol Buffer headers #include "core/graph/onnx_protobuf.h" diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index 42eedd5c2feb2..e2a9c20cf7abb 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -199,7 +199,7 @@ This version of the operator has been available since version 1 of the 'com.micr #### Type Constraints
-
T : tensor(float), tensor(float16)
+
T : tensor(float), tensor(float16), tensor(bfloat16)
Constrain input and output types to float tensors.
M : tensor(int32)
Constrain mask index to integer types
diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index cf4daac846ff6..db30b74301db1 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -940,7 +940,7 @@ Do not modify directly.* | | | | |**Operator Domain:** *com.microsoft*|||| -|Attention|*in* input:**T**
*in* weights:**T**
*in* bias:**T**
*in* mask_index:**M**
*in* past:**T**
*in* attention_bias:**T**
*in* past_sequence_length:**M**
*out* output:**T**
*out* present:**T**|1+|**T** = tensor(float), tensor(float16)| +|Attention|*in* input:**T**
*in* weights:**T**
*in* bias:**T**
*in* mask_index:**M**
*in* past:**T**
*in* attention_bias:**T**
*in* past_sequence_length:**M**
*out* output:**T**
*out* present:**T**|1+|**T** = tensor(bfloat16), tensor(float), tensor(float16)| |BeamSearch|*in* input_ids:**F**
*in* max_length:**I**
*in* min_length:**I**
*in* num_beams:**I**
*in* num_return_sequences:**I**
*in* length_penalty:**T**
*in* repetition_penalty:**T**
*in* vocab_mask:**M**
*in* prefix_vocab_mask:**M**
*in* attention_mask:**I**
*in* decoder_input_ids:**I**
*in* logits_processor:**I**
*out* sequences:**I**
*out* sequences_scores:**T**
*out* scores:**T**|1+|**T** = tensor(float), tensor(float16)| |BiasAdd|*in* X:**T**
*in* bias:**T**
*in* skip:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| |BiasDropout|*in* data:**T**
*in* bias:**T**
*in* residual:**T**
*in* ratio:**T1**
*in* training_mode:**T2**
*out* output:**T**
*out* mask:**T2**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)
**T1** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)
**T2** = tensor(bool)| diff --git a/include/onnxruntime/core/framework/ortdevice.h b/include/onnxruntime/core/framework/ortdevice.h index fea970b84fd84..935be9c3f00c7 100644 --- a/include/onnxruntime/core/framework/ortdevice.h +++ b/include/onnxruntime/core/framework/ortdevice.h @@ -24,16 +24,16 @@ struct OrtDevice { using Alignment = size_t; // Pre-defined device types. - static const DeviceType CPU = 0; - static const DeviceType GPU = 1; - static const DeviceType FPGA = 2; - static const DeviceType NPU = 3; + static constexpr DeviceType CPU = 0; + static constexpr DeviceType GPU = 1; + static constexpr DeviceType FPGA = 2; + static constexpr DeviceType NPU = 3; // this is used in the python API so we need to keep it for backward compatibility // it is only used in the OrtDevice ctor, and is mapped to GPU + VendorIds::MICROSOFT - static const DeviceType DML = 4; + static constexpr DeviceType DML = 4; struct MemType { - static const MemoryType DEFAULT = 0; + static constexpr MemoryType DEFAULT = 0; // deprecated values. MemType + VendorId is used to identify the memory type. enum Deprecated : MemoryType { @@ -49,7 +49,7 @@ struct OrtDevice { // - When creating an OrtDevice for an EP allocator, you would typically use the same device type and id // that the EP is registered with (i.e. the OrtDevice passed to the base IExecutionProvider constructor). // - Otherwise use OrtDevice::CPU. - static const MemoryType HOST_ACCESSIBLE = 5; + static constexpr MemoryType HOST_ACCESSIBLE = 5; }; // PCI vendor ids diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index 59979189eed0f..9c42bf34b5b0f 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -1582,11 +1582,13 @@ inline std::vector ConstSessionImpl::GetMemoryInfoForInputs( auto num_inputs = GetInputCount(); std::vector mem_infos; - mem_infos.resize(num_inputs); + if (num_inputs > 0) { + mem_infos.resize(num_inputs); - ThrowOnError(GetApi().SessionGetMemoryInfoForInputs(this->p_, - reinterpret_cast(mem_infos.data()), - num_inputs)); + ThrowOnError(GetApi().SessionGetMemoryInfoForInputs(this->p_, + reinterpret_cast(mem_infos.data()), + num_inputs)); + } return mem_infos; } @@ -1598,11 +1600,13 @@ inline std::vector ConstSessionImpl::GetMemoryInfoForOutputs auto num_outputs = GetOutputCount(); std::vector mem_infos; - mem_infos.resize(num_outputs); + if (num_outputs > 0) { + mem_infos.resize(num_outputs); - ThrowOnError(GetApi().SessionGetMemoryInfoForOutputs(this->p_, - reinterpret_cast(mem_infos.data()), - num_outputs)); + ThrowOnError(GetApi().SessionGetMemoryInfoForOutputs(this->p_, + reinterpret_cast(mem_infos.data()), + num_outputs)); + } return mem_infos; } @@ -1631,12 +1635,12 @@ template inline std::vector ConstSessionImpl::GetEpDeviceForInputs() const { auto num_inputs = GetInputCount(); std::vector input_devices; - input_devices.resize(num_inputs); - - ThrowOnError(GetApi().SessionGetEpDeviceForInputs(this->p_, - reinterpret_cast(input_devices.data()), - num_inputs)); - + if (num_inputs > 0) { + input_devices.resize(num_inputs); + ThrowOnError(GetApi().SessionGetEpDeviceForInputs(this->p_, + reinterpret_cast(input_devices.data()), + num_inputs)); + } return input_devices; } diff --git a/java/build.gradle b/java/build.gradle index 2d43d1ead13f0..64a31c89ad322 100644 --- a/java/build.gradle +++ b/java/build.gradle @@ -3,8 +3,7 @@ plugins { id 'maven-publish' id 'signing' id 'jacoco' - id "com.diffplug.spotless" version "6.25.0" - id "net.linguica.maven-settings" version "0.5" + id "com.diffplug.spotless" version "7.2.1" } allprojects { @@ -14,17 +13,9 @@ allprojects { } project.group = "com.microsoft.onnxruntime" -version = rootProject.file('../VERSION_NUMBER').text.trim() - // cmake runs will inform us of the build directory of the current run def cmakeBuildDir = System.properties['cmakeBuildDir'] def useCUDA = System.properties['USE_CUDA'] -def useROCM = System.properties['USE_ROCM'] - -def adoArtifact = project.findProperty('adoArtifact') -def adoAccessToken = project.findProperty('adoAccessToken') -// Only publish to ADO feed if all two properties are set -def publishToAdo = adoArtifact != null && adoAccessToken != null boolean enableTrainingApis = (System.properties['ENABLE_TRAINING_APIS'] ?: "0") == "1" def cmakeJavaDir = "${cmakeBuildDir}/java" @@ -33,21 +24,14 @@ def cmakeNativeJniDir = "${cmakeJavaDir}/native-jni" def cmakeNativeTestDir = "${cmakeJavaDir}/native-test" def cmakeBuildOutputDir = "${cmakeJavaDir}/build" -def mavenUser = System.properties['mavenUser'] -def mavenPwd = System.properties['mavenPwd'] - def tmpArtifactId = enableTrainingApis ? project.name + "-training" : project.name -def mavenArtifactId = (useCUDA == null && useROCM == null) ? tmpArtifactId : tmpArtifactId + "_gpu" +def mavenArtifactId = (useCUDA == null) ? tmpArtifactId : tmpArtifactId + "_gpu" def defaultDescription = 'ONNX Runtime is a performance-focused inference engine for ONNX (Open Neural Network Exchange) models.' def trainingDescription = 'ONNX Runtime Training is a training and inference package for ONNX ' + '(Open Neural Network Exchange) models. This package is targeted for Learning on The Edge aka On-Device Training ' + 'See https://github.com/microsoft/onnxruntime-training-examples/tree/master/on_device_training for more details.' -// We need to have a custom settings.xml so codeql can bypass the need for settings.security.xml -mavenSettings { - userSettingsFileName = "${projectDir}/settings.xml" -} java { sourceCompatibility = JavaVersion.VERSION_17 @@ -202,16 +186,27 @@ test { systemProperties System.getProperties().subMap([ 'ENABLE_TRAINING_APIS', 'JAVA_FULL_TEST', + 'USE_ACL', + 'USE_ARMNN', + 'USE_AZURE', + 'USE_CANN', 'USE_COREML', 'USE_CUDA', 'USE_DML', 'USE_DNNL', + 'USE_MIGRAPHX', + 'USE_NNAPI', + 'USE_NV', 'USE_OPENVINO', - 'USE_ROCM', - 'USE_TENSORRT', 'USE_QNN', - 'USE_XNNPACK', + 'USE_RKNPU', + 'USE_SNPE', + 'USE_TENSORRT', + 'USE_VITISAI', + 'USE_VSINPU', 'USE_WEBGPU', + 'USE_WEBNN', + 'USE_XNNPACK', ]) testLogging { events "passed", "skipped", "failed" @@ -233,13 +228,9 @@ publishing { publications { maven(MavenPublication) { groupId = project.group - if(publishToAdo) { - artifactId = 'onnxruntime_gpu' - artifact (adoArtifact) - } else { - artifactId = mavenArtifactId - from components.java - } + artifactId = mavenArtifactId + from components.java + version = project.version pom { name = enableTrainingApis ? 'onnxruntime-training' : 'onnx-runtime' @@ -270,29 +261,6 @@ publishing { } } } - repositories { - if (publishToAdo) { - maven { - url "https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/${System.getenv('ADOFeedName')}/maven/v1" - name System.getenv('ADOFeedName') - authentication { - basic(BasicAuthentication) - } - credentials { - username 'aiinfra' - password "${project.findProperty('adoAccessToken')}" - } - } - } else { - maven { - url 'https://oss.sonatype.org/service/local/staging/deploy/maven2/' - credentials { - username mavenUser - password mavenPwd - } - } - } - } } // Generates a task signMavenPublication that will // build all artifacts. @@ -300,12 +268,17 @@ signing { // Queries env vars: // ORG_GRADLE_PROJECT_signingKey // ORG_GRADLE_PROJECT_signingPassword but can be changed to properties - def signingKey = findProperty("signingKey") - def signingPassword = findProperty("signingPassword") - // Skip signing if no key is provided - if (signingKey != null && signingPassword != null) { - useInMemoryPgpKeys(signingKey, signingPassword) - sign publishing.publications.maven - sign publishing.publications.mavenAdo - } + def signingKey = findProperty("signingKey") + def signingPassword = findProperty("signingPassword") + // Skip signing if no key is provided + if (signingKey != null && signingPassword != null) { + useInMemoryPgpKeys(signingKey, signingPassword) + sign publishing.publications.maven + } +} + +tasks.named('generatePomFileForMavenPublication') { + doFirst { + println "AGENT_LOG: Generating POM for version: ${project.version}" + } } diff --git a/java/src/test/java/ai/onnxruntime/InferenceTest.java b/java/src/test/java/ai/onnxruntime/InferenceTest.java index c3f9d345078fe..c202b2a9f80e0 100644 --- a/java/src/test/java/ai/onnxruntime/InferenceTest.java +++ b/java/src/test/java/ai/onnxruntime/InferenceTest.java @@ -693,12 +693,6 @@ public void testCUDA() throws OrtException { runProvider(OrtProvider.CUDA); } - @Test - @EnabledIfSystemProperty(named = "USE_ROCM", matches = "1") - public void testROCM() throws OrtException { - runProvider(OrtProvider.ROCM); - } - @Test @EnabledIfSystemProperty(named = "USE_TENSORRT", matches = "1") public void testTensorRT() throws OrtException { @@ -725,6 +719,18 @@ public void testDNNL() throws OrtException { runProvider(OrtProvider.DNNL); } + @Test + @EnabledIfSystemProperty(named = "USE_MIGRAPHX", matches = "1") + public void testMIGRAPHX() throws OrtException { + runProvider(OrtProvider.MI_GRAPH_X); + } + + @Test + @EnabledIfSystemProperty(named = "USE_NNAPI", matches = "1") + public void testNNAPI() throws OrtException { + runProvider(OrtProvider.NNAPI); + } + @Test @EnabledIfSystemProperty(named = "USE_XNNPACK", matches = "1") public void testXNNPACK() throws OrtException { diff --git a/onnxruntime/contrib_ops/cpu/moe/moe_helper.h b/onnxruntime/contrib_ops/cpu/moe/moe_helper.h index e494719464d20..39249f842e632 100644 --- a/onnxruntime/contrib_ops/cpu/moe/moe_helper.h +++ b/onnxruntime/contrib_ops/cpu/moe/moe_helper.h @@ -49,7 +49,8 @@ Status CheckInputs(MoEParameters& parameters, const Tensor* fc3_experts_bias, // optional const Tensor* fc3_experts_scales, // required for qMoE; NULL for MOE const int64_t pack_size, // number of weights packed together (like 2 for uint4 packed to uint8) - const bool is_fused_swiglu) { + const bool is_fused_swiglu, + const int64_t block_size = 0) { // block size for block-wise quantization // Check dimensions of input to avoid input_dims index out of range. CHECK_TENSOR_SHAPE will verify each tensor later. ASSERT_TENSOR_2D_OR_3D(input); ASSERT_TENSOR_3D(fc1_experts_weights); @@ -90,9 +91,63 @@ Status CheckInputs(MoEParameters& parameters, CHECK_TENSOR_SHAPE(fc2_experts_bias, num_experts, hidden_size); CHECK_TENSOR_SHAPE(fc3_experts_bias, num_experts, inter_size); - CHECK_TENSOR_SHAPE(fc1_experts_scales, num_experts, fc1_inter_size); - CHECK_TENSOR_SHAPE(fc2_experts_scales, num_experts, hidden_size); - CHECK_TENSOR_SHAPE(fc3_experts_scales, num_experts, inter_size); + // Validate scale tensors: Handle both row-wise and block-wise quantization flexibly + // First, detect the actual quantization method from the tensor shapes + bool is_row_wise_quantization = true; + if (fc1_experts_scales != nullptr) { + const auto& fc1_scales_dims = fc1_experts_scales->Shape().GetDims(); + if (fc1_scales_dims.size() == 3 && fc1_scales_dims[2] > 1) { + is_row_wise_quantization = false; + } + } + + if (block_size > 0 && !is_row_wise_quantization) { + // Block-wise quantization: 3D scale tensors + // For block-wise quantization, we calculate the number of blocks using ceiling division + // to handle cases where the dimension is not perfectly divisible by block_size + const int64_t fc1_blocks_per_row = (hidden_size + block_size - 1) / block_size; + const int64_t fc2_blocks_per_row = (inter_size + block_size - 1) / block_size; + const int64_t fc3_blocks_per_row = (hidden_size + block_size - 1) / block_size; + + CHECK_TENSOR_SHAPE(fc1_experts_scales, num_experts, fc1_inter_size, fc1_blocks_per_row); + CHECK_TENSOR_SHAPE(fc2_experts_scales, num_experts, hidden_size, fc2_blocks_per_row); + CHECK_TENSOR_SHAPE(fc3_experts_scales, num_experts, inter_size, fc3_blocks_per_row); + } else { + // Row-wise quantization: 2D scale tensors or 3D with last dimension = 1 + // Handle both {num_experts, features} and {num_experts, features, 1} shapes + if (fc1_experts_scales != nullptr) { + const auto& fc1_scales_dims = fc1_experts_scales->Shape().GetDims(); + if (fc1_scales_dims.size() == 2) { + CHECK_TENSOR_SHAPE(fc1_experts_scales, num_experts, fc1_inter_size); + } else if (fc1_scales_dims.size() == 3) { + CHECK_TENSOR_SHAPE(fc1_experts_scales, num_experts, fc1_inter_size, 1); + } else { + ORT_THROW("fc1_experts_scales must be 2D or 3D tensor"); + } + } + + if (fc2_experts_scales != nullptr) { + const auto& fc2_scales_dims = fc2_experts_scales->Shape().GetDims(); + if (fc2_scales_dims.size() == 2) { + CHECK_TENSOR_SHAPE(fc2_experts_scales, num_experts, hidden_size); + } else if (fc2_scales_dims.size() == 3) { + CHECK_TENSOR_SHAPE(fc2_experts_scales, num_experts, hidden_size, 1); + } else { + ORT_THROW("fc2_experts_scales must be 2D or 3D tensor"); + } + } + + if (fc3_experts_scales != nullptr) { + const auto& fc3_scales_dims = fc3_experts_scales->Shape().GetDims(); + if (fc3_scales_dims.size() == 2) { + CHECK_TENSOR_SHAPE(fc3_experts_scales, num_experts, inter_size); + } else if (fc3_scales_dims.size() == 3) { + CHECK_TENSOR_SHAPE(fc3_experts_scales, num_experts, inter_size, 1); + } else { + ORT_THROW("fc3_experts_scales must be 2D or 3D tensor"); + } + } + } if (fc3_experts_weights == nullptr) { ORT_ENFORCE(fc3_experts_bias == nullptr && fc3_experts_scales == nullptr); diff --git a/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc b/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc index 5c6c3b919b572..8195c9438d408 100644 --- a/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc +++ b/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc @@ -2,12 +2,16 @@ // Licensed under the MIT License. #include "contrib_ops/cpu/moe/moe_quantization_cpu.h" - #include "core/framework/allocator.h" #include "core/framework/float16.h" #include "core/mlas/inc/mlas.h" +#include "core/mlas/inc/mlas_q4.h" #include "core/platform/threadpool.h" #include "core/providers/cpu/math/gemm_helper.h" +#include "core/providers/cpu/activation/activations.h" +#include "core/common/safeint.h" +#include "core/framework/tensor_type_and_shape.h" +#include "core/util/math.h" #include "contrib_ops/cpu/moe/moe_utils.h" #include "contrib_ops/cpu/moe/moe_helper.h" @@ -17,44 +21,325 @@ #include #include +namespace { +inline int64_t GetOptimalBlockSize(int64_t total_elements, int num_threads) { + if (total_elements <= 0 || num_threads <= 0) return 64; + const int64_t l1_cache_elements = 8192; // ~32KB / 4 bytes per float + const int64_t divisor = std::max(1, num_threads > 1 ? 4 : 2); + const int64_t base_block_size = l1_cache_elements / divisor; + const int64_t max_block = std::max(int64_t{32}, total_elements / std::max(int64_t{1}, int64_t{4})); + return std::clamp(base_block_size, int64_t{32}, std::min(int64_t{512}, max_block)); +} + +inline int64_t GetUnrollFactor(int64_t vector_size) { + if (vector_size <= 0) return 2; + if (vector_size >= 512) return 16; + if (vector_size >= 128) return 8; + if (vector_size >= 32) return 4; + return 2; +} + +inline bool ShouldUseMemcpy(int64_t size) { + return size >= 64; +} + +inline int64_t GetDequantBlockSize(int64_t features, int64_t total_work) { + if (features <= 0 || total_work <= 0) return 16; + const int64_t target_block_size = std::max(int64_t{16}, features / std::max(int64_t{1}, int64_t{8})); + const int64_t work_based_size = std::max(int64_t{16}, total_work / std::max(int64_t{1}, int64_t{4})); + return std::min(target_block_size, work_based_size); +} + +bool CanUseMlasQ4Dequant(int64_t num_bits, int64_t block_size) { + if (num_bits != 4) { + return false; + } + + return true; +} + +bool CanUseMlasQ4Gemm(int64_t expert_weight_bits, int64_t block_size, + int64_t rows, int64_t cols, MLAS_BLK_QUANT_TYPE& out_qtype) { + if (expert_weight_bits != 4) { + return false; + } + + if (block_size == 64) { + out_qtype = BlkQ4Sym64; + } else if (block_size == 128) { + out_qtype = BlkQ4Sym128; + } else if (block_size == 0) { + out_qtype = BlkQ4Sym; + } else { + return false; + } + + size_t expected_size = MlasQ4GemmPackBSize(out_qtype, static_cast(cols), static_cast(rows)); + return expected_size > 0; +} + +} // namespace + namespace onnxruntime { namespace contrib { -// Helper function to dequantize weights. Supports 4-bit and 8-bit symmetric quantization. -// The source quantized weights are stored as a row-major representation of the transposed -// logical weight matrix (W^T). This function dequantizes it into a float row-major W^T matrix. template -void DequantizeBlock(const uint8_t* quantized_data, - const TScale* scales, - int64_t /*block_size*/, - int64_t num_bits, - int64_t rows, - int64_t cols, - float* dequantized_data) { +void DequantizeBlockWithMlas(const uint8_t* quantized_data, + const TScale* scales, + int64_t block_size, + int64_t num_bits, + int64_t rows, + int64_t cols, + float* dequantized_data, + MLAS_THREADPOOL* thread_pool); + +template +Status ConvertToMlasQ4Format(const uint8_t* quantized_data, + const TScale* scales, + int64_t block_size, + int64_t num_bits, + int64_t rows, + int64_t cols, + MLAS_BLK_QUANT_TYPE qtype, + AllocatorPtr allocator, + IAllocatorUniquePtr& mlas_packed_buffer) { + if (num_bits != 4) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Only 4-bit quantization supported for MLAS Q4 format conversion"); + } + + auto temp_float_buffer = IAllocator::MakeUniquePtr(allocator, static_cast(rows * cols)); + float* temp_float = temp_float_buffer.get(); + + DequantizeBlockWithMlas(quantized_data, scales, block_size, num_bits, rows, cols, temp_float, nullptr); + + size_t packed_size = MlasQ4GemmPackBSize(qtype, static_cast(cols), static_cast(rows)); + if (packed_size == 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "MLAS Q4 packing not supported for this configuration"); + } + + mlas_packed_buffer = IAllocator::MakeUniquePtr(allocator, packed_size); + MlasQ4GemmPackB(qtype, mlas_packed_buffer.get(), temp_float, static_cast(cols), static_cast(rows), static_cast(cols)); + + return Status::OK(); +} + +Status DirectQ4Gemm(const float* A, + const uint8_t* mlas_packed_B, + const float* bias, + float* C, + int64_t M, + int64_t N, + int64_t K, + MLAS_BLK_QUANT_TYPE qtype, + MLAS_THREADPOOL* thread_pool) { + MLAS_Q4_GEMM_DATA_PARAMS params; + params.A = A; + params.lda = static_cast(K); + params.B = mlas_packed_B; + params.Bias = bias; + params.C = C; + params.ldc = static_cast(N); + params.OutputProcessor = nullptr; + + MlasQ4GemmBatch(qtype, static_cast(M), static_cast(N), static_cast(K), 1, ¶ms, thread_pool); + return Status::OK(); +} + +template +void DequantizeBlockWithMlas(const uint8_t* quantized_data, + const TScale* scales, + int64_t block_size, + int64_t num_bits, + int64_t rows, + int64_t cols, + float* dequantized_data, + MLAS_THREADPOOL* thread_pool) { const float zero_point = num_bits == 8 ? 128.0f : 8.0f; - if (num_bits == 8) { - for (int64_t r = 0; r < rows; ++r) { - const float scale = static_cast(scales[r]); - for (int64_t c = 0; c < cols; ++c) { - // Symmetric quantization: dequantized_value = scale * (quantized_value - zero_point) - dequantized_data[r * cols + c] = scale * (static_cast(quantized_data[r * cols + c]) - zero_point); + const int64_t blocks_per_row = (block_size > 0) ? ((cols + block_size - 1) / block_size) : 1; + + if (CanUseMlasQ4Dequant(num_bits, block_size)) { + const int64_t packed_cols = (cols + 1) / 2; + + if (block_size == 0) { + for (int64_t r = 0; r < rows; ++r) { + const uint8_t* row_data = quantized_data + r * packed_cols; + float* row_output = dequantized_data + r * cols; + const float scale = static_cast(scales[r]); + + int64_t c = 0; + for (; c + 8 <= cols; c += 8) { + const uint8_t packed_val0 = row_data[(c + 0) / 2]; + const uint8_t packed_val1 = row_data[(c + 2) / 2]; + const uint8_t packed_val2 = row_data[(c + 4) / 2]; + const uint8_t packed_val3 = row_data[(c + 6) / 2]; + + row_output[c + 0] = scale * (static_cast(packed_val0 & 0x0F) - zero_point); + row_output[c + 1] = scale * (static_cast(packed_val0 >> 4) - zero_point); + row_output[c + 2] = scale * (static_cast(packed_val1 & 0x0F) - zero_point); + row_output[c + 3] = scale * (static_cast(packed_val1 >> 4) - zero_point); + row_output[c + 4] = scale * (static_cast(packed_val2 & 0x0F) - zero_point); + row_output[c + 5] = scale * (static_cast(packed_val2 >> 4) - zero_point); + row_output[c + 6] = scale * (static_cast(packed_val3 & 0x0F) - zero_point); + row_output[c + 7] = scale * (static_cast(packed_val3 >> 4) - zero_point); + } + + for (; c < cols; c += 2) { + const uint8_t packed_val = row_data[c / 2]; + const uint8_t val0 = packed_val & 0x0F; + const uint8_t val1 = packed_val >> 4; + + row_output[c] = scale * (static_cast(val0) - zero_point); + if (c + 1 < cols) { + row_output[c + 1] = scale * (static_cast(val1) - zero_point); + } + } } + return; + } else { + for (int64_t r = 0; r < rows; ++r) { + const uint8_t* row_data = quantized_data + r * packed_cols; + float* row_output = dequantized_data + r * cols; + + for (int64_t block_start = 0; block_start < cols; block_start += block_size) { + const int64_t block_end = std::min(block_start + block_size, cols); + const int64_t block_idx = std::min(block_start / block_size, blocks_per_row - 1); + const int64_t scale_idx = r * blocks_per_row + block_idx; + const float scale = static_cast(scales[scale_idx]); + + int64_t c = block_start; + for (; c + 4 <= block_end; c += 4) { + const uint8_t packed_val0 = row_data[(c + 0) / 2]; + const uint8_t packed_val1 = row_data[(c + 2) / 2]; + + row_output[c + 0] = scale * (static_cast(packed_val0 & 0x0F) - zero_point); + row_output[c + 1] = scale * (static_cast(packed_val0 >> 4) - zero_point); + row_output[c + 2] = scale * (static_cast(packed_val1 & 0x0F) - zero_point); + row_output[c + 3] = scale * (static_cast(packed_val1 >> 4) - zero_point); + } + + for (; c < block_end; c += 2) { + const uint8_t packed_val = row_data[c / 2]; + const uint8_t val0 = packed_val & 0x0F; + const uint8_t val1 = packed_val >> 4; + + row_output[c] = scale * (static_cast(val0) - zero_point); + if (c + 1 < block_end) { + row_output[c + 1] = scale * (static_cast(val1) - zero_point); + } + } + } + } + return; } - } else if (num_bits == 4) { - const int64_t packed_cols = (cols + 1) / 2; + } + + if (num_bits == 8 && block_size == 0) { for (int64_t r = 0; r < rows; ++r) { const float scale = static_cast(scales[r]); - for (int64_t c = 0; c < cols; ++c) { - const uint8_t packed_val = quantized_data[r * packed_cols + c / 2]; - // Unpack the 4-bit value. Low nibble for even columns, high nibble for odd columns. - const uint8_t quantized_val = (c % 2 == 0) ? (packed_val & 0x0F) : (packed_val >> 4); - // Symmetric quantization: dequantized_value = scale * (quantized_value - zero_point) - dequantized_data[r * cols + c] = scale * (static_cast(quantized_val) - zero_point); + const uint8_t zero_pt = static_cast(zero_point); + + MlasDequantizeLinear( + quantized_data + r * cols, + dequantized_data + r * cols, + static_cast(cols), + scale, + zero_pt); + } + } else { + if (num_bits == 8) { + for (int64_t r = 0; r < rows; ++r) { + const uint8_t* row_data = quantized_data + r * cols; + float* row_output = dequantized_data + r * cols; + + int64_t c = 0; + if (block_size > 0) { + for (int64_t block_start = 0; block_start < cols; block_start += block_size) { + const int64_t block_end = std::min(block_start + block_size, cols); + const int64_t block_idx = std::min(block_start / block_size, blocks_per_row - 1); + const int64_t scale_idx = r * blocks_per_row + block_idx; + const float scale = static_cast(scales[scale_idx]); + + for (c = block_start; c + 4 <= block_end; c += 4) { + row_output[c] = scale * (static_cast(row_data[c]) - zero_point); + row_output[c + 1] = scale * (static_cast(row_data[c + 1]) - zero_point); + row_output[c + 2] = scale * (static_cast(row_data[c + 2]) - zero_point); + row_output[c + 3] = scale * (static_cast(row_data[c + 3]) - zero_point); + } + for (; c < block_end; ++c) { + row_output[c] = scale * (static_cast(row_data[c]) - zero_point); + } + } + } else { + const float scale = static_cast(scales[r]); + for (; c + 8 <= cols; c += 8) { + row_output[c] = scale * (static_cast(row_data[c]) - zero_point); + row_output[c + 1] = scale * (static_cast(row_data[c + 1]) - zero_point); + row_output[c + 2] = scale * (static_cast(row_data[c + 2]) - zero_point); + row_output[c + 3] = scale * (static_cast(row_data[c + 3]) - zero_point); + row_output[c + 4] = scale * (static_cast(row_data[c + 4]) - zero_point); + row_output[c + 5] = scale * (static_cast(row_data[c + 5]) - zero_point); + row_output[c + 6] = scale * (static_cast(row_data[c + 6]) - zero_point); + row_output[c + 7] = scale * (static_cast(row_data[c + 7]) - zero_point); + } + for (; c < cols; ++c) { + row_output[c] = scale * (static_cast(row_data[c]) - zero_point); + } + } + } + } else if (num_bits == 4) { + const int64_t packed_cols = (cols + 1) / 2; + for (int64_t r = 0; r < rows; ++r) { + const uint8_t* row_data = quantized_data + r * packed_cols; + float* row_output = dequantized_data + r * cols; + + if (block_size > 0) { + for (int64_t block_start = 0; block_start < cols; block_start += block_size) { + const int64_t block_end = std::min(block_start + block_size, cols); + const int64_t block_idx = std::min(block_start / block_size, blocks_per_row - 1); + const int64_t scale_idx = r * blocks_per_row + block_idx; + const float scale = static_cast(scales[scale_idx]); + + for (int64_t c = block_start; c < block_end; c += 2) { + const uint8_t packed_val = row_data[c / 2]; + const uint8_t val0 = packed_val & 0x0F; + const uint8_t val1 = packed_val >> 4; + + row_output[c] = scale * (static_cast(val0) - zero_point); + if (c + 1 < block_end) { + row_output[c + 1] = scale * (static_cast(val1) - zero_point); + } + } + } + } else { + const float scale = static_cast(scales[r]); + for (int64_t c = 0; c < cols; c += 2) { + const uint8_t packed_val = row_data[c / 2]; + const uint8_t val0 = packed_val & 0x0F; + const uint8_t val1 = packed_val >> 4; + + row_output[c] = scale * (static_cast(val0) - zero_point); + if (c + 1 < cols) { + row_output[c + 1] = scale * (static_cast(val1) - zero_point); + } + } + } } } } } +template +void DequantizeBlock(const uint8_t* quantized_data, + const TScale* scales, + int64_t block_size, + int64_t num_bits, + int64_t rows, + int64_t cols, + float* dequantized_data, + MLAS_THREADPOOL* thread_pool = nullptr) { + DequantizeBlockWithMlas(quantized_data, scales, block_size, num_bits, rows, cols, dequantized_data, thread_pool); +} + template QMoECPU::QMoECPU(const OpKernelInfo& op_kernel_info) : OpKernel(op_kernel_info), @@ -63,11 +348,15 @@ QMoECPU::QMoECPU(const OpKernelInfo& op_kernel_info) ORT_ENFORCE(expert_weight_bits_ == 4 || expert_weight_bits_ == 8, "Attribute 'expert_weight_bits' must be 4 or 8."); block_size_ = op_kernel_info.GetAttrOrDefault("block_size", 0); + + if (block_size_ > 0) { + ORT_ENFORCE(block_size_ >= 16, "block_size must be >= 16 when provided."); + ORT_ENFORCE((block_size_ & (block_size_ - 1)) == 0, "block_size must be a power of 2."); + } } template Status QMoECPU::Compute(OpKernelContext* context) const { - // --- 1. Get Inputs and Attributes --- const auto* input = context->Input(0); const auto* router_probs = context->Input(1); const auto* fc1_experts_weights = context->Input(2); @@ -87,7 +376,8 @@ Status QMoECPU::Compute(OpKernelContext* context) const { fc2_experts_weights, fc2_experts_bias, fc2_scales, fc3_experts_weights, fc3_experts_bias, fc3_scales, expert_weight_bits_ == 4 ? 2 : 1, - true)); + true, + block_size_)); if (fc3_experts_weights || fc3_experts_bias || fc3_scales) { return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "FC3 gating is not yet implemented on CPU for QMoE"); @@ -109,19 +399,17 @@ Status QMoECPU::Compute(OpKernelContext* context) const { const size_t output_buffer_size = static_cast(output->Shape().Size()); const T* input_data = input->Data(); - const T* router_probs_data = router_probs->Data(); - // --- 2. Routing Logic: Assign tokens to experts --- IAllocatorUniquePtr router_logits_float_buffer; const float* router_logits_float; if constexpr (std::is_same_v) { router_logits_float_buffer = IAllocator::MakeUniquePtr(allocator, static_cast(num_tokens * num_experts)); router_logits_float = router_logits_float_buffer.get(); - MlasConvertHalfToFloatBuffer(reinterpret_cast(router_probs_data), + MlasConvertHalfToFloatBuffer(reinterpret_cast(router_probs->Data()), const_cast(router_logits_float), static_cast(num_tokens * num_experts)); } else { - router_logits_float = reinterpret_cast(router_probs_data); + router_logits_float = reinterpret_cast(router_probs->Data()); } auto route_expert_ptr = IAllocator::MakeUniquePtr(allocator, static_cast(num_tokens * k_)); @@ -129,36 +417,37 @@ Status QMoECPU::Compute(OpKernelContext* context) const { auto route_scale_ptr = IAllocator::MakeUniquePtr(allocator, static_cast(num_tokens * k_)); float* route_scale = route_scale_ptr.get(); - // Parallelize the routing logic to improve performance for large token batches. - // Minor performance regression for single-token decoding is an acceptable trade-off - int num_routing_threads = (tp == nullptr || num_tokens < 4096) ? 1 : std::min(static_cast(num_tokens), concurrency::ThreadPool::DegreeOfParallelism(tp)); + const int max_threads = tp ? concurrency::ThreadPool::DegreeOfParallelism(tp) : 1; + const int64_t thread_divisor = std::max(1, max_threads * 4); + const int64_t min_work_per_thread = std::max(int64_t{32}, static_cast(num_tokens / thread_divisor)); + const int optimal_routing_threads = (tp == nullptr || num_tokens < min_work_per_thread) ? 1 : std::min(static_cast(num_tokens / std::max(int64_t{1}, min_work_per_thread)), max_threads); + const int num_routing_threads = std::max(1, optimal_routing_threads); std::vector>> thread_local_expert_token_maps(num_routing_threads); for (auto& map : thread_local_expert_token_maps) { map.resize(static_cast(num_experts)); + for (auto& expert_tokens : map) { + expert_tokens.reserve(32); + } } concurrency::ThreadPool::TrySimpleParallelFor(tp, num_routing_threads, [&](std::ptrdiff_t thread_id) { auto work = concurrency::ThreadPool::PartitionWork(static_cast(thread_id), num_routing_threads, static_cast(num_tokens)); auto& local_expert_token_map = thread_local_expert_token_maps[thread_id]; - // Pre-allocate buffers for this thread to reuse, avoiding allocations inside the loop. std::vector> sorted_logits(static_cast(num_experts)); std::vector top_k_exp(static_cast(k_)); for (int64_t i = work.start; i < work.end; ++i) { const float* logits = router_logits_float + i * num_experts; + for (int64_t j = 0; j < num_experts; ++j) { sorted_logits[static_cast(j)] = {logits[j], j}; } - std::partial_sort(sorted_logits.begin(), sorted_logits.begin() + static_cast(k_), sorted_logits.end(), std::greater<>()); + std::partial_sort(sorted_logits.begin(), sorted_logits.begin() + static_cast(k_), + sorted_logits.end(), std::greater<>()); - float max_logit = -std::numeric_limits::infinity(); - for (int64_t j = 0; j < k_; ++j) { - if (sorted_logits[static_cast(j)].first > max_logit) { - max_logit = sorted_logits[static_cast(j)].first; - } - } + float max_logit = sorted_logits[0].first; float sum_exp = 0.0f; for (int64_t j = 0; j < k_; ++j) { @@ -166,20 +455,19 @@ Status QMoECPU::Compute(OpKernelContext* context) const { sum_exp += top_k_exp[static_cast(j)]; } - float scale = (sum_exp == 0.0f) ? 0.0f : (1.0f / sum_exp); + const float inv_sum = (sum_exp == 0.0f) ? 0.0f : (1.0f / sum_exp); for (int64_t j = 0; j < k_; ++j) { int64_t expert_idx = sorted_logits[static_cast(j)].second; int64_t route_idx = i * k_ + j; route_expert[route_idx] = static_cast(expert_idx); - route_scale[route_idx] = top_k_exp[static_cast(j)] * scale; - if (route_scale[route_idx] > 0.0f) { + route_scale[route_idx] = top_k_exp[static_cast(j)] * inv_sum; + if (route_scale[route_idx] > 1e-8f) { // Use small threshold to avoid zero weights local_expert_token_map[static_cast(expert_idx)].push_back(route_idx); } } } }); - // Merge the maps from each thread into a single global map. std::vector> expert_token_map(static_cast(num_experts)); for (int64_t expert_idx = 0; expert_idx < num_experts; ++expert_idx) { size_t total_tokens_for_expert = 0; @@ -187,18 +475,17 @@ Status QMoECPU::Compute(OpKernelContext* context) const { total_tokens_for_expert += thread_local_expert_token_maps[t][static_cast(expert_idx)].size(); } expert_token_map[static_cast(expert_idx)].reserve(total_tokens_for_expert); - } - for (int t = 0; t < num_routing_threads; ++t) { - for (int64_t expert_idx = 0; expert_idx < num_experts; ++expert_idx) { + for (int t = 0; t < num_routing_threads; ++t) { auto& local_tokens = thread_local_expert_token_maps[t][static_cast(expert_idx)]; if (!local_tokens.empty()) { - expert_token_map[static_cast(expert_idx)].insert(expert_token_map[static_cast(expert_idx)].end(), local_tokens.begin(), local_tokens.end()); + expert_token_map[static_cast(expert_idx)].insert( + expert_token_map[static_cast(expert_idx)].end(), + local_tokens.begin(), local_tokens.end()); } } } - // --- 3. Parallel Expert Computation --- IAllocatorUniquePtr input_float_buffer; const float* input_float; if constexpr (std::is_same_v) { @@ -211,118 +498,434 @@ Status QMoECPU::Compute(OpKernelContext* context) const { input_float = reinterpret_cast(input_data); } - int num_expert_threads = (tp == nullptr) ? 1 : std::min(static_cast(num_experts), concurrency::ThreadPool::DegreeOfParallelism(tp)); + const int max_expert_threads = tp ? concurrency::ThreadPool::DegreeOfParallelism(tp) : 1; + const int64_t total_expert_work = std::accumulate(expert_token_map.begin(), expert_token_map.end(), 0LL, + [](int64_t sum, const std::vector& tokens) { return sum + static_cast(tokens.size()); }); + const int64_t expert_thread_divisor = std::max(1, max_expert_threads * 8); + const int64_t min_expert_work_per_thread = std::max(int64_t{16}, total_expert_work / expert_thread_divisor); + + int num_expert_threads = (tp == nullptr || total_expert_work < min_expert_work_per_thread) ? 1 : std::min(static_cast(total_expert_work / std::max(int64_t{1}, min_expert_work_per_thread)), std::min(static_cast(num_experts), max_expert_threads)); if (num_expert_threads == 0) num_expert_threads = 1; + auto thread_local_outputs_ptr = IAllocator::MakeUniquePtr(allocator, static_cast(num_expert_threads) * output_buffer_size); float* thread_local_outputs = thread_local_outputs_ptr.get(); - memset(thread_local_outputs, 0, static_cast(num_expert_threads) * output_buffer_size * sizeof(float)); + std::memset(thread_local_outputs, 0, static_cast(num_expert_threads) * output_buffer_size * sizeof(float)); - // Pre-calculate workspace size per thread to avoid allocations inside the loop size_t max_tokens_per_expert = 0; for (const auto& tokens : expert_token_map) { - if (tokens.size() > max_tokens_per_expert) { - max_tokens_per_expert = tokens.size(); - } + max_tokens_per_expert = std::max(max_tokens_per_expert, tokens.size()); } - const size_t A1_size = static_cast(max_tokens_per_expert * hidden_size); - const size_t C1_size = static_cast(max_tokens_per_expert * fc1_out_features); - const size_t A2_size = static_cast(max_tokens_per_expert * inter_size); - const size_t C2_size = static_cast(max_tokens_per_expert * hidden_size); - const size_t B1_dequant_size = static_cast(fc1_out_features * hidden_size); - const size_t B2_dequant_size = static_cast(hidden_size * inter_size); - const size_t bias1_size = static_cast(fc1_out_features); - const size_t bias2_size = static_cast(hidden_size); + const auto align_size = [](size_t size) -> size_t { + return (size + 63) & ~63; + }; + + const size_t A1_size = align_size(static_cast(max_tokens_per_expert) * static_cast(hidden_size)); + const size_t C1_size = align_size(static_cast(max_tokens_per_expert) * static_cast(fc1_out_features)); + const size_t A2_size = align_size(static_cast(max_tokens_per_expert) * static_cast(inter_size)); + const size_t C2_size = align_size(static_cast(max_tokens_per_expert) * static_cast(hidden_size)); + const size_t B1_dequant_size = align_size(static_cast(fc1_out_features) * static_cast(hidden_size)); + const size_t B2_dequant_size = align_size(static_cast(hidden_size) * static_cast(inter_size)); + + const size_t workspace_elements_per_thread = A1_size + C1_size + A2_size + C2_size + + B1_dequant_size + B2_dequant_size; - const size_t workspace_elements_per_thread = A1_size + C1_size + A2_size + C2_size + B1_dequant_size + B2_dequant_size + bias1_size + bias2_size; auto workspace_ptr = IAllocator::MakeUniquePtr(allocator, static_cast(num_expert_threads) * workspace_elements_per_thread); float* workspace = workspace_ptr.get(); + auto bias_conversion_buffers_ptr = IAllocator::MakeUniquePtr(allocator, + static_cast(num_expert_threads) * (static_cast(fc1_out_features) + static_cast(hidden_size))); + float* bias_conversion_buffers = bias_conversion_buffers_ptr.get(); + + const auto& fc1_scales_dims = fc1_scales->Shape().GetDims(); + const auto& fc2_scales_dims = fc2_scales->Shape().GetDims(); + const bool is_fc1_block_wise = (fc1_scales_dims.size() == 3 && fc1_scales_dims[2] > 1); + const bool is_fc2_block_wise = (fc2_scales_dims.size() == 3 && fc2_scales_dims[2] > 1); + + const uint8_t* fc1_weights_data = fc1_experts_weights->Data(); + const uint8_t* fc2_weights_data = fc2_experts_weights->Data(); + const T* fc1_scales_data = fc1_scales->Data(); + const T* fc2_scales_data = fc2_scales->Data(); + const T* fc1_bias_data = fc1_experts_bias ? fc1_experts_bias->Data() : nullptr; + const T* fc2_bias_data = fc2_experts_bias ? fc2_experts_bias->Data() : nullptr; + + const int64_t pack_unit = (8 / expert_weight_bits_); + const int64_t fc1_packed_cols = (hidden_size + pack_unit - 1) / pack_unit; + const int64_t fc2_packed_cols = (inter_size + pack_unit - 1) / pack_unit; + const bool has_fc1_bias = (fc1_bias_data != nullptr); + const bool has_fc2_bias = (fc2_bias_data != nullptr); + + std::vector> expert_workload; + size_t total_work = 0; + + for (int64_t i = 0; i < num_experts; ++i) { + const size_t token_count = expert_token_map[static_cast(i)].size(); + if (token_count > 0) { + expert_workload.emplace_back(i, token_count); + total_work += token_count; + } + } + + if (total_work < 48) { + num_expert_threads = 1; + } else if (total_work < 192) { + num_expert_threads = std::min(num_expert_threads, 2); + } else if (total_work < 512) { + num_expert_threads = std::min(num_expert_threads, 4); + } + + std::sort(expert_workload.begin(), expert_workload.end(), + [](const auto& a, const auto& b) { return a.second > b.second; }); + + std::vector> expert_batches(num_expert_threads); + size_t thread_idx = 0; + for (const auto& work : expert_workload) { + expert_batches[thread_idx].push_back(work.first); + thread_idx = (thread_idx + 1) % static_cast(num_expert_threads); + } + concurrency::ThreadPool::TrySimpleParallelFor(tp, num_expert_threads, [&](std::ptrdiff_t thread_id_pd) { - int thread_id = static_cast(thread_id_pd); - auto work = concurrency::ThreadPool::PartitionWork(thread_id, num_expert_threads, static_cast(num_experts)); + const int thread_id = static_cast(thread_id_pd); + const auto& expert_batch = expert_batches[static_cast(thread_id)]; float* thread_workspace = workspace + static_cast(thread_id) * workspace_elements_per_thread; - for (int64_t expert_idx = work.start; expert_idx < work.end; ++expert_idx) { + float* thread_bias1_buffer = bias_conversion_buffers + static_cast(thread_id) * (static_cast(fc1_out_features) + static_cast(hidden_size)); + float* thread_bias2_buffer = thread_bias1_buffer + static_cast(fc1_out_features); + + for (int64_t expert_idx : expert_batch) { const auto& routes = expert_token_map[static_cast(expert_idx)]; if (routes.empty()) { continue; } - const int64_t num_expert_tokens = routes.size(); + const int64_t num_expert_tokens = static_cast(routes.size()); - // Partition the workspace for the current expert float* A1 = thread_workspace; - float* C1 = A1 + num_expert_tokens * hidden_size; - float* A2 = C1 + num_expert_tokens * fc1_out_features; - float* C2 = A2 + num_expert_tokens * inter_size; - float* B1_dequant = C2 + num_expert_tokens * hidden_size; - float* B2_dequant = B1_dequant + fc1_out_features * hidden_size; - float* bias1_float = B2_dequant + hidden_size * inter_size; - float* bias2_float = bias1_float + fc1_out_features; - - // --- Gather input tokens for the current expert --- - for (int64_t i = 0; i < num_expert_tokens; ++i) { - const int64_t token_idx = routes[static_cast(i)] / k_; - memcpy(A1 + i * hidden_size, - input_float + token_idx * hidden_size, - static_cast(hidden_size) * sizeof(float)); + float* C1 = A1 + A1_size; + float* A2 = C1 + C1_size; + float* C2 = A2 + A2_size; + float* B1_dequant = C2 + C2_size; + float* B2_dequant = B1_dequant + B1_dequant_size; + + const int64_t dynamic_block_size = GetOptimalBlockSize(num_expert_tokens, tp ? concurrency::ThreadPool::DegreeOfParallelism(tp) : 1); + const int64_t num_blocks = (num_expert_tokens + dynamic_block_size - 1) / dynamic_block_size; + + if (num_expert_tokens >= 8 && num_blocks > 1 && tp != nullptr) { + concurrency::ThreadPool::TrySimpleParallelFor(tp, static_cast(num_blocks), [&](std::ptrdiff_t block_idx) { + const int64_t start_idx = block_idx * dynamic_block_size; + const int64_t end_idx = std::min(start_idx + dynamic_block_size, num_expert_tokens); + + for (int64_t i = start_idx; i < end_idx; ++i) { + const int64_t token_idx = routes[static_cast(i)] / k_; + const float* src = input_float + token_idx * hidden_size; + float* dst = A1 + i * hidden_size; + + std::memcpy(dst, src, static_cast(hidden_size) * sizeof(float)); + } + }); + } else { + for (int64_t i = 0; i < num_expert_tokens; ++i) { + const int64_t token_idx = routes[static_cast(i)] / k_; + const float* src = input_float + token_idx * hidden_size; + float* dst = A1 + i * hidden_size; + + if (ShouldUseMemcpy(hidden_size)) { + std::memcpy(dst, src, static_cast(hidden_size) * sizeof(float)); + } else { + const int64_t unroll_factor = GetUnrollFactor(hidden_size); + int64_t j = 0; + for (; j + unroll_factor <= hidden_size; j += unroll_factor) { + for (int64_t k = 0; k < unroll_factor; ++k) { + dst[j + k] = src[j + k]; + } + } + for (; j < hidden_size; ++j) { + dst[j] = src[j]; + } + } + } + } + + const T* fc1_scales_ptr; + + if (is_fc1_block_wise) { + const int64_t fc1_blocks_per_row = fc1_scales_dims[2]; + fc1_scales_ptr = fc1_scales_data + expert_idx * fc1_out_features * fc1_blocks_per_row; + } else { + fc1_scales_ptr = fc1_scales_data + expert_idx * fc1_out_features; } - // --- FC1 GEMM (X * W1^T) --- - DequantizeBlock(fc1_experts_weights->Data() + expert_idx * fc1_out_features * (hidden_size / (8 / expert_weight_bits_)), - fc1_scales->Data() + expert_idx * fc1_out_features * (block_size_ > 0 ? hidden_size / block_size_ : 1), - block_size_, expert_weight_bits_, - fc1_out_features, hidden_size, B1_dequant); + const int64_t dequant_block_size = GetDequantBlockSize(fc1_out_features, num_expert_tokens); + const int64_t num_dequant_blocks = (fc1_out_features + dequant_block_size - 1) / dequant_block_size; + + const size_t m = static_cast(num_expert_tokens); + const size_t n = static_cast(fc1_out_features); + const size_t k = static_cast(hidden_size); + + MLAS_BLK_QUANT_TYPE q_type; + bool use_direct_q4_gemm = CanUseMlasQ4Gemm(expert_weight_bits_, is_fc1_block_wise ? block_size_ : 0, + fc1_out_features, hidden_size, q_type); + bool fc1_used_direct_q4 = false; + bool fc1_bias_handled_by_q4_gemm = false; + + if (use_direct_q4_gemm) { + IAllocatorUniquePtr mlas_packed_fc1; + Status convert_status = ConvertToMlasQ4Format( + fc1_weights_data + expert_idx * fc1_out_features * fc1_packed_cols, + fc1_scales_ptr, + is_fc1_block_wise ? block_size_ : 0, + expert_weight_bits_, + fc1_out_features, + hidden_size, + q_type, + allocator, + mlas_packed_fc1); + + if (convert_status.IsOK()) { + float* fc1_bias_float = nullptr; + IAllocatorUniquePtr fc1_bias_buffer; + + if (has_fc1_bias) { + const T* B1_bias = fc1_bias_data + expert_idx * fc1_out_features; + fc1_bias_buffer = IAllocator::MakeUniquePtr(allocator, static_cast(fc1_out_features)); + fc1_bias_float = fc1_bias_buffer.get(); + + if constexpr (std::is_same_v) { + MlasConvertHalfToFloatBuffer(reinterpret_cast(B1_bias), fc1_bias_float, static_cast(fc1_out_features)); + } else { + for (int64_t i = 0; i < fc1_out_features; ++i) { + fc1_bias_float[i] = static_cast(B1_bias[i]); + } + } + } + + Status gemm_status = DirectQ4Gemm(A1, mlas_packed_fc1.get(), fc1_bias_float, C1, + num_expert_tokens, fc1_out_features, hidden_size, q_type, tp); + + if (gemm_status.IsOK()) { + fc1_used_direct_q4 = true; +#ifdef ONNXRUNTIME_ENABLE_VERBOSE_LOGGING + LOGS_DEFAULT(VERBOSE) << "QMoE: Using direct MLAS Q4 GEMM for FC1 expert " << expert_idx + << " (M=" << num_expert_tokens << ", N=" << fc1_out_features << ", K=" << hidden_size << ")"; +#endif + goto fc1_gemm_done; + } + } + // If direct Q4 GEMM failed, fall back to traditional approach + } + + // Traditional approach: dequantize + regular GEMM + if (num_dequant_blocks > 1 && fc1_out_features >= 32) { + concurrency::ThreadPool::TrySimpleParallelFor(tp, static_cast(num_dequant_blocks), [&](std::ptrdiff_t block_idx) { + const int64_t start_row = block_idx * dequant_block_size; + const int64_t end_row = std::min(start_row + dequant_block_size, fc1_out_features); + const auto offset = expert_idx * fc1_out_features * fc1_packed_cols + start_row * fc1_packed_cols; + DequantizeBlock(fc1_weights_data + offset, + fc1_scales_ptr + (is_fc1_block_wise ? start_row * fc1_scales_dims[2] : start_row), + is_fc1_block_wise ? block_size_ : 0, expert_weight_bits_, + end_row - start_row, hidden_size, B1_dequant + start_row * hidden_size, tp); + }); + } else { + DequantizeBlock(fc1_weights_data + expert_idx * fc1_out_features * fc1_packed_cols, + fc1_scales_ptr, + is_fc1_block_wise ? block_size_ : 0, expert_weight_bits_, + fc1_out_features, hidden_size, B1_dequant, tp); + } MlasGemm(CblasNoTrans, CblasTrans, - static_cast(num_expert_tokens), static_cast(fc1_out_features), static_cast(hidden_size), - 1.0f, A1, static_cast(hidden_size), - B1_dequant, static_cast(hidden_size), - 0.0f, C1, static_cast(fc1_out_features), - nullptr); - - const T* B1_bias = (fc1_experts_bias) ? fc1_experts_bias->Data() + expert_idx * fc1_out_features : nullptr; - if (B1_bias) { + m, n, k, + 1.0f, A1, k, + B1_dequant, k, + 0.0f, C1, n, + tp); + + fc1_bias_handled_by_q4_gemm = fc1_used_direct_q4 && has_fc1_bias; + if (has_fc1_bias && !fc1_bias_handled_by_q4_gemm) { + const T* B1_bias = fc1_bias_data + expert_idx * fc1_out_features; if constexpr (std::is_same_v) { - MlasConvertHalfToFloatBuffer(reinterpret_cast(B1_bias), bias1_float, static_cast(fc1_out_features)); + MlasConvertHalfToFloatBuffer(reinterpret_cast(B1_bias), thread_bias1_buffer, static_cast(fc1_out_features)); } else { - memcpy(bias1_float, B1_bias, static_cast(fc1_out_features) * sizeof(float)); + if (ShouldUseMemcpy(fc1_out_features)) { + std::memcpy(thread_bias1_buffer, B1_bias, static_cast(fc1_out_features) * sizeof(float)); + } else { + const int64_t unroll_factor = GetUnrollFactor(fc1_out_features); + int64_t j = 0; + for (; j + unroll_factor <= fc1_out_features; j += unroll_factor) { + for (int64_t k = 0; k < unroll_factor; ++k) { + thread_bias1_buffer[j + k] = static_cast(B1_bias[j + k]); + } + } + for (; j < fc1_out_features; ++j) { + thread_bias1_buffer[j] = static_cast(B1_bias[j]); + } + } } + for (int64_t i = 0; i < num_expert_tokens; ++i) { - for (int64_t j = 0; j < fc1_out_features; ++j) { - C1[i * fc1_out_features + j] += bias1_float[j]; + float* C1_row = C1 + i * fc1_out_features; + const int64_t unroll_factor = GetUnrollFactor(fc1_out_features); + + int64_t j = 0; + for (; j + unroll_factor <= fc1_out_features; j += unroll_factor) { + for (int64_t k = 0; k < unroll_factor; ++k) { + C1_row[j + k] += thread_bias1_buffer[j + k]; + } + } + for (; j < fc1_out_features; ++j) { + C1_row[j] += thread_bias1_buffer[j]; } } } - // --- Activation --- - for (int64_t i = 0; i < num_expert_tokens; ++i) { - const float* C1_token = C1 + i * fc1_out_features; - float* A2_token = A2 + i * inter_size; - ApplySwiGLUActivation(C1_token, A2_token, inter_size, true, activation_alpha_, activation_beta_, swiglu_limit_); + fc1_gemm_done: + + const int64_t activation_threshold = std::max(int64_t{4}, 256 / std::max(int64_t{1}, inter_size)); + if (num_expert_tokens >= activation_threshold && tp != nullptr) { + const int64_t activation_block_size = std::max(int64_t{1}, std::min(int64_t{64}, activation_threshold)); + const int64_t num_activation_blocks = (num_expert_tokens + activation_block_size - 1) / activation_block_size; + + if (num_activation_blocks > 1) { + concurrency::ThreadPool::TrySimpleParallelFor(tp, static_cast(num_activation_blocks), [&](std::ptrdiff_t block_idx) { + const int64_t start_token = block_idx * activation_block_size; + const int64_t end_token = std::min(start_token + activation_block_size, num_expert_tokens); + + for (int64_t i = start_token; i < end_token; ++i) { + const float* C1_token = C1 + i * fc1_out_features; + float* A2_token = A2 + i * inter_size; + ApplySwiGLUActivation(C1_token, A2_token, inter_size, true, activation_alpha_, activation_beta_, swiglu_limit_); + } + }); + } else { + for (int64_t i = 0; i < num_expert_tokens; ++i) { + const float* C1_token = C1 + i * fc1_out_features; + float* A2_token = A2 + i * inter_size; + ApplySwiGLUActivation(C1_token, A2_token, inter_size, true, activation_alpha_, activation_beta_, swiglu_limit_); + } + } + } else { + for (int64_t i = 0; i < num_expert_tokens; ++i) { + const float* C1_token = C1 + i * fc1_out_features; + float* A2_token = A2 + i * inter_size; + ApplySwiGLUActivation(C1_token, A2_token, inter_size, true, activation_alpha_, activation_beta_, swiglu_limit_); + } + } + + const T* fc2_scales_ptr; + + if (is_fc2_block_wise) { + const int64_t fc2_blocks_per_row = fc2_scales_dims[2]; + fc2_scales_ptr = fc2_scales_data + expert_idx * hidden_size * fc2_blocks_per_row; + } else { + fc2_scales_ptr = fc2_scales_data + expert_idx * hidden_size; } - // --- FC2 GEMM (A2 * W2^T) --- - DequantizeBlock(fc2_experts_weights->Data() + expert_idx * hidden_size * (inter_size / (8 / expert_weight_bits_)), - fc2_scales->Data() + expert_idx * hidden_size * (block_size_ > 0 ? inter_size / block_size_ : 1), - block_size_, expert_weight_bits_, - hidden_size, inter_size, B2_dequant); + const int64_t fc2_dequant_block_size = GetDequantBlockSize(hidden_size, num_expert_tokens); + const int64_t num_fc2_dequant_blocks = (hidden_size + fc2_dequant_block_size - 1) / fc2_dequant_block_size; + + const size_t m2 = static_cast(num_expert_tokens); + const size_t n2 = static_cast(hidden_size); + const size_t k2 = static_cast(inter_size); + + MLAS_BLK_QUANT_TYPE q_type2; + bool use_direct_q4_gemm_fc2 = CanUseMlasQ4Gemm(expert_weight_bits_, is_fc2_block_wise ? block_size_ : 0, + hidden_size, inter_size, q_type2); + bool fc2_used_direct_q4 = false; + + if (use_direct_q4_gemm_fc2) { + IAllocatorUniquePtr mlas_packed_fc2; + Status convert_status = ConvertToMlasQ4Format( + fc2_weights_data + expert_idx * hidden_size * fc2_packed_cols, + fc2_scales_ptr, + is_fc2_block_wise ? block_size_ : 0, + expert_weight_bits_, + hidden_size, + inter_size, + q_type2, + allocator, + mlas_packed_fc2); + + if (convert_status.IsOK()) { + float* fc2_bias_float = nullptr; + IAllocatorUniquePtr fc2_bias_buffer; + + if (has_fc2_bias) { + const T* B2_bias = fc2_bias_data + expert_idx * hidden_size; + fc2_bias_buffer = IAllocator::MakeUniquePtr(allocator, static_cast(hidden_size)); + fc2_bias_float = fc2_bias_buffer.get(); + + if constexpr (std::is_same_v) { + MlasConvertHalfToFloatBuffer(reinterpret_cast(B2_bias), fc2_bias_float, static_cast(hidden_size)); + } else { + for (int64_t i = 0; i < hidden_size; ++i) { + fc2_bias_float[i] = static_cast(B2_bias[i]); + } + } + } + + Status gemm_status = DirectQ4Gemm(A2, mlas_packed_fc2.get(), fc2_bias_float, C2, + num_expert_tokens, hidden_size, inter_size, q_type2, tp); + + if (gemm_status.IsOK()) { + fc2_used_direct_q4 = true; +#ifdef ONNXRUNTIME_ENABLE_VERBOSE_LOGGING + LOGS_DEFAULT(VERBOSE) << "QMoE: Using direct MLAS Q4 GEMM for FC2 expert " << expert_idx + << " (M=" << num_expert_tokens << ", N=" << hidden_size << ", K=" << inter_size << ")"; +#endif + goto fc2_gemm_done; + } + } + + // If direct Q4 GEMM failed, fall back to traditional approach + } + + // Traditional approach: dequantize + regular GEMM + if (num_fc2_dequant_blocks > 1 && hidden_size >= 32) { + concurrency::ThreadPool::TrySimpleParallelFor(tp, static_cast(num_fc2_dequant_blocks), [&](std::ptrdiff_t block_idx) { + const int64_t start_row = block_idx * fc2_dequant_block_size; + const int64_t end_row = std::min(start_row + fc2_dequant_block_size, hidden_size); + const auto offset = expert_idx * hidden_size * fc2_packed_cols + start_row * fc2_packed_cols; + DequantizeBlock(fc2_weights_data + offset, + fc2_scales_ptr + (is_fc2_block_wise ? start_row * fc2_scales_dims[2] : start_row), + is_fc2_block_wise ? block_size_ : 0, expert_weight_bits_, + end_row - start_row, inter_size, B2_dequant + start_row * inter_size, tp); + }); + } else { + DequantizeBlock(fc2_weights_data + expert_idx * hidden_size * fc2_packed_cols, + fc2_scales_ptr, + is_fc2_block_wise ? block_size_ : 0, expert_weight_bits_, + hidden_size, inter_size, B2_dequant, tp); + } MlasGemm(CblasNoTrans, CblasTrans, - static_cast(num_expert_tokens), static_cast(hidden_size), static_cast(inter_size), - 1.0f, A2, static_cast(inter_size), - B2_dequant, static_cast(inter_size), - 0.0f, C2, static_cast(hidden_size), - nullptr); - - const T* B2_bias = (fc2_experts_bias) ? fc2_experts_bias->Data() + expert_idx * hidden_size : nullptr; - if (B2_bias) { + m2, n2, k2, + 1.0f, A2, k2, + B2_dequant, k2, + 0.0f, C2, n2, + tp); + + fc2_gemm_done: + + bool fc2_bias_handled_by_q4_gemm = fc2_used_direct_q4 && has_fc2_bias; + if (has_fc2_bias && !fc2_bias_handled_by_q4_gemm) { + const T* B2_bias = fc2_bias_data + expert_idx * hidden_size; if constexpr (std::is_same_v) { - MlasConvertHalfToFloatBuffer(reinterpret_cast(B2_bias), bias2_float, static_cast(hidden_size)); + MlasConvertHalfToFloatBuffer(reinterpret_cast(B2_bias), thread_bias2_buffer, static_cast(hidden_size)); } else { - memcpy(bias2_float, B2_bias, static_cast(hidden_size) * sizeof(float)); + if (ShouldUseMemcpy(hidden_size)) { + std::memcpy(thread_bias2_buffer, B2_bias, static_cast(hidden_size) * sizeof(float)); + } else { + const int64_t unroll_factor = GetUnrollFactor(hidden_size); + int64_t j = 0; + for (; j + unroll_factor <= hidden_size; j += unroll_factor) { + for (int64_t k = 0; k < unroll_factor; ++k) { + thread_bias2_buffer[j + k] = static_cast(B2_bias[j + k]); + } + } + for (; j < hidden_size; ++j) { + thread_bias2_buffer[j] = static_cast(B2_bias[j]); + } + } } } @@ -331,28 +934,89 @@ Status QMoECPU::Compute(OpKernelContext* context) const { const int64_t token_idx = route_idx / k_; const float weight = route_scale[route_idx]; + if (token_idx < 0 || token_idx >= num_tokens) continue; + const size_t buffer_offset = static_cast(token_idx) * static_cast(hidden_size); - if (buffer_offset + static_cast(hidden_size) > output_buffer_size) { - // Skip this token to prevent buffer overflow - continue; - } + if (buffer_offset + static_cast(hidden_size) > output_buffer_size) continue; float* dest = thread_local_outputs + static_cast(thread_id) * output_buffer_size + buffer_offset; const float* src = C2 + i * hidden_size; - for (int64_t j = 0; j < hidden_size; ++j) { - dest[j] += weight * (src[j] + (B2_bias ? bias2_float[j] : 0.0f)); + + if (has_fc2_bias && !fc2_bias_handled_by_q4_gemm) { + const int64_t unroll_factor = GetUnrollFactor(hidden_size); + int64_t j = 0; + for (; j + unroll_factor <= hidden_size; j += unroll_factor) { + for (int64_t k = 0; k < unroll_factor; ++k) { + dest[j + k] += weight * (src[j + k] + thread_bias2_buffer[j + k]); + } + } + for (; j < hidden_size; ++j) { + dest[j] += weight * (src[j] + thread_bias2_buffer[j]); + } + } else { + const int64_t unroll_factor = GetUnrollFactor(hidden_size); + int64_t j = 0; + for (; j + unroll_factor <= hidden_size; j += unroll_factor) { + for (int64_t k = 0; k < unroll_factor; ++k) { + dest[j + k] += weight * src[j + k]; + } + } + for (; j < hidden_size; ++j) { + dest[j] += weight * src[j]; + } } } } }); - // --- 4. Final Reduction (accumulate expert outputs to a float buffer) --- auto accumulate = [&](float* buffer) { - memset(buffer, 0, output_buffer_size * sizeof(float)); - for (int i = 0; i < num_expert_threads; ++i) { - const size_t thread_offset = static_cast(i) * output_buffer_size; - for (size_t j = 0; j < output_buffer_size; ++j) { - buffer[j] += thread_local_outputs[thread_offset + j]; + std::memset(buffer, 0, output_buffer_size * sizeof(float)); + + const int max_acc_threads = tp ? concurrency::ThreadPool::DegreeOfParallelism(tp) : 1; + const size_t acc_thread_divisor = std::max(size_t{1}, static_cast(max_acc_threads) * 8); + const size_t min_elements_per_thread = std::max(size_t{32}, output_buffer_size / acc_thread_divisor); + const int optimal_acc_threads = (tp == nullptr || output_buffer_size < min_elements_per_thread) ? 1 : std::min(static_cast(output_buffer_size / std::max(size_t{1}, min_elements_per_thread)), max_acc_threads); + const int num_acc_threads = std::max(1, optimal_acc_threads); + + if (num_acc_threads > 1) { + concurrency::ThreadPool::TrySimpleParallelFor(tp, num_acc_threads, [&](std::ptrdiff_t acc_thread_id) { + const size_t elements_per_thread = output_buffer_size / static_cast(num_acc_threads); + const size_t start_idx = static_cast(acc_thread_id) * elements_per_thread; + const size_t end_idx = (acc_thread_id == num_acc_threads - 1) ? output_buffer_size : start_idx + elements_per_thread; + + for (int i = 0; i < num_expert_threads; ++i) { + const size_t thread_offset = static_cast(i) * output_buffer_size; + const float* src = thread_local_outputs + thread_offset + start_idx; + float* dst = buffer + start_idx; + + size_t j = 0; + const size_t chunk_size = end_idx - start_idx; + const int64_t unroll_factor = GetUnrollFactor(static_cast(chunk_size)); + for (; j + static_cast(unroll_factor) <= chunk_size; j += static_cast(unroll_factor)) { + for (int64_t k = 0; k < unroll_factor; ++k) { + dst[j + static_cast(k)] += src[j + static_cast(k)]; + } + } + for (; j < chunk_size; ++j) { + dst[j] += src[j]; + } + } + }); + } else { + for (int i = 0; i < num_expert_threads; ++i) { + const size_t thread_offset = static_cast(i) * output_buffer_size; + const float* src = thread_local_outputs + thread_offset; + + size_t j = 0; + const int64_t unroll_factor = GetUnrollFactor(static_cast(output_buffer_size)); + for (; j + static_cast(unroll_factor) <= output_buffer_size; j += static_cast(unroll_factor)) { + for (int64_t k = 0; k < unroll_factor; ++k) { + buffer[j + static_cast(k)] += src[j + static_cast(k)]; + } + } + for (; j < output_buffer_size; ++j) { + buffer[j] += src[j]; + } } } }; @@ -362,18 +1026,16 @@ Status QMoECPU::Compute(OpKernelContext* context) const { float* final_output_float = final_output_float_ptr.get(); accumulate(final_output_float); - // --- 5. Convert final float buffer to output type T --- MlasConvertFloatToHalfBuffer(final_output_float, reinterpret_cast(output->MutableData()), static_cast(output_buffer_size)); - } else { // T is float + } else { accumulate(output->MutableData()); } return Status::OK(); } -// Explicit template instantiation template QMoECPU::QMoECPU(const OpKernelInfo& op_kernel_info); template Status QMoECPU::Compute(OpKernelContext* context) const; template QMoECPU::QMoECPU(const OpKernelInfo& op_kernel_info); diff --git a/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.cu b/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.cu index 62d6a723bf32c..0c4d75aeddac0 100644 --- a/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.cu +++ b/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.cu @@ -794,6 +794,39 @@ void LaunchAddBiasTranspose( } } +template <> +void LaunchAddBiasTranspose( + cudaStream_t stream, const int num_matrices, const int format, const int max_threads_per_block, + const int batch_size, const int sequence_length, const int num_heads, const int qk_head_size, + const BFloat16* input, const BFloat16* biases, BFloat16* output, + bool /*enable_half4*/, const int v_head_size, + BFloat16* qkv_add_bias, int total_matrix_count, + bool do_rotary, int rotary_embedding, int past_sequence_length) { + total_matrix_count = std::max(num_matrices, total_matrix_count); + if (0 == (qk_head_size & 1) && (v_head_size == -1 || 0 == (v_head_size & 1)) && !do_rotary) { + const int H = qk_head_size / 2; + const int H_v = v_head_size / 2; + + const __nv_bfloat162* input2 = reinterpret_cast(input); + const __nv_bfloat162* biases2 = reinterpret_cast(biases); + __nv_bfloat162* output2 = reinterpret_cast<__nv_bfloat162*>(output); + __nv_bfloat162* qkv_add_bias2 = reinterpret_cast<__nv_bfloat162*>(qkv_add_bias); + + InvokeAddBiasTranspose<__nv_bfloat162>( + stream, num_matrices, format, max_threads_per_block, + batch_size, sequence_length, num_heads, H, + input2, biases2, output2, qkv_add_bias2, + H_v, total_matrix_count); + } else { + InvokeAddBiasTranspose( + stream, num_matrices, format, max_threads_per_block, + batch_size, sequence_length, num_heads, qk_head_size, + input, biases, output, + qkv_add_bias, v_head_size, total_matrix_count, + do_rotary, rotary_embedding, past_sequence_length); + } +} + template <> void LaunchAddBiasTranspose( cudaStream_t stream, const int num_matrices, const int format, const int max_threads_per_block, @@ -888,6 +921,20 @@ void LaunchAddBiasTransposeTrt( ORT_ENFORCE(false, "Shall not call this since fused kernel does not support float input."); } +template <> +void LaunchAddBiasTransposeTrt( + cudaStream_t /*stream*/, const int /*max_threads_per_block*/, + const int /*batch_size*/, const int /*sequence_length*/, + const int /*num_heads*/, const int /*head_size*/, + const BFloat16* /*biases*/, + const BFloat16* /*query*/, + const BFloat16* /*key*/, + const BFloat16* /*value*/, + BFloat16* /*output*/, + bool /*is_cross_attention*/, int /*kv_sequence_length*/) { + ORT_ENFORCE(false, "BF16 not supported for LaunchAddBiasTransposeTrt."); +} + template <> void LaunchAddBiasTransposeTrt( cudaStream_t stream, const int max_threads_per_block, @@ -1049,6 +1096,38 @@ void LaunchAddBias( } } +template <> +void LaunchAddBias( + cudaStream_t stream, const int max_threads_per_block, + const int batch_size, const int sequence_length, const int kv_sequence_length, + const int num_heads, const int head_size, const int v_head_size, + const BFloat16* biases, const BFloat16* query, const BFloat16* key, const BFloat16* value, + BFloat16* q, BFloat16* k, BFloat16* v) { + if (0 == (head_size & 1) && 0 == (v_head_size & 1)) { + const int H = head_size / 2; + const int H_v = v_head_size / 2; + const __nv_bfloat162* query2 = reinterpret_cast(query); + const __nv_bfloat162* key2 = reinterpret_cast(key); + const __nv_bfloat162* value2 = reinterpret_cast(value); + const __nv_bfloat162* biases2 = reinterpret_cast(biases); + __nv_bfloat162* q2 = reinterpret_cast<__nv_bfloat162*>(q); + __nv_bfloat162* k2 = reinterpret_cast<__nv_bfloat162*>(k); + __nv_bfloat162* v2 = reinterpret_cast<__nv_bfloat162*>(v); + + InvokeAddBias<__nv_bfloat162>( + stream, max_threads_per_block, + batch_size, sequence_length, kv_sequence_length, num_heads, H, H_v, + biases2, query2, key2, value2, q2, k2, v2); + + } else { + InvokeAddBias( + stream, max_threads_per_block, + batch_size, sequence_length, kv_sequence_length, num_heads, + head_size, v_head_size, + biases, query, key, value, q, k, v); + } +} + template void InvokeAddBias( cudaStream_t stream, const int max_threads_per_block, @@ -1125,6 +1204,31 @@ void LaunchAddBias( } } +template <> +void LaunchAddBias( + cudaStream_t stream, const int max_threads_per_block, + const int batch_size, const int sequence_length, + const int num_heads, const int head_size, + const BFloat16* biases, const BFloat16* query, BFloat16* q) { + if (0 == (head_size & 1)) { + const int H = head_size / 2; + const __nv_bfloat162* query2 = reinterpret_cast(query); + const __nv_bfloat162* biases2 = reinterpret_cast(biases); + __nv_bfloat162* q2 = reinterpret_cast<__nv_bfloat162*>(q); + + InvokeAddBias<__nv_bfloat162>( + stream, max_threads_per_block, + batch_size, sequence_length, num_heads, H, + biases2, query2, q2); + + } else { + InvokeAddBias( + stream, max_threads_per_block, + batch_size, sequence_length, num_heads, head_size, + biases, query, q); + } +} + } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/attention.cc b/onnxruntime/contrib_ops/cuda/bert/attention.cc index 22e2879a5be15..4bc95454d40f9 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/attention.cc @@ -36,20 +36,24 @@ constexpr int kPresentOutputIndex = 1; REGISTER_KERNEL_TYPED(float) REGISTER_KERNEL_TYPED(MLFloat16) +REGISTER_KERNEL_TYPED(BFloat16) template Attention::Attention(const OpKernelInfo& info) : CudaKernel(info), AttentionBase(info, false) { kernel_options_ = this->GetAttentionKernelOptions(); - disable_fused_self_attention_ = sizeof(T) != 2 || !kernel_options_->UseTrtFusedAttention(); + constexpr bool kIsFp16 = std::is_same::value; + constexpr bool kIsBf16 = std::is_same::value; + constexpr bool kIs16bit = kIsFp16 || kIsBf16; - enable_trt_flash_attention_ = sizeof(T) == 2 && kernel_options_->UseTrtFlashAttention(); + // We only support FP16 for TRT fused/flash/causal attention. + disable_fused_self_attention_ = !kIsFp16 || !kernel_options_->UseTrtFusedAttention(); + enable_trt_flash_attention_ = kIsFp16 && kernel_options_->UseTrtFlashAttention(); + enable_fused_causal_attention_ = kIsFp16 && kernel_options_->UseTrtCausalAttention(); - enable_fused_causal_attention_ = sizeof(T) == 2 && kernel_options_->UseTrtCausalAttention(); + disable_memory_efficient_attention_ = kIsBf16 || !kernel_options_->UseEfficientAttention(); - disable_memory_efficient_attention_ = !kernel_options_->UseEfficientAttention(); - - disable_flash_attention_ = sizeof(T) != 2 || !kernel_options_->UseFlashAttention(); + disable_flash_attention_ = !kIs16bit || !kernel_options_->UseFlashAttention(); } template diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu index 216f101aad4be..1f29ed49f624c 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu @@ -952,6 +952,13 @@ Status QkvToContext( Stream* ort_stream, contrib::AttentionParameters& parameters, AttentionData& data) { + if constexpr (std::is_same::value || std::is_same::value) { + if (device_prop.major < 8) { + ORT_THROW("BF16 Attention requires Ampere (sm_80)+ with BF16 support. This GPU (", + device_prop.name, ", cc ", device_prop.major, ".", device_prop.minor, ") is not supported."); + } + } + auto stream = static_cast(ort_stream->GetHandle()); const int max_threads_per_block = device_prop.maxThreadsPerBlock; const int batch_size = parameters.batch_size; @@ -1040,6 +1047,8 @@ template struct AttentionData; template struct AttentionData; +template struct AttentionData; + template Status QkvToContext( const cudaDeviceProp& device_prop, cublasHandle_t& cublas, @@ -1056,6 +1065,14 @@ template Status QkvToContext( contrib::AttentionParameters& parameters, AttentionData& data); +template Status QkvToContext( + const cudaDeviceProp& device_prop, + cublasHandle_t& cublas, + cudnnHandle_t& cudnn, + Stream* ort_stream, + contrib::AttentionParameters& parameters, + AttentionData& data); + template Status QkvToContext( const cudaDeviceProp& device_prop, cublasHandle_t& cublas, diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.h b/onnxruntime/contrib_ops/cuda/bert/attention_impl.h index 14841b74daec8..09f5e66286807 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.h @@ -3,6 +3,7 @@ #pragma once +#include #include #include #include @@ -96,6 +97,10 @@ Status LaunchTransCtx(cudaStream_t stream, const int sequence_length, const int batch_size, const int head_size, const int num_heads, const int max_threads_per_block, const bool reversed_bs, const half* input, half* output); +Status LaunchTransCtx(cudaStream_t stream, + const int sequence_length, const int batch_size, const int head_size, const int num_heads, + const int max_threads_per_block, const bool reversed_bs, const BFloat16* input, BFloat16* output); + // BxSxMxNxH or SxBxMxNxH (reversed_bs is true) => MxBxNxSxH Status LaunchTransQkv(cudaStream_t stream, const int matrix_num, const int sequence_length, const int batch_size, const int head_size, const int num_heads, @@ -107,12 +112,20 @@ Status LaunchTransQkv(cudaStream_t stream, const int matrix_num, const int max_threads_per_block, const bool reversed_bs, const half* input, half* output, int total_matrix_count = -1); +Status LaunchTransQkv(cudaStream_t stream, const int matrix_num, + const int sequence_length, const int batch_size, const int head_size, const int num_heads, + const int max_threads_per_block, const bool reversed_bs, const BFloat16* input, BFloat16* output, + int total_matrix_count = -1); + Status Transpose_BSNH_to_BNSH(const int batch_size, const int sequence_length, const int num_heads, const int head_size, const float* input, float* output, cudaStream_t stream, const int max_threads_per_block); Status Transpose_BSNH_to_BNSH(const int batch_size, const int sequence_length, const int num_heads, const int head_size, const half* input, half* output, cudaStream_t stream, const int max_threads_per_block); +Status Transpose_BSNH_to_BNSH(const int batch_size, const int sequence_length, const int num_heads, const int head_size, + const BFloat16* input, BFloat16* output, cudaStream_t stream, const int max_threads_per_block); + template Status ConcatPastToPresent(int batch_size, int num_heads, int qk_head_size, int v_head_size, int sequence_length, int total_sequence_length, diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_kv_cache.cu b/onnxruntime/contrib_ops/cuda/bert/attention_kv_cache.cu index 121ddcf779485..80152e918ae30 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_kv_cache.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_kv_cache.cu @@ -197,6 +197,59 @@ Status LaunchConcatTensorToTensor(cudaStream_t stream, return CUDA_CALL(cudaGetLastError()); } +Status LaunchConcatTensorToTensor(cudaStream_t stream, + const int all_sequence_length, + const int sequence_length, + const int batch_size, + const int head_size, + const int num_heads, + const int max_threads_per_block, + const int matrix_num, + const BFloat16* tensor_in, + const BFloat16* tensor_add, + BFloat16* tensor_out) { + assert(num_heads <= max_threads_per_block); + const dim3 grid(all_sequence_length, batch_size, matrix_num); + if (0 == (head_size & 1)) { + const int H = head_size / 2; + if (H * num_heads <= max_threads_per_block) { + const dim3 block(H, num_heads, 1); + ConcatTensorToTensor<__nv_bfloat162><<>>( + sequence_length, + reinterpret_cast(tensor_in), + reinterpret_cast(tensor_add), + reinterpret_cast<__nv_bfloat162*>(tensor_out)); + } else { + const dim3 block(max_threads_per_block / num_heads, num_heads, 1); + ConcatTensorToTensorLarge<__nv_bfloat162><<>>( + sequence_length, + H, + reinterpret_cast(tensor_in), + reinterpret_cast(tensor_add), + reinterpret_cast<__nv_bfloat162*>(tensor_out)); + } + } else { + if (head_size * num_heads <= max_threads_per_block) { + const dim3 block(head_size, num_heads, 1); + ConcatTensorToTensor<__nv_bfloat16><<>>( + sequence_length, + reinterpret_cast(tensor_in), + reinterpret_cast(tensor_add), + reinterpret_cast<__nv_bfloat16*>(tensor_out)); + } else { + const dim3 block(max_threads_per_block / num_heads, num_heads, 1); + ConcatTensorToTensorLarge<__nv_bfloat16><<>>( + sequence_length, + head_size, + reinterpret_cast(tensor_in), + reinterpret_cast(tensor_add), + reinterpret_cast<__nv_bfloat16*>(tensor_out)); + } + } + + return CUDA_CALL(cudaGetLastError()); +} + #ifndef USE_ROCM // exclude the following from hipify since they are not used in ROCM EP // ---------------------------------------------------------------------------------- @@ -332,6 +385,18 @@ template Status LaunchAddBiasTransAppendKvToPresent(cudaStream_t stream, const half* bias, const half* qkv_buffer, half* present); + +template Status LaunchAddBiasTransAppendKvToPresent(cudaStream_t stream, + const int max_sequence_length, + const int total_sequence_length, + const int sequence_length, + const int batch_size, + const int head_size, + const int num_heads, + const int max_threads_per_block, + const BFloat16* bias, + const BFloat16* qkv_buffer, + BFloat16* present); #endif // Kernel to append new and past kv in either BSNH or BNSH format diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_kv_cache.h b/onnxruntime/contrib_ops/cuda/bert/attention_kv_cache.h index 94104e8b6a5d6..d7d7bdd87d62f 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_kv_cache.h +++ b/onnxruntime/contrib_ops/cuda/bert/attention_kv_cache.h @@ -3,6 +3,7 @@ #pragma once #include "core/providers/cuda/shared_inc/cuda_utils.h" +#include #include #include "core/framework/allocator.h" #include "core/providers/cuda/cuda_common.h" @@ -38,6 +39,18 @@ Status LaunchConcatTensorToTensor(cudaStream_t stream, const half* tensor_add, half* tensor_out); +Status LaunchConcatTensorToTensor(cudaStream_t stream, + const int all_sequence_length, + const int sequence_length, + const int batch_size, + const int head_size, + const int num_heads, + const int max_threads_per_block, + const int matrix_num, + const BFloat16* tensor_in, + const BFloat16* tensor_add, + BFloat16* tensor_out); + template Status LaunchAddBiasTransAppendKvToPresent(cudaStream_t stream, const int max_sequence_length, diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu b/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu index a7989df3439ae..79dbb47ba406b 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu @@ -744,9 +744,11 @@ Status PrepareQkv(contrib::AttentionParameters& parameters, #endif if (nullptr != data.gemm_buffer) { // Attention operator - ORT_RETURN_IF_ERROR(PrepareQkv_Attention(parameters, data, stream, max_threads_per_block)); + ORT_RETURN_IF_ERROR(PrepareQkv_Attention( + parameters, data, stream, max_threads_per_block)); } else { // MultiHeadAttention operator - ORT_RETURN_IF_ERROR(PrepareQkv_MultiHeadAttention(parameters, data, stream, max_threads_per_block)); + ORT_RETURN_IF_ERROR(PrepareQkv_MultiHeadAttention( + parameters, data, stream, max_threads_per_block)); } assert(data.qkv_format != AttentionQkvFormat::UNKNOWN); @@ -776,6 +778,12 @@ template Status PrepareQkv( cudaStream_t stream, int max_threads_per_block); +template Status PrepareQkv( + contrib::AttentionParameters& parameters, + AttentionData& data, + cudaStream_t stream, + int max_threads_per_block); + } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_qk.cu b/onnxruntime/contrib_ops/cuda/bert/attention_qk.cu index 3f02a441da73e..ad0f8ef13e5d0 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_qk.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_qk.cu @@ -37,15 +37,20 @@ Status CopyQK(cudaStream_t stream, const int qk_size, const T* input, QK* output) { - constexpr const bool half2float = std::is_same::value && std::is_same::value; - constexpr const bool float2half = std::is_same::value && std::is_same::value; - static_assert(half2float || float2half, "This function supports either or "); + if constexpr (std::is_same_v) { + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(output, input, static_cast(qk_size) * sizeof(T), cudaMemcpyDeviceToDevice, stream)); + return Status::OK(); + } else { + constexpr const bool half2float = std::is_same::value && std::is_same::value; + constexpr const bool float2half = std::is_same::value && std::is_same::value; + static_assert(half2float || float2half, "This function supports either or "); - constexpr const int block_size = 256; - int num_blocks = (qk_size + block_size - 1) / block_size; - ConvertAndCopyQK<<>>(qk_size, input, output); + constexpr const int block_size = 256; + int num_blocks = (qk_size + block_size - 1) / block_size; + ConvertAndCopyQK<<>>(qk_size, input, output); - return CUDA_CALL(cudaGetLastError()); + return CUDA_CALL(cudaGetLastError()); + } } template Status CopyQK(cudaStream_t stream, @@ -58,23 +63,20 @@ template Status CopyQK(cudaStream_t stream, const half* input, float* output); -template <> -Status CopyQK(cudaStream_t stream, - const int qk_size, - const float* input, - float* output) { - CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(output, input, qk_size * sizeof(float), cudaMemcpyDeviceToDevice, stream)); - return Status::OK(); -} +template Status CopyQK(cudaStream_t stream, + const int qk_size, + const float* input, + float* output); -template <> -Status CopyQK(cudaStream_t stream, - const int qk_size, - const half* input, - half* output) { - CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(output, input, qk_size * sizeof(half), cudaMemcpyDeviceToDevice, stream)); - return Status::OK(); -} +template Status CopyQK(cudaStream_t stream, + const int qk_size, + const half* input, + half* output); + +template Status CopyQK(cudaStream_t stream, + const int qk_size, + const BFloat16* input, + BFloat16* output); } // namespace cuda } // namespace contrib diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_softmax.cu b/onnxruntime/contrib_ops/cuda/bert/attention_softmax.cu index 6870258ca7390..938033644a7d6 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_softmax.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_softmax.cu @@ -999,6 +999,12 @@ template Status ComputeSoftmax( const bool broadcast_attn_bias_dim_0, const bool broadcast_attn_bias_dim_1, half* input, half* output, bool causal); +template Status ComputeSoftmax( + cudaStream_t stream, const int total_sequence_length, const int sequence_length, + const int batch_size, const int num_heads, const BFloat16* attn_bias, + const bool broadcast_attn_bias_dim_0, const bool broadcast_attn_bias_dim_1, + BFloat16* input, BFloat16* output, bool causal); + template Status ComputeSoftmaxWithCumSeqLength( const float* input, const float* attn_bias, @@ -1051,6 +1057,20 @@ template Status ComputeSoftmaxWithMask1D(cudaStream_t stream, half* output, const bool causal); +template Status ComputeSoftmaxWithMask1D(cudaStream_t stream, + const int total_sequence_length, + const int sequence_length, + const int batch_size, + const int num_heads, + const int* mask_index, + const int* mask_start, + const BFloat16* attn_bias, + const bool broadcast_attn_bias_dim_0, + const bool broadcast_attn_bias_dim_1, + const BFloat16* input, + BFloat16* output, + const bool causal); + template Status ComputeSoftmaxWithRawMask(Stream* ort_stream, const int total_sequence_length, const int sequence_length, @@ -1091,6 +1111,26 @@ template Status ComputeSoftmaxWithRawMask(Stream* ort_stream, half* persistent_softmax_workspace, const float mask_filter_value); +template Status ComputeSoftmaxWithRawMask(Stream* ort_stream, + const int total_sequence_length, + const int sequence_length, + const int batch_size, + const int num_heads, + const int* attention_mask, + const bool* key_padding_mask, + const BFloat16* attn_bias, + const bool broadcast_attn_bias_dim_0, + const bool broadcast_attn_bias_dim_1, + const BFloat16* input, + BFloat16* output, + const bool causal, + const float rsqrt_head_size, + const int mask_dimension, + const int max_sequence_length, + const bool use_persistent_softmax, + BFloat16* persistent_softmax_workspace, + const float mask_filter_value); + } // namespace attention_softmax_cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_transpose.cu b/onnxruntime/contrib_ops/cuda/bert/attention_transpose.cu index 9f3e396b7f949..e7177987fa2d1 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_transpose.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_transpose.cu @@ -162,6 +162,41 @@ Status LaunchTransCtx(cudaStream_t stream, return CUDA_CALL(cudaGetLastError()); } +Status LaunchTransCtx(cudaStream_t stream, + const int sequence_length, const int batch_size, const int head_size, const int num_heads, + const int max_threads_per_block, const bool reversed_bs, + const BFloat16* input, BFloat16* output) { + assert(num_heads <= max_threads_per_block); + const dim3 grid(sequence_length, batch_size, 1); + + if (0 == (head_size & 1)) { + const int H = head_size / 2; + const __nv_bfloat162* input2 = reinterpret_cast(input); + __nv_bfloat162* output2 = reinterpret_cast<__nv_bfloat162*>(output); + + if (H * num_heads <= max_threads_per_block) { + const dim3 block(H, num_heads, 1); + TransposeCtx<__nv_bfloat162><<>>(H, reversed_bs, input2, output2); + } else { + const dim3 block(max_threads_per_block / num_heads, num_heads, 1); + TransposeCtxLarge<__nv_bfloat162><<>>(H, reversed_bs, input2, output2); + } + } else { + const __nv_bfloat16* input2 = reinterpret_cast(input); + __nv_bfloat16* output2 = reinterpret_cast<__nv_bfloat16*>(output); + + if (head_size * num_heads <= max_threads_per_block) { + const dim3 block(head_size, num_heads, 1); + TransposeCtx<__nv_bfloat16><<>>(head_size, reversed_bs, input2, output2); + } else { + const dim3 block(max_threads_per_block / num_heads, num_heads, 1); + TransposeCtxLarge<__nv_bfloat16><<>>(head_size, reversed_bs, input2, output2); + } + } + + return CUDA_CALL(cudaGetLastError()); +} + template __global__ void TransposeQKV(const int H, const bool reversed_bs, const T* input, T* output, const int chunk_num) { // Input: BxSxKxNxH or SxBxKxNxH @@ -298,6 +333,48 @@ Status LaunchTransQkv(cudaStream_t stream, const int matrix_num, return CUDA_CALL(cudaGetLastError()); } +Status LaunchTransQkv(cudaStream_t stream, const int matrix_num, + const int sequence_length, const int batch_size, const int head_size, const int num_heads, + const int max_threads_per_block, const bool reversed_bs, + const BFloat16* input, BFloat16* output, + int total_matrix_count) { + assert(num_heads <= max_threads_per_block); + total_matrix_count = max(total_matrix_count, matrix_num); + const dim3 grid(sequence_length, batch_size, matrix_num); + + if (0 == (head_size & 1)) { + const int H = head_size / 2; + const nv_bfloat162* input2 = reinterpret_cast(input); + nv_bfloat162* output2 = reinterpret_cast(output); + + if (H * num_heads <= max_threads_per_block) { + const dim3 block(H, num_heads, 1); + TransposeQKV<<>>(H, reversed_bs, input2, output2, total_matrix_count); + } else { + const dim3 block(max_threads_per_block / num_heads, num_heads, 1); + TransposeQKVLarge<<>>(H, reversed_bs, input2, output2, total_matrix_count); + } + } else { + if (head_size * num_heads <= max_threads_per_block) { + const dim3 block(head_size, num_heads, 1); + TransposeQKV<<>>( + head_size, reversed_bs, + reinterpret_cast(input), + reinterpret_cast(output), + total_matrix_count); + } else { + const dim3 block(max_threads_per_block / num_heads, num_heads, 1); + TransposeQKVLarge<<>>( + head_size, reversed_bs, + reinterpret_cast(input), + reinterpret_cast(output), + total_matrix_count); + } + } + + return CUDA_CALL(cudaGetLastError()); +} + Status Transpose_BSNH_to_BNSH(const int batch_size, const int sequence_length, const int num_heads, const int head_size, const half* input, half* output, cudaStream_t stream, const int max_threads_per_block) { return LaunchTransQkv(stream, 1, sequence_length, batch_size, head_size, num_heads, @@ -310,6 +387,12 @@ Status Transpose_BSNH_to_BNSH(const int batch_size, const int sequence_length, c max_threads_per_block, false, input, output); } +Status Transpose_BSNH_to_BNSH(const int batch_size, const int sequence_length, const int num_heads, const int head_size, + const BFloat16* input, BFloat16* output, cudaStream_t stream, const int max_threads_per_block) { + return LaunchTransQkv(stream, 1, sequence_length, batch_size, head_size, num_heads, + max_threads_per_block, false, input, output); +} + } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_128.cu b/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_128.cu index 0f2db956e55db..d4e872f8ac165 100644 --- a/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_128.cu +++ b/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_128.cu @@ -59,6 +59,8 @@ template void mmha_launch_kernel(const DecoderMaskedMultiHead template void mmha_launch_kernel(const DecoderMaskedMultiHeadAttentionParameters& params, cudaStream_t stream); template void mmha_launch_kernel(const DecoderMaskedMultiHeadAttentionParameters& params, cudaStream_t stream); template void mmha_launch_kernel(const DecoderMaskedMultiHeadAttentionParameters& params, cudaStream_t stream); +template void mmha_launch_kernel(const DecoderMaskedMultiHeadAttentionParameters&, cudaStream_t stream); +template void mmha_launch_kernel(const DecoderMaskedMultiHeadAttentionParameters&, cudaStream_t stream); } // namespace cuda } // namespace contrib diff --git a/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_32.cu b/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_32.cu index d878291cabca0..16f22b020ee1f 100644 --- a/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_32.cu +++ b/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_32.cu @@ -59,6 +59,8 @@ template void mmha_launch_kernel(const DecoderMaskedMultiHeadA template void mmha_launch_kernel(const DecoderMaskedMultiHeadAttentionParameters& params, cudaStream_t stream); template void mmha_launch_kernel(const DecoderMaskedMultiHeadAttentionParameters& params, cudaStream_t stream); template void mmha_launch_kernel(const DecoderMaskedMultiHeadAttentionParameters& params, cudaStream_t stream); +template void mmha_launch_kernel(const DecoderMaskedMultiHeadAttentionParameters&, cudaStream_t stream); +template void mmha_launch_kernel(const DecoderMaskedMultiHeadAttentionParameters&, cudaStream_t stream); } // namespace cuda } // namespace contrib diff --git a/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_64.cu b/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_64.cu index b547ad67a61a5..c933b0c6d2241 100644 --- a/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_64.cu +++ b/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_64.cu @@ -59,6 +59,8 @@ template void mmha_launch_kernel(const DecoderMaskedMultiHeadA template void mmha_launch_kernel(const DecoderMaskedMultiHeadAttentionParameters& params, cudaStream_t stream); template void mmha_launch_kernel(const DecoderMaskedMultiHeadAttentionParameters& params, cudaStream_t stream); template void mmha_launch_kernel(const DecoderMaskedMultiHeadAttentionParameters& params, cudaStream_t stream); +template void mmha_launch_kernel(const DecoderMaskedMultiHeadAttentionParameters&, cudaStream_t stream); +template void mmha_launch_kernel(const DecoderMaskedMultiHeadAttentionParameters&, cudaStream_t stream); } // namespace cuda } // namespace contrib diff --git a/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu index 6ba5ce66eaa60..efb48fee60772 100644 --- a/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu @@ -794,6 +794,31 @@ template void __global__ masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentionParameters params); template void __global__ masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentionParameters params); +// QK BFloat16 templates +template void __global__ masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentionParameters params); +template void __global__ masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentionParameters params); +template void __global__ masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentionParameters params); + +template void __global__ masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentionParameters params); +template void __global__ masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentionParameters params); +template void __global__ masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentionParameters params); + +template void __global__ masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentionParameters params); +template void __global__ masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentionParameters params); +template void __global__ masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentionParameters params); + +template void __global__ masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentionParameters params); +template void __global__ masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentionParameters params); +template void __global__ masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentionParameters params); + +template void __global__ masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentionParameters params); +template void __global__ masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentionParameters params); +template void __global__ masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentionParameters params); + +template void __global__ masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentionParameters params); +template void __global__ masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentionParameters params); +template void __global__ masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentionParameters params); + } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.cu index ea410998b8eef..b6de0f7f49282 100644 --- a/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.cu @@ -296,6 +296,16 @@ struct T2 { using Type = half2; }; +template <> +struct T2 { + using Type = __nv_bfloat162; +}; + +template <> +struct T4 { + using Type = nv_bfloat164; +}; + template void AddBiasTransposePacked( const T* input, const T* biases, T* output, @@ -680,6 +690,13 @@ template void AddBiasTransposePacked( AttentionQkvFormat format, const int32_t* token_offset, int32_t token_count, cudaStream_t stream); +template void AddBiasTransposePacked( + const BFloat16* input, const BFloat16* biases, BFloat16* output, + const int batch_size, const int sequence_length, + const int num_heads, const int qk_head_size, const int v_hidden_size, + AttentionQkvFormat format, const int32_t* token_offset, int32_t token_count, + cudaStream_t stream); + template Status QkvToContext( const cudaDeviceProp& device_prop, cublasHandle_t& cublas, diff --git a/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu index 07bca3f7fff99..4db5064c853f2 100644 --- a/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu @@ -488,6 +488,11 @@ struct T4 { using Type = Half4; }; +template <> +struct T4 { + using Type = nv_bfloat164; +}; + template struct T2; @@ -501,6 +506,11 @@ struct T2 { using Type = half2; }; +template <> +struct T2 { + using Type = __nv_bfloat162; +}; + template void AddBiasTransposePacked( const T* query, const T* key, const T* value, const T* bias, T* output, @@ -873,6 +883,14 @@ template void AddBiasTransposePacked( const int32_t* token_offset, int32_t token_count, cudaStream_t stream); +template void AddBiasTransposePacked( + const BFloat16* query, const BFloat16* key, const BFloat16* value, const BFloat16* bias, BFloat16* output, + const int batch_size, const int sequence_length, + const int num_heads, const int qk_head_size, const int v_head_size, + AttentionQkvFormat source_format, AttentionQkvFormat target_format, + const int32_t* token_offset, int32_t token_count, + cudaStream_t stream); + template Status QkvToContext( const cudaDeviceProp& device_prop, cublasHandle_t& cublas, diff --git a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_util.h b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_util.h index a35f626cadf74..320aa2a552198 100644 --- a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_util.h +++ b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_util.h @@ -314,6 +314,30 @@ __device__ __inline__ void write_smem_transpose(const uint4& vec, Half4* smem, i return; } +template <> +__device__ __inline__ void vec_from_smem_transpose(uint32_t& vec, onnxruntime::BFloat16* smem, + int transpose_idx, int smem_pitch) { + return; +} + +template <> +__device__ __inline__ void vec_from_smem_transpose(uint2& vec, __nv_bfloat162* smem, + int transpose_idx, int smem_pitch) { + return; +} + +template <> +__device__ __inline__ void write_smem_transpose(const uint32_t& vec, onnxruntime::BFloat16* smem, + int transpose_idx, int smem_pitch) { + return; +} + +template <> +__device__ __inline__ void write_smem_transpose(const uint2& vec, __nv_bfloat162* smem, + int transpose_idx, int smem_pitch) { + return; +} + template <> __device__ __inline__ void write_smem_transpose(const uint4& vec, uint16_t* smem, int transpose_idx, int smem_pitch) { union { diff --git a/onnxruntime/contrib_ops/cuda/bert/utils.cuh b/onnxruntime/contrib_ops/cuda/bert/utils.cuh index dd61288a3c126..a45664083f5c7 100644 --- a/onnxruntime/contrib_ops/cuda/bert/utils.cuh +++ b/onnxruntime/contrib_ops/cuda/bert/utils.cuh @@ -24,6 +24,7 @@ #pragma once +#include #include "core/providers/cuda/cuda_common.h" #include "core/providers/cuda/cu_inc/common.cuh" @@ -44,6 +45,18 @@ __device__ __forceinline__ Half4 operator+(const Half4& a, const Half4& b) { return r; } +struct __align__(8) nv_bfloat164 { + __nv_bfloat162 x; + __nv_bfloat162 y; +}; + +__device__ __forceinline__ nv_bfloat164 operator+(const nv_bfloat164& a, const nv_bfloat164& b) { + nv_bfloat164 r; + r.x = __hadd2(a.x, b.x); + r.y = __hadd2(a.y, b.y); + return r; +} + __device__ __forceinline__ float2 operator+(const float2& a, const float2& b) { return make_float2(a.x + b.x, a.y + b.y); } @@ -143,6 +156,24 @@ struct Vec_t { static constexpr int size = 8; }; +template <> +struct Vec_t { + using Type = uint32_t; + static constexpr int size = 2; +}; + +template <> +struct Vec_t<__nv_bfloat162> { + using Type = uint2; + static constexpr int size = 4; +}; + +template <> +struct Vec_t { + using Type = uint4; + static constexpr int size = 8; +}; + //------------------------------------------------------------ // Qk_vec //------------------------------------------------------------ diff --git a/onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc b/onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc index 93d802ca05b42..167b2af946183 100644 --- a/onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc +++ b/onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc @@ -77,7 +77,8 @@ Status ShardedMoE::ComputeInternal(OpKernelContext* context) const { fc2_experts_weights, fc2_experts_bias_optional, nullptr, fc3_experts_weights_optional, fc3_experts_bias_optional, nullptr, 1, // no quantization so pack size is 1 - activation_type_ == ort_fastertransformer::ActivationType::SwiGLU)); + activation_type_ == ort_fastertransformer::ActivationType::SwiGLU, + 0)); // no block-wise quantization for sharded MoE ORT_RETURN_IF_NOT(moe_params.num_experts % nccl_->Size() == 0, "num_experts should be divisible by world_size"); diff --git a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc index 36d6fc378d45e..20a595a639fb8 100644 --- a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc @@ -84,6 +84,7 @@ class CUDA_ONNX_OP_TYPED_CLASS_NAME(1, double, Affine); class CUDA_ONNX_OP_TYPED_CLASS_NAME(1, MLFloat16, Affine); class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, Attention); class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, Attention); +class CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, Attention); class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, PackedAttention); class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, PackedAttention); class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, PackedMultiHeadAttention); @@ -320,6 +321,7 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/contrib_ops/cuda/moe/moe.cc b/onnxruntime/contrib_ops/cuda/moe/moe.cc index a5b9d483d5ad1..e5a064d59e360 100644 --- a/onnxruntime/contrib_ops/cuda/moe/moe.cc +++ b/onnxruntime/contrib_ops/cuda/moe/moe.cc @@ -45,7 +45,8 @@ Status MoE::ComputeInternal(OpKernelContext* context) const { fc2_experts_weights, fc2_experts_bias_optional, nullptr, fc3_experts_weights_optional, fc3_experts_bias_optional, nullptr, 1, // no quantization so pack size is 1 - activation_type_ == ort_fastertransformer::ActivationType::SwiGLU)); + activation_type_ == ort_fastertransformer::ActivationType::SwiGLU, + 0)); // no block-wise quantization for regular MoE using CudaT = typename OrtToCudaType::type; auto stream = context->GetComputeStream(); diff --git a/onnxruntime/contrib_ops/cuda/quantization/moe_quantization.cc b/onnxruntime/contrib_ops/cuda/quantization/moe_quantization.cc index dcf32bb3c5ae4..931b8ac09aa49 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/moe_quantization.cc +++ b/onnxruntime/contrib_ops/cuda/quantization/moe_quantization.cc @@ -150,7 +150,8 @@ Status QMoE::ComputeInternal(OpKernelContext* context) const { fc2_experts_weights, fc2_experts_bias_optional, fc2_scales, fc3_experts_weights_optional, fc3_experts_bias_optional, fc3_scales_optional, expert_weight_bits_ == 4 ? 2 : 1, - activation_type_ == ort_fastertransformer::ActivationType::SwiGLU)); + activation_type_ == ort_fastertransformer::ActivationType::SwiGLU, + 0)); // CUDA doesn't support block-wise quantization yet #if defined(__GNUC__) #pragma GCC diagnostic push diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc index ab611a8e5a7c0..b5c1f73d1678d 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc @@ -164,7 +164,9 @@ Status FlashAttentionDecodeQKTProgram::GenerateShaderCode(ShaderHelper& shader) Status ComputeFlashAttentionDecodeQKT(onnxruntime::webgpu::ComputeContext& context, const Tensor* Q, const Tensor* attention_bias, Tensor* output, Tensor* present_key, Tensor* metadata, - const WebgpuAttentionParameters& parameters, uint32_t num_total_seq_length_tile, uint32_t num_present_sequence_length_tile, uint32_t tile_size) { + const WebgpuAttentionParameters& parameters, uint32_t num_total_seq_length_tile, + uint32_t num_present_sequence_length_tile, uint32_t tile_size, + uint32_t present_sequence_length) { const float alpha = parameters.scale_ == 0.0f ? 1.f / sqrt(static_cast(parameters.head_size_)) : parameters.scale_; @@ -187,8 +189,7 @@ Status ComputeFlashAttentionDecodeQKT(onnxruntime::webgpu::ComputeContext& conte .AddUniformVariables({{static_cast(vectorized_head_size)}, {static_cast(parameters.total_sequence_length_)}, {static_cast(alpha)}, - // present_sequence_length is used to index into the KV cache, for static kv cache it is the max sequence length. - {static_cast(parameters.is_gqa_ ? parameters.seqlen_present_kv_cache_ : parameters.total_sequence_length_)}, + present_sequence_length, {static_cast(parameters.n_reps)}, {num_total_seq_length_tile}, {num_present_sequence_length_tile}, @@ -220,7 +221,8 @@ Status ComputeFlashAttentionDecodeSplitVxScore(onnxruntime::webgpu::ComputeConte const WebgpuAttentionParameters& parameters, uint32_t num_total_seq_length_tile, uint32_t num_present_sequence_length_tile, - uint32_t tile_size) { + uint32_t tile_size, + uint32_t present_sequence_length) { const int components = 4; int head_size_vec = parameters.v_head_size_ / components; FlashAttentionDecodeSplitVxProgram program{"FlashAttentionDecodeSplitVx", tile_size, head_size_vec}; @@ -233,7 +235,7 @@ Status ComputeFlashAttentionDecodeSplitVxScore(onnxruntime::webgpu::ComputeConte .SetWorkgroupSize(64) .AddUniformVariables({{static_cast(parameters.total_sequence_length_)}, {static_cast(head_size_vec)}, - {static_cast(parameters.is_gqa_ ? parameters.seqlen_present_kv_cache_ : parameters.total_sequence_length_)}, + present_sequence_length, {static_cast(parameters.n_reps)}, num_total_seq_length_tile, num_present_sequence_length_tile, @@ -279,7 +281,10 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co Tensor* output, const Tensor* past_key, Tensor* present_key, const Tensor* past_value, Tensor* present_value, const WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context) { ORT_RETURN_IF_ERROR(CopyKVCache(context, parameters, K, past_key, present_key, V, past_value, present_value)); - const int present_sequence_length = parameters.is_gqa_ ? parameters.seqlen_present_kv_cache_ : parameters.total_sequence_length_; + + // Extract present_sequence_length directly from present_key tensor shape: + // (batch_size, num_heads, total_sequence_length/max_sequence_length, head_size) + const uint32_t present_sequence_length = static_cast(present_key->Shape()[2]); if (parameters.sequence_length_ > 1) { const uint32_t tile_size = 64; bool has_attention_bias = attention_bias != nullptr; @@ -332,12 +337,14 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co const TensorShape metadata_shape(metadata_dims); Tensor metadata = context.CreateGPUTensor(DataTypeImpl::GetType(), metadata_shape); ORT_RETURN_IF_ERROR(ComputeFlashAttentionDecodeQKT(context, Q, attention_bias, &qk, present_key, &metadata, - parameters, num_total_seq_length_tile, num_present_sequence_length_tile, tile_size)); + parameters, num_total_seq_length_tile, num_present_sequence_length_tile, tile_size, + present_sequence_length)); const TensorShapeVector out_split_vx_dims({parameters.batch_size_, parameters.num_heads_, num_present_sequence_length_tile, parameters.head_size_}); const TensorShape out_split_vx_shape(out_split_vx_dims); Tensor out_split_vx = context.CreateGPUTensor(Q->DataType(), out_split_vx_shape); - ORT_RETURN_IF_ERROR(ComputeFlashAttentionDecodeSplitVxScore(context, &metadata, &qk, &out_split_vx, present_value, parameters, num_total_seq_length_tile, num_present_sequence_length_tile, tile_size)); + ORT_RETURN_IF_ERROR(ComputeFlashAttentionDecodeSplitVxScore(context, &metadata, &qk, &out_split_vx, present_value, parameters, + num_total_seq_length_tile, num_present_sequence_length_tile, tile_size, present_sequence_length)); ORT_RETURN_IF_ERROR(ComputeFlashAttentionDecodeVxReduce(context, &out_split_vx, output, parameters, num_total_seq_length_tile, num_present_sequence_length_tile)); return Status::OK(); diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc index e2b17aa84d2b1..f4f90004f546e 100644 --- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc @@ -502,7 +502,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "T", OpSchema::Optional) .TypeConstraint("T", - {"tensor(float)", "tensor(float16)"}, + {"tensor(float)", "tensor(float16)", "tensor(bfloat16)"}, "Constrain input and output types to float tensors.") .TypeConstraint("M", {"tensor(int32)"}, diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h index 4ad13f072ae72..4cad44a56ba96 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -949,6 +949,15 @@ extern "C" { #if defined(__aarch64__) && defined(__linux__) MLAS_SBGEMM_FLOAT_KERNEL MlasSbgemmKernelZero; MLAS_SBGEMM_FLOAT_KERNEL MlasSbgemmKernelAdd; +#endif +#if defined(MLAS_TARGET_ARM64) + MLAS_CONV_FLOAT_KERNEL MlasConvNchwFloatKernelNeon; + MLAS_CONV_FLOAT_KERNEL MlasConvNchwcFloatKernelNeon; + MLAS_CONV_DEPTHWISE_FLOAT_KERNEL MlasConvDepthwiseFloatKernelNeon; + MLAS_CONV_POINTWISE_FLOAT_KERNEL MlasConvPointwiseFloatKernelNeon; + MLAS_POOL_FLOAT_KERNEL MlasPoolMaximumFloatKernelNeon; + MLAS_POOL_FLOAT_KERNEL MlasPoolAverageExcludePadFloatKernelNeon; + MLAS_POOL_FLOAT_KERNEL MlasPoolAverageIncludePadFloatKernelNeon; #endif MLAS_GEMM_DOUBLE_KERNEL MlasDgemmKernelZero; MLAS_GEMM_DOUBLE_KERNEL MlasDgemmKernelAdd; @@ -1335,6 +1344,12 @@ struct MLAS_PLATFORM { const MLAS_GEMM_QUANT_DISPATCH* GemmU8U8Dispatch; const MLAS_GEMM_QUANT_DISPATCH* GemmU8S8Dispatch; const MLAS_GEMM_QUANT_DISPATCH* GemmS8S8Dispatch; + MLAS_CONV_FLOAT_KERNEL* ConvNchwFloatKernel; + MLAS_CONV_FLOAT_KERNEL* ConvNchwcFloatKernel; + MLAS_CONV_DEPTHWISE_FLOAT_KERNEL* ConvDepthwiseFloatKernel; + MLAS_CONV_POINTWISE_FLOAT_KERNEL* ConvPointwiseFloatKernel; + MLAS_POOL_FLOAT_KERNEL* PoolFloatKernel[MlasPoolingKindCount]; + uint32_t NchwcBlockSize; #endif const MLAS_SYMM_QGEMM_DISPATCH* SymmQgemmDispatch{nullptr}; @@ -1395,6 +1410,7 @@ struct MLAS_PLATFORM { int32_t MaximumThreadCount; #elif defined(MLAS_TARGET_ARM64) static constexpr int32_t MaximumThreadCount = MLAS_MAXIMUM_THREAD_COUNT * 4; + static constexpr size_t MLAS_NEON_NCHWC_BLOCK_SIZE = 16; #else static constexpr int32_t MaximumThreadCount = MLAS_MAXIMUM_THREAD_COUNT; #endif diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp index c4b8d5e78a491..923e513ccb07a 100644 --- a/onnxruntime/core/mlas/lib/platform.cpp +++ b/onnxruntime/core/mlas/lib/platform.cpp @@ -558,6 +558,15 @@ Return Value: this->SoftmaxDispatch = &MlasSoftmaxDispatchNeon; this->EltwiseDispatch = &MlasEltwiseDispatchNeon; + this->ConvNchwFloatKernel = MlasConvNchwFloatKernelNeon; + this->ConvNchwcFloatKernel = MlasConvNchwcFloatKernelNeon; + this->ConvDepthwiseFloatKernel = MlasConvDepthwiseFloatKernelNeon; + this->ConvPointwiseFloatKernel = MlasConvPointwiseFloatKernelNeon; + this->PoolFloatKernel[MlasMaximumPooling] = MlasPoolMaximumFloatKernelNeon; + this->PoolFloatKernel[MlasAveragePoolingExcludePad] = MlasPoolAverageExcludePadFloatKernelNeon; + this->PoolFloatKernel[MlasAveragePoolingIncludePad] = MlasPoolAverageIncludePadFloatKernelNeon; + this->NchwcBlockSize = MLAS_NEON_NCHWC_BLOCK_SIZE; + // // Check if the processor supports ASIMD dot product instructions. // diff --git a/onnxruntime/core/mlas/lib/sconv.h b/onnxruntime/core/mlas/lib/sconv.h new file mode 100644 index 0000000000000..94e657638975a --- /dev/null +++ b/onnxruntime/core/mlas/lib/sconv.h @@ -0,0 +1,25 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + sconv.h + +Abstract: + + This module defines convolution kernel flags for configuring convolution + operations including output accumulation, bias addition, and activations. + +--*/ + +// +// Define the convolution kernel flags. +// + +#define MLAS_CONV_KERNEL_FLAG_ACCUMULATE_OUTPUT 0x00000001 +#define MLAS_CONV_KERNEL_FLAG_BIAS_ADDITION 0x00000002 +#define MLAS_CONV_KERNEL_FLAG_RELU_ACTIVATION 0x00000004 +#define MLAS_CONV_KERNEL_FLAG_OTHER_ACTIVATION 0x00000008 \ No newline at end of file diff --git a/onnxruntime/core/mlas/lib/sconv_kernel_neon.cpp b/onnxruntime/core/mlas/lib/sconv_kernel_neon.cpp new file mode 100644 index 0000000000000..3ecad66a32886 --- /dev/null +++ b/onnxruntime/core/mlas/lib/sconv_kernel_neon.cpp @@ -0,0 +1,520 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + sconv_kernel_neon.cpp + +Abstract: + + This module implements the single precision convolution kernels for ARM NEON. + +--*/ + +#include "mlasi.h" +#include "sconv.h" + +constexpr size_t BlockSize = MLAS_PLATFORM::MLAS_NEON_NCHWC_BLOCK_SIZE; + +// Common implementation for NCHW and NCHWC convolution kernels +template +void + MLASCALL + MlasConvFloatKernelNeonImpl( + const float* Input, + const float* Filter, + float* Output, + size_t StrideWidth, + size_t DilationWidth, + size_t FilterCount, + size_t InputStride, + size_t FilterStride, + size_t OutputStride, + size_t KernelHeight, + size_t KernelWidth, + const float* InputBase, + size_t InputWidth, + size_t DilatedInputWidth, + size_t OutputCountLeftPad, + size_t OutputCount, + size_t OutputCountRightPad, + const float* Bias, + unsigned KernelFlags + ) +{ + const bool AccumulateOutput = (KernelFlags & MLAS_CONV_KERNEL_FLAG_ACCUMULATE_OUTPUT) != 0; + const bool BiasAddition = (KernelFlags & MLAS_CONV_KERNEL_FLAG_BIAS_ADDITION) != 0; + const bool ReluActivation = (KernelFlags & MLAS_CONV_KERNEL_FLAG_RELU_ACTIVATION) != 0; + + const float32x4_t ZeroVector = MlasBroadcastFloat32x4(0.0f); + + const size_t StrideWidthElements = StrideWidth / sizeof(float); + const size_t DilationWidthElements = DilationWidth / sizeof(float); + const size_t FilterStrideElements = FilterStride / sizeof(float); + const size_t OutputStrideElements = OutputStride / sizeof(float); + const size_t InputWidthElements = InputWidth / sizeof(float); + const size_t DilatedInputWidthElements = DilatedInputWidth / sizeof(float); + + (void)InputStride; + + const size_t TotalOutputCount = OutputCountLeftPad + OutputCount + OutputCountRightPad; + + for (size_t output_idx = 0; output_idx < TotalOutputCount; output_idx++) { + bool is_main_region = (output_idx >= OutputCountLeftPad && output_idx < OutputCountLeftPad + OutputCount); + + for (size_t filterSetBlock = 0; filterSetBlock < FilterCount; filterSetBlock++) { + const float* filter = Filter + filterSetBlock * FilterStrideElements; + float* output = Output + filterSetBlock * OutputStrideElements; + + float32x4_t Accumulator0, Accumulator1, Accumulator2, Accumulator3; + + if (AccumulateOutput) { + Accumulator0 = MlasLoadFloat32x4(&output[output_idx * BlockSize]); + Accumulator1 = MlasLoadFloat32x4(&output[output_idx * BlockSize + 4]); + Accumulator2 = MlasLoadFloat32x4(&output[output_idx * BlockSize + 8]); + Accumulator3 = MlasLoadFloat32x4(&output[output_idx * BlockSize + 12]); + } else { + Accumulator0 = MlasBroadcastFloat32x4(0.0f); + Accumulator1 = MlasBroadcastFloat32x4(0.0f); + Accumulator2 = MlasBroadcastFloat32x4(0.0f); + Accumulator3 = MlasBroadcastFloat32x4(0.0f); + } + + if (BiasAddition) { + const float32x4_t BiasVector0 = MlasLoadFloat32x4(&Bias[filterSetBlock * BlockSize]); + const float32x4_t BiasVector1 = MlasLoadFloat32x4(&Bias[filterSetBlock * BlockSize + 4]); + const float32x4_t BiasVector2 = MlasLoadFloat32x4(&Bias[filterSetBlock * BlockSize + 8]); + const float32x4_t BiasVector3 = MlasLoadFloat32x4(&Bias[filterSetBlock * BlockSize + 12]); + + Accumulator0 = MlasAddFloat32x4(Accumulator0, BiasVector0); + Accumulator1 = MlasAddFloat32x4(Accumulator1, BiasVector1); + Accumulator2 = MlasAddFloat32x4(Accumulator2, BiasVector2); + Accumulator3 = MlasAddFloat32x4(Accumulator3, BiasVector3); + } + + for (size_t kh = 0; kh < KernelHeight; kh++) { + for (size_t kw = 0; kw < KernelWidth; kw++) { + const float* input_base = Input + output_idx * StrideWidthElements + + kh * DilatedInputWidthElements + kw * DilationWidthElements; + + if (IsNchwcFormat) { + for (size_t filterBlock = 0; filterBlock < BlockSize; filterBlock++) { + const float* input_element = input_base + filterBlock; + const float* input_row_start = InputBase + kh * DilatedInputWidthElements; + const float* input_row_end = input_row_start + InputWidthElements; + + float input_value; + if (is_main_region || (input_element >= input_row_start && input_element < input_row_end)) { + input_value = *input_element; + } else { + input_value = 0.0f; + } + + const float32x4_t InputVector = MlasBroadcastFloat32x4(input_value); + + size_t kernel_base_pos = kh * (KernelWidth * BlockSize * BlockSize) + + kw * (BlockSize * BlockSize) + + filterBlock * BlockSize; + + const float32x4_t FilterVector0 = MlasLoadFloat32x4(&filter[kernel_base_pos]); + const float32x4_t FilterVector1 = MlasLoadFloat32x4(&filter[kernel_base_pos + 4]); + const float32x4_t FilterVector2 = MlasLoadFloat32x4(&filter[kernel_base_pos + 8]); + const float32x4_t FilterVector3 = MlasLoadFloat32x4(&filter[kernel_base_pos + 12]); + + Accumulator0 = MlasMultiplyAddFloat32x4(InputVector, FilterVector0, Accumulator0); + Accumulator1 = MlasMultiplyAddFloat32x4(InputVector, FilterVector1, Accumulator1); + Accumulator2 = MlasMultiplyAddFloat32x4(InputVector, FilterVector2, Accumulator2); + Accumulator3 = MlasMultiplyAddFloat32x4(InputVector, FilterVector3, Accumulator3); + } + } else { + const float* input_row_start = InputBase + kh * DilatedInputWidthElements; + const float* input_row_end = input_row_start + InputWidthElements; + + float input_value; + if (is_main_region || (input_base >= input_row_start && input_base < input_row_end)) { + input_value = *input_base; + } else { + input_value = 0.0f; + } + + const float32x4_t InputVector = MlasBroadcastFloat32x4(input_value); + + size_t kernel_base_pos = kh * KernelWidth + kw; + + const float32x4_t FilterVector0 = MlasLoadFloat32x4(&filter[kernel_base_pos * BlockSize]); + const float32x4_t FilterVector1 = MlasLoadFloat32x4(&filter[kernel_base_pos * BlockSize + 4]); + const float32x4_t FilterVector2 = MlasLoadFloat32x4(&filter[kernel_base_pos * BlockSize + 8]); + const float32x4_t FilterVector3 = MlasLoadFloat32x4(&filter[kernel_base_pos * BlockSize + 12]); + + Accumulator0 = MlasMultiplyAddFloat32x4(InputVector, FilterVector0, Accumulator0); + Accumulator1 = MlasMultiplyAddFloat32x4(InputVector, FilterVector1, Accumulator1); + Accumulator2 = MlasMultiplyAddFloat32x4(InputVector, FilterVector2, Accumulator2); + Accumulator3 = MlasMultiplyAddFloat32x4(InputVector, FilterVector3, Accumulator3); + } + } + } + + if (ReluActivation) { + Accumulator0 = MlasMaximumFloat32x4(Accumulator0, ZeroVector); + Accumulator1 = MlasMaximumFloat32x4(Accumulator1, ZeroVector); + Accumulator2 = MlasMaximumFloat32x4(Accumulator2, ZeroVector); + Accumulator3 = MlasMaximumFloat32x4(Accumulator3, ZeroVector); + } + + MlasStoreFloat32x4(&output[output_idx * BlockSize], Accumulator0); + MlasStoreFloat32x4(&output[output_idx * BlockSize + 4], Accumulator1); + MlasStoreFloat32x4(&output[output_idx * BlockSize + 8], Accumulator2); + MlasStoreFloat32x4(&output[output_idx * BlockSize + 12], Accumulator3); + } + } +} + +void + MLASCALL + MlasConvNchwFloatKernelNeon( + const float* Input, + const float* Filter, + float* Output, + size_t StrideWidth, + size_t DilationWidth, + size_t FilterCount, + size_t InputStride, + size_t FilterStride, + size_t OutputStride, + size_t KernelHeight, + size_t KernelWidth, + const float* InputBase, + size_t InputWidth, + size_t DilatedInputWidth, + size_t OutputCountLeftPad, + size_t OutputCount, + size_t OutputCountRightPad, + const float* Bias, + unsigned KernelFlags + ) +{ + MlasConvFloatKernelNeonImpl( + Input, + Filter, + Output, + StrideWidth, + DilationWidth, + FilterCount, + InputStride, + FilterStride, + OutputStride, + KernelHeight, + KernelWidth, + InputBase, + InputWidth, + DilatedInputWidth, + OutputCountLeftPad, + OutputCount, + OutputCountRightPad, + Bias, + KernelFlags + ); +} + +// +// Implementation of MlasConvNchwcFloatKernelNeon +// + +void + MLASCALL + MlasConvNchwcFloatKernelNeon( + const float* Input, + const float* Filter, + float* Output, + size_t StrideWidth, + size_t DilationWidth, + size_t FilterCount, + size_t InputStride, + size_t FilterStride, + size_t OutputStride, + size_t KernelHeight, + size_t KernelWidth, + const float* InputBase, + size_t InputWidth, + size_t DilatedInputWidth, + size_t OutputCountLeftPad, + size_t OutputCount, + size_t OutputCountRightPad, + const float* Bias, + unsigned KernelFlags + ) +{ + MlasConvFloatKernelNeonImpl( + Input, + Filter, + Output, + StrideWidth, + DilationWidth, + FilterCount, + InputStride, + FilterStride, + OutputStride, + KernelHeight, + KernelWidth, + InputBase, + InputWidth, + DilatedInputWidth, + OutputCountLeftPad, + OutputCount, + OutputCountRightPad, + Bias, + KernelFlags + ); +} + +// +// Helper function to load input vector with bounds checking +// +static inline float32x4_t +LoadInputVectorWithBounds( + const float* input_base, + size_t offset, + bool is_main_region, + const float* InputBase, + size_t kh, + size_t DilatedInputWidthElements, + size_t InputWidthElements +) +{ + if (is_main_region) { + return MlasLoadFloat32x4(input_base + offset); + } else { + float input_values[4]; + for (size_t i = 0; i < 4; i++) { + const float* input_element = input_base + offset + i; + const float* input_row_start = InputBase + kh * DilatedInputWidthElements; + const float* input_row_end = input_row_start + InputWidthElements; + + if (input_element >= input_row_start && input_element < input_row_end) { + input_values[i] = *input_element; + } else { + input_values[i] = 0.0f; + } + } + return MlasLoadFloat32x4(input_values); + } +} + +// +// Implementation of MlasConvDepthwiseFloatKernelNeon +// +// This kernel performs depthwise separable convolution where each input channel +// is convolved with its own filter. This is more efficient than standard convolution +// for certain network architectures like MobileNets. +// + +void + MLASCALL + MlasConvDepthwiseFloatKernelNeon( + const float* Input, + const float* Filter, + float* Output, + size_t StrideWidth, + size_t DilationWidth, + size_t InputStride, + size_t KernelHeight, + size_t KernelWidth, + const float* InputBase, + size_t InputWidth, + size_t DilatedInputWidth, + size_t OutputCountLeftPad, + size_t OutputCount, + size_t OutputCountRightPad, + const float* Bias, + unsigned KernelFlags + ) +{ + const bool AccumulateOutput = (KernelFlags & MLAS_CONV_KERNEL_FLAG_ACCUMULATE_OUTPUT) != 0; + const bool BiasAddition = (KernelFlags & MLAS_CONV_KERNEL_FLAG_BIAS_ADDITION) != 0; + const bool ReluActivation = (KernelFlags & MLAS_CONV_KERNEL_FLAG_RELU_ACTIVATION) != 0; + + const float32x4_t ZeroVector = MlasBroadcastFloat32x4(0.0f); + + const size_t StrideWidthElements = StrideWidth / sizeof(float); + const size_t DilationWidthElements = DilationWidth / sizeof(float); + const size_t InputStrideElements = InputStride / sizeof(float); + const size_t DilatedInputWidthElements = DilatedInputWidth / sizeof(float); + + (void)InputStrideElements; + + const size_t InputWidthElements = InputWidth / sizeof(float); + + const size_t TotalOutputCount = OutputCountLeftPad + OutputCount + OutputCountRightPad; + + for (size_t output_idx = 0; output_idx < TotalOutputCount; output_idx++) { + bool is_main_region = (output_idx >= OutputCountLeftPad && output_idx < OutputCountLeftPad + OutputCount); + + float32x4_t Accumulator0, Accumulator1, Accumulator2, Accumulator3; + + if (AccumulateOutput) { + Accumulator0 = MlasLoadFloat32x4(&Output[output_idx * BlockSize]); + Accumulator1 = MlasLoadFloat32x4(&Output[output_idx * BlockSize + 4]); + Accumulator2 = MlasLoadFloat32x4(&Output[output_idx * BlockSize + 8]); + Accumulator3 = MlasLoadFloat32x4(&Output[output_idx * BlockSize + 12]); + } else { + Accumulator0 = MlasBroadcastFloat32x4(0.0f); + Accumulator1 = MlasBroadcastFloat32x4(0.0f); + Accumulator2 = MlasBroadcastFloat32x4(0.0f); + Accumulator3 = MlasBroadcastFloat32x4(0.0f); + } + + if (BiasAddition) { + const float32x4_t BiasVector0 = MlasLoadFloat32x4(Bias); + const float32x4_t BiasVector1 = MlasLoadFloat32x4(Bias + 4); + const float32x4_t BiasVector2 = MlasLoadFloat32x4(Bias + 8); + const float32x4_t BiasVector3 = MlasLoadFloat32x4(Bias + 12); + + Accumulator0 = MlasAddFloat32x4(Accumulator0, BiasVector0); + Accumulator1 = MlasAddFloat32x4(Accumulator1, BiasVector1); + Accumulator2 = MlasAddFloat32x4(Accumulator2, BiasVector2); + Accumulator3 = MlasAddFloat32x4(Accumulator3, BiasVector3); + } + + for (size_t kh = 0; kh < KernelHeight; kh++) { + for (size_t kw = 0; kw < KernelWidth; kw++) { + size_t kernel_pos = kh * KernelWidth + kw; + + const float* input_base = Input + output_idx * StrideWidthElements + + kh * DilatedInputWidthElements + kw * DilationWidthElements; + + float32x4_t InputVector0 = LoadInputVectorWithBounds(input_base, 0, is_main_region, InputBase, kh, DilatedInputWidthElements, InputWidthElements); + float32x4_t InputVector1 = LoadInputVectorWithBounds(input_base, 4, is_main_region, InputBase, kh, DilatedInputWidthElements, InputWidthElements); + float32x4_t InputVector2 = LoadInputVectorWithBounds(input_base, 8, is_main_region, InputBase, kh, DilatedInputWidthElements, InputWidthElements); + float32x4_t InputVector3 = LoadInputVectorWithBounds(input_base, 12, is_main_region, InputBase, kh, DilatedInputWidthElements, InputWidthElements); + + const float32x4_t FilterVector0 = MlasLoadFloat32x4(&Filter[kernel_pos * BlockSize]); + const float32x4_t FilterVector1 = MlasLoadFloat32x4(&Filter[kernel_pos * BlockSize + 4]); + const float32x4_t FilterVector2 = MlasLoadFloat32x4(&Filter[kernel_pos * BlockSize + 8]); + const float32x4_t FilterVector3 = MlasLoadFloat32x4(&Filter[kernel_pos * BlockSize + 12]); + + Accumulator0 = MlasMultiplyAddFloat32x4(InputVector0, FilterVector0, Accumulator0); + Accumulator1 = MlasMultiplyAddFloat32x4(InputVector1, FilterVector1, Accumulator1); + Accumulator2 = MlasMultiplyAddFloat32x4(InputVector2, FilterVector2, Accumulator2); + Accumulator3 = MlasMultiplyAddFloat32x4(InputVector3, FilterVector3, Accumulator3); + } + } + + if (ReluActivation) { + Accumulator0 = MlasMaximumFloat32x4(Accumulator0, ZeroVector); + Accumulator1 = MlasMaximumFloat32x4(Accumulator1, ZeroVector); + Accumulator2 = MlasMaximumFloat32x4(Accumulator2, ZeroVector); + Accumulator3 = MlasMaximumFloat32x4(Accumulator3, ZeroVector); + } + + MlasStoreFloat32x4(&Output[output_idx * BlockSize], Accumulator0); + MlasStoreFloat32x4(&Output[output_idx * BlockSize + 4], Accumulator1); + MlasStoreFloat32x4(&Output[output_idx * BlockSize + 8], Accumulator2); + MlasStoreFloat32x4(&Output[output_idx * BlockSize + 12], Accumulator3); + } +} + +// +// Implementation of MlasConvPointwiseFloatKernelNeon +// +// This kernel performs pointwise (1x1) convolution which is essentially +// a matrix multiplication across the channel dimension. It's optimized +// for cases where the kernel size is 1x1. +// + +void + MLASCALL + MlasConvPointwiseFloatKernelNeon( + const float* Input, + const float* Filter, + float* Output, + size_t StrideWidth, + size_t InputChannels, + size_t FilterCount, + size_t InputStride, + size_t FilterStride, + size_t OutputStride, + size_t OutputCount, + const float* Bias, + unsigned KernelFlags + ) +{ + const bool AccumulateOutput = (KernelFlags & MLAS_CONV_KERNEL_FLAG_ACCUMULATE_OUTPUT) != 0; + const bool BiasAddition = (KernelFlags & MLAS_CONV_KERNEL_FLAG_BIAS_ADDITION) != 0; + const bool ReluActivation = (KernelFlags & MLAS_CONV_KERNEL_FLAG_RELU_ACTIVATION) != 0; + + const size_t StrideWidthElements = StrideWidth / sizeof(float); + const size_t InputStrideElements = InputStride / sizeof(float); + const size_t FilterStrideElements = FilterStride / sizeof(float); + const size_t OutputStrideElements = OutputStride / sizeof(float); + + const float32x4_t ZeroVector = MlasBroadcastFloat32x4(0.0f); + + for (size_t output_idx = 0; output_idx < OutputCount; output_idx++) { + for (size_t f = 0; f < FilterCount; f++) { + const float* filter = Filter + f * FilterStrideElements; + float* output = Output + f * OutputStrideElements; + + float32x4_t Accumulator0, Accumulator1, Accumulator2, Accumulator3; + + if (AccumulateOutput) { + Accumulator0 = MlasLoadFloat32x4(&output[output_idx * BlockSize]); + Accumulator1 = MlasLoadFloat32x4(&output[output_idx * BlockSize + 4]); + Accumulator2 = MlasLoadFloat32x4(&output[output_idx * BlockSize + 8]); + Accumulator3 = MlasLoadFloat32x4(&output[output_idx * BlockSize + 12]); + } else { + Accumulator0 = MlasBroadcastFloat32x4(0.0f); + Accumulator1 = MlasBroadcastFloat32x4(0.0f); + Accumulator2 = MlasBroadcastFloat32x4(0.0f); + Accumulator3 = MlasBroadcastFloat32x4(0.0f); + } + + if (BiasAddition) { + const float32x4_t BiasVector0 = MlasLoadFloat32x4(&Bias[f * BlockSize]); + const float32x4_t BiasVector1 = MlasLoadFloat32x4(&Bias[f * BlockSize + 4]); + const float32x4_t BiasVector2 = MlasLoadFloat32x4(&Bias[f * BlockSize + 8]); + const float32x4_t BiasVector3 = MlasLoadFloat32x4(&Bias[f * BlockSize + 12]); + + Accumulator0 = MlasAddFloat32x4(Accumulator0, BiasVector0); + Accumulator1 = MlasAddFloat32x4(Accumulator1, BiasVector1); + Accumulator2 = MlasAddFloat32x4(Accumulator2, BiasVector2); + Accumulator3 = MlasAddFloat32x4(Accumulator3, BiasVector3); + } + + for (size_t c = 0; c < InputChannels; c++) { + const float* input_ptr = Input + c * InputStrideElements + output_idx * StrideWidthElements; + + for (size_t input_b = 0; input_b < BlockSize; input_b++) { + const float input_value = input_ptr[input_b]; + const float32x4_t InputVector = MlasBroadcastFloat32x4(input_value); + + const float* filter_ptr = filter + (c * BlockSize + input_b) * BlockSize; + + const float32x4_t FilterVector0 = MlasLoadFloat32x4(filter_ptr); + const float32x4_t FilterVector1 = MlasLoadFloat32x4(filter_ptr + 4); + const float32x4_t FilterVector2 = MlasLoadFloat32x4(filter_ptr + 8); + const float32x4_t FilterVector3 = MlasLoadFloat32x4(filter_ptr + 12); + + Accumulator0 = MlasMultiplyAddFloat32x4(InputVector, FilterVector0, Accumulator0); + Accumulator1 = MlasMultiplyAddFloat32x4(InputVector, FilterVector1, Accumulator1); + Accumulator2 = MlasMultiplyAddFloat32x4(InputVector, FilterVector2, Accumulator2); + Accumulator3 = MlasMultiplyAddFloat32x4(InputVector, FilterVector3, Accumulator3); + } + } + + if (ReluActivation) { + Accumulator0 = MlasMaximumFloat32x4(Accumulator0, ZeroVector); + Accumulator1 = MlasMaximumFloat32x4(Accumulator1, ZeroVector); + Accumulator2 = MlasMaximumFloat32x4(Accumulator2, ZeroVector); + Accumulator3 = MlasMaximumFloat32x4(Accumulator3, ZeroVector); + } + + MlasStoreFloat32x4(&output[output_idx * BlockSize], Accumulator0); + MlasStoreFloat32x4(&output[output_idx * BlockSize + 4], Accumulator1); + MlasStoreFloat32x4(&output[output_idx * BlockSize + 8], Accumulator2); + MlasStoreFloat32x4(&output[output_idx * BlockSize + 12], Accumulator3); + } + } +} diff --git a/onnxruntime/core/mlas/lib/snchwc.cpp b/onnxruntime/core/mlas/lib/snchwc.cpp index f9cf1605787aa..2fc27d6d4ad7f 100644 --- a/onnxruntime/core/mlas/lib/snchwc.cpp +++ b/onnxruntime/core/mlas/lib/snchwc.cpp @@ -101,7 +101,7 @@ Return Value: --*/ { -#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) +#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) || defined(MLAS_TARGET_ARM64) return GetMlasPlatform().NchwcBlockSize; #else return 1; @@ -674,7 +674,7 @@ struct MLAS_NCHWC_CONV_NCHWC_ALGORITHM : MLAS_NCHWC_GROUPED_CONV_ALGORITHM const size_t BlockedOutputWidth = BlockSize * OutputWidth; -#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) +#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) || defined(MLAS_TARGET_ARM64) MLAS_CONV_FLOAT_KERNEL* Kernel = GetMlasPlatform().ConvNchwcFloatKernel; #else MLAS_CONV_FLOAT_KERNEL* Kernel = MlasConvNchwcFloatKernel; @@ -784,7 +784,7 @@ struct MLAS_NCHWC_CONV_NCHW_ALGORITHM : MLAS_NCHWC_GROUPED_CONV_ALGORITHM const size_t BlockedOutputWidth = BlockSize * OutputWidth; -#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) +#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) || defined(MLAS_TARGET_ARM64) MLAS_CONV_FLOAT_KERNEL* Kernel = GetMlasPlatform().ConvNchwFloatKernel; #else MLAS_CONV_FLOAT_KERNEL* Kernel = MlasConvNchwFloatKernel; @@ -879,7 +879,7 @@ struct MLAS_NCHWC_CONV_POINTWISE_ALGORITHM : MLAS_NCHWC_GROUPED_CONV_ALGORITHM const size_t FilterStrideBytes = BlockSize * InputChannels * sizeof(float); const size_t OutputStrideBytes = BlockSize * OutputSize * sizeof(float); -#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) +#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) || defined(MLAS_TARGET_ARM64) MLAS_CONV_POINTWISE_FLOAT_KERNEL* Kernel = GetMlasPlatform().ConvPointwiseFloatKernel; #else MLAS_CONV_POINTWISE_FLOAT_KERNEL* Kernel = MlasConvPointwiseFloatKernel; @@ -1016,7 +1016,7 @@ struct MLAS_NCHWC_CONV_DEPTHWISE_ALGORITHM : MLAS_NCHWC_CONV_ALGORITHM const size_t BlockedOutputWidth = BlockSize * OutputWidth; -#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) +#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) || defined(MLAS_TARGET_ARM64) MLAS_CONV_DEPTHWISE_FLOAT_KERNEL* Kernel = GetMlasPlatform().ConvDepthwiseFloatKernel; #else MLAS_CONV_DEPTHWISE_FLOAT_KERNEL* Kernel = MlasConvDepthwiseFloatKernel; @@ -1093,7 +1093,7 @@ struct MLAS_NCHWC_CONV_DEPTHWISE_ALGORITHM : MLAS_NCHWC_CONV_ALGORITHM struct MLAS_NCHWC_POOL_ALGORITHM : MLAS_NCHWC_NN_ALGORITHM { -#if !defined(MLAS_TARGET_AMD64) && !defined(MLAS_TARGET_LARCH64) +#if !defined(MLAS_TARGET_AMD64) && !defined(MLAS_TARGET_LARCH64) && !defined(MLAS_TARGET_ARM64) static MLAS_POOL_FLOAT_KERNEL* const PoolKernels[]; #endif @@ -1131,7 +1131,7 @@ struct MLAS_NCHWC_POOL_ALGORITHM : MLAS_NCHWC_NN_ALGORITHM const size_t DilatedInputWidthBytes = BlockSize * DilationHeight * InputWidth * sizeof(float); const size_t InputStrideBytes = DilatedInputWidthBytes - KernelWidth * DilationWidthBytes; -#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) +#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) || defined(MLAS_TARGET_ARM64) MLAS_POOL_FLOAT_KERNEL* Kernel = GetMlasPlatform().PoolFloatKernel[WorkBlock->PoolingKind]; #else MLAS_POOL_FLOAT_KERNEL* Kernel = PoolKernels[WorkBlock->PoolingKind]; @@ -1197,7 +1197,7 @@ struct MLAS_NCHWC_POOL_ALGORITHM : MLAS_NCHWC_NN_ALGORITHM } }; -#if !defined(MLAS_TARGET_AMD64) && !defined(MLAS_TARGET_LARCH64) +#if !defined(MLAS_TARGET_AMD64) && !defined(MLAS_TARGET_LARCH64) && !defined(MLAS_TARGET_ARM64) MLAS_POOL_FLOAT_KERNEL* const MLAS_NCHWC_POOL_ALGORITHM::PoolKernels[] = { @@ -1621,7 +1621,7 @@ Return Value: } } -#if !defined(MLAS_TARGET_AMD64) && !defined(MLAS_TARGET_LARCH64) +#if !defined(MLAS_TARGET_AMD64) && !defined(MLAS_TARGET_LARCH64) && !defined(MLAS_TARGET_ARM64) // // Convolution and pooling kernel stubs for architectures that do not yet have diff --git a/onnxruntime/core/mlas/lib/spool_kernel_neon.cpp b/onnxruntime/core/mlas/lib/spool_kernel_neon.cpp new file mode 100644 index 0000000000000..8cca036d54c3a --- /dev/null +++ b/onnxruntime/core/mlas/lib/spool_kernel_neon.cpp @@ -0,0 +1,289 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + spool_kernel_neon.cpp + +Abstract: + + This module implements the single precision pooling kernels for ARM NEON. + +--*/ + +#include "mlasi.h" + +constexpr size_t BlockSize = MLAS_PLATFORM::MLAS_NEON_NCHWC_BLOCK_SIZE; + +void + MLASCALL + MlasPoolMaximumFloatKernelNeon( + const float* Input, + float* Output, + size_t StrideWidth, + size_t DilationWidth, + size_t InputStride, + size_t ActualKernelSize, + size_t KernelHeight, + size_t KernelWidth, + const float* InputBase, + size_t InputWidth, + size_t DilatedInputWidth, + size_t OutputCountLeftPad, + size_t OutputCount, + size_t OutputCountRightPad + ) +{ + MLAS_UNREFERENCED_PARAMETER(ActualKernelSize); + MLAS_UNREFERENCED_PARAMETER(InputStride); + const size_t StrideWidthElements = StrideWidth / sizeof(float); + const size_t DilationWidthElements = DilationWidth / sizeof(float); + const size_t InputWidthElements = InputWidth / sizeof(float); + const size_t DilatedInputWidthElements = DilatedInputWidth / sizeof(float); + const size_t TotalOutputCount = OutputCountLeftPad + OutputCount + OutputCountRightPad; + + const float MaxPaddingValue = std::numeric_limits::lowest(); + + const MLAS_FLOAT32X4 MaxPaddingVector = MlasBroadcastFloat32x4(MaxPaddingValue); + + for (size_t output_idx = 0; output_idx < TotalOutputCount; output_idx++) { + MLAS_FLOAT32X4 MaxVector0 = MaxPaddingVector; + MLAS_FLOAT32X4 MaxVector1 = MaxPaddingVector; + MLAS_FLOAT32X4 MaxVector2 = MaxPaddingVector; + MLAS_FLOAT32X4 MaxVector3 = MaxPaddingVector; + + for (size_t kh = 0; kh < KernelHeight; kh++) { + const float* row_start = InputBase + kh * DilatedInputWidthElements; + const float* row_end = row_start + InputWidthElements; + + for (size_t kw = 0; kw < KernelWidth; kw++) { + const float* input_ptr = Input + output_idx * StrideWidthElements + + kh * DilatedInputWidthElements + kw * DilationWidthElements; + + if (input_ptr >= row_start && (input_ptr + BlockSize) <= row_end) { + MLAS_FLOAT32X4 InputVector0 = MlasLoadFloat32x4(input_ptr); + MLAS_FLOAT32X4 InputVector1 = MlasLoadFloat32x4(input_ptr + 4); + MLAS_FLOAT32X4 InputVector2 = MlasLoadFloat32x4(input_ptr + 8); + MLAS_FLOAT32X4 InputVector3 = MlasLoadFloat32x4(input_ptr + 12); + + MaxVector0 = MlasMaximumFloat32x4(MaxVector0, InputVector0); + MaxVector1 = MlasMaximumFloat32x4(MaxVector1, InputVector1); + MaxVector2 = MlasMaximumFloat32x4(MaxVector2, InputVector2); + MaxVector3 = MlasMaximumFloat32x4(MaxVector3, InputVector3); + } else { + float values[BlockSize]; + for (size_t i = 0; i < BlockSize; i++) { + const float* element_ptr = input_ptr + i; + if (element_ptr >= row_start && element_ptr < row_end) { + values[i] = *element_ptr; + } else { + values[i] = MaxPaddingValue; + } + } + + MLAS_FLOAT32X4 InputVector0 = MlasLoadFloat32x4(&values[0]); + MLAS_FLOAT32X4 InputVector1 = MlasLoadFloat32x4(&values[4]); + MLAS_FLOAT32X4 InputVector2 = MlasLoadFloat32x4(&values[8]); + MLAS_FLOAT32X4 InputVector3 = MlasLoadFloat32x4(&values[12]); + + MaxVector0 = MlasMaximumFloat32x4(MaxVector0, InputVector0); + MaxVector1 = MlasMaximumFloat32x4(MaxVector1, InputVector1); + MaxVector2 = MlasMaximumFloat32x4(MaxVector2, InputVector2); + MaxVector3 = MlasMaximumFloat32x4(MaxVector3, InputVector3); + } + } + } + + MlasStoreFloat32x4(&Output[output_idx * BlockSize], MaxVector0); + MlasStoreFloat32x4(&Output[output_idx * BlockSize + 4], MaxVector1); + MlasStoreFloat32x4(&Output[output_idx * BlockSize + 8], MaxVector2); + MlasStoreFloat32x4(&Output[output_idx * BlockSize + 12], MaxVector3); + } +} + +static void +MlasPoolAverageFloatKernelNeonImpl( + const float* Input, + float* Output, + size_t StrideWidth, + size_t DilationWidth, + size_t ActualKernelSize, + size_t KernelHeight, + size_t KernelWidth, + const float* InputBase, + size_t InputWidth, + size_t DilatedInputWidth, + size_t OutputCountLeftPad, + size_t OutputCount, + size_t OutputCountRightPad, + bool ExcludePad +) +{ + const size_t StrideWidthElements = StrideWidth / sizeof(float); + const size_t DilationWidthElements = DilationWidth / sizeof(float); + const size_t InputWidthElements = InputWidth / sizeof(float); + const size_t DilatedInputWidthElements = DilatedInputWidth / sizeof(float); + const size_t TotalOutputCount = OutputCountLeftPad + OutputCount + OutputCountRightPad; + + const MLAS_FLOAT32X4 ZeroVector = MlasZeroFloat32x4(); + + for (size_t output_idx = 0; output_idx < TotalOutputCount; output_idx++) { + MLAS_FLOAT32X4 SumVector0 = ZeroVector; + MLAS_FLOAT32X4 SumVector1 = ZeroVector; + MLAS_FLOAT32X4 SumVector2 = ZeroVector; + MLAS_FLOAT32X4 SumVector3 = ZeroVector; + + std::vector valid_count; + if (ExcludePad) { + valid_count.resize(BlockSize, 0); + } + + for (size_t kh = 0; kh < KernelHeight; kh++) { + const float* row_start = InputBase + kh * DilatedInputWidthElements; + const float* row_end = row_start + InputWidthElements; + + for (size_t kw = 0; kw < KernelWidth; kw++) { + const float* input_ptr = Input + output_idx * StrideWidthElements + + kh * DilatedInputWidthElements + kw * DilationWidthElements; + + if (input_ptr >= row_start && (input_ptr + BlockSize) <= row_end) { + MLAS_FLOAT32X4 InputVector0 = MlasLoadFloat32x4(input_ptr); + MLAS_FLOAT32X4 InputVector1 = MlasLoadFloat32x4(input_ptr + 4); + MLAS_FLOAT32X4 InputVector2 = MlasLoadFloat32x4(input_ptr + 8); + MLAS_FLOAT32X4 InputVector3 = MlasLoadFloat32x4(input_ptr + 12); + + SumVector0 = MlasAddFloat32x4(SumVector0, InputVector0); + SumVector1 = MlasAddFloat32x4(SumVector1, InputVector1); + SumVector2 = MlasAddFloat32x4(SumVector2, InputVector2); + SumVector3 = MlasAddFloat32x4(SumVector3, InputVector3); + + if (ExcludePad) { + for (size_t i = 0; i < BlockSize; i++) { + valid_count[i]++; + } + } + } else { + float values[BlockSize]; + for (size_t i = 0; i < BlockSize; i++) { + const float* element_ptr = input_ptr + i; + if (element_ptr >= row_start && element_ptr < row_end) { + values[i] = *element_ptr; + if (ExcludePad) { + valid_count[i]++; + } + } else { + values[i] = 0.0f; + } + } + + MLAS_FLOAT32X4 InputVector0 = MlasLoadFloat32x4(&values[0]); + MLAS_FLOAT32X4 InputVector1 = MlasLoadFloat32x4(&values[4]); + MLAS_FLOAT32X4 InputVector2 = MlasLoadFloat32x4(&values[8]); + MLAS_FLOAT32X4 InputVector3 = MlasLoadFloat32x4(&values[12]); + + SumVector0 = MlasAddFloat32x4(SumVector0, InputVector0); + SumVector1 = MlasAddFloat32x4(SumVector1, InputVector1); + SumVector2 = MlasAddFloat32x4(SumVector2, InputVector2); + SumVector3 = MlasAddFloat32x4(SumVector3, InputVector3); + } + } + } + + if (ExcludePad) { + float results[BlockSize]; + + MlasStoreFloat32x4(&results[0], SumVector0); + MlasStoreFloat32x4(&results[4], SumVector1); + MlasStoreFloat32x4(&results[8], SumVector2); + MlasStoreFloat32x4(&results[12], SumVector3); + + for (size_t i = 0; i < BlockSize; i++) { + results[i] = results[i] / static_cast(valid_count[i]); + } + + MLAS_FLOAT32X4 ResultVector0 = MlasLoadFloat32x4(&results[0]); + MLAS_FLOAT32X4 ResultVector1 = MlasLoadFloat32x4(&results[4]); + MLAS_FLOAT32X4 ResultVector2 = MlasLoadFloat32x4(&results[8]); + MLAS_FLOAT32X4 ResultVector3 = MlasLoadFloat32x4(&results[12]); + + MlasStoreFloat32x4(&Output[output_idx * BlockSize], ResultVector0); + MlasStoreFloat32x4(&Output[output_idx * BlockSize + 4], ResultVector1); + MlasStoreFloat32x4(&Output[output_idx * BlockSize + 8], ResultVector2); + MlasStoreFloat32x4(&Output[output_idx * BlockSize + 12], ResultVector3); + } else { + const float KernelSize = static_cast(ActualKernelSize); + const MLAS_FLOAT32X4 KernelSizeVector = MlasBroadcastFloat32x4(KernelSize); + + MLAS_FLOAT32X4 ResultVector0 = MlasDivideFloat32x4(SumVector0, KernelSizeVector); + MLAS_FLOAT32X4 ResultVector1 = MlasDivideFloat32x4(SumVector1, KernelSizeVector); + MLAS_FLOAT32X4 ResultVector2 = MlasDivideFloat32x4(SumVector2, KernelSizeVector); + MLAS_FLOAT32X4 ResultVector3 = MlasDivideFloat32x4(SumVector3, KernelSizeVector); + + MlasStoreFloat32x4(&Output[output_idx * BlockSize], ResultVector0); + MlasStoreFloat32x4(&Output[output_idx * BlockSize + 4], ResultVector1); + MlasStoreFloat32x4(&Output[output_idx * BlockSize + 8], ResultVector2); + MlasStoreFloat32x4(&Output[output_idx * BlockSize + 12], ResultVector3); + } + } +} + +void + MLASCALL + MlasPoolAverageExcludePadFloatKernelNeon( + const float* Input, + float* Output, + size_t StrideWidth, + size_t DilationWidth, + size_t InputStride, + size_t ActualKernelSize, + size_t KernelHeight, + size_t KernelWidth, + const float* InputBase, + size_t InputWidth, + size_t DilatedInputWidth, + size_t OutputCountLeftPad, + size_t OutputCount, + size_t OutputCountRightPad + ) +{ + MLAS_UNREFERENCED_PARAMETER(InputStride); + + MlasPoolAverageFloatKernelNeonImpl( + Input, Output, StrideWidth, DilationWidth, ActualKernelSize, + KernelHeight, KernelWidth, InputBase, InputWidth, DilatedInputWidth, + OutputCountLeftPad, OutputCount, OutputCountRightPad, + true // ExcludePad = true + ); +} + +void + MLASCALL + MlasPoolAverageIncludePadFloatKernelNeon( + const float* Input, + float* Output, + size_t StrideWidth, + size_t DilationWidth, + size_t InputStride, + size_t ActualKernelSize, + size_t KernelHeight, + size_t KernelWidth, + const float* InputBase, + size_t InputWidth, + size_t DilatedInputWidth, + size_t OutputCountLeftPad, + size_t OutputCount, + size_t OutputCountRightPad + ) +{ + MLAS_UNREFERENCED_PARAMETER(InputStride); + + MlasPoolAverageFloatKernelNeonImpl( + Input, Output, StrideWidth, DilationWidth, ActualKernelSize, + KernelHeight, KernelWidth, InputBase, InputWidth, DilatedInputWidth, + OutputCountLeftPad, OutputCount, OutputCountRightPad, + false // ExcludePad = false + ); +} diff --git a/onnxruntime/core/optimizer/layout_transformation/layout_transformation_dev_notes.md b/onnxruntime/core/optimizer/layout_transformation/layout_transformation_dev_notes.md index 0daa9d1ddeaee..14bdc1413d721 100644 --- a/onnxruntime/core/optimizer/layout_transformation/layout_transformation_dev_notes.md +++ b/onnxruntime/core/optimizer/layout_transformation/layout_transformation_dev_notes.md @@ -44,7 +44,7 @@ Basic steps are as follows: 1. Implement [GetPreferredLayout](https://github.com/microsoft/onnxruntime/blob/1a4868e5c4c4a270ad91036e36f2a03410c4c278/include/onnxruntime/core/framework/execution_provider.h#L285) method for the EP which overrides the base class method. 2. Remove any existing logic in the EP to convert layouts 3. Add a validation method similar to [IsOpInRequiredLayout](https://github.com/microsoft/onnxruntime/blob/1a4868e5c4c4a270ad91036e36f2a03410c4c278/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_builder.cc#L502) to validate that the layout sensitive op's domain matches "kMSInternalNHWCDomain". Layout Transformer updates the domain for layout sensitive ops to "kMSInternalNHWCDomain" after the conversion to NHWC format. -4. Add tests. The testing framework already includes [InternalTestingExecutionProvider](https://github.com/microsoft/onnxruntime/blob/1a4868e5c4c4a270ad91036e36f2a03410c4c278/onnxruntime/test/providers/internal_testing/internal_testing_execution_provider.h#L11) which can be leveraged for such tests. +4. Add tests. The testing framework already includes [InternalTestingExecutionProvider](../../../test/internal_testing_ep/internal_testing_execution_provider.h) which can be leveraged for such tests. ## Making Updates to Transformer and Testing Apart from bug fixes, updates to layout sensitive op schema as well as addition of new layout sensitive ops will require changes in layout transformer as well as transpose optimizer. @@ -55,4 +55,4 @@ These are some places which may need changes: 3. Updates in [TransformLayoutForEP](https://github.com/microsoft/onnxruntime/blob/1a4868e5c4c4a270ad91036e36f2a03410c4c278/onnxruntime/core/optimizer/transpose_optimizer/optimizer_api_impl.cc#L815) method which relies on schema for deciding which inputs and outputs need to be wrapped with transpose nodes. 4. When upgrading to a new ONNX operator set version [kMaxSupportedOpset](https://github.com/microsoft/onnxruntime/blob/1a4868e5c4c4a270ad91036e36f2a03410c4c278/onnxruntime/core/optimizer/transpose_optimizer/optimizer_api.h#L437) needs to be updated to enable the transformations for this new opset. -Testing framework provides [InternalTestingExecutionProvider](https://github.com/microsoft/onnxruntime/blob/1a4868e5c4c4a270ad91036e36f2a03410c4c278/onnxruntime/test/providers/internal_testing/internal_testing_execution_provider.h#L11). This can be leveraged to test the changes being introduced. +Testing framework provides [InternalTestingExecutionProvider](../../../test/internal_testing_ep/internal_testing_execution_provider.h). This can be leveraged to test the changes being introduced. diff --git a/onnxruntime/core/providers/cpu/llm/attention.cc b/onnxruntime/core/providers/cpu/llm/attention.cc index bb2b073b62681..7ac19bf67fb8a 100644 --- a/onnxruntime/core/providers/cpu/llm/attention.cc +++ b/onnxruntime/core/providers/cpu/llm/attention.cc @@ -62,14 +62,14 @@ void make_copy(MLFloat16* mask_data, const MLFloat16* mask template <> void make_copy(float* mask_data, const bool* mask_index, size_t size) { for (size_t i = 0; i < size; ++i) { - mask_data[i] = mask_index[i] ? 0.0f : std::numeric_limits::lowest(); + mask_data[i] = mask_index[i] ? 0.0f : negative_infinity(); } } template <> void make_copy(MLFloat16* mask_data, const bool* mask_index, size_t size) { for (size_t i = 0; i < size; ++i) { - mask_data[i] = mask_index[i] ? MLFloat16(0.f) : std::numeric_limits::lowest(); + mask_data[i] = mask_index[i] ? MLFloat16(0.f) : negative_infinity(); } } @@ -251,7 +251,7 @@ void AttentionBase::ComputeAttentionProbs(T* attention_probs, mask_data = static_cast(allocated_ptr); for (int s_i = 0; s_i < parameters.q_sequence_length; s_i++) { for (int m_i = parameters.past_sequence_length + s_i + 1; m_i < parameters.total_sequence_length; m_i++) { - mask_data[s_i * parameters.total_sequence_length + m_i] = std::numeric_limits::lowest(); + mask_data[s_i * parameters.total_sequence_length + m_i] = negative_infinity(); } } delete_mask_data = true; @@ -277,7 +277,7 @@ void AttentionBase::ComputeAttentionProbs(T* attention_probs, for (int i = 0; i < n_iter; ++i) { for (int s_i = 0; s_i < parameters.q_sequence_length; s_i++) { for (int m_i = parameters.past_sequence_length + s_i + 1; m_i < parameters.total_sequence_length; m_i++) { - mask_data[s_i * parameters.total_sequence_length + m_i + probs_matrix_size * i] = std::numeric_limits::lowest(); + mask_data[s_i * parameters.total_sequence_length + m_i + probs_matrix_size * i] = negative_infinity(); } } } @@ -332,7 +332,8 @@ void AttentionBase::ComputeAttentionProbs(T* attention_probs, } // handling GQA - std::ptrdiff_t ki = batch_i * parameters.kv_num_heads + head_i % parameters.kv_num_heads; + std::ptrdiff_t head_ki = head_i * parameters.kv_num_heads / parameters.q_num_heads; + std::ptrdiff_t ki = batch_i * parameters.kv_num_heads + head_ki; const T* k = K + k_input_chunk_length * ki; if (nullptr != present_key) { @@ -362,7 +363,7 @@ void AttentionBase::ComputeAttentionProbs(T* attention_probs, alpha, Q + q_input_chunk_length * parameters.q_num_heads * batch_i + head_i * parameters.head_size, parameters.head_size * parameters.q_num_heads, // lda - transposed_k ? K + k_input_chunk_length * parameters.kv_num_heads * batch_i + (head_i % parameters.kv_num_heads) * parameters.head_size : k, + transposed_k ? K + k_input_chunk_length * parameters.kv_num_heads * batch_i + head_ki * parameters.head_size : k, transposed_k ? parameters.head_size * parameters.kv_num_heads : parameters.head_size, // ldb beta, output, @@ -568,7 +569,8 @@ void AttentionBase::ComputeVxAttentionScore(T* output, // bu // handling GQA std::ptrdiff_t batch_i = i / num_heads; std::ptrdiff_t head_i = i % num_heads; - std::ptrdiff_t vi = batch_i * kv_num_heads + head_i % kv_num_heads; + std::ptrdiff_t head_vi = head_i * kv_num_heads / num_heads; + std::ptrdiff_t vi = batch_i * kv_num_heads + head_vi; const T* v = V + v_input_chunk_length * vi; if (nullptr != present_value) { @@ -592,16 +594,15 @@ void AttentionBase::ComputeVxAttentionScore(T* output, // bu // V is transposed but not QK. We use GemmEx with a different value for ldb. math::GemmEx(CblasNoTrans, CblasNoTrans, - sequence_length, // M - v_head_size, // N - total_sequence_length, // K - 1.f, // alpha - attention_probs + attention_probs_offset, // QK - total_sequence_length, // lda - transposed_v ? V + (head_i % kv_num_heads) * v_head_size + v_input_chunk_length * kv_num_heads * batch_i - : v, - transposed_v ? v_head_size * kv_num_heads : v_head_size, // ldb - 0.f, // beta + sequence_length, // M + v_head_size, // N + total_sequence_length, // K + 1.f, // alpha + attention_probs + attention_probs_offset, // QK + total_sequence_length, // lda + transposed_v ? V + head_vi * v_head_size + v_input_chunk_length * kv_num_heads * batch_i : v, // V + transposed_v ? v_head_size * kv_num_heads : v_head_size, // ldb + 0.f, // beta output + ((batch_i * sequence_length * num_heads + head_i) * v_head_size), v_head_size * num_heads, // ldc nullptr); diff --git a/onnxruntime/core/providers/cpu/llm/attention.h b/onnxruntime/core/providers/cpu/llm/attention.h index 78889e48afb29..4fad6914f933d 100644 --- a/onnxruntime/core/providers/cpu/llm/attention.h +++ b/onnxruntime/core/providers/cpu/llm/attention.h @@ -9,6 +9,16 @@ namespace onnxruntime { +template +inline T negative_infinity() { + return -std::numeric_limits::infinity(); +} + +template <> +inline MLFloat16 negative_infinity() { + return MLFloat16(-std::numeric_limits::infinity()); +} + template class AttentionBase : public OpKernel { public: diff --git a/onnxruntime/core/providers/cpu/llm/rotary_embedding.cc b/onnxruntime/core/providers/cpu/llm/rotary_embedding.cc index f1b0d5850ab1f..4bd5b0d306b42 100644 --- a/onnxruntime/core/providers/cpu/llm/rotary_embedding.cc +++ b/onnxruntime/core/providers/cpu/llm/rotary_embedding.cc @@ -30,10 +30,6 @@ RotaryEmbedding::RotaryEmbedding(const OpKernelInfo& info) : OpKernel(info) { num_heads = static_cast(info.GetAttrOrDefault("num_heads", 0)); rotary_embedding_dim = static_cast(info.GetAttrOrDefault("rotary_embedding_dim", 0)); interleaved = (info.GetAttrOrDefault("interleaved", 0) == 1); // Turn 0/1 into bool - - if (rotary_embedding_dim > 0) { - ORT_ENFORCE(num_heads > 0, "num_heads must be provided if rotary_embedding_dim is specified"); - } } // TODO: rotary embedding in place @@ -111,6 +107,15 @@ Status RotaryEmbedding::Compute(OpKernelContext* context) const { // Optional position_ids input, can be nullptr const Tensor* position_ids = context->Input(3); + // If rotary_embedding_dim is set (>0) and num_heads attribute not provided (==0), + // we can only proceed if input is 4D (B, num_heads, S, head_size) so num_heads can be inferred. + if (rotary_embedding_dim > 0 && num_heads <= 0) { + const auto& dims = X->Shape().GetDims(); + ORT_ENFORCE(dims.size() == 4, + "Attribute 'num_heads' must be provided when 'rotary_embedding_dim' is specified " + "and input is not rank-4 (batch, num_heads, sequence, head)."); + } + RotaryParameters parameters = {}; ORT_RETURN_IF_ERROR(rotary_embedding_helper::CheckInputs(X, position_ids, diff --git a/onnxruntime/core/providers/cuda/cuda_provider_factory.h b/onnxruntime/core/providers/cuda/cuda_provider_factory.h index cf352757686bd..e83ef6f9b329f 100644 --- a/onnxruntime/core/providers/cuda/cuda_provider_factory.h +++ b/onnxruntime/core/providers/cuda/cuda_provider_factory.h @@ -55,7 +55,7 @@ struct ProviderInfo_CUDA { virtual std::shared_ptr CreateCudaAllocator(int16_t device_id, size_t gpu_mem_limit, onnxruntime::ArenaExtendStrategy arena_extend_strategy, onnxruntime::CUDAExecutionProviderExternalAllocatorInfo& external_allocator_info, const OrtArenaCfg* default_memory_arena_cfg) = 0; // This function is the entry point to CUDA EP's UT cases. - // All tests ared only called from onnxruntime_test_all. + // All tests are only called from onnxruntime_provider_test. virtual void TestAll() { ORT_NOT_IMPLEMENTED(__FUNCTION__, " is only implements in test code path."); } diff --git a/onnxruntime/core/providers/cuda/llm/rotary_embedding.cc b/onnxruntime/core/providers/cuda/llm/rotary_embedding.cc index f259c6021a82e..0b0dc2add3e38 100644 --- a/onnxruntime/core/providers/cuda/llm/rotary_embedding.cc +++ b/onnxruntime/core/providers/cuda/llm/rotary_embedding.cc @@ -44,6 +44,15 @@ Status RotaryEmbedding::ComputeInternal(OpKernelContext* context) const { const Tensor* sin_cache = context->Input(2); const Tensor* position_ids = context->Input(3); // Optional, can be nullptr + // If rotary_embedding_dim is set (>0) and num_heads attribute not provided (==0), + // we can only proceed if input is 4D (B, num_heads, S, head_size) so num_heads can be inferred. + if (rotary_embedding_dim > 0 && num_heads <= 0) { + const auto& dims = input->Shape().GetDims(); + ORT_ENFORCE(dims.size() == 4, + "Attribute 'num_heads' must be provided when 'rotary_embedding_dim' is specified " + "and input is not rank-4 (batch, num_heads, sequence, head)."); + } + RotaryParameters parameters = {}; ORT_RETURN_IF_ERROR(rotary_embedding_helper::CheckInputs(input, position_ids, diff --git a/onnxruntime/core/providers/cuda/shared_inc/fpgeneric.h b/onnxruntime/core/providers/cuda/shared_inc/fpgeneric.h index de445e07f5f07..c45d8c6de52d5 100644 --- a/onnxruntime/core/providers/cuda/shared_inc/fpgeneric.h +++ b/onnxruntime/core/providers/cuda/shared_inc/fpgeneric.h @@ -505,6 +505,22 @@ inline cublasStatus_t cublasGemmStridedBatchedHelper( ldb, strideB, &h_b, C, CUDA_R_16BF, ldc, strideC, batch_count, CUDA_R_32F, CUBLAS_GEMM_DEFAULT); } +inline cublasStatus_t cublasGemmStridedBatchedHelper( + cublasHandle_t handle, cublasOperation_t transa, + cublasOperation_t transb, int m, int n, int k, + const float* alpha, const onnxruntime::BFloat16* A, int lda, + int64_t strideA, const onnxruntime::BFloat16* B, int ldb, + int64_t strideB, const float* beta, onnxruntime::BFloat16* C, int ldc, + int64_t strideC, int batch_count, + const cudaDeviceProp& /*prop*/, bool /*use_tf32*/) { + float h_a = *alpha; + float h_b = *beta; + // accumulating in FP32 + return cublasGemmStridedBatchedEx( + handle, transa, transb, m, n, k, &h_a, A, CUDA_R_16BF, lda, strideA, B, CUDA_R_16BF, + ldb, strideB, &h_b, C, CUDA_R_16BF, ldc, strideC, batch_count, CUDA_R_32F, + CUBLAS_GEMM_DEFAULT); +} #else inline cublasStatus_t cublasGemmStridedBatchedHelper( cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, @@ -513,6 +529,13 @@ inline cublasStatus_t cublasGemmStridedBatchedHelper( int, int64_t, int, const cudaDeviceProp&, bool /*use_tf32*/) { return CUBLAS_STATUS_NOT_SUPPORTED; } +inline cublasStatus_t cublasGemmStridedBatchedHelper( + cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, + int, const float*, const onnxruntime::BFloat16*, int, int64_t, + const onnxruntime::BFloat16*, int, int64_t, const float*, onnxruntime::BFloat16*, + int, int64_t, int, const cudaDeviceProp&, bool /*use_tf32*/) { + return CUBLAS_STATUS_NOT_SUPPORTED; +} #endif // transpose using geam diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/expand_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/expand_op_builder.cc index 69980b8f86dab..82a4950a0791a 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/expand_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/expand_op_builder.cc @@ -114,6 +114,10 @@ Status ExpandOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper, FillShapeInputData(shape_data, shape_size, static_cast(1.0)); break; } + case QNN_DATATYPE_FLOAT_16: { + FillShapeInputData(shape_data, shape_size, static_cast(1.0f)); + break; + } case QNN_DATATYPE_INT_64: { // QNN-EP doesn't support INT64 shape input. qnn_data_type = QNN_DATATYPE_INT_32; diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/resize_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/resize_op_builder.cc index 85844721b1f2c..4e38530a56ea0 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/resize_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/resize_op_builder.cc @@ -173,33 +173,37 @@ Status ResizeOpBuilder::IsOpSupported(QnnModelWrapper& qnn_model_wrapper, "QNN EP: Resize does not support nearest_mode ", nearest_mode.c_str()); if (is_npu_backend) { - // QNN only supports the following nearest_mode values on HTP: - // - QNN 2.19: "round_prefer_floor" via QNN's Resize operator - // - QNN 2.20 (API version 2.14): "round_prefer_ceil" via QNN's Resize operator - // - "floor" via QNN's ResizeNearestNeighbor operator -#if QNN_API_VERSION_MAJOR >= 2 && QNN_API_VERSION_MINOR >= 14 - ORT_RETURN_IF_NOT(nearest_mode == "round_prefer_ceil" || nearest_mode == "floor", - "QNN EP: Resize on the NPU does not support nearest_mode ", nearest_mode.c_str()); -#else - ORT_RETURN_IF_NOT(nearest_mode == "round_prefer_floor" || nearest_mode == "floor", - "QNN EP: Resize on the NPU does not support nearest_mode ", nearest_mode.c_str()); -#endif - - // Use ResizeNearestNeighbor for rank-4 inputs. + // For better performance with HTP backend, use QNN's ResizeNearestNeighbor for rank-4 input. const bool use_resize_nn_op = input_rank == 4; - // If HTP uses ResizeNearestNeighbor ("floor"), then the "pytorch_half_pixel" coordinate_transformation_mode - // is not supported. - ORT_RETURN_IF(!use_resize_nn_op && nearest_mode == "floor" && transformation_mode == "pytorch_half_pixel", - "QNN EP: Resize on the NPU does not support the combination of nearest_mode == 'floor' ", - " and coordinate_transformation_mode == 'pytorch_half_pixel'."); - + if (!use_resize_nn_op) { + // QNN only supports the following nearest_mode values on HTP: + // - QNN 2.19: "round_prefer_floor" via QNN's Resize operator + // - QNN 2.20 (API version 2.14): "round_prefer_ceil" via QNN's Resize operator #if QNN_API_VERSION_MAJOR >= 2 && QNN_API_VERSION_MINOR >= 14 - // QNN's Resize only supports "round_prefer_ceil" if transformation_mode is "align_corners". - ORT_RETURN_IF(!use_resize_nn_op && transformation_mode != "align_corners", - "QNN EP: Resize on the NPU only supports 'round_prefer_ceil' if " - "transformation mode is 'align_corners'"); + ORT_RETURN_IF_NOT(nearest_mode == "round_prefer_ceil" || nearest_mode == "floor", + "QNN EP: Resize on the NPU does not support nearest_mode ", nearest_mode.c_str()); + + // QNN HTP Resize only supports "round_prefer_ceil" if transformation_mode is "align_corners". + ORT_RETURN_IF(nearest_mode == "round_prefer_ceil" && transformation_mode != "align_corners", + "QNN EP: Resize on the NPU only supports 'round_prefer_ceil' if " + "transformation mode is 'align_corners'"); +#else + ORT_RETURN_IF_NOT(nearest_mode == "round_prefer_floor" || nearest_mode == "floor", + "QNN EP: Resize on the NPU does not support nearest_mode ", nearest_mode.c_str()); #endif + // If HTP uses Resize ("floor"), then the transformation_mode "pytorch_half_pixel" is not supported. + ORT_RETURN_IF(nearest_mode == "floor" && transformation_mode == "pytorch_half_pixel", + "QNN EP: Resize on the NPU does not support the combination of nearest_mode == 'floor' ", + " and transformation_mode == 'pytorch_half_pixel'."); + } else { + // If HTP uses ResizeNearestNeighbor "ceil" or "round_prefer_floor", then the + // transformation_mode "asymmetric" is not supported. + // This is verified in unit test but not be documented in QNN SDK. + ORT_RETURN_IF((nearest_mode == "ceil" || nearest_mode == "round_prefer_floor") && transformation_mode == "asymmetric", + "QNN EP: ResizeNearestNeighbor on the NPU does not support the combination of ", + "nearest_mode == 'ceil' or 'round_prefer_floor' and transformation_mode == 'asymmetric'."); + } } } diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index 17b7f9af372bc..c424bc4264b0d 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -3383,17 +3383,58 @@ common::Status InferenceSession::GetInputOutputMemoryInfo(SessionInputOutputType for (const auto* def : def_list) { InlinedVector node_info_vec; + Status status; if (type == SessionInputOutputType::kOutput) { - ORT_RETURN_IF_ERROR(session_state_->GetOutputNodeInfo(def->Name(), node_info_vec)); + status = session_state_->GetOutputNodeInfo(def->Name(), node_info_vec); } else { - ORT_RETURN_IF_ERROR(session_state_->GetInputNodeInfo(def->Name(), node_info_vec)); + status = session_state_->GetInputNodeInfo(def->Name(), node_info_vec); } - // all entries are for the same OrtDevice so use the first one. - // we need to get an OrtMemoryInfo* that will remain valid, so we get the allocator for the OrtDevice - // from the session state and use its OrtMemoryInfo. - auto allocator = session_state_->GetAllocator(*node_info_vec.front().device); - memory_info.push_back(&allocator->Info()); + if (!status.IsOK()) { + if (type == SessionInputOutputType::kInput) { + return status; + } + + // Check first if this output is produced by an input that directly + // propagates to output with the same name. + status = session_state_->GetInputNodeInfo(def->Name(), node_info_vec); + if (status.IsOK()) { + // all entries are for the same OrtDevice so use the first one. + // we need to get an OrtMemoryInfo* that will remain valid, so we get the allocator for the OrtDevice + // from the session state and use its OrtMemoryInfo. + auto allocator = session_state_->GetAllocator(*node_info_vec.front().device); + memory_info.push_back(&allocator->Info()); + } else { + // Check if this output is produced by a constant initializer + // Pick the MemoryInfo from the initializer's OrtValue + const auto& ort_value_map = session_state_->GetOrtValueNameIdxMap(); + + OrtValueIndex ort_value_index; + status = ort_value_map.GetIdx(def->Name(), ort_value_index); + if (!status.IsOK()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "Failed to find node output or a constant initializer producing output: ", + def->Name(), "."); + } + + const auto& idx_to_ort_value = session_state_->GetInitializedTensors(); + auto it = idx_to_ort_value.find(ort_value_index); + if (it == idx_to_ort_value.end()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "Failed to find node output or a constant initializer producing output: ", + def->Name(), "."); + } + const auto& tensor = it->second.Get(); + auto allocator = session_state_->GetAllocator(tensor.Location()); + memory_info.push_back(&allocator->Info()); + } + } else { + // all entries are for the same OrtDevice so use the first one. + // we need to get an OrtMemoryInfo* that will remain valid, so we get the allocator for the OrtDevice + // from the session state and use its OrtMemoryInfo. + auto allocator = session_state_->GetAllocator(*node_info_vec.front().device); + memory_info.push_back(&allocator->Info()); + } } return Status::OK(); @@ -3422,15 +3463,19 @@ common::Status InferenceSession::GetEpDeviceForInputs(InlinedVector node_info_vec; ORT_RETURN_IF_ERROR(session_state_->GetInputNodeInfo(def->Name(), node_info_vec)); - - // if we have a lot of inputs or there are a lot of execution providers it may be worth creating a map - // instead of doing a linear search each time. - const auto& ep_name = node_info_vec.front().p_node->GetExecutionProviderType(); - auto it = std::find_if(available_eps.begin(), available_eps.end(), [&ep_name](const OrtEpDevice* entry) { - return entry->ep_name == ep_name; - }); - - ep_devices.push_back(it != available_eps.end() ? *it : nullptr); + assert(!node_info_vec.empty()); + // If we have an input that is not consumed by any node, + // including nodes in subgraphs, then we return nullptr. + const auto* p_node = node_info_vec.front().p_node; + if (p_node != nullptr) { + const auto ep_name = p_node->GetExecutionProviderType(); + auto it = std::find_if(available_eps.begin(), available_eps.end(), [&ep_name](const OrtEpDevice* entry) { + return entry->ep_name == ep_name; + }); + ep_devices.push_back(it != available_eps.end() ? *it : nullptr); + } else { + ep_devices.push_back(nullptr); + } } return Status::OK(); diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 21d09df5cc4db..d0fe6291c2e03 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -3313,13 +3313,21 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_V2, _In_ OrtS API_IMPL_BEGIN std::unique_ptr provider_factory = nullptr; + const auto ep_devices_span = gsl::span(ep_devices, num_ep_devices); + const auto ep_option_keys_span = gsl::span(ep_option_keys, num_ep_options); + const auto ep_option_vals_span = gsl::span(ep_option_vals, num_ep_options); + ORT_API_RETURN_IF_STATUS_NOT_OK(CreateIExecutionProviderFactoryForEpDevices( env->GetEnvironment(), - session_options->value, - gsl::span(ep_devices, num_ep_devices), - gsl::span(ep_option_keys, num_ep_options), - gsl::span(ep_option_vals, num_ep_options), + ep_devices_span, /*output*/ provider_factory)); + + ORT_API_RETURN_IF_STATUS_NOT_OK(AddEpOptionsToSessionOptions( + ep_devices_span, + ep_option_keys_span, + ep_option_vals_span, + session_options->value)); + session_options->provider_factories.push_back(std::move(provider_factory)); return nullptr; diff --git a/onnxruntime/core/session/utils.cc b/onnxruntime/core/session/utils.cc index 444027692903c..4a50bab5e8cbc 100644 --- a/onnxruntime/core/session/utils.cc +++ b/onnxruntime/core/session/utils.cc @@ -438,22 +438,13 @@ Status LoadPluginOrProviderBridge(const std::string& registration_name, } Status CreateIExecutionProviderFactoryForEpDevices(const Environment& env, - SessionOptions& session_options, gsl::span ep_devices, - gsl::span ep_option_keys, - gsl::span ep_option_vals, /*output*/ std::unique_ptr& out) { if (ep_devices.empty()) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Must provide one or more OrtEpDevice instances."); } - const size_t num_ep_options = ep_option_keys.size(); - if (ep_option_vals.size() != num_ep_options) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Must provide the same number of keys and values for EP options."); - } - const auto& ep_name = ep_devices[0]->ep_name; OrtEpFactory* ep_factory = ep_devices[0]->ep_factory; bool all_match = std::all_of(ep_devices.begin() + 1, ep_devices.end(), @@ -465,6 +456,27 @@ Status CreateIExecutionProviderFactoryForEpDevices(const Environment& env, "All OrtEpDevice values in ep_devices must have the same execution provider."); } + EpFactoryInternal* internal_factory = env.GetEpFactoryInternal(ep_factory); + + if (internal_factory) { + out = std::make_unique(*internal_factory, ep_devices); + } else { + out = std::make_unique(*ep_factory, ep_devices); + } + + return Status::OK(); +} + +Status AddEpOptionsToSessionOptions(gsl::span ep_devices, + gsl::span ep_option_keys, + gsl::span ep_option_vals, + SessionOptions& session_options) { + const size_t num_ep_options = ep_option_keys.size(); + if (ep_option_vals.size() != num_ep_options) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Must provide the same number of keys and values for EP options."); + } + for (const OrtEpDevice* ep_device : ep_devices) { // add the options to the session options with the EP prefix. // first add the default values with prefix followed by user specified values so those win @@ -483,14 +495,6 @@ Status CreateIExecutionProviderFactoryForEpDevices(const Environment& env, } } - EpFactoryInternal* internal_factory = env.GetEpFactoryInternal(ep_factory); - - if (internal_factory) { - out = std::make_unique(*internal_factory, ep_devices); - } else { - out = std::make_unique(*ep_factory, ep_devices); - } - return Status::OK(); } #endif // !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/core/session/utils.h b/onnxruntime/core/session/utils.h index 5a5dcae9165ed..2ccd4d464a261 100644 --- a/onnxruntime/core/session/utils.h +++ b/onnxruntime/core/session/utils.h @@ -61,13 +61,15 @@ Status LoadPluginOrProviderBridge(const std::string& registration_name, std::vector& internal_factories); // Creates an IExecutionProviderFactory instance for a list of OrtEpDevices that all refer to the same EP. -// Adds all provider options to the OrtSessionOptions configuration. Status CreateIExecutionProviderFactoryForEpDevices(const Environment& env, - SessionOptions& session_options, gsl::span ep_devices, - gsl::span ep_options_keys, - gsl::span ep_options_vals, /*output*/ std::unique_ptr& out); +// Adds provider options to the OrtSessionOptions configuration. +Status AddEpOptionsToSessionOptions(gsl::span ep_devices, + gsl::span ep_options_keys, + gsl::span ep_options_vals, + SessionOptions& session_options); + } // namespace onnxruntime #endif // !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index e370518b1fffb..c801e99e6edd3 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -1380,11 +1380,14 @@ static Status AddEpFactoryFromEpDevices(PySessionOptions& py_sess_options, std::unique_ptr provider_factory = nullptr; ORT_RETURN_IF_ERROR(CreateIExecutionProviderFactoryForEpDevices(env, - py_sess_options.value, ep_devices, - ep_option_keys, - ep_option_vals, /*output*/ provider_factory)); + + ORT_RETURN_IF_ERROR(AddEpOptionsToSessionOptions(ep_devices, + ep_option_keys, + ep_option_vals, + py_sess_options.value)); + py_sess_options.provider_factories.push_back(std::move(provider_factory)); return Status::OK(); } @@ -2297,7 +2300,50 @@ Applies to session load, initialization, etc. Default is 0.)pbdoc") ORT_UNUSED_PARAMETER(ort_values); ORT_THROW("External initializers are not supported in this build."); #endif - }); + }) + .def("add_external_initializers_from_files_in_memory", [](PySessionOptions* options, std::vector names, std::vector buffers, std::vector lengths) -> void { +#if !defined(ORT_MINIMAL_BUILD) && !defined(DISABLE_EXTERNAL_INITIALIZERS) + const auto num = names.size(); + ORT_ENFORCE(num == buffers.size() && num == lengths.size(), + "add_external_initializers_from_files_in_memory: expecting 'names', 'buffers' and 'lengths' to have equal length"); + + InlinedVector file_names; + InlinedVector> files_buffers; + file_names.reserve(num); + files_buffers.reserve(num); + + for (size_t i = 0; i < num; ++i) { + // Convert name and buffer using pybind-provided conversions + file_names.emplace_back(ToPathString(names[i])); + + // buffers[i] is a py::buffer; request() retrieves pointer without copying + py::buffer_info info = buffers[i].request(); + char* data_ptr = static_cast(info.ptr); + + files_buffers.emplace_back(std::make_pair(data_ptr, lengths[i])); + } + + ORT_THROW_IF_ERROR(options->value.AddExternalInitializersFromFilesInMemory(file_names, files_buffers)); +#else + ORT_UNUSED_PARAMETER(options); + ORT_UNUSED_PARAMETER(names); + ORT_UNUSED_PARAMETER(buffers); + ORT_UNUSED_PARAMETER(lengths); + ORT_THROW("External initializers are not supported in this build."); +#endif + }, + R"pbdoc( +Provide external initializer file contents from memory. + +Args: + names: sequence[str] of external file names (as referenced by the model's external_data locations). + buffers: sequence[bytes-like] objects exposing the buffer protocol (e.g., bytes, bytearray, memoryview, numpy uint8 array) containing the corresponding file contents. + lengths: sequence[int] sizes in bytes for each buffer. + +Notes: + - Keep the provided buffers alive until after session creation completes. ONNX Runtime copies needed data during session creation. + - The bytestream must match the external file layout expected by the model (raw tensor bytes at the specified offsets). +)pbdoc"); py::class_(m, "RunOptions", R"pbdoc(Configuration information for a single Run.)pbdoc") .def(py::init()) diff --git a/onnxruntime/test/autoep/library/ep.cc b/onnxruntime/test/autoep/library/ep.cc index e4265713d2d0a..14bd62467a403 100644 --- a/onnxruntime/test/autoep/library/ep.cc +++ b/onnxruntime/test/autoep/library/ep.cc @@ -233,7 +233,7 @@ OrtStatus* ORT_API_CALL ExampleEp::GetCapabilityImpl(OrtEp* this_ptr, const OrtG for (const auto& node : nodes) { auto op_type = node.GetOperatorType(); - if (op_type != "Mul") { + if (op_type == "Mul") { // Check that Mul has inputs/output of type float std::vector inputs = node.GetInputs(); std::vector outputs = node.GetOutputs(); @@ -253,6 +253,10 @@ OrtStatus* ORT_API_CALL ExampleEp::GetCapabilityImpl(OrtEp* this_ptr, const OrtG } } + if (supported_nodes.empty()) { + return nullptr; + } + // Create (optional) fusion options for the supported nodes to fuse. OrtNodeFusionOptions node_fusion_options = {}; node_fusion_options.ort_version_supported = ORT_API_VERSION; @@ -317,7 +321,7 @@ OrtStatus* ORT_API_CALL ExampleEp::CompileImpl(_In_ OrtEp* this_ptr, _In_ const Ort::ConstNode fused_node{fused_nodes[0]}; auto ep_name = fused_node.GetEpName(); - if (ep_name != "example_ep") { + if (ep_name != ep->name_) { Ort::Status status("The fused node is expected to assigned to this EP to run on", ORT_EP_FAIL); return status.release(); } diff --git a/onnxruntime/test/autoep/test_execution.cc b/onnxruntime/test/autoep/test_execution.cc index 0f4a654f116c4..586c7a9a06335 100644 --- a/onnxruntime/test/autoep/test_execution.cc +++ b/onnxruntime/test/autoep/test_execution.cc @@ -98,6 +98,7 @@ TEST(OrtEpLibrary, PluginEp_GenEpContextModel) { // Create model compilation options from the session options. Ort::ModelCompilationOptions compile_options(*ort_env, session_options); + compile_options.SetFlags(OrtCompileApiFlags_ERROR_IF_NO_NODES_COMPILED); compile_options.SetInputModelPath(input_model_file); compile_options.SetOutputModelPath(output_model_file); diff --git a/onnxruntime/test/common/random_generator.h b/onnxruntime/test/common/random_generator.h index 2335db69ff571..ad87c7162f344 100644 --- a/onnxruntime/test/common/random_generator.h +++ b/onnxruntime/test/common/random_generator.h @@ -10,6 +10,7 @@ #include #include "core/common/common.h" +#include "core/common/narrow.h" #include "core/common/optional.h" #include "core/common/type_utils.h" #include "core/framework/float16.h" @@ -20,7 +21,7 @@ namespace onnxruntime { namespace test { namespace detail { -inline int64_t SizeFromDims(gsl::span dims, gsl::span strides = {}) { +inline size_t SizeFromDims(gsl::span dims, gsl::span strides = {}) { int64_t size = 1; if (strides.empty()) { size = std::accumulate(dims.begin(), dims.end(), static_cast(1), std::multiplies{}); @@ -35,8 +36,7 @@ inline int64_t SizeFromDims(gsl::span dims, gsl::span= 0); - return size; + return narrow(size); } } // namespace detail diff --git a/onnxruntime/test/contrib_ops/beam_search_test.cc b/onnxruntime/test/contrib_ops/beam_search_test.cc index 367553a28f166..20eea2138340f 100644 --- a/onnxruntime/test/contrib_ops/beam_search_test.cc +++ b/onnxruntime/test/contrib_ops/beam_search_test.cc @@ -7,8 +7,8 @@ #include #include "core/session/onnxruntime_cxx_api.h" #include "test/common/cuda_op_test_utils.h" -#include "test/providers/model_tester.h" #include "test/util/include/current_test_name.h" +#include "test/unittest_util/model_tester.h" #include "test/util/include/scoped_env_vars.h" #include "contrib_ops/cpu/transformers/generation_shared.h" diff --git a/onnxruntime/test/contrib_ops/cuda_kernels/fpA_intB_gemm_kernel_test.cc b/onnxruntime/test/contrib_ops/cuda_kernels/fpA_intB_gemm_kernel_test.cc index cac6d46226ef8..3e339d86c7943 100644 --- a/onnxruntime/test/contrib_ops/cuda_kernels/fpA_intB_gemm_kernel_test.cc +++ b/onnxruntime/test/contrib_ops/cuda_kernels/fpA_intB_gemm_kernel_test.cc @@ -2,7 +2,7 @@ // Licensed under the MIT License. // Test can be run like the following: -// ./onnxruntime_test_all --gtest_filter=CUDA_EP_Unittest.* +// ./onnxruntime_provider_test --gtest_filter=CUDA_EP_Unittest.* #include #include diff --git a/onnxruntime/test/contrib_ops/dynamic_quantize_matmul_test.cc b/onnxruntime/test/contrib_ops/dynamic_quantize_matmul_test.cc index 2918e4baf86a4..3aafd413486c1 100644 --- a/onnxruntime/test/contrib_ops/dynamic_quantize_matmul_test.cc +++ b/onnxruntime/test/contrib_ops/dynamic_quantize_matmul_test.cc @@ -5,9 +5,9 @@ #include "core/framework/tensor.h" #include "core/session/inference_session.h" #include "test/common/tensor_op_test_utils.h" -#include "test/framework/test_utils.h" -#include "test/optimizer/graph_transform_test_builder.h" +#include "test/unittest_util/framework_test_utils.h" #include "test/providers/provider_test_utils.h" +#include "test/unittest_util/graph_transform_test_builder.h" #include "test/util/include/default_providers.h" #include "core/util/qmath.h" @@ -50,7 +50,7 @@ static void CalculateDynamicQuantizeMatMul(const int64_t M, const int64_t N, con min = std::min(min, 0.0f); float scale = static_cast(max - min) / (qmax - qmin); - T zeroPoint = std::round(std::clamp(qmin - min / scale, qmin, qmax)); + T zeroPoint = static_cast(std::round(std::clamp(qmin - min / scale, qmin, qmax))); A_scale.push_back(scale); A_zero_point.push_back(zeroPoint); diff --git a/onnxruntime/test/contrib_ops/function_ops_test.cc b/onnxruntime/test/contrib_ops/function_ops_test.cc index fa373edd166cd..2036328053d17 100644 --- a/onnxruntime/test/contrib_ops/function_ops_test.cc +++ b/onnxruntime/test/contrib_ops/function_ops_test.cc @@ -4,7 +4,7 @@ #include "gtest/gtest.h" #include "core/graph/contrib_ops/contrib_defs.h" -#include "test/contrib_ops/function_test_util.h" +#include "test/unittest_util/function_test_util.h" using namespace ::onnxruntime::common; diff --git a/onnxruntime/test/contrib_ops/gather_block_quantized_op_test.cc b/onnxruntime/test/contrib_ops/gather_block_quantized_op_test.cc index 4b586e24c9bd3..574ec49da67ea 100644 --- a/onnxruntime/test/contrib_ops/gather_block_quantized_op_test.cc +++ b/onnxruntime/test/contrib_ops/gather_block_quantized_op_test.cc @@ -161,15 +161,17 @@ ToType(const std::vector& vec) { template typename std::enable_if, T>::value, std::vector>::type ToType(const std::vector& vec) { + using UnpackedType = typename T::UnpackedType; + // UInt4x2 and Int4x2 uses global packing instead of per-row packing. size_t i = 0; - constexpr int offset = std::is_same::value ? 0 : 8; + constexpr UnpackedType offset = std::is_same::value ? 0 : 8; std::vector result; for (i = 0; i + 1 < vec.size(); i += 2) { - result.push_back(T(vec[i] + offset, vec[i + 1] + offset)); + result.push_back(T(static_cast(vec[i] + offset), static_cast(vec[i + 1] + offset))); } if (i < vec.size()) { - result.push_back(T(vec[i] + offset, 0 + offset)); + result.push_back(T(static_cast(vec[i] + offset), static_cast(0 + offset))); } return result; } diff --git a/onnxruntime/test/contrib_ops/gemm_fastgelu_op_test.cc b/onnxruntime/test/contrib_ops/gemm_fastgelu_op_test.cc index d9d2681dd3b3f..6b67b648fd9b2 100644 --- a/onnxruntime/test/contrib_ops/gemm_fastgelu_op_test.cc +++ b/onnxruntime/test/contrib_ops/gemm_fastgelu_op_test.cc @@ -8,7 +8,6 @@ #include "test/common/cuda_op_test_utils.h" #include "test/common/tensor_op_test_utils.h" #include "test/providers/provider_test_utils.h" -#include "test/providers/run_options_config_keys.h" namespace onnxruntime { namespace test { diff --git a/onnxruntime/test/contrib_ops/gemma_rotary_emb_test.cc b/onnxruntime/test/contrib_ops/gemma_rotary_emb_test.cc index 80adf04f402a8..f30a377b30f63 100644 --- a/onnxruntime/test/contrib_ops/gemma_rotary_emb_test.cc +++ b/onnxruntime/test/contrib_ops/gemma_rotary_emb_test.cc @@ -31,12 +31,12 @@ static void calculateExpectedOutput(const std::vector& emb_data, const std::vector& mul_dim, std::vector& output1, std::vector& output2) { - for (long int i = 0; i < mul_dim[0]; ++i) { - for (long int j = 0; j < mul_dim[1]; ++j) { - for (long int k = 0; k < mul_dim[2]; ++k) { - for (long int l = 0; l < mul_dim[3]; ++l) { - long int embIdx = i * mul_dim[1] * mul_dim[3] + k * mul_dim[3] + l; - long int mulIdx = i * mul_dim[1] * mul_dim[2] * mul_dim[3] + j * mul_dim[2] * mul_dim[3] + k * mul_dim[3] + l; + for (int64_t i = 0; i < mul_dim[0]; ++i) { + for (int64_t j = 0; j < mul_dim[1]; ++j) { + for (int64_t k = 0; k < mul_dim[2]; ++k) { + for (int64_t l = 0; l < mul_dim[3]; ++l) { + int64_t embIdx = i * mul_dim[1] * mul_dim[3] + k * mul_dim[3] + l; + int64_t mulIdx = i * mul_dim[1] * mul_dim[2] * mul_dim[3] + j * mul_dim[2] * mul_dim[3] + k * mul_dim[3] + l; MLFloat16 sin_val = static_cast(sin(emb_data[embIdx])); MLFloat16 cos_val = static_cast(cos(emb_data[embIdx])); diff --git a/onnxruntime/test/contrib_ops/group_norm_op_test.cc b/onnxruntime/test/contrib_ops/group_norm_op_test.cc index 4983cb5abf10c..fdc546441676b 100644 --- a/onnxruntime/test/contrib_ops/group_norm_op_test.cc +++ b/onnxruntime/test/contrib_ops/group_norm_op_test.cc @@ -4,7 +4,7 @@ #include #include "test/common/tensor_op_test_utils.h" #include "test/common/cuda_op_test_utils.h" -#include "test/framework/test_utils.h" +#include "test/unittest_util/framework_test_utils.h" #include "test/providers/provider_test_utils.h" #include "gtest/gtest.h" diff --git a/onnxruntime/test/contrib_ops/layer_norm_op_test.cc b/onnxruntime/test/contrib_ops/layer_norm_op_test.cc index e22445edc0f5b..0d4fc5af68b4f 100644 --- a/onnxruntime/test/contrib_ops/layer_norm_op_test.cc +++ b/onnxruntime/test/contrib_ops/layer_norm_op_test.cc @@ -9,7 +9,7 @@ #include "test/common/dnnl_op_test_utils.h" #include "test/common/tensor_op_test_utils.h" #include "test/common/cuda_op_test_utils.h" -#include "test/framework/test_utils.h" +#include "test/unittest_util/framework_test_utils.h" #include "test/util/include/default_providers.h" #include "test/providers/provider_test_utils.h" diff --git a/onnxruntime/test/contrib_ops/matmul_2bits_test.cc b/onnxruntime/test/contrib_ops/matmul_2bits_test.cc index 57b1edfd0edce..04a4c95dd478b 100644 --- a/onnxruntime/test/contrib_ops/matmul_2bits_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_2bits_test.cc @@ -17,9 +17,9 @@ #include "core/session/inference_session.h" #include "test/common/cuda_op_test_utils.h" #include "test/common/tensor_op_test_utils.h" -#include "test/framework/test_utils.h" -#include "test/optimizer/graph_transform_test_builder.h" +#include "test/unittest_util/framework_test_utils.h" #include "test/providers/provider_test_utils.h" +#include "test/unittest_util/graph_transform_test_builder.h" #include "test/util/include/default_providers.h" #include "test/util/include/scoped_env_vars.h" #include "core/session/onnxruntime_cxx_api.h" diff --git a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc index 7213937d0ef11..3a9bd02ef8d72 100644 --- a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc @@ -17,9 +17,9 @@ #include "core/session/inference_session.h" #include "test/common/cuda_op_test_utils.h" #include "test/common/tensor_op_test_utils.h" -#include "test/framework/test_utils.h" -#include "test/optimizer/graph_transform_test_builder.h" +#include "test/unittest_util/framework_test_utils.h" #include "test/providers/provider_test_utils.h" +#include "test/unittest_util/graph_transform_test_builder.h" #include "test/util/include/default_providers.h" #include "test/util/include/scoped_env_vars.h" #include "core/session/onnxruntime_cxx_api.h" diff --git a/onnxruntime/test/contrib_ops/matmul_8bits_test.cc b/onnxruntime/test/contrib_ops/matmul_8bits_test.cc index c60abbc278962..b336debecef94 100644 --- a/onnxruntime/test/contrib_ops/matmul_8bits_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_8bits_test.cc @@ -16,9 +16,9 @@ #include "core/session/inference_session.h" #include "test/common/cuda_op_test_utils.h" #include "test/common/tensor_op_test_utils.h" -#include "test/framework/test_utils.h" -#include "test/optimizer/graph_transform_test_builder.h" +#include "test/unittest_util/framework_test_utils.h" #include "test/providers/provider_test_utils.h" +#include "test/unittest_util/graph_transform_test_builder.h" #include "test/util/include/default_providers.h" #include "test/util/include/scoped_env_vars.h" #include "core/session/onnxruntime_cxx_api.h" diff --git a/onnxruntime/test/contrib_ops/matmul_bnb4_test.cc b/onnxruntime/test/contrib_ops/matmul_bnb4_test.cc index e739b17d5885f..a155e24800644 100644 --- a/onnxruntime/test/contrib_ops/matmul_bnb4_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_bnb4_test.cc @@ -9,9 +9,9 @@ #include "core/mlas/inc/mlas.h" #include "core/session/inference_session.h" #include "test/common/tensor_op_test_utils.h" -#include "test/framework/test_utils.h" -#include "test/optimizer/graph_transform_test_builder.h" +#include "test/unittest_util/framework_test_utils.h" #include "test/providers/provider_test_utils.h" +#include "test/unittest_util/graph_transform_test_builder.h" #include "test/util/include/default_providers.h" #include "core/util/qmath.h" #include "contrib_ops/cpu/quantization/dequantize_blockwise_bnb4.h" diff --git a/onnxruntime/test/contrib_ops/matmul_fpq4_test.cc b/onnxruntime/test/contrib_ops/matmul_fpq4_test.cc index 09ae5eddb122c..81cb430da34da 100644 --- a/onnxruntime/test/contrib_ops/matmul_fpq4_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_fpq4_test.cc @@ -8,9 +8,9 @@ #include "core/mlas/inc/mlas_q4.h" #include "core/session/inference_session.h" #include "test/common/tensor_op_test_utils.h" -#include "test/framework/test_utils.h" -#include "test/optimizer/graph_transform_test_builder.h" +#include "test/unittest_util/framework_test_utils.h" #include "test/providers/provider_test_utils.h" +#include "test/unittest_util/graph_transform_test_builder.h" #include "test/util/include/default_providers.h" #include "core/util/qmath.h" diff --git a/onnxruntime/test/contrib_ops/matmul_integer_to_float_test.cc b/onnxruntime/test/contrib_ops/matmul_integer_to_float_test.cc index 8d7629b5fda1c..30b0c0fcf73c3 100644 --- a/onnxruntime/test/contrib_ops/matmul_integer_to_float_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_integer_to_float_test.cc @@ -6,9 +6,9 @@ #include "core/mlas/inc/mlas.h" #include "core/session/inference_session.h" #include "test/common/tensor_op_test_utils.h" -#include "test/framework/test_utils.h" -#include "test/optimizer/graph_transform_test_builder.h" +#include "test/unittest_util/framework_test_utils.h" #include "test/providers/provider_test_utils.h" +#include "test/unittest_util/graph_transform_test_builder.h" #include "test/util/include/default_providers.h" #include "core/util/qmath.h" diff --git a/onnxruntime/test/contrib_ops/skip_group_norm_op_test.cc b/onnxruntime/test/contrib_ops/skip_group_norm_op_test.cc index ea8537f243f5d..3e8870892b7c9 100644 --- a/onnxruntime/test/contrib_ops/skip_group_norm_op_test.cc +++ b/onnxruntime/test/contrib_ops/skip_group_norm_op_test.cc @@ -4,7 +4,7 @@ #include #include "test/common/tensor_op_test_utils.h" #include "test/common/cuda_op_test_utils.h" -#include "test/framework/test_utils.h" +#include "test/unittest_util/framework_test_utils.h" #include "test/providers/provider_test_utils.h" #include "gtest/gtest.h" diff --git a/onnxruntime/test/framework/allocation_planner_test.cc b/onnxruntime/test/framework/allocation_planner_test.cc index 8827696bc2fb9..bcd6b62e649be 100644 --- a/onnxruntime/test/framework/allocation_planner_test.cc +++ b/onnxruntime/test/framework/allocation_planner_test.cc @@ -33,13 +33,6 @@ using json = nlohmann::json; #include "core/session/onnxruntime_session_options_config_keys.h" using namespace ONNX_NAMESPACE; -// Explicitly provide a definition for the static const var 'GPU' in the OrtDevice struct, -// GCC 4.x doesn't seem to define this and it breaks the pipelines based on CentOS as it uses -// GCC 4.x. -// (This static var is referenced in some tests below) -const OrtDevice::DeviceType OrtDevice::GPU; -const OrtDevice::DeviceType OrtDevice::CPU; - namespace onnxruntime { #ifdef USE_CUDA ProviderInfo_CUDA& GetProviderInfo_CUDA(); diff --git a/onnxruntime/test/framework/allocator_test.cc b/onnxruntime/test/framework/allocator_test.cc index 445e023746aaa..b1af7beb180b5 100644 --- a/onnxruntime/test/framework/allocator_test.cc +++ b/onnxruntime/test/framework/allocator_test.cc @@ -5,7 +5,7 @@ #include "core/framework/allocator.h" #include "core/framework/allocator_utils.h" -#include "test_utils.h" +#include "test/unittest_util/framework_test_utils.h" #include "gtest/gtest.h" namespace onnxruntime { diff --git a/onnxruntime/test/framework/cuda/fence_cuda_test.cc b/onnxruntime/test/framework/cuda/fence_cuda_test.cc index fced72ce3246d..1553469c52df7 100644 --- a/onnxruntime/test/framework/cuda/fence_cuda_test.cc +++ b/onnxruntime/test/framework/cuda/fence_cuda_test.cc @@ -22,7 +22,7 @@ #include "core/providers/cpu/math/element_wise_ops.h" #include "test/capturing_sink.h" #include "test/test_environment.h" -#include "test/framework/test_utils.h" +#include "test/unittest_util/framework_test_utils.h" #include "gtest/gtest.h" #include "core/util/protobuf_parsing_utils.h" #include "test/providers/provider_test_utils.h" diff --git a/onnxruntime/test/framework/dummy_provider.h b/onnxruntime/test/framework/dummy_provider.h index 2da040d9e703f..6bbb2fb6f693d 100644 --- a/onnxruntime/test/framework/dummy_provider.h +++ b/onnxruntime/test/framework/dummy_provider.h @@ -3,7 +3,7 @@ #include "core/framework/execution_provider.h" #include -#include "dummy_allocator.h" +#include "test/unittest_util/dummy_allocator.h" namespace onnxruntime { namespace test { diff --git a/onnxruntime/test/framework/ep_compatibility_test.cc b/onnxruntime/test/framework/ep_compatibility_test.cc index a8a83fbe5ceb6..0ae3fb746dd24 100644 --- a/onnxruntime/test/framework/ep_compatibility_test.cc +++ b/onnxruntime/test/framework/ep_compatibility_test.cc @@ -19,7 +19,7 @@ #include "core/session/abi_session_options_impl.h" #include "core/framework/error_code_helper.h" #include "dummy_provider.h" -#include "test_utils.h" +#include "test/unittest_util/framework_test_utils.h" #include "test/test_environment.h" #include "test/providers/provider_test_utils.h" @@ -527,4 +527,4 @@ TEST(EpCompatibilityCxxApiTest, SingleDeviceCpuProvider) { }); ASSERT_TRUE(status == OrtCompiledModelCompatibility_EP_NOT_APPLICABLE); -} \ No newline at end of file +} diff --git a/onnxruntime/test/framework/execution_frame_test.cc b/onnxruntime/test/framework/execution_frame_test.cc index 67a0e7fb05241..4808d2bfb447a 100644 --- a/onnxruntime/test/framework/execution_frame_test.cc +++ b/onnxruntime/test/framework/execution_frame_test.cc @@ -8,10 +8,10 @@ #include "core/graph/model.h" #include "core/providers/cpu/cpu_execution_provider.h" #include "core/session/inference_session.h" -#include "test_utils.h" +#include "test/unittest_util/framework_test_utils.h" #include "test/test_environment.h" -#include "test/framework/TestAllocatorManager.h" #include "test/util/include/inference_session_wrapper.h" +#include "test/unittest_util/test_allocator_manager.h" #include "asserts.h" #include "gtest/gtest.h" #include "gmock/gmock.h" diff --git a/onnxruntime/test/framework/execution_provider_test.cc b/onnxruntime/test/framework/execution_provider_test.cc index 390fda7bfc5ad..7589533585967 100644 --- a/onnxruntime/test/framework/execution_provider_test.cc +++ b/onnxruntime/test/framework/execution_provider_test.cc @@ -3,7 +3,7 @@ #include "core/framework/execution_provider.h" #include "core/graph/model.h" -#include "test_utils.h" +#include "test/unittest_util/framework_test_utils.h" #include "test/test_environment.h" #include "test/util/include/asserts.h" #include "core/framework/model_metadef_id_generator.h" diff --git a/onnxruntime/test/framework/float_16_test.cc b/onnxruntime/test/framework/float_16_test.cc index 8f1c03419e145..8b235ce03dba1 100644 --- a/onnxruntime/test/framework/float_16_test.cc +++ b/onnxruntime/test/framework/float_16_test.cc @@ -23,7 +23,7 @@ #include "core/common/status.h" #include "test/capturing_sink.h" #include "test/test_environment.h" -#include "test_utils.h" +#include "test/unittest_util/framework_test_utils.h" #include "gtest/gtest.h" #include "core/graph/schema_registry.h" #include "core/framework/customregistry.h" diff --git a/onnxruntime/test/framework/float_4_test.cc b/onnxruntime/test/framework/float_4_test.cc index 03d13d99c7bc1..02a08956cf867 100644 --- a/onnxruntime/test/framework/float_4_test.cc +++ b/onnxruntime/test/framework/float_4_test.cc @@ -13,7 +13,7 @@ #include "core/framework/float4.h" #include "test/test_environment.h" -#include "test_utils.h" +#include "test/unittest_util/framework_test_utils.h" #include "gtest/gtest.h" using namespace ONNX_NAMESPACE; diff --git a/onnxruntime/test/framework/float_8_test.cc b/onnxruntime/test/framework/float_8_test.cc index 62a82e50d4c8a..a282035a9bcc2 100644 --- a/onnxruntime/test/framework/float_8_test.cc +++ b/onnxruntime/test/framework/float_8_test.cc @@ -8,7 +8,7 @@ #include "core/framework/float8.h" #include "test/capturing_sink.h" #include "test/test_environment.h" -#include "test_utils.h" +#include "test/unittest_util/framework_test_utils.h" #include "gtest/gtest.h" using namespace ONNX_NAMESPACE; diff --git a/onnxruntime/test/framework/function_test.cc b/onnxruntime/test/framework/function_test.cc index 6cde2fbc71f5d..699d1b1a2c27a 100644 --- a/onnxruntime/test/framework/function_test.cc +++ b/onnxruntime/test/framework/function_test.cc @@ -13,13 +13,12 @@ #include "core/providers/cpu/cpu_execution_provider.h" #include "core/session/inference_session.h" -#include "test/test_environment.h" -#include "test/framework/test_utils.h" -#include "inference_session_wrapper.h" #include "test/common/tensor_op_test_utils.h" +#include "test/unittest_util/framework_test_utils.h" +#include "test/internal_testing_ep/internal_testing_execution_provider.h" +#include "test/test_environment.h" #include "test/util/include/asserts.h" - -#include "test/providers/internal_testing/internal_testing_execution_provider.h" +#include "test/util/include/inference_session_wrapper.h" // Unit tests to check the implementation of functions, model-local functions, // function-inlining etc. diff --git a/onnxruntime/test/framework/inference_session_test.cc b/onnxruntime/test/framework/inference_session_test.cc index 6131eff92ac78..8f6ed6f55c11a 100644 --- a/onnxruntime/test/framework/inference_session_test.cc +++ b/onnxruntime/test/framework/inference_session_test.cc @@ -51,7 +51,7 @@ #include "core/session/onnxruntime_session_options_config_keys.h" #include "core/session/onnxruntime_run_options_config_keys.h" #include "dummy_provider.h" -#include "test_utils.h" +#include "test/unittest_util/framework_test_utils.h" #include "test/capturing_sink.h" #include "test/test_environment.h" #include "test/providers/provider_test_utils.h" @@ -630,7 +630,8 @@ TEST(InferenceSessionTests, CheckRunLogger) { } // WebAssembly will emit profiling data into console -#if !defined(__wasm__) +// TODO(hasesh): Investigate why this test fails on Windows CUDA builds +#if (!defined(__wasm__) && !defined(_WIN32)) TEST(InferenceSessionTests, CheckRunProfilerWithSessionOptions) { SessionOptions so; diff --git a/onnxruntime/test/framework/insert_cast_transformer_test.cc b/onnxruntime/test/framework/insert_cast_transformer_test.cc index 9278541b07512..c4b0f3ffd15d9 100644 --- a/onnxruntime/test/framework/insert_cast_transformer_test.cc +++ b/onnxruntime/test/framework/insert_cast_transformer_test.cc @@ -6,7 +6,7 @@ #include "core/graph/model.h" #include "core/graph/node_attr_utils.h" #include "gtest/gtest.h" -#include "test_utils.h" +#include "test/unittest_util/framework_test_utils.h" #include "test/test_environment.h" #include "test/util/include/default_providers.h" #include "test/util/include/inference_session_wrapper.h" diff --git a/onnxruntime/test/framework/local_kernel_registry_test.cc b/onnxruntime/test/framework/local_kernel_registry_test.cc index 1b1ca5291c588..c6d345045e6ff 100644 --- a/onnxruntime/test/framework/local_kernel_registry_test.cc +++ b/onnxruntime/test/framework/local_kernel_registry_test.cc @@ -24,7 +24,7 @@ #include "test/capturing_sink.h" #include "test/test_environment.h" #include "test/util/include/asserts.h" -#include "test_utils.h" +#include "test/unittest_util/framework_test_utils.h" #include "gtest/gtest.h" using namespace ONNX_NAMESPACE; diff --git a/onnxruntime/test/framework/memcpy_transformer_test.cc b/onnxruntime/test/framework/memcpy_transformer_test.cc index 6e86e5b58aead..cfc3522821f7d 100644 --- a/onnxruntime/test/framework/memcpy_transformer_test.cc +++ b/onnxruntime/test/framework/memcpy_transformer_test.cc @@ -8,7 +8,7 @@ #include "core/graph/model.h" #include "default_providers.h" #include "gtest/gtest.h" -#include "test_utils.h" +#include "test/unittest_util/framework_test_utils.h" #include "test/test_environment.h" #include "asserts.h" diff --git a/onnxruntime/test/framework/opaque_kernels_test.cc b/onnxruntime/test/framework/opaque_kernels_test.cc index 5069e4a5dbe5c..80764693ce4b0 100644 --- a/onnxruntime/test/framework/opaque_kernels_test.cc +++ b/onnxruntime/test/framework/opaque_kernels_test.cc @@ -16,7 +16,7 @@ #include "test/providers/provider_test_utils.h" #include "asserts.h" -#include "test_utils.h" +#include "test/unittest_util/framework_test_utils.h" using namespace ONNX_NAMESPACE; using namespace onnxruntime::common; diff --git a/onnxruntime/test/framework/ort_model_only_test.cc b/onnxruntime/test/framework/ort_model_only_test.cc index e2cb82e47f32b..3032b3170a6e0 100644 --- a/onnxruntime/test/framework/ort_model_only_test.cc +++ b/onnxruntime/test/framework/ort_model_only_test.cc @@ -10,10 +10,10 @@ #include "core/session/onnxruntime_cxx_api.h" #include "core/session/inference_session.h" #include "core/session/onnxruntime_session_options_config_keys.h" -#include "test_utils.h" +#include "test/unittest_util/framework_test_utils.h" #include "test/common/tensor_op_test_utils.h" -#include "test/providers/checkers.h" #include "test/test_environment.h" +#include "test/unittest_util/checkers.h" #include "test/util/include/asserts.h" #include "test/util/include/inference_session_wrapper.h" diff --git a/onnxruntime/test/framework/parallel_executor_test.cc b/onnxruntime/test/framework/parallel_executor_test.cc index 8dcaad10c5b5a..2be13ec542c4b 100644 --- a/onnxruntime/test/framework/parallel_executor_test.cc +++ b/onnxruntime/test/framework/parallel_executor_test.cc @@ -4,7 +4,7 @@ #include "core/framework/data_types.h" #include "core/framework/op_kernel.h" #include "test/providers/provider_test_utils.h" -#include "test_utils.h" +#include "test/unittest_util/framework_test_utils.h" #include "core/session/inference_session.h" #include "gtest/gtest.h" diff --git a/onnxruntime/test/framework/save_model_with_external_initializers.cc b/onnxruntime/test/framework/save_model_with_external_initializers.cc index e70d870ef6988..a05a56c61ca00 100644 --- a/onnxruntime/test/framework/save_model_with_external_initializers.cc +++ b/onnxruntime/test/framework/save_model_with_external_initializers.cc @@ -9,7 +9,7 @@ #include "core/graph/model_saving_options.h" #include "core/framework/tensorprotoutils.h" #include "test/test_environment.h" -#include "test_utils.h" +#include "test/unittest_util/framework_test_utils.h" #include "test/util/include/asserts.h" #include "gtest/gtest.h" diff --git a/onnxruntime/test/framework/session_state_test.cc b/onnxruntime/test/framework/session_state_test.cc index a9d6273ae2f20..cdcd3c2327421 100644 --- a/onnxruntime/test/framework/session_state_test.cc +++ b/onnxruntime/test/framework/session_state_test.cc @@ -23,8 +23,7 @@ #include "core/util/thread_utils.h" #include "gtest/gtest.h" #include "test/test_environment.h" -#include "test/optimizer/graph_transform_test_builder.h" -#include "test/util/include/test_environment.h" +#include "test/unittest_util/graph_transform_test_builder.h" #include "test/util/include/default_providers.h" #include "test/util/include/file_util.h" #include "core/optimizer/layout_transformation/layout_transformation.h" diff --git a/onnxruntime/test/framework/sparse_kernels_test.cc b/onnxruntime/test/framework/sparse_kernels_test.cc index 43de3a945526c..89e928af10b8b 100644 --- a/onnxruntime/test/framework/sparse_kernels_test.cc +++ b/onnxruntime/test/framework/sparse_kernels_test.cc @@ -17,7 +17,7 @@ #include "core/session/inference_session.h" #include "test/providers/provider_test_utils.h" #include "asserts.h" -#include "test_utils.h" +#include "test/unittest_util/framework_test_utils.h" #include "file_util.h" #include "default_providers.h" diff --git a/onnxruntime/test/framework/tensor_test.cc b/onnxruntime/test/framework/tensor_test.cc index f08675271de21..2125de0a36e4f 100644 --- a/onnxruntime/test/framework/tensor_test.cc +++ b/onnxruntime/test/framework/tensor_test.cc @@ -3,7 +3,7 @@ #include "core/framework/tensor.h" #include "core/framework/allocator_utils.h" -#include "test_utils.h" +#include "test/unittest_util/framework_test_utils.h" #include "gmock/gmock.h" #include "gtest/gtest.h" diff --git a/onnxruntime/test/providers/internal_testing/README.md b/onnxruntime/test/internal_testing_ep/README.md similarity index 100% rename from onnxruntime/test/providers/internal_testing/README.md rename to onnxruntime/test/internal_testing_ep/README.md diff --git a/onnxruntime/test/providers/internal_testing/internal_testing_ep_static_kernels.cc b/onnxruntime/test/internal_testing_ep/internal_testing_ep_static_kernels.cc similarity index 100% rename from onnxruntime/test/providers/internal_testing/internal_testing_ep_static_kernels.cc rename to onnxruntime/test/internal_testing_ep/internal_testing_ep_static_kernels.cc diff --git a/onnxruntime/test/providers/internal_testing/internal_testing_ep_static_kernels.h b/onnxruntime/test/internal_testing_ep/internal_testing_ep_static_kernels.h similarity index 100% rename from onnxruntime/test/providers/internal_testing/internal_testing_ep_static_kernels.h rename to onnxruntime/test/internal_testing_ep/internal_testing_ep_static_kernels.h diff --git a/onnxruntime/test/providers/internal_testing/internal_testing_execution_provider.cc b/onnxruntime/test/internal_testing_ep/internal_testing_execution_provider.cc similarity index 100% rename from onnxruntime/test/providers/internal_testing/internal_testing_execution_provider.cc rename to onnxruntime/test/internal_testing_ep/internal_testing_execution_provider.cc diff --git a/onnxruntime/test/providers/internal_testing/internal_testing_execution_provider.h b/onnxruntime/test/internal_testing_ep/internal_testing_execution_provider.h similarity index 100% rename from onnxruntime/test/providers/internal_testing/internal_testing_execution_provider.h rename to onnxruntime/test/internal_testing_ep/internal_testing_execution_provider.h diff --git a/onnxruntime/test/providers/internal_testing/internal_testing_partitioning_tests.cc b/onnxruntime/test/internal_testing_ep/internal_testing_partitioning_tests.cc similarity index 99% rename from onnxruntime/test/providers/internal_testing/internal_testing_partitioning_tests.cc rename to onnxruntime/test/internal_testing_ep/internal_testing_partitioning_tests.cc index d58db5178032d..96d697e0077a4 100644 --- a/onnxruntime/test/providers/internal_testing/internal_testing_partitioning_tests.cc +++ b/onnxruntime/test/internal_testing_ep/internal_testing_partitioning_tests.cc @@ -8,9 +8,9 @@ #include "core/framework/utils.h" #include "core/session/inference_session.h" -#include "test/framework/test_utils.h" +#include "test/unittest_util/framework_test_utils.h" +#include "test/internal_testing_ep/internal_testing_execution_provider.h" #include "test/test_environment.h" -#include "test/providers/internal_testing/internal_testing_execution_provider.h" #include "test/util/include/asserts.h" #include "test/util/include/inference_session_wrapper.h" #include "test/util/include/test_utils.h" diff --git a/onnxruntime/test/providers/internal_testing/internal_testing_tests.cc b/onnxruntime/test/internal_testing_ep/internal_testing_tests.cc similarity index 99% rename from onnxruntime/test/providers/internal_testing/internal_testing_tests.cc rename to onnxruntime/test/internal_testing_ep/internal_testing_tests.cc index 275f29fdd9073..f5d4989d86cf8 100644 --- a/onnxruntime/test/providers/internal_testing/internal_testing_tests.cc +++ b/onnxruntime/test/internal_testing_ep/internal_testing_tests.cc @@ -13,9 +13,9 @@ #include "core/session/onnxruntime_session_options_config_keys.h" #include "core/session/ort_env.h" -#include "test/framework/test_utils.h" +#include "test/unittest_util/framework_test_utils.h" #include "test/test_environment.h" -#include "test/providers/internal_testing/internal_testing_execution_provider.h" +#include "test/internal_testing_ep/internal_testing_execution_provider.h" #include "test/util/include/asserts.h" #include "test/util/include/inference_session_wrapper.h" #include "test/util/include/test_utils.h" diff --git a/onnxruntime/test/ir/graph_test.cc b/onnxruntime/test/ir/graph_test.cc index ca1166e19037c..4fd9830440846 100644 --- a/onnxruntime/test/ir/graph_test.cc +++ b/onnxruntime/test/ir/graph_test.cc @@ -13,7 +13,7 @@ #include "gmock/gmock.h" #include "onnx/defs/function.h" #include "core/graph/function_impl.h" -#include "test/framework/test_utils.h" +#include "test/unittest_util/framework_test_utils.h" #ifdef __GNUC__ #define UNUSED __attribute__((unused)) diff --git a/onnxruntime/test/onnx/TestCase.cc b/onnxruntime/test/onnx/TestCase.cc index e9c13e215f619..6df98ff505fa1 100644 --- a/onnxruntime/test/onnx/TestCase.cc +++ b/onnxruntime/test/onnx/TestCase.cc @@ -1408,6 +1408,9 @@ std::unique_ptr> GetBrokenTests(const std::string& provider broken_tests->insert({"gridsample_volumetric_nearest_align_corners_1", "unknown version"}); broken_tests->insert({"rotary_embedding_no_position_ids_expanded", "unknown version"}); broken_tests->insert({"rotary_embedding_no_position_ids_interleaved_expanded", "unknown version"}); + broken_tests->insert({"rotary_embedding_no_position_ids_rotary_dim", "unknown version"}); + broken_tests->insert({"rotary_embedding_with_interleaved_rotary_dim", "unknown version"}); + broken_tests->insert({"rotary_embedding_with_rotary_dim", "unknown version"}); // Fails since QNN SDK 2.17.0: // expected 7.70947 (40f6b3f3), got 7.84096 (40fae920), diff: 0.131491, tol=0.00870947 idx=419. 100 of 1715 differ broken_tests->insert({"facedetection_op8_qdq", "result differs"}); diff --git a/onnxruntime/test/onnx/main.cc b/onnxruntime/test/onnx/main.cc index 463634b370d4c..b6f2cb2683677 100644 --- a/onnxruntime/test/onnx/main.cc +++ b/onnxruntime/test/onnx/main.cc @@ -795,6 +795,24 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)"); // Please make no more changes to the list static const ORTCHAR_T* immutable_broken_tests[] = { + // pending ONNX update + ORT_TSTR("attention_3d_gqa"), + ORT_TSTR("attention_3d_gqa_attn_mask"), + ORT_TSTR("attention_3d_gqa_causal"), + ORT_TSTR("attention_3d_gqa_scaled"), + ORT_TSTR("attention_3d_gqa_softcap"), + ORT_TSTR("attention_3d_gqa_with_past_and_present"), + ORT_TSTR("attention_4d_gqa"), + ORT_TSTR("attention_4d_gqa_attn_mask"), + ORT_TSTR("attention_4d_gqa_causal"), + ORT_TSTR("attention_4d_gqa_scaled"), + ORT_TSTR("attention_4d_gqa_softcap"), + ORT_TSTR("attention_4d_gqa_with_past_and_present"), + ORT_TSTR("attention_4d_diff_heads_mask4d_padded_kv"), + ORT_TSTR("attention_4d_gqa_with_past_and_present_fp16"), + ORT_TSTR("attention_4d_with_past_and_present_qk_matmul_bias_3d_mask_causal"), + ORT_TSTR("attention_4d_with_past_and_present_qk_matmul_bias_4d_mask_causal"), + // unsupported case ORT_TSTR("AvgPool1d"), ORT_TSTR("AvgPool1d_stride"), ORT_TSTR("AvgPool2d"), diff --git a/onnxruntime/test/opaque_api/test_opaque_api.cc b/onnxruntime/test/opaque_api/test_opaque_api.cc index 5bccf5ab1ac0d..da3ad08ae1ce2 100644 --- a/onnxruntime/test/opaque_api/test_opaque_api.cc +++ b/onnxruntime/test/opaque_api/test_opaque_api.cc @@ -15,7 +15,7 @@ #include "gtest/gtest.h" #include "core/graph/onnx_protobuf.h" #include "test/providers/provider_test_utils.h" -#include "test/framework/test_utils.h" +#include "test/unittest_util/framework_test_utils.h" using namespace ONNX_NAMESPACE; using namespace onnxruntime::common; diff --git a/onnxruntime/test/optimizer/avx2_weight_s8_to_u8_test.cc b/onnxruntime/test/optimizer/avx2_weight_s8_to_u8_test.cc index 8f3fafe43b14e..20fe06ac5ed66 100644 --- a/onnxruntime/test/optimizer/avx2_weight_s8_to_u8_test.cc +++ b/onnxruntime/test/optimizer/avx2_weight_s8_to_u8_test.cc @@ -15,8 +15,8 @@ #include "test/compare_ortvalue.h" #include "test/test_environment.h" #include "test/common/quantization_test_utils.h" -#include "test/optimizer/graph_transform_test_builder.h" #include "test/providers/provider_test_utils.h" +#include "test/unittest_util/graph_transform_test_builder.h" #include "test/util/include/asserts.h" #include "test/util/include/default_providers.h" #include "test/util/include/inference_session_wrapper.h" diff --git a/onnxruntime/test/optimizer/compute_optimizer_test.cc b/onnxruntime/test/optimizer/compute_optimizer_test.cc index 9dcedd1fd7681..333c1edf8ffab 100644 --- a/onnxruntime/test/optimizer/compute_optimizer_test.cc +++ b/onnxruntime/test/optimizer/compute_optimizer_test.cc @@ -30,14 +30,14 @@ #include "test/common/tensor_op_test_utils.h" #include "test/compare_ortvalue.h" -#include "test/framework/test_utils.h" -#include "test/optimizer/graph_transform_test_builder.h" +#include "test/unittest_util/framework_test_utils.h" #include "test/optimizer/graph_transform_test_fixture.h" #include "test/optimizer/test_optimizer_utils.h" #include "test/providers/provider_test_utils.h" -#include "test/util/include/temp_dir.h" +#include "test/unittest_util/graph_transform_test_builder.h" #include "test/util/include/asserts.h" #include "test/util/include/default_providers.h" +#include "test/util/include/temp_dir.h" namespace onnxruntime { namespace test { diff --git a/onnxruntime/test/optimizer/conv_add_act_test.cc b/onnxruntime/test/optimizer/conv_add_act_test.cc index e97c738082cb7..f61f9b29d9cce 100644 --- a/onnxruntime/test/optimizer/conv_add_act_test.cc +++ b/onnxruntime/test/optimizer/conv_add_act_test.cc @@ -4,7 +4,7 @@ #include #include "gtest/gtest.h" -#include "graph_transform_test_builder.h" +#include "test/unittest_util/graph_transform_test_builder.h" #include "core/graph/graph.h" diff --git a/onnxruntime/test/optimizer/cse_test.cc b/onnxruntime/test/optimizer/cse_test.cc index bad96406df845..f34cfcadc7ab4 100644 --- a/onnxruntime/test/optimizer/cse_test.cc +++ b/onnxruntime/test/optimizer/cse_test.cc @@ -4,7 +4,7 @@ #include "core/graph/model.h" #include "core/optimizer/common_subexpression_elimination.h" #include "core/optimizer/graph_transformer_mgr.h" -#include "test/framework/test_utils.h" +#include "test/unittest_util/framework_test_utils.h" #include "test/test_environment.h" #include "test/util/include/asserts.h" diff --git a/onnxruntime/test/optimizer/ensure_unique_dq_for_node_unit_test.cc b/onnxruntime/test/optimizer/ensure_unique_dq_for_node_unit_test.cc index 2c72658ea98fe..67cb84f04d13f 100644 --- a/onnxruntime/test/optimizer/ensure_unique_dq_for_node_unit_test.cc +++ b/onnxruntime/test/optimizer/ensure_unique_dq_for_node_unit_test.cc @@ -5,8 +5,8 @@ #include "gtest/gtest.h" -#include "test/optimizer/graph_transform_test_builder.h" #include "test/test_environment.h" +#include "test/unittest_util/graph_transform_test_builder.h" #include "test/util/include/asserts.h" namespace onnxruntime::test { diff --git a/onnxruntime/test/optimizer/free_dimension_override_test.cc b/onnxruntime/test/optimizer/free_dimension_override_test.cc index 08f7ebf1c42fc..357045facdfcb 100644 --- a/onnxruntime/test/optimizer/free_dimension_override_test.cc +++ b/onnxruntime/test/optimizer/free_dimension_override_test.cc @@ -5,7 +5,7 @@ #include "core/graph/model.h" #include "core/optimizer/graph_transformer.h" #include "core/optimizer/graph_transformer_mgr.h" -#include "test/framework/test_utils.h" +#include "test/unittest_util/framework_test_utils.h" #include "test/test_environment.h" #include "gtest/gtest.h" #include "core/optimizer/free_dim_override_transformer.h" diff --git a/onnxruntime/test/optimizer/fuse_initializers_transformer_test.cc b/onnxruntime/test/optimizer/fuse_initializers_transformer_test.cc index e59a9308155a0..de973679c8f80 100644 --- a/onnxruntime/test/optimizer/fuse_initializers_transformer_test.cc +++ b/onnxruntime/test/optimizer/fuse_initializers_transformer_test.cc @@ -3,7 +3,7 @@ #include // needed for std::transform #include "gtest/gtest.h" -#include "test/framework/test_utils.h" +#include "test/unittest_util/framework_test_utils.h" #include "test/test_environment.h" #include "test/util/include/default_providers.h" #include "test/util/include/asserts.h" diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index 15aea3a22dfd2..5ad5811cfd7bc 100644 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -83,11 +83,11 @@ #include "test/capturing_sink.h" #include "test/common/tensor_op_test_utils.h" #include "test/compare_ortvalue.h" -#include "test/framework/test_utils.h" -#include "test/optimizer/graph_transform_test_builder.h" +#include "test/unittest_util/framework_test_utils.h" #include "test/optimizer/graph_transform_test_fixture.h" #include "test/providers/provider_test_utils.h" #include "test/test_environment.h" +#include "test/unittest_util/graph_transform_test_builder.h" #include "test/util/include/asserts.h" #include "test/util/include/default_providers.h" #include "test/util/include/inference_session_wrapper.h" diff --git a/onnxruntime/test/optimizer/graph_transform_test_layernorm.cc b/onnxruntime/test/optimizer/graph_transform_test_layernorm.cc index 627a68f38b585..0afb836192b0a 100644 --- a/onnxruntime/test/optimizer/graph_transform_test_layernorm.cc +++ b/onnxruntime/test/optimizer/graph_transform_test_layernorm.cc @@ -20,10 +20,10 @@ #include "core/optimizer/skip_layer_norm_fusion.h" #include "test/capturing_sink.h" -#include "test/framework/test_utils.h" -#include "test/optimizer/graph_transform_test_builder.h" +#include "test/unittest_util/framework_test_utils.h" #include "test/optimizer/graph_transform_test_fixture.h" #include "test/providers/provider_test_utils.h" +#include "test/unittest_util/graph_transform_test_builder.h" using namespace std; using namespace ONNX_NAMESPACE; diff --git a/onnxruntime/test/optimizer/graph_transform_utils_test.cc b/onnxruntime/test/optimizer/graph_transform_utils_test.cc index caa64560426af..dc687714f07cd 100644 --- a/onnxruntime/test/optimizer/graph_transform_utils_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_utils_test.cc @@ -3,7 +3,7 @@ #include "core/common/inlined_containers.h" #include "core/graph/onnx_protobuf.h" -#include "test/framework/test_utils.h" +#include "test/unittest_util/framework_test_utils.h" #include "test/capturing_sink.h" #include "test/test_environment.h" #include "gtest/gtest.h" diff --git a/onnxruntime/test/optimizer/nchwc_optimizer_test.cc b/onnxruntime/test/optimizer/nchwc_optimizer_test.cc index 538f60040418c..fc0ba86c7f1f6 100644 --- a/onnxruntime/test/optimizer/nchwc_optimizer_test.cc +++ b/onnxruntime/test/optimizer/nchwc_optimizer_test.cc @@ -9,7 +9,7 @@ #include "core/framework/tensorprotoutils.h" #include "test/compare_ortvalue.h" #include "test/test_environment.h" -#include "test/framework/test_utils.h" +#include "test/unittest_util/framework_test_utils.h" #include "test/util/include/asserts.h" #include "test/util/include/inference_session_wrapper.h" #include diff --git a/onnxruntime/test/optimizer/nhwc_transformer_test.cc b/onnxruntime/test/optimizer/nhwc_transformer_test.cc index 1f76ca2a0291f..21ea7af4e7389 100644 --- a/onnxruntime/test/optimizer/nhwc_transformer_test.cc +++ b/onnxruntime/test/optimizer/nhwc_transformer_test.cc @@ -5,7 +5,7 @@ #include #include "gtest/gtest.h" -#include "graph_transform_test_builder.h" +#include "test/unittest_util/graph_transform_test_builder.h" #include "core/mlas/inc/mlas.h" #include "core/graph/graph.h" #include "test/common/dnnl_op_test_utils.h" diff --git a/onnxruntime/test/optimizer/optimizer_test.cc b/onnxruntime/test/optimizer/optimizer_test.cc index b306f026b2dfd..b84e4c9bcc8f0 100644 --- a/onnxruntime/test/optimizer/optimizer_test.cc +++ b/onnxruntime/test/optimizer/optimizer_test.cc @@ -12,7 +12,7 @@ #include "core/framework/op_kernel.h" #include "core/util/math.h" #include "core/platform/env.h" -#include "test/framework/test_utils.h" +#include "test/unittest_util/framework_test_utils.h" #include "test/capturing_sink.h" #include "test/test_environment.h" #include "asserts.h" diff --git a/onnxruntime/test/optimizer/qdq_matmulnbits_transformer_test.cc b/onnxruntime/test/optimizer/qdq_matmulnbits_transformer_test.cc index e9c7b11fe9da2..a0b44bbce62f8 100644 --- a/onnxruntime/test/optimizer/qdq_matmulnbits_transformer_test.cc +++ b/onnxruntime/test/optimizer/qdq_matmulnbits_transformer_test.cc @@ -13,12 +13,12 @@ #include "test/compare_ortvalue.h" #include "test/test_environment.h" -#include "test/framework/test_utils.h" -#include "test/optimizer/qdq_test_utils.h" -#include "test/optimizer/graph_transform_test_builder.h" +#include "test/unittest_util/framework_test_utils.h" +#include "test/unittest_util/graph_transform_test_builder.h" #include "test/util/include/asserts.h" #include "test/util/include/default_providers.h" #include "test/util/include/inference_session_wrapper.h" +#include "test/unittest_util/qdq_test_utils.h" #include "gtest/gtest.h" diff --git a/onnxruntime/test/optimizer/qdq_transformer_fastmath_test.cc b/onnxruntime/test/optimizer/qdq_transformer_fastmath_test.cc index ccfa1f1159937..55f1d212a8034 100644 --- a/onnxruntime/test/optimizer/qdq_transformer_fastmath_test.cc +++ b/onnxruntime/test/optimizer/qdq_transformer_fastmath_test.cc @@ -18,14 +18,14 @@ #include "test/compare_ortvalue.h" #include "test/test_environment.h" -#include "test/framework/test_utils.h" +#include "test/unittest_util/framework_test_utils.h" #include "test/util/include/asserts.h" #include "test/util/include/inference_session_wrapper.h" #include "gtest/gtest.h" -#include "graph_transform_test_builder.h" +#include "test/unittest_util/graph_transform_test_builder.h" -#include "qdq_test_utils.h" +#include "test/unittest_util/qdq_test_utils.h" #if defined(__aarch64__) && defined(__linux__) && !defined(DISABLE_CONTRIB_OPS) diff --git a/onnxruntime/test/optimizer/qdq_transformer_test.cc b/onnxruntime/test/optimizer/qdq_transformer_test.cc index 4196ed280a993..79e3073b944ff 100644 --- a/onnxruntime/test/optimizer/qdq_transformer_test.cc +++ b/onnxruntime/test/optimizer/qdq_transformer_test.cc @@ -25,15 +25,15 @@ #include "test/compare_ortvalue.h" #include "test/test_environment.h" -#include "test/framework/test_utils.h" +#include "test/unittest_util/framework_test_utils.h" #include "test/util/include/asserts.h" #include "test/util/include/inference_session_wrapper.h" #include "test/common/dnnl_op_test_utils.h" #include "gtest/gtest.h" -#include "graph_transform_test_builder.h" +#include "test/unittest_util/graph_transform_test_builder.h" -#include "qdq_test_utils.h" +#include "test/unittest_util/qdq_test_utils.h" #if defined(_MSC_VER) #pragma warning(disable : 4127) diff --git a/onnxruntime/test/optimizer/resnet50_fusion_test.cc b/onnxruntime/test/optimizer/resnet50_fusion_test.cc index 7e6677c8e1ddf..b88338f5999c6 100644 --- a/onnxruntime/test/optimizer/resnet50_fusion_test.cc +++ b/onnxruntime/test/optimizer/resnet50_fusion_test.cc @@ -7,12 +7,12 @@ #include "core/optimizer/conv_add_act_fusion.h" #include "core/mlas/inc/mlas.h" #include "gtest/gtest.h" -#include "graph_transform_test_builder.h" -#include "test/test_environment.h" -#include "test/util/include/asserts.h" +#include "test/unittest_util/graph_transform_test_builder.h" -#include "test/optimizer/graph_transform_test_builder.h" #include "test/optimizer/graph_transform_test_fixture.h" +#include "test/test_environment.h" +#include "test/unittest_util/graph_transform_test_builder.h" +#include "test/util/include/asserts.h" namespace onnxruntime { namespace test { diff --git a/onnxruntime/test/optimizer/rule_based_graph_transformer_test.cc b/onnxruntime/test/optimizer/rule_based_graph_transformer_test.cc index bc0392444267e..adc173456a7db 100644 --- a/onnxruntime/test/optimizer/rule_based_graph_transformer_test.cc +++ b/onnxruntime/test/optimizer/rule_based_graph_transformer_test.cc @@ -11,7 +11,7 @@ #include "core/optimizer/graph_transformer.h" #include "core/optimizer/graph_transformer_mgr.h" #include "dummy_graph_transformer.h" -#include "test/framework/test_utils.h" +#include "test/unittest_util/framework_test_utils.h" #include "test/test_environment.h" using namespace std; @@ -72,4 +72,4 @@ TEST(RuleBasedGraphTransformerTest, TestSettingStepsInGraphTransformerManager) { ASSERT_EQ(steps_queried, static_cast(10)); } } // namespace test -} // namespace onnxruntime \ No newline at end of file +} // namespace onnxruntime diff --git a/onnxruntime/test/optimizer/runtime_optimization/graph_runtime_optimization_test.cc b/onnxruntime/test/optimizer/runtime_optimization/graph_runtime_optimization_test.cc index a59ce60d65136..954b18cb5b694 100644 --- a/onnxruntime/test/optimizer/runtime_optimization/graph_runtime_optimization_test.cc +++ b/onnxruntime/test/optimizer/runtime_optimization/graph_runtime_optimization_test.cc @@ -7,7 +7,7 @@ #include "gtest/gtest.h" #include "core/session/onnxruntime_session_options_config_keys.h" -#include "test/framework/test_utils.h" +#include "test/unittest_util/framework_test_utils.h" #include "test/util/include/asserts.h" #include "test/util/include/inference_session_wrapper.h" #include "test/util/include/test_environment.h" diff --git a/onnxruntime/test/optimizer/test_optimizer_utils.h b/onnxruntime/test/optimizer/test_optimizer_utils.h index 82da05695de70..a90d9d7feadbe 100644 --- a/onnxruntime/test/optimizer/test_optimizer_utils.h +++ b/onnxruntime/test/optimizer/test_optimizer_utils.h @@ -10,7 +10,7 @@ #include "core/framework/float16.h" #include "core/framework/framework_common.h" #include "core/framework/ort_value.h" -#include "test/framework/test_utils.h" +#include "test/unittest_util/framework_test_utils.h" namespace onnxruntime { namespace test { diff --git a/onnxruntime/test/optimizer/transpose_optimizer_test.cc b/onnxruntime/test/optimizer/transpose_optimizer_test.cc index f6fce37322c10..fbdd73617f53f 100644 --- a/onnxruntime/test/optimizer/transpose_optimizer_test.cc +++ b/onnxruntime/test/optimizer/transpose_optimizer_test.cc @@ -9,18 +9,18 @@ #include "gtest/gtest.h" #include "gmock/gmock.h" -#include "core/graph/graph.h" -#include "core/graph/node_attr_utils.h" #include "core/framework/op_node_proto_helper.h" #include "core/framework/utils.h" +#include "core/graph/graph.h" +#include "core/graph/node_attr_utils.h" #include "core/optimizer/transpose_optimization/onnx_transpose_optimization.h" #include "core/optimizer/transpose_optimization/optimizer_api.h" #include "core/optimizer/transpose_optimization/ort_optimizer_utils.h" #include "core/session/onnxruntime_session_options_config_keys.h" +#include "test/internal_testing_ep/internal_testing_execution_provider.h" #include "test/test_environment.h" -#include "test/optimizer/graph_transform_test_builder.h" -#include "test/providers/internal_testing/internal_testing_execution_provider.h" +#include "test/unittest_util/graph_transform_test_builder.h" #include "test/util/include/asserts.h" #include "test/util/include/default_providers.h" #include "test/util/include/inference_session_wrapper.h" diff --git a/onnxruntime/test/platform/path_lib_test.cc b/onnxruntime/test/platform/path_lib_test.cc index f4a41347b9b39..1e55d77e81d6f 100644 --- a/onnxruntime/test/platform/path_lib_test.cc +++ b/onnxruntime/test/platform/path_lib_test.cc @@ -7,7 +7,7 @@ #include "core/platform/env.h" #include "core/platform/path_lib.h" -#include "test/framework/test_utils.h" +#include "test/unittest_util/framework_test_utils.h" namespace onnxruntime { namespace test { diff --git a/onnxruntime/test/providers/cann/cann_basic_test.cc b/onnxruntime/test/providers/cann/cann_basic_test.cc index 52d3ffd788487..5b0f24c685d21 100644 --- a/onnxruntime/test/providers/cann/cann_basic_test.cc +++ b/onnxruntime/test/providers/cann/cann_basic_test.cc @@ -4,7 +4,7 @@ #include "core/graph/onnx_protobuf.h" #include "core/session/inference_session.h" #include "test/providers/provider_test_utils.h" -#include "test/framework/test_utils.h" +#include "test/unittest_util/framework_test_utils.h" #include "gtest/gtest.h" #include "test/util/include/default_providers.h" diff --git a/onnxruntime/test/providers/coreml/coreml_basic_test.cc b/onnxruntime/test/providers/coreml/coreml_basic_test.cc index 2449f7c962e83..78832a5f26e45 100644 --- a/onnxruntime/test/providers/coreml/coreml_basic_test.cc +++ b/onnxruntime/test/providers/coreml/coreml_basic_test.cc @@ -12,7 +12,7 @@ #include "core/session/inference_session.h" #include "core/session/onnxruntime_cxx_api.h" #include "test/common/tensor_op_test_utils.h" -#include "test/framework/test_utils.h" +#include "test/unittest_util/framework_test_utils.h" #include "test/util/include/asserts.h" #include "test/util/include/current_test_name.h" #include "test/util/include/default_providers.h" diff --git a/onnxruntime/test/providers/coreml/dynamic_input_test.cc b/onnxruntime/test/providers/coreml/dynamic_input_test.cc index 8294f65745256..6e8c7a2f09821 100644 --- a/onnxruntime/test/providers/coreml/dynamic_input_test.cc +++ b/onnxruntime/test/providers/coreml/dynamic_input_test.cc @@ -10,8 +10,8 @@ #include "core/providers/coreml/coreml_provider_factory_creator.h" #include "core/providers/coreml/coreml_provider_factory.h" // for COREMLFlags #include "test/common/random_generator.h" -#include "test/providers/model_tester.h" #include "test/util/include/current_test_name.h" +#include "test/unittest_util/model_tester.h" #include "test/util/include/test_utils.h" namespace onnxruntime::test { diff --git a/onnxruntime/test/providers/cpu/activation/activation_op_test.cc b/onnxruntime/test/providers/cpu/activation/activation_op_test.cc index 44cc20c4a25bb..11a3d67a3e13e 100644 --- a/onnxruntime/test/providers/cpu/activation/activation_op_test.cc +++ b/onnxruntime/test/providers/cpu/activation/activation_op_test.cc @@ -739,13 +739,13 @@ TEST_F(ActivationOpTest, ONNX_Gelu) { TestActivationOp( "Gelu", input_values, - [](float x) { return 0.5 * x * (1 + erf(x * M_SQRT1_2)); }, {}, + [](float x) { return static_cast(0.5 * x * (1 + erf(x * M_SQRT1_2))); }, {}, {{"approximate", "none"}}, true, 20); TestActivationOp( "Gelu", input_values, - [](float x) { return 0.5 * x * (1 + erf(x * M_SQRT1_2)); }, + [](float x) { return static_cast(0.5 * x * (1 + erf(x * M_SQRT1_2))); }, {}, {/*default value of approximate attribute is none */}, true, 20); @@ -753,7 +753,7 @@ TEST_F(ActivationOpTest, ONNX_Gelu) { "Gelu", input_values, [](float x) { - return 0.5 * x * (1 + tanh(sqrt(2 / M_PI) * (x + 0.044715 * x * x * x))); + return static_cast(0.5 * x * (1 + tanh(sqrt(2 / M_PI) * (x + 0.044715 * x * x * x)))); }, {}, {{"approximate", "tanh"}}, true, 20); diff --git a/onnxruntime/test/providers/cpu/controlflow/loop_test.cc b/onnxruntime/test/providers/cpu/controlflow/loop_test.cc index dc50a75873034..07cd2114372dd 100644 --- a/onnxruntime/test/providers/cpu/controlflow/loop_test.cc +++ b/onnxruntime/test/providers/cpu/controlflow/loop_test.cc @@ -12,7 +12,7 @@ #include "test/providers/provider_test_utils.h" #include "test/util/include/default_providers.h" -#include "test/framework/test_utils.h" +#include "test/unittest_util/framework_test_utils.h" using namespace ONNX_NAMESPACE; diff --git a/onnxruntime/test/providers/cpu/llm/attention_op_test.cc b/onnxruntime/test/providers/cpu/llm/attention_op_test.cc index b4f6d328cacf7..54c2ed7d521db 100644 --- a/onnxruntime/test/providers/cpu/llm/attention_op_test.cc +++ b/onnxruntime/test/providers/cpu/llm/attention_op_test.cc @@ -664,6 +664,7 @@ TEST(AttentionTest, Attention4DAttnPastPresent) { false, true, true // disable_cpu, disable_cuda, disable_dml ); } + TEST(AttentionTest, Attention4DAttnIsCausal) { int batch_size = 2; // Q.shape[0] int q_num_heads = 3; // Q.shape[1] @@ -828,6 +829,38 @@ TEST(AttentionTest, Attention4DDiffHeadsWithPastAndPresent) { ); } +TEST(AttentionTest, Attention3DGqaAttn) { + int batch_size = 2; // Q.shape[0] + int q_num_heads = 9; // Q.shape[1] + int q_sequence_length = 4; // Q.shape[2] + int head_size = 8; // Q.shape[3] + int kv_sequence_length = 6; // K.shape[2] and V.shape[2] + int kv_num_heads = 3; // K.shape[1] and V.shape[1] + int v_head_size = 8; // V.shape[3] + int past_sequence_length = 0; // past_key.shape[2] and past_value.shape[2] + + // {2, 4, 72} + std::vector q = {0.548814f, 0.715189f, 0.602763f, 0.544883f, 0.423655f, 0.645894f, 0.437587f, 0.891773f, 0.963663f, 0.383442f, 0.791725f, 0.528895f, 0.568045f, 0.925597f, 0.071036f, 0.087129f, 0.020218f, 0.832620f, 0.778157f, 0.870012f, 0.978618f, 0.799159f, 0.461479f, 0.780529f, 0.118274f, 0.639921f, 0.143353f, 0.944669f, 0.521848f, 0.414662f, 0.264556f, 0.774234f, 0.456150f, 0.568434f, 0.018790f, 0.617635f, 0.612096f, 0.616934f, 0.943748f, 0.681820f, 0.359508f, 0.437032f, 0.697631f, 0.060225f, 0.666767f, 0.670638f, 0.210383f, 0.128926f, 0.315428f, 0.363711f, 0.570197f, 0.438602f, 0.988374f, 0.102045f, 0.208877f, 0.161310f, 0.653108f, 0.253292f, 0.466311f, 0.244426f, 0.158970f, 0.110375f, 0.656330f, 0.138183f, 0.196582f, 0.368725f, 0.820993f, 0.097101f, 0.837945f, 0.096098f, 0.976459f, 0.468651f, 0.976761f, 0.604846f, 0.739264f, 0.039188f, 0.282807f, 0.120197f, 0.296140f, 0.118728f, 0.317983f, 0.414263f, 0.064147f, 0.692472f, 0.566601f, 0.265390f, 0.523248f, 0.093941f, 0.575947f, 0.929296f, 0.318569f, 0.667410f, 0.131798f, 0.716327f, 0.289406f, 0.183191f, 0.586513f, 0.020108f, 0.828940f, 0.004695f, 0.677817f, 0.270008f, 0.735194f, 0.962189f, 0.248753f, 0.576157f, 0.592042f, 0.572252f, 0.223082f, 0.952749f, 0.447125f, 0.846409f, 0.699479f, 0.297437f, 0.813798f, 0.396506f, 0.881103f, 0.581273f, 0.881735f, 0.692532f, 0.725254f, 0.501324f, 0.956084f, 0.643990f, 0.423855f, 0.606393f, 0.019193f, 0.301575f, 0.660174f, 0.290078f, 0.618015f, 0.428769f, 0.135474f, 0.298282f, 0.569965f, 0.590873f, 0.574325f, 0.653201f, 0.652103f, 0.431418f, 0.896547f, 0.367562f, 0.435865f, 0.891923f, 0.806194f, 0.703889f, 0.100227f, 0.919483f, 0.714241f, 0.998847f, 0.149448f, 0.868126f, 0.162493f, 0.615560f, 0.123820f, 0.848008f, 0.807319f, 0.569101f, 0.407183f, 0.069167f, 0.697429f, 0.453543f, 0.722056f, 0.866382f, 0.975522f, 0.855803f, 0.011714f, 0.359978f, 0.729991f, 0.171630f, 0.521037f, 0.054338f, 0.199997f, 0.018522f, 0.793698f, 0.223925f, 0.345352f, 0.928081f, 0.704414f, 0.031839f, 0.164694f, 0.621478f, 0.577229f, 0.237893f, 0.934214f, 0.613966f, 0.535633f, 0.589910f, 0.730122f, 0.311945f, 0.398221f, 0.209844f, 0.186193f, 0.944372f, 0.739551f, 0.490459f, 0.227415f, 0.254356f, 0.058029f, 0.434417f, 0.311796f, 0.696343f, 0.377752f, 0.179604f, 0.024679f, 0.067250f, 0.679393f, 0.453697f, 0.536579f, 0.896671f, 0.990339f, 0.216897f, 0.663078f, 0.263322f, 0.020651f, 0.758379f, 0.320017f, 0.383464f, 0.588317f, 0.831048f, 0.628982f, 0.872651f, 0.273542f, 0.798047f, 0.185636f, 0.952792f, 0.687488f, 0.215508f, 0.947371f, 0.730856f, 0.253942f, 0.213312f, 0.518201f, 0.025663f, 0.207470f, 0.424685f, 0.374170f, 0.463575f, 0.277629f, 0.586784f, 0.863856f, 0.117532f, 0.517379f, 0.132068f, 0.716860f, 0.396060f, 0.565421f, 0.183280f, 0.144848f, 0.488056f, 0.355613f, 0.940432f, 0.765325f, 0.748664f, 0.903720f, 0.083422f, 0.552192f, 0.584476f, 0.961936f, 0.292148f, 0.240829f, 0.100294f, 0.016430f, 0.929529f, 0.669917f, 0.785153f, 0.281730f, 0.586410f, 0.063955f, 0.485628f, 0.977495f, 0.876505f, 0.338159f, 0.961570f, 0.231702f, 0.949319f, 0.941378f, 0.799203f, 0.630448f, 0.874288f, 0.293020f, 0.848944f, 0.617877f, 0.013237f, 0.347234f, 0.148141f, 0.981829f, 0.478370f, 0.497391f, 0.639473f, 0.368585f, 0.136900f, 0.822118f, 0.189848f, 0.511319f, 0.224317f, 0.097844f, 0.862191f, 0.972919f, 0.960835f, 0.906555f, 0.774047f, 0.333145f, 0.081101f, 0.407241f, 0.232234f, 0.132488f, 0.053427f, 0.725594f, 0.011427f, 0.770581f, 0.146947f, 0.079522f, 0.089603f, 0.672048f, 0.245367f, 0.420539f, 0.557369f, 0.860551f, 0.727044f, 0.270328f, 0.131483f, 0.055374f, 0.301599f, 0.262118f, 0.456141f, 0.683281f, 0.695625f, 0.283519f, 0.379927f, 0.181151f, 0.788545f, 0.056848f, 0.696997f, 0.778695f, 0.777408f, 0.259423f, 0.373813f, 0.587600f, 0.272822f, 0.370853f, 0.197054f, 0.459856f, 0.044612f, 0.799796f, 0.076956f, 0.518835f, 0.306810f, 0.577543f, 0.959433f, 0.645570f, 0.035362f, 0.430402f, 0.510017f, 0.536178f, 0.681392f, 0.277596f, 0.128861f, 0.392676f, 0.956406f, 0.187131f, 0.903984f, 0.543806f, 0.456911f, 0.882041f, 0.458604f, 0.724168f, 0.399025f, 0.904044f, 0.690025f, 0.699622f, 0.327720f, 0.756779f, 0.636061f, 0.240020f, 0.160539f, 0.796391f, 0.959167f, 0.458139f, 0.590984f, 0.857723f, 0.457223f, 0.951874f, 0.575751f, 0.820767f, 0.908844f, 0.815524f, 0.159414f, 0.628898f, 0.398434f, 0.062713f, 0.424032f, 0.258684f, 0.849038f, 0.033305f, 0.958983f, 0.355369f, 0.356707f, 0.016329f, 0.185232f, 0.401260f, 0.929291f, 0.099615f, 0.945302f, 0.869489f, 0.454162f, 0.326701f, 0.232744f, 0.614465f, 0.033075f, 0.015606f, 0.428796f, 0.068074f, 0.251941f, 0.221161f, 0.253191f, 0.131055f, 0.012036f, 0.115484f, 0.618480f, 0.974256f, 0.990345f, 0.409054f, 0.162954f, 0.638762f, 0.490305f, 0.989410f, 0.065304f, 0.783234f, 0.288399f, 0.241419f, 0.662505f, 0.246063f, 0.665859f, 0.517309f, 0.424089f, 0.554688f, 0.287052f, 0.706575f, 0.414857f, 0.360546f, 0.828657f, 0.924967f, 0.046007f, 0.232627f, 0.348519f, 0.814966f, 0.985491f, 0.968972f, 0.904948f, 0.296556f, 0.992011f, 0.249420f, 0.105906f, 0.950953f, 0.233420f, 0.689768f, 0.058356f, 0.730709f, 0.881720f, 0.272437f, 0.379057f, 0.374296f, 0.748788f, 0.237807f, 0.171853f, 0.449292f, 0.304468f, 0.839189f, 0.237742f, 0.502389f, 0.942584f, 0.633998f, 0.867289f, 0.940210f, 0.750765f, 0.699575f, 0.967966f, 0.994401f, 0.451822f, 0.070870f, 0.292794f, 0.152355f, 0.417486f, 0.131289f, 0.604118f, 0.382808f, 0.895386f, 0.967795f, 0.546885f, 0.274824f, 0.592230f, 0.896761f, 0.406733f, 0.552078f, 0.271653f, 0.455444f, 0.401714f, 0.248413f, 0.505866f, 0.310381f, 0.373035f, 0.524970f, 0.750595f, 0.333507f, 0.924159f, 0.862319f, 0.048690f, 0.253643f, 0.446136f, 0.104628f, 0.348476f, 0.740098f, 0.680514f, 0.622384f, 0.710528f, 0.204924f, 0.341698f, 0.676242f, 0.879235f, 0.543678f, 0.282700f, 0.030235f, 0.710337f, 0.007884f, 0.372679f, 0.530537f, 0.922111f, 0.089495f, 0.405942f, 0.024313f, 0.342611f, 0.622231f, 0.279068f, 0.209750f, 0.115703f, 0.577140f, 0.695270f, 0.671957f, 0.948861f, 0.002703f, 0.647197f, 0.600392f, 0.588740f, 0.962770f, 0.016872f, 0.696482f, 0.813679f, 0.509807f, 0.333965f, 0.790840f, 0.097243f, 0.442036f, 0.519952f, 0.693956f, 0.090886f, 0.227759f, 0.410302f, 0.623295f, 0.886961f, 0.618826f, 0.133461f, 0.980580f, 0.871786f, 0.502721f, 0.922348f, 0.541381f, 0.923306f, 0.829897f, 0.968286f, 0.919783f, 0.036034f, 0.174772f, 0.389135f, 0.952143f, 0.300029f}; + // {2, 6, 24} + std::vector k = {0.160468f, 0.886305f, 0.446394f, 0.907876f, 0.160230f, 0.661117f, 0.440264f, 0.076487f, 0.696463f, 0.247399f, 0.039616f, 0.059944f, 0.061079f, 0.907733f, 0.739884f, 0.898062f, 0.672582f, 0.528940f, 0.304446f, 0.997962f, 0.362189f, 0.470649f, 0.378245f, 0.979527f, 0.174658f, 0.327988f, 0.680349f, 0.063208f, 0.607249f, 0.477646f, 0.284000f, 0.238413f, 0.514513f, 0.367928f, 0.456520f, 0.337477f, 0.970494f, 0.133439f, 0.096804f, 0.343392f, 0.591027f, 0.659176f, 0.397257f, 0.999278f, 0.351893f, 0.721407f, 0.637583f, 0.813054f, 0.976226f, 0.889794f, 0.764562f, 0.698249f, 0.335498f, 0.147686f, 0.062636f, 0.241902f, 0.432281f, 0.521996f, 0.773084f, 0.958741f, 0.117320f, 0.107004f, 0.589695f, 0.745398f, 0.848150f, 0.935832f, 0.983426f, 0.399802f, 0.380335f, 0.147809f, 0.684934f, 0.656762f, 0.862063f, 0.097258f, 0.497777f, 0.581082f, 0.241557f, 0.169025f, 0.859581f, 0.058535f, 0.470621f, 0.115834f, 0.457059f, 0.979962f, 0.423706f, 0.857125f, 0.117316f, 0.271252f, 0.403793f, 0.399812f, 0.671384f, 0.344718f, 0.713767f, 0.639187f, 0.399161f, 0.431760f, 0.614528f, 0.070042f, 0.822407f, 0.653421f, 0.726342f, 0.536923f, 0.110477f, 0.405036f, 0.405374f, 0.321043f, 0.029950f, 0.737254f, 0.109784f, 0.606308f, 0.703218f, 0.634786f, 0.959142f, 0.103298f, 0.867167f, 0.029190f, 0.534917f, 0.404244f, 0.524184f, 0.365100f, 0.190567f, 0.019123f, 0.518150f, 0.842777f, 0.373216f, 0.222864f, 0.080532f, 0.085311f, 0.221396f, 0.100014f, 0.265040f, 0.066149f, 0.065605f, 0.856276f, 0.162120f, 0.559682f, 0.773456f, 0.456410f, 0.153369f, 0.199596f, 0.432984f, 0.528234f, 0.349440f, 0.781480f, 0.751022f, 0.927212f, 0.028953f, 0.895691f, 0.392569f, 0.878372f, 0.690785f, 0.987349f, 0.759282f, 0.364545f, 0.501063f, 0.376389f, 0.364912f, 0.260904f, 0.495970f, 0.681740f, 0.277340f, 0.524380f, 0.117380f, 0.159845f, 0.046806f, 0.970731f, 0.003860f, 0.178580f, 0.612867f, 0.081370f, 0.881896f, 0.719620f, 0.966390f, 0.507636f, 0.300404f, 0.549501f, 0.930819f, 0.520761f, 0.267207f, 0.877399f, 0.371919f, 0.001383f, 0.247685f, 0.318234f, 0.858777f, 0.458503f, 0.444587f, 0.336102f, 0.880678f, 0.945027f, 0.991890f, 0.376741f, 0.966147f, 0.791880f, 0.675689f, 0.244889f, 0.216457f, 0.166048f, 0.922757f, 0.294077f, 0.453094f, 0.493958f, 0.778172f, 0.844235f, 0.139073f, 0.426904f, 0.842855f, 0.818033f, 0.102414f, 0.156383f, 0.304199f, 0.075359f, 0.424663f, 0.107618f, 0.568218f, 0.246557f, 0.596433f, 0.117526f, 0.975884f, 0.932561f, 0.391797f, 0.242179f, 0.250398f, 0.483394f, 0.039993f, 0.639705f, 0.408303f, 0.377407f, 0.809365f, 0.709035f, 0.954334f, 0.351936f, 0.897543f, 0.769967f, 0.357425f, 0.621665f, 0.288570f, 0.874400f, 0.112427f, 0.212434f, 0.183033f, 0.403026f, 0.745233f, 0.526907f, 0.487676f, 0.000546f, 0.425402f, 0.063554f, 0.208253f, 0.932394f, 0.215398f, 0.858338f, 0.802893f, 0.159146f, 0.605712f, 0.115662f, 0.727888f, 0.637462f, 0.811939f, 0.479385f, 0.914863f, 0.049349f, 0.292889f, 0.715053f, 0.418109f, 0.172951f, 0.107211f, 0.817339f, 0.473143f, 0.882284f, 0.733289f, 0.409726f, 0.373511f, 0.515638f, 0.889060f, 0.737279f, 0.005153f, 0.694158f, 0.919507f, 0.710456f, 0.177006f, 0.483518f, 0.140316f, 0.358995f, 0.937117f, 0.923305f, 0.282837f, 0.339631f}; + // {2, 6, 24} + std::vector v = {0.600213f, 0.963197f, 0.147801f, 0.256917f, 0.873557f, 0.491892f, 0.898961f, 0.185518f, 0.532669f, 0.326270f, 0.316543f, 0.446877f, 0.433077f, 0.357347f, 0.914971f, 0.731744f, 0.727547f, 0.289913f, 0.577709f, 0.779179f, 0.795590f, 0.344530f, 0.770873f, 0.735894f, 0.141506f, 0.865945f, 0.441321f, 0.486410f, 0.448369f, 0.567846f, 0.621169f, 0.498180f, 0.866789f, 0.627735f, 0.401428f, 0.416692f, 0.810839f, 0.348192f, 0.211455f, 0.059383f, 0.876027f, 0.918546f, 0.120120f, 0.334474f, 0.175372f, 0.115898f, 0.899867f, 0.056877f, 0.980486f, 0.096451f, 0.863471f, 0.566506f, 0.367917f, 0.342342f, 0.757364f, 0.314573f, 0.657319f, 0.517326f, 0.484966f, 0.901162f, 0.554645f, 0.826862f, 0.725574f, 0.038557f, 0.773110f, 0.216870f, 0.903150f, 0.042924f, 0.333072f, 0.099733f, 0.475589f, 0.820022f, 0.298187f, 0.150935f, 0.330267f, 0.813880f, 0.140384f, 0.227362f, 0.068852f, 0.705710f, 0.395233f, 0.310840f, 0.718626f, 0.335978f, 0.727771f, 0.815199f, 0.217663f, 0.973819f, 0.162358f, 0.290841f, 0.179795f, 0.345506f, 0.480061f, 0.522176f, 0.853606f, 0.889448f, 0.220104f, 0.622894f, 0.111496f, 0.458970f, 0.322334f, 0.316501f, 0.482584f, 0.729828f, 0.069183f, 0.879173f, 0.734814f, 0.176499f, 0.939161f, 0.506312f, 0.999809f, 0.197259f, 0.534908f, 0.290248f, 0.304174f, 0.591065f, 0.921719f, 0.805264f, 0.723941f, 0.559174f, 0.922298f, 0.492361f, 0.873832f, 0.833982f, 0.213835f, 0.771225f, 0.012171f, 0.322830f, 0.229567f, 0.506863f, 0.736853f, 0.097676f, 0.514922f, 0.938412f, 0.228647f, 0.677141f, 0.592880f, 0.010064f, 0.475826f, 0.708770f, 0.043975f, 0.879521f, 0.520081f, 0.030661f, 0.224414f, 0.953676f, 0.582320f, 0.107473f, 0.287544f, 0.456704f, 0.020950f, 0.411616f, 0.489459f, 0.243678f, 0.588639f, 0.753240f, 0.235834f, 0.620500f, 0.639622f, 0.948540f, 0.778276f, 0.848345f, 0.490420f, 0.185349f, 0.995815f, 0.129356f, 0.471457f, 0.068093f, 0.943851f, 0.964925f, 0.719389f, 0.349993f, 0.254382f, 0.265303f, 0.127294f, 0.525809f, 0.141817f, 0.316731f, 0.626706f, 0.727544f, 0.024273f, 0.430116f, 0.652125f, 0.853246f, 0.475325f, 0.969206f, 0.265633f, 0.013509f, 0.483753f, 0.256114f, 0.823718f, 0.232773f, 0.310629f, 0.791227f, 0.715143f, 0.558051f, 0.704948f, 0.418637f, 0.005310f, 0.011355f, 0.511222f, 0.083291f, 0.051075f, 0.965517f, 0.859003f, 0.152027f, 0.000664f, 0.941668f, 0.278325f, 0.185898f, 0.691508f, 0.108904f, 0.264650f, 0.975095f, 0.639463f, 0.520678f, 0.397919f, 0.774501f, 0.140957f, 0.967338f, 0.861123f, 0.617657f, 0.042906f, 0.700856f, 0.913284f, 0.524577f, 0.354225f, 0.120277f, 0.754901f, 0.885022f, 0.100252f, 0.758985f, 0.017060f, 0.967055f, 0.615058f, 0.552439f, 0.295950f, 0.929292f, 0.265906f, 0.828147f, 0.985109f, 0.783397f, 0.518990f, 0.066074f, 0.472414f, 0.438256f, 0.202796f, 0.423588f, 0.357758f, 0.163684f, 0.441374f, 0.262800f, 0.522062f, 0.035160f, 0.906231f, 0.816364f, 0.552581f, 0.851809f, 0.962395f, 0.110522f, 0.630832f, 0.997994f, 0.987889f, 0.603323f, 0.128021f, 0.583193f, 0.002065f, 0.198911f, 0.956123f, 0.330441f, 0.638390f, 0.280860f, 0.947822f, 0.728559f, 0.329651f, 0.791761f, 0.108166f, 0.392319f, 0.221218f, 0.683726f, 0.102446f, 0.397026f, 0.276650f, 0.506343f, 0.349898f, 0.706411f, 0.024577f, 0.633987f}; + // {2, 4, 72} + std::vector y = {0.532009f, 0.526025f, 0.449746f, 0.551692f, 0.407822f, 0.436275f, 0.507807f, 0.457324f, 0.530536f, 0.517111f, 0.452785f, 0.557318f, 0.397721f, 0.434161f, 0.498276f, 0.464536f, 0.528016f, 0.548671f, 0.441040f, 0.542961f, 0.418557f, 0.444397f, 0.515088f, 0.452512f, 0.462161f, 0.530536f, 0.564630f, 0.418701f, 0.669452f, 0.633554f, 0.569379f, 0.430544f, 0.456026f, 0.529795f, 0.558238f, 0.411985f, 0.664240f, 0.619959f, 0.590516f, 0.438577f, 0.471552f, 0.521718f, 0.560465f, 0.404206f, 0.663920f, 0.628819f, 0.540935f, 0.447763f, 0.615083f, 0.344791f, 0.432664f, 0.451253f, 0.460813f, 0.441267f, 0.708582f, 0.530088f, 0.623659f, 0.343547f, 0.439418f, 0.450767f, 0.460055f, 0.442001f, 0.703292f, 0.522883f, 0.617738f, 0.343160f, 0.440540f, 0.440079f, 0.459815f, 0.436860f, 0.703290f, 0.534856f, 0.536138f, 0.499439f, 0.465771f, 0.565138f, 0.391402f, 0.430258f, 0.494915f, 0.463613f, 0.532752f, 0.526358f, 0.452075f, 0.562130f, 0.402551f, 0.442784f, 0.486721f, 0.456955f, 0.547578f, 0.527342f, 0.453800f, 0.548887f, 0.418444f, 0.438968f, 0.515475f, 0.444207f, 0.475352f, 0.524010f, 0.549702f, 0.420030f, 0.656346f, 0.620729f, 0.571884f, 0.431010f, 0.453307f, 0.522210f, 0.563368f, 0.412061f, 0.657897f, 0.634999f, 0.577458f, 0.451691f, 0.473936f, 0.524285f, 0.553525f, 0.421768f, 0.662288f, 0.622833f, 0.570081f, 0.432808f, 0.625738f, 0.353159f, 0.436185f, 0.448597f, 0.459371f, 0.429822f, 0.709026f, 0.526207f, 0.630878f, 0.351036f, 0.439799f, 0.452249f, 0.456486f, 0.431906f, 0.706014f, 0.518897f, 0.629526f, 0.351482f, 0.440728f, 0.449287f, 0.451705f, 0.426815f, 0.706598f, 0.522028f, 0.537899f, 0.527199f, 0.447980f, 0.548688f, 0.410653f, 0.436181f, 0.511135f, 0.455244f, 0.534560f, 0.540045f, 0.447505f, 0.552786f, 0.413302f, 0.446360f, 0.499945f, 0.450757f, 0.531708f, 0.526097f, 0.450511f, 0.553372f, 0.401450f, 0.438186f, 0.501418f, 0.462466f, 0.469643f, 0.527539f, 0.553613f, 0.418159f, 0.659814f, 0.622731f, 0.575224f, 0.429425f, 0.463941f, 0.524481f, 0.557632f, 0.413729f, 0.657415f, 0.629157f, 0.570920f, 0.439773f, 0.479643f, 0.526773f, 0.556809f, 0.422406f, 0.670038f, 0.625300f, 0.554451f, 0.426587f, 0.630894f, 0.353011f, 0.444285f, 0.443177f, 0.448608f, 0.419312f, 0.705883f, 0.526260f, 0.631310f, 0.347563f, 0.445672f, 0.446224f, 0.448210f, 0.428481f, 0.702004f, 0.519990f, 0.626158f, 0.342802f, 0.449770f, 0.440666f, 0.453705f, 0.427492f, 0.700510f, 0.533279f, 0.526144f, 0.538202f, 0.443619f, 0.551579f, 0.407162f, 0.442426f, 0.499995f, 0.459987f, 0.525627f, 0.544718f, 0.448060f, 0.544942f, 0.415781f, 0.444198f, 0.516948f, 0.452985f, 0.521784f, 0.523083f, 0.450924f, 0.565538f, 0.392054f, 0.440702f, 0.479094f, 0.468113f, 0.473886f, 0.523677f, 0.555144f, 0.409412f, 0.664285f, 0.620163f, 0.555448f, 0.440947f, 0.459210f, 0.528829f, 0.567231f, 0.413602f, 0.672778f, 0.632467f, 0.565881f, 0.439895f, 0.480238f, 0.525127f, 0.554365f, 0.431656f, 0.658900f, 0.634358f, 0.561181f, 0.419623f, 0.646099f, 0.364754f, 0.442180f, 0.450340f, 0.441320f, 0.412523f, 0.708121f, 0.505939f, 0.641772f, 0.375478f, 0.428502f, 0.454772f, 0.439016f, 0.407773f, 0.718457f, 0.504047f, 0.628271f, 0.345239f, 0.449391f, 0.436208f, 0.448766f, 0.426444f, 0.699202f, 0.528374f, 0.489165f, 0.818278f, 0.467403f, 0.370507f, 0.572406f, 0.417942f, 0.160316f, 0.384139f, 0.497723f, 0.820329f, 0.455669f, 0.373132f, 0.568626f, 0.418602f, 0.164551f, 0.404233f, 0.488972f, 0.813399f, 0.460936f, 0.369774f, 0.580477f, 0.417018f, 0.167442f, 0.381535f, 0.603715f, 0.360599f, 0.371685f, 0.614777f, 0.440767f, 0.425124f, 0.369342f, 0.828101f, 0.584460f, 0.352249f, 0.382191f, 0.613073f, 0.431223f, 0.421802f, 0.389292f, 0.831202f, 0.590574f, 0.355658f, 0.373391f, 0.623741f, 0.432416f, 0.412097f, 0.378312f, 0.829226f, 0.365226f, 0.726961f, 0.549872f, 0.239494f, 0.496434f, 0.668542f, 0.557774f, 0.487281f, 0.361340f, 0.749156f, 0.523408f, 0.240555f, 0.493770f, 0.639516f, 0.552116f, 0.478230f, 0.367118f, 0.740114f, 0.563789f, 0.238852f, 0.498407f, 0.682064f, 0.571327f, 0.496416f, 0.480636f, 0.820258f, 0.464776f, 0.362168f, 0.567256f, 0.417842f, 0.161815f, 0.387104f, 0.486998f, 0.821507f, 0.467362f, 0.377934f, 0.569593f, 0.418367f, 0.156778f, 0.390179f, 0.461449f, 0.823726f, 0.471401f, 0.361646f, 0.563554f, 0.418609f, 0.154999f, 0.379696f, 0.565916f, 0.345293f, 0.392969f, 0.612305f, 0.418858f, 0.416238f, 0.410985f, 0.833515f, 0.552881f, 0.338985f, 0.394863f, 0.597100f, 0.422296f, 0.401025f, 0.427810f, 0.831702f, 0.558983f, 0.339943f, 0.393544f, 0.583418f, 0.432193f, 0.405729f, 0.426401f, 0.830305f, 0.362801f, 0.731181f, 0.546338f, 0.247016f, 0.499389f, 0.662441f, 0.544727f, 0.486631f, 0.355514f, 0.726998f, 0.518056f, 0.249475f, 0.492155f, 0.643678f, 0.531052f, 0.481617f, 0.370308f, 0.743741f, 0.562172f, 0.233361f, 0.498431f, 0.679567f, 0.580747f, 0.494199f, 0.481097f, 0.817782f, 0.461707f, 0.369188f, 0.573825f, 0.419752f, 0.161614f, 0.386708f, 0.472911f, 0.822003f, 0.473412f, 0.375830f, 0.569966f, 0.422158f, 0.149228f, 0.380008f, 0.454662f, 0.818956f, 0.465984f, 0.370169f, 0.575537f, 0.423344f, 0.153818f, 0.375466f, 0.572526f, 0.348075f, 0.380718f, 0.641409f, 0.417012f, 0.407621f, 0.389074f, 0.834251f, 0.581008f, 0.348183f, 0.383659f, 0.608061f, 0.435032f, 0.422240f, 0.393710f, 0.832528f, 0.600530f, 0.360439f, 0.371006f, 0.609018f, 0.441082f, 0.416286f, 0.374920f, 0.825853f, 0.364932f, 0.727047f, 0.540001f, 0.246375f, 0.501524f, 0.656266f, 0.541761f, 0.482865f, 0.360322f, 0.752650f, 0.542120f, 0.239561f, 0.491207f, 0.663446f, 0.566643f, 0.491988f, 0.364532f, 0.737402f, 0.546869f, 0.240953f, 0.497072f, 0.664793f, 0.558528f, 0.488182f, 0.490592f, 0.819727f, 0.468739f, 0.379671f, 0.572959f, 0.422399f, 0.152699f, 0.387445f, 0.462308f, 0.822644f, 0.463886f, 0.374320f, 0.569615f, 0.423238f, 0.152603f, 0.387850f, 0.451896f, 0.818576f, 0.449904f, 0.362889f, 0.573917f, 0.421849f, 0.165145f, 0.390440f, 0.565044f, 0.343397f, 0.395512f, 0.584043f, 0.431062f, 0.417783f, 0.421165f, 0.830938f, 0.583998f, 0.354061f, 0.374016f, 0.633981f, 0.424457f, 0.404069f, 0.381920f, 0.829920f, 0.568315f, 0.347357f, 0.386911f, 0.624227f, 0.418162f, 0.411256f, 0.400332f, 0.832994f, 0.370475f, 0.739716f, 0.551429f, 0.234114f, 0.499500f, 0.665245f, 0.570648f, 0.485298f, 0.364035f, 0.756092f, 0.542251f, 0.238706f, 0.495463f, 0.659518f, 0.567976f, 0.489204f, 0.368942f, 0.756397f, 0.548083f, 0.231854f, 0.496617f, 0.659726f, 0.578330f, 0.484921f}; + + ASSERT_EQ(q.size(), batch_size * q_num_heads * q_sequence_length * head_size); + ASSERT_EQ(k.size(), batch_size * kv_num_heads * kv_sequence_length * head_size); + ASSERT_EQ(v.size(), batch_size * kv_num_heads * kv_sequence_length * v_head_size); + ASSERT_EQ(y.size(), batch_size * q_num_heads * q_sequence_length * v_head_size); + + RunTest3D(batch_size, q_num_heads, q_sequence_length, head_size, kv_sequence_length, kv_num_heads, v_head_size, past_sequence_length, + q, k, v, std::vector(), std::initializer_list(), std::vector(), std::vector(), + -1, -1, std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), -1, TensorType::kFloat, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type + y, std::vector(), std::vector(), std::vector(), + false, true, true // disable_cpu, disable_cuda, disable_dml + ); +} + TEST(AttentionTest, Attention4DGqaAttnMask) { int batch_size = 2; // Q.shape[0] int q_num_heads = 9; // Q.shape[1] @@ -847,7 +880,7 @@ TEST(AttentionTest, Attention4DGqaAttnMask) { // {4, 6} std::vector m = {0.230571f, 0.268709f, 0.800256f, 0.955568f, 0.316550f, 0.826805f, 0.103991f, 0.633982f, 0.751032f, 0.155978f, 0.426002f, 0.892707f, 0.103578f, 0.018096f, 0.590585f, 0.435532f, 0.798689f, 0.923456f, 0.299154f, 0.388404f, 0.486272f, 0.588151f, 0.983854f, 0.697330f}; // {2, 9, 4, 8} - std::vector y = {0.641842f, 0.667534f, 0.339592f, 0.480609f, 0.537525f, 0.340368f, 0.752882f, 0.387601f, 0.686814f, 0.643437f, 0.324983f, 0.468788f, 0.539061f, 0.319610f, 0.754181f, 0.373093f, 0.702380f, 0.693136f, 0.318406f, 0.456714f, 0.540838f, 0.315487f, 0.718291f, 0.311025f, 0.681769f, 0.670603f, 0.329705f, 0.456661f, 0.573902f, 0.337385f, 0.700597f, 0.333385f, 0.508992f, 0.253478f, 0.553979f, 0.466355f, 0.398637f, 0.412493f, 0.495810f, 0.677675f, 0.521609f, 0.278997f, 0.564189f, 0.434417f, 0.448085f, 0.467205f, 0.567856f, 0.664713f, 0.490146f, 0.261321f, 0.560582f, 0.424598f, 0.450318f, 0.467336f, 0.520983f, 0.720798f, 0.516095f, 0.264495f, 0.577940f, 0.475340f, 0.444145f, 0.477909f, 0.485663f, 0.672846f, 0.499389f, 0.402198f, 0.520218f, 0.550550f, 0.481065f, 0.730488f, 0.492535f, 0.392315f, 0.436722f, 0.398514f, 0.497457f, 0.502270f, 0.520993f, 0.730472f, 0.565429f, 0.380282f, 0.461226f, 0.392968f, 0.536035f, 0.505191f, 0.446570f, 0.751253f, 0.478584f, 0.389036f, 0.423738f, 0.443828f, 0.554323f, 0.462607f, 0.476656f, 0.733228f, 0.482219f, 0.411910f, 0.620556f, 0.662948f, 0.349409f, 0.482541f, 0.537250f, 0.351544f, 0.734285f, 0.397172f, 0.689500f, 0.637077f, 0.320710f, 0.470914f, 0.526307f, 0.312878f, 0.775762f, 0.384457f, 0.696615f, 0.681034f, 0.324383f, 0.459632f, 0.539497f, 0.317950f, 0.709736f, 0.320698f, 0.671696f, 0.676830f, 0.332387f, 0.453234f, 0.578648f, 0.345084f, 0.685369f, 0.328092f, 0.520830f, 0.251061f, 0.562824f, 0.469184f, 0.393635f, 0.405203f, 0.493565f, 0.668713f, 0.541328f, 0.282797f, 0.577903f, 0.434065f, 0.444664f, 0.460403f, 0.572628f, 0.646402f, 0.493508f, 0.265246f, 0.572078f, 0.418658f, 0.464491f, 0.483746f, 0.516536f, 0.724847f, 0.503705f, 0.270557f, 0.577678f, 0.465114f, 0.468430f, 0.508402f, 0.489087f, 0.689442f, 0.500042f, 0.410507f, 0.521381f, 0.553244f, 0.459062f, 0.719706f, 0.476571f, 0.395052f, 0.429926f, 0.408857f, 0.507006f, 0.493937f, 0.529878f, 0.728873f, 0.571495f, 0.376256f, 0.453676f, 0.380482f, 0.526100f, 0.496696f, 0.457383f, 0.761933f, 0.486657f, 0.396608f, 0.435748f, 0.432822f, 0.531763f, 0.482255f, 0.477046f, 0.726381f, 0.487480f, 0.416572f, 0.626676f, 0.683736f, 0.340657f, 0.475002f, 0.549981f, 0.353311f, 0.740157f, 0.378827f, 0.681403f, 0.636622f, 0.324593f, 0.469088f, 0.537323f, 0.321344f, 0.762506f, 0.384239f, 0.693108f, 0.683351f, 0.329873f, 0.460504f, 0.555115f, 0.325379f, 0.694659f, 0.316422f, 0.677285f, 0.670298f, 0.329724f, 0.456327f, 0.567533f, 0.337560f, 0.701396f, 0.336191f, 0.515940f, 0.251020f, 0.562035f, 0.442479f, 0.405802f, 0.410828f, 0.519841f, 0.686781f, 0.522057f, 0.285013f, 0.562761f, 0.453472f, 0.451971f, 0.481286f, 0.558322f, 0.649971f, 0.486787f, 0.258011f, 0.557963f, 0.426743f, 0.442028f, 0.457034f, 0.510534f, 0.724945f, 0.498901f, 0.272090f, 0.572650f, 0.467930f, 0.465335f, 0.506181f, 0.484559f, 0.690090f, 0.499525f, 0.398443f, 0.522291f, 0.550620f, 0.465209f, 0.731897f, 0.484389f, 0.388997f, 0.411109f, 0.420719f, 0.523354f, 0.478677f, 0.522513f, 0.723052f, 0.587358f, 0.350775f, 0.450881f, 0.384685f, 0.527140f, 0.502089f, 0.438660f, 0.749234f, 0.493312f, 0.377459f, 0.425945f, 0.432397f, 0.544111f, 0.466484f, 0.488077f, 0.738712f, 0.493642f, 0.412262f, 0.565934f, 0.795554f, 0.527262f, 0.295395f, 0.394937f, 0.326235f, 0.457519f, 0.454071f, 0.511390f, 0.753500f, 0.500815f, 0.303925f, 0.403792f, 0.343750f, 0.516333f, 0.463035f, 0.491925f, 0.753119f, 0.503555f, 0.310489f, 0.373396f, 0.334562f, 0.526486f, 0.470500f, 0.495985f, 0.733211f, 0.532951f, 0.342292f, 0.346065f, 0.355272f, 0.479542f, 0.509107f, 0.379088f, 0.582413f, 0.414383f, 0.571800f, 0.613176f, 0.687631f, 0.185596f, 0.656867f, 0.390452f, 0.532452f, 0.407547f, 0.564799f, 0.606499f, 0.653258f, 0.176547f, 0.698038f, 0.410398f, 0.604586f, 0.442972f, 0.497533f, 0.595085f, 0.732265f, 0.187201f, 0.663169f, 0.448716f, 0.590302f, 0.411879f, 0.518449f, 0.636722f, 0.695827f, 0.154292f, 0.666828f, 0.458054f, 0.608582f, 0.430376f, 0.316371f, 0.547620f, 0.542559f, 0.542043f, 0.556297f, 0.468371f, 0.559154f, 0.465195f, 0.344099f, 0.482571f, 0.527115f, 0.527529f, 0.616254f, 0.494566f, 0.605555f, 0.432360f, 0.382197f, 0.466678f, 0.556031f, 0.459313f, 0.588575f, 0.532798f, 0.597684f, 0.412305f, 0.393400f, 0.462773f, 0.491821f, 0.483189f, 0.593919f, 0.569241f, 0.793791f, 0.532988f, 0.300026f, 0.393843f, 0.327085f, 0.448199f, 0.457416f, 0.493302f, 0.725336f, 0.512066f, 0.327500f, 0.404238f, 0.351704f, 0.507818f, 0.477990f, 0.479548f, 0.756083f, 0.511730f, 0.309729f, 0.366024f, 0.338031f, 0.503335f, 0.472352f, 0.473026f, 0.696816f, 0.543129f, 0.374608f, 0.335432f, 0.360978f, 0.486364f, 0.531799f, 0.380422f, 0.599984f, 0.413640f, 0.564090f, 0.607571f, 0.708289f, 0.187551f, 0.671587f, 0.381058f, 0.550543f, 0.422336f, 0.556663f, 0.599418f, 0.666369f, 0.182365f, 0.678737f, 0.423800f, 0.600509f, 0.437094f, 0.494968f, 0.603340f, 0.727226f, 0.179659f, 0.667114f, 0.464399f, 0.563292f, 0.399716f, 0.529198f, 0.655782f, 0.666396f, 0.143497f, 0.659062f, 0.453034f, 0.596627f, 0.417365f, 0.314318f, 0.554269f, 0.518967f, 0.550250f, 0.556252f, 0.494918f, 0.587774f, 0.467566f, 0.350222f, 0.481994f, 0.538857f, 0.525631f, 0.605359f, 0.497486f, 0.608472f, 0.429145f, 0.384532f, 0.466790f, 0.554752f, 0.457698f, 0.586510f, 0.548577f, 0.604359f, 0.398097f, 0.414429f, 0.448200f, 0.485158f, 0.461395f, 0.593015f, 0.563470f, 0.796184f, 0.532783f, 0.293209f, 0.408910f, 0.327450f, 0.438028f, 0.447011f, 0.493041f, 0.739603f, 0.496957f, 0.311881f, 0.389768f, 0.352503f, 0.530113f, 0.476738f, 0.484897f, 0.752985f, 0.511921f, 0.312174f, 0.370408f, 0.339775f, 0.504061f, 0.473793f, 0.487978f, 0.714687f, 0.538817f, 0.358426f, 0.348908f, 0.355820f, 0.481380f, 0.516214f, 0.370872f, 0.602034f, 0.400225f, 0.611090f, 0.630508f, 0.662527f, 0.162489f, 0.658299f, 0.378734f, 0.537283f, 0.412214f, 0.570032f, 0.601452f, 0.653569f, 0.179932f, 0.693105f, 0.411981f, 0.605715f, 0.448022f, 0.481469f, 0.585099f, 0.748463f, 0.195177f, 0.671915f, 0.442141f, 0.581881f, 0.393362f, 0.555388f, 0.650764f, 0.665937f, 0.141141f, 0.675100f, 0.448606f, 0.605061f, 0.412183f, 0.312673f, 0.559178f, 0.530440f, 0.538275f, 0.546820f, 0.494936f, 0.585982f, 0.469875f, 0.355291f, 0.474437f, 0.542980f, 0.518181f, 0.609491f, 0.522046f, 0.618936f, 0.412090f, 0.410711f, 0.452217f, 0.540284f, 0.444109f, 0.585510f, 0.570158f, 0.614413f, 0.415425f, 0.410005f, 0.441791f, 0.491080f, 0.466021f, 0.595833f}; + std::vector y = {0.641842f, 0.667534f, 0.339592f, 0.480609f, 0.537525f, 0.340368f, 0.752882f, 0.387601f, 0.686814f, 0.643437f, 0.324983f, 0.468788f, 0.539061f, 0.319610f, 0.754181f, 0.373093f, 0.702380f, 0.693136f, 0.318406f, 0.456714f, 0.540838f, 0.315487f, 0.718291f, 0.311025f, 0.681769f, 0.670603f, 0.329705f, 0.456661f, 0.573902f, 0.337385f, 0.700597f, 0.333385f, 0.644472f, 0.666279f, 0.336558f, 0.478260f, 0.534820f, 0.338286f, 0.756443f, 0.387184f, 0.674255f, 0.645509f, 0.327427f, 0.465534f, 0.543598f, 0.328256f, 0.743604f, 0.373978f, 0.689753f, 0.687485f, 0.332246f, 0.457085f, 0.565540f, 0.331625f, 0.677863f, 0.308191f, 0.663033f, 0.669169f, 0.333832f, 0.452516f, 0.576569f, 0.348823f, 0.685447f, 0.338196f, 0.613061f, 0.681689f, 0.345384f, 0.474784f, 0.541609f, 0.357958f, 0.728217f, 0.383408f, 0.680108f, 0.637886f, 0.329455f, 0.469504f, 0.544973f, 0.325193f, 0.745572f, 0.378169f, 0.695405f, 0.687321f, 0.323229f, 0.456101f, 0.553544f, 0.323743f, 0.706057f, 0.314785f, 0.672814f, 0.678842f, 0.323628f, 0.449345f, 0.572724f, 0.342071f, 0.707722f, 0.332714f, 0.512254f, 0.252087f, 0.555774f, 0.456582f, 0.393340f, 0.400567f, 0.501655f, 0.680466f, 0.530775f, 0.288611f, 0.570275f, 0.444357f, 0.454871f, 0.480588f, 0.567893f, 0.645871f, 0.491847f, 0.262209f, 0.561930f, 0.418081f, 0.444398f, 0.456345f, 0.519658f, 0.722565f, 0.523232f, 0.267034f, 0.591659f, 0.459565f, 0.462164f, 0.494775f, 0.497558f, 0.678628f, 0.520830f, 0.251061f, 0.562824f, 0.469184f, 0.393635f, 0.405203f, 0.493565f, 0.668713f, 0.541328f, 0.282797f, 0.577903f, 0.434065f, 0.444664f, 0.460403f, 0.572628f, 0.646402f, 0.493508f, 0.265246f, 0.572078f, 0.418658f, 0.464491f, 0.483746f, 0.516536f, 0.724847f, 0.503705f, 0.270557f, 0.577678f, 0.465114f, 0.468430f, 0.508402f, 0.489087f, 0.689442f, 0.513034f, 0.252153f, 0.561841f, 0.455825f, 0.411518f, 0.424734f, 0.508095f, 0.683202f, 0.537475f, 0.278680f, 0.572605f, 0.449901f, 0.433722f, 0.452424f, 0.554372f, 0.643199f, 0.503808f, 0.259719f, 0.571011f, 0.415224f, 0.442363f, 0.450636f, 0.525191f, 0.716156f, 0.524579f, 0.263175f, 0.588806f, 0.462952f, 0.450874f, 0.480435f, 0.495070f, 0.675950f, 0.503113f, 0.409947f, 0.538941f, 0.550010f, 0.457564f, 0.729741f, 0.472483f, 0.384586f, 0.421666f, 0.416784f, 0.522405f, 0.484472f, 0.519795f, 0.728113f, 0.570887f, 0.363251f, 0.462182f, 0.372738f, 0.510951f, 0.511798f, 0.446353f, 0.754695f, 0.485592f, 0.397135f, 0.421437f, 0.447040f, 0.546262f, 0.462919f, 0.473860f, 0.726421f, 0.479062f, 0.420641f, 0.498228f, 0.402912f, 0.524895f, 0.548811f, 0.462668f, 0.729601f, 0.480759f, 0.390396f, 0.421638f, 0.418506f, 0.518644f, 0.484993f, 0.512452f, 0.724489f, 0.562537f, 0.370564f, 0.461864f, 0.376424f, 0.511195f, 0.510163f, 0.461531f, 0.755198f, 0.491549f, 0.400847f, 0.425338f, 0.456035f, 0.553542f, 0.466468f, 0.482400f, 0.722062f, 0.483532f, 0.415135f, 0.499525f, 0.398443f, 0.522291f, 0.550620f, 0.465209f, 0.731897f, 0.484389f, 0.388997f, 0.411109f, 0.420719f, 0.523354f, 0.478677f, 0.522513f, 0.723052f, 0.587358f, 0.350775f, 0.450881f, 0.384685f, 0.527140f, 0.502089f, 0.438660f, 0.749234f, 0.493312f, 0.377459f, 0.425945f, 0.432397f, 0.544111f, 0.466484f, 0.488077f, 0.738712f, 0.493642f, 0.412262f, 0.565934f, 0.795554f, 0.527262f, 0.295395f, 0.394937f, 0.326235f, 0.457519f, 0.454071f, 0.511390f, 0.753500f, 0.500815f, 0.303925f, 0.403792f, 0.343750f, 0.516333f, 0.463035f, 0.491925f, 0.753119f, 0.503555f, 0.310489f, 0.373396f, 0.334562f, 0.526486f, 0.470500f, 0.495985f, 0.733211f, 0.532951f, 0.342292f, 0.346065f, 0.355272f, 0.479542f, 0.509107f, 0.560779f, 0.795626f, 0.527843f, 0.292198f, 0.403399f, 0.328103f, 0.449548f, 0.449270f, 0.492632f, 0.741337f, 0.501964f, 0.308729f, 0.404425f, 0.353946f, 0.510715f, 0.469292f, 0.498506f, 0.749246f, 0.510938f, 0.317603f, 0.377607f, 0.333171f, 0.516589f, 0.472113f, 0.494030f, 0.738331f, 0.525273f, 0.334388f, 0.351797f, 0.349013f, 0.492978f, 0.499192f, 0.558701f, 0.785575f, 0.541472f, 0.309741f, 0.379566f, 0.336180f, 0.433460f, 0.471779f, 0.500494f, 0.748997f, 0.495158f, 0.302537f, 0.401868f, 0.348977f, 0.525071f, 0.465493f, 0.496427f, 0.763380f, 0.504640f, 0.303037f, 0.375539f, 0.332025f, 0.517142f, 0.464096f, 0.466789f, 0.731320f, 0.529262f, 0.338950f, 0.329005f, 0.361720f, 0.481664f, 0.514476f, 0.356477f, 0.623874f, 0.420893f, 0.592125f, 0.610336f, 0.687956f, 0.174269f, 0.652548f, 0.366057f, 0.567382f, 0.428770f, 0.553226f, 0.582617f, 0.683498f, 0.188604f, 0.695704f, 0.406930f, 0.625170f, 0.441775f, 0.499327f, 0.590722f, 0.740689f, 0.180721f, 0.681143f, 0.430954f, 0.584531f, 0.412720f, 0.532459f, 0.630830f, 0.690216f, 0.161882f, 0.663851f, 0.380422f, 0.599984f, 0.413640f, 0.564090f, 0.607571f, 0.708289f, 0.187551f, 0.671587f, 0.381058f, 0.550543f, 0.422336f, 0.556663f, 0.599418f, 0.666369f, 0.182365f, 0.678737f, 0.423800f, 0.600509f, 0.437094f, 0.494968f, 0.603340f, 0.727226f, 0.179659f, 0.667114f, 0.464399f, 0.563292f, 0.399716f, 0.529198f, 0.655782f, 0.666396f, 0.143497f, 0.659062f, 0.365268f, 0.611770f, 0.413907f, 0.600775f, 0.622849f, 0.667798f, 0.164152f, 0.647839f, 0.377540f, 0.543255f, 0.401769f, 0.588162f, 0.610896f, 0.645976f, 0.172500f, 0.695675f, 0.428349f, 0.590245f, 0.429343f, 0.497694f, 0.606978f, 0.727059f, 0.182826f, 0.671502f, 0.466759f, 0.580932f, 0.396764f, 0.527984f, 0.655065f, 0.677027f, 0.138356f, 0.672848f, 0.431113f, 0.593599f, 0.391529f, 0.327778f, 0.551802f, 0.526872f, 0.512055f, 0.547473f, 0.461591f, 0.564565f, 0.469932f, 0.335454f, 0.493299f, 0.536959f, 0.537769f, 0.611109f, 0.505296f, 0.606927f, 0.414343f, 0.395585f, 0.462205f, 0.538029f, 0.450814f, 0.585742f, 0.550355f, 0.606479f, 0.419783f, 0.396625f, 0.449703f, 0.500831f, 0.464506f, 0.594653f, 0.460993f, 0.609826f, 0.424563f, 0.322395f, 0.546231f, 0.537700f, 0.541169f, 0.555672f, 0.479953f, 0.573210f, 0.449011f, 0.356276f, 0.482535f, 0.523785f, 0.516393f, 0.605958f, 0.473948f, 0.587667f, 0.412118f, 0.378344f, 0.472903f, 0.540161f, 0.445341f, 0.585184f, 0.561693f, 0.609513f, 0.394200f, 0.418769f, 0.444939f, 0.478136f, 0.458334f, 0.591187f, 0.448606f, 0.605061f, 0.412183f, 0.312673f, 0.559178f, 0.530440f, 0.538275f, 0.546820f, 0.494936f, 0.585982f, 0.469875f, 0.355291f, 0.474437f, 0.542980f, 0.518181f, 0.609491f, 0.522046f, 0.618936f, 0.412090f, 0.410711f, 0.452217f, 0.540284f, 0.444109f, 0.585510f, 0.570158f, 0.614413f, 0.415425f, 0.410005f, 0.441791f, 0.491080f, 0.466021f, 0.595833f}; ASSERT_EQ(q.size(), batch_size * q_num_heads * q_sequence_length * head_size); ASSERT_EQ(k.size(), batch_size * kv_num_heads * kv_sequence_length * head_size); @@ -886,7 +919,7 @@ TEST(AttentionTest, Attention4DGqaWithPastAndPresent) { // {2, 3, 12, 8} std::vector past_value = {0.546950f, 0.293617f, 0.968204f, 0.226196f, 0.015738f, 0.325855f, 0.502509f, 0.028363f, 0.559248f, 0.874283f, 0.704732f, 0.622968f, 0.955962f, 0.958279f, 0.824266f, 0.607742f, 0.487765f, 0.013316f, 0.606262f, 0.989088f, 0.818101f, 0.340605f, 0.152047f, 0.784059f, 0.743938f, 0.967047f, 0.874842f, 0.555663f, 0.101284f, 0.483501f, 0.313695f, 0.512408f, 0.301702f, 0.861823f, 0.844327f, 0.315465f, 0.599581f, 0.430181f, 0.909093f, 0.187361f, 0.697728f, 0.970375f, 0.175276f, 0.201966f, 0.693723f, 0.779154f, 0.490549f, 0.609686f, 0.212682f, 0.476614f, 0.112072f, 0.321422f, 0.284780f, 0.444625f, 0.930126f, 0.181268f, 0.401388f, 0.615597f, 0.946557f, 0.133148f, 0.917877f, 0.081054f, 0.480741f, 0.454590f, 0.209603f, 0.347460f, 0.454165f, 0.865211f, 0.955064f, 0.518926f, 0.870100f, 0.608172f, 0.349087f, 0.194194f, 0.413135f, 0.522824f, 0.044443f, 0.145841f, 0.600184f, 0.225002f, 0.837326f, 0.326942f, 0.104834f, 0.083531f, 0.937123f, 0.118020f, 0.140910f, 0.862666f, 0.254288f, 0.665951f, 0.816726f, 0.607181f, 0.957489f, 0.708883f, 0.112752f, 0.558410f, 0.718186f, 0.801957f, 0.026321f, 0.718879f, 0.825681f, 0.746834f, 0.512349f, 0.458021f, 0.549419f, 0.704644f, 0.922914f, 0.617035f, 0.887834f, 0.701257f, 0.068336f, 0.500828f, 0.286486f, 0.285175f, 0.355928f, 0.314733f, 0.578610f, 0.683601f, 0.268749f, 0.129763f, 0.058809f, 0.575753f, 0.186130f, 0.009248f, 0.927753f, 0.537140f, 0.092448f, 0.842921f, 0.983203f, 0.448601f, 0.042490f, 0.117546f, 0.381654f, 0.885523f, 0.148039f, 0.823990f, 0.014976f, 0.457389f, 0.644397f, 0.060379f, 0.614763f, 0.944404f, 0.160260f, 0.729611f, 0.609094f, 0.185116f, 0.006203f, 0.009284f, 0.532092f, 0.942779f, 0.644299f, 0.714300f, 0.493865f, 0.581889f, 0.126368f, 0.876821f, 0.760793f, 0.998199f, 0.297723f, 0.227018f, 0.125162f, 0.964210f, 0.780885f, 0.166325f, 0.552686f, 0.413768f, 0.151486f, 0.162073f, 0.963470f, 0.304964f, 0.941439f, 0.075611f, 0.460803f, 0.129619f, 0.004787f, 0.553766f, 0.113894f, 0.722025f, 0.698116f, 0.176333f, 0.941742f, 0.721043f, 0.297970f, 0.709234f, 0.731930f, 0.342226f, 0.375589f, 0.359107f, 0.616618f, 0.900410f, 0.173193f, 0.875200f, 0.027653f, 0.660339f, 0.414439f, 0.791282f, 0.721198f, 0.480108f, 0.643864f, 0.501773f, 0.811518f, 0.476084f, 0.523156f, 0.250521f, 0.605043f, 0.302905f, 0.577284f, 0.169678f, 0.159469f, 0.417030f, 0.426820f, 0.268109f, 0.131597f, 0.039211f, 0.025232f, 0.271550f, 0.461853f, 0.726243f, 0.474872f, 0.904051f, 0.035220f, 0.180661f, 0.338515f, 0.577496f, 0.852736f, 0.350202f, 0.267989f, 0.061889f, 0.821303f, 0.379666f, 0.571550f, 0.983555f, 0.001595f, 0.145450f, 0.779111f, 0.805128f, 0.769247f, 0.536999f, 0.978857f, 0.396185f, 0.601944f, 0.063369f, 0.409857f, 0.722500f, 0.238739f, 0.943828f, 0.686783f, 0.287575f, 0.768999f, 0.083165f, 0.974774f, 0.049285f, 0.933456f, 0.252854f, 0.757824f, 0.000074f, 0.254240f, 0.749101f, 0.532336f, 0.114952f, 0.393630f, 0.375549f, 0.568162f, 0.667977f, 0.840830f, 0.497231f, 0.392022f, 0.143977f, 0.804823f, 0.713370f, 0.408677f, 0.518432f, 0.665183f, 0.164806f, 0.027198f, 0.317504f, 0.595585f, 0.486606f, 0.692555f, 0.819690f, 0.488442f, 0.134267f, 0.850628f, 0.574990f, 0.739937f, 0.704665f, 0.968212f, 0.295307f, 0.705307f, 0.365676f, 0.395411f, 0.230595f, 0.344010f, 0.948297f, 0.292571f, 0.245991f, 0.583138f, 0.258036f, 0.473386f, 0.834176f, 0.230400f, 0.426691f, 0.610490f, 0.545629f, 0.974723f, 0.680370f, 0.739946f, 0.966956f, 0.414438f, 0.355380f, 0.043862f, 0.184204f, 0.237190f, 0.183504f, 0.754784f, 0.535883f, 0.667634f, 0.820462f, 0.230774f, 0.325924f, 0.708360f, 0.392759f, 0.029271f, 0.434955f, 0.908273f, 0.409021f, 0.332249f, 0.989525f, 0.644416f, 0.365998f, 0.102020f, 0.787849f, 0.708075f, 0.921916f, 0.217276f, 0.114924f, 0.724073f, 0.203396f, 0.176104f, 0.319807f, 0.816825f, 0.539537f, 0.045850f, 0.463895f, 0.683980f, 0.538368f, 0.572450f, 0.224777f, 0.847739f, 0.561399f, 0.713246f, 0.981864f, 0.428199f, 0.881067f, 0.007281f, 0.033407f, 0.590280f, 0.311449f, 0.248277f, 0.277935f, 0.318403f, 0.728948f, 0.569196f, 0.789036f, 0.830197f, 0.842935f, 0.414644f, 0.421273f, 0.926266f, 0.661764f, 0.080467f, 0.542187f, 0.356007f, 0.987435f, 0.013655f, 0.612181f, 0.723623f, 0.288907f, 0.973642f, 0.859537f, 0.915653f, 0.019232f, 0.569872f, 0.294650f, 0.849029f, 0.632850f, 0.538877f, 0.114588f, 0.540223f, 0.631904f, 0.955912f, 0.585051f, 0.967401f, 0.961606f, 0.650200f, 0.505908f, 0.466022f, 0.890379f, 0.028257f, 0.113808f, 0.102072f, 0.756935f, 0.339651f, 0.637969f, 0.603783f, 0.385828f, 0.531568f, 0.645139f, 0.940950f, 0.575634f, 0.614367f, 0.067856f, 0.952216f, 0.528082f, 0.801273f, 0.050291f, 0.420910f, 0.256975f, 0.266976f, 0.791454f, 0.623867f, 0.439745f, 0.010586f, 0.964928f, 0.962023f, 0.217552f, 0.041346f, 0.530199f, 0.951411f, 0.910396f, 0.584663f, 0.303549f, 0.329961f, 0.897914f, 0.491784f, 0.131116f, 0.248425f, 0.276795f, 0.123547f, 0.463044f, 0.916051f, 0.668783f, 0.072474f, 0.005495f, 0.276248f, 0.362693f, 0.776750f, 0.967006f, 0.387567f, 0.686690f, 0.994902f, 0.745667f, 0.636190f, 0.078075f, 0.323215f, 0.913392f, 0.201005f, 0.843590f, 0.696324f, 0.366324f, 0.529174f, 0.542806f, 0.714054f, 0.516556f, 0.133076f, 0.773455f, 0.406273f, 0.963094f, 0.283514f, 0.263079f, 0.333507f, 0.572317f, 0.894870f, 0.176282f, 0.279679f, 0.581680f, 0.454334f, 0.447323f, 0.820734f, 0.923878f, 0.481307f, 0.687352f, 0.801059f, 0.518366f, 0.294316f, 0.638085f, 0.585109f, 0.901563f, 0.052407f, 0.910131f, 0.534432f, 0.015676f, 0.344702f, 0.724334f, 0.488433f, 0.980159f, 0.422610f, 0.326635f, 0.821672f, 0.547907f, 0.682327f, 0.805702f, 0.671428f, 0.422408f, 0.124796f, 0.580248f, 0.897433f, 0.418892f, 0.910725f, 0.503528f, 0.620842f, 0.832989f, 0.564597f, 0.090969f, 0.980979f, 0.245849f, 0.710505f, 0.505113f, 0.478773f, 0.243941f, 0.722151f, 0.112788f, 0.990453f, 0.845374f, 0.534509f, 0.424553f, 0.286465f, 0.501591f, 0.879417f, 0.275006f, 0.500537f, 0.234550f, 0.337149f, 0.190261f, 0.990539f, 0.571497f, 0.732815f, 0.098250f, 0.366118f, 0.892640f, 0.084438f, 0.165483f, 0.625418f, 0.622789f, 0.838227f, 0.935493f, 0.141986f, 0.259374f, 0.427461f, 0.000903f, 0.069814f, 0.226491f, 0.481102f, 0.251523f, 0.876682f, 0.324273f, 0.924623f, 0.974787f, 0.449862f, 0.227129f, 0.291666f, 0.776334f, 0.273350f, 0.380583f, 0.478576f, 0.575111f, 0.996100f, 0.232210f, 0.353424f, 0.262891f, 0.361113f, 0.100805f, 0.359810f, 0.887865f, 0.298590f, 0.371935f}; // {2, 9, 4, 8} - std::vector y = {0.544462f, 0.617844f, 0.506335f, 0.473482f, 0.606855f, 0.423464f, 0.544771f, 0.450451f, 0.524249f, 0.627160f, 0.497201f, 0.440288f, 0.619110f, 0.437084f, 0.563680f, 0.440037f, 0.516736f, 0.577726f, 0.523888f, 0.493471f, 0.594122f, 0.433401f, 0.585942f, 0.457686f, 0.528512f, 0.604578f, 0.472106f, 0.471486f, 0.600445f, 0.446256f, 0.622393f, 0.435442f, 0.440810f, 0.437705f, 0.476508f, 0.320820f, 0.605191f, 0.640150f, 0.306216f, 0.610947f, 0.485794f, 0.448216f, 0.485639f, 0.323744f, 0.594446f, 0.646597f, 0.321742f, 0.605751f, 0.501858f, 0.445502f, 0.487899f, 0.384660f, 0.597134f, 0.616430f, 0.331401f, 0.566459f, 0.502522f, 0.409965f, 0.526639f, 0.348601f, 0.565200f, 0.586558f, 0.325044f, 0.603422f, 0.450250f, 0.368009f, 0.550911f, 0.460338f, 0.523907f, 0.508816f, 0.575624f, 0.426601f, 0.472310f, 0.372844f, 0.517852f, 0.431688f, 0.551555f, 0.527657f, 0.600578f, 0.473069f, 0.456633f, 0.442035f, 0.539875f, 0.437863f, 0.540202f, 0.499608f, 0.556470f, 0.419831f, 0.463081f, 0.416724f, 0.526389f, 0.458654f, 0.540120f, 0.551554f, 0.569399f, 0.447102f, 0.534296f, 0.597655f, 0.509699f, 0.487167f, 0.607438f, 0.426383f, 0.522794f, 0.458435f, 0.510147f, 0.622761f, 0.501724f, 0.453386f, 0.629671f, 0.434103f, 0.582477f, 0.437681f, 0.520031f, 0.568543f, 0.525216f, 0.490370f, 0.571745f, 0.428629f, 0.572995f, 0.460086f, 0.533607f, 0.614962f, 0.474130f, 0.456345f, 0.576467f, 0.448127f, 0.599211f, 0.432252f, 0.447842f, 0.430169f, 0.480055f, 0.320521f, 0.590915f, 0.627003f, 0.314551f, 0.609320f, 0.499216f, 0.438828f, 0.485519f, 0.322134f, 0.586364f, 0.645824f, 0.326481f, 0.596989f, 0.496362f, 0.442741f, 0.492120f, 0.366111f, 0.601604f, 0.615566f, 0.326354f, 0.567173f, 0.496946f, 0.422179f, 0.533144f, 0.342588f, 0.590482f, 0.605923f, 0.318055f, 0.610401f, 0.452598f, 0.361594f, 0.550919f, 0.455099f, 0.530404f, 0.519313f, 0.588655f, 0.431890f, 0.464325f, 0.389636f, 0.515359f, 0.429087f, 0.540767f, 0.518376f, 0.586627f, 0.471074f, 0.458527f, 0.422216f, 0.537762f, 0.434123f, 0.550956f, 0.507704f, 0.564828f, 0.421548f, 0.463044f, 0.407985f, 0.523093f, 0.473684f, 0.542663f, 0.551348f, 0.576783f, 0.448743f, 0.546208f, 0.621128f, 0.501647f, 0.468191f, 0.612298f, 0.425183f, 0.549241f, 0.447622f, 0.519355f, 0.619636f, 0.487775f, 0.444259f, 0.625749f, 0.430264f, 0.584338f, 0.436887f, 0.521021f, 0.572716f, 0.522539f, 0.486440f, 0.581317f, 0.429079f, 0.579691f, 0.455426f, 0.526431f, 0.604615f, 0.476481f, 0.469814f, 0.588766f, 0.445640f, 0.609160f, 0.437785f, 0.443498f, 0.439338f, 0.487424f, 0.310942f, 0.607341f, 0.630362f, 0.312591f, 0.621999f, 0.483917f, 0.446308f, 0.477454f, 0.331028f, 0.592608f, 0.653297f, 0.322368f, 0.599377f, 0.497354f, 0.443447f, 0.477781f, 0.384002f, 0.591587f, 0.610287f, 0.328537f, 0.567630f, 0.499369f, 0.421961f, 0.536492f, 0.345379f, 0.586450f, 0.600541f, 0.312965f, 0.609437f, 0.451750f, 0.359685f, 0.553321f, 0.464992f, 0.524025f, 0.522507f, 0.582135f, 0.425124f, 0.459696f, 0.394679f, 0.519051f, 0.411226f, 0.539772f, 0.505003f, 0.587681f, 0.469383f, 0.451681f, 0.430062f, 0.541843f, 0.420929f, 0.542240f, 0.487570f, 0.567067f, 0.419708f, 0.456288f, 0.412096f, 0.527592f, 0.467870f, 0.545021f, 0.547842f, 0.573135f, 0.448166f, 0.581220f, 0.559255f, 0.469802f, 0.489935f, 0.557197f, 0.487135f, 0.377325f, 0.425637f, 0.582374f, 0.560738f, 0.425382f, 0.463129f, 0.549939f, 0.481810f, 0.350432f, 0.466049f, 0.593554f, 0.542315f, 0.482597f, 0.496969f, 0.518851f, 0.507807f, 0.366054f, 0.457476f, 0.569468f, 0.565965f, 0.444765f, 0.465404f, 0.515500f, 0.520271f, 0.337845f, 0.448357f, 0.557802f, 0.585925f, 0.426858f, 0.464044f, 0.585251f, 0.557395f, 0.433327f, 0.615342f, 0.534368f, 0.573723f, 0.426393f, 0.518102f, 0.586735f, 0.513129f, 0.371969f, 0.636735f, 0.544166f, 0.588469f, 0.433470f, 0.481894f, 0.595019f, 0.533156f, 0.396519f, 0.608115f, 0.547125f, 0.604473f, 0.441984f, 0.469765f, 0.599107f, 0.561685f, 0.347618f, 0.563457f, 0.507550f, 0.485293f, 0.545846f, 0.408434f, 0.482538f, 0.532314f, 0.498883f, 0.525126f, 0.514603f, 0.471457f, 0.539705f, 0.362410f, 0.490158f, 0.513690f, 0.494170f, 0.496909f, 0.492936f, 0.506153f, 0.565865f, 0.364727f, 0.508899f, 0.516217f, 0.558362f, 0.556920f, 0.530472f, 0.521715f, 0.554673f, 0.363830f, 0.509086f, 0.511590f, 0.552396f, 0.541486f, 0.572145f, 0.551531f, 0.471964f, 0.485188f, 0.555030f, 0.493247f, 0.376875f, 0.429387f, 0.580540f, 0.550944f, 0.435664f, 0.480675f, 0.544997f, 0.488698f, 0.344985f, 0.464878f, 0.593774f, 0.541202f, 0.484834f, 0.497316f, 0.509364f, 0.500045f, 0.357235f, 0.448933f, 0.565242f, 0.546653f, 0.459790f, 0.481954f, 0.514950f, 0.516297f, 0.344285f, 0.454476f, 0.548036f, 0.577907f, 0.427075f, 0.478978f, 0.581563f, 0.553606f, 0.426476f, 0.638442f, 0.498925f, 0.598346f, 0.444106f, 0.536998f, 0.575948f, 0.499260f, 0.371120f, 0.626981f, 0.545949f, 0.586548f, 0.428254f, 0.479753f, 0.596943f, 0.527697f, 0.401418f, 0.613028f, 0.542355f, 0.607063f, 0.447840f, 0.467102f, 0.603496f, 0.549575f, 0.364370f, 0.561534f, 0.507041f, 0.473640f, 0.547768f, 0.413960f, 0.490513f, 0.534377f, 0.497277f, 0.517772f, 0.531394f, 0.489105f, 0.531671f, 0.369343f, 0.486462f, 0.501787f, 0.494220f, 0.493498f, 0.485968f, 0.510301f, 0.559766f, 0.361474f, 0.507888f, 0.518858f, 0.564300f, 0.561990f, 0.537984f, 0.527982f, 0.539571f, 0.366920f, 0.498313f, 0.505709f, 0.538027f, 0.541246f, 0.585733f, 0.565800f, 0.441346f, 0.476255f, 0.556453f, 0.497693f, 0.363246f, 0.426799f, 0.578484f, 0.556489f, 0.436699f, 0.481177f, 0.549473f, 0.484153f, 0.355910f, 0.462010f, 0.590951f, 0.542803f, 0.470954f, 0.488994f, 0.512707f, 0.511876f, 0.358555f, 0.455953f, 0.559449f, 0.546003f, 0.462900f, 0.471080f, 0.517298f, 0.519225f, 0.345016f, 0.449149f, 0.526624f, 0.606761f, 0.427660f, 0.480775f, 0.577420f, 0.538850f, 0.426959f, 0.625509f, 0.530502f, 0.585784f, 0.432234f, 0.516800f, 0.584937f, 0.514154f, 0.373726f, 0.623740f, 0.550470f, 0.585577f, 0.436483f, 0.474799f, 0.594100f, 0.540052f, 0.402520f, 0.607686f, 0.537556f, 0.609680f, 0.439490f, 0.477886f, 0.602656f, 0.542957f, 0.350394f, 0.574553f, 0.506900f, 0.488792f, 0.539037f, 0.403028f, 0.494093f, 0.534739f, 0.494292f, 0.511628f, 0.528192f, 0.480037f, 0.546429f, 0.375120f, 0.484828f, 0.505006f, 0.495786f, 0.497935f, 0.502174f, 0.514122f, 0.541314f, 0.369540f, 0.493985f, 0.508263f, 0.550415f, 0.556157f, 0.543269f, 0.529970f, 0.562027f, 0.376526f, 0.499704f, 0.508621f, 0.536068f, 0.545993f}; + std::vector y = {0.544462f, 0.617844f, 0.506335f, 0.473482f, 0.606855f, 0.423464f, 0.544771f, 0.450451f, 0.524249f, 0.627160f, 0.497201f, 0.440288f, 0.619110f, 0.437084f, 0.563680f, 0.440037f, 0.516736f, 0.577726f, 0.523888f, 0.493471f, 0.594122f, 0.433401f, 0.585942f, 0.457686f, 0.528512f, 0.604578f, 0.472106f, 0.471486f, 0.600445f, 0.446256f, 0.622393f, 0.435442f, 0.546090f, 0.618047f, 0.504325f, 0.472246f, 0.609686f, 0.422467f, 0.546964f, 0.451166f, 0.519404f, 0.617868f, 0.491984f, 0.445771f, 0.633094f, 0.436822f, 0.559753f, 0.447209f, 0.519860f, 0.574899f, 0.525759f, 0.489339f, 0.586803f, 0.436452f, 0.577737f, 0.453299f, 0.532473f, 0.609446f, 0.471758f, 0.455772f, 0.573504f, 0.445466f, 0.602573f, 0.433307f, 0.538062f, 0.604199f, 0.500302f, 0.479569f, 0.614174f, 0.429231f, 0.522434f, 0.459369f, 0.528422f, 0.620683f, 0.485333f, 0.435606f, 0.616579f, 0.432233f, 0.565856f, 0.440093f, 0.525356f, 0.580613f, 0.529584f, 0.483095f, 0.583395f, 0.433491f, 0.593043f, 0.451879f, 0.540119f, 0.622995f, 0.472122f, 0.449888f, 0.586202f, 0.447435f, 0.611846f, 0.434879f, 0.449905f, 0.430732f, 0.474834f, 0.321674f, 0.590495f, 0.626300f, 0.319127f, 0.606006f, 0.492763f, 0.445330f, 0.490219f, 0.319940f, 0.588298f, 0.643644f, 0.317760f, 0.596360f, 0.507993f, 0.440004f, 0.490555f, 0.378128f, 0.588227f, 0.604974f, 0.329202f, 0.561987f, 0.511572f, 0.403440f, 0.542761f, 0.331792f, 0.568397f, 0.583366f, 0.333122f, 0.608456f, 0.447842f, 0.430169f, 0.480055f, 0.320521f, 0.590915f, 0.627003f, 0.314551f, 0.609320f, 0.499216f, 0.438828f, 0.485519f, 0.322134f, 0.586364f, 0.645824f, 0.326481f, 0.596989f, 0.496362f, 0.442741f, 0.492120f, 0.366111f, 0.601604f, 0.615566f, 0.326354f, 0.567173f, 0.496946f, 0.422179f, 0.533144f, 0.342588f, 0.590482f, 0.605923f, 0.318055f, 0.610401f, 0.441356f, 0.431701f, 0.488343f, 0.311828f, 0.606159f, 0.632821f, 0.317863f, 0.629084f, 0.495613f, 0.441177f, 0.473223f, 0.335484f, 0.579139f, 0.646878f, 0.321269f, 0.595437f, 0.504999f, 0.443626f, 0.498154f, 0.369326f, 0.588410f, 0.600189f, 0.322347f, 0.562676f, 0.508419f, 0.405342f, 0.533092f, 0.335876f, 0.570568f, 0.589600f, 0.330741f, 0.609168f, 0.456943f, 0.365603f, 0.555030f, 0.454344f, 0.526263f, 0.519062f, 0.578652f, 0.425453f, 0.464039f, 0.391848f, 0.518985f, 0.419419f, 0.541410f, 0.514459f, 0.586459f, 0.470210f, 0.460338f, 0.408599f, 0.539512f, 0.446249f, 0.551945f, 0.511356f, 0.575513f, 0.424325f, 0.452212f, 0.418205f, 0.525148f, 0.459799f, 0.536327f, 0.541881f, 0.571451f, 0.452969f, 0.454154f, 0.354641f, 0.553889f, 0.451027f, 0.536270f, 0.521832f, 0.590756f, 0.429859f, 0.459101f, 0.394962f, 0.512076f, 0.419296f, 0.535702f, 0.516757f, 0.585606f, 0.478117f, 0.458365f, 0.422929f, 0.531943f, 0.447581f, 0.546387f, 0.511705f, 0.564350f, 0.425332f, 0.463274f, 0.429223f, 0.525922f, 0.452328f, 0.539095f, 0.534372f, 0.563738f, 0.449120f, 0.451750f, 0.359685f, 0.553321f, 0.464992f, 0.524025f, 0.522507f, 0.582135f, 0.425124f, 0.459696f, 0.394679f, 0.519051f, 0.411226f, 0.539772f, 0.505003f, 0.587681f, 0.469383f, 0.451681f, 0.430062f, 0.541843f, 0.420929f, 0.542240f, 0.487570f, 0.567067f, 0.419708f, 0.456288f, 0.412096f, 0.527592f, 0.467870f, 0.545021f, 0.547842f, 0.573135f, 0.448166f, 0.581220f, 0.559255f, 0.469802f, 0.489935f, 0.557197f, 0.487135f, 0.377325f, 0.425637f, 0.582374f, 0.560738f, 0.425382f, 0.463129f, 0.549939f, 0.481810f, 0.350432f, 0.466049f, 0.593554f, 0.542315f, 0.482597f, 0.496969f, 0.518851f, 0.507807f, 0.366054f, 0.457476f, 0.569468f, 0.565965f, 0.444765f, 0.465404f, 0.515500f, 0.520271f, 0.337845f, 0.448357f, 0.586343f, 0.566462f, 0.444339f, 0.481474f, 0.557556f, 0.495837f, 0.368487f, 0.425850f, 0.580159f, 0.565990f, 0.400882f, 0.462578f, 0.551037f, 0.497924f, 0.338502f, 0.468483f, 0.592753f, 0.536897f, 0.481975f, 0.489485f, 0.519290f, 0.509298f, 0.366838f, 0.461538f, 0.567139f, 0.559419f, 0.458050f, 0.468739f, 0.514875f, 0.512271f, 0.346335f, 0.449357f, 0.583058f, 0.557532f, 0.454426f, 0.492673f, 0.551748f, 0.496414f, 0.364023f, 0.430048f, 0.579431f, 0.565100f, 0.420761f, 0.466297f, 0.551315f, 0.487418f, 0.348148f, 0.461136f, 0.585687f, 0.535194f, 0.485465f, 0.488622f, 0.513327f, 0.508844f, 0.368049f, 0.455823f, 0.554855f, 0.560589f, 0.456398f, 0.477641f, 0.507017f, 0.518069f, 0.338229f, 0.444624f, 0.500594f, 0.616610f, 0.439949f, 0.495561f, 0.569213f, 0.540425f, 0.422667f, 0.627919f, 0.514283f, 0.584446f, 0.441141f, 0.528331f, 0.577047f, 0.508969f, 0.372295f, 0.646734f, 0.536256f, 0.591823f, 0.428652f, 0.485852f, 0.592863f, 0.525360f, 0.399985f, 0.623408f, 0.552463f, 0.606841f, 0.448560f, 0.466321f, 0.600628f, 0.566464f, 0.356481f, 0.551351f, 0.548036f, 0.577907f, 0.427075f, 0.478978f, 0.581563f, 0.553606f, 0.426476f, 0.638442f, 0.498925f, 0.598346f, 0.444106f, 0.536998f, 0.575948f, 0.499260f, 0.371120f, 0.626981f, 0.545949f, 0.586548f, 0.428254f, 0.479753f, 0.596943f, 0.527697f, 0.401418f, 0.613028f, 0.542355f, 0.607063f, 0.447840f, 0.467102f, 0.603496f, 0.549575f, 0.364370f, 0.561534f, 0.532692f, 0.601573f, 0.425963f, 0.477495f, 0.573122f, 0.544325f, 0.422438f, 0.629794f, 0.512145f, 0.593241f, 0.436187f, 0.532146f, 0.582008f, 0.499410f, 0.366728f, 0.631277f, 0.550263f, 0.590346f, 0.430967f, 0.477189f, 0.600022f, 0.528313f, 0.406504f, 0.603355f, 0.537075f, 0.605495f, 0.437735f, 0.474413f, 0.601068f, 0.542204f, 0.348555f, 0.581430f, 0.499619f, 0.480920f, 0.536032f, 0.413380f, 0.478027f, 0.524393f, 0.490201f, 0.530954f, 0.517442f, 0.475326f, 0.541763f, 0.366450f, 0.498398f, 0.509411f, 0.503732f, 0.490468f, 0.488084f, 0.505941f, 0.554614f, 0.371690f, 0.503635f, 0.510325f, 0.557424f, 0.564303f, 0.534730f, 0.536543f, 0.563296f, 0.362277f, 0.498957f, 0.508357f, 0.538003f, 0.554638f, 0.514150f, 0.481676f, 0.543535f, 0.414778f, 0.478296f, 0.529467f, 0.496600f, 0.522262f, 0.522734f, 0.480361f, 0.534209f, 0.379264f, 0.485836f, 0.500082f, 0.498644f, 0.501901f, 0.474729f, 0.503193f, 0.560206f, 0.362595f, 0.515144f, 0.512647f, 0.557224f, 0.567242f, 0.539217f, 0.533273f, 0.538641f, 0.373064f, 0.495733f, 0.499786f, 0.532998f, 0.547731f, 0.506900f, 0.488792f, 0.539037f, 0.403028f, 0.494093f, 0.534739f, 0.494292f, 0.511628f, 0.528192f, 0.480037f, 0.546429f, 0.375120f, 0.484828f, 0.505006f, 0.495786f, 0.497935f, 0.502174f, 0.514122f, 0.541314f, 0.369540f, 0.493985f, 0.508263f, 0.550415f, 0.556157f, 0.543269f, 0.529970f, 0.562027f, 0.376526f, 0.499704f, 0.508621f, 0.536068f, 0.545993f}; // {2, 3, 18, 8} std::vector present_key = {0.938256f, 0.244570f, 0.458212f, 0.757407f, 0.203621f, 0.566312f, 0.185817f, 0.104736f, 0.116559f, 0.357639f, 0.004655f, 0.424854f, 0.664197f, 0.401688f, 0.085795f, 0.062689f, 0.278117f, 0.169313f, 0.965095f, 0.151230f, 0.805462f, 0.586108f, 0.569287f, 0.512081f, 0.971763f, 0.363845f, 0.787916f, 0.555294f, 0.395634f, 0.955466f, 0.598316f, 0.118917f, 0.417539f, 0.781582f, 0.693747f, 0.916340f, 0.259377f, 0.758194f, 0.459875f, 0.573610f, 0.955047f, 0.979286f, 0.861591f, 0.359097f, 0.887701f, 0.638609f, 0.429997f, 0.035743f, 0.770128f, 0.502106f, 0.786188f, 0.748023f, 0.793567f, 0.300651f, 0.800799f, 0.548846f, 0.473326f, 0.675126f, 0.021359f, 0.102317f, 0.292177f, 0.982990f, 0.139746f, 0.330596f, 0.051053f, 0.331269f, 0.320326f, 0.946807f, 0.845154f, 0.382764f, 0.024769f, 0.831031f, 0.660536f, 0.152364f, 0.996071f, 0.100233f, 0.867115f, 0.294266f, 0.435353f, 0.795457f, 0.677508f, 0.937864f, 0.621140f, 0.097810f, 0.884360f, 0.769156f, 0.711870f, 0.053734f, 0.396223f, 0.167436f, 0.821904f, 0.700529f, 0.883078f, 0.966575f, 0.774748f, 0.994233f, 0.160468f, 0.886305f, 0.446394f, 0.907876f, 0.160230f, 0.661117f, 0.440264f, 0.076487f, 0.696463f, 0.247399f, 0.039616f, 0.059944f, 0.061079f, 0.907733f, 0.739884f, 0.898062f, 0.672582f, 0.528940f, 0.304446f, 0.997962f, 0.362189f, 0.470649f, 0.378245f, 0.979527f, 0.174658f, 0.327988f, 0.680349f, 0.063208f, 0.607249f, 0.477646f, 0.284000f, 0.238413f, 0.514513f, 0.367928f, 0.456520f, 0.337477f, 0.970494f, 0.133439f, 0.096804f, 0.343392f, 0.591027f, 0.659176f, 0.397257f, 0.999278f, 0.351893f, 0.721407f, 0.637583f, 0.813054f, 0.614770f, 0.037130f, 0.014252f, 0.342104f, 0.823472f, 0.866135f, 0.960813f, 0.065121f, 0.044571f, 0.913284f, 0.305047f, 0.557987f, 0.982445f, 0.400449f, 0.665871f, 0.400880f, 0.768195f, 0.527715f, 0.237523f, 0.271306f, 0.258059f, 0.532320f, 0.703189f, 0.949280f, 0.694087f, 0.781193f, 0.168926f, 0.374063f, 0.413780f, 0.686380f, 0.295892f, 0.303292f, 0.355889f, 0.810302f, 0.577590f, 0.075277f, 0.078246f, 0.371287f, 0.766591f, 0.688683f, 0.707982f, 0.767210f, 0.287153f, 0.548256f, 0.543353f, 0.739632f, 0.956871f, 0.277990f, 0.793282f, 0.659971f, 0.580238f, 0.774880f, 0.944032f, 0.036691f, 0.147400f, 0.756287f, 0.083791f, 0.516124f, 0.219861f, 0.274296f, 0.701840f, 0.030193f, 0.873319f, 0.444479f, 0.502393f, 0.540048f, 0.645544f, 0.344857f, 0.101107f, 0.318379f, 0.168142f, 0.556133f, 0.318029f, 0.958067f, 0.965734f, 0.620126f, 0.617497f, 0.985379f, 0.887283f, 0.765070f, 0.313591f, 0.365539f, 0.201267f, 0.487148f, 0.990369f, 0.912151f, 0.118349f, 0.025190f, 0.898638f, 0.537170f, 0.200190f, 0.673653f, 0.644223f, 0.122086f, 0.259600f, 0.060078f, 0.976226f, 0.889794f, 0.764562f, 0.698249f, 0.335498f, 0.147686f, 0.062636f, 0.241902f, 0.432281f, 0.521996f, 0.773084f, 0.958741f, 0.117320f, 0.107004f, 0.589695f, 0.745398f, 0.848150f, 0.935832f, 0.983426f, 0.399802f, 0.380335f, 0.147809f, 0.684934f, 0.656762f, 0.862063f, 0.097258f, 0.497777f, 0.581082f, 0.241557f, 0.169025f, 0.859581f, 0.058535f, 0.470621f, 0.115834f, 0.457059f, 0.979962f, 0.423706f, 0.857125f, 0.117316f, 0.271252f, 0.403793f, 0.399812f, 0.671384f, 0.344718f, 0.713767f, 0.639187f, 0.399161f, 0.431760f, 0.209860f, 0.132306f, 0.193236f, 0.685467f, 0.049500f, 0.101855f, 0.134174f, 0.316541f, 0.298750f, 0.255064f, 0.750537f, 0.998023f, 0.533978f, 0.944203f, 0.396610f, 0.106682f, 0.408774f, 0.296128f, 0.493407f, 0.657044f, 0.461050f, 0.935161f, 0.884765f, 0.701978f, 0.489685f, 0.131687f, 0.397014f, 0.704402f, 0.284886f, 0.103988f, 0.907898f, 0.709051f, 0.615276f, 0.792499f, 0.835646f, 0.483459f, 0.881188f, 0.916419f, 0.271551f, 0.607545f, 0.526584f, 0.537946f, 0.937663f, 0.305189f, 0.983434f, 0.902131f, 0.458723f, 0.817453f, 0.769047f, 0.677895f, 0.319834f, 0.196451f, 0.671528f, 0.842973f, 0.016253f, 0.642803f, 0.442873f, 0.898088f, 0.321473f, 0.474185f, 0.514767f, 0.140440f, 0.712892f, 0.830476f, 0.057909f, 0.291389f, 0.038045f, 0.956544f, 0.667169f, 0.964200f, 0.531494f, 0.802069f, 0.374414f, 0.353819f, 0.378268f, 0.657862f, 0.359453f, 0.900367f, 0.983275f, 0.030427f, 0.193623f, 0.112250f, 0.042364f, 0.227741f, 0.446793f, 0.836990f, 0.221824f, 0.493945f, 0.929619f, 0.667215f, 0.798079f, 0.550994f, 0.980466f, 0.588662f, 0.045511f, 0.197983f, 0.614528f, 0.070042f, 0.822407f, 0.653421f, 0.726342f, 0.536923f, 0.110477f, 0.405036f, 0.405374f, 0.321043f, 0.029950f, 0.737254f, 0.109784f, 0.606308f, 0.703218f, 0.634786f, 0.959142f, 0.103298f, 0.867167f, 0.029190f, 0.534917f, 0.404244f, 0.524184f, 0.365100f, 0.190567f, 0.019123f, 0.518150f, 0.842777f, 0.373216f, 0.222864f, 0.080532f, 0.085311f, 0.221396f, 0.100014f, 0.265040f, 0.066149f, 0.065605f, 0.856276f, 0.162120f, 0.559682f, 0.773456f, 0.456410f, 0.153369f, 0.199596f, 0.432984f, 0.528234f, 0.349440f, 0.781480f, 0.404774f, 0.601277f, 0.771931f, 0.413086f, 0.710058f, 0.789869f, 0.317260f, 0.979270f, 0.649656f, 0.880998f, 0.555938f, 0.741603f, 0.770544f, 0.908248f, 0.150350f, 0.558283f, 0.428379f, 0.923159f, 0.105095f, 0.982574f, 0.875451f, 0.073826f, 0.490966f, 0.717560f, 0.738152f, 0.906494f, 0.799865f, 0.310930f, 0.498435f, 0.701786f, 0.138437f, 0.193991f, 0.481042f, 0.298246f, 0.862559f, 0.586277f, 0.348665f, 0.848833f, 0.804878f, 0.998355f, 0.847308f, 0.414457f, 0.127499f, 0.840641f, 0.059758f, 0.350271f, 0.919738f, 0.960766f, 0.640565f, 0.688648f, 0.042454f, 0.514480f, 0.546868f, 0.340101f, 0.068597f, 0.228908f, 0.357984f, 0.435142f, 0.590927f, 0.722392f, 0.317632f, 0.328954f, 0.019692f, 0.040875f, 0.257822f, 0.740245f, 0.628314f, 0.769789f, 0.768919f, 0.856567f, 0.720319f, 0.979011f, 0.898825f, 0.586717f, 0.588158f, 0.034267f, 0.998527f, 0.131576f, 0.740347f, 0.821015f, 0.373055f, 0.196852f, 0.098760f, 0.748606f, 0.452654f, 0.713718f, 0.915408f, 0.146584f, 0.919171f, 0.411626f, 0.305267f, 0.943062f, 0.990652f, 0.198892f, 0.656838f, 0.106495f, 0.751022f, 0.927212f, 0.028953f, 0.895691f, 0.392569f, 0.878372f, 0.690785f, 0.987349f, 0.759282f, 0.364545f, 0.501063f, 0.376389f, 0.364912f, 0.260904f, 0.495970f, 0.681740f, 0.277340f, 0.524380f, 0.117380f, 0.159845f, 0.046806f, 0.970731f, 0.003860f, 0.178580f, 0.612867f, 0.081370f, 0.881896f, 0.719620f, 0.966390f, 0.507636f, 0.300404f, 0.549501f, 0.930819f, 0.520761f, 0.267207f, 0.877399f, 0.371919f, 0.001383f, 0.247685f, 0.318234f, 0.858777f, 0.458503f, 0.444587f, 0.336102f, 0.880678f, 0.945027f, 0.991890f, 0.376741f, 0.650914f, 0.827313f, 0.684499f, 0.417333f, 0.383066f, 0.393122f, 0.589712f, 0.881567f, 0.929066f, 0.053530f, 0.181622f, 0.112224f, 0.193335f, 0.346608f, 0.506532f, 0.629461f, 0.732142f, 0.890112f, 0.989088f, 0.662856f, 0.845365f, 0.778039f, 0.307532f, 0.875692f, 0.042763f, 0.000367f, 0.273733f, 0.462098f, 0.638363f, 0.101770f, 0.673010f, 0.801816f, 0.185313f, 0.415125f, 0.519985f, 0.451807f, 0.799830f, 0.960522f, 0.798953f, 0.077993f, 0.804936f, 0.066596f, 0.235970f, 0.153097f, 0.197519f, 0.528315f, 0.671690f, 0.470321f, 0.959696f, 0.240292f, 0.763140f, 0.870182f, 0.562066f, 0.456223f, 0.596184f, 0.428810f, 0.555194f, 0.416934f, 0.400470f, 0.695346f, 0.092851f, 0.166542f, 0.851198f, 0.771077f, 0.281454f, 0.377269f, 0.926027f, 0.818077f, 0.614346f, 0.221490f, 0.044252f, 0.431258f, 0.672627f, 0.828480f, 0.852689f, 0.032776f, 0.244157f, 0.339095f, 0.188732f, 0.802975f, 0.767466f, 0.516833f, 0.982926f, 0.144059f, 0.899652f, 0.116463f, 0.163182f, 0.696219f, 0.109570f, 0.565845f, 0.420234f, 0.728474f, 0.900675f, 0.769872f, 0.849690f, 0.032945f, 0.966147f, 0.791880f, 0.675689f, 0.244889f, 0.216457f, 0.166048f, 0.922757f, 0.294077f, 0.453094f, 0.493958f, 0.778172f, 0.844235f, 0.139073f, 0.426904f, 0.842855f, 0.818033f, 0.102414f, 0.156383f, 0.304199f, 0.075359f, 0.424663f, 0.107618f, 0.568218f, 0.246557f, 0.596433f, 0.117526f, 0.975884f, 0.932561f, 0.391797f, 0.242179f, 0.250398f, 0.483394f, 0.039993f, 0.639705f, 0.408303f, 0.377407f, 0.809365f, 0.709035f, 0.954334f, 0.351936f, 0.897543f, 0.769967f, 0.357425f, 0.621665f, 0.288570f, 0.874400f, 0.112427f, 0.212434f, 0.310196f, 0.515433f, 0.415953f, 0.231255f, 0.307874f, 0.945431f, 0.294181f, 0.353904f, 0.003710f, 0.845078f, 0.154841f, 0.204144f, 0.255265f, 0.884622f, 0.206451f, 0.797526f, 0.808049f, 0.927021f, 0.115561f, 0.217279f, 0.742898f, 0.196001f, 0.286330f, 0.166742f, 0.172697f, 0.481553f, 0.109683f, 0.321698f, 0.426594f, 0.024548f, 0.388333f, 0.094122f, 0.493579f, 0.825738f, 0.818422f, 0.080449f, 0.601228f, 0.834586f, 0.237973f, 0.761927f, 0.890764f, 0.806124f, 0.107301f, 0.009060f, 0.191724f, 0.270477f, 0.616183f, 0.384273f, 0.703407f, 0.353075f, 0.154425f, 0.312690f, 0.884324f, 0.958532f, 0.207513f, 0.788468f, 0.273349f, 0.887132f, 0.165546f, 0.665960f, 0.084211f, 0.973893f, 0.700633f, 0.841816f, 0.566669f, 0.476801f, 0.621882f, 0.528742f, 0.469384f, 0.759450f, 0.178201f, 0.171172f, 0.431843f, 0.320748f, 0.074125f, 0.844471f, 0.771603f, 0.543921f, 0.979325f, 0.072600f, 0.766669f, 0.266370f, 0.368599f, 0.219279f, 0.789038f, 0.144240f, 0.840017f, 0.661578f, 0.059023f, 0.810982f, 0.627756f, 0.904982f, 0.748722f, 0.561121f, 0.836547f, 0.278050f, 0.183033f, 0.403026f, 0.745233f, 0.526907f, 0.487676f, 0.000546f, 0.425402f, 0.063554f, 0.208253f, 0.932394f, 0.215398f, 0.858338f, 0.802893f, 0.159146f, 0.605712f, 0.115662f, 0.727888f, 0.637462f, 0.811939f, 0.479385f, 0.914863f, 0.049349f, 0.292889f, 0.715053f, 0.418109f, 0.172951f, 0.107211f, 0.817339f, 0.473143f, 0.882284f, 0.733289f, 0.409726f, 0.373511f, 0.515638f, 0.889060f, 0.737279f, 0.005153f, 0.694158f, 0.919507f, 0.710456f, 0.177006f, 0.483518f, 0.140316f, 0.358995f, 0.937117f, 0.923305f, 0.282837f, 0.339631f}; // {2, 3, 18, 8} @@ -1116,7 +1149,7 @@ TEST(AttentionTest, Attention4DWithMask3DPastAndPresentQkMatmulCausal) { std::vector k = {-0.454545f, -0.447601f, -0.440657f, -0.433712f, -0.426768f, -0.419823f, -0.412879f, -0.405934f, -0.398990f, -0.392045f, -0.385101f, -0.378157f, -0.371212f, -0.364268f, -0.357323f, -0.350379f, -0.343434f, -0.336490f, -0.329545f, -0.322601f, -0.315657f, -0.308712f, -0.301768f, -0.294823f, -0.287879f, -0.280934f, -0.273990f, -0.267045f, -0.260101f, -0.253157f, -0.246212f, -0.239268f, -0.232323f, -0.225379f, -0.218434f, -0.211490f, -0.204545f, -0.197601f, -0.190657f, -0.183712f, -0.176768f, -0.169823f, -0.162879f, -0.155934f, -0.148990f, -0.142045f, -0.135101f, -0.128157f, -0.121212f, -0.114268f, -0.107323f, -0.100379f, -0.093434f, -0.086490f, -0.079545f, -0.072601f, -0.065657f, -0.058712f, -0.051768f, -0.044823f, -0.037879f, -0.030934f, -0.023990f, -0.017045f, -0.010101f, -0.003157f, 0.003788f, 0.010732f, 0.017677f, 0.024621f, 0.031566f, 0.038510f, 0.045455f, 0.052399f, 0.059343f, 0.066288f, 0.073232f, 0.080177f, 0.087121f, 0.094066f, 0.101010f, 0.107955f, 0.114899f, 0.121843f, 0.128788f, 0.135732f, 0.142677f, 0.149621f, 0.156566f, 0.163510f, 0.170455f, 0.177399f, 0.184343f, 0.191288f, 0.198232f, 0.205177f, 0.212121f, 0.219066f, 0.226010f, 0.232955f, 0.239899f, 0.246843f, 0.253788f, 0.260732f, 0.267677f, 0.274621f, 0.281566f, 0.288510f, 0.295455f, 0.302399f, 0.309343f, 0.316288f, 0.323232f, 0.330177f, 0.337121f, 0.344066f, 0.351010f, 0.357955f, 0.364899f, 0.371843f, 0.378788f, 0.385732f, 0.392677f, 0.399621f, 0.406566f, 0.413510f, 0.420455f, 0.427399f, 0.434343f, 0.441288f, 0.448232f, 0.455177f, 0.462121f, 0.469066f, 0.476010f, 0.482955f, 0.489899f, 0.496843f, 0.503788f, 0.510732f, 0.517677f, 0.524621f, 0.531566f, 0.538510f}; // {2, 3, 6, 4} std::vector v = {-0.454545f, -0.447601f, -0.440657f, -0.433712f, -0.426768f, -0.419823f, -0.412879f, -0.405934f, -0.398990f, -0.392045f, -0.385101f, -0.378157f, -0.371212f, -0.364268f, -0.357323f, -0.350379f, -0.343434f, -0.336490f, -0.329545f, -0.322601f, -0.315657f, -0.308712f, -0.301768f, -0.294823f, -0.287879f, -0.280934f, -0.273990f, -0.267045f, -0.260101f, -0.253157f, -0.246212f, -0.239268f, -0.232323f, -0.225379f, -0.218434f, -0.211490f, -0.204545f, -0.197601f, -0.190657f, -0.183712f, -0.176768f, -0.169823f, -0.162879f, -0.155934f, -0.148990f, -0.142045f, -0.135101f, -0.128157f, -0.121212f, -0.114268f, -0.107323f, -0.100379f, -0.093434f, -0.086490f, -0.079545f, -0.072601f, -0.065657f, -0.058712f, -0.051768f, -0.044823f, -0.037879f, -0.030934f, -0.023990f, -0.017045f, -0.010101f, -0.003157f, 0.003788f, 0.010732f, 0.017677f, 0.024621f, 0.031566f, 0.038510f, 0.045455f, 0.052399f, 0.059343f, 0.066288f, 0.073232f, 0.080177f, 0.087121f, 0.094066f, 0.101010f, 0.107955f, 0.114899f, 0.121843f, 0.128788f, 0.135732f, 0.142677f, 0.149621f, 0.156566f, 0.163510f, 0.170455f, 0.177399f, 0.184343f, 0.191288f, 0.198232f, 0.205177f, 0.212121f, 0.219066f, 0.226010f, 0.232955f, 0.239899f, 0.246843f, 0.253788f, 0.260732f, 0.267677f, 0.274621f, 0.281566f, 0.288510f, 0.295455f, 0.302399f, 0.309343f, 0.316288f, 0.323232f, 0.330177f, 0.337121f, 0.344066f, 0.351010f, 0.357955f, 0.364899f, 0.371843f, 0.378788f, 0.385732f, 0.392677f, 0.399621f, 0.406566f, 0.413510f, 0.420455f, 0.427399f, 0.434343f, 0.441288f, 0.448232f, 0.455177f, 0.462121f, 0.469066f, 0.476010f, 0.482955f, 0.489899f, 0.496843f, 0.503788f, 0.510732f, 0.517677f, 0.524621f, 0.531566f, 0.538510f}; - // {2, 1, 4, 13} + // {2, 1, 4, 18} std::vector m = {-0.454545f, -0.444930f, -0.435315f, -0.425699f, -0.416084f, -0.406469f, -0.396853f, -0.387238f, -0.377622f, -0.368007f, -0.358392f, -0.348776f, -0.339161f, -0.329545f, -0.319930f, -0.310315f, -0.300699f, -0.291084f, -0.281469f, -0.271853f, -0.262238f, -0.252622f, -0.243007f, -0.233392f, -0.223776f, -0.214161f, -0.204545f, -0.194930f, -0.185315f, -0.175699f, -0.166084f, -0.156469f, -0.146853f, -0.137238f, -0.127622f, -0.118007f, -0.108392f, -0.098776f, -0.089161f, -0.079545f, -0.069930f, -0.060315f, -0.050699f, -0.041084f, -0.031469f, -0.021853f, -0.012238f, -0.002622f, 0.006993f, 0.016608f, 0.026224f, 0.035839f, 0.045455f, 0.055070f, 0.064685f, 0.074301f, 0.083916f, 0.093531f, 0.103147f, 0.112762f, 0.122378f, 0.131993f, 0.141608f, 0.151224f, 0.160839f, 0.170455f, 0.180070f, 0.189685f, 0.199301f, 0.208916f, 0.218531f, 0.228147f, 0.237762f, 0.247378f, 0.256993f, 0.266608f, 0.276224f, 0.285839f, 0.295455f, 0.305070f, 0.314685f, 0.324301f, 0.333916f, 0.343531f, 0.353147f, 0.362762f, 0.372378f, 0.381993f, 0.391608f, 0.401224f, 0.410839f, 0.420455f, 0.430070f, 0.439685f, 0.449301f, 0.458916f, 0.468531f, 0.478147f, 0.487762f, 0.497378f, 0.506993f, 0.516608f, 0.526224f, 0.535839f}; // {2, 3, 12, 4} std::vector past_key = {-0.454545f, -0.448593f, -0.442641f, -0.436688f, -0.430736f, -0.424784f, -0.418831f, -0.412879f, -0.406926f, -0.400974f, -0.395022f, -0.389069f, -0.383117f, -0.377165f, -0.371212f, -0.365260f, -0.359307f, -0.353355f, -0.347403f, -0.341450f, -0.335498f, -0.329545f, -0.323593f, -0.317641f, -0.311688f, -0.305736f, -0.299784f, -0.293831f, -0.287879f, -0.281926f, -0.275974f, -0.270022f, -0.264069f, -0.258117f, -0.252165f, -0.246212f, -0.240260f, -0.234307f, -0.228355f, -0.222403f, -0.216450f, -0.210498f, -0.204545f, -0.198593f, -0.192641f, -0.186688f, -0.180736f, -0.174784f, -0.168831f, -0.162879f, -0.156926f, -0.150974f, -0.145022f, -0.139069f, -0.133117f, -0.127164f, -0.121212f, -0.115260f, -0.109307f, -0.103355f, -0.097403f, -0.091450f, -0.085498f, -0.079545f, -0.073593f, -0.067641f, -0.061688f, -0.055736f, -0.049784f, -0.043831f, -0.037879f, -0.031926f, -0.025974f, -0.020022f, -0.014069f, -0.008117f, -0.002165f, 0.003788f, 0.009740f, 0.015693f, 0.021645f, 0.027597f, 0.033550f, 0.039502f, 0.045455f, 0.051407f, 0.057359f, 0.063312f, 0.069264f, 0.075216f, 0.081169f, 0.087121f, 0.093074f, 0.099026f, 0.104978f, 0.110931f, 0.116883f, 0.122835f, 0.128788f, 0.134740f, 0.140693f, 0.146645f, 0.152597f, 0.158550f, 0.164502f, 0.170455f, 0.176407f, 0.182359f, 0.188312f, 0.194264f, 0.200216f, 0.206169f, 0.212121f, 0.218074f, 0.224026f, 0.229978f, 0.235931f, 0.241883f, 0.247836f, 0.253788f, 0.259740f, 0.265693f, 0.271645f, 0.277597f, 0.283550f, 0.289502f, 0.295455f, 0.301407f, 0.307359f, 0.313312f, 0.319264f, 0.325216f, 0.331169f, 0.337121f, 0.343074f, 0.349026f, 0.354978f, 0.360931f, 0.366883f, 0.372835f, 0.378788f, 0.384740f, 0.390693f, 0.396645f, 0.402597f, 0.408550f, 0.414502f, 0.420455f, 0.426407f, 0.432359f, 0.438312f, 0.444264f, 0.450216f, 0.456169f, 0.462121f, 0.468074f, 0.474026f, 0.479978f, 0.485931f, 0.491883f, 0.497835f, 0.503788f, 0.509740f, 0.515693f, 0.521645f, 0.527597f, 0.533550f, 0.539502f}; @@ -1132,7 +1165,7 @@ TEST(AttentionTest, Attention4DWithMask3DPastAndPresentQkMatmulCausal) { // {2, 3, 4, 4} std::vector y = {-0.393782f, -0.387694f, -0.381606f, -0.375519f, -0.397492f, -0.391304f, -0.385116f, -0.378928f, -0.397474f, -0.391207f, -0.384941f, -0.378674f, -0.394849f, -0.388519f, -0.382190f, -0.375860f, -0.226271f, -0.220186f, -0.214101f, -0.208016f, -0.230042f, -0.223857f, -0.217672f, -0.211488f, -0.230104f, -0.223841f, -0.217577f, -0.211314f, -0.227525f, -0.221197f, -0.214870f, -0.208543f, -0.058757f, -0.052674f, -0.046592f, -0.040510f, -0.062587f, -0.056406f, -0.050224f, -0.044042f, -0.062730f, -0.056470f, -0.050209f, -0.043949f, -0.060198f, -0.053873f, -0.047548f, -0.041223f, 0.108760f, 0.114840f, 0.120919f, 0.126999f, 0.104873f, 0.111051f, 0.117229f, 0.123408f, 0.104648f, 0.110906f, 0.117163f, 0.123421f, 0.107131f, 0.113454f, 0.119777f, 0.126099f, 0.276279f, 0.282356f, 0.288433f, 0.294510f, 0.272337f, 0.278512f, 0.284687f, 0.290862f, 0.272031f, 0.278286f, 0.284540f, 0.290794f, 0.274463f, 0.280783f, 0.287104f, 0.293424f, 0.443800f, 0.449874f, 0.455949f, 0.462023f, 0.439807f, 0.445978f, 0.452150f, 0.458321f, 0.439418f, 0.445669f, 0.451921f, 0.458172f, 0.441797f, 0.448115f, 0.454433f, 0.460751f}; - // {2, 3, 13, 4} + // {2, 3, 12, 4} std::vector present_key = {-0.454545f, -0.448593f, -0.442641f, -0.436688f, -0.430736f, -0.424784f, -0.418831f, -0.412879f, -0.406926f, -0.400974f, -0.395022f, -0.389069f, -0.383117f, -0.377165f, -0.371212f, -0.365260f, -0.359307f, -0.353355f, -0.347403f, -0.341450f, -0.335498f, -0.329545f, -0.323593f, -0.317641f, -0.311688f, -0.305736f, -0.299784f, -0.293831f, -0.454545f, -0.447601f, -0.440657f, -0.433712f, -0.426768f, -0.419823f, -0.412879f, -0.405934f, -0.398990f, -0.392045f, -0.385101f, -0.378157f, -0.371212f, -0.364268f, -0.357323f, -0.350379f, -0.343434f, -0.336490f, -0.329545f, -0.322601f, -0.315657f, -0.308712f, -0.301768f, -0.294823f, -0.287879f, -0.281926f, -0.275974f, -0.270022f, -0.264069f, -0.258117f, -0.252165f, -0.246212f, -0.240260f, -0.234307f, -0.228355f, -0.222403f, -0.216450f, -0.210498f, -0.204545f, -0.198593f, -0.192641f, -0.186688f, -0.180736f, -0.174784f, -0.168831f, -0.162879f, -0.156926f, -0.150974f, -0.145022f, -0.139069f, -0.133117f, -0.127164f, -0.287879f, -0.280934f, -0.273990f, -0.267045f, -0.260101f, -0.253157f, -0.246212f, -0.239268f, -0.232323f, -0.225379f, -0.218434f, -0.211490f, -0.204545f, -0.197601f, -0.190657f, -0.183712f, -0.176768f, -0.169823f, -0.162879f, -0.155934f, -0.148990f, -0.142045f, -0.135101f, -0.128157f, -0.121212f, -0.115260f, -0.109307f, -0.103355f, -0.097403f, -0.091450f, -0.085498f, -0.079545f, -0.073593f, -0.067641f, -0.061688f, -0.055736f, -0.049784f, -0.043831f, -0.037879f, -0.031926f, -0.025974f, -0.020022f, -0.014069f, -0.008117f, -0.002165f, 0.003788f, 0.009740f, 0.015693f, 0.021645f, 0.027597f, 0.033550f, 0.039502f, -0.121212f, -0.114268f, -0.107323f, -0.100379f, -0.093434f, -0.086490f, -0.079545f, -0.072601f, -0.065657f, -0.058712f, -0.051768f, -0.044823f, -0.037879f, -0.030934f, -0.023990f, -0.017045f, -0.010101f, -0.003157f, 0.003788f, 0.010732f, 0.017677f, 0.024621f, 0.031566f, 0.038510f, 0.045455f, 0.051407f, 0.057359f, 0.063312f, 0.069264f, 0.075216f, 0.081169f, 0.087121f, 0.093074f, 0.099026f, 0.104978f, 0.110931f, 0.116883f, 0.122835f, 0.128788f, 0.134740f, 0.140693f, 0.146645f, 0.152597f, 0.158550f, 0.164502f, 0.170455f, 0.176407f, 0.182359f, 0.188312f, 0.194264f, 0.200216f, 0.206169f, 0.045455f, 0.052399f, 0.059343f, 0.066288f, 0.073232f, 0.080177f, 0.087121f, 0.094066f, 0.101010f, 0.107955f, 0.114899f, 0.121843f, 0.128788f, 0.135732f, 0.142677f, 0.149621f, 0.156566f, 0.163510f, 0.170455f, 0.177399f, 0.184343f, 0.191288f, 0.198232f, 0.205177f, 0.212121f, 0.218074f, 0.224026f, 0.229978f, 0.235931f, 0.241883f, 0.247836f, 0.253788f, 0.259740f, 0.265693f, 0.271645f, 0.277597f, 0.283550f, 0.289502f, 0.295455f, 0.301407f, 0.307359f, 0.313312f, 0.319264f, 0.325216f, 0.331169f, 0.337121f, 0.343074f, 0.349026f, 0.354978f, 0.360931f, 0.366883f, 0.372835f, 0.212121f, 0.219066f, 0.226010f, 0.232955f, 0.239899f, 0.246843f, 0.253788f, 0.260732f, 0.267677f, 0.274621f, 0.281566f, 0.288510f, 0.295455f, 0.302399f, 0.309343f, 0.316288f, 0.323232f, 0.330177f, 0.337121f, 0.344066f, 0.351010f, 0.357955f, 0.364899f, 0.371843f, 0.378788f, 0.384740f, 0.390693f, 0.396645f, 0.402597f, 0.408550f, 0.414502f, 0.420455f, 0.426407f, 0.432359f, 0.438312f, 0.444264f, 0.450216f, 0.456169f, 0.462121f, 0.468074f, 0.474026f, 0.479978f, 0.485931f, 0.491883f, 0.497835f, 0.503788f, 0.509740f, 0.515693f, 0.521645f, 0.527597f, 0.533550f, 0.539502f, 0.378788f, 0.385732f, 0.392677f, 0.399621f, 0.406566f, 0.413510f, 0.420455f, 0.427399f, 0.434343f, 0.441288f, 0.448232f, 0.455177f, 0.462121f, 0.469066f, 0.476010f, 0.482955f, 0.489899f, 0.496843f, 0.503788f, 0.510732f, 0.517677f, 0.524621f, 0.531566f, 0.538510f}; // {2, 3, 18, 8} std::vector present_value = {-0.454545f, -0.448593f, -0.442641f, -0.436688f, -0.430736f, -0.424784f, -0.418831f, -0.412879f, -0.406926f, -0.400974f, -0.395022f, -0.389069f, -0.383117f, -0.377165f, -0.371212f, -0.365260f, -0.359307f, -0.353355f, -0.347403f, -0.341450f, -0.335498f, -0.329545f, -0.323593f, -0.317641f, -0.311688f, -0.305736f, -0.299784f, -0.293831f, -0.454545f, -0.447601f, -0.440657f, -0.433712f, -0.426768f, -0.419823f, -0.412879f, -0.405934f, -0.398990f, -0.392045f, -0.385101f, -0.378157f, -0.371212f, -0.364268f, -0.357323f, -0.350379f, -0.343434f, -0.336490f, -0.329545f, -0.322601f, -0.315657f, -0.308712f, -0.301768f, -0.294823f, -0.287879f, -0.281926f, -0.275974f, -0.270022f, -0.264069f, -0.258117f, -0.252165f, -0.246212f, -0.240260f, -0.234307f, -0.228355f, -0.222403f, -0.216450f, -0.210498f, -0.204545f, -0.198593f, -0.192641f, -0.186688f, -0.180736f, -0.174784f, -0.168831f, -0.162879f, -0.156926f, -0.150974f, -0.145022f, -0.139069f, -0.133117f, -0.127164f, -0.287879f, -0.280934f, -0.273990f, -0.267045f, -0.260101f, -0.253157f, -0.246212f, -0.239268f, -0.232323f, -0.225379f, -0.218434f, -0.211490f, -0.204545f, -0.197601f, -0.190657f, -0.183712f, -0.176768f, -0.169823f, -0.162879f, -0.155934f, -0.148990f, -0.142045f, -0.135101f, -0.128157f, -0.121212f, -0.115260f, -0.109307f, -0.103355f, -0.097403f, -0.091450f, -0.085498f, -0.079545f, -0.073593f, -0.067641f, -0.061688f, -0.055736f, -0.049784f, -0.043831f, -0.037879f, -0.031926f, -0.025974f, -0.020022f, -0.014069f, -0.008117f, -0.002165f, 0.003788f, 0.009740f, 0.015693f, 0.021645f, 0.027597f, 0.033550f, 0.039502f, -0.121212f, -0.114268f, -0.107323f, -0.100379f, -0.093434f, -0.086490f, -0.079545f, -0.072601f, -0.065657f, -0.058712f, -0.051768f, -0.044823f, -0.037879f, -0.030934f, -0.023990f, -0.017045f, -0.010101f, -0.003157f, 0.003788f, 0.010732f, 0.017677f, 0.024621f, 0.031566f, 0.038510f, 0.045455f, 0.051407f, 0.057359f, 0.063312f, 0.069264f, 0.075216f, 0.081169f, 0.087121f, 0.093074f, 0.099026f, 0.104978f, 0.110931f, 0.116883f, 0.122835f, 0.128788f, 0.134740f, 0.140693f, 0.146645f, 0.152597f, 0.158550f, 0.164502f, 0.170455f, 0.176407f, 0.182359f, 0.188312f, 0.194264f, 0.200216f, 0.206169f, 0.045455f, 0.052399f, 0.059343f, 0.066288f, 0.073232f, 0.080177f, 0.087121f, 0.094066f, 0.101010f, 0.107955f, 0.114899f, 0.121843f, 0.128788f, 0.135732f, 0.142677f, 0.149621f, 0.156566f, 0.163510f, 0.170455f, 0.177399f, 0.184343f, 0.191288f, 0.198232f, 0.205177f, 0.212121f, 0.218074f, 0.224026f, 0.229978f, 0.235931f, 0.241883f, 0.247836f, 0.253788f, 0.259740f, 0.265693f, 0.271645f, 0.277597f, 0.283550f, 0.289502f, 0.295455f, 0.301407f, 0.307359f, 0.313312f, 0.319264f, 0.325216f, 0.331169f, 0.337121f, 0.343074f, 0.349026f, 0.354978f, 0.360931f, 0.366883f, 0.372835f, 0.212121f, 0.219066f, 0.226010f, 0.232955f, 0.239899f, 0.246843f, 0.253788f, 0.260732f, 0.267677f, 0.274621f, 0.281566f, 0.288510f, 0.295455f, 0.302399f, 0.309343f, 0.316288f, 0.323232f, 0.330177f, 0.337121f, 0.344066f, 0.351010f, 0.357955f, 0.364899f, 0.371843f, 0.378788f, 0.384740f, 0.390693f, 0.396645f, 0.402597f, 0.408550f, 0.414502f, 0.420455f, 0.426407f, 0.432359f, 0.438312f, 0.444264f, 0.450216f, 0.456169f, 0.462121f, 0.468074f, 0.474026f, 0.479978f, 0.485931f, 0.491883f, 0.497835f, 0.503788f, 0.509740f, 0.515693f, 0.521645f, 0.527597f, 0.533550f, 0.539502f, 0.378788f, 0.385732f, 0.392677f, 0.399621f, 0.406566f, 0.413510f, 0.420455f, 0.427399f, 0.434343f, 0.441288f, 0.448232f, 0.455177f, 0.462121f, 0.469066f, 0.476010f, 0.482955f, 0.489899f, 0.496843f, 0.503788f, 0.510732f, 0.517677f, 0.524621f, 0.531566f, 0.538510f}; @@ -1151,28 +1184,28 @@ TEST(AttentionTest, Attention4DWithMask3DPastAndPresentQkMatmulCausal) { ); } -TEST(AttentionTest, Attention4DWithMask4DPastAndPresentQkMatmul) { - int batch_size = 2; // Q.shape[0] - int q_num_heads = 3; // Q.shape[1] - int q_sequence_length = 4; // Q.shape[2] - int head_size = 4; // Q.shape[3] - int kv_sequence_length = 6; // K.shape[2] and V.shape[2] - int kv_num_heads = 3; // K.shape[1] and V.shape[1] - int v_head_size = 4; // V.shape[3] - int past_sequence_length = 7; // past_key.shape[2] and past_value.shape[2] +TEST(AttentionTest, TestAttention4DWithPastAndPresentQkMatmulBias4DMaskCausal) { + int batch_size = 2; // Q.shape[0] + int q_num_heads = 3; // Q.shape[1] + int q_sequence_length = 4; // Q.shape[2] + int head_size = 8; // Q.shape[3] + int kv_sequence_length = 6; // K.shape[2] and V.shape[2] + int kv_num_heads = 3; // K.shape[1] and V.shape[1] + int v_head_size = 8; // V.shape[3] + int past_sequence_length = 12; // past_key.shape[2] and past_value.shape[2] - // {2, 3, 4, 4} - std::vector q = {-0.454545f, -0.444129f, -0.433712f, -0.423295f, -0.412879f, -0.402462f, -0.392045f, -0.381629f, -0.371212f, -0.360795f, -0.350379f, -0.339962f, -0.329545f, -0.319129f, -0.308712f, -0.298295f, -0.287879f, -0.277462f, -0.267045f, -0.256629f, -0.246212f, -0.235795f, -0.225379f, -0.214962f, -0.204545f, -0.194129f, -0.183712f, -0.173295f, -0.162879f, -0.152462f, -0.142045f, -0.131629f, -0.121212f, -0.110795f, -0.100379f, -0.089962f, -0.079545f, -0.069129f, -0.058712f, -0.048295f, -0.037879f, -0.027462f, -0.017045f, -0.006629f, 0.003788f, 0.014205f, 0.024621f, 0.035038f, 0.045455f, 0.055871f, 0.066288f, 0.076705f, 0.087121f, 0.097538f, 0.107955f, 0.118371f, 0.128788f, 0.139205f, 0.149621f, 0.160038f, 0.170455f, 0.180871f, 0.191288f, 0.201705f, 0.212121f, 0.222538f, 0.232955f, 0.243371f, 0.253788f, 0.264205f, 0.274621f, 0.285038f, 0.295455f, 0.305871f, 0.316288f, 0.326705f, 0.337121f, 0.347538f, 0.357955f, 0.368371f, 0.378788f, 0.389205f, 0.399621f, 0.410038f, 0.420455f, 0.430871f, 0.441288f, 0.451705f, 0.462121f, 0.472538f, 0.482955f, 0.493371f, 0.503788f, 0.514205f, 0.524621f, 0.535038f}; - // {2, 3, 6, 4} - std::vector k = {-0.454545f, -0.447601f, -0.440657f, -0.433712f, -0.426768f, -0.419823f, -0.412879f, -0.405934f, -0.398990f, -0.392045f, -0.385101f, -0.378157f, -0.371212f, -0.364268f, -0.357323f, -0.350379f, -0.343434f, -0.336490f, -0.329545f, -0.322601f, -0.315657f, -0.308712f, -0.301768f, -0.294823f, -0.287879f, -0.280934f, -0.273990f, -0.267045f, -0.260101f, -0.253157f, -0.246212f, -0.239268f, -0.232323f, -0.225379f, -0.218434f, -0.211490f, -0.204545f, -0.197601f, -0.190657f, -0.183712f, -0.176768f, -0.169823f, -0.162879f, -0.155934f, -0.148990f, -0.142045f, -0.135101f, -0.128157f, -0.121212f, -0.114268f, -0.107323f, -0.100379f, -0.093434f, -0.086490f, -0.079545f, -0.072601f, -0.065657f, -0.058712f, -0.051768f, -0.044823f, -0.037879f, -0.030934f, -0.023990f, -0.017045f, -0.010101f, -0.003157f, 0.003788f, 0.010732f, 0.017677f, 0.024621f, 0.031566f, 0.038510f, 0.045455f, 0.052399f, 0.059343f, 0.066288f, 0.073232f, 0.080177f, 0.087121f, 0.094066f, 0.101010f, 0.107955f, 0.114899f, 0.121843f, 0.128788f, 0.135732f, 0.142677f, 0.149621f, 0.156566f, 0.163510f, 0.170455f, 0.177399f, 0.184343f, 0.191288f, 0.198232f, 0.205177f, 0.212121f, 0.219066f, 0.226010f, 0.232955f, 0.239899f, 0.246843f, 0.253788f, 0.260732f, 0.267677f, 0.274621f, 0.281566f, 0.288510f, 0.295455f, 0.302399f, 0.309343f, 0.316288f, 0.323232f, 0.330177f, 0.337121f, 0.344066f, 0.351010f, 0.357955f, 0.364899f, 0.371843f, 0.378788f, 0.385732f, 0.392677f, 0.399621f, 0.406566f, 0.413510f, 0.420455f, 0.427399f, 0.434343f, 0.441288f, 0.448232f, 0.455177f, 0.462121f, 0.469066f, 0.476010f, 0.482955f, 0.489899f, 0.496843f, 0.503788f, 0.510732f, 0.517677f, 0.524621f, 0.531566f, 0.538510f}; + // {2, 3, 4, 8} + std::vector q = {0.548814f, 0.715189f, 0.602763f, 0.544883f, 0.423655f, 0.645894f, 0.437587f, 0.891773f, 0.963663f, 0.383442f, 0.791725f, 0.528895f, 0.568045f, 0.925597f, 0.071036f, 0.087129f, 0.020218f, 0.832620f, 0.778157f, 0.870012f, 0.978618f, 0.799159f, 0.461479f, 0.780529f, 0.118274f, 0.639921f, 0.143353f, 0.944669f, 0.521848f, 0.414662f, 0.264556f, 0.774234f, 0.456150f, 0.568434f, 0.018790f, 0.617635f, 0.612096f, 0.616934f, 0.943748f, 0.681820f, 0.359508f, 0.437032f, 0.697631f, 0.060225f, 0.666767f, 0.670638f, 0.210383f, 0.128926f, 0.315428f, 0.363711f, 0.570197f, 0.438602f, 0.988374f, 0.102045f, 0.208877f, 0.161310f, 0.653108f, 0.253292f, 0.466311f, 0.244426f, 0.158970f, 0.110375f, 0.656330f, 0.138183f, 0.196582f, 0.368725f, 0.820993f, 0.097101f, 0.837945f, 0.096098f, 0.976459f, 0.468651f, 0.976761f, 0.604846f, 0.739264f, 0.039188f, 0.282807f, 0.120197f, 0.296140f, 0.118728f, 0.317983f, 0.414263f, 0.064147f, 0.692472f, 0.566601f, 0.265390f, 0.523248f, 0.093941f, 0.575947f, 0.929296f, 0.318569f, 0.667410f, 0.131798f, 0.716327f, 0.289406f, 0.183191f, 0.586513f, 0.020108f, 0.828940f, 0.004695f, 0.677817f, 0.270008f, 0.735194f, 0.962189f, 0.248753f, 0.576157f, 0.592042f, 0.572252f, 0.223082f, 0.952749f, 0.447125f, 0.846409f, 0.699479f, 0.297437f, 0.813798f, 0.396506f, 0.881103f, 0.581273f, 0.881735f, 0.692532f, 0.725254f, 0.501324f, 0.956084f, 0.643990f, 0.423855f, 0.606393f, 0.019193f, 0.301575f, 0.660174f, 0.290078f, 0.618015f, 0.428769f, 0.135474f, 0.298282f, 0.569965f, 0.590873f, 0.574325f, 0.653201f, 0.652103f, 0.431418f, 0.896547f, 0.367562f, 0.435865f, 0.891923f, 0.806194f, 0.703889f, 0.100227f, 0.919483f, 0.714241f, 0.998847f, 0.149448f, 0.868126f, 0.162493f, 0.615560f, 0.123820f, 0.848008f, 0.807319f, 0.569101f, 0.407183f, 0.069167f, 0.697429f, 0.453543f, 0.722056f, 0.866382f, 0.975522f, 0.855803f, 0.011714f, 0.359978f, 0.729991f, 0.171630f, 0.521037f, 0.054338f, 0.199997f, 0.018522f, 0.793698f, 0.223925f, 0.345352f, 0.928081f, 0.704414f, 0.031839f, 0.164694f, 0.621478f, 0.577229f, 0.237893f, 0.934214f, 0.613966f, 0.535633f, 0.589910f, 0.730122f, 0.311945f, 0.398221f, 0.209844f}; + // {2, 3, 6, 8} + std::vector k = {0.186193f, 0.944372f, 0.739551f, 0.490459f, 0.227415f, 0.254356f, 0.058029f, 0.434417f, 0.311796f, 0.696343f, 0.377752f, 0.179604f, 0.024679f, 0.067250f, 0.679393f, 0.453697f, 0.536579f, 0.896671f, 0.990339f, 0.216897f, 0.663078f, 0.263322f, 0.020651f, 0.758379f, 0.320017f, 0.383464f, 0.588317f, 0.831048f, 0.628982f, 0.872651f, 0.273542f, 0.798047f, 0.185636f, 0.952792f, 0.687488f, 0.215508f, 0.947371f, 0.730856f, 0.253942f, 0.213312f, 0.518201f, 0.025663f, 0.207470f, 0.424685f, 0.374170f, 0.463575f, 0.277629f, 0.586784f, 0.863856f, 0.117532f, 0.517379f, 0.132068f, 0.716860f, 0.396060f, 0.565421f, 0.183280f, 0.144848f, 0.488056f, 0.355613f, 0.940432f, 0.765325f, 0.748664f, 0.903720f, 0.083422f, 0.552192f, 0.584476f, 0.961936f, 0.292148f, 0.240829f, 0.100294f, 0.016430f, 0.929529f, 0.669917f, 0.785153f, 0.281730f, 0.586410f, 0.063955f, 0.485628f, 0.977495f, 0.876505f, 0.338159f, 0.961570f, 0.231702f, 0.949319f, 0.941378f, 0.799203f, 0.630448f, 0.874288f, 0.293020f, 0.848944f, 0.617877f, 0.013237f, 0.347234f, 0.148141f, 0.981829f, 0.478370f, 0.497391f, 0.639473f, 0.368585f, 0.136900f, 0.822118f, 0.189848f, 0.511319f, 0.224317f, 0.097844f, 0.862191f, 0.972919f, 0.960835f, 0.906555f, 0.774047f, 0.333145f, 0.081101f, 0.407241f, 0.232234f, 0.132488f, 0.053427f, 0.725594f, 0.011427f, 0.770581f, 0.146947f, 0.079522f, 0.089603f, 0.672048f, 0.245367f, 0.420539f, 0.557369f, 0.860551f, 0.727044f, 0.270328f, 0.131483f, 0.055374f, 0.301599f, 0.262118f, 0.456141f, 0.683281f, 0.695625f, 0.283519f, 0.379927f, 0.181151f, 0.788545f, 0.056848f, 0.696997f, 0.778695f, 0.777408f, 0.259423f, 0.373813f, 0.587600f, 0.272822f, 0.370853f, 0.197054f, 0.459856f, 0.044612f, 0.799796f, 0.076956f, 0.518835f, 0.306810f, 0.577543f, 0.959433f, 0.645570f, 0.035362f, 0.430402f, 0.510017f, 0.536178f, 0.681392f, 0.277596f, 0.128861f, 0.392676f, 0.956406f, 0.187131f, 0.903984f, 0.543806f, 0.456911f, 0.882041f, 0.458604f, 0.724168f, 0.399025f, 0.904044f, 0.690025f, 0.699622f, 0.327720f, 0.756779f, 0.636061f, 0.240020f, 0.160539f, 0.796391f, 0.959167f, 0.458139f, 0.590984f, 0.857723f, 0.457223f, 0.951874f, 0.575751f, 0.820767f, 0.908844f, 0.815524f, 0.159414f, 0.628898f, 0.398434f, 0.062713f, 0.424032f, 0.258684f, 0.849038f, 0.033305f, 0.958983f, 0.355369f, 0.356707f, 0.016329f, 0.185232f, 0.401260f, 0.929291f, 0.099615f, 0.945302f, 0.869489f, 0.454162f, 0.326701f, 0.232744f, 0.614465f, 0.033075f, 0.015606f, 0.428796f, 0.068074f, 0.251941f, 0.221161f, 0.253191f, 0.131055f, 0.012036f, 0.115484f, 0.618480f, 0.974256f, 0.990345f, 0.409054f, 0.162954f, 0.638762f, 0.490305f, 0.989410f, 0.065304f, 0.783234f, 0.288399f, 0.241419f, 0.662505f, 0.246063f, 0.665859f, 0.517309f, 0.424089f, 0.554688f, 0.287052f, 0.706575f, 0.414857f, 0.360546f, 0.828657f, 0.924967f, 0.046007f, 0.232627f, 0.348519f, 0.814966f, 0.985491f, 0.968972f, 0.904948f, 0.296556f, 0.992011f, 0.249420f, 0.105906f, 0.950953f, 0.233420f, 0.689768f, 0.058356f, 0.730709f, 0.881720f, 0.272437f, 0.379057f, 0.374296f, 0.748788f, 0.237807f, 0.171853f, 0.449292f, 0.304468f, 0.839189f, 0.237742f, 0.502389f, 0.942584f, 0.633998f, 0.867289f, 0.940210f, 0.750765f, 0.699575f, 0.967966f, 0.994401f, 0.451822f}; // {2, 3, 6, 4} - std::vector v = {-0.454545f, -0.447601f, -0.440657f, -0.433712f, -0.426768f, -0.419823f, -0.412879f, -0.405934f, -0.398990f, -0.392045f, -0.385101f, -0.378157f, -0.371212f, -0.364268f, -0.357323f, -0.350379f, -0.343434f, -0.336490f, -0.329545f, -0.322601f, -0.315657f, -0.308712f, -0.301768f, -0.294823f, -0.287879f, -0.280934f, -0.273990f, -0.267045f, -0.260101f, -0.253157f, -0.246212f, -0.239268f, -0.232323f, -0.225379f, -0.218434f, -0.211490f, -0.204545f, -0.197601f, -0.190657f, -0.183712f, -0.176768f, -0.169823f, -0.162879f, -0.155934f, -0.148990f, -0.142045f, -0.135101f, -0.128157f, -0.121212f, -0.114268f, -0.107323f, -0.100379f, -0.093434f, -0.086490f, -0.079545f, -0.072601f, -0.065657f, -0.058712f, -0.051768f, -0.044823f, -0.037879f, -0.030934f, -0.023990f, -0.017045f, -0.010101f, -0.003157f, 0.003788f, 0.010732f, 0.017677f, 0.024621f, 0.031566f, 0.038510f, 0.045455f, 0.052399f, 0.059343f, 0.066288f, 0.073232f, 0.080177f, 0.087121f, 0.094066f, 0.101010f, 0.107955f, 0.114899f, 0.121843f, 0.128788f, 0.135732f, 0.142677f, 0.149621f, 0.156566f, 0.163510f, 0.170455f, 0.177399f, 0.184343f, 0.191288f, 0.198232f, 0.205177f, 0.212121f, 0.219066f, 0.226010f, 0.232955f, 0.239899f, 0.246843f, 0.253788f, 0.260732f, 0.267677f, 0.274621f, 0.281566f, 0.288510f, 0.295455f, 0.302399f, 0.309343f, 0.316288f, 0.323232f, 0.330177f, 0.337121f, 0.344066f, 0.351010f, 0.357955f, 0.364899f, 0.371843f, 0.378788f, 0.385732f, 0.392677f, 0.399621f, 0.406566f, 0.413510f, 0.420455f, 0.427399f, 0.434343f, 0.441288f, 0.448232f, 0.455177f, 0.462121f, 0.469066f, 0.476010f, 0.482955f, 0.489899f, 0.496843f, 0.503788f, 0.510732f, 0.517677f, 0.524621f, 0.531566f, 0.538510f}; - // {2, 3, 4, 13} - std::vector m = {-0.454545f, -0.451340f, -0.448135f, -0.444930f, -0.441725f, -0.438520f, -0.435315f, -0.432110f, -0.428904f, -0.425699f, -0.422494f, -0.419289f, -0.416084f, -0.412879f, -0.409674f, -0.406469f, -0.403263f, -0.400058f, -0.396853f, -0.393648f, -0.390443f, -0.387238f, -0.384033f, -0.380828f, -0.377622f, -0.374417f, -0.371212f, -0.368007f, -0.364802f, -0.361597f, -0.358392f, -0.355186f, -0.351981f, -0.348776f, -0.345571f, -0.342366f, -0.339161f, -0.335956f, -0.332751f, -0.329545f, -0.326340f, -0.323135f, -0.319930f, -0.316725f, -0.313520f, -0.310315f, -0.307110f, -0.303904f, -0.300699f, -0.297494f, -0.294289f, -0.291084f, -0.287879f, -0.284674f, -0.281469f, -0.278263f, -0.275058f, -0.271853f, -0.268648f, -0.265443f, -0.262238f, -0.259033f, -0.255828f, -0.252622f, -0.249417f, -0.246212f, -0.243007f, -0.239802f, -0.236597f, -0.233392f, -0.230186f, -0.226981f, -0.223776f, -0.220571f, -0.217366f, -0.214161f, -0.210956f, -0.207751f, -0.204545f, -0.201340f, -0.198135f, -0.194930f, -0.191725f, -0.188520f, -0.185315f, -0.182110f, -0.178904f, -0.175699f, -0.172494f, -0.169289f, -0.166084f, -0.162879f, -0.159674f, -0.156469f, -0.153263f, -0.150058f, -0.146853f, -0.143648f, -0.140443f, -0.137238f, -0.134033f, -0.130828f, -0.127622f, -0.124417f, -0.121212f, -0.118007f, -0.114802f, -0.111597f, -0.108392f, -0.105186f, -0.101981f, -0.098776f, -0.095571f, -0.092366f, -0.089161f, -0.085956f, -0.082751f, -0.079545f, -0.076340f, -0.073135f, -0.069930f, -0.066725f, -0.063520f, -0.060315f, -0.057110f, -0.053904f, -0.050699f, -0.047494f, -0.044289f, -0.041084f, -0.037879f, -0.034674f, -0.031469f, -0.028263f, -0.025058f, -0.021853f, -0.018648f, -0.015443f, -0.012238f, -0.009033f, -0.005828f, -0.002622f, 0.000583f, 0.003788f, 0.006993f, 0.010198f, 0.013403f, 0.016608f, 0.019814f, 0.023019f, 0.026224f, 0.029429f, 0.032634f, 0.035839f, 0.039044f, 0.042249f, 0.045455f, 0.048660f, 0.051865f, 0.055070f, 0.058275f, 0.061480f, 0.064685f, 0.067890f, 0.071096f, 0.074301f, 0.077506f, 0.080711f, 0.083916f, 0.087121f, 0.090326f, 0.093531f, 0.096737f, 0.099942f, 0.103147f, 0.106352f, 0.109557f, 0.112762f, 0.115967f, 0.119172f, 0.122378f, 0.125583f, 0.128788f, 0.131993f, 0.135198f, 0.138403f, 0.141608f, 0.144814f, 0.148019f, 0.151224f, 0.154429f, 0.157634f, 0.160839f, 0.164044f, 0.167249f, 0.170455f, 0.173660f, 0.176865f, 0.180070f, 0.183275f, 0.186480f, 0.189685f, 0.192890f, 0.196096f, 0.199301f, 0.202506f, 0.205711f, 0.208916f, 0.212121f, 0.215326f, 0.218531f, 0.221737f, 0.224942f, 0.228147f, 0.231352f, 0.234557f, 0.237762f, 0.240967f, 0.244172f, 0.247378f, 0.250583f, 0.253788f, 0.256993f, 0.260198f, 0.263403f, 0.266608f, 0.269814f, 0.273019f, 0.276224f, 0.279429f, 0.282634f, 0.285839f, 0.289044f, 0.292249f, 0.295455f, 0.298660f, 0.301865f, 0.305070f, 0.308275f, 0.311480f, 0.314685f, 0.317890f, 0.321096f, 0.324301f, 0.327506f, 0.330711f, 0.333916f, 0.337121f, 0.340326f, 0.343531f, 0.346737f, 0.349942f, 0.353147f, 0.356352f, 0.359557f, 0.362762f, 0.365967f, 0.369172f, 0.372378f, 0.375583f, 0.378788f, 0.381993f, 0.385198f, 0.388403f, 0.391608f, 0.394814f, 0.398019f, 0.401224f, 0.404429f, 0.407634f, 0.410839f, 0.414044f, 0.417249f, 0.420455f, 0.423660f, 0.426865f, 0.430070f, 0.433275f, 0.436480f, 0.439685f, 0.442890f, 0.446096f, 0.449301f, 0.452506f, 0.455711f, 0.458916f, 0.462121f, 0.465326f, 0.468531f, 0.471737f, 0.474942f, 0.478147f, 0.481352f, 0.484557f, 0.487762f, 0.490967f, 0.494172f, 0.497378f, 0.500583f, 0.503788f, 0.506993f, 0.510198f, 0.513403f, 0.516608f, 0.519814f, 0.523019f, 0.526224f, 0.529429f, 0.532634f, 0.535839f, 0.539044f, 0.542249f}; - // {2, 3, 12, 4} - std::vector past_key = {-0.454545f, -0.448593f, -0.442641f, -0.436688f, -0.430736f, -0.424784f, -0.418831f, -0.412879f, -0.406926f, -0.400974f, -0.395022f, -0.389069f, -0.383117f, -0.377165f, -0.371212f, -0.365260f, -0.359307f, -0.353355f, -0.347403f, -0.341450f, -0.335498f, -0.329545f, -0.323593f, -0.317641f, -0.311688f, -0.305736f, -0.299784f, -0.293831f, -0.287879f, -0.281926f, -0.275974f, -0.270022f, -0.264069f, -0.258117f, -0.252165f, -0.246212f, -0.240260f, -0.234307f, -0.228355f, -0.222403f, -0.216450f, -0.210498f, -0.204545f, -0.198593f, -0.192641f, -0.186688f, -0.180736f, -0.174784f, -0.168831f, -0.162879f, -0.156926f, -0.150974f, -0.145022f, -0.139069f, -0.133117f, -0.127164f, -0.121212f, -0.115260f, -0.109307f, -0.103355f, -0.097403f, -0.091450f, -0.085498f, -0.079545f, -0.073593f, -0.067641f, -0.061688f, -0.055736f, -0.049784f, -0.043831f, -0.037879f, -0.031926f, -0.025974f, -0.020022f, -0.014069f, -0.008117f, -0.002165f, 0.003788f, 0.009740f, 0.015693f, 0.021645f, 0.027597f, 0.033550f, 0.039502f, 0.045455f, 0.051407f, 0.057359f, 0.063312f, 0.069264f, 0.075216f, 0.081169f, 0.087121f, 0.093074f, 0.099026f, 0.104978f, 0.110931f, 0.116883f, 0.122835f, 0.128788f, 0.134740f, 0.140693f, 0.146645f, 0.152597f, 0.158550f, 0.164502f, 0.170455f, 0.176407f, 0.182359f, 0.188312f, 0.194264f, 0.200216f, 0.206169f, 0.212121f, 0.218074f, 0.224026f, 0.229978f, 0.235931f, 0.241883f, 0.247836f, 0.253788f, 0.259740f, 0.265693f, 0.271645f, 0.277597f, 0.283550f, 0.289502f, 0.295455f, 0.301407f, 0.307359f, 0.313312f, 0.319264f, 0.325216f, 0.331169f, 0.337121f, 0.343074f, 0.349026f, 0.354978f, 0.360931f, 0.366883f, 0.372835f, 0.378788f, 0.384740f, 0.390693f, 0.396645f, 0.402597f, 0.408550f, 0.414502f, 0.420455f, 0.426407f, 0.432359f, 0.438312f, 0.444264f, 0.450216f, 0.456169f, 0.462121f, 0.468074f, 0.474026f, 0.479978f, 0.485931f, 0.491883f, 0.497835f, 0.503788f, 0.509740f, 0.515693f, 0.521645f, 0.527597f, 0.533550f, 0.539502f}; - // {2, 3, 12, 4} - std::vector past_value = {-0.454545f, -0.448593f, -0.442641f, -0.436688f, -0.430736f, -0.424784f, -0.418831f, -0.412879f, -0.406926f, -0.400974f, -0.395022f, -0.389069f, -0.383117f, -0.377165f, -0.371212f, -0.365260f, -0.359307f, -0.353355f, -0.347403f, -0.341450f, -0.335498f, -0.329545f, -0.323593f, -0.317641f, -0.311688f, -0.305736f, -0.299784f, -0.293831f, -0.287879f, -0.281926f, -0.275974f, -0.270022f, -0.264069f, -0.258117f, -0.252165f, -0.246212f, -0.240260f, -0.234307f, -0.228355f, -0.222403f, -0.216450f, -0.210498f, -0.204545f, -0.198593f, -0.192641f, -0.186688f, -0.180736f, -0.174784f, -0.168831f, -0.162879f, -0.156926f, -0.150974f, -0.145022f, -0.139069f, -0.133117f, -0.127164f, -0.121212f, -0.115260f, -0.109307f, -0.103355f, -0.097403f, -0.091450f, -0.085498f, -0.079545f, -0.073593f, -0.067641f, -0.061688f, -0.055736f, -0.049784f, -0.043831f, -0.037879f, -0.031926f, -0.025974f, -0.020022f, -0.014069f, -0.008117f, -0.002165f, 0.003788f, 0.009740f, 0.015693f, 0.021645f, 0.027597f, 0.033550f, 0.039502f, 0.045455f, 0.051407f, 0.057359f, 0.063312f, 0.069264f, 0.075216f, 0.081169f, 0.087121f, 0.093074f, 0.099026f, 0.104978f, 0.110931f, 0.116883f, 0.122835f, 0.128788f, 0.134740f, 0.140693f, 0.146645f, 0.152597f, 0.158550f, 0.164502f, 0.170455f, 0.176407f, 0.182359f, 0.188312f, 0.194264f, 0.200216f, 0.206169f, 0.212121f, 0.218074f, 0.224026f, 0.229978f, 0.235931f, 0.241883f, 0.247836f, 0.253788f, 0.259740f, 0.265693f, 0.271645f, 0.277597f, 0.283550f, 0.289502f, 0.295455f, 0.301407f, 0.307359f, 0.313312f, 0.319264f, 0.325216f, 0.331169f, 0.337121f, 0.343074f, 0.349026f, 0.354978f, 0.360931f, 0.366883f, 0.372835f, 0.378788f, 0.384740f, 0.390693f, 0.396645f, 0.402597f, 0.408550f, 0.414502f, 0.420455f, 0.426407f, 0.432359f, 0.438312f, 0.444264f, 0.450216f, 0.456169f, 0.462121f, 0.468074f, 0.474026f, 0.479978f, 0.485931f, 0.491883f, 0.497835f, 0.503788f, 0.509740f, 0.515693f, 0.521645f, 0.527597f, 0.533550f, 0.539502f}; + std::vector v = {0.070870f, 0.292794f, 0.152355f, 0.417486f, 0.131289f, 0.604118f, 0.382808f, 0.895386f, 0.967795f, 0.546885f, 0.274824f, 0.592230f, 0.896761f, 0.406733f, 0.552078f, 0.271653f, 0.455444f, 0.401714f, 0.248413f, 0.505866f, 0.310381f, 0.373035f, 0.524970f, 0.750595f, 0.333507f, 0.924159f, 0.862319f, 0.048690f, 0.253643f, 0.446136f, 0.104628f, 0.348476f, 0.740098f, 0.680514f, 0.622384f, 0.710528f, 0.204924f, 0.341698f, 0.676242f, 0.879235f, 0.543678f, 0.282700f, 0.030235f, 0.710337f, 0.007884f, 0.372679f, 0.530537f, 0.922111f, 0.089495f, 0.405942f, 0.024313f, 0.342611f, 0.622231f, 0.279068f, 0.209750f, 0.115703f, 0.577140f, 0.695270f, 0.671957f, 0.948861f, 0.002703f, 0.647197f, 0.600392f, 0.588740f, 0.962770f, 0.016872f, 0.696482f, 0.813679f, 0.509807f, 0.333965f, 0.790840f, 0.097243f, 0.442036f, 0.519952f, 0.693956f, 0.090886f, 0.227759f, 0.410302f, 0.623295f, 0.886961f, 0.618826f, 0.133461f, 0.980580f, 0.871786f, 0.502721f, 0.922348f, 0.541381f, 0.923306f, 0.829897f, 0.968286f, 0.919783f, 0.036034f, 0.174772f, 0.389135f, 0.952143f, 0.300029f, 0.160468f, 0.886305f, 0.446394f, 0.907876f, 0.160230f, 0.661117f, 0.440264f, 0.076487f, 0.696463f, 0.247399f, 0.039616f, 0.059944f, 0.061079f, 0.907733f, 0.739884f, 0.898062f, 0.672582f, 0.528940f, 0.304446f, 0.997962f, 0.362189f, 0.470649f, 0.378245f, 0.979527f, 0.174658f, 0.327988f, 0.680349f, 0.063208f, 0.607249f, 0.477646f, 0.284000f, 0.238413f, 0.514513f, 0.367928f, 0.456520f, 0.337477f, 0.970494f, 0.133439f, 0.096804f, 0.343392f, 0.591027f, 0.659176f, 0.397257f, 0.999278f, 0.351893f, 0.721407f, 0.637583f, 0.813054f, 0.976226f, 0.889794f, 0.764562f, 0.698249f, 0.335498f, 0.147686f, 0.062636f, 0.241902f, 0.432281f, 0.521996f, 0.773084f, 0.958741f, 0.117320f, 0.107004f, 0.589695f, 0.745398f, 0.848150f, 0.935832f, 0.983426f, 0.399802f, 0.380335f, 0.147809f, 0.684934f, 0.656762f, 0.862063f, 0.097258f, 0.497777f, 0.581082f, 0.241557f, 0.169025f, 0.859581f, 0.058535f, 0.470621f, 0.115834f, 0.457059f, 0.979962f, 0.423706f, 0.857125f, 0.117316f, 0.271252f, 0.403793f, 0.399812f, 0.671384f, 0.344718f, 0.713767f, 0.639187f, 0.399161f, 0.431760f, 0.614528f, 0.070042f, 0.822407f, 0.653421f, 0.726342f, 0.536923f, 0.110477f, 0.405036f, 0.405374f, 0.321043f, 0.029950f, 0.737254f, 0.109784f, 0.606308f, 0.703218f, 0.634786f, 0.959142f, 0.103298f, 0.867167f, 0.029190f, 0.534917f, 0.404244f, 0.524184f, 0.365100f, 0.190567f, 0.019123f, 0.518150f, 0.842777f, 0.373216f, 0.222864f, 0.080532f, 0.085311f, 0.221396f, 0.100014f, 0.265040f, 0.066149f, 0.065605f, 0.856276f, 0.162120f, 0.559682f, 0.773456f, 0.456410f, 0.153369f, 0.199596f, 0.432984f, 0.528234f, 0.349440f, 0.781480f, 0.751022f, 0.927212f, 0.028953f, 0.895691f, 0.392569f, 0.878372f, 0.690785f, 0.987349f, 0.759282f, 0.364545f, 0.501063f, 0.376389f, 0.364912f, 0.260904f, 0.495970f, 0.681740f, 0.277340f, 0.524380f, 0.117380f, 0.159845f, 0.046806f, 0.970731f, 0.003860f, 0.178580f, 0.612867f, 0.081370f, 0.881896f, 0.719620f, 0.966390f, 0.507636f, 0.300404f, 0.549501f, 0.930819f, 0.520761f, 0.267207f, 0.877399f, 0.371919f, 0.001383f, 0.247685f, 0.318234f, 0.858777f, 0.458503f, 0.444587f, 0.336102f, 0.880678f, 0.945027f, 0.991890f, 0.376741f}; + // {2, 3, 4, 18} + std::vector m = {0.966147f, 0.791880f, 0.675689f, 0.244889f, 0.216457f, 0.166048f, 0.922757f, 0.294077f, 0.453094f, 0.493958f, 0.778172f, 0.844235f, 0.139073f, 0.426904f, 0.842855f, 0.818033f, 0.102414f, 0.156383f, 0.304199f, 0.075359f, 0.424663f, 0.107618f, 0.568218f, 0.246557f, 0.596433f, 0.117526f, 0.975884f, 0.932561f, 0.391797f, 0.242179f, 0.250398f, 0.483394f, 0.039993f, 0.639705f, 0.408303f, 0.377407f, 0.809365f, 0.709035f, 0.954334f, 0.351936f, 0.897543f, 0.769967f, 0.357425f, 0.621665f, 0.288570f, 0.874400f, 0.112427f, 0.212434f, 0.183033f, 0.403026f, 0.745233f, 0.526907f, 0.487676f, 0.000546f, 0.425402f, 0.063554f, 0.208253f, 0.932394f, 0.215398f, 0.858338f, 0.802893f, 0.159146f, 0.605712f, 0.115662f, 0.727888f, 0.637462f, 0.811939f, 0.479385f, 0.914863f, 0.049349f, 0.292889f, 0.715053f, 0.418109f, 0.172951f, 0.107211f, 0.817339f, 0.473143f, 0.882284f, 0.733289f, 0.409726f, 0.373511f, 0.515638f, 0.889060f, 0.737279f, 0.005153f, 0.694158f, 0.919507f, 0.710456f, 0.177006f, 0.483518f, 0.140316f, 0.358995f, 0.937117f, 0.923305f, 0.282837f, 0.339631f, 0.600213f, 0.963197f, 0.147801f, 0.256917f, 0.873557f, 0.491892f, 0.898961f, 0.185518f, 0.532669f, 0.326270f, 0.316543f, 0.446877f, 0.433077f, 0.357347f, 0.914971f, 0.731744f, 0.727547f, 0.289913f, 0.577709f, 0.779179f, 0.795590f, 0.344530f, 0.770873f, 0.735894f, 0.141506f, 0.865945f, 0.441321f, 0.486410f, 0.448369f, 0.567846f, 0.621169f, 0.498180f, 0.866789f, 0.627735f, 0.401428f, 0.416692f, 0.810839f, 0.348192f, 0.211455f, 0.059383f, 0.876027f, 0.918546f, 0.120120f, 0.334474f, 0.175372f, 0.115898f, 0.899867f, 0.056877f, 0.980486f, 0.096451f, 0.863471f, 0.566506f, 0.367917f, 0.342342f, 0.757364f, 0.314573f, 0.657319f, 0.517326f, 0.484966f, 0.901162f, 0.554645f, 0.826862f, 0.725574f, 0.038557f, 0.773110f, 0.216870f, 0.903150f, 0.042924f, 0.333072f, 0.099733f, 0.475589f, 0.820022f, 0.298187f, 0.150935f, 0.330267f, 0.813880f, 0.140384f, 0.227362f, 0.068852f, 0.705710f, 0.395233f, 0.310840f, 0.718626f, 0.335978f, 0.727771f, 0.815199f, 0.217663f, 0.973819f, 0.162358f, 0.290841f, 0.179795f, 0.345506f, 0.480061f, 0.522176f, 0.853606f, 0.889448f, 0.220104f, 0.622894f, 0.111496f, 0.458970f, 0.322334f, 0.316501f, 0.482584f, 0.729828f, 0.069183f, 0.879173f, 0.734814f, 0.176499f, 0.939161f, 0.506312f, 0.999809f, 0.197259f, 0.534908f, 0.290248f, 0.304174f, 0.591065f, 0.921719f, 0.805264f, 0.723941f, 0.559174f, 0.922298f, 0.492361f, 0.873832f, 0.833982f, 0.213835f, 0.771225f, 0.012171f, 0.322830f, 0.229567f, 0.506863f, 0.736853f, 0.097676f, 0.514922f, 0.938412f, 0.228647f, 0.677141f, 0.592880f, 0.010064f, 0.475826f, 0.708770f, 0.043975f, 0.879521f, 0.520081f, 0.030661f, 0.224414f, 0.953676f, 0.582320f, 0.107473f, 0.287544f, 0.456704f, 0.020950f, 0.411616f, 0.489459f, 0.243678f, 0.588639f, 0.753240f, 0.235834f, 0.620500f, 0.639622f, 0.948540f, 0.778276f, 0.848345f, 0.490420f, 0.185349f, 0.995815f, 0.129356f, 0.471457f, 0.068093f, 0.943851f, 0.964925f, 0.719389f, 0.349993f, 0.254382f, 0.265303f, 0.127294f, 0.525809f, 0.141817f, 0.316731f, 0.626706f, 0.727544f, 0.024273f, 0.430116f, 0.652125f, 0.853246f, 0.475325f, 0.969206f, 0.265633f, 0.013509f, 0.483753f, 0.256114f, 0.823718f, 0.232773f, 0.310629f, 0.791227f, 0.715143f, 0.558051f, 0.704948f, 0.418637f, 0.005310f, 0.011355f, 0.511222f, 0.083291f, 0.051075f, 0.965517f, 0.859003f, 0.152027f, 0.000664f, 0.941668f, 0.278325f, 0.185898f, 0.691508f, 0.108904f, 0.264650f, 0.975095f, 0.639463f, 0.520678f, 0.397919f, 0.774501f, 0.140957f, 0.967338f, 0.861123f, 0.617657f, 0.042906f, 0.700856f, 0.913284f, 0.524577f, 0.354225f, 0.120277f, 0.754901f, 0.885022f, 0.100252f, 0.758985f, 0.017060f, 0.967055f, 0.615058f, 0.552439f, 0.295950f, 0.929292f, 0.265906f, 0.828147f, 0.985109f, 0.783397f, 0.518990f, 0.066074f, 0.472414f, 0.438256f, 0.202796f, 0.423588f, 0.357758f, 0.163684f, 0.441374f, 0.262800f, 0.522062f, 0.035160f, 0.906231f, 0.816364f, 0.552581f, 0.851809f, 0.962395f, 0.110522f, 0.630832f, 0.997994f, 0.987889f, 0.603323f, 0.128021f, 0.583193f, 0.002065f, 0.198911f, 0.956123f, 0.330441f, 0.638390f, 0.280860f, 0.947822f, 0.728559f, 0.329651f, 0.791761f, 0.108166f, 0.392319f, 0.221218f, 0.683726f, 0.102446f, 0.397026f, 0.276650f, 0.506343f, 0.349898f, 0.706411f, 0.024577f, 0.633987f, 0.230571f, 0.268709f, 0.800256f, 0.955568f, 0.316550f, 0.826805f, 0.103991f, 0.633982f, 0.751032f, 0.155978f, 0.426002f, 0.892707f, 0.103578f, 0.018096f, 0.590585f, 0.435532f, 0.798689f, 0.923456f, 0.299154f, 0.388404f, 0.486272f, 0.588151f, 0.983854f, 0.697330f, 0.389549f, 0.263768f, 0.944626f, 0.135548f, 0.720266f, 0.925395f, 0.664666f, 0.423054f, 0.198991f, 0.367475f, 0.706872f, 0.649534f, 0.927976f, 0.866861f, 0.816151f, 0.911451f, 0.276337f, 0.369524f, 0.379894f, 0.560451f, 0.668218f, 0.286717f, 0.019462f, 0.399222f}; + // {2, 3, 12, 8} + std::vector past_key = {0.308528f, 0.942185f, 0.888265f, 0.860311f, 0.653000f, 0.344289f, 0.548849f, 0.815225f, 0.098610f, 0.801075f, 0.041180f, 0.816421f, 0.807564f, 0.051007f, 0.627161f, 0.502453f, 0.169820f, 0.148379f, 0.773259f, 0.567693f, 0.982999f, 0.982248f, 0.992667f, 0.118616f, 0.938256f, 0.244570f, 0.458212f, 0.757407f, 0.203621f, 0.566312f, 0.185817f, 0.104736f, 0.116559f, 0.357639f, 0.004655f, 0.424854f, 0.664197f, 0.401688f, 0.085795f, 0.062689f, 0.278117f, 0.169313f, 0.965095f, 0.151230f, 0.805462f, 0.586108f, 0.569287f, 0.512081f, 0.971763f, 0.363845f, 0.787916f, 0.555294f, 0.395634f, 0.955466f, 0.598316f, 0.118917f, 0.417539f, 0.781582f, 0.693747f, 0.916340f, 0.259377f, 0.758194f, 0.459875f, 0.573610f, 0.955047f, 0.979286f, 0.861591f, 0.359097f, 0.887701f, 0.638609f, 0.429997f, 0.035743f, 0.770128f, 0.502106f, 0.786188f, 0.748023f, 0.793567f, 0.300651f, 0.800799f, 0.548846f, 0.473326f, 0.675126f, 0.021359f, 0.102317f, 0.292177f, 0.982990f, 0.139746f, 0.330596f, 0.051053f, 0.331269f, 0.320326f, 0.946807f, 0.845154f, 0.382764f, 0.024769f, 0.831031f, 0.660536f, 0.152364f, 0.996071f, 0.100233f, 0.867115f, 0.294266f, 0.435353f, 0.795457f, 0.677508f, 0.937864f, 0.621140f, 0.097810f, 0.884360f, 0.769156f, 0.711870f, 0.053734f, 0.396223f, 0.167436f, 0.821904f, 0.700529f, 0.883078f, 0.966575f, 0.774748f, 0.994233f, 0.614770f, 0.037130f, 0.014252f, 0.342104f, 0.823472f, 0.866135f, 0.960813f, 0.065121f, 0.044571f, 0.913284f, 0.305047f, 0.557987f, 0.982445f, 0.400449f, 0.665871f, 0.400880f, 0.768195f, 0.527715f, 0.237523f, 0.271306f, 0.258059f, 0.532320f, 0.703189f, 0.949280f, 0.694087f, 0.781193f, 0.168926f, 0.374063f, 0.413780f, 0.686380f, 0.295892f, 0.303292f, 0.355889f, 0.810302f, 0.577590f, 0.075277f, 0.078246f, 0.371287f, 0.766591f, 0.688683f, 0.707982f, 0.767210f, 0.287153f, 0.548256f, 0.543353f, 0.739632f, 0.956871f, 0.277990f, 0.793282f, 0.659971f, 0.580238f, 0.774880f, 0.944032f, 0.036691f, 0.147400f, 0.756287f, 0.083791f, 0.516124f, 0.219861f, 0.274296f, 0.701840f, 0.030193f, 0.873319f, 0.444479f, 0.502393f, 0.540048f, 0.645544f, 0.344857f, 0.101107f, 0.318379f, 0.168142f, 0.556133f, 0.318029f, 0.958067f, 0.965734f, 0.620126f, 0.617497f, 0.985379f, 0.887283f, 0.765070f, 0.313591f, 0.365539f, 0.201267f, 0.487148f, 0.990369f, 0.912151f, 0.118349f, 0.025190f, 0.898638f, 0.537170f, 0.200190f, 0.673653f, 0.644223f, 0.122086f, 0.259600f, 0.060078f, 0.209860f, 0.132306f, 0.193236f, 0.685467f, 0.049500f, 0.101855f, 0.134174f, 0.316541f, 0.298750f, 0.255064f, 0.750537f, 0.998023f, 0.533978f, 0.944203f, 0.396610f, 0.106682f, 0.408774f, 0.296128f, 0.493407f, 0.657044f, 0.461050f, 0.935161f, 0.884765f, 0.701978f, 0.489685f, 0.131687f, 0.397014f, 0.704402f, 0.284886f, 0.103988f, 0.907898f, 0.709051f, 0.615276f, 0.792499f, 0.835646f, 0.483459f, 0.881188f, 0.916419f, 0.271551f, 0.607545f, 0.526584f, 0.537946f, 0.937663f, 0.305189f, 0.983434f, 0.902131f, 0.458723f, 0.817453f, 0.769047f, 0.677895f, 0.319834f, 0.196451f, 0.671528f, 0.842973f, 0.016253f, 0.642803f, 0.442873f, 0.898088f, 0.321473f, 0.474185f, 0.514767f, 0.140440f, 0.712892f, 0.830476f, 0.057909f, 0.291389f, 0.038045f, 0.956544f, 0.667169f, 0.964200f, 0.531494f, 0.802069f, 0.374414f, 0.353819f, 0.378268f, 0.657862f, 0.359453f, 0.900367f, 0.983275f, 0.030427f, 0.193623f, 0.112250f, 0.042364f, 0.227741f, 0.446793f, 0.836990f, 0.221824f, 0.493945f, 0.929619f, 0.667215f, 0.798079f, 0.550994f, 0.980466f, 0.588662f, 0.045511f, 0.197983f, 0.404774f, 0.601277f, 0.771931f, 0.413086f, 0.710058f, 0.789869f, 0.317260f, 0.979270f, 0.649656f, 0.880998f, 0.555938f, 0.741603f, 0.770544f, 0.908248f, 0.150350f, 0.558283f, 0.428379f, 0.923159f, 0.105095f, 0.982574f, 0.875451f, 0.073826f, 0.490966f, 0.717560f, 0.738152f, 0.906494f, 0.799865f, 0.310930f, 0.498435f, 0.701786f, 0.138437f, 0.193991f, 0.481042f, 0.298246f, 0.862559f, 0.586277f, 0.348665f, 0.848833f, 0.804878f, 0.998355f, 0.847308f, 0.414457f, 0.127499f, 0.840641f, 0.059758f, 0.350271f, 0.919738f, 0.960766f, 0.640565f, 0.688648f, 0.042454f, 0.514480f, 0.546868f, 0.340101f, 0.068597f, 0.228908f, 0.357984f, 0.435142f, 0.590927f, 0.722392f, 0.317632f, 0.328954f, 0.019692f, 0.040875f, 0.257822f, 0.740245f, 0.628314f, 0.769789f, 0.768919f, 0.856567f, 0.720319f, 0.979011f, 0.898825f, 0.586717f, 0.588158f, 0.034267f, 0.998527f, 0.131576f, 0.740347f, 0.821015f, 0.373055f, 0.196852f, 0.098760f, 0.748606f, 0.452654f, 0.713718f, 0.915408f, 0.146584f, 0.919171f, 0.411626f, 0.305267f, 0.943062f, 0.990652f, 0.198892f, 0.656838f, 0.106495f, 0.650914f, 0.827313f, 0.684499f, 0.417333f, 0.383066f, 0.393122f, 0.589712f, 0.881567f, 0.929066f, 0.053530f, 0.181622f, 0.112224f, 0.193335f, 0.346608f, 0.506532f, 0.629461f, 0.732142f, 0.890112f, 0.989088f, 0.662856f, 0.845365f, 0.778039f, 0.307532f, 0.875692f, 0.042763f, 0.000367f, 0.273733f, 0.462098f, 0.638363f, 0.101770f, 0.673010f, 0.801816f, 0.185313f, 0.415125f, 0.519985f, 0.451807f, 0.799830f, 0.960522f, 0.798953f, 0.077993f, 0.804936f, 0.066596f, 0.235970f, 0.153097f, 0.197519f, 0.528315f, 0.671690f, 0.470321f, 0.959696f, 0.240292f, 0.763140f, 0.870182f, 0.562066f, 0.456223f, 0.596184f, 0.428810f, 0.555194f, 0.416934f, 0.400470f, 0.695346f, 0.092851f, 0.166542f, 0.851198f, 0.771077f, 0.281454f, 0.377269f, 0.926027f, 0.818077f, 0.614346f, 0.221490f, 0.044252f, 0.431258f, 0.672627f, 0.828480f, 0.852689f, 0.032776f, 0.244157f, 0.339095f, 0.188732f, 0.802975f, 0.767466f, 0.516833f, 0.982926f, 0.144059f, 0.899652f, 0.116463f, 0.163182f, 0.696219f, 0.109570f, 0.565845f, 0.420234f, 0.728474f, 0.900675f, 0.769872f, 0.849690f, 0.032945f, 0.310196f, 0.515433f, 0.415953f, 0.231255f, 0.307874f, 0.945431f, 0.294181f, 0.353904f, 0.003710f, 0.845078f, 0.154841f, 0.204144f, 0.255265f, 0.884622f, 0.206451f, 0.797526f, 0.808049f, 0.927021f, 0.115561f, 0.217279f, 0.742898f, 0.196001f, 0.286330f, 0.166742f, 0.172697f, 0.481553f, 0.109683f, 0.321698f, 0.426594f, 0.024548f, 0.388333f, 0.094122f, 0.493579f, 0.825738f, 0.818422f, 0.080449f, 0.601228f, 0.834586f, 0.237973f, 0.761927f, 0.890764f, 0.806124f, 0.107301f, 0.009060f, 0.191724f, 0.270477f, 0.616183f, 0.384273f, 0.703407f, 0.353075f, 0.154425f, 0.312690f, 0.884324f, 0.958532f, 0.207513f, 0.788468f, 0.273349f, 0.887132f, 0.165546f, 0.665960f, 0.084211f, 0.973893f, 0.700633f, 0.841816f, 0.566669f, 0.476801f, 0.621882f, 0.528742f, 0.469384f, 0.759450f, 0.178201f, 0.171172f}; + // {2, 3, 12, 8} + std::vector past_value = {0.431843f, 0.320748f, 0.074125f, 0.844471f, 0.771603f, 0.543921f, 0.979325f, 0.072600f, 0.766669f, 0.266370f, 0.368599f, 0.219279f, 0.789038f, 0.144240f, 0.840017f, 0.661578f, 0.059023f, 0.810982f, 0.627756f, 0.904982f, 0.748722f, 0.561121f, 0.836547f, 0.278050f, 0.546950f, 0.293617f, 0.968204f, 0.226196f, 0.015738f, 0.325855f, 0.502509f, 0.028363f, 0.559248f, 0.874283f, 0.704732f, 0.622968f, 0.955962f, 0.958279f, 0.824266f, 0.607742f, 0.487765f, 0.013316f, 0.606262f, 0.989088f, 0.818101f, 0.340605f, 0.152047f, 0.784059f, 0.743938f, 0.967047f, 0.874842f, 0.555663f, 0.101284f, 0.483501f, 0.313695f, 0.512408f, 0.301702f, 0.861823f, 0.844327f, 0.315465f, 0.599581f, 0.430181f, 0.909093f, 0.187361f, 0.697728f, 0.970375f, 0.175276f, 0.201966f, 0.693723f, 0.779154f, 0.490549f, 0.609686f, 0.212682f, 0.476614f, 0.112072f, 0.321422f, 0.284780f, 0.444625f, 0.930126f, 0.181268f, 0.401388f, 0.615597f, 0.946557f, 0.133148f, 0.917877f, 0.081054f, 0.480741f, 0.454590f, 0.209603f, 0.347460f, 0.454165f, 0.865211f, 0.955064f, 0.518926f, 0.870100f, 0.608172f, 0.349087f, 0.194194f, 0.413135f, 0.522824f, 0.044443f, 0.145841f, 0.600184f, 0.225002f, 0.837326f, 0.326942f, 0.104834f, 0.083531f, 0.937123f, 0.118020f, 0.140910f, 0.862666f, 0.254288f, 0.665951f, 0.816726f, 0.607181f, 0.957489f, 0.708883f, 0.112752f, 0.558410f, 0.718186f, 0.801957f, 0.026321f, 0.718879f, 0.825681f, 0.746834f, 0.512349f, 0.458021f, 0.549419f, 0.704644f, 0.922914f, 0.617035f, 0.887834f, 0.701257f, 0.068336f, 0.500828f, 0.286486f, 0.285175f, 0.355928f, 0.314733f, 0.578610f, 0.683601f, 0.268749f, 0.129763f, 0.058809f, 0.575753f, 0.186130f, 0.009248f, 0.927753f, 0.537140f, 0.092448f, 0.842921f, 0.983203f, 0.448601f, 0.042490f, 0.117546f, 0.381654f, 0.885523f, 0.148039f, 0.823990f, 0.014976f, 0.457389f, 0.644397f, 0.060379f, 0.614763f, 0.944404f, 0.160260f, 0.729611f, 0.609094f, 0.185116f, 0.006203f, 0.009284f, 0.532092f, 0.942779f, 0.644299f, 0.714300f, 0.493865f, 0.581889f, 0.126368f, 0.876821f, 0.760793f, 0.998199f, 0.297723f, 0.227018f, 0.125162f, 0.964210f, 0.780885f, 0.166325f, 0.552686f, 0.413768f, 0.151486f, 0.162073f, 0.963470f, 0.304964f, 0.941439f, 0.075611f, 0.460803f, 0.129619f, 0.004787f, 0.553766f, 0.113894f, 0.722025f, 0.698116f, 0.176333f, 0.941742f, 0.721043f, 0.297970f, 0.709234f, 0.731930f, 0.342226f, 0.375589f, 0.359107f, 0.616618f, 0.900410f, 0.173193f, 0.875200f, 0.027653f, 0.660339f, 0.414439f, 0.791282f, 0.721198f, 0.480108f, 0.643864f, 0.501773f, 0.811518f, 0.476084f, 0.523156f, 0.250521f, 0.605043f, 0.302905f, 0.577284f, 0.169678f, 0.159469f, 0.417030f, 0.426820f, 0.268109f, 0.131597f, 0.039211f, 0.025232f, 0.271550f, 0.461853f, 0.726243f, 0.474872f, 0.904051f, 0.035220f, 0.180661f, 0.338515f, 0.577496f, 0.852736f, 0.350202f, 0.267989f, 0.061889f, 0.821303f, 0.379666f, 0.571550f, 0.983555f, 0.001595f, 0.145450f, 0.779111f, 0.805128f, 0.769247f, 0.536999f, 0.978857f, 0.396185f, 0.601944f, 0.063369f, 0.409857f, 0.722500f, 0.238739f, 0.943828f, 0.686783f, 0.287575f, 0.768999f, 0.083165f, 0.974774f, 0.049285f, 0.933456f, 0.252854f, 0.757824f, 0.000074f, 0.254240f, 0.749101f, 0.532336f, 0.114952f, 0.393630f, 0.375549f, 0.568162f, 0.667977f, 0.840830f, 0.497231f, 0.392022f, 0.143977f, 0.804823f, 0.713370f, 0.408677f, 0.518432f, 0.665183f, 0.164806f, 0.027198f, 0.317504f, 0.595585f, 0.486606f, 0.692555f, 0.819690f, 0.488442f, 0.134267f, 0.850628f, 0.574990f, 0.739937f, 0.704665f, 0.968212f, 0.295307f, 0.705307f, 0.365676f, 0.395411f, 0.230595f, 0.344010f, 0.948297f, 0.292571f, 0.245991f, 0.583138f, 0.258036f, 0.473386f, 0.834176f, 0.230400f, 0.426691f, 0.610490f, 0.545629f, 0.974723f, 0.680370f, 0.739946f, 0.966956f, 0.414438f, 0.355380f, 0.043862f, 0.184204f, 0.237190f, 0.183504f, 0.754784f, 0.535883f, 0.667634f, 0.820462f, 0.230774f, 0.325924f, 0.708360f, 0.392759f, 0.029271f, 0.434955f, 0.908273f, 0.409021f, 0.332249f, 0.989525f, 0.644416f, 0.365998f, 0.102020f, 0.787849f, 0.708075f, 0.921916f, 0.217276f, 0.114924f, 0.724073f, 0.203396f, 0.176104f, 0.319807f, 0.816825f, 0.539537f, 0.045850f, 0.463895f, 0.683980f, 0.538368f, 0.572450f, 0.224777f, 0.847739f, 0.561399f, 0.713246f, 0.981864f, 0.428199f, 0.881067f, 0.007281f, 0.033407f, 0.590280f, 0.311449f, 0.248277f, 0.277935f, 0.318403f, 0.728948f, 0.569196f, 0.789036f, 0.830197f, 0.842935f, 0.414644f, 0.421273f, 0.926266f, 0.661764f, 0.080467f, 0.542187f, 0.356007f, 0.987435f, 0.013655f, 0.612181f, 0.723623f, 0.288907f, 0.973642f, 0.859537f, 0.915653f, 0.019232f, 0.569872f, 0.294650f, 0.849029f, 0.632850f, 0.538877f, 0.114588f, 0.540223f, 0.631904f, 0.955912f, 0.585051f, 0.967401f, 0.961606f, 0.650200f, 0.505908f, 0.466022f, 0.890379f, 0.028257f, 0.113808f, 0.102072f, 0.756935f, 0.339651f, 0.637969f, 0.603783f, 0.385828f, 0.531568f, 0.645139f, 0.940950f, 0.575634f, 0.614367f, 0.067856f, 0.952216f, 0.528082f, 0.801273f, 0.050291f, 0.420910f, 0.256975f, 0.266976f, 0.791454f, 0.623867f, 0.439745f, 0.010586f, 0.964928f, 0.962023f, 0.217552f, 0.041346f, 0.530199f, 0.951411f, 0.910396f, 0.584663f, 0.303549f, 0.329961f, 0.897914f, 0.491784f, 0.131116f, 0.248425f, 0.276795f, 0.123547f, 0.463044f, 0.916051f, 0.668783f, 0.072474f, 0.005495f, 0.276248f, 0.362693f, 0.776750f, 0.967006f, 0.387567f, 0.686690f, 0.994902f, 0.745667f, 0.636190f, 0.078075f, 0.323215f, 0.913392f, 0.201005f, 0.843590f, 0.696324f, 0.366324f, 0.529174f, 0.542806f, 0.714054f, 0.516556f, 0.133076f, 0.773455f, 0.406273f, 0.963094f, 0.283514f, 0.263079f, 0.333507f, 0.572317f, 0.894870f, 0.176282f, 0.279679f, 0.581680f, 0.454334f, 0.447323f, 0.820734f, 0.923878f, 0.481307f, 0.687352f, 0.801059f, 0.518366f, 0.294316f, 0.638085f, 0.585109f, 0.901563f, 0.052407f, 0.910131f, 0.534432f, 0.015676f, 0.344702f, 0.724334f, 0.488433f, 0.980159f, 0.422610f, 0.326635f, 0.821672f, 0.547907f, 0.682327f, 0.805702f, 0.671428f, 0.422408f, 0.124796f, 0.580248f, 0.897433f, 0.418892f, 0.910725f, 0.503528f, 0.620842f, 0.832989f, 0.564597f, 0.090969f, 0.980979f, 0.245849f, 0.710505f, 0.505113f, 0.478773f, 0.243941f, 0.722151f, 0.112788f, 0.990453f, 0.845374f, 0.534509f, 0.424553f, 0.286465f, 0.501591f, 0.879417f, 0.275006f, 0.500537f, 0.234550f, 0.337149f, 0.190261f, 0.990539f, 0.571497f, 0.732815f, 0.098250f, 0.366118f, 0.892640f, 0.084438f, 0.165483f, 0.625418f, 0.622789f, 0.838227f, 0.935493f, 0.141986f, 0.259374f, 0.427461f, 0.000903f, 0.069814f, 0.226491f}; ASSERT_EQ(q.size(), batch_size * q_num_heads * q_sequence_length * head_size); ASSERT_EQ(k.size(), batch_size * kv_num_heads * kv_sequence_length * head_size); @@ -1181,14 +1214,15 @@ TEST(AttentionTest, Attention4DWithMask4DPastAndPresentQkMatmul) { ASSERT_EQ(past_key.size(), batch_size * kv_num_heads * past_sequence_length * head_size); ASSERT_EQ(past_value.size(), batch_size * kv_num_heads * past_sequence_length * v_head_size); - // {2, 3, 4, 4} - std::vector y = {-0.385742f, -0.379327f, -0.372911f, -0.366496f, -0.385554f, -0.379139f, -0.372723f, -0.366308f, -0.385366f, -0.378950f, -0.372535f, -0.366119f, -0.385178f, -0.378762f, -0.372347f, -0.365931f, -0.218323f, -0.211907f, -0.205492f, -0.199076f, -0.218134f, -0.211719f, -0.205304f, -0.198888f, -0.217946f, -0.211531f, -0.205115f, -0.198700f, -0.217758f, -0.211342f, -0.204927f, -0.198512f, -0.050903f, -0.044487f, -0.038072f, -0.031657f, -0.050715f, -0.044299f, -0.037884f, -0.031468f, -0.050526f, -0.044111f, -0.037695f, -0.031280f, -0.050338f, -0.043922f, -0.037507f, -0.031092f, 0.116517f, 0.122932f, 0.129348f, 0.135763f, 0.116705f, 0.123121f, 0.129536f, 0.135952f, 0.116894f, 0.123309f, 0.129724f, 0.136140f, 0.117082f, 0.123497f, 0.129913f, 0.136328f, 0.283937f, 0.290352f, 0.296768f, 0.303183f, 0.284125f, 0.290540f, 0.296956f, 0.303371f, 0.284313f, 0.290729f, 0.297144f, 0.303559f, 0.284501f, 0.290917f, 0.297332f, 0.303747f, 0.451356f, 0.457772f, 0.464187f, 0.470602f, 0.451544f, 0.457960f, 0.464375f, 0.470790f, 0.451732f, 0.458148f, 0.464563f, 0.470978f, 0.451920f, 0.458336f, 0.464751f, 0.471166f}; - // {2, 3, 13, 4} - std::vector present_key = {-0.454545f, -0.448593f, -0.442641f, -0.436688f, -0.430736f, -0.424784f, -0.418831f, -0.412879f, -0.406926f, -0.400974f, -0.395022f, -0.389069f, -0.383117f, -0.377165f, -0.371212f, -0.365260f, -0.359307f, -0.353355f, -0.347403f, -0.341450f, -0.335498f, -0.329545f, -0.323593f, -0.317641f, -0.311688f, -0.305736f, -0.299784f, -0.293831f, -0.454545f, -0.447601f, -0.440657f, -0.433712f, -0.426768f, -0.419823f, -0.412879f, -0.405934f, -0.398990f, -0.392045f, -0.385101f, -0.378157f, -0.371212f, -0.364268f, -0.357323f, -0.350379f, -0.343434f, -0.336490f, -0.329545f, -0.322601f, -0.315657f, -0.308712f, -0.301768f, -0.294823f, -0.287879f, -0.281926f, -0.275974f, -0.270022f, -0.264069f, -0.258117f, -0.252165f, -0.246212f, -0.240260f, -0.234307f, -0.228355f, -0.222403f, -0.216450f, -0.210498f, -0.204545f, -0.198593f, -0.192641f, -0.186688f, -0.180736f, -0.174784f, -0.168831f, -0.162879f, -0.156926f, -0.150974f, -0.145022f, -0.139069f, -0.133117f, -0.127164f, -0.287879f, -0.280934f, -0.273990f, -0.267045f, -0.260101f, -0.253157f, -0.246212f, -0.239268f, -0.232323f, -0.225379f, -0.218434f, -0.211490f, -0.204545f, -0.197601f, -0.190657f, -0.183712f, -0.176768f, -0.169823f, -0.162879f, -0.155934f, -0.148990f, -0.142045f, -0.135101f, -0.128157f, -0.121212f, -0.115260f, -0.109307f, -0.103355f, -0.097403f, -0.091450f, -0.085498f, -0.079545f, -0.073593f, -0.067641f, -0.061688f, -0.055736f, -0.049784f, -0.043831f, -0.037879f, -0.031926f, -0.025974f, -0.020022f, -0.014069f, -0.008117f, -0.002165f, 0.003788f, 0.009740f, 0.015693f, 0.021645f, 0.027597f, 0.033550f, 0.039502f, -0.121212f, -0.114268f, -0.107323f, -0.100379f, -0.093434f, -0.086490f, -0.079545f, -0.072601f, -0.065657f, -0.058712f, -0.051768f, -0.044823f, -0.037879f, -0.030934f, -0.023990f, -0.017045f, -0.010101f, -0.003157f, 0.003788f, 0.010732f, 0.017677f, 0.024621f, 0.031566f, 0.038510f, 0.045455f, 0.051407f, 0.057359f, 0.063312f, 0.069264f, 0.075216f, 0.081169f, 0.087121f, 0.093074f, 0.099026f, 0.104978f, 0.110931f, 0.116883f, 0.122835f, 0.128788f, 0.134740f, 0.140693f, 0.146645f, 0.152597f, 0.158550f, 0.164502f, 0.170455f, 0.176407f, 0.182359f, 0.188312f, 0.194264f, 0.200216f, 0.206169f, 0.045455f, 0.052399f, 0.059343f, 0.066288f, 0.073232f, 0.080177f, 0.087121f, 0.094066f, 0.101010f, 0.107955f, 0.114899f, 0.121843f, 0.128788f, 0.135732f, 0.142677f, 0.149621f, 0.156566f, 0.163510f, 0.170455f, 0.177399f, 0.184343f, 0.191288f, 0.198232f, 0.205177f, 0.212121f, 0.218074f, 0.224026f, 0.229978f, 0.235931f, 0.241883f, 0.247836f, 0.253788f, 0.259740f, 0.265693f, 0.271645f, 0.277597f, 0.283550f, 0.289502f, 0.295455f, 0.301407f, 0.307359f, 0.313312f, 0.319264f, 0.325216f, 0.331169f, 0.337121f, 0.343074f, 0.349026f, 0.354978f, 0.360931f, 0.366883f, 0.372835f, 0.212121f, 0.219066f, 0.226010f, 0.232955f, 0.239899f, 0.246843f, 0.253788f, 0.260732f, 0.267677f, 0.274621f, 0.281566f, 0.288510f, 0.295455f, 0.302399f, 0.309343f, 0.316288f, 0.323232f, 0.330177f, 0.337121f, 0.344066f, 0.351010f, 0.357955f, 0.364899f, 0.371843f, 0.378788f, 0.384740f, 0.390693f, 0.396645f, 0.402597f, 0.408550f, 0.414502f, 0.420455f, 0.426407f, 0.432359f, 0.438312f, 0.444264f, 0.450216f, 0.456169f, 0.462121f, 0.468074f, 0.474026f, 0.479978f, 0.485931f, 0.491883f, 0.497835f, 0.503788f, 0.509740f, 0.515693f, 0.521645f, 0.527597f, 0.533550f, 0.539502f, 0.378788f, 0.385732f, 0.392677f, 0.399621f, 0.406566f, 0.413510f, 0.420455f, 0.427399f, 0.434343f, 0.441288f, 0.448232f, 0.455177f, 0.462121f, 0.469066f, 0.476010f, 0.482955f, 0.489899f, 0.496843f, 0.503788f, 0.510732f, 0.517677f, 0.524621f, 0.531566f, 0.538510f}; + // {2, 3, 4, 8} + std::vector y = {0.431265f, 0.558994f, 0.492979f, 0.535281f, 0.609591f, 0.466737f, 0.692090f, 0.412591f, 0.468058f, 0.623595f, 0.468127f, 0.483497f, 0.577278f, 0.512802f, 0.639767f, 0.427679f, 0.422704f, 0.532822f, 0.449594f, 0.560548f, 0.608427f, 0.476187f, 0.695694f, 0.425740f, 0.447270f, 0.528366f, 0.506840f, 0.501836f, 0.547248f, 0.457381f, 0.583533f, 0.471707f, 0.414727f, 0.517263f, 0.342732f, 0.363543f, 0.677046f, 0.664675f, 0.271455f, 0.479982f, 0.438313f, 0.537211f, 0.342649f, 0.402609f, 0.660072f, 0.631518f, 0.266481f, 0.501402f, 0.458457f, 0.519536f, 0.434125f, 0.443849f, 0.614893f, 0.636419f, 0.310940f, 0.497030f, 0.433312f, 0.522457f, 0.417441f, 0.405432f, 0.617509f, 0.592985f, 0.310558f, 0.490073f, 0.499459f, 0.430465f, 0.601451f, 0.404111f, 0.502848f, 0.415186f, 0.440655f, 0.478187f, 0.536562f, 0.376663f, 0.527310f, 0.363608f, 0.443744f, 0.476396f, 0.453812f, 0.498910f, 0.483497f, 0.433209f, 0.541590f, 0.366029f, 0.513807f, 0.477506f, 0.492110f, 0.527910f, 0.471458f, 0.419741f, 0.536529f, 0.407806f, 0.512188f, 0.467064f, 0.496260f, 0.519270f, 0.683252f, 0.426643f, 0.425275f, 0.457410f, 0.611686f, 0.591234f, 0.394568f, 0.446171f, 0.637484f, 0.426481f, 0.346779f, 0.466867f, 0.585075f, 0.558250f, 0.387627f, 0.507636f, 0.658808f, 0.467355f, 0.496107f, 0.556756f, 0.513309f, 0.520842f, 0.411220f, 0.451704f, 0.661693f, 0.463543f, 0.421647f, 0.486068f, 0.552701f, 0.484705f, 0.412050f, 0.449818f, 0.637941f, 0.564086f, 0.543446f, 0.530844f, 0.627347f, 0.520370f, 0.389963f, 0.520054f, 0.574335f, 0.604007f, 0.468559f, 0.473710f, 0.559229f, 0.504183f, 0.453090f, 0.564618f, 0.568083f, 0.541180f, 0.491888f, 0.485970f, 0.564150f, 0.506989f, 0.421426f, 0.544228f, 0.616426f, 0.467555f, 0.529898f, 0.487372f, 0.574411f, 0.471969f, 0.388121f, 0.485012f, 0.533687f, 0.523210f, 0.560021f, 0.490233f, 0.443149f, 0.420163f, 0.538998f, 0.606965f, 0.586616f, 0.478324f, 0.572142f, 0.517933f, 0.441955f, 0.411890f, 0.550505f, 0.604577f, 0.541173f, 0.473423f, 0.505749f, 0.473388f, 0.389025f, 0.498730f, 0.507861f, 0.584389f, 0.519963f, 0.461030f, 0.576878f, 0.471281f, 0.461238f, 0.496673f, 0.509573f, 0.568405f}; // {2, 3, 18, 8} - std::vector present_value = {-0.454545f, -0.448593f, -0.442641f, -0.436688f, -0.430736f, -0.424784f, -0.418831f, -0.412879f, -0.406926f, -0.400974f, -0.395022f, -0.389069f, -0.383117f, -0.377165f, -0.371212f, -0.365260f, -0.359307f, -0.353355f, -0.347403f, -0.341450f, -0.335498f, -0.329545f, -0.323593f, -0.317641f, -0.311688f, -0.305736f, -0.299784f, -0.293831f, -0.454545f, -0.447601f, -0.440657f, -0.433712f, -0.426768f, -0.419823f, -0.412879f, -0.405934f, -0.398990f, -0.392045f, -0.385101f, -0.378157f, -0.371212f, -0.364268f, -0.357323f, -0.350379f, -0.343434f, -0.336490f, -0.329545f, -0.322601f, -0.315657f, -0.308712f, -0.301768f, -0.294823f, -0.287879f, -0.281926f, -0.275974f, -0.270022f, -0.264069f, -0.258117f, -0.252165f, -0.246212f, -0.240260f, -0.234307f, -0.228355f, -0.222403f, -0.216450f, -0.210498f, -0.204545f, -0.198593f, -0.192641f, -0.186688f, -0.180736f, -0.174784f, -0.168831f, -0.162879f, -0.156926f, -0.150974f, -0.145022f, -0.139069f, -0.133117f, -0.127164f, -0.287879f, -0.280934f, -0.273990f, -0.267045f, -0.260101f, -0.253157f, -0.246212f, -0.239268f, -0.232323f, -0.225379f, -0.218434f, -0.211490f, -0.204545f, -0.197601f, -0.190657f, -0.183712f, -0.176768f, -0.169823f, -0.162879f, -0.155934f, -0.148990f, -0.142045f, -0.135101f, -0.128157f, -0.121212f, -0.115260f, -0.109307f, -0.103355f, -0.097403f, -0.091450f, -0.085498f, -0.079545f, -0.073593f, -0.067641f, -0.061688f, -0.055736f, -0.049784f, -0.043831f, -0.037879f, -0.031926f, -0.025974f, -0.020022f, -0.014069f, -0.008117f, -0.002165f, 0.003788f, 0.009740f, 0.015693f, 0.021645f, 0.027597f, 0.033550f, 0.039502f, -0.121212f, -0.114268f, -0.107323f, -0.100379f, -0.093434f, -0.086490f, -0.079545f, -0.072601f, -0.065657f, -0.058712f, -0.051768f, -0.044823f, -0.037879f, -0.030934f, -0.023990f, -0.017045f, -0.010101f, -0.003157f, 0.003788f, 0.010732f, 0.017677f, 0.024621f, 0.031566f, 0.038510f, 0.045455f, 0.051407f, 0.057359f, 0.063312f, 0.069264f, 0.075216f, 0.081169f, 0.087121f, 0.093074f, 0.099026f, 0.104978f, 0.110931f, 0.116883f, 0.122835f, 0.128788f, 0.134740f, 0.140693f, 0.146645f, 0.152597f, 0.158550f, 0.164502f, 0.170455f, 0.176407f, 0.182359f, 0.188312f, 0.194264f, 0.200216f, 0.206169f, 0.045455f, 0.052399f, 0.059343f, 0.066288f, 0.073232f, 0.080177f, 0.087121f, 0.094066f, 0.101010f, 0.107955f, 0.114899f, 0.121843f, 0.128788f, 0.135732f, 0.142677f, 0.149621f, 0.156566f, 0.163510f, 0.170455f, 0.177399f, 0.184343f, 0.191288f, 0.198232f, 0.205177f, 0.212121f, 0.218074f, 0.224026f, 0.229978f, 0.235931f, 0.241883f, 0.247836f, 0.253788f, 0.259740f, 0.265693f, 0.271645f, 0.277597f, 0.283550f, 0.289502f, 0.295455f, 0.301407f, 0.307359f, 0.313312f, 0.319264f, 0.325216f, 0.331169f, 0.337121f, 0.343074f, 0.349026f, 0.354978f, 0.360931f, 0.366883f, 0.372835f, 0.212121f, 0.219066f, 0.226010f, 0.232955f, 0.239899f, 0.246843f, 0.253788f, 0.260732f, 0.267677f, 0.274621f, 0.281566f, 0.288510f, 0.295455f, 0.302399f, 0.309343f, 0.316288f, 0.323232f, 0.330177f, 0.337121f, 0.344066f, 0.351010f, 0.357955f, 0.364899f, 0.371843f, 0.378788f, 0.384740f, 0.390693f, 0.396645f, 0.402597f, 0.408550f, 0.414502f, 0.420455f, 0.426407f, 0.432359f, 0.438312f, 0.444264f, 0.450216f, 0.456169f, 0.462121f, 0.468074f, 0.474026f, 0.479978f, 0.485931f, 0.491883f, 0.497835f, 0.503788f, 0.509740f, 0.515693f, 0.521645f, 0.527597f, 0.533550f, 0.539502f, 0.378788f, 0.385732f, 0.392677f, 0.399621f, 0.406566f, 0.413510f, 0.420455f, 0.427399f, 0.434343f, 0.441288f, 0.448232f, 0.455177f, 0.462121f, 0.469066f, 0.476010f, 0.482955f, 0.489899f, 0.496843f, 0.503788f, 0.510732f, 0.517677f, 0.524621f, 0.531566f, 0.538510f}; - // {2, 3, 4, 13} - std::vector qk_matmul = {0.391336f, 0.370435f, 0.349534f, 0.328633f, 0.307732f, 0.286831f, 0.265930f, 0.390055f, 0.365671f, 0.341286f, 0.316902f, 0.292517f, 0.268133f, 0.354201f, 0.335284f, 0.316367f, 0.297450f, 0.278534f, 0.259617f, 0.240700f, 0.353045f, 0.330975f, 0.308905f, 0.286836f, 0.264766f, 0.242696f, 0.317066f, 0.300134f, 0.283201f, 0.266268f, 0.249335f, 0.232403f, 0.215470f, 0.316034f, 0.296279f, 0.276524f, 0.256769f, 0.237014f, 0.217260f, 0.279932f, 0.264983f, 0.250034f, 0.235086f, 0.220137f, 0.205189f, 0.190240f, 0.279023f, 0.261583f, 0.244143f, 0.226703f, 0.209263f, 0.191823f, 0.152046f, 0.139081f, 0.126117f, 0.113152f, 0.100188f, 0.087223f, 0.074259f, 0.151261f, 0.136136f, 0.121011f, 0.105885f, 0.090760f, 0.075635f, 0.128800f, 0.117819f, 0.106839f, 0.095859f, 0.084878f, 0.073898f, 0.062918f, 0.128139f, 0.115329f, 0.102518f, 0.089708f, 0.076898f, 0.064087f, 0.105554f, 0.096558f, 0.087561f, 0.078565f, 0.069569f, 0.060573f, 0.051577f, 0.105017f, 0.094522f, 0.084026f, 0.073531f, 0.063035f, 0.052539f, 0.082308f, 0.075296f, 0.068284f, 0.061272f, 0.054260f, 0.047248f, 0.040235f, 0.081896f, 0.073715f, 0.065534f, 0.057353f, 0.049172f, 0.040992f, 0.023866f, 0.018838f, 0.013810f, 0.008783f, 0.003755f, -0.001273f, -0.006301f, 0.023578f, 0.017712f, 0.011846f, 0.005980f, 0.000114f, -0.005752f, 0.014509f, 0.011466f, 0.008422f, 0.005378f, 0.002334f, -0.000710f, -0.003754f, 0.014345f, 0.010794f, 0.007243f, 0.003692f, 0.000140f, -0.003411f, 0.005152f, 0.004093f, 0.003033f, 0.001973f, 0.000914f, -0.000146f, -0.001206f, 0.005112f, 0.003876f, 0.002639f, 0.001403f, 0.000167f, -0.001070f, -0.004204f, -0.003280f, -0.002356f, -0.001431f, -0.000507f, 0.000418f, 0.001342f, -0.004121f, -0.003042f, -0.001964f, -0.000885f, 0.000193f, 0.001272f, 0.006798f, 0.009707f, 0.012616f, 0.015524f, 0.018433f, 0.021341f, 0.024250f, 0.007006f, 0.010399f, 0.013793f, 0.017186f, 0.020579f, 0.023973f, 0.011330f, 0.016223f, 0.021116f, 0.026008f, 0.030901f, 0.035794f, 0.040686f, 0.011662f, 0.017370f, 0.023078f, 0.028786f, 0.034494f, 0.040203f, 0.015862f, 0.022739f, 0.029616f, 0.036493f, 0.043369f, 0.050246f, 0.057123f, 0.016318f, 0.024341f, 0.032364f, 0.040387f, 0.048410f, 0.056433f, 0.020394f, 0.029255f, 0.038116f, 0.046977f, 0.055838f, 0.064699f, 0.073560f, 0.020974f, 0.031312f, 0.041649f, 0.051987f, 0.062325f, 0.072663f, 0.100842f, 0.111687f, 0.122532f, 0.133377f, 0.144222f, 0.155067f, 0.165912f, 0.101545f, 0.114198f, 0.126850f, 0.139503f, 0.152155f, 0.164808f, 0.119262f, 0.132092f, 0.144921f, 0.157750f, 0.170579f, 0.183408f, 0.196237f, 0.120090f, 0.135057f, 0.150025f, 0.164992f, 0.179960f, 0.194927f, 0.137683f, 0.152496f, 0.167310f, 0.182123f, 0.196936f, 0.211750f, 0.226563f, 0.138635f, 0.155917f, 0.173199f, 0.190481f, 0.207764f, 0.225046f, 0.156104f, 0.172901f, 0.189699f, 0.206496f, 0.223294f, 0.240091f, 0.256889f, 0.157180f, 0.176777f, 0.196374f, 0.215971f, 0.235568f, 0.255165f, 0.305996f, 0.324777f, 0.343559f, 0.362340f, 0.381122f, 0.399904f, 0.418685f, 0.307195f, 0.329107f, 0.351019f, 0.372931f, 0.394843f, 0.416755f, 0.338305f, 0.359071f, 0.379837f, 0.400603f, 0.421368f, 0.442134f, 0.462900f, 0.339629f, 0.363856f, 0.388082f, 0.412309f, 0.436536f, 0.460762f, 0.370615f, 0.393365f, 0.416115f, 0.438865f, 0.461614f, 0.484364f, 0.507114f, 0.372063f, 0.398604f, 0.425146f, 0.451687f, 0.478229f, 0.504770f, 0.402925f, 0.427659f, 0.452393f, 0.477127f, 0.501861f, 0.526595f, 0.551329f, 0.404497f, 0.433353f, 0.462209f, 0.491065f, 0.519922f, 0.548778f}; + std::vector present_key = {0.308528f, 0.942185f, 0.888265f, 0.860311f, 0.653000f, 0.344289f, 0.548849f, 0.815225f, 0.098610f, 0.801075f, 0.041180f, 0.816421f, 0.807564f, 0.051007f, 0.627161f, 0.502453f, 0.169820f, 0.148379f, 0.773259f, 0.567693f, 0.982999f, 0.982248f, 0.992667f, 0.118616f, 0.938256f, 0.244570f, 0.458212f, 0.757407f, 0.203621f, 0.566312f, 0.185817f, 0.104736f, 0.116559f, 0.357639f, 0.004655f, 0.424854f, 0.664197f, 0.401688f, 0.085795f, 0.062689f, 0.278117f, 0.169313f, 0.965095f, 0.151230f, 0.805462f, 0.586108f, 0.569287f, 0.512081f, 0.971763f, 0.363845f, 0.787916f, 0.555294f, 0.395634f, 0.955466f, 0.598316f, 0.118917f, 0.417539f, 0.781582f, 0.693747f, 0.916340f, 0.259377f, 0.758194f, 0.459875f, 0.573610f, 0.955047f, 0.979286f, 0.861591f, 0.359097f, 0.887701f, 0.638609f, 0.429997f, 0.035743f, 0.770128f, 0.502106f, 0.786188f, 0.748023f, 0.793567f, 0.300651f, 0.800799f, 0.548846f, 0.473326f, 0.675126f, 0.021359f, 0.102317f, 0.292177f, 0.982990f, 0.139746f, 0.330596f, 0.051053f, 0.331269f, 0.320326f, 0.946807f, 0.845154f, 0.382764f, 0.024769f, 0.831031f, 0.186193f, 0.944372f, 0.739551f, 0.490459f, 0.227415f, 0.254356f, 0.058029f, 0.434417f, 0.311796f, 0.696343f, 0.377752f, 0.179604f, 0.024679f, 0.067250f, 0.679393f, 0.453697f, 0.536579f, 0.896671f, 0.990339f, 0.216897f, 0.663078f, 0.263322f, 0.020651f, 0.758379f, 0.320017f, 0.383464f, 0.588317f, 0.831048f, 0.628982f, 0.872651f, 0.273542f, 0.798047f, 0.185636f, 0.952792f, 0.687488f, 0.215508f, 0.947371f, 0.730856f, 0.253942f, 0.213312f, 0.518201f, 0.025663f, 0.207470f, 0.424685f, 0.374170f, 0.463575f, 0.277629f, 0.586784f, 0.660536f, 0.152364f, 0.996071f, 0.100233f, 0.867115f, 0.294266f, 0.435353f, 0.795457f, 0.677508f, 0.937864f, 0.621140f, 0.097810f, 0.884360f, 0.769156f, 0.711870f, 0.053734f, 0.396223f, 0.167436f, 0.821904f, 0.700529f, 0.883078f, 0.966575f, 0.774748f, 0.994233f, 0.614770f, 0.037130f, 0.014252f, 0.342104f, 0.823472f, 0.866135f, 0.960813f, 0.065121f, 0.044571f, 0.913284f, 0.305047f, 0.557987f, 0.982445f, 0.400449f, 0.665871f, 0.400880f, 0.768195f, 0.527715f, 0.237523f, 0.271306f, 0.258059f, 0.532320f, 0.703189f, 0.949280f, 0.694087f, 0.781193f, 0.168926f, 0.374063f, 0.413780f, 0.686380f, 0.295892f, 0.303292f, 0.355889f, 0.810302f, 0.577590f, 0.075277f, 0.078246f, 0.371287f, 0.766591f, 0.688683f, 0.707982f, 0.767210f, 0.287153f, 0.548256f, 0.543353f, 0.739632f, 0.956871f, 0.277990f, 0.793282f, 0.659971f, 0.580238f, 0.774880f, 0.944032f, 0.036691f, 0.147400f, 0.756287f, 0.083791f, 0.516124f, 0.219861f, 0.274296f, 0.701840f, 0.030193f, 0.873319f, 0.444479f, 0.502393f, 0.540048f, 0.645544f, 0.344857f, 0.101107f, 0.318379f, 0.168142f, 0.556133f, 0.863856f, 0.117532f, 0.517379f, 0.132068f, 0.716860f, 0.396060f, 0.565421f, 0.183280f, 0.144848f, 0.488056f, 0.355613f, 0.940432f, 0.765325f, 0.748664f, 0.903720f, 0.083422f, 0.552192f, 0.584476f, 0.961936f, 0.292148f, 0.240829f, 0.100294f, 0.016430f, 0.929529f, 0.669917f, 0.785153f, 0.281730f, 0.586410f, 0.063955f, 0.485628f, 0.977495f, 0.876505f, 0.338159f, 0.961570f, 0.231702f, 0.949319f, 0.941378f, 0.799203f, 0.630448f, 0.874288f, 0.293020f, 0.848944f, 0.617877f, 0.013237f, 0.347234f, 0.148141f, 0.981829f, 0.478370f, 0.318029f, 0.958067f, 0.965734f, 0.620126f, 0.617497f, 0.985379f, 0.887283f, 0.765070f, 0.313591f, 0.365539f, 0.201267f, 0.487148f, 0.990369f, 0.912151f, 0.118349f, 0.025190f, 0.898638f, 0.537170f, 0.200190f, 0.673653f, 0.644223f, 0.122086f, 0.259600f, 0.060078f, 0.209860f, 0.132306f, 0.193236f, 0.685467f, 0.049500f, 0.101855f, 0.134174f, 0.316541f, 0.298750f, 0.255064f, 0.750537f, 0.998023f, 0.533978f, 0.944203f, 0.396610f, 0.106682f, 0.408774f, 0.296128f, 0.493407f, 0.657044f, 0.461050f, 0.935161f, 0.884765f, 0.701978f, 0.489685f, 0.131687f, 0.397014f, 0.704402f, 0.284886f, 0.103988f, 0.907898f, 0.709051f, 0.615276f, 0.792499f, 0.835646f, 0.483459f, 0.881188f, 0.916419f, 0.271551f, 0.607545f, 0.526584f, 0.537946f, 0.937663f, 0.305189f, 0.983434f, 0.902131f, 0.458723f, 0.817453f, 0.769047f, 0.677895f, 0.319834f, 0.196451f, 0.671528f, 0.842973f, 0.016253f, 0.642803f, 0.442873f, 0.898088f, 0.321473f, 0.474185f, 0.514767f, 0.140440f, 0.712892f, 0.830476f, 0.057909f, 0.291389f, 0.038045f, 0.956544f, 0.667169f, 0.964200f, 0.531494f, 0.802069f, 0.497391f, 0.639473f, 0.368585f, 0.136900f, 0.822118f, 0.189848f, 0.511319f, 0.224317f, 0.097844f, 0.862191f, 0.972919f, 0.960835f, 0.906555f, 0.774047f, 0.333145f, 0.081101f, 0.407241f, 0.232234f, 0.132488f, 0.053427f, 0.725594f, 0.011427f, 0.770581f, 0.146947f, 0.079522f, 0.089603f, 0.672048f, 0.245367f, 0.420539f, 0.557369f, 0.860551f, 0.727044f, 0.270328f, 0.131483f, 0.055374f, 0.301599f, 0.262118f, 0.456141f, 0.683281f, 0.695625f, 0.283519f, 0.379927f, 0.181151f, 0.788545f, 0.056848f, 0.696997f, 0.778695f, 0.777408f, 0.374414f, 0.353819f, 0.378268f, 0.657862f, 0.359453f, 0.900367f, 0.983275f, 0.030427f, 0.193623f, 0.112250f, 0.042364f, 0.227741f, 0.446793f, 0.836990f, 0.221824f, 0.493945f, 0.929619f, 0.667215f, 0.798079f, 0.550994f, 0.980466f, 0.588662f, 0.045511f, 0.197983f, 0.404774f, 0.601277f, 0.771931f, 0.413086f, 0.710058f, 0.789869f, 0.317260f, 0.979270f, 0.649656f, 0.880998f, 0.555938f, 0.741603f, 0.770544f, 0.908248f, 0.150350f, 0.558283f, 0.428379f, 0.923159f, 0.105095f, 0.982574f, 0.875451f, 0.073826f, 0.490966f, 0.717560f, 0.738152f, 0.906494f, 0.799865f, 0.310930f, 0.498435f, 0.701786f, 0.138437f, 0.193991f, 0.481042f, 0.298246f, 0.862559f, 0.586277f, 0.348665f, 0.848833f, 0.804878f, 0.998355f, 0.847308f, 0.414457f, 0.127499f, 0.840641f, 0.059758f, 0.350271f, 0.919738f, 0.960766f, 0.640565f, 0.688648f, 0.042454f, 0.514480f, 0.546868f, 0.340101f, 0.068597f, 0.228908f, 0.357984f, 0.435142f, 0.590927f, 0.722392f, 0.317632f, 0.328954f, 0.019692f, 0.040875f, 0.257822f, 0.740245f, 0.628314f, 0.769789f, 0.768919f, 0.856567f, 0.720319f, 0.979011f, 0.259423f, 0.373813f, 0.587600f, 0.272822f, 0.370853f, 0.197054f, 0.459856f, 0.044612f, 0.799796f, 0.076956f, 0.518835f, 0.306810f, 0.577543f, 0.959433f, 0.645570f, 0.035362f, 0.430402f, 0.510017f, 0.536178f, 0.681392f, 0.277596f, 0.128861f, 0.392676f, 0.956406f, 0.187131f, 0.903984f, 0.543806f, 0.456911f, 0.882041f, 0.458604f, 0.724168f, 0.399025f, 0.904044f, 0.690025f, 0.699622f, 0.327720f, 0.756779f, 0.636061f, 0.240020f, 0.160539f, 0.796391f, 0.959167f, 0.458139f, 0.590984f, 0.857723f, 0.457223f, 0.951874f, 0.575751f, 0.898825f, 0.586717f, 0.588158f, 0.034267f, 0.998527f, 0.131576f, 0.740347f, 0.821015f, 0.373055f, 0.196852f, 0.098760f, 0.748606f, 0.452654f, 0.713718f, 0.915408f, 0.146584f, 0.919171f, 0.411626f, 0.305267f, 0.943062f, 0.990652f, 0.198892f, 0.656838f, 0.106495f, 0.650914f, 0.827313f, 0.684499f, 0.417333f, 0.383066f, 0.393122f, 0.589712f, 0.881567f, 0.929066f, 0.053530f, 0.181622f, 0.112224f, 0.193335f, 0.346608f, 0.506532f, 0.629461f, 0.732142f, 0.890112f, 0.989088f, 0.662856f, 0.845365f, 0.778039f, 0.307532f, 0.875692f, 0.042763f, 0.000367f, 0.273733f, 0.462098f, 0.638363f, 0.101770f, 0.673010f, 0.801816f, 0.185313f, 0.415125f, 0.519985f, 0.451807f, 0.799830f, 0.960522f, 0.798953f, 0.077993f, 0.804936f, 0.066596f, 0.235970f, 0.153097f, 0.197519f, 0.528315f, 0.671690f, 0.470321f, 0.959696f, 0.240292f, 0.763140f, 0.870182f, 0.562066f, 0.456223f, 0.596184f, 0.428810f, 0.555194f, 0.416934f, 0.400470f, 0.695346f, 0.092851f, 0.166542f, 0.851198f, 0.771077f, 0.281454f, 0.377269f, 0.926027f, 0.818077f, 0.614346f, 0.221490f, 0.044252f, 0.431258f, 0.820767f, 0.908844f, 0.815524f, 0.159414f, 0.628898f, 0.398434f, 0.062713f, 0.424032f, 0.258684f, 0.849038f, 0.033305f, 0.958983f, 0.355369f, 0.356707f, 0.016329f, 0.185232f, 0.401260f, 0.929291f, 0.099615f, 0.945302f, 0.869489f, 0.454162f, 0.326701f, 0.232744f, 0.614465f, 0.033075f, 0.015606f, 0.428796f, 0.068074f, 0.251941f, 0.221161f, 0.253191f, 0.131055f, 0.012036f, 0.115484f, 0.618480f, 0.974256f, 0.990345f, 0.409054f, 0.162954f, 0.638762f, 0.490305f, 0.989410f, 0.065304f, 0.783234f, 0.288399f, 0.241419f, 0.662505f, 0.672627f, 0.828480f, 0.852689f, 0.032776f, 0.244157f, 0.339095f, 0.188732f, 0.802975f, 0.767466f, 0.516833f, 0.982926f, 0.144059f, 0.899652f, 0.116463f, 0.163182f, 0.696219f, 0.109570f, 0.565845f, 0.420234f, 0.728474f, 0.900675f, 0.769872f, 0.849690f, 0.032945f, 0.310196f, 0.515433f, 0.415953f, 0.231255f, 0.307874f, 0.945431f, 0.294181f, 0.353904f, 0.003710f, 0.845078f, 0.154841f, 0.204144f, 0.255265f, 0.884622f, 0.206451f, 0.797526f, 0.808049f, 0.927021f, 0.115561f, 0.217279f, 0.742898f, 0.196001f, 0.286330f, 0.166742f, 0.172697f, 0.481553f, 0.109683f, 0.321698f, 0.426594f, 0.024548f, 0.388333f, 0.094122f, 0.493579f, 0.825738f, 0.818422f, 0.080449f, 0.601228f, 0.834586f, 0.237973f, 0.761927f, 0.890764f, 0.806124f, 0.107301f, 0.009060f, 0.191724f, 0.270477f, 0.616183f, 0.384273f, 0.703407f, 0.353075f, 0.154425f, 0.312690f, 0.884324f, 0.958532f, 0.207513f, 0.788468f, 0.273349f, 0.887132f, 0.165546f, 0.665960f, 0.084211f, 0.973893f, 0.700633f, 0.841816f, 0.566669f, 0.476801f, 0.621882f, 0.528742f, 0.469384f, 0.759450f, 0.178201f, 0.171172f, 0.246063f, 0.665859f, 0.517309f, 0.424089f, 0.554688f, 0.287052f, 0.706575f, 0.414857f, 0.360546f, 0.828657f, 0.924967f, 0.046007f, 0.232627f, 0.348519f, 0.814966f, 0.985491f, 0.968972f, 0.904948f, 0.296556f, 0.992011f, 0.249420f, 0.105906f, 0.950953f, 0.233420f, 0.689768f, 0.058356f, 0.730709f, 0.881720f, 0.272437f, 0.379057f, 0.374296f, 0.748788f, 0.237807f, 0.171853f, 0.449292f, 0.304468f, 0.839189f, 0.237742f, 0.502389f, 0.942584f, 0.633998f, 0.867289f, 0.940210f, 0.750765f, 0.699575f, 0.967966f, 0.994401f, 0.451822f}; + // {2, 3, 18, 8} + std::vector present_value = {0.431843f, 0.320748f, 0.074125f, 0.844471f, 0.771603f, 0.543921f, 0.979325f, 0.072600f, 0.766669f, 0.266370f, 0.368599f, 0.219279f, 0.789038f, 0.144240f, 0.840017f, 0.661578f, 0.059023f, 0.810982f, 0.627756f, 0.904982f, 0.748722f, 0.561121f, 0.836547f, 0.278050f, 0.546950f, 0.293617f, 0.968204f, 0.226196f, 0.015738f, 0.325855f, 0.502509f, 0.028363f, 0.559248f, 0.874283f, 0.704732f, 0.622968f, 0.955962f, 0.958279f, 0.824266f, 0.607742f, 0.487765f, 0.013316f, 0.606262f, 0.989088f, 0.818101f, 0.340605f, 0.152047f, 0.784059f, 0.743938f, 0.967047f, 0.874842f, 0.555663f, 0.101284f, 0.483501f, 0.313695f, 0.512408f, 0.301702f, 0.861823f, 0.844327f, 0.315465f, 0.599581f, 0.430181f, 0.909093f, 0.187361f, 0.697728f, 0.970375f, 0.175276f, 0.201966f, 0.693723f, 0.779154f, 0.490549f, 0.609686f, 0.212682f, 0.476614f, 0.112072f, 0.321422f, 0.284780f, 0.444625f, 0.930126f, 0.181268f, 0.401388f, 0.615597f, 0.946557f, 0.133148f, 0.917877f, 0.081054f, 0.480741f, 0.454590f, 0.209603f, 0.347460f, 0.454165f, 0.865211f, 0.955064f, 0.518926f, 0.870100f, 0.608172f, 0.070870f, 0.292794f, 0.152355f, 0.417486f, 0.131289f, 0.604118f, 0.382808f, 0.895386f, 0.967795f, 0.546885f, 0.274824f, 0.592230f, 0.896761f, 0.406733f, 0.552078f, 0.271653f, 0.455444f, 0.401714f, 0.248413f, 0.505866f, 0.310381f, 0.373035f, 0.524970f, 0.750595f, 0.333507f, 0.924159f, 0.862319f, 0.048690f, 0.253643f, 0.446136f, 0.104628f, 0.348476f, 0.740098f, 0.680514f, 0.622384f, 0.710528f, 0.204924f, 0.341698f, 0.676242f, 0.879235f, 0.543678f, 0.282700f, 0.030235f, 0.710337f, 0.007884f, 0.372679f, 0.530537f, 0.922111f, 0.349087f, 0.194194f, 0.413135f, 0.522824f, 0.044443f, 0.145841f, 0.600184f, 0.225002f, 0.837326f, 0.326942f, 0.104834f, 0.083531f, 0.937123f, 0.118020f, 0.140910f, 0.862666f, 0.254288f, 0.665951f, 0.816726f, 0.607181f, 0.957489f, 0.708883f, 0.112752f, 0.558410f, 0.718186f, 0.801957f, 0.026321f, 0.718879f, 0.825681f, 0.746834f, 0.512349f, 0.458021f, 0.549419f, 0.704644f, 0.922914f, 0.617035f, 0.887834f, 0.701257f, 0.068336f, 0.500828f, 0.286486f, 0.285175f, 0.355928f, 0.314733f, 0.578610f, 0.683601f, 0.268749f, 0.129763f, 0.058809f, 0.575753f, 0.186130f, 0.009248f, 0.927753f, 0.537140f, 0.092448f, 0.842921f, 0.983203f, 0.448601f, 0.042490f, 0.117546f, 0.381654f, 0.885523f, 0.148039f, 0.823990f, 0.014976f, 0.457389f, 0.644397f, 0.060379f, 0.614763f, 0.944404f, 0.160260f, 0.729611f, 0.609094f, 0.185116f, 0.006203f, 0.009284f, 0.532092f, 0.942779f, 0.644299f, 0.714300f, 0.493865f, 0.581889f, 0.126368f, 0.876821f, 0.760793f, 0.998199f, 0.297723f, 0.227018f, 0.125162f, 0.964210f, 0.780885f, 0.166325f, 0.552686f, 0.413768f, 0.151486f, 0.162073f, 0.089495f, 0.405942f, 0.024313f, 0.342611f, 0.622231f, 0.279068f, 0.209750f, 0.115703f, 0.577140f, 0.695270f, 0.671957f, 0.948861f, 0.002703f, 0.647197f, 0.600392f, 0.588740f, 0.962770f, 0.016872f, 0.696482f, 0.813679f, 0.509807f, 0.333965f, 0.790840f, 0.097243f, 0.442036f, 0.519952f, 0.693956f, 0.090886f, 0.227759f, 0.410302f, 0.623295f, 0.886961f, 0.618826f, 0.133461f, 0.980580f, 0.871786f, 0.502721f, 0.922348f, 0.541381f, 0.923306f, 0.829897f, 0.968286f, 0.919783f, 0.036034f, 0.174772f, 0.389135f, 0.952143f, 0.300029f, 0.963470f, 0.304964f, 0.941439f, 0.075611f, 0.460803f, 0.129619f, 0.004787f, 0.553766f, 0.113894f, 0.722025f, 0.698116f, 0.176333f, 0.941742f, 0.721043f, 0.297970f, 0.709234f, 0.731930f, 0.342226f, 0.375589f, 0.359107f, 0.616618f, 0.900410f, 0.173193f, 0.875200f, 0.027653f, 0.660339f, 0.414439f, 0.791282f, 0.721198f, 0.480108f, 0.643864f, 0.501773f, 0.811518f, 0.476084f, 0.523156f, 0.250521f, 0.605043f, 0.302905f, 0.577284f, 0.169678f, 0.159469f, 0.417030f, 0.426820f, 0.268109f, 0.131597f, 0.039211f, 0.025232f, 0.271550f, 0.461853f, 0.726243f, 0.474872f, 0.904051f, 0.035220f, 0.180661f, 0.338515f, 0.577496f, 0.852736f, 0.350202f, 0.267989f, 0.061889f, 0.821303f, 0.379666f, 0.571550f, 0.983555f, 0.001595f, 0.145450f, 0.779111f, 0.805128f, 0.769247f, 0.536999f, 0.978857f, 0.396185f, 0.601944f, 0.063369f, 0.409857f, 0.722500f, 0.238739f, 0.943828f, 0.686783f, 0.287575f, 0.768999f, 0.083165f, 0.974774f, 0.049285f, 0.933456f, 0.252854f, 0.757824f, 0.000074f, 0.254240f, 0.749101f, 0.532336f, 0.114952f, 0.393630f, 0.375549f, 0.568162f, 0.667977f, 0.160468f, 0.886305f, 0.446394f, 0.907876f, 0.160230f, 0.661117f, 0.440264f, 0.076487f, 0.696463f, 0.247399f, 0.039616f, 0.059944f, 0.061079f, 0.907733f, 0.739884f, 0.898062f, 0.672582f, 0.528940f, 0.304446f, 0.997962f, 0.362189f, 0.470649f, 0.378245f, 0.979527f, 0.174658f, 0.327988f, 0.680349f, 0.063208f, 0.607249f, 0.477646f, 0.284000f, 0.238413f, 0.514513f, 0.367928f, 0.456520f, 0.337477f, 0.970494f, 0.133439f, 0.096804f, 0.343392f, 0.591027f, 0.659176f, 0.397257f, 0.999278f, 0.351893f, 0.721407f, 0.637583f, 0.813054f, 0.840830f, 0.497231f, 0.392022f, 0.143977f, 0.804823f, 0.713370f, 0.408677f, 0.518432f, 0.665183f, 0.164806f, 0.027198f, 0.317504f, 0.595585f, 0.486606f, 0.692555f, 0.819690f, 0.488442f, 0.134267f, 0.850628f, 0.574990f, 0.739937f, 0.704665f, 0.968212f, 0.295307f, 0.705307f, 0.365676f, 0.395411f, 0.230595f, 0.344010f, 0.948297f, 0.292571f, 0.245991f, 0.583138f, 0.258036f, 0.473386f, 0.834176f, 0.230400f, 0.426691f, 0.610490f, 0.545629f, 0.974723f, 0.680370f, 0.739946f, 0.966956f, 0.414438f, 0.355380f, 0.043862f, 0.184204f, 0.237190f, 0.183504f, 0.754784f, 0.535883f, 0.667634f, 0.820462f, 0.230774f, 0.325924f, 0.708360f, 0.392759f, 0.029271f, 0.434955f, 0.908273f, 0.409021f, 0.332249f, 0.989525f, 0.644416f, 0.365998f, 0.102020f, 0.787849f, 0.708075f, 0.921916f, 0.217276f, 0.114924f, 0.724073f, 0.203396f, 0.176104f, 0.319807f, 0.816825f, 0.539537f, 0.045850f, 0.463895f, 0.683980f, 0.538368f, 0.572450f, 0.224777f, 0.847739f, 0.561399f, 0.713246f, 0.981864f, 0.428199f, 0.881067f, 0.007281f, 0.033407f, 0.590280f, 0.311449f, 0.248277f, 0.277935f, 0.976226f, 0.889794f, 0.764562f, 0.698249f, 0.335498f, 0.147686f, 0.062636f, 0.241902f, 0.432281f, 0.521996f, 0.773084f, 0.958741f, 0.117320f, 0.107004f, 0.589695f, 0.745398f, 0.848150f, 0.935832f, 0.983426f, 0.399802f, 0.380335f, 0.147809f, 0.684934f, 0.656762f, 0.862063f, 0.097258f, 0.497777f, 0.581082f, 0.241557f, 0.169025f, 0.859581f, 0.058535f, 0.470621f, 0.115834f, 0.457059f, 0.979962f, 0.423706f, 0.857125f, 0.117316f, 0.271252f, 0.403793f, 0.399812f, 0.671384f, 0.344718f, 0.713767f, 0.639187f, 0.399161f, 0.431760f, 0.318403f, 0.728948f, 0.569196f, 0.789036f, 0.830197f, 0.842935f, 0.414644f, 0.421273f, 0.926266f, 0.661764f, 0.080467f, 0.542187f, 0.356007f, 0.987435f, 0.013655f, 0.612181f, 0.723623f, 0.288907f, 0.973642f, 0.859537f, 0.915653f, 0.019232f, 0.569872f, 0.294650f, 0.849029f, 0.632850f, 0.538877f, 0.114588f, 0.540223f, 0.631904f, 0.955912f, 0.585051f, 0.967401f, 0.961606f, 0.650200f, 0.505908f, 0.466022f, 0.890379f, 0.028257f, 0.113808f, 0.102072f, 0.756935f, 0.339651f, 0.637969f, 0.603783f, 0.385828f, 0.531568f, 0.645139f, 0.940950f, 0.575634f, 0.614367f, 0.067856f, 0.952216f, 0.528082f, 0.801273f, 0.050291f, 0.420910f, 0.256975f, 0.266976f, 0.791454f, 0.623867f, 0.439745f, 0.010586f, 0.964928f, 0.962023f, 0.217552f, 0.041346f, 0.530199f, 0.951411f, 0.910396f, 0.584663f, 0.303549f, 0.329961f, 0.897914f, 0.491784f, 0.131116f, 0.248425f, 0.276795f, 0.123547f, 0.463044f, 0.916051f, 0.668783f, 0.072474f, 0.005495f, 0.276248f, 0.362693f, 0.776750f, 0.967006f, 0.387567f, 0.686690f, 0.994902f, 0.745667f, 0.636190f, 0.078075f, 0.323215f, 0.913392f, 0.614528f, 0.070042f, 0.822407f, 0.653421f, 0.726342f, 0.536923f, 0.110477f, 0.405036f, 0.405374f, 0.321043f, 0.029950f, 0.737254f, 0.109784f, 0.606308f, 0.703218f, 0.634786f, 0.959142f, 0.103298f, 0.867167f, 0.029190f, 0.534917f, 0.404244f, 0.524184f, 0.365100f, 0.190567f, 0.019123f, 0.518150f, 0.842777f, 0.373216f, 0.222864f, 0.080532f, 0.085311f, 0.221396f, 0.100014f, 0.265040f, 0.066149f, 0.065605f, 0.856276f, 0.162120f, 0.559682f, 0.773456f, 0.456410f, 0.153369f, 0.199596f, 0.432984f, 0.528234f, 0.349440f, 0.781480f, 0.201005f, 0.843590f, 0.696324f, 0.366324f, 0.529174f, 0.542806f, 0.714054f, 0.516556f, 0.133076f, 0.773455f, 0.406273f, 0.963094f, 0.283514f, 0.263079f, 0.333507f, 0.572317f, 0.894870f, 0.176282f, 0.279679f, 0.581680f, 0.454334f, 0.447323f, 0.820734f, 0.923878f, 0.481307f, 0.687352f, 0.801059f, 0.518366f, 0.294316f, 0.638085f, 0.585109f, 0.901563f, 0.052407f, 0.910131f, 0.534432f, 0.015676f, 0.344702f, 0.724334f, 0.488433f, 0.980159f, 0.422610f, 0.326635f, 0.821672f, 0.547907f, 0.682327f, 0.805702f, 0.671428f, 0.422408f, 0.124796f, 0.580248f, 0.897433f, 0.418892f, 0.910725f, 0.503528f, 0.620842f, 0.832989f, 0.564597f, 0.090969f, 0.980979f, 0.245849f, 0.710505f, 0.505113f, 0.478773f, 0.243941f, 0.722151f, 0.112788f, 0.990453f, 0.845374f, 0.534509f, 0.424553f, 0.286465f, 0.501591f, 0.879417f, 0.275006f, 0.500537f, 0.234550f, 0.337149f, 0.190261f, 0.990539f, 0.571497f, 0.732815f, 0.098250f, 0.366118f, 0.892640f, 0.084438f, 0.165483f, 0.625418f, 0.622789f, 0.838227f, 0.935493f, 0.141986f, 0.259374f, 0.427461f, 0.000903f, 0.069814f, 0.226491f, 0.751022f, 0.927212f, 0.028953f, 0.895691f, 0.392569f, 0.878372f, 0.690785f, 0.987349f, 0.759282f, 0.364545f, 0.501063f, 0.376389f, 0.364912f, 0.260904f, 0.495970f, 0.681740f, 0.277340f, 0.524380f, 0.117380f, 0.159845f, 0.046806f, 0.970731f, 0.003860f, 0.178580f, 0.612867f, 0.081370f, 0.881896f, 0.719620f, 0.966390f, 0.507636f, 0.300404f, 0.549501f, 0.930819f, 0.520761f, 0.267207f, 0.877399f, 0.371919f, 0.001383f, 0.247685f, 0.318234f, 0.858777f, 0.458503f, 0.444587f, 0.336102f, 0.880678f, 0.945027f, 0.991890f, 0.376741f}; + // {2, 3, 4, 18} + constexpr float inff = std::numeric_limits::infinity(); + std::vector qk_matmul = {2.137658f, 1.567682f, 1.582827f, 0.953936f, 0.636597f, 1.001645f, 1.885707f, 1.361086f, 1.495408f, 1.566455f, 1.459078f, 1.668413f, 0.904174f, -inff, -inff, -inff, -inff, -inff, 1.229267f, 0.591855f, 1.372683f, 0.964445f, 1.006092f, 1.046331f, 1.712052f, 1.060710f, 2.141520f, 1.917742f, 1.063752f, 0.892409f, 0.884336f, 0.881352f, -inff, -inff, -inff, -inff, 2.235662f, 1.742821f, 2.198921f, 1.079357f, 1.510221f, 1.812315f, 1.396341f, 1.864746f, 1.498768f, 2.115730f, 0.844762f, 1.323617f, 1.096593f, 1.033003f, 1.868677f, -inff, -inff, -inff, 1.429269f, 0.876355f, 0.928405f, 1.469794f, 0.649940f, 1.435654f, 1.452830f, 1.053687f, 1.338220f, 0.966775f, 1.237266f, 1.488850f, 1.438267f, 0.931250f, 1.633272f, 0.944889f, -inff, -inff, 1.172613f, 1.105815f, 1.263303f, 1.702161f, 1.406517f, 1.808470f, 1.496128f, 1.169961f, 1.428707f, 1.393064f, 1.624670f, 1.287919f, 0.674733f, -inff, -inff, -inff, -inff, -inff, 0.838456f, 1.191558f, 1.771291f, 1.491907f, 0.911088f, 0.865799f, 1.154893f, 1.472593f, 0.826140f, 0.896018f, 1.281853f, 0.942941f, 1.470656f, 0.816028f, -inff, -inff, -inff, -inff, 1.133820f, 1.086309f, 1.712385f, 1.254675f, 1.427773f, 0.748848f, 1.056134f, 1.187805f, 1.419181f, 1.140224f, 1.269629f, 1.135934f, 0.694738f, 1.528325f, 0.959286f, -inff, -inff, -inff, 1.160321f, 1.097000f, 1.485019f, 1.111147f, 0.836961f, 0.948765f, 1.234762f, 0.835082f, 0.833382f, 0.589928f, 1.266538f, 1.303439f, 0.622733f, 0.837537f, 0.605730f, 0.730216f, -inff, -inff, 2.078597f, 0.610472f, 1.371772f, 0.794857f, 1.018924f, 1.165257f, 1.466839f, 1.206415f, 1.662507f, 1.098436f, 1.283408f, 1.533854f, 1.247966f, -inff, -inff, -inff, -inff, -inff, 1.707491f, 0.439978f, 0.919238f, 0.297115f, 0.982817f, 1.370520f, 0.766707f, 0.938981f, 1.095468f, 1.442393f, 0.742909f, 0.529869f, 0.628822f, 1.353301f, -inff, -inff, -inff, -inff, 1.483284f, 1.334536f, 0.757364f, 1.243801f, 0.767143f, 0.919318f, 0.693929f, 1.000990f, 1.107699f, 1.001247f, 1.434079f, 1.522769f, 0.696104f, 1.336034f, 0.501240f, -inff, -inff, -inff, 1.535892f, 1.342303f, 0.701559f, 1.211220f, 1.510985f, 0.961962f, 1.471503f, 1.440467f, 1.835586f, 0.947043f, 1.254547f, 1.009386f, 0.842613f, 1.508191f, 1.233544f, 1.280385f, -inff, -inff, 1.552432f, 0.958768f, 1.676495f, 1.810273f, 1.019336f, 1.487615f, 0.695035f, 1.391893f, 1.060641f, 0.917107f, 1.115109f, 1.128137f, 0.986429f, -inff, -inff, -inff, -inff, -inff, 1.289288f, 1.303667f, 0.882238f, 1.948027f, 1.580638f, 0.863439f, 1.059965f, 2.095325f, 1.493638f, 0.654104f, 0.828719f, 1.673449f, 0.479778f, 1.149678f, -inff, -inff, -inff, -inff, 1.177682f, 1.225590f, 1.735621f, 2.114078f, 1.905758f, 1.835981f, 1.432170f, 1.444457f, 2.016032f, 0.762211f, 1.059737f, 1.378216f, 1.564930f, 1.950097f, 1.598798f, -inff, -inff, -inff, 0.820477f, 0.962096f, 1.188223f, 1.264395f, 1.676953f, 1.487113f, 0.962162f, 1.377522f, 1.370079f, 1.450785f, 1.131087f, 1.962317f, 0.764849f, 0.777860f, 1.194763f, 1.030136f, -inff, -inff, 1.096708f, 1.345589f, 1.404595f, 1.370459f, 1.263369f, 1.364863f, 0.489623f, 0.596189f, 1.079480f, 0.915348f, 0.770954f, 1.548047f, 1.519504f, -inff, -inff, -inff, -inff, -inff, 1.856943f, 0.790590f, 1.235241f, 2.061177f, 1.282346f, 1.896653f, 1.112410f, 1.622862f, 0.780625f, 1.990919f, 1.693934f, 1.466544f, 1.026297f, 1.323339f, -inff, -inff, -inff, -inff, 1.778816f, 1.746915f, 1.169870f, 1.847628f, 0.729303f, 2.421048f, 1.266061f, 1.481203f, 1.016384f, 2.038725f, 1.132054f, 1.669076f, 1.958931f, 1.654780f, 1.644111f, -inff, -inff, -inff, 0.856287f, 1.124803f, 1.216201f, 0.831110f, 0.761234f, 1.204141f, 0.994307f, 0.832859f, 1.294077f, 1.566637f, 1.102631f, 1.472731f, 1.569911f, 0.779225f, 1.536189f, 1.277889f, -inff, -inff, 0.944230f, 1.585174f, 1.001532f, 0.973579f, 1.652668f, 1.112330f, 1.052878f, 1.326390f, 1.526319f, 1.790060f, 1.219317f, 1.742865f, 0.871467f, -inff, -inff, -inff, -inff, -inff, 0.794245f, 1.084904f, 0.813691f, 1.037344f, 0.254175f, 1.071614f, 0.477497f, 0.773591f, 1.317670f, 1.382451f, 0.759806f, 1.228428f, 0.583565f, 1.274037f, -inff, -inff, -inff, -inff, 0.865060f, 0.697643f, 1.300273f, 1.064195f, 1.435744f, 1.516307f, 0.626589f, 1.255387f, 1.115037f, 1.202643f, 1.789729f, 1.328769f, 1.046150f, 1.149905f, 1.696396f, -inff, -inff, -inff, 1.421552f, 1.324626f, 1.029005f, 0.960238f, 1.215132f, 1.450928f, 1.351898f, 1.718175f, 1.502146f, 1.736591f, 1.019685f, 1.130950f, 1.097223f, 1.330517f, 1.675029f, 1.069868f, -inff, -inff}; ASSERT_EQ(y.size(), batch_size * q_num_heads * q_sequence_length * v_head_size); ASSERT_EQ(present_key.size(), batch_size * kv_num_heads * (past_sequence_length + kv_sequence_length) * head_size); ASSERT_EQ(present_value.size(), batch_size * kv_num_heads * (past_sequence_length + kv_sequence_length) * v_head_size); @@ -1196,7 +1230,7 @@ TEST(AttentionTest, Attention4DWithMask4DPastAndPresentQkMatmul) { RunTest4D(batch_size, q_num_heads, q_sequence_length, head_size, kv_sequence_length, kv_num_heads, v_head_size, past_sequence_length, q, k, v, m, std::initializer_list(), past_key, past_value, - -1, -1, std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), -1, TensorType::kFloat, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type + 1, 1, std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), -1, TensorType::kFloat, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type y, present_key, present_value, qk_matmul, false, true, true // disable_cpu, disable_cuda, disable_dml ); diff --git a/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc b/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc index 5b384c3775738..e6ffd2aaa6041 100644 --- a/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc +++ b/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc @@ -6,7 +6,7 @@ #include "test/util/include/default_providers.h" #include "test/util/include/current_test_name.h" #include "test/util/include/test_utils.h" -#include "test/framework/test_utils.h" +#include "test/unittest_util/framework_test_utils.h" #include "test/common/dnnl_op_test_utils.h" #include "test/common/cuda_op_test_utils.h" #include "test/common/trt_op_test_utils.h" diff --git a/onnxruntime/test/providers/cpu/math/gemm_test.cc b/onnxruntime/test/providers/cpu/math/gemm_test.cc index e2cc5ac029318..0e5a4dac465b1 100644 --- a/onnxruntime/test/providers/cpu/math/gemm_test.cc +++ b/onnxruntime/test/providers/cpu/math/gemm_test.cc @@ -7,7 +7,6 @@ #include "test/common/cuda_op_test_utils.h" #include "test/providers/provider_test_utils.h" #include "test/common/dnnl_op_test_utils.h" -#include "test/providers/run_options_config_keys.h" #include "test/util/include/default_providers.h" namespace onnxruntime { diff --git a/onnxruntime/test/providers/cpu/math/matmul_fastmath_test.cc b/onnxruntime/test/providers/cpu/math/matmul_fastmath_test.cc index 75e0c06b04f0d..8e6ff1f387bf1 100644 --- a/onnxruntime/test/providers/cpu/math/matmul_fastmath_test.cc +++ b/onnxruntime/test/providers/cpu/math/matmul_fastmath_test.cc @@ -5,7 +5,6 @@ #include "core/session/onnxruntime_session_options_config_keys.h" #include "gtest/gtest.h" #include "test/providers/provider_test_utils.h" -#include "test/providers/run_options_config_keys.h" #include "test/common/dnnl_op_test_utils.h" #include "test/common/cuda_op_test_utils.h" #include "test/common/tensor_op_test_utils.h" diff --git a/onnxruntime/test/providers/cpu/math/matmul_test.cc b/onnxruntime/test/providers/cpu/math/matmul_test.cc index 4531e480e1460..b7f2b5800560a 100644 --- a/onnxruntime/test/providers/cpu/math/matmul_test.cc +++ b/onnxruntime/test/providers/cpu/math/matmul_test.cc @@ -4,7 +4,6 @@ #include "gtest/gtest.h" #include "test/providers/provider_test_utils.h" -#include "test/providers/run_options_config_keys.h" #include "test/common/dnnl_op_test_utils.h" #include "test/common/cuda_op_test_utils.h" #include "test/common/tensor_op_test_utils.h" diff --git a/onnxruntime/test/providers/cpu/ml/tree_ensembler_test.cc b/onnxruntime/test/providers/cpu/ml/tree_ensembler_test.cc index 32368734f3530..88a9756956343 100644 --- a/onnxruntime/test/providers/cpu/ml/tree_ensembler_test.cc +++ b/onnxruntime/test/providers/cpu/ml/tree_ensembler_test.cc @@ -356,7 +356,7 @@ TEST(MLOpTest, TreeEnsembleBigSet) { std::vector leaf_weights = {1.f, 10.f, 100.f, 1000.f}; std::vector member_ship_values(106); for (size_t i = 0; i < member_ship_values.size(); i++) { - member_ship_values[i] = i + 40; + member_ship_values[i] = static_cast(i + 40); } member_ship_values[100] = std::numeric_limits::quiet_NaN(); member_ship_values[101] = 201; diff --git a/onnxruntime/test/providers/cpu/ml/write_scores_test.cc b/onnxruntime/test/providers/cpu/ml/write_scores_test.cc index 4b7c3e4f04780..cec17ec254e47 100644 --- a/onnxruntime/test/providers/cpu/ml/write_scores_test.cc +++ b/onnxruntime/test/providers/cpu/ml/write_scores_test.cc @@ -4,7 +4,7 @@ #include "gtest/gtest.h" #include "core/framework/tensor.h" #include "core/providers/cpu/ml/ml_common.h" -#include "test/framework/dummy_allocator.h" +#include "test/unittest_util/dummy_allocator.h" #include using namespace onnxruntime; diff --git a/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc b/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc index 049f70d660f1e..19282437b4823 100644 --- a/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc +++ b/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc @@ -7,7 +7,6 @@ #include "gtest/gtest.h" #include "test/providers/provider_test_utils.h" -#include "test/providers/run_options_config_keys.h" #include "default_providers.h" using namespace std; diff --git a/onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc b/onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc index 0e17aa835028e..46acb5a730a78 100644 --- a/onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc @@ -931,7 +931,7 @@ void DequantizeLinearOp21BlockedTest_InvalidBlockSize_Int(int64_t block_size, for (int64_t i = 0, n = 2 * zero_point_block_count; i < n; ++i) x_zero_point.push_back(0); for (int64_t i = 0, n = 2 * scale_block_count; i < n; i++) x_scale.push_back(Tout(2.0f)); for (int i = 0; i < 8; ++i) { - x.push_back(i); + x.push_back(static_cast(i)); y.push_back(Tout(static_cast(i) * 2.0f)); } @@ -973,10 +973,11 @@ void DequantizeLinearOp21BlockedTest_InvalidBlockSize_Int4(int64_t block_size, so.session_log_verbosity_level = 1; so.graph_optimization_level = TransformerLevel::Default; + using UnpackedType = typename Tin::UnpackedType; for (int64_t i = 0, n = zero_point_block_count; i < n; ++i) x_zero_point.push_back(Tin(0, 0)); for (int64_t i = 0, n = 2 * scale_block_count; i < n; i++) x_scale.push_back(Tout(2.0f)); for (int i = 0; i < 8; ++i) { - if (i & 1) x.push_back(Tin(i - 1, i)); + if (i & 1) x.push_back(Tin(static_cast(i - 1), static_cast(i))); y.push_back(Tout(static_cast(i) * 2.0f)); } @@ -1198,14 +1199,16 @@ void DequantizeLinearOp21BlockedTest_Int4_Succeed(std::vector&& dims, x_scale_shape.push_back((int64_t)i == non_neg_axis ? (dims[i] + block_size - 1) / block_size : dims[i]); } + using UnpackedType = typename Tin::UnpackedType; size_t i = 0, n = x_.size(); - for (; i < n - 1; i += 2) x.push_back(Tin(x_[i], x_[i + 1])); - if (i < n) x.push_back(Tin(x_[i], 0xF)); + for (; i < n - 1; i += 2) x.push_back(Tin(static_cast(x_[i]), static_cast(x_[i + 1]))); + if (i < n) x.push_back(Tin(static_cast(x_[i]), 0xF)); if (use_zero_point) { i = 0, n = x_zero_point_.size(); - for (; i < n - 1; i += 2) x_zero_point.push_back(Tin(x_zero_point_[i], x_zero_point_[i + 1])); - if (i < n) x_zero_point.push_back(Tin(x_zero_point_[i], 0xF)); + for (; i < n - 1; i += 2) x_zero_point.push_back(Tin(static_cast(x_zero_point_[i]), + static_cast(x_zero_point_[i + 1]))); + if (i < n) x_zero_point.push_back(Tin(static_cast(x_zero_point_[i]), 0xF)); } test.AddInput("x", dims, x); @@ -1240,9 +1243,9 @@ void DequantizeLinearOp21BlockedTest_Int_Succeed(std::vector&& dims, for (size_t i = 0, n = dims.size(); i < n; ++i) { x_scale_shape.push_back((int64_t)i == non_neg_axis ? (dims[i] + block_size - 1) / block_size : dims[i]); } - for (auto v : x_) x.push_back(v); + for (auto v : x_) x.push_back(static_cast(v)); if (use_zero_point) - for (auto v : x_zero_point_) x_zero_point.push_back(v); + for (auto v : x_zero_point_) x_zero_point.push_back(static_cast(v)); test.AddInput("x", dims, x); test.AddAttribute("axis", axis); @@ -1913,7 +1916,7 @@ void QuantizeLinearOp21BlockedTest_InvalidBlockSize_Int(int64_t block_size, for (int64_t i = 0, n = 2 * scale_block_count; i < n; i++) x_scale.push_back(Tin(2.0f)); for (int i = 0; i < 8; ++i) { x.push_back(Tin(static_cast(i) * 2.0f)); - y.push_back(i); + y.push_back(static_cast(i)); } test.AddInput("x", dims, x); @@ -1954,10 +1957,11 @@ void QuantizeLinearOp21BlockedTest_InvalidBlockSize_Int4(int64_t block_size, so.session_log_verbosity_level = 1; so.graph_optimization_level = TransformerLevel::Default; + using UnpackedType = typename Tout::UnpackedType; for (int64_t i = 0, n = zero_point_block_count; i < n; ++i) x_zero_point.push_back(Tout(0, 0)); for (int64_t i = 0, n = 2 * scale_block_count; i < n; i++) x_scale.push_back(Tin(2.0f)); for (int i = 0; i < 8; ++i) { - if (i & 1) y.push_back(Tout(i - 1, i)); + if (i & 1) y.push_back(Tout(static_cast(i - 1), static_cast(i))); x.push_back(Tin(static_cast(i) * 2.0f)); } @@ -2179,14 +2183,17 @@ void QuantizeLinearOp21BlockedTest_Int4_Succeed(std::vector&& dims, scale_shape.push_back((int64_t)i == non_neg_axis ? (dims[i] + block_size - 1) / block_size : dims[i]); } + using UnpackedType = typename Tout::UnpackedType; size_t i = 0, n = y_.size(); - for (; i < n - 1; i += 2) y.push_back(Tout(y_[i], y_[i + 1])); - if (i < n) y.push_back(Tout(y_[i], 0xF)); + for (; i < n - 1; i += 2) y.push_back(Tout(static_cast(y_[i]), + static_cast(y_[i + 1]))); + if (i < n) y.push_back(Tout(static_cast(y_[i]), 0xF)); if (use_zero_point) { i = 0, n = zero_point_.size(); - for (; i < n - 1; i += 2) zero_point.push_back(Tout(zero_point_[i], zero_point_[i + 1])); - if (i < n) zero_point.push_back(Tout(zero_point_[i], 0xF)); + for (; i < n - 1; i += 2) zero_point.push_back(Tout(static_cast(zero_point_[i]), + static_cast(zero_point_[i + 1]))); + if (i < n) zero_point.push_back(Tout(static_cast(zero_point_[i]), 0xF)); } test.AddInput("x", dims, x); @@ -2221,9 +2228,9 @@ void QuantizeLinearOp21BlockedTest_Int_Succeed(std::vector&& dims, for (size_t i = 0, n = dims.size(); i < n; ++i) { scale_shape.push_back((int64_t)i == non_neg_axis ? (dims[i] + block_size - 1) / block_size : dims[i]); } - for (auto v : y_) y.push_back(v); + for (auto v : y_) y.push_back(static_cast(v)); if (use_zero_point) - for (auto v : zero_point_) zero_point.push_back(v); + for (auto v : zero_point_) zero_point.push_back(static_cast(v)); test.AddInput("x", dims, x); test.AddAttribute("axis", axis); diff --git a/onnxruntime/test/providers/cuda/nhwc/conv_transpose_test.cc b/onnxruntime/test/providers/cuda/nhwc/conv_transpose_test.cc index 9837dca398ff8..047b4795a34a2 100644 --- a/onnxruntime/test/providers/cuda/nhwc/conv_transpose_test.cc +++ b/onnxruntime/test/providers/cuda/nhwc/conv_transpose_test.cc @@ -20,7 +20,7 @@ struct ConvTransposeOp { std::vector dilations = {1, 1}; std::unique_ptr get_test() { - RandomValueGenerator random{T(123.f)}; // use seed so output is deterministic to aid in debugging failures + RandomValueGenerator random{123}; // use seed so output is deterministic to aid in debugging failures auto test = std::make_unique("ConvTranspose", 14); std::vector input_data = random.Uniform(input_dims, 0.0f, 1.0f); diff --git a/onnxruntime/test/providers/cuda/test_cases/cuda_test_provider.cc b/onnxruntime/test/providers/cuda/test_cases/cuda_test_provider.cc index 8213a95dcff08..735bd89aff260 100644 --- a/onnxruntime/test/providers/cuda/test_cases/cuda_test_provider.cc +++ b/onnxruntime/test/providers/cuda/test_cases/cuda_test_provider.cc @@ -107,10 +107,10 @@ struct ProviderInfo_CUDA_TestImpl : ProviderInfo_CUDA { void TestAll() override { // TestAll is the entry point of CUDA EP's internal tests. - // Those internal tests are not directly callable from onnxruntime_test_all + // Those internal tests are not directly callable from onnxruntime_provider_test // because CUDA EP is a shared library now. // Instead, this is a test provider that implements all the test cases. - // onnxruntime_test_all is calling this function through TryGetProviderInfo_CUDA_Test. + // onnxruntime_provider_test is calling this function through TryGetProviderInfo_CUDA_Test. char mock_exe_name[] = "onnxruntime_providers_cuda_ut"; // InitGoogleTest decrements argc and removes args from argv if diff --git a/onnxruntime/test/providers/memcpy_test.cc b/onnxruntime/test/providers/memcpy_test.cc index 4efa359b4e589..cfef8ba63715f 100644 --- a/onnxruntime/test/providers/memcpy_test.cc +++ b/onnxruntime/test/providers/memcpy_test.cc @@ -2,7 +2,7 @@ // Licensed under the MIT License. #include "gtest/gtest.h" -#include "../framework/test_utils.h" +#include "test/unittest_util/framework_test_utils.h" #include "core/graph/model.h" #include "core/graph/onnx_protobuf.h" #include "core/framework/execution_providers.h" diff --git a/onnxruntime/test/providers/migraphx/migraphx_basic_test.cc b/onnxruntime/test/providers/migraphx/migraphx_basic_test.cc index 761ddf1975d15..ca3b9ee8c5a9b 100644 --- a/onnxruntime/test/providers/migraphx/migraphx_basic_test.cc +++ b/onnxruntime/test/providers/migraphx/migraphx_basic_test.cc @@ -3,7 +3,7 @@ #include "core/graph/onnx_protobuf.h" #include "core/session/inference_session.h" #include "test/providers/provider_test_utils.h" -#include "test/framework/test_utils.h" +#include "test/unittest_util/framework_test_utils.h" #include "gtest/gtest.h" #include "test/util/include/default_providers.h" #include "test/util/include/scoped_env_vars.h" diff --git a/onnxruntime/test/providers/nnapi/nnapi_basic_test.cc b/onnxruntime/test/providers/nnapi/nnapi_basic_test.cc index 22a80da2f95d2..8c99d3cd995fb 100644 --- a/onnxruntime/test/providers/nnapi/nnapi_basic_test.cc +++ b/onnxruntime/test/providers/nnapi/nnapi_basic_test.cc @@ -12,7 +12,7 @@ #include "core/session/inference_session.h" #include "core/framework/tensorprotoutils.h" #include "test/common/tensor_op_test_utils.h" -#include "test/framework/test_utils.h" +#include "test/unittest_util/framework_test_utils.h" #include "test/util/include/asserts.h" #include "test/util/include/current_test_name.h" #include "test/util/include/default_providers.h" @@ -24,7 +24,7 @@ #if !defined(ORT_MINIMAL_BUILD) // if this is a full build we need the provider test utils #include "test/providers/provider_test_utils.h" -#include "test/optimizer/qdq_test_utils.h" +#include "test/unittest_util/qdq_test_utils.h" #endif #include "gtest/gtest.h" diff --git a/onnxruntime/test/providers/nv_tensorrt_rtx/nv_basic_test.cc b/onnxruntime/test/providers/nv_tensorrt_rtx/nv_basic_test.cc index 2327bc2094d1a..d8cc56d738175 100644 --- a/onnxruntime/test/providers/nv_tensorrt_rtx/nv_basic_test.cc +++ b/onnxruntime/test/providers/nv_tensorrt_rtx/nv_basic_test.cc @@ -3,7 +3,7 @@ #include "core/graph/onnx_protobuf.h" #include "core/session/inference_session.h" #include "test/providers/provider_test_utils.h" -#include "test/framework/test_utils.h" +#include "test/unittest_util/framework_test_utils.h" #include "test/util/include/scoped_env_vars.h" #include "test/common/trt_op_test_utils.h" diff --git a/onnxruntime/test/providers/nv_tensorrt_rtx/nv_ep_context_test.cc b/onnxruntime/test/providers/nv_tensorrt_rtx/nv_ep_context_test.cc index ce49ae81c81c0..ac24dcb70c1dd 100644 --- a/onnxruntime/test/providers/nv_tensorrt_rtx/nv_ep_context_test.cc +++ b/onnxruntime/test/providers/nv_tensorrt_rtx/nv_ep_context_test.cc @@ -1,7 +1,7 @@ // SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // Licensed under the MIT License. #include "core/common/path_utils.h" -#include "test/framework/test_utils.h" +#include "test/unittest_util/framework_test_utils.h" #include "test/providers/nv_tensorrt_rtx/test_nv_trt_rtx_ep_util.h" #include diff --git a/onnxruntime/test/providers/nv_tensorrt_rtx/test_nv_trt_rtx_ep_util.cc b/onnxruntime/test/providers/nv_tensorrt_rtx/test_nv_trt_rtx_ep_util.cc index 7f7894abdf3d5..3a91fc1ba09bb 100644 --- a/onnxruntime/test/providers/nv_tensorrt_rtx/test_nv_trt_rtx_ep_util.cc +++ b/onnxruntime/test/providers/nv_tensorrt_rtx/test_nv_trt_rtx_ep_util.cc @@ -18,7 +18,7 @@ #include "test/util/include/scoped_env_vars.h" #include "test/common/trt_op_test_utils.h" #include "test/providers/provider_test_utils.h" -#include "test/framework/test_utils.h" +#include "test/unittest_util/framework_test_utils.h" namespace onnxruntime { namespace test { diff --git a/onnxruntime/test/providers/openvino/openvino_ep_bfloat16_pass_test.cc b/onnxruntime/test/providers/openvino/openvino_ep_bfloat16_pass_test.cc index fc90563a61bb1..cd0263f76db76 100644 --- a/onnxruntime/test/providers/openvino/openvino_ep_bfloat16_pass_test.cc +++ b/onnxruntime/test/providers/openvino/openvino_ep_bfloat16_pass_test.cc @@ -9,7 +9,7 @@ #include "core/framework/float16.h" #include "test/util/include/test/test_environment.h" -#include "test/optimizer/qdq_test_utils.h" +#include "test/unittest_util/qdq_test_utils.h" #include "gtest/gtest.h" #include "gmock/gmock.h" diff --git a/onnxruntime/test/providers/openvino/openvino_ep_context_test.cc b/onnxruntime/test/providers/openvino/openvino_ep_context_test.cc index e205b3aeb064a..404bc12634f48 100644 --- a/onnxruntime/test/providers/openvino/openvino_ep_context_test.cc +++ b/onnxruntime/test/providers/openvino/openvino_ep_context_test.cc @@ -11,14 +11,13 @@ #include "test/util/include/test_utils.h" #include "test/util/include/test/test_environment.h" #include "test/util/include/default_providers.h" +#include "test/unittest_util/qdq_test_utils.h" #include "core/session/onnxruntime_cxx_api.h" #include "core/session/onnxruntime_session_options_config_keys.h" #include "core/session/inference_session.h" #include "core/graph/model_saving_options.h" -#include "test/optimizer/qdq_test_utils.h" - #include "gtest/gtest.h" #include "gmock/gmock.h" diff --git a/onnxruntime/test/providers/partitioning_utils_test.cc b/onnxruntime/test/providers/partitioning_utils_test.cc index 89e2bf74a1b9c..5f435199679be 100644 --- a/onnxruntime/test/providers/partitioning_utils_test.cc +++ b/onnxruntime/test/providers/partitioning_utils_test.cc @@ -13,9 +13,9 @@ #include "core/optimizer/qdq_transformer/selectors_actions/shared/utils.h" #include "core/providers/partitioning_utils.h" -#include "test/optimizer/graph_transform_test_builder.h" -#include "test/optimizer/qdq_test_utils.h" +#include "test/unittest_util/graph_transform_test_builder.h" #include "test/util/include/asserts.h" +#include "test/unittest_util/qdq_test_utils.h" #include "test/util/include/test_utils.h" #include "test/util/include/test/test_environment.h" diff --git a/onnxruntime/test/providers/provider_test_utils.h b/onnxruntime/test/providers/provider_test_utils.h index 988df1e7a5d3b..1d8a50dc2fa04 100644 --- a/onnxruntime/test/providers/provider_test_utils.h +++ b/onnxruntime/test/providers/provider_test_utils.h @@ -3,47 +3,8 @@ #pragma once -#include "test/providers/checkers.h" -#include "test/providers/op_tester.h" -#include "test/providers/model_tester.h" - -namespace onnxruntime { -namespace test { -inline void ConvertFloatToMLFloat16(const float* f_datat, MLFloat16* h_data, size_t input_size) { - auto in_vector = ConstEigenVectorMap(f_datat, input_size); - auto output_vector = EigenVectorMap(static_cast(static_cast(h_data)), input_size); - output_vector = in_vector.template cast(); -} - -inline void ConvertFloatToUint8_t(const float* f_datat, uint8_t* u8_data, size_t input_size) { - auto in_vector = ConstEigenVectorMap(f_datat, input_size); - auto output_vector = EigenVectorMap(static_cast(static_cast(u8_data)), input_size); - output_vector = in_vector.template cast(); -} - -inline void ConvertMLFloat16ToFloat(const MLFloat16* h_data, float* f_data, size_t input_size) { - auto in_vector = - ConstEigenVectorMap(static_cast(static_cast(h_data)), input_size); - auto output_vector = EigenVectorMap(f_data, input_size); - output_vector = in_vector.template cast(); -} - -inline std::vector FloatsToMLFloat16s(const std::vector& f) { - std::vector m(f.size()); - ConvertFloatToMLFloat16(f.data(), m.data(), f.size()); - return m; -} - -inline std::vector MakeBFloat16(const std::initializer_list& input) { - std::vector output; - std::transform(input.begin(), input.end(), std::back_inserter(output), [](float f) { return BFloat16(f); }); - return output; -} - -inline std::vector FloatsToBFloat16s(const std::vector& input) { - std::vector output; - std::transform(input.begin(), input.end(), std::back_inserter(output), [](float f) { return BFloat16(f); }); - return output; -} -} // namespace test -} // namespace onnxruntime +#include "test/unittest_util/checkers.h" +#include "test/unittest_util/conversion.h" +#include "test/unittest_util/model_tester.h" +#include "test/unittest_util/op_tester.h" +#include "test/unittest_util/run_options_config_keys.h" diff --git a/onnxruntime/test/providers/qnn/average_pool_test.cc b/onnxruntime/test/providers/qnn/average_pool_test.cc index 8a0dd60765612..2799da02fd418 100644 --- a/onnxruntime/test/providers/qnn/average_pool_test.cc +++ b/onnxruntime/test/providers/qnn/average_pool_test.cc @@ -8,8 +8,8 @@ #include #include "core/graph/node_attr_utils.h" -#include "test/optimizer/qdq_test_utils.h" #include "test/providers/qnn/qnn_test_utils.h" +#include "test/unittest_util/qdq_test_utils.h" #include "core/graph/onnx_protobuf.h" diff --git a/onnxruntime/test/providers/qnn/batch_norm_test.cc b/onnxruntime/test/providers/qnn/batch_norm_test.cc index c88a0ce6cf0b2..0bce5482ecefc 100644 --- a/onnxruntime/test/providers/qnn/batch_norm_test.cc +++ b/onnxruntime/test/providers/qnn/batch_norm_test.cc @@ -7,8 +7,8 @@ #include "core/graph/graph.h" #include "core/framework/float16.h" -#include "test/optimizer/qdq_test_utils.h" #include "test/providers/qnn/qnn_test_utils.h" +#include "test/unittest_util/qdq_test_utils.h" #include "gtest/gtest.h" diff --git a/onnxruntime/test/providers/qnn/cast_test.cc b/onnxruntime/test/providers/qnn/cast_test.cc index 98a0fff2b0700..9e381c6282b43 100644 --- a/onnxruntime/test/providers/qnn/cast_test.cc +++ b/onnxruntime/test/providers/qnn/cast_test.cc @@ -11,8 +11,8 @@ #include "core/framework/float16.h" #include "core/graph/onnx_protobuf.h" -#include "test/optimizer/qdq_test_utils.h" #include "test/providers/qnn/qnn_test_utils.h" +#include "test/unittest_util/qdq_test_utils.h" namespace onnxruntime { namespace test { diff --git a/onnxruntime/test/providers/qnn/gather_elems_op_test.cc b/onnxruntime/test/providers/qnn/gather_elems_op_test.cc index 9036410139b48..48a1cb634b0e4 100644 --- a/onnxruntime/test/providers/qnn/gather_elems_op_test.cc +++ b/onnxruntime/test/providers/qnn/gather_elems_op_test.cc @@ -8,8 +8,8 @@ #include #include "core/graph/node_attr_utils.h" -#include "test/optimizer/qdq_test_utils.h" #include "test/providers/qnn/qnn_test_utils.h" +#include "test/unittest_util/qdq_test_utils.h" #include "core/graph/onnx_protobuf.h" diff --git a/onnxruntime/test/providers/qnn/instance_norm_htp_test.cc b/onnxruntime/test/providers/qnn/instance_norm_htp_test.cc index 648fb00da611d..062b73c298df6 100644 --- a/onnxruntime/test/providers/qnn/instance_norm_htp_test.cc +++ b/onnxruntime/test/providers/qnn/instance_norm_htp_test.cc @@ -7,8 +7,8 @@ #include "core/graph/graph.h" #include "core/graph/node_attr_utils.h" -#include "test/optimizer/qdq_test_utils.h" #include "test/providers/qnn/qnn_test_utils.h" +#include "test/unittest_util/qdq_test_utils.h" #include "gtest/gtest.h" diff --git a/onnxruntime/test/providers/qnn/layer_norm_test.cc b/onnxruntime/test/providers/qnn/layer_norm_test.cc index 7d92d1e10c39e..6beda83d00536 100644 --- a/onnxruntime/test/providers/qnn/layer_norm_test.cc +++ b/onnxruntime/test/providers/qnn/layer_norm_test.cc @@ -7,8 +7,8 @@ #include "core/graph/graph.h" #include "core/graph/node_attr_utils.h" -#include "test/optimizer/qdq_test_utils.h" #include "test/providers/qnn/qnn_test_utils.h" +#include "test/unittest_util/qdq_test_utils.h" #include "gtest/gtest.h" diff --git a/onnxruntime/test/providers/qnn/leakyrelu_op_htp_test.cc b/onnxruntime/test/providers/qnn/leakyrelu_op_htp_test.cc index ec85fd3dd339d..ca34cb89d1424 100644 --- a/onnxruntime/test/providers/qnn/leakyrelu_op_htp_test.cc +++ b/onnxruntime/test/providers/qnn/leakyrelu_op_htp_test.cc @@ -7,8 +7,8 @@ #include "core/graph/graph.h" #include "core/graph/node_attr_utils.h" -#include "test/optimizer/qdq_test_utils.h" #include "test/providers/qnn/qnn_test_utils.h" +#include "test/unittest_util/qdq_test_utils.h" #include "gtest/gtest.h" diff --git a/onnxruntime/test/providers/qnn/logical_comp_ops_test.cc b/onnxruntime/test/providers/qnn/logical_comp_ops_test.cc index 18e19efca13d7..cf9ae5fe42d94 100644 --- a/onnxruntime/test/providers/qnn/logical_comp_ops_test.cc +++ b/onnxruntime/test/providers/qnn/logical_comp_ops_test.cc @@ -6,8 +6,8 @@ #include #include -#include "test/optimizer/qdq_test_utils.h" #include "test/providers/qnn/qnn_test_utils.h" +#include "test/unittest_util/qdq_test_utils.h" #include "core/graph/onnx_protobuf.h" diff --git a/onnxruntime/test/providers/qnn/lrn_op_test.cc b/onnxruntime/test/providers/qnn/lrn_op_test.cc index 08a7d663bddf8..de12bfebe3f42 100644 --- a/onnxruntime/test/providers/qnn/lrn_op_test.cc +++ b/onnxruntime/test/providers/qnn/lrn_op_test.cc @@ -6,8 +6,8 @@ #include #include -#include "test/optimizer/qdq_test_utils.h" #include "test/providers/qnn/qnn_test_utils.h" +#include "test/unittest_util/qdq_test_utils.h" #include "core/graph/onnx_protobuf.h" diff --git a/onnxruntime/test/providers/qnn/lstm_test.cc b/onnxruntime/test/providers/qnn/lstm_test.cc index 5d20806d3ea4d..4d368436f0d2d 100644 --- a/onnxruntime/test/providers/qnn/lstm_test.cc +++ b/onnxruntime/test/providers/qnn/lstm_test.cc @@ -6,9 +6,9 @@ #include #include -#include "test/optimizer/qdq_test_utils.h" +#include "test/unittest_util/qdq_test_utils.h" #include "test/providers/qnn/qnn_test_utils.h" -#include "test/providers/tester_types.h" +#include "test/unittest_util/tester_types.h" #include "core/graph/onnx_protobuf.h" @@ -159,7 +159,7 @@ void _BuildLSTMTestCase(ModelTestBuilder& builder, lstm_node.AddAttribute("hidden_size", hidden_size); lstm_node.AddAttribute("layout", layout); ORT_UNUSED_PARAMETER(output_qparams); - if (std::is_same::value) { + if constexpr (std::is_same::value) { size_t i = 0; if (has_Y) { AddQDQNodePairWithOutputAsGraphOutput(builder, lstm_output_Y, output_qparams[i].scale, diff --git a/onnxruntime/test/providers/qnn/optimizer/transpose_optimizer_test.cc b/onnxruntime/test/providers/qnn/optimizer/transpose_optimizer_test.cc index 77cafc4b08389..05157d716ac79 100644 --- a/onnxruntime/test/providers/qnn/optimizer/transpose_optimizer_test.cc +++ b/onnxruntime/test/providers/qnn/optimizer/transpose_optimizer_test.cc @@ -10,9 +10,9 @@ #include "core/graph/constants.h" #include "core/optimizer/transpose_optimizer.h" -#include "test/framework/test_utils.h" -#include "test/optimizer/graph_transform_test_builder.h" +#include "test/unittest_util/framework_test_utils.h" #include "test/test_environment.h" +#include "test/unittest_util/graph_transform_test_builder.h" #include "test/util/include/asserts.h" namespace onnxruntime { diff --git a/onnxruntime/test/providers/qnn/pad_op_test.cpp b/onnxruntime/test/providers/qnn/pad_op_test.cpp index bfcd2c22c53d8..baaf1c4b21063 100644 --- a/onnxruntime/test/providers/qnn/pad_op_test.cpp +++ b/onnxruntime/test/providers/qnn/pad_op_test.cpp @@ -7,8 +7,8 @@ #include #include "core/graph/node_attr_utils.h" -#include "test/optimizer/qdq_test_utils.h" #include "test/providers/qnn/qnn_test_utils.h" +#include "test/unittest_util/qdq_test_utils.h" #include "core/graph/onnx_protobuf.h" diff --git a/onnxruntime/test/providers/qnn/pool_op_test.cpp b/onnxruntime/test/providers/qnn/pool_op_test.cpp index d51eeeea1aea8..86ad1d2f4251f 100644 --- a/onnxruntime/test/providers/qnn/pool_op_test.cpp +++ b/onnxruntime/test/providers/qnn/pool_op_test.cpp @@ -12,8 +12,8 @@ #include "core/graph/node_attr_utils.h" #include "core/graph/onnx_protobuf.h" -#include "test/optimizer/qdq_test_utils.h" #include "test/providers/qnn/qnn_test_utils.h" +#include "test/unittest_util/qdq_test_utils.h" namespace onnxruntime { namespace test { diff --git a/onnxruntime/test/providers/qnn/qnn_node_group/channel_shuffle_fusion_test.cc b/onnxruntime/test/providers/qnn/qnn_node_group/channel_shuffle_fusion_test.cc index e5f7fd69ed655..daadd66aa0498 100644 --- a/onnxruntime/test/providers/qnn/qnn_node_group/channel_shuffle_fusion_test.cc +++ b/onnxruntime/test/providers/qnn/qnn_node_group/channel_shuffle_fusion_test.cc @@ -6,8 +6,8 @@ #include "core/graph/graph.h" #include "core/graph/node_attr_utils.h" -#include "test/optimizer/qdq_test_utils.h" #include "test/providers/qnn/qnn_test_utils.h" +#include "test/unittest_util/qdq_test_utils.h" #include "gtest/gtest.h" namespace onnxruntime { diff --git a/onnxruntime/test/providers/qnn/qnn_node_group/lpbqgemm_fusion_test.cc b/onnxruntime/test/providers/qnn/qnn_node_group/lpbqgemm_fusion_test.cc index b349e0c40882f..d12b8033b19b1 100644 --- a/onnxruntime/test/providers/qnn/qnn_node_group/lpbqgemm_fusion_test.cc +++ b/onnxruntime/test/providers/qnn/qnn_node_group/lpbqgemm_fusion_test.cc @@ -14,8 +14,8 @@ #include "core/graph/graph.h" #include "core/graph/node_attr_utils.h" -#include "test/optimizer/qdq_test_utils.h" #include "test/providers/qnn/qnn_test_utils.h" +#include "test/unittest_util/qdq_test_utils.h" #include "gtest/gtest.h" namespace onnxruntime { diff --git a/onnxruntime/test/providers/qnn/qnn_node_group/lpbqmatmul_fusion_test.cc b/onnxruntime/test/providers/qnn/qnn_node_group/lpbqmatmul_fusion_test.cc index 8f63ccd5f2cd1..9a6ffd0b0ce74 100644 --- a/onnxruntime/test/providers/qnn/qnn_node_group/lpbqmatmul_fusion_test.cc +++ b/onnxruntime/test/providers/qnn/qnn_node_group/lpbqmatmul_fusion_test.cc @@ -14,8 +14,8 @@ #include "core/graph/graph.h" #include "core/graph/node_attr_utils.h" -#include "test/optimizer/qdq_test_utils.h" #include "test/providers/qnn/qnn_test_utils.h" +#include "test/unittest_util/qdq_test_utils.h" #include "gtest/gtest.h" namespace onnxruntime { diff --git a/onnxruntime/test/providers/qnn/qnn_node_group/scale_softmax_fusion_test.cc b/onnxruntime/test/providers/qnn/qnn_node_group/scale_softmax_fusion_test.cc index eda04b954f590..e012f4fa86430 100644 --- a/onnxruntime/test/providers/qnn/qnn_node_group/scale_softmax_fusion_test.cc +++ b/onnxruntime/test/providers/qnn/qnn_node_group/scale_softmax_fusion_test.cc @@ -6,8 +6,8 @@ #include "core/graph/graph.h" #include "core/graph/node_attr_utils.h" -#include "test/optimizer/qdq_test_utils.h" #include "test/providers/qnn/qnn_test_utils.h" +#include "test/unittest_util/qdq_test_utils.h" #include "gtest/gtest.h" namespace onnxruntime { diff --git a/onnxruntime/test/providers/qnn/qnn_test_utils.h b/onnxruntime/test/providers/qnn/qnn_test_utils.h index fbc68c2ea14ea..36fe37bbf4c25 100644 --- a/onnxruntime/test/providers/qnn/qnn_test_utils.h +++ b/onnxruntime/test/providers/qnn/qnn_test_utils.h @@ -13,10 +13,10 @@ #include "core/framework/float16.h" #include "core/util/qmath.h" -#include "test/optimizer/qdq_test_utils.h" +#include "test/util/include/default_providers.h" +#include "test/unittest_util/qdq_test_utils.h" #include "test/util/include/test_utils.h" #include "test/util/include/test/test_environment.h" -#include "test/util/include/default_providers.h" #include "gtest/gtest.h" @@ -1159,7 +1159,7 @@ inline GetTestQDQModelFn BuildQDQOpTestCase( * Runs a test model on the QNN EP. Checks the graph node assignment, and that inference * outputs for QNN and CPU match. * - * \param build_test_case Function that builds a test model. See test/optimizer/qdq_test_utils.h + * \param build_test_case Function that builds a test model. See test/unittest_util/qdq_test_utils.h * \param provider_options Provider options for QNN EP. * \param opset_version The opset version. * \param expected_ep_assignment How many nodes are expected to be assigned to QNN (All, Some, or None). diff --git a/onnxruntime/test/providers/qnn/reduce_op_test.cc b/onnxruntime/test/providers/qnn/reduce_op_test.cc index f96e5338369b2..0884a0fc80165 100644 --- a/onnxruntime/test/providers/qnn/reduce_op_test.cc +++ b/onnxruntime/test/providers/qnn/reduce_op_test.cc @@ -6,8 +6,8 @@ #include #include "core/graph/graph.h" -#include "test/optimizer/qdq_test_utils.h" #include "test/providers/qnn/qnn_test_utils.h" +#include "test/unittest_util/qdq_test_utils.h" #include "gtest/gtest.h" @@ -113,7 +113,7 @@ static void RunReduceTest(const std::string& op_type, // - Uses opset 13, which has "axes" as an input. TEST_F(QnnCPUBackendTests, ReduceSumOpset13_Int32) { RunReduceTest("ReduceSum", - TestInputDef({2, 2}, false, -10.0f, 10.0f), + TestInputDef({2, 2}, false, -10, 10), std::vector{0, 1}, true, // keepdims 13, @@ -127,7 +127,7 @@ TEST_F(QnnCPUBackendTests, ReduceSumOpset13_Int32) { // - Uses opset 11, which has "axes" as an attribute. TEST_F(QnnCPUBackendTests, ReduceSumOpset11_Int32) { RunReduceTest("ReduceSum", - TestInputDef({2, 2}, false, -10.0f, 10.0f), + TestInputDef({2, 2}, false, -10, 10), std::vector{0, 1}, true, // keepdims 11, diff --git a/onnxruntime/test/providers/qnn/reshape_expand_op_test.cc b/onnxruntime/test/providers/qnn/reshape_expand_op_test.cc index 45626d63d1970..587e11e79b71d 100644 --- a/onnxruntime/test/providers/qnn/reshape_expand_op_test.cc +++ b/onnxruntime/test/providers/qnn/reshape_expand_op_test.cc @@ -14,18 +14,19 @@ namespace onnxruntime { namespace test { -// Runs a model with a Reshape/Expand operator on the QNN CPU backend. Checks the graph node assignment +// Runs a model with a Reshape/Expand operator on the QNN CPU or GPU backends. Checks the graph node assignment // and that inference outputs for QNN EP and CPU EP match. template -static void RunReshapeExpandTestOnCPU(const std::string& op_type, - const TestInputDef& input_def, - const TestInputDef& shape_def, - const std::vector& attrs, - ExpectedEPNodeAssignment expected_ep_assignment, - int opset = 19) { +static void RunReshapeExpandTest(const std::string& op_type, + const TestInputDef& input_def, + const TestInputDef& shape_def, + const std::vector& attrs, + ExpectedEPNodeAssignment expected_ep_assignment, + const std::string& backend_name = "cpu", + int opset = 19) { ProviderOptions provider_options; - provider_options["backend_type"] = "cpu"; + provider_options["backend_type"] = backend_name; provider_options["offload_graph_io_quantization"] = "0"; RunQnnModelTest(BuildOpTestCase(op_type, {input_def}, {shape_def}, attrs), @@ -40,68 +41,142 @@ static void RunReshapeExpandTestOnCPU(const std::string& op_type, // Test that Reshape with a dynamic shape input is not supported by QNN EP. TEST_F(QnnCPUBackendTests, Reshape_DynamicShape_Unsupported) { - RunReshapeExpandTestOnCPU("Reshape", - TestInputDef({1, 3, 4, 4}, false, -10.0f, 10.0f), - TestInputDef({2}, false /* is_initializer */, {1, 48}), - {}, // Attributes - ExpectedEPNodeAssignment::None, // Should not be assigned to QNN EP. - 19); // Opset + RunReshapeExpandTest("Reshape", + TestInputDef({1, 3, 4, 4}, false, -10.0f, 10.0f), + TestInputDef({2}, false /* is_initializer */, {1, 48}), + {}, // Attributes + ExpectedEPNodeAssignment::None, // Should not be assigned to QNN EP. + "cpu", // Backend + 19); // Opset } // Test that Reshape with an enabled 'allowzero' attribute is not supported by QNN EP. TEST_F(QnnCPUBackendTests, Reshape_AllowZeroAttr_Unsupported) { - RunReshapeExpandTestOnCPU("Reshape", TestInputDef({1, 3, 4, 4}, false, -10.0f, 10.0f), - TestInputDef({2}, true, {1, 48}), - {utils::MakeAttribute("allowzero", static_cast(1))}, - ExpectedEPNodeAssignment::None, // Should not be assigned to QNN EP. - 19); // Opset + RunReshapeExpandTest("Reshape", TestInputDef({1, 3, 4, 4}, false, -10.0f, 10.0f), + TestInputDef({2}, true, {1, 48}), + {utils::MakeAttribute("allowzero", static_cast(1))}, + ExpectedEPNodeAssignment::None, // Should not be assigned to QNN EP. + "cpu", // Backend + 19); // Opset } // Test Reshape of rank 4 -> rank 2. TEST_F(QnnCPUBackendTests, Reshape_4D_f32) { - RunReshapeExpandTestOnCPU("Reshape", TestInputDef({1, 3, 4, 4}, false, GetFloatDataInRange(-10.0f, 10.0f, 48)), - TestInputDef({2}, true, {1, 48}), - {}, // Attributes - ExpectedEPNodeAssignment::All, - 19); // Opset + RunReshapeExpandTest("Reshape", TestInputDef({1, 3, 4, 4}, false, GetFloatDataInRange(-10.0f, 10.0f, 48)), + TestInputDef({2}, true, {1, 48}), + {}, // Attributes + ExpectedEPNodeAssignment::All, + "cpu", // Backend + 19); // Opset } // Test Expand with non-initializer shape input, not supported. TEST_F(QnnCPUBackendTests, Expand_NonIniShape) { - RunReshapeExpandTestOnCPU("Expand", TestInputDef({1}, false, {1.0f}), - TestInputDef({2}, false, {2, 2}), - {}, // Attributes - ExpectedEPNodeAssignment::None, - 19); // Opset + RunReshapeExpandTest("Expand", TestInputDef({1}, false, {1.0f}), + TestInputDef({2}, false, {2, 2}), + {}, // Attributes + ExpectedEPNodeAssignment::None, + "cpu", // Backend + 19); // Opset } // Test Expand with initializer shape input. TEST_F(QnnCPUBackendTests, Expand_IniShape) { - RunReshapeExpandTestOnCPU("Expand", TestInputDef({1}, false, {1.0f}), - TestInputDef({2}, true, {2, 3}), - {}, // Attributes - ExpectedEPNodeAssignment::All, - 19); // Opset + RunReshapeExpandTest("Expand", TestInputDef({1}, false, {1.0f}), + TestInputDef({2}, true, {2, 3}), + {}, // Attributes + ExpectedEPNodeAssignment::All, + "cpu", // Backend + 19); // Opset } // Test Expand with initializer shape input. TEST_F(QnnCPUBackendTests, Expand_Uint32) { - RunReshapeExpandTestOnCPU("Expand", TestInputDef({1}, false, {1}), - TestInputDef({2}, true, {2, 3}), - {}, // Attributes - ExpectedEPNodeAssignment::All, - 19); // Opset + RunReshapeExpandTest("Expand", TestInputDef({1}, false, {1}), + TestInputDef({2}, true, {2, 3}), + {}, // Attributes + ExpectedEPNodeAssignment::All, + "cpu", // Backend + 19); // Opset } // Test Expand with 6D output. TEST_F(QnnCPUBackendTests, Expand_6D) { - RunReshapeExpandTestOnCPU("Expand", TestInputDef({3}, false, {1.0f, 2.0f, 3.0f}), - TestInputDef({6}, true, {1, 2, 3, 4, 5, 3}), - {}, // Attributes - ExpectedEPNodeAssignment::All, - 19); // Opset + RunReshapeExpandTest("Expand", TestInputDef({3}, false, {1.0f, 2.0f, 3.0f}), + TestInputDef({6}, true, {1, 2, 3, 4, 5, 3}), + {}, // Attributes + ExpectedEPNodeAssignment::All, + "cpu", // Backend + 19); // Opset +} + +#if defined(_M_ARM64) +// +// GPU tests: +// + +// Test Reshape of rank 4 -> rank 2. +TEST_F(QnnGPUBackendTests, Reshape_4D_f32) { + RunReshapeExpandTest("Reshape", TestInputDef({1, 3, 4, 4}, false, GetFloatDataInRange(-10.0f, 10.0f, 48)), + TestInputDef({2}, true, {1, 48}), + {}, // Attributes + ExpectedEPNodeAssignment::All, + "gpu", // Backend + 19); // Opset +} + +// Test Expand with initializer shape input. +TEST_F(QnnGPUBackendTests, Expand_IniShape) { + RunReshapeExpandTest("Expand", TestInputDef({1}, false, {1.0f}), + TestInputDef({2}, true, {2, 3}), + {}, // Attributes + ExpectedEPNodeAssignment::All, + "gpu", // Backend + 19); // Opset +} + +// Test Expand with FP16 data +TEST_F(QnnGPUBackendTests, Expand_IniShape_Float16) { + RunReshapeExpandTest("Expand", TestInputDef({1}, false, {MLFloat16(1.0f)}), + TestInputDef({2}, true, {2, 3}), + {}, // Attributes + ExpectedEPNodeAssignment::All, + "gpu", // Backend + 19); // Opset +} + +// Test Expand with 6D output. +TEST_F(QnnGPUBackendTests, Expand_6D) { + RunReshapeExpandTest("Expand", TestInputDef({3}, false, {1.0f, 2.0f, 3.0f}), + TestInputDef({6}, true, {1, 2, 3, 4, 5, 3}), + {}, // Attributes + ExpectedEPNodeAssignment::All, + "gpu", // Backend + 19); // Opset +} + +// Test Expand with 6D output with FP16 data. +TEST_F(QnnGPUBackendTests, Expand_6D_Float16) { + RunReshapeExpandTest("Expand", TestInputDef({3}, false, {MLFloat16(1.0f), MLFloat16(2.0f), MLFloat16(3.0f)}), + TestInputDef({6}, true, {1, 2, 3, 4, 5, 3}), + {}, // Attributes + ExpectedEPNodeAssignment::All, + "gpu", // Backend + 19); // Opset } +// Test Expand with 4D output with FP16 data. +TEST_F(QnnGPUBackendTests, Expand_4D_Float16) { + RunReshapeExpandTest("Expand", TestInputDef({1, 2, 1, 1}, false, {MLFloat16(1.0f), MLFloat16(2.0f)}), + TestInputDef({4}, true, {1, 2, 128, 128}), + {}, // Attributes + ExpectedEPNodeAssignment::All, + "gpu", // Backend + 19); // Opset +} + +#endif // defined(_M_ARM64) GPU tests + #if defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) // // HTP tests: diff --git a/onnxruntime/test/providers/qnn/resize_test.cc b/onnxruntime/test/providers/qnn/resize_test.cc index 9875a52e1d2b4..1b6730dc34711 100644 --- a/onnxruntime/test/providers/qnn/resize_test.cc +++ b/onnxruntime/test/providers/qnn/resize_test.cc @@ -6,8 +6,8 @@ #include #include -#include "test/optimizer/qdq_test_utils.h" #include "test/providers/qnn/qnn_test_utils.h" +#include "test/unittest_util/qdq_test_utils.h" #include "core/graph/onnx_protobuf.h" @@ -374,17 +374,16 @@ TEST_F(QnnHTPBackendTests, ResizeU8_2xLinearAsymmetric) { } // Test 2x QDQ Resize mode: "nearest", coordinate_transformation_mode: "half_pixel", nearest_mode: "round_prefer_floor" -// Maps to QNN's Resize operator. -// UPDATE: "round_prefer_floor" no longer supported in QNN SDK 2.21 (supported in QNN SDK 2.19) -TEST_F(QnnHTPBackendTests, ResizeU8_2xNearestHalfPixelRoundPreferFloor_Unsupported) { +// Maps to QNN's ResizeNearestNeighbor operator. +TEST_F(QnnHTPBackendTests, ResizeU8_2xNearestHalfPixelRoundPreferFloor) { std::vector input_data = GetFloatDataInRange(-10.0f, 10.0f, 48); RunQDQResizeOpTest(TestInputDef({1, 3, 4, 4}, false, input_data), {1, 3, 8, 8}, "nearest", "half_pixel", "round_prefer_floor", - ExpectedEPNodeAssignment::None); // No longer supported as of QNN SDK 2.21 + ExpectedEPNodeAssignment::All); } // Test 2x QDQ Resize mode: "nearest", coordinate_transformation_mode: "half_pixel", nearest_mode: "round_prefer_Ceil" -// Maps to QNN's ResizeNearesetNeighbor operator. +// Maps to QNN's ResizeNearestNeighbor operator. TEST_F(QnnHTPBackendTests, ResizeU8_2xNearestHalfPixelRoundPreferCeil) { std::vector input_data = GetFloatDataInRange(-10.0f, 10.0f, 48); RunQDQResizeOpTest(TestInputDef({1, 3, 4, 4}, false, input_data), @@ -393,7 +392,7 @@ TEST_F(QnnHTPBackendTests, ResizeU8_2xNearestHalfPixelRoundPreferCeil) { } // Test 2x QDQ Resize mode: "nearest", coordinate_transformation_mode: "align_corners", nearest_mode: "round_prefer_ceil" -// Maps to QNN's Resize operator. +// Maps to QNN's ResizeNearestNeighbor operator. // UPDATE: "round_prefer_ceil" is supported as of QNN SDK 2.21 if using "align_corners". (Unsupported in QNN SDK 2.19). TEST_F(QnnHTPBackendTests, ResizeU8_2xNearestAlignCornersRoundPreferCeil) { std::vector input_data = GetFloatDataInRange(-10.0f, 10.0f, 48); @@ -402,8 +401,9 @@ TEST_F(QnnHTPBackendTests, ResizeU8_2xNearestAlignCornersRoundPreferCeil) { ExpectedEPNodeAssignment::All); } -// Test that the nearest_mode "ceil" is not supported on the HTP backend. -TEST_F(QnnHTPBackendTests, ResizeU8_NearestModeCeil_Unsupported) { +// Test 2x QDQ Resize mode: "nearest", coordinate_transformation_mode: "asymmetric", nearest_mode: "ceil" +// Maps to QNN's ResizeNearestNeighbor operator. +TEST_F(QnnHTPBackendTests, ResizeU8_2xNearestAsymmetricCeil_Unsupported) { std::vector input_data = GetFloatDataInRange(-10.0f, 10.0f, 48); RunQDQResizeOpTest(TestInputDef({1, 3, 4, 4}, false, input_data), {1, 3, 8, 8}, "nearest", "asymmetric", "ceil", @@ -420,7 +420,7 @@ TEST_F(QnnHTPBackendTests, ResizeU8_3xNearestAsymmetricFloor) { } // Test 2x QDQ Resize mode: "nearest", coordinate_transformation_mode: "asymmetric", nearest_mode: "round_prefer_floor" -// Maps to QNN's Resize operator. +// Maps to QNN's ResizeNearestNeighbor operator. // UPDATE: "round_prefer_floor" no longer supported in QNN SDK 2.21 (supported in QNN SDK 2.19) TEST_F(QnnHTPBackendTests, ResizeU8_2xNearestAsymmetricRoundPreferFloor_Unsupported) { std::vector input_data = GetFloatDataInRange(-10.0f, 10.0f, 8); diff --git a/onnxruntime/test/providers/qnn/simple_op_htp_test.cc b/onnxruntime/test/providers/qnn/simple_op_htp_test.cc index 66f7ebdcd3326..ec119c155455d 100644 --- a/onnxruntime/test/providers/qnn/simple_op_htp_test.cc +++ b/onnxruntime/test/providers/qnn/simple_op_htp_test.cc @@ -10,8 +10,8 @@ #include "core/graph/node_attr_utils.h" #include "core/session/onnxruntime_session_options_config_keys.h" -#include "test/optimizer/qdq_test_utils.h" #include "test/providers/qnn/qnn_test_utils.h" +#include "test/unittest_util/qdq_test_utils.h" #include "gtest/gtest.h" diff --git a/onnxruntime/test/providers/qnn/slice_htp_test.cc b/onnxruntime/test/providers/qnn/slice_htp_test.cc index 9c50f47d462bc..0bec8caaddf76 100644 --- a/onnxruntime/test/providers/qnn/slice_htp_test.cc +++ b/onnxruntime/test/providers/qnn/slice_htp_test.cc @@ -7,8 +7,8 @@ #include "core/graph/graph.h" #include "core/graph/node_attr_utils.h" -#include "test/optimizer/qdq_test_utils.h" #include "test/providers/qnn/qnn_test_utils.h" +#include "test/unittest_util/qdq_test_utils.h" #include "gtest/gtest.h" diff --git a/onnxruntime/test/providers/qnn/transpose_htp_test.cc b/onnxruntime/test/providers/qnn/transpose_htp_test.cc index 83ff6440c8399..53604dc2af2f8 100644 --- a/onnxruntime/test/providers/qnn/transpose_htp_test.cc +++ b/onnxruntime/test/providers/qnn/transpose_htp_test.cc @@ -7,8 +7,8 @@ #include "core/graph/graph.h" #include "core/graph/node_attr_utils.h" -#include "test/optimizer/qdq_test_utils.h" #include "test/providers/qnn/qnn_test_utils.h" +#include "test/unittest_util/qdq_test_utils.h" #include "gtest/gtest.h" diff --git a/onnxruntime/test/providers/qnn/where_htp_test.cc b/onnxruntime/test/providers/qnn/where_htp_test.cc index 95a9f3dac9cb7..9b44f5edfa7dc 100644 --- a/onnxruntime/test/providers/qnn/where_htp_test.cc +++ b/onnxruntime/test/providers/qnn/where_htp_test.cc @@ -7,8 +7,8 @@ #include "core/graph/graph.h" #include "core/graph/node_attr_utils.h" -#include "test/optimizer/qdq_test_utils.h" #include "test/providers/qnn/qnn_test_utils.h" +#include "test/unittest_util/qdq_test_utils.h" #include "gtest/gtest.h" diff --git a/onnxruntime/test/providers/rknpu/rknpu_basic_test.cc b/onnxruntime/test/providers/rknpu/rknpu_basic_test.cc index fffd081a692c0..5b8be8d84dde6 100644 --- a/onnxruntime/test/providers/rknpu/rknpu_basic_test.cc +++ b/onnxruntime/test/providers/rknpu/rknpu_basic_test.cc @@ -1,6 +1,6 @@ #include "core/session/inference_session.h" #include "test/providers/provider_test_utils.h" -#include "test/framework/test_utils.h" +#include "test/unittest_util/framework_test_utils.h" #include "gtest/gtest.h" #include "core/providers/rknpu/rknpu_execution_provider.h" #include "core/common/logging/logging.h" diff --git a/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc b/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc index 95c5a0ab97728..327dfab96c2d1 100644 --- a/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc +++ b/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc @@ -3,7 +3,7 @@ #include "core/graph/onnx_protobuf.h" #include "core/session/inference_session.h" #include "test/providers/provider_test_utils.h" -#include "test/framework/test_utils.h" +#include "test/unittest_util/framework_test_utils.h" #include "gtest/gtest.h" #include "test/util/include/default_providers.h" #include "test/util/include/scoped_env_vars.h" diff --git a/onnxruntime/test/providers/xnnpack/xnnpack_basic_test.cc b/onnxruntime/test/providers/xnnpack/xnnpack_basic_test.cc index 03f1f51aa57c4..9ca081a74c850 100644 --- a/onnxruntime/test/providers/xnnpack/xnnpack_basic_test.cc +++ b/onnxruntime/test/providers/xnnpack/xnnpack_basic_test.cc @@ -15,7 +15,7 @@ #include "core/session/onnxruntime_session_options_config_keys.h" #include "test/common/tensor_op_test_utils.h" -#include "test/framework/test_utils.h" +#include "test/unittest_util/framework_test_utils.h" #include "test/test_environment.h" #include "test/util/include/api_asserts.h" #include "test/util/include/asserts.h" @@ -25,7 +25,7 @@ #if !defined(ORT_MINIMAL_BUILD) // if this is a full build we need the provider test utils -#include "test/optimizer/qdq_test_utils.h" +#include "test/unittest_util/qdq_test_utils.h" #endif #include "gtest/gtest.h" diff --git a/onnxruntime/test/python/onnxruntime_test_python.py b/onnxruntime/test/python/onnxruntime_test_python.py index 1820664e1d604..37011f1d1b362 100644 --- a/onnxruntime/test/python/onnxruntime_test_python.py +++ b/onnxruntime/test/python/onnxruntime_test_python.py @@ -1305,6 +1305,25 @@ def test_session_options_add_external_initializers(self): providers=["CPUExecutionProvider"], ) + def test_session_options_add_external_initializers_from_files_in_memory(self): + # Provide external initializer file content directly from memory + # The model references an external file named "Pads_not_on_disk.bin" for the initializer + pads_bytes = np.array([0, 0, 1, 1], dtype=np.int64).tobytes() + + so = onnxrt.SessionOptions() + so.add_external_initializers_from_files_in_memory( + ["Pads_not_on_disk.bin"], + [pads_bytes], + [len(pads_bytes)], + ) + + # This should not throw + onnxrt.InferenceSession( + get_name("model_with_external_initializer_come_from_user.onnx"), + sess_options=so, + providers=["CPUExecutionProvider"], + ) + def test_register_custom_ops_library(self): if sys.platform.startswith("win"): shared_library = os.path.abspath("custom_op_library.dll") diff --git a/onnxruntime/test/python/transformers/test_qmoe_cpu.py b/onnxruntime/test/python/transformers/test_qmoe_cpu.py index 403becbe0616a..0292111b16962 100644 --- a/onnxruntime/test/python/transformers/test_qmoe_cpu.py +++ b/onnxruntime/test/python/transformers/test_qmoe_cpu.py @@ -128,6 +128,148 @@ def quant_dequant(weights, is_4_bit_quantization: bool = True): # Calculate scale like C++ implementation abs_max = weights.abs().max(dim=-1, keepdim=True)[0] + + # Set minimum scale to avoid division by zero + scale = torch.clamp(abs_max, min=1e-6) + + # Quantization ranges for symmetric quantization + if is_4_bit_quantization: + qmin, qmax = -8, 7 + zero_point = 8 # Offset to make values unsigned + else: + qmin, qmax = -128, 127 + zero_point = 128 # Offset to make values unsigned + + # Quantize using double precision division and C-like rounding (half away from zero) + scaled = weights.double() / scale.double() + sign = torch.sign(scaled) + abs_scaled = torch.abs(scaled) + quant_rounded = torch.floor(abs_scaled + 0.5) + quantized = torch.clamp((sign * quant_rounded).to(torch.int32), qmin, qmax).to(weights.dtype) + + # Convert to unsigned and pack for storage + if is_4_bit_quantization: + # Convert to unsigned 4-bit and pack into uint8 + unsigned_quantized = (quantized + zero_point).to(torch.uint8) + + # Pack two 4-bit values into one uint8 + packed_size = (weights.shape[-1] + 1) // 2 + packed_quantized = torch.zeros((*weights.shape[:-1], packed_size), dtype=torch.uint8, device=weights.device) + + for i in range(0, weights.shape[-1], 2): + val1 = unsigned_quantized[..., i] + val2 = unsigned_quantized[..., i + 1] if i + 1 < weights.shape[-1] else torch.zeros_like(val1) + packed_quantized[..., i // 2] = (val1 & 0xF) | ((val2 & 0xF) << 4) + + quantized_storage = packed_quantized + else: + # 8-bit: convert to unsigned uint8 + quantized_storage = (quantized + zero_point).to(torch.uint8) + + # Dequantize for verification (use float32 scale for higher precision) + dequantized = quantized.to(torch.float32) * scale + + return scale.squeeze(-1).to(torch.float32), quantized_storage, dequantized + + +def quant_dequant_blockwise(weights, block_size, is_4_bit_quantization: bool = True): + """ + Block-wise quantization and dequantization for testing purposes. + This function uses symmetric quantization centered around 0 (no zero-point). + + Args: + weights: Input tensor of shape [rows, cols] + block_size: Size of each quantization block + is_4_bit_quantization: Whether to use 4-bit (True) or 8-bit (False) quantization + + Returns: + scales: Scale tensor of shape [rows, num_blocks] + quantized: Quantized tensor + dequantized: Dequantized tensor for verification + """ + rows, cols = weights.shape + num_blocks = (cols + block_size - 1) // block_size + + # Handle edge case of all-zero weights tensor + if torch.all(weights == 0): + scales = torch.zeros((rows, num_blocks), dtype=torch.float16, device=weights.device) + if is_4_bit_quantization: + packed_size = (cols + 1) // 2 + quantized = torch.zeros((rows, packed_size), dtype=torch.uint8, device=weights.device) + else: + quantized = torch.zeros((rows, cols), dtype=torch.uint8, device=weights.device) + dequantized = torch.zeros_like(weights) + return scales, quantized, dequantized + + # Initialize output tensors; use float32 for scales to reduce precision loss + scales = torch.zeros((rows, num_blocks), dtype=torch.float32, device=weights.device) + dequantized = torch.zeros_like(weights) + + # Quantization ranges and zero point + if is_4_bit_quantization: + qmin, qmax = -8, 7 + zero_point = 8 + packed_size = (cols + 1) // 2 + quantized = torch.zeros((rows, packed_size), dtype=torch.uint8, device=weights.device) + else: + qmin, qmax = -128, 127 + zero_point = 128 + quantized = torch.zeros((rows, cols), dtype=torch.uint8, device=weights.device) + + # Process each block with higher-precision math to match C++ behavior + for row in range(rows): + for block_idx in range(num_blocks): + start_col = block_idx * block_size + end_col = min(start_col + block_size, cols) + + # Get block data + block_data = weights[row, start_col:end_col] + + # Calculate absolute max and ensure small epsilon to avoid div-by-zero + abs_max = block_data.abs().max() + abs_max = torch.clamp(abs_max, min=1e-8) + + # Compute scale consistent with C++: use 7.0 for 4-bit positive max, 127.0 for 8-bit + if is_4_bit_quantization: + # Use higher precision then keep as float32 for scale + scale = (abs_max.double() / 7.0).float() + 1e-12 + else: + scale = (abs_max.double() / 127.0).float() + 1e-12 + + scales[row, block_idx] = scale.to(torch.float32) + + if scale == 0: + continue + + # Quantize using double precision for the division to reduce rounding error + scaled = block_data.double() / scale.double() + # Emulate C's round() behavior (round half away from zero) to match C++ implementation + sign = torch.sign(scaled) + abs_scaled = torch.abs(scaled) + quant_rounded = torch.floor(abs_scaled + 0.5) + quantized_block = (sign * quant_rounded).clamp(qmin, qmax).to(torch.int32) + + # Pack for 4-bit or store directly for 8-bit + if is_4_bit_quantization: + for i in range(0, end_col - start_col, 2): + col_idx = start_col + i + packed_idx = col_idx // 2 + + val1 = int(quantized_block[i]) + zero_point + val2 = int(quantized_block[i + 1]) + zero_point if i + 1 < len(quantized_block) else zero_point + + # Pack two 4-bit values into one uint8 + packed_val = (val1 & 0xF) | ((val2 & 0xF) << 4) + quantized[row, packed_idx] = packed_val + else: + quantized_vals = (quantized_block + zero_point).to(torch.uint8) + quantized[row, start_col:end_col] = quantized_vals + + # Dequantize for verification (signed quantized values multiplied by scale) + signed = quantized_block.to(torch.float32) + dequantized[row, start_col:end_col] = signed * scale + + return scales, quantized, dequantized abs_max = torch.clamp(abs_max, min=1e-8) # More conservative clamping for better precision if is_4_bit_quantization: @@ -247,6 +389,7 @@ def create_cpu_moe_onnx_graph( use_quant=False, quant_bits=4, swiglu_interleaved=False, + block_size=0, # New parameter for block-wise quantization ): if not has_onnx: return None @@ -264,14 +407,13 @@ def create_cpu_moe_onnx_graph( if not has_onnx: return None - if use_quant: - # Assertions only apply to quantized MoE - assert fc1_experts_weights.dtype == torch.uint8, "FC1 weights must be uint8 for QMoE" - assert fc2_experts_weights.dtype == torch.uint8, "FC2 weights must be uint8 for QMoE" - assert fc1_scales is not None, "FC1 scales must be provided for QMoE" - assert fc2_scales is not None, "FC2 scales must be provided for QMoE" - assert fc1_scales.dtype == torch.float16, "FC1 scales must be float16 for QMoE" - assert fc2_scales.dtype == torch.float16, "FC2 scales must be float16 for QMoE" + assert fc1_experts_weights.dtype == torch.uint8, "FC1 weights must be uint8 for QMoE" + assert fc2_experts_weights.dtype == torch.uint8, "FC2 weights must be uint8 for QMoE" + assert fc1_scales is not None, "FC1 scales must be provided for QMoE" + assert fc2_scales is not None, "FC2 scales must be provided for QMoE" + # Accept float16 or float32 scales; tests may produce float32 for better precision + assert fc1_scales.dtype in (torch.float16, torch.float32), "FC1 scales must be float16 or float32 for QMoE" + assert fc2_scales.dtype in (torch.float16, torch.float32), "FC2 scales must be float16 or float32 for QMoE" if not has_onnx: return None @@ -332,6 +474,10 @@ def create_cpu_moe_onnx_graph( if use_quant: nodes[0].attribute.extend([helper.make_attribute("expert_weight_bits", quant_bits)]) + # Add block_size attribute for block-wise quantization + if block_size > 0: + nodes[0].attribute.extend([helper.make_attribute("block_size", block_size)]) + # Weights are store in column major order. Need pack 2 int4 values into uint8. # Use the actual tensor shapes instead of calculating them to avoid size mismatches fc1_shape = list(fc1_experts_weights.shape) @@ -342,30 +488,59 @@ def create_cpu_moe_onnx_graph( weight_numpy_type = numpy.uint8 if use_quant else ort_to_numpy_type_map[onnx_dtype] weight_onnx_type = TensorProto.UINT8 if use_quant else onnx_dtype + # Use raw bytes from C-contiguous numpy arrays to ensure the exact memory layout + # of the packed uint8 weight tensors is preserved when writing the ONNX initializer. + fc1_np = fc1_experts_weights.detach().cpu().numpy().astype(weight_numpy_type) + fc2_np = fc2_experts_weights.detach().cpu().numpy().astype(weight_numpy_type) + fc1_np = numpy.ascontiguousarray(fc1_np) + fc2_np = numpy.ascontiguousarray(fc2_np) + initializers = [ helper.make_tensor( "fc1_experts_weights", weight_onnx_type, fc1_shape, - fc1_experts_weights.flatten().detach().cpu().numpy().astype(weight_numpy_type).tolist(), + fc1_np.tobytes(), + raw=True, ), helper.make_tensor( "fc2_experts_weights", weight_onnx_type, fc2_shape, - fc2_experts_weights.flatten().detach().cpu().numpy().astype(weight_numpy_type).tolist(), + fc2_np.tobytes(), + raw=True, ), ] - fc1_scale_shape = [num_experts, 2 * inter_size if use_swiglu else inter_size] - fc2_scale_shape = [num_experts, hidden_size] + # Calculate scale tensor shapes based on block_size + if block_size > 0: + # Block-wise quantization: 3D scale tensors + fc1_blocks_per_row = (hidden_size + block_size - 1) // block_size + fc2_blocks_per_row = (inter_size + block_size - 1) // block_size - fc1_scale_size = num_experts * (2 * inter_size if use_swiglu else inter_size) - fc2_scale_size = num_experts * hidden_size + fc1_scale_shape = [num_experts, 2 * inter_size if use_swiglu else inter_size, fc1_blocks_per_row] + fc2_scale_shape = [num_experts, hidden_size, fc2_blocks_per_row] - # Handle scale tensors based on quantization mode - if use_quant: - # Handle different possible scale tensor structures for fc1_scales + fc1_scale_size = num_experts * (2 * inter_size if use_swiglu else inter_size) * fc1_blocks_per_row + fc2_scale_size = num_experts * hidden_size * fc2_blocks_per_row + else: + # Row-wise quantization: 2D scale tensors + fc1_scale_shape = [num_experts, 2 * inter_size if use_swiglu else inter_size] + fc2_scale_shape = [num_experts, hidden_size] + + fc1_scale_size = num_experts * (2 * inter_size if use_swiglu else inter_size) + fc2_scale_size = num_experts * hidden_size + + # Handle scale tensors - fc1_scales and fc2_scales are guaranteed to be not None due to earlier assertions + # Process scale tensors based on whether block-wise quantization is used + if block_size > 0: + # For block-wise quantization, the scales are already in the correct 3D shape + # [num_experts, output_features, num_blocks] from quant_dequant_blockwise + # Convert scales to the selected ONNX dtype (prefer float32 for higher precision) + fc1_scale_tensor = fc1_scales.to(torch_dtype).flatten().detach().cpu().numpy() + fc2_scale_tensor = fc2_scales.to(torch_dtype).flatten().detach().cpu().numpy() + else: + # For row-wise quantization, handle different possible scale tensor structures for fc1_scales if len(fc1_scales.shape) == 4: # 4D case: [num_experts, inter_size, hidden_size, 1] - extract first scale per expert per output if use_swiglu: @@ -395,10 +570,6 @@ def create_cpu_moe_onnx_graph( [fc1_scale_tensor, numpy.ones(pad_size, dtype=fc1_scale_tensor.dtype)] ) - # Process scale tensor for proper shape - fc1_scale_data_list = fc1_scale_tensor.tolist() - fc1_scale_data = fc1_scale_data_list - # Handle different possible scale tensor structures for fc2_scales if len(fc2_scales.shape) == 4: # 4D case: [num_experts, hidden_size, inter_size, 1] - extract first scale per expert per output @@ -421,48 +592,30 @@ def create_cpu_moe_onnx_graph( [fc2_scale_tensor, numpy.ones(pad_size, dtype=fc2_scale_tensor.dtype)] ) - # Process scale tensor for proper shape - fc2_scale_data_list = fc2_scale_tensor.tolist() - fc2_scale_data = fc2_scale_data_list - - initializers.extend( - [ - helper.make_tensor( - "fc1_scales", - onnx_dtype, - fc1_scale_shape, - fc1_scale_data, - raw=False, - ), - helper.make_tensor( - "fc2_scales", - onnx_dtype, - fc2_scale_shape, - fc2_scale_data, - raw=False, - ), - ] - ) - else: - # For non-quantized mode, add bias tensors if provided - if fc1_bias is not None: - initializers.append( - helper.make_tensor( - "fc1_experts_bias", - onnx_dtype, - list(fc1_bias.shape), - fc1_bias.flatten().detach().cpu().numpy().astype(ort_to_numpy_type_map[onnx_dtype]).tolist(), - ) - ) - if fc2_bias is not None: - initializers.append( - helper.make_tensor( - "fc2_experts_bias", - onnx_dtype, - list(fc2_bias.shape), - fc2_bias.flatten().detach().cpu().numpy().astype(ort_to_numpy_type_map[onnx_dtype]).tolist(), - ) - ) + # Process scale tensors for proper data format + fc1_scale_data_list = fc1_scale_tensor.tolist() + fc1_scale_data = fc1_scale_data_list + fc2_scale_data_list = fc2_scale_tensor.tolist() + fc2_scale_data = fc2_scale_data_list + + initializers.extend( + [ + helper.make_tensor( + "fc1_scales", + onnx_dtype, + fc1_scale_shape, + fc1_scale_data, + raw=False, + ), + helper.make_tensor( + "fc2_scales", + onnx_dtype, + fc2_scale_shape, + fc2_scale_data, + raw=False, + ), + ] + ) graph_inputs = [ helper.make_tensor_value_info("input", onnx_dtype, [sequence_length, hidden_size]), @@ -645,10 +798,7 @@ class SparseMoeBlockORTHelper(nn.Module): def __init__(self, quant_bits=0, onnx_dtype=None): super().__init__() self.quant_bits = quant_bits - if onnx_dtype is None: - self.onnx_dtype = TensorProto.FLOAT16 if self.quant_bits > 0 else TensorProto.FLOAT - else: - self.onnx_dtype = onnx_dtype + self.onnx_dtype = onnx_dtype self.np_type = numpy.float16 if self.onnx_dtype == TensorProto.FLOAT16 else numpy.float32 def create_ort_session(self, moe_onnx_graph): @@ -717,8 +867,8 @@ def ort_forward(self, hidden_states: torch.Tensor, enable_performance_test=False tensors = { "input": hidden_states_flat.clone().to(device=device, dtype=torch_dtype), - "router_probs": router_input.clone().to(device=device, dtype=torch_dtype), - "output": torch.zeros_like(hidden_states_flat, device=device, dtype=torch_dtype), + "router_probs": router_logits.clone().to(device=device, dtype=torch_dtype), + "output": torch.zeros((batch_size * sequence_length, hidden_dim), device=device, dtype=torch_dtype), } try: @@ -779,14 +929,47 @@ def recreate_onnx_model(self): is_4_bit = self.quant_bits == 4 for i in range(self.num_experts): - w1_scale, pre_qweight1, w1_qdq = quant_dequant(self.experts[i].w1.weight, is_4_bit) - w2_scale, pre_qweight2, w2_qdq = quant_dequant(self.experts[i].w2.weight, is_4_bit) + if self.block_size > 0: + # Use block-wise quantization + w1_scale, pre_qweight1, w1_qdq = quant_dequant_blockwise( + self.experts[i].w1.weight, self.block_size, is_4_bit + ) + w2_scale, pre_qweight2, w2_qdq = quant_dequant_blockwise( + self.experts[i].w2.weight, self.block_size, is_4_bit + ) + else: + # Use row-wise quantization + w1_scale, pre_qweight1, w1_qdq = quant_dequant(self.experts[i].w1.weight, is_4_bit) + w2_scale, pre_qweight2, w2_qdq = quant_dequant(self.experts[i].w2.weight, is_4_bit) if self.use_swiglu: - # For SwiGLU, CPU kernel now always expects interleaved format - # SwigluMlp weights are already in interleaved format [gate_0, linear_0, gate_1, linear_1, ...] - # No conversion needed - both CPU and CUDA use interleaved format - self.experts[i].w1.weight = nn.Parameter(w1_qdq.contiguous().clone()) + if self.swiglu_interleaved: + pass + else: + if self.block_size > 0: + w3_scale, pre_qweight3, w3_qdq = quant_dequant_blockwise( + self.experts[i].w3.weight, self.block_size, is_4_bit + ) + else: + w3_scale, pre_qweight3, w3_qdq = quant_dequant(self.experts[i].w3.weight, is_4_bit) + + gate_weights = pre_qweight1 + value_weights = pre_qweight3 + gate_scales = w1_scale + value_scales = w3_scale + + pre_qweight1 = torch.cat([gate_weights, value_weights], dim=0) + w1_scale = torch.cat([gate_scales, value_scales], dim=0) + + if self.swiglu_interleaved: + self.experts[i].w1.weight = nn.Parameter(w1_qdq.contiguous().clone()) + + else: + intermediate_size = self.experts[i].w1.weight.shape[0] + gate_dequant = w1_qdq[:intermediate_size].contiguous().clone() + value_dequant = w1_qdq[intermediate_size:].contiguous().clone() + self.experts[i].w1.weight.data = gate_dequant + self.experts[i].w3.weight.data = value_dequant else: self.experts[i].w1.weight.data = w1_qdq.contiguous().clone() @@ -828,7 +1011,8 @@ def recreate_onnx_model(self): use_swiglu=self.use_swiglu, use_quant=True, # Always use QMoE quant_bits=self.quant_bits, - swiglu_interleaved=True, # CPU kernel now always expects interleaved format + swiglu_interleaved=self.swiglu_interleaved if hasattr(self, "swiglu_interleaved") else False, + block_size=self.block_size, # Add block_size for block-wise quantization ) except Exception: self.moe_onnx_graph = None @@ -877,6 +1061,45 @@ def parity_check(self): print(f"Parity check - {act_type} {self.quant_bits}-bit: max_diff = {max_diff:.6f}") + # Diagnostic dump: when differences are large, show the index and nearby values + if max_diff > 1e-3: + diff = (torch_output.cpu() - ort_output.cpu()).abs() + idx = torch.argmax(diff) + flat_idx = int(idx) + # Derive coordinates (batch, seq, hidden) from flattened index + total_elems = torch_output.numel() + # Work in flattened [batch, seq, hidden] ordering + hidden_dim = self.hidden_dim + seq = self.sequence_length + # Clamp to safe bounds + flat_idx = min(flat_idx, total_elems - 1) + i = flat_idx // (hidden_dim) + j = i // seq + k = flat_idx % hidden_dim + print( + f"Diagnostic - max diff at flat_idx={flat_idx} -> sample (batch_idx={j}, seq_idx={i % seq}, hidden_idx={k})" + ) + print("Torch sample:", torch_output.cpu().reshape(-1, hidden_dim)[i, k].item()) + print("ORT sample:", ort_output.cpu().reshape(-1, hidden_dim)[i, k].item()) + # Print routing and per-expert contributions for this token from the PyTorch reference + try: + hidden_states_flat = hidden_state.view(-1, hidden_dim) + token_vec = hidden_states_flat[i : i + 1] + gate_logits = self.gate(token_vec) + topk_vals, topk_experts = torch.topk(gate_logits, self.top_k, dim=-1) + topk_soft = F.softmax(topk_vals, dim=1) + print("Gate logits:", gate_logits.detach().cpu().numpy()) + print("Selected experts:", topk_experts.detach().cpu().numpy()) + print("Routing weights:", topk_soft.detach().cpu().numpy()) + # Compute per-expert contributions for selected experts + for idx_e, e in enumerate(topk_experts[0].tolist()): + expert_layer = self.experts[e] + expert_out = expert_layer(token_vec) + contrib = expert_out[0, k].item() * topk_soft[0, idx_e].item() + print(f"Expert {e} contrib at hidden {k}: {contrib}") + except Exception as _: + pass + ort_dtype_quant_bits_tolerance_map = { "FP32:0": (5e-3, 1e-3), "FP16:0": (5e-2, 1e-3), @@ -917,7 +1140,13 @@ def small_test_cases(): class SwigluMoEBlock(SparseMoeBlockORTHelper): def __init__( - self, config: SwigluMoeConfig, batch_size: int, sequence_length: int, quant_bits: int = 0, onnx_dtype=None + self, + config: SwigluMoeConfig, + batch_size: int, + sequence_length: int, + quant_bits: int = 0, + onnx_dtype=None, + block_size: int = 0, ): super().__init__(quant_bits, onnx_dtype=onnx_dtype) self.hidden_dim = config.hidden_size @@ -926,6 +1155,7 @@ def __init__( self.top_k = config.num_experts_per_token self.use_swiglu = True self.swiglu_interleaved = True + self.block_size = block_size # Store block_size for QMoE use_quant = self.quant_bits > 0 self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=True) @@ -995,7 +1225,13 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class PhiMoESparseMoeBlock(SparseMoeBlockORTHelper): def __init__( - self, config: PhiMoEConfig, batch_size: int, sequence_length: int, quant_bits: int = 0, onnx_dtype=None + self, + config: PhiMoEConfig, + batch_size: int, + sequence_length: int, + quant_bits: int = 0, + onnx_dtype=None, + block_size: int = 0, ): super().__init__(quant_bits, onnx_dtype=onnx_dtype) self.hidden_dim = config.hidden_size @@ -1005,6 +1241,7 @@ def __init__( self.router_jitter_noise = config.router_jitter_noise self.use_swiglu = True self.swiglu_interleaved = True + self.block_size = block_size # Store block_size for QMoE use_quant = self.quant_bits > 0 self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=True) @@ -1024,8 +1261,14 @@ def __init__( else: is_4_bit = self.quant_bits == 4 - scale1, pre_qweight1, w1_qdq = quant_dequant(expert.w1.weight, is_4_bit) - scale2, pre_qweight2, w2_qdq = quant_dequant(expert.w2.weight, is_4_bit) + if self.block_size > 0: + # Use block-wise quantization + scale1, pre_qweight1, w1_qdq = quant_dequant_blockwise(expert.w1.weight, self.block_size, is_4_bit) + scale2, pre_qweight2, w2_qdq = quant_dequant_blockwise(expert.w2.weight, self.block_size, is_4_bit) + else: + # Use row-wise quantization + scale1, pre_qweight1, w1_qdq = quant_dequant(expert.w1.weight, is_4_bit) + scale2, pre_qweight2, w2_qdq = quant_dequant(expert.w2.weight, is_4_bit) expert.w1.weight.data = w1_qdq expert.w2.weight.data = w2_qdq @@ -1064,6 +1307,7 @@ def __init__( use_quant=use_quant, quant_bits=self.quant_bits, swiglu_interleaved=self.swiglu_interleaved, + block_size=self.block_size, # Add block_size for block-wise quantization ) self.ort_sess = self.create_ort_session(self.moe_onnx_graph) if self.moe_onnx_graph else None @@ -1075,9 +1319,9 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = hidden_states.view(-1, hidden_dim) router_logits = self.gate(hidden_states) - routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) - routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) - routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + # Match CPU implementation: select top-k experts by logits, then softmax over those logits + routing_weights_vals, selected_experts = torch.topk(router_logits, self.top_k, dim=-1) + routing_weights = F.softmax(routing_weights_vals, dim=1, dtype=torch.float) routing_weights = routing_weights.to(hidden_states.dtype) final_hidden_states = torch.zeros( @@ -1112,6 +1356,14 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: (2, 16, 8), ] +# Define test cases for block-wise quantization +phi3_blockwise_test_cases = [ + (1, 32, 4, 32), # batch_size, sequence_length, quant_bits, block_size + (1, 32, 8, 64), + (2, 16, 4, 32), + (2, 16, 8, 64), +] + @unittest.skipIf(disable_cpu_qmoe_tests, "Skipping qMoE cpu tests") class TestPhiQMoECPU(unittest.TestCase): @@ -1152,6 +1404,37 @@ def test_phi3_qmoe_parity_cpu(self, batch_size, sequence_length, quant_bits): phi3_moe.parity_check() + @parameterized.expand(phi3_blockwise_test_cases) + def test_phi3_qmoe_blockwise_parity_cpu(self, batch_size, sequence_length, quant_bits, block_size): + torch.manual_seed(42) + numpy.random.seed(42) + + test_config = f"batch_size={batch_size}, sequence_length={sequence_length}, quant_bits={quant_bits}, block_size={block_size}" + print(f"Running Phi3 QMoE block-wise test: {test_config}") + + config = PhiMoEConfig(hidden_size=128, intermediate_size=256, num_local_experts=4, num_experts_per_tok=2) + + phi3_moe = PhiMoESparseMoeBlock( + config, + batch_size=batch_size, + sequence_length=sequence_length, + quant_bits=quant_bits, + onnx_dtype=TensorProto.FLOAT, + block_size=block_size, # Enable block-wise quantization + ) + + hidden_states = torch.randn(batch_size, sequence_length, config.hidden_size).to(torch.float32) + + torch_result = phi3_moe.forward(hidden_states) + + # Verify output shape and basic properties + expected_shape = (batch_size, sequence_length, config.hidden_size) + self.assertEqual(torch_result.shape, expected_shape) + self.assertFalse(torch.isnan(torch_result).any()) + self.assertFalse(torch.isinf(torch_result).any()) + + phi3_moe.parity_check() + disable_cpu_qmoe_tests = False @@ -1162,6 +1445,14 @@ def test_phi3_qmoe_parity_cpu(self, batch_size, sequence_length, quant_bits): (2, 16, 8), ] +# Define test cases for block-wise quantization +swiglu_blockwise_test_cases = [ + (1, 32, 4, 32), # batch_size, sequence_length, quant_bits, block_size + (1, 32, 8, 64), + (2, 16, 4, 32), + (2, 16, 8, 64), +] + @unittest.skipIf(disable_cpu_qmoe_tests, "Skipping qMoE cpu tests") class TestSwigluQMoECPU(unittest.TestCase): @@ -1201,6 +1492,36 @@ def test_swiglu_qmoe_parity_cpu(self, batch_size, sequence_length, quant_bits): swiglu_moe.parity_check() + @parameterized.expand(swiglu_blockwise_test_cases) + def test_swiglu_qmoe_blockwise_parity_cpu(self, batch_size, sequence_length, quant_bits, block_size): + torch.manual_seed(42) + numpy.random.seed(42) + + test_config = f"batch_size={batch_size}, sequence_length={sequence_length}, quant_bits={quant_bits}, block_size={block_size}" + print(f"Running SwiGLU block-wise test: {test_config}") + + config = SwigluMoeConfig(hidden_size=128, intermediate_size=256, num_local_experts=4, num_experts_per_token=2) + + swiglu_moe = SwigluMoEBlock( + config, + batch_size=batch_size, + sequence_length=sequence_length, + quant_bits=quant_bits, + onnx_dtype=TensorProto.FLOAT, + block_size=block_size, # Enable block-wise quantization + ) + + hidden_states = torch.randn(batch_size, sequence_length, config.hidden_size).to(torch.float32) + + torch_result = swiglu_moe.forward(hidden_states) + + expected_shape = (batch_size, sequence_length, config.hidden_size) + self.assertEqual(torch_result.shape, expected_shape) + self.assertFalse(torch.isnan(torch_result).any()) + self.assertFalse(torch.isinf(torch_result).any()) + + swiglu_moe.parity_check() + @unittest.skipIf(True, "Skipping QMoE CPU benchmark tests") class TestQMoESwiGLUBenchmark(unittest.TestCase): diff --git a/onnxruntime/test/quantization/quantization_test.cc b/onnxruntime/test/quantization/quantization_test.cc index 773f56de5361b..c2dd58a94c9dd 100644 --- a/onnxruntime/test/quantization/quantization_test.cc +++ b/onnxruntime/test/quantization/quantization_test.cc @@ -5,7 +5,7 @@ #include "core/framework/tensor.h" #include "core/quantization/quantization.h" -#include "test/framework/test_utils.h" +#include "test/unittest_util/framework_test_utils.h" namespace onnxruntime { namespace test { diff --git a/onnxruntime/test/shared_lib/test_inference.cc b/onnxruntime/test/shared_lib/test_inference.cc index b7a9da8e1b658..8c2928670934a 100644 --- a/onnxruntime/test/shared_lib/test_inference.cc +++ b/onnxruntime/test/shared_lib/test_inference.cc @@ -494,6 +494,35 @@ INSTANTIATE_TEST_SUITE_P(CApiTestWithProviders, CApiTestWithProvider, ::testing::Values(0, 1, 2, 3, 4)); +TEST(CApiTest, TestInputPassThroughToOutput) { + const ORTCHAR_T* model_uri = TSTR("testdata/input_propagated_to_output.onnx"); + Ort::Session session(*ort_env, model_uri, Ort::SessionOptions{}); + auto inputs_meminfos = session.GetMemoryInfoForInputs(); + ASSERT_EQ(1U, inputs_meminfos.size()); + auto inputs_epdevices = session.GetEpDeviceForInputs(); + ASSERT_EQ(1U, inputs_epdevices.size()); + auto outputs_meminfos = session.GetMemoryInfoForOutputs(); + ASSERT_EQ(7U, outputs_meminfos.size()); +} + +TEST(CApiTest, TestDanglingInput) { + // Here we test an issue with segments_ids that is an input not consumed by anything + // This kind of model is unlikely to be used in practice but we want to make sure it works + const ORTCHAR_T* model_uri = TSTR("testdata/test_dangling_input_segment_ids.onnx"); + Ort::Session session(*ort_env, model_uri, Ort::SessionOptions{}); + auto inputs_meminfos = session.GetMemoryInfoForInputs(); + ASSERT_EQ(2U, inputs_meminfos.size()); + auto outputs_meminfos = session.GetMemoryInfoForOutputs(); + ASSERT_EQ(2U, outputs_meminfos.size()); + auto inputs_epdevices = session.GetEpDeviceForInputs(); + ASSERT_EQ(2U, inputs_epdevices.size()); + // One of the devices returning is null since the input is not consumed + // there is not a device for it. + const bool null_present = std::any_of(inputs_epdevices.begin(), inputs_epdevices.end(), + [](const auto& device) { return device == nullptr; }); + ASSERT_TRUE(null_present); +} + #if !defined(DISABLE_SPARSE_TENSORS) TEST(CApiTest, SparseOutputModel) { std::vector dense_shape{3, 3}; @@ -505,7 +534,15 @@ TEST(CApiTest, SparseOutputModel) { std::vector ort_inputs; std::vector input_names; const char* const output_names[] = {"values"}; + // This model produces a sparse output from a constant sparse initializer Ort::Session session(*ort_env, SPARSE_OUTPUT_MODEL_URI, Ort::SessionOptions{}); + auto inputs_meminfos = session.GetMemoryInfoForInputs(); + ASSERT_TRUE(inputs_meminfos.empty()); + auto outputs_meminfos = session.GetMemoryInfoForOutputs(); + ASSERT_EQ(1U, outputs_meminfos.size()); + auto inputs_epdevices = session.GetEpDeviceForInputs(); + ASSERT_TRUE(inputs_epdevices.empty()); + auto ort_outputs = session.Run(Ort::RunOptions{}, input_names.data(), ort_inputs.data(), ort_inputs.size(), output_names, 1); ASSERT_EQ(ort_outputs.size(), 1U); diff --git a/onnxruntime/test/testdata/input_propagated_to_output.onnx b/onnxruntime/test/testdata/input_propagated_to_output.onnx new file mode 100644 index 0000000000000..feeab10556cb0 Binary files /dev/null and b/onnxruntime/test/testdata/input_propagated_to_output.onnx differ diff --git a/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc b/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc index c0c73695b6f8c..108b931e5737b 100644 --- a/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc +++ b/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc @@ -31,18 +31,27 @@ "current_failing_tests": [ "^test_adagrad", "^test_adagrad_multiple", - "^test_attention_4d_diff_heads_mask4d_padded_kv", // need nonpad_kv_seqlen - "^test_attention_4d_with_past_and_present_qk_matmul_bias_3d_mask_causal", // attention op implementation is wrong - "^test_attention_4d_with_past_and_present_qk_matmul_bias_4d_mask_causal", - "^test_attention_4d_attn_mask_3d_causal_expanded_cpu", // MacOS fails with const dimension <= num_dims was false. + "^test_attention_4d_diff_heads_mask4d_padded_kv*", // pending onnx update + "^test_attention_3d_gqa*", // pending onnx update + "^test_attention_3d_gqa_causal", // pending onnx update + "^test_attention_3d_gqa_scaled", // pending onnx update + "^test_attention_3d_gqa_softcap", // pending onnx update + "^test_attention_3d_gqa_with_past_and_present", // pending onnx update + "^test_attention_4d_gqa*", // pending onnx update + "^test_attention_4d_gqa_causal", // pending onnx update + "^test_attention_4d_gqa_scaled", // pending onnx update + "^test_attention_4d_gqa_softcap", // pending onnx update + "^test_attention_4d_gqa_with_past_and_present", // pending onnx update + "^test_attention_*causal*", // pending onnx update + "^test_attention_4d_with_past_and_present_qk_matmul_bias_3d_mask_causal*", // pending onnx update + "^test_attention_4d_with_past_and_present_qk_matmul_bias_4d_mask_causal*", // pending onnx update + "^test_attention_4d_attn_mask_3d_causal_expanded*", // pending onnx update + "^test_attention_4d_fp16*", // precision issue: 1 / 192 mismatched elements + "^test_attention_4d_fp16_expanded*", // precision issue: 3 / 192 mismatched elements "^test_l2normalization*", // LpNormalization(22) not implemented "^test_l1normalization*", // LpNormalization(22) not implemented "^test_lpnormalization*", // LpNormalization(22) not implemented "^test_tensorscatter*", // TensorScatter(24) not implemented - "^test_attention_4d_fp16*", // precision issue: 1 / 192 mismatched elements - "^test_attention_4d_fp16_expanded*", // precision issue: 3 / 192 mismatched elements - "^test_attention_4d_gqa_with_past_and_present_fp16*", // precision issue: 4 / 576 mismatched elements - "^test_attention_4d_gqa_with_past_and_present_fp16_expanded*", // precision issue: 37 / 576 mismatched elements "^test_castlike_no_saturate_FLOAT_to_FLOAT8*", // ORT does not support ml_dtypes "^test_castlike_UINT4_to*", // ORT does not support ml_dtypes "^test_castlike_INT4_to*", // ORT does not support ml_dtypes @@ -114,12 +123,10 @@ "^test_if_opt", "^test_loop16_seq_none", "^test_identity_opt", + // rotary dim should be fixed in onnx==1.19.1 "^test_rotary_embedding_no_position_ids_rotary_dim", "^test_rotary_embedding_with_interleaved_rotary_dim", "^test_rotary_embedding_with_rotary_dim", - "^test_rotary_embedding_3d_input_expanded", - "^test_rotary_embedding_interleaved_expanded", - "^test_rotary_embedding_no_position_ids_interleaved_expanded", "^test_rotary_embedding_expanded", //webgpu "^test_rotary_embedding_no_position_ids_expanded", //webgpu // Following tests are for opset 16 ops and are not yet implemented in ORT @@ -775,7 +782,8 @@ //TODO: Resolve as a graph implementation that returns a constant inf tensor with appropriate strides "^test_reduce_max_empty_set_cpu", // DNNL result in "(shapes (2, 1, 4), (1, 0, 1) mismatch)". this is the same for test_reduce_min_empty_set which is already in the list "^test_reduce_min_empty_set_cpu", - "^test_resize_upsample_sizes_nearest_not_smaller_cpu" + "^test_resize_upsample_sizes_nearest_not_smaller_cpu", + "^test_clip_min_greater_than_max_cpu" ], // ORT first supported opset 7, so models with nodes that require versions prior to opset 7 are not supported "tests_with_pre_opset7_dependencies": [ diff --git a/onnxruntime/test/testdata/test_dangling_input_segment_ids.onnx b/onnxruntime/test/testdata/test_dangling_input_segment_ids.onnx new file mode 100644 index 0000000000000..a83c21030ad67 Binary files /dev/null and b/onnxruntime/test/testdata/test_dangling_input_segment_ids.onnx differ diff --git a/onnxruntime/test/testdata/test_dangling_input_segment_ids.py b/onnxruntime/test/testdata/test_dangling_input_segment_ids.py new file mode 100644 index 0000000000000..c5eb8a600d6b5 --- /dev/null +++ b/onnxruntime/test/testdata/test_dangling_input_segment_ids.py @@ -0,0 +1,86 @@ +""" +Run this script to recreate the original onnx model. +Example usage: +python test_dangling_input_segment_ids.py out_model_path.onnx +""" + +import os +import sys + +import numpy as np +import onnx +from onnx import TensorProto, helper, numpy_helper + +DATA_DIR = os.path.join(os.path.dirname(os.path.realpath(__file__)), "test_dangling_input_segment_ids") + + +def order_repeated_field(repeated_proto, key_name, order): + order = list(order) + repeated_proto.sort(key=lambda x: order.index(getattr(x, key_name))) + + +def make_node(op_type, inputs, outputs, name=None, doc_string=None, domain=None, **kwargs): + node = helper.make_node(op_type, inputs, outputs, name, doc_string, domain, **kwargs) + if doc_string == "": + node.doc_string = "" + order_repeated_field(node.attribute, "name", kwargs.keys()) + return node + + +def make_graph(*args, doc_string=None, **kwargs): + graph = helper.make_graph(*args, doc_string=doc_string, **kwargs) + if doc_string == "": + graph.doc_string = "" + return graph + + +model = helper.make_model( + opset_imports=[helper.make_operatorsetid("", 14), helper.make_operatorsetid("com.microsoft", 1)], + ir_version=7, + graph=make_graph( + name="embed_layernorm_graph", + inputs=[ + helper.make_tensor_value_info("input_ids", TensorProto.INT32, shape=[1, 4]), + helper.make_tensor_value_info("segment_ids", TensorProto.INT32, shape=[1, 4]), + ], + outputs=[ + helper.make_tensor_value_info("layernorm_out", TensorProto.FLOAT, shape=[1, 4, 4]), + helper.make_tensor_value_info("mask_index_out", TensorProto.INT32, shape=[1]), + ], + initializer=[ + numpy_helper.from_array( + np.load(os.path.join(DATA_DIR, "const0_word_embed.npy")).astype("float32").reshape([32, 4]), + name="word_embed", + ), + numpy_helper.from_array( + np.load(os.path.join(DATA_DIR, "const1_pos_embed.npy")).astype("float32").reshape([16, 4]), + name="pos_embed", + ), + numpy_helper.from_array( + np.array( + [0.6185135841369629, 0.010364261455833912, 0.5386272668838501, 0.0030179566238075495], + dtype="float32", + ), + name="gamma", + ), + numpy_helper.from_array( + np.array( + [0.9511938095092773, 0.9054020047187805, 0.7959669232368469, 0.9152743220329285], dtype="float32" + ), + name="beta", + ), + ], + nodes=[ + make_node( + "EmbedLayerNormalization", + inputs=["input_ids", "", "word_embed", "pos_embed", "", "gamma", "beta"], + outputs=["layernorm_out", "mask_index_out"], + domain="com.microsoft", + ) + ], + ), +) + +if __name__ == "__main__" and len(sys.argv) == 2: + _, out_path = sys.argv + onnx.save(model, out_path) diff --git a/onnxruntime/test/unittest_main/test_main.cc b/onnxruntime/test/unittest_main/test_main.cc index 56c11039328bc..117a26d48efe9 100644 --- a/onnxruntime/test/unittest_main/test_main.cc +++ b/onnxruntime/test/unittest_main/test_main.cc @@ -24,38 +24,86 @@ #include "gtest/gtest.h" +#include "core/common/common.h" #include "core/platform/env_var_utils.h" #include "core/session/onnxruntime_cxx_api.h" #include "core/util/thread_utils.h" -#include "test/test_environment.h" + +#if !defined(ORT_MINIMAL_BUILD) && defined(ORT_UNIT_TEST_ENABLE_DYNAMIC_PLUGIN_EP_USAGE) +#define TEST_MAIN_ENABLE_DYNAMIC_PLUGIN_EP_USAGE +#endif // !defined(ORT_MINIMAL_BUILD) && defined(ORT_UNIT_TEST_ENABLE_DYNAMIC_PLUGIN_EP_USAGE) + +#if defined(TEST_MAIN_ENABLE_DYNAMIC_PLUGIN_EP_USAGE) +#include "test/unittest_util/test_dynamic_plugin_ep.h" +#endif // defined(TEST_MAIN_ENABLE_DYNAMIC_PLUGIN_EP_USAGE) std::unique_ptr ort_env; +// define environment variable name constants here +namespace env_var_names { +// Set ORT log level to the specified numeric log level. +constexpr const char* kLogLevel = "ORT_UNIT_TEST_MAIN_LOG_LEVEL"; + +#if defined(TEST_MAIN_ENABLE_DYNAMIC_PLUGIN_EP_USAGE) +// Specify dynamic plugin EP configuration JSON. +// Refer to `onnxruntime::test::dynamic_plugin_ep_infra::ParseInitializationConfig()` for more information. +constexpr const char* kDynamicPluginEpConfigJson = "ORT_UNIT_TEST_MAIN_DYNAMIC_PLUGIN_EP_CONFIG_JSON"; +#endif // defined(TEST_MAIN_ENABLE_DYNAMIC_PLUGIN_EP_USAGE) +} // namespace env_var_names + // ortenv_setup() and ortenv_teardown() are used by onnxruntime/test/xctest/xcgtest.mm so can't be file local extern "C" void ortenv_setup() { + ORT_TRY { #ifdef _WIN32 - // Set the locale to UTF-8 to ensure proper handling of wide characters on Windows - std::wclog.imbue(std::locale(".UTF-8", std::locale::ctype)); + // Set the locale to UTF-8 to ensure proper handling of wide characters on Windows + std::wclog.imbue(std::locale(".UTF-8", std::locale::ctype)); #endif - OrtThreadingOptions tpo; - - // allow verbose logging to be enabled by setting this environment variable to a numeric log level - constexpr auto kLogLevelEnvironmentVariableName = "ORT_UNIT_TEST_MAIN_LOG_LEVEL"; - OrtLoggingLevel log_level = ORT_LOGGING_LEVEL_WARNING; - if (auto log_level_override = onnxruntime::ParseEnvironmentVariable(kLogLevelEnvironmentVariableName); - log_level_override.has_value()) { - *log_level_override = std::clamp(*log_level_override, - static_cast(ORT_LOGGING_LEVEL_VERBOSE), - static_cast(ORT_LOGGING_LEVEL_FATAL)); - std::cout << "Setting log level to " << *log_level_override << "\n"; - log_level = static_cast(*log_level_override); + OrtThreadingOptions tpo; + + OrtLoggingLevel log_level = ORT_LOGGING_LEVEL_WARNING; + if (auto log_level_override = onnxruntime::ParseEnvironmentVariable(env_var_names::kLogLevel); + log_level_override.has_value()) { + *log_level_override = std::clamp(*log_level_override, + static_cast(ORT_LOGGING_LEVEL_VERBOSE), + static_cast(ORT_LOGGING_LEVEL_FATAL)); + std::cout << "Setting log level to " << *log_level_override << "\n"; + log_level = static_cast(*log_level_override); + } + + ort_env.reset(new Ort::Env(&tpo, log_level, "Default")); + +#if defined(TEST_MAIN_ENABLE_DYNAMIC_PLUGIN_EP_USAGE) + { + namespace dynamic_plugin_ep_infra = onnxruntime::test::dynamic_plugin_ep_infra; + if (auto dynamic_plugin_ep_config_json = onnxruntime::ParseEnvironmentVariable( + env_var_names::kDynamicPluginEpConfigJson); + dynamic_plugin_ep_config_json.has_value()) { + std::cout << "Initializing dynamic plugin EP infrastructure with configuration:\n" + << *dynamic_plugin_ep_config_json << "\n"; + dynamic_plugin_ep_infra::InitializationConfig config{}; + ORT_THROW_IF_ERROR(dynamic_plugin_ep_infra::ParseInitializationConfig(*dynamic_plugin_ep_config_json, config)); + ORT_THROW_IF_ERROR(dynamic_plugin_ep_infra::Initialize(*ort_env, config)); + } + } +#endif // defined(TEST_MAIN_ENABLE_DYNAMIC_PLUGIN_EP_USAGE) + } + ORT_CATCH(const std::exception& ex) { + ORT_HANDLE_EXCEPTION([&]() { + std::cerr << ex.what(); + std::exit(1); + }); } - - ort_env.reset(new Ort::Env(&tpo, log_level, "Default")); } extern "C" void ortenv_teardown() { +#if defined(TEST_MAIN_ENABLE_DYNAMIC_PLUGIN_EP_USAGE) + { + namespace dynamic_plugin_ep_infra = onnxruntime::test::dynamic_plugin_ep_infra; + dynamic_plugin_ep_infra::Shutdown(); + } +#endif // defined(TEST_MAIN_ENABLE_DYNAMIC_PLUGIN_EP_USAGE) + ort_env.reset(); } diff --git a/onnxruntime/test/providers/base_tester.cc b/onnxruntime/test/unittest_util/base_tester.cc similarity index 89% rename from onnxruntime/test/providers/base_tester.cc rename to onnxruntime/test/unittest_util/base_tester.cc index 4b37b6c9438aa..4d640e0f5e33d 100644 --- a/onnxruntime/test/providers/base_tester.cc +++ b/onnxruntime/test/unittest_util/base_tester.cc @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "test/providers/base_tester.h" +#include "test/unittest_util/base_tester.h" #include #include "gmock/gmock.h" @@ -13,12 +13,13 @@ #include "core/session/inference_session.h" #include "core/session/onnxruntime_session_options_config_keys.h" -#include "test/framework/TestAllocatorManager.h" -#include "test/providers/run_options_config_keys.h" +#include "test/unittest_util/run_options_config_keys.h" +#include "test/unittest_util/test_allocator_manager.h" +#include "test/unittest_util/test_dynamic_plugin_ep.h" #include "test/util/include/asserts.h" #include "test/util/include/default_providers.h" -#include "test/util/include/test_utils.h" #include "test/util/include/test_environment.h" +#include "test/util/include/test_utils.h" #ifdef ENABLE_TRAINING #include "orttraining/core/session/training_session.h" @@ -39,6 +40,7 @@ void DebugTrap() { #endif } #endif + } // namespace BaseTester::~BaseTester() { @@ -421,7 +423,8 @@ void BaseTester::ExecuteModel(Model& model, SessionType& session, bool SetEpsForAllNodes(Graph& graph, const std::vector>& execution_providers, - const std::vector>* custom_registries) { + const std::vector>* custom_registries, + const std::function& ep_uses_kernel_registry_fn) { const OpSchemaKernelTypeStrResolver kernel_type_str_resolver{}; const KernelRegistry::TypeConstraintMap type_constraint_map{}; @@ -436,39 +439,36 @@ bool SetEpsForAllNodes(Graph& graph, auto provider_type = ep->Type(); node.SetExecutionProviderType(provider_type); - if (provider_type == onnxruntime::kOpenVINOExecutionProvider || - provider_type == onnxruntime::kTensorrtExecutionProvider || - provider_type == onnxruntime::kNnapiExecutionProvider || - provider_type == onnxruntime::kVSINPUExecutionProvider || - provider_type == onnxruntime::kCoreMLExecutionProvider || - provider_type == onnxruntime::kDnnlExecutionProvider || - provider_type == onnxruntime::kQnnExecutionProvider || - provider_type == onnxruntime::kSnpeExecutionProvider) { - found = true; - break; - } - // Check the EP has an impl for the node from builtin registry. - if (KernelRegistry::HasImplementationOf(*ep->GetKernelRegistry(), node, ep->Type(), kernel_type_str_resolver, - logger)) { + if (!ep_uses_kernel_registry_fn(*ep)) { found = true; break; } - // check the internal NHWC domain if EP requests NHWC as it may only have a kernel registered in that domain - if (ep->GetPreferredLayout() == DataLayout::NHWC) { - const KernelCreateInfo* kci = nullptr; - auto status = ep->GetKernelRegistry()->TryFindKernel(ep->Type(), - std::string_view(node.OpType()), - std::string_view(kMSInternalNHWCDomain), - node.SinceVersion(), - type_constraint_map, - logger, - &kci); - if (status.IsOK() && kci != nullptr) { + if (std::shared_ptr ep_kernel_registry = ep->GetKernelRegistry(); + ep_kernel_registry != nullptr) { + // Check the EP has an impl for the node from builtin registry. + if (KernelRegistry::HasImplementationOf(*ep_kernel_registry, node, ep->Type(), kernel_type_str_resolver, + logger)) { found = true; break; } + + // check the internal NHWC domain if EP requests NHWC as it may only have a kernel registered in that domain + if (ep->GetPreferredLayout() == DataLayout::NHWC) { + const KernelCreateInfo* kci = nullptr; + auto status = ep_kernel_registry->TryFindKernel(ep->Type(), + std::string_view(node.OpType()), + std::string_view(kMSInternalNHWCDomain), + node.SinceVersion(), + type_constraint_map, + logger, + &kci); + if (status.IsOK() && kci != nullptr) { + found = true; + break; + } + } } // Check the EP has an impl for the node from custom_registries @@ -593,9 +593,9 @@ void BaseTester::RunWithConfig(size_t* number_of_pre_packed_weights_counter, fetches_.clear(); // IsAllowReleasedONNXOpsetsOnlySet() checks for the appropriate env var in the process (i.e.) process-wide - // `IsAllowReleasedONNXOpsetsOnlySetForThisTest()` is for this specific OpTester instance - // We will only support released opsets iff IsAllowReleasedONNXOpsetsOnlySet() and `IsAllowReleasedONNXOpsetsOnlySetForThisTest()` - // are both true + // `test_allow_released_onnx_opset_only_` is for this specific OpTester instance + // We will only support released opsets iff IsAllowReleasedONNXOpsetsOnlySet() and + // `test_allow_released_onnx_opset_only_` are both true auto allow_released_onnx_opset_only = test_allow_released_onnx_opset_only_ && model_load_utils::IsAllowReleasedONNXOpsetsOnlySet(); @@ -647,11 +647,11 @@ void BaseTester::RunWithConfig(size_t* number_of_pre_packed_weights_counter, #ifdef USE_TENSORRT // only run trt ep to reduce test time - static const std::string all_provider_types[] = { + static const std::vector all_provider_types = { kTensorrtExecutionProvider, }; #else - static const std::string all_provider_types[] = { + static const std::vector all_provider_types = { kCpuExecutionProvider, kCudaExecutionProvider, #ifdef ENABLE_CUDA_NHWC_OPS @@ -681,9 +681,25 @@ void BaseTester::RunWithConfig(size_t* number_of_pre_packed_weights_counter, } #endif + const auto dynamic_plugin_ep_name = dynamic_plugin_ep_infra::GetEpName(); + + std::optional> provider_types_including_dynamic_plugin_ep{}; + if (dynamic_plugin_ep_name.has_value()) { + ORT_ENFORCE(std::find(all_provider_types.begin(), all_provider_types.end(), + *dynamic_plugin_ep_name) == all_provider_types.end(), + "Dynamic plugin EP name conflicts with a known EP name: ", *dynamic_plugin_ep_name); + provider_types_including_dynamic_plugin_ep = all_provider_types; + provider_types_including_dynamic_plugin_ep->push_back(*dynamic_plugin_ep_name); + } + + const auto all_provider_types_span = + provider_types_including_dynamic_plugin_ep.has_value() + ? gsl::span{*provider_types_including_dynamic_plugin_ep} + : gsl::span{all_provider_types}; + bool has_run = false; - for (const std::string& provider_type : all_provider_types) { + for (const std::string& provider_type : all_provider_types_span) { if (ctx_.excluded_provider_types.count(provider_type) > 0) continue; @@ -732,6 +748,9 @@ void BaseTester::RunWithConfig(size_t* number_of_pre_packed_weights_counter, execution_provider = DefaultDmlExecutionProvider(); else if (provider_type == onnxruntime::kWebGpuExecutionProvider) execution_provider = DefaultWebGpuExecutionProvider(); + else if (provider_type == dynamic_plugin_ep_name) { + execution_provider = dynamic_plugin_ep_infra::MakeEp(); + } // skip if execution provider is disabled if (execution_provider == nullptr) @@ -830,13 +849,46 @@ void BaseTester::ExecuteModelForEps( } ASSERT_TRUE(!execution_providers.empty()) << "Empty execution providers vector."; - if (try_assign_ep_for_nodes && !SetEpsForAllNodes(model.MainGraph(), execution_providers, custom_registries)) { - std::string providers; - for (const auto& ep : execution_providers) { - providers.append(ep->Type() + " "); + if (try_assign_ep_for_nodes) { + auto ep_uses_kernel_registry = [](const IExecutionProvider& ep) { + const auto& provider_type = ep.Type(); + + constexpr std::array kEpsThatDoNotUseKernelRegistry{ + kOpenVINOExecutionProvider, + kTensorrtExecutionProvider, + kNnapiExecutionProvider, + kVSINPUExecutionProvider, + kCoreMLExecutionProvider, + kDnnlExecutionProvider, + kQnnExecutionProvider, + kSnpeExecutionProvider, + }; + + // check list of known EPs that do not use a kernel registry + if (const auto ep_it = std::find(kEpsThatDoNotUseKernelRegistry.begin(), kEpsThatDoNotUseKernelRegistry.end(), + provider_type); + ep_it != kEpsThatDoNotUseKernelRegistry.end()) { + return false; + } + + // assume that a dynamic plugin EP which does not return a kernel registry does not use one + if (provider_type == dynamic_plugin_ep_infra::GetEpName() && + ep.GetKernelRegistry() == nullptr) { + return false; + } + + // otherwise, assume that the EP uses a kernel registry + return true; + }; + + if (!SetEpsForAllNodes(model.MainGraph(), execution_providers, custom_registries, ep_uses_kernel_registry)) { + std::string providers; + for (const auto& ep : execution_providers) { + providers.append(ep->Type() + " "); + } + LOGS_DEFAULT(WARNING) << "registered execution providers " << providers << " were unable to run the model."; + return; } - LOGS_DEFAULT(WARNING) << "registered execution providers " << providers << " were unable to run the model."; - return; } std::string provider_type; @@ -875,7 +927,7 @@ void BaseTester::ExecuteModelForEps( *number_of_shared_pre_packed_weights_counter = session_object.GetSessionState().GetUsedSharedPrePackedWeightCounter(); } -}; +} void BaseTester::AddReferenceOutputs(const std::string& model_path, float abs_error, std::unique_ptr ep) { diff --git a/onnxruntime/test/providers/base_tester.h b/onnxruntime/test/unittest_util/base_tester.h similarity index 99% rename from onnxruntime/test/providers/base_tester.h rename to onnxruntime/test/unittest_util/base_tester.h index b55a43c92637c..58b67a0d67d3c 100644 --- a/onnxruntime/test/providers/base_tester.h +++ b/onnxruntime/test/unittest_util/base_tester.h @@ -15,11 +15,12 @@ #include "core/framework/run_options.h" #include "core/framework/tensor.h" #include "core/framework/TensorSeq.h" +#include "core/graph/graph.h" #include "core/graph/model.h" -#include "test/framework/TestAllocatorManager.h" -#include "test/providers/checkers.h" -#include "test/providers/tester_types.h" +#include "test/unittest_util/checkers.h" +#include "test/unittest_util/test_allocator_manager.h" +#include "test/unittest_util/tester_types.h" namespace onnxruntime { class InferenceSession; diff --git a/onnxruntime/test/providers/checkers.cc b/onnxruntime/test/unittest_util/checkers.cc similarity index 96% rename from onnxruntime/test/providers/checkers.cc rename to onnxruntime/test/unittest_util/checkers.cc index 19ac5e06cddd8..7b2a5a4a4ff2f 100644 --- a/onnxruntime/test/providers/checkers.cc +++ b/onnxruntime/test/unittest_util/checkers.cc @@ -1,17 +1,17 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "test/providers/checkers.h" +#include "test/unittest_util/checkers.h" #include "gtest/gtest.h" +#include "core/common/narrow.h" #include "core/graph/constants.h" #include "core/framework/TensorSeq.h" #include "core/framework/int4.h" #include "core/framework/float4.h" - -#include "test/framework/test_utils.h" -#include "test/providers/provider_test_utils.h" +#include "test/unittest_util/framework_test_utils.h" +#include "test/unittest_util/conversion.h" namespace onnxruntime { namespace test { @@ -192,7 +192,7 @@ struct TensorCheck { ORT_THROW("Shape mismatch"); } - const auto size = actual.Shape().Size(); + const auto size = narrow(actual.Shape().Size()); const Float4E2M1x2* expected_data = expected.Data(); const Float4E2M1x2* actual_data = actual.Data(); @@ -201,7 +201,7 @@ struct TensorCheck { // For now, using float tolerance is fine auto tolerance_params = get_tolerance_params(params, provider_type); - for (int64_t i = 0; i < size; ++i) { + for (size_t i = 0; i < size; ++i) { size_t r = i >> 1; size_t c = i & 0x1; @@ -228,11 +228,11 @@ struct TensorCheck { ORT_UNUSED_PARAMETER(params); const Int4x2* cur_expected; const Int4x2* cur_actual; - const auto size = actual.Shape().Size(); + const auto size = narrow(actual.Shape().Size()); cur_expected = expected.Data(); cur_actual = actual.Data(); - for (size_t i = 0; i < static_cast(size); ++i) { + for (size_t i = 0; i < size; ++i) { size_t r = i >> 1; size_t c = i & 0x1; EXPECT_EQ(cur_expected[r].GetElem(c), cur_actual[r].GetElem(c)) << "i:" << i; @@ -247,11 +247,11 @@ struct TensorCheck { ORT_UNUSED_PARAMETER(params); const UInt4x2* cur_expected; const UInt4x2* cur_actual; - const auto size = actual.Shape().Size(); + const auto size = narrow(actual.Shape().Size()); cur_expected = expected.Data(); cur_actual = actual.Data(); - for (size_t i = 0; i < static_cast(size); ++i) { + for (size_t i = 0; i < size; ++i) { size_t r = i >> 1; size_t c = i & 0x1; EXPECT_EQ(cur_expected[r].GetElem(c), cur_actual[r].GetElem(c)) << "i:" << i; @@ -452,12 +452,12 @@ struct TensorCheck { const std::string& provider_type) const { auto* cur_expected = expected.Data(); auto* cur_actual = actual.Data(); - auto size = actual.Shape().Size(); + auto size = narrow(actual.Shape().Size()); std::vector f_expected(size); std::vector f_actual(size); - ConvertMLFloat16ToFloat(cur_expected, f_expected.data(), static_cast(size)); - ConvertMLFloat16ToFloat(cur_actual, f_actual.data(), static_cast(size)); + ConvertMLFloat16ToFloat(cur_expected, f_expected.data(), size); + ConvertMLFloat16ToFloat(cur_actual, f_actual.data(), size); // deal with rare cases in which order of output data from a kernel MAY be // undefined @@ -467,7 +467,7 @@ struct TensorCheck { auto tolerance_params = get_tolerance_params(params, provider_type); - for (int64_t i = 0; i < size; ++i) { + for (size_t i = 0; i < size; ++i) { if (std::isnan(f_expected[i])) { EXPECT_TRUE(std::isnan(f_actual[i])) << "Expected NaN. i:" << i; } else if (std::isinf(f_expected[i])) { // Test infinity for equality @@ -488,12 +488,12 @@ struct TensorCheck { const std::string& provider_type) const { auto* cur_expected = expected.Data(); auto* cur_actual = actual.Data(); - auto size = actual.Shape().Size(); + auto size = narrow(actual.Shape().Size()); std::vector f_expected(size); std::vector f_actual(size); - BFloat16ToFloat(cur_expected, f_expected.data(), static_cast(size)); - BFloat16ToFloat(cur_actual, f_actual.data(), static_cast(size)); + BFloat16ToFloat(cur_expected, f_expected.data(), size); + BFloat16ToFloat(cur_actual, f_actual.data(), size); // deal with rare cases in which order of output data from a kernel MAY be // undefined @@ -503,7 +503,7 @@ struct TensorCheck { auto tolerance_params = get_tolerance_params(params, provider_type); - for (int64_t i = 0; i < size; ++i) { + for (size_t i = 0; i < size; ++i) { if (std::isnan(f_expected[i])) { EXPECT_TRUE(std::isnan(f_expected[i])) << "Expected NaN. i:" << i; } else if (std::isinf(f_expected[i])) { // Test infinity for equality diff --git a/onnxruntime/test/providers/checkers.h b/onnxruntime/test/unittest_util/checkers.h similarity index 100% rename from onnxruntime/test/providers/checkers.h rename to onnxruntime/test/unittest_util/checkers.h diff --git a/onnxruntime/test/unittest_util/conversion.h b/onnxruntime/test/unittest_util/conversion.h new file mode 100644 index 0000000000000..1f8124399d948 --- /dev/null +++ b/onnxruntime/test/unittest_util/conversion.h @@ -0,0 +1,50 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include + +#include "core/framework/float16.h" +#include "core/util/math_cpuonly.h" + +namespace onnxruntime { +namespace test { +inline void ConvertFloatToMLFloat16(const float* f_datat, MLFloat16* h_data, size_t input_size) { + auto in_vector = ConstEigenVectorMap(f_datat, input_size); + auto output_vector = EigenVectorMap(static_cast(static_cast(h_data)), input_size); + output_vector = in_vector.template cast(); +} + +inline void ConvertFloatToUint8_t(const float* f_datat, uint8_t* u8_data, size_t input_size) { + auto in_vector = ConstEigenVectorMap(f_datat, input_size); + auto output_vector = EigenVectorMap(static_cast(static_cast(u8_data)), input_size); + output_vector = in_vector.template cast(); +} + +inline void ConvertMLFloat16ToFloat(const MLFloat16* h_data, float* f_data, size_t input_size) { + auto in_vector = + ConstEigenVectorMap(static_cast(static_cast(h_data)), input_size); + auto output_vector = EigenVectorMap(f_data, input_size); + output_vector = in_vector.template cast(); +} + +inline std::vector FloatsToMLFloat16s(const std::vector& f) { + std::vector m(f.size()); + ConvertFloatToMLFloat16(f.data(), m.data(), f.size()); + return m; +} + +inline std::vector MakeBFloat16(const std::initializer_list& input) { + std::vector output; + std::transform(input.begin(), input.end(), std::back_inserter(output), [](float f) { return BFloat16(f); }); + return output; +} + +inline std::vector FloatsToBFloat16s(const std::vector& input) { + std::vector output; + std::transform(input.begin(), input.end(), std::back_inserter(output), [](float f) { return BFloat16(f); }); + return output; +} +} // namespace test +} // namespace onnxruntime diff --git a/onnxruntime/test/framework/dummy_allocator.cc b/onnxruntime/test/unittest_util/dummy_allocator.cc similarity index 100% rename from onnxruntime/test/framework/dummy_allocator.cc rename to onnxruntime/test/unittest_util/dummy_allocator.cc diff --git a/onnxruntime/test/framework/dummy_allocator.h b/onnxruntime/test/unittest_util/dummy_allocator.h similarity index 100% rename from onnxruntime/test/framework/dummy_allocator.h rename to onnxruntime/test/unittest_util/dummy_allocator.h diff --git a/onnxruntime/test/framework/test_utils.cc b/onnxruntime/test/unittest_util/framework_test_utils.cc similarity index 95% rename from onnxruntime/test/framework/test_utils.cc rename to onnxruntime/test/unittest_util/framework_test_utils.cc index 310a673008efb..70cce100f5b5a 100644 --- a/onnxruntime/test/framework/test_utils.cc +++ b/onnxruntime/test/unittest_util/framework_test_utils.cc @@ -1,6 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "test_utils.h" +#include "test/unittest_util/framework_test_utils.h" #include "core/graph/graph.h" namespace onnxruntime { diff --git a/onnxruntime/test/framework/test_utils.h b/onnxruntime/test/unittest_util/framework_test_utils.h similarity index 100% rename from onnxruntime/test/framework/test_utils.h rename to onnxruntime/test/unittest_util/framework_test_utils.h diff --git a/onnxruntime/test/contrib_ops/function_test_util.cc b/onnxruntime/test/unittest_util/function_test_util.cc similarity index 97% rename from onnxruntime/test/contrib_ops/function_test_util.cc rename to onnxruntime/test/unittest_util/function_test_util.cc index c0c1437f3a63b..5edcd1db4541b 100644 --- a/onnxruntime/test/contrib_ops/function_test_util.cc +++ b/onnxruntime/test/unittest_util/function_test_util.cc @@ -1,19 +1,19 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include "test/unittest_util/function_test_util.h" + #include #include #include #include -#include "function_test_util.h" - #include "gtest/gtest.h" #include "test/test_environment.h" -#include "test/framework/test_utils.h" +#include "test/unittest_util/framework_test_utils.h" #include "test/common/tensor_op_test_utils.h" -#include "asserts.h" +#include "test/util/include/asserts.h" namespace onnxruntime { namespace test { diff --git a/onnxruntime/test/contrib_ops/function_test_util.h b/onnxruntime/test/unittest_util/function_test_util.h similarity index 92% rename from onnxruntime/test/contrib_ops/function_test_util.h rename to onnxruntime/test/unittest_util/function_test_util.h index b23288fdfdc29..8024ebbed7f25 100644 --- a/onnxruntime/test/contrib_ops/function_test_util.h +++ b/onnxruntime/test/unittest_util/function_test_util.h @@ -11,8 +11,9 @@ #include "core/session/inference_session.h" #include "core/framework/float16.h" +#include "test/common/random_generator.h" #include "test/common/tensor_op_test_utils.h" -#include "test/framework/test_utils.h" +#include "test/unittest_util/framework_test_utils.h" namespace onnxruntime { namespace test { @@ -25,24 +26,20 @@ inline std::vector random(std::vector shape) { template <> inline std::vector random(std::vector shape) { - int64_t size = 1; - for (auto dim : shape) - size *= dim; + const auto size = detail::SizeFromDims(shape); std::vector data(size); - for (int64_t i = 0; i < size; i++) + for (size_t i = 0; i < size; i++) data[i] = static_cast(rand()); return data; } template <> inline std::vector random(std::vector shape) { - int64_t size = 1; - for (auto dim : shape) - size *= dim; + const auto size = detail::SizeFromDims(shape); std::vector data(size); - for (int64_t i = 0; i < size; i++) + for (size_t i = 0; i < size; i++) data[i] = bool(rand() % 2); return data; } @@ -51,7 +48,7 @@ template <> inline std::vector random(std::vector shape) { auto floatdata = random(shape); std::vector data(floatdata.size()); - for (uint64_t i = 0; i < floatdata.size(); i++) + for (size_t i = 0; i < floatdata.size(); i++) data[i] = BFloat16(floatdata[i]); return data; } @@ -60,7 +57,7 @@ template <> inline std::vector random(std::vector shape) { auto floatdata = random(shape); std::vector data(floatdata.size()); - for (uint64_t i = 0; i < floatdata.size(); i++) + for (size_t i = 0; i < floatdata.size(); i++) data[i] = MLFloat16(floatdata[i]); return data; } diff --git a/onnxruntime/test/optimizer/graph_transform_test_builder.cc b/onnxruntime/test/unittest_util/graph_transform_test_builder.cc similarity index 99% rename from onnxruntime/test/optimizer/graph_transform_test_builder.cc rename to onnxruntime/test/unittest_util/graph_transform_test_builder.cc index 756cc4159e6f2..5caafa0f379d4 100644 --- a/onnxruntime/test/optimizer/graph_transform_test_builder.cc +++ b/onnxruntime/test/unittest_util/graph_transform_test_builder.cc @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "test/optimizer/graph_transform_test_builder.h" +#include "test/unittest_util/graph_transform_test_builder.h" #include #include diff --git a/onnxruntime/test/optimizer/graph_transform_test_builder.h b/onnxruntime/test/unittest_util/graph_transform_test_builder.h similarity index 99% rename from onnxruntime/test/optimizer/graph_transform_test_builder.h rename to onnxruntime/test/unittest_util/graph_transform_test_builder.h index 26df588eab73f..bc3c9535a853a 100644 --- a/onnxruntime/test/optimizer/graph_transform_test_builder.h +++ b/onnxruntime/test/unittest_util/graph_transform_test_builder.h @@ -14,9 +14,9 @@ #include "core/optimizer/graph_transformer_level.h" #include "core/graph/onnx_protobuf.h" #include "core/framework/tensorprotoutils.h" -#include "test/framework/test_utils.h" +#include "test/unittest_util/framework_test_utils.h" #include "test/common/tensor_op_test_utils.h" -#include "test/framework/test_utils.h" +#include "test/unittest_util/framework_test_utils.h" #include "test/util/include/inference_session_wrapper.h" #define TEST_RETURN_IF(condition) \ diff --git a/onnxruntime/test/providers/model_tester.h b/onnxruntime/test/unittest_util/model_tester.h similarity index 97% rename from onnxruntime/test/providers/model_tester.h rename to onnxruntime/test/unittest_util/model_tester.h index 1bcab58c80f37..d1fc22e23e6a1 100644 --- a/onnxruntime/test/providers/model_tester.h +++ b/onnxruntime/test/unittest_util/model_tester.h @@ -7,7 +7,7 @@ #include "core/graph/model.h" #include "core/session/environment.h" -#include "test/providers/base_tester.h" +#include "test/unittest_util/base_tester.h" #include "test/util/include/asserts.h" #include "test/util/include/test_environment.h" diff --git a/onnxruntime/test/providers/op_tester.cc b/onnxruntime/test/unittest_util/op_tester.cc similarity index 98% rename from onnxruntime/test/providers/op_tester.cc rename to onnxruntime/test/unittest_util/op_tester.cc index b2eaaac192b91..b633989866277 100644 --- a/onnxruntime/test/providers/op_tester.cc +++ b/onnxruntime/test/unittest_util/op_tester.cc @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "test/providers/op_tester.h" +#include "test/unittest_util/op_tester.h" #include "gtest/gtest.h" #include "gmock/gmock.h" diff --git a/onnxruntime/test/providers/op_tester.h b/onnxruntime/test/unittest_util/op_tester.h similarity index 98% rename from onnxruntime/test/providers/op_tester.h rename to onnxruntime/test/unittest_util/op_tester.h index f57cc43149e4e..2d986b2d428bf 100644 --- a/onnxruntime/test/providers/op_tester.h +++ b/onnxruntime/test/unittest_util/op_tester.h @@ -4,7 +4,7 @@ #pragma once #include "core/graph/constants.h" -#include "test/providers/base_tester.h" +#include "test/unittest_util/base_tester.h" namespace onnxruntime { class InferenceSession; diff --git a/onnxruntime/test/optimizer/qdq_test_utils.cc b/onnxruntime/test/unittest_util/qdq_test_utils.cc similarity index 100% rename from onnxruntime/test/optimizer/qdq_test_utils.cc rename to onnxruntime/test/unittest_util/qdq_test_utils.cc diff --git a/onnxruntime/test/optimizer/qdq_test_utils.h b/onnxruntime/test/unittest_util/qdq_test_utils.h similarity index 100% rename from onnxruntime/test/optimizer/qdq_test_utils.h rename to onnxruntime/test/unittest_util/qdq_test_utils.h diff --git a/onnxruntime/test/providers/run_options_config_keys.h b/onnxruntime/test/unittest_util/run_options_config_keys.h similarity index 100% rename from onnxruntime/test/providers/run_options_config_keys.h rename to onnxruntime/test/unittest_util/run_options_config_keys.h diff --git a/onnxruntime/test/framework/TestAllocatorManager.cc b/onnxruntime/test/unittest_util/test_allocator_manager.cc similarity index 95% rename from onnxruntime/test/framework/TestAllocatorManager.cc rename to onnxruntime/test/unittest_util/test_allocator_manager.cc index 6440a805cdc59..561272803db9d 100644 --- a/onnxruntime/test/framework/TestAllocatorManager.cc +++ b/onnxruntime/test/unittest_util/test_allocator_manager.cc @@ -1,12 +1,12 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "test/framework/TestAllocatorManager.h" +#include "test/unittest_util/test_allocator_manager.h" namespace onnxruntime { namespace test { -// Dummy Arena which just call underline device allocator directly. +// Dummy Arena which just call underlying device allocator directly. class DummyArena : public IAllocator { public: explicit DummyArena(std::unique_ptr resource_allocator) diff --git a/onnxruntime/test/framework/TestAllocatorManager.h b/onnxruntime/test/unittest_util/test_allocator_manager.h similarity index 100% rename from onnxruntime/test/framework/TestAllocatorManager.h rename to onnxruntime/test/unittest_util/test_allocator_manager.h diff --git a/onnxruntime/test/unittest_util/test_dynamic_plugin_ep.cc b/onnxruntime/test/unittest_util/test_dynamic_plugin_ep.cc new file mode 100644 index 0000000000000..6ac741fb616a8 --- /dev/null +++ b/onnxruntime/test/unittest_util/test_dynamic_plugin_ep.cc @@ -0,0 +1,201 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "test/unittest_util/test_dynamic_plugin_ep.h" + +#include +#include +#include +#include + +#include "nlohmann/json.hpp" + +#include "core/common/common.h" +#include "core/framework/execution_provider.h" +#include "core/session/abi_session_options_impl.h" +#include "core/session/ort_env.h" +#include "core/session/utils.h" +#include "test/util/include/test_environment.h" + +namespace onnxruntime::test::dynamic_plugin_ep_infra { + +namespace { + +using PluginEpLibraryRegistrationHandle = std::unique_ptr>; + +PluginEpLibraryRegistrationHandle RegisterPluginEpLibrary(Ort::Env& env, + const std::string& ep_library_registration_name, + const std::basic_string& ep_library_path) { + env.RegisterExecutionProviderLibrary(ep_library_registration_name.c_str(), ep_library_path); + + auto unregister_ep_library = [&env, registration_name = ep_library_registration_name](void* p) { + if (p == nullptr) { + return; + } + + ORT_TRY { + env.UnregisterExecutionProviderLibrary(registration_name.c_str()); + } + ORT_CATCH(const Ort::Exception& e) { + ORT_HANDLE_EXCEPTION([&]() { + std::cerr << "Failed to unregister EP library with name '" << registration_name << "': " << e.what() << "\n"; + }); + } + }; + + // Set `handle_value` to something not equal to nullptr. The particular value doesn't really matter. + // We are just using the unique_ptr deleter to unregister the EP library. + void* const handle_value = reinterpret_cast(0x1); + return PluginEpLibraryRegistrationHandle{handle_value, unregister_ep_library}; +} + +void StrMapToKeyValueCstrVectors(const std::map& m, + std::vector& keys_out, std::vector& values_out) { + std::vector keys, values{}; + keys.reserve(m.size()); + values.reserve(m.size()); + for (auto& [key, value] : m) { + keys.push_back(key.c_str()); + values.push_back(value.c_str()); + } + keys_out = std::move(keys); + values_out = std::move(values); +} + +struct PluginEpInfrastructureState { + InitializationConfig config{}; + PluginEpLibraryRegistrationHandle plugin_ep_library_registration_handle{}; + std::unique_ptr ep_factory{}; + std::vector selected_c_ep_devices{}; + std::string ep_name{}; +}; + +std::optional g_plugin_ep_infrastructure_state{}; + +} // namespace + +Status ParseInitializationConfig(std::string_view json_str, InitializationConfig& config_out) { + using json = nlohmann::json; + Status status = Status::OK(); + ORT_TRY { + InitializationConfig config{}; + const auto parsed_json = json::parse(json_str); + + // required keys + parsed_json.at("ep_library_registration_name").get_to(config.ep_library_registration_name); + parsed_json.at("ep_library_path").get_to(config.ep_library_path); + + // optional keys + config.default_ep_options = parsed_json.value("default_ep_options", {}); + config.selected_ep_name = parsed_json.value("selected_ep_name", {}); + config.selected_ep_device_indices = + parsed_json.value("selected_ep_device_indices", {}); + + config_out = std::move(config); + return Status::OK(); + } + ORT_CATCH(const json::exception& e) { + ORT_HANDLE_EXCEPTION([&]() { + constexpr std::string_view kExampleValidJsonStr = + "{\n" + " \"ep_library_registration_name\": \"example_plugin_ep\",\n" + " \"ep_library_path\": \"/path/to/example_plugin_ep.dll\",\n" + " \"selected_ep_name\": \"example_plugin_ep\"\n" + "}"; + + status = ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "JSON parse error: ", e.what(), + "\nThis is an example valid JSON configuration:\n", kExampleValidJsonStr); + }); + } + return status; +} + +Status Initialize(Ort::Env& env, InitializationConfig config) { + ORT_RETURN_IF(IsInitialized(), "Already initialized."); + + auto ep_library_registration_handle = RegisterPluginEpLibrary(env, config.ep_library_registration_name, + ToPathString(config.ep_library_path)); + + ORT_RETURN_IF(config.selected_ep_device_indices.empty() == config.selected_ep_name.empty(), + "Exactly one of selected_ep_device_indices or selected_ep_name should be specified."); + + const auto ep_devices = env.GetEpDevices(); + std::vector selected_c_ep_devices{}; + + if (!config.selected_ep_device_indices.empty()) { + for (const auto idx : config.selected_ep_device_indices) { + ORT_RETURN_IF(idx >= ep_devices.size(), "Selected EP device index is out of range: ", idx); + selected_c_ep_devices.push_back(ep_devices[idx]); + } + } else { + std::copy_if(ep_devices.begin(), ep_devices.end(), std::back_inserter(selected_c_ep_devices), + [&selected_ep_name = std::as_const(config.selected_ep_name)](Ort::ConstEpDevice ep_device) { + return ep_device.EpName() == selected_ep_name; + }); + } + + ORT_RETURN_IF(selected_c_ep_devices.empty(), "No EP devices were selected."); + + std::unique_ptr ep_factory{}; + ORT_RETURN_IF_ERROR( + CreateIExecutionProviderFactoryForEpDevices(static_cast(env)->GetEnvironment(), + selected_c_ep_devices, + ep_factory)); + + // Note: CreateIExecutionProviderFactoryForEpDevices() ensures that all EP devices refer to the same EP, so we will + // just get the EP name from the first one. + std::string ep_name = Ort::ConstEpDevice{selected_c_ep_devices.front()}.EpName(); + + auto state = PluginEpInfrastructureState{}; + state.config = std::move(config); + state.plugin_ep_library_registration_handle = std::move(ep_library_registration_handle); + state.ep_factory = std::move(ep_factory); + state.selected_c_ep_devices = std::move(selected_c_ep_devices); + state.ep_name = std::move(ep_name); + + g_plugin_ep_infrastructure_state = std::move(state); + return Status::OK(); +} + +bool IsInitialized() { + return g_plugin_ep_infrastructure_state.has_value(); +} + +void Shutdown() { + g_plugin_ep_infrastructure_state.reset(); +} + +std::unique_ptr MakeEp(const logging::Logger* logger) { + if (!IsInitialized()) { + return nullptr; + } + + if (logger == nullptr) { + logger = &DefaultLoggingManager().DefaultLogger(); + } + + const auto& state = *g_plugin_ep_infrastructure_state; + + std::vector default_ep_option_key_cstrs{}, default_ep_option_value_cstrs{}; + StrMapToKeyValueCstrVectors(state.config.default_ep_options, + default_ep_option_key_cstrs, default_ep_option_value_cstrs); + + OrtSessionOptions ort_session_options{}; + ORT_THROW_IF_ERROR(AddEpOptionsToSessionOptions(state.selected_c_ep_devices, + default_ep_option_key_cstrs, + default_ep_option_value_cstrs, + ort_session_options.value)); + + return state.ep_factory->CreateProvider(ort_session_options, *logger->ToExternal()); +} + +std::optional GetEpName() { + if (!IsInitialized()) { + return std::nullopt; + } + + return g_plugin_ep_infrastructure_state->ep_name; +} + +} // namespace onnxruntime::test::dynamic_plugin_ep_infra diff --git a/onnxruntime/test/unittest_util/test_dynamic_plugin_ep.h b/onnxruntime/test/unittest_util/test_dynamic_plugin_ep.h new file mode 100644 index 0000000000000..9815eb685f64f --- /dev/null +++ b/onnxruntime/test/unittest_util/test_dynamic_plugin_ep.h @@ -0,0 +1,80 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "onnxruntime_cxx_api.h" + +#include "core/common/status.h" + +namespace onnxruntime { +struct IExecutionProviderFactory; +class IExecutionProvider; + +namespace logging { +class Logger; +} // namespace logging + +namespace test { + +// `onnxruntime::test::dynamic_plugin_ep_infra` contains functions and types related to dynamically loaded plugin EP +// unit testing infrastructure. +namespace dynamic_plugin_ep_infra { + +// Note: `Initialize()` and `Shutdown()` are not thread-safe. +// They should be called before and after calls to most of the other functions in this namespace. +// The exception to this is `ParseInitializationConfig()`, which may be called before `Initialize()`. + +// Configuration for initializing the dynamic plugin EP infrastructure. +struct InitializationConfig { + std::string ep_library_registration_name{}; + std::string ep_library_path{}; + + // Note: Exactly one of `selected_ep_name` or `selected_ep_device_indices` should be set. + // An empty value for either means it is unset. + + // Specifies the EP devices matching this EP name as the selected EP devices. + std::string selected_ep_name{}; + // Specifies the selected EP devices by their indices. + std::vector selected_ep_device_indices{}; + + std::map default_ep_options{}; +}; + +// Parses `InitializationConfig` from JSON. +// The configuration JSON object should have keys and values that match the `InitializationConfig` fields. +// E.g.: +// { +// "ep_library_registration_name": "example_plugin_ep", +// "ep_library_path": "/path/to/example_plugin_ep.dll", +// "selected_ep_name": "example_plugin_ep", +// "default_ep_options": { "option_key": "option_value" } +// } +Status ParseInitializationConfig(std::string_view json_str, InitializationConfig& config); + +// Initializes dynamic plugin EP infrastructure. +Status Initialize(Ort::Env& env, InitializationConfig config); + +// Gets whether the dynamic plugin EP infrastructure is initialized. +bool IsInitialized(); + +// Shuts down dynamic plugin EP infrastructure. +// This does not require a previously successful call to `Initialize()`. +void Shutdown(); + +// Returns a dynamic plugin EP `IExecutionProvider` instance, or `nullptr` if uninitialized. +std::unique_ptr MakeEp(const logging::Logger* logger = nullptr); + +// Gets the dynamic plugin EP name, or `std::nullopt` if uninitialized. +std::optional GetEpName(); + +} // namespace dynamic_plugin_ep_infra +} // namespace test +} // namespace onnxruntime diff --git a/onnxruntime/test/providers/tester_types.h b/onnxruntime/test/unittest_util/tester_types.h similarity index 100% rename from onnxruntime/test/providers/tester_types.h rename to onnxruntime/test/unittest_util/tester_types.h diff --git a/onnxruntime/test/util/default_providers.cc b/onnxruntime/test/util/default_providers.cc index bae7a14908916..cea3feeb927af 100644 --- a/onnxruntime/test/util/default_providers.cc +++ b/onnxruntime/test/util/default_providers.cc @@ -2,18 +2,20 @@ // SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates // Licensed under the MIT License. +#include "test/util/include/default_providers.h" + #include -#include "default_providers.h" -#include "providers.h" + +#include "core/framework/session_options.h" #include "core/providers/cpu/cpu_provider_factory_creator.h" #ifdef USE_COREML #include "core/providers/coreml/coreml_provider_factory.h" #endif #ifdef USE_CUDA -#include +#include "core/providers/cuda/cuda_provider_options.h" #endif #include "core/session/onnxruntime_cxx_api.h" -#include "core/framework/session_options.h" +#include "test/util/include/providers.h" namespace onnxruntime { @@ -330,5 +332,6 @@ std::unique_ptr DefaultDmlExecutionProvider() { std::unique_ptr DefaultRocmExecutionProvider(bool) { return nullptr; } + } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/util/include/default_providers.h b/onnxruntime/test/util/include/default_providers.h index 67d85edb4b8ef..ab3136c0b7b33 100644 --- a/onnxruntime/test/util/include/default_providers.h +++ b/onnxruntime/test/util/include/default_providers.h @@ -69,8 +69,5 @@ std::unique_ptr WebGpuExecutionProviderWithOptions(const Con std::unique_ptr DefaultCannExecutionProvider(); std::unique_ptr DefaultDmlExecutionProvider(); -std::unique_ptr DefaultInternalTestingExecutionProvider( - const std::unordered_set& supported_ops); - } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/wasm/onnxruntime_test_all_adapter.js b/onnxruntime/test/wasm/onnxruntime_test_adapter.js similarity index 93% rename from onnxruntime/test/wasm/onnxruntime_test_all_adapter.js rename to onnxruntime/test/wasm/onnxruntime_test_adapter.js index 54c4c5b9d725d..b62ac52557490 100644 --- a/onnxruntime/test/wasm/onnxruntime_test_all_adapter.js +++ b/onnxruntime/test/wasm/onnxruntime_test_adapter.js @@ -3,7 +3,7 @@ 'use strict'; -// This file is used to be injected into "onnxruntime_test_all" as specified by flag "--pre-js" by emcc. +// This file is injected into unit test programs like "onnxruntime_test_all" as specified by flag "--pre-js" by emcc. // It dumps the test report file from emscripten's MEMFS to real file system // Module is predefined for scripts injected from "--pre-js" diff --git a/orttraining/orttraining/test/framework/slice_concatenate_test.cc b/orttraining/orttraining/test/framework/slice_concatenate_test.cc index 1eda3a7be793e..b5962fc409202 100644 --- a/orttraining/orttraining/test/framework/slice_concatenate_test.cc +++ b/orttraining/orttraining/test/framework/slice_concatenate_test.cc @@ -8,7 +8,7 @@ #include "gtest/gtest.h" #include "orttraining/core/session/tensor_helper.h" #include "test/util/include/default_providers.h" -#include "test/framework/test_utils.h" +#include "test/unittest_util/framework_test_utils.h" #include "test/test_environment.h" namespace onnxruntime { diff --git a/orttraining/orttraining/test/gradient/allreduce_op_test.cc b/orttraining/orttraining/test/gradient/allreduce_op_test.cc index 1b1bd680a1191..6f7045d64228f 100644 --- a/orttraining/orttraining/test/gradient/allreduce_op_test.cc +++ b/orttraining/orttraining/test/gradient/allreduce_op_test.cc @@ -12,7 +12,7 @@ #include "orttraining/core/framework/communication/mpi/mpi_context.h" #include "orttraining/models/runner/training_runner.h" -#include "test/framework/test_utils.h" +#include "test/unittest_util/framework_test_utils.h" #include "test/providers/provider_test_utils.h" #include "test/test_environment.h" #include "test/util/include/default_providers.h" diff --git a/orttraining/orttraining/test/gradient/function_ops_test.cc b/orttraining/orttraining/test/gradient/function_ops_test.cc index 2cd6f4acd8da8..29a13bae73043 100644 --- a/orttraining/orttraining/test/gradient/function_ops_test.cc +++ b/orttraining/orttraining/test/gradient/function_ops_test.cc @@ -6,7 +6,7 @@ #include "core/graph/contrib_ops/contrib_defs.h" #include "orttraining/core/graph/training_op_defs.h" -#include "test/contrib_ops/function_test_util.h" +#include "test/unittest_util/function_test_util.h" using namespace ::onnxruntime::common; diff --git a/orttraining/orttraining/test/graph/gradient_graph_builder_test.cc b/orttraining/orttraining/test/graph/gradient_graph_builder_test.cc index 358deb421bc21..edabcb67aa586 100644 --- a/orttraining/orttraining/test/graph/gradient_graph_builder_test.cc +++ b/orttraining/orttraining/test/graph/gradient_graph_builder_test.cc @@ -5,7 +5,7 @@ #include "gtest/gtest.h" #include "orttraining/core/optimizer/gist_encode_decode.h" #include "test/providers/provider_test_utils.h" -#include "test/framework/test_utils.h" +#include "test/unittest_util/framework_test_utils.h" #include "test/util/include/default_providers.h" #include "core/common/path_utils.h" #include "core/providers/cpu/cpu_execution_provider.h" diff --git a/orttraining/orttraining/test/graph/optimizer_graph_builder_test.cc b/orttraining/orttraining/test/graph/optimizer_graph_builder_test.cc index 84a45398534e7..605e55a26f880 100644 --- a/orttraining/orttraining/test/graph/optimizer_graph_builder_test.cc +++ b/orttraining/orttraining/test/graph/optimizer_graph_builder_test.cc @@ -22,7 +22,7 @@ #include "orttraining/core/graph/adasum_optimizer_graph_builder.h" #endif #include "orttraining/core/graph/zero_optimizer_graph_builder.h" -#include "test/framework/test_utils.h" +#include "test/unittest_util/framework_test_utils.h" #include "test/util/include/asserts.h" #include "test/test_environment.h" #include "orttraining/test/session/training_session_test_utils.h" diff --git a/orttraining/orttraining/test/optimizer/compute_optimizer_test.cc b/orttraining/orttraining/test/optimizer/compute_optimizer_test.cc index c860788d3db91..d7f0f2ce9b743 100644 --- a/orttraining/orttraining/test/optimizer/compute_optimizer_test.cc +++ b/orttraining/orttraining/test/optimizer/compute_optimizer_test.cc @@ -24,14 +24,14 @@ #include "test/common/tensor_op_test_utils.h" #include "test/compare_ortvalue.h" -#include "test/framework/test_utils.h" -#include "test/optimizer/graph_transform_test_builder.h" +#include "test/unittest_util/framework_test_utils.h" #include "test/optimizer/graph_transform_test_fixture.h" #include "test/optimizer/test_optimizer_utils.h" #include "test/providers/provider_test_utils.h" -#include "test/util/include/temp_dir.h" +#include "test/unittest_util/graph_transform_test_builder.h" #include "test/util/include/asserts.h" #include "test/util/include/default_providers.h" +#include "test/util/include/temp_dir.h" using namespace onnxruntime::optimizer::compute_optimizer; diff --git a/orttraining/orttraining/test/optimizer/graph_transform_test.cc b/orttraining/orttraining/test/optimizer/graph_transform_test.cc index 6f5d6e1389443..85b65557cc360 100644 --- a/orttraining/orttraining/test/optimizer/graph_transform_test.cc +++ b/orttraining/orttraining/test/optimizer/graph_transform_test.cc @@ -4,7 +4,7 @@ #include "core/session/inference_session.h" #include "core/graph/model.h" -#include "test/framework/test_utils.h" +#include "test/unittest_util/framework_test_utils.h" #include "test/test_environment.h" #include "gtest/gtest.h" @@ -19,10 +19,10 @@ #include "orttraining/core/optimizer/batchnorm_replacement.h" #include "orttraining/core/optimizer/localized_recompute.h" #include "orttraining/core/optimizer/transpose_replacement.h" -#include "test/optimizer/graph_transform_test_builder.h" #include "test/optimizer/graph_transform_test_fixture.h" -#include "test/util/include/default_providers.h" +#include "test/unittest_util/graph_transform_test_builder.h" #include "test/util/include/asserts.h" +#include "test/util/include/default_providers.h" #include "orttraining/test/optimizer/horizontal_parallel_test_utils.h" #include "orttraining/core/session/training_session.h" #include "orttraining/core/optimizer/cast_sce_loss_fusion.h" diff --git a/orttraining/orttraining/test/optimizer/graph_transformer_utils_test.cc b/orttraining/orttraining/test/optimizer/graph_transformer_utils_test.cc index 1b8699d1de497..b9509370deadc 100644 --- a/orttraining/orttraining/test/optimizer/graph_transformer_utils_test.cc +++ b/orttraining/orttraining/test/optimizer/graph_transformer_utils_test.cc @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "test/framework/test_utils.h" +#include "test/unittest_util/framework_test_utils.h" #include "test/capturing_sink.h" #include "test/test_environment.h" #include "gtest/gtest.h" diff --git a/orttraining/orttraining/test/optimizer/memory_optimizer_test.cc b/orttraining/orttraining/test/optimizer/memory_optimizer_test.cc index a0959c0df8868..9a4cd97fadce2 100644 --- a/orttraining/orttraining/test/optimizer/memory_optimizer_test.cc +++ b/orttraining/orttraining/test/optimizer/memory_optimizer_test.cc @@ -18,14 +18,14 @@ #include "core/graph/graph_utils.h" #include "core/graph/graph_viewer.h" #include "core/graph/model.h" -#include "test/optimizer/graph_transform_test_builder.h" #include "core/optimizer/utils.h" #include "core/platform/env.h" #include "core/session/inference_session.h" #include "core/util/math.h" -#include "test/framework/test_utils.h" +#include "test/unittest_util/framework_test_utils.h" #include "test/capturing_sink.h" #include "test/test_environment.h" +#include "test/unittest_util/graph_transform_test_builder.h" #include "test/util/include/asserts.h" #include "test/util/include/temp_dir.h" #include "orttraining/core/optimizer/memory_optimizer/common.h" diff --git a/orttraining/orttraining/test/optimizer/shape_optimizer_test.cc b/orttraining/orttraining/test/optimizer/shape_optimizer_test.cc index a1629eb73eeb6..f81ce5fa91277 100644 --- a/orttraining/orttraining/test/optimizer/shape_optimizer_test.cc +++ b/orttraining/orttraining/test/optimizer/shape_optimizer_test.cc @@ -2,13 +2,13 @@ // Licensed under the MIT License. #include "core/graph/model.h" -#include "test/framework/test_utils.h" +#include "test/unittest_util/framework_test_utils.h" #include "test/test_environment.h" #include "gtest/gtest.h" #include "core/optimizer/utils.h" -#include "test/optimizer/graph_transform_test_builder.h" #include "test/optimizer/graph_transform_test_fixture.h" +#include "test/unittest_util/graph_transform_test_builder.h" #include "test/util/include/asserts.h" #include "orttraining/core/optimizer/shape_optimizer.h" diff --git a/orttraining/orttraining/test/session/training_session_test.cc b/orttraining/orttraining/test/session/training_session_test.cc index b6ed80c426afc..a3f6d917a76b6 100644 --- a/orttraining/orttraining/test/session/training_session_test.cc +++ b/orttraining/orttraining/test/session/training_session_test.cc @@ -5,7 +5,7 @@ #include "gtest/gtest.h" #include "orttraining/core/optimizer/gist_encode_decode.h" #include "test/providers/provider_test_utils.h" -#include "test/framework/test_utils.h" +#include "test/unittest_util/framework_test_utils.h" #include "core/common/path_utils.h" #include "core/providers/cpu/cpu_execution_provider.h" #include "core/session/environment.h" diff --git a/orttraining/orttraining/test/session/training_session_test_utils.h b/orttraining/orttraining/test/session/training_session_test_utils.h index 7565fd64514f2..4ba092b951081 100644 --- a/orttraining/orttraining/test/session/training_session_test_utils.h +++ b/orttraining/orttraining/test/session/training_session_test_utils.h @@ -5,7 +5,7 @@ #include "gtest/gtest.h" #include "orttraining/core/optimizer/gist_encode_decode.h" #include "test/providers/provider_test_utils.h" -#include "test/framework/test_utils.h" +#include "test/unittest_util/framework_test_utils.h" #include "core/common/path_utils.h" #include "core/providers/cpu/cpu_execution_provider.h" #include "core/session/environment.h" diff --git a/orttraining/orttraining/test/training_api/core/data_utils.h b/orttraining/orttraining/test/training_api/core/data_utils.h index 7724bc0c26fa3..4323d0ee171f9 100644 --- a/orttraining/orttraining/test/training_api/core/data_utils.h +++ b/orttraining/orttraining/test/training_api/core/data_utils.h @@ -9,7 +9,7 @@ #include "core/framework/ort_value.h" #include "core/framework/tensor.h" #include "default_providers.h" -#include "test/framework/test_utils.h" +#include "test/unittest_util/framework_test_utils.h" #include "test/util/include/test_utils.h" namespace onnxruntime::training::test { diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index edceae55ddda4..dcbb2143f4af7 100644 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -1559,15 +1559,14 @@ def dump_logs_on_failure(): test_data_dir = os.path.join(source_dir, "cmake", "external", "onnx", "onnx", "backend", "test") if os.path.exists(test_data_dir): adb_push(test_data_dir, device_dir + "/test", cwd=cwd) - adb_push("onnxruntime_test_all", device_dir, cwd=cwd) - adb_shell(f"chmod +x {device_dir}/onnxruntime_test_all") - adb_push("onnx_test_runner", device_dir, cwd=cwd) - adb_shell(f"chmod +x {device_dir}/onnx_test_runner") - run_adb_shell(f"{device_dir}/onnxruntime_test_all") - # remove onnxruntime_test_all as it takes up a _lot_ of space and can cause insufficient storage errors - # when we try to copy the java app to the device. - adb_shell(f"rm {device_dir}/onnxruntime_test_all") + for test_program_name in ["onnxruntime_test_all", "onnxruntime_provider_test"]: + adb_push(test_program_name, device_dir, cwd=cwd) + adb_shell(f"chmod +x {device_dir}/{test_program_name}") + run_adb_shell(f"{device_dir}/{test_program_name}") + + # remove test program when we are done testing to free up space + adb_shell(f"rm {device_dir}/{test_program_name}") if args.build_java: # use the gradle wrapper under /java @@ -1584,6 +1583,9 @@ def dump_logs_on_failure(): cwd=android_test_path, ) + adb_push("onnx_test_runner", device_dir, cwd=cwd) + adb_shell(f"chmod +x {device_dir}/onnx_test_runner") + if args.use_nnapi: run_adb_shell(f"{device_dir}/onnx_test_runner -e nnapi {device_dir}/test") else: @@ -1636,6 +1638,7 @@ def run_ios_tests(args, source_dir, config, cwd): xc_test_schemes = [ "onnxruntime_test_all_xc", + "onnxruntime_provider_test_xc", ] if args.build_shared_lib: @@ -1717,7 +1720,7 @@ def run_onnxruntime_tests(args, source_dir, ctest_path, build_dir, configs): dll_path = os.pathsep.join(dll_path_list) if not ctest_path and not is_windows(): - executables = ["onnxruntime_test_all", "onnxruntime_mlas_test"] + executables = ["onnxruntime_test_all", "onnxruntime_mlas_test", "onnxruntime_provider_test"] if args.build_shared_lib: executables.append("onnxruntime_shared_lib_test") executables.append("onnxruntime_global_thread_pools_test") @@ -1726,7 +1729,7 @@ def run_onnxruntime_tests(args, source_dir, ctest_path, build_dir, configs): test_output = f"--gtest_output=xml:{cwd}/{exe}.{config}.results.xml" run_subprocess([os.path.join(cwd, exe), test_output], cwd=cwd, dll_path=dll_path) else: - ctest_cmd = [ctest_path, "--build-config", config, "--verbose", "--timeout", args.test_all_timeout] + ctest_cmd = [ctest_path, "--build-config", config, "--verbose", "--timeout", args.ctest_timeout] run_subprocess(ctest_cmd, cwd=cwd, dll_path=dll_path) if args.enable_pybind: @@ -2338,7 +2341,7 @@ def main(): if args.test and args.disable_wasm_exception_catching and not args.minimal_build: raise BuildError("WebAssembly tests need exception catching enabled to run if it's not minimal build") if args.test and args.enable_wasm_debug_info: - # With flag --enable_wasm_debug_info, onnxruntime_test_all.wasm will be very huge (>1GB). This will fail + # With flag --enable_wasm_debug_info, the test program .wasm file will be very huge (>1GB). This will fail # Node.js when trying to load the .wasm file. # To debug ONNX Runtime WebAssembly, use ONNX Runtime Web to debug ort-wasm.wasm in browsers. raise BuildError("WebAssembly tests cannot be enabled with flag --enable_wasm_debug_info") diff --git a/tools/ci_build/build_args.py b/tools/ci_build/build_args.py index de538604aac75..f833b88ed5836 100644 --- a/tools/ci_build/build_args.py +++ b/tools/ci_build/build_args.py @@ -212,7 +212,7 @@ def add_testing_args(parser: argparse.ArgumentParser) -> None: parser.add_argument("--skip_onnx_tests", action="store_true", help="Explicitly disable ONNX related tests.") parser.add_argument("--skip_winml_tests", action="store_true", help="Explicitly disable WinML related tests.") parser.add_argument("--skip_nodejs_tests", action="store_true", help="Explicitly disable Node.js binding tests.") - parser.add_argument("--test_all_timeout", default="10800", help="Timeout for onnxruntime_test_all (seconds).") + parser.add_argument("--ctest_timeout", default="10800", help="Timeout provided to CTest --timeout (seconds).") parser.add_argument("--enable_transformers_tool_test", action="store_true", help="Enable transformers tool test.") parser.add_argument("--build_micro_benchmarks", action="store_true", help="Build ONNXRuntime micro-benchmarks.") parser.add_argument("--code_coverage", action="store_true", help="Generate code coverage report (Android only).") diff --git a/tools/ci_build/github/azure-pipelines/build-perf-test-binaries-pipeline.yml b/tools/ci_build/github/azure-pipelines/build-perf-test-binaries-pipeline.yml index 5cf5cd8c936fa..53b62762319ba 100644 --- a/tools/ci_build/github/azure-pipelines/build-perf-test-binaries-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/build-perf-test-binaries-pipeline.yml @@ -7,7 +7,6 @@ parameters: default: true stages: - # build binaries for Android - ${{ if parameters.BuildAndroidBinaries }}: - stage: BuildAndroidBinaries diff --git a/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml b/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml index d59e593ef916e..cb3adbfd881fd 100644 --- a/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml +++ b/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml @@ -122,12 +122,12 @@ extends: PreReleaseVersionSuffixString: ${{ parameters.PreReleaseVersionSuffixString }} PreReleaseVersionSuffixNumber: ${{ parameters.PreReleaseVersionSuffixNumber }} - - template: stages/download-java-tools-stage.yml - - template: templates/c-api-cpu.yml parameters: RunOnnxRuntimeTests: ${{ parameters.RunOnnxRuntimeTests }} IsReleaseBuild: ${{ parameters.IsReleaseBuild }} + PreReleaseVersionSuffixString: ${{ parameters.PreReleaseVersionSuffixString }} + PreReleaseVersionSuffixNumber: ${{ parameters.PreReleaseVersionSuffixNumber }} ${{ if eq(parameters.NugetPackageSuffix, 'NONE') }}: OrtNugetPackageId: 'Microsoft.ML.OnnxRuntime' ${{ else }}: @@ -135,16 +135,10 @@ extends: AdditionalBuildFlags: '' AdditionalWinBuildFlags: '--enable_onnx_tests ${{parameters.AdditionalBuildFlag}}' BuildVariant: 'default' - SpecificArtifact: ${{ parameters.SpecificArtifact }} - BuildId: ${{ parameters.BuildId }} QnnSDKVersion: ${{ parameters.QnnSdk }} is1ES: true - template: stages/java-cuda-packaging-stage.yml - parameters: - CudaVersion: 12.2 - SpecificArtifact: ${{ parameters.SpecificArtifact }} - BuildId: ${{ parameters.BuildId }} - template: stages/nuget-combine-cuda-stage.yml parameters: @@ -159,6 +153,8 @@ extends: buildNodejs: true SpecificArtifact: ${{ parameters.SpecificArtifact }} BuildId: ${{ parameters.BuildId }} + PreReleaseVersionSuffixString: ${{ parameters.PreReleaseVersionSuffixString }} + PreReleaseVersionSuffixNumber: ${{ parameters.PreReleaseVersionSuffixNumber }} - template: stages/nodejs-win-packaging-stage.yml parameters: diff --git a/tools/ci_build/github/azure-pipelines/c-api-noopenmp-test-pipelines.yml b/tools/ci_build/github/azure-pipelines/c-api-noopenmp-test-pipelines.yml index 9938b7a9dcf26..1ae525514bbaf 100644 --- a/tools/ci_build/github/azure-pipelines/c-api-noopenmp-test-pipelines.yml +++ b/tools/ci_build/github/azure-pipelines/c-api-noopenmp-test-pipelines.yml @@ -175,6 +175,8 @@ stages: artifact: 'Windows_Packaging_cuda_build_artifacts' displayName: 'Download Windows GPU Packages Build' + - template: templates/setup-build-tools.yml + - task: CmdLine@2 inputs: script: | @@ -188,17 +190,6 @@ stages: jdkArchitectureOption: x64 jdkSourceOption: 'PreInstalled' - - task: UsePythonVersion@0 - inputs: - versionSpec: '3.12' - addToPath: true - architecture: x64 - - - task: PipAuthenticate@1 - displayName: 'Pip Authenticate' - inputs: - artifactFeeds: 'Lotus' - - task: PythonScript@0 displayName: 'Update CTest Path References' inputs: @@ -207,10 +198,6 @@ stages: "$(Build.BinariesDirectory)/RelWithDebInfo/CTestTestfile.cmake" "$(Build.BinariesDirectory)/RelWithDebInfo" - - task: NodeTool@0 - inputs: - versionSpec: '22.x' - - template: templates/jobs/download_win_gpu_library.yml parameters: CudaVersion: 12.2 @@ -223,12 +210,6 @@ stages: scriptPath: '$(Build.SourcesDirectory)\tools\ci_build\build.py' arguments: '--config RelWithDebInfo --use_binskim_compliant_compile_flags --enable_lto --disable_rtti --build_dir $(Build.BinariesDirectory) --skip_submodule_sync --build_shared_lib --test --enable_onnx_tests' workingDirectory: '$(Build.BinariesDirectory)' - # Previous stage only assembles the java binaries, testing will be done in this stage with GPU machine - - template: templates/make_java_win_binaries.yml - parameters: - msbuildPlatform: x64 - java_artifact_id: onnxruntime_gpu - buildOnly: false - stage: Windows_Packaging_Tensorrt_Testing dependsOn: Setup @@ -242,12 +223,13 @@ stages: - checkout: self clean: true submodules: none - - + - download: build artifact: 'Windows_Packaging_tensorrt_build_artifacts' displayName: 'Download Windows GPU Packages Build' + - template: templates/setup-build-tools.yml + - task: CmdLine@2 inputs: script: | @@ -260,18 +242,7 @@ stages: versionSpec: "17" jdkArchitectureOption: x64 jdkSourceOption: 'PreInstalled' - - - task: UsePythonVersion@0 - inputs: - versionSpec: '3.12' - addToPath: true - architecture: x64 - - - task: PipAuthenticate@1 - displayName: 'Pip Authenticate' - inputs: - artifactFeeds: 'Lotus' - + - task: PythonScript@0 displayName: 'Update CTest Path References' inputs: @@ -280,10 +251,6 @@ stages: "$(Build.BinariesDirectory)/RelWithDebInfo/CTestTestfile.cmake" "$(Build.BinariesDirectory)/RelWithDebInfo" - - task: NodeTool@0 - inputs: - versionSpec: '22.x' - - template: templates/jobs/download_win_gpu_library.yml parameters: CudaVersion: 12.2 @@ -295,10 +262,4 @@ stages: inputs: scriptPath: '$(Build.SourcesDirectory)\tools\ci_build\build.py' arguments: '--config RelWithDebInfo --use_binskim_compliant_compile_flags --enable_lto --disable_rtti --build_dir $(Build.BinariesDirectory) --skip_submodule_sync --build_shared_lib --test --enable_onnx_tests' - workingDirectory: '$(Build.BinariesDirectory)' - # Previous stage only assembles the java binaries, testing will be done in this stage with GPU machine - - template: templates/make_java_win_binaries.yml - parameters: - msbuildPlatform: x64 - java_artifact_id: onnxruntime_gpu - buildOnly: false + workingDirectory: '$(Build.BinariesDirectory)' \ No newline at end of file diff --git a/tools/ci_build/github/azure-pipelines/cuda-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/cuda-packaging-pipeline.yml index 46695403fd854..95f55f52f9a68 100644 --- a/tools/ci_build/github/azure-pipelines/cuda-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/cuda-packaging-pipeline.yml @@ -123,11 +123,9 @@ extends: buildNodejs: false SpecificArtifact: ${{ parameters.SpecificArtifact }} BuildId: ${{ parameters.BuildId }} + PreReleaseVersionSuffixString: ${{ parameters.PreReleaseVersionSuffixString }} + PreReleaseVersionSuffixNumber: ${{ parameters.PreReleaseVersionSuffixNumber }} - template: stages/download-java-tools-stage.yml - template: stages/java-cuda-packaging-stage.yml - parameters: - CudaVersion: ${{ parameters.CudaVersion }} - SpecificArtifact: ${{ parameters.SpecificArtifact }} - BuildId: ${{ parameters.BuildId }} diff --git a/tools/ci_build/github/azure-pipelines/custom-nuget-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/custom-nuget-packaging-pipeline.yml index 505a35775c1fb..b3d6f2906b31d 100644 --- a/tools/ci_build/github/azure-pipelines/custom-nuget-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/custom-nuget-packaging-pipeline.yml @@ -92,6 +92,8 @@ extends: - template: templates/win-ci.yml parameters: + PreReleaseVersionSuffixString: ${{ parameters.PreReleaseVersionSuffixString }} + PreReleaseVersionSuffixNumber: ${{ parameters.PreReleaseVersionSuffixNumber }} ort_build_pool_name: 'onnxruntime-Win2022-GPU-A10' DoCompliance: false DoEsrp: true @@ -124,7 +126,6 @@ extends: - template: templates/mac-cpu-packaging-pipeline.yml parameters: AllowReleasedOpsetOnly: 1 - BuildForAllArchs: true AdditionalBuildFlags: '--use_webgpu --skip_tests' DoEsrp: true diff --git a/tools/ci_build/github/azure-pipelines/jar_package_testing.yml b/tools/ci_build/github/azure-pipelines/jar_package_testing.yml index 24e17de06e6bd..19b40cb7c549a 100644 --- a/tools/ci_build/github/azure-pipelines/jar_package_testing.yml +++ b/tools/ci_build/github/azure-pipelines/jar_package_testing.yml @@ -44,7 +44,18 @@ stages: DownloadCUDA: true DownloadTRT: true - - template: templates/download_maven_for_tests.yml + - template: templates/setup-maven.yml + + - task: Maven@4 + displayName: 'Download Java Dependencies' + inputs: + mavenPomFile: '$(Build.SourcesDirectory)/tools/ci_build/java/pom.xml' + goals: 'dependency:copy-dependencies' + options: '-DoutputDirectory=$(Pipeline.Workspace)/build/onnxruntime-java' + publishJUnitTestResults: false + javaHomeOption: 'JDKVersion' + jdkVersionOption: '1.17' + mavenVersionOption: 'Default' - download: build artifact: 'onnxruntime-java-gpu' displayName: 'Download Final Jar' @@ -64,6 +75,8 @@ stages: del *.sha256 del *.sha512 del *.pom + del *.sha1 + del *.pom cd .. mkdir tests cd tests diff --git a/tools/ci_build/github/azure-pipelines/nuget/templates/dml-vs-2022.yml b/tools/ci_build/github/azure-pipelines/nuget/templates/dml-vs-2022.yml index 757b8ac6e9a16..691ee2f16a9c6 100644 --- a/tools/ci_build/github/azure-pipelines/nuget/templates/dml-vs-2022.yml +++ b/tools/ci_build/github/azure-pipelines/nuget/templates/dml-vs-2022.yml @@ -49,21 +49,7 @@ stages: clean: true submodules: none - - template: ../../templates/telemetry-steps.yml - - - task: NodeTool@0 - inputs: - versionSpec: '22.x' - - - task: UsePythonVersion@0 - inputs: - versionSpec: '3.12' - addToPath: true - architecture: x64 - - task: PipAuthenticate@1 - displayName: 'Pip Authenticate' - inputs: - artifactFeeds: 'Lotus' + - template: ../../templates/setup-build-tools.yml # need to set PROCESSOR_ARCHITECTURE so the x86 SDK is installed correctly - task: UseDotNet@2 diff --git a/tools/ci_build/github/azure-pipelines/post-merge-jobs.yml b/tools/ci_build/github/azure-pipelines/post-merge-jobs.yml index 197edf7bcad24..c329cf66e7e7b 100644 --- a/tools/ci_build/github/azure-pipelines/post-merge-jobs.yml +++ b/tools/ci_build/github/azure-pipelines/post-merge-jobs.yml @@ -357,9 +357,9 @@ stages: #Merge the multiple prof data into a single indexed profile data file llvm-profdata merge -sparse -o ort.profdata *.profraw #Create coverage report, output the result to 'report.json' - llvm-cov export -summary-only -instr-profile=ort.profdata onnxruntime_test_all -object onnxruntime_mlas_test -object onnx_test_runner -object onnxruntime_shared_lib_test -object onnxruntime_global_thread_pools_test $(Build.SourcesDirectory)/include/onnxruntime $(Build.SourcesDirectory)/onnxruntime/core $(Build.SourcesDirectory)/onnxruntime/contrib_ops > $(Build.BinariesDirectory)/report.json + llvm-cov export -summary-only -instr-profile=ort.profdata onnxruntime_test_all -object onnxruntime_mlas_test -object onnx_test_runner -object onnxruntime_shared_lib_test -object onnxruntime_global_thread_pools_test -object onnxruntime_provider_test $(Build.SourcesDirectory)/include/onnxruntime $(Build.SourcesDirectory)/onnxruntime/core $(Build.SourcesDirectory)/onnxruntime/contrib_ops > $(Build.BinariesDirectory)/report.json - llvm-cov show -instr-profile=ort.profdata onnxruntime_test_all -object onnxruntime_mlas_test -object onnx_test_runner -object onnxruntime_shared_lib_test -object onnxruntime_global_thread_pools_test $(Build.SourcesDirectory)/include/onnxruntime $(Build.SourcesDirectory)/onnxruntime/core $(Build.SourcesDirectory)/onnxruntime/contrib_ops --format=html -output-dir=$(Build.ArtifactStagingDirectory) + llvm-cov show -instr-profile=ort.profdata onnxruntime_test_all -object onnxruntime_mlas_test -object onnx_test_runner -object onnxruntime_shared_lib_test -object onnxruntime_global_thread_pools_test -object onnxruntime_provider_test $(Build.SourcesDirectory)/include/onnxruntime $(Build.SourcesDirectory)/onnxruntime/core $(Build.SourcesDirectory)/onnxruntime/contrib_ops --format=html -output-dir=$(Build.ArtifactStagingDirectory) workingDirectory: $(Build.BinariesDirectory) - task: AzureCLI@2 diff --git a/tools/ci_build/github/azure-pipelines/stages/download-java-tools-stage.yml b/tools/ci_build/github/azure-pipelines/stages/download-java-tools-stage.yml deleted file mode 100644 index 949d29d27da9d..0000000000000 --- a/tools/ci_build/github/azure-pipelines/stages/download-java-tools-stage.yml +++ /dev/null @@ -1,26 +0,0 @@ -stages: -- stage: Download_Java_Tools - dependsOn: [] - jobs: - - job: Download_Java_Tools - pool: - name: 'onnxruntime-Ubuntu2404-AMD-CPU' - os: linux - steps: - - checkout: none - - task: CmdLine@2 - displayName: Download Java Tools - inputs: - script: | - mkdir -p java-tools - pushd java-tools - wget --tries=3 https://oss.sonatype.org/service/local/repositories/releases/content/org/junit/platform/junit-platform-console-standalone/1.6.2/junit-platform-console-standalone-1.6.2.jar -P ./ - wget --tries=3 https://oss.sonatype.org/service/local/repositories/releases/content/com/google/protobuf/protobuf-java/3.25.5/protobuf-java-3.25.5.jar -P ./ - popd - workingDirectory: '$(Agent.TempDirectory)' - - - task: 1ES.PublishPipelineArtifact@1 - displayName: 'Publish Pipeline Java Tools Artifact' - inputs: - targetPath: '$(Agent.TempDirectory)/java-tools' - artifact: 'onnxruntime-java-tools' \ No newline at end of file diff --git a/tools/ci_build/github/azure-pipelines/stages/java-cuda-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/stages/java-cuda-packaging-stage.yml index 63aaf328e1426..a58d74bf80a86 100644 --- a/tools/ci_build/github/azure-pipelines/stages/java-cuda-packaging-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/java-cuda-packaging-stage.yml @@ -1,80 +1,31 @@ -parameters: -- name: CudaVersion - type: string -- name: SpecificArtifact - type: string -- name: BuildId - type: string - stages: - stage: Jar_Packaging_GPU dependsOn: - Linux_C_API_Packaging_GPU - Windows_Packaging_CUDA - Windows_Packaging_TensorRT - - Download_Java_Tools jobs: - job: Jar_Packaging_GPU workspace: clean: all + templateContext: + inputs: + - input: pipelineArtifact + artifactName: drop-onnxruntime-java-win-x64-tensorrt + targetPath: '$(Build.BinariesDirectory)\java-artifact\onnxruntime-java-win-x64' + + - input: pipelineArtifact + artifactName: drop-onnxruntime-java-linux-x64-tensorrt + targetPath: '$(Build.BinariesDirectory)\java-artifact\onnxruntime-java-linux-x64' + + outputs: + - output: pipelineArtifact + targetPath: $(Build.BinariesDirectory)\java-artifact\onnxruntime-java-win-x64 + artifactName: onnxruntime-java-gpu pool: 'onnxruntime-Win-CPU-2022' dependsOn: [] condition: succeeded() steps: - - checkout: self - submodules: false - - template: ../templates/set-version-number-variables-step.yml - - - template: ../templates/flex-downloadPipelineArtifact.yml - parameters: - StepName: 'Download Pipeline Artifact - Win x64' - ArtifactName: 'drop-onnxruntime-java-win-x64-tensorrt' - TargetPath: '$(Build.BinariesDirectory)\java-artifact\onnxruntime-java-win-x64' - SpecificArtifact: ${{ parameters.specificArtifact }} - BuildId: ${{ parameters.BuildId }} - - - template: ../templates/flex-downloadPipelineArtifact.yml - parameters: - stepName: 'Download Pipeline Artifact - Linux x64' - artifactName: 'drop-onnxruntime-java-linux-x64-cuda' - targetPath: '$(Build.BinariesDirectory)\java-artifact\onnxruntime-java-linux-x64' - SpecificArtifact: ${{ parameters.specificArtifact }} - BuildId: ${{ parameters.BuildId }} - - - template: ../templates/flex-downloadPipelineArtifact.yml + - template: ../templates/jar-packaging.yml parameters: - StepName: 'Download Pipeline Artifact - Linux x64' - ArtifactName: 'drop-onnxruntime-java-linux-x64-tensorrt' - targetPath: '$(Build.BinariesDirectory)\java-artifact\onnxruntime-java-linux-x64-tensorrt' - SpecificArtifact: ${{ parameters.specificArtifact }} - BuildId: ${{ parameters.BuildId }} - - - task: PowerShell@2 - displayName: 'PowerShell Script' - inputs: - targetType: filePath - filePath: $(Build.SourcesDirectory)\tools\ci_build\github\windows\jar_gpu_packaging.ps1 - failOnStderr: true - showWarnings: true - workingDirectory: '$(Build.BinariesDirectory)\java-artifact' - - - template: ../templates/jar-esrp-dll.yml - parameters: - JarFileDirectory: '$(Build.BinariesDirectory)\java-artifact\onnxruntime-java-win-x64' - JarFileName: 'onnxruntime_gpu-$(OnnxRuntimeVersion).jar' - - - template: ../templates/jar-maven-signing-win.yml - parameters: - JarFileDirectory: '$(Build.BinariesDirectory)\java-artifact\onnxruntime-java-win-x64' - - - task: CopyFiles@2 - displayName: 'Copy Java Files to Artifact Staging Directory' - inputs: - SourceFolder: '$(Build.BinariesDirectory)\java-artifact\onnxruntime-java-win-x64' - TargetFolder: '$(Build.ArtifactStagingDirectory)' - - - task: 1ES.PublishPipelineArtifact@1 - displayName: 'Publish Pipeline Artifact' - inputs: - path: '$(Build.ArtifactStagingDirectory)' - artifact: 'onnxruntime-java-gpu' + package_type: gpu \ No newline at end of file diff --git a/tools/ci_build/github/azure-pipelines/stages/nodejs-win-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/stages/nodejs-win-packaging-stage.yml index 76eb5f150ad44..1e9b91c651aa5 100644 --- a/tools/ci_build/github/azure-pipelines/stages/nodejs-win-packaging-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/nodejs-win-packaging-stage.yml @@ -71,23 +71,7 @@ stages: clean: true submodules: none - - template: ../templates/telemetry-steps.yml - - - task: NodeTool@0 - inputs: - versionSpec: '22.x' - - - task: UsePythonVersion@0 - inputs: - versionSpec: '3.12' - addToPath: true - architecture: ${{ parameters.BuildArch }} - - - task: PipAuthenticate@1 - displayName: 'Pip Authenticate' - inputs: - artifactFeeds: 'Lotus' - + - template: ../templates/setup-build-tools.yml # need to set PROCESSOR_ARCHITECTURE so the x86 SDK is installed correctly - task: UseDotNet@2 diff --git a/tools/ci_build/github/azure-pipelines/stages/nuget-combine-cuda-stage.yml b/tools/ci_build/github/azure-pipelines/stages/nuget-combine-cuda-stage.yml index e33d3dbf9e107..168432283fa51 100644 --- a/tools/ci_build/github/azure-pipelines/stages/nuget-combine-cuda-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/nuget-combine-cuda-stage.yml @@ -32,6 +32,19 @@ parameters: - name: BuildId type: string +- name: PreReleaseVersionSuffixString + displayName: Suffix added to pre-release package version. Only used if IsReleaseBuild is true. Denotes the type of pre-release package. + type: string + values: + - alpha + - beta + - rc + - none + +- name: PreReleaseVersionSuffixNumber + displayName: Number added to pre-release package version. Only used if IsReleaseBuild is true. Denotes the sequence of a pre-release package. + type: number + stages: - template: nuget-linux-cuda-packaging-stage.yml parameters: @@ -52,6 +65,8 @@ stages: win_trt_home: ${{ parameters.win_trt_home }} win_cuda_home: ${{ parameters.win_cuda_home }} buildJava: ${{ parameters.buildJava }} + PreReleaseVersionSuffixString: ${{ parameters.PreReleaseVersionSuffixString }} + PreReleaseVersionSuffixNumber: ${{ parameters.PreReleaseVersionSuffixNumber }} - template: nuget-cuda-packaging-stage.yml parameters: diff --git a/tools/ci_build/github/azure-pipelines/stages/nuget-linux-cuda-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/stages/nuget-linux-cuda-packaging-stage.yml index 4175a339535e4..121e80fca1021 100644 --- a/tools/ci_build/github/azure-pipelines/stages/nuget-linux-cuda-packaging-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/nuget-linux-cuda-packaging-stage.yml @@ -28,6 +28,10 @@ stages: value: ${{ parameters.CudaVersion }} steps: - template: ../templates/set-version-number-variables-step.yml + - task: UsePythonVersion@0 + displayName: Use Python 3.12 + inputs: + versionSpec: 3.12 - template: ../templates/get-docker-image-steps.yml parameters: Dockerfile: tools/ci_build/github/linux/docker/inference/x86_64/default/cuda${{ variables.CUDA_VERSION_MAJOR }}/Dockerfile @@ -45,10 +49,8 @@ stages: arch: 'linux-x64' buildConfig: 'Release' artifactName: 'onnxruntime-java-linux-x64-cuda' - version: '$(OnnxRuntimeVersion)' libraryName: 'libonnxruntime.so' nativeLibraryName: 'libonnxruntime4j_jni.so' - is1ES: true - template: ../templates/c-api-artifacts-package-and-publish-steps-posix.yml parameters: @@ -85,6 +87,10 @@ stages: - checkout: self clean: true submodules: recursive + - task: UsePythonVersion@0 + displayName: Use Python 3.12 + inputs: + versionSpec: 3.12 - template: ../templates/get-docker-image-steps.yml parameters: Dockerfile: tools/ci_build/github/linux/docker/inference/x86_64/default/cuda${{ variables.CUDA_VERSION_MAJOR }}/Dockerfile @@ -106,10 +112,8 @@ stages: arch: 'linux-x64' buildConfig: 'Release' artifactName: 'onnxruntime-java-linux-x64-tensorrt' - version: '$(OnnxRuntimeVersion)' libraryName: 'libonnxruntime.so' nativeLibraryName: 'libonnxruntime4j_jni.so' - is1ES: true - template: ../templates/c-api-artifacts-package-and-publish-steps-posix.yml parameters: diff --git a/tools/ci_build/github/azure-pipelines/stages/nuget-win-cuda-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/stages/nuget-win-cuda-packaging-stage.yml index ed6c4c799c26d..30c3b1e48a89a 100644 --- a/tools/ci_build/github/azure-pipelines/stages/nuget-win-cuda-packaging-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/nuget-win-cuda-packaging-stage.yml @@ -34,10 +34,25 @@ parameters: - name: buildJava type: boolean +- name: PreReleaseVersionSuffixString + displayName: Suffix added to pre-release package version. Only used if IsReleaseBuild is true. Denotes the type of pre-release package. + type: string + values: + - alpha + - beta + - rc + - none + +- name: PreReleaseVersionSuffixNumber + displayName: Number added to pre-release package version. Only used if IsReleaseBuild is true. Denotes the sequence of a pre-release package. + type: number + stages: # Windows CUDA without TensorRT Packaging - template: ../templates/win-ci.yml parameters: + PreReleaseVersionSuffixString: ${{ parameters.PreReleaseVersionSuffixString }} + PreReleaseVersionSuffixNumber: ${{ parameters.PreReleaseVersionSuffixNumber }} ort_build_pool_name: 'onnxruntime-Win2022-GPU-A10' DoEsrp: ${{ parameters.DoEsrp }} stage_name_suffix: CUDA @@ -52,9 +67,12 @@ stages: UseIncreasedTimeoutForTests: ${{ parameters.UseIncreasedTimeoutForTests }} SpecificArtifact: ${{ parameters.SpecificArtifact }} BuildId: ${{ parameters.BuildId }} + # Windows CUDA with TensorRT Packaging - template: ../templates/win-ci.yml parameters: + PreReleaseVersionSuffixString: ${{ parameters.PreReleaseVersionSuffixString }} + PreReleaseVersionSuffixNumber: ${{ parameters.PreReleaseVersionSuffixNumber }} ort_build_pool_name: 'onnxruntime-Win2022-GPU-A10' DoEsrp: ${{ parameters.DoEsrp }} stage_name_suffix: TensorRT diff --git a/tools/ci_build/github/azure-pipelines/stages/py-cpu-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/stages/py-cpu-packaging-stage.yml index fdb4688963471..0833a2172b527 100644 --- a/tools/ci_build/github/azure-pipelines/stages/py-cpu-packaging-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/py-cpu-packaging-stage.yml @@ -116,17 +116,7 @@ stages: clean: true submodules: recursive - - template: ../templates/telemetry-steps.yml - - - task: UsePythonVersion@0 - inputs: - versionSpec: $(PythonVersion) - addToPath: true - architecture: $(buildArch) - - task: PipAuthenticate@1 - displayName: 'Pip Authenticate' - inputs: - artifactFeeds: 'Lotus' + - template: ../templates/setup-build-tools.yml - template: ../templates/set-nightly-build-option-variable-step.yml @@ -208,80 +198,61 @@ stages: - stage: Python_Packaging_MacOS dependsOn: [] jobs: - - job: MacOS_py_Wheels - timeoutInMinutes: 360 - workspace: - clean: all - pool: - name: "Azure Pipelines" - image: "macOS-14" - os: macOS - templateContext: - outputs: - - output: pipelineArtifact - targetPath: $(Build.SourcesDirectory)/build/Release/dist/fixed_wheels - artifactName: onnxruntime-macos-$(PythonVersion) - variables: - MACOSX_DEPLOYMENT_TARGET: '13.4' - strategy: - matrix: - Python310: - PythonVersion: '3.10' - Python311: - PythonVersion: '3.11' - Python312: - PythonVersion: '3.12' - Python313: - PythonVersion: '3.13' - steps: - - checkout: self - clean: true - submodules: recursive - - - task: UsePythonVersion@0 - displayName: 'Use Python' - inputs: - versionSpec: $(PythonVersion) - - - task: PipAuthenticate@1 - displayName: 'Pip Authenticate' - inputs: - artifactFeeds: 'Lotus' - - - template: ../templates/use-xcode-version.yml + - template: ../templates/py-macos.yml + parameters: + arch: 'arm64' + extra_build_arg: ${{ parameters.build_py_parameters }} + cmake_build_type: ${{ parameters.cmake_build_type }} + python_version: '3.10' + + - template: ../templates/py-macos.yml + parameters: + arch: 'arm64' + extra_build_arg: ${{ parameters.build_py_parameters }} + cmake_build_type: ${{ parameters.cmake_build_type }} + python_version: '3.11' - - script: | - set -e -x - export _PYTHON_HOST_PLATFORM=macosx-${{variables.MACOSX_DEPLOYMENT_TARGET}}-universal2 - python3 -m pip install -r '$(Build.SourcesDirectory)/tools/ci_build/github/linux/docker/scripts/requirements.txt' - # Note: There is a build error when we set CMAKE_OSX_ARCHITECTURES="arm64;x86_64" and KleidiAI is enabled. - # Disable KleidiAI as a workaround with --no_kleidiai. - # TODO Re-enable KleidiAI once https://github.com/microsoft/onnxruntime/issues/24152 is fixed. - python3 $(Build.SourcesDirectory)/tools/ci_build/build.py \ - --build_dir $(Build.SourcesDirectory)/build \ - --use_vcpkg --use_vcpkg_ms_internal_asset_cache \ - --use_binskim_compliant_compile_flags \ - --config Release \ - --build_wheel \ - --use_coreml \ - --no_kleidiai \ - ${{ parameters.build_py_parameters }} \ - --cmake_extra_defines CMAKE_OSX_ARCHITECTURES="arm64;x86_64" \ - --update --skip_submodule_sync --build --parallel - displayName: 'Command Line Script' + - template: ../templates/py-macos.yml + parameters: + arch: 'arm64' + extra_build_arg: ${{ parameters.build_py_parameters }} + cmake_build_type: ${{ parameters.cmake_build_type }} + python_version: '3.12' + + - template: ../templates/py-macos.yml + parameters: + arch: 'arm64' + extra_build_arg: ${{ parameters.build_py_parameters }} + cmake_build_type: ${{ parameters.cmake_build_type }} + python_version: '3.13' - - script: | - set -ex - python -m pip install --upgrade delocate - cd '$(Build.SourcesDirectory)/build/Release/dist' - ls - for file in *.whl - do - delocate-listdeps "$file" - delocate-wheel --require-archs=x86_64,arm64 -w fixed_wheels -v "$file" - done - displayName: 'delocate wheel' + - template: ../templates/py-macos.yml + parameters: + arch: 'x86_64' + extra_build_arg: ${{ parameters.build_py_parameters }} + cmake_build_type: ${{ parameters.cmake_build_type }} + python_version: '3.10' + + - template: ../templates/py-macos.yml + parameters: + arch: 'x86_64' + extra_build_arg: ${{ parameters.build_py_parameters }} + cmake_build_type: ${{ parameters.cmake_build_type }} + python_version: '3.11' + - template: ../templates/py-macos.yml + parameters: + arch: 'x86_64' + extra_build_arg: ${{ parameters.build_py_parameters }} + cmake_build_type: ${{ parameters.cmake_build_type }} + python_version: '3.12' + + - template: ../templates/py-macos.yml + parameters: + arch: 'x86_64' + extra_build_arg: ${{ parameters.build_py_parameters }} + cmake_build_type: ${{ parameters.cmake_build_type }} + python_version: '3.13' - ${{ if eq(parameters.enable_linux_arm, true) }}: - stage: Python_Packaging_Linux_ARM diff --git a/tools/ci_build/github/azure-pipelines/stages/py-win-gpu-stage.yml b/tools/ci_build/github/azure-pipelines/stages/py-win-gpu-stage.yml index 9c063f561eefc..435551cd2e55d 100644 --- a/tools/ci_build/github/azure-pipelines/stages/py-win-gpu-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/py-win-gpu-stage.yml @@ -92,18 +92,7 @@ stages: clean: true submodules: none - - template: ../templates/telemetry-steps.yml - - - task: UsePythonVersion@0 - inputs: - versionSpec: ${{ parameters.PYTHON_VERSION }} - addToPath: true - architecture: 'x64' - - - task: PipAuthenticate@1 - displayName: 'Pip Authenticate' - inputs: - artifactFeeds: 'Lotus' + - template: ../templates/setup-build-tools.yml - template: ../templates/jobs/download_win_gpu_library.yml parameters: diff --git a/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar.yml b/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar.yml index 3b4fc6063d96a..73c84cf50e785 100644 --- a/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar.yml +++ b/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar.yml @@ -74,10 +74,11 @@ jobs: ${{ if contains(parameters.pool_name, 'mac')}}: os: macOS - - variables: - artifacts_directory: $(Build.BinariesDirectory)/.artifacts - + templateContext: + outputs: + - output: pipelineArtifact + targetPath: $(Build.BinariesDirectory)/.artifacts + artifactName: ${{parameters.artifactName}} steps: - checkout: self clean: true @@ -88,7 +89,7 @@ jobs: inputs: script: | # Create a folder for artifacts - mkdir -p $(artifacts_directory) + mkdir -p $(Build.BinariesDirectory)/.artifacts workingDirectory: $(Build.BinariesDirectory) - template: get-docker-image-steps.yml @@ -131,7 +132,7 @@ jobs: --volume $(Build.BinariesDirectory):/build \ --volume $ANDROID_HOME:/android_home \ --volume $NDK_HOME:/ndk_home \ - --volume $(artifacts_directory):/home/onnxruntimedev/.artifacts \ + --volume $(Build.BinariesDirectory)/.artifacts:/home/onnxruntimedev/.artifacts \ --volume $(Build.BinariesDirectory)/.build_settings:/home/onnxruntimedev/.build_settings \ $QNN_VOLUME \ -e NIGHTLY_BUILD \ @@ -145,18 +146,6 @@ jobs: /bin/bash /onnxruntime_src/tools/ci_build/github/android/build_aar_and_copy_artifacts.sh $USE_QNN workingDirectory: $(Build.SourcesDirectory) - - - ${{ if eq(parameters['enable_code_sign'], 'true') }}: - - template: jar-maven-signing-linux.yml - parameters: - JarFileDirectory: '$(artifacts_directory)' - - ${{ if eq(parameters.is1ES, false) }}: - - task: PublishPipelineArtifact@1 - inputs: - targetPath: '$(artifacts_directory)' - artifactName: '${{parameters.artifactName}}' - - ${{ if eq(parameters.is1ES, true) }}: - - task: 1ES.PublishPipelineArtifact@1 - inputs: - targetPath: '$(artifacts_directory)' - artifactName: '${{parameters.artifactName}}' + - template: jar-maven-signing-linux.yml + parameters: + JarFileDirectory: $(Build.BinariesDirectory)/.artifacts \ No newline at end of file diff --git a/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml b/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml index 54759ffd88ff0..0bc8fed17063f 100644 --- a/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml +++ b/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml @@ -9,6 +9,19 @@ parameters: type: boolean default: false +- name: PreReleaseVersionSuffixString + displayName: Suffix added to pre-release package version. Only used if IsReleaseBuild is true. Denotes the type of pre-release package. + type: string + values: + - alpha + - beta + - rc + - none + +- name: PreReleaseVersionSuffixNumber + displayName: Number added to pre-release package version. Only used if IsReleaseBuild is true. Denotes the sequence of a pre-release package. + type: number + - name: AdditionalBuildFlags displayName: Additional build flags for build.py type: string @@ -28,16 +41,6 @@ parameters: type: string default: 'default' -- name: SpecificArtifact - displayName: Use Specific Artifact - type: boolean - default: false - -- name: BuildId - displayName: Specific Artifact's BuildId - type: string - default: '0' - # Do not update this to a version that does not exist for the qnn-runtime Maven package: # https://mvnrepository.com/artifact/com.qualcomm.qti/qnn-runtime - name: QnnSDKVersion @@ -58,9 +61,6 @@ stages: - template: mac-cpu-packaging-pipeline.yml parameters: AllowReleasedOpsetOnly: 1 - BuildForAllArchs: true - SpecificArtifact: ${{ parameters.SpecificArtifact }} - BuildId: ${{ parameters.BuildId }} DoEsrp: true - stage: Android_Java_API_AAR_Packaging_Full @@ -108,14 +108,24 @@ stages: clean: all pool: name: 'Azure Pipelines' - image: 'macOS-14' + image: 'macOS-15' os: 'macOS' timeoutInMinutes: 300 steps: - template: set-version-number-variables-step.yml + - task: JavaToolInstaller@0 + inputs: + versionSpec: "17" + jdkArchitectureOption: "x64" + jdkSourceOption: 'PreInstalled' + - template: use-xcode-version.yml + parameters: + xcodeVersion: 16.4 + + - template: setup-build-tools.yml - script: | set -e -x @@ -154,6 +164,8 @@ stages: runTests: false buildJava: false buildNodejs: false + PreReleaseVersionSuffixString: ${{ parameters.PreReleaseVersionSuffixString }} + PreReleaseVersionSuffixNumber: ${{ parameters.PreReleaseVersionSuffixNumber }} - template: win-ci.yml parameters: @@ -166,13 +178,14 @@ stages: runTests: ${{ parameters.RunOnnxRuntimeTests }} buildJava: true buildNodejs: false + PreReleaseVersionSuffixString: ${{ parameters.PreReleaseVersionSuffixString }} + PreReleaseVersionSuffixNumber: ${{ parameters.PreReleaseVersionSuffixNumber }} - stage: Jar_Packaging dependsOn: - Linux_C_API_Packaging_CPU - - MacOS_C_API_Package_Publish + - MacOS_C_API_Packaging_CPU - Windows_Packaging_CPU_x64_${{ parameters.BuildVariant }} - - Download_Java_Tools condition: succeeded() jobs: - job: Jar_Packaging @@ -203,38 +216,13 @@ stages: targetPath: '$(Build.BinariesDirectory)\java-artifact\onnxruntime-java-osx-arm64' outputs: - output: pipelineArtifact - targetPath: $(Build.ArtifactStagingDirectory) + targetPath: $(Build.BinariesDirectory)\java-artifact\onnxruntime-java-win-x64 artifactName: onnxruntime-java steps: - - checkout: self - submodules: false - - template: set-version-number-variables-step.yml - - - task: PowerShell@2 - displayName: 'PowerShell Script' - inputs: - targetType: filePath - filePath: $(Build.SourcesDirectory)\tools\ci_build\github\windows\jar_packaging.ps1 - failOnStderr: true - showWarnings: true - workingDirectory: '$(Build.BinariesDirectory)\java-artifact' - - - template: jar-esrp-dll.yml - parameters: - JarFileDirectory: '$(Build.BinariesDirectory)\java-artifact\onnxruntime-java-win-x64' - JarFileName: 'onnxruntime-$(OnnxRuntimeVersion).jar' - - - template: jar-maven-signing-win.yml + - template: jar-packaging.yml parameters: - JarFileDirectory: '$(Build.BinariesDirectory)\java-artifact\onnxruntime-java-win-x64' - - - task: CopyFiles@2 - displayName: 'Copy Java Files to Artifact Staging Directory' - inputs: - SourceFolder: '$(Build.BinariesDirectory)\java-artifact\onnxruntime-java-win-x64' - TargetFolder: '$(Build.ArtifactStagingDirectory)' - + package_type: cpu - stage: NuGet_Packaging_CPU dependsOn: @@ -261,6 +249,28 @@ stages: binskim: enabled: true scanOutputDirectoryOnly: true + inputs: + - input: pipelineArtifact + artifactName: onnxruntime-win-x64 + targetPath: $(Build.BinariesDirectory)/nuget-artifact + - input: pipelineArtifact + artifactName: onnxruntime-win-arm64 + targetPath: $(Build.BinariesDirectory)/nuget-artifact + - input: pipelineArtifact + artifactName: onnxruntime-osx + targetPath: $(Build.BinariesDirectory)/nuget-artifact + - input: pipelineArtifact + artifactName: onnxruntime-linux-x64 + targetPath: $(Build.BinariesDirectory)/nuget-artifact + - input: pipelineArtifact + artifactName: onnxruntime-linux-aarch64 + targetPath: $(Build.BinariesDirectory)/nuget-artifact + - input: pipelineArtifact + artifactName: onnxruntime-ios-full-xcframework + targetPath: $(Build.BinariesDirectory)/nuget-artifact + - input: pipelineArtifact + artifactName: onnxruntime-android-full-aar + targetPath: $(Build.BinariesDirectory)/nuget-artifact outputs: - output: pipelineArtifact targetPath: $(Build.ArtifactStagingDirectory) @@ -276,62 +286,6 @@ stages: - checkout: self submodules: true - - template: flex-downloadPipelineArtifact.yml - parameters: - StepName: 'Download Pipeline Artifact - Win x64' - ArtifactName: 'onnxruntime-win-x64' - TargetPath: '$(Build.BinariesDirectory)/nuget-artifact' - SpecificArtifact: ${{ parameters.specificArtifact }} - BuildId: ${{ parameters.BuildId }} - - - template: flex-downloadPipelineArtifact.yml - parameters: - StepName: 'Download win-arm64 Pipeline Artifact' - ArtifactName: 'onnxruntime-win-arm64' - TargetPath: '$(Build.BinariesDirectory)/nuget-artifact' - SpecificArtifact: ${{ parameters.specificArtifact }} - BuildId: ${{ parameters.BuildId }} - - - template: flex-downloadPipelineArtifact.yml - parameters: - StepName: 'Download osx-x64 Pipeline Artifact' - ArtifactName: 'onnxruntime-osx' - TargetPath: '$(Build.BinariesDirectory)/nuget-artifact' - SpecificArtifact: ${{ parameters.specificArtifact }} - BuildId: ${{ parameters.BuildId }} - - - template: flex-downloadPipelineArtifact.yml - parameters: - StepName: 'Download linux-x64 Pipeline Artifact' - ArtifactName: 'onnxruntime-linux-x64' - TargetPath: '$(Build.BinariesDirectory)/nuget-artifact' - SpecificArtifact: ${{ parameters.specificArtifact }} - BuildId: ${{ parameters.BuildId }} - - - template: flex-downloadPipelineArtifact.yml - parameters: - StepName: 'Download linux-aarch64 Pipeline Artifact' - ArtifactName: 'onnxruntime-linux-aarch64' - TargetPath: '$(Build.BinariesDirectory)/nuget-artifact' - SpecificArtifact: ${{ parameters.specificArtifact }} - BuildId: ${{ parameters.BuildId }} - - - template: flex-downloadPipelineArtifact.yml - parameters: - StepName: 'Download iOS Pipeline Artifact' - ArtifactName: 'onnxruntime-ios-full-xcframework' - TargetPath: '$(Build.BinariesDirectory)/nuget-artifact' - SpecificArtifact: ${{ parameters.specificArtifact }} - BuildId: ${{ parameters.BuildId }} - - - template: flex-downloadPipelineArtifact.yml - parameters: - StepName: 'Download Android-full-aar Pipeline Artifact' - ArtifactName: 'onnxruntime-android-full-aar' - TargetPath: '$(Build.BinariesDirectory)/nuget-artifact' - SpecificArtifact: ${{ parameters.specificArtifact }} - BuildId: ${{ parameters.BuildId }} - - script: | dir workingDirectory: '$(Build.BinariesDirectory)/nuget-artifact' @@ -442,7 +396,7 @@ stages: - Windows_Nodejs_Packaging_arm64 - Linux_Nodejs_Packaging_x64 - Linux_C_API_Packaging_CPU - - MacOS_C_API_Package_Publish + - MacOS_C_API_Packaging_CPU condition: succeeded() jobs: - job: Nodejs_Packaging diff --git a/tools/ci_build/github/azure-pipelines/templates/c-api-linux-cpu.yml b/tools/ci_build/github/azure-pipelines/templates/c-api-linux-cpu.yml index aa1e38f8b0159..f1599b6843fb5 100644 --- a/tools/ci_build/github/azure-pipelines/templates/c-api-linux-cpu.yml +++ b/tools/ci_build/github/azure-pipelines/templates/c-api-linux-cpu.yml @@ -45,6 +45,14 @@ jobs: - checkout: self clean: true submodules: none + + - task: UsePythonVersion@0 + displayName: Use Python 3.12 + inputs: + versionSpec: 3.12 + ${{ if eq(parameters.OnnxruntimeArch, 'aarch64') }}: + architecture: arm64 + - template: set-version-number-variables-step.yml - ${{ if eq(parameters.OnnxruntimeArch, 'x64') }}: - template: get-docker-image-steps.yml @@ -82,10 +90,8 @@ jobs: arch: 'linux-${{parameters.OnnxruntimeArch}}' buildConfig: 'Release' artifactName: 'onnxruntime-java-linux-${{parameters.OnnxruntimeArch}}' - version: '$(OnnxRuntimeVersion)' libraryName: 'libonnxruntime.so' nativeLibraryName: 'libonnxruntime4j_jni.so' - is1ES: true - template: c-api-artifacts-package-and-publish-steps-posix.yml parameters: diff --git a/tools/ci_build/github/azure-pipelines/templates/download_maven_for_tests.yml b/tools/ci_build/github/azure-pipelines/templates/download_maven_for_tests.yml index e53544458d494..7d4cc9550ce54 100644 --- a/tools/ci_build/github/azure-pipelines/templates/download_maven_for_tests.yml +++ b/tools/ci_build/github/azure-pipelines/templates/download_maven_for_tests.yml @@ -16,13 +16,3 @@ steps: echo "Maven is now on the PATH." mvn --version -- task: Maven@4 - displayName: 'Download Java Dependencies' - inputs: - mavenPomFile: '$(Build.SourcesDirectory)/tools/ci_build/java/pom.xml' - goals: 'dependency:copy-dependencies' - options: '-DoutputDirectory=$(Pipeline.Workspace)/build/onnxruntime-java' - publishJUnitTestResults: false - javaHomeOption: 'JDKVersion' - jdkVersionOption: '1.17' - mavenVersionOption: 'Default' \ No newline at end of file diff --git a/tools/ci_build/github/azure-pipelines/templates/final-jar-testing-linux.yml b/tools/ci_build/github/azure-pipelines/templates/final-jar-testing-linux.yml index 9d1cb58aee8bb..5a25232a90c39 100644 --- a/tools/ci_build/github/azure-pipelines/templates/final-jar-testing-linux.yml +++ b/tools/ci_build/github/azure-pipelines/templates/final-jar-testing-linux.yml @@ -73,6 +73,8 @@ stages: rm -f *.asc rm -f *.sha256 rm -f *.sha512 + rm -f *.sha1 + rm -f *.md5 rm -f *.pom ls cd .. diff --git a/tools/ci_build/github/azure-pipelines/templates/final-jar-testing-win.yml b/tools/ci_build/github/azure-pipelines/templates/final-jar-testing-win.yml index af7fa176d2ac0..de07e9e89dc81 100644 --- a/tools/ci_build/github/azure-pipelines/templates/final-jar-testing-win.yml +++ b/tools/ci_build/github/azure-pipelines/templates/final-jar-testing-win.yml @@ -62,6 +62,8 @@ stages: del *.asc del *.sha256 del *.sha512 + del *.md5 + del *.sha1 del *.pom cd .. mkdir tests diff --git a/tools/ci_build/github/azure-pipelines/templates/jar-esrp-dll.yml b/tools/ci_build/github/azure-pipelines/templates/jar-esrp-dll.yml index b59ba551c222f..dd0e0898ecc3b 100644 --- a/tools/ci_build/github/azure-pipelines/templates/jar-esrp-dll.yml +++ b/tools/ci_build/github/azure-pipelines/templates/jar-esrp-dll.yml @@ -3,28 +3,25 @@ parameters: type: string default: '' -- name: JarFileName - type: string - default: '' - steps: - - task: PowerShell@2 - displayName: 'ESRP Jar - Extract Jar File' - inputs: - targetType: filePath - filePath: $(Build.SourcesDirectory)\tools\ci_build\github\windows\jar_esrp_dll.ps1 - arguments: extract '${{ parameters.JarFileDirectory }}' '${{ parameters.JarFileName }}' - workingDirectory: '$(Build.BinariesDirectory)' +- task: PowerShell@2 + displayName: 'ESRP Jar - Extract Jar File' + inputs: + targetType: filePath + filePath: $(Build.SourcesDirectory)\tools\ci_build\github\windows\jar_esrp_dll.ps1 + arguments: extract '${{ parameters.JarFileDirectory }}' + workingDirectory: '$(Build.BinariesDirectory)' - - template: win-esrp-dll.yml - parameters: - FolderPath: '${{ parameters.JarFileDirectory }}\jar_extracted_full_files' - DisplayName: 'ESRP Jar - Sign Dlls' +- template: win-esrp-dll.yml + parameters: + FolderPath: '${{ parameters.JarFileDirectory }}\jar_extracted_full_files' + DisplayName: 'ESRP Jar - Sign Dlls' + DoEsrp: true # Assuming ESRP should always run when this template is called - - task: PowerShell@2 - displayName: 'ESRP Jar - Repack Jar File' - inputs: - targetType: filePath - filePath: $(Build.SourcesDirectory)\tools\ci_build\github\windows\jar_esrp_dll.ps1 - arguments: repack '${{ parameters.JarFileDirectory }}' '${{ parameters.JarFileName }}' - workingDirectory: '$(Build.BinariesDirectory)' +- task: PowerShell@2 + displayName: 'ESRP Jar - Repack Jar File' + inputs: + targetType: filePath + filePath: $(Build.SourcesDirectory)\tools\ci_build\github\windows\jar_esrp_dll.ps1 + arguments: repack '${{ parameters.JarFileDirectory }}' + workingDirectory: '$(Build.BinariesDirectory)' \ No newline at end of file diff --git a/tools/ci_build/github/azure-pipelines/templates/jar-maven-signing-linux.yml b/tools/ci_build/github/azure-pipelines/templates/jar-maven-signing-linux.yml index df2aff0634819..98a52b08f32f2 100644 --- a/tools/ci_build/github/azure-pipelines/templates/jar-maven-signing-linux.yml +++ b/tools/ci_build/github/azure-pipelines/templates/jar-maven-signing-linux.yml @@ -4,54 +4,25 @@ parameters: steps: - task: AzureKeyVault@2 - displayName: 'Get GnuPG signing keys' + displayName: "Get GnuPG signing keys" inputs: #The value below is the name of an ADO service connection. - azureSubscription: 'AIInfraBuildOnnxRuntimeOSS' - KeyVaultName: 'ort-release' - SecretsFilter: 'java-pgp-pwd,java-pgp-key' + azureSubscription: "AIInfraBuildOnnxRuntimeOSS" + KeyVaultName: "ort-release" + SecretsFilter: "java-pgp-pwd,java-pgp-key" RunAsPreJob: false - - task: CmdLine@2 - displayName: 'Sign jar files: GnuPG and sha256' + - task: UsePythonVersion@0 + displayName: "Use Python 3.12" inputs: - workingDirectory: '$(Build.SourcesDirectory)' - script: | - #!/bin/bash - set -e + versionSpec: "3.12" - jar_file_directory='${{ parameters.JarFileDirectory }}' - working_directory='$(Build.SourcesDirectory)' - original_private_key='$(java-pgp-key)' - original_passphrase='$(java-pgp-pwd)' - - private_key_file=$working_directory/private_key.txt - passphrase_file=$working_directory/passphrase.txt - - echo "Generating GnuPG key files." - printf "%s" "$original_private_key" >$private_key_file - printf "%s" "$original_passphrase" >$passphrase_file - echo "Generated GnuPG key files." - - echo "Importing GnuPG private key file." - gpg --batch --import $private_key_file - echo "Imported GnuPG private key file." - - for file in $(find $jar_file_directory -type f); do - echo "GnuPG signing to file: $file" - gpg --pinentry-mode loopback --passphrase-file $passphrase_file -ab $file - echo "GnuPG signed to file: $file" - done - - for file in $(find $jar_file_directory -type f); do - echo "Adding checksum of sha256 to file: $file" - sha256_value=$(sha256sum $file | awk '{print $1}') - echo $sha256_value" *"$(basename "$file") >$file.sha256 - echo "Added checksum of sha256 to file: $file" - done - - echo "GnuPG and sha256 signing to files completed." - echo "Deleting GnuPG key files." - rm -f $private_key_file - rm -f $passphrase_file - echo "Deleted GnuPG key files." + - task: PythonScript@0 + displayName: "Sign files: GnuPG, sha1, and md5" + env: + JAVA_PGP_PWD: $(java-pgp-pwd) + JAVA_PGP_KEY: $(java-pgp-key) + inputs: + scriptPath: "$(Build.SourcesDirectory)/tools/ci_build/github/windows/sign_java_artifacts.py" + arguments: "${{ parameters.JarFileDirectory }}" + workingDirectory: "$(Build.SourcesDirectory)" \ No newline at end of file diff --git a/tools/ci_build/github/azure-pipelines/templates/jar-maven-signing-win.yml b/tools/ci_build/github/azure-pipelines/templates/jar-maven-signing-win.yml deleted file mode 100644 index ef845dc3bf243..0000000000000 --- a/tools/ci_build/github/azure-pipelines/templates/jar-maven-signing-win.yml +++ /dev/null @@ -1,78 +0,0 @@ -parameters: - - name: JarFileDirectory - type: string - -steps: - - task: AzureKeyVault@2 - displayName: 'Get GnuPG signing keys' - inputs: - azureSubscription: 'AIInfraBuildOnnxRuntimeOSS' - KeyVaultName: 'ort-release' - SecretsFilter: 'java-pgp-pwd,java-pgp-key' - RunAsPreJob: false - - - task: PowerShell@2 - displayName: 'Sign jar files: GnuPG and sha256' - inputs: - targetType: 'inline' - pwsh: true - workingDirectory: '$(Build.SourcesDirectory)' - script: | - $jar_file_directory = '${{ parameters.JarFileDirectory }}' - $working_directory = '$(Build.SourcesDirectory)' - - $original_passphrase='$(java-pgp-pwd)' - $original_private_key='$(java-pgp-key)' - - $gpg_exe_path = "C:\Program Files (x86)\gnupg\bin\gpg.exe" - - $passphrase_file = Join-Path -Path $working_directory -ChildPath "passphrase.txt" - $private_key_file = Join-Path -Path $working_directory -ChildPath "private_key.txt" - - Write-Host "Generating GnuPG key files." - Out-File -FilePath $passphrase_file -InputObject $original_passphrase -NoNewline -Encoding ascii - Out-File -FilePath $private_key_file -InputObject $original_private_key -NoNewline -Encoding ascii - Write-Host "Generated GnuPG key files." - - Write-Host "Importing GnuPG private key file." - & $gpg_exe_path --batch --import $private_key_file - if ($lastExitCode -ne 0) { - Write-Host -Object "GnuPG importing private key command failed. Exitcode: $exitCode" - exit $lastExitCode - } - Write-Host "Imported GnuPG private key file." - - $targeting_original_files = Get-ChildItem $jar_file_directory -Recurse -Force -File -Name - foreach ($file in $targeting_original_files) { - $file_path = Join-Path $jar_file_directory -ChildPath $file - Write-Host "GnuPG signing to file: "$file_path - & $gpg_exe_path --pinentry-mode loopback --passphrase-file $passphrase_file -ab $file_path - if ($lastExitCode -ne 0) { - Write-Host -Object "GnuPG signing file command failed. Exitcode: $exitCode" - exit $lastExitCode - } - Write-Host "GnuPG signed to file: "$file_path - } - - $PSDefaultParameterValues['Out-File:Encoding'] = 'utf8NoBOM' - $sha256sum_exe_path = "C:\Program Files\Git\usr\bin\sha256sum.exe" - $targeting_asc_files = Get-ChildItem $jar_file_directory -Recurse -Force -File -Name - $original_location = Get-Location - Set-Location $jar_file_directory - foreach ($file in $targeting_asc_files) { - Write-Host "Adding checksum of sha256 to file: "$file - $file_path_sha256 = $file + ".sha256" - & $sha256sum_exe_path $file 1>$file_path_sha256 - if ($lastExitCode -ne 0) { - Write-Host -Object "sha256sum command failed. Exitcode: $exitCode" - exit $lastExitCode - } - Write-Host "Added checksum of sha256 to file: "$file - } - Set-Location $original_location - - Write-Host "GnuPG and sha256 signing to files completed." - Write-Host "Deleting GnuPG key files." - Remove-Item -Path $passphrase_file - Remove-Item -Path $private_key_file - Write-Host "Deleted GnuPG key files." diff --git a/tools/ci_build/github/azure-pipelines/templates/jar-packaging.yml b/tools/ci_build/github/azure-pipelines/templates/jar-packaging.yml new file mode 100644 index 0000000000000..098d7e3162d1f --- /dev/null +++ b/tools/ci_build/github/azure-pipelines/templates/jar-packaging.yml @@ -0,0 +1,61 @@ +# This template packages the Java artifacts for either CPU or GPU. +# It calls the PowerShell script with the correct package type and ensures +# that the correct final JAR file is signed and published. +# Currently this file only runs on Windows x64. + +parameters: + - name: package_type + type: string + default: 'cpu' + values: + - 'cpu' + - 'gpu' + +steps: +- checkout: self + submodules: false + +- task: UsePythonVersion@0 + inputs: + versionSpec: '3.13' + addToPath: true + +- task: PipAuthenticate@1 + displayName: 'Pip Authenticate' + inputs: + artifactFeeds: 'Lotus' + +- template: set-version-number-variables-step.yml + +- script: python -m pip install -r $(Build.SourcesDirectory)\tools\ci_build\github\windows\python\requirements.txt + +- task: PythonScript@0 + displayName: 'Package Java Artifacts' + inputs: + scriptPath: $(Build.SourcesDirectory)\tools\ci_build\github\windows\jar_packaging.py + arguments: '--package_type ${{ parameters.package_type }} --build_dir $(Build.BinariesDirectory)' + workingDirectory: '$(Build.BinariesDirectory)\java-artifact' + +- script: dir $(Build.BinariesDirectory)\java-artifact\onnxruntime-java-win-x64 + +- template: jar-esrp-dll.yml + parameters: + JarFileDirectory: '$(Build.BinariesDirectory)\java-artifact\onnxruntime-java-win-x64' + +- task: AzureKeyVault@2 + displayName: 'Get GnuPG signing keys' + inputs: + azureSubscription: 'AIInfraBuildOnnxRuntimeOSS' + KeyVaultName: 'ort-release' + SecretsFilter: 'java-pgp-pwd,java-pgp-key' + RunAsPreJob: false + +- task: PythonScript@0 + displayName: 'Sign files: GnuPG, sha1, and md5' + env: + JAVA_PGP_PWD: $(java-pgp-pwd) + JAVA_PGP_KEY: $(java-pgp-key) + inputs: + scriptPath: '$(Build.SourcesDirectory)/tools/ci_build/github/windows/sign_java_artifacts.py' + arguments: '$(Build.BinariesDirectory)\java-artifact\onnxruntime-java-win-x64' + workingDirectory: '$(Build.SourcesDirectory)' \ No newline at end of file diff --git a/tools/ci_build/github/azure-pipelines/templates/java-api-artifacts-package-and-publish-steps-posix.yml b/tools/ci_build/github/azure-pipelines/templates/java-api-artifacts-package-and-publish-steps-posix.yml index 1c4b0ae5f4137..166b03f6b55e1 100644 --- a/tools/ci_build/github/azure-pipelines/templates/java-api-artifacts-package-and-publish-steps-posix.yml +++ b/tools/ci_build/github/azure-pipelines/templates/java-api-artifacts-package-and-publish-steps-posix.yml @@ -1,28 +1,50 @@ # sets up common build tools for the windows build machines before build parameters: - arch: 'linux-x64' - buildConfig: 'RelWithDebInfo' - artifactName: 'onnxruntime-java-linux-x64' - libraryName: 'libonnxruntime.so' - nativeLibraryName: 'libonnxruntime4j_jni.so' - version: '' - is1ES: false +- name: buildConfig + displayName: Build Configuration + type: string + values: + - 'Release' + - 'Debug' + - 'RelWithDebInfo' + +- name: artifactName + displayName: Artifact Name + type: string + #default: 'onnxruntime-java' + +- name: libraryName + displayName: Main Library Name + type: string + #default: 'libonnxruntime.so' + +- name: nativeLibraryName + displayName: JNI Library Name + type: string + #default: 'libonnxruntime4j_jni.so' + +- name: arch + displayName: Architecture + type: string + #default: 'linux-x64' + steps: -- task: ShellScript@2 - displayName: 'Copy build artifacts for zipping' +- task: PythonScript@0 + inputs: + scriptSource: 'filePath' + scriptPath: 'tools/ci_build/linux_java_copy_strip_binary.py' + arguments: >- + --binary-dir $(Build.BinariesDirectory) + --build-config ${{parameters.buildConfig}} + --artifact-name ${{parameters.artifactName}} + --lib-name ${{parameters.libraryName}} + --native-lib-name ${{parameters.nativeLibraryName}} + --arch ${{parameters.arch}} + displayName: 'Package ONNX Runtime Java Native Libs' + +- task: 1ES.PublishPipelineArtifact@1 inputs: - scriptPath: 'tools/ci_build/github/linux/java_copy_strip_binary.sh' - args: '-r $(Build.BinariesDirectory) -c ${{parameters.buildConfig}} -a ${{parameters.artifactName}} -l ${{parameters.libraryName}} -n ${{parameters.nativeLibraryName}} -v ${{parameters.version}} -h ${{parameters.arch}}' - workingDirectory: '$(Build.BinariesDirectory)/${{parameters.buildConfig}}' + targetPath: '$(Build.BinariesDirectory)/${{parameters.artifactName}}' + artifactName: 'drop-${{parameters.artifactName}}' -- ${{ if eq(parameters.is1ES, true) }}: - - task: 1ES.PublishPipelineArtifact@1 - inputs: - targetPath: '$(Build.BinariesDirectory)/${{parameters.artifactName}}' - artifactName: 'drop-${{parameters.artifactName}}' -- ${{ if eq(parameters.is1ES, false) }}: - - task: PublishBuildArtifacts@1 - inputs: - pathtoPublish: '$(Build.BinariesDirectory)/${{parameters.artifactName}}' - artifactName: 'drop-${{parameters.artifactName}}' diff --git a/tools/ci_build/github/azure-pipelines/templates/mac-cpu-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/templates/mac-cpu-packaging-pipeline.yml index 7547b841c7480..56cc84a90dc68 100644 --- a/tools/ci_build/github/azure-pipelines/templates/mac-cpu-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/templates/mac-cpu-packaging-pipeline.yml @@ -1,3 +1,5 @@ +# This stage fetch built macOS binaries from other stages, sign the binaries, then repack them + parameters: - name: AdditionalBuildFlags displayName: Additional build flags for build.py @@ -13,31 +15,11 @@ parameters: - 1 - 0 -- name: BuildForAllArchs - displayName: Build for all CPU ARCHs - type: boolean - -- name: WithCache - displayName: Build with Cache - type: boolean - default: false - - name: DoESRP displayName: Do ESRP type: boolean default: false -# these 2 parameters are used for debugging. -- name: SpecificArtifact - displayName: Use Specific Artifact (Debugging only) - type: boolean - default: false - -- name: BuildId - displayName: Pipeline BuildId, you could find it in the URL - type: string - default: '0' - stages: - stage: MacOS_C_API_Packaging_CPU dependsOn: [] @@ -47,21 +29,12 @@ stages: MacosArch: 'x86_64' AllowReleasedOpsetOnly: ${{ parameters.AllowReleasedOpsetOnly }} AdditionalBuildFlags: ${{ parameters.AdditionalBuildFlags }} - WithCache: ${{ parameters.WithCache }} - - ${{ if eq(parameters.BuildForAllArchs, true) }}: - - template: mac-cpu-packing-jobs.yml - parameters: - MacosArch: 'arm64' - AllowReleasedOpsetOnly: ${{ parameters.AllowReleasedOpsetOnly }} - AdditionalBuildFlags: ${{ parameters.AdditionalBuildFlags }} - WithCache: ${{ parameters.WithCache }} - - template: mac-cpu-packing-jobs.yml - parameters: - MacosArch: 'universal2' - AllowReleasedOpsetOnly: ${{ parameters.AllowReleasedOpsetOnly }} - AdditionalBuildFlags: ${{ parameters.AdditionalBuildFlags }} - WithCache: ${{ parameters.WithCache }} + - template: mac-cpu-packing-jobs.yml + parameters: + MacosArch: 'arm64' + AllowReleasedOpsetOnly: ${{ parameters.AllowReleasedOpsetOnly }} + AdditionalBuildFlags: ${{ parameters.AdditionalBuildFlags }} - stage: MacOS_C_API_Package_Publish dependsOn: MacOS_C_API_Packaging_CPU @@ -71,68 +44,56 @@ stages: name: 'Azure Pipelines' image: 'macOS-14' os: 'macOS' + templateContext: + inputs: + - input: pipelineArtifact + artifactName: onnxruntime-osx-x86_64 # The files in this artifact are not signed + targetPath: $(Build.ArtifactStagingDirectory) + - input: pipelineArtifact + artifactName: onnxruntime-osx-arm64 # The files in this artifact are not signed + targetPath: $(Build.ArtifactStagingDirectory) + outputs: + - output: pipelineArtifact + targetPath: $(Build.ArtifactStagingDirectory) + artifactName: 'onnxruntime-osx' # The files in this artifact are signed steps: - - checkout: none - - template: flex-downloadPipelineArtifact.yml - parameters: - StepName: 'Download Pipeline onnxruntime-osx-x86_64' - ArtifactName: 'onnxruntime-osx-x86_64' - TargetPath: '$(Build.ArtifactStagingDirectory)' - SpecificArtifact: ${{ parameters.SpecificArtifact }} - BuildId: ${{ parameters.BuildId }} + - checkout: self - - ${{ if eq(parameters.BuildForAllArchs, true) }}: - - template: flex-downloadPipelineArtifact.yml - parameters: - StepName: 'Download Pipeline onnxruntime-osx-arm64' - ArtifactName: 'onnxruntime-osx-arm64' - TargetPath: '$(Build.ArtifactStagingDirectory)' - SpecificArtifact: ${{ parameters.SpecificArtifact }} - BuildId: ${{ parameters.BuildId }} - - template: flex-downloadPipelineArtifact.yml - parameters: - StepName: 'Download Pipeline onnxruntime-osx-universal2' - ArtifactName: 'onnxruntime-osx-universal2' - TargetPath: '$(Build.ArtifactStagingDirectory)' - SpecificArtifact: ${{ parameters.SpecificArtifact }} - BuildId: ${{ parameters.BuildId }} + - task: UsePythonVersion@0 + inputs: + versionSpec: '3.13' + addToPath: true - - ${{ if eq(parameters.DoESRP, true)}}: - - script: | - pushd '$(Build.ArtifactStagingDirectory)' - find . '*.tgz' -exec tar -zxvf {} \; - rm -f *.tgz; - find . -type d -name 'onnxruntime-osx-*' -exec zip -FSr --symlinks {}.zip {} \; - find . -type d -name 'onnxruntime-osx-*' -exec rm -rf {} \; - ls -l - popd - displayName: tgz to zip + - task: PythonScript@0 + displayName: 'Prepare, Create Universal Binary, and Zip with Python' + inputs: + scriptSource: 'filePath' + scriptPath: 'tools/ci_build/prepare_macos_package.py' + arguments: '--staging_dir $(Build.ArtifactStagingDirectory)' - - template: mac-esrp-dylib.yml - parameters: - FolderPath: '$(Build.ArtifactStagingDirectory)' - Pattern: '*.zip' + - template: mac-esrp-dylib.yml + parameters: + FolderPath: '$(Build.ArtifactStagingDirectory)' + Pattern: '*.zip' - - script: | - pushd '$(Build.ArtifactStagingDirectory)' - find . '*.zip' -exec unzip {} \; - rm -f *.zip; - find . -type d -name 'onnxruntime-osx-*' -exec tar -czf {}.tgz {} \; - find . -type d -name 'onnxruntime-osx-*' -exec rm -rf {} \; - ls -l - popd - displayName: zip to tgz - - bash: | - set -ex - mkdir -p $(Agent.TempDirectory)/macpackage - find $(Build.ArtifactStagingDirectory) -name "*.tgz" -exec tar -zxvf {} -C $(Agent.TempDirectory)/macpackage \; - find $(Agent.TempDirectory)/macpackage -name "*.dylib" -exec codesign -dvvv {} \; - find $(Agent.TempDirectory)/macpackage -name "*.dylib" -exec ls -l {} \; - rm -rf $(Agent.TempDirectory)/macpackage - displayName: 'Verify code signing' + - script: | + set -ex + mkdir temp + cd temp + find $(Build.ArtifactStagingDirectory) -name '*.zip' -exec unzip {} \; + rm -rf $(Build.ArtifactStagingDirectory)/*; + find . -type d -name 'onnxruntime-osx-*' -exec tar -czf {}.tgz {} \; + ls -l + mv *.tgz $(Build.ArtifactStagingDirectory) + displayName: 'Unzip Signed Files and Repackage to TGZ' + workingDirectory: $(Agent.TempDirectory) - - task: 1ES.PublishPipelineArtifact@1 - inputs: - targetPath: '$(Build.ArtifactStagingDirectory)' - artifactName: 'onnxruntime-osx' - condition: 'succeededOrFailed()' + - bash: | + set -ex + mkdir -p macpackage + find $(Build.ArtifactStagingDirectory) -name "*.tgz" -exec tar -zxvf {} -C macpackage \; + find macpackage -name "*.dylib" -exec codesign -dvvv {} \; + find macpackage -name "*.dylib" -exec ls -l {} \; + rm -rf macpackage + displayName: 'Verify Code Signing' + workingDirectory: $(Agent.TempDirectory) diff --git a/tools/ci_build/github/azure-pipelines/templates/mac-cpu-packaging-steps.yml b/tools/ci_build/github/azure-pipelines/templates/mac-cpu-packaging-steps.yml index 9a8264a288582..c43bfe2886f22 100644 --- a/tools/ci_build/github/azure-pipelines/templates/mac-cpu-packaging-steps.yml +++ b/tools/ci_build/github/azure-pipelines/templates/mac-cpu-packaging-steps.yml @@ -4,56 +4,22 @@ parameters: values: - 'x86_64' - 'arm64' - - 'universal2' - default: 'x86_64' - name: AdditionalBuildFlags displayName: Additional build flags for build.py type: string default: '' -- name: BuildJava - displayName: Build with Java - type: boolean - default: true -- name: BuildNodejs - displayName: Build with Nodejs - type: boolean - default: false - -- name: WithCache - displayName: Build with Cache - type: boolean - default: false - -- name: CacheDir - displayName: Cache Directory - type: string - default: '' - -- name: Today - type: string - default: "" steps: -- template: mac-build-step-with-cache.yml - parameters: - WithCache: ${{ parameters.WithCache }} - Today: ${{ parameters.Today }} - AdditionalKey: onnxruntime_${{ parameters.MacosArch }} - CacheDir: ${{ parameters.CacheDir }} - ChangeEveryCommit: true - BuildStep: - - script: | - set -e -x - rm -rf $(Build.BinariesDirectory)/Release - python3 $(Build.SourcesDirectory)/tools/ci_build/build.py --update --build ${{ parameters.AdditionalBuildFlags }} --build_dir $(Build.BinariesDirectory) --skip_submodule_sync --parallel --use_vcpkg --use_vcpkg_ms_internal_asset_cache --use_binskim_compliant_compile_flags --build_shared_lib --config Release --use_vcpkg --use_vcpkg_ms_internal_asset_cache - cd $(Build.BinariesDirectory)/Release - make install DESTDIR=$(Build.BinariesDirectory)/installed - displayName: 'Build ${{ parameters.MacosArch }}' - env: - CCACHE_DIR: ${{ parameters.CacheDir }} +- script: | + set -e -x + rm -rf $(Build.BinariesDirectory)/Release + python3 $(Build.SourcesDirectory)/tools/ci_build/build.py --update --build ${{ parameters.AdditionalBuildFlags }} --build_dir $(Build.BinariesDirectory) --skip_submodule_sync --parallel 3 --use_vcpkg --use_vcpkg_ms_internal_asset_cache --use_binskim_compliant_compile_flags --build_shared_lib --config Release --use_vcpkg --use_vcpkg_ms_internal_asset_cache + cd $(Build.BinariesDirectory)/Release + make install DESTDIR=$(Build.BinariesDirectory)/installed + displayName: 'Build ${{ parameters.MacosArch }}' - ${{ if eq(parameters.MacosArch, 'x86_64') }}: - script: | @@ -77,9 +43,9 @@ steps: replaceExistingArchive: true - script: | - set -e -x - mkdir -p $(Build.ArtifactStagingDirectory)/testdata - cp $(Build.BinariesDirectory)/Release/libcustom_op_library.dylib $(Build.ArtifactStagingDirectory)/testdata + set -e -x + mkdir -p $(Build.ArtifactStagingDirectory)/testdata + cp $(Build.BinariesDirectory)/Release/libcustom_op_library.dylib $(Build.ArtifactStagingDirectory)/testdata displayName: 'Copy libcustom_op_library.dylib to ArtifactStagingDirectory' condition: and(succeeded(), eq('${{ parameters.MacosArch }}', 'x86_64')) @@ -88,23 +54,19 @@ steps: targetPath: '$(Build.ArtifactStagingDirectory)' artifactName: 'onnxruntime-osx-${{ parameters.MacosArch }}' -- ${{ if eq(parameters.BuildJava, true) }}: - - template: java-api-artifacts-package-and-publish-steps-posix.yml - parameters: - arch: 'osx-${{ parameters.MacosArch }}' - buildConfig: 'Release' - artifactName: 'onnxruntime-java-osx-${{ parameters.MacosArch }}' - version: '$(OnnxRuntimeVersion)' - libraryName: 'libonnxruntime.dylib' - nativeLibraryName: 'libonnxruntime4j_jni.dylib' - is1ES: true +- template: java-api-artifacts-package-and-publish-steps-posix.yml + parameters: + arch: 'osx-${{ parameters.MacosArch }}' + buildConfig: 'Release' + artifactName: 'onnxruntime-java-osx-${{ parameters.MacosArch }}' + libraryName: 'libonnxruntime.dylib' + nativeLibraryName: 'libonnxruntime4j_jni.dylib' -- ${{ if eq(parameters.BuildNodejs, true) }}: - - template: nodejs-artifacts-package-and-publish-steps-posix.yml - parameters: - ${{ if eq(parameters.MacosArch, 'x86_64') }}: - arch: x64 - ${{ if eq(parameters.MacosArch, 'arm64') }}: - arch: arm64 - os: 'darwin' - artifactName: 'drop-onnxruntime-nodejs-osx-${{ parameters.MacosArch }}' +- template: nodejs-artifacts-package-and-publish-steps-posix.yml + parameters: + ${{ if eq(parameters.MacosArch, 'x86_64') }}: + arch: x64 + ${{ if eq(parameters.MacosArch, 'arm64') }}: + arch: arm64 + os: 'darwin' + artifactName: 'drop-onnxruntime-nodejs-osx-${{ parameters.MacosArch }}' diff --git a/tools/ci_build/github/azure-pipelines/templates/mac-cpu-packing-jobs.yml b/tools/ci_build/github/azure-pipelines/templates/mac-cpu-packing-jobs.yml index 095a53b2e44b9..8222acb43ea06 100644 --- a/tools/ci_build/github/azure-pipelines/templates/mac-cpu-packing-jobs.yml +++ b/tools/ci_build/github/azure-pipelines/templates/mac-cpu-packing-jobs.yml @@ -4,13 +4,6 @@ parameters: values: - 'x86_64' - 'arm64' - - 'universal2' - default: 'x86_64' - -- name: WithCache - displayName: Build with Cache - type: boolean - default: false - name: AdditionalBuildFlags displayName: Additional build flags for build.py @@ -33,9 +26,6 @@ jobs: variables: MACOSX_DEPLOYMENT_TARGET: '13.4' ALLOW_RELEASED_ONNX_OPSET_ONLY: ${{ parameters.AllowReleasedOpsetOnly }} - TODAY: $[format('{0:dd}{0:MM}{0:yyyy}', pipeline.startTime)] - PROTO_CACHE_DIR: $(Pipeline.Workspace)/ccache_proto - ORT_CACHE_DIR: $(Pipeline.Workspace)/ccache_ort pool: name: "Azure Pipelines" image: 'macOS-15' @@ -46,67 +36,33 @@ jobs: clean: true submodules: none - - task: UsePythonVersion@0 - displayName: Use Python 3.10 - inputs: - versionSpec: 3.10 - - - task: NodeTool@0 - inputs: - versionSpec: '22.x' - - task: JavaToolInstaller@0 inputs: versionSpec: "17" jdkArchitectureOption: "x64" jdkSourceOption: 'PreInstalled' - - template: set-version-number-variables-step.yml - - template: use-xcode-version.yml parameters: xcodeVersion: 16.4 + - template: setup-build-tools.yml + - template: set-version-number-variables-step.yml + - script: | set -e -x - export PATH=$(Build.BinariesDirectory)/installed/bin:$PATH export ONNX_ML=1 export CMAKE_ARGS="-DONNX_GEN_PB_TYPE_STUBS=ON -DONNX_WERROR=OFF" - python3 -m pip install -r '$(Build.SourcesDirectory)/tools/ci_build/github/linux/docker/scripts/requirements.txt' - - - - ${{ if eq(parameters.MacosArch, 'universal2') }}: - - template: mac-cpu-packaging-steps.yml - parameters: - MacosArch: ${{ parameters.MacosArch }} - AdditionalBuildFlags: ${{ parameters.AdditionalBuildFlags }} --use_coreml --use_webgpu --no_kleidiai --cmake_extra_defines CMAKE_OSX_ARCHITECTURES="arm64;x86_64" - BuildJava: false - BuildNodejs: false - WithCache: ${{ parameters.WithCache }} - ${{ if eq(parameters.WithCache, true) }}: - Today: $(TODAY) - CacheDir: $(ORT_CACHE_DIR) + python3 -m pip install -r '$(Build.SourcesDirectory)/tools/ci_build/github/linux/docker/scripts/requirements.txt' - ${{ if eq(parameters.MacosArch, 'arm64') }}: - template: mac-cpu-packaging-steps.yml parameters: MacosArch: ${{ parameters.MacosArch }} AdditionalBuildFlags: ${{ parameters.AdditionalBuildFlags }} --build_nodejs --build_java --use_coreml --use_webgpu --cmake_extra_defines CMAKE_OSX_ARCHITECTURES=arm64 - BuildJava: true - BuildNodejs: true - WithCache: ${{ parameters.WithCache }} - ${{ if eq(parameters.WithCache, true) }}: - Today: $(TODAY) - CacheDir: $(ORT_CACHE_DIR) - ${{ if eq(parameters.MacosArch, 'x86_64') }}: - template: mac-cpu-packaging-steps.yml parameters: MacosArch: ${{ parameters.MacosArch }} - AdditionalBuildFlags: ${{ parameters.AdditionalBuildFlags }} --build_nodejs --build_java --use_coreml --use_webgpu --cmake_extra_defines CMAKE_OSX_ARCHITECTURES=x86_64 - BuildJava: true - BuildNodejs: true - WithCache: ${{ parameters.WithCache }} - ${{ if eq(parameters.WithCache, true) }}: - Today: $(TODAY) - CacheDir: $(ORT_CACHE_DIR) + AdditionalBuildFlags: ${{ parameters.AdditionalBuildFlags }} --build_nodejs --build_java --use_coreml --use_webgpu --cmake_extra_defines CMAKE_OSX_ARCHITECTURES=x86_64 \ No newline at end of file diff --git a/tools/ci_build/github/azure-pipelines/templates/make_java_win_binaries.yml b/tools/ci_build/github/azure-pipelines/templates/make_java_win_binaries.yml index 0d62ed7907a67..d1ea61ada90c3 100644 --- a/tools/ci_build/github/azure-pipelines/templates/make_java_win_binaries.yml +++ b/tools/ci_build/github/azure-pipelines/templates/make_java_win_binaries.yml @@ -1,59 +1,50 @@ parameters: - - name: msbuildPlatform - type: string - - name: java_artifact_id - type: string - - name: buildOnly - type: boolean +- name: msbuildPlatform + type: string +- name: java_artifact_id + type: string +- name: buildOnly + type: boolean + default: false +- name: PreReleaseVersionSuffixString + displayName: Suffix added to pre-release package version. Only used if IsReleaseBuild is true. Denotes the type of pre-release package. + type: string + values: + - alpha + - beta + - rc + - none + +- name: PreReleaseVersionSuffixNumber + displayName: Number added to pre-release package version. Only used if IsReleaseBuild is true. Denotes the sequence of a pre-release package. + type: number steps: - - task: CmdLine@2 - displayName: 'Gradle cmakeCheck' - inputs: - ${{ if eq(parameters.buildOnly, true) }}: - script: | - call gradlew.bat testClasses -DcmakeBuildDir=$(Build.BinariesDirectory)\RelWithDebInfo - call gradlew.bat cmakeCheck -x test -DcmakeBuildDir=$(Build.BinariesDirectory)\RelWithDebInfo --warning-mode all - workingDirectory: $(Build.SourcesDirectory)\java - ${{ else }}: - script: | - call gradlew.bat cmakeCheck -DcmakeBuildDir=$(Build.BinariesDirectory)\RelWithDebInfo --warning-mode all - workingDirectory: $(Build.SourcesDirectory)\java +- task: PowerShell@2 + displayName: 'Build and Package Java Artifacts' + inputs: + targetType: 'inline' + script: | + # Define arguments for the Python script + $scriptArgs = @( + "--sources-dir", "$(Build.SourcesDirectory)", + "--binaries-dir", "$(Build.BinariesDirectory)", + "--platform", "${{ parameters.msbuildPlatform }}", + "--build-config", "RelWithDebInfo", + "--java-artifact-id", "${{ parameters.java_artifact_id }}", + "--pre-release-version-suffix-string", "${{ parameters.PreReleaseVersionSuffixString }}", + "--pre-release-version-suffix-number", "${{ parameters.PreReleaseVersionSuffixNumber }}", + "--commit-hash", "$(OnnxRuntimeGitCommitHash)" + ) + + # Conditionally add the --build-only flag if the parameter is true + if ('${{ parameters.buildOnly }}' -eq 'True') { + $scriptArgs += "--build-only" + } + + # Define the path to the python script within your repository + $scriptPath = "$(Build.SourcesDirectory)/tools/ci_build/manage_java_artifacts.py" - - task: CmdLine@2 - displayName: 'Add symbols and notices to Java' - inputs: - script: | - @echo on - cd $(Build.BinariesDirectory)\RelWithDebInfo - set NATIVE_FOLDER=$(Build.BinariesDirectory)\onnxruntime-java-win-${{ parameters.msbuildPlatform }}\stage\ai\onnxruntime\native\win-x64 - mkdir %NATIVE_FOLDER% - echo "Directories created" - copy .\java\build\libs\*.jar $(Build.BinariesDirectory)\onnxruntime-java-win-${{ parameters.msbuildPlatform }} - pushd $(Build.BinariesDirectory)\onnxruntime-java-win-${{ parameters.msbuildPlatform }} - set artifact_id=${{ parameters.java_artifact_id }} - jar xf onnxruntime-$(OnnxRuntimeVersion).jar META-INF\maven\com.microsoft.onnxruntime\%artifact_id%\pom.xml - move META-INF\maven\com.microsoft.onnxruntime\%artifact_id%\pom.xml onnxruntime-$(OnnxRuntimeVersion).pom - rd /s /q META-INF - popd - copy .\RelWithDebInfo\onnxruntime.pdb %NATIVE_FOLDER% - copy .\RelWithDebInfo\onnxruntime4j_jni.pdb %NATIVE_FOLDER% - copy $(Build.SourcesDirectory)\docs\Privacy.md $(Build.BinariesDirectory)\onnxruntime-java-win-${{ parameters.msbuildPlatform }}\stage\Privacy.md - copy $(Build.SourcesDirectory)\ThirdPartyNotices.txt $(Build.BinariesDirectory)\onnxruntime-java-win-${{ parameters.msbuildPlatform }}\stage\ThirdPartyNotices.txt - @echo $(OnnxRuntimeGitCommitHash) > $(Build.BinariesDirectory)\onnxruntime-java-win-${{ parameters.msbuildPlatform }}\stage\GIT_COMMIT_ID - pushd $(Build.BinariesDirectory)\onnxruntime-java-win-${{ parameters.msbuildPlatform }}\stage - jar uf $(Build.BinariesDirectory)\onnxruntime-java-win-${{ parameters.msbuildPlatform }}\onnxruntime-$(OnnxRuntimeVersion).jar ai\onnxruntime\native\win-x64\onnxruntime.pdb - jar uf $(Build.BinariesDirectory)\onnxruntime-java-win-${{ parameters.msbuildPlatform }}\onnxruntime-$(OnnxRuntimeVersion).jar ai\onnxruntime\native\win-x64\onnxruntime4j_jni.pdb - jar uf $(Build.BinariesDirectory)\onnxruntime-java-win-${{ parameters.msbuildPlatform }}\onnxruntime-$(OnnxRuntimeVersion).jar Privacy.md ThirdPartyNotices.txt GIT_COMMIT_ID - popd - pushd $(Build.SourcesDirectory)\java\build\classes\java\test - if %errorlevel% neq 0 exit /b %errorlevel% - jar cvf $(Build.BinariesDirectory)\onnxruntime-java-win-${{ parameters.msbuildPlatform }}\testing.jar . - if %errorlevel% neq 0 exit /b %errorlevel% - popd - pushd $(Build.SourcesDirectory)\java\build\resources\test - rd /s /q ai\onnxruntime\native - jar uvf $(Build.BinariesDirectory)\onnxruntime-java-win-${{ parameters.msbuildPlatform }}\testing.jar . - popd - rd /s /q $(Build.BinariesDirectory)\onnxruntime-java-win-${{ parameters.msbuildPlatform }}\stage - dir /s /b $(Build.BinariesDirectory)\onnxruntime-java-win-${{ parameters.msbuildPlatform }} + # Execute the Python script, passing all arguments + Write-Host "Executing Python script: $scriptPath with arguments: $($scriptArgs -join ' ')" + python $scriptPath $scriptArgs \ No newline at end of file diff --git a/tools/ci_build/github/azure-pipelines/templates/py-macos.yml b/tools/ci_build/github/azure-pipelines/templates/py-macos.yml new file mode 100644 index 0000000000000..470f78e90b90a --- /dev/null +++ b/tools/ci_build/github/azure-pipelines/templates/py-macos.yml @@ -0,0 +1,79 @@ +parameters: +- name: arch + type: string + +- name: python_version + type: string + +- name: cmake_build_type + type: string + default: 'Release' + values: + - Debug + - Release + - RelWithDebInfo + - MinSizeRel + +- name: extra_build_arg + type: string + default: '' + +jobs: +- job: Mac_${{ parameters.arch }}_${{ replace(parameters.python_version,'.','_') }} + timeoutInMinutes: 240 + workspace: + clean: all + pool: + name: "Azure Pipelines" + image: "macOS-15" + os: macOS + templateContext: + outputs: + - output: pipelineArtifact + targetPath: $(Build.SourcesDirectory)/build/Release/dist/fixed_wheels + artifactName: onnxruntime-macos-${{ parameters.arch }}_${{ replace(parameters.python_version,'.','_') }} + + variables: + - name: MACOSX_DEPLOYMENT_TARGET + value: '13.4' + + steps: + - checkout: self + clean: true + submodules: none + + - task: UsePythonVersion@0 + displayName: 'Use Python' + inputs: + versionSpec: ${{ parameters.python_version }} + + - task: PipAuthenticate@1 + displayName: 'Pip Authenticate' + inputs: + artifactFeeds: 'Lotus' + + - template: use-xcode-version.yml + parameters: + xcodeVersion: '16.4.0' + + - script: | + set -e -x + export _PYTHON_HOST_PLATFORM=macosx-${{variables.MACOSX_DEPLOYMENT_TARGET}}-${{ parameters.arch }} + python3 -m pip install -r '$(Build.SourcesDirectory)/tools/ci_build/github/linux/docker/scripts/requirements.txt' + python3 $(Build.SourcesDirectory)/tools/ci_build/build.py \ + --build_dir $(Build.SourcesDirectory)/build \ + --use_vcpkg --use_vcpkg_ms_internal_asset_cache \ + --use_binskim_compliant_compile_flags \ + --config Release \ + --build_wheel \ + --use_coreml ${{ parameters.extra_build_arg }} \ + --cmake_extra_defines CMAKE_OSX_ARCHITECTURES=${{ parameters.arch }} \ + --update --skip_submodule_sync --build --parallel + python -m pip install --upgrade delocate + cd '$(Build.SourcesDirectory)/build/Release/dist' + ls + for file in *.whl + do + delocate-listdeps "$file" + delocate-wheel --require-archs=${{ parameters.arch }} -w fixed_wheels -v "$file" + done \ No newline at end of file diff --git a/tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml b/tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml index 62b9e144e578a..c58b0dcbcd90c 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml @@ -57,18 +57,7 @@ jobs: clean: true submodules: none - - template: telemetry-steps.yml - - - task: UsePythonVersion@0 - inputs: - versionSpec: ${{ parameters.PYTHON_VERSION }} - addToPath: true - architecture: 'arm64' - - - task: PipAuthenticate@1 - displayName: 'Pip Authenticate' - inputs: - artifactFeeds: 'Lotus' + - template: setup-build-tools.yml - task: PythonScript@0 inputs: diff --git a/tools/ci_build/github/azure-pipelines/templates/py-win-arm64ec-qnn.yml b/tools/ci_build/github/azure-pipelines/templates/py-win-arm64ec-qnn.yml index cf1c5cb79dc54..700450defb46b 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-win-arm64ec-qnn.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-win-arm64ec-qnn.yml @@ -45,18 +45,7 @@ jobs: clean: true submodules: recursive - - template: telemetry-steps.yml - - - task: UsePythonVersion@0 - inputs: - versionSpec: $(PythonVersion) - addToPath: true - architecture: 'x64' - - - task: PipAuthenticate@1 - displayName: 'Pip Authenticate' - inputs: - artifactFeeds: 'Lotus' + - template: setup-build-tools.yml - script: python -m pip install -r $(Build.SourcesDirectory)\tools\ci_build\github\windows\python\requirements.txt diff --git a/tools/ci_build/github/azure-pipelines/templates/py-win-x64-qnn.yml b/tools/ci_build/github/azure-pipelines/templates/py-win-x64-qnn.yml index c135a5c907205..051daf0638456 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-win-x64-qnn.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-win-x64-qnn.yml @@ -49,7 +49,7 @@ jobs: clean: true submodules: recursive - - template: telemetry-steps.yml + - template: setup-build-tools.yml - task: UsePythonVersion@0 inputs: diff --git a/tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml b/tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml index e633bcb457ccb..4d1a7fa926b20 100644 --- a/tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml +++ b/tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml @@ -52,10 +52,7 @@ stages: steps: - template: set-version-number-variables-step.yml - - task: UsePythonVersion@0 - inputs: - versionSpec: '3.12' - addToPath: true + - template: setup-build-tools.yml - template: jobs/download_win_qnn_sdk.yml parameters: diff --git a/tools/ci_build/github/azure-pipelines/templates/setup-build-tools.yml b/tools/ci_build/github/azure-pipelines/templates/setup-build-tools.yml new file mode 100644 index 0000000000000..cf279cce17609 --- /dev/null +++ b/tools/ci_build/github/azure-pipelines/templates/setup-build-tools.yml @@ -0,0 +1,50 @@ +# Setup python/nodejs/cmake/vcpkg tools. Also, setup telemetry header file if the current OS is Windows. + +parameters: +- name: BuildArch + displayName: BuildArch + type: string + default: 'x64' + +- name: actionVersion + type: string + default: 'v0.0.8' + +steps: +- template: telemetry-steps.yml + +- task: UsePythonVersion@0 + displayName: Use Python 3.12 + inputs: + versionSpec: 3.12 + ${{ if eq(parameters.BuildArch, 'x86') }}: + architecture: 'x86' + +- task: PipAuthenticate@1 + displayName: 'Pip Authenticate' + inputs: + artifactFeeds: 'Lotus' + +- task: NodeTool@0 + inputs: + versionSpec: '22.x' + +- script: python3 -m pip install requests + +- task: PythonScript@0 + displayName: 'Run GitHub Action via Python Wrapper' + inputs: + scriptPath: 'tools/ci_build/run_gh_action.py' + arguments: '${{ parameters.actionVersion }}' + +- pwsh: | + $ErrorActionPreference = 'Stop' + + Write-Host "Verifying CMake installation..." + cmake --version + + Write-Host "Verifying vcpkg installation..." + & "$env:VCPKG_INSTALLATION_ROOT/vcpkg" version + + Write-Host "Tool verification successful." + displayName: 'Verify Tools' \ No newline at end of file diff --git a/tools/ci_build/github/azure-pipelines/templates/setup-maven.yml b/tools/ci_build/github/azure-pipelines/templates/setup-maven.yml new file mode 100644 index 0000000000000..7ad755c50e541 --- /dev/null +++ b/tools/ci_build/github/azure-pipelines/templates/setup-maven.yml @@ -0,0 +1,47 @@ +steps: +- task: AzureCLI@2 + displayName: 'Download and Extract Maven using Azure CLI' + inputs: + azureSubscription: 'AIInfraBuildOnnxRuntimeOSS' + scriptType: 'pscore' # Use PowerShell Core + scriptLocation: 'inlineScript' + inlineScript: | + # Define the scope for the access token + $authScope = "https://mspmecloud.onmicrosoft.com/RebuildManager.Web/.default" + + Write-Host "Requesting access token for scope: $authScope" + $tokenInfo = az account get-access-token --scope $authScope | ConvertFrom-Json + + # Set the token as an environment variable for the next tool to use + $env:TRT_UPLOAD_AUTH_TOKEN = $tokenInfo.accessToken + Write-Host "Successfully configured TRT_UPLOAD_AUTH_TOKEN environment variable." + + # Execute the Terrapin Retrieval Tool to download Maven + Write-Host "Downloading Maven..." + & C:\local\Terrapin\TerrapinRetrievalTool.exe -b https://vcpkg.storage.devpackages.microsoft.io/artifacts/ -a true -u Environment -p https://dlcdn.apache.org/maven/maven-3/3.9.11/binaries/apache-maven-3.9.11-bin.zip -s 03e2d65d4483a3396980629f260e25cac0d8b6f7f2791e4dc20bc83f9514db8d0f05b0479e699a5f34679250c49c8e52e961262ded468a20de0be254d8207076 -d $(Agent.TempDirectory)\maven.zip + + # Check if the download was successful + if ($LASTEXITCODE -ne 0) { + throw "Error downloading maven. Exit code: $LASTEXITCODE" + } + Write-Host "Maven downloaded successfully." + + # Extract the downloaded maven zip file + $arguments = "x", "$(Agent.TempDirectory)\maven.zip", "-y", "-o$(Agent.TempDirectory)" + Write-Output "Executing: 7z.exe $arguments" + & 7z.exe $arguments + + # Check if the extraction was successful + if ($LASTEXITCODE -ne 0) { + throw "Error extracting maven.zip. Exit code: $LASTEXITCODE" + } + Write-Host "Maven extracted successfully." + + # Prepend the Maven bin directory to the PATH for subsequent steps in the job + Write-Host "Adding Maven to the pipeline PATH." + Write-Host "##vso[task.prependpath]$(Agent.TempDirectory)\apache-maven-3.9.11\bin" + +- script: | + echo "Verifying Maven installation..." + mvn --version + displayName: 'Verify Maven Version' \ No newline at end of file diff --git a/tools/ci_build/github/azure-pipelines/templates/telemetry-steps.yml b/tools/ci_build/github/azure-pipelines/templates/telemetry-steps.yml index a8bc789e1cffe..8db4a8f8c8658 100644 --- a/tools/ci_build/github/azure-pipelines/templates/telemetry-steps.yml +++ b/tools/ci_build/github/azure-pipelines/templates/telemetry-steps.yml @@ -5,6 +5,7 @@ steps: # TELEMETRYGUID is a runtime variable that is stored on the pipeline in an old-fashioned way. So it cannot be used in # template expressions. We access it through env variables. - task: PowerShell@2 + condition: and(succeeded(), eq(variables['Agent.OS'], 'Windows_NT')) displayName: 'Set TelemetryOption variable and optionally create TraceLoggingConfigPrivate.h for WinML Telemetry' inputs: targetType: filePath diff --git a/tools/ci_build/github/azure-pipelines/templates/win-ci.yml b/tools/ci_build/github/azure-pipelines/templates/win-ci.yml index eec0f273581a2..6d6690dd8a78e 100644 --- a/tools/ci_build/github/azure-pipelines/templates/win-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/win-ci.yml @@ -40,6 +40,19 @@ parameters: type: string default: '' +- name: PreReleaseVersionSuffixString + displayName: Suffix added to pre-release package version. Only used if IsReleaseBuild is true. Denotes the type of pre-release package. + type: string + values: + - alpha + - beta + - rc + - none + +- name: PreReleaseVersionSuffixNumber + displayName: Number added to pre-release package version. Only used if IsReleaseBuild is true. Denotes the sequence of a pre-release package. + type: number + # for inference packages '', for training packages '-training' # used for drop-extra and c api artifacts (onnxruntime-win-* or onnxrutime-training-win-*) - name: artifact_name_suffix @@ -89,7 +102,7 @@ stages: ${{ else }}: buildJavaParameter: '' ${{ if eq(parameters['UseIncreasedTimeoutForTests'], 'true') }}: - timeoutParameter: '--test_all_timeout 72000' + timeoutParameter: '--ctest_timeout 72000' ${{ else }}: timeoutParameter: '' jobs: @@ -110,6 +123,11 @@ stages: - output: pipelineArtifact targetPath: $(Build.ArtifactStagingDirectory) artifactName: 'onnxruntime${{ parameters.artifact_name_suffix }}-win-${{ parameters.packageName }}' + + - ${{ if eq(parameters.buildJava, 'true') }}: + - output: pipelineArtifact + targetPath: $(Build.BinariesDirectory)\onnxruntime-java-win-${{ parameters.msbuildPlatform }} + artifactName: 'drop-onnxruntime-java-win-${{ parameters.packageName }}${{parameters.artifact_name_suffix}}' # GPU build has two jobs. This is the first one. - ${{ if contains(parameters.ort_build_pool_name, 'GPU') }}: - output: pipelineArtifact @@ -134,18 +152,7 @@ stages: clean: true submodules: none - - task: UsePythonVersion@0 - inputs: - versionSpec: '3.12' - addToPath: true - architecture: ${{ parameters.buildArch }} - - - template: telemetry-steps.yml - - - task: PipAuthenticate@1 - displayName: 'Pip Authenticate' - inputs: - artifactFeeds: 'Lotus' + - template: setup-build-tools.yml - ${{ if eq(parameters['buildJava'], 'true') }}: - task: JavaToolInstaller@0 @@ -154,12 +161,6 @@ stages: jdkArchitectureOption: ${{ parameters.buildArch }} jdkSourceOption: 'PreInstalled' - - - task: NodeTool@0 - condition: and(succeeded(), eq('${{ parameters.buildNodejs}}', true)) - inputs: - versionSpec: '22.x' - - ${{ if ne(parameters.CudaVersion, '') }}: - template: jobs/download_win_gpu_library.yml parameters: @@ -183,21 +184,15 @@ stages: # For CPU job, tests are run in the same machine as building - ${{ if eq(parameters.buildJava, 'true') }}: + - template: setup-maven.yml - template: make_java_win_binaries.yml parameters: msbuildPlatform: ${{ parameters.msbuildPlatform }} java_artifact_id: ${{ parameters.java_artifact_id }} - ${{ if or(contains(parameters.buildparameter, 'use_cuda'), contains(parameters.buildparameter, 'use_tensorrt')) }}: - # When it is a GPU build, we only assemble the java binaries, testing will be done in the later stage with GPU machine - buildOnly: true - ${{ else }}: - buildOnly: false - - - task: 1ES.PublishPipelineArtifact@1 - displayName: 'Publish Java temp binaries' - inputs: - targetPath: '$(Build.BinariesDirectory)\onnxruntime-java-win-${{ parameters.msbuildPlatform }}' - artifactName: 'drop-onnxruntime-java-win-${{ parameters.packageName }}${{parameters.artifact_name_suffix}}' + PreReleaseVersionSuffixString: ${{ parameters.PreReleaseVersionSuffixString }} + PreReleaseVersionSuffixNumber: ${{ parameters.PreReleaseVersionSuffixNumber }} + buildOnly: true + # All GPU builds will be tested in the next stage with GPU machine - ${{ if contains(parameters.ort_build_pool_name, 'CPU') }}: - task: PythonScript@0 diff --git a/tools/ci_build/github/azure-pipelines/templates/win-web-ci.yml b/tools/ci_build/github/azure-pipelines/templates/win-web-ci.yml index 01f73a63075e3..839d5d3ac9144 100644 --- a/tools/ci_build/github/azure-pipelines/templates/win-web-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/win-web-ci.yml @@ -76,9 +76,9 @@ jobs: git checkout -- .gitattributes workingDirectory: '$(Build.SourcesDirectory)' displayName: 'Testing: force EOL to lf on windows for /js/**' - - task: NodeTool@0 - inputs: - versionSpec: '22.x' + + - template: setup-build-tools.yml + - task: DownloadPipelineArtifact@2 inputs: patterns: '${{ parameters.BuildConfig }}_wasm/**/*' diff --git a/tools/ci_build/github/azure-pipelines/templates/windowsai-steps.yml b/tools/ci_build/github/azure-pipelines/templates/windowsai-steps.yml index a084d28e84c1e..c225cb8677d11 100644 --- a/tools/ci_build/github/azure-pipelines/templates/windowsai-steps.yml +++ b/tools/ci_build/github/azure-pipelines/templates/windowsai-steps.yml @@ -23,19 +23,9 @@ jobs: inputs: version: '6.x' - - task: UsePythonVersion@0 - inputs: - versionSpec: '3.12' - addToPath: true - ${{ if eq(parameters.BuildArch, 'x86') }}: - architecture: 'x86' - - - task: PipAuthenticate@1 - displayName: 'Pip Authenticate' - inputs: - artifactFeeds: 'Lotus' - - - template: telemetry-steps.yml + - template: setup-build-tools.yml + parameters: + BuildArch: ${{ parameters.BuildArch }} - task: NuGetCommand@2 displayName: 'NuGet restore' diff --git a/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml deleted file mode 100644 index 48e6b9c986351..0000000000000 --- a/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml +++ /dev/null @@ -1,112 +0,0 @@ -##### start trigger Don't edit it manually, Please do edit set-trigger-rules.py #### -### please do rerun set-trigger-rules.py ### -trigger: - branches: - include: - - main - - rel-* - paths: - exclude: - - docs/** - - README.md - - CONTRIBUTING.md - - BUILD.md - - 'js/web' - - 'onnxruntime/core/providers/js' -pr: - branches: - include: - - main - - rel-* - paths: - exclude: - - docs/** - - README.md - - CONTRIBUTING.md - - BUILD.md - - 'js/web' - - 'onnxruntime/core/providers/js' -#### end trigger #### - -parameters: - -- name: QnnSdk - displayName: QNN SDK version - type: string - default: 2.38.0.250901 - -jobs: -- job: 'BUILD_QNN_EP' - pool: 'Onnxruntime-QNNEP-Windows-2022-CPU' - variables: - MsbuildArguments: '-detailedsummary -maxcpucount -consoleloggerparameters:PerformanceSummary' - OnnxRuntimeBuildDirectory: '$(Build.BinariesDirectory)' - DOTNET_SKIP_FIRST_TIME_EXPERIENCE: true - buildArch: x64 - setVcvars: true - BuildConfig: 'RelWithDebInfo' - ALLOW_RELEASED_ONNX_OPSET_ONLY: '1' - TODAY: $[format('{0:dd}{0:MM}{0:yyyy}', pipeline.startTime)] - timeoutInMinutes: 120 - workspace: - clean: all - strategy: - matrix: - SHARED_LIB: - QnnLibKind: 'shared_lib' - STATIC_LIB: - QnnLibKind: 'static_lib' - steps: - - - task: UsePythonVersion@0 - inputs: - versionSpec: '3.12' - addToPath: true - architecture: $(buildArch) - - - template: templates/jobs/download_win_qnn_sdk.yml - parameters: - QnnSDKVersion: ${{ parameters.QnnSdk }} - - - template: templates/jobs/win-ci-build-steps.yml - parameters: - WithCache: True - Today: $(TODAY) - AdditionalKey: "win-qnn | $(BuildConfig)" - BuildPyArguments: >- - --config $(BuildConfig) - --build_dir $(Build.BinariesDirectory) - --cmake_generator "Visual Studio 17 2022" - --build_java - --build_shared_lib - --use_qnn $(QnnLibKind) - --qnn_home $(QnnSDKRootDir) - --use_binskim_compliant_compile_flags - --update --parallel - MsbuildArguments: $(MsbuildArguments) - BuildArch: $(buildArch) - Platform: 'x64' - BuildConfig: $(BuildConfig) - - - script: | - python $(Build.SourcesDirectory)\tools\ci_build\build.py ^ - --config $(BuildConfig) ^ - --build_dir $(Build.BinariesDirectory) ^ - --cmake_generator "Visual Studio 17 2022" ^ - --build_java ^ - --build_shared_lib ^ - --use_qnn $(QnnLibKind) ^ - --qnn_home $(QnnSDKRootDir) ^ - --use_binskim_compliant_compile_flags ^ - --test --enable_onnx_tests - displayName: 'Run unit tests' - - - script: | - .\$(BuildConfig)\onnx_test_runner -j 1 -e qnn -i "backend_path|$(QnnSDKRootDir)\lib\x86_64-windows-msvc\QnnCpu.dll" $(Build.SourcesDirectory)\cmake\external\onnx\onnx\backend\test\data\node - workingDirectory: '$(Build.BinariesDirectory)\$(BuildConfig)' - displayName: 'Run ONNX Tests' - - - script: | - .\$(BuildConfig)\onnx_test_runner -j 1 -e qnn -i "backend_path|$(QnnSDKRootDir)\lib\x86_64-windows-msvc\QnnCpu.dll" C:\data\float32_models - workingDirectory: '$(Build.BinariesDirectory)\$(BuildConfig)' - displayName: 'Run float32 model tests' diff --git a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cpu b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cpu index 177df14d6eaee..2a65e7c26b20b 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cpu +++ b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cpu @@ -1,4 +1,5 @@ -FROM onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cpu_x64_almalinux8_gcc14:20250724.1 +ARG BASEIMAGE=onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cpu_x64_almalinux8_gcc14:20250724.1 +FROM $BASEIMAGE ENV JAVA_HOME=/usr/lib/jvm/msopenjdk-17 diff --git a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_rocm b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_rocm index 957eef8046eaf..3337af3be6074 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_rocm +++ b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_rocm @@ -1,4 +1,5 @@ -FROM onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cpu_x64_almalinux8_gcc14:20250724.1 +ARG BASEIMAGE=onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cpu_x64_almalinux8_gcc14:20250724.1 +FROM $BASEIMAGE ARG ROCM_VERSION=6.2.3 #Add our own dependencies diff --git a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_webgpu b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_webgpu index 56d67599f0bce..0007a4e06f7c0 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_webgpu +++ b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_webgpu @@ -1,4 +1,5 @@ -FROM onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cpu_x64_almalinux8_gcc14:20250724.1 +ARG BASEIMAGE=onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cpu_x64_almalinux8_gcc14:20250724.1 +FROM $BASEIMAGE ENV JAVA_HOME=/usr/lib/jvm/msopenjdk-17 diff --git a/tools/ci_build/github/linux/docker/inference/aarch64/default/cpu/Dockerfile b/tools/ci_build/github/linux/docker/inference/aarch64/default/cpu/Dockerfile index c8e164282a2f0..8b2083c2ccfc1 100644 --- a/tools/ci_build/github/linux/docker/inference/aarch64/default/cpu/Dockerfile +++ b/tools/ci_build/github/linux/docker/inference/aarch64/default/cpu/Dockerfile @@ -2,7 +2,8 @@ # Licensed under the MIT License. # This file is used by Zip-Nuget Packaging NoContribOps Pipeline,Zip-Nuget-Java Packaging Pipeline -FROM onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cpu_aarch64_almalinux8_gcc14_dotnet:20250724.1 +ARG BASEIMAGE=onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cpu_aarch64_almalinux8_gcc14_dotnet:20250724.1 +FROM $BASEIMAGE ENV LANG=en_US.UTF-8 ENV LC_ALL=en_US.UTF-8 diff --git a/tools/ci_build/github/linux/docker/inference/aarch64/python/cpu/Dockerfile b/tools/ci_build/github/linux/docker/inference/aarch64/python/cpu/Dockerfile index 31bd41226263f..f5143d5ac9ab9 100644 --- a/tools/ci_build/github/linux/docker/inference/aarch64/python/cpu/Dockerfile +++ b/tools/ci_build/github/linux/docker/inference/aarch64/python/cpu/Dockerfile @@ -1,4 +1,5 @@ -FROM onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cpu_aarch64_almalinux8_gcc14:20250724.1 +ARG BASEIMAGE=onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cpu_aarch64_almalinux8_gcc14:20250724.1 +FROM $BASEIMAGE ADD scripts /tmp/scripts RUN cd /tmp/scripts && /tmp/scripts/install_centos.sh && /tmp/scripts/install_deps.sh && rm -rf /tmp/scripts diff --git a/tools/ci_build/github/linux/docker/inference/x86_64/default/cpu/Dockerfile b/tools/ci_build/github/linux/docker/inference/x86_64/default/cpu/Dockerfile index 461464093688a..cfc2ce7079148 100644 --- a/tools/ci_build/github/linux/docker/inference/x86_64/default/cpu/Dockerfile +++ b/tools/ci_build/github/linux/docker/inference/x86_64/default/cpu/Dockerfile @@ -2,7 +2,8 @@ # Licensed under the MIT License. # This file is used by Zip-Nuget Packaging NoContribOps Pipeline,Zip-Nuget-Java Packaging Pipeline -FROM onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cpu_x64_almalinux8_gcc14_dotnet:20250724.1 +ARG BASEIMAGE=onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cpu_x64_almalinux8_gcc14_dotnet:20250724.1 +FROM $BASEIMAGE ENV LANG=en_US.UTF-8 ENV LC_ALL=en_US.UTF-8 diff --git a/tools/ci_build/github/linux/docker/inference/x86_64/default/cuda12/Dockerfile b/tools/ci_build/github/linux/docker/inference/x86_64/default/cuda12/Dockerfile index 043291065736d..8401393a661b1 100644 --- a/tools/ci_build/github/linux/docker/inference/x86_64/default/cuda12/Dockerfile +++ b/tools/ci_build/github/linux/docker/inference/x86_64/default/cuda12/Dockerfile @@ -2,7 +2,8 @@ # Licensed under the MIT License. # This file is used by Zip-Nuget Packaging NoContribOps Pipeline,Zip-Nuget-Java Packaging Pipeline -FROM onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_almalinux8_gcc12_dotnet:20250724.1 +ARG BASEIMAGE=onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_almalinux8_gcc12_dotnet:20250724.1 +FROM $BASEIMAGE ARG TRT_VERSION #Install TensorRT only if TRT_VERSION is not empty diff --git a/tools/ci_build/github/linux/docker/inference/x86_64/python/cpu/Dockerfile b/tools/ci_build/github/linux/docker/inference/x86_64/python/cpu/Dockerfile index 43da13df2fe8b..b923febc1227f 100644 --- a/tools/ci_build/github/linux/docker/inference/x86_64/python/cpu/Dockerfile +++ b/tools/ci_build/github/linux/docker/inference/x86_64/python/cpu/Dockerfile @@ -1,4 +1,5 @@ -FROM onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cpu_x64_almalinux8_gcc14:20250724.1 +ARG BASEIMAGE=onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cpu_x64_almalinux8_gcc14:20250724.1 +FROM $BASEIMAGE ADD scripts /tmp/scripts RUN cd /tmp/scripts && /tmp/scripts/install_centos.sh && rm -rf /tmp/scripts diff --git a/tools/ci_build/github/linux/java_copy_strip_binary.sh b/tools/ci_build/github/linux/java_copy_strip_binary.sh deleted file mode 100755 index 329c1b0ab9b9e..0000000000000 --- a/tools/ci_build/github/linux/java_copy_strip_binary.sh +++ /dev/null @@ -1,71 +0,0 @@ -#!/bin/bash -set -e -o -x - -while getopts r:a:l:n:c:h:v: parameter_Option -do case "${parameter_Option}" -in -r) BINARY_DIR=${OPTARG};; -a) ARTIFACT_NAME=${OPTARG};; -c) BUILD_CONFIG=${OPTARG};; -l) LIB_NAME=${OPTARG};; -n) NATIVE_LIB_NAME=${OPTARG};; -h) ARCH=${OPTARG};; #must match the JAVA_OS_ARCH variable in onnxruntime_java.cmake -v) VERSION_NUMBER=${OPTARG};; -esac -done - -EXIT_CODE=1 - -uname -a - -echo "Version: $VERSION_NUMBER" -if [[ $LIB_NAME == *.dylib ]] && [[ $ARCH == 'osx-x86_64' ]]; then - ARCH='osx-x64' -elif [[ $LIB_NAME == *.dylib ]] && [[ $ARCH == 'osx-arm64' ]]; then - ARCH='osx-aarch64' -fi -NATIVE_FOLDER=ai/onnxruntime/native/$ARCH - -mkdir -p $BINARY_DIR/$ARTIFACT_NAME/$NATIVE_FOLDER - -echo "Directories created" - -echo "Copy debug symbols in a separate file and strip the original binary." - -if [[ $LIB_NAME == *.dylib ]] -then - # ORT LIB - dsymutil $BINARY_DIR/$BUILD_CONFIG/$LIB_NAME -o $BINARY_DIR/$ARTIFACT_NAME/$NATIVE_FOLDER/$LIB_NAME.dSYM - cp $BINARY_DIR/$BUILD_CONFIG/$LIB_NAME $BINARY_DIR/$ARTIFACT_NAME/$NATIVE_FOLDER/libonnxruntime.dylib - strip -S $BINARY_DIR/$ARTIFACT_NAME/$NATIVE_FOLDER/libonnxruntime.dylib - # JNI Lib - dsymutil $BINARY_DIR/$BUILD_CONFIG/$NATIVE_LIB_NAME -o $BINARY_DIR/$ARTIFACT_NAME/$NATIVE_FOLDER/$NATIVE_LIB_NAME.dSYM - cp $BINARY_DIR/$BUILD_CONFIG/$NATIVE_LIB_NAME $BINARY_DIR/$ARTIFACT_NAME/$NATIVE_FOLDER/libonnxruntime4j_jni.dylib - strip -S $BINARY_DIR/$ARTIFACT_NAME/$NATIVE_FOLDER/libonnxruntime4j_jni.dylib - # Add custom lib for testing. This should be added to testing.jar - cp $BINARY_DIR/$BUILD_CONFIG/libcustom_op_library.dylib $BINARY_DIR/$ARTIFACT_NAME -elif [[ $LIB_NAME == *.so ]] -then - cp $BINARY_DIR/$BUILD_CONFIG/$LIB_NAME $BINARY_DIR/$ARTIFACT_NAME/$NATIVE_FOLDER/libonnxruntime.so - cp $BINARY_DIR/$BUILD_CONFIG/$NATIVE_LIB_NAME $BINARY_DIR/$ARTIFACT_NAME/$NATIVE_FOLDER/libonnxruntime4j_jni.so - # Add custom lib - cp $BINARY_DIR/$BUILD_CONFIG/libcustom_op_library.so $BINARY_DIR/$ARTIFACT_NAME - # Add cuda provider if it exists - if [[ -f "$BINARY_DIR/$BUILD_CONFIG/libonnxruntime_providers_cuda.so" ]]; then - cp $BINARY_DIR/$BUILD_CONFIG/libonnxruntime_providers_shared.so $BINARY_DIR/$ARTIFACT_NAME/$NATIVE_FOLDER/libonnxruntime_providers_shared.so - cp $BINARY_DIR/$BUILD_CONFIG/libonnxruntime_providers_cuda.so $BINARY_DIR/$ARTIFACT_NAME/$NATIVE_FOLDER/libonnxruntime_providers_cuda.so - fi - # Add tensorrt provider if it exists - if [[ -f "$BINARY_DIR/$BUILD_CONFIG/libonnxruntime_providers_tensorrt.so" ]]; then - cp $BINARY_DIR/$BUILD_CONFIG/libonnxruntime_providers_shared.so $BINARY_DIR/$ARTIFACT_NAME/$NATIVE_FOLDER/libonnxruntime_providers_shared.so - cp $BINARY_DIR/$BUILD_CONFIG/libonnxruntime_providers_tensorrt.so $BINARY_DIR/$ARTIFACT_NAME/$NATIVE_FOLDER/libonnxruntime_providers_tensorrt.so - fi -fi - -find $BINARY_DIR/$ARTIFACT_NAME -ls -rm -fr $BINARY_DIR/$ARTIFACT_NAME/jar - -EXIT_CODE=$? - -set -e -exit $EXIT_CODE diff --git a/tools/ci_build/github/linux/java_linux_final_test.sh b/tools/ci_build/github/linux/java_linux_final_test.sh index 71eb24dc7a1e2..cdbfd2bad10a8 100755 --- a/tools/ci_build/github/linux/java_linux_final_test.sh +++ b/tools/ci_build/github/linux/java_linux_final_test.sh @@ -23,6 +23,8 @@ uname -a cd "$BINARY_DIR/onnxruntime-java" rm -f *.asc rm -f *.sha256 +rm -f *.sha1 +rm -f *.md5 rm -f *.sha512 rm -f *.pom ls diff --git a/tools/ci_build/github/linux/test_custom_ops_pytorch_export.sh b/tools/ci_build/github/linux/test_custom_ops_pytorch_export.sh deleted file mode 100755 index 835f83e2b8bed..0000000000000 --- a/tools/ci_build/github/linux/test_custom_ops_pytorch_export.sh +++ /dev/null @@ -1,16 +0,0 @@ -#!/bin/bash - -pip3 install --user --upgrade pip - -pip3 install --user numpy torch pytest -pip3 install --user /build/Release/dist/*.whl - -export PYTHONPATH=/onnxruntime_src/tools:/usr/local/lib/python3.10/site-packages:$PYTHONPATH - -python3 -m pytest -v /onnxruntime_src/tools/test/test_custom_ops_pytorch_exporter.py || exit 1 - -for filename in /onnxruntime_src/onnxruntime/test/python/contrib_ops/onnx_test_* ; do - cd /build/Release && python3 -m pytest -v $filename || exit 1 -done - -cd /build/Release && ./onnxruntime_test_all --gtest_filter=ShapeInferenceTests.* || exit 1 diff --git a/tools/ci_build/github/pai/pai_test_launcher.sh b/tools/ci_build/github/pai/pai_test_launcher.sh deleted file mode 100755 index e3d531aea75ad..0000000000000 --- a/tools/ci_build/github/pai/pai_test_launcher.sh +++ /dev/null @@ -1,15 +0,0 @@ -#!/bin/bash - -build_dir=${1:-"."} -script_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" - -echo "Warning: The following tests are EXCLUDED on PAI agent:" -gtest_filter="-" -while read line; do - gtest_filter="$gtest_filter:$line" - echo "$line" -done <$script_dir/pai-excluded-tests.txt -echo "" - -echo "Running ./onnxruntime_test_all .." -$build_dir/onnxruntime_test_all --gtest_filter=$gtest_filter diff --git a/tools/ci_build/github/windows/jar_esrp_dll.ps1 b/tools/ci_build/github/windows/jar_esrp_dll.ps1 index 8492d7591271b..2a53374d845a0 100644 --- a/tools/ci_build/github/windows/jar_esrp_dll.ps1 +++ b/tools/ci_build/github/windows/jar_esrp_dll.ps1 @@ -1,41 +1,70 @@ -$instruction = $args[0] # extract or repack -$original_jar_file_directory = $args[1] # The directory where the original jar file is located -$original_jar_file_name = $args[2] # The name of the original jar file +param( + [string]$instruction, # Should be 'extract' or 'repack' + [string]$jar_file_directory # The directory where the original jar file is located +) -$original_jar_file_full_path = "$original_jar_file_directory\$original_jar_file_name" -$extracted_file_directory = "$original_jar_file_directory\jar_extracted_full_files" +$extracted_file_directory = Join-Path $jar_file_directory "jar_extracted_full_files" +$state_file = Join-Path $jar_file_directory "repack_list.txt" if ($instruction -eq "extract") { - Write-Host "Extracting the jar file $original_jar_file_full_path..." - & 7z x $original_jar_file_full_path -o"$extracted_file_directory" - if ($lastExitCode -ne 0) { - Write-Host -Object "7z extracting the jar file command failed. Exitcode: $exitCode" - exit $lastExitCode + # Find the main jar file(s) by looking for names that start with 'onnxruntime' + # and excluding common suffixes for sources and javadocs. + $main_jar_files = Get-ChildItem -Path $jar_file_directory -Filter onnxruntime*.jar | Where-Object { $_.Name -notlike '*-sources.jar' -and $_.Name -notlike '*-javadoc.jar' } + + if ($main_jar_files.Count -eq 0) { + Write-Error "No main ONNX Runtime JAR file found in directory: $jar_file_directory" + exit 1 } - Write-Host "Extracted files directory: $extracted_file_directory" - Write-Host "Removing the original jar file..." - Remove-Item -Path "$original_jar_file_full_path" -Force - Write-Host "Removed the original jar file." -} -elseif ($instruction -eq "repack") { + # Clear any previous state file + if (Test-Path $state_file) { + Remove-Item $state_file + } + + foreach ($jar_file in $main_jar_files) { + Write-Host "Extracting the jar file $($jar_file.FullName)..." + & 7z x $jar_file.FullName -o"$extracted_file_directory" + if ($LASTEXITCODE -ne 0) { + Write-Error "7z failed to extract the jar file. Exitcode: $LASTEXITCODE" + exit $LASTEXITCODE + } + + # Save the original name for repacking, then remove the file + $jar_file.Name | Out-File -FilePath $state_file -Append + Write-Host "Removing the original jar file: $($jar_file.FullName)" + Remove-Item -Path $jar_file.FullName -Force + } + Write-Host "Extracted files to directory: $extracted_file_directory" + +} elseif ($instruction -eq "repack") { + if (-not (Test-Path $state_file)) { + Write-Error "State file '$state_file' not found. Cannot repack." + exit 1 + } + Write-Host "Removing ESRP's CodeSignSummary file..." - # It is the summary generated by ESRP tool. It is not needed in the jar file. - Remove-Item -Path "$extracted_file_directory/CodeSignSummary*.*" -Force + Remove-Item -Path "$extracted_file_directory/CodeSignSummary*.*" -Force -ErrorAction SilentlyContinue Write-Host "Removed ESRP's CodeSignSummary file." - Write-Host "Repacking the jar file from directory $extracted_file_directory..." - & 7z a "$original_jar_file_full_path" "$extracted_file_directory\*" - if ($lastExitCode -ne 0) { - Write-Host -Object "7z repacking the jar file command failed. Exitcode: $exitCode" - exit $lastExitCode + $jar_files_to_repack = Get-Content $state_file + + foreach ($jar_file_name in $jar_files_to_repack) { + $repacked_jar_file_path = Join-Path $jar_file_directory $jar_file_name + Write-Host "Repacking to $repacked_jar_file_path from directory $extracted_file_directory..." + & 7z a "$repacked_jar_file_path" "$extracted_file_directory\*" + if ($LASTEXITCODE -ne 0) { + Write-Error "7z failed to repack the jar file. Exitcode: $LASTEXITCODE" + exit $LASTEXITCODE + } + Write-Host "Repacked the jar file $repacked_jar_file_path." } - Write-Host "Repacked the jar file $original_jar_file_full_path." - Write-Host "Removing the extracted files..." + Write-Host "Removing the extracted files and state file..." Remove-Item -Path "$extracted_file_directory" -Recurse -Force - Write-Host "Removed the extracted files." -} -else { - Write-Host "Invalid instruction: $instruction" + Remove-Item -Path $state_file -Force + Write-Host "Cleaned up temporary files." + +} else { + Write-Error "Invalid instruction: '$instruction'. Must be 'extract' or 'repack'." + exit 1 } diff --git a/tools/ci_build/github/windows/jar_gpu_packaging.ps1 b/tools/ci_build/github/windows/jar_gpu_packaging.ps1 deleted file mode 100644 index 1c94f4678f988..0000000000000 --- a/tools/ci_build/github/windows/jar_gpu_packaging.ps1 +++ /dev/null @@ -1,19 +0,0 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. - -$ErrorActionPreference = "Stop" -Write-Output "Start" -dir -Copy-Item -Path $Env:BUILD_BINARIESDIRECTORY\java-artifact\onnxruntime-java-linux-x64\ai\onnxruntime\native\linux-x64\libonnxruntime_providers_cuda.so -Destination $Env:BUILD_BINARIESDIRECTORY\java-artifact\onnxruntime-java-linux-x64-tensorrt\ai\onnxruntime\native\linux-x64 -pushd onnxruntime-java-linux-x64-tensorrt -Write-Output "Run 7z" -7z a $Env:BUILD_BINARIESDIRECTORY\java-artifact\onnxruntime-java-win-x64\testing.jar libcustom_op_library.so -Remove-Item -Path libcustom_op_library.so -7z a $Env:BUILD_BINARIESDIRECTORY\java-artifact\onnxruntime-java-win-x64\onnxruntime-$Env:ONNXRUNTIMEVERSION.jar . -popd -pushd onnxruntime-java-win-x64 -ren onnxruntime-$Env:ONNXRUNTIMEVERSION.jar onnxruntime_gpu-$Env:ONNXRUNTIMEVERSION.jar -ren onnxruntime-$Env:ONNXRUNTIMEVERSION-javadoc.jar onnxruntime_gpu-$Env:ONNXRUNTIMEVERSION-javadoc.jar -ren onnxruntime-$Env:ONNXRUNTIMEVERSION-sources.jar onnxruntime_gpu-$Env:ONNXRUNTIMEVERSION-sources.jar -ren onnxruntime-$Env:ONNXRUNTIMEVERSION.pom onnxruntime_gpu-$Env:ONNXRUNTIMEVERSION.pom -popd diff --git a/tools/ci_build/github/windows/jar_packaging.ps1 b/tools/ci_build/github/windows/jar_packaging.ps1 deleted file mode 100644 index a132ba6b26e2a..0000000000000 --- a/tools/ci_build/github/windows/jar_packaging.ps1 +++ /dev/null @@ -1,25 +0,0 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. - -$ErrorActionPreference = "Stop" -Write-Output "Start" -dir -pushd onnxruntime-java-linux-x64 -Write-Output "Run 7z" -7z a $Env:BUILD_BINARIESDIRECTORY\java-artifact\onnxruntime-java-win-x64\testing.jar libcustom_op_library.so -Remove-Item -Path libcustom_op_library.so -7z a $Env:BUILD_BINARIESDIRECTORY\java-artifact\onnxruntime-java-win-x64\onnxruntime-$Env:ONNXRUNTIMEVERSION.jar . -popd -pushd onnxruntime-java-osx-x86_64 -7z a $Env:BUILD_BINARIESDIRECTORY\java-artifact\onnxruntime-java-win-x64\testing.jar libcustom_op_library.dylib -Remove-Item -Path libcustom_op_library.dylib -7z a $Env:BUILD_BINARIESDIRECTORY\java-artifact\onnxruntime-java-win-x64\onnxruntime-$Env:ONNXRUNTIMEVERSION.jar . -popd -pushd onnxruntime-java-linux-aarch64 -Remove-Item -Path libcustom_op_library.so -7z a $Env:BUILD_BINARIESDIRECTORY\java-artifact\onnxruntime-java-win-x64\onnxruntime-$Env:ONNXRUNTIMEVERSION.jar . -popd -pushd onnxruntime-java-osx-arm64 -Remove-Item -Path libcustom_op_library.dylib -7z a $Env:BUILD_BINARIESDIRECTORY\java-artifact\onnxruntime-java-win-x64\onnxruntime-$Env:ONNXRUNTIMEVERSION.jar . -popd diff --git a/tools/ci_build/github/windows/jar_packaging.py b/tools/ci_build/github/windows/jar_packaging.py new file mode 100644 index 0000000000000..2354363610251 --- /dev/null +++ b/tools/ci_build/github/windows/jar_packaging.py @@ -0,0 +1,312 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +""" +Packages ONNX Runtime Java artifacts by combining native libraries from +various platform builds into final Java archive (JAR) files using 7z. +""" + +import argparse +import glob +import os +import re +import shutil +import subprocess +import sys +from pathlib import Path +from typing import Any + +# Add semver as a dependency +try: + import semver +except ImportError: + print("Error: The 'semver' package is not installed. Please add it to your requirements.txt.", file=sys.stderr) + sys.exit(1) + +# --- Helper Functions for Archiving --- + + +def find_7z_executable(): + """Finds the 7z executable, checking the system PATH and default installation locations.""" + # 1. Check if '7z' is in the PATH + seven_zip_exe = shutil.which("7z") + if seven_zip_exe: + return seven_zip_exe + + # 2. Check the default installation directory under Program Files + program_files = os.environ.get("ProgramFiles") # noqa: SIM112 + if program_files: + default_path = Path(program_files) / "7-Zip" / "7z.exe" + if default_path.is_file(): + return str(default_path) + + return None + + +SEVEN_ZIP_EXE = find_7z_executable() + + +def add_file_to_archive(archive_path: Path, file_to_add: Path, description: str): + """Appends a single file to a zip archive (JAR file) using 7z.""" + print(f" -> {description}...") + try: + if not SEVEN_ZIP_EXE: + raise FileNotFoundError + # Run 7z from the file's parent directory to ensure a clean archive path. + subprocess.run( + [SEVEN_ZIP_EXE, "a", str(archive_path), file_to_add.name], + check=True, + cwd=file_to_add.parent, + capture_output=True, + text=True, + ) + except FileNotFoundError: + print( + "Error: '7z' command not found. Please ensure 7-Zip is installed and in your PATH, or in the default location 'C:\\Program Files\\7-Zip'.", + file=sys.stderr, + ) + raise + except subprocess.CalledProcessError as e: + print(f"Error: 7z failed to archive '{file_to_add.name}' to '{archive_path.name}'.", file=sys.stderr) + print(f"Reason: {e.stderr}", file=sys.stderr) + raise + + +def archive_directory_contents(archive_path: Path, source_dir: Path, description: str): + """Archives a directory into a zip file (JAR file) using 7z, preserving its top-level name.""" + print(f" -> {description}...") + try: + if not SEVEN_ZIP_EXE: + raise FileNotFoundError + # Run 7z from the parent of the source directory to ensure the source directory + # itself is added to the archive, preserving the path structure (e.g., 'ai/...'). + subprocess.run( + [SEVEN_ZIP_EXE, "a", str(archive_path), source_dir.name], + check=True, + cwd=source_dir.parent, + capture_output=True, + text=True, + ) + except FileNotFoundError: + print( + "Error: '7z' command not found. Please ensure 7-Zip is installed and in your PATH, or in the default location 'C:\\Program Files\\7-Zip'.", + file=sys.stderr, + ) + raise + except subprocess.CalledProcessError as e: + print(f"Error: 7z failed to archive directory '{source_dir.name}' to '{archive_path.name}'.", file=sys.stderr) + print(f"Reason: {e.stderr}", file=sys.stderr) + raise + + +# --- Validation Helpers --- + + +def validate_version(version_string: str): + """Validates if the version string conforms to the project's format.""" + print(f"Validating version string: {version_string}...") + try: + version_info = semver.Version.parse(version_string) + if version_info.prerelease: + prerelease_tag = version_info.prerelease + allowed_tags_pattern = r"^(alpha|beta|rc)\d+$" + if not re.match(allowed_tags_pattern, str(prerelease_tag)): + raise ValueError(f"Pre-release tag '{prerelease_tag}' is not an allowed type.") + except ValueError as e: + print(f"Error: Version '{version_string}' is not valid. Reason: {e}", file=sys.stderr) + print("Expected format is 'X.Y.Z' or 'X.Y.Z-(alpha|beta|rc)N'.", file=sys.stderr) + sys.exit(1) + print("Version format is valid.") + + +def validate_companion_jars(base_jar_path: Path): + """Ensures that -sources.jar and -javadoc.jar files exist.""" + print("Validating presence of companion -sources.jar and -javadoc.jar...") + base_stem = base_jar_path.stem + directory = base_jar_path.parent + sources_jar_path = directory / f"{base_stem}-sources.jar" + + if not sources_jar_path.is_file(): + print(f"Error: Missing companion sources JAR. Expected: {sources_jar_path.name}", file=sys.stderr) + sys.exit(1) + + if not list(directory.glob(f"{base_stem}-javadoc*.jar")): + print(f"Error: Missing companion javadoc JAR. Expected file like: {base_stem}-javadoc.jar", file=sys.stderr) + sys.exit(1) + print("Companion JARs are present.") + + +# --- Core Logic Function --- + + +def process_platform_archive( + platform_path: Path, + main_archive_file: Path, + test_archive_file: Path, + custom_lib_file: str, + archive_custom_lib: bool, +): + """Processes a single platform directory, adding only the 'ai' subdirectory to the main JAR.""" + print(f"Processing platform: {platform_path}...") + + # 1. Handle the custom op library. + custom_lib_full_path = platform_path / custom_lib_file + if custom_lib_file and custom_lib_full_path.is_file(): + if archive_custom_lib: + add_file_to_archive(test_archive_file, custom_lib_full_path, f"Archiving '{custom_lib_file}' to test JAR") + # Always remove the lib after processing to prevent it from being in the main JAR. + print(f" -> Removing '{custom_lib_file}' from source directory...") + custom_lib_full_path.unlink() + elif archive_custom_lib: + # If we expected to archive the file but it wasn't there, it's a fatal error. + print(f"Error: Expected custom op library '{custom_lib_file}' not found in {platform_path}", file=sys.stderr) + sys.exit(1) + + # 2. Archive only the native library directory ('ai/...') to the main JAR. + # This explicitly excludes other files or folders like '_manifest'. + native_lib_root = platform_path / "ai" + if native_lib_root.is_dir(): + archive_directory_contents( + main_archive_file, native_lib_root, f"Archiving native libs from '{native_lib_root.name}' to main JAR" + ) + else: + print(f"Warning: Native library path 'ai/' not found in {platform_path}. Skipping main archive step.") + + print(f"Finished platform: {platform_path}") + print("--------------------------------") + + +def run_packaging(package_type: str, build_dir: str): + """The main logic for the packaging process, refactored to be callable.""" + artifacts_base_dir = Path(build_dir) / "java-artifact" + primary_package_dir = artifacts_base_dir / "onnxruntime-java-win-x64" + if not primary_package_dir.is_dir(): + print(f"Error: Primary package directory not found at '{primary_package_dir}'", file=sys.stderr) + sys.exit(1) + + # --- Version Discovery --- + print(f"Discovering version from JAR files in '{primary_package_dir}'...") + jar_pattern = str(primary_package_dir / "onnxruntime*-*.jar") + jar_files = [Path(f) for f in glob.glob(jar_pattern) if "-sources" not in f and "-javadoc" not in f] + if not jar_files: + print( + f"Error: Could not find a main JAR file in '{primary_package_dir}' to determine the version.", + file=sys.stderr, + ) + sys.exit(1) + + main_jar_file = jar_files[0] + validate_companion_jars(main_jar_file) + + version = "" + stem = main_jar_file.stem + try: + # Per user feedback, the version is everything after the first dash. + _, version = stem.split("-", 1) + except ValueError: + # This will happen if there is no dash in the filename, which is unexpected. + print( + f"Error: Could not parse version from JAR file '{main_jar_file.name}'. Expected format -.jar", + file=sys.stderr, + ) + sys.exit(1) + + if not version: + print( + f"Error: Could not parse version from JAR file '{main_jar_file.name}'. Version part is empty.", + file=sys.stderr, + ) + sys.exit(1) + + print(f"Version discovered: {version}") + validate_version(version) + + # --- Package Definitions --- + package_definitions: dict[str, dict[str, Any]] = { + "cpu": { + "platforms": [ + {"path": "onnxruntime-java-linux-x64", "lib": "libcustom_op_library.so", "archive_lib": True}, + {"path": "onnxruntime-java-osx-x86_64", "lib": "libcustom_op_library.dylib", "archive_lib": True}, + {"path": "onnxruntime-java-linux-aarch64", "lib": "libcustom_op_library.so", "archive_lib": False}, + {"path": "onnxruntime-java-osx-arm64", "lib": "libcustom_op_library.dylib", "archive_lib": False}, + ] + }, + "gpu": { + "platforms": [ + {"path": "onnxruntime-java-linux-x64", "lib": "libcustom_op_library.so", "archive_lib": False} + ] + }, + } + + # --- Processing Loop --- + print(f"\n## Configuring for {package_type.upper()} package build...") + + final_main_archive = main_jar_file + final_test_archive = primary_package_dir / "testing.jar" + + print(f"Using '{final_main_archive.name}' as the base for in-place packaging.") + + if not final_test_archive.is_file(): + print(f"Error: Base 'testing.jar' not found at '{final_test_archive}'.", file=sys.stderr) + sys.exit(1) + + platforms_to_process = package_definitions[package_type]["platforms"] + + for platform in platforms_to_process: + platform_full_path = artifacts_base_dir / platform["path"] + if not platform_full_path.is_dir(): + print(f"Error: Required platform artifact directory not found: {platform_full_path}", file=sys.stderr) + sys.exit(1) + + process_platform_archive( + platform_path=platform_full_path, + main_archive_file=final_main_archive, + test_archive_file=final_test_archive, + custom_lib_file=platform["lib"], + archive_custom_lib=platform["archive_lib"], + ) + + print("\nScript completed successfully.") + + +def main(): + """Main script entry point for command-line execution.""" + if sys.platform != "win32": + print("Error: This script is intended to be run on Windows.", file=sys.stderr) + sys.exit(1) + + parser = argparse.ArgumentParser(description="Package ONNX Runtime Java artifacts.") + parser.add_argument( + "--package_type", + type=str, + choices=["cpu", "gpu"], + default="cpu", + help="The type of package to build ('cpu' or 'gpu').", + ) + parser.add_argument( + "--build_dir", + type=str, + help="The build directory containing the java-artifact folder.", + ) + args = parser.parse_args() + + build_dir = args.build_dir + if not build_dir: + try: + build_dir = os.environ["BUILD_BINARIESDIRECTORY"] + except KeyError: + print( + "Error: Environment variable BUILD_BINARIESDIRECTORY is not set and --build_dir is not provided.", + file=sys.stderr, + ) + sys.exit(1) + + run_packaging(args.package_type, build_dir) + + +if __name__ == "__main__": + try: + main() + except Exception as e: + print(f"\nAn unhandled error occurred: {e}", file=sys.stderr) + sys.exit(1) diff --git a/tools/ci_build/github/windows/jar_packaging_test.py b/tools/ci_build/github/windows/jar_packaging_test.py new file mode 100644 index 0000000000000..91b68728dad15 --- /dev/null +++ b/tools/ci_build/github/windows/jar_packaging_test.py @@ -0,0 +1,157 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import zipfile +from pathlib import Path + +import jar_packaging # The refactored script +import pytest + + +# Helper to create an empty file +def create_empty_file(path): + Path(path).touch() + + +# Helper to create a dummy JAR file +def create_dummy_jar(path): + with zipfile.ZipFile(path, "w") as zf: + zf.writestr("META-INF/MANIFEST.MF", "Manifest-Version: 1.0\n") + + +@pytest.fixture +def directory_setup_factory(tmp_path): + """ + A factory fixture that returns a function to set up a test directory + for a given package type and version. + """ + + def _setup_test_directory(package_type: str, version_string: str): + """Sets up a temporary directory structure mimicking the build artifacts.""" + java_artifact_dir = tmp_path / "java-artifact" + win_dir = java_artifact_dir / "onnxruntime-java-win-x64" + linux_dir = java_artifact_dir / "onnxruntime-java-linux-x64" + osx_dir = java_artifact_dir / "onnxruntime-java-osx-x86_64" + + # --- Main artifact directory (Windows) --- + win_dir.mkdir(parents=True, exist_ok=True) + artifact_name = f"onnxruntime_{package_type}" if package_type == "gpu" else "onnxruntime" + create_dummy_jar(win_dir / f"{artifact_name}-{version_string}.jar") + create_dummy_jar(win_dir / f"{artifact_name}-{version_string}-sources.jar") + create_dummy_jar(win_dir / f"{artifact_name}-{version_string}-javadoc.jar") + create_empty_file(win_dir / f"{artifact_name}-{version_string}.pom") + create_dummy_jar(win_dir / "testing.jar") + (win_dir / "_manifest" / "spdx_2.2").mkdir(parents=True, exist_ok=True) + + # --- Linux platform --- + linux_native_dir = linux_dir / "ai" / "onnxruntime" / "native" / "linux-x64" + linux_native_dir.mkdir(parents=True, exist_ok=True) + create_empty_file(linux_dir / "libcustom_op_library.so") + create_empty_file(linux_native_dir / "libonnxruntime.so") + create_empty_file(linux_native_dir / "libonnxruntime4j_jni.so") + if package_type == "gpu": + create_empty_file(linux_native_dir / "libonnxruntime_providers_cuda.so") + (linux_dir / "_manifest" / "spdx_2.2").mkdir(parents=True, exist_ok=True) + + # --- macOS and other platforms (for CPU test) --- + if package_type == "cpu": + osx_native_dir = osx_dir / "ai" / "onnxruntime" / "native" / "osx-x86_64" + osx_native_dir.mkdir(parents=True, exist_ok=True) + create_empty_file(osx_dir / "libcustom_op_library.dylib") + create_empty_file(osx_native_dir / "libonnxruntime.dylib") + create_empty_file(osx_native_dir / "libonnxruntime4j_jni.dylib") + (osx_dir / "_manifest" / "spdx_2.2").mkdir(parents=True, exist_ok=True) + + # Add linux-aarch64 and osx-arm64 for CPU test + linux_aarch64_dir = java_artifact_dir / "onnxruntime-java-linux-aarch64" + linux_aarch64_native_dir = linux_aarch64_dir / "ai" / "onnxruntime" / "native" / "linux-aarch64" + linux_aarch64_native_dir.mkdir(parents=True, exist_ok=True) + create_empty_file(linux_aarch64_dir / "libcustom_op_library.so") + + osx_arm64_dir = java_artifact_dir / "onnxruntime-java-osx-arm64" + osx_arm64_native_dir = osx_arm64_dir / "ai" / "onnxruntime" / "native" / "osx-arm64" + osx_arm64_native_dir.mkdir(parents=True, exist_ok=True) + create_empty_file(osx_arm64_dir / "libcustom_op_library.dylib") + + return tmp_path + + return _setup_test_directory + + +@pytest.mark.parametrize("version_string", ["1.23.0", "1.23.0-rc1"]) +def test_gpu_packaging(directory_setup_factory, version_string): + """ + Tests the GPU packaging logic for both release and pre-release versions + to ensure correct files are added to the JARs. + """ + temp_build_dir = directory_setup_factory("gpu", version_string) + + # Run the packaging script logic + jar_packaging.run_packaging("gpu", str(temp_build_dir)) + + # --- Verification --- + win_dir = temp_build_dir / "java-artifact" / "onnxruntime-java-win-x64" + main_jar_path = win_dir / f"onnxruntime_gpu-{version_string}.jar" + testing_jar_path = win_dir / "testing.jar" + + # 1. Verify the main JAR contains the Linux native libraries + with zipfile.ZipFile(main_jar_path, "r") as zf: + jar_contents = zf.namelist() + assert "ai/onnxruntime/native/linux-x64/libonnxruntime.so" in jar_contents + assert "ai/onnxruntime/native/linux-x64/libonnxruntime4j_jni.so" in jar_contents + assert "ai/onnxruntime/native/linux-x64/libonnxruntime_providers_cuda.so" in jar_contents + + # 2. Verify the testing JAR does not contain the custom op library for GPU builds + with zipfile.ZipFile(testing_jar_path, "r") as zf: + jar_contents = zf.namelist() + # The custom op lib for linux is not archived for GPU builds. + # This checks that it's NOT in the test jar. + assert "libcustom_op_library.so" not in jar_contents + + # 3. Verify the custom op library was removed from the source linux directory + linux_dir = temp_build_dir / "java-artifact" / "onnxruntime-java-linux-x64" + assert not (linux_dir / "libcustom_op_library.so").exists() + + +@pytest.mark.parametrize("version_string", ["1.23.0", "1.23.0-rc1"]) +def test_cpu_packaging(directory_setup_factory, version_string): + """ + Tests the CPU packaging logic to ensure correct files are added to the JARs. + """ + temp_build_dir = directory_setup_factory("cpu", version_string) + + # Run the packaging script logic + jar_packaging.run_packaging("cpu", str(temp_build_dir)) + + # --- Verification --- + win_dir = temp_build_dir / "java-artifact" / "onnxruntime-java-win-x64" + main_jar_path = win_dir / f"onnxruntime-{version_string}.jar" + testing_jar_path = win_dir / "testing.jar" + + # 1. Verify the main JAR contains native libraries from all relevant platforms + with zipfile.ZipFile(main_jar_path, "r") as zf: + jar_contents = zf.namelist() + # Linux libs + assert "ai/onnxruntime/native/linux-x64/libonnxruntime.so" in jar_contents + assert "ai/onnxruntime/native/linux-x64/libonnxruntime4j_jni.so" in jar_contents + # macOS libs + assert "ai/onnxruntime/native/osx-x86_64/libonnxruntime.dylib" in jar_contents + assert "ai/onnxruntime/native/osx-x86_64/libonnxruntime4j_jni.dylib" in jar_contents + # GPU libs should NOT be present + assert "ai/onnxruntime/native/linux-x64/libonnxruntime_providers_cuda.so" not in jar_contents + + # 2. Verify the testing JAR contains the custom op libraries that should be archived + with zipfile.ZipFile(testing_jar_path, "r") as zf: + jar_contents = zf.namelist() + assert "libcustom_op_library.so" in jar_contents + assert "libcustom_op_library.dylib" in jar_contents + + # 3. Verify the custom op libraries were removed from the source directories + linux_dir = temp_build_dir / "java-artifact" / "onnxruntime-java-linux-x64" + osx_dir = temp_build_dir / "java-artifact" / "onnxruntime-java-osx-x86_64" + linux_aarch64_dir = temp_build_dir / "java-artifact" / "onnxruntime-java-linux-aarch64" + osx_arm64_dir = temp_build_dir / "java-artifact" / "onnxruntime-java-osx-arm64" + assert not (linux_dir / "libcustom_op_library.so").exists() + assert not (osx_dir / "libcustom_op_library.dylib").exists() + assert not (linux_aarch64_dir / "libcustom_op_library.so").exists() + assert not (osx_arm64_dir / "libcustom_op_library.dylib").exists() diff --git a/tools/ci_build/github/windows/python/requirements.txt b/tools/ci_build/github/windows/python/requirements.txt index 6968c605eb649..67541fc46c8bf 100644 --- a/tools/ci_build/github/windows/python/requirements.txt +++ b/tools/ci_build/github/windows/python/requirements.txt @@ -11,3 +11,4 @@ psutil onnxscript==0.3.2 jinja2 markupsafe +semver \ No newline at end of file diff --git a/tools/ci_build/github/windows/sign_java_artifacts.py b/tools/ci_build/github/windows/sign_java_artifacts.py new file mode 100644 index 0000000000000..19d1a4af98799 --- /dev/null +++ b/tools/ci_build/github/windows/sign_java_artifacts.py @@ -0,0 +1,139 @@ +import argparse +import hashlib +import os +import platform +import shutil +import subprocess +import sys +import tempfile +from pathlib import Path + + +def get_gpg_path() -> Path: + """Finds the path to the GPG executable.""" + if platform.system() == "Windows": + program_files_x86 = os.environ.get("ProgramFiles(x86)") # noqa: SIM112 + if not program_files_x86: + raise OSError("ProgramFiles(x86) environment variable not found.") + return Path(program_files_x86) / "gnupg/bin/gpg.exe" + + gpg_path_str = shutil.which("gpg") + if gpg_path_str is None: + raise FileNotFoundError("gpg executable not found in system PATH.") + return Path(gpg_path_str) + + +def run_command(command: list[str], check: bool = True) -> subprocess.CompletedProcess: + """Executes a command and raises an exception if it fails.""" + print(f"Running command: {' '.join(command)}") + result = subprocess.run(command, capture_output=True, text=True, check=False) + if check and result.returncode != 0: + print(f"Command failed with exit code {result.returncode}") + print(f"Stdout:\n{result.stdout}") + print(f"Stderr:\n{result.stderr}") + raise subprocess.CalledProcessError(result.returncode, command, result.stdout, result.stderr) + return result + + +def create_hash_file(file_path: Path, algorithm: str) -> None: + """Creates a checksum file for the given file using the specified algorithm.""" + print(f" - Generating {algorithm.upper()} checksum...") + try: + hasher = hashlib.new(algorithm) + with file_path.open("rb") as f: + # Read in chunks to handle large files efficiently + while chunk := f.read(8192): + hasher.update(chunk) + + hash_value = hasher.hexdigest() + # Create checksum file in 'sha1sum'/'md5sum' format. + # The '*' indicates to read the file in binary mode for verification tools. + Path(f"{file_path}.{algorithm}").write_text(hash_value.lower(), encoding="utf-8") + except Exception as e: + print(f"Error generating {algorithm} hash for {file_path}: {e}") + raise + + +def main() -> None: + """ + Signs files with GPG and generates checksums. + """ + parser = argparse.ArgumentParser(description="Signs files with GPG and generates checksums.") + parser.add_argument("jar_file_directory", help="The directory containing files to sign.") + args = parser.parse_args() + + jar_file_directory = Path(args.jar_file_directory) + if not jar_file_directory.is_dir(): + print(f"Error: Directory not found at '{jar_file_directory}'", file=sys.stderr) + sys.exit(1) + + print(f"\nListing files to be processed in '{jar_file_directory}':") + files_to_process = [p for p in jar_file_directory.rglob("*") if p.is_file()] + for file_path in files_to_process: + print(f" - {file_path}") + print(f"Found {len(files_to_process)} files.") + + print("\nGetting GnuPG signing keys from environment variables.") + gpg_passphrase = os.environ.get("JAVA_PGP_PWD") + gpg_private_key = os.environ.get("JAVA_PGP_KEY") + + if not gpg_passphrase or not gpg_private_key: + print( + "Error: GPG passphrase or private key not found in environment variables ('JAVA_PGP_PWD', 'JAVA_PGP_KEY').", + file=sys.stderr, + ) + sys.exit(1) + + gpg_exe_path = get_gpg_path() + if not gpg_exe_path.is_file(): + print(f"Error: GPG executable not found at '{gpg_exe_path}'.", file=sys.stderr) + sys.exit(1) + + agent_temp_dir = os.environ.get("AGENT_TEMPDIRECTORY") + + # Use a single temporary directory to manage all temporary files + with tempfile.TemporaryDirectory(dir=agent_temp_dir) as temp_dir: + temp_dir_path = Path(temp_dir) + print(f"Created temporary directory: {temp_dir_path}") + + private_key_file = temp_dir_path / "private.key" + passphrase_file = temp_dir_path / "passphrase.txt" + + print("Writing GnuPG key and passphrase to temporary files.") + private_key_file.write_text(gpg_private_key, encoding="utf-8") + passphrase_file.write_text(gpg_passphrase, encoding="utf-8") + + print("Importing GnuPG private key.") + run_command([str(gpg_exe_path), "--batch", "--import", str(private_key_file)]) + print("Successfully imported GnuPG private key.") + + print(f"\nProcessing {len(files_to_process)} files in '{jar_file_directory}'.") + + for file_path in files_to_process: + print(f"Processing file: {file_path}") + + # GPG Signing (.asc) + print(" - GnuPG signing...") + run_command( + [ + str(gpg_exe_path), + "--pinentry-mode", + "loopback", + "--passphrase-file", + str(passphrase_file), + "--detach-sign", + "--armor", + str(file_path), + ] + ) + + # SHA-1 and MD5 Checksums + create_hash_file(file_path, "sha1") + create_hash_file(file_path, "md5") + + print("\nFile signing and checksum generation completed.") + print("Temporary directory and its contents have been deleted.") + + +if __name__ == "__main__": + main() diff --git a/tools/ci_build/linux_java_copy_strip_binary.py b/tools/ci_build/linux_java_copy_strip_binary.py new file mode 100644 index 0000000000000..b9ca856d1c514 --- /dev/null +++ b/tools/ci_build/linux_java_copy_strip_binary.py @@ -0,0 +1,197 @@ +#!/usr/bin/env python3 +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +""" +Prepares native shared libraries for the ONNX Runtime Java package. + +This script is a build utility that run as part of a packaging pipeline and takes compiled C/C++ shared libraries +(.so, .dylib) and stages them for packaging into a Java JAR file. + +It expected the following inputs: +/ +└── / + ├── libonnxruntime.so (File from --lib-name) + ├── libonnxruntime4j_jni.so (File from --native-lib-name) + ├── libcustom_op_library.so + │ + ├── (Optional) libonnxruntime_providers_shared.so + ├── (Optional) libonnxruntime_providers_cuda.so + └── (Optional) libonnxruntime_providers_tensorrt.so + +It performs the following key operations: + +1. Validates the existence of all required source directories and libraries. +2. Creates the specific Java Native Interface (JNI) directory structure + (ai/onnxruntime/native/). +3. Copies the main, JNI, and custom op libraries to their destinations. +4. For macOS, extracts debug symbols into .dSYM files using `dsymutil`. +5. Strips all release binaries of their debug symbols to reduce file size. +6. Copies optional provider libraries (e.g., CUDA, TensorRT) for Linux builds. + +It is intended to be called from a CI/CD pipeline as part of the overall +build process for the onnxruntime-java package. +""" + +import argparse +import logging +import platform +import shutil +import subprocess +import sys +from pathlib import Path + +# --- Configuration --- +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s [%(levelname)s] %(message)s", +) + + +# --- Helper Functions --- +def run_command(command: list[str | Path]): + """Runs an external command and exits the script if the command fails.""" + str_command = " ".join(map(str, command)) + logging.info(f"Running command: '{str_command}'") + try: + proc = subprocess.run(command, check=True, text=True, capture_output=True) + logging.info(f"Successfully executed: {Path(command[0]).name}") + if proc.stdout: + logging.debug(f"STDOUT: {proc.stdout.strip()}") + except FileNotFoundError: + logging.error(f"Command not found: '{command[0]}'. Please ensure it is installed and in your PATH.") + raise + except subprocess.CalledProcessError as e: + logging.error(f"Command '{Path(e.cmd[0]).name}' failed with exit code {e.returncode}.") + if e.stdout: + logging.error(f"STDOUT: {e.stdout.strip()}") + if e.stderr: + logging.error(f"STDERR: {e.stderr.strip()}") + raise + + +# --- Main Execution --- +def main(): + """Main function to parse arguments and package the native libraries.""" + parser = argparse.ArgumentParser( + description="Packages ONNX Runtime native libraries for Java.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + + # Arguments + parser.add_argument("--binary-dir", required=True, type=Path, help="Path to the build binaries directory.") + parser.add_argument("--artifact-name", required=True, help="Name of the final artifact directory.") + parser.add_argument("--build-config", required=True, help="CMake build configuration (e.g., Release).") + parser.add_argument("--lib-name", required=True, help="Filename of the main ONNX Runtime shared library.") + parser.add_argument("--native-lib-name", required=True, help="Filename of the JNI shared library.") + parser.add_argument("--arch", required=True, help="Architecture string (e.g., osx-x86_64).") + args = parser.parse_args() + + # --- Path Setup and Validation --- + logging.info(f"System Info: {' '.join(platform.uname())}") + + source_build_dir = args.binary_dir / args.build_config + target_artifact_dir = args.binary_dir / args.artifact_name + + # Validate that the source build directory exists. + if not source_build_dir.is_dir(): + logging.error(f"Source build directory not found: {source_build_dir}") + sys.exit(1) + + # Map architecture names for macOS to align with Java conventions + arch = args.arch + if args.lib_name.endswith(".dylib"): + if arch == "osx-x86_64": + arch = "osx-x64" + elif arch == "osx-arm64": + arch = "osx-aarch64" + + # --- Library Processing --- + native_folder = target_artifact_dir / "ai" / "onnxruntime" / "native" / arch + native_folder.mkdir(parents=True, exist_ok=True) + logging.info(f"Staging native libraries in: {native_folder}") + + # Validate that all required library files exist before processing. + main_lib_src = source_build_dir / args.lib_name + jni_lib_src = source_build_dir / args.native_lib_name + + required_files = [main_lib_src, jni_lib_src] + lib_suffix = ".dylib" if args.lib_name.endswith(".dylib") else ".so" + custom_op_lib_src = source_build_dir / f"libcustom_op_library{lib_suffix}" + required_files.append(custom_op_lib_src) + + for f in required_files: + if not f.is_file(): + logging.error(f"Required library file not found: {f}") + sys.exit(1) + logging.info("All required source library files found.") + + # Start processing now that checks have passed + if lib_suffix == ".dylib": # macOS + logging.info("Processing macOS libraries (.dylib)...") + run_command(["dsymutil", main_lib_src, "-o", native_folder / f"{args.lib_name}.dSYM"]) + shutil.copy2(main_lib_src, native_folder / "libonnxruntime.dylib") + run_command(["strip", "-S", native_folder / "libonnxruntime.dylib"]) + + run_command(["dsymutil", jni_lib_src, "-o", native_folder / f"{args.native_lib_name}.dSYM"]) + shutil.copy2(jni_lib_src, native_folder / "libonnxruntime4j_jni.dylib") + run_command(["strip", "-S", native_folder / "libonnxruntime4j_jni.dylib"]) + + shutil.copy2(custom_op_lib_src, target_artifact_dir) + + elif lib_suffix == ".so": # Linux + logging.info("Processing Linux libraries (.so)...") + + # Main library + main_lib_dest = native_folder / "libonnxruntime.so" + shutil.copy2(main_lib_src, main_lib_dest) + run_command(["strip", "-S", main_lib_dest]) + + # JNI library + jni_lib_dest = native_folder / "libonnxruntime4j_jni.so" + shutil.copy2(jni_lib_src, jni_lib_dest) + run_command(["strip", "-S", jni_lib_dest]) + + # Custom op library (not stripped as it's for testing) + shutil.copy2(custom_op_lib_src, target_artifact_dir) + + # Provider checks are optional, so we check for their existence here. + for provider in ["cuda", "tensorrt"]: + provider_lib_src = source_build_dir / f"libonnxruntime_providers_{provider}.so" + if provider_lib_src.exists(): + logging.info(f"Found optional {provider} provider library. Copying and stripping...") + + # Shared provider library + shared_provider_lib_src = source_build_dir / "libonnxruntime_providers_shared.so" + if shared_provider_lib_src.exists(): + shared_provider_dest = native_folder / shared_provider_lib_src.name + shutil.copy2(shared_provider_lib_src, shared_provider_dest) + run_command(["strip", "-S", shared_provider_dest]) + + # Specific provider library + provider_lib_dest = native_folder / provider_lib_src.name + shutil.copy2(provider_lib_src, provider_lib_dest) + run_command(["strip", "-S", provider_lib_dest]) + else: + logging.warning(f"Unsupported library type for '{args.lib_name}'. No special processing will occur.") + + # --- Finalization --- + logging.info(f"--- Final contents of '{target_artifact_dir}' ---") + for path in sorted(target_artifact_dir.rglob("*")): + logging.info(f" - {path.relative_to(target_artifact_dir)}") + logging.info("--- End of contents ---") + + jar_dir_to_remove = target_artifact_dir / "jar" + if jar_dir_to_remove.is_dir(): + logging.info(f"Removing temporary directory: {jar_dir_to_remove}") + shutil.rmtree(jar_dir_to_remove) + + logging.info("Script completed successfully.") + + +if __name__ == "__main__": + try: + main() + except Exception as e: + logging.error(f"Script failed due to an unhandled error: {e}") + sys.exit(1) diff --git a/tools/ci_build/manage_java_artifacts.py b/tools/ci_build/manage_java_artifacts.py new file mode 100644 index 0000000000000..51521f651adec --- /dev/null +++ b/tools/ci_build/manage_java_artifacts.py @@ -0,0 +1,312 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +# This script runs after ORT jars are built. It picks up the jars from ORT's build dir then repack them a bit. + +import argparse +import logging +import re +import shutil +import subprocess +import sys +import zipfile +from pathlib import Path + +# --- Configuration --- +logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s") + + +# --- Helper Functions --- +def run_command(command: list, working_dir: Path): + """Runs a command in a specified directory and checks for errors.""" + logging.info(f"Running command: '{' '.join(map(str, command))}' in '{working_dir}'") + try: + # On Windows, shell=True is required to correctly locate and execute .bat or .cmd files + # like gradlew.bat and mvn.cmd that may be in the system's PATH. + use_shell = sys.platform == "win32" + subprocess.run(command, cwd=working_dir, check=True, shell=use_shell) + logging.info("Command successful.") + except subprocess.CalledProcessError as e: + # Output will have been streamed, so we just need to log the failure. + logging.error(f"Command failed with exit code {e.returncode}") + raise + except FileNotFoundError: + logging.error( + f"Command failed: The executable '{command[0]}' was not found. " + "Please ensure it is installed and that its location is in the system's PATH environment variable." + ) + raise + + +def log_directory_contents(dir_path: Path, description: str): + """Logs the contents of a directory for debugging.""" + logging.info(f"--- Listing contents of {description} at '{dir_path}' ---") + if not dir_path.is_dir(): + logging.warning(f"Directory does not exist: {dir_path}") + return + contents = list(dir_path.rglob("*")) + if not contents: + logging.warning(f"Directory is empty: {dir_path}") + else: + for item in contents: + logging.info(f" - {item.relative_to(dir_path)}") + logging.info("--- End of directory listing ---") + + +def create_zip_from_directory(zip_file_path: Path, source_dir: Path): + """Creates a zip file from the contents of a source directory.""" + logging.info(f"Creating archive '{zip_file_path}' from directory '{source_dir}'...") + with zipfile.ZipFile(zip_file_path, "w", zipfile.ZIP_DEFLATED) as zipf: + for root, _, files in source_dir.walk(): + for file in files: + file_path = root / file + archive_name = file_path.relative_to(source_dir) + zipf.write(file_path, archive_name) + logging.info("Archive created successfully.") + + +# --- New function for validation --- +def validate_artifacts( + platform_dir: Path, main_jar: Path, main_pom: Path, testing_jar: Path, version: str, artifact_id: str +): + """Uses Maven to validate the generated JAR and POM files.""" + logging.info("--- Starting Maven Artifact Validation ---") + maven_executable = "mvn.cmd" if sys.platform == "win32" else "mvn" + group_id = "com.microsoft.onnxruntime" # Assuming this is constant + + # 1. Validate the main ONNX Runtime JAR and its POM + logging.info(f"Validating main artifact: {main_jar.name}") + install_main_cmd = [ + maven_executable, + "install:install-file", + f"-Dfile={main_jar.resolve()}", + f"-DpomFile={main_pom.resolve()}", + # Adding these makes the command more robust and less prone to errors + f"-DgroupId={group_id}", + f"-DartifactId={artifact_id}", + f"-Dversion={version}", + "-Dpackaging=jar", + ] + run_command(install_main_cmd, working_dir=platform_dir) + logging.info("Main artifact validated successfully.") + + # 2. Validate the testing JAR (it has no POM, so we supply all info) + logging.info(f"Validating testing artifact: {testing_jar.name}") + install_testing_cmd = [ + maven_executable, + "install:install-file", + f"-Dfile={testing_jar.resolve()}", + f"-DgroupId={group_id}", + f"-DartifactId={artifact_id}-testing", + f"-Dversion={version}", + "-Dpackaging=jar", + ] + run_command(install_testing_cmd, working_dir=platform_dir) + logging.info("Testing artifact validated successfully.") + logging.info("--- Maven Artifact Validation Complete ---") + + +def main(): + """Main script execution.""" + parser = argparse.ArgumentParser(description="Builds and packages Java artifacts, PDBs, and notice files.") + parser.add_argument("--sources-dir", required=True, type=Path, help="Path to the build sources directory.") + parser.add_argument("--binaries-dir", required=True, type=Path, help="Path to the build binaries directory.") + parser.add_argument("--platform", required=True, help="Platform string (e.g., x64).") + parser.add_argument( + "--java-artifact-id", required=True, help="The Java artifact ID (e.g., onnxruntime or onnxruntime_gpu)." + ) + parser.add_argument( + "--build-config", + choices=["Debug", "Release", "RelWithDebInfo", "MinSizeRel"], + default="RelWithDebInfo", + help="The CMake build configuration type.", + ) + parser.add_argument( + "--pre-release-version-suffix-string", + choices=["alpha", "beta", "rc", "none"], + default="none", + help="The pre-release version suffix string.", + ) + parser.add_argument( + "--pre-release-version-suffix-number", type=int, default=0, help="The pre-release version suffix number." + ) + parser.add_argument("--commit-hash", required=True, help="The git commit hash.") + parser.add_argument("--build-only", action="store_true", help="Flag to indicate if this is a build-only run.") + args = parser.parse_args() + + # --- 1. Version and Build Logic --- + # Determine the repository root from the script's location + repo_root = Path(__file__).resolve().parent.parent.parent + version_file_path = repo_root / "VERSION_NUMBER" + + logging.info(f"Reading base version from {version_file_path}") + if not version_file_path.is_file(): + raise FileNotFoundError(f"Version file not found at {version_file_path}") + + base_version = version_file_path.read_text(encoding="utf-8").strip() + + # Validate the version format + if not re.match(r"^\d+\.\d+\.\d+$", base_version): + raise ValueError(f"Version '{base_version}' from {version_file_path} is not in the required x.y.z format.") + + logging.info(f"Successfully read and validated base version: {base_version}") + + # Start with the base version and conditionally append the pre-release suffix. + full_version = base_version + if args.pre_release_version_suffix_string != "none": + if args.pre_release_version_suffix_number <= 0: + raise ValueError( + "Pre-release version suffix number must be a positive integer if a suffix string is provided." + ) + # Append the suffix, conforming to Maven standards (e.g., 1.2.3-rc1) + full_version += f"-{args.pre_release_version_suffix_string}{args.pre_release_version_suffix_number}" + + logging.info(f"Using full version: {full_version}") + + # Use the java subdirectory of the repository root as the working directory for Gradle + java_working_dir = repo_root / "java" + + build_config_dir = args.binaries_dir / args.build_config + cmake_build_dir_arg = f"-DcmakeBuildDir={build_config_dir}" + version_property_arg = f"-Dorg.gradle.project.version={full_version}" + + # Construct the absolute path to the Gradle wrapper + gradle_executable_name = "gradlew.bat" if sys.platform == "win32" else "gradlew" + gradle_executable_path = java_working_dir / gradle_executable_name + + # Rebuild the jar so that we can change the version + gradle_args = [cmake_build_dir_arg, version_property_arg] + if args.java_artifact_id == "onnxruntime_gpu": + gradle_args.append("-DUSE_CUDA") + gradle_args.append("-DUSE_TENSORRT") + run_command([str(gradle_executable_path), "cmakeBuild", *gradle_args], working_dir=java_working_dir) + if args.build_only: + run_command( + [ + str(gradle_executable_path), + "testClasses", + "--warning-mode", + "all", + *gradle_args, + ], + working_dir=java_working_dir, + ) + else: + run_command( + [ + str(gradle_executable_path), + "cmakeCheck", + "--warning-mode", + "all", + *gradle_args, + ], + working_dir=java_working_dir, + ) + + # --- 2. Path Definitions --- + platform_dir = args.binaries_dir / f"onnxruntime-java-win-{args.platform}" + stage_dir = platform_dir / "stage" + native_folder = stage_dir / "ai" / "onnxruntime" / "native" / f"win-{args.platform}" + main_jar_name = f"{args.java_artifact_id}-{full_version}.jar" + main_jar_path = platform_dir / main_jar_name + final_pom_path = platform_dir / f"{args.java_artifact_id}-{full_version}.pom" + testing_jar_path = platform_dir / "testing.jar" + + # --- 3. Packaging Logic --- + try: + stage_dir.mkdir(parents=True, exist_ok=True) + native_folder.mkdir(parents=True, exist_ok=True) + + gradle_libs_dir = java_working_dir / "build" / "libs" + log_directory_contents(gradle_libs_dir, "Gradle build output libs") + + # FIX: Filter glob results to find the main artifact JAR, excluding sources and javadoc. + main_jars = [ + p + for p in gradle_libs_dir.glob("*.jar") + if not p.name.endswith("-sources.jar") and not p.name.endswith("-javadoc.jar") + ] + + if not main_jars: + raise FileNotFoundError(f"Gradle build finished, but no main artifact JAR was found in {gradle_libs_dir}") + if len(main_jars) > 1: + logging.warning(f"Found multiple potential main JARs: {[p.name for p in main_jars]}. Using the first one.") + + source_jar_path = main_jars[0] + logging.info(f"Found source JAR to copy: {source_jar_path.name}") + + # The main JAR file is copied to its final name directly. + shutil.copy2(source_jar_path, main_jar_path) + + # Now, find and copy the associated sources and javadoc JARs, renaming them to match. + source_basename = source_jar_path.stem # e.g., 'onnxruntime-1.23.0' + dest_basename = main_jar_path.stem # e.g., 'onnxruntime_gpu-1.23.0' + + for classifier in ["sources", "javadoc"]: + source_classified_jar = gradle_libs_dir / f"{source_basename}-{classifier}.jar" + if source_classified_jar.is_file(): + dest_classified_jar = platform_dir / f"{dest_basename}-{classifier}.jar" + logging.info(f"Copying classified artifact: {source_classified_jar.name} -> {dest_classified_jar.name}") + shutil.copy2(source_classified_jar, dest_classified_jar) + else: + logging.warning(f"Optional artifact '{source_classified_jar.name}' not found, skipping.") + + log_directory_contents(platform_dir, "final platform directory before JAR processing") + + pom_archive_path = f"META-INF/maven/com.microsoft.onnxruntime/{args.java_artifact_id}/pom.xml" + with zipfile.ZipFile(main_jar_path, "r") as jar: + jar.extract(pom_archive_path, path=platform_dir) + + shutil.move(str(platform_dir / pom_archive_path), str(final_pom_path)) + shutil.rmtree(platform_dir / "META-INF") + + shutil.copy2(args.sources_dir / "docs" / "Privacy.md", stage_dir) + shutil.copy2(args.sources_dir / "ThirdPartyNotices.txt", stage_dir) + (stage_dir / "GIT_COMMIT_ID").write_text(args.commit_hash, encoding="utf-8") + + with zipfile.ZipFile(main_jar_path, "a") as jar: + for root, _, files in stage_dir.walk(): + for file in files: + file_path = root / file + jar.write(file_path, file_path.relative_to(stage_dir)) + + test_classes_dir = args.sources_dir / "java" / "build" / "classes" / "java" / "test" + test_resources_dir = args.sources_dir / "java" / "build" / "resources" / "test" + + create_zip_from_directory(testing_jar_path, test_classes_dir) + + native_resource_path = test_resources_dir / "ai" / "onnxruntime" / "native" + if native_resource_path.exists(): + shutil.rmtree(native_resource_path) + + with zipfile.ZipFile(testing_jar_path, "a") as jar: + for root, _, files in test_resources_dir.walk(): + for file in files: + file_path = root / file + jar.write(file_path, file_path.relative_to(test_resources_dir)) + + logging.info("Java artifact packaging complete.") + + # --- 4. Validation Step --- + validate_artifacts( + platform_dir=platform_dir, + main_jar=main_jar_path, + main_pom=final_pom_path, + testing_jar=testing_jar_path, + version=full_version, + artifact_id=args.java_artifact_id, + ) + + finally: + # 5. Clean up stage directory + if stage_dir.exists(): + logging.info(f"Cleaning up stage directory: {stage_dir}") + shutil.rmtree(stage_dir) + + logging.info(f"\nFinal contents of '{platform_dir}':") + for item in platform_dir.iterdir(): + print(item) + + +if __name__ == "__main__": + main() diff --git a/tools/ci_build/prepare_macos_package.py b/tools/ci_build/prepare_macos_package.py new file mode 100644 index 0000000000000..b92e81663c776 --- /dev/null +++ b/tools/ci_build/prepare_macos_package.py @@ -0,0 +1,185 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import argparse +import os +import pathlib +import shutil +import stat as stat_module +import subprocess +import sys +import tarfile +from datetime import datetime + + +def run_command(command: list[str | pathlib.Path], check: bool = True) -> subprocess.CompletedProcess: + """Helper to run a command, stream its output, and check for errors.""" + print(f"Executing: {' '.join(map(str, command))}", flush=True) + try: + return subprocess.run(command, check=check, text=True, capture_output=True) + except subprocess.CalledProcessError as e: + print(f"ERROR: Command failed with exit code {e.returncode}", file=sys.stderr) + print(f"--- STDOUT ---\n{e.stdout}", file=sys.stderr) + print(f"--- STDERR ---\n{e.stderr}", file=sys.stderr) + raise + + +def get_relative_file_paths(root_dir: pathlib.Path) -> set[pathlib.Path]: + """ + Returns a set of all relative file paths within a directory, + ignoring any files inside .dSYM directories. + """ + paths = set() + for p in root_dir.rglob("*"): + # Check if any part of the path is a .dSYM directory. + if any(part.endswith(".dSYM") for part in p.relative_to(root_dir).parts): + continue + if p.is_file(): + paths.add(p.relative_to(root_dir)) + return paths + + +def is_macho_binary(file_path: pathlib.Path) -> bool: + """Checks if a file is a Mach-O binary using the 'file' command.""" + if not file_path.is_file(): + return False + try: + result = run_command(["file", file_path]) + return "Mach-O" in result.stdout + except (subprocess.CalledProcessError, FileNotFoundError): + return False + + +def main(): + """Main function to prepare macOS packages for signing.""" + # 1. Setup paths and parse arguments + parser = argparse.ArgumentParser(description="Prepares macOS packages for signing.") + parser.add_argument( + "--staging_dir", + type=pathlib.Path, + required=True, + help="The directory where artifacts are staged and processed.", + ) + args = parser.parse_args() + staging_dir = args.staging_dir.resolve() + + if not staging_dir.is_dir(): + raise FileNotFoundError(f"Staging directory not found: {staging_dir}") + + os.chdir(staging_dir) + print(f"##[group]Working in directory: {staging_dir}") + print(f"Initial contents: {[p.name for p in staging_dir.iterdir()]}") + print("##[endgroup]") + + # 2. Unpack all .tgz archives + print("##[group]Unpacking downloaded archives...") + tgz_files = list(staging_dir.glob("*.tgz")) + if not tgz_files: + raise FileNotFoundError("Build Error: No .tgz files found to process.") + + for tgz in tgz_files: + print(f"Extracting {tgz.name}...") + with tarfile.open(tgz) as tar: + tar.extractall(path=".") + tgz.unlink() # Delete the archive + print("##[endgroup]") + + # 3. Locate architecture-specific directories + print("##[group]Locating architecture directories...") + arm64_dirs = list(staging_dir.glob("onnxruntime-osx-arm64*")) + x64_dirs = list(staging_dir.glob("onnxruntime-osx-x86_64*")) + + if len(arm64_dirs) != 1 or len(x64_dirs) != 1: + raise FileNotFoundError( + f"Build Error: Expected 1 arm64 and 1 x64 directory, but found: arm64={len(arm64_dirs)}, x64={len(x64_dirs)}" + ) + + arm64_dir, x64_dir = arm64_dirs[0], x64_dirs[0] + print(f"Found ARM64 source: {arm64_dir.name}") + print(f"Found x86_64 source: {x64_dir.name}") + print("##[endgroup]") + + # **NEW**: Remove _manifest directories before comparison or processing. + print("##[group]Removing _manifest directories...") + for package_dir in (arm64_dir, x64_dir): + manifest_path = package_dir / "_manifest" + if manifest_path.is_dir(): + print(f"Removing manifest directory: {manifest_path.relative_to(staging_dir)}") + shutil.rmtree(manifest_path) + print("##[endgroup]") + + # 4. Error Check: Verify file tree structures are identical + print("##[group]Verifying file tree structures...") + arm64_files = get_relative_file_paths(arm64_dir) + x64_files = get_relative_file_paths(x64_dir) + + if arm64_files != x64_files: + difference = arm64_files.symmetric_difference(x64_files) + print(f"ERROR: File tree structures do not match. Found {len(difference)} differing files:", file=sys.stderr) + for f in sorted(difference): + print(f"- {f}", file=sys.stderr) + sys.exit(1) + + print("✅ File tree structures match.") + print("##[endgroup]") + + # 5. Create the universal binary package + print("##[group]Creating universal2 package with lipo...") + universal_dir = staging_dir / arm64_dir.name.replace("arm64", "universal2") + + print(f"Copying {arm64_dir.name} to {universal_dir.name} as a template.") + shutil.copytree(arm64_dir, universal_dir, symlinks=True, ignore=shutil.ignore_patterns("*.dSYM")) + + for relative_path in arm64_files: + arm64_file = arm64_dir / relative_path + x64_file = x64_dir / relative_path + universal_file = universal_dir / relative_path + + if is_macho_binary(arm64_file) and is_macho_binary(x64_file): + print(f"Combining {relative_path}...") + run_command(["lipo", "-create", arm64_file, x64_file, "-output", universal_file]) + run_command(["lipo", "-info", universal_file]) + print("##[endgroup]") + + # Remove .dSYM folders from source packages before zipping. + print("##[group]Removing .dSYM folders from source packages...") + for package_dir in (arm64_dir, x64_dir): + for dsym_dir in package_dir.rglob("*.dSYM"): + if dsym_dir.is_dir(): + print(f"Removing {dsym_dir.relative_to(staging_dir)}") + shutil.rmtree(dsym_dir) + print("##[endgroup]") + + # 6. Zip all packages for signing and clean up + print("##[group]Zipping all packages for signing...") + for dir_path in (arm64_dir, x64_dir, universal_dir): + # Create a zip file in the staging directory. + zip_file_path = staging_dir / f"{dir_path.name}.zip" + print(f"Zipping {dir_path.name} to {zip_file_path}") + # The source directory path (dir_path.name) is relative to the current working directory (staging_dir). + run_command(["zip", "-FSr", "--symlinks", zip_file_path, dir_path.name]) + + print(f"Removing directory {dir_path.name}") + shutil.rmtree(dir_path) + + print("Final contents of staging directory:") + for item in sorted(staging_dir.iterdir()): + try: + stat = item.stat() + size = stat.st_size + mode_str = stat_module.filemode(stat.st_mode) + mtime = datetime.fromtimestamp(stat.st_mtime).strftime("%b %d %H:%M") + print(f"{mode_str} {size:>10} {mtime} {item.name}") + except FileNotFoundError: + # Handle cases where a file might be a broken symlink + print(f"l????????? {'?':>10} ? ? {item.name} (broken link)") + + print("##[endgroup]") + + +if __name__ == "__main__": + try: + main() + except Exception as e: + print(f"##[error]A critical error occurred: {e}", file=sys.stderr) + sys.exit(1) diff --git a/tools/ci_build/run_gh_action.py b/tools/ci_build/run_gh_action.py new file mode 100644 index 0000000000000..ddaabca5cabad --- /dev/null +++ b/tools/ci_build/run_gh_action.py @@ -0,0 +1,158 @@ +import os +import platform +import shutil +import sys +import tempfile +import zipfile +from pathlib import Path + +import requests + +SCRIPT_DIR = Path(__file__).resolve().parent +REPO_DIR = (SCRIPT_DIR / ".." / "..").resolve() + +sys.path.insert(0, str(REPO_DIR / "tools" / "python")) + +from util import run # noqa: E402 + +# Hash structure for platform-specific binaries +CMAKE_HASHES = { + "windows": { + "x64": "807b774fcb12defff8ce869e602fc5b6279d5b7bf7229ebcf3f7490da3f887d516b9c49a00d50f9179e552ed8737d19835a19ef8f366d1ffda1ad6f3352a90c2", + "arm64": "86937dc89deabe0ff2a08fe198fcfc70764476b865cca4c6dc3bfc7fb9f7d44d4929af919e26e84aaedef17ad01ffb9683e42c39cb38b409100f723bc5ef1cc0", + }, + "linux": { + "x64": "7939260931098c3f00d2b36de3bee6a0ee3bcae2dba001598c492ed5c82d295c9aa9969654f1ff937fec4d71679541238baaa648c5246f36e14f28f0a62337a0", + "arm64": "8eeb07e966a5340c122979dd2e371708a78adccc85200b22bc7e66028e65513bce5ced6c37fe65aedb94000d970186c5c7562d1ab3dbda911061de46b75345d9", + }, + "macos": "99cc9c63ae49f21253efb5921de2ba84ce136018abf08632c92c060ba91d552e0f6acc214e9ba8123dee0cf6d1cf089ca389e321879fd9d719a60d975bcffcc8", +} + + +def get_platform_keys() -> tuple[str | None, str | None]: + """Detects the OS and CPU architecture and returns normalized keys.""" + os_key: str | None = None + match sys.platform: + case "win32": + os_key = "windows" + case "linux": + os_key = "linux" + case "darwin": + os_key = "macos" + + arch_key: str | None = None + match platform.machine().lower(): + case "amd64" | "x86_64": + arch_key = "x64" + case "arm64" | "aarch64": + arch_key = "arm64" + + return os_key, arch_key + + +def main() -> None: + if len(sys.argv) < 2: + print("::error::Action version argument was not provided.") + sys.exit(1) + + action_version = sys.argv[1] + + # --- Platform Detection and Variable Setup --- + os_key, arch_key = get_platform_keys() + if not os_key or not arch_key: + print( + f"::error::Could not determine a supported platform from OS '{sys.platform}' and Arch '{platform.machine()}'." + ) + sys.exit(1) + + print(f"Detected Platform: OS='{os_key}', Architecture='{arch_key}'") + + try: + if os_key == "macos": + cmake_hash = CMAKE_HASHES[os_key] + else: + cmake_hash = CMAKE_HASHES[os_key][arch_key] + + print(f"Selected CMake hash for '{os_key}'.") + except KeyError: + print(f"::error::Unsupported platform or missing hash for OS='{os_key}' and Arch='{arch_key}'.") + sys.exit(1) + + # Conditionally set the 'DISABLE_TERRAPIN' value + disable_terrapin_value = "true" + if os_key == "windows" and arch_key == "x64": + disable_terrapin_value = "false" + print("Setting INPUT_DISABLE-TERRAPIN to 'false' for Windows x64.") + + action_inputs = { + "INPUT_CMAKE-VERSION": "3.31.8", + "INPUT_CMAKE-HASH": cmake_hash, + "INPUT_VCPKG-VERSION": "2025.06.13", + "INPUT_VCPKG-HASH": "735923258c5187966698f98ce0f1393b8adc6f84d44fd8829dda7db52828639331764ecf41f50c8e881e497b569f463dbd02dcb027ee9d9ede0711102de256cc", + "INPUT_ADD-CMAKE-TO-PATH": "true", + "INPUT_DISABLE-TERRAPIN": disable_terrapin_value, + } + + # --- Download and Extract the Action to a Temporary Directory --- + zip_url = f"https://github.com/microsoft/onnxruntime-github-actions/archive/refs/tags/{action_version}.zip" + + # Use AGENT_TEMPDIRECTORY, with a fallback to the system's default temp directory. + temp_dir = Path(os.environ.get("AGENT_TEMPDIRECTORY", tempfile.gettempdir())).resolve() + zip_path = temp_dir / "action.zip" + extract_dir = temp_dir / "action-unzipped" + + print(f"Using temporary directory: {temp_dir}") + + # --- Locate, Run, and Cleanup the Action Script --- + try: + print(f"Downloading action source from: {zip_url}") + response = requests.get(zip_url, stream=True) + response.raise_for_status() + with open(zip_path, "wb") as f: + shutil.copyfileobj(response.raw, f) + + print(f"Extracting {zip_path} to {extract_dir}") + if extract_dir.exists(): + shutil.rmtree(extract_dir) + with zipfile.ZipFile(zip_path, "r") as zip_ref: + zip_ref.extractall(extract_dir) + + try: + action_base_path = next(extract_dir.glob("onnxruntime-github-actions-*")) + print(f"Found action base path: {action_base_path}") + except StopIteration as e: + raise FileNotFoundError(f"Could not find extracted action directory in '{extract_dir}'") from e + + action_script_path = action_base_path / "setup-build-tools" / "dist" / "index.js" + if not action_script_path.exists(): + raise FileNotFoundError(f"Action script not found at expected path: {action_script_path}") + + env = os.environ.copy() + env.update(action_inputs) + + if "AGENT_TOOLSDIRECTORY" in env: + env["RUNNER_TOOL_CACHE"] = env["AGENT_TOOLSDIRECTORY"] + print(f"Mapped RUNNER_TOOL_CACHE to AGENT_TOOLSDIRECTORY: {env['RUNNER_TOOL_CACHE']}") + if "AGENT_TEMPDIRECTORY" in env: + env["RUNNER_TEMP"] = env["AGENT_TEMPDIRECTORY"] + print(f"Mapped RUNNER_TEMP to AGENT_TEMPDIRECTORY: {env['RUNNER_TEMP']}") + + run("node", str(action_script_path), env=env) + + finally: + # --- Cleanup --- + # This block ensures the zip file and extracted directory are always removed. + print("\nStarting cleanup...") + if zip_path.exists(): + print(f"Removing temporary zip file: {zip_path}") + zip_path.unlink() + + if extract_dir.exists(): + print(f"Removing extracted action directory: {extract_dir}") + shutil.rmtree(extract_dir) + + print("Cleanup complete.") + + +if __name__ == "__main__": + main() diff --git a/tools/python/run_packaging_pipelines.py b/tools/python/run_packaging_pipelines.py index 4948f35c642e8..259b7f9e39e9c 100644 --- a/tools/python/run_packaging_pipelines.py +++ b/tools/python/run_packaging_pipelines.py @@ -446,28 +446,42 @@ def main(): print(f" - {result['pipeline']['name']} (ID: {result['pipeline']['id']})") else: print(f"\n--- Triggering {len(pipelines_to_trigger)} Pipelines on branch '{branch_for_trigger}' ---") - nightly_override = None - release_override = None - if args.build_mode == "nightly": - nightly_override = "1" - release_override = "false" - elif args.build_mode == "release": - nightly_override = "0" - release_override = "true" # If pre-release flags are used, it implies a release build. if args.pre_release_suffix_string: print("Pre-release suffix provided. Forcing 'release' build mode.") if args.build_mode and args.build_mode != "release": print(f"Warning: --build-mode={args.build_mode} is overridden by pre-release flags.") - nightly_override = "0" - release_override = "true" + + # If pre-release flags are used, it implies a release build. + if args.pre_release_suffix_string: + print("Pre-release suffix provided. Forcing 'release' build mode.") + if args.build_mode and args.build_mode != "release": + print(f"Warning: --build-mode={args.build_mode} is overridden by pre-release flags.") for result in pipelines_to_trigger: pipeline = result["pipeline"] packaging_type = result["packaging_type"] has_pre_release_params = result["has_pre_release_params"] + # Determine build mode based on flags + nightly_override = None + release_override = None + if args.build_mode == "nightly": + nightly_override = "1" + release_override = "false" + elif args.build_mode == "release": + nightly_override = "0" + release_override = "true" + + # If pre-release flags are used AND the pipeline supports them, it implies a release build. + if args.pre_release_suffix_string and has_pre_release_params: + print(f"Pre-release flags used and supported by '{pipeline['name']}'. Forcing 'release' mode.") + if args.build_mode and args.build_mode != "release": + print(f" - Warning: --build-mode={args.build_mode} is overridden for this pipeline.") + nightly_override = "0" + release_override = "true" + if not args.no_cancel_builds: cancel_running_builds(pipeline["id"], branch_for_trigger, token, project) else: