diff --git a/.config/1espt/PipelineAutobaseliningConfig.yml b/.config/1espt/PipelineAutobaseliningConfig.yml index 18315e55e854d..f59528797405e 100644 --- a/.config/1espt/PipelineAutobaseliningConfig.yml +++ b/.config/1espt/PipelineAutobaseliningConfig.yml @@ -133,13 +133,16 @@ pipelines: lastModifiedDate: 2025-04-24 armory: lastModifiedDate: 2025-04-24 + policheck: + lastModifiedDate: 2025-07-25 binary: credscan: lastModifiedDate: 2025-04-25 binskim: - lastModifiedDate: 2025-04-25 + lastModifiedDate: 2025-07-25 spotbugs: lastModifiedDate: 2025-04-25 + usedBinskimScanAllExtensions: true 1757: retail: source: @@ -151,13 +154,16 @@ pipelines: lastModifiedDate: 2025-04-25 armory: lastModifiedDate: 2025-04-25 + policheck: + lastModifiedDate: 2025-07-23 binary: credscan: lastModifiedDate: 2025-04-25 binskim: - lastModifiedDate: 2025-04-25 + lastModifiedDate: 2025-07-24 spotbugs: lastModifiedDate: 2025-04-25 + usedBinskimScanAllExtensions: true 1234: retail: source: @@ -169,10 +175,24 @@ pipelines: lastModifiedDate: 2025-04-25 armory: lastModifiedDate: 2025-04-25 + policheck: + lastModifiedDate: 2025-07-23 binary: credscan: lastModifiedDate: 2025-04-25 binskim: - lastModifiedDate: 2025-04-25 + lastModifiedDate: 2025-07-24 spotbugs: lastModifiedDate: 2025-04-25 + usedBinskimScanAllExtensions: true + 1311: + retail: + source: + credscan: + lastModifiedDate: 2025-07-18 + eslint: + lastModifiedDate: 2025-07-18 + psscriptanalyzer: + lastModifiedDate: 2025-07-18 + armory: + lastModifiedDate: 2025-07-18 diff --git a/.config/guardian/.gdnbaselines b/.config/guardian/.gdnbaselines index 7246ad6ba36df..18a9250059134 100644 --- a/.config/guardian/.gdnbaselines +++ b/.config/guardian/.gdnbaselines @@ -409,6 +409,66 @@ "createdDate": "2025-04-25 22:25:55Z", "expirationDate": "2025-10-12 23:01:19Z", "justification": "This error is baselined with an expiration date of 180 days from 2025-04-25 23:01:19Z" + }, + "67acaef0adebeee9ddcb2ff2630fa3051c0c8e7083f36f64ac0040d9a22b73b5": { + "signature": "67acaef0adebeee9ddcb2ff2630fa3051c0c8e7083f36f64ac0040d9a22b73b5", + "alternativeSignatures": [ + "2f5e8344c6d8ffa32a8a54c363d0c480380320a6c0a3fd3e4ca1ff2aafe6dbcf" + ], + "target": "file:///E:/_work/_temp/RelWithDebInfo/RelWithDebInfo/dxcompiler.dll", + "memberOf": [ + "default" + ], + "tool": "binskim", + "ruleId": "BA2007", + "createdDate": "2025-07-24 10:13:44Z", + "expirationDate": "2026-01-10 11:03:50Z", + "justification": "This error is baselined with an expiration date of 180 days from 2025-07-24 11:03:50Z" + }, + "7a02b29f8870bfd4cc770281dd860421523a3aac51ea96332e25696ca1f5570e": { + "signature": "7a02b29f8870bfd4cc770281dd860421523a3aac51ea96332e25696ca1f5570e", + "alternativeSignatures": [ + "1229412e0db78558feac3bc51ea9eed6ae2311e60298dc1f2d3366bd12544c88" + ], + "target": "file:///E:/_work/_temp/RelWithDebInfo/RelWithDebInfo/onnxruntime_perf_test.exe", + "memberOf": [ + "default" + ], + "tool": "binskim", + "ruleId": "BA2007", + "createdDate": "2025-07-24 10:13:44Z", + "expirationDate": "2026-01-10 11:03:50Z", + "justification": "This error is baselined with an expiration date of 180 days from 2025-07-24 11:03:50Z" + }, + "68b657f6d9dd9386bd43f0716cb424c196f7da4559eb1c3f3f26a1297b211239": { + "signature": "68b657f6d9dd9386bd43f0716cb424c196f7da4559eb1c3f3f26a1297b211239", + "alternativeSignatures": [ + "e026012915cda24b9e85a1d1fa38607d09effa532b40a1c0f0740eb3855f9599" + ], + "target": "file:///E:/_work/_temp/RelWithDebInfo/RelWithDebInfo/onnxruntime_test_all.exe", + "memberOf": [ + "default" + ], + "tool": "binskim", + "ruleId": "BA2007", + "createdDate": "2025-07-24 10:13:44Z", + "expirationDate": "2026-01-10 11:03:50Z", + "justification": "This error is baselined with an expiration date of 180 days from 2025-07-24 11:03:50Z" + }, + "8222fbb019a791fa0541084854a0bf7bf723ce7ffaa4a0e1e5ca5cb76acb48bb": { + "signature": "8222fbb019a791fa0541084854a0bf7bf723ce7ffaa4a0e1e5ca5cb76acb48bb", + "alternativeSignatures": [ + "5e0bdc06af73864bdb480aceaf154a35e0774ab7f8490e7d9f8b5a36b7c19619" + ], + "target": "file:///E:/_work/_temp/RelWithDebInfo/RelWithDebInfo/onnx_test_runner.exe", + "memberOf": [ + "default" + ], + "tool": "binskim", + "ruleId": "BA2007", + "createdDate": "2025-07-24 10:13:44Z", + "expirationDate": "2026-01-10 11:03:50Z", + "justification": "This error is baselined with an expiration date of 180 days from 2025-07-24 11:03:50Z" } } } \ No newline at end of file diff --git a/.github/workflows/linux_cuda_ci.yml b/.github/workflows/linux_cuda_ci.yml index 38526e7a5c00f..f4ee8a7c27cd0 100644 --- a/.github/workflows/linux_cuda_ci.yml +++ b/.github/workflows/linux_cuda_ci.yml @@ -29,7 +29,7 @@ jobs: dockerfile_path: tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda docker_build_args: '--build-arg BASEIMAGE=onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_ubi8_gcc12:20250124.1' docker_image_repo: onnxruntimecuda12manylinuxbuild - extra_build_flags: '--use_binskim_compliant_compile_flags --build_wheel --cuda_version=12.2 --cuda_home=/usr/local/cuda-12.2 --cudnn_home=/usr/local/cuda-12.2 --enable_cuda_profiling --build_java --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=80 onnxruntime_BUILD_UNIT_TESTS=ON onnxruntime_ENABLE_CUDA_EP_INTERNAL_TESTS=ON' + extra_build_flags: '--use_binskim_compliant_compile_flags --build_wheel --cuda_version=12.2 --cuda_home=/usr/local/cuda-12.2 --cudnn_home=/usr/local/cuda-12.2 --enable_cuda_profiling --build_java --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=90 onnxruntime_BUILD_UNIT_TESTS=ON onnxruntime_ENABLE_CUDA_EP_INTERNAL_TESTS=ON' python_path_prefix: 'PATH=/opt/python/cp310-cp310/bin:$PATH' run_tests: false # <<< Do not run tests in this job upload_build_output: true # <<< Upload the build/Release directory @@ -42,7 +42,7 @@ jobs: needs: build-linux-cuda-x64-release runs-on: - self-hosted - - "1ES.Pool=Onnxruntime-github-Linux-GPU-A100-WUS3" + - "1ES.Pool=Onnxruntime-github-Linux-GPU-H100" permissions: contents: read packages: read @@ -99,5 +99,5 @@ jobs: build_config: Release mode: 'test' # Set mode to test execution_providers: 'cuda' - extra_build_flags: '--use_binskim_compliant_compile_flags --cuda_version=12.2 --cuda_home=/usr/local/cuda-12.2 --cudnn_home=/usr/local/cuda-12.2 --enable_cuda_profiling --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=80 onnxruntime_BUILD_UNIT_TESTS=ON onnxruntime_ENABLE_CUDA_EP_INTERNAL_TESTS=ON' + extra_build_flags: '--use_binskim_compliant_compile_flags --cuda_version=12.2 --cuda_home=/usr/local/cuda-12.2 --cudnn_home=/usr/local/cuda-12.2 --enable_cuda_profiling --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=90 onnxruntime_BUILD_UNIT_TESTS=ON onnxruntime_ENABLE_CUDA_EP_INTERNAL_TESTS=ON' python_path_prefix: 'PATH=/opt/python/cp310-cp310/bin:$PATH' diff --git a/.github/workflows/linux_tensorrt_ci.yml b/.github/workflows/linux_tensorrt_ci.yml index 1df467043329a..a7d3f5ec0f5fd 100644 --- a/.github/workflows/linux_tensorrt_ci.yml +++ b/.github/workflows/linux_tensorrt_ci.yml @@ -29,7 +29,7 @@ jobs: dockerfile_path: tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda docker_build_args: '--build-arg BASEIMAGE=onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_ubi8_gcc12:20250124.1 --build-arg TRT_VERSION=10.9.0.34-1.cuda12.8 --network=host' docker_image_repo: onnxruntimetensorrt86gpubuild - extra_build_flags: '--use_binskim_compliant_compile_flags --build_wheel --cuda_version=12.2 --cuda_home=/usr/local/cuda-12.2 --cudnn_home=/usr/local/cuda-12.2 --use_tensorrt --tensorrt_home /usr --build_java --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=80 onnxruntime_BUILD_UNIT_TESTS=ON onnxruntime_ENABLE_CUDA_EP_INTERNAL_TESTS=ON' + extra_build_flags: '--use_binskim_compliant_compile_flags --build_wheel --cuda_version=12.2 --cuda_home=/usr/local/cuda-12.2 --cudnn_home=/usr/local/cuda-12.2 --use_tensorrt --tensorrt_home /usr --build_java --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=90 onnxruntime_BUILD_UNIT_TESTS=ON onnxruntime_ENABLE_CUDA_EP_INTERNAL_TESTS=ON' python_path_prefix: 'PATH=/opt/python/cp310-cp310/bin:$PATH' run_tests: false # <<< Do not run tests in this job upload_build_output: true # <<< Upload the build/Release directory @@ -42,7 +42,7 @@ jobs: needs: build-linux-TensorRT-x64-release runs-on: - self-hosted - - "1ES.Pool=Onnxruntime-github-Linux-GPU-A100-WUS3" + - "1ES.Pool=Onnxruntime-github-Linux-GPU-H100" permissions: contents: read packages: read @@ -101,5 +101,5 @@ jobs: build_config: Release mode: 'test' # Set mode to test execution_providers: 'cuda tensorrt' - extra_build_flags: '--use_binskim_compliant_compile_flags --build_wheel --cuda_version=12.2 --cuda_home=/usr/local/cuda-12.2 --cudnn_home=/usr/local/cuda-12.2 --use_tensorrt --tensorrt_home /usr --build_java --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=80 onnxruntime_BUILD_UNIT_TESTS=ON onnxruntime_ENABLE_CUDA_EP_INTERNAL_TESTS=ON' + extra_build_flags: '--use_binskim_compliant_compile_flags --build_wheel --cuda_version=12.2 --cuda_home=/usr/local/cuda-12.2 --cudnn_home=/usr/local/cuda-12.2 --use_tensorrt --tensorrt_home /usr --build_java --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=90 onnxruntime_BUILD_UNIT_TESTS=ON onnxruntime_ENABLE_CUDA_EP_INTERNAL_TESTS=ON' python_path_prefix: 'PATH=/opt/python/cp310-cp310/bin:$PATH' diff --git a/.github/workflows/windows_tensorrt.yml b/.github/workflows/windows_tensorrt.yml index e65d23069ad32..dbc138e57a3ec 100644 --- a/.github/workflows/windows_tensorrt.yml +++ b/.github/workflows/windows_tensorrt.yml @@ -15,14 +15,15 @@ concurrency: group: ${{ github.workflow }}-${{ github.event_name == 'pull_request' && github.ref || github.sha }} cancel-in-progress: true +#TODO: enable --build_nodejs jobs: - Windows_GPU_TensorRT_CI_Pipeline: + build: name: Windows GPU TensorRT CI Pipeline - runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Win2022-GPU-A10"] + runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-vs2022-mms"] steps: - uses: actions/checkout@v4 with: - fetch-depth: 0 # Fetch all history for all tags and branches + fetch-depth: 0 submodules: 'none' - uses: actions/setup-python@v5 @@ -36,29 +37,20 @@ jobs: architecture: x64 - name: Install python modules - run: python -m pip install -r ${{ github.workspace }}\tools\ci_build\github\windows\python\requirements.txt + run: python -m pip install -r .\tools\ci_build\github\windows\python\requirements.txt + working-directory: ${{ github.workspace }} shell: cmd - - name: Download Primary CUDA SDK v12.2 - run: 'azcopy.exe cp --recursive "https://lotusscus.blob.core.windows.net/models/cuda_sdk/v12.2" ${{ runner.temp }}' + - name: Download CUDA SDK v12.2 + working-directory: ${{ runner.temp }} + run: | + azcopy.exe cp --recursive "https://lotusscus.blob.core.windows.net/models/cuda_sdk/v12.2" . + dir shell: pwsh - env: - AZCOPY_AUTO_LOGIN_TYPE: MSI - AZCOPY_MSI_CLIENT_ID: 63b63039-6328-442f-954b-5a64d124e5b4 - name: Download TensorRT-10.9.0.34.Windows10.x86_64.cuda-12.8 run: 'azcopy.exe cp --recursive "https://lotusscus.blob.core.windows.net/models/local/TensorRT-10.9.0.34.Windows10.x86_64.cuda-12.8" ${{ runner.temp }}' shell: pwsh - env: - AZCOPY_AUTO_LOGIN_TYPE: MSI - AZCOPY_MSI_CLIENT_ID: 63b63039-6328-442f-954b-5a64d124e5b4 - - - name: Export GitHub Actions cache environment variables - uses: actions/github-script@v7 - with: - script: | - core.exportVariable('ACTIONS_CACHE_URL', process.env.ACTIONS_CACHE_URL || ''); - core.exportVariable('ACTIONS_RUNTIME_TOKEN', process.env.ACTIONS_RUNTIME_TOKEN || ''); - name: Add CUDA to PATH shell: powershell @@ -69,33 +61,198 @@ jobs: Add-Content -Path $env:GITHUB_PATH -Value "$env:RUNNER_TEMP\v12.2\extras\CUPTI\lib64" Add-Content -Path $env:GITHUB_PATH -Value "$env:RUNNER_TEMP\TensorRT-10.9.0.34.Windows10.x86_64.cuda-12.8\lib" - - name: Generate sln - working-directory: ${{ runner.temp }} + - uses: actions/setup-node@v4 + with: + node-version: '20.x' + + - uses: actions/setup-java@v4 + with: + distribution: 'temurin' + java-version: '17' + architecture: x64 + + - uses: actions/cache@v4 + id: onnx-node-tests-cache + with: + path: ${{ github.workspace }}/js/test/ + key: onnxnodetests-${{ hashFiles('js/scripts/prepare-onnx-node-tests.ts') }} + + - name: API Documentation Check and generate run: | - python ${{ github.workspace }}\tools\ci_build\build.py --config RelWithDebInfo --parallel --use_binskim_compliant_compile_flags --build_dir build --skip_submodule_sync --build_shared_lib --update --cmake_generator "Visual Studio 17 2022" --build_wheel --enable_onnx_tests --use_tensorrt --tensorrt_home="${{ runner.temp }}\TensorRT-10.9.0.34.Windows10.x86_64.cuda-12.8" --cuda_home="${{ runner.temp }}\v12.2" --use_vcpkg --use_vcpkg_ms_internal_asset_cache --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=86 + set ORT_DOXY_SRC=${{ github.workspace }} + set ORT_DOXY_OUT=${{ runner.temp }}\build\RelWithDebInfo\RelWithDebInfo + mkdir %ORT_DOXY_SRC% + mkdir %ORT_DOXY_OUT% + "C:\Program Files\doxygen\bin\doxygen.exe" ${{ github.workspace }}\tools\ci_build\github\Doxyfile_csharp.cfg + working-directory: ${{ github.workspace }} shell: cmd - - name: Build + - uses: actions/setup-dotnet@v4 + env: + PROCESSOR_ARCHITECTURE: x64 + with: + dotnet-version: '8.x' + + - name: Use Nuget 6.x + uses: nuget/setup-nuget@v2 + with: + nuget-version: '6.x' + + - name: NuGet restore + run: nuget restore ${{ github.workspace }}\packages.config -ConfigFile ${{ github.workspace }}\NuGet.config -PackagesDirectory ${{ runner.temp }}\build\RelWithDebInfo + shell: cmd + + - name: Set OnnxRuntimeBuildDirectory + shell: pwsh + run: | + $buildDir = Join-Path ${{ runner.temp }} "build" + echo "OnnxRuntimeBuildDirectory=$buildDir" >> $env:GITHUB_ENV + + - name: Build and Clean Binaries working-directory: ${{ runner.temp }} run: | - python ${{ github.workspace }}\tools\ci_build\build.py --config RelWithDebInfo --parallel --use_binskim_compliant_compile_flags --build_dir build --skip_submodule_sync --build_shared_lib --build --cmake_generator "Visual Studio 17 2022" --build_wheel --enable_onnx_tests --use_tensorrt --tensorrt_home="${{ runner.temp }}\TensorRT-10.9.0.34.Windows10.x86_64.cuda-12.8" --cuda_home="${{ runner.temp }}\v12.2" --use_vcpkg --use_vcpkg_ms_internal_asset_cache --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=86 + npm install -g typescript + if ($lastExitCode -ne 0) { + exit $lastExitCode + } + # Execute the build process + python ${{ github.workspace }}\tools\ci_build\build.py --config RelWithDebInfo --parallel --use_binskim_compliant_compile_flags --build_dir build --skip_submodule_sync --build_shared_lib --build --update --cmake_generator "Visual Studio 17 2022" --build_wheel --enable_onnx_tests --use_tensorrt --tensorrt_home="${{ runner.temp }}\TensorRT-10.9.0.34.Windows10.x86_64.cuda-12.8" --cuda_home="${{ runner.temp }}\v12.2" --use_vcpkg --use_vcpkg_ms_internal_asset_cache --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=86 + if ($lastExitCode -ne 0) { + exit $lastExitCode + } + + # Clean up the output directory before uploading artifacts + $outputDir = "${{ runner.temp }}\build\RelWithDebInfo" + Write-Host "Cleaning up files from $outputDir..." + + Remove-Item -Path "$outputDir\onnxruntime" -Recurse -Force -ErrorAction SilentlyContinue + Remove-Item -Path "$outputDir\pybind11" -Recurse -Force -ErrorAction SilentlyContinue + Remove-Item -Path "$outputDir\models" -Recurse -Force -ErrorAction SilentlyContinue + Remove-Item -Path "$outputDir\vcpkg_installed" -Recurse -Force -ErrorAction SilentlyContinue + Remove-Item -Path "$outputDir\_deps" -Recurse -Force -ErrorAction SilentlyContinue + Remove-Item -Path "$outputDir\CMakeCache.txt" -Force -ErrorAction SilentlyContinue + Remove-Item -Path "$outputDir\CMakeFiles" -Recurse -Force -ErrorAction SilentlyContinue + # Remove intermediate object files as in the original script + Remove-Item -Path $outputDir -Include "*.obj" -Recurse + shell: pwsh + + - name: Upload build artifacts + uses: actions/upload-artifact@v4 + with: + name: build-artifacts + path: ${{ runner.temp }}\build + env: + OrtPackageId: Microsoft.ML.OnnxRuntime.Gpu + DOTNET_SKIP_FIRST_TIME_EXPERIENCE: true + setVcvars: true + ALLOW_RELEASED_ONNX_OPSET_ONLY: '0' + DocUpdateNeeded: false + ONNXRUNTIME_TEST_GPU_DEVICE_ID: '0' + AZCOPY_AUTO_LOGIN_TYPE: MSI + AZCOPY_MSI_CLIENT_ID: 63b63039-6328-442f-954b-5a64d124e5b4 + + test: + name: Windows GPU TensorRT CI Pipeline Test Job + needs: build + timeout-minutes: 300 + runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Win2022-GPU-A10"] + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + submodules: 'none' + + - name: Download build artifacts + uses: actions/download-artifact@v4 + with: + name: build-artifacts + path: ${{ runner.temp }}\build + + - uses: actions/setup-python@v5 + with: + python-version: '3.12' + architecture: x64 + + - uses: actions/setup-node@v4 + with: + node-version: '20.x' + + - uses: actions/setup-java@v4 + with: + distribution: 'temurin' + java-version: '17' + architecture: x64 + + - name: Locate vcvarsall and Setup Env + uses: ./.github/actions/locate-vcvarsall-and-setup-env + with: + architecture: x64 + + - name: Install python modules + run: python -m pip install -r .\tools\ci_build\github\windows\python\requirements.txt + working-directory: ${{ github.workspace }} shell: cmd - - name: Add build dir to PATH + - name: Download CUDA SDK v12.2 + working-directory: ${{ runner.temp }} + run: | + azcopy.exe cp --recursive "https://lotusscus.blob.core.windows.net/models/cuda_sdk/v12.2" . + dir + shell: pwsh + + - name: Download TensorRT-10.9.0.34.Windows10.x86_64.cuda-12.8 + run: 'azcopy.exe cp --recursive "https://lotusscus.blob.core.windows.net/models/local/TensorRT-10.9.0.34.Windows10.x86_64.cuda-12.8" ${{ runner.temp }}' + shell: pwsh + + - name: Add CUDA to PATH shell: powershell run: | Write-Host "Adding CUDA to PATH" - Add-Content -Path $env:GITHUB_PATH -Value "$env:RUNNER_TEMP\build\RelWithDebInfo\RelWithDebInfo" + Write-Host "CUDA Path: $env:RUNNER_TEMP\v12.2\bin" + Add-Content -Path $env:GITHUB_PATH -Value "$env:RUNNER_TEMP\v12.2\bin" + Add-Content -Path $env:GITHUB_PATH -Value "$env:RUNNER_TEMP\v12.2\extras\CUPTI\lib64" + Add-Content -Path $env:GITHUB_PATH -Value "$env:RUNNER_TEMP\TensorRT-10.9.0.34.Windows10.x86_64.cuda-12.8\lib" + + - name: Set OnnxRuntimeBuildDirectory + shell: pwsh + run: | + $buildDir = Join-Path ${{ runner.temp }} "build" + echo "OnnxRuntimeBuildDirectory=$buildDir" >> $env:GITHUB_ENV - name: Install ONNX Runtime Wheel uses: ./.github/actions/install-onnxruntime-wheel with: whl-directory: ${{ runner.temp }}\build\RelWithDebInfo\RelWithDebInfo\dist - - name: Run tests + - name: Run Tests working-directory: ${{ runner.temp }} run: | - mklink /D /J ${{ github.workspace }}\RelWithDebInfo\models ${{ github.workspace }}\models - python ${{ github.workspace }}\tools\ci_build\build.py --config RelWithDebInfo --parallel --use_binskim_compliant_compile_flags --build_dir build --skip_submodule_sync --build_shared_lib --test --cmake_generator "Visual Studio 17 2022" --build_wheel --enable_onnx_tests --use_tensorrt --tensorrt_home="${{ runner.temp }}\TensorRT-10.9.0.34.Windows10.x86_64.cuda-12.8" --cuda_home="${{ runner.temp }}\v12.2" --use_vcpkg --use_vcpkg_ms_internal_asset_cache --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=86 + npm install -g typescript + if ($lastExitCode -ne 0) { + exit $lastExitCode + } + + python.exe ${{ github.workspace }}\tools\python\update_ctest_path.py "${{ runner.temp }}\build\RelWithDebInfo\CTestTestfile.cmake" "${{ runner.temp }}\build\RelWithDebInfo" + if ($lastExitCode -ne 0) { + exit $lastExitCode + } + + python ${{ github.workspace }}\tools\ci_build\build.py --config RelWithDebInfo --parallel --use_binskim_compliant_compile_flags --build_dir build --skip_submodule_sync --build_shared_lib --test --cmake_generator "Visual Studio 17 2022" --build_wheel --enable_onnx_tests --use_tensorrt --tensorrt_home="${{ runner.temp }}\TensorRT-10.9.0.34.Windows10.x86_64.cuda-12.8" --cuda_home="${{ runner.temp }}\v12.2" --use_vcpkg --use_vcpkg_ms_internal_asset_cache --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=86 + if ($lastExitCode -ne 0) { + exit $lastExitCode + } + shell: pwsh + + - name: Validate C# native delegates + run: python tools\ValidateNativeDelegateAttributes.py + working-directory: ${{ github.workspace }}\csharp shell: cmd - timeout-minutes: 180 + env: + OrtPackageId: Microsoft.ML.OnnxRuntime.Gpu + DOTNET_SKIP_FIRST_TIME_EXPERIENCE: true + setVcvars: true + ALLOW_RELEASED_ONNX_OPSET_ONLY: '0' + DocUpdateNeeded: false + ONNXRUNTIME_TEST_GPU_DEVICE_ID: '0' + AZCOPY_AUTO_LOGIN_TYPE: MSI + AZCOPY_MSI_CLIENT_ID: 63b63039-6328-442f-954b-5a64d124e5b4 diff --git a/.gitmodules b/.gitmodules index b5bff01d89850..a48c4062a90fe 100644 --- a/.gitmodules +++ b/.gitmodules @@ -7,4 +7,4 @@ [submodule "cmake/external/emsdk"] path = cmake/external/emsdk url = https://github.com/emscripten-core/emsdk.git - branch = 4.0.8 + branch = 4.0.11 diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index 611203f0b3f72..a76be16572a03 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -83,6 +83,11 @@ option(onnxruntime_USE_RKNPU "Build with RKNPU support" OFF) option(onnxruntime_USE_DNNL "Build with DNNL support" OFF) option(onnxruntime_USE_JSEP "Build with JavaScript implemented kernels support" OFF) option(onnxruntime_USE_KLEIDIAI "Build with KleidiAI integration in MLAS" OFF) +# iOS simulator build explicitly builds targets with USE_KLEIDIAI=ON so attempting to force override if so +if(APPLE AND CMAKE_OSX_ARCHITECTURES MATCHES "x86_64") + message(WARNING "Disabling KleidiAI: not supported on Apple x86_64 platforms") + set(onnxruntime_USE_KLEIDIAI OFF CACHE BOOL "" FORCE) +endif() option(onnxruntime_BUILD_UNIT_TESTS "Build ONNXRuntime unit tests" ON) option(onnxruntime_BUILD_CSHARP "Build C# library" OFF) option(onnxruntime_BUILD_OBJC "Build Objective-C library" OFF) @@ -275,8 +280,6 @@ if (onnxruntime_ENABLE_TRAINING_APIS) endif() - - # Single output director for all binaries set(RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin CACHE PATH "Single output directory for all binaries.") @@ -559,6 +562,7 @@ else() check_cxx_compiler_flag(-Wcast-function-type HAS_CAST_FUNCTION_TYPE) check_cxx_compiler_flag(-Wcatch-value HAS_CATCH_VALUE) check_cxx_compiler_flag(-Wclass-memaccess HAS_CLASS_MEMACCESS) + check_cxx_compiler_flag(-Wcharacter-conversion HAS_CHARACTER_CONVERSION) check_cxx_compiler_flag(-Wdangling-reference HAS_DANGLING_REFERENCE) check_cxx_compiler_flag(-Wdeprecated-anon-enum-enum-conversion HAS_DEPRECATED_ANON_ENUM_ENUM_CONVERSION) check_cxx_compiler_flag(-Wdeprecated-builtins HAS_DEPRECATED_BUILTINS) @@ -648,17 +652,25 @@ else() endif() endif() -if (onnxruntime_USE_KLEIDIAI AND NOT MSVC AND ( - (onnxruntime_target_platform STREQUAL "aarch64") OR - (onnxruntime_target_platform STREQUAL "ARM64") OR - (APPLE AND CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64"))) - check_cxx_compiler_flag(-march=armv8.2-a+dotprod HAS_ARM64_DOTPROD) - check_cxx_compiler_flag(-march=armv8.2-a+i8mm HAS_ARM64_I8MM) - if (NOT HAS_ARM64_DOTPROD) - message(FATAL_ERROR "The compiler doesn't support dotprod") - endif() - if (NOT HAS_ARM64_I8MM) - message(FATAL_ERROR "The compiler doesn't support i8mm") +if (onnxruntime_USE_KLEIDIAI AND ( + (onnxruntime_target_platform STREQUAL "aarch64") OR + (onnxruntime_target_platform STREQUAL "ARM64") OR + (APPLE AND CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64"))) + + # TODO Add checks for MSVC Compilation + if(NOT MSVC) + check_cxx_compiler_flag(-march=armv8.2-a+dotprod HAS_ARM64_DOTPROD) + check_cxx_compiler_flag(-march=armv8.2-a+i8mm HAS_ARM64_I8MM) + if (NOT HAS_ARM64_DOTPROD) + message(FATAL_ERROR "The compiler doesn't support dotprod") + endif() + if (NOT HAS_ARM64_I8MM) + message(FATAL_ERROR "The compiler doesn't support i8mm") + endif() + else() + message(STATUS "Skipping -march= checks on MSVC (not supported), assuming dotprod/i8mm support manually.") + set(HAS_ARM64_DOTPROD TRUE) + set(HAS_ARM64_I8MM TRUE) endif() endif() @@ -1008,6 +1020,10 @@ function(onnxruntime_set_compile_flags target_name) if (onnxruntime_ENABLE_ATEN) target_compile_definitions(${target_name} PRIVATE ENABLE_ATEN) endif() + # TODO: Narrow scope for Kleidiai compile + if (onnxruntime_USE_KLEIDIAI) + target_compile_definitions(${target_name} PRIVATE USE_KLEIDIAI) + endif() set_target_properties(${target_name} PROPERTIES COMPILE_WARNING_AS_ERROR ON) if (onnxruntime_USE_CUDA) diff --git a/cmake/deps.txt b/cmake/deps.txt index 7089012a65f26..ed1de06f33dcb 100644 --- a/cmake/deps.txt +++ b/cmake/deps.txt @@ -27,8 +27,8 @@ fp16;https://github.com/Maratyszcza/FP16/archive/0a92994d729ff76a58f692d3028ca1b fxdiv;https://github.com/Maratyszcza/FXdiv/archive/63058eff77e11aa15bf531df5dd34395ec3017c8.zip;a5658f4036402dbca7cebee32be57fb8149811e1 google_benchmark;https://github.com/google/benchmark/archive/refs/tags/v1.8.5.zip;cd47d3d272faf353600c8cc2fdec2b52d6f69177 googletest;https://github.com/google/googletest/archive/refs/tags/v1.17.0.zip;f638fa0e724760e2ba07ff8cfba32cd644e1ce28 -#xnnpack 2024.09.04 -googlexnnpack;https://github.com/google/XNNPACK/archive/fe98e0b93565382648129271381c14d6205255e3.zip;14f61dcf17cec2cde34ba2dcf61d6f24bf6059f3 +#xnnpack 2025.06.22 +googlexnnpack;https://github.com/google/XNNPACK/archive/3cf85e705098622d59056dcb8f5f963ea7bb0a00.zip;6f6bbba627241f89463ca845febaf063982b34fe json;https://github.com/nlohmann/json/archive/refs/tags/v3.11.3.zip;5e88795165cc8590138d1f47ce94ee567b85b4d6 microsoft_gsl;https://github.com/microsoft/GSL/archive/refs/tags/v4.0.0.zip;cf368104cd22a87b4dd0c80228919bb2df3e2a14 microsoft_wil;https://github.com/microsoft/wil/archive/refs/tags/v1.0.230629.1.zip;e4a542a323c070376f7c2d1973d0f7ddbc1d2fa5 @@ -45,9 +45,9 @@ protoc_linux_x86;https://github.com/protocolbuffers/protobuf/releases/download/v protoc_linux_aarch64;https://github.com/protocolbuffers/protobuf/releases/download/v21.12/protoc-21.12-linux-aarch_64.zip;df9d45470b0b8cf939dd2f0ec6b88e9cafc4d617 protoc_mac_universal;https://github.com/protocolbuffers/protobuf/releases/download/v21.12/protoc-21.12-osx-universal_binary.zip;23710c3d1c2036d8d65a6a22234372fa2d7af9ef psimd;https://github.com/Maratyszcza/psimd/archive/072586a71b55b7f8c584153d223e95687148a900.zip;1f5454b01f06f9656b77e4a5e2e31d7422487013 -pthreadpool;https://github.com/google/pthreadpool/archive/4e80ca24521aa0fb3a746f9ea9c3eaa20e9afbb0.zip;bd4ea65c8292801e9555b527a0ecbb2e0092c917 +pthreadpool;https://github.com/google/pthreadpool/archive/dcc9f28589066af0dbd4555579281230abbf74dd.zip;533a77943203ef15ca608bcd9dbe2c94da7451d2 pybind11;https://github.com/pybind/pybind11/archive/refs/tags/v2.13.6.zip;f780292da9db273c8ef06ccf5fd4b623624143e9 -pytorch_cpuinfo;https://github.com/pytorch/cpuinfo/archive/8a1772a0c5c447df2d18edf33ec4603a8c9c04a6.zip;85bf8a60dae026b99b6ccd78606c85ed83bfb2cd +pytorch_cpuinfo;https://github.com/pytorch/cpuinfo/archive/de0ce7c7251372892e53ce9bc891750d2c9a4fd8.zip;c45b8d3619b9bccbd26dc5f657959aee38b18b7a re2;https://github.com/google/re2/archive/refs/tags/2024-07-02.zip;646e1728269cde7fcef990bf4a8e87b047882e88 safeint;https://github.com/dcleblanc/SafeInt/archive/refs/tags/3.0.28.zip;23f252040ff6cb9f1fd18575b32fa8fb5928daac tensorboard;https://github.com/tensorflow/tensorboard/archive/373eb09e4c5d2b3cc2493f0949dc4be6b6a45e81.zip;67b833913605a4f3f499894ab11528a702c2b381 @@ -56,5 +56,5 @@ extensions;https://github.com/microsoft/onnxruntime-extensions/archive/c24b7bab0 directx_headers;https://github.com/microsoft/DirectX-Headers/archive/refs/tags/v1.613.1.zip;47653509a3371eabb156360f42faf582f314bf2e cudnn_frontend;https://github.com/NVIDIA/cudnn-frontend/archive/refs/tags/v1.12.0.zip;7e733cfdc410d777b76122d64232499205589a96 dawn;https://github.com/google/dawn/archive/9733be39e18186961d503e064874afe3e9ceb8d1.zip;2a4017c32892b90d072a9102eba90ae691fae36d -kleidiai;https://github.com/ARM-software/kleidiai/archive/refs/tags/v1.4.0.tar.gz;22d3b57b54a61c194ab256ff11b0353a3b220244 +kleidiai;https://github.com/ARM-software/kleidiai/archive/refs/tags/v1.9.0.tar.gz;a2765979f64efb173a4b8ba4de39dcba9c655786 duktape;https://github.com/svaarala/duktape/releases/download/v2.7.0/duktape-2.7.0.tar.xz;8200c8e417dbab7adcc12c4dbdef7651cfc55794 diff --git a/cmake/external/emsdk b/cmake/external/emsdk index 419021fa04042..d49219d03a41c 160000 --- a/cmake/external/emsdk +++ b/cmake/external/emsdk @@ -1 +1 @@ -Subproject commit 419021fa040428bc69ef1559b325addb8e10211f +Subproject commit d49219d03a41cd12f95a33ba84273c20d41fd350 diff --git a/cmake/external/onnxruntime_external_deps.cmake b/cmake/external/onnxruntime_external_deps.cmake index 228906030d14c..0d1f47f195ba5 100644 --- a/cmake/external/onnxruntime_external_deps.cmake +++ b/cmake/external/onnxruntime_external_deps.cmake @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + message(STATUS "Loading Dependencies URLs ...") include(external/helper_functions.cmake) @@ -567,7 +570,7 @@ if (onnxruntime_USE_XNNPACK) ENDIF() ADD_LIBRARY(xnnpack STATIC IMPORTED) find_library(xnnpack_LIBRARY NAMES XNNPACK) - find_library(microkernels_prod_LIBRARY NAMES microkernels-prod) + find_library(microkernels_prod_LIBRARY NAMES xnnpack-microkernels-prod) find_package(unofficial-pthreadpool CONFIG REQUIRED) target_include_directories(xnnpack INTERFACE "${XNNPACK_HDR}") @@ -819,6 +822,14 @@ if(onnxruntime_USE_COREML) endif() +if(onnxruntime_USE_KLEIDIAI) + # Disable the KleidiAI tests + set(KLEIDIAI_BUILD_TESTS OFF) + + onnxruntime_fetchcontent_declare(kleidiai URL ${DEP_URL_kleidiai} URL_HASH SHA1=${DEP_SHA1_kleidiai} EXCLUDE_FROM_ALL) + onnxruntime_fetchcontent_makeavailable(kleidiai) +endif() + set(onnxruntime_LINK_DIRS) if (onnxruntime_USE_CUDA) find_package(CUDAToolkit REQUIRED) diff --git a/cmake/external/xnnpack.cmake b/cmake/external/xnnpack.cmake index d0ab770053be1..c994e7e15aac4 100644 --- a/cmake/external/xnnpack.cmake +++ b/cmake/external/xnnpack.cmake @@ -90,7 +90,7 @@ onnxruntime_fetchcontent_makeavailable(googlexnnpack) set(XNNPACK_DIR ${googlexnnpack_SOURCE_DIR}) set(XNNPACK_INCLUDE_DIR ${XNNPACK_DIR}/include) -set(onnxruntime_EXTERNAL_LIBRARIES_XNNPACK XNNPACK microkernels-prod pthreadpool) +set(onnxruntime_EXTERNAL_LIBRARIES_XNNPACK XNNPACK xnnpack-microkernels-prod pthreadpool) if(ORT_TARGET_PROCESSOR MATCHES "^arm64.*" AND NOT CMAKE_C_COMPILER_ID STREQUAL "MSVC") list(APPEND onnxruntime_EXTERNAL_LIBRARIES_XNNPACK kleidiai) endif() diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index 47e7779d93b33..24cecf07e8e36 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -267,24 +267,23 @@ function(setup_mlas_source_for_windows) endfunction() function(setup_kleidiai) - target_compile_definitions(onnxruntime_mlas PRIVATE USE_KLEIDIAI) - - # Disable the KleidiAI tests - set(KLEIDIAI_BUILD_TESTS OFF) - - # Fetch KleidiAI sources: - if (NOT TARGET kleidiai) - onnxruntime_fetchcontent_declare(kleidiai URL ${DEP_URL_kleidiai} URL_HASH SHA1=${DEP_SHA1_kleidiai} EXCLUDE_FROM_ALL) - endif() - onnxruntime_fetchcontent_makeavailable(kleidiai) - target_sources(onnxruntime_mlas PRIVATE ${MLAS_SRC_DIR}/kai_ukernel_interface.cpp + ${MLAS_SRC_DIR}/kleidiai/sgemm_kleidiai.cpp + ${MLAS_SRC_DIR}/kleidiai/convolve_kleidiai.cpp + ${MLAS_SRC_DIR}/kleidiai/qgemm_kleidiai.cpp ) target_link_libraries(onnxruntime_mlas PRIVATE kleidiai) - list(APPEND onnxruntime_EXTERNAL_LIBRARIES kleidiai) set(onnxruntime_EXTERNAL_LIBRARIES ${onnxruntime_EXTERNAL_LIBRARIES} PARENT_SCOPE) + + if (NOT onnxruntime_BUILD_SHARED_LIB) + install(TARGETS kleidiai EXPORT ${PROJECT_NAME}Targets + ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} + LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} + RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} + FRAMEWORK DESTINATION ${CMAKE_INSTALL_BINDIR}) + endif() endfunction() if (CMAKE_SYSTEM_NAME STREQUAL "Emscripten") @@ -311,7 +310,6 @@ if (CMAKE_SYSTEM_NAME STREQUAL "Emscripten") elseif(MSVC) setup_mlas_source_for_windows() else() - if(APPLE) get_target_property(ONNXRUNTIME_MLAS_OSX_ARCH onnxruntime_mlas OSX_ARCHITECTURES) diff --git a/cmake/onnxruntime_providers_qnn.cmake b/cmake/onnxruntime_providers_qnn.cmake index 748e3de843bab..f499c83d5f6c0 100644 --- a/cmake/onnxruntime_providers_qnn.cmake +++ b/cmake/onnxruntime_providers_qnn.cmake @@ -66,10 +66,10 @@ COMMAND ${CMAKE_COMMAND} -E copy ${QNN_LIB_FILES} $ ) endif() - if (EXISTS "${onnxruntime_QNN_HOME}/Qualcomm AI Hub Proprietary License.pdf") + if (EXISTS "${onnxruntime_QNN_HOME}/LICENSE.pdf") add_custom_command( TARGET ${onnxruntime_providers_qnn_target} POST_BUILD - COMMAND ${CMAKE_COMMAND} -E copy "${onnxruntime_QNN_HOME}/Qualcomm AI Hub Proprietary License.pdf" $ + COMMAND ${CMAKE_COMMAND} -E copy "${onnxruntime_QNN_HOME}/LICENSE.pdf" $/Qualcomm_LICENSE.pdf ) endif() else() @@ -154,10 +154,10 @@ COMMAND ${CMAKE_COMMAND} -E copy ${QNN_LIB_FILES} $ ) endif() - if (EXISTS "${onnxruntime_QNN_HOME}/Qualcomm AI Hub Proprietary License.pdf") + if (EXISTS "${onnxruntime_QNN_HOME}/LICENSE.pdf") add_custom_command( TARGET ${onnxruntime_providers_qnn_target} POST_BUILD - COMMAND ${CMAKE_COMMAND} -E copy "${onnxruntime_QNN_HOME}/Qualcomm AI Hub Proprietary License.pdf" $ + COMMAND ${CMAKE_COMMAND} -E copy "${onnxruntime_QNN_HOME}/LICENSE.pdf" $/Qualcomm_LICENSE.pdf ) endif() endif() diff --git a/cmake/onnxruntime_python.cmake b/cmake/onnxruntime_python.cmake index b177074a1bc02..c5c85dff96411 100644 --- a/cmake/onnxruntime_python.cmake +++ b/cmake/onnxruntime_python.cmake @@ -1068,12 +1068,10 @@ if (onnxruntime_USE_QNN) ${QNN_LIB_FILES} $/onnxruntime/capi/ ) - if (EXISTS "${onnxruntime_QNN_HOME}/Qualcomm AI Hub Proprietary License.pdf") + if (EXISTS "${onnxruntime_QNN_HOME}/LICENSE.pdf") add_custom_command( TARGET onnxruntime_pybind11_state POST_BUILD - COMMAND ${CMAKE_COMMAND} -E copy - "${onnxruntime_QNN_HOME}/Qualcomm AI Hub Proprietary License.pdf" - $/onnxruntime/ + COMMAND ${CMAKE_COMMAND} -E copy "${onnxruntime_QNN_HOME}/LICENSE.pdf" $/onnxruntime/Qualcomm_LICENSE.pdf ) endif() endif() diff --git a/cmake/onnxruntime_session.cmake b/cmake/onnxruntime_session.cmake index 3ec3c6ee1d5ae..f81a7a9726b76 100644 --- a/cmake/onnxruntime_session.cmake +++ b/cmake/onnxruntime_session.cmake @@ -5,6 +5,8 @@ file(GLOB onnxruntime_session_srcs CONFIGURE_DEPENDS "${ONNXRUNTIME_INCLUDE_DIR}/core/session/*.h" "${ONNXRUNTIME_ROOT}/core/session/*.h" "${ONNXRUNTIME_ROOT}/core/session/*.cc" + "${ONNXRUNTIME_ROOT}/core/session/plugin_ep/*.h" + "${ONNXRUNTIME_ROOT}/core/session/plugin_ep/*.cc" ) if (onnxruntime_ENABLE_TRAINING_APIS) @@ -22,7 +24,7 @@ endif() # which is not enabled for any minimal builds. if (onnxruntime_MINIMAL_BUILD) file(GLOB autoep_srcs - "${ONNXRUNTIME_ROOT}/core/session/ep_*.*" + "${ONNXRUNTIME_ROOT}/core/session/plugin_ep/*.*" ) set(onnxruntime_session_src_exclude diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index 96e513c8a7bc9..c3bebba3bab54 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -120,6 +120,9 @@ function(AddTest) if (${HAS_NOERROR}) target_compile_options(${_UT_TARGET} PRIVATE "$<$:-Wno-error=uninitialized>") endif() + if (${HAS_CHARACTER_CONVERSION}) + target_compile_options(${_UT_TARGET} PRIVATE "$<$:-Wno-error=character-conversion>") + endif() endif() set(TEST_ARGS ${_UT_TEST_ARGS}) @@ -787,6 +790,9 @@ if(MSVC) "$<$>:/wd6326>") else() target_include_directories(onnxruntime_test_utils PRIVATE ${CMAKE_CURRENT_BINARY_DIR} ${ONNXRUNTIME_ROOT}) + if (HAS_CHARACTER_CONVERSION) + target_compile_options(onnxruntime_test_utils PRIVATE "$<$:-Wno-error=character-conversion>") + endif() endif() if (onnxruntime_USE_NCCL) target_include_directories(onnxruntime_test_utils PRIVATE ${NCCL_INCLUDE_DIRS}) diff --git a/cmake/onnxruntime_webassembly.cmake b/cmake/onnxruntime_webassembly.cmake index ffe866164a411..e2d04843d858e 100644 --- a/cmake/onnxruntime_webassembly.cmake +++ b/cmake/onnxruntime_webassembly.cmake @@ -175,9 +175,9 @@ else() "${ONNXRUNTIME_ROOT}/wasm/api.cc" "${ONNXRUNTIME_ROOT}/core/session/onnxruntime_c_api.cc" ) - set (WASM_API_EXCEPTION_CATCHING "-s DISABLE_EXCEPTION_CATCHING=0") message(STATUS "onnxruntime_ENABLE_WEBASSEMBLY_EXCEPTION_CATCHING_ON_API set") - set_source_files_properties(${onnxruntime_webassembly_src_exc} PROPERTIES COMPILE_FLAGS ${WASM_API_EXCEPTION_CATCHING}) + set_source_files_properties(${onnxruntime_webassembly_src_exc} PROPERTIES COMPILE_FLAGS "-sDISABLE_EXCEPTION_CATCHING=0") + target_link_options(onnxruntime_webassembly PRIVATE "SHELL:-s DISABLE_EXCEPTION_CATCHING=0") endif() target_link_libraries(onnxruntime_webassembly PRIVATE @@ -241,11 +241,10 @@ else() "SHELL:-s FILESYSTEM=0" "SHELL:-s INCOMING_MODULE_JS_API=[locateFile,instantiateWasm,wasmBinary]" "SHELL:-s WASM_BIGINT=1" - ${WASM_API_EXCEPTION_CATCHING} --no-entry "SHELL:--pre-js \"${ONNXRUNTIME_ROOT}/wasm/pre.js\"" ) - + if (onnxruntime_USE_JSEP) # NOTE: "-s ASYNCIFY=1" is required for JSEP to work with WebGPU # This flag allows async functions to be called from sync functions, in the cost of binary size and @@ -256,7 +255,7 @@ else() "SHELL:--pre-js \"${ONNXRUNTIME_ROOT}/wasm/pre-jsep.js\"" ) list(APPEND onnxruntime_webassembly_script_deps "${ONNXRUNTIME_ROOT}/wasm/pre-jsep.js") - + endif() if (onnxruntime_USE_WEBGPU) diff --git a/cmake/patches/xnnpack/AddEmscriptenAndIosSupport.patch b/cmake/patches/xnnpack/AddEmscriptenAndIosSupport.patch index c9cb4bcad9e20..ea0bb61274f84 100644 --- a/cmake/patches/xnnpack/AddEmscriptenAndIosSupport.patch +++ b/cmake/patches/xnnpack/AddEmscriptenAndIosSupport.patch @@ -1,8 +1,8 @@ diff --git a/CMakeLists.txt b/CMakeLists.txt -index f0b3410ae..1e3cb8178 100644 +index 94bcad92e3..be7dfe95fd 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt -@@ -337,7 +337,7 @@ ENDIF() +@@ -360,7 +360,7 @@ ENDIF() # ---[ Build flags IF(NOT CMAKE_SYSTEM_NAME) MESSAGE(FATAL_ERROR "CMAKE_SYSTEM_NAME not defined") @@ -11,21 +11,30 @@ index f0b3410ae..1e3cb8178 100644 MESSAGE(FATAL_ERROR "Unrecognized CMAKE_SYSTEM_NAME value \"${CMAKE_SYSTEM_NAME}\"") ENDIF() IF(CMAKE_SYSTEM_NAME MATCHES "Windows") -@@ -848,7 +848,12 @@ IF(XNNPACK_BUILD_LIBRARY) - TARGET_LINK_LIBRARIES(operator-utils PRIVATE xnnpack-base logging) - TARGET_LINK_LIBRARIES(reference-ukernels PRIVATE xnnpack-base) - TARGET_LINK_LIBRARIES(subgraph PRIVATE xnnpack-base allocator logging memory mutex operators operator-run datatype) -- TARGET_LINK_LIBRARIES(XNNPACK PRIVATE xnnpack-base allocator cache hardware-config indirection memory microkernel-utils microparams-init mutex normalization operators operator-run operator-utils packing microkernels-prod subgraph datatype reference-ukernels) +@@ -903,10 +903,18 @@ IF(XNNPACK_BUILD_LIBRARY) + TARGET_LINK_LIBRARIES(xnnpack-operator-utils PRIVATE xnnpack-base xnnpack-logging) + TARGET_LINK_LIBRARIES(xnnpack-reference-ukernels PRIVATE xnnpack-base xnnpack-datatype) + TARGET_LINK_LIBRARIES(xnnpack-subgraph PRIVATE xnnpack-base xnnpack-allocator xnnpack-logging xnnpack-memory xnnpack-mutex xnnpack-operators xnnpack-operator-run xnnpack-datatype) +- TARGET_LINK_LIBRARIES(XNNPACK PRIVATE xnnpack-base xnnpack-allocator xnnpack-cache +- xnnpack-hardware-config xnnpack-indirection xnnpack-memory xnnpack-microkernel-utils xnnpack-microparams-init +- xnnpack-mutex xnnpack-normalization xnnpack-operators xnnpack-operator-run xnnpack-operator-utils xnnpack-pack-lh xnnpack-packing +- xnnpack-microkernels-prod xnnpack-subgraph xnnpack-datatype xnnpack-reference-ukernels) + IF(CMAKE_SYSTEM_NAME STREQUAL "Emscripten") + # omit microkernels-prod as the list is manually created by ORT in cmake/external/xnnpack.cmake -+ TARGET_LINK_LIBRARIES(XNNPACK PRIVATE xnnpack-base allocator cache hardware-config indirection memory microkernel-utils microparams-init mutex normalization operators operator-run operator-utils packing subgraph datatype reference-ukernels) ++ TARGET_LINK_LIBRARIES(XNNPACK PRIVATE xnnpack-base xnnpack-allocator xnnpack-cache ++ xnnpack-hardware-config xnnpack-indirection xnnpack-memory xnnpack-microkernel-utils xnnpack-microparams-init ++ xnnpack-mutex xnnpack-normalization xnnpack-operators xnnpack-operator-run xnnpack-operator-utils xnnpack-pack-lh xnnpack-packing ++ xnnpack-subgraph xnnpack-datatype xnnpack-reference-ukernels) + ELSE() -+ TARGET_LINK_LIBRARIES(XNNPACK PRIVATE xnnpack-base allocator cache hardware-config indirection memory microkernel-utils microparams-init mutex normalization operators operator-run operator-utils packing microkernels-prod subgraph datatype reference-ukernels) ++ TARGET_LINK_LIBRARIES(XNNPACK PRIVATE xnnpack-base xnnpack-allocator xnnpack-cache ++ xnnpack-hardware-config xnnpack-indirection xnnpack-memory xnnpack-microkernel-utils xnnpack-microparams-init ++ xnnpack-mutex xnnpack-normalization xnnpack-operators xnnpack-operator-run xnnpack-operator-utils xnnpack-pack-lh xnnpack-packing ++ xnnpack-microkernels-prod xnnpack-subgraph xnnpack-datatype xnnpack-reference-ukernels) + ENDIF() - TARGET_LINK_LIBRARIES(XNNPACK PUBLIC pthreadpool logging) + TARGET_LINK_LIBRARIES(XNNPACK PUBLIC pthreadpool xnnpack-logging) SET_TARGET_PROPERTIES(XNNPACK PROPERTIES C_EXTENSIONS YES) ENDIF() -@@ -857,7 +862,8 @@ IF(NOT MSVC) +@@ -915,7 +923,8 @@ IF(NOT MSVC) ENDIF() IF(XNNPACK_TARGET_PROCESSOR STREQUAL "arm") SET_PROPERTY(SOURCE ${ALL_MICROKERNEL_SRCS} APPEND_STRING PROPERTY COMPILE_FLAGS " -marm ") diff --git a/cmake/vcpkg-ports/cpuinfo/portfile.cmake b/cmake/vcpkg-ports/cpuinfo/portfile.cmake index e61308bf643b4..6722f10a72857 100644 --- a/cmake/vcpkg-ports/cpuinfo/portfile.cmake +++ b/cmake/vcpkg-ports/cpuinfo/portfile.cmake @@ -6,8 +6,8 @@ endif() vcpkg_from_github( OUT_SOURCE_PATH SOURCE_PATH REPO pytorch/cpuinfo - REF 8a1772a0c5c447df2d18edf33ec4603a8c9c04a6 - SHA512 b94ccbfa886221d6bb16513d074675af0a72928a9dd9485dcacdc1124a8a60aacbbe91913a1579e766dfb024f0be1d52eeead40342004ff0238a8b94a095ed08 + REF de0ce7c7251372892e53ce9bc891750d2c9a4fd8 + SHA512 0fde9210b700d2648d37c8deeb0d5c0d007d8ca5689578dd3bce4c460886b20d7649f0194d2ea06b02238fe9d4f06193599ec3ab5cafb19f1f860b00404264fa HEAD_REF master ) diff --git a/cmake/vcpkg-ports/pthreadpool/fix-cmakelists.patch b/cmake/vcpkg-ports/pthreadpool/fix-cmakelists.patch index 97fd1ac7a2bb1..cf7df0ea22980 100644 --- a/cmake/vcpkg-ports/pthreadpool/fix-cmakelists.patch +++ b/cmake/vcpkg-ports/pthreadpool/fix-cmakelists.patch @@ -1,8 +1,8 @@ diff --git a/CMakeLists.txt b/CMakeLists.txt -index f06aada..3c6c6e2 100644 +index efff8cc..1a0f7e9 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt -@@ -31,8 +31,6 @@ IF(CCACHE_BINARY) +@@ -41,8 +41,6 @@ IF(CMAKE_C_COMPILER_ID STREQUAL "MSVC") ENDIF() # ---[ Options. @@ -11,7 +11,7 @@ index f06aada..3c6c6e2 100644 OPTION(PTHREADPOOL_ALLOW_DEPRECATED_API "Enable deprecated API functions" ON) SET(PTHREADPOOL_SYNC_PRIMITIVE "default" CACHE STRING "Synchronization primitive (condvar, futex, gcd, event, or default) for worker threads") SET_PROPERTY(CACHE PTHREADPOOL_SYNC_PRIMITIVE PROPERTY STRINGS default condvar futex gcd event) -@@ -41,7 +39,7 @@ IF(CMAKE_SYSTEM_PROCESSOR MATCHES "^(i[3-6]86|AMD64|x86(_64)?)$") +@@ -51,7 +49,7 @@ IF(CMAKE_SYSTEM_PROCESSOR MATCHES "^(i[3-6]86|AMD64|x86(_64)?)$") ELSE() OPTION(PTHREADPOOL_ENABLE_FASTPATH "Enable fast path using atomic decrement instead of atomic compare-and-swap" OFF) ENDIF() @@ -20,8 +20,8 @@ index f06aada..3c6c6e2 100644 OPTION(PTHREADPOOL_BUILD_TESTS "Build pthreadpool unit tests" ON) OPTION(PTHREADPOOL_BUILD_BENCHMARKS "Build pthreadpool micro-benchmarks" ON) ELSE() -@@ -67,7 +65,8 @@ MACRO(PTHREADPOOL_TARGET_ENABLE_CXX11 target) - ENDMACRO() +@@ -71,7 +69,8 @@ IF(PTHREADPOOL_BUILD_TESTS) + ENDIF() # ---[ Download deps -IF(NOT DEFINED FXDIV_SOURCE_DIR) @@ -30,7 +30,7 @@ index f06aada..3c6c6e2 100644 MESSAGE(STATUS "Downloading FXdiv to ${CMAKE_BINARY_DIR}/FXdiv-source (define FXDIV_SOURCE_DIR to avoid it)") CONFIGURE_FILE(cmake/DownloadFXdiv.cmake "${CMAKE_BINARY_DIR}/FXdiv-download/CMakeLists.txt") EXECUTE_PROCESS(COMMAND "${CMAKE_COMMAND}" -G "${CMAKE_GENERATOR}" . -@@ -118,21 +117,13 @@ ELSE() +@@ -122,21 +121,13 @@ ELSE() ENDIF() ADD_LIBRARY(pthreadpool_interface INTERFACE) @@ -54,7 +54,7 @@ index f06aada..3c6c6e2 100644 IF(PTHREADPOOL_SYNC_PRIMITIVE STREQUAL "condvar") TARGET_COMPILE_DEFINITIONS(pthreadpool PRIVATE PTHREADPOOL_USE_FUTEX=0) -@@ -181,18 +172,22 @@ IF(CMAKE_SYSTEM_NAME STREQUAL "Linux") +@@ -182,18 +173,22 @@ IF(CMAKE_SYSTEM_NAME STREQUAL "Linux") ENDIF() # ---[ Configure FXdiv @@ -80,3 +80,4 @@ index f06aada..3c6c6e2 100644 IF(PTHREADPOOL_BUILD_TESTS) # ---[ Build google test + diff --git a/cmake/vcpkg-ports/pthreadpool/portfile.cmake b/cmake/vcpkg-ports/pthreadpool/portfile.cmake index 9400e5e886639..449459feb33cc 100644 --- a/cmake/vcpkg-ports/pthreadpool/portfile.cmake +++ b/cmake/vcpkg-ports/pthreadpool/portfile.cmake @@ -5,8 +5,8 @@ endif() vcpkg_from_github( OUT_SOURCE_PATH SOURCE_PATH REPO google/pthreadpool - REF 4e80ca24521aa0fb3a746f9ea9c3eaa20e9afbb0 - SHA512 776017cc5d2aa94337292f2f4fbd54d099ef29abf736ab8147f07f98f12b7654cbd2fe38d34646a479a519c261ac253bbaf19c6dcbb0ec4cc0859de70f7e6472 + REF dcc9f28589066af0dbd4555579281230abbf74dd + SHA512 61853fa8f6c3297d8760be3af1df3f2a00583c1e0e58bdd03cd9cb915e8660a4f2817b22e6463cf53f10de902a1c6204ec6054fcbeada72eeee9e44baeb97178 PATCHES fix-cmakelists.patch ) diff --git a/cmake/vcpkg-ports/xnnpack/fix-build.patch b/cmake/vcpkg-ports/xnnpack/fix-build.patch index b867377d2ff9e..3da8825e2b57d 100644 --- a/cmake/vcpkg-ports/xnnpack/fix-build.patch +++ b/cmake/vcpkg-ports/xnnpack/fix-build.patch @@ -1,21 +1,17 @@ diff --git a/CMakeLists.txt b/CMakeLists.txt -index f0b3410ae..ba54c3bfe 100644 +index 9f6fb5e256..4387298e59 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt -@@ -1047,9 +1047,11 @@ ENDIF() - IF(XNNPACK_BUILD_ALL_MICROKERNELS) - TARGET_INCLUDE_DIRECTORIES(microkernels-all PRIVATE include src) +@@ -1125,7 +1125,7 @@ ELSE() ENDIF() -+ - TARGET_INCLUDE_DIRECTORIES(datatype PRIVATE include src) - TARGET_INCLUDE_DIRECTORIES(microkernels-prod PRIVATE include src) --TARGET_INCLUDE_DIRECTORIES(hardware-config PRIVATE include src ${CPUINFO_SOURCE_DIR}/include) -+TARGET_INCLUDE_DIRECTORIES(hardware-config PRIVATE include src) -+ - TARGET_INCLUDE_DIRECTORIES(indirection PRIVATE include src) - TARGET_INCLUDE_DIRECTORIES(microparams-init PRIVATE include src) - TARGET_INCLUDE_DIRECTORIES(normalization PRIVATE include src) -@@ -1104,14 +1106,9 @@ IF(NOT TARGET cpuinfo) + + INCLUDE_DIRECTORIES(.) +-TARGET_INCLUDE_DIRECTORIES(xnnpack-hardware-config PRIVATE include src ${CPUINFO_SOURCE_DIR}/include) ++TARGET_INCLUDE_DIRECTORIES(xnnpack-hardware-config PRIVATE include src) + IF(XNNPACK_BUILD_LIBRARY) + TARGET_INCLUDE_DIRECTORIES(XNNPACK PUBLIC include) + IF(WIN32) +@@ -1164,14 +1164,9 @@ IF(NOT TARGET cpuinfo) "${CPUINFO_SOURCE_DIR}" "${CMAKE_BINARY_DIR}/cpuinfo") ELSE() @@ -33,7 +29,7 @@ index f0b3410ae..ba54c3bfe 100644 ENDIF() ENDIF() IF(XNNPACK_BUILD_LIBRARY) -@@ -1129,16 +1126,12 @@ IF(NOT TARGET pthreadpool) +@@ -1189,16 +1184,12 @@ IF(NOT TARGET pthreadpool) "${PTHREADPOOL_SOURCE_DIR}" "${CMAKE_BINARY_DIR}/pthreadpool") ELSE() @@ -53,7 +49,7 @@ index f0b3410ae..ba54c3bfe 100644 ENDIF() ENDIF() TARGET_LINK_LIBRARIES(xnnpack-base INTERFACE pthreadpool) -@@ -1152,12 +1145,12 @@ IF(NOT TARGET fxdiv) +@@ -1212,12 +1203,12 @@ IF(NOT TARGET fxdiv) "${FXDIV_SOURCE_DIR}" "${CMAKE_BINARY_DIR}/FXdiv") ELSE() diff --git a/cmake/vcpkg-ports/xnnpack/portfile.cmake b/cmake/vcpkg-ports/xnnpack/portfile.cmake index d63ad0fbd0cce..60b3566629e10 100644 --- a/cmake/vcpkg-ports/xnnpack/portfile.cmake +++ b/cmake/vcpkg-ports/xnnpack/portfile.cmake @@ -5,8 +5,8 @@ endif() vcpkg_from_github( OUT_SOURCE_PATH SOURCE_PATH REPO google/XNNPACK - REF 953dcb96cc1b21b4b966952f8ee67a9e1f0d3e71 - SHA512 8c12930ef3b2f832962682d73c362518c014bb4e56d0c5cad2b8b63a03c91dccf6e6a3fd0eb91931fc5872c7df9773e76bf08553fc9c3cc22c94636c74815e94 + REF 3cf85e705098622d59056dcb8f5f963ea7bb0a00 + SHA512 af10afde80def08dc3b20a35bd38e84f9f749865ecc4bc9733b5d99d8a2f0f30c19c3f23472d65462a907b3a58226e3b254354a92a6baa31031824f68012a055 HEAD_REF master PATCHES fix-build.patch diff --git a/cmake/vcpkg-ports/xnnpack/vcpkg.json b/cmake/vcpkg-ports/xnnpack/vcpkg.json index e0d0600902f36..643b5c4abe166 100644 --- a/cmake/vcpkg-ports/xnnpack/vcpkg.json +++ b/cmake/vcpkg-ports/xnnpack/vcpkg.json @@ -1,6 +1,6 @@ { "name": "xnnpack", - "version-date": "2025-01-23", + "version-date": "2025-06-22", "description": "High-efficiency floating-point neural network inference operators for mobile, server, and Web", "homepage": "https://github.com/google/XNNPACK", "license": "BSD-3-Clause", diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index f3dcde1abe37a..b59ff63ea8260 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -3089,7 +3089,7 @@ This version of the operator has been available since version 1 of the 'com.micr
activation_type : string
-
Activation function to use. Choose from relu, gelu, silu and identity. Default is relu
+
Activation function to use. Choose from relu, gelu, silu, swiglu and identity. Default is relu
k : int
Number of top experts to select from expert pool
normalize_routing_weights : int
@@ -3106,9 +3106,9 @@ This version of the operator has been available since version 1 of the 'com.micr
router_probs : T
2D input tensor with shape (num_rows, num_experts)
fc1_experts_weights : T
-
3D input tensor with shape (num_experts, hidden_size, inter_size)
+
3D input tensor with shape (num_experts, hidden_size, inter_size), or (num_experts, hidden_size, 2 * inter_size) for swiglu
fc1_experts_bias (optional) : T
-
2D optional input tensor with shape (num_experts, inter_size)
+
2D optional input tensor with shape (num_experts, inter_size), or (num_experts, 2 * inter_size) for swiglu
fc2_experts_weights : T
3D input tensor with shape (num_experts, inter_size, hidden_size)
fc2_experts_bias (optional) : T
@@ -4523,7 +4523,7 @@ This version of the operator has been available since version 1 of the 'com.micr
activation_type : string
-
Activation function to use. Choose from relu, gelu, silu and identity. Default is relu
+
Activation function to use. Choose from relu, gelu, silu, swiglu and identity. Default is relu
expert_weight_bits : int
Number of bits used in quantized weights. Default is 4 bits
k : int
@@ -4542,11 +4542,11 @@ This version of the operator has been available since version 1 of the 'com.micr
router_probs : T
2D input tensor with shape (num_rows, num_experts)
fc1_experts_weights : T1
-
3D input tensor with shape (num_experts, hidden_size, inter_size) or (num_experts, hidden_size, inter_size / 2)
+
3D input tensor with shape (num_experts, hidden_size, inter_size) or (num_experts, hidden_size, inter_size / 2). For swiglu, shape can be (num_experts, hidden_size, 2 * inter_size) or (num_experts, hidden_size, inter_size).
fc1_scales : T
-
2D input tensor with shape (num_experts, inter_size)
+
2D input tensor with shape (num_experts, inter_size), or (num_experts, 2 * inter_size) for swiglu
fc1_experts_bias (optional) : T
-
2D optional input tensor with shape (num_experts, inter_size)
+
2D optional input tensor with shape (num_experts, inter_size), or (num_experts, 2 * inter_size) for swiglu
fc2_experts_weights : T1
3D input tensor with shape (num_experts, inter_size, hidden_size) or (num_experts, inter_size, hidden_size / 2)
fc2_scales : T
diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 8659c96b540c8..3b70e5da8b3e4 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -43,6 +43,7 @@ Do not modify directly.* |||[7, 21]|**T** = tensor(float)| |Atanh|*in* input:**T**
*out* output:**T**|22+|**T** = tensor(float)| |||[9, 21]|**T** = tensor(float)| +|Attention|*in* Q:**T1**
*in* K:**T1**
*in* V:**T2**
*in* attn_mask:**U**
*in* past_key:**T1**
*in* past_value:**T2**
*out* Y:**T1**
*out* present_key:**T1**
*out* present_value:**T2**
*out* qk_matmul_output:**T1**|23+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(float), tensor(float16)
**U** = tensor(bool), tensor(float), tensor(float16)| |AveragePool|*in* X:**T**
*out* Y:**T**|22+|**T** = tensor(float)| |||[19, 21]|**T** = tensor(float)| |||[11, 18]|**T** = tensor(float)| diff --git a/include/onnxruntime/core/framework/allocator.h b/include/onnxruntime/core/framework/allocator.h index 609386fd1f081..24cc460a17fa9 100644 --- a/include/onnxruntime/core/framework/allocator.h +++ b/include/onnxruntime/core/framework/allocator.h @@ -86,8 +86,11 @@ class Stream; namespace synchronize { class Notification; } + using WaitNotificationFn = std::function; -void* AllocateBufferWithOptions(IAllocator& allocator, size_t size, bool use_reserve, Stream* stream, WaitNotificationFn wait_fn); +void* AllocateBufferWithOptions(IAllocator& allocator, size_t size, bool use_reserve, Stream* stream, + // wait fn is for backwards compat with provider bridge + WaitNotificationFn ignored = nullptr); template using IAllocatorUniquePtr = std::unique_ptr>; @@ -105,6 +108,43 @@ class IAllocator { */ virtual void* Alloc(size_t size) = 0; + /** Return true if the allocator implements Stream handling in AllocOnStream. + */ + virtual bool IsStreamAware() const { return false; } + + /** Allocate memory, handling usage across different Streams + * + * A device Stream may be available when executing a model on non-CPU devices. In this case operations are queued + * asynchronously and the allocation/free call is made when the operation is queued rather than executed. + * Due to this it is not safe to use the memory on another stream or with no stream unless synchronization has + * occurred. + * + * ORT currently handles the synchronization when executing the model using streams. + * + * When two streams are synchronized the event used is identified by the producer stream's latest sync id. + * This pair is copied into the sync information of the consumer stream. + * Each new event creates a new sync id. + * + * It is safe to re-use an allocated piece of memory if: + * - the stream that currently owns the memory and the stream that wants to use the memory have been synchronized, + * - and the sync id from when the memory was assigned to the stream that currently owns it is less than the + * sync id from the last synchronization between the two streams. + * - e.g. stream0 is assigned the memory when its sync id is 1. + * stream0 (producer) and stream1 (consumer) are synchronized. + * stream0 sync id will be incremented to 2 when creating the event used in the synchronization. + * stream1 will copy this information into its sync info and now contains an entry for stream0 + * with a sync id of 2. + * stream0 frees the memory + * the memory is marked as not in use, but still assigned to stream0 + * stream1 is now able to use the memory as it is not in use, and the sync id from the allocation (1) + * is less than the sync id (2) that is has for stream0. + * or + * - the inference session that owned the Stream has completed inferencing + * - Stream::CleanUpOnRunEnd is called when this occurs + * - any memory assigned to the Stream is now able to be used by other streams when it is not longer in use. + */ + virtual void* AllocOnStream(size_t size, Stream* /*stream*/) { return Alloc(size); } + /** * Free memory at p. * If p is nullptr, do nothing. @@ -192,7 +232,7 @@ class IAllocator { template static IAllocatorUniquePtr MakeUniquePtr(std::shared_ptr allocator, size_t count_or_bytes, bool use_reserve = false, - Stream* stream = nullptr, WaitNotificationFn wait_fn = nullptr) { + Stream* stream = nullptr) { ValidateAllocator(allocator); // for now limit to fundamental types. we could support others, but to do so either we or the caller @@ -210,7 +250,7 @@ class IAllocator { } // allocate - T* p = static_cast(AllocateBufferWithOptions(*allocator, alloc_size, use_reserve, stream, std::move(wait_fn))); + T* p = static_cast(AllocateBufferWithOptions(*allocator, alloc_size, use_reserve, stream, nullptr)); ValidateAllocation(p, alloc_size); return IAllocatorUniquePtr{p, diff --git a/include/onnxruntime/core/framework/stream_handles.h b/include/onnxruntime/core/framework/stream_handles.h index 441e3ebda1502..7d27c7471d71f 100644 --- a/include/onnxruntime/core/framework/stream_handles.h +++ b/include/onnxruntime/core/framework/stream_handles.h @@ -2,9 +2,11 @@ // Licensed under the MIT License. #pragma once +#include #include #include #include + #include "core/framework/allocator.h" #include "core/framework/ortdevice.h" #include "core/common/status.h" @@ -21,9 +23,9 @@ namespace synchronize { class Notification; } -// a stream abstraction which hold an opaque handle, and a reference to which OrtDevice instance this stream belong to. -// it need to be OrtDevice instance as we might have different stream on different OrtDevice with same type. -// i.e. different cuda stream on different GPU. +/// +/// Class to represent a stream on the OrtDevice. +/// class Stream { public: Stream(StreamHandle h, const OrtDevice& d) @@ -31,123 +33,113 @@ class Stream { } virtual ~Stream() = default; + virtual std::unique_ptr CreateNotification(size_t /*num_consumers*/) { return {}; }; + // block the host thread until all the tasks in the stream finished. virtual void Flush() {}; + // The framework may reuse the stream instance for multiple iterations. // This is the API that provide a chance to let the device stream cleanup // resource at the end of a iteration. virtual Status CleanUpOnRunEnd() { return Status::OK(); }; + // Get the native stream handle. nullptr if the OrtDevice doesn't support streams. StreamHandle GetHandle() const { return handle_; } const OrtDevice& GetDevice() const { return device_; } - // We use the timestamp based vector clocks to optimize the resource sharing - // between different streams. - // Each stream maintain following data structure: - // 1. Current timestamp - // 2. A lookup table that for a given stream, what is its timestamp when the - // last synchronization happened with current stream. - // 3. When a notification is activated, it take a snapshot of current stream's - // lookup table. - // 4. When synchronization happened (current stream wait on a notification), - // update its lookup table with the table snapshot in notification. - // The memory reusing strategy is: - // A kernel in current stream is safe to reuse another stream's memory chunk - // as long as the reused chunk's timestamp is less than the last synchronized - // timestamp recorded in the lookup table. - - // Get the current timestamp - uint64_t GetCurrentTimestamp() const { return timestamp_; } - - // return the timestamp when the last synchronization happened between target stream and current stream. - // return 0 if no synchronization happened. - // if target_stream is nullptr, it means it is a sequence running on device doesn't support Stream (i.e. CPU) - // we can safely return 0 in that case to save a lookup. - uint64_t GetLastSyncTimestampWithTargetStream(Stream* target_stream) const { - if (!target_stream) - return 0; - auto it = other_stream_clock_.find(target_stream); - return it == other_stream_clock_.end() ? 0 : it->second; + // Get the current synchronization ID. + // The value is 0 until this stream records an event. + // The sync id is incremented before each new event that is recorded in our stream via Notification::Activate. + uint64_t GetSyncId() const { return sync_id_; } + + // Return the sync id from when the last synchronization happened between producer_stream and this stream. + // i.e. the id for the event that the producer stream recorded and we waited on + // + // Returns 0 if the streams have not previously been synchronized. + uint64_t GetSyncIdForLastWaitOnStream(const Stream& producer_stream) const { + auto it = producer_stream_sync_info_.find(&producer_stream); + return it == producer_stream_sync_info_.end() ? 0 : it->second; } - // make a copy of the current stream lookup table. - // this is used to create a snapshot of the stream lookup table in notification. - void CloneCurrentStreamSyncTable(std::unordered_map& output) const { - output.reserve(other_stream_clock_.size()); - output.insert(other_stream_clock_.begin(), other_stream_clock_.end()); - } + // Get the sync information that is added to a notification when it is activated. + // This contains sync ids for all streams we have waited on, and the new sync id for our stream. + std::unordered_map OnNotificationActivation() { + // copy our sync info so the notification can pass it on to any waiting streams + auto sync_info = producer_stream_sync_info_; + // and add our info to the copy, incrementing the sync_id + sync_info[this] = ++sync_id_; - // bump the current timestamp - // When a notification get activated, bump the snapshot in its owner. - // Stream is not shared across threads, BumpTimeStampAndReturn will only be invoked on the current thread - // where the stream is executed on, so there is no race condition. - uint64_t BumpTimeStampAndReturn() { - return ++timestamp_; + return sync_info; } - // update the stream lookup table with the snapshot saved in notification. - void UpdateStreamClock(const std::unordered_map& clock) { - for (const auto& kv : clock) { - auto ret = other_stream_clock_.insert(kv); - if (!ret.second) { - ret.first->second = std::max(ret.first->second, kv.second); - } - } - } + // Record information from a Notification we waited on. + // - copies the producer stream info from the notification. + void UpdateWithAwaitedNotification(const synchronize::Notification& notification); + // used in custom ops. doesn't really belong here. virtual void* GetResource(int /*version*/, int /*id*/) const { return nullptr; } - virtual WaitNotificationFn GetWaitNotificationFn() const { return nullptr; } - private: StreamHandle handle_; const OrtDevice& device_; - uint64_t timestamp_{0}; + + // current sync id. equivalent to a counter for the number of events we have recorded in the underlying stream. + // 0 == no events recorded. sync_id_ is updated prior to recording a new event. + std::atomic sync_id_{0}; + + // This is a map to track synchronization points between streams. When we wait on another stream (the producer) + // we add an entry to the map for that stream. + // The entry has the sync id from the producer stream for the event we waited on. + // // TODO: use inline container. // currently this class is header only, but abseil doesn't compile with nvcc // we need to add new symbol to provider_bridge and hide abseil from the header. - std::unordered_map other_stream_clock_{}; + std::unordered_map producer_stream_sync_info_{}; ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(Stream); }; namespace synchronize { -// An abstraction used for synchronization between streams. See its concrete subclass (CudaNotification, etc.) how the activate -// and wait works for a specific stream +// An abstraction used for synchronization between streams. +// See derived classes (CudaNotification, etc.) for implementation examples. class Notification { public: explicit Notification(Stream& s) : stream_(s) {} virtual ~Notification() = default; - // this api will perform three operations: - // 1. activate the notification on device, for example, record an event on GPU. - // 2. take a snapshot of the timestamp lookup table in current stream. - // 3. bump the timestamp for current stream. + // Activate the notification. This records an event in the Stream that created the Notification that other streams can wait on. void ActivateAndUpdate() { Activate(); - stream_.CloneCurrentStreamSyncTable(stream_clock_); - stream_clock_[&stream_] = stream_.BumpTimeStampAndReturn(); + + // copy the sync info. this includes an entry for stream_ with the next sync id. + stream_sync_info_ = stream_.OnNotificationActivation(); } - // return the timestamp lookup table saved in the notification. - const std::unordered_map& GetStreamSyncTable() { - return stream_clock_; + // Get the sync history for the producer stream that created this Notification. + // The notification must have be activated previously. + const std::unordered_map& GetStreamSyncInfo() const { + return stream_sync_info_; } protected: virtual void Activate() = 0; - // which stream create this notification. + + Stream& GetStream() { + return stream_; + } + + private: + // Stream that created the notification (producer stream). Stream& stream_; - // TODO: use inline container. - // currently this class is header only, but abseil doesn't compile with nvcc - // we need to add new symbol to provider_bridge and hide abseil from the header. - std::unordered_map stream_clock_{}; + + // This is a snapshot of the sync history for the stream that created the Notification. + std::unordered_map stream_sync_info_{}; }; } // namespace synchronize diff --git a/include/onnxruntime/core/graph/graph.h b/include/onnxruntime/core/graph/graph.h index ea9cbbfc6ca73..e164f23b8fc35 100644 --- a/include/onnxruntime/core/graph/graph.h +++ b/include/onnxruntime/core/graph/graph.h @@ -44,6 +44,7 @@ struct OrtGraph; namespace onnxruntime { +class ExternalDataInfo; class Graph; struct IndexedSubGraph; class Model; @@ -788,6 +789,27 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi */ bool GetOrtValueInitializer(const std::string& name, OrtValue& value, bool check_outer_scope = false) const; + /// + /// Loads an initializer with data in an external file into an OrtValue. Does NOT cache the OrtValue + /// in this Graph. + /// + /// The name of the initializer. + /// Output parameter set to the loaded OrtValue. Set to an existing OrtValue if + /// it is already loaded. + /// A status indicating an error or success. An error occurs if `name` is not an initializer + /// with external data. + Status LoadExternalInitializerAsOrtValue(const std::string& name, OrtValue& value) const; + + /// + /// Gets information (external filepath, file offset, num bytes) for an initializer with data in an external file. + /// + /// The initializer's name. + /// Output parameter set to the location information of the external data. + /// Set to true if parent graphs should be checked. + /// True if `name` refers to an initializer with data in an external file. Otherwise, returns false + bool GetExternalInitializerInfo(const std::string& name, std::unique_ptr& ext_info, + bool check_outer_scope = false) const; + /** Gets all the initializer tensors in this Graph. */ const InitializedTensorSet& GetAllInitializedTensors() const noexcept { return name_to_initial_tensor_; } diff --git a/include/onnxruntime/core/providers/utils/ort_graph_to_proto.h b/include/onnxruntime/core/providers/utils/ort_graph_to_proto.h index ce0f134002d8e..0d920ab7dac89 100644 --- a/include/onnxruntime/core/providers/utils/ort_graph_to_proto.h +++ b/include/onnxruntime/core/providers/utils/ort_graph_to_proto.h @@ -1,6 +1,9 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +// DO NOT include ORT header files as this is meant to be a header-only utility that can be copied +// to other projects. + /* SUMMARY: Utilities to serialize an OrtGraph into an ONNX GraphProto or ModelProto. Can be used by execution provider @@ -363,13 +366,18 @@ Ort::Status OrtGraphToProto(const OrtGraph& ort_graph, for (const OrtOpAttr* ort_attr : ort_attrs) { OrtOpAttrType attr_type = OrtOpAttrType::ORT_OP_ATTR_UNDEFINED; - Ort::Status status{ort_api.OpAttr_GetType(ort_attr, &attr_type)}; - if (!status.IsOK()) { - // This is an attribute type that ORT does not support via ReadOpAttr(), like subgraphs, so skip it. + Ort::Status attr_type_status{ort_api.OpAttr_GetType(ort_attr, &attr_type)}; + if (attr_type == OrtOpAttrType::ORT_OP_ATTR_GRAPH) { + // ORT does not support reading subgraphs via ReadOpAttr(), so skip it. // Can use Node_GetSubgraphs to get subgraphs. continue; } + if (!attr_type_status.IsOK()) { + // Unsupported attribute type. + return attr_type_status; + } + onnx::AttributeProto* attr_proto = node_proto->add_attribute(); ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtOpAttrToProto(*ort_attr, *attr_proto)); } @@ -494,11 +502,14 @@ Ort::Status OrtGraphToProto(const OrtGraph& ort_graph, auto* ext_data_entries = tensor_proto->mutable_external_data(); onnx::StringStringEntryProto* location_entry = ext_data_entries->Add(); onnx::StringStringEntryProto* offset_entry = ext_data_entries->Add(); + onnx::StringStringEntryProto* length_entry = ext_data_entries->Add(); location_entry->set_key("location"); location_entry->set_value(ext_location); offset_entry->set_key("offset"); offset_entry->set_value(std::to_string(ext_offset)); + length_entry->set_key("length"); + length_entry->set_value(std::to_string(data_bytes)); } else { // User wants to store data inline the TensorProto's raw_data tensor_proto->set_data_location(onnx::TensorProto_DataLocation_DEFAULT); @@ -616,20 +627,24 @@ static Ort::Status OrtValueInfoToProto(const OrtValueInfo& ort_value_info, onnx::TypeProto_Tensor* type_proto_tensor = value_info_proto.mutable_type()->mutable_tensor_type(); type_proto_tensor->set_elem_type(ort_elem_type); - onnx::TensorShapeProto* shape_proto = type_proto_tensor->mutable_shape(); + // If there are no dimensions in the shape, do not set a TensorShapeProto. Otherwise, it always looks + // like a scalar value. + if (!ort_dims.empty()) { + onnx::TensorShapeProto* shape_proto = type_proto_tensor->mutable_shape(); - for (size_t dim_idx = 0; dim_idx < ort_dims.size(); dim_idx++) { - onnx::TensorShapeProto_Dimension* dim_proto = shape_proto->add_dim(); + for (size_t dim_idx = 0; dim_idx < ort_dims.size(); dim_idx++) { + onnx::TensorShapeProto_Dimension* dim_proto = shape_proto->add_dim(); - if (ort_dims[dim_idx] >= 0) { - dim_proto->set_dim_value(ort_dims[dim_idx]); - } else { - const std::string& dim_param = ort_dim_syms[dim_idx]; + if (ort_dims[dim_idx] >= 0) { + dim_proto->set_dim_value(ort_dims[dim_idx]); + } else { + const std::string& dim_param = ort_dim_syms[dim_idx]; - // If dim_param is empty, leave dim_proto with neither the dim_value or dim_param set, - // which represents an unknown dimension. - if (!dim_param.empty()) { - dim_proto->set_dim_param(dim_param); + // If dim_param is empty, leave dim_proto with neither the dim_value or dim_param set, + // which represents an unknown dimension. + if (!dim_param.empty()) { + dim_proto->set_dim_param(dim_param); + } } } } diff --git a/include/onnxruntime/core/session/environment.h b/include/onnxruntime/core/session/environment.h index 7e49275e59b8b..306f81df38e48 100644 --- a/include/onnxruntime/core/session/environment.h +++ b/include/onnxruntime/core/session/environment.h @@ -20,7 +20,7 @@ #include "core/platform/threadpool.h" #include "core/session/abi_devices.h" -#include "core/session/ep_library.h" +#include "core/session/plugin_ep/ep_library.h" #include "core/session/onnxruntime_c_api.h" struct OrtThreadingOptions; diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 9fd9e376cbf0d..d87e9e083185b 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -274,6 +274,7 @@ typedef enum OrtOpAttrType { ORT_OP_ATTR_FLOATS, ORT_OP_ATTR_STRING, ORT_OP_ATTR_STRINGS, + ORT_OP_ATTR_GRAPH, } OrtOpAttrType; //! @} @@ -324,6 +325,7 @@ ORT_RUNTIME_CLASS(HardwareDevice); ORT_RUNTIME_CLASS(EpDevice); ORT_RUNTIME_CLASS(KeyValuePairs); ORT_RUNTIME_CLASS(SyncStream); // Opaque class to create an onnxruntime::Stream. +ORT_RUNTIME_CLASS(ExternalInitializerInfo); #ifdef _MSC_VER typedef _Return_type_success_(return == 0) OrtStatus* OrtStatusPtr; @@ -351,16 +353,22 @@ typedef struct OrtAllocator { /** * @brief Optional allocation function to use for memory allocations made during session initialization. * Use this function if you want to separate allocations made by ORT during Run() calls from - * those made during session initialization. This allows for separate memory management strategies for these allocations. + * those made during session initialization. This allows for separate memory management strategies for these + * allocations. + * + * \return pointer to an allocated block of `size` bytes. nullptr if size was 0 or allocation failed. + * + * \since 1.18 */ - void*(ORT_API_CALL* Reserve)(struct OrtAllocator* this_, size_t size); ///< Returns a pointer to an allocated block of `size` bytes + void*(ORT_API_CALL* Reserve)(struct OrtAllocator* this_, size_t size); /** * @brief Function used to get the statistics of the allocator. * - * Return a pointer to the OrtKeyValuePairs structure that contains the statistics of the allocator - * and the user should call OrtApi::ReleaseKeyValuePairs. - * Supported keys are: + * Return a pointer to the OrtKeyValuePairs structure that contains the statistics of the allocator. + * The user should call OrtApi::ReleaseKeyValuePairs when done. + * + * Current known keys are: * - Limit: Bytes limit of the allocator. -1 if no limit is set. * - InUse: Number of bytes in use. * - TotalAllocated: The total number of allocated bytes by the allocator. @@ -371,9 +379,32 @@ typedef struct OrtAllocator { * - NumArenaShrinkages: Number of arena shrinkages (Relevant only for arena based allocators) * - MaxAllocSize: The max single allocation seen. * - * NOTE: If the allocator does not implement this function, the OrtKeyValuePairs instance will be empty. + * The allocator is free to add other entries as appropriate. + * + * \note Implementation of this function is optional and GetStats may be set to a nullptr. + * If the OrtAllocator is wrapping an internal ORT allocator that does not implement GetStats + * the returned OrtKeyValuePairs instance will be empty. + * + * \since 1.23 */ ORT_API2_STATUS(GetStats, _In_ const struct OrtAllocator* this_, _Outptr_ OrtKeyValuePairs** out); + + /** \brief Allocate using a stream. + * + * If the allocator is stream aware this performs allocation using a stream. + * + * Alloc will be used if this is nullptr. + * + * \param[in] this_ OrtAllocator instance + * \param[in] size Size of the allocation in bytes. nullptr if size was 0 or allocation failed. + * \param[in] stream The stream to allocate on. + * + * \return pointer to an allocated block of `size` bytes + * + * \note Implementation of this function is optional and AllocOnStream may be set to a nullptr. + * \since 1.23 + */ + void*(ORT_API_CALL* AllocOnStream)(struct OrtAllocator* this_, size_t size, OrtSyncStream* stream); } OrtAllocator; typedef void(ORT_API_CALL* OrtLoggingFunction)( @@ -421,7 +452,8 @@ typedef struct OrtCustomOp OrtCustomOp; typedef enum OrtAllocatorType { OrtInvalidAllocator = -1, OrtDeviceAllocator = 0, - OrtArenaAllocator = 1 + OrtArenaAllocator = 1, + OrtReadOnlyAllocator = 2, } OrtAllocatorType; /** \brief Memory types for allocated memory, execution provider specific types should be extended in each provider. @@ -5483,10 +5515,13 @@ struct OrtApi { * * Supports initializers defined in an outer scope (i.e., a parent graph). * + * Supports initializers stored in an external file. For external initializers, ORT memory maps + * the initializer data on the first call to this function. If caller needs custom memory mapping, + * use ValueInfo_GetExternalInitializerInfo to get the location of the initializer data. + * * \param[in] value_info The OrtValueInfo instance. - * \param[out] initializer_value Output parameter set to the initializer value or NULL. The OrtValue data pointer - * (obtained via GetTensorData) is stable during the lifetime of the OrtSession - * that owns the OrtGraph. + * \param[out] initializer_value Output parameter set to the initializer value or NULL. Do not cache the OrtValue + * as it is released when the owning OrtGraph is released. * * \snippet{doc} snippets.dox OrtStatus Return Value * @@ -5495,6 +5530,24 @@ struct OrtApi { ORT_API2_STATUS(ValueInfo_GetInitializerValue, _In_ const OrtValueInfo* value_info, _Outptr_ const OrtValue** initializer_value); + /** \brief Get information about an external initializer (e.g., filepath, file offset, byte size). + * + * Sets the output parameter `info` to NULL if the given OrtValueInfo does not represent an initializer + * with external data. In this case, a NULL status (non-error) is returned. + * + * \param[in] value_info The OrtValueInfo instance. + * \param[out] info Output parameter set to an OrtExternalInitializerInfo instance that can be used to query + * file path, file offset, etc. ORT sets this to NULL if the OrtValueInfo does not represent + * an external initializer. + * Must release with ReleaseExternalInitializerInfo. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.23. + */ + ORT_API2_STATUS(ValueInfo_GetExternalInitializerInfo, _In_ const OrtValueInfo* value_info, + _Outptr_result_maybenull_ OrtExternalInitializerInfo** info); + /** \brief Returns a boolean indicating if the given value is a required graph input. * * For ONNX IR version < 4, all graph inputs without a matching initializer are required. @@ -5793,14 +5846,13 @@ struct OrtApi { /** \brief Returns an OrtGraph that contains a subset of nodes in the source OrtGraph. * - * Note: - * The lifetime of "dst_graph" is tied to that of "src_graph", as they both internally reference + * \note The lifetime of "dst_graph" is tied to that of "src_graph", as they both internally reference * the same underlying graph. * * \param[in] src_graph The source OrtGraph instance. * \param[in] nodes A subset of the nodes/OrtNodes in 'graph'. * \param[in] num_nodes Number of nodes. - * \param[out] dst_sub_graph An OrtGraph created from a given set of nodes. Must be released by calling ReleaseGraph. + * \param[out] dst_graph An OrtGraph created from a given set of nodes. Must be released by calling ReleaseGraph. * * \snippet{doc} snippets.dox OrtStatus Return Value * @@ -6084,6 +6136,50 @@ struct OrtApi { /// @} + /// \name OrtExternalInitializerInfo + /// @{ + + /** \brief Release an OrtExternalInitializerInfo instance. + * + * \param[in] input OrtExternalInitializerInfo instance to be released. + * + * \since Version 1.23. + */ + ORT_CLASS_RELEASE(ExternalInitializerInfo); + + /** \brief Get the relative path to the file that stores the initializer's data. + * + * \note The path is relative to the filesystem directory where the ONNX model was stored. + * Caller can use Graph_GetModelPath to get the model's full path and construct the absolute path to the + * external initializer file if necessary. + * + * \param[in] info The OrtExternalInitializerInfo instance. + * \return The relative path to the file that stores the initializer's data. Do NOT free this pointer. + * + * \since Version 1.23. + */ + ORT_API_T(const ORTCHAR_T*, ExternalInitializerInfo_GetFilePath, _In_ const OrtExternalInitializerInfo* info); + + /** \brief Get the byte offset within the file where the initializer's data is stored. + * + * \param[in] info The OrtExternalInitializerInfo instance. + * \return The byte offset where the initializer's data is stored within the file. + * + * \since Version 1.23. + */ + ORT_API_T(int64_t, ExternalInitializerInfo_GetFileOffset, _In_ const OrtExternalInitializerInfo* info); + + /** \brief Get the size in bytes of the initializer's data within the file. + * + * \param[in] info The OrtExternalInitializerInfo instance. + * \return The size in bytes of the initializer's data within the file. + * + * \since Version 1.23. + */ + ORT_API_T(size_t, ExternalInitializerInfo_GetByteSize, _In_ const OrtExternalInitializerInfo* info); + + /// @} + /// \name OrtRunOptions /// @{ @@ -6131,7 +6227,7 @@ struct OrtApi { * \param[in] env The OrtEnv instance to create the shared allocator in. * \param[in] ep_device The OrtEpDevice instance to create the shared allocator for. * \param[in] mem_type The memory type to use for the shared allocator. - * \param[in] allocator_type The type of allocator to create (e.g. OrtAllocatorType::OrtArenaAllocator). + * \param[in] allocator_type The type of allocator to create. Only OrtDeviceAllocator is valid currently. * \param[in] allocator_options Optional key-value pairs to configure the allocator. If arena based, see * include/onnxruntime/core/framework/allocator.h for the keys and values that can be * used. diff --git a/include/onnxruntime/core/session/onnxruntime_ep_c_api.h b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h index f7e304c98d7b5..620cb5fcf13cc 100644 --- a/include/onnxruntime/core/session/onnxruntime_ep_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h @@ -338,8 +338,14 @@ struct OrtEpApi { * The registered values will be used in calls to OrtEpFactory::CreateAllocator to ensure the required allocator/s * are available for EP usage. * - * At most one DEFAULT and one HOST_ACCESSIBLE entry can be added. - * Multiple calls for the same memory type will replace a previous entry. + * Multiple calls for the same entry type will replace a previous entry. + * + * Available entries: + * - OrtDeviceAllocator with type of OrtDeviceMemoryType_DEFAULT + * - OrtDeviceAllocator with type of OrtDeviceMemoryType_HOST_ACCESSIBLE + * - OrtReadOnlyAllocator with type of OrtDeviceMemoryType_DEFAULT + * - if provided this allocator will only be used to copy initializers to the device the EP uses. + * ORT will use the OrtDeviceAllocator if not provided. * * \param[in] ep_device The OrtEpDevice instance to register the OrtMemoryInfo with. * \param[in] allocator_memory_info The OrtMemoryInfo information for the allocator. @@ -424,6 +430,41 @@ struct OrtEpApi { * \since Version 1.23. */ ORT_API_T(uint32_t, MemoryDevice_GetDeviceId, _In_ const OrtMemoryDevice* memory_device); + + /** \brief Get the OrtSyncStreamImpl associated with an OrtSyncStream instance. + * + * This allows an the plugin library to connect its OrtSyncStreamImpl instance with an OrtSyncStream if needed. + * + * \param[in] stream The OrtSyncStream instance to find an OrtSyncStreamImpl for. + * \return The associated OrtSyncStreamImpl if found. nullptr otherwise. + * + * \since Version 1.23. + * + * \remarks There should always be an OrtSyncStreamImpl associated with an OrtSyncStream instance that the EP gets. + */ + ORT_API_T(const OrtSyncStreamImpl*, SyncStream_GetImpl, _In_ const OrtSyncStream* stream); + + /** \brief Get the current sync ID for a stream. + * + * \param[in] stream The OrtSyncStream to get the sync ID for. + * \return Current sync ID. + * + * \since Version 1.23. + */ + ORT_API_T(uint64_t, SyncStream_GetSyncId, _In_ const OrtSyncStream* stream); + + /** \brief Get the sync ID for the last time the consumer_stream waited on the producer_stream. + * + * When two streams are synchronized, the sync id represents the event used in that synchronization. + * + * \param[in] producer_stream The OrtSyncStream that produced the data. + * \param[in] consumer_stream The OrtSyncStream that waited on the producer_stream. + * \return ID for last sync. 0 if no sync has occurred between the two streams. + * + * \since Version 1.23. + */ + ORT_API_T(uint64_t, GetSyncIdForLastWaitOnSyncStream, + _In_ const OrtSyncStream* producer_stream, _In_ const OrtSyncStream* consumer_stream); }; /** diff --git a/js/web/test/suite-test-list.jsonc b/js/web/test/suite-test-list.jsonc index d08a72b922142..3f1face2a043c 100644 --- a/js/web/test/suite-test-list.jsonc +++ b/js/web/test/suite-test-list.jsonc @@ -1598,7 +1598,7 @@ // "test_averagepool_3d_default", "test_basic_conv_with_padding", "test_basic_conv_without_padding", - // "test_basic_convinteger", + "test_basic_convinteger", "test_batchnorm_epsilon_training_mode", "test_batchnorm_epsilon", "test_batchnorm_example_training_mode", @@ -1686,8 +1686,8 @@ "test_conv_with_strides_and_asymmetric_padding", "test_conv_with_strides_no_padding", "test_conv_with_strides_padding", - // // "test_convinteger_with_padding", - // // "test_convinteger_without_padding", + "test_convinteger_with_padding", + "test_convinteger_without_padding", "test_convtranspose_1d", // // "test_convtranspose_3d", // "test_convtranspose_autopad_same", diff --git a/onnxruntime/contrib_ops/cpu/bert/attention.cc b/onnxruntime/contrib_ops/cpu/bert/attention.cc index de23444e95778..d16c55695772b 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/attention.cc @@ -71,7 +71,7 @@ bool Attention::IsPackWeightsSuccessful(int qkv_index, const T* weights_data, size_t weight_matrix_col_size, /*out*/ PrePackedWeights* prepacked_weights) { - size_t packb_size = MlasGemmPackBSize(head_size, input_hidden_size); + size_t packb_size = MlasGemmPackBSize(CblasNoTrans, CblasNoTrans, head_size, input_hidden_size); if (packb_size == 0) { return false; } @@ -87,7 +87,7 @@ bool Attention::IsPackWeightsSuccessful(int qkv_index, memset(packed_weights_data, 0, packed_weights_data_size); for (size_t i = 0; i < loop_len; i++) { - MlasGemmPackB(CblasNoTrans, head_size, input_hidden_size, weights_data, weight_matrix_col_size, packed_weights_data); + MlasGemmPackB(CblasNoTrans, CblasNoTrans, head_size, input_hidden_size, weights_data, weight_matrix_col_size, packed_weights_data); packed_weights_data += packb_size; weights_data += head_size; } diff --git a/onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_matmul.cc b/onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_matmul.cc index 69eabcfe2654a..85a2cbaea0e44 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_matmul.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_matmul.cc @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include "core/common/cpuid_info.h" // for CPUIDInfo::GetCPUIDInfo().HasArm_SME() #include "core/common/narrow.h" #include "core/common/safeint.h" #include "core/mlas/inc/mlas.h" @@ -10,7 +11,9 @@ #include "core/util/math_cpuonly.h" #include "core/util/qmath.h" +#include #include +#include namespace onnxruntime { namespace contrib { @@ -65,13 +68,13 @@ Status MatMulIntegerToFloatBase::ComputeCommon(OpKernelContext* ctx, bool a_is_signed, const Tensor* b_tensor, const Tensor* b_scale_tensor, - const Tensor* b_zp_tensor, + const Tensor* b_zp_constant_tensor, const Tensor* bias_tensor) const { MatMulComputeHelper helper; ORT_RETURN_IF_ERROR(helper.Compute(a_shape, b_tensor ? b_tensor->Shape() : b_shape_, b_scale_tensor ? &b_scale_tensor->Shape() : nullptr, - b_zp_tensor ? &b_zp_tensor->Shape() : nullptr)); + b_zp_constant_tensor ? &b_zp_constant_tensor->Shape() : nullptr)); Tensor* y = ctx->Output(OUT_Y, helper.OutputShape()); // Bail out early if the output is going to be empty @@ -85,12 +88,12 @@ Status MatMulIntegerToFloatBase::ComputeCommon(OpKernelContext* ctx, bool is_b_zp_per_column = false; uint8_t b_zp_default = 0; const uint8_t* b_zp_ptr = &b_zp_default; - if (nullptr != b_zp_tensor) { - ORT_ENFORCE(IsBQuantParamSupported(b_zp_tensor->Shape(), b_tensor ? b_tensor->Shape() : b_shape_), + if (nullptr != b_zp_constant_tensor) { + ORT_ENFORCE(IsBQuantParamSupported(b_zp_constant_tensor->Shape(), b_tensor ? b_tensor->Shape() : b_shape_), "MatmulInteger : b zero point is not valid"); - is_b_zp_per_column = !IsScalarOr1ElementVector(b_zp_tensor); - b_zp_ptr = static_cast(b_zp_tensor->DataRaw()); + is_b_zp_per_column = !IsScalarOr1ElementVector(b_zp_constant_tensor); + b_zp_ptr = static_cast(b_zp_constant_tensor->DataRaw()); } // process scale of b @@ -161,6 +164,119 @@ class DynamicQuantizeMatMul final : public MatMulIntegerToFloatBase { Status Compute(OpKernelContext* context) const override; +#if defined(USE_KLEIDIAI) && !defined(_MSC_VER) + Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + /*out*/ bool& is_packed, + /*out*/ PrePackedWeights* prepacked_weights) override { + // only pack Matrix B + if (input_idx == GetBIdx()) { + const Tensor* b_zp_constant_tensor{nullptr}; + bool b_quantization_might_be_asymmetric = false; + + const OrtValue* b_zp; + if (Info().TryGetConstantInput(IN_B_ZERO_POINT, &b_zp)) { + b_zp_constant_tensor = &b_zp->Get(); + } + + // MlasDynamicQgemm requires symmetric quantization for B, so the B zero point value should either be all zeros + // or not provided. + if (b_zp_constant_tensor != nullptr) { + // B zero point is constant. Check if it is all zeros. + assert(b_zp_constant_tensor->IsDataType() || b_zp_constant_tensor->IsDataType()); + const auto* zp_bytes = static_cast(b_zp_constant_tensor->DataRaw()); + const size_t zp_size_in_bytes = b_zp_constant_tensor->SizeInBytes(); + b_quantization_might_be_asymmetric = std::any_of(zp_bytes, zp_bytes + zp_size_in_bytes, + [](std::byte v) { return v != std::byte{0}; }); + } else { + // B zero point input is not constant. If it exists, we can't assume symmetric quantization. + const auto input_defs = Info().node().InputDefs(); + const bool b_zp_input_exists = input_defs.size() > IN_B_ZERO_POINT && input_defs[IN_B_ZERO_POINT]->Exists(); + b_quantization_might_be_asymmetric = b_zp_input_exists; + } + + // MlasDynamicQgemm requires scale data to be available at packing stage + const Tensor* b_scale_tensor = nullptr; + const bool b_scale_available = Info().TryGetConstantInput(IN_B_SCALE, &b_scale_tensor); + + can_use_dynamic_quant_mlas_ = (!b_quantization_might_be_asymmetric && b_scale_available); + + // Currently, MlasDynamicQGemmBatch() and associated functions require SME or else they are no-ops. + // We check that here too before attempting to use them. + if (!CPUIDInfo::GetCPUIDInfo().HasArm_SME()) { + can_use_dynamic_quant_mlas_ = false; + } + + // Only handle the common case of a 2D weight matrix. Additional matrices + // could be handled by stacking the packed buffers. + b_shape_ = tensor.Shape(); + // TO DO: handle b_shape_.NumDimensions() > 2 and all dimension values but the last two being 1. + if (!(b_shape_.NumDimensions() == 2 || (b_shape_.NumDimensions() == 3 && b_shape_[0] == 1))) { + can_use_dynamic_quant_mlas_ = false; + } + + // Can we use the mlas dynamic Q gemm interface supported with float output ? + if (!can_use_dynamic_quant_mlas_) { + // default to piece wise mlas interface with separate int matmul, quantize and float conversion + return MatMulIntegerToFloatBase::PrePack(tensor, input_idx, alloc, is_packed, prepacked_weights); + } + is_packed = false; + + // Default to all zeros for bias + const Tensor* bias_tensor{nullptr}; + const OrtValue* bias; + if (Info().TryGetConstantInput(IN_BIAS, &bias)) { + bias_tensor = &bias->Get(); + dynamic_quant_mlas_bias_data_was_packed_ = true; + } + size_t K = static_cast(b_shape_[0]); + size_t N = static_cast(b_shape_[1]); + + const auto* b_data = static_cast(tensor.DataRaw()); + + std::optional b_trans_buffer; + if (IsBTransposed()) { + std::swap(K, N); + b_data = quantization::TransPoseInputData(b_data, b_trans_buffer, alloc, N, K); + } + + const size_t packed_b_size = MlasDynamicQgemmPackBSize(N, K); + if (packed_b_size == 0) { + return Status::OK(); + } + + packed_b_ = IAllocator::MakeUniquePtr(alloc, packed_b_size, true); + // Initialize memory to 0 as there could be some padding associated with pre-packed + // buffer memory and we do not want it uninitialized and generate different hashes + // if and when we try to cache this pre-packed buffer for sharing between sessions. + memset(packed_b_.get(), 0, packed_b_size); + + const auto scales = static_cast(b_scale_tensor->Shape().Size()) == N ? std::vector(&b_scale_tensor->Data()[0], + &b_scale_tensor->Data()[N]) + : + // Broadcast matrix scale to all channels + std::vector(N, b_scale_tensor->Data()[0]); + + const auto biases = bias_tensor != nullptr ? std::vector(&bias_tensor->Data()[0], + &bias_tensor->Data()[N]) + : + // Broadcast zero to all channels - no bias data is available + std::vector(N, 0.f); + + MlasDynamicQgemmPackB(N, K, reinterpret_cast(b_data), scales.data(), biases.data(), + packed_b_.get()); + + bool share_prepacked_weights = (prepacked_weights != nullptr); + if (share_prepacked_weights) { + prepacked_weights->buffers_.push_back(std::move(packed_b_)); + prepacked_weights->buffer_sizes_.push_back(packed_b_size); + } + + is_packed = true; + } + return Status::OK(); + } +#endif + enum InputTensors : int { IN_A = 0, IN_B = 1, @@ -171,6 +287,12 @@ class DynamicQuantizeMatMul final : public MatMulIntegerToFloatBase { protected: int GetBIdx() const override { return IN_B; } + + private: + bool can_use_dynamic_quant_mlas_{false}; +#if defined(USE_KLEIDIAI) && !defined(_MSC_VER) + bool dynamic_quant_mlas_bias_data_was_packed_{false}; +#endif }; class MatMulIntegerToFloat final : public MatMulIntegerToFloatBase { @@ -199,44 +321,104 @@ class MatMulIntegerToFloat final : public MatMulIntegerToFloatBase { }; Status DynamicQuantizeMatMul::Compute(OpKernelContext* ctx) const { - const Tensor* a = ctx->Input(IN_A); - const Tensor* b = packed_b_ ? nullptr : ctx->Input(IN_B); - - const Tensor* b_scale_tensor = ctx->Input(IN_B_SCALE); - const Tensor* b_zp_tensor = ctx->Input(IN_B_ZERO_POINT); - - // calculate quantization parameter of a - const float* a_data = a->Data(); - int64_t num_of_elements = a->Shape().Size(); - - float a_scale; - uint8_t a_zero_point; - GetQuantizationParameter(a_data, num_of_elements, a_scale, a_zero_point, ctx->GetOperatorThreadPool()); - - AllocatorPtr allocator; - ORT_RETURN_IF_ERROR(ctx->GetTempSpaceAllocator(&allocator)); - uint8_t* a_data_quant = static_cast(allocator->Alloc(SafeInt(num_of_elements) * sizeof(uint8_t))); - BufferUniquePtr a_buffer_quant_holder(a_data_quant, BufferDeleter(std::move(allocator))); - - ParQuantizeLinearStd(a_data, a_data_quant, narrow(num_of_elements), a_scale, a_zero_point, ctx->GetOperatorThreadPool()); - - bool is_b_scale_supported = IsBQuantParamSupported(b_scale_tensor->Shape(), b ? b->Shape() : b_shape_); - ORT_RETURN_IF_ERROR(ComputeCommon( - ctx, - a_data_quant, - a->Shape(), - a_scale, - a_zero_point, - false /*a_is_signed*/, - b, - is_b_scale_supported ? b_scale_tensor : nullptr, - b_zp_tensor, - ctx->Input(IN_BIAS))); - - if (!is_b_scale_supported) { - ScaleOutput(*b_scale_tensor, *ctx->Output(0)); + // Can this operation be offloaded to a MLAS specific dynamic quantization matmul ? + if (!can_use_dynamic_quant_mlas_) { + const Tensor* a = ctx->Input(IN_A); + const Tensor* b = packed_b_ ? nullptr : ctx->Input(IN_B); + + const Tensor* b_scale_tensor = ctx->Input(IN_B_SCALE); + const Tensor* b_zp_constant_tensor = ctx->Input(IN_B_ZERO_POINT); + + // calculate quantization parameter of a + const float* a_data = a->Data(); + int64_t num_of_elements = a->Shape().Size(); + + float a_scale; + uint8_t a_zero_point; + GetQuantizationParameter(a_data, num_of_elements, a_scale, a_zero_point, ctx->GetOperatorThreadPool()); + + AllocatorPtr allocator; + ORT_RETURN_IF_ERROR(ctx->GetTempSpaceAllocator(&allocator)); + uint8_t* a_data_quant = static_cast(allocator->Alloc(SafeInt(num_of_elements) * sizeof(uint8_t))); + BufferUniquePtr a_buffer_quant_holder(a_data_quant, BufferDeleter(std::move(allocator))); + + ParQuantizeLinearStd(a_data, a_data_quant, narrow(num_of_elements), a_scale, a_zero_point, ctx->GetOperatorThreadPool()); + + bool is_b_scale_supported = IsBQuantParamSupported(b_scale_tensor->Shape(), b ? b->Shape() : b_shape_); + const bool is_a_signed = false; + ORT_RETURN_IF_ERROR(ComputeCommon( + ctx, + a_data_quant, + a->Shape(), + a_scale, + a_zero_point, + is_a_signed, + b, + is_b_scale_supported ? b_scale_tensor : nullptr, + b_zp_constant_tensor, + ctx->Input(IN_BIAS))); + + if (!is_b_scale_supported) { + ScaleOutput(*b_scale_tensor, *ctx->Output(0)); + } } + // Guard against KleidiAI functions being called in non kleidi builds + // TODO: migrate to a suitable override function call for kleidi dynamic qgemm function calls +#if defined(USE_KLEIDIAI) && !defined(_MSC_VER) + else { + MatMulComputeHelper helper; + ORT_RETURN_IF_ERROR(helper.Compute(ctx->Input(IN_A)->Shape(), + b_shape_, // ctx->Input(IN_B)->Shape(), this is not available now constant data is + // deleted during session init post prepacking + nullptr, + nullptr)); + + Tensor* y = ctx->Output(OUT_Y, helper.OutputShape()); + + // Bail out early if the output is going to be empty + if (y->Shape().Size() == 0) + return Status::OK(); + + auto a_data = static_cast(ctx->Input(IN_A)->DataRaw()); + auto* y_data = y->MutableData(); + + // batch gemm + MLAS_GEMM_DYN_QUANT_SHAPE_PARAMS gemm_shape; + gemm_shape.M = static_cast(helper.M()); + gemm_shape.N = static_cast(helper.N()); + gemm_shape.K = static_cast(helper.K()); + + const size_t num_gemms = helper.OutputOffsets().size(); + std::vector gemm_data_vec(num_gemms); + + for (size_t gemm_idx = 0; gemm_idx < num_gemms; gemm_idx++) { + auto& params = gemm_data_vec[gemm_idx]; + params.A = reinterpret_cast(a_data + helper.LeftOffsets()[gemm_idx]); + params.lda = gemm_shape.K; + params.PackedB = packed_b_.get(); + params.C = y_data + helper.OutputOffsets()[gemm_idx]; + params.ldc = gemm_shape.N; + } + MlasDynamicQGemmBatch(gemm_shape, gemm_data_vec.data(), num_gemms, ctx->GetOperatorThreadPool()); + // This evaluates to true if bias data was not provided as constant data for prepacking stage + if (!dynamic_quant_mlas_bias_data_was_packed_) { + if (ctx->Input(IN_BIAS) != nullptr) { + const auto biases = std::vector(&ctx->Input(IN_BIAS)->Data()[0], + &ctx->Input(IN_BIAS)->Data()[gemm_shape.N]); + + // deferred adding of bias + for (size_t gemm_idx = 0; gemm_idx < num_gemms; gemm_idx++) { + float* MxN = y_data + helper.OutputOffsets()[gemm_idx]; + for (auto l = gemm_shape.M; l > 0; --l) { + MlasEltwiseAdd(MxN, biases.data(), MxN, gemm_shape.N); + MxN += gemm_shape.N; + } + } + } + } + } +#endif return Status::OK(); } @@ -275,7 +457,7 @@ Status MatMulIntegerToFloat::Compute(OpKernelContext* ctx) const { a_zero_point = *(static_cast(a_zero_point_tensor->DataRaw())); } - const Tensor* b_zp_tensor = ctx->Input(IN_B_ZERO_POINT); + const Tensor* b_zp_constant_tensor = ctx->Input(IN_B_ZERO_POINT); ORT_RETURN_IF_ERROR(ComputeCommon( ctx, static_cast(a->DataRaw()), @@ -285,7 +467,7 @@ Status MatMulIntegerToFloat::Compute(OpKernelContext* ctx) const { a->IsDataType(), b, is_b_scale_supported ? b_scale_tensor : nullptr, - b_zp_tensor, + b_zp_constant_tensor, ctx->Input(IN_BIAS))); if (!is_a_scale_scalar) { diff --git a/onnxruntime/contrib_ops/cuda/bert/cudnn_fmha/cudnn_flash_attention.cc b/onnxruntime/contrib_ops/cuda/bert/cudnn_fmha/cudnn_flash_attention.cc index ec5deccf655ff..ba786931bb39a 100644 --- a/onnxruntime/contrib_ops/cuda/bert/cudnn_fmha/cudnn_flash_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/cudnn_fmha/cudnn_flash_attention.cc @@ -391,8 +391,7 @@ void run( // Allocate workspace. auto bytes = mha_graph->get_workspace_size(); - IAllocatorUniquePtr buffer = IAllocator::MakeUniquePtr( - allocator, bytes, false, stream, WaitCudaNotificationOnDevice); + IAllocatorUniquePtr buffer = IAllocator::MakeUniquePtr(allocator, bytes, false, stream); CUDNN_FE_CALL_THROW(mha_graph->execute(handle, variant_pack, buffer.get())); } diff --git a/onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc b/onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc index 1a4a63de38790..e8cdc50ed4ca7 100644 --- a/onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc +++ b/onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc @@ -78,8 +78,11 @@ Status ShardedMoE::ComputeInternal(OpKernelContext* context) const { ORT_RETURN_IF_NOT(moe_params.num_experts % nccl_->Size() == 0, "num_experts should be divisible by world_size"); - ort_fastertransformer::CutlassMoeFCRunner moe_runner(sm, fc3_experts_weights_optional != nullptr, - normalize_routing_weights_, use_sparse_mixer_); + ort_fastertransformer::CutlassMoeFCRunner moe_runner(sm, + activation_type_, + fc3_experts_weights_optional != nullptr, + normalize_routing_weights_, + use_sparse_mixer_); size_t ws_size = moe_runner.getWorkspaceSize( static_cast(moe_params.num_rows), static_cast(moe_params.hidden_size), diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels.h index 36127054cfd5e..d5ad8161e100e 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels.h +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels.h @@ -52,6 +52,7 @@ enum class ActivationType { Gelu, GeGLU, ReGLU, SiGLU, + SwiGLU, Identity, InvalidType }; diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h index ef1f97b9e57a2..8b8f45e77ab9d 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h @@ -391,12 +391,10 @@ void MoeGemmRunner::dispatch_to_arch(const T* A, con dispatch_moe_gemm_to_cutlass( A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, sm_, multi_processor_count_, stream, occupancy); - } else if (sm_ >= 80 && sm_ < 90) { + } else if (sm_ >= 80) { // Hopper and Blackwell will fallback to use Ampere kernels. dispatch_moe_gemm_to_cutlass( A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, sm_, multi_processor_count_, stream, occupancy); - } else { - ORT_THROW("[MoE][GEMM Dispatch] Arch unsupported for MoE GEMM"); } } @@ -478,6 +476,7 @@ void MoeGemmRunner::moe_gemm_bias_act(const T* A, const WeightTyp int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, ActivationType activation_type, cudaStream_t stream) { + // Swiglu will use Identity to call this function so we not need to handle it here. switch (activation_type) { case ActivationType::Relu: run_gemm(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, gemm_n, diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu index bfbe1d81b1c15..4268b79e1e4f8 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu @@ -44,6 +44,72 @@ namespace ort_fastertransformer { static constexpr int WARP_SIZE = 32; + +// SwiGLU with interleaved is like the following python code using PyTorch: +// dim = x.shape[-1] +// x = x.view(-1, dim // 2, 2) +// x_glu, x_linear = x[..., 0], x[..., 1] +// y = x_glu * torch.sigmoid(alpha * x_glu) * (x_linear + 1) +template +__global__ void swiglu_kernel_interleaved(T* output, T const* input, int intermediate_size, int num_rows, float swiglu_alpha) { + int const row = blockIdx.x; + if (row >= num_rows) { + return; + } + + T const* row_input = input + row * 2 * intermediate_size; + T* row_output = output + row * intermediate_size; + + for (int i = threadIdx.x; i < intermediate_size; i += blockDim.x) { + T x_glu = row_input[2 * i]; + T x_linear = row_input[2 * i + 1]; + + float sigmoid_arg = swiglu_alpha * static_cast(x_glu); + float sigmoid_out = 1.f / (1.f + expf(-sigmoid_arg)); + + float swish_out = static_cast(x_glu) * sigmoid_out; + row_output[i] = static_cast(swish_out * (static_cast(x_linear) + 1.f)); + } +} + +// Non interleaved version of SwiGLU kernel, which splits each row into two chunks of same size. +template +__global__ void swiglu_kernel_chunked(T* output, T const* input, int intermediate_size, int num_rows, float swiglu_alpha) { + int const row = blockIdx.x; + if (row >= num_rows) { + return; + } + + T const* row_input = input + row * 2 * intermediate_size; + T* row_output = output + row * intermediate_size; + + for (int i = threadIdx.x; i < intermediate_size; i += blockDim.x) { + T x_glu = row_input[i]; + T x_linear = row_input[i + intermediate_size]; + + float sigmoid_arg = swiglu_alpha * static_cast(x_glu); + float sigmoid_out = 1.f / (1.f + expf(-sigmoid_arg)); + + float swish_out = static_cast(x_glu) * sigmoid_out; + row_output[i] = static_cast(swish_out * (static_cast(x_linear) + 1.f)); + } +} + +template +void invokeSwiGLU(T* output, T const* input, int intermediate_size, int num_rows, float swiglu_alpha, cudaStream_t stream) { + if (num_rows == 0) { + return; + } + dim3 block(std::min(intermediate_size, 1024)); + dim3 grid(num_rows); + + if constexpr (interleaved) { + swiglu_kernel_interleaved<<>>(output, input, intermediate_size, num_rows, swiglu_alpha); + } else { + swiglu_kernel_chunked<<>>(output, input, intermediate_size, num_rows, swiglu_alpha); + } +} + // ====================== Softmax things =============================== // We have our own implementation of softmax here so we can support transposing the output // in the softmax kernel when we extend this module to support expert-choice routing. @@ -666,9 +732,14 @@ __global__ void dispatch_activations_kernel(int64_t* total_rows_before_expert, i } template -CutlassMoeFCRunner::CutlassMoeFCRunner(int sm_version, bool has_fc3, +CutlassMoeFCRunner::CutlassMoeFCRunner(int sm_version, ActivationType activation_type, bool has_fc3, bool normalize_routing_weights, bool use_sparse_mixer) - : has_fc3_(has_fc3), total_past_rows_(0), total_covered_rows_(0), normalize_routing_weights_(normalize_routing_weights), use_sparse_mixer_(use_sparse_mixer) { + : activation_type_(activation_type), + has_fc3_(has_fc3), + total_past_rows_(0), + total_covered_rows_(0), + normalize_routing_weights_(normalize_routing_weights), + use_sparse_mixer_(use_sparse_mixer) { moe_gemm_runner_.initialize(sm_version); } @@ -695,8 +766,16 @@ size_t CutlassMoeFCRunner::getWorkspaceSize(size_t num_ro total_ws_bytes += buf_size * sizeof(T); // permuted_data total_ws_bytes += padded_experts * sizeof(int64_t); // Hold total_rows_before_expert_ total_ws_bytes += num_softmax_outs * sizeof(T); - const size_t bytes_for_fc1_result = has_fc3_ ? 2 * interbuf_size * sizeof(T) : interbuf_size * sizeof(T); - const size_t sorter_ws_size_bytes = pad_to_multiple_of_16(sorter_.getWorkspaceSize(num_rows)); + + size_t bytes_for_fc1_result; + if (activation_type_ == ActivationType::SwiGLU) { + // Space for both fc1_result_ and act_result_. + bytes_for_fc1_result = (2 * interbuf_size + interbuf_size) * sizeof(T); + } else { + bytes_for_fc1_result = has_fc3_ ? 2 * interbuf_size * sizeof(T) : interbuf_size * sizeof(T); + } + + const size_t sorter_ws_size_bytes = pad_to_multiple_of_16(sorter_.getWorkspaceSize(k * num_rows)); sorter_.update_num_experts(static_cast(num_experts)); size_t bytes_for_intermediate_and_sorting = bytes_for_fc1_result; @@ -705,7 +784,7 @@ size_t CutlassMoeFCRunner::getWorkspaceSize(size_t num_ro bytes_for_intermediate_and_sorting += remaining_bytes; } - total_ws_bytes += bytes_for_intermediate_and_sorting; // intermediate (fc1) output + cub sorting workspace + total_ws_bytes += bytes_for_intermediate_and_sorting; return total_ws_bytes; } @@ -725,16 +804,34 @@ void CutlassMoeFCRunner::configure_ws_ptrs(char* ws_ptr, total_rows_before_expert_ = reinterpret_cast(permuted_data_ + buf_size); + char* current_ptr = reinterpret_cast(total_rows_before_expert_ + padded_experts); + + if (activation_type_ == ActivationType::SwiGLU) { + // fc1_result_ is used for GEMM1 output (2 * inter_size) + fc1_result_ = reinterpret_cast(current_ptr); + current_ptr += 2 * interbuf_size * sizeof(T); + + // act_result_ is used for SwiGLU output (inter_size) + act_result_ = reinterpret_cast(current_ptr); + current_ptr += interbuf_size * sizeof(T); + + ORT_ENFORCE(!has_fc3_, "SwiGLU activation is not supported with fc3"); + } else { + fc1_result_ = reinterpret_cast(current_ptr); + act_result_ = nullptr; // No extra buffer for activation since it is done inplace. + current_ptr += interbuf_size * sizeof(T); + } + if (has_fc3_) { - fc3_result_ = reinterpret_cast(total_rows_before_expert_ + padded_experts); - fc1_result_ = reinterpret_cast(fc3_result_ + interbuf_size); + fc3_result_ = reinterpret_cast(current_ptr); + current_ptr += interbuf_size * sizeof(T); } else { - fc1_result_ = reinterpret_cast(total_rows_before_expert_ + padded_experts); + fc3_result_ = nullptr; } const bool is_pow_2 = (num_experts != 0) && ((num_experts & (num_experts - 1)) == 0); if (!is_pow_2 || num_experts > 256) { - softmax_out_ = reinterpret_cast(fc1_result_ + interbuf_size); + softmax_out_ = reinterpret_cast(current_ptr); } else { softmax_out_ = nullptr; } @@ -880,8 +977,51 @@ void CutlassMoeFCRunner::run_moe_fc( stream); } - // moe_gemm_runner_.try_find_best_config(local_num_experts, hidden_size, inter_size, - // expanded_active_expert_rows); + if (fc1_activation_type == ActivationType::SwiGLU) { + T* gemm1_output_buffer = fc1_result_; + T* swiglu_output_buffer = act_result_; + + moe_gemm_runner_.moe_gemm_bias_act( + permuted_data_ + total_past_rows_ * hidden_size, + fc1_expert_weights, + fc1_scales, + fc1_expert_biases, + gemm1_output_buffer + total_past_rows_ * 2 * inter_size, + total_rows_before_expert_ + local_experts_start_index, + expanded_active_expert_rows, + 2 * inter_size, + hidden_size, + local_num_experts, + ActivationType::Identity, + stream); + + constexpr bool swiglu_interleaved = true; + constexpr float swiglu_alpha = 1.702f; + invokeSwiGLU( + swiglu_output_buffer + total_past_rows_ * inter_size, + gemm1_output_buffer + total_past_rows_ * 2 * inter_size, + inter_size, + static_cast(total_covered_rows_), + swiglu_alpha, + stream); + + moe_gemm_runner_.moe_gemm( + swiglu_output_buffer + total_past_rows_ * inter_size, + fc2_expert_weights, + fc2_scales, + nullptr, + fc2_result + total_past_rows_ * hidden_size, + total_rows_before_expert_ + local_experts_start_index, + expanded_active_expert_rows, + hidden_size, + inter_size, + local_num_experts, + stream); + + // No fc3 for SwiGLU + return; + } + moe_gemm_runner_.moe_gemm_bias_act( permuted_data_ + total_past_rows_ * hidden_size, fc1_expert_weights, fc1_scales, fc1_expert_biases, fc1_result_ + total_past_rows_ * inter_size, total_rows_before_expert_ + local_experts_start_index, @@ -1178,4 +1318,7 @@ template void finalize_moe_routing_kernelLauncher(const float*, float*, const fl template void finalize_moe_routing_kernelLauncher(const half*, half*, const half*, const half*, const half*, const half*, const int*, const int*, int, int, int, cudaStream_t); +template void invokeSwiGLU(float*, float const*, int, int, float, cudaStream_t); +template void invokeSwiGLU(half*, half const*, int, int, float, cudaStream_t); + } // namespace ort_fastertransformer diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.h index c457b608decbf..3ac4862e101c3 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.h +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.h @@ -54,7 +54,10 @@ static inline size_t pad_to_multiple_of_16(size_t input) { template void topk_gating_softmax_kernelLauncher(const T* input, const bool* finished, T* output, T* softmax_temp_out, int* indices, int* source_row, int num_rows, int num_experts, int k, - cudaStream_t stream); + bool normalize_routing_weights, bool use_sparse_mixer, cudaStream_t stream); + +template +void invokeSwiGLU(T* output, T const* input, int intermediate_size, int num_rows, float swiglu_alpha, cudaStream_t stream); class CubKeyValueSorter { public: @@ -109,7 +112,7 @@ template class CutlassMoeFCRunner { public: - CutlassMoeFCRunner(int sm_version, bool has_fc3, bool normalize_routing_weights, bool use_sparse_mixer); + CutlassMoeFCRunner(int sm_version, ActivationType activation_type, bool has_fc3, bool normalize_routing_weights, bool use_sparse_mixer); size_t getWorkspaceSize(size_t num_rows, size_t hidden_size, size_t inter_size, size_t num_experts, size_t k); @@ -157,8 +160,10 @@ class CutlassMoeFCRunner { int64_t* total_rows_before_expert_; T* fc1_result_; + T* act_result_; T* fc3_result_; + ActivationType activation_type_; bool has_fc3_; bool normalize_routing_weights_; bool use_sparse_mixer_; @@ -176,7 +181,7 @@ class CutlassMoeFCRunner { template class CutlassMoeFCRunner::value>> { public: - CutlassMoeFCRunner(int sm_version, bool has_fc3, bool normalize_routing_weights, bool use_sparse_mixer); + CutlassMoeFCRunner(int sm_version, ActivationType activation_type, bool has_fc3, bool normalize_routing_weights, bool use_sparse_mixer); size_t getWorkspaceSize(size_t num_rows, size_t hidden_size, size_t inter_size, size_t num_experts, size_t k) { return 0; diff --git a/onnxruntime/contrib_ops/cuda/moe/moe.cc b/onnxruntime/contrib_ops/cuda/moe/moe.cc index c5352d931ce2c..cc6fe871a3bc1 100644 --- a/onnxruntime/contrib_ops/cuda/moe/moe.cc +++ b/onnxruntime/contrib_ops/cuda/moe/moe.cc @@ -48,8 +48,11 @@ Status MoE::ComputeInternal(OpKernelContext* context) const { auto& device_prop = GetDeviceProp(); const int sm = device_prop.major * 10 + device_prop.minor; - ort_fastertransformer::CutlassMoeFCRunner moe_runner(sm, fc3_experts_weights_optional != nullptr, - normalize_routing_weights_, use_sparse_mixer_); + ort_fastertransformer::CutlassMoeFCRunner moe_runner(sm, + activation_type_, + fc3_experts_weights_optional != nullptr, + normalize_routing_weights_, + use_sparse_mixer_); size_t ws_size = moe_runner.getWorkspaceSize( static_cast(moe_params.num_rows), static_cast(moe_params.hidden_size), diff --git a/onnxruntime/contrib_ops/cuda/moe/moe_base.h b/onnxruntime/contrib_ops/cuda/moe/moe_base.h index 6b65557444a66..194f33acbeb59 100644 --- a/onnxruntime/contrib_ops/cuda/moe/moe_base.h +++ b/onnxruntime/contrib_ops/cuda/moe/moe_base.h @@ -76,15 +76,16 @@ class MoEBase { } const int64_t coe = quant_type == MoEQuantType::UINT4 ? 2 : 1; - if (fc1_experts_weights_dims[2] != inter_size / coe) { + const int64_t act = activation_type_ == ort_fastertransformer::ActivationType::SwiGLU ? 2 : 1; + if (fc1_experts_weights_dims[2] != act * inter_size / coe) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "fc1_experts_weights_dims[2] must be equal to inter_size, got ", - fc1_experts_weights_dims[2], " and ", inter_size); + "fc1_experts_weights_dims[2] is ", + fc1_experts_weights_dims[2], " expected ", act * inter_size / coe); } if (fc2_experts_weights_dims[2] != hidden_size / coe) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "fc2_experts_weights_dims[2] must be equal to hidden_size, got ", - fc2_experts_weights_dims[2], " and ", hidden_size); + "fc2_experts_weights_dims[2] is ", + fc2_experts_weights_dims[2], " expected ", hidden_size / coe); } if (router_probs_dims.size() != 2) { @@ -116,10 +117,10 @@ class MoEBase { "fc2_experts_bias_dims[0] must be equal to num_experts, got ", fc2_experts_bias_dims[0], " and ", num_experts); } - if (fc1_experts_bias_dims[1] != inter_size) { + if (fc1_experts_bias_dims[1] != act * inter_size) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "fc1_experts_bias_dims[1] must be equal to inter_size, got ", fc1_experts_bias_dims[1], - " and ", inter_size); + "fc1_experts_bias_dims[1] is ", fc1_experts_bias_dims[1], + ", expected ", act * inter_size); } if (fc2_experts_bias_dims[1] != hidden_size) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, @@ -182,10 +183,14 @@ class MoEBase { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc1_experts_scales[0] must be equal to num_experts, got ", fc1_experts_scales_dims[0], " and ", num_experts); } - if (fc1_experts_scales_dims[1] != inter_size) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc1_experts_scales[1] must be equal to inter_size, got ", - fc1_experts_scales_dims[1], " and ", inter_size); + + // The activation type affects the output dimension of the first FC layer. + const int64_t act = activation_type_ == ort_fastertransformer::ActivationType::SwiGLU ? 2 : 1; + if (fc1_experts_scales_dims[1] != act * inter_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc1_experts_scales[1] must be equal to act * inter_size, got ", + fc1_experts_scales_dims[1], " and ", act * inter_size); } + if (fc2_experts_scales_dims.size() != 2) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc2_experts_scales must be 2D, got ", fc2_experts_scales->Shape().GetDims().size()); @@ -219,6 +224,8 @@ class MoEBase { activation_type_ = ort_fastertransformer::ActivationType::Gelu; } else if (activation_type_str == "silu") { activation_type_ = ort_fastertransformer::ActivationType::Silu; + } else if (activation_type_str == "swiglu") { + activation_type_ = ort_fastertransformer::ActivationType::SwiGLU; } else if (activation_type_str == "identity") { activation_type_ = ort_fastertransformer::ActivationType::Identity; } else { diff --git a/onnxruntime/contrib_ops/cuda/quantization/moe_quantization.cc b/onnxruntime/contrib_ops/cuda/quantization/moe_quantization.cc index 4dd5a079d1a29..db6d99674cf5a 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/moe_quantization.cc +++ b/onnxruntime/contrib_ops/cuda/quantization/moe_quantization.cc @@ -72,6 +72,7 @@ Status QMoE::QuantizedMoEImpl(OpKernelContext* context, using CudaT = typename ToCudaType::MappedType; ort_fastertransformer::CutlassMoeFCRunner moe_runner(sm, + activation_type_, fc3_experts_weights_optional != nullptr, normalize_routing_weights_, use_sparse_mixer_); @@ -185,4 +186,4 @@ Status QMoE::ComputeInternal(OpKernelContext* context) const { } // namespace cuda } // namespace contrib -} // namespace onnxruntime \ No newline at end of file +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc index f7ed758aedbb2..d20d0b4218bd3 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc @@ -232,7 +232,7 @@ Status AddToFeeds(Stream* ort_stream, } } if (!buffer) { - buffer = IAllocator::MakeUniquePtr(device_allocator, total_bytes, false, ort_stream, WaitCudaNotificationOnDevice); + buffer = IAllocator::MakeUniquePtr(device_allocator, total_bytes, false, ort_stream); } char* gpu_data = buffer.get(); CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(gpu_data, pinned_data, total_bytes, cudaMemcpyHostToDevice, stream)); diff --git a/onnxruntime/core/common/cpuid_info.cc b/onnxruntime/core/common/cpuid_info.cc index c4667d53c0674..dccfdbda8971b 100644 --- a/onnxruntime/core/common/cpuid_info.cc +++ b/onnxruntime/core/common/cpuid_info.cc @@ -190,6 +190,7 @@ void CPUIDInfo::ArmLinuxInit() { has_arm_neon_i8mm_ = cpuinfo_has_arm_i8mm(); has_arm_sve_i8mm_ = cpuinfo_has_arm_sve() && cpuinfo_has_arm_i8mm(); has_arm_neon_bf16_ = cpuinfo_has_arm_neon_bf16(); + has_arm_sme_ = cpuinfo_has_arm_sme(); const uint32_t core_cnt = cpuinfo_get_cores_count(); core_uarchs_.resize(core_cnt, cpuinfo_uarch_unknown); @@ -342,6 +343,7 @@ void CPUIDInfo::ArmAppleInit() { has_arm_neon_i8mm_ = cpuinfo_has_arm_i8mm(); has_arm_sve_i8mm_ = cpuinfo_has_arm_sve() && cpuinfo_has_arm_i8mm(); has_arm_neon_bf16_ = cpuinfo_has_arm_neon_bf16(); + has_arm_sme_ = cpuinfo_has_arm_sme(); // Note: We leave is_armv8_narrow_ld_ unset because it only applies to a limited set of uarchs that we don't expect // to encounter on Apple platforms. diff --git a/onnxruntime/core/common/cpuid_info.h b/onnxruntime/core/common/cpuid_info.h index 9c67ebbffa260..84571fa12e6ea 100644 --- a/onnxruntime/core/common/cpuid_info.h +++ b/onnxruntime/core/common/cpuid_info.h @@ -40,6 +40,7 @@ class CPUIDInfo { bool HasArmNeon_I8MM() const { return has_arm_neon_i8mm_; } bool HasArmSVE_I8MM() const { return has_arm_sve_i8mm_; } bool HasArmNeon_BF16() const { return has_arm_neon_bf16_; } + bool HasArm_SME() const { return has_arm_sme_; } uint32_t GetCurrentCoreIdx() const; @@ -127,6 +128,7 @@ class CPUIDInfo { bool has_arm_neon_i8mm_{false}; bool has_arm_sve_i8mm_{false}; bool has_arm_neon_bf16_{false}; + bool has_arm_sme_{false}; std::string vendor_; uint32_t vendor_id_; diff --git a/onnxruntime/core/framework/allocator.cc b/onnxruntime/core/framework/allocator.cc index f089761e0643b..e1b9d1294fb9e 100644 --- a/onnxruntime/core/framework/allocator.cc +++ b/onnxruntime/core/framework/allocator.cc @@ -164,19 +164,16 @@ void CPUAllocator::Free(void* p) { AllocatorDefaultFreeAligned(p, alignment); } -void* AllocateBufferWithOptions(IAllocator& alloc, size_t size, bool use_reserve, Stream* stream, WaitNotificationFn wait_fn) { - if (use_reserve) +void* AllocateBufferWithOptions(IAllocator& alloc, size_t size, bool use_reserve, Stream* stream, + WaitNotificationFn /*ignored*/) { + if (use_reserve) { return alloc.Reserve(size); - if (stream && alloc.Info().alloc_type == OrtArenaAllocator) { -#ifdef ORT_ENABLE_STREAM - auto* stream_aware_alloc = StreamAwareArena::FromBFCArena(static_cast(alloc)); - if (stream_aware_alloc) { - return stream_aware_alloc->AllocOnStream(size, stream, wait_fn); - } -#else - ORT_UNUSED_PARAMETER(wait_fn); -#endif // ORT_ENABLE_STREAM } + + if (stream && alloc.IsStreamAware()) { + return alloc.AllocOnStream(size, stream); + } + return alloc.Alloc(size); } } // namespace onnxruntime diff --git a/onnxruntime/core/framework/allocator_utils.cc b/onnxruntime/core/framework/allocator_utils.cc index edf965d3835b5..8c4e74c4b1cc7 100644 --- a/onnxruntime/core/framework/allocator_utils.cc +++ b/onnxruntime/core/framework/allocator_utils.cc @@ -54,7 +54,6 @@ AllocatorPtr CreateAllocator(const AllocatorCreationInfo& info) { return AllocatorPtr( std::make_unique(std::move(device_allocator), max_mem, - info.enable_cross_stream_reusing, arena_extend_str, initial_chunk_size_bytes, max_dead_bytes_per_chunk, diff --git a/onnxruntime/core/framework/allocator_utils.h b/onnxruntime/core/framework/allocator_utils.h index bef0b7057a7f8..076d4dbcc16c5 100644 --- a/onnxruntime/core/framework/allocator_utils.h +++ b/onnxruntime/core/framework/allocator_utils.h @@ -19,14 +19,12 @@ struct AllocatorCreationInfo { OrtDevice::DeviceId device_id = 0, bool use_arena = true, OrtArenaCfg arena_cfg = {0, -1, -1, -1, -1, -1L}, - bool stream_aware_arena = false, - bool cross_stream_reusing = false) + bool stream_aware_arena = false) : device_alloc_factory(device_alloc_factory), device_id(device_id), use_arena(use_arena), arena_cfg(arena_cfg), - use_stream_aware_arena(stream_aware_arena), - enable_cross_stream_reusing(cross_stream_reusing) { + use_stream_aware_arena(stream_aware_arena) { } AllocatorFactory device_alloc_factory; @@ -34,7 +32,6 @@ struct AllocatorCreationInfo { bool use_arena; OrtArenaCfg arena_cfg; bool use_stream_aware_arena; - bool enable_cross_stream_reusing; }; // Returns an allocator (an instance of IAllocator) based on the creation info provided. diff --git a/onnxruntime/core/framework/bfc_arena.cc b/onnxruntime/core/framework/bfc_arena.cc index ed64769d13fcc..e0b50cd04173e 100644 --- a/onnxruntime/core/framework/bfc_arena.cc +++ b/onnxruntime/core/framework/bfc_arena.cc @@ -224,6 +224,7 @@ Status BFCArena::Extend(size_t rounded_bytes) { c->next = kInvalidChunkHandle; // assign the new created chunk to default stream, so it can be pick up by any stream c->stream = nullptr; + c->stream_sync_id = 0; region_manager_.set_handle(c->ptr, h); @@ -253,7 +254,7 @@ void BFCArena::DeallocateChunk(ChunkHandle h) { Chunk* c = ChunkFromHandle(h); // clean the stream / timestamp when deallocate chunk c->stream = nullptr; - c->stream_timestamp = 0; + c->stream_sync_id = 0; c->next = free_chunks_list_; free_chunks_list_ = h; } @@ -268,7 +269,7 @@ size_t BFCArena::RoundedBytes(size_t bytes) { } void* BFCArena::Alloc(size_t size) { - return AllocateRawInternal(size, false, nullptr, false, nullptr); + return AllocateRawInternal(size, false, nullptr); } void* BFCArena::Reserve(size_t size) { @@ -309,13 +310,11 @@ size_t BFCArena::AllocatedSize(const void* ptr) { void* BFCArena::AllocateRawInternal(size_t num_bytes, bool dump_log_on_failure, - Stream* stream, - bool enable_cross_stream_reusing, - WaitNotificationFn wait_fn) { + Stream* stream) { if (num_bytes == 0) { - LOGS_DEFAULT(VERBOSE) << "tried to allocate 0 bytes"; return nullptr; } + // First, always allocate memory of at least kMinAllocationSize // bytes, and always allocate multiples of kMinAllocationSize bytes // so all memory addresses are nicely byte aligned. @@ -326,20 +325,9 @@ void* BFCArena::AllocateRawInternal(size_t num_bytes, std::lock_guard lock(lock_); // search for a valid chunk - auto* chunk = FindChunkPtr(bin_num, - rounded_bytes, - num_bytes, - stream, - enable_cross_stream_reusing, - wait_fn); + auto* chunk = FindChunkPtr(bin_num, rounded_bytes, num_bytes, stream); if (chunk != nullptr) { - // if it is on default stream (the new allocate chunk), assign to current stream - if (chunk->stream == nullptr) { - chunk->stream = stream; - if (stream) - chunk->stream_timestamp = stream->GetCurrentTimestamp(); - } return chunk->ptr; } @@ -349,12 +337,8 @@ void* BFCArena::AllocateRawInternal(size_t num_bytes, // Try to extend auto status = Extend(rounded_bytes); if (status.IsOK()) { - chunk = FindChunkPtr(bin_num, rounded_bytes, num_bytes, stream, false); + chunk = FindChunkPtr(bin_num, rounded_bytes, num_bytes, stream); if (chunk != nullptr) { - // if it is on default stream (the new allocate chunk), assign to current stream - if (chunk->stream == nullptr && stream) { - chunk->stream = stream; - } return chunk->ptr; } else { status = ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, @@ -413,11 +397,8 @@ BFCArena::Chunk* BFCArena::SplitFreeChunkFromBin(BFCArena::Bin::FreeChunkSet* fr } BFCArena::Chunk* BFCArena::FindChunkPtr(BinNum bin_num, size_t rounded_bytes, - size_t num_bytes, Stream* stream, - bool allow_chunk_from_different_stream, - WaitNotificationFn wait_fn) { - BFCArena::Chunk* other_stream_candidate = nullptr; - // First identify the first bin that could satisfy rounded_bytes. + size_t num_bytes, Stream* stream) { + // First identify the first bin that could satisfy rounded_bytes. for (; bin_num < kNumBins; bin_num++) { // Start searching from the first bin for the smallest chunk that fits // rounded_bytes. @@ -427,29 +408,27 @@ BFCArena::Chunk* BFCArena::FindChunkPtr(BinNum bin_num, size_t rounded_bytes, BFCArena::Chunk* chunk = ChunkFromHandle(h); ORT_ENFORCE(!chunk->in_use()); if (chunk->size >= rounded_bytes) { - // We found an existing chunk that fits us that wasn't in use, now check the stream + // We found an existing chunk that fits us that wasn't in use, now check the stream. bool safe_to_use = chunk->stream == stream || !chunk->stream || (stream && chunk->stream && - chunk->stream_timestamp < stream->GetLastSyncTimestampWithTargetStream(chunk->stream)); + chunk->stream_sync_id < stream->GetSyncIdForLastWaitOnStream(*chunk->stream)); if (safe_to_use) { // the chunk with same stream has higher priority. - return SplitFreeChunkFromBin(&b->free_chunks, citer, rounded_bytes, num_bytes); - } else if (allow_chunk_from_different_stream && !other_stream_candidate) { - other_stream_candidate = chunk; + chunk = SplitFreeChunkFromBin(&b->free_chunks, citer, rounded_bytes, num_bytes); + + if (stream) { + chunk->stream = stream; + chunk->stream_sync_id = stream->GetSyncId(); + } + + return chunk; } } } } - // if trying to use an unsafe chunk from other streams, secure it. - if (other_stream_candidate) { - SecureTheChunk(other_stream_candidate->stream, stream, wait_fn); - // if find some available chunk, make sure mark it as "being used" before return - other_stream_candidate->allocation_id = next_allocation_id_++; - other_stream_candidate->bin_num = kInvalidBinNum; - } - return other_stream_candidate; + return nullptr; } void BFCArena::SplitChunk(BFCArena::ChunkHandle h, size_t num_bytes) { @@ -463,7 +442,7 @@ void BFCArena::SplitChunk(BFCArena::ChunkHandle h, size_t num_bytes) { BFCArena::Chunk* new_chunk = ChunkFromHandle(h_new_chunk); // set the new chunk's stream and timestamp new_chunk->stream = c->stream; - new_chunk->stream_timestamp = c->stream_timestamp; + new_chunk->stream_sync_id = c->stream_sync_id; new_chunk->ptr = static_cast(static_cast(c->ptr) + num_bytes); region_manager_.set_handle(new_chunk->ptr, h_new_chunk); @@ -608,7 +587,7 @@ void BFCArena::Merge(BFCArena::ChunkHandle h1, // Set the new size c1->size += c2->size; - c1->stream_timestamp = std::max(c1->stream_timestamp, c2->stream_timestamp); + c1->stream_sync_id = std::max(c1->stream_sync_id, c2->stream_sync_id); DeleteChunk(h2); } @@ -815,7 +794,7 @@ void BFCArena::ResetChunkOnTargetStream(Stream* target_stream, bool coalesce_fla Chunk* c = ChunkFromHandle(h); if (c->stream == target_stream) { c->stream = nullptr; - c->stream_timestamp = 0; + c->stream_sync_id = 0; } h = c->next; } @@ -850,24 +829,23 @@ void BFCArena::ResetChunkOnTargetStream(Stream* target_stream, bool coalesce_fla StreamAwareArena::StreamAwareArena(std::unique_ptr resource_allocator, size_t total_memory, - bool enable_cross_stream_sharing, ArenaExtendStrategy arena_extend_strategy, int initial_chunk_size_bytes, int max_dead_bytes_per_chunk, int initial_growth_chunk_size_bytes, - int64_t max_power_of_two_extend_bytes) : BFCArena(std::move(resource_allocator), - total_memory, - arena_extend_strategy, - initial_chunk_size_bytes, - max_dead_bytes_per_chunk, - initial_growth_chunk_size_bytes, - max_power_of_two_extend_bytes), - enable_cross_stream_reusing_(enable_cross_stream_sharing) { + int64_t max_power_of_two_extend_bytes) + : BFCArena(std::move(resource_allocator), + total_memory, + arena_extend_strategy, + initial_chunk_size_bytes, + max_dead_bytes_per_chunk, + initial_growth_chunk_size_bytes, + max_power_of_two_extend_bytes) { arena_type_ = ArenaType::StreamAwareArena; } -void* StreamAwareArena::AllocOnStream(size_t size, Stream* current_stream, WaitNotificationFn wait_fn) { - return AllocateRawInternal(size, false, current_stream, enable_cross_stream_reusing_, wait_fn); +void* StreamAwareArena::AllocOnStream(size_t size, Stream* current_stream) { + return AllocateRawInternal(size, false, current_stream); } void StreamAwareArena::ReleaseStreamBuffers(Stream* stream) { @@ -875,17 +853,5 @@ void StreamAwareArena::ReleaseStreamBuffers(Stream* stream) { ResetChunkOnTargetStream(stream, true); } -void StreamAwareArena::SecureTheChunk(Stream* chunk_stream, Stream* target_stream, WaitNotificationFn wait_fn) const { - if (chunk_stream && target_stream && chunk_stream != target_stream) { - auto notification = chunk_stream->CreateNotification(1); - notification->ActivateAndUpdate(); - if (wait_fn) { - wait_fn(target_stream, *notification); - } - - target_stream->UpdateStreamClock(notification->GetStreamSyncTable()); - // it should be ok to release the notification now, as the wait is already launch to stream. - } -} #endif } // namespace onnxruntime diff --git a/onnxruntime/core/framework/bfc_arena.h b/onnxruntime/core/framework/bfc_arena.h index 8081738f2a5dc..f3c0544124241 100644 --- a/onnxruntime/core/framework/bfc_arena.h +++ b/onnxruntime/core/framework/bfc_arena.h @@ -27,7 +27,6 @@ limitations under the License. #include "core/common/logging/severity.h" #include "core/common/safeint.h" -#include #include "core/framework/arena_extend_strategy.h" #include "core/framework/allocator.h" @@ -103,18 +102,13 @@ class BFCArena : public IAllocator { ArenaType GetArenaType() const { return arena_type_; } - virtual void SecureTheChunk(Stream* /*chunk_stream*/, - Stream* /*target_stream*/, - WaitNotificationFn /*wait_fn*/) const {} - protected: void* AllocateRawInternal(size_t num_bytes, bool dump_log_on_failure, - Stream* stream, - bool enable_cross_stream_reusing, - WaitNotificationFn wait_fn); + Stream* stream); + #ifdef ORT_ENABLE_STREAM - // for any chunk that associated with target stream, reset it to default (nullptr in stream, timestamp 0) + // for any chunk that associated with target stream, reset it to default (nullptr in stream, sync id 0) // perform coalesce if coalesce_flag is true void ResetChunkOnTargetStream(Stream* target_stream, bool coalesce_flag); #endif @@ -168,7 +162,7 @@ class BFCArena : public IAllocator { Stream* stream = nullptr; - uint64_t stream_timestamp = 0; + uint64_t stream_sync_id = 0; bool in_use() const { return allocation_id != -1; } @@ -374,9 +368,7 @@ class BFCArena : public IAllocator { BFCArena::Chunk* FindChunkPtr(BinNum bin_num, size_t rounded_bytes, size_t num_bytes, - Stream* stream, - bool allow_chunk_from_different_stream, - WaitNotificationFn wait_fn = nullptr); + Stream* stream); // Splits the chunk specified by 'h' into two chunks, one at least // of size 'num_bytes'. @@ -516,33 +508,28 @@ class BFCArena : public IAllocator { ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(BFCArena); }; + #ifdef ORT_ENABLE_STREAM class StreamAwareArena : public BFCArena { public: StreamAwareArena(std::unique_ptr resource_allocator, size_t total_memory, - bool enable_dynamic_cross_stream_sharing, ArenaExtendStrategy arena_extend_strategy = DEFAULT_ARENA_EXTEND_STRATEGY, int initial_chunk_size_bytes = DEFAULT_INITIAL_CHUNK_SIZE_BYTES, int max_dead_bytes_per_chunk = DEFAULT_MAX_DEAD_BYTES_PER_CHUNK, int initial_growth_chunk_size_bytes = DEFAULT_INITIAL_GROWTH_CHUNK_SIZE_BYTES, int64_t max_power_of_two_extend_bytes = DEFAULT_MAX_POWER_OF_TWO_EXTEND_BYTES); - // If size is 0, then this function returns either NULL, - // or a unique pointer value that can later be successfully - // passed to free(). Whatever, do not dereference that pointer - void* AllocOnStream(size_t size, Stream* current_stream_id, WaitNotificationFn wait_fn); + bool IsStreamAware() const override { return true; } + + // Standard alloc behavior. Returns valid pointer if size > 0 and memory was available. Otherwise returns nullptr. + void* AllocOnStream(size_t size, Stream* current_stream_id) override; void ReleaseStreamBuffers(Stream* stream); static StreamAwareArena* FromBFCArena(BFCArena& arena) { return arena.GetArenaType() == ArenaType::StreamAwareArena ? reinterpret_cast(&arena) : nullptr; } - - virtual void SecureTheChunk(Stream* chunk_stream, Stream* target_stream, WaitNotificationFn wait_fn) const override; - - private: - bool enable_cross_stream_reusing_; }; #endif #ifdef __GNUC__ diff --git a/onnxruntime/core/framework/execution_frame.cc b/onnxruntime/core/framework/execution_frame.cc index cfaee309527a6..8030690e7c92d 100644 --- a/onnxruntime/core/framework/execution_frame.cc +++ b/onnxruntime/core/framework/execution_frame.cc @@ -26,17 +26,6 @@ using namespace onnxruntime::common; namespace onnxruntime { -#ifdef ORT_ENABLE_STREAM -static StreamAwareArena* AsStreamBasedAllocator(AllocatorPtr allocator) { - ORT_ENFORCE(allocator.get() != nullptr, "allocator is nullptr"); - if (allocator->Info().alloc_type == OrtArenaAllocator) { - BFCArena* arena_ptr = static_cast(allocator.get()); - return StreamAwareArena::FromBFCArena(*arena_ptr); - } - return nullptr; -} -#endif - IExecutionFrame::IExecutionFrame(const OrtValueNameIdxMap& ort_value_idx_map, const NodeIndexInfo& node_index_info, gsl::span fetch_mlvalue_idxs) @@ -441,13 +430,23 @@ ExecutionFrame::ExecutionFrame(gsl::span feed_mlvalue_idxs, gsl::span #endif // the memory pattern buffer will leave in the whole execution. #ifdef ORT_ENABLE_STREAM - StreamAwareArena* stream_aware_alloc = AsStreamBasedAllocator(alloc); - if (stream_aware_alloc && device_streams_) { + if (alloc->IsStreamAware() && device_streams_) { Stream* mem_pattern_stream = device_streams_->GetRootStream(); - buffer = stream_aware_alloc->AllocOnStream(peak_size, mem_pattern_stream, nullptr); - for (size_t j = 0; j < device_streams_->NumStreams(); j++) { - stream_aware_alloc->SecureTheChunk(mem_pattern_stream, device_streams_->GetStream(j), nullptr); - } + + buffer = alloc->AllocOnStream(peak_size, mem_pattern_stream); + + // this seems unnecessary. any memory pattern buffer would be in use for the entire inference, so + // there's no point at which another stream (as streams are per-inference) would be able to use it. + // given that, it's unclear why we need to update the sync id in all other streams to allow them + // to take this buffer if it was free. + // + // device_stream_collection calls ReleaseStreamBuffers for all streams including the root stream in + // CleanUp, so the chunk will become available to other streams at that point. + // + // Commenting out to verify. + // for (size_t j = 0; j < device_streams_->NumStreams(); j++) { + // stream_aware_arena->WaitOnChunk(mem_pattern_stream, device_streams_->GetStream(j)); + //} } else { buffer = alloc->Alloc(peak_size); } @@ -581,13 +580,9 @@ Status ExecutionFrame::AllocateMLValueTensorSelfOwnBufferHelper(OrtValue& ort_va Stream* current_stream = GetValueStream(ort_value_index); if (current_stream) { #ifdef ORT_ENABLE_STREAM - auto stream_aware_alloc = AsStreamBasedAllocator(alloc); - if (stream_aware_alloc) { + if (alloc->IsStreamAware()) { size_t buffer_size = Tensor::CalculateTensorStorageSize(element_type, shape); - // the reused memory must from same EP - auto wait_handle = this->session_state_.GetStreamHandleRegistryInstance().GetWaitHandle( - current_stream->GetDevice(), current_stream->GetDevice()); - void* p_data = stream_aware_alloc->AllocOnStream(buffer_size, current_stream, wait_handle); + void* p_data = alloc->AllocOnStream(buffer_size, current_stream); Tensor::InitOrtValue(element_type, shape, p_data, std::move(alloc), ort_value); } else { Tensor::InitOrtValue(element_type, shape, std::move(alloc), ort_value); diff --git a/onnxruntime/core/framework/execution_steps.cc b/onnxruntime/core/framework/execution_steps.cc index 36f663699be4f..61e26416f2321 100644 --- a/onnxruntime/core/framework/execution_steps.cc +++ b/onnxruntime/core/framework/execution_steps.cc @@ -23,25 +23,25 @@ std::string BarrierStep::ToString() const { return MakeString("Barrier - BarrierId: ", barrier_id_, ", Count: ", 2); } -WaitOnEPStep::WaitOnEPStep(WaitNotificationFn handle, - NotificationIndex idx, NodeIndex node_index) : SequentialExecutionPlan::ExecutionStep(node_index), - wait_handle_(handle), - notification_idx_(idx) {} +WaitOnEPStep::WaitOnEPStep(WaitNotificationFn handle, NotificationIndex idx, NodeIndex node_index) + : SequentialExecutionPlan::ExecutionStep(node_index), + wait_fn_(handle), + notification_idx_(idx) { + ORT_ENFORCE(wait_fn_, "WaitNoficationFn must be provided."); +} Status WaitOnEPStep::Execute(StreamExecutionContext& ctx, size_t stream_idx, SessionScope& /*session_scope*/, const bool& /*terminate_flag*/, bool& continue_flag) { - ORT_ENFORCE(wait_handle_, "WaitOnEPStep.wait_handle is null"); - auto* stream = ctx.GetDeviceStream(stream_idx); auto& notification = *ctx.GetNotification(notification_idx_); - wait_handle_(stream, notification); + wait_fn_(stream, notification); // update the stream's clock status if (stream != nullptr) { - stream->UpdateStreamClock(notification.GetStreamSyncTable()); + stream->UpdateWithAwaitedNotification(notification); } LOGS(ctx.GetLogger(), VERBOSE) << "stream " << stream_idx << " wait on Notification with id: " << notification_idx_; diff --git a/onnxruntime/core/framework/execution_steps.h b/onnxruntime/core/framework/execution_steps.h index 545dabc56b272..b3b3ee6c3ce63 100644 --- a/onnxruntime/core/framework/execution_steps.h +++ b/onnxruntime/core/framework/execution_steps.h @@ -38,7 +38,7 @@ class WaitOnEPStep : public SequentialExecutionPlan::ExecutionStep { std::string ToString() const override; private: - WaitNotificationFn wait_handle_; + WaitNotificationFn wait_fn_; NotificationIndex notification_idx_; }; diff --git a/onnxruntime/core/framework/plugin_ep_stream.h b/onnxruntime/core/framework/plugin_ep_stream.h index 2b89e76e16b76..09938403ad9b5 100644 --- a/onnxruntime/core/framework/plugin_ep_stream.h +++ b/onnxruntime/core/framework/plugin_ep_stream.h @@ -87,8 +87,8 @@ class Stream : public onnxruntime::Stream { return ToStatusAndRelease(ort_status); } - WaitNotificationFn GetWaitNotificationFn() const override { - return Notification::WaitNotificationOnDevice; + const OrtSyncStreamImpl& GetImpl() const { + return impl_; } ~Stream() override { diff --git a/onnxruntime/core/framework/session_state.cc b/onnxruntime/core/framework/session_state.cc index 2cd5103b823d1..98cc2158eb0d0 100644 --- a/onnxruntime/core/framework/session_state.cc +++ b/onnxruntime/core/framework/session_state.cc @@ -89,7 +89,8 @@ SessionState::SessionState(Graph& graph, profiling::Profiler& profiler, const SessionOptions& sess_options, PrepackedWeightsContainer* prepacked_weights_container, - AllocatorMap* parent_allocators) + AllocatorMap* parent_allocators, + AllocatorMap* parent_initializer_allocators) : graph_(graph), execution_providers_(execution_providers), logger_(logger), @@ -109,16 +110,26 @@ SessionState::SessionState(Graph& graph, sess_options_.execution_mode == ExecutionMode::ORT_SEQUENTIAL; if (parent_allocators) { allocators_ = parent_allocators; + initializer_allocators_ = parent_initializer_allocators; } else { allocators_unique_ptr_ = std::make_unique(); allocators_ = allocators_unique_ptr_.get(); + + initializer_allocators_unique_ptr_ = std::make_unique(); + initializer_allocators_ = initializer_allocators_unique_ptr_.get(); + // The allocator registration rule: // Each location (OrtDevice) will only have 1 allocator used for whole session. - // The EP which is registered first will have higher priority + // The EP which is registered first will have higher priority. + // Allocators with a OrtAllocatorType of OrtReadOnlyAllocator go in the initializer allocators for (auto& ep : execution_providers_) { auto allocators = ep->CreatePreferredAllocators(); for (auto& alloc : allocators) { - allocators_->insert({alloc->Info().device, alloc}); // DON'T overwrite existing key + if (alloc->Info().alloc_type == OrtReadOnlyAllocator) { + initializer_allocators_->insert({alloc->Info().device, alloc}); // DON'T overwrite existing key + } else { + allocators_->insert({alloc->Info().device, alloc}); // DON'T overwrite existing key + } } } } @@ -130,13 +141,29 @@ AllocatorPtr SessionState::GetAllocator(const OrtMemoryInfo& location) const noe AllocatorPtr SessionState::GetAllocator(const OrtDevice& device) const noexcept { auto it = allocators_->find(device); - if (it != allocators_->end()) return it->second; + if (it != allocators_->end()) { + return it->second; + } + return nullptr; } +AllocatorPtr SessionState::GetInitializerAllocator(const OrtDevice& device) const noexcept { + auto it = initializer_allocators_->find(device); + if (it != initializer_allocators_->end()) { + return it->second; + } + + return GetAllocator(device); +} + void SessionState::UpdateAllocatorsWithEnvAllocators(const std::vector& env_allocators) { for (const auto& env_alloc : env_allocators) { - (*allocators_)[env_alloc->Info().device] = env_alloc; + if (env_alloc->Info().alloc_type == OrtReadOnlyAllocator) { + (*initializer_allocators_)[env_alloc->Info().device] = env_alloc; + } else { + (*allocators_)[env_alloc->Info().device] = env_alloc; + } } } @@ -1158,7 +1185,7 @@ Status SessionState::CreateSubgraphSessionState() { std::make_unique(*subgraph, execution_providers_, thread_pool_, inter_op_thread_pool_, data_transfer_mgr_, external_data_loader_mgr_, logger_, profiler_, sess_options_, - prepacked_weights_container_, allocators_); + prepacked_weights_container_, allocators_, initializer_allocators_); // Pass fused function manager to subgraph subgraph_session_state->fused_funcs_mgr_.SetFusedFuncs(fused_funcs_mgr_); diff --git a/onnxruntime/core/framework/session_state.h b/onnxruntime/core/framework/session_state.h index 71b88cb692f6f..e2102d95e1f17 100644 --- a/onnxruntime/core/framework/session_state.h +++ b/onnxruntime/core/framework/session_state.h @@ -98,7 +98,8 @@ class SessionState { profiling::Profiler& profiler, const SessionOptions& sess_options, PrepackedWeightsContainer* prepacked_weights_container = nullptr, - AllocatorMap* parent_allocators = nullptr); + AllocatorMap* parent_allocators = nullptr, + AllocatorMap* parent_initializer_allocators = nullptr); ~SessionState() { } @@ -127,6 +128,12 @@ class SessionState { /** Get the allocator for a given OrtDevice. The first allocator that matches will be returned. */ AllocatorPtr GetAllocator(const OrtDevice& device) const noexcept; + /** + Get an allocator for the given OrtDevice that is only used for read-only initializers. + Falls back to calling GetAllocator as needed. + */ + AllocatorPtr GetInitializerAllocator(const OrtDevice& device) const noexcept; + /* * Get allocators. */ @@ -464,17 +471,18 @@ class SessionState { } }; - // using std::map as OrtDevice would need a custom hash function to be used with std::unordered_map, - // and as this isn't considered performance critical currently it's not worth the maintenance overhead of adding one. - // We do get an allocator from ExecutionFrame so this is looked up frequently, however there most likely aren't many - // entries in the map // SessionState will contain other SessionState objects for subgraph. The unique ptr will be initialized only the // SessionState object is in the parent graph, the raw pointer will be initialized when session state is in parent // graph (from the unique ptr) or in the subgraph (from the raw pointer from parent session state). The raw pointer // will be used all the way to access std::map, unique pointer is only releasing the resource // when the parent session state is releasing. std::unique_ptr allocators_unique_ptr_; + // allocators with type of OrtAllocatorType::OrtReadOnlyAllocator that are used for initializers if found. + // if not we fallback to lookup in allocators_; + std::unique_ptr initializer_allocators_unique_ptr_; + AllocatorMap* allocators_; + AllocatorMap* initializer_allocators_; OrtValueNameIdxMap ort_value_name_idx_map_; diff --git a/onnxruntime/core/framework/session_state_utils.cc b/onnxruntime/core/framework/session_state_utils.cc index 8f0713fcd7cb1..17e337838b091 100644 --- a/onnxruntime/core/framework/session_state_utils.cc +++ b/onnxruntime/core/framework/session_state_utils.cc @@ -37,8 +37,10 @@ namespace session_state_utils { // The following method will allocate memory directly using the device allocator. // It can handle arena-based allocators and non-arena based allocators. -static common::Status AllocateBufferUsingDeviceAllocatorFromShapeAndType(const TensorShape& tensor_shape, const DataTypeImpl* type, - const AllocatorPtr& alloc, /*out*/ void*& p_data) { +static common::Status AllocateBufferUsingDeviceAllocatorFromShapeAndType(const TensorShape& tensor_shape, + const DataTypeImpl* type, + const AllocatorPtr& alloc, + /*out*/ void*& p_data) { size_t mem_size = 0; ORT_RETURN_IF_ERROR(Tensor::CalculateTensorStorageSize(type, tensor_shape, /*alignment*/ 0, mem_size)); @@ -76,13 +78,14 @@ static common::Status AllocateBufferUsingDeviceAllocatorFromShapeAndType(const T * data loading, allocation, or copying operation fails. */ static common::Status DeserializeTensorProto(const Env& env, const std::basic_string& proto_path, - const ONNX_NAMESPACE::TensorProto& tensor_proto, const MemBuffer* memory_buffer, + const ONNX_NAMESPACE::TensorProto& tensor_proto, + const MemBuffer* memory_buffer, const AllocatorPtr& alloc, const AllocatorPtr& default_cpu_alloc, OrtValue& ort_value, const DataTransferManager& data_transfer_mgr, const ExternalDataLoaderManager& external_data_loader_mgr, PrepackedWeightsForGraph& prepacked_for_graph, bool use_device_allocator_for_initializers = false) { - if (bool(alloc) == (memory_buffer != nullptr)) { + if (alloc != nullptr && memory_buffer != nullptr) { return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "DeserializeTensorProto() takes either pre-allocated buffer or an allocator!"); } @@ -138,7 +141,8 @@ static common::Status DeserializeTensorProto(const Env& env, const std::basic_st } } else { // for internal initializer, always allocate memory on device - tensor - ORT_RETURN_IF_ERROR(AllocateTensor(memory_buffer, tensor, type, tensor_shape, use_device_allocator_for_initializers, alloc)); + ORT_RETURN_IF_ERROR(AllocateTensor(memory_buffer, tensor, type, tensor_shape, + use_device_allocator_for_initializers, alloc)); if (device == default_cpu_device) { // deserialize directly to CPU tensor @@ -370,6 +374,9 @@ common::Status SaveInitializedTensors( AllocatorPtr alloc; // TODO: if the tensor need be copied, does it have enough room? ORT_RETURN_IF_ERROR(planner.GetPreallocatedBuffer(ort_value_index, name, memory_buffer, alloc)); + + // ??? Should we ignore this session option if the EP is explictly providing the read only allocator? + // bool have_readonly_initializer_allocator = alloc->Info().alloc_type == OrtReadOnlyAllocator; const bool use_device_allocator_for_initializers = session_options.config_options.GetConfigOrDefault( kOrtSessionOptionsUseDeviceAllocatorForInitializers, "0") == "1"; @@ -398,9 +405,10 @@ common::Status SaveInitializedTensors( // We need to deserialize the tensor proto into an OrtValue // using the preallocated buffer or allocator. - Status st = DeserializeTensorProto(env, graph_loc, tensor_proto, (memory_buffer.has_value()) ? &*memory_buffer : nullptr, alloc, - default_cpu_alloc, ort_value, data_transfer_mgr, external_data_loader_mgr, - prepacked_for_graph, + Status st = DeserializeTensorProto(env, graph_loc, tensor_proto, + (memory_buffer.has_value()) ? &*memory_buffer : nullptr, + alloc, default_cpu_alloc, ort_value, data_transfer_mgr, + external_data_loader_mgr, prepacked_for_graph, use_device_allocator_for_initializers); if (!st.IsOK()) { std::ostringstream oss; diff --git a/onnxruntime/core/framework/simple_tensor_allocator.cc b/onnxruntime/core/framework/simple_tensor_allocator.cc index ad9e0393baa01..d919e0c3c4a13 100644 --- a/onnxruntime/core/framework/simple_tensor_allocator.cc +++ b/onnxruntime/core/framework/simple_tensor_allocator.cc @@ -14,7 +14,7 @@ common::Status SimpleTensorAllocator::GetPreallocatedBuffer(int ort_value_index, AllocatorPtr& alloc_out) { const struct OrtDevice& location = seq_plan_.GetLocation(ort_value_index); // just return allocator and let others handle it. - alloc_out = GetAllocator(location); + alloc_out = GetInitializerAllocator(location); return Status::OK(); } } // namespace onnxruntime diff --git a/onnxruntime/core/framework/stream_handles.cc b/onnxruntime/core/framework/stream_handles.cc new file mode 100644 index 0000000000000..ab608cdda87c4 --- /dev/null +++ b/onnxruntime/core/framework/stream_handles.cc @@ -0,0 +1,20 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/framework/stream_handles.h" + +#include + +namespace onnxruntime { + +void Stream::UpdateWithAwaitedNotification(const synchronize::Notification& notification) { + const std::unordered_map& stream_sync_info = notification.GetStreamSyncInfo(); + for (const auto& kv : stream_sync_info) { + auto ret = producer_stream_sync_info_.insert(kv); + if (!ret.second) { + // we already have an entry. use the highest value for the producer stream. + ret.first->second = std::max(ret.first->second, kv.second); + } + } +} +} // namespace onnxruntime diff --git a/onnxruntime/core/framework/tensor_allocator.cc b/onnxruntime/core/framework/tensor_allocator.cc index 9e81e8cd4783d..84d40d60a5d47 100644 --- a/onnxruntime/core/framework/tensor_allocator.cc +++ b/onnxruntime/core/framework/tensor_allocator.cc @@ -5,11 +5,14 @@ #include "simple_tensor_allocator.h" namespace onnxruntime { - AllocatorPtr ITensorAllocator::GetAllocator(const OrtDevice& device) { return session_state_.GetAllocator(device); } +AllocatorPtr ITensorAllocator::GetInitializerAllocator(const OrtDevice& device) { + return session_state_.GetInitializerAllocator(device); +} + std::unique_ptr ITensorAllocator::Create(bool enable_mem_pattern, const ExecutionPlanBase& execution_plan, const SessionState& session_state, diff --git a/onnxruntime/core/framework/tensor_allocator.h b/onnxruntime/core/framework/tensor_allocator.h index 923320681e683..daddfc7fd3cc0 100644 --- a/onnxruntime/core/framework/tensor_allocator.h +++ b/onnxruntime/core/framework/tensor_allocator.h @@ -26,6 +26,7 @@ class ITensorAllocator { InlinedVector& weights_buffers); AllocatorPtr GetAllocator(const OrtDevice& device); + AllocatorPtr GetInitializerAllocator(const OrtDevice& device); /** * diff --git a/onnxruntime/core/framework/tensor_allocator_with_mem_pattern.h b/onnxruntime/core/framework/tensor_allocator_with_mem_pattern.h index ad88149c89b81..414bc1c08adf4 100644 --- a/onnxruntime/core/framework/tensor_allocator_with_mem_pattern.h +++ b/onnxruntime/core/framework/tensor_allocator_with_mem_pattern.h @@ -28,7 +28,7 @@ class TensorAllocatorWithMemPattern : public ITensorAllocator { planned_memory_sizes_in_byte.reserve(location_len); for (size_t i = 0; i < location_len; ++i) { auto& location = mem_patterns_.locations[i]; - auto alloc = GetAllocator(location); + auto alloc = GetInitializerAllocator(location); if (!alloc) return Status(common::ONNXRUNTIME, common::FAIL, "Failed to get allocator for location: " + location.ToString()); @@ -39,14 +39,8 @@ class TensorAllocatorWithMemPattern : public ITensorAllocator { } const auto peak_size = mem_patterns_.patterns[i].PeakSize(); - void* buffer; - if (alloc->Info().alloc_type == OrtArenaAllocator) { - // Arena has a specific way to store static memory. - // Arena does not reuse static memory allocated by Reserve. - buffer = static_cast(alloc.get())->Reserve(peak_size); - } else { - buffer = alloc->Alloc(peak_size); - } + // use Reserve for initializers so they don't affect arena growth patterns if an arena is involved. + void* buffer = alloc->Reserve(peak_size); auto buffer_ptr = BufferUniquePtr(buffer, BufferDeleter(std::move(alloc))); auto kvp = buffers_.insert(std::make_pair(location, buffer)); @@ -80,25 +74,28 @@ class TensorAllocatorWithMemPattern : public ITensorAllocator { if (!is_sealed_) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Internal error."); } + const struct OrtDevice& location = seq_plan_.GetLocation(ort_value_index); auto pattern = mem_patterns_.GetPatterns(location); if (pattern == nullptr) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Mem pattern for initializer ", name, " is not found"); } + // if block is not found, means this ort_value is not traced // fall back to allocate separate buffer. // if it->second.get() is null, then fall back to the block not found case auto block = pattern->GetBlock(ort_value_index); if (nullptr == block) { // not traced, only return allocator - alloc_out = GetAllocator(location); + alloc_out = GetInitializerAllocator(location); return Status::OK(); } + auto it = buffers_.find(location); if (it == buffers_.end()) { if (block != nullptr && block->size_ == 0) { // Because the size is 0, this miss find is expected. we won't allocate a buffer with size of zero. - buf_out.emplace(nullptr, 0, GetAllocator(location)->Info()); + buf_out.emplace(nullptr, 0, GetInitializerAllocator(location)->Info()); return Status::OK(); } return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Weight buffer for initializer '", name, "' is not found"); @@ -108,7 +105,8 @@ class TensorAllocatorWithMemPattern : public ITensorAllocator { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Get preallocated buffer for initializer '", name, "' failed"); } - buf_out.emplace(reinterpret_cast(it->second) + block->offset_, block->size_, GetAllocator(location)->Info()); + buf_out.emplace(static_cast(it->second) + block->offset_, block->size_, + GetInitializerAllocator(location)->Info()); return Status::OK(); } common::Status Trace(int id, const ONNX_NAMESPACE::TensorProto* value) override { diff --git a/onnxruntime/core/framework/tensorprotoutils.cc b/onnxruntime/core/framework/tensorprotoutils.cc index 5ba0b908edaf5..ff440b595e499 100644 --- a/onnxruntime/core/framework/tensorprotoutils.cc +++ b/onnxruntime/core/framework/tensorprotoutils.cc @@ -252,6 +252,10 @@ bool HasExternalDataInMemory(const ONNX_NAMESPACE::TensorProto& ten_proto) { return false; // No external data in memory } +bool HasExternalDataInFile(const ONNX_NAMESPACE::TensorProto& tensor_proto) { + return HasExternalData(tensor_proto) && !HasExternalDataInMemory(tensor_proto); +} + Status TensorProtoWithExternalDataToTensorProto( const ONNX_NAMESPACE::TensorProto& ten_proto, const std::filesystem::path& model_path, diff --git a/onnxruntime/core/framework/tensorprotoutils.h b/onnxruntime/core/framework/tensorprotoutils.h index e9148243f98b1..01086f38c8823 100644 --- a/onnxruntime/core/framework/tensorprotoutils.h +++ b/onnxruntime/core/framework/tensorprotoutils.h @@ -502,6 +502,13 @@ inline bool HasName(const ONNX_NAMESPACE::TypeProto_Opaque& op_proto) { /// true if ten_proto has external data and it is in memory [[nodiscard]] bool HasExternalDataInMemory(const ONNX_NAMESPACE::TensorProto& tensor_proto); +/// +/// Check if the given tensor proto has external data store in a file (not in memory). +/// +/// +/// +[[nodiscard]] bool HasExternalDataInFile(const ONNX_NAMESPACE::TensorProto& tensor_proto); + /// /// This function converts TensorProto with external data to TensorProto with inline data. /// diff --git a/onnxruntime/core/framework/utils.cc b/onnxruntime/core/framework/utils.cc index c6bb5d931cbe6..2c0a51f0bfdbc 100644 --- a/onnxruntime/core/framework/utils.cc +++ b/onnxruntime/core/framework/utils.cc @@ -74,28 +74,21 @@ static common::Status AllocateHelper(const AllocatorPtr& allocator, if (source_mlvalue.IsTensor()) { const Tensor& source_tensor = source_mlvalue.Get(); - if (allocator->Info().alloc_type == OrtArenaAllocator) { - void* p_data = nullptr; -#ifdef ORT_ENABLE_STREAM - BFCArena* arena_ptr = static_cast(allocator.get()); - auto* stream_aware_alloc = StreamAwareArena::FromBFCArena(*arena_ptr); - if (stream_aware_alloc && target_stream) { - size_t len = Tensor::CalculateTensorStorageSize(source_tensor.DataType(), source_tensor.Shape()); - p_data = stream_aware_alloc->AllocOnStream(len, target_stream, nullptr); - } -#else - ORT_UNUSED_PARAMETER(target_stream); -#endif // ORT_ENABLE_STREAM - if (p_data == nullptr) { - Tensor::InitOrtValue(source_tensor.DataType(), - source_tensor.Shape(), - allocator, target_mlvalue); - } else { - Tensor::InitOrtValue(source_tensor.DataType(), - source_tensor.Shape(), - p_data, - allocator, target_mlvalue); + void* p_data = nullptr; + if (target_stream && allocator->IsStreamAware()) { + size_t len = Tensor::CalculateTensorStorageSize(source_tensor.DataType(), source_tensor.Shape()); + p_data = allocator->AllocOnStream(len, target_stream); + if (p_data == nullptr && len > 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Allocation failed."); } + } + + if (p_data) { + Tensor::InitOrtValue(source_tensor.DataType(), + source_tensor.Shape(), + p_data, + allocator, target_mlvalue); + } else { Tensor::InitOrtValue(source_tensor.DataType(), source_tensor.Shape(), diff --git a/onnxruntime/core/graph/abi_graph_types.h b/onnxruntime/core/graph/abi_graph_types.h index 0e939f7986aac..6383d29d7a2bc 100644 --- a/onnxruntime/core/graph/abi_graph_types.h +++ b/onnxruntime/core/graph/abi_graph_types.h @@ -7,6 +7,7 @@ #include #include #include "core/common/inlined_containers_fwd.h" +#include "core/framework/tensor_external_data_info.h" #include "core/framework/onnxruntime_typeinfo.h" #include "core/graph/onnx_protobuf.h" @@ -29,6 +30,9 @@ enum class OrtGraphIrApi { kEpApi, }; +// Alias OrtExternalInitializerInfo to the internal type. +struct OrtExternalInitializerInfo : onnxruntime::ExternalDataInfo {}; + /// /// Public type that represents an ONNX value info. /// @@ -94,6 +98,17 @@ struct OrtValueInfo { /// A status indicating success or an error. virtual onnxruntime::Status GetInitializerValue(const OrtValue*& value) const = 0; + /// + /// Get information (file path, file offset, byte size) if this OrtValueInfo represents an initializer with + /// data in an external file. + /// + /// Output parameter set to the external initializer info or NULL if this is not an external + /// initializer. + /// A status indicating an error or success. Calling this function on an OrtValueInfo that does not represent + /// an external initializer is NOT an error. + virtual onnxruntime::Status GetExternalInitializerInfo( + std::unique_ptr& ext_info) const = 0; + /// /// Determine if the value is a required graph input. /// diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index 5511275239e45..39bf2bf855976 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -1392,14 +1392,14 @@ constexpr const char* MoE_ver1_doc = R"DOC( ONNX_MS_OPERATOR_SET_SCHEMA(MoE, 1, OpSchema() .SetDoc(MoE_ver1_doc) - .Attr("activation_type", "Activation function to use. Choose from relu, gelu, silu and identity. Default is relu", AttributeProto::STRING, std::string("relu")) + .Attr("activation_type", "Activation function to use. Choose from relu, gelu, silu, swiglu and identity. Default is relu", AttributeProto::STRING, std::string("relu")) .Attr("k", "Number of top experts to select from expert pool", AttributeProto::INT, static_cast(1)) .Attr("normalize_routing_weights", "Whether to normalize routing weights", AttributeProto::INT, static_cast(0)) .Attr("use_sparse_mixer", "Whether to use sparse mixer", AttributeProto::INT, static_cast(0)) .Input(0, "input", "2D input tensor with shape (num_rows, hidden_size) or 3D input tensor with shape (batch_size, sequence_length, hidden_size)", "T") .Input(1, "router_probs", "2D input tensor with shape (num_rows, num_experts)", "T") - .Input(2, "fc1_experts_weights", "3D input tensor with shape (num_experts, hidden_size, inter_size)", "T") - .Input(3, "fc1_experts_bias", "2D optional input tensor with shape (num_experts, inter_size)", "T", OpSchema::Optional) + .Input(2, "fc1_experts_weights", "3D input tensor with shape (num_experts, hidden_size, inter_size), or (num_experts, hidden_size, 2 * inter_size) for swiglu", "T") + .Input(3, "fc1_experts_bias", "2D optional input tensor with shape (num_experts, inter_size), or (num_experts, 2 * inter_size) for swiglu", "T", OpSchema::Optional) .Input(4, "fc2_experts_weights", "3D input tensor with shape (num_experts, inter_size, hidden_size)", "T") .Input(5, "fc2_experts_bias", "2D optional input tensor with shape (num_experts, hidden_size)", "T", OpSchema::Optional) .Input(6, "fc3_experts_weights", "3D optional input tensor with shape (num_experts, hidden_size, inter_size)", "T", OpSchema::Optional) @@ -1413,7 +1413,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA( OpSchema() .SetDoc("Quantized MoE") .Attr("activation_type", - "Activation function to use. Choose from relu, gelu, silu and identity. Default is relu", + "Activation function to use. Choose from relu, gelu, silu, swiglu and identity. Default is relu", AttributeProto::STRING, std::string("relu")) .Attr("k", @@ -1438,12 +1438,12 @@ ONNX_MS_OPERATOR_SET_SCHEMA( .Input(2, "fc1_experts_weights", "3D input tensor with shape (num_experts, hidden_size, inter_size) " - "or (num_experts, hidden_size, inter_size / 2)", + "or (num_experts, hidden_size, inter_size / 2). For swiglu, shape can be (num_experts, hidden_size, 2 * inter_size) or (num_experts, hidden_size, inter_size).", "T1") - .Input(3, "fc1_scales", "2D input tensor with shape (num_experts, inter_size)", "T") + .Input(3, "fc1_scales", "2D input tensor with shape (num_experts, inter_size), or (num_experts, 2 * inter_size) for swiglu", "T") .Input(4, "fc1_experts_bias", - "2D optional input tensor with shape (num_experts, inter_size)", "T", OpSchema::Optional) + "2D optional input tensor with shape (num_experts, inter_size), or (num_experts, 2 * inter_size) for swiglu", "T", OpSchema::Optional) .Input(5, "fc2_experts_weights", "3D input tensor with shape (num_experts, inter_size, hidden_size) " diff --git a/onnxruntime/core/graph/ep_api_types.cc b/onnxruntime/core/graph/ep_api_types.cc index b7e5351556c61..4ceadb6191a9b 100644 --- a/onnxruntime/core/graph/ep_api_types.cc +++ b/onnxruntime/core/graph/ep_api_types.cc @@ -15,6 +15,7 @@ #include #include "core/framework/allocator.h" +#include "core/framework/tensor_external_data_info.h" #include "core/framework/tensorprotoutils.h" #include "core/framework/onnxruntime_typeinfo.h" #include "core/graph/graph_viewer.h" @@ -452,11 +453,29 @@ Status EpValueInfo::GetInitializerValue(const OrtValue*& result) const { // This gets an initializer value defined in this graph or in a parent graph (as long as the value // is used in this graph). - result = graph_->GetInitializerValue(name_); + ORT_RETURN_IF_ERROR(graph_->GetInitializerValue(name_, result)); ORT_RETURN_IF(result == nullptr, "Unable to find initializer value named '", name_, "'."); return Status::OK(); } +Status EpValueInfo::GetExternalInitializerInfo(std::unique_ptr& result) const { + if (!IsFlagSet(kIsConstantInitializer) && !IsFlagSet(kIsOptionalGraphInput)) { + result = nullptr; + return Status::OK(); + } + + ORT_RETURN_IF(graph_ == nullptr, "Unable to get external initializer information for value named '", + name_, "': parent graph is NULL"); + + const onnxruntime::Graph& graph = graph_->GetGraphViewer().GetGraph(); + + if (!graph.GetExternalInitializerInfo(name_, result, /*check_outer_scope*/ true)) { + result = nullptr; + } + + return Status::OK(); +} + Status EpValueInfo::IsRequiredGraphInput(bool& is_required_graph_input) const { is_required_graph_input = IsFlagSet(Flags::kIsRequiredGraphInput); return Status::OK(); @@ -593,15 +612,18 @@ Status EpGraph::CreateImpl(std::unique_ptr ep_graph, const GraphViewer& initializer_value_infos.push_back(value_info); // Initialize OrtValue for the initializer. + // Note: using std::unique_ptr because we return a OrtValue* to the user and we want it to be stable. auto initializer_value = std::make_unique(); bool graph_has_ortvalue = graph_viewer.GetGraph().GetOrtValueInitializer(initializer_name, *initializer_value, /*check_outer_scope*/ false); if (!graph_has_ortvalue) { - // onnxruntime::Graph does not have an OrtValue for this initializer, so create one from the TensorProto. - // This should only happen for small initializers that are needed for ONNX shape inferencing. - ORT_RETURN_IF_ERROR(utils::TensorProtoToOrtValue(Env::Default(), graph_viewer.ModelPath(), *tensor_proto, - initializer_allocator, *initializer_value)); + // Copy to OrtValue if not external. This should only happen for small initializers. + // Do nothing for external initializers, as we will load/mmap into an OrtValue on demand. + if (!utils::HasExternalDataInFile(*tensor_proto)) { + ORT_RETURN_IF_ERROR(utils::TensorProtoToOrtValue(Env::Default(), graph_viewer.ModelPath(), *tensor_proto, + initializer_allocator, *initializer_value)); + } } initializer_values.emplace(value_info->GetName(), std::move(initializer_value)); @@ -650,8 +672,10 @@ Status EpGraph::CreateImpl(std::unique_ptr ep_graph, const GraphViewer& } EpValueInfo* outer_value_info = value_info_iter->second.get(); - bool is_constant = false; + + // Note: using std::unique_ptr because we return a OrtValue* to the user and we want it to be stable. auto outer_initializer_value = std::make_unique(); + bool is_constant = false; const ONNX_NAMESPACE::TensorProto* outer_initializer = parent_graph->GetInitializer(implicit_name, *outer_initializer_value, is_constant, @@ -665,11 +689,13 @@ Status EpGraph::CreateImpl(std::unique_ptr ep_graph, const GraphViewer& // Add the OrtValue if this is an initializer. if (outer_initializer != nullptr) { if (!outer_initializer_value->IsAllocated()) { - // onnxruntime::Graph does not have an OrtValue for this initializer, so create one from the TensorProto. - // This should only happen for small initializers that are needed for ONNX shape inferencing. - ORT_RETURN_IF_ERROR(utils::TensorProtoToOrtValue(Env::Default(), parent_graph->ModelPath(), - *outer_initializer, initializer_allocator, - *outer_initializer_value)); + // Copy to OrtValue if not external. This should only happen for small initializers. + // Do nothing for external initializers. Will load/mmap into an OrtValue on demand. + if (!utils::HasExternalDataInFile(*outer_initializer)) { + ORT_RETURN_IF_ERROR(utils::TensorProtoToOrtValue(Env::Default(), parent_graph->ModelPath(), + *outer_initializer, initializer_allocator, + *outer_initializer_value)); + } } outer_scope_initializer_values.emplace(outer_value_info->GetName(), std::move(outer_initializer_value)); } @@ -808,20 +834,40 @@ const EpNode* EpGraph::GetNode(NodeIndex node_index) const { return index_to_ep_node_.GetEpNode(node_index); } -const OrtValue* EpGraph::GetInitializerValue(std::string_view name) const { +Status EpGraph::GetInitializerValue(std::string_view name, const OrtValue*& result) const { + auto ensure_ort_value_loaded = [&](const std::unique_ptr& ort_value) -> Status { + if (!ort_value->IsAllocated()) { + // Lazy load the OrtValue. This happens for external initializers. + const Graph& graph = graph_viewer_.GetGraph(); + ORT_RETURN_IF_ERROR(graph.LoadExternalInitializerAsOrtValue(std::string(name), + const_cast(*ort_value))); + } + + return Status::OK(); + }; + // Check for initializer value in the graph's scope. if (auto iter = initializer_values_.find(name); iter != initializer_values_.end()) { - return iter->second.get(); + const std::unique_ptr& ort_value = iter->second; + ORT_RETURN_IF_ERROR(ensure_ort_value_loaded(ort_value)); + + result = ort_value.get(); + return Status::OK(); } // Check for the initializer value in an outer scope. // Only finds a value if the outer initializer value is used within this graph. if (auto iter = outer_scope_initializer_values_.find(name); iter != outer_scope_initializer_values_.end()) { - return iter->second.get(); + const std::unique_ptr& ort_value = iter->second; + ORT_RETURN_IF_ERROR(ensure_ort_value_loaded(ort_value)); + + result = ort_value.get(); + return Status::OK(); } - return nullptr; + result = nullptr; + return Status::OK(); } } // namespace onnxruntime diff --git a/onnxruntime/core/graph/ep_api_types.h b/onnxruntime/core/graph/ep_api_types.h index b9a494364a12e..243bdc2944ffb 100644 --- a/onnxruntime/core/graph/ep_api_types.h +++ b/onnxruntime/core/graph/ep_api_types.h @@ -65,6 +65,10 @@ struct EpValueInfo : public OrtValueInfo { // represent an initializer (either constant or non-constant). Status GetInitializerValue(const OrtValue*& value) const override; + // Gets external initializer information (file path, file offset, byte size) if this is an external initializer. + // Otherwise, sets the output parameter `ext_info` to nullptr (not an error). + Status GetExternalInitializerInfo(std::unique_ptr& ext_info) const override; + // Check if this value is a required graph input. Status IsRequiredGraphInput(bool& is_required_graph_input) const override; @@ -351,7 +355,7 @@ struct EpGraph : public OrtGraph { // Considers both constant and non-constant initializers. // Supports initializers defined in an outer scope as long as that initializer is used // within this graph. - const OrtValue* GetInitializerValue(std::string_view name) const; + Status GetInitializerValue(std::string_view name, const OrtValue*& value) const; private: /// diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index b929c27b21ec3..de6776b0e0df1 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -3815,6 +3815,48 @@ bool Graph::GetOrtValueInitializer(const std::string& name, OrtValue& value, boo return false; } +Status Graph::LoadExternalInitializerAsOrtValue(const std::string& name, OrtValue& value) const { + auto tensor_proto_iter = name_to_initial_tensor_.find(name); + ORT_RETURN_IF(tensor_proto_iter == name_to_initial_tensor_.end(), "Cannot load unknown initializer named '", + name, "'."); + const ONNX_NAMESPACE::TensorProto& tensor_proto = *tensor_proto_iter->second; + + // This only supports TensorProtos that currently have their external data in an actual file. + // This doesn't cache the new OrtValue because it would overwrite TensorProto.external_data and plugin EPs require + // every call to Graph::GetExternalInitializerInfo to return the actual external file info (path, offset, length). + ORT_RETURN_IF(!utils::HasExternalDataInFile(tensor_proto), "Initializer '", name, + "' does not have external data in a file."); + + // Create OrtValue that memory maps external initializer from file. + ORT_RETURN_IF_ERROR(utils::GetExtDataFromTensorProto(Env::Default(), ModelPath(), tensor_proto, value)); + assert(value.IsAllocated()); + + return Status::OK(); +} + +bool Graph::GetExternalInitializerInfo(const std::string& name, std::unique_ptr& ext_info, + bool check_outer_scope) const { + // We want to make sure that the external data info is found on the same level as its tensor_proto + const ONNX_NAMESPACE::TensorProto* initializer = nullptr; + if (GetInitializedTensor(name, initializer)) { + if (utils::HasExternalDataInFile(*initializer)) { + std::unique_ptr external_data_info; + ORT_THROW_IF_ERROR(ExternalDataInfo::Create(initializer->external_data(), external_data_info)); + + ext_info = std::move(external_data_info); + return true; + } + } + + if (check_outer_scope && IsSubgraph()) { + if (IsOuterScopeValue(name)) { + // make sure there's not a local value with the same name. if there is it shadows any initializer in outer scope. + return parent_graph_->GetExternalInitializerInfo(name, ext_info, check_outer_scope); + } + } + return false; +} + void Graph::CleanAllInitializedTensors() noexcept { name_to_initial_tensor_.clear(); #if !defined(DISABLE_SPARSE_TENSORS) @@ -5202,7 +5244,7 @@ Status Graph::AddConstantProtoAsInitializer(const ONNX_NAMESPACE::NodeProto& nod // In the constant node, we won't have symbolic dims. const auto tensor_shape = utils::GetTensorShapeFromTensorProto(tensor_proto); auto ml_data = DataTypeImpl::TensorTypeFromONNXEnum(tensor_proto.data_type())->GetElementType(); - const size_t size_in_bytes = SafeInt(ml_data->Size()) * tensor_shape.Size(); + const size_t size_in_bytes = Tensor::CalculateTensorStorageSize(ml_data, tensor_shape); if (size_in_bytes > utils::kSmallTensorExternalDataThreshold) { OrtValue ort_value; diff --git a/onnxruntime/core/graph/graph_proto_serializer.cc b/onnxruntime/core/graph/graph_proto_serializer.cc index 0fbcea2719ce8..9a67796254231 100644 --- a/onnxruntime/core/graph/graph_proto_serializer.cc +++ b/onnxruntime/core/graph/graph_proto_serializer.cc @@ -95,7 +95,7 @@ void GraphViewerToProto(const GraphViewer& graph_view, auto* p_initializer = graph_proto.add_initializer(); // Do not save raw into the graph, only the metadata - if (!include_initializer_data && init->has_raw_data()) { + if (!include_initializer_data && (init->has_raw_data() || utils::HasExternalDataInMemory(*init))) { // Set datatype if (init->has_data_type()) { p_initializer->set_data_type(init->data_type()); diff --git a/onnxruntime/core/graph/model_editor_api_types.h b/onnxruntime/core/graph/model_editor_api_types.h index 07c7080d74c7c..5d84e48182bfe 100644 --- a/onnxruntime/core/graph/model_editor_api_types.h +++ b/onnxruntime/core/graph/model_editor_api_types.h @@ -49,6 +49,12 @@ struct ModelEditorValueInfo : public OrtValueInfo { "OrtModelEditorApi does not support getting the initializer value for a OrtValueInfo"); } + Status GetExternalInitializerInfo(std::unique_ptr& /*ext_info*/) const override { + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, + "OrtModelEditorApi does not support getting the external initializer information ", + "for a OrtValueInfo"); + } + Status IsRequiredGraphInput(bool& /*is_required_graph_input*/) const override { return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "OrtModelEditorApi does not support querying if a graph input is required for OrtValueInfo"); diff --git a/onnxruntime/core/mlas/inc/mlas.h b/onnxruntime/core/mlas/inc/mlas.h index 4d85c35461825..22bddf58997bc 100644 --- a/onnxruntime/core/mlas/inc/mlas.h +++ b/onnxruntime/core/mlas/inc/mlas.h @@ -631,6 +631,52 @@ MlasGemm( { MlasGemmBatch(Shape, &DataParams, 1, ThreadPool); } +/** + * @brief Parameters that define the shape of a dynamically quantized GEMM operation. + * + * The structure holds the dimensions of the matrices involved in the GEMM + * computation: + * C = A * B + */ +struct MLAS_GEMM_DYN_QUANT_SHAPE_PARAMS { + size_t M = 0; /**< Row size of matrix A */ + size_t N = 0; /**< Column size of matrix B */ + size_t K = 0; /**< Column size of matrix A and Row size of matrix B */ +}; +/** + * @brief Parameters that define the data buffers and layout for a dynamic quant GEMM. + * + * This structure provides the memory pointers and strides for matrices + * involved in a dynamically quantized GEMM operation, along with the packed B format. + */ +struct MLAS_GEMM_DYN_QUANT_DATA_PARAMS { + const float* A = nullptr; /**< Pointer to input matrix A in FP32 format**/ + size_t lda = 0; /**< Number of elements between adjecent rows in A*/ + const void* PackedB = 0; /**< Points to packed weight matrix B */ + float *C = nullptr; /**< Points to output Matric C */ + size_t ldc = 0; /**< Number of elements between adjecent rows in Matrix C*/ + void* Workspace = nullptr; /**< Workspace buffer for LHS Packing Allocation */ + size_t WorkspaceSize = 0; /**< Workspace buffer size */ +}; + +void +MLASCALL +MlasDynamicQGemmBatch ( + const MLAS_GEMM_DYN_QUANT_SHAPE_PARAMS& Shape, + const MLAS_GEMM_DYN_QUANT_DATA_PARAMS* DataParams, + const size_t BatchN, + MLAS_THREADPOOL* ThreadPool +); + +inline void +MlasDynamicQGemm ( + const MLAS_GEMM_DYN_QUANT_SHAPE_PARAMS& Shape, + const MLAS_GEMM_DYN_QUANT_DATA_PARAMS* DataParams, + MLAS_THREADPOOL* ThreadPool +) { + MlasDynamicQGemmBatch(Shape, DataParams, 1, ThreadPool); +} + // // Symmetric QGEMM has limited buffer overrun. @@ -685,6 +731,8 @@ MlasSymmQgemmBatch( size_t MLASCALL MlasGemmPackBSize( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, size_t N, size_t K ); @@ -692,6 +740,7 @@ MlasGemmPackBSize( void MLASCALL MlasGemmPackB( + CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB, size_t N, size_t K, @@ -750,6 +799,26 @@ MlasSymmQgemmPackB( void* PackedB ); + +size_t +MLASCALL +MlasDynamicQgemmPackBSize( + size_t N, + size_t K +); + +void +MLASCALL +MlasDynamicQgemmPackB( + size_t N, + size_t K, + const int8_t* B, + const float* Scales, + const float* Bias, + void* PackedB +); + + // // Convolution routines. // @@ -2024,3 +2093,14 @@ MlasFlashAttention( MlasFlashAttentionThreadedArgs* args, MLAS_THREADPOOL* ThreadPool ); + +#if defined(USE_KLEIDIAI) && !defined(_MSC_VER) +/** + * @brief Function to override the packing mechanism decision if kleidi ai is included + * @param enable enable kleidiai packing (allow or disallow depending on true/false) + * @return +*/ +void +MLASCALL +MlasGemmBatchPackUseKleidi(bool enable); +#endif diff --git a/onnxruntime/core/mlas/lib/convolve.cpp b/onnxruntime/core/mlas/lib/convolve.cpp index ec79641559c6b..bc1221475fd90 100644 --- a/onnxruntime/core/mlas/lib/convolve.cpp +++ b/onnxruntime/core/mlas/lib/convolve.cpp @@ -861,6 +861,12 @@ Return Value: --*/ { + // Override + if(GetMlasPlatform().MlasConvOverride != nullptr && + GetMlasPlatform().MlasConvOverride(Parameters,Input,Filter,Bias,WorkingBuffer,Output,ThreadPool)){ + return; + } + const size_t FilterCount = Parameters->FilterCount; const size_t OutputSize = Parameters->OutputSize; const size_t K = Parameters->K; @@ -1094,6 +1100,13 @@ Return Value: --*/ { + // Override + if (GetMlasPlatform().MlasConvPrepareOverride != nullptr && + GetMlasPlatform().MlasConvPrepareOverride(Parameters, Dimensions, BatchCount, GroupCount, InputChannels, + InputShape,KernelShape,DilationShape, Padding, StrideShape, OutputShape, FilterCount, + Activation, WorkingBufferSize, Beta, ThreadPool)){ + return; + } // // Save the convolution parameters. // @@ -1299,4 +1312,4 @@ Return Value: } #if defined(_MSC_VER) && !defined(__clang__) #pragma warning(pop) -#endif \ No newline at end of file +#endif diff --git a/onnxruntime/core/mlas/lib/kleidiai/convolve_kleidiai.cpp b/onnxruntime/core/mlas/lib/kleidiai/convolve_kleidiai.cpp new file mode 100644 index 0000000000000..9eaf4902f536a --- /dev/null +++ b/onnxruntime/core/mlas/lib/kleidiai/convolve_kleidiai.cpp @@ -0,0 +1,720 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: MIT +// + +#include +#include +#include +#include +#include "mlasi_kleidiai.h" +#include +#include + +#include "kai/ukernels/matmul/imatmul_clamp_f32_f32p_f32p/kai_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa.h" +#include "kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x32p2vlx1_x32p_sme.h" +#include "kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme.h" + +// Right-hand-side (weights) cache key +struct RhsCacheKey { + size_t co, ci, kh, kw, dilationh, dilationw; + size_t weights_hash; + + bool operator==(const RhsCacheKey& other) const { + return co == other.co && ci == other.ci && + kh == other.kh && kw == other.kw && + dilationh == other.dilationh && dilationw == other.dilationw && + weights_hash == other.weights_hash; + } +}; + + +// Left-hand-side (input indirection) cache key +struct LhsCacheKey { + size_t ci, ih, iw; + size_t padding, sh, sw; + size_t kh, kw; + size_t dilationh, dilationw; + size_t data_hash; + + bool operator==(const LhsCacheKey& other) const { + return ci == other.ci && ih == other.ih && iw == other.iw && + padding == other.padding && sh == other.sh && sw == other.sw && + kh == other.kh && kw == other.kw && + dilationh == other.dilationh && dilationw == other.dilationw && + data_hash == other.data_hash; + } +}; + +// Derived from 2^32 * (sqrt(5) - 1) / 2 ≈ 0.6180339887 (reciprocal of the golden ratio) +// Based on Knuth's multiplicative hashing method +constexpr size_t HASH_GOLDEN_RATIO_CONST = 0x9e3779b9; + +size_t HashWeights(const float* data, size_t count = 16) { + size_t h = 0; + for (size_t i = 0; i < count; ++i) { + h ^= std::hash()(data[i]) + HASH_GOLDEN_RATIO_CONST + (h << 6) + (h >> 2); + } + return h; +} + +namespace std { + // Specialize hash type for cache keys and do it within namespace std. + // Doing this allows standard containers like std::unordered_map to find + // the appropriate hash function via template specialization, as ADL + // (argument-dependent lookup) does not apply to std::hash. + template<> + struct hash { + size_t operator()(const RhsCacheKey& k) const { + return k.weights_hash ^ + (std::hash()(k.co) << 1) ^ + (std::hash()(k.ci) << 2) ^ + (std::hash()(k.kh) << 3) ^ + (std::hash()(k.kw) << 4) ^ + (std::hash()(k.dilationh) << 5) ^ + (std::hash()(k.dilationw) << 6); + } + }; + + template<> + struct hash { + size_t operator()(const LhsCacheKey& k) const { + return k.data_hash ^ + (std::hash()(k.ci) << 1) ^ + (std::hash()(k.ih) << 2) ^ + (std::hash()(k.iw) << 3) ^ + (std::hash()(k.padding) << 4) ^ + (std::hash()(k.sh) << 5) ^ + (std::hash()(k.sw) << 6) ^ + (std::hash()(k.kh) << 7) ^ + (std::hash()(k.kw) << 8) ^ + (std::hash()(k.dilationh) << 9) ^ + (std::hash()(k.dilationw) << 10); + } + }; + +} + + +static constexpr size_t ComputeKernelSize(const size_t D, const size_t K) { + // D - dilation size + // K - kernel size + + // D*S scale 1D kernel dimension by dilation factor + // (D-1) remove affect of dilation scaling at kernel end + return (D*K) - (D - 1); +} + +static constexpr size_t ComputeConvOutSize(const size_t L, const size_t K, const size_t P, const size_t S) { + + //With start + end padding + + //L - input size + //K - kernel size + //P - Padding size + //S - stride size + + //Does the convolution compute one value or less ? + if ( S > 0 && (L + 2*P) >= K) { + // L-(K-1) standard convolution output size is L-(K-1) for a step size of 1 with no padding + // (2*P) 1D start and end padding + // (L+2*P)-(K-1) the 1D length of convolution result for a kernel step size of 1 + // /S apply the kernel step + return (((L - K) + (2 * P)) / S) + 1; + } + return 0; +} + +static size_t ComputeMlasWorkingBufferSize(const size_t co, + const size_t ih, const size_t iw, + const size_t kh, const size_t kw, + const size_t dilationh, const size_t dilationw, + const size_t sh, const size_t sw, + const size_t padding) { + // dimensions of dilated kernel + const auto d_kh = ComputeKernelSize(dilationh, kh); + const auto d_kw = ComputeKernelSize(dilationw, kw); + + const auto m = ComputeConvOutSize(ih, d_kh, padding, sh) * + ComputeConvOutSize(iw, d_kw, padding, sw); + + return m * co; +} + +static bool CheckCapabilitiesSme(const MLAS_CONV_PARAMETERS* Parameters) { + + //functional checks - logically can the conv be performed + if ((Parameters->Dimensions != 2) || + (Parameters->BatchCount != 1) || + (Parameters->Beta != 0.f) || + (Parameters->Padding[0] != Parameters->Padding[1]) || + (Parameters->Padding[0] != Parameters->Padding[2]) || + (Parameters->Padding[0] != Parameters->Padding[3]) || + (ComputeConvOutSize(Parameters->InputShape[0], + ComputeKernelSize(Parameters->DilationShape[0],Parameters->KernelShape[0]), + Parameters->Padding[0], Parameters->StrideShape[0]) * + ComputeConvOutSize(Parameters->InputShape[1], + ComputeKernelSize(Parameters->DilationShape[1],Parameters->KernelShape[1]), + Parameters->Padding[1], Parameters->StrideShape[1]) == 0)) { + return false; + } + + //optimization checks - is the implementation optimal for the conv request + + const auto n_step = kai_get_n_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa(); + const auto m_step = kai_get_m_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa(); + + auto M = ComputeConvOutSize(Parameters->InputShape[0], ComputeKernelSize(Parameters->DilationShape[0], + Parameters->KernelShape[0]), Parameters->Padding[0], Parameters->StrideShape[0]) * + ComputeConvOutSize(Parameters->InputShape[1], ComputeKernelSize(Parameters->DilationShape[1], + Parameters->KernelShape[1]), Parameters->Padding[1], Parameters->StrideShape[1]); + auto N = Parameters->FilterCount; + auto K = Parameters->InputChannels * Parameters->KernelShape[0] * Parameters->KernelShape[1]; + + //Can use these variables to add other conditions as required + MLAS_UNREFERENCED_PARAMETER(M); + MLAS_UNREFERENCED_PARAMETER(K); + MLAS_UNREFERENCED_PARAMETER(m_step); + MLAS_UNREFERENCED_PARAMETER(n_step); + + if (N == 1 || Parameters->KernelShape[0] < 3 || Parameters->KernelShape[1] < 3) { + return false; + } + return true; +} + +//General purpose axis swapping +static auto Transpose4D(std::array shape_in, + const float* in, + std::array permute) { + + std::array shape_out{shape_in[permute[0]], + shape_in[permute[1]], + shape_in[permute[2]], + shape_in[permute[3]]}; + + assert((shape_in[0] * shape_in[1] * shape_in[2] * shape_in[3]) == + (shape_out[0] * shape_out[1] * shape_out[2] * shape_out[3])); + assert(permute[0] < 4 && permute[1] < 4 && permute[2] < 4 && permute[3] < 4); + + const size_t get_stride[] {shape_in[1] * shape_in[2] * shape_in[3], shape_in[2] * shape_in[3], shape_in[3]}; + auto get = [get_stride,in](const std::array& el) { + return in[el[0] * get_stride[0] + + el[1] * get_stride[1] + + el[2] * get_stride[2] + + el[3]]; + }; + + auto out_ = std::make_unique(shape_in[0] * shape_in[1] * shape_in[2] * shape_in[3]); + auto out = out_.get(); + + const size_t set_stride[]{shape_out[1] * shape_out[2] * shape_out[3], shape_out[2] * shape_out[3], shape_out[3]}; + auto set = [set_stride,out](const std::array& el, float v) { + out[el[0] * set_stride[0] + + el[1] * set_stride[1] + + el[2] * set_stride[2] + + el[3]] = v; + }; + + std::array shape; + for (shape[0] = 0; shape[0] < shape_in[0]; ++shape[0]) { + for (shape[1] = 0; shape[1] < shape_in[1]; ++shape[1]) { + for (shape[2] = 0; shape[2] < shape_in[2]; ++shape[2]) { + for (shape[3] = 0; shape[3] < shape_in[3]; ++shape[3]) { + set({shape[permute[0]], shape[permute[1]], shape[permute[2]], shape[permute[3]]}, get(shape)); + } + } + } + } + + return out_; +} + +//nchw to nhwc specific axis swapping +static std::unique_ptr NChwToNhwc(const size_t n, + const size_t c, + const size_t h, + const size_t w, + const float* RESTRICT in, + const size_t dilationh=1, + const size_t dilationw=1, + const bool zero_fill=false, + MLAS_THREADPOOL* ThreadPool=nullptr) { + + const auto d_h = ComputeKernelSize(dilationh, h); + const auto d_w = ComputeKernelSize(dilationw, w); + + auto t = std::make_unique(n*d_h*d_w*c); + if (zero_fill) { + std::fill(&t.get()[0], &t.get()[n*d_h*d_w*c], 0.f); + } + + if (dilationh > 1 || dilationw > 1 || n > 1) { + const size_t get_strides[] {c*h*w,h*w,w}; + auto get = [get_strides,in](const std::array& el) { + return in[el[0]*get_strides[0] + + el[1]*get_strides[1] + + el[2]*get_strides[2] + + el[3]]; + }; + + const size_t set_strides[] {d_h*d_w*c,dilationh*d_w*c,dilationw*c}; + auto set = [set_strides](const std::array& el, float v, float* out) { + out[el[0]*set_strides[0] + + el[1]*set_strides[1] + + el[2]*set_strides[2] + + el[3]] = v; + }; + + MLAS_UNREFERENCED_PARAMETER(set); + MLAS_UNREFERENCED_PARAMETER(get); + + auto out0 = t.get(); + for (size_t s0 = n; s0 > 0; --s0) { + auto out1 = out0; + for (size_t s1 = c; s1 > 0; --s1) { + auto out2 = out1; + for (size_t s2 = h; s2 > 0; --s2) { + float* RESTRICT out3 = out2; + size_t s3 = w; + for (; s3 > 4; s3 -= 4) { + auto vf32 = MlasLoadFloat32x4(in); + in += 4; + MlasStoreLaneFloat32x4<0>(out3,vf32); + out3 += set_strides[2]; + MlasStoreLaneFloat32x4<1>(out3,vf32); + out3 += set_strides[2]; + MlasStoreLaneFloat32x4<2>(out3,vf32); + out3 += set_strides[2]; + MlasStoreLaneFloat32x4<3>(out3, vf32); + out3 += set_strides[2]; + } + for (; s3 > 0; --s3) { + //set({s0,s2,s3,s1}, get({s0,s1,s2,s3}),t.get()); + *out3 = *in++; + out3 += set_strides[2]; + } + out2 += set_strides[1]; + } + out1++; + } + out0 += set_strides[0]; + } + } else { + MlasTranspose(in, t.get(), c, d_h*d_w, ThreadPool); + } + + return t; +} + +static void MultiThreadedLHSPackSme(MLAS_THREADPOOL* ThreadPool, const size_t ci, const size_t m, const size_t kh, + const size_t kw, const void * const* lhs_ptrs, std::byte* lhs_data, + const float* in_data, + const float* pad_ptr) { + + auto m_step = kai_get_m_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa(); + + // Minimize the kernel call count for the number of available threads + auto RequiredTiles = MlasDivRoundup(m, m_step); + auto MaxTiles = std::min(static_cast(MlasGetMaximumThreadCount(ThreadPool)), RequiredTiles); + m_step *= MlasDivRoundup(RequiredTiles, MaxTiles); + RequiredTiles = MlasDivRoundup(m, m_step); + + MlasTrySimpleParallel(ThreadPool, static_cast(RequiredTiles), [&](ptrdiff_t tid) { + + auto m_idx = static_cast(tid) * m_step; + auto offset = kai_get_lhs_packed_offset_lhs_imatmul_pack_x32p2vlx1_x32p_sme(m_idx,kh*kw,ci); + + kai_run_lhs_imatmul_pack_x32p2vlx1_x32p_sme( + m < (m_idx + m_step) ? m - m_idx : m_step, kh * kw, ci, + lhs_ptrs + m_idx * kh * kw, + reinterpret_cast(in_data), + reinterpret_cast(pad_ptr), + lhs_data + offset + ); + }); +} + +static std::shared_ptr RhsPackWeightsBiasSme(const size_t co, const size_t ci, + const size_t kh, const size_t kw, + const size_t dilationh, const size_t dilationw, + const float* weights, const float* bias, + MLAS_THREADPOOL* ThreadPool) +{ + //cache of prepacked kai rhs weights and biases + static std::unordered_map> rhs_cache; + + RhsCacheKey key = { co, ci, kh, kw, dilationh, dilationw, HashWeights(weights) }; + + auto found = rhs_cache.find(key); + if (found != rhs_cache.end()) { + return found->second; + } else { + // prepare mlas filter weights for kai rhs packing + // dilated nhwc format + auto nhwc = NChwToNhwc(co, ci, kh, kw, weights, dilationh, dilationw, true, ThreadPool); + + + //dilation, axis swap (n x k -> k x n) where n == co, k == d_kh x d_kw x ci + const auto d_kh = ComputeKernelSize(dilationh,kh); + const auto d_kw = ComputeKernelSize(dilationw,kw); + + //t_weights[d_kh][d_kw][ci][co] = nhwc[co][d_kh][d_kw][ci] + auto t_weights = Transpose4D({co,d_kh,d_kw,ci},&nhwc[0],{1,2,3,0}); + + const auto packed_size = kai_get_rhs_packed_size_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme(co,d_kh*d_kw,ci); + auto packed = std::shared_ptr(new std::byte[packed_size], std::default_delete()); + + rhs_cache[key] = packed; + + std::vector bias_copy; + if (bias) { + bias_copy.assign(bias, bias + co); + } else { + bias_copy.resize(co, 0.0f); + } + + kai_run_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme( + co, d_kh*d_kw, ci, co * sizeof(float), &t_weights[0], bias_copy.data(), packed.get() + ); + + return packed; + } +} + +static std::shared_ptr LhsPtrFill(const size_t ci, const size_t ih, const size_t iw, + const size_t kh, const size_t kw, size_t sh, size_t sw, + const size_t padding, + const float* pad_ptr) { + size_t check_filled{0}; + + const auto m = ComputeConvOutSize(ih, kh, padding, sh) * ComputeConvOutSize(iw, kw, padding, sw); + + const auto m_step = kai_get_m_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa(); + const auto lhs_ptrs_k = kh * kw; + const auto lhs_ptrs_m = m_step * MlasDivRoundup(m, m_step); + auto lhs_ptrs = std::shared_ptr(new const void*[lhs_ptrs_k * lhs_ptrs_m], + std::default_delete()); + + + auto ih_out_size = ComputeConvOutSize(ih, kh, padding, 1); + auto iw_out_size = ComputeConvOutSize(iw, kw, padding, 1); + + auto ptrs_offset = [lhs_ptrs_m,lhs_ptrs_k, m_step](size_t k, size_t m) { + //(m/m_step,transpose(m_step,k) + auto offset {((m/m_step) * lhs_ptrs_k * m_step) + (k*m_step) + (m%m_step)}; + assert(offset < (lhs_ptrs_k * lhs_ptrs_m)); + + MLAS_UNREFERENCED_PARAMETER(lhs_ptrs_m); + + return offset; + }; + + auto pixel_offset = [ih, iw, ci, pad_ptr, padding](size_t h, size_t w) { + if (h < padding) { + return reinterpret_cast(&pad_ptr[0]); + } + h -= padding; + + if (w < padding) { + return reinterpret_cast(&pad_ptr[0]); + } + w -= padding; + + if ((h >= ih) || (w >= iw)) { + return reinterpret_cast(&pad_ptr[0]); + } + + auto offset{h * iw * ci + w * ci}; + assert(offset < (ih*iw*ci)); + return offset*sizeof(float); + }; + + size_t m_{0}; + auto lhs_ptrs_ = lhs_ptrs.get(); + for (size_t ih_ = 0; ih_ < ih_out_size; ih_ += sh) { + for (size_t iw_ = 0; iw_ < iw_out_size; iw_ += sw, ++m_) { + size_t k_{0}; + for (size_t kh_ = 0; kh_ < kh; ++kh_) { + for (size_t kw_ = 0; kw_ < kw; ++kw_) { + lhs_ptrs_[ptrs_offset(k_, m_)] = reinterpret_cast(pixel_offset(ih_+kh_, iw_+kw_)); + k_++; check_filled++; + } + } + } + } + + assert(check_filled == (lhs_ptrs_k * m)); + MLAS_UNREFERENCED_PARAMETER(check_filled); + + return lhs_ptrs; +} + +static std::unique_ptr LhsPackImageDataSme(const size_t ci, const size_t ih, const size_t iw, + const size_t kh, const size_t kw, const size_t sh, + const size_t sw, const size_t padding, const float* in, + MLAS_THREADPOOL* ThreadPool) +{ + size_t padsize = 256; + if(ci > padsize) + { + // figure out how many blocks needed to correctly fill padding + padsize = ((ci + padsize - 1) / padsize) * padsize; + } + static std::vectorpad_ptr(padsize, 0.f); + + LhsCacheKey key = { + ci, ih, iw, + padding, sh, sw, + kh, kw, + 1, 1, + HashWeights(in) + }; + + //create lhs in format required for imatmul + const auto m = ComputeConvOutSize(ih, kh, padding, sh) * ComputeConvOutSize(iw, kw, padding, sw); + + const auto lhs_size = kai_get_lhs_packed_size_lhs_imatmul_pack_x32p2vlx1_x32p_sme(m,kh*kw,ci); + auto lhs = std::make_unique(lhs_size); + + auto nhwc = NChwToNhwc(1, ci, ih, iw, in, 1, 1, false, ThreadPool); + + //cache of computed lhs ptr offsets + static std::unordered_map> lhs_ptrs_cache; + + std::shared_ptr lhs_ptrs; + if (auto found = lhs_ptrs_cache.find(key); found != lhs_ptrs_cache.end()) { + lhs_ptrs = found->second; + } else { + lhs_ptrs = LhsPtrFill(ci, ih, iw, kh, kw, sh, sw, padding, &pad_ptr[0]); + lhs_ptrs_cache[key] = lhs_ptrs; + } + + MultiThreadedLHSPackSme(ThreadPool, ci, m, kh, kw, &lhs_ptrs[0], &lhs[0], &nhwc[0], &pad_ptr[0]); + + return lhs; +} + +static void ConvolveSme(const size_t co, //channels out + const size_t ci, //channels in + const size_t ih, //image height + const size_t iw, //image width + const size_t kh, //kernel height + const size_t kw, //kernel width + const size_t sh, //kernel stride height + const size_t sw, //kernel stride width + const size_t dilationh, //kernel dilation stride + const size_t dilationw, //kernel dilation stride + const size_t padding, //padding size + const size_t groups, //number of filter groups + const float* weights, //kernel weights [co,ci,ih,iw] + const float* bias, //kernel biases + const float* in, //in image data + float* out, //out image data + float* tmp_mlas_aligned, //intermediate buffer if we need to perform a transpose + MLAS_THREADPOOL* ThreadPool) { + + //RhsPackWeightsBiasSme() - to perform dilation increases kernel size and masks unused weights + //compute corrected dimensions of dilated kernel + const auto d_kh = ComputeKernelSize(dilationh, kh); + const auto d_kw = ComputeKernelSize(dilationw, kw); + + //run igemm based convolution + const auto m = ComputeConvOutSize(ih, d_kh, padding, sh) * + ComputeConvOutSize(iw, d_kw, padding, sw); + + auto n_step = kai_get_n_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa(); + auto m_step = kai_get_m_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa(); + + //tile iteration dimensions + std::array dim; + dim[0] = 1; // B + dim[1] = MlasDivRoundup(m, m_step); // M + dim[2] = MlasDivRoundup(co, n_step); // N + + //Minimize the kernel call count for the number of available threads + auto RequiredTiles = std::min(static_cast(MlasGetMaximumThreadCount(ThreadPool)), dim[0]*dim[1]*dim[2]); + + //scale required tiles over available tile processors + dim[1] = MlasDivRoundup(RequiredTiles * dim[1], dim[1] * dim[2]); + dim[2] = MlasDivRoundup(RequiredTiles * dim[2], dim[1] * dim[2]); + + //compute new step sizes + m_step *= MlasDivRoundup(MlasDivRoundup(m, dim[1]), m_step); + n_step *= MlasDivRoundup(MlasDivRoundup(co, dim[2]), n_step); + + //update tile iterations + dim[1] = MlasDivRoundup(m, m_step); + dim[2] = MlasDivRoundup(co, n_step); + + for (size_t g = 0; g < groups; ++g) { + + auto result{out}; + //do we require a post matmul transpose ? + //output is m x n or image_data x co or hw x co + //MLAS require it as n x m (or co x hw), transpose required + if (co > 1) { + //intermediate buffer required, pre-transpose + //Note: because we are calling MlasTranspose() need to ensure we use a MLAS aligned buffer + result = tmp_mlas_aligned; + } + + auto lhs = LhsPackImageDataSme(ci, ih, iw, d_kh, d_kw, sh, sw, padding, in, ThreadPool); + auto rhs = RhsPackWeightsBiasSme(co, ci, kh, kw, dilationh, dilationw, weights, bias, ThreadPool); + + + MlasTrySimpleParallel(ThreadPool, + static_cast(dim[0]*dim[1]*dim[2]), + [&](ptrdiff_t tid) + { + //compute B,M,N index from iteration index + //ptrdiff_t BIdx = tid / (dim[1] * dim[2]); + ptrdiff_t MIdx = (tid % (dim[1] * dim[2])) / dim[2]; + ptrdiff_t NIdx = (tid % (dim[1] * dim[2])) % dim[2]; + + // Get rhs tile, B + const size_t rhs_packed_offset = + kai_get_rhs_packed_offset_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa(NIdx*n_step, + d_kh*d_kw,ci); + + auto BTile = reinterpret_cast( + reinterpret_cast(rhs.get()) + rhs_packed_offset + ); + + // Get lhs tile, A + const size_t lhs_packed_offset = + kai_get_lhs_packed_offset_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa(MIdx*m_step, + d_kh*d_kw,ci); + + auto ATile = reinterpret_cast( + reinterpret_cast(lhs.get()) + lhs_packed_offset + ); + + auto TileSizeM = (MIdx + 1) * m_step > m ? (m - MIdx * m_step) : m_step; + auto TileSizeN = (NIdx + 1) * n_step > co ? (co - NIdx * n_step) : n_step; + + // Get result tile, C + auto CTile = &reinterpret_cast(result)[ + MIdx * m_step * co * sizeof(float) + + NIdx * n_step * sizeof(float)]; + + kai_run_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa( + TileSizeM, TileSizeN, d_kh*d_kw, ci, ATile, BTile, CTile, co * sizeof(float), + -std::numeric_limits::max(), std::numeric_limits::max() + ); + }); + + if (result == tmp_mlas_aligned) { + //Note: this could be absorbed into post conv activation + MlasTranspose(tmp_mlas_aligned, out, m, co, ThreadPool); + } + + in += ci * ih * iw; + out += m * co; + weights += co * ci * kh * kw; + if(bias){ + bias += co; + } + } +} + +bool MLASCALL +ArmKleidiAI::MlasConvPrepare(MLAS_CONV_PARAMETERS* Parameters, + size_t Dimensions, + size_t BatchCount, + size_t GroupCount, + size_t InputChannels, + const int64_t* InputShape, + const int64_t* KernelShape, + const int64_t* DilationShape, + const int64_t* Padding, + const int64_t* StrideShape, + const int64_t* OutputShape, + size_t FilterCount, + const MLAS_ACTIVATION* Activation, + size_t* WorkingBufferSize, + float Beta, + MLAS_THREADPOOL* ThreadPool) +{ + //Check dimensions before accessing + if (Dimensions < 2) { + return false; + } + + Parameters->Activation = Activation; + Parameters->Dimensions = Dimensions; + Parameters->BatchCount = BatchCount; + Parameters->GroupCount = GroupCount; + Parameters->InputChannels = InputChannels; + Parameters->FilterCount = FilterCount; + Parameters->Beta = Beta; + + size_t InputSize = 1; + size_t OutputSize = 1; + size_t K = InputChannels; + + for (size_t dim = 0; dim < Dimensions; dim++) { + + Parameters->InputShape[dim] = size_t(InputShape[dim]); + Parameters->OutputShape[dim] = size_t(OutputShape[dim]); + Parameters->KernelShape[dim] = size_t(KernelShape[dim]); + Parameters->DilationShape[dim] = size_t(DilationShape[dim]); + Parameters->Padding[dim] = size_t(Padding[dim]); + Parameters->Padding[dim + Dimensions] = size_t(Padding[dim + Dimensions]); + Parameters->StrideShape[dim] = size_t(StrideShape[dim]); + + InputSize *= Parameters->InputShape[dim]; + OutputSize *= Parameters->OutputShape[dim]; + K *= Parameters->KernelShape[dim]; + } + + Parameters->InputSize = InputSize; + Parameters->OutputSize = OutputSize; + Parameters->K = K; + + Parameters->ThreadCount = MlasGetMaximumThreadCount(ThreadPool); + + if(!CheckCapabilitiesSme(Parameters)){ + return false; + } + + //Allocate an aligned buffer for MlasTranspose() + *WorkingBufferSize = ComputeMlasWorkingBufferSize(Parameters->FilterCount, + Parameters->InputShape[0], Parameters->InputShape[1], + Parameters->KernelShape[0], Parameters->KernelShape[1], + Parameters->DilationShape[0], Parameters->DilationShape[1], + Parameters->StrideShape[0], Parameters->StrideShape[1], + Parameters->Padding[0]); + return true; +} + +bool +MLASCALL +ArmKleidiAI::MlasConv( + const MLAS_CONV_PARAMETERS* Parameters, + const float* Input, + const float* Filter, + const float* Bias, + float* WorkingBuffer, + float* Output, + MLAS_THREADPOOL* ThreadPool + ) +{ + if(!CheckCapabilitiesSme(Parameters)){ + //Fallback to Default Mlas + return false; + }; + ConvolveSme(Parameters->FilterCount, Parameters->InputChannels, // channel out, in + Parameters->InputShape[0], Parameters->InputShape[1], // image dimensions + Parameters->KernelShape[0], Parameters->KernelShape[1], // kernel dimensions + Parameters->StrideShape[0], Parameters->StrideShape[1], // kernel stride dimensions + Parameters->DilationShape[0], Parameters->DilationShape[1], // kernel dilation + Parameters->Padding[0], // image padding + Parameters->GroupCount, // filter groups + Filter, Bias, Input, Output, WorkingBuffer, ThreadPool); + + MlasActivation(Parameters->Activation, Output, nullptr, Parameters->FilterCount, Parameters->OutputSize, + Parameters->OutputSize); + return true; +} diff --git a/onnxruntime/core/mlas/lib/kleidiai/mlasi_kleidiai.h b/onnxruntime/core/mlas/lib/kleidiai/mlasi_kleidiai.h new file mode 100644 index 0000000000000..11fd78c261834 --- /dev/null +++ b/onnxruntime/core/mlas/lib/kleidiai/mlasi_kleidiai.h @@ -0,0 +1,114 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: MIT +// + +#pragma once + +#include "mlasi.h" + +// Fix to ensure compatibility with MSVC build +#if defined(_MSC_VER) + #define RESTRICT __restrict +#else + #define RESTRICT __restrict__ +#endif +namespace ArmKleidiAI { +// +// Buffer packing routines. +// + +size_t +MLASCALL +MlasGemmPackBSize( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, + size_t N, + size_t K + ); + +bool +MLASCALL +MlasGemmPackB( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, + size_t N, + size_t K, + const float* B, + size_t ldb, + void* PackedB + ); + +bool +MLASCALL +MlasGemmBatch( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, + size_t M, + size_t N, + size_t K, + const MLAS_SGEMM_DATA_PARAMS* Data, + size_t BatchSize, + MLAS_THREADPOOL* ThreadPool + ); + +size_t +MLASCALL +MlasDynamicQgemmPackBSize( + size_t N, + size_t K +); + +void +MLASCALL +MlasDynamicQgemmPackB( + size_t N, + size_t K, + const int8_t* B, + const float* Scales, + const float* Bias, + void* PackedB +); + +//pack symmetric quantized B and dynamic quantized A +void +MLASCALL +MlasDynamicQGemmBatch( + const MLAS_GEMM_DYN_QUANT_SHAPE_PARAMS& Shape, + const MLAS_GEMM_DYN_QUANT_DATA_PARAMS* DataParams, + const size_t BatchN, + MLAS_THREADPOOL* ThreadPool + ); + +bool +MLASCALL +MlasConvPrepare(MLAS_CONV_PARAMETERS* Parameters, + size_t Dimensions, + size_t BatchCount, + size_t GroupCount, + size_t InputChannels, + const int64_t* InputShape, + const int64_t* KernelShape, + const int64_t* DilationShape, + const int64_t* Padding, + const int64_t* StrideShape, + const int64_t* OutputShape, + size_t FilterCount, + const MLAS_ACTIVATION* Activation, + size_t* WorkingBufferSize, + float Beta, + MLAS_THREADPOOL* ThreadPool); + +bool +MLASCALL +MlasConv( + const MLAS_CONV_PARAMETERS* Parameters, + const float* Input, + const float* Filter, + const float* Bias, + float* WorkingBuffer, + float* Output, + MLAS_THREADPOOL* ThreadPool + ); +} diff --git a/onnxruntime/core/mlas/lib/kleidiai/qgemm_kleidiai.cpp b/onnxruntime/core/mlas/lib/kleidiai/qgemm_kleidiai.cpp new file mode 100644 index 0000000000000..fb38f2cef9bf6 --- /dev/null +++ b/onnxruntime/core/mlas/lib/kleidiai/qgemm_kleidiai.cpp @@ -0,0 +1,116 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: MIT +// + +#include + +#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm.h" +#include "kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_f32.h" +#include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi8cxp_qsi8cx_neon.h" + +#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa.h" +#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot.h" + +#include "mlasi_kleidiai.h" + +//Matmul with float output of dynamic quantized A and symmetric quantized B. + +size_t +MLASCALL +ArmKleidiAI::MlasDynamicQgemmPackBSize( + size_t N, + size_t K +) { + //Default to sme2_mopa but this may not awalys be the most optimal kernel variant to use + auto nr = kai_get_nr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa(); + auto kr = kai_get_kr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa(); + auto sr = kai_get_sr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa(); + + //regardless of kernel variant use neon packing variant + return kai_get_rhs_packed_size_rhs_pack_kxn_qsi8cxp_qsi8cx_neon(N, K, nr, kr, sr); +} + +void +MLASCALL +ArmKleidiAI::MlasDynamicQgemmPackB( + size_t N, + size_t K, + const int8_t* B, + const float* Scales, + const float* Bias, + void* PackedB +) { + // Default to sme2_mopa but this may not awalys be the most optimal kernel variant to use + auto nr = kai_get_nr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa(); + auto kr = kai_get_kr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa(); + auto sr = kai_get_sr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa(); + + // y - float output + // scale_factor_lhs - lhs scaling factor + // scale_factor_rhs - rhs scaling factor + // lhs_q - lhs quantized (asymmetric, so has zero point) + // rhs_q - rhs quantized (symmetric so no zero point) + // lhs_zp - lhs zero point + // y = (1/(scale_factor_lhs * scale_factor_rhs) * sum( (lhs_q + lhs_zp)*rhs_q )) + bias + + // rhs packing requires lhs_zp because it will perform lhs_zp*rhs_q during rhs packing + // because lhs quantization is hidden from us, by lhs quant packing, we don't have a value for lhs_zp it is + // lhs dynamic quantization + + kai_rhs_pack_qsi8cx_params params{ + 1, // lhs_zp - set to 1 so it becomes sum((lhs_q + 1)*rhs_q )), + // the actual lhs_zp is applied during the matmul + 1.f // it is not used + }; + + //regardless of kernel variant use neon packing variant + kai_run_rhs_pack_kxn_qsi8cxp_qsi8cx_neon(1, N, K, nr, kr, sr, B, + // N bias values + Bias, + // N scale values + Scales, PackedB, 0, ¶ms); +} + +void +MLASCALL +ArmKleidiAI::MlasDynamicQGemmBatch( + const MLAS_GEMM_DYN_QUANT_SHAPE_PARAMS& Shape, + const MLAS_GEMM_DYN_QUANT_DATA_PARAMS* DataParams, + const size_t BatchN, + MLAS_THREADPOOL* ThreadPool +) { + for (auto b = BatchN; b > 0; --b,++DataParams) { + auto mr = kai_get_mr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa(); + auto kr = kai_get_kr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa(); + auto sr = kai_get_sr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa(); + + + //TODO enable multi-threading for lhs packing and matmul + MLAS_UNREFERENCED_PARAMETER(ThreadPool); + + //Dynamic Quantize A - lhs + auto lhs_size = kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32(Shape.M, Shape.K, mr, kr, sr); + std::byte* lhs = nullptr; + std::unique_ptr fallback; + + if (DataParams->Workspace && DataParams->WorkspaceSize >= lhs_size) { + lhs = static_cast(DataParams->Workspace); + } else { + fallback = std::make_unique(lhs_size); + lhs = fallback.get(); + } + + kai_run_lhs_quant_pack_qai8dxp_f32(Shape.M, Shape.K, mr, kr, sr, 0, DataParams->A, + Shape.K*sizeof(float), lhs); + + kai_run_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa( + Shape.M, Shape.N, Shape.K, lhs, DataParams->PackedB, + DataParams->C, + Shape.N * sizeof(float), + sizeof(float), + -std::numeric_limits::max(), std::numeric_limits::max() + ); + } +} diff --git a/onnxruntime/core/mlas/lib/kleidiai/sgemm_kleidiai.cpp b/onnxruntime/core/mlas/lib/kleidiai/sgemm_kleidiai.cpp new file mode 100644 index 0000000000000..caa445b71e2a5 --- /dev/null +++ b/onnxruntime/core/mlas/lib/kleidiai/sgemm_kleidiai.cpp @@ -0,0 +1,348 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: MIT +// + +#include "kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla.h" +#include "kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla.h" +#include "kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla.h" +#include "kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa.h" +#include "kai/ukernels/matmul/pack/kai_lhs_pack_f32p2vlx1_f32_sme.h" +#include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme.h" +#include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme.h" +#include "mlasi_kleidiai.h" + +size_t +MLASCALL +ArmKleidiAI::MlasGemmPackBSize( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, + size_t N, + size_t K +) +/*++ + +Routine Description: + + This routine computes the length in bytes for the packed matrix B buffer. + +Arguments: + + TransA - Supplies the transpose operation on A matrix + + TransB - Supplies the transpose operation on B matrix + + N - Supplies the number of columns of matrix B. + + K - Supplies the number of rows of matrix B. + +Return Value: + + Returns the size in bytes for the packed matrix B buffer. + +--*/ +{ + if (TransA != CblasNoTrans || N == 0 || K == 0) { + return 0; + } + // + // Compute the number of bytes required to hold the packed buffer. + // + size_t bytes = 0; + + if (TransA == CblasNoTrans) { + switch (TransB) { + case CblasNoTrans: + bytes = kai_get_rhs_packed_size_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme(N, K); + break; + case CblasTrans: + bytes = kai_get_rhs_packed_size_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme(N, K); + break; + default: + return 0; + } + } else { + return 0; + } + + return bytes; +} + +bool +MLASCALL +ArmKleidiAI::MlasGemmPackB( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, + size_t N, + size_t K, + const float* B, + size_t ldb, + void* PackedB +) +/*++ + +Routine Description: + + This routine packs the contents of matrix B to the destination buffer. The + destination buffer should be sized based on MlasGemmPackBSize(). For best + performance, the destination buffer should be aligned to the value returned + from MlasGetPreferredBufferAlignment(). + +Arguments: + + TransA - Supplies the transpose operation for matrix A. + + TransB - Supplies the transpose operation for matrix B. + + N - Supplies the number of columns of matrix B. + + K - Supplies the number of rows of matrix B. + + B - Supplies the address of matrix B. + + ldb - Supplies the first dimension of matrix B. + + PackedB - Supplies the address of packed matrix B. + +Return Value: + + None. + +--*/ +{ + if (N == 0 || K == 0) { + return false; + } + + if (TransA == CblasNoTrans) { + const size_t nr = kai_get_nr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(); + const size_t kr = kai_get_kr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(); + const size_t sr = kai_get_sr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(); + + // pass zeroed bias values + const std::vector bias(N); + + switch (TransB) { + case CblasNoTrans: + kai_run_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme(1, N, K, nr, kr, sr, ldb * sizeof(float), B, bias.data(), nullptr, PackedB, 0, nullptr); + break; + case CblasTrans: + kai_run_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme(1, N, K, nr, kr, sr, ldb * sizeof(float), B, bias.data(), nullptr, PackedB, 0, nullptr); + break; + default: + return false; + } + return true; + } + else{ + return false; + } +} + +bool +MLASCALL +ArmKleidiAI::MlasGemmBatch( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, + size_t M, + size_t N, + size_t K, + const MLAS_SGEMM_DATA_PARAMS* Data, + size_t BatchSize, + MLAS_THREADPOOL* ThreadPool +) +{ + if(TransA == CblasTrans) + { + return false; + } + if (TransA == CblasNoTrans && K == 0) { + if (Data->beta != 1.0f) { + for (size_t i = 0; i < M; ++i) { + for (size_t j = 0; j < N; ++j) { + Data->C[i * Data->ldc + j] *= Data->beta; + } + } + } + } + if (Data->beta == 0.0f){ + std::fill_n(Data->C, M * Data->ldc, 0.0f); + } + //Fallback in the case of unsupported cases + if (M == 0 || N == 0 || K == 0 || + TransA != CblasNoTrans || + (TransB != CblasNoTrans && !Data[0].BIsPacked)) + { + return false; + } + + if (TransA == CblasNoTrans) { + const size_t mr = kai_get_mr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(); + const size_t kr = kai_get_kr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(); + const size_t sr = kai_get_sr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(); + + auto m_step = kai_get_m_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(); + auto n_step = kai_get_n_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(); + + if (M < m_step || N < n_step) { + if (GetMlasPlatform().MlasGemmBatchOverride != ArmKleidiAI::MlasGemmBatch){ + //Fallback to MLAS + return false; + } + } + + std::vector KaiPackedData; + KaiPackedData.resize(BatchSize); + + size_t LhsPackedStride = 0; + std::byte* LhsPackedData = nullptr; + + LhsPackedStride = kai_get_lhs_packed_size_lhs_pack_f32p2vlx1_f32_sme(M, K, mr, kr, sr); + auto LhsPacked = std::make_unique(LhsPackedStride * BatchSize); + LhsPackedData = LhsPacked.get(); + + std::unique_ptr RhsPacked{nullptr}; + + // It is assumed all B batches require packing or not + if (Data[0].BIsPacked) { + // We have already decided the matmul variant we are using, before having values for M,N,K + MlasTrySimpleParallel(ThreadPool, BatchSize, [&](ptrdiff_t batch_idx) { + std::byte* LhsPackedPtr = &(LhsPackedData[LhsPackedStride * batch_idx]); + + kai_run_lhs_pack_f32p2vlx1_f32_sme(M, K, mr, kr, sr, 0, Data[batch_idx].A, Data[batch_idx].lda * sizeof(float), LhsPackedPtr); + + KaiPackedData[batch_idx].A = reinterpret_cast(LhsPackedPtr); + KaiPackedData[batch_idx].B = Data[batch_idx].B; + }); + } else { + // Multithread pack lhs and rhs + size_t RhsPackedStride = 0; + std::byte* RhsPackedData = nullptr; + + RhsPackedStride = ArmKleidiAI::MlasGemmPackBSize(TransA, TransB, N, K); + RhsPacked = std::make_unique(RhsPackedStride * BatchSize); + RhsPackedData = RhsPacked.get(); + + MlasTrySimpleParallel(ThreadPool, BatchSize * 2, [&](ptrdiff_t batch_idx) { + // lhs odd, rhs even + if (batch_idx & 0x1) { + batch_idx >>= 1; + + std::byte* LhsPackedPtr = &(LhsPackedData[LhsPackedStride * batch_idx]); + + kai_run_lhs_pack_f32p2vlx1_f32_sme(M, K, mr, kr, sr, 0, Data[batch_idx].A, Data[batch_idx].lda * sizeof(float), LhsPackedPtr); + + KaiPackedData[batch_idx].A = reinterpret_cast(LhsPackedPtr); + } else { + batch_idx >>= 1; + + std::byte* RhsPackedPtr = &(RhsPackedData[RhsPackedStride * batch_idx]); + + ArmKleidiAI::MlasGemmPackB(TransA, TransB, N, K, reinterpret_cast(Data[batch_idx].B), Data[batch_idx].ldb, RhsPackedPtr); + + KaiPackedData[batch_idx].B = reinterpret_cast(RhsPackedPtr); + } + }); + } + + // tile iteration dimensions + std::array dim; + dim[0] = BatchSize; // B + dim[1] = MlasDivRoundup(M, m_step); // M + dim[2] = MlasDivRoundup(N, n_step); // N + + // Minimize the kernel call count for the number of available threads + auto RequiredTiles = std::min(static_cast(MlasGetMaximumThreadCount(ThreadPool)), dim[0] * dim[1] * dim[2]); + + // scale required tiles over available tile processors + dim[1] = MlasDivRoundup(RequiredTiles * dim[1], dim[1] * dim[2]); + dim[2] = MlasDivRoundup(RequiredTiles * dim[2], dim[1] * dim[2]); + + // compute new step sizes + m_step *= MlasDivRoundup(MlasDivRoundup(M, dim[1]), m_step); + n_step *= MlasDivRoundup(MlasDivRoundup(N, dim[2]), n_step); + + // update tile iterations + dim[1] = MlasDivRoundup(M, m_step); + dim[2] = MlasDivRoundup(N, n_step); + + MlasTrySimpleParallel(ThreadPool, static_cast(dim[0] * dim[1] * dim[2]), [=](ptrdiff_t tid) { + // compute B,M,N index from iteration index + ptrdiff_t BIdx = tid / (dim[1] * dim[2]); + ptrdiff_t MIdx = (tid % (dim[1] * dim[2])) / dim[2]; + ptrdiff_t NIdx = (tid % (dim[1] * dim[2])) % dim[2]; + + // Get rhs tile, B + const size_t rhs_packed_offset = + kai_get_rhs_packed_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(NIdx * n_step, K); + + auto BTile = reinterpret_cast( + reinterpret_cast(KaiPackedData[BIdx].B) + rhs_packed_offset + ); + + // Get lhs tile, A + const size_t lhs_packed_offset = + kai_get_lhs_packed_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(MIdx * m_step, K); + + auto ATile = reinterpret_cast( + reinterpret_cast(KaiPackedData[BIdx].A) + lhs_packed_offset + ); + + auto TileSizeM = (MIdx + 1) * m_step > M ? (M - MIdx * m_step) : m_step; + auto TileSizeN = (NIdx + 1) * n_step > N ? (N - NIdx * n_step) : n_step; + + // Get result tile, C + auto CTile = reinterpret_cast( + reinterpret_cast(Data[BIdx].C) + + MIdx * m_step * Data[BIdx].ldc * sizeof(float) + + NIdx * n_step * sizeof(float) + ); + // Allocate temporary buffer for raw A*B result + std::vector OutputTile(TileSizeM * TileSizeN, 0.0f); + float* temp_tile = OutputTile.data(); + + + kai_run_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa( + TileSizeM, + TileSizeN, + K, + ATile, BTile, temp_tile, + TileSizeN * sizeof(float), sizeof(float), + -std::numeric_limits::max(), std::numeric_limits::max() + ); + + // Final output tile pointer + float* dst_tile = reinterpret_cast(CTile); + + // quick copy of data in cases where we are not scaling or accumulating anything + // with bounds checking on tile sizing to ensure the data fits in the memory block + bool can_memcpy = ( + Data[BIdx].alpha == 1.0f && + Data[BIdx].beta == 0.0f && + Data[BIdx].ldc == TileSizeN && + MIdx * m_step + TileSizeM <= M && + NIdx * n_step + TileSizeN <= N && + TileSizeM != 0 && + TileSizeN != 0); + + if (can_memcpy) { + std::memcpy(dst_tile, temp_tile, TileSizeM * TileSizeN * sizeof(float)); + }else { + // apply alpha scaling and beta to output files + for (size_t i = 0; i < TileSizeM; ++i) { + for (size_t j = 0; j < TileSizeN; ++j) { + const size_t idx = i * TileSizeN + j; + const size_t dst_idx = i * Data[BIdx].ldc + j; + + float ab = temp_tile[idx]; + float c_orig = dst_tile[dst_idx]; + + dst_tile[dst_idx] = Data[BIdx].alpha * ab + Data[BIdx].beta * c_orig; + } + } + } + }); + } + return true; +} diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h index 0879d1b0ba510..a099bcf8438fe 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -35,7 +35,7 @@ Module Name: #endif #endif // MLAS_NO_EXCEPTION -#include "mlas.h" +#include "core/mlas/inc/mlas.h" #if defined(_WIN32) #ifndef WIN32_LEAN_AND_MEAN @@ -118,9 +118,8 @@ Module Name: #ifdef MLAS_NO_EXCEPTION -MLAS_FORCEINLINE -void -MlasPrintFinalMessage(const std::string& msg) +MLAS_FORCEINLINE void + MlasPrintFinalMessage(const std::string& msg) { #if defined(__ANDROID__) __android_log_print(ANDROID_LOG_ERROR, "mlas", "%s", msg.c_str()); @@ -134,6 +133,7 @@ MlasPrintFinalMessage(const std::string& msg) #endif } + #define MLAS_THROW_EX(ex, what) \ do { \ std::string msg = #ex; \ @@ -781,6 +781,119 @@ struct MLAS_QUANT_KERNEL size_t KernelSize ); }; +typedef +void +(MLASCALL MLAS_CONV_FLOAT_FN)( + const MLAS_CONV_PARAMETERS* Parameters, + const float* Input, + const float* Filter, + const float* Bias, + float* WorkingBuffer, + float* Output, + MLAS_THREADPOOL* ThreadPool + ); +typedef +bool +(MLASCALL MLAS_CONV_FLOAT_OVERRIDE)( + const MLAS_CONV_PARAMETERS* Parameters, + const float* Input, + const float* Filter, + const float* Bias, + float* WorkingBuffer, + float* Output, + MLAS_THREADPOOL* ThreadPool + ); +// TODO: Investigate if overridden typedefs can be removed +typedef +void +(MLASCALL MLAS_CONV_PREPARE_FLOAT_FN)( + MLAS_CONV_PARAMETERS* Parameters, + size_t Dimensions, + size_t BatchCount, + size_t GroupCount, + size_t InputChannels, + const int64_t* InputShape, + const int64_t* KernelShape, + const int64_t* DilationShape, + const int64_t* Padding, + const int64_t* StrideShape, + const int64_t* OutputShape, + size_t FilterCount, + const MLAS_ACTIVATION* Activation, + size_t* WorkingBufferSize, + float Beta, + MLAS_THREADPOOL* ThreadPool + ); +typedef +bool +(MLASCALL MLAS_CONV_PREPARE_FLOAT_OVERRIDE)( + MLAS_CONV_PARAMETERS* Parameters, + size_t Dimensions, + size_t BatchCount, + size_t GroupCount, + size_t InputChannels, + const int64_t* InputShape, + const int64_t* KernelShape, + const int64_t* DilationShape, + const int64_t* Padding, + const int64_t* StrideShape, + const int64_t* OutputShape, + size_t FilterCount, + const MLAS_ACTIVATION* Activation, + size_t* WorkingBufferSize, + float Beta, + MLAS_THREADPOOL* ThreadPool + ); + +typedef void (MLASCALL MLAS_GEMM_BATCH)( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, + size_t M, + size_t N, + size_t K, + const MLAS_SGEMM_DATA_PARAMS* Data, + size_t BatchSize, + MLAS_THREADPOOL* ThreadPool); + +typedef bool (MLASCALL MLAS_GEMM_BATCH_OVERRIDE)( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, + size_t M, + size_t N, + size_t K, + const MLAS_SGEMM_DATA_PARAMS* Data, + size_t BatchSize, + MLAS_THREADPOOL* ThreadPool); + +typedef size_t (MLASCALL MLAS_GEMM_PACK_B_SIZE)( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, + size_t N, + size_t K); + +typedef size_t (MLASCALL MLAS_GEMM_PACK_B_SIZE_OVERRIDE)( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, + size_t N, + size_t K); + +typedef void (MLASCALL MLAS_GEMM_PACK_B_KERNEL)( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, + size_t N, + size_t K, + const float* B, + size_t ldb, + void* PackedB); + +typedef bool (MLASCALL MLAS_GEMM_PACK_B_KERNEL_OVERRIDE)( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, + size_t N, + size_t K, + const float* B, + size_t ldb, + void* PackedB); extern "C" { @@ -1184,6 +1297,12 @@ struct MLAS_PLATFORM { // TODO: move to cpuinfo bool Avx2Supported_ = false; bool Avx512Supported_ = false; + // Mlas overrides initialisation + MLAS_GEMM_BATCH_OVERRIDE* MlasGemmBatchOverride = nullptr; + MLAS_GEMM_PACK_B_SIZE_OVERRIDE* MlasGemmPackBSizeOverride = nullptr; + MLAS_GEMM_PACK_B_KERNEL_OVERRIDE* MlasGemmPackBOverride = nullptr; + MLAS_CONV_PREPARE_FLOAT_OVERRIDE* MlasConvPrepareOverride = nullptr; + MLAS_CONV_FLOAT_OVERRIDE* MlasConvOverride = nullptr; #if defined(MLAS_TARGET_AMD64_IX86) || defined(MLAS_TARGET_POWER) MLAS_GEMM_FLOAT_KERNEL* GemmFloatKernel; diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp index 45bba5363d4f2..3256dadb856d3 100644 --- a/onnxruntime/core/mlas/lib/platform.cpp +++ b/onnxruntime/core/mlas/lib/platform.cpp @@ -17,6 +17,10 @@ Module Name: #include "mlasi.h" +#if defined(USE_KLEIDIAI) && !defined(_MSC_VER) +#include "kleidiai/mlasi_kleidiai.h" +#endif + #include #include @@ -579,6 +583,15 @@ Return Value: } this->QNBitGemmDispatch = &GetMlasQNBitGemmDispatchNeon(HasDotProductInstructions); +#if defined(USE_KLEIDIAI) && !defined(_MSC_VER) + if(MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME()){ + this->MlasGemmBatchOverride = ArmKleidiAI::MlasGemmBatch; + this->MlasGemmPackBSizeOverride = ArmKleidiAI::MlasGemmPackBSize; + this->MlasGemmPackBOverride = ArmKleidiAI::MlasGemmPackB; + this->MlasConvPrepareOverride = ArmKleidiAI::MlasConvPrepare; + this->MlasConvOverride = ArmKleidiAI::MlasConv; + } +#endif #if defined(__linux__) // diff --git a/onnxruntime/core/mlas/lib/qgemm.cpp b/onnxruntime/core/mlas/lib/qgemm.cpp index f5b33d2a9ad34..4e9a0e27099dc 100644 --- a/onnxruntime/core/mlas/lib/qgemm.cpp +++ b/onnxruntime/core/mlas/lib/qgemm.cpp @@ -14,10 +14,16 @@ Module Name: operation (QGEMM). --*/ - -#include "mlasi.h" +#include +#include "core/mlas/lib/mlasi.h" #include "qgemm.h" +// TODO: When overrides are implemented, remove this +#if defined(USE_KLEIDIAI) && !defined(_MSC_VER) +#include "kleidiai/mlasi_kleidiai.h" +#endif + + // // Define the parameters to execute segments of a QGEMM operation on worker // threads. @@ -195,6 +201,26 @@ MlasGemmBatch( }); } +void +MLASCALL +MlasDynamicQGemmBatch ( + const MLAS_GEMM_DYN_QUANT_SHAPE_PARAMS& Shape, + const MLAS_GEMM_DYN_QUANT_DATA_PARAMS* DataParams, + const size_t BatchN, + MLAS_THREADPOOL* ThreadPool +) { +#if defined(USE_KLEIDIAI) && !defined(_MSC_VER) + //No fallback and putting in guards + if(MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME()){ + ArmKleidiAI::MlasDynamicQGemmBatch(Shape, DataParams, BatchN, ThreadPool); + } +#endif + + MLAS_UNREFERENCED_PARAMETER(Shape); + MLAS_UNREFERENCED_PARAMETER(DataParams); + MLAS_UNREFERENCED_PARAMETER(BatchN); + MLAS_UNREFERENCED_PARAMETER(ThreadPool); +} int32_t MlasSymmQgemmGetKernelOutputCnt() @@ -293,10 +319,35 @@ MlasSymmQgemmBatch( }); } + + #if defined(_MSC_VER) && !defined(__clang__) #pragma warning(pop) #endif +size_t +MLASCALL +MlasDynamicQgemmPackBSize( + size_t N, + size_t K +) +{ + size_t bytes = 0; +#if defined(USE_KLEIDIAI) && !defined(_MSC_VER) + //No fallback available + //TODO: Insert Override + if(MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME()){//Still require this since no override + bytes = ArmKleidiAI::MlasDynamicQgemmPackBSize(N, K); + } +#endif + + MLAS_UNREFERENCED_PARAMETER(N); + MLAS_UNREFERENCED_PARAMETER(K); + + return bytes; +} + + size_t MLASCALL MlasGemmPackBSize( @@ -354,10 +405,38 @@ Return Value: const size_t BufferAlignment = MlasGetPreferredBufferAlignment(); const size_t AlignedBytesRequired = (BytesRequired + BufferAlignment - 1) & ~(BufferAlignment - 1); + //If this gemm B argument is used in a dynamically quantization gemm operation we can optimize for + //this use case. Concat both packed representations for later decision. + return AlignedBytesRequired + MlasDynamicQgemmPackBSize(N, K); +} - return AlignedBytesRequired; +void +MLASCALL +MlasDynamicQgemmPackB( + size_t N, + size_t K, + const int8_t* B, + const float* Scales, + const float* Bias, + void* PackedB +) +{ +#if defined(USE_KLEIDIAI) && !defined(_MSC_VER) + //No fallback + if(MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME()){//Still require this since no override + ArmKleidiAI::MlasDynamicQgemmPackB(N, K, B, Scales, Bias, PackedB); + } +#endif + + MLAS_UNREFERENCED_PARAMETER(N); + MLAS_UNREFERENCED_PARAMETER(K); + MLAS_UNREFERENCED_PARAMETER(B); + MLAS_UNREFERENCED_PARAMETER(Scales); + MLAS_UNREFERENCED_PARAMETER(Bias); + MLAS_UNREFERENCED_PARAMETER(PackedB); } + void MLASCALL MlasGemmPackB( @@ -400,7 +479,6 @@ Return Value: // // Retrieve the packing parameters. // - const auto* GemmQuantDispatch = MlasGemmQuantGetDispatch(AIsSigned, BIsSigned); size_t PackedK = GemmQuantDispatch->PackedK; @@ -515,7 +593,6 @@ MlasSymmQgemmPackBSize( #pragma warning(pop) #endif - void MLASCALL MlasSymmQgemmPackB( diff --git a/onnxruntime/core/mlas/lib/sgemm.cpp b/onnxruntime/core/mlas/lib/sgemm.cpp index 616622a8c1f53..65c1ccbadad38 100644 --- a/onnxruntime/core/mlas/lib/sgemm.cpp +++ b/onnxruntime/core/mlas/lib/sgemm.cpp @@ -1572,7 +1572,13 @@ MlasGemmBatch( MLAS_THREADPOOL* ThreadPool ) { - + // Override + if(GetMlasPlatform().MlasGemmBatchOverride != nullptr && + // TODO: Remove once KAI supports transposing for A + TransA != CBLAS_TRANSPOSE::CblasTrans && + GetMlasPlatform().MlasGemmBatchOverride(TransA, TransB, M, N, K, Data, BatchSize, ThreadPool)){ + return; + } // // Compute the number of target threads given the complexity of the SGEMM // operation. Small requests should run using the single threaded path. @@ -1637,6 +1643,8 @@ MlasGemmBatch( size_t MLASCALL MlasGemmPackBSize( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, size_t N, size_t K ) @@ -1661,6 +1669,22 @@ Return Value: // // Compute the number of bytes required to hold the packed buffer. // + // KleidiAI or other override + #if defined(USE_KLEIDIAI) && !defined(_MSC_VER) + if (GetMlasPlatform().MlasGemmPackBSizeOverride != nullptr && + // TODO: Remove once KAI supports transposing for A + TransA != CBLAS_TRANSPOSE::CblasTrans) { + size_t bytes_required; + //TODO pass status by reference to indicate success/fail + bytes_required = GetMlasPlatform().MlasGemmPackBSizeOverride(TransA, TransB, N, K); + if (bytes_required != 0){// If ArmKleidiAI::MlasGemmPackBSize ran to completion + return bytes_required; + } + } + #endif + MLAS_UNREFERENCED_PARAMETER(TransA); + MLAS_UNREFERENCED_PARAMETER(TransB); + const size_t AlignedN = (N + MLAS_SGEMM_STRIDEN_THREAD_ALIGN - 1) & ~(MLAS_SGEMM_STRIDEN_THREAD_ALIGN - 1); @@ -1676,6 +1700,7 @@ Return Value: void MLASCALL MlasGemmPackB( + CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB, size_t N, size_t K, @@ -1712,6 +1737,17 @@ Return Value: --*/ { +#if defined(USE_KLEIDIAI) && !defined(_MSC_VER) + if (GetMlasPlatform().MlasGemmPackBOverride != nullptr && + // TODO: Remove once KAI supports transposing for A + TransA != CBLAS_TRANSPOSE::CblasTrans && + GetMlasPlatform().MlasGemmPackBOverride(TransA, TransB, N, K, B, ldb, PackedB)){ + return; + } +#endif + MLAS_UNREFERENCED_PARAMETER(TransA); + + const size_t AlignedN = (N + MLAS_SGEMM_STRIDEN_THREAD_ALIGN - 1) & ~(MLAS_SGEMM_STRIDEN_THREAD_ALIGN - 1); diff --git a/onnxruntime/core/optimizer/matmul_integer_to_float.cc b/onnxruntime/core/optimizer/matmul_integer_to_float.cc index b619efb2f751e..7abd375cda896 100644 --- a/onnxruntime/core/optimizer/matmul_integer_to_float.cc +++ b/onnxruntime/core/optimizer/matmul_integer_to_float.cc @@ -170,13 +170,18 @@ Status MatMulIntegerToFloatFusion::ApplyImpl(Graph& graph, bool& modified, int g // Find bias node Node* p_add_node = nullptr; + int idx = 0; if (optimizer_utils::CheckOutputEdges(graph, mul_node, 1)) { const Node* tmp_add_node = graph_utils::FirstChildByType(mul_node, "Add"); if (nullptr != tmp_add_node) { - const NodeArg& tmp_add_node_B = *(tmp_add_node->InputDefs()[1]); - if (graph_utils::IsConstantInitializer(graph, tmp_add_node_B.Name(), true) && - CheckBiasShape(tmp_add_node_B.Shape())) { - p_add_node = graph.GetNode(tmp_add_node->Index()); + // check both "inputs" to find bias, caters for edge case where bias index in InputDefs is not what is expected + for (idx = 0; idx < 2; ++idx) { + const NodeArg& candidate = *(tmp_add_node->InputDefs()[idx]); + if (graph_utils::IsConstantInitializer(graph, candidate.Name(), true) && + CheckBiasShape(candidate.Shape())) { + p_add_node = graph.GetNode(tmp_add_node->Index()); + break; + } } } } @@ -203,7 +208,7 @@ Status MatMulIntegerToFloatFusion::ApplyImpl(Graph& graph, bool& modified, int g } if (p_add_node != nullptr) { - input_defs.push_back(p_add_node->MutableInputDefs()[1]); + input_defs.push_back(p_add_node->MutableInputDefs()[idx]); } std::string op_type = "MatMulIntegerToFloat"; diff --git a/onnxruntime/core/providers/cann/cann_execution_provider.cc b/onnxruntime/core/providers/cann/cann_execution_provider.cc index a691faaffd2a0..4bcf71335d15e 100644 --- a/onnxruntime/core/providers/cann/cann_execution_provider.cc +++ b/onnxruntime/core/providers/cann/cann_execution_provider.cc @@ -1488,8 +1488,7 @@ AllocatorPtr CANNExecutionProvider::CreateCannAllocator(OrtDevice::DeviceId devi -1, -1, -1L)}, - true, - false); + true); return CreateAllocator(default_memory_info); } diff --git a/onnxruntime/core/providers/cann/cann_stream_handle.cc b/onnxruntime/core/providers/cann/cann_stream_handle.cc index 041fc54a725a9..cdb727f263480 100644 --- a/onnxruntime/core/providers/cann/cann_stream_handle.cc +++ b/onnxruntime/core/providers/cann/cann_stream_handle.cc @@ -18,7 +18,7 @@ struct CannNotification : public synchronize::Notification { } void Activate() override { - CANN_CALL_THROW(aclrtRecordEvent(event_, static_cast(stream_.GetHandle()))); + CANN_CALL_THROW(aclrtRecordEvent(event_, static_cast(GetStream().GetHandle()))); } void wait_on_device(Stream& device_stream) { diff --git a/onnxruntime/core/providers/cann/cann_stream_handle.h b/onnxruntime/core/providers/cann/cann_stream_handle.h index f20eafb2b4b35..e7a352298b2bd 100644 --- a/onnxruntime/core/providers/cann/cann_stream_handle.h +++ b/onnxruntime/core/providers/cann/cann_stream_handle.h @@ -24,8 +24,6 @@ struct CannStream : Stream { void Flush() override; bool own_stream_{true}; - - WaitNotificationFn GetWaitNotificationFn() const override { return WaitCannNotificationOnDevice; } }; void RegisterCannStreamHandles(IStreamCommandHandleRegistry& stream_handle_registry, diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc index 6a89fc6234f0f..5eac0523d953a 100644 --- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc +++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc @@ -1290,6 +1290,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, #endif // Opset 23 +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 23, float, Attention); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 23, MLFloat16, Attention); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 23, Cast); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 23, ConstantOfShape); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 23, int32_t, DequantizeLinear); @@ -3254,6 +3256,8 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, // Opset 23 + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo()) \ + .TypeConstraint("T2", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("U", BuildKernelDefConstraints()), \ + Attention); + +REGISTER_ONNX_KERNEL_TYPED(float) +REGISTER_ONNX_KERNEL_TYPED(MLFloat16) + +template +void make_copy(T* mask_data, const U* mask_index, size_t size); + +template <> +void make_copy(float* mask_data, const float* mask_index, size_t size) { + memcpy(mask_data, mask_index, size * sizeof(float)); +} + +template <> +void make_copy(MLFloat16* mask_data, const MLFloat16* mask_index, size_t size) { + memcpy(mask_data, mask_index, size * sizeof(MLFloat16)); +} + +template <> +void make_copy(float* mask_data, const bool* mask_index, size_t size) { + for (size_t i = 0; i < size; ++i) { + mask_data[i] = mask_index[i] ? 0.0f : std::numeric_limits::lowest(); + } +} + +template <> +void make_copy(MLFloat16* mask_data, const bool* mask_index, size_t size) { + for (size_t i = 0; i < size; ++i) { + mask_data[i] = mask_index[i] ? MLFloat16(0.f) : std::numeric_limits::lowest(); + } +} + +template +inline void ComputeAttentionSoftmaxInplace(T* score, int N, int D, ThreadPool* tp, AllocatorPtr) { + MlasComputeSoftmax(score, score, N, D, false, false, 0.0f, tp); +} + +template <> +inline void ComputeAttentionSoftmaxInplace(MLFloat16* score, int N, int D, ThreadPool* tp, AllocatorPtr allocator) { + ORT_ENFORCE(tp == nullptr, "No parallelized version of softmax for float16."); + // Mlas Lacks kernels for fp16 softmax, we convert into float32 and call the float32 version. + void* allocated_ptr = allocator->Alloc(static_cast(N * D * sizeof(float))); + BufferUniquePtr float_buffer(allocated_ptr, BufferDeleter(allocator)); + float* ptr = reinterpret_cast(allocated_ptr); + MlasConvertHalfToFloatBuffer(score, ptr, N * D); + MlasComputeSoftmax(ptr, ptr, N, D, false, false, 0.0f, tp); + MlasConvertFloatToHalfBuffer(ptr, score, N * D); +} + +template +inline void ComputeAttentionSoftcapInplace(T* scores, int sequence_length, T softcap) { + MlasComputeSoftcap(scores, scores, sequence_length, softcap); +} + +template <> +inline void ComputeAttentionSoftcapInplace(MLFloat16* scores, int sequence_length, MLFloat16 softcap) { + // Mlas Lacks kernels for fp16 softcap. The code is similar to the softcap implementation in mlas. + float x; + float cap = softcap.ToFloat(); + for (size_t i = 0; i < static_cast(sequence_length); i++) { + x = std::tanh(scores[i].ToFloat() / cap) * cap; + scores[i] = MLFloat16(x); + } +} + +template +Attention::Attention(const OpKernelInfo& info) : AttentionBase(info) { + is_causal_ = static_cast(info.GetAttrOrDefault("is_causal", 0)) == 1; + // kv_num_heads, q_num_head are mandatory for 3D inputs but not used for 4D inputs. + // The dimension is not yet known. If not specified, the inputs is assumed to be 4D. + kv_num_heads_ = static_cast(info.GetAttrOrDefault("kv_num_heads", 0)); + q_num_heads_ = static_cast(info.GetAttrOrDefault("q_num_heads", 0)); + int mode = static_cast(info.GetAttrOrDefault("qk_matmul_output_mode", 0)); + qk_matmul_output_mode_ = info.node().OutputDefs().size() >= 4 && info.node().OutputDefs()[3]->Exists() + ? static_cast(mode) + : QKMatMulOutputMode::kNone; + ORT_ENFORCE(qk_matmul_output_mode_ == QKMatMulOutputMode::kNone || + qk_matmul_output_mode_ == QKMatMulOutputMode::kQK || + qk_matmul_output_mode_ == QKMatMulOutputMode::kQKMask || + qk_matmul_output_mode_ == QKMatMulOutputMode::kQKSoftCap || + qk_matmul_output_mode_ == QKMatMulOutputMode::kQKSoftMax, + "qk_matmul_output_mode must be 0, 1, 2, or 3."); + // The default scale depends on the input dimensions. It is set to nan to indicate that it should be computed. + scale_ = info.GetAttrOrDefault("scale", std::numeric_limits::quiet_NaN()); + softcap_ = info.GetAttrOrDefault("softcap", 0.0f); + softmax_precision_ = static_cast(info.GetAttrOrDefault("softmax_precision", 0)); + ORT_ENFORCE(scale_ > 0 || std::isnan(scale_), "scale must be greater than 0 if specified"); +} + +template +Status Attention::Compute(OpKernelContext* context) const { + const Tensor* Q = context->Input(0); + const Tensor* K = context->Input(1); + const Tensor* V = context->Input(2); + const Tensor* attn_mask = context->Input(3); + const Tensor* past_key = context->Input(4); + const Tensor* past_value = context->Input(5); + + AttentionParameters parameters; + std::vector y_shape; + std::vector present_key_shape; + std::vector present_value_shape; + std::vector output_qk_shape; + + ORT_ENFORCE(attention_helper::ComputeOutputShapeForAttention( + Q, + K, + V, + attn_mask, + past_key, + past_value, + is_causal_, + softcap_, + softmax_precision_, + qk_matmul_output_mode_, + kv_num_heads_, + q_num_heads_, + scale_, + parameters, + y_shape, + present_key_shape, + present_value_shape, + output_qk_shape) + .IsOK(), + "Output shapes for Attention could not be computed."); + + Tensor* Y = context->Output(0, y_shape); + Tensor* present_key = context->Output(1, present_key_shape); + Tensor* present_value = context->Output(2, present_value_shape); + Tensor* output_qk = parameters.qk_matmul_output_mode == QKMatMulOutputMode::kNone + ? nullptr + : context->Output(3, output_qk_shape); + return this->ApplyAttention(context, + Q->Data(), // Q + K->Data(), // K + V->Data(), // V + attn_mask, // const Tensor* mask_index, // mask, nullptr if no mask + past_key, // past K input tensor (if not using past state) + past_value, // past V input tensor (if not using past state) + Y, // first output + present_key, // present K output tensor (if separating present KV) + present_value, // present V output tensor (if separating present KV) + output_qk, // Q*K output tensor (if returning Q*K value) + parameters // attention parameters + ); +} + +template +void AttentionBase::ComputeAttentionProbs(T* attention_probs, // output buffer with size BxNxSxT + const T* Q, // Q data. Its size is BxNxSxH + const T* K, // k data. Its size is BxNxLxH + const Tensor* mask_index, // mask + const AttentionParameters& parameters, // attention parameters + const T* past_key, // past key only (if not using past state) + T* present_key, // present key only (if not using present state) + T* output_qk, // Q*K output + ThreadPool* tp, + AllocatorPtr allocator) const { + // The case past_key != nullptr and present_key == nullptr is not supported. + // We use the fact present_key is requested to avoid any extra allocation. + // However, if present_key is not requested, we should avoid allocated more memory than needed but that mean + // allocating one buffer per thread. That's why the implementation is not done. + // The user should define a model with a present_key even if not used if past_key is not null. + ORT_ENFORCE((past_key == nullptr) == (present_key == nullptr), + "The implementation only supports past_key and present_key both null or both not null."); + const size_t past_chunk_length = static_cast(parameters.past_sequence_length) * parameters.head_size; // P x H + const size_t q_input_chunk_length = static_cast(parameters.q_sequence_length) * parameters.head_size; // S x H + const size_t k_input_chunk_length = static_cast(parameters.kv_sequence_length) * parameters.head_size; // L x H + const size_t present_chunk_length = past_chunk_length + k_input_chunk_length; // T x H + + TensorOpCost unit_cost; + const ptrdiff_t probs_matrix_size = SafeInt(parameters.q_sequence_length) * + parameters.total_sequence_length; + const ptrdiff_t probs_matrix_bytes = probs_matrix_size * sizeof(T); + unit_cost.compute_cycles = + static_cast(SafeInt(2) * parameters.head_size * probs_matrix_size); + unit_cost.bytes_loaded = static_cast((parameters.q_sequence_length + + parameters.total_sequence_length) * + parameters.head_size * sizeof(T)); + unit_cost.bytes_stored = static_cast(probs_matrix_bytes); + + if (present_key) { + double bytes_to_copy_key = present_chunk_length * static_cast(sizeof(T)); + unit_cost.bytes_loaded += bytes_to_copy_key; + unit_cost.bytes_stored += bytes_to_copy_key; + } + + // Prepare mask + // Merge causal mask with padding mask, and convert values from 0/1 to -inf/0. + int mask_batch_size = static_cast(mask_index == nullptr || mask_index->Shape().NumDimensions() < 4 + ? 1 + : mask_index->Shape().GetDims()[0]); + int mask_num_heads = static_cast(mask_index == nullptr || mask_index->Shape().NumDimensions() < 3 + ? 1 + : (mask_index->Shape().NumDimensions() < 4 + ? mask_index->Shape().GetDims()[0] + : mask_index->Shape().GetDims()[1])); + + T* mask_data = nullptr; + bool delete_mask_data = false; + bool causal = parameters.is_causal && parameters.q_sequence_length > 1; + if (mask_index == nullptr) { + // No mask = null mask. + if (causal) { + size_t mask_data_bytes = SafeInt(parameters.q_sequence_length) * parameters.total_sequence_length * sizeof(T); + void* allocated_ptr = allocator->Alloc(mask_data_bytes); + memset(allocated_ptr, 0, mask_data_bytes); + mask_data = static_cast(allocated_ptr); + for (int s_i = 0; s_i < parameters.q_sequence_length; s_i++) { + for (int m_i = parameters.past_sequence_length + s_i + 1; m_i < parameters.total_sequence_length; m_i++) { + mask_data[s_i * parameters.total_sequence_length + m_i] = std::numeric_limits::lowest(); + } + } + delete_mask_data = true; + } + } else if (mask_index->IsDataType() || causal) { + // We need a copy. + size_t mask_data_bytes = SafeInt(mask_index->Shape().Size()) * sizeof(T); + mask_data = static_cast(allocator->Alloc(mask_data_bytes)); + delete_mask_data = true; + + if (mask_index->IsDataType()) { + // Convert bool mask to 0/1 + make_copy(mask_data, mask_index->Data(), SafeInt(mask_index->Shape().Size())); + } else if (mask_index != nullptr) { + // We make a copy because causal is True. + make_copy(mask_data, mask_index->Data(), SafeInt(mask_index->Shape().Size())); + } + if (causal) { + // This loop could be parallelized. + // According to the specifications, this configuration is not supported + // as is_causal=1 or mask is not None (exclusive or). + int n_iter = mask_batch_size * mask_num_heads; + for (int i = 0; i < n_iter; ++i) { + for (int s_i = 0; s_i < parameters.q_sequence_length; s_i++) { + for (int m_i = parameters.past_sequence_length + s_i + 1; m_i < parameters.total_sequence_length; m_i++) { + mask_data[s_i * parameters.total_sequence_length + m_i + probs_matrix_size * i] = std::numeric_limits::lowest(); + } + } + } + } + } else { + // Nothing to do, no necessary copy. + mask_data = const_cast(mask_index->Data()); + } + + bool transposed_k = parameters.transpose_output && nullptr == present_key; + if (nullptr != present_key && parameters.kv_num_heads != parameters.q_num_heads) { + // This is not part of the main loop because it is not needed at every iteration and + // we cannot ensure the inner body is executed first before getting used in another iteration. + // parameters.batch_size * parameters.q_num_heads + for (std::ptrdiff_t batch_i = 0; batch_i < parameters.batch_size; ++batch_i) { + for (std::ptrdiff_t head_i = 0; head_i < parameters.kv_num_heads; ++head_i) { + ConcatStateChunk(past_key, K, present_key, + past_chunk_length, k_input_chunk_length, present_chunk_length, + parameters.kv_num_heads, parameters.head_size, batch_i, head_i, + parameters.transpose_output); + } + } + } + + // If present_key is not null, it is already initialized to zero. + // Main loop + // With 3D inputs, both Q and K are transposed with permutations (0, 2, 1, 3). + // To avoid expressing the transposition, we use GemmEx with different values for lda, ldb. + // If past_key is not null, then we need to concatenate it with K, the concatenation is not transposed. + const int loop_len = parameters.batch_size * parameters.q_num_heads; + const float alpha = parameters.scale; + + ThreadPool::TryParallelFor(tp, loop_len, unit_cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) { + for (std::ptrdiff_t i = begin; i != end; ++i) { + const ptrdiff_t output_offset = SafeInt(i) * probs_matrix_size; + std::ptrdiff_t batch_i = i / parameters.q_num_heads; + std::ptrdiff_t head_i = i % parameters.q_num_heads; + const ptrdiff_t mask_data_offset = probs_matrix_size * + (head_i % mask_num_heads + (batch_i % mask_batch_size) * mask_num_heads); + + T* output = attention_probs + output_offset; + T* out_qk = output_qk == nullptr ? nullptr : output_qk + output_offset; + float beta; + + if (mask_data != nullptr && + (out_qk == nullptr || parameters.qk_matmul_output_mode != attention_helper::QKMatMulOutputMode::kQK)) { + // Broadcast mask data: SxT -> SxT + memcpy(output, mask_data + mask_data_offset, probs_matrix_bytes); + beta = 1; + } else { + beta = 0; + } + + // handling GQA + std::ptrdiff_t ki = batch_i * parameters.kv_num_heads + head_i % parameters.kv_num_heads; + const T* k = K + k_input_chunk_length * ki; + + if (nullptr != present_key) { + if (parameters.kv_num_heads != parameters.q_num_heads) { + // Already done in a loop before this one. + k = present_key + ki * present_chunk_length; + } else { + k = ConcatStateChunk(past_key, K, present_key, + past_chunk_length, k_input_chunk_length, present_chunk_length, + parameters.kv_num_heads, parameters.head_size, batch_i, head_i, + parameters.transpose_output); + } + } + + // Compute Q*K' + AttentionMask + // original transposed each iteration + // A: Q (B x N x) S x H (B x N x) S x H S x H + // B: K' (B x N x) T x H (B x N x) H x T H x T + // C: attention_probs (B x N x) S x T (B x N x) S x T S x T + if constexpr (std::is_same::value) { + if (parameters.transpose_output) { + math::GemmEx(CblasNoTrans, + CblasTrans, + parameters.q_sequence_length, // M + parameters.total_sequence_length, // N + parameters.head_size, // K + alpha, + Q + q_input_chunk_length * parameters.q_num_heads * batch_i + head_i * parameters.head_size, + parameters.head_size * parameters.q_num_heads, // lda + transposed_k ? K + k_input_chunk_length * parameters.kv_num_heads * batch_i + head_i * parameters.head_size : k, + transposed_k ? parameters.head_size * parameters.kv_num_heads : parameters.head_size, // ldb + beta, + output, + parameters.total_sequence_length, // ldc + nullptr); + } else { + math::Gemm(CblasNoTrans, + CblasTrans, + parameters.q_sequence_length, // M + parameters.total_sequence_length, // N + parameters.head_size, // K + alpha, + Q + q_input_chunk_length * i, + k, + beta, + output, + nullptr); + } + } else if constexpr (std::is_same::value) { + if (MlasHGemmSupported(CblasNoTrans, CblasTrans)) { + MlasGemm(CblasNoTrans, + CblasTrans, + parameters.q_sequence_length, // M + parameters.total_sequence_length, // N + parameters.head_size, // K + parameters.transpose_output + ? Q + q_input_chunk_length * parameters.q_num_heads * batch_i + head_i * parameters.head_size + : Q + q_input_chunk_length * i, + parameters.transpose_output + ? parameters.head_size * parameters.q_num_heads + : static_cast(parameters.head_size), // lda + transposed_k + ? K + k_input_chunk_length * parameters.kv_num_heads * batch_i + head_i * parameters.head_size + : k, + transposed_k + ? parameters.head_size * parameters.kv_num_heads + : static_cast(parameters.head_size), // ldb + output, + static_cast(parameters.past_sequence_length + parameters.kv_sequence_length), // ldc + MLFloat16(alpha).val, MLFloat16(beta).val, nullptr); + } else { + if (parameters.transpose_output) { + math::GemmEx(CblasNoTrans, + CblasTrans, + parameters.q_sequence_length, // M + parameters.total_sequence_length, // N + parameters.head_size, // K + MLFloat16(alpha), + Q + q_input_chunk_length * parameters.q_num_heads * batch_i + head_i * parameters.head_size, + parameters.head_size * parameters.q_num_heads, // lda + transposed_k ? K + k_input_chunk_length * parameters.kv_num_heads * batch_i + head_i * parameters.head_size : k, + transposed_k ? parameters.head_size * parameters.kv_num_heads : parameters.head_size, // ldb + MLFloat16(beta), + output, + parameters.total_sequence_length, // ldc + nullptr); + } else { + TensorShape c_shape({parameters.q_sequence_length, parameters.total_sequence_length}); + Gemm_MLFloat16(CblasNoTrans, CblasTrans, + static_cast(parameters.q_sequence_length), // M + static_cast(parameters.total_sequence_length), // N + static_cast(parameters.head_size), // K + MLFloat16(alpha), + Q + q_input_chunk_length * i, + k, + MLFloat16(beta), + output, + &c_shape, + output, + nullptr); + } + } + } else { + ORT_THROW("Unsupported data type for attention Q*K multiplication: ", DataTypeImpl::ToString(DataTypeImpl::GetType())); + } + if (out_qk != nullptr && + (parameters.qk_matmul_output_mode == attention_helper::QKMatMulOutputMode::kQKMask || + parameters.qk_matmul_output_mode == attention_helper::QKMatMulOutputMode::kQK)) { + memcpy(out_qk, output, SafeInt(probs_matrix_size) * sizeof(T)); + if (mask_data != nullptr && parameters.qk_matmul_output_mode == attention_helper::QKMatMulOutputMode::kQK) { + // We need to add the bias we could not add because out_qk was requested without the mask. + // This can be optimized with vectorized add using MlasAddFloat32x4. + MlasEltwiseAdd(output, mask_data + mask_data_offset, output, probs_matrix_size); + } + } + if (parameters.softcap > 0.0f) { + if constexpr (std::is_same::value) { + ComputeAttentionSoftcapInplace(output, static_cast(probs_matrix_size), parameters.softcap); + } else if constexpr (std::is_same::value) { + ComputeAttentionSoftcapInplace(output, static_cast(probs_matrix_size), MLFloat16(parameters.softcap)); + } else { + ORT_THROW("Unsupported data type for ComputeAttentionSoftcapInplace: ", + DataTypeImpl::ToString(DataTypeImpl::GetType())); + } + } + if (out_qk != nullptr && parameters.qk_matmul_output_mode == attention_helper::QKMatMulOutputMode::kQKSoftCap) { + memcpy(out_qk, output, SafeInt(probs_matrix_size) * sizeof(T)); + } + ComputeAttentionSoftmaxInplace(output, parameters.q_sequence_length, parameters.total_sequence_length, nullptr, allocator); + + if (output_qk != nullptr && parameters.qk_matmul_output_mode == attention_helper::QKMatMulOutputMode::kQKSoftMax) { + memcpy(output_qk + output_offset, output, + SafeInt(parameters.q_sequence_length) * parameters.total_sequence_length * sizeof(T)); + } + } + }); + if (delete_mask_data) { + allocator->Free(mask_data); + } +} + +template +T* AttentionBase::ConcatStateChunk(const T* past, + const T* base_chunk, // chunk is K or V, it can be transposed or not + T* present, + size_t past_chunk_length, + size_t input_chunk_length, // chunk length of K or V + size_t present_chunk_length, + size_t num_heads, + size_t head_size, + std::ptrdiff_t batch_i, + std::ptrdiff_t head_i, + bool transposed) const { + std::ptrdiff_t i = batch_i * num_heads + head_i % num_heads; + + T* start = present + i * present_chunk_length; + + T* p = start; + if (nullptr != past) { + const T* src_past = past + i * past_chunk_length; + memcpy(p, src_past, past_chunk_length * sizeof(T)); + p += past_chunk_length; + } + + if (transposed) { + ORT_ENFORCE(head_size > 0 && num_heads > 0 && batch_i >= 0 && head_i >= 0, + "Invalid parameters for ConcatStateChunk: head_size=", head_size, ", batch_i=", batch_i, ", head_i=", head_i); + size_t sequence_length = SafeInt(input_chunk_length / head_size); + const T* chunk = base_chunk + head_i * head_size + input_chunk_length * num_heads * batch_i; + for (size_t j = 0; j < sequence_length; ++j) { + memcpy(p, chunk, head_size * sizeof(T)); + p += head_size; + chunk += num_heads * head_size; + } + } else { + const T* chunk = base_chunk + input_chunk_length * i; + memcpy(p, chunk, (present_chunk_length - past_chunk_length) * sizeof(T)); + } + return start; +} + +template +void AttentionBase::ComputeVxAttentionScore(T* output, // buffer for the result with size BxSxNxH_v + const T* attention_probs, // Attention probs with size BxNxSxT + const T* V, // V value with size BxNxLxH_v + int batch_size, // batch size + int sequence_length, // sequence length + int kv_sequence_length, // sequence length of K or V + int past_sequence_length, // sequence length in past state + int total_sequence_length, // total sequence length = past_sequence_length + kv_sequence_length + int v_head_size, // head size of V (H_v) + int num_heads, // number of attention heads + int kv_num_heads, // number of KV heads + const T* past_value, // past value only (if not using past state) + T* present_value, // present value only (if not using present state) + bool transpose_output, // whether to transpose the output (0, 2, 1, 3) + ThreadPool* tp) const { + ORT_ENFORCE((past_value == nullptr) == (present_value == nullptr), + "The implementation only supports past_value and present_value both null or both not null."); + const ptrdiff_t past_chunk_length = SafeInt(past_sequence_length) * v_head_size; // P x H_v + const ptrdiff_t v_input_chunk_length = SafeInt(kv_sequence_length) * v_head_size; // L x H_v + const ptrdiff_t present_chunk_length = past_chunk_length + v_input_chunk_length; // T x H_v + + // The cost of Gemm + TensorOpCost unit_cost; + unit_cost.compute_cycles = + static_cast(SafeInt(2) * sequence_length * v_head_size * total_sequence_length); + unit_cost.bytes_loaded = + static_cast(SafeInt(sequence_length + v_head_size) * total_sequence_length * sizeof(T)); + unit_cost.bytes_stored = static_cast(sequence_length * v_head_size * sizeof(T)); + + const size_t bytes_to_copy_trans = SafeInt(v_head_size) * sizeof(T); + double bytes_to_copy_trans_all = static_cast(sequence_length * bytes_to_copy_trans); + unit_cost.bytes_loaded += bytes_to_copy_trans_all; + unit_cost.bytes_stored += bytes_to_copy_trans_all; + + bool transposed_v = transpose_output && nullptr == present_value; + if (nullptr != present_value && kv_num_heads != num_heads) { + // This is not part of the main loop because it is not needed at every iteration and + // we cannot ensure the inner body is executed first before getting used in another iteration. + // parameters.batch_size * parameters.q_num_heads + for (std::ptrdiff_t batch_i = 0; batch_i < batch_size; ++batch_i) { + for (std::ptrdiff_t head_i = 0; head_i < kv_num_heads; ++head_i) { + ConcatStateChunk(past_value, V, present_value, + past_chunk_length, v_input_chunk_length, present_chunk_length, + kv_num_heads, v_head_size, batch_i, head_i, + transpose_output); + } + } + } + + ThreadPool::TryParallelFor( + tp, SafeInt(batch_size) * num_heads, unit_cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) { + for (std::ptrdiff_t i = begin; i != end; ++i) { + // handling GQA + std::ptrdiff_t batch_i = i / num_heads; + std::ptrdiff_t head_i = i % num_heads; + std::ptrdiff_t vi = batch_i * kv_num_heads + head_i % kv_num_heads; + const T* v = V + v_input_chunk_length * vi; + + if (nullptr != present_value) { + if (kv_num_heads != num_heads) { + // Already done in a loop before this one. + v = present_value + vi * present_chunk_length; + } else { + // transposed_v is false here. + v = ConcatStateChunk(past_value, V, present_value, + past_chunk_length, v_input_chunk_length, present_chunk_length, + kv_num_heads, v_head_size, batch_i, head_i, + transpose_output); + } + } + + if (transpose_output) { + // transpose_output is false + ptrdiff_t attention_probs_offset = SafeInt(sequence_length) * total_sequence_length * i; + + if constexpr (std::is_same::value) { + // V is transposed but not QK. We use GemmEx with a different value for ldb. + math::GemmEx(CblasNoTrans, + CblasNoTrans, + sequence_length, // M + v_head_size, // N + total_sequence_length, // K + 1.f, // alpha + attention_probs + attention_probs_offset, // QK + total_sequence_length, // lda + transposed_v ? V + head_i * v_head_size + v_input_chunk_length * kv_num_heads * batch_i : v, // V + transposed_v ? v_head_size * kv_num_heads : v_head_size, // ldb + 0.f, // beta + output + ((batch_i * sequence_length * num_heads + head_i) * v_head_size), + v_head_size * num_heads, // ldc + nullptr); + } else if constexpr (std::is_same::value) { + // This switch should probably be moved to math_cpu.h. + if (MlasHGemmSupported(CblasNoTrans, CblasNoTrans)) { + MlasGemm(CblasNoTrans, + CblasNoTrans, + sequence_length, // M + v_head_size, // N + total_sequence_length, // K + attention_probs + attention_probs_offset, + total_sequence_length, // lda + transposed_v ? V + head_i * v_head_size + v_input_chunk_length * kv_num_heads * batch_i : v, + transposed_v ? static_cast(v_head_size * kv_num_heads) : static_cast(v_head_size), // ldb + output + ((batch_i * sequence_length * num_heads + head_i) * v_head_size), + v_head_size * num_heads, // ldc + MLFloat16(1.f).val, MLFloat16(0.f).val, nullptr); + } else { + math::GemmEx(CblasNoTrans, + CblasNoTrans, + sequence_length, // M + v_head_size, // N + total_sequence_length, // K + MLFloat16(1.f), // alpha + attention_probs + attention_probs_offset, // QK + total_sequence_length, // lda + transposed_v ? V + head_i * v_head_size + v_input_chunk_length * kv_num_heads * batch_i : v, // V + transposed_v ? v_head_size * kv_num_heads : v_head_size, // ldb + MLFloat16(0.f), // beta + output + ((batch_i * sequence_length * num_heads + head_i) * v_head_size), + v_head_size * num_heads, // ldc + nullptr); + } + } else { + ORT_THROW("Unsupported data type for attention QK*V multiplication: ", + DataTypeImpl::ToString(DataTypeImpl::GetType())); + } + } else { + // transpose_output is false + ptrdiff_t attention_probs_offset = SafeInt(sequence_length) * total_sequence_length * i; + ptrdiff_t dest_offset = SafeInt(sequence_length) * v_head_size * i; + T* dest = output + dest_offset; + + if constexpr (std::is_same::value) { + math::MatMul(sequence_length, v_head_size, total_sequence_length, + attention_probs + attention_probs_offset, v, dest, nullptr); + } else if constexpr (std::is_same::value) { + if (MlasHGemmSupported(CblasNoTrans, CblasNoTrans)) { + MlasGemm(CblasNoTrans, + CblasNoTrans, + sequence_length, // M + v_head_size, // N + total_sequence_length, // K + attention_probs + attention_probs_offset, + total_sequence_length, // lda + v, + static_cast(v_head_size), // ldb + dest, + static_cast(v_head_size), // ldc + MLFloat16(1.f).val, MLFloat16(0.f).val, nullptr); + } else { + Gemm_MLFloat16(CblasNoTrans, + CblasNoTrans, + static_cast(sequence_length), // M + static_cast(v_head_size), // N + static_cast(total_sequence_length), // K + MLFloat16(1.f), // alpha + attention_probs + attention_probs_offset, + v, + MLFloat16(0.f), // beta + nullptr, + nullptr, + dest, + nullptr); + } + } else { + ORT_THROW("Unsupported data type for attention QK*V multiplication: ", + DataTypeImpl::ToString(DataTypeImpl::GetType())); + } + } + } + }); +} + +template +Status AttentionBase::ApplyAttention(OpKernelContext* context, + const T* Q, // Q data with shape BxNxSxH + const T* K, // K data with shape BxNxLxH + const T* V, // V value with size BxNxLxH_v + const Tensor* mask_index, // mask index. nullptr if no mask or its size is B + const Tensor* past_key, // past K input tensor (if not using past state) + const Tensor* past_value, // past V input tensor (if not using past state) + Tensor* output, // output tensor + Tensor* present_key, // present K output tensor (if separating present KV) + Tensor* present_value, // present V output tensor (if separating present KV) + Tensor* output_qk, // Q*K output tensor (if returning Q*K value) + const AttentionParameters& parameters // attention parameters +) const { + AllocatorPtr allocator; + ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator)); + + auto* tp = context->GetOperatorThreadPool(); + + const T* past_key_data = past_key != nullptr ? past_key->Data() : nullptr; + T* present_key_data = present_key != nullptr ? present_key->MutableData() : nullptr; + const T* past_value_data = past_value != nullptr ? past_value->Data() : nullptr; + T* present_value_data = present_value != nullptr ? present_value->MutableData() : nullptr; + T* output_qk_data = output_qk != nullptr ? output_qk->MutableData() : nullptr; + + // Compute the attention score. + size_t bytes = SafeInt(parameters.batch_size) * parameters.q_num_heads * + parameters.q_sequence_length * parameters.total_sequence_length * sizeof(T); + auto attention_probs = allocator->Alloc(bytes); + BufferUniquePtr scratch_buffer(attention_probs, BufferDeleter(allocator)); + this->ComputeAttentionProbs(static_cast(attention_probs), + Q, + K, + mask_index, + parameters, + past_key_data, + present_key_data, + output_qk_data, + tp, + allocator); + + this->ComputeVxAttentionScore(output->MutableData(), + static_cast(attention_probs), + V, + parameters.batch_size, + parameters.q_sequence_length, + parameters.kv_sequence_length, + parameters.past_sequence_length, + parameters.total_sequence_length, + parameters.v_head_size, + parameters.q_num_heads, + parameters.kv_num_heads, + past_value_data, + present_value_data, + parameters.transpose_output, + tp); + + return Status::OK(); +} + +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/llm/attention.h b/onnxruntime/core/providers/cpu/llm/attention.h new file mode 100644 index 0000000000000..78889e48afb29 --- /dev/null +++ b/onnxruntime/core/providers/cpu/llm/attention.h @@ -0,0 +1,88 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/common/common.h" +#include "core/framework/op_kernel.h" +#include "core/platform/threadpool.h" +#include "core/providers/cpu/llm/attention_helper.h" + +namespace onnxruntime { + +template +class AttentionBase : public OpKernel { + public: + AttentionBase(const OpKernelInfo& info) : OpKernel(info) {} + + Status ApplyAttention(OpKernelContext* context, + const T* Q, // Q data with shape BxNxSxH + const T* K, // K data with shape BxNxLxH + const T* V, // V value with size BxNxLxH_v + const Tensor* mask_index, // mask index. nullptr if no mask or its size is B + const Tensor* past_key, // past K input tensor (if not using past state) + const Tensor* past_value, // past V input tensor (if not using past state) + Tensor* output, // output tensor + Tensor* present_key, // present K output tensor (if separating present KV) + Tensor* present_value, // present V output tensor (if separating present KV) + Tensor* output_qk, // Q*K output tensor (if returning Q*K value) + const attention_helper::AttentionParameters& parameters // attention parameters + ) const; + + protected: + void ComputeVxAttentionScore(T* output, // buffer for the result with size BxSxNxH_v + const T* attention_probs, // Attention probs with size BxNxSxT + const T* V, // V value with size BxNxLxH_v + int batch_size, // batch size + int sequence_length, // sequence length + int kv_sequence_length, // sequence length of K or V + int past_sequence_length, // sequence length in past state + int total_sequence_length, // total sequence length = past_sequence_length + kv_sequence_length + int v_head_size, // head size of V (H_v) + int num_heads, // number of attention heads + int kv_num_heads, // number of KV heads + const T* past_value, // past value only (if not using past state) + T* present_value, // present value only (if not using present state) + bool transpose_output, // whether to transpose the output from BxNxSxH to BxSxNxH + concurrency::ThreadPool* tp) const; + + void ComputeAttentionProbs(T* attention_probs, // output buffer with size BxNxSxT + const T* Q, // Q data. Its size is BxNxSxH + const T* K, // k data. Its size is BxNxLxH + const Tensor* mask_index, // mask_index + const attention_helper::AttentionParameters& parameters, // attention parameters + const T* past_key, // past key only (if not using past state) + T* present_key, // present key only (if not using present state) + T* output_qk, // Q*K output + concurrency::ThreadPool* tp, + AllocatorPtr allocator) const; + + T* ConcatStateChunk(const T* past, + const T* chunk, + T* present, + size_t past_chunk_length, + size_t input_chunk_length, + size_t present_chunk_length, + size_t num_heads, + size_t head_size, + std::ptrdiff_t batch_i, + std::ptrdiff_t head_i, + bool transposed) const; +}; + +template +class Attention final : public AttentionBase { + public: + Attention(const OpKernelInfo& info); + Status Compute(OpKernelContext* context) const override; + + protected: + bool is_causal_; + int kv_num_heads_; + int q_num_heads_; + attention_helper::QKMatMulOutputMode qk_matmul_output_mode_; + float scale_; + float softcap_; + int softmax_precision_; +}; + +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/core/providers/cpu/llm/attention_helper.cc b/onnxruntime/core/providers/cpu/llm/attention_helper.cc new file mode 100644 index 0000000000000..9bd954f128454 --- /dev/null +++ b/onnxruntime/core/providers/cpu/llm/attention_helper.cc @@ -0,0 +1,156 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/cpu/llm/attention_helper.h" +#include "core/util/shape_checker.h" + +namespace onnxruntime { +namespace attention_helper { + +void AttentionParameters::checkParameters() const { + ORT_ENFORCE(batch_size > 0, "Batch size must be greater than 0"); + ORT_ENFORCE(q_sequence_length > 0, "Q sequence length must be greater than 0"); + ORT_ENFORCE(kv_sequence_length > 0, "KV sequence length must be greater than 0"); + ORT_ENFORCE(head_size > 0, "Head size must be greater than 0"); + ORT_ENFORCE(v_head_size > 0, "V head size must be greater than 0"); + ORT_ENFORCE(past_sequence_length >= 0, "Past sequence length must be non-negative"); + ORT_ENFORCE(total_sequence_length > 0, "Total sequence length must be greater than 0"); + ORT_ENFORCE(kv_num_heads > 0, "KV number of heads must be greater than 0"); + ORT_ENFORCE(q_num_heads > 0, "Q number of heads must be greater than 0"); + ORT_ENFORCE(total_sequence_length == past_sequence_length + kv_sequence_length, + "Total sequence length must be equal to past sequence length plus KV sequence length"); +} + +Status ComputeOutputShapeForAttention( + const Tensor* Q, + const Tensor* K, + const Tensor* V, + const Tensor* attn_mask, + const Tensor* past_key, + const Tensor* past_value, + bool is_causal, + float softcap, + int softmax_precision, + attention_helper::QKMatMulOutputMode qk_matmul_output_mode, + int kv_num_heads, + int q_num_heads, + float scale, + AttentionParameters& parameters, + std::vector& y_shape, + std::vector& present_key_shape, + std::vector& present_value_shape, + std::vector& output_qk_shape) { + ORT_ENFORCE(Q != nullptr && K != nullptr && V != nullptr, + "Q, K, and V inputs must not be null"); + int q_dims = onnxruntime::narrow(Q->Shape().NumDimensions()); + int k_dims = onnxruntime::narrow(K->Shape().NumDimensions()); + int v_dims = onnxruntime::narrow(V->Shape().NumDimensions()); + ORT_ENFORCE(q_dims == 3 || q_dims == 4, "Q must be a 3D or 4D tensor"); + ORT_ENFORCE(q_dims == k_dims, "Q and K must have the same rank."); + ORT_ENFORCE(q_dims == v_dims, "Q and V must have the same rank."); + + ORT_ENFORCE((past_key == nullptr) == (past_value == nullptr), "past_key and past_value must be both null or both not null"); + ORT_ENFORCE(Q->Shape()[0] == K->Shape()[0], "inconsistent batch_size (between Q and K)"); + ORT_ENFORCE(Q->Shape()[0] == V->Shape()[0], "inconsistent batch_size (between Q and V)"); + ORT_ENFORCE(past_key == nullptr || Q->Shape()[0] == past_key->Shape()[0], "inconsistent batch_size (between Q and past_key)"); + ORT_ENFORCE(past_value == nullptr || Q->Shape()[0] == past_value->Shape()[0], "inconsistent batch_size (between Q and past_value)"); + ORT_ENFORCE(past_value == nullptr || past_value->Shape()[2] == past_key->Shape()[2], "inconsistent past_sequence_length (between past_key and past_value)"); + + parameters.is_causal = is_causal; + parameters.softcap = softcap; + parameters.softmax_precision = softmax_precision; + parameters.qk_matmul_output_mode = qk_matmul_output_mode; // output mode for Q*K matmul + parameters.batch_size = onnxruntime::narrow(Q->Shape()[0]); // Q.shape[0], K.shape[0], V.shape[0] (4D) + + ORT_ENFORCE(parameters.batch_size > 0, "Batch size must be greater than 0"); + ORT_ENFORCE(attn_mask == nullptr || (attn_mask->Shape().NumDimensions() >= 2 && attn_mask->Shape().NumDimensions() <= 4), "attn_mask must be 2D or 3D or 4D tensor"); + + if (q_dims == 4) { + // 4D + parameters.kv_num_heads = kv_num_heads > 0 ? kv_num_heads : onnxruntime::narrow(K->Shape()[1]); // K.shape[1] or V.shape[1] (4D) + parameters.q_num_heads = q_num_heads > 0 ? q_num_heads : onnxruntime::narrow(Q->Shape()[1]); // Q.shape[1] (4D) + + ORT_ENFORCE(parameters.kv_num_heads == onnxruntime::narrow(K->Shape()[1]), "kv_num_heads different from K.shape[1]"); + ORT_ENFORCE(parameters.kv_num_heads == onnxruntime::narrow(V->Shape()[1]), "kv_num_heads different from V.shape[1]"); + ORT_ENFORCE(parameters.q_num_heads == onnxruntime::narrow(Q->Shape()[1]), "q_num_heads different from Q.shape[1]"); + ORT_ENFORCE(Q->Shape()[3] == K->Shape()[3], "inconsistent head_size"); + ORT_ENFORCE(K->Shape()[2] == V->Shape()[2], "inconsistent kv_sequence_length"); + ORT_ENFORCE(attn_mask == nullptr || attn_mask->Shape()[attn_mask->Shape().NumDimensions() - 2] == Q->Shape()[2], "inconsistent q_sequence_length (between attn_mask and Q)"); + + // From shapes + parameters.transpose_output = false; // whether to transpose the input/output with permutation (0, 2, 1, 3) + parameters.q_sequence_length = onnxruntime::narrow(Q->Shape()[2]); // Q.shape[2] (4D) + parameters.head_size = onnxruntime::narrow(Q->Shape()[3]); // Q.shape[3] (4D) + parameters.kv_sequence_length = onnxruntime::narrow(K->Shape()[2]); // K.shape[2] or V.shape[2] (4D) + parameters.v_head_size = onnxruntime::narrow(V->Shape()[3]); // V.shape[3] (4D) + parameters.past_sequence_length = past_key == nullptr // past_key.shape[2] or past_value.shape[2] (4D) or given by the mask + ? 0 + : onnxruntime::narrow(past_key->Shape()[2]); + + y_shape = {static_cast(parameters.batch_size), + static_cast(parameters.q_num_heads), + static_cast(parameters.q_sequence_length), + static_cast(parameters.v_head_size)}; + } else { + // 3D + parameters.kv_num_heads = kv_num_heads; + parameters.q_num_heads = q_num_heads; + + // From shapes + ORT_ENFORCE(Q->Shape()[2] % parameters.q_num_heads == 0, "inconsistent q_hidden_size, it should be a multiple of q_num_heads"); + ORT_ENFORCE(V->Shape()[2] % parameters.kv_num_heads == 0, "inconsistent v_hidden_size, it should be a multiple of kv_num_heads"); + + parameters.transpose_output = true; // whether to transpose the input/output with permutation (0, 2, 1, 3) + parameters.q_sequence_length = onnxruntime::narrow(Q->Shape()[1]); + parameters.head_size = onnxruntime::narrow(Q->Shape()[2]) / parameters.q_num_heads; + parameters.kv_sequence_length = onnxruntime::narrow(K->Shape()[1]); + parameters.v_head_size = onnxruntime::narrow(V->Shape()[2]) / parameters.kv_num_heads; + parameters.past_sequence_length = past_key == nullptr + ? 0 + : onnxruntime::narrow(past_key->Shape()[2]); + + y_shape = {static_cast(parameters.batch_size), + static_cast(parameters.q_sequence_length), + static_cast(parameters.q_num_heads * parameters.v_head_size)}; + } + parameters.total_sequence_length = parameters.past_sequence_length + parameters.kv_sequence_length; + + ORT_ENFORCE(parameters.q_num_heads % parameters.kv_num_heads == 0, "q_num_heads % kv_num_heads == 0 is not verified"); + ORT_ENFORCE(attn_mask == nullptr || attn_mask->Shape()[attn_mask->Shape().NumDimensions() - 1] == parameters.total_sequence_length, + "inconsistent total_sequence_length (between attn_mask and past_key and past_value)"); + ORT_ENFORCE(attn_mask == nullptr || + attn_mask->Shape().NumDimensions() < 3 || + attn_mask->Shape()[attn_mask->Shape().NumDimensions() - 3] == 1 || + attn_mask->Shape()[attn_mask->Shape().NumDimensions() - 3] == parameters.kv_num_heads, + "attn_mask must be broadcastable to (batch_size, kv_num_heads, q_sequence_length, total_sequence_length) but is not compatible with kv_num_heads"); + ORT_ENFORCE(attn_mask == nullptr || + attn_mask->Shape().NumDimensions() < 4 || + attn_mask->Shape()[0] == 1 || + attn_mask->Shape()[0] == parameters.batch_size, + "attn_mask must be broadcastable to (batch_size, kv_num_heads, q_sequence_length, total_sequence_length) but is not compatible with batch_size"); + ASSERT_TENSOR_DIMS(past_key, parameters.batch_size, parameters.kv_num_heads, parameters.past_sequence_length, parameters.head_size); + ASSERT_TENSOR_DIMS(past_value, parameters.batch_size, parameters.kv_num_heads, parameters.past_sequence_length, parameters.v_head_size); + + parameters.scale = std::isnan(scale) ? static_cast(1.0 / sqrt(parameters.head_size)) : scale; + parameters.checkParameters(); + + present_key_shape = {static_cast(parameters.batch_size), + static_cast(parameters.kv_num_heads), + static_cast(parameters.total_sequence_length), + static_cast(parameters.head_size)}; + present_value_shape = {static_cast(parameters.batch_size), + static_cast(parameters.kv_num_heads), + static_cast(parameters.total_sequence_length), + static_cast(parameters.v_head_size)}; + if (qk_matmul_output_mode == QKMatMulOutputMode::kNone) { + output_qk_shape.clear(); + } else { + output_qk_shape = {static_cast(parameters.batch_size), + static_cast(parameters.q_num_heads), + static_cast(parameters.q_sequence_length), + static_cast(parameters.total_sequence_length)}; + } + return Status::OK(); +} +} // namespace attention_helper +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/llm/attention_helper.h b/onnxruntime/core/providers/cpu/llm/attention_helper.h new file mode 100644 index 0000000000000..1cea27760408f --- /dev/null +++ b/onnxruntime/core/providers/cpu/llm/attention_helper.h @@ -0,0 +1,73 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/common/common.h" +#include "core/providers/common.h" + +namespace onnxruntime { +namespace attention_helper { + +// enum equivalent to the onnx defintion of qk_matmul_output_mode +enum QKMatMulOutputMode { + kNone = -1, // No output Q*K + kQK = 0, // Output Q*K + kQKMask = 1, // Output Q*K + Mask + kQKSoftCap = 2, // Output SoftCap(Q*K + Mask) + kQKSoftMax = 3, // Output SoftMax(SoftCap(Q*K + Mask)) +}; + +// Parameters deduced from node attributes and inputs/outputs. +struct AttentionParameters { + /* + * Attention Parameters + * MHA: q_num_heads == kv_num_heads -> MHA + * GQA: q_num_heads > kv_num_heads && q_num_heads % kv_num_heads == 0 + * MQA: q_num_heads > kv_num_heads && kv_num_heads == 1 + */ + bool is_causal; + int kv_num_heads; // K.shape[1] or V.shape[1] (4D) + int q_num_heads; // Q.shape[1] (4D) + float scale; + float softcap; + int softmax_precision; + QKMatMulOutputMode qk_matmul_output_mode; + + // From shapes + int batch_size; // Q.shape[0], K.shape[0], V.shape[0] (4D) + int q_sequence_length; // Q.shape[2] (4D) + int head_size; // Q.shape[3] or K.shape[3 (4D) + int kv_sequence_length; // K.shape[2] or V.shape[2] (4D) + int v_head_size; // V.shape[4] (4D) + int past_sequence_length; // pask_key.shape[2] or past_value.shape[2] (4D) + int total_sequence_length; // past_sequence_length + kv_sequence_length + bool transpose_output; // Whether to transpose the inputs and the outputs from BxNxSxH to BxSxNxH + // This covers the case where the inputs are 3D. + + // Checks the consistency of the parameters. + void checkParameters() const; +}; + +// Computes the output shape for attention based on the input tensors and parameters. +Status ComputeOutputShapeForAttention( + const Tensor* Q, + const Tensor* K, + const Tensor* V, + const Tensor* attn_mask, + const Tensor* past_key, + const Tensor* past_value, + bool is_causal, + float softcap, + int softmax_precision, + attention_helper::QKMatMulOutputMode qk_matmul_output_mode, + int kv_num_heads, + int q_num_heads, + float scale, + AttentionParameters& parameters, + std::vector& y_shape, + std::vector& present_key_shape, + std::vector& present_value_shape, + std::vector& output_qk_shape); + +} // namespace attention_helper +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/math/gemm.cc b/onnxruntime/core/providers/cpu/math/gemm.cc index 5406dd1a40446..181d0c5e98dd1 100644 --- a/onnxruntime/core/providers/cpu/math/gemm.cc +++ b/onnxruntime/core/providers/cpu/math/gemm.cc @@ -102,6 +102,7 @@ ONNX_CPU_OPERATOR_TYPED_KERNEL( bool GemmPackBFp32(AllocatorPtr& alloc, const Tensor& tensor_b, + bool trans_a, bool trans_b, IAllocatorUniquePtr& packed_b, size_t& packed_b_size, @@ -116,7 +117,7 @@ bool GemmPackBFp32(AllocatorPtr& alloc, const size_t K = trans_b ? static_cast(b_shape[1]) : static_cast(b_shape[0]); const size_t N = trans_b ? static_cast(b_shape[0]) : static_cast(b_shape[1]); - packed_b_size = MlasGemmPackBSize(N, K); + packed_b_size = MlasGemmPackBSize(trans_a ? CblasTrans : CblasNoTrans, trans_b ? CblasTrans : CblasNoTrans, N, K); if (packed_b_size == 0) { return false; } @@ -129,7 +130,8 @@ bool GemmPackBFp32(AllocatorPtr& alloc, // if and when we try to cache this pre-packed buffer for sharing between sessions. memset(packed_b_data, 0, packed_b_size); - MlasGemmPackB(trans_b ? CblasTrans : CblasNoTrans, + MlasGemmPackB(trans_a ? CblasTrans : CblasNoTrans, + trans_b ? CblasTrans : CblasNoTrans, N, K, tensor_b.Data(), @@ -174,15 +176,14 @@ void Gemm::ComputeGemm(CBLAS_TRANSPOSE trans_a, CBLAS_TRANSPOSE trans_b, thread_pool); } -template <> -void Gemm::ComputeGemm(CBLAS_TRANSPOSE trans_a, CBLAS_TRANSPOSE trans_b, - ptrdiff_t M, ptrdiff_t N, ptrdiff_t K, - MLFloat16 alpha, - const MLFloat16* a_data, const MLFloat16* b_data, - MLFloat16 beta, - const MLFloat16* c_data, const TensorShape* c_shape, - MLFloat16* y_data, - concurrency::ThreadPool* thread_pool) { +void Gemm_MLFloat16(CBLAS_TRANSPOSE trans_a, CBLAS_TRANSPOSE trans_b, + ptrdiff_t M, ptrdiff_t N, ptrdiff_t K, + MLFloat16 alpha, + const MLFloat16* a_data, const MLFloat16* b_data, + MLFloat16 beta, + const MLFloat16* c_data, const TensorShape* c_shape, + MLFloat16* y_data, + concurrency::ThreadPool* thread_pool) { // if input is empty tensor, return directly as nothing need to be calculated. if (M == 0 || N == 0) return; @@ -237,6 +238,18 @@ void Gemm::ComputeGemm(CBLAS_TRANSPOSE trans_a, CBLAS_TRANSPOSE trans #endif } +template <> +void Gemm::ComputeGemm(CBLAS_TRANSPOSE trans_a, CBLAS_TRANSPOSE trans_b, + ptrdiff_t M, ptrdiff_t N, ptrdiff_t K, + MLFloat16 alpha, + const MLFloat16* a_data, const MLFloat16* b_data, + MLFloat16 beta, + const MLFloat16* c_data, const TensorShape* c_shape, + MLFloat16* y_data, + concurrency::ThreadPool* thread_pool) { + Gemm_MLFloat16(trans_a, trans_b, M, N, K, alpha, a_data, b_data, beta, c_data, c_shape, y_data, thread_pool); +} + template void Gemm::ComputeGemm(CBLAS_TRANSPOSE trans_a, CBLAS_TRANSPOSE trans_b, ptrdiff_t M, ptrdiff_t N, ptrdiff_t K, float alpha, @@ -263,7 +276,7 @@ Status Gemm::PrePack(const Tensor& tensor, int input_idx, // only pack Matrix B if (input_idx == 1) { size_t packed_b_size; - is_packed = GemmPackBFp32(alloc, tensor, trans_B_ != CblasNoTrans, packed_b_, packed_b_size, b_shape_); + is_packed = GemmPackBFp32(alloc, tensor, trans_A_ != CblasNoTrans, trans_B_ != CblasNoTrans, packed_b_, packed_b_size, b_shape_); bool share_prepacked_weights = (prepacked_weights != nullptr); if (is_packed && share_prepacked_weights) { prepacked_weights->buffers_.push_back(std::move(packed_b_)); diff --git a/onnxruntime/core/providers/cpu/math/gemm.h b/onnxruntime/core/providers/cpu/math/gemm.h index 953949732560d..9876109c42df1 100644 --- a/onnxruntime/core/providers/cpu/math/gemm.h +++ b/onnxruntime/core/providers/cpu/math/gemm.h @@ -12,6 +12,15 @@ namespace onnxruntime { +void Gemm_MLFloat16(CBLAS_TRANSPOSE trans_a, CBLAS_TRANSPOSE trans_b, // 0, 1 + ptrdiff_t M, ptrdiff_t N, ptrdiff_t K, // 2, 3, 4 + MLFloat16 alpha, // 5 + const MLFloat16* a_data, const MLFloat16* b_data, // 6, 7 + MLFloat16 beta, // 8 + const MLFloat16* c_data, const TensorShape* c_shape, // 9, 10 + MLFloat16* y_data, // 11 + concurrency::ThreadPool* thread_pool); // 12 + template class Gemm : protected GemmBase, public OpKernel { public: diff --git a/onnxruntime/core/providers/cpu/math/gemm_matmul_common.h b/onnxruntime/core/providers/cpu/math/gemm_matmul_common.h index 599847e61a54f..0189edb23dddb 100644 --- a/onnxruntime/core/providers/cpu/math/gemm_matmul_common.h +++ b/onnxruntime/core/providers/cpu/math/gemm_matmul_common.h @@ -9,9 +9,9 @@ namespace onnxruntime { bool GemmPackBFp32(AllocatorPtr& alloc, const Tensor& tensor_b, + bool trans_a, bool trans_b, IAllocatorUniquePtr& packed_b, size_t& packed_b_size, TensorShape& b_shape); - }; // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/math/matmul.cc b/onnxruntime/core/providers/cpu/math/matmul.cc index 2c6d23e4de908..530218db31e3d 100644 --- a/onnxruntime/core/providers/cpu/math/matmul.cc +++ b/onnxruntime/core/providers/cpu/math/matmul.cc @@ -195,7 +195,7 @@ Status MatMul::PrePack(const Tensor& tensor, int input_idx, /*out*/ Alloc } else #endif { - is_packed = GemmPackBFp32(alloc, tensor, trans_b_attr_ != 0, packed_b_, packed_b_size, b_shape_); + is_packed = GemmPackBFp32(alloc, tensor, trans_a_attr_, trans_b_attr_ != 0, packed_b_, packed_b_size, b_shape_); } bool share_prepacked_weights = (prepacked_weights != nullptr); diff --git a/onnxruntime/core/providers/cpu/rnn/deep_cpu_gru.cc b/onnxruntime/core/providers/cpu/rnn/deep_cpu_gru.cc index c0171f7728ea8..d781de2eb5541 100644 --- a/onnxruntime/core/providers/cpu/rnn/deep_cpu_gru.cc +++ b/onnxruntime/core/providers/cpu/rnn/deep_cpu_gru.cc @@ -194,7 +194,7 @@ bool DeepCpuGruOp::TryPackInputWeights(const Tensor& weights, AllocatorPtr& allo const size_t N = static_cast(shape[1]); const size_t K = static_cast(shape[2]); - const size_t packed_weights_size = MlasGemmPackBSize(N, K); + const size_t packed_weights_size = MlasGemmPackBSize(CblasNoTrans, CblasTrans, N, K); if (packed_weights_size == 0) { return false; } @@ -215,7 +215,7 @@ bool DeepCpuGruOp::TryPackInputWeights(const Tensor& weights, AllocatorPtr& allo const size_t N_x_K = N * K; const auto* weights_data = weights.Data(); for (int64_t dir = 0; dir < num_directions; ++dir) { - MlasGemmPackB(CblasTrans, N, K, weights_data, K, packed_weights_data); + MlasGemmPackB(CblasNoTrans, CblasTrans, N, K, weights_data, K, packed_weights_data); weights_data += N_x_K; packed_weights_data += packed_weights_size; } @@ -244,12 +244,12 @@ bool DeepCpuGruOp::TryPackRecurrentWeights(const Tensor& weights, AllocatorPtr& const auto hidden_size_x_2 = N - hidden_size_; // We are making two packed buffers, one for ZR weights and another for H weights. - const size_t ZR_packed_size = MlasGemmPackBSize(narrow(hidden_size_x_2), narrow(K)); + const size_t ZR_packed_size = MlasGemmPackBSize(CblasNoTrans, CblasTrans, narrow(hidden_size_x_2), narrow(K)); if (ZR_packed_size == 0) { return false; } - const size_t H_packed_size = MlasGemmPackBSize(narrow(hidden_size_), narrow(K)); + const size_t H_packed_size = MlasGemmPackBSize(CblasNoTrans, CblasTrans, narrow(hidden_size_), narrow(K)); if (H_packed_size == 0) { return false; } @@ -275,18 +275,18 @@ bool DeepCpuGruOp::TryPackRecurrentWeights(const Tensor& weights, AllocatorPtr& const auto hidden_2_step = hidden_size_x_2 * K; const auto hidden_1_step = hidden_size_ * K; // square const auto* weights_data = weights.Data(); - MlasGemmPackB(CblasTrans, narrow(hidden_size_x_2), narrow(K), weights_data, narrow(K), buffer_ZR); + MlasGemmPackB(CblasNoTrans, CblasTrans, narrow(hidden_size_x_2), narrow(K), weights_data, narrow(K), buffer_ZR); weights_data += hidden_2_step; - MlasGemmPackB(CblasTrans, narrow(hidden_size_), narrow(K), weights_data, narrow(K), buffer_H); + MlasGemmPackB(CblasNoTrans, CblasTrans, narrow(hidden_size_), narrow(K), weights_data, narrow(K), buffer_H); if (num_directions == 2) { weights_data += hidden_1_step; buffer_ZR = static_cast(buffer_ZR) + ZR_packed_size; - MlasGemmPackB(CblasTrans, narrow(hidden_size_x_2), narrow(K), weights_data, narrow(K), buffer_ZR); + MlasGemmPackB(CblasNoTrans, CblasTrans, narrow(hidden_size_x_2), narrow(K), weights_data, narrow(K), buffer_ZR); weights_data += hidden_2_step; buffer_H = static_cast(buffer_H) + H_packed_size; - MlasGemmPackB(CblasTrans, narrow(hidden_size_), narrow(K), weights_data, narrow(K), buffer_H); + MlasGemmPackB(CblasNoTrans, CblasTrans, narrow(hidden_size_), narrow(K), weights_data, narrow(K), buffer_H); } return true; diff --git a/onnxruntime/core/providers/cpu/rnn/deep_cpu_lstm.cc b/onnxruntime/core/providers/cpu/rnn/deep_cpu_lstm.cc index e95ad707cf2b0..b38e271fdbe4a 100644 --- a/onnxruntime/core/providers/cpu/rnn/deep_cpu_lstm.cc +++ b/onnxruntime/core/providers/cpu/rnn/deep_cpu_lstm.cc @@ -196,7 +196,7 @@ Status DeepCpuLstmOp::TryPackWeights(const Tensor& weights, PackedWeights& packe return Status::OK(); } - const size_t packed_weights_size = MlasGemmPackBSize(N, K); + const size_t packed_weights_size = MlasGemmPackBSize(CblasNoTrans, CblasTrans, N, K); if (packed_weights_size == 0) { return Status::OK(); } @@ -217,7 +217,7 @@ Status DeepCpuLstmOp::TryPackWeights(const Tensor& weights, PackedWeights& packe const auto* weights_data = weights.Data(); for (int i = 0; i < num_directions_; i++) { - MlasGemmPackB(CblasTrans, N, K, weights_data, K, packed_weights_data); + MlasGemmPackB(CblasNoTrans, CblasTrans, N, K, weights_data, K, packed_weights_data); packed_weights_data = static_cast(packed_weights_data) + packed_weights_size; weights_data += N * K; } diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index 1f4c9fcdbc073..e036c7764d041 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -161,9 +161,7 @@ AllocatorPtr CUDAExecutionProvider::CreateCudaAllocator(OrtDevice::DeviceId devi {default_memory_arena_cfg ? *default_memory_arena_cfg : OrtArenaCfg(gpu_mem_limit, static_cast(arena_extend_strategy), -1, -1, -1, -1L)}, // make it stream aware - true, - // enable cross stream sharing? - false); + true); // CUDA malloc/free is expensive so always use an arena return CreateAllocator(default_memory_info); diff --git a/onnxruntime/core/providers/cuda/cuda_kernel.h b/onnxruntime/core/providers/cuda/cuda_kernel.h index 054dd9f9da9f3..bcbf1d4a1c800 100644 --- a/onnxruntime/core/providers/cuda/cuda_kernel.h +++ b/onnxruntime/core/providers/cuda/cuda_kernel.h @@ -41,8 +41,12 @@ class CudaKernel : public OpKernel { template inline IAllocatorUniquePtr GetScratchBuffer(size_t count_or_bytes, onnxruntime::Stream* stream) const { - if (count_or_bytes == 0) return nullptr; - return IAllocator::MakeUniquePtr(Info().GetAllocator(OrtMemType::OrtMemTypeDefault), count_or_bytes, false, stream, WaitCudaNotificationOnDevice); + if (count_or_bytes == 0) { + return nullptr; + } + + return IAllocator::MakeUniquePtr(Info().GetAllocator(OrtMemType::OrtMemTypeDefault), count_or_bytes, false, + stream); } // Different from GetScratchBuffer which use IAllocator::Alloc() to allocate memory, diff --git a/onnxruntime/core/providers/cuda/cuda_provider_factory.cc b/onnxruntime/core/providers/cuda/cuda_provider_factory.cc index 021a6f1e7e350..e8d133779f33c 100644 --- a/onnxruntime/core/providers/cuda/cuda_provider_factory.cc +++ b/onnxruntime/core/providers/cuda/cuda_provider_factory.cc @@ -355,8 +355,9 @@ struct CudaOrtAllocator : OrtAllocator { Alloc = AllocImpl; Free = FreeImpl; Info = InfoImpl; - Reserve = AllocImpl; // no special behavior for Reserve so use AllocImpl - GetStats = nullptr; // GetStatsImpl. The CUDA allocators don't have stats currently so we can skip. + Reserve = AllocImpl; // no special behavior for Reserve so use AllocImpl + GetStats = nullptr; // GetStatsImpl. The CUDA allocators don't have stats currently so we can skip. + AllocOnStream = nullptr; // TODO. Plugin EP arena to provide this. const OrtEpApi& ep_api = *api.GetEpApi(); const OrtMemoryDevice* mem_device = ep_api.MemoryInfo_GetMemoryDevice(mem_info); @@ -679,7 +680,6 @@ struct CudaEpFactory : OrtEpFactory { CreateAllocator = CreateAllocatorImpl; ReleaseAllocator = ReleaseAllocatorImpl; - CreateDataTransfer = CreateDataTransferImpl; IsStreamAware = IsStreamAwareImpl; diff --git a/onnxruntime/core/providers/cuda/cuda_stream_handle.cc b/onnxruntime/core/providers/cuda/cuda_stream_handle.cc index b6cbffb073774..fbee1841ae8d5 100644 --- a/onnxruntime/core/providers/cuda/cuda_stream_handle.cc +++ b/onnxruntime/core/providers/cuda/cuda_stream_handle.cc @@ -38,7 +38,7 @@ struct CudaNotification : public synchronize::Notification { void Activate() override { // record event with cudaEventBlockingSync so we can support sync on host without busy wait. - CUDA_CALL_THROW(cudaEventRecord(event_, static_cast(stream_.GetHandle()))); + CUDA_CALL_THROW(cudaEventRecord(event_, static_cast(GetStream().GetHandle()))); } void wait_on_device(Stream& device_stream) { diff --git a/onnxruntime/core/providers/cuda/cuda_stream_handle.h b/onnxruntime/core/providers/cuda/cuda_stream_handle.h index c75cf15f7c2f8..1be7a3d510082 100644 --- a/onnxruntime/core/providers/cuda/cuda_stream_handle.h +++ b/onnxruntime/core/providers/cuda/cuda_stream_handle.h @@ -48,8 +48,6 @@ struct CudaStream : Stream { onnxruntime::IAllocator* GetCpuAllocator() const { return cpu_allocator_.get(); } - WaitNotificationFn GetWaitNotificationFn() const override { return WaitCudaNotificationOnDevice; } - private: std::vector deferred_cpu_buffers_; AllocatorPtr cpu_allocator_; diff --git a/onnxruntime/core/providers/cuda/reduction/reduction_ops.cc b/onnxruntime/core/providers/cuda/reduction/reduction_ops.cc index 4f8e6605ce151..b232124dc6b00 100644 --- a/onnxruntime/core/providers/cuda/reduction/reduction_ops.cc +++ b/onnxruntime/core/providers/cuda/reduction/reduction_ops.cc @@ -367,7 +367,7 @@ Status ReduceComputeCore(const AllocatorPtr& gpu_allocator, const Tensor& input, IAllocatorUniquePtr input_data_buffer(nullptr, [](T*) {}); const CudaT* input_data = reinterpret_cast(input.Data()); if (calculate_sqt) { - input_data_buffer = IAllocator::MakeUniquePtr(gpu_allocator, input_count, false, ort_stream, WaitCudaNotificationOnDevice); + input_data_buffer = IAllocator::MakeUniquePtr(gpu_allocator, input_count, false, ort_stream); input_data = reinterpret_cast(input_data_buffer.get()); fast_divmod tmp_div; Impl_Mul(stream, static_cast(SimpleBroadcast::NoBroadcast), nullptr, @@ -384,7 +384,9 @@ Status ReduceComputeCore(const AllocatorPtr& gpu_allocator, const Tensor& input, } break; case ApplicableMatrixReduction::Columns: { const auto buffer_size_bytes = compute_reduce_matrix_columns_buffer_size(m, n); - auto buffer = buffer_size_bytes == 0 ? nullptr : IAllocator::MakeUniquePtr(gpu_allocator, buffer_size_bytes, false, ort_stream, WaitCudaNotificationOnDevice); + auto buffer = buffer_size_bytes == 0 + ? nullptr + : IAllocator::MakeUniquePtr(gpu_allocator, buffer_size_bytes, false, ort_stream); ORT_RETURN_IF_ERROR(reduce_matrix_columns(stream, input_data, reinterpret_cast(output.MutableData()), m, n, buffer.get(), buffer_size_bytes)); @@ -421,7 +423,7 @@ Status ReduceComputeCore(const AllocatorPtr& gpu_allocator, const Tensor& input, if ((ReduceTensorIndices == CUDNN_REDUCE_TENSOR_FLATTENED_INDICES && std::is_same::value) || (ReduceTensorIndices == CUDNN_REDUCE_TENSOR_NO_INDICES && std::is_same::value)) { // ArgMax/ArgMin with FP16 are not supported by cudnn, so convert input to fp32 then call cudnn - temp_X = IAllocator::MakeUniquePtr(gpu_allocator, input_count, false, ort_stream, WaitCudaNotificationOnDevice); + temp_X = IAllocator::MakeUniquePtr(gpu_allocator, input_count, false, ort_stream); Impl_Cast(stream, reinterpret_cast(input.Data()), temp_X.get(), input_shape.Size()); } else { cudnn_type_X = CudnnTensor::GetDataType(); @@ -444,18 +446,22 @@ Status ReduceComputeCore(const AllocatorPtr& gpu_allocator, const Tensor& input, CudaStream* cuda_stream = static_cast(ort_stream); CUDNN_RETURN_IF_ERROR(cudnnGetReductionWorkspaceSize(CudaKernel::GetCudnnHandle(cuda_stream), reduce_desc, input_tensor, output_tensor, &workspace_bytes)); - auto workspace_cuda = workspace_bytes == 0 ? nullptr : IAllocator::MakeUniquePtr(gpu_allocator, workspace_bytes, false, ort_stream, WaitCudaNotificationOnDevice); + auto workspace_cuda = workspace_bytes == 0 + ? nullptr + : IAllocator::MakeUniquePtr(gpu_allocator, workspace_bytes, false, ort_stream); size_t indices_bytes = 0; CUDNN_RETURN_IF_ERROR(cudnnGetReductionIndicesSize(CudaKernel::GetCudnnHandle(cuda_stream), reduce_desc, input_tensor, output_tensor, &indices_bytes)); - auto indices_cuda = indices_bytes == 0 ? nullptr : IAllocator::MakeUniquePtr(gpu_allocator, indices_bytes, false, ort_stream, WaitCudaNotificationOnDevice); + auto indices_cuda = indices_bytes == 0 + ? nullptr + : IAllocator::MakeUniquePtr(gpu_allocator, indices_bytes, false, ort_stream); if (ReduceTensorIndices == CUDNN_REDUCE_TENSOR_NO_INDICES) { IAllocatorUniquePtr input_data_buffer(nullptr, [](T*) {}); CudaT* input_data = nullptr; if (calculate_sqt) { - input_data_buffer = IAllocator::MakeUniquePtr(gpu_allocator, input_count, false, ort_stream, WaitCudaNotificationOnDevice); + input_data_buffer = IAllocator::MakeUniquePtr(gpu_allocator, input_count, false, ort_stream); input_data = reinterpret_cast(input_data_buffer.get()); fast_divmod tmp_div; Impl_Mul(stream, @@ -482,7 +488,9 @@ Status ReduceComputeCore(const AllocatorPtr& gpu_allocator, const Tensor& input, size_t indices_bytes_max = 0; CUDNN_RETURN_IF_ERROR(cudnnGetReductionIndicesSize(CudaKernel::GetCudnnHandle(cuda_stream), reduce_max_desc, input_tensor, output_tensor, &indices_bytes_max)); - auto indices_cuda_max = indices_bytes == 0 ? nullptr : IAllocator::MakeUniquePtr(gpu_allocator, indices_bytes, false, ort_stream, WaitCudaNotificationOnDevice); + auto indices_cuda_max = indices_bytes == 0 + ? nullptr + : IAllocator::MakeUniquePtr(gpu_allocator, indices_bytes, false, ort_stream); auto* p_output = reinterpret_cast(output.template MutableData()); CUDNN_RETURN_IF_ERROR(cudnnReduceTensor( CudaKernel::GetCudnnHandle(cuda_stream), reduce_max_desc, indices_cuda_max.get(), indices_bytes_max, @@ -493,9 +501,11 @@ Status ReduceComputeCore(const AllocatorPtr& gpu_allocator, const Tensor& input, // Exp(X-ReduceMax) const TensorShape output_shape(output_dims); - auto exp_result_buffer = IAllocator::MakeUniquePtr(gpu_allocator, input_count, false, ort_stream, WaitCudaNotificationOnDevice); + auto exp_result_buffer = IAllocator::MakeUniquePtr(gpu_allocator, input_count, false, ort_stream); auto exp_result = exp_result_buffer.get(); - auto log_sum_result_buffer = output_count == 0 ? nullptr : IAllocator::MakeUniquePtr(gpu_allocator, output_count, false, ort_stream, WaitCudaNotificationOnDevice); + auto log_sum_result_buffer = output_count == 0 + ? nullptr + : IAllocator::MakeUniquePtr(gpu_allocator, output_count, false, ort_stream); auto log_sum_result = log_sum_result_buffer.get(); BinaryElementwisePreparation prepare; ORT_RETURN_IF_ERROR(prepare.BinaryElementwiseBroadcastPrepareHelper(input_shape, output_shape, input_shape)); @@ -563,7 +573,9 @@ Status ReduceComputeCore(const AllocatorPtr& gpu_allocator, const Tensor& input, } } else { if (temp_X) { - auto temp_output = output_count == 0 ? nullptr : IAllocator::MakeUniquePtr(gpu_allocator, output_count, false, ort_stream, WaitCudaNotificationOnDevice); + auto temp_output = output_count == 0 + ? nullptr + : IAllocator::MakeUniquePtr(gpu_allocator, output_count, false, ort_stream); CUDNN_RETURN_IF_ERROR(cudnnReduceTensor( CudaKernel::GetCudnnHandle(cuda_stream), reduce_desc, indices_cuda.get(), indices_bytes, workspace_cuda.get(), workspace_bytes, @@ -589,14 +601,18 @@ Status ReduceComputeCore(const AllocatorPtr& gpu_allocator, const Tensor& input, CUDA_RETURN_IF_ERROR(cudaMemsetAsync(output.MutableData(), static_cast(0), output_count * sizeof(int64_t), stream)); } else { if (temp_X) { - auto temp_output = output_count == 0 ? nullptr : IAllocator::MakeUniquePtr(gpu_allocator, output_count, false, ort_stream, WaitCudaNotificationOnDevice); + auto temp_output = output_count == 0 + ? nullptr + : IAllocator::MakeUniquePtr(gpu_allocator, output_count, false, ort_stream); CUDNN_RETURN_IF_ERROR(cudnnReduceTensor( CudaKernel::GetCudnnHandle(cuda_stream), reduce_desc, indices_cuda.get(), indices_bytes, workspace_cuda.get(), workspace_bytes, &one, input_tensor, temp_X.get(), &zero, output_tensor, temp_output.get())); } else { - auto temp_output = output_count == 0 ? nullptr : IAllocator::MakeUniquePtr(gpu_allocator, output_count, false, ort_stream, WaitCudaNotificationOnDevice); + auto temp_output = output_count == 0 + ? nullptr + : IAllocator::MakeUniquePtr(gpu_allocator, output_count, false, ort_stream); CUDNN_RETURN_IF_ERROR(cudnnReduceTensor( CudaKernel::GetCudnnHandle(cuda_stream), reduce_desc, indices_cuda.get(), indices_bytes, workspace_cuda.get(), workspace_bytes, @@ -605,7 +621,8 @@ Status ReduceComputeCore(const AllocatorPtr& gpu_allocator, const Tensor& input, } // CUDA reduction index is uint32_t for now, cast it to int64_t according to ONNX spec - Impl_Cast(stream, reinterpret_cast(indices_cuda.get()), output.MutableData(), output_count); + Impl_Cast(stream, reinterpret_cast(indices_cuda.get()), + output.MutableData(), output_count); } } diff --git a/onnxruntime/core/providers/cuda/tunable/cuda_tuning_context.cc b/onnxruntime/core/providers/cuda/tunable/cuda_tuning_context.cc index 2df995d6e62ac..f4a33a128608a 100644 --- a/onnxruntime/core/providers/cuda/tunable/cuda_tuning_context.cc +++ b/onnxruntime/core/providers/cuda/tunable/cuda_tuning_context.cc @@ -120,7 +120,7 @@ IAllocatorUniquePtr CudaTuningContext::GetScratchBuffer( return nullptr; } - return IAllocator::MakeUniquePtr(it->second, num_bytes, false, stream, WaitCudaNotificationOnDevice); + return IAllocator::MakeUniquePtr(it->second, num_bytes, false, stream); } } // namespace tunable diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc index aa8b21ea3fe52..41b55e3baf508 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc @@ -323,9 +323,7 @@ AllocatorPtr MIGraphXExecutionProvider::CreateMIGraphXAllocator(OrtDevice::Devic : OrtArenaCfg(migx_mem_limit, static_cast(arena_extend_strategy), -1, -1, -1, -1L)}, // make it stream aware - true, - // enable cross stream sharing? - false); + true); // ROCM malloc/free is expensive so always use an arena return CreateAllocator(default_memory_info); diff --git a/onnxruntime/core/providers/migraphx/migraphx_stream_handle.cc b/onnxruntime/core/providers/migraphx/migraphx_stream_handle.cc index 8ed4e4a45a8c4..6e492327a73a3 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_stream_handle.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_stream_handle.cc @@ -18,11 +18,12 @@ struct MIGraphXNotification : public synchronize::Notification { void Activate() override { // record event with hipEventBlockingSync so we can support sync on host without busy wait. - HIP_CALL_THROW(hipEventRecord(event_, static_cast(stream_.GetHandle()))); + HIP_CALL_THROW(hipEventRecord(event_, static_cast(GetStream().GetHandle()))); } void wait_on_device(Stream& device_stream) { - ORT_ENFORCE(device_stream.GetDevice().Type() == OrtDevice::GPU, "Unexpected device:", device_stream.GetDevice().ToString()); + ORT_ENFORCE(device_stream.GetDevice().Type() == OrtDevice::GPU, "Unexpected device:", + device_stream.GetDevice().ToString()); // launch a wait command to the migraphx stream HIP_CALL_THROW(hipStreamWaitEvent(static_cast(device_stream.GetHandle()), event_, 0)); }; diff --git a/onnxruntime/core/providers/migraphx/migraphx_stream_handle.h b/onnxruntime/core/providers/migraphx/migraphx_stream_handle.h index d0ef3334b38c9..886103690c661 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_stream_handle.h +++ b/onnxruntime/core/providers/migraphx/migraphx_stream_handle.h @@ -29,8 +29,6 @@ struct MIGraphXStream : Stream { virtual void* GetResource(int version, int id) const; - virtual WaitNotificationFn GetWaitNotificationFn() const { return WaitMIGraphXNotificationOnDevice; } - private: std::vector deferred_cpu_buffers_; AllocatorPtr cpu_allocator_; diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory.cc b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory.cc index 428d24f2f3df8..e236cccaaaa77 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory.cc +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory.cc @@ -7,10 +7,14 @@ #include #include "nv_execution_provider.h" #include "nv_provider_factory_creator.h" +#include "nv_data_transfer.h" +#include "nv_allocator.h" #include "core/framework/provider_options.h" #include "core/providers/nv_tensorrt_rtx/nv_provider_options.h" #include "core/providers/nv_tensorrt_rtx/nv_execution_provider_custom_ops.h" #include +#include "core/providers/cuda/shared_inc/cuda_call.h" +#include "core/providers/cuda/cuda_stream_handle.h" using namespace onnxruntime; @@ -151,28 +155,385 @@ ORT_API(onnxruntime::Provider*, GetProvider) { } } -#include "core/framework/error_code_helper.h" +// +// Plug-in EP infrastructure +// + +#include "core/session/abi_devices.h" +#include "onnxruntime_config.h" // for ORT_VERSION + +struct ErrorHelper { + static const OrtApi* ort_api; + + static OrtStatus* ToOrtStatus(const Status& status) { + if (status.IsOK()) { + return nullptr; // no error + } + + return ort_api->CreateStatus(static_cast(status.Code()), + status.ErrorMessage().c_str()); + } +}; + +const OrtApi* ErrorHelper::ort_api = nullptr; + +#define RETURN_IF_ERROR(fn) \ + do { \ + OrtStatus* _status = (fn); \ + if (_status != nullptr) { \ + return _status; \ + } \ + } while (0) + +#define RETURN_IF_STATUS_NOTOK(fn) \ + do { \ + Status _status = (fn); \ + if (!_status.IsOK()) { \ + return ErrorHelper::ToOrtStatus(_status); \ + } \ + } while (0) + +#define CUDA_RETURN_IF_ERROR(expr) RETURN_IF_STATUS_NOTOK(CUDA_CALL(expr)) + +struct NvTrtRtxOrtAllocator : OrtAllocator { + NvTrtRtxOrtAllocator(const OrtMemoryInfo* mem_info, const OrtApi& api) : memory_info_{mem_info} { + version = ORT_API_VERSION; + Alloc = AllocImpl; + Free = FreeImpl; + Info = InfoImpl; + Reserve = AllocImpl; // no special behavior for Reserve so use AllocImpl + GetStats = nullptr; // GetStatsImpl. The CUDA allocators don't have stats currently so we can skip. + + const OrtEpApi& ep_api = *api.GetEpApi(); + const OrtMemoryDevice* mem_device = ep_api.MemoryInfo_GetMemoryDevice(mem_info); + uint32_t device_id = ep_api.MemoryDevice_GetDeviceId(mem_device); + const char* name = nullptr; + auto* status = api.MemoryInfoGetName(mem_info, &name); + static_cast(status); // GetName never fails + + if (ep_api.MemoryDevice_GetMemoryType(mem_device) == OrtDeviceMemoryType_HOST_ACCESSIBLE) { + allocator_ = std::make_unique(device_id, name); + } else { + allocator_ = std::make_unique(device_id, name); + } + } + + static void* ORT_API_CALL AllocImpl(struct OrtAllocator* this_, size_t size) { + auto& impl = *static_cast(this_); + return impl.allocator_->Alloc(size); + } + + static void ORT_API_CALL FreeImpl(struct OrtAllocator* this_, void* p) { + auto& impl = *static_cast(this_); + impl.allocator_->Free(p); + } + + static const struct OrtMemoryInfo* ORT_API_CALL InfoImpl(const struct OrtAllocator* this_) { + const NvTrtRtxOrtAllocator& impl = *static_cast(this_); + return impl.memory_info_; + } + + private: + const OrtMemoryInfo* memory_info_; + std::unique_ptr allocator_; +}; + +struct NvTrtRtxDataTransferImpl : OrtDataTransferImpl { + NvTrtRtxDataTransferImpl(const OrtApi& ort_api_in) + : ort_api{ort_api_in}, ep_api{*ort_api_in.GetEpApi()} { + ort_version_supported = ORT_API_VERSION; + CanCopy = CanCopyImpl; + CopyTensors = CopyTensorsImpl; + Release = ReleaseImpl; + } + + static bool CanCopyImpl(const OrtDataTransferImpl* this_ptr, + const OrtMemoryDevice* src_memory_device, + const OrtMemoryDevice* dst_memory_device) noexcept { + const auto& impl = *static_cast(this_ptr); + + // logic copied from GPUDataTransfer::CanCopy + OrtMemoryInfoDeviceType src_type = impl.ep_api.MemoryDevice_GetDeviceType(src_memory_device); + OrtMemoryInfoDeviceType dst_type = impl.ep_api.MemoryDevice_GetDeviceType(dst_memory_device); + auto src_vendor_id = impl.ep_api.MemoryDevice_GetVendorId(src_memory_device); + auto dst_vendor_id = impl.ep_api.MemoryDevice_GetVendorId(dst_memory_device); + + if ((src_type == OrtDevice::GPU && src_vendor_id != OrtDevice::VendorIds::NVIDIA) || + (dst_type == OrtDevice::GPU && dst_vendor_id != OrtDevice::VendorIds::NVIDIA)) { + return false; + } + + // copy must be GPU to GPU or between GPU and CPU + return (src_type == OrtMemoryInfoDeviceType_GPU && dst_type == OrtMemoryInfoDeviceType_GPU) || + (src_type == OrtMemoryInfoDeviceType_GPU && dst_type == OrtMemoryInfoDeviceType_CPU) || + (src_type == OrtMemoryInfoDeviceType_CPU && dst_type == OrtMemoryInfoDeviceType_GPU); + } + + static OrtStatus* CopyTensorsImpl(OrtDataTransferImpl* this_ptr, + const OrtValue** src_tensors, + OrtValue** dst_tensors, + OrtSyncStream** streams, + size_t num_tensors) noexcept { + auto& impl = *static_cast(this_ptr); + bool need_stream_sync = false; + + for (size_t idx = 0; idx < num_tensors; ++idx) { + const OrtValue* src_tensor = src_tensors[idx]; + OrtValue* dst_tensor = dst_tensors[idx]; + OrtSyncStream* stream = streams ? streams[idx] : nullptr; + + const OrtMemoryDevice* src_device = impl.ep_api.Value_GetMemoryDevice(src_tensor); + const OrtMemoryDevice* dst_device = impl.ep_api.Value_GetMemoryDevice(dst_tensor); + + size_t bytes; + RETURN_IF_ERROR(impl.ort_api.GetTensorSizeInBytes(src_tensor, &bytes)); + + const void* src_data = nullptr; + void* dst_data = nullptr; + RETURN_IF_ERROR(impl.ort_api.GetTensorData(src_tensor, &src_data)); + RETURN_IF_ERROR(impl.ort_api.GetTensorMutableData(dst_tensor, &dst_data)); + + OrtMemoryInfoDeviceType src_type = impl.ep_api.MemoryDevice_GetDeviceType(src_device); + OrtMemoryInfoDeviceType dst_type = impl.ep_api.MemoryDevice_GetDeviceType(dst_device); + OrtDeviceMemoryType src_mem_type = impl.ep_api.MemoryDevice_GetMemoryType(src_device); + OrtDeviceMemoryType dst_mem_type = impl.ep_api.MemoryDevice_GetMemoryType(dst_device); + + const bool src_is_gpu_default = src_type == OrtMemoryInfoDeviceType_GPU && + src_mem_type == OrtDeviceMemoryType_DEFAULT; + const bool dst_is_gpu_default = dst_type == OrtMemoryInfoDeviceType_GPU && + dst_mem_type == OrtDeviceMemoryType_DEFAULT; + + cudaStream_t cuda_stream = nullptr; + if (stream) { + cuda_stream = static_cast(impl.ort_api.SyncStream_GetHandle(stream)); + } + + if (dst_is_gpu_default) { + if (src_is_gpu_default) { + // Copy only if the two addresses are different. + if (dst_data != src_data) { + if (cuda_stream) { + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(dst_data, src_data, bytes, cudaMemcpyDeviceToDevice, cuda_stream)); + + } else { + CUDA_RETURN_IF_ERROR(cudaMemcpy(dst_data, src_data, bytes, cudaMemcpyDeviceToDevice)); + + // For device memory to device memory copy, no host-side synchronization is performed by cudaMemcpy. + // see https://docs.nvidia.com/cuda/cuda-runtime-api/api-sync-behavior.html + need_stream_sync = true; + } + } + } else { + // copy from pinned or non-pinned CPU memory to GPU + if (cuda_stream) { + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(dst_data, src_data, bytes, cudaMemcpyHostToDevice, cuda_stream)); + } else { + CUDA_RETURN_IF_ERROR(cudaMemcpy(dst_data, src_data, bytes, cudaMemcpyHostToDevice)); + + if (src_mem_type != OrtDeviceMemoryType_HOST_ACCESSIBLE) { + // For cudaMemcpy from pageable host memory to device memory, DMA to final destination may not + // have completed. + // see https://docs.nvidia.com/cuda/cuda-runtime-api/api-sync-behavior.html + need_stream_sync = true; + } + } + } + } else if (src_is_gpu_default) { + // copying from GPU to CPU memory, this is blocking + + if (cuda_stream) { + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(dst_data, src_data, bytes, cudaMemcpyDeviceToHost, cuda_stream)); + + } else { + CUDA_RETURN_IF_ERROR(cudaMemcpy(dst_data, src_data, bytes, cudaMemcpyDeviceToHost)); + } + } else { + // copying between CPU accessible memory + + if (dst_data != src_data) { + if (cuda_stream) { + if (src_mem_type == OrtDeviceMemoryType_HOST_ACCESSIBLE) { + // sync the stream first to make sure the data arrived + CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(cuda_stream)); + } + } + + memcpy(dst_data, src_data, bytes); + } + } + } + + if (need_stream_sync) { + CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(nullptr)); + } + + return nullptr; + } + + static void ReleaseImpl(OrtDataTransferImpl* /*this_ptr*/) noexcept { + // no-op as we have a single shared instance in OrtEpFactory which is returned from CreateDataTransferImpl, and is + // owned by and freed by the factory. + } + + const OrtApi& ort_api; + const OrtEpApi& ep_api; +}; + +struct NvTrtRtxSyncNotificationImpl : OrtSyncNotificationImpl { + static OrtStatus* Create(cudaStream_t stream, const OrtApi& ort_api, + std::unique_ptr& notification) { + notification.reset(new NvTrtRtxSyncNotificationImpl(stream, ort_api)); // can't use make_unique with private ctor + CUDA_RETURN_IF_ERROR(cudaEventCreateWithFlags(¬ification->event_, cudaEventDisableTiming)); + + return nullptr; + } + + static void ReleaseImpl(_In_ OrtSyncNotificationImpl* this_ptr) noexcept { + delete static_cast(this_ptr); + } + + static OrtStatus* ActivateImpl(_In_ OrtSyncNotificationImpl* this_ptr) noexcept { + auto& impl = *static_cast(this_ptr); + CUDA_RETURN_IF_ERROR(cudaEventRecord(impl.event_, impl.stream_)); + + return nullptr; + } + + static OrtStatus* WaitOnDeviceImpl(_In_ OrtSyncNotificationImpl* this_ptr, + _In_ OrtSyncStream* consumer_stream) noexcept { + auto& impl = *static_cast(this_ptr); + + // setup the consumer stream to wait on our event. + void* consumer_handle = impl.ort_api.SyncStream_GetHandle(consumer_stream); + CUDA_RETURN_IF_ERROR(cudaStreamWaitEvent(static_cast(consumer_handle), impl.event_)); + + return nullptr; + } + + static OrtStatus* WaitOnHostImpl(_In_ OrtSyncNotificationImpl* this_ptr) noexcept { + auto& impl = *static_cast(this_ptr); + CUDA_RETURN_IF_ERROR(cudaEventSynchronize(impl.event_)); + + return nullptr; + } + + ~NvTrtRtxSyncNotificationImpl() { + cudaEventDestroy(event_); + } + + private: + NvTrtRtxSyncNotificationImpl(cudaStream_t stream, const OrtApi& ort_api_in) + : stream_{stream}, ort_api{ort_api_in}, ep_api{*ort_api_in.GetEpApi()} { + ort_version_supported = ORT_API_VERSION; + Activate = ActivateImpl; + WaitOnDevice = WaitOnDeviceImpl; + WaitOnHost = WaitOnHostImpl; + Release = ReleaseImpl; + } + + cudaStream_t& stream_; + cudaEvent_t event_; + + const OrtApi& ort_api; + const OrtEpApi& ep_api; +}; + +struct NvTrtRtxSyncStreamImpl : OrtSyncStreamImpl { + NvTrtRtxSyncStreamImpl(cudaStream_t&& stream, + const OrtDevice& device, + AllocatorPtr cpu_allocator, + bool release_cpu_buffer_on_cuda_stream, + const OrtApi& ort_api_in) + : stream_{ + stream, device, cpu_allocator, release_cpu_buffer_on_cuda_stream, /*own*/ true, + /*external_cudnn_handle*/ nullptr, + /*external_cublas_handle*/ nullptr, + // ep_info is used by GetResource which seems to be a somewhat ugly way to make arbitrary info that is + // unrelated to the stream available to a custom op. + // avoiding adding GetResource to OrtSyncStreamImpl as we should have a cleaner setup for custom ops, + // so this argument value isn't used and doesn't matter. + /*ep_info*/ CUDAExecutionProviderInfo{}}, + ort_api{ort_api_in} { + ort_version_supported = ORT_API_VERSION; + GetHandle = GetHandleImpl; + CreateNotification = CreateNotificationImpl; + Flush = FlushImpl; + OnSessionRunEnd = OnSessionRunEndImpl; + Release = ReleaseImpl; + } + + static void ReleaseImpl(_In_ OrtSyncStreamImpl* this_ptr) noexcept { + delete static_cast(this_ptr); + } + + static void* GetHandleImpl(_In_ OrtSyncStreamImpl* this_ptr) noexcept { + auto& impl = *static_cast(this_ptr); + return impl.stream_.GetHandle(); + } + + static OrtStatus* CreateNotificationImpl(_In_ OrtSyncStreamImpl* this_ptr, + _Outptr_ OrtSyncNotificationImpl** notification_impl) noexcept { + auto& impl = *static_cast(this_ptr); + *notification_impl = nullptr; + + std::unique_ptr notification; + cudaStream_t* cuda_stream = static_cast(impl.stream_.GetHandle()); + + RETURN_IF_ERROR(NvTrtRtxSyncNotificationImpl::Create(*cuda_stream, impl.ort_api, notification)); + *notification_impl = notification.release(); + + return nullptr; + } + + static OrtStatus* FlushImpl(_In_ OrtSyncStreamImpl* this_ptr) noexcept { + auto& impl = *static_cast(this_ptr); + impl.stream_.Flush(); + + return nullptr; + } + + static OrtStatus* OnSessionRunEndImpl(_In_ OrtSyncStreamImpl* this_ptr) noexcept { + auto& impl = *static_cast(this_ptr); + RETURN_IF_STATUS_NOTOK(impl.stream_.CleanUpOnRunEnd()); + + return nullptr; + } + + private: + // this is a little onion-ish as CudaStream is a onnxruntime::Stream and this is an OrtSyncStreamImpl that will be + // used via plugin_ep::Stream, which is also an onnxruntime::Stream. in a 'real' plugin EP implementation + // CudaStream would go away and the logic it has would be implemented directly here. + CudaStream stream_; + const OrtApi& ort_api; +}; // OrtEpApi infrastructure to be able to use the NvTensorRTRTX EP as an OrtEpFactory for auto EP selection. struct NvTensorRtRtxEpFactory : OrtEpFactory { + using MemoryInfoUniquePtr = std::unique_ptr>; + NvTensorRtRtxEpFactory(const OrtApi& ort_api_in, - const OrtLogger& default_logger_in, - OrtHardwareDeviceType hw_type) - : ort_api{ort_api_in}, default_logger{default_logger_in}, ort_hw_device_type{hw_type} { + const OrtLogger& default_logger_in) : ort_api{ort_api_in}, + ep_api{*ort_api_in.GetEpApi()}, + default_logger{default_logger_in}, + data_transfer_impl{ort_api_in} { GetName = GetNameImpl; GetVendor = GetVendorImpl; GetVendorId = GetVendorIdImpl; GetVersion = GetVersionImpl; + GetVendorId = GetVendorIdImpl; GetSupportedDevices = GetSupportedDevicesImpl; CreateEp = CreateEpImpl; ReleaseEp = ReleaseEpImpl; CreateAllocator = CreateAllocatorImpl; ReleaseAllocator = ReleaseAllocatorImpl; + CreateDataTransfer = CreateDataTransferImpl; IsStreamAware = IsStreamAwareImpl; CreateSyncStreamForDevice = CreateSyncStreamForDeviceImpl; + + ort_version_supported = ORT_API_VERSION; // Set to the ORT version we were compiled with. } // Returns the name for the EP. Each unique factory configuration must have a unique name. @@ -211,18 +572,36 @@ struct NvTensorRtRtxEpFactory : OrtEpFactory { size_t& num_ep_devices = *p_num_ep_devices; auto* factory = static_cast(this_ptr); + int num_cuda_devices = 0; + cudaGetDeviceCount(&num_cuda_devices); + RETURN_IF_ERROR(factory->CreateMemoryInfoForDevices(num_cuda_devices)); + + int16_t device_id = 0; for (size_t i = 0; i < num_devices && num_ep_devices < max_ep_devices; ++i) { const OrtHardwareDevice& device = *devices[i]; - if (factory->ort_api.HardwareDevice_Type(&device) == factory->ort_hw_device_type && + if (factory->ort_api.HardwareDevice_Type(&device) == OrtHardwareDeviceType::OrtHardwareDeviceType_GPU && factory->ort_api.HardwareDevice_VendorId(&device) == factory->vendor_id) { OrtKeyValuePairs* ep_options = nullptr; + OrtKeyValuePairs* ep_metadata = nullptr; + factory->ort_api.CreateKeyValuePairs(&ep_options); - ORT_API_RETURN_IF_ERROR( - factory->ort_api.GetEpApi()->CreateEpDevice(factory, &device, nullptr, ep_options, - &ep_devices[num_ep_devices++])); + factory->ort_api.CreateKeyValuePairs(&ep_metadata); + factory->ort_api.AddKeyValuePair(ep_options, "device_id", std::to_string(device_id).c_str()); + + RETURN_IF_ERROR(factory->ort_api.GetEpApi()->CreateEpDevice(factory, &device, ep_metadata, ep_options, + &ep_devices[num_ep_devices])); + factory->ort_api.ReleaseKeyValuePairs(ep_options); + factory->ort_api.ReleaseKeyValuePairs(ep_metadata); + + const OrtMemoryInfo* gpu_mem_info = factory->gpu_memory_infos[device_id].get(); + const OrtMemoryInfo* host_accessible_mem_info = factory->host_accessible_memory_infos[device_id].get(); + + RETURN_IF_ERROR(factory->ep_api.EpDevice_AddAllocatorInfo(ep_devices[num_ep_devices], gpu_mem_info)); + RETURN_IF_ERROR(factory->ep_api.EpDevice_AddAllocatorInfo(ep_devices[num_ep_devices], host_accessible_mem_info)); + num_ep_devices++; + device_id++; } } - return nullptr; } @@ -241,50 +620,99 @@ struct NvTensorRtRtxEpFactory : OrtEpFactory { } static OrtStatus* ORT_API_CALL CreateAllocatorImpl(OrtEpFactory* this_ptr, - const OrtMemoryInfo* /*memory_info*/, + const OrtMemoryInfo* memory_info, const OrtKeyValuePairs* /*allocator_options*/, OrtAllocator** allocator) noexcept { - auto* factory = static_cast(this_ptr); - - *allocator = nullptr; - return factory->ort_api.CreateStatus( - ORT_INVALID_ARGUMENT, - "CreateAllocator should not be called as we did not add OrtMemoryInfo to our OrtEpDevice."); + auto& factory = *static_cast(this_ptr); + auto allocator_ = std::make_unique(memory_info, factory.ort_api); + *allocator = allocator_.release(); + return nullptr; } - static void ORT_API_CALL ReleaseAllocatorImpl(OrtEpFactory* /*this_ptr*/, OrtAllocator* /*allocator*/) noexcept { - // should never be called as we don't implement CreateAllocator + static void ORT_API_CALL ReleaseAllocatorImpl(OrtEpFactory* /*this*/, OrtAllocator* allocator) noexcept { + delete static_cast(allocator); } - static OrtStatus* ORT_API_CALL CreateDataTransferImpl(OrtEpFactory* /*this_ptr*/, + static OrtStatus* ORT_API_CALL CreateDataTransferImpl(OrtEpFactory* this_ptr, OrtDataTransferImpl** data_transfer) noexcept { - *data_transfer = nullptr; // not implemented + auto& factory = *static_cast(this_ptr); + *data_transfer = &factory.data_transfer_impl; return nullptr; } static bool ORT_API_CALL IsStreamAwareImpl(const OrtEpFactory* /*this_ptr*/) noexcept { - return false; + return true; } static OrtStatus* ORT_API_CALL CreateSyncStreamForDeviceImpl(OrtEpFactory* this_ptr, - const OrtMemoryDevice* /*memory_device*/, + const OrtMemoryDevice* memory_device, const OrtKeyValuePairs* /*stream_options*/, - OrtSyncStreamImpl** stream) noexcept { - auto* factory = static_cast(this_ptr); + OrtSyncStreamImpl** ort_stream) noexcept { + auto& factory = *static_cast(this_ptr); + + auto device_id = factory.ep_api.MemoryDevice_GetDeviceId(memory_device); + cudaStream_t stream = nullptr; + CUDA_RETURN_IF_ERROR(cudaSetDevice(device_id)); + CUDA_RETURN_IF_ERROR(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking)); - *stream = nullptr; - return factory->ort_api.CreateStatus( - ORT_INVALID_ARGUMENT, "CreateSyncStreamForDevice should not be called as IsStreamAware returned false."); + const OrtDevice* ort_device = static_cast(memory_device); + + auto impl = std::make_unique(std::move(stream), *ort_device, nullptr, + /*release_cpu_buffer_on_cuda_stream*/ true, + factory.ort_api); + *ort_stream = impl.release(); + return nullptr; } + OrtStatus* CreateMemoryInfoForDevices(int num_devices) { + gpu_memory_infos.reserve(num_devices); + host_accessible_memory_infos.reserve(num_devices); + + for (int device_id = 0; device_id < num_devices; ++device_id) { + OrtMemoryInfo* mem_info = nullptr; + RETURN_IF_ERROR(ort_api.CreateMemoryInfo_V2("NvTensorRTRTX", OrtMemoryInfoDeviceType_GPU, + /*vendor*/ OrtDevice::VendorIds::NVIDIA, + /* device_id */ device_id, + OrtDeviceMemoryType_DEFAULT, + /*alignment*/ 0, + OrtAllocatorType::OrtDeviceAllocator, + &mem_info)); + gpu_memory_infos.emplace_back(MemoryInfoUniquePtr(mem_info, ort_api.ReleaseMemoryInfo)); + + mem_info = nullptr; + RETURN_IF_ERROR(ort_api.CreateMemoryInfo_V2("NvTensorRTRTX host accessible", OrtMemoryInfoDeviceType_GPU, + /*vendor*/ OrtDevice::VendorIds::NVIDIA, + /* device_id */ device_id, + OrtDeviceMemoryType_HOST_ACCESSIBLE, + /*alignment*/ 0, + OrtAllocatorType::OrtDeviceAllocator, + &mem_info)); + host_accessible_memory_infos.emplace_back(MemoryInfoUniquePtr(mem_info, ort_api.ReleaseMemoryInfo)); + } + return nullptr; + } + + private: const OrtApi& ort_api; + const OrtEpApi& ep_api; const OrtLogger& default_logger; const std::string ep_name{kNvTensorRTRTXExecutionProvider}; const std::string vendor{"NVIDIA"}; // NVIDIA vendor ID. Refer to the ACPI ID registry (search NVIDIA): https://uefi.org/ACPI_ID_List const uint32_t vendor_id{0x10de}; - const OrtHardwareDeviceType ort_hw_device_type; // Supported OrtHardwareDevice + + std::vector gpu_memory_infos; + std::vector host_accessible_memory_infos; + + // we use a shared instance for the OrtDataTransferImpl instead of creating a new one on every call to + NvTrtRtxDataTransferImpl data_transfer_impl; + + NvTensorRtRtxEpFactory(const NvTensorRtRtxEpFactory&) = delete; + NvTensorRtRtxEpFactory& operator=(const NvTensorRtRtxEpFactory&) = delete; + + NvTensorRtRtxEpFactory(NvTensorRtRtxEpFactory&&) = default; + NvTensorRtRtxEpFactory& operator=(NvTensorRtRtxEpFactory&&) = default; }; extern "C" { @@ -297,14 +725,14 @@ OrtStatus* CreateEpFactories(const char* /*registration_name*/, const OrtApiBase const OrtApi* ort_api = ort_api_base->GetApi(ORT_API_VERSION); // Factory could use registration_name or define its own EP name. - auto factory_gpu = std::make_unique(*ort_api, *default_logger, OrtHardwareDeviceType_GPU); + auto factory = std::make_unique(*ort_api, *default_logger); if (max_factories < 1) { return ort_api->CreateStatus(ORT_INVALID_ARGUMENT, "Not enough space to return EP factory. Need at least one."); } - factories[0] = factory_gpu.release(); + factories[0] = factory.release(); *num_factories = 1; return nullptr; diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/einsum_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/einsum_op_builder.cc index 9db0b5202dcd4..7e17addf2f577 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/einsum_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/einsum_op_builder.cc @@ -45,13 +45,7 @@ std::optional ParseEquation(std::string_view equation_string) { if (term_1.empty() || term_2.empty()) { return std::nullopt; } - if (term_1.size() < 2) { - return std::nullopt; - } - if (term_1.size() != term_2.size()) { - return std::nullopt; - } - if (term_1.size() != result.size()) { + if (term_1.size() < 2 || term_2.size() < 2 || result.size() < 2) { return std::nullopt; } if (!std::all_of(term_1.begin(), term_1.end(), [](unsigned char c) { return std::islower(c); })) { @@ -154,6 +148,50 @@ bool IsEquationMatMulTransposeAll(const Equation& equation) { return true; } +bool IsEquationMatMulBroadcastTransposeY(const Equation& equation) { + // E.g., bhwc,hkc->bhwk + const auto& [term_1, term_2, result] = equation; + const size_t term1_dims = term_1.size(); + if (term1_dims != 4) { + return false; + } + const size_t term2_dims = term_2.size(); + if (term2_dims != 3) { + return false; + } + const size_t result_dims = result.size(); + if (result_dims != 4) { + return false; + } + // Check matrix multiplication dimensions + char term_1_m = term_1[term1_dims - 2]; + char term_1_k = term_1[term1_dims - 1]; + char term_2_k = term_2[term2_dims - 1]; + char term_2_n = term_2[term2_dims - 2]; + char result_m = result[result_dims - 2]; + char result_n = result[result_dims - 1]; + if (term_1_m != result_m) { + return false; + } + if (term_1_k != term_2_k) { + return false; + } + if (term_2_n != result_n) { + return false; + } + // Check batch dimensions + if (term_1[0] != result[0]) { + return false; + } + if (term_1[1] != result[1]) { + return false; + } + if (term_2[0] != result[1]) { + return false; + } + return true; +} + /** * @brief Sets the parameter tensor names for a MatMul op. * @@ -317,6 +355,7 @@ Status EinsumOpBuilder::IsOpSupported(QnnModelWrapper& qnn_model_wrapper, } if (!IsEquationMatMul(parsed_equation.value()) && !IsEquationMatMulTransposeY(parsed_equation.value()) && + !IsEquationMatMulBroadcastTransposeY(parsed_equation.value()) && !IsEquationMatMulTransposeAll(parsed_equation.value())) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, node_unit.OpType() + " unsupported equation: " + equation); } @@ -353,7 +392,8 @@ Status EinsumOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_w /*logger=*/logger, /*do_op_validation=*/do_op_validation, /*qnn_op_type=*/QNN_OP_MAT_MUL)); - } else if (IsEquationMatMulTransposeY(parsed_equation.value())) { + } else if (IsEquationMatMulTransposeY(parsed_equation.value()) || + IsEquationMatMulBroadcastTransposeY(parsed_equation.value())) { std::vector param_tensor_names = SetMatMulParamTensorNames( &qnn_model_wrapper, node_unit, /*transpose_in0=*/false, /*transpose_in1=*/true); ORT_RETURN_IF_ERROR(ProcessOutputs(/*qnn_model_wrapper=*/qnn_model_wrapper, @@ -364,7 +404,10 @@ Status EinsumOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_w /*do_op_validation=*/do_op_validation, /*qnn_op_type=*/QNN_OP_MAT_MUL)); } else if (IsEquationMatMulTransposeAll(parsed_equation.value())) { - ORT_RETURN_IF_ERROR(CreateMatMulTransposeAll(&qnn_model_wrapper, node_unit, std::move(input_names), do_op_validation)); + ORT_RETURN_IF_ERROR(CreateMatMulTransposeAll(/*qnn_model_wrapper=*/&qnn_model_wrapper, + /*node_unit=*/node_unit, + /*input_names=*/std::move(input_names), + /*do_op_validation=*/do_op_validation)); } else { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, node_unit.OpType() + " unsupported equation: " + equation); } diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/lpbqgemm_fusion.cc b/onnxruntime/core/providers/qnn/builder/qnn_node_group/lpbqgemm_fusion.cc index ca15f861f4596..99ea79e028b0c 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_node_group/lpbqgemm_fusion.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/lpbqgemm_fusion.cc @@ -32,8 +32,9 @@ std::unique_ptr LowPowerBlockQuantizedGemmFusion::TryFusion( const logging::Logger& logger) { ORT_UNUSED_PARAMETER(logger); + // Only HTP supports LPBQ encoding format // Looking for a Gemm to start search for Gemm w/ LPBQ encodings pattern. - if (gemm_node_unit.OpType() != "Gemm") { + if (!IsNpuBackend(qnn_model_wrapper.GetQnnBackendType()) || gemm_node_unit.OpType() != "Gemm") { return nullptr; } @@ -236,18 +237,22 @@ Status CreateOrValidateOnQnn(QnnModelWrapper& qnn_model_wrapper, (elem_data_type == ONNX_NAMESPACE::TensorProto_DataType_UINT4); } + std::vector weight_shape; + std::string weight_tensor_name = w_ql_input_1_def.node_arg.Name(); + ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(w_ql_input_1_def.node_arg, weight_shape), "Failed to get weight shape"); + // Get attributes like axis, block_size from QuantizeLinear NodeAttrHelper helper(w_ql_node_unit.GetNode()); auto input_channel_axis = helper.Get("axis", static_cast(0)); + if (input_channel_axis < 0) { + input_channel_axis = weight_shape.size() + input_channel_axis; + } auto block_size = helper.Get("block_size", static_cast(0)); size_t output_channel_axis = 0; // Current LowPowerBlockQuantize() support output_channel_axis at index=0; weight_qparams = QnnQuantParamsWrapper(per_channel_float_scale, per_block_int_scale, weight_offset, output_channel_axis, block_size, is_int4_type); - std::vector weight_shape; std::vector unpacked_tensor; - std::string weight_tensor_name = w_ql_input_1_def.node_arg.Name(); - ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(w_ql_input_1_def.node_arg, weight_shape), "Failed to get weight shape"); Qnn_DataType_t weight_data_type = is_int4_type ? QNN_DATATYPE_SFIXED_POINT_4 : QNN_DATATYPE_SFIXED_POINT_8; const auto& weight_tensor_proto = qnn_model_wrapper.GetConstantTensor(weight_tensor_name); ORT_RETURN_IF_ERROR(UnpackWeightTensorData(qnn_model_wrapper, weight_tensor_proto, weight_shape, input_channel_axis, unpacked_tensor)); diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/lpbqgemm_fusion.h b/onnxruntime/core/providers/qnn/builder/qnn_node_group/lpbqgemm_fusion.h index 9dcf07fa863d2..374df8b346e8d 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_node_group/lpbqgemm_fusion.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/lpbqgemm_fusion.h @@ -16,7 +16,7 @@ namespace qnn { class QnnModelWrapper; /// -/// Represents a fusion of a {DQ, DQ->Q->DQ} -> Gemm -> DQ sequence. +/// Represents a fusion of a {DQ, DQ->Q->DQ} -> Gemm -> Q sequence. /// This is translated into a QNN's FC w/ LPBQ encodings. /// The contained NodeUnits are of type SingleNode since they are not part of a QDQ node unit. /// diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/lpbqmatmul_fusion.cc b/onnxruntime/core/providers/qnn/builder/qnn_node_group/lpbqmatmul_fusion.cc new file mode 100644 index 0000000000000..92e0f28b0307c --- /dev/null +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/lpbqmatmul_fusion.cc @@ -0,0 +1,365 @@ +#include +#include +#include +#include +#include +#include + +#include "core/providers/qnn/ort_api.h" +#include "core/providers/qnn/builder/qnn_utils.h" +#include "core/providers/qnn/builder/op_builder_factory.h" +#include "core/providers/qnn/builder/qnn_model_wrapper.h" +#include "core/providers/qnn/builder/qnn_node_group/utils.h" +#include "core/providers/qnn/builder/qnn_node_group/lpbqmatmul_fusion.h" + +namespace onnxruntime { +namespace qnn { + +static Status CreateOrValidateOnQnn(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& scale_dql_node_unit, + const NodeUnit& w_ql_node_unit, + const NodeUnit& matmul_node_unit, + const logging::Logger& logger, + bool validate); + +std::unique_ptr LowPowerBlockQuantizedMatMulFusion::TryFusion( + QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& matmul_node_unit, + const std::unordered_map& node_to_node_unit, + const std::unordered_map& node_unit_to_qnn_node_group, + const logging::Logger& logger) { + ORT_UNUSED_PARAMETER(logger); + + // Only HTP supports LPBQ encoding format + // Looking for a MatMul to start search for MatMul w/ LPBQ encodings pattern. + if (!IsNpuBackend(qnn_model_wrapper.GetQnnBackendType()) || matmul_node_unit.OpType() != "MatMul") { + return nullptr; + } + + const GraphViewer& graph_viewer = qnn_model_wrapper.GetGraphViewer(); + + // Get QuantizeLinear on Weight (input 1) of MatMul node + const NodeUnit* p_w_ql_node_unit = GetParentOfInput(graph_viewer, + matmul_node_unit, + matmul_node_unit.Inputs()[1], + node_to_node_unit, + node_unit_to_qnn_node_group); + if (p_w_ql_node_unit == nullptr || p_w_ql_node_unit->OpType() != "QuantizeLinear") { + return nullptr; + } + + // Check if input of QuantizeLinear is constant initializer + if (!qnn_model_wrapper.IsConstantInput(p_w_ql_node_unit->Inputs()[0].node_arg.Name())) { + return nullptr; + } + + // Get DequantizeLinear node unit contains per-block int scales and per-channel float scales + const std::array w_ql_parent_types = {"DequantizeLinear"}; + const NodeUnit* p_scale_dql_node_unit = GetParentOfType(graph_viewer, + *p_w_ql_node_unit, + w_ql_parent_types, + node_to_node_unit, + node_unit_to_qnn_node_group); + if (p_scale_dql_node_unit == nullptr) { + return nullptr; + } + + TensorInfo pc_scales_tensor_info = {}; + if (Status status = qnn_model_wrapper.GetTensorInfo(p_scale_dql_node_unit->Inputs()[0], pc_scales_tensor_info); + !status.IsOK()) { + return nullptr; + } + // Check if input 0 of DequantizeLinear is constant initializer and has per-channel float scales + if (!pc_scales_tensor_info.is_initializer || !pc_scales_tensor_info.quant_param.IsPerChannel()) { + return nullptr; + } + + if (Status status = CreateOrValidateOnQnn(qnn_model_wrapper, + *p_scale_dql_node_unit, + *p_w_ql_node_unit, + matmul_node_unit, + logger, + true); + !status.IsOK()) { + return nullptr; + } + + return std::make_unique(*p_scale_dql_node_unit, + *p_w_ql_node_unit, + matmul_node_unit); +} + +LowPowerBlockQuantizedMatMulFusion::LowPowerBlockQuantizedMatMulFusion(const NodeUnit& Scale_DQL_node_unit, + const NodeUnit& W_QL_node_unit, + const NodeUnit& MatMul_node_unit) + : node_units_{&Scale_DQL_node_unit, + &W_QL_node_unit, + &MatMul_node_unit} { +} + +Status LowPowerBlockQuantizedMatMulFusion::IsSupported(QnnModelWrapper& qmw, const logging::Logger& logger) const { + return CreateOrValidateOnQnn(qmw, *node_units_[0], *node_units_[1], *node_units_[2], logger, true); +} + +Status LowPowerBlockQuantizedMatMulFusion::AddToModelBuilder(QnnModelWrapper& qmw, const logging::Logger& logger) const { + return CreateOrValidateOnQnn(qmw, *node_units_[0], *node_units_[1], *node_units_[2], logger, false); +} + +gsl::span LowPowerBlockQuantizedMatMulFusion::GetNodeUnits() const { + return node_units_; +} + +const NodeUnit* LowPowerBlockQuantizedMatMulFusion::GetTargetNodeUnit() const { + return node_units_[2]; +} + +namespace { +// Process input[0] for ONNX MatMul that can be translated to either a QNN MatMul. +Status ProcessInput0(QnnModelWrapper& qnn_model_wrapper, + const NodeUnitIODef& input_def, + const std::string& original_input_0_name, + std::vector& input_names, + const logging::Logger& logger, + bool do_op_validation) { + TensorInfo input_0_info{}; + ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(input_def, input_0_info)); + bool reshape_input_0 = input_0_info.shape.size() == 1; + std::string actual_input_0_name = original_input_0_name; + + if (reshape_input_0) { + actual_input_0_name = original_input_0_name + "_ort_qnn_ep_reshape"; + std::vector shape_2d{1, input_0_info.shape[0]}; + QnnQuantParamsWrapper quant_param_2d = input_0_info.quant_param.Copy(); + ORT_RETURN_IF_ERROR(quant_param_2d.HandleUnsqueeze(input_0_info.shape, shape_2d)); + + // If input_0 is initializer, unpack it and add the tensor with new quantization parameter and shape. + // Otherwise, add a Reshape node. + if (input_0_info.is_initializer) { + std::vector unpacked_tensor; + ORT_RETURN_IF_ERROR(qnn_model_wrapper.UnpackInitializerData(*input_0_info.initializer_tensor, unpacked_tensor)); + QnnTensorWrapper input_tensorwrapper(actual_input_0_name, QNN_TENSOR_TYPE_STATIC, input_0_info.qnn_data_type, + std::move(quant_param_2d), std::move(shape_2d), std::move(unpacked_tensor)); + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(input_tensorwrapper)), "Failed to add tensor."); + } else { + ORT_RETURN_IF_ERROR(qnn_model_wrapper.AddReshapeNode(original_input_0_name, actual_input_0_name, + input_0_info.shape, shape_2d, + input_0_info.qnn_data_type, input_0_info.quant_param, + quant_param_2d, do_op_validation, + qnn_model_wrapper.IsGraphInput(original_input_0_name), false)); + } + } else { + if (qnn_model_wrapper.IsQnnTensorWrapperExist(actual_input_0_name)) { + LOGS(logger, VERBOSE) << "Tensor already added, skip it: " << actual_input_0_name; + } else { + QnnTensorWrapper input_0_tensor; + ORT_RETURN_IF_ERROR(qnn_model_wrapper.MakeTensorWrapper(input_0_info, actual_input_0_name, input_0_tensor)); + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(input_0_tensor)), "Failed to add tensor."); + } + } + input_names.emplace_back(actual_input_0_name); + + return Status::OK(); +} + +// Utility function to unpack weight tensor and transpose to shape [out_channels][in_channels] +Status UnpackWeightTensorData(const QnnModelWrapper& qnn_model_wrapper, + const onnx::TensorProto* weight_tensor_proto, + std::vector& weight_shape, + int64_t& input_channel_axis, + std::vector& unpacked_tensor) { + ORT_RETURN_IF_NOT(weight_tensor_proto != nullptr, "Weight tensor proto is null"); + + if (input_channel_axis == 0) { + // Transpose to keep output_channel at index 0; + // The current logic that quantizes with LPBQ encodings requires out_channels at index 0 + input_channel_axis = weight_shape.size() - 1; + return utils::TwoDimensionTranspose(qnn_model_wrapper, weight_shape, *weight_tensor_proto, unpacked_tensor); + } else { + // No transpose needed, just unpack the initializer data + return qnn_model_wrapper.UnpackInitializerData(*weight_tensor_proto, unpacked_tensor); + } +} + +// A utility function to transpose a 2D data +Status TwoDimensionTranspose(std::vector& data, + std::vector& data_shape, + const Qnn_DataType_t element_type) { + ORT_RETURN_IF_NOT(data_shape.size() == 2, "Expected shape of rank 2"); + + std::array perm = {1, 0}; + std::vector output_shape(data_shape.size()); + ORT_RETURN_IF_ERROR((qnn::utils::PermuteShape(data_shape, perm, output_shape))); + + const size_t elem_byte_size = qnn::utils::GetElementSizeByType(element_type); + ORT_RETURN_IF_NOT(elem_byte_size != 0, "Can't get element byte size from given QNN type"); + + std::vector transposed_data(data.size()); + + for (size_t row = 0; row < data_shape[0]; row++) { + for (size_t col = 0; col < data_shape[1]; col++) { + const size_t src_elem_index = (row * data_shape[1] + col); + const size_t dst_elem_index = (col * output_shape[1] + row); + const size_t src_byte_index = src_elem_index * elem_byte_size; + const size_t dst_byte_index = dst_elem_index * elem_byte_size; + assert(src_byte_index < data.size()); + assert(dst_byte_index < transposed_data.size()); + + std::memcpy(&transposed_data[dst_byte_index], &data[src_byte_index], elem_byte_size); + } + } + + data = std::move(transposed_data); // Update data with transposed data + data_shape = std::move(output_shape); // Update parameter with final transposed shape + return Status::OK(); +} + +// Process LPBQWeight for ONNX MatMul that can be translated to either a QNN MatMul. +Status ProcessLPBQWeight(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& scale_dql_node_unit, + const NodeUnit& w_ql_node_unit, + const NodeUnit& matmul_node_unit, + std::vector& input_names, + const logging::Logger& logger) { + ORT_UNUSED_PARAMETER(logger); + const NodeUnitIODef& mm_input_1_def = matmul_node_unit.Inputs()[1]; + const NodeUnitIODef& w_ql_input_1_def = w_ql_node_unit.Inputs()[0]; + + // get per_channel_float_scale value from Quant param of input[0] of DequantizeLinear + std::vector per_channel_float_scale; + const NodeUnitIODef& per_channel_float_def = scale_dql_node_unit.Inputs()[0]; + const std::optional& scale_dql_quant_param = per_channel_float_def.quant_param; + ORT_RETURN_IF_ERROR(qnn_model_wrapper.UnpackScales(scale_dql_quant_param->scale.Name(), per_channel_float_scale)); + + // get per_block_int_scale value from input[0] of DequantizeLinear + std::vector per_block_int_scale; + const NodeUnitIODef& per_block_int_def = scale_dql_node_unit.Inputs()[0]; + ORT_RETURN_IF_ERROR(qnn_model_wrapper.UnpackScales(per_block_int_def.node_arg.Name(), per_block_int_scale)); + std::vector weight_offset(per_channel_float_scale.size(), 0); + std::vector block_scales_shape; + ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(per_block_int_def.node_arg, block_scales_shape), "Failed to get block_scales shape"); + + // Read axis of channels in per-block-int-scales data + NodeAttrHelper scales_node_helper(scale_dql_node_unit.GetNode()); + auto block_scales_axis = scales_node_helper.Get("axis", static_cast(0)); + + // Transpose per-block-int-scales to keep channels at index-0 (QNN LPBQ format requires shape [axis_size][blocks-per-axis]) + if (block_scales_axis == 1) { + ORT_RETURN_IF_ERROR(TwoDimensionTranspose(per_block_int_scale, block_scales_shape, QNN_DATATYPE_UFIXED_POINT_8)); + block_scales_axis = 0; + } + + // Extract weight datatype from zeropoint (aka offset) of Input1 Quant param + const std::optional& mm_input_1_quant_param = mm_input_1_def.quant_param; + bool is_int4_type = false; + if (mm_input_1_quant_param->zero_point != nullptr) { + int32_t elem_data_type = 0; + ORT_RETURN_IF_ERROR(utils::GetOnnxTensorElemDataType(*mm_input_1_quant_param->zero_point, elem_data_type)); + is_int4_type = (elem_data_type == ONNX_NAMESPACE::TensorProto_DataType_INT4) || + (elem_data_type == ONNX_NAMESPACE::TensorProto_DataType_UINT4); + } + + std::vector weight_shape; + std::string weight_tensor_name = w_ql_input_1_def.node_arg.Name(); + ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(w_ql_input_1_def.node_arg, weight_shape), "Failed to get weight shape"); + + // Get attributes like weight data axis, block_size from QuantizeLinear + NodeAttrHelper helper(w_ql_node_unit.GetNode()); + auto input_channel_axis = helper.Get("axis", static_cast(0)); + if (input_channel_axis < 0) { + input_channel_axis = weight_shape.size() + input_channel_axis; // QNN requires positive axis value + } + auto block_size = helper.Get("block_size", static_cast(0)); + + std::vector unpacked_tensor; + const auto& weight_tensor_proto = qnn_model_wrapper.GetConstantTensor(weight_tensor_name); + // if input_channel_axis = 0, UnpackWeightTensorData will transpose and keep output_channel at 0 + ORT_RETURN_IF_ERROR(UnpackWeightTensorData(qnn_model_wrapper, weight_tensor_proto, weight_shape, input_channel_axis, unpacked_tensor)); + + // Quantize weight tensor + size_t weight_elements = unpacked_tensor.size() / sizeof(float); + auto float_data = gsl::make_span(reinterpret_cast(unpacked_tensor.data()), weight_elements); + std::vector quant_data(weight_elements); + + // weight_data_type = 4 but store in int8 buffer + size_t output_channel_axis = 0; // MatMul requires axis to be rank-1 + Qnn_DataType_t weight_data_type = is_int4_type ? QNN_DATATYPE_SFIXED_POINT_4 : QNN_DATATYPE_SFIXED_POINT_8; + ORT_RETURN_IF_ERROR(qnn::utils::LowPowerBlockQuantizeData(float_data, + weight_shape, + per_channel_float_scale, + per_block_int_scale, + weight_offset, + quant_data, + weight_data_type, + output_channel_axis, + block_scales_axis, + block_size, + block_scales_shape)); + + // MatMul w/ LPBQ requies MatMul(MxK, KxN) and axis = rank-1 (out channels) + // Transpose Weight to KxN, output_channel_axis is modified to rank-1; + if (input_channel_axis == 1) { + ORT_RETURN_IF_ERROR(TwoDimensionTranspose(quant_data, weight_shape, QNN_DATATYPE_SFIXED_POINT_8)); + input_channel_axis = 0; + output_channel_axis = weight_shape.size() - 1; + } + + // Construct Quant params for Weight + QnnQuantParamsWrapper weight_qparams; + weight_qparams = QnnQuantParamsWrapper(per_channel_float_scale, per_block_int_scale, weight_offset, output_channel_axis, block_size, is_int4_type); + + // Get weight tensor type from input of w_dql_tensor or output_dql_tensor + Qnn_TensorType_t weight_tensor_type = qnn_model_wrapper.GetTensorType(weight_tensor_name); + QnnTensorWrapper weight_tensor(weight_tensor_name, weight_tensor_type, QNN_DATATYPE_SFIXED_POINT_8, + std::move(weight_qparams), std::move(weight_shape), + std::move(quant_data)); + + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(weight_tensor)), "Failed to add weight"); + input_names.emplace_back(weight_tensor_name); + return Status::OK(); +} +} // namespace + +Status CreateOrValidateOnQnn(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& scale_dql_node_unit, + const NodeUnit& w_ql_node_unit, + const NodeUnit& matmul_node_unit, + const logging::Logger& logger, + bool validate) { + assert(scale_dql_node_unit.OpType() == "DequantizeLinear" && + w_ql_node_unit.OpType() == "QuantizeLinear" && + matmul_node_unit.OpType() == "MatMul"); + + const auto& node_name = utils::GetNodeName(matmul_node_unit); + + std::vector input_names; + + // prepare input tensor + const NodeUnitIODef& input_def = matmul_node_unit.Inputs()[0]; + const std::string& input_tensor_name = input_def.node_arg.Name(); + ORT_RETURN_IF_ERROR(ProcessInput0(qnn_model_wrapper, input_def, input_tensor_name, input_names, + logger, validate)); + + // Prepare LowPowerBlockQuantized(LPBQ) Weight + ORT_RETURN_IF_ERROR(ProcessLPBQWeight(qnn_model_wrapper, scale_dql_node_unit, w_ql_node_unit, + matmul_node_unit, input_names, logger)); + + // Prepare Output + const NodeUnitIODef& output_def = matmul_node_unit.Outputs()[0]; + const std::string& op_output_name = output_def.node_arg.Name(); + QnnTensorWrapper output_tensor; + ORT_RETURN_IF_ERROR(qnn_model_wrapper.MakeTensorWrapper(output_def, output_tensor)); + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(output_tensor)), "Failed to add output"); + + // Create QNN Node and Validate if require. + ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(node_name, QNN_OP_PACKAGE_NAME_QTI_AISW, + QNN_OP_MAT_MUL, + std::move(input_names), + {op_output_name}, + {}, + validate), + "Failed to add fused Matmul node."); + + return Status(); +} +} // namespace qnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/lpbqmatmul_fusion.h b/onnxruntime/core/providers/qnn/builder/qnn_node_group/lpbqmatmul_fusion.h new file mode 100644 index 0000000000000..0d8967de5ace3 --- /dev/null +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/lpbqmatmul_fusion.h @@ -0,0 +1,49 @@ +// Copyright (c) Qualcomm. All rights reserved. +// Licensed under the MIT License + +#pragma once + +#include +#include +#include + +#include "core/providers/qnn/ort_api.h" +#include "core/providers/qnn/builder/qnn_node_group/qnn_node_group.h" + +namespace onnxruntime { +namespace qnn { + +class QnnModelWrapper; + +/// +/// Represents a fusion of a DQ -> Q -> MatMul (+ DQ, DQ, Q). +/// This is translated into a QNN's MatMul w/ LPBQ encodings. +/// The contained NodeUnits are of type SingleNode since they are not part of a QDQ node unit. +/// + +class LowPowerBlockQuantizedMatMulFusion : public IQnnNodeGroup { + public: + LowPowerBlockQuantizedMatMulFusion(const NodeUnit& Scale_DQL_node_unit, + const NodeUnit& W_QL_node_unit, + const NodeUnit& MatMul_node_unit); + ORT_DISALLOW_COPY_AND_ASSIGNMENT(LowPowerBlockQuantizedMatMulFusion); + + Status IsSupported(QnnModelWrapper& qmw, const logging::Logger& logger) const override; + Status AddToModelBuilder(QnnModelWrapper& qmw, const logging::Logger& logger) const override; + gsl::span GetNodeUnits() const override; + const NodeUnit* GetTargetNodeUnit() const override; + std::string_view Type() const override { return "LowPowerBlockQuantizedMatMulFusion"; } + + static std::unique_ptr TryFusion( + QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& matmul_node_unit, + const std::unordered_map& node_to_node_unit, + const std::unordered_map& node_unit_to_qnn_node_group, + const logging::Logger& logger); + + private: + std::array node_units_; +}; + +} // namespace qnn +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/qnn_node_group.cc b/onnxruntime/core/providers/qnn/builder/qnn_node_group/qnn_node_group.cc index b0f0b4c0ff48a..5f33b639ce613 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_node_group/qnn_node_group.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/qnn_node_group.cc @@ -19,6 +19,7 @@ #include "core/providers/qnn/builder/qnn_node_group/channel_shuffle_fusion.h" #include "core/providers/qnn/builder/qnn_node_group/udo_fusion.h" #include "core/providers/qnn/builder/qnn_node_group/lpbqgemm_fusion.h" +#include "core/providers/qnn/builder/qnn_node_group/lpbqmatmul_fusion.h" #include "core/providers/qnn/builder/qnn_utils.h" #include "core/providers/qnn/ort_api.h" @@ -76,6 +77,7 @@ using FusionFunc = std::function(QnnModelWrapper& static std::unordered_map> fusions = { {"DequantizeLinear", {DQQFusion::TryFusion}}, {"HardSigmoid", {HardSigmoidMulFusion::TryFusion}}, + {"MatMul", {LowPowerBlockQuantizedMatMulFusion::TryFusion}}, {"Gemm", {LowPowerBlockQuantizedGemmFusion::TryFusion, ReshapeGemmFusion::TryFusion}}, {"Mul", {ScaleSoftmaxFusion::TryFusion}}, {"Transpose", {ChannelShuffleFusion::TryFusion}}}; @@ -113,8 +115,8 @@ static std::unique_ptr TryQnnFusions( const std::unordered_map& node_to_node_unit, const std::unordered_map& node_unit_to_qnn_node_group, const logging::Logger& logger) { - // For now, all fusions involve standalone node units (i.e., no wrapping DQ/Q nodes). - if (starting_node_unit.UnitType() != NodeUnit::Type::SingleNode) { + // For now, all fusions involve standalone node units (i.e., no wrapping DQ/Q nodes) except MatMul w/ LPBQ encodings + if (starting_node_unit.UnitType() != NodeUnit::Type::SingleNode && starting_node_unit.OpType() != "MatMul") { return nullptr; } diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/utils.cc b/onnxruntime/core/providers/qnn/builder/qnn_node_group/utils.cc index 92478e0db7795..10e1633e4b57d 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_node_group/utils.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/utils.cc @@ -177,7 +177,26 @@ const NodeUnit* GetParentOfInput(const GraphViewer& graph_viewer, const NodeUnitIODef& input, const std::unordered_map& node_unit_map, const std::unordered_map& qnn_node_group_map) { - const Node& child_node = node_unit.GetNode(); + const Node* p_child_node = nullptr; + + for (auto node : node_unit.GetAllNodesInGroup()) { + for (auto node_input : node->InputDefs()) { + if (node_input->Name() == input.node_arg.Name()) { + p_child_node = node; + break; + } + + if (p_child_node != nullptr) { + break; + } + } + } + + if (p_child_node == nullptr) { + return nullptr; + } + + const Node& child_node = *p_child_node; for (auto edge = child_node.InputEdgesBegin(); edge != child_node.InputEdgesEnd(); ++edge) { const Node& parent_node = edge->GetNode(); diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index 64be445b4c15c..b60f64db1734d 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -2337,11 +2337,14 @@ SubGraphCollection_t TensorrtExecutionProvider::GetSupportedList(SubGraphCollect if (load_user_initializer_) { auto allInitializers = graph_viewer->GetAllInitializedTensors(); - for (auto entry : allInitializers) { + for (auto& entry : allInitializers) { auto* tp = entry.second; if (tp->has_raw_data()) { - userWeights.push_back( - TensorrtUserWeights{tp->name(), tp->raw_data(), (int64_t)tp->raw_data().size()}); + userWeights.emplace_back(tp->name(), tp->raw_data()); + } else if (utils::HasExternalDataInMemory(*tp)) { + std::unique_ptr full_init; + ORT_THROW_IF_ERROR(utils::GetTensorProtoWithDataIfInMemory(*tp, full_init)); + userWeights.emplace_back(full_init->name(), full_init->raw_data()); } } } @@ -2378,7 +2381,7 @@ SubGraphCollection_t TensorrtExecutionProvider::GetSupportedList(SubGraphCollect if (load_user_initializer_) { trt_parser->loadModelProto(string_buf.data(), string_buf.size(), model_path_); for (auto const& userWeight : userWeights) { - trt_parser->loadInitializer(userWeight.name.c_str(), static_cast(userWeight.data.c_str()), userWeight.size); + trt_parser->loadInitializer(userWeight.Name(), userWeight.Data(), userWeight.Size()); } is_model_supported = trt_parser->parseModelProto(); } else { @@ -2862,7 +2865,8 @@ common::Status TensorrtExecutionProvider::RefitEngine(std::string onnx_model_fil if (onnx_model_path.empty()) { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "The ONNX model was not provided as path. " - "Please use provide an ONNX bytestream to enable refitting the weightless engine."); + "Please use provide an ONNX bytestream to enable refitting the weightless engine." + "When providing a bytestream during session initialization, it should also be set as trt_onnx_bytes_stream"); } else { // check if file path to ONNX is legal if (path_check && IsAbsolutePath(onnx_model_path.string())) { @@ -2909,6 +2913,7 @@ common::Status TensorrtExecutionProvider::RefitEngine(std::string onnx_model_fil int required_weights = refitter->getAllWeights(0, nullptr); std::vector refit_names(required_weights); refitter->getAllWeights(required_weights, refit_names.data()); + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Refitter requires " << required_weights << " weights"; // Vectors to keep track of data pointers. std::vector names; @@ -2918,67 +2923,69 @@ common::Status TensorrtExecutionProvider::RefitEngine(std::string onnx_model_fil std::vector sizes; sizes.reserve(required_weights); - if (refit_with_external_data) { - auto onnx_model = ModelProto::Create(); - TensorProtos* allInitializers_byte_stream; + auto onnx_model = ModelProto::Create(); + TensorProtos* allInitializers_byte_stream; - // Reconstruct onnx model view. - const auto onnx_model_view = std::string((const char*)onnx_model_bytestream, - onnx_model_bytestream_size); - if (!onnx_model->ParseFromString(onnx_model_view)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, - "The provided ONNX bytestream to refit could not be parsed."); - } - - // Extract graph and initializer information. - auto const& graph = onnx_model->mutable_graph(); - allInitializers_byte_stream = graph->mutable_initializer(); - LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Initializers that were found " << allInitializers_byte_stream->size(); - - // Loop through all initializers - for (int initializer_idx = 0; initializer_idx < allInitializers_byte_stream->size(); ++initializer_idx) { - auto& proto = allInitializers_byte_stream->at(initializer_idx); - auto& proto_name = proto.name(); - bool weight_is_refittable = std::find(refit_names.begin(), refit_names.end(), proto_name) != refit_names.end(); - if (weight_is_refittable) { - if (proto.has_data_location()) { - if (proto.data_location() == TensorProto_DataLocation_EXTERNAL) { - // Default values for reading into external_data blob. - int64_t offset = 0; - size_t length = 0; - auto external_data = proto.mutable_external_data(); - const std::string kOffset = "offset", kLength = "length"; - for (int entry_idx = 0; entry_idx < external_data->size(); ++entry_idx) { - auto current_key = external_data->at(entry_idx).mutable_key(); - auto current_value = external_data->at(entry_idx).mutable_value(); - if (*current_key == kOffset && !current_value->empty()) { - offset = std::stoll(*current_value); - } else if (*current_key == kLength && !current_value->empty()) { - length = std::stoul(*current_value); - } + // Reconstruct onnx model view. + const auto onnx_model_view = std::string((const char*)onnx_model_bytestream, + onnx_model_bytestream_size); + if (!onnx_model->ParseFromString(onnx_model_view)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "The provided ONNX bytestream to refit could not be parsed."); + } + + // Extract graph and initializer information. + auto const& graph = onnx_model->mutable_graph(); + allInitializers_byte_stream = graph->mutable_initializer(); + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Initializers that were found " << allInitializers_byte_stream->size(); + + // Loop through all initializers + int missing_initializer_data = 0; + for (int initializer_idx = 0; initializer_idx < allInitializers_byte_stream->size(); ++initializer_idx) { + auto& proto = allInitializers_byte_stream->at(initializer_idx); + auto& proto_name = proto.name(); + bool weight_is_refittable = std::find(refit_names.begin(), refit_names.end(), proto_name) != refit_names.end(); + if (weight_is_refittable) { + if (proto.has_data_location()) { + if (proto.data_location() == TensorProto_DataLocation_EXTERNAL) { + // Default values for reading into external_data blob. + int64_t offset = 0; + size_t length = 0; + auto external_data = proto.mutable_external_data(); + const std::string kOffset = "offset", kLength = "length"; + for (int entry_idx = 0; entry_idx < external_data->size(); ++entry_idx) { + auto current_key = external_data->at(entry_idx).mutable_key(); + auto current_value = external_data->at(entry_idx).mutable_value(); + if (*current_key == kOffset && !current_value->empty()) { + offset = std::stoll(*current_value); + } else if (*current_key == kLength && !current_value->empty()) { + length = std::stoul(*current_value); } - names.push_back(proto.name()); - bytes.push_back(static_cast(onnx_external_data_bytestream) + offset); - sizes.push_back(length); - } else { - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, - "[TensorRT EP] Proto: " + proto_name + " expected to have external datalocation, but default datalocation was provided instead."); } - } else { - if (!proto.has_raw_data()) { - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, - "[TensorRT EP] Proto: " + proto_name + " has no raw data"); - } - auto& raw_data = proto.raw_data(); names.push_back(proto.name()); - bytes.push_back(raw_data.c_str()); - sizes.push_back(raw_data.size()); + bytes.push_back(static_cast(onnx_external_data_bytestream) + offset); + sizes.push_back(length); + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "[TensorRT EP] Proto: " + proto_name + " expected to have external datalocation, but default datalocation was provided instead."); } + } else if (proto.has_raw_data()) { + auto& raw_data = proto.raw_data(); + names.push_back(proto.name()); + bytes.push_back(raw_data.c_str()); + sizes.push_back(raw_data.size()); } else { - LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Initializer with name: " << proto_name << " was not marked as refittable"; + LOGS_DEFAULT(WARNING) << "[TensorRT EP] Proto: " + proto_name + " has no raw nor external data."; + ++missing_initializer_data; } + } else { + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Initializer with name: " << proto_name << " was not marked as refittable"; } } + if (missing_initializer_data) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "[TensorRT EP] RefitEngine is missing " + std::to_string(missing_initializer_data) + " initializers."); + } // Load extracted initializers into the parser if (!names.empty()) { @@ -3093,12 +3100,17 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView if (load_user_initializer_) { auto allInitializers = graph_body_viewer.GetAllInitializedTensors(); - for (auto entry : allInitializers) { + for (auto& entry : allInitializers) { auto name = entry.first; auto* tp = entry.second; if (tp->has_raw_data()) { - userWeights->push_back( - TensorrtUserWeights{tp->name(), tp->raw_data(), (int64_t)tp->raw_data().size()}); + userWeights->emplace_back( + TensorrtUserWeights(tp->name(), tp->raw_data())); + } else if (utils::HasExternalDataInMemory(*tp)) { + std::unique_ptr full_init; + ORT_THROW_IF_ERROR(utils::GetTensorProtoWithDataIfInMemory(*tp, full_init)); + userWeights->emplace_back( + TensorrtUserWeights(full_init->name(), full_init->raw_data())); } } } @@ -3134,7 +3146,7 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView if (load_user_initializer_) { trt_parser->loadModelProto(string_buf.data(), string_buf.size(), model_path_); for (auto const& userWeight : *userWeights) { - trt_parser->loadInitializer(userWeight.name.c_str(), static_cast(userWeight.data.c_str()), userWeight.size); + trt_parser->loadInitializer(userWeight.Name(), userWeight.Data(), userWeight.Size()); } trt_parser->parseModelProto(); } else { @@ -3671,14 +3683,12 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView if (weight_stripped_engine_refit_) { LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Refit engine from main ONNX file after engine build"; - char* onnx = string_buf.data(); - size_t onnx_size = string_buf.size(); auto status = RefitEngine(model_path_, onnx_model_folder_path_, engine_cache_path, false /* path check for security */, - onnx, - onnx_size, + onnx_model_bytestream_, + onnx_model_bytestream_size_, onnx_external_data_bytestream_, onnx_external_data_bytestream_size_, trt_engine.get(), diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h index dba17f7822eac..e817fc51237c0 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h @@ -158,10 +158,25 @@ class OutputAllocator : public nvinfer1::IOutputAllocator { using ShapeRangesMap = std::unordered_map>>>; // Struct to hold user weights when ModelProtos are serialized with data. -struct TensorrtUserWeights { - std::string name{}; - std::string data{}; - int64_t size{}; +class TensorrtUserWeights { + public: + TensorrtUserWeights(const std::string& name, const std::string& data) : name_(name), data_(data) {}; + + const char* Name() const { + return name_.c_str(); + }; + + const void* Data() const { + return static_cast(data_.data()); + } + + int64_t Size() const { + return static_cast(data_.size()); + } + + private: + std::string name_{}; + std::string data_{}; }; // Information to construct kernel function state. diff --git a/onnxruntime/core/providers/webgpu/tensor/concat.cc b/onnxruntime/core/providers/webgpu/tensor/concat.cc index 5cfd6c78f8929..283a9e5fe8262 100644 --- a/onnxruntime/core/providers/webgpu/tensor/concat.cc +++ b/onnxruntime/core/providers/webgpu/tensor/concat.cc @@ -38,19 +38,19 @@ WEBGPU_CONCAT_VERSIONED_KERNEL(4, 10) WEBGPU_CONCAT_VERSIONED_KERNEL(11, 12) WEBGPU_CONCAT_KERNEL(13) -void AppendCalCulateInputIndexFunction(std::ostream& os, size_t input_count) { - os << "fn calculate_input_index(index: u32) -> u32 {\n" - << " for (var i = 0u; i < " << input_count << "; i = i + 1u) {\n" - << " if (index < " << GetElementAt("uniforms.size_in_concat_axis", "i", input_count) << ") {\n" - << " return i;\n" +void AppendCalculateInputIndexFunction(std::ostream& os, size_t input_count) { + os << "fn calculate_input_index(global_idx: u32) -> u32 {\n" + << " for (var i = 1u; i < " << input_count << "; i = i + 1u) {\n" + << " if (global_idx < " << GetElementAt("uniforms.offsets", "i", input_count) << ") {\n" + << " return i - 1;\n" << " }\n" << " }\n" - << " return " << input_count << ";\n" + << " return " << input_count - 1 << ";\n" << "}\n"; } -void AppendAssignOutputDataFunction(std::ostream& os, gsl::span inputs, const ShaderVariableHelper& output) { - os << "fn assign_output_data(global_idx: u32, input_index: u32, indices: output_indices_t) {\n"; +void AppendAssignOutputDataFunction(std::ostream& os, gsl::span inputs, const ShaderVariableHelper& output, size_t axis, size_t input_count) { + os << "fn assign_output_data(global_idx: u32, input_index: u32) {\n"; for (size_t i = 0; i < inputs.size(); ++i) { if (i == 0) { os << " if (input_index == 0u) {\n"; @@ -59,7 +59,12 @@ void AppendAssignOutputDataFunction(std::ostream& os, gsl::spanGetByIndices("indices")) << ";\n"; + std::string offset = GetElementAt("uniforms.offsets", "input_index", input_count); + std::string concat_axis_offset = GetElementAt("uniforms.sizes_in_concat_axis", std::to_string(i), input_count); + std::string output_indices_axis = "output_indices" + (inputs[i]->Rank() > 1 ? "[" + std::to_string(axis) + "]" : ""); + os << " var output_indices = " << inputs[i]->OffsetToIndices("global_idx - " + offset) << ";\n" + << " " << output_indices_axis << " += " << concat_axis_offset << ";\n" + << " " << output.SetByIndices("output_indices", inputs[i]->GetByOffset("global_idx - " + offset)) << "\n"; } os << " }\n" "}\n"; @@ -74,27 +79,21 @@ Status ConcatProgram::GenerateShaderCode(ShaderHelper& shader) const { } const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias); - // add implementation of fn calculate_input_index - AppendCalCulateInputIndexFunction(shader.AdditionalImplementation(), input_count); - // add implementation of fn assign_output_data - AppendAssignOutputDataFunction(shader.AdditionalImplementation(), inputs, output); - const std::string size_in_concat_axis = GetElementAt("uniforms.size_in_concat_axis", "input_index - 1", input_count); + AppendCalculateInputIndexFunction(shader.AdditionalImplementation(), input_count); + AppendAssignOutputDataFunction(shader.AdditionalImplementation(), inputs, output, axis_, input_count); + shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size") - << " var indices = " << output.OffsetToIndices("global_idx") << ";\n" - << " let indices_axis = " << output.IndicesGet("indices", axis_) << ";\n" - << " let input_index = calculate_input_index(indices_axis);\n" - << " if (input_index != 0u) {\n" - << " " << output.IndicesSet("indices", axis_, "indices_axis - " + size_in_concat_axis) << ";\n" - << " }\n" - " assign_output_data(global_idx, input_index, indices);\n"; + << "let input_index = calculate_input_index(global_idx);\n" + << "assign_output_data(global_idx, input_index);\n"; + return Status::OK(); } Status Concat::ComputeInternal(ComputeContext& context) const { - int input_count = context.InputCount(); + uint32_t input_count = context.InputCount(); InlinedTensorsVector input_tensors; input_tensors.reserve(input_count); - for (int i = 0; i < input_count; ++i) { + for (uint32_t i = 0; i < input_count; ++i) { input_tensors.push_back(context.Input(i)); } @@ -104,42 +103,55 @@ Status Concat::ComputeInternal(ComputeContext& context) const { return Status::OK(); } - uint32_t output_size = onnxruntime::narrow(prepare.output_tensor->Shape().Size()); + uint32_t axis = static_cast(prepare.axis); + uint32_t max_inputs_per_concat = context.DeviceLimits().maxStorageBuffersPerShaderStage - 1; + + uint32_t input_index = 0; + uint32_t cumulative_size_in_concat_axis = 0; + + while (input_index < input_count) { + ConcatProgram program{axis}; + uint32_t num_inputs_this_concat = std::min(max_inputs_per_concat, input_count - input_index); + + std::vector offsets; + offsets.reserve(num_inputs_this_concat + 1); + offsets.push_back(0); - size_t axis = static_cast(prepare.axis); - ConcatProgram program{axis}; + std::vector sizes_in_concat_axis; + sizes_in_concat_axis.reserve(num_inputs_this_concat + 1); + sizes_in_concat_axis.push_back(cumulative_size_in_concat_axis); - std::vector sizes_in_concat_axis; - sizes_in_concat_axis.reserve(input_count); - uint32_t sum = 0; - for (int i = 0; i < input_count; ++i) { - const auto& input = prepare.inputs[i]; - if (input.tensor->Shape().Size() == 0) { - continue; + uint32_t output_size = 0; + for (uint32_t i = 0; i < num_inputs_this_concat; i++) { + auto& input = prepare.inputs[input_index + i]; + if (input.tensor->Shape().Size() == 0) { + continue; + } + program.AddInput({input.tensor, ProgramTensorMetadataDependency::TypeAndRank}); + + uint32_t size = onnxruntime::narrow(input.tensor->Shape().Size()); + uint32_t axis_size = static_cast(input.tensor->Shape()[axis]); + + output_size += size; + offsets.push_back(output_size); + cumulative_size_in_concat_axis += axis_size; + sizes_in_concat_axis.push_back(cumulative_size_in_concat_axis); } - program.AddInput({input.tensor, ProgramTensorMetadataDependency::TypeAndRank}); - auto axis_size = input.tensor->Shape()[axis]; - sum += static_cast(axis_size); - sizes_in_concat_axis.push_back(sum); - } + offsets.pop_back(); + sizes_in_concat_axis.pop_back(); - size_t non_empty_input_count = sizes_in_concat_axis.size(); + program.CacheHint(absl::StrJoin(std::make_tuple(num_inputs_this_concat, prepare.axis), ",")) + .AddOutputs({prepare.output_tensor}) + .SetDispatchGroupSize((output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) + .AddUniformVariables({gsl::span(offsets.data(), offsets.size()), gsl::span(sizes_in_concat_axis.data(), sizes_in_concat_axis.size()), output_size}); + ORT_RETURN_IF_ERROR(context.RunProgram(program)); - if (non_empty_input_count + 1 > context.DeviceLimits().maxStorageBuffersPerShaderStage) { - // TODO: support when input_count + 1 > maxStorageBuffersPerShaderStage, by raising the limit or run the program in multiple passes. - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "The number of storage buffer (input=", - input_count, ", output=1) exceeds the limit (", - context.DeviceLimits().maxStorageBuffersPerShaderStage, ") of the device."); + input_index += num_inputs_this_concat; } - program.CacheHint(absl::StrJoin(std::make_tuple(non_empty_input_count, prepare.axis), ",")) - .AddOutputs({prepare.output_tensor}) - .SetDispatchGroupSize((prepare.output_num_elements + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) - .AddUniformVariables({gsl::span(sizes_in_concat_axis.data(), sizes_in_concat_axis.size()), - output_size}); - return context.RunProgram(program); + return Status::OK(); } } // namespace webgpu -} // namespace onnxruntime +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/core/providers/webgpu/tensor/concat.h b/onnxruntime/core/providers/webgpu/tensor/concat.h index 0f6e6dd327e33..7980556e0a1f4 100644 --- a/onnxruntime/core/providers/webgpu/tensor/concat.h +++ b/onnxruntime/core/providers/webgpu/tensor/concat.h @@ -17,7 +17,8 @@ class ConcatProgram final : public Program { Status GenerateShaderCode(ShaderHelper& sh) const override; - WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"size_in_concat_axis", ProgramUniformVariableDataType::Uint32}, + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"offsets", ProgramUniformVariableDataType::Uint32}, + {"sizes_in_concat_axis", ProgramUniformVariableDataType::Uint32}, {"output_size", ProgramUniformVariableDataType::Uint32}); private: @@ -33,4 +34,4 @@ class Concat final : public WebGpuKernel, public ConcatBase { }; } // namespace webgpu -} // namespace onnxruntime +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc index 6e09f494f4a8d..bca41b7851c28 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc @@ -18,8 +18,11 @@ #include "core/framework/data_transfer_manager.h" #include "core/framework/fallback_cpu_capability.h" #include "core/framework/kernel_registry.h" +#include "core/framework/run_options.h" #include "core/graph/function_utils.h" #include "core/graph/indexed_sub_graph.h" +#include "core/session/onnxruntime_run_options_config_keys.h" +#include "core/common/parse_string.h" #include "core/providers/webgpu/webgpu_context.h" #include "core/providers/webgpu/data_transfer.h" @@ -692,7 +695,6 @@ std::unique_ptr RegisterKernels() { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -907,7 +909,7 @@ Status WebGpuExecutionProvider::OnSessionInitializationEnd() { return Status::OK(); } -Status WebGpuExecutionProvider::OnRunStart(const onnxruntime::RunOptions& /*run_options*/) { +Status WebGpuExecutionProvider::OnRunStart(const onnxruntime::RunOptions& run_options) { if (context_.ValidationMode() >= ValidationMode::Basic) { context_.PushErrorScope(); } @@ -916,20 +918,32 @@ Status WebGpuExecutionProvider::OnRunStart(const onnxruntime::RunOptions& /*run_ context_.StartProfiling(); } - if (IsGraphCaptureEnabled() && IsGraphCaptureAllowed() && !IsGraphCaptured(0)) { - context_.CaptureBegin(&captured_commands_, *graph_buffer_mgr_); + if (IsGraphCaptureEnabled()) { + auto graph_annotation_str = run_options.config_options.GetConfigEntry(kOrtRunOptionsConfigCudaGraphAnnotation); + int graph_annotation_id = 0; + if (graph_annotation_str.has_value()) { + ORT_ENFORCE(onnxruntime::TryParseStringWithClassicLocale(*graph_annotation_str, graph_annotation_id), + "Failed to parse the graph annotation id: ", + *graph_annotation_str); + } + + if (graph_annotation_id != -1 && IsGraphCaptureAllowed() && !IsGraphCaptured(graph_annotation_id)) { + context_.CaptureBegin(&captured_commands_, *graph_buffer_mgr_); + } + m_current_graph_annotation_id = graph_annotation_id; } return Status::OK(); } -Status WebGpuExecutionProvider::OnRunEnd(bool /* sync_stream */, const onnxruntime::RunOptions& /*run_options*/) { +Status WebGpuExecutionProvider::OnRunEnd(bool /* sync_stream */, const onnxruntime::RunOptions& /* run_options */) { context_.Flush(BufferManager()); - if (IsGraphCaptureEnabled() && !IsGraphCaptured(0)) { - if (IsGraphCaptureAllowed()) { + if (IsGraphCaptureEnabled() && !IsGraphCaptured(m_current_graph_annotation_id)) { + if (m_current_graph_annotation_id != -1 && IsGraphCaptureAllowed()) { context_.CaptureEnd(); is_graph_captured_ = true; + ORT_RETURN_IF_ERROR(ReplayGraph(m_current_graph_annotation_id)); } else { IncrementRegularRunCountBeforeGraphCapture(); } @@ -952,12 +966,12 @@ bool WebGpuExecutionProvider::IsGraphCaptureEnabled() const { return enable_graph_capture_; } -bool WebGpuExecutionProvider::IsGraphCaptured(int) const { - return is_graph_captured_; +bool WebGpuExecutionProvider::IsGraphCaptured(int graph_annotation_id) const { + return is_graph_captured_ && graph_annotation_id != -1; } -Status WebGpuExecutionProvider::ReplayGraph(int) { - ORT_ENFORCE(IsGraphCaptured(0)); +Status WebGpuExecutionProvider::ReplayGraph(int graph_annotation_id) { + ORT_ENFORCE(IsGraphCaptured(graph_annotation_id)); context_.Replay(captured_commands_, *graph_buffer_mgr_); return Status::OK(); } diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h index 2567be2a1eb18..3bbec164a0190 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h @@ -99,6 +99,7 @@ class WebGpuExecutionProvider : public IExecutionProvider { bool is_graph_captured_ = false; int regular_run_count_before_graph_capture_ = 0; const int min_num_runs_before_cuda_graph_capture_ = 1; // required min regular runs before graph capture for the necessary memory allocations. + int m_current_graph_annotation_id = 0; webgpu::GpuBufferAllocator* allocator_ = nullptr; // Buffer manager specifically for graph capture mode diff --git a/onnxruntime/core/providers/webgpu/webgpu_pix_frame_generator.cc b/onnxruntime/core/providers/webgpu/webgpu_pix_frame_generator.cc index 9b287b7b7df99..bc5c755160cb5 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_pix_frame_generator.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_pix_frame_generator.cc @@ -36,6 +36,7 @@ WebGpuPIXFrameGenerator::WebGpuPIXFrameGenerator(wgpu::Instance instance, wgpu:: format = capabilities.formats[0]; wgpu::SurfaceConfiguration config; + config.presentMode = capabilities.presentModes[0]; config.device = device; config.format = format; config.width = kWidth; diff --git a/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc index f75b6f41f7f9c..109228cc60d7d 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc @@ -317,32 +317,34 @@ Status ConvOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N int32_t x_type; ORT_RETURN_IF_NOT(GetType(*input_defs[0], x_type, logger), "Cannot get data type of input x"); - emscripten::val x_zero_point, w_zero_point, x_scale, w_scale; + emscripten::val x_zero_point, w_zero_point; + std::vector x_zero_point_shape; if (TensorExists(input_defs, 2)) { x_zero_point = model_builder.GetOperand(node.InputDefs()[2]->Name()); + ORT_RETURN_IF_NOT(GetShape(*input_defs[2], x_zero_point_shape, logger), "Cannot get shape of x_zero_point"); } else { x_zero_point = model_builder.CreateOrGetConstant(x_type, 0); } - // Scale is not used by ConvInteger but required by DequantizeLinear. So set it to default value 1.0f. // The x_zero_point must be a scalar and the scale input should have the same shape as the zero point input. // So the x_scale must be a scalar too. - x_scale = model_builder.CreateOrGetConstant(ONNX_NAMESPACE::TensorProto_DataType_FLOAT, 1.0f); + // ONNX allows 1D tensor of size 1 as scalar. So explicitly set the shape of x_scale to x_zero_point_shape. + emscripten::val x_scale = model_builder.CreateOrGetConstant( + ONNX_NAMESPACE::TensorProto_DataType_FLOAT, 1.0f, GetNarrowedIntFromInt64(x_zero_point_shape)); // Dequantize x to Float32 common_options.set("label", node.Name() + "_dequantized_x"); input = model_builder.GetBuilder().call("dequantizeLinear", input, x_scale, x_zero_point, common_options); + std::vector w_zero_point_shape; if (TensorExists(input_defs, 3)) { w_zero_point = model_builder.GetOperand(node.InputDefs()[3]->Name()); - std::vector w_zero_point_shape; ORT_RETURN_IF_NOT(GetShape(*input_defs[3], w_zero_point_shape, logger), "Cannot get shape of w_zero_point"); - w_scale = model_builder.CreateOrGetConstant(ONNX_NAMESPACE::TensorProto_DataType_FLOAT, 1.0f, - GetNarrowedIntFromInt64(w_zero_point_shape)); } else { w_zero_point = model_builder.CreateOrGetConstant(x_type, 0); - w_scale = x_scale; } + emscripten::val w_scale = model_builder.CreateOrGetConstant( + ONNX_NAMESPACE::TensorProto_DataType_FLOAT, 1.0f, GetNarrowedIntFromInt64(w_zero_point_shape)); // Dequantize w to Float32 common_options.set("label", node.Name() + "_dequantized_w"); filter = model_builder.GetBuilder().call("dequantizeLinear", filter, w_scale, w_zero_point, diff --git a/onnxruntime/core/providers/xnnpack/math/gemm.cc b/onnxruntime/core/providers/xnnpack/math/gemm.cc index a3ff3b585ae45..9b78e943122de 100644 --- a/onnxruntime/core/providers/xnnpack/math/gemm.cc +++ b/onnxruntime/core/providers/xnnpack/math/gemm.cc @@ -139,7 +139,6 @@ Status Gemm::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr, // flags - 1 - for no transpose - 0 for transpose uint32_t flags = trans_B_ == CblasTrans ? 0 : XNN_FLAG_TRANSPOSE_WEIGHTS; - auto code_cache = GetCodeCache(); auto weights_cache = GetWeightsCache(); xnn_status status = xnn_status::xnn_status_uninitialized; struct xnn_operator* p = nullptr; @@ -159,7 +158,7 @@ Status Gemm::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr, bias_data, // const float* bias, foutput_min, foutput_max, flags, - code_cache, weights_cache, + weights_cache, &p); } else if (op_compute_type_ == OpComputeType::op_compute_type_fp16) { const MLFloat16* bias_data = nullptr; @@ -175,7 +174,7 @@ Status Gemm::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr, bias_data, // const float* bias, foutput_min, foutput_max, flags, - code_cache, weights_cache, + weights_cache, &p); } diff --git a/onnxruntime/core/providers/xnnpack/math/matmul.cc b/onnxruntime/core/providers/xnnpack/math/matmul.cc index 9083b9c22f64a..7870bcff298f2 100644 --- a/onnxruntime/core/providers/xnnpack/math/matmul.cc +++ b/onnxruntime/core/providers/xnnpack/math/matmul.cc @@ -102,10 +102,8 @@ Status MatMul::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, } #ifdef XNN_CACHE_ENABLE - xnn_code_cache_t code_cache = GetCodeCache(); xnn_weights_cache_t weight_cache = GetWeightsCache(); #else - xnn_code_cache_t code_cache = nullptr; xnn_weights_cache_t weight_cache = nullptr; #endif @@ -122,7 +120,6 @@ Status MatMul::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, foutput_min, foutput_max, flags, - code_cache, weight_cache, &p); } else if (op_type_ == OpComputeType::op_compute_type_fp16) { @@ -136,7 +133,6 @@ Status MatMul::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, foutput_min, foutput_max, flags, - code_cache, weight_cache, &p); } diff --git a/onnxruntime/core/providers/xnnpack/nn/average_pool.cc b/onnxruntime/core/providers/xnnpack/nn/average_pool.cc index f320274f65db3..963dfa5fa26d7 100644 --- a/onnxruntime/core/providers/xnnpack/nn/average_pool.cc +++ b/onnxruntime/core/providers/xnnpack/nn/average_pool.cc @@ -17,7 +17,6 @@ namespace { Status CreateXnnpackKernel(const PoolAttributes& pool_attrs, const std::optional>& clip_min_max, struct xnn_operator*& p, - const OpQuantParam& quant_param, OpComputeType avgpool_type) { uint32_t input_padding_top = narrow(pool_attrs.pads[0]); uint32_t input_padding_left = narrow(pool_attrs.pads[1]); @@ -48,20 +47,6 @@ Status CreateXnnpackKernel(const PoolAttributes& pool_attrs, pooling_height, pooling_width, stride_height, stride_width, foutput_min, foutput_max, flags, &p); - } else if (avgpool_type == OpComputeType::op_compute_type_qu8) { - const float output_scale = quant_param[1].first[0]; - const uint8_t output_zero_point = quant_param[1].second; - const uint8_t output_min = xnn_u8s8_quantize(foutput_min, output_scale, output_zero_point); - const uint8_t output_max = xnn_u8s8_quantize(foutput_max, output_scale, output_zero_point); - status = xnn_create_average_pooling2d_nhwc_qu8(input_padding_top, input_padding_right, - input_padding_bottom, input_padding_left, - pooling_height, pooling_width, - stride_height, stride_width, - quant_param[0].second, - quant_param[0].first[0], - quant_param[1].second, - quant_param[1].first[0], - output_min, output_max, flags, &p); } if (status != xnn_status_success) { @@ -72,9 +57,9 @@ Status CreateXnnpackKernel(const PoolAttributes& pool_attrs, } bool IsQuantAvgPoolSupported(const NodeUnit& node_unit, const GraphViewer& graph) { - TensorQuantType x_input_type = GetTensorQuantType(node_unit, 0, false, graph); - TensorQuantType output_type = GetTensorQuantType(node_unit, 0, true, graph); - return (x_input_type == TensorTypeUint8 && output_type == TensorTypeUint8); + (void)node_unit; + (void)graph; + return false; } bool IsQuantizedAvgPool(QuantizedOpType quant_op_type) { @@ -209,14 +194,10 @@ AveragePool::AveragePool(const OpKernelInfo& info) avgpool_type_ = OpComputeType::op_compute_type_fp32; } else if (input_dtype == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) { avgpool_type_ = OpComputeType::op_compute_type_fp16; - } else if (input_dtype == ONNX_NAMESPACE::TensorProto_DataType_UINT8) { - // the order of input tensor, x,x_scale, x_zp, y_scale, y_zp - quant_param = ParseQuantParamForOp(info, input_dtype, 1); - avgpool_type_ = OpComputeType::op_compute_type_qu8; } struct xnn_operator* p; auto ret = CreateXnnpackKernel(pool_attrs_, clip_min_max_, p, - quant_param, avgpool_type_); + avgpool_type_); ORT_ENFORCE(ret.IsOK(), ret.ErrorMessage()); op0_.reset(p); } @@ -242,23 +223,12 @@ Status AveragePool::Compute(OpKernelContext* context) const { pthreadpool_t threadpool = GetThreadPool(); - // setup allocator/automated dellocate for workspace - size_t workspace_size = 0; - size_t workspace_alignment = 0; - xnn_allocator* allocator = GetStoredAllocator().second; - auto deallocator = [allocator](void* ptr) { allocator->aligned_deallocate(allocator->context, ptr); }; - - std::unique_ptr workspace(nullptr, deallocator); - auto reshape_fn = xnn_reshape_average_pooling2d_nhwc_f32; if (avgpool_type_ == OpComputeType::op_compute_type_fp16) { reshape_fn = xnn_reshape_average_pooling2d_nhwc_f16; - } else if (avgpool_type_ == OpComputeType::op_compute_type_qu8) { - reshape_fn = xnn_reshape_average_pooling2d_nhwc_qu8; } auto status = reshape_fn(op0_.get(), N, H, W, C, C, C, - &workspace_size, &workspace_alignment, /*output_height_out=*/nullptr, /*output_width_out=*/nullptr, threadpool); @@ -267,17 +237,12 @@ Status AveragePool::Compute(OpKernelContext* context) const { " returned ", status); } - workspace.reset(allocator->aligned_allocate(allocator->context, XNN_ALLOCATION_ALIGNMENT, workspace_size)); - if (avgpool_type_ == OpComputeType::op_compute_type_fp32) { - status = xnn_setup_average_pooling2d_nhwc_f32(op0_.get(), workspace.get(), - X.Data(), Y.MutableData()); + status = xnn_setup_average_pooling2d_nhwc_f32(op0_.get(), X.Data(), + Y.MutableData()); } else if (avgpool_type_ == OpComputeType::op_compute_type_fp16) { - status = xnn_setup_average_pooling2d_nhwc_f16(op0_.get(), workspace.get(), - X.Data(), Y.MutableData()); - } else if (avgpool_type_ == OpComputeType::op_compute_type_qu8) { - status = xnn_setup_average_pooling2d_nhwc_qu8(op0_.get(), workspace.get(), - X.Data(), Y.MutableData()); + status = xnn_setup_average_pooling2d_nhwc_f16(op0_.get(), X.Data(), + Y.MutableData()); } if (status != xnn_status_success) { diff --git a/onnxruntime/core/providers/xnnpack/nn/conv.cc b/onnxruntime/core/providers/xnnpack/nn/conv.cc index 4e6b308e28ae5..3ef0c1a7cf495 100644 --- a/onnxruntime/core/providers/xnnpack/nn/conv.cc +++ b/onnxruntime/core/providers/xnnpack/nn/conv.cc @@ -91,7 +91,6 @@ Status Conv::Compute(OpKernelContext* context) const { // setup allocator/automated dellocate for workspace size_t workspace_size = 0; - size_t workspace_alignment = 0; xnn_allocator* allocator = GetStoredAllocator().second; auto deallocator = [allocator](void* ptr) { allocator->aligned_deallocate(allocator->context, ptr); }; std::unique_ptr workspace(nullptr, deallocator); @@ -108,7 +107,7 @@ Status Conv::Compute(OpKernelContext* context) const { } auto status = reshape_fn(op0_.get(), N, H, W, - &workspace_size, &workspace_alignment, + &workspace_size, /*output_height_out=*/nullptr, /*output_width_out=*/nullptr, threadpool); if (status != xnn_status_success) { diff --git a/onnxruntime/core/providers/xnnpack/nn/conv_base.cc b/onnxruntime/core/providers/xnnpack/nn/conv_base.cc index 44962c1796631..9742f397315a7 100644 --- a/onnxruntime/core/providers/xnnpack/nn/conv_base.cc +++ b/onnxruntime/core/providers/xnnpack/nn/conv_base.cc @@ -24,7 +24,6 @@ Status CreateXnnpackKernel(const ConvAttributes& conv_attrs, const std::optional>& clip_min_max, const Tensor& Weight, const Tensor* Bias, XnnpackOperator& op_uptr, - xnn_code_cache_t code_cache, xnn_weights_cache_t weights_cache, const OpQuantParam& quant_param, OpComputeType conv_type, @@ -79,7 +78,7 @@ Status CreateXnnpackKernel(const ConvAttributes& conv_attrs, C, M, // input channel stride, output channel stride Weight.Data(), B_data, foutput_min, foutput_max, flags, - code_cache, weights_cache, + weights_cache, &p); } else if (conv_type == OpComputeType::op_compute_type_fp16) { const auto* B_data = Bias ? Bias->Data() : nullptr; @@ -97,7 +96,7 @@ Status CreateXnnpackKernel(const ConvAttributes& conv_attrs, Weight.Data(), B_data, // kernel, bias foutput_min, foutput_max, flags, - code_cache, weights_cache, + weights_cache, &p); } else if (conv_type == OpComputeType::op_compute_type_qs8) { const float output_scale = quant_param[2].first[0]; @@ -121,7 +120,7 @@ Status CreateXnnpackKernel(const ConvAttributes& conv_attrs, quant_param[2].second, quant_param[2].first[0], output_min, output_max, flags, - code_cache, weights_cache, + weights_cache, &p); } else if (conv_type == OpComputeType::op_compute_type_qs8_per_channel) { auto* B_data = Bias ? Bias->Data() : nullptr; @@ -145,7 +144,7 @@ Status CreateXnnpackKernel(const ConvAttributes& conv_attrs, quant_param[2].second, quant_param[2].first[0], output_min, output_max, flags, - code_cache, weights_cache, + weights_cache, &p); } else if (conv_type == OpComputeType::op_compute_type_qu8) { const auto* B_data = Bias ? Bias->Data() : nullptr; @@ -170,7 +169,7 @@ Status CreateXnnpackKernel(const ConvAttributes& conv_attrs, quant_param[2].second, quant_param[2].first[0], output_min, output_max, flags, - code_cache, weights_cache, + weights_cache, &p); } @@ -521,7 +520,7 @@ ConvBase::ConvBase(const OpKernelInfo& info, bool is_transpose) Status ConvBase::CreateKernel() { auto ret = CreateXnnpackKernel(convbase_attrs_ref_, C_, M_, kernel_shape_, clip_min_max_, packed_w_, B_, op0_, - GetCodeCache(), GetWeightsCache(), + GetWeightsCache(), quant_param_, conv_type_, is_transpose_); return ret; } diff --git a/onnxruntime/core/providers/xnnpack/tensor/resize.cc b/onnxruntime/core/providers/xnnpack/tensor/resize.cc index 0bb1194643743..32d91084d3507 100644 --- a/onnxruntime/core/providers/xnnpack/tensor/resize.cc +++ b/onnxruntime/core/providers/xnnpack/tensor/resize.cc @@ -228,13 +228,13 @@ Resize::Resize(const OpKernelInfo& info) : UpsampleBase(info), XnnpackKernel{inf auto out_h = output_dims_[1]; auto out_w = output_dims_[2]; if (op_type_ == OpComputeType::op_compute_type_fp32) { - xstatus = xnn_create_resize_bilinear2d_nhwc_f32(out_h, out_w, flags, &p); + xstatus = xnn_create_resize_bilinear2d_nhwc(xnn_datatype_fp32, out_h, out_w, flags, &p); } else if (op_type_ == OpComputeType::op_compute_type_fp16) { - xstatus = xnn_create_resize_bilinear2d_nhwc_f16(out_h, out_w, flags, &p); + xstatus = xnn_create_resize_bilinear2d_nhwc(xnn_datatype_fp16, out_h, out_w, flags, &p); } else if (op_type_ == OpComputeType::op_compute_type_qu8) { - xstatus = xnn_create_resize_bilinear2d_nhwc_u8(out_h, out_w, flags, &p); + xstatus = xnn_create_resize_bilinear2d_nhwc(xnn_datatype_quint8, out_h, out_w, flags, &p); } else { - xstatus = xnn_create_resize_bilinear2d_nhwc_s8(out_h, out_w, flags, &p); + xstatus = xnn_create_resize_bilinear2d_nhwc(xnn_datatype_qint8, out_h, out_w, flags, &p); } ORT_ENFORCE(xstatus == xnn_status_success, "xnn_create_resize_bilinear2d_nhwc_", OpTypeToString(op_type_), " failed. Status:", @@ -257,22 +257,14 @@ Status Resize::ComputeInternal(OpKernelContext* ctx, const Tensor* input, // setup allocator/automated dellocate for workspace size_t workspace_size = 0; - size_t workspace_alignment = 0; xnn_allocator* allocator = GetStoredAllocator().second; auto deallocator = [allocator](void* ptr) { allocator->aligned_deallocate(allocator->context, ptr); }; std::unique_ptr workspace(nullptr, deallocator); - auto reshape_fn = xnn_reshape_resize_bilinear2d_nhwc_f32; - if (op_type_ == OpComputeType::op_compute_type_fp16) { - reshape_fn = xnn_reshape_resize_bilinear2d_nhwc_f16; - } else if (op_type_ == OpComputeType::op_compute_type_qu8) { - reshape_fn = xnn_reshape_resize_bilinear2d_nhwc_u8; - } else if (op_type_ == OpComputeType::op_compute_type_qs8) { - reshape_fn = xnn_reshape_resize_bilinear2d_nhwc_s8; - } + auto reshape_fn = xnn_reshape_resize_bilinear2d_nhwc; auto status = reshape_fn(op0_.get(), N, H, W, C, C, C, - &workspace_size, &workspace_alignment, threadpool); + &workspace_size, threadpool); if (status != xnn_status_success) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "xnn_reshape_resize_bilinear2d_nhwc_", OpTypeToString(op_type_), " returned ", status); @@ -281,17 +273,17 @@ Status Resize::ComputeInternal(OpKernelContext* ctx, const Tensor* input, workspace.reset(allocator->aligned_allocate(allocator->context, XNN_ALLOCATION_ALIGNMENT, workspace_size)); if (op_type_ == OpComputeType::op_compute_type_fp32) { - status = xnn_setup_resize_bilinear2d_nhwc_f32(op0_.get(), workspace.get(), input->Data(), - output->MutableData()); + status = xnn_setup_resize_bilinear2d_nhwc(op0_.get(), workspace.get(), input->Data(), + output->MutableData()); } else if (op_type_ == OpComputeType::op_compute_type_fp16) { - status = xnn_setup_resize_bilinear2d_nhwc_f16(op0_.get(), workspace.get(), input->Data(), - output->MutableData()); + status = xnn_setup_resize_bilinear2d_nhwc(op0_.get(), workspace.get(), input->Data(), + output->MutableData()); } else if (op_type_ == OpComputeType::op_compute_type_qu8) { - status = xnn_setup_resize_bilinear2d_nhwc_u8(op0_.get(), workspace.get(), input->Data(), - output->MutableData()); + status = xnn_setup_resize_bilinear2d_nhwc(op0_.get(), workspace.get(), input->Data(), + output->MutableData()); } else { - status = xnn_setup_resize_bilinear2d_nhwc_s8(op0_.get(), workspace.get(), input->Data(), - output->MutableData()); + status = xnn_setup_resize_bilinear2d_nhwc(op0_.get(), workspace.get(), input->Data(), + output->MutableData()); } if (status != xnn_status_success) { diff --git a/onnxruntime/core/providers/xnnpack/xnnpack_kernel.h b/onnxruntime/core/providers/xnnpack/xnnpack_kernel.h index 31512586be19d..1779f51046c59 100644 --- a/onnxruntime/core/providers/xnnpack/xnnpack_kernel.h +++ b/onnxruntime/core/providers/xnnpack/xnnpack_kernel.h @@ -24,8 +24,6 @@ class XnnpackKernel : public OpKernel { } // see comment below about enabling code cache - // xnn_code_cache_t GetCodeCache() { return caches_.auto_code_cache.get();} - xnn_code_cache_t GetCodeCache() { return nullptr; } xnn_weights_cache_t GetWeightsCache() { return caches_.auto_weights_cache.get(); } private: @@ -42,11 +40,6 @@ class XnnpackKernel : public OpKernel { if (enable) { #ifdef XNN_CACHE_ENABLE xnn_status status = xnn_status_success; -#if XNN_PLATFORM_JIT - // status = xnn_init_code_cache(&code_cache_); - // ORT_ENFORCE(status == xnn_status_success, "Failed to initialize XNNPACK code cache");) - // auto_code_cache.reset(&code_cache_); -#endif // status = xnn_init_weights_cache(&weights_cache_); xnn_weights_cache_t weights_cache_provider = nullptr; status = xnn_create_weights_cache(&weights_cache, 0); diff --git a/onnxruntime/core/session/abi_devices.h b/onnxruntime/core/session/abi_devices.h index 50469126996b2..571a9eb2a54e2 100644 --- a/onnxruntime/core/session/abi_devices.h +++ b/onnxruntime/core/session/abi_devices.h @@ -68,6 +68,9 @@ struct OrtEpDevice { const OrtMemoryInfo* device_memory_info{nullptr}; const OrtMemoryInfo* host_accessible_memory_info{nullptr}; + // used internally by ORT for initializers only. optional. + const OrtMemoryInfo* read_only_device_memory_info{nullptr}; + // the user provides const OrtEpDevice instances, but the OrtEpFactory API takes non-const instances for all // get/create methods to be as flexible as possible. this helper converts to a non-const factory instance. OrtEpFactory* GetMutableFactory() const { return ep_factory; } diff --git a/onnxruntime/core/session/allocator_adapters.cc b/onnxruntime/core/session/allocator_adapters.cc index c6eff29a0bd4f..008d54c44ff70 100644 --- a/onnxruntime/core/session/allocator_adapters.cc +++ b/onnxruntime/core/session/allocator_adapters.cc @@ -3,6 +3,7 @@ #include "allocator_adapters.h" #include "core/framework/error_code_helper.h" +#include "core/framework/plugin_ep_stream.h" #include "core/session/abi_devices.h" #include "core/session/abi_key_value_pairs.h" #include "core/session/environment.h" @@ -21,24 +22,33 @@ namespace { // `IAllocatorImplWrappingOrtAllocator` to ensure compatibility. constexpr uint32_t kOrtAllocatorReserveMinVersion = 18; constexpr uint32_t kOrtAllocatorStatsMinVersion = 23; +constexpr uint32_t kOrtAllocatorAllocOnStreamMinVersion = 23; } // namespace OrtAllocatorImplWrappingIAllocator::OrtAllocatorImplWrappingIAllocator(onnxruntime::AllocatorPtr&& i_allocator) : i_allocator_(std::move(i_allocator)) { OrtAllocator::version = ORT_API_VERSION; - OrtAllocator::Alloc = - [](OrtAllocator* this_, size_t size) { return static_cast(this_)->Alloc(size); }; - OrtAllocator::Free = - [](OrtAllocator* this_, void* p) { static_cast(this_)->Free(p); }; - OrtAllocator::Info = - [](const OrtAllocator* this_) { return static_cast(this_)->Info(); }; + + OrtAllocator::Alloc = [](OrtAllocator* this_, size_t size) { + return static_cast(this_)->Alloc(size); + }; + + OrtAllocator::Free = [](OrtAllocator* this_, void* p) { + static_cast(this_)->Free(p); + }; + + OrtAllocator::Info = [](const OrtAllocator* this_) { + return static_cast(this_)->Info(); + }; + if (OrtAllocator::version >= kOrtAllocatorReserveMinVersion) { - OrtAllocator::Reserve = - [](OrtAllocator* this_, size_t size) { return static_cast(this_)->Reserve(size); }; + OrtAllocator::Reserve = [](OrtAllocator* this_, size_t size) { + return static_cast(this_)->Reserve(size); + }; } + if (OrtAllocator::version >= kOrtAllocatorStatsMinVersion) { - OrtAllocator::GetStats = - [](const OrtAllocator* this_, OrtKeyValuePairs** stats) noexcept -> OrtStatusPtr { + OrtAllocator::GetStats = [](const OrtAllocator* this_, OrtKeyValuePairs** stats) noexcept -> OrtStatusPtr { API_IMPL_BEGIN auto kvp = std::make_unique(); const auto& stats_map = static_cast(this_)->Stats(); @@ -48,12 +58,22 @@ OrtAllocatorImplWrappingIAllocator::OrtAllocatorImplWrappingIAllocator(onnxrunti API_IMPL_END }; } + + if (OrtAllocator::version >= kOrtAllocatorAllocOnStreamMinVersion) { + OrtAllocator::AllocOnStream = [](OrtAllocator* this_, size_t size, OrtSyncStream* stream) { + return static_cast(this_)->AllocOnStream(size, stream); + }; + } } void* OrtAllocatorImplWrappingIAllocator::Alloc(size_t size) { return i_allocator_->Alloc(size); } +void* OrtAllocatorImplWrappingIAllocator::AllocOnStream(size_t size, OrtSyncStream* stream) { + return i_allocator_->AllocOnStream(size, static_cast(stream)); +} + void* OrtAllocatorImplWrappingIAllocator::Reserve(size_t size) { return i_allocator_->Reserve(size); } @@ -105,6 +125,18 @@ void* IAllocatorImplWrappingOrtAllocator::Alloc(size_t size) { return ort_allocator_->Alloc(ort_allocator_.get(), size); } +bool IAllocatorImplWrappingOrtAllocator::IsStreamAware() const { + return ort_allocator_->version >= kOrtAllocatorAllocOnStreamMinVersion && ort_allocator_->AllocOnStream != nullptr; +} + +void* IAllocatorImplWrappingOrtAllocator::AllocOnStream(size_t size, Stream* stream) { + if (ort_allocator_->version >= kOrtAllocatorAllocOnStreamMinVersion && ort_allocator_->AllocOnStream) { + return ort_allocator_->AllocOnStream(ort_allocator_.get(), size, static_cast(stream)); + } + + return ort_allocator_->Alloc(ort_allocator_.get(), size); +} + void* IAllocatorImplWrappingOrtAllocator::Reserve(size_t size) { if (ort_allocator_->version >= kOrtAllocatorReserveMinVersion && ort_allocator_->Reserve) { return ort_allocator_->Reserve(ort_allocator_.get(), size); diff --git a/onnxruntime/core/session/allocator_adapters.h b/onnxruntime/core/session/allocator_adapters.h index 544c7828e46f8..d67eae90985bf 100644 --- a/onnxruntime/core/session/allocator_adapters.h +++ b/onnxruntime/core/session/allocator_adapters.h @@ -25,6 +25,7 @@ struct OrtAllocatorImplWrappingIAllocator final : public OrtAllocatorImpl { ~OrtAllocatorImplWrappingIAllocator() override = default; void* Alloc(size_t size); + void* AllocOnStream(size_t size, OrtSyncStream* stream); void Free(void* p); void* Reserve(size_t size); @@ -56,6 +57,9 @@ class IAllocatorImplWrappingOrtAllocator final : public IAllocator { void Free(void* p) override; void* Reserve(size_t size) override; + bool IsStreamAware() const override; + void* AllocOnStream(size_t size, Stream* stream) override; + const OrtAllocator* GetWrappedOrtAllocator() const { return ort_allocator_.get(); } diff --git a/onnxruntime/core/session/custom_ops.cc b/onnxruntime/core/session/custom_ops.cc index c5fc4e7ccf76f..2a898a2b0bf9f 100644 --- a/onnxruntime/core/session/custom_ops.cc +++ b/onnxruntime/core/session/custom_ops.cc @@ -757,7 +757,7 @@ ORT_API_STATUS_IMPL(OrtApis::KernelContext_GetScratchBuffer, _In_ const OrtKerne return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "No requested allocator available"); } onnxruntime::Stream* stream = reinterpret_cast(context)->GetComputeStream(); - *out = AllocateBufferWithOptions(*allocator, count_or_bytes, false, stream, stream->GetWaitNotificationFn()); + *out = AllocateBufferWithOptions(*allocator, count_or_bytes, false, stream); return nullptr; }; diff --git a/onnxruntime/core/session/environment.cc b/onnxruntime/core/session/environment.cc index 493c0a106074c..2b553aecbca6c 100644 --- a/onnxruntime/core/session/environment.cc +++ b/onnxruntime/core/session/environment.cc @@ -16,10 +16,10 @@ #include "core/session/abi_session_options_impl.h" #include "core/session/allocator_adapters.h" #include "core/session/inference_session.h" -#include "core/session/ep_factory_internal.h" -#include "core/session/ep_library_internal.h" -#include "core/session/ep_library_plugin.h" -#include "core/session/ep_library_provider_bridge.h" +#include "core/session/plugin_ep/ep_factory_internal.h" +#include "core/session/plugin_ep/ep_library_internal.h" +#include "core/session/plugin_ep/ep_library_plugin.h" +#include "core/session/plugin_ep/ep_library_provider_bridge.h" #include "core/session/ort_apis.h" #include "core/session/utils.h" @@ -639,6 +639,13 @@ Status Environment::CreateSharedAllocatorImpl(const OrtEpDevice& ep_device, bool replace_existing) { // NOTE: memory_info is guaranteed to come from the OrtEpDevice when this is called + if (allocator_type == OrtAllocatorType::OrtArenaAllocator) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "OrtAllocatorType::OrtArenaAllocator is reserved for ONNX Runtime internal usage only. " + "The EP implements arena support internally so please use OrtDeviceAllocator and provide " + "any arena options via the allocator options."); + } + // we need to remove from shared_ort_allocators_ first in case the entry in shared_allocators_ owns the pointer in // shared_ort_allocators_. if (auto it = FindExistingAllocator(shared_ort_allocators_, memory_info, /*match_name*/ true); @@ -669,48 +676,21 @@ Status Environment::CreateSharedAllocatorImpl(const OrtEpDevice& ep_device, return ToStatusAndRelease(ort_status); } + if (allocator->Info(allocator)->alloc_type == OrtAllocatorType::OrtArenaAllocator) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "OrtEpFactory returned an allocator with OrtAllocatorType of OrtArenaAllocator. " + "This type is reserved for ONNX Runtime internal usage only, as any arena usage by the " + "EP library should be opaque to ORT"); + } + auto ort_allocator = OrtAllocatorUniquePtr(allocator, [&ep_device](OrtAllocator* allocator) { ep_device.ep_factory->ReleaseAllocator(ep_device.ep_factory, allocator); }); - AllocatorPtr shared_allocator; - - if (allocator_type == OrtArenaAllocator) { - // wrap with ORT arena - OrtArenaCfg arena_cfg; - if (allocator_options != nullptr) { - auto status = OrtArenaCfg::FromKeyValuePairs(*allocator_options, arena_cfg); - } - - bool stream_aware_arena = ep_device.ep_factory->IsStreamAware(ep_device.ep_factory); - - AllocatorCreationInfo alloc_creation_info{ - [&ort_allocator](int) -> std::unique_ptr { - return std::make_unique(std::move(ort_allocator)); - }, - /*unused*/ -1, // arg to the lambda above that is ignored as the device id comes from the allocator - /*create_arena*/ true, - arena_cfg, - stream_aware_arena, - }; - - shared_allocator = CreateAllocator(alloc_creation_info); - - // need an OrtAllocator to return to the user so we need yet another layer. - // we pass in a copy of the AllocatorPtr (which is a shared_ptr) in order to maintain the overall condition that - // shared_allocators_ is the main owner of the allocator and the last place we delete from when removing - // from shared_ort_allocators_, arena_ort_allocators_ and shared_allocators_. - auto arena_ort_allocator = std::make_unique(AllocatorPtr(shared_allocator)); - allocator = arena_ort_allocator.get(); - - // store the entry using the EPs memory info for easier lookup when removing - arena_ort_allocators_.insert({&memory_info, std::move(arena_ort_allocator)}); - } else { - shared_ort_allocators_.insert(allocator); - shared_allocator = std::make_shared(std::move(ort_allocator)); - } + shared_ort_allocators_.insert(allocator); + AllocatorPtr shared_allocator = std::make_shared(std::move(ort_allocator)); shared_allocators_.push_back(std::move(shared_allocator)); if (allocator_out != nullptr) { diff --git a/onnxruntime/core/session/ep_library_internal.cc b/onnxruntime/core/session/ep_library_internal.cc deleted file mode 100644 index 986ccb1fa17fc..0000000000000 --- a/onnxruntime/core/session/ep_library_internal.cc +++ /dev/null @@ -1,281 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/session/ep_library_internal.h" - -#include "core/framework/error_code_helper.h" -#include "core/framework/ortmemoryinfo.h" -#include "core/framework/session_options.h" -#include "core/providers/cpu/cpu_execution_provider.h" -#include "core/session/abi_devices.h" -#include "core/session/abi_logger.h" -#include "core/session/abi_session_options_impl.h" -#include "core/session/ep_api.h" -#include "core/session/ort_apis.h" - -#if defined(USE_DML) -#include "core/providers/dml/dml_provider_factory_creator.h" -#endif - -#if defined(USE_WEBGPU) -#include "core/providers/webgpu/webgpu_provider_factory_creator.h" -#endif - -namespace onnxruntime { - -class CpuEpFactory : public EpFactoryInternalImpl { - public: - CpuEpFactory() : EpFactoryInternalImpl(kCpuExecutionProvider, "Microsoft", OrtDevice::VendorIds::MICROSOFT) { - } - - private: - OrtStatus* GetSupportedDevices(EpFactoryInternal& ep_factory, - const OrtHardwareDevice* const* devices, - size_t num_devices, - OrtEpDevice** ep_devices, - size_t max_ep_devices, - size_t* p_num_ep_devices) noexcept override { - size_t& num_ep_devices = *p_num_ep_devices; - for (size_t i = 0; i < num_devices && num_ep_devices < max_ep_devices; ++i) { - const OrtHardwareDevice& device = *devices[i]; - if (device.type == OrtHardwareDeviceType::OrtHardwareDeviceType_CPU) { - ORT_API_RETURN_IF_ERROR( - OrtExecutionProviderApi::CreateEpDevice(&ep_factory, &device, nullptr, nullptr, - &ep_devices[num_ep_devices++])); - } - } - - return nullptr; - } - - OrtStatus* CreateIExecutionProvider(const OrtHardwareDevice* const* /*devices*/, - const OrtKeyValuePairs* const* /*ep_metadata_pairs*/, - size_t num_devices, - const OrtSessionOptions* session_options, - const OrtLogger* session_logger, - std::unique_ptr* ep) noexcept override { - if (num_devices != 1) { - return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, - "CPU EP factory currently only supports one device at a time."); - } - - CPUExecutionProviderInfo epi{session_options->value.enable_cpu_mem_arena}; - *ep = std::make_unique(epi); - (*ep)->SetLogger(session_logger->ToInternal()); - - return nullptr; - } -}; - -std::unique_ptr EpLibraryInternal::CreateCpuEp() { - auto cpu_factory_impl = std::make_unique(); - auto internal_factory = std::make_unique(std::move(cpu_factory_impl)); - return std::make_unique(std::move(internal_factory)); -} - -#if defined(USE_DML) -class DmlEpFactory : public EpFactoryInternalImpl { - public: - DmlEpFactory() : EpFactoryInternalImpl(kDmlExecutionProvider, "Microsoft", OrtDevice::VendorIds::MICROSOFT) { - } - - private: - OrtStatus* GetSupportedDevices(EpFactoryInternal& ep_factory, - const OrtHardwareDevice* const* devices, - size_t num_devices, - OrtEpDevice** ep_devices, - size_t max_ep_devices, - size_t* p_num_ep_devices) noexcept override { - size_t& num_ep_devices = *p_num_ep_devices; - num_ep_devices = 0; - - for (size_t i = 0; i < num_devices && num_ep_devices < max_ep_devices; ++i) { - const OrtHardwareDevice& device = *devices[i]; - if (device.type == OrtHardwareDeviceType::OrtHardwareDeviceType_GPU) { - std::unique_ptr ep_options; - - // TODO: Should we ignore a user provided 'device_id' when they select an OrtEpDevice as that is - // associated with a specific device. - // How would we know what options should not allow user overrides if set in OrtEpDevice? - int32_t device_id = 0; // If no device_id was found default to 0 - if (auto it = device.metadata.Entries().find("DxgiAdapterNumber"); it != device.metadata.Entries().end()) { - ep_options = std::make_unique(); - device_id = std::stoi(it->second); - } - - ep_options->Add("device_id", std::to_string(device_id)); - - auto* api_status = OrtExecutionProviderApi::CreateEpDevice(&ep_factory, - &device, nullptr, ep_options.get(), - &ep_devices[num_ep_devices]); - - if (device_memory_infos.size() < device_id + 1) { - device_memory_infos.resize(device_id + 1); - device_allocators.resize(device_id + 1); - } - - if (device_memory_infos[device_id] == nullptr) { - // Create memory info for the device if it doesn't already exist - device_memory_infos[device_id] = std::make_unique( - "DML", OrtAllocatorType::OrtDeviceAllocator, - OrtDevice(OrtDevice::DML, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::MICROSOFT, - narrow(device_id))); - } - - // This is what we need to add once CreateAllocator is implemented to create a shared allocator for the device. - // OrtExecutionProviderApi::EpDevice_AddAllocatorInfo(ep_devices[num_ep_devices], - // device_memory_infos[device_id].get()); - - if (api_status != nullptr) { - return api_status; - } - - ++num_ep_devices; - } - } - - return nullptr; - } - - OrtStatus* CreateIExecutionProvider(const OrtHardwareDevice* const* /*devices*/, - const OrtKeyValuePairs* const* /*ep_metadata_pairs*/, - size_t num_devices, - const OrtSessionOptions* session_options, - const OrtLogger* session_logger, - std::unique_ptr* ep) noexcept override { - *ep = nullptr; - - if (num_devices != 1) { - return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, - "DML EP factory currently only supports one device at a time."); - } - - auto ep_options = GetOptionsFromSessionOptions(session_options->value); - auto dml_ep_factory = DMLProviderFactoryCreator::CreateFromProviderOptions(session_options->value.config_options, - ep_options); - - *ep = dml_ep_factory->CreateProvider(); - (*ep)->SetLogger(session_logger->ToInternal()); - - return nullptr; - } - - OrtStatus* CreateAllocator(const OrtMemoryInfo* /*memory_info*/, - const OrtKeyValuePairs* /*allocator_options*/, - OrtAllocator** allocator) noexcept override { - // TODO: This needs to create an allocator for the specific device so it's available as a shared allocator. That - // requires pulling lots of things out of the DML EP to get the D3D12 device and create a - // BucketizedBufferAllocator. See providers\dml\DmlExecutionProvider\src\ExecutionProvider.cpp - //*allocator = device_allocators[memory_info->device.Id()].get(); - *allocator = nullptr; - return nullptr; - } - - OrtStatus* CreateDataTransfer(_Outptr_result_maybenull_ OrtDataTransferImpl** data_transfer) override { - // TODO: Wrap the IDataTransfer implementation so we can copy to device using OrtApi CopyTensors. - *data_transfer = nullptr; - return nullptr; - } - - std::vector> device_memory_infos; // memory info for each device - std::vector> device_allocators; // allocators for each device -}; - -std::unique_ptr EpLibraryInternal::CreateDmlEp() { - auto dml_factory_impl = std::make_unique(); - auto internal_factory = std::make_unique(std::move(dml_factory_impl)); - return std::make_unique(std::move(internal_factory)); -} -#endif - -#if defined(USE_WEBGPU) -class WebGpuEpFactory : public EpFactoryInternalImpl { - public: - WebGpuEpFactory() : EpFactoryInternalImpl(kWebGpuExecutionProvider, "Microsoft", OrtDevice::VendorIds::MICROSOFT) { - } - - private: - OrtStatus* GetSupportedDevices(EpFactoryInternal& ep_factory, - const OrtHardwareDevice* const* devices, - size_t num_devices, - OrtEpDevice** ep_devices, - size_t max_ep_devices, - size_t* p_num_ep_devices) noexcept override { - size_t& num_ep_devices = *p_num_ep_devices; - num_ep_devices = 0; - - for (size_t i = 0; i < num_devices && num_ep_devices < max_ep_devices; ++i) { - const OrtHardwareDevice& device = *devices[i]; - if (device.type == OrtHardwareDeviceType::OrtHardwareDeviceType_GPU) { - // TODO: any metadata or options to add? - ORT_API_RETURN_IF_ERROR(OrtExecutionProviderApi::CreateEpDevice(&ep_factory, - &device, nullptr, nullptr, - &ep_devices[num_ep_devices++])); - } - } - - return nullptr; - } - - OrtStatus* CreateIExecutionProvider(const OrtHardwareDevice* const* /*devices*/, - const OrtKeyValuePairs* const* /*ep_metadata_pairs*/, - size_t num_devices, - const OrtSessionOptions* session_options, - const OrtLogger* session_logger, - std::unique_ptr* ep) noexcept override { - *ep = nullptr; - - if (num_devices != 1) { - return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, - "WebGPU EP factory currently only supports one device at a time."); - } - - auto webgpu_ep_factory = WebGpuProviderFactoryCreator::Create(session_options->value.config_options); - *ep = webgpu_ep_factory->CreateProvider(); - (*ep)->SetLogger(session_logger->ToInternal()); - - return nullptr; - } - - /* TODO: Implement CreateAllocator and CreateDataTransfer to support shared allocators and data transfer outside of - an InferenceSession. - OrtStatus* CreateAllocator(const OrtMemoryInfo* memory_info, - const OrtKeyValuePairs* allocator_options, - OrtAllocator** allocator) noexcept override { - *allocator = device_allocators[memory_info->device.Id()].get(); - } - - OrtStatus* CreateDataTransfer(_Outptr_result_maybenull_ OrtDataTransferImpl** data_transfer) override { - // TODO: Wrap the IDataTransfer implementation so we can copy to device using OrtApi CopyTensors. - *data_transfer = nullptr; - return nullptr; - } - */ -}; - -std::unique_ptr EpLibraryInternal::CreateWebGpuEp() { - auto webgpu_factory_impl = std::make_unique(); - auto internal_factory = std::make_unique(std::move(webgpu_factory_impl)); - return std::make_unique(std::move(internal_factory)); -} -#endif - -std::vector> EpLibraryInternal::CreateInternalEps() { - std::vector> internal_eps; - internal_eps.reserve(4); - - // CPU EP - internal_eps.push_back(CreateCpuEp()); - -#if defined(USE_WEBGPU) - internal_eps.push_back(CreateWebGpuEp()); -#endif - -#if defined(USE_DML) - internal_eps.push_back(CreateDmlEp()); -#endif - - return internal_eps; -} - -} // namespace onnxruntime diff --git a/onnxruntime/core/session/ep_library_provider_bridge.cc b/onnxruntime/core/session/ep_library_provider_bridge.cc deleted file mode 100644 index ae553891beaa7..0000000000000 --- a/onnxruntime/core/session/ep_library_provider_bridge.cc +++ /dev/null @@ -1,140 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/session/ep_library_provider_bridge.h" - -#include "core/common/status.h" -#include "core/framework/error_code_helper.h" -#include "core/framework/session_options.h" -#include "core/providers/cuda/cuda_provider_options.h" -#include "core/providers/shared_library/provider_host_api.h" -#include "core/session/abi_devices.h" -#include "core/session/abi_session_options_impl.h" -#include "core/session/ep_factory_internal.h" - -namespace onnxruntime { -class ProviderBridgeEpFactory : public EpFactoryInternalImpl { - public: - ProviderBridgeEpFactory(OrtEpFactory& ep_factory, ProviderLibrary& provider_library) - : EpFactoryInternalImpl(ep_factory.GetName(&ep_factory), - ep_factory.GetVendor(&ep_factory), - ep_factory.GetVendorId(&ep_factory)), - ep_factory_{ep_factory}, - provider_library_{provider_library} { - } - - private: - OrtStatus* GetSupportedDevices(EpFactoryInternal& ep_factory, - const OrtHardwareDevice* const* devices, - size_t num_devices, - OrtEpDevice** ep_devices, - size_t max_ep_devices, - size_t* num_ep_devices) noexcept override { - ORT_API_RETURN_IF_ERROR(ep_factory_.GetSupportedDevices(&ep_factory_, devices, num_devices, ep_devices, - max_ep_devices, num_ep_devices)); - - // add the EpFactoryInternal layer back in so that we can redirect to CreateIExecutionProvider. - for (size_t i = 0; i < *num_ep_devices; ++i) { - auto* ep_device = ep_devices[i]; - if (ep_device) { - ep_device->ep_factory = &ep_factory; - } - } - - return nullptr; - } - - OrtStatus* CreateIExecutionProvider(const OrtHardwareDevice* const* devices, - const OrtKeyValuePairs* const* ep_metadata_pairs, - size_t num_devices, - const OrtSessionOptions* session_options, - const OrtLogger* session_logger, - std::unique_ptr* ep) noexcept override { - // get the provider specific options - auto ep_options = GetOptionsFromSessionOptions(session_options->value); - auto& provider = provider_library_.Get(); - - auto status = provider.CreateIExecutionProvider(devices, ep_metadata_pairs, num_devices, - ep_options, *session_options, *session_logger, *ep); - - return ToOrtStatus(status); - } - - OrtStatus* CreateAllocator(const OrtMemoryInfo* memory_info, - const OrtKeyValuePairs* allocator_options, - OrtAllocator** allocator) noexcept override { - return ep_factory_.CreateAllocator(&ep_factory_, memory_info, allocator_options, allocator); - } - - void ReleaseAllocator(OrtAllocator* allocator) noexcept override { - ep_factory_.ReleaseAllocator(&ep_factory_, allocator); - } - - OrtStatus* CreateDataTransfer(_Outptr_result_maybenull_ OrtDataTransferImpl** data_transfer) override { - return ep_factory_.CreateDataTransfer(&ep_factory_, data_transfer); - } - - bool IsStreamAware() const noexcept override { - return ep_factory_.IsStreamAware(&ep_factory_); - } - - OrtStatus* CreateSyncStreamForDevice(const OrtMemoryDevice* device, - const OrtKeyValuePairs* stream_options, - OrtSyncStreamImpl** stream) noexcept override { - return ep_factory_.CreateSyncStreamForDevice(&ep_factory_, device, stream_options, stream); - } - - OrtEpFactory& ep_factory_; // OrtEpFactory from the provider bridge EP - ProviderLibrary& provider_library_; // ProviderLibrary from the provider bridge EP -}; - -Status EpLibraryProviderBridge::Load() { - std::lock_guard lock{mutex_}; - - if (!factories_.empty()) { - // already loaded - return Status::OK(); - } - - // if we have been unloaded we can't just be reloaded. - if (!ep_library_plugin_ || !provider_library_) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, - "EpLibraryProviderBridge has been unloaded. " - "Please create a new instance using LoadPluginOrProviderBridge."); - } - - // wrap the EpLibraryPlugin factories that were created via calling CreateEpFactories in the library. - // use GetSupportedDevices from the library's factory. - // to do this we need to capture `factory` and plug it in to is_supported_fn and create_fn. - // we also need to update any returned OrtEpDevice instances to swap the wrapper EpFactoryInternal in so that we can - // call Provider::CreateIExecutionProvider in EpFactoryInternal::CreateIExecutionProvider. - for (const auto& factory : ep_library_plugin_->GetFactories()) { - auto factory_impl = std::make_unique(*factory, *provider_library_); - auto internal_factory = std::make_unique(std::move(factory_impl)); - - factory_ptrs_.push_back(internal_factory.get()); - internal_factory_ptrs_.push_back(internal_factory.get()); - factories_.push_back(std::move(internal_factory)); - } - - return Status::OK(); -} - -Status EpLibraryProviderBridge::Unload() { - std::lock_guard lock{mutex_}; - - internal_factory_ptrs_.clear(); - factory_ptrs_.clear(); - factories_.clear(); - - // we loaded ep_library_plugin_ after provider_library_ in LoadPluginOrProviderBridge so do the reverse order here. - ORT_RETURN_IF_ERROR(ep_library_plugin_->Unload()); - ep_library_plugin_ = nullptr; - - provider_library_->Unload(); - provider_library_ = nullptr; - - return Status::OK(); -} - -} // namespace onnxruntime diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index 25cabd256e318..f4f76a389030e 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -3571,7 +3571,7 @@ common::Status InferenceSession::ValidateAndParseShrinkArenaString(const std::st ++iter; } - // Shrink if it is an arena based allocator + // Shrink if it is a BFCArena allocator // Iterate through the registered allocators as we could have multiple allocators for the device+type // if they differ by vendor_id. for (const auto& [device, allocator_ptr] : session_state_->GetAllocators()) { diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 6ada5df5976df..37f4fe7312bb4 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -38,8 +38,8 @@ #include "core/session/allocator_adapters.h" #include "core/session/compile_api.h" #include "core/session/environment.h" -#include "core/session/ep_api.h" -#include "core/session/ep_library_internal.h" +#include "core/session/plugin_ep/ep_api.h" +#include "core/session/plugin_ep/ep_library_internal.h" #include "core/session/inference_session.h" #include "core/session/inference_session_utils.h" #include "core/session/IOBinding.h" @@ -2514,6 +2514,35 @@ ORT_API_STATUS_IMPL(OrtApis::ValueInfo_GetInitializerValue, _In_ const OrtValueI API_IMPL_END } +ORT_API_STATUS_IMPL(OrtApis::ValueInfo_GetExternalInitializerInfo, _In_ const OrtValueInfo* value_info, + _Outptr_result_maybenull_ OrtExternalInitializerInfo** info) { + API_IMPL_BEGIN + std::unique_ptr ext_data_info = nullptr; + ORT_API_RETURN_IF_STATUS_NOT_OK(value_info->GetExternalInitializerInfo(ext_data_info)); + + // Note: ext_data_info can be nullptr if this OrtValueInfo does not represent an external initializer. + // std::unique_ptr::release() handles both cases. + *info = static_cast(ext_data_info.release()); + return nullptr; + API_IMPL_END +} + +ORT_API(void, OrtApis::ReleaseExternalInitializerInfo, _Frees_ptr_opt_ OrtExternalInitializerInfo* info) { + delete static_cast(info); +} + +ORT_API(const ORTCHAR_T*, OrtApis::ExternalInitializerInfo_GetFilePath, _In_ const OrtExternalInitializerInfo* info) { + return info->GetRelPath().c_str(); +} + +ORT_API(int64_t, OrtApis::ExternalInitializerInfo_GetFileOffset, _In_ const OrtExternalInitializerInfo* info) { + return static_cast(info->GetOffset()); +} + +ORT_API(size_t, OrtApis::ExternalInitializerInfo_GetByteSize, _In_ const OrtExternalInitializerInfo* info) { + return info->GetLength(); +} + ORT_API_STATUS_IMPL(OrtApis::ValueInfo_IsRequiredGraphInput, _In_ const OrtValueInfo* value_info, _Out_ bool* is_required_graph_input) { API_IMPL_BEGIN @@ -3019,6 +3048,10 @@ ORT_API_STATUS_IMPL(OrtApis::OpAttr_GetType, _In_ const OrtOpAttr* attribute, _O *type = OrtOpAttrType::ORT_OP_ATTR_STRINGS; break; } + case ONNX_NAMESPACE::AttributeProto_AttributeType::AttributeProto_AttributeType_GRAPH: { + *type = OrtOpAttrType::ORT_OP_ATTR_GRAPH; + break; + } default: return OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, "Unexpected attribute type."); } @@ -3966,6 +3999,7 @@ static constexpr OrtApi ort_api_1_to_23 = { &OrtApis::ValueInfo_GetValueNumConsumers, &OrtApis::ValueInfo_GetValueConsumers, &OrtApis::ValueInfo_GetInitializerValue, + &OrtApis::ValueInfo_GetExternalInitializerInfo, &OrtApis::ValueInfo_IsRequiredGraphInput, &OrtApis::ValueInfo_IsOptionalGraphInput, &OrtApis::ValueInfo_IsGraphOutput, @@ -4006,6 +4040,10 @@ static constexpr OrtApi ort_api_1_to_23 = { &OrtApis::Node_GetSubgraphs, &OrtApis::Node_GetGraph, &OrtApis::Node_GetEpName, + &OrtApis::ReleaseExternalInitializerInfo, + &OrtApis::ExternalInitializerInfo_GetFilePath, + &OrtApis::ExternalInitializerInfo_GetFileOffset, + &OrtApis::ExternalInitializerInfo_GetByteSize, &OrtApis::GetRunConfigEntry, diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index 772de5e312ffb..d2f22397bf82c 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -621,6 +621,8 @@ ORT_API_STATUS_IMPL(ValueInfo_GetValueConsumers, _In_ const OrtValueInfo* value_ _In_ size_t num_consumers); ORT_API_STATUS_IMPL(ValueInfo_GetInitializerValue, _In_ const OrtValueInfo* value_info, _Outptr_ const OrtValue** initializer_value); +ORT_API_STATUS_IMPL(ValueInfo_GetExternalInitializerInfo, _In_ const OrtValueInfo* value_info, + _Outptr_result_maybenull_ OrtExternalInitializerInfo** info); ORT_API_STATUS_IMPL(ValueInfo_IsRequiredGraphInput, _In_ const OrtValueInfo* value_info, _Out_ bool* is_required_graph_input); ORT_API_STATUS_IMPL(ValueInfo_IsOptionalGraphInput, _In_ const OrtValueInfo* value_info, @@ -686,6 +688,12 @@ ORT_API_STATUS_IMPL(Node_GetSubgraphs, _In_ const OrtNode* node, ORT_API_STATUS_IMPL(Node_GetGraph, _In_ const OrtNode* node, _Outptr_result_maybenull_ const OrtGraph** graph); ORT_API_STATUS_IMPL(Node_GetEpName, _In_ const OrtNode* node, _Outptr_result_maybenull_ const char** out); +// OrtExternalInitializerInfo +ORT_API(void, ReleaseExternalInitializerInfo, _Frees_ptr_opt_ OrtExternalInitializerInfo* info); +ORT_API(const ORTCHAR_T*, ExternalInitializerInfo_GetFilePath, _In_ const OrtExternalInitializerInfo* info); +ORT_API(int64_t, ExternalInitializerInfo_GetFileOffset, _In_ const OrtExternalInitializerInfo* info); +ORT_API(size_t, ExternalInitializerInfo_GetByteSize, _In_ const OrtExternalInitializerInfo* info); + ORT_API(const char*, GetRunConfigEntry, _In_ const OrtRunOptions* options, _In_z_ const char* config_key); ORT_API(const OrtMemoryInfo*, EpDevice_MemoryInfo, _In_ const OrtEpDevice* ep_device, diff --git a/onnxruntime/core/session/ep_api.cc b/onnxruntime/core/session/plugin_ep/ep_api.cc similarity index 84% rename from onnxruntime/core/session/ep_api.cc rename to onnxruntime/core/session/plugin_ep/ep_api.cc index c49985d74c988..cae0b086af66c 100644 --- a/onnxruntime/core/session/ep_api.cc +++ b/onnxruntime/core/session/plugin_ep/ep_api.cc @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "core/session/ep_api.h" +#include "core/session/plugin_ep/ep_api.h" #include #include @@ -114,7 +114,11 @@ ORT_API_STATUS_IMPL(EpDevice_AddAllocatorInfo, _In_ OrtEpDevice* ep_device, const OrtDevice& info = allocator_memory_info->device; switch (info.MemType()) { case OrtDevice::MemType::DEFAULT: - ep_device->device_memory_info = allocator_memory_info; + if (allocator_memory_info->alloc_type == OrtReadOnlyAllocator) { + ep_device->read_only_device_memory_info = allocator_memory_info; + } else { + ep_device->device_memory_info = allocator_memory_info; + } break; case OrtDevice::MemType::HOST_ACCESSIBLE: ep_device->host_accessible_memory_info = allocator_memory_info; @@ -176,6 +180,31 @@ ORT_API(uint32_t, MemoryDevice_GetDeviceId, _In_ const OrtMemoryDevice* memory_d return memory_device->Id(); } +ORT_API(const OrtSyncStreamImpl*, SyncStream_GetImpl, _In_ const OrtSyncStream* ort_stream) { + // the EP API should only ever see plugin_ep::Stream instances + const auto& stream = *reinterpret_cast(ort_stream); + return &stream.GetImpl(); +} + +ORT_API(uint64_t, SyncStream_GetSyncId, _In_ const OrtSyncStream* stream) { + return static_cast(stream)->GetSyncId(); +} + +ORT_API(uint64_t, GetSyncIdForLastWaitOnSyncStream, _In_ const OrtSyncStream* producer_stream, + _In_ const OrtSyncStream* consumer_stream) { + uint64_t id{0}; + if (producer_stream && consumer_stream) { + const auto& producer = *static_cast(producer_stream); + const auto& consumer = *static_cast(consumer_stream); + + // If both streams are valid, we can return the sync id for the last wait on the producer stream. + // This is useful for synchronizing operations between different streams. + id = consumer.GetSyncIdForLastWaitOnStream(producer); + } + + return id; +} + static constexpr OrtEpApi ort_ep_api = { // NOTE: ABI compatibility depends on the order within this struct so all additions must be at the end, // and no functions can be removed (the implementation needs to change to return an error). @@ -197,6 +226,10 @@ static constexpr OrtEpApi ort_ep_api = { &OrtExecutionProviderApi::MemoryDevice_GetMemoryType, &OrtExecutionProviderApi::MemoryDevice_GetVendorId, &OrtExecutionProviderApi::MemoryDevice_GetDeviceId, + + &OrtExecutionProviderApi::SyncStream_GetImpl, + &OrtExecutionProviderApi::SyncStream_GetSyncId, + &OrtExecutionProviderApi::GetSyncIdForLastWaitOnSyncStream, }; // checks that we don't violate the rule that the functions must remain in the slots they were originally assigned diff --git a/onnxruntime/core/session/ep_api.h b/onnxruntime/core/session/plugin_ep/ep_api.h similarity index 86% rename from onnxruntime/core/session/ep_api.h rename to onnxruntime/core/session/plugin_ep/ep_api.h index 1af23664f71eb..c0dc79f3fb333 100644 --- a/onnxruntime/core/session/ep_api.h +++ b/onnxruntime/core/session/plugin_ep/ep_api.h @@ -35,4 +35,9 @@ ORT_API(OrtMemoryInfoDeviceType, MemoryDevice_GetDeviceType, _In_ const OrtMemor ORT_API(OrtDeviceMemoryType, MemoryDevice_GetMemoryType, _In_ const OrtMemoryDevice* memory_device); ORT_API(uint32_t, MemoryDevice_GetVendorId, _In_ const OrtMemoryDevice* memory_device); ORT_API(uint32_t, MemoryDevice_GetDeviceId, _In_ const OrtMemoryDevice* memory_device); + +ORT_API(const OrtSyncStreamImpl*, SyncStream_GetImpl, _In_ const OrtSyncStream* stream); +ORT_API(uint64_t, SyncStream_GetSyncId, _In_ const OrtSyncStream* stream); +ORT_API(uint64_t, GetSyncIdForLastWaitOnSyncStream, _In_ const OrtSyncStream* producer_stream, + _In_ const OrtSyncStream* consumer_stream); } // namespace OrtExecutionProviderApi diff --git a/onnxruntime/core/session/plugin_ep/ep_factory_cpu.cc b/onnxruntime/core/session/plugin_ep/ep_factory_cpu.cc new file mode 100644 index 0000000000000..7e6d0dd2ae5df --- /dev/null +++ b/onnxruntime/core/session/plugin_ep/ep_factory_cpu.cc @@ -0,0 +1,56 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/session/plugin_ep/ep_factory_cpu.h" + +#include "core/framework/error_code_helper.h" +#include "core/graph/constants.h" +#include "core/providers/cpu/cpu_execution_provider.h" +#include "core/session/abi_devices.h" +#include "core/session/abi_logger.h" +#include "core/session/abi_session_options_impl.h" +#include "core/session/plugin_ep/ep_api.h" +#include "core/session/plugin_ep/ep_factory_internal.h" +#include "core/session/ort_apis.h" + +namespace onnxruntime { + +OrtStatus* CpuEpFactory::GetSupportedDevices(EpFactoryInternal& ep_factory, + const OrtHardwareDevice* const* devices, + size_t num_devices, + OrtEpDevice** ep_devices, + size_t max_ep_devices, + size_t* p_num_ep_devices) noexcept { + size_t& num_ep_devices = *p_num_ep_devices; + num_ep_devices = 0; + + for (size_t i = 0; i < num_devices && num_ep_devices < max_ep_devices; ++i) { + const OrtHardwareDevice& device = *devices[i]; + if (device.type == OrtHardwareDeviceType::OrtHardwareDeviceType_CPU) { + ORT_API_RETURN_IF_ERROR( + OrtExecutionProviderApi::CreateEpDevice(&ep_factory, &device, nullptr, nullptr, + &ep_devices[num_ep_devices++])); + } + } + + return nullptr; +} + +OrtStatus* CpuEpFactory::CreateIExecutionProvider(const OrtHardwareDevice* const* /*devices*/, + const OrtKeyValuePairs* const* /*ep_metadata_pairs*/, + size_t num_devices, + const OrtSessionOptions* session_options, + const OrtLogger* session_logger, + std::unique_ptr* ep) noexcept { + if (num_devices != 1) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, + "CPU EP factory currently only supports one device at a time."); + } + + CPUExecutionProviderInfo epi{session_options->value.enable_cpu_mem_arena}; + *ep = std::make_unique(epi); + (*ep)->SetLogger(session_logger->ToInternal()); + + return nullptr; +} +} // namespace onnxruntime diff --git a/onnxruntime/core/session/plugin_ep/ep_factory_cpu.h b/onnxruntime/core/session/plugin_ep/ep_factory_cpu.h new file mode 100644 index 0000000000000..fba9bac976bb2 --- /dev/null +++ b/onnxruntime/core/session/plugin_ep/ep_factory_cpu.h @@ -0,0 +1,31 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#pragma once + +#include "core/session/plugin_ep/ep_factory_internal_impl.h" + +#include "core/graph/constants.h" + +namespace onnxruntime { + +class CpuEpFactory : public EpFactoryInternalImpl { + public: + CpuEpFactory() : EpFactoryInternalImpl(kCpuExecutionProvider, "Microsoft", OrtDevice::VendorIds::MICROSOFT) { + } + + private: + OrtStatus* GetSupportedDevices(EpFactoryInternal& ep_factory, + const OrtHardwareDevice* const* devices, + size_t num_devices, + OrtEpDevice** ep_devices, + size_t max_ep_devices, + size_t* p_num_ep_devices) noexcept override; + + OrtStatus* CreateIExecutionProvider(const OrtHardwareDevice* const* devices, + const OrtKeyValuePairs* const* ep_metadata_pairs, + size_t num_devices, + const OrtSessionOptions* session_options, + const OrtLogger* session_logger, + std::unique_ptr* ep) noexcept override; +}; +} // namespace onnxruntime diff --git a/onnxruntime/core/session/plugin_ep/ep_factory_dml.cc b/onnxruntime/core/session/plugin_ep/ep_factory_dml.cc new file mode 100644 index 0000000000000..2f12ffa394537 --- /dev/null +++ b/onnxruntime/core/session/plugin_ep/ep_factory_dml.cc @@ -0,0 +1,113 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if defined(USE_DML) + +#include "core/session/plugin_ep/ep_factory_dml.h" + +#include "core/framework/error_code_helper.h" +#include "core/providers/dml/dml_provider_factory_creator.h" +#include "core/session/abi_devices.h" +#include "core/session/abi_logger.h" +#include "core/session/abi_session_options_impl.h" +#include "core/session/plugin_ep/ep_api.h" +#include "core/session/plugin_ep/ep_factory_internal.h" +#include "core/session/ort_apis.h" + +namespace onnxruntime { + +OrtStatus* DmlEpFactory::GetSupportedDevices(EpFactoryInternal& ep_factory, + const OrtHardwareDevice* const* devices, + size_t num_devices, + OrtEpDevice** ep_devices, + size_t max_ep_devices, + size_t* p_num_ep_devices) noexcept { + size_t& num_ep_devices = *p_num_ep_devices; + num_ep_devices = 0; + + for (size_t i = 0; i < num_devices && num_ep_devices < max_ep_devices; ++i) { + const OrtHardwareDevice& device = *devices[i]; + if (device.type == OrtHardwareDeviceType::OrtHardwareDeviceType_GPU) { + auto ep_options = std::make_unique(); + + // TODO: Should we ignore a user provided 'device_id' when they select an OrtEpDevice as that is + // associated with a specific device. + // How would we know what options should not allow user overrides if set in OrtEpDevice? + int32_t device_id = 0; // If no device_id was found default to 0 + if (auto it = device.metadata.Entries().find("DxgiAdapterNumber"); it != device.metadata.Entries().end()) { + device_id = std::stoi(it->second); + } + + ep_options->Add("device_id", std::to_string(device_id)); + + auto* api_status = OrtExecutionProviderApi::CreateEpDevice(&ep_factory, + &device, nullptr, ep_options.get(), + &ep_devices[num_ep_devices]); + + if (device_memory_infos.size() < device_id + 1) { + device_memory_infos.resize(device_id + 1); + device_allocators.resize(device_id + 1); + } + + if (device_memory_infos[device_id] == nullptr) { + // Create memory info for the device if it doesn't already exist + device_memory_infos[device_id] = std::make_unique( + "DML", OrtAllocatorType::OrtDeviceAllocator, + OrtDevice(OrtDevice::DML, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::MICROSOFT, + narrow(device_id))); + } + + // This is what we need to add once CreateAllocator is implemented to create a shared allocator for the device. + // OrtExecutionProviderApi::EpDevice_AddAllocatorInfo(ep_devices[num_ep_devices], + // device_memory_infos[device_id].get()); + + if (api_status != nullptr) { + return api_status; + } + + ++num_ep_devices; + } + } + + return nullptr; +} + +OrtStatus* DmlEpFactory::CreateIExecutionProvider(const OrtHardwareDevice* const* /*devices*/, + const OrtKeyValuePairs* const* /*ep_metadata_pairs*/, + size_t num_devices, + const OrtSessionOptions* session_options, + const OrtLogger* session_logger, + std::unique_ptr* ep) noexcept { + *ep = nullptr; + + if (num_devices != 1) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, + "DML EP factory currently only supports one device at a time."); + } + + auto ep_options = GetOptionsFromSessionOptions(session_options->value); + auto dml_ep_factory = DMLProviderFactoryCreator::CreateFromProviderOptions(session_options->value.config_options, + ep_options); + + *ep = dml_ep_factory->CreateProvider(); + (*ep)->SetLogger(session_logger->ToInternal()); + + return nullptr; +} + +/* +// TODO: This needs to create an allocator for the specific device so it's available as a shared allocator. That +// requires pulling lots of things out of the DML EP to get the D3D12 device and create a +// BucketizedBufferAllocator. See providers\dml\DmlExecutionProvider\src\ExecutionProvider.cpp +OrtStatus* DmlEpFactory::CreateAllocator(const OrtMemoryInfo* memory_info, + const OrtKeyValuePairs* allocator_options, + OrtAllocator** allocator) noexcept { +} + +// TODO: Wrap the IDataTransfer implementation so we can copy to device using OrtApi CopyTensors. +OrtStatus* DmlEpFactory::CreateDataTransfer(_Outptr_result_maybenull_ OrtDataTransferImpl** data_transfer) noexcept { +} +*/ +} // namespace onnxruntime + +#endif // USE_DML diff --git a/onnxruntime/core/session/plugin_ep/ep_factory_dml.h b/onnxruntime/core/session/plugin_ep/ep_factory_dml.h new file mode 100644 index 0000000000000..1cdd172901942 --- /dev/null +++ b/onnxruntime/core/session/plugin_ep/ep_factory_dml.h @@ -0,0 +1,40 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#if defined(USE_DML) + +#include "core/session/plugin_ep/ep_factory_internal_impl.h" + +#include "core/graph/constants.h" + +namespace onnxruntime { + +class DmlEpFactory : public EpFactoryInternalImpl { + public: + DmlEpFactory() : EpFactoryInternalImpl(kDmlExecutionProvider, "Microsoft", OrtDevice::VendorIds::MICROSOFT) { + } + + private: + OrtStatus* GetSupportedDevices(EpFactoryInternal& ep_factory, + const OrtHardwareDevice* const* devices, + size_t num_devices, + OrtEpDevice** ep_devices, + size_t max_ep_devices, + size_t* p_num_ep_devices) noexcept override; + + OrtStatus* CreateIExecutionProvider(const OrtHardwareDevice* const* devices, + const OrtKeyValuePairs* const* ep_metadata_pairs, + size_t num_devices, + const OrtSessionOptions* session_options, + const OrtLogger* session_logger, + std::unique_ptr* ep) noexcept override; + + std::vector> device_memory_infos; // memory info for each device + std::vector> device_allocators; // allocators for each device +}; + +} // namespace onnxruntime + +#endif // USE_DML diff --git a/onnxruntime/core/session/ep_factory_internal.cc b/onnxruntime/core/session/plugin_ep/ep_factory_internal.cc similarity index 58% rename from onnxruntime/core/session/ep_factory_internal.cc rename to onnxruntime/core/session/plugin_ep/ep_factory_internal.cc index 9804aa6a5c42d..3610b0f797a46 100644 --- a/onnxruntime/core/session/ep_factory_internal.cc +++ b/onnxruntime/core/session/plugin_ep/ep_factory_internal.cc @@ -1,18 +1,16 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "core/session/ep_factory_internal.h" +#include "core/session/plugin_ep/ep_factory_internal.h" #include "core/framework/error_code_helper.h" #include "core/session/abi_devices.h" #include "core/session/abi_session_options_impl.h" -#include "core/session/ep_api_utils.h" +#include "core/session/plugin_ep/forward_to_factory_impl.h" #include "core/session/ort_apis.h" -#include "onnxruntime_config.h" // for ORT_VERSION namespace onnxruntime { - -using Forward = ForwardToFactory; +using Forward = ForwardToFactoryImpl; EpFactoryInternal::EpFactoryInternal(std::unique_ptr impl) : impl_{std::move(impl)} { @@ -32,38 +30,6 @@ EpFactoryInternal::EpFactoryInternal(std::unique_ptr impl OrtEpFactory::CreateSyncStreamForDevice = Forward::CreateSyncStreamForDevice; } -const char* EpFactoryInternal::GetVersion() const noexcept { - return ORT_VERSION; -} - -OrtStatus* EpFactoryInternal::CreateEp(const OrtHardwareDevice* const* /*devices*/, - const OrtKeyValuePairs* const* /*ep_metadata_pairs*/, - size_t /*num_devices*/, - const OrtSessionOptions* /*api_session_options*/, - const OrtLogger* /*api_logger*/, - OrtEp** /*ep*/) { - ORT_THROW("Internal error. CreateIExecutionProvider should be used for EpFactoryInternal."); -} - -// Prior to addition to SessionOptions the EP options do not have a prefix. -// They are prefixed with 'ep..' when added to SessionOptions. -// -// Use this function to get the options without the prefix from SessionOptions. -// Required by the option parsing for multiple existing EPs. -ProviderOptions EpFactoryInternalImpl::GetOptionsFromSessionOptions(const SessionOptions& session_options) const { - const std::string option_prefix = OrtSessionOptions::GetProviderOptionPrefix(GetName()); - ProviderOptions ep_options; - - for (const auto& [key, value] : session_options.config_options.configurations) { - if (key.find(option_prefix) == 0) { - // remove the prefix and add - ep_options[key.substr(option_prefix.length())] = value; - } - } - - return ep_options; -} - InternalExecutionProviderFactory::InternalExecutionProviderFactory(EpFactoryInternal& ep_factory, gsl::span ep_devices) : ep_factory_{ep_factory} { diff --git a/onnxruntime/core/session/ep_factory_internal.h b/onnxruntime/core/session/plugin_ep/ep_factory_internal.h similarity index 50% rename from onnxruntime/core/session/ep_factory_internal.h rename to onnxruntime/core/session/plugin_ep/ep_factory_internal.h index ae450efa394e8..0e34fef0ff74c 100644 --- a/onnxruntime/core/session/ep_factory_internal.h +++ b/onnxruntime/core/session/plugin_ep/ep_factory_internal.h @@ -7,85 +7,16 @@ #include #include "core/common/common.h" -#include "core/framework/execution_provider.h" #include "core/providers/providers.h" #include "core/session/onnxruntime_c_api.h" #include "core/session/ort_apis.h" +#include "core/session/plugin_ep/ep_factory_internal_impl.h" + +#include "onnxruntime_config.h" // for ORT_VERSION namespace onnxruntime { -class EpFactoryInternal; -class EpLibraryInternal; struct SessionOptions; - -// class with virtual methods that are implemented for each internal EP -class EpFactoryInternalImpl { - public: - EpFactoryInternalImpl(const std::string& ep_name, const std::string& vendor, uint32_t vendor_id) - : ep_name_(ep_name), vendor_(vendor), vendor_id_(vendor_id) { - } - - const char* GetName() const noexcept { return ep_name_.c_str(); } - const char* GetVendor() const noexcept { return vendor_.c_str(); } - uint32_t GetVendorId() const noexcept { return vendor_id_; } - const char* GetVersion() const noexcept; - - virtual OrtStatus* GetSupportedDevices(EpFactoryInternal& ep_factory, - _In_reads_(num_devices) const OrtHardwareDevice* const* devices, - _In_ size_t num_devices, - _Inout_ OrtEpDevice** ep_devices, - _In_ size_t max_ep_devices, - _Out_ size_t* num_ep_devices) noexcept = 0; - - virtual OrtStatus* CreateIExecutionProvider(_In_reads_(num_devices) const OrtHardwareDevice* const* devices, - _In_reads_(num_devices) const OrtKeyValuePairs* const* ep_metadata_pairs, - _In_ size_t num_devices, - _In_ const OrtSessionOptions* session_options, - _In_ const OrtLogger* logger, - _Out_ std::unique_ptr* ep) = 0; - - virtual OrtStatus* CreateAllocator(_In_ const OrtMemoryInfo* /*memory_info*/, - _In_opt_ const OrtKeyValuePairs* /*allocator_options*/, - _Outptr_ OrtAllocator** allocator) noexcept { - // default implementation does not add OrtMemoryInfo to OrtEpDevice instances returned - // so this should never be called - *allocator = nullptr; - return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "CreateAllocator is not implemented for this EP factory."); - } - - virtual void ReleaseAllocator(_In_ OrtAllocator* /*allocator*/) noexcept { - // we don't create any allocators so we don't need to release any - } - - virtual OrtStatus* CreateDataTransfer(_Outptr_result_maybenull_ OrtDataTransferImpl** data_transfer) { - *data_transfer = nullptr; - return nullptr; // Default implementation does nothing - } - - virtual bool IsStreamAware() const { - return false; - } - - virtual OrtStatus* CreateSyncStreamForDevice(_In_ const OrtMemoryDevice* /*memory_device*/, - _In_opt_ const OrtKeyValuePairs* /*stream_options*/, - _Outptr_result_maybenull_ OrtSyncStreamImpl** stream) { - *stream = nullptr; - return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, - "CreateSyncStreamForDevice is not implemented for this EP factory."); - } - - // Function ORT calls to release an EP instance. - void ReleaseEp(OrtEp* ep); - - virtual ~EpFactoryInternalImpl() = default; - - protected: - ProviderOptions GetOptionsFromSessionOptions(const SessionOptions& session_options) const; - - private: - const std::string ep_name_; // EP name library was registered with - const std::string vendor_; // EP vendor name - const uint32_t vendor_id_; // EP vendor ID -}; +class EpFactoryInternalImpl; // this class can't have any virtual methods as they break using it as an OrtEpFactory* in OrtEpDevice. class EpFactoryInternal : public OrtEpFactory { @@ -95,7 +26,7 @@ class EpFactoryInternal : public OrtEpFactory { const char* GetName() const noexcept { return impl_->GetName(); } const char* GetVendor() const noexcept { return impl_->GetVendor(); } uint32_t GetVendorId() const noexcept { return impl_->GetVendorId(); } - const char* GetVersion() const noexcept; + const char* GetVersion() const noexcept { return ORT_VERSION; } OrtStatus* GetSupportedDevices(_In_reads_(num_devices) const OrtHardwareDevice* const* devices, _In_ size_t num_devices, @@ -106,11 +37,14 @@ class EpFactoryInternal : public OrtEpFactory { } // we don't implement this. CreateIExecutionProvider should be used. - OrtStatus* CreateEp(_In_reads_(num_devices) const OrtHardwareDevice* const* devices, - _In_reads_(num_devices) const OrtKeyValuePairs* const* ep_metadata_pairs, - _In_ size_t num_devices, - _In_ const OrtSessionOptions* session_options, - _In_ const OrtLogger* logger, _Out_ OrtEp** ep); + OrtStatus* CreateEp(_In_reads_(num_devices) const OrtHardwareDevice* const* /*devices*/, + _In_reads_(num_devices) const OrtKeyValuePairs* const* /*ep_metadata_pairs*/, + _In_ size_t /*num_devices*/, + _In_ const OrtSessionOptions* /*session_options*/, + _In_ const OrtLogger* /*logger*/, + _Out_ OrtEp** /*ep*/) { + ORT_THROW("Internal error. CreateIExecutionProvider should be used for EpFactoryInternal."); + } // same input args as CreateEp in case we need something from device or ep_metadata_pairs in the future. OrtStatus* CreateIExecutionProvider(_In_reads_(num_devices) const OrtHardwareDevice* const* devices, @@ -132,24 +66,23 @@ class EpFactoryInternal : public OrtEpFactory { return impl_->ReleaseAllocator(allocator); } - OrtStatus* CreateDataTransfer(_Outptr_result_maybenull_ OrtDataTransferImpl** data_transfer) { + OrtStatus* CreateDataTransfer(_Outptr_result_maybenull_ OrtDataTransferImpl** data_transfer) noexcept { return impl_->CreateDataTransfer(data_transfer); } - bool IsStreamAware() const { + bool IsStreamAware() const noexcept { return impl_->IsStreamAware(); } OrtStatus* CreateSyncStreamForDevice(_In_ const OrtMemoryDevice* memory_device, _In_opt_ const OrtKeyValuePairs* stream_options, - _Outptr_result_maybenull_ OrtSyncStreamImpl** stream) { + _Outptr_result_maybenull_ OrtSyncStreamImpl** stream) noexcept { return impl_->CreateSyncStreamForDevice(memory_device, stream_options, stream); } // Function ORT calls to release an EP instance. - void ReleaseEp(OrtEp* /*ep*/) { + void ReleaseEp(OrtEp* /*ep*/) noexcept { // we never create an OrtEp so we should never be trying to release one - ORT_THROW("Internal error. No ReleaseEp call is required for EpFactoryInternal."); } private: diff --git a/onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.cc b/onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.cc new file mode 100644 index 0000000000000..e61804d842859 --- /dev/null +++ b/onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.cc @@ -0,0 +1,32 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/session/plugin_ep/ep_factory_internal_impl.h" + +#include "core/framework/error_code_helper.h" +#include "core/session/abi_devices.h" +#include "core/session/abi_session_options_impl.h" +#include "core/session/plugin_ep/ep_factory_internal.h" + +namespace onnxruntime { + +// Prior to addition to SessionOptions the EP options do not have a prefix. +// They are prefixed with 'ep..' when added to SessionOptions. +// +// Use this function to get the options without the prefix from SessionOptions. +// Required by the option parsing for multiple existing EPs. +ProviderOptions EpFactoryInternalImpl::GetOptionsFromSessionOptions(const SessionOptions& session_options) const { + const std::string option_prefix = OrtSessionOptions::GetProviderOptionPrefix(GetName()); + ProviderOptions ep_options; + + for (const auto& [key, value] : session_options.config_options.configurations) { + if (key.find(option_prefix) == 0) { + // remove the prefix and add + ep_options[key.substr(option_prefix.length())] = value; + } + } + + return ep_options; +} + +} // namespace onnxruntime diff --git a/onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.h b/onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.h new file mode 100644 index 0000000000000..bd0b76b21511f --- /dev/null +++ b/onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.h @@ -0,0 +1,86 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include + +#include "core/framework/execution_provider.h" +#include "core/framework/provider_options.h" +#include "core/session/onnxruntime_c_api.h" +#include "core/session/ort_apis.h" + +namespace onnxruntime { +class EpFactoryInternal; +struct SessionOptions; + +// class with virtual methods that are implemented for each internal EP +class EpFactoryInternalImpl { + public: + EpFactoryInternalImpl(const std::string& ep_name, const std::string& vendor, uint32_t vendor_id) + : ep_name_(ep_name), vendor_(vendor), vendor_id_(vendor_id) { + } + + const char* GetName() const noexcept { return ep_name_.c_str(); } + const char* GetVendor() const noexcept { return vendor_.c_str(); } + uint32_t GetVendorId() const noexcept { return vendor_id_; } + const char* GetVersion() const noexcept; + + virtual OrtStatus* GetSupportedDevices(EpFactoryInternal& ep_factory, + _In_reads_(num_devices) const OrtHardwareDevice* const* devices, + _In_ size_t num_devices, + _Inout_ OrtEpDevice** ep_devices, + _In_ size_t max_ep_devices, + _Out_ size_t* num_ep_devices) noexcept = 0; + + virtual OrtStatus* CreateIExecutionProvider(_In_reads_(num_devices) const OrtHardwareDevice* const* devices, + _In_reads_(num_devices) const OrtKeyValuePairs* const* ep_metadata_pairs, + _In_ size_t num_devices, + _In_ const OrtSessionOptions* session_options, + _In_ const OrtLogger* logger, + _Out_ std::unique_ptr* ep) = 0; + + virtual OrtStatus* CreateAllocator(_In_ const OrtMemoryInfo* /*memory_info*/, + _In_opt_ const OrtKeyValuePairs* /*allocator_options*/, + _Outptr_ OrtAllocator** allocator) noexcept { + // default implementation does not add OrtMemoryInfo to OrtEpDevice instances returned + // so this should never be called + *allocator = nullptr; + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "CreateAllocator is not implemented for this EP factory."); + } + + virtual void ReleaseAllocator(_In_ OrtAllocator* /*allocator*/) noexcept { + // we don't create any allocators so we don't need to release any + } + + virtual OrtStatus* CreateDataTransfer(_Outptr_result_maybenull_ OrtDataTransferImpl** data_transfer) noexcept { + *data_transfer = nullptr; + return nullptr; // Default implementation does nothing + } + + virtual bool IsStreamAware() const noexcept { + return false; + } + + virtual OrtStatus* CreateSyncStreamForDevice(_In_ const OrtMemoryDevice* /*memory_device*/, + _In_opt_ const OrtKeyValuePairs* /*stream_options*/, + _Outptr_result_maybenull_ OrtSyncStreamImpl** stream) noexcept { + *stream = nullptr; + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, + "CreateSyncStreamForDevice is not implemented for this EP factory."); + } + + // Function ORT calls to release an EP instance. + void ReleaseEp(OrtEp* ep); + + virtual ~EpFactoryInternalImpl() = default; + + protected: + ProviderOptions GetOptionsFromSessionOptions(const SessionOptions& session_options) const; + + private: + const std::string ep_name_; // EP name library was registered with + const std::string vendor_; // EP vendor name + const uint32_t vendor_id_; // EP vendor ID +}; +} // namespace onnxruntime diff --git a/onnxruntime/core/session/plugin_ep/ep_factory_provider_bridge.cc b/onnxruntime/core/session/plugin_ep/ep_factory_provider_bridge.cc new file mode 100644 index 0000000000000..d6e51a44c1c69 --- /dev/null +++ b/onnxruntime/core/session/plugin_ep/ep_factory_provider_bridge.cc @@ -0,0 +1,44 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/session/plugin_ep/ep_factory_provider_bridge.h" + +#include "core/providers/shared_library/provider_host_api.h" + +namespace onnxruntime { +OrtStatus* ProviderBridgeEpFactory::GetSupportedDevices(EpFactoryInternal& ep_factory, + const OrtHardwareDevice* const* devices, + size_t num_devices, + OrtEpDevice** ep_devices, + size_t max_ep_devices, + size_t* num_ep_devices) noexcept { + ORT_API_RETURN_IF_ERROR(ep_factory_.GetSupportedDevices(&ep_factory_, devices, num_devices, ep_devices, + max_ep_devices, num_ep_devices)); + + // add the EpFactoryInternal layer back in so that we can redirect to CreateIExecutionProvider. + for (size_t i = 0; i < *num_ep_devices; ++i) { + auto* ep_device = ep_devices[i]; + if (ep_device) { + ep_device->ep_factory = &ep_factory; + } + } + + return nullptr; +} + +OrtStatus* ProviderBridgeEpFactory::CreateIExecutionProvider(const OrtHardwareDevice* const* devices, + const OrtKeyValuePairs* const* ep_metadata_pairs, + size_t num_devices, + const OrtSessionOptions* session_options, + const OrtLogger* session_logger, + std::unique_ptr* ep) noexcept { + // get the provider specific options + auto ep_options = GetOptionsFromSessionOptions(session_options->value); + auto& provider = provider_library_.Get(); + + auto status = provider.CreateIExecutionProvider(devices, ep_metadata_pairs, num_devices, + ep_options, *session_options, *session_logger, *ep); + + return ToOrtStatus(status); +} +} // namespace onnxruntime diff --git a/onnxruntime/core/session/plugin_ep/ep_factory_provider_bridge.h b/onnxruntime/core/session/plugin_ep/ep_factory_provider_bridge.h new file mode 100644 index 0000000000000..437af62dc2c0c --- /dev/null +++ b/onnxruntime/core/session/plugin_ep/ep_factory_provider_bridge.h @@ -0,0 +1,66 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/framework/error_code_helper.h" +#include "core/session/abi_devices.h" +#include "core/session/abi_session_options_impl.h" +#include "core/session/plugin_ep/ep_factory_internal.h" +#include "core/session/provider_bridge_library.h" + +namespace onnxruntime { +class ProviderBridgeEpFactory : public EpFactoryInternalImpl { + public: + ProviderBridgeEpFactory(OrtEpFactory& ep_factory, ProviderLibrary& provider_library) + : EpFactoryInternalImpl(ep_factory.GetName(&ep_factory), + ep_factory.GetVendor(&ep_factory), + ep_factory.GetVendorId(&ep_factory)), + ep_factory_{ep_factory}, + provider_library_{provider_library} { + } + + private: + OrtStatus* GetSupportedDevices(EpFactoryInternal& ep_factory, + const OrtHardwareDevice* const* devices, + size_t num_devices, + OrtEpDevice** ep_devices, + size_t max_ep_devices, + size_t* num_ep_devices) noexcept override; + + OrtStatus* CreateIExecutionProvider(const OrtHardwareDevice* const* devices, + const OrtKeyValuePairs* const* ep_metadata_pairs, + size_t num_devices, + const OrtSessionOptions* session_options, + const OrtLogger* session_logger, + std::unique_ptr* ep) noexcept override; + + OrtStatus* CreateAllocator(const OrtMemoryInfo* memory_info, + const OrtKeyValuePairs* allocator_options, + OrtAllocator** allocator) noexcept override { + return ep_factory_.CreateAllocator(&ep_factory_, memory_info, allocator_options, allocator); + } + + void ReleaseAllocator(OrtAllocator* allocator) noexcept override { + ep_factory_.ReleaseAllocator(&ep_factory_, allocator); + } + + OrtStatus* CreateDataTransfer(_Outptr_result_maybenull_ OrtDataTransferImpl** data_transfer) noexcept override { + return ep_factory_.CreateDataTransfer(&ep_factory_, data_transfer); + } + + bool IsStreamAware() const noexcept override { + return ep_factory_.IsStreamAware(&ep_factory_); + } + + OrtStatus* CreateSyncStreamForDevice(const OrtMemoryDevice* device, + const OrtKeyValuePairs* stream_options, + OrtSyncStreamImpl** stream) noexcept override { + return ep_factory_.CreateSyncStreamForDevice(&ep_factory_, device, stream_options, stream); + } + + OrtEpFactory& ep_factory_; // OrtEpFactory from the provider bridge EP + ProviderLibrary& provider_library_; // ProviderLibrary from the provider bridge EP +}; + +} // namespace onnxruntime diff --git a/onnxruntime/core/session/plugin_ep/ep_factory_webgpu.cc b/onnxruntime/core/session/plugin_ep/ep_factory_webgpu.cc new file mode 100644 index 0000000000000..0f955e0bab248 --- /dev/null +++ b/onnxruntime/core/session/plugin_ep/ep_factory_webgpu.cc @@ -0,0 +1,76 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if defined(USE_WEBGPU) +#include "core/session/plugin_ep/ep_factory_webgpu.h" + +#include "core/framework/error_code_helper.h" +#include "core/providers/webgpu/webgpu_provider_factory_creator.h" +#include "core/session/abi_devices.h" +#include "core/session/abi_logger.h" +#include "core/session/abi_session_options_impl.h" +#include "core/session/plugin_ep/ep_api.h" +#include "core/session/plugin_ep/ep_factory_internal.h" +#include "core/session/ort_apis.h" + +namespace onnxruntime { + +OrtStatus* WebGpuEpFactory::GetSupportedDevices(EpFactoryInternal& ep_factory, + const OrtHardwareDevice* const* devices, + size_t num_devices, + OrtEpDevice** ep_devices, + size_t max_ep_devices, + size_t* p_num_ep_devices) noexcept { + size_t& num_ep_devices = *p_num_ep_devices; + num_ep_devices = 0; + + for (size_t i = 0; i < num_devices && num_ep_devices < max_ep_devices; ++i) { + const OrtHardwareDevice& device = *devices[i]; + if (device.type == OrtHardwareDeviceType::OrtHardwareDeviceType_GPU) { + // TODO: any metadata or options to add? + ORT_API_RETURN_IF_ERROR(OrtExecutionProviderApi::CreateEpDevice(&ep_factory, + &device, nullptr, nullptr, + &ep_devices[num_ep_devices++])); + } + } + + return nullptr; +} + +OrtStatus* WebGpuEpFactory::CreateIExecutionProvider(const OrtHardwareDevice* const* /*devices*/, + const OrtKeyValuePairs* const* /*ep_metadata_pairs*/, + size_t num_devices, + const OrtSessionOptions* session_options, + const OrtLogger* session_logger, + std::unique_ptr* ep) noexcept { + *ep = nullptr; + + if (num_devices != 1) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, + "WebGPU EP factory currently only supports one device at a time."); + } + + auto webgpu_ep_factory = WebGpuProviderFactoryCreator::Create(session_options->value.config_options); + *ep = webgpu_ep_factory->CreateProvider(); + (*ep)->SetLogger(session_logger->ToInternal()); + + return nullptr; +} + +/* TODO: Implement CreateAllocator and CreateDataTransfer to support shared allocators and data transfer outside of + an InferenceSession. +OrtStatus* WebGpuEpFactory::CreateAllocator(const OrtMemoryInfo* memory_info, + const OrtKeyValuePairs* allocator_options, + OrtAllocator** allocator) noexcept override { + *allocator = device_allocators[memory_info->device.Id()].get(); +} + +OrtStatus* WebGpuEpFactory::CreateDataTransfer(_Outptr_result_maybenull_ OrtDataTransferImpl** data_transfer) override { + // TODO: Wrap the IDataTransfer implementation so we can copy to device using OrtApi CopyTensors. + *data_transfer = nullptr; + return nullptr; +} +*/ +} // namespace onnxruntime + +#endif // USE_WEBGPU diff --git a/onnxruntime/core/session/plugin_ep/ep_factory_webgpu.h b/onnxruntime/core/session/plugin_ep/ep_factory_webgpu.h new file mode 100644 index 0000000000000..06ecfa744bbda --- /dev/null +++ b/onnxruntime/core/session/plugin_ep/ep_factory_webgpu.h @@ -0,0 +1,35 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#if defined(USE_WEBGPU) +#include "core/session/plugin_ep/ep_factory_internal_impl.h" + +#include "core/graph/constants.h" + +namespace onnxruntime { + +class WebGpuEpFactory : public EpFactoryInternalImpl { + public: + WebGpuEpFactory() : EpFactoryInternalImpl(kWebGpuExecutionProvider, "Microsoft", OrtDevice::VendorIds::MICROSOFT) { + } + + private: + OrtStatus* GetSupportedDevices(EpFactoryInternal& ep_factory, + const OrtHardwareDevice* const* devices, + size_t num_devices, + OrtEpDevice** ep_devices, + size_t max_ep_devices, + size_t* p_num_ep_devices) noexcept override; + + OrtStatus* CreateIExecutionProvider(const OrtHardwareDevice* const* devices, + const OrtKeyValuePairs* const* ep_metadata_pairs, + size_t num_devices, + const OrtSessionOptions* session_options, + const OrtLogger* session_logger, + std::unique_ptr* ep) noexcept override; +}; +} // namespace onnxruntime + +#endif // USE_WEBGPU diff --git a/onnxruntime/core/session/ep_library.h b/onnxruntime/core/session/plugin_ep/ep_library.h similarity index 100% rename from onnxruntime/core/session/ep_library.h rename to onnxruntime/core/session/plugin_ep/ep_library.h diff --git a/onnxruntime/core/session/plugin_ep/ep_library_internal.cc b/onnxruntime/core/session/plugin_ep/ep_library_internal.cc new file mode 100644 index 0000000000000..d4015e0bbd366 --- /dev/null +++ b/onnxruntime/core/session/plugin_ep/ep_library_internal.cc @@ -0,0 +1,52 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/session/plugin_ep/ep_library_internal.h" +#include "core/session/plugin_ep/ep_factory_cpu.h" +#include "core/session/plugin_ep/ep_factory_dml.h" +#include "core/session/plugin_ep/ep_factory_webgpu.h" + +namespace onnxruntime { + +std::unique_ptr EpLibraryInternal::CreateCpuEp() { + auto cpu_factory_impl = std::make_unique(); + auto internal_factory = std::make_unique(std::move(cpu_factory_impl)); + return std::make_unique(std::move(internal_factory)); +} + +#if defined(USE_DML) + +std::unique_ptr EpLibraryInternal::CreateDmlEp() { + auto dml_factory_impl = std::make_unique(); + auto internal_factory = std::make_unique(std::move(dml_factory_impl)); + return std::make_unique(std::move(internal_factory)); +} +#endif + +#if defined(USE_WEBGPU) +std::unique_ptr EpLibraryInternal::CreateWebGpuEp() { + auto webgpu_factory_impl = std::make_unique(); + auto internal_factory = std::make_unique(std::move(webgpu_factory_impl)); + return std::make_unique(std::move(internal_factory)); +} +#endif + +std::vector> EpLibraryInternal::CreateInternalEps() { + std::vector> internal_eps; + internal_eps.reserve(4); + + // CPU EP + internal_eps.push_back(CreateCpuEp()); + +#if defined(USE_WEBGPU) + internal_eps.push_back(CreateWebGpuEp()); +#endif + +#if defined(USE_DML) + internal_eps.push_back(CreateDmlEp()); +#endif + + return internal_eps; +} + +} // namespace onnxruntime diff --git a/onnxruntime/core/session/ep_library_internal.h b/onnxruntime/core/session/plugin_ep/ep_library_internal.h similarity index 94% rename from onnxruntime/core/session/ep_library_internal.h rename to onnxruntime/core/session/plugin_ep/ep_library_internal.h index ab529edc2507f..1587f01360e26 100644 --- a/onnxruntime/core/session/ep_library_internal.h +++ b/onnxruntime/core/session/plugin_ep/ep_library_internal.h @@ -4,8 +4,8 @@ #pragma once #include "core/common/common.h" -#include "core/session/ep_library.h" -#include "core/session/ep_factory_internal.h" +#include "core/session/plugin_ep/ep_library.h" +#include "core/session/plugin_ep/ep_factory_internal.h" #include "core/session/onnxruntime_c_api.h" #include "core/session/provider_bridge_library.h" diff --git a/onnxruntime/core/session/ep_library_plugin.cc b/onnxruntime/core/session/plugin_ep/ep_library_plugin.cc similarity index 98% rename from onnxruntime/core/session/ep_library_plugin.cc rename to onnxruntime/core/session/plugin_ep/ep_library_plugin.cc index 32ddd8a765b4c..ebfa364f4f1df 100644 --- a/onnxruntime/core/session/ep_library_plugin.cc +++ b/onnxruntime/core/session/plugin_ep/ep_library_plugin.cc @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "core/session/ep_library_plugin.h" +#include "core/session/plugin_ep/ep_library_plugin.h" #include "core/common/logging/logging.h" #include "core/framework/error_code_helper.h" diff --git a/onnxruntime/core/session/ep_library_plugin.h b/onnxruntime/core/session/plugin_ep/ep_library_plugin.h similarity index 96% rename from onnxruntime/core/session/ep_library_plugin.h rename to onnxruntime/core/session/plugin_ep/ep_library_plugin.h index e2b02ccc654da..e044e91b61e37 100644 --- a/onnxruntime/core/session/ep_library_plugin.h +++ b/onnxruntime/core/session/plugin_ep/ep_library_plugin.h @@ -6,7 +6,7 @@ #include #include -#include "core/session/ep_library.h" +#include "core/session/plugin_ep/ep_library.h" namespace onnxruntime { /// diff --git a/onnxruntime/core/session/plugin_ep/ep_library_provider_bridge.cc b/onnxruntime/core/session/plugin_ep/ep_library_provider_bridge.cc new file mode 100644 index 0000000000000..06cf54aea4071 --- /dev/null +++ b/onnxruntime/core/session/plugin_ep/ep_library_provider_bridge.cc @@ -0,0 +1,58 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/session/plugin_ep/ep_library_provider_bridge.h" + +#include "core/session/plugin_ep/ep_factory_provider_bridge.h" + +namespace onnxruntime { +Status EpLibraryProviderBridge::Load() { + std::lock_guard lock{mutex_}; + + if (!factories_.empty()) { + // already loaded + return Status::OK(); + } + + // if we have been unloaded we can't just be reloaded. + if (!ep_library_plugin_ || !provider_library_) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "EpLibraryProviderBridge has been unloaded. " + "Please create a new instance using LoadPluginOrProviderBridge."); + } + + // wrap the EpLibraryPlugin factories that were created via calling CreateEpFactories in the library. + // use GetSupportedDevices from the library's factory. + // to do this we need to capture `factory` and plug it in to is_supported_fn and create_fn. + // we also need to update any returned OrtEpDevice instances to swap the wrapper EpFactoryInternal in so that we can + // call Provider::CreateIExecutionProvider in EpFactoryInternal::CreateIExecutionProvider. + for (const auto& factory : ep_library_plugin_->GetFactories()) { + auto factory_impl = std::make_unique(*factory, *provider_library_); + auto internal_factory = std::make_unique(std::move(factory_impl)); + + factory_ptrs_.push_back(internal_factory.get()); + internal_factory_ptrs_.push_back(internal_factory.get()); + factories_.push_back(std::move(internal_factory)); + } + + return Status::OK(); +} + +Status EpLibraryProviderBridge::Unload() { + std::lock_guard lock{mutex_}; + + internal_factory_ptrs_.clear(); + factory_ptrs_.clear(); + factories_.clear(); + + // we loaded ep_library_plugin_ after provider_library_ in LoadPluginOrProviderBridge so do the reverse order here. + ORT_RETURN_IF_ERROR(ep_library_plugin_->Unload()); + ep_library_plugin_ = nullptr; + + provider_library_->Unload(); + provider_library_ = nullptr; + + return Status::OK(); +} + +} // namespace onnxruntime diff --git a/onnxruntime/core/session/ep_library_provider_bridge.h b/onnxruntime/core/session/plugin_ep/ep_library_provider_bridge.h similarity index 95% rename from onnxruntime/core/session/ep_library_provider_bridge.h rename to onnxruntime/core/session/plugin_ep/ep_library_provider_bridge.h index 0717ccd957de7..c7e8ebefc3785 100644 --- a/onnxruntime/core/session/ep_library_provider_bridge.h +++ b/onnxruntime/core/session/plugin_ep/ep_library_provider_bridge.h @@ -5,8 +5,8 @@ #include #include -#include "core/session/ep_library.h" -#include "core/session/ep_factory_internal.h" +#include "core/session/plugin_ep/ep_library.h" +#include "core/session/plugin_ep/ep_factory_internal.h" #include "core/session/provider_bridge_library.h" namespace onnxruntime { diff --git a/onnxruntime/core/session/ep_plugin_provider_interfaces.cc b/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc similarity index 95% rename from onnxruntime/core/session/ep_plugin_provider_interfaces.cc rename to onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc index 52cf6c62c9702..2aac1e1c21cc7 100644 --- a/onnxruntime/core/session/ep_plugin_provider_interfaces.cc +++ b/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "core/session/ep_plugin_provider_interfaces.h" +#include "core/session/plugin_ep/ep_plugin_provider_interfaces.h" #include #include @@ -135,6 +135,10 @@ PluginExecutionProvider::PluginExecutionProvider(UniqueOrtEp ep, const OrtSessio if (ep_device->host_accessible_memory_info != nullptr) { allocator_mem_infos_.push_back(ep_device->host_accessible_memory_info); } + + if (ep_device->read_only_device_memory_info != nullptr) { + allocator_mem_infos_.push_back(ep_device->read_only_device_memory_info); + } } } @@ -546,9 +550,12 @@ Status PluginExecutionProvider::SetEpDynamicOptions(gsl::span } std::unique_ptr PluginExecutionProvider::GetDataTransfer() const { OrtDataTransferImpl* data_transfer_impl = nullptr; - OrtStatus* status = ep_factory_.CreateDataTransfer(&ep_factory_, &data_transfer_impl); - if (status != nullptr) { - ORT_THROW("Error creating data transfer: ", ToStatusAndRelease(status).ToString()); + + if (ep_factory_.CreateDataTransfer != nullptr) { + OrtStatus* status = ep_factory_.CreateDataTransfer(&ep_factory_, &data_transfer_impl); + if (status != nullptr) { + ORT_THROW("Error creating data transfer: ", ToStatusAndRelease(status).ToString()); + } } if (data_transfer_impl == nullptr) { @@ -564,6 +571,11 @@ std::vector PluginExecutionProvider::CreatePreferredAllocators() { for (const auto* memory_info : allocator_mem_infos_) { OrtAllocator* ort_allocator_ptr = nullptr; + + if (!ort_ep_->CreateAllocator && !ep_factory_.CreateAllocator) { + ORT_THROW("The OrtEpDevice requires the EP library to implement an allocator, but none were found."); + } + // prefer OrtEp function if available, otherwise fall back to using the OrtEpFactory implementation. OrtStatus* ort_status = ort_ep_->CreateAllocator ? ort_ep_->CreateAllocator(ort_ep_.get(), memory_info, &ort_allocator_ptr) @@ -575,6 +587,13 @@ std::vector PluginExecutionProvider::CreatePreferredAllocators() { ORT_THROW("Error creating allocator: ", ToStatusAndRelease(ort_status).ToString()); } + if (ort_allocator_ptr->Info(ort_allocator_ptr)->alloc_type == OrtAllocatorType::OrtArenaAllocator) { + ORT_THROW( + "OrtEpFactory returned an allocator with OrtAllocatorType of OrtArenaAllocator. " + "This type is reserved for ONNX Runtime internal usage only, as any arena usage by the " + "EP library should be opaque to ORT"); + } + auto ort_allocator = OrtAllocatorUniquePtr( ort_allocator_ptr, [this](OrtAllocator* allocator) { @@ -588,7 +607,7 @@ std::vector PluginExecutionProvider::CreatePreferredAllocators() { void PluginExecutionProvider::RegisterStreamHandlers(IStreamCommandHandleRegistry& registry, AllocatorMap& /*allocators*/) const { - if (!ep_factory_.IsStreamAware(&ep_factory_)) { + if (ep_factory_.IsStreamAware == nullptr || !ep_factory_.IsStreamAware(&ep_factory_)) { return; } @@ -598,6 +617,10 @@ void PluginExecutionProvider::RegisterStreamHandlers(IStreamCommandHandleRegistr continue; } + if (!ort_ep_->CreateSyncStreamForDevice && !ep_factory_.CreateSyncStreamForDevice) { + ORT_THROW("The OrtEpFactory is stream aware, but did not provide CreateSyncStreamForDevice."); + } + auto device_type = mem_info->device.Type(); registry.RegisterCreateStreamFn( diff --git a/onnxruntime/core/session/ep_plugin_provider_interfaces.h b/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.h similarity index 100% rename from onnxruntime/core/session/ep_plugin_provider_interfaces.h rename to onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.h diff --git a/onnxruntime/core/session/ep_api_utils.h b/onnxruntime/core/session/plugin_ep/forward_to_factory_impl.h similarity index 99% rename from onnxruntime/core/session/ep_api_utils.h rename to onnxruntime/core/session/plugin_ep/forward_to_factory_impl.h index 77528565eced7..67b22779395ec 100644 --- a/onnxruntime/core/session/ep_api_utils.h +++ b/onnxruntime/core/session/plugin_ep/forward_to_factory_impl.h @@ -7,7 +7,7 @@ namespace onnxruntime { // helper to forward a call from the C API to an instance of the factory implementation. // used by EpFactoryInternal and EpFactoryProviderBridge. template -struct ForwardToFactory { +struct ForwardToFactoryImpl { static const char* ORT_API_CALL GetFactoryName(const OrtEpFactory* this_ptr) noexcept { return static_cast(this_ptr)->GetName(); } diff --git a/onnxruntime/core/session/provider_policy_context.cc b/onnxruntime/core/session/provider_policy_context.cc index 211bf8b2d15a4..6bcbda0f13b92 100644 --- a/onnxruntime/core/session/provider_policy_context.cc +++ b/onnxruntime/core/session/provider_policy_context.cc @@ -11,8 +11,8 @@ #include "core/framework/error_code_helper.h" #include "core/session/abi_devices.h" #include "core/session/abi_logger.h" -#include "core/session/ep_factory_internal.h" -#include "core/session/ep_plugin_provider_interfaces.h" +#include "core/session/plugin_ep/ep_factory_internal.h" +#include "core/session/plugin_ep/ep_plugin_provider_interfaces.h" #include "core/session/inference_session.h" #include "core/session/inference_session_utils.h" #include "core/session/onnxruntime_c_api.h" diff --git a/onnxruntime/core/session/utils.cc b/onnxruntime/core/session/utils.cc index 69039beb49363..f90ace95d6e58 100644 --- a/onnxruntime/core/session/utils.cc +++ b/onnxruntime/core/session/utils.cc @@ -19,10 +19,10 @@ #include "core/session/ort_env.h" #if !defined(ORT_MINIMAL_BUILD) -#include "core/session/ep_factory_internal.h" -#include "core/session/ep_plugin_provider_interfaces.h" -#include "core/session/ep_library_plugin.h" -#include "core/session/ep_library_provider_bridge.h" +#include "core/session/plugin_ep/ep_factory_internal.h" +#include "core/session/plugin_ep/ep_plugin_provider_interfaces.h" +#include "core/session/plugin_ep/ep_library_plugin.h" +#include "core/session/plugin_ep/ep_library_provider_bridge.h" #include "core/session/model_compilation_options.h" #include "core/session/provider_policy_context.h" #endif // !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/core/util/math_cpu.cc b/onnxruntime/core/util/math_cpu.cc index 983321593a92b..63b647060df3c 100644 --- a/onnxruntime/core/util/math_cpu.cc +++ b/onnxruntime/core/util/math_cpu.cc @@ -201,6 +201,75 @@ void GemmEx(CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB, p MlasGemm(TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc, threadpool); } +template <> +void GemmEx(CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB, ptrdiff_t M, ptrdiff_t N, ptrdiff_t K, + MLFloat16 alpha, const MLFloat16* A, int lda, const MLFloat16* B, int ldb, MLFloat16 beta, + MLFloat16* C, int ldc, ThreadPool*) { + // The following function is not implemented for MLFloat16 in Mlas. + // MlasGemm(TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc, threadpool); + // Threadpool is not used. + auto C_mat = EigenMatrixMapWithStrides(reinterpret_cast(C), N, M, Eigen::Stride(ldc, 1)); + +#if defined(__GNUC__) +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wstrict-aliasing" +#endif + if (beta == MLFloat16(0.f)) { + C_mat.setZero(); + } else { + C_mat *= *reinterpret_cast(&beta); + } + Eigen::half alpha_half = *reinterpret_cast(&alpha); +#if defined(__GNUC__) +#pragma GCC diagnostic pop +#endif + + switch (TransA) { + case CblasNoTrans: { + switch (TransB) { + case CblasNoTrans: + C_mat.noalias() += alpha_half * (ConstEigenMatrixMapWithStrides( + reinterpret_cast(B), N, K, Eigen::Stride(ldb, 1)) * + ConstEigenMatrixMapWithStrides( + reinterpret_cast(A), K, M, Eigen::Stride(lda, 1))); + return; + case CblasTrans: + C_mat.noalias() += alpha_half * (ConstEigenMatrixMapWithStrides( + reinterpret_cast(B), K, N, Eigen::Stride(ldb, 1)) + .transpose() * + ConstEigenMatrixMapWithStrides( + reinterpret_cast(A), K, M, Eigen::Stride(lda, 1))); + return; + default: + ORT_THROW("CblasNoTrans Unexpected CBLAS_TRANSPOSE for TransB of ", TransB); + } + } + case CblasTrans: { + switch (TransB) { + case CblasNoTrans: + C_mat.noalias() += alpha_half * (ConstEigenMatrixMapWithStrides( + reinterpret_cast(B), N, K, Eigen::Stride(ldb, 1)) * + ConstEigenMatrixMapWithStrides( + reinterpret_cast(A), M, K, Eigen::Stride(lda, 1)) + .transpose()); + return; + case CblasTrans: + C_mat.noalias() += alpha_half * (ConstEigenMatrixMapWithStrides( + reinterpret_cast(B), K, N, Eigen::Stride(ldb, 1)) + .transpose() * + ConstEigenMatrixMapWithStrides( + reinterpret_cast(A), M, K, Eigen::Stride(lda, 1)) + .transpose()); + return; + default: + ORT_THROW("CblasTrans Unexpected CBLAS_TRANSPOSE for TransB of ", TransB); + } + } + default: + ORT_THROW("Unexpected CBLAS_TRANSPOSE for TransA of ", TransA); + } +} + template void Gemv(CBLAS_TRANSPOSE TransA, int M, diff --git a/onnxruntime/core/util/math_cpuonly.h b/onnxruntime/core/util/math_cpuonly.h index 73caf9f86180d..1b80bfb02c706 100644 --- a/onnxruntime/core/util/math_cpuonly.h +++ b/onnxruntime/core/util/math_cpuonly.h @@ -80,6 +80,12 @@ namespace onnxruntime { template using EigenMatrixMap = Eigen::Map>; +template +using EigenMatrixMapWithStrides = Eigen::Map, 0, Eigen::Stride>; + +template +using ConstEigenMatrixMapWithStrides = Eigen::Map, 0, Eigen::Stride>; + template using EigenArrayMap = Eigen::Map>; diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index ec4d8c6330c8d..acf0681cf8752 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -46,7 +46,7 @@ #if !defined(ORT_MINIMAL_BUILD) #include "core/session/abi_devices.h" -#include "core/session/ep_factory_internal.h" +#include "core/session/plugin_ep/ep_factory_internal.h" #include "core/session/provider_policy_context.h" #include "core/session/utils.h" #endif diff --git a/onnxruntime/test/autoep/library/ep_allocator.h b/onnxruntime/test/autoep/library/ep_allocator.h index 624b4fcb484cd..e46c03dfc8f14 100644 --- a/onnxruntime/test/autoep/library/ep_allocator.h +++ b/onnxruntime/test/autoep/library/ep_allocator.h @@ -5,17 +5,50 @@ #include "example_plugin_ep_utils.h" +#include + // from onnxruntime/core/framework/allocator_stats.h +// copied from onnxruntime::AllocatorStats struct AllocatorStats { int64_t num_allocs; // Number of allocations. int64_t num_reserves; // Number of reserves. (Number of calls to Reserve() in arena-based allocators) + int64_t num_arena_extensions; // Number of arena extensions (Relevant only for arena based allocators) + int64_t num_arena_shrinkages; // Number of arena shrinkages (Relevant only for arena based allocators) int64_t bytes_in_use; // Number of bytes in use. int64_t total_allocated_bytes; // The total number of allocated bytes by the allocator. int64_t max_bytes_in_use; // The maximum bytes in use. int64_t max_alloc_size; // The max single allocation seen. - int64_t bytes_limit; // The upper limit what the allocator can allocate, if such a limit - // is known. Certain allocator may return 0 to indicate the limit is - // unknown. + // The upper limit what the allocator can allocate, if such a limit + // is known. Certain allocator may return 0 to indicate the limit is unknown. + int64_t bytes_limit; + + void ToKeyValuePairs(const OrtApi& api, OrtKeyValuePairs* kvps) const { + if (num_allocs > 0 || bytes_limit != 0) { + api.AddKeyValuePair(kvps, "Limit", std::to_string(bytes_limit).c_str()); + api.AddKeyValuePair(kvps, "InUse", std::to_string(bytes_in_use).c_str()); + api.AddKeyValuePair(kvps, "TotalAllocated", std::to_string(total_allocated_bytes).c_str()); + api.AddKeyValuePair(kvps, "MaxInUse", std::to_string(max_bytes_in_use).c_str()); + api.AddKeyValuePair(kvps, "NumAllocs", std::to_string(num_allocs).c_str()); + api.AddKeyValuePair(kvps, "NumReserves", std::to_string(num_reserves).c_str()); + api.AddKeyValuePair(kvps, "NumArenaExtensions", std::to_string(num_arena_extensions).c_str()); + api.AddKeyValuePair(kvps, "NumArenaShrinkages", std::to_string(num_arena_shrinkages).c_str()); + api.AddKeyValuePair(kvps, "MaxAllocSize", std::to_string(max_alloc_size).c_str()); + } + } + + std::string DebugString() const { + std::ostringstream ss; + ss << "Limit: " << this->bytes_limit << "\n" + << "InUse: " << this->bytes_in_use << "\n" + << "TotalAllocated: " << this->total_allocated_bytes << "\n" + << "MaxInUse: " << this->max_bytes_in_use << "\n" + << "NumAllocs: " << this->num_allocs << "\n" + << "NumReserves: " << this->num_reserves << "\n" + << "NumArenaExtensions: " << this->num_arena_extensions << "\n" + << "NumArenaShrinkages: " << this->num_arena_shrinkages << "\n" + << "MaxAllocSize: " << this->max_alloc_size << "\n"; + return ss.str(); + } }; struct CustomAllocator : OrtAllocator { @@ -27,6 +60,7 @@ struct CustomAllocator : OrtAllocator { Info = InfoImpl; Reserve = AllocImpl; // no special reserve logic and most likely unnecessary unless you have your own arena GetStats = GetStatsImpl; // this can be set to nullptr if you don't want to implement it + AllocOnStream = nullptr; } static void* ORT_API_CALL AllocImpl(struct OrtAllocator* this_, size_t size) { diff --git a/onnxruntime/test/autoep/library/ep_arena.cc b/onnxruntime/test/autoep/library/ep_arena.cc new file mode 100644 index 0000000000000..aa0db71e97925 --- /dev/null +++ b/onnxruntime/test/autoep/library/ep_arena.cc @@ -0,0 +1,778 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "ep_arena.h" + +#include +#include + +namespace { +std::string GetAllocatorName(const OrtApi& api, OrtAllocator& allocator) { + const OrtMemoryInfo* mem_info = allocator.Info(&allocator); + const char* allocator_name; + auto* status = api.MemoryInfoGetName(mem_info, &allocator_name); // never fails + static_cast(status); + return allocator_name; +} +} // namespace + +ArenaImpl::ArenaImpl(AllocatorUniquePtr allocator, const ArenaConfig& config, const OrtApi& api, + const OrtLogger& logger) + : device_allocator_{std::move(allocator)}, + allocator_name_{GetAllocatorName(api, *device_allocator_)}, + config_{config}, + next_allocation_id_(1), + free_chunks_list_(kInvalidChunkHandle), + api_{api}, + ep_api_{*api_.GetEpApi()}, + logger_{logger} { + LOG(INFO, "Creating ArenaImpl for " + << allocator_name_ + << " with following configs: initial_chunk_size_bytes: " << config_.initial_chunk_size_bytes + << " max_dead_bytes_per_chunk: " << config_.max_dead_bytes_per_chunk + << " initial_growth_chunk_size_bytes: " << config_.initial_growth_chunk_size_bytes + << " max_power_of_two_extend_bytes: " << config_.max_power_of_two_extend_bytes + << " memory limit: " << config_.max_mem + << " arena_extend_strategy: " << config_.arena_extend_strategy); + + curr_region_allocation_bytes_ = RoundedBytes( + std::min(config_.max_mem, static_cast(config_.initial_chunk_size_bytes))); + + stats_.bytes_limit = static_cast(config.max_mem); + + // Create a bunch of bins of various good sizes. + + // We create bins to fit all possible ranges that cover the + // config_.max_mem starting from allocations up to 256 bytes to + // allocations up to (and including) the memory limit. + LOG(VERBOSE, "Creating " << kNumBins << " bins of max chunk size " + << BinNumToSize(0) << " to " << BinNumToSize(kNumBins - 1)); + + for (BinNum b = 0; b < kNumBins; b++) { + size_t bin_size = BinNumToSize(b); + new (BinFromIndex(b)) Bin(this, bin_size); + EP_ENFORCE((BinForSize(bin_size) == BinFromIndex(b) && + BinForSize(bin_size + 255) == BinFromIndex(b) && + BinForSize(bin_size * 2 - 1) == BinFromIndex(b)), + "Invalid bin size for bin " << b); + + if (b + 1 < kNumBins) { + EP_ENFORCE(BinForSize(bin_size * 2) != BinFromIndex(b), "Invalid bin size for " << b); + } + } +} + +ArenaImpl::~ArenaImpl() { + for (const auto& region : region_manager_.regions()) { + device_allocator_->Free(device_allocator_.get(), region.ptr()); + } + + for (const auto& reserve_chunk : reserved_chunks_) { + device_allocator_->Free(device_allocator_.get(), reserve_chunk.first); + } + + for (BinNum b = 0; b < kNumBins; b++) { + BinFromIndex(b)->~Bin(); + } +} + +ArenaImpl::Chunk* ArenaImpl::ChunkFromHandle(ChunkHandle h) { + EP_ENFORCE(h < chunks_.size(), "ChunkFromHandle"); + return &(chunks_[h]); +} + +OrtStatus* ArenaImpl::Extend(size_t rounded_bytes) { + size_t available_bytes = config_.max_mem - static_cast(stats_.total_allocated_bytes); + // Rounds available_bytes down to the nearest multiple of kMinAllocationSize. + available_bytes = (available_bytes / kMinAllocationSize) * kMinAllocationSize; + + // Do we have enough space to handle the client's request? + // If not, fail immediately. + if (rounded_bytes > available_bytes) { + RETURN_ERROR(ORT_EP_FAIL, "Available memory of " << available_bytes << " is smaller than requested bytes of " + << rounded_bytes); + } + + auto safe_alloc = [this](size_t alloc_bytes) { + void* new_mem = nullptr; + try { + new_mem = device_allocator_->Alloc(device_allocator_.get(), alloc_bytes); + } catch (const std::bad_alloc&) { + // attempted allocation can throw std::bad_alloc. we want to treat this the same as if it returned nullptr + // so swallow the exception + } + // catch (const MyException& exception) { + // if your implementation threw, consider swallowing the exception to enable attempting a smaller allocation + // if possible + //} + return new_mem; + }; + + auto get_extend_bytes = [this, available_bytes](const size_t bytes, size_t& extend_bytes) -> OrtStatus* { + extend_bytes = 0; + if (config_.arena_extend_strategy == ArenaExtendStrategy::kNextPowerOfTwo) { + // If curr_region_allocation_bytes_ is not enough to satisfy the + // allocation, keep multiplying by a power of two until that is + // sufficient. + bool increased_allocation = false; + while (bytes > curr_region_allocation_bytes_) { + curr_region_allocation_bytes_ *= 2; + increased_allocation = true; + } + + extend_bytes = std::min(static_cast(curr_region_allocation_bytes_), available_bytes); + + // we allocated the same number of bytes as the current region + // the 2x is to double the minimum size of the next amount we'll allocate + if (!increased_allocation) { + if (config_.arena_extend_strategy == ArenaExtendStrategy::kNextPowerOfTwo && + static_cast(curr_region_allocation_bytes_) * 2 < config_.max_power_of_two_extend_bytes) { + curr_region_allocation_bytes_ *= 2; + } else { + curr_region_allocation_bytes_ = config_.max_power_of_two_extend_bytes; + } + } + } else if (config_.arena_extend_strategy == ArenaExtendStrategy::kSameAsRequested) { + // BFC Arena could cause internal and external fragmentation. But, running training with + // big batch size will be very sensitive to fragmentation. So, to avoid fragmentation, + // just extend arena with actual requested size. + extend_bytes = bytes; + } else { + RETURN_ERROR(ORT_INVALID_ARGUMENT, "Invalid arena extend strategy." << config_.arena_extend_strategy); + } + + return nullptr; + }; + + size_t bytes; + RETURN_IF_ERROR(get_extend_bytes(rounded_bytes, bytes)); + + // Try allocating. + void* mem_addr = safe_alloc(bytes); + + static constexpr float kBackpedalFactor = 0.9f; + // Try allocating less memory. + while (mem_addr == nullptr) { + // kBackpedalFactor is float, bytes is size_t. The result of bytes * kBackpedalFactor is float. When we cast it to + // size_t, which is a smaller type, it could loss data. This is what C4244 complains. The "static_cast" here + // is to suppress the warning. C26451 suggest we may change kBackpedalFactor to double to get better accuary. But if + // we do that, AMD GPU CI build pipeline will have an "out-of-memory" error. So I choose to keep this piece of code + // untouched and disable the warning first. +#if defined(_MSC_VER) && !defined(__clang__) +#pragma warning(push) +#pragma warning(disable : 26451) +#endif + bytes = RoundedBytes(static_cast(bytes * kBackpedalFactor)); +#if defined(_MSC_VER) && !defined(__clang__) +#pragma warning(pop) +#endif + // give up if we can't satisfy the requested size, or we're attempting an allocation of less than 8K. + // + // the latter protects against an infinite loop that occurs when bytes is less than 2560. at that point the 10% + // reduction to 2304 bytes is undone by rounding to a 256 boundary in RoundedBytes, leading to an infinite loop. + // the 8K value is just to give up a little earlier vs. getting all the way down to 2560 bytes. + // If we can't allocate 8K, we're pretty much dead. + if (bytes < rounded_bytes || bytes < 8 * 1024) + break; + + mem_addr = safe_alloc(bytes); + } + + if (mem_addr == nullptr) { + RETURN_ERROR(ORT_EP_FAIL, "Failed to allocate memory for requested buffer of size " << rounded_bytes); + } + + LOG(INFO, "Extended allocation by " << bytes << " bytes."); + + stats_.total_allocated_bytes += bytes; + LOG(INFO, "Total allocated bytes: " << stats_.total_allocated_bytes); + + LOG(INFO, "Allocated memory at " << mem_addr << " to " << static_cast(static_cast(mem_addr) + bytes)); + + region_manager_.AddAllocationRegion(mem_addr, bytes, stats_.num_arena_extensions); + stats_.num_arena_extensions += 1; + + // Create one large chunk for the whole memory space that will + // be chunked later. + ChunkHandle h = AllocateChunk(); + ArenaImpl::Chunk* c = ChunkFromHandle(h); + c->ptr = mem_addr; + c->size = bytes; + c->allocation_id = -1; + c->prev = kInvalidChunkHandle; + c->next = kInvalidChunkHandle; + // assign the new created chunk to default stream, so it can be pick up by any stream + c->stream = nullptr; + + region_manager_.set_handle(c->ptr, h); + + // TODO(vrv): Try to merge this new region with an existing region, + // if the address space is contiguous, to avoid fragmentation + // across regions. + + // Insert the chunk into the right bin. + InsertFreeChunkIntoBin(h); + + return nullptr; +} + +ArenaImpl::ChunkHandle +ArenaImpl::AllocateChunk() { + if (free_chunks_list_ != kInvalidChunkHandle) { + ChunkHandle h = free_chunks_list_; + Chunk* c = ChunkFromHandle(h); + free_chunks_list_ = c->next; + return h; + } + ChunkHandle h = chunks_.size(); + chunks_.resize(h + 1); + return h; +} + +void ArenaImpl::DeallocateChunk(ChunkHandle h) { + Chunk* c = ChunkFromHandle(h); + + if (c->stream) { + if (auto it = stream_to_chunks_.find(c->stream); it != stream_to_chunks_.end()) { + size_t result = it->second.erase(h); + static_cast(result); // should always be found + + if (it->second.empty()) { + stream_to_chunks_.erase(it); + impl_to_stream_.erase(ep_api_.SyncStream_GetImpl(c->stream)); + } + } + + c->stream = nullptr; + c->stream_sync_id = 0; + } + + c->next = free_chunks_list_; + free_chunks_list_ = h; +} + +// static +size_t ArenaImpl::RoundedBytes(size_t bytes) { + return (kMinAllocationSize * ((bytes + kMinAllocationSize - 1) / kMinAllocationSize)); +} + +void* ArenaImpl::Alloc(size_t size) { + return AllocateRawInternal(size, nullptr, false); +} + +void* ArenaImpl::AllocOnStream(size_t size, OrtSyncStream* stream) { + return AllocateRawInternal(size, stream, false); +} + +void* ArenaImpl::Reserve(size_t size) { + if (size == 0) + return nullptr; + + std::lock_guard lock(lock_); + + LOG(INFO, "Reserving memory in ArenaImpl for " << allocator_name_ << " size: " << size); + + void* ptr = device_allocator_->Alloc(device_allocator_.get(), size); + EP_ENFORCE(reserved_chunks_.find(ptr) == reserved_chunks_.end(), __FUNCTION__); + reserved_chunks_.insert(std::pair(ptr, size)); + stats_.bytes_in_use += size; + stats_.num_reserves += 1; + stats_.num_allocs += 1; + stats_.max_alloc_size = std::max(static_cast(stats_.max_alloc_size), size); + stats_.max_bytes_in_use = std::max(static_cast(stats_.max_bytes_in_use), stats_.bytes_in_use); + stats_.total_allocated_bytes += size; + return ptr; +} + +size_t ArenaImpl::RequestedSize(const void* ptr) { + std::lock_guard lock(lock_); + ArenaImpl::ChunkHandle h = region_manager_.get_handle(ptr); + EP_ENFORCE(h != kInvalidChunkHandle, __FUNCTION__); + ArenaImpl::Chunk* c = ChunkFromHandle(h); + return c->requested_size; +} + +size_t ArenaImpl::AllocatedSize(const void* ptr) { + std::lock_guard lock(lock_); + ArenaImpl::ChunkHandle h = region_manager_.get_handle(ptr); + EP_ENFORCE(h != kInvalidChunkHandle, __FUNCTION__); + ArenaImpl::Chunk* c = ChunkFromHandle(h); + return c->size; +} + +void* ArenaImpl::AllocateRawInternal(size_t num_bytes, OrtSyncStream* stream, bool dump_log_on_failure) { + if (num_bytes == 0) { + return nullptr; + } + + // Round to multiple of kMinAllocationSize + size_t rounded_bytes = RoundedBytes(num_bytes); + + // The BFC allocator tries to find the best fit first. + BinNum bin_num = BinNumForSize(rounded_bytes); + + std::lock_guard lock(lock_); + + if (stream && stream_to_chunks_.find(stream) == stream_to_chunks_.end()) { + stream_to_chunks_.insert({stream, std::set{}}); + const OrtSyncStreamImpl* stream_impl = ep_api_.SyncStream_GetImpl(stream); + assert(stream_impl); + impl_to_stream_.insert({stream_impl, stream}); + } + + // search for a valid chunk + auto* chunk = FindChunkPtr(bin_num, rounded_bytes, num_bytes, stream); + + if (chunk != nullptr) { + return chunk->ptr; + } + + LOG(INFO, "Extending arena for " << allocator_name_ + << ". bin_num:" << bin_num << " (requested) num_bytes: " << num_bytes + << " (actual) rounded_bytes:" << rounded_bytes); + + // Try to extend + auto status = Extend(rounded_bytes); + if (status == nullptr) { + chunk = FindChunkPtr(bin_num, rounded_bytes, num_bytes, stream); + if (chunk != nullptr) { + return chunk->ptr; + } else { + status = api_.CreateStatus(ORT_EP_FAIL, + ("Failed to find a free memory block despite calling Extend. rounded_bytes=" + + std::to_string(rounded_bytes)) + .c_str()); + } + } + + // We searched all bins for an existing free chunk to use and couldn't find one. Dump the memory log for analysis. + if (dump_log_on_failure) { + LOG(ERROR, "BFC Arena ran out of memory trying to allocate " << num_bytes); + DumpMemoryLog(rounded_bytes); + } + + throw std::runtime_error(api_.GetErrorMessage(status)); +} + +OrtStatus* ArenaImpl::GetStats(OrtKeyValuePairs** stats) { + std::lock_guard lock(lock_); + + api_.CreateKeyValuePairs(stats); + stats_.ToKeyValuePairs(api_, *stats); + + return nullptr; +} + +ArenaImpl::Chunk* ArenaImpl::SplitFreeChunkFromBin(ArenaImpl::Bin::FreeChunkSet* free_chunks, + const ArenaImpl::Bin::FreeChunkSet::iterator& citer, + size_t rounded_bytes, + size_t num_bytes) { + const ArenaImpl::ChunkHandle h = (*citer); + RemoveFreeChunkIterFromBin(free_chunks, citer); + ArenaImpl::Chunk* chunk = ChunkFromHandle(h); + + // If we can break the size of the chunk into two reasonably large pieces, do so. + // In any case don't waste more than max_dead_bytes_per_chunk bytes on padding this alloc. + if (chunk->size >= rounded_bytes * 2 || + static_cast(chunk->size - rounded_bytes) >= config_.max_dead_bytes_per_chunk) { + SplitChunk(h, rounded_bytes); + chunk = ChunkFromHandle(h); // Update chunk pointer in case it moved + } + + // The requested size of the returned chunk is what the user has allocated. + chunk->requested_size = num_bytes; + // Assign a unique id and increment the id counter, marking the chunk as being in use. + chunk->allocation_id = next_allocation_id_++; + + ++stats_.num_allocs; + stats_.bytes_in_use += chunk->size; + stats_.max_bytes_in_use = std::max(stats_.max_bytes_in_use, stats_.bytes_in_use); + stats_.max_alloc_size = std::max(stats_.max_alloc_size, static_cast(chunk->size)); + + return chunk; +} + +ArenaImpl::Chunk* ArenaImpl::FindChunkPtr(BinNum bin_num, size_t rounded_bytes, size_t num_bytes, + OrtSyncStream* stream) { + // First identify the first bin that could satisfy rounded_bytes. + for (; bin_num < kNumBins; bin_num++) { + // Start searching from the first bin for the smallest chunk that fits rounded_bytes. + Bin* b = BinFromIndex(bin_num); + for (auto citer = b->free_chunks.begin(); citer != b->free_chunks.end(); ++citer) { + const ArenaImpl::ChunkHandle h = (*citer); + ArenaImpl::Chunk* chunk = ChunkFromHandle(h); + EP_ENFORCE(!chunk->in_use(), __FUNCTION__); + + if (chunk->size >= rounded_bytes) { + // We found an existing chunk that fits us that wasn't in use. + // If it's assigned to another stream, and we have synchronized with that stream more recently than it + // was assigned, we can take the chunk. + bool safe_to_use = chunk->stream == stream || + !chunk->stream || + (stream && chunk->stream && + chunk->stream_sync_id < ep_api_.GetSyncIdForLastWaitOnSyncStream(chunk->stream, stream)); + + if (safe_to_use) { + chunk = SplitFreeChunkFromBin(&b->free_chunks, citer, rounded_bytes, num_bytes); + + if (stream) { + chunk->stream = stream; + chunk->stream_sync_id = ep_api_.SyncStream_GetSyncId(stream); + stream_to_chunks_[stream].insert(h); + } + + return chunk; + } + } + } + } + + return nullptr; +} + +void ArenaImpl::SplitChunk(ArenaImpl::ChunkHandle h, size_t num_bytes) { + // Allocate the new chunk before we do any ChunkFromHandle + ChunkHandle h_new_chunk = AllocateChunk(); + + Chunk* c = ChunkFromHandle(h); + EP_ENFORCE(!c->in_use() && (c->bin_num == kInvalidBinNum), __FUNCTION__); + + // Create a new chunk starting num_bytes after c + ArenaImpl::Chunk* new_chunk = ChunkFromHandle(h_new_chunk); + new_chunk->stream = c->stream; + new_chunk->stream_sync_id = c->stream_sync_id; + + new_chunk->ptr = static_cast(static_cast(c->ptr) + num_bytes); + region_manager_.set_handle(new_chunk->ptr, h_new_chunk); + + // Set the new sizes of the chunks. + new_chunk->size = c->size - num_bytes; + c->size = num_bytes; + + // The new chunk is not in use. + new_chunk->allocation_id = -1; + + // Maintain the pointers. + // c <-> c_neighbor becomes + // c <-> new_chunk <-> c_neighbor + ArenaImpl::ChunkHandle h_neighbor = c->next; + new_chunk->prev = h; + new_chunk->next = h_neighbor; + c->next = h_new_chunk; + if (h_neighbor != kInvalidChunkHandle) { + Chunk* c_neighbor = ChunkFromHandle(h_neighbor); + c_neighbor->prev = h_new_chunk; + } + + // Add the newly free chunk to the free bin. + InsertFreeChunkIntoBin(h_new_chunk); +} + +void ArenaImpl::Free(void* p) { + if (p == nullptr) { + return; + } + + std::lock_guard lock(lock_); + auto it = reserved_chunks_.find(p); + if (it != reserved_chunks_.end()) { + device_allocator_->Free(device_allocator_.get(), it->first); + stats_.bytes_in_use -= it->second; + stats_.total_allocated_bytes -= it->second; + reserved_chunks_.erase(it); + } else { + DeallocateRawInternal(p); + } +} + +void ArenaImpl::DeallocateRawInternal(void* ptr) { + // Find the chunk from the ptr. + ArenaImpl::ChunkHandle h = region_manager_.get_handle(ptr); + EP_ENFORCE(h != kInvalidChunkHandle, __FUNCTION__); + + // Consider coalescing it. + FreeAndMaybeCoalesce(h); +} + +// Merges Chunk(h2) into Chunk(h1) when Chunk(h1)->next is h2 and Chunk(h2)->prev is h1. +void ArenaImpl::Merge(ArenaImpl::ChunkHandle h1, + ArenaImpl::ChunkHandle h2) { + Chunk* c1 = ChunkFromHandle(h1); + Chunk* c2 = ChunkFromHandle(h2); + // We can only merge chunks that are not in use. + EP_ENFORCE(!c1->in_use() && !c2->in_use() && c1->stream == c2->stream, __FUNCTION__); + + // c1's prev doesn't change, still points to the same ptr, and is + // still not in use. + + // Fix up neighbor pointers + // + // c1 <-> c2 <-> c3 should become + // c1 <-> c3 + + ArenaImpl::ChunkHandle h3 = c2->next; + c1->next = h3; + EP_ENFORCE(c2->prev == h1, __FUNCTION__); + if (h3 != kInvalidChunkHandle) { + ArenaImpl::Chunk* c3 = ChunkFromHandle(h3); + c3->prev = h1; + } + + // Set the new size + c1->size += c2->size; + + // we only merge chunks that have the same stream + assert(c1->stream == c2->stream); + c1->stream_sync_id = std::max(c1->stream_sync_id, c2->stream_sync_id); + + DeleteChunk(h2); +} + +void ArenaImpl::DeleteChunk(ChunkHandle h) { + // Delete h and cleanup all state + Chunk* c = ChunkFromHandle(h); + region_manager_.erase(c->ptr); + DeallocateChunk(h); +} + +void ArenaImpl::InsertFreeChunkIntoBin(ArenaImpl::ChunkHandle h) { + Chunk* c = ChunkFromHandle(h); + EP_ENFORCE(!c->in_use() && (c->bin_num == kInvalidBinNum), __FUNCTION__); + BinNum bin_num = BinNumForSize(c->size); + Bin* new_bin = BinFromIndex(bin_num); + c->bin_num = bin_num; + new_bin->free_chunks.insert(h); +} + +void ArenaImpl::RemoveFreeChunkIterFromBin(ArenaImpl::Bin::FreeChunkSet* free_chunks, + const ArenaImpl::Bin::FreeChunkSet::iterator& citer) { + ChunkHandle h = *citer; + Chunk* c = ChunkFromHandle(h); + EP_ENFORCE(!c->in_use() && (c->bin_num != kInvalidBinNum), __FUNCTION__); + free_chunks->erase(citer); + c->bin_num = kInvalidBinNum; +} + +void ArenaImpl::RemoveFreeChunkFromBin(ArenaImpl::ChunkHandle h) { + Chunk* c = ChunkFromHandle(h); + EP_ENFORCE(!c->in_use() && (c->bin_num != kInvalidBinNum), __FUNCTION__); + EP_ENFORCE(BinFromIndex(c->bin_num)->free_chunks.erase(h) > 0, "Could not find chunk in bin"); + c->bin_num = kInvalidBinNum; +} + +void ArenaImpl::FreeAndMaybeCoalesce(ArenaImpl::ChunkHandle h) { + Chunk* c = ChunkFromHandle(h); + EP_ENFORCE(c->in_use() && (c->bin_num == kInvalidBinNum), __FUNCTION__); + + // Mark the chunk as no longer in use + c->allocation_id = -1; + + // Updates the stats. + stats_.bytes_in_use -= c->size; + + // This chunk is no longer in-use, consider coalescing the chunk + // with adjacent chunks. + ChunkHandle chunk_to_reassign = Coalesce(h); + InsertFreeChunkIntoBin(chunk_to_reassign); +} + +ArenaImpl::ChunkHandle ArenaImpl::Coalesce(ChunkHandle h) { + Chunk* c = ChunkFromHandle(h); + EP_ENFORCE(!c->in_use(), __FUNCTION__); + + // This chunk is no longer in-use, consider coalescing the chunk with adjacent chunks. + ChunkHandle chunk_to_reassign = h; + + // If the next chunk is free, coalesce the two + if (c->next != kInvalidChunkHandle) { + Chunk* cnext = ChunkFromHandle(c->next); + // only merge the chunks belong to the same stream + if (!cnext->in_use() && cnext->stream == c->stream) { + chunk_to_reassign = h; + + // Deletes c->next + RemoveFreeChunkFromBin(c->next); + Merge(h, ChunkFromHandle(h)->next); + } + } + + // If the previous chunk is free, coalesce the two + c = ChunkFromHandle(h); + if (c->prev != kInvalidChunkHandle) { + Chunk* cprev = ChunkFromHandle(c->prev); + // only merge the chunks belong to the same stream + if (!cprev->in_use() && cprev->stream == c->stream) { + chunk_to_reassign = c->prev; + + RemoveFreeChunkFromBin(c->prev); // this deletes c + Merge(ChunkFromHandle(h)->prev, h); + } + } + + return chunk_to_reassign; +} + +std::array ArenaImpl::GetBinDebugInfo() { + std::array bin_infos; + + for (const auto& region : region_manager_.regions()) { + ChunkHandle h = region_manager_.get_handle(region.ptr()); + while (h != kInvalidChunkHandle) { + const Chunk* c = ChunkFromHandle(h); + BinNum bin_num = BinNumForSize(c->size); + BinDebugInfo& bin_info = bin_infos[bin_num]; + bin_info.total_bytes_in_bin += c->size; + bin_info.total_chunks_in_bin++; + + if (c->in_use()) { + bin_info.total_bytes_in_use += c->size; + bin_info.total_requested_bytes_in_use += c->requested_size; + bin_info.total_chunks_in_use++; + } else { + Bin* bin = BinFromIndex(bin_num); + EP_ENFORCE(bin->free_chunks.count(h) == 1 && c->bin_num == bin_num, __FUNCTION__); + } + + h = c->next; + } + } + return bin_infos; +} + +void ArenaImpl::DumpMemoryLog(size_t num_bytes) { + const std::array bin_infos = GetBinDebugInfo(); + LOG(INFO, "Allocator:" << allocator_name_); + LOG(INFO, "Bin size: Chunks in_use/total (if not zero). Allocated bytes in_use/total. Requested bytes."); + + size_t waste = 0; + for (BinNum bin_num = 0; bin_num < kNumBins; bin_num++) { + Bin* b = BinFromIndex(bin_num); + const BinDebugInfo& bin_info = bin_infos[bin_num]; + EP_ENFORCE(b->free_chunks.size() == bin_info.total_chunks_in_bin - bin_info.total_chunks_in_use, __FUNCTION__); + + if (bin_info.total_chunks_in_bin > 0) { + LOG(INFO, b->bin_size + << ": Chunks " << bin_info.total_chunks_in_use << "/" << bin_info.total_chunks_in_bin + << ". Bytes " + << bin_info.total_bytes_in_use << "/" << bin_info.total_bytes_in_bin << ". " + << "Requested " << bin_info.total_requested_bytes_in_use << "."); + + waste += bin_info.total_bytes_in_use - bin_info.total_requested_bytes_in_use; + } + } + + if (waste > 0) { + LOG(INFO, "Diff between in-use and requested bytes is " << waste); + } + + // Find the bin that we would have liked to allocate in, so we can get some further analysis about fragmentation. + Bin* b = BinForSize(num_bytes); + + LOG(INFO, "Bin for " << num_bytes + << " bytes has max bytes of " << b->bin_size + << ", Chunk State: "); + + for (ChunkHandle h : b->free_chunks) { + Chunk* c = ChunkFromHandle(h); + LOG(INFO, " " << c->DebugString(this, true)); + } + + // Next show the chunks that are in use, and also summarize their number by size. + LOG(INFO, "Overall chunks summary:"); + std::map in_use_by_size; + for (const auto& region : region_manager_.regions()) { + ChunkHandle h = region_manager_.get_handle(region.ptr()); + while (h != kInvalidChunkHandle) { + const Chunk* c = ChunkFromHandle(h); + if (c->in_use()) { + in_use_by_size[c->size]++; + } + LOG(INFO, (c->in_use() ? " Chunk" : " Free ") << " at " << c->ptr + << " of size " << c->size); + h = c->next; + } + } + + LOG(INFO, "Summary of in-use chunks by size: "); + size_t total_bytes = 0; + for (auto& it : in_use_by_size) { + LOG(INFO, " " << it.second << " chunks of size " << it.first + << ". Total " << it.first * it.second); + total_bytes += (it.first * it.second); + } + + LOG(INFO, "Sum Total of in-use chunks: " << total_bytes); + LOG(INFO, "Stats: \n" + << stats_.DebugString()); +} + +OrtStatus* ArenaImpl::ResetChunksUsingStream(const OrtSyncStreamImpl* stream_impl) { + std::lock_guard lock(lock_); + + auto impl_it = impl_to_stream_.find(stream_impl); + if (impl_it == impl_to_stream_.end()) { + return nullptr; // stream hasn't been used with this arena + } + + const OrtSyncStream* stream = impl_it->second; + + auto it = stream_to_chunks_.find(stream); + if (it != stream_to_chunks_.end()) { + const auto& chunk_handles = it->second; + for (size_t handle : chunk_handles) { + Chunk* c = ChunkFromHandle(handle); + assert(c->stream == stream); // something is out of sync if this is not the case + c->stream = nullptr; + } + + stream_to_chunks_.erase(it); + impl_to_stream_.erase(stream_impl); + } + + // It's also possible to find the chunks this way, but that requires iterating every single in-use allocation. + // We also repeat this for every single stream used in a session. + // OTOH there's a cost to create/update keep streams_to_chunks_. + // Using streams_to_chunks_ for now. It also simplifies debugging to have that info. If you're unsure about this + // choice feel free to perf test the two approaches. + // + // for (const auto& region : region_manager_.regions()) { + // ChunkHandle region_begin_chunk = region_manager_.get_handle(region.ptr()); + // ChunkHandle h = region_begin_chunk; + // while (h != kInvalidChunkHandle) { + // Chunk* c = ChunkFromHandle(h); + // if (c->stream == target_stream) { + // c->stream = nullptr; + // c->stream_sync_id = 0; + // } + // h = c->next; + // } + // } + + // coalesce + for (const auto& region : region_manager_.regions()) { + ChunkHandle region_begin_chunk = region_manager_.get_handle(region.ptr()); + ChunkHandle h = region_begin_chunk; + while (h != kInvalidChunkHandle) { + Chunk* c = ChunkFromHandle(h); + if (!c->in_use()) { + RemoveFreeChunkFromBin(h); + ChunkHandle h_next = c->next; + Chunk* c_next = h_next != kInvalidChunkHandle ? ChunkFromHandle(h_next) : nullptr; + + // merge until next chunk is different stream + while (c_next && !c_next->in_use() && c_next->stream == c->stream) { + Coalesce(h); + h_next = c->next; + c_next = h_next != kInvalidChunkHandle ? ChunkFromHandle(h_next) : nullptr; + } + + if (c->bin_num == kInvalidBinNum) { + InsertFreeChunkIntoBin(h); + } + } + h = c->next; + } + } + + return nullptr; +} diff --git a/onnxruntime/test/autoep/library/ep_arena.h b/onnxruntime/test/autoep/library/ep_arena.h new file mode 100644 index 0000000000000..641f3ce3f7b17 --- /dev/null +++ b/onnxruntime/test/autoep/library/ep_arena.h @@ -0,0 +1,629 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// Portions Copyright (c) Microsoft Corporation + +#pragma once +#include +#include +#include +#include +#include + +#include "onnxruntime_cxx_api.h" +#include "ep_allocator.h" +#include "example_plugin_ep_utils.h" + +#if defined(PLATFORM_WINDOWS) +#include +#endif + +enum ArenaExtendStrategy { + kDefault = -1, + kNextPowerOfTwo = 0, + kSameAsRequested = 1, +}; + +// copied from onnxruntime::OrtArenaCfg so the values and config key names match +struct ArenaConfig { + static const ArenaExtendStrategy DEFAULT_ARENA_EXTEND_STRATEGY = ArenaExtendStrategy::kNextPowerOfTwo; + static const int DEFAULT_INITIAL_CHUNK_SIZE_BYTES = 1 * 1024 * 1024; + static const int DEFAULT_MAX_DEAD_BYTES_PER_CHUNK = 128 * 1024 * 1024; + static const int DEFAULT_INITIAL_GROWTH_CHUNK_SIZE_BYTES = 2 * 1024 * 1024; + static const int64_t DEFAULT_MAX_POWER_OF_TWO_EXTEND_BYTES = 1024 * 1024 * 1024; // 1GB + static const size_t DEFAULT_MAX_MEM = std::numeric_limits::max(); + + ArenaConfig(size_t max_mem = std::numeric_limits::max(), + ArenaExtendStrategy arena_extend_strategy = DEFAULT_ARENA_EXTEND_STRATEGY, + int initial_chunk_size_bytes = DEFAULT_INITIAL_CHUNK_SIZE_BYTES, + int max_dead_bytes_per_chunk = DEFAULT_MAX_DEAD_BYTES_PER_CHUNK, + int initial_growth_chunk_size_bytes = DEFAULT_INITIAL_GROWTH_CHUNK_SIZE_BYTES, + int64_t max_power_of_two_extend_bytes = DEFAULT_MAX_POWER_OF_TWO_EXTEND_BYTES) + : max_mem(max_mem), + arena_extend_strategy(arena_extend_strategy), + initial_chunk_size_bytes(initial_chunk_size_bytes), + max_dead_bytes_per_chunk(max_dead_bytes_per_chunk), + initial_growth_chunk_size_bytes(initial_growth_chunk_size_bytes), + max_power_of_two_extend_bytes(max_power_of_two_extend_bytes) { + if (arena_extend_strategy == ArenaExtendStrategy::kDefault) { + arena_extend_strategy = ArenaExtendStrategy::kNextPowerOfTwo; + } + } + + size_t max_mem; + ArenaExtendStrategy arena_extend_strategy; + int initial_chunk_size_bytes; + int max_dead_bytes_per_chunk; + int initial_growth_chunk_size_bytes; + int64_t max_power_of_two_extend_bytes; + + bool IsValid() { + return initial_chunk_size_bytes > 0 && + max_dead_bytes_per_chunk > 0 && + initial_growth_chunk_size_bytes > 0 && + max_power_of_two_extend_bytes > 0; + } + + // config key names that we parse in FromKeyValuePairs + struct ConfigKeyNames { + static constexpr const char* ArenaExtendStrategy = "arena.extend_strategy"; + static constexpr const char* InitialChunkSizeBytes = "arena.initial_chunk_size_bytes"; + static constexpr const char* MaxDeadBytesPerChunk = "arena.max_dead_bytes_per_chunk"; + static constexpr const char* InitialGrowthChunkSizeBytes = "arena.initial_growth_chunk_size_bytes"; + static constexpr const char* MaxPowerOfTwoExtendBytes = "arena.max_power_of_two_extend_bytes"; + static constexpr const char* MaxMem = "arena.max_mem"; + }; + + static ArenaConfig FromKeyValuePairs(const OrtApi& api, const OrtKeyValuePairs& kvps) { + ArenaConfig config{}; + const char* value = nullptr; + + if (value = api.GetKeyValue(&kvps, ConfigKeyNames::ArenaExtendStrategy); value) { + config.arena_extend_strategy = std::string(value) == "1" ? kSameAsRequested : kNextPowerOfTwo; + } + + if (value = api.GetKeyValue(&kvps, ConfigKeyNames::InitialChunkSizeBytes); value) { + config.initial_chunk_size_bytes = std::stoi(std::string(value)); + } + + if (value = api.GetKeyValue(&kvps, ConfigKeyNames::MaxDeadBytesPerChunk); value) { + config.max_dead_bytes_per_chunk = std::stoi(std::string(value)); + } + + if (value = api.GetKeyValue(&kvps, ConfigKeyNames::InitialGrowthChunkSizeBytes); value) { + config.initial_growth_chunk_size_bytes = std::stoi(std::string(value)); + } + + if (value = api.GetKeyValue(&kvps, ConfigKeyNames::MaxPowerOfTwoExtendBytes); value) { + config.max_power_of_two_extend_bytes = std::stoll(value); + } + + if (value = api.GetKeyValue(&kvps, ConfigKeyNames::MaxMem); value) { + config.max_mem = static_cast(std::stoull(std::string(value))); + } + + return config; + } +}; + +// A memory allocator that implements a 'best-fit with coalescing' algorithm. +// This is essentially a very simple version of Doug Lea's malloc (dlmalloc). +// +// The goal of this allocator is to support defragmentation via coalescing. +// One assumption we make is that the process using this allocator owns pretty much all of the memory, and that nearly +// all requests to allocate memory go through this interface. +class ArenaImpl { + public: + static const ArenaExtendStrategy DEFAULT_ARENA_EXTEND_STRATEGY = ArenaExtendStrategy::kNextPowerOfTwo; + static const int DEFAULT_INITIAL_CHUNK_SIZE_BYTES = 1 * 1024 * 1024; + static const int DEFAULT_MAX_DEAD_BYTES_PER_CHUNK = 128 * 1024 * 1024; + static const int DEFAULT_INITIAL_GROWTH_CHUNK_SIZE_BYTES = 2 * 1024 * 1024; + static const int64_t DEFAULT_MAX_POWER_OF_TWO_EXTEND_BYTES = 1024 * 1024 * 1024; // 1GB + static const size_t DEFAULT_MAX_MEM = std::numeric_limits::max(); + + ArenaImpl(AllocatorUniquePtr allocator, const ArenaConfig& config, const OrtApi& api, + const OrtLogger& logger); + + ~ArenaImpl(); + + void* Alloc(size_t size); + void* AllocOnStream(size_t size, OrtSyncStream* stream); + void Free(void* p); + + // allocate memory directly. this is used for initializers so they don't affect the arena growth patterns + void* Reserve(size_t size); + + OrtStatus* GetStats(OrtKeyValuePairs** stats); + + size_t RequestedSize(const void* ptr); + size_t AllocatedSize(const void* ptr); + + // Un-assign chunks that are currently assigned to the stream. + // + // This should be called from OrtSyncStreamImpl::OnSessionRunEnd. + // A stream is used in one session at a time. When called from OnSessionRunEnd we know that the stream is done and + // will not be performing any more operations on the data. + // + // We don't have a better way to know when it's safe to re-use a chunk in another stream given the actual memory + // usage is asynchronous on the GPU side, and the code assigning memory is running on CPU prior to that. + OrtStatus* ResetChunksUsingStream(const OrtSyncStreamImpl* stream_impl); + + private: + void* AllocateRawInternal(size_t num_bytes, OrtSyncStream* stream, bool dump_log_on_failure); + void DeallocateRawInternal(void* ptr); + + // A ChunkHandle is an index into the chunks_ vector in BFCAllocator + // kInvalidChunkHandle means an invalid chunk + using ChunkHandle = size_t; + static const size_t kInvalidChunkHandle = static_cast(-1); + + using BinNum = int; + static const int kInvalidBinNum = -1; + static const int kNumBins = 21; + + // Chunks point to memory. Their prev/next pointers form a + // doubly-linked list of addresses sorted by base address that + // must be contiguous. Chunks contain information about whether + // they are in use or whether they are free, and contain a pointer + // to the bin they are in. + struct Chunk { + size_t size = 0; // Full size of buffer. + + // We sometimes give chunks that are larger than needed to reduce + // fragmentation. requested_size keeps track of what the client + // actually wanted so we can understand whether our splitting + // strategy is efficient. + size_t requested_size = 0; + + // allocation_id is set to -1 when the chunk is not in use. It is assigned a + // value greater than zero before the chunk is returned from + // AllocateRaw, and this value is unique among values assigned by + // the parent allocator. + int64_t allocation_id = -1; + void* ptr = nullptr; // pointer to granted subbuffer. + + // If not kInvalidChunkHandle, the memory referred to by 'prev' is directly + // preceding the memory used by this chunk. E.g., It should start + // at 'ptr - prev->size' + ChunkHandle prev = kInvalidChunkHandle; + + // If not kInvalidChunkHandle, the memory referred to by 'next' is directly + // following the memory used by this chunk. E.g., It should be at + // 'ptr + next->size' + ChunkHandle next = kInvalidChunkHandle; + + // What bin are we in? + BinNum bin_num = kInvalidBinNum; + + OrtSyncStream* stream = nullptr; + // Current sync id of `stream` when it was assigned the Chunk. + // If the chunk is assigned to a stream and is free, and another Stream wants to use it, that Stream must have + // synchronized with `stream` at a sync id > to stream_sync_id. + // stream_sync_id is set when the chunk is first assigned to `stream`. + // The sync id is incremented at the start of sync, so any chunk with a previous sync id is safe to re-assign. + uint64_t stream_sync_id = 0; + + bool in_use() const { return allocation_id != -1; } + + std::string DebugString(ArenaImpl* a, bool recurse) { + std::ostringstream ss; + ss << " Size: " << size << " | Requested Size: " << requested_size << " | in_use: " << in_use(); + if (recurse && prev != ArenaImpl::kInvalidChunkHandle) { + Chunk* p = a->ChunkFromHandle(prev); + ss << ", prev: " << p->DebugString(a, false); + } + + if (recurse && next != ArenaImpl::kInvalidChunkHandle) { + Chunk* n = a->ChunkFromHandle(next); + ss << ", next: " << n->DebugString(a, false); + } + return ss.str(); + } + }; + + // A Bin is a collection of similar-sized free chunks. + struct Bin { + // All chunks in this bin have >= bin_size memory. + size_t bin_size = 0; + + struct ChunkComparator { + explicit ChunkComparator(ArenaImpl* allocator) + : allocator_(allocator) {} + + // Sort first by size and then use pointer address as a tie breaker. + bool operator()(const ChunkHandle ha, + const ChunkHandle hb) const { + const Chunk* a = allocator_->ChunkFromHandle(ha); + const Chunk* b = allocator_->ChunkFromHandle(hb); + if (a->size != b->size) { + return a->size < b->size; + } + return a->ptr < b->ptr; + } + + private: + ArenaImpl* allocator_; // The parent allocator + }; + + typedef std::set FreeChunkSet; + // List of free chunks within the bin, sorted by chunk size. + // Chunk * not owned. + FreeChunkSet free_chunks; + Bin(ArenaImpl* allocator, size_t bs) + : bin_size(bs), free_chunks(ChunkComparator(allocator)) {} + }; + + static const size_t kMinAllocationBits = 8; + static const size_t kMinAllocationSize = 1 << kMinAllocationBits; + + // AllocationRegion maps pointers to ChunkHandles for a single + // contiguous memory region. + // + // This class is thread-compatible. + class AllocationRegion { + public: + AllocationRegion(void* ptr, size_t memory_size, int64_t id) + : ptr_(ptr), + memory_size_(memory_size), + end_ptr_(static_cast(static_cast(ptr_) + memory_size_)), + id_(id) { + EP_ENFORCE(0 == memory_size % kMinAllocationSize, __FUNCTION__); + + const size_t n_handles = (memory_size + kMinAllocationSize - 1) / kMinAllocationSize; + handles_ = std::make_unique(n_handles); + for (size_t i = 0; i < n_handles; i++) { + handles_[i] = kInvalidChunkHandle; + } + } + + AllocationRegion(AllocationRegion&& other) noexcept { Swap(other); } + AllocationRegion() = default; + ~AllocationRegion() = default; + + AllocationRegion& operator=(AllocationRegion&& other) noexcept { + Swap(other); + return *this; + } + + void* ptr() const { return ptr_; } + void* end_ptr() const { return end_ptr_; } + size_t memory_size() const { return memory_size_; } + int64_t id() const { return id_; } + + ChunkHandle get_handle(const void* p) const { + return handles_[IndexFor(p)]; + } + + void set_handle(const void* p, ChunkHandle h) { + handles_[IndexFor(p)] = h; + } + + void erase(const void* p) { + set_handle(p, kInvalidChunkHandle); + } + + private: + void Swap(AllocationRegion& other) { + std::swap(ptr_, other.ptr_); + std::swap(memory_size_, other.memory_size_); + std::swap(end_ptr_, other.end_ptr_); + std::swap(id_, other.id_); + std::swap(handles_, other.handles_); + } + + int IndexFor(const void* p) const { + std::uintptr_t p_int = reinterpret_cast(p); + std::uintptr_t base_int = reinterpret_cast(ptr_); + EP_ENFORCE(p_int >= base_int, "AllocationRegion::IndexFor"); + EP_ENFORCE(p_int < base_int + memory_size_, "AllocationRegion::IndexFor"); + return static_cast(((p_int - base_int) >> kMinAllocationBits)); + } + + // metadata about the allocation region. + void* ptr_ = nullptr; + size_t memory_size_ = 0; + void* end_ptr_ = nullptr; + // A unique identifier for this allocation region + // (May be used by the client to track which allocation region was allocated first, second, and so on) + int64_t id_ = -1; + + // Array of size "memory_size / kMinAllocationSize". It is + // indexed by (p-base) / kMinAllocationSize, contains ChunkHandle + // for the memory allocation represented by "p" + std::unique_ptr handles_; + + AllocationRegion& operator=(const AllocationRegion&) = delete; + }; + + // RegionManager aggregates one or more "AllocationRegions" and provides + // a layer of indirection from pointers to the underlying ChunkHandle, + // allowing allocation across multiple discontiguous memory regions. + // + // This class is thread-compatible. + class RegionManager { + public: + RegionManager() = default; + ~RegionManager() = default; + + void AddAllocationRegion(void* ptr, size_t memory_size, int64_t id) { + // Insert sorted by end_ptr + auto entry = std::upper_bound(regions_.begin(), regions_.end(), ptr, &Comparator); + regions_.insert(entry, AllocationRegion(ptr, memory_size, id)); + } + + void RemoveAllocationRegion(void* ptr) { + auto entry = std::upper_bound(regions_.begin(), regions_.end(), ptr, &Comparator); + EP_ENFORCE(entry != regions_.end(), "RegionManager::RemoveAllocationRegion Could not find Region for: " << ptr); + regions_.erase(entry); + } + + ChunkHandle get_handle(const void* p) const { + return RegionFor(p)->get_handle(p); + } + + void set_handle(const void* p, ChunkHandle h) { + return MutableRegionFor(p)->set_handle(p, h); + } + void erase(const void* p) { return MutableRegionFor(p)->erase(p); } + + const std::vector& regions() const { return regions_; } + + private: + RegionManager(const RegionManager&) = delete; + RegionManager& operator=(const RegionManager&) = delete; + RegionManager(RegionManager&&) = delete; + RegionManager& operator=(RegionManager&&) = delete; + + static bool Comparator(const void* ptr, const AllocationRegion& other) { + return ptr < other.end_ptr(); + } + + AllocationRegion* MutableRegionFor(const void* p) { + return const_cast(RegionFor(p)); + } + + const AllocationRegion* RegionFor(const void* p) const { + auto entry = std::upper_bound(regions_.begin(), regions_.end(), p, &Comparator); + + if (entry != regions_.end()) { + return &(*entry); + } + + EP_ENFORCE(entry != regions_.end(), "RegionManager::RegionFor Could not find Region for: " << p); + return nullptr; + } + + private: + std::vector regions_; + }; + + // Returns 'bytes' rounded up to the next highest kMinAllocationSize. + size_t RoundedBytes(size_t bytes); + + // Try to add a new memory region that can satisfy an allocation of + // 'rounded_bytes' bytes. + OrtStatus* Extend(size_t rounded_bytes); + + // Returns an underlying allocated chunk of size + // 'rounded_bytes'. + ArenaImpl::Chunk* FindChunkPtr(BinNum bin_num, size_t rounded_bytes, size_t num_bytes, OrtSyncStream* stream); + + // Splits the chunk specified by 'h' into two chunks, one at least + // of size 'num_bytes'. + void SplitChunk(ChunkHandle h, size_t num_bytes); + + // Merges the two chunk handles. Requires that the chunks are + // contiguous in their allocation. + void Merge(ChunkHandle h, ChunkHandle h2); + + // Frees the memory represented by 'h', coalescing the chunk if + // possible. + void FreeAndMaybeCoalesce(ChunkHandle h); + + ArenaImpl::ChunkHandle Coalesce(ChunkHandle h); + + // Adds the chunk 'h' to the proper free bin. + void InsertFreeChunkIntoBin(ChunkHandle h); + + // Removes the free chunk pointed to by 'c' from the set free_chunks. + void RemoveFreeChunkIterFromBin(Bin::FreeChunkSet* free_chunks, + const Bin::FreeChunkSet::iterator& c); + + // Removes a free chunk from the bin. + void RemoveFreeChunkFromBin(ChunkHandle h); + + ArenaImpl::Chunk* SplitFreeChunkFromBin(ArenaImpl::Bin::FreeChunkSet* free_chunks, + const ArenaImpl::Bin::FreeChunkSet::iterator& citer, + size_t rounded_bytes, + size_t num_bytes); + + // Removes the chunk metadata represented by 'h'. + void DeleteChunk(ChunkHandle h); + + void DumpMemoryLog(size_t num_bytes); + + ChunkHandle AllocateChunk(); + void DeallocateChunk(ChunkHandle h); + + Chunk* ChunkFromHandle(ChunkHandle h); + + // Information about a Bin that is useful for debugging. + struct BinDebugInfo { + size_t total_bytes_in_use = 0; + size_t total_bytes_in_bin = 0; + size_t total_requested_bytes_in_use = 0; + size_t total_chunks_in_use = 0; + size_t total_chunks_in_bin = 0; + }; + + // Computes and returns a BinDebugInfo for each Bin. + std::array GetBinDebugInfo(); + + int Log2FloorNonZeroSlow(uint64_t n) { + int r = 0; + while (n > 0) { + r++; + n >>= 1; + } + return r - 1; + } + + // Returns floor(log2(n)). + int Log2FloorNonZero(uint64_t n) { +#if defined(__GNUC__) + return 63 ^ __builtin_clzll(n); +#elif defined(PLATFORM_WINDOWS) + unsigned long index; +#if defined(_WIN64) + _BitScanReverse64(&index, n); +#else + auto high = static_cast(n >> 32); + if (_BitScanReverse(&index, high) > 0) { + index += 32; + } else { + auto low = static_cast((n << 32) >> 32); + _BitScanReverse(&index, low); + } +#endif + return index; +#else + return Log2FloorNonZeroSlow(n); +#endif + } + + // Map from bin size to Bin + Bin* BinFromIndex(BinNum index) { + return reinterpret_cast(&(bins_space_[index * sizeof(Bin)])); + } + + size_t BinNumToSize(BinNum index) { + return static_cast(256) << index; + } + + BinNum BinNumForSize(size_t bytes) { + uint64_t v = std::max(bytes, 256) >> kMinAllocationBits; + int b = std::min(kNumBins - 1, Log2FloorNonZero(v)); + return b; + } + + Bin* BinForSize(size_t bytes) { + return BinFromIndex(BinNumForSize(bytes)); + } + + alignas(Bin) char bins_space_[sizeof(Bin) * kNumBins]; + + mutable std::mutex lock_; + + AllocatorUniquePtr device_allocator_; + const std::string allocator_name_; + const ArenaConfig config_; + + RegionManager region_manager_; + size_t curr_region_allocation_bytes_; + + // Counter containing the next unique identifier to assign to a newly-created chunk. + int64_t next_allocation_id_; + + std::vector chunks_; + ChunkHandle free_chunks_list_; // Pointer to head of linked list of free Chunks + std::unordered_map reserved_chunks_; + + // chunks being used by a stream + std::unordered_map> stream_to_chunks_; + + // map to connect the OrtSyncStreamImpl the EP library creates to the OrtSyncStream that ORT uses. + // we don't know that it's safe to re-use a chunk until the stream is done with, which is via the call to + // OrtSyncStreamImpl::OnSessionRunEnd. the allocations see OrtSyncStream, so we need to connect things up to + // un-assign chunks when StreamImpl::OnSessionRunEnd is called. + std::unordered_map impl_to_stream_; + + AllocatorStats stats_; + + const OrtApi& api_; + const OrtEpApi& ep_api_; + const OrtLogger& logger_; + + ArenaImpl(const ArenaImpl&) = delete; + ArenaImpl& operator=(const ArenaImpl&) = delete; + ArenaImpl(ArenaImpl&&) = delete; + ArenaImpl& operator=(ArenaImpl&&) = delete; +}; + +struct ArenaAllocator : OrtAllocator { + static OrtStatus* CreateOrtArenaAllocator(AllocatorUniquePtr allocator, + const OrtKeyValuePairs* options, + const OrtApi& api, + const OrtLogger& logger, + std::unique_ptr& arena_allocator) { + ArenaConfig config = options ? ArenaConfig::FromKeyValuePairs(api, *options) : ArenaConfig{}; + const OrtMemoryInfo* mem_info = allocator->Info(allocator.get()); + auto impl = std::make_unique(std::move(allocator), config, api, logger); + + arena_allocator = std::make_unique(std::move(impl), *mem_info); + + return nullptr; + } + + ArenaAllocator(std::unique_ptr implementation, const OrtMemoryInfo& memory_info) + : impl_{std::move(implementation)}, + memory_info_{memory_info} { + version = ORT_API_VERSION; + Alloc = AllocImpl; + Reserve = ReserveImpl; + Free = FreeImpl; + Info = InfoImpl; + GetStats = GetStatsImpl; + AllocOnStream = AllocOnStreamImpl; + } + + // remove the OrtSyncStream* from any chunks that were using the stream + OrtStatus* ResetChunksUsingStream(const OrtSyncStreamImpl* stream_impl) { + impl_->ResetChunksUsingStream(stream_impl); + return nullptr; + } + + static void* ORT_API_CALL AllocImpl(struct OrtAllocator* this_, size_t size) { + auto& arena = *static_cast(this_); + return arena.impl_->Alloc(size); + } + + static void* ORT_API_CALL AllocOnStreamImpl(struct OrtAllocator* this_, size_t size, OrtSyncStream* stream) { + auto& arena = *static_cast(this_); + return arena.impl_->AllocOnStream(size, stream); + } + + static void* ORT_API_CALL ReserveImpl(struct OrtAllocator* this_, size_t size) { + auto& arena = *static_cast(this_); + return arena.impl_->Reserve(size); + } + + static void ORT_API_CALL FreeImpl(struct OrtAllocator* this_, void* p) { + auto& arena = *static_cast(this_); + arena.impl_->Free(p); + } + + static const OrtMemoryInfo* ORT_API_CALL InfoImpl(const struct OrtAllocator* this_) { + const auto& arena = *static_cast(this_); + return &arena.memory_info_; + } + + static OrtStatus* ORT_API_CALL GetStatsImpl(const struct OrtAllocator* this_, OrtKeyValuePairs** out) noexcept { + const auto& arena = *static_cast(this_); + return arena.impl_->GetStats(out); + }; + + private: + std::unique_ptr impl_; + const OrtMemoryInfo& memory_info_; +}; diff --git a/onnxruntime/test/autoep/library/ep_factory.cc b/onnxruntime/test/autoep/library/ep_factory.cc index 019cf77a66b88..4da7d722a5e0b 100644 --- a/onnxruntime/test/autoep/library/ep_factory.cc +++ b/onnxruntime/test/autoep/library/ep_factory.cc @@ -7,6 +7,7 @@ #include "ep.h" #include "ep_allocator.h" +#include "ep_arena.h" #include "ep_data_transfer.h" #include "ep_stream_support.h" @@ -38,6 +39,8 @@ ExampleEpFactory::ExampleEpFactory(const char* ep_name, ApiPtrs apis, const OrtL /*vendor*/ 0xBE57, /* device_id */ 0, OrtDeviceMemoryType_DEFAULT, /*alignment*/ 0, + // it is invalid to use OrtArenaAllocator as that is reserved for the + // internal ORT Arena implementation OrtAllocatorType::OrtDeviceAllocator, &mem_info); assert(status == nullptr); // should never fail. @@ -47,6 +50,17 @@ ExampleEpFactory::ExampleEpFactory(const char* ep_name, ApiPtrs apis, const OrtL const OrtMemoryDevice* device = ep_api.MemoryInfo_GetMemoryDevice(default_memory_info_.get()); data_transfer_impl_ = std::make_unique(apis, device); + // create read-only allocator for use with initializers. same info as DEFAULT memory apart from the allocator type. + status = ort_api.CreateMemoryInfo_V2("ExampleEP GPU readonly", OrtMemoryInfoDeviceType_GPU, + /*vendor*/ 0xBE57, /* device_id */ 0, + OrtDeviceMemoryType_DEFAULT, + /*alignment*/ 0, + OrtAllocatorType::OrtReadOnlyAllocator, + &mem_info); + assert(status == nullptr); // should never fail. + + readonly_memory_info_ = MemoryInfoUniquePtr(mem_info, ort_api.ReleaseMemoryInfo); + // HOST_ACCESSIBLE memory example. use the non-CPU device type so it's clear which device the memory is also // accessible from. we infer from the type of HOST_ACCESSIBLE that it's CPU accessible. mem_info = nullptr; @@ -121,7 +135,9 @@ OrtStatus* ORT_API_CALL ExampleEpFactory::GetSupportedDevicesImpl(OrtEpFactory* // register the allocator info required by the EP. // registering OrtMemoryInfo for host accessible memory would be done in an additional call. + // OrtReadOnlyAllocator + OrtDeviceMemoryType_DEFAULT allocator for use with initializers is optional. RETURN_IF_ERROR(factory->ep_api.EpDevice_AddAllocatorInfo(ep_device, factory->default_memory_info_.get())); + RETURN_IF_ERROR(factory->ep_api.EpDevice_AddAllocatorInfo(ep_device, factory->readonly_memory_info_.get())); ep_devices[num_ep_devices++] = ep_device; } @@ -195,12 +211,15 @@ void ORT_API_CALL ExampleEpFactory::ReleaseEpImpl(OrtEpFactory* /*this_ptr*/, Or /*static*/ OrtStatus* ORT_API_CALL ExampleEpFactory::CreateAllocatorImpl(OrtEpFactory* this_ptr, const OrtMemoryInfo* memory_info, - const OrtKeyValuePairs* /*allocator_options*/, + const OrtKeyValuePairs* allocator_options, OrtAllocator** allocator) noexcept { auto& factory = *static_cast(this_ptr); *allocator = nullptr; - if (memory_info != factory.default_memory_info_.get()) { + bool is_default_allocator = memory_info == factory.default_memory_info_.get(); + bool is_readonly_allocator = memory_info == factory.readonly_memory_info_.get(); + + if (!is_default_allocator && !is_readonly_allocator) { return factory.ort_api.CreateStatus(ORT_INVALID_ARGUMENT, "INTERNAL ERROR! Unknown memory info provided to CreateAllocator. " "Value did not come directly from an OrtEpDevice returned by this factory."); @@ -209,14 +228,58 @@ OrtStatus* ORT_API_CALL ExampleEpFactory::CreateAllocatorImpl(OrtEpFactory* this // NOTE: The factory implementation is free to return a shared OrtAllocator* instance instead of creating a new // allocator on each call. To do this have an allocator instance as an OrtEpFactory class member and make // ReleaseAllocatorImpl a no-op. - auto cpu_allocator = std::make_unique(memory_info, factory); - *allocator = cpu_allocator.release(); + // + // NOTE: EP should implement its own arena logic. ep_arena.cc/h is provided as a reference and we use it here for + // device memory. `allocator_options` can be used for arena configuration and there is a helper in ep_arena.h + // to convert from OrtKeyValuePairs to the same arena config settings that ORT uses. + // You are of course free to have completely different settings. + + // the read-only allocator is used for initializers. we don't need an arena for that. + if (is_readonly_allocator) { + auto read_only_allocator = std::make_unique(memory_info, factory); + *allocator = read_only_allocator.release(); + return nullptr; + } + + // create/use the shared arena based allocator + std::lock_guard lock{factory.mutex_}; + + if (!factory.arena_allocator_) { + std::unique_ptr ep_allocator = std::make_unique(memory_info, factory); + + // initial shared allocator in environment does not have allocator options. + // if the user calls CreateSharedAllocator they can provide options to configure the arena differently. + factory.arena_allocator_using_default_settings_ = allocator_options == nullptr; + RETURN_IF_ERROR(ArenaAllocator::CreateOrtArenaAllocator(std::move(ep_allocator), allocator_options, + factory.ort_api, + factory.default_logger_, factory.arena_allocator_)); + + } else { + if (factory.arena_allocator_using_default_settings_ && allocator_options) { + // potential change in arena settings. up to EP author to determine how to handle this. + // we should not get here if replacing the shared allocator in the environment, as we free the existing one + // before replacing it. i.e. ReleaseAllocatorImpl should have been called, and arena_allocator_ should be null. + } + } + + ++factory.num_arena_users_; + *allocator = factory.arena_allocator_.get(); + return nullptr; } /*static*/ -void ORT_API_CALL ExampleEpFactory::ReleaseAllocatorImpl(OrtEpFactory* /*this*/, OrtAllocator* allocator) noexcept { - delete static_cast(allocator); +void ORT_API_CALL ExampleEpFactory::ReleaseAllocatorImpl(OrtEpFactory* this_ptr, OrtAllocator* allocator) noexcept { + auto& factory = *static_cast(this_ptr); + std::lock_guard lock{factory.mutex_}; + + if (allocator == factory.arena_allocator_.get()) { + if (--factory.num_arena_users_ == 0) { + factory.arena_allocator_ = nullptr; + } + } else { + delete static_cast(allocator); + } } /*static*/ @@ -238,7 +301,7 @@ OrtStatus* ORT_API_CALL ExampleEpFactory::CreateSyncStreamForDeviceImpl(OrtEpFac const OrtMemoryDevice* memory_device, const OrtKeyValuePairs* stream_options, OrtSyncStreamImpl** stream) noexcept { - auto& factory = *static_cast(this_ptr); + auto& factory = *static_cast(this_ptr); *stream = nullptr; // we only need stream synchronization on the device stream diff --git a/onnxruntime/test/autoep/library/ep_factory.h b/onnxruntime/test/autoep/library/ep_factory.h index 4b286928a79eb..088deda1fe9d2 100644 --- a/onnxruntime/test/autoep/library/ep_factory.h +++ b/onnxruntime/test/autoep/library/ep_factory.h @@ -3,6 +3,9 @@ #pragma once +#include + +#include "ep_arena.h" #include "ep_data_transfer.h" #include "example_plugin_ep_utils.h" @@ -17,6 +20,11 @@ class ExampleEpFactory : public OrtEpFactory, public ApiPtrs { return data_transfer_impl_.get(); } + // Get the shared arena allocator if created. + ArenaAllocator* GetArenaAllocator() const { + return arena_allocator_.get(); + } + private: static const char* ORT_API_CALL GetNameImpl(const OrtEpFactory* this_ptr) noexcept; @@ -68,6 +76,12 @@ class ExampleEpFactory : public OrtEpFactory, public ApiPtrs { // CPU allocator so we can control the arena behavior. optional as ORT always provides a CPU allocator if needed. using MemoryInfoUniquePtr = std::unique_ptr>; MemoryInfoUniquePtr default_memory_info_; + MemoryInfoUniquePtr readonly_memory_info_; // used for initializers + + bool arena_allocator_using_default_settings_{true}; + std::unique_ptr arena_allocator_; // shared device allocator that uses an arena + uint32_t num_arena_users_{0}; + std::mutex mutex_; // mutex to protect arena_allocator_ and num_arena_users_ std::unique_ptr data_transfer_impl_; // data transfer implementation for this factory }; diff --git a/onnxruntime/test/autoep/library/ep_stream_support.cc b/onnxruntime/test/autoep/library/ep_stream_support.cc index a948fe1bfce1e..1f6c16a8cb358 100644 --- a/onnxruntime/test/autoep/library/ep_stream_support.cc +++ b/onnxruntime/test/autoep/library/ep_stream_support.cc @@ -2,7 +2,7 @@ // Licensed under the MIT License. #include "ep_stream_support.h" - +#include "ep_factory.h" // // StreamImpl implementation // @@ -27,7 +27,13 @@ OrtStatus* ORT_API_CALL StreamImpl::FlushImpl(_In_ OrtSyncStreamImpl* /*this_ptr } /*static*/ -OrtStatus* ORT_API_CALL StreamImpl::OnSessionRunEndImpl(_In_ OrtSyncStreamImpl* /*this_ptr*/) noexcept { +OrtStatus* ORT_API_CALL StreamImpl::OnSessionRunEndImpl(_In_ OrtSyncStreamImpl* this_ptr) noexcept { + auto& impl = *static_cast(this_ptr); + auto* arena = impl.factory_->GetArenaAllocator(); + if (arena) { + arena->ResetChunksUsingStream(this_ptr); + } + return nullptr; } diff --git a/onnxruntime/test/autoep/library/ep_stream_support.h b/onnxruntime/test/autoep/library/ep_stream_support.h index 10c4804722f8b..a825e5afd2250 100644 --- a/onnxruntime/test/autoep/library/ep_stream_support.h +++ b/onnxruntime/test/autoep/library/ep_stream_support.h @@ -4,15 +4,18 @@ #pragma once #include "onnxruntime_c_api.h" +#include "ep_factory.h" #include "example_plugin_ep_utils.h" +class ExampleEpFactory; + // // Class implementing Stream support for synchronization. // class StreamImpl : public OrtSyncStreamImpl, public ApiPtrs { public: - StreamImpl(ApiPtrs apis, const OrtEp* ep, const OrtKeyValuePairs* /*stream_options*/) - : ApiPtrs(apis), ep_{ep} { + StreamImpl(ExampleEpFactory& factory, const OrtEp* ep, const OrtKeyValuePairs* /*stream_options*/) + : ApiPtrs(factory), ep_{ep}, factory_{&factory} { ort_version_supported = ORT_API_VERSION; CreateNotification = CreateNotificationImpl; GetHandle = GetHandleImpl; @@ -34,6 +37,7 @@ class StreamImpl : public OrtSyncStreamImpl, public ApiPtrs { // EP instance if the stream is being created internally for inferencing. // nullptr when the stream is created outside of an inference session for data copies. const OrtEp* ep_; + ExampleEpFactory* factory_{nullptr}; }; // diff --git a/onnxruntime/test/autoep/library/example_plugin_ep_utils.h b/onnxruntime/test/autoep/library/example_plugin_ep_utils.h index e107a94410dba..99ebee9ff64de 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep_utils.h +++ b/onnxruntime/test/autoep/library/example_plugin_ep_utils.h @@ -25,12 +25,53 @@ } \ } while (0) +// see ORT_ENFORCE for implementations that also capture a stack trace and work in builds with exceptions disabled +// NOTE: In this simplistic implementation you must provide an argument, even it if's an empty string +#define EP_ENFORCE(condition, ...) \ + do { \ + if (!(condition)) { \ + std::ostringstream oss; \ + oss << "EP_ENFORCE failed: " << #condition << " "; \ + oss << __VA_ARGS__; \ + throw std::runtime_error(oss.str()); \ + } \ + } while (false) + +#ifdef _WIN32 +#define EP_WSTR(x) L##x +#define EP_FILE_INTERNAL(x) EP_WSTR(x) +#define EP_FILE EP_FILE_INTERNAL(__FILE__) +#else +#define EP_FILE __FILE__ +#endif + +#define LOG(level, ...) \ + do { \ + std::ostringstream ss; \ + ss << __VA_ARGS__; \ + api_.Logger_LogMessage(&logger_, ORT_LOGGING_LEVEL_##level, ss.str().c_str(), EP_FILE, __LINE__, __FUNCTION__); \ + } while (false) + +#define RETURN_ERROR(code, ...) \ + do { \ + std::ostringstream ss; \ + ss << __VA_ARGS__; \ + return api_.CreateStatus(code, ss.str().c_str()); \ + } while (false) + +#define THROW(...) \ + std::ostringstream ss; \ + ss << __VA_ARGS__; \ + throw std::runtime_error(ss.str()) + struct ApiPtrs { const OrtApi& ort_api; const OrtEpApi& ep_api; const OrtModelEditorApi& model_editor_api; }; +using AllocatorUniquePtr = std::unique_ptr>; + // Helper to release Ort one or more objects obtained from the public C API at the end of their scope. template struct DeferOrtRelease { diff --git a/onnxruntime/test/autoep/test_allocators.cc b/onnxruntime/test/autoep/test_allocators.cc index 84b6e284ccb8e..77d2bb24b7d35 100644 --- a/onnxruntime/test/autoep/test_allocators.cc +++ b/onnxruntime/test/autoep/test_allocators.cc @@ -30,8 +30,9 @@ struct DummyAllocator : OrtAllocator { Alloc = AllocImpl; Free = FreeImpl; Info = InfoImpl; - Reserve = AllocImpl; // no special reserve logic and most likely unnecessary unless you have your own arena - GetStats = nullptr; // this can be set to nullptr if not implemented + Reserve = AllocImpl; // no special reserve logic and most likely unnecessary unless you have your own arena + GetStats = nullptr; // this can be set to nullptr if not implemented + AllocOnStream = nullptr; // optional } static void* ORT_API_CALL AllocImpl(struct OrtAllocator* this_, size_t size) { @@ -75,9 +76,11 @@ TEST(SharedAllocators, AddArenaToSharedAllocator) { auto initial_chunk_size = "25600"; // arena allocates in 256 byte amounts allocator_options.Add(OrtArenaCfg::ConfigKeyNames::InitialChunkSizeBytes, initial_chunk_size); - ASSERT_ORTSTATUS_OK(c_api.CreateSharedAllocator(*ort_env, example_ep.get(), - OrtDeviceMemoryType_DEFAULT, OrtArenaAllocator, &allocator_options, - &allocator)); + ASSERT_ORTSTATUS_OK(c_api.CreateSharedAllocator(*ort_env, example_ep.get(), OrtDeviceMemoryType_DEFAULT, + // allocator is internally added by EP. + // OrtArenaAllocator can only be used for the internal BFCArena + OrtDeviceAllocator, + &allocator_options, &allocator)); // first allocation should init the arena to the initial chunk size void* mem = allocator->Alloc(allocator, 16); diff --git a/onnxruntime/test/contrib_ops/attention_op_test.cc b/onnxruntime/test/contrib_ops/attention_op_test.cc index 61e5fa05c66c1..4245c4bbb1b0a 100644 --- a/onnxruntime/test/contrib_ops/attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/attention_op_test.cc @@ -311,7 +311,7 @@ static void RunAttentionTest( kv_sequence_length, past_present_share_buffer, use_scale, do_neox_rotary); } -TEST(AttentionTest, AttentionBatch1) { +TEST(ContribOpAttentionTest, AttentionBatch1) { int batch_size = 1; int sequence_length = 2; int hidden_size = 4; @@ -340,7 +340,7 @@ TEST(AttentionTest, AttentionBatch1) { batch_size, sequence_length, hidden_size, number_of_heads); } -TEST(AttentionTest, AttentionBatch1WithQKVAttr1) { +TEST(ContribOpAttentionTest, AttentionBatch1WithQKVAttr1) { int batch_size = 1; int sequence_length = 2; int hidden_size = 4; @@ -381,7 +381,7 @@ TEST(AttentionTest, AttentionBatch1WithQKVAttr1) { 0, false, false, disable_rocm, false, qkv_sizes); } -TEST(AttentionTest, AttentionBatch1WithQKVAttr2) { +TEST(ContribOpAttentionTest, AttentionBatch1WithQKVAttr2) { int batch_size = 1; int sequence_length = 2; int hidden_size = 4; @@ -419,7 +419,7 @@ TEST(AttentionTest, AttentionBatch1WithQKVAttr2) { 0, false, false, disable_rocm, false, qkv_sizes); } -TEST(AttentionTest, AttentionBatch1AttentionBias) { +TEST(ContribOpAttentionTest, AttentionBatch1AttentionBias) { int batch_size = 1; int sequence_length = 2; int hidden_size = 4; @@ -460,7 +460,7 @@ TEST(AttentionTest, AttentionBatch1AttentionBias) { 0, disable_cpu, disable_cuda, disable_rocm, disable_dml, qkv_sizes, attention_bias); } -TEST(AttentionTest, AttentionBatch2AttentionBias) { +TEST(ContribOpAttentionTest, AttentionBatch2AttentionBias) { int batch_size = 2; int sequence_length = 2; int hidden_size = 4; @@ -506,7 +506,7 @@ TEST(AttentionTest, AttentionBatch2AttentionBias) { 0, disable_cpu, disable_cuda, disable_rocm, disable_dml, qkv_sizes, attention_bias); } -TEST(AttentionTest, AttentionBatch1_Float16) { +TEST(ContribOpAttentionTest, AttentionBatch1_Float16) { int batch_size = 1; int sequence_length = 2; int hidden_size = 4; @@ -535,7 +535,7 @@ TEST(AttentionTest, AttentionBatch1_Float16) { batch_size, sequence_length, hidden_size, number_of_heads, true); } -TEST(AttentionTest, AttentionBatch2) { +TEST(ContribOpAttentionTest, AttentionBatch2) { int batch_size = 2; int sequence_length = 2; int hidden_size = 4; @@ -568,7 +568,7 @@ TEST(AttentionTest, AttentionBatch2) { batch_size, sequence_length, hidden_size, number_of_heads); } -TEST(AttentionTest, AttentionMaskPartialSequence) { +TEST(ContribOpAttentionTest, AttentionMaskPartialSequence) { int batch_size = 1; int sequence_length = 2; int hidden_size = 4; @@ -598,7 +598,7 @@ TEST(AttentionTest, AttentionMaskPartialSequence) { batch_size, sequence_length, hidden_size, number_of_heads); } -TEST(AttentionTest, AttentionMaskExceedSequence) { +TEST(ContribOpAttentionTest, AttentionMaskExceedSequence) { int batch_size = 1; int sequence_length = 2; int hidden_size = 4; @@ -628,7 +628,7 @@ TEST(AttentionTest, AttentionMaskExceedSequence) { batch_size, sequence_length, hidden_size, number_of_heads); } -TEST(AttentionTest, AttentionNoMaskIndex) { +TEST(ContribOpAttentionTest, AttentionNoMaskIndex) { int batch_size = 1; int sequence_length = 2; int hidden_size = 4; @@ -658,7 +658,7 @@ TEST(AttentionTest, AttentionNoMaskIndex) { batch_size, sequence_length, hidden_size, number_of_heads); } -TEST(AttentionTest, AttentionUnidirectional) { +TEST(ContribOpAttentionTest, AttentionUnidirectional) { int batch_size = 1; int sequence_length = 2; int hidden_size = 4; @@ -855,7 +855,7 @@ void RawAttentionEmptyPastState(bool past_present_share_buffer) { } } -TEST(AttentionTest, Causal_EmptyPastState) { +TEST(ContribOpAttentionTest, Causal_EmptyPastState) { int batch_size = 1; int sequence_length = 2; int hidden_size = 64; @@ -918,11 +918,11 @@ TEST(AttentionTest, Causal_EmptyPastState) { } } -TEST(AttentionTest, AttentionEmptyPastState) { +TEST(ContribOpAttentionTest, AttentionEmptyPastState) { RawAttentionEmptyPastState(false); } -TEST(AttentionTest, AttentionEmptyPastState_SharedPastPresent) { +TEST(ContribOpAttentionTest, AttentionEmptyPastState_SharedPastPresent) { RawAttentionEmptyPastState(true); } @@ -1037,11 +1037,11 @@ void RawAttentionPastStateBatch1(bool past_present_share_buffer) { } } -TEST(AttentionTest, AttentionPastStateBatch1) { +TEST(ContribOpAttentionTest, AttentionPastStateBatch1) { RawAttentionPastStateBatch1(false); } -TEST(AttentionTest, AttentionPastStateBatch1_SharedPastPresent) { +TEST(ContribOpAttentionTest, AttentionPastStateBatch1_SharedPastPresent) { RawAttentionPastStateBatch1(true); } @@ -1170,11 +1170,11 @@ void RawAttentionPastStateBatch2(bool past_present_share_buffer) { } } -TEST(AttentionTest, AttentionPastStateBatch2) { +TEST(ContribOpAttentionTest, AttentionPastStateBatch2) { RawAttentionPastStateBatch2(false); } -TEST(AttentionTest, AttentionPastStateBatch2_SharedPastPresent) { +TEST(ContribOpAttentionTest, AttentionPastStateBatch2_SharedPastPresent) { RawAttentionPastStateBatch2(true); } @@ -1295,15 +1295,15 @@ void RawAttentionPastStateBatch2WithPadding(bool past_present_share_buffer) { } } -TEST(AttentionTest, AttentionPastStateBatch2WithPadding) { +TEST(ContribOpAttentionTest, AttentionPastStateBatch2WithPadding) { RawAttentionPastStateBatch2WithPadding(false); } -TEST(AttentionTest, AttentionPastStateBatch2WithPadding_SharedPastPresent) { +TEST(ContribOpAttentionTest, AttentionPastStateBatch2WithPadding_SharedPastPresent) { RawAttentionPastStateBatch2WithPadding(true); } -TEST(AttentionTest, AttentionBatch2MaskIndex2) { +TEST(ContribOpAttentionTest, AttentionBatch2MaskIndex2) { int batch_size = 2; int sequence_length = 2; int hidden_size = 4; @@ -1344,7 +1344,7 @@ TEST(AttentionTest, AttentionBatch2MaskIndex2) { AttentionMaskType::MASK_1D_END_START); } -TEST(AttentionTest, AttentionRightPaddingMaskIndex2) { +TEST(ContribOpAttentionTest, AttentionRightPaddingMaskIndex2) { int batch_size = 1; int sequence_length = 2; int hidden_size = 4; @@ -1382,7 +1382,7 @@ TEST(AttentionTest, AttentionRightPaddingMaskIndex2) { AttentionMaskType::MASK_1D_END_START); } -TEST(AttentionTest, AttentionLeftPaddingMaskIndex2) { +TEST(ContribOpAttentionTest, AttentionLeftPaddingMaskIndex2) { int batch_size = 1; int sequence_length = 2; int hidden_size = 4; @@ -1420,7 +1420,7 @@ TEST(AttentionTest, AttentionLeftPaddingMaskIndex2) { AttentionMaskType::MASK_1D_END_START); } -TEST(AttentionTest, AttentionBatch2LeftPaddingMaskIndex2) { +TEST(ContribOpAttentionTest, AttentionBatch2LeftPaddingMaskIndex2) { int batch_size = 2; int sequence_length = 2; int hidden_size = 4; @@ -1462,7 +1462,7 @@ TEST(AttentionTest, AttentionBatch2LeftPaddingMaskIndex2) { AttentionMaskType::MASK_1D_END_START); } -TEST(AttentionTest, Attention3DMask) { +TEST(ContribOpAttentionTest, Attention3DMask) { int batch_size = 2; int sequence_length = 2; int hidden_size = 4; @@ -1508,7 +1508,7 @@ TEST(AttentionTest, Attention3DMask) { AttentionMaskType::MASK_3D_ATTENTION); } -TEST(AttentionTest, AttentionBatch2AttentionMask) { +TEST(ContribOpAttentionTest, AttentionBatch2AttentionMask) { int batch_size = 2; int sequence_length = 2; int hidden_size = 4; @@ -1550,7 +1550,7 @@ TEST(AttentionTest, AttentionBatch2AttentionMask) { AttentionMaskType::MASK_2D_KEY_PADDING); } -TEST(AttentionTest, AttentionUnidirectional3DMask) { +TEST(ContribOpAttentionTest, AttentionUnidirectional3DMask) { int batch_size = 2; int sequence_length = 2; int hidden_size = 4; @@ -1596,7 +1596,7 @@ TEST(AttentionTest, AttentionUnidirectional3DMask) { AttentionMaskType::MASK_3D_ATTENTION); } -TEST(AttentionTest, AttentionUnidirectionalAttentionMask) { +TEST(ContribOpAttentionTest, AttentionUnidirectionalAttentionMask) { int batch_size = 2; int sequence_length = 2; int hidden_size = 4; @@ -1638,7 +1638,7 @@ TEST(AttentionTest, AttentionUnidirectionalAttentionMask) { AttentionMaskType::MASK_2D_KEY_PADDING); } -TEST(AttentionTest, AttentionWithNormFactor) { +TEST(ContribOpAttentionTest, AttentionWithNormFactor) { int batch_size = 2; int sequence_length = 2; int hidden_size = 4; @@ -1683,7 +1683,7 @@ TEST(AttentionTest, AttentionWithNormFactor) { true /*use_scale*/); } -TEST(AttentionTest, AttentionWithNeoXRotaryEmbedding) { +TEST(ContribOpAttentionTest, AttentionWithNeoXRotaryEmbedding) { int batch_size = 2; int sequence_length = 2; int hidden_size = 64; @@ -1717,7 +1717,7 @@ TEST(AttentionTest, AttentionWithNeoXRotaryEmbedding) { true /*use_scale*/, true /*use_neox_rotary_embedding*/); } -TEST(AttentionTest, AttentionMask1DEndNoWord) { +TEST(ContribOpAttentionTest, AttentionMask1DEndNoWord) { int batch_size = 2; int sequence_length = 2; int hidden_size = 4; @@ -1760,7 +1760,7 @@ TEST(AttentionTest, AttentionMask1DEndNoWord) { AttentionMaskType::MASK_1D_KEY_SEQ_LEN); } -TEST(AttentionTest, AttentionMask1DNoWord) { +TEST(ContribOpAttentionTest, AttentionMask1DNoWord) { int batch_size = 2; int sequence_length = 2; int hidden_size = 4; @@ -1803,7 +1803,7 @@ TEST(AttentionTest, AttentionMask1DNoWord) { AttentionMaskType::MASK_1D_END_START); } -TEST(AttentionTest, AttentionMask2DNoWord) { +TEST(ContribOpAttentionTest, AttentionMask2DNoWord) { int batch_size = 2; int sequence_length = 2; int hidden_size = 4; @@ -1846,7 +1846,7 @@ TEST(AttentionTest, AttentionMask2DNoWord) { AttentionMaskType::MASK_2D_KEY_PADDING); } -TEST(AttentionTest, AttentionMask3DNoWord) { +TEST(ContribOpAttentionTest, AttentionMask3DNoWord) { int batch_size = 2; int sequence_length = 2; int hidden_size = 4; @@ -1889,7 +1889,7 @@ TEST(AttentionTest, AttentionMask3DNoWord) { AttentionMaskType::MASK_3D_ATTENTION); } -TEST(AttentionTest, AttentionDummyMask2D) { +TEST(ContribOpAttentionTest, AttentionDummyMask2D) { int batch_size = 2; int sequence_length = 2; int hidden_size = 4; @@ -1931,7 +1931,7 @@ TEST(AttentionTest, AttentionDummyMask2D) { AttentionMaskType::MASK_2D_DUMMY); } -TEST(AttentionTest, Attention4DMask) { +TEST(ContribOpAttentionTest, Attention4DMask) { int batch_size = 1; int sequence_length = 2; int hidden_size = 4; @@ -1977,7 +1977,7 @@ TEST(AttentionTest, Attention4DMask) { disable_cpu); } -TEST(AttentionTest, AttentionMaskIndexOutOfRange) { +TEST(ContribOpAttentionTest, AttentionMaskIndexOutOfRange) { int batch_size = 2; int sequence_length = 2; int hidden_size = 4; @@ -2021,7 +2021,7 @@ TEST(AttentionTest, AttentionMaskIndexOutOfRange) { #if !defined(__wasm__) // TODO: fix in web assembly -TEST(AttentionTest, AttentionPastState_dynamic) { +TEST(ContribOpAttentionTest, AttentionPastState_dynamic) { // create rand inputs RandomValueGenerator random{}; @@ -2051,7 +2051,7 @@ TEST(AttentionTest, AttentionPastState_dynamic) { } #endif //! defined(__wasm__) -TEST(AttentionTest, AttentionPrunedModel) { +TEST(ContribOpAttentionTest, AttentionPrunedModel) { int batch_size = 2; int sequence_length = 2; // test input_hidden_size > hidden_size @@ -2174,7 +2174,7 @@ static void RunModelWithRandomInput( } } -TEST(AttentionTest, Attention_Mask2D_Fp32_B2_S32) { +TEST(ContribOpAttentionTest, Attention_Mask2D_Fp32_B2_S32) { constexpr int batch_size = 2; constexpr int sequence_length = 32; @@ -2196,7 +2196,7 @@ TEST(AttentionTest, Attention_Mask2D_Fp32_B2_S32) { false); } -TEST(AttentionTest, Attention_Mask1D_Fp32_B2_S64) { +TEST(ContribOpAttentionTest, Attention_Mask1D_Fp32_B2_S64) { constexpr int batch_size = 2; constexpr int sequence_length = 64; @@ -2217,7 +2217,7 @@ TEST(AttentionTest, Attention_Mask1D_Fp32_B2_S64) { } // This case can be used to test flash attention using Ampere GPU -TEST(AttentionTest, Attention_NoMask_Fp16) { +TEST(ContribOpAttentionTest, Attention_NoMask_Fp16) { constexpr int batch_size = 2; std::vector sequence_lengths{1, 7, 8}; for (const auto& sequence_length : sequence_lengths) { @@ -2236,7 +2236,7 @@ TEST(AttentionTest, Attention_NoMask_Fp16) { } // This test is disabled since it is flaky. -TEST(AttentionTest, DISABLED_Attention_Mask1D_Fp16_B2_FusedNoPadding) { +TEST(ContribOpAttentionTest, DISABLED_Attention_Mask1D_Fp16_B2_FusedNoPadding) { constexpr int batch_size = 2; // Sequence lengths used in TRT fused attention fp16 v2 kernels. @@ -2263,7 +2263,7 @@ TEST(AttentionTest, DISABLED_Attention_Mask1D_Fp16_B2_FusedNoPadding) { #ifndef ENABLE_TRAINING // Prepacking is disabled in full training build so no need to test the feature in a training build. -TEST(AttentionTest, SharedPrepackedWeights) { +TEST(ContribOpAttentionTest, SharedPrepackedWeights) { int batch_size = 2; int sequence_length = 2; int hidden_size = 4; diff --git a/onnxruntime/test/contrib_ops/dynamic_quantize_matmul_test.cc b/onnxruntime/test/contrib_ops/dynamic_quantize_matmul_test.cc index c9a7116bf8052..2918e4baf86a4 100644 --- a/onnxruntime/test/contrib_ops/dynamic_quantize_matmul_test.cc +++ b/onnxruntime/test/contrib_ops/dynamic_quantize_matmul_test.cc @@ -82,21 +82,48 @@ static void CalculateDynamicQuantizeMatMul(const int64_t M, const int64_t N, con } } +struct TestDynamicQuantizeMatMulOptions { + bool is_matrix_b_constant = true; + + bool per_column = false; + + bool is_scale_constant = false; + + bool has_zp = true; + bool is_zp_constant = false; + bool is_zp_zero = false; + + bool has_bias = false; + bool is_bias_constant = false; + + bool empty_input = false; +}; + template -void TestDynamicQuantizeMatMul(bool is_matrix_b_constant, - bool per_column = false, - bool has_zp = true, - bool has_bias = false, - bool empty_input = false) { +void TestDynamicQuantizeMatMul(const TestDynamicQuantizeMatMulOptions& opts) { + static_assert(std::is_same_v || std::is_same_v); + + SCOPED_TRACE(MakeString( + "b data type:", (std::is_same_v ? "uint8" : "int8"), + ", is_matrix_b_constant:", opts.is_matrix_b_constant, + ", per_column:", opts.per_column, + ", is_scale_constant:", opts.is_scale_constant, + ", has_zp:", opts.has_zp, + ", is_zp_constant:", opts.is_zp_constant, + ", is_zp_zero:", opts.is_zp_zero, + ", has_bias:", opts.has_bias, + ", is_bias_constant:", opts.is_bias_constant, + ", empty_input:", opts.empty_input)); + // create rand inputs RandomValueGenerator random{1668426375}; - int64_t M = empty_input ? 1 : 4; + int64_t M = opts.empty_input ? 1 : 4; int64_t N = 128; int64_t K = 128; - std::vector A_dims{empty_input ? 0 : M, K}; + std::vector A_dims{opts.empty_input ? 0 : M, K}; std::vector B_dims{K, N}; - std::vector Y_dims{empty_input ? 0 : M, K}; + std::vector Y_dims{opts.empty_input ? 0 : M, K}; std::vector A_data = random.Uniform(A_dims, -1.0f, 1.0f); std::vector B_data; std::vector tmp_B_data = random.Uniform(B_dims, @@ -106,101 +133,120 @@ void TestDynamicQuantizeMatMul(bool is_matrix_b_constant, return static_cast(v); }); - int64_t b_scale_zp_size = per_column ? B_dims.back() : 1; + int64_t b_scale_zp_size = opts.per_column ? B_dims.back() : 1; std::vector B_scale = random.Uniform(AsSpan({b_scale_zp_size}), -0.1f, 0.1f); std::vector B_zero_point(b_scale_zp_size); - std::for_each(B_zero_point.begin(), - B_zero_point.end(), - [&random](T& zp) { - zp = static_cast(random.Uniform(std::array{1}, - std::numeric_limits::min(), - std::numeric_limits::max())[0]); - }); + if (!opts.is_zp_zero) { + std::for_each(B_zero_point.begin(), + B_zero_point.end(), + [&random](T& zp) { + zp = static_cast(random.Uniform(std::array{1}, + std::numeric_limits::min(), + std::numeric_limits::max())[0]); + }); + } std::vector Bias = random.Uniform(AsSpan({B_dims.back()}), -0.1f, 0.1f); OpTester test("DynamicQuantizeMatMul", 1, onnxruntime::kMSDomain); test.AddInput("A", A_dims, A_data); - test.AddInput("B", B_dims, B_data, is_matrix_b_constant); - test.AddInput("b_scale", {b_scale_zp_size}, B_scale); + test.AddInput("B", B_dims, B_data, opts.is_matrix_b_constant); + test.AddInput("b_scale", {b_scale_zp_size}, B_scale, opts.is_scale_constant); - if (has_zp) { - test.AddInput("b_zero_point", {b_scale_zp_size}, B_zero_point); + if (opts.has_zp) { + test.AddInput("b_zero_point", {b_scale_zp_size}, B_zero_point, opts.is_zp_constant); } else { test.AddOptionalInputEdge(); } - if (has_bias) { - test.AddInput("bias", {B_dims.back()}, Bias); + if (opts.has_bias) { + test.AddInput("bias", {B_dims.back()}, Bias, opts.is_bias_constant); } else { test.AddOptionalInputEdge(); } std::vector Y_data(M * N); CalculateDynamicQuantizeMatMul(M, N, K, A_data, B_data, B_scale, B_zero_point, Bias, Y_data, - per_column, has_zp, has_bias); + opts.per_column, opts.has_zp, opts.has_bias); test.AddOutput("Y", Y_dims, Y_data); test.SetOutputRelErr("Y", 0.02f); test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kOpenVINOExecutionProvider}); } -template -void RunDynamicQuantizeMatMulTest() { - TestDynamicQuantizeMatMul(false, /*is_matrix_b_constant*/ - false, /*per_column*/ - HasZeroPoint, /*has_zp*/ - HasBias /*has_bias*/ - ); - - TestDynamicQuantizeMatMul(true, /*is_matrix_b_constant*/ - false, /*per_column*/ - HasZeroPoint, /*has_zp*/ - HasBias /*has_bias*/ - ); - - TestDynamicQuantizeMatMul(false, /*is_matrix_b_constant*/ - true, /*per_column*/ - HasZeroPoint, /*has_zp*/ - HasBias /*has_bias*/ - ); - - TestDynamicQuantizeMatMul(true, /*is_matrix_b_constant*/ - true, /*per_column*/ - HasZeroPoint, /*has_zp*/ - HasBias /*has_bias*/ - ); +template +void TestDynamicQuantizeMatMul(bool is_matrix_b_constant, + bool per_column = false, + bool has_zp = true, + bool has_bias = false, + bool empty_input = false) { + TestDynamicQuantizeMatMulOptions opts{}; + opts.is_matrix_b_constant = is_matrix_b_constant; + opts.per_column = per_column; + opts.has_zp = has_zp; + opts.has_bias = has_bias; + opts.empty_input = empty_input; + + TestDynamicQuantizeMatMul(opts); } -TEST(DynamicQuantizeMatMul, HasZeroPoint_NoBias_test_S8) { - RunDynamicQuantizeMatMulTest(); +template +void RunDynamicQuantizeMatMulTest() { + for (bool is_matrix_b_constant : {false, true}) { + for (bool per_column : {false, true}) { + for (bool has_zp : {false, true}) { + for (bool has_bias : {false, true}) { + TestDynamicQuantizeMatMul(is_matrix_b_constant, + per_column, + has_zp, + has_bias); + } + } + } + } } -TEST(DynamicQuantizeMatMul, HasZeroPoint_NoBias_test_U8) { - RunDynamicQuantizeMatMulTest(); +TEST(DynamicQuantizeMatMul, Int8) { + RunDynamicQuantizeMatMulTest(); } -TEST(DynamicQuantizeMatMul, NoZeroPoint_HasBias_test_S8) { - RunDynamicQuantizeMatMulTest(); +TEST(DynamicQuantizeMatMul, UInt8) { + RunDynamicQuantizeMatMulTest(); } -TEST(DynamicQuantizeMatMul, NoZeroPoint_HasBias_test_U8) { - RunDynamicQuantizeMatMulTest(); -} +TEST(DynamicQuantizeMatMul, WithConstantBInputs) { + TestDynamicQuantizeMatMulOptions base_opts{}; + base_opts.is_matrix_b_constant = true; + base_opts.is_scale_constant = true; + base_opts.is_zp_constant = true; -TEST(DynamicQuantizeMatMul, NoZeroPoint_NoBias_test_S8) { - RunDynamicQuantizeMatMulTest(); -} + { + // no zp + auto opts = base_opts; + opts.has_zp = false; -TEST(DynamicQuantizeMatMul, NoZeroPoint_NoBias_test_U8) { - RunDynamicQuantizeMatMulTest(); -} + TestDynamicQuantizeMatMul(opts); + TestDynamicQuantizeMatMul(opts); + } -TEST(DynamicQuantizeMatMul, HasZeroPoint_HasBias_test_S8) { - RunDynamicQuantizeMatMulTest(); -} + { + // zp that is zero (symmetric quantization) + auto opts = base_opts; + opts.has_zp = true; + opts.is_zp_zero = true; -TEST(DynamicQuantizeMatMul, HasZeroPoint_HasBias_test_U8) { - RunDynamicQuantizeMatMulTest(); + TestDynamicQuantizeMatMul(opts); + TestDynamicQuantizeMatMul(opts); + } + + { + // zp that is non-zero + auto opts = base_opts; + opts.has_zp = true; + opts.is_zp_zero = false; + + TestDynamicQuantizeMatMul(opts); + TestDynamicQuantizeMatMul(opts); + } } TEST(DynamicQuantizeMatMul, UInt8_test_with_empty_input) { diff --git a/onnxruntime/test/ep_graph/test_ep_graph.cc b/onnxruntime/test/ep_graph/test_ep_graph.cc index 2a2d69a1c2e47..45314f8f39eea 100644 --- a/onnxruntime/test/ep_graph/test_ep_graph.cc +++ b/onnxruntime/test/ep_graph/test_ep_graph.cc @@ -87,6 +87,106 @@ TEST(EpGraphTest, Check3LayerNestedSubgraphV2) { CheckGraphCApi(test_graph->GetGraphViewer(), test_graph->GetOrtGraph()); } +// Check correctness of an OrtGraph that has external initializers. +TEST(EpGraphTest, CheckModelExternalInitializers) { + auto test_graph = TestGraph::Load(ORT_TSTR("testdata/conv_qdq_external_ini.onnx")); + ASSERT_NE(test_graph, nullptr) << "Failed to load test model"; + + CheckGraphCApi(test_graph->GetGraphViewer(), test_graph->GetOrtGraph()); +} + +static void RunConvQDQExtIni(const ORTCHAR_T* model_path, std::vector& output_data) { + auto memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); + Ort::SessionOptions sess_options; + Ort::Session session(*ort_env, model_path, sess_options); + + std::vector input_shape = {1, 3, 24, 24}; + std::vector input_data(3 * 24 * 24, 0.5f); + std::vector ort_inputs; + std::vector ort_input_names; + + // Add 'input' + ort_inputs.emplace_back(Ort::Value::CreateTensor( + memory_info, input_data.data(), input_data.size(), input_shape.data(), input_shape.size())); + ort_input_names.push_back("input"); + + // Run session and get outputs + std::array output_names{"output"}; + std::vector ort_outputs = session.Run(Ort::RunOptions{nullptr}, ort_input_names.data(), ort_inputs.data(), + ort_inputs.size(), output_names.data(), output_names.size()); + + // Check output type and number of elements. + Ort::Value& ort_output = ort_outputs[0]; + auto output_type_shape = ort_output.GetTensorTypeAndShapeInfo(); + size_t num_output_elems = output_type_shape.GetElementCount(); + + ASSERT_EQ(output_type_shape.GetElementType(), ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT); + ASSERT_EQ(num_output_elems, 32 * 26 * 26); + + // Return output data. + const float* output_values = ort_output.GetTensorData(); + output_data.assign(output_values, output_values + num_output_elems); +} + +// Test serializing an OrtGraph with external initializers to GraphProto. +// Checks that the outputs of the serialized and original models are identical. +TEST(EpGraphTest, SerializeToProto_InputModelHasExternalIni) { + const ORTCHAR_T* original_model_path = ORT_TSTR("testdata/conv_qdq_external_ini.onnx"); + const ORTCHAR_T* serialized_model_path = ORT_TSTR("conv_qdq_ext_ini_serialized.onnx"); + std::filesystem::remove(serialized_model_path); + + { + auto test_graph = TestGraph::Load(original_model_path); + ASSERT_NE(test_graph, nullptr) << "Failed to load test model"; + + // Serialize OrtGraph to GraphProto. Save initializers to external file. + std::string ext_ini_file_path = "conv_qdq_ext_ini_serialized.bin"; + std::filesystem::remove(ext_ini_file_path); + std::ofstream ext_ini_ofs(ext_ini_file_path, std::ios::binary); + auto handle_initializer_data = [&ext_ini_ofs, &ext_ini_file_path](const OrtValueInfo* value_info, + const void* data, size_t bytes, + bool& is_external, std::string& location, + int64_t& offset) -> Ort::Status { + // OrtValueInfo* could be used to query initializer's name, type, shape, + // node consumers, etc. + (void)value_info; + + if (bytes <= 127) { + is_external = false; // Keep small initializers stored inside the TensorProto. + return Ort::Status{nullptr}; + } + + offset = ext_ini_ofs.tellp(); + location = ext_ini_file_path; + ext_ini_ofs.write(static_cast(data), bytes); + ext_ini_ofs.flush(); + is_external = true; // True if is external initializer. + + return Ort::Status{nullptr}; + }; + + ONNX_NAMESPACE::ModelProto model_proto; + ASSERT_CXX_ORTSTATUS_OK(OrtEpUtils::OrtGraphToProto(test_graph->GetOrtGraph(), model_proto, + handle_initializer_data)); + + std::ofstream ofs(serialized_model_path, std::ios::binary); + model_proto.SerializeToOstream(&ofs); + ofs.flush(); + + ASSERT_TRUE(std::filesystem::exists(serialized_model_path)); + ASSERT_TRUE(std::filesystem::exists(ext_ini_file_path)); + } + + // Compare output of the original and serialized models. Should be identical. + std::vector output_original; + std::vector output_serialized; + + RunConvQDQExtIni(original_model_path, output_original); + RunConvQDQExtIni(serialized_model_path, output_serialized); + + EXPECT_EQ(output_serialized, output_original); +} + static void RunMNISTModel(const ORTCHAR_T* model_path, std::vector& output_data) { auto memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); Ort::SessionOptions sess_options; @@ -158,7 +258,8 @@ TEST(EpGraphTest, SerializeToProto_Mnist) { }; ONNX_NAMESPACE::ModelProto model_proto; - OrtEpUtils::OrtGraphToProto(test_graph->GetOrtGraph(), model_proto, handle_initializer_data); + ASSERT_CXX_ORTSTATUS_OK(OrtEpUtils::OrtGraphToProto(test_graph->GetOrtGraph(), model_proto, + handle_initializer_data)); std::ofstream ofs(serialized_model_path, std::ios::binary); model_proto.SerializeToOstream(&ofs); @@ -202,7 +303,7 @@ TEST(EpGraphTest, SerializeToProto_ExternalInitializersInMemory) { }; ONNX_NAMESPACE::GraphProto graph_proto; - OrtEpUtils::OrtGraphToProto(ort_graph, graph_proto, handle_initializer_data); + ASSERT_CXX_ORTSTATUS_OK(OrtEpUtils::OrtGraphToProto(ort_graph, graph_proto, handle_initializer_data)); // Verify that TensorProto objects within GraphProto point to memory owned by OrtValues in the OrtGraph. const OrtApi& ort_api = Ort::GetApi(); @@ -294,7 +395,7 @@ TEST(EpGraphTest, SerializeToProto_3LayerSubgraphs) { // Serialize OrtGraph to ModelProto (all initializers stored within TensorProtos). ONNX_NAMESPACE::ModelProto model_proto; - OrtEpUtils::OrtGraphToProto(test_graph->GetOrtGraph(), model_proto); + ASSERT_CXX_ORTSTATUS_OK(OrtEpUtils::OrtGraphToProto(test_graph->GetOrtGraph(), model_proto)); std::ofstream ofs(serialized_model_path, std::ios::binary); model_proto.SerializeToOstream(&ofs); @@ -442,17 +543,40 @@ static void CheckValueInfoConsumers(const GraphViewer& graph_viewer, const OrtVa } static void CheckInitializerValueInfo(const OrtValueInfo* api_value_info, - const ONNX_NAMESPACE::TensorProto* tensor_proto) { + const ONNX_NAMESPACE::TensorProto* tensor_proto, + const GraphViewer& graph_viewer) { const OrtApi& ort_api = Ort::GetApi(); - const OrtValue* api_initializer_value = nullptr; - ASSERT_ORTSTATUS_OK(ort_api.ValueInfo_GetInitializerValue(api_value_info, &api_initializer_value)); - ASSERT_NE(api_initializer_value, nullptr); - const char* api_initializer_name = nullptr; ASSERT_ORTSTATUS_OK(ort_api.GetValueInfoName(api_value_info, &api_initializer_name)); ASSERT_NE(api_initializer_name, nullptr); + // Check external initializer info (if any). + OrtExternalInitializerInfo* api_ext_info = nullptr; + ASSERT_ORTSTATUS_OK(ort_api.ValueInfo_GetExternalInitializerInfo(api_value_info, &api_ext_info)); + DeferOrtRelease defer_release_info(&api_ext_info, ort_api.ReleaseExternalInitializerInfo); + + std::unique_ptr ext_info = nullptr; + bool has_ext_info = graph_viewer.GetGraph().GetExternalInitializerInfo(api_initializer_name, ext_info, true); + + if (has_ext_info) { + ASSERT_NE(api_ext_info, nullptr); + const ORTCHAR_T* api_ext_file_path = ort_api.ExternalInitializerInfo_GetFilePath(api_ext_info); + int64_t api_ext_file_offset = ort_api.ExternalInitializerInfo_GetFileOffset(api_ext_info); + size_t api_ext_byte_size = ort_api.ExternalInitializerInfo_GetByteSize(api_ext_info); + + ASSERT_EQ(PathString(api_ext_file_path), ext_info->GetRelPath()); + ASSERT_EQ(api_ext_file_offset, static_cast(ext_info->GetOffset())); + ASSERT_EQ(api_ext_byte_size, ext_info->GetLength()); + } else { + ASSERT_EQ(api_ext_info, nullptr); + ASSERT_FALSE(utils::HasExternalDataInFile(*tensor_proto)); + } + + const OrtValue* api_initializer_value = nullptr; + ASSERT_ORTSTATUS_OK(ort_api.ValueInfo_GetInitializerValue(api_value_info, &api_initializer_value)); + ASSERT_NE(api_initializer_value, nullptr); + // Check initializer type. const ONNX_NAMESPACE::TypeProto type_proto = utils::TypeProtoFromTensorProto(*tensor_proto); auto type_info = OrtTypeInfo::FromTypeProto(type_proto); @@ -463,7 +587,8 @@ static void CheckInitializerValueInfo(const OrtValueInfo* api_value_info, } static void CheckInitializerValueInfosCApi(gsl::span initializer_value_infos, - const InitializedTensorSet& initializer_tensor_protos) { + const InitializedTensorSet& initializer_tensor_protos, + const GraphViewer& graph_viewer) { const OrtApi& ort_api = Ort::GetApi(); for (size_t i = 0; i < initializer_value_infos.size(); i++) { @@ -479,7 +604,7 @@ static void CheckInitializerValueInfosCApi(gsl::span const ONNX_NAMESPACE::TensorProto* tensor_proto = tensor_proto_iter->second; ASSERT_NE(tensor_proto, nullptr); - CheckInitializerValueInfo(api_value_info, tensor_proto); + CheckInitializerValueInfo(api_value_info, tensor_proto, graph_viewer); } } @@ -543,7 +668,7 @@ static void CheckValueInfosCApi(const GraphViewer& graph_viewer, gsl::spanTypeAsProto()); const OrtTypeInfo* api_type_info = nullptr; @@ -643,7 +768,7 @@ static void CheckGraphCApi(const GraphViewer& graph_viewer, const OrtGraph& api_ std::vector api_initializers(api_num_initializers); ASSERT_ORTSTATUS_OK(ort_api.Graph_GetInitializers(&api_graph, api_initializers.data(), api_initializers.size())); - CheckInitializerValueInfosCApi(api_initializers, graph_initializers); + CheckInitializerValueInfosCApi(api_initializers, graph_initializers, graph_viewer); // Check if it has a parent node. const Node* parent_node = graph_viewer.ParentNode(); @@ -725,6 +850,8 @@ static void CheckGraphCApi(const GraphViewer& graph_viewer, const OrtGraph& api_ // It's possible that the type is defined in ONNX::AttributeProto_AttributeType but not in OrtOpAttrType, since the two are not in a 1:1 mapping. // In such cases, OpAttr_GetType will return a non-null status, and we simply skip the check here. + // TODO: Once we add support for ORT_OP_ATTR_TENSOR, we should be able to just fail if OpAttr_GetType + // returns an error. OrtStatusPtr status = ort_api.OpAttr_GetType(api_node_attr, &api_node_attr_type); if (status != nullptr) { Ort::GetApi().ReleaseStatus(status); @@ -761,6 +888,10 @@ static void CheckGraphCApi(const GraphViewer& graph_viewer, const OrtGraph& api_ ASSERT_EQ(api_node_attr_type, OrtOpAttrType::ORT_OP_ATTR_STRINGS); break; } + case ONNX_NAMESPACE::AttributeProto_AttributeType::AttributeProto_AttributeType_GRAPH: { + ASSERT_EQ(api_node_attr_type, OrtOpAttrType::ORT_OP_ATTR_GRAPH); + break; + } default: // The unsupported type should be skipped by 'continue' above. It's unexpected so we force test to fail. ASSERT_ORTSTATUS_OK(ort_api.CreateStatus(ORT_FAIL, "The attribute type is not in AttributeProto_AttributeType and this case shouldn't be hit.")); diff --git a/onnxruntime/test/ep_graph/test_ep_graph_utils.h b/onnxruntime/test/ep_graph/test_ep_graph_utils.h index 2ce107cf734c6..2aebd75e0aaac 100644 --- a/onnxruntime/test/ep_graph/test_ep_graph_utils.h +++ b/onnxruntime/test/ep_graph/test_ep_graph_utils.h @@ -42,6 +42,30 @@ struct NodeArgConsumer { int64_t input_index = -1; }; +// Helper to release Ort one or more objects obtained from the public C API at the end of their scope. +template +struct DeferOrtRelease { + DeferOrtRelease(T** object_ptr, std::function release_func) + : objects_(object_ptr), count_(1), release_func_(release_func) {} + + DeferOrtRelease(T** objects, size_t count, std::function release_func) + : objects_(objects), count_(count), release_func_(release_func) {} + + ~DeferOrtRelease() { + if (objects_ != nullptr && count_ > 0) { + for (size_t i = 0; i < count_; ++i) { + if (objects_[i] != nullptr) { + release_func_(objects_[i]); + objects_[i] = nullptr; + } + } + } + } + T** objects_ = nullptr; + size_t count_ = 0; + std::function release_func_ = nullptr; +}; + // Returns consumers (i.e., consumer node + input index) of a NodeArg from the original graph. Status GetNodeArgConsumers(const GraphViewer& graph_viewer, const NodeArg& node_arg, /*out*/ std::vector& consumers); diff --git a/onnxruntime/test/framework/bfc_arena_test.cc b/onnxruntime/test/framework/bfc_arena_test.cc index 670447f2804dc..9ded9d2bfeac0 100644 --- a/onnxruntime/test/framework/bfc_arena_test.cc +++ b/onnxruntime/test/framework/bfc_arena_test.cc @@ -339,81 +339,82 @@ struct StreamMock : public Stream { #ifdef ORT_ENABLE_STREAM TEST(StreamAwareArenaTest, TwoStreamAllocation) { - StreamAwareArena a(std::unique_ptr(new CPUAllocator()), 1 << 30, false); + StreamAwareArena a(std::unique_ptr(new CPUAllocator()), 1 << 30); CheckStats(&a, 0, 0, 0, 0); OrtDevice tmp; StreamMock stream1(tmp), stream2(tmp); - auto* stream1_chunk_a = a.AllocOnStream(4096, &stream1, nullptr); - auto* stream2_chunk_a = a.AllocOnStream(4096, &stream2, nullptr); - a.Free(stream1_chunk_a); - auto* stream2_chunk_b = a.AllocOnStream(4096, &stream2, nullptr); + auto* stream1_chunk_a = a.AllocOnStream(4096, &stream1); // 4K chunk on stream 1 + auto* stream2_chunk_a = a.AllocOnStream(4096, &stream2); // 4K chunk on stream 2 + a.Free(stream1_chunk_a); // free but assigned to stream1 + // stream2 can't reuse stream1's chunk + auto* stream2_chunk_b = a.AllocOnStream(4096, &stream2); // 4K chunk on stream 2 EXPECT_NE(stream2_chunk_b, stream1_chunk_a); - a.Free(stream2_chunk_a); - auto* stream1_chunk_c = a.AllocOnStream(4096, &stream1, nullptr); - // it should pick the first chunk + + a.Free(stream2_chunk_a); // free but assigned to stream2 + + // it should pick the first chunk. + auto* stream1_chunk_c = a.AllocOnStream(4096, &stream1); EXPECT_EQ(stream1_chunk_c, stream1_chunk_a); - auto* stream1_chunk_d = a.AllocOnStream(4096, &stream1, nullptr); - // it shouldn't pick stream2_chunk_a's buffer + // it shouldn't pick stream2_chunk_a due to stream mismatch + auto* stream1_chunk_d = a.AllocOnStream(4096, &stream1); EXPECT_NE(stream1_chunk_d, stream2_chunk_a); - a.Free(stream2_chunk_b); + + a.Free(stream2_chunk_b); // still assigned to stream 2. should coalesce with stream1_chunk_a to create 8K buffer + // test clean stream2 - a.ReleaseStreamBuffers(&stream2); - auto stream1_chunk_e = a.AllocOnStream(8192, &stream1, nullptr); - // now it should pick the stream2_chunk_a's buffer - EXPECT_EQ(stream1_chunk_e, stream2_chunk_a); + a.ReleaseStreamBuffers(&stream2); // all stream 2 buffers are now available + + // now it should pick stream2_chunk_a as it is no longer assigned to stream 2 + auto stream1_chunk_e = a.AllocOnStream(8192, &stream1); + EXPECT_EQ(stream1_chunk_e, stream2_chunk_a); // stream1_chunk_e and stream2_chunk_a are assigned to stream1 + a.Free(stream1_chunk_c); a.Free(stream1_chunk_d); - // add stream2 to stream 1 depenency + + // stream 2 wait on stream 1 auto stream1_notification_a = stream1.CreateNotification(1); - stream1_notification_a->ActivateAndUpdate(); - stream2.UpdateStreamClock(stream1_notification_a->GetStreamSyncTable()); - auto* stream2_chunk_c = a.AllocOnStream(4096, &stream2, nullptr); - // it should pick the first chunk - EXPECT_EQ(stream2_chunk_c, stream1_chunk_c); - auto* stream2_chunk_d = a.AllocOnStream(4096, &stream2, nullptr); - // it should pick the third slot - EXPECT_EQ(stream2_chunk_d, stream1_chunk_d); - // continue allocate on stream1 - auto* stream1_chunk_f = a.AllocOnStream(4096, &stream1, nullptr); + stream1_notification_a->ActivateAndUpdate(); // stream 1 sync id 0 -> 1 + stream2.UpdateWithAwaitedNotification(*stream1_notification_a); // stream 2 now has sync id info of stream1:1 + + // stream 2 can now take stream 1 buffers with sync id of 0 + auto* stream2_chunk_c = a.AllocOnStream(4096, &stream2); + EXPECT_EQ(stream2_chunk_c, stream1_chunk_c); // stream2 took a buffer from stream1 with sync id 0 + + // stream 2 can take the remaining free buffer from stream 1 with sync id of 0 + auto* stream2_chunk_d = a.AllocOnStream(4096, &stream2); + EXPECT_EQ(stream2_chunk_d, stream1_chunk_d); // stream2 took the other buffer from stream 1 + + // new buffer required + auto* stream1_chunk_f = a.AllocOnStream(4096, &stream1); // new buffer on stream 1. sync id = 1 a.Free(stream1_chunk_f); - auto* stream2_chunk_e = a.AllocOnStream(4096, &stream2, nullptr); + + // new buffer required + auto* stream2_chunk_e = a.AllocOnStream(4096, &stream2); // new buffer on stream 2 EXPECT_NE(stream2_chunk_e, stream1_chunk_f); + + // free 8K buffer on stream 1 a.Free(stream1_chunk_e); - // test clean stream1 - a.ReleaseStreamBuffers(&stream1); - auto* stream2_chunk_f = a.AllocOnStream(8192, &stream2, nullptr); - // now it should pick stream1_chunk_e + + // can use 8K stream1_chunk_e as it has sync id = 0 and stream 2 has sync id of 1 for stream 1 + auto* stream2_chunk_f = a.AllocOnStream(8192, &stream2); EXPECT_EQ(stream2_chunk_f, stream1_chunk_e); + // remove assignment to stream 1 for free buffers. stream1_chunk_f will become available to stream 2 + a.ReleaseStreamBuffers(&stream1); // stream1 buffers are new available + + auto* stream2_chunk_g = a.AllocOnStream(4096, &stream2); + EXPECT_EQ(stream2_chunk_g, stream1_chunk_f); + // cleanup a.Free(stream2_chunk_d); a.Free(stream2_chunk_e); a.Free(stream2_chunk_f); } - -TEST(StreamAwareArenaTest, TestSecureTheChunk) { - StreamAwareArena a(std::unique_ptr(new CPUAllocator()), 1 << 30, true); - OrtDevice tmp; - StreamMock stream1(tmp), stream2(tmp); - - void* p1 = a.AllocOnStream(BFCArena::DEFAULT_INITIAL_CHUNK_SIZE_BYTES, &stream1, nullptr); - a.Free(p1); - - bool waitFunctionInvoked = false; - void* p2 = a.AllocOnStream(BFCArena::DEFAULT_INITIAL_CHUNK_SIZE_BYTES, &stream2, - [&waitFunctionInvoked](Stream*, synchronize::Notification&) { waitFunctionInvoked = true; }); - - std::unordered_map syncTable; - stream2.CloneCurrentStreamSyncTable(syncTable); - EXPECT_EQ(syncTable.size(), 1u) << "stream2 has been updated with stream1's nofitication on the clock"; - EXPECT_TRUE(waitFunctionInvoked) << "wait function should be invoked"; - a.Free(p2); -} #endif TEST(BFCArenaTest, TestExtendStrategy) { diff --git a/onnxruntime/test/framework/ep_plugin_provider_test.cc b/onnxruntime/test/framework/ep_plugin_provider_test.cc index 4c5dcd2bd7580..35f7d06fb0912 100644 --- a/onnxruntime/test/framework/ep_plugin_provider_test.cc +++ b/onnxruntime/test/framework/ep_plugin_provider_test.cc @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "core/session/ep_plugin_provider_interfaces.h" +#include "core/session/plugin_ep/ep_plugin_provider_interfaces.h" #include "gsl/gsl" #include "gtest/gtest.h" diff --git a/onnxruntime/test/mlas/bench/bench_sgemm.cpp b/onnxruntime/test/mlas/bench/bench_sgemm.cpp index a94d33cd77f63..422fc6fbcadbf 100644 --- a/onnxruntime/test/mlas/bench/bench_sgemm.cpp +++ b/onnxruntime/test/mlas/bench/bench_sgemm.cpp @@ -30,9 +30,12 @@ void SGEMM(benchmark::State& state, bool pack_b, bool trans_a, bool trans_b, flo tpo, onnxruntime::concurrency::ThreadPoolType::INTRA_OP)); if (pack_b) { - size_t pack_b_size = MlasGemmPackBSize(N, K); + CBLAS_TRANSPOSE transB_enum = trans_b ? CblasTrans : CblasNoTrans; + CBLAS_TRANSPOSE transA_enum = trans_a ? CblasTrans : CblasNoTrans; + + size_t pack_b_size = MlasGemmPackBSize(transA_enum, transB_enum, N, K); std::vector B_packed(pack_b_size); - MlasGemmPackB(CblasNoTrans, N, K, B.data(), N, B_packed.data()); + MlasGemmPackB(transA_enum, transB_enum, N, K, B.data(), N, B_packed.data()); MlasGemm( trans_a ? CblasTrans : CblasNoTrans, diff --git a/onnxruntime/test/mlas/unittest/test_dynamic_qgemm.cpp b/onnxruntime/test/mlas/unittest/test_dynamic_qgemm.cpp new file mode 100644 index 0000000000000..a048ded8349b8 --- /dev/null +++ b/onnxruntime/test/mlas/unittest/test_dynamic_qgemm.cpp @@ -0,0 +1,165 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: MIT +// + +#include "test_util.h" +// Currently this test only applies to KleidiAI Guard against it running in any other situation +#if defined(USE_KLEIDIAI) && !defined(_MSC_VER) + +class MlasDynamicQgemmTest { + private: + MatrixGuardBuffer buffer_a; + MatrixGuardBuffer buffer_bf; + MatrixGuardBuffer buffer_bq; + MatrixGuardBuffer buffer_c; + MatrixGuardBuffer buffer_c_ref; + + public: + void Test(size_t M, size_t N, size_t K, size_t BatchSize) { + // Setup buffers for holding various data + + float* A = buffer_a.GetBuffer(M * K * BatchSize); + // Buffer for holding floating point version of weight matrix + float* Bf = buffer_bf.GetBuffer(K * N * BatchSize); + // Buffer for holding quantized version of weight matrix + int8_t* Bq = buffer_bq.GetBuffer(K * N * BatchSize); + float* C = buffer_c.GetBuffer(M * N * BatchSize); + float* CRef = buffer_c_ref.GetBuffer(M * N * BatchSize); + + // Initialize A and Bf + for (size_t i = 0; i < M * K * BatchSize; ++i) + A[i] = static_cast((rand() % 255 - 128) / 16.0f); + for (size_t i = 0; i < K * N * BatchSize; ++i) + Bf[i] = static_cast((rand() % 255 - 128) / 16.0f); + + // Quantize Bf → Bq and compute per-column scale and bias per batch + std::vector> b_scale_batches(BatchSize, std::vector(N)); + std::vector> b_bias_batches(BatchSize, std::vector(N, 0.0f)); + + for (size_t b = 0; b < BatchSize; ++b) { + for (size_t n = 0; n < N; ++n) { + float min_val = Bf[b * K * N + n]; + float max_val = min_val; + for (size_t k = 1; k < K; ++k) { + float v = Bf[b * K * N + k * N + n]; + min_val = std::min(min_val, v); + max_val = std::max(max_val, v); + } + float scale = (max_val - min_val) / 255.0f; + if (scale < 1e-8f) scale = 1.0f; + b_scale_batches[b][n] = scale; + + for (size_t k = 0; k < K; ++k) { + float v = Bf[b * K * N + k * N + n]; + int q = static_cast(std::round(v / scale)); + Bq[b * K * N + k * N + n] = static_cast(std::clamp(q, -128, 127)); + } + } + } + + // Prepare kernel parameters + MLAS_GEMM_DYN_QUANT_SHAPE_PARAMS shape{M, N, K}; + std::vector packed_b_storage(BatchSize * MlasDynamicQgemmPackBSize(N, K)); + std::vector params(BatchSize); + + for (size_t b = 0; b < BatchSize; ++b) { + params[b].A = A + b * M * K; + params[b].lda = K; + params[b].C = C + b * M * N; + params[b].ldc = N; + // Pack b matrix using MlasDynamicQgemmPackBSize & MlasDynamicQgemmPackB + void* packed_b = packed_b_storage.data() + b * MlasDynamicQgemmPackBSize(N, K); + MlasDynamicQgemmPackB(N, K, + Bq + b * K * N, + b_scale_batches[b].data(), + b_bias_batches[b].data(), + packed_b); + params[b].PackedB = packed_b; + } + + // call MlasDynamicQGemmBatch Function + MlasDynamicQGemmBatch(shape, params.data(), BatchSize, nullptr); + + // Compute reference result + for (size_t b = 0; b < BatchSize; ++b) { + for (size_t m = 0; m < M; ++m) { + for (size_t n = 0; n < N; ++n) { + float sum = 0.0f; + for (size_t k = 0; k < K; ++k) { + float a = A[b * M * K + m * K + k]; + float bval = static_cast(Bq[b * K * N + k * N + n]) * b_scale_batches[b][n]; + sum += a * bval; + } + CRef[b * M * N + m * N + n] = sum; + } + } + } + + // Validate results + for (size_t i = 0; i < M * N * BatchSize; ++i) { + float abs_c_ref = std::abs(CRef[i]); + float dynamic_rel_tol = (K <= 4) ? 0.05f : 0.03f; + float rel_tol = dynamic_rel_tol * std::max(abs_c_ref, 1.0f); + float abs_tol = 3.0f; + float allowed = std::max(rel_tol, abs_tol); + float diff = std::abs(C[i] - CRef[i]); + ASSERT_LE(diff, allowed); + } + } + + static const char* GetTestSuiteName() { + return "DynamicQgemm"; + } +}; + +class DynamicQgemmExecuteTest : public MlasTestFixture { + public: + DynamicQgemmExecuteTest(size_t M, size_t N, size_t K, size_t BatchSize) + : M_(M), N_(N), K_(K), BatchSize_(BatchSize) {} + + void TestBody() override { + this->mlas_tester->Test(M_, N_, K_, BatchSize_); + } + static size_t RegisterSingleTest(size_t M, size_t N, size_t K, size_t BatchSize) { + std::stringstream ss; + ss << "M" << M << "_N" << N << "_K" << K << "_B" << BatchSize; + + std::string test_name = ss.str(); + + testing::RegisterTest( + MlasDynamicQgemmTest::GetTestSuiteName(), + test_name.c_str(), + nullptr, + test_name.c_str(), + __FILE__, + __LINE__, + [=]() -> MlasTestFixture* { + return new DynamicQgemmExecuteTest(M, N, K, BatchSize); + }); + + return 1; + } + + static size_t RegisterAll(bool is_short_execute) { + const std::vector batch_size = is_short_execute ? std::vector{1UL, 2UL, 4UL} + : std::vector{1UL, 2UL, 4UL, 8UL, 16UL, 32UL, 64UL}; + size_t count = 0; + const size_t sizes[] = {1, 4, 8, 16, 32, 64}; + for (size_t M : sizes) + for (size_t N : sizes) + for (size_t K : sizes) + for (size_t B : batch_size) + count += RegisterSingleTest(M, N, K, B); + return count; + } + + private: + size_t M_, N_, K_, BatchSize_; +}; + +static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_execute) { + return DynamicQgemmExecuteTest::RegisterAll(is_short_execute); +}); +#endif diff --git a/onnxruntime/test/mlas/unittest/test_fgemm.h b/onnxruntime/test/mlas/unittest/test_fgemm.h index 2bd094152d6f0..e7741fba1c3fb 100644 --- a/onnxruntime/test/mlas/unittest/test_fgemm.h +++ b/onnxruntime/test/mlas/unittest/test_fgemm.h @@ -112,11 +112,11 @@ class FgemmPackedContext { float* C, size_t ldc, MLAS_THREADPOOL* threadpool) { - size_t PackedBSize = MlasGemmPackBSize(N, K); + size_t PackedBSize = MlasGemmPackBSize(TransA, TransB, N, K); void* PackedB = BufferBPacked.GetBuffer(PackedBSize * BatchSize, true); std::vector data(BatchSize); for (size_t i = 0; i < BatchSize; i++) { - MlasGemmPackB(TransB, N, K, B + K * N * i, ldb, (uint8_t*)PackedB + PackedBSize * i); + MlasGemmPackB(TransA, TransB, N, K, B + K * N * i, ldb, (uint8_t*)PackedB + PackedBSize * i); data[i].BIsPacked = true; data[i].A = A + M * K * i; data[i].lda = lda; diff --git a/onnxruntime/test/mlas/unittest/test_fgemm_fixture.h b/onnxruntime/test/mlas/unittest/test_fgemm_fixture.h index 53b3edafdf84f..c832ca69dbb31 100644 --- a/onnxruntime/test/mlas/unittest/test_fgemm_fixture.h +++ b/onnxruntime/test/mlas/unittest/test_fgemm_fixture.h @@ -70,6 +70,7 @@ class FgemmShortExecuteTest : public MlasTestFixture> GetBrokenTests(const std::string& provider {"qlinearmatmul_3D_int8_float32", "result diff", {}}, {"qlinearmatmul_3D_uint8_float16", "fp16 type ont supported by CPU EP", {}}}); + // Attention3D examples are wrong with onnx==1.18.0 + broken_tests->insert({"attention_3d", "wrong expected values (fixed in onnx==1.19.0)"}); + broken_tests->insert({"attention_3d_attn_mask", "wrong expected values (fixed in onnx==1.19.0)"}); + broken_tests->insert({"attention_3d_causal", "wrong expected values (fixed in onnx==1.19.0)"}); + broken_tests->insert({"attention_3d_diff_heads_sizes", "wrong expected values (fixed in onnx==1.19.0)"}); + broken_tests->insert({"attention_3d_diff_heads_sizes_attn_mask", "wrong expected values (fixed in onnx==1.19.0)"}); + broken_tests->insert({"attention_3d_diff_heads_sizes_causal", "wrong expected values (fixed in onnx==1.19.0)"}); + broken_tests->insert({"attention_3d_diff_heads_sizes_scaled", "wrong expected values (fixed in onnx==1.19.0)"}); + broken_tests->insert({"attention_3d_diff_heads_sizes_softcap", "wrong expected values (fixed in onnx==1.19.0)"}); + broken_tests->insert({"attention_3d_diff_heads_with_past_and_present", "wrong expected values (fixed in onnx==1.19.0)"}); + broken_tests->insert({"attention_3d_gqa", "wrong expected values (fixed in onnx==1.19.0)"}); + broken_tests->insert({"attention_3d_gqa_attn_mask", "wrong expected values (fixed in onnx==1.19.0)"}); + broken_tests->insert({"attention_3d_gqa_causal", "wrong expected values (fixed in onnx==1.19.0)"}); + broken_tests->insert({"attention_3d_gqa_scaled", "wrong expected values (fixed in onnx==1.19.0)"}); + broken_tests->insert({"attention_3d_gqa_softcap", "wrong expected values (fixed in onnx==1.19.0)"}); + broken_tests->insert({"attention_3d_gqa_with_past_and_present", "wrong expected values (fixed in onnx==1.19.0)"}); + broken_tests->insert({"attention_3d_scaled", "wrong expected values (fixed in onnx==1.19.0)"}); + broken_tests->insert({"attention_3d_softcap", "wrong expected values (fixed in onnx==1.19.0)"}); + broken_tests->insert({"attention_3d_with_past_and_present", "wrong expected values (fixed in onnx==1.19.0)"}); + broken_tests->insert({"attention_3d_with_past_and_present_qk_matmul", "wrong expected values (fixed in onnx==1.19.0)"}); + broken_tests->insert({"attention_3d_with_past_and_present_qk_matmul_bias", "wrong expected values (fixed in onnx==1.19.0)"}); + broken_tests->insert({"attention_3d_with_past_and_present_qk_matmul_softcap", "wrong expected values (fixed in onnx==1.19.0)"}); + broken_tests->insert({"attention_3d_with_past_and_present_qk_matmul_softmax", "wrong expected values (fixed in onnx==1.19.0)"}); + // Some EPs may fail to pass some specific testcases. // For example TenosrRT EP may fail on FLOAT16 related testcases if GPU doesn't support float16. // Instead of list all these testcases, we can use following keyword set to filter out testcases wchich contain diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index d4b54852cc1d0..c82c74b48ccb2 100644 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -6222,6 +6222,78 @@ TEST_F(GraphTransformationTests, MatMulIntegerToFloatTest) { EXPECT_EQ(op_to_count["Add"], 1); } +TEST_F(GraphTransformationTests, MatMulIntegerToFloatFusion_Int8Bias_Input0) { + constexpr const ORTCHAR_T* model_uri = ORT_TSTR("testdata/matmul_integer_to_float_int8_bias_initializer_index1.onnx"); + std::shared_ptr p_model; + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); + Graph& graph = p_model->MainGraph(); + + // check graph structure before applying transformations + const Node* add_node = nullptr; + for (const auto& node : graph.Nodes()) { + if (node.OpType() == "Add") { + add_node = &node; + break; + } + } + + ASSERT_NE(add_node, nullptr) << "Expected Add node not found."; + + const auto& inputs = add_node->InputDefs(); + ASSERT_EQ(inputs.size(), 2u); + + // Assert bias is in position 1 + EXPECT_EQ(inputs[1]->Name(), "bias") << "Expected bias in input 1 but found in input 0."; + + // Apply the transformer + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level2)); + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_)); + + std::map op_to_count = CountOpsInGraph(graph); + EXPECT_EQ(op_to_count["MatMulInteger"], 0); + EXPECT_EQ(op_to_count["Cast"], 0); + EXPECT_EQ(op_to_count["Mul"], 0); + EXPECT_EQ(op_to_count["com.microsoft.MatMulIntegerToFloat"], 1); + EXPECT_EQ(op_to_count["Add"], 0); +} + +TEST_F(GraphTransformationTests, MatMulIntegerToFloatFusion_Int8Bias_Input1) { + constexpr const ORTCHAR_T* model_uri = ORT_TSTR("testdata/matmul_integer_to_float_int8_bias_initializer_index0.onnx"); + std::shared_ptr p_model; + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); + Graph& graph = p_model->MainGraph(); + + // check graph structure before applying transformations + const Node* add_node = nullptr; + for (const auto& node : graph.Nodes()) { + if (node.OpType() == "Add") { + add_node = &node; + break; + } + } + + ASSERT_NE(add_node, nullptr) << "Expected Add node not found."; + + const auto& inputs = add_node->InputDefs(); + ASSERT_EQ(inputs.size(), 2u); + + // Assert bias is in position 0 + EXPECT_EQ(inputs[0]->Name(), "bias") << "Expected bias in input 0 but found in input 1."; + + // Apply the transformer + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level2)); + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_)); + + std::map op_to_count = CountOpsInGraph(graph); + EXPECT_EQ(op_to_count["MatMulInteger"], 0); + EXPECT_EQ(op_to_count["Cast"], 0); + EXPECT_EQ(op_to_count["Mul"], 0); + EXPECT_EQ(op_to_count["com.microsoft.MatMulIntegerToFloat"], 1); + EXPECT_EQ(op_to_count["Add"], 0); +} + #ifdef USE_DML TEST_F(GraphTransformationTests, MatMulIntegerToFloat16Test) { constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/matmul_integer_to_float16_int8.onnx"; diff --git a/onnxruntime/test/providers/base_tester.cc b/onnxruntime/test/providers/base_tester.cc index 7263c435a6a2e..4b37b6c9438aa 100644 --- a/onnxruntime/test/providers/base_tester.cc +++ b/onnxruntime/test/providers/base_tester.cc @@ -629,6 +629,7 @@ void BaseTester::RunWithConfig(size_t* number_of_pre_packed_weights_counter, std::unordered_map feeds; std::vector output_names; FillFeedsAndOutputNames(feeds, output_names); + number_of_nodes_ = model.MainGraph().NumberOfNodes(); // Run the model if (ctx_.run_with_specified_eps) { @@ -794,6 +795,8 @@ void BaseTester::RunWithConfig(size_t* number_of_pre_packed_weights_counter, } } +int BaseTester::GetNumberOfNodesAfterRun() const { return number_of_nodes_; } + void BaseTester::ExecuteModelForEps( std::vector>&& execution_providers, onnxruntime::Model& model, diff --git a/onnxruntime/test/providers/base_tester.h b/onnxruntime/test/providers/base_tester.h index d39cc3c750dec..182ee4a9550fe 100644 --- a/onnxruntime/test/providers/base_tester.h +++ b/onnxruntime/test/providers/base_tester.h @@ -39,6 +39,7 @@ class BaseTester { ONNX_NAMESPACE::OpSchemaRegistry::DomainToVersionRange().Map().at(ONNX_NAMESPACE::ONNX_DOMAIN).second; opset_version_ = latest_onnx_version; } + number_of_nodes_ = 0; } // Derived class to implement to provide the model to test. @@ -621,6 +622,8 @@ class BaseTester { test_allow_released_onnx_opset_only_ = false; } + int GetNumberOfNodesAfterRun() const; + protected: //// if the derived class is caching the model this helper can be called in CreateModelToTest to reset the nodes // static void ClearEpsForAllNodes(Graph& graph); @@ -767,6 +770,7 @@ class BaseTester { std::vector input_data_; std::vector output_data_; std::vector fetches_; + int number_of_nodes_; bool testing_function_called_{}; // has the function that performs the actual testing been called yet? diff --git a/onnxruntime/test/providers/cpu/llm/attention_op_test.cc b/onnxruntime/test/providers/cpu/llm/attention_op_test.cc new file mode 100644 index 0000000000000..b4f6d328cacf7 --- /dev/null +++ b/onnxruntime/test/providers/cpu/llm/attention_op_test.cc @@ -0,0 +1,1206 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include "gtest/gtest.h" +#include "core/session/onnxruntime_cxx_api.h" +#include "test/common/tensor_op_test_utils.h" +#include "test/common/cuda_op_test_utils.h" +#include "test/providers/provider_test_utils.h" + +namespace onnxruntime { +namespace test { + +namespace { +enum class TensorType { + kFloat, + kFloat16, + kBFloat16 +}; +} // anonymous namespace + +static void AddInputs(OpTester& test, + const std::vector& q, + const std::vector& k, + const std::vector& v, + const std::vector& attn_mask, + const std::initializer_list& attn_mask_bool, + const std::vector& past_key, + const std::vector& past_value, + int is_causal, + const std::vector& q_shape, + const std::vector& k_shape, + const std::vector& v_shape, + const std::vector& attn_mask_shape, + const std::vector& past_key_shape, + const std::vector& past_value_shape, + // outputs + const std::vector& y_shape, + const std::vector& present_key_shape, + const std::vector& present_value_shape, + const std::vector& qk_matmul_output_shape, + int kv_num_heads, + int q_num_heads, + int qk_matmul_output_mode, + float scale, + float softcap, + int softmax_precision, + TensorType tensor_type, + const std::vector& y, + const std::vector& present_key, + const std::vector& present_value, + const std::vector& qk_matmul_output) { + if (is_causal >= 0) + test.AddAttribute("is_causal", is_causal); + if (q_shape.size() == 3) { + test.AddAttribute("kv_num_heads", kv_num_heads); + test.AddAttribute("q_num_heads", q_num_heads); + } + if (qk_matmul_output_mode >= 0) + test.AddAttribute("qk_matmul_output_mode", qk_matmul_output_mode); + if (!std::isnan(scale)) + test.AddAttribute("scale", scale); + if (!std::isnan(softcap)) + test.AddAttribute("softcap", softcap); + if (softmax_precision >= 0) + test.AddAttribute("softmax_precision", softmax_precision); + + if (tensor_type == TensorType::kFloat) { + // inputs + test.AddInput("Q", q_shape, q); + test.AddInput("K", k_shape, k); + test.AddInput("V", v_shape, v); + if (!attn_mask.empty()) + test.AddInput("attn_mask", attn_mask_shape, attn_mask); + else if (attn_mask_bool.size() > 0) + test.AddInput("attn_mask", attn_mask_shape, attn_mask_bool); + else + test.AddOptionalInputEdge(); + + if (!past_key.empty()) + test.AddInput("past_key", past_key_shape, past_key); + else + test.AddOptionalInputEdge(); + + if (!past_value.empty()) + test.AddInput("past_value", past_value_shape, past_value); + else + test.AddOptionalInputEdge(); + // outputs + test.AddOutput("Y", y_shape, y, false, 0, 3e-5f); + if (!present_key.empty()) + test.AddOutput("present_key", present_key_shape, present_key); + if (!present_value.empty()) + test.AddOutput("present_value", present_value_shape, present_value); + if (!qk_matmul_output.empty()) + test.AddOutput("qk_matmul_output", qk_matmul_output_shape, qk_matmul_output); + } else if (tensor_type == TensorType::kFloat16) { + // inputs + test.AddInput("Q", q_shape, ToFloat16(q)); + test.AddInput("K", k_shape, ToFloat16(k)); + test.AddInput("V", v_shape, ToFloat16(v)); + if (!attn_mask.empty()) + test.AddInput("attn_mask", attn_mask_shape, ToFloat16(attn_mask)); + else if (attn_mask_bool.size() > 0) + test.AddInput("attn_mask", attn_mask_shape, attn_mask_bool); + else + test.AddOptionalInputEdge(); + + if (!past_key.empty()) + test.AddInput("past_key", past_key_shape, ToFloat16(past_key)); + else + test.AddOptionalInputEdge(); + + if (!past_value.empty()) + test.AddInput("past_value", past_value_shape, ToFloat16(past_value)); + else + test.AddOptionalInputEdge(); + // outputs + test.AddOutput("Y", y_shape, ToFloat16(y), false, 0, 3e-3f); + if (!present_key.empty()) + test.AddOutput("present_key", present_key_shape, ToFloat16(present_key)); + if (!present_value.empty()) + test.AddOutput("present_value", present_value_shape, ToFloat16(present_value)); + if (!qk_matmul_output.empty()) + test.AddOutput("qk_matmul_output", qk_matmul_output_shape, ToFloat16(qk_matmul_output)); + } else { + // inputs + test.AddInput("Q", q_shape, FloatsToBFloat16s(q)); + test.AddInput("K", k_shape, FloatsToBFloat16s(k)); + test.AddInput("V", v_shape, FloatsToBFloat16s(v)); + if (!attn_mask.empty()) + test.AddInput("attn_mask", attn_mask_shape, FloatsToBFloat16s(attn_mask)); + else if (attn_mask_bool.size() > 0) + test.AddInput("attn_mask", attn_mask_shape, attn_mask_bool); + else + test.AddOptionalInputEdge(); + + if (!past_key.empty()) + test.AddInput("past_key", past_key_shape, FloatsToBFloat16s(past_key)); + else + test.AddOptionalInputEdge(); + + if (!past_value.empty()) + test.AddInput("past_value", past_value_shape, FloatsToBFloat16s(past_value)); + else + test.AddOptionalInputEdge(); + // outputs + test.AddOutput("Y", y_shape, FloatsToBFloat16s(y), false, 0, 3e-3f); + if (!present_key.empty()) + test.AddOutput("present_key", present_key_shape, FloatsToBFloat16s(present_key)); + if (!present_value.empty()) + test.AddOutput("present_value", present_value_shape, FloatsToBFloat16s(present_value)); + if (!qk_matmul_output.empty()) + test.AddOutput("qk_matmul_output", qk_matmul_output_shape, FloatsToBFloat16s(qk_matmul_output)); + } +} + +static void SetProviders(std::vector>& execution_providers, bool disable_cpu, bool disable_cuda, bool disable_dml, TensorType tensor_type) { + int min_cuda_architecture = (tensor_type == TensorType::kBFloat16) + ? 800 + : (tensor_type == TensorType::kFloat16) ? 530 + : 0; + bool enable_cuda = HasCudaEnvironment(min_cuda_architecture); + bool enable_dml = (nullptr != DefaultDmlExecutionProvider().get()) && !disable_dml; + bool enable_webgpu = nullptr != DefaultWebGpuExecutionProvider().get(); + + if (enable_cuda && !disable_cuda) { + execution_providers.push_back(DefaultCudaExecutionProvider()); + } + if (enable_dml && !disable_dml) { + execution_providers.push_back(DefaultDmlExecutionProvider()); + } + if ((tensor_type == TensorType::kFloat || tensor_type == TensorType::kFloat16) && !disable_cpu) { + execution_providers.push_back(DefaultCpuExecutionProvider()); + } + if (enable_webgpu) { + execution_providers.push_back(DefaultWebGpuExecutionProvider()); + } +} + +static void RunTest3D( + int batch_size, + int q_num_heads, + int q_sequence_length, + int head_size, + int kv_sequence_length, + int kv_num_heads, + int v_head_size, + int past_sequence_length, + const std::vector& q, + const std::vector& k, + const std::vector& v, + const std::vector& attn_mask, + const std::initializer_list& attn_mask_bool, + const std::vector& past_key, + const std::vector& past_value, + int is_causal, // 0 + // int kv_num_heads, // not needed for 3D + // int q_num_heads, // not needed for 3D + int qk_matmul_output_mode, // 0 + float scale, // 1.0 + float softcap, // 0.0, + int softmax_precision, + TensorType tensor_type, + const std::vector& y, + const std::vector& present_key, + const std::vector& present_value, + const std::vector& qk_matmul_output, + bool disable_cpu, + bool disable_cuda, + bool disable_dml) { + int total_sequence_length = past_sequence_length + kv_sequence_length; + // inputs + int q_hidden_size = q_num_heads * head_size; + int k_hidden_size = kv_num_heads * head_size; + int v_hidden_size = kv_num_heads * v_head_size; + int hidden_size = q_num_heads * v_head_size; + std::vector q_shape = {batch_size, q_sequence_length, q_hidden_size}; + std::vector k_shape = {batch_size, kv_sequence_length, k_hidden_size}; + std::vector v_shape = {batch_size, kv_sequence_length, v_hidden_size}; + + std::vector attn_mask_shape = {q_sequence_length, total_sequence_length}; + if (q_sequence_length * total_sequence_length != attn_mask.size() && attn_mask.size() > 0) { + if (batch_size * q_sequence_length * total_sequence_length == attn_mask.size()) { + attn_mask_shape = {batch_size, 1, q_sequence_length, total_sequence_length}; + } else if (batch_size * q_num_heads * q_sequence_length * total_sequence_length == attn_mask.size()) { + attn_mask_shape = {batch_size, q_num_heads, q_sequence_length, total_sequence_length}; + } else { + ORT_THROW("Invalid attention mask size: ", attn_mask.size(), + " expected ", q_sequence_length, "*", total_sequence_length, " or ", + batch_size, "*", q_sequence_length, "*", total_sequence_length); + } + } + + std::vector past_key_shape = {batch_size, kv_num_heads, past_sequence_length, head_size}; + std::vector past_value_shape = {batch_size, kv_num_heads, past_sequence_length, head_size}; + // outputs + std::vector y_shape = {batch_size, q_sequence_length, hidden_size}; + std::vector present_key_shape = {batch_size, kv_num_heads, total_sequence_length, head_size}; + std::vector present_value_shape = {batch_size, kv_num_heads, total_sequence_length, v_head_size}; + std::vector qk_matmul_output_shape = {batch_size, q_num_heads, q_sequence_length, total_sequence_length}; + + std::vector> execution_providers; + SetProviders(execution_providers, disable_cpu, disable_cuda, disable_dml, tensor_type); + if (execution_providers.size() == 0) { + // Return early if CI pipeline does not support EP (e.g. CUDA EP for CPU CI pipeline) + return; + } + + for (auto& ep : execution_providers) { + OpTester test("Attention", 23, onnxruntime::kOnnxDomain); + AddInputs(test, q, k, v, attn_mask, attn_mask_bool, past_key, past_value, is_causal, + q_shape, k_shape, v_shape, attn_mask_shape, past_key_shape, past_value_shape, y_shape, present_key_shape, present_value_shape, qk_matmul_output_shape, + kv_num_heads, q_num_heads, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type, y, present_key, present_value, qk_matmul_output); + + std::vector> test_execution_providers; + test_execution_providers.push_back(std::move(ep)); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &test_execution_providers); + ASSERT_EQ(test.GetNumberOfNodesAfterRun(), 1); // This checks the operator was not inlined. + } +} + +static void RunTest4D( + int batch_size, + int q_num_heads, + int q_sequence_length, + int head_size, + int kv_sequence_length, + int kv_num_heads, + int v_head_size, + int past_sequence_length, + const std::vector& q, + const std::vector& k, + const std::vector& v, + const std::vector& attn_mask, + const std::initializer_list& attn_mask_bool, + const std::vector& past_key, + const std::vector& past_value, + int is_causal, // 0 + // int kv_num_heads, // not needed for 3D + // int q_num_heads, // not needed for 3D + int qk_matmul_output_mode, // 0 + float scale, // 1.0 + float softcap, // 0.0, + int softmax_precision, + TensorType tensor_type, + const std::vector& y, + const std::vector& present_key, + const std::vector& present_value, + const std::vector& qk_matmul_output, + bool disable_cpu, + bool disable_cuda, + bool disable_dml) { + int total_sequence_length = past_sequence_length + kv_sequence_length; + // inputs + std::vector q_shape = {batch_size, q_num_heads, q_sequence_length, head_size}; + std::vector k_shape = {batch_size, kv_num_heads, kv_sequence_length, head_size}; + std::vector v_shape = {batch_size, kv_num_heads, kv_sequence_length, v_head_size}; + + std::vector attn_mask_shape = {q_sequence_length, total_sequence_length}; + if (q_sequence_length * total_sequence_length != attn_mask.size() && attn_mask.size() > 0) { + if (batch_size * q_sequence_length * total_sequence_length == attn_mask.size()) { + attn_mask_shape = {batch_size, 1, q_sequence_length, total_sequence_length}; + } else if (batch_size * q_num_heads * q_sequence_length * total_sequence_length == attn_mask.size()) { + attn_mask_shape = {batch_size, q_num_heads, q_sequence_length, total_sequence_length}; + } else { + ORT_THROW("Invalid attention mask size: ", attn_mask.size(), + " expected ", q_sequence_length, "*", total_sequence_length, " or ", + batch_size, "*", q_sequence_length, "*", total_sequence_length); + } + } + + std::vector past_key_shape = {batch_size, kv_num_heads, past_sequence_length, head_size}; + std::vector past_value_shape = {batch_size, kv_num_heads, past_sequence_length, v_head_size}; + // outputs + std::vector y_shape = {batch_size, q_num_heads, q_sequence_length, v_head_size}; + std::vector present_key_shape = {batch_size, kv_num_heads, total_sequence_length, head_size}; + std::vector present_value_shape = {batch_size, kv_num_heads, total_sequence_length, v_head_size}; + std::vector qk_matmul_output_shape = {batch_size, q_num_heads, q_sequence_length, total_sequence_length}; + + std::vector> execution_providers; + SetProviders(execution_providers, disable_cpu, disable_cuda, disable_dml, tensor_type); + if (execution_providers.size() == 0) { + // Return early if CI pipeline does not support EP (e.g. CUDA EP for CPU CI pipeline) + return; + } + + for (auto& ep : execution_providers) { + OpTester test("Attention", 23, onnxruntime::kOnnxDomain); + AddInputs(test, q, k, v, attn_mask, attn_mask_bool, past_key, past_value, is_causal, + q_shape, k_shape, v_shape, attn_mask_shape, past_key_shape, past_value_shape, y_shape, present_key_shape, present_value_shape, qk_matmul_output_shape, + kv_num_heads, q_num_heads, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type, y, present_key, present_value, qk_matmul_output); + + std::vector> test_execution_providers; + test_execution_providers.push_back(std::move(ep)); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &test_execution_providers); + ASSERT_EQ(test.GetNumberOfNodesAfterRun(), 1); // This checks the operator was not inlined. + } +} + +TEST(AttentionTest, Attention3DDefault) { + int batch_size = 2; // Q.shape[0] + int q_num_heads = 3; // Q.shape[1] + int q_sequence_length = 4; // Q.shape[2] + int head_size = 8; // Q.shape[3] + int kv_sequence_length = 6; // K.shape[2] and V.shape[2] + int kv_num_heads = 3; // K.shape[1] and V.shape[1] + int v_head_size = 8; // V.shape[3] + int past_sequence_length = 5; // past_key.shape[2] and past_value.shape[2] + + std::vector q = {0.548814f, 0.715189f, 0.602763f, 0.544883f, 0.423655f, 0.645894f, 0.437587f, 0.891773f, 0.963663f, 0.383442f, 0.791725f, 0.528895f, 0.568045f, 0.925597f, 0.071036f, 0.087129f, 0.020218f, 0.832620f, 0.778157f, 0.870012f, 0.978618f, 0.799159f, 0.461479f, 0.780529f, 0.118274f, 0.639921f, 0.143353f, 0.944669f, 0.521848f, 0.414662f, 0.264556f, 0.774234f, 0.456150f, 0.568434f, 0.018790f, 0.617635f, 0.612096f, 0.616934f, 0.943748f, 0.681820f, 0.359508f, 0.437032f, 0.697631f, 0.060225f, 0.666767f, 0.670638f, 0.210383f, 0.128926f, 0.315428f, 0.363711f, 0.570197f, 0.438602f, 0.988374f, 0.102045f, 0.208877f, 0.161310f, 0.653108f, 0.253292f, 0.466311f, 0.244426f, 0.158970f, 0.110375f, 0.656330f, 0.138183f, 0.196582f, 0.368725f, 0.820993f, 0.097101f, 0.837945f, 0.096098f, 0.976459f, 0.468651f, 0.976761f, 0.604846f, 0.739264f, 0.039188f, 0.282807f, 0.120197f, 0.296140f, 0.118728f, 0.317983f, 0.414263f, 0.064147f, 0.692472f, 0.566601f, 0.265390f, 0.523248f, 0.093941f, 0.575947f, 0.929296f, 0.318569f, 0.667410f, 0.131798f, 0.716327f, 0.289406f, 0.183191f, 0.586513f, 0.020108f, 0.828940f, 0.004695f, 0.677817f, 0.270008f, 0.735194f, 0.962189f, 0.248753f, 0.576157f, 0.592042f, 0.572252f, 0.223082f, 0.952749f, 0.447125f, 0.846409f, 0.699479f, 0.297437f, 0.813798f, 0.396506f, 0.881103f, 0.581273f, 0.881735f, 0.692532f, 0.725254f, 0.501324f, 0.956084f, 0.643990f, 0.423855f, 0.606393f, 0.019193f, 0.301575f, 0.660174f, 0.290078f, 0.618015f, 0.428769f, 0.135474f, 0.298282f, 0.569965f, 0.590873f, 0.574325f, 0.653201f, 0.652103f, 0.431418f, 0.896547f, 0.367562f, 0.435865f, 0.891923f, 0.806194f, 0.703889f, 0.100227f, 0.919483f, 0.714241f, 0.998847f, 0.149448f, 0.868126f, 0.162493f, 0.615560f, 0.123820f, 0.848008f, 0.807319f, 0.569101f, 0.407183f, 0.069167f, 0.697429f, 0.453543f, 0.722056f, 0.866382f, 0.975522f, 0.855803f, 0.011714f, 0.359978f, 0.729991f, 0.171630f, 0.521037f, 0.054338f, 0.199997f, 0.018522f, 0.793698f, 0.223925f, 0.345352f, 0.928081f, 0.704414f, 0.031839f, 0.164694f, 0.621478f, 0.577229f, 0.237893f, 0.934214f, 0.613966f, 0.535633f, 0.589910f, 0.730122f, 0.311945f, 0.398221f, 0.209844f}; + std::vector k = {0.186193f, 0.944372f, 0.739551f, 0.490459f, 0.227415f, 0.254356f, 0.058029f, 0.434417f, 0.311796f, 0.696343f, 0.377752f, 0.179604f, 0.024679f, 0.067250f, 0.679393f, 0.453697f, 0.536579f, 0.896671f, 0.990339f, 0.216897f, 0.663078f, 0.263322f, 0.020651f, 0.758379f, 0.320017f, 0.383464f, 0.588317f, 0.831048f, 0.628982f, 0.872651f, 0.273542f, 0.798047f, 0.185636f, 0.952792f, 0.687488f, 0.215508f, 0.947371f, 0.730856f, 0.253942f, 0.213312f, 0.518201f, 0.025663f, 0.207470f, 0.424685f, 0.374170f, 0.463575f, 0.277629f, 0.586784f, 0.863856f, 0.117532f, 0.517379f, 0.132068f, 0.716860f, 0.396060f, 0.565421f, 0.183280f, 0.144848f, 0.488056f, 0.355613f, 0.940432f, 0.765325f, 0.748664f, 0.903720f, 0.083422f, 0.552192f, 0.584476f, 0.961936f, 0.292148f, 0.240829f, 0.100294f, 0.016430f, 0.929529f, 0.669917f, 0.785153f, 0.281730f, 0.586410f, 0.063955f, 0.485628f, 0.977495f, 0.876505f, 0.338159f, 0.961570f, 0.231702f, 0.949319f, 0.941378f, 0.799203f, 0.630448f, 0.874288f, 0.293020f, 0.848944f, 0.617877f, 0.013237f, 0.347234f, 0.148141f, 0.981829f, 0.478370f, 0.497391f, 0.639473f, 0.368585f, 0.136900f, 0.822118f, 0.189848f, 0.511319f, 0.224317f, 0.097844f, 0.862191f, 0.972919f, 0.960835f, 0.906555f, 0.774047f, 0.333145f, 0.081101f, 0.407241f, 0.232234f, 0.132488f, 0.053427f, 0.725594f, 0.011427f, 0.770581f, 0.146947f, 0.079522f, 0.089603f, 0.672048f, 0.245367f, 0.420539f, 0.557369f, 0.860551f, 0.727044f, 0.270328f, 0.131483f, 0.055374f, 0.301599f, 0.262118f, 0.456141f, 0.683281f, 0.695625f, 0.283519f, 0.379927f, 0.181151f, 0.788545f, 0.056848f, 0.696997f, 0.778695f, 0.777408f, 0.259423f, 0.373813f, 0.587600f, 0.272822f, 0.370853f, 0.197054f, 0.459856f, 0.044612f, 0.799796f, 0.076956f, 0.518835f, 0.306810f, 0.577543f, 0.959433f, 0.645570f, 0.035362f, 0.430402f, 0.510017f, 0.536178f, 0.681392f, 0.277596f, 0.128861f, 0.392676f, 0.956406f, 0.187131f, 0.903984f, 0.543806f, 0.456911f, 0.882041f, 0.458604f, 0.724168f, 0.399025f, 0.904044f, 0.690025f, 0.699622f, 0.327720f, 0.756779f, 0.636061f, 0.240020f, 0.160539f, 0.796391f, 0.959167f, 0.458139f, 0.590984f, 0.857723f, 0.457223f, 0.951874f, 0.575751f, 0.820767f, 0.908844f, 0.815524f, 0.159414f, 0.628898f, 0.398434f, 0.062713f, 0.424032f, 0.258684f, 0.849038f, 0.033305f, 0.958983f, 0.355369f, 0.356707f, 0.016329f, 0.185232f, 0.401260f, 0.929291f, 0.099615f, 0.945302f, 0.869489f, 0.454162f, 0.326701f, 0.232744f, 0.614465f, 0.033075f, 0.015606f, 0.428796f, 0.068074f, 0.251941f, 0.221161f, 0.253191f, 0.131055f, 0.012036f, 0.115484f, 0.618480f, 0.974256f, 0.990345f, 0.409054f, 0.162954f, 0.638762f, 0.490305f, 0.989410f, 0.065304f, 0.783234f, 0.288399f, 0.241419f, 0.662505f, 0.246063f, 0.665859f, 0.517309f, 0.424089f, 0.554688f, 0.287052f, 0.706575f, 0.414857f, 0.360546f, 0.828657f, 0.924967f, 0.046007f, 0.232627f, 0.348519f, 0.814966f, 0.985491f, 0.968972f, 0.904948f, 0.296556f, 0.992011f, 0.249420f, 0.105906f, 0.950953f, 0.233420f, 0.689768f, 0.058356f, 0.730709f, 0.881720f, 0.272437f, 0.379057f, 0.374296f, 0.748788f, 0.237807f, 0.171853f, 0.449292f, 0.304468f, 0.839189f, 0.237742f, 0.502389f, 0.942584f, 0.633998f, 0.867289f, 0.940210f, 0.750765f, 0.699575f, 0.967966f, 0.994401f, 0.451822f}; + std::vector v = {0.070870f, 0.292794f, 0.152355f, 0.417486f, 0.131289f, 0.604118f, 0.382808f, 0.895386f, 0.967795f, 0.546885f, 0.274824f, 0.592230f, 0.896761f, 0.406733f, 0.552078f, 0.271653f, 0.455444f, 0.401714f, 0.248413f, 0.505866f, 0.310381f, 0.373035f, 0.524970f, 0.750595f, 0.333507f, 0.924159f, 0.862319f, 0.048690f, 0.253643f, 0.446136f, 0.104628f, 0.348476f, 0.740098f, 0.680514f, 0.622384f, 0.710528f, 0.204924f, 0.341698f, 0.676242f, 0.879235f, 0.543678f, 0.282700f, 0.030235f, 0.710337f, 0.007884f, 0.372679f, 0.530537f, 0.922111f, 0.089495f, 0.405942f, 0.024313f, 0.342611f, 0.622231f, 0.279068f, 0.209750f, 0.115703f, 0.577140f, 0.695270f, 0.671957f, 0.948861f, 0.002703f, 0.647197f, 0.600392f, 0.588740f, 0.962770f, 0.016872f, 0.696482f, 0.813679f, 0.509807f, 0.333965f, 0.790840f, 0.097243f, 0.442036f, 0.519952f, 0.693956f, 0.090886f, 0.227759f, 0.410302f, 0.623295f, 0.886961f, 0.618826f, 0.133461f, 0.980580f, 0.871786f, 0.502721f, 0.922348f, 0.541381f, 0.923306f, 0.829897f, 0.968286f, 0.919783f, 0.036034f, 0.174772f, 0.389135f, 0.952143f, 0.300029f, 0.160468f, 0.886305f, 0.446394f, 0.907876f, 0.160230f, 0.661117f, 0.440264f, 0.076487f, 0.696463f, 0.247399f, 0.039616f, 0.059944f, 0.061079f, 0.907733f, 0.739884f, 0.898062f, 0.672582f, 0.528940f, 0.304446f, 0.997962f, 0.362189f, 0.470649f, 0.378245f, 0.979527f, 0.174658f, 0.327988f, 0.680349f, 0.063208f, 0.607249f, 0.477646f, 0.284000f, 0.238413f, 0.514513f, 0.367928f, 0.456520f, 0.337477f, 0.970494f, 0.133439f, 0.096804f, 0.343392f, 0.591027f, 0.659176f, 0.397257f, 0.999278f, 0.351893f, 0.721407f, 0.637583f, 0.813054f, 0.976226f, 0.889794f, 0.764562f, 0.698249f, 0.335498f, 0.147686f, 0.062636f, 0.241902f, 0.432281f, 0.521996f, 0.773084f, 0.958741f, 0.117320f, 0.107004f, 0.589695f, 0.745398f, 0.848150f, 0.935832f, 0.983426f, 0.399802f, 0.380335f, 0.147809f, 0.684934f, 0.656762f, 0.862063f, 0.097258f, 0.497777f, 0.581082f, 0.241557f, 0.169025f, 0.859581f, 0.058535f, 0.470621f, 0.115834f, 0.457059f, 0.979962f, 0.423706f, 0.857125f, 0.117316f, 0.271252f, 0.403793f, 0.399812f, 0.671384f, 0.344718f, 0.713767f, 0.639187f, 0.399161f, 0.431760f, 0.614528f, 0.070042f, 0.822407f, 0.653421f, 0.726342f, 0.536923f, 0.110477f, 0.405036f, 0.405374f, 0.321043f, 0.029950f, 0.737254f, 0.109784f, 0.606308f, 0.703218f, 0.634786f, 0.959142f, 0.103298f, 0.867167f, 0.029190f, 0.534917f, 0.404244f, 0.524184f, 0.365100f, 0.190567f, 0.019123f, 0.518150f, 0.842777f, 0.373216f, 0.222864f, 0.080532f, 0.085311f, 0.221396f, 0.100014f, 0.265040f, 0.066149f, 0.065605f, 0.856276f, 0.162120f, 0.559682f, 0.773456f, 0.456410f, 0.153369f, 0.199596f, 0.432984f, 0.528234f, 0.349440f, 0.781480f, 0.751022f, 0.927212f, 0.028953f, 0.895691f, 0.392569f, 0.878372f, 0.690785f, 0.987349f, 0.759282f, 0.364545f, 0.501063f, 0.376389f, 0.364912f, 0.260904f, 0.495970f, 0.681740f, 0.277340f, 0.524380f, 0.117380f, 0.159845f, 0.046806f, 0.970731f, 0.003860f, 0.178580f, 0.612867f, 0.081370f, 0.881896f, 0.719620f, 0.966390f, 0.507636f, 0.300404f, 0.549501f, 0.930819f, 0.520761f, 0.267207f, 0.877399f, 0.371919f, 0.001383f, 0.247685f, 0.318234f, 0.858777f, 0.458503f, 0.444587f, 0.336102f, 0.880678f, 0.945027f, 0.991890f, 0.376741f}; + std::vector y = {0.231425f, 0.572015f, 0.512671f, 0.279597f, 0.323671f, 0.474956f, 0.344308f, 0.454604f, 0.677763f, 0.427182f, 0.518734f, 0.586593f, 0.366221f, 0.617469f, 0.568592f, 0.711734f, 0.669865f, 0.477629f, 0.443902f, 0.657931f, 0.294461f, 0.444926f, 0.646996f, 0.624016f, 0.230982f, 0.577089f, 0.515905f, 0.281810f, 0.318254f, 0.478419f, 0.341943f, 0.456036f, 0.671153f, 0.419443f, 0.553783f, 0.617598f, 0.405113f, 0.612246f, 0.546371f, 0.691976f, 0.673135f, 0.474435f, 0.440636f, 0.656117f, 0.290562f, 0.437461f, 0.641583f, 0.628633f, 0.213246f, 0.573821f, 0.481404f, 0.314601f, 0.331198f, 0.479336f, 0.334377f, 0.416422f, 0.683961f, 0.438780f, 0.515832f, 0.594131f, 0.421298f, 0.581216f, 0.544020f, 0.665089f, 0.680353f, 0.496091f, 0.458597f, 0.644262f, 0.290254f, 0.439397f, 0.648748f, 0.622587f, 0.215077f, 0.561958f, 0.470216f, 0.315574f, 0.330295f, 0.476255f, 0.346486f, 0.433062f, 0.675563f, 0.430004f, 0.531206f, 0.603125f, 0.392384f, 0.606396f, 0.553218f, 0.688558f, 0.672218f, 0.481904f, 0.442930f, 0.664552f, 0.291008f, 0.447983f, 0.646510f, 0.629446f, 0.684469f, 0.333075f, 0.591230f, 0.723174f, 0.527550f, 0.429390f, 0.379490f, 0.407681f, 0.549282f, 0.325072f, 0.396408f, 0.659680f, 0.252716f, 0.438976f, 0.383743f, 0.537200f, 0.679028f, 0.472077f, 0.522267f, 0.258646f, 0.543009f, 0.648117f, 0.524809f, 0.455668f, 0.679968f, 0.320914f, 0.603929f, 0.720663f, 0.535420f, 0.427747f, 0.365637f, 0.402336f, 0.555204f, 0.329413f, 0.403408f, 0.674143f, 0.257068f, 0.430207f, 0.384353f, 0.534996f, 0.682781f, 0.472336f, 0.532518f, 0.255054f, 0.533888f, 0.631695f, 0.517009f, 0.460408f, 0.676468f, 0.310125f, 0.594133f, 0.720721f, 0.531343f, 0.428411f, 0.383201f, 0.400798f, 0.520066f, 0.313406f, 0.378438f, 0.660871f, 0.236947f, 0.471855f, 0.380046f, 0.533181f, 0.692040f, 0.460203f, 0.533379f, 0.249623f, 0.540433f, 0.638632f, 0.525843f, 0.453184f, 0.678596f, 0.343161f, 0.587705f, 0.727194f, 0.516850f, 0.421908f, 0.366269f, 0.400319f, 0.550307f, 0.323773f, 0.406273f, 0.671064f, 0.258597f, 0.441523f, 0.386403f, 0.537742f, 0.671703f, 0.464797f, 0.523623f, 0.248851f, 0.522889f, 0.644907f, 0.502470f, 0.446048f}; + + ASSERT_EQ(q.size(), batch_size * q_num_heads * q_sequence_length * head_size); + ASSERT_EQ(k.size(), batch_size * kv_num_heads * kv_sequence_length * head_size); + ASSERT_EQ(v.size(), batch_size * kv_num_heads * kv_sequence_length * v_head_size); + + RunTest3D(batch_size, q_num_heads, q_sequence_length, head_size, kv_sequence_length, kv_num_heads, v_head_size, past_sequence_length, + q, k, v, std::vector(), std::initializer_list(), std::vector(), std::vector(), + -1, -1, std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), -1, TensorType::kFloat, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type + y, std::vector(), std::vector(), std::vector(), + false, true, true // disable_cpu, disable_cuda, disable_dml + ); +} + +TEST(AttentionTest, Attention3DDefaultFloat16) { + int batch_size = 2; // Q.shape[0] + int q_num_heads = 3; // Q.shape[1] + int q_sequence_length = 4; // Q.shape[2] + int head_size = 8; // Q.shape[3] + int kv_sequence_length = 6; // K.shape[2] and V.shape[2] + int kv_num_heads = 3; // K.shape[1] and V.shape[1] + int v_head_size = 8; // V.shape[3] + int past_sequence_length = 5; // past_key.shape[2] and past_value.shape[2] + + std::vector q = {0.548814f, 0.715189f, 0.602763f, 0.544883f, 0.423655f, 0.645894f, 0.437587f, 0.891773f, 0.963663f, 0.383442f, 0.791725f, 0.528895f, 0.568045f, 0.925597f, 0.071036f, 0.087129f, 0.020218f, 0.832620f, 0.778157f, 0.870012f, 0.978618f, 0.799159f, 0.461479f, 0.780529f, 0.118274f, 0.639921f, 0.143353f, 0.944669f, 0.521848f, 0.414662f, 0.264556f, 0.774234f, 0.456150f, 0.568434f, 0.018790f, 0.617635f, 0.612096f, 0.616934f, 0.943748f, 0.681820f, 0.359508f, 0.437032f, 0.697631f, 0.060225f, 0.666767f, 0.670638f, 0.210383f, 0.128926f, 0.315428f, 0.363711f, 0.570197f, 0.438602f, 0.988374f, 0.102045f, 0.208877f, 0.161310f, 0.653108f, 0.253292f, 0.466311f, 0.244426f, 0.158970f, 0.110375f, 0.656330f, 0.138183f, 0.196582f, 0.368725f, 0.820993f, 0.097101f, 0.837945f, 0.096098f, 0.976459f, 0.468651f, 0.976761f, 0.604846f, 0.739264f, 0.039188f, 0.282807f, 0.120197f, 0.296140f, 0.118728f, 0.317983f, 0.414263f, 0.064147f, 0.692472f, 0.566601f, 0.265390f, 0.523248f, 0.093941f, 0.575947f, 0.929296f, 0.318569f, 0.667410f, 0.131798f, 0.716327f, 0.289406f, 0.183191f, 0.586513f, 0.020108f, 0.828940f, 0.004695f, 0.677817f, 0.270008f, 0.735194f, 0.962189f, 0.248753f, 0.576157f, 0.592042f, 0.572252f, 0.223082f, 0.952749f, 0.447125f, 0.846409f, 0.699479f, 0.297437f, 0.813798f, 0.396506f, 0.881103f, 0.581273f, 0.881735f, 0.692532f, 0.725254f, 0.501324f, 0.956084f, 0.643990f, 0.423855f, 0.606393f, 0.019193f, 0.301575f, 0.660174f, 0.290078f, 0.618015f, 0.428769f, 0.135474f, 0.298282f, 0.569965f, 0.590873f, 0.574325f, 0.653201f, 0.652103f, 0.431418f, 0.896547f, 0.367562f, 0.435865f, 0.891923f, 0.806194f, 0.703889f, 0.100227f, 0.919483f, 0.714241f, 0.998847f, 0.149448f, 0.868126f, 0.162493f, 0.615560f, 0.123820f, 0.848008f, 0.807319f, 0.569101f, 0.407183f, 0.069167f, 0.697429f, 0.453543f, 0.722056f, 0.866382f, 0.975522f, 0.855803f, 0.011714f, 0.359978f, 0.729991f, 0.171630f, 0.521037f, 0.054338f, 0.199997f, 0.018522f, 0.793698f, 0.223925f, 0.345352f, 0.928081f, 0.704414f, 0.031839f, 0.164694f, 0.621478f, 0.577229f, 0.237893f, 0.934214f, 0.613966f, 0.535633f, 0.589910f, 0.730122f, 0.311945f, 0.398221f, 0.209844f}; + std::vector k = {0.186193f, 0.944372f, 0.739551f, 0.490459f, 0.227415f, 0.254356f, 0.058029f, 0.434417f, 0.311796f, 0.696343f, 0.377752f, 0.179604f, 0.024679f, 0.067250f, 0.679393f, 0.453697f, 0.536579f, 0.896671f, 0.990339f, 0.216897f, 0.663078f, 0.263322f, 0.020651f, 0.758379f, 0.320017f, 0.383464f, 0.588317f, 0.831048f, 0.628982f, 0.872651f, 0.273542f, 0.798047f, 0.185636f, 0.952792f, 0.687488f, 0.215508f, 0.947371f, 0.730856f, 0.253942f, 0.213312f, 0.518201f, 0.025663f, 0.207470f, 0.424685f, 0.374170f, 0.463575f, 0.277629f, 0.586784f, 0.863856f, 0.117532f, 0.517379f, 0.132068f, 0.716860f, 0.396060f, 0.565421f, 0.183280f, 0.144848f, 0.488056f, 0.355613f, 0.940432f, 0.765325f, 0.748664f, 0.903720f, 0.083422f, 0.552192f, 0.584476f, 0.961936f, 0.292148f, 0.240829f, 0.100294f, 0.016430f, 0.929529f, 0.669917f, 0.785153f, 0.281730f, 0.586410f, 0.063955f, 0.485628f, 0.977495f, 0.876505f, 0.338159f, 0.961570f, 0.231702f, 0.949319f, 0.941378f, 0.799203f, 0.630448f, 0.874288f, 0.293020f, 0.848944f, 0.617877f, 0.013237f, 0.347234f, 0.148141f, 0.981829f, 0.478370f, 0.497391f, 0.639473f, 0.368585f, 0.136900f, 0.822118f, 0.189848f, 0.511319f, 0.224317f, 0.097844f, 0.862191f, 0.972919f, 0.960835f, 0.906555f, 0.774047f, 0.333145f, 0.081101f, 0.407241f, 0.232234f, 0.132488f, 0.053427f, 0.725594f, 0.011427f, 0.770581f, 0.146947f, 0.079522f, 0.089603f, 0.672048f, 0.245367f, 0.420539f, 0.557369f, 0.860551f, 0.727044f, 0.270328f, 0.131483f, 0.055374f, 0.301599f, 0.262118f, 0.456141f, 0.683281f, 0.695625f, 0.283519f, 0.379927f, 0.181151f, 0.788545f, 0.056848f, 0.696997f, 0.778695f, 0.777408f, 0.259423f, 0.373813f, 0.587600f, 0.272822f, 0.370853f, 0.197054f, 0.459856f, 0.044612f, 0.799796f, 0.076956f, 0.518835f, 0.306810f, 0.577543f, 0.959433f, 0.645570f, 0.035362f, 0.430402f, 0.510017f, 0.536178f, 0.681392f, 0.277596f, 0.128861f, 0.392676f, 0.956406f, 0.187131f, 0.903984f, 0.543806f, 0.456911f, 0.882041f, 0.458604f, 0.724168f, 0.399025f, 0.904044f, 0.690025f, 0.699622f, 0.327720f, 0.756779f, 0.636061f, 0.240020f, 0.160539f, 0.796391f, 0.959167f, 0.458139f, 0.590984f, 0.857723f, 0.457223f, 0.951874f, 0.575751f, 0.820767f, 0.908844f, 0.815524f, 0.159414f, 0.628898f, 0.398434f, 0.062713f, 0.424032f, 0.258684f, 0.849038f, 0.033305f, 0.958983f, 0.355369f, 0.356707f, 0.016329f, 0.185232f, 0.401260f, 0.929291f, 0.099615f, 0.945302f, 0.869489f, 0.454162f, 0.326701f, 0.232744f, 0.614465f, 0.033075f, 0.015606f, 0.428796f, 0.068074f, 0.251941f, 0.221161f, 0.253191f, 0.131055f, 0.012036f, 0.115484f, 0.618480f, 0.974256f, 0.990345f, 0.409054f, 0.162954f, 0.638762f, 0.490305f, 0.989410f, 0.065304f, 0.783234f, 0.288399f, 0.241419f, 0.662505f, 0.246063f, 0.665859f, 0.517309f, 0.424089f, 0.554688f, 0.287052f, 0.706575f, 0.414857f, 0.360546f, 0.828657f, 0.924967f, 0.046007f, 0.232627f, 0.348519f, 0.814966f, 0.985491f, 0.968972f, 0.904948f, 0.296556f, 0.992011f, 0.249420f, 0.105906f, 0.950953f, 0.233420f, 0.689768f, 0.058356f, 0.730709f, 0.881720f, 0.272437f, 0.379057f, 0.374296f, 0.748788f, 0.237807f, 0.171853f, 0.449292f, 0.304468f, 0.839189f, 0.237742f, 0.502389f, 0.942584f, 0.633998f, 0.867289f, 0.940210f, 0.750765f, 0.699575f, 0.967966f, 0.994401f, 0.451822f}; + std::vector v = {0.070870f, 0.292794f, 0.152355f, 0.417486f, 0.131289f, 0.604118f, 0.382808f, 0.895386f, 0.967795f, 0.546885f, 0.274824f, 0.592230f, 0.896761f, 0.406733f, 0.552078f, 0.271653f, 0.455444f, 0.401714f, 0.248413f, 0.505866f, 0.310381f, 0.373035f, 0.524970f, 0.750595f, 0.333507f, 0.924159f, 0.862319f, 0.048690f, 0.253643f, 0.446136f, 0.104628f, 0.348476f, 0.740098f, 0.680514f, 0.622384f, 0.710528f, 0.204924f, 0.341698f, 0.676242f, 0.879235f, 0.543678f, 0.282700f, 0.030235f, 0.710337f, 0.007884f, 0.372679f, 0.530537f, 0.922111f, 0.089495f, 0.405942f, 0.024313f, 0.342611f, 0.622231f, 0.279068f, 0.209750f, 0.115703f, 0.577140f, 0.695270f, 0.671957f, 0.948861f, 0.002703f, 0.647197f, 0.600392f, 0.588740f, 0.962770f, 0.016872f, 0.696482f, 0.813679f, 0.509807f, 0.333965f, 0.790840f, 0.097243f, 0.442036f, 0.519952f, 0.693956f, 0.090886f, 0.227759f, 0.410302f, 0.623295f, 0.886961f, 0.618826f, 0.133461f, 0.980580f, 0.871786f, 0.502721f, 0.922348f, 0.541381f, 0.923306f, 0.829897f, 0.968286f, 0.919783f, 0.036034f, 0.174772f, 0.389135f, 0.952143f, 0.300029f, 0.160468f, 0.886305f, 0.446394f, 0.907876f, 0.160230f, 0.661117f, 0.440264f, 0.076487f, 0.696463f, 0.247399f, 0.039616f, 0.059944f, 0.061079f, 0.907733f, 0.739884f, 0.898062f, 0.672582f, 0.528940f, 0.304446f, 0.997962f, 0.362189f, 0.470649f, 0.378245f, 0.979527f, 0.174658f, 0.327988f, 0.680349f, 0.063208f, 0.607249f, 0.477646f, 0.284000f, 0.238413f, 0.514513f, 0.367928f, 0.456520f, 0.337477f, 0.970494f, 0.133439f, 0.096804f, 0.343392f, 0.591027f, 0.659176f, 0.397257f, 0.999278f, 0.351893f, 0.721407f, 0.637583f, 0.813054f, 0.976226f, 0.889794f, 0.764562f, 0.698249f, 0.335498f, 0.147686f, 0.062636f, 0.241902f, 0.432281f, 0.521996f, 0.773084f, 0.958741f, 0.117320f, 0.107004f, 0.589695f, 0.745398f, 0.848150f, 0.935832f, 0.983426f, 0.399802f, 0.380335f, 0.147809f, 0.684934f, 0.656762f, 0.862063f, 0.097258f, 0.497777f, 0.581082f, 0.241557f, 0.169025f, 0.859581f, 0.058535f, 0.470621f, 0.115834f, 0.457059f, 0.979962f, 0.423706f, 0.857125f, 0.117316f, 0.271252f, 0.403793f, 0.399812f, 0.671384f, 0.344718f, 0.713767f, 0.639187f, 0.399161f, 0.431760f, 0.614528f, 0.070042f, 0.822407f, 0.653421f, 0.726342f, 0.536923f, 0.110477f, 0.405036f, 0.405374f, 0.321043f, 0.029950f, 0.737254f, 0.109784f, 0.606308f, 0.703218f, 0.634786f, 0.959142f, 0.103298f, 0.867167f, 0.029190f, 0.534917f, 0.404244f, 0.524184f, 0.365100f, 0.190567f, 0.019123f, 0.518150f, 0.842777f, 0.373216f, 0.222864f, 0.080532f, 0.085311f, 0.221396f, 0.100014f, 0.265040f, 0.066149f, 0.065605f, 0.856276f, 0.162120f, 0.559682f, 0.773456f, 0.456410f, 0.153369f, 0.199596f, 0.432984f, 0.528234f, 0.349440f, 0.781480f, 0.751022f, 0.927212f, 0.028953f, 0.895691f, 0.392569f, 0.878372f, 0.690785f, 0.987349f, 0.759282f, 0.364545f, 0.501063f, 0.376389f, 0.364912f, 0.260904f, 0.495970f, 0.681740f, 0.277340f, 0.524380f, 0.117380f, 0.159845f, 0.046806f, 0.970731f, 0.003860f, 0.178580f, 0.612867f, 0.081370f, 0.881896f, 0.719620f, 0.966390f, 0.507636f, 0.300404f, 0.549501f, 0.930819f, 0.520761f, 0.267207f, 0.877399f, 0.371919f, 0.001383f, 0.247685f, 0.318234f, 0.858777f, 0.458503f, 0.444587f, 0.336102f, 0.880678f, 0.945027f, 0.991890f, 0.376741f}; + std::vector y = {0.231425f, 0.572015f, 0.512671f, 0.279597f, 0.323671f, 0.474956f, 0.344308f, 0.454604f, 0.677763f, 0.427182f, 0.518734f, 0.586593f, 0.366221f, 0.617469f, 0.568592f, 0.711734f, 0.669865f, 0.477629f, 0.443902f, 0.657931f, 0.294461f, 0.444926f, 0.646996f, 0.624016f, 0.230982f, 0.577089f, 0.515905f, 0.281810f, 0.318254f, 0.478419f, 0.341943f, 0.456036f, 0.671153f, 0.419443f, 0.553783f, 0.617598f, 0.405113f, 0.612246f, 0.546371f, 0.691976f, 0.673135f, 0.474435f, 0.440636f, 0.656117f, 0.290562f, 0.437461f, 0.641583f, 0.628633f, 0.213246f, 0.573821f, 0.481404f, 0.314601f, 0.331198f, 0.479336f, 0.334377f, 0.416422f, 0.683961f, 0.438780f, 0.515832f, 0.594131f, 0.421298f, 0.581216f, 0.544020f, 0.665089f, 0.680353f, 0.496091f, 0.458597f, 0.644262f, 0.290254f, 0.439397f, 0.648748f, 0.622587f, 0.215077f, 0.561958f, 0.470216f, 0.315574f, 0.330295f, 0.476255f, 0.346486f, 0.433062f, 0.675563f, 0.430004f, 0.531206f, 0.603125f, 0.392384f, 0.606396f, 0.553218f, 0.688558f, 0.672218f, 0.481904f, 0.442930f, 0.664552f, 0.291008f, 0.447983f, 0.646510f, 0.629446f, 0.684469f, 0.333075f, 0.591230f, 0.723174f, 0.527550f, 0.429390f, 0.379490f, 0.407681f, 0.549282f, 0.325072f, 0.396408f, 0.659680f, 0.252716f, 0.438976f, 0.383743f, 0.537200f, 0.679028f, 0.472077f, 0.522267f, 0.258646f, 0.543009f, 0.648117f, 0.524809f, 0.455668f, 0.679968f, 0.320914f, 0.603929f, 0.720663f, 0.535420f, 0.427747f, 0.365637f, 0.402336f, 0.555204f, 0.329413f, 0.403408f, 0.674143f, 0.257068f, 0.430207f, 0.384353f, 0.534996f, 0.682781f, 0.472336f, 0.532518f, 0.255054f, 0.533888f, 0.631695f, 0.517009f, 0.460408f, 0.676468f, 0.310125f, 0.594133f, 0.720721f, 0.531343f, 0.428411f, 0.383201f, 0.400798f, 0.520066f, 0.313406f, 0.378438f, 0.660871f, 0.236947f, 0.471855f, 0.380046f, 0.533181f, 0.692040f, 0.460203f, 0.533379f, 0.249623f, 0.540433f, 0.638632f, 0.525843f, 0.453184f, 0.678596f, 0.343161f, 0.587705f, 0.727194f, 0.516850f, 0.421908f, 0.366269f, 0.400319f, 0.550307f, 0.323773f, 0.406273f, 0.671064f, 0.258597f, 0.441523f, 0.386403f, 0.537742f, 0.671703f, 0.464797f, 0.523623f, 0.248851f, 0.522889f, 0.644907f, 0.502470f, 0.446048f}; + + ASSERT_EQ(q.size(), batch_size * q_num_heads * q_sequence_length * head_size); + ASSERT_EQ(k.size(), batch_size * kv_num_heads * kv_sequence_length * head_size); + ASSERT_EQ(v.size(), batch_size * kv_num_heads * kv_sequence_length * v_head_size); + + RunTest3D(batch_size, q_num_heads, q_sequence_length, head_size, kv_sequence_length, kv_num_heads, v_head_size, past_sequence_length, + q, k, v, std::vector(), std::initializer_list(), std::vector(), std::vector(), + -1, -1, std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), -1, TensorType::kFloat16, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type + y, std::vector(), std::vector(), std::vector(), + false, true, true // disable_cpu, disable_cuda, disable_dml + ); +} + +TEST(AttentionTest, Attention4DDefaultBasic) { + int batch_size = 2; // Q.shape[0] + int q_num_heads = 3; // Q.shape[1] + int q_sequence_length = 4; // Q.shape[2] + int head_size = 8; // Q.shape[3] + int kv_sequence_length = 6; // K.shape[2] and V.shape[2] + int kv_num_heads = 3; // K.shape[1] and V.shape[1] + int v_head_size = 8; // V.shape[3] + int past_sequence_length = 5; // past_key.shape[2] and past_value.shape[2] + + std::vector q = {1.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f}; + std::vector k = {1.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f}; + std::vector v = {1.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f}; + std::vector y = {0.221683f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.166667f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.166667f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.166667f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f}; + + ASSERT_EQ(q.size(), batch_size * q_num_heads * q_sequence_length * head_size); + ASSERT_EQ(k.size(), batch_size * kv_num_heads * kv_sequence_length * head_size); + ASSERT_EQ(v.size(), batch_size * kv_num_heads * kv_sequence_length * v_head_size); + + RunTest4D(batch_size, q_num_heads, q_sequence_length, head_size, kv_sequence_length, kv_num_heads, v_head_size, past_sequence_length, + q, k, v, std::vector(), std::initializer_list(), std::vector(), std::vector(), + -1, -1, std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), -1, TensorType::kFloat, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type + y, std::vector(), std::vector(), std::vector(), + false, true, true // disable_cpu, disable_cuda, disable_dml + ); +} + +TEST(AttentionTest, Attention4DDefault) { + int batch_size = 2; // Q.shape[0] + int q_num_heads = 3; // Q.shape[1] + int q_sequence_length = 4; // Q.shape[2] + int head_size = 8; // Q.shape[3] + int kv_sequence_length = 6; // K.shape[2] and V.shape[2] + int kv_num_heads = 3; // K.shape[1] and V.shape[1] + int v_head_size = 8; // V.shape[3] + int past_sequence_length = 5; // past_key.shape[2] and past_value.shape[2] + + std::vector q = {0.548814f, 0.715189f, 0.602763f, 0.544883f, 0.423655f, 0.645894f, 0.437587f, 0.891773f, 0.963663f, 0.383442f, 0.791725f, 0.528895f, 0.568045f, 0.925597f, 0.071036f, 0.087129f, 0.020218f, 0.832620f, 0.778157f, 0.870012f, 0.978618f, 0.799159f, 0.461479f, 0.780529f, 0.118274f, 0.639921f, 0.143353f, 0.944669f, 0.521848f, 0.414662f, 0.264556f, 0.774234f, 0.456150f, 0.568434f, 0.018790f, 0.617635f, 0.612096f, 0.616934f, 0.943748f, 0.681820f, 0.359508f, 0.437032f, 0.697631f, 0.060225f, 0.666767f, 0.670638f, 0.210383f, 0.128926f, 0.315428f, 0.363711f, 0.570197f, 0.438602f, 0.988374f, 0.102045f, 0.208877f, 0.161310f, 0.653108f, 0.253292f, 0.466311f, 0.244426f, 0.158970f, 0.110375f, 0.656330f, 0.138183f, 0.196582f, 0.368725f, 0.820993f, 0.097101f, 0.837945f, 0.096098f, 0.976459f, 0.468651f, 0.976761f, 0.604846f, 0.739264f, 0.039188f, 0.282807f, 0.120197f, 0.296140f, 0.118728f, 0.317983f, 0.414263f, 0.064147f, 0.692472f, 0.566601f, 0.265390f, 0.523248f, 0.093941f, 0.575947f, 0.929296f, 0.318569f, 0.667410f, 0.131798f, 0.716327f, 0.289406f, 0.183191f, 0.586513f, 0.020108f, 0.828940f, 0.004695f, 0.677817f, 0.270008f, 0.735194f, 0.962189f, 0.248753f, 0.576157f, 0.592042f, 0.572252f, 0.223082f, 0.952749f, 0.447125f, 0.846409f, 0.699479f, 0.297437f, 0.813798f, 0.396506f, 0.881103f, 0.581273f, 0.881735f, 0.692532f, 0.725254f, 0.501324f, 0.956084f, 0.643990f, 0.423855f, 0.606393f, 0.019193f, 0.301575f, 0.660174f, 0.290078f, 0.618015f, 0.428769f, 0.135474f, 0.298282f, 0.569965f, 0.590873f, 0.574325f, 0.653201f, 0.652103f, 0.431418f, 0.896547f, 0.367562f, 0.435865f, 0.891923f, 0.806194f, 0.703889f, 0.100227f, 0.919483f, 0.714241f, 0.998847f, 0.149448f, 0.868126f, 0.162493f, 0.615560f, 0.123820f, 0.848008f, 0.807319f, 0.569101f, 0.407183f, 0.069167f, 0.697429f, 0.453543f, 0.722056f, 0.866382f, 0.975522f, 0.855803f, 0.011714f, 0.359978f, 0.729991f, 0.171630f, 0.521037f, 0.054338f, 0.199997f, 0.018522f, 0.793698f, 0.223925f, 0.345352f, 0.928081f, 0.704414f, 0.031839f, 0.164694f, 0.621478f, 0.577229f, 0.237893f, 0.934214f, 0.613966f, 0.535633f, 0.589910f, 0.730122f, 0.311945f, 0.398221f, 0.209844f}; + std::vector k = {0.186193f, 0.944372f, 0.739551f, 0.490459f, 0.227415f, 0.254356f, 0.058029f, 0.434417f, 0.311796f, 0.696343f, 0.377752f, 0.179604f, 0.024679f, 0.067250f, 0.679393f, 0.453697f, 0.536579f, 0.896671f, 0.990339f, 0.216897f, 0.663078f, 0.263322f, 0.020651f, 0.758379f, 0.320017f, 0.383464f, 0.588317f, 0.831048f, 0.628982f, 0.872651f, 0.273542f, 0.798047f, 0.185636f, 0.952792f, 0.687488f, 0.215508f, 0.947371f, 0.730856f, 0.253942f, 0.213312f, 0.518201f, 0.025663f, 0.207470f, 0.424685f, 0.374170f, 0.463575f, 0.277629f, 0.586784f, 0.863856f, 0.117532f, 0.517379f, 0.132068f, 0.716860f, 0.396060f, 0.565421f, 0.183280f, 0.144848f, 0.488056f, 0.355613f, 0.940432f, 0.765325f, 0.748664f, 0.903720f, 0.083422f, 0.552192f, 0.584476f, 0.961936f, 0.292148f, 0.240829f, 0.100294f, 0.016430f, 0.929529f, 0.669917f, 0.785153f, 0.281730f, 0.586410f, 0.063955f, 0.485628f, 0.977495f, 0.876505f, 0.338159f, 0.961570f, 0.231702f, 0.949319f, 0.941378f, 0.799203f, 0.630448f, 0.874288f, 0.293020f, 0.848944f, 0.617877f, 0.013237f, 0.347234f, 0.148141f, 0.981829f, 0.478370f, 0.497391f, 0.639473f, 0.368585f, 0.136900f, 0.822118f, 0.189848f, 0.511319f, 0.224317f, 0.097844f, 0.862191f, 0.972919f, 0.960835f, 0.906555f, 0.774047f, 0.333145f, 0.081101f, 0.407241f, 0.232234f, 0.132488f, 0.053427f, 0.725594f, 0.011427f, 0.770581f, 0.146947f, 0.079522f, 0.089603f, 0.672048f, 0.245367f, 0.420539f, 0.557369f, 0.860551f, 0.727044f, 0.270328f, 0.131483f, 0.055374f, 0.301599f, 0.262118f, 0.456141f, 0.683281f, 0.695625f, 0.283519f, 0.379927f, 0.181151f, 0.788545f, 0.056848f, 0.696997f, 0.778695f, 0.777408f, 0.259423f, 0.373813f, 0.587600f, 0.272822f, 0.370853f, 0.197054f, 0.459856f, 0.044612f, 0.799796f, 0.076956f, 0.518835f, 0.306810f, 0.577543f, 0.959433f, 0.645570f, 0.035362f, 0.430402f, 0.510017f, 0.536178f, 0.681392f, 0.277596f, 0.128861f, 0.392676f, 0.956406f, 0.187131f, 0.903984f, 0.543806f, 0.456911f, 0.882041f, 0.458604f, 0.724168f, 0.399025f, 0.904044f, 0.690025f, 0.699622f, 0.327720f, 0.756779f, 0.636061f, 0.240020f, 0.160539f, 0.796391f, 0.959167f, 0.458139f, 0.590984f, 0.857723f, 0.457223f, 0.951874f, 0.575751f, 0.820767f, 0.908844f, 0.815524f, 0.159414f, 0.628898f, 0.398434f, 0.062713f, 0.424032f, 0.258684f, 0.849038f, 0.033305f, 0.958983f, 0.355369f, 0.356707f, 0.016329f, 0.185232f, 0.401260f, 0.929291f, 0.099615f, 0.945302f, 0.869489f, 0.454162f, 0.326701f, 0.232744f, 0.614465f, 0.033075f, 0.015606f, 0.428796f, 0.068074f, 0.251941f, 0.221161f, 0.253191f, 0.131055f, 0.012036f, 0.115484f, 0.618480f, 0.974256f, 0.990345f, 0.409054f, 0.162954f, 0.638762f, 0.490305f, 0.989410f, 0.065304f, 0.783234f, 0.288399f, 0.241419f, 0.662505f, 0.246063f, 0.665859f, 0.517309f, 0.424089f, 0.554688f, 0.287052f, 0.706575f, 0.414857f, 0.360546f, 0.828657f, 0.924967f, 0.046007f, 0.232627f, 0.348519f, 0.814966f, 0.985491f, 0.968972f, 0.904948f, 0.296556f, 0.992011f, 0.249420f, 0.105906f, 0.950953f, 0.233420f, 0.689768f, 0.058356f, 0.730709f, 0.881720f, 0.272437f, 0.379057f, 0.374296f, 0.748788f, 0.237807f, 0.171853f, 0.449292f, 0.304468f, 0.839189f, 0.237742f, 0.502389f, 0.942584f, 0.633998f, 0.867289f, 0.940210f, 0.750765f, 0.699575f, 0.967966f, 0.994401f, 0.451822f}; + std::vector v = {0.070870f, 0.292794f, 0.152355f, 0.417486f, 0.131289f, 0.604118f, 0.382808f, 0.895386f, 0.967795f, 0.546885f, 0.274824f, 0.592230f, 0.896761f, 0.406733f, 0.552078f, 0.271653f, 0.455444f, 0.401714f, 0.248413f, 0.505866f, 0.310381f, 0.373035f, 0.524970f, 0.750595f, 0.333507f, 0.924159f, 0.862319f, 0.048690f, 0.253643f, 0.446136f, 0.104628f, 0.348476f, 0.740098f, 0.680514f, 0.622384f, 0.710528f, 0.204924f, 0.341698f, 0.676242f, 0.879235f, 0.543678f, 0.282700f, 0.030235f, 0.710337f, 0.007884f, 0.372679f, 0.530537f, 0.922111f, 0.089495f, 0.405942f, 0.024313f, 0.342611f, 0.622231f, 0.279068f, 0.209750f, 0.115703f, 0.577140f, 0.695270f, 0.671957f, 0.948861f, 0.002703f, 0.647197f, 0.600392f, 0.588740f, 0.962770f, 0.016872f, 0.696482f, 0.813679f, 0.509807f, 0.333965f, 0.790840f, 0.097243f, 0.442036f, 0.519952f, 0.693956f, 0.090886f, 0.227759f, 0.410302f, 0.623295f, 0.886961f, 0.618826f, 0.133461f, 0.980580f, 0.871786f, 0.502721f, 0.922348f, 0.541381f, 0.923306f, 0.829897f, 0.968286f, 0.919783f, 0.036034f, 0.174772f, 0.389135f, 0.952143f, 0.300029f, 0.160468f, 0.886305f, 0.446394f, 0.907876f, 0.160230f, 0.661117f, 0.440264f, 0.076487f, 0.696463f, 0.247399f, 0.039616f, 0.059944f, 0.061079f, 0.907733f, 0.739884f, 0.898062f, 0.672582f, 0.528940f, 0.304446f, 0.997962f, 0.362189f, 0.470649f, 0.378245f, 0.979527f, 0.174658f, 0.327988f, 0.680349f, 0.063208f, 0.607249f, 0.477646f, 0.284000f, 0.238413f, 0.514513f, 0.367928f, 0.456520f, 0.337477f, 0.970494f, 0.133439f, 0.096804f, 0.343392f, 0.591027f, 0.659176f, 0.397257f, 0.999278f, 0.351893f, 0.721407f, 0.637583f, 0.813054f, 0.976226f, 0.889794f, 0.764562f, 0.698249f, 0.335498f, 0.147686f, 0.062636f, 0.241902f, 0.432281f, 0.521996f, 0.773084f, 0.958741f, 0.117320f, 0.107004f, 0.589695f, 0.745398f, 0.848150f, 0.935832f, 0.983426f, 0.399802f, 0.380335f, 0.147809f, 0.684934f, 0.656762f, 0.862063f, 0.097258f, 0.497777f, 0.581082f, 0.241557f, 0.169025f, 0.859581f, 0.058535f, 0.470621f, 0.115834f, 0.457059f, 0.979962f, 0.423706f, 0.857125f, 0.117316f, 0.271252f, 0.403793f, 0.399812f, 0.671384f, 0.344718f, 0.713767f, 0.639187f, 0.399161f, 0.431760f, 0.614528f, 0.070042f, 0.822407f, 0.653421f, 0.726342f, 0.536923f, 0.110477f, 0.405036f, 0.405374f, 0.321043f, 0.029950f, 0.737254f, 0.109784f, 0.606308f, 0.703218f, 0.634786f, 0.959142f, 0.103298f, 0.867167f, 0.029190f, 0.534917f, 0.404244f, 0.524184f, 0.365100f, 0.190567f, 0.019123f, 0.518150f, 0.842777f, 0.373216f, 0.222864f, 0.080532f, 0.085311f, 0.221396f, 0.100014f, 0.265040f, 0.066149f, 0.065605f, 0.856276f, 0.162120f, 0.559682f, 0.773456f, 0.456410f, 0.153369f, 0.199596f, 0.432984f, 0.528234f, 0.349440f, 0.781480f, 0.751022f, 0.927212f, 0.028953f, 0.895691f, 0.392569f, 0.878372f, 0.690785f, 0.987349f, 0.759282f, 0.364545f, 0.501063f, 0.376389f, 0.364912f, 0.260904f, 0.495970f, 0.681740f, 0.277340f, 0.524380f, 0.117380f, 0.159845f, 0.046806f, 0.970731f, 0.003860f, 0.178580f, 0.612867f, 0.081370f, 0.881896f, 0.719620f, 0.966390f, 0.507636f, 0.300404f, 0.549501f, 0.930819f, 0.520761f, 0.267207f, 0.877399f, 0.371919f, 0.001383f, 0.247685f, 0.318234f, 0.858777f, 0.458503f, 0.444587f, 0.336102f, 0.880678f, 0.945027f, 0.991890f, 0.376741f}; + + ASSERT_EQ(q.size(), batch_size * q_num_heads * q_sequence_length * head_size); + ASSERT_EQ(k.size(), batch_size * kv_num_heads * kv_sequence_length * head_size); + ASSERT_EQ(v.size(), batch_size * kv_num_heads * kv_sequence_length * v_head_size); + + std::vector y = {0.501465f, 0.543511f, 0.398088f, 0.474061f, 0.290507f, 0.423018f, 0.447999f, 0.672390f, 0.500878f, 0.545140f, 0.402253f, 0.478354f, 0.278711f, 0.420929f, 0.451124f, 0.682613f, 0.496502f, 0.557356f, 0.419293f, 0.467867f, 0.280946f, 0.422295f, 0.445183f, 0.675748f, 0.498804f, 0.545264f, 0.399543f, 0.471287f, 0.287601f, 0.424845f, 0.443877f, 0.670841f, 0.580098f, 0.450536f, 0.702941f, 0.538382f, 0.329768f, 0.543394f, 0.613723f, 0.562010f, 0.584549f, 0.447129f, 0.673676f, 0.537643f, 0.342950f, 0.515742f, 0.613437f, 0.502951f, 0.585248f, 0.443070f, 0.676620f, 0.549025f, 0.343112f, 0.522440f, 0.611621f, 0.507324f, 0.580745f, 0.461632f, 0.668496f, 0.507376f, 0.336816f, 0.500750f, 0.618162f, 0.500909f, 0.464240f, 0.493342f, 0.380525f, 0.530712f, 0.397056f, 0.582067f, 0.443341f, 0.559227f, 0.467916f, 0.503694f, 0.373170f, 0.549178f, 0.387171f, 0.587037f, 0.448581f, 0.561591f, 0.478681f, 0.496704f, 0.369457f, 0.545459f, 0.392339f, 0.587842f, 0.452645f, 0.576330f, 0.483897f, 0.491793f, 0.360676f, 0.530990f, 0.380686f, 0.603393f, 0.467172f, 0.583590f, 0.642787f, 0.470883f, 0.686034f, 0.642719f, 0.386365f, 0.366454f, 0.467120f, 0.405736f, 0.644347f, 0.466390f, 0.684379f, 0.640710f, 0.385963f, 0.366271f, 0.472645f, 0.403025f, 0.631421f, 0.453237f, 0.677676f, 0.643979f, 0.390879f, 0.377663f, 0.467158f, 0.401772f, 0.637457f, 0.459313f, 0.677889f, 0.659685f, 0.383362f, 0.379251f, 0.453763f, 0.401437f, 0.555998f, 0.186013f, 0.455395f, 0.406430f, 0.395553f, 0.526708f, 0.320193f, 0.484448f, 0.577368f, 0.190770f, 0.462801f, 0.384114f, 0.403607f, 0.534057f, 0.326255f, 0.496504f, 0.563586f, 0.180264f, 0.464196f, 0.384055f, 0.385514f, 0.537212f, 0.338047f, 0.485235f, 0.555800f, 0.177971f, 0.457827f, 0.377928f, 0.372441f, 0.541035f, 0.343750f, 0.483692f, 0.705313f, 0.467049f, 0.389698f, 0.530555f, 0.548003f, 0.637789f, 0.501241f, 0.493046f, 0.692096f, 0.474284f, 0.375588f, 0.530258f, 0.507811f, 0.618987f, 0.468782f, 0.502795f, 0.703758f, 0.479856f, 0.374269f, 0.518477f, 0.518286f, 0.631821f, 0.502535f, 0.509264f, 0.689539f, 0.474638f, 0.374363f, 0.519131f, 0.519441f, 0.644891f, 0.480984f, 0.490645f}; + RunTest4D(batch_size, q_num_heads, q_sequence_length, head_size, kv_sequence_length, kv_num_heads, v_head_size, past_sequence_length, + q, k, v, std::vector(), std::initializer_list(), std::vector(), std::vector(), + -1, -1, std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), -1, TensorType::kFloat, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type + y, std::vector(), std::vector(), std::vector(), + false, true, true // disable_cpu, disable_cuda, disable_dml + ); +} + +TEST(AttentionTest, Attention4DDefaultFloat16) { + int batch_size = 2; // Q.shape[0] + int q_num_heads = 3; // Q.shape[1] + int q_sequence_length = 4; // Q.shape[2] + int head_size = 8; // Q.shape[3] + int kv_sequence_length = 6; // K.shape[2] and V.shape[2] + int kv_num_heads = 3; // K.shape[1] and V.shape[1] + int v_head_size = 8; // V.shape[3] + int past_sequence_length = 5; // past_key.shape[2] and past_value.shape[2] + + std::vector q = {0.548814f, 0.715189f, 0.602763f, 0.544883f, 0.423655f, 0.645894f, 0.437587f, 0.891773f, 0.963663f, 0.383442f, 0.791725f, 0.528895f, 0.568045f, 0.925597f, 0.071036f, 0.087129f, 0.020218f, 0.832620f, 0.778157f, 0.870012f, 0.978618f, 0.799159f, 0.461479f, 0.780529f, 0.118274f, 0.639921f, 0.143353f, 0.944669f, 0.521848f, 0.414662f, 0.264556f, 0.774234f, 0.456150f, 0.568434f, 0.018790f, 0.617635f, 0.612096f, 0.616934f, 0.943748f, 0.681820f, 0.359508f, 0.437032f, 0.697631f, 0.060225f, 0.666767f, 0.670638f, 0.210383f, 0.128926f, 0.315428f, 0.363711f, 0.570197f, 0.438602f, 0.988374f, 0.102045f, 0.208877f, 0.161310f, 0.653108f, 0.253292f, 0.466311f, 0.244426f, 0.158970f, 0.110375f, 0.656330f, 0.138183f, 0.196582f, 0.368725f, 0.820993f, 0.097101f, 0.837945f, 0.096098f, 0.976459f, 0.468651f, 0.976761f, 0.604846f, 0.739264f, 0.039188f, 0.282807f, 0.120197f, 0.296140f, 0.118728f, 0.317983f, 0.414263f, 0.064147f, 0.692472f, 0.566601f, 0.265390f, 0.523248f, 0.093941f, 0.575947f, 0.929296f, 0.318569f, 0.667410f, 0.131798f, 0.716327f, 0.289406f, 0.183191f, 0.586513f, 0.020108f, 0.828940f, 0.004695f, 0.677817f, 0.270008f, 0.735194f, 0.962189f, 0.248753f, 0.576157f, 0.592042f, 0.572252f, 0.223082f, 0.952749f, 0.447125f, 0.846409f, 0.699479f, 0.297437f, 0.813798f, 0.396506f, 0.881103f, 0.581273f, 0.881735f, 0.692532f, 0.725254f, 0.501324f, 0.956084f, 0.643990f, 0.423855f, 0.606393f, 0.019193f, 0.301575f, 0.660174f, 0.290078f, 0.618015f, 0.428769f, 0.135474f, 0.298282f, 0.569965f, 0.590873f, 0.574325f, 0.653201f, 0.652103f, 0.431418f, 0.896547f, 0.367562f, 0.435865f, 0.891923f, 0.806194f, 0.703889f, 0.100227f, 0.919483f, 0.714241f, 0.998847f, 0.149448f, 0.868126f, 0.162493f, 0.615560f, 0.123820f, 0.848008f, 0.807319f, 0.569101f, 0.407183f, 0.069167f, 0.697429f, 0.453543f, 0.722056f, 0.866382f, 0.975522f, 0.855803f, 0.011714f, 0.359978f, 0.729991f, 0.171630f, 0.521037f, 0.054338f, 0.199997f, 0.018522f, 0.793698f, 0.223925f, 0.345352f, 0.928081f, 0.704414f, 0.031839f, 0.164694f, 0.621478f, 0.577229f, 0.237893f, 0.934214f, 0.613966f, 0.535633f, 0.589910f, 0.730122f, 0.311945f, 0.398221f, 0.209844f}; + std::vector k = {0.186193f, 0.944372f, 0.739551f, 0.490459f, 0.227415f, 0.254356f, 0.058029f, 0.434417f, 0.311796f, 0.696343f, 0.377752f, 0.179604f, 0.024679f, 0.067250f, 0.679393f, 0.453697f, 0.536579f, 0.896671f, 0.990339f, 0.216897f, 0.663078f, 0.263322f, 0.020651f, 0.758379f, 0.320017f, 0.383464f, 0.588317f, 0.831048f, 0.628982f, 0.872651f, 0.273542f, 0.798047f, 0.185636f, 0.952792f, 0.687488f, 0.215508f, 0.947371f, 0.730856f, 0.253942f, 0.213312f, 0.518201f, 0.025663f, 0.207470f, 0.424685f, 0.374170f, 0.463575f, 0.277629f, 0.586784f, 0.863856f, 0.117532f, 0.517379f, 0.132068f, 0.716860f, 0.396060f, 0.565421f, 0.183280f, 0.144848f, 0.488056f, 0.355613f, 0.940432f, 0.765325f, 0.748664f, 0.903720f, 0.083422f, 0.552192f, 0.584476f, 0.961936f, 0.292148f, 0.240829f, 0.100294f, 0.016430f, 0.929529f, 0.669917f, 0.785153f, 0.281730f, 0.586410f, 0.063955f, 0.485628f, 0.977495f, 0.876505f, 0.338159f, 0.961570f, 0.231702f, 0.949319f, 0.941378f, 0.799203f, 0.630448f, 0.874288f, 0.293020f, 0.848944f, 0.617877f, 0.013237f, 0.347234f, 0.148141f, 0.981829f, 0.478370f, 0.497391f, 0.639473f, 0.368585f, 0.136900f, 0.822118f, 0.189848f, 0.511319f, 0.224317f, 0.097844f, 0.862191f, 0.972919f, 0.960835f, 0.906555f, 0.774047f, 0.333145f, 0.081101f, 0.407241f, 0.232234f, 0.132488f, 0.053427f, 0.725594f, 0.011427f, 0.770581f, 0.146947f, 0.079522f, 0.089603f, 0.672048f, 0.245367f, 0.420539f, 0.557369f, 0.860551f, 0.727044f, 0.270328f, 0.131483f, 0.055374f, 0.301599f, 0.262118f, 0.456141f, 0.683281f, 0.695625f, 0.283519f, 0.379927f, 0.181151f, 0.788545f, 0.056848f, 0.696997f, 0.778695f, 0.777408f, 0.259423f, 0.373813f, 0.587600f, 0.272822f, 0.370853f, 0.197054f, 0.459856f, 0.044612f, 0.799796f, 0.076956f, 0.518835f, 0.306810f, 0.577543f, 0.959433f, 0.645570f, 0.035362f, 0.430402f, 0.510017f, 0.536178f, 0.681392f, 0.277596f, 0.128861f, 0.392676f, 0.956406f, 0.187131f, 0.903984f, 0.543806f, 0.456911f, 0.882041f, 0.458604f, 0.724168f, 0.399025f, 0.904044f, 0.690025f, 0.699622f, 0.327720f, 0.756779f, 0.636061f, 0.240020f, 0.160539f, 0.796391f, 0.959167f, 0.458139f, 0.590984f, 0.857723f, 0.457223f, 0.951874f, 0.575751f, 0.820767f, 0.908844f, 0.815524f, 0.159414f, 0.628898f, 0.398434f, 0.062713f, 0.424032f, 0.258684f, 0.849038f, 0.033305f, 0.958983f, 0.355369f, 0.356707f, 0.016329f, 0.185232f, 0.401260f, 0.929291f, 0.099615f, 0.945302f, 0.869489f, 0.454162f, 0.326701f, 0.232744f, 0.614465f, 0.033075f, 0.015606f, 0.428796f, 0.068074f, 0.251941f, 0.221161f, 0.253191f, 0.131055f, 0.012036f, 0.115484f, 0.618480f, 0.974256f, 0.990345f, 0.409054f, 0.162954f, 0.638762f, 0.490305f, 0.989410f, 0.065304f, 0.783234f, 0.288399f, 0.241419f, 0.662505f, 0.246063f, 0.665859f, 0.517309f, 0.424089f, 0.554688f, 0.287052f, 0.706575f, 0.414857f, 0.360546f, 0.828657f, 0.924967f, 0.046007f, 0.232627f, 0.348519f, 0.814966f, 0.985491f, 0.968972f, 0.904948f, 0.296556f, 0.992011f, 0.249420f, 0.105906f, 0.950953f, 0.233420f, 0.689768f, 0.058356f, 0.730709f, 0.881720f, 0.272437f, 0.379057f, 0.374296f, 0.748788f, 0.237807f, 0.171853f, 0.449292f, 0.304468f, 0.839189f, 0.237742f, 0.502389f, 0.942584f, 0.633998f, 0.867289f, 0.940210f, 0.750765f, 0.699575f, 0.967966f, 0.994401f, 0.451822f}; + std::vector v = {0.070870f, 0.292794f, 0.152355f, 0.417486f, 0.131289f, 0.604118f, 0.382808f, 0.895386f, 0.967795f, 0.546885f, 0.274824f, 0.592230f, 0.896761f, 0.406733f, 0.552078f, 0.271653f, 0.455444f, 0.401714f, 0.248413f, 0.505866f, 0.310381f, 0.373035f, 0.524970f, 0.750595f, 0.333507f, 0.924159f, 0.862319f, 0.048690f, 0.253643f, 0.446136f, 0.104628f, 0.348476f, 0.740098f, 0.680514f, 0.622384f, 0.710528f, 0.204924f, 0.341698f, 0.676242f, 0.879235f, 0.543678f, 0.282700f, 0.030235f, 0.710337f, 0.007884f, 0.372679f, 0.530537f, 0.922111f, 0.089495f, 0.405942f, 0.024313f, 0.342611f, 0.622231f, 0.279068f, 0.209750f, 0.115703f, 0.577140f, 0.695270f, 0.671957f, 0.948861f, 0.002703f, 0.647197f, 0.600392f, 0.588740f, 0.962770f, 0.016872f, 0.696482f, 0.813679f, 0.509807f, 0.333965f, 0.790840f, 0.097243f, 0.442036f, 0.519952f, 0.693956f, 0.090886f, 0.227759f, 0.410302f, 0.623295f, 0.886961f, 0.618826f, 0.133461f, 0.980580f, 0.871786f, 0.502721f, 0.922348f, 0.541381f, 0.923306f, 0.829897f, 0.968286f, 0.919783f, 0.036034f, 0.174772f, 0.389135f, 0.952143f, 0.300029f, 0.160468f, 0.886305f, 0.446394f, 0.907876f, 0.160230f, 0.661117f, 0.440264f, 0.076487f, 0.696463f, 0.247399f, 0.039616f, 0.059944f, 0.061079f, 0.907733f, 0.739884f, 0.898062f, 0.672582f, 0.528940f, 0.304446f, 0.997962f, 0.362189f, 0.470649f, 0.378245f, 0.979527f, 0.174658f, 0.327988f, 0.680349f, 0.063208f, 0.607249f, 0.477646f, 0.284000f, 0.238413f, 0.514513f, 0.367928f, 0.456520f, 0.337477f, 0.970494f, 0.133439f, 0.096804f, 0.343392f, 0.591027f, 0.659176f, 0.397257f, 0.999278f, 0.351893f, 0.721407f, 0.637583f, 0.813054f, 0.976226f, 0.889794f, 0.764562f, 0.698249f, 0.335498f, 0.147686f, 0.062636f, 0.241902f, 0.432281f, 0.521996f, 0.773084f, 0.958741f, 0.117320f, 0.107004f, 0.589695f, 0.745398f, 0.848150f, 0.935832f, 0.983426f, 0.399802f, 0.380335f, 0.147809f, 0.684934f, 0.656762f, 0.862063f, 0.097258f, 0.497777f, 0.581082f, 0.241557f, 0.169025f, 0.859581f, 0.058535f, 0.470621f, 0.115834f, 0.457059f, 0.979962f, 0.423706f, 0.857125f, 0.117316f, 0.271252f, 0.403793f, 0.399812f, 0.671384f, 0.344718f, 0.713767f, 0.639187f, 0.399161f, 0.431760f, 0.614528f, 0.070042f, 0.822407f, 0.653421f, 0.726342f, 0.536923f, 0.110477f, 0.405036f, 0.405374f, 0.321043f, 0.029950f, 0.737254f, 0.109784f, 0.606308f, 0.703218f, 0.634786f, 0.959142f, 0.103298f, 0.867167f, 0.029190f, 0.534917f, 0.404244f, 0.524184f, 0.365100f, 0.190567f, 0.019123f, 0.518150f, 0.842777f, 0.373216f, 0.222864f, 0.080532f, 0.085311f, 0.221396f, 0.100014f, 0.265040f, 0.066149f, 0.065605f, 0.856276f, 0.162120f, 0.559682f, 0.773456f, 0.456410f, 0.153369f, 0.199596f, 0.432984f, 0.528234f, 0.349440f, 0.781480f, 0.751022f, 0.927212f, 0.028953f, 0.895691f, 0.392569f, 0.878372f, 0.690785f, 0.987349f, 0.759282f, 0.364545f, 0.501063f, 0.376389f, 0.364912f, 0.260904f, 0.495970f, 0.681740f, 0.277340f, 0.524380f, 0.117380f, 0.159845f, 0.046806f, 0.970731f, 0.003860f, 0.178580f, 0.612867f, 0.081370f, 0.881896f, 0.719620f, 0.966390f, 0.507636f, 0.300404f, 0.549501f, 0.930819f, 0.520761f, 0.267207f, 0.877399f, 0.371919f, 0.001383f, 0.247685f, 0.318234f, 0.858777f, 0.458503f, 0.444587f, 0.336102f, 0.880678f, 0.945027f, 0.991890f, 0.376741f}; + + ASSERT_EQ(q.size(), batch_size * q_num_heads * q_sequence_length * head_size); + ASSERT_EQ(k.size(), batch_size * kv_num_heads * kv_sequence_length * head_size); + ASSERT_EQ(v.size(), batch_size * kv_num_heads * kv_sequence_length * v_head_size); + + std::vector y = {0.501465f, 0.543511f, 0.398088f, 0.474061f, 0.290507f, 0.423018f, 0.447999f, 0.672390f, 0.500878f, 0.545140f, 0.402253f, 0.478354f, 0.278711f, 0.420929f, 0.451124f, 0.682613f, 0.496502f, 0.557356f, 0.419293f, 0.467867f, 0.280946f, 0.422295f, 0.445183f, 0.675748f, 0.498804f, 0.545264f, 0.399543f, 0.471287f, 0.287601f, 0.424845f, 0.443877f, 0.670841f, 0.580098f, 0.450536f, 0.702941f, 0.538382f, 0.329768f, 0.543394f, 0.613723f, 0.562010f, 0.584549f, 0.447129f, 0.673676f, 0.537643f, 0.342950f, 0.515742f, 0.613437f, 0.502951f, 0.585248f, 0.443070f, 0.676620f, 0.549025f, 0.343112f, 0.522440f, 0.611621f, 0.507324f, 0.580745f, 0.461632f, 0.668496f, 0.507376f, 0.336816f, 0.500750f, 0.618162f, 0.500909f, 0.464240f, 0.493342f, 0.380525f, 0.530712f, 0.397056f, 0.582067f, 0.443341f, 0.559227f, 0.467916f, 0.503694f, 0.373170f, 0.549178f, 0.387171f, 0.587037f, 0.448581f, 0.561591f, 0.478681f, 0.496704f, 0.369457f, 0.545459f, 0.392339f, 0.587842f, 0.452645f, 0.576330f, 0.483897f, 0.491793f, 0.360676f, 0.530990f, 0.380686f, 0.603393f, 0.467172f, 0.583590f, 0.642787f, 0.470883f, 0.686034f, 0.642719f, 0.386365f, 0.366454f, 0.467120f, 0.405736f, 0.644347f, 0.466390f, 0.684379f, 0.640710f, 0.385963f, 0.366271f, 0.472645f, 0.403025f, 0.631421f, 0.453237f, 0.677676f, 0.643979f, 0.390879f, 0.377663f, 0.467158f, 0.401772f, 0.637457f, 0.459313f, 0.677889f, 0.659685f, 0.383362f, 0.379251f, 0.453763f, 0.401437f, 0.555998f, 0.186013f, 0.455395f, 0.406430f, 0.395553f, 0.526708f, 0.320193f, 0.484448f, 0.577368f, 0.190770f, 0.462801f, 0.384114f, 0.403607f, 0.534057f, 0.326255f, 0.496504f, 0.563586f, 0.180264f, 0.464196f, 0.384055f, 0.385514f, 0.537212f, 0.338047f, 0.485235f, 0.555800f, 0.177971f, 0.457827f, 0.377928f, 0.372441f, 0.541035f, 0.343750f, 0.483692f, 0.705313f, 0.467049f, 0.389698f, 0.530555f, 0.548003f, 0.637789f, 0.501241f, 0.493046f, 0.692096f, 0.474284f, 0.375588f, 0.530258f, 0.507811f, 0.618987f, 0.468782f, 0.502795f, 0.703758f, 0.479856f, 0.374269f, 0.518477f, 0.518286f, 0.631821f, 0.502535f, 0.509264f, 0.689539f, 0.474638f, 0.374363f, 0.519131f, 0.519441f, 0.644891f, 0.480984f, 0.490645f}; + RunTest4D(batch_size, q_num_heads, q_sequence_length, head_size, kv_sequence_length, kv_num_heads, v_head_size, past_sequence_length, + q, k, v, std::vector(), std::initializer_list(), std::vector(), std::vector(), + -1, -1, std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), -1, TensorType::kFloat16, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type + y, std::vector(), std::vector(), std::vector(), + false, true, true // disable_cpu, disable_cuda, disable_dml + ); +} + +TEST(AttentionTest, Attention4DSoftCap) { + int batch_size = 2; // Q.shape[0] + int q_num_heads = 3; // Q.shape[1] + int q_sequence_length = 4; // Q.shape[2] + int head_size = 8; // Q.shape[3] + int kv_sequence_length = 6; // K.shape[2] and V.shape[2] + int kv_num_heads = 3; // K.shape[1] and V.shape[1] + int v_head_size = 10; // V.shape[3] + int past_sequence_length = 5; // past_key.shape[2] and past_value.shape[2] + + std::vector q = {0.548814f, 0.715189f, 0.602763f, 0.544883f, 0.423655f, 0.645894f, 0.437587f, 0.891773f, 0.963663f, 0.383442f, 0.791725f, 0.528895f, 0.568045f, 0.925597f, 0.071036f, 0.087129f, 0.020218f, 0.832620f, 0.778157f, 0.870012f, 0.978618f, 0.799159f, 0.461479f, 0.780529f, 0.118274f, 0.639921f, 0.143353f, 0.944669f, 0.521848f, 0.414662f, 0.264556f, 0.774234f, 0.456150f, 0.568434f, 0.018790f, 0.617635f, 0.612096f, 0.616934f, 0.943748f, 0.681820f, 0.359508f, 0.437032f, 0.697631f, 0.060225f, 0.666767f, 0.670638f, 0.210383f, 0.128926f, 0.315428f, 0.363711f, 0.570197f, 0.438602f, 0.988374f, 0.102045f, 0.208877f, 0.161310f, 0.653108f, 0.253292f, 0.466311f, 0.244426f, 0.158970f, 0.110375f, 0.656330f, 0.138183f, 0.196582f, 0.368725f, 0.820993f, 0.097101f, 0.837945f, 0.096098f, 0.976459f, 0.468651f, 0.976761f, 0.604846f, 0.739264f, 0.039188f, 0.282807f, 0.120197f, 0.296140f, 0.118728f, 0.317983f, 0.414263f, 0.064147f, 0.692472f, 0.566601f, 0.265390f, 0.523248f, 0.093941f, 0.575947f, 0.929296f, 0.318569f, 0.667410f, 0.131798f, 0.716327f, 0.289406f, 0.183191f, 0.586513f, 0.020108f, 0.828940f, 0.004695f, 0.677817f, 0.270008f, 0.735194f, 0.962189f, 0.248753f, 0.576157f, 0.592042f, 0.572252f, 0.223082f, 0.952749f, 0.447125f, 0.846409f, 0.699479f, 0.297437f, 0.813798f, 0.396506f, 0.881103f, 0.581273f, 0.881735f, 0.692532f, 0.725254f, 0.501324f, 0.956084f, 0.643990f, 0.423855f, 0.606393f, 0.019193f, 0.301575f, 0.660174f, 0.290078f, 0.618015f, 0.428769f, 0.135474f, 0.298282f, 0.569965f, 0.590873f, 0.574325f, 0.653201f, 0.652103f, 0.431418f, 0.896547f, 0.367562f, 0.435865f, 0.891923f, 0.806194f, 0.703889f, 0.100227f, 0.919483f, 0.714241f, 0.998847f, 0.149448f, 0.868126f, 0.162493f, 0.615560f, 0.123820f, 0.848008f, 0.807319f, 0.569101f, 0.407183f, 0.069167f, 0.697429f, 0.453543f, 0.722056f, 0.866382f, 0.975522f, 0.855803f, 0.011714f, 0.359978f, 0.729991f, 0.171630f, 0.521037f, 0.054338f, 0.199997f, 0.018522f, 0.793698f, 0.223925f, 0.345352f, 0.928081f, 0.704414f, 0.031839f, 0.164694f, 0.621478f, 0.577229f, 0.237893f, 0.934214f, 0.613966f, 0.535633f, 0.589910f, 0.730122f, 0.311945f, 0.398221f, 0.209844f}; + std::vector k = {0.186193f, 0.944372f, 0.739551f, 0.490459f, 0.227415f, 0.254356f, 0.058029f, 0.434417f, 0.311796f, 0.696343f, 0.377752f, 0.179604f, 0.024679f, 0.067250f, 0.679393f, 0.453697f, 0.536579f, 0.896671f, 0.990339f, 0.216897f, 0.663078f, 0.263322f, 0.020651f, 0.758379f, 0.320017f, 0.383464f, 0.588317f, 0.831048f, 0.628982f, 0.872651f, 0.273542f, 0.798047f, 0.185636f, 0.952792f, 0.687488f, 0.215508f, 0.947371f, 0.730856f, 0.253942f, 0.213312f, 0.518201f, 0.025663f, 0.207470f, 0.424685f, 0.374170f, 0.463575f, 0.277629f, 0.586784f, 0.863856f, 0.117532f, 0.517379f, 0.132068f, 0.716860f, 0.396060f, 0.565421f, 0.183280f, 0.144848f, 0.488056f, 0.355613f, 0.940432f, 0.765325f, 0.748664f, 0.903720f, 0.083422f, 0.552192f, 0.584476f, 0.961936f, 0.292148f, 0.240829f, 0.100294f, 0.016430f, 0.929529f, 0.669917f, 0.785153f, 0.281730f, 0.586410f, 0.063955f, 0.485628f, 0.977495f, 0.876505f, 0.338159f, 0.961570f, 0.231702f, 0.949319f, 0.941378f, 0.799203f, 0.630448f, 0.874288f, 0.293020f, 0.848944f, 0.617877f, 0.013237f, 0.347234f, 0.148141f, 0.981829f, 0.478370f, 0.497391f, 0.639473f, 0.368585f, 0.136900f, 0.822118f, 0.189848f, 0.511319f, 0.224317f, 0.097844f, 0.862191f, 0.972919f, 0.960835f, 0.906555f, 0.774047f, 0.333145f, 0.081101f, 0.407241f, 0.232234f, 0.132488f, 0.053427f, 0.725594f, 0.011427f, 0.770581f, 0.146947f, 0.079522f, 0.089603f, 0.672048f, 0.245367f, 0.420539f, 0.557369f, 0.860551f, 0.727044f, 0.270328f, 0.131483f, 0.055374f, 0.301599f, 0.262118f, 0.456141f, 0.683281f, 0.695625f, 0.283519f, 0.379927f, 0.181151f, 0.788545f, 0.056848f, 0.696997f, 0.778695f, 0.777408f, 0.259423f, 0.373813f, 0.587600f, 0.272822f, 0.370853f, 0.197054f, 0.459856f, 0.044612f, 0.799796f, 0.076956f, 0.518835f, 0.306810f, 0.577543f, 0.959433f, 0.645570f, 0.035362f, 0.430402f, 0.510017f, 0.536178f, 0.681392f, 0.277596f, 0.128861f, 0.392676f, 0.956406f, 0.187131f, 0.903984f, 0.543806f, 0.456911f, 0.882041f, 0.458604f, 0.724168f, 0.399025f, 0.904044f, 0.690025f, 0.699622f, 0.327720f, 0.756779f, 0.636061f, 0.240020f, 0.160539f, 0.796391f, 0.959167f, 0.458139f, 0.590984f, 0.857723f, 0.457223f, 0.951874f, 0.575751f, 0.820767f, 0.908844f, 0.815524f, 0.159414f, 0.628898f, 0.398434f, 0.062713f, 0.424032f, 0.258684f, 0.849038f, 0.033305f, 0.958983f, 0.355369f, 0.356707f, 0.016329f, 0.185232f, 0.401260f, 0.929291f, 0.099615f, 0.945302f, 0.869489f, 0.454162f, 0.326701f, 0.232744f, 0.614465f, 0.033075f, 0.015606f, 0.428796f, 0.068074f, 0.251941f, 0.221161f, 0.253191f, 0.131055f, 0.012036f, 0.115484f, 0.618480f, 0.974256f, 0.990345f, 0.409054f, 0.162954f, 0.638762f, 0.490305f, 0.989410f, 0.065304f, 0.783234f, 0.288399f, 0.241419f, 0.662505f, 0.246063f, 0.665859f, 0.517309f, 0.424089f, 0.554688f, 0.287052f, 0.706575f, 0.414857f, 0.360546f, 0.828657f, 0.924967f, 0.046007f, 0.232627f, 0.348519f, 0.814966f, 0.985491f, 0.968972f, 0.904948f, 0.296556f, 0.992011f, 0.249420f, 0.105906f, 0.950953f, 0.233420f, 0.689768f, 0.058356f, 0.730709f, 0.881720f, 0.272437f, 0.379057f, 0.374296f, 0.748788f, 0.237807f, 0.171853f, 0.449292f, 0.304468f, 0.839189f, 0.237742f, 0.502389f, 0.942584f, 0.633998f, 0.867289f, 0.940210f, 0.750765f, 0.699575f, 0.967966f, 0.994401f, 0.451822f}; + std::vector v = {0.070870f, 0.292794f, 0.152355f, 0.417486f, 0.131289f, 0.604118f, 0.382808f, 0.895386f, 0.967795f, 0.546885f, 0.274824f, 0.592230f, 0.896761f, 0.406733f, 0.552078f, 0.271653f, 0.455444f, 0.401714f, 0.248413f, 0.505866f, 0.310381f, 0.373035f, 0.524970f, 0.750595f, 0.333507f, 0.924159f, 0.862319f, 0.048690f, 0.253643f, 0.446136f, 0.104628f, 0.348476f, 0.740098f, 0.680514f, 0.622384f, 0.710528f, 0.204924f, 0.341698f, 0.676242f, 0.879235f, 0.543678f, 0.282700f, 0.030235f, 0.710337f, 0.007884f, 0.372679f, 0.530537f, 0.922111f, 0.089495f, 0.405942f, 0.024313f, 0.342611f, 0.622231f, 0.279068f, 0.209750f, 0.115703f, 0.577140f, 0.695270f, 0.671957f, 0.948861f, 0.002703f, 0.647197f, 0.600392f, 0.588740f, 0.962770f, 0.016872f, 0.696482f, 0.813679f, 0.509807f, 0.333965f, 0.790840f, 0.097243f, 0.442036f, 0.519952f, 0.693956f, 0.090886f, 0.227759f, 0.410302f, 0.623295f, 0.886961f, 0.618826f, 0.133461f, 0.980580f, 0.871786f, 0.502721f, 0.922348f, 0.541381f, 0.923306f, 0.829897f, 0.968286f, 0.919783f, 0.036034f, 0.174772f, 0.389135f, 0.952143f, 0.300029f, 0.160468f, 0.886305f, 0.446394f, 0.907876f, 0.160230f, 0.661117f, 0.440264f, 0.076487f, 0.696463f, 0.247399f, 0.039616f, 0.059944f, 0.061079f, 0.907733f, 0.739884f, 0.898062f, 0.672582f, 0.528940f, 0.304446f, 0.997962f, 0.362189f, 0.470649f, 0.378245f, 0.979527f, 0.174658f, 0.327988f, 0.680349f, 0.063208f, 0.607249f, 0.477646f, 0.284000f, 0.238413f, 0.514513f, 0.367928f, 0.456520f, 0.337477f, 0.970494f, 0.133439f, 0.096804f, 0.343392f, 0.591027f, 0.659176f, 0.397257f, 0.999278f, 0.351893f, 0.721407f, 0.637583f, 0.813054f, 0.976226f, 0.889794f, 0.764562f, 0.698249f, 0.335498f, 0.147686f, 0.062636f, 0.241902f, 0.432281f, 0.521996f, 0.773084f, 0.958741f, 0.117320f, 0.107004f, 0.589695f, 0.745398f, 0.848150f, 0.935832f, 0.983426f, 0.399802f, 0.380335f, 0.147809f, 0.684934f, 0.656762f, 0.862063f, 0.097258f, 0.497777f, 0.581082f, 0.241557f, 0.169025f, 0.859581f, 0.058535f, 0.470621f, 0.115834f, 0.457059f, 0.979962f, 0.423706f, 0.857125f, 0.117316f, 0.271252f, 0.403793f, 0.399812f, 0.671384f, 0.344718f, 0.713767f, 0.639187f, 0.399161f, 0.431760f, 0.614528f, 0.070042f, 0.822407f, 0.653421f, 0.726342f, 0.536923f, 0.110477f, 0.405036f, 0.405374f, 0.321043f, 0.029950f, 0.737254f, 0.109784f, 0.606308f, 0.703218f, 0.634786f, 0.959142f, 0.103298f, 0.867167f, 0.029190f, 0.534917f, 0.404244f, 0.524184f, 0.365100f, 0.190567f, 0.019123f, 0.518150f, 0.842777f, 0.373216f, 0.222864f, 0.080532f, 0.085311f, 0.221396f, 0.100014f, 0.265040f, 0.066149f, 0.065605f, 0.856276f, 0.162120f, 0.559682f, 0.773456f, 0.456410f, 0.153369f, 0.199596f, 0.432984f, 0.528234f, 0.349440f, 0.781480f, 0.751022f, 0.927212f, 0.028953f, 0.895691f, 0.392569f, 0.878372f, 0.690785f, 0.987349f, 0.759282f, 0.364545f, 0.501063f, 0.376389f, 0.364912f, 0.260904f, 0.495970f, 0.681740f, 0.277340f, 0.524380f, 0.117380f, 0.159845f, 0.046806f, 0.970731f, 0.003860f, 0.178580f, 0.612867f, 0.081370f, 0.881896f, 0.719620f, 0.966390f, 0.507636f, 0.300404f, 0.549501f, 0.930819f, 0.520761f, 0.267207f, 0.877399f, 0.371919f, 0.001383f, 0.247685f, 0.318234f, 0.858777f, 0.458503f, 0.444587f, 0.336102f, 0.880678f, 0.945027f, 0.991890f, 0.376741f, 0.966147f, 0.791880f, 0.675689f, 0.244889f, 0.216457f, 0.166048f, 0.922757f, 0.294077f, 0.453094f, 0.493958f, 0.778172f, 0.844235f, 0.139073f, 0.426904f, 0.842855f, 0.818033f, 0.102414f, 0.156383f, 0.304199f, 0.075359f, 0.424663f, 0.107618f, 0.568218f, 0.246557f, 0.596433f, 0.117526f, 0.975884f, 0.932561f, 0.391797f, 0.242179f, 0.250398f, 0.483394f, 0.039993f, 0.639705f, 0.408303f, 0.377407f, 0.809365f, 0.709035f, 0.954334f, 0.351936f, 0.897543f, 0.769967f, 0.357425f, 0.621665f, 0.288570f, 0.874400f, 0.112427f, 0.212434f, 0.183033f, 0.403026f, 0.745233f, 0.526907f, 0.487676f, 0.000546f, 0.425402f, 0.063554f, 0.208253f, 0.932394f, 0.215398f, 0.858338f, 0.802893f, 0.159146f, 0.605712f, 0.115662f, 0.727888f, 0.637462f, 0.811939f, 0.479385f, 0.914863f, 0.049349f, 0.292889f, 0.715053f}; + + ASSERT_EQ(q.size(), batch_size * q_num_heads * q_sequence_length * head_size); + ASSERT_EQ(k.size(), batch_size * kv_num_heads * kv_sequence_length * head_size); + ASSERT_EQ(v.size(), batch_size * kv_num_heads * kv_sequence_length * v_head_size); + + // with softcap=2 + std::vector ys = {0.227656f, 0.365938f, 0.487233f, 0.563168f, 0.314693f, 0.531065f, 0.502050f, 0.532911f, 0.479305f, 0.619133f, 0.230719f, 0.361396f, 0.476682f, 0.566474f, 0.307008f, 0.529635f, 0.503316f, 0.540530f, 0.476847f, 0.620507f, 0.233811f, 0.361041f, 0.472995f, 0.571894f, 0.309176f, 0.536943f, 0.498525f, 0.540409f, 0.475846f, 0.615972f, 0.223131f, 0.365223f, 0.488599f, 0.559249f, 0.315942f, 0.525688f, 0.494637f, 0.539772f, 0.488000f, 0.625606f, 0.539676f, 0.409601f, 0.515692f, 0.453467f, 0.697314f, 0.396105f, 0.298034f, 0.552743f, 0.440534f, 0.843839f, 0.525229f, 0.418362f, 0.546100f, 0.481009f, 0.687614f, 0.414847f, 0.327302f, 0.572564f, 0.461664f, 0.831423f, 0.521430f, 0.418181f, 0.545782f, 0.477744f, 0.687580f, 0.409896f, 0.324292f, 0.565326f, 0.459461f, 0.832106f, 0.542037f, 0.412166f, 0.539834f, 0.486373f, 0.691028f, 0.421836f, 0.330124f, 0.590678f, 0.466584f, 0.831750f, 0.382651f, 0.501226f, 0.660685f, 0.342294f, 0.602060f, 0.492331f, 0.474420f, 0.409177f, 0.518175f, 0.581219f, 0.387046f, 0.503621f, 0.666169f, 0.332572f, 0.596846f, 0.479979f, 0.479994f, 0.413598f, 0.515513f, 0.577655f, 0.398240f, 0.510706f, 0.663548f, 0.331466f, 0.594592f, 0.465828f, 0.485982f, 0.414944f, 0.516808f, 0.588646f, 0.401608f, 0.503138f, 0.664086f, 0.314710f, 0.579984f, 0.448406f, 0.482952f, 0.410394f, 0.515656f, 0.614177f, 0.430626f, 0.390476f, 0.382732f, 0.345745f, 0.361913f, 0.378760f, 0.487068f, 0.359749f, 0.440638f, 0.611671f, 0.434161f, 0.384956f, 0.382824f, 0.347990f, 0.361064f, 0.378348f, 0.483768f, 0.357084f, 0.441993f, 0.612507f, 0.430795f, 0.387191f, 0.392464f, 0.339543f, 0.365489f, 0.373725f, 0.480792f, 0.354801f, 0.428210f, 0.621415f, 0.430196f, 0.387751f, 0.374630f, 0.333935f, 0.363445f, 0.372619f, 0.482465f, 0.350530f, 0.427172f, 0.618986f, 0.529767f, 0.595815f, 0.301624f, 0.397276f, 0.605455f, 0.607591f, 0.617002f, 0.544150f, 0.662428f, 0.510301f, 0.533071f, 0.602211f, 0.278156f, 0.392687f, 0.617217f, 0.593104f, 0.629293f, 0.563362f, 0.682795f, 0.519542f, 0.520110f, 0.607374f, 0.289463f, 0.386297f, 0.609416f, 0.600651f, 0.634780f, 0.553284f, 0.672042f, 0.506020f, 0.514322f, 0.606722f, 0.293574f, 0.377031f, 0.612149f, 0.599634f, 0.640889f, 0.546806f, 0.672437f, 0.505487f, 0.380489f, 0.334473f, 0.554343f, 0.499727f, 0.526942f, 0.558871f, 0.530154f, 0.309413f, 0.555978f, 0.488827f, 0.371393f, 0.341934f, 0.552609f, 0.481362f, 0.537837f, 0.574948f, 0.524870f, 0.312968f, 0.558314f, 0.484292f, 0.382443f, 0.330414f, 0.567252f, 0.481373f, 0.557600f, 0.575927f, 0.536800f, 0.295057f, 0.535626f, 0.488409f, 0.369831f, 0.343157f, 0.554056f, 0.492472f, 0.539300f, 0.565926f, 0.540317f, 0.307066f, 0.560539f, 0.493642f}; + RunTest4D(batch_size, q_num_heads, q_sequence_length, head_size, kv_sequence_length, kv_num_heads, v_head_size, past_sequence_length, + q, k, v, std::vector(), std::initializer_list(), std::vector(), std::vector(), + -1, -1, std::numeric_limits::quiet_NaN(), 2.0f, -1, TensorType::kFloat, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type + ys, std::vector(), std::vector(), std::vector(), + false, true, true // disable_cpu, disable_cuda, disable_dml + ); +} + +TEST(AttentionTest, Attention4DSoftCapFloat16) { + int batch_size = 2; // Q.shape[0] + int q_num_heads = 3; // Q.shape[1] + int q_sequence_length = 4; // Q.shape[2] + int head_size = 8; // Q.shape[3] + int kv_sequence_length = 6; // K.shape[2] and V.shape[2] + int kv_num_heads = 3; // K.shape[1] and V.shape[1] + int v_head_size = 10; // V.shape[3] + int past_sequence_length = 5; // past_key.shape[2] and past_value.shape[2] + + std::vector q = {0.548814f, 0.715189f, 0.602763f, 0.544883f, 0.423655f, 0.645894f, 0.437587f, 0.891773f, 0.963663f, 0.383442f, 0.791725f, 0.528895f, 0.568045f, 0.925597f, 0.071036f, 0.087129f, 0.020218f, 0.832620f, 0.778157f, 0.870012f, 0.978618f, 0.799159f, 0.461479f, 0.780529f, 0.118274f, 0.639921f, 0.143353f, 0.944669f, 0.521848f, 0.414662f, 0.264556f, 0.774234f, 0.456150f, 0.568434f, 0.018790f, 0.617635f, 0.612096f, 0.616934f, 0.943748f, 0.681820f, 0.359508f, 0.437032f, 0.697631f, 0.060225f, 0.666767f, 0.670638f, 0.210383f, 0.128926f, 0.315428f, 0.363711f, 0.570197f, 0.438602f, 0.988374f, 0.102045f, 0.208877f, 0.161310f, 0.653108f, 0.253292f, 0.466311f, 0.244426f, 0.158970f, 0.110375f, 0.656330f, 0.138183f, 0.196582f, 0.368725f, 0.820993f, 0.097101f, 0.837945f, 0.096098f, 0.976459f, 0.468651f, 0.976761f, 0.604846f, 0.739264f, 0.039188f, 0.282807f, 0.120197f, 0.296140f, 0.118728f, 0.317983f, 0.414263f, 0.064147f, 0.692472f, 0.566601f, 0.265390f, 0.523248f, 0.093941f, 0.575947f, 0.929296f, 0.318569f, 0.667410f, 0.131798f, 0.716327f, 0.289406f, 0.183191f, 0.586513f, 0.020108f, 0.828940f, 0.004695f, 0.677817f, 0.270008f, 0.735194f, 0.962189f, 0.248753f, 0.576157f, 0.592042f, 0.572252f, 0.223082f, 0.952749f, 0.447125f, 0.846409f, 0.699479f, 0.297437f, 0.813798f, 0.396506f, 0.881103f, 0.581273f, 0.881735f, 0.692532f, 0.725254f, 0.501324f, 0.956084f, 0.643990f, 0.423855f, 0.606393f, 0.019193f, 0.301575f, 0.660174f, 0.290078f, 0.618015f, 0.428769f, 0.135474f, 0.298282f, 0.569965f, 0.590873f, 0.574325f, 0.653201f, 0.652103f, 0.431418f, 0.896547f, 0.367562f, 0.435865f, 0.891923f, 0.806194f, 0.703889f, 0.100227f, 0.919483f, 0.714241f, 0.998847f, 0.149448f, 0.868126f, 0.162493f, 0.615560f, 0.123820f, 0.848008f, 0.807319f, 0.569101f, 0.407183f, 0.069167f, 0.697429f, 0.453543f, 0.722056f, 0.866382f, 0.975522f, 0.855803f, 0.011714f, 0.359978f, 0.729991f, 0.171630f, 0.521037f, 0.054338f, 0.199997f, 0.018522f, 0.793698f, 0.223925f, 0.345352f, 0.928081f, 0.704414f, 0.031839f, 0.164694f, 0.621478f, 0.577229f, 0.237893f, 0.934214f, 0.613966f, 0.535633f, 0.589910f, 0.730122f, 0.311945f, 0.398221f, 0.209844f}; + std::vector k = {0.186193f, 0.944372f, 0.739551f, 0.490459f, 0.227415f, 0.254356f, 0.058029f, 0.434417f, 0.311796f, 0.696343f, 0.377752f, 0.179604f, 0.024679f, 0.067250f, 0.679393f, 0.453697f, 0.536579f, 0.896671f, 0.990339f, 0.216897f, 0.663078f, 0.263322f, 0.020651f, 0.758379f, 0.320017f, 0.383464f, 0.588317f, 0.831048f, 0.628982f, 0.872651f, 0.273542f, 0.798047f, 0.185636f, 0.952792f, 0.687488f, 0.215508f, 0.947371f, 0.730856f, 0.253942f, 0.213312f, 0.518201f, 0.025663f, 0.207470f, 0.424685f, 0.374170f, 0.463575f, 0.277629f, 0.586784f, 0.863856f, 0.117532f, 0.517379f, 0.132068f, 0.716860f, 0.396060f, 0.565421f, 0.183280f, 0.144848f, 0.488056f, 0.355613f, 0.940432f, 0.765325f, 0.748664f, 0.903720f, 0.083422f, 0.552192f, 0.584476f, 0.961936f, 0.292148f, 0.240829f, 0.100294f, 0.016430f, 0.929529f, 0.669917f, 0.785153f, 0.281730f, 0.586410f, 0.063955f, 0.485628f, 0.977495f, 0.876505f, 0.338159f, 0.961570f, 0.231702f, 0.949319f, 0.941378f, 0.799203f, 0.630448f, 0.874288f, 0.293020f, 0.848944f, 0.617877f, 0.013237f, 0.347234f, 0.148141f, 0.981829f, 0.478370f, 0.497391f, 0.639473f, 0.368585f, 0.136900f, 0.822118f, 0.189848f, 0.511319f, 0.224317f, 0.097844f, 0.862191f, 0.972919f, 0.960835f, 0.906555f, 0.774047f, 0.333145f, 0.081101f, 0.407241f, 0.232234f, 0.132488f, 0.053427f, 0.725594f, 0.011427f, 0.770581f, 0.146947f, 0.079522f, 0.089603f, 0.672048f, 0.245367f, 0.420539f, 0.557369f, 0.860551f, 0.727044f, 0.270328f, 0.131483f, 0.055374f, 0.301599f, 0.262118f, 0.456141f, 0.683281f, 0.695625f, 0.283519f, 0.379927f, 0.181151f, 0.788545f, 0.056848f, 0.696997f, 0.778695f, 0.777408f, 0.259423f, 0.373813f, 0.587600f, 0.272822f, 0.370853f, 0.197054f, 0.459856f, 0.044612f, 0.799796f, 0.076956f, 0.518835f, 0.306810f, 0.577543f, 0.959433f, 0.645570f, 0.035362f, 0.430402f, 0.510017f, 0.536178f, 0.681392f, 0.277596f, 0.128861f, 0.392676f, 0.956406f, 0.187131f, 0.903984f, 0.543806f, 0.456911f, 0.882041f, 0.458604f, 0.724168f, 0.399025f, 0.904044f, 0.690025f, 0.699622f, 0.327720f, 0.756779f, 0.636061f, 0.240020f, 0.160539f, 0.796391f, 0.959167f, 0.458139f, 0.590984f, 0.857723f, 0.457223f, 0.951874f, 0.575751f, 0.820767f, 0.908844f, 0.815524f, 0.159414f, 0.628898f, 0.398434f, 0.062713f, 0.424032f, 0.258684f, 0.849038f, 0.033305f, 0.958983f, 0.355369f, 0.356707f, 0.016329f, 0.185232f, 0.401260f, 0.929291f, 0.099615f, 0.945302f, 0.869489f, 0.454162f, 0.326701f, 0.232744f, 0.614465f, 0.033075f, 0.015606f, 0.428796f, 0.068074f, 0.251941f, 0.221161f, 0.253191f, 0.131055f, 0.012036f, 0.115484f, 0.618480f, 0.974256f, 0.990345f, 0.409054f, 0.162954f, 0.638762f, 0.490305f, 0.989410f, 0.065304f, 0.783234f, 0.288399f, 0.241419f, 0.662505f, 0.246063f, 0.665859f, 0.517309f, 0.424089f, 0.554688f, 0.287052f, 0.706575f, 0.414857f, 0.360546f, 0.828657f, 0.924967f, 0.046007f, 0.232627f, 0.348519f, 0.814966f, 0.985491f, 0.968972f, 0.904948f, 0.296556f, 0.992011f, 0.249420f, 0.105906f, 0.950953f, 0.233420f, 0.689768f, 0.058356f, 0.730709f, 0.881720f, 0.272437f, 0.379057f, 0.374296f, 0.748788f, 0.237807f, 0.171853f, 0.449292f, 0.304468f, 0.839189f, 0.237742f, 0.502389f, 0.942584f, 0.633998f, 0.867289f, 0.940210f, 0.750765f, 0.699575f, 0.967966f, 0.994401f, 0.451822f}; + std::vector v = {0.070870f, 0.292794f, 0.152355f, 0.417486f, 0.131289f, 0.604118f, 0.382808f, 0.895386f, 0.967795f, 0.546885f, 0.274824f, 0.592230f, 0.896761f, 0.406733f, 0.552078f, 0.271653f, 0.455444f, 0.401714f, 0.248413f, 0.505866f, 0.310381f, 0.373035f, 0.524970f, 0.750595f, 0.333507f, 0.924159f, 0.862319f, 0.048690f, 0.253643f, 0.446136f, 0.104628f, 0.348476f, 0.740098f, 0.680514f, 0.622384f, 0.710528f, 0.204924f, 0.341698f, 0.676242f, 0.879235f, 0.543678f, 0.282700f, 0.030235f, 0.710337f, 0.007884f, 0.372679f, 0.530537f, 0.922111f, 0.089495f, 0.405942f, 0.024313f, 0.342611f, 0.622231f, 0.279068f, 0.209750f, 0.115703f, 0.577140f, 0.695270f, 0.671957f, 0.948861f, 0.002703f, 0.647197f, 0.600392f, 0.588740f, 0.962770f, 0.016872f, 0.696482f, 0.813679f, 0.509807f, 0.333965f, 0.790840f, 0.097243f, 0.442036f, 0.519952f, 0.693956f, 0.090886f, 0.227759f, 0.410302f, 0.623295f, 0.886961f, 0.618826f, 0.133461f, 0.980580f, 0.871786f, 0.502721f, 0.922348f, 0.541381f, 0.923306f, 0.829897f, 0.968286f, 0.919783f, 0.036034f, 0.174772f, 0.389135f, 0.952143f, 0.300029f, 0.160468f, 0.886305f, 0.446394f, 0.907876f, 0.160230f, 0.661117f, 0.440264f, 0.076487f, 0.696463f, 0.247399f, 0.039616f, 0.059944f, 0.061079f, 0.907733f, 0.739884f, 0.898062f, 0.672582f, 0.528940f, 0.304446f, 0.997962f, 0.362189f, 0.470649f, 0.378245f, 0.979527f, 0.174658f, 0.327988f, 0.680349f, 0.063208f, 0.607249f, 0.477646f, 0.284000f, 0.238413f, 0.514513f, 0.367928f, 0.456520f, 0.337477f, 0.970494f, 0.133439f, 0.096804f, 0.343392f, 0.591027f, 0.659176f, 0.397257f, 0.999278f, 0.351893f, 0.721407f, 0.637583f, 0.813054f, 0.976226f, 0.889794f, 0.764562f, 0.698249f, 0.335498f, 0.147686f, 0.062636f, 0.241902f, 0.432281f, 0.521996f, 0.773084f, 0.958741f, 0.117320f, 0.107004f, 0.589695f, 0.745398f, 0.848150f, 0.935832f, 0.983426f, 0.399802f, 0.380335f, 0.147809f, 0.684934f, 0.656762f, 0.862063f, 0.097258f, 0.497777f, 0.581082f, 0.241557f, 0.169025f, 0.859581f, 0.058535f, 0.470621f, 0.115834f, 0.457059f, 0.979962f, 0.423706f, 0.857125f, 0.117316f, 0.271252f, 0.403793f, 0.399812f, 0.671384f, 0.344718f, 0.713767f, 0.639187f, 0.399161f, 0.431760f, 0.614528f, 0.070042f, 0.822407f, 0.653421f, 0.726342f, 0.536923f, 0.110477f, 0.405036f, 0.405374f, 0.321043f, 0.029950f, 0.737254f, 0.109784f, 0.606308f, 0.703218f, 0.634786f, 0.959142f, 0.103298f, 0.867167f, 0.029190f, 0.534917f, 0.404244f, 0.524184f, 0.365100f, 0.190567f, 0.019123f, 0.518150f, 0.842777f, 0.373216f, 0.222864f, 0.080532f, 0.085311f, 0.221396f, 0.100014f, 0.265040f, 0.066149f, 0.065605f, 0.856276f, 0.162120f, 0.559682f, 0.773456f, 0.456410f, 0.153369f, 0.199596f, 0.432984f, 0.528234f, 0.349440f, 0.781480f, 0.751022f, 0.927212f, 0.028953f, 0.895691f, 0.392569f, 0.878372f, 0.690785f, 0.987349f, 0.759282f, 0.364545f, 0.501063f, 0.376389f, 0.364912f, 0.260904f, 0.495970f, 0.681740f, 0.277340f, 0.524380f, 0.117380f, 0.159845f, 0.046806f, 0.970731f, 0.003860f, 0.178580f, 0.612867f, 0.081370f, 0.881896f, 0.719620f, 0.966390f, 0.507636f, 0.300404f, 0.549501f, 0.930819f, 0.520761f, 0.267207f, 0.877399f, 0.371919f, 0.001383f, 0.247685f, 0.318234f, 0.858777f, 0.458503f, 0.444587f, 0.336102f, 0.880678f, 0.945027f, 0.991890f, 0.376741f, 0.966147f, 0.791880f, 0.675689f, 0.244889f, 0.216457f, 0.166048f, 0.922757f, 0.294077f, 0.453094f, 0.493958f, 0.778172f, 0.844235f, 0.139073f, 0.426904f, 0.842855f, 0.818033f, 0.102414f, 0.156383f, 0.304199f, 0.075359f, 0.424663f, 0.107618f, 0.568218f, 0.246557f, 0.596433f, 0.117526f, 0.975884f, 0.932561f, 0.391797f, 0.242179f, 0.250398f, 0.483394f, 0.039993f, 0.639705f, 0.408303f, 0.377407f, 0.809365f, 0.709035f, 0.954334f, 0.351936f, 0.897543f, 0.769967f, 0.357425f, 0.621665f, 0.288570f, 0.874400f, 0.112427f, 0.212434f, 0.183033f, 0.403026f, 0.745233f, 0.526907f, 0.487676f, 0.000546f, 0.425402f, 0.063554f, 0.208253f, 0.932394f, 0.215398f, 0.858338f, 0.802893f, 0.159146f, 0.605712f, 0.115662f, 0.727888f, 0.637462f, 0.811939f, 0.479385f, 0.914863f, 0.049349f, 0.292889f, 0.715053f}; + + ASSERT_EQ(q.size(), batch_size * q_num_heads * q_sequence_length * head_size); + ASSERT_EQ(k.size(), batch_size * kv_num_heads * kv_sequence_length * head_size); + ASSERT_EQ(v.size(), batch_size * kv_num_heads * kv_sequence_length * v_head_size); + + // with softcap=2 + std::vector ys = {0.227656f, 0.365938f, 0.487233f, 0.563168f, 0.314693f, 0.531065f, 0.502050f, 0.532911f, 0.479305f, 0.619133f, 0.230719f, 0.361396f, 0.476682f, 0.566474f, 0.307008f, 0.529635f, 0.503316f, 0.540530f, 0.476847f, 0.620507f, 0.233811f, 0.361041f, 0.472995f, 0.571894f, 0.309176f, 0.536943f, 0.498525f, 0.540409f, 0.475846f, 0.615972f, 0.223131f, 0.365223f, 0.488599f, 0.559249f, 0.315942f, 0.525688f, 0.494637f, 0.539772f, 0.488000f, 0.625606f, 0.539676f, 0.409601f, 0.515692f, 0.453467f, 0.697314f, 0.396105f, 0.298034f, 0.552743f, 0.440534f, 0.843839f, 0.525229f, 0.418362f, 0.546100f, 0.481009f, 0.687614f, 0.414847f, 0.327302f, 0.572564f, 0.461664f, 0.831423f, 0.521430f, 0.418181f, 0.545782f, 0.477744f, 0.687580f, 0.409896f, 0.324292f, 0.565326f, 0.459461f, 0.832106f, 0.542037f, 0.412166f, 0.539834f, 0.486373f, 0.691028f, 0.421836f, 0.330124f, 0.590678f, 0.466584f, 0.831750f, 0.382651f, 0.501226f, 0.660685f, 0.342294f, 0.602060f, 0.492331f, 0.474420f, 0.409177f, 0.518175f, 0.581219f, 0.387046f, 0.503621f, 0.666169f, 0.332572f, 0.596846f, 0.479979f, 0.479994f, 0.413598f, 0.515513f, 0.577655f, 0.398240f, 0.510706f, 0.663548f, 0.331466f, 0.594592f, 0.465828f, 0.485982f, 0.414944f, 0.516808f, 0.588646f, 0.401608f, 0.503138f, 0.664086f, 0.314710f, 0.579984f, 0.448406f, 0.482952f, 0.410394f, 0.515656f, 0.614177f, 0.430626f, 0.390476f, 0.382732f, 0.345745f, 0.361913f, 0.378760f, 0.487068f, 0.359749f, 0.440638f, 0.611671f, 0.434161f, 0.384956f, 0.382824f, 0.347990f, 0.361064f, 0.378348f, 0.483768f, 0.357084f, 0.441993f, 0.612507f, 0.430795f, 0.387191f, 0.392464f, 0.339543f, 0.365489f, 0.373725f, 0.480792f, 0.354801f, 0.428210f, 0.621415f, 0.430196f, 0.387751f, 0.374630f, 0.333935f, 0.363445f, 0.372619f, 0.482465f, 0.350530f, 0.427172f, 0.618986f, 0.529767f, 0.595815f, 0.301624f, 0.397276f, 0.605455f, 0.607591f, 0.617002f, 0.544150f, 0.662428f, 0.510301f, 0.533071f, 0.602211f, 0.278156f, 0.392687f, 0.617217f, 0.593104f, 0.629293f, 0.563362f, 0.682795f, 0.519542f, 0.520110f, 0.607374f, 0.289463f, 0.386297f, 0.609416f, 0.600651f, 0.634780f, 0.553284f, 0.672042f, 0.506020f, 0.514322f, 0.606722f, 0.293574f, 0.377031f, 0.612149f, 0.599634f, 0.640889f, 0.546806f, 0.672437f, 0.505487f, 0.380489f, 0.334473f, 0.554343f, 0.499727f, 0.526942f, 0.558871f, 0.530154f, 0.309413f, 0.555978f, 0.488827f, 0.371393f, 0.341934f, 0.552609f, 0.481362f, 0.537837f, 0.574948f, 0.524870f, 0.312968f, 0.558314f, 0.484292f, 0.382443f, 0.330414f, 0.567252f, 0.481373f, 0.557600f, 0.575927f, 0.536800f, 0.295057f, 0.535626f, 0.488409f, 0.369831f, 0.343157f, 0.554056f, 0.492472f, 0.539300f, 0.565926f, 0.540317f, 0.307066f, 0.560539f, 0.493642f}; + RunTest4D(batch_size, q_num_heads, q_sequence_length, head_size, kv_sequence_length, kv_num_heads, v_head_size, past_sequence_length, + q, k, v, std::vector(), std::initializer_list(), std::vector(), std::vector(), + -1, -1, std::numeric_limits::quiet_NaN(), 2.0f, -1, TensorType::kFloat16, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type + ys, std::vector(), std::vector(), std::vector(), + false, true, true // disable_cpu, disable_cuda, disable_dml + ); +} + +TEST(AttentionTest, Attention4DAttnMask) { + int batch_size = 2; // Q.shape[0] + int q_num_heads = 3; // Q.shape[1] + int q_sequence_length = 4; // Q.shape[2] + int head_size = 8; // Q.shape[3] + int kv_sequence_length = 6; // K.shape[2] and V.shape[2] + int kv_num_heads = 3; // K.shape[1] and V.shape[1] + int v_head_size = 8; // V.shape[3] + int past_sequence_length = 0; // past_key.shape[2] and past_value.shape[2] + + std::vector q = {0.548814f, 0.715189f, 0.602763f, 0.544883f, 0.423655f, 0.645894f, 0.437587f, 0.891773f, 0.963663f, 0.383442f, 0.791725f, 0.528895f, 0.568045f, 0.925597f, 0.071036f, 0.087129f, 0.020218f, 0.832620f, 0.778157f, 0.870012f, 0.978618f, 0.799159f, 0.461479f, 0.780529f, 0.118274f, 0.639921f, 0.143353f, 0.944669f, 0.521848f, 0.414662f, 0.264556f, 0.774234f, 0.456150f, 0.568434f, 0.018790f, 0.617635f, 0.612096f, 0.616934f, 0.943748f, 0.681820f, 0.359508f, 0.437032f, 0.697631f, 0.060225f, 0.666767f, 0.670638f, 0.210383f, 0.128926f, 0.315428f, 0.363711f, 0.570197f, 0.438602f, 0.988374f, 0.102045f, 0.208877f, 0.161310f, 0.653108f, 0.253292f, 0.466311f, 0.244426f, 0.158970f, 0.110375f, 0.656330f, 0.138183f, 0.196582f, 0.368725f, 0.820993f, 0.097101f, 0.837945f, 0.096098f, 0.976459f, 0.468651f, 0.976761f, 0.604846f, 0.739264f, 0.039188f, 0.282807f, 0.120197f, 0.296140f, 0.118728f, 0.317983f, 0.414263f, 0.064147f, 0.692472f, 0.566601f, 0.265390f, 0.523248f, 0.093941f, 0.575947f, 0.929296f, 0.318569f, 0.667410f, 0.131798f, 0.716327f, 0.289406f, 0.183191f, 0.586513f, 0.020108f, 0.828940f, 0.004695f, 0.677817f, 0.270008f, 0.735194f, 0.962189f, 0.248753f, 0.576157f, 0.592042f, 0.572252f, 0.223082f, 0.952749f, 0.447125f, 0.846409f, 0.699479f, 0.297437f, 0.813798f, 0.396506f, 0.881103f, 0.581273f, 0.881735f, 0.692532f, 0.725254f, 0.501324f, 0.956084f, 0.643990f, 0.423855f, 0.606393f, 0.019193f, 0.301575f, 0.660174f, 0.290078f, 0.618015f, 0.428769f, 0.135474f, 0.298282f, 0.569965f, 0.590873f, 0.574325f, 0.653201f, 0.652103f, 0.431418f, 0.896547f, 0.367562f, 0.435865f, 0.891923f, 0.806194f, 0.703889f, 0.100227f, 0.919483f, 0.714241f, 0.998847f, 0.149448f, 0.868126f, 0.162493f, 0.615560f, 0.123820f, 0.848008f, 0.807319f, 0.569101f, 0.407183f, 0.069167f, 0.697429f, 0.453543f, 0.722056f, 0.866382f, 0.975522f, 0.855803f, 0.011714f, 0.359978f, 0.729991f, 0.171630f, 0.521037f, 0.054338f, 0.199997f, 0.018522f, 0.793698f, 0.223925f, 0.345352f, 0.928081f, 0.704414f, 0.031839f, 0.164694f, 0.621478f, 0.577229f, 0.237893f, 0.934214f, 0.613966f, 0.535633f, 0.589910f, 0.730122f, 0.311945f, 0.398221f, 0.209844f}; + std::vector k = {0.186193f, 0.944372f, 0.739551f, 0.490459f, 0.227415f, 0.254356f, 0.058029f, 0.434417f, 0.311796f, 0.696343f, 0.377752f, 0.179604f, 0.024679f, 0.067250f, 0.679393f, 0.453697f, 0.536579f, 0.896671f, 0.990339f, 0.216897f, 0.663078f, 0.263322f, 0.020651f, 0.758379f, 0.320017f, 0.383464f, 0.588317f, 0.831048f, 0.628982f, 0.872651f, 0.273542f, 0.798047f, 0.185636f, 0.952792f, 0.687488f, 0.215508f, 0.947371f, 0.730856f, 0.253942f, 0.213312f, 0.518201f, 0.025663f, 0.207470f, 0.424685f, 0.374170f, 0.463575f, 0.277629f, 0.586784f, 0.863856f, 0.117532f, 0.517379f, 0.132068f, 0.716860f, 0.396060f, 0.565421f, 0.183280f, 0.144848f, 0.488056f, 0.355613f, 0.940432f, 0.765325f, 0.748664f, 0.903720f, 0.083422f, 0.552192f, 0.584476f, 0.961936f, 0.292148f, 0.240829f, 0.100294f, 0.016430f, 0.929529f, 0.669917f, 0.785153f, 0.281730f, 0.586410f, 0.063955f, 0.485628f, 0.977495f, 0.876505f, 0.338159f, 0.961570f, 0.231702f, 0.949319f, 0.941378f, 0.799203f, 0.630448f, 0.874288f, 0.293020f, 0.848944f, 0.617877f, 0.013237f, 0.347234f, 0.148141f, 0.981829f, 0.478370f, 0.497391f, 0.639473f, 0.368585f, 0.136900f, 0.822118f, 0.189848f, 0.511319f, 0.224317f, 0.097844f, 0.862191f, 0.972919f, 0.960835f, 0.906555f, 0.774047f, 0.333145f, 0.081101f, 0.407241f, 0.232234f, 0.132488f, 0.053427f, 0.725594f, 0.011427f, 0.770581f, 0.146947f, 0.079522f, 0.089603f, 0.672048f, 0.245367f, 0.420539f, 0.557369f, 0.860551f, 0.727044f, 0.270328f, 0.131483f, 0.055374f, 0.301599f, 0.262118f, 0.456141f, 0.683281f, 0.695625f, 0.283519f, 0.379927f, 0.181151f, 0.788545f, 0.056848f, 0.696997f, 0.778695f, 0.777408f, 0.259423f, 0.373813f, 0.587600f, 0.272822f, 0.370853f, 0.197054f, 0.459856f, 0.044612f, 0.799796f, 0.076956f, 0.518835f, 0.306810f, 0.577543f, 0.959433f, 0.645570f, 0.035362f, 0.430402f, 0.510017f, 0.536178f, 0.681392f, 0.277596f, 0.128861f, 0.392676f, 0.956406f, 0.187131f, 0.903984f, 0.543806f, 0.456911f, 0.882041f, 0.458604f, 0.724168f, 0.399025f, 0.904044f, 0.690025f, 0.699622f, 0.327720f, 0.756779f, 0.636061f, 0.240020f, 0.160539f, 0.796391f, 0.959167f, 0.458139f, 0.590984f, 0.857723f, 0.457223f, 0.951874f, 0.575751f, 0.820767f, 0.908844f, 0.815524f, 0.159414f, 0.628898f, 0.398434f, 0.062713f, 0.424032f, 0.258684f, 0.849038f, 0.033305f, 0.958983f, 0.355369f, 0.356707f, 0.016329f, 0.185232f, 0.401260f, 0.929291f, 0.099615f, 0.945302f, 0.869489f, 0.454162f, 0.326701f, 0.232744f, 0.614465f, 0.033075f, 0.015606f, 0.428796f, 0.068074f, 0.251941f, 0.221161f, 0.253191f, 0.131055f, 0.012036f, 0.115484f, 0.618480f, 0.974256f, 0.990345f, 0.409054f, 0.162954f, 0.638762f, 0.490305f, 0.989410f, 0.065304f, 0.783234f, 0.288399f, 0.241419f, 0.662505f, 0.246063f, 0.665859f, 0.517309f, 0.424089f, 0.554688f, 0.287052f, 0.706575f, 0.414857f, 0.360546f, 0.828657f, 0.924967f, 0.046007f, 0.232627f, 0.348519f, 0.814966f, 0.985491f, 0.968972f, 0.904948f, 0.296556f, 0.992011f, 0.249420f, 0.105906f, 0.950953f, 0.233420f, 0.689768f, 0.058356f, 0.730709f, 0.881720f, 0.272437f, 0.379057f, 0.374296f, 0.748788f, 0.237807f, 0.171853f, 0.449292f, 0.304468f, 0.839189f, 0.237742f, 0.502389f, 0.942584f, 0.633998f, 0.867289f, 0.940210f, 0.750765f, 0.699575f, 0.967966f, 0.994401f, 0.451822f}; + std::vector v = {0.070870f, 0.292794f, 0.152355f, 0.417486f, 0.131289f, 0.604118f, 0.382808f, 0.895386f, 0.967795f, 0.546885f, 0.274824f, 0.592230f, 0.896761f, 0.406733f, 0.552078f, 0.271653f, 0.455444f, 0.401714f, 0.248413f, 0.505866f, 0.310381f, 0.373035f, 0.524970f, 0.750595f, 0.333507f, 0.924159f, 0.862319f, 0.048690f, 0.253643f, 0.446136f, 0.104628f, 0.348476f, 0.740098f, 0.680514f, 0.622384f, 0.710528f, 0.204924f, 0.341698f, 0.676242f, 0.879235f, 0.543678f, 0.282700f, 0.030235f, 0.710337f, 0.007884f, 0.372679f, 0.530537f, 0.922111f, 0.089495f, 0.405942f, 0.024313f, 0.342611f, 0.622231f, 0.279068f, 0.209750f, 0.115703f, 0.577140f, 0.695270f, 0.671957f, 0.948861f, 0.002703f, 0.647197f, 0.600392f, 0.588740f, 0.962770f, 0.016872f, 0.696482f, 0.813679f, 0.509807f, 0.333965f, 0.790840f, 0.097243f, 0.442036f, 0.519952f, 0.693956f, 0.090886f, 0.227759f, 0.410302f, 0.623295f, 0.886961f, 0.618826f, 0.133461f, 0.980580f, 0.871786f, 0.502721f, 0.922348f, 0.541381f, 0.923306f, 0.829897f, 0.968286f, 0.919783f, 0.036034f, 0.174772f, 0.389135f, 0.952143f, 0.300029f, 0.160468f, 0.886305f, 0.446394f, 0.907876f, 0.160230f, 0.661117f, 0.440264f, 0.076487f, 0.696463f, 0.247399f, 0.039616f, 0.059944f, 0.061079f, 0.907733f, 0.739884f, 0.898062f, 0.672582f, 0.528940f, 0.304446f, 0.997962f, 0.362189f, 0.470649f, 0.378245f, 0.979527f, 0.174658f, 0.327988f, 0.680349f, 0.063208f, 0.607249f, 0.477646f, 0.284000f, 0.238413f, 0.514513f, 0.367928f, 0.456520f, 0.337477f, 0.970494f, 0.133439f, 0.096804f, 0.343392f, 0.591027f, 0.659176f, 0.397257f, 0.999278f, 0.351893f, 0.721407f, 0.637583f, 0.813054f, 0.976226f, 0.889794f, 0.764562f, 0.698249f, 0.335498f, 0.147686f, 0.062636f, 0.241902f, 0.432281f, 0.521996f, 0.773084f, 0.958741f, 0.117320f, 0.107004f, 0.589695f, 0.745398f, 0.848150f, 0.935832f, 0.983426f, 0.399802f, 0.380335f, 0.147809f, 0.684934f, 0.656762f, 0.862063f, 0.097258f, 0.497777f, 0.581082f, 0.241557f, 0.169025f, 0.859581f, 0.058535f, 0.470621f, 0.115834f, 0.457059f, 0.979962f, 0.423706f, 0.857125f, 0.117316f, 0.271252f, 0.403793f, 0.399812f, 0.671384f, 0.344718f, 0.713767f, 0.639187f, 0.399161f, 0.431760f, 0.614528f, 0.070042f, 0.822407f, 0.653421f, 0.726342f, 0.536923f, 0.110477f, 0.405036f, 0.405374f, 0.321043f, 0.029950f, 0.737254f, 0.109784f, 0.606308f, 0.703218f, 0.634786f, 0.959142f, 0.103298f, 0.867167f, 0.029190f, 0.534917f, 0.404244f, 0.524184f, 0.365100f, 0.190567f, 0.019123f, 0.518150f, 0.842777f, 0.373216f, 0.222864f, 0.080532f, 0.085311f, 0.221396f, 0.100014f, 0.265040f, 0.066149f, 0.065605f, 0.856276f, 0.162120f, 0.559682f, 0.773456f, 0.456410f, 0.153369f, 0.199596f, 0.432984f, 0.528234f, 0.349440f, 0.781480f, 0.751022f, 0.927212f, 0.028953f, 0.895691f, 0.392569f, 0.878372f, 0.690785f, 0.987349f, 0.759282f, 0.364545f, 0.501063f, 0.376389f, 0.364912f, 0.260904f, 0.495970f, 0.681740f, 0.277340f, 0.524380f, 0.117380f, 0.159845f, 0.046806f, 0.970731f, 0.003860f, 0.178580f, 0.612867f, 0.081370f, 0.881896f, 0.719620f, 0.966390f, 0.507636f, 0.300404f, 0.549501f, 0.930819f, 0.520761f, 0.267207f, 0.877399f, 0.371919f, 0.001383f, 0.247685f, 0.318234f, 0.858777f, 0.458503f, 0.444587f, 0.336102f, 0.880678f, 0.945027f, 0.991890f, 0.376741f}; + std::vector m = {0.966147f, 0.791880f, 0.675689f, 0.244889f, 0.216457f, 0.166048f, 0.922757f, 0.294077f, 0.453094f, 0.493958f, 0.778172f, 0.844235f, 0.139073f, 0.426904f, 0.842855f, 0.818033f, 0.102414f, 0.156383f, 0.304199f, 0.075359f, 0.424663f, 0.107618f, 0.568218f, 0.246557f}; + std::vector y = {0.478040f, 0.503674f, 0.349552f, 0.475550f, 0.319086f, 0.440731f, 0.452109f, 0.673914f, 0.477799f, 0.522510f, 0.381228f, 0.496104f, 0.239154f, 0.427475f, 0.460164f, 0.727212f, 0.478457f, 0.589145f, 0.456094f, 0.413665f, 0.297445f, 0.419073f, 0.407575f, 0.626054f, 0.503276f, 0.536857f, 0.396718f, 0.495176f, 0.270464f, 0.419459f, 0.466892f, 0.704668f, 0.544710f, 0.446025f, 0.625069f, 0.574330f, 0.337465f, 0.515011f, 0.576166f, 0.495398f, 0.561775f, 0.451492f, 0.656295f, 0.501454f, 0.371102f, 0.511117f, 0.597942f, 0.486135f, 0.613719f, 0.415552f, 0.679385f, 0.545510f, 0.334013f, 0.491561f, 0.634246f, 0.501191f, 0.592514f, 0.421301f, 0.682063f, 0.535644f, 0.365155f, 0.518639f, 0.614815f, 0.501439f, 0.460727f, 0.519269f, 0.348532f, 0.554692f, 0.328284f, 0.619616f, 0.469338f, 0.556237f, 0.442274f, 0.547421f, 0.394879f, 0.609402f, 0.399426f, 0.573414f, 0.435733f, 0.513013f, 0.478210f, 0.470028f, 0.379309f, 0.520524f, 0.393439f, 0.580848f, 0.442115f, 0.602217f, 0.485329f, 0.501646f, 0.370504f, 0.561198f, 0.416058f, 0.567774f, 0.439229f, 0.571259f, 0.674824f, 0.550989f, 0.722801f, 0.662394f, 0.352779f, 0.301575f, 0.454417f, 0.436797f, 0.640218f, 0.464017f, 0.673274f, 0.631072f, 0.416194f, 0.405371f, 0.424135f, 0.380459f, 0.676026f, 0.466017f, 0.693624f, 0.619528f, 0.361035f, 0.314311f, 0.546125f, 0.401422f, 0.634731f, 0.457909f, 0.673249f, 0.669035f, 0.395002f, 0.414838f, 0.422935f, 0.397171f, 0.578772f, 0.171263f, 0.507806f, 0.446147f, 0.431901f, 0.525101f, 0.333084f, 0.473000f, 0.581295f, 0.193171f, 0.470985f, 0.376522f, 0.425847f, 0.546483f, 0.292789f, 0.509355f, 0.590731f, 0.161755f, 0.514375f, 0.380830f, 0.398416f, 0.492429f, 0.361418f, 0.440428f, 0.559340f, 0.167691f, 0.474461f, 0.331081f, 0.368636f, 0.558841f, 0.331704f, 0.485050f, 0.683438f, 0.514064f, 0.339780f, 0.536424f, 0.478815f, 0.654453f, 0.482692f, 0.544422f, 0.718284f, 0.508385f, 0.350896f, 0.561493f, 0.527900f, 0.642672f, 0.514512f, 0.516495f, 0.644405f, 0.441945f, 0.397069f, 0.484688f, 0.496761f, 0.647967f, 0.423362f, 0.480241f, 0.686930f, 0.492126f, 0.344961f, 0.526120f, 0.489709f, 0.638597f, 0.457665f, 0.469929f}; + + ASSERT_EQ(q.size(), batch_size * q_num_heads * q_sequence_length * head_size); + ASSERT_EQ(k.size(), batch_size * kv_num_heads * kv_sequence_length * head_size); + ASSERT_EQ(v.size(), batch_size * kv_num_heads * kv_sequence_length * v_head_size); + ASSERT_EQ(m.size(), q_sequence_length * kv_sequence_length); + ASSERT_EQ(y.size(), batch_size * q_num_heads * q_sequence_length * v_head_size); + + RunTest4D(batch_size, q_num_heads, q_sequence_length, head_size, kv_sequence_length, kv_num_heads, v_head_size, past_sequence_length, + q, k, v, m, std::initializer_list(), std::vector(), std::vector(), + -1, -1, std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), -1, TensorType::kFloat, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type + y, std::vector(), std::vector(), std::vector(), + false, true, true // disable_cpu, disable_cuda, disable_dml + ); +} + +TEST(AttentionTest, Attention4DAttnMaskBool) { + int batch_size = 2; // Q.shape[0] + int q_num_heads = 3; // Q.shape[1] + int q_sequence_length = 4; // Q.shape[2] + int head_size = 8; // Q.shape[3] + int kv_sequence_length = 6; // K.shape[2] and V.shape[2] + int kv_num_heads = 3; // K.shape[1] and V.shape[1] + int v_head_size = 8; // V.shape[3] + int past_sequence_length = 0; // past_key.shape[2] and past_value.shape[2] + + std::vector q = {0.548814f, 0.715189f, 0.602763f, 0.544883f, 0.423655f, 0.645894f, 0.437587f, 0.891773f, 0.963663f, 0.383442f, 0.791725f, 0.528895f, 0.568045f, 0.925597f, 0.071036f, 0.087129f, 0.020218f, 0.832620f, 0.778157f, 0.870012f, 0.978618f, 0.799159f, 0.461479f, 0.780529f, 0.118274f, 0.639921f, 0.143353f, 0.944669f, 0.521848f, 0.414662f, 0.264556f, 0.774234f, 0.456150f, 0.568434f, 0.018790f, 0.617635f, 0.612096f, 0.616934f, 0.943748f, 0.681820f, 0.359508f, 0.437032f, 0.697631f, 0.060225f, 0.666767f, 0.670638f, 0.210383f, 0.128926f, 0.315428f, 0.363711f, 0.570197f, 0.438602f, 0.988374f, 0.102045f, 0.208877f, 0.161310f, 0.653108f, 0.253292f, 0.466311f, 0.244426f, 0.158970f, 0.110375f, 0.656330f, 0.138183f, 0.196582f, 0.368725f, 0.820993f, 0.097101f, 0.837945f, 0.096098f, 0.976459f, 0.468651f, 0.976761f, 0.604846f, 0.739264f, 0.039188f, 0.282807f, 0.120197f, 0.296140f, 0.118728f, 0.317983f, 0.414263f, 0.064147f, 0.692472f, 0.566601f, 0.265390f, 0.523248f, 0.093941f, 0.575947f, 0.929296f, 0.318569f, 0.667410f, 0.131798f, 0.716327f, 0.289406f, 0.183191f, 0.586513f, 0.020108f, 0.828940f, 0.004695f, 0.677817f, 0.270008f, 0.735194f, 0.962189f, 0.248753f, 0.576157f, 0.592042f, 0.572252f, 0.223082f, 0.952749f, 0.447125f, 0.846409f, 0.699479f, 0.297437f, 0.813798f, 0.396506f, 0.881103f, 0.581273f, 0.881735f, 0.692532f, 0.725254f, 0.501324f, 0.956084f, 0.643990f, 0.423855f, 0.606393f, 0.019193f, 0.301575f, 0.660174f, 0.290078f, 0.618015f, 0.428769f, 0.135474f, 0.298282f, 0.569965f, 0.590873f, 0.574325f, 0.653201f, 0.652103f, 0.431418f, 0.896547f, 0.367562f, 0.435865f, 0.891923f, 0.806194f, 0.703889f, 0.100227f, 0.919483f, 0.714241f, 0.998847f, 0.149448f, 0.868126f, 0.162493f, 0.615560f, 0.123820f, 0.848008f, 0.807319f, 0.569101f, 0.407183f, 0.069167f, 0.697429f, 0.453543f, 0.722056f, 0.866382f, 0.975522f, 0.855803f, 0.011714f, 0.359978f, 0.729991f, 0.171630f, 0.521037f, 0.054338f, 0.199997f, 0.018522f, 0.793698f, 0.223925f, 0.345352f, 0.928081f, 0.704414f, 0.031839f, 0.164694f, 0.621478f, 0.577229f, 0.237893f, 0.934214f, 0.613966f, 0.535633f, 0.589910f, 0.730122f, 0.311945f, 0.398221f, 0.209844f}; + std::vector k = {0.186193f, 0.944372f, 0.739551f, 0.490459f, 0.227415f, 0.254356f, 0.058029f, 0.434417f, 0.311796f, 0.696343f, 0.377752f, 0.179604f, 0.024679f, 0.067250f, 0.679393f, 0.453697f, 0.536579f, 0.896671f, 0.990339f, 0.216897f, 0.663078f, 0.263322f, 0.020651f, 0.758379f, 0.320017f, 0.383464f, 0.588317f, 0.831048f, 0.628982f, 0.872651f, 0.273542f, 0.798047f, 0.185636f, 0.952792f, 0.687488f, 0.215508f, 0.947371f, 0.730856f, 0.253942f, 0.213312f, 0.518201f, 0.025663f, 0.207470f, 0.424685f, 0.374170f, 0.463575f, 0.277629f, 0.586784f, 0.863856f, 0.117532f, 0.517379f, 0.132068f, 0.716860f, 0.396060f, 0.565421f, 0.183280f, 0.144848f, 0.488056f, 0.355613f, 0.940432f, 0.765325f, 0.748664f, 0.903720f, 0.083422f, 0.552192f, 0.584476f, 0.961936f, 0.292148f, 0.240829f, 0.100294f, 0.016430f, 0.929529f, 0.669917f, 0.785153f, 0.281730f, 0.586410f, 0.063955f, 0.485628f, 0.977495f, 0.876505f, 0.338159f, 0.961570f, 0.231702f, 0.949319f, 0.941378f, 0.799203f, 0.630448f, 0.874288f, 0.293020f, 0.848944f, 0.617877f, 0.013237f, 0.347234f, 0.148141f, 0.981829f, 0.478370f, 0.497391f, 0.639473f, 0.368585f, 0.136900f, 0.822118f, 0.189848f, 0.511319f, 0.224317f, 0.097844f, 0.862191f, 0.972919f, 0.960835f, 0.906555f, 0.774047f, 0.333145f, 0.081101f, 0.407241f, 0.232234f, 0.132488f, 0.053427f, 0.725594f, 0.011427f, 0.770581f, 0.146947f, 0.079522f, 0.089603f, 0.672048f, 0.245367f, 0.420539f, 0.557369f, 0.860551f, 0.727044f, 0.270328f, 0.131483f, 0.055374f, 0.301599f, 0.262118f, 0.456141f, 0.683281f, 0.695625f, 0.283519f, 0.379927f, 0.181151f, 0.788545f, 0.056848f, 0.696997f, 0.778695f, 0.777408f, 0.259423f, 0.373813f, 0.587600f, 0.272822f, 0.370853f, 0.197054f, 0.459856f, 0.044612f, 0.799796f, 0.076956f, 0.518835f, 0.306810f, 0.577543f, 0.959433f, 0.645570f, 0.035362f, 0.430402f, 0.510017f, 0.536178f, 0.681392f, 0.277596f, 0.128861f, 0.392676f, 0.956406f, 0.187131f, 0.903984f, 0.543806f, 0.456911f, 0.882041f, 0.458604f, 0.724168f, 0.399025f, 0.904044f, 0.690025f, 0.699622f, 0.327720f, 0.756779f, 0.636061f, 0.240020f, 0.160539f, 0.796391f, 0.959167f, 0.458139f, 0.590984f, 0.857723f, 0.457223f, 0.951874f, 0.575751f, 0.820767f, 0.908844f, 0.815524f, 0.159414f, 0.628898f, 0.398434f, 0.062713f, 0.424032f, 0.258684f, 0.849038f, 0.033305f, 0.958983f, 0.355369f, 0.356707f, 0.016329f, 0.185232f, 0.401260f, 0.929291f, 0.099615f, 0.945302f, 0.869489f, 0.454162f, 0.326701f, 0.232744f, 0.614465f, 0.033075f, 0.015606f, 0.428796f, 0.068074f, 0.251941f, 0.221161f, 0.253191f, 0.131055f, 0.012036f, 0.115484f, 0.618480f, 0.974256f, 0.990345f, 0.409054f, 0.162954f, 0.638762f, 0.490305f, 0.989410f, 0.065304f, 0.783234f, 0.288399f, 0.241419f, 0.662505f, 0.246063f, 0.665859f, 0.517309f, 0.424089f, 0.554688f, 0.287052f, 0.706575f, 0.414857f, 0.360546f, 0.828657f, 0.924967f, 0.046007f, 0.232627f, 0.348519f, 0.814966f, 0.985491f, 0.968972f, 0.904948f, 0.296556f, 0.992011f, 0.249420f, 0.105906f, 0.950953f, 0.233420f, 0.689768f, 0.058356f, 0.730709f, 0.881720f, 0.272437f, 0.379057f, 0.374296f, 0.748788f, 0.237807f, 0.171853f, 0.449292f, 0.304468f, 0.839189f, 0.237742f, 0.502389f, 0.942584f, 0.633998f, 0.867289f, 0.940210f, 0.750765f, 0.699575f, 0.967966f, 0.994401f, 0.451822f}; + std::vector v = {0.070870f, 0.292794f, 0.152355f, 0.417486f, 0.131289f, 0.604118f, 0.382808f, 0.895386f, 0.967795f, 0.546885f, 0.274824f, 0.592230f, 0.896761f, 0.406733f, 0.552078f, 0.271653f, 0.455444f, 0.401714f, 0.248413f, 0.505866f, 0.310381f, 0.373035f, 0.524970f, 0.750595f, 0.333507f, 0.924159f, 0.862319f, 0.048690f, 0.253643f, 0.446136f, 0.104628f, 0.348476f, 0.740098f, 0.680514f, 0.622384f, 0.710528f, 0.204924f, 0.341698f, 0.676242f, 0.879235f, 0.543678f, 0.282700f, 0.030235f, 0.710337f, 0.007884f, 0.372679f, 0.530537f, 0.922111f, 0.089495f, 0.405942f, 0.024313f, 0.342611f, 0.622231f, 0.279068f, 0.209750f, 0.115703f, 0.577140f, 0.695270f, 0.671957f, 0.948861f, 0.002703f, 0.647197f, 0.600392f, 0.588740f, 0.962770f, 0.016872f, 0.696482f, 0.813679f, 0.509807f, 0.333965f, 0.790840f, 0.097243f, 0.442036f, 0.519952f, 0.693956f, 0.090886f, 0.227759f, 0.410302f, 0.623295f, 0.886961f, 0.618826f, 0.133461f, 0.980580f, 0.871786f, 0.502721f, 0.922348f, 0.541381f, 0.923306f, 0.829897f, 0.968286f, 0.919783f, 0.036034f, 0.174772f, 0.389135f, 0.952143f, 0.300029f, 0.160468f, 0.886305f, 0.446394f, 0.907876f, 0.160230f, 0.661117f, 0.440264f, 0.076487f, 0.696463f, 0.247399f, 0.039616f, 0.059944f, 0.061079f, 0.907733f, 0.739884f, 0.898062f, 0.672582f, 0.528940f, 0.304446f, 0.997962f, 0.362189f, 0.470649f, 0.378245f, 0.979527f, 0.174658f, 0.327988f, 0.680349f, 0.063208f, 0.607249f, 0.477646f, 0.284000f, 0.238413f, 0.514513f, 0.367928f, 0.456520f, 0.337477f, 0.970494f, 0.133439f, 0.096804f, 0.343392f, 0.591027f, 0.659176f, 0.397257f, 0.999278f, 0.351893f, 0.721407f, 0.637583f, 0.813054f, 0.976226f, 0.889794f, 0.764562f, 0.698249f, 0.335498f, 0.147686f, 0.062636f, 0.241902f, 0.432281f, 0.521996f, 0.773084f, 0.958741f, 0.117320f, 0.107004f, 0.589695f, 0.745398f, 0.848150f, 0.935832f, 0.983426f, 0.399802f, 0.380335f, 0.147809f, 0.684934f, 0.656762f, 0.862063f, 0.097258f, 0.497777f, 0.581082f, 0.241557f, 0.169025f, 0.859581f, 0.058535f, 0.470621f, 0.115834f, 0.457059f, 0.979962f, 0.423706f, 0.857125f, 0.117316f, 0.271252f, 0.403793f, 0.399812f, 0.671384f, 0.344718f, 0.713767f, 0.639187f, 0.399161f, 0.431760f, 0.614528f, 0.070042f, 0.822407f, 0.653421f, 0.726342f, 0.536923f, 0.110477f, 0.405036f, 0.405374f, 0.321043f, 0.029950f, 0.737254f, 0.109784f, 0.606308f, 0.703218f, 0.634786f, 0.959142f, 0.103298f, 0.867167f, 0.029190f, 0.534917f, 0.404244f, 0.524184f, 0.365100f, 0.190567f, 0.019123f, 0.518150f, 0.842777f, 0.373216f, 0.222864f, 0.080532f, 0.085311f, 0.221396f, 0.100014f, 0.265040f, 0.066149f, 0.065605f, 0.856276f, 0.162120f, 0.559682f, 0.773456f, 0.456410f, 0.153369f, 0.199596f, 0.432984f, 0.528234f, 0.349440f, 0.781480f, 0.751022f, 0.927212f, 0.028953f, 0.895691f, 0.392569f, 0.878372f, 0.690785f, 0.987349f, 0.759282f, 0.364545f, 0.501063f, 0.376389f, 0.364912f, 0.260904f, 0.495970f, 0.681740f, 0.277340f, 0.524380f, 0.117380f, 0.159845f, 0.046806f, 0.970731f, 0.003860f, 0.178580f, 0.612867f, 0.081370f, 0.881896f, 0.719620f, 0.966390f, 0.507636f, 0.300404f, 0.549501f, 0.930819f, 0.520761f, 0.267207f, 0.877399f, 0.371919f, 0.001383f, 0.247685f, 0.318234f, 0.858777f, 0.458503f, 0.444587f, 0.336102f, 0.880678f, 0.945027f, 0.991890f, 0.376741f}; + std::initializer_list m = {true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true}; + std::vector y = {0.501465f, 0.543511f, 0.398088f, 0.474061f, 0.290507f, 0.423018f, 0.447999f, 0.672390f, 0.500878f, 0.545140f, 0.402253f, 0.478354f, 0.278711f, 0.420929f, 0.451124f, 0.682613f, 0.496502f, 0.557356f, 0.419293f, 0.467867f, 0.280946f, 0.422295f, 0.445183f, 0.675748f, 0.498804f, 0.545264f, 0.399543f, 0.471287f, 0.287601f, 0.424845f, 0.443877f, 0.670841f, 0.580098f, 0.450536f, 0.702941f, 0.538382f, 0.329768f, 0.543394f, 0.613723f, 0.562010f, 0.584549f, 0.447129f, 0.673676f, 0.537643f, 0.342950f, 0.515742f, 0.613437f, 0.502951f, 0.585248f, 0.443070f, 0.676620f, 0.549025f, 0.343112f, 0.522440f, 0.611621f, 0.507324f, 0.580745f, 0.461632f, 0.668496f, 0.507376f, 0.336816f, 0.500750f, 0.618162f, 0.500909f, 0.464240f, 0.493342f, 0.380525f, 0.530712f, 0.397056f, 0.582067f, 0.443341f, 0.559227f, 0.467916f, 0.503694f, 0.373170f, 0.549178f, 0.387171f, 0.587037f, 0.448581f, 0.561591f, 0.478681f, 0.496704f, 0.369457f, 0.545459f, 0.392339f, 0.587842f, 0.452645f, 0.576330f, 0.483897f, 0.491793f, 0.360676f, 0.530990f, 0.380686f, 0.603393f, 0.467172f, 0.583590f, 0.642787f, 0.470883f, 0.686034f, 0.642719f, 0.386365f, 0.366454f, 0.467120f, 0.405736f, 0.644347f, 0.466390f, 0.684379f, 0.640710f, 0.385963f, 0.366271f, 0.472645f, 0.403025f, 0.631421f, 0.453237f, 0.677676f, 0.643979f, 0.390879f, 0.377663f, 0.467158f, 0.401772f, 0.637457f, 0.459313f, 0.677889f, 0.659685f, 0.383362f, 0.379251f, 0.453763f, 0.401437f, 0.555998f, 0.186013f, 0.455395f, 0.406430f, 0.395553f, 0.526708f, 0.320193f, 0.484448f, 0.577368f, 0.190770f, 0.462801f, 0.384114f, 0.403607f, 0.534057f, 0.326255f, 0.496504f, 0.563586f, 0.180264f, 0.464196f, 0.384055f, 0.385514f, 0.537212f, 0.338047f, 0.485235f, 0.555800f, 0.177971f, 0.457827f, 0.377928f, 0.372441f, 0.541035f, 0.343750f, 0.483692f, 0.705313f, 0.467049f, 0.389698f, 0.530555f, 0.548003f, 0.637789f, 0.501241f, 0.493046f, 0.692096f, 0.474284f, 0.375588f, 0.530258f, 0.507811f, 0.618987f, 0.468782f, 0.502795f, 0.703758f, 0.479856f, 0.374269f, 0.518477f, 0.518286f, 0.631821f, 0.502535f, 0.509264f, 0.689539f, 0.474638f, 0.374363f, 0.519131f, 0.519441f, 0.644891f, 0.480984f, 0.490645f}; + + ASSERT_EQ(q.size(), batch_size * q_num_heads * q_sequence_length * head_size); + ASSERT_EQ(k.size(), batch_size * kv_num_heads * kv_sequence_length * head_size); + ASSERT_EQ(v.size(), batch_size * kv_num_heads * kv_sequence_length * v_head_size); + ASSERT_EQ(m.size(), q_sequence_length * kv_sequence_length); + ASSERT_EQ(y.size(), batch_size * q_num_heads * q_sequence_length * v_head_size); + + RunTest4D(batch_size, q_num_heads, q_sequence_length, head_size, kv_sequence_length, kv_num_heads, v_head_size, past_sequence_length, + q, k, v, std::vector(), m, std::vector(), std::vector(), + -1, -1, std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), -1, TensorType::kFloat, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type + y, std::vector(), std::vector(), std::vector(), + false, true, true // disable_cpu, disable_cuda, disable_dml + ); +} + +TEST(AttentionTest, Attention4DAttnPastPresentBasic) { + int batch_size = 2; // Q.shape[0] + int q_num_heads = 1; // Q.shape[1] + int q_sequence_length = 3; // Q.shape[2] + int head_size = 2; // Q.shape[3] + int kv_sequence_length = 4; // K.shape[2] and V.shape[2] + int kv_num_heads = 1; // K.shape[1] and V.shape[1] + int v_head_size = 2; // V.shape[3] + int past_sequence_length = 1; // past_key.shape[2] and past_value.shape[2] + + std::vector q = {1, 1, 0, 1, 2, 1, 1, 1, 1, 1, 1, 1}; + std::vector k = {1, 0, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 0, 1, 2}; + std::vector v = {0, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 0, 2}; + std::vector m = {1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1}; + std::vector past_key = {1, 2, 1, 1}; + std::vector past_value = {1, 1, 2, 1}; + std::vector y = {1.2691493034362793, 1.0, 1.0774023532867432, 1.0, 0.9539920091629028, 1.0, 0.4988941252231598, 1.6121423244476318, 0.8137872219085693, 1.3673334121704102, 0.8579846620559692, 1.2801470756530762}; + std::vector present_key = {1.0, 2.0, 1.0, 0.0, 1.0, 1.0, 1.0, 2.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 2.0}; + std::vector present_value = {1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 2.0, 1.0, 2.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 2.0}; + + ASSERT_EQ(q.size(), batch_size * q_num_heads * q_sequence_length * head_size); + ASSERT_EQ(k.size(), batch_size * kv_num_heads * kv_sequence_length * head_size); + ASSERT_EQ(v.size(), batch_size * kv_num_heads * kv_sequence_length * v_head_size); + ASSERT_EQ(m.size(), q_sequence_length * (kv_sequence_length + past_sequence_length)); + ASSERT_EQ(past_key.size(), batch_size * kv_num_heads * past_sequence_length * head_size); + ASSERT_EQ(past_value.size(), batch_size * kv_num_heads * past_sequence_length * v_head_size); + ASSERT_EQ(y.size(), batch_size * q_num_heads * q_sequence_length * v_head_size); + ASSERT_EQ(present_key.size(), batch_size * kv_num_heads * (past_sequence_length + kv_sequence_length) * head_size); + ASSERT_EQ(present_value.size(), batch_size * kv_num_heads * (past_sequence_length + kv_sequence_length) * v_head_size); + + RunTest4D(batch_size, q_num_heads, q_sequence_length, head_size, kv_sequence_length, kv_num_heads, v_head_size, past_sequence_length, + q, k, v, m, std::initializer_list(), past_key, past_value, + -1, -1, std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), -1, TensorType::kFloat, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type + y, present_key, present_value, std::vector(), + false, true, true // disable_cpu, disable_cuda, disable_dml + ); +} + +TEST(AttentionTest, Attention4DAttnPastPresent) { + int batch_size = 2; // Q.shape[0] + int q_num_heads = 3; // Q.shape[1] + int q_sequence_length = 4; // Q.shape[2] + int head_size = 8; // Q.shape[3] + int kv_sequence_length = 6; // K.shape[2] and V.shape[2] + int kv_num_heads = 3; // K.shape[1] and V.shape[1] + int v_head_size = 8; // V.shape[3] + int past_sequence_length = 12; // past_key.shape[2] and past_value.shape[2] + + std::vector q = {0.548814f, 0.715189f, 0.602763f, 0.544883f, 0.423655f, 0.645894f, 0.437587f, 0.891773f, 0.963663f, 0.383442f, 0.791725f, 0.528895f, 0.568045f, 0.925597f, 0.071036f, 0.087129f, 0.020218f, 0.832620f, 0.778157f, 0.870012f, 0.978618f, 0.799159f, 0.461479f, 0.780529f, 0.118274f, 0.639921f, 0.143353f, 0.944669f, 0.521848f, 0.414662f, 0.264556f, 0.774234f, 0.456150f, 0.568434f, 0.018790f, 0.617635f, 0.612096f, 0.616934f, 0.943748f, 0.681820f, 0.359508f, 0.437032f, 0.697631f, 0.060225f, 0.666767f, 0.670638f, 0.210383f, 0.128926f, 0.315428f, 0.363711f, 0.570197f, 0.438602f, 0.988374f, 0.102045f, 0.208877f, 0.161310f, 0.653108f, 0.253292f, 0.466311f, 0.244426f, 0.158970f, 0.110375f, 0.656330f, 0.138183f, 0.196582f, 0.368725f, 0.820993f, 0.097101f, 0.837945f, 0.096098f, 0.976459f, 0.468651f, 0.976761f, 0.604846f, 0.739264f, 0.039188f, 0.282807f, 0.120197f, 0.296140f, 0.118728f, 0.317983f, 0.414263f, 0.064147f, 0.692472f, 0.566601f, 0.265390f, 0.523248f, 0.093941f, 0.575947f, 0.929296f, 0.318569f, 0.667410f, 0.131798f, 0.716327f, 0.289406f, 0.183191f, 0.586513f, 0.020108f, 0.828940f, 0.004695f, 0.677817f, 0.270008f, 0.735194f, 0.962189f, 0.248753f, 0.576157f, 0.592042f, 0.572252f, 0.223082f, 0.952749f, 0.447125f, 0.846409f, 0.699479f, 0.297437f, 0.813798f, 0.396506f, 0.881103f, 0.581273f, 0.881735f, 0.692532f, 0.725254f, 0.501324f, 0.956084f, 0.643990f, 0.423855f, 0.606393f, 0.019193f, 0.301575f, 0.660174f, 0.290078f, 0.618015f, 0.428769f, 0.135474f, 0.298282f, 0.569965f, 0.590873f, 0.574325f, 0.653201f, 0.652103f, 0.431418f, 0.896547f, 0.367562f, 0.435865f, 0.891923f, 0.806194f, 0.703889f, 0.100227f, 0.919483f, 0.714241f, 0.998847f, 0.149448f, 0.868126f, 0.162493f, 0.615560f, 0.123820f, 0.848008f, 0.807319f, 0.569101f, 0.407183f, 0.069167f, 0.697429f, 0.453543f, 0.722056f, 0.866382f, 0.975522f, 0.855803f, 0.011714f, 0.359978f, 0.729991f, 0.171630f, 0.521037f, 0.054338f, 0.199997f, 0.018522f, 0.793698f, 0.223925f, 0.345352f, 0.928081f, 0.704414f, 0.031839f, 0.164694f, 0.621478f, 0.577229f, 0.237893f, 0.934214f, 0.613966f, 0.535633f, 0.589910f, 0.730122f, 0.311945f, 0.398221f, 0.209844f}; + std::vector k = {0.186193f, 0.944372f, 0.739551f, 0.490459f, 0.227415f, 0.254356f, 0.058029f, 0.434417f, 0.311796f, 0.696343f, 0.377752f, 0.179604f, 0.024679f, 0.067250f, 0.679393f, 0.453697f, 0.536579f, 0.896671f, 0.990339f, 0.216897f, 0.663078f, 0.263322f, 0.020651f, 0.758379f, 0.320017f, 0.383464f, 0.588317f, 0.831048f, 0.628982f, 0.872651f, 0.273542f, 0.798047f, 0.185636f, 0.952792f, 0.687488f, 0.215508f, 0.947371f, 0.730856f, 0.253942f, 0.213312f, 0.518201f, 0.025663f, 0.207470f, 0.424685f, 0.374170f, 0.463575f, 0.277629f, 0.586784f, 0.863856f, 0.117532f, 0.517379f, 0.132068f, 0.716860f, 0.396060f, 0.565421f, 0.183280f, 0.144848f, 0.488056f, 0.355613f, 0.940432f, 0.765325f, 0.748664f, 0.903720f, 0.083422f, 0.552192f, 0.584476f, 0.961936f, 0.292148f, 0.240829f, 0.100294f, 0.016430f, 0.929529f, 0.669917f, 0.785153f, 0.281730f, 0.586410f, 0.063955f, 0.485628f, 0.977495f, 0.876505f, 0.338159f, 0.961570f, 0.231702f, 0.949319f, 0.941378f, 0.799203f, 0.630448f, 0.874288f, 0.293020f, 0.848944f, 0.617877f, 0.013237f, 0.347234f, 0.148141f, 0.981829f, 0.478370f, 0.497391f, 0.639473f, 0.368585f, 0.136900f, 0.822118f, 0.189848f, 0.511319f, 0.224317f, 0.097844f, 0.862191f, 0.972919f, 0.960835f, 0.906555f, 0.774047f, 0.333145f, 0.081101f, 0.407241f, 0.232234f, 0.132488f, 0.053427f, 0.725594f, 0.011427f, 0.770581f, 0.146947f, 0.079522f, 0.089603f, 0.672048f, 0.245367f, 0.420539f, 0.557369f, 0.860551f, 0.727044f, 0.270328f, 0.131483f, 0.055374f, 0.301599f, 0.262118f, 0.456141f, 0.683281f, 0.695625f, 0.283519f, 0.379927f, 0.181151f, 0.788545f, 0.056848f, 0.696997f, 0.778695f, 0.777408f, 0.259423f, 0.373813f, 0.587600f, 0.272822f, 0.370853f, 0.197054f, 0.459856f, 0.044612f, 0.799796f, 0.076956f, 0.518835f, 0.306810f, 0.577543f, 0.959433f, 0.645570f, 0.035362f, 0.430402f, 0.510017f, 0.536178f, 0.681392f, 0.277596f, 0.128861f, 0.392676f, 0.956406f, 0.187131f, 0.903984f, 0.543806f, 0.456911f, 0.882041f, 0.458604f, 0.724168f, 0.399025f, 0.904044f, 0.690025f, 0.699622f, 0.327720f, 0.756779f, 0.636061f, 0.240020f, 0.160539f, 0.796391f, 0.959167f, 0.458139f, 0.590984f, 0.857723f, 0.457223f, 0.951874f, 0.575751f, 0.820767f, 0.908844f, 0.815524f, 0.159414f, 0.628898f, 0.398434f, 0.062713f, 0.424032f, 0.258684f, 0.849038f, 0.033305f, 0.958983f, 0.355369f, 0.356707f, 0.016329f, 0.185232f, 0.401260f, 0.929291f, 0.099615f, 0.945302f, 0.869489f, 0.454162f, 0.326701f, 0.232744f, 0.614465f, 0.033075f, 0.015606f, 0.428796f, 0.068074f, 0.251941f, 0.221161f, 0.253191f, 0.131055f, 0.012036f, 0.115484f, 0.618480f, 0.974256f, 0.990345f, 0.409054f, 0.162954f, 0.638762f, 0.490305f, 0.989410f, 0.065304f, 0.783234f, 0.288399f, 0.241419f, 0.662505f, 0.246063f, 0.665859f, 0.517309f, 0.424089f, 0.554688f, 0.287052f, 0.706575f, 0.414857f, 0.360546f, 0.828657f, 0.924967f, 0.046007f, 0.232627f, 0.348519f, 0.814966f, 0.985491f, 0.968972f, 0.904948f, 0.296556f, 0.992011f, 0.249420f, 0.105906f, 0.950953f, 0.233420f, 0.689768f, 0.058356f, 0.730709f, 0.881720f, 0.272437f, 0.379057f, 0.374296f, 0.748788f, 0.237807f, 0.171853f, 0.449292f, 0.304468f, 0.839189f, 0.237742f, 0.502389f, 0.942584f, 0.633998f, 0.867289f, 0.940210f, 0.750765f, 0.699575f, 0.967966f, 0.994401f, 0.451822f}; + std::vector v = {0.070870f, 0.292794f, 0.152355f, 0.417486f, 0.131289f, 0.604118f, 0.382808f, 0.895386f, 0.967795f, 0.546885f, 0.274824f, 0.592230f, 0.896761f, 0.406733f, 0.552078f, 0.271653f, 0.455444f, 0.401714f, 0.248413f, 0.505866f, 0.310381f, 0.373035f, 0.524970f, 0.750595f, 0.333507f, 0.924159f, 0.862319f, 0.048690f, 0.253643f, 0.446136f, 0.104628f, 0.348476f, 0.740098f, 0.680514f, 0.622384f, 0.710528f, 0.204924f, 0.341698f, 0.676242f, 0.879235f, 0.543678f, 0.282700f, 0.030235f, 0.710337f, 0.007884f, 0.372679f, 0.530537f, 0.922111f, 0.089495f, 0.405942f, 0.024313f, 0.342611f, 0.622231f, 0.279068f, 0.209750f, 0.115703f, 0.577140f, 0.695270f, 0.671957f, 0.948861f, 0.002703f, 0.647197f, 0.600392f, 0.588740f, 0.962770f, 0.016872f, 0.696482f, 0.813679f, 0.509807f, 0.333965f, 0.790840f, 0.097243f, 0.442036f, 0.519952f, 0.693956f, 0.090886f, 0.227759f, 0.410302f, 0.623295f, 0.886961f, 0.618826f, 0.133461f, 0.980580f, 0.871786f, 0.502721f, 0.922348f, 0.541381f, 0.923306f, 0.829897f, 0.968286f, 0.919783f, 0.036034f, 0.174772f, 0.389135f, 0.952143f, 0.300029f, 0.160468f, 0.886305f, 0.446394f, 0.907876f, 0.160230f, 0.661117f, 0.440264f, 0.076487f, 0.696463f, 0.247399f, 0.039616f, 0.059944f, 0.061079f, 0.907733f, 0.739884f, 0.898062f, 0.672582f, 0.528940f, 0.304446f, 0.997962f, 0.362189f, 0.470649f, 0.378245f, 0.979527f, 0.174658f, 0.327988f, 0.680349f, 0.063208f, 0.607249f, 0.477646f, 0.284000f, 0.238413f, 0.514513f, 0.367928f, 0.456520f, 0.337477f, 0.970494f, 0.133439f, 0.096804f, 0.343392f, 0.591027f, 0.659176f, 0.397257f, 0.999278f, 0.351893f, 0.721407f, 0.637583f, 0.813054f, 0.976226f, 0.889794f, 0.764562f, 0.698249f, 0.335498f, 0.147686f, 0.062636f, 0.241902f, 0.432281f, 0.521996f, 0.773084f, 0.958741f, 0.117320f, 0.107004f, 0.589695f, 0.745398f, 0.848150f, 0.935832f, 0.983426f, 0.399802f, 0.380335f, 0.147809f, 0.684934f, 0.656762f, 0.862063f, 0.097258f, 0.497777f, 0.581082f, 0.241557f, 0.169025f, 0.859581f, 0.058535f, 0.470621f, 0.115834f, 0.457059f, 0.979962f, 0.423706f, 0.857125f, 0.117316f, 0.271252f, 0.403793f, 0.399812f, 0.671384f, 0.344718f, 0.713767f, 0.639187f, 0.399161f, 0.431760f, 0.614528f, 0.070042f, 0.822407f, 0.653421f, 0.726342f, 0.536923f, 0.110477f, 0.405036f, 0.405374f, 0.321043f, 0.029950f, 0.737254f, 0.109784f, 0.606308f, 0.703218f, 0.634786f, 0.959142f, 0.103298f, 0.867167f, 0.029190f, 0.534917f, 0.404244f, 0.524184f, 0.365100f, 0.190567f, 0.019123f, 0.518150f, 0.842777f, 0.373216f, 0.222864f, 0.080532f, 0.085311f, 0.221396f, 0.100014f, 0.265040f, 0.066149f, 0.065605f, 0.856276f, 0.162120f, 0.559682f, 0.773456f, 0.456410f, 0.153369f, 0.199596f, 0.432984f, 0.528234f, 0.349440f, 0.781480f, 0.751022f, 0.927212f, 0.028953f, 0.895691f, 0.392569f, 0.878372f, 0.690785f, 0.987349f, 0.759282f, 0.364545f, 0.501063f, 0.376389f, 0.364912f, 0.260904f, 0.495970f, 0.681740f, 0.277340f, 0.524380f, 0.117380f, 0.159845f, 0.046806f, 0.970731f, 0.003860f, 0.178580f, 0.612867f, 0.081370f, 0.881896f, 0.719620f, 0.966390f, 0.507636f, 0.300404f, 0.549501f, 0.930819f, 0.520761f, 0.267207f, 0.877399f, 0.371919f, 0.001383f, 0.247685f, 0.318234f, 0.858777f, 0.458503f, 0.444587f, 0.336102f, 0.880678f, 0.945027f, 0.991890f, 0.376741f}; + std::vector m = {0.966147f, 0.791880f, 0.675689f, 0.244889f, 0.216457f, 0.166048f, 0.922757f, 0.294077f, 0.453094f, 0.493958f, 0.778172f, 0.844235f, 0.139073f, 0.426904f, 0.842855f, 0.818033f, 0.102414f, 0.156383f, 0.304199f, 0.075359f, 0.424663f, 0.107618f, 0.568218f, 0.246557f, 0.596433f, 0.117526f, 0.975884f, 0.932561f, 0.391797f, 0.242179f, 0.250398f, 0.483394f, 0.039993f, 0.639705f, 0.408303f, 0.377407f, 0.809365f, 0.709035f, 0.954334f, 0.351936f, 0.897543f, 0.769967f, 0.357425f, 0.621665f, 0.288570f, 0.874400f, 0.112427f, 0.212434f, 0.183033f, 0.403026f, 0.745233f, 0.526907f, 0.487676f, 0.000546f, 0.425402f, 0.063554f, 0.208253f, 0.932394f, 0.215398f, 0.858338f, 0.802893f, 0.159146f, 0.605712f, 0.115662f, 0.727888f, 0.637462f, 0.811939f, 0.479385f, 0.914863f, 0.049349f, 0.292889f, 0.715053f}; + std::vector past_key = {0.418109f, 0.172951f, 0.107211f, 0.817339f, 0.473143f, 0.882284f, 0.733289f, 0.409726f, 0.373511f, 0.515638f, 0.889060f, 0.737279f, 0.005153f, 0.694158f, 0.919507f, 0.710456f, 0.177006f, 0.483518f, 0.140316f, 0.358995f, 0.937117f, 0.923305f, 0.282837f, 0.339631f, 0.600213f, 0.963197f, 0.147801f, 0.256917f, 0.873557f, 0.491892f, 0.898961f, 0.185518f, 0.532669f, 0.326270f, 0.316543f, 0.446877f, 0.433077f, 0.357347f, 0.914971f, 0.731744f, 0.727547f, 0.289913f, 0.577709f, 0.779179f, 0.795590f, 0.344530f, 0.770873f, 0.735894f, 0.141506f, 0.865945f, 0.441321f, 0.486410f, 0.448369f, 0.567846f, 0.621169f, 0.498180f, 0.866789f, 0.627735f, 0.401428f, 0.416692f, 0.810839f, 0.348192f, 0.211455f, 0.059383f, 0.876027f, 0.918546f, 0.120120f, 0.334474f, 0.175372f, 0.115898f, 0.899867f, 0.056877f, 0.980486f, 0.096451f, 0.863471f, 0.566506f, 0.367917f, 0.342342f, 0.757364f, 0.314573f, 0.657319f, 0.517326f, 0.484966f, 0.901162f, 0.554645f, 0.826862f, 0.725574f, 0.038557f, 0.773110f, 0.216870f, 0.903150f, 0.042924f, 0.333072f, 0.099733f, 0.475589f, 0.820022f, 0.298187f, 0.150935f, 0.330267f, 0.813880f, 0.140384f, 0.227362f, 0.068852f, 0.705710f, 0.395233f, 0.310840f, 0.718626f, 0.335978f, 0.727771f, 0.815199f, 0.217663f, 0.973819f, 0.162358f, 0.290841f, 0.179795f, 0.345506f, 0.480061f, 0.522176f, 0.853606f, 0.889448f, 0.220104f, 0.622894f, 0.111496f, 0.458970f, 0.322334f, 0.316501f, 0.482584f, 0.729828f, 0.069183f, 0.879173f, 0.734814f, 0.176499f, 0.939161f, 0.506312f, 0.999809f, 0.197259f, 0.534908f, 0.290248f, 0.304174f, 0.591065f, 0.921719f, 0.805264f, 0.723941f, 0.559174f, 0.922298f, 0.492361f, 0.873832f, 0.833982f, 0.213835f, 0.771225f, 0.012171f, 0.322830f, 0.229567f, 0.506863f, 0.736853f, 0.097676f, 0.514922f, 0.938412f, 0.228647f, 0.677141f, 0.592880f, 0.010064f, 0.475826f, 0.708770f, 0.043975f, 0.879521f, 0.520081f, 0.030661f, 0.224414f, 0.953676f, 0.582320f, 0.107473f, 0.287544f, 0.456704f, 0.020950f, 0.411616f, 0.489459f, 0.243678f, 0.588639f, 0.753240f, 0.235834f, 0.620500f, 0.639622f, 0.948540f, 0.778276f, 0.848345f, 0.490420f, 0.185349f, 0.995815f, 0.129356f, 0.471457f, 0.068093f, 0.943851f, 0.964925f, 0.719389f, 0.349993f, 0.254382f, 0.265303f, 0.127294f, 0.525809f, 0.141817f, 0.316731f, 0.626706f, 0.727544f, 0.024273f, 0.430116f, 0.652125f, 0.853246f, 0.475325f, 0.969206f, 0.265633f, 0.013509f, 0.483753f, 0.256114f, 0.823718f, 0.232773f, 0.310629f, 0.791227f, 0.715143f, 0.558051f, 0.704948f, 0.418637f, 0.005310f, 0.011355f, 0.511222f, 0.083291f, 0.051075f, 0.965517f, 0.859003f, 0.152027f, 0.000664f, 0.941668f, 0.278325f, 0.185898f, 0.691508f, 0.108904f, 0.264650f, 0.975095f, 0.639463f, 0.520678f, 0.397919f, 0.774501f, 0.140957f, 0.967338f, 0.861123f, 0.617657f, 0.042906f, 0.700856f, 0.913284f, 0.524577f, 0.354225f, 0.120277f, 0.754901f, 0.885022f, 0.100252f, 0.758985f, 0.017060f, 0.967055f, 0.615058f, 0.552439f, 0.295950f, 0.929292f, 0.265906f, 0.828147f, 0.985109f, 0.783397f, 0.518990f, 0.066074f, 0.472414f, 0.438256f, 0.202796f, 0.423588f, 0.357758f, 0.163684f, 0.441374f, 0.262800f, 0.522062f, 0.035160f, 0.906231f, 0.816364f, 0.552581f, 0.851809f, 0.962395f, 0.110522f, 0.630832f, 0.997994f, 0.987889f, 0.603323f, 0.128021f, 0.583193f, 0.002065f, 0.198911f, 0.956123f, 0.330441f, 0.638390f, 0.280860f, 0.947822f, 0.728559f, 0.329651f, 0.791761f, 0.108166f, 0.392319f, 0.221218f, 0.683726f, 0.102446f, 0.397026f, 0.276650f, 0.506343f, 0.349898f, 0.706411f, 0.024577f, 0.633987f, 0.230571f, 0.268709f, 0.800256f, 0.955568f, 0.316550f, 0.826805f, 0.103991f, 0.633982f, 0.751032f, 0.155978f, 0.426002f, 0.892707f, 0.103578f, 0.018096f, 0.590585f, 0.435532f, 0.798689f, 0.923456f, 0.299154f, 0.388404f, 0.486272f, 0.588151f, 0.983854f, 0.697330f, 0.389549f, 0.263768f, 0.944626f, 0.135548f, 0.720266f, 0.925395f, 0.664666f, 0.423054f, 0.198991f, 0.367475f, 0.706872f, 0.649534f, 0.927976f, 0.866861f, 0.816151f, 0.911451f, 0.276337f, 0.369524f, 0.379894f, 0.560451f, 0.668218f, 0.286717f, 0.019462f, 0.399222f, 0.308528f, 0.942185f, 0.888265f, 0.860311f, 0.653000f, 0.344289f, 0.548849f, 0.815225f, 0.098610f, 0.801075f, 0.041180f, 0.816421f, 0.807564f, 0.051007f, 0.627161f, 0.502453f, 0.169820f, 0.148379f, 0.773259f, 0.567693f, 0.982999f, 0.982248f, 0.992667f, 0.118616f, 0.938256f, 0.244570f, 0.458212f, 0.757407f, 0.203621f, 0.566312f, 0.185817f, 0.104736f, 0.116559f, 0.357639f, 0.004655f, 0.424854f, 0.664197f, 0.401688f, 0.085795f, 0.062689f, 0.278117f, 0.169313f, 0.965095f, 0.151230f, 0.805462f, 0.586108f, 0.569287f, 0.512081f, 0.971763f, 0.363845f, 0.787916f, 0.555294f, 0.395634f, 0.955466f, 0.598316f, 0.118917f, 0.417539f, 0.781582f, 0.693747f, 0.916340f, 0.259377f, 0.758194f, 0.459875f, 0.573610f, 0.955047f, 0.979286f, 0.861591f, 0.359097f, 0.887701f, 0.638609f, 0.429997f, 0.035743f, 0.770128f, 0.502106f, 0.786188f, 0.748023f, 0.793567f, 0.300651f, 0.800799f, 0.548846f, 0.473326f, 0.675126f, 0.021359f, 0.102317f, 0.292177f, 0.982990f, 0.139746f, 0.330596f, 0.051053f, 0.331269f, 0.320326f, 0.946807f, 0.845154f, 0.382764f, 0.024769f, 0.831031f, 0.660536f, 0.152364f, 0.996071f, 0.100233f, 0.867115f, 0.294266f, 0.435353f, 0.795457f, 0.677508f, 0.937864f, 0.621140f, 0.097810f, 0.884360f, 0.769156f, 0.711870f, 0.053734f, 0.396223f, 0.167436f, 0.821904f, 0.700529f, 0.883078f, 0.966575f, 0.774748f, 0.994233f, 0.614770f, 0.037130f, 0.014252f, 0.342104f, 0.823472f, 0.866135f, 0.960813f, 0.065121f, 0.044571f, 0.913284f, 0.305047f, 0.557987f, 0.982445f, 0.400449f, 0.665871f, 0.400880f, 0.768195f, 0.527715f, 0.237523f, 0.271306f, 0.258059f, 0.532320f, 0.703189f, 0.949280f, 0.694087f, 0.781193f, 0.168926f, 0.374063f, 0.413780f, 0.686380f, 0.295892f, 0.303292f, 0.355889f, 0.810302f, 0.577590f, 0.075277f, 0.078246f, 0.371287f, 0.766591f, 0.688683f, 0.707982f, 0.767210f, 0.287153f, 0.548256f, 0.543353f, 0.739632f, 0.956871f, 0.277990f, 0.793282f, 0.659971f, 0.580238f, 0.774880f, 0.944032f, 0.036691f, 0.147400f, 0.756287f, 0.083791f, 0.516124f, 0.219861f, 0.274296f, 0.701840f, 0.030193f, 0.873319f, 0.444479f, 0.502393f, 0.540048f, 0.645544f, 0.344857f, 0.101107f, 0.318379f, 0.168142f, 0.556133f, 0.318029f, 0.958067f, 0.965734f, 0.620126f, 0.617497f, 0.985379f, 0.887283f, 0.765070f, 0.313591f, 0.365539f, 0.201267f, 0.487148f, 0.990369f, 0.912151f, 0.118349f, 0.025190f, 0.898638f, 0.537170f, 0.200190f, 0.673653f, 0.644223f, 0.122086f, 0.259600f, 0.060078f}; + std::vector past_value = {0.209860f, 0.132306f, 0.193236f, 0.685467f, 0.049500f, 0.101855f, 0.134174f, 0.316541f, 0.298750f, 0.255064f, 0.750537f, 0.998023f, 0.533978f, 0.944203f, 0.396610f, 0.106682f, 0.408774f, 0.296128f, 0.493407f, 0.657044f, 0.461050f, 0.935161f, 0.884765f, 0.701978f, 0.489685f, 0.131687f, 0.397014f, 0.704402f, 0.284886f, 0.103988f, 0.907898f, 0.709051f, 0.615276f, 0.792499f, 0.835646f, 0.483459f, 0.881188f, 0.916419f, 0.271551f, 0.607545f, 0.526584f, 0.537946f, 0.937663f, 0.305189f, 0.983434f, 0.902131f, 0.458723f, 0.817453f, 0.769047f, 0.677895f, 0.319834f, 0.196451f, 0.671528f, 0.842973f, 0.016253f, 0.642803f, 0.442873f, 0.898088f, 0.321473f, 0.474185f, 0.514767f, 0.140440f, 0.712892f, 0.830476f, 0.057909f, 0.291389f, 0.038045f, 0.956544f, 0.667169f, 0.964200f, 0.531494f, 0.802069f, 0.374414f, 0.353819f, 0.378268f, 0.657862f, 0.359453f, 0.900367f, 0.983275f, 0.030427f, 0.193623f, 0.112250f, 0.042364f, 0.227741f, 0.446793f, 0.836990f, 0.221824f, 0.493945f, 0.929619f, 0.667215f, 0.798079f, 0.550994f, 0.980466f, 0.588662f, 0.045511f, 0.197983f, 0.404774f, 0.601277f, 0.771931f, 0.413086f, 0.710058f, 0.789869f, 0.317260f, 0.979270f, 0.649656f, 0.880998f, 0.555938f, 0.741603f, 0.770544f, 0.908248f, 0.150350f, 0.558283f, 0.428379f, 0.923159f, 0.105095f, 0.982574f, 0.875451f, 0.073826f, 0.490966f, 0.717560f, 0.738152f, 0.906494f, 0.799865f, 0.310930f, 0.498435f, 0.701786f, 0.138437f, 0.193991f, 0.481042f, 0.298246f, 0.862559f, 0.586277f, 0.348665f, 0.848833f, 0.804878f, 0.998355f, 0.847308f, 0.414457f, 0.127499f, 0.840641f, 0.059758f, 0.350271f, 0.919738f, 0.960766f, 0.640565f, 0.688648f, 0.042454f, 0.514480f, 0.546868f, 0.340101f, 0.068597f, 0.228908f, 0.357984f, 0.435142f, 0.590927f, 0.722392f, 0.317632f, 0.328954f, 0.019692f, 0.040875f, 0.257822f, 0.740245f, 0.628314f, 0.769789f, 0.768919f, 0.856567f, 0.720319f, 0.979011f, 0.898825f, 0.586717f, 0.588158f, 0.034267f, 0.998527f, 0.131576f, 0.740347f, 0.821015f, 0.373055f, 0.196852f, 0.098760f, 0.748606f, 0.452654f, 0.713718f, 0.915408f, 0.146584f, 0.919171f, 0.411626f, 0.305267f, 0.943062f, 0.990652f, 0.198892f, 0.656838f, 0.106495f, 0.650914f, 0.827313f, 0.684499f, 0.417333f, 0.383066f, 0.393122f, 0.589712f, 0.881567f, 0.929066f, 0.053530f, 0.181622f, 0.112224f, 0.193335f, 0.346608f, 0.506532f, 0.629461f, 0.732142f, 0.890112f, 0.989088f, 0.662856f, 0.845365f, 0.778039f, 0.307532f, 0.875692f, 0.042763f, 0.000367f, 0.273733f, 0.462098f, 0.638363f, 0.101770f, 0.673010f, 0.801816f, 0.185313f, 0.415125f, 0.519985f, 0.451807f, 0.799830f, 0.960522f, 0.798953f, 0.077993f, 0.804936f, 0.066596f, 0.235970f, 0.153097f, 0.197519f, 0.528315f, 0.671690f, 0.470321f, 0.959696f, 0.240292f, 0.763140f, 0.870182f, 0.562066f, 0.456223f, 0.596184f, 0.428810f, 0.555194f, 0.416934f, 0.400470f, 0.695346f, 0.092851f, 0.166542f, 0.851198f, 0.771077f, 0.281454f, 0.377269f, 0.926027f, 0.818077f, 0.614346f, 0.221490f, 0.044252f, 0.431258f, 0.672627f, 0.828480f, 0.852689f, 0.032776f, 0.244157f, 0.339095f, 0.188732f, 0.802975f, 0.767466f, 0.516833f, 0.982926f, 0.144059f, 0.899652f, 0.116463f, 0.163182f, 0.696219f, 0.109570f, 0.565845f, 0.420234f, 0.728474f, 0.900675f, 0.769872f, 0.849690f, 0.032945f, 0.310196f, 0.515433f, 0.415953f, 0.231255f, 0.307874f, 0.945431f, 0.294181f, 0.353904f, 0.003710f, 0.845078f, 0.154841f, 0.204144f, 0.255265f, 0.884622f, 0.206451f, 0.797526f, 0.808049f, 0.927021f, 0.115561f, 0.217279f, 0.742898f, 0.196001f, 0.286330f, 0.166742f, 0.172697f, 0.481553f, 0.109683f, 0.321698f, 0.426594f, 0.024548f, 0.388333f, 0.094122f, 0.493579f, 0.825738f, 0.818422f, 0.080449f, 0.601228f, 0.834586f, 0.237973f, 0.761927f, 0.890764f, 0.806124f, 0.107301f, 0.009060f, 0.191724f, 0.270477f, 0.616183f, 0.384273f, 0.703407f, 0.353075f, 0.154425f, 0.312690f, 0.884324f, 0.958532f, 0.207513f, 0.788468f, 0.273349f, 0.887132f, 0.165546f, 0.665960f, 0.084211f, 0.973893f, 0.700633f, 0.841816f, 0.566669f, 0.476801f, 0.621882f, 0.528742f, 0.469384f, 0.759450f, 0.178201f, 0.171172f, 0.431843f, 0.320748f, 0.074125f, 0.844471f, 0.771603f, 0.543921f, 0.979325f, 0.072600f, 0.766669f, 0.266370f, 0.368599f, 0.219279f, 0.789038f, 0.144240f, 0.840017f, 0.661578f, 0.059023f, 0.810982f, 0.627756f, 0.904982f, 0.748722f, 0.561121f, 0.836547f, 0.278050f, 0.546950f, 0.293617f, 0.968204f, 0.226196f, 0.015738f, 0.325855f, 0.502509f, 0.028363f, 0.559248f, 0.874283f, 0.704732f, 0.622968f, 0.955962f, 0.958279f, 0.824266f, 0.607742f, 0.487765f, 0.013316f, 0.606262f, 0.989088f, 0.818101f, 0.340605f, 0.152047f, 0.784059f, 0.743938f, 0.967047f, 0.874842f, 0.555663f, 0.101284f, 0.483501f, 0.313695f, 0.512408f, 0.301702f, 0.861823f, 0.844327f, 0.315465f, 0.599581f, 0.430181f, 0.909093f, 0.187361f, 0.697728f, 0.970375f, 0.175276f, 0.201966f, 0.693723f, 0.779154f, 0.490549f, 0.609686f, 0.212682f, 0.476614f, 0.112072f, 0.321422f, 0.284780f, 0.444625f, 0.930126f, 0.181268f, 0.401388f, 0.615597f, 0.946557f, 0.133148f, 0.917877f, 0.081054f, 0.480741f, 0.454590f, 0.209603f, 0.347460f, 0.454165f, 0.865211f, 0.955064f, 0.518926f, 0.870100f, 0.608172f, 0.349087f, 0.194194f, 0.413135f, 0.522824f, 0.044443f, 0.145841f, 0.600184f, 0.225002f, 0.837326f, 0.326942f, 0.104834f, 0.083531f, 0.937123f, 0.118020f, 0.140910f, 0.862666f, 0.254288f, 0.665951f, 0.816726f, 0.607181f, 0.957489f, 0.708883f, 0.112752f, 0.558410f, 0.718186f, 0.801957f, 0.026321f, 0.718879f, 0.825681f, 0.746834f, 0.512349f, 0.458021f, 0.549419f, 0.704644f, 0.922914f, 0.617035f, 0.887834f, 0.701257f, 0.068336f, 0.500828f, 0.286486f, 0.285175f, 0.355928f, 0.314733f, 0.578610f, 0.683601f, 0.268749f, 0.129763f, 0.058809f, 0.575753f, 0.186130f, 0.009248f, 0.927753f, 0.537140f, 0.092448f, 0.842921f, 0.983203f, 0.448601f, 0.042490f, 0.117546f, 0.381654f, 0.885523f, 0.148039f, 0.823990f, 0.014976f, 0.457389f, 0.644397f, 0.060379f, 0.614763f, 0.944404f, 0.160260f, 0.729611f, 0.609094f, 0.185116f, 0.006203f, 0.009284f, 0.532092f, 0.942779f, 0.644299f, 0.714300f, 0.493865f, 0.581889f, 0.126368f, 0.876821f, 0.760793f, 0.998199f, 0.297723f, 0.227018f, 0.125162f, 0.964210f, 0.780885f, 0.166325f, 0.552686f, 0.413768f, 0.151486f, 0.162073f, 0.963470f, 0.304964f, 0.941439f, 0.075611f, 0.460803f, 0.129619f, 0.004787f, 0.553766f, 0.113894f, 0.722025f, 0.698116f, 0.176333f, 0.941742f, 0.721043f, 0.297970f, 0.709234f, 0.731930f, 0.342226f, 0.375589f, 0.359107f, 0.616618f, 0.900410f, 0.173193f, 0.875200f}; + std::vector y = {0.457694f, 0.455757f, 0.445489f, 0.526766f, 0.477853f, 0.608758f, 0.406654f, 0.519316f, 0.444463f, 0.465842f, 0.428262f, 0.540428f, 0.477282f, 0.638667f, 0.474591f, 0.547811f, 0.457420f, 0.470657f, 0.487116f, 0.542242f, 0.482364f, 0.617841f, 0.476829f, 0.557317f, 0.463370f, 0.432599f, 0.412642f, 0.520960f, 0.479831f, 0.589828f, 0.446331f, 0.612812f, 0.585487f, 0.538315f, 0.504264f, 0.615235f, 0.527800f, 0.515899f, 0.536401f, 0.541573f, 0.578147f, 0.544553f, 0.531175f, 0.583502f, 0.528233f, 0.518028f, 0.562917f, 0.588512f, 0.599006f, 0.525119f, 0.535656f, 0.623945f, 0.521523f, 0.515306f, 0.544257f, 0.592741f, 0.600172f, 0.529797f, 0.490615f, 0.601856f, 0.495671f, 0.500725f, 0.555493f, 0.482300f, 0.538304f, 0.469695f, 0.555198f, 0.489711f, 0.521836f, 0.485628f, 0.493937f, 0.562992f, 0.521894f, 0.489056f, 0.584299f, 0.474376f, 0.493005f, 0.475963f, 0.460919f, 0.567615f, 0.547787f, 0.466202f, 0.536014f, 0.473239f, 0.485554f, 0.498408f, 0.501733f, 0.586437f, 0.517314f, 0.440046f, 0.514271f, 0.545266f, 0.487437f, 0.481043f, 0.518498f, 0.568266f, 0.514357f, 0.572526f, 0.423650f, 0.474643f, 0.492550f, 0.533325f, 0.512998f, 0.452411f, 0.526065f, 0.535346f, 0.407074f, 0.502433f, 0.501283f, 0.528505f, 0.510491f, 0.402870f, 0.516862f, 0.596280f, 0.397160f, 0.469242f, 0.458194f, 0.537358f, 0.510243f, 0.439715f, 0.530736f, 0.580630f, 0.437646f, 0.462414f, 0.484492f, 0.477003f, 0.476393f, 0.431391f, 0.481805f, 0.420751f, 0.544359f, 0.440140f, 0.533953f, 0.453877f, 0.460864f, 0.446440f, 0.454282f, 0.416850f, 0.494072f, 0.462208f, 0.524801f, 0.453293f, 0.493179f, 0.462526f, 0.489181f, 0.452340f, 0.570383f, 0.422193f, 0.524420f, 0.468229f, 0.489729f, 0.444768f, 0.534646f, 0.457197f, 0.522207f, 0.400594f, 0.538509f, 0.489581f, 0.457599f, 0.488340f, 0.549355f, 0.482543f, 0.431908f, 0.352921f, 0.633369f, 0.690998f, 0.314418f, 0.542520f, 0.580878f, 0.489810f, 0.451832f, 0.346453f, 0.599024f, 0.630982f, 0.310195f, 0.532405f, 0.568864f, 0.486514f, 0.432211f, 0.345150f, 0.586195f, 0.659745f, 0.269926f, 0.528033f, 0.509392f, 0.511314f, 0.378251f, 0.319656f, 0.601292f, 0.726670f, 0.338636f, 0.564731f}; + std::vector present_key = {0.418109f, 0.172951f, 0.107211f, 0.817339f, 0.473143f, 0.882284f, 0.733289f, 0.409726f, 0.373511f, 0.515638f, 0.889060f, 0.737279f, 0.005153f, 0.694158f, 0.919507f, 0.710456f, 0.177006f, 0.483518f, 0.140316f, 0.358995f, 0.937117f, 0.923305f, 0.282837f, 0.339631f, 0.600213f, 0.963197f, 0.147801f, 0.256917f, 0.873557f, 0.491892f, 0.898961f, 0.185518f, 0.532669f, 0.326270f, 0.316543f, 0.446877f, 0.433077f, 0.357347f, 0.914971f, 0.731744f, 0.727547f, 0.289913f, 0.577709f, 0.779179f, 0.795590f, 0.344530f, 0.770873f, 0.735894f, 0.141506f, 0.865945f, 0.441321f, 0.486410f, 0.448369f, 0.567846f, 0.621169f, 0.498180f, 0.866789f, 0.627735f, 0.401428f, 0.416692f, 0.810839f, 0.348192f, 0.211455f, 0.059383f, 0.876027f, 0.918546f, 0.120120f, 0.334474f, 0.175372f, 0.115898f, 0.899867f, 0.056877f, 0.980486f, 0.096451f, 0.863471f, 0.566506f, 0.367917f, 0.342342f, 0.757364f, 0.314573f, 0.657319f, 0.517326f, 0.484966f, 0.901162f, 0.554645f, 0.826862f, 0.725574f, 0.038557f, 0.773110f, 0.216870f, 0.903150f, 0.042924f, 0.333072f, 0.099733f, 0.475589f, 0.820022f, 0.186193f, 0.944372f, 0.739551f, 0.490459f, 0.227415f, 0.254356f, 0.058029f, 0.434417f, 0.311796f, 0.696343f, 0.377752f, 0.179604f, 0.024679f, 0.067250f, 0.679393f, 0.453697f, 0.536579f, 0.896671f, 0.990339f, 0.216897f, 0.663078f, 0.263322f, 0.020651f, 0.758379f, 0.320017f, 0.383464f, 0.588317f, 0.831048f, 0.628982f, 0.872651f, 0.273542f, 0.798047f, 0.185636f, 0.952792f, 0.687488f, 0.215508f, 0.947371f, 0.730856f, 0.253942f, 0.213312f, 0.518201f, 0.025663f, 0.207470f, 0.424685f, 0.374170f, 0.463575f, 0.277629f, 0.586784f, 0.298187f, 0.150935f, 0.330267f, 0.813880f, 0.140384f, 0.227362f, 0.068852f, 0.705710f, 0.395233f, 0.310840f, 0.718626f, 0.335978f, 0.727771f, 0.815199f, 0.217663f, 0.973819f, 0.162358f, 0.290841f, 0.179795f, 0.345506f, 0.480061f, 0.522176f, 0.853606f, 0.889448f, 0.220104f, 0.622894f, 0.111496f, 0.458970f, 0.322334f, 0.316501f, 0.482584f, 0.729828f, 0.069183f, 0.879173f, 0.734814f, 0.176499f, 0.939161f, 0.506312f, 0.999809f, 0.197259f, 0.534908f, 0.290248f, 0.304174f, 0.591065f, 0.921719f, 0.805264f, 0.723941f, 0.559174f, 0.922298f, 0.492361f, 0.873832f, 0.833982f, 0.213835f, 0.771225f, 0.012171f, 0.322830f, 0.229567f, 0.506863f, 0.736853f, 0.097676f, 0.514922f, 0.938412f, 0.228647f, 0.677141f, 0.592880f, 0.010064f, 0.475826f, 0.708770f, 0.043975f, 0.879521f, 0.520081f, 0.030661f, 0.224414f, 0.953676f, 0.582320f, 0.107473f, 0.287544f, 0.456704f, 0.020950f, 0.411616f, 0.489459f, 0.243678f, 0.588639f, 0.753240f, 0.235834f, 0.620500f, 0.639622f, 0.948540f, 0.778276f, 0.848345f, 0.490420f, 0.185349f, 0.995815f, 0.129356f, 0.471457f, 0.068093f, 0.863856f, 0.117532f, 0.517379f, 0.132068f, 0.716860f, 0.396060f, 0.565421f, 0.183280f, 0.144848f, 0.488056f, 0.355613f, 0.940432f, 0.765325f, 0.748664f, 0.903720f, 0.083422f, 0.552192f, 0.584476f, 0.961936f, 0.292148f, 0.240829f, 0.100294f, 0.016430f, 0.929529f, 0.669917f, 0.785153f, 0.281730f, 0.586410f, 0.063955f, 0.485628f, 0.977495f, 0.876505f, 0.338159f, 0.961570f, 0.231702f, 0.949319f, 0.941378f, 0.799203f, 0.630448f, 0.874288f, 0.293020f, 0.848944f, 0.617877f, 0.013237f, 0.347234f, 0.148141f, 0.981829f, 0.478370f, 0.943851f, 0.964925f, 0.719389f, 0.349993f, 0.254382f, 0.265303f, 0.127294f, 0.525809f, 0.141817f, 0.316731f, 0.626706f, 0.727544f, 0.024273f, 0.430116f, 0.652125f, 0.853246f, 0.475325f, 0.969206f, 0.265633f, 0.013509f, 0.483753f, 0.256114f, 0.823718f, 0.232773f, 0.310629f, 0.791227f, 0.715143f, 0.558051f, 0.704948f, 0.418637f, 0.005310f, 0.011355f, 0.511222f, 0.083291f, 0.051075f, 0.965517f, 0.859003f, 0.152027f, 0.000664f, 0.941668f, 0.278325f, 0.185898f, 0.691508f, 0.108904f, 0.264650f, 0.975095f, 0.639463f, 0.520678f, 0.397919f, 0.774501f, 0.140957f, 0.967338f, 0.861123f, 0.617657f, 0.042906f, 0.700856f, 0.913284f, 0.524577f, 0.354225f, 0.120277f, 0.754901f, 0.885022f, 0.100252f, 0.758985f, 0.017060f, 0.967055f, 0.615058f, 0.552439f, 0.295950f, 0.929292f, 0.265906f, 0.828147f, 0.985109f, 0.783397f, 0.518990f, 0.066074f, 0.472414f, 0.438256f, 0.202796f, 0.423588f, 0.357758f, 0.163684f, 0.441374f, 0.262800f, 0.522062f, 0.035160f, 0.906231f, 0.816364f, 0.552581f, 0.851809f, 0.962395f, 0.110522f, 0.630832f, 0.997994f, 0.987889f, 0.603323f, 0.497391f, 0.639473f, 0.368585f, 0.136900f, 0.822118f, 0.189848f, 0.511319f, 0.224317f, 0.097844f, 0.862191f, 0.972919f, 0.960835f, 0.906555f, 0.774047f, 0.333145f, 0.081101f, 0.407241f, 0.232234f, 0.132488f, 0.053427f, 0.725594f, 0.011427f, 0.770581f, 0.146947f, 0.079522f, 0.089603f, 0.672048f, 0.245367f, 0.420539f, 0.557369f, 0.860551f, 0.727044f, 0.270328f, 0.131483f, 0.055374f, 0.301599f, 0.262118f, 0.456141f, 0.683281f, 0.695625f, 0.283519f, 0.379927f, 0.181151f, 0.788545f, 0.056848f, 0.696997f, 0.778695f, 0.777408f, 0.128021f, 0.583193f, 0.002065f, 0.198911f, 0.956123f, 0.330441f, 0.638390f, 0.280860f, 0.947822f, 0.728559f, 0.329651f, 0.791761f, 0.108166f, 0.392319f, 0.221218f, 0.683726f, 0.102446f, 0.397026f, 0.276650f, 0.506343f, 0.349898f, 0.706411f, 0.024577f, 0.633987f, 0.230571f, 0.268709f, 0.800256f, 0.955568f, 0.316550f, 0.826805f, 0.103991f, 0.633982f, 0.751032f, 0.155978f, 0.426002f, 0.892707f, 0.103578f, 0.018096f, 0.590585f, 0.435532f, 0.798689f, 0.923456f, 0.299154f, 0.388404f, 0.486272f, 0.588151f, 0.983854f, 0.697330f, 0.389549f, 0.263768f, 0.944626f, 0.135548f, 0.720266f, 0.925395f, 0.664666f, 0.423054f, 0.198991f, 0.367475f, 0.706872f, 0.649534f, 0.927976f, 0.866861f, 0.816151f, 0.911451f, 0.276337f, 0.369524f, 0.379894f, 0.560451f, 0.668218f, 0.286717f, 0.019462f, 0.399222f, 0.308528f, 0.942185f, 0.888265f, 0.860311f, 0.653000f, 0.344289f, 0.548849f, 0.815225f, 0.098610f, 0.801075f, 0.041180f, 0.816421f, 0.807564f, 0.051007f, 0.627161f, 0.502453f, 0.169820f, 0.148379f, 0.773259f, 0.567693f, 0.982999f, 0.982248f, 0.992667f, 0.118616f, 0.259423f, 0.373813f, 0.587600f, 0.272822f, 0.370853f, 0.197054f, 0.459856f, 0.044612f, 0.799796f, 0.076956f, 0.518835f, 0.306810f, 0.577543f, 0.959433f, 0.645570f, 0.035362f, 0.430402f, 0.510017f, 0.536178f, 0.681392f, 0.277596f, 0.128861f, 0.392676f, 0.956406f, 0.187131f, 0.903984f, 0.543806f, 0.456911f, 0.882041f, 0.458604f, 0.724168f, 0.399025f, 0.904044f, 0.690025f, 0.699622f, 0.327720f, 0.756779f, 0.636061f, 0.240020f, 0.160539f, 0.796391f, 0.959167f, 0.458139f, 0.590984f, 0.857723f, 0.457223f, 0.951874f, 0.575751f, 0.938256f, 0.244570f, 0.458212f, 0.757407f, 0.203621f, 0.566312f, 0.185817f, 0.104736f, 0.116559f, 0.357639f, 0.004655f, 0.424854f, 0.664197f, 0.401688f, 0.085795f, 0.062689f, 0.278117f, 0.169313f, 0.965095f, 0.151230f, 0.805462f, 0.586108f, 0.569287f, 0.512081f, 0.971763f, 0.363845f, 0.787916f, 0.555294f, 0.395634f, 0.955466f, 0.598316f, 0.118917f, 0.417539f, 0.781582f, 0.693747f, 0.916340f, 0.259377f, 0.758194f, 0.459875f, 0.573610f, 0.955047f, 0.979286f, 0.861591f, 0.359097f, 0.887701f, 0.638609f, 0.429997f, 0.035743f, 0.770128f, 0.502106f, 0.786188f, 0.748023f, 0.793567f, 0.300651f, 0.800799f, 0.548846f, 0.473326f, 0.675126f, 0.021359f, 0.102317f, 0.292177f, 0.982990f, 0.139746f, 0.330596f, 0.051053f, 0.331269f, 0.320326f, 0.946807f, 0.845154f, 0.382764f, 0.024769f, 0.831031f, 0.660536f, 0.152364f, 0.996071f, 0.100233f, 0.867115f, 0.294266f, 0.435353f, 0.795457f, 0.677508f, 0.937864f, 0.621140f, 0.097810f, 0.884360f, 0.769156f, 0.711870f, 0.053734f, 0.396223f, 0.167436f, 0.821904f, 0.700529f, 0.883078f, 0.966575f, 0.774748f, 0.994233f, 0.820767f, 0.908844f, 0.815524f, 0.159414f, 0.628898f, 0.398434f, 0.062713f, 0.424032f, 0.258684f, 0.849038f, 0.033305f, 0.958983f, 0.355369f, 0.356707f, 0.016329f, 0.185232f, 0.401260f, 0.929291f, 0.099615f, 0.945302f, 0.869489f, 0.454162f, 0.326701f, 0.232744f, 0.614465f, 0.033075f, 0.015606f, 0.428796f, 0.068074f, 0.251941f, 0.221161f, 0.253191f, 0.131055f, 0.012036f, 0.115484f, 0.618480f, 0.974256f, 0.990345f, 0.409054f, 0.162954f, 0.638762f, 0.490305f, 0.989410f, 0.065304f, 0.783234f, 0.288399f, 0.241419f, 0.662505f, 0.614770f, 0.037130f, 0.014252f, 0.342104f, 0.823472f, 0.866135f, 0.960813f, 0.065121f, 0.044571f, 0.913284f, 0.305047f, 0.557987f, 0.982445f, 0.400449f, 0.665871f, 0.400880f, 0.768195f, 0.527715f, 0.237523f, 0.271306f, 0.258059f, 0.532320f, 0.703189f, 0.949280f, 0.694087f, 0.781193f, 0.168926f, 0.374063f, 0.413780f, 0.686380f, 0.295892f, 0.303292f, 0.355889f, 0.810302f, 0.577590f, 0.075277f, 0.078246f, 0.371287f, 0.766591f, 0.688683f, 0.707982f, 0.767210f, 0.287153f, 0.548256f, 0.543353f, 0.739632f, 0.956871f, 0.277990f, 0.793282f, 0.659971f, 0.580238f, 0.774880f, 0.944032f, 0.036691f, 0.147400f, 0.756287f, 0.083791f, 0.516124f, 0.219861f, 0.274296f, 0.701840f, 0.030193f, 0.873319f, 0.444479f, 0.502393f, 0.540048f, 0.645544f, 0.344857f, 0.101107f, 0.318379f, 0.168142f, 0.556133f, 0.318029f, 0.958067f, 0.965734f, 0.620126f, 0.617497f, 0.985379f, 0.887283f, 0.765070f, 0.313591f, 0.365539f, 0.201267f, 0.487148f, 0.990369f, 0.912151f, 0.118349f, 0.025190f, 0.898638f, 0.537170f, 0.200190f, 0.673653f, 0.644223f, 0.122086f, 0.259600f, 0.060078f, 0.246063f, 0.665859f, 0.517309f, 0.424089f, 0.554688f, 0.287052f, 0.706575f, 0.414857f, 0.360546f, 0.828657f, 0.924967f, 0.046007f, 0.232627f, 0.348519f, 0.814966f, 0.985491f, 0.968972f, 0.904948f, 0.296556f, 0.992011f, 0.249420f, 0.105906f, 0.950953f, 0.233420f, 0.689768f, 0.058356f, 0.730709f, 0.881720f, 0.272437f, 0.379057f, 0.374296f, 0.748788f, 0.237807f, 0.171853f, 0.449292f, 0.304468f, 0.839189f, 0.237742f, 0.502389f, 0.942584f, 0.633998f, 0.867289f, 0.940210f, 0.750765f, 0.699575f, 0.967966f, 0.994401f, 0.451822f}; + std::vector present_value = {0.209860f, 0.132306f, 0.193236f, 0.685467f, 0.049500f, 0.101855f, 0.134174f, 0.316541f, 0.298750f, 0.255064f, 0.750537f, 0.998023f, 0.533978f, 0.944203f, 0.396610f, 0.106682f, 0.408774f, 0.296128f, 0.493407f, 0.657044f, 0.461050f, 0.935161f, 0.884765f, 0.701978f, 0.489685f, 0.131687f, 0.397014f, 0.704402f, 0.284886f, 0.103988f, 0.907898f, 0.709051f, 0.615276f, 0.792499f, 0.835646f, 0.483459f, 0.881188f, 0.916419f, 0.271551f, 0.607545f, 0.526584f, 0.537946f, 0.937663f, 0.305189f, 0.983434f, 0.902131f, 0.458723f, 0.817453f, 0.769047f, 0.677895f, 0.319834f, 0.196451f, 0.671528f, 0.842973f, 0.016253f, 0.642803f, 0.442873f, 0.898088f, 0.321473f, 0.474185f, 0.514767f, 0.140440f, 0.712892f, 0.830476f, 0.057909f, 0.291389f, 0.038045f, 0.956544f, 0.667169f, 0.964200f, 0.531494f, 0.802069f, 0.374414f, 0.353819f, 0.378268f, 0.657862f, 0.359453f, 0.900367f, 0.983275f, 0.030427f, 0.193623f, 0.112250f, 0.042364f, 0.227741f, 0.446793f, 0.836990f, 0.221824f, 0.493945f, 0.929619f, 0.667215f, 0.798079f, 0.550994f, 0.980466f, 0.588662f, 0.045511f, 0.197983f, 0.070870f, 0.292794f, 0.152355f, 0.417486f, 0.131289f, 0.604118f, 0.382808f, 0.895386f, 0.967795f, 0.546885f, 0.274824f, 0.592230f, 0.896761f, 0.406733f, 0.552078f, 0.271653f, 0.455444f, 0.401714f, 0.248413f, 0.505866f, 0.310381f, 0.373035f, 0.524970f, 0.750595f, 0.333507f, 0.924159f, 0.862319f, 0.048690f, 0.253643f, 0.446136f, 0.104628f, 0.348476f, 0.740098f, 0.680514f, 0.622384f, 0.710528f, 0.204924f, 0.341698f, 0.676242f, 0.879235f, 0.543678f, 0.282700f, 0.030235f, 0.710337f, 0.007884f, 0.372679f, 0.530537f, 0.922111f, 0.404774f, 0.601277f, 0.771931f, 0.413086f, 0.710058f, 0.789869f, 0.317260f, 0.979270f, 0.649656f, 0.880998f, 0.555938f, 0.741603f, 0.770544f, 0.908248f, 0.150350f, 0.558283f, 0.428379f, 0.923159f, 0.105095f, 0.982574f, 0.875451f, 0.073826f, 0.490966f, 0.717560f, 0.738152f, 0.906494f, 0.799865f, 0.310930f, 0.498435f, 0.701786f, 0.138437f, 0.193991f, 0.481042f, 0.298246f, 0.862559f, 0.586277f, 0.348665f, 0.848833f, 0.804878f, 0.998355f, 0.847308f, 0.414457f, 0.127499f, 0.840641f, 0.059758f, 0.350271f, 0.919738f, 0.960766f, 0.640565f, 0.688648f, 0.042454f, 0.514480f, 0.546868f, 0.340101f, 0.068597f, 0.228908f, 0.357984f, 0.435142f, 0.590927f, 0.722392f, 0.317632f, 0.328954f, 0.019692f, 0.040875f, 0.257822f, 0.740245f, 0.628314f, 0.769789f, 0.768919f, 0.856567f, 0.720319f, 0.979011f, 0.898825f, 0.586717f, 0.588158f, 0.034267f, 0.998527f, 0.131576f, 0.740347f, 0.821015f, 0.373055f, 0.196852f, 0.098760f, 0.748606f, 0.452654f, 0.713718f, 0.915408f, 0.146584f, 0.919171f, 0.411626f, 0.305267f, 0.943062f, 0.990652f, 0.198892f, 0.656838f, 0.106495f, 0.089495f, 0.405942f, 0.024313f, 0.342611f, 0.622231f, 0.279068f, 0.209750f, 0.115703f, 0.577140f, 0.695270f, 0.671957f, 0.948861f, 0.002703f, 0.647197f, 0.600392f, 0.588740f, 0.962770f, 0.016872f, 0.696482f, 0.813679f, 0.509807f, 0.333965f, 0.790840f, 0.097243f, 0.442036f, 0.519952f, 0.693956f, 0.090886f, 0.227759f, 0.410302f, 0.623295f, 0.886961f, 0.618826f, 0.133461f, 0.980580f, 0.871786f, 0.502721f, 0.922348f, 0.541381f, 0.923306f, 0.829897f, 0.968286f, 0.919783f, 0.036034f, 0.174772f, 0.389135f, 0.952143f, 0.300029f, 0.650914f, 0.827313f, 0.684499f, 0.417333f, 0.383066f, 0.393122f, 0.589712f, 0.881567f, 0.929066f, 0.053530f, 0.181622f, 0.112224f, 0.193335f, 0.346608f, 0.506532f, 0.629461f, 0.732142f, 0.890112f, 0.989088f, 0.662856f, 0.845365f, 0.778039f, 0.307532f, 0.875692f, 0.042763f, 0.000367f, 0.273733f, 0.462098f, 0.638363f, 0.101770f, 0.673010f, 0.801816f, 0.185313f, 0.415125f, 0.519985f, 0.451807f, 0.799830f, 0.960522f, 0.798953f, 0.077993f, 0.804936f, 0.066596f, 0.235970f, 0.153097f, 0.197519f, 0.528315f, 0.671690f, 0.470321f, 0.959696f, 0.240292f, 0.763140f, 0.870182f, 0.562066f, 0.456223f, 0.596184f, 0.428810f, 0.555194f, 0.416934f, 0.400470f, 0.695346f, 0.092851f, 0.166542f, 0.851198f, 0.771077f, 0.281454f, 0.377269f, 0.926027f, 0.818077f, 0.614346f, 0.221490f, 0.044252f, 0.431258f, 0.672627f, 0.828480f, 0.852689f, 0.032776f, 0.244157f, 0.339095f, 0.188732f, 0.802975f, 0.767466f, 0.516833f, 0.982926f, 0.144059f, 0.899652f, 0.116463f, 0.163182f, 0.696219f, 0.109570f, 0.565845f, 0.420234f, 0.728474f, 0.900675f, 0.769872f, 0.849690f, 0.032945f, 0.160468f, 0.886305f, 0.446394f, 0.907876f, 0.160230f, 0.661117f, 0.440264f, 0.076487f, 0.696463f, 0.247399f, 0.039616f, 0.059944f, 0.061079f, 0.907733f, 0.739884f, 0.898062f, 0.672582f, 0.528940f, 0.304446f, 0.997962f, 0.362189f, 0.470649f, 0.378245f, 0.979527f, 0.174658f, 0.327988f, 0.680349f, 0.063208f, 0.607249f, 0.477646f, 0.284000f, 0.238413f, 0.514513f, 0.367928f, 0.456520f, 0.337477f, 0.970494f, 0.133439f, 0.096804f, 0.343392f, 0.591027f, 0.659176f, 0.397257f, 0.999278f, 0.351893f, 0.721407f, 0.637583f, 0.813054f, 0.310196f, 0.515433f, 0.415953f, 0.231255f, 0.307874f, 0.945431f, 0.294181f, 0.353904f, 0.003710f, 0.845078f, 0.154841f, 0.204144f, 0.255265f, 0.884622f, 0.206451f, 0.797526f, 0.808049f, 0.927021f, 0.115561f, 0.217279f, 0.742898f, 0.196001f, 0.286330f, 0.166742f, 0.172697f, 0.481553f, 0.109683f, 0.321698f, 0.426594f, 0.024548f, 0.388333f, 0.094122f, 0.493579f, 0.825738f, 0.818422f, 0.080449f, 0.601228f, 0.834586f, 0.237973f, 0.761927f, 0.890764f, 0.806124f, 0.107301f, 0.009060f, 0.191724f, 0.270477f, 0.616183f, 0.384273f, 0.703407f, 0.353075f, 0.154425f, 0.312690f, 0.884324f, 0.958532f, 0.207513f, 0.788468f, 0.273349f, 0.887132f, 0.165546f, 0.665960f, 0.084211f, 0.973893f, 0.700633f, 0.841816f, 0.566669f, 0.476801f, 0.621882f, 0.528742f, 0.469384f, 0.759450f, 0.178201f, 0.171172f, 0.431843f, 0.320748f, 0.074125f, 0.844471f, 0.771603f, 0.543921f, 0.979325f, 0.072600f, 0.766669f, 0.266370f, 0.368599f, 0.219279f, 0.789038f, 0.144240f, 0.840017f, 0.661578f, 0.059023f, 0.810982f, 0.627756f, 0.904982f, 0.748722f, 0.561121f, 0.836547f, 0.278050f, 0.976226f, 0.889794f, 0.764562f, 0.698249f, 0.335498f, 0.147686f, 0.062636f, 0.241902f, 0.432281f, 0.521996f, 0.773084f, 0.958741f, 0.117320f, 0.107004f, 0.589695f, 0.745398f, 0.848150f, 0.935832f, 0.983426f, 0.399802f, 0.380335f, 0.147809f, 0.684934f, 0.656762f, 0.862063f, 0.097258f, 0.497777f, 0.581082f, 0.241557f, 0.169025f, 0.859581f, 0.058535f, 0.470621f, 0.115834f, 0.457059f, 0.979962f, 0.423706f, 0.857125f, 0.117316f, 0.271252f, 0.403793f, 0.399812f, 0.671384f, 0.344718f, 0.713767f, 0.639187f, 0.399161f, 0.431760f, 0.546950f, 0.293617f, 0.968204f, 0.226196f, 0.015738f, 0.325855f, 0.502509f, 0.028363f, 0.559248f, 0.874283f, 0.704732f, 0.622968f, 0.955962f, 0.958279f, 0.824266f, 0.607742f, 0.487765f, 0.013316f, 0.606262f, 0.989088f, 0.818101f, 0.340605f, 0.152047f, 0.784059f, 0.743938f, 0.967047f, 0.874842f, 0.555663f, 0.101284f, 0.483501f, 0.313695f, 0.512408f, 0.301702f, 0.861823f, 0.844327f, 0.315465f, 0.599581f, 0.430181f, 0.909093f, 0.187361f, 0.697728f, 0.970375f, 0.175276f, 0.201966f, 0.693723f, 0.779154f, 0.490549f, 0.609686f, 0.212682f, 0.476614f, 0.112072f, 0.321422f, 0.284780f, 0.444625f, 0.930126f, 0.181268f, 0.401388f, 0.615597f, 0.946557f, 0.133148f, 0.917877f, 0.081054f, 0.480741f, 0.454590f, 0.209603f, 0.347460f, 0.454165f, 0.865211f, 0.955064f, 0.518926f, 0.870100f, 0.608172f, 0.349087f, 0.194194f, 0.413135f, 0.522824f, 0.044443f, 0.145841f, 0.600184f, 0.225002f, 0.837326f, 0.326942f, 0.104834f, 0.083531f, 0.937123f, 0.118020f, 0.140910f, 0.862666f, 0.254288f, 0.665951f, 0.816726f, 0.607181f, 0.957489f, 0.708883f, 0.112752f, 0.558410f, 0.614528f, 0.070042f, 0.822407f, 0.653421f, 0.726342f, 0.536923f, 0.110477f, 0.405036f, 0.405374f, 0.321043f, 0.029950f, 0.737254f, 0.109784f, 0.606308f, 0.703218f, 0.634786f, 0.959142f, 0.103298f, 0.867167f, 0.029190f, 0.534917f, 0.404244f, 0.524184f, 0.365100f, 0.190567f, 0.019123f, 0.518150f, 0.842777f, 0.373216f, 0.222864f, 0.080532f, 0.085311f, 0.221396f, 0.100014f, 0.265040f, 0.066149f, 0.065605f, 0.856276f, 0.162120f, 0.559682f, 0.773456f, 0.456410f, 0.153369f, 0.199596f, 0.432984f, 0.528234f, 0.349440f, 0.781480f, 0.718186f, 0.801957f, 0.026321f, 0.718879f, 0.825681f, 0.746834f, 0.512349f, 0.458021f, 0.549419f, 0.704644f, 0.922914f, 0.617035f, 0.887834f, 0.701257f, 0.068336f, 0.500828f, 0.286486f, 0.285175f, 0.355928f, 0.314733f, 0.578610f, 0.683601f, 0.268749f, 0.129763f, 0.058809f, 0.575753f, 0.186130f, 0.009248f, 0.927753f, 0.537140f, 0.092448f, 0.842921f, 0.983203f, 0.448601f, 0.042490f, 0.117546f, 0.381654f, 0.885523f, 0.148039f, 0.823990f, 0.014976f, 0.457389f, 0.644397f, 0.060379f, 0.614763f, 0.944404f, 0.160260f, 0.729611f, 0.609094f, 0.185116f, 0.006203f, 0.009284f, 0.532092f, 0.942779f, 0.644299f, 0.714300f, 0.493865f, 0.581889f, 0.126368f, 0.876821f, 0.760793f, 0.998199f, 0.297723f, 0.227018f, 0.125162f, 0.964210f, 0.780885f, 0.166325f, 0.552686f, 0.413768f, 0.151486f, 0.162073f, 0.963470f, 0.304964f, 0.941439f, 0.075611f, 0.460803f, 0.129619f, 0.004787f, 0.553766f, 0.113894f, 0.722025f, 0.698116f, 0.176333f, 0.941742f, 0.721043f, 0.297970f, 0.709234f, 0.731930f, 0.342226f, 0.375589f, 0.359107f, 0.616618f, 0.900410f, 0.173193f, 0.875200f, 0.751022f, 0.927212f, 0.028953f, 0.895691f, 0.392569f, 0.878372f, 0.690785f, 0.987349f, 0.759282f, 0.364545f, 0.501063f, 0.376389f, 0.364912f, 0.260904f, 0.495970f, 0.681740f, 0.277340f, 0.524380f, 0.117380f, 0.159845f, 0.046806f, 0.970731f, 0.003860f, 0.178580f, 0.612867f, 0.081370f, 0.881896f, 0.719620f, 0.966390f, 0.507636f, 0.300404f, 0.549501f, 0.930819f, 0.520761f, 0.267207f, 0.877399f, 0.371919f, 0.001383f, 0.247685f, 0.318234f, 0.858777f, 0.458503f, 0.444587f, 0.336102f, 0.880678f, 0.945027f, 0.991890f, 0.376741f}; + + ASSERT_EQ(q.size(), batch_size * q_num_heads * q_sequence_length * head_size); + ASSERT_EQ(k.size(), batch_size * kv_num_heads * kv_sequence_length * head_size); + ASSERT_EQ(v.size(), batch_size * kv_num_heads * kv_sequence_length * v_head_size); + ASSERT_EQ(m.size(), q_sequence_length * (kv_sequence_length + past_sequence_length)); + ASSERT_EQ(past_key.size(), batch_size * kv_num_heads * past_sequence_length * head_size); + ASSERT_EQ(past_value.size(), batch_size * kv_num_heads * past_sequence_length * v_head_size); + ASSERT_EQ(y.size(), batch_size * q_num_heads * q_sequence_length * v_head_size); + ASSERT_EQ(present_key.size(), batch_size * kv_num_heads * (past_sequence_length + kv_sequence_length) * head_size); + ASSERT_EQ(present_value.size(), batch_size * kv_num_heads * (past_sequence_length + kv_sequence_length) * v_head_size); + + RunTest4D(batch_size, q_num_heads, q_sequence_length, head_size, kv_sequence_length, kv_num_heads, v_head_size, past_sequence_length, + q, k, v, m, std::initializer_list(), past_key, past_value, + -1, -1, std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), -1, TensorType::kFloat, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type + y, present_key, present_value, std::vector(), + false, true, true // disable_cpu, disable_cuda, disable_dml + ); +} +TEST(AttentionTest, Attention4DAttnIsCausal) { + int batch_size = 2; // Q.shape[0] + int q_num_heads = 3; // Q.shape[1] + int q_sequence_length = 4; // Q.shape[2] + int head_size = 8; // Q.shape[3] + int kv_sequence_length = 6; // K.shape[2] and V.shape[2] + int kv_num_heads = 3; // K.shape[1] and V.shape[1] + int v_head_size = 8; // V.shape[3] + int past_sequence_length = 12; // past_key.shape[2] and past_value.shape[2] + + std::vector q = {0.548814f, 0.715189f, 0.602763f, 0.544883f, 0.423655f, 0.645894f, 0.437587f, 0.891773f, 0.963663f, 0.383442f, 0.791725f, 0.528895f, 0.568045f, 0.925597f, 0.071036f, 0.087129f, 0.020218f, 0.832620f, 0.778157f, 0.870012f, 0.978618f, 0.799159f, 0.461479f, 0.780529f, 0.118274f, 0.639921f, 0.143353f, 0.944669f, 0.521848f, 0.414662f, 0.264556f, 0.774234f, 0.456150f, 0.568434f, 0.018790f, 0.617635f, 0.612096f, 0.616934f, 0.943748f, 0.681820f, 0.359508f, 0.437032f, 0.697631f, 0.060225f, 0.666767f, 0.670638f, 0.210383f, 0.128926f, 0.315428f, 0.363711f, 0.570197f, 0.438602f, 0.988374f, 0.102045f, 0.208877f, 0.161310f, 0.653108f, 0.253292f, 0.466311f, 0.244426f, 0.158970f, 0.110375f, 0.656330f, 0.138183f, 0.196582f, 0.368725f, 0.820993f, 0.097101f, 0.837945f, 0.096098f, 0.976459f, 0.468651f, 0.976761f, 0.604846f, 0.739264f, 0.039188f, 0.282807f, 0.120197f, 0.296140f, 0.118728f, 0.317983f, 0.414263f, 0.064147f, 0.692472f, 0.566601f, 0.265390f, 0.523248f, 0.093941f, 0.575947f, 0.929296f, 0.318569f, 0.667410f, 0.131798f, 0.716327f, 0.289406f, 0.183191f, 0.586513f, 0.020108f, 0.828940f, 0.004695f, 0.677817f, 0.270008f, 0.735194f, 0.962189f, 0.248753f, 0.576157f, 0.592042f, 0.572252f, 0.223082f, 0.952749f, 0.447125f, 0.846409f, 0.699479f, 0.297437f, 0.813798f, 0.396506f, 0.881103f, 0.581273f, 0.881735f, 0.692532f, 0.725254f, 0.501324f, 0.956084f, 0.643990f, 0.423855f, 0.606393f, 0.019193f, 0.301575f, 0.660174f, 0.290078f, 0.618015f, 0.428769f, 0.135474f, 0.298282f, 0.569965f, 0.590873f, 0.574325f, 0.653201f, 0.652103f, 0.431418f, 0.896547f, 0.367562f, 0.435865f, 0.891923f, 0.806194f, 0.703889f, 0.100227f, 0.919483f, 0.714241f, 0.998847f, 0.149448f, 0.868126f, 0.162493f, 0.615560f, 0.123820f, 0.848008f, 0.807319f, 0.569101f, 0.407183f, 0.069167f, 0.697429f, 0.453543f, 0.722056f, 0.866382f, 0.975522f, 0.855803f, 0.011714f, 0.359978f, 0.729991f, 0.171630f, 0.521037f, 0.054338f, 0.199997f, 0.018522f, 0.793698f, 0.223925f, 0.345352f, 0.928081f, 0.704414f, 0.031839f, 0.164694f, 0.621478f, 0.577229f, 0.237893f, 0.934214f, 0.613966f, 0.535633f, 0.589910f, 0.730122f, 0.311945f, 0.398221f, 0.209844f}; + std::vector k = {0.186193f, 0.944372f, 0.739551f, 0.490459f, 0.227415f, 0.254356f, 0.058029f, 0.434417f, 0.311796f, 0.696343f, 0.377752f, 0.179604f, 0.024679f, 0.067250f, 0.679393f, 0.453697f, 0.536579f, 0.896671f, 0.990339f, 0.216897f, 0.663078f, 0.263322f, 0.020651f, 0.758379f, 0.320017f, 0.383464f, 0.588317f, 0.831048f, 0.628982f, 0.872651f, 0.273542f, 0.798047f, 0.185636f, 0.952792f, 0.687488f, 0.215508f, 0.947371f, 0.730856f, 0.253942f, 0.213312f, 0.518201f, 0.025663f, 0.207470f, 0.424685f, 0.374170f, 0.463575f, 0.277629f, 0.586784f, 0.863856f, 0.117532f, 0.517379f, 0.132068f, 0.716860f, 0.396060f, 0.565421f, 0.183280f, 0.144848f, 0.488056f, 0.355613f, 0.940432f, 0.765325f, 0.748664f, 0.903720f, 0.083422f, 0.552192f, 0.584476f, 0.961936f, 0.292148f, 0.240829f, 0.100294f, 0.016430f, 0.929529f, 0.669917f, 0.785153f, 0.281730f, 0.586410f, 0.063955f, 0.485628f, 0.977495f, 0.876505f, 0.338159f, 0.961570f, 0.231702f, 0.949319f, 0.941378f, 0.799203f, 0.630448f, 0.874288f, 0.293020f, 0.848944f, 0.617877f, 0.013237f, 0.347234f, 0.148141f, 0.981829f, 0.478370f, 0.497391f, 0.639473f, 0.368585f, 0.136900f, 0.822118f, 0.189848f, 0.511319f, 0.224317f, 0.097844f, 0.862191f, 0.972919f, 0.960835f, 0.906555f, 0.774047f, 0.333145f, 0.081101f, 0.407241f, 0.232234f, 0.132488f, 0.053427f, 0.725594f, 0.011427f, 0.770581f, 0.146947f, 0.079522f, 0.089603f, 0.672048f, 0.245367f, 0.420539f, 0.557369f, 0.860551f, 0.727044f, 0.270328f, 0.131483f, 0.055374f, 0.301599f, 0.262118f, 0.456141f, 0.683281f, 0.695625f, 0.283519f, 0.379927f, 0.181151f, 0.788545f, 0.056848f, 0.696997f, 0.778695f, 0.777408f, 0.259423f, 0.373813f, 0.587600f, 0.272822f, 0.370853f, 0.197054f, 0.459856f, 0.044612f, 0.799796f, 0.076956f, 0.518835f, 0.306810f, 0.577543f, 0.959433f, 0.645570f, 0.035362f, 0.430402f, 0.510017f, 0.536178f, 0.681392f, 0.277596f, 0.128861f, 0.392676f, 0.956406f, 0.187131f, 0.903984f, 0.543806f, 0.456911f, 0.882041f, 0.458604f, 0.724168f, 0.399025f, 0.904044f, 0.690025f, 0.699622f, 0.327720f, 0.756779f, 0.636061f, 0.240020f, 0.160539f, 0.796391f, 0.959167f, 0.458139f, 0.590984f, 0.857723f, 0.457223f, 0.951874f, 0.575751f, 0.820767f, 0.908844f, 0.815524f, 0.159414f, 0.628898f, 0.398434f, 0.062713f, 0.424032f, 0.258684f, 0.849038f, 0.033305f, 0.958983f, 0.355369f, 0.356707f, 0.016329f, 0.185232f, 0.401260f, 0.929291f, 0.099615f, 0.945302f, 0.869489f, 0.454162f, 0.326701f, 0.232744f, 0.614465f, 0.033075f, 0.015606f, 0.428796f, 0.068074f, 0.251941f, 0.221161f, 0.253191f, 0.131055f, 0.012036f, 0.115484f, 0.618480f, 0.974256f, 0.990345f, 0.409054f, 0.162954f, 0.638762f, 0.490305f, 0.989410f, 0.065304f, 0.783234f, 0.288399f, 0.241419f, 0.662505f, 0.246063f, 0.665859f, 0.517309f, 0.424089f, 0.554688f, 0.287052f, 0.706575f, 0.414857f, 0.360546f, 0.828657f, 0.924967f, 0.046007f, 0.232627f, 0.348519f, 0.814966f, 0.985491f, 0.968972f, 0.904948f, 0.296556f, 0.992011f, 0.249420f, 0.105906f, 0.950953f, 0.233420f, 0.689768f, 0.058356f, 0.730709f, 0.881720f, 0.272437f, 0.379057f, 0.374296f, 0.748788f, 0.237807f, 0.171853f, 0.449292f, 0.304468f, 0.839189f, 0.237742f, 0.502389f, 0.942584f, 0.633998f, 0.867289f, 0.940210f, 0.750765f, 0.699575f, 0.967966f, 0.994401f, 0.451822f}; + std::vector v = {0.070870f, 0.292794f, 0.152355f, 0.417486f, 0.131289f, 0.604118f, 0.382808f, 0.895386f, 0.967795f, 0.546885f, 0.274824f, 0.592230f, 0.896761f, 0.406733f, 0.552078f, 0.271653f, 0.455444f, 0.401714f, 0.248413f, 0.505866f, 0.310381f, 0.373035f, 0.524970f, 0.750595f, 0.333507f, 0.924159f, 0.862319f, 0.048690f, 0.253643f, 0.446136f, 0.104628f, 0.348476f, 0.740098f, 0.680514f, 0.622384f, 0.710528f, 0.204924f, 0.341698f, 0.676242f, 0.879235f, 0.543678f, 0.282700f, 0.030235f, 0.710337f, 0.007884f, 0.372679f, 0.530537f, 0.922111f, 0.089495f, 0.405942f, 0.024313f, 0.342611f, 0.622231f, 0.279068f, 0.209750f, 0.115703f, 0.577140f, 0.695270f, 0.671957f, 0.948861f, 0.002703f, 0.647197f, 0.600392f, 0.588740f, 0.962770f, 0.016872f, 0.696482f, 0.813679f, 0.509807f, 0.333965f, 0.790840f, 0.097243f, 0.442036f, 0.519952f, 0.693956f, 0.090886f, 0.227759f, 0.410302f, 0.623295f, 0.886961f, 0.618826f, 0.133461f, 0.980580f, 0.871786f, 0.502721f, 0.922348f, 0.541381f, 0.923306f, 0.829897f, 0.968286f, 0.919783f, 0.036034f, 0.174772f, 0.389135f, 0.952143f, 0.300029f, 0.160468f, 0.886305f, 0.446394f, 0.907876f, 0.160230f, 0.661117f, 0.440264f, 0.076487f, 0.696463f, 0.247399f, 0.039616f, 0.059944f, 0.061079f, 0.907733f, 0.739884f, 0.898062f, 0.672582f, 0.528940f, 0.304446f, 0.997962f, 0.362189f, 0.470649f, 0.378245f, 0.979527f, 0.174658f, 0.327988f, 0.680349f, 0.063208f, 0.607249f, 0.477646f, 0.284000f, 0.238413f, 0.514513f, 0.367928f, 0.456520f, 0.337477f, 0.970494f, 0.133439f, 0.096804f, 0.343392f, 0.591027f, 0.659176f, 0.397257f, 0.999278f, 0.351893f, 0.721407f, 0.637583f, 0.813054f, 0.976226f, 0.889794f, 0.764562f, 0.698249f, 0.335498f, 0.147686f, 0.062636f, 0.241902f, 0.432281f, 0.521996f, 0.773084f, 0.958741f, 0.117320f, 0.107004f, 0.589695f, 0.745398f, 0.848150f, 0.935832f, 0.983426f, 0.399802f, 0.380335f, 0.147809f, 0.684934f, 0.656762f, 0.862063f, 0.097258f, 0.497777f, 0.581082f, 0.241557f, 0.169025f, 0.859581f, 0.058535f, 0.470621f, 0.115834f, 0.457059f, 0.979962f, 0.423706f, 0.857125f, 0.117316f, 0.271252f, 0.403793f, 0.399812f, 0.671384f, 0.344718f, 0.713767f, 0.639187f, 0.399161f, 0.431760f, 0.614528f, 0.070042f, 0.822407f, 0.653421f, 0.726342f, 0.536923f, 0.110477f, 0.405036f, 0.405374f, 0.321043f, 0.029950f, 0.737254f, 0.109784f, 0.606308f, 0.703218f, 0.634786f, 0.959142f, 0.103298f, 0.867167f, 0.029190f, 0.534917f, 0.404244f, 0.524184f, 0.365100f, 0.190567f, 0.019123f, 0.518150f, 0.842777f, 0.373216f, 0.222864f, 0.080532f, 0.085311f, 0.221396f, 0.100014f, 0.265040f, 0.066149f, 0.065605f, 0.856276f, 0.162120f, 0.559682f, 0.773456f, 0.456410f, 0.153369f, 0.199596f, 0.432984f, 0.528234f, 0.349440f, 0.781480f, 0.751022f, 0.927212f, 0.028953f, 0.895691f, 0.392569f, 0.878372f, 0.690785f, 0.987349f, 0.759282f, 0.364545f, 0.501063f, 0.376389f, 0.364912f, 0.260904f, 0.495970f, 0.681740f, 0.277340f, 0.524380f, 0.117380f, 0.159845f, 0.046806f, 0.970731f, 0.003860f, 0.178580f, 0.612867f, 0.081370f, 0.881896f, 0.719620f, 0.966390f, 0.507636f, 0.300404f, 0.549501f, 0.930819f, 0.520761f, 0.267207f, 0.877399f, 0.371919f, 0.001383f, 0.247685f, 0.318234f, 0.858777f, 0.458503f, 0.444587f, 0.336102f, 0.880678f, 0.945027f, 0.991890f, 0.376741f}; + std::vector y = {0.070870f, 0.292794f, 0.152355f, 0.417486f, 0.131289f, 0.604118f, 0.382808f, 0.895386f, + 0.466662f, 0.404919f, 0.206397f, 0.494597f, 0.469075f, 0.517016f, 0.457503f, 0.620147f, + 0.455868f, 0.401850f, 0.222910f, 0.498051f, 0.398273f, 0.458905f, 0.484206f, 0.678309f, + 0.428625f, 0.565862f, 0.420294f, 0.361176f, 0.366713f, 0.456673f, 0.367244f, 0.565962f, + 0.089495f, 0.405942f, 0.024313f, 0.342611f, 0.622231f, 0.279068f, 0.209750f, 0.115703f, + 0.340486f, 0.554859f, 0.357655f, 0.654648f, 0.303360f, 0.468544f, 0.410813f, 0.359175f, 0.539688f, 0.388773f, 0.469414f, 0.709710f, 0.362709f, 0.429548f, 0.533266f, 0.281177f, 0.507994f, 0.419524f, 0.523713f, 0.531125f, 0.334381f, 0.418885f, 0.553995f, 0.441341f, 0.160468f, 0.886305f, 0.446394f, 0.907876f, 0.160230f, 0.661117f, 0.440264f, 0.076487f, 0.440199f, 0.552865f, 0.234100f, 0.465348f, 0.108484f, 0.789824f, 0.596633f, 0.505260f, 0.521296f, 0.529090f, 0.243612f, 0.596347f, 0.178938f, 0.704410f, 0.541649f, 0.663573f, 0.447473f, 0.471171f, 0.330193f, 0.440955f, 0.264086f, 0.669717f, 0.497800f, 0.570196f, 0.976226f, 0.889794f, 0.764562f, 0.698249f, 0.335498f, 0.147686f, 0.062636f, 0.241902f, 0.666526f, 0.680385f, 0.769414f, 0.846562f, 0.211277f, 0.124523f, 0.362721f, 0.528572f, 0.722160f, 0.763995f, 0.843738f, 0.695165f, 0.266952f, 0.132048f, 0.481567f, 0.579821f, 0.766651f, 0.587935f, 0.750237f, 0.660460f, 0.262872f, 0.142580f, 0.578552f, 0.432957f, 0.614528f, 0.070042f, 0.822407f, 0.653421f, 0.726342f, 0.536923f, 0.110477f, 0.405036f, 0.528620f, 0.173138f, 0.496913f, 0.687855f, 0.473097f, 0.565422f, 0.353939f, 0.499403f, 0.683711f, 0.156556f, 0.606089f, 0.441246f, 0.472192f, 0.507007f, 0.441957f, 0.457522f, 0.599108f, 0.136602f, 0.579971f, 0.504480f, 0.443634f, 0.456725f, 0.392707f, 0.395364f, 0.751022f, 0.927212f, 0.028953f, 0.895691f, 0.392569f, 0.878372f, 0.690785f, 0.987349f, 0.755483f, 0.623352f, 0.283909f, 0.615250f, 0.377633f, 0.544918f, 0.585578f, 0.822309f, 0.598965f, 0.584465f, 0.234792f, 0.460114f, 0.268955f, 0.677291f, 0.392800f, 0.607946f, 0.577946f, 0.470810f, 0.371437f, 0.510227f, 0.419904f, 0.671214f, 0.345365f, 0.567849f}; + + ASSERT_EQ(q.size(), batch_size * q_num_heads * q_sequence_length * head_size); + ASSERT_EQ(k.size(), batch_size * kv_num_heads * kv_sequence_length * head_size); + ASSERT_EQ(v.size(), batch_size * kv_num_heads * kv_sequence_length * v_head_size); + ASSERT_EQ(y.size(), batch_size * q_num_heads * q_sequence_length * v_head_size); + + RunTest4D(batch_size, q_num_heads, q_sequence_length, head_size, kv_sequence_length, kv_num_heads, v_head_size, past_sequence_length, + q, k, v, std::vector(), std::initializer_list(), std::vector(), std::vector(), + 1, -1, std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), -1, TensorType::kFloat, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type + y, std::vector(), std::vector(), std::vector(), + false, true, true // disable_cpu, disable_cuda, disable_dml + ); +} + +TEST(AttentionTest, Attention4DAttnIsCausalBasic) { + int batch_size = 2; // Q.shape[0] + int q_num_heads = 1; // Q.shape[1] + int q_sequence_length = 3; // Q.shape[2] + int head_size = 2; // Q.shape[3] + int kv_sequence_length = 3; // K.shape[2] and V.shape[2] + int kv_num_heads = 1; // K.shape[1] and V.shape[1] + int v_head_size = 2; // V.shape[3] + int past_sequence_length = 0; // past_key.shape[2] and past_value.shape[2] + + std::vector q = {1.f, 1.f, 0.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}; + std::vector k = {1.f, 0.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}; + std::vector v = {0.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}; + std::vector y = {0.0, 1.0, 0.6697615385055542, 1.0, 0.8022241592407227, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0}; + + ASSERT_EQ(q.size(), batch_size * q_num_heads * q_sequence_length * head_size); + ASSERT_EQ(k.size(), batch_size * kv_num_heads * kv_sequence_length * head_size); + ASSERT_EQ(v.size(), batch_size * kv_num_heads * kv_sequence_length * v_head_size); + ASSERT_EQ(y.size(), batch_size * q_num_heads * q_sequence_length * v_head_size); + + RunTest4D(batch_size, q_num_heads, q_sequence_length, head_size, kv_sequence_length, kv_num_heads, v_head_size, past_sequence_length, + q, k, v, std::vector(), std::initializer_list(), std::vector(), std::vector(), + 1, -1, std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), -1, TensorType::kFloat, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type + y, std::vector(), std::vector(), std::vector(), + false, true, true // disable_cpu, disable_cuda, disable_dml + ); +} + +TEST(AttentionTest, Attention4DAttnIsCausalBasicFloat16) { + int batch_size = 2; // Q.shape[0] + int q_num_heads = 1; // Q.shape[1] + int q_sequence_length = 3; // Q.shape[2] + int head_size = 2; // Q.shape[3] + int kv_sequence_length = 3; // K.shape[2] and V.shape[2] + int kv_num_heads = 1; // K.shape[1] and V.shape[1] + int v_head_size = 2; // V.shape[3] + int past_sequence_length = 0; // past_key.shape[2] and past_value.shape[2] + + std::vector q = {1.f, 1.f, 0.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}; + std::vector k = {1.f, 0.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}; + std::vector v = {0.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}; + std::vector y = {0.0, 1.0, 0.6697615385055542, 1.0, 0.8022241592407227, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0}; + + ASSERT_EQ(q.size(), batch_size * q_num_heads * q_sequence_length * head_size); + ASSERT_EQ(k.size(), batch_size * kv_num_heads * kv_sequence_length * head_size); + ASSERT_EQ(v.size(), batch_size * kv_num_heads * kv_sequence_length * v_head_size); + ASSERT_EQ(y.size(), batch_size * q_num_heads * q_sequence_length * v_head_size); + + RunTest4D(batch_size, q_num_heads, q_sequence_length, head_size, kv_sequence_length, kv_num_heads, v_head_size, past_sequence_length, + q, k, v, std::vector(), std::initializer_list(), std::vector(), std::vector(), + 1, -1, std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), -1, TensorType::kFloat16, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type + y, std::vector(), std::vector(), std::vector(), + false, true, true // disable_cpu, disable_cuda, disable_dml + ); +} + +TEST(AttentionTest, Attention4DAttnIsCausalBasicDifferentSequenceLength) { + int batch_size = 2; // Q.shape[0] + int q_num_heads = 1; // Q.shape[1] + int q_sequence_length = 3; // Q.shape[2] + int head_size = 2; // Q.shape[3] + int kv_sequence_length = 4; // K.shape[2] and V.shape[2] + int kv_num_heads = 1; // K.shape[1] and V.shape[1] + int v_head_size = 2; // V.shape[3] + int past_sequence_length = 0; // past_key.shape[2] and past_value.shape[2] + + std::vector q = {1.f, 1.f, 0.f, 1.f, 2.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}; + std::vector k = {1.f, 0.f, 1.f, 1.f, 1.f, 2.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 0.f, 1.f, 2}; + std::vector v = {0.f, 1.f, 1.f, 1.f, 1.f, 1.f, 2.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 0.f, 2}; + std::vector y = {0.0, 1.0, 0.6697615385055542, 1.0, 0.85997074842453, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0}; + + ASSERT_EQ(q.size(), batch_size * q_num_heads * q_sequence_length * head_size); + ASSERT_EQ(k.size(), batch_size * kv_num_heads * kv_sequence_length * head_size); + ASSERT_EQ(v.size(), batch_size * kv_num_heads * kv_sequence_length * v_head_size); + ASSERT_EQ(y.size(), batch_size * q_num_heads * q_sequence_length * v_head_size); + + RunTest4D(batch_size, q_num_heads, q_sequence_length, head_size, kv_sequence_length, kv_num_heads, v_head_size, past_sequence_length, + q, k, v, std::vector(), std::initializer_list(), std::vector(), std::vector(), + 1, -1, std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), -1, TensorType::kFloat, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type + y, std::vector(), std::vector(), std::vector(), + false, true, true // disable_cpu, disable_cuda, disable_dml + ); +} + +TEST(AttentionTest, Attention4DDiffHeadsWithPastAndPresent) { + int batch_size = 2; // Q.shape[0] + int q_num_heads = 3; // Q.shape[1] + int q_sequence_length = 4; // Q.shape[2] + int head_size = 8; // Q.shape[3] + int kv_sequence_length = 6; // K.shape[2] and V.shape[2] + int kv_num_heads = 3; // K.shape[1] and V.shape[1] + int v_head_size = 10; // V.shape[3] + int past_sequence_length = 12; // past_key.shape[2] and past_value.shape[2] + + // {2, 3, 4, 8} + std::vector q = {0.548814f, 0.715189f, 0.602763f, 0.544883f, 0.423655f, 0.645894f, 0.437587f, 0.891773f, 0.963663f, 0.383442f, 0.791725f, 0.528895f, 0.568045f, 0.925597f, 0.071036f, 0.087129f, 0.020218f, 0.832620f, 0.778157f, 0.870012f, 0.978618f, 0.799159f, 0.461479f, 0.780529f, 0.118274f, 0.639921f, 0.143353f, 0.944669f, 0.521848f, 0.414662f, 0.264556f, 0.774234f, 0.456150f, 0.568434f, 0.018790f, 0.617635f, 0.612096f, 0.616934f, 0.943748f, 0.681820f, 0.359508f, 0.437032f, 0.697631f, 0.060225f, 0.666767f, 0.670638f, 0.210383f, 0.128926f, 0.315428f, 0.363711f, 0.570197f, 0.438602f, 0.988374f, 0.102045f, 0.208877f, 0.161310f, 0.653108f, 0.253292f, 0.466311f, 0.244426f, 0.158970f, 0.110375f, 0.656330f, 0.138183f, 0.196582f, 0.368725f, 0.820993f, 0.097101f, 0.837945f, 0.096098f, 0.976459f, 0.468651f, 0.976761f, 0.604846f, 0.739264f, 0.039188f, 0.282807f, 0.120197f, 0.296140f, 0.118728f, 0.317983f, 0.414263f, 0.064147f, 0.692472f, 0.566601f, 0.265390f, 0.523248f, 0.093941f, 0.575947f, 0.929296f, 0.318569f, 0.667410f, 0.131798f, 0.716327f, 0.289406f, 0.183191f, 0.586513f, 0.020108f, 0.828940f, 0.004695f, 0.677817f, 0.270008f, 0.735194f, 0.962189f, 0.248753f, 0.576157f, 0.592042f, 0.572252f, 0.223082f, 0.952749f, 0.447125f, 0.846409f, 0.699479f, 0.297437f, 0.813798f, 0.396506f, 0.881103f, 0.581273f, 0.881735f, 0.692532f, 0.725254f, 0.501324f, 0.956084f, 0.643990f, 0.423855f, 0.606393f, 0.019193f, 0.301575f, 0.660174f, 0.290078f, 0.618015f, 0.428769f, 0.135474f, 0.298282f, 0.569965f, 0.590873f, 0.574325f, 0.653201f, 0.652103f, 0.431418f, 0.896547f, 0.367562f, 0.435865f, 0.891923f, 0.806194f, 0.703889f, 0.100227f, 0.919483f, 0.714241f, 0.998847f, 0.149448f, 0.868126f, 0.162493f, 0.615560f, 0.123820f, 0.848008f, 0.807319f, 0.569101f, 0.407183f, 0.069167f, 0.697429f, 0.453543f, 0.722056f, 0.866382f, 0.975522f, 0.855803f, 0.011714f, 0.359978f, 0.729991f, 0.171630f, 0.521037f, 0.054338f, 0.199997f, 0.018522f, 0.793698f, 0.223925f, 0.345352f, 0.928081f, 0.704414f, 0.031839f, 0.164694f, 0.621478f, 0.577229f, 0.237893f, 0.934214f, 0.613966f, 0.535633f, 0.589910f, 0.730122f, 0.311945f, 0.398221f, 0.209844f}; + // {2, 3, 6, 8} + std::vector k = {0.186193f, 0.944372f, 0.739551f, 0.490459f, 0.227415f, 0.254356f, 0.058029f, 0.434417f, 0.311796f, 0.696343f, 0.377752f, 0.179604f, 0.024679f, 0.067250f, 0.679393f, 0.453697f, 0.536579f, 0.896671f, 0.990339f, 0.216897f, 0.663078f, 0.263322f, 0.020651f, 0.758379f, 0.320017f, 0.383464f, 0.588317f, 0.831048f, 0.628982f, 0.872651f, 0.273542f, 0.798047f, 0.185636f, 0.952792f, 0.687488f, 0.215508f, 0.947371f, 0.730856f, 0.253942f, 0.213312f, 0.518201f, 0.025663f, 0.207470f, 0.424685f, 0.374170f, 0.463575f, 0.277629f, 0.586784f, 0.863856f, 0.117532f, 0.517379f, 0.132068f, 0.716860f, 0.396060f, 0.565421f, 0.183280f, 0.144848f, 0.488056f, 0.355613f, 0.940432f, 0.765325f, 0.748664f, 0.903720f, 0.083422f, 0.552192f, 0.584476f, 0.961936f, 0.292148f, 0.240829f, 0.100294f, 0.016430f, 0.929529f, 0.669917f, 0.785153f, 0.281730f, 0.586410f, 0.063955f, 0.485628f, 0.977495f, 0.876505f, 0.338159f, 0.961570f, 0.231702f, 0.949319f, 0.941378f, 0.799203f, 0.630448f, 0.874288f, 0.293020f, 0.848944f, 0.617877f, 0.013237f, 0.347234f, 0.148141f, 0.981829f, 0.478370f, 0.497391f, 0.639473f, 0.368585f, 0.136900f, 0.822118f, 0.189848f, 0.511319f, 0.224317f, 0.097844f, 0.862191f, 0.972919f, 0.960835f, 0.906555f, 0.774047f, 0.333145f, 0.081101f, 0.407241f, 0.232234f, 0.132488f, 0.053427f, 0.725594f, 0.011427f, 0.770581f, 0.146947f, 0.079522f, 0.089603f, 0.672048f, 0.245367f, 0.420539f, 0.557369f, 0.860551f, 0.727044f, 0.270328f, 0.131483f, 0.055374f, 0.301599f, 0.262118f, 0.456141f, 0.683281f, 0.695625f, 0.283519f, 0.379927f, 0.181151f, 0.788545f, 0.056848f, 0.696997f, 0.778695f, 0.777408f, 0.259423f, 0.373813f, 0.587600f, 0.272822f, 0.370853f, 0.197054f, 0.459856f, 0.044612f, 0.799796f, 0.076956f, 0.518835f, 0.306810f, 0.577543f, 0.959433f, 0.645570f, 0.035362f, 0.430402f, 0.510017f, 0.536178f, 0.681392f, 0.277596f, 0.128861f, 0.392676f, 0.956406f, 0.187131f, 0.903984f, 0.543806f, 0.456911f, 0.882041f, 0.458604f, 0.724168f, 0.399025f, 0.904044f, 0.690025f, 0.699622f, 0.327720f, 0.756779f, 0.636061f, 0.240020f, 0.160539f, 0.796391f, 0.959167f, 0.458139f, 0.590984f, 0.857723f, 0.457223f, 0.951874f, 0.575751f, 0.820767f, 0.908844f, 0.815524f, 0.159414f, 0.628898f, 0.398434f, 0.062713f, 0.424032f, 0.258684f, 0.849038f, 0.033305f, 0.958983f, 0.355369f, 0.356707f, 0.016329f, 0.185232f, 0.401260f, 0.929291f, 0.099615f, 0.945302f, 0.869489f, 0.454162f, 0.326701f, 0.232744f, 0.614465f, 0.033075f, 0.015606f, 0.428796f, 0.068074f, 0.251941f, 0.221161f, 0.253191f, 0.131055f, 0.012036f, 0.115484f, 0.618480f, 0.974256f, 0.990345f, 0.409054f, 0.162954f, 0.638762f, 0.490305f, 0.989410f, 0.065304f, 0.783234f, 0.288399f, 0.241419f, 0.662505f, 0.246063f, 0.665859f, 0.517309f, 0.424089f, 0.554688f, 0.287052f, 0.706575f, 0.414857f, 0.360546f, 0.828657f, 0.924967f, 0.046007f, 0.232627f, 0.348519f, 0.814966f, 0.985491f, 0.968972f, 0.904948f, 0.296556f, 0.992011f, 0.249420f, 0.105906f, 0.950953f, 0.233420f, 0.689768f, 0.058356f, 0.730709f, 0.881720f, 0.272437f, 0.379057f, 0.374296f, 0.748788f, 0.237807f, 0.171853f, 0.449292f, 0.304468f, 0.839189f, 0.237742f, 0.502389f, 0.942584f, 0.633998f, 0.867289f, 0.940210f, 0.750765f, 0.699575f, 0.967966f, 0.994401f, 0.451822f}; + // {2, 3, 6, 10} + std::vector v = {0.070870f, 0.292794f, 0.152355f, 0.417486f, 0.131289f, 0.604118f, 0.382808f, 0.895386f, 0.967795f, 0.546885f, 0.274824f, 0.592230f, 0.896761f, 0.406733f, 0.552078f, 0.271653f, 0.455444f, 0.401714f, 0.248413f, 0.505866f, 0.310381f, 0.373035f, 0.524970f, 0.750595f, 0.333507f, 0.924159f, 0.862319f, 0.048690f, 0.253643f, 0.446136f, 0.104628f, 0.348476f, 0.740098f, 0.680514f, 0.622384f, 0.710528f, 0.204924f, 0.341698f, 0.676242f, 0.879235f, 0.543678f, 0.282700f, 0.030235f, 0.710337f, 0.007884f, 0.372679f, 0.530537f, 0.922111f, 0.089495f, 0.405942f, 0.024313f, 0.342611f, 0.622231f, 0.279068f, 0.209750f, 0.115703f, 0.577140f, 0.695270f, 0.671957f, 0.948861f, 0.002703f, 0.647197f, 0.600392f, 0.588740f, 0.962770f, 0.016872f, 0.696482f, 0.813679f, 0.509807f, 0.333965f, 0.790840f, 0.097243f, 0.442036f, 0.519952f, 0.693956f, 0.090886f, 0.227759f, 0.410302f, 0.623295f, 0.886961f, 0.618826f, 0.133461f, 0.980580f, 0.871786f, 0.502721f, 0.922348f, 0.541381f, 0.923306f, 0.829897f, 0.968286f, 0.919783f, 0.036034f, 0.174772f, 0.389135f, 0.952143f, 0.300029f, 0.160468f, 0.886305f, 0.446394f, 0.907876f, 0.160230f, 0.661117f, 0.440264f, 0.076487f, 0.696463f, 0.247399f, 0.039616f, 0.059944f, 0.061079f, 0.907733f, 0.739884f, 0.898062f, 0.672582f, 0.528940f, 0.304446f, 0.997962f, 0.362189f, 0.470649f, 0.378245f, 0.979527f, 0.174658f, 0.327988f, 0.680349f, 0.063208f, 0.607249f, 0.477646f, 0.284000f, 0.238413f, 0.514513f, 0.367928f, 0.456520f, 0.337477f, 0.970494f, 0.133439f, 0.096804f, 0.343392f, 0.591027f, 0.659176f, 0.397257f, 0.999278f, 0.351893f, 0.721407f, 0.637583f, 0.813054f, 0.976226f, 0.889794f, 0.764562f, 0.698249f, 0.335498f, 0.147686f, 0.062636f, 0.241902f, 0.432281f, 0.521996f, 0.773084f, 0.958741f, 0.117320f, 0.107004f, 0.589695f, 0.745398f, 0.848150f, 0.935832f, 0.983426f, 0.399802f, 0.380335f, 0.147809f, 0.684934f, 0.656762f, 0.862063f, 0.097258f, 0.497777f, 0.581082f, 0.241557f, 0.169025f, 0.859581f, 0.058535f, 0.470621f, 0.115834f, 0.457059f, 0.979962f, 0.423706f, 0.857125f, 0.117316f, 0.271252f, 0.403793f, 0.399812f, 0.671384f, 0.344718f, 0.713767f, 0.639187f, 0.399161f, 0.431760f, 0.614528f, 0.070042f, 0.822407f, 0.653421f, 0.726342f, 0.536923f, 0.110477f, 0.405036f, 0.405374f, 0.321043f, 0.029950f, 0.737254f, 0.109784f, 0.606308f, 0.703218f, 0.634786f, 0.959142f, 0.103298f, 0.867167f, 0.029190f, 0.534917f, 0.404244f, 0.524184f, 0.365100f, 0.190567f, 0.019123f, 0.518150f, 0.842777f, 0.373216f, 0.222864f, 0.080532f, 0.085311f, 0.221396f, 0.100014f, 0.265040f, 0.066149f, 0.065605f, 0.856276f, 0.162120f, 0.559682f, 0.773456f, 0.456410f, 0.153369f, 0.199596f, 0.432984f, 0.528234f, 0.349440f, 0.781480f, 0.751022f, 0.927212f, 0.028953f, 0.895691f, 0.392569f, 0.878372f, 0.690785f, 0.987349f, 0.759282f, 0.364545f, 0.501063f, 0.376389f, 0.364912f, 0.260904f, 0.495970f, 0.681740f, 0.277340f, 0.524380f, 0.117380f, 0.159845f, 0.046806f, 0.970731f, 0.003860f, 0.178580f, 0.612867f, 0.081370f, 0.881896f, 0.719620f, 0.966390f, 0.507636f, 0.300404f, 0.549501f, 0.930819f, 0.520761f, 0.267207f, 0.877399f, 0.371919f, 0.001383f, 0.247685f, 0.318234f, 0.858777f, 0.458503f, 0.444587f, 0.336102f, 0.880678f, 0.945027f, 0.991890f, 0.376741f, 0.966147f, 0.791880f, 0.675689f, 0.244889f, 0.216457f, 0.166048f, 0.922757f, 0.294077f, 0.453094f, 0.493958f, 0.778172f, 0.844235f, 0.139073f, 0.426904f, 0.842855f, 0.818033f, 0.102414f, 0.156383f, 0.304199f, 0.075359f, 0.424663f, 0.107618f, 0.568218f, 0.246557f, 0.596433f, 0.117526f, 0.975884f, 0.932561f, 0.391797f, 0.242179f, 0.250398f, 0.483394f, 0.039993f, 0.639705f, 0.408303f, 0.377407f, 0.809365f, 0.709035f, 0.954334f, 0.351936f, 0.897543f, 0.769967f, 0.357425f, 0.621665f, 0.288570f, 0.874400f, 0.112427f, 0.212434f, 0.183033f, 0.403026f, 0.745233f, 0.526907f, 0.487676f, 0.000546f, 0.425402f, 0.063554f, 0.208253f, 0.932394f, 0.215398f, 0.858338f, 0.802893f, 0.159146f, 0.605712f, 0.115662f, 0.727888f, 0.637462f, 0.811939f, 0.479385f, 0.914863f, 0.049349f, 0.292889f, 0.715053f}; + // {4, 18} + std::vector m = {0.418109f, 0.172951f, 0.107211f, 0.817339f, 0.473143f, 0.882284f, 0.733289f, 0.409726f, 0.373511f, 0.515638f, 0.889060f, 0.737279f, 0.005153f, 0.694158f, 0.919507f, 0.710456f, 0.177006f, 0.483518f, 0.140316f, 0.358995f, 0.937117f, 0.923305f, 0.282837f, 0.339631f, 0.600213f, 0.963197f, 0.147801f, 0.256917f, 0.873557f, 0.491892f, 0.898961f, 0.185518f, 0.532669f, 0.326270f, 0.316543f, 0.446877f, 0.433077f, 0.357347f, 0.914971f, 0.731744f, 0.727547f, 0.289913f, 0.577709f, 0.779179f, 0.795590f, 0.344530f, 0.770873f, 0.735894f, 0.141506f, 0.865945f, 0.441321f, 0.486410f, 0.448369f, 0.567846f, 0.621169f, 0.498180f, 0.866789f, 0.627735f, 0.401428f, 0.416692f, 0.810839f, 0.348192f, 0.211455f, 0.059383f, 0.876027f, 0.918546f, 0.120120f, 0.334474f, 0.175372f, 0.115898f, 0.899867f, 0.056877f}; + // {2, 3, 12, 8} + std::vector past_key = {0.980486f, 0.096451f, 0.863471f, 0.566506f, 0.367917f, 0.342342f, 0.757364f, 0.314573f, 0.657319f, 0.517326f, 0.484966f, 0.901162f, 0.554645f, 0.826862f, 0.725574f, 0.038557f, 0.773110f, 0.216870f, 0.903150f, 0.042924f, 0.333072f, 0.099733f, 0.475589f, 0.820022f, 0.298187f, 0.150935f, 0.330267f, 0.813880f, 0.140384f, 0.227362f, 0.068852f, 0.705710f, 0.395233f, 0.310840f, 0.718626f, 0.335978f, 0.727771f, 0.815199f, 0.217663f, 0.973819f, 0.162358f, 0.290841f, 0.179795f, 0.345506f, 0.480061f, 0.522176f, 0.853606f, 0.889448f, 0.220104f, 0.622894f, 0.111496f, 0.458970f, 0.322334f, 0.316501f, 0.482584f, 0.729828f, 0.069183f, 0.879173f, 0.734814f, 0.176499f, 0.939161f, 0.506312f, 0.999809f, 0.197259f, 0.534908f, 0.290248f, 0.304174f, 0.591065f, 0.921719f, 0.805264f, 0.723941f, 0.559174f, 0.922298f, 0.492361f, 0.873832f, 0.833982f, 0.213835f, 0.771225f, 0.012171f, 0.322830f, 0.229567f, 0.506863f, 0.736853f, 0.097676f, 0.514922f, 0.938412f, 0.228647f, 0.677141f, 0.592880f, 0.010064f, 0.475826f, 0.708770f, 0.043975f, 0.879521f, 0.520081f, 0.030661f, 0.224414f, 0.953676f, 0.582320f, 0.107473f, 0.287544f, 0.456704f, 0.020950f, 0.411616f, 0.489459f, 0.243678f, 0.588639f, 0.753240f, 0.235834f, 0.620500f, 0.639622f, 0.948540f, 0.778276f, 0.848345f, 0.490420f, 0.185349f, 0.995815f, 0.129356f, 0.471457f, 0.068093f, 0.943851f, 0.964925f, 0.719389f, 0.349993f, 0.254382f, 0.265303f, 0.127294f, 0.525809f, 0.141817f, 0.316731f, 0.626706f, 0.727544f, 0.024273f, 0.430116f, 0.652125f, 0.853246f, 0.475325f, 0.969206f, 0.265633f, 0.013509f, 0.483753f, 0.256114f, 0.823718f, 0.232773f, 0.310629f, 0.791227f, 0.715143f, 0.558051f, 0.704948f, 0.418637f, 0.005310f, 0.011355f, 0.511222f, 0.083291f, 0.051075f, 0.965517f, 0.859003f, 0.152027f, 0.000664f, 0.941668f, 0.278325f, 0.185898f, 0.691508f, 0.108904f, 0.264650f, 0.975095f, 0.639463f, 0.520678f, 0.397919f, 0.774501f, 0.140957f, 0.967338f, 0.861123f, 0.617657f, 0.042906f, 0.700856f, 0.913284f, 0.524577f, 0.354225f, 0.120277f, 0.754901f, 0.885022f, 0.100252f, 0.758985f, 0.017060f, 0.967055f, 0.615058f, 0.552439f, 0.295950f, 0.929292f, 0.265906f, 0.828147f, 0.985109f, 0.783397f, 0.518990f, 0.066074f, 0.472414f, 0.438256f, 0.202796f, 0.423588f, 0.357758f, 0.163684f, 0.441374f, 0.262800f, 0.522062f, 0.035160f, 0.906231f, 0.816364f, 0.552581f, 0.851809f, 0.962395f, 0.110522f, 0.630832f, 0.997994f, 0.987889f, 0.603323f, 0.128021f, 0.583193f, 0.002065f, 0.198911f, 0.956123f, 0.330441f, 0.638390f, 0.280860f, 0.947822f, 0.728559f, 0.329651f, 0.791761f, 0.108166f, 0.392319f, 0.221218f, 0.683726f, 0.102446f, 0.397026f, 0.276650f, 0.506343f, 0.349898f, 0.706411f, 0.024577f, 0.633987f, 0.230571f, 0.268709f, 0.800256f, 0.955568f, 0.316550f, 0.826805f, 0.103991f, 0.633982f, 0.751032f, 0.155978f, 0.426002f, 0.892707f, 0.103578f, 0.018096f, 0.590585f, 0.435532f, 0.798689f, 0.923456f, 0.299154f, 0.388404f, 0.486272f, 0.588151f, 0.983854f, 0.697330f, 0.389549f, 0.263768f, 0.944626f, 0.135548f, 0.720266f, 0.925395f, 0.664666f, 0.423054f, 0.198991f, 0.367475f, 0.706872f, 0.649534f, 0.927976f, 0.866861f, 0.816151f, 0.911451f, 0.276337f, 0.369524f, 0.379894f, 0.560451f, 0.668218f, 0.286717f, 0.019462f, 0.399222f, 0.308528f, 0.942185f, 0.888265f, 0.860311f, 0.653000f, 0.344289f, 0.548849f, 0.815225f, 0.098610f, 0.801075f, 0.041180f, 0.816421f, 0.807564f, 0.051007f, 0.627161f, 0.502453f, 0.169820f, 0.148379f, 0.773259f, 0.567693f, 0.982999f, 0.982248f, 0.992667f, 0.118616f, 0.938256f, 0.244570f, 0.458212f, 0.757407f, 0.203621f, 0.566312f, 0.185817f, 0.104736f, 0.116559f, 0.357639f, 0.004655f, 0.424854f, 0.664197f, 0.401688f, 0.085795f, 0.062689f, 0.278117f, 0.169313f, 0.965095f, 0.151230f, 0.805462f, 0.586108f, 0.569287f, 0.512081f, 0.971763f, 0.363845f, 0.787916f, 0.555294f, 0.395634f, 0.955466f, 0.598316f, 0.118917f, 0.417539f, 0.781582f, 0.693747f, 0.916340f, 0.259377f, 0.758194f, 0.459875f, 0.573610f, 0.955047f, 0.979286f, 0.861591f, 0.359097f, 0.887701f, 0.638609f, 0.429997f, 0.035743f, 0.770128f, 0.502106f, 0.786188f, 0.748023f, 0.793567f, 0.300651f, 0.800799f, 0.548846f, 0.473326f, 0.675126f, 0.021359f, 0.102317f, 0.292177f, 0.982990f, 0.139746f, 0.330596f, 0.051053f, 0.331269f, 0.320326f, 0.946807f, 0.845154f, 0.382764f, 0.024769f, 0.831031f, 0.660536f, 0.152364f, 0.996071f, 0.100233f, 0.867115f, 0.294266f, 0.435353f, 0.795457f, 0.677508f, 0.937864f, 0.621140f, 0.097810f, 0.884360f, 0.769156f, 0.711870f, 0.053734f, 0.396223f, 0.167436f, 0.821904f, 0.700529f, 0.883078f, 0.966575f, 0.774748f, 0.994233f, 0.614770f, 0.037130f, 0.014252f, 0.342104f, 0.823472f, 0.866135f, 0.960813f, 0.065121f, 0.044571f, 0.913284f, 0.305047f, 0.557987f, 0.982445f, 0.400449f, 0.665871f, 0.400880f, 0.768195f, 0.527715f, 0.237523f, 0.271306f, 0.258059f, 0.532320f, 0.703189f, 0.949280f, 0.694087f, 0.781193f, 0.168926f, 0.374063f, 0.413780f, 0.686380f, 0.295892f, 0.303292f, 0.355889f, 0.810302f, 0.577590f, 0.075277f, 0.078246f, 0.371287f, 0.766591f, 0.688683f, 0.707982f, 0.767210f, 0.287153f, 0.548256f, 0.543353f, 0.739632f, 0.956871f, 0.277990f, 0.793282f, 0.659971f, 0.580238f, 0.774880f, 0.944032f, 0.036691f, 0.147400f, 0.756287f, 0.083791f, 0.516124f, 0.219861f, 0.274296f, 0.701840f, 0.030193f, 0.873319f, 0.444479f, 0.502393f, 0.540048f, 0.645544f, 0.344857f, 0.101107f, 0.318379f, 0.168142f, 0.556133f, 0.318029f, 0.958067f, 0.965734f, 0.620126f, 0.617497f, 0.985379f, 0.887283f, 0.765070f, 0.313591f, 0.365539f, 0.201267f, 0.487148f, 0.990369f, 0.912151f, 0.118349f, 0.025190f, 0.898638f, 0.537170f, 0.200190f, 0.673653f, 0.644223f, 0.122086f, 0.259600f, 0.060078f, 0.209860f, 0.132306f, 0.193236f, 0.685467f, 0.049500f, 0.101855f, 0.134174f, 0.316541f, 0.298750f, 0.255064f, 0.750537f, 0.998023f, 0.533978f, 0.944203f, 0.396610f, 0.106682f, 0.408774f, 0.296128f, 0.493407f, 0.657044f, 0.461050f, 0.935161f, 0.884765f, 0.701978f, 0.489685f, 0.131687f, 0.397014f, 0.704402f, 0.284886f, 0.103988f, 0.907898f, 0.709051f, 0.615276f, 0.792499f, 0.835646f, 0.483459f, 0.881188f, 0.916419f, 0.271551f, 0.607545f, 0.526584f, 0.537946f, 0.937663f, 0.305189f, 0.983434f, 0.902131f, 0.458723f, 0.817453f, 0.769047f, 0.677895f, 0.319834f, 0.196451f, 0.671528f, 0.842973f, 0.016253f, 0.642803f, 0.442873f, 0.898088f, 0.321473f, 0.474185f, 0.514767f, 0.140440f, 0.712892f, 0.830476f, 0.057909f, 0.291389f, 0.038045f, 0.956544f, 0.667169f, 0.964200f, 0.531494f, 0.802069f}; + // {2, 3, 12, 10} + std::vector past_value = {0.374414f, 0.353819f, 0.378268f, 0.657862f, 0.359453f, 0.900367f, 0.983275f, 0.030427f, 0.193623f, 0.112250f, 0.042364f, 0.227741f, 0.446793f, 0.836990f, 0.221824f, 0.493945f, 0.929619f, 0.667215f, 0.798079f, 0.550994f, 0.980466f, 0.588662f, 0.045511f, 0.197983f, 0.404774f, 0.601277f, 0.771931f, 0.413086f, 0.710058f, 0.789869f, 0.317260f, 0.979270f, 0.649656f, 0.880998f, 0.555938f, 0.741603f, 0.770544f, 0.908248f, 0.150350f, 0.558283f, 0.428379f, 0.923159f, 0.105095f, 0.982574f, 0.875451f, 0.073826f, 0.490966f, 0.717560f, 0.738152f, 0.906494f, 0.799865f, 0.310930f, 0.498435f, 0.701786f, 0.138437f, 0.193991f, 0.481042f, 0.298246f, 0.862559f, 0.586277f, 0.348665f, 0.848833f, 0.804878f, 0.998355f, 0.847308f, 0.414457f, 0.127499f, 0.840641f, 0.059758f, 0.350271f, 0.919738f, 0.960766f, 0.640565f, 0.688648f, 0.042454f, 0.514480f, 0.546868f, 0.340101f, 0.068597f, 0.228908f, 0.357984f, 0.435142f, 0.590927f, 0.722392f, 0.317632f, 0.328954f, 0.019692f, 0.040875f, 0.257822f, 0.740245f, 0.628314f, 0.769789f, 0.768919f, 0.856567f, 0.720319f, 0.979011f, 0.898825f, 0.586717f, 0.588158f, 0.034267f, 0.998527f, 0.131576f, 0.740347f, 0.821015f, 0.373055f, 0.196852f, 0.098760f, 0.748606f, 0.452654f, 0.713718f, 0.915408f, 0.146584f, 0.919171f, 0.411626f, 0.305267f, 0.943062f, 0.990652f, 0.198892f, 0.656838f, 0.106495f, 0.650914f, 0.827313f, 0.684499f, 0.417333f, 0.383066f, 0.393122f, 0.589712f, 0.881567f, 0.929066f, 0.053530f, 0.181622f, 0.112224f, 0.193335f, 0.346608f, 0.506532f, 0.629461f, 0.732142f, 0.890112f, 0.989088f, 0.662856f, 0.845365f, 0.778039f, 0.307532f, 0.875692f, 0.042763f, 0.000367f, 0.273733f, 0.462098f, 0.638363f, 0.101770f, 0.673010f, 0.801816f, 0.185313f, 0.415125f, 0.519985f, 0.451807f, 0.799830f, 0.960522f, 0.798953f, 0.077993f, 0.804936f, 0.066596f, 0.235970f, 0.153097f, 0.197519f, 0.528315f, 0.671690f, 0.470321f, 0.959696f, 0.240292f, 0.763140f, 0.870182f, 0.562066f, 0.456223f, 0.596184f, 0.428810f, 0.555194f, 0.416934f, 0.400470f, 0.695346f, 0.092851f, 0.166542f, 0.851198f, 0.771077f, 0.281454f, 0.377269f, 0.926027f, 0.818077f, 0.614346f, 0.221490f, 0.044252f, 0.431258f, 0.672627f, 0.828480f, 0.852689f, 0.032776f, 0.244157f, 0.339095f, 0.188732f, 0.802975f, 0.767466f, 0.516833f, 0.982926f, 0.144059f, 0.899652f, 0.116463f, 0.163182f, 0.696219f, 0.109570f, 0.565845f, 0.420234f, 0.728474f, 0.900675f, 0.769872f, 0.849690f, 0.032945f, 0.310196f, 0.515433f, 0.415953f, 0.231255f, 0.307874f, 0.945431f, 0.294181f, 0.353904f, 0.003710f, 0.845078f, 0.154841f, 0.204144f, 0.255265f, 0.884622f, 0.206451f, 0.797526f, 0.808049f, 0.927021f, 0.115561f, 0.217279f, 0.742898f, 0.196001f, 0.286330f, 0.166742f, 0.172697f, 0.481553f, 0.109683f, 0.321698f, 0.426594f, 0.024548f, 0.388333f, 0.094122f, 0.493579f, 0.825738f, 0.818422f, 0.080449f, 0.601228f, 0.834586f, 0.237973f, 0.761927f, 0.890764f, 0.806124f, 0.107301f, 0.009060f, 0.191724f, 0.270477f, 0.616183f, 0.384273f, 0.703407f, 0.353075f, 0.154425f, 0.312690f, 0.884324f, 0.958532f, 0.207513f, 0.788468f, 0.273349f, 0.887132f, 0.165546f, 0.665960f, 0.084211f, 0.973893f, 0.700633f, 0.841816f, 0.566669f, 0.476801f, 0.621882f, 0.528742f, 0.469384f, 0.759450f, 0.178201f, 0.171172f, 0.431843f, 0.320748f, 0.074125f, 0.844471f, 0.771603f, 0.543921f, 0.979325f, 0.072600f, 0.766669f, 0.266370f, 0.368599f, 0.219279f, 0.789038f, 0.144240f, 0.840017f, 0.661578f, 0.059023f, 0.810982f, 0.627756f, 0.904982f, 0.748722f, 0.561121f, 0.836547f, 0.278050f, 0.546950f, 0.293617f, 0.968204f, 0.226196f, 0.015738f, 0.325855f, 0.502509f, 0.028363f, 0.559248f, 0.874283f, 0.704732f, 0.622968f, 0.955962f, 0.958279f, 0.824266f, 0.607742f, 0.487765f, 0.013316f, 0.606262f, 0.989088f, 0.818101f, 0.340605f, 0.152047f, 0.784059f, 0.743938f, 0.967047f, 0.874842f, 0.555663f, 0.101284f, 0.483501f, 0.313695f, 0.512408f, 0.301702f, 0.861823f, 0.844327f, 0.315465f, 0.599581f, 0.430181f, 0.909093f, 0.187361f, 0.697728f, 0.970375f, 0.175276f, 0.201966f, 0.693723f, 0.779154f, 0.490549f, 0.609686f, 0.212682f, 0.476614f, 0.112072f, 0.321422f, 0.284780f, 0.444625f, 0.930126f, 0.181268f, 0.401388f, 0.615597f, 0.946557f, 0.133148f, 0.917877f, 0.081054f, 0.480741f, 0.454590f, 0.209603f, 0.347460f, 0.454165f, 0.865211f, 0.955064f, 0.518926f, 0.870100f, 0.608172f, 0.349087f, 0.194194f, 0.413135f, 0.522824f, 0.044443f, 0.145841f, 0.600184f, 0.225002f, 0.837326f, 0.326942f, 0.104834f, 0.083531f, 0.937123f, 0.118020f, 0.140910f, 0.862666f, 0.254288f, 0.665951f, 0.816726f, 0.607181f, 0.957489f, 0.708883f, 0.112752f, 0.558410f, 0.718186f, 0.801957f, 0.026321f, 0.718879f, 0.825681f, 0.746834f, 0.512349f, 0.458021f, 0.549419f, 0.704644f, 0.922914f, 0.617035f, 0.887834f, 0.701257f, 0.068336f, 0.500828f, 0.286486f, 0.285175f, 0.355928f, 0.314733f, 0.578610f, 0.683601f, 0.268749f, 0.129763f, 0.058809f, 0.575753f, 0.186130f, 0.009248f, 0.927753f, 0.537140f, 0.092448f, 0.842921f, 0.983203f, 0.448601f, 0.042490f, 0.117546f, 0.381654f, 0.885523f, 0.148039f, 0.823990f, 0.014976f, 0.457389f, 0.644397f, 0.060379f, 0.614763f, 0.944404f, 0.160260f, 0.729611f, 0.609094f, 0.185116f, 0.006203f, 0.009284f, 0.532092f, 0.942779f, 0.644299f, 0.714300f, 0.493865f, 0.581889f, 0.126368f, 0.876821f, 0.760793f, 0.998199f, 0.297723f, 0.227018f, 0.125162f, 0.964210f, 0.780885f, 0.166325f, 0.552686f, 0.413768f, 0.151486f, 0.162073f, 0.963470f, 0.304964f, 0.941439f, 0.075611f, 0.460803f, 0.129619f, 0.004787f, 0.553766f, 0.113894f, 0.722025f, 0.698116f, 0.176333f, 0.941742f, 0.721043f, 0.297970f, 0.709234f, 0.731930f, 0.342226f, 0.375589f, 0.359107f, 0.616618f, 0.900410f, 0.173193f, 0.875200f, 0.027653f, 0.660339f, 0.414439f, 0.791282f, 0.721198f, 0.480108f, 0.643864f, 0.501773f, 0.811518f, 0.476084f, 0.523156f, 0.250521f, 0.605043f, 0.302905f, 0.577284f, 0.169678f, 0.159469f, 0.417030f, 0.426820f, 0.268109f, 0.131597f, 0.039211f, 0.025232f, 0.271550f, 0.461853f, 0.726243f, 0.474872f, 0.904051f, 0.035220f, 0.180661f, 0.338515f, 0.577496f, 0.852736f, 0.350202f, 0.267989f, 0.061889f, 0.821303f, 0.379666f, 0.571550f, 0.983555f, 0.001595f, 0.145450f, 0.779111f, 0.805128f, 0.769247f, 0.536999f, 0.978857f, 0.396185f, 0.601944f, 0.063369f, 0.409857f, 0.722500f, 0.238739f, 0.943828f, 0.686783f, 0.287575f, 0.768999f, 0.083165f, 0.974774f, 0.049285f, 0.933456f, 0.252854f, 0.757824f, 0.000074f, 0.254240f, 0.749101f, 0.532336f, 0.114952f, 0.393630f, 0.375549f, 0.568162f, 0.667977f, 0.840830f, 0.497231f, 0.392022f, 0.143977f, 0.804823f, 0.713370f, 0.408677f, 0.518432f, 0.665183f, 0.164806f, 0.027198f, 0.317504f, 0.595585f, 0.486606f, 0.692555f, 0.819690f, 0.488442f, 0.134267f, 0.850628f, 0.574990f, 0.739937f, 0.704665f, 0.968212f, 0.295307f, 0.705307f, 0.365676f, 0.395411f, 0.230595f, 0.344010f, 0.948297f, 0.292571f, 0.245991f, 0.583138f, 0.258036f, 0.473386f, 0.834176f, 0.230400f, 0.426691f, 0.610490f, 0.545629f, 0.974723f, 0.680370f, 0.739946f, 0.966956f, 0.414438f, 0.355380f, 0.043862f, 0.184204f, 0.237190f, 0.183504f, 0.754784f, 0.535883f, 0.667634f, 0.820462f, 0.230774f, 0.325924f, 0.708360f, 0.392759f, 0.029271f, 0.434955f, 0.908273f, 0.409021f, 0.332249f, 0.989525f, 0.644416f, 0.365998f, 0.102020f, 0.787849f, 0.708075f, 0.921916f, 0.217276f, 0.114924f, 0.724073f, 0.203396f, 0.176104f, 0.319807f, 0.816825f, 0.539537f, 0.045850f, 0.463895f, 0.683980f, 0.538368f, 0.572450f, 0.224777f, 0.847739f, 0.561399f, 0.713246f, 0.981864f, 0.428199f, 0.881067f, 0.007281f, 0.033407f, 0.590280f, 0.311449f, 0.248277f, 0.277935f, 0.318403f, 0.728948f, 0.569196f, 0.789036f, 0.830197f, 0.842935f, 0.414644f, 0.421273f, 0.926266f, 0.661764f, 0.080467f, 0.542187f, 0.356007f, 0.987435f, 0.013655f, 0.612181f, 0.723623f, 0.288907f, 0.973642f, 0.859537f, 0.915653f, 0.019232f, 0.569872f, 0.294650f, 0.849029f, 0.632850f, 0.538877f, 0.114588f, 0.540223f, 0.631904f, 0.955912f, 0.585051f, 0.967401f, 0.961606f, 0.650200f, 0.505908f, 0.466022f, 0.890379f, 0.028257f, 0.113808f, 0.102072f, 0.756935f, 0.339651f, 0.637969f, 0.603783f, 0.385828f, 0.531568f, 0.645139f}; + // {2, 3, 4, 10} + std::vector y = {0.484245f, 0.491594f, 0.566765f, 0.698646f, 0.412717f, 0.529403f, 0.546576f, 0.477395f, 0.458289f, 0.526034f, 0.507523f, 0.501791f, 0.516438f, 0.666451f, 0.374304f, 0.541111f, 0.568747f, 0.520548f, 0.473141f, 0.519258f, 0.498172f, 0.514510f, 0.527296f, 0.682262f, 0.396020f, 0.501123f, 0.530399f, 0.488510f, 0.446185f, 0.542778f, 0.511414f, 0.485035f, 0.517123f, 0.684857f, 0.389196f, 0.515658f, 0.556560f, 0.526948f, 0.446624f, 0.513224f, 0.518960f, 0.522651f, 0.541202f, 0.520867f, 0.515921f, 0.390582f, 0.438142f, 0.557164f, 0.504964f, 0.579576f, 0.465363f, 0.569218f, 0.532317f, 0.551877f, 0.490628f, 0.361162f, 0.458657f, 0.568250f, 0.511133f, 0.519196f, 0.508355f, 0.532992f, 0.540742f, 0.536218f, 0.491775f, 0.346055f, 0.430588f, 0.545529f, 0.508855f, 0.534426f, 0.477742f, 0.559174f, 0.522186f, 0.518533f, 0.461976f, 0.366468f, 0.455339f, 0.541203f, 0.513318f, 0.516310f, 0.417490f, 0.509893f, 0.590295f, 0.518703f, 0.497346f, 0.569950f, 0.531036f, 0.515108f, 0.551188f, 0.511368f, 0.428004f, 0.470681f, 0.584422f, 0.481287f, 0.526080f, 0.523233f, 0.457405f, 0.481407f, 0.573666f, 0.505292f, 0.455096f, 0.488968f, 0.602769f, 0.494229f, 0.506703f, 0.531687f, 0.494376f, 0.500014f, 0.557185f, 0.516992f, 0.456706f, 0.474918f, 0.604858f, 0.507587f, 0.469668f, 0.505480f, 0.509594f, 0.501727f, 0.579587f, 0.520784f, 0.493654f, 0.421248f, 0.447569f, 0.512260f, 0.385047f, 0.415280f, 0.512025f, 0.438027f, 0.412472f, 0.566399f, 0.521616f, 0.425188f, 0.438491f, 0.497757f, 0.359007f, 0.354674f, 0.526893f, 0.436536f, 0.365545f, 0.598360f, 0.539148f, 0.414424f, 0.449425f, 0.469435f, 0.387864f, 0.398897f, 0.495746f, 0.442739f, 0.325650f, 0.565445f, 0.528260f, 0.427462f, 0.414675f, 0.471898f, 0.383976f, 0.365848f, 0.492247f, 0.412142f, 0.346633f, 0.594105f, 0.607776f, 0.533772f, 0.468197f, 0.372208f, 0.489865f, 0.443200f, 0.545535f, 0.493389f, 0.551969f, 0.423333f, 0.646158f, 0.558704f, 0.439156f, 0.446620f, 0.451905f, 0.487079f, 0.528236f, 0.561621f, 0.598777f, 0.437840f, 0.621812f, 0.514033f, 0.477342f, 0.401848f, 0.471414f, 0.463881f, 0.530019f, 0.506494f, 0.559079f, 0.454743f, 0.645883f, 0.532612f, 0.484295f, 0.429611f, 0.471412f, 0.470437f, 0.545854f, 0.509529f, 0.591309f, 0.463628f, 0.463473f, 0.428821f, 0.487303f, 0.522334f, 0.486353f, 0.659896f, 0.556700f, 0.410148f, 0.569697f, 0.495767f, 0.437882f, 0.420329f, 0.503654f, 0.527284f, 0.465816f, 0.623204f, 0.569190f, 0.413123f, 0.554353f, 0.518062f, 0.492239f, 0.410378f, 0.461884f, 0.498402f, 0.509016f, 0.682983f, 0.535407f, 0.412562f, 0.551318f, 0.498037f, 0.470375f, 0.407394f, 0.460899f, 0.496268f, 0.464923f, 0.672767f, 0.533764f, 0.427543f, 0.577909f, 0.506939f}; + // {2, 3, 18, 8} + std::vector present_key = {0.980486f, 0.096451f, 0.863471f, 0.566506f, 0.367917f, 0.342342f, 0.757364f, 0.314573f, 0.657319f, 0.517326f, 0.484966f, 0.901162f, 0.554645f, 0.826862f, 0.725574f, 0.038557f, 0.773110f, 0.216870f, 0.903150f, 0.042924f, 0.333072f, 0.099733f, 0.475589f, 0.820022f, 0.298187f, 0.150935f, 0.330267f, 0.813880f, 0.140384f, 0.227362f, 0.068852f, 0.705710f, 0.395233f, 0.310840f, 0.718626f, 0.335978f, 0.727771f, 0.815199f, 0.217663f, 0.973819f, 0.162358f, 0.290841f, 0.179795f, 0.345506f, 0.480061f, 0.522176f, 0.853606f, 0.889448f, 0.220104f, 0.622894f, 0.111496f, 0.458970f, 0.322334f, 0.316501f, 0.482584f, 0.729828f, 0.069183f, 0.879173f, 0.734814f, 0.176499f, 0.939161f, 0.506312f, 0.999809f, 0.197259f, 0.534908f, 0.290248f, 0.304174f, 0.591065f, 0.921719f, 0.805264f, 0.723941f, 0.559174f, 0.922298f, 0.492361f, 0.873832f, 0.833982f, 0.213835f, 0.771225f, 0.012171f, 0.322830f, 0.229567f, 0.506863f, 0.736853f, 0.097676f, 0.514922f, 0.938412f, 0.228647f, 0.677141f, 0.592880f, 0.010064f, 0.475826f, 0.708770f, 0.043975f, 0.879521f, 0.520081f, 0.030661f, 0.186193f, 0.944372f, 0.739551f, 0.490459f, 0.227415f, 0.254356f, 0.058029f, 0.434417f, 0.311796f, 0.696343f, 0.377752f, 0.179604f, 0.024679f, 0.067250f, 0.679393f, 0.453697f, 0.536579f, 0.896671f, 0.990339f, 0.216897f, 0.663078f, 0.263322f, 0.020651f, 0.758379f, 0.320017f, 0.383464f, 0.588317f, 0.831048f, 0.628982f, 0.872651f, 0.273542f, 0.798047f, 0.185636f, 0.952792f, 0.687488f, 0.215508f, 0.947371f, 0.730856f, 0.253942f, 0.213312f, 0.518201f, 0.025663f, 0.207470f, 0.424685f, 0.374170f, 0.463575f, 0.277629f, 0.586784f, 0.224414f, 0.953676f, 0.582320f, 0.107473f, 0.287544f, 0.456704f, 0.020950f, 0.411616f, 0.489459f, 0.243678f, 0.588639f, 0.753240f, 0.235834f, 0.620500f, 0.639622f, 0.948540f, 0.778276f, 0.848345f, 0.490420f, 0.185349f, 0.995815f, 0.129356f, 0.471457f, 0.068093f, 0.943851f, 0.964925f, 0.719389f, 0.349993f, 0.254382f, 0.265303f, 0.127294f, 0.525809f, 0.141817f, 0.316731f, 0.626706f, 0.727544f, 0.024273f, 0.430116f, 0.652125f, 0.853246f, 0.475325f, 0.969206f, 0.265633f, 0.013509f, 0.483753f, 0.256114f, 0.823718f, 0.232773f, 0.310629f, 0.791227f, 0.715143f, 0.558051f, 0.704948f, 0.418637f, 0.005310f, 0.011355f, 0.511222f, 0.083291f, 0.051075f, 0.965517f, 0.859003f, 0.152027f, 0.000664f, 0.941668f, 0.278325f, 0.185898f, 0.691508f, 0.108904f, 0.264650f, 0.975095f, 0.639463f, 0.520678f, 0.397919f, 0.774501f, 0.140957f, 0.967338f, 0.861123f, 0.617657f, 0.042906f, 0.700856f, 0.913284f, 0.524577f, 0.354225f, 0.120277f, 0.754901f, 0.885022f, 0.100252f, 0.758985f, 0.017060f, 0.967055f, 0.615058f, 0.552439f, 0.295950f, 0.929292f, 0.265906f, 0.828147f, 0.863856f, 0.117532f, 0.517379f, 0.132068f, 0.716860f, 0.396060f, 0.565421f, 0.183280f, 0.144848f, 0.488056f, 0.355613f, 0.940432f, 0.765325f, 0.748664f, 0.903720f, 0.083422f, 0.552192f, 0.584476f, 0.961936f, 0.292148f, 0.240829f, 0.100294f, 0.016430f, 0.929529f, 0.669917f, 0.785153f, 0.281730f, 0.586410f, 0.063955f, 0.485628f, 0.977495f, 0.876505f, 0.338159f, 0.961570f, 0.231702f, 0.949319f, 0.941378f, 0.799203f, 0.630448f, 0.874288f, 0.293020f, 0.848944f, 0.617877f, 0.013237f, 0.347234f, 0.148141f, 0.981829f, 0.478370f, 0.985109f, 0.783397f, 0.518990f, 0.066074f, 0.472414f, 0.438256f, 0.202796f, 0.423588f, 0.357758f, 0.163684f, 0.441374f, 0.262800f, 0.522062f, 0.035160f, 0.906231f, 0.816364f, 0.552581f, 0.851809f, 0.962395f, 0.110522f, 0.630832f, 0.997994f, 0.987889f, 0.603323f, 0.128021f, 0.583193f, 0.002065f, 0.198911f, 0.956123f, 0.330441f, 0.638390f, 0.280860f, 0.947822f, 0.728559f, 0.329651f, 0.791761f, 0.108166f, 0.392319f, 0.221218f, 0.683726f, 0.102446f, 0.397026f, 0.276650f, 0.506343f, 0.349898f, 0.706411f, 0.024577f, 0.633987f, 0.230571f, 0.268709f, 0.800256f, 0.955568f, 0.316550f, 0.826805f, 0.103991f, 0.633982f, 0.751032f, 0.155978f, 0.426002f, 0.892707f, 0.103578f, 0.018096f, 0.590585f, 0.435532f, 0.798689f, 0.923456f, 0.299154f, 0.388404f, 0.486272f, 0.588151f, 0.983854f, 0.697330f, 0.389549f, 0.263768f, 0.944626f, 0.135548f, 0.720266f, 0.925395f, 0.664666f, 0.423054f, 0.198991f, 0.367475f, 0.706872f, 0.649534f, 0.927976f, 0.866861f, 0.816151f, 0.911451f, 0.276337f, 0.369524f, 0.379894f, 0.560451f, 0.668218f, 0.286717f, 0.019462f, 0.399222f, 0.497391f, 0.639473f, 0.368585f, 0.136900f, 0.822118f, 0.189848f, 0.511319f, 0.224317f, 0.097844f, 0.862191f, 0.972919f, 0.960835f, 0.906555f, 0.774047f, 0.333145f, 0.081101f, 0.407241f, 0.232234f, 0.132488f, 0.053427f, 0.725594f, 0.011427f, 0.770581f, 0.146947f, 0.079522f, 0.089603f, 0.672048f, 0.245367f, 0.420539f, 0.557369f, 0.860551f, 0.727044f, 0.270328f, 0.131483f, 0.055374f, 0.301599f, 0.262118f, 0.456141f, 0.683281f, 0.695625f, 0.283519f, 0.379927f, 0.181151f, 0.788545f, 0.056848f, 0.696997f, 0.778695f, 0.777408f, 0.308528f, 0.942185f, 0.888265f, 0.860311f, 0.653000f, 0.344289f, 0.548849f, 0.815225f, 0.098610f, 0.801075f, 0.041180f, 0.816421f, 0.807564f, 0.051007f, 0.627161f, 0.502453f, 0.169820f, 0.148379f, 0.773259f, 0.567693f, 0.982999f, 0.982248f, 0.992667f, 0.118616f, 0.938256f, 0.244570f, 0.458212f, 0.757407f, 0.203621f, 0.566312f, 0.185817f, 0.104736f, 0.116559f, 0.357639f, 0.004655f, 0.424854f, 0.664197f, 0.401688f, 0.085795f, 0.062689f, 0.278117f, 0.169313f, 0.965095f, 0.151230f, 0.805462f, 0.586108f, 0.569287f, 0.512081f, 0.971763f, 0.363845f, 0.787916f, 0.555294f, 0.395634f, 0.955466f, 0.598316f, 0.118917f, 0.417539f, 0.781582f, 0.693747f, 0.916340f, 0.259377f, 0.758194f, 0.459875f, 0.573610f, 0.955047f, 0.979286f, 0.861591f, 0.359097f, 0.887701f, 0.638609f, 0.429997f, 0.035743f, 0.770128f, 0.502106f, 0.786188f, 0.748023f, 0.793567f, 0.300651f, 0.800799f, 0.548846f, 0.473326f, 0.675126f, 0.021359f, 0.102317f, 0.292177f, 0.982990f, 0.139746f, 0.330596f, 0.051053f, 0.331269f, 0.320326f, 0.946807f, 0.845154f, 0.382764f, 0.024769f, 0.831031f, 0.259423f, 0.373813f, 0.587600f, 0.272822f, 0.370853f, 0.197054f, 0.459856f, 0.044612f, 0.799796f, 0.076956f, 0.518835f, 0.306810f, 0.577543f, 0.959433f, 0.645570f, 0.035362f, 0.430402f, 0.510017f, 0.536178f, 0.681392f, 0.277596f, 0.128861f, 0.392676f, 0.956406f, 0.187131f, 0.903984f, 0.543806f, 0.456911f, 0.882041f, 0.458604f, 0.724168f, 0.399025f, 0.904044f, 0.690025f, 0.699622f, 0.327720f, 0.756779f, 0.636061f, 0.240020f, 0.160539f, 0.796391f, 0.959167f, 0.458139f, 0.590984f, 0.857723f, 0.457223f, 0.951874f, 0.575751f, 0.660536f, 0.152364f, 0.996071f, 0.100233f, 0.867115f, 0.294266f, 0.435353f, 0.795457f, 0.677508f, 0.937864f, 0.621140f, 0.097810f, 0.884360f, 0.769156f, 0.711870f, 0.053734f, 0.396223f, 0.167436f, 0.821904f, 0.700529f, 0.883078f, 0.966575f, 0.774748f, 0.994233f, 0.614770f, 0.037130f, 0.014252f, 0.342104f, 0.823472f, 0.866135f, 0.960813f, 0.065121f, 0.044571f, 0.913284f, 0.305047f, 0.557987f, 0.982445f, 0.400449f, 0.665871f, 0.400880f, 0.768195f, 0.527715f, 0.237523f, 0.271306f, 0.258059f, 0.532320f, 0.703189f, 0.949280f, 0.694087f, 0.781193f, 0.168926f, 0.374063f, 0.413780f, 0.686380f, 0.295892f, 0.303292f, 0.355889f, 0.810302f, 0.577590f, 0.075277f, 0.078246f, 0.371287f, 0.766591f, 0.688683f, 0.707982f, 0.767210f, 0.287153f, 0.548256f, 0.543353f, 0.739632f, 0.956871f, 0.277990f, 0.793282f, 0.659971f, 0.580238f, 0.774880f, 0.944032f, 0.036691f, 0.147400f, 0.756287f, 0.083791f, 0.516124f, 0.219861f, 0.274296f, 0.701840f, 0.030193f, 0.873319f, 0.444479f, 0.502393f, 0.540048f, 0.645544f, 0.344857f, 0.101107f, 0.318379f, 0.168142f, 0.556133f, 0.820767f, 0.908844f, 0.815524f, 0.159414f, 0.628898f, 0.398434f, 0.062713f, 0.424032f, 0.258684f, 0.849038f, 0.033305f, 0.958983f, 0.355369f, 0.356707f, 0.016329f, 0.185232f, 0.401260f, 0.929291f, 0.099615f, 0.945302f, 0.869489f, 0.454162f, 0.326701f, 0.232744f, 0.614465f, 0.033075f, 0.015606f, 0.428796f, 0.068074f, 0.251941f, 0.221161f, 0.253191f, 0.131055f, 0.012036f, 0.115484f, 0.618480f, 0.974256f, 0.990345f, 0.409054f, 0.162954f, 0.638762f, 0.490305f, 0.989410f, 0.065304f, 0.783234f, 0.288399f, 0.241419f, 0.662505f, 0.318029f, 0.958067f, 0.965734f, 0.620126f, 0.617497f, 0.985379f, 0.887283f, 0.765070f, 0.313591f, 0.365539f, 0.201267f, 0.487148f, 0.990369f, 0.912151f, 0.118349f, 0.025190f, 0.898638f, 0.537170f, 0.200190f, 0.673653f, 0.644223f, 0.122086f, 0.259600f, 0.060078f, 0.209860f, 0.132306f, 0.193236f, 0.685467f, 0.049500f, 0.101855f, 0.134174f, 0.316541f, 0.298750f, 0.255064f, 0.750537f, 0.998023f, 0.533978f, 0.944203f, 0.396610f, 0.106682f, 0.408774f, 0.296128f, 0.493407f, 0.657044f, 0.461050f, 0.935161f, 0.884765f, 0.701978f, 0.489685f, 0.131687f, 0.397014f, 0.704402f, 0.284886f, 0.103988f, 0.907898f, 0.709051f, 0.615276f, 0.792499f, 0.835646f, 0.483459f, 0.881188f, 0.916419f, 0.271551f, 0.607545f, 0.526584f, 0.537946f, 0.937663f, 0.305189f, 0.983434f, 0.902131f, 0.458723f, 0.817453f, 0.769047f, 0.677895f, 0.319834f, 0.196451f, 0.671528f, 0.842973f, 0.016253f, 0.642803f, 0.442873f, 0.898088f, 0.321473f, 0.474185f, 0.514767f, 0.140440f, 0.712892f, 0.830476f, 0.057909f, 0.291389f, 0.038045f, 0.956544f, 0.667169f, 0.964200f, 0.531494f, 0.802069f, 0.246063f, 0.665859f, 0.517309f, 0.424089f, 0.554688f, 0.287052f, 0.706575f, 0.414857f, 0.360546f, 0.828657f, 0.924967f, 0.046007f, 0.232627f, 0.348519f, 0.814966f, 0.985491f, 0.968972f, 0.904948f, 0.296556f, 0.992011f, 0.249420f, 0.105906f, 0.950953f, 0.233420f, 0.689768f, 0.058356f, 0.730709f, 0.881720f, 0.272437f, 0.379057f, 0.374296f, 0.748788f, 0.237807f, 0.171853f, 0.449292f, 0.304468f, 0.839189f, 0.237742f, 0.502389f, 0.942584f, 0.633998f, 0.867289f, 0.940210f, 0.750765f, 0.699575f, 0.967966f, 0.994401f, 0.451822f}; + // {2, 3, 18, 10} + std::vector present_value = {0.374414f, 0.353819f, 0.378268f, 0.657862f, 0.359453f, 0.900367f, 0.983275f, 0.030427f, 0.193623f, 0.112250f, 0.042364f, 0.227741f, 0.446793f, 0.836990f, 0.221824f, 0.493945f, 0.929619f, 0.667215f, 0.798079f, 0.550994f, 0.980466f, 0.588662f, 0.045511f, 0.197983f, 0.404774f, 0.601277f, 0.771931f, 0.413086f, 0.710058f, 0.789869f, 0.317260f, 0.979270f, 0.649656f, 0.880998f, 0.555938f, 0.741603f, 0.770544f, 0.908248f, 0.150350f, 0.558283f, 0.428379f, 0.923159f, 0.105095f, 0.982574f, 0.875451f, 0.073826f, 0.490966f, 0.717560f, 0.738152f, 0.906494f, 0.799865f, 0.310930f, 0.498435f, 0.701786f, 0.138437f, 0.193991f, 0.481042f, 0.298246f, 0.862559f, 0.586277f, 0.348665f, 0.848833f, 0.804878f, 0.998355f, 0.847308f, 0.414457f, 0.127499f, 0.840641f, 0.059758f, 0.350271f, 0.919738f, 0.960766f, 0.640565f, 0.688648f, 0.042454f, 0.514480f, 0.546868f, 0.340101f, 0.068597f, 0.228908f, 0.357984f, 0.435142f, 0.590927f, 0.722392f, 0.317632f, 0.328954f, 0.019692f, 0.040875f, 0.257822f, 0.740245f, 0.628314f, 0.769789f, 0.768919f, 0.856567f, 0.720319f, 0.979011f, 0.898825f, 0.586717f, 0.588158f, 0.034267f, 0.998527f, 0.131576f, 0.740347f, 0.821015f, 0.373055f, 0.196852f, 0.098760f, 0.748606f, 0.452654f, 0.713718f, 0.915408f, 0.146584f, 0.919171f, 0.411626f, 0.305267f, 0.943062f, 0.990652f, 0.198892f, 0.656838f, 0.106495f, 0.070870f, 0.292794f, 0.152355f, 0.417486f, 0.131289f, 0.604118f, 0.382808f, 0.895386f, 0.967795f, 0.546885f, 0.274824f, 0.592230f, 0.896761f, 0.406733f, 0.552078f, 0.271653f, 0.455444f, 0.401714f, 0.248413f, 0.505866f, 0.310381f, 0.373035f, 0.524970f, 0.750595f, 0.333507f, 0.924159f, 0.862319f, 0.048690f, 0.253643f, 0.446136f, 0.104628f, 0.348476f, 0.740098f, 0.680514f, 0.622384f, 0.710528f, 0.204924f, 0.341698f, 0.676242f, 0.879235f, 0.543678f, 0.282700f, 0.030235f, 0.710337f, 0.007884f, 0.372679f, 0.530537f, 0.922111f, 0.089495f, 0.405942f, 0.024313f, 0.342611f, 0.622231f, 0.279068f, 0.209750f, 0.115703f, 0.577140f, 0.695270f, 0.671957f, 0.948861f, 0.650914f, 0.827313f, 0.684499f, 0.417333f, 0.383066f, 0.393122f, 0.589712f, 0.881567f, 0.929066f, 0.053530f, 0.181622f, 0.112224f, 0.193335f, 0.346608f, 0.506532f, 0.629461f, 0.732142f, 0.890112f, 0.989088f, 0.662856f, 0.845365f, 0.778039f, 0.307532f, 0.875692f, 0.042763f, 0.000367f, 0.273733f, 0.462098f, 0.638363f, 0.101770f, 0.673010f, 0.801816f, 0.185313f, 0.415125f, 0.519985f, 0.451807f, 0.799830f, 0.960522f, 0.798953f, 0.077993f, 0.804936f, 0.066596f, 0.235970f, 0.153097f, 0.197519f, 0.528315f, 0.671690f, 0.470321f, 0.959696f, 0.240292f, 0.763140f, 0.870182f, 0.562066f, 0.456223f, 0.596184f, 0.428810f, 0.555194f, 0.416934f, 0.400470f, 0.695346f, 0.092851f, 0.166542f, 0.851198f, 0.771077f, 0.281454f, 0.377269f, 0.926027f, 0.818077f, 0.614346f, 0.221490f, 0.044252f, 0.431258f, 0.672627f, 0.828480f, 0.852689f, 0.032776f, 0.244157f, 0.339095f, 0.188732f, 0.802975f, 0.767466f, 0.516833f, 0.982926f, 0.144059f, 0.899652f, 0.116463f, 0.163182f, 0.696219f, 0.109570f, 0.565845f, 0.420234f, 0.728474f, 0.900675f, 0.769872f, 0.849690f, 0.032945f, 0.310196f, 0.515433f, 0.415953f, 0.231255f, 0.307874f, 0.945431f, 0.294181f, 0.353904f, 0.003710f, 0.845078f, 0.154841f, 0.204144f, 0.255265f, 0.884622f, 0.206451f, 0.797526f, 0.808049f, 0.927021f, 0.115561f, 0.217279f, 0.742898f, 0.196001f, 0.286330f, 0.166742f, 0.002703f, 0.647197f, 0.600392f, 0.588740f, 0.962770f, 0.016872f, 0.696482f, 0.813679f, 0.509807f, 0.333965f, 0.790840f, 0.097243f, 0.442036f, 0.519952f, 0.693956f, 0.090886f, 0.227759f, 0.410302f, 0.623295f, 0.886961f, 0.618826f, 0.133461f, 0.980580f, 0.871786f, 0.502721f, 0.922348f, 0.541381f, 0.923306f, 0.829897f, 0.968286f, 0.919783f, 0.036034f, 0.174772f, 0.389135f, 0.952143f, 0.300029f, 0.160468f, 0.886305f, 0.446394f, 0.907876f, 0.160230f, 0.661117f, 0.440264f, 0.076487f, 0.696463f, 0.247399f, 0.039616f, 0.059944f, 0.061079f, 0.907733f, 0.739884f, 0.898062f, 0.672582f, 0.528940f, 0.304446f, 0.997962f, 0.362189f, 0.470649f, 0.378245f, 0.979527f, 0.172697f, 0.481553f, 0.109683f, 0.321698f, 0.426594f, 0.024548f, 0.388333f, 0.094122f, 0.493579f, 0.825738f, 0.818422f, 0.080449f, 0.601228f, 0.834586f, 0.237973f, 0.761927f, 0.890764f, 0.806124f, 0.107301f, 0.009060f, 0.191724f, 0.270477f, 0.616183f, 0.384273f, 0.703407f, 0.353075f, 0.154425f, 0.312690f, 0.884324f, 0.958532f, 0.207513f, 0.788468f, 0.273349f, 0.887132f, 0.165546f, 0.665960f, 0.084211f, 0.973893f, 0.700633f, 0.841816f, 0.566669f, 0.476801f, 0.621882f, 0.528742f, 0.469384f, 0.759450f, 0.178201f, 0.171172f, 0.431843f, 0.320748f, 0.074125f, 0.844471f, 0.771603f, 0.543921f, 0.979325f, 0.072600f, 0.766669f, 0.266370f, 0.368599f, 0.219279f, 0.789038f, 0.144240f, 0.840017f, 0.661578f, 0.059023f, 0.810982f, 0.627756f, 0.904982f, 0.748722f, 0.561121f, 0.836547f, 0.278050f, 0.546950f, 0.293617f, 0.968204f, 0.226196f, 0.015738f, 0.325855f, 0.502509f, 0.028363f, 0.559248f, 0.874283f, 0.704732f, 0.622968f, 0.955962f, 0.958279f, 0.824266f, 0.607742f, 0.487765f, 0.013316f, 0.606262f, 0.989088f, 0.818101f, 0.340605f, 0.152047f, 0.784059f, 0.743938f, 0.967047f, 0.874842f, 0.555663f, 0.101284f, 0.483501f, 0.313695f, 0.512408f, 0.301702f, 0.861823f, 0.844327f, 0.315465f, 0.599581f, 0.430181f, 0.909093f, 0.187361f, 0.697728f, 0.970375f, 0.175276f, 0.201966f, 0.693723f, 0.779154f, 0.490549f, 0.609686f, 0.174658f, 0.327988f, 0.680349f, 0.063208f, 0.607249f, 0.477646f, 0.284000f, 0.238413f, 0.514513f, 0.367928f, 0.456520f, 0.337477f, 0.970494f, 0.133439f, 0.096804f, 0.343392f, 0.591027f, 0.659176f, 0.397257f, 0.999278f, 0.351893f, 0.721407f, 0.637583f, 0.813054f, 0.976226f, 0.889794f, 0.764562f, 0.698249f, 0.335498f, 0.147686f, 0.062636f, 0.241902f, 0.432281f, 0.521996f, 0.773084f, 0.958741f, 0.117320f, 0.107004f, 0.589695f, 0.745398f, 0.848150f, 0.935832f, 0.983426f, 0.399802f, 0.380335f, 0.147809f, 0.684934f, 0.656762f, 0.862063f, 0.097258f, 0.497777f, 0.581082f, 0.241557f, 0.169025f, 0.859581f, 0.058535f, 0.470621f, 0.115834f, 0.457059f, 0.979962f, 0.212682f, 0.476614f, 0.112072f, 0.321422f, 0.284780f, 0.444625f, 0.930126f, 0.181268f, 0.401388f, 0.615597f, 0.946557f, 0.133148f, 0.917877f, 0.081054f, 0.480741f, 0.454590f, 0.209603f, 0.347460f, 0.454165f, 0.865211f, 0.955064f, 0.518926f, 0.870100f, 0.608172f, 0.349087f, 0.194194f, 0.413135f, 0.522824f, 0.044443f, 0.145841f, 0.600184f, 0.225002f, 0.837326f, 0.326942f, 0.104834f, 0.083531f, 0.937123f, 0.118020f, 0.140910f, 0.862666f, 0.254288f, 0.665951f, 0.816726f, 0.607181f, 0.957489f, 0.708883f, 0.112752f, 0.558410f, 0.718186f, 0.801957f, 0.026321f, 0.718879f, 0.825681f, 0.746834f, 0.512349f, 0.458021f, 0.549419f, 0.704644f, 0.922914f, 0.617035f, 0.887834f, 0.701257f, 0.068336f, 0.500828f, 0.286486f, 0.285175f, 0.355928f, 0.314733f, 0.578610f, 0.683601f, 0.268749f, 0.129763f, 0.058809f, 0.575753f, 0.186130f, 0.009248f, 0.927753f, 0.537140f, 0.092448f, 0.842921f, 0.983203f, 0.448601f, 0.042490f, 0.117546f, 0.381654f, 0.885523f, 0.148039f, 0.823990f, 0.014976f, 0.457389f, 0.644397f, 0.060379f, 0.614763f, 0.944404f, 0.160260f, 0.729611f, 0.609094f, 0.185116f, 0.006203f, 0.009284f, 0.532092f, 0.942779f, 0.644299f, 0.714300f, 0.493865f, 0.581889f, 0.126368f, 0.876821f, 0.760793f, 0.998199f, 0.297723f, 0.227018f, 0.125162f, 0.964210f, 0.780885f, 0.166325f, 0.552686f, 0.413768f, 0.151486f, 0.162073f, 0.423706f, 0.857125f, 0.117316f, 0.271252f, 0.403793f, 0.399812f, 0.671384f, 0.344718f, 0.713767f, 0.639187f, 0.399161f, 0.431760f, 0.614528f, 0.070042f, 0.822407f, 0.653421f, 0.726342f, 0.536923f, 0.110477f, 0.405036f, 0.405374f, 0.321043f, 0.029950f, 0.737254f, 0.109784f, 0.606308f, 0.703218f, 0.634786f, 0.959142f, 0.103298f, 0.867167f, 0.029190f, 0.534917f, 0.404244f, 0.524184f, 0.365100f, 0.190567f, 0.019123f, 0.518150f, 0.842777f, 0.373216f, 0.222864f, 0.080532f, 0.085311f, 0.221396f, 0.100014f, 0.265040f, 0.066149f, 0.065605f, 0.856276f, 0.162120f, 0.559682f, 0.773456f, 0.456410f, 0.153369f, 0.199596f, 0.432984f, 0.528234f, 0.349440f, 0.781480f, 0.963470f, 0.304964f, 0.941439f, 0.075611f, 0.460803f, 0.129619f, 0.004787f, 0.553766f, 0.113894f, 0.722025f, 0.698116f, 0.176333f, 0.941742f, 0.721043f, 0.297970f, 0.709234f, 0.731930f, 0.342226f, 0.375589f, 0.359107f, 0.616618f, 0.900410f, 0.173193f, 0.875200f, 0.027653f, 0.660339f, 0.414439f, 0.791282f, 0.721198f, 0.480108f, 0.643864f, 0.501773f, 0.811518f, 0.476084f, 0.523156f, 0.250521f, 0.605043f, 0.302905f, 0.577284f, 0.169678f, 0.159469f, 0.417030f, 0.426820f, 0.268109f, 0.131597f, 0.039211f, 0.025232f, 0.271550f, 0.461853f, 0.726243f, 0.474872f, 0.904051f, 0.035220f, 0.180661f, 0.338515f, 0.577496f, 0.852736f, 0.350202f, 0.267989f, 0.061889f, 0.821303f, 0.379666f, 0.571550f, 0.983555f, 0.001595f, 0.145450f, 0.779111f, 0.805128f, 0.769247f, 0.536999f, 0.978857f, 0.396185f, 0.601944f, 0.063369f, 0.409857f, 0.722500f, 0.238739f, 0.943828f, 0.686783f, 0.287575f, 0.768999f, 0.083165f, 0.974774f, 0.049285f, 0.933456f, 0.252854f, 0.757824f, 0.000074f, 0.254240f, 0.749101f, 0.532336f, 0.114952f, 0.393630f, 0.375549f, 0.568162f, 0.667977f, 0.840830f, 0.497231f, 0.392022f, 0.143977f, 0.804823f, 0.713370f, 0.408677f, 0.518432f, 0.665183f, 0.164806f, 0.027198f, 0.317504f, 0.595585f, 0.486606f, 0.692555f, 0.819690f, 0.488442f, 0.134267f, 0.850628f, 0.574990f, 0.739937f, 0.704665f, 0.968212f, 0.295307f, 0.751022f, 0.927212f, 0.028953f, 0.895691f, 0.392569f, 0.878372f, 0.690785f, 0.987349f, 0.759282f, 0.364545f, 0.501063f, 0.376389f, 0.364912f, 0.260904f, 0.495970f, 0.681740f, 0.277340f, 0.524380f, 0.117380f, 0.159845f, 0.046806f, 0.970731f, 0.003860f, 0.178580f, 0.612867f, 0.081370f, 0.881896f, 0.719620f, 0.966390f, 0.507636f, 0.300404f, 0.549501f, 0.930819f, 0.520761f, 0.267207f, 0.877399f, 0.371919f, 0.001383f, 0.247685f, 0.318234f, 0.858777f, 0.458503f, 0.444587f, 0.336102f, 0.880678f, 0.945027f, 0.991890f, 0.376741f, 0.966147f, 0.791880f, 0.675689f, 0.244889f, 0.216457f, 0.166048f, 0.922757f, 0.294077f, 0.453094f, 0.493958f, 0.778172f, 0.844235f, 0.705307f, 0.365676f, 0.395411f, 0.230595f, 0.344010f, 0.948297f, 0.292571f, 0.245991f, 0.583138f, 0.258036f, 0.473386f, 0.834176f, 0.230400f, 0.426691f, 0.610490f, 0.545629f, 0.974723f, 0.680370f, 0.739946f, 0.966956f, 0.414438f, 0.355380f, 0.043862f, 0.184204f, 0.237190f, 0.183504f, 0.754784f, 0.535883f, 0.667634f, 0.820462f, 0.230774f, 0.325924f, 0.708360f, 0.392759f, 0.029271f, 0.434955f, 0.908273f, 0.409021f, 0.332249f, 0.989525f, 0.644416f, 0.365998f, 0.102020f, 0.787849f, 0.708075f, 0.921916f, 0.217276f, 0.114924f, 0.724073f, 0.203396f, 0.176104f, 0.319807f, 0.816825f, 0.539537f, 0.045850f, 0.463895f, 0.683980f, 0.538368f, 0.572450f, 0.224777f, 0.847739f, 0.561399f, 0.713246f, 0.981864f, 0.428199f, 0.881067f, 0.007281f, 0.033407f, 0.590280f, 0.311449f, 0.248277f, 0.277935f, 0.318403f, 0.728948f, 0.569196f, 0.789036f, 0.830197f, 0.842935f, 0.414644f, 0.421273f, 0.926266f, 0.661764f, 0.080467f, 0.542187f, 0.356007f, 0.987435f, 0.013655f, 0.612181f, 0.723623f, 0.288907f, 0.973642f, 0.859537f, 0.915653f, 0.019232f, 0.569872f, 0.294650f, 0.849029f, 0.632850f, 0.538877f, 0.114588f, 0.540223f, 0.631904f, 0.955912f, 0.585051f, 0.967401f, 0.961606f, 0.650200f, 0.505908f, 0.466022f, 0.890379f, 0.028257f, 0.113808f, 0.102072f, 0.756935f, 0.339651f, 0.637969f, 0.603783f, 0.385828f, 0.531568f, 0.645139f, 0.139073f, 0.426904f, 0.842855f, 0.818033f, 0.102414f, 0.156383f, 0.304199f, 0.075359f, 0.424663f, 0.107618f, 0.568218f, 0.246557f, 0.596433f, 0.117526f, 0.975884f, 0.932561f, 0.391797f, 0.242179f, 0.250398f, 0.483394f, 0.039993f, 0.639705f, 0.408303f, 0.377407f, 0.809365f, 0.709035f, 0.954334f, 0.351936f, 0.897543f, 0.769967f, 0.357425f, 0.621665f, 0.288570f, 0.874400f, 0.112427f, 0.212434f, 0.183033f, 0.403026f, 0.745233f, 0.526907f, 0.487676f, 0.000546f, 0.425402f, 0.063554f, 0.208253f, 0.932394f, 0.215398f, 0.858338f, 0.802893f, 0.159146f, 0.605712f, 0.115662f, 0.727888f, 0.637462f, 0.811939f, 0.479385f, 0.914863f, 0.049349f, 0.292889f, 0.715053f}; + + ASSERT_EQ(q.size(), batch_size * q_num_heads * q_sequence_length * head_size); + ASSERT_EQ(k.size(), batch_size * kv_num_heads * kv_sequence_length * head_size); + ASSERT_EQ(v.size(), batch_size * kv_num_heads * kv_sequence_length * v_head_size); + ASSERT_EQ(m.size(), q_sequence_length * (kv_sequence_length + past_sequence_length)); + ASSERT_EQ(past_key.size(), batch_size * kv_num_heads * past_sequence_length * head_size); + ASSERT_EQ(past_value.size(), batch_size * kv_num_heads * past_sequence_length * v_head_size); + ASSERT_EQ(y.size(), batch_size * q_num_heads * q_sequence_length * v_head_size); + ASSERT_EQ(present_key.size(), batch_size * kv_num_heads * (past_sequence_length + kv_sequence_length) * head_size); + ASSERT_EQ(present_value.size(), batch_size * kv_num_heads * (past_sequence_length + kv_sequence_length) * v_head_size); + + RunTest4D(batch_size, q_num_heads, q_sequence_length, head_size, kv_sequence_length, kv_num_heads, v_head_size, past_sequence_length, + q, k, v, m, std::initializer_list(), past_key, past_value, + -1, -1, std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), -1, TensorType::kFloat, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type + y, present_key, present_value, std::vector(), + false, true, true // disable_cpu, disable_cuda, disable_dml + ); +} + +TEST(AttentionTest, Attention4DGqaAttnMask) { + int batch_size = 2; // Q.shape[0] + int q_num_heads = 9; // Q.shape[1] + int q_sequence_length = 4; // Q.shape[2] + int head_size = 8; // Q.shape[3] + int kv_sequence_length = 6; // K.shape[2] and V.shape[2] + int kv_num_heads = 3; // K.shape[1] and V.shape[1] + int v_head_size = 8; // V.shape[3] + int past_sequence_length = 0; // past_key.shape[2] and past_value.shape[2] + + // {2, 9, 4, 8} + std::vector q = {0.548814f, 0.715189f, 0.602763f, 0.544883f, 0.423655f, 0.645894f, 0.437587f, 0.891773f, 0.963663f, 0.383442f, 0.791725f, 0.528895f, 0.568045f, 0.925597f, 0.071036f, 0.087129f, 0.020218f, 0.832620f, 0.778157f, 0.870012f, 0.978618f, 0.799159f, 0.461479f, 0.780529f, 0.118274f, 0.639921f, 0.143353f, 0.944669f, 0.521848f, 0.414662f, 0.264556f, 0.774234f, 0.456150f, 0.568434f, 0.018790f, 0.617635f, 0.612096f, 0.616934f, 0.943748f, 0.681820f, 0.359508f, 0.437032f, 0.697631f, 0.060225f, 0.666767f, 0.670638f, 0.210383f, 0.128926f, 0.315428f, 0.363711f, 0.570197f, 0.438602f, 0.988374f, 0.102045f, 0.208877f, 0.161310f, 0.653108f, 0.253292f, 0.466311f, 0.244426f, 0.158970f, 0.110375f, 0.656330f, 0.138183f, 0.196582f, 0.368725f, 0.820993f, 0.097101f, 0.837945f, 0.096098f, 0.976459f, 0.468651f, 0.976761f, 0.604846f, 0.739264f, 0.039188f, 0.282807f, 0.120197f, 0.296140f, 0.118728f, 0.317983f, 0.414263f, 0.064147f, 0.692472f, 0.566601f, 0.265390f, 0.523248f, 0.093941f, 0.575947f, 0.929296f, 0.318569f, 0.667410f, 0.131798f, 0.716327f, 0.289406f, 0.183191f, 0.586513f, 0.020108f, 0.828940f, 0.004695f, 0.677817f, 0.270008f, 0.735194f, 0.962189f, 0.248753f, 0.576157f, 0.592042f, 0.572252f, 0.223082f, 0.952749f, 0.447125f, 0.846409f, 0.699479f, 0.297437f, 0.813798f, 0.396506f, 0.881103f, 0.581273f, 0.881735f, 0.692532f, 0.725254f, 0.501324f, 0.956084f, 0.643990f, 0.423855f, 0.606393f, 0.019193f, 0.301575f, 0.660174f, 0.290078f, 0.618015f, 0.428769f, 0.135474f, 0.298282f, 0.569965f, 0.590873f, 0.574325f, 0.653201f, 0.652103f, 0.431418f, 0.896547f, 0.367562f, 0.435865f, 0.891923f, 0.806194f, 0.703889f, 0.100227f, 0.919483f, 0.714241f, 0.998847f, 0.149448f, 0.868126f, 0.162493f, 0.615560f, 0.123820f, 0.848008f, 0.807319f, 0.569101f, 0.407183f, 0.069167f, 0.697429f, 0.453543f, 0.722056f, 0.866382f, 0.975522f, 0.855803f, 0.011714f, 0.359978f, 0.729991f, 0.171630f, 0.521037f, 0.054338f, 0.199997f, 0.018522f, 0.793698f, 0.223925f, 0.345352f, 0.928081f, 0.704414f, 0.031839f, 0.164694f, 0.621478f, 0.577229f, 0.237893f, 0.934214f, 0.613966f, 0.535633f, 0.589910f, 0.730122f, 0.311945f, 0.398221f, 0.209844f, 0.186193f, 0.944372f, 0.739551f, 0.490459f, 0.227415f, 0.254356f, 0.058029f, 0.434417f, 0.311796f, 0.696343f, 0.377752f, 0.179604f, 0.024679f, 0.067250f, 0.679393f, 0.453697f, 0.536579f, 0.896671f, 0.990339f, 0.216897f, 0.663078f, 0.263322f, 0.020651f, 0.758379f, 0.320017f, 0.383464f, 0.588317f, 0.831048f, 0.628982f, 0.872651f, 0.273542f, 0.798047f, 0.185636f, 0.952792f, 0.687488f, 0.215508f, 0.947371f, 0.730856f, 0.253942f, 0.213312f, 0.518201f, 0.025663f, 0.207470f, 0.424685f, 0.374170f, 0.463575f, 0.277629f, 0.586784f, 0.863856f, 0.117532f, 0.517379f, 0.132068f, 0.716860f, 0.396060f, 0.565421f, 0.183280f, 0.144848f, 0.488056f, 0.355613f, 0.940432f, 0.765325f, 0.748664f, 0.903720f, 0.083422f, 0.552192f, 0.584476f, 0.961936f, 0.292148f, 0.240829f, 0.100294f, 0.016430f, 0.929529f, 0.669917f, 0.785153f, 0.281730f, 0.586410f, 0.063955f, 0.485628f, 0.977495f, 0.876505f, 0.338159f, 0.961570f, 0.231702f, 0.949319f, 0.941378f, 0.799203f, 0.630448f, 0.874288f, 0.293020f, 0.848944f, 0.617877f, 0.013237f, 0.347234f, 0.148141f, 0.981829f, 0.478370f, 0.497391f, 0.639473f, 0.368585f, 0.136900f, 0.822118f, 0.189848f, 0.511319f, 0.224317f, 0.097844f, 0.862191f, 0.972919f, 0.960835f, 0.906555f, 0.774047f, 0.333145f, 0.081101f, 0.407241f, 0.232234f, 0.132488f, 0.053427f, 0.725594f, 0.011427f, 0.770581f, 0.146947f, 0.079522f, 0.089603f, 0.672048f, 0.245367f, 0.420539f, 0.557369f, 0.860551f, 0.727044f, 0.270328f, 0.131483f, 0.055374f, 0.301599f, 0.262118f, 0.456141f, 0.683281f, 0.695625f, 0.283519f, 0.379927f, 0.181151f, 0.788545f, 0.056848f, 0.696997f, 0.778695f, 0.777408f, 0.259423f, 0.373813f, 0.587600f, 0.272822f, 0.370853f, 0.197054f, 0.459856f, 0.044612f, 0.799796f, 0.076956f, 0.518835f, 0.306810f, 0.577543f, 0.959433f, 0.645570f, 0.035362f, 0.430402f, 0.510017f, 0.536178f, 0.681392f, 0.277596f, 0.128861f, 0.392676f, 0.956406f, 0.187131f, 0.903984f, 0.543806f, 0.456911f, 0.882041f, 0.458604f, 0.724168f, 0.399025f, 0.904044f, 0.690025f, 0.699622f, 0.327720f, 0.756779f, 0.636061f, 0.240020f, 0.160539f, 0.796391f, 0.959167f, 0.458139f, 0.590984f, 0.857723f, 0.457223f, 0.951874f, 0.575751f, 0.820767f, 0.908844f, 0.815524f, 0.159414f, 0.628898f, 0.398434f, 0.062713f, 0.424032f, 0.258684f, 0.849038f, 0.033305f, 0.958983f, 0.355369f, 0.356707f, 0.016329f, 0.185232f, 0.401260f, 0.929291f, 0.099615f, 0.945302f, 0.869489f, 0.454162f, 0.326701f, 0.232744f, 0.614465f, 0.033075f, 0.015606f, 0.428796f, 0.068074f, 0.251941f, 0.221161f, 0.253191f, 0.131055f, 0.012036f, 0.115484f, 0.618480f, 0.974256f, 0.990345f, 0.409054f, 0.162954f, 0.638762f, 0.490305f, 0.989410f, 0.065304f, 0.783234f, 0.288399f, 0.241419f, 0.662505f, 0.246063f, 0.665859f, 0.517309f, 0.424089f, 0.554688f, 0.287052f, 0.706575f, 0.414857f, 0.360546f, 0.828657f, 0.924967f, 0.046007f, 0.232627f, 0.348519f, 0.814966f, 0.985491f, 0.968972f, 0.904948f, 0.296556f, 0.992011f, 0.249420f, 0.105906f, 0.950953f, 0.233420f, 0.689768f, 0.058356f, 0.730709f, 0.881720f, 0.272437f, 0.379057f, 0.374296f, 0.748788f, 0.237807f, 0.171853f, 0.449292f, 0.304468f, 0.839189f, 0.237742f, 0.502389f, 0.942584f, 0.633998f, 0.867289f, 0.940210f, 0.750765f, 0.699575f, 0.967966f, 0.994401f, 0.451822f, 0.070870f, 0.292794f, 0.152355f, 0.417486f, 0.131289f, 0.604118f, 0.382808f, 0.895386f, 0.967795f, 0.546885f, 0.274824f, 0.592230f, 0.896761f, 0.406733f, 0.552078f, 0.271653f, 0.455444f, 0.401714f, 0.248413f, 0.505866f, 0.310381f, 0.373035f, 0.524970f, 0.750595f, 0.333507f, 0.924159f, 0.862319f, 0.048690f, 0.253643f, 0.446136f, 0.104628f, 0.348476f, 0.740098f, 0.680514f, 0.622384f, 0.710528f, 0.204924f, 0.341698f, 0.676242f, 0.879235f, 0.543678f, 0.282700f, 0.030235f, 0.710337f, 0.007884f, 0.372679f, 0.530537f, 0.922111f, 0.089495f, 0.405942f, 0.024313f, 0.342611f, 0.622231f, 0.279068f, 0.209750f, 0.115703f, 0.577140f, 0.695270f, 0.671957f, 0.948861f, 0.002703f, 0.647197f, 0.600392f, 0.588740f, 0.962770f, 0.016872f, 0.696482f, 0.813679f, 0.509807f, 0.333965f, 0.790840f, 0.097243f, 0.442036f, 0.519952f, 0.693956f, 0.090886f, 0.227759f, 0.410302f, 0.623295f, 0.886961f, 0.618826f, 0.133461f, 0.980580f, 0.871786f, 0.502721f, 0.922348f, 0.541381f, 0.923306f, 0.829897f, 0.968286f, 0.919783f, 0.036034f, 0.174772f, 0.389135f, 0.952143f, 0.300029f}; + // {2, 3, 6, 8} + std::vector k = {0.160468f, 0.886305f, 0.446394f, 0.907876f, 0.160230f, 0.661117f, 0.440264f, 0.076487f, 0.696463f, 0.247399f, 0.039616f, 0.059944f, 0.061079f, 0.907733f, 0.739884f, 0.898062f, 0.672582f, 0.528940f, 0.304446f, 0.997962f, 0.362189f, 0.470649f, 0.378245f, 0.979527f, 0.174658f, 0.327988f, 0.680349f, 0.063208f, 0.607249f, 0.477646f, 0.284000f, 0.238413f, 0.514513f, 0.367928f, 0.456520f, 0.337477f, 0.970494f, 0.133439f, 0.096804f, 0.343392f, 0.591027f, 0.659176f, 0.397257f, 0.999278f, 0.351893f, 0.721407f, 0.637583f, 0.813054f, 0.976226f, 0.889794f, 0.764562f, 0.698249f, 0.335498f, 0.147686f, 0.062636f, 0.241902f, 0.432281f, 0.521996f, 0.773084f, 0.958741f, 0.117320f, 0.107004f, 0.589695f, 0.745398f, 0.848150f, 0.935832f, 0.983426f, 0.399802f, 0.380335f, 0.147809f, 0.684934f, 0.656762f, 0.862063f, 0.097258f, 0.497777f, 0.581082f, 0.241557f, 0.169025f, 0.859581f, 0.058535f, 0.470621f, 0.115834f, 0.457059f, 0.979962f, 0.423706f, 0.857125f, 0.117316f, 0.271252f, 0.403793f, 0.399812f, 0.671384f, 0.344718f, 0.713767f, 0.639187f, 0.399161f, 0.431760f, 0.614528f, 0.070042f, 0.822407f, 0.653421f, 0.726342f, 0.536923f, 0.110477f, 0.405036f, 0.405374f, 0.321043f, 0.029950f, 0.737254f, 0.109784f, 0.606308f, 0.703218f, 0.634786f, 0.959142f, 0.103298f, 0.867167f, 0.029190f, 0.534917f, 0.404244f, 0.524184f, 0.365100f, 0.190567f, 0.019123f, 0.518150f, 0.842777f, 0.373216f, 0.222864f, 0.080532f, 0.085311f, 0.221396f, 0.100014f, 0.265040f, 0.066149f, 0.065605f, 0.856276f, 0.162120f, 0.559682f, 0.773456f, 0.456410f, 0.153369f, 0.199596f, 0.432984f, 0.528234f, 0.349440f, 0.781480f, 0.751022f, 0.927212f, 0.028953f, 0.895691f, 0.392569f, 0.878372f, 0.690785f, 0.987349f, 0.759282f, 0.364545f, 0.501063f, 0.376389f, 0.364912f, 0.260904f, 0.495970f, 0.681740f, 0.277340f, 0.524380f, 0.117380f, 0.159845f, 0.046806f, 0.970731f, 0.003860f, 0.178580f, 0.612867f, 0.081370f, 0.881896f, 0.719620f, 0.966390f, 0.507636f, 0.300404f, 0.549501f, 0.930819f, 0.520761f, 0.267207f, 0.877399f, 0.371919f, 0.001383f, 0.247685f, 0.318234f, 0.858777f, 0.458503f, 0.444587f, 0.336102f, 0.880678f, 0.945027f, 0.991890f, 0.376741f, 0.966147f, 0.791880f, 0.675689f, 0.244889f, 0.216457f, 0.166048f, 0.922757f, 0.294077f, 0.453094f, 0.493958f, 0.778172f, 0.844235f, 0.139073f, 0.426904f, 0.842855f, 0.818033f, 0.102414f, 0.156383f, 0.304199f, 0.075359f, 0.424663f, 0.107618f, 0.568218f, 0.246557f, 0.596433f, 0.117526f, 0.975884f, 0.932561f, 0.391797f, 0.242179f, 0.250398f, 0.483394f, 0.039993f, 0.639705f, 0.408303f, 0.377407f, 0.809365f, 0.709035f, 0.954334f, 0.351936f, 0.897543f, 0.769967f, 0.357425f, 0.621665f, 0.288570f, 0.874400f, 0.112427f, 0.212434f, 0.183033f, 0.403026f, 0.745233f, 0.526907f, 0.487676f, 0.000546f, 0.425402f, 0.063554f, 0.208253f, 0.932394f, 0.215398f, 0.858338f, 0.802893f, 0.159146f, 0.605712f, 0.115662f, 0.727888f, 0.637462f, 0.811939f, 0.479385f, 0.914863f, 0.049349f, 0.292889f, 0.715053f, 0.418109f, 0.172951f, 0.107211f, 0.817339f, 0.473143f, 0.882284f, 0.733289f, 0.409726f, 0.373511f, 0.515638f, 0.889060f, 0.737279f, 0.005153f, 0.694158f, 0.919507f, 0.710456f, 0.177006f, 0.483518f, 0.140316f, 0.358995f, 0.937117f, 0.923305f, 0.282837f, 0.339631f}; + // {2, 3, 6, 8} + std::vector v = {0.600213f, 0.963197f, 0.147801f, 0.256917f, 0.873557f, 0.491892f, 0.898961f, 0.185518f, 0.532669f, 0.326270f, 0.316543f, 0.446877f, 0.433077f, 0.357347f, 0.914971f, 0.731744f, 0.727547f, 0.289913f, 0.577709f, 0.779179f, 0.795590f, 0.344530f, 0.770873f, 0.735894f, 0.141506f, 0.865945f, 0.441321f, 0.486410f, 0.448369f, 0.567846f, 0.621169f, 0.498180f, 0.866789f, 0.627735f, 0.401428f, 0.416692f, 0.810839f, 0.348192f, 0.211455f, 0.059383f, 0.876027f, 0.918546f, 0.120120f, 0.334474f, 0.175372f, 0.115898f, 0.899867f, 0.056877f, 0.980486f, 0.096451f, 0.863471f, 0.566506f, 0.367917f, 0.342342f, 0.757364f, 0.314573f, 0.657319f, 0.517326f, 0.484966f, 0.901162f, 0.554645f, 0.826862f, 0.725574f, 0.038557f, 0.773110f, 0.216870f, 0.903150f, 0.042924f, 0.333072f, 0.099733f, 0.475589f, 0.820022f, 0.298187f, 0.150935f, 0.330267f, 0.813880f, 0.140384f, 0.227362f, 0.068852f, 0.705710f, 0.395233f, 0.310840f, 0.718626f, 0.335978f, 0.727771f, 0.815199f, 0.217663f, 0.973819f, 0.162358f, 0.290841f, 0.179795f, 0.345506f, 0.480061f, 0.522176f, 0.853606f, 0.889448f, 0.220104f, 0.622894f, 0.111496f, 0.458970f, 0.322334f, 0.316501f, 0.482584f, 0.729828f, 0.069183f, 0.879173f, 0.734814f, 0.176499f, 0.939161f, 0.506312f, 0.999809f, 0.197259f, 0.534908f, 0.290248f, 0.304174f, 0.591065f, 0.921719f, 0.805264f, 0.723941f, 0.559174f, 0.922298f, 0.492361f, 0.873832f, 0.833982f, 0.213835f, 0.771225f, 0.012171f, 0.322830f, 0.229567f, 0.506863f, 0.736853f, 0.097676f, 0.514922f, 0.938412f, 0.228647f, 0.677141f, 0.592880f, 0.010064f, 0.475826f, 0.708770f, 0.043975f, 0.879521f, 0.520081f, 0.030661f, 0.224414f, 0.953676f, 0.582320f, 0.107473f, 0.287544f, 0.456704f, 0.020950f, 0.411616f, 0.489459f, 0.243678f, 0.588639f, 0.753240f, 0.235834f, 0.620500f, 0.639622f, 0.948540f, 0.778276f, 0.848345f, 0.490420f, 0.185349f, 0.995815f, 0.129356f, 0.471457f, 0.068093f, 0.943851f, 0.964925f, 0.719389f, 0.349993f, 0.254382f, 0.265303f, 0.127294f, 0.525809f, 0.141817f, 0.316731f, 0.626706f, 0.727544f, 0.024273f, 0.430116f, 0.652125f, 0.853246f, 0.475325f, 0.969206f, 0.265633f, 0.013509f, 0.483753f, 0.256114f, 0.823718f, 0.232773f, 0.310629f, 0.791227f, 0.715143f, 0.558051f, 0.704948f, 0.418637f, 0.005310f, 0.011355f, 0.511222f, 0.083291f, 0.051075f, 0.965517f, 0.859003f, 0.152027f, 0.000664f, 0.941668f, 0.278325f, 0.185898f, 0.691508f, 0.108904f, 0.264650f, 0.975095f, 0.639463f, 0.520678f, 0.397919f, 0.774501f, 0.140957f, 0.967338f, 0.861123f, 0.617657f, 0.042906f, 0.700856f, 0.913284f, 0.524577f, 0.354225f, 0.120277f, 0.754901f, 0.885022f, 0.100252f, 0.758985f, 0.017060f, 0.967055f, 0.615058f, 0.552439f, 0.295950f, 0.929292f, 0.265906f, 0.828147f, 0.985109f, 0.783397f, 0.518990f, 0.066074f, 0.472414f, 0.438256f, 0.202796f, 0.423588f, 0.357758f, 0.163684f, 0.441374f, 0.262800f, 0.522062f, 0.035160f, 0.906231f, 0.816364f, 0.552581f, 0.851809f, 0.962395f, 0.110522f, 0.630832f, 0.997994f, 0.987889f, 0.603323f, 0.128021f, 0.583193f, 0.002065f, 0.198911f, 0.956123f, 0.330441f, 0.638390f, 0.280860f, 0.947822f, 0.728559f, 0.329651f, 0.791761f, 0.108166f, 0.392319f, 0.221218f, 0.683726f, 0.102446f, 0.397026f, 0.276650f, 0.506343f, 0.349898f, 0.706411f, 0.024577f, 0.633987f}; + // {4, 6} + std::vector m = {0.230571f, 0.268709f, 0.800256f, 0.955568f, 0.316550f, 0.826805f, 0.103991f, 0.633982f, 0.751032f, 0.155978f, 0.426002f, 0.892707f, 0.103578f, 0.018096f, 0.590585f, 0.435532f, 0.798689f, 0.923456f, 0.299154f, 0.388404f, 0.486272f, 0.588151f, 0.983854f, 0.697330f}; + // {2, 9, 4, 8} + std::vector y = {0.641842f, 0.667534f, 0.339592f, 0.480609f, 0.537525f, 0.340368f, 0.752882f, 0.387601f, 0.686814f, 0.643437f, 0.324983f, 0.468788f, 0.539061f, 0.319610f, 0.754181f, 0.373093f, 0.702380f, 0.693136f, 0.318406f, 0.456714f, 0.540838f, 0.315487f, 0.718291f, 0.311025f, 0.681769f, 0.670603f, 0.329705f, 0.456661f, 0.573902f, 0.337385f, 0.700597f, 0.333385f, 0.508992f, 0.253478f, 0.553979f, 0.466355f, 0.398637f, 0.412493f, 0.495810f, 0.677675f, 0.521609f, 0.278997f, 0.564189f, 0.434417f, 0.448085f, 0.467205f, 0.567856f, 0.664713f, 0.490146f, 0.261321f, 0.560582f, 0.424598f, 0.450318f, 0.467336f, 0.520983f, 0.720798f, 0.516095f, 0.264495f, 0.577940f, 0.475340f, 0.444145f, 0.477909f, 0.485663f, 0.672846f, 0.499389f, 0.402198f, 0.520218f, 0.550550f, 0.481065f, 0.730488f, 0.492535f, 0.392315f, 0.436722f, 0.398514f, 0.497457f, 0.502270f, 0.520993f, 0.730472f, 0.565429f, 0.380282f, 0.461226f, 0.392968f, 0.536035f, 0.505191f, 0.446570f, 0.751253f, 0.478584f, 0.389036f, 0.423738f, 0.443828f, 0.554323f, 0.462607f, 0.476656f, 0.733228f, 0.482219f, 0.411910f, 0.620556f, 0.662948f, 0.349409f, 0.482541f, 0.537250f, 0.351544f, 0.734285f, 0.397172f, 0.689500f, 0.637077f, 0.320710f, 0.470914f, 0.526307f, 0.312878f, 0.775762f, 0.384457f, 0.696615f, 0.681034f, 0.324383f, 0.459632f, 0.539497f, 0.317950f, 0.709736f, 0.320698f, 0.671696f, 0.676830f, 0.332387f, 0.453234f, 0.578648f, 0.345084f, 0.685369f, 0.328092f, 0.520830f, 0.251061f, 0.562824f, 0.469184f, 0.393635f, 0.405203f, 0.493565f, 0.668713f, 0.541328f, 0.282797f, 0.577903f, 0.434065f, 0.444664f, 0.460403f, 0.572628f, 0.646402f, 0.493508f, 0.265246f, 0.572078f, 0.418658f, 0.464491f, 0.483746f, 0.516536f, 0.724847f, 0.503705f, 0.270557f, 0.577678f, 0.465114f, 0.468430f, 0.508402f, 0.489087f, 0.689442f, 0.500042f, 0.410507f, 0.521381f, 0.553244f, 0.459062f, 0.719706f, 0.476571f, 0.395052f, 0.429926f, 0.408857f, 0.507006f, 0.493937f, 0.529878f, 0.728873f, 0.571495f, 0.376256f, 0.453676f, 0.380482f, 0.526100f, 0.496696f, 0.457383f, 0.761933f, 0.486657f, 0.396608f, 0.435748f, 0.432822f, 0.531763f, 0.482255f, 0.477046f, 0.726381f, 0.487480f, 0.416572f, 0.626676f, 0.683736f, 0.340657f, 0.475002f, 0.549981f, 0.353311f, 0.740157f, 0.378827f, 0.681403f, 0.636622f, 0.324593f, 0.469088f, 0.537323f, 0.321344f, 0.762506f, 0.384239f, 0.693108f, 0.683351f, 0.329873f, 0.460504f, 0.555115f, 0.325379f, 0.694659f, 0.316422f, 0.677285f, 0.670298f, 0.329724f, 0.456327f, 0.567533f, 0.337560f, 0.701396f, 0.336191f, 0.515940f, 0.251020f, 0.562035f, 0.442479f, 0.405802f, 0.410828f, 0.519841f, 0.686781f, 0.522057f, 0.285013f, 0.562761f, 0.453472f, 0.451971f, 0.481286f, 0.558322f, 0.649971f, 0.486787f, 0.258011f, 0.557963f, 0.426743f, 0.442028f, 0.457034f, 0.510534f, 0.724945f, 0.498901f, 0.272090f, 0.572650f, 0.467930f, 0.465335f, 0.506181f, 0.484559f, 0.690090f, 0.499525f, 0.398443f, 0.522291f, 0.550620f, 0.465209f, 0.731897f, 0.484389f, 0.388997f, 0.411109f, 0.420719f, 0.523354f, 0.478677f, 0.522513f, 0.723052f, 0.587358f, 0.350775f, 0.450881f, 0.384685f, 0.527140f, 0.502089f, 0.438660f, 0.749234f, 0.493312f, 0.377459f, 0.425945f, 0.432397f, 0.544111f, 0.466484f, 0.488077f, 0.738712f, 0.493642f, 0.412262f, 0.565934f, 0.795554f, 0.527262f, 0.295395f, 0.394937f, 0.326235f, 0.457519f, 0.454071f, 0.511390f, 0.753500f, 0.500815f, 0.303925f, 0.403792f, 0.343750f, 0.516333f, 0.463035f, 0.491925f, 0.753119f, 0.503555f, 0.310489f, 0.373396f, 0.334562f, 0.526486f, 0.470500f, 0.495985f, 0.733211f, 0.532951f, 0.342292f, 0.346065f, 0.355272f, 0.479542f, 0.509107f, 0.379088f, 0.582413f, 0.414383f, 0.571800f, 0.613176f, 0.687631f, 0.185596f, 0.656867f, 0.390452f, 0.532452f, 0.407547f, 0.564799f, 0.606499f, 0.653258f, 0.176547f, 0.698038f, 0.410398f, 0.604586f, 0.442972f, 0.497533f, 0.595085f, 0.732265f, 0.187201f, 0.663169f, 0.448716f, 0.590302f, 0.411879f, 0.518449f, 0.636722f, 0.695827f, 0.154292f, 0.666828f, 0.458054f, 0.608582f, 0.430376f, 0.316371f, 0.547620f, 0.542559f, 0.542043f, 0.556297f, 0.468371f, 0.559154f, 0.465195f, 0.344099f, 0.482571f, 0.527115f, 0.527529f, 0.616254f, 0.494566f, 0.605555f, 0.432360f, 0.382197f, 0.466678f, 0.556031f, 0.459313f, 0.588575f, 0.532798f, 0.597684f, 0.412305f, 0.393400f, 0.462773f, 0.491821f, 0.483189f, 0.593919f, 0.569241f, 0.793791f, 0.532988f, 0.300026f, 0.393843f, 0.327085f, 0.448199f, 0.457416f, 0.493302f, 0.725336f, 0.512066f, 0.327500f, 0.404238f, 0.351704f, 0.507818f, 0.477990f, 0.479548f, 0.756083f, 0.511730f, 0.309729f, 0.366024f, 0.338031f, 0.503335f, 0.472352f, 0.473026f, 0.696816f, 0.543129f, 0.374608f, 0.335432f, 0.360978f, 0.486364f, 0.531799f, 0.380422f, 0.599984f, 0.413640f, 0.564090f, 0.607571f, 0.708289f, 0.187551f, 0.671587f, 0.381058f, 0.550543f, 0.422336f, 0.556663f, 0.599418f, 0.666369f, 0.182365f, 0.678737f, 0.423800f, 0.600509f, 0.437094f, 0.494968f, 0.603340f, 0.727226f, 0.179659f, 0.667114f, 0.464399f, 0.563292f, 0.399716f, 0.529198f, 0.655782f, 0.666396f, 0.143497f, 0.659062f, 0.453034f, 0.596627f, 0.417365f, 0.314318f, 0.554269f, 0.518967f, 0.550250f, 0.556252f, 0.494918f, 0.587774f, 0.467566f, 0.350222f, 0.481994f, 0.538857f, 0.525631f, 0.605359f, 0.497486f, 0.608472f, 0.429145f, 0.384532f, 0.466790f, 0.554752f, 0.457698f, 0.586510f, 0.548577f, 0.604359f, 0.398097f, 0.414429f, 0.448200f, 0.485158f, 0.461395f, 0.593015f, 0.563470f, 0.796184f, 0.532783f, 0.293209f, 0.408910f, 0.327450f, 0.438028f, 0.447011f, 0.493041f, 0.739603f, 0.496957f, 0.311881f, 0.389768f, 0.352503f, 0.530113f, 0.476738f, 0.484897f, 0.752985f, 0.511921f, 0.312174f, 0.370408f, 0.339775f, 0.504061f, 0.473793f, 0.487978f, 0.714687f, 0.538817f, 0.358426f, 0.348908f, 0.355820f, 0.481380f, 0.516214f, 0.370872f, 0.602034f, 0.400225f, 0.611090f, 0.630508f, 0.662527f, 0.162489f, 0.658299f, 0.378734f, 0.537283f, 0.412214f, 0.570032f, 0.601452f, 0.653569f, 0.179932f, 0.693105f, 0.411981f, 0.605715f, 0.448022f, 0.481469f, 0.585099f, 0.748463f, 0.195177f, 0.671915f, 0.442141f, 0.581881f, 0.393362f, 0.555388f, 0.650764f, 0.665937f, 0.141141f, 0.675100f, 0.448606f, 0.605061f, 0.412183f, 0.312673f, 0.559178f, 0.530440f, 0.538275f, 0.546820f, 0.494936f, 0.585982f, 0.469875f, 0.355291f, 0.474437f, 0.542980f, 0.518181f, 0.609491f, 0.522046f, 0.618936f, 0.412090f, 0.410711f, 0.452217f, 0.540284f, 0.444109f, 0.585510f, 0.570158f, 0.614413f, 0.415425f, 0.410005f, 0.441791f, 0.491080f, 0.466021f, 0.595833f}; + + ASSERT_EQ(q.size(), batch_size * q_num_heads * q_sequence_length * head_size); + ASSERT_EQ(k.size(), batch_size * kv_num_heads * kv_sequence_length * head_size); + ASSERT_EQ(v.size(), batch_size * kv_num_heads * kv_sequence_length * v_head_size); + ASSERT_EQ(m.size(), q_sequence_length * (kv_sequence_length + past_sequence_length)); + ASSERT_EQ(y.size(), batch_size * q_num_heads * q_sequence_length * v_head_size); + + RunTest4D(batch_size, q_num_heads, q_sequence_length, head_size, kv_sequence_length, kv_num_heads, v_head_size, past_sequence_length, + q, k, v, m, std::initializer_list(), std::vector(), std::vector(), + -1, -1, std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), -1, TensorType::kFloat, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type + y, std::vector(), std::vector(), std::vector(), + false, true, true // disable_cpu, disable_cuda, disable_dml + ); +} + +TEST(AttentionTest, Attention4DGqaWithPastAndPresent) { + int batch_size = 2; // Q.shape[0] + int q_num_heads = 9; // Q.shape[1] + int q_sequence_length = 4; // Q.shape[2] + int head_size = 8; // Q.shape[3] + int kv_sequence_length = 6; // K.shape[2] and V.shape[2] + int kv_num_heads = 3; // K.shape[1] and V.shape[1] + int v_head_size = 8; // V.shape[3] + int past_sequence_length = 12; // past_key.shape[2] and past_value.shape[2] + + // {2, 9, 4, 8} + std::vector q = {0.548814f, 0.715189f, 0.602763f, 0.544883f, 0.423655f, 0.645894f, 0.437587f, 0.891773f, 0.963663f, 0.383442f, 0.791725f, 0.528895f, 0.568045f, 0.925597f, 0.071036f, 0.087129f, 0.020218f, 0.832620f, 0.778157f, 0.870012f, 0.978618f, 0.799159f, 0.461479f, 0.780529f, 0.118274f, 0.639921f, 0.143353f, 0.944669f, 0.521848f, 0.414662f, 0.264556f, 0.774234f, 0.456150f, 0.568434f, 0.018790f, 0.617635f, 0.612096f, 0.616934f, 0.943748f, 0.681820f, 0.359508f, 0.437032f, 0.697631f, 0.060225f, 0.666767f, 0.670638f, 0.210383f, 0.128926f, 0.315428f, 0.363711f, 0.570197f, 0.438602f, 0.988374f, 0.102045f, 0.208877f, 0.161310f, 0.653108f, 0.253292f, 0.466311f, 0.244426f, 0.158970f, 0.110375f, 0.656330f, 0.138183f, 0.196582f, 0.368725f, 0.820993f, 0.097101f, 0.837945f, 0.096098f, 0.976459f, 0.468651f, 0.976761f, 0.604846f, 0.739264f, 0.039188f, 0.282807f, 0.120197f, 0.296140f, 0.118728f, 0.317983f, 0.414263f, 0.064147f, 0.692472f, 0.566601f, 0.265390f, 0.523248f, 0.093941f, 0.575947f, 0.929296f, 0.318569f, 0.667410f, 0.131798f, 0.716327f, 0.289406f, 0.183191f, 0.586513f, 0.020108f, 0.828940f, 0.004695f, 0.677817f, 0.270008f, 0.735194f, 0.962189f, 0.248753f, 0.576157f, 0.592042f, 0.572252f, 0.223082f, 0.952749f, 0.447125f, 0.846409f, 0.699479f, 0.297437f, 0.813798f, 0.396506f, 0.881103f, 0.581273f, 0.881735f, 0.692532f, 0.725254f, 0.501324f, 0.956084f, 0.643990f, 0.423855f, 0.606393f, 0.019193f, 0.301575f, 0.660174f, 0.290078f, 0.618015f, 0.428769f, 0.135474f, 0.298282f, 0.569965f, 0.590873f, 0.574325f, 0.653201f, 0.652103f, 0.431418f, 0.896547f, 0.367562f, 0.435865f, 0.891923f, 0.806194f, 0.703889f, 0.100227f, 0.919483f, 0.714241f, 0.998847f, 0.149448f, 0.868126f, 0.162493f, 0.615560f, 0.123820f, 0.848008f, 0.807319f, 0.569101f, 0.407183f, 0.069167f, 0.697429f, 0.453543f, 0.722056f, 0.866382f, 0.975522f, 0.855803f, 0.011714f, 0.359978f, 0.729991f, 0.171630f, 0.521037f, 0.054338f, 0.199997f, 0.018522f, 0.793698f, 0.223925f, 0.345352f, 0.928081f, 0.704414f, 0.031839f, 0.164694f, 0.621478f, 0.577229f, 0.237893f, 0.934214f, 0.613966f, 0.535633f, 0.589910f, 0.730122f, 0.311945f, 0.398221f, 0.209844f, 0.186193f, 0.944372f, 0.739551f, 0.490459f, 0.227415f, 0.254356f, 0.058029f, 0.434417f, 0.311796f, 0.696343f, 0.377752f, 0.179604f, 0.024679f, 0.067250f, 0.679393f, 0.453697f, 0.536579f, 0.896671f, 0.990339f, 0.216897f, 0.663078f, 0.263322f, 0.020651f, 0.758379f, 0.320017f, 0.383464f, 0.588317f, 0.831048f, 0.628982f, 0.872651f, 0.273542f, 0.798047f, 0.185636f, 0.952792f, 0.687488f, 0.215508f, 0.947371f, 0.730856f, 0.253942f, 0.213312f, 0.518201f, 0.025663f, 0.207470f, 0.424685f, 0.374170f, 0.463575f, 0.277629f, 0.586784f, 0.863856f, 0.117532f, 0.517379f, 0.132068f, 0.716860f, 0.396060f, 0.565421f, 0.183280f, 0.144848f, 0.488056f, 0.355613f, 0.940432f, 0.765325f, 0.748664f, 0.903720f, 0.083422f, 0.552192f, 0.584476f, 0.961936f, 0.292148f, 0.240829f, 0.100294f, 0.016430f, 0.929529f, 0.669917f, 0.785153f, 0.281730f, 0.586410f, 0.063955f, 0.485628f, 0.977495f, 0.876505f, 0.338159f, 0.961570f, 0.231702f, 0.949319f, 0.941378f, 0.799203f, 0.630448f, 0.874288f, 0.293020f, 0.848944f, 0.617877f, 0.013237f, 0.347234f, 0.148141f, 0.981829f, 0.478370f, 0.497391f, 0.639473f, 0.368585f, 0.136900f, 0.822118f, 0.189848f, 0.511319f, 0.224317f, 0.097844f, 0.862191f, 0.972919f, 0.960835f, 0.906555f, 0.774047f, 0.333145f, 0.081101f, 0.407241f, 0.232234f, 0.132488f, 0.053427f, 0.725594f, 0.011427f, 0.770581f, 0.146947f, 0.079522f, 0.089603f, 0.672048f, 0.245367f, 0.420539f, 0.557369f, 0.860551f, 0.727044f, 0.270328f, 0.131483f, 0.055374f, 0.301599f, 0.262118f, 0.456141f, 0.683281f, 0.695625f, 0.283519f, 0.379927f, 0.181151f, 0.788545f, 0.056848f, 0.696997f, 0.778695f, 0.777408f, 0.259423f, 0.373813f, 0.587600f, 0.272822f, 0.370853f, 0.197054f, 0.459856f, 0.044612f, 0.799796f, 0.076956f, 0.518835f, 0.306810f, 0.577543f, 0.959433f, 0.645570f, 0.035362f, 0.430402f, 0.510017f, 0.536178f, 0.681392f, 0.277596f, 0.128861f, 0.392676f, 0.956406f, 0.187131f, 0.903984f, 0.543806f, 0.456911f, 0.882041f, 0.458604f, 0.724168f, 0.399025f, 0.904044f, 0.690025f, 0.699622f, 0.327720f, 0.756779f, 0.636061f, 0.240020f, 0.160539f, 0.796391f, 0.959167f, 0.458139f, 0.590984f, 0.857723f, 0.457223f, 0.951874f, 0.575751f, 0.820767f, 0.908844f, 0.815524f, 0.159414f, 0.628898f, 0.398434f, 0.062713f, 0.424032f, 0.258684f, 0.849038f, 0.033305f, 0.958983f, 0.355369f, 0.356707f, 0.016329f, 0.185232f, 0.401260f, 0.929291f, 0.099615f, 0.945302f, 0.869489f, 0.454162f, 0.326701f, 0.232744f, 0.614465f, 0.033075f, 0.015606f, 0.428796f, 0.068074f, 0.251941f, 0.221161f, 0.253191f, 0.131055f, 0.012036f, 0.115484f, 0.618480f, 0.974256f, 0.990345f, 0.409054f, 0.162954f, 0.638762f, 0.490305f, 0.989410f, 0.065304f, 0.783234f, 0.288399f, 0.241419f, 0.662505f, 0.246063f, 0.665859f, 0.517309f, 0.424089f, 0.554688f, 0.287052f, 0.706575f, 0.414857f, 0.360546f, 0.828657f, 0.924967f, 0.046007f, 0.232627f, 0.348519f, 0.814966f, 0.985491f, 0.968972f, 0.904948f, 0.296556f, 0.992011f, 0.249420f, 0.105906f, 0.950953f, 0.233420f, 0.689768f, 0.058356f, 0.730709f, 0.881720f, 0.272437f, 0.379057f, 0.374296f, 0.748788f, 0.237807f, 0.171853f, 0.449292f, 0.304468f, 0.839189f, 0.237742f, 0.502389f, 0.942584f, 0.633998f, 0.867289f, 0.940210f, 0.750765f, 0.699575f, 0.967966f, 0.994401f, 0.451822f, 0.070870f, 0.292794f, 0.152355f, 0.417486f, 0.131289f, 0.604118f, 0.382808f, 0.895386f, 0.967795f, 0.546885f, 0.274824f, 0.592230f, 0.896761f, 0.406733f, 0.552078f, 0.271653f, 0.455444f, 0.401714f, 0.248413f, 0.505866f, 0.310381f, 0.373035f, 0.524970f, 0.750595f, 0.333507f, 0.924159f, 0.862319f, 0.048690f, 0.253643f, 0.446136f, 0.104628f, 0.348476f, 0.740098f, 0.680514f, 0.622384f, 0.710528f, 0.204924f, 0.341698f, 0.676242f, 0.879235f, 0.543678f, 0.282700f, 0.030235f, 0.710337f, 0.007884f, 0.372679f, 0.530537f, 0.922111f, 0.089495f, 0.405942f, 0.024313f, 0.342611f, 0.622231f, 0.279068f, 0.209750f, 0.115703f, 0.577140f, 0.695270f, 0.671957f, 0.948861f, 0.002703f, 0.647197f, 0.600392f, 0.588740f, 0.962770f, 0.016872f, 0.696482f, 0.813679f, 0.509807f, 0.333965f, 0.790840f, 0.097243f, 0.442036f, 0.519952f, 0.693956f, 0.090886f, 0.227759f, 0.410302f, 0.623295f, 0.886961f, 0.618826f, 0.133461f, 0.980580f, 0.871786f, 0.502721f, 0.922348f, 0.541381f, 0.923306f, 0.829897f, 0.968286f, 0.919783f, 0.036034f, 0.174772f, 0.389135f, 0.952143f, 0.300029f}; + // {2, 3, 6, 8} + std::vector k = {0.160468f, 0.886305f, 0.446394f, 0.907876f, 0.160230f, 0.661117f, 0.440264f, 0.076487f, 0.696463f, 0.247399f, 0.039616f, 0.059944f, 0.061079f, 0.907733f, 0.739884f, 0.898062f, 0.672582f, 0.528940f, 0.304446f, 0.997962f, 0.362189f, 0.470649f, 0.378245f, 0.979527f, 0.174658f, 0.327988f, 0.680349f, 0.063208f, 0.607249f, 0.477646f, 0.284000f, 0.238413f, 0.514513f, 0.367928f, 0.456520f, 0.337477f, 0.970494f, 0.133439f, 0.096804f, 0.343392f, 0.591027f, 0.659176f, 0.397257f, 0.999278f, 0.351893f, 0.721407f, 0.637583f, 0.813054f, 0.976226f, 0.889794f, 0.764562f, 0.698249f, 0.335498f, 0.147686f, 0.062636f, 0.241902f, 0.432281f, 0.521996f, 0.773084f, 0.958741f, 0.117320f, 0.107004f, 0.589695f, 0.745398f, 0.848150f, 0.935832f, 0.983426f, 0.399802f, 0.380335f, 0.147809f, 0.684934f, 0.656762f, 0.862063f, 0.097258f, 0.497777f, 0.581082f, 0.241557f, 0.169025f, 0.859581f, 0.058535f, 0.470621f, 0.115834f, 0.457059f, 0.979962f, 0.423706f, 0.857125f, 0.117316f, 0.271252f, 0.403793f, 0.399812f, 0.671384f, 0.344718f, 0.713767f, 0.639187f, 0.399161f, 0.431760f, 0.614528f, 0.070042f, 0.822407f, 0.653421f, 0.726342f, 0.536923f, 0.110477f, 0.405036f, 0.405374f, 0.321043f, 0.029950f, 0.737254f, 0.109784f, 0.606308f, 0.703218f, 0.634786f, 0.959142f, 0.103298f, 0.867167f, 0.029190f, 0.534917f, 0.404244f, 0.524184f, 0.365100f, 0.190567f, 0.019123f, 0.518150f, 0.842777f, 0.373216f, 0.222864f, 0.080532f, 0.085311f, 0.221396f, 0.100014f, 0.265040f, 0.066149f, 0.065605f, 0.856276f, 0.162120f, 0.559682f, 0.773456f, 0.456410f, 0.153369f, 0.199596f, 0.432984f, 0.528234f, 0.349440f, 0.781480f, 0.751022f, 0.927212f, 0.028953f, 0.895691f, 0.392569f, 0.878372f, 0.690785f, 0.987349f, 0.759282f, 0.364545f, 0.501063f, 0.376389f, 0.364912f, 0.260904f, 0.495970f, 0.681740f, 0.277340f, 0.524380f, 0.117380f, 0.159845f, 0.046806f, 0.970731f, 0.003860f, 0.178580f, 0.612867f, 0.081370f, 0.881896f, 0.719620f, 0.966390f, 0.507636f, 0.300404f, 0.549501f, 0.930819f, 0.520761f, 0.267207f, 0.877399f, 0.371919f, 0.001383f, 0.247685f, 0.318234f, 0.858777f, 0.458503f, 0.444587f, 0.336102f, 0.880678f, 0.945027f, 0.991890f, 0.376741f, 0.966147f, 0.791880f, 0.675689f, 0.244889f, 0.216457f, 0.166048f, 0.922757f, 0.294077f, 0.453094f, 0.493958f, 0.778172f, 0.844235f, 0.139073f, 0.426904f, 0.842855f, 0.818033f, 0.102414f, 0.156383f, 0.304199f, 0.075359f, 0.424663f, 0.107618f, 0.568218f, 0.246557f, 0.596433f, 0.117526f, 0.975884f, 0.932561f, 0.391797f, 0.242179f, 0.250398f, 0.483394f, 0.039993f, 0.639705f, 0.408303f, 0.377407f, 0.809365f, 0.709035f, 0.954334f, 0.351936f, 0.897543f, 0.769967f, 0.357425f, 0.621665f, 0.288570f, 0.874400f, 0.112427f, 0.212434f, 0.183033f, 0.403026f, 0.745233f, 0.526907f, 0.487676f, 0.000546f, 0.425402f, 0.063554f, 0.208253f, 0.932394f, 0.215398f, 0.858338f, 0.802893f, 0.159146f, 0.605712f, 0.115662f, 0.727888f, 0.637462f, 0.811939f, 0.479385f, 0.914863f, 0.049349f, 0.292889f, 0.715053f, 0.418109f, 0.172951f, 0.107211f, 0.817339f, 0.473143f, 0.882284f, 0.733289f, 0.409726f, 0.373511f, 0.515638f, 0.889060f, 0.737279f, 0.005153f, 0.694158f, 0.919507f, 0.710456f, 0.177006f, 0.483518f, 0.140316f, 0.358995f, 0.937117f, 0.923305f, 0.282837f, 0.339631f}; + // {2, 3, 6, 8} + std::vector v = {0.600213f, 0.963197f, 0.147801f, 0.256917f, 0.873557f, 0.491892f, 0.898961f, 0.185518f, 0.532669f, 0.326270f, 0.316543f, 0.446877f, 0.433077f, 0.357347f, 0.914971f, 0.731744f, 0.727547f, 0.289913f, 0.577709f, 0.779179f, 0.795590f, 0.344530f, 0.770873f, 0.735894f, 0.141506f, 0.865945f, 0.441321f, 0.486410f, 0.448369f, 0.567846f, 0.621169f, 0.498180f, 0.866789f, 0.627735f, 0.401428f, 0.416692f, 0.810839f, 0.348192f, 0.211455f, 0.059383f, 0.876027f, 0.918546f, 0.120120f, 0.334474f, 0.175372f, 0.115898f, 0.899867f, 0.056877f, 0.980486f, 0.096451f, 0.863471f, 0.566506f, 0.367917f, 0.342342f, 0.757364f, 0.314573f, 0.657319f, 0.517326f, 0.484966f, 0.901162f, 0.554645f, 0.826862f, 0.725574f, 0.038557f, 0.773110f, 0.216870f, 0.903150f, 0.042924f, 0.333072f, 0.099733f, 0.475589f, 0.820022f, 0.298187f, 0.150935f, 0.330267f, 0.813880f, 0.140384f, 0.227362f, 0.068852f, 0.705710f, 0.395233f, 0.310840f, 0.718626f, 0.335978f, 0.727771f, 0.815199f, 0.217663f, 0.973819f, 0.162358f, 0.290841f, 0.179795f, 0.345506f, 0.480061f, 0.522176f, 0.853606f, 0.889448f, 0.220104f, 0.622894f, 0.111496f, 0.458970f, 0.322334f, 0.316501f, 0.482584f, 0.729828f, 0.069183f, 0.879173f, 0.734814f, 0.176499f, 0.939161f, 0.506312f, 0.999809f, 0.197259f, 0.534908f, 0.290248f, 0.304174f, 0.591065f, 0.921719f, 0.805264f, 0.723941f, 0.559174f, 0.922298f, 0.492361f, 0.873832f, 0.833982f, 0.213835f, 0.771225f, 0.012171f, 0.322830f, 0.229567f, 0.506863f, 0.736853f, 0.097676f, 0.514922f, 0.938412f, 0.228647f, 0.677141f, 0.592880f, 0.010064f, 0.475826f, 0.708770f, 0.043975f, 0.879521f, 0.520081f, 0.030661f, 0.224414f, 0.953676f, 0.582320f, 0.107473f, 0.287544f, 0.456704f, 0.020950f, 0.411616f, 0.489459f, 0.243678f, 0.588639f, 0.753240f, 0.235834f, 0.620500f, 0.639622f, 0.948540f, 0.778276f, 0.848345f, 0.490420f, 0.185349f, 0.995815f, 0.129356f, 0.471457f, 0.068093f, 0.943851f, 0.964925f, 0.719389f, 0.349993f, 0.254382f, 0.265303f, 0.127294f, 0.525809f, 0.141817f, 0.316731f, 0.626706f, 0.727544f, 0.024273f, 0.430116f, 0.652125f, 0.853246f, 0.475325f, 0.969206f, 0.265633f, 0.013509f, 0.483753f, 0.256114f, 0.823718f, 0.232773f, 0.310629f, 0.791227f, 0.715143f, 0.558051f, 0.704948f, 0.418637f, 0.005310f, 0.011355f, 0.511222f, 0.083291f, 0.051075f, 0.965517f, 0.859003f, 0.152027f, 0.000664f, 0.941668f, 0.278325f, 0.185898f, 0.691508f, 0.108904f, 0.264650f, 0.975095f, 0.639463f, 0.520678f, 0.397919f, 0.774501f, 0.140957f, 0.967338f, 0.861123f, 0.617657f, 0.042906f, 0.700856f, 0.913284f, 0.524577f, 0.354225f, 0.120277f, 0.754901f, 0.885022f, 0.100252f, 0.758985f, 0.017060f, 0.967055f, 0.615058f, 0.552439f, 0.295950f, 0.929292f, 0.265906f, 0.828147f, 0.985109f, 0.783397f, 0.518990f, 0.066074f, 0.472414f, 0.438256f, 0.202796f, 0.423588f, 0.357758f, 0.163684f, 0.441374f, 0.262800f, 0.522062f, 0.035160f, 0.906231f, 0.816364f, 0.552581f, 0.851809f, 0.962395f, 0.110522f, 0.630832f, 0.997994f, 0.987889f, 0.603323f, 0.128021f, 0.583193f, 0.002065f, 0.198911f, 0.956123f, 0.330441f, 0.638390f, 0.280860f, 0.947822f, 0.728559f, 0.329651f, 0.791761f, 0.108166f, 0.392319f, 0.221218f, 0.683726f, 0.102446f, 0.397026f, 0.276650f, 0.506343f, 0.349898f, 0.706411f, 0.024577f, 0.633987f}; + // {4, 18} + std::vector m = {0.230571f, 0.268709f, 0.800256f, 0.955568f, 0.316550f, 0.826805f, 0.103991f, 0.633982f, 0.751032f, 0.155978f, 0.426002f, 0.892707f, 0.103578f, 0.018096f, 0.590585f, 0.435532f, 0.798689f, 0.923456f, 0.299154f, 0.388404f, 0.486272f, 0.588151f, 0.983854f, 0.697330f, 0.389549f, 0.263768f, 0.944626f, 0.135548f, 0.720266f, 0.925395f, 0.664666f, 0.423054f, 0.198991f, 0.367475f, 0.706872f, 0.649534f, 0.927976f, 0.866861f, 0.816151f, 0.911451f, 0.276337f, 0.369524f, 0.379894f, 0.560451f, 0.668218f, 0.286717f, 0.019462f, 0.399222f, 0.308528f, 0.942185f, 0.888265f, 0.860311f, 0.653000f, 0.344289f, 0.548849f, 0.815225f, 0.098610f, 0.801075f, 0.041180f, 0.816421f, 0.807564f, 0.051007f, 0.627161f, 0.502453f, 0.169820f, 0.148379f, 0.773259f, 0.567693f, 0.982999f, 0.982248f, 0.992667f, 0.118616f}; + // {2, 3, 12, 8} + std::vector past_key = {0.938256f, 0.244570f, 0.458212f, 0.757407f, 0.203621f, 0.566312f, 0.185817f, 0.104736f, 0.116559f, 0.357639f, 0.004655f, 0.424854f, 0.664197f, 0.401688f, 0.085795f, 0.062689f, 0.278117f, 0.169313f, 0.965095f, 0.151230f, 0.805462f, 0.586108f, 0.569287f, 0.512081f, 0.971763f, 0.363845f, 0.787916f, 0.555294f, 0.395634f, 0.955466f, 0.598316f, 0.118917f, 0.417539f, 0.781582f, 0.693747f, 0.916340f, 0.259377f, 0.758194f, 0.459875f, 0.573610f, 0.955047f, 0.979286f, 0.861591f, 0.359097f, 0.887701f, 0.638609f, 0.429997f, 0.035743f, 0.770128f, 0.502106f, 0.786188f, 0.748023f, 0.793567f, 0.300651f, 0.800799f, 0.548846f, 0.473326f, 0.675126f, 0.021359f, 0.102317f, 0.292177f, 0.982990f, 0.139746f, 0.330596f, 0.051053f, 0.331269f, 0.320326f, 0.946807f, 0.845154f, 0.382764f, 0.024769f, 0.831031f, 0.660536f, 0.152364f, 0.996071f, 0.100233f, 0.867115f, 0.294266f, 0.435353f, 0.795457f, 0.677508f, 0.937864f, 0.621140f, 0.097810f, 0.884360f, 0.769156f, 0.711870f, 0.053734f, 0.396223f, 0.167436f, 0.821904f, 0.700529f, 0.883078f, 0.966575f, 0.774748f, 0.994233f, 0.614770f, 0.037130f, 0.014252f, 0.342104f, 0.823472f, 0.866135f, 0.960813f, 0.065121f, 0.044571f, 0.913284f, 0.305047f, 0.557987f, 0.982445f, 0.400449f, 0.665871f, 0.400880f, 0.768195f, 0.527715f, 0.237523f, 0.271306f, 0.258059f, 0.532320f, 0.703189f, 0.949280f, 0.694087f, 0.781193f, 0.168926f, 0.374063f, 0.413780f, 0.686380f, 0.295892f, 0.303292f, 0.355889f, 0.810302f, 0.577590f, 0.075277f, 0.078246f, 0.371287f, 0.766591f, 0.688683f, 0.707982f, 0.767210f, 0.287153f, 0.548256f, 0.543353f, 0.739632f, 0.956871f, 0.277990f, 0.793282f, 0.659971f, 0.580238f, 0.774880f, 0.944032f, 0.036691f, 0.147400f, 0.756287f, 0.083791f, 0.516124f, 0.219861f, 0.274296f, 0.701840f, 0.030193f, 0.873319f, 0.444479f, 0.502393f, 0.540048f, 0.645544f, 0.344857f, 0.101107f, 0.318379f, 0.168142f, 0.556133f, 0.318029f, 0.958067f, 0.965734f, 0.620126f, 0.617497f, 0.985379f, 0.887283f, 0.765070f, 0.313591f, 0.365539f, 0.201267f, 0.487148f, 0.990369f, 0.912151f, 0.118349f, 0.025190f, 0.898638f, 0.537170f, 0.200190f, 0.673653f, 0.644223f, 0.122086f, 0.259600f, 0.060078f, 0.209860f, 0.132306f, 0.193236f, 0.685467f, 0.049500f, 0.101855f, 0.134174f, 0.316541f, 0.298750f, 0.255064f, 0.750537f, 0.998023f, 0.533978f, 0.944203f, 0.396610f, 0.106682f, 0.408774f, 0.296128f, 0.493407f, 0.657044f, 0.461050f, 0.935161f, 0.884765f, 0.701978f, 0.489685f, 0.131687f, 0.397014f, 0.704402f, 0.284886f, 0.103988f, 0.907898f, 0.709051f, 0.615276f, 0.792499f, 0.835646f, 0.483459f, 0.881188f, 0.916419f, 0.271551f, 0.607545f, 0.526584f, 0.537946f, 0.937663f, 0.305189f, 0.983434f, 0.902131f, 0.458723f, 0.817453f, 0.769047f, 0.677895f, 0.319834f, 0.196451f, 0.671528f, 0.842973f, 0.016253f, 0.642803f, 0.442873f, 0.898088f, 0.321473f, 0.474185f, 0.514767f, 0.140440f, 0.712892f, 0.830476f, 0.057909f, 0.291389f, 0.038045f, 0.956544f, 0.667169f, 0.964200f, 0.531494f, 0.802069f, 0.374414f, 0.353819f, 0.378268f, 0.657862f, 0.359453f, 0.900367f, 0.983275f, 0.030427f, 0.193623f, 0.112250f, 0.042364f, 0.227741f, 0.446793f, 0.836990f, 0.221824f, 0.493945f, 0.929619f, 0.667215f, 0.798079f, 0.550994f, 0.980466f, 0.588662f, 0.045511f, 0.197983f, 0.404774f, 0.601277f, 0.771931f, 0.413086f, 0.710058f, 0.789869f, 0.317260f, 0.979270f, 0.649656f, 0.880998f, 0.555938f, 0.741603f, 0.770544f, 0.908248f, 0.150350f, 0.558283f, 0.428379f, 0.923159f, 0.105095f, 0.982574f, 0.875451f, 0.073826f, 0.490966f, 0.717560f, 0.738152f, 0.906494f, 0.799865f, 0.310930f, 0.498435f, 0.701786f, 0.138437f, 0.193991f, 0.481042f, 0.298246f, 0.862559f, 0.586277f, 0.348665f, 0.848833f, 0.804878f, 0.998355f, 0.847308f, 0.414457f, 0.127499f, 0.840641f, 0.059758f, 0.350271f, 0.919738f, 0.960766f, 0.640565f, 0.688648f, 0.042454f, 0.514480f, 0.546868f, 0.340101f, 0.068597f, 0.228908f, 0.357984f, 0.435142f, 0.590927f, 0.722392f, 0.317632f, 0.328954f, 0.019692f, 0.040875f, 0.257822f, 0.740245f, 0.628314f, 0.769789f, 0.768919f, 0.856567f, 0.720319f, 0.979011f, 0.898825f, 0.586717f, 0.588158f, 0.034267f, 0.998527f, 0.131576f, 0.740347f, 0.821015f, 0.373055f, 0.196852f, 0.098760f, 0.748606f, 0.452654f, 0.713718f, 0.915408f, 0.146584f, 0.919171f, 0.411626f, 0.305267f, 0.943062f, 0.990652f, 0.198892f, 0.656838f, 0.106495f, 0.650914f, 0.827313f, 0.684499f, 0.417333f, 0.383066f, 0.393122f, 0.589712f, 0.881567f, 0.929066f, 0.053530f, 0.181622f, 0.112224f, 0.193335f, 0.346608f, 0.506532f, 0.629461f, 0.732142f, 0.890112f, 0.989088f, 0.662856f, 0.845365f, 0.778039f, 0.307532f, 0.875692f, 0.042763f, 0.000367f, 0.273733f, 0.462098f, 0.638363f, 0.101770f, 0.673010f, 0.801816f, 0.185313f, 0.415125f, 0.519985f, 0.451807f, 0.799830f, 0.960522f, 0.798953f, 0.077993f, 0.804936f, 0.066596f, 0.235970f, 0.153097f, 0.197519f, 0.528315f, 0.671690f, 0.470321f, 0.959696f, 0.240292f, 0.763140f, 0.870182f, 0.562066f, 0.456223f, 0.596184f, 0.428810f, 0.555194f, 0.416934f, 0.400470f, 0.695346f, 0.092851f, 0.166542f, 0.851198f, 0.771077f, 0.281454f, 0.377269f, 0.926027f, 0.818077f, 0.614346f, 0.221490f, 0.044252f, 0.431258f, 0.672627f, 0.828480f, 0.852689f, 0.032776f, 0.244157f, 0.339095f, 0.188732f, 0.802975f, 0.767466f, 0.516833f, 0.982926f, 0.144059f, 0.899652f, 0.116463f, 0.163182f, 0.696219f, 0.109570f, 0.565845f, 0.420234f, 0.728474f, 0.900675f, 0.769872f, 0.849690f, 0.032945f, 0.310196f, 0.515433f, 0.415953f, 0.231255f, 0.307874f, 0.945431f, 0.294181f, 0.353904f, 0.003710f, 0.845078f, 0.154841f, 0.204144f, 0.255265f, 0.884622f, 0.206451f, 0.797526f, 0.808049f, 0.927021f, 0.115561f, 0.217279f, 0.742898f, 0.196001f, 0.286330f, 0.166742f, 0.172697f, 0.481553f, 0.109683f, 0.321698f, 0.426594f, 0.024548f, 0.388333f, 0.094122f, 0.493579f, 0.825738f, 0.818422f, 0.080449f, 0.601228f, 0.834586f, 0.237973f, 0.761927f, 0.890764f, 0.806124f, 0.107301f, 0.009060f, 0.191724f, 0.270477f, 0.616183f, 0.384273f, 0.703407f, 0.353075f, 0.154425f, 0.312690f, 0.884324f, 0.958532f, 0.207513f, 0.788468f, 0.273349f, 0.887132f, 0.165546f, 0.665960f, 0.084211f, 0.973893f, 0.700633f, 0.841816f, 0.566669f, 0.476801f, 0.621882f, 0.528742f, 0.469384f, 0.759450f, 0.178201f, 0.171172f, 0.431843f, 0.320748f, 0.074125f, 0.844471f, 0.771603f, 0.543921f, 0.979325f, 0.072600f, 0.766669f, 0.266370f, 0.368599f, 0.219279f, 0.789038f, 0.144240f, 0.840017f, 0.661578f, 0.059023f, 0.810982f, 0.627756f, 0.904982f, 0.748722f, 0.561121f, 0.836547f, 0.278050f}; + // {2, 3, 12, 8} + std::vector past_value = {0.546950f, 0.293617f, 0.968204f, 0.226196f, 0.015738f, 0.325855f, 0.502509f, 0.028363f, 0.559248f, 0.874283f, 0.704732f, 0.622968f, 0.955962f, 0.958279f, 0.824266f, 0.607742f, 0.487765f, 0.013316f, 0.606262f, 0.989088f, 0.818101f, 0.340605f, 0.152047f, 0.784059f, 0.743938f, 0.967047f, 0.874842f, 0.555663f, 0.101284f, 0.483501f, 0.313695f, 0.512408f, 0.301702f, 0.861823f, 0.844327f, 0.315465f, 0.599581f, 0.430181f, 0.909093f, 0.187361f, 0.697728f, 0.970375f, 0.175276f, 0.201966f, 0.693723f, 0.779154f, 0.490549f, 0.609686f, 0.212682f, 0.476614f, 0.112072f, 0.321422f, 0.284780f, 0.444625f, 0.930126f, 0.181268f, 0.401388f, 0.615597f, 0.946557f, 0.133148f, 0.917877f, 0.081054f, 0.480741f, 0.454590f, 0.209603f, 0.347460f, 0.454165f, 0.865211f, 0.955064f, 0.518926f, 0.870100f, 0.608172f, 0.349087f, 0.194194f, 0.413135f, 0.522824f, 0.044443f, 0.145841f, 0.600184f, 0.225002f, 0.837326f, 0.326942f, 0.104834f, 0.083531f, 0.937123f, 0.118020f, 0.140910f, 0.862666f, 0.254288f, 0.665951f, 0.816726f, 0.607181f, 0.957489f, 0.708883f, 0.112752f, 0.558410f, 0.718186f, 0.801957f, 0.026321f, 0.718879f, 0.825681f, 0.746834f, 0.512349f, 0.458021f, 0.549419f, 0.704644f, 0.922914f, 0.617035f, 0.887834f, 0.701257f, 0.068336f, 0.500828f, 0.286486f, 0.285175f, 0.355928f, 0.314733f, 0.578610f, 0.683601f, 0.268749f, 0.129763f, 0.058809f, 0.575753f, 0.186130f, 0.009248f, 0.927753f, 0.537140f, 0.092448f, 0.842921f, 0.983203f, 0.448601f, 0.042490f, 0.117546f, 0.381654f, 0.885523f, 0.148039f, 0.823990f, 0.014976f, 0.457389f, 0.644397f, 0.060379f, 0.614763f, 0.944404f, 0.160260f, 0.729611f, 0.609094f, 0.185116f, 0.006203f, 0.009284f, 0.532092f, 0.942779f, 0.644299f, 0.714300f, 0.493865f, 0.581889f, 0.126368f, 0.876821f, 0.760793f, 0.998199f, 0.297723f, 0.227018f, 0.125162f, 0.964210f, 0.780885f, 0.166325f, 0.552686f, 0.413768f, 0.151486f, 0.162073f, 0.963470f, 0.304964f, 0.941439f, 0.075611f, 0.460803f, 0.129619f, 0.004787f, 0.553766f, 0.113894f, 0.722025f, 0.698116f, 0.176333f, 0.941742f, 0.721043f, 0.297970f, 0.709234f, 0.731930f, 0.342226f, 0.375589f, 0.359107f, 0.616618f, 0.900410f, 0.173193f, 0.875200f, 0.027653f, 0.660339f, 0.414439f, 0.791282f, 0.721198f, 0.480108f, 0.643864f, 0.501773f, 0.811518f, 0.476084f, 0.523156f, 0.250521f, 0.605043f, 0.302905f, 0.577284f, 0.169678f, 0.159469f, 0.417030f, 0.426820f, 0.268109f, 0.131597f, 0.039211f, 0.025232f, 0.271550f, 0.461853f, 0.726243f, 0.474872f, 0.904051f, 0.035220f, 0.180661f, 0.338515f, 0.577496f, 0.852736f, 0.350202f, 0.267989f, 0.061889f, 0.821303f, 0.379666f, 0.571550f, 0.983555f, 0.001595f, 0.145450f, 0.779111f, 0.805128f, 0.769247f, 0.536999f, 0.978857f, 0.396185f, 0.601944f, 0.063369f, 0.409857f, 0.722500f, 0.238739f, 0.943828f, 0.686783f, 0.287575f, 0.768999f, 0.083165f, 0.974774f, 0.049285f, 0.933456f, 0.252854f, 0.757824f, 0.000074f, 0.254240f, 0.749101f, 0.532336f, 0.114952f, 0.393630f, 0.375549f, 0.568162f, 0.667977f, 0.840830f, 0.497231f, 0.392022f, 0.143977f, 0.804823f, 0.713370f, 0.408677f, 0.518432f, 0.665183f, 0.164806f, 0.027198f, 0.317504f, 0.595585f, 0.486606f, 0.692555f, 0.819690f, 0.488442f, 0.134267f, 0.850628f, 0.574990f, 0.739937f, 0.704665f, 0.968212f, 0.295307f, 0.705307f, 0.365676f, 0.395411f, 0.230595f, 0.344010f, 0.948297f, 0.292571f, 0.245991f, 0.583138f, 0.258036f, 0.473386f, 0.834176f, 0.230400f, 0.426691f, 0.610490f, 0.545629f, 0.974723f, 0.680370f, 0.739946f, 0.966956f, 0.414438f, 0.355380f, 0.043862f, 0.184204f, 0.237190f, 0.183504f, 0.754784f, 0.535883f, 0.667634f, 0.820462f, 0.230774f, 0.325924f, 0.708360f, 0.392759f, 0.029271f, 0.434955f, 0.908273f, 0.409021f, 0.332249f, 0.989525f, 0.644416f, 0.365998f, 0.102020f, 0.787849f, 0.708075f, 0.921916f, 0.217276f, 0.114924f, 0.724073f, 0.203396f, 0.176104f, 0.319807f, 0.816825f, 0.539537f, 0.045850f, 0.463895f, 0.683980f, 0.538368f, 0.572450f, 0.224777f, 0.847739f, 0.561399f, 0.713246f, 0.981864f, 0.428199f, 0.881067f, 0.007281f, 0.033407f, 0.590280f, 0.311449f, 0.248277f, 0.277935f, 0.318403f, 0.728948f, 0.569196f, 0.789036f, 0.830197f, 0.842935f, 0.414644f, 0.421273f, 0.926266f, 0.661764f, 0.080467f, 0.542187f, 0.356007f, 0.987435f, 0.013655f, 0.612181f, 0.723623f, 0.288907f, 0.973642f, 0.859537f, 0.915653f, 0.019232f, 0.569872f, 0.294650f, 0.849029f, 0.632850f, 0.538877f, 0.114588f, 0.540223f, 0.631904f, 0.955912f, 0.585051f, 0.967401f, 0.961606f, 0.650200f, 0.505908f, 0.466022f, 0.890379f, 0.028257f, 0.113808f, 0.102072f, 0.756935f, 0.339651f, 0.637969f, 0.603783f, 0.385828f, 0.531568f, 0.645139f, 0.940950f, 0.575634f, 0.614367f, 0.067856f, 0.952216f, 0.528082f, 0.801273f, 0.050291f, 0.420910f, 0.256975f, 0.266976f, 0.791454f, 0.623867f, 0.439745f, 0.010586f, 0.964928f, 0.962023f, 0.217552f, 0.041346f, 0.530199f, 0.951411f, 0.910396f, 0.584663f, 0.303549f, 0.329961f, 0.897914f, 0.491784f, 0.131116f, 0.248425f, 0.276795f, 0.123547f, 0.463044f, 0.916051f, 0.668783f, 0.072474f, 0.005495f, 0.276248f, 0.362693f, 0.776750f, 0.967006f, 0.387567f, 0.686690f, 0.994902f, 0.745667f, 0.636190f, 0.078075f, 0.323215f, 0.913392f, 0.201005f, 0.843590f, 0.696324f, 0.366324f, 0.529174f, 0.542806f, 0.714054f, 0.516556f, 0.133076f, 0.773455f, 0.406273f, 0.963094f, 0.283514f, 0.263079f, 0.333507f, 0.572317f, 0.894870f, 0.176282f, 0.279679f, 0.581680f, 0.454334f, 0.447323f, 0.820734f, 0.923878f, 0.481307f, 0.687352f, 0.801059f, 0.518366f, 0.294316f, 0.638085f, 0.585109f, 0.901563f, 0.052407f, 0.910131f, 0.534432f, 0.015676f, 0.344702f, 0.724334f, 0.488433f, 0.980159f, 0.422610f, 0.326635f, 0.821672f, 0.547907f, 0.682327f, 0.805702f, 0.671428f, 0.422408f, 0.124796f, 0.580248f, 0.897433f, 0.418892f, 0.910725f, 0.503528f, 0.620842f, 0.832989f, 0.564597f, 0.090969f, 0.980979f, 0.245849f, 0.710505f, 0.505113f, 0.478773f, 0.243941f, 0.722151f, 0.112788f, 0.990453f, 0.845374f, 0.534509f, 0.424553f, 0.286465f, 0.501591f, 0.879417f, 0.275006f, 0.500537f, 0.234550f, 0.337149f, 0.190261f, 0.990539f, 0.571497f, 0.732815f, 0.098250f, 0.366118f, 0.892640f, 0.084438f, 0.165483f, 0.625418f, 0.622789f, 0.838227f, 0.935493f, 0.141986f, 0.259374f, 0.427461f, 0.000903f, 0.069814f, 0.226491f, 0.481102f, 0.251523f, 0.876682f, 0.324273f, 0.924623f, 0.974787f, 0.449862f, 0.227129f, 0.291666f, 0.776334f, 0.273350f, 0.380583f, 0.478576f, 0.575111f, 0.996100f, 0.232210f, 0.353424f, 0.262891f, 0.361113f, 0.100805f, 0.359810f, 0.887865f, 0.298590f, 0.371935f}; + // {2, 9, 4, 8} + std::vector y = {0.544462f, 0.617844f, 0.506335f, 0.473482f, 0.606855f, 0.423464f, 0.544771f, 0.450451f, 0.524249f, 0.627160f, 0.497201f, 0.440288f, 0.619110f, 0.437084f, 0.563680f, 0.440037f, 0.516736f, 0.577726f, 0.523888f, 0.493471f, 0.594122f, 0.433401f, 0.585942f, 0.457686f, 0.528512f, 0.604578f, 0.472106f, 0.471486f, 0.600445f, 0.446256f, 0.622393f, 0.435442f, 0.440810f, 0.437705f, 0.476508f, 0.320820f, 0.605191f, 0.640150f, 0.306216f, 0.610947f, 0.485794f, 0.448216f, 0.485639f, 0.323744f, 0.594446f, 0.646597f, 0.321742f, 0.605751f, 0.501858f, 0.445502f, 0.487899f, 0.384660f, 0.597134f, 0.616430f, 0.331401f, 0.566459f, 0.502522f, 0.409965f, 0.526639f, 0.348601f, 0.565200f, 0.586558f, 0.325044f, 0.603422f, 0.450250f, 0.368009f, 0.550911f, 0.460338f, 0.523907f, 0.508816f, 0.575624f, 0.426601f, 0.472310f, 0.372844f, 0.517852f, 0.431688f, 0.551555f, 0.527657f, 0.600578f, 0.473069f, 0.456633f, 0.442035f, 0.539875f, 0.437863f, 0.540202f, 0.499608f, 0.556470f, 0.419831f, 0.463081f, 0.416724f, 0.526389f, 0.458654f, 0.540120f, 0.551554f, 0.569399f, 0.447102f, 0.534296f, 0.597655f, 0.509699f, 0.487167f, 0.607438f, 0.426383f, 0.522794f, 0.458435f, 0.510147f, 0.622761f, 0.501724f, 0.453386f, 0.629671f, 0.434103f, 0.582477f, 0.437681f, 0.520031f, 0.568543f, 0.525216f, 0.490370f, 0.571745f, 0.428629f, 0.572995f, 0.460086f, 0.533607f, 0.614962f, 0.474130f, 0.456345f, 0.576467f, 0.448127f, 0.599211f, 0.432252f, 0.447842f, 0.430169f, 0.480055f, 0.320521f, 0.590915f, 0.627003f, 0.314551f, 0.609320f, 0.499216f, 0.438828f, 0.485519f, 0.322134f, 0.586364f, 0.645824f, 0.326481f, 0.596989f, 0.496362f, 0.442741f, 0.492120f, 0.366111f, 0.601604f, 0.615566f, 0.326354f, 0.567173f, 0.496946f, 0.422179f, 0.533144f, 0.342588f, 0.590482f, 0.605923f, 0.318055f, 0.610401f, 0.452598f, 0.361594f, 0.550919f, 0.455099f, 0.530404f, 0.519313f, 0.588655f, 0.431890f, 0.464325f, 0.389636f, 0.515359f, 0.429087f, 0.540767f, 0.518376f, 0.586627f, 0.471074f, 0.458527f, 0.422216f, 0.537762f, 0.434123f, 0.550956f, 0.507704f, 0.564828f, 0.421548f, 0.463044f, 0.407985f, 0.523093f, 0.473684f, 0.542663f, 0.551348f, 0.576783f, 0.448743f, 0.546208f, 0.621128f, 0.501647f, 0.468191f, 0.612298f, 0.425183f, 0.549241f, 0.447622f, 0.519355f, 0.619636f, 0.487775f, 0.444259f, 0.625749f, 0.430264f, 0.584338f, 0.436887f, 0.521021f, 0.572716f, 0.522539f, 0.486440f, 0.581317f, 0.429079f, 0.579691f, 0.455426f, 0.526431f, 0.604615f, 0.476481f, 0.469814f, 0.588766f, 0.445640f, 0.609160f, 0.437785f, 0.443498f, 0.439338f, 0.487424f, 0.310942f, 0.607341f, 0.630362f, 0.312591f, 0.621999f, 0.483917f, 0.446308f, 0.477454f, 0.331028f, 0.592608f, 0.653297f, 0.322368f, 0.599377f, 0.497354f, 0.443447f, 0.477781f, 0.384002f, 0.591587f, 0.610287f, 0.328537f, 0.567630f, 0.499369f, 0.421961f, 0.536492f, 0.345379f, 0.586450f, 0.600541f, 0.312965f, 0.609437f, 0.451750f, 0.359685f, 0.553321f, 0.464992f, 0.524025f, 0.522507f, 0.582135f, 0.425124f, 0.459696f, 0.394679f, 0.519051f, 0.411226f, 0.539772f, 0.505003f, 0.587681f, 0.469383f, 0.451681f, 0.430062f, 0.541843f, 0.420929f, 0.542240f, 0.487570f, 0.567067f, 0.419708f, 0.456288f, 0.412096f, 0.527592f, 0.467870f, 0.545021f, 0.547842f, 0.573135f, 0.448166f, 0.581220f, 0.559255f, 0.469802f, 0.489935f, 0.557197f, 0.487135f, 0.377325f, 0.425637f, 0.582374f, 0.560738f, 0.425382f, 0.463129f, 0.549939f, 0.481810f, 0.350432f, 0.466049f, 0.593554f, 0.542315f, 0.482597f, 0.496969f, 0.518851f, 0.507807f, 0.366054f, 0.457476f, 0.569468f, 0.565965f, 0.444765f, 0.465404f, 0.515500f, 0.520271f, 0.337845f, 0.448357f, 0.557802f, 0.585925f, 0.426858f, 0.464044f, 0.585251f, 0.557395f, 0.433327f, 0.615342f, 0.534368f, 0.573723f, 0.426393f, 0.518102f, 0.586735f, 0.513129f, 0.371969f, 0.636735f, 0.544166f, 0.588469f, 0.433470f, 0.481894f, 0.595019f, 0.533156f, 0.396519f, 0.608115f, 0.547125f, 0.604473f, 0.441984f, 0.469765f, 0.599107f, 0.561685f, 0.347618f, 0.563457f, 0.507550f, 0.485293f, 0.545846f, 0.408434f, 0.482538f, 0.532314f, 0.498883f, 0.525126f, 0.514603f, 0.471457f, 0.539705f, 0.362410f, 0.490158f, 0.513690f, 0.494170f, 0.496909f, 0.492936f, 0.506153f, 0.565865f, 0.364727f, 0.508899f, 0.516217f, 0.558362f, 0.556920f, 0.530472f, 0.521715f, 0.554673f, 0.363830f, 0.509086f, 0.511590f, 0.552396f, 0.541486f, 0.572145f, 0.551531f, 0.471964f, 0.485188f, 0.555030f, 0.493247f, 0.376875f, 0.429387f, 0.580540f, 0.550944f, 0.435664f, 0.480675f, 0.544997f, 0.488698f, 0.344985f, 0.464878f, 0.593774f, 0.541202f, 0.484834f, 0.497316f, 0.509364f, 0.500045f, 0.357235f, 0.448933f, 0.565242f, 0.546653f, 0.459790f, 0.481954f, 0.514950f, 0.516297f, 0.344285f, 0.454476f, 0.548036f, 0.577907f, 0.427075f, 0.478978f, 0.581563f, 0.553606f, 0.426476f, 0.638442f, 0.498925f, 0.598346f, 0.444106f, 0.536998f, 0.575948f, 0.499260f, 0.371120f, 0.626981f, 0.545949f, 0.586548f, 0.428254f, 0.479753f, 0.596943f, 0.527697f, 0.401418f, 0.613028f, 0.542355f, 0.607063f, 0.447840f, 0.467102f, 0.603496f, 0.549575f, 0.364370f, 0.561534f, 0.507041f, 0.473640f, 0.547768f, 0.413960f, 0.490513f, 0.534377f, 0.497277f, 0.517772f, 0.531394f, 0.489105f, 0.531671f, 0.369343f, 0.486462f, 0.501787f, 0.494220f, 0.493498f, 0.485968f, 0.510301f, 0.559766f, 0.361474f, 0.507888f, 0.518858f, 0.564300f, 0.561990f, 0.537984f, 0.527982f, 0.539571f, 0.366920f, 0.498313f, 0.505709f, 0.538027f, 0.541246f, 0.585733f, 0.565800f, 0.441346f, 0.476255f, 0.556453f, 0.497693f, 0.363246f, 0.426799f, 0.578484f, 0.556489f, 0.436699f, 0.481177f, 0.549473f, 0.484153f, 0.355910f, 0.462010f, 0.590951f, 0.542803f, 0.470954f, 0.488994f, 0.512707f, 0.511876f, 0.358555f, 0.455953f, 0.559449f, 0.546003f, 0.462900f, 0.471080f, 0.517298f, 0.519225f, 0.345016f, 0.449149f, 0.526624f, 0.606761f, 0.427660f, 0.480775f, 0.577420f, 0.538850f, 0.426959f, 0.625509f, 0.530502f, 0.585784f, 0.432234f, 0.516800f, 0.584937f, 0.514154f, 0.373726f, 0.623740f, 0.550470f, 0.585577f, 0.436483f, 0.474799f, 0.594100f, 0.540052f, 0.402520f, 0.607686f, 0.537556f, 0.609680f, 0.439490f, 0.477886f, 0.602656f, 0.542957f, 0.350394f, 0.574553f, 0.506900f, 0.488792f, 0.539037f, 0.403028f, 0.494093f, 0.534739f, 0.494292f, 0.511628f, 0.528192f, 0.480037f, 0.546429f, 0.375120f, 0.484828f, 0.505006f, 0.495786f, 0.497935f, 0.502174f, 0.514122f, 0.541314f, 0.369540f, 0.493985f, 0.508263f, 0.550415f, 0.556157f, 0.543269f, 0.529970f, 0.562027f, 0.376526f, 0.499704f, 0.508621f, 0.536068f, 0.545993f}; + // {2, 3, 18, 8} + std::vector present_key = {0.938256f, 0.244570f, 0.458212f, 0.757407f, 0.203621f, 0.566312f, 0.185817f, 0.104736f, 0.116559f, 0.357639f, 0.004655f, 0.424854f, 0.664197f, 0.401688f, 0.085795f, 0.062689f, 0.278117f, 0.169313f, 0.965095f, 0.151230f, 0.805462f, 0.586108f, 0.569287f, 0.512081f, 0.971763f, 0.363845f, 0.787916f, 0.555294f, 0.395634f, 0.955466f, 0.598316f, 0.118917f, 0.417539f, 0.781582f, 0.693747f, 0.916340f, 0.259377f, 0.758194f, 0.459875f, 0.573610f, 0.955047f, 0.979286f, 0.861591f, 0.359097f, 0.887701f, 0.638609f, 0.429997f, 0.035743f, 0.770128f, 0.502106f, 0.786188f, 0.748023f, 0.793567f, 0.300651f, 0.800799f, 0.548846f, 0.473326f, 0.675126f, 0.021359f, 0.102317f, 0.292177f, 0.982990f, 0.139746f, 0.330596f, 0.051053f, 0.331269f, 0.320326f, 0.946807f, 0.845154f, 0.382764f, 0.024769f, 0.831031f, 0.660536f, 0.152364f, 0.996071f, 0.100233f, 0.867115f, 0.294266f, 0.435353f, 0.795457f, 0.677508f, 0.937864f, 0.621140f, 0.097810f, 0.884360f, 0.769156f, 0.711870f, 0.053734f, 0.396223f, 0.167436f, 0.821904f, 0.700529f, 0.883078f, 0.966575f, 0.774748f, 0.994233f, 0.160468f, 0.886305f, 0.446394f, 0.907876f, 0.160230f, 0.661117f, 0.440264f, 0.076487f, 0.696463f, 0.247399f, 0.039616f, 0.059944f, 0.061079f, 0.907733f, 0.739884f, 0.898062f, 0.672582f, 0.528940f, 0.304446f, 0.997962f, 0.362189f, 0.470649f, 0.378245f, 0.979527f, 0.174658f, 0.327988f, 0.680349f, 0.063208f, 0.607249f, 0.477646f, 0.284000f, 0.238413f, 0.514513f, 0.367928f, 0.456520f, 0.337477f, 0.970494f, 0.133439f, 0.096804f, 0.343392f, 0.591027f, 0.659176f, 0.397257f, 0.999278f, 0.351893f, 0.721407f, 0.637583f, 0.813054f, 0.614770f, 0.037130f, 0.014252f, 0.342104f, 0.823472f, 0.866135f, 0.960813f, 0.065121f, 0.044571f, 0.913284f, 0.305047f, 0.557987f, 0.982445f, 0.400449f, 0.665871f, 0.400880f, 0.768195f, 0.527715f, 0.237523f, 0.271306f, 0.258059f, 0.532320f, 0.703189f, 0.949280f, 0.694087f, 0.781193f, 0.168926f, 0.374063f, 0.413780f, 0.686380f, 0.295892f, 0.303292f, 0.355889f, 0.810302f, 0.577590f, 0.075277f, 0.078246f, 0.371287f, 0.766591f, 0.688683f, 0.707982f, 0.767210f, 0.287153f, 0.548256f, 0.543353f, 0.739632f, 0.956871f, 0.277990f, 0.793282f, 0.659971f, 0.580238f, 0.774880f, 0.944032f, 0.036691f, 0.147400f, 0.756287f, 0.083791f, 0.516124f, 0.219861f, 0.274296f, 0.701840f, 0.030193f, 0.873319f, 0.444479f, 0.502393f, 0.540048f, 0.645544f, 0.344857f, 0.101107f, 0.318379f, 0.168142f, 0.556133f, 0.318029f, 0.958067f, 0.965734f, 0.620126f, 0.617497f, 0.985379f, 0.887283f, 0.765070f, 0.313591f, 0.365539f, 0.201267f, 0.487148f, 0.990369f, 0.912151f, 0.118349f, 0.025190f, 0.898638f, 0.537170f, 0.200190f, 0.673653f, 0.644223f, 0.122086f, 0.259600f, 0.060078f, 0.976226f, 0.889794f, 0.764562f, 0.698249f, 0.335498f, 0.147686f, 0.062636f, 0.241902f, 0.432281f, 0.521996f, 0.773084f, 0.958741f, 0.117320f, 0.107004f, 0.589695f, 0.745398f, 0.848150f, 0.935832f, 0.983426f, 0.399802f, 0.380335f, 0.147809f, 0.684934f, 0.656762f, 0.862063f, 0.097258f, 0.497777f, 0.581082f, 0.241557f, 0.169025f, 0.859581f, 0.058535f, 0.470621f, 0.115834f, 0.457059f, 0.979962f, 0.423706f, 0.857125f, 0.117316f, 0.271252f, 0.403793f, 0.399812f, 0.671384f, 0.344718f, 0.713767f, 0.639187f, 0.399161f, 0.431760f, 0.209860f, 0.132306f, 0.193236f, 0.685467f, 0.049500f, 0.101855f, 0.134174f, 0.316541f, 0.298750f, 0.255064f, 0.750537f, 0.998023f, 0.533978f, 0.944203f, 0.396610f, 0.106682f, 0.408774f, 0.296128f, 0.493407f, 0.657044f, 0.461050f, 0.935161f, 0.884765f, 0.701978f, 0.489685f, 0.131687f, 0.397014f, 0.704402f, 0.284886f, 0.103988f, 0.907898f, 0.709051f, 0.615276f, 0.792499f, 0.835646f, 0.483459f, 0.881188f, 0.916419f, 0.271551f, 0.607545f, 0.526584f, 0.537946f, 0.937663f, 0.305189f, 0.983434f, 0.902131f, 0.458723f, 0.817453f, 0.769047f, 0.677895f, 0.319834f, 0.196451f, 0.671528f, 0.842973f, 0.016253f, 0.642803f, 0.442873f, 0.898088f, 0.321473f, 0.474185f, 0.514767f, 0.140440f, 0.712892f, 0.830476f, 0.057909f, 0.291389f, 0.038045f, 0.956544f, 0.667169f, 0.964200f, 0.531494f, 0.802069f, 0.374414f, 0.353819f, 0.378268f, 0.657862f, 0.359453f, 0.900367f, 0.983275f, 0.030427f, 0.193623f, 0.112250f, 0.042364f, 0.227741f, 0.446793f, 0.836990f, 0.221824f, 0.493945f, 0.929619f, 0.667215f, 0.798079f, 0.550994f, 0.980466f, 0.588662f, 0.045511f, 0.197983f, 0.614528f, 0.070042f, 0.822407f, 0.653421f, 0.726342f, 0.536923f, 0.110477f, 0.405036f, 0.405374f, 0.321043f, 0.029950f, 0.737254f, 0.109784f, 0.606308f, 0.703218f, 0.634786f, 0.959142f, 0.103298f, 0.867167f, 0.029190f, 0.534917f, 0.404244f, 0.524184f, 0.365100f, 0.190567f, 0.019123f, 0.518150f, 0.842777f, 0.373216f, 0.222864f, 0.080532f, 0.085311f, 0.221396f, 0.100014f, 0.265040f, 0.066149f, 0.065605f, 0.856276f, 0.162120f, 0.559682f, 0.773456f, 0.456410f, 0.153369f, 0.199596f, 0.432984f, 0.528234f, 0.349440f, 0.781480f, 0.404774f, 0.601277f, 0.771931f, 0.413086f, 0.710058f, 0.789869f, 0.317260f, 0.979270f, 0.649656f, 0.880998f, 0.555938f, 0.741603f, 0.770544f, 0.908248f, 0.150350f, 0.558283f, 0.428379f, 0.923159f, 0.105095f, 0.982574f, 0.875451f, 0.073826f, 0.490966f, 0.717560f, 0.738152f, 0.906494f, 0.799865f, 0.310930f, 0.498435f, 0.701786f, 0.138437f, 0.193991f, 0.481042f, 0.298246f, 0.862559f, 0.586277f, 0.348665f, 0.848833f, 0.804878f, 0.998355f, 0.847308f, 0.414457f, 0.127499f, 0.840641f, 0.059758f, 0.350271f, 0.919738f, 0.960766f, 0.640565f, 0.688648f, 0.042454f, 0.514480f, 0.546868f, 0.340101f, 0.068597f, 0.228908f, 0.357984f, 0.435142f, 0.590927f, 0.722392f, 0.317632f, 0.328954f, 0.019692f, 0.040875f, 0.257822f, 0.740245f, 0.628314f, 0.769789f, 0.768919f, 0.856567f, 0.720319f, 0.979011f, 0.898825f, 0.586717f, 0.588158f, 0.034267f, 0.998527f, 0.131576f, 0.740347f, 0.821015f, 0.373055f, 0.196852f, 0.098760f, 0.748606f, 0.452654f, 0.713718f, 0.915408f, 0.146584f, 0.919171f, 0.411626f, 0.305267f, 0.943062f, 0.990652f, 0.198892f, 0.656838f, 0.106495f, 0.751022f, 0.927212f, 0.028953f, 0.895691f, 0.392569f, 0.878372f, 0.690785f, 0.987349f, 0.759282f, 0.364545f, 0.501063f, 0.376389f, 0.364912f, 0.260904f, 0.495970f, 0.681740f, 0.277340f, 0.524380f, 0.117380f, 0.159845f, 0.046806f, 0.970731f, 0.003860f, 0.178580f, 0.612867f, 0.081370f, 0.881896f, 0.719620f, 0.966390f, 0.507636f, 0.300404f, 0.549501f, 0.930819f, 0.520761f, 0.267207f, 0.877399f, 0.371919f, 0.001383f, 0.247685f, 0.318234f, 0.858777f, 0.458503f, 0.444587f, 0.336102f, 0.880678f, 0.945027f, 0.991890f, 0.376741f, 0.650914f, 0.827313f, 0.684499f, 0.417333f, 0.383066f, 0.393122f, 0.589712f, 0.881567f, 0.929066f, 0.053530f, 0.181622f, 0.112224f, 0.193335f, 0.346608f, 0.506532f, 0.629461f, 0.732142f, 0.890112f, 0.989088f, 0.662856f, 0.845365f, 0.778039f, 0.307532f, 0.875692f, 0.042763f, 0.000367f, 0.273733f, 0.462098f, 0.638363f, 0.101770f, 0.673010f, 0.801816f, 0.185313f, 0.415125f, 0.519985f, 0.451807f, 0.799830f, 0.960522f, 0.798953f, 0.077993f, 0.804936f, 0.066596f, 0.235970f, 0.153097f, 0.197519f, 0.528315f, 0.671690f, 0.470321f, 0.959696f, 0.240292f, 0.763140f, 0.870182f, 0.562066f, 0.456223f, 0.596184f, 0.428810f, 0.555194f, 0.416934f, 0.400470f, 0.695346f, 0.092851f, 0.166542f, 0.851198f, 0.771077f, 0.281454f, 0.377269f, 0.926027f, 0.818077f, 0.614346f, 0.221490f, 0.044252f, 0.431258f, 0.672627f, 0.828480f, 0.852689f, 0.032776f, 0.244157f, 0.339095f, 0.188732f, 0.802975f, 0.767466f, 0.516833f, 0.982926f, 0.144059f, 0.899652f, 0.116463f, 0.163182f, 0.696219f, 0.109570f, 0.565845f, 0.420234f, 0.728474f, 0.900675f, 0.769872f, 0.849690f, 0.032945f, 0.966147f, 0.791880f, 0.675689f, 0.244889f, 0.216457f, 0.166048f, 0.922757f, 0.294077f, 0.453094f, 0.493958f, 0.778172f, 0.844235f, 0.139073f, 0.426904f, 0.842855f, 0.818033f, 0.102414f, 0.156383f, 0.304199f, 0.075359f, 0.424663f, 0.107618f, 0.568218f, 0.246557f, 0.596433f, 0.117526f, 0.975884f, 0.932561f, 0.391797f, 0.242179f, 0.250398f, 0.483394f, 0.039993f, 0.639705f, 0.408303f, 0.377407f, 0.809365f, 0.709035f, 0.954334f, 0.351936f, 0.897543f, 0.769967f, 0.357425f, 0.621665f, 0.288570f, 0.874400f, 0.112427f, 0.212434f, 0.310196f, 0.515433f, 0.415953f, 0.231255f, 0.307874f, 0.945431f, 0.294181f, 0.353904f, 0.003710f, 0.845078f, 0.154841f, 0.204144f, 0.255265f, 0.884622f, 0.206451f, 0.797526f, 0.808049f, 0.927021f, 0.115561f, 0.217279f, 0.742898f, 0.196001f, 0.286330f, 0.166742f, 0.172697f, 0.481553f, 0.109683f, 0.321698f, 0.426594f, 0.024548f, 0.388333f, 0.094122f, 0.493579f, 0.825738f, 0.818422f, 0.080449f, 0.601228f, 0.834586f, 0.237973f, 0.761927f, 0.890764f, 0.806124f, 0.107301f, 0.009060f, 0.191724f, 0.270477f, 0.616183f, 0.384273f, 0.703407f, 0.353075f, 0.154425f, 0.312690f, 0.884324f, 0.958532f, 0.207513f, 0.788468f, 0.273349f, 0.887132f, 0.165546f, 0.665960f, 0.084211f, 0.973893f, 0.700633f, 0.841816f, 0.566669f, 0.476801f, 0.621882f, 0.528742f, 0.469384f, 0.759450f, 0.178201f, 0.171172f, 0.431843f, 0.320748f, 0.074125f, 0.844471f, 0.771603f, 0.543921f, 0.979325f, 0.072600f, 0.766669f, 0.266370f, 0.368599f, 0.219279f, 0.789038f, 0.144240f, 0.840017f, 0.661578f, 0.059023f, 0.810982f, 0.627756f, 0.904982f, 0.748722f, 0.561121f, 0.836547f, 0.278050f, 0.183033f, 0.403026f, 0.745233f, 0.526907f, 0.487676f, 0.000546f, 0.425402f, 0.063554f, 0.208253f, 0.932394f, 0.215398f, 0.858338f, 0.802893f, 0.159146f, 0.605712f, 0.115662f, 0.727888f, 0.637462f, 0.811939f, 0.479385f, 0.914863f, 0.049349f, 0.292889f, 0.715053f, 0.418109f, 0.172951f, 0.107211f, 0.817339f, 0.473143f, 0.882284f, 0.733289f, 0.409726f, 0.373511f, 0.515638f, 0.889060f, 0.737279f, 0.005153f, 0.694158f, 0.919507f, 0.710456f, 0.177006f, 0.483518f, 0.140316f, 0.358995f, 0.937117f, 0.923305f, 0.282837f, 0.339631f}; + // {2, 3, 18, 8} + std::vector present_value = {0.546950f, 0.293617f, 0.968204f, 0.226196f, 0.015738f, 0.325855f, 0.502509f, 0.028363f, 0.559248f, 0.874283f, 0.704732f, 0.622968f, 0.955962f, 0.958279f, 0.824266f, 0.607742f, 0.487765f, 0.013316f, 0.606262f, 0.989088f, 0.818101f, 0.340605f, 0.152047f, 0.784059f, 0.743938f, 0.967047f, 0.874842f, 0.555663f, 0.101284f, 0.483501f, 0.313695f, 0.512408f, 0.301702f, 0.861823f, 0.844327f, 0.315465f, 0.599581f, 0.430181f, 0.909093f, 0.187361f, 0.697728f, 0.970375f, 0.175276f, 0.201966f, 0.693723f, 0.779154f, 0.490549f, 0.609686f, 0.212682f, 0.476614f, 0.112072f, 0.321422f, 0.284780f, 0.444625f, 0.930126f, 0.181268f, 0.401388f, 0.615597f, 0.946557f, 0.133148f, 0.917877f, 0.081054f, 0.480741f, 0.454590f, 0.209603f, 0.347460f, 0.454165f, 0.865211f, 0.955064f, 0.518926f, 0.870100f, 0.608172f, 0.349087f, 0.194194f, 0.413135f, 0.522824f, 0.044443f, 0.145841f, 0.600184f, 0.225002f, 0.837326f, 0.326942f, 0.104834f, 0.083531f, 0.937123f, 0.118020f, 0.140910f, 0.862666f, 0.254288f, 0.665951f, 0.816726f, 0.607181f, 0.957489f, 0.708883f, 0.112752f, 0.558410f, 0.600213f, 0.963197f, 0.147801f, 0.256917f, 0.873557f, 0.491892f, 0.898961f, 0.185518f, 0.532669f, 0.326270f, 0.316543f, 0.446877f, 0.433077f, 0.357347f, 0.914971f, 0.731744f, 0.727547f, 0.289913f, 0.577709f, 0.779179f, 0.795590f, 0.344530f, 0.770873f, 0.735894f, 0.141506f, 0.865945f, 0.441321f, 0.486410f, 0.448369f, 0.567846f, 0.621169f, 0.498180f, 0.866789f, 0.627735f, 0.401428f, 0.416692f, 0.810839f, 0.348192f, 0.211455f, 0.059383f, 0.876027f, 0.918546f, 0.120120f, 0.334474f, 0.175372f, 0.115898f, 0.899867f, 0.056877f, 0.718186f, 0.801957f, 0.026321f, 0.718879f, 0.825681f, 0.746834f, 0.512349f, 0.458021f, 0.549419f, 0.704644f, 0.922914f, 0.617035f, 0.887834f, 0.701257f, 0.068336f, 0.500828f, 0.286486f, 0.285175f, 0.355928f, 0.314733f, 0.578610f, 0.683601f, 0.268749f, 0.129763f, 0.058809f, 0.575753f, 0.186130f, 0.009248f, 0.927753f, 0.537140f, 0.092448f, 0.842921f, 0.983203f, 0.448601f, 0.042490f, 0.117546f, 0.381654f, 0.885523f, 0.148039f, 0.823990f, 0.014976f, 0.457389f, 0.644397f, 0.060379f, 0.614763f, 0.944404f, 0.160260f, 0.729611f, 0.609094f, 0.185116f, 0.006203f, 0.009284f, 0.532092f, 0.942779f, 0.644299f, 0.714300f, 0.493865f, 0.581889f, 0.126368f, 0.876821f, 0.760793f, 0.998199f, 0.297723f, 0.227018f, 0.125162f, 0.964210f, 0.780885f, 0.166325f, 0.552686f, 0.413768f, 0.151486f, 0.162073f, 0.963470f, 0.304964f, 0.941439f, 0.075611f, 0.460803f, 0.129619f, 0.004787f, 0.553766f, 0.113894f, 0.722025f, 0.698116f, 0.176333f, 0.941742f, 0.721043f, 0.297970f, 0.709234f, 0.731930f, 0.342226f, 0.375589f, 0.359107f, 0.616618f, 0.900410f, 0.173193f, 0.875200f, 0.980486f, 0.096451f, 0.863471f, 0.566506f, 0.367917f, 0.342342f, 0.757364f, 0.314573f, 0.657319f, 0.517326f, 0.484966f, 0.901162f, 0.554645f, 0.826862f, 0.725574f, 0.038557f, 0.773110f, 0.216870f, 0.903150f, 0.042924f, 0.333072f, 0.099733f, 0.475589f, 0.820022f, 0.298187f, 0.150935f, 0.330267f, 0.813880f, 0.140384f, 0.227362f, 0.068852f, 0.705710f, 0.395233f, 0.310840f, 0.718626f, 0.335978f, 0.727771f, 0.815199f, 0.217663f, 0.973819f, 0.162358f, 0.290841f, 0.179795f, 0.345506f, 0.480061f, 0.522176f, 0.853606f, 0.889448f, 0.027653f, 0.660339f, 0.414439f, 0.791282f, 0.721198f, 0.480108f, 0.643864f, 0.501773f, 0.811518f, 0.476084f, 0.523156f, 0.250521f, 0.605043f, 0.302905f, 0.577284f, 0.169678f, 0.159469f, 0.417030f, 0.426820f, 0.268109f, 0.131597f, 0.039211f, 0.025232f, 0.271550f, 0.461853f, 0.726243f, 0.474872f, 0.904051f, 0.035220f, 0.180661f, 0.338515f, 0.577496f, 0.852736f, 0.350202f, 0.267989f, 0.061889f, 0.821303f, 0.379666f, 0.571550f, 0.983555f, 0.001595f, 0.145450f, 0.779111f, 0.805128f, 0.769247f, 0.536999f, 0.978857f, 0.396185f, 0.601944f, 0.063369f, 0.409857f, 0.722500f, 0.238739f, 0.943828f, 0.686783f, 0.287575f, 0.768999f, 0.083165f, 0.974774f, 0.049285f, 0.933456f, 0.252854f, 0.757824f, 0.000074f, 0.254240f, 0.749101f, 0.532336f, 0.114952f, 0.393630f, 0.375549f, 0.568162f, 0.667977f, 0.840830f, 0.497231f, 0.392022f, 0.143977f, 0.804823f, 0.713370f, 0.408677f, 0.518432f, 0.665183f, 0.164806f, 0.027198f, 0.317504f, 0.595585f, 0.486606f, 0.692555f, 0.819690f, 0.488442f, 0.134267f, 0.850628f, 0.574990f, 0.739937f, 0.704665f, 0.968212f, 0.295307f, 0.220104f, 0.622894f, 0.111496f, 0.458970f, 0.322334f, 0.316501f, 0.482584f, 0.729828f, 0.069183f, 0.879173f, 0.734814f, 0.176499f, 0.939161f, 0.506312f, 0.999809f, 0.197259f, 0.534908f, 0.290248f, 0.304174f, 0.591065f, 0.921719f, 0.805264f, 0.723941f, 0.559174f, 0.922298f, 0.492361f, 0.873832f, 0.833982f, 0.213835f, 0.771225f, 0.012171f, 0.322830f, 0.229567f, 0.506863f, 0.736853f, 0.097676f, 0.514922f, 0.938412f, 0.228647f, 0.677141f, 0.592880f, 0.010064f, 0.475826f, 0.708770f, 0.043975f, 0.879521f, 0.520081f, 0.030661f, 0.705307f, 0.365676f, 0.395411f, 0.230595f, 0.344010f, 0.948297f, 0.292571f, 0.245991f, 0.583138f, 0.258036f, 0.473386f, 0.834176f, 0.230400f, 0.426691f, 0.610490f, 0.545629f, 0.974723f, 0.680370f, 0.739946f, 0.966956f, 0.414438f, 0.355380f, 0.043862f, 0.184204f, 0.237190f, 0.183504f, 0.754784f, 0.535883f, 0.667634f, 0.820462f, 0.230774f, 0.325924f, 0.708360f, 0.392759f, 0.029271f, 0.434955f, 0.908273f, 0.409021f, 0.332249f, 0.989525f, 0.644416f, 0.365998f, 0.102020f, 0.787849f, 0.708075f, 0.921916f, 0.217276f, 0.114924f, 0.724073f, 0.203396f, 0.176104f, 0.319807f, 0.816825f, 0.539537f, 0.045850f, 0.463895f, 0.683980f, 0.538368f, 0.572450f, 0.224777f, 0.847739f, 0.561399f, 0.713246f, 0.981864f, 0.428199f, 0.881067f, 0.007281f, 0.033407f, 0.590280f, 0.311449f, 0.248277f, 0.277935f, 0.318403f, 0.728948f, 0.569196f, 0.789036f, 0.830197f, 0.842935f, 0.414644f, 0.421273f, 0.926266f, 0.661764f, 0.080467f, 0.542187f, 0.356007f, 0.987435f, 0.013655f, 0.612181f, 0.723623f, 0.288907f, 0.973642f, 0.859537f, 0.915653f, 0.019232f, 0.569872f, 0.294650f, 0.224414f, 0.953676f, 0.582320f, 0.107473f, 0.287544f, 0.456704f, 0.020950f, 0.411616f, 0.489459f, 0.243678f, 0.588639f, 0.753240f, 0.235834f, 0.620500f, 0.639622f, 0.948540f, 0.778276f, 0.848345f, 0.490420f, 0.185349f, 0.995815f, 0.129356f, 0.471457f, 0.068093f, 0.943851f, 0.964925f, 0.719389f, 0.349993f, 0.254382f, 0.265303f, 0.127294f, 0.525809f, 0.141817f, 0.316731f, 0.626706f, 0.727544f, 0.024273f, 0.430116f, 0.652125f, 0.853246f, 0.475325f, 0.969206f, 0.265633f, 0.013509f, 0.483753f, 0.256114f, 0.823718f, 0.232773f, 0.849029f, 0.632850f, 0.538877f, 0.114588f, 0.540223f, 0.631904f, 0.955912f, 0.585051f, 0.967401f, 0.961606f, 0.650200f, 0.505908f, 0.466022f, 0.890379f, 0.028257f, 0.113808f, 0.102072f, 0.756935f, 0.339651f, 0.637969f, 0.603783f, 0.385828f, 0.531568f, 0.645139f, 0.940950f, 0.575634f, 0.614367f, 0.067856f, 0.952216f, 0.528082f, 0.801273f, 0.050291f, 0.420910f, 0.256975f, 0.266976f, 0.791454f, 0.623867f, 0.439745f, 0.010586f, 0.964928f, 0.962023f, 0.217552f, 0.041346f, 0.530199f, 0.951411f, 0.910396f, 0.584663f, 0.303549f, 0.329961f, 0.897914f, 0.491784f, 0.131116f, 0.248425f, 0.276795f, 0.123547f, 0.463044f, 0.916051f, 0.668783f, 0.072474f, 0.005495f, 0.276248f, 0.362693f, 0.776750f, 0.967006f, 0.387567f, 0.686690f, 0.994902f, 0.745667f, 0.636190f, 0.078075f, 0.323215f, 0.913392f, 0.201005f, 0.843590f, 0.696324f, 0.366324f, 0.529174f, 0.542806f, 0.714054f, 0.516556f, 0.133076f, 0.773455f, 0.406273f, 0.963094f, 0.283514f, 0.263079f, 0.333507f, 0.572317f, 0.894870f, 0.176282f, 0.279679f, 0.581680f, 0.454334f, 0.447323f, 0.820734f, 0.923878f, 0.310629f, 0.791227f, 0.715143f, 0.558051f, 0.704948f, 0.418637f, 0.005310f, 0.011355f, 0.511222f, 0.083291f, 0.051075f, 0.965517f, 0.859003f, 0.152027f, 0.000664f, 0.941668f, 0.278325f, 0.185898f, 0.691508f, 0.108904f, 0.264650f, 0.975095f, 0.639463f, 0.520678f, 0.397919f, 0.774501f, 0.140957f, 0.967338f, 0.861123f, 0.617657f, 0.042906f, 0.700856f, 0.913284f, 0.524577f, 0.354225f, 0.120277f, 0.754901f, 0.885022f, 0.100252f, 0.758985f, 0.017060f, 0.967055f, 0.615058f, 0.552439f, 0.295950f, 0.929292f, 0.265906f, 0.828147f, 0.481307f, 0.687352f, 0.801059f, 0.518366f, 0.294316f, 0.638085f, 0.585109f, 0.901563f, 0.052407f, 0.910131f, 0.534432f, 0.015676f, 0.344702f, 0.724334f, 0.488433f, 0.980159f, 0.422610f, 0.326635f, 0.821672f, 0.547907f, 0.682327f, 0.805702f, 0.671428f, 0.422408f, 0.124796f, 0.580248f, 0.897433f, 0.418892f, 0.910725f, 0.503528f, 0.620842f, 0.832989f, 0.564597f, 0.090969f, 0.980979f, 0.245849f, 0.710505f, 0.505113f, 0.478773f, 0.243941f, 0.722151f, 0.112788f, 0.990453f, 0.845374f, 0.534509f, 0.424553f, 0.286465f, 0.501591f, 0.879417f, 0.275006f, 0.500537f, 0.234550f, 0.337149f, 0.190261f, 0.990539f, 0.571497f, 0.732815f, 0.098250f, 0.366118f, 0.892640f, 0.084438f, 0.165483f, 0.625418f, 0.622789f, 0.838227f, 0.935493f, 0.141986f, 0.259374f, 0.427461f, 0.000903f, 0.069814f, 0.226491f, 0.481102f, 0.251523f, 0.876682f, 0.324273f, 0.924623f, 0.974787f, 0.449862f, 0.227129f, 0.291666f, 0.776334f, 0.273350f, 0.380583f, 0.478576f, 0.575111f, 0.996100f, 0.232210f, 0.353424f, 0.262891f, 0.361113f, 0.100805f, 0.359810f, 0.887865f, 0.298590f, 0.371935f, 0.985109f, 0.783397f, 0.518990f, 0.066074f, 0.472414f, 0.438256f, 0.202796f, 0.423588f, 0.357758f, 0.163684f, 0.441374f, 0.262800f, 0.522062f, 0.035160f, 0.906231f, 0.816364f, 0.552581f, 0.851809f, 0.962395f, 0.110522f, 0.630832f, 0.997994f, 0.987889f, 0.603323f, 0.128021f, 0.583193f, 0.002065f, 0.198911f, 0.956123f, 0.330441f, 0.638390f, 0.280860f, 0.947822f, 0.728559f, 0.329651f, 0.791761f, 0.108166f, 0.392319f, 0.221218f, 0.683726f, 0.102446f, 0.397026f, 0.276650f, 0.506343f, 0.349898f, 0.706411f, 0.024577f, 0.633987f}; + + ASSERT_EQ(q.size(), batch_size * q_num_heads * q_sequence_length * head_size); + ASSERT_EQ(k.size(), batch_size * kv_num_heads * kv_sequence_length * head_size); + ASSERT_EQ(v.size(), batch_size * kv_num_heads * kv_sequence_length * v_head_size); + ASSERT_EQ(m.size(), q_sequence_length * (kv_sequence_length + past_sequence_length)); + ASSERT_EQ(past_key.size(), batch_size * kv_num_heads * past_sequence_length * head_size); + ASSERT_EQ(past_value.size(), batch_size * kv_num_heads * past_sequence_length * v_head_size); + ASSERT_EQ(y.size(), batch_size * q_num_heads * q_sequence_length * v_head_size); + ASSERT_EQ(present_key.size(), batch_size * kv_num_heads * (past_sequence_length + kv_sequence_length) * head_size); + ASSERT_EQ(present_value.size(), batch_size * kv_num_heads * (past_sequence_length + kv_sequence_length) * v_head_size); + + RunTest4D(batch_size, q_num_heads, q_sequence_length, head_size, kv_sequence_length, kv_num_heads, v_head_size, past_sequence_length, + q, k, v, m, std::initializer_list(), past_key, past_value, + -1, -1, std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), -1, TensorType::kFloat, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type + y, present_key, present_value, std::vector(), + false, true, true // disable_cpu, disable_cuda, disable_dml + ); +} + +TEST(AttentionTest, Attention4DWithPastAndPresentQkMatmul) { + int batch_size = 2; // Q.shape[0] + int q_num_heads = 3; // Q.shape[1] + int q_sequence_length = 4; // Q.shape[2] + int head_size = 8; // Q.shape[3] + int kv_sequence_length = 6; // K.shape[2] and V.shape[2] + int kv_num_heads = 3; // K.shape[1] and V.shape[1] + int v_head_size = 8; // V.shape[3] + int past_sequence_length = 12; // past_key.shape[2] and past_value.shape[2] + + // {2, 3, 4, 8} + std::vector q = {0.548814f, 0.715189f, 0.602763f, 0.544883f, 0.423655f, 0.645894f, 0.437587f, 0.891773f, 0.963663f, 0.383442f, 0.791725f, 0.528895f, 0.568045f, 0.925597f, 0.071036f, 0.087129f, 0.020218f, 0.832620f, 0.778157f, 0.870012f, 0.978618f, 0.799159f, 0.461479f, 0.780529f, 0.118274f, 0.639921f, 0.143353f, 0.944669f, 0.521848f, 0.414662f, 0.264556f, 0.774234f, 0.456150f, 0.568434f, 0.018790f, 0.617635f, 0.612096f, 0.616934f, 0.943748f, 0.681820f, 0.359508f, 0.437032f, 0.697631f, 0.060225f, 0.666767f, 0.670638f, 0.210383f, 0.128926f, 0.315428f, 0.363711f, 0.570197f, 0.438602f, 0.988374f, 0.102045f, 0.208877f, 0.161310f, 0.653108f, 0.253292f, 0.466311f, 0.244426f, 0.158970f, 0.110375f, 0.656330f, 0.138183f, 0.196582f, 0.368725f, 0.820993f, 0.097101f, 0.837945f, 0.096098f, 0.976459f, 0.468651f, 0.976761f, 0.604846f, 0.739264f, 0.039188f, 0.282807f, 0.120197f, 0.296140f, 0.118728f, 0.317983f, 0.414263f, 0.064147f, 0.692472f, 0.566601f, 0.265390f, 0.523248f, 0.093941f, 0.575947f, 0.929296f, 0.318569f, 0.667410f, 0.131798f, 0.716327f, 0.289406f, 0.183191f, 0.586513f, 0.020108f, 0.828940f, 0.004695f, 0.677817f, 0.270008f, 0.735194f, 0.962189f, 0.248753f, 0.576157f, 0.592042f, 0.572252f, 0.223082f, 0.952749f, 0.447125f, 0.846409f, 0.699479f, 0.297437f, 0.813798f, 0.396506f, 0.881103f, 0.581273f, 0.881735f, 0.692532f, 0.725254f, 0.501324f, 0.956084f, 0.643990f, 0.423855f, 0.606393f, 0.019193f, 0.301575f, 0.660174f, 0.290078f, 0.618015f, 0.428769f, 0.135474f, 0.298282f, 0.569965f, 0.590873f, 0.574325f, 0.653201f, 0.652103f, 0.431418f, 0.896547f, 0.367562f, 0.435865f, 0.891923f, 0.806194f, 0.703889f, 0.100227f, 0.919483f, 0.714241f, 0.998847f, 0.149448f, 0.868126f, 0.162493f, 0.615560f, 0.123820f, 0.848008f, 0.807319f, 0.569101f, 0.407183f, 0.069167f, 0.697429f, 0.453543f, 0.722056f, 0.866382f, 0.975522f, 0.855803f, 0.011714f, 0.359978f, 0.729991f, 0.171630f, 0.521037f, 0.054338f, 0.199997f, 0.018522f, 0.793698f, 0.223925f, 0.345352f, 0.928081f, 0.704414f, 0.031839f, 0.164694f, 0.621478f, 0.577229f, 0.237893f, 0.934214f, 0.613966f, 0.535633f, 0.589910f, 0.730122f, 0.311945f, 0.398221f, 0.209844f}; + // {2, 3, 6, 8} + std::vector k = {0.186193f, 0.944372f, 0.739551f, 0.490459f, 0.227415f, 0.254356f, 0.058029f, 0.434417f, 0.311796f, 0.696343f, 0.377752f, 0.179604f, 0.024679f, 0.067250f, 0.679393f, 0.453697f, 0.536579f, 0.896671f, 0.990339f, 0.216897f, 0.663078f, 0.263322f, 0.020651f, 0.758379f, 0.320017f, 0.383464f, 0.588317f, 0.831048f, 0.628982f, 0.872651f, 0.273542f, 0.798047f, 0.185636f, 0.952792f, 0.687488f, 0.215508f, 0.947371f, 0.730856f, 0.253942f, 0.213312f, 0.518201f, 0.025663f, 0.207470f, 0.424685f, 0.374170f, 0.463575f, 0.277629f, 0.586784f, 0.863856f, 0.117532f, 0.517379f, 0.132068f, 0.716860f, 0.396060f, 0.565421f, 0.183280f, 0.144848f, 0.488056f, 0.355613f, 0.940432f, 0.765325f, 0.748664f, 0.903720f, 0.083422f, 0.552192f, 0.584476f, 0.961936f, 0.292148f, 0.240829f, 0.100294f, 0.016430f, 0.929529f, 0.669917f, 0.785153f, 0.281730f, 0.586410f, 0.063955f, 0.485628f, 0.977495f, 0.876505f, 0.338159f, 0.961570f, 0.231702f, 0.949319f, 0.941378f, 0.799203f, 0.630448f, 0.874288f, 0.293020f, 0.848944f, 0.617877f, 0.013237f, 0.347234f, 0.148141f, 0.981829f, 0.478370f, 0.497391f, 0.639473f, 0.368585f, 0.136900f, 0.822118f, 0.189848f, 0.511319f, 0.224317f, 0.097844f, 0.862191f, 0.972919f, 0.960835f, 0.906555f, 0.774047f, 0.333145f, 0.081101f, 0.407241f, 0.232234f, 0.132488f, 0.053427f, 0.725594f, 0.011427f, 0.770581f, 0.146947f, 0.079522f, 0.089603f, 0.672048f, 0.245367f, 0.420539f, 0.557369f, 0.860551f, 0.727044f, 0.270328f, 0.131483f, 0.055374f, 0.301599f, 0.262118f, 0.456141f, 0.683281f, 0.695625f, 0.283519f, 0.379927f, 0.181151f, 0.788545f, 0.056848f, 0.696997f, 0.778695f, 0.777408f, 0.259423f, 0.373813f, 0.587600f, 0.272822f, 0.370853f, 0.197054f, 0.459856f, 0.044612f, 0.799796f, 0.076956f, 0.518835f, 0.306810f, 0.577543f, 0.959433f, 0.645570f, 0.035362f, 0.430402f, 0.510017f, 0.536178f, 0.681392f, 0.277596f, 0.128861f, 0.392676f, 0.956406f, 0.187131f, 0.903984f, 0.543806f, 0.456911f, 0.882041f, 0.458604f, 0.724168f, 0.399025f, 0.904044f, 0.690025f, 0.699622f, 0.327720f, 0.756779f, 0.636061f, 0.240020f, 0.160539f, 0.796391f, 0.959167f, 0.458139f, 0.590984f, 0.857723f, 0.457223f, 0.951874f, 0.575751f, 0.820767f, 0.908844f, 0.815524f, 0.159414f, 0.628898f, 0.398434f, 0.062713f, 0.424032f, 0.258684f, 0.849038f, 0.033305f, 0.958983f, 0.355369f, 0.356707f, 0.016329f, 0.185232f, 0.401260f, 0.929291f, 0.099615f, 0.945302f, 0.869489f, 0.454162f, 0.326701f, 0.232744f, 0.614465f, 0.033075f, 0.015606f, 0.428796f, 0.068074f, 0.251941f, 0.221161f, 0.253191f, 0.131055f, 0.012036f, 0.115484f, 0.618480f, 0.974256f, 0.990345f, 0.409054f, 0.162954f, 0.638762f, 0.490305f, 0.989410f, 0.065304f, 0.783234f, 0.288399f, 0.241419f, 0.662505f, 0.246063f, 0.665859f, 0.517309f, 0.424089f, 0.554688f, 0.287052f, 0.706575f, 0.414857f, 0.360546f, 0.828657f, 0.924967f, 0.046007f, 0.232627f, 0.348519f, 0.814966f, 0.985491f, 0.968972f, 0.904948f, 0.296556f, 0.992011f, 0.249420f, 0.105906f, 0.950953f, 0.233420f, 0.689768f, 0.058356f, 0.730709f, 0.881720f, 0.272437f, 0.379057f, 0.374296f, 0.748788f, 0.237807f, 0.171853f, 0.449292f, 0.304468f, 0.839189f, 0.237742f, 0.502389f, 0.942584f, 0.633998f, 0.867289f, 0.940210f, 0.750765f, 0.699575f, 0.967966f, 0.994401f, 0.451822f}; + // {2, 3, 6, 8} + std::vector v = {0.070870f, 0.292794f, 0.152355f, 0.417486f, 0.131289f, 0.604118f, 0.382808f, 0.895386f, 0.967795f, 0.546885f, 0.274824f, 0.592230f, 0.896761f, 0.406733f, 0.552078f, 0.271653f, 0.455444f, 0.401714f, 0.248413f, 0.505866f, 0.310381f, 0.373035f, 0.524970f, 0.750595f, 0.333507f, 0.924159f, 0.862319f, 0.048690f, 0.253643f, 0.446136f, 0.104628f, 0.348476f, 0.740098f, 0.680514f, 0.622384f, 0.710528f, 0.204924f, 0.341698f, 0.676242f, 0.879235f, 0.543678f, 0.282700f, 0.030235f, 0.710337f, 0.007884f, 0.372679f, 0.530537f, 0.922111f, 0.089495f, 0.405942f, 0.024313f, 0.342611f, 0.622231f, 0.279068f, 0.209750f, 0.115703f, 0.577140f, 0.695270f, 0.671957f, 0.948861f, 0.002703f, 0.647197f, 0.600392f, 0.588740f, 0.962770f, 0.016872f, 0.696482f, 0.813679f, 0.509807f, 0.333965f, 0.790840f, 0.097243f, 0.442036f, 0.519952f, 0.693956f, 0.090886f, 0.227759f, 0.410302f, 0.623295f, 0.886961f, 0.618826f, 0.133461f, 0.980580f, 0.871786f, 0.502721f, 0.922348f, 0.541381f, 0.923306f, 0.829897f, 0.968286f, 0.919783f, 0.036034f, 0.174772f, 0.389135f, 0.952143f, 0.300029f, 0.160468f, 0.886305f, 0.446394f, 0.907876f, 0.160230f, 0.661117f, 0.440264f, 0.076487f, 0.696463f, 0.247399f, 0.039616f, 0.059944f, 0.061079f, 0.907733f, 0.739884f, 0.898062f, 0.672582f, 0.528940f, 0.304446f, 0.997962f, 0.362189f, 0.470649f, 0.378245f, 0.979527f, 0.174658f, 0.327988f, 0.680349f, 0.063208f, 0.607249f, 0.477646f, 0.284000f, 0.238413f, 0.514513f, 0.367928f, 0.456520f, 0.337477f, 0.970494f, 0.133439f, 0.096804f, 0.343392f, 0.591027f, 0.659176f, 0.397257f, 0.999278f, 0.351893f, 0.721407f, 0.637583f, 0.813054f, 0.976226f, 0.889794f, 0.764562f, 0.698249f, 0.335498f, 0.147686f, 0.062636f, 0.241902f, 0.432281f, 0.521996f, 0.773084f, 0.958741f, 0.117320f, 0.107004f, 0.589695f, 0.745398f, 0.848150f, 0.935832f, 0.983426f, 0.399802f, 0.380335f, 0.147809f, 0.684934f, 0.656762f, 0.862063f, 0.097258f, 0.497777f, 0.581082f, 0.241557f, 0.169025f, 0.859581f, 0.058535f, 0.470621f, 0.115834f, 0.457059f, 0.979962f, 0.423706f, 0.857125f, 0.117316f, 0.271252f, 0.403793f, 0.399812f, 0.671384f, 0.344718f, 0.713767f, 0.639187f, 0.399161f, 0.431760f, 0.614528f, 0.070042f, 0.822407f, 0.653421f, 0.726342f, 0.536923f, 0.110477f, 0.405036f, 0.405374f, 0.321043f, 0.029950f, 0.737254f, 0.109784f, 0.606308f, 0.703218f, 0.634786f, 0.959142f, 0.103298f, 0.867167f, 0.029190f, 0.534917f, 0.404244f, 0.524184f, 0.365100f, 0.190567f, 0.019123f, 0.518150f, 0.842777f, 0.373216f, 0.222864f, 0.080532f, 0.085311f, 0.221396f, 0.100014f, 0.265040f, 0.066149f, 0.065605f, 0.856276f, 0.162120f, 0.559682f, 0.773456f, 0.456410f, 0.153369f, 0.199596f, 0.432984f, 0.528234f, 0.349440f, 0.781480f, 0.751022f, 0.927212f, 0.028953f, 0.895691f, 0.392569f, 0.878372f, 0.690785f, 0.987349f, 0.759282f, 0.364545f, 0.501063f, 0.376389f, 0.364912f, 0.260904f, 0.495970f, 0.681740f, 0.277340f, 0.524380f, 0.117380f, 0.159845f, 0.046806f, 0.970731f, 0.003860f, 0.178580f, 0.612867f, 0.081370f, 0.881896f, 0.719620f, 0.966390f, 0.507636f, 0.300404f, 0.549501f, 0.930819f, 0.520761f, 0.267207f, 0.877399f, 0.371919f, 0.001383f, 0.247685f, 0.318234f, 0.858777f, 0.458503f, 0.444587f, 0.336102f, 0.880678f, 0.945027f, 0.991890f, 0.376741f}; + // {4, 18} + std::vector m = {0.966147f, 0.791880f, 0.675689f, 0.244889f, 0.216457f, 0.166048f, 0.922757f, 0.294077f, 0.453094f, 0.493958f, 0.778172f, 0.844235f, 0.139073f, 0.426904f, 0.842855f, 0.818033f, 0.102414f, 0.156383f, 0.304199f, 0.075359f, 0.424663f, 0.107618f, 0.568218f, 0.246557f, 0.596433f, 0.117526f, 0.975884f, 0.932561f, 0.391797f, 0.242179f, 0.250398f, 0.483394f, 0.039993f, 0.639705f, 0.408303f, 0.377407f, 0.809365f, 0.709035f, 0.954334f, 0.351936f, 0.897543f, 0.769967f, 0.357425f, 0.621665f, 0.288570f, 0.874400f, 0.112427f, 0.212434f, 0.183033f, 0.403026f, 0.745233f, 0.526907f, 0.487676f, 0.000546f, 0.425402f, 0.063554f, 0.208253f, 0.932394f, 0.215398f, 0.858338f, 0.802893f, 0.159146f, 0.605712f, 0.115662f, 0.727888f, 0.637462f, 0.811939f, 0.479385f, 0.914863f, 0.049349f, 0.292889f, 0.715053f}; + // {2, 3, 12, 8} + std::vector past_key = {0.418109f, 0.172951f, 0.107211f, 0.817339f, 0.473143f, 0.882284f, 0.733289f, 0.409726f, 0.373511f, 0.515638f, 0.889060f, 0.737279f, 0.005153f, 0.694158f, 0.919507f, 0.710456f, 0.177006f, 0.483518f, 0.140316f, 0.358995f, 0.937117f, 0.923305f, 0.282837f, 0.339631f, 0.600213f, 0.963197f, 0.147801f, 0.256917f, 0.873557f, 0.491892f, 0.898961f, 0.185518f, 0.532669f, 0.326270f, 0.316543f, 0.446877f, 0.433077f, 0.357347f, 0.914971f, 0.731744f, 0.727547f, 0.289913f, 0.577709f, 0.779179f, 0.795590f, 0.344530f, 0.770873f, 0.735894f, 0.141506f, 0.865945f, 0.441321f, 0.486410f, 0.448369f, 0.567846f, 0.621169f, 0.498180f, 0.866789f, 0.627735f, 0.401428f, 0.416692f, 0.810839f, 0.348192f, 0.211455f, 0.059383f, 0.876027f, 0.918546f, 0.120120f, 0.334474f, 0.175372f, 0.115898f, 0.899867f, 0.056877f, 0.980486f, 0.096451f, 0.863471f, 0.566506f, 0.367917f, 0.342342f, 0.757364f, 0.314573f, 0.657319f, 0.517326f, 0.484966f, 0.901162f, 0.554645f, 0.826862f, 0.725574f, 0.038557f, 0.773110f, 0.216870f, 0.903150f, 0.042924f, 0.333072f, 0.099733f, 0.475589f, 0.820022f, 0.298187f, 0.150935f, 0.330267f, 0.813880f, 0.140384f, 0.227362f, 0.068852f, 0.705710f, 0.395233f, 0.310840f, 0.718626f, 0.335978f, 0.727771f, 0.815199f, 0.217663f, 0.973819f, 0.162358f, 0.290841f, 0.179795f, 0.345506f, 0.480061f, 0.522176f, 0.853606f, 0.889448f, 0.220104f, 0.622894f, 0.111496f, 0.458970f, 0.322334f, 0.316501f, 0.482584f, 0.729828f, 0.069183f, 0.879173f, 0.734814f, 0.176499f, 0.939161f, 0.506312f, 0.999809f, 0.197259f, 0.534908f, 0.290248f, 0.304174f, 0.591065f, 0.921719f, 0.805264f, 0.723941f, 0.559174f, 0.922298f, 0.492361f, 0.873832f, 0.833982f, 0.213835f, 0.771225f, 0.012171f, 0.322830f, 0.229567f, 0.506863f, 0.736853f, 0.097676f, 0.514922f, 0.938412f, 0.228647f, 0.677141f, 0.592880f, 0.010064f, 0.475826f, 0.708770f, 0.043975f, 0.879521f, 0.520081f, 0.030661f, 0.224414f, 0.953676f, 0.582320f, 0.107473f, 0.287544f, 0.456704f, 0.020950f, 0.411616f, 0.489459f, 0.243678f, 0.588639f, 0.753240f, 0.235834f, 0.620500f, 0.639622f, 0.948540f, 0.778276f, 0.848345f, 0.490420f, 0.185349f, 0.995815f, 0.129356f, 0.471457f, 0.068093f, 0.943851f, 0.964925f, 0.719389f, 0.349993f, 0.254382f, 0.265303f, 0.127294f, 0.525809f, 0.141817f, 0.316731f, 0.626706f, 0.727544f, 0.024273f, 0.430116f, 0.652125f, 0.853246f, 0.475325f, 0.969206f, 0.265633f, 0.013509f, 0.483753f, 0.256114f, 0.823718f, 0.232773f, 0.310629f, 0.791227f, 0.715143f, 0.558051f, 0.704948f, 0.418637f, 0.005310f, 0.011355f, 0.511222f, 0.083291f, 0.051075f, 0.965517f, 0.859003f, 0.152027f, 0.000664f, 0.941668f, 0.278325f, 0.185898f, 0.691508f, 0.108904f, 0.264650f, 0.975095f, 0.639463f, 0.520678f, 0.397919f, 0.774501f, 0.140957f, 0.967338f, 0.861123f, 0.617657f, 0.042906f, 0.700856f, 0.913284f, 0.524577f, 0.354225f, 0.120277f, 0.754901f, 0.885022f, 0.100252f, 0.758985f, 0.017060f, 0.967055f, 0.615058f, 0.552439f, 0.295950f, 0.929292f, 0.265906f, 0.828147f, 0.985109f, 0.783397f, 0.518990f, 0.066074f, 0.472414f, 0.438256f, 0.202796f, 0.423588f, 0.357758f, 0.163684f, 0.441374f, 0.262800f, 0.522062f, 0.035160f, 0.906231f, 0.816364f, 0.552581f, 0.851809f, 0.962395f, 0.110522f, 0.630832f, 0.997994f, 0.987889f, 0.603323f, 0.128021f, 0.583193f, 0.002065f, 0.198911f, 0.956123f, 0.330441f, 0.638390f, 0.280860f, 0.947822f, 0.728559f, 0.329651f, 0.791761f, 0.108166f, 0.392319f, 0.221218f, 0.683726f, 0.102446f, 0.397026f, 0.276650f, 0.506343f, 0.349898f, 0.706411f, 0.024577f, 0.633987f, 0.230571f, 0.268709f, 0.800256f, 0.955568f, 0.316550f, 0.826805f, 0.103991f, 0.633982f, 0.751032f, 0.155978f, 0.426002f, 0.892707f, 0.103578f, 0.018096f, 0.590585f, 0.435532f, 0.798689f, 0.923456f, 0.299154f, 0.388404f, 0.486272f, 0.588151f, 0.983854f, 0.697330f, 0.389549f, 0.263768f, 0.944626f, 0.135548f, 0.720266f, 0.925395f, 0.664666f, 0.423054f, 0.198991f, 0.367475f, 0.706872f, 0.649534f, 0.927976f, 0.866861f, 0.816151f, 0.911451f, 0.276337f, 0.369524f, 0.379894f, 0.560451f, 0.668218f, 0.286717f, 0.019462f, 0.399222f, 0.308528f, 0.942185f, 0.888265f, 0.860311f, 0.653000f, 0.344289f, 0.548849f, 0.815225f, 0.098610f, 0.801075f, 0.041180f, 0.816421f, 0.807564f, 0.051007f, 0.627161f, 0.502453f, 0.169820f, 0.148379f, 0.773259f, 0.567693f, 0.982999f, 0.982248f, 0.992667f, 0.118616f, 0.938256f, 0.244570f, 0.458212f, 0.757407f, 0.203621f, 0.566312f, 0.185817f, 0.104736f, 0.116559f, 0.357639f, 0.004655f, 0.424854f, 0.664197f, 0.401688f, 0.085795f, 0.062689f, 0.278117f, 0.169313f, 0.965095f, 0.151230f, 0.805462f, 0.586108f, 0.569287f, 0.512081f, 0.971763f, 0.363845f, 0.787916f, 0.555294f, 0.395634f, 0.955466f, 0.598316f, 0.118917f, 0.417539f, 0.781582f, 0.693747f, 0.916340f, 0.259377f, 0.758194f, 0.459875f, 0.573610f, 0.955047f, 0.979286f, 0.861591f, 0.359097f, 0.887701f, 0.638609f, 0.429997f, 0.035743f, 0.770128f, 0.502106f, 0.786188f, 0.748023f, 0.793567f, 0.300651f, 0.800799f, 0.548846f, 0.473326f, 0.675126f, 0.021359f, 0.102317f, 0.292177f, 0.982990f, 0.139746f, 0.330596f, 0.051053f, 0.331269f, 0.320326f, 0.946807f, 0.845154f, 0.382764f, 0.024769f, 0.831031f, 0.660536f, 0.152364f, 0.996071f, 0.100233f, 0.867115f, 0.294266f, 0.435353f, 0.795457f, 0.677508f, 0.937864f, 0.621140f, 0.097810f, 0.884360f, 0.769156f, 0.711870f, 0.053734f, 0.396223f, 0.167436f, 0.821904f, 0.700529f, 0.883078f, 0.966575f, 0.774748f, 0.994233f, 0.614770f, 0.037130f, 0.014252f, 0.342104f, 0.823472f, 0.866135f, 0.960813f, 0.065121f, 0.044571f, 0.913284f, 0.305047f, 0.557987f, 0.982445f, 0.400449f, 0.665871f, 0.400880f, 0.768195f, 0.527715f, 0.237523f, 0.271306f, 0.258059f, 0.532320f, 0.703189f, 0.949280f, 0.694087f, 0.781193f, 0.168926f, 0.374063f, 0.413780f, 0.686380f, 0.295892f, 0.303292f, 0.355889f, 0.810302f, 0.577590f, 0.075277f, 0.078246f, 0.371287f, 0.766591f, 0.688683f, 0.707982f, 0.767210f, 0.287153f, 0.548256f, 0.543353f, 0.739632f, 0.956871f, 0.277990f, 0.793282f, 0.659971f, 0.580238f, 0.774880f, 0.944032f, 0.036691f, 0.147400f, 0.756287f, 0.083791f, 0.516124f, 0.219861f, 0.274296f, 0.701840f, 0.030193f, 0.873319f, 0.444479f, 0.502393f, 0.540048f, 0.645544f, 0.344857f, 0.101107f, 0.318379f, 0.168142f, 0.556133f, 0.318029f, 0.958067f, 0.965734f, 0.620126f, 0.617497f, 0.985379f, 0.887283f, 0.765070f, 0.313591f, 0.365539f, 0.201267f, 0.487148f, 0.990369f, 0.912151f, 0.118349f, 0.025190f, 0.898638f, 0.537170f, 0.200190f, 0.673653f, 0.644223f, 0.122086f, 0.259600f, 0.060078f}; + // {2, 3, 12, 8} + std::vector past_value = {0.209860f, 0.132306f, 0.193236f, 0.685467f, 0.049500f, 0.101855f, 0.134174f, 0.316541f, 0.298750f, 0.255064f, 0.750537f, 0.998023f, 0.533978f, 0.944203f, 0.396610f, 0.106682f, 0.408774f, 0.296128f, 0.493407f, 0.657044f, 0.461050f, 0.935161f, 0.884765f, 0.701978f, 0.489685f, 0.131687f, 0.397014f, 0.704402f, 0.284886f, 0.103988f, 0.907898f, 0.709051f, 0.615276f, 0.792499f, 0.835646f, 0.483459f, 0.881188f, 0.916419f, 0.271551f, 0.607545f, 0.526584f, 0.537946f, 0.937663f, 0.305189f, 0.983434f, 0.902131f, 0.458723f, 0.817453f, 0.769047f, 0.677895f, 0.319834f, 0.196451f, 0.671528f, 0.842973f, 0.016253f, 0.642803f, 0.442873f, 0.898088f, 0.321473f, 0.474185f, 0.514767f, 0.140440f, 0.712892f, 0.830476f, 0.057909f, 0.291389f, 0.038045f, 0.956544f, 0.667169f, 0.964200f, 0.531494f, 0.802069f, 0.374414f, 0.353819f, 0.378268f, 0.657862f, 0.359453f, 0.900367f, 0.983275f, 0.030427f, 0.193623f, 0.112250f, 0.042364f, 0.227741f, 0.446793f, 0.836990f, 0.221824f, 0.493945f, 0.929619f, 0.667215f, 0.798079f, 0.550994f, 0.980466f, 0.588662f, 0.045511f, 0.197983f, 0.404774f, 0.601277f, 0.771931f, 0.413086f, 0.710058f, 0.789869f, 0.317260f, 0.979270f, 0.649656f, 0.880998f, 0.555938f, 0.741603f, 0.770544f, 0.908248f, 0.150350f, 0.558283f, 0.428379f, 0.923159f, 0.105095f, 0.982574f, 0.875451f, 0.073826f, 0.490966f, 0.717560f, 0.738152f, 0.906494f, 0.799865f, 0.310930f, 0.498435f, 0.701786f, 0.138437f, 0.193991f, 0.481042f, 0.298246f, 0.862559f, 0.586277f, 0.348665f, 0.848833f, 0.804878f, 0.998355f, 0.847308f, 0.414457f, 0.127499f, 0.840641f, 0.059758f, 0.350271f, 0.919738f, 0.960766f, 0.640565f, 0.688648f, 0.042454f, 0.514480f, 0.546868f, 0.340101f, 0.068597f, 0.228908f, 0.357984f, 0.435142f, 0.590927f, 0.722392f, 0.317632f, 0.328954f, 0.019692f, 0.040875f, 0.257822f, 0.740245f, 0.628314f, 0.769789f, 0.768919f, 0.856567f, 0.720319f, 0.979011f, 0.898825f, 0.586717f, 0.588158f, 0.034267f, 0.998527f, 0.131576f, 0.740347f, 0.821015f, 0.373055f, 0.196852f, 0.098760f, 0.748606f, 0.452654f, 0.713718f, 0.915408f, 0.146584f, 0.919171f, 0.411626f, 0.305267f, 0.943062f, 0.990652f, 0.198892f, 0.656838f, 0.106495f, 0.650914f, 0.827313f, 0.684499f, 0.417333f, 0.383066f, 0.393122f, 0.589712f, 0.881567f, 0.929066f, 0.053530f, 0.181622f, 0.112224f, 0.193335f, 0.346608f, 0.506532f, 0.629461f, 0.732142f, 0.890112f, 0.989088f, 0.662856f, 0.845365f, 0.778039f, 0.307532f, 0.875692f, 0.042763f, 0.000367f, 0.273733f, 0.462098f, 0.638363f, 0.101770f, 0.673010f, 0.801816f, 0.185313f, 0.415125f, 0.519985f, 0.451807f, 0.799830f, 0.960522f, 0.798953f, 0.077993f, 0.804936f, 0.066596f, 0.235970f, 0.153097f, 0.197519f, 0.528315f, 0.671690f, 0.470321f, 0.959696f, 0.240292f, 0.763140f, 0.870182f, 0.562066f, 0.456223f, 0.596184f, 0.428810f, 0.555194f, 0.416934f, 0.400470f, 0.695346f, 0.092851f, 0.166542f, 0.851198f, 0.771077f, 0.281454f, 0.377269f, 0.926027f, 0.818077f, 0.614346f, 0.221490f, 0.044252f, 0.431258f, 0.672627f, 0.828480f, 0.852689f, 0.032776f, 0.244157f, 0.339095f, 0.188732f, 0.802975f, 0.767466f, 0.516833f, 0.982926f, 0.144059f, 0.899652f, 0.116463f, 0.163182f, 0.696219f, 0.109570f, 0.565845f, 0.420234f, 0.728474f, 0.900675f, 0.769872f, 0.849690f, 0.032945f, 0.310196f, 0.515433f, 0.415953f, 0.231255f, 0.307874f, 0.945431f, 0.294181f, 0.353904f, 0.003710f, 0.845078f, 0.154841f, 0.204144f, 0.255265f, 0.884622f, 0.206451f, 0.797526f, 0.808049f, 0.927021f, 0.115561f, 0.217279f, 0.742898f, 0.196001f, 0.286330f, 0.166742f, 0.172697f, 0.481553f, 0.109683f, 0.321698f, 0.426594f, 0.024548f, 0.388333f, 0.094122f, 0.493579f, 0.825738f, 0.818422f, 0.080449f, 0.601228f, 0.834586f, 0.237973f, 0.761927f, 0.890764f, 0.806124f, 0.107301f, 0.009060f, 0.191724f, 0.270477f, 0.616183f, 0.384273f, 0.703407f, 0.353075f, 0.154425f, 0.312690f, 0.884324f, 0.958532f, 0.207513f, 0.788468f, 0.273349f, 0.887132f, 0.165546f, 0.665960f, 0.084211f, 0.973893f, 0.700633f, 0.841816f, 0.566669f, 0.476801f, 0.621882f, 0.528742f, 0.469384f, 0.759450f, 0.178201f, 0.171172f, 0.431843f, 0.320748f, 0.074125f, 0.844471f, 0.771603f, 0.543921f, 0.979325f, 0.072600f, 0.766669f, 0.266370f, 0.368599f, 0.219279f, 0.789038f, 0.144240f, 0.840017f, 0.661578f, 0.059023f, 0.810982f, 0.627756f, 0.904982f, 0.748722f, 0.561121f, 0.836547f, 0.278050f, 0.546950f, 0.293617f, 0.968204f, 0.226196f, 0.015738f, 0.325855f, 0.502509f, 0.028363f, 0.559248f, 0.874283f, 0.704732f, 0.622968f, 0.955962f, 0.958279f, 0.824266f, 0.607742f, 0.487765f, 0.013316f, 0.606262f, 0.989088f, 0.818101f, 0.340605f, 0.152047f, 0.784059f, 0.743938f, 0.967047f, 0.874842f, 0.555663f, 0.101284f, 0.483501f, 0.313695f, 0.512408f, 0.301702f, 0.861823f, 0.844327f, 0.315465f, 0.599581f, 0.430181f, 0.909093f, 0.187361f, 0.697728f, 0.970375f, 0.175276f, 0.201966f, 0.693723f, 0.779154f, 0.490549f, 0.609686f, 0.212682f, 0.476614f, 0.112072f, 0.321422f, 0.284780f, 0.444625f, 0.930126f, 0.181268f, 0.401388f, 0.615597f, 0.946557f, 0.133148f, 0.917877f, 0.081054f, 0.480741f, 0.454590f, 0.209603f, 0.347460f, 0.454165f, 0.865211f, 0.955064f, 0.518926f, 0.870100f, 0.608172f, 0.349087f, 0.194194f, 0.413135f, 0.522824f, 0.044443f, 0.145841f, 0.600184f, 0.225002f, 0.837326f, 0.326942f, 0.104834f, 0.083531f, 0.937123f, 0.118020f, 0.140910f, 0.862666f, 0.254288f, 0.665951f, 0.816726f, 0.607181f, 0.957489f, 0.708883f, 0.112752f, 0.558410f, 0.718186f, 0.801957f, 0.026321f, 0.718879f, 0.825681f, 0.746834f, 0.512349f, 0.458021f, 0.549419f, 0.704644f, 0.922914f, 0.617035f, 0.887834f, 0.701257f, 0.068336f, 0.500828f, 0.286486f, 0.285175f, 0.355928f, 0.314733f, 0.578610f, 0.683601f, 0.268749f, 0.129763f, 0.058809f, 0.575753f, 0.186130f, 0.009248f, 0.927753f, 0.537140f, 0.092448f, 0.842921f, 0.983203f, 0.448601f, 0.042490f, 0.117546f, 0.381654f, 0.885523f, 0.148039f, 0.823990f, 0.014976f, 0.457389f, 0.644397f, 0.060379f, 0.614763f, 0.944404f, 0.160260f, 0.729611f, 0.609094f, 0.185116f, 0.006203f, 0.009284f, 0.532092f, 0.942779f, 0.644299f, 0.714300f, 0.493865f, 0.581889f, 0.126368f, 0.876821f, 0.760793f, 0.998199f, 0.297723f, 0.227018f, 0.125162f, 0.964210f, 0.780885f, 0.166325f, 0.552686f, 0.413768f, 0.151486f, 0.162073f, 0.963470f, 0.304964f, 0.941439f, 0.075611f, 0.460803f, 0.129619f, 0.004787f, 0.553766f, 0.113894f, 0.722025f, 0.698116f, 0.176333f, 0.941742f, 0.721043f, 0.297970f, 0.709234f, 0.731930f, 0.342226f, 0.375589f, 0.359107f, 0.616618f, 0.900410f, 0.173193f, 0.875200f}; + + ASSERT_EQ(q.size(), batch_size * q_num_heads * q_sequence_length * head_size); + ASSERT_EQ(k.size(), batch_size * kv_num_heads * kv_sequence_length * head_size); + ASSERT_EQ(v.size(), batch_size * kv_num_heads * kv_sequence_length * v_head_size); + ASSERT_EQ(m.size(), q_sequence_length * (kv_sequence_length + past_sequence_length)); + ASSERT_EQ(past_key.size(), batch_size * kv_num_heads * past_sequence_length * head_size); + ASSERT_EQ(past_value.size(), batch_size * kv_num_heads * past_sequence_length * v_head_size); + + // {2, 3, 4, 8} + std::vector y = {0.457694f, 0.455757f, 0.445489f, 0.526766f, 0.477853f, 0.608758f, 0.406654f, 0.519316f, 0.444463f, 0.465842f, 0.428262f, 0.540428f, 0.477282f, 0.638667f, 0.474591f, 0.547811f, 0.457420f, 0.470657f, 0.487116f, 0.542242f, 0.482364f, 0.617841f, 0.476829f, 0.557317f, 0.463370f, 0.432599f, 0.412642f, 0.520960f, 0.479831f, 0.589828f, 0.446331f, 0.612812f, 0.585487f, 0.538315f, 0.504264f, 0.615235f, 0.527800f, 0.515899f, 0.536401f, 0.541573f, 0.578147f, 0.544553f, 0.531175f, 0.583502f, 0.528233f, 0.518028f, 0.562917f, 0.588512f, 0.599006f, 0.525119f, 0.535656f, 0.623945f, 0.521523f, 0.515306f, 0.544257f, 0.592741f, 0.600172f, 0.529797f, 0.490615f, 0.601856f, 0.495671f, 0.500725f, 0.555493f, 0.482300f, 0.538304f, 0.469695f, 0.555198f, 0.489711f, 0.521836f, 0.485628f, 0.493937f, 0.562992f, 0.521894f, 0.489056f, 0.584299f, 0.474376f, 0.493005f, 0.475963f, 0.460919f, 0.567615f, 0.547787f, 0.466202f, 0.536014f, 0.473239f, 0.485554f, 0.498408f, 0.501733f, 0.586437f, 0.517314f, 0.440046f, 0.514271f, 0.545266f, 0.487437f, 0.481043f, 0.518498f, 0.568266f, 0.514357f, 0.572526f, 0.423650f, 0.474643f, 0.492550f, 0.533325f, 0.512998f, 0.452411f, 0.526065f, 0.535346f, 0.407074f, 0.502433f, 0.501283f, 0.528505f, 0.510491f, 0.402870f, 0.516862f, 0.596280f, 0.397160f, 0.469242f, 0.458194f, 0.537358f, 0.510243f, 0.439715f, 0.530736f, 0.580630f, 0.437646f, 0.462414f, 0.484492f, 0.477003f, 0.476393f, 0.431391f, 0.481805f, 0.420751f, 0.544359f, 0.440140f, 0.533953f, 0.453877f, 0.460864f, 0.446440f, 0.454282f, 0.416850f, 0.494072f, 0.462208f, 0.524801f, 0.453293f, 0.493179f, 0.462526f, 0.489181f, 0.452340f, 0.570383f, 0.422193f, 0.524420f, 0.468229f, 0.489729f, 0.444768f, 0.534646f, 0.457197f, 0.522207f, 0.400594f, 0.538509f, 0.489581f, 0.457599f, 0.488340f, 0.549355f, 0.482543f, 0.431908f, 0.352921f, 0.633369f, 0.690998f, 0.314418f, 0.542520f, 0.580878f, 0.489810f, 0.451832f, 0.346453f, 0.599024f, 0.630982f, 0.310195f, 0.532405f, 0.568864f, 0.486514f, 0.432211f, 0.345150f, 0.586195f, 0.659745f, 0.269926f, 0.528033f, 0.509392f, 0.511314f, 0.378251f, 0.319656f, 0.601292f, 0.726670f, 0.338636f, 0.564731f}; + // {2, 3, 18, 8} + std::vector present_key = {0.418109f, 0.172951f, 0.107211f, 0.817339f, 0.473143f, 0.882284f, 0.733289f, 0.409726f, 0.373511f, 0.515638f, 0.889060f, 0.737279f, 0.005153f, 0.694158f, 0.919507f, 0.710456f, 0.177006f, 0.483518f, 0.140316f, 0.358995f, 0.937117f, 0.923305f, 0.282837f, 0.339631f, 0.600213f, 0.963197f, 0.147801f, 0.256917f, 0.873557f, 0.491892f, 0.898961f, 0.185518f, 0.532669f, 0.326270f, 0.316543f, 0.446877f, 0.433077f, 0.357347f, 0.914971f, 0.731744f, 0.727547f, 0.289913f, 0.577709f, 0.779179f, 0.795590f, 0.344530f, 0.770873f, 0.735894f, 0.141506f, 0.865945f, 0.441321f, 0.486410f, 0.448369f, 0.567846f, 0.621169f, 0.498180f, 0.866789f, 0.627735f, 0.401428f, 0.416692f, 0.810839f, 0.348192f, 0.211455f, 0.059383f, 0.876027f, 0.918546f, 0.120120f, 0.334474f, 0.175372f, 0.115898f, 0.899867f, 0.056877f, 0.980486f, 0.096451f, 0.863471f, 0.566506f, 0.367917f, 0.342342f, 0.757364f, 0.314573f, 0.657319f, 0.517326f, 0.484966f, 0.901162f, 0.554645f, 0.826862f, 0.725574f, 0.038557f, 0.773110f, 0.216870f, 0.903150f, 0.042924f, 0.333072f, 0.099733f, 0.475589f, 0.820022f, 0.186193f, 0.944372f, 0.739551f, 0.490459f, 0.227415f, 0.254356f, 0.058029f, 0.434417f, 0.311796f, 0.696343f, 0.377752f, 0.179604f, 0.024679f, 0.067250f, 0.679393f, 0.453697f, 0.536579f, 0.896671f, 0.990339f, 0.216897f, 0.663078f, 0.263322f, 0.020651f, 0.758379f, 0.320017f, 0.383464f, 0.588317f, 0.831048f, 0.628982f, 0.872651f, 0.273542f, 0.798047f, 0.185636f, 0.952792f, 0.687488f, 0.215508f, 0.947371f, 0.730856f, 0.253942f, 0.213312f, 0.518201f, 0.025663f, 0.207470f, 0.424685f, 0.374170f, 0.463575f, 0.277629f, 0.586784f, 0.298187f, 0.150935f, 0.330267f, 0.813880f, 0.140384f, 0.227362f, 0.068852f, 0.705710f, 0.395233f, 0.310840f, 0.718626f, 0.335978f, 0.727771f, 0.815199f, 0.217663f, 0.973819f, 0.162358f, 0.290841f, 0.179795f, 0.345506f, 0.480061f, 0.522176f, 0.853606f, 0.889448f, 0.220104f, 0.622894f, 0.111496f, 0.458970f, 0.322334f, 0.316501f, 0.482584f, 0.729828f, 0.069183f, 0.879173f, 0.734814f, 0.176499f, 0.939161f, 0.506312f, 0.999809f, 0.197259f, 0.534908f, 0.290248f, 0.304174f, 0.591065f, 0.921719f, 0.805264f, 0.723941f, 0.559174f, 0.922298f, 0.492361f, 0.873832f, 0.833982f, 0.213835f, 0.771225f, 0.012171f, 0.322830f, 0.229567f, 0.506863f, 0.736853f, 0.097676f, 0.514922f, 0.938412f, 0.228647f, 0.677141f, 0.592880f, 0.010064f, 0.475826f, 0.708770f, 0.043975f, 0.879521f, 0.520081f, 0.030661f, 0.224414f, 0.953676f, 0.582320f, 0.107473f, 0.287544f, 0.456704f, 0.020950f, 0.411616f, 0.489459f, 0.243678f, 0.588639f, 0.753240f, 0.235834f, 0.620500f, 0.639622f, 0.948540f, 0.778276f, 0.848345f, 0.490420f, 0.185349f, 0.995815f, 0.129356f, 0.471457f, 0.068093f, 0.863856f, 0.117532f, 0.517379f, 0.132068f, 0.716860f, 0.396060f, 0.565421f, 0.183280f, 0.144848f, 0.488056f, 0.355613f, 0.940432f, 0.765325f, 0.748664f, 0.903720f, 0.083422f, 0.552192f, 0.584476f, 0.961936f, 0.292148f, 0.240829f, 0.100294f, 0.016430f, 0.929529f, 0.669917f, 0.785153f, 0.281730f, 0.586410f, 0.063955f, 0.485628f, 0.977495f, 0.876505f, 0.338159f, 0.961570f, 0.231702f, 0.949319f, 0.941378f, 0.799203f, 0.630448f, 0.874288f, 0.293020f, 0.848944f, 0.617877f, 0.013237f, 0.347234f, 0.148141f, 0.981829f, 0.478370f, 0.943851f, 0.964925f, 0.719389f, 0.349993f, 0.254382f, 0.265303f, 0.127294f, 0.525809f, 0.141817f, 0.316731f, 0.626706f, 0.727544f, 0.024273f, 0.430116f, 0.652125f, 0.853246f, 0.475325f, 0.969206f, 0.265633f, 0.013509f, 0.483753f, 0.256114f, 0.823718f, 0.232773f, 0.310629f, 0.791227f, 0.715143f, 0.558051f, 0.704948f, 0.418637f, 0.005310f, 0.011355f, 0.511222f, 0.083291f, 0.051075f, 0.965517f, 0.859003f, 0.152027f, 0.000664f, 0.941668f, 0.278325f, 0.185898f, 0.691508f, 0.108904f, 0.264650f, 0.975095f, 0.639463f, 0.520678f, 0.397919f, 0.774501f, 0.140957f, 0.967338f, 0.861123f, 0.617657f, 0.042906f, 0.700856f, 0.913284f, 0.524577f, 0.354225f, 0.120277f, 0.754901f, 0.885022f, 0.100252f, 0.758985f, 0.017060f, 0.967055f, 0.615058f, 0.552439f, 0.295950f, 0.929292f, 0.265906f, 0.828147f, 0.985109f, 0.783397f, 0.518990f, 0.066074f, 0.472414f, 0.438256f, 0.202796f, 0.423588f, 0.357758f, 0.163684f, 0.441374f, 0.262800f, 0.522062f, 0.035160f, 0.906231f, 0.816364f, 0.552581f, 0.851809f, 0.962395f, 0.110522f, 0.630832f, 0.997994f, 0.987889f, 0.603323f, 0.497391f, 0.639473f, 0.368585f, 0.136900f, 0.822118f, 0.189848f, 0.511319f, 0.224317f, 0.097844f, 0.862191f, 0.972919f, 0.960835f, 0.906555f, 0.774047f, 0.333145f, 0.081101f, 0.407241f, 0.232234f, 0.132488f, 0.053427f, 0.725594f, 0.011427f, 0.770581f, 0.146947f, 0.079522f, 0.089603f, 0.672048f, 0.245367f, 0.420539f, 0.557369f, 0.860551f, 0.727044f, 0.270328f, 0.131483f, 0.055374f, 0.301599f, 0.262118f, 0.456141f, 0.683281f, 0.695625f, 0.283519f, 0.379927f, 0.181151f, 0.788545f, 0.056848f, 0.696997f, 0.778695f, 0.777408f, 0.128021f, 0.583193f, 0.002065f, 0.198911f, 0.956123f, 0.330441f, 0.638390f, 0.280860f, 0.947822f, 0.728559f, 0.329651f, 0.791761f, 0.108166f, 0.392319f, 0.221218f, 0.683726f, 0.102446f, 0.397026f, 0.276650f, 0.506343f, 0.349898f, 0.706411f, 0.024577f, 0.633987f, 0.230571f, 0.268709f, 0.800256f, 0.955568f, 0.316550f, 0.826805f, 0.103991f, 0.633982f, 0.751032f, 0.155978f, 0.426002f, 0.892707f, 0.103578f, 0.018096f, 0.590585f, 0.435532f, 0.798689f, 0.923456f, 0.299154f, 0.388404f, 0.486272f, 0.588151f, 0.983854f, 0.697330f, 0.389549f, 0.263768f, 0.944626f, 0.135548f, 0.720266f, 0.925395f, 0.664666f, 0.423054f, 0.198991f, 0.367475f, 0.706872f, 0.649534f, 0.927976f, 0.866861f, 0.816151f, 0.911451f, 0.276337f, 0.369524f, 0.379894f, 0.560451f, 0.668218f, 0.286717f, 0.019462f, 0.399222f, 0.308528f, 0.942185f, 0.888265f, 0.860311f, 0.653000f, 0.344289f, 0.548849f, 0.815225f, 0.098610f, 0.801075f, 0.041180f, 0.816421f, 0.807564f, 0.051007f, 0.627161f, 0.502453f, 0.169820f, 0.148379f, 0.773259f, 0.567693f, 0.982999f, 0.982248f, 0.992667f, 0.118616f, 0.259423f, 0.373813f, 0.587600f, 0.272822f, 0.370853f, 0.197054f, 0.459856f, 0.044612f, 0.799796f, 0.076956f, 0.518835f, 0.306810f, 0.577543f, 0.959433f, 0.645570f, 0.035362f, 0.430402f, 0.510017f, 0.536178f, 0.681392f, 0.277596f, 0.128861f, 0.392676f, 0.956406f, 0.187131f, 0.903984f, 0.543806f, 0.456911f, 0.882041f, 0.458604f, 0.724168f, 0.399025f, 0.904044f, 0.690025f, 0.699622f, 0.327720f, 0.756779f, 0.636061f, 0.240020f, 0.160539f, 0.796391f, 0.959167f, 0.458139f, 0.590984f, 0.857723f, 0.457223f, 0.951874f, 0.575751f, 0.938256f, 0.244570f, 0.458212f, 0.757407f, 0.203621f, 0.566312f, 0.185817f, 0.104736f, 0.116559f, 0.357639f, 0.004655f, 0.424854f, 0.664197f, 0.401688f, 0.085795f, 0.062689f, 0.278117f, 0.169313f, 0.965095f, 0.151230f, 0.805462f, 0.586108f, 0.569287f, 0.512081f, 0.971763f, 0.363845f, 0.787916f, 0.555294f, 0.395634f, 0.955466f, 0.598316f, 0.118917f, 0.417539f, 0.781582f, 0.693747f, 0.916340f, 0.259377f, 0.758194f, 0.459875f, 0.573610f, 0.955047f, 0.979286f, 0.861591f, 0.359097f, 0.887701f, 0.638609f, 0.429997f, 0.035743f, 0.770128f, 0.502106f, 0.786188f, 0.748023f, 0.793567f, 0.300651f, 0.800799f, 0.548846f, 0.473326f, 0.675126f, 0.021359f, 0.102317f, 0.292177f, 0.982990f, 0.139746f, 0.330596f, 0.051053f, 0.331269f, 0.320326f, 0.946807f, 0.845154f, 0.382764f, 0.024769f, 0.831031f, 0.660536f, 0.152364f, 0.996071f, 0.100233f, 0.867115f, 0.294266f, 0.435353f, 0.795457f, 0.677508f, 0.937864f, 0.621140f, 0.097810f, 0.884360f, 0.769156f, 0.711870f, 0.053734f, 0.396223f, 0.167436f, 0.821904f, 0.700529f, 0.883078f, 0.966575f, 0.774748f, 0.994233f, 0.820767f, 0.908844f, 0.815524f, 0.159414f, 0.628898f, 0.398434f, 0.062713f, 0.424032f, 0.258684f, 0.849038f, 0.033305f, 0.958983f, 0.355369f, 0.356707f, 0.016329f, 0.185232f, 0.401260f, 0.929291f, 0.099615f, 0.945302f, 0.869489f, 0.454162f, 0.326701f, 0.232744f, 0.614465f, 0.033075f, 0.015606f, 0.428796f, 0.068074f, 0.251941f, 0.221161f, 0.253191f, 0.131055f, 0.012036f, 0.115484f, 0.618480f, 0.974256f, 0.990345f, 0.409054f, 0.162954f, 0.638762f, 0.490305f, 0.989410f, 0.065304f, 0.783234f, 0.288399f, 0.241419f, 0.662505f, 0.614770f, 0.037130f, 0.014252f, 0.342104f, 0.823472f, 0.866135f, 0.960813f, 0.065121f, 0.044571f, 0.913284f, 0.305047f, 0.557987f, 0.982445f, 0.400449f, 0.665871f, 0.400880f, 0.768195f, 0.527715f, 0.237523f, 0.271306f, 0.258059f, 0.532320f, 0.703189f, 0.949280f, 0.694087f, 0.781193f, 0.168926f, 0.374063f, 0.413780f, 0.686380f, 0.295892f, 0.303292f, 0.355889f, 0.810302f, 0.577590f, 0.075277f, 0.078246f, 0.371287f, 0.766591f, 0.688683f, 0.707982f, 0.767210f, 0.287153f, 0.548256f, 0.543353f, 0.739632f, 0.956871f, 0.277990f, 0.793282f, 0.659971f, 0.580238f, 0.774880f, 0.944032f, 0.036691f, 0.147400f, 0.756287f, 0.083791f, 0.516124f, 0.219861f, 0.274296f, 0.701840f, 0.030193f, 0.873319f, 0.444479f, 0.502393f, 0.540048f, 0.645544f, 0.344857f, 0.101107f, 0.318379f, 0.168142f, 0.556133f, 0.318029f, 0.958067f, 0.965734f, 0.620126f, 0.617497f, 0.985379f, 0.887283f, 0.765070f, 0.313591f, 0.365539f, 0.201267f, 0.487148f, 0.990369f, 0.912151f, 0.118349f, 0.025190f, 0.898638f, 0.537170f, 0.200190f, 0.673653f, 0.644223f, 0.122086f, 0.259600f, 0.060078f, 0.246063f, 0.665859f, 0.517309f, 0.424089f, 0.554688f, 0.287052f, 0.706575f, 0.414857f, 0.360546f, 0.828657f, 0.924967f, 0.046007f, 0.232627f, 0.348519f, 0.814966f, 0.985491f, 0.968972f, 0.904948f, 0.296556f, 0.992011f, 0.249420f, 0.105906f, 0.950953f, 0.233420f, 0.689768f, 0.058356f, 0.730709f, 0.881720f, 0.272437f, 0.379057f, 0.374296f, 0.748788f, 0.237807f, 0.171853f, 0.449292f, 0.304468f, 0.839189f, 0.237742f, 0.502389f, 0.942584f, 0.633998f, 0.867289f, 0.940210f, 0.750765f, 0.699575f, 0.967966f, 0.994401f, 0.451822f}; + // {2, 3, 18, 8} + std::vector present_value = {0.209860f, 0.132306f, 0.193236f, 0.685467f, 0.049500f, 0.101855f, 0.134174f, 0.316541f, 0.298750f, 0.255064f, 0.750537f, 0.998023f, 0.533978f, 0.944203f, 0.396610f, 0.106682f, 0.408774f, 0.296128f, 0.493407f, 0.657044f, 0.461050f, 0.935161f, 0.884765f, 0.701978f, 0.489685f, 0.131687f, 0.397014f, 0.704402f, 0.284886f, 0.103988f, 0.907898f, 0.709051f, 0.615276f, 0.792499f, 0.835646f, 0.483459f, 0.881188f, 0.916419f, 0.271551f, 0.607545f, 0.526584f, 0.537946f, 0.937663f, 0.305189f, 0.983434f, 0.902131f, 0.458723f, 0.817453f, 0.769047f, 0.677895f, 0.319834f, 0.196451f, 0.671528f, 0.842973f, 0.016253f, 0.642803f, 0.442873f, 0.898088f, 0.321473f, 0.474185f, 0.514767f, 0.140440f, 0.712892f, 0.830476f, 0.057909f, 0.291389f, 0.038045f, 0.956544f, 0.667169f, 0.964200f, 0.531494f, 0.802069f, 0.374414f, 0.353819f, 0.378268f, 0.657862f, 0.359453f, 0.900367f, 0.983275f, 0.030427f, 0.193623f, 0.112250f, 0.042364f, 0.227741f, 0.446793f, 0.836990f, 0.221824f, 0.493945f, 0.929619f, 0.667215f, 0.798079f, 0.550994f, 0.980466f, 0.588662f, 0.045511f, 0.197983f, 0.070870f, 0.292794f, 0.152355f, 0.417486f, 0.131289f, 0.604118f, 0.382808f, 0.895386f, 0.967795f, 0.546885f, 0.274824f, 0.592230f, 0.896761f, 0.406733f, 0.552078f, 0.271653f, 0.455444f, 0.401714f, 0.248413f, 0.505866f, 0.310381f, 0.373035f, 0.524970f, 0.750595f, 0.333507f, 0.924159f, 0.862319f, 0.048690f, 0.253643f, 0.446136f, 0.104628f, 0.348476f, 0.740098f, 0.680514f, 0.622384f, 0.710528f, 0.204924f, 0.341698f, 0.676242f, 0.879235f, 0.543678f, 0.282700f, 0.030235f, 0.710337f, 0.007884f, 0.372679f, 0.530537f, 0.922111f, 0.404774f, 0.601277f, 0.771931f, 0.413086f, 0.710058f, 0.789869f, 0.317260f, 0.979270f, 0.649656f, 0.880998f, 0.555938f, 0.741603f, 0.770544f, 0.908248f, 0.150350f, 0.558283f, 0.428379f, 0.923159f, 0.105095f, 0.982574f, 0.875451f, 0.073826f, 0.490966f, 0.717560f, 0.738152f, 0.906494f, 0.799865f, 0.310930f, 0.498435f, 0.701786f, 0.138437f, 0.193991f, 0.481042f, 0.298246f, 0.862559f, 0.586277f, 0.348665f, 0.848833f, 0.804878f, 0.998355f, 0.847308f, 0.414457f, 0.127499f, 0.840641f, 0.059758f, 0.350271f, 0.919738f, 0.960766f, 0.640565f, 0.688648f, 0.042454f, 0.514480f, 0.546868f, 0.340101f, 0.068597f, 0.228908f, 0.357984f, 0.435142f, 0.590927f, 0.722392f, 0.317632f, 0.328954f, 0.019692f, 0.040875f, 0.257822f, 0.740245f, 0.628314f, 0.769789f, 0.768919f, 0.856567f, 0.720319f, 0.979011f, 0.898825f, 0.586717f, 0.588158f, 0.034267f, 0.998527f, 0.131576f, 0.740347f, 0.821015f, 0.373055f, 0.196852f, 0.098760f, 0.748606f, 0.452654f, 0.713718f, 0.915408f, 0.146584f, 0.919171f, 0.411626f, 0.305267f, 0.943062f, 0.990652f, 0.198892f, 0.656838f, 0.106495f, 0.089495f, 0.405942f, 0.024313f, 0.342611f, 0.622231f, 0.279068f, 0.209750f, 0.115703f, 0.577140f, 0.695270f, 0.671957f, 0.948861f, 0.002703f, 0.647197f, 0.600392f, 0.588740f, 0.962770f, 0.016872f, 0.696482f, 0.813679f, 0.509807f, 0.333965f, 0.790840f, 0.097243f, 0.442036f, 0.519952f, 0.693956f, 0.090886f, 0.227759f, 0.410302f, 0.623295f, 0.886961f, 0.618826f, 0.133461f, 0.980580f, 0.871786f, 0.502721f, 0.922348f, 0.541381f, 0.923306f, 0.829897f, 0.968286f, 0.919783f, 0.036034f, 0.174772f, 0.389135f, 0.952143f, 0.300029f, 0.650914f, 0.827313f, 0.684499f, 0.417333f, 0.383066f, 0.393122f, 0.589712f, 0.881567f, 0.929066f, 0.053530f, 0.181622f, 0.112224f, 0.193335f, 0.346608f, 0.506532f, 0.629461f, 0.732142f, 0.890112f, 0.989088f, 0.662856f, 0.845365f, 0.778039f, 0.307532f, 0.875692f, 0.042763f, 0.000367f, 0.273733f, 0.462098f, 0.638363f, 0.101770f, 0.673010f, 0.801816f, 0.185313f, 0.415125f, 0.519985f, 0.451807f, 0.799830f, 0.960522f, 0.798953f, 0.077993f, 0.804936f, 0.066596f, 0.235970f, 0.153097f, 0.197519f, 0.528315f, 0.671690f, 0.470321f, 0.959696f, 0.240292f, 0.763140f, 0.870182f, 0.562066f, 0.456223f, 0.596184f, 0.428810f, 0.555194f, 0.416934f, 0.400470f, 0.695346f, 0.092851f, 0.166542f, 0.851198f, 0.771077f, 0.281454f, 0.377269f, 0.926027f, 0.818077f, 0.614346f, 0.221490f, 0.044252f, 0.431258f, 0.672627f, 0.828480f, 0.852689f, 0.032776f, 0.244157f, 0.339095f, 0.188732f, 0.802975f, 0.767466f, 0.516833f, 0.982926f, 0.144059f, 0.899652f, 0.116463f, 0.163182f, 0.696219f, 0.109570f, 0.565845f, 0.420234f, 0.728474f, 0.900675f, 0.769872f, 0.849690f, 0.032945f, 0.160468f, 0.886305f, 0.446394f, 0.907876f, 0.160230f, 0.661117f, 0.440264f, 0.076487f, 0.696463f, 0.247399f, 0.039616f, 0.059944f, 0.061079f, 0.907733f, 0.739884f, 0.898062f, 0.672582f, 0.528940f, 0.304446f, 0.997962f, 0.362189f, 0.470649f, 0.378245f, 0.979527f, 0.174658f, 0.327988f, 0.680349f, 0.063208f, 0.607249f, 0.477646f, 0.284000f, 0.238413f, 0.514513f, 0.367928f, 0.456520f, 0.337477f, 0.970494f, 0.133439f, 0.096804f, 0.343392f, 0.591027f, 0.659176f, 0.397257f, 0.999278f, 0.351893f, 0.721407f, 0.637583f, 0.813054f, 0.310196f, 0.515433f, 0.415953f, 0.231255f, 0.307874f, 0.945431f, 0.294181f, 0.353904f, 0.003710f, 0.845078f, 0.154841f, 0.204144f, 0.255265f, 0.884622f, 0.206451f, 0.797526f, 0.808049f, 0.927021f, 0.115561f, 0.217279f, 0.742898f, 0.196001f, 0.286330f, 0.166742f, 0.172697f, 0.481553f, 0.109683f, 0.321698f, 0.426594f, 0.024548f, 0.388333f, 0.094122f, 0.493579f, 0.825738f, 0.818422f, 0.080449f, 0.601228f, 0.834586f, 0.237973f, 0.761927f, 0.890764f, 0.806124f, 0.107301f, 0.009060f, 0.191724f, 0.270477f, 0.616183f, 0.384273f, 0.703407f, 0.353075f, 0.154425f, 0.312690f, 0.884324f, 0.958532f, 0.207513f, 0.788468f, 0.273349f, 0.887132f, 0.165546f, 0.665960f, 0.084211f, 0.973893f, 0.700633f, 0.841816f, 0.566669f, 0.476801f, 0.621882f, 0.528742f, 0.469384f, 0.759450f, 0.178201f, 0.171172f, 0.431843f, 0.320748f, 0.074125f, 0.844471f, 0.771603f, 0.543921f, 0.979325f, 0.072600f, 0.766669f, 0.266370f, 0.368599f, 0.219279f, 0.789038f, 0.144240f, 0.840017f, 0.661578f, 0.059023f, 0.810982f, 0.627756f, 0.904982f, 0.748722f, 0.561121f, 0.836547f, 0.278050f, 0.976226f, 0.889794f, 0.764562f, 0.698249f, 0.335498f, 0.147686f, 0.062636f, 0.241902f, 0.432281f, 0.521996f, 0.773084f, 0.958741f, 0.117320f, 0.107004f, 0.589695f, 0.745398f, 0.848150f, 0.935832f, 0.983426f, 0.399802f, 0.380335f, 0.147809f, 0.684934f, 0.656762f, 0.862063f, 0.097258f, 0.497777f, 0.581082f, 0.241557f, 0.169025f, 0.859581f, 0.058535f, 0.470621f, 0.115834f, 0.457059f, 0.979962f, 0.423706f, 0.857125f, 0.117316f, 0.271252f, 0.403793f, 0.399812f, 0.671384f, 0.344718f, 0.713767f, 0.639187f, 0.399161f, 0.431760f, 0.546950f, 0.293617f, 0.968204f, 0.226196f, 0.015738f, 0.325855f, 0.502509f, 0.028363f, 0.559248f, 0.874283f, 0.704732f, 0.622968f, 0.955962f, 0.958279f, 0.824266f, 0.607742f, 0.487765f, 0.013316f, 0.606262f, 0.989088f, 0.818101f, 0.340605f, 0.152047f, 0.784059f, 0.743938f, 0.967047f, 0.874842f, 0.555663f, 0.101284f, 0.483501f, 0.313695f, 0.512408f, 0.301702f, 0.861823f, 0.844327f, 0.315465f, 0.599581f, 0.430181f, 0.909093f, 0.187361f, 0.697728f, 0.970375f, 0.175276f, 0.201966f, 0.693723f, 0.779154f, 0.490549f, 0.609686f, 0.212682f, 0.476614f, 0.112072f, 0.321422f, 0.284780f, 0.444625f, 0.930126f, 0.181268f, 0.401388f, 0.615597f, 0.946557f, 0.133148f, 0.917877f, 0.081054f, 0.480741f, 0.454590f, 0.209603f, 0.347460f, 0.454165f, 0.865211f, 0.955064f, 0.518926f, 0.870100f, 0.608172f, 0.349087f, 0.194194f, 0.413135f, 0.522824f, 0.044443f, 0.145841f, 0.600184f, 0.225002f, 0.837326f, 0.326942f, 0.104834f, 0.083531f, 0.937123f, 0.118020f, 0.140910f, 0.862666f, 0.254288f, 0.665951f, 0.816726f, 0.607181f, 0.957489f, 0.708883f, 0.112752f, 0.558410f, 0.614528f, 0.070042f, 0.822407f, 0.653421f, 0.726342f, 0.536923f, 0.110477f, 0.405036f, 0.405374f, 0.321043f, 0.029950f, 0.737254f, 0.109784f, 0.606308f, 0.703218f, 0.634786f, 0.959142f, 0.103298f, 0.867167f, 0.029190f, 0.534917f, 0.404244f, 0.524184f, 0.365100f, 0.190567f, 0.019123f, 0.518150f, 0.842777f, 0.373216f, 0.222864f, 0.080532f, 0.085311f, 0.221396f, 0.100014f, 0.265040f, 0.066149f, 0.065605f, 0.856276f, 0.162120f, 0.559682f, 0.773456f, 0.456410f, 0.153369f, 0.199596f, 0.432984f, 0.528234f, 0.349440f, 0.781480f, 0.718186f, 0.801957f, 0.026321f, 0.718879f, 0.825681f, 0.746834f, 0.512349f, 0.458021f, 0.549419f, 0.704644f, 0.922914f, 0.617035f, 0.887834f, 0.701257f, 0.068336f, 0.500828f, 0.286486f, 0.285175f, 0.355928f, 0.314733f, 0.578610f, 0.683601f, 0.268749f, 0.129763f, 0.058809f, 0.575753f, 0.186130f, 0.009248f, 0.927753f, 0.537140f, 0.092448f, 0.842921f, 0.983203f, 0.448601f, 0.042490f, 0.117546f, 0.381654f, 0.885523f, 0.148039f, 0.823990f, 0.014976f, 0.457389f, 0.644397f, 0.060379f, 0.614763f, 0.944404f, 0.160260f, 0.729611f, 0.609094f, 0.185116f, 0.006203f, 0.009284f, 0.532092f, 0.942779f, 0.644299f, 0.714300f, 0.493865f, 0.581889f, 0.126368f, 0.876821f, 0.760793f, 0.998199f, 0.297723f, 0.227018f, 0.125162f, 0.964210f, 0.780885f, 0.166325f, 0.552686f, 0.413768f, 0.151486f, 0.162073f, 0.963470f, 0.304964f, 0.941439f, 0.075611f, 0.460803f, 0.129619f, 0.004787f, 0.553766f, 0.113894f, 0.722025f, 0.698116f, 0.176333f, 0.941742f, 0.721043f, 0.297970f, 0.709234f, 0.731930f, 0.342226f, 0.375589f, 0.359107f, 0.616618f, 0.900410f, 0.173193f, 0.875200f, 0.751022f, 0.927212f, 0.028953f, 0.895691f, 0.392569f, 0.878372f, 0.690785f, 0.987349f, 0.759282f, 0.364545f, 0.501063f, 0.376389f, 0.364912f, 0.260904f, 0.495970f, 0.681740f, 0.277340f, 0.524380f, 0.117380f, 0.159845f, 0.046806f, 0.970731f, 0.003860f, 0.178580f, 0.612867f, 0.081370f, 0.881896f, 0.719620f, 0.966390f, 0.507636f, 0.300404f, 0.549501f, 0.930819f, 0.520761f, 0.267207f, 0.877399f, 0.371919f, 0.001383f, 0.247685f, 0.318234f, 0.858777f, 0.458503f, 0.444587f, 0.336102f, 0.880678f, 0.945027f, 0.991890f, 0.376741f}; + // {2, 3, 4, 18} + std::vector qk_matmul = {0.820140f, 1.059902f, 0.757718f, 0.881749f, 0.858141f, 1.036822f, 0.884175f, 0.745137f, 0.702161f, 0.857424f, 0.931616f, 0.810373f, 0.765101f, 0.618886f, 0.985434f, 1.031954f, 0.880308f, 0.622930f, 0.763532f, 0.857065f, 0.740183f, 0.789191f, 0.647322f, 0.909152f, 0.686916f, 0.854634f, 0.616661f, 0.909399f, 0.999737f, 0.690372f, 0.633938f, 0.397958f, 0.865367f, 0.924445f, 0.867537f, 0.569419f, 0.980506f, 1.169838f, 1.017614f, 1.046616f, 0.926423f, 1.190621f, 1.081360f, 0.859412f, 0.668530f, 0.881618f, 1.122157f, 0.778354f, 0.913560f, 0.629977f, 1.123444f, 1.261700f, 1.171818f, 0.666636f, 0.732417f, 0.806783f, 0.671492f, 0.704470f, 0.679564f, 0.856373f, 0.747101f, 0.574466f, 0.511335f, 0.570812f, 0.772065f, 0.486530f, 0.626328f, 0.451866f, 0.718409f, 0.895540f, 0.694231f, 0.503419f, 0.531406f, 0.847033f, 0.878291f, 0.737390f, 0.926101f, 1.027148f, 0.731989f, 0.720755f, 0.637853f, 0.523248f, 0.924757f, 0.757182f, 0.669580f, 0.979738f, 0.580251f, 1.052969f, 1.255782f, 0.775240f, 0.284305f, 0.708099f, 0.458294f, 0.381689f, 0.754442f, 0.688000f, 0.675486f, 0.683084f, 0.468356f, 0.518191f, 0.554623f, 0.658507f, 0.571695f, 0.630510f, 0.528123f, 0.531325f, 0.767081f, 0.532916f, 0.348042f, 0.636357f, 0.445687f, 0.399611f, 0.727809f, 0.686446f, 0.593512f, 0.523768f, 0.360500f, 0.423699f, 0.527520f, 0.714839f, 0.553231f, 0.662379f, 0.517964f, 0.485448f, 0.809493f, 0.494930f, 0.274371f, 0.437410f, 0.411925f, 0.342756f, 0.545288f, 0.529269f, 0.533905f, 0.380022f, 0.436475f, 0.301469f, 0.529214f, 0.526297f, 0.502613f, 0.503063f, 0.430358f, 0.614318f, 0.557536f, 0.523195f, 0.627666f, 0.646350f, 0.711912f, 0.578261f, 0.510271f, 0.666607f, 0.609787f, 0.652893f, 0.673018f, 0.618551f, 0.787326f, 1.094408f, 0.693321f, 0.857913f, 0.604598f, 0.781784f, 0.506659f, 0.587050f, 0.797275f, 0.415388f, 0.596291f, 0.560429f, 0.353030f, 0.474825f, 0.499545f, 0.677266f, 0.512789f, 0.749157f, 0.460399f, 0.860298f, 0.559970f, 0.647591f, 0.385551f, 0.412029f, 0.286456f, 0.386895f, 0.466306f, 0.448868f, 0.485777f, 0.485511f, 0.524956f, 0.380963f, 0.659871f, 0.495008f, 0.515935f, 0.440779f, 0.441189f, 0.658574f, 0.476000f, 0.713140f, 0.389744f, 0.417265f, 0.369560f, 0.531347f, 0.798962f, 0.607254f, 0.635098f, 0.675595f, 0.504633f, 0.579773f, 0.825966f, 0.745334f, 0.850824f, 0.713222f, 0.417185f, 0.949167f, 0.538440f, 0.917125f, 0.311825f, 0.475121f, 0.418353f, 0.698230f, 0.553783f, 0.653118f, 0.479333f, 0.683333f, 0.611400f, 0.926136f, 0.937356f, 1.079461f, 0.500571f, 0.941776f, 0.571910f, 0.891547f, 0.471507f, 0.728790f, 0.757396f, 0.784496f, 0.757036f, 0.999690f, 0.542418f, 0.841219f, 0.709393f, 0.945488f, 0.605568f, 1.000231f, 0.913339f, 1.138695f, 0.564313f, 1.077245f, 0.676031f, 0.922692f, 0.458828f, 0.738062f, 0.805418f, 0.864807f, 0.792745f, 1.025324f, 0.755005f, 0.867548f, 0.634732f, 0.905661f, 0.776584f, 1.184950f, 1.140206f, 1.327115f, 0.665969f, 1.196436f, 0.815515f, 1.206247f, 0.621079f, 0.985172f, 0.879408f, 1.054329f, 1.023972f, 1.311348f, 0.430584f, 0.838594f, 0.577089f, 0.887826f, 0.637326f, 0.838023f, 0.852760f, 0.930619f, 0.596678f, 1.004560f, 0.556861f, 0.837758f, 0.499217f, 0.764351f, 0.711010f, 0.774022f, 0.933743f, 0.958043f, 0.587815f, 0.233866f, 0.638163f, 0.785593f, 0.772991f, 0.770025f, 0.862170f, 0.414778f, 0.518855f, 0.729107f, 0.683017f, 0.903488f, 0.660502f, 0.396731f, 0.558027f, 0.342514f, 0.418391f, 0.680441f, 0.667967f, 0.467863f, 0.921835f, 0.926976f, 0.997494f, 1.115404f, 1.154781f, 0.618698f, 0.888651f, 1.045274f, 1.019208f, 1.253905f, 0.983391f, 0.622483f, 0.921609f, 0.369652f, 0.702290f, 1.012872f, 0.884131f, 0.593858f, 0.802401f, 1.081408f, 1.169599f, 1.146572f, 1.132834f, 0.866719f, 1.021105f, 0.884109f, 1.029369f, 1.321895f, 0.973822f, 0.871383f, 1.125121f, 0.518882f, 0.912889f, 0.876105f, 0.555648f, 0.496401f, 0.582726f, 0.730206f, 0.806009f, 0.858020f, 0.827912f, 0.515117f, 0.715055f, 0.533599f, 0.810529f, 0.887599f, 0.607516f, 0.668702f, 0.905358f, 0.279895f, 0.740854f, 0.538839f, 0.824322f, 0.920016f, 0.791579f, 0.844334f, 0.618349f, 0.989377f, 1.120477f, 0.554956f, 0.683589f, 1.280705f, 0.957804f, 0.833027f, 0.763301f, 0.786487f, 0.915324f, 0.941565f, 0.777569f, 1.361176f, 0.508790f, 0.424516f, 0.573465f, 0.405641f, 0.526471f, 0.626492f, 0.534790f, 0.428795f, 0.388423f, 0.689702f, 0.260757f, 0.438301f, 0.479575f, 0.640056f, 0.682344f, 0.519170f, 0.436916f, 0.774498f, 0.534469f, 0.702171f, 0.684503f, 0.648164f, 0.754539f, 0.828688f, 0.623366f, 0.500542f, 0.560133f, 1.098588f, 0.498203f, 0.465793f, 0.656601f, 0.886137f, 0.751770f, 0.533794f, 0.483658f, 1.098963f, 0.733365f, 0.808374f, 0.764603f, 0.755506f, 0.638693f, 0.946285f, 1.001370f, 0.578989f, 0.603487f, 1.074992f, 0.697424f, 0.812599f, 0.717330f, 0.770067f, 1.006811f, 0.783151f, 0.647946f, 1.193171f}; + ASSERT_EQ(y.size(), batch_size * q_num_heads * q_sequence_length * v_head_size); + ASSERT_EQ(present_key.size(), batch_size * kv_num_heads * (past_sequence_length + kv_sequence_length) * head_size); + ASSERT_EQ(present_value.size(), batch_size * kv_num_heads * (past_sequence_length + kv_sequence_length) * v_head_size); + ASSERT_EQ(qk_matmul.size(), batch_size * kv_num_heads * q_sequence_length * (past_sequence_length + kv_sequence_length)); + + RunTest4D(batch_size, q_num_heads, q_sequence_length, head_size, kv_sequence_length, kv_num_heads, v_head_size, past_sequence_length, + q, k, v, m, std::initializer_list(), past_key, past_value, + -1, -1, std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), -1, TensorType::kFloat, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type + y, present_key, present_value, qk_matmul, + false, true, true // disable_cpu, disable_cuda, disable_dml + ); + RunTest4D(batch_size, q_num_heads, q_sequence_length, head_size, kv_sequence_length, kv_num_heads, v_head_size, past_sequence_length, + q, k, v, m, std::initializer_list(), past_key, past_value, + -1, 0, std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), -1, TensorType::kFloat, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type + y, present_key, present_value, qk_matmul, + false, true, true // disable_cpu, disable_cuda, disable_dml + ); + + qk_matmul = std::vector{1.786287f, 1.851782f, 1.433406f, 1.126638f, 1.074598f, 1.202869f, 1.806932f, 1.039214f, 1.155254f, 1.351381f, 1.709788f, 1.654608f, 0.904174f, 1.045790f, 1.828289f, 1.849986f, 0.982722f, 0.779313f, 1.067731f, 0.932425f, 1.164846f, 0.896809f, 1.215540f, 1.155709f, 1.283348f, 0.972161f, 1.592545f, 1.841960f, 1.391534f, 0.932551f, 0.884336f, 0.881353f, 0.905360f, 1.564150f, 1.275840f, 0.946826f, 1.789871f, 1.878873f, 1.971947f, 1.398552f, 1.823965f, 1.960587f, 1.438784f, 1.481077f, 0.957099f, 1.756017f, 1.234584f, 0.990787f, 1.096593f, 1.033003f, 1.868677f, 1.788607f, 1.659495f, 0.667182f, 1.157819f, 0.870338f, 0.879745f, 1.636864f, 0.894962f, 1.714711f, 1.549994f, 0.733612f, 1.117046f, 0.686474f, 1.499953f, 1.123992f, 1.438267f, 0.931251f, 1.633272f, 0.944889f, 0.987120f, 1.218472f, 1.497553f, 1.638913f, 1.553980f, 0.982279f, 1.142558f, 1.193196f, 1.654746f, 1.014832f, 1.090946f, 1.017206f, 1.702928f, 1.601417f, 0.808653f, 1.406642f, 1.423106f, 1.871002f, 1.358196f, 0.931623f, 0.588504f, 0.783458f, 0.882957f, 0.489307f, 1.322660f, 0.934557f, 1.271919f, 0.800610f, 1.444240f, 1.450752f, 0.946420f, 0.900686f, 0.822093f, 1.113904f, 0.568116f, 1.171030f, 1.175384f, 0.910323f, 1.157407f, 1.345392f, 1.400021f, 0.751548f, 1.625352f, 1.456414f, 0.950937f, 1.145433f, 0.649070f, 1.298100f, 0.639947f, 0.927273f, 0.736265f, 1.065406f, 1.263197f, 1.012355f, 1.297169f, 0.495477f, 0.699773f, 0.500964f, 0.620178f, 1.275150f, 0.760687f, 1.387608f, 1.336798f, 0.539168f, 1.042187f, 0.417132f, 1.257103f, 1.163759f, 1.314552f, 0.982448f, 1.345221f, 0.663667f, 0.850426f, 1.238248f, 1.593812f, 1.438230f, 1.387601f, 0.823150f, 0.726727f, 0.832655f, 1.532544f, 0.946970f, 1.126112f, 1.112509f, 1.565497f, 1.938642f, 0.832394f, 1.284816f, 1.447452f, 1.599816f, 0.609072f, 0.743433f, 1.101475f, 0.490747f, 1.020954f, 0.668047f, 0.921248f, 0.721382f, 1.095978f, 0.794792f, 1.488673f, 1.681718f, 0.852196f, 1.102478f, 0.810369f, 1.130985f, 0.425544f, 1.051735f, 0.694759f, 0.764302f, 1.275671f, 1.157903f, 1.440112f, 0.837447f, 1.422500f, 1.150930f, 1.017296f, 1.116673f, 0.804505f, 1.315179f, 0.553615f, 0.871008f, 0.659033f, 1.116166f, 1.134977f, 0.944172f, 0.857236f, 0.531893f, 1.224364f, 0.670808f, 0.843351f, 1.607988f, 0.720031f, 1.438111f, 1.628858f, 0.904480f, 1.456536f, 0.828884f, 1.145072f, 1.586629f, 1.350379f, 1.396510f, 1.226688f, 0.524469f, 0.711242f, 1.413283f, 1.519931f, 1.444998f, 1.155023f, 0.928222f, 0.827857f, 1.092185f, 1.860113f, 1.373539f, 0.953664f, 1.435734f, 1.350082f, 1.735783f, 0.610580f, 1.155694f, 1.600251f, 1.602529f, 0.859450f, 1.156073f, 0.846617f, 0.916578f, 1.134056f, 1.053106f, 1.173786f, 1.246788f, 1.509772f, 1.256221f, 1.540197f, 2.009806f, 1.067828f, 1.164871f, 0.709226f, 1.221456f, 0.845411f, 1.504512f, 1.201048f, 1.402731f, 1.564370f, 1.576583f, 1.589067f, 1.257597f, 1.674126f, 1.954917f, 1.497631f, 1.948780f, 0.954539f, 2.070836f, 0.927942f, 1.418681f, 0.804113f, 1.388198f, 1.624642f, 1.581236f, 1.511648f, 1.311894f, 0.855986f, 0.902148f, 0.785342f, 1.820220f, 0.852723f, 1.696361f, 1.655653f, 1.089764f, 1.202390f, 1.120222f, 1.284748f, 1.475221f, 1.311156f, 1.243736f, 1.625873f, 0.823371f, 1.226631f, 1.673096f, 1.553962f, 1.025746f, 1.313852f, 1.030482f, 0.989448f, 0.936074f, 1.784927f, 0.708855f, 0.971949f, 1.223065f, 1.461189f, 1.747723f, 0.799575f, 0.823636f, 1.400882f, 1.160547f, 0.520804f, 0.836825f, 0.972166f, 0.543222f, 1.346498f, 1.034594f, 1.565712f, 1.361961f, 1.751214f, 0.736224f, 1.864534f, 1.977835f, 1.411005f, 1.496084f, 1.233789f, 1.105877f, 0.961602f, 1.009357f, 1.110593f, 1.390279f, 1.693497f, 1.302893f, 1.756735f, 1.433344f, 2.067142f, 1.916540f, 1.490259f, 1.488384f, 1.309675f, 1.758509f, 1.141796f, 1.534330f, 1.156855f, 1.274409f, 1.870354f, 1.045789f, 1.400564f, 0.876651f, 0.981051f, 0.559955f, 0.790979f, 1.662600f, 1.021407f, 1.716358f, 1.630805f, 0.674263f, 1.320767f, 0.649261f, 1.538417f, 1.525061f, 1.419455f, 1.148088f, 1.820221f, 0.329244f, 1.033743f, 1.253892f, 1.790469f, 1.711897f, 1.467268f, 1.089224f, 0.834806f, 1.155425f, 2.043234f, 0.849033f, 1.136683f, 1.774663f, 1.735976f, 1.677263f, 0.902375f, 1.213391f, 1.758179f, 1.759598f, 0.879983f, 1.517559f, 0.812989f, 0.499876f, 0.998129f, 0.513259f, 1.094689f, 0.873050f, 1.131224f, 0.546321f, 1.364307f, 1.622263f, 0.652555f, 0.680481f, 0.729973f, 1.123450f, 0.722337f, 1.158875f, 0.845219f, 1.151906f, 1.343835f, 1.411206f, 1.638837f, 1.000100f, 1.652081f, 1.598655f, 0.980791f, 1.122207f, 0.848703f, 1.972988f, 0.610630f, 0.678227f, 0.839634f, 1.289163f, 1.497003f, 1.060701f, 0.971334f, 1.099509f, 1.158767f, 0.871929f, 0.972856f, 1.687900f, 0.854091f, 1.804623f, 1.804263f, 0.738135f, 1.209199f, 1.190654f, 1.425313f, 1.450061f, 1.529269f, 1.249452f, 1.921674f, 0.832500f, 0.940835f, 1.908224f}; + RunTest4D(batch_size, q_num_heads, q_sequence_length, head_size, kv_sequence_length, kv_num_heads, v_head_size, past_sequence_length, + q, k, v, m, std::initializer_list(), past_key, past_value, + -1, 1, std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), -1, TensorType::kFloat, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type + y, present_key, present_value, qk_matmul, + false, true, true // disable_cpu, disable_cuda, disable_dml + ); + RunTest4D(batch_size, q_num_heads, q_sequence_length, head_size, kv_sequence_length, kv_num_heads, v_head_size, past_sequence_length, + q, k, v, m, std::initializer_list(), past_key, past_value, + -1, 2, std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), -1, TensorType::kFloat, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type + y, present_key, present_value, qk_matmul, + false, true, true // disable_cpu, disable_cuda, disable_dml + ); + + qk_matmul = std::vector{0.079204f, 0.084565f, 0.055653f, 0.040951f, 0.038874f, 0.044195f, 0.080856f, 0.037523f, 0.042140f, 0.051271f, 0.073371f, 0.069432f, 0.032783f, 0.037770f, 0.082601f, 0.084413f, 0.035462f, 0.028935f, 0.048528f, 0.042386f, 0.053477f, 0.040903f, 0.056258f, 0.052990f, 0.060205f, 0.044104f, 0.082018f, 0.105252f, 0.067083f, 0.042392f, 0.040396f, 0.040276f, 0.041254f, 0.079722f, 0.059754f, 0.043001f, 0.069900f, 0.076406f, 0.083859f, 0.047264f, 0.072324f, 0.082912f, 0.049204f, 0.051330f, 0.030395f, 0.067573f, 0.040116f, 0.031437f, 0.034945f, 0.032792f, 0.075631f, 0.069811f, 0.061356f, 0.022746f, 0.052157f, 0.039125f, 0.039495f, 0.084209f, 0.040101f, 0.091026f, 0.077202f, 0.034126f, 0.050073f, 0.032554f, 0.073434f, 0.050422f, 0.069041f, 0.041583f, 0.083907f, 0.042154f, 0.043972f, 0.055418f, 0.062936f, 0.072492f, 0.066589f, 0.037594f, 0.044129f, 0.046421f, 0.073649f, 0.038838f, 0.041909f, 0.038930f, 0.077284f, 0.069824f, 0.031602f, 0.057467f, 0.058421f, 0.091429f, 0.054749f, 0.035737f, 0.036234f, 0.044034f, 0.048640f, 0.032812f, 0.075502f, 0.051216f, 0.071766f, 0.044795f, 0.085263f, 0.085820f, 0.051827f, 0.049510f, 0.045768f, 0.061277f, 0.035503f, 0.064879f, 0.065162f, 0.049990f, 0.057976f, 0.069967f, 0.073895f, 0.038636f, 0.092571f, 0.078182f, 0.047161f, 0.057286f, 0.034872f, 0.066735f, 0.034556f, 0.046058f, 0.038050f, 0.052880f, 0.064446f, 0.050148f, 0.066673f, 0.029907f, 0.040424f, 0.033136f, 0.037332f, 0.071867f, 0.042963f, 0.080421f, 0.076436f, 0.034427f, 0.056931f, 0.030472f, 0.070581f, 0.064291f, 0.074755f, 0.053630f, 0.077083f, 0.038991f, 0.046997f, 0.069263f, 0.077018f, 0.065921f, 0.062667f, 0.035637f, 0.032361f, 0.035977f, 0.072441f, 0.040334f, 0.048247f, 0.047595f, 0.074868f, 0.108730f, 0.035968f, 0.056545f, 0.066532f, 0.077482f, 0.028769f, 0.032906f, 0.062422f, 0.033892f, 0.057593f, 0.040467f, 0.052127f, 0.042684f, 0.062080f, 0.045935f, 0.091938f, 0.111515f, 0.048649f, 0.062485f, 0.046656f, 0.064291f, 0.031753f, 0.059393f, 0.041563f, 0.044556f, 0.069887f, 0.062123f, 0.082378f, 0.045090f, 0.080940f, 0.061691f, 0.053974f, 0.059613f, 0.043629f, 0.072703f, 0.033948f, 0.046629f, 0.037722f, 0.059583f, 0.060715f, 0.050168f, 0.045991f, 0.033218f, 0.056448f, 0.032452f, 0.038564f, 0.082843f, 0.034089f, 0.069900f, 0.084590f, 0.040994f, 0.071200f, 0.038010f, 0.052145f, 0.081092f, 0.064029f, 0.067052f, 0.056579f, 0.028034f, 0.033791f, 0.068186f, 0.068271f, 0.063343f, 0.047398f, 0.037780f, 0.034172f, 0.044511f, 0.095935f, 0.058974f, 0.038754f, 0.062758f, 0.057607f, 0.084719f, 0.027499f, 0.047430f, 0.073981f, 0.074150f, 0.035269f, 0.047448f, 0.036752f, 0.039415f, 0.048991f, 0.045181f, 0.050976f, 0.054837f, 0.071332f, 0.055356f, 0.073536f, 0.117610f, 0.045851f, 0.050524f, 0.032034f, 0.053465f, 0.036708f, 0.070958f, 0.052385f, 0.064091f, 0.057214f, 0.057917f, 0.058645f, 0.042099f, 0.063851f, 0.084550f, 0.053520f, 0.084033f, 0.031093f, 0.094942f, 0.030276f, 0.049457f, 0.026750f, 0.047972f, 0.060768f, 0.058187f, 0.054276f, 0.044448f, 0.035207f, 0.036870f, 0.032806f, 0.092340f, 0.035092f, 0.081583f, 0.078329f, 0.044479f, 0.049782f, 0.045855f, 0.054055f, 0.065397f, 0.055502f, 0.051883f, 0.076030f, 0.034077f, 0.051003f, 0.079707f, 0.080020f, 0.047184f, 0.062939f, 0.047408f, 0.045502f, 0.043137f, 0.100811f, 0.034370f, 0.044713f, 0.057477f, 0.072930f, 0.097129f, 0.037633f, 0.038550f, 0.068662f, 0.053994f, 0.028478f, 0.039062f, 0.038495f, 0.025068f, 0.055973f, 0.040975f, 0.069692f, 0.056845f, 0.083897f, 0.030405f, 0.093963f, 0.105236f, 0.059703f, 0.065004f, 0.050007f, 0.044003f, 0.038091f, 0.039954f, 0.044211f, 0.058478f, 0.065917f, 0.044603f, 0.070220f, 0.050818f, 0.095779f, 0.082388f, 0.053794f, 0.053693f, 0.044906f, 0.070345f, 0.037966f, 0.056218f, 0.038542f, 0.043350f, 0.078669f, 0.034491f, 0.049179f, 0.029124f, 0.042079f, 0.027618f, 0.034795f, 0.083187f, 0.043812f, 0.087782f, 0.080584f, 0.030962f, 0.059102f, 0.030197f, 0.073473f, 0.072498f, 0.065232f, 0.049729f, 0.097389f, 0.021927f, 0.044356f, 0.055279f, 0.076017f, 0.070273f, 0.055023f, 0.037702f, 0.029233f, 0.040282f, 0.097878f, 0.029652f, 0.039534f, 0.074825f, 0.071985f, 0.067881f, 0.031276f, 0.042686f, 0.073602f, 0.073706f, 0.030584f, 0.057861f, 0.047710f, 0.034884f, 0.057413f, 0.035354f, 0.063233f, 0.050663f, 0.065586f, 0.036542f, 0.082802f, 0.107169f, 0.040638f, 0.041789f, 0.043909f, 0.065079f, 0.043575f, 0.067425f, 0.049272f, 0.066957f, 0.059910f, 0.064085f, 0.080467f, 0.042483f, 0.081539f, 0.077297f, 0.041671f, 0.048000f, 0.036514f, 0.112392f, 0.028779f, 0.030791f, 0.036185f, 0.056722f, 0.069826f, 0.045137f, 0.041278f, 0.046923f, 0.044357f, 0.033296f, 0.036832f, 0.075295f, 0.032707f, 0.084617f, 0.084586f, 0.029126f, 0.046652f, 0.045794f, 0.057906f, 0.059357f, 0.064250f, 0.048568f, 0.095124f, 0.032009f, 0.035671f, 0.093853f}; + RunTest4D(batch_size, q_num_heads, q_sequence_length, head_size, kv_sequence_length, kv_num_heads, v_head_size, past_sequence_length, + q, k, v, m, std::initializer_list(), past_key, past_value, + -1, 3, std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), -1, TensorType::kFloat, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type + y, present_key, present_value, qk_matmul, + false, true, true // disable_cpu, disable_cuda, disable_dml + ); + + y = std::vector{0.466021f, 0.458662f, 0.433769f, 0.544055f, 0.483743f, 0.601701f, 0.452252f, 0.558874f, 0.462717f, 0.462769f, 0.429452f, 0.544879f, 0.480609f, 0.607708f, 0.462766f, 0.570020f, 0.465546f, 0.464215f, 0.442318f, 0.544785f, 0.481242f, 0.599103f, 0.465833f, 0.567976f, 0.466527f, 0.450295f, 0.420681f, 0.541622f, 0.478068f, 0.592818f, 0.453533f, 0.586057f, 0.586788f, 0.542723f, 0.521934f, 0.605385f, 0.523076f, 0.515204f, 0.538008f, 0.539990f, 0.580554f, 0.544345f, 0.524057f, 0.593493f, 0.520281f, 0.513084f, 0.549197f, 0.556567f, 0.590750f, 0.536522f, 0.528383f, 0.608365f, 0.523467f, 0.511267f, 0.533588f, 0.556113f, 0.589547f, 0.537869f, 0.512585f, 0.601047f, 0.507374f, 0.511124f, 0.547465f, 0.512627f, 0.537318f, 0.460441f, 0.540844f, 0.491120f, 0.495359f, 0.476360f, 0.487767f, 0.575867f, 0.522542f, 0.469555f, 0.552479f, 0.488850f, 0.498227f, 0.480921f, 0.484224f, 0.563258f, 0.536463f, 0.455656f, 0.529199f, 0.484251f, 0.487531f, 0.482517f, 0.496116f, 0.576080f, 0.527226f, 0.455449f, 0.525402f, 0.516090f, 0.487896f, 0.477256f, 0.499739f, 0.574474f, 0.520127f, 0.578615f, 0.430572f, 0.471035f, 0.475543f, 0.515079f, 0.488231f, 0.438589f, 0.525065f, 0.569547f, 0.430350f, 0.477609f, 0.478081f, 0.515330f, 0.479993f, 0.427992f, 0.520505f, 0.584227f, 0.430333f, 0.470616f, 0.468772f, 0.517313f, 0.478180f, 0.435562f, 0.527655f, 0.580609f, 0.440415f, 0.475648f, 0.474939f, 0.501466f, 0.474016f, 0.433277f, 0.489508f, 0.425301f, 0.542249f, 0.446878f, 0.532601f, 0.462732f, 0.460696f, 0.462333f, 0.480973f, 0.421038f, 0.522864f, 0.446350f, 0.525882f, 0.466933f, 0.459678f, 0.470179f, 0.485580f, 0.431242f, 0.545418f, 0.440407f, 0.527849f, 0.471587f, 0.464982f, 0.464551f, 0.502461f, 0.437563f, 0.528884f, 0.426691f, 0.531206f, 0.480744f, 0.460218f, 0.480733f, 0.543597f, 0.506559f, 0.419551f, 0.372524f, 0.622818f, 0.678228f, 0.309035f, 0.543150f, 0.561392f, 0.501923f, 0.420097f, 0.368626f, 0.607674f, 0.661294f, 0.315077f, 0.540017f, 0.552392f, 0.506226f, 0.409681f, 0.376208f, 0.608944f, 0.674258f, 0.301188f, 0.537046f, 0.536986f, 0.515894f, 0.402735f, 0.364314f, 0.612694f, 0.684161f, 0.315733f, 0.553979f}; + qk_matmul = std::vector{0.945367f, 0.951913f, 0.892363f, 0.809865f, 0.791187f, 0.834528f, 0.947519f, 0.777578f, 0.819487f, 0.874379f, 0.936622f, 0.929487f, 0.718324f, 0.780164f, 0.949658f, 0.951745f, 0.754242f, 0.652312f, 0.788605f, 0.731722f, 0.822613f, 0.714741f, 0.838334f, 0.819636f, 0.857374f, 0.749652f, 0.920539f, 0.950983f, 0.883508f, 0.731781f, 0.708585f, 0.707096f, 0.718898f, 0.916090f, 0.855373f, 0.738343f, 0.945747f, 0.954392f, 0.961991f, 0.885038f, 0.949232f, 0.961135f, 0.893453f, 0.901670f, 0.742980f, 0.942057f, 0.843904f, 0.757698f, 0.799272f, 0.775110f, 0.953474f, 0.945613f, 0.930149f, 0.583123f, 0.820328f, 0.701546f, 0.706292f, 0.927033f, 0.713836f, 0.937223f, 0.913785f, 0.625270f, 0.806539f, 0.595712f, 0.905140f, 0.808953f, 0.893348f, 0.731177f, 0.926526f, 0.737460f, 0.756132f, 0.839203f, 0.904705f, 0.927320f, 0.914440f, 0.754051f, 0.815274f, 0.831567f, 0.929506f, 0.767753f, 0.797223f, 0.768726f, 0.935774f, 0.921882f, 0.668846f, 0.886779f, 0.890245f, 0.953685f, 0.875974f, 0.731350f, 0.528819f, 0.654687f, 0.707898f, 0.453666f, 0.867444f, 0.732712f, 0.854317f, 0.664378f, 0.894548f, 0.895841f, 0.738158f, 0.716632f, 0.676207f, 0.805438f, 0.513974f, 0.824602f, 0.825990f, 0.721287f, 0.820193f, 0.872961f, 0.885356f, 0.636072f, 0.925397f, 0.896954f, 0.740207f, 0.816236f, 0.571043f, 0.861233f, 0.564864f, 0.729320f, 0.626883f, 0.787724f, 0.851943f, 0.766734f, 0.860993f, 0.458553f, 0.604224f, 0.462875f, 0.551252f, 0.855187f, 0.641481f, 0.882643f, 0.870901f, 0.492358f, 0.778750f, 0.394511f, 0.850263f, 0.822261f, 0.865423f, 0.754124f, 0.872921f, 0.580799f, 0.691292f, 0.844955f, 0.920732f, 0.893341f, 0.882642f, 0.676781f, 0.621059f, 0.681899f, 0.910859f, 0.738408f, 0.809684f, 0.804947f, 0.916307f, 0.959426f, 0.681760f, 0.857763f, 0.895188f, 0.921641f, 0.543474f, 0.631215f, 0.801028f, 0.454809f, 0.770255f, 0.583694f, 0.726487f, 0.617765f, 0.799050f, 0.661115f, 0.903080f, 0.933084f, 0.692215f, 0.801387f, 0.669793f, 0.811356f, 0.401591f, 0.782480f, 0.601031f, 0.643604f, 0.855327f, 0.820355f, 0.893720f, 0.684454f, 0.890119f, 0.818062f, 0.768763f, 0.806408f, 0.666548f, 0.865580f, 0.503225f, 0.701886f, 0.577719f, 0.806231f, 0.812716f, 0.737133f, 0.694831f, 0.486827f, 0.840937f, 0.585511f, 0.687580f, 0.922862f, 0.616929f, 0.893317f, 0.925899f, 0.718472f, 0.896978f, 0.679876f, 0.816115f, 0.919631f, 0.874143f, 0.884595f, 0.841616f, 0.481142f, 0.611455f, 0.888189f, 0.908686f, 0.894699f, 0.819411f, 0.729764f, 0.679323f, 0.797674f, 0.952689f, 0.879496f, 0.741438f, 0.892836f, 0.874073f, 0.939736f, 0.544535f, 0.819632f, 0.921706f, 0.922048f, 0.695974f, 0.819756f, 0.689298f, 0.724275f, 0.812403f, 0.783011f, 0.825482f, 0.847380f, 0.906899f, 0.850019f, 0.912154f, 0.964714f, 0.788641f, 0.822621f, 0.610191f, 0.840083f, 0.688664f, 0.905960f, 0.833974f, 0.885940f, 0.916126f, 0.918067f, 0.920006f, 0.850400f, 0.932095f, 0.960700f, 0.904719f, 0.960224f, 0.741831f, 0.968705f, 0.729633f, 0.889323f, 0.666330f, 0.882774f, 0.925295f, 0.918795f, 0.907231f, 0.864754f, 0.694184f, 0.717342f, 0.655762f, 0.948860f, 0.692490f, 0.934953f, 0.929629f, 0.796792f, 0.834382f, 0.807646f, 0.857745f, 0.900569f, 0.864568f, 0.846518f, 0.925472f, 0.676900f, 0.841599f, 0.931960f, 0.914437f, 0.772197f, 0.865247f, 0.774102f, 0.757127f, 0.733413f, 0.945223f, 0.609958f, 0.749560f, 0.840556f, 0.897883f, 0.941116f, 0.663799f, 0.677044f, 0.885542f, 0.821218f, 0.478321f, 0.684124f, 0.749655f, 0.495423f, 0.873224f, 0.775744f, 0.916341f, 0.876847f, 0.941513f, 0.626858f, 0.953096f, 0.962428f, 0.887707f, 0.904438f, 0.843675f, 0.802600f, 0.744991f, 0.765496f, 0.804272f, 0.883232f, 0.934591f, 0.862466f, 0.942137f, 0.892350f, 0.968477f, 0.957631f, 0.903372f, 0.903027f, 0.864193f, 0.942336f, 0.815018f, 0.911163f, 0.820012f, 0.854988f, 0.953626f, 0.780164f, 0.885474f, 0.704738f, 0.753520f, 0.507944f, 0.658964f, 0.930567f, 0.770439f, 0.937423f, 0.926176f, 0.587777f, 0.866974f, 0.571172f, 0.911854f, 0.909576f, 0.889485f, 0.817120f, 0.948860f, 0.317842f, 0.775405f, 0.849371f, 0.945810f, 0.936880f, 0.899055f, 0.796595f, 0.683048f, 0.819543f, 0.966958f, 0.690564f, 0.813294f, 0.944118f, 0.939758f, 0.932505f, 0.717452f, 0.837694f, 0.942299f, 0.942458f, 0.706411f, 0.908271f, 0.671236f, 0.462019f, 0.760807f, 0.472481f, 0.798583f, 0.702920f, 0.811438f, 0.497758f, 0.877388f, 0.924952f, 0.573387f, 0.591832f, 0.623049f, 0.808766f, 0.618355f, 0.820673f, 0.688564f, 0.818385f, 0.872590f, 0.887750f, 0.927310f, 0.761636f, 0.929143f, 0.921466f, 0.753408f, 0.808335f, 0.690391f, 0.962069f, 0.544571f, 0.590366f, 0.685615f, 0.858907f, 0.904605f, 0.785932f, 0.749290f, 0.800322f, 0.820638f, 0.702353f, 0.749957f, 0.933879f, 0.693201f, 0.947283f, 0.947246f, 0.628017f, 0.836439f, 0.830782f, 0.890702f, 0.895705f, 0.910299f, 0.848130f, 0.958055f, 0.681816f, 0.735606f, 0.956936f}; + RunTest4D(batch_size, q_num_heads, q_sequence_length, head_size, kv_sequence_length, kv_num_heads, v_head_size, past_sequence_length, + q, k, v, m, std::initializer_list(), past_key, past_value, + -1, 2, std::numeric_limits::quiet_NaN(), 1.f, -1, TensorType::kFloat, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type + y, present_key, present_value, qk_matmul, + false, true, true // disable_cpu, disable_cuda, disable_dml + ); +} + +TEST(AttentionTest, Attention3DWithPastAndPresentQkMatmul) { + int batch_size = 2; // Q.shape[0] + int q_num_heads = 3; // Q.shape[1] + int q_sequence_length = 4; // Q.shape[2] + int head_size = 8; // Q.shape[3] + int kv_sequence_length = 6; // K.shape[2] and V.shape[2] + int kv_num_heads = 3; // K.shape[1] and V.shape[1] + int v_head_size = 8; // V.shape[3] + int past_sequence_length = 12; // past_key.shape[2] and past_value.shape[2] + + // {2, 4, 24} + std::vector q = {0.548814f, 0.715189f, 0.602763f, 0.544883f, 0.423655f, 0.645894f, 0.437587f, 0.891773f, 0.963663f, 0.383442f, 0.791725f, 0.528895f, 0.568045f, 0.925597f, 0.071036f, 0.087129f, 0.020218f, 0.832620f, 0.778157f, 0.870012f, 0.978618f, 0.799159f, 0.461479f, 0.780529f, 0.118274f, 0.639921f, 0.143353f, 0.944669f, 0.521848f, 0.414662f, 0.264556f, 0.774234f, 0.456150f, 0.568434f, 0.018790f, 0.617635f, 0.612096f, 0.616934f, 0.943748f, 0.681820f, 0.359508f, 0.437032f, 0.697631f, 0.060225f, 0.666767f, 0.670638f, 0.210383f, 0.128926f, 0.315428f, 0.363711f, 0.570197f, 0.438602f, 0.988374f, 0.102045f, 0.208877f, 0.161310f, 0.653108f, 0.253292f, 0.466311f, 0.244426f, 0.158970f, 0.110375f, 0.656330f, 0.138183f, 0.196582f, 0.368725f, 0.820993f, 0.097101f, 0.837945f, 0.096098f, 0.976459f, 0.468651f, 0.976761f, 0.604846f, 0.739264f, 0.039188f, 0.282807f, 0.120197f, 0.296140f, 0.118728f, 0.317983f, 0.414263f, 0.064147f, 0.692472f, 0.566601f, 0.265390f, 0.523248f, 0.093941f, 0.575947f, 0.929296f, 0.318569f, 0.667410f, 0.131798f, 0.716327f, 0.289406f, 0.183191f, 0.586513f, 0.020108f, 0.828940f, 0.004695f, 0.677817f, 0.270008f, 0.735194f, 0.962189f, 0.248753f, 0.576157f, 0.592042f, 0.572252f, 0.223082f, 0.952749f, 0.447125f, 0.846409f, 0.699479f, 0.297437f, 0.813798f, 0.396506f, 0.881103f, 0.581273f, 0.881735f, 0.692532f, 0.725254f, 0.501324f, 0.956084f, 0.643990f, 0.423855f, 0.606393f, 0.019193f, 0.301575f, 0.660174f, 0.290078f, 0.618015f, 0.428769f, 0.135474f, 0.298282f, 0.569965f, 0.590873f, 0.574325f, 0.653201f, 0.652103f, 0.431418f, 0.896547f, 0.367562f, 0.435865f, 0.891923f, 0.806194f, 0.703889f, 0.100227f, 0.919483f, 0.714241f, 0.998847f, 0.149448f, 0.868126f, 0.162493f, 0.615560f, 0.123820f, 0.848008f, 0.807319f, 0.569101f, 0.407183f, 0.069167f, 0.697429f, 0.453543f, 0.722056f, 0.866382f, 0.975522f, 0.855803f, 0.011714f, 0.359978f, 0.729991f, 0.171630f, 0.521037f, 0.054338f, 0.199997f, 0.018522f, 0.793698f, 0.223925f, 0.345352f, 0.928081f, 0.704414f, 0.031839f, 0.164694f, 0.621478f, 0.577229f, 0.237893f, 0.934214f, 0.613966f, 0.535633f, 0.589910f, 0.730122f, 0.311945f, 0.398221f, 0.209844f}; + // {2, 6, 24} + std::vector k = {0.186193f, 0.944372f, 0.739551f, 0.490459f, 0.227415f, 0.254356f, 0.058029f, 0.434417f, 0.311796f, 0.696343f, 0.377752f, 0.179604f, 0.024679f, 0.067250f, 0.679393f, 0.453697f, 0.536579f, 0.896671f, 0.990339f, 0.216897f, 0.663078f, 0.263322f, 0.020651f, 0.758379f, 0.320017f, 0.383464f, 0.588317f, 0.831048f, 0.628982f, 0.872651f, 0.273542f, 0.798047f, 0.185636f, 0.952792f, 0.687488f, 0.215508f, 0.947371f, 0.730856f, 0.253942f, 0.213312f, 0.518201f, 0.025663f, 0.207470f, 0.424685f, 0.374170f, 0.463575f, 0.277629f, 0.586784f, 0.863856f, 0.117532f, 0.517379f, 0.132068f, 0.716860f, 0.396060f, 0.565421f, 0.183280f, 0.144848f, 0.488056f, 0.355613f, 0.940432f, 0.765325f, 0.748664f, 0.903720f, 0.083422f, 0.552192f, 0.584476f, 0.961936f, 0.292148f, 0.240829f, 0.100294f, 0.016430f, 0.929529f, 0.669917f, 0.785153f, 0.281730f, 0.586410f, 0.063955f, 0.485628f, 0.977495f, 0.876505f, 0.338159f, 0.961570f, 0.231702f, 0.949319f, 0.941378f, 0.799203f, 0.630448f, 0.874288f, 0.293020f, 0.848944f, 0.617877f, 0.013237f, 0.347234f, 0.148141f, 0.981829f, 0.478370f, 0.497391f, 0.639473f, 0.368585f, 0.136900f, 0.822118f, 0.189848f, 0.511319f, 0.224317f, 0.097844f, 0.862191f, 0.972919f, 0.960835f, 0.906555f, 0.774047f, 0.333145f, 0.081101f, 0.407241f, 0.232234f, 0.132488f, 0.053427f, 0.725594f, 0.011427f, 0.770581f, 0.146947f, 0.079522f, 0.089603f, 0.672048f, 0.245367f, 0.420539f, 0.557369f, 0.860551f, 0.727044f, 0.270328f, 0.131483f, 0.055374f, 0.301599f, 0.262118f, 0.456141f, 0.683281f, 0.695625f, 0.283519f, 0.379927f, 0.181151f, 0.788545f, 0.056848f, 0.696997f, 0.778695f, 0.777408f, 0.259423f, 0.373813f, 0.587600f, 0.272822f, 0.370853f, 0.197054f, 0.459856f, 0.044612f, 0.799796f, 0.076956f, 0.518835f, 0.306810f, 0.577543f, 0.959433f, 0.645570f, 0.035362f, 0.430402f, 0.510017f, 0.536178f, 0.681392f, 0.277596f, 0.128861f, 0.392676f, 0.956406f, 0.187131f, 0.903984f, 0.543806f, 0.456911f, 0.882041f, 0.458604f, 0.724168f, 0.399025f, 0.904044f, 0.690025f, 0.699622f, 0.327720f, 0.756779f, 0.636061f, 0.240020f, 0.160539f, 0.796391f, 0.959167f, 0.458139f, 0.590984f, 0.857723f, 0.457223f, 0.951874f, 0.575751f, 0.820767f, 0.908844f, 0.815524f, 0.159414f, 0.628898f, 0.398434f, 0.062713f, 0.424032f, 0.258684f, 0.849038f, 0.033305f, 0.958983f, 0.355369f, 0.356707f, 0.016329f, 0.185232f, 0.401260f, 0.929291f, 0.099615f, 0.945302f, 0.869489f, 0.454162f, 0.326701f, 0.232744f, 0.614465f, 0.033075f, 0.015606f, 0.428796f, 0.068074f, 0.251941f, 0.221161f, 0.253191f, 0.131055f, 0.012036f, 0.115484f, 0.618480f, 0.974256f, 0.990345f, 0.409054f, 0.162954f, 0.638762f, 0.490305f, 0.989410f, 0.065304f, 0.783234f, 0.288399f, 0.241419f, 0.662505f, 0.246063f, 0.665859f, 0.517309f, 0.424089f, 0.554688f, 0.287052f, 0.706575f, 0.414857f, 0.360546f, 0.828657f, 0.924967f, 0.046007f, 0.232627f, 0.348519f, 0.814966f, 0.985491f, 0.968972f, 0.904948f, 0.296556f, 0.992011f, 0.249420f, 0.105906f, 0.950953f, 0.233420f, 0.689768f, 0.058356f, 0.730709f, 0.881720f, 0.272437f, 0.379057f, 0.374296f, 0.748788f, 0.237807f, 0.171853f, 0.449292f, 0.304468f, 0.839189f, 0.237742f, 0.502389f, 0.942584f, 0.633998f, 0.867289f, 0.940210f, 0.750765f, 0.699575f, 0.967966f, 0.994401f, 0.451822f}; + // {2, 6, 24} + std::vector v = {0.070870f, 0.292794f, 0.152355f, 0.417486f, 0.131289f, 0.604118f, 0.382808f, 0.895386f, 0.967795f, 0.546885f, 0.274824f, 0.592230f, 0.896761f, 0.406733f, 0.552078f, 0.271653f, 0.455444f, 0.401714f, 0.248413f, 0.505866f, 0.310381f, 0.373035f, 0.524970f, 0.750595f, 0.333507f, 0.924159f, 0.862319f, 0.048690f, 0.253643f, 0.446136f, 0.104628f, 0.348476f, 0.740098f, 0.680514f, 0.622384f, 0.710528f, 0.204924f, 0.341698f, 0.676242f, 0.879235f, 0.543678f, 0.282700f, 0.030235f, 0.710337f, 0.007884f, 0.372679f, 0.530537f, 0.922111f, 0.089495f, 0.405942f, 0.024313f, 0.342611f, 0.622231f, 0.279068f, 0.209750f, 0.115703f, 0.577140f, 0.695270f, 0.671957f, 0.948861f, 0.002703f, 0.647197f, 0.600392f, 0.588740f, 0.962770f, 0.016872f, 0.696482f, 0.813679f, 0.509807f, 0.333965f, 0.790840f, 0.097243f, 0.442036f, 0.519952f, 0.693956f, 0.090886f, 0.227759f, 0.410302f, 0.623295f, 0.886961f, 0.618826f, 0.133461f, 0.980580f, 0.871786f, 0.502721f, 0.922348f, 0.541381f, 0.923306f, 0.829897f, 0.968286f, 0.919783f, 0.036034f, 0.174772f, 0.389135f, 0.952143f, 0.300029f, 0.160468f, 0.886305f, 0.446394f, 0.907876f, 0.160230f, 0.661117f, 0.440264f, 0.076487f, 0.696463f, 0.247399f, 0.039616f, 0.059944f, 0.061079f, 0.907733f, 0.739884f, 0.898062f, 0.672582f, 0.528940f, 0.304446f, 0.997962f, 0.362189f, 0.470649f, 0.378245f, 0.979527f, 0.174658f, 0.327988f, 0.680349f, 0.063208f, 0.607249f, 0.477646f, 0.284000f, 0.238413f, 0.514513f, 0.367928f, 0.456520f, 0.337477f, 0.970494f, 0.133439f, 0.096804f, 0.343392f, 0.591027f, 0.659176f, 0.397257f, 0.999278f, 0.351893f, 0.721407f, 0.637583f, 0.813054f, 0.976226f, 0.889794f, 0.764562f, 0.698249f, 0.335498f, 0.147686f, 0.062636f, 0.241902f, 0.432281f, 0.521996f, 0.773084f, 0.958741f, 0.117320f, 0.107004f, 0.589695f, 0.745398f, 0.848150f, 0.935832f, 0.983426f, 0.399802f, 0.380335f, 0.147809f, 0.684934f, 0.656762f, 0.862063f, 0.097258f, 0.497777f, 0.581082f, 0.241557f, 0.169025f, 0.859581f, 0.058535f, 0.470621f, 0.115834f, 0.457059f, 0.979962f, 0.423706f, 0.857125f, 0.117316f, 0.271252f, 0.403793f, 0.399812f, 0.671384f, 0.344718f, 0.713767f, 0.639187f, 0.399161f, 0.431760f, 0.614528f, 0.070042f, 0.822407f, 0.653421f, 0.726342f, 0.536923f, 0.110477f, 0.405036f, 0.405374f, 0.321043f, 0.029950f, 0.737254f, 0.109784f, 0.606308f, 0.703218f, 0.634786f, 0.959142f, 0.103298f, 0.867167f, 0.029190f, 0.534917f, 0.404244f, 0.524184f, 0.365100f, 0.190567f, 0.019123f, 0.518150f, 0.842777f, 0.373216f, 0.222864f, 0.080532f, 0.085311f, 0.221396f, 0.100014f, 0.265040f, 0.066149f, 0.065605f, 0.856276f, 0.162120f, 0.559682f, 0.773456f, 0.456410f, 0.153369f, 0.199596f, 0.432984f, 0.528234f, 0.349440f, 0.781480f, 0.751022f, 0.927212f, 0.028953f, 0.895691f, 0.392569f, 0.878372f, 0.690785f, 0.987349f, 0.759282f, 0.364545f, 0.501063f, 0.376389f, 0.364912f, 0.260904f, 0.495970f, 0.681740f, 0.277340f, 0.524380f, 0.117380f, 0.159845f, 0.046806f, 0.970731f, 0.003860f, 0.178580f, 0.612867f, 0.081370f, 0.881896f, 0.719620f, 0.966390f, 0.507636f, 0.300404f, 0.549501f, 0.930819f, 0.520761f, 0.267207f, 0.877399f, 0.371919f, 0.001383f, 0.247685f, 0.318234f, 0.858777f, 0.458503f, 0.444587f, 0.336102f, 0.880678f, 0.945027f, 0.991890f, 0.376741f}; + // {4, 18} + std::vector m = {0.966147f, 0.791880f, 0.675689f, 0.244889f, 0.216457f, 0.166048f, 0.922757f, 0.294077f, 0.453094f, 0.493958f, 0.778172f, 0.844235f, 0.139073f, 0.426904f, 0.842855f, 0.818033f, 0.102414f, 0.156383f, 0.304199f, 0.075359f, 0.424663f, 0.107618f, 0.568218f, 0.246557f, 0.596433f, 0.117526f, 0.975884f, 0.932561f, 0.391797f, 0.242179f, 0.250398f, 0.483394f, 0.039993f, 0.639705f, 0.408303f, 0.377407f, 0.809365f, 0.709035f, 0.954334f, 0.351936f, 0.897543f, 0.769967f, 0.357425f, 0.621665f, 0.288570f, 0.874400f, 0.112427f, 0.212434f, 0.183033f, 0.403026f, 0.745233f, 0.526907f, 0.487676f, 0.000546f, 0.425402f, 0.063554f, 0.208253f, 0.932394f, 0.215398f, 0.858338f, 0.802893f, 0.159146f, 0.605712f, 0.115662f, 0.727888f, 0.637462f, 0.811939f, 0.479385f, 0.914863f, 0.049349f, 0.292889f, 0.715053f}; + // {2, 3, 12, 8} + std::vector past_key = {0.418109f, 0.172951f, 0.107211f, 0.817339f, 0.473143f, 0.882284f, 0.733289f, 0.409726f, 0.373511f, 0.515638f, 0.889060f, 0.737279f, 0.005153f, 0.694158f, 0.919507f, 0.710456f, 0.177006f, 0.483518f, 0.140316f, 0.358995f, 0.937117f, 0.923305f, 0.282837f, 0.339631f, 0.600213f, 0.963197f, 0.147801f, 0.256917f, 0.873557f, 0.491892f, 0.898961f, 0.185518f, 0.532669f, 0.326270f, 0.316543f, 0.446877f, 0.433077f, 0.357347f, 0.914971f, 0.731744f, 0.727547f, 0.289913f, 0.577709f, 0.779179f, 0.795590f, 0.344530f, 0.770873f, 0.735894f, 0.141506f, 0.865945f, 0.441321f, 0.486410f, 0.448369f, 0.567846f, 0.621169f, 0.498180f, 0.866789f, 0.627735f, 0.401428f, 0.416692f, 0.810839f, 0.348192f, 0.211455f, 0.059383f, 0.876027f, 0.918546f, 0.120120f, 0.334474f, 0.175372f, 0.115898f, 0.899867f, 0.056877f, 0.980486f, 0.096451f, 0.863471f, 0.566506f, 0.367917f, 0.342342f, 0.757364f, 0.314573f, 0.657319f, 0.517326f, 0.484966f, 0.901162f, 0.554645f, 0.826862f, 0.725574f, 0.038557f, 0.773110f, 0.216870f, 0.903150f, 0.042924f, 0.333072f, 0.099733f, 0.475589f, 0.820022f, 0.298187f, 0.150935f, 0.330267f, 0.813880f, 0.140384f, 0.227362f, 0.068852f, 0.705710f, 0.395233f, 0.310840f, 0.718626f, 0.335978f, 0.727771f, 0.815199f, 0.217663f, 0.973819f, 0.162358f, 0.290841f, 0.179795f, 0.345506f, 0.480061f, 0.522176f, 0.853606f, 0.889448f, 0.220104f, 0.622894f, 0.111496f, 0.458970f, 0.322334f, 0.316501f, 0.482584f, 0.729828f, 0.069183f, 0.879173f, 0.734814f, 0.176499f, 0.939161f, 0.506312f, 0.999809f, 0.197259f, 0.534908f, 0.290248f, 0.304174f, 0.591065f, 0.921719f, 0.805264f, 0.723941f, 0.559174f, 0.922298f, 0.492361f, 0.873832f, 0.833982f, 0.213835f, 0.771225f, 0.012171f, 0.322830f, 0.229567f, 0.506863f, 0.736853f, 0.097676f, 0.514922f, 0.938412f, 0.228647f, 0.677141f, 0.592880f, 0.010064f, 0.475826f, 0.708770f, 0.043975f, 0.879521f, 0.520081f, 0.030661f, 0.224414f, 0.953676f, 0.582320f, 0.107473f, 0.287544f, 0.456704f, 0.020950f, 0.411616f, 0.489459f, 0.243678f, 0.588639f, 0.753240f, 0.235834f, 0.620500f, 0.639622f, 0.948540f, 0.778276f, 0.848345f, 0.490420f, 0.185349f, 0.995815f, 0.129356f, 0.471457f, 0.068093f, 0.943851f, 0.964925f, 0.719389f, 0.349993f, 0.254382f, 0.265303f, 0.127294f, 0.525809f, 0.141817f, 0.316731f, 0.626706f, 0.727544f, 0.024273f, 0.430116f, 0.652125f, 0.853246f, 0.475325f, 0.969206f, 0.265633f, 0.013509f, 0.483753f, 0.256114f, 0.823718f, 0.232773f, 0.310629f, 0.791227f, 0.715143f, 0.558051f, 0.704948f, 0.418637f, 0.005310f, 0.011355f, 0.511222f, 0.083291f, 0.051075f, 0.965517f, 0.859003f, 0.152027f, 0.000664f, 0.941668f, 0.278325f, 0.185898f, 0.691508f, 0.108904f, 0.264650f, 0.975095f, 0.639463f, 0.520678f, 0.397919f, 0.774501f, 0.140957f, 0.967338f, 0.861123f, 0.617657f, 0.042906f, 0.700856f, 0.913284f, 0.524577f, 0.354225f, 0.120277f, 0.754901f, 0.885022f, 0.100252f, 0.758985f, 0.017060f, 0.967055f, 0.615058f, 0.552439f, 0.295950f, 0.929292f, 0.265906f, 0.828147f, 0.985109f, 0.783397f, 0.518990f, 0.066074f, 0.472414f, 0.438256f, 0.202796f, 0.423588f, 0.357758f, 0.163684f, 0.441374f, 0.262800f, 0.522062f, 0.035160f, 0.906231f, 0.816364f, 0.552581f, 0.851809f, 0.962395f, 0.110522f, 0.630832f, 0.997994f, 0.987889f, 0.603323f, 0.128021f, 0.583193f, 0.002065f, 0.198911f, 0.956123f, 0.330441f, 0.638390f, 0.280860f, 0.947822f, 0.728559f, 0.329651f, 0.791761f, 0.108166f, 0.392319f, 0.221218f, 0.683726f, 0.102446f, 0.397026f, 0.276650f, 0.506343f, 0.349898f, 0.706411f, 0.024577f, 0.633987f, 0.230571f, 0.268709f, 0.800256f, 0.955568f, 0.316550f, 0.826805f, 0.103991f, 0.633982f, 0.751032f, 0.155978f, 0.426002f, 0.892707f, 0.103578f, 0.018096f, 0.590585f, 0.435532f, 0.798689f, 0.923456f, 0.299154f, 0.388404f, 0.486272f, 0.588151f, 0.983854f, 0.697330f, 0.389549f, 0.263768f, 0.944626f, 0.135548f, 0.720266f, 0.925395f, 0.664666f, 0.423054f, 0.198991f, 0.367475f, 0.706872f, 0.649534f, 0.927976f, 0.866861f, 0.816151f, 0.911451f, 0.276337f, 0.369524f, 0.379894f, 0.560451f, 0.668218f, 0.286717f, 0.019462f, 0.399222f, 0.308528f, 0.942185f, 0.888265f, 0.860311f, 0.653000f, 0.344289f, 0.548849f, 0.815225f, 0.098610f, 0.801075f, 0.041180f, 0.816421f, 0.807564f, 0.051007f, 0.627161f, 0.502453f, 0.169820f, 0.148379f, 0.773259f, 0.567693f, 0.982999f, 0.982248f, 0.992667f, 0.118616f, 0.938256f, 0.244570f, 0.458212f, 0.757407f, 0.203621f, 0.566312f, 0.185817f, 0.104736f, 0.116559f, 0.357639f, 0.004655f, 0.424854f, 0.664197f, 0.401688f, 0.085795f, 0.062689f, 0.278117f, 0.169313f, 0.965095f, 0.151230f, 0.805462f, 0.586108f, 0.569287f, 0.512081f, 0.971763f, 0.363845f, 0.787916f, 0.555294f, 0.395634f, 0.955466f, 0.598316f, 0.118917f, 0.417539f, 0.781582f, 0.693747f, 0.916340f, 0.259377f, 0.758194f, 0.459875f, 0.573610f, 0.955047f, 0.979286f, 0.861591f, 0.359097f, 0.887701f, 0.638609f, 0.429997f, 0.035743f, 0.770128f, 0.502106f, 0.786188f, 0.748023f, 0.793567f, 0.300651f, 0.800799f, 0.548846f, 0.473326f, 0.675126f, 0.021359f, 0.102317f, 0.292177f, 0.982990f, 0.139746f, 0.330596f, 0.051053f, 0.331269f, 0.320326f, 0.946807f, 0.845154f, 0.382764f, 0.024769f, 0.831031f, 0.660536f, 0.152364f, 0.996071f, 0.100233f, 0.867115f, 0.294266f, 0.435353f, 0.795457f, 0.677508f, 0.937864f, 0.621140f, 0.097810f, 0.884360f, 0.769156f, 0.711870f, 0.053734f, 0.396223f, 0.167436f, 0.821904f, 0.700529f, 0.883078f, 0.966575f, 0.774748f, 0.994233f, 0.614770f, 0.037130f, 0.014252f, 0.342104f, 0.823472f, 0.866135f, 0.960813f, 0.065121f, 0.044571f, 0.913284f, 0.305047f, 0.557987f, 0.982445f, 0.400449f, 0.665871f, 0.400880f, 0.768195f, 0.527715f, 0.237523f, 0.271306f, 0.258059f, 0.532320f, 0.703189f, 0.949280f, 0.694087f, 0.781193f, 0.168926f, 0.374063f, 0.413780f, 0.686380f, 0.295892f, 0.303292f, 0.355889f, 0.810302f, 0.577590f, 0.075277f, 0.078246f, 0.371287f, 0.766591f, 0.688683f, 0.707982f, 0.767210f, 0.287153f, 0.548256f, 0.543353f, 0.739632f, 0.956871f, 0.277990f, 0.793282f, 0.659971f, 0.580238f, 0.774880f, 0.944032f, 0.036691f, 0.147400f, 0.756287f, 0.083791f, 0.516124f, 0.219861f, 0.274296f, 0.701840f, 0.030193f, 0.873319f, 0.444479f, 0.502393f, 0.540048f, 0.645544f, 0.344857f, 0.101107f, 0.318379f, 0.168142f, 0.556133f, 0.318029f, 0.958067f, 0.965734f, 0.620126f, 0.617497f, 0.985379f, 0.887283f, 0.765070f, 0.313591f, 0.365539f, 0.201267f, 0.487148f, 0.990369f, 0.912151f, 0.118349f, 0.025190f, 0.898638f, 0.537170f, 0.200190f, 0.673653f, 0.644223f, 0.122086f, 0.259600f, 0.060078f}; + // {2, 3, 12, 8} + std::vector past_value = {0.209860f, 0.132306f, 0.193236f, 0.685467f, 0.049500f, 0.101855f, 0.134174f, 0.316541f, 0.298750f, 0.255064f, 0.750537f, 0.998023f, 0.533978f, 0.944203f, 0.396610f, 0.106682f, 0.408774f, 0.296128f, 0.493407f, 0.657044f, 0.461050f, 0.935161f, 0.884765f, 0.701978f, 0.489685f, 0.131687f, 0.397014f, 0.704402f, 0.284886f, 0.103988f, 0.907898f, 0.709051f, 0.615276f, 0.792499f, 0.835646f, 0.483459f, 0.881188f, 0.916419f, 0.271551f, 0.607545f, 0.526584f, 0.537946f, 0.937663f, 0.305189f, 0.983434f, 0.902131f, 0.458723f, 0.817453f, 0.769047f, 0.677895f, 0.319834f, 0.196451f, 0.671528f, 0.842973f, 0.016253f, 0.642803f, 0.442873f, 0.898088f, 0.321473f, 0.474185f, 0.514767f, 0.140440f, 0.712892f, 0.830476f, 0.057909f, 0.291389f, 0.038045f, 0.956544f, 0.667169f, 0.964200f, 0.531494f, 0.802069f, 0.374414f, 0.353819f, 0.378268f, 0.657862f, 0.359453f, 0.900367f, 0.983275f, 0.030427f, 0.193623f, 0.112250f, 0.042364f, 0.227741f, 0.446793f, 0.836990f, 0.221824f, 0.493945f, 0.929619f, 0.667215f, 0.798079f, 0.550994f, 0.980466f, 0.588662f, 0.045511f, 0.197983f, 0.404774f, 0.601277f, 0.771931f, 0.413086f, 0.710058f, 0.789869f, 0.317260f, 0.979270f, 0.649656f, 0.880998f, 0.555938f, 0.741603f, 0.770544f, 0.908248f, 0.150350f, 0.558283f, 0.428379f, 0.923159f, 0.105095f, 0.982574f, 0.875451f, 0.073826f, 0.490966f, 0.717560f, 0.738152f, 0.906494f, 0.799865f, 0.310930f, 0.498435f, 0.701786f, 0.138437f, 0.193991f, 0.481042f, 0.298246f, 0.862559f, 0.586277f, 0.348665f, 0.848833f, 0.804878f, 0.998355f, 0.847308f, 0.414457f, 0.127499f, 0.840641f, 0.059758f, 0.350271f, 0.919738f, 0.960766f, 0.640565f, 0.688648f, 0.042454f, 0.514480f, 0.546868f, 0.340101f, 0.068597f, 0.228908f, 0.357984f, 0.435142f, 0.590927f, 0.722392f, 0.317632f, 0.328954f, 0.019692f, 0.040875f, 0.257822f, 0.740245f, 0.628314f, 0.769789f, 0.768919f, 0.856567f, 0.720319f, 0.979011f, 0.898825f, 0.586717f, 0.588158f, 0.034267f, 0.998527f, 0.131576f, 0.740347f, 0.821015f, 0.373055f, 0.196852f, 0.098760f, 0.748606f, 0.452654f, 0.713718f, 0.915408f, 0.146584f, 0.919171f, 0.411626f, 0.305267f, 0.943062f, 0.990652f, 0.198892f, 0.656838f, 0.106495f, 0.650914f, 0.827313f, 0.684499f, 0.417333f, 0.383066f, 0.393122f, 0.589712f, 0.881567f, 0.929066f, 0.053530f, 0.181622f, 0.112224f, 0.193335f, 0.346608f, 0.506532f, 0.629461f, 0.732142f, 0.890112f, 0.989088f, 0.662856f, 0.845365f, 0.778039f, 0.307532f, 0.875692f, 0.042763f, 0.000367f, 0.273733f, 0.462098f, 0.638363f, 0.101770f, 0.673010f, 0.801816f, 0.185313f, 0.415125f, 0.519985f, 0.451807f, 0.799830f, 0.960522f, 0.798953f, 0.077993f, 0.804936f, 0.066596f, 0.235970f, 0.153097f, 0.197519f, 0.528315f, 0.671690f, 0.470321f, 0.959696f, 0.240292f, 0.763140f, 0.870182f, 0.562066f, 0.456223f, 0.596184f, 0.428810f, 0.555194f, 0.416934f, 0.400470f, 0.695346f, 0.092851f, 0.166542f, 0.851198f, 0.771077f, 0.281454f, 0.377269f, 0.926027f, 0.818077f, 0.614346f, 0.221490f, 0.044252f, 0.431258f, 0.672627f, 0.828480f, 0.852689f, 0.032776f, 0.244157f, 0.339095f, 0.188732f, 0.802975f, 0.767466f, 0.516833f, 0.982926f, 0.144059f, 0.899652f, 0.116463f, 0.163182f, 0.696219f, 0.109570f, 0.565845f, 0.420234f, 0.728474f, 0.900675f, 0.769872f, 0.849690f, 0.032945f, 0.310196f, 0.515433f, 0.415953f, 0.231255f, 0.307874f, 0.945431f, 0.294181f, 0.353904f, 0.003710f, 0.845078f, 0.154841f, 0.204144f, 0.255265f, 0.884622f, 0.206451f, 0.797526f, 0.808049f, 0.927021f, 0.115561f, 0.217279f, 0.742898f, 0.196001f, 0.286330f, 0.166742f, 0.172697f, 0.481553f, 0.109683f, 0.321698f, 0.426594f, 0.024548f, 0.388333f, 0.094122f, 0.493579f, 0.825738f, 0.818422f, 0.080449f, 0.601228f, 0.834586f, 0.237973f, 0.761927f, 0.890764f, 0.806124f, 0.107301f, 0.009060f, 0.191724f, 0.270477f, 0.616183f, 0.384273f, 0.703407f, 0.353075f, 0.154425f, 0.312690f, 0.884324f, 0.958532f, 0.207513f, 0.788468f, 0.273349f, 0.887132f, 0.165546f, 0.665960f, 0.084211f, 0.973893f, 0.700633f, 0.841816f, 0.566669f, 0.476801f, 0.621882f, 0.528742f, 0.469384f, 0.759450f, 0.178201f, 0.171172f, 0.431843f, 0.320748f, 0.074125f, 0.844471f, 0.771603f, 0.543921f, 0.979325f, 0.072600f, 0.766669f, 0.266370f, 0.368599f, 0.219279f, 0.789038f, 0.144240f, 0.840017f, 0.661578f, 0.059023f, 0.810982f, 0.627756f, 0.904982f, 0.748722f, 0.561121f, 0.836547f, 0.278050f, 0.546950f, 0.293617f, 0.968204f, 0.226196f, 0.015738f, 0.325855f, 0.502509f, 0.028363f, 0.559248f, 0.874283f, 0.704732f, 0.622968f, 0.955962f, 0.958279f, 0.824266f, 0.607742f, 0.487765f, 0.013316f, 0.606262f, 0.989088f, 0.818101f, 0.340605f, 0.152047f, 0.784059f, 0.743938f, 0.967047f, 0.874842f, 0.555663f, 0.101284f, 0.483501f, 0.313695f, 0.512408f, 0.301702f, 0.861823f, 0.844327f, 0.315465f, 0.599581f, 0.430181f, 0.909093f, 0.187361f, 0.697728f, 0.970375f, 0.175276f, 0.201966f, 0.693723f, 0.779154f, 0.490549f, 0.609686f, 0.212682f, 0.476614f, 0.112072f, 0.321422f, 0.284780f, 0.444625f, 0.930126f, 0.181268f, 0.401388f, 0.615597f, 0.946557f, 0.133148f, 0.917877f, 0.081054f, 0.480741f, 0.454590f, 0.209603f, 0.347460f, 0.454165f, 0.865211f, 0.955064f, 0.518926f, 0.870100f, 0.608172f, 0.349087f, 0.194194f, 0.413135f, 0.522824f, 0.044443f, 0.145841f, 0.600184f, 0.225002f, 0.837326f, 0.326942f, 0.104834f, 0.083531f, 0.937123f, 0.118020f, 0.140910f, 0.862666f, 0.254288f, 0.665951f, 0.816726f, 0.607181f, 0.957489f, 0.708883f, 0.112752f, 0.558410f, 0.718186f, 0.801957f, 0.026321f, 0.718879f, 0.825681f, 0.746834f, 0.512349f, 0.458021f, 0.549419f, 0.704644f, 0.922914f, 0.617035f, 0.887834f, 0.701257f, 0.068336f, 0.500828f, 0.286486f, 0.285175f, 0.355928f, 0.314733f, 0.578610f, 0.683601f, 0.268749f, 0.129763f, 0.058809f, 0.575753f, 0.186130f, 0.009248f, 0.927753f, 0.537140f, 0.092448f, 0.842921f, 0.983203f, 0.448601f, 0.042490f, 0.117546f, 0.381654f, 0.885523f, 0.148039f, 0.823990f, 0.014976f, 0.457389f, 0.644397f, 0.060379f, 0.614763f, 0.944404f, 0.160260f, 0.729611f, 0.609094f, 0.185116f, 0.006203f, 0.009284f, 0.532092f, 0.942779f, 0.644299f, 0.714300f, 0.493865f, 0.581889f, 0.126368f, 0.876821f, 0.760793f, 0.998199f, 0.297723f, 0.227018f, 0.125162f, 0.964210f, 0.780885f, 0.166325f, 0.552686f, 0.413768f, 0.151486f, 0.162073f, 0.963470f, 0.304964f, 0.941439f, 0.075611f, 0.460803f, 0.129619f, 0.004787f, 0.553766f, 0.113894f, 0.722025f, 0.698116f, 0.176333f, 0.941742f, 0.721043f, 0.297970f, 0.709234f, 0.731930f, 0.342226f, 0.375589f, 0.359107f, 0.616618f, 0.900410f, 0.173193f, 0.875200f}; + + ASSERT_EQ(q.size(), batch_size * q_num_heads * q_sequence_length * head_size); + ASSERT_EQ(k.size(), batch_size * kv_num_heads * kv_sequence_length * head_size); + ASSERT_EQ(v.size(), batch_size * kv_num_heads * kv_sequence_length * v_head_size); + ASSERT_EQ(m.size(), q_sequence_length * (kv_sequence_length + past_sequence_length)); + ASSERT_EQ(past_key.size(), batch_size * kv_num_heads * past_sequence_length * head_size); + ASSERT_EQ(past_value.size(), batch_size * kv_num_heads * past_sequence_length * v_head_size); + + // {2, 4, 24} + std::vector y = {0.387434f, 0.451660f, 0.466422f, 0.473844f, 0.487732f, 0.616663f, 0.389945f, 0.474446f, 0.610035f, 0.540721f, 0.465339f, 0.659275f, 0.542400f, 0.558199f, 0.496998f, 0.580479f, 0.608613f, 0.454357f, 0.591427f, 0.539400f, 0.491600f, 0.439752f, 0.574766f, 0.534788f, 0.369295f, 0.476453f, 0.472667f, 0.474934f, 0.484975f, 0.653894f, 0.434421f, 0.507237f, 0.606547f, 0.512561f, 0.492485f, 0.627438f, 0.547220f, 0.559142f, 0.549041f, 0.650326f, 0.576993f, 0.484612f, 0.597630f, 0.527508f, 0.458643f, 0.432526f, 0.522555f, 0.581898f, 0.375984f, 0.479550f, 0.484624f, 0.506722f, 0.499591f, 0.628391f, 0.457767f, 0.484544f, 0.612554f, 0.547468f, 0.485806f, 0.634928f, 0.524544f, 0.542711f, 0.529978f, 0.645564f, 0.613958f, 0.471193f, 0.571000f, 0.499555f, 0.454844f, 0.456024f, 0.567122f, 0.580956f, 0.367353f, 0.449829f, 0.439545f, 0.467891f, 0.516863f, 0.600392f, 0.405625f, 0.505181f, 0.632177f, 0.541634f, 0.449302f, 0.641351f, 0.504706f, 0.533341f, 0.527675f, 0.566799f, 0.572756f, 0.403738f, 0.539009f, 0.570743f, 0.478912f, 0.426711f, 0.567812f, 0.569001f, 0.495478f, 0.510849f, 0.388839f, 0.497814f, 0.545673f, 0.571958f, 0.453011f, 0.440750f, 0.458974f, 0.457386f, 0.506820f, 0.500591f, 0.499766f, 0.469500f, 0.465457f, 0.482146f, 0.581360f, 0.481272f, 0.463336f, 0.277110f, 0.627647f, 0.672684f, 0.342731f, 0.533800f, 0.530251f, 0.504140f, 0.385565f, 0.520337f, 0.548283f, 0.549735f, 0.473426f, 0.404586f, 0.463533f, 0.448576f, 0.497032f, 0.524322f, 0.474570f, 0.430653f, 0.498514f, 0.465629f, 0.578306f, 0.489042f, 0.491176f, 0.239511f, 0.588495f, 0.640517f, 0.319799f, 0.521414f, 0.510868f, 0.564625f, 0.348291f, 0.465071f, 0.498481f, 0.557391f, 0.469662f, 0.433203f, 0.471745f, 0.483765f, 0.520633f, 0.501991f, 0.485003f, 0.471836f, 0.500727f, 0.477256f, 0.574286f, 0.472931f, 0.487446f, 0.259796f, 0.603843f, 0.658305f, 0.303291f, 0.520652f, 0.560815f, 0.513931f, 0.418469f, 0.482361f, 0.535024f, 0.506256f, 0.440027f, 0.428132f, 0.519530f, 0.520400f, 0.482710f, 0.517258f, 0.479400f, 0.442196f, 0.466145f, 0.508808f, 0.534070f, 0.488154f, 0.483878f, 0.234783f, 0.628834f, 0.685886f, 0.369073f, 0.545753f}; + // {2, 3, 18, 8} + std::vector present_key = {0.418109f, 0.172951f, 0.107211f, 0.817339f, 0.473143f, 0.882284f, 0.733289f, 0.409726f, 0.373511f, 0.515638f, 0.889060f, 0.737279f, 0.005153f, 0.694158f, 0.919507f, 0.710456f, 0.177006f, 0.483518f, 0.140316f, 0.358995f, 0.937117f, 0.923305f, 0.282837f, 0.339631f, 0.600213f, 0.963197f, 0.147801f, 0.256917f, 0.873557f, 0.491892f, 0.898961f, 0.185518f, 0.532669f, 0.326270f, 0.316543f, 0.446877f, 0.433077f, 0.357347f, 0.914971f, 0.731744f, 0.727547f, 0.289913f, 0.577709f, 0.779179f, 0.795590f, 0.344530f, 0.770873f, 0.735894f, 0.141506f, 0.865945f, 0.441321f, 0.486410f, 0.448369f, 0.567846f, 0.621169f, 0.498180f, 0.866789f, 0.627735f, 0.401428f, 0.416692f, 0.810839f, 0.348192f, 0.211455f, 0.059383f, 0.876027f, 0.918546f, 0.120120f, 0.334474f, 0.175372f, 0.115898f, 0.899867f, 0.056877f, 0.980486f, 0.096451f, 0.863471f, 0.566506f, 0.367917f, 0.342342f, 0.757364f, 0.314573f, 0.657319f, 0.517326f, 0.484966f, 0.901162f, 0.554645f, 0.826862f, 0.725574f, 0.038557f, 0.773110f, 0.216870f, 0.903150f, 0.042924f, 0.333072f, 0.099733f, 0.475589f, 0.820022f, 0.186193f, 0.944372f, 0.739551f, 0.490459f, 0.227415f, 0.254356f, 0.058029f, 0.434417f, 0.320017f, 0.383464f, 0.588317f, 0.831048f, 0.628982f, 0.872651f, 0.273542f, 0.798047f, 0.863856f, 0.117532f, 0.517379f, 0.132068f, 0.716860f, 0.396060f, 0.565421f, 0.183280f, 0.669917f, 0.785153f, 0.281730f, 0.586410f, 0.063955f, 0.485628f, 0.977495f, 0.876505f, 0.497391f, 0.639473f, 0.368585f, 0.136900f, 0.822118f, 0.189848f, 0.511319f, 0.224317f, 0.079522f, 0.089603f, 0.672048f, 0.245367f, 0.420539f, 0.557369f, 0.860551f, 0.727044f, 0.298187f, 0.150935f, 0.330267f, 0.813880f, 0.140384f, 0.227362f, 0.068852f, 0.705710f, 0.395233f, 0.310840f, 0.718626f, 0.335978f, 0.727771f, 0.815199f, 0.217663f, 0.973819f, 0.162358f, 0.290841f, 0.179795f, 0.345506f, 0.480061f, 0.522176f, 0.853606f, 0.889448f, 0.220104f, 0.622894f, 0.111496f, 0.458970f, 0.322334f, 0.316501f, 0.482584f, 0.729828f, 0.069183f, 0.879173f, 0.734814f, 0.176499f, 0.939161f, 0.506312f, 0.999809f, 0.197259f, 0.534908f, 0.290248f, 0.304174f, 0.591065f, 0.921719f, 0.805264f, 0.723941f, 0.559174f, 0.922298f, 0.492361f, 0.873832f, 0.833982f, 0.213835f, 0.771225f, 0.012171f, 0.322830f, 0.229567f, 0.506863f, 0.736853f, 0.097676f, 0.514922f, 0.938412f, 0.228647f, 0.677141f, 0.592880f, 0.010064f, 0.475826f, 0.708770f, 0.043975f, 0.879521f, 0.520081f, 0.030661f, 0.224414f, 0.953676f, 0.582320f, 0.107473f, 0.287544f, 0.456704f, 0.020950f, 0.411616f, 0.489459f, 0.243678f, 0.588639f, 0.753240f, 0.235834f, 0.620500f, 0.639622f, 0.948540f, 0.778276f, 0.848345f, 0.490420f, 0.185349f, 0.995815f, 0.129356f, 0.471457f, 0.068093f, 0.311796f, 0.696343f, 0.377752f, 0.179604f, 0.024679f, 0.067250f, 0.679393f, 0.453697f, 0.185636f, 0.952792f, 0.687488f, 0.215508f, 0.947371f, 0.730856f, 0.253942f, 0.213312f, 0.144848f, 0.488056f, 0.355613f, 0.940432f, 0.765325f, 0.748664f, 0.903720f, 0.083422f, 0.338159f, 0.961570f, 0.231702f, 0.949319f, 0.941378f, 0.799203f, 0.630448f, 0.874288f, 0.097844f, 0.862191f, 0.972919f, 0.960835f, 0.906555f, 0.774047f, 0.333145f, 0.081101f, 0.270328f, 0.131483f, 0.055374f, 0.301599f, 0.262118f, 0.456141f, 0.683281f, 0.695625f, 0.943851f, 0.964925f, 0.719389f, 0.349993f, 0.254382f, 0.265303f, 0.127294f, 0.525809f, 0.141817f, 0.316731f, 0.626706f, 0.727544f, 0.024273f, 0.430116f, 0.652125f, 0.853246f, 0.475325f, 0.969206f, 0.265633f, 0.013509f, 0.483753f, 0.256114f, 0.823718f, 0.232773f, 0.310629f, 0.791227f, 0.715143f, 0.558051f, 0.704948f, 0.418637f, 0.005310f, 0.011355f, 0.511222f, 0.083291f, 0.051075f, 0.965517f, 0.859003f, 0.152027f, 0.000664f, 0.941668f, 0.278325f, 0.185898f, 0.691508f, 0.108904f, 0.264650f, 0.975095f, 0.639463f, 0.520678f, 0.397919f, 0.774501f, 0.140957f, 0.967338f, 0.861123f, 0.617657f, 0.042906f, 0.700856f, 0.913284f, 0.524577f, 0.354225f, 0.120277f, 0.754901f, 0.885022f, 0.100252f, 0.758985f, 0.017060f, 0.967055f, 0.615058f, 0.552439f, 0.295950f, 0.929292f, 0.265906f, 0.828147f, 0.985109f, 0.783397f, 0.518990f, 0.066074f, 0.472414f, 0.438256f, 0.202796f, 0.423588f, 0.357758f, 0.163684f, 0.441374f, 0.262800f, 0.522062f, 0.035160f, 0.906231f, 0.816364f, 0.552581f, 0.851809f, 0.962395f, 0.110522f, 0.630832f, 0.997994f, 0.987889f, 0.603323f, 0.536579f, 0.896671f, 0.990339f, 0.216897f, 0.663078f, 0.263322f, 0.020651f, 0.758379f, 0.518201f, 0.025663f, 0.207470f, 0.424685f, 0.374170f, 0.463575f, 0.277629f, 0.586784f, 0.552192f, 0.584476f, 0.961936f, 0.292148f, 0.240829f, 0.100294f, 0.016430f, 0.929529f, 0.293020f, 0.848944f, 0.617877f, 0.013237f, 0.347234f, 0.148141f, 0.981829f, 0.478370f, 0.407241f, 0.232234f, 0.132488f, 0.053427f, 0.725594f, 0.011427f, 0.770581f, 0.146947f, 0.283519f, 0.379927f, 0.181151f, 0.788545f, 0.056848f, 0.696997f, 0.778695f, 0.777408f, 0.128021f, 0.583193f, 0.002065f, 0.198911f, 0.956123f, 0.330441f, 0.638390f, 0.280860f, 0.947822f, 0.728559f, 0.329651f, 0.791761f, 0.108166f, 0.392319f, 0.221218f, 0.683726f, 0.102446f, 0.397026f, 0.276650f, 0.506343f, 0.349898f, 0.706411f, 0.024577f, 0.633987f, 0.230571f, 0.268709f, 0.800256f, 0.955568f, 0.316550f, 0.826805f, 0.103991f, 0.633982f, 0.751032f, 0.155978f, 0.426002f, 0.892707f, 0.103578f, 0.018096f, 0.590585f, 0.435532f, 0.798689f, 0.923456f, 0.299154f, 0.388404f, 0.486272f, 0.588151f, 0.983854f, 0.697330f, 0.389549f, 0.263768f, 0.944626f, 0.135548f, 0.720266f, 0.925395f, 0.664666f, 0.423054f, 0.198991f, 0.367475f, 0.706872f, 0.649534f, 0.927976f, 0.866861f, 0.816151f, 0.911451f, 0.276337f, 0.369524f, 0.379894f, 0.560451f, 0.668218f, 0.286717f, 0.019462f, 0.399222f, 0.308528f, 0.942185f, 0.888265f, 0.860311f, 0.653000f, 0.344289f, 0.548849f, 0.815225f, 0.098610f, 0.801075f, 0.041180f, 0.816421f, 0.807564f, 0.051007f, 0.627161f, 0.502453f, 0.169820f, 0.148379f, 0.773259f, 0.567693f, 0.982999f, 0.982248f, 0.992667f, 0.118616f, 0.259423f, 0.373813f, 0.587600f, 0.272822f, 0.370853f, 0.197054f, 0.459856f, 0.044612f, 0.187131f, 0.903984f, 0.543806f, 0.456911f, 0.882041f, 0.458604f, 0.724168f, 0.399025f, 0.820767f, 0.908844f, 0.815524f, 0.159414f, 0.628898f, 0.398434f, 0.062713f, 0.424032f, 0.614465f, 0.033075f, 0.015606f, 0.428796f, 0.068074f, 0.251941f, 0.221161f, 0.253191f, 0.246063f, 0.665859f, 0.517309f, 0.424089f, 0.554688f, 0.287052f, 0.706575f, 0.414857f, 0.689768f, 0.058356f, 0.730709f, 0.881720f, 0.272437f, 0.379057f, 0.374296f, 0.748788f, 0.938256f, 0.244570f, 0.458212f, 0.757407f, 0.203621f, 0.566312f, 0.185817f, 0.104736f, 0.116559f, 0.357639f, 0.004655f, 0.424854f, 0.664197f, 0.401688f, 0.085795f, 0.062689f, 0.278117f, 0.169313f, 0.965095f, 0.151230f, 0.805462f, 0.586108f, 0.569287f, 0.512081f, 0.971763f, 0.363845f, 0.787916f, 0.555294f, 0.395634f, 0.955466f, 0.598316f, 0.118917f, 0.417539f, 0.781582f, 0.693747f, 0.916340f, 0.259377f, 0.758194f, 0.459875f, 0.573610f, 0.955047f, 0.979286f, 0.861591f, 0.359097f, 0.887701f, 0.638609f, 0.429997f, 0.035743f, 0.770128f, 0.502106f, 0.786188f, 0.748023f, 0.793567f, 0.300651f, 0.800799f, 0.548846f, 0.473326f, 0.675126f, 0.021359f, 0.102317f, 0.292177f, 0.982990f, 0.139746f, 0.330596f, 0.051053f, 0.331269f, 0.320326f, 0.946807f, 0.845154f, 0.382764f, 0.024769f, 0.831031f, 0.660536f, 0.152364f, 0.996071f, 0.100233f, 0.867115f, 0.294266f, 0.435353f, 0.795457f, 0.677508f, 0.937864f, 0.621140f, 0.097810f, 0.884360f, 0.769156f, 0.711870f, 0.053734f, 0.396223f, 0.167436f, 0.821904f, 0.700529f, 0.883078f, 0.966575f, 0.774748f, 0.994233f, 0.799796f, 0.076956f, 0.518835f, 0.306810f, 0.577543f, 0.959433f, 0.645570f, 0.035362f, 0.904044f, 0.690025f, 0.699622f, 0.327720f, 0.756779f, 0.636061f, 0.240020f, 0.160539f, 0.258684f, 0.849038f, 0.033305f, 0.958983f, 0.355369f, 0.356707f, 0.016329f, 0.185232f, 0.131055f, 0.012036f, 0.115484f, 0.618480f, 0.974256f, 0.990345f, 0.409054f, 0.162954f, 0.360546f, 0.828657f, 0.924967f, 0.046007f, 0.232627f, 0.348519f, 0.814966f, 0.985491f, 0.237807f, 0.171853f, 0.449292f, 0.304468f, 0.839189f, 0.237742f, 0.502389f, 0.942584f, 0.614770f, 0.037130f, 0.014252f, 0.342104f, 0.823472f, 0.866135f, 0.960813f, 0.065121f, 0.044571f, 0.913284f, 0.305047f, 0.557987f, 0.982445f, 0.400449f, 0.665871f, 0.400880f, 0.768195f, 0.527715f, 0.237523f, 0.271306f, 0.258059f, 0.532320f, 0.703189f, 0.949280f, 0.694087f, 0.781193f, 0.168926f, 0.374063f, 0.413780f, 0.686380f, 0.295892f, 0.303292f, 0.355889f, 0.810302f, 0.577590f, 0.075277f, 0.078246f, 0.371287f, 0.766591f, 0.688683f, 0.707982f, 0.767210f, 0.287153f, 0.548256f, 0.543353f, 0.739632f, 0.956871f, 0.277990f, 0.793282f, 0.659971f, 0.580238f, 0.774880f, 0.944032f, 0.036691f, 0.147400f, 0.756287f, 0.083791f, 0.516124f, 0.219861f, 0.274296f, 0.701840f, 0.030193f, 0.873319f, 0.444479f, 0.502393f, 0.540048f, 0.645544f, 0.344857f, 0.101107f, 0.318379f, 0.168142f, 0.556133f, 0.318029f, 0.958067f, 0.965734f, 0.620126f, 0.617497f, 0.985379f, 0.887283f, 0.765070f, 0.313591f, 0.365539f, 0.201267f, 0.487148f, 0.990369f, 0.912151f, 0.118349f, 0.025190f, 0.898638f, 0.537170f, 0.200190f, 0.673653f, 0.644223f, 0.122086f, 0.259600f, 0.060078f, 0.430402f, 0.510017f, 0.536178f, 0.681392f, 0.277596f, 0.128861f, 0.392676f, 0.956406f, 0.796391f, 0.959167f, 0.458139f, 0.590984f, 0.857723f, 0.457223f, 0.951874f, 0.575751f, 0.401260f, 0.929291f, 0.099615f, 0.945302f, 0.869489f, 0.454162f, 0.326701f, 0.232744f, 0.638762f, 0.490305f, 0.989410f, 0.065304f, 0.783234f, 0.288399f, 0.241419f, 0.662505f, 0.968972f, 0.904948f, 0.296556f, 0.992011f, 0.249420f, 0.105906f, 0.950953f, 0.233420f, 0.633998f, 0.867289f, 0.940210f, 0.750765f, 0.699575f, 0.967966f, 0.994401f, 0.451822f}; + // {2, 3, 18, 8} + std::vector present_value = {0.209860f, 0.132306f, 0.193236f, 0.685467f, 0.049500f, 0.101855f, 0.134174f, 0.316541f, 0.298750f, 0.255064f, 0.750537f, 0.998023f, 0.533978f, 0.944203f, 0.396610f, 0.106682f, 0.408774f, 0.296128f, 0.493407f, 0.657044f, 0.461050f, 0.935161f, 0.884765f, 0.701978f, 0.489685f, 0.131687f, 0.397014f, 0.704402f, 0.284886f, 0.103988f, 0.907898f, 0.709051f, 0.615276f, 0.792499f, 0.835646f, 0.483459f, 0.881188f, 0.916419f, 0.271551f, 0.607545f, 0.526584f, 0.537946f, 0.937663f, 0.305189f, 0.983434f, 0.902131f, 0.458723f, 0.817453f, 0.769047f, 0.677895f, 0.319834f, 0.196451f, 0.671528f, 0.842973f, 0.016253f, 0.642803f, 0.442873f, 0.898088f, 0.321473f, 0.474185f, 0.514767f, 0.140440f, 0.712892f, 0.830476f, 0.057909f, 0.291389f, 0.038045f, 0.956544f, 0.667169f, 0.964200f, 0.531494f, 0.802069f, 0.374414f, 0.353819f, 0.378268f, 0.657862f, 0.359453f, 0.900367f, 0.983275f, 0.030427f, 0.193623f, 0.112250f, 0.042364f, 0.227741f, 0.446793f, 0.836990f, 0.221824f, 0.493945f, 0.929619f, 0.667215f, 0.798079f, 0.550994f, 0.980466f, 0.588662f, 0.045511f, 0.197983f, 0.070870f, 0.292794f, 0.152355f, 0.417486f, 0.131289f, 0.604118f, 0.382808f, 0.895386f, 0.333507f, 0.924159f, 0.862319f, 0.048690f, 0.253643f, 0.446136f, 0.104628f, 0.348476f, 0.089495f, 0.405942f, 0.024313f, 0.342611f, 0.622231f, 0.279068f, 0.209750f, 0.115703f, 0.442036f, 0.519952f, 0.693956f, 0.090886f, 0.227759f, 0.410302f, 0.623295f, 0.886961f, 0.160468f, 0.886305f, 0.446394f, 0.907876f, 0.160230f, 0.661117f, 0.440264f, 0.076487f, 0.174658f, 0.327988f, 0.680349f, 0.063208f, 0.607249f, 0.477646f, 0.284000f, 0.238413f, 0.404774f, 0.601277f, 0.771931f, 0.413086f, 0.710058f, 0.789869f, 0.317260f, 0.979270f, 0.649656f, 0.880998f, 0.555938f, 0.741603f, 0.770544f, 0.908248f, 0.150350f, 0.558283f, 0.428379f, 0.923159f, 0.105095f, 0.982574f, 0.875451f, 0.073826f, 0.490966f, 0.717560f, 0.738152f, 0.906494f, 0.799865f, 0.310930f, 0.498435f, 0.701786f, 0.138437f, 0.193991f, 0.481042f, 0.298246f, 0.862559f, 0.586277f, 0.348665f, 0.848833f, 0.804878f, 0.998355f, 0.847308f, 0.414457f, 0.127499f, 0.840641f, 0.059758f, 0.350271f, 0.919738f, 0.960766f, 0.640565f, 0.688648f, 0.042454f, 0.514480f, 0.546868f, 0.340101f, 0.068597f, 0.228908f, 0.357984f, 0.435142f, 0.590927f, 0.722392f, 0.317632f, 0.328954f, 0.019692f, 0.040875f, 0.257822f, 0.740245f, 0.628314f, 0.769789f, 0.768919f, 0.856567f, 0.720319f, 0.979011f, 0.898825f, 0.586717f, 0.588158f, 0.034267f, 0.998527f, 0.131576f, 0.740347f, 0.821015f, 0.373055f, 0.196852f, 0.098760f, 0.748606f, 0.452654f, 0.713718f, 0.915408f, 0.146584f, 0.919171f, 0.411626f, 0.305267f, 0.943062f, 0.990652f, 0.198892f, 0.656838f, 0.106495f, 0.967795f, 0.546885f, 0.274824f, 0.592230f, 0.896761f, 0.406733f, 0.552078f, 0.271653f, 0.740098f, 0.680514f, 0.622384f, 0.710528f, 0.204924f, 0.341698f, 0.676242f, 0.879235f, 0.577140f, 0.695270f, 0.671957f, 0.948861f, 0.002703f, 0.647197f, 0.600392f, 0.588740f, 0.618826f, 0.133461f, 0.980580f, 0.871786f, 0.502721f, 0.922348f, 0.541381f, 0.923306f, 0.696463f, 0.247399f, 0.039616f, 0.059944f, 0.061079f, 0.907733f, 0.739884f, 0.898062f, 0.514513f, 0.367928f, 0.456520f, 0.337477f, 0.970494f, 0.133439f, 0.096804f, 0.343392f, 0.650914f, 0.827313f, 0.684499f, 0.417333f, 0.383066f, 0.393122f, 0.589712f, 0.881567f, 0.929066f, 0.053530f, 0.181622f, 0.112224f, 0.193335f, 0.346608f, 0.506532f, 0.629461f, 0.732142f, 0.890112f, 0.989088f, 0.662856f, 0.845365f, 0.778039f, 0.307532f, 0.875692f, 0.042763f, 0.000367f, 0.273733f, 0.462098f, 0.638363f, 0.101770f, 0.673010f, 0.801816f, 0.185313f, 0.415125f, 0.519985f, 0.451807f, 0.799830f, 0.960522f, 0.798953f, 0.077993f, 0.804936f, 0.066596f, 0.235970f, 0.153097f, 0.197519f, 0.528315f, 0.671690f, 0.470321f, 0.959696f, 0.240292f, 0.763140f, 0.870182f, 0.562066f, 0.456223f, 0.596184f, 0.428810f, 0.555194f, 0.416934f, 0.400470f, 0.695346f, 0.092851f, 0.166542f, 0.851198f, 0.771077f, 0.281454f, 0.377269f, 0.926027f, 0.818077f, 0.614346f, 0.221490f, 0.044252f, 0.431258f, 0.672627f, 0.828480f, 0.852689f, 0.032776f, 0.244157f, 0.339095f, 0.188732f, 0.802975f, 0.767466f, 0.516833f, 0.982926f, 0.144059f, 0.899652f, 0.116463f, 0.163182f, 0.696219f, 0.109570f, 0.565845f, 0.420234f, 0.728474f, 0.900675f, 0.769872f, 0.849690f, 0.032945f, 0.455444f, 0.401714f, 0.248413f, 0.505866f, 0.310381f, 0.373035f, 0.524970f, 0.750595f, 0.543678f, 0.282700f, 0.030235f, 0.710337f, 0.007884f, 0.372679f, 0.530537f, 0.922111f, 0.962770f, 0.016872f, 0.696482f, 0.813679f, 0.509807f, 0.333965f, 0.790840f, 0.097243f, 0.829897f, 0.968286f, 0.919783f, 0.036034f, 0.174772f, 0.389135f, 0.952143f, 0.300029f, 0.672582f, 0.528940f, 0.304446f, 0.997962f, 0.362189f, 0.470649f, 0.378245f, 0.979527f, 0.591027f, 0.659176f, 0.397257f, 0.999278f, 0.351893f, 0.721407f, 0.637583f, 0.813054f, 0.310196f, 0.515433f, 0.415953f, 0.231255f, 0.307874f, 0.945431f, 0.294181f, 0.353904f, 0.003710f, 0.845078f, 0.154841f, 0.204144f, 0.255265f, 0.884622f, 0.206451f, 0.797526f, 0.808049f, 0.927021f, 0.115561f, 0.217279f, 0.742898f, 0.196001f, 0.286330f, 0.166742f, 0.172697f, 0.481553f, 0.109683f, 0.321698f, 0.426594f, 0.024548f, 0.388333f, 0.094122f, 0.493579f, 0.825738f, 0.818422f, 0.080449f, 0.601228f, 0.834586f, 0.237973f, 0.761927f, 0.890764f, 0.806124f, 0.107301f, 0.009060f, 0.191724f, 0.270477f, 0.616183f, 0.384273f, 0.703407f, 0.353075f, 0.154425f, 0.312690f, 0.884324f, 0.958532f, 0.207513f, 0.788468f, 0.273349f, 0.887132f, 0.165546f, 0.665960f, 0.084211f, 0.973893f, 0.700633f, 0.841816f, 0.566669f, 0.476801f, 0.621882f, 0.528742f, 0.469384f, 0.759450f, 0.178201f, 0.171172f, 0.431843f, 0.320748f, 0.074125f, 0.844471f, 0.771603f, 0.543921f, 0.979325f, 0.072600f, 0.766669f, 0.266370f, 0.368599f, 0.219279f, 0.789038f, 0.144240f, 0.840017f, 0.661578f, 0.059023f, 0.810982f, 0.627756f, 0.904982f, 0.748722f, 0.561121f, 0.836547f, 0.278050f, 0.976226f, 0.889794f, 0.764562f, 0.698249f, 0.335498f, 0.147686f, 0.062636f, 0.241902f, 0.862063f, 0.097258f, 0.497777f, 0.581082f, 0.241557f, 0.169025f, 0.859581f, 0.058535f, 0.614528f, 0.070042f, 0.822407f, 0.653421f, 0.726342f, 0.536923f, 0.110477f, 0.405036f, 0.190567f, 0.019123f, 0.518150f, 0.842777f, 0.373216f, 0.222864f, 0.080532f, 0.085311f, 0.751022f, 0.927212f, 0.028953f, 0.895691f, 0.392569f, 0.878372f, 0.690785f, 0.987349f, 0.612867f, 0.081370f, 0.881896f, 0.719620f, 0.966390f, 0.507636f, 0.300404f, 0.549501f, 0.546950f, 0.293617f, 0.968204f, 0.226196f, 0.015738f, 0.325855f, 0.502509f, 0.028363f, 0.559248f, 0.874283f, 0.704732f, 0.622968f, 0.955962f, 0.958279f, 0.824266f, 0.607742f, 0.487765f, 0.013316f, 0.606262f, 0.989088f, 0.818101f, 0.340605f, 0.152047f, 0.784059f, 0.743938f, 0.967047f, 0.874842f, 0.555663f, 0.101284f, 0.483501f, 0.313695f, 0.512408f, 0.301702f, 0.861823f, 0.844327f, 0.315465f, 0.599581f, 0.430181f, 0.909093f, 0.187361f, 0.697728f, 0.970375f, 0.175276f, 0.201966f, 0.693723f, 0.779154f, 0.490549f, 0.609686f, 0.212682f, 0.476614f, 0.112072f, 0.321422f, 0.284780f, 0.444625f, 0.930126f, 0.181268f, 0.401388f, 0.615597f, 0.946557f, 0.133148f, 0.917877f, 0.081054f, 0.480741f, 0.454590f, 0.209603f, 0.347460f, 0.454165f, 0.865211f, 0.955064f, 0.518926f, 0.870100f, 0.608172f, 0.349087f, 0.194194f, 0.413135f, 0.522824f, 0.044443f, 0.145841f, 0.600184f, 0.225002f, 0.837326f, 0.326942f, 0.104834f, 0.083531f, 0.937123f, 0.118020f, 0.140910f, 0.862666f, 0.254288f, 0.665951f, 0.816726f, 0.607181f, 0.957489f, 0.708883f, 0.112752f, 0.558410f, 0.432281f, 0.521996f, 0.773084f, 0.958741f, 0.117320f, 0.107004f, 0.589695f, 0.745398f, 0.470621f, 0.115834f, 0.457059f, 0.979962f, 0.423706f, 0.857125f, 0.117316f, 0.271252f, 0.405374f, 0.321043f, 0.029950f, 0.737254f, 0.109784f, 0.606308f, 0.703218f, 0.634786f, 0.221396f, 0.100014f, 0.265040f, 0.066149f, 0.065605f, 0.856276f, 0.162120f, 0.559682f, 0.759282f, 0.364545f, 0.501063f, 0.376389f, 0.364912f, 0.260904f, 0.495970f, 0.681740f, 0.930819f, 0.520761f, 0.267207f, 0.877399f, 0.371919f, 0.001383f, 0.247685f, 0.318234f, 0.718186f, 0.801957f, 0.026321f, 0.718879f, 0.825681f, 0.746834f, 0.512349f, 0.458021f, 0.549419f, 0.704644f, 0.922914f, 0.617035f, 0.887834f, 0.701257f, 0.068336f, 0.500828f, 0.286486f, 0.285175f, 0.355928f, 0.314733f, 0.578610f, 0.683601f, 0.268749f, 0.129763f, 0.058809f, 0.575753f, 0.186130f, 0.009248f, 0.927753f, 0.537140f, 0.092448f, 0.842921f, 0.983203f, 0.448601f, 0.042490f, 0.117546f, 0.381654f, 0.885523f, 0.148039f, 0.823990f, 0.014976f, 0.457389f, 0.644397f, 0.060379f, 0.614763f, 0.944404f, 0.160260f, 0.729611f, 0.609094f, 0.185116f, 0.006203f, 0.009284f, 0.532092f, 0.942779f, 0.644299f, 0.714300f, 0.493865f, 0.581889f, 0.126368f, 0.876821f, 0.760793f, 0.998199f, 0.297723f, 0.227018f, 0.125162f, 0.964210f, 0.780885f, 0.166325f, 0.552686f, 0.413768f, 0.151486f, 0.162073f, 0.963470f, 0.304964f, 0.941439f, 0.075611f, 0.460803f, 0.129619f, 0.004787f, 0.553766f, 0.113894f, 0.722025f, 0.698116f, 0.176333f, 0.941742f, 0.721043f, 0.297970f, 0.709234f, 0.731930f, 0.342226f, 0.375589f, 0.359107f, 0.616618f, 0.900410f, 0.173193f, 0.875200f, 0.848150f, 0.935832f, 0.983426f, 0.399802f, 0.380335f, 0.147809f, 0.684934f, 0.656762f, 0.403793f, 0.399812f, 0.671384f, 0.344718f, 0.713767f, 0.639187f, 0.399161f, 0.431760f, 0.959142f, 0.103298f, 0.867167f, 0.029190f, 0.534917f, 0.404244f, 0.524184f, 0.365100f, 0.773456f, 0.456410f, 0.153369f, 0.199596f, 0.432984f, 0.528234f, 0.349440f, 0.781480f, 0.277340f, 0.524380f, 0.117380f, 0.159845f, 0.046806f, 0.970731f, 0.003860f, 0.178580f, 0.858777f, 0.458503f, 0.444587f, 0.336102f, 0.880678f, 0.945027f, 0.991890f, 0.376741f}; + // {2, 3, 4, 18} + std::vector qk_matmul = {0.820140f, 1.059902f, 0.757718f, 0.881749f, 0.858141f, 1.036822f, 0.884175f, 0.745137f, 0.702161f, 0.857424f, 0.931616f, 0.810373f, 0.765101f, 1.031954f, 0.676118f, 1.049585f, 0.679454f, 0.781211f, 0.732417f, 0.806783f, 0.671492f, 0.704470f, 0.679564f, 0.856373f, 0.747101f, 0.574466f, 0.511335f, 0.570812f, 0.772065f, 0.486530f, 0.626328f, 0.895540f, 0.426428f, 0.830139f, 0.518625f, 0.578420f, 0.491913f, 0.536788f, 0.566909f, 0.660403f, 0.508000f, 0.745048f, 0.542980f, 0.637834f, 0.427056f, 0.598455f, 0.656768f, 0.504709f, 0.485053f, 0.649462f, 0.553231f, 0.485448f, 0.577920f, 0.466000f, 0.399496f, 0.637952f, 0.382979f, 0.665599f, 0.527650f, 0.680828f, 0.511044f, 0.664769f, 0.654046f, 0.736594f, 0.645048f, 0.671768f, 0.524199f, 0.519912f, 0.615914f, 0.647178f, 0.559970f, 0.412029f, 0.492759f, 0.889178f, 0.525811f, 0.479380f, 0.766941f, 0.901303f, 1.087107f, 0.808560f, 0.779749f, 0.609254f, 0.801121f, 0.808370f, 0.397958f, 0.867537f, 0.814879f, 0.981307f, 1.048465f, 0.422327f, 0.531406f, 0.847033f, 0.878291f, 0.737390f, 0.926101f, 1.027148f, 0.731989f, 0.720755f, 0.637853f, 0.523248f, 0.924757f, 0.757182f, 0.588026f, 0.773634f, 0.979738f, 1.255782f, 0.901064f, 0.688140f, 0.274371f, 0.437410f, 0.411925f, 0.342756f, 0.545288f, 0.529269f, 0.533905f, 0.380022f, 0.436475f, 0.301469f, 0.529214f, 0.526297f, 0.395983f, 0.411271f, 0.503063f, 0.557536f, 0.505664f, 0.334459f, 0.348011f, 0.483405f, 0.482135f, 0.438657f, 0.623578f, 0.666952f, 0.527974f, 0.396662f, 0.441010f, 0.322428f, 0.543776f, 0.569352f, 0.341589f, 0.541193f, 0.719589f, 0.825763f, 0.713140f, 0.369560f, 0.925217f, 0.962246f, 0.804315f, 0.969734f, 0.939348f, 0.895554f, 1.240035f, 1.032457f, 1.260824f, 0.838023f, 0.816715f, 1.381388f, 1.123444f, 0.666636f, 0.901369f, 0.880265f, 0.544716f, 0.964444f, 0.610261f, 0.432138f, 0.522623f, 0.616368f, 0.392524f, 0.601866f, 0.610201f, 0.716924f, 0.662694f, 0.625345f, 0.421250f, 0.927903f, 0.710488f, 0.375567f, 0.528123f, 0.532916f, 0.359236f, 0.428232f, 0.627666f, 0.646350f, 0.711912f, 0.578261f, 0.510271f, 0.666607f, 0.609787f, 0.652893f, 0.673018f, 0.618551f, 0.787326f, 1.094408f, 0.787271f, 0.433836f, 0.638263f, 0.836964f, 0.604598f, 0.587050f, 0.798962f, 0.607254f, 0.635098f, 0.675595f, 0.504633f, 0.579773f, 0.825966f, 0.745334f, 0.850824f, 0.713222f, 0.417185f, 0.949167f, 0.715411f, 0.438783f, 0.580263f, 0.596451f, 0.311825f, 0.698230f, 0.553783f, 0.653118f, 0.479333f, 0.683333f, 0.611400f, 0.926136f, 0.937356f, 1.079461f, 0.500571f, 0.941776f, 0.571910f, 0.891547f, 0.471507f, 0.784496f, 0.765230f, 0.316921f, 0.693191f, 0.812555f, 0.430584f, 0.838594f, 0.577089f, 0.887826f, 0.637326f, 0.838023f, 0.852760f, 0.930619f, 0.596678f, 1.004560f, 0.556861f, 0.837758f, 0.499217f, 0.774022f, 0.908813f, 0.359039f, 0.646230f, 0.839435f, 0.724433f, 1.107947f, 0.836124f, 1.043592f, 0.755617f, 1.190845f, 0.927864f, 1.247710f, 0.759936f, 1.199264f, 0.903627f, 0.981243f, 0.477713f, 0.991537f, 0.973822f, 0.518882f, 0.798147f, 0.975918f, 0.343779f, 0.491195f, 0.197678f, 0.348761f, 0.506575f, 0.694266f, 0.570159f, 0.588826f, 0.260686f, 0.583943f, 0.370536f, 0.570071f, 0.363210f, 0.512280f, 0.518522f, 0.260276f, 0.479575f, 0.519170f, 0.649026f, 0.390051f, 0.795750f, 0.920073f, 1.046746f, 0.900276f, 0.940614f, 0.679509f, 0.778774f, 0.792281f, 0.857889f, 1.197963f, 0.738062f, 0.792745f, 0.602892f, 0.687147f, 0.962916f, 0.719326f, 0.587815f, 0.233866f, 0.638163f, 0.785593f, 0.772991f, 0.770025f, 0.862170f, 0.414778f, 0.518855f, 0.729107f, 0.683017f, 0.903488f, 0.620768f, 0.669556f, 0.396731f, 0.418391f, 0.796217f, 0.580872f, 0.555648f, 0.496401f, 0.582726f, 0.730206f, 0.806009f, 0.858020f, 0.827912f, 0.515117f, 0.715055f, 0.533599f, 0.810529f, 0.887599f, 0.629091f, 0.713460f, 0.668702f, 0.740854f, 0.533289f, 0.544756f, 0.500474f, 0.287242f, 0.666506f, 0.805604f, 0.814325f, 0.939329f, 0.784865f, 0.575117f, 0.413632f, 0.650744f, 0.916553f, 0.821434f, 0.634740f, 0.761039f, 0.447249f, 0.427194f, 0.886137f, 0.483658f, 0.957992f, 0.967132f, 0.993273f, 0.791302f, 0.858239f, 1.102870f, 1.073905f, 0.782627f, 0.700627f, 1.402989f, 0.781228f, 0.752175f, 0.879408f, 1.311348f, 0.881165f, 1.044089f, 1.012252f, 1.461238f, 0.731050f, 0.967882f, 0.932687f, 0.778944f, 0.812401f, 0.974234f, 1.130671f, 0.729870f, 0.702872f, 1.304851f, 0.727443f, 0.734453f, 0.899574f, 1.238530f, 0.921609f, 1.012872f, 0.938401f, 1.303568f, 0.824322f, 0.920016f, 0.791579f, 0.844334f, 0.618349f, 0.989377f, 1.120477f, 0.554956f, 0.683589f, 1.280705f, 0.957804f, 0.833027f, 0.791589f, 1.159548f, 1.031220f, 0.951427f, 0.915324f, 1.361176f, 0.733365f, 0.808374f, 0.764603f, 0.755506f, 0.638693f, 0.946285f, 1.001370f, 0.578989f, 0.603487f, 1.074992f, 0.697424f, 0.812599f, 0.708634f, 1.129837f, 0.888077f, 0.835530f, 1.006811f, 1.193171f}; + ASSERT_EQ(y.size(), batch_size * q_num_heads * q_sequence_length * v_head_size); + ASSERT_EQ(present_key.size(), batch_size * kv_num_heads * (past_sequence_length + kv_sequence_length) * head_size); + ASSERT_EQ(present_value.size(), batch_size * kv_num_heads * (past_sequence_length + kv_sequence_length) * v_head_size); + ASSERT_EQ(qk_matmul.size(), batch_size * kv_num_heads * q_sequence_length * (past_sequence_length + kv_sequence_length)); + + RunTest3D(batch_size, q_num_heads, q_sequence_length, head_size, kv_sequence_length, kv_num_heads, v_head_size, past_sequence_length, + q, k, v, m, std::initializer_list(), past_key, past_value, + -1, -1, std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), -1, TensorType::kFloat, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type + y, present_key, present_value, qk_matmul, + false, true, true // disable_cpu, disable_cuda, disable_dml + ); +} + +TEST(AttentionTest, Attention4DWithMask3DPastAndPresentQkMatmul) { + int batch_size = 2; // Q.shape[0] + int q_num_heads = 3; // Q.shape[1] + int q_sequence_length = 4; // Q.shape[2] + int head_size = 4; // Q.shape[3] + int kv_sequence_length = 6; // K.shape[2] and V.shape[2] + int kv_num_heads = 3; // K.shape[1] and V.shape[1] + int v_head_size = 4; // V.shape[3] + int past_sequence_length = 7; // past_key.shape[2] and past_value.shape[2] + + // {2, 3, 4, 4} + std::vector q = {-0.454545f, -0.444129f, -0.433712f, -0.423295f, -0.412879f, -0.402462f, -0.392045f, -0.381629f, -0.371212f, -0.360795f, -0.350379f, -0.339962f, -0.329545f, -0.319129f, -0.308712f, -0.298295f, -0.287879f, -0.277462f, -0.267045f, -0.256629f, -0.246212f, -0.235795f, -0.225379f, -0.214962f, -0.204545f, -0.194129f, -0.183712f, -0.173295f, -0.162879f, -0.152462f, -0.142045f, -0.131629f, -0.121212f, -0.110795f, -0.100379f, -0.089962f, -0.079545f, -0.069129f, -0.058712f, -0.048295f, -0.037879f, -0.027462f, -0.017045f, -0.006629f, 0.003788f, 0.014205f, 0.024621f, 0.035038f, 0.045455f, 0.055871f, 0.066288f, 0.076705f, 0.087121f, 0.097538f, 0.107955f, 0.118371f, 0.128788f, 0.139205f, 0.149621f, 0.160038f, 0.170455f, 0.180871f, 0.191288f, 0.201705f, 0.212121f, 0.222538f, 0.232955f, 0.243371f, 0.253788f, 0.264205f, 0.274621f, 0.285038f, 0.295455f, 0.305871f, 0.316288f, 0.326705f, 0.337121f, 0.347538f, 0.357955f, 0.368371f, 0.378788f, 0.389205f, 0.399621f, 0.410038f, 0.420455f, 0.430871f, 0.441288f, 0.451705f, 0.462121f, 0.472538f, 0.482955f, 0.493371f, 0.503788f, 0.514205f, 0.524621f, 0.535038f}; + // {2, 3, 6, 4} + std::vector k = {-0.454545f, -0.447601f, -0.440657f, -0.433712f, -0.426768f, -0.419823f, -0.412879f, -0.405934f, -0.398990f, -0.392045f, -0.385101f, -0.378157f, -0.371212f, -0.364268f, -0.357323f, -0.350379f, -0.343434f, -0.336490f, -0.329545f, -0.322601f, -0.315657f, -0.308712f, -0.301768f, -0.294823f, -0.287879f, -0.280934f, -0.273990f, -0.267045f, -0.260101f, -0.253157f, -0.246212f, -0.239268f, -0.232323f, -0.225379f, -0.218434f, -0.211490f, -0.204545f, -0.197601f, -0.190657f, -0.183712f, -0.176768f, -0.169823f, -0.162879f, -0.155934f, -0.148990f, -0.142045f, -0.135101f, -0.128157f, -0.121212f, -0.114268f, -0.107323f, -0.100379f, -0.093434f, -0.086490f, -0.079545f, -0.072601f, -0.065657f, -0.058712f, -0.051768f, -0.044823f, -0.037879f, -0.030934f, -0.023990f, -0.017045f, -0.010101f, -0.003157f, 0.003788f, 0.010732f, 0.017677f, 0.024621f, 0.031566f, 0.038510f, 0.045455f, 0.052399f, 0.059343f, 0.066288f, 0.073232f, 0.080177f, 0.087121f, 0.094066f, 0.101010f, 0.107955f, 0.114899f, 0.121843f, 0.128788f, 0.135732f, 0.142677f, 0.149621f, 0.156566f, 0.163510f, 0.170455f, 0.177399f, 0.184343f, 0.191288f, 0.198232f, 0.205177f, 0.212121f, 0.219066f, 0.226010f, 0.232955f, 0.239899f, 0.246843f, 0.253788f, 0.260732f, 0.267677f, 0.274621f, 0.281566f, 0.288510f, 0.295455f, 0.302399f, 0.309343f, 0.316288f, 0.323232f, 0.330177f, 0.337121f, 0.344066f, 0.351010f, 0.357955f, 0.364899f, 0.371843f, 0.378788f, 0.385732f, 0.392677f, 0.399621f, 0.406566f, 0.413510f, 0.420455f, 0.427399f, 0.434343f, 0.441288f, 0.448232f, 0.455177f, 0.462121f, 0.469066f, 0.476010f, 0.482955f, 0.489899f, 0.496843f, 0.503788f, 0.510732f, 0.517677f, 0.524621f, 0.531566f, 0.538510f}; + // {2, 3, 6, 4} + std::vector v = {-0.454545f, -0.447601f, -0.440657f, -0.433712f, -0.426768f, -0.419823f, -0.412879f, -0.405934f, -0.398990f, -0.392045f, -0.385101f, -0.378157f, -0.371212f, -0.364268f, -0.357323f, -0.350379f, -0.343434f, -0.336490f, -0.329545f, -0.322601f, -0.315657f, -0.308712f, -0.301768f, -0.294823f, -0.287879f, -0.280934f, -0.273990f, -0.267045f, -0.260101f, -0.253157f, -0.246212f, -0.239268f, -0.232323f, -0.225379f, -0.218434f, -0.211490f, -0.204545f, -0.197601f, -0.190657f, -0.183712f, -0.176768f, -0.169823f, -0.162879f, -0.155934f, -0.148990f, -0.142045f, -0.135101f, -0.128157f, -0.121212f, -0.114268f, -0.107323f, -0.100379f, -0.093434f, -0.086490f, -0.079545f, -0.072601f, -0.065657f, -0.058712f, -0.051768f, -0.044823f, -0.037879f, -0.030934f, -0.023990f, -0.017045f, -0.010101f, -0.003157f, 0.003788f, 0.010732f, 0.017677f, 0.024621f, 0.031566f, 0.038510f, 0.045455f, 0.052399f, 0.059343f, 0.066288f, 0.073232f, 0.080177f, 0.087121f, 0.094066f, 0.101010f, 0.107955f, 0.114899f, 0.121843f, 0.128788f, 0.135732f, 0.142677f, 0.149621f, 0.156566f, 0.163510f, 0.170455f, 0.177399f, 0.184343f, 0.191288f, 0.198232f, 0.205177f, 0.212121f, 0.219066f, 0.226010f, 0.232955f, 0.239899f, 0.246843f, 0.253788f, 0.260732f, 0.267677f, 0.274621f, 0.281566f, 0.288510f, 0.295455f, 0.302399f, 0.309343f, 0.316288f, 0.323232f, 0.330177f, 0.337121f, 0.344066f, 0.351010f, 0.357955f, 0.364899f, 0.371843f, 0.378788f, 0.385732f, 0.392677f, 0.399621f, 0.406566f, 0.413510f, 0.420455f, 0.427399f, 0.434343f, 0.441288f, 0.448232f, 0.455177f, 0.462121f, 0.469066f, 0.476010f, 0.482955f, 0.489899f, 0.496843f, 0.503788f, 0.510732f, 0.517677f, 0.524621f, 0.531566f, 0.538510f}; + // {2, 1, 4, 13} + std::vector m = {-0.454545f, -0.444930f, -0.435315f, -0.425699f, -0.416084f, -0.406469f, -0.396853f, -0.387238f, -0.377622f, -0.368007f, -0.358392f, -0.348776f, -0.339161f, -0.329545f, -0.319930f, -0.310315f, -0.300699f, -0.291084f, -0.281469f, -0.271853f, -0.262238f, -0.252622f, -0.243007f, -0.233392f, -0.223776f, -0.214161f, -0.204545f, -0.194930f, -0.185315f, -0.175699f, -0.166084f, -0.156469f, -0.146853f, -0.137238f, -0.127622f, -0.118007f, -0.108392f, -0.098776f, -0.089161f, -0.079545f, -0.069930f, -0.060315f, -0.050699f, -0.041084f, -0.031469f, -0.021853f, -0.012238f, -0.002622f, 0.006993f, 0.016608f, 0.026224f, 0.035839f, 0.045455f, 0.055070f, 0.064685f, 0.074301f, 0.083916f, 0.093531f, 0.103147f, 0.112762f, 0.122378f, 0.131993f, 0.141608f, 0.151224f, 0.160839f, 0.170455f, 0.180070f, 0.189685f, 0.199301f, 0.208916f, 0.218531f, 0.228147f, 0.237762f, 0.247378f, 0.256993f, 0.266608f, 0.276224f, 0.285839f, 0.295455f, 0.305070f, 0.314685f, 0.324301f, 0.333916f, 0.343531f, 0.353147f, 0.362762f, 0.372378f, 0.381993f, 0.391608f, 0.401224f, 0.410839f, 0.420455f, 0.430070f, 0.439685f, 0.449301f, 0.458916f, 0.468531f, 0.478147f, 0.487762f, 0.497378f, 0.506993f, 0.516608f, 0.526224f, 0.535839f}; + // {2, 3, 12, 4} + std::vector past_key = {-0.454545f, -0.448593f, -0.442641f, -0.436688f, -0.430736f, -0.424784f, -0.418831f, -0.412879f, -0.406926f, -0.400974f, -0.395022f, -0.389069f, -0.383117f, -0.377165f, -0.371212f, -0.365260f, -0.359307f, -0.353355f, -0.347403f, -0.341450f, -0.335498f, -0.329545f, -0.323593f, -0.317641f, -0.311688f, -0.305736f, -0.299784f, -0.293831f, -0.287879f, -0.281926f, -0.275974f, -0.270022f, -0.264069f, -0.258117f, -0.252165f, -0.246212f, -0.240260f, -0.234307f, -0.228355f, -0.222403f, -0.216450f, -0.210498f, -0.204545f, -0.198593f, -0.192641f, -0.186688f, -0.180736f, -0.174784f, -0.168831f, -0.162879f, -0.156926f, -0.150974f, -0.145022f, -0.139069f, -0.133117f, -0.127164f, -0.121212f, -0.115260f, -0.109307f, -0.103355f, -0.097403f, -0.091450f, -0.085498f, -0.079545f, -0.073593f, -0.067641f, -0.061688f, -0.055736f, -0.049784f, -0.043831f, -0.037879f, -0.031926f, -0.025974f, -0.020022f, -0.014069f, -0.008117f, -0.002165f, 0.003788f, 0.009740f, 0.015693f, 0.021645f, 0.027597f, 0.033550f, 0.039502f, 0.045455f, 0.051407f, 0.057359f, 0.063312f, 0.069264f, 0.075216f, 0.081169f, 0.087121f, 0.093074f, 0.099026f, 0.104978f, 0.110931f, 0.116883f, 0.122835f, 0.128788f, 0.134740f, 0.140693f, 0.146645f, 0.152597f, 0.158550f, 0.164502f, 0.170455f, 0.176407f, 0.182359f, 0.188312f, 0.194264f, 0.200216f, 0.206169f, 0.212121f, 0.218074f, 0.224026f, 0.229978f, 0.235931f, 0.241883f, 0.247836f, 0.253788f, 0.259740f, 0.265693f, 0.271645f, 0.277597f, 0.283550f, 0.289502f, 0.295455f, 0.301407f, 0.307359f, 0.313312f, 0.319264f, 0.325216f, 0.331169f, 0.337121f, 0.343074f, 0.349026f, 0.354978f, 0.360931f, 0.366883f, 0.372835f, 0.378788f, 0.384740f, 0.390693f, 0.396645f, 0.402597f, 0.408550f, 0.414502f, 0.420455f, 0.426407f, 0.432359f, 0.438312f, 0.444264f, 0.450216f, 0.456169f, 0.462121f, 0.468074f, 0.474026f, 0.479978f, 0.485931f, 0.491883f, 0.497835f, 0.503788f, 0.509740f, 0.515693f, 0.521645f, 0.527597f, 0.533550f, 0.539502f}; + // {2, 3, 12, 4} + std::vector past_value = {-0.454545f, -0.448593f, -0.442641f, -0.436688f, -0.430736f, -0.424784f, -0.418831f, -0.412879f, -0.406926f, -0.400974f, -0.395022f, -0.389069f, -0.383117f, -0.377165f, -0.371212f, -0.365260f, -0.359307f, -0.353355f, -0.347403f, -0.341450f, -0.335498f, -0.329545f, -0.323593f, -0.317641f, -0.311688f, -0.305736f, -0.299784f, -0.293831f, -0.287879f, -0.281926f, -0.275974f, -0.270022f, -0.264069f, -0.258117f, -0.252165f, -0.246212f, -0.240260f, -0.234307f, -0.228355f, -0.222403f, -0.216450f, -0.210498f, -0.204545f, -0.198593f, -0.192641f, -0.186688f, -0.180736f, -0.174784f, -0.168831f, -0.162879f, -0.156926f, -0.150974f, -0.145022f, -0.139069f, -0.133117f, -0.127164f, -0.121212f, -0.115260f, -0.109307f, -0.103355f, -0.097403f, -0.091450f, -0.085498f, -0.079545f, -0.073593f, -0.067641f, -0.061688f, -0.055736f, -0.049784f, -0.043831f, -0.037879f, -0.031926f, -0.025974f, -0.020022f, -0.014069f, -0.008117f, -0.002165f, 0.003788f, 0.009740f, 0.015693f, 0.021645f, 0.027597f, 0.033550f, 0.039502f, 0.045455f, 0.051407f, 0.057359f, 0.063312f, 0.069264f, 0.075216f, 0.081169f, 0.087121f, 0.093074f, 0.099026f, 0.104978f, 0.110931f, 0.116883f, 0.122835f, 0.128788f, 0.134740f, 0.140693f, 0.146645f, 0.152597f, 0.158550f, 0.164502f, 0.170455f, 0.176407f, 0.182359f, 0.188312f, 0.194264f, 0.200216f, 0.206169f, 0.212121f, 0.218074f, 0.224026f, 0.229978f, 0.235931f, 0.241883f, 0.247836f, 0.253788f, 0.259740f, 0.265693f, 0.271645f, 0.277597f, 0.283550f, 0.289502f, 0.295455f, 0.301407f, 0.307359f, 0.313312f, 0.319264f, 0.325216f, 0.331169f, 0.337121f, 0.343074f, 0.349026f, 0.354978f, 0.360931f, 0.366883f, 0.372835f, 0.378788f, 0.384740f, 0.390693f, 0.396645f, 0.402597f, 0.408550f, 0.414502f, 0.420455f, 0.426407f, 0.432359f, 0.438312f, 0.444264f, 0.450216f, 0.456169f, 0.462121f, 0.468074f, 0.474026f, 0.479978f, 0.485931f, 0.491883f, 0.497835f, 0.503788f, 0.509740f, 0.515693f, 0.521645f, 0.527597f, 0.533550f, 0.539502f}; + + ASSERT_EQ(q.size(), batch_size * q_num_heads * q_sequence_length * head_size); + ASSERT_EQ(k.size(), batch_size * kv_num_heads * kv_sequence_length * head_size); + ASSERT_EQ(v.size(), batch_size * kv_num_heads * kv_sequence_length * v_head_size); + ASSERT_EQ(m.size(), batch_size * 1 * q_sequence_length * (kv_sequence_length + past_sequence_length)); + ASSERT_EQ(past_key.size(), batch_size * kv_num_heads * past_sequence_length * head_size); + ASSERT_EQ(past_value.size(), batch_size * kv_num_heads * past_sequence_length * v_head_size); + + // {2, 3, 4, 4} + std::vector y = {-0.385197f, -0.378771f, -0.372345f, -0.365919f, -0.385008f, -0.378583f, -0.372157f, -0.365731f, -0.384820f, -0.378394f, -0.371968f, -0.365543f, -0.384632f, -0.378206f, -0.371780f, -0.365354f, -0.217777f, -0.211351f, -0.204925f, -0.198499f, -0.217588f, -0.211163f, -0.204737f, -0.198311f, -0.217400f, -0.210974f, -0.204549f, -0.198123f, -0.217212f, -0.210786f, -0.204360f, -0.197935f, -0.050357f, -0.043931f, -0.037505f, -0.031080f, -0.050169f, -0.043743f, -0.037317f, -0.030891f, -0.049980f, -0.043555f, -0.037129f, -0.030703f, -0.049792f, -0.043366f, -0.036941f, -0.030515f, 0.117063f, 0.123489f, 0.129914f, 0.136340f, 0.117251f, 0.123677f, 0.130102f, 0.136528f, 0.117439f, 0.123865f, 0.130291f, 0.136716f, 0.117628f, 0.124053f, 0.130479f, 0.136904f, 0.284482f, 0.290908f, 0.297334f, 0.303759f, 0.284670f, 0.291096f, 0.297522f, 0.303947f, 0.284859f, 0.291284f, 0.297710f, 0.304135f, 0.285047f, 0.291472f, 0.297898f, 0.304323f, 0.451901f, 0.458327f, 0.464752f, 0.471178f, 0.452089f, 0.458515f, 0.464940f, 0.471366f, 0.452277f, 0.458703f, 0.465128f, 0.471554f, 0.452465f, 0.458890f, 0.465316f, 0.471741f}; + // {2, 3, 13, 4} + std::vector present_key = {-0.454545f, -0.448593f, -0.442641f, -0.436688f, -0.430736f, -0.424784f, -0.418831f, -0.412879f, -0.406926f, -0.400974f, -0.395022f, -0.389069f, -0.383117f, -0.377165f, -0.371212f, -0.365260f, -0.359307f, -0.353355f, -0.347403f, -0.341450f, -0.335498f, -0.329545f, -0.323593f, -0.317641f, -0.311688f, -0.305736f, -0.299784f, -0.293831f, -0.454545f, -0.447601f, -0.440657f, -0.433712f, -0.426768f, -0.419823f, -0.412879f, -0.405934f, -0.398990f, -0.392045f, -0.385101f, -0.378157f, -0.371212f, -0.364268f, -0.357323f, -0.350379f, -0.343434f, -0.336490f, -0.329545f, -0.322601f, -0.315657f, -0.308712f, -0.301768f, -0.294823f, -0.287879f, -0.281926f, -0.275974f, -0.270022f, -0.264069f, -0.258117f, -0.252165f, -0.246212f, -0.240260f, -0.234307f, -0.228355f, -0.222403f, -0.216450f, -0.210498f, -0.204545f, -0.198593f, -0.192641f, -0.186688f, -0.180736f, -0.174784f, -0.168831f, -0.162879f, -0.156926f, -0.150974f, -0.145022f, -0.139069f, -0.133117f, -0.127164f, -0.287879f, -0.280934f, -0.273990f, -0.267045f, -0.260101f, -0.253157f, -0.246212f, -0.239268f, -0.232323f, -0.225379f, -0.218434f, -0.211490f, -0.204545f, -0.197601f, -0.190657f, -0.183712f, -0.176768f, -0.169823f, -0.162879f, -0.155934f, -0.148990f, -0.142045f, -0.135101f, -0.128157f, -0.121212f, -0.115260f, -0.109307f, -0.103355f, -0.097403f, -0.091450f, -0.085498f, -0.079545f, -0.073593f, -0.067641f, -0.061688f, -0.055736f, -0.049784f, -0.043831f, -0.037879f, -0.031926f, -0.025974f, -0.020022f, -0.014069f, -0.008117f, -0.002165f, 0.003788f, 0.009740f, 0.015693f, 0.021645f, 0.027597f, 0.033550f, 0.039502f, -0.121212f, -0.114268f, -0.107323f, -0.100379f, -0.093434f, -0.086490f, -0.079545f, -0.072601f, -0.065657f, -0.058712f, -0.051768f, -0.044823f, -0.037879f, -0.030934f, -0.023990f, -0.017045f, -0.010101f, -0.003157f, 0.003788f, 0.010732f, 0.017677f, 0.024621f, 0.031566f, 0.038510f, 0.045455f, 0.051407f, 0.057359f, 0.063312f, 0.069264f, 0.075216f, 0.081169f, 0.087121f, 0.093074f, 0.099026f, 0.104978f, 0.110931f, 0.116883f, 0.122835f, 0.128788f, 0.134740f, 0.140693f, 0.146645f, 0.152597f, 0.158550f, 0.164502f, 0.170455f, 0.176407f, 0.182359f, 0.188312f, 0.194264f, 0.200216f, 0.206169f, 0.045455f, 0.052399f, 0.059343f, 0.066288f, 0.073232f, 0.080177f, 0.087121f, 0.094066f, 0.101010f, 0.107955f, 0.114899f, 0.121843f, 0.128788f, 0.135732f, 0.142677f, 0.149621f, 0.156566f, 0.163510f, 0.170455f, 0.177399f, 0.184343f, 0.191288f, 0.198232f, 0.205177f, 0.212121f, 0.218074f, 0.224026f, 0.229978f, 0.235931f, 0.241883f, 0.247836f, 0.253788f, 0.259740f, 0.265693f, 0.271645f, 0.277597f, 0.283550f, 0.289502f, 0.295455f, 0.301407f, 0.307359f, 0.313312f, 0.319264f, 0.325216f, 0.331169f, 0.337121f, 0.343074f, 0.349026f, 0.354978f, 0.360931f, 0.366883f, 0.372835f, 0.212121f, 0.219066f, 0.226010f, 0.232955f, 0.239899f, 0.246843f, 0.253788f, 0.260732f, 0.267677f, 0.274621f, 0.281566f, 0.288510f, 0.295455f, 0.302399f, 0.309343f, 0.316288f, 0.323232f, 0.330177f, 0.337121f, 0.344066f, 0.351010f, 0.357955f, 0.364899f, 0.371843f, 0.378788f, 0.384740f, 0.390693f, 0.396645f, 0.402597f, 0.408550f, 0.414502f, 0.420455f, 0.426407f, 0.432359f, 0.438312f, 0.444264f, 0.450216f, 0.456169f, 0.462121f, 0.468074f, 0.474026f, 0.479978f, 0.485931f, 0.491883f, 0.497835f, 0.503788f, 0.509740f, 0.515693f, 0.521645f, 0.527597f, 0.533550f, 0.539502f, 0.378788f, 0.385732f, 0.392677f, 0.399621f, 0.406566f, 0.413510f, 0.420455f, 0.427399f, 0.434343f, 0.441288f, 0.448232f, 0.455177f, 0.462121f, 0.469066f, 0.476010f, 0.482955f, 0.489899f, 0.496843f, 0.503788f, 0.510732f, 0.517677f, 0.524621f, 0.531566f, 0.538510f}; + // {2, 3, 18, 8} + std::vector present_value = {-0.454545f, -0.448593f, -0.442641f, -0.436688f, -0.430736f, -0.424784f, -0.418831f, -0.412879f, -0.406926f, -0.400974f, -0.395022f, -0.389069f, -0.383117f, -0.377165f, -0.371212f, -0.365260f, -0.359307f, -0.353355f, -0.347403f, -0.341450f, -0.335498f, -0.329545f, -0.323593f, -0.317641f, -0.311688f, -0.305736f, -0.299784f, -0.293831f, -0.454545f, -0.447601f, -0.440657f, -0.433712f, -0.426768f, -0.419823f, -0.412879f, -0.405934f, -0.398990f, -0.392045f, -0.385101f, -0.378157f, -0.371212f, -0.364268f, -0.357323f, -0.350379f, -0.343434f, -0.336490f, -0.329545f, -0.322601f, -0.315657f, -0.308712f, -0.301768f, -0.294823f, -0.287879f, -0.281926f, -0.275974f, -0.270022f, -0.264069f, -0.258117f, -0.252165f, -0.246212f, -0.240260f, -0.234307f, -0.228355f, -0.222403f, -0.216450f, -0.210498f, -0.204545f, -0.198593f, -0.192641f, -0.186688f, -0.180736f, -0.174784f, -0.168831f, -0.162879f, -0.156926f, -0.150974f, -0.145022f, -0.139069f, -0.133117f, -0.127164f, -0.287879f, -0.280934f, -0.273990f, -0.267045f, -0.260101f, -0.253157f, -0.246212f, -0.239268f, -0.232323f, -0.225379f, -0.218434f, -0.211490f, -0.204545f, -0.197601f, -0.190657f, -0.183712f, -0.176768f, -0.169823f, -0.162879f, -0.155934f, -0.148990f, -0.142045f, -0.135101f, -0.128157f, -0.121212f, -0.115260f, -0.109307f, -0.103355f, -0.097403f, -0.091450f, -0.085498f, -0.079545f, -0.073593f, -0.067641f, -0.061688f, -0.055736f, -0.049784f, -0.043831f, -0.037879f, -0.031926f, -0.025974f, -0.020022f, -0.014069f, -0.008117f, -0.002165f, 0.003788f, 0.009740f, 0.015693f, 0.021645f, 0.027597f, 0.033550f, 0.039502f, -0.121212f, -0.114268f, -0.107323f, -0.100379f, -0.093434f, -0.086490f, -0.079545f, -0.072601f, -0.065657f, -0.058712f, -0.051768f, -0.044823f, -0.037879f, -0.030934f, -0.023990f, -0.017045f, -0.010101f, -0.003157f, 0.003788f, 0.010732f, 0.017677f, 0.024621f, 0.031566f, 0.038510f, 0.045455f, 0.051407f, 0.057359f, 0.063312f, 0.069264f, 0.075216f, 0.081169f, 0.087121f, 0.093074f, 0.099026f, 0.104978f, 0.110931f, 0.116883f, 0.122835f, 0.128788f, 0.134740f, 0.140693f, 0.146645f, 0.152597f, 0.158550f, 0.164502f, 0.170455f, 0.176407f, 0.182359f, 0.188312f, 0.194264f, 0.200216f, 0.206169f, 0.045455f, 0.052399f, 0.059343f, 0.066288f, 0.073232f, 0.080177f, 0.087121f, 0.094066f, 0.101010f, 0.107955f, 0.114899f, 0.121843f, 0.128788f, 0.135732f, 0.142677f, 0.149621f, 0.156566f, 0.163510f, 0.170455f, 0.177399f, 0.184343f, 0.191288f, 0.198232f, 0.205177f, 0.212121f, 0.218074f, 0.224026f, 0.229978f, 0.235931f, 0.241883f, 0.247836f, 0.253788f, 0.259740f, 0.265693f, 0.271645f, 0.277597f, 0.283550f, 0.289502f, 0.295455f, 0.301407f, 0.307359f, 0.313312f, 0.319264f, 0.325216f, 0.331169f, 0.337121f, 0.343074f, 0.349026f, 0.354978f, 0.360931f, 0.366883f, 0.372835f, 0.212121f, 0.219066f, 0.226010f, 0.232955f, 0.239899f, 0.246843f, 0.253788f, 0.260732f, 0.267677f, 0.274621f, 0.281566f, 0.288510f, 0.295455f, 0.302399f, 0.309343f, 0.316288f, 0.323232f, 0.330177f, 0.337121f, 0.344066f, 0.351010f, 0.357955f, 0.364899f, 0.371843f, 0.378788f, 0.384740f, 0.390693f, 0.396645f, 0.402597f, 0.408550f, 0.414502f, 0.420455f, 0.426407f, 0.432359f, 0.438312f, 0.444264f, 0.450216f, 0.456169f, 0.462121f, 0.468074f, 0.474026f, 0.479978f, 0.485931f, 0.491883f, 0.497835f, 0.503788f, 0.509740f, 0.515693f, 0.521645f, 0.527597f, 0.533550f, 0.539502f, 0.378788f, 0.385732f, 0.392677f, 0.399621f, 0.406566f, 0.413510f, 0.420455f, 0.427399f, 0.434343f, 0.441288f, 0.448232f, 0.455177f, 0.462121f, 0.469066f, 0.476010f, 0.482955f, 0.489899f, 0.496843f, 0.503788f, 0.510732f, 0.517677f, 0.524621f, 0.531566f, 0.538510f}; + // {2, 3, 4, 13} + std::vector qk_matmul = {0.391336f, 0.370435f, 0.349534f, 0.328633f, 0.307732f, 0.286831f, 0.265930f, 0.390055f, 0.365671f, 0.341286f, 0.316902f, 0.292517f, 0.268133f, 0.354201f, 0.335284f, 0.316367f, 0.297450f, 0.278534f, 0.259617f, 0.240700f, 0.353045f, 0.330975f, 0.308905f, 0.286836f, 0.264766f, 0.242696f, 0.317066f, 0.300134f, 0.283201f, 0.266268f, 0.249335f, 0.232403f, 0.215470f, 0.316034f, 0.296279f, 0.276524f, 0.256769f, 0.237014f, 0.217260f, 0.279932f, 0.264983f, 0.250034f, 0.235086f, 0.220137f, 0.205189f, 0.190240f, 0.279023f, 0.261583f, 0.244143f, 0.226703f, 0.209263f, 0.191823f, 0.152046f, 0.139081f, 0.126117f, 0.113152f, 0.100188f, 0.087223f, 0.074259f, 0.151261f, 0.136136f, 0.121011f, 0.105885f, 0.090760f, 0.075635f, 0.128800f, 0.117819f, 0.106839f, 0.095859f, 0.084878f, 0.073898f, 0.062918f, 0.128139f, 0.115329f, 0.102518f, 0.089708f, 0.076898f, 0.064087f, 0.105554f, 0.096558f, 0.087561f, 0.078565f, 0.069569f, 0.060573f, 0.051577f, 0.105017f, 0.094522f, 0.084026f, 0.073531f, 0.063035f, 0.052539f, 0.082308f, 0.075296f, 0.068284f, 0.061272f, 0.054260f, 0.047248f, 0.040235f, 0.081896f, 0.073715f, 0.065534f, 0.057353f, 0.049172f, 0.040992f, 0.023866f, 0.018838f, 0.013810f, 0.008783f, 0.003755f, -0.001273f, -0.006301f, 0.023578f, 0.017712f, 0.011846f, 0.005980f, 0.000114f, -0.005752f, 0.014509f, 0.011466f, 0.008422f, 0.005378f, 0.002334f, -0.000710f, -0.003754f, 0.014345f, 0.010794f, 0.007243f, 0.003692f, 0.000140f, -0.003411f, 0.005152f, 0.004093f, 0.003033f, 0.001973f, 0.000914f, -0.000146f, -0.001206f, 0.005112f, 0.003876f, 0.002639f, 0.001403f, 0.000167f, -0.001070f, -0.004204f, -0.003280f, -0.002356f, -0.001431f, -0.000507f, 0.000418f, 0.001342f, -0.004121f, -0.003042f, -0.001964f, -0.000885f, 0.000193f, 0.001272f, 0.006798f, 0.009707f, 0.012616f, 0.015524f, 0.018433f, 0.021341f, 0.024250f, 0.007006f, 0.010399f, 0.013793f, 0.017186f, 0.020579f, 0.023973f, 0.011330f, 0.016223f, 0.021116f, 0.026008f, 0.030901f, 0.035794f, 0.040686f, 0.011662f, 0.017370f, 0.023078f, 0.028786f, 0.034494f, 0.040203f, 0.015862f, 0.022739f, 0.029616f, 0.036493f, 0.043369f, 0.050246f, 0.057123f, 0.016318f, 0.024341f, 0.032364f, 0.040387f, 0.048410f, 0.056433f, 0.020394f, 0.029255f, 0.038116f, 0.046977f, 0.055838f, 0.064699f, 0.073560f, 0.020974f, 0.031312f, 0.041649f, 0.051987f, 0.062325f, 0.072663f, 0.100842f, 0.111687f, 0.122532f, 0.133377f, 0.144222f, 0.155067f, 0.165912f, 0.101545f, 0.114198f, 0.126850f, 0.139503f, 0.152155f, 0.164808f, 0.119262f, 0.132092f, 0.144921f, 0.157750f, 0.170579f, 0.183408f, 0.196237f, 0.120090f, 0.135057f, 0.150025f, 0.164992f, 0.179960f, 0.194927f, 0.137683f, 0.152496f, 0.167310f, 0.182123f, 0.196936f, 0.211750f, 0.226563f, 0.138635f, 0.155917f, 0.173199f, 0.190481f, 0.207764f, 0.225046f, 0.156104f, 0.172901f, 0.189699f, 0.206496f, 0.223294f, 0.240091f, 0.256889f, 0.157180f, 0.176777f, 0.196374f, 0.215971f, 0.235568f, 0.255165f, 0.305996f, 0.324777f, 0.343559f, 0.362340f, 0.381122f, 0.399904f, 0.418685f, 0.307195f, 0.329107f, 0.351019f, 0.372931f, 0.394843f, 0.416755f, 0.338305f, 0.359071f, 0.379837f, 0.400603f, 0.421368f, 0.442134f, 0.462900f, 0.339629f, 0.363856f, 0.388082f, 0.412309f, 0.436536f, 0.460762f, 0.370615f, 0.393365f, 0.416115f, 0.438865f, 0.461614f, 0.484364f, 0.507114f, 0.372063f, 0.398604f, 0.425146f, 0.451687f, 0.478229f, 0.504770f, 0.402925f, 0.427659f, 0.452393f, 0.477127f, 0.501861f, 0.526595f, 0.551329f, 0.404497f, 0.433353f, 0.462209f, 0.491065f, 0.519922f, 0.548778f}; + ASSERT_EQ(y.size(), batch_size * q_num_heads * q_sequence_length * v_head_size); + ASSERT_EQ(present_key.size(), batch_size * kv_num_heads * (past_sequence_length + kv_sequence_length) * head_size); + ASSERT_EQ(present_value.size(), batch_size * kv_num_heads * (past_sequence_length + kv_sequence_length) * v_head_size); + ASSERT_EQ(qk_matmul.size(), batch_size * kv_num_heads * q_sequence_length * (past_sequence_length + kv_sequence_length)); + + RunTest4D(batch_size, q_num_heads, q_sequence_length, head_size, kv_sequence_length, kv_num_heads, v_head_size, past_sequence_length, + q, k, v, m, std::initializer_list(), past_key, past_value, + -1, -1, std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), -1, TensorType::kFloat, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type + y, present_key, present_value, qk_matmul, + false, true, true // disable_cpu, disable_cuda, disable_dml + ); +} + +TEST(AttentionTest, Attention4DWithMask3DPastAndPresentQkMatmulCausal) { + int batch_size = 2; // Q.shape[0] + int q_num_heads = 3; // Q.shape[1] + int q_sequence_length = 4; // Q.shape[2] + int head_size = 4; // Q.shape[3] + int kv_sequence_length = 6; // K.shape[2] and V.shape[2] + int kv_num_heads = 3; // K.shape[1] and V.shape[1] + int v_head_size = 4; // V.shape[3] + int past_sequence_length = 7; // past_key.shape[2] and past_value.shape[2] + + // {2, 3, 4, 4} + std::vector q = {-0.454545f, -0.444129f, -0.433712f, -0.423295f, -0.412879f, -0.402462f, -0.392045f, -0.381629f, -0.371212f, -0.360795f, -0.350379f, -0.339962f, -0.329545f, -0.319129f, -0.308712f, -0.298295f, -0.287879f, -0.277462f, -0.267045f, -0.256629f, -0.246212f, -0.235795f, -0.225379f, -0.214962f, -0.204545f, -0.194129f, -0.183712f, -0.173295f, -0.162879f, -0.152462f, -0.142045f, -0.131629f, -0.121212f, -0.110795f, -0.100379f, -0.089962f, -0.079545f, -0.069129f, -0.058712f, -0.048295f, -0.037879f, -0.027462f, -0.017045f, -0.006629f, 0.003788f, 0.014205f, 0.024621f, 0.035038f, 0.045455f, 0.055871f, 0.066288f, 0.076705f, 0.087121f, 0.097538f, 0.107955f, 0.118371f, 0.128788f, 0.139205f, 0.149621f, 0.160038f, 0.170455f, 0.180871f, 0.191288f, 0.201705f, 0.212121f, 0.222538f, 0.232955f, 0.243371f, 0.253788f, 0.264205f, 0.274621f, 0.285038f, 0.295455f, 0.305871f, 0.316288f, 0.326705f, 0.337121f, 0.347538f, 0.357955f, 0.368371f, 0.378788f, 0.389205f, 0.399621f, 0.410038f, 0.420455f, 0.430871f, 0.441288f, 0.451705f, 0.462121f, 0.472538f, 0.482955f, 0.493371f, 0.503788f, 0.514205f, 0.524621f, 0.535038f}; + // {2, 3, 6, 4} + std::vector k = {-0.454545f, -0.447601f, -0.440657f, -0.433712f, -0.426768f, -0.419823f, -0.412879f, -0.405934f, -0.398990f, -0.392045f, -0.385101f, -0.378157f, -0.371212f, -0.364268f, -0.357323f, -0.350379f, -0.343434f, -0.336490f, -0.329545f, -0.322601f, -0.315657f, -0.308712f, -0.301768f, -0.294823f, -0.287879f, -0.280934f, -0.273990f, -0.267045f, -0.260101f, -0.253157f, -0.246212f, -0.239268f, -0.232323f, -0.225379f, -0.218434f, -0.211490f, -0.204545f, -0.197601f, -0.190657f, -0.183712f, -0.176768f, -0.169823f, -0.162879f, -0.155934f, -0.148990f, -0.142045f, -0.135101f, -0.128157f, -0.121212f, -0.114268f, -0.107323f, -0.100379f, -0.093434f, -0.086490f, -0.079545f, -0.072601f, -0.065657f, -0.058712f, -0.051768f, -0.044823f, -0.037879f, -0.030934f, -0.023990f, -0.017045f, -0.010101f, -0.003157f, 0.003788f, 0.010732f, 0.017677f, 0.024621f, 0.031566f, 0.038510f, 0.045455f, 0.052399f, 0.059343f, 0.066288f, 0.073232f, 0.080177f, 0.087121f, 0.094066f, 0.101010f, 0.107955f, 0.114899f, 0.121843f, 0.128788f, 0.135732f, 0.142677f, 0.149621f, 0.156566f, 0.163510f, 0.170455f, 0.177399f, 0.184343f, 0.191288f, 0.198232f, 0.205177f, 0.212121f, 0.219066f, 0.226010f, 0.232955f, 0.239899f, 0.246843f, 0.253788f, 0.260732f, 0.267677f, 0.274621f, 0.281566f, 0.288510f, 0.295455f, 0.302399f, 0.309343f, 0.316288f, 0.323232f, 0.330177f, 0.337121f, 0.344066f, 0.351010f, 0.357955f, 0.364899f, 0.371843f, 0.378788f, 0.385732f, 0.392677f, 0.399621f, 0.406566f, 0.413510f, 0.420455f, 0.427399f, 0.434343f, 0.441288f, 0.448232f, 0.455177f, 0.462121f, 0.469066f, 0.476010f, 0.482955f, 0.489899f, 0.496843f, 0.503788f, 0.510732f, 0.517677f, 0.524621f, 0.531566f, 0.538510f}; + // {2, 3, 6, 4} + std::vector v = {-0.454545f, -0.447601f, -0.440657f, -0.433712f, -0.426768f, -0.419823f, -0.412879f, -0.405934f, -0.398990f, -0.392045f, -0.385101f, -0.378157f, -0.371212f, -0.364268f, -0.357323f, -0.350379f, -0.343434f, -0.336490f, -0.329545f, -0.322601f, -0.315657f, -0.308712f, -0.301768f, -0.294823f, -0.287879f, -0.280934f, -0.273990f, -0.267045f, -0.260101f, -0.253157f, -0.246212f, -0.239268f, -0.232323f, -0.225379f, -0.218434f, -0.211490f, -0.204545f, -0.197601f, -0.190657f, -0.183712f, -0.176768f, -0.169823f, -0.162879f, -0.155934f, -0.148990f, -0.142045f, -0.135101f, -0.128157f, -0.121212f, -0.114268f, -0.107323f, -0.100379f, -0.093434f, -0.086490f, -0.079545f, -0.072601f, -0.065657f, -0.058712f, -0.051768f, -0.044823f, -0.037879f, -0.030934f, -0.023990f, -0.017045f, -0.010101f, -0.003157f, 0.003788f, 0.010732f, 0.017677f, 0.024621f, 0.031566f, 0.038510f, 0.045455f, 0.052399f, 0.059343f, 0.066288f, 0.073232f, 0.080177f, 0.087121f, 0.094066f, 0.101010f, 0.107955f, 0.114899f, 0.121843f, 0.128788f, 0.135732f, 0.142677f, 0.149621f, 0.156566f, 0.163510f, 0.170455f, 0.177399f, 0.184343f, 0.191288f, 0.198232f, 0.205177f, 0.212121f, 0.219066f, 0.226010f, 0.232955f, 0.239899f, 0.246843f, 0.253788f, 0.260732f, 0.267677f, 0.274621f, 0.281566f, 0.288510f, 0.295455f, 0.302399f, 0.309343f, 0.316288f, 0.323232f, 0.330177f, 0.337121f, 0.344066f, 0.351010f, 0.357955f, 0.364899f, 0.371843f, 0.378788f, 0.385732f, 0.392677f, 0.399621f, 0.406566f, 0.413510f, 0.420455f, 0.427399f, 0.434343f, 0.441288f, 0.448232f, 0.455177f, 0.462121f, 0.469066f, 0.476010f, 0.482955f, 0.489899f, 0.496843f, 0.503788f, 0.510732f, 0.517677f, 0.524621f, 0.531566f, 0.538510f}; + // {2, 1, 4, 13} + std::vector m = {-0.454545f, -0.444930f, -0.435315f, -0.425699f, -0.416084f, -0.406469f, -0.396853f, -0.387238f, -0.377622f, -0.368007f, -0.358392f, -0.348776f, -0.339161f, -0.329545f, -0.319930f, -0.310315f, -0.300699f, -0.291084f, -0.281469f, -0.271853f, -0.262238f, -0.252622f, -0.243007f, -0.233392f, -0.223776f, -0.214161f, -0.204545f, -0.194930f, -0.185315f, -0.175699f, -0.166084f, -0.156469f, -0.146853f, -0.137238f, -0.127622f, -0.118007f, -0.108392f, -0.098776f, -0.089161f, -0.079545f, -0.069930f, -0.060315f, -0.050699f, -0.041084f, -0.031469f, -0.021853f, -0.012238f, -0.002622f, 0.006993f, 0.016608f, 0.026224f, 0.035839f, 0.045455f, 0.055070f, 0.064685f, 0.074301f, 0.083916f, 0.093531f, 0.103147f, 0.112762f, 0.122378f, 0.131993f, 0.141608f, 0.151224f, 0.160839f, 0.170455f, 0.180070f, 0.189685f, 0.199301f, 0.208916f, 0.218531f, 0.228147f, 0.237762f, 0.247378f, 0.256993f, 0.266608f, 0.276224f, 0.285839f, 0.295455f, 0.305070f, 0.314685f, 0.324301f, 0.333916f, 0.343531f, 0.353147f, 0.362762f, 0.372378f, 0.381993f, 0.391608f, 0.401224f, 0.410839f, 0.420455f, 0.430070f, 0.439685f, 0.449301f, 0.458916f, 0.468531f, 0.478147f, 0.487762f, 0.497378f, 0.506993f, 0.516608f, 0.526224f, 0.535839f}; + // {2, 3, 12, 4} + std::vector past_key = {-0.454545f, -0.448593f, -0.442641f, -0.436688f, -0.430736f, -0.424784f, -0.418831f, -0.412879f, -0.406926f, -0.400974f, -0.395022f, -0.389069f, -0.383117f, -0.377165f, -0.371212f, -0.365260f, -0.359307f, -0.353355f, -0.347403f, -0.341450f, -0.335498f, -0.329545f, -0.323593f, -0.317641f, -0.311688f, -0.305736f, -0.299784f, -0.293831f, -0.287879f, -0.281926f, -0.275974f, -0.270022f, -0.264069f, -0.258117f, -0.252165f, -0.246212f, -0.240260f, -0.234307f, -0.228355f, -0.222403f, -0.216450f, -0.210498f, -0.204545f, -0.198593f, -0.192641f, -0.186688f, -0.180736f, -0.174784f, -0.168831f, -0.162879f, -0.156926f, -0.150974f, -0.145022f, -0.139069f, -0.133117f, -0.127164f, -0.121212f, -0.115260f, -0.109307f, -0.103355f, -0.097403f, -0.091450f, -0.085498f, -0.079545f, -0.073593f, -0.067641f, -0.061688f, -0.055736f, -0.049784f, -0.043831f, -0.037879f, -0.031926f, -0.025974f, -0.020022f, -0.014069f, -0.008117f, -0.002165f, 0.003788f, 0.009740f, 0.015693f, 0.021645f, 0.027597f, 0.033550f, 0.039502f, 0.045455f, 0.051407f, 0.057359f, 0.063312f, 0.069264f, 0.075216f, 0.081169f, 0.087121f, 0.093074f, 0.099026f, 0.104978f, 0.110931f, 0.116883f, 0.122835f, 0.128788f, 0.134740f, 0.140693f, 0.146645f, 0.152597f, 0.158550f, 0.164502f, 0.170455f, 0.176407f, 0.182359f, 0.188312f, 0.194264f, 0.200216f, 0.206169f, 0.212121f, 0.218074f, 0.224026f, 0.229978f, 0.235931f, 0.241883f, 0.247836f, 0.253788f, 0.259740f, 0.265693f, 0.271645f, 0.277597f, 0.283550f, 0.289502f, 0.295455f, 0.301407f, 0.307359f, 0.313312f, 0.319264f, 0.325216f, 0.331169f, 0.337121f, 0.343074f, 0.349026f, 0.354978f, 0.360931f, 0.366883f, 0.372835f, 0.378788f, 0.384740f, 0.390693f, 0.396645f, 0.402597f, 0.408550f, 0.414502f, 0.420455f, 0.426407f, 0.432359f, 0.438312f, 0.444264f, 0.450216f, 0.456169f, 0.462121f, 0.468074f, 0.474026f, 0.479978f, 0.485931f, 0.491883f, 0.497835f, 0.503788f, 0.509740f, 0.515693f, 0.521645f, 0.527597f, 0.533550f, 0.539502f}; + // {2, 3, 12, 4} + std::vector past_value = {-0.454545f, -0.448593f, -0.442641f, -0.436688f, -0.430736f, -0.424784f, -0.418831f, -0.412879f, -0.406926f, -0.400974f, -0.395022f, -0.389069f, -0.383117f, -0.377165f, -0.371212f, -0.365260f, -0.359307f, -0.353355f, -0.347403f, -0.341450f, -0.335498f, -0.329545f, -0.323593f, -0.317641f, -0.311688f, -0.305736f, -0.299784f, -0.293831f, -0.287879f, -0.281926f, -0.275974f, -0.270022f, -0.264069f, -0.258117f, -0.252165f, -0.246212f, -0.240260f, -0.234307f, -0.228355f, -0.222403f, -0.216450f, -0.210498f, -0.204545f, -0.198593f, -0.192641f, -0.186688f, -0.180736f, -0.174784f, -0.168831f, -0.162879f, -0.156926f, -0.150974f, -0.145022f, -0.139069f, -0.133117f, -0.127164f, -0.121212f, -0.115260f, -0.109307f, -0.103355f, -0.097403f, -0.091450f, -0.085498f, -0.079545f, -0.073593f, -0.067641f, -0.061688f, -0.055736f, -0.049784f, -0.043831f, -0.037879f, -0.031926f, -0.025974f, -0.020022f, -0.014069f, -0.008117f, -0.002165f, 0.003788f, 0.009740f, 0.015693f, 0.021645f, 0.027597f, 0.033550f, 0.039502f, 0.045455f, 0.051407f, 0.057359f, 0.063312f, 0.069264f, 0.075216f, 0.081169f, 0.087121f, 0.093074f, 0.099026f, 0.104978f, 0.110931f, 0.116883f, 0.122835f, 0.128788f, 0.134740f, 0.140693f, 0.146645f, 0.152597f, 0.158550f, 0.164502f, 0.170455f, 0.176407f, 0.182359f, 0.188312f, 0.194264f, 0.200216f, 0.206169f, 0.212121f, 0.218074f, 0.224026f, 0.229978f, 0.235931f, 0.241883f, 0.247836f, 0.253788f, 0.259740f, 0.265693f, 0.271645f, 0.277597f, 0.283550f, 0.289502f, 0.295455f, 0.301407f, 0.307359f, 0.313312f, 0.319264f, 0.325216f, 0.331169f, 0.337121f, 0.343074f, 0.349026f, 0.354978f, 0.360931f, 0.366883f, 0.372835f, 0.378788f, 0.384740f, 0.390693f, 0.396645f, 0.402597f, 0.408550f, 0.414502f, 0.420455f, 0.426407f, 0.432359f, 0.438312f, 0.444264f, 0.450216f, 0.456169f, 0.462121f, 0.468074f, 0.474026f, 0.479978f, 0.485931f, 0.491883f, 0.497835f, 0.503788f, 0.509740f, 0.515693f, 0.521645f, 0.527597f, 0.533550f, 0.539502f}; + + ASSERT_EQ(q.size(), batch_size * q_num_heads * q_sequence_length * head_size); + ASSERT_EQ(k.size(), batch_size * kv_num_heads * kv_sequence_length * head_size); + ASSERT_EQ(v.size(), batch_size * kv_num_heads * kv_sequence_length * v_head_size); + ASSERT_EQ(m.size(), batch_size * 1 * q_sequence_length * (kv_sequence_length + past_sequence_length)); + ASSERT_EQ(past_key.size(), batch_size * kv_num_heads * past_sequence_length * head_size); + ASSERT_EQ(past_value.size(), batch_size * kv_num_heads * past_sequence_length * v_head_size); + + // {2, 3, 4, 4} + std::vector y = {-0.393782f, -0.387694f, -0.381606f, -0.375519f, -0.397492f, -0.391304f, -0.385116f, -0.378928f, -0.397474f, -0.391207f, -0.384941f, -0.378674f, -0.394849f, -0.388519f, -0.382190f, -0.375860f, -0.226271f, -0.220186f, -0.214101f, -0.208016f, -0.230042f, -0.223857f, -0.217672f, -0.211488f, -0.230104f, -0.223841f, -0.217577f, -0.211314f, -0.227525f, -0.221197f, -0.214870f, -0.208543f, -0.058757f, -0.052674f, -0.046592f, -0.040510f, -0.062587f, -0.056406f, -0.050224f, -0.044042f, -0.062730f, -0.056470f, -0.050209f, -0.043949f, -0.060198f, -0.053873f, -0.047548f, -0.041223f, 0.108760f, 0.114840f, 0.120919f, 0.126999f, 0.104873f, 0.111051f, 0.117229f, 0.123408f, 0.104648f, 0.110906f, 0.117163f, 0.123421f, 0.107131f, 0.113454f, 0.119777f, 0.126099f, 0.276279f, 0.282356f, 0.288433f, 0.294510f, 0.272337f, 0.278512f, 0.284687f, 0.290862f, 0.272031f, 0.278286f, 0.284540f, 0.290794f, 0.274463f, 0.280783f, 0.287104f, 0.293424f, 0.443800f, 0.449874f, 0.455949f, 0.462023f, 0.439807f, 0.445978f, 0.452150f, 0.458321f, 0.439418f, 0.445669f, 0.451921f, 0.458172f, 0.441797f, 0.448115f, 0.454433f, 0.460751f}; + // {2, 3, 13, 4} + std::vector present_key = {-0.454545f, -0.448593f, -0.442641f, -0.436688f, -0.430736f, -0.424784f, -0.418831f, -0.412879f, -0.406926f, -0.400974f, -0.395022f, -0.389069f, -0.383117f, -0.377165f, -0.371212f, -0.365260f, -0.359307f, -0.353355f, -0.347403f, -0.341450f, -0.335498f, -0.329545f, -0.323593f, -0.317641f, -0.311688f, -0.305736f, -0.299784f, -0.293831f, -0.454545f, -0.447601f, -0.440657f, -0.433712f, -0.426768f, -0.419823f, -0.412879f, -0.405934f, -0.398990f, -0.392045f, -0.385101f, -0.378157f, -0.371212f, -0.364268f, -0.357323f, -0.350379f, -0.343434f, -0.336490f, -0.329545f, -0.322601f, -0.315657f, -0.308712f, -0.301768f, -0.294823f, -0.287879f, -0.281926f, -0.275974f, -0.270022f, -0.264069f, -0.258117f, -0.252165f, -0.246212f, -0.240260f, -0.234307f, -0.228355f, -0.222403f, -0.216450f, -0.210498f, -0.204545f, -0.198593f, -0.192641f, -0.186688f, -0.180736f, -0.174784f, -0.168831f, -0.162879f, -0.156926f, -0.150974f, -0.145022f, -0.139069f, -0.133117f, -0.127164f, -0.287879f, -0.280934f, -0.273990f, -0.267045f, -0.260101f, -0.253157f, -0.246212f, -0.239268f, -0.232323f, -0.225379f, -0.218434f, -0.211490f, -0.204545f, -0.197601f, -0.190657f, -0.183712f, -0.176768f, -0.169823f, -0.162879f, -0.155934f, -0.148990f, -0.142045f, -0.135101f, -0.128157f, -0.121212f, -0.115260f, -0.109307f, -0.103355f, -0.097403f, -0.091450f, -0.085498f, -0.079545f, -0.073593f, -0.067641f, -0.061688f, -0.055736f, -0.049784f, -0.043831f, -0.037879f, -0.031926f, -0.025974f, -0.020022f, -0.014069f, -0.008117f, -0.002165f, 0.003788f, 0.009740f, 0.015693f, 0.021645f, 0.027597f, 0.033550f, 0.039502f, -0.121212f, -0.114268f, -0.107323f, -0.100379f, -0.093434f, -0.086490f, -0.079545f, -0.072601f, -0.065657f, -0.058712f, -0.051768f, -0.044823f, -0.037879f, -0.030934f, -0.023990f, -0.017045f, -0.010101f, -0.003157f, 0.003788f, 0.010732f, 0.017677f, 0.024621f, 0.031566f, 0.038510f, 0.045455f, 0.051407f, 0.057359f, 0.063312f, 0.069264f, 0.075216f, 0.081169f, 0.087121f, 0.093074f, 0.099026f, 0.104978f, 0.110931f, 0.116883f, 0.122835f, 0.128788f, 0.134740f, 0.140693f, 0.146645f, 0.152597f, 0.158550f, 0.164502f, 0.170455f, 0.176407f, 0.182359f, 0.188312f, 0.194264f, 0.200216f, 0.206169f, 0.045455f, 0.052399f, 0.059343f, 0.066288f, 0.073232f, 0.080177f, 0.087121f, 0.094066f, 0.101010f, 0.107955f, 0.114899f, 0.121843f, 0.128788f, 0.135732f, 0.142677f, 0.149621f, 0.156566f, 0.163510f, 0.170455f, 0.177399f, 0.184343f, 0.191288f, 0.198232f, 0.205177f, 0.212121f, 0.218074f, 0.224026f, 0.229978f, 0.235931f, 0.241883f, 0.247836f, 0.253788f, 0.259740f, 0.265693f, 0.271645f, 0.277597f, 0.283550f, 0.289502f, 0.295455f, 0.301407f, 0.307359f, 0.313312f, 0.319264f, 0.325216f, 0.331169f, 0.337121f, 0.343074f, 0.349026f, 0.354978f, 0.360931f, 0.366883f, 0.372835f, 0.212121f, 0.219066f, 0.226010f, 0.232955f, 0.239899f, 0.246843f, 0.253788f, 0.260732f, 0.267677f, 0.274621f, 0.281566f, 0.288510f, 0.295455f, 0.302399f, 0.309343f, 0.316288f, 0.323232f, 0.330177f, 0.337121f, 0.344066f, 0.351010f, 0.357955f, 0.364899f, 0.371843f, 0.378788f, 0.384740f, 0.390693f, 0.396645f, 0.402597f, 0.408550f, 0.414502f, 0.420455f, 0.426407f, 0.432359f, 0.438312f, 0.444264f, 0.450216f, 0.456169f, 0.462121f, 0.468074f, 0.474026f, 0.479978f, 0.485931f, 0.491883f, 0.497835f, 0.503788f, 0.509740f, 0.515693f, 0.521645f, 0.527597f, 0.533550f, 0.539502f, 0.378788f, 0.385732f, 0.392677f, 0.399621f, 0.406566f, 0.413510f, 0.420455f, 0.427399f, 0.434343f, 0.441288f, 0.448232f, 0.455177f, 0.462121f, 0.469066f, 0.476010f, 0.482955f, 0.489899f, 0.496843f, 0.503788f, 0.510732f, 0.517677f, 0.524621f, 0.531566f, 0.538510f}; + // {2, 3, 18, 8} + std::vector present_value = {-0.454545f, -0.448593f, -0.442641f, -0.436688f, -0.430736f, -0.424784f, -0.418831f, -0.412879f, -0.406926f, -0.400974f, -0.395022f, -0.389069f, -0.383117f, -0.377165f, -0.371212f, -0.365260f, -0.359307f, -0.353355f, -0.347403f, -0.341450f, -0.335498f, -0.329545f, -0.323593f, -0.317641f, -0.311688f, -0.305736f, -0.299784f, -0.293831f, -0.454545f, -0.447601f, -0.440657f, -0.433712f, -0.426768f, -0.419823f, -0.412879f, -0.405934f, -0.398990f, -0.392045f, -0.385101f, -0.378157f, -0.371212f, -0.364268f, -0.357323f, -0.350379f, -0.343434f, -0.336490f, -0.329545f, -0.322601f, -0.315657f, -0.308712f, -0.301768f, -0.294823f, -0.287879f, -0.281926f, -0.275974f, -0.270022f, -0.264069f, -0.258117f, -0.252165f, -0.246212f, -0.240260f, -0.234307f, -0.228355f, -0.222403f, -0.216450f, -0.210498f, -0.204545f, -0.198593f, -0.192641f, -0.186688f, -0.180736f, -0.174784f, -0.168831f, -0.162879f, -0.156926f, -0.150974f, -0.145022f, -0.139069f, -0.133117f, -0.127164f, -0.287879f, -0.280934f, -0.273990f, -0.267045f, -0.260101f, -0.253157f, -0.246212f, -0.239268f, -0.232323f, -0.225379f, -0.218434f, -0.211490f, -0.204545f, -0.197601f, -0.190657f, -0.183712f, -0.176768f, -0.169823f, -0.162879f, -0.155934f, -0.148990f, -0.142045f, -0.135101f, -0.128157f, -0.121212f, -0.115260f, -0.109307f, -0.103355f, -0.097403f, -0.091450f, -0.085498f, -0.079545f, -0.073593f, -0.067641f, -0.061688f, -0.055736f, -0.049784f, -0.043831f, -0.037879f, -0.031926f, -0.025974f, -0.020022f, -0.014069f, -0.008117f, -0.002165f, 0.003788f, 0.009740f, 0.015693f, 0.021645f, 0.027597f, 0.033550f, 0.039502f, -0.121212f, -0.114268f, -0.107323f, -0.100379f, -0.093434f, -0.086490f, -0.079545f, -0.072601f, -0.065657f, -0.058712f, -0.051768f, -0.044823f, -0.037879f, -0.030934f, -0.023990f, -0.017045f, -0.010101f, -0.003157f, 0.003788f, 0.010732f, 0.017677f, 0.024621f, 0.031566f, 0.038510f, 0.045455f, 0.051407f, 0.057359f, 0.063312f, 0.069264f, 0.075216f, 0.081169f, 0.087121f, 0.093074f, 0.099026f, 0.104978f, 0.110931f, 0.116883f, 0.122835f, 0.128788f, 0.134740f, 0.140693f, 0.146645f, 0.152597f, 0.158550f, 0.164502f, 0.170455f, 0.176407f, 0.182359f, 0.188312f, 0.194264f, 0.200216f, 0.206169f, 0.045455f, 0.052399f, 0.059343f, 0.066288f, 0.073232f, 0.080177f, 0.087121f, 0.094066f, 0.101010f, 0.107955f, 0.114899f, 0.121843f, 0.128788f, 0.135732f, 0.142677f, 0.149621f, 0.156566f, 0.163510f, 0.170455f, 0.177399f, 0.184343f, 0.191288f, 0.198232f, 0.205177f, 0.212121f, 0.218074f, 0.224026f, 0.229978f, 0.235931f, 0.241883f, 0.247836f, 0.253788f, 0.259740f, 0.265693f, 0.271645f, 0.277597f, 0.283550f, 0.289502f, 0.295455f, 0.301407f, 0.307359f, 0.313312f, 0.319264f, 0.325216f, 0.331169f, 0.337121f, 0.343074f, 0.349026f, 0.354978f, 0.360931f, 0.366883f, 0.372835f, 0.212121f, 0.219066f, 0.226010f, 0.232955f, 0.239899f, 0.246843f, 0.253788f, 0.260732f, 0.267677f, 0.274621f, 0.281566f, 0.288510f, 0.295455f, 0.302399f, 0.309343f, 0.316288f, 0.323232f, 0.330177f, 0.337121f, 0.344066f, 0.351010f, 0.357955f, 0.364899f, 0.371843f, 0.378788f, 0.384740f, 0.390693f, 0.396645f, 0.402597f, 0.408550f, 0.414502f, 0.420455f, 0.426407f, 0.432359f, 0.438312f, 0.444264f, 0.450216f, 0.456169f, 0.462121f, 0.468074f, 0.474026f, 0.479978f, 0.485931f, 0.491883f, 0.497835f, 0.503788f, 0.509740f, 0.515693f, 0.521645f, 0.527597f, 0.533550f, 0.539502f, 0.378788f, 0.385732f, 0.392677f, 0.399621f, 0.406566f, 0.413510f, 0.420455f, 0.427399f, 0.434343f, 0.441288f, 0.448232f, 0.455177f, 0.462121f, 0.469066f, 0.476010f, 0.482955f, 0.489899f, 0.496843f, 0.503788f, 0.510732f, 0.517677f, 0.524621f, 0.531566f, 0.538510f}; + // {2, 3, 4, 13} + std::vector qk_matmul = {0.391336f, 0.370435f, 0.349534f, 0.328633f, 0.307732f, 0.286831f, 0.265930f, 0.390055f, 0.365671f, 0.341286f, 0.316902f, 0.292517f, 0.268133f, 0.354201f, 0.335284f, 0.316367f, 0.297450f, 0.278534f, 0.259617f, 0.240700f, 0.353045f, 0.330975f, 0.308905f, 0.286836f, 0.264766f, 0.242696f, 0.317066f, 0.300134f, 0.283201f, 0.266268f, 0.249335f, 0.232403f, 0.215470f, 0.316034f, 0.296279f, 0.276524f, 0.256769f, 0.237014f, 0.217260f, 0.279932f, 0.264983f, 0.250034f, 0.235086f, 0.220137f, 0.205189f, 0.190240f, 0.279023f, 0.261583f, 0.244143f, 0.226703f, 0.209263f, 0.191823f, 0.152046f, 0.139081f, 0.126117f, 0.113152f, 0.100188f, 0.087223f, 0.074259f, 0.151261f, 0.136136f, 0.121011f, 0.105885f, 0.090760f, 0.075635f, 0.128800f, 0.117819f, 0.106839f, 0.095859f, 0.084878f, 0.073898f, 0.062918f, 0.128139f, 0.115329f, 0.102518f, 0.089708f, 0.076898f, 0.064087f, 0.105554f, 0.096558f, 0.087561f, 0.078565f, 0.069569f, 0.060573f, 0.051577f, 0.105017f, 0.094522f, 0.084026f, 0.073531f, 0.063035f, 0.052539f, 0.082308f, 0.075296f, 0.068284f, 0.061272f, 0.054260f, 0.047248f, 0.040235f, 0.081896f, 0.073715f, 0.065534f, 0.057353f, 0.049172f, 0.040992f, 0.023866f, 0.018838f, 0.013810f, 0.008783f, 0.003755f, -0.001273f, -0.006301f, 0.023578f, 0.017712f, 0.011846f, 0.005980f, 0.000114f, -0.005752f, 0.014509f, 0.011466f, 0.008422f, 0.005378f, 0.002334f, -0.000710f, -0.003754f, 0.014345f, 0.010794f, 0.007243f, 0.003692f, 0.000140f, -0.003411f, 0.005152f, 0.004093f, 0.003033f, 0.001973f, 0.000914f, -0.000146f, -0.001206f, 0.005112f, 0.003876f, 0.002639f, 0.001403f, 0.000167f, -0.001070f, -0.004204f, -0.003280f, -0.002356f, -0.001431f, -0.000507f, 0.000418f, 0.001342f, -0.004121f, -0.003042f, -0.001964f, -0.000885f, 0.000193f, 0.001272f, 0.006798f, 0.009707f, 0.012616f, 0.015524f, 0.018433f, 0.021341f, 0.024250f, 0.007006f, 0.010399f, 0.013793f, 0.017186f, 0.020579f, 0.023973f, 0.011330f, 0.016223f, 0.021116f, 0.026008f, 0.030901f, 0.035794f, 0.040686f, 0.011662f, 0.017370f, 0.023078f, 0.028786f, 0.034494f, 0.040203f, 0.015862f, 0.022739f, 0.029616f, 0.036493f, 0.043369f, 0.050246f, 0.057123f, 0.016318f, 0.024341f, 0.032364f, 0.040387f, 0.048410f, 0.056433f, 0.020394f, 0.029255f, 0.038116f, 0.046977f, 0.055838f, 0.064699f, 0.073560f, 0.020974f, 0.031312f, 0.041649f, 0.051987f, 0.062325f, 0.072663f, 0.100842f, 0.111687f, 0.122532f, 0.133377f, 0.144222f, 0.155067f, 0.165912f, 0.101545f, 0.114198f, 0.126850f, 0.139503f, 0.152155f, 0.164808f, 0.119262f, 0.132092f, 0.144921f, 0.157750f, 0.170579f, 0.183408f, 0.196237f, 0.120090f, 0.135057f, 0.150025f, 0.164992f, 0.179960f, 0.194927f, 0.137683f, 0.152496f, 0.167310f, 0.182123f, 0.196936f, 0.211750f, 0.226563f, 0.138635f, 0.155917f, 0.173199f, 0.190481f, 0.207764f, 0.225046f, 0.156104f, 0.172901f, 0.189699f, 0.206496f, 0.223294f, 0.240091f, 0.256889f, 0.157180f, 0.176777f, 0.196374f, 0.215971f, 0.235568f, 0.255165f, 0.305996f, 0.324777f, 0.343559f, 0.362340f, 0.381122f, 0.399904f, 0.418685f, 0.307195f, 0.329107f, 0.351019f, 0.372931f, 0.394843f, 0.416755f, 0.338305f, 0.359071f, 0.379837f, 0.400603f, 0.421368f, 0.442134f, 0.462900f, 0.339629f, 0.363856f, 0.388082f, 0.412309f, 0.436536f, 0.460762f, 0.370615f, 0.393365f, 0.416115f, 0.438865f, 0.461614f, 0.484364f, 0.507114f, 0.372063f, 0.398604f, 0.425146f, 0.451687f, 0.478229f, 0.504770f, 0.402925f, 0.427659f, 0.452393f, 0.477127f, 0.501861f, 0.526595f, 0.551329f, 0.404497f, 0.433353f, 0.462209f, 0.491065f, 0.519922f, 0.548778f}; + ASSERT_EQ(y.size(), batch_size * q_num_heads * q_sequence_length * v_head_size); + ASSERT_EQ(present_key.size(), batch_size * kv_num_heads * (past_sequence_length + kv_sequence_length) * head_size); + ASSERT_EQ(present_value.size(), batch_size * kv_num_heads * (past_sequence_length + kv_sequence_length) * v_head_size); + ASSERT_EQ(qk_matmul.size(), batch_size * kv_num_heads * q_sequence_length * (past_sequence_length + kv_sequence_length)); + + RunTest4D(batch_size, q_num_heads, q_sequence_length, head_size, kv_sequence_length, kv_num_heads, v_head_size, past_sequence_length, + q, k, v, m, std::initializer_list(), past_key, past_value, + 1, -1, std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), -1, TensorType::kFloat, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type + y, present_key, present_value, qk_matmul, + false, true, true // disable_cpu, disable_cuda, disable_dml + ); +} + +TEST(AttentionTest, Attention4DWithMask4DPastAndPresentQkMatmul) { + int batch_size = 2; // Q.shape[0] + int q_num_heads = 3; // Q.shape[1] + int q_sequence_length = 4; // Q.shape[2] + int head_size = 4; // Q.shape[3] + int kv_sequence_length = 6; // K.shape[2] and V.shape[2] + int kv_num_heads = 3; // K.shape[1] and V.shape[1] + int v_head_size = 4; // V.shape[3] + int past_sequence_length = 7; // past_key.shape[2] and past_value.shape[2] + + // {2, 3, 4, 4} + std::vector q = {-0.454545f, -0.444129f, -0.433712f, -0.423295f, -0.412879f, -0.402462f, -0.392045f, -0.381629f, -0.371212f, -0.360795f, -0.350379f, -0.339962f, -0.329545f, -0.319129f, -0.308712f, -0.298295f, -0.287879f, -0.277462f, -0.267045f, -0.256629f, -0.246212f, -0.235795f, -0.225379f, -0.214962f, -0.204545f, -0.194129f, -0.183712f, -0.173295f, -0.162879f, -0.152462f, -0.142045f, -0.131629f, -0.121212f, -0.110795f, -0.100379f, -0.089962f, -0.079545f, -0.069129f, -0.058712f, -0.048295f, -0.037879f, -0.027462f, -0.017045f, -0.006629f, 0.003788f, 0.014205f, 0.024621f, 0.035038f, 0.045455f, 0.055871f, 0.066288f, 0.076705f, 0.087121f, 0.097538f, 0.107955f, 0.118371f, 0.128788f, 0.139205f, 0.149621f, 0.160038f, 0.170455f, 0.180871f, 0.191288f, 0.201705f, 0.212121f, 0.222538f, 0.232955f, 0.243371f, 0.253788f, 0.264205f, 0.274621f, 0.285038f, 0.295455f, 0.305871f, 0.316288f, 0.326705f, 0.337121f, 0.347538f, 0.357955f, 0.368371f, 0.378788f, 0.389205f, 0.399621f, 0.410038f, 0.420455f, 0.430871f, 0.441288f, 0.451705f, 0.462121f, 0.472538f, 0.482955f, 0.493371f, 0.503788f, 0.514205f, 0.524621f, 0.535038f}; + // {2, 3, 6, 4} + std::vector k = {-0.454545f, -0.447601f, -0.440657f, -0.433712f, -0.426768f, -0.419823f, -0.412879f, -0.405934f, -0.398990f, -0.392045f, -0.385101f, -0.378157f, -0.371212f, -0.364268f, -0.357323f, -0.350379f, -0.343434f, -0.336490f, -0.329545f, -0.322601f, -0.315657f, -0.308712f, -0.301768f, -0.294823f, -0.287879f, -0.280934f, -0.273990f, -0.267045f, -0.260101f, -0.253157f, -0.246212f, -0.239268f, -0.232323f, -0.225379f, -0.218434f, -0.211490f, -0.204545f, -0.197601f, -0.190657f, -0.183712f, -0.176768f, -0.169823f, -0.162879f, -0.155934f, -0.148990f, -0.142045f, -0.135101f, -0.128157f, -0.121212f, -0.114268f, -0.107323f, -0.100379f, -0.093434f, -0.086490f, -0.079545f, -0.072601f, -0.065657f, -0.058712f, -0.051768f, -0.044823f, -0.037879f, -0.030934f, -0.023990f, -0.017045f, -0.010101f, -0.003157f, 0.003788f, 0.010732f, 0.017677f, 0.024621f, 0.031566f, 0.038510f, 0.045455f, 0.052399f, 0.059343f, 0.066288f, 0.073232f, 0.080177f, 0.087121f, 0.094066f, 0.101010f, 0.107955f, 0.114899f, 0.121843f, 0.128788f, 0.135732f, 0.142677f, 0.149621f, 0.156566f, 0.163510f, 0.170455f, 0.177399f, 0.184343f, 0.191288f, 0.198232f, 0.205177f, 0.212121f, 0.219066f, 0.226010f, 0.232955f, 0.239899f, 0.246843f, 0.253788f, 0.260732f, 0.267677f, 0.274621f, 0.281566f, 0.288510f, 0.295455f, 0.302399f, 0.309343f, 0.316288f, 0.323232f, 0.330177f, 0.337121f, 0.344066f, 0.351010f, 0.357955f, 0.364899f, 0.371843f, 0.378788f, 0.385732f, 0.392677f, 0.399621f, 0.406566f, 0.413510f, 0.420455f, 0.427399f, 0.434343f, 0.441288f, 0.448232f, 0.455177f, 0.462121f, 0.469066f, 0.476010f, 0.482955f, 0.489899f, 0.496843f, 0.503788f, 0.510732f, 0.517677f, 0.524621f, 0.531566f, 0.538510f}; + // {2, 3, 6, 4} + std::vector v = {-0.454545f, -0.447601f, -0.440657f, -0.433712f, -0.426768f, -0.419823f, -0.412879f, -0.405934f, -0.398990f, -0.392045f, -0.385101f, -0.378157f, -0.371212f, -0.364268f, -0.357323f, -0.350379f, -0.343434f, -0.336490f, -0.329545f, -0.322601f, -0.315657f, -0.308712f, -0.301768f, -0.294823f, -0.287879f, -0.280934f, -0.273990f, -0.267045f, -0.260101f, -0.253157f, -0.246212f, -0.239268f, -0.232323f, -0.225379f, -0.218434f, -0.211490f, -0.204545f, -0.197601f, -0.190657f, -0.183712f, -0.176768f, -0.169823f, -0.162879f, -0.155934f, -0.148990f, -0.142045f, -0.135101f, -0.128157f, -0.121212f, -0.114268f, -0.107323f, -0.100379f, -0.093434f, -0.086490f, -0.079545f, -0.072601f, -0.065657f, -0.058712f, -0.051768f, -0.044823f, -0.037879f, -0.030934f, -0.023990f, -0.017045f, -0.010101f, -0.003157f, 0.003788f, 0.010732f, 0.017677f, 0.024621f, 0.031566f, 0.038510f, 0.045455f, 0.052399f, 0.059343f, 0.066288f, 0.073232f, 0.080177f, 0.087121f, 0.094066f, 0.101010f, 0.107955f, 0.114899f, 0.121843f, 0.128788f, 0.135732f, 0.142677f, 0.149621f, 0.156566f, 0.163510f, 0.170455f, 0.177399f, 0.184343f, 0.191288f, 0.198232f, 0.205177f, 0.212121f, 0.219066f, 0.226010f, 0.232955f, 0.239899f, 0.246843f, 0.253788f, 0.260732f, 0.267677f, 0.274621f, 0.281566f, 0.288510f, 0.295455f, 0.302399f, 0.309343f, 0.316288f, 0.323232f, 0.330177f, 0.337121f, 0.344066f, 0.351010f, 0.357955f, 0.364899f, 0.371843f, 0.378788f, 0.385732f, 0.392677f, 0.399621f, 0.406566f, 0.413510f, 0.420455f, 0.427399f, 0.434343f, 0.441288f, 0.448232f, 0.455177f, 0.462121f, 0.469066f, 0.476010f, 0.482955f, 0.489899f, 0.496843f, 0.503788f, 0.510732f, 0.517677f, 0.524621f, 0.531566f, 0.538510f}; + // {2, 3, 4, 13} + std::vector m = {-0.454545f, -0.451340f, -0.448135f, -0.444930f, -0.441725f, -0.438520f, -0.435315f, -0.432110f, -0.428904f, -0.425699f, -0.422494f, -0.419289f, -0.416084f, -0.412879f, -0.409674f, -0.406469f, -0.403263f, -0.400058f, -0.396853f, -0.393648f, -0.390443f, -0.387238f, -0.384033f, -0.380828f, -0.377622f, -0.374417f, -0.371212f, -0.368007f, -0.364802f, -0.361597f, -0.358392f, -0.355186f, -0.351981f, -0.348776f, -0.345571f, -0.342366f, -0.339161f, -0.335956f, -0.332751f, -0.329545f, -0.326340f, -0.323135f, -0.319930f, -0.316725f, -0.313520f, -0.310315f, -0.307110f, -0.303904f, -0.300699f, -0.297494f, -0.294289f, -0.291084f, -0.287879f, -0.284674f, -0.281469f, -0.278263f, -0.275058f, -0.271853f, -0.268648f, -0.265443f, -0.262238f, -0.259033f, -0.255828f, -0.252622f, -0.249417f, -0.246212f, -0.243007f, -0.239802f, -0.236597f, -0.233392f, -0.230186f, -0.226981f, -0.223776f, -0.220571f, -0.217366f, -0.214161f, -0.210956f, -0.207751f, -0.204545f, -0.201340f, -0.198135f, -0.194930f, -0.191725f, -0.188520f, -0.185315f, -0.182110f, -0.178904f, -0.175699f, -0.172494f, -0.169289f, -0.166084f, -0.162879f, -0.159674f, -0.156469f, -0.153263f, -0.150058f, -0.146853f, -0.143648f, -0.140443f, -0.137238f, -0.134033f, -0.130828f, -0.127622f, -0.124417f, -0.121212f, -0.118007f, -0.114802f, -0.111597f, -0.108392f, -0.105186f, -0.101981f, -0.098776f, -0.095571f, -0.092366f, -0.089161f, -0.085956f, -0.082751f, -0.079545f, -0.076340f, -0.073135f, -0.069930f, -0.066725f, -0.063520f, -0.060315f, -0.057110f, -0.053904f, -0.050699f, -0.047494f, -0.044289f, -0.041084f, -0.037879f, -0.034674f, -0.031469f, -0.028263f, -0.025058f, -0.021853f, -0.018648f, -0.015443f, -0.012238f, -0.009033f, -0.005828f, -0.002622f, 0.000583f, 0.003788f, 0.006993f, 0.010198f, 0.013403f, 0.016608f, 0.019814f, 0.023019f, 0.026224f, 0.029429f, 0.032634f, 0.035839f, 0.039044f, 0.042249f, 0.045455f, 0.048660f, 0.051865f, 0.055070f, 0.058275f, 0.061480f, 0.064685f, 0.067890f, 0.071096f, 0.074301f, 0.077506f, 0.080711f, 0.083916f, 0.087121f, 0.090326f, 0.093531f, 0.096737f, 0.099942f, 0.103147f, 0.106352f, 0.109557f, 0.112762f, 0.115967f, 0.119172f, 0.122378f, 0.125583f, 0.128788f, 0.131993f, 0.135198f, 0.138403f, 0.141608f, 0.144814f, 0.148019f, 0.151224f, 0.154429f, 0.157634f, 0.160839f, 0.164044f, 0.167249f, 0.170455f, 0.173660f, 0.176865f, 0.180070f, 0.183275f, 0.186480f, 0.189685f, 0.192890f, 0.196096f, 0.199301f, 0.202506f, 0.205711f, 0.208916f, 0.212121f, 0.215326f, 0.218531f, 0.221737f, 0.224942f, 0.228147f, 0.231352f, 0.234557f, 0.237762f, 0.240967f, 0.244172f, 0.247378f, 0.250583f, 0.253788f, 0.256993f, 0.260198f, 0.263403f, 0.266608f, 0.269814f, 0.273019f, 0.276224f, 0.279429f, 0.282634f, 0.285839f, 0.289044f, 0.292249f, 0.295455f, 0.298660f, 0.301865f, 0.305070f, 0.308275f, 0.311480f, 0.314685f, 0.317890f, 0.321096f, 0.324301f, 0.327506f, 0.330711f, 0.333916f, 0.337121f, 0.340326f, 0.343531f, 0.346737f, 0.349942f, 0.353147f, 0.356352f, 0.359557f, 0.362762f, 0.365967f, 0.369172f, 0.372378f, 0.375583f, 0.378788f, 0.381993f, 0.385198f, 0.388403f, 0.391608f, 0.394814f, 0.398019f, 0.401224f, 0.404429f, 0.407634f, 0.410839f, 0.414044f, 0.417249f, 0.420455f, 0.423660f, 0.426865f, 0.430070f, 0.433275f, 0.436480f, 0.439685f, 0.442890f, 0.446096f, 0.449301f, 0.452506f, 0.455711f, 0.458916f, 0.462121f, 0.465326f, 0.468531f, 0.471737f, 0.474942f, 0.478147f, 0.481352f, 0.484557f, 0.487762f, 0.490967f, 0.494172f, 0.497378f, 0.500583f, 0.503788f, 0.506993f, 0.510198f, 0.513403f, 0.516608f, 0.519814f, 0.523019f, 0.526224f, 0.529429f, 0.532634f, 0.535839f, 0.539044f, 0.542249f}; + // {2, 3, 12, 4} + std::vector past_key = {-0.454545f, -0.448593f, -0.442641f, -0.436688f, -0.430736f, -0.424784f, -0.418831f, -0.412879f, -0.406926f, -0.400974f, -0.395022f, -0.389069f, -0.383117f, -0.377165f, -0.371212f, -0.365260f, -0.359307f, -0.353355f, -0.347403f, -0.341450f, -0.335498f, -0.329545f, -0.323593f, -0.317641f, -0.311688f, -0.305736f, -0.299784f, -0.293831f, -0.287879f, -0.281926f, -0.275974f, -0.270022f, -0.264069f, -0.258117f, -0.252165f, -0.246212f, -0.240260f, -0.234307f, -0.228355f, -0.222403f, -0.216450f, -0.210498f, -0.204545f, -0.198593f, -0.192641f, -0.186688f, -0.180736f, -0.174784f, -0.168831f, -0.162879f, -0.156926f, -0.150974f, -0.145022f, -0.139069f, -0.133117f, -0.127164f, -0.121212f, -0.115260f, -0.109307f, -0.103355f, -0.097403f, -0.091450f, -0.085498f, -0.079545f, -0.073593f, -0.067641f, -0.061688f, -0.055736f, -0.049784f, -0.043831f, -0.037879f, -0.031926f, -0.025974f, -0.020022f, -0.014069f, -0.008117f, -0.002165f, 0.003788f, 0.009740f, 0.015693f, 0.021645f, 0.027597f, 0.033550f, 0.039502f, 0.045455f, 0.051407f, 0.057359f, 0.063312f, 0.069264f, 0.075216f, 0.081169f, 0.087121f, 0.093074f, 0.099026f, 0.104978f, 0.110931f, 0.116883f, 0.122835f, 0.128788f, 0.134740f, 0.140693f, 0.146645f, 0.152597f, 0.158550f, 0.164502f, 0.170455f, 0.176407f, 0.182359f, 0.188312f, 0.194264f, 0.200216f, 0.206169f, 0.212121f, 0.218074f, 0.224026f, 0.229978f, 0.235931f, 0.241883f, 0.247836f, 0.253788f, 0.259740f, 0.265693f, 0.271645f, 0.277597f, 0.283550f, 0.289502f, 0.295455f, 0.301407f, 0.307359f, 0.313312f, 0.319264f, 0.325216f, 0.331169f, 0.337121f, 0.343074f, 0.349026f, 0.354978f, 0.360931f, 0.366883f, 0.372835f, 0.378788f, 0.384740f, 0.390693f, 0.396645f, 0.402597f, 0.408550f, 0.414502f, 0.420455f, 0.426407f, 0.432359f, 0.438312f, 0.444264f, 0.450216f, 0.456169f, 0.462121f, 0.468074f, 0.474026f, 0.479978f, 0.485931f, 0.491883f, 0.497835f, 0.503788f, 0.509740f, 0.515693f, 0.521645f, 0.527597f, 0.533550f, 0.539502f}; + // {2, 3, 12, 4} + std::vector past_value = {-0.454545f, -0.448593f, -0.442641f, -0.436688f, -0.430736f, -0.424784f, -0.418831f, -0.412879f, -0.406926f, -0.400974f, -0.395022f, -0.389069f, -0.383117f, -0.377165f, -0.371212f, -0.365260f, -0.359307f, -0.353355f, -0.347403f, -0.341450f, -0.335498f, -0.329545f, -0.323593f, -0.317641f, -0.311688f, -0.305736f, -0.299784f, -0.293831f, -0.287879f, -0.281926f, -0.275974f, -0.270022f, -0.264069f, -0.258117f, -0.252165f, -0.246212f, -0.240260f, -0.234307f, -0.228355f, -0.222403f, -0.216450f, -0.210498f, -0.204545f, -0.198593f, -0.192641f, -0.186688f, -0.180736f, -0.174784f, -0.168831f, -0.162879f, -0.156926f, -0.150974f, -0.145022f, -0.139069f, -0.133117f, -0.127164f, -0.121212f, -0.115260f, -0.109307f, -0.103355f, -0.097403f, -0.091450f, -0.085498f, -0.079545f, -0.073593f, -0.067641f, -0.061688f, -0.055736f, -0.049784f, -0.043831f, -0.037879f, -0.031926f, -0.025974f, -0.020022f, -0.014069f, -0.008117f, -0.002165f, 0.003788f, 0.009740f, 0.015693f, 0.021645f, 0.027597f, 0.033550f, 0.039502f, 0.045455f, 0.051407f, 0.057359f, 0.063312f, 0.069264f, 0.075216f, 0.081169f, 0.087121f, 0.093074f, 0.099026f, 0.104978f, 0.110931f, 0.116883f, 0.122835f, 0.128788f, 0.134740f, 0.140693f, 0.146645f, 0.152597f, 0.158550f, 0.164502f, 0.170455f, 0.176407f, 0.182359f, 0.188312f, 0.194264f, 0.200216f, 0.206169f, 0.212121f, 0.218074f, 0.224026f, 0.229978f, 0.235931f, 0.241883f, 0.247836f, 0.253788f, 0.259740f, 0.265693f, 0.271645f, 0.277597f, 0.283550f, 0.289502f, 0.295455f, 0.301407f, 0.307359f, 0.313312f, 0.319264f, 0.325216f, 0.331169f, 0.337121f, 0.343074f, 0.349026f, 0.354978f, 0.360931f, 0.366883f, 0.372835f, 0.378788f, 0.384740f, 0.390693f, 0.396645f, 0.402597f, 0.408550f, 0.414502f, 0.420455f, 0.426407f, 0.432359f, 0.438312f, 0.444264f, 0.450216f, 0.456169f, 0.462121f, 0.468074f, 0.474026f, 0.479978f, 0.485931f, 0.491883f, 0.497835f, 0.503788f, 0.509740f, 0.515693f, 0.521645f, 0.527597f, 0.533550f, 0.539502f}; + + ASSERT_EQ(q.size(), batch_size * q_num_heads * q_sequence_length * head_size); + ASSERT_EQ(k.size(), batch_size * kv_num_heads * kv_sequence_length * head_size); + ASSERT_EQ(v.size(), batch_size * kv_num_heads * kv_sequence_length * v_head_size); + ASSERT_EQ(m.size(), batch_size * q_num_heads * q_sequence_length * (kv_sequence_length + past_sequence_length)); + ASSERT_EQ(past_key.size(), batch_size * kv_num_heads * past_sequence_length * head_size); + ASSERT_EQ(past_value.size(), batch_size * kv_num_heads * past_sequence_length * v_head_size); + + // {2, 3, 4, 4} + std::vector y = {-0.385742f, -0.379327f, -0.372911f, -0.366496f, -0.385554f, -0.379139f, -0.372723f, -0.366308f, -0.385366f, -0.378950f, -0.372535f, -0.366119f, -0.385178f, -0.378762f, -0.372347f, -0.365931f, -0.218323f, -0.211907f, -0.205492f, -0.199076f, -0.218134f, -0.211719f, -0.205304f, -0.198888f, -0.217946f, -0.211531f, -0.205115f, -0.198700f, -0.217758f, -0.211342f, -0.204927f, -0.198512f, -0.050903f, -0.044487f, -0.038072f, -0.031657f, -0.050715f, -0.044299f, -0.037884f, -0.031468f, -0.050526f, -0.044111f, -0.037695f, -0.031280f, -0.050338f, -0.043922f, -0.037507f, -0.031092f, 0.116517f, 0.122932f, 0.129348f, 0.135763f, 0.116705f, 0.123121f, 0.129536f, 0.135952f, 0.116894f, 0.123309f, 0.129724f, 0.136140f, 0.117082f, 0.123497f, 0.129913f, 0.136328f, 0.283937f, 0.290352f, 0.296768f, 0.303183f, 0.284125f, 0.290540f, 0.296956f, 0.303371f, 0.284313f, 0.290729f, 0.297144f, 0.303559f, 0.284501f, 0.290917f, 0.297332f, 0.303747f, 0.451356f, 0.457772f, 0.464187f, 0.470602f, 0.451544f, 0.457960f, 0.464375f, 0.470790f, 0.451732f, 0.458148f, 0.464563f, 0.470978f, 0.451920f, 0.458336f, 0.464751f, 0.471166f}; + // {2, 3, 13, 4} + std::vector present_key = {-0.454545f, -0.448593f, -0.442641f, -0.436688f, -0.430736f, -0.424784f, -0.418831f, -0.412879f, -0.406926f, -0.400974f, -0.395022f, -0.389069f, -0.383117f, -0.377165f, -0.371212f, -0.365260f, -0.359307f, -0.353355f, -0.347403f, -0.341450f, -0.335498f, -0.329545f, -0.323593f, -0.317641f, -0.311688f, -0.305736f, -0.299784f, -0.293831f, -0.454545f, -0.447601f, -0.440657f, -0.433712f, -0.426768f, -0.419823f, -0.412879f, -0.405934f, -0.398990f, -0.392045f, -0.385101f, -0.378157f, -0.371212f, -0.364268f, -0.357323f, -0.350379f, -0.343434f, -0.336490f, -0.329545f, -0.322601f, -0.315657f, -0.308712f, -0.301768f, -0.294823f, -0.287879f, -0.281926f, -0.275974f, -0.270022f, -0.264069f, -0.258117f, -0.252165f, -0.246212f, -0.240260f, -0.234307f, -0.228355f, -0.222403f, -0.216450f, -0.210498f, -0.204545f, -0.198593f, -0.192641f, -0.186688f, -0.180736f, -0.174784f, -0.168831f, -0.162879f, -0.156926f, -0.150974f, -0.145022f, -0.139069f, -0.133117f, -0.127164f, -0.287879f, -0.280934f, -0.273990f, -0.267045f, -0.260101f, -0.253157f, -0.246212f, -0.239268f, -0.232323f, -0.225379f, -0.218434f, -0.211490f, -0.204545f, -0.197601f, -0.190657f, -0.183712f, -0.176768f, -0.169823f, -0.162879f, -0.155934f, -0.148990f, -0.142045f, -0.135101f, -0.128157f, -0.121212f, -0.115260f, -0.109307f, -0.103355f, -0.097403f, -0.091450f, -0.085498f, -0.079545f, -0.073593f, -0.067641f, -0.061688f, -0.055736f, -0.049784f, -0.043831f, -0.037879f, -0.031926f, -0.025974f, -0.020022f, -0.014069f, -0.008117f, -0.002165f, 0.003788f, 0.009740f, 0.015693f, 0.021645f, 0.027597f, 0.033550f, 0.039502f, -0.121212f, -0.114268f, -0.107323f, -0.100379f, -0.093434f, -0.086490f, -0.079545f, -0.072601f, -0.065657f, -0.058712f, -0.051768f, -0.044823f, -0.037879f, -0.030934f, -0.023990f, -0.017045f, -0.010101f, -0.003157f, 0.003788f, 0.010732f, 0.017677f, 0.024621f, 0.031566f, 0.038510f, 0.045455f, 0.051407f, 0.057359f, 0.063312f, 0.069264f, 0.075216f, 0.081169f, 0.087121f, 0.093074f, 0.099026f, 0.104978f, 0.110931f, 0.116883f, 0.122835f, 0.128788f, 0.134740f, 0.140693f, 0.146645f, 0.152597f, 0.158550f, 0.164502f, 0.170455f, 0.176407f, 0.182359f, 0.188312f, 0.194264f, 0.200216f, 0.206169f, 0.045455f, 0.052399f, 0.059343f, 0.066288f, 0.073232f, 0.080177f, 0.087121f, 0.094066f, 0.101010f, 0.107955f, 0.114899f, 0.121843f, 0.128788f, 0.135732f, 0.142677f, 0.149621f, 0.156566f, 0.163510f, 0.170455f, 0.177399f, 0.184343f, 0.191288f, 0.198232f, 0.205177f, 0.212121f, 0.218074f, 0.224026f, 0.229978f, 0.235931f, 0.241883f, 0.247836f, 0.253788f, 0.259740f, 0.265693f, 0.271645f, 0.277597f, 0.283550f, 0.289502f, 0.295455f, 0.301407f, 0.307359f, 0.313312f, 0.319264f, 0.325216f, 0.331169f, 0.337121f, 0.343074f, 0.349026f, 0.354978f, 0.360931f, 0.366883f, 0.372835f, 0.212121f, 0.219066f, 0.226010f, 0.232955f, 0.239899f, 0.246843f, 0.253788f, 0.260732f, 0.267677f, 0.274621f, 0.281566f, 0.288510f, 0.295455f, 0.302399f, 0.309343f, 0.316288f, 0.323232f, 0.330177f, 0.337121f, 0.344066f, 0.351010f, 0.357955f, 0.364899f, 0.371843f, 0.378788f, 0.384740f, 0.390693f, 0.396645f, 0.402597f, 0.408550f, 0.414502f, 0.420455f, 0.426407f, 0.432359f, 0.438312f, 0.444264f, 0.450216f, 0.456169f, 0.462121f, 0.468074f, 0.474026f, 0.479978f, 0.485931f, 0.491883f, 0.497835f, 0.503788f, 0.509740f, 0.515693f, 0.521645f, 0.527597f, 0.533550f, 0.539502f, 0.378788f, 0.385732f, 0.392677f, 0.399621f, 0.406566f, 0.413510f, 0.420455f, 0.427399f, 0.434343f, 0.441288f, 0.448232f, 0.455177f, 0.462121f, 0.469066f, 0.476010f, 0.482955f, 0.489899f, 0.496843f, 0.503788f, 0.510732f, 0.517677f, 0.524621f, 0.531566f, 0.538510f}; + // {2, 3, 18, 8} + std::vector present_value = {-0.454545f, -0.448593f, -0.442641f, -0.436688f, -0.430736f, -0.424784f, -0.418831f, -0.412879f, -0.406926f, -0.400974f, -0.395022f, -0.389069f, -0.383117f, -0.377165f, -0.371212f, -0.365260f, -0.359307f, -0.353355f, -0.347403f, -0.341450f, -0.335498f, -0.329545f, -0.323593f, -0.317641f, -0.311688f, -0.305736f, -0.299784f, -0.293831f, -0.454545f, -0.447601f, -0.440657f, -0.433712f, -0.426768f, -0.419823f, -0.412879f, -0.405934f, -0.398990f, -0.392045f, -0.385101f, -0.378157f, -0.371212f, -0.364268f, -0.357323f, -0.350379f, -0.343434f, -0.336490f, -0.329545f, -0.322601f, -0.315657f, -0.308712f, -0.301768f, -0.294823f, -0.287879f, -0.281926f, -0.275974f, -0.270022f, -0.264069f, -0.258117f, -0.252165f, -0.246212f, -0.240260f, -0.234307f, -0.228355f, -0.222403f, -0.216450f, -0.210498f, -0.204545f, -0.198593f, -0.192641f, -0.186688f, -0.180736f, -0.174784f, -0.168831f, -0.162879f, -0.156926f, -0.150974f, -0.145022f, -0.139069f, -0.133117f, -0.127164f, -0.287879f, -0.280934f, -0.273990f, -0.267045f, -0.260101f, -0.253157f, -0.246212f, -0.239268f, -0.232323f, -0.225379f, -0.218434f, -0.211490f, -0.204545f, -0.197601f, -0.190657f, -0.183712f, -0.176768f, -0.169823f, -0.162879f, -0.155934f, -0.148990f, -0.142045f, -0.135101f, -0.128157f, -0.121212f, -0.115260f, -0.109307f, -0.103355f, -0.097403f, -0.091450f, -0.085498f, -0.079545f, -0.073593f, -0.067641f, -0.061688f, -0.055736f, -0.049784f, -0.043831f, -0.037879f, -0.031926f, -0.025974f, -0.020022f, -0.014069f, -0.008117f, -0.002165f, 0.003788f, 0.009740f, 0.015693f, 0.021645f, 0.027597f, 0.033550f, 0.039502f, -0.121212f, -0.114268f, -0.107323f, -0.100379f, -0.093434f, -0.086490f, -0.079545f, -0.072601f, -0.065657f, -0.058712f, -0.051768f, -0.044823f, -0.037879f, -0.030934f, -0.023990f, -0.017045f, -0.010101f, -0.003157f, 0.003788f, 0.010732f, 0.017677f, 0.024621f, 0.031566f, 0.038510f, 0.045455f, 0.051407f, 0.057359f, 0.063312f, 0.069264f, 0.075216f, 0.081169f, 0.087121f, 0.093074f, 0.099026f, 0.104978f, 0.110931f, 0.116883f, 0.122835f, 0.128788f, 0.134740f, 0.140693f, 0.146645f, 0.152597f, 0.158550f, 0.164502f, 0.170455f, 0.176407f, 0.182359f, 0.188312f, 0.194264f, 0.200216f, 0.206169f, 0.045455f, 0.052399f, 0.059343f, 0.066288f, 0.073232f, 0.080177f, 0.087121f, 0.094066f, 0.101010f, 0.107955f, 0.114899f, 0.121843f, 0.128788f, 0.135732f, 0.142677f, 0.149621f, 0.156566f, 0.163510f, 0.170455f, 0.177399f, 0.184343f, 0.191288f, 0.198232f, 0.205177f, 0.212121f, 0.218074f, 0.224026f, 0.229978f, 0.235931f, 0.241883f, 0.247836f, 0.253788f, 0.259740f, 0.265693f, 0.271645f, 0.277597f, 0.283550f, 0.289502f, 0.295455f, 0.301407f, 0.307359f, 0.313312f, 0.319264f, 0.325216f, 0.331169f, 0.337121f, 0.343074f, 0.349026f, 0.354978f, 0.360931f, 0.366883f, 0.372835f, 0.212121f, 0.219066f, 0.226010f, 0.232955f, 0.239899f, 0.246843f, 0.253788f, 0.260732f, 0.267677f, 0.274621f, 0.281566f, 0.288510f, 0.295455f, 0.302399f, 0.309343f, 0.316288f, 0.323232f, 0.330177f, 0.337121f, 0.344066f, 0.351010f, 0.357955f, 0.364899f, 0.371843f, 0.378788f, 0.384740f, 0.390693f, 0.396645f, 0.402597f, 0.408550f, 0.414502f, 0.420455f, 0.426407f, 0.432359f, 0.438312f, 0.444264f, 0.450216f, 0.456169f, 0.462121f, 0.468074f, 0.474026f, 0.479978f, 0.485931f, 0.491883f, 0.497835f, 0.503788f, 0.509740f, 0.515693f, 0.521645f, 0.527597f, 0.533550f, 0.539502f, 0.378788f, 0.385732f, 0.392677f, 0.399621f, 0.406566f, 0.413510f, 0.420455f, 0.427399f, 0.434343f, 0.441288f, 0.448232f, 0.455177f, 0.462121f, 0.469066f, 0.476010f, 0.482955f, 0.489899f, 0.496843f, 0.503788f, 0.510732f, 0.517677f, 0.524621f, 0.531566f, 0.538510f}; + // {2, 3, 4, 13} + std::vector qk_matmul = {0.391336f, 0.370435f, 0.349534f, 0.328633f, 0.307732f, 0.286831f, 0.265930f, 0.390055f, 0.365671f, 0.341286f, 0.316902f, 0.292517f, 0.268133f, 0.354201f, 0.335284f, 0.316367f, 0.297450f, 0.278534f, 0.259617f, 0.240700f, 0.353045f, 0.330975f, 0.308905f, 0.286836f, 0.264766f, 0.242696f, 0.317066f, 0.300134f, 0.283201f, 0.266268f, 0.249335f, 0.232403f, 0.215470f, 0.316034f, 0.296279f, 0.276524f, 0.256769f, 0.237014f, 0.217260f, 0.279932f, 0.264983f, 0.250034f, 0.235086f, 0.220137f, 0.205189f, 0.190240f, 0.279023f, 0.261583f, 0.244143f, 0.226703f, 0.209263f, 0.191823f, 0.152046f, 0.139081f, 0.126117f, 0.113152f, 0.100188f, 0.087223f, 0.074259f, 0.151261f, 0.136136f, 0.121011f, 0.105885f, 0.090760f, 0.075635f, 0.128800f, 0.117819f, 0.106839f, 0.095859f, 0.084878f, 0.073898f, 0.062918f, 0.128139f, 0.115329f, 0.102518f, 0.089708f, 0.076898f, 0.064087f, 0.105554f, 0.096558f, 0.087561f, 0.078565f, 0.069569f, 0.060573f, 0.051577f, 0.105017f, 0.094522f, 0.084026f, 0.073531f, 0.063035f, 0.052539f, 0.082308f, 0.075296f, 0.068284f, 0.061272f, 0.054260f, 0.047248f, 0.040235f, 0.081896f, 0.073715f, 0.065534f, 0.057353f, 0.049172f, 0.040992f, 0.023866f, 0.018838f, 0.013810f, 0.008783f, 0.003755f, -0.001273f, -0.006301f, 0.023578f, 0.017712f, 0.011846f, 0.005980f, 0.000114f, -0.005752f, 0.014509f, 0.011466f, 0.008422f, 0.005378f, 0.002334f, -0.000710f, -0.003754f, 0.014345f, 0.010794f, 0.007243f, 0.003692f, 0.000140f, -0.003411f, 0.005152f, 0.004093f, 0.003033f, 0.001973f, 0.000914f, -0.000146f, -0.001206f, 0.005112f, 0.003876f, 0.002639f, 0.001403f, 0.000167f, -0.001070f, -0.004204f, -0.003280f, -0.002356f, -0.001431f, -0.000507f, 0.000418f, 0.001342f, -0.004121f, -0.003042f, -0.001964f, -0.000885f, 0.000193f, 0.001272f, 0.006798f, 0.009707f, 0.012616f, 0.015524f, 0.018433f, 0.021341f, 0.024250f, 0.007006f, 0.010399f, 0.013793f, 0.017186f, 0.020579f, 0.023973f, 0.011330f, 0.016223f, 0.021116f, 0.026008f, 0.030901f, 0.035794f, 0.040686f, 0.011662f, 0.017370f, 0.023078f, 0.028786f, 0.034494f, 0.040203f, 0.015862f, 0.022739f, 0.029616f, 0.036493f, 0.043369f, 0.050246f, 0.057123f, 0.016318f, 0.024341f, 0.032364f, 0.040387f, 0.048410f, 0.056433f, 0.020394f, 0.029255f, 0.038116f, 0.046977f, 0.055838f, 0.064699f, 0.073560f, 0.020974f, 0.031312f, 0.041649f, 0.051987f, 0.062325f, 0.072663f, 0.100842f, 0.111687f, 0.122532f, 0.133377f, 0.144222f, 0.155067f, 0.165912f, 0.101545f, 0.114198f, 0.126850f, 0.139503f, 0.152155f, 0.164808f, 0.119262f, 0.132092f, 0.144921f, 0.157750f, 0.170579f, 0.183408f, 0.196237f, 0.120090f, 0.135057f, 0.150025f, 0.164992f, 0.179960f, 0.194927f, 0.137683f, 0.152496f, 0.167310f, 0.182123f, 0.196936f, 0.211750f, 0.226563f, 0.138635f, 0.155917f, 0.173199f, 0.190481f, 0.207764f, 0.225046f, 0.156104f, 0.172901f, 0.189699f, 0.206496f, 0.223294f, 0.240091f, 0.256889f, 0.157180f, 0.176777f, 0.196374f, 0.215971f, 0.235568f, 0.255165f, 0.305996f, 0.324777f, 0.343559f, 0.362340f, 0.381122f, 0.399904f, 0.418685f, 0.307195f, 0.329107f, 0.351019f, 0.372931f, 0.394843f, 0.416755f, 0.338305f, 0.359071f, 0.379837f, 0.400603f, 0.421368f, 0.442134f, 0.462900f, 0.339629f, 0.363856f, 0.388082f, 0.412309f, 0.436536f, 0.460762f, 0.370615f, 0.393365f, 0.416115f, 0.438865f, 0.461614f, 0.484364f, 0.507114f, 0.372063f, 0.398604f, 0.425146f, 0.451687f, 0.478229f, 0.504770f, 0.402925f, 0.427659f, 0.452393f, 0.477127f, 0.501861f, 0.526595f, 0.551329f, 0.404497f, 0.433353f, 0.462209f, 0.491065f, 0.519922f, 0.548778f}; + ASSERT_EQ(y.size(), batch_size * q_num_heads * q_sequence_length * v_head_size); + ASSERT_EQ(present_key.size(), batch_size * kv_num_heads * (past_sequence_length + kv_sequence_length) * head_size); + ASSERT_EQ(present_value.size(), batch_size * kv_num_heads * (past_sequence_length + kv_sequence_length) * v_head_size); + ASSERT_EQ(qk_matmul.size(), batch_size * kv_num_heads * q_sequence_length * (past_sequence_length + kv_sequence_length)); + + RunTest4D(batch_size, q_num_heads, q_sequence_length, head_size, kv_sequence_length, kv_num_heads, v_head_size, past_sequence_length, + q, k, v, m, std::initializer_list(), past_key, past_value, + -1, -1, std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), -1, TensorType::kFloat, // is_causal, qk_matmul_output_mode, scale, softcap, softmax_precision, tensor_type + y, present_key, present_value, qk_matmul, + false, true, true // disable_cpu, disable_cuda, disable_dml + ); +} + +} // namespace test +} // namespace onnxruntime diff --git a/onnxruntime/test/providers/cpu/tensor/concat_op_test.cc b/onnxruntime/test/providers/cpu/tensor/concat_op_test.cc index 9e0fb81cbb0fc..b5e13c6377ccb 100644 --- a/onnxruntime/test/providers/cpu/tensor/concat_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/concat_op_test.cc @@ -434,5 +434,110 @@ TEST(ConcatOpTest, Concat4D_2) { test.Run(); } +#ifdef USE_WEBGPU +TEST(ConcatOpTest, Concat1D_int32_4inputs) { + OpTester test("Concat"); + test.AddAttribute("axis", int64_t{0}); + + test.AddInput("input1", {1}, {1}); + test.AddInput("input2", {2}, {2, 3}); + test.AddInput("input3", {4}, {4, 5, 6, 7}); + test.AddInput("input4", {2}, {8, 9}); + test.AddOutput("concat_result", {9}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); + test.Run(); +} + +TEST(ConcatOpTest, Concat1D_exceed_maxStorageBuffersPerShaderStage) { + // maxStorageBuffersPerShaderStage==8 + OpTester test("Concat"); + test.AddAttribute("axis", int64_t{0}); + + test.AddInput("input1", {1}, {1}); + test.AddInput("input2", {1}, {2}); + test.AddInput("input3", {1}, {3}); + test.AddInput("input4", {1}, {4}); + test.AddInput("input5", {1}, {5}); + test.AddInput("input6", {1}, {6}); + test.AddInput("input7", {1}, {7}); + test.AddInput("input8", {1}, {8}); + test.AddInput("input9", {1}, {9}); + test.AddOutput("concat_result", {9}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); + test.Run(); +} + +TEST(ConcatOpTest, Concat2D_exceed_maxStorageBuffersPerShaderStage_axis0) { + // maxStorageBuffersPerShaderStage==8 + OpTester test("Concat"); + test.AddAttribute("axis", int64_t{0}); + + test.AddInput("input1", {1, 2}, {1, 2}); + test.AddInput("input2", {1, 2}, {3, 4}); + test.AddInput("input3", {1, 2}, {5, 6}); + test.AddInput("input4", {1, 2}, {7, 8}); + test.AddInput("input5", {1, 2}, {9, 10}); + test.AddInput("input6", {1, 2}, {11, 12}); + test.AddInput("input7", {1, 2}, {13, 14}); + test.AddInput("input8", {1, 2}, {15, 16}); + test.AddInput("input9", {1, 2}, {17, 18}); + test.AddOutput("concat_result", {9, 2}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18}); + test.Run(); +} + +TEST(ConcatOpTest, Concat2D_exceed_maxStorageBuffersPerShaderStage_axis1) { + // maxStorageBuffersPerShaderStage==8 + OpTester test("Concat"); + test.AddAttribute("axis", int64_t{1}); + + test.AddInput("input1", {1, 2}, {1, 2}); + test.AddInput("input2", {1, 2}, {3, 4}); + test.AddInput("input3", {1, 2}, {5, 6}); + test.AddInput("input4", {1, 2}, {7, 8}); + test.AddInput("input5", {1, 2}, {9, 10}); + test.AddInput("input6", {1, 2}, {11, 12}); + test.AddInput("input7", {1, 2}, {13, 14}); + test.AddInput("input8", {1, 2}, {15, 16}); + test.AddInput("input9", {1, 2}, {17, 18}); + test.AddOutput("concat_result", {1, 18}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18}); + test.Run(); +} + +TEST(ConcatOpTest, Concat3D_exceed_maxStorageBuffersPerShaderStage) { + // maxStorageBuffersPerShaderStage==8 + OpTester test("Concat"); + test.AddAttribute("axis", int64_t{1}); + + test.AddInput("input1", {2, 1, 1}, {1, 2}); + test.AddInput("input2", {2, 1, 1}, {3, 4}); + test.AddInput("input3", {2, 1, 1}, {5, 6}); + test.AddInput("input4", {2, 1, 1}, {7, 8}); + test.AddInput("input5", {2, 1, 1}, {9, 10}); + test.AddInput("input6", {2, 1, 1}, {11, 12}); + test.AddInput("input7", {2, 1, 1}, {13, 14}); + test.AddInput("input8", {2, 1, 1}, {15, 16}); + test.AddInput("input9", {2, 1, 1}, {17, 18}); + test.AddOutput("concat_result", {2, 9, 1}, {// batch 0 + 1, 3, 5, 7, 9, 11, 13, 15, 17, + // batch 1 + 2, 4, 6, 8, 10, 12, 14, 16, 18}); + test.Run(); +} + +TEST(ConcatOpTest, Concat3D_exceed_maxStorageBuffersPerShaderStage_mixed_sizes) { + // maxStorageBuffersPerShaderStage==8 + OpTester test("Concat"); + test.AddAttribute("axis", int64_t{1}); + + test.AddInput("input1", {2, 1, 1}, {1, 2}); + test.AddInput("input2", {2, 3, 1}, {3, 4, 5, 6, 7, 8}); + test.AddInput("input3", {2, 2, 1}, {9, 10, 11, 12}); + test.AddInput("input4", {2, 1, 1}, {13, 14}); + test.AddOutput("concat_result", {2, 7, 1}, {// batch 0 + 1, 3, 4, 5, 9, 10, 13, + // batch 1 + 2, 6, 7, 8, 11, 12, 14}); + test.Run(); +} +#endif // USE_WEBGPU + } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/providers/cuda/nhwc/nhwc_cuda_helper.h b/onnxruntime/test/providers/cuda/nhwc/nhwc_cuda_helper.h index 1aea58c8d7a10..a49f662ca1adb 100644 --- a/onnxruntime/test/providers/cuda/nhwc/nhwc_cuda_helper.h +++ b/onnxruntime/test/providers/cuda/nhwc/nhwc_cuda_helper.h @@ -46,7 +46,7 @@ } else if (std::is_same::value) { \ MAKE_PROVIDERS_EPS_EXT(2e-4, pad_to_nc1d) \ } else { \ - MAKE_PROVIDERS_EPS_EXT(2e-3, pad_to_nc1d) \ + MAKE_PROVIDERS_EPS_EXT(4e-3, pad_to_nc1d) \ } #define MAKE_PROVIDERS_EPS_TYPE(T) \ diff --git a/onnxruntime/test/providers/nv_tensorrt_rtx/nv_basic_test.cc b/onnxruntime/test/providers/nv_tensorrt_rtx/nv_basic_test.cc index 8858ae75fb39a..0559699670c4a 100644 --- a/onnxruntime/test/providers/nv_tensorrt_rtx/nv_basic_test.cc +++ b/onnxruntime/test/providers/nv_tensorrt_rtx/nv_basic_test.cc @@ -8,8 +8,13 @@ #include "gtest/gtest.h" #include "test/util/include/scoped_env_vars.h" #include "test/common/trt_op_test_utils.h" +#include "test/common/random_generator.h" +#include "test/providers/nv_tensorrt_rtx/test_nv_trt_rtx_ep_util.h" +#include "test/util/include/api_asserts.h" +#include "test/util/include/asserts.h" #include +#include #include #include #include @@ -20,7 +25,7 @@ using namespace std; using namespace ONNX_NAMESPACE; using namespace ::onnxruntime::logging; - +extern std::unique_ptr ort_env; namespace onnxruntime { namespace test { @@ -410,9 +415,10 @@ static bool SessionHasEp(Ort::Session& session, const char* ep_name) { TEST(NvExecutionProviderTest, AutoEp_PreferGpu) { PathString model_name = ORT_TSTR("nv_execution_provider_data_dyn_test.onnx"); std::string graph_name = "test"; - std::vector dims = {1, -1, -1}; - CreateBaseModel(model_name, graph_name, dims, true); + std::vector dims = {1, 3, 2}; + + CreateBaseModel(model_name, graph_name, dims); auto env = Ort::Env(); auto logging_level = OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING; @@ -429,6 +435,151 @@ TEST(NvExecutionProviderTest, AutoEp_PreferGpu) { env.UnregisterExecutionProviderLibrary(kNvTensorRTRTXExecutionProvider); } + +TEST(NvExecutionProviderTest, GetSharedAllocator) { + const OrtApi& c_api = Ort::GetApi(); + RegisteredEpDeviceUniquePtr nv_tensorrt_rtx_ep; + Utils::RegisterAndGetNvTensorRtRtxEp(*ort_env, nv_tensorrt_rtx_ep); + + const auto* ep_memory_info = c_api.EpDevice_MemoryInfo(nv_tensorrt_rtx_ep.get(), OrtDeviceMemoryType_DEFAULT); + + // validate there is a shared allocator + OrtAllocator* allocator = nullptr; + ASSERT_ORTSTATUS_OK(c_api.GetSharedAllocator(*ort_env, ep_memory_info, &allocator)); + ASSERT_NE(allocator, nullptr); + + const auto* ep_host_accessible_memory_info = c_api.EpDevice_MemoryInfo(nv_tensorrt_rtx_ep.get(), OrtDeviceMemoryType_HOST_ACCESSIBLE); + OrtAllocator* host_accessible_allocator = nullptr; + ASSERT_ORTSTATUS_OK(c_api.GetSharedAllocator(*ort_env, ep_host_accessible_memory_info, &host_accessible_allocator)); + ASSERT_NE(host_accessible_allocator, nullptr); +} + +TEST(NvExecutionProviderTest, LoadUnloadPluginLibrary) { + const std::filesystem::path& library_path = Utils::nv_tensorrt_rtx_ep_info.library_path; + const std::string& registration_name = Utils::nv_tensorrt_rtx_ep_info.registration_name; + + const OrtApi* c_api = &Ort::GetApi(); + // this should load the library and create OrtEpDevice + ASSERT_ORTSTATUS_OK(Ort::GetApi().RegisterExecutionProviderLibrary(*ort_env, registration_name.c_str(), + library_path.c_str())); + + const OrtEpDevice* const* ep_devices = nullptr; + size_t num_devices = 0; + + ASSERT_ORTSTATUS_OK(Ort::GetApi().GetEpDevices(*ort_env, &ep_devices, &num_devices)); + // should be one device for the example EP + auto num_test_ep_devices = std::count_if(ep_devices, ep_devices + num_devices, + [®istration_name, &c_api](const OrtEpDevice* device) { + // the example uses the registration name for the EP name + // but that is not a requirement and the two can differ. + return c_api->EpDevice_EpName(device) == registration_name; + }); + ASSERT_EQ(num_test_ep_devices, 1) << "Expected an OrtEpDevice to have been created by the test library."; + + // and this should unload it + ASSERT_ORTSTATUS_OK(Ort::GetApi().UnregisterExecutionProviderLibrary(*ort_env, + registration_name.c_str())); +} + +TEST(NvExecutionProviderTest, LoadUnloadPluginLibraryCxxApi) { + const std::filesystem::path& library_path = Utils::nv_tensorrt_rtx_ep_info.library_path; + const std::string& registration_name = Utils::nv_tensorrt_rtx_ep_info.registration_name; + const OrtApi* c_api = &Ort::GetApi(); + // this should load the library and create OrtEpDevice + ort_env->RegisterExecutionProviderLibrary(registration_name.c_str(), library_path.c_str()); + + std::vector ep_devices = ort_env->GetEpDevices(); + + auto test_ep_device = std::find_if(ep_devices.begin(), ep_devices.end(), + [®istration_name, &c_api](const Ort::ConstEpDevice& device) { + return device.EpName() == registration_name; + }); + ASSERT_NE(test_ep_device, ep_devices.end()) << "Expected an OrtEpDevice to have been created by the test library."; + + // test all the C++ getters. expected values are from \onnxruntime\test\autoep\library\example_plugin_ep.cc + ASSERT_STREQ(test_ep_device->EpVendor(), "NVIDIA"); + + auto metadata = test_ep_device->EpMetadata(); + ASSERT_STREQ(metadata.GetValue(kOrtEpDevice_EpMetadataKey_Version), ORT_VERSION); + + // the GPU device info will vary by machine so check for the lowest common denominator values + Ort::ConstHardwareDevice device = test_ep_device->Device(); + ASSERT_EQ(device.Type(), OrtHardwareDeviceType_GPU); + ASSERT_GE(device.VendorId(), 0); + ASSERT_GE(device.DeviceId(), 0); + ASSERT_NE(device.Vendor(), nullptr); + Ort::ConstKeyValuePairs device_metadata = device.Metadata(); + std::unordered_map metadata_entries = device_metadata.GetKeyValuePairs(); + ASSERT_GT(metadata_entries.size(), 0); // should have at least SPDRP_HARDWAREID on Windows + + // and this should unload it without throwing + ort_env->UnregisterExecutionProviderLibrary(registration_name.c_str()); +} + +TEST(NvExecutionProviderTest, DataTransfer) { + const OrtApi& c_api = Ort::GetApi(); + RegisteredEpDeviceUniquePtr nv_tensorrt_rtx_ep; + Utils::RegisterAndGetNvTensorRtRtxEp(*ort_env, nv_tensorrt_rtx_ep); + const OrtEpDevice* ep_device = nv_tensorrt_rtx_ep.get(); + + const OrtMemoryInfo* device_memory_info = c_api.EpDevice_MemoryInfo(ep_device, OrtDeviceMemoryType_DEFAULT); + + // create a tensor using the default CPU allocator + Ort::AllocatorWithDefaultOptions cpu_allocator; + std::vector shape{2, 3, 4}; // shape doesn't matter + const size_t num_elements = 2 * 3 * 4; + + RandomValueGenerator random{}; + std::vector input_data = random.Gaussian(shape, 0.0f, 2.f); + Ort::Value cpu_tensor = Ort::Value::CreateTensor(cpu_allocator.GetInfo(), + input_data.data(), input_data.size(), + shape.data(), shape.size()); + + // create an on-device Tensor using the NV TensorRT RTX EP GPU allocator. + + OrtAllocator* allocator = nullptr; + ASSERT_ORTSTATUS_OK(c_api.GetSharedAllocator(*ort_env, device_memory_info, &allocator)); + ASSERT_NE(allocator, nullptr); + Ort::Value device_tensor = Ort::Value::CreateTensor(allocator, shape.data(), shape.size()); + + std::vector src_tensor_ptrs{cpu_tensor}; + std::vector dst_tensor_ptrs{device_tensor}; + + ASSERT_ORTSTATUS_OK(c_api.CopyTensors(*ort_env, src_tensor_ptrs.data(), dst_tensor_ptrs.data(), nullptr, + src_tensor_ptrs.size())); + + // Copy data back from device_tensor to a new CPU tensor and verify the contents + + // Create a new CPU tensor to receive the data + Ort::Value cpu_tensor_copy = Ort::Value::CreateTensor(cpu_allocator, shape.data(), shape.size()); + + std::vector src_tensor_ptrs_back{device_tensor}; + std::vector dst_tensor_ptrs_back{cpu_tensor_copy}; + + ASSERT_ORTSTATUS_OK(c_api.CopyTensors(*ort_env, src_tensor_ptrs_back.data(), dst_tensor_ptrs_back.data(), nullptr, + src_tensor_ptrs_back.size())); + + const float* src_data = nullptr; + ASSERT_ORTSTATUS_OK(c_api.GetTensorData(cpu_tensor, reinterpret_cast(&src_data))); + + const float* cpu_copy_data = nullptr; + ASSERT_ORTSTATUS_OK(c_api.GetTensorData(cpu_tensor_copy, reinterpret_cast(&cpu_copy_data))); + + ASSERT_NE(src_data, cpu_copy_data) << "Should have copied between two different memory locations"; + + size_t bytes; + ASSERT_ORTSTATUS_OK(c_api.GetTensorSizeInBytes(cpu_tensor, &bytes)); + ASSERT_EQ(bytes, num_elements * sizeof(float)); + + auto src_span = gsl::make_span(src_data, num_elements); + auto cpu_copy_span = gsl::make_span(cpu_copy_data, num_elements); + + EXPECT_THAT(cpu_copy_span, ::testing::ContainerEq(src_span)); + + // must release this before we unload the EP and the allocator is deleted + device_tensor = Ort::Value(); +} + #endif // defined(WIN32) } // namespace test diff --git a/onnxruntime/test/providers/nv_tensorrt_rtx/test_nv_trt_rtx_ep_util.cc b/onnxruntime/test/providers/nv_tensorrt_rtx/test_nv_trt_rtx_ep_util.cc new file mode 100644 index 0000000000000..f0ce5c0b296ca --- /dev/null +++ b/onnxruntime/test/providers/nv_tensorrt_rtx/test_nv_trt_rtx_ep_util.cc @@ -0,0 +1,58 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Licensed under the MIT License. + +// registration/selection is only supported on windows as there's no device discovery on other platforms +#ifdef _WIN32 + +#include "test/providers/nv_tensorrt_rtx/test_nv_trt_rtx_ep_util.h" + +#include +#include + +#include "core/session/onnxruntime_cxx_api.h" +#include "test/util/include/api_asserts.h" + +namespace onnxruntime { +namespace test { + +Utils::NvTensorRtRtxEpInfo Utils::nv_tensorrt_rtx_ep_info; + +void Utils::GetEp(Ort::Env& env, const std::string& ep_name, const OrtEpDevice*& ep_device) { + const OrtApi& c_api = Ort::GetApi(); + const OrtEpDevice* const* ep_devices = nullptr; + size_t num_devices; + ASSERT_ORTSTATUS_OK(c_api.GetEpDevices(env, &ep_devices, &num_devices)); + + auto it = std::find_if(ep_devices, ep_devices + num_devices, + [&c_api, &ep_name](const OrtEpDevice* ep_device) { + // NV TensorRT RTX EP uses registration name as ep name + return c_api.EpDevice_EpName(ep_device) == ep_name; + }); + + if (it == ep_devices + num_devices) { + ep_device = nullptr; + } else { + ep_device = *it; + } +} + +void Utils::RegisterAndGetNvTensorRtRtxEp(Ort::Env& env, RegisteredEpDeviceUniquePtr& registered_ep) { + const OrtApi& c_api = Ort::GetApi(); + // this should load the library and create OrtEpDevice + ASSERT_ORTSTATUS_OK(c_api.RegisterExecutionProviderLibrary(env, + nv_tensorrt_rtx_ep_info.registration_name.c_str(), + nv_tensorrt_rtx_ep_info.library_path.c_str())); + const OrtEpDevice* nv_tensorrt_rtx_ep = nullptr; + GetEp(env, nv_tensorrt_rtx_ep_info.registration_name, nv_tensorrt_rtx_ep); + ASSERT_NE(nv_tensorrt_rtx_ep, nullptr); + + registered_ep = RegisteredEpDeviceUniquePtr(nv_tensorrt_rtx_ep, [&env, c_api](const OrtEpDevice* /*ep*/) { + c_api.UnregisterExecutionProviderLibrary(env, nv_tensorrt_rtx_ep_info.registration_name.c_str()); + }); +} + +} // namespace test +} // namespace onnxruntime + +#endif // _WIN32 diff --git a/onnxruntime/test/providers/nv_tensorrt_rtx/test_nv_trt_rtx_ep_util.h b/onnxruntime/test/providers/nv_tensorrt_rtx/test_nv_trt_rtx_ep_util.h new file mode 100644 index 0000000000000..ef14d3cb382c0 --- /dev/null +++ b/onnxruntime/test/providers/nv_tensorrt_rtx/test_nv_trt_rtx_ep_util.h @@ -0,0 +1,38 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Licensed under the MIT License. +#pragma once + +#include +#include + +#include "core/session/onnxruntime_cxx_api.h" +#include "core/graph/constants.h" + +namespace onnxruntime { +namespace test { + +using RegisteredEpDeviceUniquePtr = std::unique_ptr>; + +struct Utils { + struct NvTensorRtRtxEpInfo { + const std::filesystem::path library_path = +#if _WIN32 + "onnxruntime_providers_nv_tensorrt_rtx.dll"; +#else + "libonnxruntime_providers_nv_tensorrt_rtx.so"; +#endif + const std::string registration_name = kNvTensorRTRTXExecutionProvider; + }; + + static NvTensorRtRtxEpInfo nv_tensorrt_rtx_ep_info; + + // get the OrtEpDevice for the NV TensorRT RTX EP from the environment + static void GetEp(Ort::Env& env, const std::string& ep_name, const OrtEpDevice*& ep_device); + + // Register the NV TensorRT RTX EP library, get the OrtEpDevice for it, and return a unique pointer that will + // automatically unregister the EP library. + static void RegisterAndGetNvTensorRtRtxEp(Ort::Env& env, RegisteredEpDeviceUniquePtr& nv_tensorrt_rtx_ep); +}; +} // namespace test +} // namespace onnxruntime diff --git a/onnxruntime/test/providers/qnn/einsum_op_test.cc b/onnxruntime/test/providers/qnn/einsum_op_test.cc index a2a0ce485bb35..d8dbbd799a427 100644 --- a/onnxruntime/test/providers/qnn/einsum_op_test.cc +++ b/onnxruntime/test/providers/qnn/einsum_op_test.cc @@ -189,6 +189,19 @@ TEST_F(QnnCPUBackendTests, EinsumRank4MatMulTransposeAll1) { /*tolerance=*/1e-4f); } +TEST_F(QnnCPUBackendTests, EinsumMatMulBroadcastTransposeY) { + const std::vector shape0{2, 3, 3, 4}; + const std::vector shape1{3, 3, 4}; + const std::vector data0 = GetSequentialFloatData(shape0, /*start=*/-0.1f, /*step=*/0.05f); + const std::vector data1 = GetSequentialFloatData(shape1, /*start=*/-0.1f, /*step=*/0.05f); + RunQnnEinsum( + /*backend=*/kQnnBackendTypeCpu, + /*in0=*/TestInputDef(shape0, /*is_initializer=*/false, std::move(data0)), + /*in1=*/TestInputDef(shape1, /*is_initializer=*/false, std::move(data1)), + /*equation=*/"bhwc,hkc->bhwk", + /*tolerance=*/1e-4f); +} + TEST_F(QnnCPUBackendTests, EinsumRank4MatMulTransposeAll2) { const std::vector shape0{1, 7, 1, 7}; const std::vector shape1{1, 9, 1, 7}; @@ -273,6 +286,19 @@ TEST_F(QnnHTPBackendTests, EinsumF16Rank4MatMulTransposeAll2) { /*tolerance=*/1e-2f); } +TEST_F(QnnHTPBackendTests, EinsumF16MatMulBroadcastTransposeY) { + const std::vector shape0{2, 3, 3, 4}; + const std::vector shape1{3, 3, 4}; + const std::vector data0 = GetSequentialFloatData(shape0, /*start=*/-0.1f, /*step=*/0.05f); + const std::vector data1 = GetSequentialFloatData(shape1, /*start=*/-0.1f, /*step=*/0.05f); + RunQnnEinsum( + /*backend=*/kQnnBackendTypeHtp, + /*in0=*/TestInputDef(shape0, /*is_initializer=*/false, std::move(data0)), + /*in1=*/TestInputDef(shape1, /*is_initializer=*/false, std::move(data1)), + /*equation=*/"bhwc,hkc->bhwk", + /*tolerance=*/1e-2f); +} + // // QNN HTP QDQ // @@ -337,6 +363,18 @@ TEST_F(QnnHTPBackendTests, EinsumQdqRank4MatMulTransposeAll2) { /*tolerance=*/QDQTolerance()); } +TEST_F(QnnHTPBackendTests, EinsumQdqMatMulBroadcastTransposeY) { + const std::vector shape0{2, 3, 3, 4}; + const std::vector shape1{3, 3, 4}; + const std::vector data0 = GetSequentialFloatData(shape0, /*start=*/-0.1f, /*step=*/0.05f); + const std::vector data1 = GetSequentialFloatData(shape1, /*start=*/-0.1f, /*step=*/0.05f); + RunQnnHtpQdqEinsum( + /*in0=*/TestInputDef(shape0, /*is_initializer=*/false, std::move(data0)), + /*in1=*/TestInputDef(shape1, /*is_initializer=*/false, std::move(data1)), + /*equation=*/"bhwc,hkc->bhwk", + /*tolerance=*/QDQTolerance()); +} + #endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) #if defined(_M_ARM64) @@ -422,6 +460,20 @@ TEST_F(QnnGPUBackendTests, EinsumRank4MatMulTransposeAll2) { /*tolerance=*/1e-4f); } +// Numeric instability in GPU backend, see also MatMul tests. +TEST_F(QnnGPUBackendTests, DISABLED_EinsumMatMulBroadcastTransposeY) { + const std::vector shape0{2, 3, 3, 4}; + const std::vector shape1{3, 3, 4}; + const std::vector data0 = GetSequentialFloatData(shape0, /*start=*/-0.1f, /*step=*/0.05f); + const std::vector data1 = GetSequentialFloatData(shape1, /*start=*/-0.1f, /*step=*/0.05f); + RunQnnEinsum( + /*backend=*/kQnnBackendTypeGpu, + /*in0=*/TestInputDef(shape0, /*is_initializer=*/false, std::move(data0)), + /*in1=*/TestInputDef(shape1, /*is_initializer=*/false, std::move(data1)), + /*equation=*/"bhwc,hkc->bhwk", + /*tolerance=*/1e-4f); +} + #endif // defined(_M_ARM64) GPU tests } // namespace test diff --git a/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc b/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc index 553059932db90..706bd3c0fce62 100644 --- a/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc +++ b/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc @@ -571,6 +571,8 @@ TEST(TensorrtExecutionProviderTest, EPContextNode) { params7.trt_dump_ep_context_model = 1; params7.trt_ep_context_embed_mode = 1; params7.trt_weight_stripped_engine_enable = 1; + params7.trt_onnx_bytestream = model_bytes.data(); + params7.trt_onnx_bytestream_size = model_bytes.size(); params7.trt_ep_context_file_path = ctx_model_name_str.c_str(); execution_provider = TensorrtExecutionProviderWithOptions(¶ms7); EXPECT_TRUE(session_object7.RegisterExecutionProvider(std::move(execution_provider)).IsOK()); diff --git a/onnxruntime/test/python/transformers/test_parity_moe.py b/onnxruntime/test/python/transformers/test_parity_moe.py index 252d89a2257fc..d805c8f9cae3c 100644 --- a/onnxruntime/test/python/transformers/test_parity_moe.py +++ b/onnxruntime/test/python/transformers/test_parity_moe.py @@ -9,6 +9,7 @@ # Licensed under the MIT License. See License.txt in the project root for # license information. # -------------------------------------------------------------------------- +import itertools import unittest from collections import OrderedDict @@ -24,11 +25,6 @@ torch.manual_seed(42) numpy.random.seed(42) -USE_QUANT = False -ORT_DTYPE = TensorProto.FLOAT16 if USE_QUANT else TensorProto.FLOAT -NP_TYPE = numpy.float16 if ORT_DTYPE == TensorProto.FLOAT16 else numpy.float32 -THRESHOLD = 5e-1 if USE_QUANT else 1e-2 - def value_string_of(numpy_array): arr = numpy_array.flatten() @@ -40,26 +36,69 @@ def print_tensor(name, numpy_array): print(f"const std::vector {name} = {value_string_of(numpy_array)};") -def quant_dequant(weights, quant_mode: bool = True): - # use the test version `_symmetric_...` to get the non-interleaved weights - type = torch.quint4x2 if quant_mode else torch.int8 - # This import is needed to use torch.ops.trtllm._symmetric_quantize_last_axis_of_batched_matrix() - # Comment out this line for passing the lintrunner check in the CI. - # import tensorrt_llm +def quant_dequant(weights: torch.Tensor, is_4_bit_quantization: bool): + """ + Performs symmetric per-column quantization and dequantization on a weight tensor. + + This implementation is a pure PyTorch replacement for the original function that + relied on a custom tensorrt_llm operator. It supports both 8-bit (int8) and + 4-bit (quint4x2 style) quantization. + + Args: + weights (torch.Tensor): The input weight tensor to be quantized. + is_4_bit_quantization (bool): If True, performs 4-bit quantization. If False, + performs 8-bit quantization. + + Returns: + Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: A tuple containing: + - scales (torch.float16): The quantization scales for each column. + - processed_q_weight (torch.int8): The packed quantized weights. For + 4-bit mode, two 4-bit values are packed into a single int8. For + 8-bit mode, this is the standard int8 quantized tensor. It is + transposed relative to the input weights' shape. + - dequantized_weights (torch.Tensor): The weights after being dequantized, + restored to the original dtype and device. + """ + # Determine quantization bits and range based on the mode + if is_4_bit_quantization: + # 4-bit symmetric quantization path + q_bits = 4 + q_max = 2 ** (q_bits - 1) - 1 # 7 + q_min = -(2 ** (q_bits - 1)) # -8 - quant_weights, processed_q_weight, torch_weight_scales = ( - torch.ops.trtllm._symmetric_quantize_last_axis_of_batched_matrix(weights.T.cpu().contiguous(), type) - ) + max_abs_val = torch.max(torch.abs(weights), dim=0, keepdim=True).values + max_abs_val[max_abs_val == 0] = 1.0 + scales = max_abs_val / q_max + + quant_weights = torch.round(weights / scales).clamp(q_min, q_max).to(torch.int8) + + # Pack two 4-bit integers into a single int8 + q_weights_t = quant_weights.T.contiguous() + shape = q_weights_t.shape + q_weights_t_reshaped = q_weights_t.view(shape[0], shape[1] // 2, 2) + lower_nibble = q_weights_t_reshaped[..., 0] + upper_nibble = q_weights_t_reshaped[..., 1] + processed_q_weight = (lower_nibble & 0x0F) | (upper_nibble << 4) + + else: + # 8-bit symmetric quantization path + q_bits = 8 + q_max = 2 ** (q_bits - 1) - 1 # 127 + q_min = -(2 ** (q_bits - 1)) # -128 - # Unpack the int4s int int8s - if quant_mode: - upper = quant_weights >> 4 - lower = (quant_weights << 4) >> 4 # Arithmetic right shift sign extends - quant_weights = torch.stack((lower, upper), dim=2).view(weights.T.shape) + max_abs_val = torch.max(torch.abs(weights), dim=0, keepdim=True).values + max_abs_val[max_abs_val == 0] = 1.0 + scales = max_abs_val / q_max - quant_weights = quant_weights.to(dtype=weights.dtype) - result = torch.multiply(quant_weights, torch_weight_scales.unsqueeze(0)).T.contiguous() - return torch_weight_scales.to(torch.float16), processed_q_weight, result.to(device=weights.device) + quant_weights = torch.round(weights / scales).clamp(q_min, q_max).to(torch.int8) + + # For 8-bit, the processed weights are just the transposed quantized weights (no packing) + processed_q_weight = quant_weights.T.contiguous() + + # Dequantize the weights to verify and return for PyTorch-side parity check + dequantized_weights = quant_weights.to(weights.dtype) * scales.to(weights.dtype) + + return (scales.squeeze(0).to(torch.float16), processed_q_weight, dequantized_weights.T.to(device=weights.device)) def create_moe_onnx_graph( @@ -71,6 +110,7 @@ def create_moe_onnx_graph( fc1_experts_bias, fc2_experts_weights, fc2_experts_bias, + ort_dtype, ): nodes = [ helper.make_node( @@ -94,19 +134,19 @@ def create_moe_onnx_graph( fc1_shape = [num_experts, hidden_size, inter_size] fc2_shape = [num_experts, inter_size, hidden_size] - torch_type = torch.float16 if ORT_DTYPE == TensorProto.FLOAT16 else torch.float32 + torch_type = torch.float16 if ort_dtype == TensorProto.FLOAT16 else torch.float32 initializers = [ helper.make_tensor( "fc1_experts_weights", - ORT_DTYPE, + ort_dtype, fc1_shape, fc1_experts_weights.to(torch_type).flatten().tolist(), raw=False, ), helper.make_tensor( "fc2_experts_weights", - ORT_DTYPE, + ort_dtype, fc2_shape, fc2_experts_weights.to(torch_type).flatten().tolist(), raw=False, @@ -119,14 +159,14 @@ def create_moe_onnx_graph( [ helper.make_tensor( "fc1_experts_bias", - ORT_DTYPE, + ort_dtype, fc1_bias_shape, fc1_experts_bias.to(torch_type).flatten().tolist(), raw=False, ), helper.make_tensor( "fc2_experts_bias", - ORT_DTYPE, + ort_dtype, fc2_bias_shape, fc2_experts_bias.to(torch_type).flatten().tolist(), raw=False, @@ -135,19 +175,19 @@ def create_moe_onnx_graph( ) graph_inputs = [ - helper.make_tensor_value_info("input", ORT_DTYPE, [sequence_length, hidden_size]), + helper.make_tensor_value_info("input", ort_dtype, [sequence_length, hidden_size]), ] graph_inputs.append( helper.make_tensor_value_info( "router_probs", - ORT_DTYPE, + ort_dtype, [sequence_length, num_experts], ) ) graph_outputs = [ - helper.make_tensor_value_info("output", ORT_DTYPE, [sequence_length, hidden_size]), + helper.make_tensor_value_info("output", ort_dtype, [sequence_length, hidden_size]), ] graph = helper.make_graph( @@ -171,6 +211,7 @@ def create_mixtral_moe_onnx_graph( fc2_experts_weights, fc3_experts_weights, topk, + ort_dtype, ): nodes = [ helper.make_node( @@ -197,26 +238,26 @@ def create_mixtral_moe_onnx_graph( fc2_shape = [num_experts, inter_size, hidden_size] fc3_shape = [num_experts, hidden_size, inter_size] - torch_type = torch.float16 if ORT_DTYPE == TensorProto.FLOAT16 else torch.float32 + torch_type = torch.float16 if ort_dtype == TensorProto.FLOAT16 else torch.float32 initializers = [ helper.make_tensor( "fc1_experts_weights", - ORT_DTYPE, + ort_dtype, fc1_shape, fc1_experts_weights.to(torch_type).flatten().tolist(), raw=False, ), helper.make_tensor( "fc2_experts_weights", - ORT_DTYPE, + ort_dtype, fc2_shape, fc2_experts_weights.to(torch_type).flatten().tolist(), raw=False, ), helper.make_tensor( "fc3_experts_weights", - ORT_DTYPE, + ort_dtype, fc3_shape, fc3_experts_weights.to(torch_type).flatten().tolist(), raw=False, @@ -224,19 +265,19 @@ def create_mixtral_moe_onnx_graph( ] graph_inputs = [ - helper.make_tensor_value_info("input", ORT_DTYPE, [sequence_length, hidden_size]), + helper.make_tensor_value_info("input", ort_dtype, [sequence_length, hidden_size]), ] graph_inputs.append( helper.make_tensor_value_info( "router_probs", - ORT_DTYPE, + ort_dtype, [sequence_length, num_experts], ) ) graph_outputs = [ - helper.make_tensor_value_info("output", ORT_DTYPE, [sequence_length, hidden_size]), + helper.make_tensor_value_info("output", ort_dtype, [sequence_length, hidden_size]), ] graph = helper.make_graph( @@ -259,12 +300,14 @@ def create_phi_moe_onnx_graph( fc1_experts_weights, fc2_experts_weights, fc3_experts_weights, - fc1_scales, - fc2_scales, - fc3_scales, topk, + ort_dtype, + quant_bits=0, + fc1_scales=None, + fc2_scales=None, + fc3_scales=None, ): - use_quant = USE_QUANT + use_quant = quant_bits > 0 if use_quant: assert fc1_experts_weights.dtype == torch.int8 assert fc2_experts_weights.dtype == torch.int8 @@ -276,34 +319,37 @@ def create_phi_moe_onnx_graph( assert fc2_scales.dtype == torch.float16 assert fc3_scales.dtype == torch.float16 + op_name = "QMoE" if use_quant else "MoE" + inputs = ( + [ + "input", + "router_probs", + "fc1_experts_weights", + "fc1_scales", + "", + "fc2_experts_weights", + "fc2_scales", + "", + "fc3_experts_weights", + "fc3_scales", + "", + ] + if use_quant + else [ + "input", + "router_probs", + "fc1_experts_weights", + "", + "fc2_experts_weights", + "", + "fc3_experts_weights", + ] + ) + nodes = [ helper.make_node( - "MoE" if not use_quant else "QMoE", - ( - [ - "input", - "router_probs", - "fc1_experts_weights", - "", - "fc2_experts_weights", - "", - "fc3_experts_weights", - ] - if not use_quant - else [ - "input", - "router_probs", - "fc1_experts_weights", - "fc1_scales", - "", - "fc2_experts_weights", - "fc2_scales", - "", - "fc3_experts_weights", - "fc3_scales", - "", - ] - ), + op_name, + inputs, ["output"], "MoE_0", k=topk, @@ -315,37 +361,38 @@ def create_phi_moe_onnx_graph( ] if use_quant: - nodes[0].attribute.extend([helper.make_attribute("expert_weight_bits", 8)]) + nodes[0].attribute.extend([helper.make_attribute("expert_weight_bits", quant_bits)]) - fc1_shape = [num_experts, hidden_size, inter_size] - fc2_shape = [num_experts, inter_size, hidden_size] - fc3_shape = [num_experts, hidden_size, inter_size] + components = 2 if quant_bits == 4 else 1 + fc1_shape = [num_experts, hidden_size, inter_size // components] + fc2_shape = [num_experts, inter_size, hidden_size // components] + fc3_shape = [num_experts, hidden_size, inter_size // components] - torch_type = torch.float16 if ORT_DTYPE == TensorProto.FLOAT16 else torch.float32 - numpy_type = numpy.float16 if ORT_DTYPE == TensorProto.FLOAT16 else numpy.float32 - if use_quant: - numpy_type = numpy.uint8 + torch_type = torch.float16 if ort_dtype == TensorProto.FLOAT16 else torch.float32 + numpy_type = numpy.float16 if ort_dtype == TensorProto.FLOAT16 else numpy.float32 + weight_numpy_type = numpy.uint8 if use_quant else numpy_type + weight_onnx_type = TensorProto.UINT8 if use_quant else ort_dtype initializers = [ helper.make_tensor( "fc1_experts_weights", - ORT_DTYPE if not use_quant else TensorProto.UINT8, + weight_onnx_type, fc1_shape, - fc1_experts_weights.flatten().detach().numpy().astype(numpy_type).tolist(), + fc1_experts_weights.flatten().detach().numpy().astype(weight_numpy_type).tolist(), raw=False, ), helper.make_tensor( "fc2_experts_weights", - ORT_DTYPE if not use_quant else TensorProto.UINT8, + weight_onnx_type, fc2_shape, - fc2_experts_weights.flatten().detach().numpy().astype(numpy_type).tolist(), + fc2_experts_weights.flatten().detach().numpy().astype(weight_numpy_type).tolist(), raw=False, ), helper.make_tensor( "fc3_experts_weights", - ORT_DTYPE if not use_quant else TensorProto.UINT8, + weight_onnx_type, fc3_shape, - fc3_experts_weights.flatten().detach().numpy().astype(numpy_type).tolist(), + fc3_experts_weights.flatten().detach().numpy().astype(weight_numpy_type).tolist(), raw=False, ), ] @@ -358,21 +405,21 @@ def create_phi_moe_onnx_graph( [ helper.make_tensor( "fc1_scales", - ORT_DTYPE, + ort_dtype, fc1_scale_shape, fc1_scales.to(torch_type).flatten().tolist(), raw=False, ), helper.make_tensor( "fc2_scales", - ORT_DTYPE, + ort_dtype, fc2_scale_shape, fc2_scales.to(torch_type).flatten().tolist(), raw=False, ), helper.make_tensor( "fc3_scales", - ORT_DTYPE, + ort_dtype, fc3_scale_shape, fc3_scales.to(torch_type).flatten().tolist(), raw=False, @@ -381,19 +428,19 @@ def create_phi_moe_onnx_graph( ) graph_inputs = [ - helper.make_tensor_value_info("input", ORT_DTYPE, [sequence_length, hidden_size]), + helper.make_tensor_value_info("input", ort_dtype, [sequence_length, hidden_size]), ] graph_inputs.append( helper.make_tensor_value_info( "router_probs", - ORT_DTYPE, + ort_dtype, [sequence_length, num_experts], ) ) graph_outputs = [ - helper.make_tensor_value_info("output", ORT_DTYPE, [sequence_length, hidden_size]), + helper.make_tensor_value_info("output", ort_dtype, [sequence_length, hidden_size]), ] graph = helper.make_graph( @@ -546,8 +593,11 @@ def __init__(self, config: PhiMoEConfig): class SparseMoeBlockORTHelper(nn.Module): - def __init__(self): + def __init__(self, quant_bits=0): super().__init__() + self.quant_bits = quant_bits + self.ort_dtype = TensorProto.FLOAT16 if self.quant_bits > 0 else TensorProto.FLOAT + self.np_type = numpy.float16 if self.ort_dtype == TensorProto.FLOAT16 else numpy.float32 def create_ort_session(self, moe_onnx_graph): from onnxruntime import InferenceSession, SessionOptions # noqa: PLC0415 @@ -573,8 +623,8 @@ def ort_forward(self, hidden_states: torch.Tensor, iobinding=False) -> torch.Ten router_logits = self.gate(hidden_states) ort_inputs = { - "input": numpy.ascontiguousarray(hidden_states.detach().numpy().astype(NP_TYPE)), - "router_probs": numpy.ascontiguousarray(router_logits.detach().numpy().astype(NP_TYPE)), + "input": numpy.ascontiguousarray(hidden_states.detach().numpy().astype(self.np_type)), + "router_probs": numpy.ascontiguousarray(router_logits.detach().numpy().astype(self.np_type)), } ort_output = None @@ -586,13 +636,6 @@ def ort_forward(self, hidden_states: torch.Tensor, iobinding=False) -> torch.Ten self.ort_run_with_iobinding(ort_inputs) return None - # print_tensor("input", ort_inputs["input"]) - # print_tensor("router_probs", ort_inputs["router_probs"]) - # print_tensor("fc1_experts_weights", self.moe_experts_weight1.detach().numpy()) - # print_tensor("fc2_experts_weights", self.moe_experts_weight2.detach().numpy()) - # print_tensor("fc3_experts_weights", self.moe_experts_weight3.detach().numpy()) - # print_tensor("output", ort_output[0]) - return None def ort_run_with_iobinding(self, ort_inputs, repeat=1000): @@ -603,7 +646,7 @@ def ort_run_with_iobinding(self, ort_inputs, repeat=1000): name="input", device_type="cuda", device_id=device_id, - element_type=NP_TYPE, + element_type=self.np_type, shape=ort_inputs["input"].shape, buffer_ptr=onnxruntime.OrtValue.ortvalue_from_numpy(ort_inputs["input"], "cuda", device_id).data_ptr(), ) @@ -612,7 +655,7 @@ def ort_run_with_iobinding(self, ort_inputs, repeat=1000): name="router_probs", device_type="cuda", device_id=device_id, - element_type=NP_TYPE, + element_type=self.np_type, shape=ort_inputs["router_probs"].shape, buffer_ptr=onnxruntime.OrtValue.ortvalue_from_numpy( ort_inputs["router_probs"], "cuda", device_id @@ -623,7 +666,7 @@ def ort_run_with_iobinding(self, ort_inputs, repeat=1000): name="output", device_type="cuda", device_id=device_id, - element_type=NP_TYPE, + element_type=self.np_type, shape=ort_inputs["input"].shape, buffer_ptr=onnxruntime.OrtValue.ortvalue_from_numpy( numpy.zeros(ort_inputs["input"].shape), "cuda", device_id @@ -646,22 +689,27 @@ def ort_run_with_iobinding(self, ort_inputs, repeat=1000): e = time.time() print(f"MoE cuda kernel time: {(e - s) / repeat * 1000} ms") - def parity_check(self): + def parity_check(self, atol=None, rtol=None): hidden_state = torch.randn(self.batch_size, self.sequence_length, self.hidden_dim) torch_output = self.forward(hidden_state) ort_output = self.ort_forward(hidden_state) + + if atol is None: + atol = 5e-2 if self.quant_bits == 0 else (2.0 if self.quant_bits == 8 else 3.0) + + if rtol is None: + rtol = 1e-3 if self.quant_bits == 0 else 1e-2 + if ort_output is not None: + dtype_str = "FP32" if self.quant_bits == 0 else "FP16" print( - "name:", - self.__class__.__name__, - " batch_size:", - self.batch_size, - " sequence_length:", - self.sequence_length, - " max_diff:", - (torch_output - ort_output).abs().max(), + f"name: {self.__class__.__name__}, quant_bits: {self.quant_bits}, dtype: {dtype_str}," + f" batch: {self.batch_size}, seq_len: {self.sequence_length}," + f" max_diff: {(torch_output - ort_output).abs().max()}" + ) + torch.testing.assert_close( + ort_output.to(torch.float32), torch_output.to(torch.float32), rtol=rtol, atol=atol ) - torch.testing.assert_close(ort_output.to(torch.float32), torch_output, rtol=THRESHOLD, atol=THRESHOLD) def benchmark_ort(self): hidden_state = torch.randn(self.batch_size, self.sequence_length, self.hidden_dim) @@ -680,7 +728,7 @@ def __init__( eval_capacity=-1, activation="gelu", ): - super().__init__() + super().__init__(quant_bits=0) # SwitchMoE is not quantized self.batch_size = batch_size self.sequence_length = sequence_length self.num_experts = num_experts @@ -709,6 +757,7 @@ def __init__( self.moe_experts.bias1, self.moe_experts.weight2.transpose(1, 2), self.moe_experts.bias2, + self.ort_dtype, ) self.ort_sess = self.create_ort_session(self.moe_onnx_graph) @@ -744,7 +793,7 @@ class MixtralSparseMoeBlock(SparseMoeBlockORTHelper): """ def __init__(self, config, batch_size, sequence_length): - super().__init__() + super().__init__(quant_bits=0) # Mixtral test is not quantized self.hidden_dim = config.hidden_size self.ffn_dim = config.intermediate_size self.num_experts = config.num_local_experts @@ -778,6 +827,7 @@ def __init__(self, config, batch_size, sequence_length): self.moe_experts_weight2, self.moe_experts_weight3, self.top_k, + self.ort_dtype, ) self.ort_sess = self.create_ort_session(self.moe_onnx_graph) @@ -874,43 +924,44 @@ class PhiMoESparseMoeBlock(SparseMoeBlockORTHelper): and memory on padding. """ - def __init__(self, config, batch_size, sequence_length): - super().__init__() + def __init__(self, config, batch_size, sequence_length, quant_bits=0): + super().__init__(quant_bits) self.hidden_dim = config.hidden_size self.ffn_dim = config.intermediate_size self.num_experts = config.num_local_experts self.top_k = config.num_experts_per_tok self.router_jitter_noise = config.router_jitter_noise + use_quant = self.quant_bits > 0 # gating self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False) self.experts = nn.ModuleList([PhiMoEBlockSparseTop2MLP(config) for _ in range(self.num_experts)]) - w1_list = [] - w2_list = [] - w3_list = [] - w1_scale_list = [] - w2_scale_list = [] - w3_scale_list = [] - if not USE_QUANT: + w1_list, w2_list, w3_list = [], [], [] + w1_scale_list, w2_scale_list, w3_scale_list = [], [], [] + + if not use_quant: for i in range(self.num_experts): w1_list.append(self.experts[i].w1.weight) w2_list.append(self.experts[i].w2.weight) w3_list.append(self.experts[i].w3.weight) else: + is_4_bit = self.quant_bits == 4 for i in range(self.num_experts): - w1_scale, pre_qweight1, w1_qdq = quant_dequant(self.experts[i].w1.weight, False) - w2_scale, pre_qweight2, w2_qdq = quant_dequant(self.experts[i].w2.weight, False) - w3_scale, pre_qweight3, w3_qdq = quant_dequant(self.experts[i].w3.weight, False) + # Corrected quantization logic for per-output-channel quantization + w1_scale, pre_qweight1, w1_qdq = quant_dequant(self.experts[i].w1.weight.T, is_4_bit) + w2_scale, pre_qweight2, w2_qdq = quant_dequant(self.experts[i].w2.weight.T, is_4_bit) + w3_scale, pre_qweight3, w3_qdq = quant_dequant(self.experts[i].w3.weight.T, is_4_bit) self.experts[i].w1.weight.data = w1_qdq self.experts[i].w2.weight.data = w2_qdq self.experts[i].w3.weight.data = w3_qdq - w1_list.append(pre_qweight1) - w2_list.append(pre_qweight2) - w3_list.append(pre_qweight3) + # Transpose quantized weights to match the expected ONNX layout + w1_list.append(pre_qweight1.T) + w2_list.append(pre_qweight2.T) + w3_list.append(pre_qweight3.T) w1_scale_list.append(w1_scale) w2_scale_list.append(w2_scale) w3_scale_list.append(w3_scale) @@ -919,9 +970,9 @@ def __init__(self, config, batch_size, sequence_length): self.moe_experts_weight2 = torch.stack(w2_list, dim=0) self.moe_experts_weight3 = torch.stack(w3_list, dim=0) - moe_experts_weight_scale1 = torch.stack(w1_scale_list, dim=0) if USE_QUANT else None - moe_experts_weight_scale2 = torch.stack(w2_scale_list, dim=0) if USE_QUANT else None - moe_experts_weight_scale3 = torch.stack(w3_scale_list, dim=0) if USE_QUANT else None + moe_experts_weight_scale1 = torch.stack(w1_scale_list, dim=0) if use_quant else None + moe_experts_weight_scale2 = torch.stack(w2_scale_list, dim=0) if use_quant else None + moe_experts_weight_scale3 = torch.stack(w3_scale_list, dim=0) if use_quant else None self.batch_size = batch_size self.sequence_length = sequence_length @@ -933,10 +984,12 @@ def __init__(self, config, batch_size, sequence_length): self.moe_experts_weight1, self.moe_experts_weight2, self.moe_experts_weight3, + self.top_k, + self.ort_dtype, + self.quant_bits, moe_experts_weight_scale1, moe_experts_weight_scale2, moe_experts_weight_scale3, - self.top_k, ) self.ort_sess = self.create_ort_session(self.moe_onnx_graph) @@ -992,19 +1045,23 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def small_test_cases(): for batch_size in [1, 4, 16]: for sequence_length in [128, 512, 1024]: - yield batch_size, sequence_length + yield batch_size, sequence_length, 0 -def phi3_test_cases(): - # TODO: phi3 moe failed in long sequence lengths (max diff 0.22 > threshold 0.01), need investigation. - for batch_size in [1, 4, 16]: - for sequence_length in [128]: - yield batch_size, sequence_length +# Test cases for Phi-3 MoE. +# We test three modes: no quantization, 8-bit, and 4-bit. +phi3_test_params = list( + itertools.product( + [1, 4], # batch_size + [1, 32], # sequence_length + [0, 8, 4], # quant_bits (0 for fp32/fp32, 8 for int8/fp16, 4 for int4/fp16) + ) +) class TestSwitchMoE(unittest.TestCase): @parameterized.expand(small_test_cases()) - def test_switch_moe_parity(self, batch_size, sequence_length): + def test_switch_moe_parity(self, batch_size, sequence_length, quant_bits): # if platform.system() == "Windows": # pytest.skip("Skip on Windows") switch_moe = SwitchMoE( @@ -1020,8 +1077,8 @@ def test_switch_moe_parity(self, batch_size, sequence_length): class TestMixtralMoE(unittest.TestCase): - @parameterized.expand(small_test_cases()) - def test_mixtral_moe_parity(self, batch_size, sequence_length): + @parameterized.expand([(b, s, q) for b, s, q in small_test_cases() if q == 0]) # only run non-quantized + def test_mixtral_moe_parity(self, batch_size, sequence_length, quant_bits): config = MixtralConfig(hidden_size=256, intermediate_size=1024) mixtral_moe = MixtralSparseMoeBlock(config, batch_size, sequence_length) mixtral_moe.parity_check() @@ -1029,13 +1086,329 @@ def test_mixtral_moe_parity(self, batch_size, sequence_length): class TestPhiMoE(unittest.TestCase): - @parameterized.expand(phi3_test_cases()) - def test_phi3_moe_parity(self, batch_size, sequence_length): + @parameterized.expand(phi3_test_params) + def test_phi3_moe_parity(self, batch_size, sequence_length, quant_bits): config = PhiMoEConfig(hidden_size=256, intermediate_size=1024) - phi3_moe = PhiMoESparseMoeBlock(config, batch_size, sequence_length) + phi3_moe = PhiMoESparseMoeBlock(config, batch_size, sequence_length, quant_bits) phi3_moe.parity_check() # phi3_moe.benchmark_ort() +# --------------------------------------------- +# The following test are for swiglu activation +# --------------------------------------------- +class SwigluMoeConfig: + def __init__( + self, + hidden_size=2048, + intermediate_size=2048, + num_experts_per_token=2, + num_local_experts=8, + ): + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_experts_per_token = num_experts_per_token + self.num_local_experts = num_local_experts + + +class SwigluMlp(nn.Module): + def __init__(self, config): + super().__init__() + self.intermediate_size = config.intermediate_size + self.hidden_dim = config.hidden_size + self.w1 = nn.Linear(self.hidden_dim, 2 * self.intermediate_size, bias=True) + self.w2 = nn.Linear(self.intermediate_size, self.hidden_dim, bias=True) + + def swiglu(self, x: torch.Tensor): + dim = x.shape[-1] + x = x.view(-1, dim // 2, 2) + x_glu, x_linear = x[..., 0], x[..., 1] + y = x_glu * torch.sigmoid(1.702 * x_glu) * (x_linear + 1) + return y + + def forward(self, x): + y = self.swiglu(self.w1(x)) + y = self.w2(y) + return y + + +def create_swiglu_moe_onnx_graph( + num_tokens: int, + num_experts: int, + hidden_size: int, + inter_size: int, + topk: int, + ort_dtype: int, + quant_bits: int, + fc1_experts_weights: torch.Tensor, + fc1_experts_bias: torch.Tensor, + fc2_experts_weights: torch.Tensor, + fc2_experts_bias: torch.Tensor, + fc1_experts_weight_scale: torch.Tensor = None, + fc2_experts_weight_scale: torch.Tensor = None, +): + use_quant = quant_bits > 0 + op_name = "QMoE" if use_quant else "MoE" + + inputs = ( + [ + "input", + "router_probs", + "fc1_experts_weights", + "fc1_experts_weight_scale", + "fc1_experts_bias", + "fc2_experts_weights", + "fc2_experts_weight_scale", + "fc2_experts_bias", + ] + if use_quant + else [ + "input", + "router_probs", + "fc1_experts_weights", + "fc1_experts_bias", + "fc2_experts_weights", + "fc2_experts_bias", + ] + ) + + nodes = [ + helper.make_node( + op_name, + inputs, + ["output"], + "MoE_0", + k=topk, + normalize_routing_weights=1, + activation_type="swiglu", + domain="com.microsoft", + ), + ] + + if use_quant: + nodes[0].attribute.extend([helper.make_attribute("expert_weight_bits", quant_bits)]) + + components = 2 if quant_bits == 4 else 1 + fc1_weight_shape = [num_experts, hidden_size, 2 * inter_size // components] + fc1_bias_shape = [num_experts, 2 * inter_size] + fc1_experts_weight_scale_shape = [num_experts, 2 * inter_size] + + fc2_weight_shape = [num_experts, inter_size, hidden_size // components] + fc2_bias_shape = [num_experts, hidden_size] + fc2_experts_weight_scale_shape = [num_experts, hidden_size] + + torch_type = torch.float16 if ort_dtype == TensorProto.FLOAT16 else torch.float32 + numpy_type = numpy.float16 if ort_dtype == TensorProto.FLOAT16 else numpy.float32 + weight_numpy_type = numpy.uint8 if use_quant else numpy_type + weight_onnx_type = TensorProto.UINT8 if use_quant else ort_dtype + + initializers = [ + helper.make_tensor( + "fc1_experts_weights", + weight_onnx_type, + fc1_weight_shape, + fc1_experts_weights.flatten().detach().numpy().astype(weight_numpy_type).tolist() + if use_quant + else fc1_experts_weights.to(torch_type).flatten().tolist(), + raw=False, + ), + helper.make_tensor( + "fc1_experts_bias", + ort_dtype, + fc1_bias_shape, + fc1_experts_bias.to(torch_type).flatten().tolist(), + raw=False, + ), + helper.make_tensor( + "fc2_experts_weights", + weight_onnx_type, + fc2_weight_shape, + fc2_experts_weights.flatten().detach().numpy().astype(weight_numpy_type).tolist() + if use_quant + else fc2_experts_weights.to(torch_type).flatten().tolist(), + raw=False, + ), + helper.make_tensor( + "fc2_experts_bias", + ort_dtype, + fc2_bias_shape, + fc2_experts_bias.to(torch_type).flatten().tolist(), + raw=False, + ), + ] + + if use_quant: + initializers.extend( + [ + helper.make_tensor( + "fc1_experts_weight_scale", + ort_dtype, + fc1_experts_weight_scale_shape, + fc1_experts_weight_scale.to(torch_type).flatten().tolist(), + raw=False, + ), + helper.make_tensor( + "fc2_experts_weight_scale", + ort_dtype, + fc2_experts_weight_scale_shape, + fc2_experts_weight_scale.to(torch_type).flatten().tolist(), + raw=False, + ), + ] + ) + + graph_inputs = [ + helper.make_tensor_value_info("input", ort_dtype, [num_tokens, hidden_size]), + ] + + graph_inputs.append( + helper.make_tensor_value_info( + "router_probs", + ort_dtype, + [num_tokens, num_experts], + ) + ) + + graph_outputs = [ + helper.make_tensor_value_info("output", ort_dtype, [num_tokens, hidden_size]), + ] + + graph = helper.make_graph( + nodes, + "MoE_Graph", + graph_inputs, + graph_outputs, + initializers, + ) + + model = helper.make_model(graph) + return model.SerializeToString() + + +class SwigluMoEBlock(SparseMoeBlockORTHelper): + def __init__(self, config: SwigluMoeConfig, batch_size: int, sequence_length: int, quant_bits: int = 0): + super().__init__(quant_bits) + self.hidden_dim = config.hidden_size + self.ffn_dim = config.intermediate_size + self.num_experts = config.num_local_experts + self.top_k = config.num_experts_per_token + use_quant = self.quant_bits > 0 + + self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=True) + + self.experts = nn.ModuleList([SwigluMlp(config) for _ in range(self.num_experts)]) + + weight_1_list, weight_2_list = [], [] + bias_1_list, bias_2_list = [], [] + scale_1_list, scale_2_list = [], [] + + for i in range(self.num_experts): + bias_1_list.append(self.experts[i].w1.bias) + bias_2_list.append(self.experts[i].w2.bias) + if not use_quant: + weight_1_list.append(self.experts[i].w1.weight) + weight_2_list.append(self.experts[i].w2.weight) + else: + is_4_bit = self.quant_bits == 4 + # Pass the transposed weight to quant_dequant to get correct scales, + # then transpose the resulting quantized weight back to the expected layout. + scale1, pre_qweight1, w1_qdq = quant_dequant(self.experts[i].w1.weight.T, is_4_bit) + scale2, pre_qweight2, w2_qdq = quant_dequant(self.experts[i].w2.weight.T, is_4_bit) + + self.experts[i].w1.weight.data = w1_qdq + self.experts[i].w2.weight.data = w2_qdq + + weight_1_list.append(pre_qweight1.T) + weight_2_list.append(pre_qweight2.T) + scale_1_list.append(scale1) + scale_2_list.append(scale2) + + self.moe_experts_weight1 = torch.stack(weight_1_list, dim=0) + self.moe_experts_weight2 = torch.stack(weight_2_list, dim=0) + + self.moe_experts_bias1 = torch.stack(bias_1_list, dim=0) + self.moe_experts_bias2 = torch.stack(bias_2_list, dim=0) + + moe_experts_weight_scale1 = torch.stack(scale_1_list, dim=0) if use_quant else None + moe_experts_weight_scale2 = torch.stack(scale_2_list, dim=0) if use_quant else None + + self.batch_size = batch_size + self.sequence_length = sequence_length + self.moe_onnx_graph = create_swiglu_moe_onnx_graph( + num_tokens=self.batch_size * self.sequence_length, + num_experts=self.num_experts, + hidden_size=self.hidden_dim, + inter_size=self.ffn_dim, + topk=self.top_k, + ort_dtype=self.ort_dtype, + quant_bits=self.quant_bits, + fc1_experts_weights=self.moe_experts_weight1, + fc1_experts_bias=self.moe_experts_bias1, + fc2_experts_weights=self.moe_experts_weight2, + fc2_experts_bias=self.moe_experts_bias2, + fc1_experts_weight_scale=moe_experts_weight_scale1, + fc2_experts_weight_scale=moe_experts_weight_scale2, + ) + + self.ort_sess = self.create_ort_session(self.moe_onnx_graph) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + """ """ + batch_size, sequence_length, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + router_logits = self.gate(hidden_states) # router_logits shape is (batch * sequence_length, num_experts) + routing_weights, selected_experts = torch.topk(router_logits, self.top_k, dim=-1) + + routing_weights = F.softmax(routing_weights, dim=1, dtype=torch.float) + + # we cast back to the input dtype + routing_weights = routing_weights.to(hidden_states.dtype) + + final_hidden_states = torch.zeros( + (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device + ) + + # One hot encode the selected experts to create an expert mask + # this will be used to easily index which expert is going to be sollicitated + expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) + + # Loop over all available experts in the model and perform the computation on each expert + for expert_idx in range(self.num_experts): + expert_layer = self.experts[expert_idx] + idx, top_x = torch.where(expert_mask[expert_idx]) + + if top_x.shape[0] == 0: + continue + + # Index the correct hidden states and compute the expert hidden state for + # the current expert. We need to make sure to multiply the output hidden + # states by `routing_weights` on the corresponding tokens (top-1 and top-2) + current_state = hidden_states[None, top_x].reshape(-1, hidden_dim) + current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None] + + # However `index_add_` only support torch tensors for indexing so we'll use + # the `top_x` tensor here. + final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) + final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) + return final_hidden_states + + +swiglu_test_params = list( + itertools.product( + [1, 4], # batch_size + [1, 32], # sequence_length + [0, 8, 4], # quant_bits (0 for fp32/fp32, 8 for int8/fp16, 4 for int4/fp16) + ) +) + + +class TestSwigluMoE(unittest.TestCase): + @parameterized.expand(swiglu_test_params) + def test_swiglu_moe_parity(self, batch_size, sequence_length, quant_bits): + config = SwigluMoeConfig(hidden_size=128, intermediate_size=512, num_experts_per_token=1, num_local_experts=4) + moe = SwigluMoEBlock(config, batch_size, sequence_length, quant_bits) + moe.parity_check() + + if __name__ == "__main__": unittest.main() diff --git a/onnxruntime/test/shared_lib/test_model_builder_api.cc b/onnxruntime/test/shared_lib/test_model_builder_api.cc index 9807fcca06ed4..0fe747cdd84e5 100644 --- a/onnxruntime/test/shared_lib/test_model_builder_api.cc +++ b/onnxruntime/test/shared_lib/test_model_builder_api.cc @@ -118,6 +118,9 @@ struct TestAllocator : public OrtAllocator { Reserve = [](struct OrtAllocator* /*this*/, size_t /*size*/) -> void* { throw std::runtime_error("This should not be used"); }; + + GetStats = nullptr; + AllocOnStream = nullptr; } // initializers that are used directly by the model. as there's no copy they must remain valid. diff --git a/onnxruntime/test/testdata/matmul_integer_to_float.py b/onnxruntime/test/testdata/matmul_integer_to_float.py index 0c1ea47fff5b1..5e9a1778198ef 100644 --- a/onnxruntime/test/testdata/matmul_integer_to_float.py +++ b/onnxruntime/test/testdata/matmul_integer_to_float.py @@ -1,8 +1,11 @@ +import numpy as np import onnx -from onnx import TensorProto, helper +from onnx import TensorProto, helper, numpy_helper -def GenerateModel(model_name, sign_i, sign_w, output_type_fp16, has_zp=True, bias=False): # noqa: N802 +def generate_model( + model_name, sign_i, sign_w, output_type_fp16, has_zp=True, bias=False, bias_initializer=False, bias_flip=False +): nodes = [ # subgraph helper.make_node( "MatMulInteger", @@ -50,15 +53,22 @@ def GenerateModel(model_name, sign_i, sign_w, output_type_fp16, has_zp=True, bia ) if bias: - nodes.extend([helper.make_node("Add", ["mul_bottom_output", "bias"], ["Y"], "add")]) + if bias_flip: + nodes.extend([helper.make_node("Add", ["bias", "mul_bottom_output"], ["Y"], "add")]) + else: + nodes.extend([helper.make_node("Add", ["mul_bottom_output", "bias"], ["Y"], "add")]) - inputs.extend( - [ - helper.make_tensor_value_info( - "bias", TensorProto.FLOAT16 if output_type_fp16 else TensorProto.FLOAT, ["N"] - ) - ] + if bias_initializer: + # Use a constant initializer + bias_vals = np.array([[1.0, 2.0, 3.0, 4.0]], dtype=np.float16 if output_type_fp16 else np.float32) + bias_tensor = numpy_helper.from_array(bias_vals, name="bias") + initializers = [bias_tensor] + else: + # Use runtime input + inputs.append( + helper.make_tensor_value_info("bias", TensorProto.FLOAT16 if output_type_fp16 else TensorProto.FLOAT, ["N"]) ) + initializers = [] graph = helper.make_graph( nodes, @@ -69,6 +79,7 @@ def GenerateModel(model_name, sign_i, sign_w, output_type_fp16, has_zp=True, bia "Y", TensorProto.FLOAT16 if output_type_fp16 else TensorProto.FLOAT, ["M", "N"] ), ], + initializer=initializers, ) model = helper.make_model(graph) @@ -76,10 +87,10 @@ def GenerateModel(model_name, sign_i, sign_w, output_type_fp16, has_zp=True, bia if __name__ == "__main__": - GenerateModel("matmul_integer_to_float16_int8.onnx", sign_i=False, sign_w=True, output_type_fp16=True) - GenerateModel("matmul_integer_to_float_int8.onnx", sign_i=False, sign_w=True, output_type_fp16=False) - GenerateModel("matmul_integer_to_float_uint8.onnx", sign_i=False, sign_w=False, output_type_fp16=False) - GenerateModel( + generate_model("matmul_integer_to_float16_int8.onnx", sign_i=False, sign_w=True, output_type_fp16=True) + generate_model("matmul_integer_to_float_int8.onnx", sign_i=False, sign_w=True, output_type_fp16=False) + generate_model("matmul_integer_to_float_uint8.onnx", sign_i=False, sign_w=False, output_type_fp16=False) + generate_model( "matmul_integer_to_float_int8_bias.onnx", sign_i=False, sign_w=True, @@ -87,7 +98,7 @@ def GenerateModel(model_name, sign_i, sign_w, output_type_fp16, has_zp=True, bia has_zp=False, bias=True, ) - GenerateModel( + generate_model( "matmul_integer_to_float_uint8_bias.onnx", sign_i=False, sign_w=False, @@ -95,9 +106,27 @@ def GenerateModel(model_name, sign_i, sign_w, output_type_fp16, has_zp=True, bia has_zp=False, bias=True, ) - - GenerateModel("matmul_integer_to_float_int8_int8.onnx", sign_i=True, sign_w=True, output_type_fp16=False) - GenerateModel( + generate_model( + "matmul_integer_to_float_int8_bias_initializer_index1.onnx", + sign_i=False, + sign_w=True, + output_type_fp16=False, + has_zp=False, + bias=True, + bias_initializer=True, + ) + generate_model( + "matmul_integer_to_float_int8_bias_initializer_index0.onnx", + sign_i=False, + sign_w=True, + output_type_fp16=False, + has_zp=False, + bias=True, + bias_flip=True, + bias_initializer=True, + ) + generate_model("matmul_integer_to_float_int8_int8.onnx", sign_i=True, sign_w=True, output_type_fp16=False) + generate_model( "matmul_integer_to_float_int8_int8_bias.onnx", sign_i=True, sign_w=True, diff --git a/onnxruntime/test/testdata/matmul_integer_to_float_int8_bias_initializer_index0.onnx b/onnxruntime/test/testdata/matmul_integer_to_float_int8_bias_initializer_index0.onnx new file mode 100644 index 0000000000000..841e61cef8fb2 Binary files /dev/null and b/onnxruntime/test/testdata/matmul_integer_to_float_int8_bias_initializer_index0.onnx differ diff --git a/onnxruntime/test/testdata/matmul_integer_to_float_int8_bias_initializer_index1.onnx b/onnxruntime/test/testdata/matmul_integer_to_float_int8_bias_initializer_index1.onnx new file mode 100644 index 0000000000000..c0d8c14e4e775 Binary files /dev/null and b/onnxruntime/test/testdata/matmul_integer_to_float_int8_bias_initializer_index1.onnx differ diff --git a/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc b/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc index aa89ab80bc4e5..23c3a922326cb 100644 --- a/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc +++ b/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc @@ -31,6 +31,7 @@ "current_failing_tests": [ "^test_adagrad", "^test_adagrad_multiple", + "^test_attention_3d.*", // wrong expected values in onnx==1.18.0, fixed in 1.19.0 "^test_batchnorm_epsilon_training_mode", "^test_batchnorm_example_training_mode", "^test_col2im_pads", // still one wrong value coming from the backtest example diff --git a/onnxruntime/test/util/include/api_asserts.h b/onnxruntime/test/util/include/api_asserts.h index 9d34be24d5012..423135f96fbcd 100644 --- a/onnxruntime/test/util/include/api_asserts.h +++ b/onnxruntime/test/util/include/api_asserts.h @@ -37,3 +37,9 @@ EXPECT_NE(_tmp_status, nullptr); \ if (_tmp_status) Ort::GetApi().ReleaseStatus(_tmp_status); \ } while (false) + +#define ASSERT_CXX_ORTSTATUS_OK(function) \ + do { \ + Ort::Status _tmp_status = (function); \ + ASSERT_TRUE(_tmp_status.IsOK()) << _tmp_status.GetErrorMessage(); \ + } while (false) diff --git a/onnxruntime/wasm/pre-async.js b/onnxruntime/wasm/pre-async.js index 8c75dc7c5cf1e..1f8f17535e7d4 100644 --- a/onnxruntime/wasm/pre-async.js +++ b/onnxruntime/wasm/pre-async.js @@ -15,78 +15,20 @@ let initAsyncImpl = () => { // It removes some overhead in cwarp() and ccall() that we don't need. // // Currently in ASYNCIFY build, we only use this for the following functions: + // - OrtAppendExecutionProvider() // - OrtCreateSession() // - OrtRun() // - OrtRunWithBinding() // - OrtBindInput() // - // Note: about parameters "getFunc" and "setFunc": - // - Emscripten has different behaviors for Debug and Release builds for generating exported function wrapper. + // We need to wrap these functions with an async wrapper so that they can be called in an async context. // - // - In Debug build, it will generate a wrapper function for each exported function. For example, it generates a - // wrapper for OrtRun() like this (minified): - // ``` - // var _OrtRun = Module["_OrtRun"] = createExportWrapper("OrtRun"); - // ``` - // - // - In Release build, it will generate a lazy loading wrapper for each exported function. For example, it generates - // a wrapper for OrtRun() like this (minified): - // ``` - // d._OrtRun = (a, b, c, e, f, h, l, q) => (d._OrtRun = J.ka)(a, b, c, e, f, h, l, q); - // ``` - // - // The behavior of these two wrappers are different. The debug build will assign `Module["_OrtRun"]` only once - // because `createExportWrapper()` does not reset `Module["_OrtRun"]` inside. The release build, however, will - // reset d._OrtRun to J.ka when the first time it is called. - // - // The difference is important because we need to design the async wrapper in a way that it can handle both cases. - // - // Now, let's look at how the async wrapper is designed to work for both cases: - // - // - Debug build: - // 1. When Web assembly is being loaded, `Module["_OrtRun"]` is assigned to `createExportWrapper("OrtRun")`. - // 2. When the first time `Module["initAsync"]` is called, `Module["_OrtRun"]` is re-assigned to a new async - // wrapper function. - // Value of `Module["_OrtRun"]` will not be changed again. - // - // - Release build: - // 1. When Web assembly is being loaded, `Module["_OrtRun"]` is assigned to a lazy loading wrapper function. - // 2. When the first time `Module["initAsync"]` is called, `Module["_OrtRun"]` is re-assigned to a new async - // wrapper function. - // 3. When the first time `Module["_OrtRun"]` is called, the async wrapper will be called. It will call into this - // function: - // ``` - // (a, b, c, e, f, h, l, q) => (d._OrtRun = J.ka)(a, b, c, e, f, h, l, q); - // ``` - // This function will assign d._OrtRun (ie. the minimized `Module["_OrtRun"]`) to the real function (J.ka). - // 4. Since d._OrtRun is re-assigned, we need to update the async wrapper to re-assign its stored - // function to the updated value (J.ka), and re-assign the value of `d._OrtRun` back to the async wrapper. - // Value of `Module["_OrtRun"]` will not be changed again. - // - // The value of `Module["_OrtRun"]` will need to be assigned for 2 times for debug build and 4 times for release - // build. - // - // This is why we need this `getFunc` and `setFunc` parameters. They are used to get the current value of an - // exported function and set the new value of an exported function. - // - const wrapAsync = (func, getFunc, setFunc) => { + const wrapAsync = (func) => { return (...args) => { // cache the async data before calling the function. const previousAsync = Asyncify.currData; - const previousFunc = getFunc?.(); const ret = func(...args); - const newFunc = getFunc?.(); - if (previousFunc !== newFunc) { - // The exported function has been updated. - // Set the sync function reference to the new function. - func = newFunc; - // Set the exported function back to the async wrapper. - setFunc(previousFunc); - // Remove getFunc and setFunc. They are no longer needed. - setFunc = null; - getFunc = null; - } // If the async data has been changed, it means that the function started an async operation. if (Asyncify.currData != previousAsync) { @@ -101,11 +43,7 @@ let initAsyncImpl = () => { // replace the original functions with asyncified versions const wrapAsyncAPIs = (funcNames) => { for (const funcName of funcNames) { - Module[funcName] = wrapAsync( - Module[funcName], - () => Module[funcName], - (v) => (Module[funcName] = v) - ); + Module[funcName] = wrapAsync(Module[funcName]); } }; diff --git a/setup.py b/setup.py index 1893e18b8aab6..5ab1ac5b840d4 100644 --- a/setup.py +++ b/setup.py @@ -478,7 +478,7 @@ def finalize_options(self): examples = [path.join("datasets", x) for x in examples_names] # Extra files such as EULA and ThirdPartyNotices (and Qualcomm License, only for QNN release packages) -extra = ["LICENSE", "ThirdPartyNotices.txt", "Privacy.md", "Qualcomm AI Hub Proprietary License.pdf"] +extra = ["LICENSE", "ThirdPartyNotices.txt", "Privacy.md", "Qualcomm_LICENSE.pdf"] # Description readme_file = "docs/python/ReadMeOV.rst" if is_openvino else "docs/python/README.rst" diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index 87e0ac6a42ea6..561a76be5fa89 100644 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -887,7 +887,22 @@ def generate_build_tree( if args.use_snpe: cmake_args += ["-Donnxruntime_USE_SNPE=ON"] - cmake_args += ["-Donnxruntime_USE_KLEIDIAI=" + ("OFF" if args.no_kleidiai else "ON")] + # Set onnxruntime_USE_KLEIDIAI based on: + # * Default value above is NO. + # * Leave disabled if "no_kleidiai" argument was specified. + # * Enable if the target is Android and args.android_abi contains arm64* + # * Enable for a Windows cross compile build if compile target is an Arm one. + # * Finally enable if platform.machine contains "arm64". This should cover the following cases: + # * Linux on Arm + # * MacOs (case must be ignored) + # * TODO Delegate responsibility for Onnxruntime_USE_KLEIDIAI = ON to CMake logic + if not args.no_kleidiai: + if ( + (args.android and "arm64" in args.android_abi.lower()) + or (is_windows() and (args.arm64 or args.arm64ec or args.arm) and platform.architecture()[0] != "AMD64") + or ("arm64" in platform.machine().lower()) + ): + cmake_args += ["-Donnxruntime_USE_KLEIDIAI=ON"] if is_macOS() and (args.macos or args.ios or args.visionos or args.tvos): # Note: Xcode CMake generator doesn't have a good support for Mac Catalyst yet. diff --git a/tools/ci_build/build_args.py b/tools/ci_build/build_args.py index c42f8e3219da4..82118148d35f9 100644 --- a/tools/ci_build/build_args.py +++ b/tools/ci_build/build_args.py @@ -342,7 +342,7 @@ def add_webassembly_args(parser: argparse.ArgumentParser) -> None: """Adds arguments for WebAssembly (WASM) platform builds.""" parser.add_argument("--build_wasm", action="store_true", help="Build for WebAssembly.") parser.add_argument("--build_wasm_static_lib", action="store_true", help="Build WebAssembly static library.") - parser.add_argument("--emsdk_version", default="4.0.8", help="Specify version of emsdk.") + parser.add_argument("--emsdk_version", default="4.0.11", help="Specify version of emsdk.") parser.add_argument("--enable_wasm_simd", action="store_true", help="Enable WebAssembly SIMD.") parser.add_argument("--enable_wasm_relaxed_simd", action="store_true", help="Enable WebAssembly Relaxed SIMD.") parser.add_argument("--enable_wasm_threads", action="store_true", help="Enable WebAssembly multi-threading.") diff --git a/tools/ci_build/get_docker_image.py b/tools/ci_build/get_docker_image.py index e656cedae5916..90947e534918d 100755 --- a/tools/ci_build/get_docker_image.py +++ b/tools/ci_build/get_docker_image.py @@ -71,11 +71,14 @@ def main(): log.info(f"Image: {full_image_name}") - dst_deps_file = Path(args.context) / "scripts" / "deps.txt" + dst_scripts_dir = Path(args.context) / "scripts" + dst_deps_file = dst_scripts_dir / "deps.txt" # The docker file may provide a special deps.txt in its docker context dir and uses that one. # Otherwise, copy a generic one from this repo's cmake dir. if not dst_deps_file.exists(): log.info(f"Copy deps.txt to : {dst_deps_file}") + if not dst_scripts_dir.exists(): + dst_scripts_dir.mkdir(parents=True, exist_ok=True) shutil.copyfile(Path(REPO_DIR) / "cmake" / "deps.txt", str(dst_deps_file)) if "manylinux" in args.dockerfile and args.multiple_repos: diff --git a/tools/ci_build/github/azure-pipelines/android-arm64-v8a-QNN-crosscompile-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/android-arm64-v8a-QNN-crosscompile-ci-pipeline.yml index e5e2a4749ef85..91f35d2b54033 100644 --- a/tools/ci_build/github/azure-pipelines/android-arm64-v8a-QNN-crosscompile-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/android-arm64-v8a-QNN-crosscompile-ci-pipeline.yml @@ -52,7 +52,7 @@ jobs: - script: sudo chmod go+rw /dev/kvm displayName: Update permissions to KVM - - template: templates/jobs/download_linux_qnn_sdk.yml + - template: templates/jobs/init_linux_qnn_sdk_x64.yml parameters: QnnSDKVersion: ${{ parameters.QnnSdk }} diff --git a/tools/ci_build/github/azure-pipelines/build-perf-test-binaries-pipeline.yml b/tools/ci_build/github/azure-pipelines/build-perf-test-binaries-pipeline.yml index 0ce4227c9ef9f..5cf5cd8c936fa 100644 --- a/tools/ci_build/github/azure-pipelines/build-perf-test-binaries-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/build-perf-test-binaries-pipeline.yml @@ -20,13 +20,17 @@ stages: artifactName: 'onnxruntime-android-full-aar' job_name_suffix: 'Full' publish_executables: '1' - pool_name: 'onnxruntime-Ubuntu2204-AMD-CPU' + pool_name: 'onnxruntime-Ubuntu2404-AMD-CPU' enable_code_sign: false # build Python packages # Linux GPU only - ${{ if parameters.BuildPythonPackages }}: - - template: stages/py-gpu-packaging-stage.yml + - template: stages/py-linux-gpu-stage.yml parameters: - enable_linux_cuda: true - cuda_version: 12.2 + arch: 'x86_64' + machine_pool: 'onnxruntime-Ubuntu2404-AMD-CPU' + extra_build_arg: '' + cmake_build_type: Release + cuda_version: 12.2 + docker_base_image: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_almalinux8_gcc12:20250714.2 \ No newline at end of file diff --git a/tools/ci_build/github/azure-pipelines/c-api-noopenmp-test-pipelines.yml b/tools/ci_build/github/azure-pipelines/c-api-noopenmp-test-pipelines.yml index ed6183b3fa6da..3772b5e9c4c20 100644 --- a/tools/ci_build/github/azure-pipelines/c-api-noopenmp-test-pipelines.yml +++ b/tools/ci_build/github/azure-pipelines/c-api-noopenmp-test-pipelines.yml @@ -75,6 +75,16 @@ stages: artifactName: 'onnxruntime-android-full-aar' ReleaseVersionSuffix: $(ReleaseVersionSuffix) +- stage: Final_AAR_Testing_Android_QNN + dependsOn: Setup + jobs: + - template: templates/android-java-api-aar-test.yml + parameters: + artifactName: 'onnxruntime-android-qnn-aar' + packageName: 'onnxruntime-android-qnn' + #TODO: get this information from the setup stage + QnnSDKVersion: '2.36.1.250708' + - template: nuget/templates/test_win.yml parameters: AgentPool: 'onnxruntime-Win-CPU-2022' @@ -180,7 +190,7 @@ stages: - name: runCodesignValidationInjection value: false - name: docker_base_image - value: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_almalinux8_gcc12:20250714.2 + value: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_almalinux8_gcc12:20250724.1 timeoutInMinutes: 60 steps: - checkout: self diff --git a/tools/ci_build/github/azure-pipelines/dml-nuget-packaging.yml b/tools/ci_build/github/azure-pipelines/dml-nuget-packaging.yml index 0e0a0632b9b6c..6e196e1f8ffd3 100644 --- a/tools/ci_build/github/azure-pipelines/dml-nuget-packaging.yml +++ b/tools/ci_build/github/azure-pipelines/dml-nuget-packaging.yml @@ -68,7 +68,6 @@ extends: ArtifactName: 'drop-nuget-dml' StageName: 'Windows_CI_GPU_DML_Dev' BuildCommand: --build_dir $(Build.BinariesDirectory) --skip_submodule_sync --build_shared_lib --enable_onnx_tests --enable_wcos --use_telemetry --use_dml --enable_generic_interface --build_nodejs --cmake_generator "Visual Studio 17 2022" --use_vcpkg --use_vcpkg_ms_internal_asset_cache - BuildArch: 'x64' msbuildArchitecture: 'amd64' EnvSetupScript: 'setup_env.bat' sln_platform: 'x64' @@ -88,7 +87,6 @@ extends: ArtifactName: 'drop-win-dml-arm64-zip' StageName: 'Windows_CI_GPU_DML_Dev_arm64' BuildCommand: --build_dir $(Build.BinariesDirectory) --arm64 --skip_submodule_sync --build_shared_lib --enable_onnx_tests --enable_wcos --use_telemetry --use_dml --enable_generic_interface --build_nodejs --cmake_generator "Visual Studio 17 2022" --use_vcpkg --use_vcpkg_ms_internal_asset_cache - BuildArch: 'x64' EnvSetupScript: 'setup_env.bat' sln_platform: 'arm64' DoDebugBuild: 'false' diff --git a/tools/ci_build/github/azure-pipelines/linux-gpu-tensorrt-cuda-minimal-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-gpu-tensorrt-cuda-minimal-ci-pipeline.yml index da2f7b5e01e5f..b304ccdb4c533 100644 --- a/tools/ci_build/github/azure-pipelines/linux-gpu-tensorrt-cuda-minimal-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-gpu-tensorrt-cuda-minimal-ci-pipeline.yml @@ -39,9 +39,9 @@ variables: - template: templates/common-variables.yml - name: docker_base_image ${{ if eq(parameters.CudaVersion, '11.8') }}: - value: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda11_x64_almalinux8_gcc11:20250714.2 + value: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda11_x64_almalinux8_gcc11:20250724.1 ${{ if eq(parameters.CudaVersion, '12.2') }}: - value: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_almalinux8_gcc12:20250714.2 + value: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_almalinux8_gcc12:20250724.1 - name: linux_trt_version ${{ if eq(parameters.CudaVersion, '11.8') }}: value: ${{ variables.linux_trt_version_cuda11 }} diff --git a/tools/ci_build/github/azure-pipelines/mac-react-native-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/mac-react-native-ci-pipeline.yml index e72f088cfeb55..c71cd95150aa6 100644 --- a/tools/ci_build/github/azure-pipelines/mac-react-native-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/mac-react-native-ci-pipeline.yml @@ -55,5 +55,5 @@ stages: parameters: NpmPackagingMode: ${{ variables.NpmPackagingMode }} BuildConfig: 'Release' - PoolName: 'onnxruntime-Ubuntu2204-AMD-CPU' + PoolName: 'onnxruntime-Ubuntu2404-AMD-CPU' enable_code_sign: false diff --git a/tools/ci_build/github/azure-pipelines/nodejs/templates/test.yml b/tools/ci_build/github/azure-pipelines/nodejs/templates/test.yml index 28ece85428287..6c998f9c3da13 100644 --- a/tools/ci_build/github/azure-pipelines/nodejs/templates/test.yml +++ b/tools/ci_build/github/azure-pipelines/nodejs/templates/test.yml @@ -1,10 +1,12 @@ steps: - checkout: none -- task: DownloadPipelineArtifact@0 - displayName: 'Download NPM packages' - inputs: - artifactName: NPM_packages - targetPath: '$(Build.BinariesDirectory)/nodejs-artifact' +- download: build + displayName: 'Download NPM_packages' + artifact: 'NPM_packages' + +- script: | + mv $(Pipeline.Workspace)/build/NPM_packages '$(Build.BinariesDirectory)/nodejs-artifact' + - script: mkdir e2e_test workingDirectory: '$(Build.BinariesDirectory)' @@ -31,6 +33,4 @@ steps: npm init -y npm install $(NpmPackageFilesForTest) --onnxruntime-node-install-cuda=skip node -p "require('onnxruntime-node')" - workingDirectory: '$(Build.BinariesDirectory)/e2e_test' - -- template: ../../templates/clean-agent-build-directory-step.yml + workingDirectory: '$(Build.BinariesDirectory)/e2e_test' \ No newline at end of file diff --git a/tools/ci_build/github/azure-pipelines/nodejs/templates/test_linux.yml b/tools/ci_build/github/azure-pipelines/nodejs/templates/test_linux.yml index 13516b93db4e0..50121595aed54 100644 --- a/tools/ci_build/github/azure-pipelines/nodejs/templates/test_linux.yml +++ b/tools/ci_build/github/azure-pipelines/nodejs/templates/test_linux.yml @@ -13,7 +13,6 @@ stages: timeoutInMinutes: 120 pool: name: ${{ parameters.AgentPool }} - os: 'linux' variables: - name: OnnxRuntimeBuildDirectory diff --git a/tools/ci_build/github/azure-pipelines/nodejs/templates/test_macos.yml b/tools/ci_build/github/azure-pipelines/nodejs/templates/test_macos.yml index 6f51abb761c51..bb4f600395ac9 100644 --- a/tools/ci_build/github/azure-pipelines/nodejs/templates/test_macos.yml +++ b/tools/ci_build/github/azure-pipelines/nodejs/templates/test_macos.yml @@ -12,7 +12,7 @@ stages: timeoutInMinutes: 120 pool: name: 'Azure Pipelines' - image: 'macOS-14' + image: 'macOS-15' os: 'macOS' variables: diff --git a/tools/ci_build/github/azure-pipelines/npm-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/npm-packaging-pipeline.yml index f6d404c3bde62..3615f9f7c0960 100644 --- a/tools/ci_build/github/azure-pipelines/npm-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/npm-packaging-pipeline.yml @@ -55,7 +55,7 @@ extends: parameters: NpmPackagingMode: ${{ variables.NpmPackagingMode }} IsReleasePipeline: true - PoolName: 'onnxruntime-Ubuntu2204-AMD-CPU' + PoolName: 'onnxruntime-Ubuntu2404-AMD-CPU' PackageName: 'onnxruntime-web' ExtraBuildArgs: '' UseWebPoolName: true @@ -69,7 +69,7 @@ extends: parameters: NpmPackagingMode: ${{ variables.NpmPackagingMode }} BuildConfig: 'Release' - PoolName: 'onnxruntime-Ubuntu2204-AMD-CPU' + PoolName: 'onnxruntime-Ubuntu2404-AMD-CPU' PackageName: 'onnxruntime-react-native' InitialStageDependsOn: 'Precheck_and_extract_commit' enable_code_sign: false diff --git a/tools/ci_build/github/azure-pipelines/nuget-windows-ai.yml b/tools/ci_build/github/azure-pipelines/nuget-windows-ai.yml index feffd6b268c17..8e29381bc7eb4 100644 --- a/tools/ci_build/github/azure-pipelines/nuget-windows-ai.yml +++ b/tools/ci_build/github/azure-pipelines/nuget-windows-ai.yml @@ -100,7 +100,7 @@ extends: - output: pipelineArtifact path: '$(Build.ArtifactStagingDirectory)/merged' artifact: drop_Windows_Build_NuGet_Packaging - - ${{if and(eq(parameters.IsReleaseBuild, false), or(eq(variables['Build.SourceBranch'], 'refs/heads/main'), startsWith(variables['Build.SourceBranch'], 'refs/heads/users/snnn/')))}}: + - ${{if and(eq(parameters.IsReleaseBuild, false), or(eq(variables['Build.SourceBranch'], 'refs/heads/main'), startsWith(variables['Build.SourceBranch'], 'refs/heads/rel-')))}}: - output: nuget useDotNetTask: false # The default is false to use the NuGetCommand task. Set to true to use the DotNetCoreCLI task to publish packages. packagesToPush: '$(Build.ArtifactStagingDirectory)/merged/*.nupkg;!$(Build.ArtifactStagingDirectory)/merged/*.symbols.nupkg' diff --git a/tools/ci_build/github/azure-pipelines/nuget/templates/dml-vs-2022.yml b/tools/ci_build/github/azure-pipelines/nuget/templates/dml-vs-2022.yml index 71ff44ebb2ae5..757b8ac6e9a16 100644 --- a/tools/ci_build/github/azure-pipelines/nuget/templates/dml-vs-2022.yml +++ b/tools/ci_build/github/azure-pipelines/nuget/templates/dml-vs-2022.yml @@ -6,11 +6,8 @@ parameters: DoNugetPack: 'false' NuPackScript : '' ArtifactName: 'drop-nuget' - DoNodejsPack: 'false' - BuildNodejs: 'true' DoEsrp: 'false' DoTestCoverage: 'false' - BuildArch: 'x64' # Optional. Options: x86, x64 sln_platform: 'x64' # Options: Win32, x64, arm, arm64 EnvSetupScript: 'setup_env.bat' AgentDemands: [] @@ -40,7 +37,6 @@ stages: variables: buildDirectory: '$(Build.BinariesDirectory)' OnnxRuntimeBuildDirectory: '$(Build.BinariesDirectory)' - runCodesignValidationInjection: and(${{ parameters.DoNodejsPack }},${{ parameters. DoEsrp}}) #For the others, code sign is in a separated job DOTNET_SKIP_FIRST_TIME_EXPERIENCE: true ALLOW_RELEASED_ONNX_OPSET_ONLY: ${{ parameters.AllowReleasedOpsetOnly }} BuildDate : $[stageDependencies.Setup.Set_Variables.outputs['Set_Build_Date.BuildDate']] @@ -63,7 +59,7 @@ stages: inputs: versionSpec: '3.12' addToPath: true - architecture: ${{ parameters.BuildArch }} + architecture: x64 - task: PipAuthenticate@1 displayName: 'Pip Authenticate' inputs: @@ -74,13 +70,13 @@ stages: inputs: version: 8.x env: - PROCESSOR_ARCHITECTURE: ${{ parameters.BuildArch }} + PROCESSOR_ARCHITECTURE: x64 - task: BatchScript@1 displayName: 'Setup VS2022 env vars' inputs: filename: 'C:\Program Files\Microsoft Visual Studio\2022\Enterprise\VC\Auxiliary\Build\vcvarsall.bat' - arguments: ${{ parameters.BuildArch }} + arguments: x64 modifyEnvironment: true - ${{ if notIn(parameters['sln_platform'], 'Win32', 'x64') }}: @@ -114,7 +110,7 @@ stages: inputs: version: 6.x env: - PROCESSOR_ARCHITECTURE: ${{ parameters.BuildArch }} + PROCESSOR_ARCHITECTURE: x64 - template: ../../templates/win-esrp-dll.yml parameters: @@ -148,64 +144,10 @@ stages: ${{if eq(variables['Build.SourceBranch'], 'refs/heads/main')}}: symbolExpiryTime: 60 includePublicSymbolServer: true - symbolsArtifactName: onnxruntime-dml-nuget-${{ parameters.BuildArch }} + symbolsArtifactName: onnxruntime-dml-nuget-${{ parameters.sln_platform }} symbolsVersion: $(Build.BuildId) symbolProject: 'ONNX Runtime' subscription: 'OnnxrunTimeCodeSign_20240611' searchPattern: | $(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo\onnxruntime.pdb $(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo\onnxruntime_providers_*.pdb - - # Node.js Publish - - ${{ if eq(parameters['DoNodejsPack'], 'true') }}: - - task: BatchScript@1 - displayName: 'Setup VS env vars' - inputs: - filename: 'C:\Program Files\Microsoft Visual Studio\2022\Enterprise\VC\Auxiliary\Build\vcvarsall.bat' - arguments: ${{ parameters.BuildArch }} - modifyEnvironment: true - - template: ../../templates/win-esrp-dll.yml - parameters: - FolderPath: '$(Build.SourcesDirectory)\js\node\bin\napi-v6\win32\x64' - DisplayName: 'ESRP - Sign Node.js binding binaries' - DoEsrp: ${{ parameters.DoEsrp }} - Pattern: '*.dll,*.node' - - - script: | - del /Q $(Build.SourcesDirectory)\js\node\bin\napi-v6\win32\x64\CodeSignSummary-*.* - call npm pack - copy $(Build.SourcesDirectory)\js\node\onnxruntime-*.tgz $(Build.ArtifactStagingDirectory) - xcopy /E /I $(Build.SourcesDirectory)\js\node\prebuilds $(Build.ArtifactStagingDirectory)\prebuilds - workingDirectory: '$(Build.SourcesDirectory)\js\node' - displayName: 'Create NPM Package' - - - task: 1ES.PublishPipelineArtifact@1 - displayName: 'Publish Pipeline Artifact: ${{ parameters.ArtifactName }}' - inputs: - artifactName: ${{ parameters.ArtifactName }} - targetPath: '$(Build.ArtifactStagingDirectory)' - - # Put an unzipped version there to check if all the binaries are signed. - - script: | - 7z x $(Build.ArtifactStagingDirectory)\prebuilds\onnxruntime-*.tar.gz - 7z x $(Build.ArtifactStagingDirectory)\onnxruntime-*.tar - displayName: 'Unzip package to test' - workingDirectory: '$(Build.ArtifactStagingDirectory)' - - - ${{ if eq(parameters.BuildNodejs, 'true') }}: - - task: CopyFiles@2 - displayName: 'Copy DirectML binaries to: $(Build.SourcesDirectory)\js\node\bin\napi-v6\win32\${{ parameters.sln_platform }}' - inputs: - SourceFolder: '$(Build.BinariesDirectory)\$(BuildConfig)\$(BuildConfig)' - Contents: 'DirectML.dll' - TargetFolder: '$(Build.SourcesDirectory)\js\node\bin\napi-v6\win32\${{ parameters.sln_platform }}' - - template: ../../templates/win-esrp-dll.yml - parameters: - FolderPath: '$(Build.SourcesDirectory)\js\node\bin\napi-v6\win32\${{ parameters.sln_platform }}' - DisplayName: 'ESRP - Sign Node.js binding binaries' - DoEsrp: ${{ parameters.DoEsrp }} - Pattern: '*.node' - - task: 1ES.PublishPipelineArtifact@1 - inputs: - targetPath: '$(Build.SourcesDirectory)\js\node\bin\napi-v6\win32\${{ parameters.sln_platform }}' - artifactName: 'drop-onnxruntime-nodejs-win-${{ parameters.sln_platform }}-dml' diff --git a/tools/ci_build/github/azure-pipelines/nuget/templates/test_android.yml b/tools/ci_build/github/azure-pipelines/nuget/templates/test_android.yml index 17ea414152be8..e75804f0b35cb 100644 --- a/tools/ci_build/github/azure-pipelines/nuget/templates/test_android.yml +++ b/tools/ci_build/github/azure-pipelines/nuget/templates/test_android.yml @@ -24,17 +24,13 @@ stages: inputs: versionSpec: 6.10.x - - template: ../../templates/flex-downloadPipelineArtifact.yml - parameters: - StepName: 'Download Pipeline Artifact' - ArtifactName: drop-signed-nuget-${{ parameters.ArtifactSuffix }} - TargetPath: '$(Build.BinariesDirectory)\nuget-artifact' - SpecificArtifact: ${{ parameters.SpecificArtifact }} - BuildId: ${{ parameters.BuildId }} + - download: build + displayName: 'Download Nuget' + artifact: 'drop-signed-nuget-${{ parameters.ArtifactSuffix }}' - template: get-nuget-package-version-as-variable.yml parameters: - packageFolder: '$(Build.BinariesDirectory)\nuget-artifact' + packageFolder: '$(Pipeline.Workspace)/build/drop-signed-nuget-${{ parameters.ArtifactSuffix }}' - task: PowerShell@2 displayName: Install MAUI workloads @@ -49,7 +45,7 @@ stages: inputs: targetType: 'inline' script: | - dotnet nuget add source $(Build.BinariesDirectory)\nuget-artifact --name local-nuget + dotnet nuget add source $(Pipeline.Workspace)/build/drop-signed-nuget-${{ parameters.ArtifactSuffix }} --name local-nuget dotnet publish -c Release --property:UsePrebuiltNativePackage=true --property:CurrentOnnxRuntimeVersion=$(NuGetPackageVersionNumber) -f net8.0-android workingDirectory: '$(Build.SourcesDirectory)\csharp\test\Microsoft.ML.OnnxRuntime.Tests.MAUI' diff --git a/tools/ci_build/github/azure-pipelines/nuget/templates/test_linux.yml b/tools/ci_build/github/azure-pipelines/nuget/templates/test_linux.yml index 2f4f480eeb122..89ce3f3c86727 100644 --- a/tools/ci_build/github/azure-pipelines/nuget/templates/test_linux.yml +++ b/tools/ci_build/github/azure-pipelines/nuget/templates/test_linux.yml @@ -1,5 +1,5 @@ parameters: - AgentPool: 'onnxruntime-Ubuntu2204-AMD-CPU' + AgentPool: 'onnxruntime-Ubuntu2404-AMD-CPU' ArtifactSuffix: '' NugetPackageName: '' StageSuffix: 'CPU' @@ -30,21 +30,18 @@ stages: value: '$(Build.BinariesDirectory)' steps: - - template: ../../templates/flex-downloadPipelineArtifact.yml - parameters: - StepName: 'Download Signed NuGet' - ArtifactName: drop-signed-nuget-${{ parameters.ArtifactSuffix }} - TargetPath: '$(Build.BinariesDirectory)/nuget-artifact' - SpecificArtifact: ${{ parameters.SpecificArtifact }} - BuildId: ${{ parameters.BuildId }} + - download: build + displayName: 'Download Nuget' + artifact: 'drop-signed-nuget-${{ parameters.ArtifactSuffix }}' + - download: build + displayName: 'Download Linux CustomOp TestData' + artifact: ${{ parameters.CustomOpArtifactName }} - - template: ../../templates/flex-downloadPipelineArtifact.yml - parameters: - StepName: 'Download Linux CustomOp TestData' - ArtifactName: ${{ parameters.CustomOpArtifactName }} - TargetPath: '$(Build.BinariesDirectory)/testdata' - SpecificArtifact: ${{ parameters.specificArtifact }} - BuildId: ${{ parameters.BuildId }} + + - script: | + mv $(Pipeline.Workspace)/build/drop-signed-nuget-${{ parameters.ArtifactSuffix }} $(Build.BinariesDirectory)/nuget-artifact + mv $(Pipeline.Workspace)/build/${{ parameters.CustomOpArtifactName }} $(Build.BinariesDirectory)/testdata + - template: get-nuget-package-version-as-variable.yml parameters: @@ -110,6 +107,4 @@ stages: DisableContribOps: $(DisableContribOps) DisableMlOps: $(DisableMlOps) IsReleaseBuild: $(IsReleaseBuild) - PACKAGENAME: ${{ parameters.NugetPackageName }} - - - template: ../../templates/clean-agent-build-directory-step.yml + PACKAGENAME: ${{ parameters.NugetPackageName }} \ No newline at end of file diff --git a/tools/ci_build/github/azure-pipelines/nuget/templates/test_macos.yml b/tools/ci_build/github/azure-pipelines/nuget/templates/test_macos.yml index dcaa8f9381ad4..1d122d64b1211 100644 --- a/tools/ci_build/github/azure-pipelines/nuget/templates/test_macos.yml +++ b/tools/ci_build/github/azure-pipelines/nuget/templates/test_macos.yml @@ -21,23 +21,19 @@ stages: displayName: 'Download Nuget' artifact: 'drop-signed-nuget-${{ parameters.ArtifactSuffix }}' + - download: build + displayName: 'Download Nuget' + artifact: 'onnxruntime-osx' - script: | mv $(Pipeline.Workspace)/build/drop-signed-nuget-${{ parameters.ArtifactSuffix }} $(Build.BinariesDirectory)/nuget-artifact - - - - task: DownloadPipelineArtifact@0 - displayName: 'Download OsX CustomOp test data' - inputs: - artifactName: 'onnxruntime-osx' - targetPath: '$(Build.BinariesDirectory)/testdata' + mv $(Pipeline.Workspace)/build/onnxruntime-osx $(Build.BinariesDirectory)/testdata - template: get-nuget-package-version-as-variable.yml parameters: packageFolder: '$(Build.BinariesDirectory)/nuget-artifact' - script: | - echo "TODO: Enable this test once fix this nuget test issue" $(Build.SourcesDirectory)/csharp/test/Microsoft.ML.OnnxRuntime.EndToEndTests/runtest.sh \ $(Build.BinariesDirectory)/nuget-artifact \ $(NuGetPackageVersionNumber) \ @@ -52,6 +48,4 @@ stages: OnnxRuntimeBuildDirectory: $(Build.BinariesDirectory) DisableContribOps: $(DisableContribOps) DisableMlOps: $(DisableMlOps) - IsReleaseBuild: $(IsReleaseBuild) - - - template: ../../templates/clean-agent-build-directory-step.yml + IsReleaseBuild: $(IsReleaseBuild) \ No newline at end of file diff --git a/tools/ci_build/github/azure-pipelines/post-merge-jobs.yml b/tools/ci_build/github/azure-pipelines/post-merge-jobs.yml index 11beafb7c05e1..8647b32962165 100644 --- a/tools/ci_build/github/azure-pipelines/post-merge-jobs.yml +++ b/tools/ci_build/github/azure-pipelines/post-merge-jobs.yml @@ -19,7 +19,7 @@ stages: parameters: NpmPackagingMode: 'dev' IsReleasePipeline: true - PoolName: 'onnxruntime-Ubuntu2204-AMD-CPU' + PoolName: 'onnxruntime-Ubuntu2404-AMD-CPU' BuildStaticLib: true ExtraBuildArgs: '' UseWebPoolName: true @@ -340,7 +340,7 @@ stages: timeoutInMinutes: 150 variables: skipComponentGovernanceDetection: true - pool: 'onnxruntime-Ubuntu2204-AMD-CPU' + pool: 'onnxruntime-Ubuntu2404-AMD-CPU' steps: - template: templates/set-version-number-variables-step.yml @@ -383,7 +383,7 @@ stages: - job: AndroidCustomBuildScript workspace: clean: all - pool: 'onnxruntime-Ubuntu2204-AMD-CPU' + pool: 'onnxruntime-Ubuntu2404-AMD-CPU' variables: dockerImageTag: onnxruntime-android-custom-build steps: diff --git a/tools/ci_build/github/azure-pipelines/py-cuda-package-test-pipeline.yml b/tools/ci_build/github/azure-pipelines/py-cuda-package-test-pipeline.yml index b10d15432ed5b..f1d578b9c86a4 100644 --- a/tools/ci_build/github/azure-pipelines/py-cuda-package-test-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/py-cuda-package-test-pipeline.yml @@ -18,14 +18,14 @@ stages: machine_pool: 'Onnxruntime-Linux-GPU' python_wheel_suffix: '_gpu' timeout: 480 - docker_base_image: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_almalinux8_gcc12:20250714.2 + docker_base_image: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_almalinux8_gcc12:20250724.1 cuda_version: '12.2' - stage: Republish_Wheels dependsOn: jobs: - job: Python_Publishing_GPU - pool: 'onnxruntime-Ubuntu2204-AMD-CPU' + pool: 'onnxruntime-Ubuntu2404-AMD-CPU' steps: - checkout: none - download: build @@ -54,4 +54,4 @@ stages: - publish: $(Pipeline.Workspace)/build/onnxruntime_gpu artifact: whl - displayName: Republish artifacts \ No newline at end of file + displayName: Republish artifacts diff --git a/tools/ci_build/github/azure-pipelines/py-package-test-pipeline.yml b/tools/ci_build/github/azure-pipelines/py-package-test-pipeline.yml index 01c1366107292..379b20ce8a0c4 100644 --- a/tools/ci_build/github/azure-pipelines/py-package-test-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/py-package-test-pipeline.yml @@ -11,7 +11,7 @@ stages: - template: templates/py-packaging-linux-test-cpu.yml parameters: arch: 'x86_64' - machine_pool: 'onnxruntime-Ubuntu2204-AMD-CPU' + machine_pool: 'onnxruntime-Ubuntu2404-AMD-CPU' - stage: Linux_Test_CPU_aarch64_stage @@ -38,7 +38,7 @@ stages: itemPattern: '*/*manylinux*x86_64.whl' arch: 'x86_64' machine_pool: - name: 'onnxruntime-Ubuntu2204-AMD-CPU' + name: 'onnxruntime-Ubuntu2404-AMD-CPU' # ****The following Stage depend on all previous tags. *** diff --git a/tools/ci_build/github/azure-pipelines/stages/c-api-linux-cpu-stage.yml b/tools/ci_build/github/azure-pipelines/stages/c-api-linux-cpu-stage.yml index ee46d5dac2ff8..ea706a65fb4c9 100644 --- a/tools/ci_build/github/azure-pipelines/stages/c-api-linux-cpu-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/c-api-linux-cpu-stage.yml @@ -6,6 +6,6 @@ stages: parameters: OnnxruntimeArch: 'x64' OnnxruntimeNodejsBindingArch: 'x64' - PoolName: 'onnxruntime-Ubuntu2204-AMD-CPU' + PoolName: 'onnxruntime-Ubuntu2404-AMD-CPU' PackageJava: false PackageNodeJS: false \ No newline at end of file diff --git a/tools/ci_build/github/azure-pipelines/stages/download-java-tools-stage.yml b/tools/ci_build/github/azure-pipelines/stages/download-java-tools-stage.yml index 67fa5dba029b1..949d29d27da9d 100644 --- a/tools/ci_build/github/azure-pipelines/stages/download-java-tools-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/download-java-tools-stage.yml @@ -4,7 +4,7 @@ stages: jobs: - job: Download_Java_Tools pool: - name: 'onnxruntime-Ubuntu2204-AMD-CPU' + name: 'onnxruntime-Ubuntu2404-AMD-CPU' os: linux steps: - checkout: none diff --git a/tools/ci_build/github/azure-pipelines/stages/jobs/py-linux-cuda-package-test-job.yml b/tools/ci_build/github/azure-pipelines/stages/jobs/py-linux-cuda-package-test-job.yml index 890b97cbf889a..858de4d173484 100644 --- a/tools/ci_build/github/azure-pipelines/stages/jobs/py-linux-cuda-package-test-job.yml +++ b/tools/ci_build/github/azure-pipelines/stages/jobs/py-linux-cuda-package-test-job.yml @@ -44,9 +44,9 @@ jobs: - template: ../../templates/common-variables.yml - name: docker_base_image ${{ if eq(parameters.CudaVersion, '11.8') }}: - value: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda11_x64_almalinux8_gcc11:20250714.2 + value: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda11_x64_almalinux8_gcc11:20250724.1 ${{ if eq(parameters.CudaVersion, '12.2') }}: - value: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_almalinux8_gcc12:20250714.2 + value: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_almalinux8_gcc12:20250724.1 - name: linux_trt_version ${{ if eq(parameters.CudaVersion, '11.8') }}: value: ${{ variables.linux_trt_version_cuda11 }} diff --git a/tools/ci_build/github/azure-pipelines/stages/nodejs-linux-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/stages/nodejs-linux-packaging-stage.yml index e1247565d8f5b..bca95a4a2fd02 100644 --- a/tools/ci_build/github/azure-pipelines/stages/nodejs-linux-packaging-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/nodejs-linux-packaging-stage.yml @@ -13,7 +13,7 @@ stages: clean: all timeoutInMinutes: 180 pool: - name: 'onnxruntime-Ubuntu2204-AMD-CPU' + name: 'onnxruntime-Ubuntu2404-AMD-CPU' os: linux variables: - template: ../templates/common-variables.yml diff --git a/tools/ci_build/github/azure-pipelines/stages/nuget-linux-cuda-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/stages/nuget-linux-cuda-packaging-stage.yml index e36fe98fe0ac2..4175a339535e4 100644 --- a/tools/ci_build/github/azure-pipelines/stages/nuget-linux-cuda-packaging-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/nuget-linux-cuda-packaging-stage.yml @@ -16,7 +16,7 @@ stages: clean: all timeoutInMinutes: 150 pool: - name: 'onnxruntime-Ubuntu2204-AMD-CPU' + name: 'onnxruntime-Ubuntu2404-AMD-CPU' os: linux variables: - name: CUDA_VERSION_MAJOR @@ -65,7 +65,7 @@ stages: clean: all timeoutInMinutes: 180 pool: - name: 'onnxruntime-Ubuntu2204-AMD-CPU' + name: 'onnxruntime-Ubuntu2404-AMD-CPU' os: linux variables: - template: ../templates/common-variables.yml diff --git a/tools/ci_build/github/azure-pipelines/stages/nuget_dml_packaging_stage.yml b/tools/ci_build/github/azure-pipelines/stages/nuget_dml_packaging_stage.yml index 06b52173b236c..33d656d18928d 100644 --- a/tools/ci_build/github/azure-pipelines/stages/nuget_dml_packaging_stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/nuget_dml_packaging_stage.yml @@ -29,7 +29,7 @@ stages: artifactName: drop-win-dml-arm64-zip targetPath: '$(Build.BinariesDirectory)/nuget-artifact-dml' outputs: - - ${{if and(eq(parameters.IsReleaseBuild, false), or(eq(variables['Build.SourceBranch'], 'refs/heads/main'), startsWith(variables['Build.SourceBranch'], 'refs/heads/users/snnn/')))}}: + - ${{if and(eq(parameters.IsReleaseBuild, false), or(eq(variables['Build.SourceBranch'], 'refs/heads/main'), startsWith(variables['Build.SourceBranch'], 'refs/heads/rel-')))}}: - output: nuget useDotNetTask: false # The default is false to use the NuGetCommand task. Set to true to use the DotNetCoreCLI task to publish packages. packagesToPush: '$(Build.ArtifactStagingDirectory)/*.nupkg' diff --git a/tools/ci_build/github/azure-pipelines/stages/py-cpu-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/stages/py-cpu-packaging-stage.yml index e366dd147b118..f4a62208059c8 100644 --- a/tools/ci_build/github/azure-pipelines/stages/py-cpu-packaging-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/py-cpu-packaging-stage.yml @@ -302,7 +302,7 @@ stages: - template: ../templates/py-linux.yml parameters: arch: 'x86_64' - machine_pool: 'onnxruntime-Ubuntu2204-AMD-CPU-Large' + machine_pool: 'onnxruntime-Ubuntu2404-AMD-CPU' extra_build_arg: ${{ parameters.build_py_parameters }} cmake_build_type: ${{ parameters.cmake_build_type }} is1ES: true @@ -316,7 +316,21 @@ stages: MACHINE_POOL: 'onnxruntime-qnn-windows-vs-2022-arm64' QNN_SDK: ${{ parameters.qnn_sdk_version }} BUILD_PY_PARAMETERS: ${{ parameters.build_py_parameters }} - is1ES: true + PYTHON_VERSION: '3.11' + + - template: ../templates/py-win-arm64-qnn.yml + parameters: + MACHINE_POOL: 'onnxruntime-qnn-windows-vs-2022-arm64' + QNN_SDK: ${{ parameters.qnn_sdk_version }} + BUILD_PY_PARAMETERS: ${{ parameters.build_py_parameters }} + PYTHON_VERSION: '3.12' + + - template: ../templates/py-win-arm64-qnn.yml + parameters: + MACHINE_POOL: 'onnxruntime-qnn-windows-vs-2022-arm64' + QNN_SDK: ${{ parameters.qnn_sdk_version }} + BUILD_PY_PARAMETERS: ${{ parameters.build_py_parameters }} + PYTHON_VERSION: '3.13' - ${{ if eq(parameters.enable_windows_arm64ec_qnn, true) }}: - stage: Python_Packaging_Windows_arm64ec_QNN @@ -327,7 +341,6 @@ stages: MACHINE_POOL: 'Onnxruntime-QNNEP-Windows-2022-CPU' QNN_SDK: ${{ parameters.qnn_sdk_version }} BUILD_PY_PARAMETERS: ${{ parameters.build_py_parameters }} - is1ES: true - ${{ if eq(parameters.enable_windows_x64_qnn, true) }}: - stage: Python_Packaging_Windows_x64_QNN @@ -346,7 +359,7 @@ stages: jobs: - template: ../templates/py-linux-qnn.yml parameters: - machine_pool: 'onnxruntime-Ubuntu2204-AMD-CPU' + machine_pool: 'onnxruntime-Ubuntu2404-AMD-CPU' extra_build_arg: ${{ parameters.build_py_parameters }} cmake_build_type: ${{ parameters.cmake_build_type }} is1ES: true diff --git a/tools/ci_build/github/azure-pipelines/stages/py-cuda-publishing-stage.yml b/tools/ci_build/github/azure-pipelines/stages/py-cuda-publishing-stage.yml index fbfbc69bce0a8..25645044c30c3 100644 --- a/tools/ci_build/github/azure-pipelines/stages/py-cuda-publishing-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/py-cuda-publishing-stage.yml @@ -8,7 +8,7 @@ stages: jobs: - job: Python_Publishing_GPU pool: - name: 'onnxruntime-Ubuntu2204-AMD-CPU' + name: 'onnxruntime-Ubuntu2404-AMD-CPU' os: linux steps: - checkout: none diff --git a/tools/ci_build/github/azure-pipelines/stages/py-gpu-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/stages/py-gpu-packaging-stage.yml index 4058ddfe089c8..f3d3b2a8ecbf2 100644 --- a/tools/ci_build/github/azure-pipelines/stages/py-gpu-packaging-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/py-gpu-packaging-stage.yml @@ -45,8 +45,8 @@ stages: - template: py-linux-gpu-stage.yml parameters: arch: 'x86_64' - machine_pool: 'onnxruntime-Ubuntu2204-AMD-CPU-Large' + machine_pool: 'onnxruntime-Ubuntu2404-AMD-CPU' extra_build_arg: ${{ parameters.build_py_parameters }} cmake_build_type: ${{ parameters.cmake_build_type }} cuda_version: ${{ parameters.cuda_version }} - docker_base_image: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_almalinux8_gcc12:20250714.2 + docker_base_image: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_almalinux8_gcc12:20250724.1 diff --git a/tools/ci_build/github/azure-pipelines/stages/set_packaging_variables_stage.yml b/tools/ci_build/github/azure-pipelines/stages/set_packaging_variables_stage.yml index 869fe05cb1756..396d37ca9710a 100644 --- a/tools/ci_build/github/azure-pipelines/stages/set_packaging_variables_stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/set_packaging_variables_stage.yml @@ -24,7 +24,7 @@ stages: jobs: - job: Set_Variables pool: - name: 'onnxruntime-Ubuntu2204-AMD-CPU' + name: 'onnxruntime-Ubuntu2404-AMD-CPU' os: 'linux' templateContext: sdl: diff --git a/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar-test.yml b/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar-test.yml index 5b95e6ff9c89a..6e6fb98e6e68c 100644 --- a/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar-test.yml +++ b/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar-test.yml @@ -29,7 +29,7 @@ parameters: jobs: - job: Final_AAR_Testing_Android pool: - name: 'onnxruntime-Ubuntu2204-AMD-CPU' + name: 'onnxruntime-Ubuntu2404-AMD-CPU' os: linux workspace: clean: all diff --git a/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar.yml b/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar.yml index 92e862bd79008..e4bfe20238770 100644 --- a/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar.yml +++ b/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar.yml @@ -33,7 +33,7 @@ parameters: - name: pool_name displayName: Pool name type: string - default: 'onnxruntime-Ubuntu2204-AMD-CPU' + default: 'onnxruntime-Ubuntu2404-AMD-CPU' - name: packageName displayName: Package Name @@ -103,7 +103,7 @@ jobs: - template: use-android-ndk.yml - ${{ if contains(parameters.packageName, 'qnn') }}: - - template: jobs/download_linux_qnn_sdk.yml + - template: jobs/init_linux_qnn_sdk_x64.yml parameters: QnnSDKVersion: '${{parameters.QnnSDKVersion}}' diff --git a/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml b/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml index f4335df1530cf..bf65b0c54cf27 100644 --- a/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml +++ b/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml @@ -100,15 +100,6 @@ stages: QnnSDKVersion: ${{ parameters.QnnSDKVersion }} is1ES: ${{ parameters.is1ES }} -- stage: Final_AAR_Testing_Android_QNN - dependsOn: Android_Java_API_AAR_Packaging_QNN - jobs: - - template: android-java-api-aar-test.yml - parameters: - artifactName: 'onnxruntime-android-qnn-aar' - packageName: 'onnxruntime-android-qnn' - QnnSDKVersion: ${{ parameters.QnnSDKVersion }} - - stage: iOS_Full_xcframework dependsOn: [] jobs: diff --git a/tools/ci_build/github/azure-pipelines/templates/c-api-linux-cpu.yml b/tools/ci_build/github/azure-pipelines/templates/c-api-linux-cpu.yml index cd2997cc389e9..aa1e38f8b0159 100644 --- a/tools/ci_build/github/azure-pipelines/templates/c-api-linux-cpu.yml +++ b/tools/ci_build/github/azure-pipelines/templates/c-api-linux-cpu.yml @@ -16,7 +16,7 @@ parameters: - name: PoolName type: string - default: 'onnxruntime-Ubuntu2204-AMD-CPU' + default: 'onnxruntime-Ubuntu2404-AMD-CPU' - name: ArtifactNamePrefix type: string diff --git a/tools/ci_build/github/azure-pipelines/templates/jobs/download_linux_qnn_sdk.yml b/tools/ci_build/github/azure-pipelines/templates/jobs/download_linux_qnn_sdk.yml index 930dc83b73460..57703239fc594 100644 --- a/tools/ci_build/github/azure-pipelines/templates/jobs/download_linux_qnn_sdk.yml +++ b/tools/ci_build/github/azure-pipelines/templates/jobs/download_linux_qnn_sdk.yml @@ -39,10 +39,6 @@ steps: fi displayName: "Sanity Check: QnnSDKVersion vs sdk.yaml version" - - script: | - azcopy cp --recursive 'https://lotusscus.blob.core.windows.net/models/qnnsdk/Qualcomm AI Hub Proprietary License.pdf' $(QnnSDKRootDir) - displayName: 'Download Qualcomm AI Hub license' - - script: | ls -al $(QnnSDKRootDir) displayName: 'Print contents of QNN SDK' diff --git a/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_qnn_sdk.yml b/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_qnn_sdk.yml index 96eea6cd6d2fb..d2e401f3f6ab4 100644 --- a/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_qnn_sdk.yml +++ b/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_qnn_sdk.yml @@ -18,10 +18,6 @@ steps: echo $(QnnSDKRootDir) displayName: 'Print QnnSDKRootDir after downloading QNN SDK' - - powershell: | - azcopy.exe cp --recursive 'https://lotusscus.blob.core.windows.net/models/qnnsdk/Qualcomm AI Hub Proprietary License.pdf' $(QnnSDKRootDir) - displayName: 'Download Qualcomm AI Hub license' - - task: CmdLine@2 displayName: 'Print contents of QNN SDK' inputs: diff --git a/tools/ci_build/github/azure-pipelines/templates/jobs/init_linux_qnn_sdk_x64.yml b/tools/ci_build/github/azure-pipelines/templates/jobs/init_linux_qnn_sdk_x64.yml new file mode 100644 index 0000000000000..b7fb8a51f28be --- /dev/null +++ b/tools/ci_build/github/azure-pipelines/templates/jobs/init_linux_qnn_sdk_x64.yml @@ -0,0 +1,42 @@ +parameters: + - name: QnnSDKVersion + type: string + default: '2.36.1.250708' + +steps: + - bash: | + echo "##vso[task.setvariable variable=QnnSDKRootDir]/data/qnnsdk/qnn-v${{ parameters.QnnSDKVersion }}" + displayName: Set QnnSDKRootDir + + - script: | + echo $(QnnSDKRootDir) + displayName: 'Print QnnSDKRootDir after downloading QNN SDK' + + - script: | + set -x + sdk_file="$(QnnSDKRootDir)/sdk.yaml" + # Parse the sdk.yaml file to get the QNN SDK version downloaded + downloaded_qnn_sdk_version=$(grep '^version:' "$sdk_file" | head -n 1 | cut -d':' -f2 | xargs | cut -d'.' -f1-3 | tr -d '\r') + + # Extract major.minor.patch part from QnnSDKVersion passed as parameter + expected_qnn_sdk_version=$(echo ${{ parameters.QnnSDKVersion }} | cut -d'.' -f1-3) + + if [[ -z "$downloaded_qnn_sdk_version" ]]; then + echo "QNN version not found in sdk.yaml." + exit 1 + fi + + # Compare provided version with version from sdk.yaml + if [[ "$downloaded_qnn_sdk_version" == "$expected_qnn_sdk_version" ]]; then + echo "Success: QnnSDKVersion matches sdk.yaml version ($downloaded_qnn_sdk_version)." + else + echo "Error: QnnSDKVersion ($expected_qnn_sdk_version) does not match sdk.yaml version ($downloaded_qnn_sdk_version) in the QNN SDK directory" + exit 1 + fi + displayName: "Sanity Check: QnnSDKVersion vs sdk.yaml version" + + + + - script: | + ls -al $(QnnSDKRootDir) + displayName: 'Print contents of QNN SDK' diff --git a/tools/ci_build/github/azure-pipelines/templates/linux-cpu-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/templates/linux-cpu-packaging-pipeline.yml index fb1c63e1f8a24..986a384d5197d 100644 --- a/tools/ci_build/github/azure-pipelines/templates/linux-cpu-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/templates/linux-cpu-packaging-pipeline.yml @@ -31,7 +31,7 @@ stages: AdditionalBuildFlags: ${{ parameters.AdditionalBuildFlags }} OnnxruntimeArch: 'x64' OnnxruntimeNodejsBindingArch: 'x64' - PoolName: 'onnxruntime-Ubuntu2204-AMD-CPU' + PoolName: 'onnxruntime-Ubuntu2404-AMD-CPU' ArtifactNamePrefix: ${{ parameters.ArtifactNamePrefix }} PackageJava: ${{ parameters.PackageJava }} PackageNodeJS: false diff --git a/tools/ci_build/github/azure-pipelines/templates/linux-wasm-ci.yml b/tools/ci_build/github/azure-pipelines/templates/linux-wasm-ci.yml index 9f76c150ca2a4..e08de4be17574 100644 --- a/tools/ci_build/github/azure-pipelines/templates/linux-wasm-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/linux-wasm-ci.yml @@ -13,7 +13,7 @@ parameters: - name: PoolName type: string - default: 'onnxruntime-Ubuntu2204-AMD-CPU' + default: 'onnxruntime-Ubuntu2404-AMD-CPU' - name: SkipPublish type: boolean @@ -88,15 +88,15 @@ jobs: - script: | set -ex cd '$(Build.SourcesDirectory)/cmake/external/emsdk' - ./emsdk install 4.0.8 ccache-git-emscripten-64bit - ./emsdk activate 4.0.8 ccache-git-emscripten-64bit + ./emsdk install 4.0.11 ccache-git-emscripten-64bit + ./emsdk activate 4.0.11 ccache-git-emscripten-64bit displayName: 'emsdk install and activate ccache for emscripten' - ${{if eq(parameters.WithCache, false)}}: - script: | set -ex cd '$(Build.SourcesDirectory)/cmake/external/emsdk' - ./emsdk install 4.0.8 - ./emsdk activate 4.0.8 + ./emsdk install 4.0.11 + ./emsdk activate 4.0.11 displayName: 'emsdk install and activate ccache for emscripten' - template: build-linux-wasm-step.yml diff --git a/tools/ci_build/github/azure-pipelines/templates/py-linux-qnn.yml b/tools/ci_build/github/azure-pipelines/templates/py-linux-qnn.yml index 788ceff8fd4f2..2168214527c91 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-linux-qnn.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-linux-qnn.yml @@ -58,7 +58,7 @@ jobs: clean: true submodules: none - - template: jobs/download_linux_qnn_sdk.yml + - template: jobs/init_linux_qnn_sdk_x64.yml parameters: QnnSDKVersion: ${{ parameters.QnnSdk }} diff --git a/tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml b/tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml index 761c551e9f4d9..3c2ef4741f049 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml @@ -4,6 +4,10 @@ parameters: type: string default: 'onnxruntime-qnn-windows-vs-2022-arm64' +- name: PYTHON_VERSION + type: string + default: '3.11' + - name: QNN_SDK displayName: QNN SDK Version type: string @@ -19,13 +23,8 @@ parameters: type: string default: '' -- name: is1ES - displayName: 'Whether the pipeline is running in 1ES' - type: boolean - default: false - jobs: -- job: Win_py_arm64_qnn_Wheels +- job: Win_py_arm64_qnn_Wheels_${{ replace(parameters.PYTHON_VERSION,'.','_') }} timeoutInMinutes: 210 workspace: clean: all @@ -48,41 +47,21 @@ jobs: outputs: - output: pipelineArtifact targetPath: $(Build.ArtifactStagingDirectory) - artifactName: onnxruntime_qnn_arm64_$(PythonVersion) - - strategy: - matrix: - Python311_arm64: - PythonVersion: '3.11.0' - LocalPythonDir: 'C:\Python\Python311' - Python312_arm64: - PythonVersion: '3.12.6' - LocalPythonDir: 'C:\Python\Python312' - Python313_arm64: - PythonVersion: '3.13.2' - LocalPythonDir: 'C:\Python\Python313' + artifactName: onnxruntime_qnn_arm64_${{ parameters.PYTHON_VERSION }} + variables: GRADLE_OPTS: '-Dorg.gradle.daemon=false' VSGenerator: 'Visual Studio 17 2022' steps: - checkout: self clean: true - submodules: recursive + submodules: none - template: telemetry-steps.yml - - script: | - MKDIR $(Agent.ToolsDirectory)\Python\$(PythonVersion)\arm64 - XCOPY /s /y /h /e /c /q "$(LocalPythonDir)\*.*" $(Agent.ToolsDirectory)\Python\$(PythonVersion)\arm64\ - COPY NUL $(Agent.ToolsDirectory)\Python\$(PythonVersion)\arm64.complete - DIR $(Agent.ToolsDirectory)\Python - DIR $(Agent.ToolsDirectory)\Python\$(PythonVersion) - DIR $(Agent.ToolsDirectory)\Python\$(PythonVersion)\arm64 - displayName: Copy python $(PythonVersion) version to agent tools directory - - task: UsePythonVersion@0 inputs: - versionSpec: $(PythonVersion) + versionSpec: ${{ parameters.PYTHON_VERSION }} addToPath: true architecture: 'arm64' diff --git a/tools/ci_build/github/azure-pipelines/templates/py-win-arm64ec-qnn.yml b/tools/ci_build/github/azure-pipelines/templates/py-win-arm64ec-qnn.yml index 74cae38393ea6..c8d37457a1034 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-win-arm64ec-qnn.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-win-arm64ec-qnn.yml @@ -19,11 +19,6 @@ parameters: type: string default: '' -- name: is1ES - displayName: 'Whether the pipeline is running in 1ES' - type: boolean - default: false - jobs: - job: Win_py_x64_qnn_Wheels timeoutInMinutes: 210 diff --git a/tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml b/tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml index 7375318f3722e..52d9eb139fab7 100644 --- a/tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml +++ b/tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml @@ -29,7 +29,7 @@ stages: enabled: true scanOutputDirectoryOnly: true outputs: - - ${{if and(and(eq(parameters.PublishNugetToFeed, true), eq(parameters.IsReleaseBuild, false)), or(eq(variables['Build.SourceBranch'], 'refs/heads/main'), startsWith(variables['Build.SourceBranch'], 'refs/heads/users/snnn/')))}}: + - ${{if and(and(eq(parameters.PublishNugetToFeed, true), eq(parameters.IsReleaseBuild, false)), or(eq(variables['Build.SourceBranch'], 'refs/heads/main'), startsWith(variables['Build.SourceBranch'], 'refs/heads/rel-')))}}: - output: nuget # condition: and(succeeded(), eq(variables['Build.SourceBranch'], 'refs/heads/main')) # Optional condition useDotNetTask: false # The default is false to use the NuGetCommand task. Set to true to use the DotNetCoreCLI task to publish packages. diff --git a/tools/ci_build/github/azure-pipelines/web-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/web-ci-pipeline.yml index 01920ad1f7fbb..4399219f3f7d5 100644 --- a/tools/ci_build/github/azure-pipelines/web-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/web-ci-pipeline.yml @@ -54,7 +54,7 @@ stages: parameters: NpmPackagingMode: ${{ variables.NpmPackagingMode }} IsReleasePipeline: false - PoolName: 'onnxruntime-Ubuntu2204-AMD-CPU' + PoolName: 'onnxruntime-Ubuntu2404-AMD-CPU' BuildStaticLib: true ExtraBuildArgs: $(ExtraBuildArgs) WASMTemplate: linux-wasm-ci.yml diff --git a/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml index 7ebf5394e4530..66d1cd1687d99 100644 --- a/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml @@ -61,16 +61,6 @@ jobs: # because the python bindings also use the USE__PROVIDER_INTERFACE preprocessor macros. ExtraQnnBuildArgs: '--enable_generic_interface --build_wheel' steps: - - - script: | - MKDIR $(Agent.ToolsDirectory)\Python\3.11.0\arm64 - XCOPY /s /y /h /e /c /q "C:\Python\Python311\*.*" $(Agent.ToolsDirectory)\Python\3.11.0\arm64\ - COPY NUL $(Agent.ToolsDirectory)\Python\3.11.0\arm64.complete - DIR $(Agent.ToolsDirectory)\Python - DIR $(Agent.ToolsDirectory)\Python\3.11.0 - DIR $(Agent.ToolsDirectory)\Python\3.11.0\arm64 - displayName: Copy python 3.11.0 version to agent tools directory - - task: UsePythonVersion@0 inputs: versionSpec: '3.x' diff --git a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cpu b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cpu index db8668fa9eafe..177df14d6eaee 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cpu +++ b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cpu @@ -1,4 +1,4 @@ -FROM onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cpu_x64_almalinux8_gcc14:20250714.2 +FROM onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cpu_x64_almalinux8_gcc14:20250724.1 ENV JAVA_HOME=/usr/lib/jvm/msopenjdk-17 diff --git a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda index 6552c423617b5..489e4ce9f3913 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda +++ b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda @@ -43,4 +43,4 @@ RUN adduser --uid $BUILD_UID $BUILD_USER WORKDIR /home/$BUILD_USER USER $BUILD_USER ENV PATH=/usr/local/dotnet:$PATH -ENV CUDA_MODULE_LOADING="LAZY" \ No newline at end of file +ENV CUDA_MODULE_LOADING="LAZY" diff --git a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_rocm b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_rocm index d20da1867926b..957eef8046eaf 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_rocm +++ b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_rocm @@ -1,4 +1,4 @@ -FROM onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cpu_x64_almalinux8_gcc14:20250714.2 +FROM onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cpu_x64_almalinux8_gcc14:20250724.1 ARG ROCM_VERSION=6.2.3 #Add our own dependencies @@ -23,4 +23,3 @@ RUN adduser --uid $BUILD_UID $BUILD_USER WORKDIR /home/$BUILD_USER USER $BUILD_USER ENV PATH=/usr/local/dotnet:$PATH - diff --git a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_webgpu b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_webgpu index bd3872b4e88e5..56d67599f0bce 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_webgpu +++ b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_webgpu @@ -1,4 +1,4 @@ -FROM onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cpu_x64_almalinux8_gcc14:20250714.2 +FROM onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cpu_x64_almalinux8_gcc14:20250724.1 ENV JAVA_HOME=/usr/lib/jvm/msopenjdk-17 diff --git a/tools/ci_build/github/linux/docker/inference/aarch64/default/cpu/Dockerfile b/tools/ci_build/github/linux/docker/inference/aarch64/default/cpu/Dockerfile index e6e362ade897d..c8e164282a2f0 100644 --- a/tools/ci_build/github/linux/docker/inference/aarch64/default/cpu/Dockerfile +++ b/tools/ci_build/github/linux/docker/inference/aarch64/default/cpu/Dockerfile @@ -2,13 +2,12 @@ # Licensed under the MIT License. # This file is used by Zip-Nuget Packaging NoContribOps Pipeline,Zip-Nuget-Java Packaging Pipeline -FROM onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cpu_aarch64_almalinux8_gcc14_dotnet:20250714.2 +FROM onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cpu_aarch64_almalinux8_gcc14_dotnet:20250724.1 ENV LANG=en_US.UTF-8 ENV LC_ALL=en_US.UTF-8 -ADD scripts /tmp/scripts -RUN cd /tmp/scripts && /tmp/scripts/install_deps.sh && python3 -m pip install flatbuffers && rm -rf /tmp/scripts +RUN python3 -m pip install flatbuffers ARG BUILD_UID=1001 ARG BUILD_USER=onnxruntimedev diff --git a/tools/ci_build/github/linux/docker/inference/aarch64/default/cpu/scripts/install_deps.sh b/tools/ci_build/github/linux/docker/inference/aarch64/default/cpu/scripts/install_deps.sh deleted file mode 100755 index 39d7dcfcb70b8..0000000000000 --- a/tools/ci_build/github/linux/docker/inference/aarch64/default/cpu/scripts/install_deps.sh +++ /dev/null @@ -1,56 +0,0 @@ -#!/bin/bash -set -e -x - -# Download a file from internet -function GetFile { - local uri=$1 - local path=$2 - local force=${3:-false} - local download_retries=${4:-5} - local retry_wait_time_seconds=${5:-30} - - if [[ -f $path ]]; then - if [[ $force = false ]]; then - echo "File '$path' already exists. Skipping download" - return 0 - else - rm -rf "$path" - fi - fi - - if [[ -f $uri ]]; then - echo "'$uri' is a file path, copying file to '$path'" - cp "$uri" "$path" - return $? - fi - - echo "Downloading $uri" - # Use aria2c if available, otherwise use curl - if command -v aria2c > /dev/null; then - aria2c -q -d "$(dirname $path)" -o "$(basename $path)" "$uri" - else - curl "$uri" -sSL --retry $download_retries --retry-delay $retry_wait_time_seconds --create-dirs -o "$path" --fail - fi - - return $? -} -mkdir -p /tmp/src - -cd /tmp/src - -CPU_ARCH=$(uname -m) - -echo "Installing Node.js" - -if [[ "$CPU_ARCH" = "x86_64" ]]; then - NODEJS_ARCH=x64 -elif [[ "$CPU_ARCH" = "aarch64" ]]; then - NODEJS_ARCH=arm64 -else - NODEJS_ARCH=$CPU_ARCH -fi -GetFile https://nodejs.org/dist/v22.17.1/node-v22.17.1-linux-${NODEJS_ARCH}.tar.gz /tmp/src/node-v22.17.1-linux-${NODEJS_ARCH}.tar.gz -tar --strip 1 -xf /tmp/src/node-v22.17.1-linux-${NODEJS_ARCH}.tar.gz -C /usr - -cd / -rm -rf /tmp/src diff --git a/tools/ci_build/github/linux/docker/inference/aarch64/python/cpu/Dockerfile b/tools/ci_build/github/linux/docker/inference/aarch64/python/cpu/Dockerfile index 267fc1e661242..31bd41226263f 100644 --- a/tools/ci_build/github/linux/docker/inference/aarch64/python/cpu/Dockerfile +++ b/tools/ci_build/github/linux/docker/inference/aarch64/python/cpu/Dockerfile @@ -1,4 +1,4 @@ -FROM onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cpu_aarch64_almalinux8_gcc14:20250714.2 +FROM onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cpu_aarch64_almalinux8_gcc14:20250724.1 ADD scripts /tmp/scripts RUN cd /tmp/scripts && /tmp/scripts/install_centos.sh && /tmp/scripts/install_deps.sh && rm -rf /tmp/scripts @@ -8,4 +8,3 @@ ARG BUILD_USER=onnxruntimedev RUN adduser --uid $BUILD_UID $BUILD_USER WORKDIR /home/$BUILD_USER USER $BUILD_USER - diff --git a/tools/ci_build/github/linux/docker/inference/x86_64/default/cpu/Dockerfile b/tools/ci_build/github/linux/docker/inference/x86_64/default/cpu/Dockerfile index 7981210af14a1..461464093688a 100644 --- a/tools/ci_build/github/linux/docker/inference/x86_64/default/cpu/Dockerfile +++ b/tools/ci_build/github/linux/docker/inference/x86_64/default/cpu/Dockerfile @@ -2,13 +2,12 @@ # Licensed under the MIT License. # This file is used by Zip-Nuget Packaging NoContribOps Pipeline,Zip-Nuget-Java Packaging Pipeline -FROM onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cpu_x64_almalinux8_gcc14_dotnet:20250714.2 +FROM onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cpu_x64_almalinux8_gcc14_dotnet:20250724.1 ENV LANG=en_US.UTF-8 ENV LC_ALL=en_US.UTF-8 -ADD scripts /tmp/scripts -RUN cd /tmp/scripts && /tmp/scripts/install_deps.sh && python3 -m pip install flatbuffers && rm -rf /tmp/scripts +RUN python3 -m pip install flatbuffers ARG BUILD_UID=1001 ARG BUILD_USER=onnxruntimedev diff --git a/tools/ci_build/github/linux/docker/inference/x86_64/default/cpu/scripts/install_deps.sh b/tools/ci_build/github/linux/docker/inference/x86_64/default/cpu/scripts/install_deps.sh deleted file mode 100755 index 8a5348f3ef995..0000000000000 --- a/tools/ci_build/github/linux/docker/inference/x86_64/default/cpu/scripts/install_deps.sh +++ /dev/null @@ -1,56 +0,0 @@ -#!/bin/bash -set -e -x - -# Download a file from internet -function GetFile { - local uri=$1 - local path=$2 - local force=${3:-false} - local download_retries=${4:-5} - local retry_wait_time_seconds=${5:-30} - - if [[ -f $path ]]; then - if [[ $force = false ]]; then - echo "File '$path' already exists. Skipping download" - return 0 - else - rm -rf $path - fi - fi - - if [[ -f $uri ]]; then - echo "'$uri' is a file path, copying file to '$path'" - cp $uri $path - return $? - fi - - echo "Downloading $uri" - # Use aria2c if available, otherwise use curl - if command -v aria2c > /dev/null; then - aria2c -q -d $(dirname $path) -o $(basename $path) "$uri" - else - curl "$uri" -sSL --retry $download_retries --retry-delay $retry_wait_time_seconds --create-dirs -o "$path" --fail - fi - - return $? -} -mkdir -p /tmp/src - -cd /tmp/src -CPU_ARCH=$(uname -m) - - -echo "Installing Node.js" -CPU_ARCH=`uname -m` -if [[ "$CPU_ARCH" = "x86_64" ]]; then - NODEJS_ARCH=x64 -elif [[ "$CPU_ARCH" = "aarch64" ]]; then - NODEJS_ARCH=arm64 -else - NODEJS_ARCH=$CPU_ARCH -fi -GetFile https://nodejs.org/dist/v22.17.1/node-v22.17.1-linux-${NODEJS_ARCH}.tar.gz /tmp/src/node-v22.17.1-linux-${NODEJS_ARCH}.tar.gz -tar --strip 1 -xf /tmp/src/node-v22.17.1-linux-${NODEJS_ARCH}.tar.gz -C /usr - -cd / -rm -rf /tmp/src diff --git a/tools/ci_build/github/linux/docker/inference/x86_64/default/cuda12/Dockerfile b/tools/ci_build/github/linux/docker/inference/x86_64/default/cuda12/Dockerfile index 894802dfc8675..043291065736d 100644 --- a/tools/ci_build/github/linux/docker/inference/x86_64/default/cuda12/Dockerfile +++ b/tools/ci_build/github/linux/docker/inference/x86_64/default/cuda12/Dockerfile @@ -2,7 +2,7 @@ # Licensed under the MIT License. # This file is used by Zip-Nuget Packaging NoContribOps Pipeline,Zip-Nuget-Java Packaging Pipeline -FROM onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_almalinux8_gcc12_dotnet:20250714.2 +FROM onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_almalinux8_gcc12_dotnet:20250724.1 ARG TRT_VERSION #Install TensorRT only if TRT_VERSION is not empty @@ -37,8 +37,7 @@ ENV LC_ALL=en_US.UTF-8 ENV CUDAHOSTCXX=/opt/rh/gcc-toolset-12/root/usr/bin/g++ ADD scripts /tmp/scripts -RUN sed -i 's/enabled\s*=\s*1/enabled = 1\nexclude=dotnet* aspnet* netstandard*/g' /etc/yum.repos.d/almalinux.repo && \ - cd /tmp/scripts && /tmp/scripts/install_deps.sh && rm -rf /tmp/scripts +RUN sed -i 's/enabled\s*=\s*1/enabled = 1\nexclude=dotnet* aspnet* netstandard*/g' /etc/yum.repos.d/almalinux.repo ENV PATH=/usr/lib/jvm/msopenjdk-17/bin:$PATH ENV JAVA_HOME=/usr/lib/jvm/msopenjdk-17 ARG BUILD_UID=1001 diff --git a/tools/ci_build/github/linux/docker/inference/x86_64/default/cuda12/scripts/install_deps.sh b/tools/ci_build/github/linux/docker/inference/x86_64/default/cuda12/scripts/install_deps.sh deleted file mode 100755 index f55c017eb8393..0000000000000 --- a/tools/ci_build/github/linux/docker/inference/x86_64/default/cuda12/scripts/install_deps.sh +++ /dev/null @@ -1,55 +0,0 @@ -#!/bin/bash -set -e -x - -# Download a file from internet -function GetFile { - local uri=$1 - local path=$2 - local force=${3:-false} - local download_retries=${4:-5} - local retry_wait_time_seconds=${5:-30} - - if [[ -f $path ]]; then - if [[ $force = false ]]; then - echo "File '$path' already exists. Skipping download" - return 0 - else - rm -rf $path - fi - fi - - if [[ -f $uri ]]; then - echo "'$uri' is a file path, copying file to '$path'" - cp $uri $path - return $? - fi - - echo "Downloading $uri" - # Use aria2c if available, otherwise use curl - if command -v aria2c > /dev/null; then - aria2c -q -d $(dirname $path) -o $(basename $path) "$uri" - else - curl "$uri" -sSL --retry $download_retries --retry-delay $retry_wait_time_seconds --create-dirs -o "$path" --fail - fi - - return $? -} -mkdir -p /tmp/src - -cd /tmp/src - - -echo "Installing Node.js" -CPU_ARCH=`uname -m` -if [[ "$CPU_ARCH" = "x86_64" ]]; then - NODEJS_ARCH=x64 -elif [[ "$CPU_ARCH" = "aarch64" ]]; then - NODEJS_ARCH=arm64 -else - NODEJS_ARCH=$CPU_ARCH -fi -GetFile https://nodejs.org/dist/v22.17.1/node-v22.17.1-linux-${NODEJS_ARCH}.tar.gz /tmp/src/node-v22.17.1-linux-${NODEJS_ARCH}.tar.gz -tar --strip 1 -xf /tmp/src/node-v22.17.1-linux-${NODEJS_ARCH}.tar.gz -C /usr - -cd / -rm -rf /tmp/src diff --git a/tools/ci_build/github/linux/docker/inference/x86_64/python/cpu/Dockerfile b/tools/ci_build/github/linux/docker/inference/x86_64/python/cpu/Dockerfile index fc376e33d6d10..43da13df2fe8b 100644 --- a/tools/ci_build/github/linux/docker/inference/x86_64/python/cpu/Dockerfile +++ b/tools/ci_build/github/linux/docker/inference/x86_64/python/cpu/Dockerfile @@ -1,4 +1,4 @@ -FROM onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cpu_x64_almalinux8_gcc14:20250714.2 +FROM onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cpu_x64_almalinux8_gcc14:20250724.1 ADD scripts /tmp/scripts RUN cd /tmp/scripts && /tmp/scripts/install_centos.sh && rm -rf /tmp/scripts @@ -8,4 +8,3 @@ ARG BUILD_USER=onnxruntimedev RUN adduser --uid $BUILD_UID $BUILD_USER WORKDIR /home/$BUILD_USER USER $BUILD_USER - diff --git a/tools/ci_build/github/linux/docker/inference/x86_64/python/openvino/Dockerfile b/tools/ci_build/github/linux/docker/inference/x86_64/python/openvino/Dockerfile index fe6c00f99323f..f3341f32a768d 100644 --- a/tools/ci_build/github/linux/docker/inference/x86_64/python/openvino/Dockerfile +++ b/tools/ci_build/github/linux/docker/inference/x86_64/python/openvino/Dockerfile @@ -1,5 +1,5 @@ # Use the specified UBI8 base image with GCC 14 -ARG BASEIMAGE="onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cpu_x64_almalinux8_gcc14:20250714.2" +ARG BASEIMAGE="onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cpu_x64_almalinux8_gcc14:20250724.1" FROM ${BASEIMAGE} ARG BUILD_UID=1000 diff --git a/tools/nuget/generate_nuspec_for_native_nuget.py b/tools/nuget/generate_nuspec_for_native_nuget.py index c5a204b6cb958..211cb7a2a8a75 100644 --- a/tools/nuget/generate_nuspec_for_native_nuget.py +++ b/tools/nuget/generate_nuspec_for_native_nuget.py @@ -1081,8 +1081,8 @@ def generate_files(line_list, args): files_list.append( "' + + os.path.join(args.native_build_path, "Qualcomm_LICENSE.pdf") + + '" target="Qualcomm_LICENSE.pdf" />' ) files_list.append("")